机器学习算法整理(二)

原创
2021/08/29 14:54
阅读数 4.2K

机器学习算法整理

scikit-learn中的PCA

from sklearn.decomposition import PCA
import numpy as np
import matplotlib.pyplot as plt

if __name__ == "__main__":

    X = np.empty((100, 2))
    X[:, 0] = np.random.uniform(0, 100, size=100)
    X[:, 1] = 0.75 * X[:, 0] + 3 + np.random.normal(0, 10, size=100)
    pca = PCA(n_components=1)
    pca.fit(X)
    print(pca.components_)
    X_reduction = pca.transform(X)
    print(X_reduction.shape)
    X_restore = pca.inverse_transform(X_reduction)
    print(X_restore.shape)
    plt.scatter(X[:, 0], X[:, 1], color='b', alpha=0.5)
    plt.scatter(X_restore[:, 0], X_restore[:, 1], color='r', alpha=0.5)
    plt.show()

运行结果

[[-0.78144234 -0.62397746]]
(100, 1)
(100, 2)

现在我们用真实的数据来看一下scikit-learn中的PCA的使用,我们要处理的是一组手写识别的数据分类。

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split

if __name__ == "__main__":

    digits = datasets.load_digits()
    X = digits.data
    y = digits.target
    # 对数据集进行训练数据和测试数据分类
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)
    print(X_train.shape)

运行结果

(1347, 64)

由结果我们可以看到我们的训练数据集有1347个样本数,每个样本有64个特征(维度)。我们先对原始的数据进行一下训练,看一看相应的识别率是多少。由于目前我们只用过一种分类算法——KNN算法,所以我们就使用KNN算法来进行分类。

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
import timeit

if __name__ == "__main__":

    digits = datasets.load_digits()
    X = digits.data
    y = digits.target
    # 对数据集进行训练数据和测试数据分类
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)
    print(X_train.shape)
    start_time = timeit.default_timer()
    knn_clf = KNeighborsClassifier()
    knn_clf.fit(X_train, y_train)
    print(timeit.default_timer() - start_time)
    print(knn_clf.score(X_test, y_test))

运行结果

(1347, 64)
0.0025633269999999486
0.9866666666666667

由结果可以看到,KNN算法对原始数据集的训练时间是2.5毫秒,训练结果对测试数据集进行打分为0.98分,识别准确率能达到98.66%。

现在我们对原始数据进行降维,再对降维后的数据进行训练

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.decomposition import PCA
import timeit

if __name__ == "__main__":

    digits = datasets.load_digits()
    X = digits.data
    y = digits.target
    # 对数据集进行训练数据和测试数据分类
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)
    print(X_train.shape)
    start_time = timeit.default_timer()
    knn_clf = KNeighborsClassifier()
    knn_clf.fit(X_train, y_train)
    print(timeit.default_timer() - start_time)
    print(knn_clf.score(X_test, y_test))
    # 将原始数据的特征降为2维
    pca = PCA(n_components=2)
    pca.fit(X_train)
    X_train_reduction = pca.transform(X_train)
    X_test_reduction = pca.transform(X_test)
    # 对降维后的数据集进行KNN训练
    start_time = timeit.default_timer()
    knn_clf = KNeighborsClassifier()
    knn_clf.fit(X_train_reduction, y_train)
    print(timeit.default_timer() - start_time)
    print(knn_clf.score(X_test_reduction, y_test))

运行结果

(1347, 64)
0.0021396940000000253
0.9866666666666667
0.0005477309999999402
0.6066666666666667

通过结果我们可以看到,降维后的训练时间变成了0.54毫秒,说明训练时间减少了很多,但是识别准确率只有60.6%。由此我们想到的是原来有64个维度的信息,现在一下子降到了2维,识别准确率从98.66%变成了60.6%,是不是可以增加降低的维度,来提高识别准确率呢?但是这个维度又是多少合适呢?

实际上,PCA算法为我们提供了一个特殊的指标,我们可以使用这种指标非常方便的找到对于某一个数据集来说,我们保持降低的维度就够。PCA中的这个指标叫做解释的方差比例。我们来看一下降到2维时的这个比例。

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.decomposition import PCA
import timeit

if __name__ == "__main__":

    digits = datasets.load_digits()
    X = digits.data
    y = digits.target
    # 对数据集进行训练数据和测试数据分类
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)
    print(X_train.shape)
    start_time = timeit.default_timer()
    knn_clf = KNeighborsClassifier()
    knn_clf.fit(X_train, y_train)
    print(timeit.default_timer() - start_time)
    print(knn_clf.score(X_test, y_test))
    # 将原始数据的特征降为2维
    pca = PCA(n_components=2)
    pca.fit(X_train)
    X_train_reduction = pca.transform(X_train)
    X_test_reduction = pca.transform(X_test)
    # 对降维后的数据集进行KNN训练
    start_time = timeit.default_timer()
    knn_clf = KNeighborsClassifier()
    knn_clf.fit(X_train_reduction, y_train)
    print(timeit.default_timer() - start_time)
    print(knn_clf.score(X_test_reduction, y_test))
    # 解释的方差比例
    print(pca.explained_variance_ratio_)

运行结果

(1347, 64)
0.002134913000000016
0.9866666666666667
0.0005182209999999854
0.6066666666666667
[0.14566817 0.13735469]

根据结果,我们降维后的两个维度,第一个维度可以解释14.5%原数据的方差,第二个维度可以解释13.7%原数据的方差。PCA就是为了寻找降维后方差最大,而这个指标就是说明了,降到某个维度后,维持了这个最大方差的百分比。而这两个维度的总百分比就是维持了最大方差的14.5%+13.7%=28.2%左右的比例,剩下的72%方差的信息就丢失了,这显然丢失的信息过多。

现在我们来看一下训练数据集所有特征的方差比例。

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.decomposition import PCA
import timeit

if __name__ == "__main__":

    digits = datasets.load_digits()
    X = digits.data
    y = digits.target
    # 对数据集进行训练数据和测试数据分类
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)
    print(X_train.shape)
    start_time = timeit.default_timer()
    knn_clf = KNeighborsClassifier()
    knn_clf.fit(X_train, y_train)
    print(timeit.default_timer() - start_time)
    print(knn_clf.score(X_test, y_test))
    # 将原始数据的特征降为2维
    pca = PCA(n_components=2)
    pca.fit(X_train)
    X_train_reduction = pca.transform(X_train)
    X_test_reduction = pca.transform(X_test)
    # 对降维后的数据集进行KNN训练
    start_time = timeit.default_timer()
    knn_clf = KNeighborsClassifier()
    knn_clf.fit(X_train_reduction, y_train)
    print(timeit.default_timer() - start_time)
    print(knn_clf.score(X_test_reduction, y_test))
    # 解释的方差比例
    print(pca.explained_variance_ratio_)
    # 查看所有特征的方差比例
    pca = PCA(n_components=X_train.shape[1])
    pca.fit(X_train)
    print(pca.explained_variance_ratio_)

运行结果

(1347, 64)
0.0023309580000000496
0.9866666666666667
0.0005462749999999295
0.6066666666666667
[0.14566817 0.13735469]
[1.45668166e-01 1.37354688e-01 1.17777287e-01 8.49968861e-02
 5.86018996e-02 5.11542945e-02 4.26605279e-02 3.60119663e-02
 3.41105814e-02 3.05407804e-02 2.42337671e-02 2.28700570e-02
 1.80304649e-02 1.79346003e-02 1.45798298e-02 1.42044841e-02
 1.29961033e-02 1.26617002e-02 1.01728635e-02 9.09314698e-03
 8.85220461e-03 7.73828332e-03 7.60516219e-03 7.11864860e-03
 6.85977267e-03 5.76411920e-03 5.71688020e-03 5.08255707e-03
 4.89020776e-03 4.34888085e-03 3.72917505e-03 3.57755036e-03
 3.26989470e-03 3.14917937e-03 3.09269839e-03 2.87619649e-03
 2.50362666e-03 2.25417403e-03 2.20030857e-03 1.98028746e-03
 1.88195578e-03 1.52769283e-03 1.42823692e-03 1.38003340e-03
 1.17572392e-03 1.07377463e-03 9.55152460e-04 9.00017642e-04
 5.79162563e-04 3.82793717e-04 2.38328586e-04 8.40132221e-05
 5.60545588e-05 5.48538930e-05 1.08077650e-05 4.01354717e-06
 1.23186515e-06 1.05783059e-06 6.06659094e-07 5.86686040e-07
 1.71368535e-33 7.44075955e-34 7.44075955e-34 7.15189459e-34]

现在我们可以看到这64个维度所有的方差比例,它是按照从大到小依次排列的。现在我们来绘制出前n个维度能解释的方差的和的图形。

plt.plot([i for i in range(X_train.shape[1])],
         [np.sum(pca.explained_variance_ratio_[: i + 1]) for i in range(X_train.shape[1])])
plt.show()

通过这个图,我们可以看出,当取的特征数越接近于原始数据特征数的时候,它能解释的方差的比例是越来越大的。此时如果我们需要保留95%以上的方差比例的时候,我们只需要在该图中纵轴0.95对应的图像的横轴是多少就可以了。这个功能scikt-learn中已经帮我们封装好了。

pca = PCA(0.95)
pca.fit(X_train)
print(pca.n_components_)

运行结果

28

说明当我们需要保留95%的最大方差比例的时候,我们需要降低的维度就是28维。我们可以求出此时的降维后的数据集和训练时间,训练后测试数据集的识别准确率。

pca = PCA(0.95)
pca.fit(X_train)
print(pca.n_components_)
X_train_reduction = pca.transform(X_train)
X_test_reduction = pca.transform(X_test)
start_time = timeit.default_timer()
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train_reduction, y_train)
print(timeit.default_timer() - start_time)
print(knn_clf.score(X_test_reduction, y_test))

运行结果

28
0.0014477489999999982
0.98

结果显示,KNN训练时间为1.4毫秒,这比用全维度的原始数据集要快了一倍左右,识别准确率为98%,比原始数据集只少了0.66%的识别准确率。这完全是可以接受的,在数据集非常巨大的情况下,我们进行这样的降维可以大大减少训练时间,识别准确率也是非常高的。

最后,我们把原始数据降到2维也不是完全没有意义的,它的意义就在于可以方便我们进行可视化。

pca = PCA(n_components=2)
pca.fit(X)
X_reduction = pca.transform(X)
print(X_reduction.shape)
for i in range(10):
    # 一次绘制一个数据在二维平面中的点
    plt.scatter(X_reduction[y == i, 0], X_reduction[y == i, 1], alpha=0.8)
plt.show()

运行结果

(1797, 2)

从图中可以看出(此时不做训练数据集和测试数据集的区分),每一个数据,它们的区分度也是非常高的。

MNIST数据集

现在我们来使用更正规的手写数据集MNIST,先获取MNIST的数据

import numpy as np
from sklearn.datasets import fetch_openml

if __name__ == "__main__":

    mnist = fetch_openml('mnist_784')
    print(mnist)

运行结果

{'data': array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]]), 'target': array(['5', '0', '4', ..., '4', '5', '6'], dtype=object), 'frame': None, 'feature_names': ['pixel1', 'pixel2', 'pixel3', 'pixel4', 'pixel5', 'pixel6', 'pixel7', 'pixel8', 'pixel9', 'pixel10', 'pixel11', 'pixel12', 'pixel13', 'pixel14', 'pixel15', 'pixel16', 'pixel17', 'pixel18', 'pixel19', 'pixel20', 'pixel21', 'pixel22', 'pixel23', 'pixel24', 'pixel25', 'pixel26', 'pixel27', 'pixel28', 'pixel29', 'pixel30', 'pixel31', 'pixel32', 'pixel33', 'pixel34', 'pixel35', 'pixel36', 'pixel37', 'pixel38', 'pixel39', 'pixel40', 'pixel41', 'pixel42', 'pixel43', 'pixel44', 'pixel45', 'pixel46', 'pixel47', 'pixel48', 'pixel49', 'pixel50', 'pixel51', 'pixel52', 'pixel53', 'pixel54', 'pixel55', 'pixel56', 'pixel57', 'pixel58', 'pixel59', 'pixel60', 'pixel61', 'pixel62', 'pixel63', 'pixel64', 'pixel65', 'pixel66', 'pixel67', 'pixel68', 'pixel69', 'pixel70', 'pixel71', 'pixel72', 'pixel73', 'pixel74', 'pixel75', 'pixel76', 'pixel77', 'pixel78', 'pixel79', 'pixel80', 'pixel81', 'pixel82', 'pixel83', 'pixel84', 'pixel85', 'pixel86', 'pixel87', 'pixel88', 'pixel89', 'pixel90', 'pixel91', 'pixel92', 'pixel93', 'pixel94', 'pixel95', 'pixel96', 'pixel97', 'pixel98', 'pixel99', 'pixel100', 'pixel101', 'pixel102', 'pixel103', 'pixel104', 'pixel105', 'pixel106', 'pixel107', 'pixel108', 'pixel109', 'pixel110', 'pixel111', 'pixel112', 'pixel113', 'pixel114', 'pixel115', 'pixel116', 'pixel117', 'pixel118', 'pixel119', 'pixel120', 'pixel121', 'pixel122', 'pixel123', 'pixel124', 'pixel125', 'pixel126', 'pixel127', 'pixel128', 'pixel129', 'pixel130', 'pixel131', 'pixel132', 'pixel133', 'pixel134', 'pixel135', 'pixel136', 'pixel137', 'pixel138', 'pixel139', 'pixel140', 'pixel141', 'pixel142', 'pixel143', 'pixel144', 'pixel145', 'pixel146', 'pixel147', 'pixel148', 'pixel149', 'pixel150', 'pixel151', 'pixel152', 'pixel153', 'pixel154', 'pixel155', 'pixel156', 'pixel157', 'pixel158', 'pixel159', 'pixel160', 'pixel161', 'pixel162', 'pixel163', 'pixel164', 'pixel165', 'pixel166', 'pixel167', 'pixel168', 'pixel169', 'pixel170', 'pixel171', 'pixel172', 'pixel173', 'pixel174', 'pixel175', 'pixel176', 'pixel177', 'pixel178', 'pixel179', 'pixel180', 'pixel181', 'pixel182', 'pixel183', 'pixel184', 'pixel185', 'pixel186', 'pixel187', 'pixel188', 'pixel189', 'pixel190', 'pixel191', 'pixel192', 'pixel193', 'pixel194', 'pixel195', 'pixel196', 'pixel197', 'pixel198', 'pixel199', 'pixel200', 'pixel201', 'pixel202', 'pixel203', 'pixel204', 'pixel205', 'pixel206', 'pixel207', 'pixel208', 'pixel209', 'pixel210', 'pixel211', 'pixel212', 'pixel213', 'pixel214', 'pixel215', 'pixel216', 'pixel217', 'pixel218', 'pixel219', 'pixel220', 'pixel221', 'pixel222', 'pixel223', 'pixel224', 'pixel225', 'pixel226', 'pixel227', 'pixel228', 'pixel229', 'pixel230', 'pixel231', 'pixel232', 'pixel233', 'pixel234', 'pixel235', 'pixel236', 'pixel237', 'pixel238', 'pixel239', 'pixel240', 'pixel241', 'pixel242', 'pixel243', 'pixel244', 'pixel245', 'pixel246', 'pixel247', 'pixel248', 'pixel249', 'pixel250', 'pixel251', 'pixel252', 'pixel253', 'pixel254', 'pixel255', 'pixel256', 'pixel257', 'pixel258', 'pixel259', 'pixel260', 'pixel261', 'pixel262', 'pixel263', 'pixel264', 'pixel265', 'pixel266', 'pixel267', 'pixel268', 'pixel269', 'pixel270', 'pixel271', 'pixel272', 'pixel273', 'pixel274', 'pixel275', 'pixel276', 'pixel277', 'pixel278', 'pixel279', 'pixel280', 'pixel281', 'pixel282', 'pixel283', 'pixel284', 'pixel285', 'pixel286', 'pixel287', 'pixel288', 'pixel289', 'pixel290', 'pixel291', 'pixel292', 'pixel293', 'pixel294', 'pixel295', 'pixel296', 'pixel297', 'pixel298', 'pixel299', 'pixel300', 'pixel301', 'pixel302', 'pixel303', 'pixel304', 'pixel305', 'pixel306', 'pixel307', 'pixel308', 'pixel309', 'pixel310', 'pixel311', 'pixel312', 'pixel313', 'pixel314', 'pixel315', 'pixel316', 'pixel317', 'pixel318', 'pixel319', 'pixel320', 'pixel321', 'pixel322', 'pixel323', 'pixel324', 'pixel325', 'pixel326', 'pixel327', 'pixel328', 'pixel329', 'pixel330', 'pixel331', 'pixel332', 'pixel333', 'pixel334', 'pixel335', 'pixel336', 'pixel337', 'pixel338', 'pixel339', 'pixel340', 'pixel341', 'pixel342', 'pixel343', 'pixel344', 'pixel345', 'pixel346', 'pixel347', 'pixel348', 'pixel349', 'pixel350', 'pixel351', 'pixel352', 'pixel353', 'pixel354', 'pixel355', 'pixel356', 'pixel357', 'pixel358', 'pixel359', 'pixel360', 'pixel361', 'pixel362', 'pixel363', 'pixel364', 'pixel365', 'pixel366', 'pixel367', 'pixel368', 'pixel369', 'pixel370', 'pixel371', 'pixel372', 'pixel373', 'pixel374', 'pixel375', 'pixel376', 'pixel377', 'pixel378', 'pixel379', 'pixel380', 'pixel381', 'pixel382', 'pixel383', 'pixel384', 'pixel385', 'pixel386', 'pixel387', 'pixel388', 'pixel389', 'pixel390', 'pixel391', 'pixel392', 'pixel393', 'pixel394', 'pixel395', 'pixel396', 'pixel397', 'pixel398', 'pixel399', 'pixel400', 'pixel401', 'pixel402', 'pixel403', 'pixel404', 'pixel405', 'pixel406', 'pixel407', 'pixel408', 'pixel409', 'pixel410', 'pixel411', 'pixel412', 'pixel413', 'pixel414', 'pixel415', 'pixel416', 'pixel417', 'pixel418', 'pixel419', 'pixel420', 'pixel421', 'pixel422', 'pixel423', 'pixel424', 'pixel425', 'pixel426', 'pixel427', 'pixel428', 'pixel429', 'pixel430', 'pixel431', 'pixel432', 'pixel433', 'pixel434', 'pixel435', 'pixel436', 'pixel437', 'pixel438', 'pixel439', 'pixel440', 'pixel441', 'pixel442', 'pixel443', 'pixel444', 'pixel445', 'pixel446', 'pixel447', 'pixel448', 'pixel449', 'pixel450', 'pixel451', 'pixel452', 'pixel453', 'pixel454', 'pixel455', 'pixel456', 'pixel457', 'pixel458', 'pixel459', 'pixel460', 'pixel461', 'pixel462', 'pixel463', 'pixel464', 'pixel465', 'pixel466', 'pixel467', 'pixel468', 'pixel469', 'pixel470', 'pixel471', 'pixel472', 'pixel473', 'pixel474', 'pixel475', 'pixel476', 'pixel477', 'pixel478', 'pixel479', 'pixel480', 'pixel481', 'pixel482', 'pixel483', 'pixel484', 'pixel485', 'pixel486', 'pixel487', 'pixel488', 'pixel489', 'pixel490', 'pixel491', 'pixel492', 'pixel493', 'pixel494', 'pixel495', 'pixel496', 'pixel497', 'pixel498', 'pixel499', 'pixel500', 'pixel501', 'pixel502', 'pixel503', 'pixel504', 'pixel505', 'pixel506', 'pixel507', 'pixel508', 'pixel509', 'pixel510', 'pixel511', 'pixel512', 'pixel513', 'pixel514', 'pixel515', 'pixel516', 'pixel517', 'pixel518', 'pixel519', 'pixel520', 'pixel521', 'pixel522', 'pixel523', 'pixel524', 'pixel525', 'pixel526', 'pixel527', 'pixel528', 'pixel529', 'pixel530', 'pixel531', 'pixel532', 'pixel533', 'pixel534', 'pixel535', 'pixel536', 'pixel537', 'pixel538', 'pixel539', 'pixel540', 'pixel541', 'pixel542', 'pixel543', 'pixel544', 'pixel545', 'pixel546', 'pixel547', 'pixel548', 'pixel549', 'pixel550', 'pixel551', 'pixel552', 'pixel553', 'pixel554', 'pixel555', 'pixel556', 'pixel557', 'pixel558', 'pixel559', 'pixel560', 'pixel561', 'pixel562', 'pixel563', 'pixel564', 'pixel565', 'pixel566', 'pixel567', 'pixel568', 'pixel569', 'pixel570', 'pixel571', 'pixel572', 'pixel573', 'pixel574', 'pixel575', 'pixel576', 'pixel577', 'pixel578', 'pixel579', 'pixel580', 'pixel581', 'pixel582', 'pixel583', 'pixel584', 'pixel585', 'pixel586', 'pixel587', 'pixel588', 'pixel589', 'pixel590', 'pixel591', 'pixel592', 'pixel593', 'pixel594', 'pixel595', 'pixel596', 'pixel597', 'pixel598', 'pixel599', 'pixel600', 'pixel601', 'pixel602', 'pixel603', 'pixel604', 'pixel605', 'pixel606', 'pixel607', 'pixel608', 'pixel609', 'pixel610', 'pixel611', 'pixel612', 'pixel613', 'pixel614', 'pixel615', 'pixel616', 'pixel617', 'pixel618', 'pixel619', 'pixel620', 'pixel621', 'pixel622', 'pixel623', 'pixel624', 'pixel625', 'pixel626', 'pixel627', 'pixel628', 'pixel629', 'pixel630', 'pixel631', 'pixel632', 'pixel633', 'pixel634', 'pixel635', 'pixel636', 'pixel637', 'pixel638', 'pixel639', 'pixel640', 'pixel641', 'pixel642', 'pixel643', 'pixel644', 'pixel645', 'pixel646', 'pixel647', 'pixel648', 'pixel649', 'pixel650', 'pixel651', 'pixel652', 'pixel653', 'pixel654', 'pixel655', 'pixel656', 'pixel657', 'pixel658', 'pixel659', 'pixel660', 'pixel661', 'pixel662', 'pixel663', 'pixel664', 'pixel665', 'pixel666', 'pixel667', 'pixel668', 'pixel669', 'pixel670', 'pixel671', 'pixel672', 'pixel673', 'pixel674', 'pixel675', 'pixel676', 'pixel677', 'pixel678', 'pixel679', 'pixel680', 'pixel681', 'pixel682', 'pixel683', 'pixel684', 'pixel685', 'pixel686', 'pixel687', 'pixel688', 'pixel689', 'pixel690', 'pixel691', 'pixel692', 'pixel693', 'pixel694', 'pixel695', 'pixel696', 'pixel697', 'pixel698', 'pixel699', 'pixel700', 'pixel701', 'pixel702', 'pixel703', 'pixel704', 'pixel705', 'pixel706', 'pixel707', 'pixel708', 'pixel709', 'pixel710', 'pixel711', 'pixel712', 'pixel713', 'pixel714', 'pixel715', 'pixel716', 'pixel717', 'pixel718', 'pixel719', 'pixel720', 'pixel721', 'pixel722', 'pixel723', 'pixel724', 'pixel725', 'pixel726', 'pixel727', 'pixel728', 'pixel729', 'pixel730', 'pixel731', 'pixel732', 'pixel733', 'pixel734', 'pixel735', 'pixel736', 'pixel737', 'pixel738', 'pixel739', 'pixel740', 'pixel741', 'pixel742', 'pixel743', 'pixel744', 'pixel745', 'pixel746', 'pixel747', 'pixel748', 'pixel749', 'pixel750', 'pixel751', 'pixel752', 'pixel753', 'pixel754', 'pixel755', 'pixel756', 'pixel757', 'pixel758', 'pixel759', 'pixel760', 'pixel761', 'pixel762', 'pixel763', 'pixel764', 'pixel765', 'pixel766', 'pixel767', 'pixel768', 'pixel769', 'pixel770', 'pixel771', 'pixel772', 'pixel773', 'pixel774', 'pixel775', 'pixel776', 'pixel777', 'pixel778', 'pixel779', 'pixel780', 'pixel781', 'pixel782', 'pixel783', 'pixel784'], 'target_names': ['class'], 'DESCR': "**Author**: Yann LeCun, Corinna Cortes, Christopher J.C. Burges  \n**Source**: [MNIST Website](http://yann.lecun.com/exdb/mnist/) - Date unknown  \n**Please cite**:  \n\nThe MNIST database of handwritten digits with 784 features, raw data available at: http://yann.lecun.com/exdb/mnist/. It can be split in a training set of the first 60,000 examples, and a test set of 10,000 examples  \n\nIt is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting. The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.  \n\nWith some classification methods (particularly template-based methods, such as SVM and K-nearest neighbors), the error rate improves when the digits are centered by bounding box rather than center of mass. If you do this kind of pre-processing, you should report it in your publications. The MNIST database was constructed from NIST's NIST originally designated SD-3 as their training set and SD-1 as their test set. However, SD-3 is much cleaner and easier to recognize than SD-1. The reason for this can be found on the fact that SD-3 was collected among Census Bureau employees, while SD-1 was collected among high-school students. Drawing sensible conclusions from learning experiments requires that the result be independent of the choice of training set and test among the complete set of samples. Therefore it was necessary to build a new database by mixing NIST's datasets.  \n\nThe MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers. We made sure that the sets of writers of the training set and test set were disjoint. SD-1 contains 58,527 digit images written by 500 different writers. In contrast to SD-3, where blocks of data from each writer appeared in sequence, the data in SD-1 is scrambled. Writer identities for SD-1 is available and we used this information to unscramble the writers. We then split SD-1 in two: characters written by the first 250 writers went into our new training set. The remaining 250 writers were placed in our test set. Thus we had two sets with nearly 30,000 examples each. The new training set was completed with enough examples from SD-3, starting at pattern # 0, to make a full set of 60,000 training patterns. Similarly, the new test set was completed with SD-3 examples starting at pattern # 35,000 to make a full set with 60,000 test patterns. Only a subset of 10,000 test images (5,000 from SD-1 and 5,000 from SD-3) is available on this site. The full 60,000 sample training set is available.\n\nDownloaded from openml.org.", 'details': {'id': '554', 'name': 'mnist_784', 'version': '1', 'description_version': '1', 'format': 'ARFF', 'creator': ['Yann LeCun', 'Corinna Cortes', 'Christopher J.C. Burges'], 'upload_date': '2014-09-29T03:28:38', 'language': 'English', 'licence': 'Public', 'url': 'https://www.openml.org/data/v1/download/52667/mnist_784.arff', 'file_id': '52667', 'default_target_attribute': 'class', 'tag': ['AzurePilot', 'OpenML-CC18', 'OpenML100', 'study_1', 'study_123', 'study_41', 'study_99', 'vision'], 'visibility': 'public', 'status': 'active', 'processing_date': '2020-11-20 20:12:09', 'md5_checksum': '0298d579eb1b86163de7723944c7e495'}, 'categories': {}, 'url': 'https://www.openml.org/d/554'}

我们来看一下该数据集的样本量和特征数

import numpy as np
from sklearn.datasets import fetch_openml

if __name__ == "__main__":

    mnist = fetch_openml('mnist_784')
    print(mnist)
    X, y = mnist['data'], mnist['target']
    print(X.shape)

运行结果

{'data': array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]]), 'target': array(['5', '0', '4', ..., '4', '5', '6'], dtype=object), 'frame': None, 'feature_names': ['pixel1', 'pixel2', 'pixel3', 'pixel4', 'pixel5', 'pixel6', 'pixel7', 'pixel8', 'pixel9', 'pixel10', 'pixel11', 'pixel12', 'pixel13', 'pixel14', 'pixel15', 'pixel16', 'pixel17', 'pixel18', 'pixel19', 'pixel20', 'pixel21', 'pixel22', 'pixel23', 'pixel24', 'pixel25', 'pixel26', 'pixel27', 'pixel28', 'pixel29', 'pixel30', 'pixel31', 'pixel32', 'pixel33', 'pixel34', 'pixel35', 'pixel36', 'pixel37', 'pixel38', 'pixel39', 'pixel40', 'pixel41', 'pixel42', 'pixel43', 'pixel44', 'pixel45', 'pixel46', 'pixel47', 'pixel48', 'pixel49', 'pixel50', 'pixel51', 'pixel52', 'pixel53', 'pixel54', 'pixel55', 'pixel56', 'pixel57', 'pixel58', 'pixel59', 'pixel60', 'pixel61', 'pixel62', 'pixel63', 'pixel64', 'pixel65', 'pixel66', 'pixel67', 'pixel68', 'pixel69', 'pixel70', 'pixel71', 'pixel72', 'pixel73', 'pixel74', 'pixel75', 'pixel76', 'pixel77', 'pixel78', 'pixel79', 'pixel80', 'pixel81', 'pixel82', 'pixel83', 'pixel84', 'pixel85', 'pixel86', 'pixel87', 'pixel88', 'pixel89', 'pixel90', 'pixel91', 'pixel92', 'pixel93', 'pixel94', 'pixel95', 'pixel96', 'pixel97', 'pixel98', 'pixel99', 'pixel100', 'pixel101', 'pixel102', 'pixel103', 'pixel104', 'pixel105', 'pixel106', 'pixel107', 'pixel108', 'pixel109', 'pixel110', 'pixel111', 'pixel112', 'pixel113', 'pixel114', 'pixel115', 'pixel116', 'pixel117', 'pixel118', 'pixel119', 'pixel120', 'pixel121', 'pixel122', 'pixel123', 'pixel124', 'pixel125', 'pixel126', 'pixel127', 'pixel128', 'pixel129', 'pixel130', 'pixel131', 'pixel132', 'pixel133', 'pixel134', 'pixel135', 'pixel136', 'pixel137', 'pixel138', 'pixel139', 'pixel140', 'pixel141', 'pixel142', 'pixel143', 'pixel144', 'pixel145', 'pixel146', 'pixel147', 'pixel148', 'pixel149', 'pixel150', 'pixel151', 'pixel152', 'pixel153', 'pixel154', 'pixel155', 'pixel156', 'pixel157', 'pixel158', 'pixel159', 'pixel160', 'pixel161', 'pixel162', 'pixel163', 'pixel164', 'pixel165', 'pixel166', 'pixel167', 'pixel168', 'pixel169', 'pixel170', 'pixel171', 'pixel172', 'pixel173', 'pixel174', 'pixel175', 'pixel176', 'pixel177', 'pixel178', 'pixel179', 'pixel180', 'pixel181', 'pixel182', 'pixel183', 'pixel184', 'pixel185', 'pixel186', 'pixel187', 'pixel188', 'pixel189', 'pixel190', 'pixel191', 'pixel192', 'pixel193', 'pixel194', 'pixel195', 'pixel196', 'pixel197', 'pixel198', 'pixel199', 'pixel200', 'pixel201', 'pixel202', 'pixel203', 'pixel204', 'pixel205', 'pixel206', 'pixel207', 'pixel208', 'pixel209', 'pixel210', 'pixel211', 'pixel212', 'pixel213', 'pixel214', 'pixel215', 'pixel216', 'pixel217', 'pixel218', 'pixel219', 'pixel220', 'pixel221', 'pixel222', 'pixel223', 'pixel224', 'pixel225', 'pixel226', 'pixel227', 'pixel228', 'pixel229', 'pixel230', 'pixel231', 'pixel232', 'pixel233', 'pixel234', 'pixel235', 'pixel236', 'pixel237', 'pixel238', 'pixel239', 'pixel240', 'pixel241', 'pixel242', 'pixel243', 'pixel244', 'pixel245', 'pixel246', 'pixel247', 'pixel248', 'pixel249', 'pixel250', 'pixel251', 'pixel252', 'pixel253', 'pixel254', 'pixel255', 'pixel256', 'pixel257', 'pixel258', 'pixel259', 'pixel260', 'pixel261', 'pixel262', 'pixel263', 'pixel264', 'pixel265', 'pixel266', 'pixel267', 'pixel268', 'pixel269', 'pixel270', 'pixel271', 'pixel272', 'pixel273', 'pixel274', 'pixel275', 'pixel276', 'pixel277', 'pixel278', 'pixel279', 'pixel280', 'pixel281', 'pixel282', 'pixel283', 'pixel284', 'pixel285', 'pixel286', 'pixel287', 'pixel288', 'pixel289', 'pixel290', 'pixel291', 'pixel292', 'pixel293', 'pixel294', 'pixel295', 'pixel296', 'pixel297', 'pixel298', 'pixel299', 'pixel300', 'pixel301', 'pixel302', 'pixel303', 'pixel304', 'pixel305', 'pixel306', 'pixel307', 'pixel308', 'pixel309', 'pixel310', 'pixel311', 'pixel312', 'pixel313', 'pixel314', 'pixel315', 'pixel316', 'pixel317', 'pixel318', 'pixel319', 'pixel320', 'pixel321', 'pixel322', 'pixel323', 'pixel324', 'pixel325', 'pixel326', 'pixel327', 'pixel328', 'pixel329', 'pixel330', 'pixel331', 'pixel332', 'pixel333', 'pixel334', 'pixel335', 'pixel336', 'pixel337', 'pixel338', 'pixel339', 'pixel340', 'pixel341', 'pixel342', 'pixel343', 'pixel344', 'pixel345', 'pixel346', 'pixel347', 'pixel348', 'pixel349', 'pixel350', 'pixel351', 'pixel352', 'pixel353', 'pixel354', 'pixel355', 'pixel356', 'pixel357', 'pixel358', 'pixel359', 'pixel360', 'pixel361', 'pixel362', 'pixel363', 'pixel364', 'pixel365', 'pixel366', 'pixel367', 'pixel368', 'pixel369', 'pixel370', 'pixel371', 'pixel372', 'pixel373', 'pixel374', 'pixel375', 'pixel376', 'pixel377', 'pixel378', 'pixel379', 'pixel380', 'pixel381', 'pixel382', 'pixel383', 'pixel384', 'pixel385', 'pixel386', 'pixel387', 'pixel388', 'pixel389', 'pixel390', 'pixel391', 'pixel392', 'pixel393', 'pixel394', 'pixel395', 'pixel396', 'pixel397', 'pixel398', 'pixel399', 'pixel400', 'pixel401', 'pixel402', 'pixel403', 'pixel404', 'pixel405', 'pixel406', 'pixel407', 'pixel408', 'pixel409', 'pixel410', 'pixel411', 'pixel412', 'pixel413', 'pixel414', 'pixel415', 'pixel416', 'pixel417', 'pixel418', 'pixel419', 'pixel420', 'pixel421', 'pixel422', 'pixel423', 'pixel424', 'pixel425', 'pixel426', 'pixel427', 'pixel428', 'pixel429', 'pixel430', 'pixel431', 'pixel432', 'pixel433', 'pixel434', 'pixel435', 'pixel436', 'pixel437', 'pixel438', 'pixel439', 'pixel440', 'pixel441', 'pixel442', 'pixel443', 'pixel444', 'pixel445', 'pixel446', 'pixel447', 'pixel448', 'pixel449', 'pixel450', 'pixel451', 'pixel452', 'pixel453', 'pixel454', 'pixel455', 'pixel456', 'pixel457', 'pixel458', 'pixel459', 'pixel460', 'pixel461', 'pixel462', 'pixel463', 'pixel464', 'pixel465', 'pixel466', 'pixel467', 'pixel468', 'pixel469', 'pixel470', 'pixel471', 'pixel472', 'pixel473', 'pixel474', 'pixel475', 'pixel476', 'pixel477', 'pixel478', 'pixel479', 'pixel480', 'pixel481', 'pixel482', 'pixel483', 'pixel484', 'pixel485', 'pixel486', 'pixel487', 'pixel488', 'pixel489', 'pixel490', 'pixel491', 'pixel492', 'pixel493', 'pixel494', 'pixel495', 'pixel496', 'pixel497', 'pixel498', 'pixel499', 'pixel500', 'pixel501', 'pixel502', 'pixel503', 'pixel504', 'pixel505', 'pixel506', 'pixel507', 'pixel508', 'pixel509', 'pixel510', 'pixel511', 'pixel512', 'pixel513', 'pixel514', 'pixel515', 'pixel516', 'pixel517', 'pixel518', 'pixel519', 'pixel520', 'pixel521', 'pixel522', 'pixel523', 'pixel524', 'pixel525', 'pixel526', 'pixel527', 'pixel528', 'pixel529', 'pixel530', 'pixel531', 'pixel532', 'pixel533', 'pixel534', 'pixel535', 'pixel536', 'pixel537', 'pixel538', 'pixel539', 'pixel540', 'pixel541', 'pixel542', 'pixel543', 'pixel544', 'pixel545', 'pixel546', 'pixel547', 'pixel548', 'pixel549', 'pixel550', 'pixel551', 'pixel552', 'pixel553', 'pixel554', 'pixel555', 'pixel556', 'pixel557', 'pixel558', 'pixel559', 'pixel560', 'pixel561', 'pixel562', 'pixel563', 'pixel564', 'pixel565', 'pixel566', 'pixel567', 'pixel568', 'pixel569', 'pixel570', 'pixel571', 'pixel572', 'pixel573', 'pixel574', 'pixel575', 'pixel576', 'pixel577', 'pixel578', 'pixel579', 'pixel580', 'pixel581', 'pixel582', 'pixel583', 'pixel584', 'pixel585', 'pixel586', 'pixel587', 'pixel588', 'pixel589', 'pixel590', 'pixel591', 'pixel592', 'pixel593', 'pixel594', 'pixel595', 'pixel596', 'pixel597', 'pixel598', 'pixel599', 'pixel600', 'pixel601', 'pixel602', 'pixel603', 'pixel604', 'pixel605', 'pixel606', 'pixel607', 'pixel608', 'pixel609', 'pixel610', 'pixel611', 'pixel612', 'pixel613', 'pixel614', 'pixel615', 'pixel616', 'pixel617', 'pixel618', 'pixel619', 'pixel620', 'pixel621', 'pixel622', 'pixel623', 'pixel624', 'pixel625', 'pixel626', 'pixel627', 'pixel628', 'pixel629', 'pixel630', 'pixel631', 'pixel632', 'pixel633', 'pixel634', 'pixel635', 'pixel636', 'pixel637', 'pixel638', 'pixel639', 'pixel640', 'pixel641', 'pixel642', 'pixel643', 'pixel644', 'pixel645', 'pixel646', 'pixel647', 'pixel648', 'pixel649', 'pixel650', 'pixel651', 'pixel652', 'pixel653', 'pixel654', 'pixel655', 'pixel656', 'pixel657', 'pixel658', 'pixel659', 'pixel660', 'pixel661', 'pixel662', 'pixel663', 'pixel664', 'pixel665', 'pixel666', 'pixel667', 'pixel668', 'pixel669', 'pixel670', 'pixel671', 'pixel672', 'pixel673', 'pixel674', 'pixel675', 'pixel676', 'pixel677', 'pixel678', 'pixel679', 'pixel680', 'pixel681', 'pixel682', 'pixel683', 'pixel684', 'pixel685', 'pixel686', 'pixel687', 'pixel688', 'pixel689', 'pixel690', 'pixel691', 'pixel692', 'pixel693', 'pixel694', 'pixel695', 'pixel696', 'pixel697', 'pixel698', 'pixel699', 'pixel700', 'pixel701', 'pixel702', 'pixel703', 'pixel704', 'pixel705', 'pixel706', 'pixel707', 'pixel708', 'pixel709', 'pixel710', 'pixel711', 'pixel712', 'pixel713', 'pixel714', 'pixel715', 'pixel716', 'pixel717', 'pixel718', 'pixel719', 'pixel720', 'pixel721', 'pixel722', 'pixel723', 'pixel724', 'pixel725', 'pixel726', 'pixel727', 'pixel728', 'pixel729', 'pixel730', 'pixel731', 'pixel732', 'pixel733', 'pixel734', 'pixel735', 'pixel736', 'pixel737', 'pixel738', 'pixel739', 'pixel740', 'pixel741', 'pixel742', 'pixel743', 'pixel744', 'pixel745', 'pixel746', 'pixel747', 'pixel748', 'pixel749', 'pixel750', 'pixel751', 'pixel752', 'pixel753', 'pixel754', 'pixel755', 'pixel756', 'pixel757', 'pixel758', 'pixel759', 'pixel760', 'pixel761', 'pixel762', 'pixel763', 'pixel764', 'pixel765', 'pixel766', 'pixel767', 'pixel768', 'pixel769', 'pixel770', 'pixel771', 'pixel772', 'pixel773', 'pixel774', 'pixel775', 'pixel776', 'pixel777', 'pixel778', 'pixel779', 'pixel780', 'pixel781', 'pixel782', 'pixel783', 'pixel784'], 'target_names': ['class'], 'DESCR': "**Author**: Yann LeCun, Corinna Cortes, Christopher J.C. Burges  \n**Source**: [MNIST Website](http://yann.lecun.com/exdb/mnist/) - Date unknown  \n**Please cite**:  \n\nThe MNIST database of handwritten digits with 784 features, raw data available at: http://yann.lecun.com/exdb/mnist/. It can be split in a training set of the first 60,000 examples, and a test set of 10,000 examples  \n\nIt is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting. The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.  \n\nWith some classification methods (particularly template-based methods, such as SVM and K-nearest neighbors), the error rate improves when the digits are centered by bounding box rather than center of mass. If you do this kind of pre-processing, you should report it in your publications. The MNIST database was constructed from NIST's NIST originally designated SD-3 as their training set and SD-1 as their test set. However, SD-3 is much cleaner and easier to recognize than SD-1. The reason for this can be found on the fact that SD-3 was collected among Census Bureau employees, while SD-1 was collected among high-school students. Drawing sensible conclusions from learning experiments requires that the result be independent of the choice of training set and test among the complete set of samples. Therefore it was necessary to build a new database by mixing NIST's datasets.  \n\nThe MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers. We made sure that the sets of writers of the training set and test set were disjoint. SD-1 contains 58,527 digit images written by 500 different writers. In contrast to SD-3, where blocks of data from each writer appeared in sequence, the data in SD-1 is scrambled. Writer identities for SD-1 is available and we used this information to unscramble the writers. We then split SD-1 in two: characters written by the first 250 writers went into our new training set. The remaining 250 writers were placed in our test set. Thus we had two sets with nearly 30,000 examples each. The new training set was completed with enough examples from SD-3, starting at pattern # 0, to make a full set of 60,000 training patterns. Similarly, the new test set was completed with SD-3 examples starting at pattern # 35,000 to make a full set with 60,000 test patterns. Only a subset of 10,000 test images (5,000 from SD-1 and 5,000 from SD-3) is available on this site. The full 60,000 sample training set is available.\n\nDownloaded from openml.org.", 'details': {'id': '554', 'name': 'mnist_784', 'version': '1', 'description_version': '1', 'format': 'ARFF', 'creator': ['Yann LeCun', 'Corinna Cortes', 'Christopher J.C. Burges'], 'upload_date': '2014-09-29T03:28:38', 'language': 'English', 'licence': 'Public', 'url': 'https://www.openml.org/data/v1/download/52667/mnist_784.arff', 'file_id': '52667', 'default_target_attribute': 'class', 'tag': ['AzurePilot', 'OpenML-CC18', 'OpenML100', 'study_1', 'study_123', 'study_41', 'study_99', 'vision'], 'visibility': 'public', 'status': 'active', 'processing_date': '2020-11-20 20:12:09', 'md5_checksum': '0298d579eb1b86163de7723944c7e495'}, 'categories': {}, 'url': 'https://www.openml.org/d/554'}
(70000, 784)

