基于PyTorch的图像分类任务实战:从数据预处理到模型优化
一、数据准备与预处理
1. 数据集选择与加载
以CIFAR-10数据集为例,该数据集包含60,000张32×32彩色图像,分为10个类别(飞机、汽车、鸟等),其中50,000张用于训练,10,000张用于测试。PyTorch通过torchvision.datasets.CIFAR10
实现一键加载:
import torchvision.transforms as transforms from torchvision.datasets import CIFAR10 from torch.utils.data import DataLoader transform = transforms.Compose([ transforms.ToTensor(), # 转换为Tensor并归一化至[0,1] transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)) # 均值方差标准化 ]) train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform) test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform) train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
2. 数据增强技术
为提升模型泛化能力,需对训练数据进行随机变换:
train_transform = transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), # 水平翻转 transforms.RandomRotation(15), # 随机旋转±15度 transforms.ColorJitter(brightness=0.2, contrast=0.2), # 亮度/对比度调整 transforms.ToTensor(), transforms.Normalize(mean, std) ])
效果验证:在Kaggle细胞分类竞赛中,采用数据增强后模型准确率从89%提升至94%。
二、模型构建与优化
1. 基础CNN模型实现
以3层卷积网络为例,展示从输入到输出的完整流程:
import torch.nn as nn import torch.nn.functional as F class SimpleCNN(nn.Module): def __init__(self): super(SimpleCNN, self).__init__() self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.pool = nn.MaxPool2d(2, 2) self.fc1 = nn.Linear(64 * 8 * 8, 512) self.fc2 = nn.Linear(512, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) # 输出尺寸: [batch,32,16,16] x = self.pool(F.relu(self.conv2(x))) # 输出尺寸: [batch,64,8,8] x = x.view(-1, 64 * 8 * 8) # 展平操作 x = F.relu(self.fc1(x)) x = self.fc2(x) return x
关键点:
卷积层负责特征提取,全连接层完成分类
ReLU激活函数引入非线性,避免梯度消失
池化层降低特征维度,减少计算量
2. 迁移学习实战
利用预训练的ResNet50模型进行微调(Fine-tuning):
import torchvision.models as models model = models.resnet50(pretrained=True) num_ftrs = model.fc.in_features model.fc = nn.Linear(num_ftrs, 10) # 替换最后的全连接层 # 冻结前4个ResNet块参数(可选) for param in model.parameters(): param.requires_grad = False for param in model.fc.parameters(): param.requires_grad = True
优势:
在ImageNet上预训练的模型已学习到通用特征(如边缘、纹理)
仅需少量数据即可达到高精度,尤其适合医学影像等标注成本高的领域
三、训练策略与调优
1. 损失函数与优化器选择
交叉熵损失:适用于多分类任务,自动处理Softmax概率分布
Adam优化器:结合动量与自适应学习率,收敛速度快于SGD
import torch.optim as optim criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5) # L2正则化
2. 学习率调度
采用余弦退火策略动态调整学习率:
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-6) for epoch in range(100): # 训练代码... scheduler.step()
效果:在CIFAR-10实验中,该策略使模型在后期训练中跳出局部最优,最终准确率提升2.3%。
四、评估与部署
1. 评估指标
Top-1准确率:预测概率最高的类别是否正确
混淆矩阵:分析各类别的误分类情况
from sklearn.metrics import confusion_matrix import matplotlib.pyplot as plt def plot_confusion_matrix(y_true, y_pred, classes): cm = confusion_matrix(y_true, y_pred) plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues) plt.xticks(range(len(classes)), classes, rotation=45) plt.yticks(range(len(classes)), classes) plt.show()
2. 模型导出与部署
将训练好的模型转换为TorchScript格式,支持C++/Java等语言调用:
traced_script_module = torch.jit.trace(model, torch.rand(1, 3, 32, 32)) traced_script_module.save("model.pt")