强曰为道
与天地相似,故不违。知周乎万物,而道济天下,故不过。旁行而不流,乐天知命,故不忧.
文档目录

Memcached 传输协议精讲 / 第09章 代理与分布式

第09章 代理与分布式

当单台 Memcached 无法满足需求时,分布式架构和代理层成为必然选择。本章深入讲解分片策略、连接池、一致性哈希和负载均衡的实现。


9.1 分布式架构概述

为什么需要分布式

问题说明
内存限制单机内存有限,无法缓存所有数据
连接限制单机连接数有上限
带宽瓶颈单机网络带宽有限
高可用单点故障会导致缓存雪崩

分布式架构模式

模式一:客户端分片(推荐)
┌────────┐
│ Client │──hash(key)──▶┌──────────┐
│        │              │ Server 1 │
└────────┘              ├──────────┤
                        │ Server 2 │
                        ├──────────┤
                        │ Server 3 │
                        └──────────┘

模式二:代理层分片
┌────────┐    ┌────────┐    ┌──────────┐
│ Client │───▶│ Proxy  │───▶│ Server 1 │
│        │    │        │   ─┤ Server 2 │
└────────┘    └────────┘  ──┤ Server 3 │
                            └──────────┘

模式三:服务端集群(如 memcached集群配合ketama)
┌────────┐    ┌──────────┐    ┌──────────┐
│ Client │───▶│ Server 1 │◀──▶│ Server 2 │
└────────┘    └──────────┘    └──────────┘
                ▲                    ▲
                │      同步          │
                ▼                    ▼
              ┌──────────┐
              │ Server 3 │
              └──────────┘

9.2 一致性哈希

哈希取模的问题

最简单的分片方式是 hash(key) % N(N 为服务器数量),但存在严重问题:

当 N=3 时:
key="user:1001" → hash=7 → 7%3=1 → Server 2

当新增一台服务器(N=4)时:
key="user:1001" → hash=7 → 7%4=3 → Server 4

结论:几乎所有 key 的映射都会改变!

一致性哈希原理

一致性哈希将所有节点映射到一个虚拟的哈希环上:

         0
         │
    ┌────┴────┐
    │  Node A │
    │  (100)  │
    └────┬────┘
         │
  120°   │   240°
    ┌────┴────┐
    │  Node B │
    │  (200)  │
    └────┬────┘
         │
         │
    ┌────┴────┐
    │  Node C │
    │  (300)  │
    └────┬────┘
         │
         360°

Key "user:1001" → hash=150 → 落在 Node A(100) 和 Node B(200) 之间
                            → 顺时针找到 Node B

Python 实现

#!/usr/bin/env python3
"""consistent_hash.py — 一致性哈希实现"""

import hashlib
from bisect import bisect_right
from typing import List, Optional

class ConsistentHash:
    def __init__(self, nodes: List[str], virtual_nodes: int = 150):
        """
        nodes: 物理节点列表(如 ['server1:11211', 'server2:11211'])
        virtual_nodes: 每个物理节点的虚拟节点数量
        """
        self.virtual_nodes = virtual_nodes
        self.ring = {}       # hash → node
        self.sorted_keys = []  # 排序的哈希值

        for node in nodes:
            self.add_node(node)

    def _hash(self, key: str) -> int:
        """计算哈希值"""
        return int(hashlib.md5(key.encode()).hexdigest(), 16)

    def add_node(self, node: str):
        """添加节点"""
        for i in range(self.virtual_nodes):
            virtual_key = f"{node}#{i}"
            h = self._hash(virtual_key)
            self.ring[h] = node
            self.sorted_keys.append(h)
        self.sorted_keys.sort()

    def remove_node(self, node: str):
        """移除节点"""
        for i in range(self.virtual_nodes):
            virtual_key = f"{node}#{i}"
            h = self._hash(virtual_key)
            if h in self.ring:
                del self.ring[h]
                self.sorted_keys.remove(h)

    def get_node(self, key: str) -> Optional[str]:
        """获取 key 对应的节点"""
        if not self.ring:
            return None

        h = self._hash(key)
        idx = bisect_right(self.sorted_keys, h)

        if idx == len(self.sorted_keys):
            idx = 0  # 绕回到环的起点

        return self.ring[self.sorted_keys[idx]]

    def get_nodes(self, key: str, count: int = 1) -> List[str]:
        """获取 key 对应的多个节点(用于副本)"""
        if not self.ring:
            return []

        h = self._hash(key)
        nodes = []
        seen = set()
        idx = bisect_right(self.sorted_keys, h)

        for _ in range(len(self.sorted_keys)):
            if idx >= len(self.sorted_keys):
                idx = 0

            node = self.ring[self.sorted_keys[idx]]
            if node not in seen:
                nodes.append(node)
                seen.add(node)
                if len(nodes) >= count:
                    break

            idx += 1

        return nodes


