一、朴素贝叶斯模型分类与核心原理
朴素贝叶斯算法的核心是基于 “特征条件独立性假设”,通过贝叶斯公式计算后验概率实现分类。根据特征数据类型的差异,衍生出三大经典模型,分别适用于不同场景,其核心区别在于对 “特征条件概率” 的计算方式不同。
1.1 多项式朴素贝叶斯(MultinomialNB)
适用场景
- 特征为离散型数据,尤其是文本分类(如统计单词出现次数、TF-IDF 值)、物品计数等场景。
- 典型案例:垃圾邮件分类(统计 “优惠”“中奖” 等关键词的出现频次)、新闻主题分类。
核心原理
- 假设特征服从多项式分布,即特征值代表 “事件发生的次数”(如单词在文本中出现的次数)。
- 计算条件概率时,需引入拉普拉斯平滑(通过
alpha
参数控制),避免因某些特征未出现导致概率为 0 的问题。
关键参数(sklearn 实现)
参数 | 类型 | 默认值 | 功能说明 |
---|---|---|---|
alpha | 浮点型 | 1.0 | 拉普拉斯平滑系数:alpha=0 表示不平滑,alpha>0 时,值越大平滑效果越强 |
fit_prior | 布尔型 | True | 是否使用先验概率:True 时基于数据计算先验,False 时假设所有类别先验概率相等 |
class_prior | 数组型 | None | 自定义类别先验概率:若指定,则忽略fit_prior 的设置 |
1.2 高斯朴素贝叶斯(GaussianNB)
适用场景
- 特征为连续型数据(如身高、体重、温度等数值型特征),无法通过 “计数” 描述的场景。
- 典型案例:鸢尾花品种分类(花瓣长度、宽度为连续值)、房价区间预测(面积、楼层为连续值)。
核心原理
- 假设特征服从正态分布(高斯分布),通过计算样本中每个类别下特征的均值和标准差,构建正态分布概率密度函数,进而求解条件概率。
- 无需手动设置平滑参数,模型会自动通过极大似然法估计正态分布的参数(均值、方差)。
关键参数(sklearn 实现)
参数 | 类型 | 默认值 | 功能说明 |
---|---|---|---|
priors | 数组型 | None | 自定义类别先验概率:若为 None,模型通过样本数据自动计算(极大似然法) |
1.3 伯努利朴素贝叶斯(BernoulliNB)
适用场景
- 特征为二值离散型数据(仅取值 0 或 1),即 “特征是否存在” 而非 “特征出现次数” 的场景。
- 典型案例:文本分类(单词是否在文本中出现,1 = 出现,0 = 未出现)、用户行为分析(是否点击某按钮,1 = 点击,0 = 未点击)。
核心原理
- 假设特征服从伯努利分布(0-1 分布),仅关注特征 “是否发生”,不关注发生次数。
- 需通过
binarize
参数将非二值特征转换为二值(若特征已二值化,可设为 None),同时支持拉普拉斯平滑(alpha
参数)。
关键参数(sklearn 实现)
参数 | 类型 | 默认值 | 功能说明 |
---|---|---|---|
alpha | 浮点型 | 1.0 | 拉普拉斯平滑系数,作用同 MultinomialNB |
binarize | 浮点型 / None | 0 | 特征二值化阈值:若为x ,则特征值 > x 设为 1,否则设为 0;None 表示特征已二值化 |
fit_prior | 布尔型 | True | 是否使用先验概率,作用同 MultinomialNB |
class_prior | 数组型 | None | 自定义类别先验概率,作用同 MultinomialNB |
1.4 三大模型对比与选择指南
模型 | 适用特征类型 | 核心假设 | 关键参数 | 典型场景 |
---|---|---|---|---|
多项式朴素贝叶斯 | 离散型(计数数据) | 特征服从多项式分布 | alpha 、fit_prior | 文本分类(单词频次)、商品销量分类 |
高斯朴素贝叶斯 | 连续型数据 | 特征服从正态分布 | priors | 数值型特征分类(鸢尾花、房价) |
伯努利朴素贝叶斯 | 二值离散型(0/1) | 特征服从伯努利分布 | alpha 、binarize | 文本分类(单词存在性)、用户行为分析 |
二、朴素贝叶斯模型通用 API(sklearn)
sklearn 中三种朴素贝叶斯模型的接口完全一致,核心方法如下,便于快速切换模型进行实验:
方法 | 功能描述 |
---|---|
fit(X, y) | 用训练集(X:特征矩阵,y:标签)拟合模型,学习概率参数 |
predict(X) | 对测试集 X 进行分类预测,返回每个样本的类别标签 |
predict_proba(X) | 返回测试集 X 属于每个类别的概率(概率和为 1) |
predict_log_proba(X) | 返回predict_proba(X) 的对数形式(避免概率过小导致的数值下溢) |
score(X, y) | 计算模型在数据集(X,y)上的准确率(正确预测数 / 总样本数) |
三、课后练习:朴素贝叶斯实现手写数字识别
3.1 任务背景
手写数字数据集(load_digits
)包含 8×8 像素的灰度图像(共 1797 个样本),每个样本的特征是 64 个连续值(0-16,代表像素灰度),标签是 0-9 的数字。需选择合适的朴素贝叶斯模型实现分类,并评估性能。
3.2 模型选择依据
- 特征类型:64 个像素值为连续型数据(0-16 的整数,本质是连续区间内的离散化表示)。
- 模型匹配:连续型特征适合使用高斯朴素贝叶斯(GaussianNB),无需手动处理特征分布,直接通过正态分布拟合像素值概率。
3.3 完整代码实现
python
运行
# 1. 导入必要库
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.naive_bayes import GaussianNB # 选择高斯朴素贝叶斯
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
# 2. 加载并探索数据集
digits = load_digits()
print("数据集基本信息:")
print(f"样本数量:{digits.data.shape[0]}, 特征数量(像素数):{digits.data.shape[1]}")
print(f"类别数量(数字0-9):{len(digits.target_names)}")
print(f"像素值范围:{digits.data.min()} ~ {digits.data.max()}")
# 可视化前4个样本(验证数据格式)
plt.figure(figsize=(8, 4))
for i in range(4):
plt.subplot(1, 4, i+1)
# 将64维特征重塑为8×8图像
plt.imshow(digits.images[i], cmap=plt.cm.gray_r)
plt.title(f"Label: {digits.target[i]}")
plt.axis("off")
plt.show()
# 3. 数据预处理:划分训练集与测试集
# 随机划分,测试集占比30%,固定随机种子确保结果可复现
X_train, X_test, y_train, y_test = train_test_split(
digits.data, digits.target, test_size=0.3, random_state=42
)
print(f"\n训练集大小:{X_train.shape}, 测试集大小:{X_test.shape}")
# 4. 初始化并训练高斯朴素贝叶斯模型
model = GaussianNB()
model.fit(X_train, y_train) # 拟合模型,学习每个类别下特征的正态分布参数
# 5. 模型预测与性能评估
# 5.1 测试集预测
y_pred = model.predict(X_test) # 类别预测
y_pred_proba = model.predict_proba(X_test) # 类别概率预测(可选)
# 5.2 核心指标:准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"\n模型在测试集上的准确率:{accuracy:.4f}")
# 5.3 详细评估:分类报告(精确率、召回率、F1分数)
print("\n分类报告(精确率/召回率/F1分数):")
print(classification_report(
y_test, y_pred, target_names=[str(i) for i in digits.target_names]
))
# 5.4 混淆矩阵:分析各类别预测错误情况
conf_matrix = confusion_matrix(y_test, y_pred)
print("\n混淆矩阵(行:真实标签,列:预测标签):")
print(conf_matrix)
# 6. 错误案例分析(可选):查看前3个预测错误的样本
error_indices = np.where(y_pred != y_test)[0][:3] # 前3个错误样本的索引
plt.figure(figsize=(8, 3))
for i, idx in enumerate(error_indices):
plt.subplot(1, 3, i+1)
plt.imshow(X_test[idx].reshape(8, 8), cmap=plt.cm.gray_r)
plt.title(f"True: {y_test[idx]}, Pred: {y_pred[idx]}")
plt.axis("off")
plt.show()
3.4 结果分析
1. 基础性能
- 模型在测试集上的准确率约为0.83-0.85(因随机划分可能略有波动,固定
random_state=42
后准确率为 0.8426)。 - 从分类报告可见:数字 “0”“1”“6” 的 F1 分数接近 1.0,分类效果极佳;数字 “8”“9” 的 F1 分数较低(约 0.75),因这两个数字的像素分布更相似,易混淆。
2. 混淆矩阵解读
- 对角线元素表示 “正确预测数”,非对角线元素表示 “错误预测数”。
- 例如:混淆矩阵中
conf_matrix[8,9]
(真实为 8,预测为 9)的值较大,说明模型易将 “8” 误判为 “9”,需后续优化(如增加特征工程、换用其他模型)。
3. 错误案例可视化
- 错误样本的图像显示:“8” 与 “9” 的区别仅在于底部是否有缺口,像素差异小,导致高斯朴素贝叶斯的正态分布假设难以区分,这是模型性能瓶颈的主要原因。
四、学习总结与拓展思考
4.1 核心收获
- 模型选型逻辑:根据特征类型选择朴素贝叶斯模型是关键 —— 离散计数用多项式、连续值用高斯、二值特征用伯努利,避免 “错配” 导致的性能损失。
- 高斯模型特点:无需手动处理连续特征的分布,实现简单、速度快,但对 “特征独立” 假设敏感(如手写数字中相邻像素存在相关性,会影响模型精度)。
- 评估维度:除准确率外,需通过混淆矩阵、分类报告分析 “类别级” 性能,定位易混淆类别,为优化提供方向。
4.2 优化方向
- 特征工程:对像素特征进行预处理(如二值化、边缘检测),增强数字间的区分度,可尝试用伯努利朴素贝叶斯重新实验。
- 模型改进:若需更高精度,可换用非朴素贝叶斯模型(如 SVM、随机森林),或通过 “贝叶斯网络” 放松 “特征独立” 假设。
- 超参数调优:对高斯朴素贝叶斯的
priors
参数进行自定义(如根据训练集中各类别样本占比调整先验概率),可能提升少数类别的召回率。
通过本次实践,不仅掌握了三种朴素贝叶斯模型的应用场景与代码实现,更理解了 “模型假设与数据特性匹配” 的重要性 —— 朴素贝叶斯的 “朴素” 既是其优势(计算快),也是其局限(依赖独立假设),需在实际任务中灵活权衡。