通过结果可以看出,它有70000个样本数,784个特征。

对于一般的数据集分类,我们会采用from sklearn.model_selection import train_test_split来进行训练数据和测试数据集的分类,但是对于MNIST来说,它前60000个数据就是我们的训练数据集,后10000个数据就是测试数据集。

import numpy as np
from sklearn.datasets import fetch_openml

if __name__ == "__main__":

    mnist = fetch_openml('mnist_784')
    print(mnist)
    X, y = mnist['data'], mnist['target']
    print(X.shape)
    X_train = np.array(X[:60000], dtype=float)
    y_train = np.array(y[:60000], dtype=float)
    X_test = np.array(X[60000:], dtype=float)
    y_test = np.array(y[60000:], dtype=float)
    print(X_train.shape)
    print(y_train.shape)
    print(X_test.shape)
    print(y_test.shape)

运行结果

{'data': array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]]), 'target': array(['5', '0', '4', ..., '4', '5', '6'], dtype=object), 'frame': None, 'feature_names': ['pixel1', 'pixel2', 'pixel3', 'pixel4', 'pixel5', 'pixel6', 'pixel7', 'pixel8', 'pixel9', 'pixel10', 'pixel11', 'pixel12', 'pixel13', 'pixel14', 'pixel15', 'pixel16', 'pixel17', 'pixel18', 'pixel19', 'pixel20', 'pixel21', 'pixel22', 'pixel23', 'pixel24', 'pixel25', 'pixel26', 'pixel27', 'pixel28', 'pixel29', 'pixel30', 'pixel31', 'pixel32', 'pixel33', 'pixel34', 'pixel35', 'pixel36', 'pixel37', 'pixel38', 'pixel39', 'pixel40', 'pixel41', 'pixel42', 'pixel43', 'pixel44', 'pixel45', 'pixel46', 'pixel47', 'pixel48', 'pixel49', 'pixel50', 'pixel51', 'pixel52', 'pixel53', 'pixel54', 'pixel55', 'pixel56', 'pixel57', 'pixel58', 'pixel59', 'pixel60', 'pixel61', 'pixel62', 'pixel63', 'pixel64', 'pixel65', 'pixel66', 'pixel67', 'pixel68', 'pixel69', 'pixel70', 'pixel71', 'pixel72', 'pixel73', 'pixel74', 'pixel75', 'pixel76', 'pixel77', 'pixel78', 'pixel79', 'pixel80', 'pixel81', 'pixel82', 'pixel83', 'pixel84', 'pixel85', 'pixel86', 'pixel87', 'pixel88', 'pixel89', 'pixel90', 'pixel91', 'pixel92', 'pixel93', 'pixel94', 'pixel95', 'pixel96', 'pixel97', 'pixel98', 'pixel99', 'pixel100', 'pixel101', 'pixel102', 'pixel103', 'pixel104', 'pixel105', 'pixel106', 'pixel107', 'pixel108', 'pixel109', 'pixel110', 'pixel111', 'pixel112', 'pixel113', 'pixel114', 'pixel115', 'pixel116', 'pixel117', 'pixel118', 'pixel119', 'pixel120', 'pixel121', 'pixel122', 'pixel123', 'pixel124', 'pixel125', 'pixel126', 'pixel127', 'pixel128', 'pixel129', 'pixel130', 'pixel131', 'pixel132', 'pixel133', 'pixel134', 'pixel135', 'pixel136', 'pixel137', 'pixel138', 'pixel139', 'pixel140', 'pixel141', 'pixel142', 'pixel143', 'pixel144', 'pixel145', 'pixel146', 'pixel147', 'pixel148', 'pixel149', 'pixel150', 'pixel151', 'pixel152', 'pixel153', 'pixel154', 'pixel155', 'pixel156', 'pixel157', 'pixel158', 'pixel159', 'pixel160', 'pixel161', 'pixel162', 'pixel163', 'pixel164', 'pixel165', 'pixel166', 'pixel167', 'pixel168', 'pixel169', 'pixel170', 'pixel171', 'pixel172', 'pixel173', 'pixel174', 'pixel175', 'pixel176', 'pixel177', 'pixel178', 'pixel179', 'pixel180', 'pixel181', 'pixel182', 'pixel183', 'pixel184', 'pixel185', 'pixel186', 'pixel187', 'pixel188', 'pixel189', 'pixel190', 'pixel191', 'pixel192', 'pixel193', 'pixel194', 'pixel195', 'pixel196', 'pixel197', 'pixel198', 'pixel199', 'pixel200', 'pixel201', 'pixel202', 'pixel203', 'pixel204', 'pixel205', 'pixel206', 'pixel207', 'pixel208', 'pixel209', 'pixel210', 'pixel211', 'pixel212', 'pixel213', 'pixel214', 'pixel215', 'pixel216', 'pixel217', 'pixel218', 'pixel219', 'pixel220', 'pixel221', 'pixel222', 'pixel223', 'pixel224', 'pixel225', 'pixel226', 'pixel227', 'pixel228', 'pixel229', 'pixel230', 'pixel231', 'pixel232', 'pixel233', 'pixel234', 'pixel235', 'pixel236', 'pixel237', 'pixel238', 'pixel239', 'pixel240', 'pixel241', 'pixel242', 'pixel243', 'pixel244', 'pixel245', 'pixel246', 'pixel247', 'pixel248', 'pixel249', 'pixel250', 'pixel251', 'pixel252', 'pixel253', 'pixel254', 'pixel255', 'pixel256', 'pixel257', 'pixel258', 'pixel259', 'pixel260', 'pixel261', 'pixel262', 'pixel263', 'pixel264', 'pixel265', 'pixel266', 'pixel267', 'pixel268', 'pixel269', 'pixel270', 'pixel271', 'pixel272', 'pixel273', 'pixel274', 'pixel275', 'pixel276', 'pixel277', 'pixel278', 'pixel279', 'pixel280', 'pixel281', 'pixel282', 'pixel283', 'pixel284', 'pixel285', 'pixel286', 'pixel287', 'pixel288', 'pixel289', 'pixel290', 'pixel291', 'pixel292', 'pixel293', 'pixel294', 'pixel295', 'pixel296', 'pixel297', 'pixel298', 'pixel299', 'pixel300', 'pixel301', 'pixel302', 'pixel303', 'pixel304', 'pixel305', 'pixel306', 'pixel307', 'pixel308', 'pixel309', 'pixel310', 'pixel311', 'pixel312', 'pixel313', 'pixel314', 'pixel315', 'pixel316', 'pixel317', 'pixel318', 'pixel319', 'pixel320', 'pixel321', 'pixel322', 'pixel323', 'pixel324', 'pixel325', 'pixel326', 'pixel327', 'pixel328', 'pixel329', 'pixel330', 'pixel331', 'pixel332', 'pixel333', 'pixel334', 'pixel335', 'pixel336', 'pixel337', 'pixel338', 'pixel339', 'pixel340', 'pixel341', 'pixel342', 'pixel343', 'pixel344', 'pixel345', 'pixel346', 'pixel347', 'pixel348', 'pixel349', 'pixel350', 'pixel351', 'pixel352', 'pixel353', 'pixel354', 'pixel355', 'pixel356', 'pixel357', 'pixel358', 'pixel359', 'pixel360', 'pixel361', 'pixel362', 'pixel363', 'pixel364', 'pixel365', 'pixel366', 'pixel367', 'pixel368', 'pixel369', 'pixel370', 'pixel371', 'pixel372', 'pixel373', 'pixel374', 'pixel375', 'pixel376', 'pixel377', 'pixel378', 'pixel379', 'pixel380', 'pixel381', 'pixel382', 'pixel383', 'pixel384', 'pixel385', 'pixel386', 'pixel387', 'pixel388', 'pixel389', 'pixel390', 'pixel391', 'pixel392', 'pixel393', 'pixel394', 'pixel395', 'pixel396', 'pixel397', 'pixel398', 'pixel399', 'pixel400', 'pixel401', 'pixel402', 'pixel403', 'pixel404', 'pixel405', 'pixel406', 'pixel407', 'pixel408', 'pixel409', 'pixel410', 'pixel411', 'pixel412', 'pixel413', 'pixel414', 'pixel415', 'pixel416', 'pixel417', 'pixel418', 'pixel419', 'pixel420', 'pixel421', 'pixel422', 'pixel423', 'pixel424', 'pixel425', 'pixel426', 'pixel427', 'pixel428', 'pixel429', 'pixel430', 'pixel431', 'pixel432', 'pixel433', 'pixel434', 'pixel435', 'pixel436', 'pixel437', 'pixel438', 'pixel439', 'pixel440', 'pixel441', 'pixel442', 'pixel443', 'pixel444', 'pixel445', 'pixel446', 'pixel447', 'pixel448', 'pixel449', 'pixel450', 'pixel451', 'pixel452', 'pixel453', 'pixel454', 'pixel455', 'pixel456', 'pixel457', 'pixel458', 'pixel459', 'pixel460', 'pixel461', 'pixel462', 'pixel463', 'pixel464', 'pixel465', 'pixel466', 'pixel467', 'pixel468', 'pixel469', 'pixel470', 'pixel471', 'pixel472', 'pixel473', 'pixel474', 'pixel475', 'pixel476', 'pixel477', 'pixel478', 'pixel479', 'pixel480', 'pixel481', 'pixel482', 'pixel483', 'pixel484', 'pixel485', 'pixel486', 'pixel487', 'pixel488', 'pixel489', 'pixel490', 'pixel491', 'pixel492', 'pixel493', 'pixel494', 'pixel495', 'pixel496', 'pixel497', 'pixel498', 'pixel499', 'pixel500', 'pixel501', 'pixel502', 'pixel503', 'pixel504', 'pixel505', 'pixel506', 'pixel507', 'pixel508', 'pixel509', 'pixel510', 'pixel511', 'pixel512', 'pixel513', 'pixel514', 'pixel515', 'pixel516', 'pixel517', 'pixel518', 'pixel519', 'pixel520', 'pixel521', 'pixel522', 'pixel523', 'pixel524', 'pixel525', 'pixel526', 'pixel527', 'pixel528', 'pixel529', 'pixel530', 'pixel531', 'pixel532', 'pixel533', 'pixel534', 'pixel535', 'pixel536', 'pixel537', 'pixel538', 'pixel539', 'pixel540', 'pixel541', 'pixel542', 'pixel543', 'pixel544', 'pixel545', 'pixel546', 'pixel547', 'pixel548', 'pixel549', 'pixel550', 'pixel551', 'pixel552', 'pixel553', 'pixel554', 'pixel555', 'pixel556', 'pixel557', 'pixel558', 'pixel559', 'pixel560', 'pixel561', 'pixel562', 'pixel563', 'pixel564', 'pixel565', 'pixel566', 'pixel567', 'pixel568', 'pixel569', 'pixel570', 'pixel571', 'pixel572', 'pixel573', 'pixel574', 'pixel575', 'pixel576', 'pixel577', 'pixel578', 'pixel579', 'pixel580', 'pixel581', 'pixel582', 'pixel583', 'pixel584', 'pixel585', 'pixel586', 'pixel587', 'pixel588', 'pixel589', 'pixel590', 'pixel591', 'pixel592', 'pixel593', 'pixel594', 'pixel595', 'pixel596', 'pixel597', 'pixel598', 'pixel599', 'pixel600', 'pixel601', 'pixel602', 'pixel603', 'pixel604', 'pixel605', 'pixel606', 'pixel607', 'pixel608', 'pixel609', 'pixel610', 'pixel611', 'pixel612', 'pixel613', 'pixel614', 'pixel615', 'pixel616', 'pixel617', 'pixel618', 'pixel619', 'pixel620', 'pixel621', 'pixel622', 'pixel623', 'pixel624', 'pixel625', 'pixel626', 'pixel627', 'pixel628', 'pixel629', 'pixel630', 'pixel631', 'pixel632', 'pixel633', 'pixel634', 'pixel635', 'pixel636', 'pixel637', 'pixel638', 'pixel639', 'pixel640', 'pixel641', 'pixel642', 'pixel643', 'pixel644', 'pixel645', 'pixel646', 'pixel647', 'pixel648', 'pixel649', 'pixel650', 'pixel651', 'pixel652', 'pixel653', 'pixel654', 'pixel655', 'pixel656', 'pixel657', 'pixel658', 'pixel659', 'pixel660', 'pixel661', 'pixel662', 'pixel663', 'pixel664', 'pixel665', 'pixel666', 'pixel667', 'pixel668', 'pixel669', 'pixel670', 'pixel671', 'pixel672', 'pixel673', 'pixel674', 'pixel675', 'pixel676', 'pixel677', 'pixel678', 'pixel679', 'pixel680', 'pixel681', 'pixel682', 'pixel683', 'pixel684', 'pixel685', 'pixel686', 'pixel687', 'pixel688', 'pixel689', 'pixel690', 'pixel691', 'pixel692', 'pixel693', 'pixel694', 'pixel695', 'pixel696', 'pixel697', 'pixel698', 'pixel699', 'pixel700', 'pixel701', 'pixel702', 'pixel703', 'pixel704', 'pixel705', 'pixel706', 'pixel707', 'pixel708', 'pixel709', 'pixel710', 'pixel711', 'pixel712', 'pixel713', 'pixel714', 'pixel715', 'pixel716', 'pixel717', 'pixel718', 'pixel719', 'pixel720', 'pixel721', 'pixel722', 'pixel723', 'pixel724', 'pixel725', 'pixel726', 'pixel727', 'pixel728', 'pixel729', 'pixel730', 'pixel731', 'pixel732', 'pixel733', 'pixel734', 'pixel735', 'pixel736', 'pixel737', 'pixel738', 'pixel739', 'pixel740', 'pixel741', 'pixel742', 'pixel743', 'pixel744', 'pixel745', 'pixel746', 'pixel747', 'pixel748', 'pixel749', 'pixel750', 'pixel751', 'pixel752', 'pixel753', 'pixel754', 'pixel755', 'pixel756', 'pixel757', 'pixel758', 'pixel759', 'pixel760', 'pixel761', 'pixel762', 'pixel763', 'pixel764', 'pixel765', 'pixel766', 'pixel767', 'pixel768', 'pixel769', 'pixel770', 'pixel771', 'pixel772', 'pixel773', 'pixel774', 'pixel775', 'pixel776', 'pixel777', 'pixel778', 'pixel779', 'pixel780', 'pixel781', 'pixel782', 'pixel783', 'pixel784'], 'target_names': ['class'], 'DESCR': "**Author**: Yann LeCun, Corinna Cortes, Christopher J.C. Burges  \n**Source**: [MNIST Website](http://yann.lecun.com/exdb/mnist/) - Date unknown  \n**Please cite**:  \n\nThe MNIST database of handwritten digits with 784 features, raw data available at: http://yann.lecun.com/exdb/mnist/. It can be split in a training set of the first 60,000 examples, and a test set of 10,000 examples  \n\nIt is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting. The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.  \n\nWith some classification methods (particularly template-based methods, such as SVM and K-nearest neighbors), the error rate improves when the digits are centered by bounding box rather than center of mass. If you do this kind of pre-processing, you should report it in your publications. The MNIST database was constructed from NIST's NIST originally designated SD-3 as their training set and SD-1 as their test set. However, SD-3 is much cleaner and easier to recognize than SD-1. The reason for this can be found on the fact that SD-3 was collected among Census Bureau employees, while SD-1 was collected among high-school students. Drawing sensible conclusions from learning experiments requires that the result be independent of the choice of training set and test among the complete set of samples. Therefore it was necessary to build a new database by mixing NIST's datasets.  \n\nThe MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers. We made sure that the sets of writers of the training set and test set were disjoint. SD-1 contains 58,527 digit images written by 500 different writers. In contrast to SD-3, where blocks of data from each writer appeared in sequence, the data in SD-1 is scrambled. Writer identities for SD-1 is available and we used this information to unscramble the writers. We then split SD-1 in two: characters written by the first 250 writers went into our new training set. The remaining 250 writers were placed in our test set. Thus we had two sets with nearly 30,000 examples each. The new training set was completed with enough examples from SD-3, starting at pattern # 0, to make a full set of 60,000 training patterns. Similarly, the new test set was completed with SD-3 examples starting at pattern # 35,000 to make a full set with 60,000 test patterns. Only a subset of 10,000 test images (5,000 from SD-1 and 5,000 from SD-3) is available on this site. The full 60,000 sample training set is available.\n\nDownloaded from openml.org.", 'details': {'id': '554', 'name': 'mnist_784', 'version': '1', 'description_version': '1', 'format': 'ARFF', 'creator': ['Yann LeCun', 'Corinna Cortes', 'Christopher J.C. Burges'], 'upload_date': '2014-09-29T03:28:38', 'language': 'English', 'licence': 'Public', 'url': 'https://www.openml.org/data/v1/download/52667/mnist_784.arff', 'file_id': '52667', 'default_target_attribute': 'class', 'tag': ['AzurePilot', 'OpenML-CC18', 'OpenML100', 'study_1', 'study_123', 'study_41', 'study_99', 'vision'], 'visibility': 'public', 'status': 'active', 'processing_date': '2020-11-20 20:12:09', 'md5_checksum': '0298d579eb1b86163de7723944c7e495'}, 'categories': {}, 'url': 'https://www.openml.org/d/554'}
(70000, 784)
(60000, 784)
{'data': array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]]), 'target': array(['5', '0', '4', ..., '4', '5', '6'], dtype=object), 'frame': None, 'feature_names': ['pixel1', 'pixel2', 'pixel3', 'pixel4', 'pixel5', 'pixel6', 'pixel7', 'pixel8', 'pixel9', 'pixel10', 'pixel11', 'pixel12', 'pixel13', 'pixel14', 'pixel15', 'pixel16', 'pixel17', 'pixel18', 'pixel19', 'pixel20', 'pixel21', 'pixel22', 'pixel23', 'pixel24', 'pixel25', 'pixel26', 'pixel27', 'pixel28', 'pixel29', 'pixel30', 'pixel31', 'pixel32', 'pixel33', 'pixel34', 'pixel35', 'pixel36', 'pixel37', 'pixel38', 'pixel39', 'pixel40', 'pixel41', 'pixel42', 'pixel43', 'pixel44', 'pixel45', 'pixel46', 'pixel47', 'pixel48', 'pixel49', 'pixel50', 'pixel51', 'pixel52', 'pixel53', 'pixel54', 'pixel55', 'pixel56', 'pixel57', 'pixel58', 'pixel59', 'pixel60', 'pixel61', 'pixel62', 'pixel63', 'pixel64', 'pixel65', 'pixel66', 'pixel67', 'pixel68', 'pixel69', 'pixel70', 'pixel71', 'pixel72', 'pixel73', 'pixel74', 'pixel75', 'pixel76', 'pixel77', 'pixel78', 'pixel79', 'pixel80', 'pixel81', 'pixel82', 'pixel83', 'pixel84', 'pixel85', 'pixel86', 'pixel87', 'pixel88', 'pixel89', 'pixel90', 'pixel91', 'pixel92', 'pixel93', 'pixel94', 'pixel95', 'pixel96', 'pixel97', 'pixel98', 'pixel99', 'pixel100', 'pixel101', 'pixel102', 'pixel103', 'pixel104', 'pixel105', 'pixel106', 'pixel107', 'pixel108', 'pixel109', 'pixel110', 'pixel111', 'pixel112', 'pixel113', 'pixel114', 'pixel115', 'pixel116', 'pixel117', 'pixel118', 'pixel119', 'pixel120', 'pixel121', 'pixel122', 'pixel123', 'pixel124', 'pixel125', 'pixel126', 'pixel127', 'pixel128', 'pixel129', 'pixel130', 'pixel131', 'pixel132', 'pixel133', 'pixel134', 'pixel135', 'pixel136', 'pixel137', 'pixel138', 'pixel139', 'pixel140', 'pixel141', 'pixel142', 'pixel143', 'pixel144', 'pixel145', 'pixel146', 'pixel147', 'pixel148', 'pixel149', 'pixel150', 'pixel151', 'pixel152', 'pixel153', 'pixel154', 'pixel155', 'pixel156', 'pixel157', 'pixel158', 'pixel159', 'pixel160', 'pixel161', 'pixel162', 'pixel163', 'pixel164', 'pixel165', 'pixel166', 'pixel167', 'pixel168', 'pixel169', 'pixel170', 'pixel171', 'pixel172', 'pixel173', 'pixel174', 'pixel175', 'pixel176', 'pixel177', 'pixel178', 'pixel179', 'pixel180', 'pixel181', 'pixel182', 'pixel183', 'pixel184', 'pixel185', 'pixel186', 'pixel187', 'pixel188', 'pixel189', 'pixel190', 'pixel191', 'pixel192', 'pixel193', 'pixel194', 'pixel195', 'pixel196', 'pixel197', 'pixel198', 'pixel199', 'pixel200', 'pixel201', 'pixel202', 'pixel203', 'pixel204', 'pixel205', 'pixel206', 'pixel207', 'pixel208', 'pixel209', 'pixel210', 'pixel211', 'pixel212', 'pixel213', 'pixel214', 'pixel215', 'pixel216', 'pixel217', 'pixel218', 'pixel219', 'pixel220', 'pixel221', 'pixel222', 'pixel223', 'pixel224', 'pixel225', 'pixel226', 'pixel227', 'pixel228', 'pixel229', 'pixel230', 'pixel231', 'pixel232', 'pixel233', 'pixel234', 'pixel235', 'pixel236', 'pixel237', 'pixel238', 'pixel239', 'pixel240', 'pixel241', 'pixel242', 'pixel243', 'pixel244', 'pixel245', 'pixel246', 'pixel247', 'pixel248', 'pixel249', 'pixel250', 'pixel251', 'pixel252', 'pixel253', 'pixel254', 'pixel255', 'pixel256', 'pixel257', 'pixel258', 'pixel259', 'pixel260', 'pixel261', 'pixel262', 'pixel263', 'pixel264', 'pixel265', 'pixel266', 'pixel267', 'pixel268', 'pixel269', 'pixel270', 'pixel271', 'pixel272', 'pixel273', 'pixel274', 'pixel275', 'pixel276', 'pixel277', 'pixel278', 'pixel279', 'pixel280', 'pixel281', 'pixel282', 'pixel283', 'pixel284', 'pixel285', 'pixel286', 'pixel287', 'pixel288', 'pixel289', 'pixel290', 'pixel291', 'pixel292', 'pixel293', 'pixel294', 'pixel295', 'pixel296', 'pixel297', 'pixel298', 'pixel299', 'pixel300', 'pixel301', 'pixel302', 'pixel303', 'pixel304', 'pixel305', 'pixel306', 'pixel307', 'pixel308', 'pixel309', 'pixel310', 'pixel311', 'pixel312', 'pixel313', 'pixel314', 'pixel315', 'pixel316', 'pixel317', 'pixel318', 'pixel319', 'pixel320', 'pixel321', 'pixel322', 'pixel323', 'pixel324', 'pixel325', 'pixel326', 'pixel327', 'pixel328', 'pixel329', 'pixel330', 'pixel331', 'pixel332', 'pixel333', 'pixel334', 'pixel335', 'pixel336', 'pixel337', 'pixel338', 'pixel339', 'pixel340', 'pixel341', 'pixel342', 'pixel343', 'pixel344', 'pixel345', 'pixel346', 'pixel347', 'pixel348', 'pixel349', 'pixel350', 'pixel351', 'pixel352', 'pixel353', 'pixel354', 'pixel355', 'pixel356', 'pixel357', 'pixel358', 'pixel359', 'pixel360', 'pixel361', 'pixel362', 'pixel363', 'pixel364', 'pixel365', 'pixel366', 'pixel367', 'pixel368', 'pixel369', 'pixel370', 'pixel371', 'pixel372', 'pixel373', 'pixel374', 'pixel375', 'pixel376', 'pixel377', 'pixel378', 'pixel379', 'pixel380', 'pixel381', 'pixel382', 'pixel383', 'pixel384', 'pixel385', 'pixel386', 'pixel387', 'pixel388', 'pixel389', 'pixel390', 'pixel391', 'pixel392', 'pixel393', 'pixel394', 'pixel395', 'pixel396', 'pixel397', 'pixel398', 'pixel399', 'pixel400', 'pixel401', 'pixel402', 'pixel403', 'pixel404', 'pixel405', 'pixel406', 'pixel407', 'pixel408', 'pixel409', 'pixel410', 'pixel411', 'pixel412', 'pixel413', 'pixel414', 'pixel415', 'pixel416', 'pixel417', 'pixel418', 'pixel419', 'pixel420', 'pixel421', 'pixel422', 'pixel423', 'pixel424', 'pixel425', 'pixel426', 'pixel427', 'pixel428', 'pixel429', 'pixel430', 'pixel431', 'pixel432', 'pixel433', 'pixel434', 'pixel435', 'pixel436', 'pixel437', 'pixel438', 'pixel439', 'pixel440', 'pixel441', 'pixel442', 'pixel443', 'pixel444', 'pixel445', 'pixel446', 'pixel447', 'pixel448', 'pixel449', 'pixel450', 'pixel451', 'pixel452', 'pixel453', 'pixel454', 'pixel455', 'pixel456', 'pixel457', 'pixel458', 'pixel459', 'pixel460', 'pixel461', 'pixel462', 'pixel463', 'pixel464', 'pixel465', 'pixel466', 'pixel467', 'pixel468', 'pixel469', 'pixel470', 'pixel471', 'pixel472', 'pixel473', 'pixel474', 'pixel475', 'pixel476', 'pixel477', 'pixel478', 'pixel479', 'pixel480', 'pixel481', 'pixel482', 'pixel483', 'pixel484', 'pixel485', 'pixel486', 'pixel487', 'pixel488', 'pixel489', 'pixel490', 'pixel491', 'pixel492', 'pixel493', 'pixel494', 'pixel495', 'pixel496', 'pixel497', 'pixel498', 'pixel499', 'pixel500', 'pixel501', 'pixel502', 'pixel503', 'pixel504', 'pixel505', 'pixel506', 'pixel507', 'pixel508', 'pixel509', 'pixel510', 'pixel511', 'pixel512', 'pixel513', 'pixel514', 'pixel515', 'pixel516', 'pixel517', 'pixel518', 'pixel519', 'pixel520', 'pixel521', 'pixel522', 'pixel523', 'pixel524', 'pixel525', 'pixel526', 'pixel527', 'pixel528', 'pixel529', 'pixel530', 'pixel531', 'pixel532', 'pixel533', 'pixel534', 'pixel535', 'pixel536', 'pixel537', 'pixel538', 'pixel539', 'pixel540', 'pixel541', 'pixel542', 'pixel543', 'pixel544', 'pixel545', 'pixel546', 'pixel547', 'pixel548', 'pixel549', 'pixel550', 'pixel551', 'pixel552', 'pixel553', 'pixel554', 'pixel555', 'pixel556', 'pixel557', 'pixel558', 'pixel559', 'pixel560', 'pixel561', 'pixel562', 'pixel563', 'pixel564', 'pixel565', 'pixel566', 'pixel567', 'pixel568', 'pixel569', 'pixel570', 'pixel571', 'pixel572', 'pixel573', 'pixel574', 'pixel575', 'pixel576', 'pixel577', 'pixel578', 'pixel579', 'pixel580', 'pixel581', 'pixel582', 'pixel583', 'pixel584', 'pixel585', 'pixel586', 'pixel587', 'pixel588', 'pixel589', 'pixel590', 'pixel591', 'pixel592', 'pixel593', 'pixel594', 'pixel595', 'pixel596', 'pixel597', 'pixel598', 'pixel599', 'pixel600', 'pixel601', 'pixel602', 'pixel603', 'pixel604', 'pixel605', 'pixel606', 'pixel607', 'pixel608', 'pixel609', 'pixel610', 'pixel611', 'pixel612', 'pixel613', 'pixel614', 'pixel615', 'pixel616', 'pixel617', 'pixel618', 'pixel619', 'pixel620', 'pixel621', 'pixel622', 'pixel623', 'pixel624', 'pixel625', 'pixel626', 'pixel627', 'pixel628', 'pixel629', 'pixel630', 'pixel631', 'pixel632', 'pixel633', 'pixel634', 'pixel635', 'pixel636', 'pixel637', 'pixel638', 'pixel639', 'pixel640', 'pixel641', 'pixel642', 'pixel643', 'pixel644', 'pixel645', 'pixel646', 'pixel647', 'pixel648', 'pixel649', 'pixel650', 'pixel651', 'pixel652', 'pixel653', 'pixel654', 'pixel655', 'pixel656', 'pixel657', 'pixel658', 'pixel659', 'pixel660', 'pixel661', 'pixel662', 'pixel663', 'pixel664', 'pixel665', 'pixel666', 'pixel667', 'pixel668', 'pixel669', 'pixel670', 'pixel671', 'pixel672', 'pixel673', 'pixel674', 'pixel675', 'pixel676', 'pixel677', 'pixel678', 'pixel679', 'pixel680', 'pixel681', 'pixel682', 'pixel683', 'pixel684', 'pixel685', 'pixel686', 'pixel687', 'pixel688', 'pixel689', 'pixel690', 'pixel691', 'pixel692', 'pixel693', 'pixel694', 'pixel695', 'pixel696', 'pixel697', 'pixel698', 'pixel699', 'pixel700', 'pixel701', 'pixel702', 'pixel703', 'pixel704', 'pixel705', 'pixel706', 'pixel707', 'pixel708', 'pixel709', 'pixel710', 'pixel711', 'pixel712', 'pixel713', 'pixel714', 'pixel715', 'pixel716', 'pixel717', 'pixel718', 'pixel719', 'pixel720', 'pixel721', 'pixel722', 'pixel723', 'pixel724', 'pixel725', 'pixel726', 'pixel727', 'pixel728', 'pixel729', 'pixel730', 'pixel731', 'pixel732', 'pixel733', 'pixel734', 'pixel735', 'pixel736', 'pixel737', 'pixel738', 'pixel739', 'pixel740', 'pixel741', 'pixel742', 'pixel743', 'pixel744', 'pixel745', 'pixel746', 'pixel747', 'pixel748', 'pixel749', 'pixel750', 'pixel751', 'pixel752', 'pixel753', 'pixel754', 'pixel755', 'pixel756', 'pixel757', 'pixel758', 'pixel759', 'pixel760', 'pixel761', 'pixel762', 'pixel763', 'pixel764', 'pixel765', 'pixel766', 'pixel767', 'pixel768', 'pixel769', 'pixel770', 'pixel771', 'pixel772', 'pixel773', 'pixel774', 'pixel775', 'pixel776', 'pixel777', 'pixel778', 'pixel779', 'pixel780', 'pixel781', 'pixel782', 'pixel783', 'pixel784'], 'target_names': ['class'], 'DESCR': "**Author**: Yann LeCun, Corinna Cortes, Christopher J.C. Burges  \n**Source**: [MNIST Website](http://yann.lecun.com/exdb/mnist/) - Date unknown  \n**Please cite**:  \n\nThe MNIST database of handwritten digits with 784 features, raw data available at: http://yann.lecun.com/exdb/mnist/. It can be split in a training set of the first 60,000 examples, and a test set of 10,000 examples  \n\nIt is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting. The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.  \n\nWith some classification methods (particularly template-based methods, such as SVM and K-nearest neighbors), the error rate improves when the digits are centered by bounding box rather than center of mass. If you do this kind of pre-processing, you should report it in your publications. The MNIST database was constructed from NIST's NIST originally designated SD-3 as their training set and SD-1 as their test set. However, SD-3 is much cleaner and easier to recognize than SD-1. The reason for this can be found on the fact that SD-3 was collected among Census Bureau employees, while SD-1 was collected among high-school students. Drawing sensible conclusions from learning experiments requires that the result be independent of the choice of training set and test among the complete set of samples. Therefore it was necessary to build a new database by mixing NIST's datasets.  \n\nThe MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers. We made sure that the sets of writers of the training set and test set were disjoint. SD-1 contains 58,527 digit images written by 500 different writers. In contrast to SD-3, where blocks of data from each writer appeared in sequence, the data in SD-1 is scrambled. Writer identities for SD-1 is available and we used this information to unscramble the writers. We then split SD-1 in two: characters written by the first 250 writers went into our new training set. The remaining 250 writers were placed in our test set. Thus we had two sets with nearly 30,000 examples each. The new training set was completed with enough examples from SD-3, starting at pattern # 0, to make a full set of 60,000 training patterns. Similarly, the new test set was completed with SD-3 examples starting at pattern # 35,000 to make a full set with 60,000 test patterns. Only a subset of 10,000 test images (5,000 from SD-1 and 5,000 from SD-3) is available on this site. The full 60,000 sample training set is available.\n\nDownloaded from openml.org.", 'details': {'id': '554', 'name': 'mnist_784', 'version': '1', 'description_version': '1', 'format': 'ARFF', 'creator': ['Yann LeCun', 'Corinna Cortes', 'Christopher J.C. Burges'], 'upload_date': '2014-09-29T03:28:38', 'language': 'English', 'licence': 'Public', 'url': 'https://www.openml.org/data/v1/download/52667/mnist_784.arff', 'file_id': '52667', 'default_target_attribute': 'class', 'tag': ['AzurePilot', 'OpenML-CC18', 'OpenML100', 'study_1', 'study_123', 'study_41', 'study_99', 'vision'], 'visibility': 'public', 'status': 'active', 'processing_date': '2020-11-20 20:12:09', 'md5_checksum': '0298d579eb1b86163de7723944c7e495'}, 'categories': {}, 'url': 'https://www.openml.org/d/554'}
(70000, 784)
(60000, 784)
(60000,)
(10000, 784)
(10000,)

