在python上使用knn算法识别mnist。正确率只有27%。求查错,自己看了好几天都找不出来哪出

文章正文
发布时间:2024-10-22 23:12

在python上使用knn算法识别mnist。正确率只有27%。求查错,自己看了好几天都找不出来哪出问题了

# -*- coding: UTF-8 -*- from __future__ import division import os import struct import numpy as np import data import heapq '''knn 求距离公式''' def euc(vec1, vec2): npvec1, npvec2 = np.array(vec1), np.array(vec2) return ((npvec1-npvec2)**2).sum() '''data.image_data是mnist数据集,b是将这个数据集分成60000份''' a=np.array([data.image_data]) b=a.reshape((60000,784)) '''data.image_test_data是mnist测试集,d是将这个数据集分成10000份''' c=np.array([data.image_test_data]) d=c.reshape((10000,784)) '''i是测试次数,y是正确的次数''' i=0 y=0 while i < 10000: list1=[] list2=[] '''计算距离,并放入list1''' for x in b: list1.append(euc(d[i],x)) '''从list1里选3个最小的''' result = map(list1.index, heapq.nsmallest(11, list1)) result.sort() for x in result: x1=data.label_data[x] list2.append(x1) if data.label_test_data[i]==max(set(list2), key=list2.count): '''用百分比显示出正确率''' y=y+1 print("correct",i+1,"%.4f%%" % (y/(i+1)*100)) else: print("not correct",i+1) i=i+1

图片说明