# 测试
ring = ConsistentHash([
    "server1:11211",
    "server2:11211",
    "server3:11211"
])

# 查看分布
distribution = {}
for i in range(1000):
    key = f"user:{i}"
    node = ring.get_node(key)
    distribution[node] = distribution.get(node, 0) + 1

print("数据分布:")
for node, count in sorted(distribution.items()):
    print(f"  {node}: {count} keys ({count/10:.1f}%)")

# 添加新节点后查看变化
ring.add_node("server4:11211")
moved = 0
for i in range(1000):
    key = f"user:{i}"
    new_node = ring.get_node(key)
    # 比较是否改变了节点(需要预先保存旧映射)
print(f"\n添加新节点后约 {1/4*100:.0f}% 的 key 会迁移")

虚拟节点数量选择

虚拟节点数均匀性内存开销建议场景
50一般节点数 > 100
150良好通用场景(推荐)
250优秀较高节点数 < 10
500极好测试/基准

9.3 连接池

为什么需要连接池

问题说明
TCP 连接开销每次建立连接需要 3 次握手
文件描述符限制每个连接占用一个 fd
连接数限制服务端有最大连接数限制
复用效率连接池可显著降低延迟

连接池实现

#!/usr/bin/env python3
"""connection_pool.py — Memcached 连接池"""

import socket
import threading
import time
from queue import Queue, Empty
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Optional

@dataclass
class PooledConnection:
    socket: socket.socket
    created_at: float
    last_used: float
    in_use: bool = False

class MemcachedConnectionPool:
    def __init__(self, host: str = '127.0.0.1', port: int = 11211,
                 min_size: int = 5, max_size: int = 20,
                 max_idle_time: int = 300):
        self.host = host
        self.port = port
        self.min_size = min_size
        self.max_size = max_size
        self.max_idle_time = max_idle_time

        self._pool = Queue(maxsize=max_size)
        self._all_connections = []
        self._lock = threading.Lock()
        self._current_size = 0

        # 预创建最小连接数
        for _ in range(min_size):
            conn = self._create_connection()
            if conn:
                self._pool.put(conn)

    def _create_connection(self) -> Optional[PooledConnection]:
        """创建新连接"""
        try:
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.connect((self.host, self.port))
            conn = PooledConnection(
                socket=sock,
                created_at=time.time(),
                last_used=time.time()
            )
            with self._lock:
                self._all_connections.append(conn)
                self._current_size += 1
            return conn
        except Exception as e:
            print(f"连接创建失败: {e}")
            return None

    @contextmanager
    def acquire(self):
        """获取连接(上下文管理器)"""
        conn = None
        try:
            # 尝试从池中获取
            try:
                conn = self._pool.get_nowait()
                # 检查连接是否过期
                if time.time() - conn.last_used > self.max_idle_time:
                    self._close_connection(conn)
                    conn = None
            except Empty:
                pass

            # 池中无可用连接,尝试创建新连接
            if conn is None:
                with self._lock:
                    if self._current_size < self.max_size:
                        conn = self._create_connection()

            # 等待可用连接
            if conn is None:
                try:
                    conn = self._pool.get(timeout=5)
                except Empty:
                    raise TimeoutError("获取连接超时")

            conn.in_use = True
            conn.last_used = time.time()
            yield conn.socket

        finally:
            if conn:
                conn.in_use = False
                conn.last_used = time.time()
                # 验证连接是否仍然可用
                try:
                    conn.socket.sendall(b"version\r\n")
                    resp = conn.socket.recv(1024)
                    if b"VERSION" in resp:
                        self._pool.put(conn)
                    else:
                        self._close_connection(conn)
                except Exception:
                    self._close_connection(conn)

    def _close_connection(self, conn: PooledConnection):
        """关闭连接"""
        try:
            conn.socket.sendall(b"quit\r\n")
            conn.socket.close()
        except Exception:
            pass
        finally:
            with self._lock:
                if conn in self._all_connections:
                    self._all_connections.remove(conn)
                self._current_size -= 1

    def close_all(self):
        """关闭所有连接"""
        while not self._pool.empty():
            try:
                conn = self._pool.get_nowait()
                self._close_connection(conn)
            except Empty:
                break

    @property
    def stats(self) -> dict:
        """连接池统计"""
        return {
            'total': self._current_size,
            'available': self._pool.qsize(),
            'in_use': self._current_size - self._pool.qsize()
        }


