好的,以下是《TensorFlow 深度学习实战(21)——Transformer 架构详解与实现》的完整教程,结合理论结构与代码实战,适用于希望深入掌握 NLP 模型架构、并动手实现 Transformer 的开发者、研究人员或学生。
🤖 TensorFlow 深度学习实战(21)——Transformer 架构详解与实现
📚 目录
- Transformer 架构概述
- 基本组成结构详解
- 输入嵌入与位置编码(Positional Encoding)
- 多头自注意力机制(Multi-Head Attention)
- 编码器(Encoder)与解码器(Decoder)结构
- 掩码机制(Mask)详解
- TensorFlow 实现核心模块
- 构建完整 Transformer 模型
- 模型训练与测试(以翻译任务为例)
- 总结与进阶建议
1. Transformer 架构概述
Transformer 是 Google 于 2017 年提出的序列建模架构(论文《Attention is All You Need》),摒弃了传统的 RNN/CNN,完全基于 自注意力机制(Self-Attention) 实现。
主要应用:机器翻译、文本生成、语言模型(如 GPT、BERT)
2. Transformer 基本结构图
Input -> Embedding -> Positional Encoding -> N个Encoder层 -> Decoder层 -> Output
主要模块:
- Encoder:多层堆叠(一般6层),每层包含:
- 多头自注意力(Multi-Head Self Attention)
- 前馈网络(Feed Forward)
- Decoder:结构类似 Encoder,但包含:
- 自注意力
- 编码器-解码器注意力
- 前馈网络
3. 输入嵌入与位置编码
Transformer 没有循环结构,因此需要为每个 Token 加上位置信息。
def get_angles(pos, i, d_model):
angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
return pos * angle_rates
def positional_encoding(position, d_model):
angle_rads = get_angles(np.arange(position)[:, np.newaxis],
np.arange(d_model)[np.newaxis, :],
d_model)
# 偶数用 sin,奇数用 cos
angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
pos_encoding = angle_rads[np.newaxis, ...]
return tf.cast(pos_encoding, dtype=tf.float32)
4. 多头注意力机制(Multi-Head Attention)
将查询 Q、键 K、值 V 映射到多个子空间进行并行注意力计算,再进行拼接:
def scaled_dot_product_attention(q, k, v, mask):
matmul_qk = tf.matmul(q, k, transpose_b=True)
dk = tf.cast(tf.shape(k)[-1], tf.float32)
scaled = matmul_qk / tf.math.sqrt(dk)
if mask is not None:
scaled += (mask * -1e9) # 掩码位置填充极小值
attention_weights = tf.nn.softmax(scaled, axis=-1)
output = tf.matmul(attention_weights, v)
return output, attention_weights
完整 Multi-Head 实现:
class MultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads):
...
def split_heads(self, x, batch_size):
...
def call(self, v, k, q, mask):
...
5. 编码器与解码器结构
编码器(Encoder Layer):
class EncoderLayer(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads, dff, rate=0.1):
...
def call(self, x, training, mask):
attn_output, _ = self.mha(x, x, x, mask) # Self-attention
x = self.layernorm1(x + attn_output)
ffn_output = self.ffn(x)
return self.layernorm2(x + ffn_output)
解码器(Decoder Layer):
class DecoderLayer(tf.keras.layers.Layer):
def __init__(self, d_model, num_heads, dff, rate=0.1):
...
def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
attn1, _ = self.mha1(x, x, x, look_ahead_mask)
attn2, _ = self.mha2(enc_output, enc_output, attn1, padding_mask)
ffn_output = self.ffn(attn2)
...
6. 掩码机制详解
Padding Mask
对输入中的 <pad>
部分进行屏蔽:
def create_padding_mask(seq):
mask = tf.cast(tf.math.equal(seq, 0), tf.float32)
return mask[:, tf.newaxis, tf.newaxis, :]
Look-ahead Mask(防止解码时看到未来)
def create_look_ahead_mask(size):
return 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
7. 构建核心模块
- 多头注意力:
MultiHeadAttention
- 前馈网络:两层全连接 + ReLU
- 残差连接 + LayerNorm
- Dropout 防止过拟合
def point_wise_feed_forward_network(d_model, dff):
return tf.keras.Sequential([
tf.keras.layers.Dense(dff, activation='relu'),
tf.keras.layers.Dense(d_model)
])
8. 构建完整 Transformer 模型
class Transformer(tf.keras.Model):
def __init__(self, num_layers, d_model, num_heads, dff,
input_vocab_size, target_vocab_size,
pe_input, pe_target, rate=0.1):
...
def call(self, inp, tar, training, enc_padding_mask,
look_ahead_mask, dec_padding_mask):
...
9. 模型训练与测试:翻译任务案例
编译与训练
transformer = Transformer(...)
learning_rate = CustomSchedule(d_model)
optimizer = tf.keras.optimizers.Adam(learning_rate, ...)
transformer.compile(optimizer=optimizer, loss=loss_function)
输入输出
- 输入:
英语句子(token 序列)
- 输出:
法语/中文句子(预测 token)
output = transformer(inp_sentence, ...)
你可使用
TensorFlow Datasets
加载例如 TED Talks 多语言语料库。
10. 总结与进阶建议
关键点 | 建议 |
---|---|
理解 Self-Attention | 是 Transformer 的核心 |
熟悉 Positional Encoding | 弥补非循环架构的顺序感知缺陷 |
明确掩码的作用 | 防止未来泄露与 Padding 噪声 |
模块化代码复用 | 多头注意力/编码器/解码器 |
利用 TensorFlow Dataset | 简化数据处理管道 |
🔗 参考资料
- 📄 原始论文:Attention Is All You Need
- 📘 《TensorFlow 2 实战》
- 🧠 教程:TensorFlow 官方 Transformer 教程
- 🌐 实战代码仓库:github.com/tensorflow/models
发表回复