# 机器学习算法整理(二)

2021/08/29 14:54

scikit-learn中的PCA

from sklearn.decomposition import PCA
import numpy as np
import matplotlib.pyplot as plt

if __name__ == "__main__":

X = np.empty((100, 2))
X[:, 0] = np.random.uniform(0, 100, size=100)
X[:, 1] = 0.75 * X[:, 0] + 3 + np.random.normal(0, 10, size=100)
pca = PCA(n_components=1)
pca.fit(X)
print(pca.components_)
X_reduction = pca.transform(X)
print(X_reduction.shape)
X_restore = pca.inverse_transform(X_reduction)
print(X_restore.shape)
plt.scatter(X[:, 0], X[:, 1], color='b', alpha=0.5)
plt.scatter(X_restore[:, 0], X_restore[:, 1], color='r', alpha=0.5)
plt.show()

[[-0.78144234 -0.62397746]]
(100, 1)
(100, 2)


import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split

if __name__ == "__main__":

X = digits.data
y = digits.target
# 对数据集进行训练数据和测试数据分类
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)
print(X_train.shape)

(1347, 64)

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
import timeit

if __name__ == "__main__":

X = digits.data
y = digits.target
# 对数据集进行训练数据和测试数据分类
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)
print(X_train.shape)
start_time = timeit.default_timer()
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_train)
print(timeit.default_timer() - start_time)
print(knn_clf.score(X_test, y_test))

(1347, 64)
0.0025633269999999486
0.9866666666666667

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.decomposition import PCA
import timeit

if __name__ == "__main__":

X = digits.data
y = digits.target
# 对数据集进行训练数据和测试数据分类
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)
print(X_train.shape)
start_time = timeit.default_timer()
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_train)
print(timeit.default_timer() - start_time)
print(knn_clf.score(X_test, y_test))
# 将原始数据的特征降为2维
pca = PCA(n_components=2)
pca.fit(X_train)
X_train_reduction = pca.transform(X_train)
X_test_reduction = pca.transform(X_test)
# 对降维后的数据集进行KNN训练
start_time = timeit.default_timer()
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train_reduction, y_train)
print(timeit.default_timer() - start_time)
print(knn_clf.score(X_test_reduction, y_test))

(1347, 64)
0.0021396940000000253
0.9866666666666667
0.0005477309999999402
0.6066666666666667

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.decomposition import PCA
import timeit

if __name__ == "__main__":

X = digits.data
y = digits.target
# 对数据集进行训练数据和测试数据分类
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)
print(X_train.shape)
start_time = timeit.default_timer()
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_train)
print(timeit.default_timer() - start_time)
print(knn_clf.score(X_test, y_test))
# 将原始数据的特征降为2维
pca = PCA(n_components=2)
pca.fit(X_train)
X_train_reduction = pca.transform(X_train)
X_test_reduction = pca.transform(X_test)
# 对降维后的数据集进行KNN训练
start_time = timeit.default_timer()
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train_reduction, y_train)
print(timeit.default_timer() - start_time)
print(knn_clf.score(X_test_reduction, y_test))
# 解释的方差比例
print(pca.explained_variance_ratio_)

(1347, 64)
0.002134913000000016
0.9866666666666667
0.0005182209999999854
0.6066666666666667
[0.14566817 0.13735469]

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.decomposition import PCA
import timeit

if __name__ == "__main__":

X = digits.data
y = digits.target
# 对数据集进行训练数据和测试数据分类
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)
print(X_train.shape)
start_time = timeit.default_timer()
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_train)
print(timeit.default_timer() - start_time)
print(knn_clf.score(X_test, y_test))
# 将原始数据的特征降为2维
pca = PCA(n_components=2)
pca.fit(X_train)
X_train_reduction = pca.transform(X_train)
X_test_reduction = pca.transform(X_test)
# 对降维后的数据集进行KNN训练
start_time = timeit.default_timer()
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train_reduction, y_train)
print(timeit.default_timer() - start_time)
print(knn_clf.score(X_test_reduction, y_test))
# 解释的方差比例
print(pca.explained_variance_ratio_)
# 查看所有特征的方差比例
pca = PCA(n_components=X_train.shape[1])
pca.fit(X_train)
print(pca.explained_variance_ratio_)

(1347, 64)
0.0023309580000000496
0.9866666666666667
0.0005462749999999295
0.6066666666666667
[0.14566817 0.13735469]
[1.45668166e-01 1.37354688e-01 1.17777287e-01 8.49968861e-02
5.86018996e-02 5.11542945e-02 4.26605279e-02 3.60119663e-02
3.41105814e-02 3.05407804e-02 2.42337671e-02 2.28700570e-02
1.80304649e-02 1.79346003e-02 1.45798298e-02 1.42044841e-02
1.29961033e-02 1.26617002e-02 1.01728635e-02 9.09314698e-03
8.85220461e-03 7.73828332e-03 7.60516219e-03 7.11864860e-03
6.85977267e-03 5.76411920e-03 5.71688020e-03 5.08255707e-03
4.89020776e-03 4.34888085e-03 3.72917505e-03 3.57755036e-03
3.26989470e-03 3.14917937e-03 3.09269839e-03 2.87619649e-03
2.50362666e-03 2.25417403e-03 2.20030857e-03 1.98028746e-03
1.88195578e-03 1.52769283e-03 1.42823692e-03 1.38003340e-03
1.17572392e-03 1.07377463e-03 9.55152460e-04 9.00017642e-04
5.79162563e-04 3.82793717e-04 2.38328586e-04 8.40132221e-05
5.60545588e-05 5.48538930e-05 1.08077650e-05 4.01354717e-06
1.23186515e-06 1.05783059e-06 6.06659094e-07 5.86686040e-07
1.71368535e-33 7.44075955e-34 7.44075955e-34 7.15189459e-34]

plt.plot([i for i in range(X_train.shape[1])],
[np.sum(pca.explained_variance_ratio_[: i + 1]) for i in range(X_train.shape[1])])
plt.show()

pca = PCA(0.95)
pca.fit(X_train)
print(pca.n_components_)

28

pca = PCA(0.95)
pca.fit(X_train)
print(pca.n_components_)
X_train_reduction = pca.transform(X_train)
X_test_reduction = pca.transform(X_test)
start_time = timeit.default_timer()
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train_reduction, y_train)
print(timeit.default_timer() - start_time)
print(knn_clf.score(X_test_reduction, y_test))

28
0.0014477489999999982
0.98

pca = PCA(n_components=2)
pca.fit(X)
X_reduction = pca.transform(X)
print(X_reduction.shape)
for i in range(10):
# 一次绘制一个数据在二维平面中的点
plt.scatter(X_reduction[y == i, 0], X_reduction[y == i, 1], alpha=0.8)
plt.show()

(1797, 2)

MNIST数据集

import numpy as np
from sklearn.datasets import fetch_openml

if __name__ == "__main__":

mnist = fetch_openml('mnist_784')
print(mnist)

{'data': array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]), 'target': array(['5', '0', '4', ..., '4', '5', '6'], dtype=object), 'frame': None, 'feature_names': ['pixel1', 'pixel2', 'pixel3', 'pixel4', 'pixel5', 'pixel6', 'pixel7', 'pixel8', 'pixel9', 'pixel10', 'pixel11', 'pixel12', 'pixel13', 'pixel14', 'pixel15', 'pixel16', 'pixel17', 'pixel18', 'pixel19', 'pixel20', 'pixel21', 'pixel22', 'pixel23', 'pixel24', 'pixel25', 'pixel26', 'pixel27', 'pixel28', 'pixel29', 'pixel30', 'pixel31', 'pixel32', 'pixel33', 'pixel34', 'pixel35', 'pixel36', 'pixel37', 'pixel38', 'pixel39', 'pixel40', 'pixel41', 'pixel42', 'pixel43', 'pixel44', 'pixel45', 'pixel46', 'pixel47', 'pixel48', 'pixel49', 'pixel50', 'pixel51', 'pixel52', 'pixel53', 'pixel54', 'pixel55', 'pixel56', 'pixel57', 'pixel58', 'pixel59', 'pixel60', 'pixel61', 'pixel62', 'pixel63', 'pixel64', 'pixel65', 'pixel66', 'pixel67', 'pixel68', 'pixel69', 'pixel70', 'pixel71', 'pixel72', 'pixel73', 'pixel74', 'pixel75', 'pixel76', 'pixel77', 'pixel78', 'pixel79', 'pixel80', 'pixel81', 'pixel82', 'pixel83', 'pixel84', 'pixel85', 'pixel86', 'pixel87', 'pixel88', 'pixel89', 'pixel90', 'pixel91', 'pixel92', 'pixel93', 'pixel94', 'pixel95', 'pixel96', 'pixel97', 'pixel98', 'pixel99', 'pixel100', 'pixel101', 'pixel102', 'pixel103', 'pixel104', 'pixel105', 'pixel106', 'pixel107', 'pixel108', 'pixel109', 'pixel110', 'pixel111', 'pixel112', 'pixel113', 'pixel114', 'pixel115', 'pixel116', 'pixel117', 'pixel118', 'pixel119', 'pixel120', 'pixel121', 'pixel122', 'pixel123', 'pixel124', 'pixel125', 'pixel126', 'pixel127', 'pixel128', 'pixel129', 'pixel130', 'pixel131', 'pixel132', 'pixel133', 'pixel134', 'pixel135', 'pixel136', 'pixel137', 'pixel138', 'pixel139', 'pixel140', 'pixel141', 'pixel142', 'pixel143', 'pixel144', 'pixel145', 'pixel146', 'pixel147', 'pixel148', 'pixel149', 'pixel150', 'pixel151', 'pixel152', 'pixel153', 'pixel154', 'pixel155', 'pixel156', 'pixel157', 'pixel158', 'pixel159', 'pixel160', 'pixel161', 'pixel162', 'pixel163', 'pixel164', 'pixel165', 'pixel166', 'pixel167', 'pixel168', 'pixel169', 'pixel170', 'pixel171', 'pixel172', 'pixel173', 'pixel174', 'pixel175', 'pixel176', 'pixel177', 'pixel178', 'pixel179', 'pixel180', 'pixel181', 'pixel182', 'pixel183', 'pixel184', 'pixel185', 'pixel186', 'pixel187', 'pixel188', 'pixel189', 'pixel190', 'pixel191', 'pixel192', 'pixel193', 'pixel194', 'pixel195', 'pixel196', 'pixel197', 'pixel198', 'pixel199', 'pixel200', 'pixel201', 'pixel202', 'pixel203', 'pixel204', 'pixel205', 'pixel206', 'pixel207', 'pixel208', 'pixel209', 'pixel210', 'pixel211', 'pixel212', 'pixel213', 'pixel214', 'pixel215', 'pixel216', 'pixel217', 'pixel218', 'pixel219', 'pixel220', 'pixel221', 'pixel222', 'pixel223', 'pixel224', 'pixel225', 'pixel226', 'pixel227', 'pixel228', 'pixel229', 'pixel230', 'pixel231', 'pixel232', 'pixel233', 'pixel234', 'pixel235', 'pixel236', 'pixel237', 'pixel238', 'pixel239', 'pixel240', 'pixel241', 'pixel242', 'pixel243', 'pixel244', 'pixel245', 'pixel246', 'pixel247', 'pixel248', 'pixel249', 'pixel250', 'pixel251', 'pixel252', 'pixel253', 'pixel254', 'pixel255', 'pixel256', 'pixel257', 'pixel258', 'pixel259', 'pixel260', 'pixel261', 'pixel262', 'pixel263', 'pixel264', 'pixel265', 'pixel266', 'pixel267', 'pixel268', 'pixel269', 'pixel270', 'pixel271', 'pixel272', 'pixel273', 'pixel274', 'pixel275', 'pixel276', 'pixel277', 'pixel278', 'pixel279', 'pixel280', 'pixel281', 'pixel282', 'pixel283', 'pixel284', 'pixel285', 'pixel286', 'pixel287', 'pixel288', 'pixel289', 'pixel290', 'pixel291', 'pixel292', 'pixel293', 'pixel294', 'pixel295', 'pixel296', 'pixel297', 'pixel298', 'pixel299', 'pixel300', 'pixel301', 'pixel302', 'pixel303', 'pixel304', 'pixel305', 'pixel306', 'pixel307', 'pixel308', 'pixel309', 'pixel310', 'pixel311', 'pixel312', 'pixel313', 'pixel314', 'pixel315', 'pixel316', 'pixel317', 'pixel318', 'pixel319', 'pixel320', 'pixel321', 'pixel322', 'pixel323', 'pixel324', 'pixel325', 'pixel326', 'pixel327', 'pixel328', 'pixel329', 'pixel330', 'pixel331', 'pixel332', 'pixel333', 'pixel334', 'pixel335', 'pixel336', 'pixel337', 'pixel338', 'pixel339', 'pixel340', 'pixel341', 'pixel342', 'pixel343', 'pixel344', 'pixel345', 'pixel346', 'pixel347', 'pixel348', 'pixel349', 'pixel350', 'pixel351', 'pixel352', 'pixel353', 'pixel354', 'pixel355', 'pixel356', 'pixel357', 'pixel358', 'pixel359', 'pixel360', 'pixel361', 'pixel362', 'pixel363', 'pixel364', 'pixel365', 'pixel366', 'pixel367', 'pixel368', 'pixel369', 'pixel370', 'pixel371', 'pixel372', 'pixel373', 'pixel374', 'pixel375', 'pixel376', 'pixel377', 'pixel378', 'pixel379', 'pixel380', 'pixel381', 'pixel382', 'pixel383', 'pixel384', 'pixel385', 'pixel386', 'pixel387', 'pixel388', 'pixel389', 'pixel390', 'pixel391', 'pixel392', 'pixel393', 'pixel394', 'pixel395', 'pixel396', 'pixel397', 'pixel398', 'pixel399', 'pixel400', 'pixel401', 'pixel402', 'pixel403', 'pixel404', 'pixel405', 'pixel406', 'pixel407', 'pixel408', 'pixel409', 'pixel410', 'pixel411', 'pixel412', 'pixel413', 'pixel414', 'pixel415', 'pixel416', 'pixel417', 'pixel418', 'pixel419', 'pixel420', 'pixel421', 'pixel422', 'pixel423', 'pixel424', 'pixel425', 'pixel426', 'pixel427', 'pixel428', 'pixel429', 'pixel430', 'pixel431', 'pixel432', 'pixel433', 'pixel434', 'pixel435', 'pixel436', 'pixel437', 'pixel438', 'pixel439', 'pixel440', 'pixel441', 'pixel442', 'pixel443', 'pixel444', 'pixel445', 'pixel446', 'pixel447', 'pixel448', 'pixel449', 'pixel450', 'pixel451', 'pixel452', 'pixel453', 'pixel454', 'pixel455', 'pixel456', 'pixel457', 'pixel458', 'pixel459', 'pixel460', 'pixel461', 'pixel462', 'pixel463', 'pixel464', 'pixel465', 'pixel466', 'pixel467', 'pixel468', 'pixel469', 'pixel470', 'pixel471', 'pixel472', 'pixel473', 'pixel474', 'pixel475', 'pixel476', 'pixel477', 'pixel478', 'pixel479', 'pixel480', 'pixel481', 'pixel482', 'pixel483', 'pixel484', 'pixel485', 'pixel486', 'pixel487', 'pixel488', 'pixel489', 'pixel490', 'pixel491', 'pixel492', 'pixel493', 'pixel494', 'pixel495', 'pixel496', 'pixel497', 'pixel498', 'pixel499', 'pixel500', 'pixel501', 'pixel502', 'pixel503', 'pixel504', 'pixel505', 'pixel506', 'pixel507', 'pixel508', 'pixel509', 'pixel510', 'pixel511', 'pixel512', 'pixel513', 'pixel514', 'pixel515', 'pixel516', 'pixel517', 'pixel518', 'pixel519', 'pixel520', 'pixel521', 'pixel522', 'pixel523', 'pixel524', 'pixel525', 'pixel526', 'pixel527', 'pixel528', 'pixel529', 'pixel530', 'pixel531', 'pixel532', 'pixel533', 'pixel534', 'pixel535', 'pixel536', 'pixel537', 'pixel538', 'pixel539', 'pixel540', 'pixel541', 'pixel542', 'pixel543', 'pixel544', 'pixel545', 'pixel546', 'pixel547', 'pixel548', 'pixel549', 'pixel550', 'pixel551', 'pixel552', 'pixel553', 'pixel554', 'pixel555', 'pixel556', 'pixel557', 'pixel558', 'pixel559', 'pixel560', 'pixel561', 'pixel562', 'pixel563', 'pixel564', 'pixel565', 'pixel566', 'pixel567', 'pixel568', 'pixel569', 'pixel570', 'pixel571', 'pixel572', 'pixel573', 'pixel574', 'pixel575', 'pixel576', 'pixel577', 'pixel578', 'pixel579', 'pixel580', 'pixel581', 'pixel582', 'pixel583', 'pixel584', 'pixel585', 'pixel586', 'pixel587', 'pixel588', 'pixel589', 'pixel590', 'pixel591', 'pixel592', 'pixel593', 'pixel594', 'pixel595', 'pixel596', 'pixel597', 'pixel598', 'pixel599', 'pixel600', 'pixel601', 'pixel602', 'pixel603', 'pixel604', 'pixel605', 'pixel606', 'pixel607', 'pixel608', 'pixel609', 'pixel610', 'pixel611', 'pixel612', 'pixel613', 'pixel614', 'pixel615', 'pixel616', 'pixel617', 'pixel618', 'pixel619', 'pixel620', 'pixel621', 'pixel622', 'pixel623', 'pixel624', 'pixel625', 'pixel626', 'pixel627', 'pixel628', 'pixel629', 'pixel630', 'pixel631', 'pixel632', 'pixel633', 'pixel634', 'pixel635', 'pixel636', 'pixel637', 'pixel638', 'pixel639', 'pixel640', 'pixel641', 'pixel642', 'pixel643', 'pixel644', 'pixel645', 'pixel646', 'pixel647', 'pixel648', 'pixel649', 'pixel650', 'pixel651', 'pixel652', 'pixel653', 'pixel654', 'pixel655', 'pixel656', 'pixel657', 'pixel658', 'pixel659', 'pixel660', 'pixel661', 'pixel662', 'pixel663', 'pixel664', 'pixel665', 'pixel666', 'pixel667', 'pixel668', 'pixel669', 'pixel670', 'pixel671', 'pixel672', 'pixel673', 'pixel674', 'pixel675', 'pixel676', 'pixel677', 'pixel678', 'pixel679', 'pixel680', 'pixel681', 'pixel682', 'pixel683', 'pixel684', 'pixel685', 'pixel686', 'pixel687', 'pixel688', 'pixel689', 'pixel690', 'pixel691', 'pixel692', 'pixel693', 'pixel694', 'pixel695', 'pixel696', 'pixel697', 'pixel698', 'pixel699', 'pixel700', 'pixel701', 'pixel702', 'pixel703', 'pixel704', 'pixel705', 'pixel706', 'pixel707', 'pixel708', 'pixel709', 'pixel710', 'pixel711', 'pixel712', 'pixel713', 'pixel714', 'pixel715', 'pixel716', 'pixel717', 'pixel718', 'pixel719', 'pixel720', 'pixel721', 'pixel722', 'pixel723', 'pixel724', 'pixel725', 'pixel726', 'pixel727', 'pixel728', 'pixel729', 'pixel730', 'pixel731', 'pixel732', 'pixel733', 'pixel734', 'pixel735', 'pixel736', 'pixel737', 'pixel738', 'pixel739', 'pixel740', 'pixel741', 'pixel742', 'pixel743', 'pixel744', 'pixel745', 'pixel746', 'pixel747', 'pixel748', 'pixel749', 'pixel750', 'pixel751', 'pixel752', 'pixel753', 'pixel754', 'pixel755', 'pixel756', 'pixel757', 'pixel758', 'pixel759', 'pixel760', 'pixel761', 'pixel762', 'pixel763', 'pixel764', 'pixel765', 'pixel766', 'pixel767', 'pixel768', 'pixel769', 'pixel770', 'pixel771', 'pixel772', 'pixel773', 'pixel774', 'pixel775', 'pixel776', 'pixel777', 'pixel778', 'pixel779', 'pixel780', 'pixel781', 'pixel782', 'pixel783', 'pixel784'], 'target_names': ['class'], 'DESCR': "**Author**: Yann LeCun, Corinna Cortes, Christopher J.C. Burges  \n**Source**: [MNIST Website](http://yann.lecun.com/exdb/mnist/) - Date unknown  \n**Please cite**:  \n\nThe MNIST database of handwritten digits with 784 features, raw data available at: http://yann.lecun.com/exdb/mnist/. It can be split in a training set of the first 60,000 examples, and a test set of 10,000 examples  \n\nIt is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting. The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.  \n\nWith some classification methods (particularly template-based methods, such as SVM and K-nearest neighbors), the error rate improves when the digits are centered by bounding box rather than center of mass. If you do this kind of pre-processing, you should report it in your publications. The MNIST database was constructed from NIST's NIST originally designated SD-3 as their training set and SD-1 as their test set. However, SD-3 is much cleaner and easier to recognize than SD-1. The reason for this can be found on the fact that SD-3 was collected among Census Bureau employees, while SD-1 was collected among high-school students. Drawing sensible conclusions from learning experiments requires that the result be independent of the choice of training set and test among the complete set of samples. Therefore it was necessary to build a new database by mixing NIST's datasets.  \n\nThe MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers. We made sure that the sets of writers of the training set and test set were disjoint. SD-1 contains 58,527 digit images written by 500 different writers. In contrast to SD-3, where blocks of data from each writer appeared in sequence, the data in SD-1 is scrambled. Writer identities for SD-1 is available and we used this information to unscramble the writers. We then split SD-1 in two: characters written by the first 250 writers went into our new training set. The remaining 250 writers were placed in our test set. Thus we had two sets with nearly 30,000 examples each. The new training set was completed with enough examples from SD-3, starting at pattern # 0, to make a full set of 60,000 training patterns. Similarly, the new test set was completed with SD-3 examples starting at pattern # 35,000 to make a full set with 60,000 test patterns. Only a subset of 10,000 test images (5,000 from SD-1 and 5,000 from SD-3) is available on this site. The full 60,000 sample training set is available.\n\nDownloaded from openml.org.", 'details': {'id': '554', 'name': 'mnist_784', 'version': '1', 'description_version': '1', 'format': 'ARFF', 'creator': ['Yann LeCun', 'Corinna Cortes', 'Christopher J.C. Burges'], 'upload_date': '2014-09-29T03:28:38', 'language': 'English', 'licence': 'Public', 'url': 'https://www.openml.org/data/v1/download/52667/mnist_784.arff', 'file_id': '52667', 'default_target_attribute': 'class', 'tag': ['AzurePilot', 'OpenML-CC18', 'OpenML100', 'study_1', 'study_123', 'study_41', 'study_99', 'vision'], 'visibility': 'public', 'status': 'active', 'processing_date': '2020-11-20 20:12:09', 'md5_checksum': '0298d579eb1b86163de7723944c7e495'}, 'categories': {}, 'url': 'https://www.openml.org/d/554'}

import numpy as np
from sklearn.datasets import fetch_openml

if __name__ == "__main__":

mnist = fetch_openml('mnist_784')
print(mnist)
X, y = mnist['data'], mnist['target']
print(X.shape)

