"); //-->
利用 Router 分配 token 给不同的 expert
每个 expert 内部只在局部或关键 token 上建立注意力连接
减少 token-to-token 的注意力连接(只对一部分 token 建立 attention)
降低计算复杂度,同时保持关键 token 的交互质量
位置: 0 1 2 3 4 5 6 7 8 strided: ↑ ↑ ↑ ↑
位置: 5 局部窗口: 3 4 5 6 7
Token 位置: 8 局部连接: 6 7 8 跳跃连接: 0 2 4
def build_sparse_attention_mask(seq_len, block_size=64, num_local_blocks=1, stride=2): """ 构造稀疏注意力 mask:局部 + 跳跃 """ mask = torch.zeros(seq_len, seq_len, dtype=torch.bool) for i in range(seq_len): # 添加局部窗口(例如前后1个block) for j in range(-num_local_blocks, num_local_blocks + 1): idx = i + j * block_size if 0 <= idx < seq_len: mask[i, idx] = True # 添加跳跃连接(stride) for j in range(0, seq_len, stride): mask[i, j] = True return mask
attn_scores = torch.matmul(query, key.transpose(-2, -1)) # [B, H, N, N] attn_scores[~mask] = -inf # 掩蔽非连接位置 attn_probs = softmax(attn_scores)
def sparse_attention(query, key, value, mask): attn_scores = torch.matmul(query, key.transpose(-2, -1)) attn_scores = attn_scores.masked_fill(~mask, float('-inf')) attn_probs = torch.softmax(attn_scores, dim=-1) return torch.matmul(attn_probs, value)
Token ID → 0 1 2 3 4 5 6 7 8 ┌────────────────── 0 │● ● ● 1 │ ● ● ● 2 │ ● ● ● 3 │ ● ● ● 4 │ ● ● ● 5 │ ● ● ● 6 │ ● ● ● 7 │ ● ● ● 8 │● ● ●
提升计算效率
在不显著增加推理成本的前提下,扩大模型容量(参数量)
# 伪代码 scores = gate(x) # (batch_size, num_experts) top_k_scores, top_k_indices = torch.topk(scores, k=2)
scores 通常是通过一个线性层获得,表示输入样本对各个专家的偏好程度。
路由器可能带有噪声或正则(如 Switch Transformer 中的 noisy gating)。
class Expert(nn.Module): def __init__(self, hidden_dim): self.ff = nn.Sequential( nn.Linear(hidden_dim, 4*hidden_dim), nn.ReLU(), nn.Linear(4*hidden_dim, hidden_dim) )
for i in range(num_experts): expert_input = input[mask[:, i]] expert_output = experts[i](expert_input) output[mask[:, i]] = expert_output * gate_scores[:, i]
class MoE(nn.Module): def __init__(self, hidden_size, experts=..., ep_size=..., k=1): self.experts = Experts() self.gate = TopKGate() ...
ep_size: 每个专家组的并行数(Expert Parallelism)。
k: top-k gating.
使用了通信优化如 All-to-All 分发样本。
class MoE(nn.Module): def __init__(self, hidden_size: int, expert: nn.Module, num_experts: int = 1, ep_size: int = 1, k: int = 1, capacity_factor: float = 1.0, eval_capacity_factor: float = 1.0, min_capacity: int = 4, use_residual: bool = False, noisy_gate_policy: Optional[str] = None, drop_tokens: bool = True, use_rts: bool = True, use_tutel: bool = False, enable_expert_tensor_parallelism: bool = False, top2_2nd_expert_sampling: bool = True)
hidden_size:输入和输出的维度;
expert:作为子模块传入的专家网络(如 MLP);
num_experts:专家总数;
ep_size:专家并行维度;
k:选用 top‑k 路由;
capacity_factor 和 eval_capacity_factor:训练/评估期间专家最大处理 token 数比例;
min_capacity:每个专家至少能接收的 token 数;
use_residual:是否启用 Residual MoE 结构;
noisy_gate_policy、drop_tokens、use_rts:路由噪声、token 丢弃、随机选择;
enable_expert_tensor_parallelism:专家参数 tensor 切分;
top2_2nd_expert_sampling:top‑2 第二专家采样策略。
def forward(self, hidden_states: Tensor, used_token: Optional[Tensor] = None) -> Tuple[Tensor, Tensor, Tensor]:
def forward(self, hidden_states, used_token=None): # 1️⃣ Gating 阶段:计算每个 token 的专家 logits,并选出 top‑k 专家 gates, load, indices, expert_capacity = self.gate( hidden_states, self.training) # gates: (B, k) 专家权重,load: auxiliary balance loss,indices: 专家索引 # 2️⃣ Capacity 控制:根据 capacity_factor 限制每个专家最多处理的 token 数 # 用 expert_capacity 来计算实际可接收的 token 数 # 3️⃣ Dispatch 阶段:将 token 分发给对应专家 dispatch_mask, combine_mask = create_masks(indices, expert_capacity) # dispatch_mask: 用于提取每个专家的 token # 重塑 hidden_states 方便通信: expert_inputs = torch.einsum("b h, b e -> e b h", hidden_states, dispatch_mask) # 4️⃣ all_to_all 分发 token:跨 GPU 路由 token 到对应专家所在 GPU expert_inputs = all_to_all(expert_inputs, self.expert_parallel_group) # 5️⃣ 专家计算阶段:每个专家在本地 receive 的 token 上执行 forward expert_outputs = self.experts(expert_inputs) # 6️⃣ all_to_all 收集结果:各专家输出回传给原始 GPU expert_outputs = all_to_all(expert_outputs, self.expert_parallel_group) # 7️⃣ 合并输出:将专家输出按照 token-group 重组回 batch 维度 output = torch.einsum("e b h, b e -> b h", expert_outputs, combine_mask) return output, load, indices.bincount(...)
class TopKGate(nn.Module): def forward(self, input): logits = self.w_gating(input) topk_vals, topk_indices = torch.topk(logits, k) ...
有 4 张 GPU,每张 GPU 上部署 2 个专家,共 8 个专家;
假设 batch 里的一部分 token 需要被路由到第 3 个专家(在 GPU 2 上),另一部分需要被送到第 6 个专家(在 GPU 4 上);
那么就必须跨 GPU 发送这些 token —— 所以通信效率就非常关键。
每个进程(GPU)上都有自己的输入 token;
每个进程根据路由器(Gate)的输出,将 token 分为 N 份,分别属于 N 个专家(也就是 N 个目标 GPU);
然后用 all_to_all 一次性把这 N 份数据分发到对应的 GPU 上;
所有专家完成前向计算后,再用 all_to_all 把结果发回。
def all_to_all(input, group): # input shape: (num_local_experts, tokens_per_expert, hidden_size) output = torch.empty_like(input) torch.distributed.all_to_all_single(output, input, group=group) return output
[local_experts, tokens_per_expert, hidden_dim]
tokens 分发阶段:将 token 从本地发送到它所需的专家所在的 GPU;
结果收集阶段:将处理后的输出再收集回原始的 GPU。
[local_experts, tokens_per_expert, hidden_dim]
tokens 分发阶段:将 token 从本地发送到它所需的专家所在的 GPU;
结果收集阶段:将处理后的输出再收集回原始的 GPU。
*博客内容为网友个人发布,仅代表博主个人观点,如有侵权请联系工作人员删除。