Transformer 详解

Attention is all you need 是一篇将 Attention 思想发挥到极致的论文,出自 Google。这篇论文中提出一个全新的模型,叫 Transformer,抛弃了以往深度学习任务里面使用到的 CNN 和 RNN (其实也不完全是,还是用到了一维卷积)。这个模型广泛应用于 NLP 领域,例如机器翻译,问答系统,文本摘要和语音识别等等方向。

参考资料:


文章目录


1 Transformer 整体框架

在这里插入图片描述

和经典的 seq2seq 模型一样,Transformer 模型中也采用了 encoer-decoder 架构。上图的左半边用 NX 框出来的,就代表一层 encoder,其中论文里面的 encoder 一共有 6 层这样的结构。上图的右半边用 NX 框出来的,则代表一层 decoder,同样也有 6 层。

定义输入序列首先经过 word embedding,再和 positional encoding 相加后,输入到 encoder 中。输出序列经过的处理和输入序列一样,然后输入到 decoder。

最后,decoder 的输出经过一个线性层,再接 Softmax。

于上便是 Transformer 的整体框架,下面先来介绍 encoder 和 decoder。

1.1 Encoder

encoder 由 6 层相同的层组成,每一层分别由两部分组成:

  • 第一部分是 multi-head self-attention
  • 第二部分是 position-wise feed-forward network,是一个全连接层

两个部分,都有一个残差连接 (residual connection),然后接着一个 Layer Normalization。

1.2 Decoder

和 encoder 类似,decoder 也是由 6 个相同的层组成,每一个层包括以下 3 个部分:

  • 第一个部分是 multi-head self-attention mechanism
  • 第二部分是 multi-head context-attention mechanism
  • 第三部分是一个 position-wise feed-forward network

和 encoder 一样,上面三个部分的每一个部分,都有一个残差连接,后接一个 Layer Normalization

decoder 和 encoder 不同的地方在 multi-head context-attention mechanism

1.3 Attention

我在以前的文章中讲过,Attention 如果用一句话来描述,那就是 encoder 层的输出经过加权平均后再输入到 decoder 层中。它主要应用在 seq2seq 模型中,这个加权可以用矩阵来表示,也叫 Attention 矩阵。它表示对于某个时刻的输出 y,它在输入 x 上各个部分的注意力。这个注意力就是我们刚才说到的加权。

Attention 又分为很多种,其中两种比较典型的有加性 Attention 和乘性 Attention。加性 Attention 对于输入的隐状态 直接做 concat 操作,乘性 Attention 则是对输入和输出做 dot 操作。

在 Google 这篇论文中,使用的 Attention 模型是乘性 Attention。

我在之前讲 ESIM 模型的文章里面写过一个 soft-align-attention,大家可以参考体会一下。

1.4 Self-Attention

上面我们说 attention 机制的时候,都会说到两个隐状态 hi 和 St,前者是输入序列第 i 个位置产生的隐状态,后者是输出序列在第 t 个位置产生的隐状态。所谓 self-attention 实际上就是,输出序列就是输入序列。因而自己计算自己的 attention 得分。

1.5 Context-Attention

context-attention 是 encoder 和 decoder 之间的 attention,是两个不同序列之间的 attention,与来源于自身的 self-attention 相区别。

不管是哪种 attention,我们在计算 attention 权重的时候,可以选择很多方式,常用的方法有

  • additive attention
  • local-base
  • general
  • dot-product
  • scaled dot-product

Transformer 模型采用的是最后一种:scaled dot-product attention。

1.6 Scaled Dot-Product Attention

那么什么是 scaled dot-product attention 呢?

Google 在论文中对 Attention 机制这么来描述:

An attention function can be described as a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility of the query with the corresponding key.

通过 query 和 key 的相似性程度来确定 value 的权重分布。论文中的公式长下面这个样子:

在这里插入图片描述

看到 Q,K,V 会不会有点晕,没事,后面会解释。

在这里插入图片描述

scaled dot-product attention 的结构图如下所示。

在这里插入图片描述

现在来说下 K、Q、V 分别代表什么:

在这里插入图片描述

目前可能描述有有点抽象,不容易理解。结合一些应用来说,比如,如果是在自动问答任务中的话,Q 可以代表答案的词向量序列,取 K = V 为问题的词向量序列,那么输出就是所谓的 Aligned Question Embedding。