{'data': array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]), 'target': array(['5', '0', '4', ..., '4', '5', '6'], dtype=object), 'frame': None, 'feature_names': ['pixel1', 'pixel2', 'pixel3', 'pixel4', 'pixel5', 'pixel6', 'pixel7', 'pixel8', 'pixel9', 'pixel10', 'pixel11', 'pixel12', 'pixel13', 'pixel14', 'pixel15', 'pixel16', 'pixel17', 'pixel18', 'pixel19', 'pixel20', 'pixel21', 'pixel22', 'pixel23', 'pixel24', 'pixel25', 'pixel26', 'pixel27', 'pixel28', 'pixel29', 'pixel30', 'pixel31', 'pixel32', 'pixel33', 'pixel34', 'pixel35', 'pixel36', 'pixel37', 'pixel38', 'pixel39', 'pixel40', 'pixel41', 'pixel42', 'pixel43', 'pixel44', 'pixel45', 'pixel46', 'pixel47', 'pixel48', 'pixel49', 'pixel50', 'pixel51', 'pixel52', 'pixel53', 'pixel54', 'pixel55', 'pixel56', 'pixel57', 'pixel58', 'pixel59', 'pixel60', 'pixel61', 'pixel62', 'pixel63', 'pixel64', 'pixel65', 'pixel66', 'pixel67', 'pixel68', 'pixel69', 'pixel70', 'pixel71', 'pixel72', 'pixel73', 'pixel74', 'pixel75', 'pixel76', 'pixel77', 'pixel78', 'pixel79', 'pixel80', 'pixel81', 'pixel82', 'pixel83', 'pixel84', 'pixel85', 'pixel86', 'pixel87', 'pixel88', 'pixel89', 'pixel90', 'pixel91', 'pixel92', 'pixel93', 'pixel94', 'pixel95', 'pixel96', 'pixel97', 'pixel98', 'pixel99', 'pixel100', 'pixel101', 'pixel102', 'pixel103', 'pixel104', 'pixel105', 'pixel106', 'pixel107', 'pixel108', 'pixel109', 'pixel110', 'pixel111', 'pixel112', 'pixel113', 'pixel114', 'pixel115', 'pixel116', 'pixel117', 'pixel118', 'pixel119', 'pixel120', 'pixel121', 'pixel122', 'pixel123', 'pixel124', 'pixel125', 'pixel126', 'pixel127', 'pixel128', 'pixel129', 'pixel130', 'pixel131', 'pixel132', 'pixel133', 'pixel134', 'pixel135', 'pixel136', 'pixel137', 'pixel138', 'pixel139', 'pixel140', 'pixel141', 'pixel142', 'pixel143', 'pixel144', 'pixel145', 'pixel146', 'pixel147', 'pixel148', 'pixel149', 'pixel150', 'pixel151', 'pixel152', 'pixel153', 'pixel154', 'pixel155', 'pixel156', 'pixel157', 'pixel158', 'pixel159', 'pixel160', 'pixel161', 'pixel162', 'pixel163', 'pixel164', 'pixel165', 'pixel166', 'pixel167', 'pixel168', 'pixel169', 'pixel170', 'pixel171', 'pixel172', 'pixel173', 'pixel174', 'pixel175', 'pixel176', 'pixel177', 'pixel178', 'pixel179', 'pixel180', 'pixel181', 'pixel182', 'pixel183', 'pixel184', 'pixel185', 'pixel186', 'pixel187', 'pixel188', 'pixel189', 'pixel190', 'pixel191', 'pixel192', 'pixel193', 'pixel194', 'pixel195', 'pixel196', 'pixel197', 'pixel198', 'pixel199', 'pixel200', 'pixel201', 'pixel202', 'pixel203', 'pixel204', 'pixel205', 'pixel206', 'pixel207', 'pixel208', 'pixel209', 'pixel210', 'pixel211', 'pixel212', 'pixel213', 'pixel214', 'pixel215', 'pixel216', 'pixel217', 'pixel218', 'pixel219', 'pixel220', 'pixel221', 'pixel222', 'pixel223', 'pixel224', 'pixel225', 'pixel226', 'pixel227', 'pixel228', 'pixel229', 'pixel230', 'pixel231', 'pixel232', 'pixel233', 'pixel234', 'pixel235', 'pixel236', 'pixel237', 'pixel238', 'pixel239', 'pixel240', 'pixel241', 'pixel242', 'pixel243', 'pixel244', 'pixel245', 'pixel246', 'pixel247', 'pixel248', 'pixel249', 'pixel250', 'pixel251', 'pixel252', 'pixel253', 'pixel254', 'pixel255', 'pixel256', 'pixel257', 'pixel258', 'pixel259', 'pixel260', 'pixel261', 'pixel262', 'pixel263', 'pixel264', 'pixel265', 'pixel266', 'pixel267', 'pixel268', 'pixel269', 'pixel270', 'pixel271', 'pixel272', 'pixel273', 'pixel274', 'pixel275', 'pixel276', 'pixel277', 'pixel278', 'pixel279', 'pixel280', 'pixel281', 'pixel282', 'pixel283', 'pixel284', 'pixel285', 'pixel286', 'pixel287', 'pixel288', 'pixel289', 'pixel290', 'pixel291', 'pixel292', 'pixel293', 'pixel294', 'pixel295', 'pixel296', 'pixel297', 'pixel298', 'pixel299', 'pixel300', 'pixel301', 'pixel302', 'pixel303', 'pixel304', 'pixel305', 'pixel306', 'pixel307', 'pixel308', 'pixel309', 'pixel310', 'pixel311', 'pixel312', 'pixel313', 'pixel314', 'pixel315', 'pixel316', 'pixel317', 'pixel318', 'pixel319', 'pixel320', 'pixel321', 'pixel322', 'pixel323', 'pixel324', 'pixel325', 'pixel326', 'pixel327', 'pixel328', 'pixel329', 'pixel330', 'pixel331', 'pixel332', 'pixel333', 'pixel334', 'pixel335', 'pixel336', 'pixel337', 'pixel338', 'pixel339', 'pixel340', 'pixel341', 'pixel342', 'pixel343', 'pixel344', 'pixel345', 'pixel346', 'pixel347', 'pixel348', 'pixel349', 'pixel350', 'pixel351', 'pixel352', 'pixel353', 'pixel354', 'pixel355', 'pixel356', 'pixel357', 'pixel358', 'pixel359', 'pixel360', 'pixel361', 'pixel362', 'pixel363', 'pixel364', 'pixel365', 'pixel366', 'pixel367', 'pixel368', 'pixel369', 'pixel370', 'pixel371', 'pixel372', 'pixel373', 'pixel374', 'pixel375', 'pixel376', 'pixel377', 'pixel378', 'pixel379', 'pixel380', 'pixel381', 'pixel382', 'pixel383', 'pixel384', 'pixel385', 'pixel386', 'pixel387', 'pixel388', 'pixel389', 'pixel390', 'pixel391', 'pixel392', 'pixel393', 'pixel394', 'pixel395', 'pixel396', 'pixel397', 'pixel398', 'pixel399', 'pixel400', 'pixel401', 'pixel402', 'pixel403', 'pixel404', 'pixel405', 'pixel406', 'pixel407', 'pixel408', 'pixel409', 'pixel410', 'pixel411', 'pixel412', 'pixel413', 'pixel414', 'pixel415', 'pixel416', 'pixel417', 'pixel418', 'pixel419', 'pixel420', 'pixel421', 'pixel422', 'pixel423', 'pixel424', 'pixel425', 'pixel426', 'pixel427', 'pixel428', 'pixel429', 'pixel430', 'pixel431', 'pixel432', 'pixel433', 'pixel434', 'pixel435', 'pixel436', 'pixel437', 'pixel438', 'pixel439', 'pixel440', 'pixel441', 'pixel442', 'pixel443', 'pixel444', 'pixel445', 'pixel446', 'pixel447', 'pixel448', 'pixel449', 'pixel450', 'pixel451', 'pixel452', 'pixel453', 'pixel454', 'pixel455', 'pixel456', 'pixel457', 'pixel458', 'pixel459', 'pixel460', 'pixel461', 'pixel462', 'pixel463', 'pixel464', 'pixel465', 'pixel466', 'pixel467', 'pixel468', 'pixel469', 'pixel470', 'pixel471', 'pixel472', 'pixel473', 'pixel474', 'pixel475', 'pixel476', 'pixel477', 'pixel478', 'pixel479', 'pixel480', 'pixel481', 'pixel482', 'pixel483', 'pixel484', 'pixel485', 'pixel486', 'pixel487', 'pixel488', 'pixel489', 'pixel490', 'pixel491', 'pixel492', 'pixel493', 'pixel494', 'pixel495', 'pixel496', 'pixel497', 'pixel498', 'pixel499', 'pixel500', 'pixel501', 'pixel502', 'pixel503', 'pixel504', 'pixel505', 'pixel506', 'pixel507', 'pixel508', 'pixel509', 'pixel510', 'pixel511', 'pixel512', 'pixel513', 'pixel514', 'pixel515', 'pixel516', 'pixel517', 'pixel518', 'pixel519', 'pixel520', 'pixel521', 'pixel522', 'pixel523', 'pixel524', 'pixel525', 'pixel526', 'pixel527', 'pixel528', 'pixel529', 'pixel530', 'pixel531', 'pixel532', 'pixel533', 'pixel534', 'pixel535', 'pixel536', 'pixel537', 'pixel538', 'pixel539', 'pixel540', 'pixel541', 'pixel542', 'pixel543', 'pixel544', 'pixel545', 'pixel546', 'pixel547', 'pixel548', 'pixel549', 'pixel550', 'pixel551', 'pixel552', 'pixel553', 'pixel554', 'pixel555', 'pixel556', 'pixel557', 'pixel558', 'pixel559', 'pixel560', 'pixel561', 'pixel562', 'pixel563', 'pixel564', 'pixel565', 'pixel566', 'pixel567', 'pixel568', 'pixel569', 'pixel570', 'pixel571', 'pixel572', 'pixel573', 'pixel574', 'pixel575', 'pixel576', 'pixel577', 'pixel578', 'pixel579', 'pixel580', 'pixel581', 'pixel582', 'pixel583', 'pixel584', 'pixel585', 'pixel586', 'pixel587', 'pixel588', 'pixel589', 'pixel590', 'pixel591', 'pixel592', 'pixel593', 'pixel594', 'pixel595', 'pixel596', 'pixel597', 'pixel598', 'pixel599', 'pixel600', 'pixel601', 'pixel602', 'pixel603', 'pixel604', 'pixel605', 'pixel606', 'pixel607', 'pixel608', 'pixel609', 'pixel610', 'pixel611', 'pixel612', 'pixel613', 'pixel614', 'pixel615', 'pixel616', 'pixel617', 'pixel618', 'pixel619', 'pixel620', 'pixel621', 'pixel622', 'pixel623', 'pixel624', 'pixel625', 'pixel626', 'pixel627', 'pixel628', 'pixel629', 'pixel630', 'pixel631', 'pixel632', 'pixel633', 'pixel634', 'pixel635', 'pixel636', 'pixel637', 'pixel638', 'pixel639', 'pixel640', 'pixel641', 'pixel642', 'pixel643', 'pixel644', 'pixel645', 'pixel646', 'pixel647', 'pixel648', 'pixel649', 'pixel650', 'pixel651', 'pixel652', 'pixel653', 'pixel654', 'pixel655', 'pixel656', 'pixel657', 'pixel658', 'pixel659', 'pixel660', 'pixel661', 'pixel662', 'pixel663', 'pixel664', 'pixel665', 'pixel666', 'pixel667', 'pixel668', 'pixel669', 'pixel670', 'pixel671', 'pixel672', 'pixel673', 'pixel674', 'pixel675', 'pixel676', 'pixel677', 'pixel678', 'pixel679', 'pixel680', 'pixel681', 'pixel682', 'pixel683', 'pixel684', 'pixel685', 'pixel686', 'pixel687', 'pixel688', 'pixel689', 'pixel690', 'pixel691', 'pixel692', 'pixel693', 'pixel694', 'pixel695', 'pixel696', 'pixel697', 'pixel698', 'pixel699', 'pixel700', 'pixel701', 'pixel702', 'pixel703', 'pixel704', 'pixel705', 'pixel706', 'pixel707', 'pixel708', 'pixel709', 'pixel710', 'pixel711', 'pixel712', 'pixel713', 'pixel714', 'pixel715', 'pixel716', 'pixel717', 'pixel718', 'pixel719', 'pixel720', 'pixel721', 'pixel722', 'pixel723', 'pixel724', 'pixel725', 'pixel726', 'pixel727', 'pixel728', 'pixel729', 'pixel730', 'pixel731', 'pixel732', 'pixel733', 'pixel734', 'pixel735', 'pixel736', 'pixel737', 'pixel738', 'pixel739', 'pixel740', 'pixel741', 'pixel742', 'pixel743', 'pixel744', 'pixel745', 'pixel746', 'pixel747', 'pixel748', 'pixel749', 'pixel750', 'pixel751', 'pixel752', 'pixel753', 'pixel754', 'pixel755', 'pixel756', 'pixel757', 'pixel758', 'pixel759', 'pixel760', 'pixel761', 'pixel762', 'pixel763', 'pixel764', 'pixel765', 'pixel766', 'pixel767', 'pixel768', 'pixel769', 'pixel770', 'pixel771', 'pixel772', 'pixel773', 'pixel774', 'pixel775', 'pixel776', 'pixel777', 'pixel778', 'pixel779', 'pixel780', 'pixel781', 'pixel782', 'pixel783', 'pixel784'], 'target_names': ['class'], 'DESCR': "**Author**: Yann LeCun, Corinna Cortes, Christopher J.C. Burges  \n**Source**: [MNIST Website](http://yann.lecun.com/exdb/mnist/) - Date unknown  \n**Please cite**:  \n\nThe MNIST database of handwritten digits with 784 features, raw data available at: http://yann.lecun.com/exdb/mnist/. It can be split in a training set of the first 60,000 examples, and a test set of 10,000 examples  \n\nIt is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting. The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.  \n\nWith some classification methods (particularly template-based methods, such as SVM and K-nearest neighbors), the error rate improves when the digits are centered by bounding box rather than center of mass. If you do this kind of pre-processing, you should report it in your publications. The MNIST database was constructed from NIST's NIST originally designated SD-3 as their training set and SD-1 as their test set. However, SD-3 is much cleaner and easier to recognize than SD-1. The reason for this can be found on the fact that SD-3 was collected among Census Bureau employees, while SD-1 was collected among high-school students. Drawing sensible conclusions from learning experiments requires that the result be independent of the choice of training set and test among the complete set of samples. Therefore it was necessary to build a new database by mixing NIST's datasets.  \n\nThe MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers. We made sure that the sets of writers of the training set and test set were disjoint. SD-1 contains 58,527 digit images written by 500 different writers. In contrast to SD-3, where blocks of data from each writer appeared in sequence, the data in SD-1 is scrambled. Writer identities for SD-1 is available and we used this information to unscramble the writers. We then split SD-1 in two: characters written by the first 250 writers went into our new training set. The remaining 250 writers were placed in our test set. Thus we had two sets with nearly 30,000 examples each. The new training set was completed with enough examples from SD-3, starting at pattern # 0, to make a full set of 60,000 training patterns. Similarly, the new test set was completed with SD-3 examples starting at pattern # 35,000 to make a full set with 60,000 test patterns. Only a subset of 10,000 test images (5,000 from SD-1 and 5,000 from SD-3) is available on this site. The full 60,000 sample training set is available.\n\nDownloaded from openml.org.", 'details': {'id': '554', 'name': 'mnist_784', 'version': '1', 'description_version': '1', 'format': 'ARFF', 'creator': ['Yann LeCun', 'Corinna Cortes', 'Christopher J.C. Burges'], 'upload_date': '2014-09-29T03:28:38', 'language': 'English', 'licence': 'Public', 'url': 'https://www.openml.org/data/v1/download/52667/mnist_784.arff', 'file_id': '52667', 'default_target_attribute': 'class', 'tag': ['AzurePilot', 'OpenML-CC18', 'OpenML100', 'study_1', 'study_123', 'study_41', 'study_99', 'vision'], 'visibility': 'public', 'status': 'active', 'processing_date': '2020-11-20 20:12:09', 'md5_checksum': '0298d579eb1b86163de7723944c7e495'}, 'categories': {}, 'url': 'https://www.openml.org/d/554'}
(70000, 784)

import numpy as np
from sklearn.datasets import fetch_openml

if __name__ == "__main__":

mnist = fetch_openml('mnist_784')
print(mnist)
X, y = mnist['data'], mnist['target']
print(X.shape)
X_train = np.array(X[:60000], dtype=float)
y_train = np.array(y[:60000], dtype=float)
X_test = np.array(X[60000:], dtype=float)
y_test = np.array(y[60000:], dtype=float)
print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)

{'data': array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]), 'target': array(['5', '0', '4', ..., '4', '5', '6'], dtype=object), 'frame': None, 'feature_names': ['pixel1', 'pixel2', 'pixel3', 'pixel4', 'pixel5', 'pixel6', 'pixel7', 'pixel8', 'pixel9', 'pixel10', 'pixel11', 'pixel12', 'pixel13', 'pixel14', 'pixel15', 'pixel16', 'pixel17', 'pixel18', 'pixel19', 'pixel20', 'pixel21', 'pixel22', 'pixel23', 'pixel24', 'pixel25', 'pixel26', 'pixel27', 'pixel28', 'pixel29', 'pixel30', 'pixel31', 'pixel32', 'pixel33', 'pixel34', 'pixel35', 'pixel36', 'pixel37', 'pixel38', 'pixel39', 'pixel40', 'pixel41', 'pixel42', 'pixel43', 'pixel44', 'pixel45', 'pixel46', 'pixel47', 'pixel48', 'pixel49', 'pixel50', 'pixel51', 'pixel52', 'pixel53', 'pixel54', 'pixel55', 'pixel56', 'pixel57', 'pixel58', 'pixel59', 'pixel60', 'pixel61', 'pixel62', 'pixel63', 'pixel64', 'pixel65', 'pixel66', 'pixel67', 'pixel68', 'pixel69', 'pixel70', 'pixel71', 'pixel72', 'pixel73', 'pixel74', 'pixel75', 'pixel76', 'pixel77', 'pixel78', 'pixel79', 'pixel80', 'pixel81', 'pixel82', 'pixel83', 'pixel84', 'pixel85', 'pixel86', 'pixel87', 'pixel88', 'pixel89', 'pixel90', 'pixel91', 'pixel92', 'pixel93', 'pixel94', 'pixel95', 'pixel96', 'pixel97', 'pixel98', 'pixel99', 'pixel100', 'pixel101', 'pixel102', 'pixel103', 'pixel104', 'pixel105', 'pixel106', 'pixel107', 'pixel108', 'pixel109', 'pixel110', 'pixel111', 'pixel112', 'pixel113', 'pixel114', 'pixel115', 'pixel116', 'pixel117', 'pixel118', 'pixel119', 'pixel120', 'pixel121', 'pixel122', 'pixel123', 'pixel124', 'pixel125', 'pixel126', 'pixel127', 'pixel128', 'pixel129', 'pixel130', 'pixel131', 'pixel132', 'pixel133', 'pixel134', 'pixel135', 'pixel136', 'pixel137', 'pixel138', 'pixel139', 'pixel140', 'pixel141', 'pixel142', 'pixel143', 'pixel144', 'pixel145', 'pixel146', 'pixel147', 'pixel148', 'pixel149', 'pixel150', 'pixel151', 'pixel152', 'pixel153', 'pixel154', 'pixel155', 'pixel156', 'pixel157', 'pixel158', 'pixel159', 'pixel160', 'pixel161', 'pixel162', 'pixel163', 'pixel164', 'pixel165', 'pixel166', 'pixel167', 'pixel168', 'pixel169', 'pixel170', 'pixel171', 'pixel172', 'pixel173', 'pixel174', 'pixel175', 'pixel176', 'pixel177', 'pixel178', 'pixel179', 'pixel180', 'pixel181', 'pixel182', 'pixel183', 'pixel184', 'pixel185', 'pixel186', 'pixel187', 'pixel188', 'pixel189', 'pixel190', 'pixel191', 'pixel192', 'pixel193', 'pixel194', 'pixel195', 'pixel196', 'pixel197', 'pixel198', 'pixel199', 'pixel200', 'pixel201', 'pixel202', 'pixel203', 'pixel204', 'pixel205', 'pixel206', 'pixel207', 'pixel208', 'pixel209', 'pixel210', 'pixel211', 'pixel212', 'pixel213', 'pixel214', 'pixel215', 'pixel216', 'pixel217', 'pixel218', 'pixel219', 'pixel220', 'pixel221', 'pixel222', 'pixel223', 'pixel224', 'pixel225', 'pixel226', 'pixel227', 'pixel228', 'pixel229', 'pixel230', 'pixel231', 'pixel232', 'pixel233', 'pixel234', 'pixel235', 'pixel236', 'pixel237', 'pixel238', 'pixel239', 'pixel240', 'pixel241', 'pixel242', 'pixel243', 'pixel244', 'pixel245', 'pixel246', 'pixel247', 'pixel248', 'pixel249', 'pixel250', 'pixel251', 'pixel252', 'pixel253', 'pixel254', 'pixel255', 'pixel256', 'pixel257', 'pixel258', 'pixel259', 'pixel260', 'pixel261', 'pixel262', 'pixel263', 'pixel264', 'pixel265', 'pixel266', 'pixel267', 'pixel268', 'pixel269', 'pixel270', 'pixel271', 'pixel272', 'pixel273', 'pixel274', 'pixel275', 'pixel276', 'pixel277', 'pixel278', 'pixel279', 'pixel280', 'pixel281', 'pixel282', 'pixel283', 'pixel284', 'pixel285', 'pixel286', 'pixel287', 'pixel288', 'pixel289', 'pixel290', 'pixel291', 'pixel292', 'pixel293', 'pixel294', 'pixel295', 'pixel296', 'pixel297', 'pixel298', 'pixel299', 'pixel300', 'pixel301', 'pixel302', 'pixel303', 'pixel304', 'pixel305', 'pixel306', 'pixel307', 'pixel308', 'pixel309', 'pixel310', 'pixel311', 'pixel312', 'pixel313', 'pixel314', 'pixel315', 'pixel316', 'pixel317', 'pixel318', 'pixel319', 'pixel320', 'pixel321', 'pixel322', 'pixel323', 'pixel324', 'pixel325', 'pixel326', 'pixel327', 'pixel328', 'pixel329', 'pixel330', 'pixel331', 'pixel332', 'pixel333', 'pixel334', 'pixel335', 'pixel336', 'pixel337', 'pixel338', 'pixel339', 'pixel340', 'pixel341', 'pixel342', 'pixel343', 'pixel344', 'pixel345', 'pixel346', 'pixel347', 'pixel348', 'pixel349', 'pixel350', 'pixel351', 'pixel352', 'pixel353', 'pixel354', 'pixel355', 'pixel356', 'pixel357', 'pixel358', 'pixel359', 'pixel360', 'pixel361', 'pixel362', 'pixel363', 'pixel364', 'pixel365', 'pixel366', 'pixel367', 'pixel368', 'pixel369', 'pixel370', 'pixel371', 'pixel372', 'pixel373', 'pixel374', 'pixel375', 'pixel376', 'pixel377', 'pixel378', 'pixel379', 'pixel380', 'pixel381', 'pixel382', 'pixel383', 'pixel384', 'pixel385', 'pixel386', 'pixel387', 'pixel388', 'pixel389', 'pixel390', 'pixel391', 'pixel392', 'pixel393', 'pixel394', 'pixel395', 'pixel396', 'pixel397', 'pixel398', 'pixel399', 'pixel400', 'pixel401', 'pixel402', 'pixel403', 'pixel404', 'pixel405', 'pixel406', 'pixel407', 'pixel408', 'pixel409', 'pixel410', 'pixel411', 'pixel412', 'pixel413', 'pixel414', 'pixel415', 'pixel416', 'pixel417', 'pixel418', 'pixel419', 'pixel420', 'pixel421', 'pixel422', 'pixel423', 'pixel424', 'pixel425', 'pixel426', 'pixel427', 'pixel428', 'pixel429', 'pixel430', 'pixel431', 'pixel432', 'pixel433', 'pixel434', 'pixel435', 'pixel436', 'pixel437', 'pixel438', 'pixel439', 'pixel440', 'pixel441', 'pixel442', 'pixel443', 'pixel444', 'pixel445', 'pixel446', 'pixel447', 'pixel448', 'pixel449', 'pixel450', 'pixel451', 'pixel452', 'pixel453', 'pixel454', 'pixel455', 'pixel456', 'pixel457', 'pixel458', 'pixel459', 'pixel460', 'pixel461', 'pixel462', 'pixel463', 'pixel464', 'pixel465', 'pixel466', 'pixel467', 'pixel468', 'pixel469', 'pixel470', 'pixel471', 'pixel472', 'pixel473', 'pixel474', 'pixel475', 'pixel476', 'pixel477', 'pixel478', 'pixel479', 'pixel480', 'pixel481', 'pixel482', 'pixel483', 'pixel484', 'pixel485', 'pixel486', 'pixel487', 'pixel488', 'pixel489', 'pixel490', 'pixel491', 'pixel492', 'pixel493', 'pixel494', 'pixel495', 'pixel496', 'pixel497', 'pixel498', 'pixel499', 'pixel500', 'pixel501', 'pixel502', 'pixel503', 'pixel504', 'pixel505', 'pixel506', 'pixel507', 'pixel508', 'pixel509', 'pixel510', 'pixel511', 'pixel512', 'pixel513', 'pixel514', 'pixel515', 'pixel516', 'pixel517', 'pixel518', 'pixel519', 'pixel520', 'pixel521', 'pixel522', 'pixel523', 'pixel524', 'pixel525', 'pixel526', 'pixel527', 'pixel528', 'pixel529', 'pixel530', 'pixel531', 'pixel532', 'pixel533', 'pixel534', 'pixel535', 'pixel536', 'pixel537', 'pixel538', 'pixel539', 'pixel540', 'pixel541', 'pixel542', 'pixel543', 'pixel544', 'pixel545', 'pixel546', 'pixel547', 'pixel548', 'pixel549', 'pixel550', 'pixel551', 'pixel552', 'pixel553', 'pixel554', 'pixel555', 'pixel556', 'pixel557', 'pixel558', 'pixel559', 'pixel560', 'pixel561', 'pixel562', 'pixel563', 'pixel564', 'pixel565', 'pixel566', 'pixel567', 'pixel568', 'pixel569', 'pixel570', 'pixel571', 'pixel572', 'pixel573', 'pixel574', 'pixel575', 'pixel576', 'pixel577', 'pixel578', 'pixel579', 'pixel580', 'pixel581', 'pixel582', 'pixel583', 'pixel584', 'pixel585', 'pixel586', 'pixel587', 'pixel588', 'pixel589', 'pixel590', 'pixel591', 'pixel592', 'pixel593', 'pixel594', 'pixel595', 'pixel596', 'pixel597', 'pixel598', 'pixel599', 'pixel600', 'pixel601', 'pixel602', 'pixel603', 'pixel604', 'pixel605', 'pixel606', 'pixel607', 'pixel608', 'pixel609', 'pixel610', 'pixel611', 'pixel612', 'pixel613', 'pixel614', 'pixel615', 'pixel616', 'pixel617', 'pixel618', 'pixel619', 'pixel620', 'pixel621', 'pixel622', 'pixel623', 'pixel624', 'pixel625', 'pixel626', 'pixel627', 'pixel628', 'pixel629', 'pixel630', 'pixel631', 'pixel632', 'pixel633', 'pixel634', 'pixel635', 'pixel636', 'pixel637', 'pixel638', 'pixel639', 'pixel640', 'pixel641', 'pixel642', 'pixel643', 'pixel644', 'pixel645', 'pixel646', 'pixel647', 'pixel648', 'pixel649', 'pixel650', 'pixel651', 'pixel652', 'pixel653', 'pixel654', 'pixel655', 'pixel656', 'pixel657', 'pixel658', 'pixel659', 'pixel660', 'pixel661', 'pixel662', 'pixel663', 'pixel664', 'pixel665', 'pixel666', 'pixel667', 'pixel668', 'pixel669', 'pixel670', 'pixel671', 'pixel672', 'pixel673', 'pixel674', 'pixel675', 'pixel676', 'pixel677', 'pixel678', 'pixel679', 'pixel680', 'pixel681', 'pixel682', 'pixel683', 'pixel684', 'pixel685', 'pixel686', 'pixel687', 'pixel688', 'pixel689', 'pixel690', 'pixel691', 'pixel692', 'pixel693', 'pixel694', 'pixel695', 'pixel696', 'pixel697', 'pixel698', 'pixel699', 'pixel700', 'pixel701', 'pixel702', 'pixel703', 'pixel704', 'pixel705', 'pixel706', 'pixel707', 'pixel708', 'pixel709', 'pixel710', 'pixel711', 'pixel712', 'pixel713', 'pixel714', 'pixel715', 'pixel716', 'pixel717', 'pixel718', 'pixel719', 'pixel720', 'pixel721', 'pixel722', 'pixel723', 'pixel724', 'pixel725', 'pixel726', 'pixel727', 'pixel728', 'pixel729', 'pixel730', 'pixel731', 'pixel732', 'pixel733', 'pixel734', 'pixel735', 'pixel736', 'pixel737', 'pixel738', 'pixel739', 'pixel740', 'pixel741', 'pixel742', 'pixel743', 'pixel744', 'pixel745', 'pixel746', 'pixel747', 'pixel748', 'pixel749', 'pixel750', 'pixel751', 'pixel752', 'pixel753', 'pixel754', 'pixel755', 'pixel756', 'pixel757', 'pixel758', 'pixel759', 'pixel760', 'pixel761', 'pixel762', 'pixel763', 'pixel764', 'pixel765', 'pixel766', 'pixel767', 'pixel768', 'pixel769', 'pixel770', 'pixel771', 'pixel772', 'pixel773', 'pixel774', 'pixel775', 'pixel776', 'pixel777', 'pixel778', 'pixel779', 'pixel780', 'pixel781', 'pixel782', 'pixel783', 'pixel784'], 'target_names': ['class'], 'DESCR': "**Author**: Yann LeCun, Corinna Cortes, Christopher J.C. Burges  \n**Source**: [MNIST Website](http://yann.lecun.com/exdb/mnist/) - Date unknown  \n**Please cite**:  \n\nThe MNIST database of handwritten digits with 784 features, raw data available at: http://yann.lecun.com/exdb/mnist/. It can be split in a training set of the first 60,000 examples, and a test set of 10,000 examples  \n\nIt is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting. The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.  \n\nWith some classification methods (particularly template-based methods, such as SVM and K-nearest neighbors), the error rate improves when the digits are centered by bounding box rather than center of mass. If you do this kind of pre-processing, you should report it in your publications. The MNIST database was constructed from NIST's NIST originally designated SD-3 as their training set and SD-1 as their test set. However, SD-3 is much cleaner and easier to recognize than SD-1. The reason for this can be found on the fact that SD-3 was collected among Census Bureau employees, while SD-1 was collected among high-school students. Drawing sensible conclusions from learning experiments requires that the result be independent of the choice of training set and test among the complete set of samples. Therefore it was necessary to build a new database by mixing NIST's datasets.  \n\nThe MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers. We made sure that the sets of writers of the training set and test set were disjoint. SD-1 contains 58,527 digit images written by 500 different writers. In contrast to SD-3, where blocks of data from each writer appeared in sequence, the data in SD-1 is scrambled. Writer identities for SD-1 is available and we used this information to unscramble the writers. We then split SD-1 in two: characters written by the first 250 writers went into our new training set. The remaining 250 writers were placed in our test set. Thus we had two sets with nearly 30,000 examples each. The new training set was completed with enough examples from SD-3, starting at pattern # 0, to make a full set of 60,000 training patterns. Similarly, the new test set was completed with SD-3 examples starting at pattern # 35,000 to make a full set with 60,000 test patterns. Only a subset of 10,000 test images (5,000 from SD-1 and 5,000 from SD-3) is available on this site. The full 60,000 sample training set is available.\n\nDownloaded from openml.org.", 'details': {'id': '554', 'name': 'mnist_784', 'version': '1', 'description_version': '1', 'format': 'ARFF', 'creator': ['Yann LeCun', 'Corinna Cortes', 'Christopher J.C. Burges'], 'upload_date': '2014-09-29T03:28:38', 'language': 'English', 'licence': 'Public', 'url': 'https://www.openml.org/data/v1/download/52667/mnist_784.arff', 'file_id': '52667', 'default_target_attribute': 'class', 'tag': ['AzurePilot', 'OpenML-CC18', 'OpenML100', 'study_1', 'study_123', 'study_41', 'study_99', 'vision'], 'visibility': 'public', 'status': 'active', 'processing_date': '2020-11-20 20:12:09', 'md5_checksum': '0298d579eb1b86163de7723944c7e495'}, 'categories': {}, 'url': 'https://www.openml.org/d/554'}
(70000, 784)
(60000, 784)
{'data': array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]), 'target': array(['5', '0', '4', ..., '4', '5', '6'], dtype=object), 'frame': None, 'feature_names': ['pixel1', 'pixel2', 'pixel3', 'pixel4', 'pixel5', 'pixel6', 'pixel7', 'pixel8', 'pixel9', 'pixel10', 'pixel11', 'pixel12', 'pixel13', 'pixel14', 'pixel15', 'pixel16', 'pixel17', 'pixel18', 'pixel19', 'pixel20', 'pixel21', 'pixel22', 'pixel23', 'pixel24', 'pixel25', 'pixel26', 'pixel27', 'pixel28', 'pixel29', 'pixel30', 'pixel31', 'pixel32', 'pixel33', 'pixel34', 'pixel35', 'pixel36', 'pixel37', 'pixel38', 'pixel39', 'pixel40', 'pixel41', 'pixel42', 'pixel43', 'pixel44', 'pixel45', 'pixel46', 'pixel47', 'pixel48', 'pixel49', 'pixel50', 'pixel51', 'pixel52', 'pixel53', 'pixel54', 'pixel55', 'pixel56', 'pixel57', 'pixel58', 'pixel59', 'pixel60', 'pixel61', 'pixel62', 'pixel63', 'pixel64', 'pixel65', 'pixel66', 'pixel67', 'pixel68', 'pixel69', 'pixel70', 'pixel71', 'pixel72', 'pixel73', 'pixel74', 'pixel75', 'pixel76', 'pixel77', 'pixel78', 'pixel79', 'pixel80', 'pixel81', 'pixel82', 'pixel83', 'pixel84', 'pixel85', 'pixel86', 'pixel87', 'pixel88', 'pixel89', 'pixel90', 'pixel91', 'pixel92', 'pixel93', 'pixel94', 'pixel95', 'pixel96', 'pixel97', 'pixel98', 'pixel99', 'pixel100', 'pixel101', 'pixel102', 'pixel103', 'pixel104', 'pixel105', 'pixel106', 'pixel107', 'pixel108', 'pixel109', 'pixel110', 'pixel111', 'pixel112', 'pixel113', 'pixel114', 'pixel115', 'pixel116', 'pixel117', 'pixel118', 'pixel119', 'pixel120', 'pixel121', 'pixel122', 'pixel123', 'pixel124', 'pixel125', 'pixel126', 'pixel127', 'pixel128', 'pixel129', 'pixel130', 'pixel131', 'pixel132', 'pixel133', 'pixel134', 'pixel135', 'pixel136', 'pixel137', 'pixel138', 'pixel139', 'pixel140', 'pixel141', 'pixel142', 'pixel143', 'pixel144', 'pixel145', 'pixel146', 'pixel147', 'pixel148', 'pixel149', 'pixel150', 'pixel151', 'pixel152', 'pixel153', 'pixel154', 'pixel155', 'pixel156', 'pixel157', 'pixel158', 'pixel159', 'pixel160', 'pixel161', 'pixel162', 'pixel163', 'pixel164', 'pixel165', 'pixel166', 'pixel167', 'pixel168', 'pixel169', 'pixel170', 'pixel171', 'pixel172', 'pixel173', 'pixel174', 'pixel175', 'pixel176', 'pixel177', 'pixel178', 'pixel179', 'pixel180', 'pixel181', 'pixel182', 'pixel183', 'pixel184', 'pixel185', 'pixel186', 'pixel187', 'pixel188', 'pixel189', 'pixel190', 'pixel191', 'pixel192', 'pixel193', 'pixel194', 'pixel195', 'pixel196', 'pixel197', 'pixel198', 'pixel199', 'pixel200', 'pixel201', 'pixel202', 'pixel203', 'pixel204', 'pixel205', 'pixel206', 'pixel207', 'pixel208', 'pixel209', 'pixel210', 'pixel211', 'pixel212', 'pixel213', 'pixel214', 'pixel215', 'pixel216', 'pixel217', 'pixel218', 'pixel219', 'pixel220', 'pixel221', 'pixel222', 'pixel223', 'pixel224', 'pixel225', 'pixel226', 'pixel227', 'pixel228', 'pixel229', 'pixel230', 'pixel231', 'pixel232', 'pixel233', 'pixel234', 'pixel235', 'pixel236', 'pixel237', 'pixel238', 'pixel239', 'pixel240', 'pixel241', 'pixel242', 'pixel243', 'pixel244', 'pixel245', 'pixel246', 'pixel247', 'pixel248', 'pixel249', 'pixel250', 'pixel251', 'pixel252', 'pixel253', 'pixel254', 'pixel255', 'pixel256', 'pixel257', 'pixel258', 'pixel259', 'pixel260', 'pixel261', 'pixel262', 'pixel263', 'pixel264', 'pixel265', 'pixel266', 'pixel267', 'pixel268', 'pixel269', 'pixel270', 'pixel271', 'pixel272', 'pixel273', 'pixel274', 'pixel275', 'pixel276', 'pixel277', 'pixel278', 'pixel279', 'pixel280', 'pixel281', 'pixel282', 'pixel283', 'pixel284', 'pixel285', 'pixel286', 'pixel287', 'pixel288', 'pixel289', 'pixel290', 'pixel291', 'pixel292', 'pixel293', 'pixel294', 'pixel295', 'pixel296', 'pixel297', 'pixel298', 'pixel299', 'pixel300', 'pixel301', 'pixel302', 'pixel303', 'pixel304', 'pixel305', 'pixel306', 'pixel307', 'pixel308', 'pixel309', 'pixel310', 'pixel311', 'pixel312', 'pixel313', 'pixel314', 'pixel315', 'pixel316', 'pixel317', 'pixel318', 'pixel319', 'pixel320', 'pixel321', 'pixel322', 'pixel323', 'pixel324', 'pixel325', 'pixel326', 'pixel327', 'pixel328', 'pixel329', 'pixel330', 'pixel331', 'pixel332', 'pixel333', 'pixel334', 'pixel335', 'pixel336', 'pixel337', 'pixel338', 'pixel339', 'pixel340', 'pixel341', 'pixel342', 'pixel343', 'pixel344', 'pixel345', 'pixel346', 'pixel347', 'pixel348', 'pixel349', 'pixel350', 'pixel351', 'pixel352', 'pixel353', 'pixel354', 'pixel355', 'pixel356', 'pixel357', 'pixel358', 'pixel359', 'pixel360', 'pixel361', 'pixel362', 'pixel363', 'pixel364', 'pixel365', 'pixel366', 'pixel367', 'pixel368', 'pixel369', 'pixel370', 'pixel371', 'pixel372', 'pixel373', 'pixel374', 'pixel375', 'pixel376', 'pixel377', 'pixel378', 'pixel379', 'pixel380', 'pixel381', 'pixel382', 'pixel383', 'pixel384', 'pixel385', 'pixel386', 'pixel387', 'pixel388', 'pixel389', 'pixel390', 'pixel391', 'pixel392', 'pixel393', 'pixel394', 'pixel395', 'pixel396', 'pixel397', 'pixel398', 'pixel399', 'pixel400', 'pixel401', 'pixel402', 'pixel403', 'pixel404', 'pixel405', 'pixel406', 'pixel407', 'pixel408', 'pixel409', 'pixel410', 'pixel411', 'pixel412', 'pixel413', 'pixel414', 'pixel415', 'pixel416', 'pixel417', 'pixel418', 'pixel419', 'pixel420', 'pixel421', 'pixel422', 'pixel423', 'pixel424', 'pixel425', 'pixel426', 'pixel427', 'pixel428', 'pixel429', 'pixel430', 'pixel431', 'pixel432', 'pixel433', 'pixel434', 'pixel435', 'pixel436', 'pixel437', 'pixel438', 'pixel439', 'pixel440', 'pixel441', 'pixel442', 'pixel443', 'pixel444', 'pixel445', 'pixel446', 'pixel447', 'pixel448', 'pixel449', 'pixel450', 'pixel451', 'pixel452', 'pixel453', 'pixel454', 'pixel455', 'pixel456', 'pixel457', 'pixel458', 'pixel459', 'pixel460', 'pixel461', 'pixel462', 'pixel463', 'pixel464', 'pixel465', 'pixel466', 'pixel467', 'pixel468', 'pixel469', 'pixel470', 'pixel471', 'pixel472', 'pixel473', 'pixel474', 'pixel475', 'pixel476', 'pixel477', 'pixel478', 'pixel479', 'pixel480', 'pixel481', 'pixel482', 'pixel483', 'pixel484', 'pixel485', 'pixel486', 'pixel487', 'pixel488', 'pixel489', 'pixel490', 'pixel491', 'pixel492', 'pixel493', 'pixel494', 'pixel495', 'pixel496', 'pixel497', 'pixel498', 'pixel499', 'pixel500', 'pixel501', 'pixel502', 'pixel503', 'pixel504', 'pixel505', 'pixel506', 'pixel507', 'pixel508', 'pixel509', 'pixel510', 'pixel511', 'pixel512', 'pixel513', 'pixel514', 'pixel515', 'pixel516', 'pixel517', 'pixel518', 'pixel519', 'pixel520', 'pixel521', 'pixel522', 'pixel523', 'pixel524', 'pixel525', 'pixel526', 'pixel527', 'pixel528', 'pixel529', 'pixel530', 'pixel531', 'pixel532', 'pixel533', 'pixel534', 'pixel535', 'pixel536', 'pixel537', 'pixel538', 'pixel539', 'pixel540', 'pixel541', 'pixel542', 'pixel543', 'pixel544', 'pixel545', 'pixel546', 'pixel547', 'pixel548', 'pixel549', 'pixel550', 'pixel551', 'pixel552', 'pixel553', 'pixel554', 'pixel555', 'pixel556', 'pixel557', 'pixel558', 'pixel559', 'pixel560', 'pixel561', 'pixel562', 'pixel563', 'pixel564', 'pixel565', 'pixel566', 'pixel567', 'pixel568', 'pixel569', 'pixel570', 'pixel571', 'pixel572', 'pixel573', 'pixel574', 'pixel575', 'pixel576', 'pixel577', 'pixel578', 'pixel579', 'pixel580', 'pixel581', 'pixel582', 'pixel583', 'pixel584', 'pixel585', 'pixel586', 'pixel587', 'pixel588', 'pixel589', 'pixel590', 'pixel591', 'pixel592', 'pixel593', 'pixel594', 'pixel595', 'pixel596', 'pixel597', 'pixel598', 'pixel599', 'pixel600', 'pixel601', 'pixel602', 'pixel603', 'pixel604', 'pixel605', 'pixel606', 'pixel607', 'pixel608', 'pixel609', 'pixel610', 'pixel611', 'pixel612', 'pixel613', 'pixel614', 'pixel615', 'pixel616', 'pixel617', 'pixel618', 'pixel619', 'pixel620', 'pixel621', 'pixel622', 'pixel623', 'pixel624', 'pixel625', 'pixel626', 'pixel627', 'pixel628', 'pixel629', 'pixel630', 'pixel631', 'pixel632', 'pixel633', 'pixel634', 'pixel635', 'pixel636', 'pixel637', 'pixel638', 'pixel639', 'pixel640', 'pixel641', 'pixel642', 'pixel643', 'pixel644', 'pixel645', 'pixel646', 'pixel647', 'pixel648', 'pixel649', 'pixel650', 'pixel651', 'pixel652', 'pixel653', 'pixel654', 'pixel655', 'pixel656', 'pixel657', 'pixel658', 'pixel659', 'pixel660', 'pixel661', 'pixel662', 'pixel663', 'pixel664', 'pixel665', 'pixel666', 'pixel667', 'pixel668', 'pixel669', 'pixel670', 'pixel671', 'pixel672', 'pixel673', 'pixel674', 'pixel675', 'pixel676', 'pixel677', 'pixel678', 'pixel679', 'pixel680', 'pixel681', 'pixel682', 'pixel683', 'pixel684', 'pixel685', 'pixel686', 'pixel687', 'pixel688', 'pixel689', 'pixel690', 'pixel691', 'pixel692', 'pixel693', 'pixel694', 'pixel695', 'pixel696', 'pixel697', 'pixel698', 'pixel699', 'pixel700', 'pixel701', 'pixel702', 'pixel703', 'pixel704', 'pixel705', 'pixel706', 'pixel707', 'pixel708', 'pixel709', 'pixel710', 'pixel711', 'pixel712', 'pixel713', 'pixel714', 'pixel715', 'pixel716', 'pixel717', 'pixel718', 'pixel719', 'pixel720', 'pixel721', 'pixel722', 'pixel723', 'pixel724', 'pixel725', 'pixel726', 'pixel727', 'pixel728', 'pixel729', 'pixel730', 'pixel731', 'pixel732', 'pixel733', 'pixel734', 'pixel735', 'pixel736', 'pixel737', 'pixel738', 'pixel739', 'pixel740', 'pixel741', 'pixel742', 'pixel743', 'pixel744', 'pixel745', 'pixel746', 'pixel747', 'pixel748', 'pixel749', 'pixel750', 'pixel751', 'pixel752', 'pixel753', 'pixel754', 'pixel755', 'pixel756', 'pixel757', 'pixel758', 'pixel759', 'pixel760', 'pixel761', 'pixel762', 'pixel763', 'pixel764', 'pixel765', 'pixel766', 'pixel767', 'pixel768', 'pixel769', 'pixel770', 'pixel771', 'pixel772', 'pixel773', 'pixel774', 'pixel775', 'pixel776', 'pixel777', 'pixel778', 'pixel779', 'pixel780', 'pixel781', 'pixel782', 'pixel783', 'pixel784'], 'target_names': ['class'], 'DESCR': "**Author**: Yann LeCun, Corinna Cortes, Christopher J.C. Burges  \n**Source**: [MNIST Website](http://yann.lecun.com/exdb/mnist/) - Date unknown  \n**Please cite**:  \n\nThe MNIST database of handwritten digits with 784 features, raw data available at: http://yann.lecun.com/exdb/mnist/. It can be split in a training set of the first 60,000 examples, and a test set of 10,000 examples  \n\nIt is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting. The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.  \n\nWith some classification methods (particularly template-based methods, such as SVM and K-nearest neighbors), the error rate improves when the digits are centered by bounding box rather than center of mass. If you do this kind of pre-processing, you should report it in your publications. The MNIST database was constructed from NIST's NIST originally designated SD-3 as their training set and SD-1 as their test set. However, SD-3 is much cleaner and easier to recognize than SD-1. The reason for this can be found on the fact that SD-3 was collected among Census Bureau employees, while SD-1 was collected among high-school students. Drawing sensible conclusions from learning experiments requires that the result be independent of the choice of training set and test among the complete set of samples. Therefore it was necessary to build a new database by mixing NIST's datasets.  \n\nThe MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers. We made sure that the sets of writers of the training set and test set were disjoint. SD-1 contains 58,527 digit images written by 500 different writers. In contrast to SD-3, where blocks of data from each writer appeared in sequence, the data in SD-1 is scrambled. Writer identities for SD-1 is available and we used this information to unscramble the writers. We then split SD-1 in two: characters written by the first 250 writers went into our new training set. The remaining 250 writers were placed in our test set. Thus we had two sets with nearly 30,000 examples each. The new training set was completed with enough examples from SD-3, starting at pattern # 0, to make a full set of 60,000 training patterns. Similarly, the new test set was completed with SD-3 examples starting at pattern # 35,000 to make a full set with 60,000 test patterns. Only a subset of 10,000 test images (5,000 from SD-1 and 5,000 from SD-3) is available on this site. The full 60,000 sample training set is available.\n\nDownloaded from openml.org.", 'details': {'id': '554', 'name': 'mnist_784', 'version': '1', 'description_version': '1', 'format': 'ARFF', 'creator': ['Yann LeCun', 'Corinna Cortes', 'Christopher J.C. Burges'], 'upload_date': '2014-09-29T03:28:38', 'language': 'English', 'licence': 'Public', 'url': 'https://www.openml.org/data/v1/download/52667/mnist_784.arff', 'file_id': '52667', 'default_target_attribute': 'class', 'tag': ['AzurePilot', 'OpenML-CC18', 'OpenML100', 'study_1', 'study_123', 'study_41', 'study_99', 'vision'], 'visibility': 'public', 'status': 'active', 'processing_date': '2020-11-20 20:12:09', 'md5_checksum': '0298d579eb1b86163de7723944c7e495'}, 'categories': {}, 'url': 'https://www.openml.org/d/554'}
(70000, 784)
(60000, 784)
(60000,)
(10000, 784)
(10000,)

