{"cells": [{"cell_type": "markdown", "metadata": {}, "source": ["# Quantile MLPRegressor\n", "\n", "[scikit-learn](http://scikit-learn.org/stable/) does not have a quantile regression for multi-layer perceptron. [mlinsights](http://www.xavierdupre.fr/app/mlinsights/helpsphinx/index.html) implements a version of it based on the *scikit-learn* model. The implementation overwrites method ``_backprop``."]}, {"cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": ["%matplotlib inline"]}, {"cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": ["import warnings\n", "warnings.simplefilter(\"ignore\")"]}, {"cell_type": "markdown", "metadata": {}, "source": ["We generate some dummy data."]}, {"cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": ["import numpy\n", "X = numpy.random.random(1000)\n", "eps1 = (numpy.random.random(900) - 0.5) * 0.1\n", "eps2 = (numpy.random.random(100)) * 10\n", "eps = numpy.hstack([eps1, eps2])\n", "X = X.reshape((1000, 1))\n", "Y = X.ravel() * 3.4 + 5.6 + eps"]}, {"cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [{"data": {"text/plain": ["MLPRegressor(activation='tanh', hidden_layer_sizes=(30,))"]}, "execution_count": 5, "metadata": {}, "output_type": "execute_result"}], "source": ["from sklearn.neural_network import MLPRegressor\n", "clr = MLPRegressor(hidden_layer_sizes=(30,), activation='tanh')\n", "clr.fit(X, Y)"]}, {"cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [{"data": {"text/plain": ["QuantileMLPRegressor(activation='tanh', hidden_layer_sizes=(30,))"]}, "execution_count": 6, "metadata": {}, "output_type": "execute_result"}], "source": ["from mlinsights.mlmodel import QuantileMLPRegressor\n", "clq = QuantileMLPRegressor(hidden_layer_sizes=(30,), activation='tanh')\n", "clq.fit(X, Y)"]}, {"cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", " | X | \n", "Y | \n", "clr | \n", "clq | \n", "
---|---|---|---|---|
0 | \n", "0.251734 | \n", "6.470634 | \n", "7.059780 | \n", "6.481283 | \n", "
1 | \n", "0.538065 | \n", "7.423694 | \n", "8.029974 | \n", "7.510084 | \n", "
2 | \n", "0.530510 | \n", "7.411181 | \n", "8.006414 | \n", "7.485186 | \n", "
3 | \n", "0.048348 | \n", "5.808051 | \n", "6.278572 | \n", "5.646920 | \n", "
4 | \n", "0.882162 | \n", "8.624456 | \n", "8.986741 | \n", "8.519049 | \n", "