现在我们就对这组原始数据进行KNN的分类

import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.neighbors import KNeighborsClassifier
import timeit

if __name__ == "__main__":

    mnist = fetch_openml('mnist_784')
    print(mnist)
    X, y = mnist['data'], mnist['target']
    print(X.shape)
    X_train = np.array(X[:60000], dtype=float)
    y_train = np.array(y[:60000], dtype=float)
    X_test = np.array(X[60000:], dtype=float)
    y_test = np.array(y[60000:], dtype=float)
    print(X_train.shape)
    print(y_train.shape)
    print(X_test.shape)
    print(y_test.shape)
    # 对原始数据进行KNN分类
    start_time = timeit.default_timer()
    knn_clf = KNeighborsClassifier()
    knn_clf.fit(X_train, y_train)
    print(timeit.default_timer() - start_time)
{'data': array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]]), 'target': array(['5', '0', '4', ..., '4', '5', '6'], dtype=object), 'frame': None, 'feature_names': ['pixel1', 'pixel2', 'pixel3', 'pixel4', 'pixel5', 'pixel6', 'pixel7', 'pixel8', 'pixel9', 'pixel10', 'pixel11', 'pixel12', 'pixel13', 'pixel14', 'pixel15', 'pixel16', 'pixel17', 'pixel18', 'pixel19', 'pixel20', 'pixel21', 'pixel22', 'pixel23', 'pixel24', 'pixel25', 'pixel26', 'pixel27', 'pixel28', 'pixel29', 'pixel30', 'pixel31', 'pixel32', 'pixel33', 'pixel34', 'pixel35', 'pixel36', 'pixel37', 'pixel38', 'pixel39', 'pixel40', 'pixel41', 'pixel42', 'pixel43', 'pixel44', 'pixel45', 'pixel46', 'pixel47', 'pixel48', 'pixel49', 'pixel50', 'pixel51', 'pixel52', 'pixel53', 'pixel54', 'pixel55', 'pixel56', 'pixel57', 'pixel58', 'pixel59', 'pixel60', 'pixel61', 'pixel62', 'pixel63', 'pixel64', 'pixel65', 'pixel66', 'pixel67', 'pixel68', 'pixel69', 'pixel70', 'pixel71', 'pixel72', 'pixel73', 'pixel74', 'pixel75', 'pixel76', 'pixel77', 'pixel78', 'pixel79', 'pixel80', 'pixel81', 'pixel82', 'pixel83', 'pixel84', 'pixel85', 'pixel86', 'pixel87', 'pixel88', 'pixel89', 'pixel90', 'pixel91', 'pixel92', 'pixel93', 'pixel94', 'pixel95', 'pixel96', 'pixel97', 'pixel98', 'pixel99', 'pixel100', 'pixel101', 'pixel102', 'pixel103', 'pixel104', 'pixel105', 'pixel106', 'pixel107', 'pixel108', 'pixel109', 'pixel110', 'pixel111', 'pixel112', 'pixel113', 'pixel114', 'pixel115', 'pixel116', 'pixel117', 'pixel118', 'pixel119', 'pixel120', 'pixel121', 'pixel122', 'pixel123', 'pixel124', 'pixel125', 'pixel126', 'pixel127', 'pixel128', 'pixel129', 'pixel130', 'pixel131', 'pixel132', 'pixel133', 'pixel134', 'pixel135', 'pixel136', 'pixel137', 'pixel138', 'pixel139', 'pixel140', 'pixel141', 'pixel142', 'pixel143', 'pixel144', 'pixel145', 'pixel146', 'pixel147', 'pixel148', 'pixel149', 'pixel150', 'pixel151', 'pixel152', 'pixel153', 'pixel154', 'pixel155', 'pixel156', 'pixel157', 'pixel158', 'pixel159', 'pixel160', 'pixel161', 'pixel162', 'pixel163', 'pixel164', 'pixel165', 'pixel166', 'pixel167', 'pixel168', 'pixel169', 'pixel170', 'pixel171', 'pixel172', 'pixel173', 'pixel174', 'pixel175', 'pixel176', 'pixel177', 'pixel178', 'pixel179', 'pixel180', 'pixel181', 'pixel182', 'pixel183', 'pixel184', 'pixel185', 'pixel186', 'pixel187', 'pixel188', 'pixel189', 'pixel190', 'pixel191', 'pixel192', 'pixel193', 'pixel194', 'pixel195', 'pixel196', 'pixel197', 'pixel198', 'pixel199', 'pixel200', 'pixel201', 'pixel202', 'pixel203', 'pixel204', 'pixel205', 'pixel206', 'pixel207', 'pixel208', 'pixel209', 'pixel210', 'pixel211', 'pixel212', 'pixel213', 'pixel214', 'pixel215', 'pixel216', 'pixel217', 'pixel218', 'pixel219', 'pixel220', 'pixel221', 'pixel222', 'pixel223', 'pixel224', 'pixel225', 'pixel226', 'pixel227', 'pixel228', 'pixel229', 'pixel230', 'pixel231', 'pixel232', 'pixel233', 'pixel234', 'pixel235', 'pixel236', 'pixel237', 'pixel238', 'pixel239', 'pixel240', 'pixel241', 'pixel242', 'pixel243', 'pixel244', 'pixel245', 'pixel246', 'pixel247', 'pixel248', 'pixel249', 'pixel250', 'pixel251', 'pixel252', 'pixel253', 'pixel254', 'pixel255', 'pixel256', 'pixel257', 'pixel258', 'pixel259', 'pixel260', 'pixel261', 'pixel262', 'pixel263', 'pixel264', 'pixel265', 'pixel266', 'pixel267', 'pixel268', 'pixel269', 'pixel270', 'pixel271', 'pixel272', 'pixel273', 'pixel274', 'pixel275', 'pixel276', 'pixel277', 'pixel278', 'pixel279', 'pixel280', 'pixel281', 'pixel282', 'pixel283', 'pixel284', 'pixel285', 'pixel286', 'pixel287', 'pixel288', 'pixel289', 'pixel290', 'pixel291', 'pixel292', 'pixel293', 'pixel294', 'pixel295', 'pixel296', 'pixel297', 'pixel298', 'pixel299', 'pixel300', 'pixel301', 'pixel302', 'pixel303', 'pixel304', 'pixel305', 'pixel306', 'pixel307', 'pixel308', 'pixel309', 'pixel310', 'pixel311', 'pixel312', 'pixel313', 'pixel314', 'pixel315', 'pixel316', 'pixel317', 'pixel318', 'pixel319', 'pixel320', 'pixel321', 'pixel322', 'pixel323', 'pixel324', 'pixel325', 'pixel326', 'pixel327', 'pixel328', 'pixel329', 'pixel330', 'pixel331', 'pixel332', 'pixel333', 'pixel334', 'pixel335', 'pixel336', 'pixel337', 'pixel338', 'pixel339', 'pixel340', 'pixel341', 'pixel342', 'pixel343', 'pixel344', 'pixel345', 'pixel346', 'pixel347', 'pixel348', 'pixel349', 'pixel350', 'pixel351', 'pixel352', 'pixel353', 'pixel354', 'pixel355', 'pixel356', 'pixel357', 'pixel358', 'pixel359', 'pixel360', 'pixel361', 'pixel362', 'pixel363', 'pixel364', 'pixel365', 'pixel366', 'pixel367', 'pixel368', 'pixel369', 'pixel370', 'pixel371', 'pixel372', 'pixel373', 'pixel374', 'pixel375', 'pixel376', 'pixel377', 'pixel378', 'pixel379', 'pixel380', 'pixel381', 'pixel382', 'pixel383', 'pixel384', 'pixel385', 'pixel386', 'pixel387', 'pixel388', 'pixel389', 'pixel390', 'pixel391', 'pixel392', 'pixel393', 'pixel394', 'pixel395', 'pixel396', 'pixel397', 'pixel398', 'pixel399', 'pixel400', 'pixel401', 'pixel402', 'pixel403', 'pixel404', 'pixel405', 'pixel406', 'pixel407', 'pixel408', 'pixel409', 'pixel410', 'pixel411', 'pixel412', 'pixel413', 'pixel414', 'pixel415', 'pixel416', 'pixel417', 'pixel418', 'pixel419', 'pixel420', 'pixel421', 'pixel422', 'pixel423', 'pixel424', 'pixel425', 'pixel426', 'pixel427', 'pixel428', 'pixel429', 'pixel430', 'pixel431', 'pixel432', 'pixel433', 'pixel434', 'pixel435', 'pixel436', 'pixel437', 'pixel438', 'pixel439', 'pixel440', 'pixel441', 'pixel442', 'pixel443', 'pixel444', 'pixel445', 'pixel446', 'pixel447', 'pixel448', 'pixel449', 'pixel450', 'pixel451', 'pixel452', 'pixel453', 'pixel454', 'pixel455', 'pixel456', 'pixel457', 'pixel458', 'pixel459', 'pixel460', 'pixel461', 'pixel462', 'pixel463', 'pixel464', 'pixel465', 'pixel466', 'pixel467', 'pixel468', 'pixel469', 'pixel470', 'pixel471', 'pixel472', 'pixel473', 'pixel474', 'pixel475', 'pixel476', 'pixel477', 'pixel478', 'pixel479', 'pixel480', 'pixel481', 'pixel482', 'pixel483', 'pixel484', 'pixel485', 'pixel486', 'pixel487', 'pixel488', 'pixel489', 'pixel490', 'pixel491', 'pixel492', 'pixel493', 'pixel494', 'pixel495', 'pixel496', 'pixel497', 'pixel498', 'pixel499', 'pixel500', 'pixel501', 'pixel502', 'pixel503', 'pixel504', 'pixel505', 'pixel506', 'pixel507', 'pixel508', 'pixel509', 'pixel510', 'pixel511', 'pixel512', 'pixel513', 'pixel514', 'pixel515', 'pixel516', 'pixel517', 'pixel518', 'pixel519', 'pixel520', 'pixel521', 'pixel522', 'pixel523', 'pixel524', 'pixel525', 'pixel526', 'pixel527', 'pixel528', 'pixel529', 'pixel530', 'pixel531', 'pixel532', 'pixel533', 'pixel534', 'pixel535', 'pixel536', 'pixel537', 'pixel538', 'pixel539', 'pixel540', 'pixel541', 'pixel542', 'pixel543', 'pixel544', 'pixel545', 'pixel546', 'pixel547', 'pixel548', 'pixel549', 'pixel550', 'pixel551', 'pixel552', 'pixel553', 'pixel554', 'pixel555', 'pixel556', 'pixel557', 'pixel558', 'pixel559', 'pixel560', 'pixel561', 'pixel562', 'pixel563', 'pixel564', 'pixel565', 'pixel566', 'pixel567', 'pixel568', 'pixel569', 'pixel570', 'pixel571', 'pixel572', 'pixel573', 'pixel574', 'pixel575', 'pixel576', 'pixel577', 'pixel578', 'pixel579', 'pixel580', 'pixel581', 'pixel582', 'pixel583', 'pixel584', 'pixel585', 'pixel586', 'pixel587', 'pixel588', 'pixel589', 'pixel590', 'pixel591', 'pixel592', 'pixel593', 'pixel594', 'pixel595', 'pixel596', 'pixel597', 'pixel598', 'pixel599', 'pixel600', 'pixel601', 'pixel602', 'pixel603', 'pixel604', 'pixel605', 'pixel606', 'pixel607', 'pixel608', 'pixel609', 'pixel610', 'pixel611', 'pixel612', 'pixel613', 'pixel614', 'pixel615', 'pixel616', 'pixel617', 'pixel618', 'pixel619', 'pixel620', 'pixel621', 'pixel622', 'pixel623', 'pixel624', 'pixel625', 'pixel626', 'pixel627', 'pixel628', 'pixel629', 'pixel630', 'pixel631', 'pixel632', 'pixel633', 'pixel634', 'pixel635', 'pixel636', 'pixel637', 'pixel638', 'pixel639', 'pixel640', 'pixel641', 'pixel642', 'pixel643', 'pixel644', 'pixel645', 'pixel646', 'pixel647', 'pixel648', 'pixel649', 'pixel650', 'pixel651', 'pixel652', 'pixel653', 'pixel654', 'pixel655', 'pixel656', 'pixel657', 'pixel658', 'pixel659', 'pixel660', 'pixel661', 'pixel662', 'pixel663', 'pixel664', 'pixel665', 'pixel666', 'pixel667', 'pixel668', 'pixel669', 'pixel670', 'pixel671', 'pixel672', 'pixel673', 'pixel674', 'pixel675', 'pixel676', 'pixel677', 'pixel678', 'pixel679', 'pixel680', 'pixel681', 'pixel682', 'pixel683', 'pixel684', 'pixel685', 'pixel686', 'pixel687', 'pixel688', 'pixel689', 'pixel690', 'pixel691', 'pixel692', 'pixel693', 'pixel694', 'pixel695', 'pixel696', 'pixel697', 'pixel698', 'pixel699', 'pixel700', 'pixel701', 'pixel702', 'pixel703', 'pixel704', 'pixel705', 'pixel706', 'pixel707', 'pixel708', 'pixel709', 'pixel710', 'pixel711', 'pixel712', 'pixel713', 'pixel714', 'pixel715', 'pixel716', 'pixel717', 'pixel718', 'pixel719', 'pixel720', 'pixel721', 'pixel722', 'pixel723', 'pixel724', 'pixel725', 'pixel726', 'pixel727', 'pixel728', 'pixel729', 'pixel730', 'pixel731', 'pixel732', 'pixel733', 'pixel734', 'pixel735', 'pixel736', 'pixel737', 'pixel738', 'pixel739', 'pixel740', 'pixel741', 'pixel742', 'pixel743', 'pixel744', 'pixel745', 'pixel746', 'pixel747', 'pixel748', 'pixel749', 'pixel750', 'pixel751', 'pixel752', 'pixel753', 'pixel754', 'pixel755', 'pixel756', 'pixel757', 'pixel758', 'pixel759', 'pixel760', 'pixel761', 'pixel762', 'pixel763', 'pixel764', 'pixel765', 'pixel766', 'pixel767', 'pixel768', 'pixel769', 'pixel770', 'pixel771', 'pixel772', 'pixel773', 'pixel774', 'pixel775', 'pixel776', 'pixel777', 'pixel778', 'pixel779', 'pixel780', 'pixel781', 'pixel782', 'pixel783', 'pixel784'], 'target_names': ['class'], 'DESCR': "**Author**: Yann LeCun, Corinna Cortes, Christopher J.C. Burges  \n**Source**: [MNIST Website](http://yann.lecun.com/exdb/mnist/) - Date unknown  \n**Please cite**:  \n\nThe MNIST database of handwritten digits with 784 features, raw data available at: http://yann.lecun.com/exdb/mnist/. It can be split in a training set of the first 60,000 examples, and a test set of 10,000 examples  \n\nIt is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting. The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.  \n\nWith some classification methods (particularly template-based methods, such as SVM and K-nearest neighbors), the error rate improves when the digits are centered by bounding box rather than center of mass. If you do this kind of pre-processing, you should report it in your publications. The MNIST database was constructed from NIST's NIST originally designated SD-3 as their training set and SD-1 as their test set. However, SD-3 is much cleaner and easier to recognize than SD-1. The reason for this can be found on the fact that SD-3 was collected among Census Bureau employees, while SD-1 was collected among high-school students. Drawing sensible conclusions from learning experiments requires that the result be independent of the choice of training set and test among the complete set of samples. Therefore it was necessary to build a new database by mixing NIST's datasets.  \n\nThe MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers. We made sure that the sets of writers of the training set and test set were disjoint. SD-1 contains 58,527 digit images written by 500 different writers. In contrast to SD-3, where blocks of data from each writer appeared in sequence, the data in SD-1 is scrambled. Writer identities for SD-1 is available and we used this information to unscramble the writers. We then split SD-1 in two: characters written by the first 250 writers went into our new training set. The remaining 250 writers were placed in our test set. Thus we had two sets with nearly 30,000 examples each. The new training set was completed with enough examples from SD-3, starting at pattern # 0, to make a full set of 60,000 training patterns. Similarly, the new test set was completed with SD-3 examples starting at pattern # 35,000 to make a full set with 60,000 test patterns. Only a subset of 10,000 test images (5,000 from SD-1 and 5,000 from SD-3) is available on this site. The full 60,000 sample training set is available.\n\nDownloaded from openml.org.", 'details': {'id': '554', 'name': 'mnist_784', 'version': '1', 'description_version': '1', 'format': 'ARFF', 'creator': ['Yann LeCun', 'Corinna Cortes', 'Christopher J.C. Burges'], 'upload_date': '2014-09-29T03:28:38', 'language': 'English', 'licence': 'Public', 'url': 'https://www.openml.org/data/v1/download/52667/mnist_784.arff', 'file_id': '52667', 'default_target_attribute': 'class', 'tag': ['AzurePilot', 'OpenML-CC18', 'OpenML100', 'study_1', 'study_123', 'study_41', 'study_99', 'vision'], 'visibility': 'public', 'status': 'active', 'processing_date': '2020-11-20 20:12:09', 'md5_checksum': '0298d579eb1b86163de7723944c7e495'}, 'categories': {}, 'url': 'https://www.openml.org/d/554'}
(70000, 784)
(60000, 784)
(60000,)
(10000, 784)
(10000,)
12.399795246999998

通过结果可以看出,使用KNN算法对这组原始数据进行训练使用了12.39秒。现在我们来看一下识别准确率是多少

import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.neighbors import KNeighborsClassifier
import timeit

if __name__ == "__main__":

    mnist = fetch_openml('mnist_784')
    print(mnist)
    X, y = mnist['data'], mnist['target']
    print(X.shape)
    X_train = np.array(X[:60000], dtype=float)
    y_train = np.array(y[:60000], dtype=float)
    X_test = np.array(X[60000:], dtype=float)
    y_test = np.array(y[60000:], dtype=float)
    print(X_train.shape)
    print(y_train.shape)
    print(X_test.shape)
    print(y_test.shape)
    # 对原始数据进行KNN分类
    start_time = timeit.default_timer()
    knn_clf = KNeighborsClassifier()
    knn_clf.fit(X_train, y_train)
    print(timeit.default_timer() - start_time)
    start_time = timeit.default_timer()
    print(knn_clf.score(X_test, y_test))
    print(timeit.default_timer() - start_time)
{'data': array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]]), 'target': array(['5', '0', '4', ..., '4', '5', '6'], dtype=object), 'frame': None, 'feature_names': ['pixel1', 'pixel2', 'pixel3', 'pixel4', 'pixel5', 'pixel6', 'pixel7', 'pixel8', 'pixel9', 'pixel10', 'pixel11', 'pixel12', 'pixel13', 'pixel14', 'pixel15', 'pixel16', 'pixel17', 'pixel18', 'pixel19', 'pixel20', 'pixel21', 'pixel22', 'pixel23', 'pixel24', 'pixel25', 'pixel26', 'pixel27', 'pixel28', 'pixel29', 'pixel30', 'pixel31', 'pixel32', 'pixel33', 'pixel34', 'pixel35', 'pixel36', 'pixel37', 'pixel38', 'pixel39', 'pixel40', 'pixel41', 'pixel42', 'pixel43', 'pixel44', 'pixel45', 'pixel46', 'pixel47', 'pixel48', 'pixel49', 'pixel50', 'pixel51', 'pixel52', 'pixel53', 'pixel54', 'pixel55', 'pixel56', 'pixel57', 'pixel58', 'pixel59', 'pixel60', 'pixel61', 'pixel62', 'pixel63', 'pixel64', 'pixel65', 'pixel66', 'pixel67', 'pixel68', 'pixel69', 'pixel70', 'pixel71', 'pixel72', 'pixel73', 'pixel74', 'pixel75', 'pixel76', 'pixel77', 'pixel78', 'pixel79', 'pixel80', 'pixel81', 'pixel82', 'pixel83', 'pixel84', 'pixel85', 'pixel86', 'pixel87', 'pixel88', 'pixel89', 'pixel90', 'pixel91', 'pixel92', 'pixel93', 'pixel94', 'pixel95', 'pixel96', 'pixel97', 'pixel98', 'pixel99', 'pixel100', 'pixel101', 'pixel102', 'pixel103', 'pixel104', 'pixel105', 'pixel106', 'pixel107', 'pixel108', 'pixel109', 'pixel110', 'pixel111', 'pixel112', 'pixel113', 'pixel114', 'pixel115', 'pixel116', 'pixel117', 'pixel118', 'pixel119', 'pixel120', 'pixel121', 'pixel122', 'pixel123', 'pixel124', 'pixel125', 'pixel126', 'pixel127', 'pixel128', 'pixel129', 'pixel130', 'pixel131', 'pixel132', 'pixel133', 'pixel134', 'pixel135', 'pixel136', 'pixel137', 'pixel138', 'pixel139', 'pixel140', 'pixel141', 'pixel142', 'pixel143', 'pixel144', 'pixel145', 'pixel146', 'pixel147', 'pixel148', 'pixel149', 'pixel150', 'pixel151', 'pixel152', 'pixel153', 'pixel154', 'pixel155', 'pixel156', 'pixel157', 'pixel158', 'pixel159', 'pixel160', 'pixel161', 'pixel162', 'pixel163', 'pixel164', 'pixel165', 'pixel166', 'pixel167', 'pixel168', 'pixel169', 'pixel170', 'pixel171', 'pixel172', 'pixel173', 'pixel174', 'pixel175', 'pixel176', 'pixel177', 'pixel178', 'pixel179', 'pixel180', 'pixel181', 'pixel182', 'pixel183', 'pixel184', 'pixel185', 'pixel186', 'pixel187', 'pixel188', 'pixel189', 'pixel190', 'pixel191', 'pixel192', 'pixel193', 'pixel194', 'pixel195', 'pixel196', 'pixel197', 'pixel198', 'pixel199', 'pixel200', 'pixel201', 'pixel202', 'pixel203', 'pixel204', 'pixel205', 'pixel206', 'pixel207', 'pixel208', 'pixel209', 'pixel210', 'pixel211', 'pixel212', 'pixel213', 'pixel214', 'pixel215', 'pixel216', 'pixel217', 'pixel218', 'pixel219', 'pixel220', 'pixel221', 'pixel222', 'pixel223', 'pixel224', 'pixel225', 'pixel226', 'pixel227', 'pixel228', 'pixel229', 'pixel230', 'pixel231', 'pixel232', 'pixel233', 'pixel234', 'pixel235', 'pixel236', 'pixel237', 'pixel238', 'pixel239', 'pixel240', 'pixel241', 'pixel242', 'pixel243', 'pixel244', 'pixel245', 'pixel246', 'pixel247', 'pixel248', 'pixel249', 'pixel250', 'pixel251', 'pixel252', 'pixel253', 'pixel254', 'pixel255', 'pixel256', 'pixel257', 'pixel258', 'pixel259', 'pixel260', 'pixel261', 'pixel262', 'pixel263', 'pixel264', 'pixel265', 'pixel266', 'pixel267', 'pixel268', 'pixel269', 'pixel270', 'pixel271', 'pixel272', 'pixel273', 'pixel274', 'pixel275', 'pixel276', 'pixel277', 'pixel278', 'pixel279', 'pixel280', 'pixel281', 'pixel282', 'pixel283', 'pixel284', 'pixel285', 'pixel286', 'pixel287', 'pixel288', 'pixel289', 'pixel290', 'pixel291', 'pixel292', 'pixel293', 'pixel294', 'pixel295', 'pixel296', 'pixel297', 'pixel298', 'pixel299', 'pixel300', 'pixel301', 'pixel302', 'pixel303', 'pixel304', 'pixel305', 'pixel306', 'pixel307', 'pixel308', 'pixel309', 'pixel310', 'pixel311', 'pixel312', 'pixel313', 'pixel314', 'pixel315', 'pixel316', 'pixel317', 'pixel318', 'pixel319', 'pixel320', 'pixel321', 'pixel322', 'pixel323', 'pixel324', 'pixel325', 'pixel326', 'pixel327', 'pixel328', 'pixel329', 'pixel330', 'pixel331', 'pixel332', 'pixel333', 'pixel334', 'pixel335', 'pixel336', 'pixel337', 'pixel338', 'pixel339', 'pixel340', 'pixel341', 'pixel342', 'pixel343', 'pixel344', 'pixel345', 'pixel346', 'pixel347', 'pixel348', 'pixel349', 'pixel350', 'pixel351', 'pixel352', 'pixel353', 'pixel354', 'pixel355', 'pixel356', 'pixel357', 'pixel358', 'pixel359', 'pixel360', 'pixel361', 'pixel362', 'pixel363', 'pixel364', 'pixel365', 'pixel366', 'pixel367', 'pixel368', 'pixel369', 'pixel370', 'pixel371', 'pixel372', 'pixel373', 'pixel374', 'pixel375', 'pixel376', 'pixel377', 'pixel378', 'pixel379', 'pixel380', 'pixel381', 'pixel382', 'pixel383', 'pixel384', 'pixel385', 'pixel386', 'pixel387', 'pixel388', 'pixel389', 'pixel390', 'pixel391', 'pixel392', 'pixel393', 'pixel394', 'pixel395', 'pixel396', 'pixel397', 'pixel398', 'pixel399', 'pixel400', 'pixel401', 'pixel402', 'pixel403', 'pixel404', 'pixel405', 'pixel406', 'pixel407', 'pixel408', 'pixel409', 'pixel410', 'pixel411', 'pixel412', 'pixel413', 'pixel414', 'pixel415', 'pixel416', 'pixel417', 'pixel418', 'pixel419', 'pixel420', 'pixel421', 'pixel422', 'pixel423', 'pixel424', 'pixel425', 'pixel426', 'pixel427', 'pixel428', 'pixel429', 'pixel430', 'pixel431', 'pixel432', 'pixel433', 'pixel434', 'pixel435', 'pixel436', 'pixel437', 'pixel438', 'pixel439', 'pixel440', 'pixel441', 'pixel442', 'pixel443', 'pixel444', 'pixel445', 'pixel446', 'pixel447', 'pixel448', 'pixel449', 'pixel450', 'pixel451', 'pixel452', 'pixel453', 'pixel454', 'pixel455', 'pixel456', 'pixel457', 'pixel458', 'pixel459', 'pixel460', 'pixel461', 'pixel462', 'pixel463', 'pixel464', 'pixel465', 'pixel466', 'pixel467', 'pixel468', 'pixel469', 'pixel470', 'pixel471', 'pixel472', 'pixel473', 'pixel474', 'pixel475', 'pixel476', 'pixel477', 'pixel478', 'pixel479', 'pixel480', 'pixel481', 'pixel482', 'pixel483', 'pixel484', 'pixel485', 'pixel486', 'pixel487', 'pixel488', 'pixel489', 'pixel490', 'pixel491', 'pixel492', 'pixel493', 'pixel494', 'pixel495', 'pixel496', 'pixel497', 'pixel498', 'pixel499', 'pixel500', 'pixel501', 'pixel502', 'pixel503', 'pixel504', 'pixel505', 'pixel506', 'pixel507', 'pixel508', 'pixel509', 'pixel510', 'pixel511', 'pixel512', 'pixel513', 'pixel514', 'pixel515', 'pixel516', 'pixel517', 'pixel518', 'pixel519', 'pixel520', 'pixel521', 'pixel522', 'pixel523', 'pixel524', 'pixel525', 'pixel526', 'pixel527', 'pixel528', 'pixel529', 'pixel530', 'pixel531', 'pixel532', 'pixel533', 'pixel534', 'pixel535', 'pixel536', 'pixel537', 'pixel538', 'pixel539', 'pixel540', 'pixel541', 'pixel542', 'pixel543', 'pixel544', 'pixel545', 'pixel546', 'pixel547', 'pixel548', 'pixel549', 'pixel550', 'pixel551', 'pixel552', 'pixel553', 'pixel554', 'pixel555', 'pixel556', 'pixel557', 'pixel558', 'pixel559', 'pixel560', 'pixel561', 'pixel562', 'pixel563', 'pixel564', 'pixel565', 'pixel566', 'pixel567', 'pixel568', 'pixel569', 'pixel570', 'pixel571', 'pixel572', 'pixel573', 'pixel574', 'pixel575', 'pixel576', 'pixel577', 'pixel578', 'pixel579', 'pixel580', 'pixel581', 'pixel582', 'pixel583', 'pixel584', 'pixel585', 'pixel586', 'pixel587', 'pixel588', 'pixel589', 'pixel590', 'pixel591', 'pixel592', 'pixel593', 'pixel594', 'pixel595', 'pixel596', 'pixel597', 'pixel598', 'pixel599', 'pixel600', 'pixel601', 'pixel602', 'pixel603', 'pixel604', 'pixel605', 'pixel606', 'pixel607', 'pixel608', 'pixel609', 'pixel610', 'pixel611', 'pixel612', 'pixel613', 'pixel614', 'pixel615', 'pixel616', 'pixel617', 'pixel618', 'pixel619', 'pixel620', 'pixel621', 'pixel622', 'pixel623', 'pixel624', 'pixel625', 'pixel626', 'pixel627', 'pixel628', 'pixel629', 'pixel630', 'pixel631', 'pixel632', 'pixel633', 'pixel634', 'pixel635', 'pixel636', 'pixel637', 'pixel638', 'pixel639', 'pixel640', 'pixel641', 'pixel642', 'pixel643', 'pixel644', 'pixel645', 'pixel646', 'pixel647', 'pixel648', 'pixel649', 'pixel650', 'pixel651', 'pixel652', 'pixel653', 'pixel654', 'pixel655', 'pixel656', 'pixel657', 'pixel658', 'pixel659', 'pixel660', 'pixel661', 'pixel662', 'pixel663', 'pixel664', 'pixel665', 'pixel666', 'pixel667', 'pixel668', 'pixel669', 'pixel670', 'pixel671', 'pixel672', 'pixel673', 'pixel674', 'pixel675', 'pixel676', 'pixel677', 'pixel678', 'pixel679', 'pixel680', 'pixel681', 'pixel682', 'pixel683', 'pixel684', 'pixel685', 'pixel686', 'pixel687', 'pixel688', 'pixel689', 'pixel690', 'pixel691', 'pixel692', 'pixel693', 'pixel694', 'pixel695', 'pixel696', 'pixel697', 'pixel698', 'pixel699', 'pixel700', 'pixel701', 'pixel702', 'pixel703', 'pixel704', 'pixel705', 'pixel706', 'pixel707', 'pixel708', 'pixel709', 'pixel710', 'pixel711', 'pixel712', 'pixel713', 'pixel714', 'pixel715', 'pixel716', 'pixel717', 'pixel718', 'pixel719', 'pixel720', 'pixel721', 'pixel722', 'pixel723', 'pixel724', 'pixel725', 'pixel726', 'pixel727', 'pixel728', 'pixel729', 'pixel730', 'pixel731', 'pixel732', 'pixel733', 'pixel734', 'pixel735', 'pixel736', 'pixel737', 'pixel738', 'pixel739', 'pixel740', 'pixel741', 'pixel742', 'pixel743', 'pixel744', 'pixel745', 'pixel746', 'pixel747', 'pixel748', 'pixel749', 'pixel750', 'pixel751', 'pixel752', 'pixel753', 'pixel754', 'pixel755', 'pixel756', 'pixel757', 'pixel758', 'pixel759', 'pixel760', 'pixel761', 'pixel762', 'pixel763', 'pixel764', 'pixel765', 'pixel766', 'pixel767', 'pixel768', 'pixel769', 'pixel770', 'pixel771', 'pixel772', 'pixel773', 'pixel774', 'pixel775', 'pixel776', 'pixel777', 'pixel778', 'pixel779', 'pixel780', 'pixel781', 'pixel782', 'pixel783', 'pixel784'], 'target_names': ['class'], 'DESCR': "**Author**: Yann LeCun, Corinna Cortes, Christopher J.C. Burges  \n**Source**: [MNIST Website](http://yann.lecun.com/exdb/mnist/) - Date unknown  \n**Please cite**:  \n\nThe MNIST database of handwritten digits with 784 features, raw data available at: http://yann.lecun.com/exdb/mnist/. It can be split in a training set of the first 60,000 examples, and a test set of 10,000 examples  \n\nIt is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting. The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.  \n\nWith some classification methods (particularly template-based methods, such as SVM and K-nearest neighbors), the error rate improves when the digits are centered by bounding box rather than center of mass. If you do this kind of pre-processing, you should report it in your publications. The MNIST database was constructed from NIST's NIST originally designated SD-3 as their training set and SD-1 as their test set. However, SD-3 is much cleaner and easier to recognize than SD-1. The reason for this can be found on the fact that SD-3 was collected among Census Bureau employees, while SD-1 was collected among high-school students. Drawing sensible conclusions from learning experiments requires that the result be independent of the choice of training set and test among the complete set of samples. Therefore it was necessary to build a new database by mixing NIST's datasets.  \n\nThe MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers. We made sure that the sets of writers of the training set and test set were disjoint. SD-1 contains 58,527 digit images written by 500 different writers. In contrast to SD-3, where blocks of data from each writer appeared in sequence, the data in SD-1 is scrambled. Writer identities for SD-1 is available and we used this information to unscramble the writers. We then split SD-1 in two: characters written by the first 250 writers went into our new training set. The remaining 250 writers were placed in our test set. Thus we had two sets with nearly 30,000 examples each. The new training set was completed with enough examples from SD-3, starting at pattern # 0, to make a full set of 60,000 training patterns. Similarly, the new test set was completed with SD-3 examples starting at pattern # 35,000 to make a full set with 60,000 test patterns. Only a subset of 10,000 test images (5,000 from SD-1 and 5,000 from SD-3) is available on this site. The full 60,000 sample training set is available.\n\nDownloaded from openml.org.", 'details': {'id': '554', 'name': 'mnist_784', 'version': '1', 'description_version': '1', 'format': 'ARFF', 'creator': ['Yann LeCun', 'Corinna Cortes', 'Christopher J.C. Burges'], 'upload_date': '2014-09-29T03:28:38', 'language': 'English', 'licence': 'Public', 'url': 'https://www.openml.org/data/v1/download/52667/mnist_784.arff', 'file_id': '52667', 'default_target_attribute': 'class', 'tag': ['AzurePilot', 'OpenML-CC18', 'OpenML100', 'study_1', 'study_123', 'study_41', 'study_99', 'vision'], 'visibility': 'public', 'status': 'active', 'processing_date': '2020-11-20 20:12:09', 'md5_checksum': '0298d579eb1b86163de7723944c7e495'}, 'categories': {}, 'url': 'https://www.openml.org/d/554'}
(70000, 784)
(60000, 784)
(60000,)
(10000, 784)
(10000,)
12.444964293999998
0.9688
537.592075492

