文档章节

kNN(k近邻)算法代码实现

o
 osc_y8yehimr
发布于 2019/03/20 20:43
字数 593
阅读 3
收藏 0

目标:预测未知数据(或测试数据)X的分类y
批量kNN算法
1.输入一个待预测的X(一维或多维)给训练数据集,计算出训练集X_train中的每一个样本与其的距离
2.找到前k个距离该数据最近的样本-->所属的分类y_train
3.将前k近的样本进行统计,哪个分类多,则我们将x分类为哪个分类

# 准备阶段:

import numpy as np
# import matplotlib.pyplot as plt

raw_data_X = [[3.393533211, 2.331273381],
              [3.110073483, 1.781539638],
              [1.343808831, 3.368360954],
              [3.582294042, 4.679179110],
              [2.280362439, 2.866990263],
              [7.423436942, 4.696522875],
              [5.745051997, 3.533989803],
              [9.172168622, 2.511101045],
              [7.792783481, 3.424088941],
              [7.939820817, 0.791637231]
             ]
raw_data_y = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]

X_train = np.array(raw_data_X)
y_train = np.array(raw_data_y)

x = np.array([8.093607318, 3.365731514])

核心代码:

 目标:预测未知数据(或测试数据)X的分类y  
批量kNN算法  
1.输入一个待预测的X(一维或多维)给训练数据集,计算出训练集X_train中的每一个样本与其的距离  
2.找到前k个距离该数据最近的样本-->所属的分类y_train  
3.将前k近的样本进行统计,哪个分类多,则我们将x分类为哪个分类

from math import sqrt
from collections import Counter

# 已知X_train,y_train
# 预测x的分类
def predict(x, k=5):
    # 计算训练集每个样本与x的距离
    distances = [sqrt(np.sum((x-x_train)**2)) for x_train in X_train]  # 这里用了numpy的fancy方法,np.sum((x-x_train)**2)
    # 获得距离对应的索引,可以通过这些索引找到其所属分类y_train
    nearest = np.argsort(distances)
    # 得到前k近的分类y
    topK_y = [y_train[neighbor] for neighbor in nearest[:k]]
    # 投票的方式,得到一个字典,key是分类,value数个数
    votes = Counter(topK_y)
    # 取出得票第一名的分类
    return votes.most_common(1)[0][0]   # 得到y_predict

predict(x, k=6)

面向对象的方式,模仿sklearn中的方法实现kNN算法:

import numpy as np
from math import sqrt
from collections import Counter


class kNN_classify:
    def __init__(self, n_neighbor=5):
        self.k = n_neighbor
        self._X_train = None
        self._y_train = None

    def fit(self, X_train, y_train):
        self._X_train = X_train
        self._y_train = y_train
        return self

    def predict(self, X):
        '''接收多维数据,返回y_predict也是多维的'''
        y_predict = [self._predict(x) for x in X]
        # return y_predict
        return np.array(y_predict)  # 返回array的格式

    def _predict(self, x):
        '''接收一个待预测的x,返回y_predict'''
        distances = [sqrt(np.sum((x-x_train)**2)) for x_train in self._X_train]
        nearest = np.argsort(distances)
        topK_y = [self._y_train[neighbor] for neighbor in nearest[:self.k]]
        votes = Counter(topK_y)
        return votes.most_common(1)[0][0]

    def __repr__(self):
        return 'kNN_clf(k=%d)' % self.k

 

o
粉丝 0
博文 500
码字总数 0
作品 0
私信 提问
加载中
请先登录后再评论。

暂无文章

Java 获取资源文件路径

1 问题描述 通过源码运行时,一般使用如下方式读取资源文件: String str = "1.jpg"; 资源文件与源码文件放在同一目录下,或者拥有同一父级目录: String str = "a/b/1.jpg"; 这样直接编译...

氷泠
9分钟前
4
0
Linux程序移植到Android上

序言: 由于本人还是比较偏重于先说明原理在说明实际操作步骤,要知其然更要知其所以然,如下图所示: 传统的linux系统中的程序基本都依赖于glibc(至于什么是glibc可以百度去),而右边AOS...

shzwork
21分钟前
17
0
git 为项目设置用户名/邮箱/密码

1.找到项目所在目录下的 .git,进入.git文件夹,然后执行如下命令分别设置用户名和邮箱 git config user.name "Affandi" git config user.email "123333333@qq.com" 然后执行命令查看con......

有时很滑稽
53分钟前
0
0
如何从int转换为String? - How do I convert from int to String?

问题: I'm working on a project where all conversions from int to String are done like this: 我正在一个项目中,所有从int到String转换都是这样完成的: int i = 5;String strI = "" ......

javail
今天
10
0
Vue+Spring Data JPA+MySQL 增查改删

视频讲解: https://www.bilibili.com/video/BV16i4y1G7i2/ 工程概述: 前后端分离,进行简单增查改删(CRUD) 前端使用VUE 后端使用Spring Data JPA 数据库使用MySQL #EmployeeController.jav...

潘文海
今天
13
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部