🏗️ nn.Module基础
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(SimpleNet, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out
model = SimpleNet(input_size=784, hidden_size=128, num_classes=10)
print(model)
🧱 使用nn.Sequential构建模型
model = nn.Sequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Dropout(0.2),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
from collections import OrderedDict
model = nn.Sequential(OrderedDict([
('fc1', nn.Linear(784, 256)),
('relu1', nn.ReLU()),
('dropout', nn.Dropout(0.2)),
('fc2', nn.Linear(256, 10))
]))
print(model.fc1)
🔧 模型参数管理
for name, param in model.named_parameters():
print(f'{name}: {param.shape}')
params = list(model.parameters())
print(f'参数数量: {len(params)}')
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'总参数: {total_params}, 可训练参数: {trainable_params}')
for param in model.parameters():
param.requires_grad = False