通过结果我们可以看出识别准确率为96.88%,耗时差不多9分钟。现在我们对原始数据进行降维(由于时间关系,我们将原始数据的KNN分类给屏蔽了)

import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.neighbors import KNeighborsClassifier
from sklearn.decomposition import PCA
import timeit

if __name__ == "__main__":

    mnist = fetch_openml('mnist_784')
    print(mnist)
    X, y = mnist['data'], mnist['target']
    print(X.shape)
    X_train = np.array(X[:60000], dtype=float)
    y_train = np.array(y[:60000], dtype=float)
    X_test = np.array(X[60000:], dtype=float)
    y_test = np.array(y[60000:], dtype=float)
    # print(X_train.shape)
    # print(y_train.shape)
    # print(X_test.shape)
    # print(y_test.shape)
    # 对原始数据进行KNN分类
    # start_time = timeit.default_timer()
    # knn_clf = KNeighborsClassifier()
    # knn_clf.fit(X_train, y_train)
    # print(timeit.default_timer() - start_time)
    # start_time = timeit.default_timer()
    # print(knn_clf.score(X_test, y_test))
    # print(timeit.default_timer() - start_time)
    # 对原始数据进行降维,保留90%的方差比例
    pca = PCA(0.9)
    pca.fit(X_train)
    X_train_reduction = pca.transform(X_train)
    print(X_train_reduction.shape)

运行结果

{'data': array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]]), 'target': array(['5', '0', '4', ..., '4', '5', '6'], dtype=object), 'frame': None, 'feature_names': ['pixel1', 'pixel2', 'pixel3', 'pixel4', 'pixel5', 'pixel6', 'pixel7', 'pixel8', 'pixel9', 'pixel10', 'pixel11', 'pixel12', 'pixel13', 'pixel14', 'pixel15', 'pixel16', 'pixel17', 'pixel18', 'pixel19', 'pixel20', 'pixel21', 'pixel22', 'pixel23', 'pixel24', 'pixel25', 'pixel26', 'pixel27', 'pixel28', 'pixel29', 'pixel30', 'pixel31', 'pixel32', 'pixel33', 'pixel34', 'pixel35', 'pixel36', 'pixel37', 'pixel38', 'pixel39', 'pixel40', 'pixel41', 'pixel42', 'pixel43', 'pixel44', 'pixel45', 'pixel46', 'pixel47', 'pixel48', 'pixel49', 'pixel50', 'pixel51', 'pixel52', 'pixel53', 'pixel54', 'pixel55', 'pixel56', 'pixel57', 'pixel58', 'pixel59', 'pixel60', 'pixel61', 'pixel62', 'pixel63', 'pixel64', 'pixel65', 'pixel66', 'pixel67', 'pixel68', 'pixel69', 'pixel70', 'pixel71', 'pixel72', 'pixel73', 'pixel74', 'pixel75', 'pixel76', 'pixel77', 'pixel78', 'pixel79', 'pixel80', 'pixel81', 'pixel82', 'pixel83', 'pixel84', 'pixel85', 'pixel86', 'pixel87', 'pixel88', 'pixel89', 'pixel90', 'pixel91', 'pixel92', 'pixel93', 'pixel94', 'pixel95', 'pixel96', 'pixel97', 'pixel98', 'pixel99', 'pixel100', 'pixel101', 'pixel102', 'pixel103', 'pixel104', 'pixel105', 'pixel106', 'pixel107', 'pixel108', 'pixel109', 'pixel110', 'pixel111', 'pixel112', 'pixel113', 'pixel114', 'pixel115', 'pixel116', 'pixel117', 'pixel118', 'pixel119', 'pixel120', 'pixel121', 'pixel122', 'pixel123', 'pixel124', 'pixel125', 'pixel126', 'pixel127', 'pixel128', 'pixel129', 'pixel130', 'pixel131', 'pixel132', 'pixel133', 'pixel134', 'pixel135', 'pixel136', 'pixel137', 'pixel138', 'pixel139', 'pixel140', 'pixel141', 'pixel142', 'pixel143', 'pixel144', 'pixel145', 'pixel146', 'pixel147', 'pixel148', 'pixel149', 'pixel150', 'pixel151', 'pixel152', 'pixel153', 'pixel154', 'pixel155', 'pixel156', 'pixel157', 'pixel158', 'pixel159', 'pixel160', 'pixel161', 'pixel162', 'pixel163', 'pixel164', 'pixel165', 'pixel166', 'pixel167', 'pixel168', 'pixel169', 'pixel170', 'pixel171', 'pixel172', 'pixel173', 'pixel174', 'pixel175', 'pixel176', 'pixel177', 'pixel178', 'pixel179', 'pixel180', 'pixel181', 'pixel182', 'pixel183', 'pixel184', 'pixel185', 'pixel186', 'pixel187', 'pixel188', 'pixel189', 'pixel190', 'pixel191', 'pixel192', 'pixel193', 'pixel194', 'pixel195', 'pixel196', 'pixel197', 'pixel198', 'pixel199', 'pixel200', 'pixel201', 'pixel202', 'pixel203', 'pixel204', 'pixel205', 'pixel206', 'pixel207', 'pixel208', 'pixel209', 'pixel210', 'pixel211', 'pixel212', 'pixel213', 'pixel214', 'pixel215', 'pixel216', 'pixel217', 'pixel218', 'pixel219', 'pixel220', 'pixel221', 'pixel222', 'pixel223', 'pixel224', 'pixel225', 'pixel226', 'pixel227', 'pixel228', 'pixel229', 'pixel230', 'pixel231', 'pixel232', 'pixel233', 'pixel234', 'pixel235', 'pixel236', 'pixel237', 'pixel238', 'pixel239', 'pixel240', 'pixel241', 'pixel242', 'pixel243', 'pixel244', 'pixel245', 'pixel246', 'pixel247', 'pixel248', 'pixel249', 'pixel250', 'pixel251', 'pixel252', 'pixel253', 'pixel254', 'pixel255', 'pixel256', 'pixel257', 'pixel258', 'pixel259', 'pixel260', 'pixel261', 'pixel262', 'pixel263', 'pixel264', 'pixel265', 'pixel266', 'pixel267', 'pixel268', 'pixel269', 'pixel270', 'pixel271', 'pixel272', 'pixel273', 'pixel274', 'pixel275', 'pixel276', 'pixel277', 'pixel278', 'pixel279', 'pixel280', 'pixel281', 'pixel282', 'pixel283', 'pixel284', 'pixel285', 'pixel286', 'pixel287', 'pixel288', 'pixel289', 'pixel290', 'pixel291', 'pixel292', 'pixel293', 'pixel294', 'pixel295', 'pixel296', 'pixel297', 'pixel298', 'pixel299', 'pixel300', 'pixel301', 'pixel302', 'pixel303', 'pixel304', 'pixel305', 'pixel306', 'pixel307', 'pixel308', 'pixel309', 'pixel310', 'pixel311', 'pixel312', 'pixel313', 'pixel314', 'pixel315', 'pixel316', 'pixel317', 'pixel318', 'pixel319', 'pixel320', 'pixel321', 'pixel322', 'pixel323', 'pixel324', 'pixel325', 'pixel326', 'pixel327', 'pixel328', 'pixel329', 'pixel330', 'pixel331', 'pixel332', 'pixel333', 'pixel334', 'pixel335', 'pixel336', 'pixel337', 'pixel338', 'pixel339', 'pixel340', 'pixel341', 'pixel342', 'pixel343', 'pixel344', 'pixel345', 'pixel346', 'pixel347', 'pixel348', 'pixel349', 'pixel350', 'pixel351', 'pixel352', 'pixel353', 'pixel354', 'pixel355', 'pixel356', 'pixel357', 'pixel358', 'pixel359', 'pixel360', 'pixel361', 'pixel362', 'pixel363', 'pixel364', 'pixel365', 'pixel366', 'pixel367', 'pixel368', 'pixel369', 'pixel370', 'pixel371', 'pixel372', 'pixel373', 'pixel374', 'pixel375', 'pixel376', 'pixel377', 'pixel378', 'pixel379', 'pixel380', 'pixel381', 'pixel382', 'pixel383', 'pixel384', 'pixel385', 'pixel386', 'pixel387', 'pixel388', 'pixel389', 'pixel390', 'pixel391', 'pixel392', 'pixel393', 'pixel394', 'pixel395', 'pixel396', 'pixel397', 'pixel398', 'pixel399', 'pixel400', 'pixel401', 'pixel402', 'pixel403', 'pixel404', 'pixel405', 'pixel406', 'pixel407', 'pixel408', 'pixel409', 'pixel410', 'pixel411', 'pixel412', 'pixel413', 'pixel414', 'pixel415', 'pixel416', 'pixel417', 'pixel418', 'pixel419', 'pixel420', 'pixel421', 'pixel422', 'pixel423', 'pixel424', 'pixel425', 'pixel426', 'pixel427', 'pixel428', 'pixel429', 'pixel430', 'pixel431', 'pixel432', 'pixel433', 'pixel434', 'pixel435', 'pixel436', 'pixel437', 'pixel438', 'pixel439', 'pixel440', 'pixel441', 'pixel442', 'pixel443', 'pixel444', 'pixel445', 'pixel446', 'pixel447', 'pixel448', 'pixel449', 'pixel450', 'pixel451', 'pixel452', 'pixel453', 'pixel454', 'pixel455', 'pixel456', 'pixel457', 'pixel458', 'pixel459', 'pixel460', 'pixel461', 'pixel462', 'pixel463', 'pixel464', 'pixel465', 'pixel466', 'pixel467', 'pixel468', 'pixel469', 'pixel470', 'pixel471', 'pixel472', 'pixel473', 'pixel474', 'pixel475', 'pixel476', 'pixel477', 'pixel478', 'pixel479', 'pixel480', 'pixel481', 'pixel482', 'pixel483', 'pixel484', 'pixel485', 'pixel486', 'pixel487', 'pixel488', 'pixel489', 'pixel490', 'pixel491', 'pixel492', 'pixel493', 'pixel494', 'pixel495', 'pixel496', 'pixel497', 'pixel498', 'pixel499', 'pixel500', 'pixel501', 'pixel502', 'pixel503', 'pixel504', 'pixel505', 'pixel506', 'pixel507', 'pixel508', 'pixel509', 'pixel510', 'pixel511', 'pixel512', 'pixel513', 'pixel514', 'pixel515', 'pixel516', 'pixel517', 'pixel518', 'pixel519', 'pixel520', 'pixel521', 'pixel522', 'pixel523', 'pixel524', 'pixel525', 'pixel526', 'pixel527', 'pixel528', 'pixel529', 'pixel530', 'pixel531', 'pixel532', 'pixel533', 'pixel534', 'pixel535', 'pixel536', 'pixel537', 'pixel538', 'pixel539', 'pixel540', 'pixel541', 'pixel542', 'pixel543', 'pixel544', 'pixel545', 'pixel546', 'pixel547', 'pixel548', 'pixel549', 'pixel550', 'pixel551', 'pixel552', 'pixel553', 'pixel554', 'pixel555', 'pixel556', 'pixel557', 'pixel558', 'pixel559', 'pixel560', 'pixel561', 'pixel562', 'pixel563', 'pixel564', 'pixel565', 'pixel566', 'pixel567', 'pixel568', 'pixel569', 'pixel570', 'pixel571', 'pixel572', 'pixel573', 'pixel574', 'pixel575', 'pixel576', 'pixel577', 'pixel578', 'pixel579', 'pixel580', 'pixel581', 'pixel582', 'pixel583', 'pixel584', 'pixel585', 'pixel586', 'pixel587', 'pixel588', 'pixel589', 'pixel590', 'pixel591', 'pixel592', 'pixel593', 'pixel594', 'pixel595', 'pixel596', 'pixel597', 'pixel598', 'pixel599', 'pixel600', 'pixel601', 'pixel602', 'pixel603', 'pixel604', 'pixel605', 'pixel606', 'pixel607', 'pixel608', 'pixel609', 'pixel610', 'pixel611', 'pixel612', 'pixel613', 'pixel614', 'pixel615', 'pixel616', 'pixel617', 'pixel618', 'pixel619', 'pixel620', 'pixel621', 'pixel622', 'pixel623', 'pixel624', 'pixel625', 'pixel626', 'pixel627', 'pixel628', 'pixel629', 'pixel630', 'pixel631', 'pixel632', 'pixel633', 'pixel634', 'pixel635', 'pixel636', 'pixel637', 'pixel638', 'pixel639', 'pixel640', 'pixel641', 'pixel642', 'pixel643', 'pixel644', 'pixel645', 'pixel646', 'pixel647', 'pixel648', 'pixel649', 'pixel650', 'pixel651', 'pixel652', 'pixel653', 'pixel654', 'pixel655', 'pixel656', 'pixel657', 'pixel658', 'pixel659', 'pixel660', 'pixel661', 'pixel662', 'pixel663', 'pixel664', 'pixel665', 'pixel666', 'pixel667', 'pixel668', 'pixel669', 'pixel670', 'pixel671', 'pixel672', 'pixel673', 'pixel674', 'pixel675', 'pixel676', 'pixel677', 'pixel678', 'pixel679', 'pixel680', 'pixel681', 'pixel682', 'pixel683', 'pixel684', 'pixel685', 'pixel686', 'pixel687', 'pixel688', 'pixel689', 'pixel690', 'pixel691', 'pixel692', 'pixel693', 'pixel694', 'pixel695', 'pixel696', 'pixel697', 'pixel698', 'pixel699', 'pixel700', 'pixel701', 'pixel702', 'pixel703', 'pixel704', 'pixel705', 'pixel706', 'pixel707', 'pixel708', 'pixel709', 'pixel710', 'pixel711', 'pixel712', 'pixel713', 'pixel714', 'pixel715', 'pixel716', 'pixel717', 'pixel718', 'pixel719', 'pixel720', 'pixel721', 'pixel722', 'pixel723', 'pixel724', 'pixel725', 'pixel726', 'pixel727', 'pixel728', 'pixel729', 'pixel730', 'pixel731', 'pixel732', 'pixel733', 'pixel734', 'pixel735', 'pixel736', 'pixel737', 'pixel738', 'pixel739', 'pixel740', 'pixel741', 'pixel742', 'pixel743', 'pixel744', 'pixel745', 'pixel746', 'pixel747', 'pixel748', 'pixel749', 'pixel750', 'pixel751', 'pixel752', 'pixel753', 'pixel754', 'pixel755', 'pixel756', 'pixel757', 'pixel758', 'pixel759', 'pixel760', 'pixel761', 'pixel762', 'pixel763', 'pixel764', 'pixel765', 'pixel766', 'pixel767', 'pixel768', 'pixel769', 'pixel770', 'pixel771', 'pixel772', 'pixel773', 'pixel774', 'pixel775', 'pixel776', 'pixel777', 'pixel778', 'pixel779', 'pixel780', 'pixel781', 'pixel782', 'pixel783', 'pixel784'], 'target_names': ['class'], 'DESCR': "**Author**: Yann LeCun, Corinna Cortes, Christopher J.C. Burges  \n**Source**: [MNIST Website](http://yann.lecun.com/exdb/mnist/) - Date unknown  \n**Please cite**:  \n\nThe MNIST database of handwritten digits with 784 features, raw data available at: http://yann.lecun.com/exdb/mnist/. It can be split in a training set of the first 60,000 examples, and a test set of 10,000 examples  \n\nIt is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting. The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.  \n\nWith some classification methods (particularly template-based methods, such as SVM and K-nearest neighbors), the error rate improves when the digits are centered by bounding box rather than center of mass. If you do this kind of pre-processing, you should report it in your publications. The MNIST database was constructed from NIST's NIST originally designated SD-3 as their training set and SD-1 as their test set. However, SD-3 is much cleaner and easier to recognize than SD-1. The reason for this can be found on the fact that SD-3 was collected among Census Bureau employees, while SD-1 was collected among high-school students. Drawing sensible conclusions from learning experiments requires that the result be independent of the choice of training set and test among the complete set of samples. Therefore it was necessary to build a new database by mixing NIST's datasets.  \n\nThe MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers. We made sure that the sets of writers of the training set and test set were disjoint. SD-1 contains 58,527 digit images written by 500 different writers. In contrast to SD-3, where blocks of data from each writer appeared in sequence, the data in SD-1 is scrambled. Writer identities for SD-1 is available and we used this information to unscramble the writers. We then split SD-1 in two: characters written by the first 250 writers went into our new training set. The remaining 250 writers were placed in our test set. Thus we had two sets with nearly 30,000 examples each. The new training set was completed with enough examples from SD-3, starting at pattern # 0, to make a full set of 60,000 training patterns. Similarly, the new test set was completed with SD-3 examples starting at pattern # 35,000 to make a full set with 60,000 test patterns. Only a subset of 10,000 test images (5,000 from SD-1 and 5,000 from SD-3) is available on this site. The full 60,000 sample training set is available.\n\nDownloaded from openml.org.", 'details': {'id': '554', 'name': 'mnist_784', 'version': '1', 'description_version': '1', 'format': 'ARFF', 'creator': ['Yann LeCun', 'Corinna Cortes', 'Christopher J.C. Burges'], 'upload_date': '2014-09-29T03:28:38', 'language': 'English', 'licence': 'Public', 'url': 'https://www.openml.org/data/v1/download/52667/mnist_784.arff', 'file_id': '52667', 'default_target_attribute': 'class', 'tag': ['AzurePilot', 'OpenML-CC18', 'OpenML100', 'study_1', 'study_123', 'study_41', 'study_99', 'vision'], 'visibility': 'public', 'status': 'active', 'processing_date': '2020-11-20 20:12:09', 'md5_checksum': '0298d579eb1b86163de7723944c7e495'}, 'categories': {}, 'url': 'https://www.openml.org/d/554'}
(70000, 784)
(60000, 87)

通过结果可以看到,保留90%的方差比例,降维后从784维降到了87维。现在我们对降维后的数据进行KNN分类

import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.neighbors import KNeighborsClassifier
from sklearn.decomposition import PCA
import timeit

if __name__ == "__main__":

    mnist = fetch_openml('mnist_784')
    print(mnist)
    X, y = mnist['data'], mnist['target']
    print(X.shape)
    X_train = np.array(X[:60000], dtype=float)
    y_train = np.array(y[:60000], dtype=float)
    X_test = np.array(X[60000:], dtype=float)
    y_test = np.array(y[60000:], dtype=float)
    # print(X_train.shape)
    # print(y_train.shape)
    # print(X_test.shape)
    # print(y_test.shape)
    # 对原始数据进行KNN分类
    # start_time = timeit.default_timer()
    # knn_clf = KNeighborsClassifier()
    # knn_clf.fit(X_train, y_train)
    # print(timeit.default_timer() - start_time)
    # start_time = timeit.default_timer()
    # print(knn_clf.score(X_test, y_test))
    # print(timeit.default_timer() - start_time)
    # 对原始数据进行降维,保留90%的方差比例
    pca = PCA(0.9)
    pca.fit(X_train)
    X_train_reduction = pca.transform(X_train)
    X_test_reduction = pca.transform(X_test)
    print(X_train_reduction.shape)
    # 对降维后的数据集进行KNN分类
    start_time = timeit.default_timer()
    knn_clf = KNeighborsClassifier()
    knn_clf.fit(X_train_reduction, y_train)
    print(timeit.default_timer() - start_time)
    start_time = timeit.default_timer()
    print(knn_clf.score(X_test_reduction, y_test))
    print(timeit.default_timer() - start_time)

运行结果

{'data': array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]]), 'target': array(['5', '0', '4', ..., '4', '5', '6'], dtype=object), 'frame': None, 'feature_names': ['pixel1', 'pixel2', 'pixel3', 'pixel4', 'pixel5', 'pixel6', 'pixel7', 'pixel8', 'pixel9', 'pixel10', 'pixel11', 'pixel12', 'pixel13', 'pixel14', 'pixel15', 'pixel16', 'pixel17', 'pixel18', 'pixel19', 'pixel20', 'pixel21', 'pixel22', 'pixel23', 'pixel24', 'pixel25', 'pixel26', 'pixel27', 'pixel28', 'pixel29', 'pixel30', 'pixel31', 'pixel32', 'pixel33', 'pixel34', 'pixel35', 'pixel36', 'pixel37', 'pixel38', 'pixel39', 'pixel40', 'pixel41', 'pixel42', 'pixel43', 'pixel44', 'pixel45', 'pixel46', 'pixel47', 'pixel48', 'pixel49', 'pixel50', 'pixel51', 'pixel52', 'pixel53', 'pixel54', 'pixel55', 'pixel56', 'pixel57', 'pixel58', 'pixel59', 'pixel60', 'pixel61', 'pixel62', 'pixel63', 'pixel64', 'pixel65', 'pixel66', 'pixel67', 'pixel68', 'pixel69', 'pixel70', 'pixel71', 'pixel72', 'pixel73', 'pixel74', 'pixel75', 'pixel76', 'pixel77', 'pixel78', 'pixel79', 'pixel80', 'pixel81', 'pixel82', 'pixel83', 'pixel84', 'pixel85', 'pixel86', 'pixel87', 'pixel88', 'pixel89', 'pixel90', 'pixel91', 'pixel92', 'pixel93', 'pixel94', 'pixel95', 'pixel96', 'pixel97', 'pixel98', 'pixel99', 'pixel100', 'pixel101', 'pixel102', 'pixel103', 'pixel104', 'pixel105', 'pixel106', 'pixel107', 'pixel108', 'pixel109', 'pixel110', 'pixel111', 'pixel112', 'pixel113', 'pixel114', 'pixel115', 'pixel116', 'pixel117', 'pixel118', 'pixel119', 'pixel120', 'pixel121', 'pixel122', 'pixel123', 'pixel124', 'pixel125', 'pixel126', 'pixel127', 'pixel128', 'pixel129', 'pixel130', 'pixel131', 'pixel132', 'pixel133', 'pixel134', 'pixel135', 'pixel136', 'pixel137', 'pixel138', 'pixel139', 'pixel140', 'pixel141', 'pixel142', 'pixel143', 'pixel144', 'pixel145', 'pixel146', 'pixel147', 'pixel148', 'pixel149', 'pixel150', 'pixel151', 'pixel152', 'pixel153', 'pixel154', 'pixel155', 'pixel156', 'pixel157', 'pixel158', 'pixel159', 'pixel160', 'pixel161', 'pixel162', 'pixel163', 'pixel164', 'pixel165', 'pixel166', 'pixel167', 'pixel168', 'pixel169', 'pixel170', 'pixel171', 'pixel172', 'pixel173', 'pixel174', 'pixel175', 'pixel176', 'pixel177', 'pixel178', 'pixel179', 'pixel180', 'pixel181', 'pixel182', 'pixel183', 'pixel184', 'pixel185', 'pixel186', 'pixel187', 'pixel188', 'pixel189', 'pixel190', 'pixel191', 'pixel192', 'pixel193', 'pixel194', 'pixel195', 'pixel196', 'pixel197', 'pixel198', 'pixel199', 'pixel200', 'pixel201', 'pixel202', 'pixel203', 'pixel204', 'pixel205', 'pixel206', 'pixel207', 'pixel208', 'pixel209', 'pixel210', 'pixel211', 'pixel212', 'pixel213', 'pixel214', 'pixel215', 'pixel216', 'pixel217', 'pixel218', 'pixel219', 'pixel220', 'pixel221', 'pixel222', 'pixel223', 'pixel224', 'pixel225', 'pixel226', 'pixel227', 'pixel228', 'pixel229', 'pixel230', 'pixel231', 'pixel232', 'pixel233', 'pixel234', 'pixel235', 'pixel236', 'pixel237', 'pixel238', 'pixel239', 'pixel240', 'pixel241', 'pixel242', 'pixel243', 'pixel244', 'pixel245', 'pixel246', 'pixel247', 'pixel248', 'pixel249', 'pixel250', 'pixel251', 'pixel252', 'pixel253', 'pixel254', 'pixel255', 'pixel256', 'pixel257', 'pixel258', 'pixel259', 'pixel260', 'pixel261', 'pixel262', 'pixel263', 'pixel264', 'pixel265', 'pixel266', 'pixel267', 'pixel268', 'pixel269', 'pixel270', 'pixel271', 'pixel272', 'pixel273', 'pixel274', 'pixel275', 'pixel276', 'pixel277', 'pixel278', 'pixel279', 'pixel280', 'pixel281', 'pixel282', 'pixel283', 'pixel284', 'pixel285', 'pixel286', 'pixel287', 'pixel288', 'pixel289', 'pixel290', 'pixel291', 'pixel292', 'pixel293', 'pixel294', 'pixel295', 'pixel296', 'pixel297', 'pixel298', 'pixel299', 'pixel300', 'pixel301', 'pixel302', 'pixel303', 'pixel304', 'pixel305', 'pixel306', 'pixel307', 'pixel308', 'pixel309', 'pixel310', 'pixel311', 'pixel312', 'pixel313', 'pixel314', 'pixel315', 'pixel316', 'pixel317', 'pixel318', 'pixel319', 'pixel320', 'pixel321', 'pixel322', 'pixel323', 'pixel324', 'pixel325', 'pixel326', 'pixel327', 'pixel328', 'pixel329', 'pixel330', 'pixel331', 'pixel332', 'pixel333', 'pixel334', 'pixel335', 'pixel336', 'pixel337', 'pixel338', 'pixel339', 'pixel340', 'pixel341', 'pixel342', 'pixel343', 'pixel344', 'pixel345', 'pixel346', 'pixel347', 'pixel348', 'pixel349', 'pixel350', 'pixel351', 'pixel352', 'pixel353', 'pixel354', 'pixel355', 'pixel356', 'pixel357', 'pixel358', 'pixel359', 'pixel360', 'pixel361', 'pixel362', 'pixel363', 'pixel364', 'pixel365', 'pixel366', 'pixel367', 'pixel368', 'pixel369', 'pixel370', 'pixel371', 'pixel372', 'pixel373', 'pixel374', 'pixel375', 'pixel376', 'pixel377', 'pixel378', 'pixel379', 'pixel380', 'pixel381', 'pixel382', 'pixel383', 'pixel384', 'pixel385', 'pixel386', 'pixel387', 'pixel388', 'pixel389', 'pixel390', 'pixel391', 'pixel392', 'pixel393', 'pixel394', 'pixel395', 'pixel396', 'pixel397', 'pixel398', 'pixel399', 'pixel400', 'pixel401', 'pixel402', 'pixel403', 'pixel404', 'pixel405', 'pixel406', 'pixel407', 'pixel408', 'pixel409', 'pixel410', 'pixel411', 'pixel412', 'pixel413', 'pixel414', 'pixel415', 'pixel416', 'pixel417', 'pixel418', 'pixel419', 'pixel420', 'pixel421', 'pixel422', 'pixel423', 'pixel424', 'pixel425', 'pixel426', 'pixel427', 'pixel428', 'pixel429', 'pixel430', 'pixel431', 'pixel432', 'pixel433', 'pixel434', 'pixel435', 'pixel436', 'pixel437', 'pixel438', 'pixel439', 'pixel440', 'pixel441', 'pixel442', 'pixel443', 'pixel444', 'pixel445', 'pixel446', 'pixel447', 'pixel448', 'pixel449', 'pixel450', 'pixel451', 'pixel452', 'pixel453', 'pixel454', 'pixel455', 'pixel456', 'pixel457', 'pixel458', 'pixel459', 'pixel460', 'pixel461', 'pixel462', 'pixel463', 'pixel464', 'pixel465', 'pixel466', 'pixel467', 'pixel468', 'pixel469', 'pixel470', 'pixel471', 'pixel472', 'pixel473', 'pixel474', 'pixel475', 'pixel476', 'pixel477', 'pixel478', 'pixel479', 'pixel480', 'pixel481', 'pixel482', 'pixel483', 'pixel484', 'pixel485', 'pixel486', 'pixel487', 'pixel488', 'pixel489', 'pixel490', 'pixel491', 'pixel492', 'pixel493', 'pixel494', 'pixel495', 'pixel496', 'pixel497', 'pixel498', 'pixel499', 'pixel500', 'pixel501', 'pixel502', 'pixel503', 'pixel504', 'pixel505', 'pixel506', 'pixel507', 'pixel508', 'pixel509', 'pixel510', 'pixel511', 'pixel512', 'pixel513', 'pixel514', 'pixel515', 'pixel516', 'pixel517', 'pixel518', 'pixel519', 'pixel520', 'pixel521', 'pixel522', 'pixel523', 'pixel524', 'pixel525', 'pixel526', 'pixel527', 'pixel528', 'pixel529', 'pixel530', 'pixel531', 'pixel532', 'pixel533', 'pixel534', 'pixel535', 'pixel536', 'pixel537', 'pixel538', 'pixel539', 'pixel540', 'pixel541', 'pixel542', 'pixel543', 'pixel544', 'pixel545', 'pixel546', 'pixel547', 'pixel548', 'pixel549', 'pixel550', 'pixel551', 'pixel552', 'pixel553', 'pixel554', 'pixel555', 'pixel556', 'pixel557', 'pixel558', 'pixel559', 'pixel560', 'pixel561', 'pixel562', 'pixel563', 'pixel564', 'pixel565', 'pixel566', 'pixel567', 'pixel568', 'pixel569', 'pixel570', 'pixel571', 'pixel572', 'pixel573', 'pixel574', 'pixel575', 'pixel576', 'pixel577', 'pixel578', 'pixel579', 'pixel580', 'pixel581', 'pixel582', 'pixel583', 'pixel584', 'pixel585', 'pixel586', 'pixel587', 'pixel588', 'pixel589', 'pixel590', 'pixel591', 'pixel592', 'pixel593', 'pixel594', 'pixel595', 'pixel596', 'pixel597', 'pixel598', 'pixel599', 'pixel600', 'pixel601', 'pixel602', 'pixel603', 'pixel604', 'pixel605', 'pixel606', 'pixel607', 'pixel608', 'pixel609', 'pixel610', 'pixel611', 'pixel612', 'pixel613', 'pixel614', 'pixel615', 'pixel616', 'pixel617', 'pixel618', 'pixel619', 'pixel620', 'pixel621', 'pixel622', 'pixel623', 'pixel624', 'pixel625', 'pixel626', 'pixel627', 'pixel628', 'pixel629', 'pixel630', 'pixel631', 'pixel632', 'pixel633', 'pixel634', 'pixel635', 'pixel636', 'pixel637', 'pixel638', 'pixel639', 'pixel640', 'pixel641', 'pixel642', 'pixel643', 'pixel644', 'pixel645', 'pixel646', 'pixel647', 'pixel648', 'pixel649', 'pixel650', 'pixel651', 'pixel652', 'pixel653', 'pixel654', 'pixel655', 'pixel656', 'pixel657', 'pixel658', 'pixel659', 'pixel660', 'pixel661', 'pixel662', 'pixel663', 'pixel664', 'pixel665', 'pixel666', 'pixel667', 'pixel668', 'pixel669', 'pixel670', 'pixel671', 'pixel672', 'pixel673', 'pixel674', 'pixel675', 'pixel676', 'pixel677', 'pixel678', 'pixel679', 'pixel680', 'pixel681', 'pixel682', 'pixel683', 'pixel684', 'pixel685', 'pixel686', 'pixel687', 'pixel688', 'pixel689', 'pixel690', 'pixel691', 'pixel692', 'pixel693', 'pixel694', 'pixel695', 'pixel696', 'pixel697', 'pixel698', 'pixel699', 'pixel700', 'pixel701', 'pixel702', 'pixel703', 'pixel704', 'pixel705', 'pixel706', 'pixel707', 'pixel708', 'pixel709', 'pixel710', 'pixel711', 'pixel712', 'pixel713', 'pixel714', 'pixel715', 'pixel716', 'pixel717', 'pixel718', 'pixel719', 'pixel720', 'pixel721', 'pixel722', 'pixel723', 'pixel724', 'pixel725', 'pixel726', 'pixel727', 'pixel728', 'pixel729', 'pixel730', 'pixel731', 'pixel732', 'pixel733', 'pixel734', 'pixel735', 'pixel736', 'pixel737', 'pixel738', 'pixel739', 'pixel740', 'pixel741', 'pixel742', 'pixel743', 'pixel744', 'pixel745', 'pixel746', 'pixel747', 'pixel748', 'pixel749', 'pixel750', 'pixel751', 'pixel752', 'pixel753', 'pixel754', 'pixel755', 'pixel756', 'pixel757', 'pixel758', 'pixel759', 'pixel760', 'pixel761', 'pixel762', 'pixel763', 'pixel764', 'pixel765', 'pixel766', 'pixel767', 'pixel768', 'pixel769', 'pixel770', 'pixel771', 'pixel772', 'pixel773', 'pixel774', 'pixel775', 'pixel776', 'pixel777', 'pixel778', 'pixel779', 'pixel780', 'pixel781', 'pixel782', 'pixel783', 'pixel784'], 'target_names': ['class'], 'DESCR': "**Author**: Yann LeCun, Corinna Cortes, Christopher J.C. Burges  \n**Source**: [MNIST Website](http://yann.lecun.com/exdb/mnist/) - Date unknown  \n**Please cite**:  \n\nThe MNIST database of handwritten digits with 784 features, raw data available at: http://yann.lecun.com/exdb/mnist/. It can be split in a training set of the first 60,000 examples, and a test set of 10,000 examples  \n\nIt is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting. The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.  \n\nWith some classification methods (particularly template-based methods, such as SVM and K-nearest neighbors), the error rate improves when the digits are centered by bounding box rather than center of mass. If you do this kind of pre-processing, you should report it in your publications. The MNIST database was constructed from NIST's NIST originally designated SD-3 as their training set and SD-1 as their test set. However, SD-3 is much cleaner and easier to recognize than SD-1. The reason for this can be found on the fact that SD-3 was collected among Census Bureau employees, while SD-1 was collected among high-school students. Drawing sensible conclusions from learning experiments requires that the result be independent of the choice of training set and test among the complete set of samples. Therefore it was necessary to build a new database by mixing NIST's datasets.  \n\nThe MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers. We made sure that the sets of writers of the training set and test set were disjoint. SD-1 contains 58,527 digit images written by 500 different writers. In contrast to SD-3, where blocks of data from each writer appeared in sequence, the data in SD-1 is scrambled. Writer identities for SD-1 is available and we used this information to unscramble the writers. We then split SD-1 in two: characters written by the first 250 writers went into our new training set. The remaining 250 writers were placed in our test set. Thus we had two sets with nearly 30,000 examples each. The new training set was completed with enough examples from SD-3, starting at pattern # 0, to make a full set of 60,000 training patterns. Similarly, the new test set was completed with SD-3 examples starting at pattern # 35,000 to make a full set with 60,000 test patterns. Only a subset of 10,000 test images (5,000 from SD-1 and 5,000 from SD-3) is available on this site. The full 60,000 sample training set is available.\n\nDownloaded from openml.org.", 'details': {'id': '554', 'name': 'mnist_784', 'version': '1', 'description_version': '1', 'format': 'ARFF', 'creator': ['Yann LeCun', 'Corinna Cortes', 'Christopher J.C. Burges'], 'upload_date': '2014-09-29T03:28:38', 'language': 'English', 'licence': 'Public', 'url': 'https://www.openml.org/data/v1/download/52667/mnist_784.arff', 'file_id': '52667', 'default_target_attribute': 'class', 'tag': ['AzurePilot', 'OpenML-CC18', 'OpenML100', 'study_1', 'study_123', 'study_41', 'study_99', 'vision'], 'visibility': 'public', 'status': 'active', 'processing_date': '2020-11-20 20:12:09', 'md5_checksum': '0298d579eb1b86163de7723944c7e495'}, 'categories': {}, 'url': 'https://www.openml.org/d/554'}
(70000, 784)
(60000, 87)
0.353538382
0.9728
66.771680345

