.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gyexamples/plot_orttraining_linear_regression_fwbw.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_gyexamples_plot_orttraining_linear_regression_fwbw.py: .. _l-orttraining-linreg-fwbw: Train a linear regression with forward backward =============================================== This example rewrites :ref:`l-orttraining-linreg` with another optimizer :class:`OrtGradientForwardBackwardOptimizer `. This optimizer relies on class :epkg:`TrainingAgent` from :epkg:`onnxruntime-training`. In this case, the user does not have to modify the graph to compute the error. The optimizer builds another graph which returns the gradient of every weights assuming the gradient on the output is known. Finally, the optimizer adds the gradients to the weights. To summarize, it starts from the following graph: .. image:: images/onnxfwbw1.png Class :class:`OrtGradientForwardBackwardOptimizer ` builds other ONNX graph to implement a gradient descent algorithm: .. image:: images/onnxfwbw2.png The blue node is built by class :epkg:`TrainingAgent` (from :epkg:`onnxruntime-training`). The green nodes are added by class :class:`OrtGradientForwardBackwardOptimizer `. This implementation relies on ONNX to do the computation but it could be replaced by any other framework such as :epkg:`pytorch`. This design gives more freedom to the user to implement his own training algorithm. .. contents:: :local: A simple linear regression with scikit-learn ++++++++++++++++++++++++++++++++++++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 43-61 .. code-block:: default from pprint import pprint import numpy from pandas import DataFrame from onnxruntime import get_device from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split from sklearn.neural_network import MLPRegressor from mlprodict.onnx_conv import to_onnx from onnxcustom.plotting.plotting_onnx import plot_onnxs from onnxcustom.utils.orttraining_helper import get_train_initializer from onnxcustom.training.optimizers_partial import ( OrtGradientForwardBackwardOptimizer) X, y = make_regression(n_features=2, bias=2) X = X.astype(numpy.float32) y = y.astype(numpy.float32) X_train, X_test, y_train, y_test = train_test_split(X, y) .. GENERATED FROM PYTHON SOURCE LINES 62-63 We use a :class:`sklearn.neural_network.MLPRegressor`. .. GENERATED FROM PYTHON SOURCE LINES 63-73 .. code-block:: default lr = MLPRegressor(hidden_layer_sizes=tuple(), activation='identity', max_iter=50, batch_size=10, solver='sgd', alpha=0, learning_rate_init=1e-2, n_iter_no_change=200, momentum=0, nesterovs_momentum=False) lr.fit(X, y) print(lr.predict(X[:5])) .. rst-class:: sphx-glr-script-out .. code-block:: none somewhere/workspace/onnxcustom/onnxcustom_UT_39_std/_venv/lib/python3.9/site-packages/sklearn/neural_network/_multilayer_perceptron.py:679: ConvergenceWarning: Stochastic Optimizer: Maximum iterations (50) reached and the optimization hasn't converged yet. warnings.warn( [-31.613749 74.62066 1.8960948 34.660263 49.48107 ] .. GENERATED FROM PYTHON SOURCE LINES 74-75 The trained coefficients are: .. GENERATED FROM PYTHON SOURCE LINES 75-77 .. code-block:: default print("trained coefficients:", lr.coefs_, lr.intercepts_) .. rst-class:: sphx-glr-script-out .. code-block:: none trained coefficients: [array([[53.34389 ], [24.865873]], dtype=float32)] [array([1.7797871], dtype=float32)] .. GENERATED FROM PYTHON SOURCE LINES 78-84 ONNX graph ++++++++++ Training with :epkg:`onnxruntime-training` starts with an ONNX graph which defines the model to learn. It is obtained by simply converting the previous linear regression into ONNX. .. GENERATED FROM PYTHON SOURCE LINES 84-90 .. code-block:: default onx = to_onnx(lr, X_train[:1].astype(numpy.float32), target_opset=15, black_op={'LinearRegressor'}) plot_onnxs(onx, title="Linear Regression in ONNX") .. image-sg:: /gyexamples/images/sphx_glr_plot_orttraining_linear_regression_fwbw_001.png :alt: Linear Regression in ONNX :srcset: /gyexamples/images/sphx_glr_plot_orttraining_linear_regression_fwbw_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 91-101 Weights +++++++ Every initializer is a set of weights which can be trained and a gradient will be computed for it. However an initializer used to modify a shape or to extract a subpart of a tensor does not need training. :func:`get_train_initializer ` removes them. .. GENERATED FROM PYTHON SOURCE LINES 101-106 .. code-block:: default inits = get_train_initializer(onx) weights = {k: v for k, v in inits.items() if k != "shape_tensor"} pprint(list((k, v[0].shape) for k, v in weights.items())) .. rst-class:: sphx-glr-script-out .. code-block:: none [('coefficient', (2, 1)), ('intercepts', (1, 1))] .. GENERATED FROM PYTHON SOURCE LINES 107-109 Train on CPU or GPU if available ++++++++++++++++++++++++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 109-113 .. code-block:: default device = "cuda" if get_device().upper() == 'GPU' else 'cpu' print(f"device={device!r} get_device()={get_device()!r}") .. rst-class:: sphx-glr-script-out .. code-block:: none device='cpu' get_device()='CPU' .. GENERATED FROM PYTHON SOURCE LINES 114-123 Stochastic Gradient Descent +++++++++++++++++++++++++++ The training logic is hidden in class :class:`OrtGradientForwardBackwardOptimizer ` It follows :epkg:`scikit-learn` API (see `SGDRegressor `_. .. GENERATED FROM PYTHON SOURCE LINES 123-130 .. code-block:: default train_session = OrtGradientForwardBackwardOptimizer( onx, list(weights), device=device, verbose=1, learning_rate=1e-2, warm_start=False, max_iter=200, batch_size=10) train_session.fit(X, y) .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0/200 [00:00, ]] last_losses: [0.0054119555, 0.0043852413, 0.00564376, 0.0044368017, 0.0050066994] .. GENERATED FROM PYTHON SOURCE LINES 143-153 The convergence speed is almost the same. Gradient Graph ++++++++++++++ As mentioned in this introduction, the computation relies on a few more graphs than the initial graph. When the loss is needed but not the gradient, class :epkg:`TrainingAgent` creates another graph, faster, with the trained initializers as additional inputs. .. GENERATED FROM PYTHON SOURCE LINES 153-158 .. code-block:: default onx_loss = train_session.train_session_.cls_type_._optimized_pre_grad_model plot_onnxs(onx, onx_loss, title=['regression', 'loss']) .. image-sg:: /gyexamples/images/sphx_glr_plot_orttraining_linear_regression_fwbw_003.png :alt: regression, loss :srcset: /gyexamples/images/sphx_glr_plot_orttraining_linear_regression_fwbw_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none array([, ], dtype=object) .. GENERATED FROM PYTHON SOURCE LINES 159-160 And the gradient. .. GENERATED FROM PYTHON SOURCE LINES 160-165 .. code-block:: default onx_gradient = train_session.train_session_.cls_type_._trained_onnx plot_onnxs(onx_loss, onx_gradient, title=['loss', 'gradient + loss']) .. image-sg:: /gyexamples/images/sphx_glr_plot_orttraining_linear_regression_fwbw_004.png :alt: loss, gradient + loss :srcset: /gyexamples/images/sphx_glr_plot_orttraining_linear_regression_fwbw_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none array([, ], dtype=object) .. GENERATED FROM PYTHON SOURCE LINES 166-172 The last ONNX graphs are used to compute the gradient *dE/dY* and to update the weights. The first graph takes the labels and the expected labels and returns the square loss and its gradient. The second graph takes the weights and the learning rate as inputs and returns the updated weights. This graph works on tensors of any shape but with the same element type. .. GENERATED FROM PYTHON SOURCE LINES 172-179 .. code-block:: default plot_onnxs(train_session.learning_loss.loss_grad_onnx_, train_session.learning_rate.axpy_onnx_, title=['error gradient + loss', 'gradient update']) # import matplotlib.pyplot as plt # plt.show() .. image-sg:: /gyexamples/images/sphx_glr_plot_orttraining_linear_regression_fwbw_005.png :alt: error gradient + loss, gradient update :srcset: /gyexamples/images/sphx_glr_plot_orttraining_linear_regression_fwbw_005.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none array([, ], dtype=object) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 8.743 seconds) .. _sphx_glr_download_gyexamples_plot_orttraining_linear_regression_fwbw.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_orttraining_linear_regression_fwbw.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_orttraining_linear_regression_fwbw.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_