Google 论文的主要贡献之一是它表明了内部注意力在机器翻译 (甚至是一般的 Seq2Seq 任务)的序列编码上是相当重要的,而之前关于 Seq2Seq 的研究基本都只是把注意力机制用在解码端。

1.7 Scaled Dot-Product Attention 实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import torch.nn as nn
import torch.functional as F
import numpy as np

class ScaledDotProductAttention(nn.Module):
"""Scaled dot-product attention mechanism."""

def __init__(self, attention_dropout=0.0):
super(ScaledDotProductAttention, self).__init__()
self.dropout = nn.Dropout(attention_dropout)
self.softmax = nn.Softmax(dim=2)

def forward(self, q, k, v, scale=None, attn_mask=None):
"""
前向传播.
Args:
q: Queries张量,形状为[B, L_q, D_q]
k: Keys张量,形状为[B, L_k, D_k]
v: Values张量,形状为[B, L_v, D_v],一般来说就是k
scale: 缩放因子,一个浮点标量
attn_mask: Masking张量,形状为[B, L_q, L_k]

Returns:
上下文张量和attention张量
"""
attention = torch.bmm(q, k.transpose(1, 2))
if scale:
attention = attention * scale
if attn_mask:
# 给需要 mask 的地方设置一个负无穷
attention = attention.masked_fill_(attn_mask, -np.inf)
# 计算softmax
attention = self.softmax(attention)
# 添加dropout
attention = self.dropout(attention)
# 和V做点积
context = torch.bmm(attention, v)
return context, attention

1.8 Multi-head attention

理解了 Scaled dot-product attention,Multi-head attention 也很容易理解啦。论文提到,他们发现将 Q、K、V 通过一个线性映射之后,分成 h 份,对每一份进行 scaled dot-product attention 效果更好。然后,把各个部分的结果合并起来,再次经过线性映射,得到最终的输出。这就是所谓的 multi-head attention。上面的超参数 h 就是 heads 的数量。论文默认是 8。

multi-head attention 的结构图如下所示。

在这里插入图片描述

在这里插入图片描述

1.9 Multi-head attention 实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
class MultiHeadAttention(nn.Module):

def __init__(self, model_dim=512, num_heads=8, dropout=0.0):
super(MultiHeadAttention, self).__init__()

self.dim_per_head = model_dim // num_heads
self.num_heads = num_heads
self.linear_k = nn.Linear(model_dim, self.dim_per_head * num_heads)
self.linear_v = nn.Linear(model_dim, self.dim_per_head * num_heads)
self.linear_q = nn.Linear(model_dim, self.dim_per_head * num_heads)

self.dot_product_attention = ScaledDotProductAttention(dropout)
self.linear_final = nn.Linear(model_dim, model_dim)
self.dropout = nn.Dropout(dropout)

# multi-head attention之后需要做layer norm
self.layer_norm = nn.LayerNorm(model_dim)

def forward(self, key, value, query, attn_mask=None):
# 残差连接
residual = query
dim_per_head = self.dim_per_head
num_heads = self.num_heads
batch_size = key.size(0)

# linear projection
key = self.linear_k(key)
value = self.linear_v(value)
query = self.linear_q(query)

# split by heads
key = key.view(batch_size * num_heads, -1, dim_per_head)
value = value.view(batch_size * num_heads, -1, dim_per_head)
query = query.view(batch_size * num_heads, -1, dim_per_head)

if attn_mask:
attn_mask = attn_mask.repeat(num_heads, 1, 1)

# scaled dot product attention
scale = (key.size(-1)) ** -0.5
context, attention = self.dot_product_attention(
query, key, value, scale, attn_mask)

# concat heads
context = context.view(batch_size, -1, dim_per_head * num_heads)

# final linear projection
output = self.linear_final(context)

# dropout
output = self.dropout(output)

# add residual and norm layer
output = self.layer_norm(residual + output)

return output, attention

上面代码中出现的 Residual connection 我在之前一篇文章中讲过,这里不再赘述,只解释 Layer normalization。

1.10 Layer normalization

Normalization 有很多种,但是它们都有一个共同的目的,那就是把输入转化成均值为 0 方差为 1 的数据。我们在把数据送入激活函数之前进行 normalization(归一化),因为我们不希望输入数据落在激活函数的饱和区。

说到 normalization,那就肯定得提到 Batch Normalization。

