本文所用文件的链接
链接:https://pan.baidu.com/s/1RWNVHuXMQleOrEi5vig_bQ
提取码:p57s
朴素贝叶斯分类
朴素贝叶斯分类是一种依据统计理论而实现的一种分类方式. 观察一组数据:
天气情况 | 穿衣风格 | 约女朋友 | ==> | 心情 |
---|---|---|---|---|
0(晴天) | 0(休闲) | 0(约了) | ==> | 0(高兴) |
0 | 1(风骚) | 1(没约) | ==> | 0 |
1(多云) | 1 | 0 | ==> | 0 |
0 | 2(破旧) | 1 | ==> | 1(郁闷) |
2(下雨) | 2 | 0 | ==> | 0 |
… | … | … | … | … |
0 | 1 | 0 | ==> | ? |
通过上述训练样本如何预测:010的心情? 可以依照决策树的方式找相似输入预测输出. 但是如果在样本空间中没有完全匹配的相似样本该如何预测?
贝叶斯公式:
P(A,B)=P(A)P(B∣A)=P(B)P(A∣B)⇓⇓⇓P(A∣B)=P(A)P(B∣A)P(B)
P(A,B) = P(A)P(B|A) = P(B)P(A|B) \\
\Downarrow \Downarrow \Downarrow \\
P(A|B) = \frac{P(A)P(B|A)}{P(B)}
P(A,B)=P(A)P(B∣A)=P(B)P(A∣B)⇓⇓⇓P(A∣B)=P(B)P(A)P(B∣A)
例如:
假设一个学校中有60%男生和40%女生. 女生穿裤子的人数和穿裙子的人数相等. 所有男生都穿裤子. 一人在远处随机看到了一个穿裤子的学生, 那么这个学生是女生的概率是多少?
P(女) = 0.4
P(裤子|女) = 0.5
P(裤子) = 0.8
P(女|裤子) = P(女)*P(裤子|女)/P(裤子)
= 0.4 * 0.5 / 0.8 = 0.25
根据贝叶斯定理, 如何预测: 晴天并且休闲并且没约并且高兴的概率?
P(晴天,休闲,没约,高兴)
P(晴天|休闲,没约,高兴)P(休闲,没约,高兴)
P(晴天|休闲,没约,高兴)P(休闲|没约,高兴)P(没约,高兴)
P(晴天|休闲,没约,高兴)P(休闲|没约,高兴)P(没约|高兴)P(高兴)
(朴素: 条件独立, 特征值之间没有任何关系)
P(晴天|高兴)P(休闲|高兴)P(没约|高兴)P(高兴)
朴素贝叶斯相关API:
import sklearn.naive_bayes as nb
# 构建高斯朴素贝叶斯
model = nb.GaussianNB()
model.fit(x, y)
pred_test_y = model.predict(test_x)
案例: multiple1.txt
"""
朴素贝叶斯分类
"""
import numpy as np
import sklearn.naive_bayes as nb
import matplotlib.pyplot as mp
data = np.loadtxt('../ml_data/multiple1.txt',
unpack=False, delimiter=',')
print(data.shape, data.dtype)
# 获取输入与输出
x = np.array(data[:, :-1])
y = np.array(data[:, -1])
# 绘制这些点, 点的颜色即是点的类别
mp.figure('Naive Bayes', facecolor='lightgray')
mp.title('Naive Bayes', fontsize=16)
mp.xlabel('X', fontsize=14)
mp.ylabel('Y', fontsize=14)
mp.tick_params(labelsize=10)
# 通过样本数据,训练朴素贝叶斯分类模型
model = nb.GaussianNB()
model.fit(x, y)
# 绘制分类边界线
l, r = x[:, 0].min()-1, x[:, 0].max()+1
b, t = x[:, 1].min()-1, x[:, 1].max()+1
n = 500
grid_x, grid_y = np.meshgrid(
np.linspace(l, r, n),
np.linspace(b, t, n))
test_x = np.column_stack(
(grid_x.ravel(), grid_y.ravel()))
pred_test_y = model.predict(test_x)
grid_z = pred_test_y.reshape(grid_x.shape)
mp.pcolormesh(grid_x,grid_y,grid_z,cmap='gray')
mp.scatter(x[:,0], x[:,1], s=60, c=y,
cmap='jet', label='Points')
mp.legend()
mp.show()