文档章节

kNN分类学习(Tensorflow实现)

Nioacht
 Nioacht
发布于 2017/02/28 00:48
字数 1337
阅读 1596
收藏 10

kNN算法原理

kNN也就是k-NearestNeighbour的缩写。从命名上也可大致了解到这个算法的精髓了。用一句话概括而言,kNN分类算法就是‘近朱者赤,近墨者黑’。说得准确一点就是如果一个样本在特征空间中的k个最相邻的样本大多数属于某一类别,则该样本也属于此类别,并具有相应类别的特征

下面这个例子出现在无数讲解kNN的文章中,可见其的代表性:

我们把数据样本在一个平面上表示出来,相同类别的使用相同颜色和记号。绿色的圆形代表一个新的样本,我们使用kNN来判断它的类别。方法如下:以绿色圆形为圆心,开始做不同半径的同心圆,从实心线的同心圆来看,绿色圆形属于红色三角,从虚线同心圆来看,绿色圆形属于蓝色正方形,以此类推......

从上面的例子中不难发现问题,一方面,不同的半径会有不同的分类结果。可以说从图像上来看未知样本属于红色三角的可能性要比属于蓝色正方形的可能性大,也就是实线的同心圆范围内是合理的结果。反应在算法设计方面就是k值的选择。从另一方面看,距离测试样本近的数据所占的权重要更大,距离远的占的权重应该小,从而可以部分避免k值选取不当而造成的判断错误。

k值的选择

k越小,分类边界曲线越光滑,偏差越小,方差越大;K越大,分类边界曲线越平坦,偏差越大,方差越小。所以即使简单如kNN,同样要考虑偏差和方差的权衡问题,表现为k的选取。

k太小,分类结果易受噪声点影响;

k太大,近邻中又可能包含太多的其它类别的点。(对距离加权,可以降低k值设定的影响)k值通常是采用交叉检验来确定(以k=1为基准)

经验规则:k一般低于训练样本数的平方根

而所谓的交叉验证就是把数据样本分成训练集和测试集,然后k=1开始,使用验证集来更新k的值。

类别判定(权值选择)

投票决定:少数服从多数,近邻中哪个类别的点最多就分为该类。

加权投票法:根据距离的远近,对近邻的投票进行加权,距离越近则权重越大(权重为距离平方的倒数)

距离的定义

距离衡量包括欧式距离、夹角余弦等。

对于文本分类来说,使用余弦(cosine)来计算相似度就比欧式(Euclidean)距离更合适

Tensorflow实现

对minist数据集使用kNN算法 python3.5版本可以运行:

import tensorflow as tf 
import numpy as np

import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)     #下载并加载mnist数据

train_X, train_Y = mnist.train.next_batch(5000) # 5000 for training (nn candidates)
test_X, test_Y = mnist.test.next_batch(100)   # 200 for testing


tra_X = tf.placeholder("float", [None, 784])
te_X = tf.placeholder("float", [784])

# Nearest Neighbor calculation using L1 Distance
# Calculate L1 Distance
distance = tf.reduce_sum(tf.abs(tf.add(tra_X, tf.neg(te_X))), reduction_indices=1)
# Prediction: Get min distance index (Nearest neighbor)
pred = tf.arg_min(distance, 0)

accuracy = 0.

# Initializing the variables
init = tf.initialize_all_variables()

# Launch the graph
with tf.Session() as sess:
	sess.run(init)

	# loop over test data
	for i in range(len(test_X)):
    	# Get nearest neighbor
    	nn_index = sess.run(pred, feed_dict={tra_X: train_X, te_X: test_X[i, :]})
    	# Get nearest neighbor class label and compare it to its true label
    	print("Test", i, "Prediction:", np.argmax(train_Y[nn_index]), \
        	"True Class:", np.argmax(test_Y[i]))
    	# Calculate accuracy
    	if np.argmax(train_Y[nn_index]) == np.argmax(test_Y[i]):
        	accuracy += 1./len(test_X)
	print("Done!")
	print("Accuracy:", accuracy)

kNN分类算法的评价

##优点

  • 1.简单,易于理解,易于实现,无需估计参数,无需训练;
    1. 适合对稀有事件进行分类;
  • 3.特别适合于多分类问题(multi-modal,对象具有多个类别标签), kNN比SVM的表现要好。

缺点

  • 1.样本不平衡时,会使结果出现偏差
  • 2.计算量庞大
  • 3.由于不需要训练,所以算法的可控性比较差

改进策略

