网站首页 > 技术文章 正文
上一节我们讲到了决策树这个算法,但是一棵决策树可能会存在过拟合的现象,而且对数据微小的变化也比较敏感,为了解决这些问题,我们可以通过多棵树的方式,也就是今天要介绍的随机森林。
随机森林算法也就是Random Forest,它是一种集成学习算法,所谓集成学习,就是通过多个能力比较弱的机器学习模型的预测结果来得到一个更好的预测结果。随机森林的这一种方式叫做Bagging,它是Bootstrap Aggregate的简写,就是通过自助采样(bootstrap sampling)生成多个训练子集,分别训练后把结果聚合起来。
我们通过一个简单的例子来说一下为什么多个组合效果会好,比如我有四棵决策树,预测准确性都在90%左右,这四棵预测的结果为0、1、0、0,由于有三棵树的预测结果都是0,那么结果大概率就是0。注意这里有个前提,那就是每棵决策树的准确率都还可以,如果单棵树的准确率不足50%,那么还是要先去提升单棵树的准确性。
这里我们直接来看一个例子吧,代码如下:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.tree import plot_tree, DecisionTreeClassifier
from matplotlib.colors import ListedColormap
# 生成自定义分类数据
X, y = make_classification(
n_samples=300, # 样本数量
n_features=2, # 特征数量
n_informative=2, # 有效特征数量
n_redundant=0, # 冗余特征数量
n_clusters_per_class=1, # 每个类别的簇数
random_state=42,
flip_y=0.1 # 类别标签翻转比例,增加噪声
)
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 创建随机森林分类器
rf = RandomForestClassifier(n_estimators=50, random_state=42)
# 使用随机森林训练模型
rf.fit(X_train, y_train)
# 使用随机森林预测测试集
y_pred = rf.predict(X_test)
# 计算随机森林的准确率
rf_accuracy = accuracy_score(y_test, y_pred)
print(f"随机森林模型的准确率: {rf_accuracy:.2f}")
# 创建决策树分类器
clf = DecisionTreeClassifier(random_state=42)
# 使用决策树训练模型
clf.fit(X_train, y_train)
# 使用决策树预测测试集
y_pred = clf.predict(X_test)
# 计算决策树的准确率
clf_accuracy = accuracy_score(y_test, y_pred)
print(f"单棵决策树的准确率: {clf_accuracy:.2f}")
# 定义绘制决策边界的函数
def plot_decision_boundary(model, X, y, ax):
h = 0.02 # 网格步长
cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA'])
cmap_bold = ListedColormap(['#FF0000', '#00FF00'])
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
np.arange(y_min, y_max, h))
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)
ax.pcolormesh(xx, yy, Z, cmap=cmap_light)
ax.scatter(X[:, 0], X[:, 1], c=y, cmap=cmap_bold, edgecolor='k', s=20)
ax.set_xlim(xx.min(), xx.max())
ax.set_ylim(yy.min(), yy.max())
ax.set_title("Random Forest Decision Boundary")
# 创建画布
fig, axes = plt.subplots(1, 2, figsize=(15, 6))
# 绘制决策边界
plot_decision_boundary(rf, X, y, axes[0])
# 可视化随机森林中的第一棵树
plot_tree(rf.estimators_[0],
feature_names=['Feature 1', 'Feature 2'],
class_names=['Class 0', 'Class 1'],
filled=True,
rounded=True,
ax=axes[1])
axes[1].set_title("Random Forest First Decision Tree")
plt.tight_layout()
plt.show()
上面代码虽然比较多,但是做的内容其实就四件事:
第一步就是生成随机数据集并且切分为训练集和测试集,这个数据集有300个样本,共计包含2个特征,都是有效特征。
第二步就是通过随机森林分类器来进行训练并且进行预测,我们这次的随机森林包括了50棵决策树,也就是n_estimators=50这个参数的含义,然后打印出准确率。
第三步就是通过单棵决策树进行训练并且进行预测,然后打印出准确率。
第四步就是可视化,包括绘制了第一棵决策树和决策边界,这里由于决策树过多,就不一一列举了。
来看一下两者的准确率情况,如下所示:
可以看到,单棵决策树的准确率是88%,随机森林的准确性为92%,提升了4%的准确率,然后再看一下这个分类情况,如下:
对于这里的第一棵决策树,也是相对有点复杂的,这里也来个截图看一下,如下所示:
可以看到,相对于使用单棵决策树来说,我们只改动了少数几行代码,就可以提升准确率,不过需要注意的是,随机森林相对于单棵决策树来说,消耗的资源也会增加一些。
- 上一篇: 用Python进行机器学习(8)分类任务的模型评估
- 下一篇: Scikit-Learn最新更新简介
猜你喜欢
- 2025-05-14 专访庾恩利:不完美才是完美
- 2025-05-14 实例 | 教你用python写一个电信客户流失预测模型
- 2025-05-14 分析5万多场英雄联盟比赛,教你如何轻松用python预测胜负
- 2025-05-14 python决策树用于分类和回归问题实际应用案例
- 2025-05-14 梯度提升算法决策过程的逐步可视化
- 2025-05-14 为机器学习模型设置最佳阈值:0.5是二元分类的最佳阈值吗
- 2025-05-14 python创建分类器小结
- 2025-05-14 一文带您了解随机森林分类和回归模型:Python示例
- 2025-05-14 pybaobabdt,一个超强的 Python 库!
- 2025-05-14 Scikit-Learn最新更新简介
- 1507℃桌面软件开发新体验!用 Blazor Hybrid 打造简洁高效的视频处理工具
- 518℃Dify工具使用全场景:dify-sandbox沙盒的原理(源码篇·第2期)
- 489℃MySQL service启动脚本浅析(r12笔记第59天)
- 468℃服务器异常重启,导致mysql启动失败,问题解决过程记录
- 466℃启用MySQL查询缓存(mysql8.0查询缓存)
- 446℃「赵强老师」MySQL的闪回(赵强iso是哪个大学毕业的)
- 426℃mysql服务怎么启动和关闭?(mysql服务怎么启动和关闭)
- 423℃MySQL server PID file could not be found!失败
- 最近发表
- 标签列表
-
- c++中::是什么意思 (83)
- 标签用于 (65)
- 主键只能有一个吗 (66)
- c#console.writeline不显示 (75)
- pythoncase语句 (81)
- es6includes (73)
- windowsscripthost (67)
- apt-getinstall-y (86)
- node_modules怎么生成 (76)
- chromepost (65)
- c++int转char (75)
- static函数和普通函数 (76)
- el-date-picker开始日期早于结束日期 (70)
- js判断是否是json字符串 (67)
- checkout-b (67)
- localstorage.removeitem (74)
- vector线程安全吗 (70)
- & (66)
- java (73)
- js数组插入 (83)
- linux删除一个文件夹 (65)
- mac安装java (72)
- eacces (67)
- 查看mysql是否启动 (70)
- 无效的列索引 (74)