Le modèle des plus proches voisins KNeighborsRegressor est paramétrable. Le nombre de voisins est variables, la prédiction peut dépendre du plus proche voisins ou des $k$ plus proches proches. Comment choisir $k$ ?
%matplotlib inline
from papierstat.datasets import load_wines_dataset
df = load_wines_dataset()
import numpy.random as rnd
index = list(df.index)
rnd.shuffle(index)
df_alea = df.iloc[index, :].reset_index(drop=True)
X = df_alea.drop(['quality', 'color'], axis=1)
y = df_alea['quality']
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y)
On fait une boucle sur un paramètre.
from sklearn.neighbors import KNeighborsRegressor
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import r2_score
voisins = []
r2s = []
for n in range(1, 10):
knn = KNeighborsRegressor(n_neighbors=n)
knn.fit(X_train, y_train)
r2 = r2_score(y_test, knn.predict(X_test))
voisins.append(n)
r2s.append(r2)
import pandas
df = pandas.DataFrame(dict(voisin=voisins, r2=r2s))
ax = df.plot(x='voisin', y='r2')
ax.set_title("Performance en fonction\ndu nombre de voisins");
La fonction GridSearchCV automatise la recherche d'un optimum parmi les hyperparamètre, elle utilise notamment la validation croisée. On teste toutes les valeurs de $k$ de 1 à 20.
parameters = {'n_neighbors': list(range(1,31))}
from sklearn.neighbors import KNeighborsRegressor
from sklearn.model_selection import GridSearchCV
knn = KNeighborsRegressor()
grid = GridSearchCV(knn, parameters, verbose=2, return_train_score=True)
grid.fit(X, y)
Fitting 3 folds for each of 30 candidates, totalling 90 fits [CV] n_neighbors=1 ................................................... [CV] .................................... n_neighbors=1, total= 0.0s [CV] n_neighbors=1 ................................................... [CV] .................................... n_neighbors=1, total= 0.0s [CV] n_neighbors=1 ................................................... [CV] .................................... n_neighbors=1, total= 0.0s [CV] n_neighbors=2 ...................................................
[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s
[CV] .................................... n_neighbors=2, total= 0.0s [CV] n_neighbors=2 ................................................... [CV] .................................... n_neighbors=2, total= 0.0s [CV] n_neighbors=2 ................................................... [CV] .................................... n_neighbors=2, total= 0.0s [CV] n_neighbors=3 ................................................... [CV] .................................... n_neighbors=3, total= 0.0s [CV] n_neighbors=3 ................................................... [CV] .................................... n_neighbors=3, total= 0.0s [CV] n_neighbors=3 ................................................... [CV] .................................... n_neighbors=3, total= 0.0s [CV] n_neighbors=4 ................................................... [CV] .................................... n_neighbors=4, total= 0.0s [CV] n_neighbors=4 ................................................... [CV] .................................... n_neighbors=4, total= 0.0s [CV] n_neighbors=4 ................................................... [CV] .................................... n_neighbors=4, total= 0.0s [CV] n_neighbors=5 ................................................... [CV] .................................... n_neighbors=5, total= 0.0s [CV] n_neighbors=5 ................................................... [CV] .................................... n_neighbors=5, total= 0.0s [CV] n_neighbors=5 ................................................... [CV] .................................... n_neighbors=5, total= 0.0s [CV] n_neighbors=6 ................................................... [CV] .................................... n_neighbors=6, total= 0.0s [CV] n_neighbors=6 ................................................... [CV] .................................... n_neighbors=6, total= 0.0s [CV] n_neighbors=6 ................................................... [CV] .................................... n_neighbors=6, total= 0.0s [CV] n_neighbors=7 ................................................... [CV] .................................... n_neighbors=7, total= 0.0s [CV] n_neighbors=7 ................................................... [CV] .................................... n_neighbors=7, total= 0.0s [CV] n_neighbors=7 ................................................... [CV] .................................... n_neighbors=7, total= 0.0s [CV] n_neighbors=8 ................................................... [CV] .................................... n_neighbors=8, total= 0.0s [CV] n_neighbors=8 ................................................... [CV] .................................... n_neighbors=8, total= 0.0s [CV] n_neighbors=8 ................................................... [CV] .................................... n_neighbors=8, total= 0.0s [CV] n_neighbors=9 ................................................... [CV] .................................... n_neighbors=9, total= 0.0s [CV] n_neighbors=9 ................................................... [CV] .................................... n_neighbors=9, total= 0.0s [CV] n_neighbors=9 ................................................... [CV] .................................... n_neighbors=9, total= 0.0s [CV] n_neighbors=10 .................................................. [CV] ................................... n_neighbors=10, total= 0.0s [CV] n_neighbors=10 .................................................. [CV] ................................... n_neighbors=10, total= 0.0s [CV] n_neighbors=10 .................................................. [CV] ................................... n_neighbors=10, total= 0.0s [CV] n_neighbors=11 .................................................. [CV] ................................... n_neighbors=11, total= 0.0s [CV] n_neighbors=11 .................................................. [CV] ................................... n_neighbors=11, total= 0.0s [CV] n_neighbors=11 .................................................. [CV] ................................... n_neighbors=11, total= 0.0s [CV] n_neighbors=12 .................................................. [CV] ................................... n_neighbors=12, total= 0.0s [CV] n_neighbors=12 .................................................. [CV] ................................... n_neighbors=12, total= 0.0s [CV] n_neighbors=12 .................................................. [CV] ................................... n_neighbors=12, total= 0.0s [CV] n_neighbors=13 .................................................. [CV] ................................... n_neighbors=13, total= 0.0s [CV] n_neighbors=13 .................................................. [CV] ................................... n_neighbors=13, total= 0.0s [CV] n_neighbors=13 .................................................. [CV] ................................... n_neighbors=13, total= 0.0s [CV] n_neighbors=14 .................................................. [CV] ................................... n_neighbors=14, total= 0.0s [CV] n_neighbors=14 .................................................. [CV] ................................... n_neighbors=14, total= 0.0s [CV] n_neighbors=14 .................................................. [CV] ................................... n_neighbors=14, total= 0.0s [CV] n_neighbors=15 .................................................. [CV] ................................... n_neighbors=15, total= 0.0s [CV] n_neighbors=15 .................................................. [CV] ................................... n_neighbors=15, total= 0.0s [CV] n_neighbors=15 .................................................. [CV] ................................... n_neighbors=15, total= 0.0s [CV] n_neighbors=16 .................................................. [CV] ................................... n_neighbors=16, total= 0.0s [CV] n_neighbors=16 .................................................. [CV] ................................... n_neighbors=16, total= 0.0s [CV] n_neighbors=16 .................................................. [CV] ................................... n_neighbors=16, total= 0.0s [CV] n_neighbors=17 .................................................. [CV] ................................... n_neighbors=17, total= 0.0s [CV] n_neighbors=17 .................................................. [CV] ................................... n_neighbors=17, total= 0.0s [CV] n_neighbors=17 .................................................. [CV] ................................... n_neighbors=17, total= 0.0s [CV] n_neighbors=18 .................................................. [CV] ................................... n_neighbors=18, total= 0.0s [CV] n_neighbors=18 .................................................. [CV] ................................... n_neighbors=18, total= 0.0s [CV] n_neighbors=18 .................................................. [CV] ................................... n_neighbors=18, total= 0.0s [CV] n_neighbors=19 .................................................. [CV] ................................... n_neighbors=19, total= 0.0s [CV] n_neighbors=19 .................................................. [CV] ................................... n_neighbors=19, total= 0.0s [CV] n_neighbors=19 .................................................. [CV] ................................... n_neighbors=19, total= 0.0s [CV] n_neighbors=20 .................................................. [CV] ................................... n_neighbors=20, total= 0.0s [CV] n_neighbors=20 .................................................. [CV] ................................... n_neighbors=20, total= 0.0s [CV] n_neighbors=20 .................................................. [CV] ................................... n_neighbors=20, total= 0.0s [CV] n_neighbors=21 .................................................. [CV] ................................... n_neighbors=21, total= 0.0s [CV] n_neighbors=21 .................................................. [CV] ................................... n_neighbors=21, total= 0.0s [CV] n_neighbors=21 .................................................. [CV] ................................... n_neighbors=21, total= 0.0s [CV] n_neighbors=22 .................................................. [CV] ................................... n_neighbors=22, total= 0.0s [CV] n_neighbors=22 .................................................. [CV] ................................... n_neighbors=22, total= 0.0s [CV] n_neighbors=22 .................................................. [CV] ................................... n_neighbors=22, total= 0.0s [CV] n_neighbors=23 .................................................. [CV] ................................... n_neighbors=23, total= 0.0s [CV] n_neighbors=23 .................................................. [CV] ................................... n_neighbors=23, total= 0.0s [CV] n_neighbors=23 .................................................. [CV] ................................... n_neighbors=23, total= 0.0s [CV] n_neighbors=24 .................................................. [CV] ................................... n_neighbors=24, total= 0.0s [CV] n_neighbors=24 .................................................. [CV] ................................... n_neighbors=24, total= 0.0s [CV] n_neighbors=24 .................................................. [CV] ................................... n_neighbors=24, total= 0.0s [CV] n_neighbors=25 .................................................. [CV] ................................... n_neighbors=25, total= 0.0s [CV] n_neighbors=25 .................................................. [CV] ................................... n_neighbors=25, total= 0.0s [CV] n_neighbors=25 .................................................. [CV] ................................... n_neighbors=25, total= 0.0s [CV] n_neighbors=26 .................................................. [CV] ................................... n_neighbors=26, total= 0.0s [CV] n_neighbors=26 .................................................. [CV] ................................... n_neighbors=26, total= 0.0s [CV] n_neighbors=26 .................................................. [CV] ................................... n_neighbors=26, total= 0.0s [CV] n_neighbors=27 .................................................. [CV] ................................... n_neighbors=27, total= 0.0s [CV] n_neighbors=27 .................................................. [CV] ................................... n_neighbors=27, total= 0.0s [CV] n_neighbors=27 .................................................. [CV] ................................... n_neighbors=27, total= 0.0s [CV] n_neighbors=28 .................................................. [CV] ................................... n_neighbors=28, total= 0.0s [CV] n_neighbors=28 .................................................. [CV] ................................... n_neighbors=28, total= 0.0s [CV] n_neighbors=28 .................................................. [CV] ................................... n_neighbors=28, total= 0.0s [CV] n_neighbors=29 .................................................. [CV] ................................... n_neighbors=29, total= 0.0s [CV] n_neighbors=29 .................................................. [CV] ................................... n_neighbors=29, total= 0.0s [CV] n_neighbors=29 .................................................. [CV] ................................... n_neighbors=29, total= 0.0s [CV] n_neighbors=30 .................................................. [CV] ................................... n_neighbors=30, total= 0.0s [CV] n_neighbors=30 .................................................. [CV] ................................... n_neighbors=30, total= 0.0s [CV] n_neighbors=30 .................................................. [CV] ................................... n_neighbors=30, total= 0.0s
[Parallel(n_jobs=1)]: Done 90 out of 90 | elapsed: 13.1s finished
GridSearchCV(cv=None, error_score='raise', estimator=KNeighborsRegressor(algorithm='auto', leaf_size=30, metric='minkowski', metric_params=None, n_jobs=1, n_neighbors=5, p=2, weights='uniform'), fit_params=None, iid=True, n_jobs=1, param_grid={'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]}, pre_dispatch='2*n_jobs', refit=True, return_train_score=True, scoring=None, verbose=2)
res = grid.cv_results_
k = res['param_n_neighbors']
train_score = res['mean_train_score']
test_score = res['mean_test_score']
import pandas
df_score = pandas.DataFrame(dict(k=k, test=test_score, train=train_score))
ax = df_score.plot(x='k', y='train', figsize=(6, 4))
df_score.plot(x='k', y='test', ax=ax, grid=True)
ax.set_title("Evolution de la performance sur\nles bases d'apprentissage et de test" +
"\nen fonction du nombre de voisins")
ax.set_ylabel("r2");
On voit que le modèle gagne en pertinence sur la base de test et que le nombre de voisins optimal parmi ceux essayés se situe autour de 15.
df_score[12:17]
k | test | train | |
---|---|---|---|
12 | 13 | 0.159266 | 0.279302 |
13 | 14 | 0.160284 | 0.269703 |
14 | 15 | 0.157910 | 0.261720 |
15 | 16 | 0.159066 | 0.256823 |
16 | 17 | 0.158029 | 0.249684 |
L'erreur sur la base d'apprentissage augmente de manière sensible ($R^2$ baisse). Voyons ce qu'il en est un peu plus loin.
parameters = {'n_neighbors': list(range(5, 51, 5)) + list(range(50, 201, 20))}
grid = GridSearchCV(knn, parameters, verbose=2, return_train_score=True)
grid.fit(X, y)
Fitting 3 folds for each of 18 candidates, totalling 54 fits [CV] n_neighbors=5 ................................................... [CV] .................................... n_neighbors=5, total= 0.0s [CV] n_neighbors=5 ................................................... [CV] .................................... n_neighbors=5, total= 0.0s [CV] n_neighbors=5 ...................................................
[Parallel(n_jobs=1)]: Done 1 out of 1 | elapsed: 0.0s remaining: 0.0s
[CV] .................................... n_neighbors=5, total= 0.0s [CV] n_neighbors=10 .................................................. [CV] ................................... n_neighbors=10, total= 0.0s [CV] n_neighbors=10 .................................................. [CV] ................................... n_neighbors=10, total= 0.0s [CV] n_neighbors=10 .................................................. [CV] ................................... n_neighbors=10, total= 0.0s [CV] n_neighbors=15 .................................................. [CV] ................................... n_neighbors=15, total= 0.0s [CV] n_neighbors=15 .................................................. [CV] ................................... n_neighbors=15, total= 0.0s [CV] n_neighbors=15 .................................................. [CV] ................................... n_neighbors=15, total= 0.0s [CV] n_neighbors=20 .................................................. [CV] ................................... n_neighbors=20, total= 0.0s [CV] n_neighbors=20 .................................................. [CV] ................................... n_neighbors=20, total= 0.0s [CV] n_neighbors=20 .................................................. [CV] ................................... n_neighbors=20, total= 0.0s [CV] n_neighbors=25 .................................................. [CV] ................................... n_neighbors=25, total= 0.0s [CV] n_neighbors=25 .................................................. [CV] ................................... n_neighbors=25, total= 0.0s [CV] n_neighbors=25 .................................................. [CV] ................................... n_neighbors=25, total= 0.0s [CV] n_neighbors=30 .................................................. [CV] ................................... n_neighbors=30, total= 0.0s [CV] n_neighbors=30 .................................................. [CV] ................................... n_neighbors=30, total= 0.0s [CV] n_neighbors=30 .................................................. [CV] ................................... n_neighbors=30, total= 0.0s [CV] n_neighbors=35 .................................................. [CV] ................................... n_neighbors=35, total= 0.0s [CV] n_neighbors=35 .................................................. [CV] ................................... n_neighbors=35, total= 0.0s [CV] n_neighbors=35 .................................................. [CV] ................................... n_neighbors=35, total= 0.0s [CV] n_neighbors=40 .................................................. [CV] ................................... n_neighbors=40, total= 0.0s [CV] n_neighbors=40 .................................................. [CV] ................................... n_neighbors=40, total= 0.0s [CV] n_neighbors=40 .................................................. [CV] ................................... n_neighbors=40, total= 0.0s [CV] n_neighbors=45 .................................................. [CV] ................................... n_neighbors=45, total= 0.0s [CV] n_neighbors=45 .................................................. [CV] ................................... n_neighbors=45, total= 0.0s [CV] n_neighbors=45 .................................................. [CV] ................................... n_neighbors=45, total= 0.1s [CV] n_neighbors=50 .................................................. [CV] ................................... n_neighbors=50, total= 0.0s [CV] n_neighbors=50 .................................................. [CV] ................................... n_neighbors=50, total= 0.0s [CV] n_neighbors=50 .................................................. [CV] ................................... n_neighbors=50, total= 0.0s [CV] n_neighbors=50 .................................................. [CV] ................................... n_neighbors=50, total= 0.0s [CV] n_neighbors=50 .................................................. [CV] ................................... n_neighbors=50, total= 0.1s [CV] n_neighbors=50 .................................................. [CV] ................................... n_neighbors=50, total= 0.0s [CV] n_neighbors=70 .................................................. [CV] ................................... n_neighbors=70, total= 0.0s [CV] n_neighbors=70 .................................................. [CV] ................................... n_neighbors=70, total= 0.0s [CV] n_neighbors=70 .................................................. [CV] ................................... n_neighbors=70, total= 0.0s [CV] n_neighbors=90 .................................................. [CV] ................................... n_neighbors=90, total= 0.0s [CV] n_neighbors=90 .................................................. [CV] ................................... n_neighbors=90, total= 0.0s [CV] n_neighbors=90 .................................................. [CV] ................................... n_neighbors=90, total= 0.0s [CV] n_neighbors=110 ................................................. [CV] .................................. n_neighbors=110, total= 0.1s [CV] n_neighbors=110 ................................................. [CV] .................................. n_neighbors=110, total= 0.1s [CV] n_neighbors=110 ................................................. [CV] .................................. n_neighbors=110, total= 0.1s [CV] n_neighbors=130 ................................................. [CV] .................................. n_neighbors=130, total= 0.1s [CV] n_neighbors=130 ................................................. [CV] .................................. n_neighbors=130, total= 0.1s [CV] n_neighbors=130 ................................................. [CV] .................................. n_neighbors=130, total= 0.1s [CV] n_neighbors=150 ................................................. [CV] .................................. n_neighbors=150, total= 0.1s [CV] n_neighbors=150 ................................................. [CV] .................................. n_neighbors=150, total= 0.1s [CV] n_neighbors=150 ................................................. [CV] .................................. n_neighbors=150, total= 0.1s [CV] n_neighbors=170 ................................................. [CV] .................................. n_neighbors=170, total= 0.1s [CV] n_neighbors=170 ................................................. [CV] .................................. n_neighbors=170, total= 0.1s [CV] n_neighbors=170 ................................................. [CV] .................................. n_neighbors=170, total= 0.1s [CV] n_neighbors=190 ................................................. [CV] .................................. n_neighbors=190, total= 0.1s [CV] n_neighbors=190 ................................................. [CV] .................................. n_neighbors=190, total= 0.1s [CV] n_neighbors=190 ................................................. [CV] .................................. n_neighbors=190, total= 0.1s
[Parallel(n_jobs=1)]: Done 54 out of 54 | elapsed: 18.0s finished
GridSearchCV(cv=None, error_score='raise', estimator=KNeighborsRegressor(algorithm='auto', leaf_size=30, metric='minkowski', metric_params=None, n_jobs=1, n_neighbors=5, p=2, weights='uniform'), fit_params=None, iid=True, n_jobs=1, param_grid={'n_neighbors': [5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 50, 70, 90, 110, 130, 150, 170, 190]}, pre_dispatch='2*n_jobs', refit=True, return_train_score=True, scoring=None, verbose=2)
res = grid.cv_results_
k = res['param_n_neighbors']
train_score = res['mean_train_score']
test_score = res['mean_test_score']
import pandas
df_score = pandas.DataFrame(dict(k=k, test=test_score, train=train_score))
ax = df_score.plot(x='k', y='train', figsize=(6, 4))
df_score.plot(x='k', y='test', ax=ax, grid=True)
ax.set_title("Evolution de la performance sur\nles bases d'apprentissage et de test" +
"\nen fonction du nombre de voisins")
ax.set_ylabel("r2");
Après 25 voisins, la pertinence du modèle décroît fortement, ce qui paraît normal car plus il y a de voisins, moins la prédiction est locale en quelque sorte.