🎯 学习目标

  • 掌握模型训练方法
  • 学会交叉验证
  • 理解超参数优化
  • 能够训练ML模型
模型训练

机器学习模型训练

本节介绍ML模型的训练方法和技巧。

⚙️ 模型训练

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import accuracy_score, roc_auc_score

class MLModelTrainer:
    """
    模型训练器
    """

    def __init__(self):
        self.model = None
        self.best_params = None

    def train_model(self, X, y):
        """
        训练模型
        """
        # 划分训练集和验证集
        X_train, X_val, y_train, y_val = train_test_split(
            X, y, test_size=0.2, shuffle=False
        )

        # 定义模型
        model = RandomForestClassifier(
            n_estimators=100,
            max_depth=10,
            random_state=42
        )

        # 训练
        model.fit(X_train, y_train)

        # 评估
        train_pred = model.predict(X_train)
        val_pred = model.predict(X_val)

        train_acc = accuracy_score(y_train, train_pred)
        val_acc = accuracy_score(y_val, val_pred)

        print(f"训练集准确率: {train_acc:.4f}")
        print(f"验证集准确率: {val_acc:.4f}")

        self.model = model
        return model

    def hyperparameter_tuning(self, X, y):
        """
        超参数调优
        """
        param_grid = {
            'n_estimators': [50, 100, 200],
            'max_depth': [5, 10, 15],
            'min_samples_split': [2, 5, 10]
        }

        rf = RandomForestClassifier(random_state=42)
        grid_search = GridSearchCV(
            rf, param_grid, cv=5,
            scoring='accuracy', n_jobs=-1
        )

        grid_search.fit(X, y)

        self.best_params = grid_search.best_params_
        self.model = grid_search.best_estimator_

        print(f"最佳参数: {self.best_params}")
        return grid_search.best_estimator_

    def predict(self, X):
        """
        预测
        """
        return self.model.predict(X)

    def predict_proba(self, X):
        """
        预测概率
        """
        return self.model.predict_proba(X)
训练关键

1)防止过拟合;2)交叉验证;3)特征重要性;4)持续监控。

📝 本节小结

  • • 掌握了模型训练方法
  • • 学会了交叉验证
  • • 理解了超参数优化
  • • 能够训练ML模型