# 使用示例
pool = MemcachedConnectionPool(
    host='127.0.0.1', port=11211,
    min_size=5, max_size=20
)

# 使用连接
with pool.acquire() as sock:
    sock.sendall(b"set pool:test 0 60 5\r\nhello\r\n")
    print(sock.recv(1024).decode())

# 连接自动归还到池中
print(f"连接池状态: {pool.stats}")

pool.close_all()

9.4 带一致性哈希的分布式客户端

#!/usr/bin/env python3
"""distributed_client.py — 带一致性哈希的分布式 Memcached 客户端"""

import socket
import json
from typing import List, Optional

class DistributedMemcached:
    def __init__(self, servers: List[str], virtual_nodes: int = 150):
        """
        servers: ['host1:port1', 'host2:port2', ...]
        """
        self.servers = servers
        self.hash_ring = ConsistentHash(servers, virtual_nodes)
        self.connections = {}  # server → socket

        # 建立连接
        for server in servers:
            host, port = server.split(':')
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.connect((host, int(port)))
            self.connections[server] = sock

    def _get_connection(self, key: str) -> tuple[socket.socket, str]:
        server = self.hash_ring.get_node(key)
        return self.connections[server], server

    def set(self, key: str, value: bytes, flags: int = 0,
            exptime: int = 0) -> bool:
        conn, _ = self._get_connection(key)
        data = value if isinstance(value, bytes) else value.encode()
        cmd = f"set {key} {flags} {exptime} {len(data)}\r\n"
        conn.sendall(cmd.encode() + data + b"\r\n")
        return b"STORED" in conn.recv(1024)

    def get(self, key: str) -> Optional[bytes]:
        conn, _ = self._get_connection(key)
        conn.sendall(f"get {key}\r\n".encode())
        buffer = b""
        while True:
            chunk = conn.recv(65536)
            buffer += chunk
            if b"END\r\n" in buffer:
                break

        lines = buffer.split(b"\r\n")
        if lines[0].startswith(b"VALUE"):
            return lines[1]
        return None

    def get_multi(self, keys: List[str]) -> dict[str, bytes]:
        """批量获取(自动按服务器分组)"""
        # 按服务器分组
        server_keys = {}
        for key in keys:
            server = self.hash_ring.get_node(key)
            if server not in server_keys:
                server_keys[server] = []
            server_keys[server].append(key)

        # 并行查询各服务器
        result = {}
        for server, skeys in server_keys.items():
            conn = self.connections[server]
            cmd = f"get {' '.join(skeys)}\r\n"
            conn.sendall(cmd.encode())

            buffer = b""
            while True:
                chunk = conn.recv(65536)
                buffer += chunk
                if b"END\r\n" in buffer:
                    break

            lines = buffer.split(b"\r\n")
            i = 0
            while i < len(lines):
                line = lines[i]
                if line == b"END":
                    break
                if line.startswith(b"VALUE "):
                    key = line.decode().split()[1]
                    i += 1
                    if i < len(lines):
                        result[key] = lines[i]
                i += 1

        return result

    def delete(self, key: str) -> bool:
        conn, _ = self._get_connection(key)
        conn.sendall(f"delete {key}\r\n".encode())
        return b"DELETED" in conn.recv(1024)

    def stats(self) -> dict[str, dict]:
        """获取所有服务器的统计"""
        result = {}
        for server, conn in self.connections.items():
            conn.sendall(b"stats\r\n")
            buffer = b""
            while True:
                chunk = conn.recv(8192)
                buffer += chunk
                if b"END\r\n" in buffer:
                    break

            stats = {}
            for line in buffer.decode().split("\r\n"):
                if line.startswith("STAT"):
                    parts = line.split()
                    stats[parts[1]] = parts[2]
            result[server] = stats
        return result

    def close(self):
        for conn in self.connections.values():
            try:
                conn.sendall(b"quit\r\n")
                conn.close()
            except Exception:
                pass


# 使用示例
client = DistributedMemcached([
    '127.0.0.1:11211',
    # '127.0.0.1:11212',
    # '127.0.0.1:11213',
])

# 写入数据
client.set("user:1001", json.dumps({"name": "Bob"}).encode())
client.set("user:1002", json.dumps({"name": "Alice"}).encode())

# 读取数据
data = client.get("user:1001")
if data:
    print(f"user:1001 = {json.loads(data)}")