import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.neighbors import KNeighborsClassifier
import timeit

if __name__ == "__main__":

mnist = fetch_openml('mnist_784')
print(mnist)
X, y = mnist['data'], mnist['target']
print(X.shape)
X_train = np.array(X[:60000], dtype=float)
y_train = np.array(y[:60000], dtype=float)
X_test = np.array(X[60000:], dtype=float)
y_test = np.array(y[60000:], dtype=float)
print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)
# 对原始数据进行KNN分类
start_time = timeit.default_timer()
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_train)
print(timeit.default_timer() - start_time)
{'data': array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]), 'target': array(['5', '0', '4', ..., '4', '5', '6'], dtype=object), 'frame': None, 'feature_names': ['pixel1', 'pixel2', 'pixel3', 'pixel4', 'pixel5', 'pixel6', 'pixel7', 'pixel8', 'pixel9', 'pixel10', 'pixel11', 'pixel12', 'pixel13', 'pixel14', 'pixel15', 'pixel16', 'pixel17', 'pixel18', 'pixel19', 'pixel20', 'pixel21', 'pixel22', 'pixel23', 'pixel24', 'pixel25', 'pixel26', 'pixel27', 'pixel28', 'pixel29', 'pixel30', 'pixel31', 'pixel32', 'pixel33', 'pixel34', 'pixel35', 'pixel36', 'pixel37', 'pixel38', 'pixel39', 'pixel40', 'pixel41', 'pixel42', 'pixel43', 'pixel44', 'pixel45', 'pixel46', 'pixel47', 'pixel48', 'pixel49', 'pixel50', 'pixel51', 'pixel52', 'pixel53', 'pixel54', 'pixel55', 'pixel56', 'pixel57', 'pixel58', 'pixel59', 'pixel60', 'pixel61', 'pixel62', 'pixel63', 'pixel64', 'pixel65', 'pixel66', 'pixel67', 'pixel68', 'pixel69', 'pixel70', 'pixel71', 'pixel72', 'pixel73', 'pixel74', 'pixel75', 'pixel76', 'pixel77', 'pixel78', 'pixel79', 'pixel80', 'pixel81', 'pixel82', 'pixel83', 'pixel84', 'pixel85', 'pixel86', 'pixel87', 'pixel88', 'pixel89', 'pixel90', 'pixel91', 'pixel92', 'pixel93', 'pixel94', 'pixel95', 'pixel96', 'pixel97', 'pixel98', 'pixel99', 'pixel100', 'pixel101', 'pixel102', 'pixel103', 'pixel104', 'pixel105', 'pixel106', 'pixel107', 'pixel108', 'pixel109', 'pixel110', 'pixel111', 'pixel112', 'pixel113', 'pixel114', 'pixel115', 'pixel116', 'pixel117', 'pixel118', 'pixel119', 'pixel120', 'pixel121', 'pixel122', 'pixel123', 'pixel124', 'pixel125', 'pixel126', 'pixel127', 'pixel128', 'pixel129', 'pixel130', 'pixel131', 'pixel132', 'pixel133', 'pixel134', 'pixel135', 'pixel136', 'pixel137', 'pixel138', 'pixel139', 'pixel140', 'pixel141', 'pixel142', 'pixel143', 'pixel144', 'pixel145', 'pixel146', 'pixel147', 'pixel148', 'pixel149', 'pixel150', 'pixel151', 'pixel152', 'pixel153', 'pixel154', 'pixel155', 'pixel156', 'pixel157', 'pixel158', 'pixel159', 'pixel160', 'pixel161', 'pixel162', 'pixel163', 'pixel164', 'pixel165', 'pixel166', 'pixel167', 'pixel168', 'pixel169', 'pixel170', 'pixel171', 'pixel172', 'pixel173', 'pixel174', 'pixel175', 'pixel176', 'pixel177', 'pixel178', 'pixel179', 'pixel180', 'pixel181', 'pixel182', 'pixel183', 'pixel184', 'pixel185', 'pixel186', 'pixel187', 'pixel188', 'pixel189', 'pixel190', 'pixel191', 'pixel192', 'pixel193', 'pixel194', 'pixel195', 'pixel196', 'pixel197', 'pixel198', 'pixel199', 'pixel200', 'pixel201', 'pixel202', 'pixel203', 'pixel204', 'pixel205', 'pixel206', 'pixel207', 'pixel208', 'pixel209', 'pixel210', 'pixel211', 'pixel212', 'pixel213', 'pixel214', 'pixel215', 'pixel216', 'pixel217', 'pixel218', 'pixel219', 'pixel220', 'pixel221', 'pixel222', 'pixel223', 'pixel224', 'pixel225', 'pixel226', 'pixel227', 'pixel228', 'pixel229', 'pixel230', 'pixel231', 'pixel232', 'pixel233', 'pixel234', 'pixel235', 'pixel236', 'pixel237', 'pixel238', 'pixel239', 'pixel240', 'pixel241', 'pixel242', 'pixel243', 'pixel244', 'pixel245', 'pixel246', 'pixel247', 'pixel248', 'pixel249', 'pixel250', 'pixel251', 'pixel252', 'pixel253', 'pixel254', 'pixel255', 'pixel256', 'pixel257', 'pixel258', 'pixel259', 'pixel260', 'pixel261', 'pixel262', 'pixel263', 'pixel264', 'pixel265', 'pixel266', 'pixel267', 'pixel268', 'pixel269', 'pixel270', 'pixel271', 'pixel272', 'pixel273', 'pixel274', 'pixel275', 'pixel276', 'pixel277', 'pixel278', 'pixel279', 'pixel280', 'pixel281', 'pixel282', 'pixel283', 'pixel284', 'pixel285', 'pixel286', 'pixel287', 'pixel288', 'pixel289', 'pixel290', 'pixel291', 'pixel292', 'pixel293', 'pixel294', 'pixel295', 'pixel296', 'pixel297', 'pixel298', 'pixel299', 'pixel300', 'pixel301', 'pixel302', 'pixel303', 'pixel304', 'pixel305', 'pixel306', 'pixel307', 'pixel308', 'pixel309', 'pixel310', 'pixel311', 'pixel312', 'pixel313', 'pixel314', 'pixel315', 'pixel316', 'pixel317', 'pixel318', 'pixel319', 'pixel320', 'pixel321', 'pixel322', 'pixel323', 'pixel324', 'pixel325', 'pixel326', 'pixel327', 'pixel328', 'pixel329', 'pixel330', 'pixel331', 'pixel332', 'pixel333', 'pixel334', 'pixel335', 'pixel336', 'pixel337', 'pixel338', 'pixel339', 'pixel340', 'pixel341', 'pixel342', 'pixel343', 'pixel344', 'pixel345', 'pixel346', 'pixel347', 'pixel348', 'pixel349', 'pixel350', 'pixel351', 'pixel352', 'pixel353', 'pixel354', 'pixel355', 'pixel356', 'pixel357', 'pixel358', 'pixel359', 'pixel360', 'pixel361', 'pixel362', 'pixel363', 'pixel364', 'pixel365', 'pixel366', 'pixel367', 'pixel368', 'pixel369', 'pixel370', 'pixel371', 'pixel372', 'pixel373', 'pixel374', 'pixel375', 'pixel376', 'pixel377', 'pixel378', 'pixel379', 'pixel380', 'pixel381', 'pixel382', 'pixel383', 'pixel384', 'pixel385', 'pixel386', 'pixel387', 'pixel388', 'pixel389', 'pixel390', 'pixel391', 'pixel392', 'pixel393', 'pixel394', 'pixel395', 'pixel396', 'pixel397', 'pixel398', 'pixel399', 'pixel400', 'pixel401', 'pixel402', 'pixel403', 'pixel404', 'pixel405', 'pixel406', 'pixel407', 'pixel408', 'pixel409', 'pixel410', 'pixel411', 'pixel412', 'pixel413', 'pixel414', 'pixel415', 'pixel416', 'pixel417', 'pixel418', 'pixel419', 'pixel420', 'pixel421', 'pixel422', 'pixel423', 'pixel424', 'pixel425', 'pixel426', 'pixel427', 'pixel428', 'pixel429', 'pixel430', 'pixel431', 'pixel432', 'pixel433', 'pixel434', 'pixel435', 'pixel436', 'pixel437', 'pixel438', 'pixel439', 'pixel440', 'pixel441', 'pixel442', 'pixel443', 'pixel444', 'pixel445', 'pixel446', 'pixel447', 'pixel448', 'pixel449', 'pixel450', 'pixel451', 'pixel452', 'pixel453', 'pixel454', 'pixel455', 'pixel456', 'pixel457', 'pixel458', 'pixel459', 'pixel460', 'pixel461', 'pixel462', 'pixel463', 'pixel464', 'pixel465', 'pixel466', 'pixel467', 'pixel468', 'pixel469', 'pixel470', 'pixel471', 'pixel472', 'pixel473', 'pixel474', 'pixel475', 'pixel476', 'pixel477', 'pixel478', 'pixel479', 'pixel480', 'pixel481', 'pixel482', 'pixel483', 'pixel484', 'pixel485', 'pixel486', 'pixel487', 'pixel488', 'pixel489', 'pixel490', 'pixel491', 'pixel492', 'pixel493', 'pixel494', 'pixel495', 'pixel496', 'pixel497', 'pixel498', 'pixel499', 'pixel500', 'pixel501', 'pixel502', 'pixel503', 'pixel504', 'pixel505', 'pixel506', 'pixel507', 'pixel508', 'pixel509', 'pixel510', 'pixel511', 'pixel512', 'pixel513', 'pixel514', 'pixel515', 'pixel516', 'pixel517', 'pixel518', 'pixel519', 'pixel520', 'pixel521', 'pixel522', 'pixel523', 'pixel524', 'pixel525', 'pixel526', 'pixel527', 'pixel528', 'pixel529', 'pixel530', 'pixel531', 'pixel532', 'pixel533', 'pixel534', 'pixel535', 'pixel536', 'pixel537', 'pixel538', 'pixel539', 'pixel540', 'pixel541', 'pixel542', 'pixel543', 'pixel544', 'pixel545', 'pixel546', 'pixel547', 'pixel548', 'pixel549', 'pixel550', 'pixel551', 'pixel552', 'pixel553', 'pixel554', 'pixel555', 'pixel556', 'pixel557', 'pixel558', 'pixel559', 'pixel560', 'pixel561', 'pixel562', 'pixel563', 'pixel564', 'pixel565', 'pixel566', 'pixel567', 'pixel568', 'pixel569', 'pixel570', 'pixel571', 'pixel572', 'pixel573', 'pixel574', 'pixel575', 'pixel576', 'pixel577', 'pixel578', 'pixel579', 'pixel580', 'pixel581', 'pixel582', 'pixel583', 'pixel584', 'pixel585', 'pixel586', 'pixel587', 'pixel588', 'pixel589', 'pixel590', 'pixel591', 'pixel592', 'pixel593', 'pixel594', 'pixel595', 'pixel596', 'pixel597', 'pixel598', 'pixel599', 'pixel600', 'pixel601', 'pixel602', 'pixel603', 'pixel604', 'pixel605', 'pixel606', 'pixel607', 'pixel608', 'pixel609', 'pixel610', 'pixel611', 'pixel612', 'pixel613', 'pixel614', 'pixel615', 'pixel616', 'pixel617', 'pixel618', 'pixel619', 'pixel620', 'pixel621', 'pixel622', 'pixel623', 'pixel624', 'pixel625', 'pixel626', 'pixel627', 'pixel628', 'pixel629', 'pixel630', 'pixel631', 'pixel632', 'pixel633', 'pixel634', 'pixel635', 'pixel636', 'pixel637', 'pixel638', 'pixel639', 'pixel640', 'pixel641', 'pixel642', 'pixel643', 'pixel644', 'pixel645', 'pixel646', 'pixel647', 'pixel648', 'pixel649', 'pixel650', 'pixel651', 'pixel652', 'pixel653', 'pixel654', 'pixel655', 'pixel656', 'pixel657', 'pixel658', 'pixel659', 'pixel660', 'pixel661', 'pixel662', 'pixel663', 'pixel664', 'pixel665', 'pixel666', 'pixel667', 'pixel668', 'pixel669', 'pixel670', 'pixel671', 'pixel672', 'pixel673', 'pixel674', 'pixel675', 'pixel676', 'pixel677', 'pixel678', 'pixel679', 'pixel680', 'pixel681', 'pixel682', 'pixel683', 'pixel684', 'pixel685', 'pixel686', 'pixel687', 'pixel688', 'pixel689', 'pixel690', 'pixel691', 'pixel692', 'pixel693', 'pixel694', 'pixel695', 'pixel696', 'pixel697', 'pixel698', 'pixel699', 'pixel700', 'pixel701', 'pixel702', 'pixel703', 'pixel704', 'pixel705', 'pixel706', 'pixel707', 'pixel708', 'pixel709', 'pixel710', 'pixel711', 'pixel712', 'pixel713', 'pixel714', 'pixel715', 'pixel716', 'pixel717', 'pixel718', 'pixel719', 'pixel720', 'pixel721', 'pixel722', 'pixel723', 'pixel724', 'pixel725', 'pixel726', 'pixel727', 'pixel728', 'pixel729', 'pixel730', 'pixel731', 'pixel732', 'pixel733', 'pixel734', 'pixel735', 'pixel736', 'pixel737', 'pixel738', 'pixel739', 'pixel740', 'pixel741', 'pixel742', 'pixel743', 'pixel744', 'pixel745', 'pixel746', 'pixel747', 'pixel748', 'pixel749', 'pixel750', 'pixel751', 'pixel752', 'pixel753', 'pixel754', 'pixel755', 'pixel756', 'pixel757', 'pixel758', 'pixel759', 'pixel760', 'pixel761', 'pixel762', 'pixel763', 'pixel764', 'pixel765', 'pixel766', 'pixel767', 'pixel768', 'pixel769', 'pixel770', 'pixel771', 'pixel772', 'pixel773', 'pixel774', 'pixel775', 'pixel776', 'pixel777', 'pixel778', 'pixel779', 'pixel780', 'pixel781', 'pixel782', 'pixel783', 'pixel784'], 'target_names': ['class'], 'DESCR': "**Author**: Yann LeCun, Corinna Cortes, Christopher J.C. Burges  \n**Source**: [MNIST Website](http://yann.lecun.com/exdb/mnist/) - Date unknown  \n**Please cite**:  \n\nThe MNIST database of handwritten digits with 784 features, raw data available at: http://yann.lecun.com/exdb/mnist/. It can be split in a training set of the first 60,000 examples, and a test set of 10,000 examples  \n\nIt is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting. The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.  \n\nWith some classification methods (particularly template-based methods, such as SVM and K-nearest neighbors), the error rate improves when the digits are centered by bounding box rather than center of mass. If you do this kind of pre-processing, you should report it in your publications. The MNIST database was constructed from NIST's NIST originally designated SD-3 as their training set and SD-1 as their test set. However, SD-3 is much cleaner and easier to recognize than SD-1. The reason for this can be found on the fact that SD-3 was collected among Census Bureau employees, while SD-1 was collected among high-school students. Drawing sensible conclusions from learning experiments requires that the result be independent of the choice of training set and test among the complete set of samples. Therefore it was necessary to build a new database by mixing NIST's datasets.  \n\nThe MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers. We made sure that the sets of writers of the training set and test set were disjoint. SD-1 contains 58,527 digit images written by 500 different writers. In contrast to SD-3, where blocks of data from each writer appeared in sequence, the data in SD-1 is scrambled. Writer identities for SD-1 is available and we used this information to unscramble the writers. We then split SD-1 in two: characters written by the first 250 writers went into our new training set. The remaining 250 writers were placed in our test set. Thus we had two sets with nearly 30,000 examples each. The new training set was completed with enough examples from SD-3, starting at pattern # 0, to make a full set of 60,000 training patterns. Similarly, the new test set was completed with SD-3 examples starting at pattern # 35,000 to make a full set with 60,000 test patterns. Only a subset of 10,000 test images (5,000 from SD-1 and 5,000 from SD-3) is available on this site. The full 60,000 sample training set is available.\n\nDownloaded from openml.org.", 'details': {'id': '554', 'name': 'mnist_784', 'version': '1', 'description_version': '1', 'format': 'ARFF', 'creator': ['Yann LeCun', 'Corinna Cortes', 'Christopher J.C. Burges'], 'upload_date': '2014-09-29T03:28:38', 'language': 'English', 'licence': 'Public', 'url': 'https://www.openml.org/data/v1/download/52667/mnist_784.arff', 'file_id': '52667', 'default_target_attribute': 'class', 'tag': ['AzurePilot', 'OpenML-CC18', 'OpenML100', 'study_1', 'study_123', 'study_41', 'study_99', 'vision'], 'visibility': 'public', 'status': 'active', 'processing_date': '2020-11-20 20:12:09', 'md5_checksum': '0298d579eb1b86163de7723944c7e495'}, 'categories': {}, 'url': 'https://www.openml.org/d/554'}
(70000, 784)
(60000, 784)
(60000,)
(10000, 784)
(10000,)
12.399795246999998

import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.neighbors import KNeighborsClassifier
import timeit

if __name__ == "__main__":

mnist = fetch_openml('mnist_784')
print(mnist)
X, y = mnist['data'], mnist['target']
print(X.shape)
X_train = np.array(X[:60000], dtype=float)
y_train = np.array(y[:60000], dtype=float)
X_test = np.array(X[60000:], dtype=float)
y_test = np.array(y[60000:], dtype=float)
print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)
# 对原始数据进行KNN分类
start_time = timeit.default_timer()
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_train)
print(timeit.default_timer() - start_time)
start_time = timeit.default_timer()
print(knn_clf.score(X_test, y_test))
print(timeit.default_timer() - start_time)

