Train a scikit-learn neural network with onnxruntime-training on GPU#

This example leverages example Train a linear regression with onnxruntime-training on GPU in details to train a neural network from scikit-learn on GPU. However, the code is using classes implemented in this module, following the pattern introduced in exemple Train a linear regression with onnxruntime-training.

A neural network with scikit-learn#

import warnings
from pprint import pprint
import numpy
from pandas import DataFrame
from onnxruntime import get_device, InferenceSession
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPRegressor
from sklearn.metrics import mean_squared_error
from onnxcustom.plotting.plotting_onnx import plot_onnxs
from mlprodict.onnx_conv import to_onnx
from onnxcustom.utils.orttraining_helper import (
    add_loss_output, get_train_initializer)
from onnxcustom.training.optimizers import OrtGradientOptimizer


X, y = make_regression(1000, n_features=10, bias=2)
X = X.astype(numpy.float32)
y = y.astype(numpy.float32)
X_train, X_test, y_train, y_test = train_test_split(X, y)

nn = MLPRegressor(hidden_layer_sizes=(10, 10), max_iter=200,
                  solver='sgd', learning_rate_init=1e-4, alpha=0,
                  n_iter_no_change=1000, batch_size=10,
                  momentum=0, nesterovs_momentum=False)

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    nn.fit(X_train, y_train)

Score:

print(f"mean_squared_error={mean_squared_error(y_test, nn.predict(X_test))!r}")
mean_squared_error=0.45765528

Conversion to ONNX#

onx = to_onnx(nn, X_train[:1].astype(numpy.float32), target_opset=15)
plot_onnxs(onx)
plot orttraining nn gpu
<AxesSubplot:>

Training graph#

The loss function is the square function. We use function add_loss_output. It does something what is implemented in example Train a linear regression with onnxruntime-training in details.

onx_train = add_loss_output(onx)
plot_onnxs(onx_train)
plot orttraining nn gpu
<AxesSubplot:>

Let’s check inference is working.

sess = InferenceSession(onx_train.SerializeToString(),
                        providers=['CPUExecutionProvider'])
res = sess.run(None, {'X': X_test, 'label': y_test.reshape((-1, 1))})
print(f"onnx loss={res[0][0, 0] / X_test.shape[0]!r}")
onnx loss=0.457655029296875

Let’s retrieve the constant, the weight to optimize. We remove initializer which cannot be optimized.

inits = get_train_initializer(onx)
weights = {k: v for k, v in inits.items() if k != "shape_tensor"}
pprint(list((k, v[0].shape) for k, v in weights.items()))
[('coefficient', (10, 10)),
 ('intercepts', (1, 10)),
 ('coefficient1', (10, 10)),
 ('intercepts1', (1, 10)),
 ('coefficient2', (10, 1)),
 ('intercepts2', (1, 1))]

Training#

The training session. If GPU is available, it chooses CUDA otherwise it falls back to CPU.

device = "cuda" if get_device().upper() == 'GPU' else 'cpu'

print(f"device={device!r} get_device()={get_device()!r}")
device='cpu' get_device()='CPU'

The training session.

train_session = OrtGradientOptimizer(
    onx_train, list(weights), device=device, verbose=1,
    learning_rate=5e-4, warm_start=False, max_iter=200, batch_size=10)

train_session.fit(X, y)
state_tensors = train_session.get_state()

print(train_session.train_losses_)

df = DataFrame({'ort losses': train_session.train_losses_,
                'skl losses:': nn.loss_curve_})
df.plot(title="Train loss against iterations", logy=True)

