"); //-->
需要 重新计算所有之前 token 的 K 和 V,并与当前 token 进行注意力计算。
计算复杂度是 O(n²)(对于长度为 n 的序列)。
只需计算 新 token 的 K 和 V,然后将其与缓存的值结合使用。
计算复杂度下降到 O(n)(每个 token 只与之前缓存的 token 计算注意力)。
初始输入: [t0, t1, t2] 首次计算: K=[K0,K1,K2], V=[V0,V1,V2] → 生成t3 缓存状态: K=[K0,K1,K2], V=[V0,V1,V2] 第二次计算: 新Q=Q3 注意力计算: Attention(Q3, [K0,K1,K2]) → 生成t4 更新缓存: K=[K0,K1,K2,K3], V=[V0,V1,V2,V3] 第三次计算: 新Q=Q4 注意力计算: Attention(Q4, [K0,K1,K2,K3]) → 生成t5 更新缓存: K=[K0,K1,K2,K3,K4], V=[V0,V1,V2,V3,V4] ...
import torch import torch.nn as nn import torch.nn.functional as F class SelfAttention(nn.Module): def __init__(self, embed_dim, num_heads): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads # 定义Q、K、V投影矩阵 self.q_proj = nn.Linear(embed_dim, embed_dim) self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) def forward(self, x): batch_size, seq_len, embed_dim = x.shape # 计算Q、K、V q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # 计算注意力分数 attn_scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) attn_probs = F.softmax(attn_scores, dim=-1) # 应用注意力权重 output = attn_probs @ v output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim) return self.out_proj(output)
class CachedSelfAttention(nn.Module): def __init__(self, embed_dim, num_heads): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads # 定义投影矩阵 self.q_proj = nn.Linear(embed_dim, embed_dim) self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) # 初始化缓存 self.cache_k = None self.cache_v = None def forward(self, x, use_cache=False): batch_size, seq_len, embed_dim = x.shape # 计算Q、K、V q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # 如果使用缓存且缓存存在,则拼接历史KV if use_cache and self.cache_k is not None: k = torch.cat([self.cache_k, k], dim=-2) v = torch.cat([self.cache_v, v], dim=-2) # 如果使用缓存,更新缓存 if use_cache: self.cache_k = k self.cache_v = v # 计算注意力分数(注意这里的k是包含历史缓存的) attn_scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) attn_probs = F.softmax(attn_scores, dim=-1) # 应用注意力权重 output = attn_probs @ v output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim) return self.out_proj(output) def reset_cache(self): """重置缓存,用于新序列的生成""" self.cache_k = None self.cache_v = None
def generate_text(model, input_ids, max_length=50): # 初始化模型缓存 model.reset_cache() # 处理初始输入 output = model(input_ids, use_cache=True) next_token = torch.argmax(output[:, -1, :], dim=-1, keepdim=True) generated = [next_token] # 生成后续token for _ in range(max_length - 1): # 只输入新生成的token output = model(next_token, use_cache=True) next_token = torch.argmax(output[:, -1, :], dim=-1, keepdim=True) generated.append(next_token) # 如果生成结束符则停止 if next_token.item() == 102: # 假设102是[SEP]的id break return torch.cat(generated, dim=1)
*博客内容为网友个人发布,仅代表博主个人观点,如有侵权请联系工作人员删除。