当前位置:首页 > python > 正文内容

基于PyTorch的图像分类任务实战:从数据预处理到模型优化

zhangsir3周前 (07-07)python19

一、数据准备与预处理

1. 数据集选择与加载

以CIFAR-10数据集为例,该数据集包含60,000张32×32彩色图像,分为10个类别(飞机、汽车、鸟等),其中50,000张用于训练,10,000张用于测试。PyTorch通过torchvision.datasets.CIFAR10实现一键加载:

import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.ToTensor(),  # 转换为Tensor并归一化至[0,1]
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))  # 均值方差标准化
])

train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

2. 数据增强技术

为提升模型泛化能力,需对训练数据进行随机变换:

train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),  # 水平翻转
    transforms.RandomRotation(15),          # 随机旋转±15度
    transforms.ColorJitter(brightness=0.2, contrast=0.2),  # 亮度/对比度调整
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

效果验证:在Kaggle细胞分类竞赛中,采用数据增强后模型准确率从89%提升至94%。

二、模型构建与优化

1. 基础CNN模型实现

以3层卷积网络为例,展示从输入到输出的完整流程:

import torch.nn as nn
import torch.nn.functional as F

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # 输出尺寸: [batch,32,16,16]
        x = self.pool(F.relu(self.conv2(x)))  # 输出尺寸: [batch,64,8,8]
        x = x.view(-1, 64 * 8 * 8)            # 展平操作
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

关键点

  • 卷积层负责特征提取,全连接层完成分类

  • ReLU激活函数引入非线性,避免梯度消失

  • 池化层降低特征维度,减少计算量

2. 迁移学习实战

利用预训练的ResNet50模型进行微调(Fine-tuning):

import torchvision.models as models

model = models.resnet50(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 10)  # 替换最后的全连接层

# 冻结前4个ResNet块参数(可选)
for param in model.parameters():
    param.requires_grad = False
for param in model.fc.parameters():
    param.requires_grad = True

优势

  • 在ImageNet上预训练的模型已学习到通用特征(如边缘、纹理)

  • 仅需少量数据即可达到高精度,尤其适合医学影像等标注成本高的领域

三、训练策略与调优

1. 损失函数与优化器选择

  • 交叉熵损失:适用于多分类任务,自动处理Softmax概率分布

  • Adam优化器:结合动量与自适应学习率,收敛速度快于SGD

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)  # L2正则化

2. 学习率调度

采用余弦退火策略动态调整学习率:

scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-6)

for epoch in range(100):
    # 训练代码...
    scheduler.step()

效果:在CIFAR-10实验中,该策略使模型在后期训练中跳出局部最优,最终准确率提升2.3%。

四、评估与部署

1. 评估指标

  • Top-1准确率:预测概率最高的类别是否正确

  • 混淆矩阵:分析各类别的误分类情况

from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

def plot_confusion_matrix(y_true, y_pred, classes):
    cm = confusion_matrix(y_true, y_pred)
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.xticks(range(len(classes)), classes, rotation=45)
    plt.yticks(range(len(classes)), classes)
    plt.show()

2. 模型导出与部署

将训练好的模型转换为TorchScript格式,支持C++/Java等语言调用:

traced_script_module = torch.jit.trace(model, torch.rand(1, 3, 32, 32))
traced_script_module.save("model.pt")


zhangsir版权c3防采集https://mianka.xyz

扫描二维码推送至手机访问。

版权声明:本文由zhangsir or zhangmaam发布,如需转载请注明出处。

本文链接:https://www.mianka.xyz/post/190.html

分享给朋友:

“基于PyTorch的图像分类任务实战:从数据预处理到模型优化” 的相关文章

python 爬虫 报错:UnicodeDecodeError: ‘utf-8‘ codec can‘t decode byte 0x8b in position”解决方案

发现报错“UnicodeDecodeError: 'utf-8' codec can't decode byte 0x8b in position 1:invalid start byte”,方法一:根据报错提示,错误原因有一条是这样的:“'Accept-Encodi...

python 实现彩色图转素描图

python可以把彩色图片转化为铅笔素描草图,对人像、景色都有很好的效果。而且只需几行代码就可以一键生成,适合批量操作,非常的快捷。需要的第三方库:Opencv - 计算机视觉工具,可以实现多元化的图像视频处理,有Python接口""" Photo ...

python 给电脑设置闹钟

python会自动触发windows桌面通知,提示重要事项,比如说:您已工作两小时,该休息了我们可以设定固定时间提示,比如隔10分钟、1小时等用到的第三方库:win10toast - 用于发送桌面通知的工具from win10toast import ToastNoti...

python 多线程与多进程的代码实例

一.两者区别多进程和多线程的主要区别是:线程是进程的子集(部分),一个进程可能由多个线程组成。多进程的数据是分开的、共享复杂,需要用IPC;但同步简单。多线程共享进程数据,共享简单;但同步复杂。(1)多进程进程是程序在计算机上的一次执行活动,即正在运行中的应用程序,通常称为进程。当你运行一个程序,你...

Linux系统下使用Python+selenium+谷歌浏览器下载文件

from seleniumwire import webdriver import time ch_options = webdriver.ChromeOptions() ch_options.add_argument("-...

python 实现快速扣背景图功能

一,实现快速扣背景图需要rembg这个三方库#引入rembg库 from rembg import remove #素材 input_path = 'input.jpg' #效果 output_path =&nbs...