# import matplotlib.pyplot as plt
# plt.show()
Train loss against iterations
  0%|          | 0/200 [00:00<?, ?it/s]
  2%|1         | 3/200 [00:00<00:07, 25.55it/s]
  3%|3         | 6/200 [00:00<00:07, 25.43it/s]
  4%|4         | 9/200 [00:00<00:07, 25.45it/s]
  6%|6         | 12/200 [00:00<00:07, 25.41it/s]
  8%|7         | 15/200 [00:00<00:07, 25.46it/s]
  9%|9         | 18/200 [00:00<00:07, 25.48it/s]
 10%|#         | 21/200 [00:00<00:07, 25.45it/s]
 12%|#2        | 24/200 [00:00<00:06, 25.48it/s]
 14%|#3        | 27/200 [00:01<00:06, 25.47it/s]
 15%|#5        | 30/200 [00:01<00:06, 25.48it/s]
 16%|#6        | 33/200 [00:01<00:06, 25.51it/s]
 18%|#8        | 36/200 [00:01<00:06, 25.51it/s]
 20%|#9        | 39/200 [00:01<00:06, 25.53it/s]
 21%|##1       | 42/200 [00:01<00:06, 25.52it/s]
 22%|##2       | 45/200 [00:01<00:06, 25.52it/s]
 24%|##4       | 48/200 [00:01<00:05, 25.53it/s]
 26%|##5       | 51/200 [00:02<00:05, 25.53it/s]
 27%|##7       | 54/200 [00:02<00:05, 25.54it/s]
 28%|##8       | 57/200 [00:02<00:05, 25.57it/s]
 30%|###       | 60/200 [00:02<00:05, 25.54it/s]
 32%|###1      | 63/200 [00:02<00:05, 25.55it/s]
 33%|###3      | 66/200 [00:02<00:05, 25.52it/s]
 34%|###4      | 69/200 [00:02<00:05, 25.52it/s]
 36%|###6      | 72/200 [00:02<00:05, 25.54it/s]
 38%|###7      | 75/200 [00:02<00:04, 25.53it/s]
 39%|###9      | 78/200 [00:03<00:04, 25.56it/s]
 40%|####      | 81/200 [00:03<00:04, 25.53it/s]
 42%|####2     | 84/200 [00:03<00:04, 25.52it/s]
 44%|####3     | 87/200 [00:03<00:04, 25.55it/s]
 45%|####5     | 90/200 [00:03<00:04, 25.53it/s]
 46%|####6     | 93/200 [00:03<00:04, 25.52it/s]
 48%|####8     | 96/200 [00:03<00:04, 25.48it/s]
 50%|####9     | 99/200 [00:03<00:03, 25.51it/s]
 51%|#####1    | 102/200 [00:03<00:03, 25.53it/s]
 52%|#####2    | 105/200 [00:04<00:03, 25.50it/s]
 54%|#####4    | 108/200 [00:04<00:03, 25.51it/s]
 56%|#####5    | 111/200 [00:04<00:03, 25.50it/s]
 57%|#####6    | 114/200 [00:04<00:03, 25.52it/s]
 58%|#####8    | 117/200 [00:04<00:03, 25.53it/s]
 60%|######    | 120/200 [00:04<00:03, 25.49it/s]
 62%|######1   | 123/200 [00:04<00:03, 25.50it/s]
 63%|######3   | 126/200 [00:04<00:02, 25.53it/s]
 64%|######4   | 129/200 [00:05<00:02, 25.53it/s]
 66%|######6   | 132/200 [00:05<00:02, 25.54it/s]
 68%|######7   | 135/200 [00:05<00:02, 25.52it/s]
 69%|######9   | 138/200 [00:05<00:02, 25.53it/s]
 70%|#######   | 141/200 [00:05<00:02, 25.53it/s]
 72%|#######2  | 144/200 [00:05<00:02, 25.52it/s]
 74%|#######3  | 147/200 [00:05<00:02, 25.53it/s]
 75%|#######5  | 150/200 [00:05<00:01, 25.50it/s]
 76%|#######6  | 153/200 [00:05<00:01, 25.52it/s]
 78%|#######8  | 156/200 [00:06<00:01, 25.54it/s]
 80%|#######9  | 159/200 [00:06<00:01, 25.53it/s]
 81%|########1 | 162/200 [00:06<00:01, 25.53it/s]
 82%|########2 | 165/200 [00:06<00:01, 25.51it/s]
 84%|########4 | 168/200 [00:06<00:01, 25.53it/s]
 86%|########5 | 171/200 [00:06<00:01, 25.55it/s]
 87%|########7 | 174/200 [00:06<00:01, 25.51it/s]
 88%|########8 | 177/200 [00:06<00:00, 25.52it/s]
 90%|######### | 180/200 [00:07<00:00, 25.51it/s]
 92%|#########1| 183/200 [00:07<00:00, 25.53it/s]
 93%|#########3| 186/200 [00:07<00:00, 25.55it/s]
 94%|#########4| 189/200 [00:07<00:00, 25.54it/s]
 96%|#########6| 192/200 [00:07<00:00, 25.56it/s]
 98%|#########7| 195/200 [00:07<00:00, 25.57it/s]
 99%|#########9| 198/200 [00:07<00:00, 25.55it/s]
