{"cells": [{"cell_type": "markdown", "id": "15794c25", "metadata": {}, "source": ["# onnxruntime-training, scikit-learn\n", "\n", "Simple examples mixing packages. The notebook takes a neural network from [scikit-learn](https://scikit-learn.org/stable/modules/generated/sklearn.neural_network.MLPRegressor.html) (regression), converts it into [ONNX](https://onnx.ai/) and trains it with [onnxruntime-training](https://github.com/microsoft/onnxruntime-training-examples)."]}, {"cell_type": "code", "execution_count": 1, "id": "b1d6834c", "metadata": {}, "outputs": [{"data": {"text/html": ["
run previous cell, wait for 2 seconds
\n", ""], "text/plain": [""]}, "execution_count": 2, "metadata": {}, "output_type": "execute_result"}], "source": ["from jyquickhelper import add_notebook_menu\n", "add_notebook_menu()"]}, {"cell_type": "code", "execution_count": 2, "id": "9c7d85a7", "metadata": {}, "outputs": [], "source": ["%matplotlib inline"]}, {"cell_type": "code", "execution_count": 3, "id": "f3772d83", "metadata": {}, "outputs": [], "source": ["%load_ext mlprodict"]}, {"cell_type": "code", "execution_count": 4, "id": "95b5be9a", "metadata": {}, "outputs": [], "source": ["import warnings\n", "from time import perf_counter\n", "warnings.filterwarnings(\"ignore\")"]}, {"cell_type": "markdown", "id": "19c76814", "metadata": {}, "source": ["## Data and first model"]}, {"cell_type": "code", "execution_count": 5, "id": "c9059636", "metadata": {}, "outputs": [], "source": ["from sklearn.datasets import load_diabetes\n", "from sklearn.model_selection import train_test_split\n", "\n", "data = load_diabetes()\n", "X, y = data.data, data.target\n", "y /= 100\n", "X_train, X_test, y_train, y_test = train_test_split(X, y)"]}, {"cell_type": "code", "execution_count": 6, "id": "4a6af2a3", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["training time: 0.9272374000000028\n"]}], "source": ["from sklearn.neural_network import MLPRegressor\n", "\n", "nn = MLPRegressor(hidden_layer_sizes=(20,), max_iter=400)\n", "begin = perf_counter()\n", "nn.fit(X_train, y_train)\n", "print(\"training time: %r\" % (perf_counter() - begin))"]}, {"cell_type": "code", "execution_count": 7, "id": "305af27c", "metadata": {}, "outputs": [{"data": {"text/plain": ["0.4844054606180469"]}, "execution_count": 8, "metadata": {}, "output_type": "execute_result"}], "source": ["from sklearn.metrics import r2_score\n", "r2_score(y_test, nn.predict(X_test))"]}, {"cell_type": "markdown", "id": "9ad32de8", "metadata": {}, "source": ["## Conversion to ONNX\n", "\n", "With [skl2onnx](https://github.com/onnx/sklearn-onnx)."]}, {"cell_type": "code", "execution_count": 8, "id": "375d3a0d", "metadata": {}, "outputs": [{"data": {"text/html": ["
\n", ""], "text/plain": [""]}, "execution_count": 9, "metadata": {}, "output_type": "execute_result"}], "source": ["import numpy\n", "from skl2onnx import to_onnx\n", "\n", "nn = MLPRegressor(hidden_layer_sizes=(20,), max_iter=1).fit(X_train, y_train)\n", "nn_onnx = to_onnx(nn, X_train[1:].astype(numpy.float32))\n", "\n", "%onnxview nn_onnx"]}, {"cell_type": "markdown", "id": "95dfff89", "metadata": {}, "source": ["## Training with pytorch + ONNX\n", "\n", "We could use onnxruntime-training only (see [Train a linear regression with onnxruntime-training](http://www.xavierdupre.fr/app/onnxcustom/helpsphinx/gyexamples/plot_orttraining_linear_regression.html) but instead we try to extend [pytorch](https://pytorch.org/) with a custom function defined with an ONNX graph, the one obtained by converting a neural network from scikit-learn into ONNX. First, let's get the list of parameters of the model."]}, {"cell_type": "code", "execution_count": 9, "id": "6f261cbf", "metadata": {}, "outputs": [{"data": {"text/plain": ["['coefficient', 'intercepts', 'coefficient1', 'intercepts1']"]}, "execution_count": 10, "metadata": {}, "output_type": "execute_result"}], "source": ["from onnx.numpy_helper import to_array\n", "\n", "weights = [(init.name, to_array(init)) \n", " for init in nn_onnx.graph.initializer\n", " if 'shape' not in init.name]\n", "[w[0] for w in weights]"]}, {"cell_type": "markdown", "id": "43625d28", "metadata": {}, "source": ["Class [TorchOrtFactory](http://www.xavierdupre.fr/app/deeponnxcustom/helpsphinx/deeponnxcustom/onnxtorch/torchort.html#deeponnxcustom.onnxtorch.torchort.TorchOrtFactory) creates a torch function by taking the ONNX graph and the weights to learn."]}, {"cell_type": "code", "execution_count": 10, "id": "9d1e2863", "metadata": {}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["No CUDA runtime is found, using CUDA_HOME='C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v11.4'\n", "List of weights to train must be sorted but is not in ['coefficient', 'intercepts', 'coefficient1', 'intercepts1']. You shoud use function onnx_rename_weights to do that before calling this class.\n"]}], "source": ["from deeponnxcustom.onnxtorch import TorchOrtFactory\n", "\n", "try:\n", " fact = TorchOrtFactory(nn_onnx, [w[0] for w in weights])\n", "except ValueError as e:\n", " print(e)"]}, {"cell_type": "markdown", "id": "491b8432", "metadata": {}, "source": ["The function fails because the weights needs to be in alphabetical order. We use a function to rename them."]}, {"cell_type": "code", "execution_count": 11, "id": "cd11eba0", "metadata": {}, "outputs": [{"data": {"text/plain": ["['I0_coefficient', 'I1_intercepts', 'I2_coefficient1', 'I3_intercepts1']"]}, "execution_count": 12, "metadata": {}, "output_type": "execute_result"}], "source": ["from deeponnxcustom.tools.onnx_helper import onnx_rename_weights\n", "\n", "onnx_rename_weights(nn_onnx)\n", "weights = [(init.name, to_array(init)) \n", " for init in nn_onnx.graph.initializer\n", " if 'shape' not in init.name]\n", "[w[0] for w in weights]"]}, {"cell_type": "markdown", "id": "ba778d93", "metadata": {}, "source": ["We start again."]}, {"cell_type": "code", "execution_count": 12, "id": "3801ace8", "metadata": {"scrolled": false}, "outputs": [], "source": ["fact = TorchOrtFactory(nn_onnx, [w[0] for w in weights])"]}, {"cell_type": "markdown", "id": "de16a998", "metadata": {}, "source": ["Let's create the torch function."]}, {"cell_type": "code", "execution_count": 13, "id": "22ba8477", "metadata": {}, "outputs": [{"data": {"text/plain": ["deeponnxcustom.onnxtorch.torchort.TorchOrtFunction_2140275442256"]}, "execution_count": 14, "metadata": {}, "output_type": "execute_result"}], "source": ["cls = fact.create_class()\n", "cls"]}, {"cell_type": "code", "execution_count": 14, "id": "427f65f6", "metadata": {}, "outputs": [{"data": {"text/plain": ["(deeponnxcustom.onnxtorch.torchort.TorchOrtFunction,)"]}, "execution_count": 15, "metadata": {}, "output_type": "execute_result"}], "source": ["cls.__bases__"]}, {"cell_type": "code", "execution_count": 15, "id": "70574953", "metadata": {}, "outputs": [{"data": {"text/plain": ["(torch.autograd.function.Function,)"]}, "execution_count": 16, "metadata": {}, "output_type": "execute_result"}], "source": ["cls.__bases__[0].__bases__"]}, {"cell_type": "markdown", "id": "3b473d24", "metadata": {}, "source": ["Let's train it."]}, {"cell_type": "code", "execution_count": 16, "id": "7dd0e117", "metadata": {"scrolled": false}, "outputs": [{"name": "stdout", "output_type": "stream", "text": ["device: cpu\n"]}, {"name": "stderr", "output_type": "stream", "text": ["100%|\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588\u2588| 400/400 [00:00<00:00, 428.75it/s]\n"]}, {"name": "stdout", "output_type": "stream", "text": ["training time: 0.9546589000000054\n"]}], "source": ["from tqdm import tqdm\n", "import torch\n", "\n", "\n", "def from_numpy(v, device=None, requires_grad=False):\n", " v = torch.from_numpy(v)\n", " if device is not None:\n", " v = v.to(device)\n", " v.requires_grad_(requires_grad)\n", " return v\n", "\n", "\n", "def train_cls(cls, device, X_train, y_train, weights, n_iter=20, learning_rate=1e-3):\n", " x = from_numpy(X_train.astype(numpy.float32), \n", " requires_grad=True, device=device)\n", " y = from_numpy(y_train.astype(numpy.float32),\n", " requires_grad=True, device=device)\n", " fact = torch.tensor([x.shape[0]], dtype=torch.float32).to(device)\n", " fact.requires_grad_(True)\n", "\n", " weights_tch = [(w[0], from_numpy(w[1], requires_grad=True, device=device))\n", " for w in weights]\n", " weights_values = [w[1] for w in weights_tch]\n", "\n", " all_losses = []\n", " for t in tqdm(range(n_iter)):\n", " # forward - backward\n", " y_pred = cls.apply(x, *weights_values)\n", " loss = (y_pred - y).pow(2).sum() / fact\n", " loss.backward()\n", "\n", " # update weights\n", " with torch.no_grad():\n", " for w in weights_values:\n", " w -= w.grad * learning_rate\n", " w.grad.zero_()\n", "\n", " all_losses.append((t, float(loss.detach().numpy())))\n", " return all_losses, weights_tch\n", "\n", "\n", "device_name = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", "device = torch.device(device_name)\n", "print(\"device:\", device)\n", "\n", "begin = perf_counter()\n", "train_losses, final_weights = train_cls(cls, device, X_train, y_test, weights, n_iter=400)\n", "print(\"training time: %r\" % (perf_counter() - begin))"]}, {"cell_type": "code", "execution_count": 17, "id": "0e41f8ac", "metadata": {}, "outputs": [{"data": {"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYAAAAEWCAYAAABv+EDhAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/MnkTPAAAACXBIWXMAAAsTAAALEwEAmpwYAAAt4UlEQVR4nO3deXxU9bnH8c+TjUACSQghbIGEHVzYwi7WiqLiglu9tdWCVXHXtrZe673Xpe3tfm21AmpRa61aBXdF3BcERNYgCLIokrCGJewQQp77xxwwYgIDJJlJ5vt+vfJizu+cM+eZQzLfOb/fOWfM3RERkdgTF+kCREQkMhQAIiIxSgEgIhKjFAAiIjFKASAiEqMUACIiMUoBIDHNzF43s5HVvewR1nCKmRVV9/OKHE5CpAsQOVJmtr3CZCNgD7AvmL7G3Z8M97nc/ayaWFakLlAASJ3j7qn7H5vZCuAqd3/74OXMLMHdy2qzNpG6RF1AUm/s70oxs/80s7XAY2aWYWavmlmxmW0OHrepsM77ZnZV8HiUmX1kZn8Olv3SzM46ymXzzOxDM9tmZm+b2Rgz+1eYr6NbsK0SM1toZudVmDfczD4LnneVmf08aG8WvLYSM9tkZlPMTH/fckj6BZH6pgXQFGgHjCb0O/5YMN0W2AU8cIj1+wOfA82APwKPmJkdxbJPAZ8AmcDdwOXhFG9micArwJtAc+Am4Ekz6xIs8gihbq7GwPHAu0H7rUARkAVkA3cAus+LHJICQOqbcuAud9/j7rvcfaO7P+fuO919G/C/wHcOsf5X7v53d98HPA60JPSGGvayZtYW6Avc6e6l7v4R8HKY9Q8AUoHfB+u+C7wKXBrM3wt0N7Mm7r7Z3edUaG8JtHP3ve4+xXWjLzkMBYDUN8Xuvnv/hJk1MrOHzOwrM9sKfAikm1l8Feuv3f/A3XcGD1OPcNlWwKYKbQCFYdbfCih09/IKbV8BrYPHFwHDga/M7AMzGxi0/wlYBrxpZl+Y2e1hbk9imAJA6puDP/XeCnQB+rt7E+DkoL2qbp3qsAZoamaNKrTlhLnuaiDnoP77tsAqAHef6e4jCHUPvQg8G7Rvc/db3b09cB7wMzMbemwvQ+o7BYDUd40J9fuXmFlT4K6a3qC7fwXMAu42s6TgU/q5Ya4+A9gJ3GZmiWZ2SrDuv4Pn+qGZpbn7XmAroS4vzOwcM+sYjEFsIXRabHmlWxAJKACkvvsr0BDYAHwMTK6l7f4QGAhsBH4DPEPoeoVDcvdSQm/4ZxGqeSzwI3dfHCxyObAi6M66NtgOQCfgbWA7MB0Y6+7vVdurkXrJNE4kUvPM7BlgsbvX+BGISLh0BCBSA8ysr5l1MLM4MzsTGEGoz14kauhKYJGa0QJ4ntB1AEXAde4+N7IliXyTuoBERGKUuoBERGJUneoCatasmefm5ka6DBGROmX27Nkb3D3r4PY6FQC5ubnMmjUr0mWIiNQpZvZVZe3qAhIRiVEKABGRGKUAEBGJUXVqDEBE6p+9e/dSVFTE7t27D7+wHFJycjJt2rQhMTExrOUVACISUUVFRTRu3Jjc3Fyq/u4dORx3Z+PGjRQVFZGXlxfWOuoCEpGI2r17N5mZmXrzP0ZmRmZm5hEdSSkARCTi9OZfPY50P8ZEALwwt4h/fVzpabAiIjErJgLglYI1PDMz3G/kExGJDTERAInxxt59+nIkEfm2kpISxo4de8TrDR8+nJKSkiNeb9SoUUycOPGI16sJMREACfFxlCoARKQSVQVAWVnZIdebNGkS6enpNVRV7YiJ00CT4uMo26fbXotEu3teWchnq7dW63N2b9WEu849rsr5t99+O8uXL6dnz54kJiaSnJxMRkYGixcvZsmSJZx//vkUFhaye/dubrnlFkaPHg18fW+y7du3c9ZZZ3HSSScxbdo0WrduzUsvvUTDhg0PW9s777zDz3/+c8rKyujbty/jxo2jQYMG3H777bz88sskJCQwbNgw/vznPzNhwgTuuece4uPjSUtL48MPPzzmfRMTAZAQpy4gEanc73//exYsWMC8efN4//33Ofvss1mwYMGBc+kfffRRmjZtyq5du+jbty8XXXQRmZmZ33iOpUuX8vTTT/P3v/+dSy65hOeee47LLrvskNvdvXs3o0aN4p133qFz58786Ec/Yty4cVx++eW88MILLF68GDM70M30q1/9ijfeeIPWrVsfVddTZWIiABIT4hQAInXAoT6p15Z+/fp940Kq+++/nxdeeAGAwsJCli5d+q0AyMvLo2fPngD06dOHFStWHHY7n3/+OXl5eXTu3BmAkSNHMmbMGG688UaSk5O58sorOeecczjnnHMAGDx4MKNGjeKSSy7hwgsvrIZXGiNjAEnxcexVF5CIhCElJeXA4/fff5+3336b6dOnU1BQQK9evSq90KpBgwYHHsfHxx92/OBQEhIS+OSTT7j44ot59dVXOfPMMwF48MEH+c1vfkNhYSF9+vRh48aNR72NA9s65meoA9QFJCJVady4Mdu2bat03pYtW8jIyKBRo0YsXryYjz/+uNq226VLF1asWMGyZcvo2LEjTzzxBN/5znfYvn07O3fuZPjw4QwePJj27dsDsHz5cvr370///v15/fXXKSws/NaRyJGKiQBQF5CIVCUzM5PBgwdz/PHH07BhQ7Kzsw/MO/PMM3nwwQfp1q0bXbp0YcCAAdW23eTkZB577DG+973vHRgEvvbaa9m0aRMjRoxg9+7duDv33nsvAL/4xS9YunQp7s7QoUPp0aPHMddQp74UPj8/34/mG8HufWsJ97+zlC9/N1yXnItEmUWLFtGtW7dIl1FvVLY/zWy2u+cfvGxMjAEkxoXe9MvK607YiYjUtJjpAgLYu6+cxPiYyDwRibAbbriBqVOnfqPtlltu4YorrohQRd8WGwEQvz8AdAQgEo3cvd51z44ZM6bWt3mkXfox8XE4MT70i6WBYJHok5yczMaNG4/4zUu+af8XwiQnJ4e9TowdASgARKJNmzZtKCoqori4ONKl1Hn7vxIyXDEVALofkEj0SUxMDPsrDKV6hdUFZGbpZjbRzBab2SIzG1hh3q1m5mbWrIp1J5tZiZm9WsX8+81s+9GVH579XUC6I6iIyNfCPQK4D5js7hebWRLQCMDMcoBhwMpDrPunYPlrDp5hZvlAxhFVfBTUBSQi8m2HPQIwszTgZOARAHcvdfeSYPZfgNuAKvtW3P0d4FvXWZtZPKFwuO2Iqz5CCfuvA1AXkIjIAeF0AeUBxcBjZjbXzMabWYqZjQBWuXvBUW77RuBld19zlOuHbf91AOoCEhH5WjhdQAlAb+Amd59hZvcBdxM6Khh2NBs1s1bA94BTwlh2NDAaoG3btkezOZI0CCwi8i3hHAEUAUXuPiOYnkgoEPKAAjNbAbQB5phZizC32wvoCCwL1m9kZssqW9DdH3b3fHfPz8rKCvPpv2l/F5DGAEREvnbYIwB3X2tmhWbWxd0/B4YCc9x96P5lgjfxfHffEM5G3f014EBYmNl2d+94xNWHSV1AIiLfFu6VwDcBT5rZfKAn8NuqFjSzfDMbX2F6CjABGGpmRWZ2xjHUe1TUBSQi8m1hnQbq7vOAb91KtML83AqPZwFXVZgeEsbzp4ZTx9FK0K0gRES+JUbuBaTrAEREDhYTAZCku4GKiHxLTASAuoBERL4tJgJAXUAiIt8WYwGgLiARkf1iJADUBSQicrAYCYDgCKBMASAisl9MBMCBW0GUqwtIRGS/mAgAMyMx3tQFJCJSQUwEAIS6gdQFJCLytZgJgEZJCWzfUxbpMkREokbMBEBO04as3LQz0mWIiESNmAmAtk0bKQBERCqIqQBYXbKLUo0DiIgAMRYA5Q6rS3ZFuhQRkagQUwEA8OXGHRGuREQkOsRMAHRr1YQmyQk8+tGXuOuCMBGRmAmAJsmJ/Oz0zkxZuoGb/z2PXaX7Il2SiEhEhfWVkPXFyEG57Ny7jz+98TnL1m9nzA960T6rRr+NUkQkasXMEQCEbglx/SkdeWRkPmu27OKcv33Ei3NXRbosEZGIiKkA2O/Urtm8fssQjmvVhJ88M4/bJhaws1RXCYtIbInJAABomdaQp68ewI3f7ciE2UWMeGAqS9Zti3RZIiK1JmYDACAhPo6fn9GFf/64H5t3lnLeAx/xzMyVOktIRGJCTAfAfkM6ZTHpliH0bpvBfz73KT95Zp5uHCci9Z4CINC8cTJPXNmfW0/vzCsFqznn/iksWLUl0mWJiNSYsALAzNLNbKKZLTazRWY2sMK8W83MzaxZFetONrMSM3v1oPYnzexzM1tgZo+aWeKxvZRjFx9n3DS0E09fPYBde/dx4dhp/HP6CnUJiUi9FO4RwH3AZHfvCvQAFgGYWQ4wDFh5iHX/BFxeSfuTQFfgBKAhcFWYtdS4/u0zmXTzEAZ1zOTOlxZy3b/msGXX3kiXJSJSrQ4bAGaWBpwMPALg7qXuXhLM/gtwG1DlR2R3fwf41uk17j7JA8AnQJsjrr4GZaY24NGRfbljeFfeXrSOs++fwtyVmyNdlohItQnnCCAPKAYeM7O5ZjbezFLMbASwyt0LjqWAoOvncmByFfNHm9ksM5tVXFx8LJs6YnFxxuiTO/DstQNxh+89OJ3xU75Ql5CI1AvhBEAC0BsY5+69gB3A3cAdwJ3VUMNY4EN3n1LZTHd/2N3z3T0/KyurGjZ35Hq3zWDSzUM4tWtzfvPaIq5/cg7bdqtLSETqtnACoAgocvcZwfREQoGQBxSY2QpC3TdzzKzFkWzczO4CsoCfHcl6kZDWKJGHLu/DHcO78uZn6zjvgaksXrs10mWJiBy1wwaAu68FCs2sS9A0FJjj7s3dPdfdcwmFRO9g2bCY2VXAGcCl7l4nvqbLLNQl9NRV/dm+p4zzx0zludlFkS5LROSohHsW0E3Ak2Y2H+gJ/LaqBc0s38zGV5ieAkwAhppZkZmdEcx6EMgGppvZPDOrju6kWtG/fSav3XwSPXPSuXVCAb98/lN279XtpUWkbrG6NKCZn5/vs2bNinQZB5TtK+fet5Yw9v3ldG/ZhLE/7E1us5RIlyUi8g1mNtvd8w9u15XAxyAhPo7bzuzKo6PyWVWyi3P/9hGvf7om0mWJiIRFAVANTu2azWs3n0SH5qlc9+Qc7nllIaVldWJYQ0RimAKgmrTJaMSz1wzkisG5PDZ1BZc8NJ1VJbsiXZaISJUUANUoKSGOu849jrE/7M2y9ds5+/4pvLd4faTLEhGplAKgBgw/oSWv3nQSLdMacsU/ZvLHyYsp26cuIRGJLgqAGpLbLIUXrh/Epf1yGPv+cn44fgbrt+6OdFkiIgcoAGpQcmI8v7vwRO69pAfzi7Yw/P6PmLZ8Q6TLEhEBFAC14sLebXjpxsGkN0rksvEzGPv+MsrL6871FyJSPykAaknn7Ma8dMNgzj6xFX+c/Dmjn5it7xgQkYhSANSilAYJ3P/9ntx9bnfe/3w95/7tIxau1tdOikhkKABqmZkxanAez1wzkNKyci4cO40JswojXZaIxCAFQIT0aZfBqzefRJ92Gfxi4nxum1jArlLdUE5Eao8CIIKapTbgiSv7c+N3OzJhdhHnj5nK8uLtkS5LRGKEAiDC4uOMn5/RhX9c0Y/i7Xs4928f8dK8VZEuS0RigAIgSnyncxav3XwSx7Vqwi3/nqfvGBCRGqcAiCIt0xry1NUDuOY77Xn6k5VcOHYaKzbsiHRZIlJPKQCiTGJ8HL88qxuPjAx9x8A5f/uI1+brOwZEpPopAKLU0G7ZTLplCJ2yU7nhqTnc9dIC9pSpS0hEqo8CIIq1Tm/IM6MHctVJeTw+/SsuHjedlRt3RrosEaknFABRLikhjv8+pzsPXd6Hrzbu4Oz7pzBJXzspItVAAVBHnHFcC167eQjtm6dy/ZNz+O8XdZaQiBwbBUAdktO0EROuGcjok9vzr49XcsHYaXyhC8dE5CgpAOqYpIQ47hjejUdH5bN2S+gsoRfn6sIxETlyCoA66tSuobOEjmvVhJ88M0/3EhKRI6YAqMNapjXk6asHHLiX0HkPfMSSddsiXZaI1BFhBYCZpZvZRDNbbGaLzGxghXm3mpmbWbMq1p1sZiVm9upB7XlmNsPMlpnZM2aWdGwvJTYlxMfx8zO68M8f92PzzlLOe+Ajnp1ZiLu+cUxEDi3cI4D7gMnu3hXoASwCMLMcYBiw8hDr/gm4vJL2PwB/cfeOwGbgynCLlm8b0imLSbcMoXfbDG57bj4/fWYe2/eURbosEYlihw0AM0sDTgYeAXD3UncvCWb/BbgNqPLjpru/A3yjX8LMDDgVmBg0PQ6cf2Sly8GaN07miSv7c+vpnXm5YLW+cUxEDimcI4A8oBh4zMzmmtl4M0sxsxHAKncvOIrtZgIl7r7/I2oR0LqyBc1stJnNMrNZxcXFR7Gp2BIfZ9w0tBNPXz2AnaVlXDB2Gv/+5FAHaCISq8IJgASgNzDO3XsBO4C7gTuAO2uutBB3f9jd8909Pysrq6Y3V2/0b5/JpJuH0D+vKbc//yn3vLKQsn3lkS5LRKJIOAFQBBS5+4xgeiKhQMgDCsxsBdAGmGNmLcLc7kYg3cwSguk2gE5mr2aZqQ14bFRffjw4j8emruCKf8xky869kS5LRKLEYQPA3dcChWbWJWgaCsxx9+bunuvuuYRConew7GF56BSV94CLg6aRwEtHWrwcXkJ8HHee250/XHQCH3+xkQvGTtXVwyIChH8W0E3Ak2Y2H+gJ/LaqBc0s38zGV5ieAkwAhppZkZmdEcz6T+BnZraM0JjAI0dRv4TpP/q25cmrBlCyay/nj5nKh0s0niIS66wunS+en5/vs2bNinQZdVrhpp1c/c9ZLFm3jf85pzujBuUSOilLROorM5vt7vkHt+tK4BiT07QRE68bxNBu2dzzymfc8cKnlJZpcFgkFikAYlBqgwQeuqwPN3y3A09/Ushlj8xg4/Y9kS5LRGqZAiBGxcUZvzijK/d9vyfzCksYMWYqi9dujXRZIlKLFAAxbkTP1jx7zUBKy8q5aOw03lu8PtIliUgtUQAIPXPSefnGk8jLSuGqf87SlcMiMUIBIAC0SEvm36MHclLHZtz+/Kfc+9YS3VFUpJ5TAMgBqQ0SGD8yn+/1acP97yzlFxPns1e3jxCptxIOv4jEksT4OP548Ym0Sm/Ife8sZd3W3Yy7rA+pDfSrIlLf6AhAvsXM+OnpnfnDRScwbflG/uOh6azfujvSZYlINVMASJX+o29bxo/M58sNO7hg7DSWrdfXTYrUJwoAOaTvdmnOM6MHsqesnIvGTWfmik2RLklEqokCQA7rhDZpvHD9IDJTkrhs/AwmLwjrpq8iEuUUABKW/fcQ6t6qCdc9OZsnpq+IdEkicowUABK2pilJPHXVAE7t0pz/eWkhf3pjsa4VEKnDFAByRBomxfPQ5X24tF8OY95brmsFROowndwtRywhPo7fXnAC2U2S+evbSynetoexP+xNiq4VEKlTdAQgR8XM+MlpnfndhScwZWkxl/79YzboltIidYoCQI7Jpf3a8vDl+SxZt42Lxk1jxYYdkS5JRMKkAJBjdlr3bJ66egBbd+3lonHTmF9UEumSRCQMCgCpFr3bZjDxukE0TIrn+w9/zHuf63sFRKKdAkCqTYesVJ6/bhC5mSlc9fgsJswqjHRJInIICgCpVs2bJPPMNQMY2D6TX0ycz5j3lulaAZEopQCQatc4OZFHR/Xl/J6t+NMbn3PnSwvZV64QEIk2OnFbakRSQhz3XtKT7LRkHvrgC9Zv28193+9FcmJ8pEsTkYCOAKTGxMUZvzyrG3ed2503P1vHZeNnULKzNNJliUggrAAws3Qzm2hmi81skZkNrDDvVjNzM2tWxbojzWxp8DOyQvulZvapmc03s8lVrS913xWD8/jbpb2YX7SFix+czqqSXZEuSUQI/wjgPmCyu3cFegCLAMwsBxgGrKxsJTNrCtwF9Af6AXeZWYaZJQTP+V13PxGYD9x4LC9Eots5J7bi8R/3Y93W3Vw4diqL1myNdEkiMe+wAWBmacDJwCMA7l7q7iXB7L8AtwFVjfCdAbzl7pvcfTPwFnAmYMFPipkZ0ARYfQyvQ+qAgR0ymXDtQAzjew9O54MlxZEuSSSmhXMEkAcUA4+Z2VwzG29mKWY2Aljl7gWHWLc1UPFk8CKgtbvvBa4DPiX0xt+dIGAOZmajzWyWmc0qLtYbRl3XtUUTXrhhEDlNG/Hjf8zkqRmVHjyKSC0IJwASgN7AOHfvBewA7gbuAO48mo2aWSKhAOgFtCLUBfTLypZ194fdPd/d87Oyso5mcxJlWqY1ZMK1AxnSqRl3vPApv5u0iHKdJipS68IJgCKgyN1nBNMTCQVCHlBgZiuANsAcM2tx0LqrgJwK022Ctp4A7r7cQ1cJPQsMOsrXIHVQaoMExv8on8sGtOWhD7/ghqfmsKt0X6TLEokphw0Ad18LFJpZl6BpKDDH3Zu7e6675xIKid7BshW9AQwLBn4zCA0Yv0EoBLqb2f6P9KcTDCxL7EiIj+PXI47nv8/uxuSFa/n+w9NZt3V3pMsSiRnhngV0E/Ckmc0n9On9t1UtaGb5ZjYewN03Ab8GZgY/vwoGhFcD9wAfhvOcUn+ZGVcNac/Dl+ezdP12RjwwlQWrtkS6LJGYYHXpPi35+fk+a9asSJchNWTRmq1c9fgsNu7Yw72X9GT4CS0jXZJIvWBms909/+B2XQksUaNbyya8eMNgurdswvVPzuH+d5ZqcFikBikAJKpkNW7AU1cP4MJerbn3rSWMfmI2W3btjXRZIvWSAkCiTnJiPP93SQ/uPrc773++nvMe+IiFqzUuIFLdFAASlcyMUYPzeOaaAezZW86FY6cxcXZRpMsSqVcUABLV+rRryqs3n0Tvthn8fEIBtz83X9cLiFQTBYBEvWapDXjiyn5cf0oH/j2zkBFjPmLpum2RLkukzlMASJ2QEB/HbWd25Z8/7semHaWc+8BHPDuzUF83KXIMFABSp5zcOYtJNw+hd9sMbntuPj99Zh7b95RFuiyROkkBIHVO8ybJPHFlf352emdeLljNuX/TWUIiR0MBIHVSfJxx89BOPH31AHaWlnHBmGn8c/oKdQmJHAEFgNRp/dtnMunmIQzqmMmdLy3kun/N0YVjImFSAEidl5nagEdH9uWO4V15e9E6zr5/CnNXbo50WSJRTwEg9UJcnDH65A5MuHYg7vC9B6fz8IfLdS8hkUNQAEi90qttBpNuHsJp3bL57aTFXPn4TDbtKI10WSJRSQEg9U5ao0TGXdabX484jqnLNnLWfR8y44uNkS5LJOooAKReMjMuH5jL89cPolFSApf+/WPuf2cp+9QlJHKAAkDqteNbp/HKTSdxXo9W3PvWEi5/ZAbF2/ZEuiyRqKAAkHovtUECf/mPnvzxohOZ/dVmzr5/irqERFAASIwwMy7pm8OLNwwmpUECPxg/gwc/WK4LxySmKQAkpnRr2YSXbxzMmce14PevL+bKx2exftvuSJclEhEKAIk5jZMTeeAHvbjnvOOYumwDZ/51CpMXrI10WSK1TgEgMcnMGDkol9duPonW6Q259l+zufXZArbu1m0kJHYoACSmdWzemOevH8TNp3bkxXmrOOuvU/hgSXGkyxKpFQoAiXmJ8XH8bFgXJl47kAaJcYx89BOufWI2q0t2Rbo0kRqlABAJ9Gqbweu3DOEXZ3Th/SXrGfp/HzDu/eWUlpVHujSRGhFWAJhZuplNNLPFZrbIzAZWmHermbmZNati3ZFmtjT4GVmhPcnMHjazJcHzXnTsL0fk2DRIiOeG73bk7Z99hyGdmvGHyYs5674PmbZsQ6RLE6l24R4B3AdMdveuQA9gEYCZ5QDDgJWVrWRmTYG7gP5AP+AuM8sIZv8XsN7dOwPdgQ+O9kWIVLc2GY14+Ef5PDaqL3v3OT8YP4Obnp7Luq06ZVTqj8MGgJmlAScDjwC4e6m7lwSz/wLcBlR1Nc0ZwFvuvsndNwNvAWcG834M/C54znJ310csiTrf7dqcN396Mj85rRNvLFzLqX9+n4c/XM7uvfsiXZrIMQvnCCAPKAYeM7O5ZjbezFLMbASwyt0LDrFua6CwwnQR0NrM0oPpX5vZHDObYGbZlT2BmY02s1lmNqu4WGdnSO1LToznJ6d15q2fnky/vKb8dtJiTvnT+zw1YyV792l8QOqucAIgAegNjHP3XsAO4G7gDuDOo9xuAtAGmObuvYHpwJ8rW9DdH3b3fHfPz8rKOsrNiRy7dpkpPHZFP566uj+t0pO544VPOe3eD3hhbpHuMip1UjgBUAQUufuMYHoioUDIAwrMbAWhN/M5ZtbioHVXATkVptsEbRuBncDzQfuE4DlFot6gDs147rpBPDoqn5SkBH76TAFn/vVDJi9Yo3sLSZ1y2ABw97VAoZl1CZqGAnPcvbm757p7LqGQ6B0sW9EbwDAzywgGf4cBb3jor+QV4JQKz/nZMb8akVpiZpzaNZtXbzqJMT/oTbk71/5rDuc+8BHvfb5eQSB1goXzi2pmPYHxQBLwBXBFMKi7f/4KIN/dN5hZPnCtu18VzPsxoe4igP9198eC9nbAE0A6oTGGK9y90rOJ9svPz/dZs2YdyesTqRX7yp0X567ir+8soXDTLnq1TefmUztxSpcszCzS5UmMM7PZ7p7/rfa69ElFASDRrrSsnAmzCxn73nJWlezihNZp3HRqR07vnq0gkIhRAIjUor37ynlhzioeeG8ZKzftpGuLxtzw3Y4MP6El8XEKAqldCgCRCCjbV87LBasZ894ylhfvoH2zFK47pQPn92pNYrzuxCK1QwEgEkHl5c7khWv527vLWLRma+gW1Kd04Ht92pCcGB/p8qSeUwCIRAF3593F6/nbu8uYV1hCVuMGXDE4lx/2b0daw8RIlyf1lAJAJIq4O9OWb+TBD5YzZekGUhskcEl+DpcNaEv7rNRIlyf1jAJAJEotXL2Fhz/8gkmfrmHvPmdIp2ZcPqAdp3ZtToLGCaQaKABEotz6bbt5dmYhT85YyZotu2mVlswPB7Tjkvwcsho3iHR5UocpAETqiLJ95by9aD3/+vgrPlq2gfg4Y0inZlzQqzWnd8+mUVJCpEuUOkYBIFIHLS/eznOzi3hp3mpWleyiUVI8Zx7XghG9WjO4Q6a6iCQsCgCROqy83Jm5YhMvzlvFa/PXsHV3Gc1SG3Bej1Zc0Ks1x7duoiuNpUoKAJF6Yk/ZPt5bXMyLc1fx7uL1lO4rp0NWChf0as2Inq3Jadoo0iVKlFEAiNRDW3buZdKCNbwwdxWffLkJgBPbpHF6t2xOPy6bLtmNdWQgCgCR+q5o805eKVjDm5+tZe7KEgBymjbk9G4tGHZcNvntMjRmEKMUACIxZP3W3by9aD1vfbaWqcs3UlpWTnqjRE7t2pxh3bM5uXOWziaKIQoAkRi1Y08ZHy4p5q3P1vHO4vVs2bWXpIQ4TurYjGHdsxnaLVvXGdRzCgARoWxfOTNXbOatz9bx5mdrKdq8CzPolZPOkE5ZDOqQSc+26TRI0A3q6hMFgIh8g7uzeO220JHBonV8umoL5Q7JiXHkt2vKwA6ZDOqQyQmt0zR2UMcpAETkkLbs2ssnX25i2vINTF++kcVrtwGQ2iCBfnlNGdQhk/55mXRt2VjfZVDHVBUAGgUSEQDSGiZyevdsTu+eDcDG7Xv4+IuvA+HdxesBaJgYT4+cNPLbNaVPuwx6tU0nvVFSJEuXo6QjABEJy9otu5m5YhOzv9rMnJWbWbh6K/vKQ+8fnZqn0qddBr3bZdCnXQbtm6Xo+oMooi4gEalWO0vLKCjcwpyVm5m1YhNzVpawZddeADIaJX4dCG0z6JGTrm8+iyB1AYlItWqUlMDADpkM7JAJhO5X9MWG7cxasZnZX21m9srNvL0o1G2UEGcc1zqNPm1DRwh92mXQIi05kuULOgIQkRq0aUcpc4IwmL1iMwVFJewpKwegRZNkTmyTRo+cdHq0SeeENmn6WswaoiMAEal1TVOSOK17NqcFA8ulZeV8tmYrs7/azPyiEgoKS3jzs3UHlm/fLIUeOekHgqF7yybqOqpBYQWAmaUD44HjAQd+7O7Tg3m3An8Gstx9QyXrjgT+O5j8jbs/ftD8l4H27n780b4IEakbkhLi6JmTTs+c9ANtJTtLmV+0hflFJcwr3MLUZRt4Ye4qINR11LVlY05sk86JrdM4oU0anbN1Gmp1CfcI4D5gsrtfbGZJQCMAM8sBhgErK1vJzJoCdwH5hIJjtpm97O6bg/kXAtuP7SWISF2W3iiJkztncXLnrANta7fsZl5hSegooaiEVwpW89SM0NtMUkIc3Vo2oef+7qOcdPIyU4iL01lHR+qwYwBmlgbMI/Qp3Q+aNxH4NfASkH/wEYCZXQqc4u7XBNMPAe+7+9NmlgpMBkYDz4ZzBKAxAJHY5O58tXEn81dtYcGqLRQUlrBg1RZ2lO4DoHFyAj3apNMjJ40ebUJHGM2baJB5v2MZA8gDioHHzKwHMBu4BTgNWOXuBYc437c1UFhhuihog1Bw/B+wM6xXICIxy8zIbZZCbrMUzuvRCoB95c7y4u3MKwyNJRQUlfDQB19QFlyb0DItmZ7BEcL+QebUBhr2rCicvZEA9AZucvcZZnYfcDdwMqHunyNmZj2BDu7+UzPLPcyyowkdJdC2bduj2ZyI1EPxcUbn7MZ0zm7MJfk5AOzeu4+Fq7cwr3DLgVB4fcFaAMxCF6yFjhRCRwldWsT2eEI4XUAtgI/dPTeYHkIoAE7g60/vbYDVQD93X1th3Uq7gIB04H+AUkIB0xyY5u6nHKoWdQGJyJHavKOUgqISCgq3BP+WsHFHKQANEuI4rlUTTmwTOuOoW8smdMpOrXdnHh3TlcBmNgW4yt0/N7O7gRR3/0WF+SuofAygKaEuo95B0xygj7tvqrBMLvCqxgBEpDa4O0Wbdx0Ig3mFJSxcvZWdwXhCfJzRISuFbkEgdG0ROspomZZcZ29vcazXAdwEPBmcAfQFcMUhNpQPXOvuV7n7JjP7NTAzmP2rim/+IiK1zczIadqInKaNOOfE0HhCebnz1aadLFqzlc9Wb2XRmq188uUmXpq3+sB6qQ0S6NA8lc7NU+mUnUqn5o3plJ1Kq7SGdfYMJF0JLCJShc07Svl83TaWrt/OsuDfJeu2s2H7ngPLNEqKp2PzVDo2D4VC5yAc2mRETzDoSmARkSOUkZLEgPaZDGif+Y32zTtKWVa8naXrtrNk3TaWrd/O1GUbeH7OqgPLJCfG0SErlc7ZjYNwSKVTdmPaNm1EfJQEgwJAROQIZaQk0TelKX1zm36jfcuuvSxbv42l67azdH3o5+MvNh64shlCF7J1yAoCIQiFTtmptGvaqNa/eU0BICJSTdIaJtKnXVP6tPtmMGzbvZdlQSAsWx86apj91WZeLvh6jCEpPo68ZinfGF/o1DyV3GYpNXaqqgJARKSGNU5OpFfbDHq1zfhG+449ZSwvDo0rLF2/jWXrtjO/aAuvfbqG/cOzCXFGXrMUxl3Wh47NU6u1LgWAiEiEpDRICN3ork36N9p3lpbxRfEOlgbdSUvWbadZavV/7aYCQEQkyjRKSuD41mkc3zqtRrcTu9dAi4jEOAWAiEiMUgCIiMQoBYCISIxSAIiIxCgFgIhIjFIAiIjEKAWAiEiMqlO3gzazYuCrMBZtBmw47FKRodqOTrTWFq11gWo7GtFaFxxbbe3cPevgxjoVAOEys1mV3fs6Gqi2oxOttUVrXaDajka01gU1U5u6gEREYpQCQEQkRtXXAHg40gUcgmo7OtFaW7TWBartaERrXVADtdXLMQARETm8+noEICIih6EAEBGJUfUuAMzsTDP73MyWmdntUVDPCjP71MzmmdmsoK2pmb1lZkuDfzMO9zzVVMujZrbezBZUaKu0Fgu5P9iP882sdy3XdbeZrQr22zwzG15h3i+Duj43szNqqq5gWzlm9p6ZfWZmC83slqA9ovvtEHVFfL+ZWbKZfWJmBUFt9wTteWY2I6jhGTNLCtobBNPLgvm5EajtH2b2ZYX91jNor7W/g2B78WY218xeDaZrdp+5e735AeKB5UB7IAkoALpHuKYVQLOD2v4I3B48vh34Qy3VcjLQG1hwuFqA4cDrgAEDgBm1XNfdwM8rWbZ78P/aAMgL/r/ja7C2lkDv4HFjYElQQ0T32yHqivh+C157avA4EZgR7Itnge8H7Q8C1wWPrwceDB5/H3imBv8/q6rtH8DFlSxfa38HwfZ+BjwFvBpM1+g+q29HAP2AZe7+hbuXAv8GRkS4psqMAB4PHj8OnF8bG3X3D4FNYdYyAvinh3wMpJtZy1qsqyojgH+7+x53/xJYRuj/vUa4+xp3nxM83gYsAloT4f12iLqqUmv7LXjt24PJxODHgVOBiUH7wfts/76cCAw1M6vl2qpSa38HZtYGOBsYH0wbNbzP6lsAtAYKK0wXceg/itrgwJtmNtvMRgdt2e6+Jni8FsiOTGmHrCUa9uWNwWH3oxW6ySJWV3CY3YvQp8ao2W8H1QVRsN+Crox5wHrgLUJHHCXuXlbJ9g/UFszfAmTWVm3uvn+//W+w3/5iZg0Orq2SuqvbX4HbgPJgOpMa3mf1LQCi0Unu3hs4C7jBzE6uONNDx3BRcS5uNNUCjAM6AD2BNcD/RbIYM0sFngN+4u5bK86L5H6rpK6o2G/uvs/dewJtCB1pdI1EHZU5uDYzOx74JaEa+wJNgf+szZrM7BxgvbvPrs3t1rcAWAXkVJhuE7RFjLuvCv5dD7xA6I9h3f7DyODf9ZGrsMpaIrov3X1d8IdaDvydr7srar0uM0sk9Cb7pLs/HzRHfL9VVlc07begnhLgPWAgoe6ThEq2f6C2YH4asLEWazsz6FJzd98DPEbt77fBwHlmtoJQ1/WpwH3U8D6rbwEwE+gUjJwnERoceTlSxZhZipk13v8YGAYsCGoaGSw2EngpMhXCIWp5GfhRcBbEAGBLhS6PGndQP+sFhPbb/rq+H5wFkQd0Aj6pwToMeARY5O73VpgV0f1WVV3RsN/MLMvM0oPHDYHTCY1RvAdcHCx28D7bvy8vBt4Njqpqq7bFFcLcCPWzV9xvNf7/6e6/dPc27p5L6H3rXXf/ITW9z6pzBDsafgiN2i8h1Of4XxGupT2hMy8KgIX76yHUV/cOsBR4G2haS/U8TahbYC+h/sQrq6qF0FkPY4L9+CmQX8t1PRFsd37wy96ywvL/FdT1OXBWDe+zkwh178wH5gU/wyO93w5RV8T3G3AiMDeoYQFwZ4W/h08IDUBPABoE7cnB9LJgfvsI1PZusN8WAP/i6zOFau3voEKNp/D1WUA1us90KwgRkRhV37qAREQkTAoAEZEYpQAQEYlRCgARkRilABARiVEKAJEwmdm04N9cM/tBpOsROVYKAJEwufug4GEucEQBUOFqTpGooQAQCZOZ7b+L5O+BIcF9438a3FzsT2Y2M7iZ2DXB8qeY2RQzexn4LGKFi1RBn0pEjtzthO65fw5AcJfXLe7eN7iL5FQzezNYtjdwvIduwSwSVRQAIsduGHCime2/Z0saoXvtlAKf6M1fopUCQOTYGXCTu7/xjUazU4AdkShIJBwaAxA5ctsIfQ3jfm8A1wW3Z8bMOgd3fxWJajoCEDly84F9ZlZA6Ltk7yN0ZtCc4HbCxdTS13yKHAvdDVREJEapC0hEJEYpAEREYpQCQEQkRikARERilAJARCRGKQBERGKUAkBEJEb9PyVepHxYnNXlAAAAAElFTkSuQmCC\n", "text/plain": ["
"]}, "metadata": {"needs_background": "light"}, "output_type": "display_data"}], "source": ["from pandas import DataFrame\n", "\n", "df = DataFrame(data=train_losses, columns=['iter', 'train_loss'])\n", "df[6:].plot(x=\"iter\", y=\"train_loss\", title=\"Training loss\");"]}, {"cell_type": "code", "execution_count": 18, "id": "dfe34617", "metadata": {}, "outputs": [], "source": []}], "metadata": {"kernelspec": {"display_name": "Python 3", "language": "python", "name": "python3"}, "language_info": {"codemirror_mode": {"name": "ipython", "version": 3}, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.5"}}, "nbformat": 4, "nbformat_minor": 5}