文档章节

[scikit-learn 机器学习] 3. K-近邻算法分类和回归

o
 osc_ov4ok6s2
发布于 06/27 11:44
字数 932
阅读 52
收藏 0

「深度学习福利」大神带你进阶工程师,立即查看>>>


本文为 scikit-learn机器学习(第2版)学习笔记

K 近邻法(K-Nearest Neighbor, K-NN) 常用于 搜索和推荐系统。

1. KNN模型

  • 确定距离度量方法(如欧氏距离)
  • 根据 K 个最近的距离的邻居样本,选择策略做出预测
  • 模型假设:距离相近的样本,有接近的响应值

2. KNN分类

根据身高、体重对性别进行分类

import numpy as np
import matplotlib.pyplot as plt

X_train = np.array([
    [158, 64],
    [170, 86],
    [183, 84],
    [191, 80],
    [155, 49],
    [163, 59],
    [180, 67],
    [158, 54],
    [170, 67]
])
y_train = ['male', 'male', 'male', 'male', 'female', 'female', 'female', 'female', 'female']

plt.figure()
plt.title('Human Heights and Weights by Sex')
plt.xlabel('Height in cm')
plt.ylabel('Weight in kg')

for i, x in enumerate(X_train):
    if y_train[i] == 'male':
        c1 = plt.scatter(x[0], x[1], c='k', marker='x')
    else:
        c2 = plt.scatter(x[0], x[1], c='r', marker='o')
plt.grid(True)
plt.legend((c1,c2),('male','female'),loc='lower right')
# plt.show()

在这里插入图片描述

  • 对身高 155cm,体重 70 kg的人进行性别预测
  • 设置 KNN 模型 k = 3
计算距离
x = np.array([[155,70]])
dis = np.sqrt(np.sum((X_train-x)**2 ,axis = 1))
dis
选取最近k个
nearset_k_neighbor = dis.argsort()[0:3]
k_genders = [y_train[i] for i in nearset_k_neighbor]
k_genders  # ['male', 'female', 'female']
计算最近的k个的标签
from collections import Counter
# b = Counter(np.take(y_train, dis.argsort()[0:3]))
b = Counter(k_genders)
b # Counter({'male': 1, 'female': 2})
性别为女性占多数
# help(Counter.most_common)
# most_common(self, n=None)
#     List the n most common elements and their counts from the most
#     common to the least.  If n is None, then list all element counts.
b.most_common(2) # [('female', 2), ('male', 1)]
b.most_common(1)[0][0] # 'female'

3. 使用sklearn KNN分类

标签(male,female)数字化(0,1)

from sklearn.preprocessing import LabelBinarizer
from sklearn.neighbors import KNeighborsClassifier

lb = LabelBinarizer()
y_train_lb = lb.fit_transform(y_train)
y_train_lb
######
array([[1],
       [1],
       [1],
       [1],
       [0],
       [0],
       [0],
       [0],
       [0]])

预测前面的例子的性别

K=3
clf = KNeighborsClassifier(n_neighbors=K)
clf.fit(X_train,y_train_lb.ravel())
pred_gender = clf.predict(x)
pred_gender # array([0])
pred_label_gender = lb.inverse_transform(pred_gender)
pred_label_gender # array(['female'], dtype='<U6')

在test集上验证

X_test = np.array([
    [168, 65],
    [180, 96],
    [160, 52],
    [169, 67]
])
y_test = ['male', 'male', 'female', 'female']
y_test_lb = lb.transform(y_test)

pred_lb = clf.predict(X_test)
print('Predicted labels: %s' % lb.inverse_transform(pred_lb))
# Predicted labels: ['female' 'male' 'female' 'female']

计算评价指标

准确率:预测对了的比例3/4
from sklearn.metrics import accuracy_score
accuracy_score(y_test_lb, pred_lb) # 0.75
精准率:正类为男,男预测为男/(男预测男+女预测男)
from sklearn.metrics import precision_score
precision_score(y_test_lb, pred_lb) # 1.0
召回率: 男预测男/(男预测男+男预测女)
from sklearn.metrics import recall_score
recall_score(y_test_lb, pred_lb) # 0.5

F1 值

F1 得分是:精准率和召回率的均衡
from sklearn.metrics import f1_score
f1_score(y_test_lb, pred_lb) # 0.6667
评价报告
from sklearn.metrics import classification_report
# help(classification_report)
# classification_report(y_true, y_pred, labels=None, target_names=None, s
#        ample_weight=None, digits=2, output_dict=False, zero_division='warn')
print(classification_report(y_test_lb, pred_lb, target_names=['male','female'], labels=[1,0]))

在这里插入图片描述

4. KNN回归

根据身高、性别,预测其体重

from sklearn.neighbors import KNeighborsRegressor
from sklearn.metrics import mean_absolute_error, mean_squared_error,r2_score

X_train = np.array([
    [158,  1],
    [170,  1],
    [183,  1],
    [191,  1],
    [155,  0],
    [163,  0],
    [180,  0],
    [158,  0],
    [170,  0]
])
y_train = [64,86,84,80,49,59,67,54,67]

X_test = np.array([
    [168,  1],
    [180,  1],
    [160,  0],
    [169,  0]
])
y_test = [65,96,52,67]

K = 3
clf = KNeighborsRegressor(n_neighbors=K)
clf.fit(X_train, y_train)
predictions = clf.predict(np.array(X_test))
predictions # array([70.66666667, 79.        , 59.        , 70.66666667])

# help(r2_score)
# R^2 (coefficient of determination)
r2_score(y_test, predictions) # 0.6290565226735438

平均绝对值误差
mean_absolute_error(y_test, predictions) # 8.333333333333336