分类效率:事先对样本属性进行约简,删除对分类结果影响较小的属性,快速的得出待分类样本的类别。该算法比较适用于样本容量比较大的类域的自动分类,而那些样本容量较小的类域采用这种算法比较容易产生误分。

分类效果:采用权值的方法(和该样本距离小的邻居权值大)来改进,Han等人于2002年尝试利用贪心法,针对文件分类实做可调整权重的k最近邻居法WAkNN (weighted adjusted k nearest neighbor),以促进分类效果;而Li等人于2004年提出由于不同分类的文件本身有数量上有差异,因此也应该依照训练集合中各种分类的文件数量,选取不同数目的最近邻居,来参与分类

© 著作权归作者所有

Nioacht
粉丝 9
博文 5
码字总数 5863
作品 0
程序员
私信 提问
加载中

评论(2)

Nioacht
Nioacht 博主
我测试了一下,kNN的程序在tensorflow 1.0 版本可能有问题 ,tf.neg要改成tf.negative
然后如果是在python2.7的版本下运行的,只要把print的()删除就可以了,如果还有其他问题,可以再评论处留言,我会再核对。
OSC_sTKuXx
OSC_sTKuXx
1- OpenCV+TensorFlow 入门人工智能图像处理-课程介绍

人工智能最火的两个方向,自然语言处理和计算机视觉 OpenCV的图像处理 TensorFlow的使用 供需关系理论,有需求所以才有提供 招聘网站: 图像算法两万以上 都需要的技能: OpenCV TensorFlow 人...

天涯明月笙
2018/04/04
0
0
史上最全TensorFlow学习资源汇总

来源 悦动智能(公众号ID:aibbtcom) 本篇文章将为大家总结TensorFlow纯干货学习资源,非常适合新手学习,建议大家收藏。 ▌一 、TensorFlow教程资源 1)适合初学者的TensorFlow教程和代码示...

悦动智能
2018/04/12
393
0
Tensorflow 2.0 轻松实现迁移学习

image from unsplash by Gábor Juhász 迁移学习即利用已有的知识来学习新的知识,与人类类似,比如你学会了用笔画画,也就可以学习用笔来画画,并不用从头学习握笔的姿势。对于机器学习来说...

Hongtao洪滔
昨天
0
0
【干货】史上最全的Tensorflow学习资源汇总,速藏!

一 、Tensorflow教程资源: 1)适合初学者的Tensorflow教程和代码示例:(https://github.com/aymericdamien/TensorFlow-Examples)该教程不光提供了一些经典的数据集,更是从实现最简单的“Hel...

技术小能手
2018/04/16
0
0
从GitHub在TensorFlow模型

图像处理/识别 1.PixelCNN&PixelRNN在TensorFlow TensorFlow实施像素回归神经网络。 地址:https://github.com/carpedm20/pixel-rnn-tensorflow 在TensorFlow 2.Simulated +无监督(S + U)......

知行合一1
2017/03/31
532
0

没有更多内容

加载失败,请刷新页面

加载更多

哪些情况下适合使用云服务器?

我们一直在说云服务器价格适中,具备弹性扩展机制,适合部署中小规模的网站或应用。那么云服务器到底适用于哪些情况呢?如果您需要经常原始计算能力,那么使用独立服务器就能满足需求,因为他...

云漫网络Ruan
今天
9
0
Java 中的 String 有没有长度限制

转载: https://juejin.im/post/5d53653f5188257315539f9a String是Java中很重要的一个数据类型,除了基本数据类型以外,String是被使用的最广泛的了,但是,关于String,其实还是有很多东西...

低至一折起
今天
21
0
OpenStack 简介和几种安装方式总结

OpenStack :是一个由NASA和Rackspace合作研发并发起的,以Apache许可证授权的自由软件和开放源代码项目。项目目标是提供实施简单、可大规模扩展、丰富、标准统一的云计算管理平台。OpenSta...

小海bug
昨天
11
0
DDD(五)

1、引言 之前学习了解了DDD中实体这一概念,那么接下来需要了解的就是值对象、唯一标识。值对象,值就是数字1、2、3,字符串“1”,“2”,“3”,值时对象的特征,对象是一个事物的具体描述...

MrYuZixian
昨天
9
0
解决Mac下VSCode打开zsh乱码

1.乱码问题 iTerm2终端使用Zsh,并且配置Zsh主题,该主题主题需要安装字体来支持箭头效果,在iTerm2中设置这个字体,但是VSCode里这个箭头还是显示乱码。 iTerm2展示如下: VSCode展示如下: 2...

HelloDeveloper
昨天
9
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部