K - 近邻算法
一, 算法概述
(1) 采用测量不同特征值之间的距离方法进行分类
优点: 精度高, 对异常值不敏感, 无数据输入假定.
缺点: 计算复杂度高, 空间复杂度高.
(2)KNN 模型的三个要素
kNN 算法模型实际上就是对特征空间的的划分. 模型有三个基本要素: 距离度量, K 值的选择和分类决策规则的决定.
距离度量
距离定义为:
\[L_p(x_i,x_j)=\left( \sum^n_{l=1} |x_i^{(l)} - x_j^{(l)}|^p \right) ^{\frac{1}{p}}\]
一般使用欧式距离: p = 2 的个情况
\[L_p(x_i,x_j)=\left( \sum^n_{l=1} |x_i^{(l)} - x_j^{(l)}|^2 \right) ^{\frac{1}{2}}\]
K 值的选择
一般根据经验选择, 需要多次选择对比才可以选择一个比较合适的 K 值.
如果 K 值太小, 会导致模型太复杂, 容易产生过拟合现象, 并且对噪声点非常敏感.
如果 K 值太大, 模型太过简单, 忽略的大部分有用信息, 也是不可取的.
分类决策规则
一般采用多数表决规则, 通俗点说就是在这 K 个类别中, 哪种类别最后就判别为哪种类型
二, 实施 kNN 算法
2.1 伪代码
计算法已经类别数据集中的点与当前点之间的距离
按照距离递增次序排序
选取与但前点距离最小的 k 个点
确定前 k 个点所在类别的出现频率
返回前 k 个点出现频率最高的类别作为当前点的预测分类
2.2 实际代码
- def classify0(inX, dataSet, labels, k):
- dataSetSize = dataSet.shape[0]
- diffMat = tile(inX, (dataSetSize,1)) - dataSet
- sqDiffMat = diffMat**2
- sqDistances = sqDiffMat.sum(axis=1)
- distances = sqDistances**0.5
- sortedDistIndicies = distances.argsort()
- classCount={}
- for i in range(k):
- voteIlabel = labels[sortedDistIndicies[i]]
- classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
- sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
- return sortedClassCount[0][0]
三, 实际案例: 使用 kNN 算法改进约会网站的配对效果
我的朋友阿 J 一直使用在线约会软件寻找约会对象, 他曾经交往过三种类型的人:
不喜欢的人
感觉一般的人
非常喜欢的人
步骤:
收集数据
准备数据: 也就是读取数据的过程
分析数据: 使用 Matplotlib 画出二维散点图
训练算法
测试算法
使用算法
3.1 准备数据
样本数据共有 1000 个, 3 个特征值, 共有 4 列数据, 最后一列表示标签分类 (0: 不喜欢的人; 1: 感觉一般的人; 2: 非常喜欢的人)
特征
每年获得的飞行常客里程数
玩视频游戏所好的时间百分比
每周消费的冰淇淋公斤数
部分数据如下:
- 40920 8.326976 0.953952 3
- 14488 7.153469 1.673904 2
- 26052 1.441871 0.805124 1
- 75136 13.147394 0.428964 1
- 38344 1.669788 0.134296 1
- 72993 10.141740 1.032955 1
- 35948 6.830792 1.213192 3
- 42666 13.276369 0.543880 3
- 67497 8.631577 0.749278 1
- 35483 12.273169 1.508053 3
读取数据 (读取 txt 文件)
- def file2matrix(filename):
- fr = open(filename)
- numberOfLines = len(fr.readlines()) #get the number of lines in the file
- returnMat = zeros((numberOfLines,3)) #prepare matrix to return
- classLabelVector = [] #prepare labels return
- fr = open(filename)
- index = 0
- for line in fr.readlines():
- line = line.strip()
- listFromLine = line.split('\t')
- returnMat[index,:] = listFromLine[0:3]
- classLabelVector.append(int(listFromLine[-1]))
- index += 1
- return returnMat,classLabelVector
3.2 分析数据: 使用 Matplotlib 创建散点图
初步分析
- import matplotlib
- import matplotlib.pyplot as plt
- plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']
- fig = plt.figure()
- ax = fig.add_subplot(111)
- ax.scatter(datingDataMat[:,1], datingDataMat[:,2])
- ax.set_xlabel("玩视频游戏所耗时间百分比")
- ax.set_ylabel("每周消费的冰淇淋公斤数")
- plt.show()
因为有三种类型的分类, 这样看的不直观, 我们添加以下颜色
- fig = plt.figure()
- ax = fig.add_subplot(111)
- ax.scatter(datingDataMat[:,1], datingDataMat[:,2])
- ax.scatter(datingDataMat[:,1], datingDataMat[:,2], 15.0*array(datingLabels), 15.0*array(datingLabels))
- ax.set_xlabel("玩视频游戏所耗时间百分比")
- ax.set_ylabel("每周消费的冰淇淋公斤数")
- plt.show()
通过都多次的尝试后发现, 玩游戏时间和冰淇淋这个两个特征关系比较明显
具体的步骤:
分别将标签为 1,2,3 的三种类型的数据分开
使用 matplotlib 绘制, 并使用不同的颜色加以区分
- datingDataType1 = array([[x[0][0],x[0][1],x[0][2]] for x in zip(datingDataMat,datingLabels) if x[1]==1])
- datingDataType2 = array([[x[0][0],x[0][1],x[0][2]] for x in zip(datingDataMat,datingLabels) if x[1]==2])
- datingDataType3 = array([[x[0][0],x[0][1],x[0][2]] for x in zip(datingDataMat,datingLabels) if x[1]==3])
- fig, axs = plt.subplots(2, 2, figsize = (15,10))
- axs[0,0].scatter(datingDataType1[:,0], datingDataType1[:,1], s = 20, c = 'red')
- axs[0,1].scatter(datingDataType2[:,0], datingDataType2[:,1], s = 30, c = 'green')
- axs[1,0].scatter(datingDataType3[:,0], datingDataType3[:,1], s = 40, c = 'blue')
- type1 = axs[1,1].scatter(datingDataType1[:,0], datingDataType1[:,1], s = 20, c = 'red')
- type2 = axs[1,1].scatter(datingDataType2[:,0], datingDataType2[:,1], s = 30, c = 'green')
- type3 = axs[1,1].scatter(datingDataType3[:,0], datingDataType3[:,1], s = 40, c = 'blue')
- axs[1,1].legend([type1, type2, type3], ["Did Not Like", "Liked in Small Doses", "Liked in Large Doses"], loc=2)
- axs[1,1].set_xlabel("玩视频游戏所耗时间百分比")
- axs[1,1].set_ylabel("每周消费的冰淇淋公斤数")
- plt.show()
3.3 准备数据: 数据归一化
通过上面的图形绘制, 发现三个特征值的范围不一样, 在使用 KNN 进行计算距离的时候, 数值大的特征值就会对结果产生更大的影响.
数据归一化: 就是将几组不同范围的数据, 转换到同一个范围内.
公式: newValue = (oldValue - min)/(max - min)
- def autoNorm(dataSet):
- minVals = dataSet.min(0) # array([[1,20,3], [4,5,60], [7,8,9]]) min(0) = [1, 5, 3]
- maxVals = dataSet.max(0)
- ranges = maxVals - minVals
- normData = zeros(shape(dataSet))
- m = dataSet.shape[0]
- normData = (dataSet - tile(minVals, (m,1)))/tile(ranges,(m,1))
- return normData
3.4 测试算法
我们将原始样本保留 20% 作为测试集, 剩余 80% 作为训练集
def datingClassTest(): hoRatio = 0.20 datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') #load data setfrom file normMat = autoNorm(datingDataMat) m = normMat.shape[0] numTestVecs = int(m*hoRatio) errorCount = 0.0 for i in range(numTestVecs): classifierResult = classify0(normMat[i,:],normMat[numTestVecs:,:],datingLabels[numTestVecs:],3) if (classifierResult != datingLabels[i]): errorCount += 1.0 print ("the total error rate is: %f" % (errorCount/float(numTestVecs))) print (errorCount)
运行结果
the total error rate is: 0.080000 16.0
四, 源代码
from numpy import * import operator from os import listdir import matplotlib import matplotlib.pyplot as plt ## KNN function def classify0(inX, dataSet, labels, k): dataSetSize = dataSet.shape[0] diffMat = tile(inX, (dataSetSize,1)) - dataSet sqDiffMat = diffMat**2 sqDistances = sqDiffMat.sum(axis=1) distances = sqDistances**0.5 sortedDistIndicies = distances.argsort() classCount={} for i in range(k): voteIlabel = labels[sortedDistIndicies[i]] classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True) return sortedClassCount[0][0] # read txt data def file2matrix(filename): fr = open(filename) numberOfLines = len(fr.readlines()) #get the number of lines in the file returnMat = zeros((numberOfLines,3)) #prepare matrix to return classLabelVector = [] #prepare labels return fr = open(filename) index = 0 for line in fr.readlines(): line = line.strip() listFromLine = line.split('\t') returnMat[index,:] = listFromLine[0:3] classLabelVector.append(int(listFromLine[-1])) index += 1 return returnMat,classLabelVector def autoNorm(dataSet): minVals = dataSet.min(0) # array([[1,20,3], [4,5,60], [7,8,9]]) min(0) = [1, 5, 3] maxVals = dataSet.max(0) ranges = maxVals - minVals normData = zeros(shape(dataSet)) m = dataSet.shape[0] normData = (dataSet - tile(minVals, (m,1)))/tile(ranges,(m,1)) return normData def drawScatter1(datingDataMat, datingLabels): plt.rcParams['font.sans-serif'] = ['Microsoft YaHei'] fig = plt.figure() ax = fig.add_subplot(111) ax.scatter(datingDataMat[:,1], datingDataMat[:,2]) ax.set_xlabel("玩视频游戏所耗时间百分比") ax.set_ylabel("每周消费的冰淇淋公斤数") plt.show() def drawScatter2(datingDataMat, datingLabels): fig = plt.figure() ax = fig.add_subplot(111) ax.scatter(datingDataMat[:,1], datingDataMat[:,2]) ax.scatter(datingDataMat[:,1], datingDataMat[:,2], 15.0*array(datingLabels), 15.0*array(datingLabels)) ax.set_xlabel("玩视频游戏所耗时间百分比") ax.set_ylabel("每周消费的冰淇淋公斤数") plt.show() def drawScatter3(datingDataMat, datingLabels): datingDataType1 = array([[x[0][0],x[0][1],x[0][2]] for x in zip(datingDataMat,datingLabels) if x[1]==1]) datingDataType2 = array([[x[0][0],x[0][1],x[0][2]] for x in zip(datingDataMat,datingLabels) if x[1]==2]) datingDataType3 = array([[x[0][0],x[0][1],x[0][2]] for x in zip(datingDataMat,datingLabels) if x[1]==3]) fig, axs = plt.subplots(2, 2, figsize = (15,10)) axs[0,0].scatter(datingDataType1[:,0], datingDataType1[:,1], s = 20, c = 'red') axs[0,1].scatter(datingDataType2[:,0], datingDataType2[:,1], s = 30, c = 'green') axs[1,0].scatter(datingDataType3[:,0], datingDataType3[:,1], s = 40, c = 'blue') type1 = axs[1,1].scatter(datingDataType1[:,0], datingDataType1[:,1], s = 20, c = 'red') type2 = axs[1,1].scatter(datingDataType2[:,0], datingDataType2[:,1], s = 30, c = 'green') type3 = axs[1,1].scatter(datingDataType3[:,0], datingDataType3[:,1], s = 40, c = 'blue') axs[1,1].legend([type1, type2, type3], ["Did Not Like", "Liked in Small Doses", "Liked in Large Doses"], loc=2) axs[1,1].set_xlabel("玩视频游戏所耗时间百分比") axs[1,1].set_ylabel("每周消费的冰淇淋公斤数") plt.show() def datingClassTest(): hoRatio = 0.20 datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') #load data setfrom file normMat = autoNorm(datingDataMat) m = normMat.shape[0] numTestVecs = int(m*hoRatio) errorCount = 0.0 for i in range(numTestVecs): classifierResult = classify0(normMat[i,:],normMat[numTestVecs:,:],datingLabels[numTestVecs:],3) if (classifierResult != datingLabels[i]): errorCount += 1.0 print ("the total error rate is: %f" % (errorCount/float(numTestVecs))) print (errorCount) datingDataMat, datingLabels = file2matrix("datingTestSet2.txt") drawScatter1(datingDataMat, datingLabels) drawScatter2(datingDataMat, datingLabels) drawScatter3(datingDataMat, datingLabels) datingClassTest()
来源: https://www.cnblogs.com/zou107/p/11918852.html