通过结果,我们惊讶的发现降维后识别准确率反而提升了(原来是96.88%,现在是97.28%),而时间也大大的缩短,从差不多9分钟降到了1分6秒。其实这是PCA的另外一个用途——降噪。这个过程不仅仅是对原始数据进行了降维,还有可能将原始数据所包含的噪音给消除了,这使得我们可以更好的,更准确的拿到我们的数据集对应的特征,从而使得我们的识别准确率得到了提升。

三维数据上的PCA

之前我们的数据都是构建在二维图形上的,现在来构建一个三维数据图形

import numpy as np
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt

if __name__ == "__main__":

    # 在三维空间随机生成100个样本点
    np.random.seed(8888)
    X_random = np.random.random(size=(100, 3))
    ax = plt.axes(projection='3d')
    ax.scatter3D(X_random[:, 0], X_random[:, 1], X_random[:, 2])
    plt.show()

对于三维数据,我们依然要进行一个demean操作(样本均值归0)

import numpy as np
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt

if __name__ == "__main__":

    # 在三维空间随机生成100个样本点
    np.random.seed(8888)
    X_random = np.random.random(size=(100, 3))
    # ax = plt.axes(projection='3d')
    # ax.scatter3D(X_random[:, 0], X_random[:, 1], X_random[:, 2])
    # plt.show()

    def demean(X):
        return X - np.mean(X, axis=0)
    # demean操作
    X_demean = demean(X_random)
    ax = plt.axes(projection='3d')
    ax.scatter3D(X_demean[:, 0], X_demean[:, 1],X_demean[:, 2])
    plt.show()

现在我们加入梯度上升法,来求出第二主成分分量

import numpy as np
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt

if __name__ == "__main__":

    # 在三维空间随机生成100个样本点
    np.random.seed(8888)
    X_random = np.random.random(size=(100, 3))
    # ax = plt.axes(projection='3d')
    # ax.scatter3D(X_random[:, 0], X_random[:, 1], X_random[:, 2])
    # plt.show()

    def demean(X):
        return X - np.mean(X, axis=0)
    # demean操作
    X_demean = demean(X_random)
    # ax = plt.axes(projection='3d')
    # ax.scatter3D(X_demean[:, 0], X_demean[:, 1],X_demean[:, 2])
    # plt.show()


    def f(w, X):
        # 目标函数
        return np.sum((X.dot(w) ** 2)) / len(X)


    def df(w, X):
        # 梯度
        return X.T.dot(X.dot(w)) * 2 / len(X)


    def direction(w):
        # 把w变成单位向量,只表示方向
        return w / np.linalg.norm(w)


    def first_component(X, initial_w, eta, n_iters=1e4, epsilon=1e-8):
        """
        梯度上升法,求出第一主成分
        :param X: 数据矩阵
        :param initial_w: 初始的常数向量,这里需要注意的是真正待求的是常数向量,求偏导的也是常数向量
        :param eta: 步长,学习率
        :param n_iters: 最大迭代次数
        :param epsilon: 误差值
        :return:
        """
        # 将初始的常数向量变成单位向量
        w = direction(initial_w)
        # 真实迭代次数
        cur_iter = 0
        while cur_iter < n_iters:
            # 获取梯度
            gradient = df(w, X)
            last_w = w
            # 迭代更新w,不断顺着梯度方向寻找新的w
            # 跟梯度下降法不同的是,梯度下降法是-,梯度上升法这里是+
            w = w + eta * gradient
            # 将获取的新的w重新变成单位向量
            w = direction(w)
            # 计算前后两次迭代后的目标函数差值的绝对值
            if abs(f(w, X) - f(last_w, X)) < epsilon:
                break
            # 更新迭代次数
            cur_iter += 1
        return w

    initial_w = np.random.random(X_demean.shape[1])
    eta = 0.01
    # 求出第一主成分
    w1 = first_component(X_demean, initial_w, eta)
    # 求第二主成分分量
    X2 = np.empty(X_demean.shape)
    for i in range(len(X_demean)):
        X2[i] = X_demean[i] - X_demean[i].dot(w1) * w1
    ax = plt.axes(projection='3d')
    ax.scatter3D(X2[:, 0], X2[:, 1], X2[:, 2])
    plt.show()

现在我们来求第三主成分分量

import numpy as np
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt

if __name__ == "__main__":

    # 在三维空间随机生成100个样本点
    np.random.seed(8888)
    X_random = np.random.random(size=(100, 3))
    # ax = plt.axes(projection='3d')
    # ax.scatter3D(X_random[:, 0], X_random[:, 1], X_random[:, 2])
    # plt.show()

    def demean(X):
        return X - np.mean(X, axis=0)
    # demean操作
    X_demean = demean(X_random)
    # ax = plt.axes(projection='3d')
    # ax.scatter3D(X_demean[:, 0], X_demean[:, 1],X_demean[:, 2])
    # plt.show()


    def f(w, X):
        # 目标函数
        return np.sum((X.dot(w) ** 2)) / len(X)


    def df(w, X):
        # 梯度
        return X.T.dot(X.dot(w)) * 2 / len(X)


    def direction(w):
        # 把w变成单位向量,只表示方向
        return w / np.linalg.norm(w)


    def first_component(X, initial_w, eta, n_iters=1e4, epsilon=1e-8):
        """
        梯度上升法,求出第一主成分
        :param X: 数据矩阵
        :param initial_w: 初始的常数向量,这里需要注意的是真正待求的是常数向量,求偏导的也是常数向量
        :param eta: 步长,学习率
        :param n_iters: 最大迭代次数
        :param epsilon: 误差值
        :return:
        """
        # 将初始的常数向量变成单位向量
        w = direction(initial_w)
        # 真实迭代次数
        cur_iter = 0
        while cur_iter < n_iters:
            # 获取梯度
            gradient = df(w, X)
            last_w = w
            # 迭代更新w,不断顺着梯度方向寻找新的w
            # 跟梯度下降法不同的是,梯度下降法是-,梯度上升法这里是+
            w = w + eta * gradient
            # 将获取的新的w重新变成单位向量
            w = direction(w)
            # 计算前后两次迭代后的目标函数差值的绝对值
            if abs(f(w, X) - f(last_w, X)) < epsilon:
                break
            # 更新迭代次数
            cur_iter += 1
        return w

    initial_w = np.random.random(X_demean.shape[1])
    eta = 0.01
    # 求出第一主成分
    w1 = first_component(X_demean, initial_w, eta)
    # 求第二主成分分量
    X2 = np.empty(X_demean.shape)
    for i in range(len(X_demean)):
        X2[i] = X_demean[i] - X_demean[i].dot(w1) * w1
    # ax = plt.axes(projection='3d')
    # ax.scatter3D(X2[:, 0], X2[:, 1], X2[:, 2])
    # plt.show()
    # 求第二主成分
    w2 = first_component(X2, initial_w, eta)
    # 求第三主成分分量
    X3 = np.empty(X2.shape)
    for i in range(len(X2)):
        X3[i] = X2[i] - X2[i].dot(w2) * w2
    ax = plt.axes(projection='3d')
    ax.scatter3D(X3[:, 0], X3[:, 1], X3[:, 2])
    plt.show()

使用PCA对数据进行降噪

之前我们有说过,将一组二维数据进行PCA降维为一维,再反向为二维,就由可能是下面这个样子

这个过程我们也可以理解成原始数据(蓝色的点)是带有噪音的,它可能是由各种测量不准确,精度不够,或者粗心大意,或者测量方法本身就有问题所引发的,而真实数据本身就是一个线性的样子,那么通过PCA降为一维再恢复成二维(红色的点)以后就对其进行了降噪处理,得到真实数据本身的样子。

现在我们来看一个手写识别的降噪。我们先获取手写识别的数据,然后再创造一个有噪音的数据集

from sklearn import datasets
import numpy as np

if __name__ == "__main__":

    digits = datasets.load_digits()
    X = digits.data
    y = digits.target
    # 创造有噪音的手写数据集,噪音为随机正态分布的矩阵(均值为0,方差为4)
    noisy_digits = X + np.random.normal(0, 4, size=X.shape)
    # 由于数据量比较大,我们取出一些样例
    example_digits = noisy_digits[y == 0, :][: 10]
    for num in range(1, 10):
        X_num = noisy_digits[y == num, :][: 10]
        example_digits = np.vstack([example_digits, X_num])
    print(example_digits.shape)

运行结果

(100, 64)

通过结果,我们可以看到,样例有100个样本数,64个特征。我们把这些样例给绘制出来

from sklearn import datasets
import numpy as np
import matplotlib.pyplot as plt

if __name__ == "__main__":

    digits = datasets.load_digits()
    X = digits.data
    y = digits.target
    # 创造有噪音的手写数据集,噪音为随机正态分布的矩阵(均值为0,方差为4)
    noisy_digits = X + np.random.normal(0, 4, size=X.shape)
    # 由于数据量比较大,我们取出一些样例
    example_digits = noisy_digits[y == 0, :][: 10]
    for num in range(1, 10):
        X_num = noisy_digits[y == num, :][: 10]
        example_digits = np.vstack([example_digits, X_num])
    print(example_digits.shape)

    def plot_digits(data):
        # 绘制样本
        fig, axes = plt.subplots(10, 10, figsize=(10, 10),
                                 subplot_kw={'xticks': [], 'yticks': []},
                                 gridspec_kw=dict(hspace=0.1, wspace=0.1))
        for i, ax in enumerate(axes.flat):
            ax.imshow(data[i].reshape(8, 8), cmap='binary', interpolation='nearest', clim=(0, 16))
        plt.show()

    plot_digits(example_digits)

运行结果

看这个结果,我们是不是觉得这些数据噪音非常明显,很难看清楚这些是些什么数字。现在我们来对其进行降噪。

# 对样例进行降噪,保留50%的方差比例
pca = PCA(0.5)
pca.fit(noisy_digits)
print(pca.n_components_)

运行结果

12

可见我们从64维降到了12维,然后开始降噪

from sklearn import datasets
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

if __name__ == "__main__":

    digits = datasets.load_digits()
    X = digits.data
    y = digits.target
    # 创造有噪音的手写数据集,噪音为随机正态分布的矩阵(均值为0,方差为4)
    noisy_digits = X + np.random.normal(0, 4, size=X.shape)
    # 由于数据量比较大,我们取出一些样例
    example_digits = noisy_digits[y == 0, :][: 10]
    for num in range(1, 10):
        X_num = noisy_digits[y == num, :][: 10]
        example_digits = np.vstack([example_digits, X_num])
    print(example_digits.shape)

    def plot_digits(data):
        # 绘制样本
        fig, axes = plt.subplots(10, 10, figsize=(10, 10),
                                 subplot_kw={'xticks': [], 'yticks': []},
                                 gridspec_kw=dict(hspace=0.1, wspace=0.1))
        for i, ax in enumerate(axes.flat):
            ax.imshow(data[i].reshape(8, 8), cmap='binary', interpolation='nearest', clim=(0, 16))
        plt.show()

    plot_digits(example_digits)
    # 对样例进行降噪,保留50%的方差比例
    pca = PCA(0.5)
    pca.fit(noisy_digits)
    print(pca.n_components_)
    components = pca.transform(example_digits)
    # 将降维后的12维数据返回到64维
    filtered_digits = pca.inverse_transform(components)
    plot_digits(filtered_digits)

运行结果

通过结果,进行降噪后的图形比含噪音的图形要平滑了很多。

人脸识别与特征脸

之前我们在讲高维数据映射为低维数据的时候

X行数m就是样本量,列数n就是维度。W是计算出来的主成分,构成了另外一个坐标系,它其中的每一行都代表着一个方向,而第一行其实是我们说的最重要的那个方向;第二行是次重要的那个方向,以此类推。在人脸识别中,X中的每一行就是一个人脸图像,而将W中的每一行都看作是一个样本的话,我们可以说第一行是最重要的那个样本,最能反映X这个矩阵原来的那个样本特征的那个样本;第二行的样本是次重要的那个样本,它也能够非常好的反映原来的X这些样本相应的特征。则W中的每一行也可以看成是一个人脸,这个人脸就称之为是特征脸。之所以称为特征脸其实是因为每一个特征脸其实对应的是一个主成分,它相当于表达了一部分原来的样本中,这一些人脸数据对应的特征,它们也对应线性代数中特征值和特征向量的一些概念,有兴趣的朋友可以参考线性代数整理(三) 中有关特征值和特征向量的内容。

现在我们来获取人脸数据(下载数据的时间会比较长)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_lfw_people

if __name__ == "__main__":

    faces = fetch_lfw_people()
    print(faces)
    # 打印数据字典
    print(faces.keys())
    # 打印数据量和特征数
    print(faces.data.shape)
    # 打印每个图像的矩阵样式
    print(faces.images.shape)

运行结果

{'data': array([[ 34.      ,  29.333334,  22.333334, ...,  14.666667,  16.      ,
         14.      ],
       [158.      , 160.66667 , 169.66667 , ..., 138.66667 , 135.33333 ,
        130.33333 ],
       [ 77.      ,  81.333336,  88.      , ..., 192.      , 145.33333 ,
         66.333336],
       ...,
       [ 38.      ,  41.666668,  55.333332, ...,  66.      ,  63.666668,
         54.333332],
       [ 16.666666,  24.333334,  60.333332, ..., 219.      , 143.33333 ,
         69.333336],
       [ 58.333332,  48.      ,  20.      , ..., 116.      , 106.333336,
        143.33333 ]], dtype=float32), 'images': array([[[ 34.      ,  29.333334,  22.333334, ...,  20.      ,
          25.666666,  30.666666],
        [ 37.333332,  32.      ,  25.333334, ...,  21.      ,
          26.666666,  32.      ],
        [ 33.333332,  32.333332,  40.333332, ...,  23.666666,
          28.      ,  35.666668],
        ...,
        [166.      ,  97.      ,  44.333332, ...,   9.666667,
          14.333333,  12.333333],
        [ 64.      ,  38.666668,  30.      , ...,  12.666667,
          16.      ,  14.      ],
        [ 30.666666,  29.      ,  26.333334, ...,  14.666667,
          16.      ,  14.      ]],

       [[158.      , 160.66667 , 169.66667 , ...,  74.333336,
          28.      ,  15.666667],
        [156.      , 155.33333 , 163.33333 , ...,  83.      ,
          25.666666,  14.      ],
        [146.66667 , 143.66667 , 144.66667 , ...,  82.333336,
          26.      ,  14.666667],
        ...,
        [118.666664, 120.      , 170.      , ..., 131.33333 ,
         127.333336, 126.      ],
        [125.      , 117.666664, 141.33333 , ..., 133.33333 ,
         132.      , 129.33333 ],
        [128.66667 , 122.666664, 121.666664, ..., 138.66667 ,
         135.33333 , 130.33333 ]],

       [[ 77.      ,  81.333336,  88.      , ...,  71.      ,
          80.666664,  65.333336],
        [ 77.666664,  89.      , 104.666664, ...,  75.666664,
          72.      ,  69.      ],
        [ 83.      ,  97.      , 117.      , ...,  80.333336,
          80.333336,  66.      ],
        ...,
        [ 13.333333,  16.      ,  32.      , ..., 183.66667 ,
         131.      ,  58.333332],
        [ 46.333332,  70.      ,  94.333336, ..., 188.66667 ,
         143.66667 ,  65.      ],
        [112.333336, 122.      , 118.      , ..., 192.      ,
         145.33333 ,  66.333336]],

       ...,

       [[ 38.      ,  41.666668,  55.333332, ...,  28.666666,
          26.333334,  29.333334],
        [ 46.666668,  49.333332,  60.666668, ...,  30.333334,
          26.333334,  31.      ],
        [ 50.333332,  53.666668,  63.      , ...,  32.333332,
          27.666666,  34.333332],
        ...,
        [ 71.      , 123.      , 204.66667 , ...,  58.333332,
          47.      ,  38.      ],
        [ 70.666664,  85.666664, 148.66667 , ...,  64.      ,
          57.666668,  46.666668],
        [ 70.333336,  71.666664,  97.333336, ...,  66.      ,
          63.666668,  54.333332]],

       [[ 16.666666,  24.333334,  60.333332, ..., 177.33333 ,
         175.      , 174.33333 ],
        [ 16.333334,  22.333334,  55.666668, ..., 177.33333 ,
         174.66667 , 174.      ],
        [ 18.333334,  25.666666,  61.      , ..., 175.33333 ,
         172.66667 , 172.33333 ],
        ...,
        [ 21.666666,  21.333334,  22.666666, ..., 221.33333 ,
         132.66667 ,  52.      ],
        [ 22.      ,  21.666666,  22.333334, ..., 219.      ,
         137.      ,  59.      ],
        [ 22.333334,  22.      ,  22.666666, ..., 219.      ,
         143.33333 ,  69.333336]],

       [[ 58.333332,  48.      ,  20.      , ...,  66.      ,
         101.666664,  94.666664],
        [ 62.      ,  32.666668,  26.333334, ...,  50.      ,
          89.666664, 101.333336],
        [ 56.333332,  29.333334,  47.      , ...,  55.333332,
          76.666664, 106.333336],
        ...,
        [116.333336, 106.333336,  95.      , ..., 113.333336,
         100.333336,  88.      ],
        [116.666664, 104.666664,  93.333336, ..., 115.666664,
         103.666664, 112.      ],
        [116.333336, 104.      ,  95.333336, ..., 116.      ,
         106.333336, 143.33333 ]]], dtype=float32), 'target': array([5360, 3434, 3807, ..., 2175,  373, 2941]), 'target_names': array(['AJ Cook', 'AJ Lamas', 'Aaron Eckhart', ..., 'Zumrati Juma',
       'Zurab Tsereteli', 'Zydrunas Ilgauskas'], dtype='<U35'), 'DESCR': ".. _labeled_faces_in_the_wild_dataset:\n\nThe Labeled Faces in the Wild face recognition dataset\n------------------------------------------------------\n\nThis dataset is a collection of JPEG pictures of famous people collected\nover the internet, all details are available on the official website:\n\n    http://vis-www.cs.umass.edu/lfw/\n\nEach picture is centered on a single face. The typical task is called\nFace Verification: given a pair of two pictures, a binary classifier\nmust predict whether the two images are from the same person.\n\nAn alternative task, Face Recognition or Face Identification is:\ngiven the picture of the face of an unknown person, identify the name\nof the person by referring to a gallery of previously seen pictures of\nidentified persons.\n\nBoth Face Verification and Face Recognition are tasks that are typically\nperformed on the output of a model trained to perform Face Detection. The\nmost popular model for Face Detection is called Viola-Jones and is\nimplemented in the OpenCV library. The LFW faces were extracted by this\nface detector from various online websites.\n\n**Data Set Characteristics:**\n\n    =================   =======================\n    Classes                                5749\n    Samples total                         13233\n    Dimensionality                         5828\n    Features            real, between 0 and 255\n    =================   =======================\n\nUsage\n~~~~~\n\n``scikit-learn`` provides two loaders that will automatically download,\ncache, parse the metadata files, decode the jpeg and convert the\ninteresting slices into memmapped numpy arrays. This dataset size is more\nthan 200 MB. The first load typically takes more than a couple of minutes\nto fully decode the relevant part of the JPEG files into numpy arrays. If\nthe dataset has  been loaded once, the following times the loading times\nless than 200ms by using a memmapped version memoized on the disk in the\n``~/scikit_learn_data/lfw_home/`` folder using ``joblib``.\n\nThe first loader is used for the Face Identification task: a multi-class\nclassification task (hence supervised learning)::\n\n  >>> from sklearn.datasets import fetch_lfw_people\n  >>> lfw_people = fetch_lfw_people(min_faces_per_person=70, resize=0.4)\n\n  >>> for name in lfw_people.target_names:\n  ...     print(name)\n  ...\n  Ariel Sharon\n  Colin Powell\n  Donald Rumsfeld\n  George W Bush\n  Gerhard Schroeder\n  Hugo Chavez\n  Tony Blair\n\nThe default slice is a rectangular shape around the face, removing\nmost of the background::\n\n  >>> lfw_people.data.dtype\n  dtype('float32')\n\n  >>> lfw_people.data.shape\n  (1288, 1850)\n\n  >>> lfw_people.images.shape\n  (1288, 50, 37)\n\nEach of the ``1140`` faces is assigned to a single person id in the ``target``\narray::\n\n  >>> lfw_people.target.shape\n  (1288,)\n\n  >>> list(lfw_people.target[:10])\n  [5, 6, 3, 1, 0, 1, 3, 4, 3, 0]\n\nThe second loader is typically used for the face verification task: each sample\nis a pair of two picture belonging or not to the same person::\n\n  >>> from sklearn.datasets import fetch_lfw_pairs\n  >>> lfw_pairs_train = fetch_lfw_pairs(subset='train')\n\n  >>> list(lfw_pairs_train.target_names)\n  ['Different persons', 'Same person']\n\n  >>> lfw_pairs_train.pairs.shape\n  (2200, 2, 62, 47)\n\n  >>> lfw_pairs_train.data.shape\n  (2200, 5828)\n\n  >>> lfw_pairs_train.target.shape\n  (2200,)\n\nBoth for the :func:`sklearn.datasets.fetch_lfw_people` and\n:func:`sklearn.datasets.fetch_lfw_pairs` function it is\npossible to get an additional dimension with the RGB color channels by\npassing ``color=True``, in that case the shape will be\n``(2200, 2, 62, 47, 3)``.\n\nThe :func:`sklearn.datasets.fetch_lfw_pairs` datasets is subdivided into\n3 subsets: the development ``train`` set, the development ``test`` set and\nan evaluation ``10_folds`` set meant to compute performance metrics using a\n10-folds cross validation scheme.\n\n.. topic:: References:\n\n * `Labeled Faces in the Wild: A Database for Studying Face Recognition\n   in Unconstrained Environments.\n   <http://vis-www.cs.umass.edu/lfw/lfw.pdf>`_\n   Gary B. Huang, Manu Ramesh, Tamara Berg, and Erik Learned-Miller.\n   University of Massachusetts, Amherst, Technical Report 07-49, October, 2007.\n\n\nExamples\n~~~~~~~~\n\n:ref:`sphx_glr_auto_examples_applications_plot_face_recognition.py`\n"}
dict_keys(['data', 'images', 'target', 'target_names', 'DESCR'])
(13233, 2914)
(13233, 62, 47)

通过结果我们可以看出,它有13233个样本数,2914个特征(维度);每个图像就是一个62*47的图像。现在我们来绘制这些人脸图像

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_lfw_people

if __name__ == "__main__":

    faces = fetch_lfw_people()
    print(faces)
    # 打印数据字典
    print(faces.keys())
    # 打印数据量和特征数
    print(faces.data.shape)
    # 打印每个图像的矩阵样式
    print(faces.images.shape)
    # 绘制人脸图像
    # 获取随机排列
    random_indexes = np.random.permutation(len(faces.data))
    # 取出随机排列的脸的数据
    X = faces.data[random_indexes]
    # 取出前36张脸
    example_faces = X[: 36, :]
    print(example_faces.shape)

    def plot_faces(faces):
        # 绘制样本
        fig, axes = plt.subplots(6, 6, figsize=(10, 10),
                                 subplot_kw={'xticks': [], 'yticks': []},
                                 gridspec_kw=dict(hspace=0.1, wspace=0.1))
        gridespec_kw = dict(hspace=0.1, wspace=0.1)
        for i, ax in enumerate(axes.flat):
            ax.imshow(faces[i].reshape(62, 47), cmap='bone')
        plt.show()

    plot_faces(example_faces)

运行结果

{'data': array([[ 34.      ,  29.333334,  22.333334, ...,  14.666667,  16.      ,
         14.      ],
       [158.      , 160.66667 , 169.66667 , ..., 138.66667 , 135.33333 ,
        130.33333 ],
       [ 77.      ,  81.333336,  88.      , ..., 192.      , 145.33333 ,
         66.333336],
       ...,
       [ 38.      ,  41.666668,  55.333332, ...,  66.      ,  63.666668,
         54.333332],
       [ 16.666666,  24.333334,  60.333332, ..., 219.      , 143.33333 ,
         69.333336],
       [ 58.333332,  48.      ,  20.      , ..., 116.      , 106.333336,
        143.33333 ]], dtype=float32), 'images': array([[[ 34.      ,  29.333334,  22.333334, ...,  20.      ,
          25.666666,  30.666666],
        [ 37.333332,  32.      ,  25.333334, ...,  21.      ,
          26.666666,  32.      ],
        [ 33.333332,  32.333332,  40.333332, ...,  23.666666,
          28.      ,  35.666668],
        ...,
        [166.      ,  97.      ,  44.333332, ...,   9.666667,
          14.333333,  12.333333],
        [ 64.      ,  38.666668,  30.      , ...,  12.666667,
          16.      ,  14.      ],
        [ 30.666666,  29.      ,  26.333334, ...,  14.666667,
          16.      ,  14.      ]],

       [[158.      , 160.66667 , 169.66667 , ...,  74.333336,
          28.      ,  15.666667],
        [156.      , 155.33333 , 163.33333 , ...,  83.      ,
          25.666666,  14.      ],
        [146.66667 , 143.66667 , 144.66667 , ...,  82.333336,
          26.      ,  14.666667],
        ...,
        [118.666664, 120.      , 170.      , ..., 131.33333 ,
         127.333336, 126.      ],
        [125.      , 117.666664, 141.33333 , ..., 133.33333 ,
         132.      , 129.33333 ],
        [128.66667 , 122.666664, 121.666664, ..., 138.66667 ,
         135.33333 , 130.33333 ]],

       [[ 77.      ,  81.333336,  88.      , ...,  71.      ,
          80.666664,  65.333336],
        [ 77.666664,  89.      , 104.666664, ...,  75.666664,
          72.      ,  69.      ],
        [ 83.      ,  97.      , 117.      , ...,  80.333336,
          80.333336,  66.      ],
        ...,
        [ 13.333333,  16.      ,  32.      , ..., 183.66667 ,
         131.      ,  58.333332],
        [ 46.333332,  70.      ,  94.333336, ..., 188.66667 ,
         143.66667 ,  65.      ],
        [112.333336, 122.      , 118.      , ..., 192.      ,
         145.33333 ,  66.333336]],

       ...,

       [[ 38.      ,  41.666668,  55.333332, ...,  28.666666,
          26.333334,  29.333334],
        [ 46.666668,  49.333332,  60.666668, ...,  30.333334,
          26.333334,  31.      ],
        [ 50.333332,  53.666668,  63.      , ...,  32.333332,
          27.666666,  34.333332],
        ...,
        [ 71.      , 123.      , 204.66667 , ...,  58.333332,
          47.      ,  38.      ],
        [ 70.666664,  85.666664, 148.66667 , ...,  64.      ,
          57.666668,  46.666668],
        [ 70.333336,  71.666664,  97.333336, ...,  66.      ,
          63.666668,  54.333332]],

       [[ 16.666666,  24.333334,  60.333332, ..., 177.33333 ,
         175.      , 174.33333 ],
        [ 16.333334,  22.333334,  55.666668, ..., 177.33333 ,
         174.66667 , 174.      ],
        [ 18.333334,  25.666666,  61.      , ..., 175.33333 ,
         172.66667 , 172.33333 ],
        ...,
        [ 21.666666,  21.333334,  22.666666, ..., 221.33333 ,
         132.66667 ,  52.      ],
        [ 22.      ,  21.666666,  22.333334, ..., 219.      ,
         137.      ,  59.      ],
        [ 22.333334,  22.      ,  22.666666, ..., 219.      ,
         143.33333 ,  69.333336]],

       [[ 58.333332,  48.      ,  20.      , ...,  66.      ,
         101.666664,  94.666664],
        [ 62.      ,  32.666668,  26.333334, ...,  50.      ,
          89.666664, 101.333336],
        [ 56.333332,  29.333334,  47.      , ...,  55.333332,
          76.666664, 106.333336],
        ...,
        [116.333336, 106.333336,  95.      , ..., 113.333336,
         100.333336,  88.      ],
        [116.666664, 104.666664,  93.333336, ..., 115.666664,
         103.666664, 112.      ],
        [116.333336, 104.      ,  95.333336, ..., 116.      ,
         106.333336, 143.33333 ]]], dtype=float32), 'target': array([5360, 3434, 3807, ..., 2175,  373, 2941]), 'target_names': array(['AJ Cook', 'AJ Lamas', 'Aaron Eckhart', ..., 'Zumrati Juma',
       'Zurab Tsereteli', 'Zydrunas Ilgauskas'], dtype='<U35'), 'DESCR': ".. _labeled_faces_in_the_wild_dataset:\n\nThe Labeled Faces in the Wild face recognition dataset\n------------------------------------------------------\n\nThis dataset is a collection of JPEG pictures of famous people collected\nover the internet, all details are available on the official website:\n\n    http://vis-www.cs.umass.edu/lfw/\n\nEach picture is centered on a single face. The typical task is called\nFace Verification: given a pair of two pictures, a binary classifier\nmust predict whether the two images are from the same person.\n\nAn alternative task, Face Recognition or Face Identification is:\ngiven the picture of the face of an unknown person, identify the name\nof the person by referring to a gallery of previously seen pictures of\nidentified persons.\n\nBoth Face Verification and Face Recognition are tasks that are typically\nperformed on the output of a model trained to perform Face Detection. The\nmost popular model for Face Detection is called Viola-Jones and is\nimplemented in the OpenCV library. The LFW faces were extracted by this\nface detector from various online websites.\n\n**Data Set Characteristics:**\n\n    =================   =======================\n    Classes                                5749\n    Samples total                         13233\n    Dimensionality                         5828\n    Features            real, between 0 and 255\n    =================   =======================\n\nUsage\n~~~~~\n\n``scikit-learn`` provides two loaders that will automatically download,\ncache, parse the metadata files, decode the jpeg and convert the\ninteresting slices into memmapped numpy arrays. This dataset size is more\nthan 200 MB. The first load typically takes more than a couple of minutes\nto fully decode the relevant part of the JPEG files into numpy arrays. If\nthe dataset has  been loaded once, the following times the loading times\nless than 200ms by using a memmapped version memoized on the disk in the\n``~/scikit_learn_data/lfw_home/`` folder using ``joblib``.\n\nThe first loader is used for the Face Identification task: a multi-class\nclassification task (hence supervised learning)::\n\n  >>> from sklearn.datasets import fetch_lfw_people\n  >>> lfw_people = fetch_lfw_people(min_faces_per_person=70, resize=0.4)\n\n  >>> for name in lfw_people.target_names:\n  ...     print(name)\n  ...\n  Ariel Sharon\n  Colin Powell\n  Donald Rumsfeld\n  George W Bush\n  Gerhard Schroeder\n  Hugo Chavez\n  Tony Blair\n\nThe default slice is a rectangular shape around the face, removing\nmost of the background::\n\n  >>> lfw_people.data.dtype\n  dtype('float32')\n\n  >>> lfw_people.data.shape\n  (1288, 1850)\n\n  >>> lfw_people.images.shape\n  (1288, 50, 37)\n\nEach of the ``1140`` faces is assigned to a single person id in the ``target``\narray::\n\n  >>> lfw_people.target.shape\n  (1288,)\n\n  >>> list(lfw_people.target[:10])\n  [5, 6, 3, 1, 0, 1, 3, 4, 3, 0]\n\nThe second loader is typically used for the face verification task: each sample\nis a pair of two picture belonging or not to the same person::\n\n  >>> from sklearn.datasets import fetch_lfw_pairs\n  >>> lfw_pairs_train = fetch_lfw_pairs(subset='train')\n\n  >>> list(lfw_pairs_train.target_names)\n  ['Different persons', 'Same person']\n\n  >>> lfw_pairs_train.pairs.shape\n  (2200, 2, 62, 47)\n\n  >>> lfw_pairs_train.data.shape\n  (2200, 5828)\n\n  >>> lfw_pairs_train.target.shape\n  (2200,)\n\nBoth for the :func:`sklearn.datasets.fetch_lfw_people` and\n:func:`sklearn.datasets.fetch_lfw_pairs` function it is\npossible to get an additional dimension with the RGB color channels by\npassing ``color=True``, in that case the shape will be\n``(2200, 2, 62, 47, 3)``.\n\nThe :func:`sklearn.datasets.fetch_lfw_pairs` datasets is subdivided into\n3 subsets: the development ``train`` set, the development ``test`` set and\nan evaluation ``10_folds`` set meant to compute performance metrics using a\n10-folds cross validation scheme.\n\n.. topic:: References:\n\n * `Labeled Faces in the Wild: A Database for Studying Face Recognition\n   in Unconstrained Environments.\n   <http://vis-www.cs.umass.edu/lfw/lfw.pdf>`_\n   Gary B. Huang, Manu Ramesh, Tamara Berg, and Erik Learned-Miller.\n   University of Massachusetts, Amherst, Technical Report 07-49, October, 2007.\n\n\nExamples\n~~~~~~~~\n\n:ref:`sphx_glr_auto_examples_applications_plot_face_recognition.py`\n"}
