{"cells": [{"cell_type": "markdown", "metadata": {}, "source": ["# Classification multi-classe et stacking\n", "\n", "On cherche \u00e0 pr\u00e9dire la note d'un vin avec un classifieur multi-classe puis \u00e0 am\u00e9liorer le score obtenu avec une m\u00e9thode dite de [stacking](https://www.quora.com/What-is-stacking-in-machine-learning)."]}, {"cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [{"data": {"text/html": ["
run previous cell, wait for 2 seconds
\n", ""], "text/plain": [""]}, "execution_count": 2, "metadata": {}, "output_type": "execute_result"}], "source": ["from jyquickhelper import add_notebook_menu\n", "add_notebook_menu()"]}, {"cell_type": "markdown", "metadata": {}, "source": ["## Le probl\u00e8me\n", "\n", "Il n'est pas \u00e9vident que les scores des diff\u00e9rents mod\u00e8les qu'on apprend sur chacun des classes soient comparables. Si le mod\u00e8le n'est pas assez performant, on peut songer \u00e0 ajouter un dernier mod\u00e8le qui prend la d\u00e9cision finale en fonction du r\u00e9sultat de chaque mod\u00e8le."]}, {"cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": [""]}, "execution_count": 3, "metadata": {"image/png": {"width": 400}}, "output_type": "execute_result"}], "source": ["from pyquickhelper.helpgen import NbImage\n", "NbImage('images/stackmulti.png', width=400)"]}, {"cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": ["%matplotlib inline"]}, {"cell_type": "code", "execution_count": 4, "metadata": {"scrolled": true}, "outputs": [], "source": ["from papierstat.datasets import load_wines_dataset\n", "df = load_wines_dataset()\n", "X = df.drop(['quality', 'color'], axis=1)\n", "y = df['quality']"]}, {"cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": ["from sklearn.model_selection import train_test_split\n", "X_train, X_test, y_train, y_test = train_test_split(X, y)"]}, {"cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [{"data": {"text/plain": ["OneVsRestClassifier(estimator=LogisticRegression(C=1.0, class_weight=None,\n", " dual=False, fit_intercept=True,\n", " intercept_scaling=1,\n", " l1_ratio=None, max_iter=1500,\n", " multi_class='auto',\n", " n_jobs=None, penalty='l2',\n", " random_state=None,\n", " solver='lbfgs', tol=0.0001,\n", " verbose=0, warm_start=False),\n", " n_jobs=None)"]}, "execution_count": 7, "metadata": {}, "output_type": "execute_result"}], "source": ["from sklearn.linear_model import LogisticRegression\n", "from sklearn.multiclass import OneVsRestClassifier\n", "clr = OneVsRestClassifier(LogisticRegression(max_iter=1500))\n", "clr.fit(X_train, y_train)"]}, {"cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [{"data": {"text/plain": ["53.907692307692315"]}, "execution_count": 8, "metadata": {}, "output_type": "execute_result"}], "source": ["import numpy\n", "numpy.mean(clr.predict(X_test).ravel() == y_test.ravel()) * 100"]}, {"cell_type": "markdown", "metadata": {}, "source": ["On regarde la matrice de confusion."]}, {"cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
3456789
30021010
4003915100
500314201200
600170537900
700172332500
800247600
90002100
\n", "
"], "text/plain": [" 3 4 5 6 7 8 9\n", "3 0 0 2 1 0 1 0\n", "4 0 0 39 15 1 0 0\n", "5 0 0 314 201 2 0 0\n", "6 0 0 170 537 9 0 0\n", "7 0 0 17 233 25 0 0\n", "8 0 0 2 47 6 0 0\n", "9 0 0 0 2 1 0 0"]}, "execution_count": 9, "metadata": {}, "output_type": "execute_result"}], "source": ["from sklearn.metrics import confusion_matrix\n", "import pandas\n", "df = pandas.DataFrame(confusion_matrix(y_test, clr.predict(X_test)))\n", "try:\n", " df.columns = [str(_) for _ in clr.classes_][:df.shape[1]]\n", " df.index = [str(_) for _ in clr.classes_][:df.shape[0]]\n", "except ValueError:\n", " # Il peut arriver qu'une classe ne soit pas repr\u00e9senter\n", " # lors de l'apprentissage\n", " print(\"erreur\", df.shape, clr.classes_)\n", "df"]}, {"cell_type": "markdown", "metadata": {}, "source": ["On cale d'abord une random forest sur les donn\u00e9es brutes."]}, {"cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [{"data": {"text/plain": ["67.44615384615385"]}, "execution_count": 10, "metadata": {}, "output_type": "execute_result"}], "source": ["from sklearn.ensemble import RandomForestClassifier\n", "rfc = RandomForestClassifier()\n", "rfc.fit(X_train, y_train)\n", "numpy.mean(rfc.predict(X_test).ravel() == y_test.ravel()) * 100"]}, {"cell_type": "markdown", "metadata": {}, "source": ["On cale une random forest avec les sorties de la r\u00e9gression logistique."]}, {"cell_type": "code", "execution_count": 10, "metadata": {"scrolled": false}, "outputs": [{"data": {"text/plain": ["RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,\n", " criterion='gini', max_depth=None, max_features='auto',\n", " max_leaf_nodes=None, max_samples=None,\n", " min_impurity_decrease=0.0, min_impurity_split=None,\n", " min_samples_leaf=1, min_samples_split=2,\n", " min_weight_fraction_leaf=0.0, n_estimators=100,\n", " n_jobs=None, oob_score=False, random_state=None,\n", " verbose=0, warm_start=False)"]}, "execution_count": 11, "metadata": {}, "output_type": "execute_result"}], "source": ["rf_train = clr.decision_function(X_train)\n", "\n", "rfc_y = RandomForestClassifier()\n", "rfc_y.fit(rf_train, y_train)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["On calcule le taux d'erreur."]}, {"cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [{"data": {"text/plain": ["64.8"]}, "execution_count": 12, "metadata": {}, "output_type": "execute_result"}], "source": ["rf_test = clr.decision_function(X_test)\n", "numpy.mean(rfc_y.predict(rf_test).ravel() == y_test.ravel()) * 100"]}, {"cell_type": "markdown", "metadata": {}, "source": ["C'est presque \u00e9quivalent \u00e0 une random forest cal\u00e9e sur les donn\u00e9es brutes. On trace les courbes ROC pour la classe 4."]}, {"cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [{"data": {"text/plain": ["(0.7485466126230457, 0.6752634626519977, 0.6984597568037059)"]}, "execution_count": 13, "metadata": {}, "output_type": "execute_result"}], "source": ["from sklearn.metrics import roc_curve, roc_auc_score\n", "fpr_lr, tpr_lr, th_lr = roc_curve(y_test == 4, clr.decision_function(X_test)[:, 2])\n", "fpr_rfc, tpr_rfc, th_rfc = roc_curve(y_test == 4, rfc.predict_proba(X_test)[:, 2])\n", "fpr_rfc_y, tpr_rfc_y, th_rfc_y = roc_curve(y_test == 4, rfc_y.predict_proba(rf_test)[:, 2])\n", "auc_lr = roc_auc_score(y_test == 4, clr.decision_function(X_test)[:, 2])\n", "auc_rfc = roc_auc_score(y_test == 4, rfc.predict_proba(X_test)[:, 2])\n", "auc_rfc_y = roc_auc_score(y_test == 4, rfc_y.predict_proba(rf_test)[:, 2])\n", "auc_lr, auc_rfc, auc_rfc_y"]}, {"cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["import matplotlib.pyplot as plt\n", "fig, ax = plt.subplots(1, 1, figsize=(4,4))\n", "ax.plot([0, 1], [0, 1], 'k--')\n", "ax.plot(fpr_lr, tpr_lr, label=\"OneVsRest + LR\")\n", "ax.plot(fpr_rfc, tpr_rfc, label=\"RF\")\n", "ax.plot(fpr_rfc_y, tpr_rfc_y, label=\"OneVsRest + LR + RF\")\n", "ax.set_title('Courbe ROC - comparaison de deux\\nmod\u00e8les pour la classe 4')\n", "ax.legend();"]}, {"cell_type": "markdown", "metadata": {}, "source": ["La courbe ROC ne montre rien de probant. Il faudrait v\u00e9rifier avec une cross-validation qu'il serait pratique de faire avec un [pipeline](http://scikit-learn.org/stable/modules/pipeline.html) mais ceux-ci n'acceptent qu'un seul pr\u00e9dicteur final."]}, {"cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["ERREUR :\n", "All intermediate steps should be transformers and implement fit and transform or be the string 'passthrough' 'OneVsRestClassifier(estimator=LogisticRegression(C=1.0, class_weight=None,\n", " dual=False, fit_intercept=True,\n", " intercept_scaling=1,\n", " l1_ratio=None, max_iter=1500,\n", " multi_class='auto',\n", " n_jobs=None, penalty='l2',\n", " random_state=None,\n", " solver='lbfgs', tol=0.0001,\n", " verbose=0, warm_start=False),\n", " n_jobs=None)' (type ) doesn't\n"]}], "source": ["from sklearn.pipeline import make_pipeline\n", "try:\n", " pipe = make_pipeline(OneVsRestClassifier(LogisticRegression(max_iter=1500)),\n", " RandomForestClassifier())\n", "except Exception as e:\n", " print('ERREUR :')\n", " print(e)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["On construit une ROC sur toutes les classes."]}, {"cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [{"data": {"text/plain": ["(0.5510406569489914, 0.753562533633215, 0.7354245943989535)"]}, "execution_count": 16, "metadata": {}, "output_type": "execute_result"}], "source": ["fpr_lr, tpr_lr, th_lr = roc_curve(y_test == clr.predict(X_test), \n", " clr.predict_proba(X_test).max(axis=1), drop_intermediate=False)\n", "fpr_rfc, tpr_rfc, th_rfc = roc_curve(y_test == rfc.predict(X_test), \n", " rfc.predict_proba(X_test).max(axis=1), drop_intermediate=False)\n", "fpr_rfc_y, tpr_rfc_y, th_rfc_y = roc_curve(y_test == rfc_y.predict(rf_test), \n", " rfc_y.predict_proba(rf_test).max(axis=1), drop_intermediate=False)\n", "auc_lr = roc_auc_score(y_test == clr.predict(X_test), \n", " clr.decision_function(X_test).max(axis=1))\n", "auc_rfc = roc_auc_score(y_test == rfc.predict(X_test), \n", " rfc.predict_proba(X_test).max(axis=1))\n", "auc_rfc_y = roc_auc_score(y_test == rfc_y.predict(rf_test), \n", " rfc_y.predict_proba(rf_test).max(axis=1))\n", "auc_lr, auc_rfc, auc_rfc_y"]}, {"cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["import matplotlib.pyplot as plt\n", "fig, ax = plt.subplots(1, 1, figsize=(4,4))\n", "ax.plot([0, 1], [0, 1], 'k--')\n", "ax.plot(fpr_lr, tpr_lr, label=\"OneVsRest + LR\")\n", "ax.plot(fpr_rfc, tpr_rfc, label=\"RF\")\n", "ax.plot(fpr_rfc_y, tpr_rfc_y, label=\"OneVsRest + LR + RF\")\n", "ax.set_title('Courbe ROC - comparaison de deux\\nmod\u00e8les pour toutes les classes')\n", "ax.legend();"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Sur ce mod\u00e8le, le score produit par le classifieur final para\u00eet plus partinent que le score obtenu en prenant le score maximum sur toutes les classes. On tente une derni\u00e8re approche o\u00f9 le mod\u00e8le final doit valider ou non la r\u00e9ponse : c'est un classifieur binaire. Avec celui-ci, tous les classifieurs estim\u00e9s sont binaires."]}, {"cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [{"data": {"text/plain": ["RandomForestClassifier(bootstrap=True, ccp_alpha=0.0, class_weight=None,\n", " criterion='gini', max_depth=None, max_features='auto',\n", " max_leaf_nodes=None, max_samples=None,\n", " min_impurity_decrease=0.0, min_impurity_split=None,\n", " min_samples_leaf=1, min_samples_split=2,\n", " min_weight_fraction_leaf=0.0, n_estimators=100,\n", " n_jobs=None, oob_score=False, random_state=None,\n", " verbose=0, warm_start=False)"]}, "execution_count": 18, "metadata": {}, "output_type": "execute_result"}], "source": ["rf_train_bin = clr.decision_function(X_train)\n", "y_train_bin = clr.predict(X_train) == y_train\n", "rfc = RandomForestClassifier()\n", "rfc.fit(rf_train_bin, y_train_bin)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["On regarde les premi\u00e8res r\u00e9ponses."]}, {"cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [{"data": {"text/plain": ["array([[0.32, 0.68],\n", " [0.5 , 0.5 ],\n", " [0.89, 0.11]])"]}, "execution_count": 19, "metadata": {}, "output_type": "execute_result"}], "source": ["rf_test_bin = clr.decision_function(X_test)\n", "rfc.predict_proba(rf_test_bin)[:3]"]}, {"cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": ["y_test_bin = clr.predict(X_test) == y_test"]}, {"cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [{"data": {"text/plain": ["(0.5510406569489914, 0.7668718108162481)"]}, "execution_count": 21, "metadata": {}, "output_type": "execute_result"}], "source": ["fpr_rfc_bin, tpr_rfc_bin, th_rfc_bin = roc_curve(y_test_bin, rfc.predict_proba(rf_test_bin)[:, 1])\n", "auc_rfc_bin = roc_auc_score(y_test_bin, rfc.predict_proba(rf_test_bin)[:, 1])\n", "auc_lr, auc_rfc_bin"]}, {"cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["import matplotlib.pyplot as plt\n", "fig, ax = plt.subplots(1, 1, figsize=(4,4))\n", "ax.plot([0, 1], [0, 1], 'k--')\n", "ax.plot(fpr_lr, tpr_lr, label=\"OneVsRest + LR\")\n", "ax.plot(fpr_rfc, tpr_rfc, label=\"OneVsRest + LR + RF\")\n", "ax.plot(fpr_rfc_bin, tpr_rfc_bin, label=\"OneVsRest + LR + RF binaire\")\n", "ax.set_title('Courbe ROC - comparaison de deux\\nmod\u00e8les pour toutes les classes')\n", "ax.legend();"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Un peu mieux mais il faudrait encore valider avec une validation crois\u00e9e et plusieurs jeux de donn\u00e9es, y compris artificiels. Il reste n\u00e9anmoins l'id\u00e9e."]}, {"cell_type": "markdown", "metadata": {}, "source": ["## Automatisation avec une impl\u00e9mentation\n", "\n", "Comme c'est fastidieux de faire tout cela, on impl\u00e9mente une classe qui convertit un mod\u00e8le de machine learning en un *transform* qu'on peut ins\u00e9rer dans un pipeline."]}, {"cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [{"data": {"text/plain": ["Pipeline(memory=None,\n", " steps=[('skbasetransformstacking',\n", " SkBaseTransformStacking([OneVsRestClassifier(estimator=LogisticRegress\n", " ion(C=1.0, class_weight=None,\n", " dual=False, fit_intercept=True,\n", " intercept_scaling=1,\n", " l1_ratio=None, max_iter=1500,\n", " multi_class='auto',\n", " n_jobs=None, penalty='l2',\n", " random_state=None,\n", " solver='lbfgs', tol=0.0001,\n", " verbose=0, warm_start=False), n_j...\n", " RandomForestClassifier(bootstrap=True, ccp_alpha=0.0,\n", " class_weight=None, criterion='gini',\n", " max_depth=None, max_features='auto',\n", " max_leaf_nodes=None, max_samples=None,\n", " min_impurity_decrease=0.0,\n", " min_impurity_split=None,\n", " min_samples_leaf=1, min_samples_split=2,\n", " min_weight_fraction_leaf=0.0,\n", " n_estimators=100, n_jobs=None,\n", " oob_score=False, random_state=None,\n", " verbose=0, warm_start=False))],\n", " verbose=False)"]}, "execution_count": 23, "metadata": {}, "output_type": "execute_result"}], "source": ["from mlinsights.sklapi import SkBaseTransformStacking\n", "model = make_pipeline(\n", " SkBaseTransformStacking(\n", " [OneVsRestClassifier(LogisticRegression(max_iter=1500))], \n", " 'decision_function'), \n", " RandomForestClassifier())\n", "model.fit(X_train, y_train)"]}, {"cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [{"data": {"text/plain": ["0.7330188804547779"]}, "execution_count": 24, "metadata": {}, "output_type": "execute_result"}], "source": ["fpr_pipe, tpr_pipe, th_pipe = roc_curve(y_test == model.predict(X_test), \n", " model.predict_proba(X_test).max(axis=1),\n", " drop_intermediate=False)\n", "auc_pipe = roc_auc_score(y_test == model.predict(X_test),\n", " model.predict_proba(X_test).max(axis=1))\n", "auc_pipe"]}, {"cell_type": "markdown", "metadata": {}, "source": ["On n'oublie pas de m\u00e9langer les donn\u00e9es avant de faire tourner la validation crois\u00e9e."]}, {"cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": ["df = load_wines_dataset(shuffle=True)\n", "X = df.drop(['quality', 'color'], axis=1)\n", "y = df['quality']"]}, {"cell_type": "markdown", "metadata": {}, "source": ["On retrouve les m\u00eames r\u00e9sultats mais on peut maintenant faire une validation crois\u00e9e."]}, {"cell_type": "code", "execution_count": 25, "metadata": {"scrolled": false}, "outputs": [{"data": {"text/plain": ["array([0.66923077, 0.66153846, 0.67513472, 0.66897614, 0.65588915])"]}, "execution_count": 26, "metadata": {}, "output_type": "execute_result"}], "source": ["from sklearn.model_selection import cross_val_score\n", "from sklearn.model_selection import cross_validate\n", "cross_val_score(model, X, y, cv=5)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["## scikit-learn 0.22 - stacking\n", "\n", "A partir de la version 0.22, *scikit-learn* a introduit le mod\u00e8le [StackingClassifier](https://scikit-learn.org/dev/modules/generated/sklearn.ensemble.StackingClassifier.html#sklearn.ensemble.StackingClassifier) avec un design un peu diff\u00e9rent"]}, {"cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [{"name": "stderr", "output_type": "stream", "text": ["C:\\xavierdupre\\__home_\\github_fork\\scikit-learn\\sklearn\\model_selection\\_split.py:667: UserWarning: The least populated class in y has only 2 members, which is less than n_splits=5.\n", " % (min_groups, self.n_splits)), UserWarning)\n", "C:\\xavierdupre\\__home_\\github_fork\\scikit-learn\\sklearn\\model_selection\\_split.py:667: UserWarning: The least populated class in y has only 2 members, which is less than n_splits=5.\n", " % (min_groups, self.n_splits)), UserWarning)\n", "C:\\xavierdupre\\__home_\\github_fork\\scikit-learn\\sklearn\\linear_model\\_logistic.py:939: ConvergenceWarning: lbfgs failed to converge (status=1):\n", "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n", "\n", "Increase the number of iterations (max_iter) or scale the data as shown in:\n", " https://scikit-learn.org/stable/modules/preprocessing.html.\n", "Please also refer to the documentation for alternative solver options:\n", " https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n", " extra_warning_msg=_LOGISTIC_SOLVER_CONVERGENCE_MSG)\n"]}, {"data": {"text/plain": ["StackingClassifier(cv=None,\n", " estimators=[('ovrlr',\n", " OneVsRestClassifier(estimator=LogisticRegression(C=1.0,\n", " class_weight=None,\n", " dual=False,\n", " fit_intercept=True,\n", " intercept_scaling=1,\n", " l1_ratio=None,\n", " max_iter=1500,\n", " multi_class='auto',\n", " n_jobs=None,\n", " penalty='l2',\n", " random_state=None,\n", " solver='lbfgs',\n", " tol=0.0001,\n", " verbose=0,\n", " warm_start=False),\n", " n_jobs=None)),\n", " ('rf',\n", " RandomForestCla...\n", " max_depth=None,\n", " max_features='auto',\n", " max_leaf_nodes=None,\n", " max_samples=None,\n", " min_impurity_decrease=0.0,\n", " min_impurity_split=None,\n", " min_samples_leaf=1,\n", " min_samples_split=2,\n", " min_weight_fraction_leaf=0.0,\n", " n_estimators=100,\n", " n_jobs=None,\n", " oob_score=False,\n", " random_state=None,\n", " verbose=0,\n", " warm_start=False))],\n", " final_estimator=None, n_jobs=None, passthrough=False,\n", " stack_method='auto', verbose=0)"]}, "execution_count": 27, "metadata": {}, "output_type": "execute_result"}], "source": ["try:\n", " from sklearn.ensemble import StackingClassifier\n", " skl = True\n", "except ImportError:\n", " # scikit-learn pas assez r\u00e9cent\n", " skl = False\n", "if skl:\n", " model = StackingClassifier([\n", " ('ovrlr', OneVsRestClassifier(LogisticRegression(max_iter=1500))),\n", " ('rf', RandomForestClassifier())\n", " ])\n", " model.fit(X_train, y_train)\n", "else:\n", " model = None\n", "model"]}, {"cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [{"data": {"text/plain": ["0.7301920159931777"]}, "execution_count": 28, "metadata": {}, "output_type": "execute_result"}], "source": ["if model is not None:\n", " fpr_pipe, tpr_pipe, th_pipe = roc_curve(y_test == model.predict(X_test),\n", " model.predict_proba(X_test).max(axis=1),\n", " drop_intermediate=False)\n", " auc_pipe = roc_auc_score(y_test == model.predict(X_test),\n", " model.predict_proba(X_test).max(axis=1))\n", "else:\n", " auc_pipe = None\n", "auc_pipe"]}, {"cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": []}], "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.2"}}, "nbformat": 4, "nbformat_minor": 2}