{"cells": [{"cell_type": "markdown", "metadata": {}, "source": ["# Classification multi-classe\n", "\n", "On cherche \u00e0 pr\u00e9dire la note d'un vin avec un classifieur multi-classe."]}, {"cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": ["%matplotlib inline"]}, {"cell_type": "code", "execution_count": 2, "metadata": {"scrolled": true}, "outputs": [], "source": ["from papierstat.datasets import load_wines_dataset\n", "df = load_wines_dataset()\n", "X = df.drop(['quality', 'color'], axis=1)\n", "y = df['quality']"]}, {"cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": ["from sklearn.model_selection import train_test_split\n", "X_train, X_test, y_train, y_test = train_test_split(X, y)"]}, {"cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [{"data": {"text/plain": ["LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n", " intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,\n", " penalty='l2', random_state=None, solver='liblinear', tol=0.0001,\n", " verbose=0, warm_start=False)"]}, "execution_count": 5, "metadata": {}, "output_type": "execute_result"}], "source": ["from sklearn.linear_model import LogisticRegression\n", "clr = LogisticRegression()\n", "clr.fit(X_train, y_train)"]}, {"cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [{"data": {"text/plain": ["53.84615384615385"]}, "execution_count": 6, "metadata": {}, "output_type": "execute_result"}], "source": ["import numpy\n", "numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100"]}, {"cell_type": "markdown", "metadata": {}, "source": ["On regarde la matrice de confusion."]}, {"cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", "\n", "
\n", " \n", " \n", " | \n", " 0 | \n", " 1 | \n", " 2 | \n", " 3 | \n", " 4 | \n", " 5 | \n", " 6 | \n", "
\n", " \n", " \n", " \n", " 0 | \n", " 0 | \n", " 0 | \n", " 6 | \n", " 0 | \n", " 0 | \n", " 0 | \n", " 0 | \n", "
\n", " \n", " 1 | \n", " 0 | \n", " 0 | \n", " 39 | \n", " 14 | \n", " 1 | \n", " 0 | \n", " 0 | \n", "
\n", " \n", " 2 | \n", " 0 | \n", " 0 | \n", " 338 | \n", " 208 | \n", " 2 | \n", " 0 | \n", " 0 | \n", "
\n", " \n", " 3 | \n", " 0 | \n", " 0 | \n", " 195 | \n", " 517 | \n", " 17 | \n", " 0 | \n", " 0 | \n", "
\n", " \n", " 4 | \n", " 0 | \n", " 0 | \n", " 19 | \n", " 200 | \n", " 20 | \n", " 0 | \n", " 0 | \n", "
\n", " \n", " 5 | \n", " 0 | \n", " 0 | \n", " 2 | \n", " 38 | \n", " 8 | \n", " 0 | \n", " 0 | \n", "
\n", " \n", " 6 | \n", " 0 | \n", " 0 | \n", " 0 | \n", " 1 | \n", " 0 | \n", " 0 | \n", " 0 | \n", "
\n", " \n", "
\n", "
"], "text/plain": [" 0 1 2 3 4 5 6\n", "0 0 0 6 0 0 0 0\n", "1 0 0 39 14 1 0 0\n", "2 0 0 338 208 2 0 0\n", "3 0 0 195 517 17 0 0\n", "4 0 0 19 200 20 0 0\n", "5 0 0 2 38 8 0 0\n", "6 0 0 0 1 0 0 0"]}, "execution_count": 7, "metadata": {}, "output_type": "execute_result"}], "source": ["from sklearn.metrics import confusion_matrix\n", "import pandas\n", "pandas.DataFrame(confusion_matrix(y_test, clr.predict(X_test)))"]}, {"cell_type": "markdown", "metadata": {}, "source": ["On l'affiche diff\u00e9remment avec le nom des classes."]}, {"cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [{"data": {"text/html": ["\n", "\n", "
\n", " \n", " \n", " | \n", " 3 | \n", " 4 | \n", " 5 | \n", " 6 | \n", " 7 | \n", " 8 | \n", " 9 | \n", "
\n", " \n", " \n", " \n", " 3 | \n", " 0 | \n", " 0 | \n", " 6 | \n", " 0 | \n", " 0 | \n", " 0 | \n", " 0 | \n", "
\n", " \n", " 4 | \n", " 0 | \n", " 0 | \n", " 39 | \n", " 14 | \n", " 1 | \n", " 0 | \n", " 0 | \n", "
\n", " \n", " 5 | \n", " 0 | \n", " 0 | \n", " 338 | \n", " 208 | \n", " 2 | \n", " 0 | \n", " 0 | \n", "
\n", " \n", " 6 | \n", " 0 | \n", " 0 | \n", " 195 | \n", " 517 | \n", " 17 | \n", " 0 | \n", " 0 | \n", "
\n", " \n", " 7 | \n", " 0 | \n", " 0 | \n", " 19 | \n", " 200 | \n", " 20 | \n", " 0 | \n", " 0 | \n", "
\n", " \n", " 8 | \n", " 0 | \n", " 0 | \n", " 2 | \n", " 38 | \n", " 8 | \n", " 0 | \n", " 0 | \n", "
\n", " \n", " 9 | \n", " 0 | \n", " 0 | \n", " 0 | \n", " 1 | \n", " 0 | \n", " 0 | \n", " 0 | \n", "
\n", " \n", "
\n", "
"], "text/plain": [" 3 4 5 6 7 8 9\n", "3 0 0 6 0 0 0 0\n", "4 0 0 39 14 1 0 0\n", "5 0 0 338 208 2 0 0\n", "6 0 0 195 517 17 0 0\n", "7 0 0 19 200 20 0 0\n", "8 0 0 2 38 8 0 0\n", "9 0 0 0 1 0 0 0"]}, "execution_count": 8, "metadata": {}, "output_type": "execute_result"}], "source": ["conf = confusion_matrix(y_test, clr.predict(X_test))\n", "dfconf = pandas.DataFrame(conf)\n", "labels = list(clr.classes_)\n", "if len(labels) < dfconf.shape[1]:\n", " labels += [9] # La classe 9 est tr\u00e8s repr\u00e9sent\u00e9e, elle est parfois absente en train.\n", "elif len(labels) > dfconf.shape[1]:\n", " labels = labels[:dfconf.shape[1]] # ou l'inverse\n", "dfconf.columns = labels\n", "dfconf.index = labels\n", "dfconf"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Pas extraordinaire. On applique la strat\u00e9gie [OneVsRestClassifier](http://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsRestClassifier.html)."]}, {"cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [{"data": {"text/plain": ["OneVsRestClassifier(estimator=LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n", " intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,\n", " penalty='l2', random_state=None, solver='liblinear', tol=0.0001,\n", " verbose=0, warm_start=False),\n", " n_jobs=1)"]}, "execution_count": 9, "metadata": {}, "output_type": "execute_result"}], "source": ["from sklearn.multiclass import OneVsRestClassifier\n", "clr = OneVsRestClassifier(LogisticRegression())\n", "clr.fit(X_train, y_train)"]}, {"cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [{"data": {"text/plain": ["53.784615384615385"]}, "execution_count": 10, "metadata": {}, "output_type": "execute_result"}], "source": ["numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Le mod\u00e8le logistique r\u00e9gression multi-classe est \u00e9quivalent \u00e0 la strat\u00e9gie *OneVsRest*. Voyons l'autre."]}, {"cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [{"data": {"text/plain": ["OneVsOneClassifier(estimator=LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n", " intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,\n", " penalty='l2', random_state=None, solver='liblinear', tol=0.0001,\n", " verbose=0, warm_start=False),\n", " n_jobs=1)"]}, "execution_count": 11, "metadata": {}, "output_type": "execute_result"}], "source": ["from sklearn.multiclass import OneVsOneClassifier\n", "clr = OneVsOneClassifier(LogisticRegression())\n", "clr.fit(X_train, y_train)"]}, {"cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [{"data": {"text/plain": ["53.47692307692308"]}, "execution_count": 12, "metadata": {}, "output_type": "execute_result"}], "source": ["numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100"]}, {"cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [{"data": {"text/html": ["\n", "\n", "
\n", " \n", " \n", " | \n", " 3 | \n", " 4 | \n", " 5 | \n", " 6 | \n", " 7 | \n", " 8 | \n", " 9 | \n", "
\n", " \n", " \n", " \n", " 3 | \n", " 0 | \n", " 0 | \n", " 6 | \n", " 0 | \n", " 0 | \n", " 0 | \n", " 0 | \n", "
\n", " \n", " 4 | \n", " 0 | \n", " 0 | \n", " 38 | \n", " 15 | \n", " 1 | \n", " 0 | \n", " 0 | \n", "
\n", " \n", " 5 | \n", " 0 | \n", " 0 | \n", " 335 | \n", " 208 | \n", " 5 | \n", " 0 | \n", " 0 | \n", "
\n", " \n", " 6 | \n", " 0 | \n", " 0 | \n", " 197 | \n", " 491 | \n", " 41 | \n", " 0 | \n", " 0 | \n", "
\n", " \n", " 7 | \n", " 0 | \n", " 0 | \n", " 20 | \n", " 176 | \n", " 43 | \n", " 0 | \n", " 0 | \n", "
\n", " \n", " 8 | \n", " 0 | \n", " 0 | \n", " 1 | \n", " 34 | \n", " 13 | \n", " 0 | \n", " 0 | \n", "
\n", " \n", " 9 | \n", " 0 | \n", " 0 | \n", " 0 | \n", " 1 | \n", " 0 | \n", " 0 | \n", " 0 | \n", "
\n", " \n", "
\n", "
"], "text/plain": [" 3 4 5 6 7 8 9\n", "3 0 0 6 0 0 0 0\n", "4 0 0 38 15 1 0 0\n", "5 0 0 335 208 5 0 0\n", "6 0 0 197 491 41 0 0\n", "7 0 0 20 176 43 0 0\n", "8 0 0 1 34 13 0 0\n", "9 0 0 0 1 0 0 0"]}, "execution_count": 13, "metadata": {}, "output_type": "execute_result"}], "source": ["conf = confusion_matrix(y_test, clr.predict(X_test))\n", "dfconf = pandas.DataFrame(conf)\n", "labels = list(clr.classes_)\n", "if len(labels) < dfconf.shape[1]:\n", " labels += [9] # La classe 9 est tr\u00e8s repr\u00e9sent\u00e9e, elle est parfois absente en train.\n", "elif len(labels) > dfconf.shape[1]:\n", " labels = labels[:dfconf.shape[1]] # ou l'inverse\n", "dfconf.columns = labels\n", "dfconf.index = labels\n", "dfconf"]}, {"cell_type": "markdown", "metadata": {}, "source": ["A peu pr\u00e8s pareil mais sans doute pas de mani\u00e8re significative. Voyons avec un arbre de d\u00e9cision."]}, {"cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [{"data": {"text/plain": ["DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,\n", " max_features=None, max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_impurity_split=None,\n", " min_samples_leaf=1, min_samples_split=2,\n", " min_weight_fraction_leaf=0.0, presort=False, random_state=None,\n", " splitter='best')"]}, "execution_count": 14, "metadata": {}, "output_type": "execute_result"}], "source": ["from sklearn.tree import DecisionTreeClassifier\n", "clr = DecisionTreeClassifier()\n", "clr.fit(X_train, y_train)"]}, {"cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [{"data": {"text/plain": ["59.50769230769231"]}, "execution_count": 15, "metadata": {}, "output_type": "execute_result"}], "source": ["numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Et avec [OneVsRestClassifier](http://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsRestClassifier.html) :"]}, {"cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [{"data": {"text/plain": ["OneVsRestClassifier(estimator=DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,\n", " max_features=None, max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_impurity_split=None,\n", " min_samples_leaf=1, min_samples_split=2,\n", " min_weight_fraction_leaf=0.0, presort=False, random_state=None,\n", " splitter='best'),\n", " n_jobs=1)"]}, "execution_count": 16, "metadata": {}, "output_type": "execute_result"}], "source": ["clr = OneVsRestClassifier(DecisionTreeClassifier())\n", "clr.fit(X_train, y_train)"]}, {"cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [{"data": {"text/plain": ["52.92307692307693"]}, "execution_count": 17, "metadata": {}, "output_type": "execute_result"}], "source": ["numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Et avec [OneVsOneClassifier](http://scikit-learn.org/stable/modules/generated/sklearn.multiclass.OneVsOneClassifier.html)"]}, {"cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [{"data": {"text/plain": ["OneVsOneClassifier(estimator=DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,\n", " max_features=None, max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_impurity_split=None,\n", " min_samples_leaf=1, min_samples_split=2,\n", " min_weight_fraction_leaf=0.0, presort=False, random_state=None,\n", " splitter='best'),\n", " n_jobs=1)"]}, "execution_count": 18, "metadata": {}, "output_type": "execute_result"}], "source": ["clr = OneVsOneClassifier(DecisionTreeClassifier())\n", "clr.fit(X_train, y_train)"]}, {"cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [{"data": {"text/plain": ["60.12307692307692"]}, "execution_count": 19, "metadata": {}, "output_type": "execute_result"}], "source": ["numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Mieux."]}, {"cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [{"data": {"text/plain": ["RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='auto', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_impurity_split=None,\n", " min_samples_leaf=1, min_samples_split=2,\n", " min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=1,\n", " oob_score=False, random_state=None, verbose=0,\n", " warm_start=False)"]}, "execution_count": 20, "metadata": {}, "output_type": "execute_result"}], "source": ["from sklearn.ensemble import RandomForestClassifier\n", "clr = RandomForestClassifier()\n", "clr.fit(X_train, y_train)"]}, {"cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [{"data": {"text/plain": ["66.46153846153847"]}, "execution_count": 21, "metadata": {}, "output_type": "execute_result"}], "source": ["numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100"]}, {"cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [{"data": {"text/plain": ["OneVsRestClassifier(estimator=RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',\n", " max_depth=None, max_features='auto', max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_impurity_split=None,\n", " min_samples_leaf=1, min_samples_split=2,\n", " min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=1,\n", " oob_score=False, random_state=None, verbose=0,\n", " warm_start=False),\n", " n_jobs=1)"]}, "execution_count": 22, "metadata": {}, "output_type": "execute_result"}], "source": ["clr = OneVsRestClassifier(RandomForestClassifier())\n", "clr.fit(X_train, y_train)"]}, {"cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [{"data": {"text/plain": ["65.90769230769232"]}, "execution_count": 23, "metadata": {}, "output_type": "execute_result"}], "source": ["numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Proche, il faut affiner avec une validation crois\u00e9e."]}, {"cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [{"data": {"text/plain": ["MLPClassifier(activation='relu', alpha=0.0001, batch_size='auto', beta_1=0.9,\n", " beta_2=0.999, early_stopping=False, epsilon=1e-08,\n", " hidden_layer_sizes=30, learning_rate='constant',\n", " learning_rate_init=0.001, max_iter=600, momentum=0.9,\n", " nesterovs_momentum=True, power_t=0.5, random_state=None,\n", " shuffle=True, solver='adam', tol=0.0001, validation_fraction=0.1,\n", " verbose=False, warm_start=False)"]}, "execution_count": 24, "metadata": {}, "output_type": "execute_result"}], "source": ["from sklearn.neural_network import MLPClassifier\n", "clr = MLPClassifier(hidden_layer_sizes=30, max_iter=600)\n", "clr.fit(X_train, y_train)"]}, {"cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [{"data": {"text/plain": ["51.323076923076925"]}, "execution_count": 25, "metadata": {}, "output_type": "execute_result"}], "source": ["numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100"]}, {"cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [{"data": {"text/plain": ["OneVsRestClassifier(estimator=MLPClassifier(activation='relu', alpha=0.0001, batch_size='auto', beta_1=0.9,\n", " beta_2=0.999, early_stopping=False, epsilon=1e-08,\n", " hidden_layer_sizes=30, learning_rate='constant',\n", " learning_rate_init=0.001, max_iter=600, momentum=0.9,\n", " nesterovs_momentum=True, power_t=0.5, random_state=None,\n", " shuffle=True, solver='adam', tol=0.0001, validation_fraction=0.1,\n", " verbose=False, warm_start=False),\n", " n_jobs=1)"]}, "execution_count": 26, "metadata": {}, "output_type": "execute_result"}], "source": ["clr = OneVsRestClassifier(MLPClassifier(hidden_layer_sizes=30, max_iter=600))\n", "clr.fit(X_train, y_train)"]}, {"cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [{"data": {"text/plain": ["47.56923076923077"]}, "execution_count": 27, "metadata": {}, "output_type": "execute_result"}], "source": ["numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Pas foudroyant."]}, {"cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": []}], "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4"}}, "nbformat": 4, "nbformat_minor": 2}