在这篇文章中,我将使用python中的决策树(用于分类)。最近我们被客户要求撰写关于决策树的研究报告,包括一些图形和统计输出。重点将放在基础知识和对最终决策树的理解上。
视频:从决策树到随机森林:R语言信用卡违约分析信贷数据实例
从决策树到随机森林:R语言信用卡违约分析信贷数据实例
,时长10:11
导入
因此,首先我们进行一些导入。
from __future__ import print_function
import os
import subprocess
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier, export_graphviz
数据
接下来,我们需要考虑一些数据。我将使用著名的iris数据集,该数据集可对各种不同的iris类型进行各种测量。pandas和sckit-learn都可以轻松导入这些数据,我将使用pandas编写一个从csv文件导入的函数。这样做的目的是演示如何将scikit-learn与pandas一起使用。因此,我们定义了一个获取iris数据的函数:
def get_iris_data():
"""从本地csv或pandas中获取iris数据。"""
if os.path.exists("iris.csv"):
print("-- iris.csv found locally")
df = pd.read_csv("iris.csv", index_col=0)
else:
print("-- trying to download from github")
fn = "https://raw.githubusercontent.com/pydata/pandas/" + \
"master/pandas/tests/data/iris.csv"
try:
df = pd.read_csv(fn)
except:
exit("-- Unable to download iris.csv")
with open("iris.csv", 'w') as f:
print("-- writing to local iris.csv file")
df.to_csv(f)
return df
- 此函数首先尝试在本地读取数据。利用os.path.exists() 方法。如果在本地目录中找到iris.csv文件,则使用pandas通过pd.read_csv()读取文件。
- 如果本地iris.csv没有发现,抓取URL数据来运行。
下一步是获取数据,并使用head()和tail()方法查看数据的样子。因此,首先获取数据:
df = get_iris_data()
-- iris.csv found locally
然后 :
print("* df.head()", df.head(), sep="\n", end="\n\n")
print("* df.tail()", df.tail(), sep="\n", end="\n\n")
* df.head()
SepalLength SepalWidth PetalLength PetalWidth Name
0 5.1 3.5 1.4 0.2 Iris-setosa
1 4.9 3.0 1.4 0.2 Iris-setosa
2 4.7 3.2 1.3 0.2 Iris-setosa
3 4.6 3.1 1.5 0.2 Iris-setosa
4 5.0 3.6 1.4 0.2 Iris-setosa
* df.tail()
SepalLength SepalWidth PetalLength PetalWidth Name
145 6.7 3.0 5.2 2.3 Iris-virgin