从零实现自注意力与交叉注意力:PyTorch实战指南

张开发
2026/4/12 20:29:51 15 分钟阅读

分享文章

从零实现自注意力与交叉注意力:PyTorch实战指南
1. 注意力机制基础概念在深度学习领域注意力机制已经成为现代神经网络架构的核心组件之一。简单来说注意力机制允许模型在处理输入序列时动态地关注与当前任务最相关的部分。想象一下你在阅读一本书时会不自觉地对某些关键段落给予更多关注这就是人类注意力机制的表现。自注意力Self-Attention和交叉注意力Cross-Attention是两种最常见的注意力机制变体。它们的主要区别在于信息的来源自注意力中查询(Query)、键(Key)和值(Value)都来自同一个输入序列交叉注意力中查询来自一个序列而键和值来自另一个序列这两种机制都基于相似的计算原理首先计算查询与键的相似度得分然后使用softmax归一化这些得分得到注意力权重最后用这些权重对值进行加权求和。这个过程可以用以下公式表示Attention(Q,K,V) softmax(QK^T/√d_k)V其中d_k是键向量的维度缩放因子√d_k用于防止点积过大导致softmax梯度消失。2. 自注意力机制实现详解2.1 自注意力核心组件让我们从零开始构建一个完整的自注意力模块。首先需要理解几个关键组件查询(Query)、键(Key)、值(Value)投影将输入映射到不同的表示空间注意力得分计算查询和键的点积衡量它们之间的相关性缩放和softmax归一化得分得到注意力权重加权求和用注意力权重对值进行加权组合在PyTorch中我们可以这样实现import torch import torch.nn as nn import torch.nn.functional as F class SelfAttention(nn.Module): def __init__(self, embed_size, heads): super(SelfAttention, self).__init__() self.embed_size embed_size self.heads heads self.head_dim embed_size // heads assert ( self.head_dim * heads embed_size ), Embedding size needs to be divisible by heads self.values nn.Linear(self.head_dim, self.head_dim, biasFalse) self.keys nn.Linear(self.head_dim, self.head_dim, biasFalse) self.queries nn.Linear(self.head_dim, self.head_dim, biasFalse) self.fc_out nn.Linear(heads * self.head_dim, embed_size) def forward(self, values, keys, query, mask): N query.shape[0] value_len, key_len, query_len values.shape[1], keys.shape[1], query.shape[1] # Split the embedding into self.heads pieces values values.reshape(N, value_len, self.heads, self.head_dim) keys keys.reshape(N, key_len, self.heads, self.head_dim) queries query.reshape(N, query_len, self.heads, self.head_dim) values self.values(values) keys self.keys(keys) queries self.queries(queries) energy torch.einsum(nqhd,nkhd-nhqk, [queries, keys]) if mask is not None: energy energy.masked_fill(mask 0, float(-1e20)) attention torch.softmax(energy / (self.embed_size ** (1/2)), dim3) out torch.einsum(nhql,nlhd-nqhd, [attention, values]).reshape( N, query_len, self.heads * self.head_dim ) out self.fc_out(out) return out2.2 多头注意力机制在实际应用中我们通常会使用多头注意力Multi-Head Attention它将注意力机制并行执行多次然后将结果拼接起来。这样做的好处是模型可以同时关注来自不同位置的不同表示子空间的信息。多头注意力的关键参数包括头数(heads)通常设置为4-8个每个头的维度(head_dim)通常是总嵌入维度除以头数缩放因子(scale)1/√d_k用于稳定梯度3. 交叉注意力机制实现3.1 交叉注意力与自注意力的区别交叉注意力与自注意力的主要区别在于信息的来源不同。在交叉注意力中查询(Query)来自一个序列通常是解码器的当前状态键(Key)和值(Value)来自另一个序列通常是编码器的输出这种机制在序列到序列任务如机器翻译中特别有用它允许解码器在生成每个词时有选择地关注编码器输出中最相关的部分。3.2 PyTorch实现下面是交叉注意力的完整实现class CrossAttention(nn.Module): def __init__(self, query_dim, context_dimNone, heads8, dim_head64, dropout0.): super().__init__() inner_dim dim_head * heads context_dim context_dim if context_dim is not None else query_dim self.scale dim_head ** -0.5 self.heads heads self.to_q nn.Linear(query_dim, inner_dim, biasFalse) self.to_k nn.Linear(context_dim, inner_dim, biasFalse) self.to_v nn.Linear(context_dim, inner_dim, biasFalse) self.to_out nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) def forward(self, x, contextNone, maskNone): h self.heads context x if context is None else context q self.to_q(x) k self.to_k(context) v self.to_v(context) q, k, v map(lambda t: rearrange(t, b n (h d) - (b h) n d, hh), (q, k, v)) sim torch.einsum(b i d, b j d - b i j, q, k) * self.scale if mask is not None: mask rearrange(mask, b ... - b (...)) max_neg_value -torch.finfo(sim.dtype).max mask repeat(mask, b j - (b h) () j, hh) sim.masked_fill_(~mask, max_neg_value) attn sim.softmax(dim-1) out torch.einsum(b i j, b j d - b i d, attn, v) out rearrange(out, (b h) n d - b n (h d), hh) return self.to_out(out)4. 实际应用与性能优化4.1 两种注意力机制的应用场景自注意力和交叉注意力在实际中有不同的应用场景机制典型应用特点自注意力BERT, GPT等语言模型捕捉序列内部关系交叉注意力机器翻译, 图像描述生成建立不同模态或序列间联系4.2 性能优化技巧在实际实现中有几点可以显著提升注意力机制的性能内存优化使用分块计算处理长序列计算优化利用Flash Attention等优化算法稀疏注意力只计算最重要的注意力对缓存机制在推理时缓存键值对例如我们可以实现一个带缓存的交叉注意力版本class CrossAttentionWithCache(nn.Module): def __init__(self, query_dim, context_dim, heads8, dim_head64): super().__init__() # 初始化代码与之前相同... self.cache_k None self.cache_v None def forward(self, x, contextNone, use_cacheFalse): if use_cache and self.cache_k is not None: # 使用缓存的k和v k, v self.cache_k, self.cache_v else: # 正常计算k和v k self.to_k(context) v self.to_v(context) if use_cache: self.cache_k, self.cache_v k, v # 其余计算与之前相同...5. 调试与常见问题解决实现注意力机制时经常会遇到一些问题这里分享几个调试经验注意力权重全为1/n通常是因为输入值太大导致softmax饱和确保使用了正确的缩放因子梯度消失检查注意力得分的范围确保不是太大或太小内存溢出对于长序列考虑使用内存高效的注意力实现训练不稳定添加层归一化(LayerNorm)通常有帮助一个实用的调试技巧是可视化注意力权重def plot_attention(attention_weights): import matplotlib.pyplot as plt plt.imshow(attention_weights.cpu().detach().numpy(), cmaphot) plt.colorbar() plt.show() # 在forward方法中添加 if debug: plot_attention(attn[0,0]) # 显示第一个batch第一个头的注意力6. 完整模型集成现在我们将实现的注意力模块集成到一个完整的Transformer块中class TransformerBlock(nn.Module): def __init__(self, embed_size, heads, dropout, forward_expansion): super(TransformerBlock, self).__init__() self.attention SelfAttention(embed_size, heads) self.norm1 nn.LayerNorm(embed_size) self.norm2 nn.LayerNorm(embed_size) self.feed_forward nn.Sequential( nn.Linear(embed_size, forward_expansion * embed_size), nn.ReLU(), nn.Linear(forward_expansion * embed_size, embed_size) ) self.dropout nn.Dropout(dropout) def forward(self, value, key, query, mask): attention self.attention(value, key, query, mask) x self.dropout(self.norm1(attention query)) forward self.feed_forward(x) out self.dropout(self.norm2(forward x)) return out在实际项目中我发现正确初始化注意力层的权重非常重要。通常使用Xavier初始化或者更小的标准差如0.02可以带来更好的训练稳定性。另外对于不同的任务可能需要调整注意力的温度参数softmax前的缩放因子以获得最佳性能。

更多文章