{'data': array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]), 'target': array(['5', '0', '4', ..., '4', '5', '6'], dtype=object), 'frame': None, 'feature_names': ['pixel1', 'pixel2', 'pixel3', 'pixel4', 'pixel5', 'pixel6', 'pixel7', 'pixel8', 'pixel9', 'pixel10', 'pixel11', 'pixel12', 'pixel13', 'pixel14', 'pixel15', 'pixel16', 'pixel17', 'pixel18', 'pixel19', 'pixel20', 'pixel21', 'pixel22', 'pixel23', 'pixel24', 'pixel25', 'pixel26', 'pixel27', 'pixel28', 'pixel29', 'pixel30', 'pixel31', 'pixel32', 'pixel33', 'pixel34', 'pixel35', 'pixel36', 'pixel37', 'pixel38', 'pixel39', 'pixel40', 'pixel41', 'pixel42', 'pixel43', 'pixel44', 'pixel45', 'pixel46', 'pixel47', 'pixel48', 'pixel49', 'pixel50', 'pixel51', 'pixel52', 'pixel53', 'pixel54', 'pixel55', 'pixel56', 'pixel57', 'pixel58', 'pixel59', 'pixel60', 'pixel61', 'pixel62', 'pixel63', 'pixel64', 'pixel65', 'pixel66', 'pixel67', 'pixel68', 'pixel69', 'pixel70', 'pixel71', 'pixel72', 'pixel73', 'pixel74', 'pixel75', 'pixel76', 'pixel77', 'pixel78', 'pixel79', 'pixel80', 'pixel81', 'pixel82', 'pixel83', 'pixel84', 'pixel85', 'pixel86', 'pixel87', 'pixel88', 'pixel89', 'pixel90', 'pixel91', 'pixel92', 'pixel93', 'pixel94', 'pixel95', 'pixel96', 'pixel97', 'pixel98', 'pixel99', 'pixel100', 'pixel101', 'pixel102', 'pixel103', 'pixel104', 'pixel105', 'pixel106', 'pixel107', 'pixel108', 'pixel109', 'pixel110', 'pixel111', 'pixel112', 'pixel113', 'pixel114', 'pixel115', 'pixel116', 'pixel117', 'pixel118', 'pixel119', 'pixel120', 'pixel121', 'pixel122', 'pixel123', 'pixel124', 'pixel125', 'pixel126', 'pixel127', 'pixel128', 'pixel129', 'pixel130', 'pixel131', 'pixel132', 'pixel133', 'pixel134', 'pixel135', 'pixel136', 'pixel137', 'pixel138', 'pixel139', 'pixel140', 'pixel141', 'pixel142', 'pixel143', 'pixel144', 'pixel145', 'pixel146', 'pixel147', 'pixel148', 'pixel149', 'pixel150', 'pixel151', 'pixel152', 'pixel153', 'pixel154', 'pixel155', 'pixel156', 'pixel157', 'pixel158', 'pixel159', 'pixel160', 'pixel161', 'pixel162', 'pixel163', 'pixel164', 'pixel165', 'pixel166', 'pixel167', 'pixel168', 'pixel169', 'pixel170', 'pixel171', 'pixel172', 'pixel173', 'pixel174', 'pixel175', 'pixel176', 'pixel177', 'pixel178', 'pixel179', 'pixel180', 'pixel181', 'pixel182', 'pixel183', 'pixel184', 'pixel185', 'pixel186', 'pixel187', 'pixel188', 'pixel189', 'pixel190', 'pixel191', 'pixel192', 'pixel193', 'pixel194', 'pixel195', 'pixel196', 'pixel197', 'pixel198', 'pixel199', 'pixel200', 'pixel201', 'pixel202', 'pixel203', 'pixel204', 'pixel205', 'pixel206', 'pixel207', 'pixel208', 'pixel209', 'pixel210', 'pixel211', 'pixel212', 'pixel213', 'pixel214', 'pixel215', 'pixel216', 'pixel217', 'pixel218', 'pixel219', 'pixel220', 'pixel221', 'pixel222', 'pixel223', 'pixel224', 'pixel225', 'pixel226', 'pixel227', 'pixel228', 'pixel229', 'pixel230', 'pixel231', 'pixel232', 'pixel233', 'pixel234', 'pixel235', 'pixel236', 'pixel237', 'pixel238', 'pixel239', 'pixel240', 'pixel241', 'pixel242', 'pixel243', 'pixel244', 'pixel245', 'pixel246', 'pixel247', 'pixel248', 'pixel249', 'pixel250', 'pixel251', 'pixel252', 'pixel253', 'pixel254', 'pixel255', 'pixel256', 'pixel257', 'pixel258', 'pixel259', 'pixel260', 'pixel261', 'pixel262', 'pixel263', 'pixel264', 'pixel265', 'pixel266', 'pixel267', 'pixel268', 'pixel269', 'pixel270', 'pixel271', 'pixel272', 'pixel273', 'pixel274', 'pixel275', 'pixel276', 'pixel277', 'pixel278', 'pixel279', 'pixel280', 'pixel281', 'pixel282', 'pixel283', 'pixel284', 'pixel285', 'pixel286', 'pixel287', 'pixel288', 'pixel289', 'pixel290', 'pixel291', 'pixel292', 'pixel293', 'pixel294', 'pixel295', 'pixel296', 'pixel297', 'pixel298', 'pixel299', 'pixel300', 'pixel301', 'pixel302', 'pixel303', 'pixel304', 'pixel305', 'pixel306', 'pixel307', 'pixel308', 'pixel309', 'pixel310', 'pixel311', 'pixel312', 'pixel313', 'pixel314', 'pixel315', 'pixel316', 'pixel317', 'pixel318', 'pixel319', 'pixel320', 'pixel321', 'pixel322', 'pixel323', 'pixel324', 'pixel325', 'pixel326', 'pixel327', 'pixel328', 'pixel329', 'pixel330', 'pixel331', 'pixel332', 'pixel333', 'pixel334', 'pixel335', 'pixel336', 'pixel337', 'pixel338', 'pixel339', 'pixel340', 'pixel341', 'pixel342', 'pixel343', 'pixel344', 'pixel345', 'pixel346', 'pixel347', 'pixel348', 'pixel349', 'pixel350', 'pixel351', 'pixel352', 'pixel353', 'pixel354', 'pixel355', 'pixel356', 'pixel357', 'pixel358', 'pixel359', 'pixel360', 'pixel361', 'pixel362', 'pixel363', 'pixel364', 'pixel365', 'pixel366', 'pixel367', 'pixel368', 'pixel369', 'pixel370', 'pixel371', 'pixel372', 'pixel373', 'pixel374', 'pixel375', 'pixel376', 'pixel377', 'pixel378', 'pixel379', 'pixel380', 'pixel381', 'pixel382', 'pixel383', 'pixel384', 'pixel385', 'pixel386', 'pixel387', 'pixel388', 'pixel389', 'pixel390', 'pixel391', 'pixel392', 'pixel393', 'pixel394', 'pixel395', 'pixel396', 'pixel397', 'pixel398', 'pixel399', 'pixel400', 'pixel401', 'pixel402', 'pixel403', 'pixel404', 'pixel405', 'pixel406', 'pixel407', 'pixel408', 'pixel409', 'pixel410', 'pixel411', 'pixel412', 'pixel413', 'pixel414', 'pixel415', 'pixel416', 'pixel417', 'pixel418', 'pixel419', 'pixel420', 'pixel421', 'pixel422', 'pixel423', 'pixel424', 'pixel425', 'pixel426', 'pixel427', 'pixel428', 'pixel429', 'pixel430', 'pixel431', 'pixel432', 'pixel433', 'pixel434', 'pixel435', 'pixel436', 'pixel437', 'pixel438', 'pixel439', 'pixel440', 'pixel441', 'pixel442', 'pixel443', 'pixel444', 'pixel445', 'pixel446', 'pixel447', 'pixel448', 'pixel449', 'pixel450', 'pixel451', 'pixel452', 'pixel453', 'pixel454', 'pixel455', 'pixel456', 'pixel457', 'pixel458', 'pixel459', 'pixel460', 'pixel461', 'pixel462', 'pixel463', 'pixel464', 'pixel465', 'pixel466', 'pixel467', 'pixel468', 'pixel469', 'pixel470', 'pixel471', 'pixel472', 'pixel473', 'pixel474', 'pixel475', 'pixel476', 'pixel477', 'pixel478', 'pixel479', 'pixel480', 'pixel481', 'pixel482', 'pixel483', 'pixel484', 'pixel485', 'pixel486', 'pixel487', 'pixel488', 'pixel489', 'pixel490', 'pixel491', 'pixel492', 'pixel493', 'pixel494', 'pixel495', 'pixel496', 'pixel497', 'pixel498', 'pixel499', 'pixel500', 'pixel501', 'pixel502', 'pixel503', 'pixel504', 'pixel505', 'pixel506', 'pixel507', 'pixel508', 'pixel509', 'pixel510', 'pixel511', 'pixel512', 'pixel513', 'pixel514', 'pixel515', 'pixel516', 'pixel517', 'pixel518', 'pixel519', 'pixel520', 'pixel521', 'pixel522', 'pixel523', 'pixel524', 'pixel525', 'pixel526', 'pixel527', 'pixel528', 'pixel529', 'pixel530', 'pixel531', 'pixel532', 'pixel533', 'pixel534', 'pixel535', 'pixel536', 'pixel537', 'pixel538', 'pixel539', 'pixel540', 'pixel541', 'pixel542', 'pixel543', 'pixel544', 'pixel545', 'pixel546', 'pixel547', 'pixel548', 'pixel549', 'pixel550', 'pixel551', 'pixel552', 'pixel553', 'pixel554', 'pixel555', 'pixel556', 'pixel557', 'pixel558', 'pixel559', 'pixel560', 'pixel561', 'pixel562', 'pixel563', 'pixel564', 'pixel565', 'pixel566', 'pixel567', 'pixel568', 'pixel569', 'pixel570', 'pixel571', 'pixel572', 'pixel573', 'pixel574', 'pixel575', 'pixel576', 'pixel577', 'pixel578', 'pixel579', 'pixel580', 'pixel581', 'pixel582', 'pixel583', 'pixel584', 'pixel585', 'pixel586', 'pixel587', 'pixel588', 'pixel589', 'pixel590', 'pixel591', 'pixel592', 'pixel593', 'pixel594', 'pixel595', 'pixel596', 'pixel597', 'pixel598', 'pixel599', 'pixel600', 'pixel601', 'pixel602', 'pixel603', 'pixel604', 'pixel605', 'pixel606', 'pixel607', 'pixel608', 'pixel609', 'pixel610', 'pixel611', 'pixel612', 'pixel613', 'pixel614', 'pixel615', 'pixel616', 'pixel617', 'pixel618', 'pixel619', 'pixel620', 'pixel621', 'pixel622', 'pixel623', 'pixel624', 'pixel625', 'pixel626', 'pixel627', 'pixel628', 'pixel629', 'pixel630', 'pixel631', 'pixel632', 'pixel633', 'pixel634', 'pixel635', 'pixel636', 'pixel637', 'pixel638', 'pixel639', 'pixel640', 'pixel641', 'pixel642', 'pixel643', 'pixel644', 'pixel645', 'pixel646', 'pixel647', 'pixel648', 'pixel649', 'pixel650', 'pixel651', 'pixel652', 'pixel653', 'pixel654', 'pixel655', 'pixel656', 'pixel657', 'pixel658', 'pixel659', 'pixel660', 'pixel661', 'pixel662', 'pixel663', 'pixel664', 'pixel665', 'pixel666', 'pixel667', 'pixel668', 'pixel669', 'pixel670', 'pixel671', 'pixel672', 'pixel673', 'pixel674', 'pixel675', 'pixel676', 'pixel677', 'pixel678', 'pixel679', 'pixel680', 'pixel681', 'pixel682', 'pixel683', 'pixel684', 'pixel685', 'pixel686', 'pixel687', 'pixel688', 'pixel689', 'pixel690', 'pixel691', 'pixel692', 'pixel693', 'pixel694', 'pixel695', 'pixel696', 'pixel697', 'pixel698', 'pixel699', 'pixel700', 'pixel701', 'pixel702', 'pixel703', 'pixel704', 'pixel705', 'pixel706', 'pixel707', 'pixel708', 'pixel709', 'pixel710', 'pixel711', 'pixel712', 'pixel713', 'pixel714', 'pixel715', 'pixel716', 'pixel717', 'pixel718', 'pixel719', 'pixel720', 'pixel721', 'pixel722', 'pixel723', 'pixel724', 'pixel725', 'pixel726', 'pixel727', 'pixel728', 'pixel729', 'pixel730', 'pixel731', 'pixel732', 'pixel733', 'pixel734', 'pixel735', 'pixel736', 'pixel737', 'pixel738', 'pixel739', 'pixel740', 'pixel741', 'pixel742', 'pixel743', 'pixel744', 'pixel745', 'pixel746', 'pixel747', 'pixel748', 'pixel749', 'pixel750', 'pixel751', 'pixel752', 'pixel753', 'pixel754', 'pixel755', 'pixel756', 'pixel757', 'pixel758', 'pixel759', 'pixel760', 'pixel761', 'pixel762', 'pixel763', 'pixel764', 'pixel765', 'pixel766', 'pixel767', 'pixel768', 'pixel769', 'pixel770', 'pixel771', 'pixel772', 'pixel773', 'pixel774', 'pixel775', 'pixel776', 'pixel777', 'pixel778', 'pixel779', 'pixel780', 'pixel781', 'pixel782', 'pixel783', 'pixel784'], 'target_names': ['class'], 'DESCR': "**Author**: Yann LeCun, Corinna Cortes, Christopher J.C. Burges  \n**Source**: [MNIST Website](http://yann.lecun.com/exdb/mnist/) - Date unknown  \n**Please cite**:  \n\nThe MNIST database of handwritten digits with 784 features, raw data available at: http://yann.lecun.com/exdb/mnist/. It can be split in a training set of the first 60,000 examples, and a test set of 10,000 examples  \n\nIt is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting. The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.  \n\nWith some classification methods (particularly template-based methods, such as SVM and K-nearest neighbors), the error rate improves when the digits are centered by bounding box rather than center of mass. If you do this kind of pre-processing, you should report it in your publications. The MNIST database was constructed from NIST's NIST originally designated SD-3 as their training set and SD-1 as their test set. However, SD-3 is much cleaner and easier to recognize than SD-1. The reason for this can be found on the fact that SD-3 was collected among Census Bureau employees, while SD-1 was collected among high-school students. Drawing sensible conclusions from learning experiments requires that the result be independent of the choice of training set and test among the complete set of samples. Therefore it was necessary to build a new database by mixing NIST's datasets.  \n\nThe MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers. We made sure that the sets of writers of the training set and test set were disjoint. SD-1 contains 58,527 digit images written by 500 different writers. In contrast to SD-3, where blocks of data from each writer appeared in sequence, the data in SD-1 is scrambled. Writer identities for SD-1 is available and we used this information to unscramble the writers. We then split SD-1 in two: characters written by the first 250 writers went into our new training set. The remaining 250 writers were placed in our test set. Thus we had two sets with nearly 30,000 examples each. The new training set was completed with enough examples from SD-3, starting at pattern # 0, to make a full set of 60,000 training patterns. Similarly, the new test set was completed with SD-3 examples starting at pattern # 35,000 to make a full set with 60,000 test patterns. Only a subset of 10,000 test images (5,000 from SD-1 and 5,000 from SD-3) is available on this site. The full 60,000 sample training set is available.\n\nDownloaded from openml.org.", 'details': {'id': '554', 'name': 'mnist_784', 'version': '1', 'description_version': '1', 'format': 'ARFF', 'creator': ['Yann LeCun', 'Corinna Cortes', 'Christopher J.C. Burges'], 'upload_date': '2014-09-29T03:28:38', 'language': 'English', 'licence': 'Public', 'url': 'https://www.openml.org/data/v1/download/52667/mnist_784.arff', 'file_id': '52667', 'default_target_attribute': 'class', 'tag': ['AzurePilot', 'OpenML-CC18', 'OpenML100', 'study_1', 'study_123', 'study_41', 'study_99', 'vision'], 'visibility': 'public', 'status': 'active', 'processing_date': '2020-11-20 20:12:09', 'md5_checksum': '0298d579eb1b86163de7723944c7e495'}, 'categories': {}, 'url': 'https://www.openml.org/d/554'}
(70000, 784)
(60000, 784)
(60000,)
(10000, 784)
(10000,)
12.444964293999998
0.9688
537.592075492

import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.neighbors import KNeighborsClassifier
from sklearn.decomposition import PCA
import timeit

if __name__ == "__main__":

mnist = fetch_openml('mnist_784')
print(mnist)
X, y = mnist['data'], mnist['target']
print(X.shape)
X_train = np.array(X[:60000], dtype=float)
y_train = np.array(y[:60000], dtype=float)
X_test = np.array(X[60000:], dtype=float)
y_test = np.array(y[60000:], dtype=float)
# print(X_train.shape)
# print(y_train.shape)
# print(X_test.shape)
# print(y_test.shape)
# 对原始数据进行KNN分类
# start_time = timeit.default_timer()
# knn_clf = KNeighborsClassifier()
# knn_clf.fit(X_train, y_train)
# print(timeit.default_timer() - start_time)
# start_time = timeit.default_timer()
# print(knn_clf.score(X_test, y_test))
# print(timeit.default_timer() - start_time)
# 对原始数据进行降维，保留90%的方差比例
pca = PCA(0.9)
pca.fit(X_train)
X_train_reduction = pca.transform(X_train)
print(X_train_reduction.shape)

{'data': array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]), 'target': array(['5', '0', '4', ..., '4', '5', '6'], dtype=object), 'frame': None, 'feature_names': ['pixel1', 'pixel2', 'pixel3', 'pixel4', 'pixel5', 'pixel6', 'pixel7', 'pixel8', 'pixel9', 'pixel10', 'pixel11', 'pixel12', 'pixel13', 'pixel14', 'pixel15', 'pixel16', 'pixel17', 'pixel18', 'pixel19', 'pixel20', 'pixel21', 'pixel22', 'pixel23', 'pixel24', 'pixel25', 'pixel26', 'pixel27', 'pixel28', 'pixel29', 'pixel30', 'pixel31', 'pixel32', 'pixel33', 'pixel34', 'pixel35', 'pixel36', 'pixel37', 'pixel38', 'pixel39', 'pixel40', 'pixel41', 'pixel42', 'pixel43', 'pixel44', 'pixel45', 'pixel46', 'pixel47', 'pixel48', 'pixel49', 'pixel50', 'pixel51', 'pixel52', 'pixel53', 'pixel54', 'pixel55', 'pixel56', 'pixel57', 'pixel58', 'pixel59', 'pixel60', 'pixel61', 'pixel62', 'pixel63', 'pixel64', 'pixel65', 'pixel66', 'pixel67', 'pixel68', 'pixel69', 'pixel70', 'pixel71', 'pixel72', 'pixel73', 'pixel74', 'pixel75', 'pixel76', 'pixel77', 'pixel78', 'pixel79', 'pixel80', 'pixel81', 'pixel82', 'pixel83', 'pixel84', 'pixel85', 'pixel86', 'pixel87', 'pixel88', 'pixel89', 'pixel90', 'pixel91', 'pixel92', 'pixel93', 'pixel94', 'pixel95', 'pixel96', 'pixel97', 'pixel98', 'pixel99', 'pixel100', 'pixel101', 'pixel102', 'pixel103', 'pixel104', 'pixel105', 'pixel106', 'pixel107', 'pixel108', 'pixel109', 'pixel110', 'pixel111', 'pixel112', 'pixel113', 'pixel114', 'pixel115', 'pixel116', 'pixel117', 'pixel118', 'pixel119', 'pixel120', 'pixel121', 'pixel122', 'pixel123', 'pixel124', 'pixel125', 'pixel126', 'pixel127', 'pixel128', 'pixel129', 'pixel130', 'pixel131', 'pixel132', 'pixel133', 'pixel134', 'pixel135', 'pixel136', 'pixel137', 'pixel138', 'pixel139', 'pixel140', 'pixel141', 'pixel142', 'pixel143', 'pixel144', 'pixel145', 'pixel146', 'pixel147', 'pixel148', 'pixel149', 'pixel150', 'pixel151', 'pixel152', 'pixel153', 'pixel154', 'pixel155', 'pixel156', 'pixel157', 'pixel158', 'pixel159', 'pixel160', 'pixel161', 'pixel162', 'pixel163', 'pixel164', 'pixel165', 'pixel166', 'pixel167', 'pixel168', 'pixel169', 'pixel170', 'pixel171', 'pixel172', 'pixel173', 'pixel174', 'pixel175', 'pixel176', 'pixel177', 'pixel178', 'pixel179', 'pixel180', 'pixel181', 'pixel182', 'pixel183', 'pixel184', 'pixel185', 'pixel186', 'pixel187', 'pixel188', 'pixel189', 'pixel190', 'pixel191', 'pixel192', 'pixel193', 'pixel194', 'pixel195', 'pixel196', 'pixel197', 'pixel198', 'pixel199', 'pixel200', 'pixel201', 'pixel202', 'pixel203', 'pixel204', 'pixel205', 'pixel206', 'pixel207', 'pixel208', 'pixel209', 'pixel210', 'pixel211', 'pixel212', 'pixel213', 'pixel214', 'pixel215', 'pixel216', 'pixel217', 'pixel218', 'pixel219', 'pixel220', 'pixel221', 'pixel222', 'pixel223', 'pixel224', 'pixel225', 'pixel226', 'pixel227', 'pixel228', 'pixel229', 'pixel230', 'pixel231', 'pixel232', 'pixel233', 'pixel234', 'pixel235', 'pixel236', 'pixel237', 'pixel238', 'pixel239', 'pixel240', 'pixel241', 'pixel242', 'pixel243', 'pixel244', 'pixel245', 'pixel246', 'pixel247', 'pixel248', 'pixel249', 'pixel250', 'pixel251', 'pixel252', 'pixel253', 'pixel254', 'pixel255', 'pixel256', 'pixel257', 'pixel258', 'pixel259', 'pixel260', 'pixel261', 'pixel262', 'pixel263', 'pixel264', 'pixel265', 'pixel266', 'pixel267', 'pixel268', 'pixel269', 'pixel270', 'pixel271', 'pixel272', 'pixel273', 'pixel274', 'pixel275', 'pixel276', 'pixel277', 'pixel278', 'pixel279', 'pixel280', 'pixel281', 'pixel282', 'pixel283', 'pixel284', 'pixel285', 'pixel286', 'pixel287', 'pixel288', 'pixel289', 'pixel290', 'pixel291', 'pixel292', 'pixel293', 'pixel294', 'pixel295', 'pixel296', 'pixel297', 'pixel298', 'pixel299', 'pixel300', 'pixel301', 'pixel302', 'pixel303', 'pixel304', 'pixel305', 'pixel306', 'pixel307', 'pixel308', 'pixel309', 'pixel310', 'pixel311', 'pixel312', 'pixel313', 'pixel314', 'pixel315', 'pixel316', 'pixel317', 'pixel318', 'pixel319', 'pixel320', 'pixel321', 'pixel322', 'pixel323', 'pixel324', 'pixel325', 'pixel326', 'pixel327', 'pixel328', 'pixel329', 'pixel330', 'pixel331', 'pixel332', 'pixel333', 'pixel334', 'pixel335', 'pixel336', 'pixel337', 'pixel338', 'pixel339', 'pixel340', 'pixel341', 'pixel342', 'pixel343', 'pixel344', 'pixel345', 'pixel346', 'pixel347', 'pixel348', 'pixel349', 'pixel350', 'pixel351', 'pixel352', 'pixel353', 'pixel354', 'pixel355', 'pixel356', 'pixel357', 'pixel358', 'pixel359', 'pixel360', 'pixel361', 'pixel362', 'pixel363', 'pixel364', 'pixel365', 'pixel366', 'pixel367', 'pixel368', 'pixel369', 'pixel370', 'pixel371', 'pixel372', 'pixel373', 'pixel374', 'pixel375', 'pixel376', 'pixel377', 'pixel378', 'pixel379', 'pixel380', 'pixel381', 'pixel382', 'pixel383', 'pixel384', 'pixel385', 'pixel386', 'pixel387', 'pixel388', 'pixel389', 'pixel390', 'pixel391', 'pixel392', 'pixel393', 'pixel394', 'pixel395', 'pixel396', 'pixel397', 'pixel398', 'pixel399', 'pixel400', 'pixel401', 'pixel402', 'pixel403', 'pixel404', 'pixel405', 'pixel406', 'pixel407', 'pixel408', 'pixel409', 'pixel410', 'pixel411', 'pixel412', 'pixel413', 'pixel414', 'pixel415', 'pixel416', 'pixel417', 'pixel418', 'pixel419', 'pixel420', 'pixel421', 'pixel422', 'pixel423', 'pixel424', 'pixel425', 'pixel426', 'pixel427', 'pixel428', 'pixel429', 'pixel430', 'pixel431', 'pixel432', 'pixel433', 'pixel434', 'pixel435', 'pixel436', 'pixel437', 'pixel438', 'pixel439', 'pixel440', 'pixel441', 'pixel442', 'pixel443', 'pixel444', 'pixel445', 'pixel446', 'pixel447', 'pixel448', 'pixel449', 'pixel450', 'pixel451', 'pixel452', 'pixel453', 'pixel454', 'pixel455', 'pixel456', 'pixel457', 'pixel458', 'pixel459', 'pixel460', 'pixel461', 'pixel462', 'pixel463', 'pixel464', 'pixel465', 'pixel466', 'pixel467', 'pixel468', 'pixel469', 'pixel470', 'pixel471', 'pixel472', 'pixel473', 'pixel474', 'pixel475', 'pixel476', 'pixel477', 'pixel478', 'pixel479', 'pixel480', 'pixel481', 'pixel482', 'pixel483', 'pixel484', 'pixel485', 'pixel486', 'pixel487', 'pixel488', 'pixel489', 'pixel490', 'pixel491', 'pixel492', 'pixel493', 'pixel494', 'pixel495', 'pixel496', 'pixel497', 'pixel498', 'pixel499', 'pixel500', 'pixel501', 'pixel502', 'pixel503', 'pixel504', 'pixel505', 'pixel506', 'pixel507', 'pixel508', 'pixel509', 'pixel510', 'pixel511', 'pixel512', 'pixel513', 'pixel514', 'pixel515', 'pixel516', 'pixel517', 'pixel518', 'pixel519', 'pixel520', 'pixel521', 'pixel522', 'pixel523', 'pixel524', 'pixel525', 'pixel526', 'pixel527', 'pixel528', 'pixel529', 'pixel530', 'pixel531', 'pixel532', 'pixel533', 'pixel534', 'pixel535', 'pixel536', 'pixel537', 'pixel538', 'pixel539', 'pixel540', 'pixel541', 'pixel542', 'pixel543', 'pixel544', 'pixel545', 'pixel546', 'pixel547', 'pixel548', 'pixel549', 'pixel550', 'pixel551', 'pixel552', 'pixel553', 'pixel554', 'pixel555', 'pixel556', 'pixel557', 'pixel558', 'pixel559', 'pixel560', 'pixel561', 'pixel562', 'pixel563', 'pixel564', 'pixel565', 'pixel566', 'pixel567', 'pixel568', 'pixel569', 'pixel570', 'pixel571', 'pixel572', 'pixel573', 'pixel574', 'pixel575', 'pixel576', 'pixel577', 'pixel578', 'pixel579', 'pixel580', 'pixel581', 'pixel582', 'pixel583', 'pixel584', 'pixel585', 'pixel586', 'pixel587', 'pixel588', 'pixel589', 'pixel590', 'pixel591', 'pixel592', 'pixel593', 'pixel594', 'pixel595', 'pixel596', 'pixel597', 'pixel598', 'pixel599', 'pixel600', 'pixel601', 'pixel602', 'pixel603', 'pixel604', 'pixel605', 'pixel606', 'pixel607', 'pixel608', 'pixel609', 'pixel610', 'pixel611', 'pixel612', 'pixel613', 'pixel614', 'pixel615', 'pixel616', 'pixel617', 'pixel618', 'pixel619', 'pixel620', 'pixel621', 'pixel622', 'pixel623', 'pixel624', 'pixel625', 'pixel626', 'pixel627', 'pixel628', 'pixel629', 'pixel630', 'pixel631', 'pixel632', 'pixel633', 'pixel634', 'pixel635', 'pixel636', 'pixel637', 'pixel638', 'pixel639', 'pixel640', 'pixel641', 'pixel642', 'pixel643', 'pixel644', 'pixel645', 'pixel646', 'pixel647', 'pixel648', 'pixel649', 'pixel650', 'pixel651', 'pixel652', 'pixel653', 'pixel654', 'pixel655', 'pixel656', 'pixel657', 'pixel658', 'pixel659', 'pixel660', 'pixel661', 'pixel662', 'pixel663', 'pixel664', 'pixel665', 'pixel666', 'pixel667', 'pixel668', 'pixel669', 'pixel670', 'pixel671', 'pixel672', 'pixel673', 'pixel674', 'pixel675', 'pixel676', 'pixel677', 'pixel678', 'pixel679', 'pixel680', 'pixel681', 'pixel682', 'pixel683', 'pixel684', 'pixel685', 'pixel686', 'pixel687', 'pixel688', 'pixel689', 'pixel690', 'pixel691', 'pixel692', 'pixel693', 'pixel694', 'pixel695', 'pixel696', 'pixel697', 'pixel698', 'pixel699', 'pixel700', 'pixel701', 'pixel702', 'pixel703', 'pixel704', 'pixel705', 'pixel706', 'pixel707', 'pixel708', 'pixel709', 'pixel710', 'pixel711', 'pixel712', 'pixel713', 'pixel714', 'pixel715', 'pixel716', 'pixel717', 'pixel718', 'pixel719', 'pixel720', 'pixel721', 'pixel722', 'pixel723', 'pixel724', 'pixel725', 'pixel726', 'pixel727', 'pixel728', 'pixel729', 'pixel730', 'pixel731', 'pixel732', 'pixel733', 'pixel734', 'pixel735', 'pixel736', 'pixel737', 'pixel738', 'pixel739', 'pixel740', 'pixel741', 'pixel742', 'pixel743', 'pixel744', 'pixel745', 'pixel746', 'pixel747', 'pixel748', 'pixel749', 'pixel750', 'pixel751', 'pixel752', 'pixel753', 'pixel754', 'pixel755', 'pixel756', 'pixel757', 'pixel758', 'pixel759', 'pixel760', 'pixel761', 'pixel762', 'pixel763', 'pixel764', 'pixel765', 'pixel766', 'pixel767', 'pixel768', 'pixel769', 'pixel770', 'pixel771', 'pixel772', 'pixel773', 'pixel774', 'pixel775', 'pixel776', 'pixel777', 'pixel778', 'pixel779', 'pixel780', 'pixel781', 'pixel782', 'pixel783', 'pixel784'], 'target_names': ['class'], 'DESCR': "**Author**: Yann LeCun, Corinna Cortes, Christopher J.C. Burges  \n**Source**: [MNIST Website](http://yann.lecun.com/exdb/mnist/) - Date unknown  \n**Please cite**:  \n\nThe MNIST database of handwritten digits with 784 features, raw data available at: http://yann.lecun.com/exdb/mnist/. It can be split in a training set of the first 60,000 examples, and a test set of 10,000 examples  \n\nIt is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting. The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.  \n\nWith some classification methods (particularly template-based methods, such as SVM and K-nearest neighbors), the error rate improves when the digits are centered by bounding box rather than center of mass. If you do this kind of pre-processing, you should report it in your publications. The MNIST database was constructed from NIST's NIST originally designated SD-3 as their training set and SD-1 as their test set. However, SD-3 is much cleaner and easier to recognize than SD-1. The reason for this can be found on the fact that SD-3 was collected among Census Bureau employees, while SD-1 was collected among high-school students. Drawing sensible conclusions from learning experiments requires that the result be independent of the choice of training set and test among the complete set of samples. Therefore it was necessary to build a new database by mixing NIST's datasets.  \n\nThe MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers. We made sure that the sets of writers of the training set and test set were disjoint. SD-1 contains 58,527 digit images written by 500 different writers. In contrast to SD-3, where blocks of data from each writer appeared in sequence, the data in SD-1 is scrambled. Writer identities for SD-1 is available and we used this information to unscramble the writers. We then split SD-1 in two: characters written by the first 250 writers went into our new training set. The remaining 250 writers were placed in our test set. Thus we had two sets with nearly 30,000 examples each. The new training set was completed with enough examples from SD-3, starting at pattern # 0, to make a full set of 60,000 training patterns. Similarly, the new test set was completed with SD-3 examples starting at pattern # 35,000 to make a full set with 60,000 test patterns. Only a subset of 10,000 test images (5,000 from SD-1 and 5,000 from SD-3) is available on this site. The full 60,000 sample training set is available.\n\nDownloaded from openml.org.", 'details': {'id': '554', 'name': 'mnist_784', 'version': '1', 'description_version': '1', 'format': 'ARFF', 'creator': ['Yann LeCun', 'Corinna Cortes', 'Christopher J.C. Burges'], 'upload_date': '2014-09-29T03:28:38', 'language': 'English', 'licence': 'Public', 'url': 'https://www.openml.org/data/v1/download/52667/mnist_784.arff', 'file_id': '52667', 'default_target_attribute': 'class', 'tag': ['AzurePilot', 'OpenML-CC18', 'OpenML100', 'study_1', 'study_123', 'study_41', 'study_99', 'vision'], 'visibility': 'public', 'status': 'active', 'processing_date': '2020-11-20 20:12:09', 'md5_checksum': '0298d579eb1b86163de7723944c7e495'}, 'categories': {}, 'url': 'https://www.openml.org/d/554'}
(70000, 784)
(60000, 87)

