🎯 学习目标

  • 理解Dataset和DataLoader的作用与区别
  • 掌握自定义Dataset的创建方法
  • 学会使用DataLoader进行批量数据加载
  • 了解数据增强和预处理技术
数据处理管道

数据加载概述

在深度学习中,高效的数据加载至关重要。PyTorch提供了Dataset和DataLoader两个核心类, Dataset负责定义数据集的访问方式,DataLoader负责批量加载、打乱顺序、并行处理等。

📊 Dataset vs DataLoader

Dataset

数据集的抽象基类,定义如何获取单个样本

  • 实现__len__方法返回数据集大小
  • 实现__getitem__方法获取单个样本
  • 支持索引访问数据

DataLoader

数据加载器,负责批量加载和迭代

  • 批量处理(Batching)
  • 数据打乱(Shuffling)
  • 并行加载(多进程)

💻 自定义Dataset

import torch from torch.utils.data import Dataset, DataLoader import numpy as np class CustomDataset(Dataset): """自定义数据集类""" def __init__(self, data, labels, transform=None): """ Args: data: 数据数组 labels: 标签数组 transform: 可选的数据预处理 """ self.data = torch.tensor(data, dtype=torch.float32) self.labels = torch.tensor(labels, dtype=torch.long) self.transform = transform def __len__(self): """返回数据集大小""" return len(self.data) def __getitem__(self, idx): """获取单个样本""" sample = self.data[idx] label = self.labels[idx] if self.transform: sample = self.transform(sample) return sample, label # 使用示例 data = np.random.randn(1000, 10) # 1000个样本,每个10维 labels = np.random.randint(0, 2, 1000) # 二分类标签 dataset = CustomDataset(data, labels) print(f'数据集大小: {len(dataset)}') print(f'第一个样本: {dataset[0]}')

DataLoader使用

from torch.utils.data import DataLoader # 创建DataLoader dataloader = DataLoader( dataset=dataset, batch_size=32, # 批量大小 shuffle=True, # 是否打乱数据 num_workers=4, # 并行加载进程数 pin_memory=True, # 锁页内存(GPU训练时加速) drop_last=False # 是否丢弃不完整批次 ) # 迭代获取数据 for batch_idx, (data, labels) in enumerate(dataloader): print(f'Batch {batch_idx}: data.shape = {data.shape}') if batch_idx >= 2: break # 获取单个批次 data, labels = next(iter(dataloader))

⚖️ DataLoader参数详解

参数 说明 建议值
batch_size 每批样本数量 32/64/128(根据GPU内存调整)
shuffle 是否每轮打乱数据顺序 训练集True,验证/测试集False
num_workers 数据加载的子进程数 CPU核心数(通常4-8)
pin_memory 使用锁页内存 GPU训练时设为True
drop_last 丢弃最后不完整的批次 需要固定批次大小时True

🖼️ 图像数据增强

from torchvision import transforms # 定义训练数据预处理 train_transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.RandomCrop(224), # 随机裁剪 transforms.RandomHorizontalFlip(), # 随机水平翻转 transforms.RandomRotation(15), # 随机旋转 transforms.ColorJitter( # 颜色抖动 brightness=0.2, contrast=0.2, saturation=0.2 ), transforms.ToTensor(), # 转为Tensor transforms.Normalize( # 标准化 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # 验证/测试数据预处理(不增强) val_transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ])

📁 使用内置数据集

from torchvision.datasets import MNIST, CIFAR10, ImageFolder from torch.utils.data import random_split # MNIST数据集 train_dataset = MNIST( root='./data', train=True, download=True, transform=transforms.ToTensor() ) # 划分训练集和验证集 train_size = int(0.8 * len(train_dataset)) val_size = len(train_dataset) - train_size train_set, val_set = random_split( train_dataset, [train_size, val_size] ) # 创建DataLoader train_loader = DataLoader(train_set, batch_size=64, shuffle=True) val_loader = DataLoader(val_set, batch_size=64, shuffle=False)
⚠️
Windows注意事项

在Windows系统上使用num_workers > 0时,需要将数据加载代码放在if __name__ == '__main__':块中,否则可能产生多进程错误。

数据处理流程
图:数据从磁盘到模型的高效流转

📝 本节小结

  • • Dataset定义数据集的访问方式,需实现__len__和__getitem__
  • • DataLoader负责批量加载、打乱、并行处理数据
  • • 合理设置batch_size和num_workers可提升训练效率
  • • transforms模块提供了丰富的图像预处理和数据增强方法
  • • torchvision.datasets包含常用内置数据集