⚙️ 优化器选择
import torch.optim as optim
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
optimizer = optim.RMSprop(model.parameters(), lr=0.01)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
💻 完整训练循环
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
def train_model(model, train_loader, val_loader, epochs, device):
"""完整训练函数"""
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
model = model.to(device)
best_val_acc = 0.0
for epoch in range(epochs):
model.train()
train_loss = 0.0
train_correct = 0
train_total = 0
for batch_idx, (data, targets) in enumerate(tqdm(train_loader, desc=f'Epoch {epoch+1}')):
data, targets = data.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(data)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
train_total += targets.size(0)
train_correct += predicted.eq(targets).sum().item()
scheduler.step()
model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad():
for data, targets in val_loader:
data, targets = data.to(device), targets.to(device)
outputs = model(data)
loss = criterion(outputs, targets)
val_loss += loss.item()
_, predicted = outputs.max(1)
val_total += targets.size(0)
val_correct += predicted.eq(targets).sum().item()
print(f'Epoch [{epoch+1}/{epochs}]')
print(f'Train Loss: {train_loss/len(train_loader):.4f}, '
f'Train Acc: {100.*train_correct/train_total:.2f}%')
print(f'Val Loss: {val_loss/len(val_loader):.4f}, '
f'Val Acc: {100.*val_correct/val_total:.2f}%')
val_acc = val_correct / val_total
if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save(model.state_dict(), 'best_model.pth')
print(f'Best Validation Accuracy: {best_val_acc*100:.2f}%')
return model
💾 模型保存与加载
torch.save(model.state_dict(), 'model_weights.pth')
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
torch.save(model, 'complete_model.pth')
model = torch.load('complete_model.pth')
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}
torch.save(checkpoint, 'checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']