好的,以下是《TensorFlow 深度学习实战(21)——Transformer 架构详解与实现》的完整教程,结合理论结构与代码实战,适用于希望深入掌握 NLP 模型架构、并动手实现 Transformer 的开发者、研究人员或学生。


🤖 TensorFlow 深度学习实战(21)——Transformer 架构详解与实现


📚 目录

  1. Transformer 架构概述
  2. 基本组成结构详解
  3. 输入嵌入与位置编码(Positional Encoding)
  4. 多头自注意力机制(Multi-Head Attention)
  5. 编码器(Encoder)与解码器(Decoder)结构
  6. 掩码机制(Mask)详解
  7. TensorFlow 实现核心模块
  8. 构建完整 Transformer 模型
  9. 模型训练与测试(以翻译任务为例)
  10. 总结与进阶建议

1. Transformer 架构概述

Transformer 是 Google 于 2017 年提出的序列建模架构(论文《Attention is All You Need》),摒弃了传统的 RNN/CNN,完全基于 自注意力机制(Self-Attention) 实现。

主要应用:机器翻译、文本生成、语言模型(如 GPT、BERT)


2. Transformer 基本结构图

Input -> Embedding -> Positional Encoding -> N个Encoder层 -> Decoder层 -> Output

主要模块:

  • Encoder:多层堆叠(一般6层),每层包含:
    • 多头自注意力(Multi-Head Self Attention)
    • 前馈网络(Feed Forward)
  • Decoder:结构类似 Encoder,但包含:
    • 自注意力
    • 编码器-解码器注意力
    • 前馈网络

3. 输入嵌入与位置编码

Transformer 没有循环结构,因此需要为每个 Token 加上位置信息

def get_angles(pos, i, d_model):
    angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model))
    return pos * angle_rates

def positional_encoding(position, d_model):
    angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                            np.arange(d_model)[np.newaxis, :],
                            d_model)
    # 偶数用 sin,奇数用 cos
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    pos_encoding = angle_rads[np.newaxis, ...]
    return tf.cast(pos_encoding, dtype=tf.float32)

4. 多头注意力机制(Multi-Head Attention)

将查询 Q、键 K、值 V 映射到多个子空间进行并行注意力计算,再进行拼接:

def scaled_dot_product_attention(q, k, v, mask):
    matmul_qk = tf.matmul(q, k, transpose_b=True)
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled = matmul_qk / tf.math.sqrt(dk)

    if mask is not None:
        scaled += (mask * -1e9)  # 掩码位置填充极小值

    attention_weights = tf.nn.softmax(scaled, axis=-1)
    output = tf.matmul(attention_weights, v)
    return output, attention_weights

完整 Multi-Head 实现:

class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads):
        ...
    def split_heads(self, x, batch_size):
        ...
    def call(self, v, k, q, mask):
        ...

5. 编码器与解码器结构

编码器(Encoder Layer):

class EncoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        ...
    def call(self, x, training, mask):
        attn_output, _ = self.mha(x, x, x, mask)  # Self-attention
        x = self.layernorm1(x + attn_output)
        ffn_output = self.ffn(x)
        return self.layernorm2(x + ffn_output)

解码器(Decoder Layer):

class DecoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        ...
    def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
        attn1, _ = self.mha1(x, x, x, look_ahead_mask)
        attn2, _ = self.mha2(enc_output, enc_output, attn1, padding_mask)
        ffn_output = self.ffn(attn2)
        ...

6. 掩码机制详解

Padding Mask

对输入中的 <pad> 部分进行屏蔽:

def create_padding_mask(seq):
    mask = tf.cast(tf.math.equal(seq, 0), tf.float32)
    return mask[:, tf.newaxis, tf.newaxis, :]

Look-ahead Mask(防止解码时看到未来)

def create_look_ahead_mask(size):
    return 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)

7. 构建核心模块

  • 多头注意力:MultiHeadAttention
  • 前馈网络:两层全连接 + ReLU
  • 残差连接 + LayerNorm
  • Dropout 防止过拟合
def point_wise_feed_forward_network(d_model, dff):
    return tf.keras.Sequential([
        tf.keras.layers.Dense(dff, activation='relu'),
        tf.keras.layers.Dense(d_model)
    ])

8. 构建完整 Transformer 模型

class Transformer(tf.keras.Model):
    def __init__(self, num_layers, d_model, num_heads, dff,
                 input_vocab_size, target_vocab_size,
                 pe_input, pe_target, rate=0.1):
        ...
    def call(self, inp, tar, training, enc_padding_mask,
             look_ahead_mask, dec_padding_mask):
        ...

9. 模型训练与测试:翻译任务案例

编译与训练

transformer = Transformer(...)
learning_rate = CustomSchedule(d_model)
optimizer = tf.keras.optimizers.Adam(learning_rate, ...)
transformer.compile(optimizer=optimizer, loss=loss_function)

输入输出

  • 输入:英语句子(token 序列)
  • 输出:法语/中文句子(预测 token)
output = transformer(inp_sentence, ...)

你可使用 TensorFlow Datasets 加载例如 TED Talks 多语言语料库。


10. 总结与进阶建议

关键点建议
理解 Self-Attention是 Transformer 的核心
熟悉 Positional Encoding弥补非循环架构的顺序感知缺陷
明确掩码的作用防止未来泄露与 Padding 噪声
模块化代码复用多头注意力/编码器/解码器
利用 TensorFlow Dataset简化数据处理管道

🔗 参考资料