这是我的第417篇原创文章。
一、引言
XGBoost + 随机森林结合的优势:
1. 偏差–方差权衡更加优化
-
二者结合后,既保留了随机森林在抑制过拟合上的优势,又继承了 XGBoost 在拟合复杂模式时的高精度。
2. 鲁棒性更强
-
结合后,整体模型对数据分布变化和异常情况的适应性更强。
3. 特征多样性与互补性
-
二者互补,有助于捕捉数据中既“局部”又“全局”的模式。
这里,咱们通过房价预测案例,逐步如何使用XGBoost和随机森林各自训练模型,并通过线性回归融合技术提高整体预测性能。
二、实现过程
2.1 读取数据集
核心代码:
filename = 'data.csv'
names = ['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS',
'RAD', 'TAX', 'PRTATIO', 'B', 'LSTAT', 'MEDV']
dataset = pd.read_csv(filename, names=names, delim_whitespace=True)
print(dataset)
df = pd.DataFrame(dataset)
dataset:
2.2 划分数据集
核心代码:
features = names[:-1]
target = ['MEDV']
X_train, X_test, y_train, y_test = train_test_split(df[features], df[target], test_size=0.2, random_state=0)
2.3 训练XGBoost模型
核心代码:
xgb_model = XGBRegressor(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)
xgb_model.fit(X_train, y_train)
xgb_preds = xgb_model.predict(X_test)
2.4 训练随机森林模型
核心代码:
rf_model = RandomForestRegressor(n_estimators=100, max_depth=5, random_state=42)
rf_model.fit(X_train, y_train)
rf_preds = rf_model.predict(X_test)
2.5 创建融合模型输入
核心代码:
blend_X = ...
blend_model = LinearRegression()
blend_model.fit(blend_X, y_test)
blend_preds = blend_model.predict(blend_X)
2.6 评估指标
核心代码:
xgb_mse = mean_squared_error(y_test, xgb_preds)
rf_mse = mean_squared_error(y_test, rf_preds)
blend_mse = mean_squared_error(y_test, blend_preds)
xgb_r2 = r2_score(y_test, xgb_preds)
rf_r2 = r2_score(y_test, rf_preds)
blend_r2 = r2_score(y_test, blend_preds)
2.7 可视化分析
绘图1:真实值vs预测值(融合模型)
plt.figure(figsize=(8, 6))
plt.scatter(y_test, blend_preds, c='darkorange', label="Blended Prediction", alpha=0.6)
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'k--', lw=2, label="Ideal")
plt.xlabel("True Values")
plt.ylabel("Predicted Values")
plt.title("图 1:融合模型预测结果 vs 实际值")
plt.legend()
plt.tight_layout()
plt.show()
结果:
-
每个点代表一个样本的真实值与融合预测值。
-
越靠近对角线表示预测越准确。
-
该图说明融合模型基本贴合真实值,预测精度良好。
绘图2:MSE对比条形图
plt.figure(figsize=(8, 6))
mse_vals = [xgb_mse, rf_mse, blend_mse]
r2_vals = [xgb_r2, rf_r2, blend_r2]
models = ['XGBoost', 'Random Forest', 'Blended']
colors = ['#1f77b4', '#2ca02c', '#d62728']
sns.barplot(x=models, y=mse_vals, palette=colors)
plt.ylabel("MSE")
plt.title("图 2:各模型的均方误差(MSE)对比")
plt.tight_layout()
plt.show()
结果:
-
蓝色:XGBoost,绿色:随机森林,红色:融合模型。
-
融合模型的 MSE 最低,说明其误差比单个模型都低。
-
可见融合策略有效降低了预测误差。
绘图3:R²分数对比
plt.figure(figsize=(8, 6))
sns.barplot(x=models, y=r2_vals, palette=colors)
plt.ylabel("R² Score")
plt.title("图 3:各模型的R^2分数对比")
plt.tight_layout()
plt.show()
结果:
-
R²越接近 1 越好,表示模型解释方差能力越强。
-
融合模型明显高于单一模型,增强了解释力。
绘图4:XGBoost特征重要性图
xgb_feat_imp = pd.Series(xgb_model.feature_importances_, index=features).sort_values(ascending=False)
plt.figure(figsize=(10, 6))
xgb_feat_imp.head(10).plot(kind='bar', color='dodgerblue')
plt.title("图 4:XGBoost 模型前10重要特征")
plt.ylabel("Importance Score")
plt.tight_layout()
plt.show()
结果:
-
对 XGBoost 最重要的前10个特征。
-
通过这张图我们可以了解模型关注了哪些特征,对特征选择与后续特征工程有帮助。
三、小结
方法 |
优点 |
缺点 |
---|---|---|
XGBoost |
收敛快、支持正则化、特征选择自动 |
参数多、调参复杂 |
随机森林 |
抗噪声强、不易过拟合、易于实现 |
对异常值敏感、不能很好捕捉复杂非线性关系 |
融合模型 |
折中多个模型的缺陷、泛化能力更强 |
需要更多训练和计算资源 |
融合模型通过线性回归的方式利用了 XGBoost 和 RF 的预测结果,显著提升了精度,并且在各项指标(MSE、R²)上均有优势。
作者简介:
读研期间发表6篇SCI数据挖掘相关论文,现在某研究院从事数据算法相关科研工作,结合自身科研实践经历不定期分享关于Python、机器学习、深度学习、人工智能系列基础知识与应用案例。致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。需要数据集和源码的小伙伴可以关注底部公众号添加作者微信。