{"cells": [{"cell_type": "markdown", "metadata": {}, "source": ["# 2A.ml - Arbres de d\u00e9cision / Random Forest\n", "\n", "Classification, r\u00e9gression, visualisation avec des m\u00e9thodes ensemblistes (arbres, for\u00eats, ...)."]}, {"cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": ["%matplotlib inline"]}, {"cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": ["import matplotlib.pyplot as plt"]}, {"cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [{"data": {"text/html": ["
run previous cell, wait for 2 seconds
\n", ""], "text/plain": [""]}, "execution_count": 4, "metadata": {}, "output_type": "execute_result"}], "source": ["from jyquickhelper import add_notebook_menu\n", "add_notebook_menu()"]}, {"cell_type": "markdown", "metadata": {}, "source": ["### Description du probl\u00e8me\n", "\n", "Le code suivant t\u00e9l\u00e9charge les donn\u00e9es n\u00e9cessaires [salaries2010.zip](http://www.xavierdupre.fr/enseignement/complements/salaries2010.zip)."]}, {"cell_type": "markdown", "metadata": {}, "source": ["Le machine learning peut se r\u00e9sumer \u00e0 la construction d'une fonction de pr\u00e9diction $Y=f(X) + \\epsilon$. $f$ est le plus souvent le r\u00e9sultat d'une minimisation de l'erreur $\\sum_i E(Y_i,f(X_i))$ o\u00f9 $(X_i,Y_i)$ est une liste de couples (features, cible). Les [arbres de d\u00e9cision](http://fr.wikipedia.org/wiki/Arbre_de_d%C3%A9cision) sont des mod\u00e8les assez faciles \u00e0 apprendre et ils ont l'avantage d'accepter des [features](http://en.wikipedia.org/wiki/Feature_%28machine_learning%29) continues et discr\u00e8tes. Pour cet exercice, on reprend la base des salari\u00e9s vu dans un pr\u00e9c\u00e9dent notebook et on va essayer de pr\u00e9dire le salaire en fonction de plus de variables que l'\u00e2ge ou le sexe :"]}, {"cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
TRNETTOTAGESEXEDEPTDEPRTYP_EMPLOIPCSCSCONT_TRAVCONV_COLLVARIABLEMODALITEMODLIBELLEmontant
01450.01972O628G62ZZZTRNETTOT1418 000 \u00e0 19 999 euros18999.5
11441.017575O354C35CDD1734TRNETTOT1418 000 \u00e0 19 999 euros18999.5
21429.017575O373C37CDD0014TRNETTOT1418 000 \u00e0 19 999 euros18999.5
31430.017575O651A65CDD9999TRNETTOT1418 000 \u00e0 19 999 euros18999.5
41455.017892O623E62ZZZTRNETTOT1418 000 \u00e0 19 999 euros18999.5
\n", "
"], "text/plain": [" TRNETTOT AGE SEXE DEPT DEPR TYP_EMPLOI PCS CS CONT_TRAV CONV_COLL \\\n", "0 14 50.0 1 972 O 628G 62 ZZZ \n", "1 14 41.0 1 75 75 O 354C 35 CDD 1734 \n", "2 14 29.0 1 75 75 O 373C 37 CDD 0014 \n", "3 14 30.0 1 75 75 O 651A 65 CDD 9999 \n", "4 14 55.0 1 78 92 O 623E 62 ZZZ \n", "\n", " VARIABLE MODALITE MODLIBELLE montant \n", "0 TRNETTOT 14 18 000 \u00e0 19 999 euros 18999.5 \n", "1 TRNETTOT 14 18 000 \u00e0 19 999 euros 18999.5 \n", "2 TRNETTOT 14 18 000 \u00e0 19 999 euros 18999.5 \n", "3 TRNETTOT 14 18 000 \u00e0 19 999 euros 18999.5 \n", "4 TRNETTOT 14 18 000 \u00e0 19 999 euros 18999.5 "]}, "execution_count": 5, "metadata": {}, "output_type": "execute_result"}], "source": ["import os\n", "if not os.path.exists(\"salaries2010.db3\"):\n", " import pyensae.datasource\n", " db3 = pyensae.datasource.download_data(\"salaries2010.zip\")\n", "\n", "import sqlite3, pandas\n", "con = sqlite3.connect(\"salaries2010.db3\")\n", "df = pandas.io.sql.read_sql(\"select * from varmod\", con)\n", "con.close()\n", "\n", "values = df[ df.VARIABLE == \"TRNETTOT\"].copy()\n", "\n", "def process_intervalle(s):\n", " if \"euros et plus\" in s : \n", " return float ( s.replace(\"euros et plus\", \"\").replace(\" \",\"\") )\n", " spl = s.split(\"\u00e0\")\n", " if len(spl) == 2 : \n", " s1 = spl[0].replace(\"Moins de\",\"\").replace(\"euros\",\"\").replace(\" \",\"\")\n", " s2 = spl[1].replace(\"Moins de\",\"\").replace(\"euros\",\"\").replace(\" \",\"\")\n", " return (float(s1)+float(s2))/2\n", " else : \n", " s = spl[0].replace(\"Moins de\",\"\").replace(\"euros\",\"\").replace(\" \",\"\")\n", " return float(s)/2\n", "\n", "values[\"montant\"] = values.apply(lambda r : process_intervalle(r [\"MODLIBELLE\"]), axis = 1)\n", "\n", "con = sqlite3.connect(\"salaries2010.db3\")\n", "data = pandas.io.sql.read_sql(\"select TRNETTOT,AGE,SEXE,DEPT,DEPR,TYP_EMPLOI,PCS,CS,CONT_TRAV,CONV_COLL from salaries\", con)\n", "con.close()\n", "\n", "salaires = data.merge ( values, left_on = \"TRNETTOT\", right_on=\"MODALITE\" )\n", "salaires.dropna(inplace=True)\n", "salaires.head()"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Le module scikit-learn n'accepte pas les features sous forme de cha\u00eenes de caract\u00e8res :\n", "* [Encoding categorical features](http://scikit-learn.org/stable/modules/feature_extraction.html#dict-feature-extraction)\n", "* [Loading features from dicts](http://scikit-learn.org/stable/modules/feature_extraction.html#loading-features-from-dicts)\n", "* [Vectorizing a Pandas dataframe for Scikit-Learn](http://stackoverflow.com/questions/20024584/vectorizing-a-pandas-dataframe-for-scikit-learn)\n", "\n", "Il faut transformer les variables qui ne sont pas num\u00e9riques (et non ordonn\u00e9es) en variables bool\u00e9ennes (on fait cela sur un \u00e9chantillon d'abord) :"]}, {"cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": ["import random\n", "salaires[\"rnd\"] = salaires.apply (lambda r : random.randint(0,50),axis=1)\n", "ech = salaires [ salaires.rnd == 0 ]"]}, {"cell_type": "markdown", "metadata": {}, "source": ["La taille de l'\u00e9chantillon doit \u00eatre ajust\u00e9e en fonction de la m\u00e9moire de l'ordinateur et il est aussi pr\u00e9f\u00e9rable de commencer avec un \u00e9chantillon petit. Le d\u00e9veloppement du mod\u00e8le prend moins de temps. On agrandit la taille de l'\u00e9chantillon quand tout fonctionne bien (on perd souvent pas mal de temps parce que le type d'une variable n'est pas celui attendu, qu'on s'est tromp\u00e9 de nom, qu'une valeur est manquante...)."]}, {"cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [{"data": {"text/plain": ["(43111, 4)"]}, "execution_count": 7, "metadata": {}, "output_type": "execute_result"}], "source": ["X,Y = ech[[\"AGE\",\"SEXE\",\"TYP_EMPLOI\",\"CONT_TRAV\"]], ech[[\"montant\"]]\n", "Xd = X.T.to_dict().values()\n", "X.shape"]}, {"cell_type": "markdown", "metadata": {}, "source": ["On transforme les variables sous forme de cha\u00eenes de caract\u00e8res en variables binaires :"]}, {"cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": ["from sklearn.feature_extraction import DictVectorizer\n", "prep = DictVectorizer()\n", "Xt = prep.fit_transform(Xd).toarray()"]}, {"cell_type": "markdown", "metadata": {}, "source": ["``Xt`` est un [numpy.ndarray](http://docs.scipy.org/doc/numpy/reference/generated/numpy.ndarray.html) mais la variable ``prep`` a conserv\u00e9 le nom des features."]}, {"cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [{"data": {"text/plain": ["['AGE',\n", " 'CONT_TRAV=APP',\n", " 'CONT_TRAV=AUT',\n", " 'CONT_TRAV=CDD',\n", " 'CONT_TRAV=CDI',\n", " 'CONT_TRAV=TTP',\n", " 'CONT_TRAV=ZZZ',\n", " 'SEXE=',\n", " 'SEXE=1',\n", " 'SEXE=2',\n", " 'TYP_EMPLOI=A',\n", " 'TYP_EMPLOI=O',\n", " 'TYP_EMPLOI=X']"]}, "execution_count": 9, "metadata": {}, "output_type": "execute_result"}], "source": ["prep.feature_names_"]}, {"cell_type": "markdown", "metadata": {}, "source": ["**Remarque :** On transforme une variable cat\u00e9gorielle en une s\u00e9rie de variables bool\u00e9ennes mais lorsque les cat\u00e9gories sont exclusives, une observation est n\u00e9cessairement dans l'une d'elles. La somme des variables bool\u00e9ennes qui en d\u00e9coulent est \u00e9gale \u00e0 1. Cela revient \u00e0 cr\u00e9er une s\u00e9ries de variables dont la somme est corr\u00e9l\u00e9e \u00e0 une constante : ce cas est \u00e0 \u00e9viter lors d'un mod\u00e8le lin\u00e9aire comme la r\u00e9gression. Il faut enlever une variable. Comme on cale un arbre de d\u00e9cision par la suite, ce n'est pas indispensable."]}, {"cell_type": "markdown", "metadata": {}, "source": ["On entra\u00eene l'arbre, on limite la profondeur \u00e0 3 histoire de pouvoir visualiser l'arbre r\u00e9sultant. Ce n'est certainement pas assez puisque $2^3=8$ = le nombre de feuilles est un nombre inf\u00e9rieure au nombres de tranches de salaires possibles."]}, {"cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [{"data": {"text/plain": ["0.21069611975253544"]}, "execution_count": 10, "metadata": {}, "output_type": "execute_result"}], "source": ["from sklearn.tree import DecisionTreeRegressor\n", "clf = DecisionTreeRegressor(min_samples_leaf=10, max_depth=3)\n", "clf = clf.fit(Xt,Y)\n", "clf.score(Xt,Y)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["On repr\u00e9sente l'arbre de d\u00e9cision (et \u00e7a devient un peu complexe) :"]}, {"cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": ["from sklearn.tree import export_graphviz\n", "export_graphviz(clf, out_file=\"arbre.dot\") "]}, {"cell_type": "markdown", "metadata": {}, "source": ["Pour visualiser l'arbre, il faut installer [graphviz](http://www.graphviz.org/) et lancer la commande (il faudra sans doute remplacer le chemin vers votre installation de Graphviz)."]}, {"cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": ["import sys\n", "cwd = os.getcwd()\n", "if sys.platform.startswith(\"win\"):\n", " exe = 'C:\\\\Program Files (x86)\\\\Graphviz2.38\\\\bin\\\\dot.exe'\n", " if not os.path.exists(exe):\n", " raise FileNotFoundError(exe)\n", " exe = '\"{0}\"'.format(exe)\n", "else:\n", " exe = \"dot\"\n", "cmd = '\"{0}\" -Tpng {1}\\\\arbre.dot -o {1}\\\\arbre.png'.format(exe, cwd)"]}, {"cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [{"data": {"text/plain": ["0"]}, "execution_count": 13, "metadata": {}, "output_type": "execute_result"}], "source": ["os.system(cmd)"]}, {"cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [{"data": {"image/png": "\n", "text/plain": [""]}, "execution_count": 14, "metadata": {}, "output_type": "execute_result"}], "source": ["from IPython.core.display import Image\n", "Image(\"arbre.png\")"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Dans un notebook, le javascript peut \u00eatre utilis\u00e9 pour tracer de le graphe (voir [Visualiser un arbre de d\u00e9cision](http://www.xavierdupre.fr/app/papierstat/helpsphinx/notebooks/decision_tree_visualization.html))."]}, {"cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", ""], "text/plain": [""]}, "execution_count": 15, "metadata": {}, "output_type": "execute_result"}], "source": ["from jyquickhelper import RenderJsVis\n", "dot = export_graphviz(clf, out_file=None, feature_names=prep.feature_names_)\n", "RenderJsVis(dot=dot, height=\"400px\", layout='hierarchical')"]}, {"cell_type": "markdown", "metadata": {}, "source": ["Oups... j'ai oubli\u00e9 de s\u00e9parer base d'apprentissage et base de test. Il ne restera plus qu'\u00e0 tracer la courbe ROC : [Receiver operating characteristic (ROC)](http://scikit-learn.org/0.11/auto_examples/plot_roc.html)."]}, {"cell_type": "markdown", "metadata": {}, "source": ["### Exercice 1 : Bases d'apprentissage, test, courbes\n", "\n", "A vous de jouer. Quelques id\u00e9es :\n", "\n", "* [train_test_split](http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html)\n", "* [random forest](http://blog.yhathq.com/posts/random-forests-in-python.html)"]}, {"cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": []}, {"cell_type": "markdown", "metadata": {}, "source": ["### Exercice 2 : Courbes ROC\n", " \n", "On retourne le probl\u00e8me, on essaye de pr\u00e9voir le sexe en fonction des autres variables dont le salaire.\n", "\n", "* [RandomForestClassifier](http://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html)"]}, {"cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": []}], "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.4"}}, "nbformat": 4, "nbformat_minor": 2}