当前位置:首页 > python > 正文内容

基于PyTorch的图像分类任务实战:从数据预处理到模型优化

zhangsir2个月前 (07-07)python51

一、数据准备与预处理

1. 数据集选择与加载

以CIFAR-10数据集为例,该数据集包含60,000张32×32彩色图像,分为10个类别(飞机、汽车、鸟等),其中50,000张用于训练,10,000张用于测试。PyTorch通过torchvision.datasets.CIFAR10实现一键加载:

import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为Tensor并归一化至[0,1]
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))  # 均值方差标准化
])

train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

2. 数据增强技术

为提升模型泛化能力,需对训练数据进行随机变换:

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),  # 水平翻转
    transforms.RandomRotation(15),          # 随机旋转±15度
    transforms.ColorJitter(brightness=0.2, contrast=0.2),  # 亮度/对比度调整
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

效果验证:在Kaggle细胞分类竞赛中,采用数据增强后模型准确率从89%提升至94%。

二、模型构建与优化

1. 基础CNN模型实现

以3层卷积网络为例,展示从输入到输出的完整流程:

import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # 输出尺寸: [batch,32,16,16]
        x = self.pool(F.relu(self.conv2(x)))  # 输出尺寸: [batch,64,8,8]
        x = x.view(-1, 64 * 8 * 8)            # 展平操作
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

关键点

  • 卷积层负责特征提取,全连接层完成分类

  • ReLU激活函数引入非线性,避免梯度消失

  • 池化层降低特征维度,减少计算量

2. 迁移学习实战

利用预训练的ResNet50模型进行微调(Fine-tuning):

import torchvision.models as models

model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)  # 替换最后的全连接层

# 冻结前4个ResNet块参数(可选)
for param in model.parameters():
    param.requires_grad = False
for param in model.fc.parameters():
    param.requires_grad = True

优势

  • 在ImageNet上预训练的模型已学习到通用特征(如边缘、纹理)

  • 仅需少量数据即可达到高精度,尤其适合医学影像等标注成本高的领域

三、训练策略与调优

1. 损失函数与优化器选择

  • 交叉熵损失:适用于多分类任务,自动处理Softmax概率分布

  • Adam优化器:结合动量与自适应学习率,收敛速度快于SGD

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)  # L2正则化

2. 学习率调度

采用余弦退火策略动态调整学习率:

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-6)

for epoch in range(100):
    # 训练代码...
    scheduler.step()

效果:在CIFAR-10实验中,该策略使模型在后期训练中跳出局部最优,最终准确率提升2.3%。

四、评估与部署

1. 评估指标

  • Top-1准确率:预测概率最高的类别是否正确

  • 混淆矩阵:分析各类别的误分类情况

from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

def plot_confusion_matrix(y_true, y_pred, classes):
    cm = confusion_matrix(y_true, y_pred)
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.xticks(range(len(classes)), classes, rotation=45)
    plt.yticks(range(len(classes)), classes)
    plt.show()

2. 模型导出与部署

将训练好的模型转换为TorchScript格式,支持C++/Java等语言调用:

traced_script_module = torch.jit.trace(model, torch.rand(1, 3, 32, 32))
traced_script_module.save("model.pt")


zhangsir版权c2防采集https://mianka.xyz

扫描二维码推送至手机访问。

版权声明:本文由zhangsir or zhangmaam发布,如需转载请注明出处。

本文链接:https://www.mianka.xyz/post/190.html

分享给朋友:

“基于PyTorch的图像分类任务实战:从数据预处理到模型优化” 的相关文章

Python怎么获取命令行参数

输入:“ import sys”,导入 sys 模块。插入语句:“print(sys.argv)”,打印获取的命令行参数。...

python+selenium元素定位的8种方法

定位元素,selenium提供了8中元素定位方法:(1)find_element_by_id() :html规定,id在html中必须是唯一的,有点类似于身份证号(2)find_element_by_name() :html规定,name用来指定元素的名称,有点类似于人名(3)find_elemen...

Python三方库ddddocr实现验证码识别

Python三方库ddddocr实现验证码识别环境要求python >= 3.8安装三方库pip install ddddocr -i https://pypi.tuna.tsinghua.edu.cn/simple参数说明:参数名参数类型默认值说明us...

使用pyautogui进行屏幕捕捉实现自动化操作

import pyautogui import time # # 获取基本信息 # # 屏幕大小 # size = pyautogui.size() # print(size) #&nbs...

解决Django的request.POST获取不到请求参数的问题

这个是Django自身的问题:只要在请求头的添加"content-type":'application/x-www-form-urlencoded'就行。...

python 爬虫 报错:UnicodeDecodeError: ‘utf-8‘ codec can‘t decode byte 0x8b in position”解决方案

发现报错“UnicodeDecodeError: 'utf-8' codec can't decode byte 0x8b in position 1:invalid start byte”,方法一:根据报错提示,错误原因有一条是这样的:“'Accept-Encodi...