import numpy as np
from sklearn.datasets import fetch_openml
from sklearn.neighbors import KNeighborsClassifier
from sklearn.decomposition import PCA
import timeit

if __name__ == "__main__":

mnist = fetch_openml('mnist_784')
print(mnist)
X, y = mnist['data'], mnist['target']
print(X.shape)
X_train = np.array(X[:60000], dtype=float)
y_train = np.array(y[:60000], dtype=float)
X_test = np.array(X[60000:], dtype=float)
y_test = np.array(y[60000:], dtype=float)
# print(X_train.shape)
# print(y_train.shape)
# print(X_test.shape)
# print(y_test.shape)
# 对原始数据进行KNN分类
# start_time = timeit.default_timer()
# knn_clf = KNeighborsClassifier()
# knn_clf.fit(X_train, y_train)
# print(timeit.default_timer() - start_time)
# start_time = timeit.default_timer()
# print(knn_clf.score(X_test, y_test))
# print(timeit.default_timer() - start_time)
# 对原始数据进行降维，保留90%的方差比例
pca = PCA(0.9)
pca.fit(X_train)
X_train_reduction = pca.transform(X_train)
X_test_reduction = pca.transform(X_test)
print(X_train_reduction.shape)
# 对降维后的数据集进行KNN分类
start_time = timeit.default_timer()
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train_reduction, y_train)
print(timeit.default_timer() - start_time)
start_time = timeit.default_timer()
print(knn_clf.score(X_test_reduction, y_test))
print(timeit.default_timer() - start_time)

{'data': array([[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
...,
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.],
[0., 0., 0., ..., 0., 0., 0.]]), 'target': array(['5', '0', '4', ..., '4', '5', '6'], dtype=object), 'frame': None, 'feature_names': ['pixel1', 'pixel2', 'pixel3', 'pixel4', 'pixel5', 'pixel6', 'pixel7', 'pixel8', 'pixel9', 'pixel10', 'pixel11', 'pixel12', 'pixel13', 'pixel14', 'pixel15', 'pixel16', 'pixel17', 'pixel18', 'pixel19', 'pixel20', 'pixel21', 'pixel22', 'pixel23', 'pixel24', 'pixel25', 'pixel26', 'pixel27', 'pixel28', 'pixel29', 'pixel30', 'pixel31', 'pixel32', 'pixel33', 'pixel34', 'pixel35', 'pixel36', 'pixel37', 'pixel38', 'pixel39', 'pixel40', 'pixel41', 'pixel42', 'pixel43', 'pixel44', 'pixel45', 'pixel46', 'pixel47', 'pixel48', 'pixel49', 'pixel50', 'pixel51', 'pixel52', 'pixel53', 'pixel54', 'pixel55', 'pixel56', 'pixel57', 'pixel58', 'pixel59', 'pixel60', 'pixel61', 'pixel62', 'pixel63', 'pixel64', 'pixel65', 'pixel66', 'pixel67', 'pixel68', 'pixel69', 'pixel70', 'pixel71', 'pixel72', 'pixel73', 'pixel74', 'pixel75', 'pixel76', 'pixel77', 'pixel78', 'pixel79', 'pixel80', 'pixel81', 'pixel82', 'pixel83', 'pixel84', 'pixel85', 'pixel86', 'pixel87', 'pixel88', 'pixel89', 'pixel90', 'pixel91', 'pixel92', 'pixel93', 'pixel94', 'pixel95', 'pixel96', 'pixel97', 'pixel98', 'pixel99', 'pixel100', 'pixel101', 'pixel102', 'pixel103', 'pixel104', 'pixel105', 'pixel106', 'pixel107', 'pixel108', 'pixel109', 'pixel110', 'pixel111', 'pixel112', 'pixel113', 'pixel114', 'pixel115', 'pixel116', 'pixel117', 'pixel118', 'pixel119', 'pixel120', 'pixel121', 'pixel122', 'pixel123', 'pixel124', 'pixel125', 'pixel126', 'pixel127', 'pixel128', 'pixel129', 'pixel130', 'pixel131', 'pixel132', 'pixel133', 'pixel134', 'pixel135', 'pixel136', 'pixel137', 'pixel138', 'pixel139', 'pixel140', 'pixel141', 'pixel142', 'pixel143', 'pixel144', 'pixel145', 'pixel146', 'pixel147', 'pixel148', 'pixel149', 'pixel150', 'pixel151', 'pixel152', 'pixel153', 'pixel154', 'pixel155', 'pixel156', 'pixel157', 'pixel158', 'pixel159', 'pixel160', 'pixel161', 'pixel162', 'pixel163', 'pixel164', 'pixel165', 'pixel166', 'pixel167', 'pixel168', 'pixel169', 'pixel170', 'pixel171', 'pixel172', 'pixel173', 'pixel174', 'pixel175', 'pixel176', 'pixel177', 'pixel178', 'pixel179', 'pixel180', 'pixel181', 'pixel182', 'pixel183', 'pixel184', 'pixel185', 'pixel186', 'pixel187', 'pixel188', 'pixel189', 'pixel190', 'pixel191', 'pixel192', 'pixel193', 'pixel194', 'pixel195', 'pixel196', 'pixel197', 'pixel198', 'pixel199', 'pixel200', 'pixel201', 'pixel202', 'pixel203', 'pixel204', 'pixel205', 'pixel206', 'pixel207', 'pixel208', 'pixel209', 'pixel210', 'pixel211', 'pixel212', 'pixel213', 'pixel214', 'pixel215', 'pixel216', 'pixel217', 'pixel218', 'pixel219', 'pixel220', 'pixel221', 'pixel222', 'pixel223', 'pixel224', 'pixel225', 'pixel226', 'pixel227', 'pixel228', 'pixel229', 'pixel230', 'pixel231', 'pixel232', 'pixel233', 'pixel234', 'pixel235', 'pixel236', 'pixel237', 'pixel238', 'pixel239', 'pixel240', 'pixel241', 'pixel242', 'pixel243', 'pixel244', 'pixel245', 'pixel246', 'pixel247', 'pixel248', 'pixel249', 'pixel250', 'pixel251', 'pixel252', 'pixel253', 'pixel254', 'pixel255', 'pixel256', 'pixel257', 'pixel258', 'pixel259', 'pixel260', 'pixel261', 'pixel262', 'pixel263', 'pixel264', 'pixel265', 'pixel266', 'pixel267', 'pixel268', 'pixel269', 'pixel270', 'pixel271', 'pixel272', 'pixel273', 'pixel274', 'pixel275', 'pixel276', 'pixel277', 'pixel278', 'pixel279', 'pixel280', 'pixel281', 'pixel282', 'pixel283', 'pixel284', 'pixel285', 'pixel286', 'pixel287', 'pixel288', 'pixel289', 'pixel290', 'pixel291', 'pixel292', 'pixel293', 'pixel294', 'pixel295', 'pixel296', 'pixel297', 'pixel298', 'pixel299', 'pixel300', 'pixel301', 'pixel302', 'pixel303', 'pixel304', 'pixel305', 'pixel306', 'pixel307', 'pixel308', 'pixel309', 'pixel310', 'pixel311', 'pixel312', 'pixel313', 'pixel314', 'pixel315', 'pixel316', 'pixel317', 'pixel318', 'pixel319', 'pixel320', 'pixel321', 'pixel322', 'pixel323', 'pixel324', 'pixel325', 'pixel326', 'pixel327', 'pixel328', 'pixel329', 'pixel330', 'pixel331', 'pixel332', 'pixel333', 'pixel334', 'pixel335', 'pixel336', 'pixel337', 'pixel338', 'pixel339', 'pixel340', 'pixel341', 'pixel342', 'pixel343', 'pixel344', 'pixel345', 'pixel346', 'pixel347', 'pixel348', 'pixel349', 'pixel350', 'pixel351', 'pixel352', 'pixel353', 'pixel354', 'pixel355', 'pixel356', 'pixel357', 'pixel358', 'pixel359', 'pixel360', 'pixel361', 'pixel362', 'pixel363', 'pixel364', 'pixel365', 'pixel366', 'pixel367', 'pixel368', 'pixel369', 'pixel370', 'pixel371', 'pixel372', 'pixel373', 'pixel374', 'pixel375', 'pixel376', 'pixel377', 'pixel378', 'pixel379', 'pixel380', 'pixel381', 'pixel382', 'pixel383', 'pixel384', 'pixel385', 'pixel386', 'pixel387', 'pixel388', 'pixel389', 'pixel390', 'pixel391', 'pixel392', 'pixel393', 'pixel394', 'pixel395', 'pixel396', 'pixel397', 'pixel398', 'pixel399', 'pixel400', 'pixel401', 'pixel402', 'pixel403', 'pixel404', 'pixel405', 'pixel406', 'pixel407', 'pixel408', 'pixel409', 'pixel410', 'pixel411', 'pixel412', 'pixel413', 'pixel414', 'pixel415', 'pixel416', 'pixel417', 'pixel418', 'pixel419', 'pixel420', 'pixel421', 'pixel422', 'pixel423', 'pixel424', 'pixel425', 'pixel426', 'pixel427', 'pixel428', 'pixel429', 'pixel430', 'pixel431', 'pixel432', 'pixel433', 'pixel434', 'pixel435', 'pixel436', 'pixel437', 'pixel438', 'pixel439', 'pixel440', 'pixel441', 'pixel442', 'pixel443', 'pixel444', 'pixel445', 'pixel446', 'pixel447', 'pixel448', 'pixel449', 'pixel450', 'pixel451', 'pixel452', 'pixel453', 'pixel454', 'pixel455', 'pixel456', 'pixel457', 'pixel458', 'pixel459', 'pixel460', 'pixel461', 'pixel462', 'pixel463', 'pixel464', 'pixel465', 'pixel466', 'pixel467', 'pixel468', 'pixel469', 'pixel470', 'pixel471', 'pixel472', 'pixel473', 'pixel474', 'pixel475', 'pixel476', 'pixel477', 'pixel478', 'pixel479', 'pixel480', 'pixel481', 'pixel482', 'pixel483', 'pixel484', 'pixel485', 'pixel486', 'pixel487', 'pixel488', 'pixel489', 'pixel490', 'pixel491', 'pixel492', 'pixel493', 'pixel494', 'pixel495', 'pixel496', 'pixel497', 'pixel498', 'pixel499', 'pixel500', 'pixel501', 'pixel502', 'pixel503', 'pixel504', 'pixel505', 'pixel506', 'pixel507', 'pixel508', 'pixel509', 'pixel510', 'pixel511', 'pixel512', 'pixel513', 'pixel514', 'pixel515', 'pixel516', 'pixel517', 'pixel518', 'pixel519', 'pixel520', 'pixel521', 'pixel522', 'pixel523', 'pixel524', 'pixel525', 'pixel526', 'pixel527', 'pixel528', 'pixel529', 'pixel530', 'pixel531', 'pixel532', 'pixel533', 'pixel534', 'pixel535', 'pixel536', 'pixel537', 'pixel538', 'pixel539', 'pixel540', 'pixel541', 'pixel542', 'pixel543', 'pixel544', 'pixel545', 'pixel546', 'pixel547', 'pixel548', 'pixel549', 'pixel550', 'pixel551', 'pixel552', 'pixel553', 'pixel554', 'pixel555', 'pixel556', 'pixel557', 'pixel558', 'pixel559', 'pixel560', 'pixel561', 'pixel562', 'pixel563', 'pixel564', 'pixel565', 'pixel566', 'pixel567', 'pixel568', 'pixel569', 'pixel570', 'pixel571', 'pixel572', 'pixel573', 'pixel574', 'pixel575', 'pixel576', 'pixel577', 'pixel578', 'pixel579', 'pixel580', 'pixel581', 'pixel582', 'pixel583', 'pixel584', 'pixel585', 'pixel586', 'pixel587', 'pixel588', 'pixel589', 'pixel590', 'pixel591', 'pixel592', 'pixel593', 'pixel594', 'pixel595', 'pixel596', 'pixel597', 'pixel598', 'pixel599', 'pixel600', 'pixel601', 'pixel602', 'pixel603', 'pixel604', 'pixel605', 'pixel606', 'pixel607', 'pixel608', 'pixel609', 'pixel610', 'pixel611', 'pixel612', 'pixel613', 'pixel614', 'pixel615', 'pixel616', 'pixel617', 'pixel618', 'pixel619', 'pixel620', 'pixel621', 'pixel622', 'pixel623', 'pixel624', 'pixel625', 'pixel626', 'pixel627', 'pixel628', 'pixel629', 'pixel630', 'pixel631', 'pixel632', 'pixel633', 'pixel634', 'pixel635', 'pixel636', 'pixel637', 'pixel638', 'pixel639', 'pixel640', 'pixel641', 'pixel642', 'pixel643', 'pixel644', 'pixel645', 'pixel646', 'pixel647', 'pixel648', 'pixel649', 'pixel650', 'pixel651', 'pixel652', 'pixel653', 'pixel654', 'pixel655', 'pixel656', 'pixel657', 'pixel658', 'pixel659', 'pixel660', 'pixel661', 'pixel662', 'pixel663', 'pixel664', 'pixel665', 'pixel666', 'pixel667', 'pixel668', 'pixel669', 'pixel670', 'pixel671', 'pixel672', 'pixel673', 'pixel674', 'pixel675', 'pixel676', 'pixel677', 'pixel678', 'pixel679', 'pixel680', 'pixel681', 'pixel682', 'pixel683', 'pixel684', 'pixel685', 'pixel686', 'pixel687', 'pixel688', 'pixel689', 'pixel690', 'pixel691', 'pixel692', 'pixel693', 'pixel694', 'pixel695', 'pixel696', 'pixel697', 'pixel698', 'pixel699', 'pixel700', 'pixel701', 'pixel702', 'pixel703', 'pixel704', 'pixel705', 'pixel706', 'pixel707', 'pixel708', 'pixel709', 'pixel710', 'pixel711', 'pixel712', 'pixel713', 'pixel714', 'pixel715', 'pixel716', 'pixel717', 'pixel718', 'pixel719', 'pixel720', 'pixel721', 'pixel722', 'pixel723', 'pixel724', 'pixel725', 'pixel726', 'pixel727', 'pixel728', 'pixel729', 'pixel730', 'pixel731', 'pixel732', 'pixel733', 'pixel734', 'pixel735', 'pixel736', 'pixel737', 'pixel738', 'pixel739', 'pixel740', 'pixel741', 'pixel742', 'pixel743', 'pixel744', 'pixel745', 'pixel746', 'pixel747', 'pixel748', 'pixel749', 'pixel750', 'pixel751', 'pixel752', 'pixel753', 'pixel754', 'pixel755', 'pixel756', 'pixel757', 'pixel758', 'pixel759', 'pixel760', 'pixel761', 'pixel762', 'pixel763', 'pixel764', 'pixel765', 'pixel766', 'pixel767', 'pixel768', 'pixel769', 'pixel770', 'pixel771', 'pixel772', 'pixel773', 'pixel774', 'pixel775', 'pixel776', 'pixel777', 'pixel778', 'pixel779', 'pixel780', 'pixel781', 'pixel782', 'pixel783', 'pixel784'], 'target_names': ['class'], 'DESCR': "**Author**: Yann LeCun, Corinna Cortes, Christopher J.C. Burges  \n**Source**: [MNIST Website](http://yann.lecun.com/exdb/mnist/) - Date unknown  \n**Please cite**:  \n\nThe MNIST database of handwritten digits with 784 features, raw data available at: http://yann.lecun.com/exdb/mnist/. It can be split in a training set of the first 60,000 examples, and a test set of 10,000 examples  \n\nIt is a subset of a larger set available from NIST. The digits have been size-normalized and centered in a fixed-size image. It is a good database for people who want to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on preprocessing and formatting. The original black and white (bilevel) images from NIST were size normalized to fit in a 20x20 pixel box while preserving their aspect ratio. The resulting images contain grey levels as a result of the anti-aliasing technique used by the normalization algorithm. the images were centered in a 28x28 image by computing the center of mass of the pixels, and translating the image so as to position this point at the center of the 28x28 field.  \n\nWith some classification methods (particularly template-based methods, such as SVM and K-nearest neighbors), the error rate improves when the digits are centered by bounding box rather than center of mass. If you do this kind of pre-processing, you should report it in your publications. The MNIST database was constructed from NIST's NIST originally designated SD-3 as their training set and SD-1 as their test set. However, SD-3 is much cleaner and easier to recognize than SD-1. The reason for this can be found on the fact that SD-3 was collected among Census Bureau employees, while SD-1 was collected among high-school students. Drawing sensible conclusions from learning experiments requires that the result be independent of the choice of training set and test among the complete set of samples. Therefore it was necessary to build a new database by mixing NIST's datasets.  \n\nThe MNIST training set is composed of 30,000 patterns from SD-3 and 30,000 patterns from SD-1. Our test set was composed of 5,000 patterns from SD-3 and 5,000 patterns from SD-1. The 60,000 pattern training set contained examples from approximately 250 writers. We made sure that the sets of writers of the training set and test set were disjoint. SD-1 contains 58,527 digit images written by 500 different writers. In contrast to SD-3, where blocks of data from each writer appeared in sequence, the data in SD-1 is scrambled. Writer identities for SD-1 is available and we used this information to unscramble the writers. We then split SD-1 in two: characters written by the first 250 writers went into our new training set. The remaining 250 writers were placed in our test set. Thus we had two sets with nearly 30,000 examples each. The new training set was completed with enough examples from SD-3, starting at pattern # 0, to make a full set of 60,000 training patterns. Similarly, the new test set was completed with SD-3 examples starting at pattern # 35,000 to make a full set with 60,000 test patterns. Only a subset of 10,000 test images (5,000 from SD-1 and 5,000 from SD-3) is available on this site. The full 60,000 sample training set is available.\n\nDownloaded from openml.org.", 'details': {'id': '554', 'name': 'mnist_784', 'version': '1', 'description_version': '1', 'format': 'ARFF', 'creator': ['Yann LeCun', 'Corinna Cortes', 'Christopher J.C. Burges'], 'upload_date': '2014-09-29T03:28:38', 'language': 'English', 'licence': 'Public', 'url': 'https://www.openml.org/data/v1/download/52667/mnist_784.arff', 'file_id': '52667', 'default_target_attribute': 'class', 'tag': ['AzurePilot', 'OpenML-CC18', 'OpenML100', 'study_1', 'study_123', 'study_41', 'study_99', 'vision'], 'visibility': 'public', 'status': 'active', 'processing_date': '2020-11-20 20:12:09', 'md5_checksum': '0298d579eb1b86163de7723944c7e495'}, 'categories': {}, 'url': 'https://www.openml.org/d/554'}
(70000, 784)
(60000, 87)
0.353538382
0.9728
66.771680345

import numpy as np
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt

if __name__ == "__main__":

# 在三维空间随机生成100个样本点
np.random.seed(8888)
X_random = np.random.random(size=(100, 3))
ax = plt.axes(projection='3d')
ax.scatter3D(X_random[:, 0], X_random[:, 1], X_random[:, 2])
plt.show()

import numpy as np
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt

if __name__ == "__main__":

# 在三维空间随机生成100个样本点
np.random.seed(8888)
X_random = np.random.random(size=(100, 3))
# ax = plt.axes(projection='3d')
# ax.scatter3D(X_random[:, 0], X_random[:, 1], X_random[:, 2])
# plt.show()

def demean(X):
return X - np.mean(X, axis=0)
# demean操作
X_demean = demean(X_random)
ax = plt.axes(projection='3d')
ax.scatter3D(X_demean[:, 0], X_demean[:, 1],X_demean[:, 2])
plt.show()

import numpy as np
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt

if __name__ == "__main__":

# 在三维空间随机生成100个样本点
np.random.seed(8888)
X_random = np.random.random(size=(100, 3))
# ax = plt.axes(projection='3d')
# ax.scatter3D(X_random[:, 0], X_random[:, 1], X_random[:, 2])
# plt.show()

def demean(X):
return X - np.mean(X, axis=0)
# demean操作
X_demean = demean(X_random)
# ax = plt.axes(projection='3d')
# ax.scatter3D(X_demean[:, 0], X_demean[:, 1],X_demean[:, 2])
# plt.show()

def f(w, X):
# 目标函数
return np.sum((X.dot(w) ** 2)) / len(X)

def df(w, X):
# 梯度
return X.T.dot(X.dot(w)) * 2 / len(X)

def direction(w):
# 把w变成单位向量，只表示方向
return w / np.linalg.norm(w)

def first_component(X, initial_w, eta, n_iters=1e4, epsilon=1e-8):
"""
梯度上升法，求出第一主成分
:param X: 数据矩阵
:param initial_w: 初始的常数向量，这里需要注意的是真正待求的是常数向量，求偏导的也是常数向量
:param eta: 步长，学习率
:param n_iters: 最大迭代次数
:param epsilon: 误差值
:return:
"""
# 将初始的常数向量变成单位向量
w = direction(initial_w)
# 真实迭代次数
cur_iter = 0
while cur_iter < n_iters:
# 获取梯度
last_w = w
# 迭代更新w，不断顺着梯度方向寻找新的w
# 跟梯度下降法不同的是，梯度下降法是-，梯度上升法这里是+
w = w + eta * gradient
# 将获取的新的w重新变成单位向量
w = direction(w)
# 计算前后两次迭代后的目标函数差值的绝对值
if abs(f(w, X) - f(last_w, X)) < epsilon:
break
# 更新迭代次数
cur_iter += 1
return w

initial_w = np.random.random(X_demean.shape[1])
eta = 0.01
# 求出第一主成分
w1 = first_component(X_demean, initial_w, eta)
# 求第二主成分分量
X2 = np.empty(X_demean.shape)
for i in range(len(X_demean)):
X2[i] = X_demean[i] - X_demean[i].dot(w1) * w1
ax = plt.axes(projection='3d')
ax.scatter3D(X2[:, 0], X2[:, 1], X2[:, 2])
plt.show()

import numpy as np
from mpl_toolkits import mplot3d
import matplotlib.pyplot as plt

if __name__ == "__main__":

# 在三维空间随机生成100个样本点
np.random.seed(8888)
X_random = np.random.random(size=(100, 3))
# ax = plt.axes(projection='3d')
# ax.scatter3D(X_random[:, 0], X_random[:, 1], X_random[:, 2])
# plt.show()

def demean(X):
return X - np.mean(X, axis=0)
# demean操作
X_demean = demean(X_random)
# ax = plt.axes(projection='3d')
# ax.scatter3D(X_demean[:, 0], X_demean[:, 1],X_demean[:, 2])
# plt.show()

def f(w, X):
# 目标函数
return np.sum((X.dot(w) ** 2)) / len(X)

def df(w, X):
# 梯度
return X.T.dot(X.dot(w)) * 2 / len(X)

def direction(w):
# 把w变成单位向量，只表示方向
return w / np.linalg.norm(w)

def first_component(X, initial_w, eta, n_iters=1e4, epsilon=1e-8):
"""
梯度上升法，求出第一主成分
:param X: 数据矩阵
:param initial_w: 初始的常数向量，这里需要注意的是真正待求的是常数向量，求偏导的也是常数向量
:param eta: 步长，学习率
:param n_iters: 最大迭代次数
:param epsilon: 误差值
:return:
"""
# 将初始的常数向量变成单位向量
w = direction(initial_w)
# 真实迭代次数
cur_iter = 0
while cur_iter < n_iters:
# 获取梯度
last_w = w
# 迭代更新w，不断顺着梯度方向寻找新的w
# 跟梯度下降法不同的是，梯度下降法是-，梯度上升法这里是+
w = w + eta * gradient
# 将获取的新的w重新变成单位向量
w = direction(w)
# 计算前后两次迭代后的目标函数差值的绝对值
if abs(f(w, X) - f(last_w, X)) < epsilon:
break
# 更新迭代次数
cur_iter += 1
return w

initial_w = np.random.random(X_demean.shape[1])
eta = 0.01
# 求出第一主成分
w1 = first_component(X_demean, initial_w, eta)
# 求第二主成分分量
X2 = np.empty(X_demean.shape)
for i in range(len(X_demean)):
X2[i] = X_demean[i] - X_demean[i].dot(w1) * w1
# ax = plt.axes(projection='3d')
# ax.scatter3D(X2[:, 0], X2[:, 1], X2[:, 2])
# plt.show()
# 求第二主成分
w2 = first_component(X2, initial_w, eta)
# 求第三主成分分量
X3 = np.empty(X2.shape)
for i in range(len(X2)):
X3[i] = X2[i] - X2[i].dot(w2) * w2
ax = plt.axes(projection='3d')
ax.scatter3D(X3[:, 0], X3[:, 1], X3[:, 2])
plt.show()

from sklearn import datasets
import numpy as np

if __name__ == "__main__":

X = digits.data
y = digits.target
# 创造有噪音的手写数据集，噪音为随机正态分布的矩阵(均值为0，方差为4)
noisy_digits = X + np.random.normal(0, 4, size=X.shape)
# 由于数据量比较大，我们取出一些样例
example_digits = noisy_digits[y == 0, :][: 10]
for num in range(1, 10):
X_num = noisy_digits[y == num, :][: 10]
example_digits = np.vstack([example_digits, X_num])
print(example_digits.shape)

(100, 64)

from sklearn import datasets
import numpy as np
import matplotlib.pyplot as plt

if __name__ == "__main__":

X = digits.data
y = digits.target
# 创造有噪音的手写数据集，噪音为随机正态分布的矩阵(均值为0，方差为4)
noisy_digits = X + np.random.normal(0, 4, size=X.shape)
# 由于数据量比较大，我们取出一些样例
example_digits = noisy_digits[y == 0, :][: 10]
for num in range(1, 10):
X_num = noisy_digits[y == num, :][: 10]
example_digits = np.vstack([example_digits, X_num])
print(example_digits.shape)