# 批量读取
results = client.get_multi(["user:1001", "user:1002"])
for key, val in results.items():
    print(f"{key} = {json.loads(val)}")

# 查看统计
stats = client.stats()
for server, s in stats.items():
    print(f"{server}: items={s.get('curr_items', 'N/A')}, "
          f"hit_rate={int(s.get('get_hits', 0)) / max(1, int(s.get('get_hits', 0)) + int(s.get('get_misses', 0))) * 100:.1f}%")

client.close()

9.5 Memcached 代理(mcrouter)

mcrouter 简介

mcrouter 是 Facebook 开发的 Memcached 代理,支持:

  • 连接池和多路复用
  • 一致性哈希
  • 复制和故障转移
  • 流量镜像
  • 冷热分离

mcrouter 配置示例

{
  "pools": {
    "A": {
      "servers": [
        "127.0.0.1:11211",
        "127.0.0.1:11212"
      ]
    },
    "B": {
      "servers": [
        "127.0.0.1:11213",
        "127.0.0.1:11214"
      ]
    }
  },
  "route": {
    "type": "OperationSelectorRoute",
    "default_policy": {
      "type": "PoolRoute",
      "pool": "A"
    },
    "operation_policies": {
      "get": {
        "type": "FailoverRoute",
        "children": [
          {"type": "PoolRoute", "pool": "A"},
          {"type": "PoolRoute", "pool": "B"}
        ]
      }
    }
  }
}

启动 mcrouter

# 安装
sudo apt install mcrouter

# 启动
mcrouter --config-file=/etc/mcrouter/config.json \
         --port=11210 \
         --listen-addr=127.0.0.1

# 连接到代理(像连接普通 Memcached 一样使用)
echo "version" | nc 127.0.0.1 11210

9.6 负载均衡策略

策略对比

策略说明适用场景
一致性哈希相同 key 总是路由到同一节点通用缓存(推荐)
取模哈希hash(key) % N节点数固定
范围分片按 key 前缀分配节点多租户系统
随机随机选择节点无状态数据
轮询Round-Robin负载均匀

范围分片实现

class RangeSharding:
    """按 key 前缀分片"""
    def __init__(self, servers: List[str]):
        self.servers = servers
        self.prefix_map = {}

    def add_rule(self, prefix: str, server: str):
        self.prefix_map[prefix] = server

    def get_server(self, key: str) -> str:
        for prefix, server in self.prefix_map.items():
            if key.startswith(prefix):
                return server
        # 默认使用一致性哈希
        return self.servers[hash(key) % len(self.servers)]

# 使用
sharding = RangeSharding(['s1:11211', 's2:11211', 's3:11211'])
sharding.add_rule("user:", "s1:11211")
sharding.add_rule("session:", "s2:11211")
sharding.add_rule("product:", "s3:11211")

print(sharding.get_server("user:1001"))      # s1:11211
print(sharding.get_server("session:abc"))     # s2:11211
print(sharding.get_server("product:2001"))    # s3:11211

9.7 故障转移

自动故障检测

import socket
import time
from typing import Dict

class HealthChecker:
    def __init__(self, servers: List[str], check_interval: int = 5):
        self.servers = servers
        self.check_interval = check_interval
        self.healthy = {s: True for s in servers}
        self.last_check = {s: 0 for s in servers}

    def is_healthy(self, server: str) -> bool:
        return self.healthy.get(server, False)

    def check(self, server: str) -> bool:
        try:
            host, port = server.split(':')
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.settimeout(2)
            sock.connect((host, int(port)))
            sock.sendall(b"version\r\n")
            resp = sock.recv(1024)
            sock.close()
            is_ok = b"VERSION" in resp
            self.healthy[server] = is_ok
            self.last_check[server] = time.time()
            return is_ok
        except Exception:
            self.healthy[server] = False
            self.last_check[server] = time.time()
            return False

    def check_all(self) -> Dict[str, bool]:
        results = {}
        for server in self.servers:
            results[server] = self.check(server)
        return results

    def get_healthy_servers(self) -> List[str]:
        return [s for s in self.servers if self.healthy[s]]


