🎯 学习目标

  • 理解类与对象的概念
  • 掌握继承与多态
  • 学会设计可复用的类
  • 了解PyTorch中的OOP模式
面向对象
图:面向对象是组织复杂代码的有效方式

🏗️ 类的定义

class NeuralNetwork: """神经网络基类""" # 类属性 version = "1.0" def __init__(self, input_size, hidden_size, output_size): # 实例属性 self.input_size = input_size self.hidden_size = hidden_size self.output_size = output_size self.weights = None def forward(self, x): """前向传播""" raise NotImplementedError def __repr__(self): return f"NeuralNetwork(input={self.input_size}, hidden={self.hidden_size})" def __len__(self): return self.input_size + self.hidden_size + self.output_size # 使用 model = NeuralNetwork(784, 256, 10) print(model) # NeuralNetwork(input=784, hidden=256) print(len(model)) # 1030

🔄 继承与多态

继承示例

import torch import torch.nn as nn # 继承PyTorch的Module类 class MLP(nn.Module): def __init__(self, input_size, hidden_size, output_size): super().__init__() # 调用父类初始化 self.fc1 = nn.Linear(input_size, hidden_size) self.fc2 = nn.Linear(hidden_size, output_size) self.relu = nn.ReLU() def forward(self, x): x = self.fc1(x) x = self.relu(x) x = self.fc2(x) return x # 多继承 class Mixin: def save(self, path): torch.save(self.state_dict(), path) class MyModel(MLP, Mixin): pass model = MyModel(784, 256, 10) model.save("model.pt")

🎭 属性装饰器

class Dataset: def __init__(self, data): self._data = data self._size = None @property def size(self): """延迟计算大小""" if self._size is None: self._size = len(self._data) return self._size @size.setter def size(self, value): raise AttributeError("Cannot set size directly") @staticmethod def from_file(path): """静态方法:从文件创建数据集""" with open(path) as f: data = f.readlines() return Dataset(data) @classmethod def empty(cls): """类方法:创建空数据集""" return cls([]) # 使用 ds = Dataset([1, 2, 3, 4, 5]) print(ds.size) # 5 ds2 = Dataset.from_file("data.txt") ds3 = Dataset.empty()

📦 数据类(Dataclass)

Python 3.7+ 简化类定义

from dataclasses import dataclass, field from typing import List @dataclass class TrainingConfig: learning_rate: float = 0.001 batch_size: int = 32 epochs: int = 100 optimizer: str = "adam" layers: List[int] = field(default_factory=lambda: [256, 128]) def __post_init__(self): # 初始化后执行 if self.learning_rate <= 0: raise ValueError("learning_rate must be positive") # 使用 config = TrainingConfig(learning_rate=0.0001, epochs=50) print(config) # TrainingConfig(learning_rate=0.0001, batch_size=32, # epochs=50, optimizer='adam', layers=[256, 128])

📝 本节小结

  • • 类是创建对象的模板,封装数据和行为
  • • 继承实现代码复用,多态实现接口统一
  • • @property实现属性的getter/setter
  • • @dataclass简化数据类的定义