平均平方误差
mean_squared_error(y_test, predictions)  # 95.8888888888889
  • 数据没有标准化的影响
from scipy.spatial.distance import euclidean
# help(euclidean) # 欧氏距离
X_train = np.array([
    [1700,1],
    [1600,0]
])
X_test = np.array([1640,1]).reshape(1,-1)
print(euclidean(X_train[0,:], X_test))
print(euclidean(X_train[1,:], X_test))
# 60.0
# 40.01249804748511

X_train = np.array([
    [1.7,1],
    [1.6,0]
])
X_test = np.array([1.64,1]).reshape(1,-1)
print(euclidean(X_train[0,:], X_test))
print(euclidean(X_train[1,:], X_test))
# 0.06000000000000005
# 1.0007996802557444

可以看出不同单位下的欧式距离差异很大

  • 进行数据标准化
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)

print(X_train)
print(X_train_scaled)
[[158   1]
 [170   1]
 [183   1]
 [191   1]
 [155   0]
 [163   0]
 [180   0]
 [158   0]
 [170   0]]
[[-0.9908706   1.11803399]
 [ 0.01869567  1.11803399]
 [ 1.11239246  1.11803399]
 [ 1.78543664  1.11803399]
 [-1.24326216 -0.89442719]
 [-0.57021798 -0.89442719]
 [ 0.86000089 -0.89442719]
 [-0.9908706  -0.89442719]
 [ 0.01869567 -0.89442719]]
  • 标准化特征后 模型误差更低
pred = clf.predict(X_test_scaled)
pred # array([78.        , 83.33333333, 54.        , 64.33333333])

# R^2 (coefficient of determination)
r2_score(y_test, pred) # 0.6706425961745109

# 平均绝对值误差
mean_absolute_error(y_test, pred) # 7.583333333333336

# 平均平方误差
mean_squared_error(y_test, pred)  # 85.13888888888893

o
粉丝 0
博文 76
码字总数 0
作品 0
私信 提问
加载中
请先登录后再评论。
基于 ThinkPHP 的内容管理系统--歪酷CMS

歪酷网站管理系统(歪酷CMS)是一款基于THINKPHP框架开发的PHP+MYSQL网站建站程序,本程序实现了文章和栏目的批量动态管理,支持栏目无限分类,实现多管理员管理,程序辅助功能也基本实现了常见的文...

鲁大在线
2013/02/19
7.1K
2
硬实时操作系统--Raw OS

Raw-OS 起飞于2012年,Raw-OS志在制作中国人自己的最优秀硬实时操作系统。 Raw-OS 操作系统特性 内核最大关中断时间无限接近0us, s3c2440系统最大关中断时间实测0.8us。 支持idle任务级别的事...

jorya_txj
2013/03/19
6.3K
1
CSS编译工具--Peaches

Peaches是一个基于Node的CSS编译工具,用于自动合成CSS Sprite。 Peaches 追求简单、自然的CSS书写方式! 大致的工作原理如下: 1. 我们在书写样式时,对每个需要使用背景图片的元素,进行单...

sliuqin
2013/04/12
598
0
PHP博客系统--WBlog

Wblog是一个基于thinkphp3.1框架开发的轻量级的简洁实用的PHP博客系统,倡导“大道至简,开发由我”的理念,用最少的代码完成更多的功能。更多功能仍在完善中。。。 目前主要功能:   主博...

匿名
2012/11/02
3K
0
MySQL全文搜索引擎--mysqlcft

MySQL在高并发连接、数据库记录数较多的情况下,SELECT ... WHERE ... LIKE '%...%'的全文搜索方式不仅效率差,而且以通配符%开头作查询时,使用不到索引,需要全表扫描,对数据库的压力也很...

张宴
2012/11/29
1.6W
2

没有更多内容

加载失败,请刷新页面

加载更多

大数据研发学习之路--Hadoop集群搭建

阅读编译文档 准备一个hadoop源码包,我选择的hadoop版本是:hadoop-2.7.7-src.tar.gz,在hadoop-2.7.7的源码 包的根目录下有一个文档叫做BUILDING.txt,这其中说明了编译hadoop所需要的一些...

DSJ-shitou
27分钟前
8
0
OSChina 周五乱弹 —— 特么是别的公司派来的特洛伊木马吧?

Osc乱弹歌单(2020)请戳(这里) 【今日歌曲】 小小编辑推荐:《我会守在这里》- 毛不易 《我会守在这里》- 毛不易 手机党少年们想听歌,请使劲儿戳(这里) @FalconChen :股市连跪了五天,...

小小编辑
29分钟前
32
2
如何在find中排除目录。命令 - How to exclude a directory in find . command

问题: I'm trying to run a find command for all JavaScript files, but how do I exclude a specific directory? 我正在尝试为所有JavaScript文件运行find命令,但是如何排除特定目录? ......

法国红酒甜
今天
69
0
《Java8实战》笔记(02):通过行为参数传递代码

本文源码 应对不断变化的需求 通过筛选苹果阐述通过行为参数传递代码 初试牛刀:筛选绿苹果 public static List<Apple> filterGreenApples(List<Apple> inventory){List<Apple> result = ......

巨輪
今天
19
0
JeeSite 4 架构特点、安全方面、为什么好、工匠精神、不忘初心

1、底层架构 以 Spring Boot 2 为基础,Maven 多项目依赖,模块分项目,松耦合,方便模块升级、增减模块。 模块化的数据库自动升级程序,当模块升级代码需要更新数据库时,自动执行对应版本 ...

ThinkGem
昨天
13
0

没有更多内容

加载失败,请刷新页面

加载更多

返回顶部
顶部