{"cells": [{"cell_type": "markdown", "metadata": {}, "source": ["# Visualiser un arbre de d\u00e9cision\n", "\n", "Les arbres de d\u00e9cision sont des mod\u00e8les int\u00e9ressants car ils peuvent \u00eatre interpr\u00e9t\u00e9s. Encore faut-il pouvoir les voir."]}, {"cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": ["from sklearn import datasets\n", "iris = datasets.load_iris()\n", "X = iris.data[:, :2] # we only take the first two features.\n", "y = iris.target"]}, {"cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [{"data": {"text/plain": ["DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,\n", " max_features=None, max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_impurity_split=None,\n", " min_samples_leaf=1, min_samples_split=2,\n", " min_weight_fraction_leaf=0.0, presort=False, random_state=None,\n", " splitter='best')"]}, "execution_count": 3, "metadata": {}, "output_type": "execute_result"}], "source": ["from sklearn.tree import DecisionTreeClassifier\n", "clf = DecisionTreeClassifier()\n", "clf.fit(X, y)"]}, {"cell_type": "markdown", "metadata": {}, "source": ["[scikit-learn](http://scikit-learn.org/stable/) impl\u00e9mente une m\u00e9thode qui permet d'exporter de graphe au format [DOT](https://en.wikipedia.org/wiki/DOT_(graph_description_language)) : [export_graphviz](http://scikit-learn.org/stable/modules/generated/sklearn.tree.export_graphviz.html). Ce graphe peut \u00eatre visualiser avec l'outil [graphviz](https://www.graphviz.org/) ou des modules comme [pydot](https://github.com/erocarrera/pydot) mais cela passe par l'installation [graphviz](https://www.graphviz.org/)."]}, {"cell_type": "code", "execution_count": 3, "metadata": {"scrolled": false}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["digraph Tree {\n", "node [shape=box] ;\n", "0 [label=\"X[0] <= 5.45\\ngini = 0.667\\nsamples = 150\\nvalue = [50, 50, 50]\"] ;\n", "1 [label=\"X[1] <= 2.8\\ngini = 0.237\\nsamples = 52\\nvalue = [45, 6, 1]\"] ;\n", "0 -> 1 [labeldistance=2.5, labelangle=45, headlabel=\"True\"] ;\n", "2 [label=\"X[0] <= 4.7\\ngini = 0.449\\nsamples = 7\\nvalue = [1, 5, 1]\"] ;\n", "1 -> 2 ;\n", "3 [label=\"gini = 0.0\\nsamples = 1\\nvalue = [1, 0, 0]\"] ;\n", "2 -> 3 ;\n", "4 [label=\"X[0] <= 4.95\\ngini = 0.278\\nsamples = 6\\nvalue = [0, 5, 1]\"] ;\n", "...\n"]}], "source": ["from sklearn.tree import export_graphviz\n", "dot = export_graphviz(clf, out_file=None)\n", "print(\"\\n\".join(dot.split('\\n')[:10]) + \"\\n...\")"]}, {"cell_type": "markdown", "metadata": {}, "source": ["La libraire [viz.js](https://github.com/mdaines/viz.js/) est une version javascript de [graphviz](https://www.graphviz.org/). Avec un wrapper disponible [RenderJsDot](http://www.xavierdupre.fr/app/jyquickhelper/helpsphinx/jyquickhelper/jspy/render_nb_js_dot.html?highlight=renderjsdot#jyquickhelper.jspy.render_nb_js_dot.RenderJsDot), cela devient :"]}, {"cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [{"data": {"text/html": ["