- • 类是创建对象的模板,封装数据和行为
- • 继承实现代码复用,多态实现接口统一
- • @property实现属性的getter/setter
- • @dataclass简化数据类的定义
4.2 面向对象编程
构建可复用的AI组件
🎯 学习目标
- 理解类与对象的概念
- 掌握继承与多态
- 学会设计可复用的类
- 了解PyTorch中的OOP模式
图:面向对象是组织复杂代码的有效方式
🏗️ 类的定义
class NeuralNetwork:
"""神经网络基类"""
# 类属性
version = "1.0"
def __init__(self, input_size, hidden_size, output_size):
# 实例属性
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.weights = None
def forward(self, x):
"""前向传播"""
raise NotImplementedError
def __repr__(self):
return f"NeuralNetwork(input={self.input_size}, hidden={self.hidden_size})"
def __len__(self):
return self.input_size + self.hidden_size + self.output_size
# 使用
model = NeuralNetwork(784, 256, 10)
print(model) # NeuralNetwork(input=784, hidden=256)
print(len(model)) # 1030
🔄 继承与多态
继承示例
import torch
import torch.nn as nn
# 继承PyTorch的Module类
class MLP(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__() # 调用父类初始化
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
# 多继承
class Mixin:
def save(self, path):
torch.save(self.state_dict(), path)
class MyModel(MLP, Mixin):
pass
model = MyModel(784, 256, 10)
model.save("model.pt")
🎭 属性装饰器
class Dataset:
def __init__(self, data):
self._data = data
self._size = None
@property
def size(self):
"""延迟计算大小"""
if self._size is None:
self._size = len(self._data)
return self._size
@size.setter
def size(self, value):
raise AttributeError("Cannot set size directly")
@staticmethod
def from_file(path):
"""静态方法:从文件创建数据集"""
with open(path) as f:
data = f.readlines()
return Dataset(data)
@classmethod
def empty(cls):
"""类方法:创建空数据集"""
return cls([])
# 使用
ds = Dataset([1, 2, 3, 4, 5])
print(ds.size) # 5
ds2 = Dataset.from_file("data.txt")
ds3 = Dataset.empty()
📦 数据类(Dataclass)
Python 3.7+ 简化类定义
from dataclasses import dataclass, field
from typing import List
@dataclass
class TrainingConfig:
learning_rate: float = 0.001
batch_size: int = 32
epochs: int = 100
optimizer: str = "adam"
layers: List[int] = field(default_factory=lambda: [256, 128])
def __post_init__(self):
# 初始化后执行
if self.learning_rate <= 0:
raise ValueError("learning_rate must be positive")
# 使用
config = TrainingConfig(learning_rate=0.0001, epochs=50)
print(config)
# TrainingConfig(learning_rate=0.0001, batch_size=32,
# epochs=50, optimizer='adam', layers=[256, 128])
📝 本节小结
✅