.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "gyexamples/plot_orttraining_nn_gpu.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_gyexamples_plot_orttraining_nn_gpu.py: .. _l-orttraining-nn-gpu: Train a scikit-learn neural network with onnxruntime-training on GPU ==================================================================== This example leverages example :ref:`l-orttraining-linreg-gpu` to train a neural network from :epkg:`scikit-learn` on GPU. However, the code is using classes implemented in this module, following the pattern introduced in exemple :ref:`l-orttraining-linreg`. .. contents:: :local: A neural network with scikit-learn ++++++++++++++++++++++++++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 20-50 .. code-block:: default import warnings from pprint import pprint import numpy from pandas import DataFrame from onnxruntime import get_device, InferenceSession from sklearn.datasets import make_regression from sklearn.model_selection import train_test_split from sklearn.neural_network import MLPRegressor from sklearn.metrics import mean_squared_error from onnxcustom.plotting.plotting_onnx import plot_onnxs from mlprodict.onnx_conv import to_onnx from onnxcustom.utils.orttraining_helper import ( add_loss_output, get_train_initializer) from onnxcustom.training.optimizers import OrtGradientOptimizer X, y = make_regression(1000, n_features=10, bias=2) X = X.astype(numpy.float32) y = y.astype(numpy.float32) X_train, X_test, y_train, y_test = train_test_split(X, y) nn = MLPRegressor(hidden_layer_sizes=(10, 10), max_iter=200, solver='sgd', learning_rate_init=1e-4, alpha=0, n_iter_no_change=1000, batch_size=10, momentum=0, nesterovs_momentum=False) with warnings.catch_warnings(): warnings.simplefilter('ignore') nn.fit(X_train, y_train) .. GENERATED FROM PYTHON SOURCE LINES 51-52 Score: .. GENERATED FROM PYTHON SOURCE LINES 52-56 .. code-block:: default print(f"mean_squared_error={mean_squared_error(y_test, nn.predict(X_test))!r}") .. rst-class:: sphx-glr-script-out .. code-block:: none mean_squared_error=0.093900256 .. GENERATED FROM PYTHON SOURCE LINES 57-59 Conversion to ONNX ++++++++++++++++++ .. GENERATED FROM PYTHON SOURCE LINES 59-63 .. code-block:: default onx = to_onnx(nn, X_train[:1].astype(numpy.float32), target_opset=15) plot_onnxs(onx) .. image-sg:: /gyexamples/images/sphx_glr_plot_orttraining_nn_gpu_001.png :alt: plot orttraining nn gpu :srcset: /gyexamples/images/sphx_glr_plot_orttraining_nn_gpu_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 64-72 Training graph ++++++++++++++ The loss function is the square function. We use function :func:`add_loss_output `. It does something what is implemented in example :ref:`l-orttraining-linreg-cpu`. .. GENERATED FROM PYTHON SOURCE LINES 72-76 .. code-block:: default onx_train = add_loss_output(onx) plot_onnxs(onx_train) .. image-sg:: /gyexamples/images/sphx_glr_plot_orttraining_nn_gpu_002.png :alt: plot orttraining nn gpu :srcset: /gyexamples/images/sphx_glr_plot_orttraining_nn_gpu_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none .. GENERATED FROM PYTHON SOURCE LINES 77-78 Let's check inference is working. .. GENERATED FROM PYTHON SOURCE LINES 78-84 .. code-block:: default sess = InferenceSession(onx_train.SerializeToString(), providers=['CPUExecutionProvider']) res = sess.run(None, {'X': X_test, 'label': y_test.reshape((-1, 1))}) print(f"onnx loss={res[0][0, 0] / X_test.shape[0]!r}") .. rst-class:: sphx-glr-script-out .. code-block:: none onnx loss=0.0939002456665039 .. GENERATED FROM PYTHON SOURCE LINES 85-87 Let's retrieve the constant, the weight to optimize. We remove initializer which cannot be optimized. .. GENERATED FROM PYTHON SOURCE LINES 87-93 .. code-block:: default inits = get_train_initializer(onx) weights = {k: v for k, v in inits.items() if k != "shape_tensor"} pprint(list((k, v[0].shape) for k, v in weights.items())) .. rst-class:: sphx-glr-script-out .. code-block:: none [('coefficient', (10, 10)), ('intercepts', (1, 10)), ('coefficient1', (10, 10)), ('intercepts1', (1, 10)), ('coefficient2', (10, 1)), ('intercepts2', (1, 1))] .. GENERATED FROM PYTHON SOURCE LINES 94-99 Training ++++++++ The training session. If GPU is available, it chooses CUDA otherwise it falls back to CPU. .. GENERATED FROM PYTHON SOURCE LINES 99-104 .. code-block:: default device = "cuda" if get_device().upper() == 'GPU' else 'cpu' print(f"device={device!r} get_device()={get_device()!r}") .. rst-class:: sphx-glr-script-out .. code-block:: none device='cpu' get_device()='CPU' .. GENERATED FROM PYTHON SOURCE LINES 105-106 The training session. .. GENERATED FROM PYTHON SOURCE LINES 106-122 .. code-block:: default train_session = OrtGradientOptimizer( onx_train, list(weights), device=device, verbose=1, learning_rate=5e-4, warm_start=False, max_iter=200, batch_size=10) train_session.fit(X, y) state_tensors = train_session.get_state() print(train_session.train_losses_) df = DataFrame({'ort losses': train_session.train_losses_, 'skl losses:': nn.loss_curve_}) df.plot(title="Train loss against iterations", logy=True) # import matplotlib.pyplot as plt # plt.show() .. image-sg:: /gyexamples/images/sphx_glr_plot_orttraining_nn_gpu_003.png :alt: Train loss against iterations :srcset: /gyexamples/images/sphx_glr_plot_orttraining_nn_gpu_003.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0/200 [00:00 .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 0 minutes 27.891 seconds) .. _sphx_glr_download_gyexamples_plot_orttraining_nn_gpu.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: plot_orttraining_nn_gpu.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: plot_orttraining_nn_gpu.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_