数据挖掘决策树python代码
时间: 2025-07-11 07:00:34 AIGC 浏览: 21
在数据挖掘中,决策树是一种常用的分类算法。它通过递归选择最优划分属性,并根据属性取值将数据集划分为若干子集,从而构建一棵具有决策功能的树结构。以下是一个使用 Python 和 Scikit-learn 实现决策树分类任务的完整示例,该示例基于鸢尾花(Iris)数据集[^1]。
### 使用 Scikit-learn 构建决策树分类器
#### 1. 导入所需库并加载数据集
```python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import tree
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.tree import DecisionTreeClassifier
# 加载鸢尾花数据集
iris = load_iris()
X = iris.data # 特征
y = iris.target # 类别标签
feature_names = iris.feature_names # 特征名称
class_names = iris.target_names # 类别名称
# 划分训练集和测试集(6:4)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=30)
```
#### 2. 创建并训练决策树模型
```python
# 初始化决策树分类器(默认为CART算法,使用Gini不纯度作为划分标准)
clf = DecisionTreeClassifier(random_state=30)
# 训练模型
clf.fit(X_train, y_train)
```
#### 3. 进行预测与评估
```python
# 在测试集上进行预测
y_pred = clf.predict(X_test)
# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率为:{accuracy:.2f}")
```
#### 4. 可视化决策树
```python
plt.figure(figsize=(12, 8))
tree.plot_tree(clf, feature_names=feature_names, class_names=class_names, filled=True)
plt.show()
```
#### 5. 导出树结构为文本格式(可选)
```python
text_representation = tree.export_text(clf, feature_names=feature_names)
print(text_representation)
```
#### 6. 参数调优(以`max_depth`为例)
```python
# 尝试不同最大深度
depth_range = range(1, 10)
accuracies = []
for depth in depth_range:
clf = DecisionTreeClassifier(max_depth=depth, random_state=30)
clf.fit(X_train, y_train)
y_pred = clf.predict(X_test)
accuracies.append(accuracy_score(y_test, y_pred))
# 绘制准确率随深度变化曲线
plt.plot(depth_range, accuracies, marker='o')
plt.xlabel('Max Depth of Tree')
plt.ylabel('Accuracy')
plt.title('Decision Tree Accuracy vs Max Depth')
plt.grid(True)
plt.show()
```
---
###
阅读全文
相关推荐




















