机器学习 August 19, 2018

2-2 测试算法

Words count 27k Reading time 24 mins. Read count 0

测试算法

import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
iris = datasets.load_iris()
X = iris.data
y = iris.target
X.shape
(150, 4)
y.shape
(150,)

train_test_split

y.reshape(1,-1)
array([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]])
# 使用vstack将x和y合并
tatol_data = np.hstack([X,y.reshape(-1,1)])
tatol_data
array([[5.1, 3.5, 1.4, 0.2, 0. ],
       [4.9, 3. , 1.4, 0.2, 0. ],
       [4.7, 3.2, 1.3, 0.2, 0. ],
       [4.6, 3.1, 1.5, 0.2, 0. ],
       [5. , 3.6, 1.4, 0.2, 0. ],
       [5.4, 3.9, 1.7, 0.4, 0. ],
       [4.6, 3.4, 1.4, 0.3, 0. ],
       [5. , 3.4, 1.5, 0.2, 0. ],
       [4.4, 2.9, 1.4, 0.2, 0. ],
       [4.9, 3.1, 1.5, 0.1, 0. ],
       [5.4, 3.7, 1.5, 0.2, 0. ],
       [4.8, 3.4, 1.6, 0.2, 0. ],
       [4.8, 3. , 1.4, 0.1, 0. ],
       [4.3, 3. , 1.1, 0.1, 0. ],
       [5.8, 4. , 1.2, 0.2, 0. ],
       [5.7, 4.4, 1.5, 0.4, 0. ],
       [5.4, 3.9, 1.3, 0.4, 0. ],
       [5.1, 3.5, 1.4, 0.3, 0. ],
       [5.7, 3.8, 1.7, 0.3, 0. ],
       [5.1, 3.8, 1.5, 0.3, 0. ],
       [5.4, 3.4, 1.7, 0.2, 0. ],
       [5.1, 3.7, 1.5, 0.4, 0. ],
       [4.6, 3.6, 1. , 0.2, 0. ],
       [5.1, 3.3, 1.7, 0.5, 0. ],
       [4.8, 3.4, 1.9, 0.2, 0. ],
       [5. , 3. , 1.6, 0.2, 0. ],
       [5. , 3.4, 1.6, 0.4, 0. ],
       [5.2, 3.5, 1.5, 0.2, 0. ],
       [5.2, 3.4, 1.4, 0.2, 0. ],
       [4.7, 3.2, 1.6, 0.2, 0. ],
       [4.8, 3.1, 1.6, 0.2, 0. ],
       [5.4, 3.4, 1.5, 0.4, 0. ],
       [5.2, 4.1, 1.5, 0.1, 0. ],
       [5.5, 4.2, 1.4, 0.2, 0. ],
       [4.9, 3.1, 1.5, 0.1, 0. ],
       [5. , 3.2, 1.2, 0.2, 0. ],
       [5.5, 3.5, 1.3, 0.2, 0. ],
       [4.9, 3.1, 1.5, 0.1, 0. ],
       [4.4, 3. , 1.3, 0.2, 0. ],
       [5.1, 3.4, 1.5, 0.2, 0. ],
       [5. , 3.5, 1.3, 0.3, 0. ],
       [4.5, 2.3, 1.3, 0.3, 0. ],
       [4.4, 3.2, 1.3, 0.2, 0. ],
       [5. , 3.5, 1.6, 0.6, 0. ],
       [5.1, 3.8, 1.9, 0.4, 0. ],
       [4.8, 3. , 1.4, 0.3, 0. ],
       [5.1, 3.8, 1.6, 0.2, 0. ],
       [4.6, 3.2, 1.4, 0.2, 0. ],
       [5.3, 3.7, 1.5, 0.2, 0. ],
       [5. , 3.3, 1.4, 0.2, 0. ],
       [7. , 3.2, 4.7, 1.4, 1. ],
       [6.4, 3.2, 4.5, 1.5, 1. ],
       [6.9, 3.1, 4.9, 1.5, 1. ],
       [5.5, 2.3, 4. , 1.3, 1. ],
       [6.5, 2.8, 4.6, 1.5, 1. ],
       [5.7, 2.8, 4.5, 1.3, 1. ],
       [6.3, 3.3, 4.7, 1.6, 1. ],
       [4.9, 2.4, 3.3, 1. , 1. ],
       [6.6, 2.9, 4.6, 1.3, 1. ],
       [5.2, 2.7, 3.9, 1.4, 1. ],
       [5. , 2. , 3.5, 1. , 1. ],
       [5.9, 3. , 4.2, 1.5, 1. ],
       [6. , 2.2, 4. , 1. , 1. ],
       [6.1, 2.9, 4.7, 1.4, 1. ],
       [5.6, 2.9, 3.6, 1.3, 1. ],
       [6.7, 3.1, 4.4, 1.4, 1. ],
       [5.6, 3. , 4.5, 1.5, 1. ],
       [5.8, 2.7, 4.1, 1. , 1. ],
       [6.2, 2.2, 4.5, 1.5, 1. ],
       [5.6, 2.5, 3.9, 1.1, 1. ],
       [5.9, 3.2, 4.8, 1.8, 1. ],
       [6.1, 2.8, 4. , 1.3, 1. ],
       [6.3, 2.5, 4.9, 1.5, 1. ],
       [6.1, 2.8, 4.7, 1.2, 1. ],
       [6.4, 2.9, 4.3, 1.3, 1. ],
       [6.6, 3. , 4.4, 1.4, 1. ],
       [6.8, 2.8, 4.8, 1.4, 1. ],
       [6.7, 3. , 5. , 1.7, 1. ],
       [6. , 2.9, 4.5, 1.5, 1. ],
       [5.7, 2.6, 3.5, 1. , 1. ],
       [5.5, 2.4, 3.8, 1.1, 1. ],
       [5.5, 2.4, 3.7, 1. , 1. ],
       [5.8, 2.7, 3.9, 1.2, 1. ],
       [6. , 2.7, 5.1, 1.6, 1. ],
       [5.4, 3. , 4.5, 1.5, 1. ],
       [6. , 3.4, 4.5, 1.6, 1. ],
       [6.7, 3.1, 4.7, 1.5, 1. ],
       [6.3, 2.3, 4.4, 1.3, 1. ],
       [5.6, 3. , 4.1, 1.3, 1. ],
       [5.5, 2.5, 4. , 1.3, 1. ],
       [5.5, 2.6, 4.4, 1.2, 1. ],
       [6.1, 3. , 4.6, 1.4, 1. ],
       [5.8, 2.6, 4. , 1.2, 1. ],
       [5. , 2.3, 3.3, 1. , 1. ],
       [5.6, 2.7, 4.2, 1.3, 1. ],
       [5.7, 3. , 4.2, 1.2, 1. ],
       [5.7, 2.9, 4.2, 1.3, 1. ],
       [6.2, 2.9, 4.3, 1.3, 1. ],
       [5.1, 2.5, 3. , 1.1, 1. ],
       [5.7, 2.8, 4.1, 1.3, 1. ],
       [6.3, 3.3, 6. , 2.5, 2. ],
       [5.8, 2.7, 5.1, 1.9, 2. ],
       [7.1, 3. , 5.9, 2.1, 2. ],
       [6.3, 2.9, 5.6, 1.8, 2. ],
       [6.5, 3. , 5.8, 2.2, 2. ],
       [7.6, 3. , 6.6, 2.1, 2. ],
       [4.9, 2.5, 4.5, 1.7, 2. ],
       [7.3, 2.9, 6.3, 1.8, 2. ],
       [6.7, 2.5, 5.8, 1.8, 2. ],
       [7.2, 3.6, 6.1, 2.5, 2. ],
       [6.5, 3.2, 5.1, 2. , 2. ],
       [6.4, 2.7, 5.3, 1.9, 2. ],
       [6.8, 3. , 5.5, 2.1, 2. ],
       [5.7, 2.5, 5. , 2. , 2. ],
       [5.8, 2.8, 5.1, 2.4, 2. ],
       [6.4, 3.2, 5.3, 2.3, 2. ],
       [6.5, 3. , 5.5, 1.8, 2. ],
       [7.7, 3.8, 6.7, 2.2, 2. ],
       [7.7, 2.6, 6.9, 2.3, 2. ],
       [6. , 2.2, 5. , 1.5, 2. ],
       [6.9, 3.2, 5.7, 2.3, 2. ],
       [5.6, 2.8, 4.9, 2. , 2. ],
       [7.7, 2.8, 6.7, 2. , 2. ],
       [6.3, 2.7, 4.9, 1.8, 2. ],
       [6.7, 3.3, 5.7, 2.1, 2. ],
       [7.2, 3.2, 6. , 1.8, 2. ],
       [6.2, 2.8, 4.8, 1.8, 2. ],
       [6.1, 3. , 4.9, 1.8, 2. ],
       [6.4, 2.8, 5.6, 2.1, 2. ],
       [7.2, 3. , 5.8, 1.6, 2. ],
       [7.4, 2.8, 6.1, 1.9, 2. ],
       [7.9, 3.8, 6.4, 2. , 2. ],
       [6.4, 2.8, 5.6, 2.2, 2. ],
       [6.3, 2.8, 5.1, 1.5, 2. ],
       [6.1, 2.6, 5.6, 1.4, 2. ],
       [7.7, 3. , 6.1, 2.3, 2. ],
       [6.3, 3.4, 5.6, 2.4, 2. ],
       [6.4, 3.1, 5.5, 1.8, 2. ],
       [6. , 3. , 4.8, 1.8, 2. ],
       [6.9, 3.1, 5.4, 2.1, 2. ],
       [6.7, 3.1, 5.6, 2.4, 2. ],
       [6.9, 3.1, 5.1, 2.3, 2. ],
       [5.8, 2.7, 5.1, 1.9, 2. ],
       [6.8, 3.2, 5.9, 2.3, 2. ],
       [6.7, 3.3, 5.7, 2.5, 2. ],
       [6.7, 3. , 5.2, 2.3, 2. ],
       [6.3, 2.5, 5. , 1.9, 2. ],
       [6.5, 3. , 5.2, 2. , 2. ],
       [6.2, 3.4, 5.4, 2.3, 2. ],
       [5.9, 3. , 5.1, 1.8, 2. ]])