dict_keys(['data', 'images', 'target', 'target_names', 'DESCR'])
(13233, 2914)
(13233, 62, 47)
(36, 2914)

结果显示我们拿到了36个示例样本,每个样本有2914个特征。现在我们就来获取和绘制特征脸

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_lfw_people
from sklearn.decomposition import PCA
import timeit

if __name__ == "__main__":

    faces = fetch_lfw_people()
    print(faces)
    # 打印数据字典
    print(faces.keys())
    # 打印数据量和特征数
    print(faces.data.shape)
    # 打印每个图像的矩阵样式
    print(faces.images.shape)
    # 绘制人脸图像
    # 获取随机排列
    random_indexes = np.random.permutation(len(faces.data))
    # 取出随机排列的脸的数据
    X = faces.data[random_indexes]
    # 取出前36张脸
    example_faces = X[: 36, :]
    print(example_faces.shape)

    def plot_faces(faces):
        # 绘制样本
        fig, axes = plt.subplots(6, 6, figsize=(10, 10),
                                 subplot_kw={'xticks': [], 'yticks': []},
                                 gridspec_kw=dict(hspace=0.1, wspace=0.1))
        gridespec_kw = dict(hspace=0.1, wspace=0.1)
        for i, ax in enumerate(axes.flat):
            ax.imshow(faces[i].reshape(62, 47), cmap='bone')
        plt.show()

    # plot_faces(example_faces)
    # 获取特征脸
    pca = PCA(svd_solver='randomized')
    start_time = timeit.default_timer()
    pca.fit(X)
    print(timeit.default_timer() - start_time)
    print(pca.components_.shape)
    # 绘制特征脸
    plot_faces(pca.components_[: 36, :])

运行结果

{'data': array([[ 34.      ,  29.333334,  22.333334, ...,  14.666667,  16.      ,
         14.      ],
       [158.      , 160.66667 , 169.66667 , ..., 138.66667 , 135.33333 ,
        130.33333 ],
       [ 77.      ,  81.333336,  88.      , ..., 192.      , 145.33333 ,
         66.333336],
       ...,
       [ 38.      ,  41.666668,  55.333332, ...,  66.      ,  63.666668,
         54.333332],
       [ 16.666666,  24.333334,  60.333332, ..., 219.      , 143.33333 ,
         69.333336],
       [ 58.333332,  48.      ,  20.      , ..., 116.      , 106.333336,
        143.33333 ]], dtype=float32), 'images': array([[[ 34.      ,  29.333334,  22.333334, ...,  20.      ,
          25.666666,  30.666666],
        [ 37.333332,  32.      ,  25.333334, ...,  21.      ,
          26.666666,  32.      ],
        [ 33.333332,  32.333332,  40.333332, ...,  23.666666,
          28.      ,  35.666668],
        ...,
        [166.      ,  97.      ,  44.333332, ...,   9.666667,
          14.333333,  12.333333],
        [ 64.      ,  38.666668,  30.      , ...,  12.666667,
          16.      ,  14.      ],
        [ 30.666666,  29.      ,  26.333334, ...,  14.666667,
          16.      ,  14.      ]],

       [[158.      , 160.66667 , 169.66667 , ...,  74.333336,
          28.      ,  15.666667],
        [156.      , 155.33333 , 163.33333 , ...,  83.      ,
          25.666666,  14.      ],
        [146.66667 , 143.66667 , 144.66667 , ...,  82.333336,
          26.      ,  14.666667],
        ...,
        [118.666664, 120.      , 170.      , ..., 131.33333 ,
         127.333336, 126.      ],
        [125.      , 117.666664, 141.33333 , ..., 133.33333 ,
         132.      , 129.33333 ],
        [128.66667 , 122.666664, 121.666664, ..., 138.66667 ,
         135.33333 , 130.33333 ]],

       [[ 77.      ,  81.333336,  88.      , ...,  71.      ,
          80.666664,  65.333336],
        [ 77.666664,  89.      , 104.666664, ...,  75.666664,
          72.      ,  69.      ],
        [ 83.      ,  97.      , 117.      , ...,  80.333336,
          80.333336,  66.      ],
        ...,
        [ 13.333333,  16.      ,  32.      , ..., 183.66667 ,
         131.      ,  58.333332],
        [ 46.333332,  70.      ,  94.333336, ..., 188.66667 ,
         143.66667 ,  65.      ],
        [112.333336, 122.      , 118.      , ..., 192.      ,
         145.33333 ,  66.333336]],

       ...,

       [[ 38.      ,  41.666668,  55.333332, ...,  28.666666,
          26.333334,  29.333334],
        [ 46.666668,  49.333332,  60.666668, ...,  30.333334,
          26.333334,  31.      ],
        [ 50.333332,  53.666668,  63.      , ...,  32.333332,
          27.666666,  34.333332],
        ...,
        [ 71.      , 123.      , 204.66667 , ...,  58.333332,
          47.      ,  38.      ],
        [ 70.666664,  85.666664, 148.66667 , ...,  64.      ,
          57.666668,  46.666668],
        [ 70.333336,  71.666664,  97.333336, ...,  66.      ,
          63.666668,  54.333332]],

       [[ 16.666666,  24.333334,  60.333332, ..., 177.33333 ,
         175.      , 174.33333 ],
        [ 16.333334,  22.333334,  55.666668, ..., 177.33333 ,
         174.66667 , 174.      ],
        [ 18.333334,  25.666666,  61.      , ..., 175.33333 ,
         172.66667 , 172.33333 ],
        ...,
        [ 21.666666,  21.333334,  22.666666, ..., 221.33333 ,
         132.66667 ,  52.      ],
        [ 22.      ,  21.666666,  22.333334, ..., 219.      ,
         137.      ,  59.      ],
        [ 22.333334,  22.      ,  22.666666, ..., 219.      ,
         143.33333 ,  69.333336]],

       [[ 58.333332,  48.      ,  20.      , ...,  66.      ,
         101.666664,  94.666664],
        [ 62.      ,  32.666668,  26.333334, ...,  50.      ,
          89.666664, 101.333336],
        [ 56.333332,  29.333334,  47.      , ...,  55.333332,
          76.666664, 106.333336],
        ...,
        [116.333336, 106.333336,  95.      , ..., 113.333336,
         100.333336,  88.      ],
        [116.666664, 104.666664,  93.333336, ..., 115.666664,
         103.666664, 112.      ],
        [116.333336, 104.      ,  95.333336, ..., 116.      ,
         106.333336, 143.33333 ]]], dtype=float32), 'target': array([5360, 3434, 3807, ..., 2175,  373, 2941]), 'target_names': array(['AJ Cook', 'AJ Lamas', 'Aaron Eckhart', ..., 'Zumrati Juma',
       'Zurab Tsereteli', 'Zydrunas Ilgauskas'], dtype='<U35'), 'DESCR': ".. _labeled_faces_in_the_wild_dataset:\n\nThe Labeled Faces in the Wild face recognition dataset\n------------------------------------------------------\n\nThis dataset is a collection of JPEG pictures of famous people collected\nover the internet, all details are available on the official website:\n\n    http://vis-www.cs.umass.edu/lfw/\n\nEach picture is centered on a single face. The typical task is called\nFace Verification: given a pair of two pictures, a binary classifier\nmust predict whether the two images are from the same person.\n\nAn alternative task, Face Recognition or Face Identification is:\ngiven the picture of the face of an unknown person, identify the name\nof the person by referring to a gallery of previously seen pictures of\nidentified persons.\n\nBoth Face Verification and Face Recognition are tasks that are typically\nperformed on the output of a model trained to perform Face Detection. The\nmost popular model for Face Detection is called Viola-Jones and is\nimplemented in the OpenCV library. The LFW faces were extracted by this\nface detector from various online websites.\n\n**Data Set Characteristics:**\n\n    =================   =======================\n    Classes                                5749\n    Samples total                         13233\n    Dimensionality                         5828\n    Features            real, between 0 and 255\n    =================   =======================\n\nUsage\n~~~~~\n\n``scikit-learn`` provides two loaders that will automatically download,\ncache, parse the metadata files, decode the jpeg and convert the\ninteresting slices into memmapped numpy arrays. This dataset size is more\nthan 200 MB. The first load typically takes more than a couple of minutes\nto fully decode the relevant part of the JPEG files into numpy arrays. If\nthe dataset has  been loaded once, the following times the loading times\nless than 200ms by using a memmapped version memoized on the disk in the\n``~/scikit_learn_data/lfw_home/`` folder using ``joblib``.\n\nThe first loader is used for the Face Identification task: a multi-class\nclassification task (hence supervised learning)::\n\n  >>> from sklearn.datasets import fetch_lfw_people\n  >>> lfw_people = fetch_lfw_people(min_faces_per_person=70, resize=0.4)\n\n  >>> for name in lfw_people.target_names:\n  ...     print(name)\n  ...\n  Ariel Sharon\n  Colin Powell\n  Donald Rumsfeld\n  George W Bush\n  Gerhard Schroeder\n  Hugo Chavez\n  Tony Blair\n\nThe default slice is a rectangular shape around the face, removing\nmost of the background::\n\n  >>> lfw_people.data.dtype\n  dtype('float32')\n\n  >>> lfw_people.data.shape\n  (1288, 1850)\n\n  >>> lfw_people.images.shape\n  (1288, 50, 37)\n\nEach of the ``1140`` faces is assigned to a single person id in the ``target``\narray::\n\n  >>> lfw_people.target.shape\n  (1288,)\n\n  >>> list(lfw_people.target[:10])\n  [5, 6, 3, 1, 0, 1, 3, 4, 3, 0]\n\nThe second loader is typically used for the face verification task: each sample\nis a pair of two picture belonging or not to the same person::\n\n  >>> from sklearn.datasets import fetch_lfw_pairs\n  >>> lfw_pairs_train = fetch_lfw_pairs(subset='train')\n\n  >>> list(lfw_pairs_train.target_names)\n  ['Different persons', 'Same person']\n\n  >>> lfw_pairs_train.pairs.shape\n  (2200, 2, 62, 47)\n\n  >>> lfw_pairs_train.data.shape\n  (2200, 5828)\n\n  >>> lfw_pairs_train.target.shape\n  (2200,)\n\nBoth for the :func:`sklearn.datasets.fetch_lfw_people` and\n:func:`sklearn.datasets.fetch_lfw_pairs` function it is\npossible to get an additional dimension with the RGB color channels by\npassing ``color=True``, in that case the shape will be\n``(2200, 2, 62, 47, 3)``.\n\nThe :func:`sklearn.datasets.fetch_lfw_pairs` datasets is subdivided into\n3 subsets: the development ``train`` set, the development ``test`` set and\nan evaluation ``10_folds`` set meant to compute performance metrics using a\n10-folds cross validation scheme.\n\n.. topic:: References:\n\n * `Labeled Faces in the Wild: A Database for Studying Face Recognition\n   in Unconstrained Environments.\n   <http://vis-www.cs.umass.edu/lfw/lfw.pdf>`_\n   Gary B. Huang, Manu Ramesh, Tamara Berg, and Erik Learned-Miller.\n   University of Massachusetts, Amherst, Technical Report 07-49, October, 2007.\n\n\nExamples\n~~~~~~~~\n\n:ref:`sphx_glr_auto_examples_applications_plot_face_recognition.py`\n"}
dict_keys(['data', 'images', 'target', 'target_names', 'DESCR'])
(13233, 2914)
(13233, 62, 47)
(36, 2914)
12.20208079
(2914, 2914)

由结果可以看到获取主成分使用了12.2秒,总共有2914个主成分,每一个主成分有2914个向量。而图像就是绘出的36个特征脸。我们发现越靠前的图形越模糊,而越靠后的图形逐渐清晰一些。我们通过求出特征脸,一方面我们可以直观的看出来在人脸识别的过程中,我们是怎么看到每一张人脸相应的特征的。另外一方面我们通过之前的式子也可以看出来,其实我们的每一个人脸都是这些特征脸相应的一个线性组合,而特征脸按照重要程度顺次的排在了图形中,

另外在这个人脸库中,每个人的图片数量是不同的,我们可以获取那些有多张照片的人脸来进行人脸识别的训练数据。

# 只获取有60张照片的人脸数据
faces = fetch_lfw_people(min_faces_per_person=60)
print(faces)
# 打印数据字典
print(faces.keys())
# 打印数据量和特征数
print(faces.data.shape)
# 打印这些人名
print(faces.target_names)
# 打印满足条件的人数
print(len(faces.target_names))

运行结果

{'data': array([[138.        , 135.66667   , 127.666664  , ...,   1.6666666 ,
          1.6666666 ,   0.33333334],
       [ 71.333336  ,  56.        ,  67.666664  , ..., 247.66667   ,
        243.        , 238.33333   ],
       [ 84.333336  ,  97.333336  ,  72.333336  , ..., 114.        ,
        194.33333   , 241.        ],
       ...,
       [ 29.333334  ,  29.        ,  29.333334  , ..., 145.        ,
        147.        , 141.66667   ],
       [ 49.333332  ,  55.666668  ,  76.666664  , ..., 186.33333   ,
        176.33333   , 161.        ],
       [ 31.        ,  26.333334  ,  28.        , ...,  34.        ,
         42.        ,  69.666664  ]], dtype=float32), 'images': array([[[138.        , 135.66667   , 127.666664  , ...,  69.        ,
          68.333336  ,  67.333336  ],
        [146.        , 139.33333   , 125.        , ...,  68.333336  ,
          67.666664  ,  67.333336  ],
        [150.        , 138.33333   , 124.333336  , ...,  68.333336  ,
          67.666664  ,  66.666664  ],
        ...,
        [153.        , 174.        , 110.666664  , ...,   1.6666666 ,
           0.6666667 ,   0.6666667 ],
        [122.        , 193.        , 167.33333   , ...,   1.3333334 ,
           1.6666666 ,   1.3333334 ],
        [ 88.        , 177.33333   , 206.        , ...,   1.6666666 ,
           1.6666666 ,   0.33333334]],

       [[ 71.333336  ,  56.        ,  67.666664  , ...,  74.333336  ,
          89.666664  ,  78.666664  ],
        [ 64.333336  ,  61.666668  ,  84.333336  , ...,  72.        ,
          87.        ,  78.666664  ],
        [ 74.        ,  76.        ,  94.333336  , ...,  69.666664  ,
          84.666664  ,  83.333336  ],
        ...,
        [ 28.333334  ,  26.666666  ,  20.666666  , ..., 242.        ,
         236.33333   , 232.        ],
        [ 24.        ,  20.666666  ,  18.666666  , ..., 247.        ,
         242.33333   , 238.33333   ],
        [ 19.666666  ,  14.666667  ,  16.666666  , ..., 247.66667   ,
         243.        , 238.33333   ]],

       [[ 84.333336  ,  97.333336  ,  72.333336  , ...,  82.666664  ,
          51.        ,  71.333336  ],
        [ 98.333336  , 101.        ,  75.        , ..., 100.        ,
          54.666668  ,  60.666668  ],
        [104.666664  , 100.        ,  76.        , ..., 110.666664  ,
          67.        ,  62.666668  ],
        ...,
        [ 56.        ,  56.333332  ,  55.        , ...,  91.        ,
         106.666664  , 204.66667   ],
        [ 58.333332  ,  58.        ,  56.666668  , ...,  90.666664  ,
         140.        , 226.        ],
        [ 61.666668  ,  63.        ,  63.333332  , ..., 114.        ,
         194.33333   , 241.        ]],

       ...,

       [[ 29.333334  ,  29.        ,  29.333334  , ...,  85.333336  ,
          80.333336  ,  77.        ],
        [ 30.        ,  31.666666  ,  43.333332  , ...,  82.        ,
          85.        ,  82.333336  ],
        [ 35.333332  ,  42.        ,  72.        , ...,  85.666664  ,
          83.        ,  87.        ],
        ...,
        [ 59.333332  ,  57.333332  ,  56.666668  , ..., 145.        ,
         143.33333   , 144.        ],
        [ 59.333332  ,  58.        ,  58.        , ..., 146.33333   ,
         143.66667   , 144.        ],
        [ 61.666668  ,  60.333332  ,  59.666668  , ..., 145.        ,
         147.        , 141.66667   ]],

       [[ 49.333332  ,  55.666668  ,  76.666664  , ..., 160.        ,
         158.33333   , 149.66667   ],
        [ 55.666668  ,  68.        ,  93.333336  , ..., 156.        ,
         153.        , 152.        ],
        [ 61.333332  ,  76.        , 104.333336  , ..., 151.66667   ,
         143.        , 146.33333   ],
        ...,
        [ 60.333332  ,  60.333332  ,  61.333332  , ..., 178.33333   ,
         169.        , 165.        ],
        [ 60.666668  ,  61.333332  ,  62.666668  , ..., 188.33333   ,
         172.        , 168.33333   ],
        [ 61.        ,  61.333332  ,  61.333332  , ..., 186.33333   ,
         176.33333   , 161.        ]],

       [[ 31.        ,  26.333334  ,  28.        , ...,  65.333336  ,
          49.        ,  47.666668  ],
        [ 31.333334  ,  29.333334  ,  34.        , ...,  71.        ,
          45.666668  ,  42.333332  ],
        [ 33.333332  ,  32.333332  ,  33.333332  , ...,  84.333336  ,
          52.333332  ,  45.666668  ],
        ...,
        [ 44.666668  ,  42.666668  ,  44.666668  , ...,  22.333334  ,
          25.333334  ,  46.333332  ],
        [ 42.333332  ,  42.333332  ,  45.        , ...,  25.333334  ,
          32.666668  ,  49.666668  ],
        [ 46.        ,  49.333332  ,  51.666668  , ...,  34.        ,
          42.        ,  69.666664  ]]], dtype=float32), 'target': array([1, 3, 3, ..., 7, 3, 5]), 'target_names': array(['Ariel Sharon', 'Colin Powell', 'Donald Rumsfeld', 'George W Bush',
       'Gerhard Schroeder', 'Hugo Chavez', 'Junichiro Koizumi',
       'Tony Blair'], dtype='<U17'), 'DESCR': ".. _labeled_faces_in_the_wild_dataset:\n\nThe Labeled Faces in the Wild face recognition dataset\n------------------------------------------------------\n\nThis dataset is a collection of JPEG pictures of famous people collected\nover the internet, all details are available on the official website:\n\n    http://vis-www.cs.umass.edu/lfw/\n\nEach picture is centered on a single face. The typical task is called\nFace Verification: given a pair of two pictures, a binary classifier\nmust predict whether the two images are from the same person.\n\nAn alternative task, Face Recognition or Face Identification is:\ngiven the picture of the face of an unknown person, identify the name\nof the person by referring to a gallery of previously seen pictures of\nidentified persons.\n\nBoth Face Verification and Face Recognition are tasks that are typically\nperformed on the output of a model trained to perform Face Detection. The\nmost popular model for Face Detection is called Viola-Jones and is\nimplemented in the OpenCV library. The LFW faces were extracted by this\nface detector from various online websites.\n\n**Data Set Characteristics:**\n\n    =================   =======================\n    Classes                                5749\n    Samples total                         13233\n    Dimensionality                         5828\n    Features            real, between 0 and 255\n    =================   =======================\n\nUsage\n~~~~~\n\n``scikit-learn`` provides two loaders that will automatically download,\ncache, parse the metadata files, decode the jpeg and convert the\ninteresting slices into memmapped numpy arrays. This dataset size is more\nthan 200 MB. The first load typically takes more than a couple of minutes\nto fully decode the relevant part of the JPEG files into numpy arrays. If\nthe dataset has  been loaded once, the following times the loading times\nless than 200ms by using a memmapped version memoized on the disk in the\n``~/scikit_learn_data/lfw_home/`` folder using ``joblib``.\n\nThe first loader is used for the Face Identification task: a multi-class\nclassification task (hence supervised learning)::\n\n  >>> from sklearn.datasets import fetch_lfw_people\n  >>> lfw_people = fetch_lfw_people(min_faces_per_person=70, resize=0.4)\n\n  >>> for name in lfw_people.target_names:\n  ...     print(name)\n  ...\n  Ariel Sharon\n  Colin Powell\n  Donald Rumsfeld\n  George W Bush\n  Gerhard Schroeder\n  Hugo Chavez\n  Tony Blair\n\nThe default slice is a rectangular shape around the face, removing\nmost of the background::\n\n  >>> lfw_people.data.dtype\n  dtype('float32')\n\n  >>> lfw_people.data.shape\n  (1288, 1850)\n\n  >>> lfw_people.images.shape\n  (1288, 50, 37)\n\nEach of the ``1140`` faces is assigned to a single person id in the ``target``\narray::\n\n  >>> lfw_people.target.shape\n  (1288,)\n\n  >>> list(lfw_people.target[:10])\n  [5, 6, 3, 1, 0, 1, 3, 4, 3, 0]\n\nThe second loader is typically used for the face verification task: each sample\nis a pair of two picture belonging or not to the same person::\n\n  >>> from sklearn.datasets import fetch_lfw_pairs\n  >>> lfw_pairs_train = fetch_lfw_pairs(subset='train')\n\n  >>> list(lfw_pairs_train.target_names)\n  ['Different persons', 'Same person']\n\n  >>> lfw_pairs_train.pairs.shape\n  (2200, 2, 62, 47)\n\n  >>> lfw_pairs_train.data.shape\n  (2200, 5828)\n\n  >>> lfw_pairs_train.target.shape\n  (2200,)\n\nBoth for the :func:`sklearn.datasets.fetch_lfw_people` and\n:func:`sklearn.datasets.fetch_lfw_pairs` function it is\npossible to get an additional dimension with the RGB color channels by\npassing ``color=True``, in that case the shape will be\n``(2200, 2, 62, 47, 3)``.\n\nThe :func:`sklearn.datasets.fetch_lfw_pairs` datasets is subdivided into\n3 subsets: the development ``train`` set, the development ``test`` set and\nan evaluation ``10_folds`` set meant to compute performance metrics using a\n10-folds cross validation scheme.\n\n.. topic:: References:\n\n * `Labeled Faces in the Wild: A Database for Studying Face Recognition\n   in Unconstrained Environments.\n   <http://vis-www.cs.umass.edu/lfw/lfw.pdf>`_\n   Gary B. Huang, Manu Ramesh, Tamara Berg, and Erik Learned-Miller.\n   University of Massachusetts, Amherst, Technical Report 07-49, October, 2007.\n\n\nExamples\n~~~~~~~~\n\n:ref:`sphx_glr_auto_examples_applications_plot_face_recognition.py`\n"}
dict_keys(['data', 'images', 'target', 'target_names', 'DESCR'])
(1348, 2914)
['Ariel Sharon' 'Colin Powell' 'Donald Rumsfeld' 'George W Bush'
 'Gerhard Schroeder' 'Hugo Chavez' 'Junichiro Koizumi' 'Tony Blair']
8

此时我们可以看到,只有1348个样本数,比之前13233个样本数据要少了很多。总共也只有8个人,而他们的名字也打印了出来。最后我们可以依照这个人脸库自己开发一个人脸识别的应用程序,使用KNN算法就可以。

多项式回归与模型泛化

什么是多项式回归

之前我们在讲线性回归的时候是说有一些数据样本点,我们猜测它存在一种线性关系y=ax+b(这里假设为一维的简单线性回归),而这条线性函数上的函数值,我们称为预测值,而真实值(数据样本值)与预测值的差的平方要尽可能的小,从而使用最小二乘法或者梯度下降法来求出这条线性函数的系数a(斜率)和b(截距),而也被称为线性回归的损失函数。而大多数情况下,我们的样本数据未必呈现出线性关系,如

虽然我们也可以使用线性回归来拟合这些数据,但是其实它呈现出的一种非线性的关系。

如果我们使用二次曲线来拟合的话,效果会更好。假设这些数据也只有一个特征(维度),那么它的方程就可以写成    。虽然这是一个二次方程,但是如果我们把看成一个特征,x看成一个特征,那么它就变成了一个有两个特征的线性回归了。现在我们来模拟生成一些具有二次曲线的点。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression

if __name__ == "__main__":

    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, size=100)
    plt.scatter(x, y)
    plt.show()

运行结果

这显然x和y是一种非线性的关系。但我们先使用线性回归的方式来拟合这些数据集。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression

if __name__ == "__main__":

    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()
    lin_reg = LinearRegression()
    lin_reg.fit(X, y)
    y_predict = lin_reg.predict(X)
    plt.scatter(x, y)
    plt.plot(x, y_predict, color='r')
    plt.show()

运行结果

很明显,我们用一根直线来拟合这样一个有弧度的数据集,这样的拟合效果是不够好的。现在我们使用多项式回归来重新拟合这些数据。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression

if __name__ == "__main__":

    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()
    lin_reg = LinearRegression()
    lin_reg.fit(X, y)
    y_predict = lin_reg.predict(X)
    # plt.scatter(x, y)
    # plt.plot(x, y_predict, color='r')
    # plt.show()
    # 打印出X的平方特征的样式
    print((X**2).shape)
    # 将X的平方和X合并成一个新的矩阵
    X2 = np.hstack([X, X**2])
    print(X2.shape)
    # 使用线性回归来拟合这个新的具有2个特征的矩阵
    lin_reg2 = LinearRegression()
    lin_reg2.fit(X2, y)
    y_predict2 = lin_reg2.predict(X2)
    plt.scatter(x, y)
    # 由于画曲线需要按照每个点的顺序来画,所以我们需要给数据集添加顺序索引
    plt.plot(np.sort(x), y_predict2[np.argsort(x)], color='r')
    plt.show()
    # 打印线性方程的系数
    print(lin_reg2.coef_)
    # 打印截距
    print(lin_reg2.intercept_)

运行结果

(100, 1)
(100, 2)
[1.00934827 0.50673415]
2.0050492749040076

这里需要注意的是,虽然我们是使用线性回归来拟合具有x^2和x两个特征的数据集,但单从x来看,它是一条二次曲线图。而打印出来的形状也说明我们将一维的x加上了x^2变成了一个二维的矩阵。最后我们还得到了x和x^2前面的系数分别为1.00934827和0.50673415,这跟我们设计时候的y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, size=100)中的1和0.5是十分接近的。而截距2.0050492749040076跟设计时的2也是非常接近的。

多项式回归跟上一章的PCA有着截然相反的思路,多项式回归是将原有的数据集进行升维变成高维度的数据集从而更好的拟合数据集,而PCA则是在进行降维处理。

scikit-learn中的多项式回归与Pipeline

scikit-learn中对原始数据升维封装到了sklearn.preprocessing的PolynomialFeatures中

import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import PolynomialFeatures

if __name__ == "__main__":

    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, 100)
    # 添加一个二次幂的维度
    poly = PolynomialFeatures(degree=2)
    # 将X矩阵进行升维
    poly.fit(X)
    X2 = poly.transform(X)
    print(X2.shape)
    # 打印X2的前5行
    print(X2[: 5, :])

运行结果

(100, 3)
[[ 1.          2.83592438  8.04246708]
 [ 1.         -2.96465367  8.78917138]
 [ 1.         -1.7820983   3.17587436]
 [ 1.         -0.18299249  0.03348625]
 [ 1.          0.59608116  0.35531275]]

这里需要注意的是scikit-learn中的升维并不像我们上一小节中将一维升到二维,而是升到了三维。通过对前5行数据可以看出它的第一列都是1,其实是它为我们加入了X^0。后面两列就好理解了,第二列是X,第三列是X^2。则通过PolynomialFeatures,我们获得了多项式的数据集。然后就是通过线性回归进行处理了。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression

if __name__ == "__main__":

    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, 100)
    # 添加一个二次幂的维度
    poly = PolynomialFeatures(degree=2)
    # 将X矩阵进行升维
    poly.fit(X)
    X2 = poly.transform(X)
    print(X2.shape)
    # 打印X2的前5行
    print(X2[: 5, :])
    lin_reg = LinearRegression()
    lin_reg.fit(X2, y)
    y_predict = lin_reg.predict(X2)
    plt.scatter(x, y)
    # 由于画曲线需要按照每个点的顺序来画,所以我们需要给数据集添加顺序索引
    plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
    plt.show()

运行结果

(100, 3)
[[ 1.          2.83592438  8.04246708]
 [ 1.         -2.96465367  8.78917138]
 [ 1.         -1.7820983   3.17587436]
 [ 1.         -0.18299249  0.03348625]
 [ 1.          0.59608116  0.35531275]]
[0.         0.97626103 0.48271288]
1.9847292103040368

这里我们也看到了线性方程系数以及截距,只不过系数中多了一个0,也就是说X^0的系数为0.

之前我们一直都说的是一维的数据集进行多项式回归,现在我们来看一下多维的数据集进行多项式回归会如何?

# 生成1到11的一维数据集,并分成两列
X = np.arange(1, 11).reshape(-1, 2)
print(X.shape)
print(X)
poly = PolynomialFeatures(degree=2)
poly.fit(X)
X2 = poly.transform(X)
print(X2.shape)
print(X2)

运行结果

(5, 2)
[[ 1  2]
 [ 3  4]
 [ 5  6]
 [ 7  8]
 [ 9 10]]
(5, 6)
[[  1.   1.   2.   1.   2.   4.]
 [  1.   3.   4.   9.  12.  16.]
 [  1.   5.   6.  25.  30.  36.]
 [  1.   7.   8.  49.  56.  64.]
 [  1.   9.  10.  81.  90. 100.]]

首先生成的数据集有2维,经过PolynomialFeatures升维后变成了6维,升维后,第一列依然是0次幂,这个很好理解。然后第二列和第三列就是原始数据集中的1次幂的值,这个也没问题。第四列是第二列的二次幂,第六列是第三列的二次幂,那第五列是什么呢?第五列其实就是第二列和第三列的乘积。

那如果原始数据集是二维的,PolynomialFeatures(degree=3)会如何呢?

# 生成1到11的一维数据集,并分成两列
X = np.arange(1, 11).reshape(-1, 2)
print(X.shape)
print(X)
poly = PolynomialFeatures(degree=3)
poly.fit(X)
X3 = poly.transform(X)
print(X3.shape)
print(X3)

运行结果

(5, 2)
[[ 1  2]
 [ 3  4]
 [ 5  6]
 [ 7  8]
 [ 9 10]]
(5, 10)
[[   1.    1.    2.    1.    2.    4.    1.    2.    4.    8.]
 [   1.    3.    4.    9.   12.   16.   27.   36.   48.   64.]
 [   1.    5.    6.   25.   30.   36.  125.  150.  180.  216.]
 [   1.    7.    8.   49.   56.   64.  343.  392.  448.  512.]
 [   1.    9.   10.   81.   90.  100.  729.  810.  900. 1000.]]

由结果可以看到当PolynomialFeatures(degree=3)的时候,升维后变成了10维。第一列是0次幂,第二列和第三列是两个不同的X1,X2的1次幂,第四列是X1的二次幂,第五列是X1*X2,第六列是X2的二次幂,第七列是X1的三次幂,第八列是X1^2*X2,第九列是X1*X2^2,第十列是X2的三次幂。

总结下来就是原始的二维数据被分成了三个部分

Pipeline(管道)

管道可以将几个步骤的写法合并到一起,一次性去执行这些不同的步骤,从而减少了书写代码的繁琐性。

之前我们在说线性回归的时候有一个归一化的过程,归一化是说在梯度下降法中,由于数据存在数量级的不同,数量级不同会影响梯度的结果(意思就是步长太大或太小都不合适)。那么多项式回归也同样存在这个问题,现在我们使用Pipeline来将这些步骤合并。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

if __name__ == "__main__":

    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x ** 2 + x + 2 + np.random.normal(0, 1, 100)
    poly_reg = Pipeline([
        ("poly", PolynomialFeatures(degree=2)),
        ("std_scaler", StandardScaler()),
        ("lin_reg", LinearRegression())
    ])
    poly_reg.fit(X, y)
    y_predict = poly_reg.predict(X)
    plt.scatter(x, y)
    # 由于画曲线需要按照每个点的顺序来画,所以我们需要给数据集添加顺序索引
    plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
    plt.show()

运行结果

通过Pipeline,先将数据升维,再进行归一化,最后通过线性回归求出多项式回归的曲线。

过拟合与欠拟合

之前我们在线性回归中讲到的一个衡量线性回归拟合度的一个指标——R Squared,它指的是用1减去用我们自己的模型产生的错误/使用Baseline Model产生的错误,其结果就相当于衡量了我们的模型拟合住的这些数据的地方,就衡量了我们的模型没有产生错误的相应的指标。现在我们用线性回归来看看这个指标会是一种什么样的结果。先做数据准备

import numpy as np
import matplotlib.pyplot as plt

if __name__ == "__main__":

    np.random.seed(666)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x**2 + x + 2+ np.random.normal(0, 1, size=100)
    plt.scatter(x, y)
    plt.show()

运行结果

我们先使用线性回归来拟合这些数据集

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression

if __name__ == "__main__":

    np.random.seed(666)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x**2 + x + 2+ np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()
    lin_reg = LinearRegression()
    lin_reg.fit(X, y)
    print(lin_reg.score(X, y))
    y_predict = lin_reg.predict(X)
    plt.scatter(x, y)
    plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
    plt.show()

运行结果

0.4953707811865009

这里的R方值只有0.49,显然拟合度是非常低的(1是最高值,0是最低值),通过R方值也可以看出来,样本x和样本标记输出y之间的线性关系是很弱的。所以直接使用这种线性回归的方式可能并不合适。现在我们要使用多项式回归来拟合这些数据集,虽然多项式回归也使用的是线性回归,也可以使用R方来判断拟合度,但由于多项式回归使用的线性回归跟只使用线性回归它们的方程是不同的,系数个数也是不一样的,但是它们在不同的维度上,为了避免歧义,我们使用均方误差MSE来看数据拟合的结果。这是因为我们是对同一组数据集进行拟合,使用不同的方法得到的均方误差的结果是具有可比性的

 

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

if __name__ == "__main__":

    np.random.seed(666)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x**2 + x + 2+ np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()
    lin_reg = LinearRegression()
    lin_reg.fit(X, y)
    # print(lin_reg.score(X, y))
    y_predict = lin_reg.predict(X)
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
    # plt.show()
    # 打印均方误差
    print(mean_squared_error(y, y_predict))

运行结果

3.0750025765636577

现在我们使用多项式回归来看一下这个均方误差

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler

if __name__ == "__main__":

    np.random.seed(666)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x**2 + x + 2+ np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()
    lin_reg = LinearRegression()
    lin_reg.fit(X, y)
    # print(lin_reg.score(X, y))
    y_predict = lin_reg.predict(X)
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
    # plt.show()
    # 打印均方误差
    print(mean_squared_error(y, y_predict))

    def PolynomialRegression(degree):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("lin_reg", LinearRegression())
        ])

    poly2_reg = PolynomialRegression(degree=2)
    poly2_reg.fit(X, y)
    y2_predict = poly2_reg.predict(X)
    print(mean_squared_error(y, y2_predict))
    plt.scatter(x, y)
    plt.plot(np.sort(x), y2_predict[np.argsort(x)], color='r')
    plt.show()

运行结果

3.0750025765636577
1.0987392142417856

通过结果,我们可以看出使用多项式回归,它的均方误差是1.09,比线性回归要低。如果我们的多项式回归PolynomialFeatures(degree=degree)更高,会如何呢?

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler

if __name__ == "__main__":

    np.random.seed(666)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x**2 + x + 2+ np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()
    lin_reg = LinearRegression()
    lin_reg.fit(X, y)
    # print(lin_reg.score(X, y))
    y_predict = lin_reg.predict(X)
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
    # plt.show()
    # 打印均方误差
    print(mean_squared_error(y, y_predict))

    def PolynomialRegression(degree):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("lin_reg", LinearRegression())
        ])

    poly2_reg = PolynomialRegression(degree=2)
    poly2_reg.fit(X, y)
    y2_predict = poly2_reg.predict(X)
    print(mean_squared_error(y, y2_predict))
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y2_predict[np.argsort(x)], color='r')
    # plt.show()

    poly10_reg = PolynomialRegression(degree=10)
    poly10_reg.fit(X, y)
    y10_predict = poly10_reg.predict(X)
    print(mean_squared_error(y, y10_predict))
    plt.scatter(x, y)
    plt.plot(np.sort(x), y10_predict[np.argsort(x)], color='r')
    plt.show()

