项目说明
训练之后,即可识别手写数字图片如下(我用windows自带的画板画的,可能有点丑,见谅):
等等
数据集下载
链接:https://pan.baidu.com/s/1wWAuZxWUUhlCzRbIl-NE_g
提取码:nzjk
运行环境说明
运行在自己安装好jupyter的服务器上,安装可以参考:
在CentOS上安装jupyter notebook (支持外网访问)
- python 3
- numpy
- pandas
- pillow
注意事项
一定要保证路径正确,如果不希望做任何修改的话,请保证相对路径如下:
- data
- num
- 下面是云盘下所有文件,包括testdata文件夹,traindata文件夹,若干个检测图片。
- num
- 此python文件
如果遇到环境问题,请自行查资料解决,一般情况下是只需要安装pillow就可以了,
pip install pillow -i https://pypi.tuna.tsinghua.edu.cn/simple
推荐安装jupyter,能够省去很多环境安装问题
具体实现代码
主要包括以下几部分代码:
- 把图片转换为txt文本文件
- 把txt文本文件写入转换为一维数组
- 读取label(即训练集的文件名的前缀)
- KNN 算法
- 训练,即把所有训练集中数据转换为二维数组
- 测试,加载图片,使用上面的函数查看识别结果
代码如下
from PIL import Image,ImageDraw
import operator
from numpy import *
from os import listdir
# 临时文件路径,表示将图片转换成的文本
TEXT_PATH = "data/num/test.txt"
# 训练文件的路径
TRAIN_PATH = "data/num/traindata"
# 图片转换为text文本文件
def image_to_text(img):
# resize
img = img.resize((32,32))
width = img.size[0]
height = img.size[1]
# open 文本文件进行写操作
fh = open(TEXT_PATH,'w')
for j in range(height):
for i in range(width):
rgb = img.getpixel((i,j))
r,g,b = rgb[0],rgb[1],rgb[2]
if(r==255 and g==255 and b==255):
fh.write('0')
elif(r==0 and b==0 and g == 0):
fh.write('1')
# 为了解决既不是白也不是黑的问题
else:
fh.write('0')
fh.write('\n')
fh.close()
# 读数据写入一维数组
def datatoarray(filepath):
ary=[]
fh = open(filepath)
# 把二维数组转换为一维数组
for i in range(32):
line = fh.readline()
for j in range(32):
ary.append(int(line[j]))
return ary
#取文件名前缀,类似于5_45取5
def seplabel(fname):
label=int(fname.split("_")[0])
return label
# KNN算法
def knn(k,testdata,traindata,labels):
# 获得训练数据的列数(即属性数目)
traindatasize=traindata.shape[0]
'''
将测试数据(原本只有一行)复制成为traindatasize行,从而方便进行减法运算
举个例子,测试数据和训练数据都有n列(即n个属性),
测试数据只有一行,但是训练数据却有k行,为了让测试数据和所有的训练数据都进行计算,得到距离
为了方便,扩充一下测试数据就可以了。
列数不变,所以为1
'''
dif = tile(testdata,(traindatasize,1))-traindata
# 计算距离
sqdif = dif**2
sumsqdif = sqdif.sum(axis=1)
distance = sumsqdif**0.5
# 排序
sortdistance = distance.argsort()
# 进行统计
count={}
for i in range(0,k):
# 按照距离排序后,label中第i个最近的类型 + 1
vote = labels[sortdistance[i]]
# 字典中该元素+1
count[vote] = count.get(vote,0)+1
# 排序,得到最近次数最多的
sortcount = sorted(count.items(),key=operator.itemgetter(1),reverse=True)
return sortcount[0][0]
#建立训练数据
def traindata():
labels=[]
trainfile=listdir(TRAIN_PATH)
num=len(trainfile)
#长度1024(列),每一行存储一个文件
#用一个数组存储所有训练数据,行:文件总数,列:1024
trainarr=zeros((num,1024))
for i in range(0,num):
thisfname=trainfile[i]
thislabel=seplabel(thisfname)
labels.append(thislabel)
trainarr[i,:]=datatoarray(TRAIN_PATH+"/"+thisfname)
return trainarr,labels
# 入口
# 训练
trainarr,labels=traindata()
for i in range(10):
IMAGE_PATH = "data/num/"+str(i)+"_1.png"
img = Image.open(IMAGE_PATH,'r')
image_to_text(img)
# 单个文件
testarr=datatoarray(TEXT_PATH)
# KNN算法获得结果
rknn=knn(3,testarr,trainarr,labels)
print(rknn)
'''
# 6 7 9
IMAGE_PATH = "data/num/6_2.png"
img = Image.open(IMAGE_PATH)
image_to_text(img)
# 单个文件
testarr=datatoarray(TEXT_PATH)
# KNN算法获得结果
rknn=knn(3,testarr,trainarr,labels)
print(rknn)
'''
总结
重点放在KNN算法的应用上,KNN算法很容易理解,也很容易应用。
Smileyan
2019年12月10日 10:34