import random
random.shuffle(tatol_data)
tatol_data
array([[5.1, 3.5, 1.4, 0.2, 0. ],
       [5.1, 3.5, 1.4, 0.2, 0. ],
       [4.9, 3. , 1.4, 0.2, 0. ],
       [4.9, 3. , 1.4, 0.2, 0. ],
       [5.1, 3.5, 1.4, 0.2, 0. ],
       [4.6, 3.1, 1.5, 0.2, 0. ],
       [4.6, 3.1, 1.5, 0.2, 0. ],
       [5. , 3.4, 1.5, 0.2, 0. ],
       [4.6, 3.4, 1.4, 0.3, 0. ],
       [5.1, 3.5, 1.4, 0.2, 0. ],
       [5.4, 3.7, 1.5, 0.2, 0. ],
       [5.4, 3.9, 1.7, 0.4, 0. ],
       [5. , 3.4, 1.5, 0.2, 0. ],
       [5.4, 3.9, 1.7, 0.4, 0. ],
       [4.8, 3. , 1.4, 0.1, 0. ],
       [4.8, 3.4, 1.6, 0.2, 0. ],
       [5.7, 4.4, 1.5, 0.4, 0. ],
       [4.8, 3. , 1.4, 0.1, 0. ],
       [5.4, 3.9, 1.7, 0.4, 0. ],
       [5.1, 3.5, 1.4, 0.2, 0. ],
       [5.4, 3.7, 1.5, 0.2, 0. ],
       [4.6, 3.4, 1.4, 0.3, 0. ],
       [5.4, 3.4, 1.7, 0.2, 0. ],
       [5.7, 4.4, 1.5, 0.4, 0. ],
       [4.6, 3.4, 1.4, 0.3, 0. ],
       [4.4, 2.9, 1.4, 0.2, 0. ],
       [5.1, 3.3, 1.7, 0.5, 0. ],
       [4.8, 3.4, 1.6, 0.2, 0. ],
       [4.6, 3.4, 1.4, 0.3, 0. ],
       [5.1, 3.3, 1.7, 0.5, 0. ],
       [5.2, 3.5, 1.5, 0.2, 0. ],
       [5. , 3. , 1.6, 0.2, 0. ],
       [4.8, 3.1, 1.6, 0.2, 0. ],
       [5.1, 3.8, 1.5, 0.3, 0. ],
       [4.8, 3. , 1.4, 0.1, 0. ],
       [5. , 3.6, 1.4, 0.2, 0. ],
       [4.6, 3.6, 1. , 0.2, 0. ],
       [4.8, 3. , 1.4, 0.1, 0. ],
       [5.1, 3.5, 1.4, 0.2, 0. ],
       [5.7, 4.4, 1.5, 0.4, 0. ],
       [5. , 3.4, 1.5, 0.2, 0. ],
       [4.4, 2.9, 1.4, 0.2, 0. ],
       [5.1, 3.8, 1.5, 0.3, 0. ],
       [5.1, 3.5, 1.4, 0.2, 0. ],
       [5.1, 3.3, 1.7, 0.5, 0. ],
       [4.8, 3.4, 1.9, 0.2, 0. ],
       [5.4, 3.9, 1.7, 0.4, 0. ],
       [5.2, 3.5, 1.5, 0.2, 0. ],
       [4.7, 3.2, 1.6, 0.2, 0. ],
       [5.5, 3.5, 1.3, 0.2, 0. ],
       [4.6, 3.2, 1.4, 0.2, 0. ],
       [4.7, 3.2, 1.6, 0.2, 0. ],
       [5. , 3.4, 1.5, 0.2, 0. ],
       [4.6, 3.4, 1.4, 0.3, 0. ],
       [5.4, 3.4, 1.5, 0.4, 0. ],
       [5.5, 4.2, 1.4, 0.2, 0. ],
       [5. , 3.5, 1.3, 0.3, 0. ],
       [6.5, 2.8, 4.6, 1.5, 1. ],
       [4.6, 3.6, 1. , 0.2, 0. ],
       [5. , 3.5, 1.3, 0.3, 0. ],
       [6.4, 3.2, 4.5, 1.5, 1. ],
       [5.3, 3.7, 1.5, 0.2, 0. ],
       [6. , 2.2, 4. , 1. , 1. ],
       [4.8, 3.1, 1.6, 0.2, 0. ],
       [6. , 2.2, 4. , 1. , 1. ],
       [5.7, 2.8, 4.5, 1.3, 1. ],
       [4.9, 3.1, 1.5, 0.1, 0. ],
       [4.7, 3.2, 1.6, 0.2, 0. ],
       [5. , 3.3, 1.4, 0.2, 0. ],
       [4.3, 3. , 1.1, 0.1, 0. ],
       [4.9, 3. , 1.4, 0.2, 0. ],
       [5.4, 3.4, 1.7, 0.2, 0. ],
       [5.5, 4.2, 1.4, 0.2, 0. ],
       [6.5, 2.8, 4.6, 1.5, 1. ],
       [4.8, 3.1, 1.6, 0.2, 0. ],
       [5.9, 3.2, 4.8, 1.8, 1. ],
       [7. , 3.2, 4.7, 1.4, 1. ],
       [6.7, 3.1, 4.4, 1.4, 1. ],
       [6.8, 2.8, 4.8, 1.4, 1. ],
       [4.8, 3.4, 1.9, 0.2, 0. ],
       [6.2, 2.2, 4.5, 1.5, 1. ],
       [5.7, 4.4, 1.5, 0.4, 0. ],
       [5.6, 2.9, 3.6, 1.3, 1. ],
       [5.8, 2.7, 3.9, 1.2, 1. ],
       [4.4, 2.9, 1.4, 0.2, 0. ],
       [6.6, 3. , 4.4, 1.4, 1. ],
       [5.1, 3.7, 1.5, 0.4, 0. ],
       [4.9, 3. , 1.4, 0.2, 0. ],
       [4.6, 3.6, 1. , 0.2, 0. ],
       [5. , 3.2, 1.2, 0.2, 0. ],
       [4.4, 2.9, 1.4, 0.2, 0. ],
       [5.1, 3.8, 1.5, 0.3, 0. ],
       [5.1, 3.4, 1.5, 0.2, 0. ],
       [5.8, 2.7, 3.9, 1.2, 1. ],
       [5. , 2. , 3.5, 1. , 1. ],
       [7. , 3.2, 4.7, 1.4, 1. ],
       [4.7, 3.2, 1.6, 0.2, 0. ],
       [5.8, 2.7, 3.9, 1.2, 1. ],
       [4.6, 3.2, 1.4, 0.2, 0. ],
       [5.6, 2.9, 3.6, 1.3, 1. ],
       [4.7, 3.2, 1.6, 0.2, 0. ],
       [5.8, 2.6, 4. , 1.2, 1. ],
       [6.2, 2.2, 4.5, 1.5, 1. ],
       [5.7, 2.6, 3.5, 1. , 1. ],
       [5.1, 3.4, 1.5, 0.2, 0. ],
       [6.8, 2.8, 4.8, 1.4, 1. ],
       [6.1, 2.9, 4.7, 1.4, 1. ],
       [5.9, 3. , 4.2, 1.5, 1. ],
       [5.5, 2.3, 4. , 1.3, 1. ],
       [5. , 2.3, 3.3, 1. , 1. ],
       [6.6, 2.9, 4.6, 1.3, 1. ],
       [6.3, 3.3, 4.7, 1.6, 1. ],
       [5.5, 2.5, 4. , 1.3, 1. ],
       [5.6, 2.9, 3.6, 1.3, 1. ],
       [5.4, 3.9, 1.7, 0.4, 0. ],
       [5.7, 2.9, 4.2, 1.3, 1. ],
       [5.7, 2.8, 4.5, 1.3, 1. ],
       [6.3, 2.9, 5.6, 1.8, 2. ],
       [5.7, 2.9, 4.2, 1.3, 1. ],
       [6.3, 3.3, 4.7, 1.6, 1. ],
       [6. , 2.2, 5. , 1.5, 2. ],
       [5.5, 2.4, 3.8, 1.1, 1. ],
       [6.5, 3. , 5.5, 1.8, 2. ],
       [6.2, 2.9, 4.3, 1.3, 1. ],
       [5.2, 3.5, 1.5, 0.2, 0. ],
       [7.3, 2.9, 6.3, 1.8, 2. ],
       [4.6, 3.6, 1. , 0.2, 0. ],
       [6.3, 3.3, 6. , 2.5, 2. ],
       [6.7, 3.3, 5.7, 2.1, 2. ],
       [5.6, 3. , 4.5, 1.5, 1. ],
       [6.2, 2.9, 4.3, 1.3, 1. ],
       [6.3, 2.5, 4.9, 1.5, 1. ],
       [4.8, 3. , 1.4, 0.3, 0. ],
       [6. , 3.4, 4.5, 1.6, 1. ],
       [5. , 3.4, 1.5, 0.2, 0. ],
       [4.8, 3. , 1.4, 0.1, 0. ],
       [5.1, 3.5, 1.4, 0.2, 0. ],
       [5.7, 3. , 4.2, 1.2, 1. ],
       [5.2, 2.7, 3.9, 1.4, 1. ],
       [6.3, 3.3, 4.7, 1.6, 1. ],
       [4.6, 3.1, 1.5, 0.2, 0. ],
       [6.9, 3.1, 4.9, 1.5, 1. ],
       [6.7, 3.3, 5.7, 2.1, 2. ],
       [5.8, 2.7, 5.1, 1.9, 2. ],
       [7.9, 3.8, 6.4, 2. , 2. ],
       [6.2, 2.8, 4.8, 1.8, 2. ],
       [6.5, 3. , 5.8, 2.2, 2. ],
       [6.3, 3.3, 6. , 2.5, 2. ],
       [5.7, 3. , 4.2, 1.2, 1. ],
       [5.4, 3.4, 1.7, 0.2, 0. ]])