100%|##########| 200/200 [00:07<00:00, 25.52it/s]
[10356.855, 3629.3547, 5468.5664, 56.93733, 21.106413, 9.469525, 6.4444423, 5.2016644, 3.6126451, 2.1732643, 1.5829742, 1.860848, 1.2751704, 0.9995298, 0.7592125, 0.9017307, 0.9579516, 0.69556296, 0.76768607, 0.4751102, 0.6556558, 0.48176563, 0.47170743, 0.5104978, 0.38221976, 0.38100347, 0.37095225, 0.2985677, 0.3834814, 0.3384063, 0.3193733, 0.32169548, 0.35431206, 0.32055867, 0.28405303, 0.31619704, 0.2606541, 0.24231105, 0.31852734, 0.26802945, 0.28784284, 0.23087664, 0.24087748, 0.241963, 0.21733445, 0.2596913, 0.20460276, 0.21412516, 0.1899021, 0.23373967, 0.24461012, 0.18258958, 0.18951698, 0.21923462, 0.19709414, 0.19156687, 0.16901016, 0.19155061, 0.12159678, 0.13814253, 0.18099941, 0.15530445, 0.17474842, 0.1900974, 0.15872309, 0.14015315, 0.11668899, 0.16468897, 0.14514636, 0.13027368, 0.14565146, 0.13048214, 0.15211047, 0.13929005, 0.14334789, 0.10838359, 0.12183977, 0.1149121, 0.13145874, 0.12681946, 0.143691, 0.09989229, 0.14133218, 0.13635536, 0.12463614, 0.10697017, 0.14476025, 0.11179676, 0.12284869, 0.10495844, 0.12956223, 0.11531146, 0.09658861, 0.11387021, 0.10891845, 0.11061918, 0.106412075, 0.10080794, 0.08622, 0.12126827, 0.112020016, 0.11092112, 0.09028332, 0.0996965, 0.10365843, 0.11601756, 0.112796746, 0.08982027, 0.10317604, 0.1039547, 0.09923191, 0.099138364, 0.09271118, 0.10959093, 0.09520128, 0.09023816, 0.10819502, 0.0705075, 0.078772366, 0.11278635, 0.10365344, 0.10405728, 0.0912991, 0.08606227, 0.074225515, 0.09489141, 0.09398032, 0.0870906, 0.07308353, 0.087707445, 0.079344146, 0.07076826, 0.08865332, 0.07360243, 0.084816806, 0.08144194, 0.11355602, 0.07774, 0.08522823, 0.07332131, 0.06729105, 0.07457762, 0.074100055, 0.08964393, 0.06868683, 0.07093787, 0.066420816, 0.07660875, 0.06951344, 0.09136959, 0.095407724, 0.097968504, 0.089812316, 0.06884828, 0.08935883, 0.070964575, 0.076458275, 0.082713984, 0.08508434, 0.083531894, 0.076436475, 0.066789396, 0.06575546, 0.08968302, 0.05860862, 0.06622231, 0.057139497, 0.07177789, 0.08635917, 0.071850315, 0.07057329, 0.068253726, 0.06798706, 0.058654524, 0.05781555, 0.055289954, 0.08243201, 0.06664098, 0.06659025, 0.06493263, 0.05997903, 0.08052919, 0.06449012, 0.08076675, 0.05476511, 0.060389828, 0.060871974, 0.058537036, 0.058930896, 0.05889268, 0.07461551, 0.0679916, 0.060866784, 0.06353437, 0.04434182, 0.05416678, 0.06558044, 0.053351857, 0.066473596, 0.06475205]

<AxesSubplot:title={'center':'Train loss against iterations'}>

Total running time of the script: ( 0 minutes 26.421 seconds)

Gallery generated by Sphinx-Gallery