🎯 学习目标

  • 理解Scikit-learn的API设计原则
  • 掌握Estimator、Transformer、Predictor接口
  • 学会使用fit/transform/predict方法
  • 了解一致接口的优势
API设计概念图

统一API设计原则

Scikit-learn采用一致的API设计,所有算法遵循相同的接口规范。 这使得学习新算法变得简单,也便于模型替换和组合。

🔧 核心接口

📊

Estimator

基础估计器,fit()方法训练模型

🔄

Transformer

数据变换器,transform()方法转换数据

🎯

Predictor

预测器,predict()方法进行预测

💻 基本用法模式

from sklearn.ensemble import RandomForestClassifier from sklearn.preprocessing import StandardScaler # 1. 创建估计器实例 model = RandomForestClassifier(n_estimators=100) # 2. 训练模型 (fit) model.fit(X_train, y_train) # 3. 预测 (predict) y_pred = model.predict(X_test) # 4. 评估 (score) accuracy = model.score(X_test, y_test) # Transformer模式 scaler = StandardScaler() scaler.fit(X_train) # 学习参数 X_scaled = scaler.transform(X_train) # 转换数据 # 或一步完成 X_scaled = scaler.fit_transform(X_train)

📋 核心方法一览

方法 说明 适用对象
fit(X, y) 训练模型/学习参数 所有Estimator
transform(X) 转换数据 Transformer
fit_transform(X) fit + transform Transformer
predict(X) 预测标签/值 Predictor
predict_proba(X) 预测概率 分类器
score(X, y) 评估得分 Predictor

⚙️ 超参数设置

# 所有超参数在构造函数中设置 model = RandomForestClassifier( n_estimators=100, max_depth=10, min_samples_split=5, random_state=42 ) # 获取参数 print(model.get_params()) # 设置参数(用于网格搜索) model.set_params(max_depth=15) # 访问学习到的参数(带下划线后缀) model.fit(X, y) print(model.feature_importances_) # 学习到的参数 print(model.n_estimators) # 超参数
💡
命名约定

超参数在构造时设置,不带下划线;学习参数在fit后可用,带下划线后缀(如 feature_importances_)。

API设计
图:Scikit-learn统一的API设计简化了机器学习开发

📝 本节小结

  • • Scikit-learn采用统一的API设计,便于学习和使用
  • • 核心接口:Estimator、Transformer、Predictor
  • • 统一方法:fit、transform、predict、score
  • • 超参数在构造时设置,学习参数带下划线后缀
  • • 一致的接口使得模型替换和组合变得简单