# 故障转移客户端
class FailoverClient:
    def __init__(self, servers: List[str]):
        self.hash_ring = ConsistentHash(servers)
        self.connections: Dict[str, socket.socket] = {}
        self.health_checker = HealthChecker(servers)

        for server in servers:
            self._connect(server)

    def _connect(self, server: str):
        try:
            host, port = server.split(':')
            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.connect((host, int(port)))
            self.connections[server] = sock
        except Exception as e:
            print(f"连接失败 {server}: {e}")

    def _get_connection(self, key: str, exclude: set = None):
        """获取连接,支持排除故障节点"""
        if exclude is None:
            exclude = set()

        # 获取所有候选节点
        candidates = self.hash_ring.get_nodes(key, count=len(self.servers))
        for server in candidates:
            if server in exclude:
                continue
            if not self.health_checker.is_healthy(server):
                exclude.add(server)
                continue
            try:
                conn = self.connections.get(server)
                if conn:
                    return conn, server
            except Exception:
                exclude.add(server)

        raise ConnectionError(f"所有候选节点不可用: {key}")

    def get(self, key: str) -> Optional[bytes]:
        excluded = set()
        last_error = None

        for _ in range(min(3, len(self.servers))):
            try:
                conn, server = self._get_connection(key, excluded)
                conn.sendall(f"get {key}\r\n".encode())

                buffer = b""
                conn.settimeout(5)
                while True:
                    chunk = conn.recv(65536)
                    buffer += chunk
                    if b"END\r\n" in buffer:
                        break

                lines = buffer.split(b"\r\n")
                if lines[0].startswith(b"VALUE"):
                    return lines[1]
                return None

            except Exception as e:
                last_error = e
                continue

        raise ConnectionError(f"故障转移失败: {last_error}")

9.8 业务场景

场景一:电商缓存架构

class EcommerceCache:
    def __init__(self):
        # 不同业务使用不同的缓存集群
        self.product_cache = DistributedMemcached([
            'cache-product-1:11211',
            'cache-product-2:11211',
        ])
        self.session_cache = DistributedMemcached([
            'cache-session-1:11211',
            'cache-session-2:11211',
        ])
        self.user_cache = DistributedMemcached([
            'cache-user-1:11211',
            'cache-user-2:11211',
        ])

    def get_product(self, product_id: int) -> dict:
        data = self.product_cache.get(f"product:{product_id}")
        return json.loads(data) if data else None

    def cache_product(self, product_id: int, data: dict, ttl: int = 600):
        self.product_cache.set(
            f"product:{product_id}",
            json.dumps(data).encode(),
            exptime=ttl
        )

    def get_user_session(self, session_id: str) -> dict:
        data = self.session_cache.get(f"session:{session_id}")
        return json.loads(data) if data else None

场景二:多级缓存架构

class MultiLevelCache:
    """L1(本地) + L2(Memcached) 多级缓存"""
    def __init__(self, servers: List[str]):
        self.l1_cache = {}  # 本地内存缓存
        self.l1_ttl = {}
        self.l2_client = DistributedMemcached(servers)
        self.l1_max_size = 1000
        self.l1_ttl_seconds = 60

    def get(self, key: str) -> Optional[bytes]:
        # L1 命中
        if key in self.l1_cache:
            if self.l1_ttl.get(key, 0) > time.time():
                return self.l1_cache[key]
            else:
                del self.l1_cache[key]
                del self.l1_ttl[key]

        # L2 查询
        data = self.l2_client.get(key)
        if data:
            # 回填 L1
            if len(self.l1_cache) >= self.l1_max_size:
                self._evict_l1()
            self.l1_cache[key] = data
            self.l1_ttl[key] = time.time() + self.l1_ttl_seconds

        return data

    def set(self, key: str, value: bytes, ttl: int = 300):
        # 写入 L2
        self.l2_client.set(key, value, exptime=ttl)
        # 更新 L1
        self.l1_cache[key] = value
        self.l1_ttl[key] = time.time() + self.l1_ttl_seconds

    def _evict_l1(self):
        """LRU 淘汰"""
        now = time.time()
        expired = [k for k, t in self.l1_ttl.items() if t < now]
        for k in expired:
            del self.l1_cache[k]
            del self.l1_ttl[k]
        # 如果仍然过大,随机淘汰
        while len(self.l1_cache) >= self.l1_max_size:
            self.l1_cache.pop(next(iter(self.l1_cache)))

9.9 注意事项

编号注意事项说明
1节点增减时的数据迁移一致性哈希只影响 1/N 的 key
2连接池大小设置根据并发量设置,避免连接过多或过少
3故障检测间隔太频繁影响性能,太稀疏影响可用性
4热点 key 问题某些 key 访问量极高,考虑本地缓存
5批量操作的路由get_multi 需要按服务器分组
6时钟同步CAS unique 依赖服务端时钟

9.10 扩展阅读


上一章: 第08章 Meta 协议 下一章: 第10章 最佳实践 — 协议选择、客户端开发、性能优化与安全加固。