复现 Pix2Pix 图像转换项目是实现 图像到图像的翻译(Image-to-Image Translation) 的经典案例之一,通常用于图像修复、风格迁移、图像着色等任务。以下是一个 PyTorch 中复现 Pix2Pix 的基本框架。Pix2Pix 使用的是 条件生成对抗网络(cGAN),它由一个生成器(Generator)和一个判别器(Discriminator)组成。

1. 安装依赖

首先,确保你的环境中安装了以下依赖:

pip install torch torchvision matplotlib

如果你没有安装 PyTorch,参考官方 安装指南 进行安装。

2. 定义生成器(Generator)

生成器使用 U-Net 架构,结合了跳跃连接(Skip Connections)以保留细节信息。

import torch
import torch.nn as nn

# U-Net generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        # Encoder
        self.enc1 = self.conv_block(3, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)
        self.enc5 = self.conv_block(512, 512)
        self.enc6 = self.conv_block(512, 512)
        self.enc7 = self.conv_block(512, 512)

        # Decoder
        self.dec1 = self.upconv_block(512, 512)
        self.dec2 = self.upconv_block(512, 512)
        self.dec3 = self.upconv_block(512, 512)
        self.dec4 = self.upconv_block(512, 256)
        self.dec5 = self.upconv_block(256, 128)
        self.dec6 = self.upconv_block(128, 64)

        self.final = nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
        )

    def upconv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)
        e6 = self.enc6(e5)
        e7 = self.enc7(e6)
        
        # Decoder
        d1 = self.dec1(e7)
        d2 = self.dec2(d1 + e6)
        d3 = self.dec3(d2 + e5)
        d4 = self.dec4(d3 + e4)
        d5 = self.dec5(d4 + e3)
        d6 = self.dec6(d5 + e2)

        return self.final(d6 + e1)

3. 定义判别器(Discriminator)

判别器采用标准的 PatchGAN,这是一个局部判别器,通过对每个图像块进行分类来提高训练效率。

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Conv2d(6, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=1),
        )

    def forward(self, x):
        return self.model(x)

4. 定义损失函数

在 Pix2Pix 中,损失函数包含了两部分:

  • 生成器损失:基于对抗损失(Adversarial Loss)和 L1 损失(图像重建损失)。
  • 判别器损失:基于对抗损失,用来判断生成的图像是否真实。
import torch.nn.functional as F

# 定义对抗损失
def adversarial_loss(y_true, y_pred):
    return F.mse_loss(y_pred, y_true)

# 定义 L1 损失(用于生成器的图像重建)
def l1_loss(y_true, y_pred):
    return F.l1_loss(y_pred, y_true)

5. 训练循环

训练中,生成器和判别器交替更新。生成器通过减少判别器的判断能力来生成更逼真的图像,而判别器则试图更准确地区分真实和生成图像。

import torch.optim as optim

# 实例化生成器和判别器
generator = Generator()
discriminator = Discriminator()

# 设置优化器
lr = 0.0002
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# 设置损失函数
criterion_gan = adversarial_loss
criterion_l1 = l1_loss

# 训练过程
for epoch in range(num_epochs):
    for i, (real_A, real_B) in enumerate(dataloader):
        
        # 真实图像和假图像
        real_A = real_A.cuda()
        real_B = real_B.cuda()

        # 生成器生成图像
        fake_B = generator(real_A)

        # 判别器的训练
        optimizer_D.zero_grad()

        # 真实图像的判别
        pred_real = discriminator(torch.cat((real_A, real_B), 1))
        loss_D_real = criterion_gan(pred_real, torch.ones_like(pred_real).cuda())

        # 假图像的判别
        pred_fake = discriminator(torch.cat((real_A, fake_B.detach()), 1))
        loss_D_fake = criterion_gan(pred_fake, torch.zeros_like(pred_fake).cuda())

        # 判别器总损失
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        optimizer_D.step()

        # 生成器的训练
        optimizer_G.zero_grad()

        # 生成器损失:对抗损失 + L1 损失
        pred_fake = discriminator(torch.cat((real_A, fake_B), 1))
        loss_G_GAN = criterion_gan(pred_fake, torch.ones_like(pred_fake).cuda())
        loss_G_L1 = criterion_l1(fake_B, real_B)
        loss_G = loss_G_GAN + 100 * loss_G_L1

        loss_G.backward()
        optimizer_G.step()

    print(f"Epoch [{epoch}/{num_epochs}] - Loss D: {loss_D.item()}, Loss G: {loss_G.item()}")

6. 图像生成与结果展示

每训练一定次数后,可以使用生成器生成新的图像,并使用 matplotlib 或 PIL 展示结果。

import matplotlib.pyplot as plt

# 生成图像
generated_image = generator(real_A)

# 转换为 CPU 并显示
generated_image = generated_image.cpu().detach().numpy().transpose(0, 2, 3, 1)[0]

plt.imshow((generated_image + 1) / 2)  # [-1, 1] 范围 -> [0, 1]
plt.show()

7. 总结

以上代码展示了如何使用 PyTorch 实现 Pix2Pix 图像转换。这个项目的关键在于对生成器和判别器的设计,以及如何平衡它们的训练。生成器使用 U-Net 结构来保证高分辨率图像输出,而判别器使用 PatchGAN 来对每个图像块进行分类。

要运行完整的项目,您还需要设置合适的数据加载器,并使用合适的图像数据集(例如 Cityscapes 或 edges2shoes)。