好嘞,阿杰 👍
我来为你写一份 《PyTorch 实战——ResNet 与 DenseNet 详解》,从原理到 PyTorch 实现,再到实战对比,帮助你全面理解这两大经典 CNN 网络。
PyTorch 实战——ResNet 与 DenseNet 详解
一、背景与动机
- 传统 CNN(如 VGGNet):随着网络加深,性能先上升后下降,出现 梯度消失 / 梯度爆炸 问题。
- ResNet(Residual Network):2015 年提出,通过 残差连接(skip connection) 解决深层网络训练困难问题。
- DenseNet(Densely Connected Network):2017 年提出,进一步强化特征传递,通过 密集连接(dense connection) 提升特征复用率,减少参数量。
👉 两者都是 深度 CNN 的里程碑,被广泛应用于图像分类、检测、分割等任务。
二、ResNet 详解
1. 核心思想
- 引入 残差块(Residual Block):
- 如果 $F(x)$ 是卷积后的特征,则输出变为:
y=F(x)+x
- 如果 $F(x)$ 是卷积后的特征,则输出变为:
- 优势:让网络更容易学习恒等映射,避免性能退化。
2. 残差结构
import torch.nn as nn
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
# 调整输入和输出通道一致性
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
return self.relu(out)
3. 常见变体
- ResNet-18 / 34:基础残差块(Basic Block)
- ResNet-50 / 101 / 152:瓶颈残差块(Bottleneck Block)
三、DenseNet 详解
1. 核心思想
- 引入 密集连接(Dense Connection):
- 每一层的输出都作为下一层的输入,公式为:
xl=Hl([x0,x1,…,xl−1])
([]
表示拼接)
- 每一层的输出都作为下一层的输入,公式为:
- 优势:
- 特征复用更强
- 梯度更容易传播
- 参数更少
2. Dense Block 实现
class DenseLayer(nn.Module):
def __init__(self, in_channels, growth_rate=32):
super().__init__()
self.bn = nn.BatchNorm2d(in_channels)
self.relu = nn.ReLU(inplace=True)
self.conv = nn.Conv2d(in_channels, growth_rate, kernel_size=3, padding=1, bias=False)
def forward(self, x):
out = self.conv(self.relu(self.bn(x)))
return torch.cat([x, out], 1) # 拼接输入和输出
class DenseBlock(nn.Module):
def __init__(self, num_layers, in_channels, growth_rate):
super().__init__()
layers = []
for i in range(num_layers):
layers.append(DenseLayer(in_channels + i * growth_rate, growth_rate))
self.block = nn.Sequential(*layers)
def forward(self, x):
return self.block(x)
3. 特点
- growth rate:每层新增加的通道数
- Transition Layer:用于压缩通道数,避免过大
四、ResNet vs DenseNet 对比
特性 | ResNet | DenseNet |
---|---|---|
核心思想 | 残差连接 | 密集连接 |
特征传递 | $y = F(x) + x$ | $[x_0, x_1, …, x_{l-1}]$ |
参数量 | 较大 | 较少 |
特征复用 | 一部分 | 全部 |
梯度传播 | 直接相加 | 拼接传递,更强 |
应用场景 | 各类视觉任务(分类、检测) | 更适合小数据集,避免过拟合 |
五、实战:CIFAR-10 分类
1. 数据准备
import torch
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
2. 模型选择
import torchvision.models as models
# ResNet-18
resnet18 = models.resnet18(num_classes=10)
# DenseNet-121
densenet121 = models.densenet121(num_classes=10)
3. 训练与评估
import torch.optim as optim
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnet18.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(5):
model.train()
for images, labels in trainloader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")
六、实战结果对比
(以 CIFAR-10 为例,实际效果因训练轮数不同而不同)
- ResNet-18:准确率 ~ 92%
- DenseNet-121:准确率 ~ 94%,参数量更少
👉 DenseNet 在小数据集上往往优于 ResNet,但 ResNet 在大规模数据(如 ImageNet)更常用。
七、总结
- ResNet:通过残差连接解决深度网络训练困难问题,奠定现代深度学习基础。
- DenseNet:通过密集连接增强特征复用,提升精度并减少参数量。
- 实战建议:
- 大数据 + 深层网络:优先 ResNet
- 小数据 + 轻量级模型:优先 DenseNet
这里是一份 PyTorch 实战教程:ResNet 与 DenseNet 的详解(配合上面生成的示意图更直观):
一、ResNet(残差网络)
1. 设计动机
- 传统网络在层数加深时容易出现 梯度消失 / 梯度爆炸 问题。
- ResNet 引入 残差结构(Residual Block),通过 shortcut 跳跃连接 解决退化问题。
2. 残差块结构
- 输入 → 卷积层 → 卷积层 → 输出
- 同时存在一条 捷径连接(identity mapping),把输入直接加到输出。
- 数学公式:y=F(x)+x其中 F(x) 表示卷积层的变换。
3. PyTorch 实现
import torch
import torch.nn as nn
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
# 维度不一致时的捷径
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
return self.relu(out)
二、DenseNet(稠密网络)
1. 设计动机
- DenseNet 与 ResNet 相似,也为了解决 梯度消失。
- 不同之处:ResNet 采用 加法,而 DenseNet 采用 拼接(concat)。
2. 稠密块结构
- 每一层的输入都是 前面所有层的输出拼接。
- 数学公式:xl=Hl([x0,x1,…,xl−1])其中 [x0,…,xl−1] 表示特征拼接。
3. PyTorch 实现
class DenseLayer(nn.Module):
def __init__(self, in_channels, growth_rate):
super().__init__()
self.bn = nn.BatchNorm2d(in_channels)
self.relu = nn.ReLU(inplace=True)
self.conv = nn.Conv2d(in_channels, growth_rate, kernel_size=3, padding=1)
def forward(self, x):
out = self.conv(self.relu(self.bn(x)))
return torch.cat([x, out], 1) # 特征拼接
class DenseBlock(nn.Module):
def __init__(self, num_layers, in_channels, growth_rate):
super().__init__()
layers = []
for i in range(num_layers):
layers.append(DenseLayer(in_channels + i*growth_rate, growth_rate))
self.block = nn.Sequential(*layers)
def forward(self, x):
return self.block(x)
三、对比总结
特性 | ResNet | DenseNet |
---|---|---|
连接方式 | 加法 (Add) | 拼接 (Concat) |
特征复用 | 较弱(只加法传递) | 较强(特征累积) |
参数量 | 相对较少 | 相对较多 |
训练难度 | 简单 | 更高 |
好的 👍 我来给你准备一个 完整的 PyTorch 实战案例:我们用 CIFAR-10 数据集,分别训练 ResNet 和 DenseNet,对比分类效果。
🔥 PyTorch 实战:ResNet vs DenseNet on CIFAR-10
一、环境准备
pip install torch torchvision matplotlib
二、导入依赖
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
三、数据集准备
# 数据预处理
transform = transforms.Compose([
transforms.RandomHorizontalFlip(), # 数据增强:随机水平翻转
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# CIFAR-10 数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
四、定义 ResNet(简化版)
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
return self.relu(out)
class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super().__init__()
self.in_channels = 16
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(16)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(64, num_classes)
def _make_layer(self, block, out_channels, blocks, stride):
strides = [stride] + [1]*(blocks-1)
layers = []
for s in strides:
layers.append(block(self.in_channels, out_channels, s))
self.in_channels = out_channels
return nn.Sequential(*layers)
def forward(self, x):
out = self.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.avgpool(out)
out = out.view(out.size(0), -1)
return self.fc(out)
def ResNet20():
return ResNet(ResidualBlock, [3, 3, 3])
五、定义 DenseNet(简化版)
class DenseLayer(nn.Module):
def __init__(self, in_channels, growth_rate):
super().__init__()
self.bn = nn.BatchNorm2d(in_channels)
self.relu = nn.ReLU(inplace=True)
self.conv = nn.Conv2d(in_channels, growth_rate, kernel_size=3, padding=1)
def forward(self, x):
out = self.conv(self.relu(self.bn(x)))
return torch.cat([x, out], 1)
class DenseBlock(nn.Module):
def __init__(self, num_layers, in_channels, growth_rate):
super().__init__()
layers = []
for i in range(num_layers):
layers.append(DenseLayer(in_channels + i*growth_rate, growth_rate))
self.block = nn.Sequential(*layers)
def forward(self, x):
return self.block(x)
class DenseNet(nn.Module):
def __init__(self, growth_rate=12, num_classes=10):
super().__init__()
self.conv1 = nn.Conv2d(3, growth_rate*2, kernel_size=3, stride=1, padding=1)
num_channels = growth_rate*2
self.block1 = DenseBlock(4, num_channels, growth_rate)
num_channels += 4*growth_rate
self.block2 = DenseBlock(4, num_channels, growth_rate)
num_channels += 4*growth_rate
self.bn = nn.BatchNorm2d(num_channels)
self.relu = nn.ReLU(inplace=True)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(num_channels, num_classes)
def forward(self, x):
out = self.conv1(x)
out = self.block1(out)
out = self.block2(out)
out = self.relu(self.bn(out))
out = self.avgpool(out)
out = out.view(out.size(0), -1)
return self.fc(out)
def DenseNetCIFAR():
return DenseNet()
六、训练与测试
def train(model, trainloader, criterion, optimizer, device):
model.train()
total_loss = 0
for inputs, targets in trainloader:
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(trainloader)
def test(model, testloader, criterion, device):
model.eval()
correct, total, test_loss = 0, 0, 0
with torch.no_grad():
for inputs, targets in testloader:
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
return test_loss / len(testloader), 100. * correct / total
七、主程序:训练对比
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 选择模型
models = {
"ResNet20": ResNet20().to(device),
"DenseNet": DenseNetCIFAR().to(device)
}
criterion = nn.CrossEntropyLoss()
num_epochs = 10
results = {}
for name, model in models.items():
print(f"\n🔹 Training {name}...")
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
train_losses, test_accs = [], []
for epoch in range(num_epochs):
train_loss = train(model, trainloader, criterion, optimizer, device)
test_loss, test_acc = test(model, testloader, criterion, device)
train_losses.append(train_loss)
test_accs.append(test_acc)
print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Test Acc: {test_acc:.2f}%")
results[name] = test_accs
# 可视化对比
for name, accs in results.items():
plt.plot(accs, label=name)
plt.xlabel("Epoch")
plt.ylabel("Test Accuracy (%)")
plt.title("ResNet vs DenseNet on CIFAR-10")
plt.legend()
plt.show()
✅ 最终运行后,你会看到 ResNet 与 DenseNet 在 CIFAR-10 上的准确率曲线对比。
发表回复