当前位置:首页 > python > 正文内容

PyTorch自定义Dataset全解析:从理论到实战的完整指南

zhangsir4周前 (06-30)python21

一、Dataset的核心机制

PyTorch的数据加载体系基于两大核心组件:

  1. Dataset:定义数据集的抽象接口,负责索引到样本的映射。

  2. DataLoader:封装Dataset,提供批量加载、多线程加速等功能。

1.1 数据传递流程

# 典型流程示例
dataset = CustomDataset(...)  # 创建自定义数据集
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)  # 包装为DataLoader
for batch in dataloader:  # 迭代获取批量数据
    inputs, labels = batch

1.2 Map式 vs Iterable式数据集

  • Map式数据集(常用):通过__getitem__实现索引访问,支持随机打乱(shuffle)。

  • Iterable式数据集:适用于流式数据(如实时传感器数据),按顺序迭代。

二、自定义Dataset的实现范式

继承torch.utils.data.Dataset需实现三个核心方法:

2.1 基础模板

from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, data_path, transform=None):
        self.data_path = data_path
        self.transform = transform
        self.samples = self._load_data()  # 加载数据列表

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        # 1. 读取原始数据
        sample = self._read_sample(index)
        
        # 2. 应用预处理
        if self.transform:
            sample = self.transform(sample)
            
        # 3. 返回样本(如图像+标签)
        return sample

2.2 关键方法详解

  • __init__:初始化数据路径、预处理变换,并加载数据元信息。

  • __len__:返回数据集大小,用于DataLoader的进度控制。

  • __getitem__:核心方法,需完成数据读取、预处理和返回。

三、实战案例:图像分类数据集

以Kaggle的Dogs vs Cats数据集为例,实现自定义Dataset:

3.1 数据准备

dataset/
├── train/
│   ├── cat.0.jpg
│   ├── dog.0.jpg
│   └── ...
└── annotations.txt  # 格式:filename label

3.2 完整实现

import os
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

class DogCatDataset(Dataset):
    def __init__(self, root_dir, ann_file, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = self._load_annotations(ann_file)

    def _load_annotations(self, ann_file):
        samples = []
        with open(ann_file) as f:
            for line in f:
                filename, label = line.strip().split()
                samples.append((filename, int(label)))
        return samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        filename, label = self.samples[index]
        img_path = os.path.join(self.root_dir, filename)
        image = Image.open(img_path)
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

# 使用示例
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

dataset = DogCatDataset(
    root_dir='dataset/train',
    ann_file='dataset/annotations.txt',
    transform=transform
)

dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)


四、高级技巧与优化

4.1 多线程加速

通过num_workers参数启用多进程加载:

pythondataloader = DataLoader(dataset, batch_size=64, num_workers=8)

4.2 自定义Collate函数

处理变长数据(如NLP序列):

def collate_fn(batch):
    # batch: List[Tuple(image, label)]
    images = [item[0] for item in batch]
    labels = [item[1] for item in batch]
    # 堆叠图像并转换为Tensor
    images = torch.stack(images, dim=0)
    labels = torch.LongTensor(labels)
    return images, labels

dataloader = DataLoader(dataset, batch_size=32, collate_fn=collate_fn)


4.3 内存映射优化

对于大型数据集,可使用内存映射文件(如HDF5)减少I/O开销:

import h5py

class HDF5Dataset(Dataset):
    def __init__(self, h5_path):
        self.h5_file = h5py.File(h5_path, 'r')
        self.length = len(self.h5_file['images'])

    def __getitem__(self, index):
        image = self.h5_file['images'][index]
        label = self.h5_file['labels'][index]
        return image, label

五、常见问题与解决方案

5.1 数据路径错误

  • 问题FileNotFoundError

  • 解决:使用os.path.join构建跨平台路径,检查文件权限。

5.2 内存不足

  • 问题:加载大型数据集时OOM

  • 解决

    • 使用生成器(IterableDataset)

    • 分批加载数据

    • 降低batch_size

5.3 数据类型不匹配

  • 问题RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor)

  • 解决:显式调用.to(device)或使用pin_memory=True加速GPU传输。

六、总结

自定义Dataset是PyTorch数据流水线的核心技能,通过继承Dataset类并实现__len____getitem__方法,开发者可以灵活处理各类数据格式。结合DataLoader的多线程加速和自定义collate_fn,可构建高效的数据加载管道。实际应用中需注意路径处理、内存优化和设备兼容性,以确保训练过程的稳定性。


zhangsir版权k3防采集https://mianka.xyz

扫描二维码推送至手机访问。

版权声明:本文由zhangsir or zhangmaam发布,如需转载请注明出处。

本文链接:https://www.mianka.xyz/post/186.html

分享给朋友:

“PyTorch自定义Dataset全解析:从理论到实战的完整指南” 的相关文章

Selenium添加Cookie来实现自动登录

Selenium添加Cookie来实现自动登录第一步获取你登录的cookie,以csdn为例from selenium import webdriver driver = webdriver.Chrome() driver.get('...

Python三方库ddddocr实现验证码识别

Python三方库ddddocr实现验证码识别环境要求python >= 3.8安装三方库pip install ddddocr -i https://pypi.tuna.tsinghua.edu.cn/simple参数说明:参数名参数类型默认值说明us...

python 多线程与多进程的代码实例

一.两者区别多进程和多线程的主要区别是:线程是进程的子集(部分),一个进程可能由多个线程组成。多进程的数据是分开的、共享复杂,需要用IPC;但同步简单。多线程共享进程数据,共享简单;但同步复杂。(1)多进程进程是程序在计算机上的一次执行活动,即正在运行中的应用程序,通常称为进程。当你运行一个程序,你...

python selenium 使用代理ip

代码如下:from selenium import webdriver chromeOptions = webdriver.ChromeOptions() chromeOptions.add_argument("--proxy-serv...

Linux系统下使用Python+selenium+谷歌浏览器下载文件

from seleniumwire import webdriver import time ch_options = webdriver.ChromeOptions() ch_options.add_argument("-...

python 实现快速扣背景图功能

一,实现快速扣背景图需要rembg这个三方库#引入rembg库 from rembg import remove #素材 input_path = 'input.jpg' #效果 output_path =&nbs...