X_train,y_train = np.hsplit(tatol_data,[-1])
X_train
array([[5.1, 3.5, 1.4, 0.2],
       [5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [5.1, 3.5, 1.4, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.4, 1.5, 0.2],
       [4.6, 3.4, 1.4, 0.3],
       [5.1, 3.5, 1.4, 0.2],
       [5.4, 3.7, 1.5, 0.2],
       [5.4, 3.9, 1.7, 0.4],
       [5. , 3.4, 1.5, 0.2],
       [5.4, 3.9, 1.7, 0.4],
       [4.8, 3. , 1.4, 0.1],
       [4.8, 3.4, 1.6, 0.2],
       [5.7, 4.4, 1.5, 0.4],
       [4.8, 3. , 1.4, 0.1],
       [5.4, 3.9, 1.7, 0.4],
       [5.1, 3.5, 1.4, 0.2],
       [5.4, 3.7, 1.5, 0.2],
       [4.6, 3.4, 1.4, 0.3],
       [5.4, 3.4, 1.7, 0.2],
       [5.7, 4.4, 1.5, 0.4],
       [4.6, 3.4, 1.4, 0.3],
       [4.4, 2.9, 1.4, 0.2],
       [5.1, 3.3, 1.7, 0.5],
       [4.8, 3.4, 1.6, 0.2],
       [4.6, 3.4, 1.4, 0.3],
       [5.1, 3.3, 1.7, 0.5],
       [5.2, 3.5, 1.5, 0.2],
       [5. , 3. , 1.6, 0.2],
       [4.8, 3.1, 1.6, 0.2],
       [5.1, 3.8, 1.5, 0.3],
       [4.8, 3. , 1.4, 0.1],
       [5. , 3.6, 1.4, 0.2],
       [4.6, 3.6, 1. , 0.2],
       [4.8, 3. , 1.4, 0.1],
       [5.1, 3.5, 1.4, 0.2],
       [5.7, 4.4, 1.5, 0.4],
       [5. , 3.4, 1.5, 0.2],
       [4.4, 2.9, 1.4, 0.2],
       [5.1, 3.8, 1.5, 0.3],
       [5.1, 3.5, 1.4, 0.2],
       [5.1, 3.3, 1.7, 0.5],
       [4.8, 3.4, 1.9, 0.2],
       [5.4, 3.9, 1.7, 0.4],
       [5.2, 3.5, 1.5, 0.2],
       [4.7, 3.2, 1.6, 0.2],
       [5.5, 3.5, 1.3, 0.2],
       [4.6, 3.2, 1.4, 0.2],
       [4.7, 3.2, 1.6, 0.2],
       [5. , 3.4, 1.5, 0.2],
       [4.6, 3.4, 1.4, 0.3],
       [5.4, 3.4, 1.5, 0.4],
       [5.5, 4.2, 1.4, 0.2],
       [5. , 3.5, 1.3, 0.3],
       [6.5, 2.8, 4.6, 1.5],
       [4.6, 3.6, 1. , 0.2],
       [5. , 3.5, 1.3, 0.3],
       [6.4, 3.2, 4.5, 1.5],
       [5.3, 3.7, 1.5, 0.2],
       [6. , 2.2, 4. , 1. ],
       [4.8, 3.1, 1.6, 0.2],
       [6. , 2.2, 4. , 1. ],
       [5.7, 2.8, 4.5, 1.3],
       [4.9, 3.1, 1.5, 0.1],
       [4.7, 3.2, 1.6, 0.2],
       [5. , 3.3, 1.4, 0.2],
       [4.3, 3. , 1.1, 0.1],
       [4.9, 3. , 1.4, 0.2],
       [5.4, 3.4, 1.7, 0.2],
       [5.5, 4.2, 1.4, 0.2],
       [6.5, 2.8, 4.6, 1.5],
       [4.8, 3.1, 1.6, 0.2],
       [5.9, 3.2, 4.8, 1.8],
       [7. , 3.2, 4.7, 1.4],
       [6.7, 3.1, 4.4, 1.4],
       [6.8, 2.8, 4.8, 1.4],
       [4.8, 3.4, 1.9, 0.2],
       [6.2, 2.2, 4.5, 1.5],
       [5.7, 4.4, 1.5, 0.4],
       [5.6, 2.9, 3.6, 1.3],
       [5.8, 2.7, 3.9, 1.2],
       [4.4, 2.9, 1.4, 0.2],
       [6.6, 3. , 4.4, 1.4],
       [5.1, 3.7, 1.5, 0.4],
       [4.9, 3. , 1.4, 0.2],
       [4.6, 3.6, 1. , 0.2],
       [5. , 3.2, 1.2, 0.2],
       [4.4, 2.9, 1.4, 0.2],
       [5.1, 3.8, 1.5, 0.3],
       [5.1, 3.4, 1.5, 0.2],
       [5.8, 2.7, 3.9, 1.2],
       [5. , 2. , 3.5, 1. ],
       [7. , 3.2, 4.7, 1.4],
       [4.7, 3.2, 1.6, 0.2],
       [5.8, 2.7, 3.9, 1.2],
       [4.6, 3.2, 1.4, 0.2],
       [5.6, 2.9, 3.6, 1.3],
       [4.7, 3.2, 1.6, 0.2],
       [5.8, 2.6, 4. , 1.2],
       [6.2, 2.2, 4.5, 1.5],
       [5.7, 2.6, 3.5, 1. ],
       [5.1, 3.4, 1.5, 0.2],
       [6.8, 2.8, 4.8, 1.4],
       [6.1, 2.9, 4.7, 1.4],
       [5.9, 3. , 4.2, 1.5],
       [5.5, 2.3, 4. , 1.3],
       [5. , 2.3, 3.3, 1. ],
       [6.6, 2.9, 4.6, 1.3],
       [6.3, 3.3, 4.7, 1.6],
       [5.5, 2.5, 4. , 1.3],
       [5.6, 2.9, 3.6, 1.3],
       [5.4, 3.9, 1.7, 0.4],
       [5.7, 2.9, 4.2, 1.3],
       [5.7, 2.8, 4.5, 1.3],
       [6.3, 2.9, 5.6, 1.8],
       [5.7, 2.9, 4.2, 1.3],
       [6.3, 3.3, 4.7, 1.6],
       [6. , 2.2, 5. , 1.5],
       [5.5, 2.4, 3.8, 1.1],
       [6.5, 3. , 5.5, 1.8],
       [6.2, 2.9, 4.3, 1.3],
       [5.2, 3.5, 1.5, 0.2],
       [7.3, 2.9, 6.3, 1.8],
       [4.6, 3.6, 1. , 0.2],
       [6.3, 3.3, 6. , 2.5],
       [6.7, 3.3, 5.7, 2.1],
       [5.6, 3. , 4.5, 1.5],
       [6.2, 2.9, 4.3, 1.3],
       [6.3, 2.5, 4.9, 1.5],
       [4.8, 3. , 1.4, 0.3],
       [6. , 3.4, 4.5, 1.6],
       [5. , 3.4, 1.5, 0.2],
       [4.8, 3. , 1.4, 0.1],
       [5.1, 3.5, 1.4, 0.2],
       [5.7, 3. , 4.2, 1.2],
       [5.2, 2.7, 3.9, 1.4],
       [6.3, 3.3, 4.7, 1.6],
       [4.6, 3.1, 1.5, 0.2],
       [6.9, 3.1, 4.9, 1.5],
       [6.7, 3.3, 5.7, 2.1],
       [5.8, 2.7, 5.1, 1.9],
       [7.9, 3.8, 6.4, 2. ],
       [6.2, 2.8, 4.8, 1.8],
       [6.5, 3. , 5.8, 2.2],
       [6.3, 3.3, 6. , 2.5],
       [5.7, 3. , 4.2, 1.2],
       [5.4, 3.4, 1.7, 0.2]])
y_train = np.transpose(y_train)
# 获取随机索引
shuffle_indexs = np.random.permutation(len(X))
shuffle_indexs
array([126,  60, 115,  72,  12, 107,  51,  36,  52,  18,  32,  87, 135,
       137, 105,  24,  37,  89,  64,  31, 131,  43,   0,  17,  23, 106,
        97,  45,  92,  70,  65,   4,  68,  95,  40, 121,  81, 139,  69,
       133,  49,  20,  79,  74,  46,  48,   6,   9,  63, 120, 103,  91,
        38,  42,  35, 109,  41, 101,  30,  86,  47,  94, 111, 116, 129,
        29,  88,  55, 127, 117,  26, 113,  78,  13,  25,  57,  44,  34,
       149,  67,  28,  22, 123, 132, 142,  90, 102, 110,  15,  21,  16,
        54,  98,  39,  58,  75,   1,  96,  82,   3, 143, 144,  71,  59,
       125,  19,  56, 114, 141,  83,  61,  50, 148, 118,  84,  11,  10,
         8,  33, 147, 146,  85,  80,  76, 100,  99, 108, 128,   7, 145,
        93,  27, 138,   2, 134,  77,  53, 112, 119,  66, 124, 140,  62,
       136, 130,   5,  73,  14, 104, 122])
test_ratio = 0.2
test_size = int(len(X)*test_ratio)
test_size
30
test_indexes = shuffle_indexs[:test_size]
train_indexs = shuffle_indexs[test_size:]
print(test_indexes)
X_train = X[train_indexs]
y_train = y[train_indexs]

X_test = X[test_indexes]
y_test = y[test_indexes]
X_train
[126  60 115  72  12 107  51  36  52  18  32  87 135 137 105  24  37  89
  64  31 131  43   0  17  23 106  97  45  92  70]





array([[6.7, 3.1, 4.4, 1.4],
       [5. , 3.6, 1.4, 0.2],
       [6.2, 2.2, 4.5, 1.5],
       [5.7, 3. , 4.2, 1.2],
       [5. , 3.5, 1.3, 0.3],
       [5.6, 2.8, 4.9, 2. ],
       [5.5, 2.4, 3.7, 1. ],
       [6.9, 3.1, 5.4, 2.1],
       [5.6, 2.5, 3.9, 1.1],
       [6.3, 2.8, 5.1, 1.5],
       [5. , 3.3, 1.4, 0.2],
       [5.4, 3.4, 1.7, 0.2],
       [5.7, 2.6, 3.5, 1. ],
       [6.4, 2.9, 4.3, 1.3],
       [5.1, 3.8, 1.6, 0.2],
       [5.3, 3.7, 1.5, 0.2],
       [4.6, 3.4, 1.4, 0.3],
       [4.9, 3.1, 1.5, 0.1],
       [6.1, 2.9, 4.7, 1.4],
       [6.9, 3.2, 5.7, 2.3],
       [6.3, 2.9, 5.6, 1.8],
       [6.1, 3. , 4.6, 1.4],
       [4.4, 3. , 1.3, 0.2],
       [4.4, 3.2, 1.3, 0.2],
       [5. , 3.2, 1.2, 0.2],
       [7.2, 3.6, 6.1, 2.5],
       [4.5, 2.3, 1.3, 0.3],
       [5.8, 2.7, 5.1, 1.9],
       [4.8, 3.1, 1.6, 0.2],
       [6.7, 3.1, 4.7, 1.5],
       [4.6, 3.2, 1.4, 0.2],
       [5.6, 2.7, 4.2, 1.3],
       [6.4, 2.7, 5.3, 1.9],
       [6.5, 3. , 5.5, 1.8],
       [7.2, 3. , 5.8, 1.6],
       [4.7, 3.2, 1.6, 0.2],
       [5.6, 3. , 4.1, 1.3],
       [5.7, 2.8, 4.5, 1.3],
       [6.1, 3. , 4.9, 1.8],
       [7.7, 3.8, 6.7, 2.2],
       [5. , 3.4, 1.6, 0.4],
       [5.7, 2.5, 5. , 2. ],
       [6. , 2.9, 4.5, 1.5],
       [4.3, 3. , 1.1, 0.1],
       [5. , 3. , 1.6, 0.2],
       [4.9, 2.4, 3.3, 1. ],
       [5.1, 3.8, 1.9, 0.4],
       [4.9, 3.1, 1.5, 0.1],
       [5.9, 3. , 5.1, 1.8],
       [5.8, 2.7, 4.1, 1. ],
       [5.2, 3.4, 1.4, 0.2],
       [4.6, 3.6, 1. , 0.2],
       [6.3, 2.7, 4.9, 1.8],
       [6.4, 2.8, 5.6, 2.2],
       [5.8, 2.7, 5.1, 1.9],
       [5.5, 2.6, 4.4, 1.2],
       [7.1, 3. , 5.9, 2.1],
       [6.5, 3.2, 5.1, 2. ],
       [5.7, 4.4, 1.5, 0.4],
       [5.1, 3.7, 1.5, 0.4],
       [5.4, 3.9, 1.3, 0.4],
       [6.5, 2.8, 4.6, 1.5],
       [5.1, 2.5, 3. , 1.1],
       [5.1, 3.4, 1.5, 0.2],
       [6.6, 2.9, 4.6, 1.3],
       [6.6, 3. , 4.4, 1.4],
       [4.9, 3. , 1.4, 0.2],
       [5.7, 2.9, 4.2, 1.3],
       [5.8, 2.7, 3.9, 1.2],
       [4.6, 3.1, 1.5, 0.2],
       [6.8, 3.2, 5.9, 2.3],
       [6.7, 3.3, 5.7, 2.5],
       [6.1, 2.8, 4. , 1.3],
       [5.2, 2.7, 3.9, 1.4],
       [7.2, 3.2, 6. , 1.8],
       [5.1, 3.8, 1.5, 0.3],
       [6.3, 3.3, 4.7, 1.6],
       [5.8, 2.8, 5.1, 2.4],
       [6.9, 3.1, 5.1, 2.3],
       [6. , 2.7, 5.1, 1.6],
       [5.9, 3. , 4.2, 1.5],
       [7. , 3.2, 4.7, 1.4],
       [6.2, 3.4, 5.4, 2.3],
       [7.7, 2.6, 6.9, 2.3],
       [5.4, 3. , 4.5, 1.5],
       [4.8, 3.4, 1.6, 0.2],
       [5.4, 3.7, 1.5, 0.2],
       [4.4, 2.9, 1.4, 0.2],
       [5.5, 4.2, 1.4, 0.2],
       [6.5, 3. , 5.2, 2. ],
       [6.3, 2.5, 5. , 1.9],
       [6. , 3.4, 4.5, 1.6],
       [5.5, 2.4, 3.8, 1.1],
       [6.8, 2.8, 4.8, 1.4],
       [6.3, 3.3, 6. , 2.5],
       [5.7, 2.8, 4.1, 1.3],
       [6.7, 2.5, 5.8, 1.8],
       [6.4, 2.8, 5.6, 2.1],
       [5. , 3.4, 1.5, 0.2],
       [6.7, 3. , 5.2, 2.3],
       [5. , 2.3, 3.3, 1. ],
       [5.2, 3.5, 1.5, 0.2],
       [6. , 3. , 4.8, 1.8],
       [4.7, 3.2, 1.3, 0.2],
       [6.1, 2.6, 5.6, 1.4],
       [6.7, 3. , 5. , 1.7],
       [5.5, 2.3, 4. , 1.3],
       [6.8, 3. , 5.5, 2.1],
       [6. , 2.2, 5. , 1.5],
       [5.6, 3. , 4.5, 1.5],
       [6.7, 3.3, 5.7, 2.1],
       [6.7, 3.1, 5.6, 2.4],
       [6. , 2.2, 4. , 1. ],
       [6.3, 3.4, 5.6, 2.4],
       [7.4, 2.8, 6.1, 1.9],
       [5.4, 3.9, 1.7, 0.4],
       [6.1, 2.8, 4.7, 1.2],
       [5.8, 4. , 1.2, 0.2],
       [6.5, 3. , 5.8, 2.2],
       [7.7, 2.8, 6.7, 2. ]])

使用算法

from script.kNN_function.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X,y)
X_train.shape
(120, 4)
X
array([[5.1, 3.5, 1.4, 0.2],
       [4.9, 3. , 1.4, 0.2],
       [4.7, 3.2, 1.3, 0.2],
       [4.6, 3.1, 1.5, 0.2],
       [5. , 3.6, 1.4, 0.2],
       [5.4, 3.9, 1.7, 0.4],
       [4.6, 3.4, 1.4, 0.3],
       [5. , 3.4, 1.5, 0.2],
       [4.4, 2.9, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5.4, 3.7, 1.5, 0.2],
       [4.8, 3.4, 1.6, 0.2],
       [4.8, 3. , 1.4, 0.1],
       [4.3, 3. , 1.1, 0.1],
       [5.8, 4. , 1.2, 0.2],
       [5.7, 4.4, 1.5, 0.4],
       [5.4, 3.9, 1.3, 0.4],
       [5.1, 3.5, 1.4, 0.3],
       [5.7, 3.8, 1.7, 0.3],
       [5.1, 3.8, 1.5, 0.3],
       [5.4, 3.4, 1.7, 0.2],
       [5.1, 3.7, 1.5, 0.4],
       [4.6, 3.6, 1. , 0.2],
       [5.1, 3.3, 1.7, 0.5],
       [4.8, 3.4, 1.9, 0.2],
       [5. , 3. , 1.6, 0.2],
       [5. , 3.4, 1.6, 0.4],
       [5.2, 3.5, 1.5, 0.2],
       [5.2, 3.4, 1.4, 0.2],
       [4.7, 3.2, 1.6, 0.2],
       [4.8, 3.1, 1.6, 0.2],
       [5.4, 3.4, 1.5, 0.4],
       [5.2, 4.1, 1.5, 0.1],
       [5.5, 4.2, 1.4, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [5. , 3.2, 1.2, 0.2],
       [5.5, 3.5, 1.3, 0.2],
       [4.9, 3.1, 1.5, 0.1],
       [4.4, 3. , 1.3, 0.2],
       [5.1, 3.4, 1.5, 0.2],
       [5. , 3.5, 1.3, 0.3],
       [4.5, 2.3, 1.3, 0.3],
       [4.4, 3.2, 1.3, 0.2],
       [5. , 3.5, 1.6, 0.6],
       [5.1, 3.8, 1.9, 0.4],
       [4.8, 3. , 1.4, 0.3],
       [5.1, 3.8, 1.6, 0.2],
       [4.6, 3.2, 1.4, 0.2],
       [5.3, 3.7, 1.5, 0.2],
       [5. , 3.3, 1.4, 0.2],
       [7. , 3.2, 4.7, 1.4],
       [6.4, 3.2, 4.5, 1.5],
       [6.9, 3.1, 4.9, 1.5],
       [5.5, 2.3, 4. , 1.3],
       [6.5, 2.8, 4.6, 1.5],
       [5.7, 2.8, 4.5, 1.3],
       [6.3, 3.3, 4.7, 1.6],
       [4.9, 2.4, 3.3, 1. ],
       [6.6, 2.9, 4.6, 1.3],
       [5.2, 2.7, 3.9, 1.4],
       [5. , 2. , 3.5, 1. ],
       [5.9, 3. , 4.2, 1.5],
       [6. , 2.2, 4. , 1. ],
       [6.1, 2.9, 4.7, 1.4],
       [5.6, 2.9, 3.6, 1.3],
       [6.7, 3.1, 4.4, 1.4],
       [5.6, 3. , 4.5, 1.5],
       [5.8, 2.7, 4.1, 1. ],
       [6.2, 2.2, 4.5, 1.5],
       [5.6, 2.5, 3.9, 1.1],
       [5.9, 3.2, 4.8, 1.8],
       [6.1, 2.8, 4. , 1.3],
       [6.3, 2.5, 4.9, 1.5],
       [6.1, 2.8, 4.7, 1.2],
       [6.4, 2.9, 4.3, 1.3],
       [6.6, 3. , 4.4, 1.4],
       [6.8, 2.8, 4.8, 1.4],
       [6.7, 3. , 5. , 1.7],
       [6. , 2.9, 4.5, 1.5],
       [5.7, 2.6, 3.5, 1. ],
       [5.5, 2.4, 3.8, 1.1],
       [5.5, 2.4, 3.7, 1. ],
       [5.8, 2.7, 3.9, 1.2],
       [6. , 2.7, 5.1, 1.6],
       [5.4, 3. , 4.5, 1.5],
       [6. , 3.4, 4.5, 1.6],
       [6.7, 3.1, 4.7, 1.5],
       [6.3, 2.3, 4.4, 1.3],
       [5.6, 3. , 4.1, 1.3],
       [5.5, 2.5, 4. , 1.3],
       [5.5, 2.6, 4.4, 1.2],
       [6.1, 3. , 4.6, 1.4],
       [5.8, 2.6, 4. , 1.2],
       [5. , 2.3, 3.3, 1. ],
       [5.6, 2.7, 4.2, 1.3],
       [5.7, 3. , 4.2, 1.2],
       [5.7, 2.9, 4.2, 1.3],
       [6.2, 2.9, 4.3, 1.3],
       [5.1, 2.5, 3. , 1.1],
       [5.7, 2.8, 4.1, 1.3],
       [6.3, 3.3, 6. , 2.5],
       [5.8, 2.7, 5.1, 1.9],
       [7.1, 3. , 5.9, 2.1],
       [6.3, 2.9, 5.6, 1.8],
       [6.5, 3. , 5.8, 2.2],
       [7.6, 3. , 6.6, 2.1],
       [4.9, 2.5, 4.5, 1.7],
       [7.3, 2.9, 6.3, 1.8],
       [6.7, 2.5, 5.8, 1.8],
       [7.2, 3.6, 6.1, 2.5],
       [6.5, 3.2, 5.1, 2. ],
       [6.4, 2.7, 5.3, 1.9],
       [6.8, 3. , 5.5, 2.1],
       [5.7, 2.5, 5. , 2. ],
       [5.8, 2.8, 5.1, 2.4],
       [6.4, 3.2, 5.3, 2.3],
       [6.5, 3. , 5.5, 1.8],
       [7.7, 3.8, 6.7, 2.2],
       [7.7, 2.6, 6.9, 2.3],
       [6. , 2.2, 5. , 1.5],
       [6.9, 3.2, 5.7, 2.3],
       [5.6, 2.8, 4.9, 2. ],
       [7.7, 2.8, 6.7, 2. ],
       [6.3, 2.7, 4.9, 1.8],
       [6.7, 3.3, 5.7, 2.1],
       [7.2, 3.2, 6. , 1.8],
       [6.2, 2.8, 4.8, 1.8],
       [6.1, 3. , 4.9, 1.8],
       [6.4, 2.8, 5.6, 2.1],
       [7.2, 3. , 5.8, 1.6],
       [7.4, 2.8, 6.1, 1.9],
       [7.9, 3.8, 6.4, 2. ],
       [6.4, 2.8, 5.6, 2.2],
       [6.3, 2.8, 5.1, 1.5],
       [6.1, 2.6, 5.6, 1.4],
       [7.7, 3. , 6.1, 2.3],
       [6.3, 3.4, 5.6, 2.4],
       [6.4, 3.1, 5.5, 1.8],
       [6. , 3. , 4.8, 1.8],
       [6.9, 3.1, 5.4, 2.1],
       [6.7, 3.1, 5.6, 2.4],
       [6.9, 3.1, 5.1, 2.3],
       [5.8, 2.7, 5.1, 1.9],
       [6.8, 3.2, 5.9, 2.3],
       [6.7, 3.3, 5.7, 2.5],
       [6.7, 3. , 5.2, 2.3],
       [6.3, 2.5, 5. , 1.9],
       [6.5, 3. , 5.2, 2. ],
       [6.2, 3.4, 5.4, 2.3],
       [5.9, 3. , 5.1, 1.8]])
from script.kNN_function.kNN import kNNClassifier
my_knn_clf = kNNClassifier(3)
my_knn_clf.fit(X_train,y_train)
y_predict = my_knn_clf.predict(X_test)
y_predict
array([2, 0, 1, 1, 2, 1, 2, 0, 2, 0, 1, 1, 2, 0, 0, 1, 2, 2, 0, 0, 0, 0,
       0, 1, 2, 2, 2, 2, 0, 0])
sum(y_predict ==y_test)/len(y_test)
0.9333333333333333

sklean 中的train_test_split

from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X,y)
0%