BN 的主要思想就是:在每一层的每一批数据上进行归一化。我们可能会对输入数据进行归一化,但是经过该网络层的作用后,我们的数据已经不再是归一化的了。随着这种情况的发展,数据的偏差越来越大,我的反向传播需要考虑到这些大的偏差,这就迫使我们只能使用较小的学习率来防止梯度消失或者梯度爆炸。

BN 的具体做法就是对每一小批数据,在批这个方向上做归一化。如下图所示:

在这里插入图片描述

可以看到,右半边求均值是沿着数据 batch N 的方向进行的

Batch normalization 的计算公式如下:

在这里插入图片描述

那么什么是 Layer normalization 呢?它也是归一化数据的一种方式,不过 LN 是在每一个样本上计算均值和方差,而不是 BN 那种在批方向计算均值和方差

下面是 LN 的示意图:

在这里插入图片描述

和上面的 BN 示意图一比较就可以看出二者的区别啦!

下面看一下 LN 的公式:

在这里插入图片描述

1.11 Mask

mask 表示掩码,它对某些值进行掩盖,使其在参数更新时不产生效果。Transformer 模型里面涉及两种 mask,分别是 padding mask 和 sequence mask。

其中,padding mask 在所有的 scaled dot-product attention 里面都需要用到,而 sequence mask 只有在 decoder 的 self-attention 里面用到。

Padding Mask

什么是 padding mask 呢?因为每个批次输入序列长度是不一样的也就是说,我们要对输入序列进行对齐。具体来说,就是给在较短的序列后面填充 0。因为这些填充的位置,其实是没什么意义的,所以我们的 attention 机制不应该把注意力放在这些位置上,所以我们需要进行一些处理。

具体的做法是,把这些位置的值加上一个非常大的负数 (负无穷),这样的话,经过 softmax,这些位置的概率就会接近 0!

而我们的 padding mask 实际上是一个张量,每个值都是一个 Boolean,值为 false 的地方就是我们要进行处理的地方。

实现:

1
2
3
4
5
6
7
def padding_mask(seq_k, seq_q):
# seq_k 和 seq_q 的形状都是 [B,L]
len_q = seq_q.size(1)
# `PAD` is 0
pad_mask = seq_k.eq(0)
pad_mask = pad_mask.unsqueeze(1).expand(-1, len_q, -1) # shape [B, L_q, L_k]
return pad_mask

Sequence mask

文章前面也提到,sequence mask 是为了使得 decoder 不能看见未来的信息。也就是对于一个序列,在 time_step 为 t 的时刻,我们的解码输出应该只能依赖于 t 时刻之前的输出,而不能依赖 t 之后的输出。因此我们需要想一个办法,把 t 之后的信息给隐藏起来。

那么具体怎么做呢?也很简单:产生一个上三角矩阵,上三角的值全为 1,下三角的值权威 0,对角线也是 0。把这个矩阵作用在每一个序列上,就可以达到我们的目的啦。

具体的代码实现如下:

1
2
3
4
5
6
def sequence_mask(seq):
batch_size, seq_len = seq.size()
mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.uint8),
diagonal=1)
mask = mask.unsqueeze(0).expand(batch_size, -1, -1) # [B, L, L]
return mask

效果如下,

在这里插入图片描述
  • 对于 decoder 的 self-attention,里面使用到的 scaled dot-product attention,同时需要 padding mask 和 sequence mask 作为 attn_mask,具体实现就是两个 mask 相加作为 attn_mask。
  • 其他情况,attn_mask 一律等于 padding mask。

1.12 Positional Embedding

现在的 Transformer 架构还没有提取序列顺序的信息,这个信息对于序列而言非常重要,如果缺失了这个信息,可能我们的结果就是:所有词语都对了,但是无法组成有意义的语句。

为了解决这个问题。论文使用了 Positional Embedding:对序列中的词语出现的位置进行编码。

在实现的时候使用正余弦函数。公式如下:

在这里插入图片描述

其中,pos 是指词语在序列中的位置。可以看出,在偶数位置,使用正弦编码,在奇数位置,使用余弦编码

在这里插入图片描述

上面的位置编码是绝对位置编码。但是词语的相对位置也非常重要。这就是论文为什么要使用三角函数的原因!

正弦函数能够表达相对位置信息,主要数学依据是以下两个公式:

在这里插入图片描述

在这里插入图片描述

具体实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
class PositionalEncoding(nn.Module):

