机器学习KNN算法手写数字识别系统

摘要

本文手写体数字识别系统的工作主要是运用K最邻近算法实现了对手写体数字的识别,支持上传本地图片和调用摄像头进行拍摄并截图识别两种识别的途径,同时有添加完善数据集、查看测试集的识别率的功能,形成了一个比较完整的手写数字识别系统。本文还运用python的GUI编程中的tkinter模块设计了一个简洁友好的用户界面。本文重点阐述了手写数字识别图像处理流程,运用KNN算法进行分类识别,同时运用数理统计的方法对K值的选取进行优化,最后对整个系统的实现结果进行了分析。

KNN算法简介

存在一个样本数据集合,也称为训练样本集,并且样本集中的每一个数据都有标签,即我们知道样本集中的每一个数据的特征和对应的类型。当输入没有标签的新的数据的时候,将新的数据集的每一个特征和样本集中的每一个数据的对应的特征进行比较(计算两个样本的特征之间的距离),然后提取样本集中和输入的新数据特征最相似的数据的类的标签,通常我们只关心前k个最相似的数据,这就是k近算法中的k的出处。一般来说,我们只选择样本数据集中的前k最相似的数据,然后选择k个最相似的数据集中出现次数最多的作为新数据的分类。
在这里插入图片描述

系统模块流程图

手写数字识别系统流程图

主要功能

手写数字识别的系统主要功能是运用KNN算法对手写数字进行识别,支持上传本地图片识别、拍照识别、将未识别数据加入数据集、测试数据集的功能。下面做简要叙述。
(1)测试数据集:测试集testDigits中含有946个数据集,可以对测试集中的数据进行识别,得出识别率。
(2)识别图片:用户可以选择上传本地图片,对本地图片进行截图,或者不做处理直接进行全图识别,并输入真值。系统进行识别后将预测值与真值进行比较,得出识别是否正确。
(3)拍照识别:用户调用摄像头,拍摄图像另存为图片同时待识别样本输入真实值,之后系统对拍摄的照片同样的进行截图或者全图识别,将识别结果与真值比对,得出识别结果。
(4)添加数据集:用户可以将识别错误的测试集中的数字和图片识别错误的二值图添加至数据集,以此完善数据集,提高数据集的识别率,使其能够识别下一次相似的数据。

数字处理格式

在这里插入图片描述

算法关键代码

# KNN算法
def KNN(test_data, train_data, train_label, k):
    # 已知分类的数据集(训练集)的行数
    dataSetSize = train_data.shape[0]
    # 求所有距离:先tile函数将输入点拓展成与训练集相同维数的矩阵,计算测试样本与每一个训练样本的距离
    all_distances = np.sqrt(np.sum(np.square(np.tile(test_data, (dataSetSize, 1)) - train_data), axis=1))
    # print("所有距离:",all_distances)
    # 按all_distances中元素进行升序排序后得到其对应索引的列表
    sort_distance_index = all_distances.argsort()
    # print("文件索引排序:",sort_distance_index)
    # 选择距离最小的k个点
    classCount = {}
    for i in range(k):
        # 返回最小距离的训练集的索引(预测值)
        voteIlabel = train_label[sort_distance_index[i]]
        # print('第',i+1,'次预测值',voteIlabel)
        classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
        # print(classCount)
    # 求众数:按classCount字典的第2个元素(即类别出现的次数)从大到小排序
    sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCount[0][0]
 # 文本向量化 32x32 -> 1x1024
def img2vector(filename):
    returnVect = []
    fr = open(filename)
    for i in range(32):
        lineStr = fr.readline()
        for j in range(32):
            returnVect.append(int(lineStr[j]))
    return returnVect
# 构建训练集数据向量,及对应分类标签向量
def trainingDataSet():
    if (os.path.exists('data_set/train_set_label.pk1') and os.path.exists('data_set/train_set_data.npy')):
        os.remove('data_set/train_set_label.pk1')
        os.remove('data_set/train_set_data.npy')
    train_label = []
    trainingFileList = listdir('trainingDigits')
    m = len(trainingFileList)
    train_data = np.zeros((m, 1024))
    # 获取训练集的标签
    for i in range(m):
        # fileNameStr:所有训练集文件名
        fileNameStr = trainingFileList[i]
        # 得到训练集索引
        train_label.append(classnumCut(fileNameStr))
        train_data[i, :] = img2vector('trainingDigits/%s' % fileNameStr)
    # 永久化储存
    out_label = open('data_set/train_set_label.pk1', 'wb')
    pickle.dump(train_label, out_label)
    out_label.close()
    np.save('data_set/train_set_data.npy', train_data)
    return train_label, train_data 

GUI展示界面关键代码

root = Tk()
root.title('基于KNN算法的手写数字识别系统')
root.geometry('835x710+250+0')      #定义图形界面的大小
root.resizable(False, False)        #设定root的大小不可以改变
lb = Label(root, text='你可以点击 关于->Help 查看帮助文档')
lb.pack()
select_file = Button(root, text="选择文件", command=Open)
select_file.place(x=10, y=20, height=40, width=200)
cama = Button(root, text='摄像头', command=vedio_cut.get_img_from_camera_local)
cama.place(x=220, y=20, height=40, width=200)
result = Button(root, text='开始识别', command=lambda: now(select_back))
result.place(x=430, y=20, height=40, width=200)
label_back = tkinter.Label(root, text='使用备份:')
label_back.place(x=640, y=30, height=20)
#1为使用, 0为不使用
select_back = tkinter.IntVar()
select_back.set(1)
radioBack = tkinter.Radiobutton(root, variable=select_back, value=1, text='是')
radioBack.place(x=700, y=30, width=40, height=20)
radioNew = tkinter.Radiobutton(root, variable=select_back, value=0, text='否')
radioNew.place(x=745, y=30, width=40, height=20)
#创建画布
showimage1 = tkinter.PhotoImage()
canvas1 = tkinter.Canvas(root, bg='white', width=400, height=400)
canvas1.place(x=10, y=80)

showimage2 = tkinter.PhotoImage()
canvas2 = tkinter.Canvas(root, bg='white', width=400, height=400)
canvas2.place(x=420, y=80)
#创建列表框组件和滚动条
scroll = tkinter.Scrollbar(root)
scroll.place(x=810, y=500, width=15, height=200)
show_result = tkinter.Listbox(root, width=300, yscrollcommand=scroll.set)
show_result.place(x=10, y=500, width=800, height=200)
scroll.config(command=show_result.yview)

界面及结果展示图

上传本地图片
上传本地图片
拍照识别
在这里插入图片描述
截图功能
在这里插入图片描述
识别结果
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

Logo

CSDN联合极客时间,共同打造面向开发者的精品内容学习社区,助力成长!

更多推荐