MindSpore大语言模型实战
上QQ阅读APP看书,第一时间看更新

2.1.1 注意力机制

注意力(Attention)机制是深度学习中一个重要的模型组件,它允许模型集中关注输入序列的特定部分,从而更好地捕捉相关信息。这种机制在处理NLP任务时尤其有用,因为自然语言具有丰富的上下文和依赖关系。

在注意力机制中,有3个关键概念:查询(Query)、键(Key)和值(Value)。假设有一个将英语翻译为法语的机器翻译任务,该任务要将句子“I love cats”翻译为法语。我们使用基于注意力机制的模型进行翻译,其中注意力机制帮助模型关注与当前正在生成的法语单词最相关的英语部分。

①查询:在这个例子中,查询是当前解码器(Decoder)的隐藏状态或正在生成的法语单词的表示。它提供了一个指示模型需要关注哪些信息的信号。

②键:键是输入序列的表示,用于计算查询与输入序列的相似性。键帮助确定在给定上下文中输入序列的哪些部分是相关的。

③值:值是对输入序列的实际信息的表示。当计算注意力权重时,这些值将会被加权求和,从而形成最终的输出。

键和值都来自编码器(Encoder),编码器将英语句子转化为一系列特征向量。每个特征向量既是一个键,又是一个值。它们包含输入句子的语义和上下文信息。在注意力机制中,查询和键主要用于计算注意力权重,而值用于实际的信息传递和输出。

在文本翻译中,通常希望翻译后的句子与原始句子具有相同的意思。因此,在计算注意力权重时,查询一般与目标序列(即翻译后的句子)相关,而键与源序列(即翻译前的原始句子)相关。现在,让我们看看在生成法语单词时,注意力机制是如何工作的。假设要生成法语单词“j'aime”(我喜欢)。在这个时刻,解码器的查询表示正在生成“j'aime”单词。注意力机制会计算解码器的查询与编码器的每个键之间的相似度。相似度高的键对应的值将在注意力机制中得到更高的权重。这意味着模型会更关注与当前生成的法语单词最相关的英语部分。在这个例子中,注意力机制可能会给英语句子中的“love”和“cats”这两个键对应的值分配较高的权重。这意味着模型将更多地关注“love”和“cats”这两个英语单词对于翻译“j'aime”法语单词的贡献。通过对这些具有权重的值进行加权求和,模型得到一个上下文向量(Context Vector),其中包含与当前生成的法语单词相关的英语部分的信息。这个上下文向量将与解码器的其他输入结合,帮助生成下一个法语单词,直到完成整个翻译过程。

计算注意力权重即计算查询和键之间的相似度。常用的计算注意力权重的方法包括加性注意力(Additive Attention)和缩放点积注意力(Scaled Dot-Product Attention),本书主要介绍后者。从几何的角度来看,点积(Dot Product)表示一个向量在另一个向量方向上的投影。换句话说,点积反映了两个向量之间的相似度。为了消除查询()和键()本身的“大小”对相似度计算的影响,需要对点积结果除以进行缩放。

   (2.1)

为了将相似度限制在0~1,注意力机制将对除以后的点积结果进行归一化处理。常见的方法是通过对除以后的点积结果进行softmax操作,使注意力权重符合概率分布。

   (2.2)

代码2.1实现了缩放点积注意力的计算。调用代码2.1中的函数返回加权后的值(output)和注意力权重(attn)。

代码2.1 点积函数

import mindspore
from mindspore import nn
from mindspore import ops
from mindspore import Tensor
from mindspore import dtype as mstype
 
class ScaledDotProductAttention(nn.Cell):
     def __init__(self, dropout_p=0.):
          super().__init__()
          self.softmax = nn.Softmax()
          self.dropout = nn.Dropout(1-dropout_p)
          self.sqrt = ops.Sqrt()
 
     def construct(self, query, key, value, attn_mask=None):
          """scaled dot product attention"""
          # 计算scaling factor
          embed_size = query.shape[-1]
          scaling_factor = self.sqrt(Tensor(embed_size, mstype.float32))
 
          # 注意力权重计算
          # 计算查询和键之间的相似度,并除以scaling factor进行归一化
          attn = ops.matmul(query, key.swapaxes(-2, -1) / scaling_factor)
 
          # 注意力掩码机制
          if attn_mask is not None:
                attn = attn.masked_fill(attn_mask, -1e9)
 
          # softmax保证注意力权重范围为0~1
          attn = self.softmax(attn)
          # dropout
          attn = self.dropout(attn)
          # 对值进行加权
          output = ops.matmul(attn, value)
 
          return (output, attn)

在数据处理过程中,为了统一文本的长度,通常会使用占位符“<pad>”来填充一些较短的文本。例如,对于文本“Hello world!”可以进行填充操作,生成结果<bos> Hello world ! <eos> <pad> <pad>。然而,这些填充占位符<pad>本身并不含有任何信息,因此注意力机制不应考虑它们。为了解决这个问题,注意力机制引入了注意力掩码的概念,用于标识输入序列中的<pad>位置,并确保在计算注意力权重的过程中将这些位置对应的注意力权重设置为0。代码2.2实现了获取注意力掩码功能。

代码2.2 获取注意力掩码功能

def get_attn_pad_mask(seq_q, seq_k, pad_idx):
     """注意力掩码:识别输入序列中的<pad>占位符
 
     Args:
          seq_q (Tensor): query序列,shape = [batch size, query len]
          seq_k (Tensor): key序列,shape = [batch size, key len]
          pad_idx (Tensor): key序列中的<pad>占位符对应的数字索引
     """
     batch_size, len_q = seq_q.shape
     batch_size, len_k = seq_k.shape
 
     # 如果输入序列中元素对应<pad>占位符,则元素所在位置在掩码中对应元素为True
     # pad_attn_mask: [batch size, key len]
     pad_attn_mask = ops.equal(seq_k, pad_idx)
 
     # 增加额外的维度
     # pad_attn_mask: [batch size, 1, key len]
     pad_attn_mask = pad_attn_mask.expand_dims(1)
     # 将掩码广播到[batch size, query len, key len]
     pad_attn_mask = ops.broadcast_to(pad_attn_mask, (batch_size, len_q, len_k))
 
     return pad_attn_mask