参考
一文搞定自注意力机制(Self-Attention)

Query(查询):表示当前的关注点或目标,是模型在特定时刻的”需求”或”兴趣点”。Query体现了模型的主观意图,用于主动引导对信息的检索和聚焦。

Key(键):表示被比对的对象,是客观存在的特征向量,作为与Query进行匹配的依据。Key相当于数据的”索引”或”标识符”,帮助模型判断哪些信息与当前Query相关。

Value(值):表示与Key对应的详细信息,是实际的特征向量内容。当Query与某个Key匹配后,模型会提取对应的Value作为最终输出。Value可视为数据的”实际内容”或”特征表示”。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import torch
from torch import nn
import torch.nn.functional as F

# decoder(解码器),利用编码器(encoder)提取的特征表示,在每一个时间步生成目标序列中的一个词
class BahdanauAttention(nn.Module):
def __init__(self, enc_hidden_size, dec_hidden_size):
super().__init__()
self.attn = nn.Linear(enc_hidden_size + dec_hidden_size, dec_hidden_size, bias=False)
self.v = nn.Parameter(torch.rand(dec_hidden_size))

def forward(self, hidden_dec, encoder_outputs):
# hidden_dec: [batch, dec_hidden] (当前解码器隐状态)
# encoder_outputs: [batch, src_len, enc_hidden] (所有时间步编码器输出)

batch_size = encoder_outputs.size(0)
src_len = encoder_outputs.size(1)

# 扩展 decoder hidden
hidden_dec_exp = hidden_dec.unsqueeze(1).repeat(1, src_len, 1)

# 拼接并计算注意力得分
energy = torch.tanh(self.attn(torch.cat((hidden_dec_exp, encoder_outputs), dim=2))) # [batch, src_len, dec_hidden]
v = self.v.unsqueeze(0).unsqueeze(1).repeat(batch_size, 1, 1) # [batch, 1, dec_hidden]

attn_scores = torch.bmm(v, energy.transpose(1, 2)).squeeze(1) # [batch, src_len]
attn_weights = F.softmax(attn_scores, dim=1) # [batch, src_len]

return attn_weights


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class DecoderWithAttention(nn.Module):
def __init__(self, input_size, enc_hidden_size, dec_hidden_size, output_size):
super().__init__()
self.embedding = nn.Embedding(input_size, dec_hidden_size)
self.attention = BahdanauAttention(enc_hidden_size, dec_hidden_size)
self.rnn = nn.GRU(enc_hidden_size + dec_hidden_size, dec_hidden_size, batch_first=True)
self.fc_out = nn.Linear(enc_hidden_size + dec_hidden_size * 2, output_size)

def forward(self, input_token, hidden_dec, encoder_outputs):
# input_token: [batch] 当前输入 token 的索引
# hidden_dec: [batch, dec_hidden]
# encoder_outputs: [batch, src_len, enc_hidden]

embedded = self.embedding(input_token).unsqueeze(1) # [batch, 1, dec_hidden]

# Attention weights: [batch, src_len]
attn_weights = self.attention(hidden_dec, encoder_outputs)

# 上下文向量:context = ∑ attention_weights * encoder_outputs
context = torch.bmm(attn_weights.unsqueeze(1), encoder_outputs) # [batch, 1, enc_hidden]

# 合并 context 和嵌入后输入 RNN
rnn_input = torch.cat((embedded, context), dim=2) # [batch, 1, enc_hidden + dec_hidden]
output, hidden = self.rnn(rnn_input, hidden_dec.unsqueeze(0))

# 拼接 context、rnn output、embedded 作为最终输出的输入
output = output.squeeze(1) # [batch, dec_hidden]
context = context.squeeze(1) # [batch, enc_hidden]
embedded = embedded.squeeze(1) # [batch, dec_hidden]

final_out = self.fc_out(torch.cat((output, context, embedded), dim=1)) # [batch, output_size]

return final_out, hidden.squeeze(0), attn_weights

自注意力机制Self-Attention

是一种将单个序列的不同位置关联起来以计算同一序列的表示的注意机制。是一种将单个序列的不同位置关联起来以计算同一序列的表示的注意机制。可以建立全局的依赖关系,扩大图像的感受野。相比于CNN,其感受野更大,可以获取更多上下文信息。通过筛选重要信息,过滤不重要信息实现的,这就导致其有效信息的抓取能力会比CNN小一些。必须通过大量数据进行学习。这就导致自注意力机制只有在大数据的基础上才能有效地建立准确的全局关系,而在小数据的情况下,其效果不如CNN。