运行结果

3.0750025765636577
1.0987392142417856
1.0508466763764148

通过结果,我们可以看出当我们PolynomialFeatures(degree=degree)为10的时候,它的均方误差更小了。现在将这个degree传入更大的值来看看结果

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler

if __name__ == "__main__":

    np.random.seed(666)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x**2 + x + 2+ np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()
    lin_reg = LinearRegression()
    lin_reg.fit(X, y)
    # print(lin_reg.score(X, y))
    y_predict = lin_reg.predict(X)
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
    # plt.show()
    # 打印均方误差
    print(mean_squared_error(y, y_predict))

    def PolynomialRegression(degree):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("lin_reg", LinearRegression())
        ])

    poly2_reg = PolynomialRegression(degree=2)
    poly2_reg.fit(X, y)
    y2_predict = poly2_reg.predict(X)
    print(mean_squared_error(y, y2_predict))
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y2_predict[np.argsort(x)], color='r')
    # plt.show()

    poly10_reg = PolynomialRegression(degree=10)
    poly10_reg.fit(X, y)
    y10_predict = poly10_reg.predict(X)
    print(mean_squared_error(y, y10_predict))
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y10_predict[np.argsort(x)], color='r')
    # plt.show()

    poly100_reg = PolynomialRegression(degree=100)
    poly100_reg.fit(X, y)
    y100_predict = poly100_reg.predict(X)
    print(mean_squared_error(y, y100_predict))
    plt.scatter(x, y)
    plt.plot(np.sort(x), y100_predict[np.argsort(x)], color='r')
    plt.show()

运行结果

3.0750025765636577
1.0987392142417856
1.0508466763764148
0.6879768981520811

通过结果我们可以看出,PolynomialFeatures(degree=degree)为100的时候,它的均方误差更小了,但对于这个图形来说并不准确,它只是原有的数据点对应的预测值连接出来的结果,不过有很多地方可能没有数据点,所以这个连接的结果和原来的曲线不一样。现在我们来尝试还原真实的曲线。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler

if __name__ == "__main__":

    np.random.seed(666)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x**2 + x + 2+ np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()
    lin_reg = LinearRegression()
    lin_reg.fit(X, y)
    # print(lin_reg.score(X, y))
    y_predict = lin_reg.predict(X)
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
    # plt.show()
    # 打印均方误差
    print(mean_squared_error(y, y_predict))

    def PolynomialRegression(degree):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("lin_reg", LinearRegression())
        ])

    poly2_reg = PolynomialRegression(degree=2)
    poly2_reg.fit(X, y)
    y2_predict = poly2_reg.predict(X)
    print(mean_squared_error(y, y2_predict))
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y2_predict[np.argsort(x)], color='r')
    # plt.show()

    poly10_reg = PolynomialRegression(degree=10)
    poly10_reg.fit(X, y)
    y10_predict = poly10_reg.predict(X)
    print(mean_squared_error(y, y10_predict))
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y10_predict[np.argsort(x)], color='r')
    # plt.show()

    poly100_reg = PolynomialRegression(degree=100)
    poly100_reg.fit(X, y)
    y100_predict = poly100_reg.predict(X)
    print(mean_squared_error(y, y100_predict))
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y100_predict[np.argsort(x)], color='r')
    # plt.show()
    # 重新在-3,3之间均匀生成数据点
    X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
    y_plot = poly100_reg.predict(X_plot)
    plt.scatter(x, y)
    plt.plot(X_plot[:, 0], y_plot, color='r')
    # 限定预测值的范围
    plt.axis([-3, 3, -1, 10])
    plt.show()

运行结果

通过以上的分析,我们不难发现多项式回归中PolynomialFeatures(degree=degree)传入的值越高,我们拟合的值就越好。我们有这么多的样本点,只要degree足够大,我们总能找到一根曲线,这根曲线可以将所有的样本点进行拟合,从而使得均方误差为0。这个结果从均方误差的角度来看虽然是更加好的,更加小的,但它真的是可以更好的反应我们样本走势的这样一根曲线吗?从上面的图形中可以看出,在某些部分,它完全不是我们想要的样子,它为了能够拟合我们所有的的样本点,变得太过复杂,这种情况我们就称之为过拟合。而在之前我们直接使用线性回归的一根直线来拟合我们的数据,它显然也没有非常好的反应我们原始数据的样本特征,它是太过简单了,这种情况我们称之为欠拟合

为什么要有训练数据集和测试数据集

图中蓝色的样本点为原始的样本点,而紫色的点是新的样本点,经过我们拟合的曲线预测的结果。这个预测的结果跟原始的样本点看上去不在一个趋势上,我们直观的想这个预测值很有可能是错误的。在过拟合的情况下,虽然这条曲线将原始的样本点拟合的非常好,总体的均方误差非常的小,但是一旦来了新的样本点,那么它就不能进行很好的预测了。在这里我们就称我们的模型(曲线),它的泛化能力是非常弱的。泛化能力指的是由此及彼的能力,也就是说我们根据我们的训练数据得到的这条曲线,可是这条曲线在面对新的数据的时候,它的预测能力却非常的弱,也就是泛化能力非常的差。我们要训练这个模型为的不是最大程度的拟合这些点,而是为了可以获得一个可以预测的模型,当有了新的样本新的数据的时候,我们这个模型可以给出很好的解答。正因为如此,我们去衡量我们这个模型对于这个训练的数据拟合程度有多好,是没有意义的。我们真正需要的是能够衡量我们得到的这个模型,它的泛化能力有多好。在这种情况下,我们应该使用训练、测试数据集的分离

在这里,我们只使用我们的训练数据集来获得模型,测试数据对于我们的模型来说就是全新的数据,如果我们的模型面对测试数据也能获得很好的结果的话,那么我们就说我们这个模型泛化能力是很强的。因为它能通过训练数据得到的结果,很好的给出测试数据相应的结果。如果我们这个模型面对测试数据是很弱的,那么我们多半就遭遇了过拟合。事实上,这是测试数据集更大的意义。现在我们来将之前的数据集分成训练数据和测试数据,先使用线性回归来看一下测试数据的均方误差。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

if __name__ == "__main__":

    np.random.seed(666)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x**2 + x + 2+ np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()
    lin_reg = LinearRegression()
    lin_reg.fit(X, y)
    # print(lin_reg.score(X, y))
    y_predict = lin_reg.predict(X)
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
    # plt.show()
    # 打印均方误差
    print(mean_squared_error(y, y_predict))

    def PolynomialRegression(degree):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("lin_reg", LinearRegression())
        ])

    poly2_reg = PolynomialRegression(degree=2)
    poly2_reg.fit(X, y)
    y2_predict = poly2_reg.predict(X)
    print(mean_squared_error(y, y2_predict))
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y2_predict[np.argsort(x)], color='r')
    # plt.show()

    poly10_reg = PolynomialRegression(degree=10)
    poly10_reg.fit(X, y)
    y10_predict = poly10_reg.predict(X)
    print(mean_squared_error(y, y10_predict))
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y10_predict[np.argsort(x)], color='r')
    # plt.show()

    poly100_reg = PolynomialRegression(degree=100)
    poly100_reg.fit(X, y)
    y100_predict = poly100_reg.predict(X)
    print(mean_squared_error(y, y100_predict))
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y100_predict[np.argsort(x)], color='r')
    # plt.show()
    # 重新在-3,3之间均匀生成数据点
    X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
    y_plot = poly100_reg.predict(X_plot)
    # plt.scatter(x, y)
    # plt.plot(X_plot[:, 0], y_plot, color='r')
    # # 限定预测值的范围
    # plt.axis([-3, 3, -1, 10])
    # plt.show()
    # 将原始数据分成训练数据集和测试数据集
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)
    lin_reg = LinearRegression()
    lin_reg.fit(X_train, y_train)
    y_predict = lin_reg.predict(X_test)
    print(mean_squared_error(y_test, y_predict))

运行结果

3.0750025765636577
1.0987392142417856
1.0508466763764148
0.6879768981520811
2.2199965269396573

由结果我们可以看出使用线性回归,测试数据集的均方误差为2.21,我们再来看一下多项式回归

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

if __name__ == "__main__":

    np.random.seed(666)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x**2 + x + 2+ np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()
    lin_reg = LinearRegression()
    lin_reg.fit(X, y)
    # print(lin_reg.score(X, y))
    y_predict = lin_reg.predict(X)
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
    # plt.show()
    # 打印均方误差
    print(mean_squared_error(y, y_predict))

    def PolynomialRegression(degree):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("lin_reg", LinearRegression())
        ])

    poly2_reg = PolynomialRegression(degree=2)
    poly2_reg.fit(X, y)
    y2_predict = poly2_reg.predict(X)
    print(mean_squared_error(y, y2_predict))
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y2_predict[np.argsort(x)], color='r')
    # plt.show()

    poly10_reg = PolynomialRegression(degree=10)
    poly10_reg.fit(X, y)
    y10_predict = poly10_reg.predict(X)
    print(mean_squared_error(y, y10_predict))
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y10_predict[np.argsort(x)], color='r')
    # plt.show()

    poly100_reg = PolynomialRegression(degree=100)
    poly100_reg.fit(X, y)
    y100_predict = poly100_reg.predict(X)
    print(mean_squared_error(y, y100_predict))
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y100_predict[np.argsort(x)], color='r')
    # plt.show()
    # 重新在-3,3之间均匀生成数据点
    X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
    y_plot = poly100_reg.predict(X_plot)
    # plt.scatter(x, y)
    # plt.plot(X_plot[:, 0], y_plot, color='r')
    # # 限定预测值的范围
    # plt.axis([-3, 3, -1, 10])
    # plt.show()
    # 将原始数据分成训练数据集和测试数据集
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)
    lin_reg = LinearRegression()
    lin_reg.fit(X_train, y_train)
    y_predict = lin_reg.predict(X_test)
    print(mean_squared_error(y_test, y_predict))

    poly2_reg = PolynomialRegression(degree=2)
    poly2_reg.fit(X_train, y_train)
    y2_predict = poly2_reg.predict(X_test)
    print(mean_squared_error(y_test, y2_predict))

运行结果

3.0750025765636577
1.0987392142417856
1.0508466763764148
0.6879768981520811
2.2199965269396573
0.8035641056297901

这里的多项式回归,degree=2的时候,均方误差为0.80,现在提升degree的值

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

if __name__ == "__main__":

    np.random.seed(666)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x**2 + x + 2+ np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()
    lin_reg = LinearRegression()
    lin_reg.fit(X, y)
    # print(lin_reg.score(X, y))
    y_predict = lin_reg.predict(X)
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
    # plt.show()
    # 打印均方误差
    print(mean_squared_error(y, y_predict))

    def PolynomialRegression(degree):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("lin_reg", LinearRegression())
        ])

    poly2_reg = PolynomialRegression(degree=2)
    poly2_reg.fit(X, y)
    y2_predict = poly2_reg.predict(X)
    print(mean_squared_error(y, y2_predict))
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y2_predict[np.argsort(x)], color='r')
    # plt.show()

    poly10_reg = PolynomialRegression(degree=10)
    poly10_reg.fit(X, y)
    y10_predict = poly10_reg.predict(X)
    print(mean_squared_error(y, y10_predict))
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y10_predict[np.argsort(x)], color='r')
    # plt.show()

    poly100_reg = PolynomialRegression(degree=100)
    poly100_reg.fit(X, y)
    y100_predict = poly100_reg.predict(X)
    print(mean_squared_error(y, y100_predict))
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y100_predict[np.argsort(x)], color='r')
    # plt.show()
    # 重新在-3,3之间均匀生成数据点
    X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
    y_plot = poly100_reg.predict(X_plot)
    # plt.scatter(x, y)
    # plt.plot(X_plot[:, 0], y_plot, color='r')
    # # 限定预测值的范围
    # plt.axis([-3, 3, -1, 10])
    # plt.show()
    # 将原始数据分成训练数据集和测试数据集
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)
    lin_reg = LinearRegression()
    lin_reg.fit(X_train, y_train)
    y_predict = lin_reg.predict(X_test)
    print(mean_squared_error(y_test, y_predict))

    poly2_reg = PolynomialRegression(degree=2)
    poly2_reg.fit(X_train, y_train)
    y2_predict = poly2_reg.predict(X_test)
    print(mean_squared_error(y_test, y2_predict))

    poly10_reg = PolynomialRegression(degree=10)
    poly10_reg.fit(X_train, y_train)
    y10_predict = poly10_reg.predict(X_test)
    print(mean_squared_error(y_test, y10_predict))

运行结果

3.0750025765636577
1.0987392142417856
1.0508466763764148
0.6879768981520811
2.2199965269396573
0.8035641056297901
0.9212930722150788

当degree=10的时候,均方误差为0.92,比degree=2的时候要高,这就说明了它的泛化能力其实变差了。现在来看一下degree=100的情况。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

if __name__ == "__main__":

    np.random.seed(666)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x**2 + x + 2+ np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()
    lin_reg = LinearRegression()
    lin_reg.fit(X, y)
    # print(lin_reg.score(X, y))
    y_predict = lin_reg.predict(X)
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
    # plt.show()
    # 打印均方误差
    print(mean_squared_error(y, y_predict))

    def PolynomialRegression(degree):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("lin_reg", LinearRegression())
        ])

    poly2_reg = PolynomialRegression(degree=2)
    poly2_reg.fit(X, y)
    y2_predict = poly2_reg.predict(X)
    print(mean_squared_error(y, y2_predict))
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y2_predict[np.argsort(x)], color='r')
    # plt.show()

    poly10_reg = PolynomialRegression(degree=10)
    poly10_reg.fit(X, y)
    y10_predict = poly10_reg.predict(X)
    print(mean_squared_error(y, y10_predict))
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y10_predict[np.argsort(x)], color='r')
    # plt.show()

    poly100_reg = PolynomialRegression(degree=100)
    poly100_reg.fit(X, y)
    y100_predict = poly100_reg.predict(X)
    print(mean_squared_error(y, y100_predict))
    # plt.scatter(x, y)
    # plt.plot(np.sort(x), y100_predict[np.argsort(x)], color='r')
    # plt.show()
    # 重新在-3,3之间均匀生成数据点
    X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
    y_plot = poly100_reg.predict(X_plot)
    # plt.scatter(x, y)
    # plt.plot(X_plot[:, 0], y_plot, color='r')
    # # 限定预测值的范围
    # plt.axis([-3, 3, -1, 10])
    # plt.show()
    # 将原始数据分成训练数据集和测试数据集
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)
    lin_reg = LinearRegression()
    lin_reg.fit(X_train, y_train)
    y_predict = lin_reg.predict(X_test)
    print(mean_squared_error(y_test, y_predict))

    poly2_reg = PolynomialRegression(degree=2)
    poly2_reg.fit(X_train, y_train)
    y2_predict = poly2_reg.predict(X_test)
    print(mean_squared_error(y_test, y2_predict))

    poly10_reg = PolynomialRegression(degree=10)
    poly10_reg.fit(X_train, y_train)
    y10_predict = poly10_reg.predict(X_test)
    print(mean_squared_error(y_test, y10_predict))

    poly100_reg = PolynomialRegression(degree=100)
    poly100_reg.fit(X_train, y_train)
    y100_predict = poly100_reg.predict(X_test)
    print(mean_squared_error(y_test, y100_predict))

运行结果

3.0750025765636577
1.0987392142417856
1.0508466763764148
0.6879768981520811
2.2199965269396573
0.8035641056297901
0.9212930722150788
11695751727.09301

我们可以看到当degree取100的时候,均方误差已经上亿了。通过上一小节的比较,虽然它对训练数据集的拟合比2,10都要好的多,但面对新的测试数据集的时候,它的预测结果是极差的。所以我们要放入生产环境的模型绝不是degree=100的模型。

模型复杂度

这个模型复杂度对于不同的算法来说代表的是不同的意思。对于多项式回归来说,是阶数degree越高,模型会越复杂。对KNN算法来说,其实是K越小,模型越复杂;K越大,模型越简单。当K取最大值的时候,跟样本总数一样的时候,我们这个模型是最简单的,因为次数KNN算法就变成了看整个样本里,哪种样本最多,我们就选谁。而K=1的时候,这是KNN算法中最复杂的模型,因为对于每一个点,我们都要找到离它最近的那个点。我们每一个模型都可以通过参数的调整是它从简单变复杂。对于训练数据集来说,随着模型越来越复杂,它的模型准确率将会越来越高。对于测试数据集来说,通常是一根曲线。模型最简单的时候,对测试数据集相应的模型准确率会比较低,随着模型逐渐的变复杂,对测试数据预测的准确率再逐渐的提升,提升到一定程度之后,模型再复杂下去的话,我们对测试数据集的准确率将又开始进行下降。其实就是从欠拟合到正合适到过拟合的这么一个过程

欠拟合 underfitting

算法所训练的模型不能完整表述数据关系。

过拟合 overfitting

算法所训练的模型过多地表达了数据间的噪音关系。我们真实采集到的数据,通常都是有噪音的,而不是非常纯的数据,那么我们在拟合这些数据的时候,很有可能把这些噪音当作了特征来进行了训练,所以就产生了过拟合的结果。

一个生活中的例子

假设我们的机器学习系统要分辨猫和狗,对于这样的系统如果对所有的有眼睛的动物都管它叫做猫或者狗,那么这样的模型显然就是一个欠拟合的模型。因为它寻找的这个特征太普遍了,太一般了。不仅是猫和狗,几乎大部分动物都是有眼睛的,所以它不能完整的表达我们要识别的那个内容所代表的那个特征。从另外一方面,如果我们的机器学习算法说毛发的颜色是黄色的才是狗,那么此时它就是一个过拟合的例子。这是因为可能只有这一只狗它的毛发是黄色的。我们的机器学习算法其实只针对我们的训练数据集进行了学习,其实学习到的可能是噪音。在这里狗的毛发是黄色的,这只是我们给出的训练数据里的特征,这个特征不是一般性的特征。有很多狗的毛发的颜色可能是斑点的,是黑色的,是白色的都是有可能的,所以在这种情况下,我们又找了一个太细节的特征作为我们的预测标准,这就成为了过拟合。

其实我们要找的就是泛化能力最好的那个地方,就是我们的测试数据集来说,模型准确率最高的地方。

学习曲线

欠拟合和过拟合有一条曲线可以可视化的看到,那就是学习曲线。随着训练样本的逐渐增多,算法训练出的模型的表现能力称为学习曲线

依然是先创造数据,画出散点图

import numpy as np
import matplotlib.pyplot as plt

if __name__ == "__main__":

    np.random.seed(666)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, size=100)
    plt.scatter(x, y)
    plt.show()

运行结果

将创建的数据分成训练数据集和测试数据集

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

if __name__ == "__main__":

    np.random.seed(666)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=10)
    print(X_train.shape)

运行结果

(75, 1)

这里可以看到训练数据集有75个1维的数据。现在我们先使用线性回归来训练训练数据集的每一个数据,并得出预测结果,并看这每一个数据的均方根误差RMSE

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

if __name__ == "__main__":

    np.random.seed(666)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=10)
    # print(X_train.shape)

    def plot_learning_curve(algo, X_train, X_test, y_train, y_test):
        # 画出学习曲线
        train_score = []
        test_score = []
        for i in range(1, len(X_train) + 1):
            # 对训练数据集的每一个数据进行拟合
            algo.fit(X_train[: i], y_train[: i])
            # 每次拟合后对训练数据集进行预测
            y_train_predict = algo.predict(X_train[: i])
            # 记录每一次训练数据集预测的均方误差
            train_score.append(mean_squared_error(y_train[: i], y_train_predict))

            # 每次拟合后对测试数据集进行预测
            y_test_predict = algo.predict(X_test)
            # 记录每一次测试数据集预测的均方误差
            test_score.append(mean_squared_error(y_test, y_test_predict))
        # 画出训练数据集的均方根误差线段
        plt.plot([i for i in range(1, len(X_train) + 1)], np.sqrt(train_score), label="train")
        # 画出测试数据集的均方根误差线段
        plt.plot([i for i in range(1, len(X_train) + 1)], np.sqrt(test_score), label="test")
        plt.legend()
        plt.axis([0, len(X_train) + 1, 0, 4])
        plt.show()

    plot_learning_curve(LinearRegression(), X_train, X_test, y_train, y_test)

运行结果

从图形中,我们可以看到,训练数据集的均方根误差一开始是很小的,但随着训练数据的增多在不断的增大,并且增大的趋势非常的快,但到了一定的训练数据量的时候增大开始放缓,最后趋于平稳。测试数据集的均方根误差一开始在迅速增大,到了一个阈值极速下降,最后下降的趋于平缓。在最终的时候,我们的训练误差和测试误差大体是在一个级别上的。不过测试误差还是会比训练误差高一些,这是因为在拟合的过程可以把训练数据集拟合的比较好,相对的误差小一些,但是泛化到测试数据上的时候,还是有可能多一些误差。现在我们来看一下使用多项式回归的学习曲线是什么样子的。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler

if __name__ == "__main__":

    np.random.seed(666)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=10)
    # print(X_train.shape)

    def plot_learning_curve(algo, X_train, X_test, y_train, y_test):
        # 画出学习曲线
        train_score = []
        test_score = []
        for i in range(1, len(X_train) + 1):
            # 对训练数据集的每一个数据进行拟合
            algo.fit(X_train[: i], y_train[: i])
            # 每次拟合后对训练数据集进行预测
            y_train_predict = algo.predict(X_train[: i])
            # 记录每一次训练数据集预测的均方误差
            train_score.append(mean_squared_error(y_train[: i], y_train_predict))

            # 每次拟合后对测试数据集进行预测
            y_test_predict = algo.predict(X_test)
            # 记录每一次测试数据集预测的均方误差
            test_score.append(mean_squared_error(y_test, y_test_predict))
        # 画出训练数据集的均方根误差线段
        plt.plot([i for i in range(1, len(X_train) + 1)], np.sqrt(train_score), label="train")
        # 画出测试数据集的均方根误差线段
        plt.plot([i for i in range(1, len(X_train) + 1)], np.sqrt(test_score), label="test")
        plt.legend()
        plt.axis([0, len(X_train) + 1, 0, 4])
        plt.show()

    # plot_learning_curve(LinearRegression(), X_train, X_test, y_train, y_test)

    def PolynomialRegression(degree):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("lin_reg", LinearRegression())
        ])

    poly2_reg = PolynomialRegression(degree=2)
    plot_learning_curve(poly2_reg, X_train, X_test, y_train, y_test)

运行结果

这就是我们使用2阶的多项式回归得到的学习曲线,它的基本趋势跟线性回归是差不多的,但它们最大的不同就在于线性回归的稳定值在1.6左右,而多项式回归的稳定值在1左右。说明多项式回归拟合的是比较好的。现在我们将多项式回归的阶数提升到20,来看看学习曲线是什么样子的。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler

if __name__ == "__main__":

    np.random.seed(666)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()
    X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=10)
    # print(X_train.shape)

    def plot_learning_curve(algo, X_train, X_test, y_train, y_test):
        # 画出学习曲线
        train_score = []
        test_score = []
        for i in range(1, len(X_train) + 1):
            # 对训练数据集的每一个数据进行拟合
            algo.fit(X_train[: i], y_train[: i])
            # 每次拟合后对训练数据集进行预测
            y_train_predict = algo.predict(X_train[: i])
            # 记录每一次训练数据集预测的均方误差
            train_score.append(mean_squared_error(y_train[: i], y_train_predict))

            # 每次拟合后对测试数据集进行预测
            y_test_predict = algo.predict(X_test)
            # 记录每一次测试数据集预测的均方误差
            test_score.append(mean_squared_error(y_test, y_test_predict))
        # 画出训练数据集的均方根误差线段
        plt.plot([i for i in range(1, len(X_train) + 1)], np.sqrt(train_score), label="train")
        # 画出测试数据集的均方根误差线段
        plt.plot([i for i in range(1, len(X_train) + 1)], np.sqrt(test_score), label="test")
        plt.legend()
        plt.axis([0, len(X_train) + 1, 0, 4])
        plt.show()

    # plot_learning_curve(LinearRegression(), X_train, X_test, y_train, y_test)

    def PolynomialRegression(degree):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("lin_reg", LinearRegression())
        ])

    poly2_reg = PolynomialRegression(degree=2)
    # plot_learning_curve(poly2_reg, X_train, X_test, y_train, y_test)

    poly20_reg = PolynomialRegression(degree=20)
    plot_learning_curve(poly20_reg, X_train, X_test, y_train, y_test)

运行结果

对比阶数2和20的两个图像,我们会发现,虽然整体趋势是差不多的,但是在趋于稳定的时候,阶数20的训练数据集和测试数据集的学习曲线的间距是比较大的。这说明了我们的模型虽然在训练数据集上已经拟合的非常好,但是在测试数据集上相应的它的误差依然是很大的,离训练数据集的这根学习曲线比较远。这种情况通常就是过拟合的情况。它的泛化能力是不够的。

以上的三个样例对应了三种情况——欠拟合,最佳,过拟合

欠拟合

最佳

对比欠拟合和最佳的情况,欠拟合的那根学习曲线趋于稳定的位置比最佳的情况的学习曲线趋于稳定的位置要高一些,说明对于训练数据集还是测试数据集来说相应的误差都比较大,这是因为我们本身模型选的就不对,所以即使在训练数据集上它的误差也是大的。

过拟合

在过拟合的这种情况,相应的它在训练数据集上它的误差不大,甚至随着阶数的增大,训练数据集趋于稳定时的误差会更小。但是对于测试数据集的误差是比较大的。并且测试数据集的误差离训练数据集的误差比较远,它们之间的差距比较大。这就说明了此时我们这个模型的泛化能力不够好,对于新的数据来说误差是比较大的。

验证数据集与交叉验证

之前我们把数据分成了训练数据集和测试数据集,我们使用测试数据来考验模型的泛化能力比只使用训练数据要靠谱的多。但是严格来说它也有不靠谱的地方,这个问题就在于:针对特定测试数据集过拟合。因为这个测试数据集是已知的,我们相当于针对这个测试数据集进行调参,那么它也有可能产生过拟合的情况,也就是说我们得到的这个模型针对这个测试数据集过拟合了。解决这个问题,我们需要将我们的数据分成三部分——训练数据集,验证数据集和测试数据集。

我们训练好了模型之后,将验证数据集送给这个模型,看看它相应的效果如何。如果效果不好的话,我们重新换参数,重新训练这个模型。直到我们找到了一组参数,使得我们的这个模型针对我们的验证数据集来说已经达到最优了。这样一个模型经过这个过程之后,我们的测试数据再给这个模型,作为衡量最终模型性能的数据集。而测试数据集是不参与这个模型的创建的,而训练数据集和验证数据集都参与了模型的创建。训练数据集进行训练,验证数据集用来评判,一旦不好的话我们就需要重新进行训练,这两种形式都叫参与了模型的创建。但是我们的测试数据对于模型来说是完全不可知的。而此时验证数据集是调整超参数使用的数据集

而这种方式也可能存在问题,因为验证数据集是从我们的原始数据集中随机的切割出来的,我们训练的模型有可能过拟合这个验证数据集。一旦这个验证数据集里有比较极端的数据,就有可能导致我们的这个模型相应的不准确。为了解决这个问题就有了交叉验证(Cross Validation)

交叉验证

我们将训练数据集分成k份(上图中k=3),它们分别组合用来做训练数据和验证数据,就会得到k个模型。每一个模型在验证数据集上都会求出来一个性能指标。那么这些性能指标的平均作为最终我们衡量当前的这个算法得到的这个模型它的标准是怎么样的。k个模型的均值作为结果调参。如果这个结果不够好的话,就相应的调整一下超参数,然后继续将我们的训练数据集分成k份,每一份数据集都作为验证数据集得到k个模型,这k个模型进行平均来看最终的结果。由于我们有一个求平均的过程,所以说不会由于某一份验证数据集有极端的数据,而导致我们最终训练出来的模型有过大的偏差。这样做比我们只设立一个验证数据集要靠谱。所以通常我们在调参的时候是要使用这种交叉验证的方式的。现在我们使用手写数据集来进行交叉验证,先进行训练数据集和测试数据集的分类检测。

import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

if __name__ == "__main__":

    digits = datasets.load_digits()
    X = digits.data
    y = digits.target
    # 将手写数据集分成训练数据集和测试数据集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=666)
    best_score, best_p, best_k = 0, 0, 0
    # 分别使用2-11个数据来分类,获取最大的分类准确度
    for k in range(2, 11):
        # p代表用哪一种方式的距离来计算
        # 当p = 1时,得到绝对值距离,也叫曼哈顿距离
        # 当p = 2时,得到欧几里德距离(Euclideandistance)距离,就是两点之间的直线距离
        for p in range(1, 6):
            knn_clf = KNeighborsClassifier(weights="distance", n_neighbors=k, p=p)
            # 将训练数据集导入
            knn_clf.fit(X_train, y_train)
            # 获取分类准确度
            score = knn_clf.score(X_test, y_test)
            if score > best_score:
                best_score, best_p, best_k = score, p, k

    print("Best K =", best_k)
    print("Best P =", best_p)
    print("Best Score =", best_score)

运行结果

Best K = 3
Best P = 4
Best Score = 0.9860917941585535

由结果可知,当K=3时,第4种距离时候的分类准确度最高。现在我们来进行交叉验证

import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
# 可以自动为我们完成交叉验证的过程,同时返回生成的k个模型,每个模型对应的准确率
from sklearn.model_selection import cross_val_score

if __name__ == "__main__":

    digits = datasets.load_digits()
    X = digits.data
    y = digits.target
    # 将手写数据集分成训练数据集和测试数据集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=666)
    best_score, best_p, best_k = 0, 0, 0
    # 分别使用2-11个数据来分类,获取最大的分类准确度
    for k in range(2, 11):
        # p代表用哪一种方式的距离来计算
        # 当p = 1时,得到绝对值距离,也叫曼哈顿距离
        # 当p = 2时,得到欧几里德距离(Euclideandistance)距离,就是两点之间的直线距离
        for p in range(1, 6):
            knn_clf = KNeighborsClassifier(weights="distance", n_neighbors=k, p=p)
            # 将训练数据集导入
            knn_clf.fit(X_train, y_train)
            # 获取分类准确度
            score = knn_clf.score(X_test, y_test)
            if score > best_score:
                best_score, best_p, best_k = score, p, k

    print("Best K =", best_k)
    print("Best P =", best_p)
    print("Best Score =", best_score)

    knn_clf = KNeighborsClassifier()
    print(cross_val_score(knn_clf, X_train, y_train))

运行结果

Best K = 3
Best P = 4
Best Score = 0.9860917941585535
[0.99537037 0.98148148 0.97685185 0.97674419 0.97209302]

由结果可以看出,交叉验证把训练数据分成了五份,并给出了5个分类准确度的值。现在我们使用交叉验证的方式来进行调参。

import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
# 可以自动为我们完成交叉验证的过程,同时返回生成的k个模型,每个模型对应的准确率
from sklearn.model_selection import cross_val_score

if __name__ == "__main__":

    digits = datasets.load_digits()
    X = digits.data
    y = digits.target
    # 将手写数据集分成训练数据集和测试数据集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=666)
    best_score, best_p, best_k = 0, 0, 0
    # 分别使用2-11个数据来分类,获取最大的分类准确度
    for k in range(2, 11):
        # p代表用哪一种方式的距离来计算
        # 当p = 1时,得到绝对值距离,也叫曼哈顿距离
        # 当p = 2时,得到欧几里德距离(Euclideandistance)距离,就是两点之间的直线距离
        for p in range(1, 6):
            knn_clf = KNeighborsClassifier(weights="distance", n_neighbors=k, p=p)
            # 将训练数据集导入
            knn_clf.fit(X_train, y_train)
            # 获取分类准确度
            score = knn_clf.score(X_test, y_test)
            if score > best_score:
                best_score, best_p, best_k = score, p, k

    print("Best K =", best_k)
    print("Best P =", best_p)
    print("Best Score =", best_score)

    knn_clf = KNeighborsClassifier()
    print(cross_val_score(knn_clf, X_train, y_train))

    best_score, best_p, best_k = 0, 0, 0
    # 分别使用2-11个数据来分类,获取最大的分类准确度
    for k in range(2, 11):
        # p代表用哪一种方式的距离来计算
        # 当p = 1时,得到绝对值距离,也叫曼哈顿距离
        # 当p = 2时,得到欧几里德距离(Euclideandistance)距离,就是两点之间的直线距离
        for p in range(1, 6):
            knn_clf = KNeighborsClassifier(weights="distance", n_neighbors=k, p=p)
            scores = cross_val_score(knn_clf, X_train, y_train)
            score = np.mean(scores)
            if score > best_score:
                best_score, best_p, best_k = score, p, k

    print("Best K =", best_k)
    print("Best P =", best_p)
    print("Best Score =", best_score)

运行结果

Best K = 3
Best P = 4
Best Score = 0.9860917941585535
[0.99537037 0.98148148 0.97685185 0.97674419 0.97209302]
Best K = 2
Best P = 2
Best Score = 0.9851507321274763

通过结果,我们会发现,使用交叉验证得到的结果和直接使用train、test是不一样的。在这种情况下,我们通常会更相信通过交叉验证得到的结果。因为在train、test中得到的这组参数很有可能只是过拟合了分离出来的这组测试数据集。在我们的交叉验证中得到的这个最佳的分类准确度它是低于在train、test中得到的最佳准确度,这是因为在交叉验证的过程中,通常不会过拟合某一组的测试数据集,所以平均来讲这个分数会稍微低一些。现在我们得到了最终的K和P,那么最终我们这个模型的分类准确率是多少呢,就是0.985吗?当然不是,我们这个交叉验证的过程为的就是拿到最好的K和P而已。当我们拿到了这组K和P之后,我们就可以获得一个最佳的分类准确率。

import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
# 可以自动为我们完成交叉验证的过程,同时返回生成的k个模型,每个模型对应的准确率
from sklearn.model_selection import cross_val_score

if __name__ == "__main__":

    digits = datasets.load_digits()
    X = digits.data
    y = digits.target
    # 将手写数据集分成训练数据集和测试数据集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=666)
    best_score, best_p, best_k = 0, 0, 0
    # 分别使用2-11个数据来分类,获取最大的分类准确度
    for k in range(2, 11):
        # p代表用哪一种方式的距离来计算
        # 当p = 1时,得到绝对值距离,也叫曼哈顿距离
        # 当p = 2时,得到欧几里德距离(Euclideandistance)距离,就是两点之间的直线距离
        for p in range(1, 6):
            knn_clf = KNeighborsClassifier(weights="distance", n_neighbors=k, p=p)
            # 将训练数据集导入
            knn_clf.fit(X_train, y_train)
            # 获取分类准确度
            score = knn_clf.score(X_test, y_test)
            if score > best_score:
                best_score, best_p, best_k = score, p, k

    print("Best K =", best_k)
    print("Best P =", best_p)
    print("Best Score =", best_score)

    knn_clf = KNeighborsClassifier()
    print(cross_val_score(knn_clf, X_train, y_train))

    best_score, best_p, best_k = 0, 0, 0
    # 分别使用2-11个数据来分类,获取最大的分类准确度
    for k in range(2, 11):
        # p代表用哪一种方式的距离来计算
        # 当p = 1时,得到绝对值距离,也叫曼哈顿距离
        # 当p = 2时,得到欧几里德距离(Euclideandistance)距离,就是两点之间的直线距离
        for p in range(1, 6):
            knn_clf = KNeighborsClassifier(weights="distance", n_neighbors=k, p=p)
            scores = cross_val_score(knn_clf, X_train, y_train)
            score = np.mean(scores)
            if score > best_score:
                best_score, best_p, best_k = score, p, k

    print("Best K =", best_k)
    print("Best P =", best_p)
    print("Best Score =", best_score)
    # 通过最佳超参数来重新训练和获取分类准确度
    best_knn_clf = KNeighborsClassifier(weights="distance", n_neighbors=2, p=2)
    best_knn_clf.fit(X_train, y_train)
    print(best_knn_clf.score(X_test, y_test))