def plot_digits(data):
# 绘制样本
fig, axes = plt.subplots(10, 10, figsize=(10, 10),
subplot_kw={'xticks': [], 'yticks': []},
gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i, ax in enumerate(axes.flat):
ax.imshow(data[i].reshape(8, 8), cmap='binary', interpolation='nearest', clim=(0, 16))
plt.show()

plot_digits(example_digits)

# 对样例进行降噪，保留50%的方差比例
pca = PCA(0.5)
pca.fit(noisy_digits)
print(pca.n_components_)

12

from sklearn import datasets
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

if __name__ == "__main__":

X = digits.data
y = digits.target
# 创造有噪音的手写数据集，噪音为随机正态分布的矩阵(均值为0，方差为4)
noisy_digits = X + np.random.normal(0, 4, size=X.shape)
# 由于数据量比较大，我们取出一些样例
example_digits = noisy_digits[y == 0, :][: 10]
for num in range(1, 10):
X_num = noisy_digits[y == num, :][: 10]
example_digits = np.vstack([example_digits, X_num])
print(example_digits.shape)

def plot_digits(data):
# 绘制样本
fig, axes = plt.subplots(10, 10, figsize=(10, 10),
subplot_kw={'xticks': [], 'yticks': []},
gridspec_kw=dict(hspace=0.1, wspace=0.1))
for i, ax in enumerate(axes.flat):
ax.imshow(data[i].reshape(8, 8), cmap='binary', interpolation='nearest', clim=(0, 16))
plt.show()

plot_digits(example_digits)
# 对样例进行降噪，保留50%的方差比例
pca = PCA(0.5)
pca.fit(noisy_digits)
print(pca.n_components_)
components = pca.transform(example_digits)
# 将降维后的12维数据返回到64维
filtered_digits = pca.inverse_transform(components)
plot_digits(filtered_digits)

X行数m就是样本量，列数n就是维度。W是计算出来的主成分，构成了另外一个坐标系，它其中的每一行都代表着一个方向，而第一行其实是我们说的最重要的那个方向；第二行是次重要的那个方向，以此类推。在人脸识别中，X中的每一行就是一个人脸图像，而将W中的每一行都看作是一个样本的话，我们可以说第一行是最重要的那个样本，最能反映X这个矩阵原来的那个样本特征的那个样本；第二行的样本是次重要的那个样本，它也能够非常好的反映原来的X这些样本相应的特征。则W中的每一行也可以看成是一个人脸，这个人脸就称之为是特征脸。之所以称为特征脸其实是因为每一个特征脸其实对应的是一个主成分，它相当于表达了一部分原来的样本中，这一些人脸数据对应的特征，它们也对应线性代数中特征值和特征向量的一些概念，有兴趣的朋友可以参考线性代数整理(三) 中有关特征值和特征向量的内容。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_lfw_people

if __name__ == "__main__":

faces = fetch_lfw_people()
print(faces)
# 打印数据字典
print(faces.keys())
# 打印数据量和特征数
print(faces.data.shape)
# 打印每个图像的矩阵样式
print(faces.images.shape)

{'data': array([[ 34.      ,  29.333334,  22.333334, ...,  14.666667,  16.      ,
14.      ],
[158.      , 160.66667 , 169.66667 , ..., 138.66667 , 135.33333 ,
130.33333 ],
[ 77.      ,  81.333336,  88.      , ..., 192.      , 145.33333 ,
66.333336],
...,
[ 38.      ,  41.666668,  55.333332, ...,  66.      ,  63.666668,
54.333332],
[ 16.666666,  24.333334,  60.333332, ..., 219.      , 143.33333 ,
69.333336],
[ 58.333332,  48.      ,  20.      , ..., 116.      , 106.333336,
143.33333 ]], dtype=float32), 'images': array([[[ 34.      ,  29.333334,  22.333334, ...,  20.      ,
25.666666,  30.666666],
[ 37.333332,  32.      ,  25.333334, ...,  21.      ,
26.666666,  32.      ],
[ 33.333332,  32.333332,  40.333332, ...,  23.666666,
28.      ,  35.666668],
...,
[166.      ,  97.      ,  44.333332, ...,   9.666667,
14.333333,  12.333333],
[ 64.      ,  38.666668,  30.      , ...,  12.666667,
16.      ,  14.      ],
[ 30.666666,  29.      ,  26.333334, ...,  14.666667,
16.      ,  14.      ]],

[[158.      , 160.66667 , 169.66667 , ...,  74.333336,
28.      ,  15.666667],
[156.      , 155.33333 , 163.33333 , ...,  83.      ,
25.666666,  14.      ],
[146.66667 , 143.66667 , 144.66667 , ...,  82.333336,
26.      ,  14.666667],
...,
[118.666664, 120.      , 170.      , ..., 131.33333 ,
127.333336, 126.      ],
[125.      , 117.666664, 141.33333 , ..., 133.33333 ,
132.      , 129.33333 ],
[128.66667 , 122.666664, 121.666664, ..., 138.66667 ,
135.33333 , 130.33333 ]],

[[ 77.      ,  81.333336,  88.      , ...,  71.      ,
80.666664,  65.333336],
[ 77.666664,  89.      , 104.666664, ...,  75.666664,
72.      ,  69.      ],
[ 83.      ,  97.      , 117.      , ...,  80.333336,
80.333336,  66.      ],
...,
[ 13.333333,  16.      ,  32.      , ..., 183.66667 ,
131.      ,  58.333332],
[ 46.333332,  70.      ,  94.333336, ..., 188.66667 ,
143.66667 ,  65.      ],
[112.333336, 122.      , 118.      , ..., 192.      ,
145.33333 ,  66.333336]],

...,

[[ 38.      ,  41.666668,  55.333332, ...,  28.666666,
26.333334,  29.333334],
[ 46.666668,  49.333332,  60.666668, ...,  30.333334,
26.333334,  31.      ],
[ 50.333332,  53.666668,  63.      , ...,  32.333332,
27.666666,  34.333332],
...,
[ 71.      , 123.      , 204.66667 , ...,  58.333332,
47.      ,  38.      ],
[ 70.666664,  85.666664, 148.66667 , ...,  64.      ,
57.666668,  46.666668],
[ 70.333336,  71.666664,  97.333336, ...,  66.      ,
63.666668,  54.333332]],

[[ 16.666666,  24.333334,  60.333332, ..., 177.33333 ,
175.      , 174.33333 ],
[ 16.333334,  22.333334,  55.666668, ..., 177.33333 ,
174.66667 , 174.      ],
[ 18.333334,  25.666666,  61.      , ..., 175.33333 ,
172.66667 , 172.33333 ],
...,
[ 21.666666,  21.333334,  22.666666, ..., 221.33333 ,
132.66667 ,  52.      ],
[ 22.      ,  21.666666,  22.333334, ..., 219.      ,
137.      ,  59.      ],
[ 22.333334,  22.      ,  22.666666, ..., 219.      ,
143.33333 ,  69.333336]],

[[ 58.333332,  48.      ,  20.      , ...,  66.      ,
101.666664,  94.666664],
[ 62.      ,  32.666668,  26.333334, ...,  50.      ,
89.666664, 101.333336],
[ 56.333332,  29.333334,  47.      , ...,  55.333332,
76.666664, 106.333336],
...,
[116.333336, 106.333336,  95.      , ..., 113.333336,
100.333336,  88.      ],
[116.666664, 104.666664,  93.333336, ..., 115.666664,
103.666664, 112.      ],
[116.333336, 104.      ,  95.333336, ..., 116.      ,
106.333336, 143.33333 ]]], dtype=float32), 'target': array([5360, 3434, 3807, ..., 2175,  373, 2941]), 'target_names': array(['AJ Cook', 'AJ Lamas', 'Aaron Eckhart', ..., 'Zumrati Juma',
'Zurab Tsereteli', 'Zydrunas Ilgauskas'], dtype='<U35'), 'DESCR': ".. _labeled_faces_in_the_wild_dataset:\n\nThe Labeled Faces in the Wild face recognition dataset\n------------------------------------------------------\n\nThis dataset is a collection of JPEG pictures of famous people collected\nover the internet, all details are available on the official website:\n\n    http://vis-www.cs.umass.edu/lfw/\n\nEach picture is centered on a single face. The typical task is called\nFace Verification: given a pair of two pictures, a binary classifier\nmust predict whether the two images are from the same person.\n\nAn alternative task, Face Recognition or Face Identification is:\ngiven the picture of the face of an unknown person, identify the name\nof the person by referring to a gallery of previously seen pictures of\nidentified persons.\n\nBoth Face Verification and Face Recognition are tasks that are typically\nperformed on the output of a model trained to perform Face Detection. The\nmost popular model for Face Detection is called Viola-Jones and is\nimplemented in the OpenCV library. The LFW faces were extracted by this\nface detector from various online websites.\n\n**Data Set Characteristics:**\n\n    =================   =======================\n    Classes                                5749\n    Samples total                         13233\n    Dimensionality                         5828\n    Features            real, between 0 and 255\n    =================   =======================\n\nUsage\n~~~~~\n\nscikit-learn provides two loaders that will automatically download,\ncache, parse the metadata files, decode the jpeg and convert the\ninteresting slices into memmapped numpy arrays. This dataset size is more\nthan 200 MB. The first load typically takes more than a couple of minutes\nto fully decode the relevant part of the JPEG files into numpy arrays. If\nthe dataset has  been loaded once, the following times the loading times\nless than 200ms by using a memmapped version memoized on the disk in the\n~/scikit_learn_data/lfw_home/ folder using joblib.\n\nThe first loader is used for the Face Identification task: a multi-class\nclassification task (hence supervised learning)::\n\n  >>> from sklearn.datasets import fetch_lfw_people\n  >>> lfw_people = fetch_lfw_people(min_faces_per_person=70, resize=0.4)\n\n  >>> for name in lfw_people.target_names:\n  ...     print(name)\n  ...\n  Ariel Sharon\n  Colin Powell\n  Donald Rumsfeld\n  George W Bush\n  Gerhard Schroeder\n  Hugo Chavez\n  Tony Blair\n\nThe default slice is a rectangular shape around the face, removing\nmost of the background::\n\n  >>> lfw_people.data.dtype\n  dtype('float32')\n\n  >>> lfw_people.data.shape\n  (1288, 1850)\n\n  >>> lfw_people.images.shape\n  (1288, 50, 37)\n\nEach of the 1140 faces is assigned to a single person id in the target\narray::\n\n  >>> lfw_people.target.shape\n  (1288,)\n\n  >>> list(lfw_people.target[:10])\n  [5, 6, 3, 1, 0, 1, 3, 4, 3, 0]\n\nThe second loader is typically used for the face verification task: each sample\nis a pair of two picture belonging or not to the same person::\n\n  >>> from sklearn.datasets import fetch_lfw_pairs\n  >>> lfw_pairs_train = fetch_lfw_pairs(subset='train')\n\n  >>> list(lfw_pairs_train.target_names)\n  ['Different persons', 'Same person']\n\n  >>> lfw_pairs_train.pairs.shape\n  (2200, 2, 62, 47)\n\n  >>> lfw_pairs_train.data.shape\n  (2200, 5828)\n\n  >>> lfw_pairs_train.target.shape\n  (2200,)\n\nBoth for the :func:sklearn.datasets.fetch_lfw_people and\n:func:sklearn.datasets.fetch_lfw_pairs function it is\npossible to get an additional dimension with the RGB color channels by\npassing color=True, in that case the shape will be\n(2200, 2, 62, 47, 3).\n\nThe :func:sklearn.datasets.fetch_lfw_pairs datasets is subdivided into\n3 subsets: the development train set, the development test set and\nan evaluation 10_folds set meant to compute performance metrics using a\n10-folds cross validation scheme.\n\n.. topic:: References:\n\n * Labeled Faces in the Wild: A Database for Studying Face Recognition\n   in Unconstrained Environments.\n   <http://vis-www.cs.umass.edu/lfw/lfw.pdf>_\n   Gary B. Huang, Manu Ramesh, Tamara Berg, and Erik Learned-Miller.\n   University of Massachusetts, Amherst, Technical Report 07-49, October, 2007.\n\n\nExamples\n~~~~~~~~\n\n:ref:sphx_glr_auto_examples_applications_plot_face_recognition.py\n"}
dict_keys(['data', 'images', 'target', 'target_names', 'DESCR'])
(13233, 2914)
(13233, 62, 47)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_lfw_people

if __name__ == "__main__":

faces = fetch_lfw_people()
print(faces)
# 打印数据字典
print(faces.keys())
# 打印数据量和特征数
print(faces.data.shape)
# 打印每个图像的矩阵样式
print(faces.images.shape)
# 绘制人脸图像
# 获取随机排列
random_indexes = np.random.permutation(len(faces.data))
# 取出随机排列的脸的数据
X = faces.data[random_indexes]
# 取出前36张脸
example_faces = X[: 36, :]
print(example_faces.shape)

def plot_faces(faces):
# 绘制样本
fig, axes = plt.subplots(6, 6, figsize=(10, 10),
subplot_kw={'xticks': [], 'yticks': []},
gridspec_kw=dict(hspace=0.1, wspace=0.1))
gridespec_kw = dict(hspace=0.1, wspace=0.1)
for i, ax in enumerate(axes.flat):
ax.imshow(faces[i].reshape(62, 47), cmap='bone')
plt.show()

plot_faces(example_faces)

{'data': array([[ 34.      ,  29.333334,  22.333334, ...,  14.666667,  16.      ,
14.      ],
[158.      , 160.66667 , 169.66667 , ..., 138.66667 , 135.33333 ,
130.33333 ],
[ 77.      ,  81.333336,  88.      , ..., 192.      , 145.33333 ,
66.333336],
...,
[ 38.      ,  41.666668,  55.333332, ...,  66.      ,  63.666668,
54.333332],
[ 16.666666,  24.333334,  60.333332, ..., 219.      , 143.33333 ,
69.333336],
[ 58.333332,  48.      ,  20.      , ..., 116.      , 106.333336,
143.33333 ]], dtype=float32), 'images': array([[[ 34.      ,  29.333334,  22.333334, ...,  20.      ,
25.666666,  30.666666],
[ 37.333332,  32.      ,  25.333334, ...,  21.      ,
26.666666,  32.      ],
[ 33.333332,  32.333332,  40.333332, ...,  23.666666,
28.      ,  35.666668],
...,
[166.      ,  97.      ,  44.333332, ...,   9.666667,
14.333333,  12.333333],
[ 64.      ,  38.666668,  30.      , ...,  12.666667,
16.      ,  14.      ],
[ 30.666666,  29.      ,  26.333334, ...,  14.666667,
16.      ,  14.      ]],

[[158.      , 160.66667 , 169.66667 , ...,  74.333336,
28.      ,  15.666667],
[156.      , 155.33333 , 163.33333 , ...,  83.      ,
25.666666,  14.      ],
[146.66667 , 143.66667 , 144.66667 , ...,  82.333336,
26.      ,  14.666667],
...,
[118.666664, 120.      , 170.      , ..., 131.33333 ,
127.333336, 126.      ],
[125.      , 117.666664, 141.33333 , ..., 133.33333 ,
132.      , 129.33333 ],
[128.66667 , 122.666664, 121.666664, ..., 138.66667 ,
135.33333 , 130.33333 ]],

[[ 77.      ,  81.333336,  88.      , ...,  71.      ,
80.666664,  65.333336],
[ 77.666664,  89.      , 104.666664, ...,  75.666664,
72.      ,  69.      ],
[ 83.      ,  97.      , 117.      , ...,  80.333336,
80.333336,  66.      ],
...,
[ 13.333333,  16.      ,  32.      , ..., 183.66667 ,
131.      ,  58.333332],
[ 46.333332,  70.      ,  94.333336, ..., 188.66667 ,
143.66667 ,  65.      ],
[112.333336, 122.      , 118.      , ..., 192.      ,
145.33333 ,  66.333336]],

...,

[[ 38.      ,  41.666668,  55.333332, ...,  28.666666,
26.333334,  29.333334],
[ 46.666668,  49.333332,  60.666668, ...,  30.333334,
26.333334,  31.      ],
[ 50.333332,  53.666668,  63.      , ...,  32.333332,
27.666666,  34.333332],
...,
[ 71.      , 123.      , 204.66667 , ...,  58.333332,
47.      ,  38.      ],
[ 70.666664,  85.666664, 148.66667 , ...,  64.      ,
57.666668,  46.666668],
[ 70.333336,  71.666664,  97.333336, ...,  66.      ,
63.666668,  54.333332]],

[[ 16.666666,  24.333334,  60.333332, ..., 177.33333 ,
175.      , 174.33333 ],
[ 16.333334,  22.333334,  55.666668, ..., 177.33333 ,
174.66667 , 174.      ],
[ 18.333334,  25.666666,  61.      , ..., 175.33333 ,
172.66667 , 172.33333 ],
...,
[ 21.666666,  21.333334,  22.666666, ..., 221.33333 ,
132.66667 ,  52.      ],
[ 22.      ,  21.666666,  22.333334, ..., 219.      ,
137.      ,  59.      ],
[ 22.333334,  22.      ,  22.666666, ..., 219.      ,
143.33333 ,  69.333336]],

[[ 58.333332,  48.      ,  20.      , ...,  66.      ,
101.666664,  94.666664],
[ 62.      ,  32.666668,  26.333334, ...,  50.      ,
89.666664, 101.333336],
[ 56.333332,  29.333334,  47.      , ...,  55.333332,
76.666664, 106.333336],
...,
[116.333336, 106.333336,  95.      , ..., 113.333336,
100.333336,  88.      ],
[116.666664, 104.666664,  93.333336, ..., 115.666664,
103.666664, 112.      ],
[116.333336, 104.      ,  95.333336, ..., 116.      ,
106.333336, 143.33333 ]]], dtype=float32), 'target': array([5360, 3434, 3807, ..., 2175,  373, 2941]), 'target_names': array(['AJ Cook', 'AJ Lamas', 'Aaron Eckhart', ..., 'Zumrati Juma',
'Zurab Tsereteli', 'Zydrunas Ilgauskas'], dtype='<U35'), 'DESCR': ".. _labeled_faces_in_the_wild_dataset:\n\nThe Labeled Faces in the Wild face recognition dataset\n------------------------------------------------------\n\nThis dataset is a collection of JPEG pictures of famous people collected\nover the internet, all details are available on the official website:\n\n    http://vis-www.cs.umass.edu/lfw/\n\nEach picture is centered on a single face. The typical task is called\nFace Verification: given a pair of two pictures, a binary classifier\nmust predict whether the two images are from the same person.\n\nAn alternative task, Face Recognition or Face Identification is:\ngiven the picture of the face of an unknown person, identify the name\nof the person by referring to a gallery of previously seen pictures of\nidentified persons.\n\nBoth Face Verification and Face Recognition are tasks that are typically\nperformed on the output of a model trained to perform Face Detection. The\nmost popular model for Face Detection is called Viola-Jones and is\nimplemented in the OpenCV library. The LFW faces were extracted by this\nface detector from various online websites.\n\n**Data Set Characteristics:**\n\n    =================   =======================\n    Classes                                5749\n    Samples total                         13233\n    Dimensionality                         5828\n    Features            real, between 0 and 255\n    =================   =======================\n\nUsage\n~~~~~\n\nscikit-learn provides two loaders that will automatically download,\ncache, parse the metadata files, decode the jpeg and convert the\ninteresting slices into memmapped numpy arrays. This dataset size is more\nthan 200 MB. The first load typically takes more than a couple of minutes\nto fully decode the relevant part of the JPEG files into numpy arrays. If\nthe dataset has  been loaded once, the following times the loading times\nless than 200ms by using a memmapped version memoized on the disk in the\n~/scikit_learn_data/lfw_home/ folder using joblib.\n\nThe first loader is used for the Face Identification task: a multi-class\nclassification task (hence supervised learning)::\n\n  >>> from sklearn.datasets import fetch_lfw_people\n  >>> lfw_people = fetch_lfw_people(min_faces_per_person=70, resize=0.4)\n\n  >>> for name in lfw_people.target_names:\n  ...     print(name)\n  ...\n  Ariel Sharon\n  Colin Powell\n  Donald Rumsfeld\n  George W Bush\n  Gerhard Schroeder\n  Hugo Chavez\n  Tony Blair\n\nThe default slice is a rectangular shape around the face, removing\nmost of the background::\n\n  >>> lfw_people.data.dtype\n  dtype('float32')\n\n  >>> lfw_people.data.shape\n  (1288, 1850)\n\n  >>> lfw_people.images.shape\n  (1288, 50, 37)\n\nEach of the 1140 faces is assigned to a single person id in the target\narray::\n\n  >>> lfw_people.target.shape\n  (1288,)\n\n  >>> list(lfw_people.target[:10])\n  [5, 6, 3, 1, 0, 1, 3, 4, 3, 0]\n\nThe second loader is typically used for the face verification task: each sample\nis a pair of two picture belonging or not to the same person::\n\n  >>> from sklearn.datasets import fetch_lfw_pairs\n  >>> lfw_pairs_train = fetch_lfw_pairs(subset='train')\n\n  >>> list(lfw_pairs_train.target_names)\n  ['Different persons', 'Same person']\n\n  >>> lfw_pairs_train.pairs.shape\n  (2200, 2, 62, 47)\n\n  >>> lfw_pairs_train.data.shape\n  (2200, 5828)\n\n  >>> lfw_pairs_train.target.shape\n  (2200,)\n\nBoth for the :func:sklearn.datasets.fetch_lfw_people and\n:func:sklearn.datasets.fetch_lfw_pairs function it is\npossible to get an additional dimension with the RGB color channels by\npassing color=True, in that case the shape will be\n(2200, 2, 62, 47, 3).\n\nThe :func:sklearn.datasets.fetch_lfw_pairs datasets is subdivided into\n3 subsets: the development train set, the development test set and\nan evaluation 10_folds set meant to compute performance metrics using a\n10-folds cross validation scheme.\n\n.. topic:: References:\n\n * Labeled Faces in the Wild: A Database for Studying Face Recognition\n   in Unconstrained Environments.\n   <http://vis-www.cs.umass.edu/lfw/lfw.pdf>_\n   Gary B. Huang, Manu Ramesh, Tamara Berg, and Erik Learned-Miller.\n   University of Massachusetts, Amherst, Technical Report 07-49, October, 2007.\n\n\nExamples\n~~~~~~~~\n\n:ref:sphx_glr_auto_examples_applications_plot_face_recognition.py\n"}
dict_keys(['data', 'images', 'target', 'target_names', 'DESCR'])
(13233, 2914)
(13233, 62, 47)
(36, 2914)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_lfw_people
from sklearn.decomposition import PCA
import timeit

if __name__ == "__main__":

faces = fetch_lfw_people()
print(faces)
# 打印数据字典
print(faces.keys())
# 打印数据量和特征数
print(faces.data.shape)
# 打印每个图像的矩阵样式
print(faces.images.shape)
# 绘制人脸图像
# 获取随机排列
random_indexes = np.random.permutation(len(faces.data))
# 取出随机排列的脸的数据
X = faces.data[random_indexes]
# 取出前36张脸
example_faces = X[: 36, :]
print(example_faces.shape)

def plot_faces(faces):
# 绘制样本
fig, axes = plt.subplots(6, 6, figsize=(10, 10),
subplot_kw={'xticks': [], 'yticks': []},
gridspec_kw=dict(hspace=0.1, wspace=0.1))
gridespec_kw = dict(hspace=0.1, wspace=0.1)
for i, ax in enumerate(axes.flat):
ax.imshow(faces[i].reshape(62, 47), cmap='bone')
plt.show()

# plot_faces(example_faces)
# 获取特征脸
pca = PCA(svd_solver='randomized')
start_time = timeit.default_timer()
pca.fit(X)
print(timeit.default_timer() - start_time)
print(pca.components_.shape)
# 绘制特征脸
plot_faces(pca.components_[: 36, :])

{'data': array([[ 34.      ,  29.333334,  22.333334, ...,  14.666667,  16.      ,
14.      ],
[158.      , 160.66667 , 169.66667 , ..., 138.66667 , 135.33333 ,
130.33333 ],
[ 77.      ,  81.333336,  88.      , ..., 192.      , 145.33333 ,
66.333336],
...,
[ 38.      ,  41.666668,  55.333332, ...,  66.      ,  63.666668,
54.333332],
[ 16.666666,  24.333334,  60.333332, ..., 219.      , 143.33333 ,
69.333336],
[ 58.333332,  48.      ,  20.      , ..., 116.      , 106.333336,
143.33333 ]], dtype=float32), 'images': array([[[ 34.      ,  29.333334,  22.333334, ...,  20.      ,
25.666666,  30.666666],
[ 37.333332,  32.      ,  25.333334, ...,  21.      ,
26.666666,  32.      ],
[ 33.333332,  32.333332,  40.333332, ...,  23.666666,
28.      ,  35.666668],
...,
[166.      ,  97.      ,  44.333332, ...,   9.666667,
14.333333,  12.333333],
[ 64.      ,  38.666668,  30.      , ...,  12.666667,
16.      ,  14.      ],
[ 30.666666,  29.      ,  26.333334, ...,  14.666667,
16.      ,  14.      ]],

[[158.      , 160.66667 , 169.66667 , ...,  74.333336,
28.      ,  15.666667],
[156.      , 155.33333 , 163.33333 , ...,  83.      ,
25.666666,  14.      ],
[146.66667 , 143.66667 , 144.66667 , ...,  82.333336,
26.      ,  14.666667],
...,
[118.666664, 120.      , 170.      , ..., 131.33333 ,
127.333336, 126.      ],
[125.      , 117.666664, 141.33333 , ..., 133.33333 ,
132.      , 129.33333 ],
[128.66667 , 122.666664, 121.666664, ..., 138.66667 ,
135.33333 , 130.33333 ]],

[[ 77.      ,  81.333336,  88.      , ...,  71.      ,
80.666664,  65.333336],
[ 77.666664,  89.      , 104.666664, ...,  75.666664,
72.      ,  69.      ],
[ 83.      ,  97.      , 117.      , ...,  80.333336,
80.333336,  66.      ],
...,
[ 13.333333,  16.      ,  32.      , ..., 183.66667 ,
131.      ,  58.333332],
[ 46.333332,  70.      ,  94.333336, ..., 188.66667 ,
143.66667 ,  65.      ],
[112.333336, 122.      , 118.      , ..., 192.      ,
145.33333 ,  66.333336]],

...,

[[ 38.      ,  41.666668,  55.333332, ...,  28.666666,
26.333334,  29.333334],
[ 46.666668,  49.333332,  60.666668, ...,  30.333334,
26.333334,  31.      ],
[ 50.333332,  53.666668,  63.      , ...,  32.333332,
27.666666,  34.333332],
...,
[ 71.      , 123.      , 204.66667 , ...,  58.333332,
47.      ,  38.      ],
[ 70.666664,  85.666664, 148.66667 , ...,  64.      ,
57.666668,  46.666668],
[ 70.333336,  71.666664,  97.333336, ...,  66.      ,
63.666668,  54.333332]],

[[ 16.666666,  24.333334,  60.333332, ..., 177.33333 ,
175.      , 174.33333 ],
[ 16.333334,  22.333334,  55.666668, ..., 177.33333 ,
174.66667 , 174.      ],
[ 18.333334,  25.666666,  61.      , ..., 175.33333 ,
172.66667 , 172.33333 ],
...,
[ 21.666666,  21.333334,  22.666666, ..., 221.33333 ,
132.66667 ,  52.      ],
[ 22.      ,  21.666666,  22.333334, ..., 219.      ,
137.      ,  59.      ],
[ 22.333334,  22.      ,  22.666666, ..., 219.      ,
143.33333 ,  69.333336]],

[[ 58.333332,  48.      ,  20.      , ...,  66.      ,
101.666664,  94.666664],
[ 62.      ,  32.666668,  26.333334, ...,  50.      ,
89.666664, 101.333336],
[ 56.333332,  29.333334,  47.      , ...,  55.333332,
76.666664, 106.333336],
...,
[116.333336, 106.333336,  95.      , ..., 113.333336,
100.333336,  88.      ],
[116.666664, 104.666664,  93.333336, ..., 115.666664,
103.666664, 112.      ],
[116.333336, 104.      ,  95.333336, ..., 116.      ,
106.333336, 143.33333 ]]], dtype=float32), 'target': array([5360, 3434, 3807, ..., 2175,  373, 2941]), 'target_names': array(['AJ Cook', 'AJ Lamas', 'Aaron Eckhart', ..., 'Zumrati Juma',
'Zurab Tsereteli', 'Zydrunas Ilgauskas'], dtype='<U35'), 'DESCR': ".. _labeled_faces_in_the_wild_dataset:\n\nThe Labeled Faces in the Wild face recognition dataset\n------------------------------------------------------\n\nThis dataset is a collection of JPEG pictures of famous people collected\nover the internet, all details are available on the official website:\n\n    http://vis-www.cs.umass.edu/lfw/\n\nEach picture is centered on a single face. The typical task is called\nFace Verification: given a pair of two pictures, a binary classifier\nmust predict whether the two images are from the same person.\n\nAn alternative task, Face Recognition or Face Identification is:\ngiven the picture of the face of an unknown person, identify the name\nof the person by referring to a gallery of previously seen pictures of\nidentified persons.\n\nBoth Face Verification and Face Recognition are tasks that are typically\nperformed on the output of a model trained to perform Face Detection. The\nmost popular model for Face Detection is called Viola-Jones and is\nimplemented in the OpenCV library. The LFW faces were extracted by this\nface detector from various online websites.\n\n**Data Set Characteristics:**\n\n    =================   =======================\n    Classes                                5749\n    Samples total                         13233\n    Dimensionality                         5828\n    Features            real, between 0 and 255\n    =================   =======================\n\nUsage\n~~~~~\n\nscikit-learn provides two loaders that will automatically download,\ncache, parse the metadata files, decode the jpeg and convert the\ninteresting slices into memmapped numpy arrays. This dataset size is more\nthan 200 MB. The first load typically takes more than a couple of minutes\nto fully decode the relevant part of the JPEG files into numpy arrays. If\nthe dataset has  been loaded once, the following times the loading times\nless than 200ms by using a memmapped version memoized on the disk in the\n~/scikit_learn_data/lfw_home/ folder using joblib.\n\nThe first loader is used for the Face Identification task: a multi-class\nclassification task (hence supervised learning)::\n\n  >>> from sklearn.datasets import fetch_lfw_people\n  >>> lfw_people = fetch_lfw_people(min_faces_per_person=70, resize=0.4)\n\n  >>> for name in lfw_people.target_names:\n  ...     print(name)\n  ...\n  Ariel Sharon\n  Colin Powell\n  Donald Rumsfeld\n  George W Bush\n  Gerhard Schroeder\n  Hugo Chavez\n  Tony Blair\n\nThe default slice is a rectangular shape around the face, removing\nmost of the background::\n\n  >>> lfw_people.data.dtype\n  dtype('float32')\n\n  >>> lfw_people.data.shape\n  (1288, 1850)\n\n  >>> lfw_people.images.shape\n  (1288, 50, 37)\n\nEach of the 1140 faces is assigned to a single person id in the target\narray::\n\n  >>> lfw_people.target.shape\n  (1288,)\n\n  >>> list(lfw_people.target[:10])\n  [5, 6, 3, 1, 0, 1, 3, 4, 3, 0]\n\nThe second loader is typically used for the face verification task: each sample\nis a pair of two picture belonging or not to the same person::\n\n  >>> from sklearn.datasets import fetch_lfw_pairs\n  >>> lfw_pairs_train = fetch_lfw_pairs(subset='train')\n\n  >>> list(lfw_pairs_train.target_names)\n  ['Different persons', 'Same person']\n\n  >>> lfw_pairs_train.pairs.shape\n  (2200, 2, 62, 47)\n\n  >>> lfw_pairs_train.data.shape\n  (2200, 5828)\n\n  >>> lfw_pairs_train.target.shape\n  (2200,)\n\nBoth for the :func:sklearn.datasets.fetch_lfw_people and\n:func:sklearn.datasets.fetch_lfw_pairs function it is\npossible to get an additional dimension with the RGB color channels by\npassing color=True, in that case the shape will be\n(2200, 2, 62, 47, 3).\n\nThe :func:sklearn.datasets.fetch_lfw_pairs datasets is subdivided into\n3 subsets: the development train set, the development test set and\nan evaluation 10_folds set meant to compute performance metrics using a\n10-folds cross validation scheme.\n\n.. topic:: References:\n\n * Labeled Faces in the Wild: A Database for Studying Face Recognition\n   in Unconstrained Environments.\n   <http://vis-www.cs.umass.edu/lfw/lfw.pdf>_\n   Gary B. Huang, Manu Ramesh, Tamara Berg, and Erik Learned-Miller.\n   University of Massachusetts, Amherst, Technical Report 07-49, October, 2007.\n\n\nExamples\n~~~~~~~~\n\n:ref:sphx_glr_auto_examples_applications_plot_face_recognition.py\n"}
dict_keys(['data', 'images', 'target', 'target_names', 'DESCR'])
(13233, 2914)
(13233, 62, 47)
(36, 2914)
12.20208079
(2914, 2914)

# 只获取有60张照片的人脸数据
faces = fetch_lfw_people(min_faces_per_person=60)
print(faces)
# 打印数据字典
print(faces.keys())
# 打印数据量和特征数
print(faces.data.shape)
# 打印这些人名
print(faces.target_names)
# 打印满足条件的人数
print(len(faces.target_names))

{'data': array([[138.        , 135.66667   , 127.666664  , ...,   1.6666666 ,
1.6666666 ,   0.33333334],
[ 71.333336  ,  56.        ,  67.666664  , ..., 247.66667   ,
243.        , 238.33333   ],
[ 84.333336  ,  97.333336  ,  72.333336  , ..., 114.        ,
194.33333   , 241.        ],
...,
[ 29.333334  ,  29.        ,  29.333334  , ..., 145.        ,
147.        , 141.66667   ],
[ 49.333332  ,  55.666668  ,  76.666664  , ..., 186.33333   ,
176.33333   , 161.        ],
[ 31.        ,  26.333334  ,  28.        , ...,  34.        ,
42.        ,  69.666664  ]], dtype=float32), 'images': array([[[138.        , 135.66667   , 127.666664  , ...,  69.        ,
68.333336  ,  67.333336  ],
[146.        , 139.33333   , 125.        , ...,  68.333336  ,
67.666664  ,  67.333336  ],
[150.        , 138.33333   , 124.333336  , ...,  68.333336  ,
67.666664  ,  66.666664  ],
...,
[153.        , 174.        , 110.666664  , ...,   1.6666666 ,
0.6666667 ,   0.6666667 ],
[122.        , 193.        , 167.33333   , ...,   1.3333334 ,
1.6666666 ,   1.3333334 ],
[ 88.        , 177.33333   , 206.        , ...,   1.6666666 ,
1.6666666 ,   0.33333334]],

[[ 71.333336  ,  56.        ,  67.666664  , ...,  74.333336  ,
89.666664  ,  78.666664  ],
[ 64.333336  ,  61.666668  ,  84.333336  , ...,  72.        ,
87.        ,  78.666664  ],
[ 74.        ,  76.        ,  94.333336  , ...,  69.666664  ,
84.666664  ,  83.333336  ],
...,
[ 28.333334  ,  26.666666  ,  20.666666  , ..., 242.        ,
236.33333   , 232.        ],
[ 24.        ,  20.666666  ,  18.666666  , ..., 247.        ,
242.33333   , 238.33333   ],
[ 19.666666  ,  14.666667  ,  16.666666  , ..., 247.66667   ,
243.        , 238.33333   ]],

[[ 84.333336  ,  97.333336  ,  72.333336  , ...,  82.666664  ,
51.        ,  71.333336  ],
[ 98.333336  , 101.        ,  75.        , ..., 100.        ,
54.666668  ,  60.666668  ],
[104.666664  , 100.        ,  76.        , ..., 110.666664  ,
67.        ,  62.666668  ],
...,
[ 56.        ,  56.333332  ,  55.        , ...,  91.        ,
106.666664  , 204.66667   ],
[ 58.333332  ,  58.        ,  56.666668  , ...,  90.666664  ,
140.        , 226.        ],
[ 61.666668  ,  63.        ,  63.333332  , ..., 114.        ,
194.33333   , 241.        ]],

...,

[[ 29.333334  ,  29.        ,  29.333334  , ...,  85.333336  ,
80.333336  ,  77.        ],
[ 30.        ,  31.666666  ,  43.333332  , ...,  82.        ,
85.        ,  82.333336  ],
[ 35.333332  ,  42.        ,  72.        , ...,  85.666664  ,
83.        ,  87.        ],
...,
[ 59.333332  ,  57.333332  ,  56.666668  , ..., 145.        ,
143.33333   , 144.        ],
[ 59.333332  ,  58.        ,  58.        , ..., 146.33333   ,
143.66667   , 144.        ],
[ 61.666668  ,  60.333332  ,  59.666668  , ..., 145.        ,
147.        , 141.66667   ]],

[[ 49.333332  ,  55.666668  ,  76.666664  , ..., 160.        ,
158.33333   , 149.66667   ],
[ 55.666668  ,  68.        ,  93.333336  , ..., 156.        ,
153.        , 152.        ],
[ 61.333332  ,  76.        , 104.333336  , ..., 151.66667   ,
143.        , 146.33333   ],
...,
[ 60.333332  ,  60.333332  ,  61.333332  , ..., 178.33333   ,
169.        , 165.        ],
[ 60.666668  ,  61.333332  ,  62.666668  , ..., 188.33333   ,
172.        , 168.33333   ],
[ 61.        ,  61.333332  ,  61.333332  , ..., 186.33333   ,
176.33333   , 161.        ]],

[[ 31.        ,  26.333334  ,  28.        , ...,  65.333336  ,
49.        ,  47.666668  ],
[ 31.333334  ,  29.333334  ,  34.        , ...,  71.        ,
45.666668  ,  42.333332  ],
[ 33.333332  ,  32.333332  ,  33.333332  , ...,  84.333336  ,
52.333332  ,  45.666668  ],
...,
[ 44.666668  ,  42.666668  ,  44.666668  , ...,  22.333334  ,
25.333334  ,  46.333332  ],
[ 42.333332  ,  42.333332  ,  45.        , ...,  25.333334  ,
32.666668  ,  49.666668  ],
[ 46.        ,  49.333332  ,  51.666668  , ...,  34.        ,
42.        ,  69.666664  ]]], dtype=float32), 'target': array([1, 3, 3, ..., 7, 3, 5]), 'target_names': array(['Ariel Sharon', 'Colin Powell', 'Donald Rumsfeld', 'George W Bush',
'Gerhard Schroeder', 'Hugo Chavez', 'Junichiro Koizumi',
'Tony Blair'], dtype='<U17'), 'DESCR': ".. _labeled_faces_in_the_wild_dataset:\n\nThe Labeled Faces in the Wild face recognition dataset\n------------------------------------------------------\n\nThis dataset is a collection of JPEG pictures of famous people collected\nover the internet, all details are available on the official website:\n\n    http://vis-www.cs.umass.edu/lfw/\n\nEach picture is centered on a single face. The typical task is called\nFace Verification: given a pair of two pictures, a binary classifier\nmust predict whether the two images are from the same person.\n\nAn alternative task, Face Recognition or Face Identification is:\ngiven the picture of the face of an unknown person, identify the name\nof the person by referring to a gallery of previously seen pictures of\nidentified persons.\n\nBoth Face Verification and Face Recognition are tasks that are typically\nperformed on the output of a model trained to perform Face Detection. The\nmost popular model for Face Detection is called Viola-Jones and is\nimplemented in the OpenCV library. The LFW faces were extracted by this\nface detector from various online websites.\n\n**Data Set Characteristics:**\n\n    =================   =======================\n    Classes                                5749\n    Samples total                         13233\n    Dimensionality                         5828\n    Features            real, between 0 and 255\n    =================   =======================\n\nUsage\n~~~~~\n\nscikit-learn provides two loaders that will automatically download,\ncache, parse the metadata files, decode the jpeg and convert the\ninteresting slices into memmapped numpy arrays. This dataset size is more\nthan 200 MB. The first load typically takes more than a couple of minutes\nto fully decode the relevant part of the JPEG files into numpy arrays. If\nthe dataset has  been loaded once, the following times the loading times\nless than 200ms by using a memmapped version memoized on the disk in the\n~/scikit_learn_data/lfw_home/ folder using joblib.\n\nThe first loader is used for the Face Identification task: a multi-class\nclassification task (hence supervised learning)::\n\n  >>> from sklearn.datasets import fetch_lfw_people\n  >>> lfw_people = fetch_lfw_people(min_faces_per_person=70, resize=0.4)\n\n  >>> for name in lfw_people.target_names:\n  ...     print(name)\n  ...\n  Ariel Sharon\n  Colin Powell\n  Donald Rumsfeld\n  George W Bush\n  Gerhard Schroeder\n  Hugo Chavez\n  Tony Blair\n\nThe default slice is a rectangular shape around the face, removing\nmost of the background::\n\n  >>> lfw_people.data.dtype\n  dtype('float32')\n\n  >>> lfw_people.data.shape\n  (1288, 1850)\n\n  >>> lfw_people.images.shape\n  (1288, 50, 37)\n\nEach of the 1140 faces is assigned to a single person id in the target\narray::\n\n  >>> lfw_people.target.shape\n  (1288,)\n\n  >>> list(lfw_people.target[:10])\n  [5, 6, 3, 1, 0, 1, 3, 4, 3, 0]\n\nThe second loader is typically used for the face verification task: each sample\nis a pair of two picture belonging or not to the same person::\n\n  >>> from sklearn.datasets import fetch_lfw_pairs\n  >>> lfw_pairs_train = fetch_lfw_pairs(subset='train')\n\n  >>> list(lfw_pairs_train.target_names)\n  ['Different persons', 'Same person']\n\n  >>> lfw_pairs_train.pairs.shape\n  (2200, 2, 62, 47)\n\n  >>> lfw_pairs_train.data.shape\n  (2200, 5828)\n\n  >>> lfw_pairs_train.target.shape\n  (2200,)\n\nBoth for the :func:sklearn.datasets.fetch_lfw_people and\n:func:sklearn.datasets.fetch_lfw_pairs function it is\npossible to get an additional dimension with the RGB color channels by\npassing color=True, in that case the shape will be\n(2200, 2, 62, 47, 3).\n\nThe :func:sklearn.datasets.fetch_lfw_pairs datasets is subdivided into\n3 subsets: the development train set, the development test set and\nan evaluation 10_folds set meant to compute performance metrics using a\n10-folds cross validation scheme.\n\n.. topic:: References:\n\n * Labeled Faces in the Wild: A Database for Studying Face Recognition\n   in Unconstrained Environments.\n   <http://vis-www.cs.umass.edu/lfw/lfw.pdf>_\n   Gary B. Huang, Manu Ramesh, Tamara Berg, and Erik Learned-Miller.\n   University of Massachusetts, Amherst, Technical Report 07-49, October, 2007.\n\n\nExamples\n~~~~~~~~\n\n:ref:sphx_glr_auto_examples_applications_plot_face_recognition.py\n"}
dict_keys(['data', 'images', 'target', 'target_names', 'DESCR'])
(1348, 2914)
['Ariel Sharon' 'Colin Powell' 'Donald Rumsfeld' 'George W Bush'
'Gerhard Schroeder' 'Hugo Chavez' 'Junichiro Koizumi' 'Tony Blair']
8

# 多项式回归与模型泛化

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression

if __name__ == "__main__":

x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, size=100)
plt.scatter(x, y)
plt.show()

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression

if __name__ == "__main__":

x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()
lin_reg = LinearRegression()
lin_reg.fit(X, y)
y_predict = lin_reg.predict(X)
plt.scatter(x, y)
plt.plot(x, y_predict, color='r')
plt.show()

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression

if __name__ == "__main__":

x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()
lin_reg = LinearRegression()
lin_reg.fit(X, y)
y_predict = lin_reg.predict(X)
# plt.scatter(x, y)
# plt.plot(x, y_predict, color='r')
# plt.show()
# 打印出X的平方特征的样式
print((X**2).shape)
# 将X的平方和X合并成一个新的矩阵
X2 = np.hstack([X, X**2])
print(X2.shape)
# 使用线性回归来拟合这个新的具有2个特征的矩阵
lin_reg2 = LinearRegression()
lin_reg2.fit(X2, y)
y_predict2 = lin_reg2.predict(X2)
plt.scatter(x, y)
# 由于画曲线需要按照每个点的顺序来画，所以我们需要给数据集添加顺序索引
plt.plot(np.sort(x), y_predict2[np.argsort(x)], color='r')
plt.show()
# 打印线性方程的系数
print(lin_reg2.coef_)
# 打印截距
print(lin_reg2.intercept_)


(100, 1)
(100, 2)
[1.00934827 0.50673415]
2.0050492749040076

scikit-learn中的多项式回归与Pipeline

scikit-learn中对原始数据升维封装到了sklearn.preprocessing的PolynomialFeatures中

import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import PolynomialFeatures

if __name__ == "__main__":

x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, 100)
# 添加一个二次幂的维度
poly = PolynomialFeatures(degree=2)
# 将X矩阵进行升维
poly.fit(X)
X2 = poly.transform(X)
print(X2.shape)
# 打印X2的前5行
print(X2[: 5, :])


(100, 3)
[[ 1.          2.83592438  8.04246708]
[ 1.         -2.96465367  8.78917138]
[ 1.         -1.7820983   3.17587436]
[ 1.         -0.18299249  0.03348625]
[ 1.          0.59608116  0.35531275]]

import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression

if __name__ == "__main__":

x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, 100)
# 添加一个二次幂的维度
poly = PolynomialFeatures(degree=2)
# 将X矩阵进行升维
poly.fit(X)
X2 = poly.transform(X)
print(X2.shape)
# 打印X2的前5行
print(X2[: 5, :])
lin_reg = LinearRegression()
lin_reg.fit(X2, y)
y_predict = lin_reg.predict(X2)
plt.scatter(x, y)
# 由于画曲线需要按照每个点的顺序来画，所以我们需要给数据集添加顺序索引
plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
plt.show()


(100, 3)
[[ 1.          2.83592438  8.04246708]
[ 1.         -2.96465367  8.78917138]
[ 1.         -1.7820983   3.17587436]
[ 1.         -0.18299249  0.03348625]
[ 1.          0.59608116  0.35531275]]
[0.         0.97626103 0.48271288]
1.9847292103040368

# 生成1到11的一维数据集，并分成两列
X = np.arange(1, 11).reshape(-1, 2)
print(X.shape)
print(X)
poly = PolynomialFeatures(degree=2)
poly.fit(X)
X2 = poly.transform(X)
print(X2.shape)
print(X2)

(5, 2)
[[ 1  2]
[ 3  4]
[ 5  6]
[ 7  8]
[ 9 10]]
(5, 6)
[[  1.   1.   2.   1.   2.   4.]
[  1.   3.   4.   9.  12.  16.]
[  1.   5.   6.  25.  30.  36.]
[  1.   7.   8.  49.  56.  64.]
[  1.   9.  10.  81.  90. 100.]]

# 生成1到11的一维数据集，并分成两列
X = np.arange(1, 11).reshape(-1, 2)
print(X.shape)
print(X)
poly = PolynomialFeatures(degree=3)
poly.fit(X)
X3 = poly.transform(X)
print(X3.shape)
print(X3)

(5, 2)
[[ 1  2]
[ 3  4]
[ 5  6]
[ 7  8]
[ 9 10]]
(5, 10)
[[   1.    1.    2.    1.    2.    4.    1.    2.    4.    8.]
[   1.    3.    4.    9.   12.   16.   27.   36.   48.   64.]
[   1.    5.    6.   25.   30.   36.  125.  150.  180.  216.]
[   1.    7.    8.   49.   56.   64.  343.  392.  448.  512.]
[   1.    9.   10.   81.   90.  100.  729.  810.  900. 1000.]]

Pipeline(管道)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

if __name__ == "__main__":

x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x ** 2 + x + 2 + np.random.normal(0, 1, 100)
poly_reg = Pipeline([
("poly", PolynomialFeatures(degree=2)),
("std_scaler", StandardScaler()),
("lin_reg", LinearRegression())
])
poly_reg.fit(X, y)
y_predict = poly_reg.predict(X)
plt.scatter(x, y)
# 由于画曲线需要按照每个点的顺序来画，所以我们需要给数据集添加顺序索引
plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
plt.show()

import numpy as np
import matplotlib.pyplot as plt

if __name__ == "__main__":

np.random.seed(666)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2+ np.random.normal(0, 1, size=100)
plt.scatter(x, y)
plt.show()

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression

if __name__ == "__main__":

np.random.seed(666)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2+ np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()
lin_reg = LinearRegression()
lin_reg.fit(X, y)
print(lin_reg.score(X, y))
y_predict = lin_reg.predict(X)
plt.scatter(x, y)
plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
plt.show()

0.4953707811865009

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

if __name__ == "__main__":

np.random.seed(666)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2+ np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()
lin_reg = LinearRegression()
lin_reg.fit(X, y)
# print(lin_reg.score(X, y))
y_predict = lin_reg.predict(X)
# plt.scatter(x, y)
# plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
# plt.show()
# 打印均方误差
print(mean_squared_error(y, y_predict))

3.0750025765636577

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler

if __name__ == "__main__":

np.random.seed(666)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2+ np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()
lin_reg = LinearRegression()
lin_reg.fit(X, y)
# print(lin_reg.score(X, y))
y_predict = lin_reg.predict(X)
# plt.scatter(x, y)
# plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
# plt.show()
# 打印均方误差
print(mean_squared_error(y, y_predict))

def PolynomialRegression(degree):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("lin_reg", LinearRegression())
])

poly2_reg = PolynomialRegression(degree=2)
poly2_reg.fit(X, y)
y2_predict = poly2_reg.predict(X)
print(mean_squared_error(y, y2_predict))
plt.scatter(x, y)
plt.plot(np.sort(x), y2_predict[np.argsort(x)], color='r')
plt.show()

3.0750025765636577
1.0987392142417856

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler

if __name__ == "__main__":

np.random.seed(666)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2+ np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()
lin_reg = LinearRegression()
lin_reg.fit(X, y)
# print(lin_reg.score(X, y))
y_predict = lin_reg.predict(X)
# plt.scatter(x, y)
# plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
# plt.show()
# 打印均方误差
print(mean_squared_error(y, y_predict))

def PolynomialRegression(degree):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("lin_reg", LinearRegression())
])

poly2_reg = PolynomialRegression(degree=2)
poly2_reg.fit(X, y)
y2_predict = poly2_reg.predict(X)
print(mean_squared_error(y, y2_predict))
# plt.scatter(x, y)
# plt.plot(np.sort(x), y2_predict[np.argsort(x)], color='r')
# plt.show()

poly10_reg = PolynomialRegression(degree=10)
poly10_reg.fit(X, y)
y10_predict = poly10_reg.predict(X)
print(mean_squared_error(y, y10_predict))
plt.scatter(x, y)
plt.plot(np.sort(x), y10_predict[np.argsort(x)], color='r')
plt.show()

3.0750025765636577
1.0987392142417856
1.0508466763764148

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler

if __name__ == "__main__":

np.random.seed(666)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2+ np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()
lin_reg = LinearRegression()
lin_reg.fit(X, y)
# print(lin_reg.score(X, y))
y_predict = lin_reg.predict(X)
# plt.scatter(x, y)
# plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
# plt.show()
# 打印均方误差
print(mean_squared_error(y, y_predict))

def PolynomialRegression(degree):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("lin_reg", LinearRegression())
])

poly2_reg = PolynomialRegression(degree=2)
poly2_reg.fit(X, y)
y2_predict = poly2_reg.predict(X)
print(mean_squared_error(y, y2_predict))
# plt.scatter(x, y)
# plt.plot(np.sort(x), y2_predict[np.argsort(x)], color='r')
# plt.show()

poly10_reg = PolynomialRegression(degree=10)
poly10_reg.fit(X, y)
y10_predict = poly10_reg.predict(X)
print(mean_squared_error(y, y10_predict))
# plt.scatter(x, y)
# plt.plot(np.sort(x), y10_predict[np.argsort(x)], color='r')
# plt.show()

poly100_reg = PolynomialRegression(degree=100)
poly100_reg.fit(X, y)
y100_predict = poly100_reg.predict(X)
print(mean_squared_error(y, y100_predict))
plt.scatter(x, y)
plt.plot(np.sort(x), y100_predict[np.argsort(x)], color='r')
plt.show()

3.0750025765636577
1.0987392142417856
1.0508466763764148
0.6879768981520811

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler

if __name__ == "__main__":

np.random.seed(666)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2+ np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()
lin_reg = LinearRegression()
lin_reg.fit(X, y)
# print(lin_reg.score(X, y))
y_predict = lin_reg.predict(X)
# plt.scatter(x, y)
# plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
# plt.show()
# 打印均方误差
print(mean_squared_error(y, y_predict))

def PolynomialRegression(degree):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("lin_reg", LinearRegression())
])

poly2_reg = PolynomialRegression(degree=2)
poly2_reg.fit(X, y)
y2_predict = poly2_reg.predict(X)
print(mean_squared_error(y, y2_predict))
# plt.scatter(x, y)
# plt.plot(np.sort(x), y2_predict[np.argsort(x)], color='r')
# plt.show()

poly10_reg = PolynomialRegression(degree=10)
poly10_reg.fit(X, y)
y10_predict = poly10_reg.predict(X)
print(mean_squared_error(y, y10_predict))
# plt.scatter(x, y)
# plt.plot(np.sort(x), y10_predict[np.argsort(x)], color='r')
# plt.show()

poly100_reg = PolynomialRegression(degree=100)
poly100_reg.fit(X, y)
y100_predict = poly100_reg.predict(X)
print(mean_squared_error(y, y100_predict))
# plt.scatter(x, y)
# plt.plot(np.sort(x), y100_predict[np.argsort(x)], color='r')
# plt.show()
# 重新在-3，3之间均匀生成数据点
X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
y_plot = poly100_reg.predict(X_plot)
plt.scatter(x, y)
plt.plot(X_plot[:, 0], y_plot, color='r')
# 限定预测值的范围
plt.axis([-3, 3, -1, 10])
plt.show()

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

if __name__ == "__main__":

np.random.seed(666)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2+ np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()
lin_reg = LinearRegression()
lin_reg.fit(X, y)
# print(lin_reg.score(X, y))
y_predict = lin_reg.predict(X)
# plt.scatter(x, y)
# plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
# plt.show()
# 打印均方误差
print(mean_squared_error(y, y_predict))

def PolynomialRegression(degree):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("lin_reg", LinearRegression())
])

poly2_reg = PolynomialRegression(degree=2)
poly2_reg.fit(X, y)
y2_predict = poly2_reg.predict(X)
print(mean_squared_error(y, y2_predict))
# plt.scatter(x, y)
# plt.plot(np.sort(x), y2_predict[np.argsort(x)], color='r')
# plt.show()

poly10_reg = PolynomialRegression(degree=10)
poly10_reg.fit(X, y)
y10_predict = poly10_reg.predict(X)
print(mean_squared_error(y, y10_predict))
# plt.scatter(x, y)
# plt.plot(np.sort(x), y10_predict[np.argsort(x)], color='r')
# plt.show()

poly100_reg = PolynomialRegression(degree=100)
poly100_reg.fit(X, y)
y100_predict = poly100_reg.predict(X)
print(mean_squared_error(y, y100_predict))
# plt.scatter(x, y)
# plt.plot(np.sort(x), y100_predict[np.argsort(x)], color='r')
# plt.show()
# 重新在-3，3之间均匀生成数据点
X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
y_plot = poly100_reg.predict(X_plot)
# plt.scatter(x, y)
# plt.plot(X_plot[:, 0], y_plot, color='r')
# # 限定预测值的范围
# plt.axis([-3, 3, -1, 10])
# plt.show()
# 将原始数据分成训练数据集和测试数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)
lin_reg = LinearRegression()
lin_reg.fit(X_train, y_train)
y_predict = lin_reg.predict(X_test)
print(mean_squared_error(y_test, y_predict))

3.0750025765636577
1.0987392142417856
1.0508466763764148
0.6879768981520811
2.2199965269396573

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

if __name__ == "__main__":

np.random.seed(666)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2+ np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()
lin_reg = LinearRegression()
lin_reg.fit(X, y)
# print(lin_reg.score(X, y))
y_predict = lin_reg.predict(X)
# plt.scatter(x, y)
# plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
# plt.show()
# 打印均方误差
print(mean_squared_error(y, y_predict))

def PolynomialRegression(degree):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("lin_reg", LinearRegression())
])

poly2_reg = PolynomialRegression(degree=2)
poly2_reg.fit(X, y)
y2_predict = poly2_reg.predict(X)
print(mean_squared_error(y, y2_predict))
# plt.scatter(x, y)
# plt.plot(np.sort(x), y2_predict[np.argsort(x)], color='r')
# plt.show()

poly10_reg = PolynomialRegression(degree=10)
poly10_reg.fit(X, y)
y10_predict = poly10_reg.predict(X)
print(mean_squared_error(y, y10_predict))
# plt.scatter(x, y)
# plt.plot(np.sort(x), y10_predict[np.argsort(x)], color='r')
# plt.show()

poly100_reg = PolynomialRegression(degree=100)
poly100_reg.fit(X, y)
y100_predict = poly100_reg.predict(X)
print(mean_squared_error(y, y100_predict))
# plt.scatter(x, y)
# plt.plot(np.sort(x), y100_predict[np.argsort(x)], color='r')
# plt.show()
# 重新在-3，3之间均匀生成数据点
X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
y_plot = poly100_reg.predict(X_plot)
# plt.scatter(x, y)
# plt.plot(X_plot[:, 0], y_plot, color='r')
# # 限定预测值的范围
# plt.axis([-3, 3, -1, 10])
# plt.show()
# 将原始数据分成训练数据集和测试数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)
lin_reg = LinearRegression()
lin_reg.fit(X_train, y_train)
y_predict = lin_reg.predict(X_test)
print(mean_squared_error(y_test, y_predict))

poly2_reg = PolynomialRegression(degree=2)
poly2_reg.fit(X_train, y_train)
y2_predict = poly2_reg.predict(X_test)
print(mean_squared_error(y_test, y2_predict))

3.0750025765636577
1.0987392142417856
1.0508466763764148
0.6879768981520811
2.2199965269396573
0.8035641056297901

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

if __name__ == "__main__":

np.random.seed(666)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2+ np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()
lin_reg = LinearRegression()
lin_reg.fit(X, y)
# print(lin_reg.score(X, y))
y_predict = lin_reg.predict(X)
# plt.scatter(x, y)
# plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
# plt.show()
# 打印均方误差
print(mean_squared_error(y, y_predict))

def PolynomialRegression(degree):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("lin_reg", LinearRegression())
])

poly2_reg = PolynomialRegression(degree=2)
poly2_reg.fit(X, y)
y2_predict = poly2_reg.predict(X)
print(mean_squared_error(y, y2_predict))
# plt.scatter(x, y)
# plt.plot(np.sort(x), y2_predict[np.argsort(x)], color='r')
# plt.show()

poly10_reg = PolynomialRegression(degree=10)
poly10_reg.fit(X, y)
y10_predict = poly10_reg.predict(X)
print(mean_squared_error(y, y10_predict))
# plt.scatter(x, y)
# plt.plot(np.sort(x), y10_predict[np.argsort(x)], color='r')
# plt.show()

poly100_reg = PolynomialRegression(degree=100)
poly100_reg.fit(X, y)
y100_predict = poly100_reg.predict(X)
print(mean_squared_error(y, y100_predict))
# plt.scatter(x, y)
# plt.plot(np.sort(x), y100_predict[np.argsort(x)], color='r')
# plt.show()
# 重新在-3，3之间均匀生成数据点
X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
y_plot = poly100_reg.predict(X_plot)
# plt.scatter(x, y)
# plt.plot(X_plot[:, 0], y_plot, color='r')
# # 限定预测值的范围
# plt.axis([-3, 3, -1, 10])
# plt.show()
# 将原始数据分成训练数据集和测试数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)
lin_reg = LinearRegression()
lin_reg.fit(X_train, y_train)
y_predict = lin_reg.predict(X_test)
print(mean_squared_error(y_test, y_predict))

poly2_reg = PolynomialRegression(degree=2)
poly2_reg.fit(X_train, y_train)
y2_predict = poly2_reg.predict(X_test)
print(mean_squared_error(y_test, y2_predict))

poly10_reg = PolynomialRegression(degree=10)
poly10_reg.fit(X_train, y_train)
y10_predict = poly10_reg.predict(X_test)
print(mean_squared_error(y_test, y10_predict))

3.0750025765636577
1.0987392142417856
1.0508466763764148
0.6879768981520811
2.2199965269396573
0.8035641056297901
0.9212930722150788

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

if __name__ == "__main__":

np.random.seed(666)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2+ np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()
lin_reg = LinearRegression()
lin_reg.fit(X, y)
# print(lin_reg.score(X, y))
y_predict = lin_reg.predict(X)
# plt.scatter(x, y)
# plt.plot(np.sort(x), y_predict[np.argsort(x)], color='r')
# plt.show()
# 打印均方误差
print(mean_squared_error(y, y_predict))

def PolynomialRegression(degree):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("lin_reg", LinearRegression())
])

poly2_reg = PolynomialRegression(degree=2)
poly2_reg.fit(X, y)
y2_predict = poly2_reg.predict(X)
print(mean_squared_error(y, y2_predict))
# plt.scatter(x, y)
# plt.plot(np.sort(x), y2_predict[np.argsort(x)], color='r')
# plt.show()

poly10_reg = PolynomialRegression(degree=10)
poly10_reg.fit(X, y)
y10_predict = poly10_reg.predict(X)
print(mean_squared_error(y, y10_predict))
# plt.scatter(x, y)
# plt.plot(np.sort(x), y10_predict[np.argsort(x)], color='r')
# plt.show()

poly100_reg = PolynomialRegression(degree=100)
poly100_reg.fit(X, y)
y100_predict = poly100_reg.predict(X)
print(mean_squared_error(y, y100_predict))
# plt.scatter(x, y)
# plt.plot(np.sort(x), y100_predict[np.argsort(x)], color='r')
# plt.show()
# 重新在-3，3之间均匀生成数据点
X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
y_plot = poly100_reg.predict(X_plot)
# plt.scatter(x, y)
# plt.plot(X_plot[:, 0], y_plot, color='r')
# # 限定预测值的范围
# plt.axis([-3, 3, -1, 10])
# plt.show()
# 将原始数据分成训练数据集和测试数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=666)
lin_reg = LinearRegression()
lin_reg.fit(X_train, y_train)
y_predict = lin_reg.predict(X_test)
print(mean_squared_error(y_test, y_predict))

poly2_reg = PolynomialRegression(degree=2)
poly2_reg.fit(X_train, y_train)
y2_predict = poly2_reg.predict(X_test)
print(mean_squared_error(y_test, y2_predict))

