Transformer 详解
Attention is all you need 是一篇将 Attention 思想发挥到极致的论文,出自 Google。这篇论文中提出一个全新的模型,叫 Transformer,抛弃了以往深度学习任务里面使用到的 CNN 和 RNN (其实也不完全是,还是用到了一维卷积)。这个模型广泛应用于 NLP 领域,例如机器翻译,问答系统,文本摘要和语音识别等等方向。
参考资料:
文章目录
1 Transformer 整体框架
和经典的 seq2seq 模型一样,Transformer 模型中也采用了 encoer-decoder 架构。上图的左半边用 NX 框出来的,就代表一层 encoder,其中论文里面的 encoder 一共有 6 层这样的结构。上图的右半边用 NX 框出来的,则代表一层 decoder,同样也有 6 层。
定义输入序列首先经过 word embedding,再和 positional encoding 相加后,输入到 encoder 中。输出序列经过的处理和输入序列一样,然后输入到 decoder。
最后,decoder 的输出经过一个线性层,再接 Softmax。
于上便是 Transformer 的整体框架,下面先来介绍 encoder 和 decoder。
1.1 Encoder
encoder 由 6 层相同的层组成,每一层分别由两部分组成:
第一部分是 multi-head self-attention
第二部分是 position-wise feed-forward network,是一个全连接层
两个部分,都有一个残差连接 (residual connection),然后接着一个 Layer Normalization。
1.2 Decoder
和 encoder 类似,decoder 也是由 6 个相同的层组成,每一个层包括以下 3 个部分:
第一个部分是 multi-head self-attention mechanism
第二部分是 multi-head context-attention mechanism
第三部分是一个 position-wise feed-forward network
和 encoder 一样,上面三个部分的每一个部分,都有一个残差连接,后接一个 Layer Normalization 。
decoder 和 encoder 不同的地方在 multi-head context-attention mechanism
1.3 Attention
我在以前的文章中讲过,Attention 如果用一句话来描述,那就是 encoder 层的输出经过加权平均后再输入到 decoder 层中。它主要应用在 seq2seq 模型中,这个加权可以用矩阵来表示,也叫 Attention 矩阵。它表示对于某个时刻的输出 y,它在输入 x 上各个部分的注意力。这个注意力就是我们刚才说到的加权。
Attention 又分为很多种,其中两种比较典型的有加性 Attention 和乘性 Attention。加性 Attention 对于输入的隐状态 直接做 concat 操作,乘性 Attention 则是对输入和输出做 dot 操作。
在 Google 这篇论文中,使用的 Attention 模型是乘性 Attention。
我在之前讲 ESIM 模型的文章里面写过一个 soft-align-attention,大家可以参考体会一下。
1.4 Self-Attention
上面我们说 attention 机制的时候,都会说到两个隐状态 hi 和 St,前者是输入序列第 i 个位置产生的隐状态,后者是输出序列在第 t 个位置产生的隐状态。所谓 self-attention 实际上就是,输出序列就是输入序列。因而自己计算自己的 attention 得分。
1.5 Context-Attention
context-attention 是 encoder 和 decoder 之间的 attention,是两个不同序列之间的 attention,与来源于自身的 self-attention 相区别。
不管是哪种 attention,我们在计算 attention 权重的时候,可以选择很多方式,常用的方法有
additive attention
local-base
general
dot-product
scaled dot-product
Transformer 模型采用的是最后一种:scaled dot-product attention。
1.6 Scaled Dot-Product Attention
那么什么是 scaled dot-product attention 呢?
Google 在论文中对 Attention 机制这么来描述:
An attention function can be described as a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility of the query with the corresponding key.
通过 query 和 key 的相似性程度来确定 value 的权重分布。论文中的公式长下面这个样子:
看到 Q,K,V 会不会有点晕,没事,后面会解释。
scaled dot-product attention 的结构图如下所示。
现在来说下 K、Q、V 分别代表什么:
目前可能描述有有点抽象,不容易理解。结合一些应用来说,比如,如果是在自动问答任务中的话,Q 可以代表答案的词向量序列,取 K = V 为问题的词向量序列,那么输出就是所谓的 Aligned Question Embedding。
Google 论文的主要贡献之一是它表明了内部注意力在机器翻译 (甚至是一般的 Seq2Seq 任务)的序列编码上是相当重要的,而之前关于 Seq2Seq 的研究基本都只是把注意力机制用在解码端。
1.7 Scaled Dot-Product Attention 实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 import torchimport torch.nn as nnimport torch.functional as Fimport numpy as npclass ScaledDotProductAttention (nn.Module ): """Scaled dot-product attention mechanism.""" def __init__ (self, attention_dropout=0.0 ): super (ScaledDotProductAttention, self).__init__() self.dropout = nn.Dropout(attention_dropout) self.softmax = nn.Softmax(dim=2 ) def forward (self, q, k, v, scale=None , attn_mask=None ): """ 前向传播. Args: q: Queries张量,形状为[B, L_q, D_q] k: Keys张量,形状为[B, L_k, D_k] v: Values张量,形状为[B, L_v, D_v],一般来说就是k scale: 缩放因子,一个浮点标量 attn_mask: Masking张量,形状为[B, L_q, L_k] Returns: 上下文张量和attention张量 """ attention = torch.bmm(q, k.transpose(1 , 2 )) if scale: attention = attention * scale if attn_mask: attention = attention.masked_fill_(attn_mask, -np.inf) attention = self.softmax(attention) attention = self.dropout(attention) context = torch.bmm(attention, v) return context, attention
1.8 Multi-head attention
理解了 Scaled dot-product attention,Multi-head attention 也很容易理解啦。论文提到,他们发现将 Q、K、V 通过一个线性映射之后,分成 h 份,对每一份进行 scaled dot-product attention 效果更好。然后,把各个部分的结果合并起来,再次经过线性映射,得到最终的输出。这就是所谓的 multi-head attention。上面的超参数 h 就是 heads 的数量。论文默认是 8。
multi-head attention 的结构图如下所示。
1.9 Multi-head attention 实现
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 class MultiHeadAttention (nn.Module ): def __init__ (self, model_dim=512 , num_heads=8 , dropout=0.0 ): super (MultiHeadAttention, self).__init__() self.dim_per_head = model_dim // num_heads self.num_heads = num_heads self.linear_k = nn.Linear(model_dim, self.dim_per_head * num_heads) self.linear_v = nn.Linear(model_dim, self.dim_per_head * num_heads) self.linear_q = nn.Linear(model_dim, self.dim_per_head * num_heads) self.dot_product_attention = ScaledDotProductAttention(dropout) self.linear_final = nn.Linear(model_dim, model_dim) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(model_dim) def forward (self, key, value, query, attn_mask=None ): residual = query dim_per_head = self.dim_per_head num_heads = self.num_heads batch_size = key.size(0 ) key = self.linear_k(key) value = self.linear_v(value) query = self.linear_q(query) key = key.view(batch_size * num_heads, -1 , dim_per_head) value = value.view(batch_size * num_heads, -1 , dim_per_head) query = query.view(batch_size * num_heads, -1 , dim_per_head) if attn_mask: attn_mask = attn_mask.repeat(num_heads, 1 , 1 ) scale = (key.size(-1 )) ** -0.5 context, attention = self.dot_product_attention( query, key, value, scale, attn_mask) context = context.view(batch_size, -1 , dim_per_head * num_heads) output = self.linear_final(context) output = self.dropout(output) output = self.layer_norm(residual + output) return output, attention
上面代码中出现的 Residual connection 我在之前一篇文章中讲过,这里不再赘述,只解释 Layer normalization。
1.10 Layer normalization
Normalization 有很多种,但是它们都有一个共同的目的,那就是把输入转化成均值为 0 方差为 1 的数据。我们在把数据送入激活函数之前进行 normalization(归一化),因为我们不希望输入数据落在激活函数的饱和区。
说到 normalization,那就肯定得提到 Batch Normalization。
BN 的主要思想就是:在每一层的每一批数据上进行归一化。我们可能会对输入数据进行归一化,但是经过该网络层的作用后,我们的数据已经不再是归一化的了。随着这种情况的发展,数据的偏差越来越大,我的反向传播需要考虑到这些大的偏差,这就迫使我们只能使用较小的学习率来防止梯度消失或者梯度爆炸。
BN 的具体做法就是对每一小批数据,在批这个方向上做归一化。如下图所示:
可以看到,右半边求均值是沿着数据 batch N 的方向进行的 !
Batch normalization 的计算公式如下:
那么什么是 Layer normalization 呢?它也是归一化数据的一种方式,不过 LN 是在每一个样本上计算均值和方差,而不是 BN 那种在批方向计算均值和方差 !
下面是 LN 的示意图:
和上面的 BN 示意图一比较就可以看出二者的区别啦!
下面看一下 LN 的公式:
1.11 Mask
mask 表示掩码,它对某些值进行掩盖,使其在参数更新时不产生效果。Transformer 模型里面涉及两种 mask,分别是 padding mask 和 sequence mask。
其中,padding mask 在所有的 scaled dot-product attention 里面都需要用到,而 sequence mask 只有在 decoder 的 self-attention 里面用到。
Padding Mask
什么是 padding mask 呢?因为每个批次输入序列长度是不一样的也就是说,我们要对输入序列进行对齐。具体来说,就是给在较短的序列后面填充 0。因为这些填充的位置,其实是没什么意义的,所以我们的 attention 机制不应该把注意力放在这些位置上,所以我们需要进行一些处理。
具体的做法是,把这些位置的值加上一个非常大的负数 (负无穷),这样的话,经过 softmax,这些位置的概率就会接近 0!
而我们的 padding mask 实际上是一个张量,每个值都是一个 Boolean,值为 false 的地方就是我们要进行处理的地方。
实现:
1 2 3 4 5 6 7 def padding_mask (seq_k, seq_q ): len_q = seq_q.size(1 ) pad_mask = seq_k.eq(0 ) pad_mask = pad_mask.unsqueeze(1 ).expand(-1 , len_q, -1 ) return pad_mask
Sequence mask
文章前面也提到,sequence mask 是为了使得 decoder 不能看见未来的信息。也就是对于一个序列,在 time_step 为 t 的时刻,我们的解码输出应该只能依赖于 t 时刻之前的输出,而不能依赖 t 之后的输出。因此我们需要想一个办法,把 t 之后的信息给隐藏起来。
那么具体怎么做呢?也很简单:产生一个上三角矩阵,上三角的值全为 1,下三角的值权威 0,对角线也是 0 。把这个矩阵作用在每一个序列上,就可以达到我们的目的啦。
具体的代码实现如下:
1 2 3 4 5 6 def sequence_mask(seq): batch_size, seq_len = seq.size() mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8), diagonal=1) mask = mask.unsqueeze(0).expand(batch_size, -1, -1) # [B, L, L] return mask
效果如下,
对于 decoder 的 self-attention,里面使用到的 scaled dot-product attention,同时需要 padding mask 和 sequence mask 作为 attn_mask,具体实现就是两个 mask 相加作为 attn_mask。
其他情况,attn_mask 一律等于 padding mask。
1.12 Positional Embedding
现在的 Transformer 架构还没有提取序列顺序的信息,这个信息对于序列而言非常重要,如果缺失了这个信息,可能我们的结果就是:所有词语都对了,但是无法组成有意义的语句。
为了解决这个问题。论文使用了 Positional Embedding:对序列中的词语出现的位置进行编码。
在实现的时候使用正余弦函数。公式如下:
其中,pos 是指词语在序列中的位置。可以看出,在偶数位置,使用正弦编码,在奇数位置,使用余弦编码 。
上面的位置编码是绝对位置编码 。但是词语的相对位置 也非常重要。这就是论文为什么要使用三角函数的原因!
正弦函数能够表达相对位置信息,主要数学依据是以下两个公式:
具体实现如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 class PositionalEncoding (nn.Module ): def __init__ (self, d_model, max_seq_len ): """初始化。 Args: d_model: 一个标量。模型的维度,论文默认是512 max_seq_len: 一个标量。文本序列的最大长度 """ super (PositionalEncoding, self).__init__() position_encoding = np.array([ [pos / np.power(10000 , 2.0 * (j // 2 ) / d_model) for j in range (d_model)] for pos in range (max_seq_len)]) position_encoding[:, 0 ::2 ] = np.sin(position_encoding[:, 0 ::2 ]) position_encoding[:, 1 ::2 ] = np.cos(position_encoding[:, 1 ::2 ]) pad_row = torch.zeros([1 , d_model]) position_encoding = torch.cat((pad_row, position_encoding)) self.position_encoding = nn.Embedding(max_seq_len + 1 , d_model) self.position_encoding.weight = nn.Parameter(position_encoding, requires_grad=False ) def forward (self, input_len ): """神经网络的前向传播。 Args: input_len: 一个张量,形状为[BATCH_SIZE, 1]。每一个张量的值代表这一批文本序列中对应的长度。 Returns: 返回这一批序列的位置编码,进行了对齐。 """ max_len = torch.max (input_len) tensor = torch.cuda.LongTensor if input_len.is_cuda else torch.LongTensor input_pos = tensor( [list (range (1 , len + 1 )) + [0 ] * (max_len - len ) for len in input_len]) return self.position_encoding(input_pos)
1.13 Position-wise Feed-Forward network
这是一个全连接网络,包含两个线性变换和一个非线性函数 (实际上就是 ReLU)。公式如下
这个线性变换在不同的位置都表现地一样,并且在不同的层之间使用不同的参数。
这里实现上用到了两个一维卷积。
实现如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 class PositionalWiseFeedForward (nn.Module ): def __init__ (self, model_dim=512 , ffn_dim=2048 , dropout=0.0 ): super (PositionalWiseFeedForward, self).__init__() self.w1 = nn.Conv1d(model_dim, ffn_dim, 1 ) self.w2 = nn.Conv1d(ffn_dim, model_dim, 1 ) self.dropout = nn.Dropout(dropout) self.layer_norm = nn.LayerNorm(model_dim) def forward (self, x ): output = x.transpose(1 , 2 ) output = self.w2(F.relu(self.w1(output))) output = self.dropout(output.transpose(1 , 2 )) output = self.layer_norm(x + output) return output
2 Transformer 的实现
现在可以开始完成 Transformer 模型的构建了,encoder 端和 decoder 端分别都有 6 层,实现如下,首先是
2.1 Encoder 端
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 class EncoderLayer (nn.Module ): """Encoder的一层。""" def __init__ (self, model_dim=512 , num_heads=8 , ffn_dim=2048 , dropout=0.0 ): super (EncoderLayer, self).__init__() self.attention = MultiHeadAttention(model_dim, num_heads, dropout) self.feed_forward = PositionalWiseFeedForward(model_dim, ffn_dim, dropout) def forward (self, inputs, attn_mask=None ): context, attention = self.attention(inputs, inputs, inputs, padding_mask) output = self.feed_forward(context) return output, attention class Encoder (nn.Module ): """多层EncoderLayer组成Encoder。""" def __init__ (self, vocab_size, max_seq_len, num_layers=6 , model_dim=512 , num_heads=8 , ffn_dim=2048 , dropout=0.0 ): super (Encoder, self).__init__() self.encoder_layers = nn.ModuleList( [EncoderLayer(model_dim, num_heads, ffn_dim, dropout) for _ in range (num_layers)]) self.seq_embedding = nn.Embedding(vocab_size + 1 , model_dim, padding_idx=0 ) self.pos_embedding = PositionalEncoding(model_dim, max_seq_len) def forward (self, inputs, inputs_len ): output = self.seq_embedding(inputs) output += self.pos_embedding(inputs_len) self_attention_mask = padding_mask(inputs, inputs) attentions = [] for encoder in self.encoder_layers: output, attention = encoder(output, self_attention_mask) attentions.append(attention) return output, attentions
2.2 Decoder 端
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 class DecoderLayer (nn.Module ): def __init__ (self, model_dim, num_heads=8 , ffn_dim=2048 , dropout=0.0 ): super (DecoderLayer, self).__init__() self.attention = MultiHeadAttention(model_dim, num_heads, dropout) self.feed_forward = PositionalWiseFeedForward(model_dim, ffn_dim, dropout) def forward (self, dec_inputs, enc_outputs, self_attn_mask=None , context_attn_mask=None ): dec_output, self_attention = self.attention( dec_inputs, dec_inputs, dec_inputs, self_attn_mask) dec_output, context_attention = self.attention( enc_outputs, enc_outputs, dec_output, context_attn_mask) dec_output = self.feed_forward(dec_output) return dec_output, self_attention, context_attention class Decoder (nn.Module ): def __init__ (self, vocab_size, max_seq_len, num_layers=6 , model_dim=512 , num_heads=8 , ffn_dim=2048 , dropout=0.0 ): super (Decoder, self).__init__() self.num_layers = num_layers self.decoder_layers = nn.ModuleList( [DecoderLayer(model_dim, num_heads, ffn_dim, dropout) for _ in range (num_layers)]) self.seq_embedding = nn.Embedding(vocab_size + 1 , model_dim, padding_idx=0 ) self.pos_embedding = PositionalEncoding(model_dim, max_seq_len) def forward (self, inputs, inputs_len, enc_output, context_attn_mask=None ): output = self.seq_embedding(inputs) output += self.pos_embedding(inputs_len) self_attention_padding_mask = padding_mask(inputs, inputs) seq_mask = sequence_mask(inputs) self_attn_mask = torch.gt((self_attention_padding_mask + seq_mask), 0 ) self_attentions = [] context_attentions = [] for decoder in self.decoder_layers: output, self_attn, context_attn = decoder( output, enc_output, self_attn_mask, context_attn_mask) self_attentions.append(self_attn) context_attentions.append(context_attn) return output, self_attentions, context_attentions
组合一下
2.3 Transformer 模型
class Transformer(nn.Module):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 def __init__ (self, src_vocab_size, src_max_len, tgt_vocab_size, tgt_max_len, num_layers=6 , model_dim=512 , num_heads=8 , ffn_dim=2048 , dropout=0.2 ): super (Transformer, self).__init__() self.encoder = Encoder(src_vocab_size, src_max_len, num_layers, model_dim, num_heads, ffn_dim, dropout) self.decoder = Decoder(tgt_vocab_size, tgt_max_len, num_layers, model_dim, num_heads, ffn_dim, dropout) self.linear = nn.Linear(model_dim, tgt_vocab_size, bias=False ) self.softmax = nn.Softmax(dim=2 ) def forward (self, src_seq, src_len, tgt_seq, tgt_len ): context_attn_mask = padding_mask(tgt_seq, src_seq) output, enc_self_attn = self.encoder(src_seq, src_len) output, dec_self_attn, ctx_attn = self.decoder( tgt_seq, tgt_len, output, context_attn_mask) output = self.linear(output) output = self.softmax(output) return output, enc_self_attn, dec_self_attn, ctx_attn num_heads, ffn_dim, dropout) self.linear = nn.Linear(model_dim, tgt_vocab_size, bias=False ) self.softmax = nn.Softmax(dim=2 ) def forward (self, src_seq, src_len, tgt_seq, tgt_len ): context_attn_mask = padding_mask(tgt_seq, src_seq) output, enc_self_attn = self.encoder(src_seq, src_len) output, dec_self_attn, ctx_attn = self.decoder( tgt_seq, tgt_len, output, context_attn_mask) output = self.linear(output) output = self.softmax(output) return output, enc_self_attn, dec_self_attn, ctx_attn