💻 sklearn实现
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
# 创建决策树
model = DecisionTreeClassifier(
max_depth=5, # 最大深度
min_samples_split=10, # 分裂最小样本数
min_samples_leaf=5, # 叶节点最小样本数
criterion='gini', # 分裂标准: 'gini' 或 'entropy'
random_state=42
)
# 训练
model.fit(X_train, y_train)
# 预测
y_pred = model.predict(X_test)
# 可视化决策树
plt.figure(figsize=(20, 10))
plot_tree(model, filled=True, feature_names=feature_names)
plt.show()
# 特征重要性
print(model.feature_importances_)