poly10_reg = PolynomialRegression(degree=10)
poly10_reg.fit(X_train, y_train)
y10_predict = poly10_reg.predict(X_test)
print(mean_squared_error(y_test, y10_predict))

poly100_reg = PolynomialRegression(degree=100)
poly100_reg.fit(X_train, y_train)
y100_predict = poly100_reg.predict(X_test)
print(mean_squared_error(y_test, y100_predict))

3.0750025765636577
1.0987392142417856
1.0508466763764148
0.6879768981520811
2.2199965269396573
0.8035641056297901
0.9212930722150788
11695751727.09301

import numpy as np
import matplotlib.pyplot as plt

if __name__ == "__main__":

np.random.seed(666)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, size=100)
plt.scatter(x, y)
plt.show()

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

if __name__ == "__main__":

np.random.seed(666)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=10)
print(X_train.shape)

(75, 1)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error

if __name__ == "__main__":

np.random.seed(666)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=10)
# print(X_train.shape)

def plot_learning_curve(algo, X_train, X_test, y_train, y_test):
# 画出学习曲线
train_score = []
test_score = []
for i in range(1, len(X_train) + 1):
# 对训练数据集的每一个数据进行拟合
algo.fit(X_train[: i], y_train[: i])
# 每次拟合后对训练数据集进行预测
y_train_predict = algo.predict(X_train[: i])
# 记录每一次训练数据集预测的均方误差
train_score.append(mean_squared_error(y_train[: i], y_train_predict))

# 每次拟合后对测试数据集进行预测
y_test_predict = algo.predict(X_test)
# 记录每一次测试数据集预测的均方误差
test_score.append(mean_squared_error(y_test, y_test_predict))
# 画出训练数据集的均方根误差线段
plt.plot([i for i in range(1, len(X_train) + 1)], np.sqrt(train_score), label="train")
# 画出测试数据集的均方根误差线段
plt.plot([i for i in range(1, len(X_train) + 1)], np.sqrt(test_score), label="test")
plt.legend()
plt.axis([0, len(X_train) + 1, 0, 4])
plt.show()

plot_learning_curve(LinearRegression(), X_train, X_test, y_train, y_test)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler

if __name__ == "__main__":

np.random.seed(666)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=10)
# print(X_train.shape)

def plot_learning_curve(algo, X_train, X_test, y_train, y_test):
# 画出学习曲线
train_score = []
test_score = []
for i in range(1, len(X_train) + 1):
# 对训练数据集的每一个数据进行拟合
algo.fit(X_train[: i], y_train[: i])
# 每次拟合后对训练数据集进行预测
y_train_predict = algo.predict(X_train[: i])
# 记录每一次训练数据集预测的均方误差
train_score.append(mean_squared_error(y_train[: i], y_train_predict))

# 每次拟合后对测试数据集进行预测
y_test_predict = algo.predict(X_test)
# 记录每一次测试数据集预测的均方误差
test_score.append(mean_squared_error(y_test, y_test_predict))
# 画出训练数据集的均方根误差线段
plt.plot([i for i in range(1, len(X_train) + 1)], np.sqrt(train_score), label="train")
# 画出测试数据集的均方根误差线段
plt.plot([i for i in range(1, len(X_train) + 1)], np.sqrt(test_score), label="test")
plt.legend()
plt.axis([0, len(X_train) + 1, 0, 4])
plt.show()

# plot_learning_curve(LinearRegression(), X_train, X_test, y_train, y_test)

def PolynomialRegression(degree):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("lin_reg", LinearRegression())
])

poly2_reg = PolynomialRegression(degree=2)
plot_learning_curve(poly2_reg, X_train, X_test, y_train, y_test)

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler

if __name__ == "__main__":

np.random.seed(666)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x**2 + x + 2 + np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=10)
# print(X_train.shape)

def plot_learning_curve(algo, X_train, X_test, y_train, y_test):
# 画出学习曲线
train_score = []
test_score = []
for i in range(1, len(X_train) + 1):
# 对训练数据集的每一个数据进行拟合
algo.fit(X_train[: i], y_train[: i])
# 每次拟合后对训练数据集进行预测
y_train_predict = algo.predict(X_train[: i])
# 记录每一次训练数据集预测的均方误差
train_score.append(mean_squared_error(y_train[: i], y_train_predict))

# 每次拟合后对测试数据集进行预测
y_test_predict = algo.predict(X_test)
# 记录每一次测试数据集预测的均方误差
test_score.append(mean_squared_error(y_test, y_test_predict))
# 画出训练数据集的均方根误差线段
plt.plot([i for i in range(1, len(X_train) + 1)], np.sqrt(train_score), label="train")
# 画出测试数据集的均方根误差线段
plt.plot([i for i in range(1, len(X_train) + 1)], np.sqrt(test_score), label="test")
plt.legend()
plt.axis([0, len(X_train) + 1, 0, 4])
plt.show()

# plot_learning_curve(LinearRegression(), X_train, X_test, y_train, y_test)

def PolynomialRegression(degree):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("lin_reg", LinearRegression())
])

poly2_reg = PolynomialRegression(degree=2)
# plot_learning_curve(poly2_reg, X_train, X_test, y_train, y_test)

poly20_reg = PolynomialRegression(degree=20)
plot_learning_curve(poly20_reg, X_train, X_test, y_train, y_test)

import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier

if __name__ == "__main__":

X = digits.data
y = digits.target
# 将手写数据集分成训练数据集和测试数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=666)
best_score, best_p, best_k = 0, 0, 0
# 分别使用2-11个数据来分类，获取最大的分类准确度
for k in range(2, 11):
# p代表用哪一种方式的距离来计算
# 当p = 1时，得到绝对值距离，也叫曼哈顿距离
# 当p = 2时，得到欧几里德距离（Euclideandistance）距离，就是两点之间的直线距离
for p in range(1, 6):
knn_clf = KNeighborsClassifier(weights="distance", n_neighbors=k, p=p)
# 将训练数据集导入
knn_clf.fit(X_train, y_train)
# 获取分类准确度
score = knn_clf.score(X_test, y_test)
if score > best_score:
best_score, best_p, best_k = score, p, k

print("Best K =", best_k)
print("Best P =", best_p)
print("Best Score =", best_score)


Best K = 3
Best P = 4
Best Score = 0.9860917941585535

import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
# 可以自动为我们完成交叉验证的过程，同时返回生成的k个模型，每个模型对应的准确率
from sklearn.model_selection import cross_val_score

if __name__ == "__main__":

X = digits.data
y = digits.target
# 将手写数据集分成训练数据集和测试数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=666)
best_score, best_p, best_k = 0, 0, 0
# 分别使用2-11个数据来分类，获取最大的分类准确度
for k in range(2, 11):
# p代表用哪一种方式的距离来计算
# 当p = 1时，得到绝对值距离，也叫曼哈顿距离
# 当p = 2时，得到欧几里德距离（Euclideandistance）距离，就是两点之间的直线距离
for p in range(1, 6):
knn_clf = KNeighborsClassifier(weights="distance", n_neighbors=k, p=p)
# 将训练数据集导入
knn_clf.fit(X_train, y_train)
# 获取分类准确度
score = knn_clf.score(X_test, y_test)
if score > best_score:
best_score, best_p, best_k = score, p, k

print("Best K =", best_k)
print("Best P =", best_p)
print("Best Score =", best_score)

knn_clf = KNeighborsClassifier()
print(cross_val_score(knn_clf, X_train, y_train))

Best K = 3
Best P = 4
Best Score = 0.9860917941585535
[0.99537037 0.98148148 0.97685185 0.97674419 0.97209302]

import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
# 可以自动为我们完成交叉验证的过程，同时返回生成的k个模型，每个模型对应的准确率
from sklearn.model_selection import cross_val_score

if __name__ == "__main__":

X = digits.data
y = digits.target
# 将手写数据集分成训练数据集和测试数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=666)
best_score, best_p, best_k = 0, 0, 0
# 分别使用2-11个数据来分类，获取最大的分类准确度
for k in range(2, 11):
# p代表用哪一种方式的距离来计算
# 当p = 1时，得到绝对值距离，也叫曼哈顿距离
# 当p = 2时，得到欧几里德距离（Euclideandistance）距离，就是两点之间的直线距离
for p in range(1, 6):
knn_clf = KNeighborsClassifier(weights="distance", n_neighbors=k, p=p)
# 将训练数据集导入
knn_clf.fit(X_train, y_train)
# 获取分类准确度
score = knn_clf.score(X_test, y_test)
if score > best_score:
best_score, best_p, best_k = score, p, k

print("Best K =", best_k)
print("Best P =", best_p)
print("Best Score =", best_score)

knn_clf = KNeighborsClassifier()
print(cross_val_score(knn_clf, X_train, y_train))

best_score, best_p, best_k = 0, 0, 0
# 分别使用2-11个数据来分类，获取最大的分类准确度
for k in range(2, 11):
# p代表用哪一种方式的距离来计算
# 当p = 1时，得到绝对值距离，也叫曼哈顿距离
# 当p = 2时，得到欧几里德距离（Euclideandistance）距离，就是两点之间的直线距离
for p in range(1, 6):
knn_clf = KNeighborsClassifier(weights="distance", n_neighbors=k, p=p)
scores = cross_val_score(knn_clf, X_train, y_train)
score = np.mean(scores)
if score > best_score:
best_score, best_p, best_k = score, p, k

print("Best K =", best_k)
print("Best P =", best_p)
print("Best Score =", best_score)

Best K = 3
Best P = 4
Best Score = 0.9860917941585535
[0.99537037 0.98148148 0.97685185 0.97674419 0.97209302]
Best K = 2
Best P = 2
Best Score = 0.9851507321274763

import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
# 可以自动为我们完成交叉验证的过程，同时返回生成的k个模型，每个模型对应的准确率
from sklearn.model_selection import cross_val_score

if __name__ == "__main__":

X = digits.data
y = digits.target
# 将手写数据集分成训练数据集和测试数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=666)
best_score, best_p, best_k = 0, 0, 0
# 分别使用2-11个数据来分类，获取最大的分类准确度
for k in range(2, 11):
# p代表用哪一种方式的距离来计算
# 当p = 1时，得到绝对值距离，也叫曼哈顿距离
# 当p = 2时，得到欧几里德距离（Euclideandistance）距离，就是两点之间的直线距离
for p in range(1, 6):
knn_clf = KNeighborsClassifier(weights="distance", n_neighbors=k, p=p)
# 将训练数据集导入
knn_clf.fit(X_train, y_train)
# 获取分类准确度
score = knn_clf.score(X_test, y_test)
if score > best_score:
best_score, best_p, best_k = score, p, k

print("Best K =", best_k)
print("Best P =", best_p)
print("Best Score =", best_score)

knn_clf = KNeighborsClassifier()
print(cross_val_score(knn_clf, X_train, y_train))

best_score, best_p, best_k = 0, 0, 0
# 分别使用2-11个数据来分类，获取最大的分类准确度
for k in range(2, 11):
# p代表用哪一种方式的距离来计算
# 当p = 1时，得到绝对值距离，也叫曼哈顿距离
# 当p = 2时，得到欧几里德距离（Euclideandistance）距离，就是两点之间的直线距离
for p in range(1, 6):
knn_clf = KNeighborsClassifier(weights="distance", n_neighbors=k, p=p)
scores = cross_val_score(knn_clf, X_train, y_train)
score = np.mean(scores)
if score > best_score:
best_score, best_p, best_k = score, p, k

print("Best K =", best_k)
print("Best P =", best_p)
print("Best Score =", best_score)
# 通过最佳超参数来重新训练和获取分类准确度
best_knn_clf = KNeighborsClassifier(weights="distance", n_neighbors=2, p=2)
best_knn_clf.fit(X_train, y_train)
print(best_knn_clf.score(X_test, y_test))

Best K = 3
Best P = 4
Best Score = 0.9860917941585535
[0.99537037 0.98148148 0.97685185 0.97674419 0.97209302]
Best K = 2
Best P = 2
Best Score = 0.9851507321274763
0.980528511821975

import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
# 可以自动为我们完成交叉验证的过程，同时返回生成的k个模型，每个模型对应的准确率
from sklearn.model_selection import cross_val_score

if __name__ == "__main__":

X = digits.data
y = digits.target
# 将手写数据集分成训练数据集和测试数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.4, random_state=666)
best_score, best_p, best_k = 0, 0, 0
# 分别使用2-11个数据来分类，获取最大的分类准确度
for k in range(2, 11):
# p代表用哪一种方式的距离来计算
# 当p = 1时，得到绝对值距离，也叫曼哈顿距离
# 当p = 2时，得到欧几里德距离（Euclideandistance）距离，就是两点之间的直线距离
for p in range(1, 6):
knn_clf = KNeighborsClassifier(weights="distance", n_neighbors=k, p=p)
# 将训练数据集导入
knn_clf.fit(X_train, y_train)
# 获取分类准确度
score = knn_clf.score(X_test, y_test)
if score > best_score:
best_score, best_p, best_k = score, p, k

print("Best K =", best_k)
print("Best P =", best_p)
print("Best Score =", best_score)

knn_clf = KNeighborsClassifier()
对
print(cross_val_score(knn_clf, X_train, y_train, cv=3))

best_score, best_p, best_k = 0, 0, 0
# 分别使用2-11个数据来分类，获取最大的分类准确度
for k in range(2, 11):
# p代表用哪一种方式的距离来计算
# 当p = 1时，得到绝对值距离，也叫曼哈顿距离
# 当p = 2时，得到欧几里德距离（Euclideandistance）距离，就是两点之间的直线距离
for p in range(1, 6):
knn_clf = KNeighborsClassifier(weights="distance", n_neighbors=k, p=p)
scores = cross_val_score(knn_clf, X_train, y_train, cv=3)
score = np.mean(scores)
if score > best_score:
best_score, best_p, best_k = score, p, k

print("Best K =", best_k)
print("Best P =", best_p)
print("Best Score =", best_score)
# 通过最佳超参数来重新训练和获取分类准确度
best_knn_clf = KNeighborsClassifier(weights="distance", n_neighbors=2, p=2)
best_knn_clf.fit(X_train, y_train)
print(best_knn_clf.score(X_test, y_test))

Best K = 3
Best P = 4
Best Score = 0.9860917941585535
[0.98888889 0.97771588 0.96935933]
Best K = 2
Best P = 2
Best Score = 0.9833023831631073
0.980528511821975

1. 降低模型复杂度
2. 减少数据维度；降噪
3. 增加样本数
4. 使用验证集
5. 模型正则化

import numpy as np
import matplotlib.pyplot as plt

if __name__ == "__main__":

np.random.seed(42)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
plt.scatter(x, y)
plt.show()

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error

if __name__ == "__main__":

np.random.seed(42)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()
lin_reg = LinearRegression()
def PolynomialRegression(degree):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("lin_reg", lin_reg)
])

poly100_reg = PolynomialRegression(degree=100)
poly100_reg.fit(X, y)
y100_predict = poly100_reg.predict(X)
print(mean_squared_error(y, y100_predict))
# 查看拟合曲线的各项系数
print(lin_reg.coef_)

X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
y_plot = poly100_reg.predict(X_plot)
plt.scatter(x, y)
plt.plot(X_plot[:, 0], y_plot, color='r')
plt.axis([-3, 3, 0, 10])
plt.show()

0.3759549053775048
[ 1.42439073e+13 -2.12466748e+00  1.54479690e+02  1.67098517e+03
-1.68354221e+04 -1.67356966e+05  7.69671197e+05  7.19167159e+06
-1.91564362e+07 -1.73386970e+08  2.94158467e+08  2.65855494e+09
-2.98779993e+09 -2.77902518e+10  2.08205157e+10  2.06137067e+11
-1.00503399e+11 -1.10794836e+12  3.27932426e+11  4.33721479e+12
-6.51849432e+11 -1.22101693e+13  4.33621529e+11  2.37019684e+13
1.39932363e+12 -2.82060254e+13 -4.10258949e+12  1.18702565e+13
3.48399831e+12  1.55678622e+13  1.94003321e+12 -1.71368365e+13
-3.89936356e+12 -1.09652871e+13 -2.30837691e+12  1.49553243e+13
3.45115415e+12  1.30288924e+13  4.29214358e+12 -8.22504346e+12
-1.50375379e+12 -1.55744264e+13 -5.98634047e+12 -3.52006563e+12
-1.59409626e+12  1.15482917e+13  2.09128451e+12  1.42192391e+13
6.69760485e+12  6.94440143e+11  2.68881401e+12 -1.08847554e+13
-3.22254621e+12 -1.37622857e+13 -6.20116568e+12 -1.61659413e+12
-4.42220524e+12  5.80908289e+12 -4.00657127e+10  1.26136586e+13
6.42455657e+12  6.94601810e+12  6.53005938e+12  6.31892663e+11
3.57264555e+12 -8.75913024e+12 -2.97618352e+12 -1.03119202e+13
-5.67563631e+12 -6.78110437e+12 -7.42635518e+12  3.73270228e+11
-3.68184310e+12  7.60158227e+12  2.32388705e+12  1.13431444e+13
6.24335439e+12  7.08572883e+12  9.38808268e+12 -1.50553406e+12
4.16083097e+12 -7.73762602e+12 -2.31468569e+12 -1.00502179e+13
-7.35406158e+12 -6.36365989e+12 -8.63079902e+12  2.40655509e+12
-6.95260658e+12  9.67202712e+12  4.10574155e+12  6.90151222e+12
1.23520497e+13  1.40956708e+12  9.16331620e+12 -7.16918098e+12
-3.34748423e+12 -5.97986471e+12 -1.81238211e+13  4.96934673e+12
9.42782890e+12]


import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split

if __name__ == "__main__":

np.random.seed(42)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()

def PolynomialRegression(degree):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("lin_reg", LinearRegression())
])

np.random.seed(666)
X_train, X_test, y_train, y_test = train_test_split(X, y)
poly_reg = PolynomialRegression(degree=20)
poly_reg.fit(X_train, y_train)
y_poly_predict = poly_reg.predict(X_test)
print(mean_squared_error(y_test, y_poly_predict))

def plot_model(model):
X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
y_plot = model.predict(X_plot)
plt.scatter(x, y)
plt.plot(X_plot[:, 0], y_plot, color='r')
plt.axis([-3, 3, 0, 6])
plt.show()

plot_model(poly_reg)

167.9401085999025

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge

if __name__ == "__main__":

np.random.seed(42)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()

def PolynomialRegression(degree):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("lin_reg", LinearRegression())
])

np.random.seed(666)
X_train, X_test, y_train, y_test = train_test_split(X, y)
poly_reg = PolynomialRegression(degree=20)
poly_reg.fit(X_train, y_train)
y_poly_predict = poly_reg.predict(X_test)
print(mean_squared_error(y_test, y_poly_predict))

def plot_model(model):
X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
y_plot = model.predict(X_plot)
plt.scatter(x, y)
plt.plot(X_plot[:, 0], y_plot, color='r')
plt.axis([-3, 3, 0, 6])
plt.show()

# plot_model(poly_reg)

# 使用岭回归，对超参数α赋值
def RidgeRegression(degree, alpha):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("ridge_reg", Ridge(alpha=alpha))
])

rigde1_reg = RidgeRegression(20, 0.0001)
rigde1_reg.fit(X_train, y_train)
y1_predict = rigde1_reg.predict(X_test)
print(mean_squared_error(y_test, y1_predict))
plot_model(rigde1_reg)

167.9401085999025
1.3233492754136291

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge

if __name__ == "__main__":

np.random.seed(42)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()

def PolynomialRegression(degree):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("lin_reg", LinearRegression())
])

np.random.seed(666)
X_train, X_test, y_train, y_test = train_test_split(X, y)
poly_reg = PolynomialRegression(degree=20)
poly_reg.fit(X_train, y_train)
y_poly_predict = poly_reg.predict(X_test)
print(mean_squared_error(y_test, y_poly_predict))

def plot_model(model):
X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
y_plot = model.predict(X_plot)
plt.scatter(x, y)
plt.plot(X_plot[:, 0], y_plot, color='r')
plt.axis([-3, 3, 0, 6])
plt.show()

# plot_model(poly_reg)

# 使用岭回归，对超参数α赋值
def RidgeRegression(degree, alpha):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("ridge_reg", Ridge(alpha=alpha))
])

rigde1_reg = RidgeRegression(20, 0.0001)
rigde1_reg.fit(X_train, y_train)
y1_predict = rigde1_reg.predict(X_test)
print(mean_squared_error(y_test, y1_predict))
# plot_model(rigde1_reg)

rigde2_reg = RidgeRegression(20, 1)
rigde2_reg.fit(X_train, y_train)
y2_predict = rigde2_reg.predict(X_test)
print(mean_squared_error(y_test, y2_predict))
plot_model(rigde2_reg)

167.9401085999025
1.3233492754136291
1.1888759304218461

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge

if __name__ == "__main__":

np.random.seed(42)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()

def PolynomialRegression(degree):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("lin_reg", LinearRegression())
])

np.random.seed(666)
X_train, X_test, y_train, y_test = train_test_split(X, y)
poly_reg = PolynomialRegression(degree=20)
poly_reg.fit(X_train, y_train)
y_poly_predict = poly_reg.predict(X_test)
print(mean_squared_error(y_test, y_poly_predict))

def plot_model(model):
X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
y_plot = model.predict(X_plot)
plt.scatter(x, y)
plt.plot(X_plot[:, 0], y_plot, color='r')
plt.axis([-3, 3, 0, 6])
plt.show()

# plot_model(poly_reg)

# 使用岭回归，对超参数α赋值
def RidgeRegression(degree, alpha):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("ridge_reg", Ridge(alpha=alpha))
])

rigde1_reg = RidgeRegression(20, 0.0001)
rigde1_reg.fit(X_train, y_train)
y1_predict = rigde1_reg.predict(X_test)
print(mean_squared_error(y_test, y1_predict))
# plot_model(rigde1_reg)

rigde2_reg = RidgeRegression(20, 1)
rigde2_reg.fit(X_train, y_train)
y2_predict = rigde2_reg.predict(X_test)
print(mean_squared_error(y_test, y2_predict))
# plot_model(rigde2_reg)

rigde3_reg = RidgeRegression(20, 100)
rigde3_reg.fit(X_train, y_train)
y3_predict = rigde3_reg.predict(X_test)
print(mean_squared_error(y_test, y3_predict))
plot_model(rigde3_reg)

167.9401085999025
1.3233492754136291
1.1888759304218461
1.3196456113086197

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Ridge

if __name__ == "__main__":

np.random.seed(42)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()

def PolynomialRegression(degree):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("lin_reg", LinearRegression())
])

np.random.seed(666)
X_train, X_test, y_train, y_test = train_test_split(X, y)
poly_reg = PolynomialRegression(degree=20)
poly_reg.fit(X_train, y_train)
y_poly_predict = poly_reg.predict(X_test)
print(mean_squared_error(y_test, y_poly_predict))

def plot_model(model):
X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
y_plot = model.predict(X_plot)
plt.scatter(x, y)
plt.plot(X_plot[:, 0], y_plot, color='r')
plt.axis([-3, 3, 0, 6])
plt.show()

# plot_model(poly_reg)

# 使用岭回归，对超参数α赋值
def RidgeRegression(degree, alpha):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("ridge_reg", Ridge(alpha=alpha))
])

rigde1_reg = RidgeRegression(20, 0.0001)
rigde1_reg.fit(X_train, y_train)
y1_predict = rigde1_reg.predict(X_test)
print(mean_squared_error(y_test, y1_predict))
# plot_model(rigde1_reg)

rigde2_reg = RidgeRegression(20, 1)
rigde2_reg.fit(X_train, y_train)
y2_predict = rigde2_reg.predict(X_test)
print(mean_squared_error(y_test, y2_predict))
# plot_model(rigde2_reg)

rigde3_reg = RidgeRegression(20, 100)
rigde3_reg.fit(X_train, y_train)
y3_predict = rigde3_reg.predict(X_test)
print(mean_squared_error(y_test, y3_predict))
# plot_model(rigde3_reg)

rigde4_reg = RidgeRegression(20, 1000000000000)
rigde4_reg.fit(X_train, y_train)
y4_predict = rigde4_reg.predict(X_test)
print(mean_squared_error(y_test, y4_predict))
plot_model(rigde4_reg)

167.9401085999025
1.3233492754136291
1.1888759304218461
1.3196456113086197
1.8408939654674448

LASSO

LASSO是Least Absolute Shrinkage and Selection Operator Regression的缩写，同样作为正则化项，它跟岭回归有一点点不同，但作用是一样的，都是为了让系数θ尽可能的小。那么我们改写损失函数为LASSO就为使

import numpy as np
import matplotlib.pyplot as plt

if __name__ == "__main__":

np.random.seed(42)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
plt.scatter(x, y)
plt.show()

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split

if __name__ == "__main__":

np.random.seed(42)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()

def PolynomialRegression(degree):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("lin_reg", LinearRegression())
])

np.random.seed(666)
X_train, X_test, y_train, y_test = train_test_split(X, y)
poly_reg = PolynomialRegression(degree=20)
poly_reg.fit(X_train, y_train)
y_poly_predict = poly_reg.predict(X_test)
print(mean_squared_error(y_test, y_poly_predict))

def plot_model(model):
X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
y_plot = model.predict(X_plot)
plt.scatter(x, y)
plt.plot(X_plot[:, 0], y_plot, color='r')
plt.axis([-3, 3, 0, 6])
plt.show()

plot_model(poly_reg)

167.9401085999025

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Lasso

if __name__ == "__main__":

np.random.seed(42)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()

def PolynomialRegression(degree):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("lin_reg", LinearRegression())
])

np.random.seed(666)
X_train, X_test, y_train, y_test = train_test_split(X, y)
poly_reg = PolynomialRegression(degree=20)
poly_reg.fit(X_train, y_train)
y_poly_predict = poly_reg.predict(X_test)
print(mean_squared_error(y_test, y_poly_predict))

def plot_model(model):
X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
y_plot = model.predict(X_plot)
plt.scatter(x, y)
plt.plot(X_plot[:, 0], y_plot, color='r')
plt.axis([-3, 3, 0, 6])
plt.show()

# plot_model(poly_reg)

# LASSO
def LassoRegression(degree, alpha):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("lasso_reg", Lasso(alpha=alpha))
])

lasso1_reg = LassoRegression(20, 0.01)
lasso1_reg.fit(X_train, y_train)
y1_predict = lasso1_reg.predict(X_test)
print(mean_squared_error(y_test, y1_predict))
plot_model(lasso1_reg)

167.9401085999025
1.1496080843259966

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Lasso

if __name__ == "__main__":

np.random.seed(42)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()

def PolynomialRegression(degree):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("lin_reg", LinearRegression())
])

np.random.seed(666)
X_train, X_test, y_train, y_test = train_test_split(X, y)
poly_reg = PolynomialRegression(degree=20)
poly_reg.fit(X_train, y_train)
y_poly_predict = poly_reg.predict(X_test)
print(mean_squared_error(y_test, y_poly_predict))

def plot_model(model):
X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
y_plot = model.predict(X_plot)
plt.scatter(x, y)
plt.plot(X_plot[:, 0], y_plot, color='r')
plt.axis([-3, 3, 0, 6])
plt.show()

# plot_model(poly_reg)

# LASSO
def LassoRegression(degree, alpha):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("lasso_reg", Lasso(alpha=alpha))
])

lasso1_reg = LassoRegression(20, 0.01)
lasso1_reg.fit(X_train, y_train)
y1_predict = lasso1_reg.predict(X_test)
print(mean_squared_error(y_test, y1_predict))
# plot_model(lasso1_reg)

lasso2_reg = LassoRegression(20, 0.1)
lasso2_reg.fit(X_train, y_train)
y2_predict = lasso2_reg.predict(X_test)
print(mean_squared_error(y_test, y2_predict))
plot_model(lasso2_reg)

167.9401085999025
1.1496080843259966
1.1213911351818648


import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import PolynomialFeatures
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Lasso

if __name__ == "__main__":

np.random.seed(42)
x = np.random.uniform(-3, 3, size=100)
X = x.reshape(-1, 1)
y = 0.5 * x + 3 + np.random.normal(0, 1, size=100)
# plt.scatter(x, y)
# plt.show()

def PolynomialRegression(degree):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("lin_reg", LinearRegression())
])

np.random.seed(666)
X_train, X_test, y_train, y_test = train_test_split(X, y)
poly_reg = PolynomialRegression(degree=20)
poly_reg.fit(X_train, y_train)
y_poly_predict = poly_reg.predict(X_test)
print(mean_squared_error(y_test, y_poly_predict))

def plot_model(model):
X_plot = np.linspace(-3, 3, 100).reshape(100, 1)
y_plot = model.predict(X_plot)
plt.scatter(x, y)
plt.plot(X_plot[:, 0], y_plot, color='r')
plt.axis([-3, 3, 0, 6])
plt.show()

# plot_model(poly_reg)

# LASSO
def LassoRegression(degree, alpha):
return Pipeline([
("poly", PolynomialFeatures(degree=degree)),
("std_scaler", StandardScaler()),
("lasso_reg", Lasso(alpha=alpha))
])

lasso1_reg = LassoRegression(20, 0.01)
lasso1_reg.fit(X_train, y_train)
y1_predict = lasso1_reg.predict(X_test)
print(mean_squared_error(y_test, y1_predict))
# plot_model(lasso1_reg)

lasso2_reg = LassoRegression(20, 0.1)
lasso2_reg.fit(X_train, y_train)
y2_predict = lasso2_reg.predict(X_test)
print(mean_squared_error(y_test, y2_predict))
# plot_model(lasso2_reg)

lasso3_reg = LassoRegression(20, 1)
lasso3_reg.fit(X_train, y_train)
y3_predict = lasso3_reg.predict(X_test)
print(mean_squared_error(y_test, y3_predict))
plot_model(lasso3_reg)

167.9401085999025
1.1496080843259966
1.1213911351818648
1.8408939659515595

LASSO(α=0.1)

LASSO趋向于在优化损失函数的过程中，使得一部分θ值变为0，而不是让一部分θ都变为一个很小的值。所以可作为特征选择用。因为使用LASSO的过程，如果某一些θ=0了，就代表LASSO认为这个θ对应的那个特征是完全没有用的，而剩下的那些θ不等于0的那些特征，就是LASSO认为这个特征有用。

L1，L2和弹性网络

L1正则，L2正则

0
0 收藏

0 评论
0 收藏
0