Spark MLlib决策树实战:基于KDD Cup 1999的网络攻击检测
决策树算法简介
决策树是机器学习中非常流行的算法,主要优势包括:
- 模型可解释性强,决策过程直观可见
- 天然支持类别型特征处理
- 无需特征缩放
- 能够自动捕捉特征间的非线性关系和交互作用
- 适用于多分类问题
在网络安全领域,决策树特别适合用于网络攻击检测这类分类任务。本文将使用Spark MLlib中的决策树算法,基于著名的KDD Cup 1999数据集构建网络攻击检测模型。
环境准备与数据加载
我们使用以下Spark集群配置:
- 1个主节点和7个工作节点
- 每个节点8GB内存(实际使用6GB)
- 双核2.4GHz Intel处理器
- Spark 1.3.1版本
首先下载并加载完整版的KDD Cup 1999数据集,该数据集包含近500万条网络交互记录:
import urllib
# 下载训练数据
urllib.urlretrieve("http://kdd.ics.uci.edu/databases/kddcup99/kddcup.data.gz", "kddcup.data.gz")
# 加载为RDD
data_file = "./kddcup.data.gz"
raw_data = sc.textFile(data_file)
print "训练数据量: {}".format(raw_data.count()) # 输出: 4898431条记录
# 下载并加载测试数据
urllib.urlretrieve("http://kdd.ics.uci.edu/databases/kddcup99/corrected.gz", "corrected.gz")
test_data_file = "./corrected.gz"
test_raw_data = sc.textFile(test_data_file)
print "测试数据量: {}".format(test_raw_data.count()) # 输出: 311029条记录
数据预处理
决策树可以直接处理类别型特征,但需要先将它们转换为数值型索引。我们首先提取所有可能的类别值:
from pyspark.mllib.regression import LabeledPoint
from numpy import array
# 解析CSV数据
csv_data = raw_data.map(lambda x: x.split(","))
test_csv_data = test_raw_data.map(lambda x: x.split(","))
# 获取所有类别型特征的可能取值
protocols = csv_data.map(lambda x: x[1]).distinct().collect() # 协议类型
services = csv_data.map(lambda x: x[2]).distinct().collect() # 服务类型
flags = csv_data.map(lambda x: x[3]).distinct().collect() # 连接状态标志
然后定义转换函数,将原始数据转换为LabeledPoint格式,这是MLlib要求的输入格式:
def create_labeled_point(line_split):
clean_line_split = line_split[0:41]
# 转换协议类型为数值索引
try:
clean_line_split[1] = protocols.index(clean_line_split[1])
except:
clean_line_split[1] = len(protocols) # 未知类别设为特殊值
# 转换服务类型为数值索引
try:
clean_line_split[2] = services.index(clean_line_split[2])
except:
clean_line_split[2] = len(services)
# 转换连接标志为数值索引
try:
clean_line_split[3] = flags.index(clean_line_split[3])
except:
clean_line_split[3] = len(flags)
# 将标签转换为二元标签:0表示正常,1表示攻击
attack = 1.0 if line_split[41] != 'normal.' else 0.0
return LabeledPoint(attack, array([float(x) for x in clean_line_split]))
# 应用转换函数
training_data = csv_data.map(create_labeled_point)
test_data = test_csv_data.map(create_labeled_point)
训练决策树模型
使用MLlib的DecisionTree训练分类器,关键参数包括:
numClasses
: 类别数(这里是二元分类)categoricalFeaturesInfo
: 指定哪些特征是类别型及其基数impurity
: 不纯度度量(基尼系数或信息增益)maxDepth
: 树的最大深度maxBins
: 连续特征离散化的最大分箱数
from pyspark.mllib.tree import DecisionTree
from time import time
# 训练模型
t0 = time()
tree_model = DecisionTree.trainClassifier(
training_data,
numClasses=2,
categoricalFeaturesInfo={1: len(protocols), 2: len(services), 3: len(flags)},
impurity='gini',
maxDepth=4,
maxBins=100
)
tt = time() - t0
print "模型训练耗时: {}秒".format(round(tt,3)) # 输出约440秒
模型评估
在测试集上评估模型性能:
# 预测测试数据
predictions = tree_model.predict(test_data.map(lambda p: p.features))
labels_and_preds = test_data.map(lambda p: p.label).zip(predictions)
# 计算准确率
t0 = time()
test_accuracy = labels_and_preds.filter(lambda (v, p): v == p).count() / float(test_data.count())
tt = time() - t0
print "预测耗时: {}秒,测试准确率: {}".format(round(tt,3), round(test_accuracy,4))
# 输出: 预测耗时38.651秒,准确率91.55%
模型解释
决策树最大的优势在于其可解释性。我们可以打印出树的决策规则:
print "学习到的决策树模型:"
print tree_model.toDebugString()
输出结果展示了完整的决策路径。例如,以下特征的网络交互会被分类为攻击:
count
(过去2秒内与同一主机的连接数)> 32dst_bytes
(目标到源的数据字节数)= 0service
既不是"urp_i"也不是"tftp_u"logged_in
=False(未登录状态)
通过分析树结构,我们发现最重要的三个特征是:
count
(连接计数)flag
(连接状态标志)dst_bytes
(目标字节数)
构建精简模型
基于上述发现,我们可以仅使用这三个关键特征构建一个更简单的模型:
# 选择关键特征:count(22), dst_bytes(32), flag(3)
def create_minimal_labeled_point(line_split):
selected_features = [float(line_split[22]), float(line_split[32]), float(line_split[3])]
# 转换flag为数值索引
try:
selected_features[2] = flags.index(line_split[3])
except:
selected_features[2] = len(flags)
attack = 1.0 if line_split[41] != 'normal.' else 0.0
return LabeledPoint(attack, array(selected_features))
# 准备精简数据集
minimal_training_data = csv_data.map(create_minimal_labeled_point)
minimal_test_data = test_csv_data.map(create_minimal_labeled_point)
# 训练精简模型
minimal_tree_model = DecisionTree.trainClassifier(
minimal_training_data,
numClasses=2,
categoricalFeaturesInfo={2: len(flags)}, # 只有flag是类别型
impurity='gini',
maxDepth=4,
maxBins=100
)
# 评估精简模型
minimal_predictions = minimal_tree_model.predict(minimal_test_data.map(lambda p: p.features))
minimal_labels_and_preds = minimal_test_data.map(lambda p: p.label).zip(minimal_predictions)
minimal_test_accuracy = minimal_labels_and_preds.filter(lambda (v, p): v == p).count() / float(minimal_test_data.count())
print "精简模型准确率: {}".format(round(minimal_test_accuracy,4))
精简模型虽然可能损失一些准确率,但具有以下优势:
- 训练和预测速度更快
- 模型更简单,更容易解释和维护
- 减少了特征工程的工作量
总结
本文通过Spark MLlib实现了一个基于决策树的网络攻击检测系统,主要步骤包括:
- 数据加载与预处理
- 类别型特征转换
- 决策树模型训练
- 模型评估与解释
- 关键特征提取与精简模型构建
决策树模型不仅提供了良好的分类性能(91.55%准确率),还能直观地展示攻击检测的决策逻辑,这对网络安全分析人员非常有价值。通过分析树结构,我们可以识别出最关键的检测特征,进而构建更高效的轻量级模型。
在实际应用中,可以进一步优化:
- 调整maxDepth等参数提升准确率
- 尝试集成方法如随机森林
- 对类别型特征进行更精细的编码
- 处理类别不平衡问题
创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考