运行结果

Best K = 3
Best P = 4
Best Score = 0.9860917941585535
[0.99537037 0.98148148 0.97685185 0.97674419 0.97209302]
Best K = 2
Best P = 2
Best Score = 0.9851507321274763
0.980528511821975

如果我们想在交叉验证中只用3份数据,只调整一下cross_val_score的参数就可以了

import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
# 可以自动为我们完成交叉验证的过程,同时返回生成的k个模型,每个模型对应的准确率
from sklearn.model_selection import cross_val_score

if __name__ == "__main__":

    digits = datasets.load_digits()
    X = digits.data
    y = digits.target
    # 将手写数据集分成训练数据集和测试数据集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=666)
    best_score, best_p, best_k = 0, 0, 0
    # 分别使用2-11个数据来分类,获取最大的分类准确度
    for k in range(2, 11):
        # p代表用哪一种方式的距离来计算
        # 当p = 1时,得到绝对值距离,也叫曼哈顿距离
        # 当p = 2时,得到欧几里德距离(Euclideandistance)距离,就是两点之间的直线距离
        for p in range(1, 6):
            knn_clf = KNeighborsClassifier(weights="distance", n_neighbors=k, p=p)
            # 将训练数据集导入
            knn_clf.fit(X_train, y_train)
            # 获取分类准确度
            score = knn_clf.score(X_test, y_test)
            if score > best_score:
                best_score, best_p, best_k = score, p, k

    print("Best K =", best_k)
    print("Best P =", best_p)
    print("Best Score =", best_score)

    knn_clf = KNeighborsClassifier()
    对
    print(cross_val_score(knn_clf, X_train, y_train, cv=3))

    best_score, best_p, best_k = 0, 0, 0
    # 分别使用2-11个数据来分类,获取最大的分类准确度
    for k in range(2, 11):
        # p代表用哪一种方式的距离来计算
        # 当p = 1时,得到绝对值距离,也叫曼哈顿距离
        # 当p = 2时,得到欧几里德距离(Euclideandistance)距离,就是两点之间的直线距离
        for p in range(1, 6):
            knn_clf = KNeighborsClassifier(weights="distance", n_neighbors=k, p=p)
            scores = cross_val_score(knn_clf, X_train, y_train, cv=3)
            score = np.mean(scores)
            if score > best_score:
                best_score, best_p, best_k = score, p, k

    print("Best K =", best_k)
    print("Best P =", best_p)
    print("Best Score =", best_score)
    # 通过最佳超参数来重新训练和获取分类准确度
    best_knn_clf = KNeighborsClassifier(weights="distance", n_neighbors=2, p=2)
    best_knn_clf.fit(X_train, y_train)
    print(best_knn_clf.score(X_test, y_test))

运行结果

Best K = 3
Best P = 4
Best Score = 0.9860917941585535
[0.98888889 0.97771588 0.96935933]
Best K = 2
Best P = 2
Best Score = 0.9833023831631073
0.980528511821975

这种交叉验证法又称为k-folds交叉验证。它的缺点就是,每次训练k个模型,相当于整体性能慢了k倍。在极端情况下,这种k-folds可以变成一种留一法LOO-CV,把训练数据集分成m份,将m-1份用于训练,称为留一法。看预测那剩下的那一个样本预测的准不准。将这些结果综合起来进行平均,作为衡量我们当前参数下这个模型对应的预测的准确度。这样做,将完全不受随机的影响,最接近模型真正的性能指标。缺点就是计算量巨大。

偏差和方差平衡(Bias Variance Trade off)

这张图片可以很好的解释偏差和方差。所谓的偏差见左下角的圆环靶,我们所有的点全都完全偏离了中心的位置,这种情况就叫做偏差。所谓的方差见右上角的圆环靶,我们射击的这些点看起来都围绕着靶心,就是没有大的偏差,但是它整体太过分散,不集中,所以就是有非常高的方差。而上面的两个圆环靶整体就是没有偏差,其中左边的圆环靶既没有偏差也没有方差,全都集中在靶心。而右边的圆环靶虽然没有偏差却有高的方差。而下面的两个圆环靶就是都有高的偏差,其中左边的圆环靶有高的偏差,但是方差是低的。而右边的圆环靶既有高的偏差又有高的方差。在我们进行机器学习的过程中,我们实际要训练的那个模型,都是要预测一个问题,问题本身我们可以理解成靶心,而我们根据数据来拟合一个模型,进而预测这个问题,我们拟合的这个模型其实就是我们射出去的这些点,那么我们的模型就有可能犯偏差和方差这样两种错误

模型误差

模型误差 = 偏差(Bias) + 方差(Variance) + 不可避免的误差

不可避免的误差对于这部分错误,我们是无能为力的,它是客观存在的,比如我们采集的数据本身就是有噪音的,这是怎样改进我们的模型,改进我们的算法都不能避免的。但是偏差和方差却是和我们的模型,我们的算法相关的两个问题。

导致偏差的主要原因:对问题本身的假设不正确!如:非线形数据使用线性回归。而欠拟合就是这样一个例子。当然带来偏差还有其他的可能性,最典型的例子就是我们训练的数据所采用的那个特征其实跟这个问题完全没有关系,比如说我们想预测一个学生的考试成绩,但是我们是用这个学生的名字来预测考试成绩,这显然一定是高偏差的。因为这个特征本身离我们要预测的那个问题——考试成绩之间是高度不相关的。

方差:数据的一点点扰动都会较大地影响模型。通常原因,使用的模型太复杂。如高阶多项式回归。换句话说,我们的模型没有完全学习到这个问题的实质,没有学习到中心,而学习到了很多的噪音。多拟合就会极大的引入方差。

偏差和方差

有一些算法天生是高方差的算法。如KNN,因为KNN完全是数据驱动,新样本离训练数据的远近决定了预测的结果。但是离它最近的样本有多数是错误的结果,我们的预测就不准确了。

非参数学习通常都是高方差算法。因为不对数据进行任何假设,只能根据训练数据来进行预测,所以它对训练数据的准确性依赖程度非常高,它对这些训练数据本身是非常敏感的。

有一些算法天生就是高偏差算法。如线性回归。因为我们在现实生活中,很多问题它不是线性的,非要使用线性的手段去拟合这些问题的话,那得到的结果就会产生错误,这种错误通常都是偏差错误。

参数学习通常都是高偏差算法。因为堆数据具有极强的假设。我们认为这些数据是符合这个数学模型的,可是一旦这个问题不符合这个数学模型,也就是说这个假设是错误的,相应的我们训练出来的模型就会带来错误,这种错误通常都是高偏差的错误。

大多数算法具有相应的参数,可以调整偏差和方差。如kNN中的k。k越小,说明我们的模型越复杂,相应的我们的模型的方差越大,偏差越小。而k越大,我们的模型越简单,直到k达到最大值,也就是k=样本总数的时候,我们的kNN算法本质其实就是看我们的全部样本中谁最多,我就预测谁,在这种情况下达到kNN算法的偏差最大、方差最小;如线性回归中使用多项式回归。我们可以调整多项式回归的阶数degree,它相应的就会改变我们的线性回归的偏差和方差,degree值越小,最低的时候是1,那么我们的模型越简单,它相应的偏差就会越大。而degree值越大,我们拟合出来的曲线越弯曲,形状会越来越奇怪,那么相应的它引入的方差误差就会越来越大。

偏差和方差通常是矛盾的。降低偏差就会提高方差。降低方差,会提高偏差。通常我们要找到一个平衡。首先来看我们的算法主要的错误到底是集中在偏差的位置还是方差的位置。看看能不能让它的偏差和方差达到一定的平衡。换句话说不要特别高的方差,因为此时模型的泛化能力太差了,也不要特别高的偏差,因为在这种情况下,我们的模型太偏离原问题了。我们不能完全杜绝错误,但是我们让它有一点偏差有一点方差,不是集中在一个方向上。达到这个目标通常也是我们在机器学习调参的过程中要做的一个主要的事情。

通常在机器学习领域,主要的挑战来自于方差。但是在问题这个层面上不一定如此,因为我们还对很多的问题太过肤浅,比如对疾病的理解,比如对金融市场的理解。比如在过去,有人尝试用历史的金融数据来预测未来的金融情况,那么通常这个预测的结果都不是很理想。很有可能是因为历史的金融数据本身并不能非常好的反映未来的金融走向。那么这种预测方法本身带来了非常高的偏差。

但是在数据完备,特征选择正确的情况下,方差是主要的问题,解决高方差的通常手段:

  1. 降低模型复杂度
  2. 减少数据维度;降噪
  3. 增加样本数
  4. 使用验证集
  5. 模型正则化

模型泛化与岭回归

模型正则化(Regularization):限制系数的大小

之前我们在讲过拟合的时候,它生成的曲线非常的弯曲,非常的陡峭。对于这一根曲线来说,相应的每一个多项式项前面的系数会非常的大。模型正则化就是限制这些系数的大小。现在我们用代码来验证一下,先生成线性的训练数据

import numpy as np
import matplotlib.pyplot as plt

if __name__ == "__main__":

    np.random.seed(42)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
    plt.scatter(x, y)
    plt.show()

运行结果

再进行多项式拟合

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error

if __name__ == "__main__":

    np.random.seed(42)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()
    lin_reg = LinearRegression()
    def PolynomialRegression(degree):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("lin_reg", lin_reg)
        ])

    poly100_reg = PolynomialRegression(degree=100)
    poly100_reg.fit(X, y)
    y100_predict = poly100_reg.predict(X)
    print(mean_squared_error(y, y100_predict))
    # 查看拟合曲线的各项系数
    print(lin_reg.coef_)

    X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
    y_plot = poly100_reg.predict(X_plot)
    plt.scatter(x, y)
    plt.plot(X_plot[:, 0], y_plot, color='r')
    plt.axis([-3, 3, 0, 10])
    plt.show()

运行结果

0.3759549053775048
[ 1.42439073e+13 -2.12466748e+00  1.54479690e+02  1.67098517e+03
 -1.68354221e+04 -1.67356966e+05  7.69671197e+05  7.19167159e+06
 -1.91564362e+07 -1.73386970e+08  2.94158467e+08  2.65855494e+09
 -2.98779993e+09 -2.77902518e+10  2.08205157e+10  2.06137067e+11
 -1.00503399e+11 -1.10794836e+12  3.27932426e+11  4.33721479e+12
 -6.51849432e+11 -1.22101693e+13  4.33621529e+11  2.37019684e+13
  1.39932363e+12 -2.82060254e+13 -4.10258949e+12  1.18702565e+13
  3.48399831e+12  1.55678622e+13  1.94003321e+12 -1.71368365e+13
 -3.89936356e+12 -1.09652871e+13 -2.30837691e+12  1.49553243e+13
  3.45115415e+12  1.30288924e+13  4.29214358e+12 -8.22504346e+12
 -1.50375379e+12 -1.55744264e+13 -5.98634047e+12 -3.52006563e+12
 -1.59409626e+12  1.15482917e+13  2.09128451e+12  1.42192391e+13
  6.69760485e+12  6.94440143e+11  2.68881401e+12 -1.08847554e+13
 -3.22254621e+12 -1.37622857e+13 -6.20116568e+12 -1.61659413e+12
 -4.42220524e+12  5.80908289e+12 -4.00657127e+10  1.26136586e+13
  6.42455657e+12  6.94601810e+12  6.53005938e+12  6.31892663e+11
  3.57264555e+12 -8.75913024e+12 -2.97618352e+12 -1.03119202e+13
 -5.67563631e+12 -6.78110437e+12 -7.42635518e+12  3.73270228e+11
 -3.68184310e+12  7.60158227e+12  2.32388705e+12  1.13431444e+13
  6.24335439e+12  7.08572883e+12  9.38808268e+12 -1.50553406e+12
  4.16083097e+12 -7.73762602e+12 -2.31468569e+12 -1.00502179e+13
 -7.35406158e+12 -6.36365989e+12 -8.63079902e+12  2.40655509e+12
 -6.95260658e+12  9.67202712e+12  4.10574155e+12  6.90151222e+12
  1.23520497e+13  1.40956708e+12  9.16331620e+12 -7.16918098e+12
 -3.34748423e+12 -5.97986471e+12 -1.81238211e+13  4.96934673e+12
  9.42782890e+12]

从结果我们可以看出,当阶数到100的时候,它对应的各项系数非常的大,有10^13这么多。

之前我们在讲多元线性回归的时候,是讲它的损失函数尽可能的小,它等价于尽可能的小。即为求原始的数据y和使用θ预测的它们的均方误差尽可能的小。但是在这里如果我们过拟合的话,这个θ系数就会非常的大。我们需要限制这个系数θ不要太大,则我们需要改变损失函数,加入模型正则化,使得

尽可能的小。现在我们要让我们的目标函数J(θ)尽可能的小,就不仅要顾及均方误差这一项,还要顾及后面的这一项。而后面的这一项是所有的θi的平方和,所以我们要想让后面这一项尽可能的小,就只能让每一个θ都尽可能的小。所以我们在考虑让J(θ)尽可能小的话,就相应的让所有的θ都尽可能的小了。所以就不会出现像之前过拟合一样每一个θ都那么大,使得我们的曲线都那么陡峭的情况。这就是模型正则化的原理。这里有几个细节需要注意的是,正则化里的θ是从1到n,都是斜率,而表示截距的并不包含在其中。这是因为并不是任何一个X的系数,这个决定了我们整个曲线的高低,但不决定我们曲线每一部分它的陡峭和缓和。所以在模型正则化的时候不需要加上。第二点里有一个1/2,这只是一个惯例,在我们使用梯度下降法的时候需要对各个θ求偏导,由于(θi^2)'=2θi,所以这个求导后的2倍可以跟1/2约去,方便我们的计算而已。实际上不要这个1/2也是可以的,那么第三点就是α,实际上这个α是一个新的超参数,它代表在我们的模型正则化下,模型新的损失函数中,我们要让我们每一个θ都尽可能的小,这个小的程度占我们整个优化这个损失函数的程度多少。如果α=0的时候相当于我们的损失函数没有加入模型正则化,此时和之前是一样的。当α=+∞的时候,那么此时前面的均方误差占整个J(θ)的比重就会非常非常的小,我们的优化任务就变成了要让每一个θi都尽可能的小。在极端情况下,每一个θi都等于0才能让J(θ)最小。当然我们实际是要让我们预测的准确度和让每个θi都尽量小这两件事之间取得一个平衡。那么对于不同的数据,我们就要尝试让α取不同的值。而这个1/2可以和α融合到一起,这并不影响。

事实上,模型正则化的方式不仅有这一种方式,而这种方式,通常称为岭回归(Ridge Regression),现在我们就从代码层面看一下岭回归。先是不包含正则化的多项式回归

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split

if __name__ == "__main__":

    np.random.seed(42)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()

    def PolynomialRegression(degree):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("lin_reg", LinearRegression())
        ])

    np.random.seed(666)
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    poly_reg = PolynomialRegression(degree=20)
    poly_reg.fit(X_train, y_train)
    y_poly_predict = poly_reg.predict(X_test)
    print(mean_squared_error(y_test, y_poly_predict))

    def plot_model(model):
        X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
        y_plot = model.predict(X_plot)
        plt.scatter(x, y)
        plt.plot(X_plot[:, 0], y_plot, color='r')
        plt.axis([-3, 3, 0, 6])
        plt.show()

    plot_model(poly_reg)

运行结果

167.9401085999025

现在我们加入岭回归

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge

if __name__ == "__main__":

    np.random.seed(42)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()

    def PolynomialRegression(degree):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("lin_reg", LinearRegression())
        ])

    np.random.seed(666)
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    poly_reg = PolynomialRegression(degree=20)
    poly_reg.fit(X_train, y_train)
    y_poly_predict = poly_reg.predict(X_test)
    print(mean_squared_error(y_test, y_poly_predict))

    def plot_model(model):
        X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
        y_plot = model.predict(X_plot)
        plt.scatter(x, y)
        plt.plot(X_plot[:, 0], y_plot, color='r')
        plt.axis([-3, 3, 0, 6])
        plt.show()

    # plot_model(poly_reg)

    # 使用岭回归,对超参数α赋值
    def RidgeRegression(degree, alpha):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("ridge_reg", Ridge(alpha=alpha))
        ])

    rigde1_reg = RidgeRegression(20, 0.0001)
    rigde1_reg.fit(X_train, y_train)
    y1_predict = rigde1_reg.predict(X_test)
    print(mean_squared_error(y_test, y1_predict))
    plot_model(rigde1_reg)

运行结果

167.9401085999025
1.3233492754136291

通过结果我们可以看出,加入了岭回归后,我们的测试数据集的均方误差值从167.94一下子降到了1.32,模型的泛化能力得到了大大的提高。而绘制出的模型曲线也比之前要缓和的太多了。现在我们调整岭回归的α值来看一下

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge

if __name__ == "__main__":

    np.random.seed(42)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()

    def PolynomialRegression(degree):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("lin_reg", LinearRegression())
        ])

    np.random.seed(666)
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    poly_reg = PolynomialRegression(degree=20)
    poly_reg.fit(X_train, y_train)
    y_poly_predict = poly_reg.predict(X_test)
    print(mean_squared_error(y_test, y_poly_predict))

    def plot_model(model):
        X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
        y_plot = model.predict(X_plot)
        plt.scatter(x, y)
        plt.plot(X_plot[:, 0], y_plot, color='r')
        plt.axis([-3, 3, 0, 6])
        plt.show()

    # plot_model(poly_reg)

    # 使用岭回归,对超参数α赋值
    def RidgeRegression(degree, alpha):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("ridge_reg", Ridge(alpha=alpha))
        ])

    rigde1_reg = RidgeRegression(20, 0.0001)
    rigde1_reg.fit(X_train, y_train)
    y1_predict = rigde1_reg.predict(X_test)
    print(mean_squared_error(y_test, y1_predict))
    # plot_model(rigde1_reg)

    rigde2_reg = RidgeRegression(20, 1)
    rigde2_reg.fit(X_train, y_train)
    y2_predict = rigde2_reg.predict(X_test)
    print(mean_squared_error(y_test, y2_predict))
    plot_model(rigde2_reg)

运行结果

167.9401085999025
1.3233492754136291
1.1888759304218461

根据结果来看,调整了α从0.0001到1后,测试数据集的均方误差又好了一些,而且图像也更平缓了一些。继续调大α值再来看看

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge

if __name__ == "__main__":

    np.random.seed(42)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()

    def PolynomialRegression(degree):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("lin_reg", LinearRegression())
        ])

    np.random.seed(666)
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    poly_reg = PolynomialRegression(degree=20)
    poly_reg.fit(X_train, y_train)
    y_poly_predict = poly_reg.predict(X_test)
    print(mean_squared_error(y_test, y_poly_predict))

    def plot_model(model):
        X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
        y_plot = model.predict(X_plot)
        plt.scatter(x, y)
        plt.plot(X_plot[:, 0], y_plot, color='r')
        plt.axis([-3, 3, 0, 6])
        plt.show()

    # plot_model(poly_reg)

    # 使用岭回归,对超参数α赋值
    def RidgeRegression(degree, alpha):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("ridge_reg", Ridge(alpha=alpha))
        ])

    rigde1_reg = RidgeRegression(20, 0.0001)
    rigde1_reg.fit(X_train, y_train)
    y1_predict = rigde1_reg.predict(X_test)
    print(mean_squared_error(y_test, y1_predict))
    # plot_model(rigde1_reg)

    rigde2_reg = RidgeRegression(20, 1)
    rigde2_reg.fit(X_train, y_train)
    y2_predict = rigde2_reg.predict(X_test)
    print(mean_squared_error(y_test, y2_predict))
    # plot_model(rigde2_reg)

    rigde3_reg = RidgeRegression(20, 100)
    rigde3_reg.fit(X_train, y_train)
    y3_predict = rigde3_reg.predict(X_test)
    print(mean_squared_error(y_test, y3_predict))
    plot_model(rigde3_reg)

运行结果

167.9401085999025
1.3233492754136291
1.1888759304218461
1.3196456113086197

通过结果,我们可以看到测试数据集的均方误差增大了,说明我们可能正则过头了,但是从图像来看,图像更加的平滑了,没有那些波折的地方,最后我们来模拟α到+∞的情况。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge

if __name__ == "__main__":

    np.random.seed(42)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()

    def PolynomialRegression(degree):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("lin_reg", LinearRegression())
        ])

    np.random.seed(666)
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    poly_reg = PolynomialRegression(degree=20)
    poly_reg.fit(X_train, y_train)
    y_poly_predict = poly_reg.predict(X_test)
    print(mean_squared_error(y_test, y_poly_predict))

    def plot_model(model):
        X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
        y_plot = model.predict(X_plot)
        plt.scatter(x, y)
        plt.plot(X_plot[:, 0], y_plot, color='r')
        plt.axis([-3, 3, 0, 6])
        plt.show()

    # plot_model(poly_reg)

    # 使用岭回归,对超参数α赋值
    def RidgeRegression(degree, alpha):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("ridge_reg", Ridge(alpha=alpha))
        ])

    rigde1_reg = RidgeRegression(20, 0.0001)
    rigde1_reg.fit(X_train, y_train)
    y1_predict = rigde1_reg.predict(X_test)
    print(mean_squared_error(y_test, y1_predict))
    # plot_model(rigde1_reg)

    rigde2_reg = RidgeRegression(20, 1)
    rigde2_reg.fit(X_train, y_train)
    y2_predict = rigde2_reg.predict(X_test)
    print(mean_squared_error(y_test, y2_predict))
    # plot_model(rigde2_reg)

    rigde3_reg = RidgeRegression(20, 100)
    rigde3_reg.fit(X_train, y_train)
    y3_predict = rigde3_reg.predict(X_test)
    print(mean_squared_error(y_test, y3_predict))
    # plot_model(rigde3_reg)

    rigde4_reg = RidgeRegression(20, 1000000000000)
    rigde4_reg.fit(X_train, y_train)
    y4_predict = rigde4_reg.predict(X_test)
    print(mean_squared_error(y_test, y4_predict))
    plot_model(rigde4_reg)

运行结果

167.9401085999025
1.3233492754136291
1.1888759304218461
1.3196456113086197
1.8408939654674448

通过结果,我们可以看出测试数据集的均方误差值再继续增大,但依然要比没有加入岭回归的时候要好的多。而图像则是一根平行于x轴的直线。之前我们说当α=+∞的时候,那么此时前面的均方误差占整个J(θ)的比重就会非常非常的小,我们的优化任务就变成了要让每一个θi都尽可能的小。在极端情况下,每一个θi都等于0才能让J(θ)最小。所有的斜率都为0,那就是一条平整的直线。

LASSO

LASSO是Least Absolute Shrinkage and Selection Operator Regression的缩写,同样作为正则化项,它跟岭回归有一点点不同,但作用是一样的,都是为了让系数θ尽可能的小。那么我们改写损失函数为LASSO就为使

尽可能的小。

我们通过代码来看看LASSO与岭回归的不同,我们首先依然是先生成需要训练的数据集

import numpy as np
import matplotlib.pyplot as plt

if __name__ == "__main__":

    np.random.seed(42)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
    plt.scatter(x, y)
    plt.show()

运行结果

然后依然是我们先使用多项式回归来进行拟合。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split

if __name__ == "__main__":

    np.random.seed(42)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()

    def PolynomialRegression(degree):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("lin_reg", LinearRegression())
        ])

    np.random.seed(666)
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    poly_reg = PolynomialRegression(degree=20)
    poly_reg.fit(X_train, y_train)
    y_poly_predict = poly_reg.predict(X_test)
    print(mean_squared_error(y_test, y_poly_predict))

    def plot_model(model):
        X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
        y_plot = model.predict(X_plot)
        plt.scatter(x, y)
        plt.plot(X_plot[:, 0], y_plot, color='r')
        plt.axis([-3, 3, 0, 6])
        plt.show()

    plot_model(poly_reg)

运行结果

167.9401085999025

现在我们开始使用LASSO

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Lasso

if __name__ == "__main__":

    np.random.seed(42)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()

    def PolynomialRegression(degree):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("lin_reg", LinearRegression())
        ])

    np.random.seed(666)
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    poly_reg = PolynomialRegression(degree=20)
    poly_reg.fit(X_train, y_train)
    y_poly_predict = poly_reg.predict(X_test)
    print(mean_squared_error(y_test, y_poly_predict))

    def plot_model(model):
        X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
        y_plot = model.predict(X_plot)
        plt.scatter(x, y)
        plt.plot(X_plot[:, 0], y_plot, color='r')
        plt.axis([-3, 3, 0, 6])
        plt.show()

    # plot_model(poly_reg)

    # LASSO
    def LassoRegression(degree, alpha):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("lasso_reg", Lasso(alpha=alpha))
        ])

    lasso1_reg = LassoRegression(20, 0.01)
    lasso1_reg.fit(X_train, y_train)
    y1_predict = lasso1_reg.predict(X_test)
    print(mean_squared_error(y_test, y1_predict))
    plot_model(lasso1_reg)

运行结果

167.9401085999025
1.1496080843259966

现在我们来增大LASSO的超参数α值

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Lasso

if __name__ == "__main__":

    np.random.seed(42)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()

    def PolynomialRegression(degree):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("lin_reg", LinearRegression())
        ])

    np.random.seed(666)
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    poly_reg = PolynomialRegression(degree=20)
    poly_reg.fit(X_train, y_train)
    y_poly_predict = poly_reg.predict(X_test)
    print(mean_squared_error(y_test, y_poly_predict))

    def plot_model(model):
        X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
        y_plot = model.predict(X_plot)
        plt.scatter(x, y)
        plt.plot(X_plot[:, 0], y_plot, color='r')
        plt.axis([-3, 3, 0, 6])
        plt.show()

    # plot_model(poly_reg)

    # LASSO
    def LassoRegression(degree, alpha):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("lasso_reg", Lasso(alpha=alpha))
        ])

    lasso1_reg = LassoRegression(20, 0.01)
    lasso1_reg.fit(X_train, y_train)
    y1_predict = lasso1_reg.predict(X_test)
    print(mean_squared_error(y_test, y1_predict))
    # plot_model(lasso1_reg)

    lasso2_reg = LassoRegression(20, 0.1)
    lasso2_reg.fit(X_train, y_train)
    y2_predict = lasso2_reg.predict(X_test)
    print(mean_squared_error(y_test, y2_predict))
    plot_model(lasso2_reg)

运行结果

167.9401085999025
1.1496080843259966
1.1213911351818648

现在我们继续增大LASSO的超参数α值

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Lasso

if __name__ == "__main__":

    np.random.seed(42)
    x = np.random.uniform(-3, 3, size=100)
    X = x.reshape(-1, 1)
    y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
    # plt.scatter(x, y)
    # plt.show()

    def PolynomialRegression(degree):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("lin_reg", LinearRegression())
        ])

    np.random.seed(666)
    X_train, X_test, y_train, y_test = train_test_split(X, y)
    poly_reg = PolynomialRegression(degree=20)
    poly_reg.fit(X_train, y_train)
    y_poly_predict = poly_reg.predict(X_test)
    print(mean_squared_error(y_test, y_poly_predict))

    def plot_model(model):
        X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
        y_plot = model.predict(X_plot)
        plt.scatter(x, y)
        plt.plot(X_plot[:, 0], y_plot, color='r')
        plt.axis([-3, 3, 0, 6])
        plt.show()

    # plot_model(poly_reg)

    # LASSO
    def LassoRegression(degree, alpha):
        return Pipeline([
            ("poly", PolynomialFeatures(degree=degree)),
            ("std_scaler", StandardScaler()),
            ("lasso_reg", Lasso(alpha=alpha))
        ])

    lasso1_reg = LassoRegression(20, 0.01)
    lasso1_reg.fit(X_train, y_train)
    y1_predict = lasso1_reg.predict(X_test)
    print(mean_squared_error(y_test, y1_predict))
    # plot_model(lasso1_reg)

    lasso2_reg = LassoRegression(20, 0.1)
    lasso2_reg.fit(X_train, y_train)
    y2_predict = lasso2_reg.predict(X_test)
    print(mean_squared_error(y_test, y2_predict))
    # plot_model(lasso2_reg)

    lasso3_reg = LassoRegression(20, 1)
    lasso3_reg.fit(X_train, y_train)
    y3_predict = lasso3_reg.predict(X_test)
    print(mean_squared_error(y_test, y3_predict))
    plot_model(lasso3_reg)

运行结果

167.9401085999025
1.1496080843259966
1.1213911351818648
1.8408939659515595

通过结果,我们发现LASSO的α=1的时候,它的正则化程度就非常的高了。实际上我们选择机器学习算法的超参数的时候是要在完全不正则化和正则化过头,产生了一条平行于x轴的直线之间选择一个程度最好的结果。对比于岭回归,我们发现随着α的增大,这条曲线虽然越来越平缓,但是它始终是一个曲线的样子,但是对于LASSO来说,当我们的α取0.1的时候,这根线近乎就是一条直线。其实这个特性是由LASSO回归这个正则化的式子的特殊性所决定的。

岭回归(α=100)

LASSO(α=0.1)

通过上面两幅图,我们可以来比较一下岭回归和LASSO两种正则化的方式,当岭回归的α=100的时候,我们得到的模型依然是一根曲线,事实上我们很难让我们的岭回归得到的是一根倾斜的直线,总是保存着这种弯曲的形状。当我们使用LASSO α=0.1的时候,虽然我们得到的依然是一根曲线,但是显然它比用岭回归得到的曲线弯曲程度更低,更像一根直线。换句话说,我们使用LASSO得到的模型更倾向于是一根直线。直线和曲线的区别,就是我们使用岭回归得到的是曲线,就是有很多的X,就是我们的特征,它前面是存在系数的;当我们使用LASSO的时候,则有很多X,它前面不再有系数了,这个系数为0,也就是对应着θ=0。

LASSO趋向于在优化损失函数的过程中,使得一部分θ值变为0,而不是让一部分θ都变为一个很小的值。所以可作为特征选择用。因为使用LASSO的过程,如果某一些θ=0了,就代表LASSO认为这个θ对应的那个特征是完全没有用的,而剩下的那些θ不等于0的那些特征,就是LASSO认为这个特征有用。

对于岭回归来说,当α->+∞的过程中,我们来看θ是怎么从初始值变到0的。我们用梯度下降法的角度来看这个问题。正则化项对应的梯度就为

当我们的损失函数只剩下后面的正则化项的时候,

假如θ的初始值是这个蓝色的点的时候,它会逐渐向0的方向搜索,并逐渐变成0,但它的中间过程都是有值的。

但是对于LASSO不同,,当α->+∞的过程中,虽然|θi|是不可导的,但我们可以简单的用一个分类函数来刻画它的梯度

这个分段函数它的值域就是(-1,0,1)这么三个数。当我们的损失函数只剩下后面的正则化项的时候,

假如θ的初始值是这个蓝色的点的时候,它会沿着一个方向首先走到一个轴为0的地方,这个方向其实就是(-1,-1)这样的一个方向。走到了y轴,之后再在y轴上沿着y轴的方向到达0点,这个方向就变成了(0,-1)这个方向。LASSO不能像岭回归一样沿着一条曲线到达0点,而只能使用这种非常规则的方式走,在走的过程中,它就会达到某些轴的0点,使得LASSO最终的结果相应的θ值包含很多的0。这也解释了岭回归更像是爬山的过程,而LASSO可以作为特征选择用,直接淘汰一些不重要的特征(该特征的系数直接变0)。不过尽管如此,也正是因为这样的特性,使得LASSO这种方法有可能会错误的将一些原本是有用的特征也变为0。所以从计算的准确度的角度来讲,还是岭回归更为准确。但是如果我们的特征特别的大的时候,多项式回归非要阶数为100的话,这个特征数量会非常的多,此时使用LASSO也可以非常好的起到将模型的特征变小这样一个作用。

L1,L2和弹性网络

通过对岭回归和LASSO的对比,我们不难发现,与之类似的还有我们之前说的线性回归的均方误差MSE以及平均绝对误差MAE,它们跟岭回归和LASSO的正则化的表现形式其实是非常像的。而KNN算法中两种距离欧拉距离和曼哈顿距离也是类似于岭回归和LASSO正则化的表现形式,一种是平方和,一种是绝对值求和。虽然这三种类型是做不同的事情的,但是它们背后本质的数学思想却是非常相近的,表达出来的数学含义也近乎是一致的,只不过应用在了不同的场景中就产生了不同的效果,进而生成了不同的名词。

L1正则,L2正则

在KNN算法中还有一种距离,叫明可夫斯基距离(Minkowski Distance),它的表达式如下

这个距离的表达方式,两点之间对应的维度,它们的差的绝对值的p次方的和再开一个p次方根。我们将这个式子再进行一下泛化。

对于任何一个向量X,它的第i个维度的绝对值的p次方求和,再开p次方根,就是X向量的模,在右下角写一个p,在数学上,我们通常将它称为Lp范数。如果p=1,就是L1范数,这个式子其实就是相当于0点到X向量的曼哈顿距离;p=2就是L2范数,这个式子其实就是相当于0点到X向量的欧拉距离。

结合Lp范数的概念,我们来看一下我们说的两种正则化的方式。对于岭回归,我们通常称为L2正则项;而LASSO通常称为L1正则项。这个正则项跟范数对比的区别就在于没有开p次方根。但有的时候我们也直接称岭回归为L2范数了,因为其实我们的正则项是用于放置我们的损失函数中来进行最优化的过程,加上开根号是不影响最终最优化的结果的,但是不要这个根号,整个式子更加的简单。当然有L1正则项,有L2正则项,相应的就应该有Ln正则项,我们只要依据Lp范数的式子去写就好了,不过在我们进行模型正则化的过程中,通常很少使用p>2这样的正则项。我们不会用L3正则项,L4正则项,但是理论上是存在这些正则项的。

实际上还存在一个L0正则项的,就是给我们的损失函数添加一个内容,

就是我们希望让θ的个数尽量小。L1、L2正则项都是一个可以明确写出来的数学式,但是L0正则项就是θ的个数越少越好,这一项描述的是非0θ元素的个数。实际上我们很少使用L0正则来正则化我们的模型的。这是因为L0正则的优化是一个NP难的问题。我们不能直接使用梯度下降法或者直接求解出一个公式来直接找出这个最优解。这一项本质其实是一个离散项,它是一个离散最优化的问题,我们可能要穷举所有的让各种θ的组合为0这样的可能性,依次来计算出这个J(θ),进而决定出我们让哪些θ为0,哪些θ不为0。通常如果我们真的想限制θ的个数的话,我们用L1来取代。

弹性网(Elastic Net)

弹性网是结合了L1正则岭回归和L2正则LASSO这两种方式,在均方误差后面添加了一个L1正则和L2正则,只不过我们又引入了一个新的超参数r,来表示这两种正则项它们之间的比例。L1前面就是r,相应的L2前面就是1-r,这个1/2是L2本身带的,跟r无关。通过这样的一种方式对模型进行处理,它就同时结合了岭回归和LASSO回归的优势。在实际应用中,通常都应该先尝试一下岭回归,因为岭回归的计算是比较精准的,但是如果特征数量比较多的话,岭回归没有特征选择的功能,不能将某些θ设置为0,所以当θ的量太大的话,可能整体的计算量会非常大,那么此时优先选择弹性网。这是因为弹性网结合了岭回归的优点,同时结合了LASSO回归可以进行特征选择这样的优势,而LASSO回归的缺点在于它急于将某一些θ化成0,那么这个过程可能会产生一些错误,使得我们最终得到的模型偏差比较大。

展开阅读全文
加载中

作者的其它热门文章

打赏
0
0 收藏
分享
打赏
0 评论
0 收藏
0
分享
返回顶部
顶部