参考
一文搞定自注意力机制(Self-Attention)
Query(查询):表示当前的关注点或目标,体现了模型的主观意图,用于引导信息的聚焦。
Key(键):表示被比对的对象,是客观存在的特征向量,作为与Query进行匹配的依据。Key相当于数据的”索引”或”标识符”,帮助模型判断哪些信息与当前Query相关。
Value(值):表示与Key对应的详细信息,是实际的特征向量内容。当Query与某个Key匹配后,模型会提取对应的Value作为最终输出。Value可视为数据的”实际内容”或”特征表示”。
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 
 | import torchfrom torch import nn
 import torch.nn.functional as F
 
 
 class BahdanauAttention(nn.Module):
 def __init__(self, enc_hidden_size, dec_hidden_size):
 super().__init__()
 self.attn = nn.Linear(enc_hidden_size + dec_hidden_size, dec_hidden_size, bias=False)
 self.v = nn.Parameter(torch.rand(dec_hidden_size))
 
 def forward(self, hidden_dec, encoder_outputs):
 
 
 
 batch_size = encoder_outputs.size(0)
 src_len = encoder_outputs.size(1)
 
 
 hidden_dec_exp = hidden_dec.unsqueeze(1).repeat(1, src_len, 1)
 
 
 energy = torch.tanh(self.attn(torch.cat((hidden_dec_exp, encoder_outputs), dim=2)))
 v = self.v.unsqueeze(0).unsqueeze(1).repeat(batch_size, 1, 1)
 
 attn_scores = torch.bmm(v, energy.transpose(1, 2)).squeeze(1)
 attn_weights = F.softmax(attn_scores, dim=1)
 
 return attn_weights
 
 | 
| 12
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 
 | class DecoderWithAttention(nn.Module):def __init__(self, input_size, enc_hidden_size, dec_hidden_size, output_size):
 super().__init__()
 self.embedding = nn.Embedding(input_size, dec_hidden_size)
 self.attention = BahdanauAttention(enc_hidden_size, dec_hidden_size)
 self.rnn = nn.GRU(enc_hidden_size + dec_hidden_size, dec_hidden_size, batch_first=True)
 self.fc_out = nn.Linear(enc_hidden_size + dec_hidden_size * 2, output_size)
 
 def forward(self, input_token, hidden_dec, encoder_outputs):
 
 
 
 
 embedded = self.embedding(input_token).unsqueeze(1)
 
 
 attn_weights = self.attention(hidden_dec, encoder_outputs)
 
 
 context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs)
 
 
 rnn_input = torch.cat((embedded, context), dim=2)
 output, hidden = self.rnn(rnn_input, hidden_dec.unsqueeze(0))
 
 
 output = output.squeeze(1)
 context = context.squeeze(1)
 embedded = embedded.squeeze(1)
 
 final_out = self.fc_out(torch.cat((output, context, embedded), dim=1))
 
 return final_out, hidden.squeeze(0), attn_weights
 
 
 | 
自注意力机制Self-Attention
是一种将单个序列的不同位置关联起来以计算同一序列的表示的注意机制。通过建立全局依赖关系来扩大图像的感受野。由于在大数据的基础上才能有效建立准确的全局关系,故小数据的情况下效果不如CNN。
自注意力机制是通过筛选重要信息和过滤不重要信息实现的,这导致其有效信息的抓取能力比CNN更小。
相比于CNN,其感受野更大,可以获取更多上下文信息。通过筛选重要信息,过滤不重要信息实现的,这就导致其有效信息的抓取能力会比CNN小一些。必须通过大量数据进行学习。这就导致自注意力机制