def __init__(self, d_model, max_seq_len):
"""初始化。
Args:
d_model: 一个标量。模型的维度,论文默认是512
max_seq_len: 一个标量。文本序列的最大长度
"""
super(PositionalEncoding, self).__init__()

# 根据论文给的公式,构造出PE矩阵
position_encoding = np.array([
[pos / np.power(10000, 2.0 * (j // 2) / d_model) for j in range(d_model)]
for pos in range(max_seq_len)])
# 偶数列使用sin,奇数列使用cos
position_encoding[:, 0::2] = np.sin(position_encoding[:, 0::2])
position_encoding[:, 1::2] = np.cos(position_encoding[:, 1::2])

# 在PE矩阵的第一行,加上一行全是0的向量,代表这`PAD`的positional encoding
# 在word embedding中也经常会加上`UNK`,代表位置单词的word embedding,两者十分类似
# 那么为什么需要这个额外的PAD的编码呢?很简单,因为文本序列的长度不一,我们需要对齐,
# 短的序列我们使用0在结尾补全,我们也需要这些补全位置的编码,也就是`PAD`对应的位置编码
pad_row = torch.zeros([1, d_model])
position_encoding = torch.cat((pad_row, position_encoding))

# 嵌入操作,+1是因为增加了`PAD`这个补全位置的编码,
# Word embedding中如果词典增加`UNK`,我们也需要+1。看吧,两者十分相似
self.position_encoding = nn.Embedding(max_seq_len + 1, d_model)
self.position_encoding.weight = nn.Parameter(position_encoding,
requires_grad=False)
def forward(self, input_len):
"""神经网络的前向传播。

Args:
input_len: 一个张量,形状为[BATCH_SIZE, 1]。每一个张量的值代表这一批文本序列中对应的长度。

Returns:
返回这一批序列的位置编码,进行了对齐。
"""

# 找出这一批序列的最大长度
max_len = torch.max(input_len)
tensor = torch.cuda.LongTensor if input_len.is_cuda else torch.LongTensor
# 对每一个序列的位置进行对齐,在原序列位置的后面补上0
# 这里range从1开始也是因为要避开PAD(0)的位置
input_pos = tensor(
[list(range(1, len + 1)) + [0] * (max_len - len) for len in input_len])
return self.position_encoding(input_pos)

1.13 Position-wise Feed-Forward network

这是一个全连接网络,包含两个线性变换和一个非线性函数 (实际上就是 ReLU)。公式如下

在这里插入图片描述

这个线性变换在不同的位置都表现地一样,并且在不同的层之间使用不同的参数。

这里实现上用到了两个一维卷积。

实现如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class PositionalWiseFeedForward(nn.Module):

def __init__(self, model_dim=512, ffn_dim=2048, dropout=0.0):
super(PositionalWiseFeedForward, self).__init__()
self.w1 = nn.Conv1d(model_dim, ffn_dim, 1)
self.w2 = nn.Conv1d(ffn_dim, model_dim, 1)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(model_dim)

def forward(self, x):
output = x.transpose(1, 2)
output = self.w2(F.relu(self.w1(output)))
output = self.dropout(output.transpose(1, 2))

# add residual and norm layer
output = self.layer_norm(x + output)
return output

2 Transformer 的实现

现在可以开始完成 Transformer 模型的构建了,encoder 端和 decoder 端分别都有 6 层,实现如下,首先是

2.1 Encoder 端

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class EncoderLayer(nn.Module):
"""Encoder的一层。"""

def __init__(self, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.0):
super(EncoderLayer, self).__init__()

self.attention = MultiHeadAttention(model_dim, num_heads, dropout)
self.feed_forward = PositionalWiseFeedForward(model_dim, ffn_dim, dropout)

def forward(self, inputs, attn_mask=None):

# self attention
context, attention = self.attention(inputs, inputs, inputs, padding_mask)

# feed forward network
output = self.feed_forward(context)

return output, attention


class Encoder(nn.Module):
"""多层EncoderLayer组成Encoder。"""

def __init__(self,
vocab_size,
max_seq_len,
num_layers=6,
model_dim=512,
num_heads=8,
ffn_dim=2048,
dropout=0.0):
super(Encoder, self).__init__()

self.encoder_layers = nn.ModuleList(
[EncoderLayer(model_dim, num_heads, ffn_dim, dropout) for _ in
range(num_layers)])

self.seq_embedding = nn.Embedding(vocab_size + 1, model_dim, padding_idx=0)
self.pos_embedding = PositionalEncoding(model_dim, max_seq_len)

def forward(self, inputs, inputs_len):
output = self.seq_embedding(inputs)
output += self.pos_embedding(inputs_len)

self_attention_mask = padding_mask(inputs, inputs)

attentions = []
for encoder in self.encoder_layers:
output, attention = encoder(output, self_attention_mask)
attentions.append(attention)

return output, attentions

2.2 Decoder 端

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class DecoderLayer(nn.Module):

def __init__(self, model_dim, num_heads=8, ffn_dim=2048, dropout=0.0):
super(DecoderLayer, self).__init__()

self.attention = MultiHeadAttention(model_dim, num_heads, dropout)
self.feed_forward = PositionalWiseFeedForward(model_dim, ffn_dim, dropout)

def forward(self,
dec_inputs,
enc_outputs,
self_attn_mask=None,
context_attn_mask=None):
# self attention, all inputs are decoder inputs
dec_output, self_attention = self.attention(
dec_inputs, dec_inputs, dec_inputs, self_attn_mask)

# context attention
# query is decoder's outputs, key and value are encoder's inputs
dec_output, context_attention = self.attention(
enc_outputs, enc_outputs, dec_output, context_attn_mask)

# decoder's output, or context
dec_output = self.feed_forward(dec_output)

return dec_output, self_attention, context_attention


class Decoder(nn.Module):

def __init__(self,
vocab_size,
max_seq_len,
num_layers=6,
model_dim=512,
num_heads=8,
ffn_dim=2048,
dropout=0.0):
super(Decoder, self).__init__()

self.num_layers = num_layers

self.decoder_layers = nn.ModuleList(
[DecoderLayer(model_dim, num_heads, ffn_dim, dropout) for _ in
range(num_layers)])

self.seq_embedding = nn.Embedding(vocab_size + 1, model_dim, padding_idx=0)
self.pos_embedding = PositionalEncoding(model_dim, max_seq_len)

def forward(self, inputs, inputs_len, enc_output, context_attn_mask=None):
output = self.seq_embedding(inputs)
output += self.pos_embedding(inputs_len)

self_attention_padding_mask = padding_mask(inputs, inputs)
seq_mask = sequence_mask(inputs)
self_attn_mask = torch.gt((self_attention_padding_mask + seq_mask), 0)

self_attentions = []
context_attentions = []
for decoder in self.decoder_layers:
output, self_attn, context_attn = decoder(
output, enc_output, self_attn_mask, context_attn_mask)
self_attentions.append(self_attn)
context_attentions.append(context_attn)

return output, self_attentions, context_attentions

组合一下

2.3 Transformer 模型

class Transformer(nn.Module):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def __init__(self,
src_vocab_size,
src_max_len,
tgt_vocab_size,
tgt_max_len,
num_layers=6,
model_dim=512,
num_heads=8,
ffn_dim=2048,
dropout=0.2):
super(Transformer, self).__init__()

self.encoder = Encoder(src_vocab_size, src_max_len, num_layers, model_dim,
num_heads, ffn_dim, dropout)
self.decoder = Decoder(tgt_vocab_size, tgt_max_len, num_layers, model_dim,
num_heads, ffn_dim, dropout)

self.linear = nn.Linear(model_dim, tgt_vocab_size, bias=False)
self.softmax = nn.Softmax(dim=2)

def forward(self, src_seq, src_len, tgt_seq, tgt_len):
context_attn_mask = padding_mask(tgt_seq, src_seq)

output, enc_self_attn = self.encoder(src_seq, src_len)

output, dec_self_attn, ctx_attn = self.decoder(
tgt_seq, tgt_len, output, context_attn_mask)

output = self.linear(output)
output = self.softmax(output)

return output, enc_self_attn, dec_self_attn, ctx_attn
num_heads, ffn_dim, dropout)

self.linear = nn.Linear(model_dim, tgt_vocab_size, bias=False)
self.softmax = nn.Softmax(dim=2)

def forward(self, src_seq, src_len, tgt_seq, tgt_len):
context_attn_mask = padding_mask(tgt_seq, src_seq)

output, enc_self_attn = self.encoder(src_seq, src_len)

output, dec_self_attn, ctx_attn = self.decoder(
tgt_seq, tgt_len, output, context_attn_mask)

output = self.linear(output)
output = self.softmax(output)

return output, enc_self_attn, dec_self_attn, ctx_attn
× 请我吃糖~
打赏二维码