Spark MLlib决策树实战:基于KDD Cup 1999的网络攻击检测

Spark MLlib决策树实战:基于KDD Cup 1999的网络攻击检测

spark-py-notebooks Apache Spark & Python (pySpark) tutorials for Big Data Analysis and Machine Learning as IPython / Jupyter notebooks spark-py-notebooks 项目地址: https://gitcode.com/gh_mirrors/sp/spark-py-notebooks

决策树算法简介

决策树是机器学习中非常流行的算法,主要优势包括:

  • 模型可解释性强,决策过程直观可见
  • 天然支持类别型特征处理
  • 无需特征缩放
  • 能够自动捕捉特征间的非线性关系和交互作用
  • 适用于多分类问题

在网络安全领域,决策树特别适合用于网络攻击检测这类分类任务。本文将使用Spark MLlib中的决策树算法,基于著名的KDD Cup 1999数据集构建网络攻击检测模型。

环境准备与数据加载

我们使用以下Spark集群配置:

  • 1个主节点和7个工作节点
  • 每个节点8GB内存(实际使用6GB)
  • 双核2.4GHz Intel处理器
  • Spark 1.3.1版本

首先下载并加载完整版的KDD Cup 1999数据集,该数据集包含近500万条网络交互记录:

import urllib

# 下载训练数据
urllib.urlretrieve("http://kdd.ics.uci.edu/databases/kddcup99/kddcup.data.gz", "kddcup.data.gz")

# 加载为RDD
data_file = "./kddcup.data.gz"
raw_data = sc.textFile(data_file)
print "训练数据量: {}".format(raw_data.count())  # 输出: 4898431条记录

# 下载并加载测试数据
urllib.urlretrieve("http://kdd.ics.uci.edu/databases/kddcup99/corrected.gz", "corrected.gz")
test_data_file = "./corrected.gz"
test_raw_data = sc.textFile(test_data_file)
print "测试数据量: {}".format(test_raw_data.count())  # 输出: 311029条记录

数据预处理

决策树可以直接处理类别型特征,但需要先将它们转换为数值型索引。我们首先提取所有可能的类别值:

from pyspark.mllib.regression import LabeledPoint
from numpy import array

# 解析CSV数据
csv_data = raw_data.map(lambda x: x.split(","))
test_csv_data = test_raw_data.map(lambda x: x.split(","))

# 获取所有类别型特征的可能取值
protocols = csv_data.map(lambda x: x[1]).distinct().collect()  # 协议类型
services = csv_data.map(lambda x: x[2]).distinct().collect()   # 服务类型
flags = csv_data.map(lambda x: x[3]).distinct().collect()      # 连接状态标志

然后定义转换函数,将原始数据转换为LabeledPoint格式,这是MLlib要求的输入格式:

def create_labeled_point(line_split):
    clean_line_split = line_split[0:41]
    
    # 转换协议类型为数值索引
    try: 
        clean_line_split[1] = protocols.index(clean_line_split[1])
    except:
        clean_line_split[1] = len(protocols)  # 未知类别设为特殊值
        
    # 转换服务类型为数值索引
    try:
        clean_line_split[2] = services.index(clean_line_split[2])
    except:
        clean_line_split[2] = len(services)
    
    # 转换连接标志为数值索引
    try:
        clean_line_split[3] = flags.index(clean_line_split[3])
    except:
        clean_line_split[3] = len(flags)
    
    # 将标签转换为二元标签:0表示正常,1表示攻击
    attack = 1.0 if line_split[41] != 'normal.' else 0.0
        
    return LabeledPoint(attack, array([float(x) for x in clean_line_split]))

# 应用转换函数
training_data = csv_data.map(create_labeled_point)
test_data = test_csv_data.map(create_labeled_point)

训练决策树模型

使用MLlib的DecisionTree训练分类器,关键参数包括:

  • numClasses: 类别数(这里是二元分类)
  • categoricalFeaturesInfo: 指定哪些特征是类别型及其基数
  • impurity: 不纯度度量(基尼系数或信息增益)
  • maxDepth: 树的最大深度
  • maxBins: 连续特征离散化的最大分箱数
from pyspark.mllib.tree import DecisionTree
from time import time

