参考
一文搞定自注意力机制(Self-Attention)
Query(查询):表示当前的关注点或目标,是模型在特定时刻的”需求”或”兴趣点”。Query体现了模型的主观意图,用于主动引导对信息的检索和聚焦。
Key(键):表示被比对的对象,是客观存在的特征向量,作为与Query进行匹配的依据。Key相当于数据的”索引”或”标识符”,帮助模型判断哪些信息与当前Query相关。
Value(值):表示与Key对应的详细信息,是实际的特征向量内容。当Query与某个Key匹配后,模型会提取对应的Value作为最终输出。Value可视为数据的”实际内容”或”特征表示”。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
| import torch from torch import nn import torch.nn.functional as F
class BahdanauAttention(nn.Module): def __init__(self, enc_hidden_size, dec_hidden_size): super().__init__() self.attn = nn.Linear(enc_hidden_size + dec_hidden_size, dec_hidden_size, bias=False) self.v = nn.Parameter(torch.rand(dec_hidden_size))
def forward(self, hidden_dec, encoder_outputs):
batch_size = encoder_outputs.size(0) src_len = encoder_outputs.size(1)
hidden_dec_exp = hidden_dec.unsqueeze(1).repeat(1, src_len, 1)
energy = torch.tanh(self.attn(torch.cat((hidden_dec_exp, encoder_outputs), dim=2))) v = self.v.unsqueeze(0).unsqueeze(1).repeat(batch_size, 1, 1)
attn_scores = torch.bmm(v, energy.transpose(1, 2)).squeeze(1) attn_weights = F.softmax(attn_scores, dim=1)
return attn_weights
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
| class DecoderWithAttention(nn.Module): def __init__(self, input_size, enc_hidden_size, dec_hidden_size, output_size): super().__init__() self.embedding = nn.Embedding(input_size, dec_hidden_size) self.attention = BahdanauAttention(enc_hidden_size, dec_hidden_size) self.rnn = nn.GRU(enc_hidden_size + dec_hidden_size, dec_hidden_size, batch_first=True) self.fc_out = nn.Linear(enc_hidden_size + dec_hidden_size * 2, output_size)
def forward(self, input_token, hidden_dec, encoder_outputs):
embedded = self.embedding(input_token).unsqueeze(1)
attn_weights = self.attention(hidden_dec, encoder_outputs)
context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs)
rnn_input = torch.cat((embedded, context), dim=2) output, hidden = self.rnn(rnn_input, hidden_dec.unsqueeze(0))
output = output.squeeze(1) context = context.squeeze(1) embedded = embedded.squeeze(1)
final_out = self.fc_out(torch.cat((output, context, embedded), dim=1))
return final_out, hidden.squeeze(0), attn_weights
|
自注意力机制Self-Attention
是一种将单个序列的不同位置关联起来以计算同一序列的表示的注意机制。是一种将单个序列的不同位置关联起来以计算同一序列的表示的注意机制。可以建立全局的依赖关系,扩大图像的感受野。相比于CNN,其感受野更大,可以获取更多上下文信息。通过筛选重要信息,过滤不重要信息实现的,这就导致其有效信息的抓取能力会比CNN小一些。必须通过大量数据进行学习。这就导致自注意力机制只有在大数据的基础上才能有效地建立准确的全局关系,而在小数据的情况下,其效果不如CNN。