💻 自定义Dataset
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
class CustomDataset(Dataset):
"""自定义数据集类"""
def __init__(self, data, labels, transform=None):
"""
Args:
data: 数据数组
labels: 标签数组
transform: 可选的数据预处理
"""
self.data = torch.tensor(data, dtype=torch.float32)
self.labels = torch.tensor(labels, dtype=torch.long)
self.transform = transform
def __len__(self):
"""返回数据集大小"""
return len(self.data)
def __getitem__(self, idx):
"""获取单个样本"""
sample = self.data[idx]
label = self.labels[idx]
if self.transform:
sample = self.transform(sample)
return sample, label
data = np.random.randn(1000, 10)
labels = np.random.randint(0, 2, 1000)
dataset = CustomDataset(data, labels)
print(f'数据集大小: {len(dataset)}')
print(f'第一个样本: {dataset[0]}')
⚡ DataLoader使用
from torch.utils.data import DataLoader
dataloader = DataLoader(
dataset=dataset,
batch_size=32,
shuffle=True,
num_workers=4,
pin_memory=True,
drop_last=False
)
for batch_idx, (data, labels) in enumerate(dataloader):
print(f'Batch {batch_idx}: data.shape = {data.shape}')
if batch_idx >= 2:
break
data, labels = next(iter(dataloader))
🖼️ 图像数据增强
from torchvision import transforms
train_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(
brightness=0.2,
contrast=0.2,
saturation=0.2
),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
val_transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
📁 使用内置数据集
from torchvision.datasets import MNIST, CIFAR10, ImageFolder
from torch.utils.data import random_split
train_dataset = MNIST(
root='./data',
train=True,
download=True,
transform=transforms.ToTensor()
)
train_size = int(0.8 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_set, val_set = random_split(
train_dataset,
[train_size, val_size]
)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)
val_loader = DataLoader(val_set, batch_size=64, shuffle=False)