# 训练模型
t0 = time()
tree_model = DecisionTree.trainClassifier(
    training_data, 
    numClasses=2,
    categoricalFeaturesInfo={1: len(protocols), 2: len(services), 3: len(flags)},
    impurity='gini', 
    maxDepth=4, 
    maxBins=100
)
tt = time() - t0

print "模型训练耗时: {}秒".format(round(tt,3))  # 输出约440秒

模型评估

在测试集上评估模型性能:

# 预测测试数据
predictions = tree_model.predict(test_data.map(lambda p: p.features))
labels_and_preds = test_data.map(lambda p: p.label).zip(predictions)

# 计算准确率
t0 = time()
test_accuracy = labels_and_preds.filter(lambda (v, p): v == p).count() / float(test_data.count())
tt = time() - t0

print "预测耗时: {}秒,测试准确率: {}".format(round(tt,3), round(test_accuracy,4)) 
# 输出: 预测耗时38.651秒,准确率91.55%

模型解释

决策树最大的优势在于其可解释性。我们可以打印出树的决策规则:

print "学习到的决策树模型:"
print tree_model.toDebugString()

输出结果展示了完整的决策路径。例如,以下特征的网络交互会被分类为攻击:

  1. count(过去2秒内与同一主机的连接数)> 32
  2. dst_bytes(目标到源的数据字节数)= 0
  3. service既不是"urp_i"也不是"tftp_u"
  4. logged_in=False(未登录状态)

通过分析树结构,我们发现最重要的三个特征是:

  1. count(连接计数)
  2. flag(连接状态标志)
  3. dst_bytes(目标字节数)

构建精简模型

基于上述发现,我们可以仅使用这三个关键特征构建一个更简单的模型:

# 选择关键特征:count(22), dst_bytes(32), flag(3)
def create_minimal_labeled_point(line_split):
    selected_features = [float(line_split[22]), float(line_split[32]), float(line_split[3])]
    
    # 转换flag为数值索引
    try:
        selected_features[2] = flags.index(line_split[3])
    except:
        selected_features[2] = len(flags)
    
    attack = 1.0 if line_split[41] != 'normal.' else 0.0
    
    return LabeledPoint(attack, array(selected_features))

# 准备精简数据集
minimal_training_data = csv_data.map(create_minimal_labeled_point)
minimal_test_data = test_csv_data.map(create_minimal_labeled_point)

# 训练精简模型
minimal_tree_model = DecisionTree.trainClassifier(
    minimal_training_data,
    numClasses=2,
    categoricalFeaturesInfo={2: len(flags)},  # 只有flag是类别型
    impurity='gini',
    maxDepth=4,
    maxBins=100
)

# 评估精简模型
minimal_predictions = minimal_tree_model.predict(minimal_test_data.map(lambda p: p.features))
minimal_labels_and_preds = minimal_test_data.map(lambda p: p.label).zip(minimal_predictions)
minimal_test_accuracy = minimal_labels_and_preds.filter(lambda (v, p): v == p).count() / float(minimal_test_data.count())

print "精简模型准确率: {}".format(round(minimal_test_accuracy,4))

精简模型虽然可能损失一些准确率,但具有以下优势:

  1. 训练和预测速度更快
  2. 模型更简单,更容易解释和维护
  3. 减少了特征工程的工作量

总结

本文通过Spark MLlib实现了一个基于决策树的网络攻击检测系统,主要步骤包括:

  1. 数据加载与预处理
  2. 类别型特征转换
  3. 决策树模型训练
  4. 模型评估与解释
  5. 关键特征提取与精简模型构建

决策树模型不仅提供了良好的分类性能(91.55%准确率),还能直观地展示攻击检测的决策逻辑,这对网络安全分析人员非常有价值。通过分析树结构,我们可以识别出最关键的检测特征,进而构建更高效的轻量级模型。

在实际应用中,可以进一步优化:

  • 调整maxDepth等参数提升准确率
  • 尝试集成方法如随机森林
  • 对类别型特征进行更精细的编码
  • 处理类别不平衡问题

spark-py-notebooks Apache Spark & Python (pySpark) tutorials for Big Data Analysis and Machine Learning as IPython / Jupyter notebooks spark-py-notebooks 项目地址: https://gitcode.com/gh_mirrors/sp/spark-py-notebooks

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

计姗群

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值