Note
Click here to download the full example code
Train a scikit-learn neural network with onnxruntime-training on GPU#
This example leverages example Train a linear regression with onnxruntime-training on GPU in details to train a neural network from scikit-learn on GPU. However, the code is using classes implemented in this module, following the pattern introduced in exemple Train a linear regression with onnxruntime-training.
A neural network with scikit-learn#
import warnings
from pprint import pprint
import numpy
from pandas import DataFrame
from onnxruntime import get_device, InferenceSession
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPRegressor
from sklearn.metrics import mean_squared_error
from onnxcustom.plotting.plotting_onnx import plot_onnxs
from mlprodict.onnx_conv import to_onnx
from onnxcustom.utils.orttraining_helper import (
add_loss_output, get_train_initializer)
from onnxcustom.training.optimizers import OrtGradientOptimizer
X, y = make_regression(1000, n_features=10, bias=2)
X = X.astype(numpy.float32)
y = y.astype(numpy.float32)
X_train, X_test, y_train, y_test = train_test_split(X, y)
nn = MLPRegressor(hidden_layer_sizes=(10, 10), max_iter=200,
solver='sgd', learning_rate_init=1e-4, alpha=0,
n_iter_no_change=1000, batch_size=10,
momentum=0, nesterovs_momentum=False)
with warnings.catch_warnings():
warnings.simplefilter('ignore')
nn.fit(X_train, y_train)
Score:
print(f"mean_squared_error={mean_squared_error(y_test, nn.predict(X_test))!r}")
mean_squared_error=0.45765528
Conversion to ONNX#
onx = to_onnx(nn, X_train[:1].astype(numpy.float32), target_opset=15)
plot_onnxs(onx)

<AxesSubplot:>
Training graph#
The loss function is the square function. We use function
add_loss_output
.
It does something what is implemented in example
Train a linear regression with onnxruntime-training in details.
onx_train = add_loss_output(onx)
plot_onnxs(onx_train)

<AxesSubplot:>
Let’s check inference is working.
sess = InferenceSession(onx_train.SerializeToString(),
providers=['CPUExecutionProvider'])
res = sess.run(None, {'X': X_test, 'label': y_test.reshape((-1, 1))})
print(f"onnx loss={res[0][0, 0] / X_test.shape[0]!r}")
onnx loss=0.457655029296875
Let’s retrieve the constant, the weight to optimize. We remove initializer which cannot be optimized.
inits = get_train_initializer(onx)
weights = {k: v for k, v in inits.items() if k != "shape_tensor"}
pprint(list((k, v[0].shape) for k, v in weights.items()))
[('coefficient', (10, 10)),
('intercepts', (1, 10)),
('coefficient1', (10, 10)),
('intercepts1', (1, 10)),
('coefficient2', (10, 1)),
('intercepts2', (1, 1))]
Training#
The training session. If GPU is available, it chooses CUDA otherwise it falls back to CPU.
device = "cuda" if get_device().upper() == 'GPU' else 'cpu'
print(f"device={device!r} get_device()={get_device()!r}")
device='cpu' get_device()='CPU'
The training session.
train_session = OrtGradientOptimizer(
onx_train, list(weights), device=device, verbose=1,
learning_rate=5e-4, warm_start=False, max_iter=200, batch_size=10)
train_session.fit(X, y)
state_tensors = train_session.get_state()
print(train_session.train_losses_)
df = DataFrame({'ort losses': train_session.train_losses_,
'skl losses:': nn.loss_curve_})
df.plot(title="Train loss against iterations", logy=True)
# import matplotlib.pyplot as plt
# plt.show()

0%| | 0/200 [00:00<?, ?it/s]
2%|1 | 3/200 [00:00<00:07, 25.55it/s]
3%|3 | 6/200 [00:00<00:07, 25.43it/s]
4%|4 | 9/200 [00:00<00:07, 25.45it/s]
6%|6 | 12/200 [00:00<00:07, 25.41it/s]
8%|7 | 15/200 [00:00<00:07, 25.46it/s]
9%|9 | 18/200 [00:00<00:07, 25.48it/s]
10%|# | 21/200 [00:00<00:07, 25.45it/s]
12%|#2 | 24/200 [00:00<00:06, 25.48it/s]
14%|#3 | 27/200 [00:01<00:06, 25.47it/s]
15%|#5 | 30/200 [00:01<00:06, 25.48it/s]
16%|#6 | 33/200 [00:01<00:06, 25.51it/s]
18%|#8 | 36/200 [00:01<00:06, 25.51it/s]
20%|#9 | 39/200 [00:01<00:06, 25.53it/s]
21%|##1 | 42/200 [00:01<00:06, 25.52it/s]
22%|##2 | 45/200 [00:01<00:06, 25.52it/s]
24%|##4 | 48/200 [00:01<00:05, 25.53it/s]
26%|##5 | 51/200 [00:02<00:05, 25.53it/s]
27%|##7 | 54/200 [00:02<00:05, 25.54it/s]
28%|##8 | 57/200 [00:02<00:05, 25.57it/s]
30%|### | 60/200 [00:02<00:05, 25.54it/s]
32%|###1 | 63/200 [00:02<00:05, 25.55it/s]
33%|###3 | 66/200 [00:02<00:05, 25.52it/s]
34%|###4 | 69/200 [00:02<00:05, 25.52it/s]
36%|###6 | 72/200 [00:02<00:05, 25.54it/s]
38%|###7 | 75/200 [00:02<00:04, 25.53it/s]
39%|###9 | 78/200 [00:03<00:04, 25.56it/s]
40%|#### | 81/200 [00:03<00:04, 25.53it/s]
42%|####2 | 84/200 [00:03<00:04, 25.52it/s]
44%|####3 | 87/200 [00:03<00:04, 25.55it/s]
45%|####5 | 90/200 [00:03<00:04, 25.53it/s]
46%|####6 | 93/200 [00:03<00:04, 25.52it/s]
48%|####8 | 96/200 [00:03<00:04, 25.48it/s]
50%|####9 | 99/200 [00:03<00:03, 25.51it/s]
51%|#####1 | 102/200 [00:03<00:03, 25.53it/s]
52%|#####2 | 105/200 [00:04<00:03, 25.50it/s]
54%|#####4 | 108/200 [00:04<00:03, 25.51it/s]
56%|#####5 | 111/200 [00:04<00:03, 25.50it/s]
57%|#####6 | 114/200 [00:04<00:03, 25.52it/s]
58%|#####8 | 117/200 [00:04<00:03, 25.53it/s]
60%|###### | 120/200 [00:04<00:03, 25.49it/s]
62%|######1 | 123/200 [00:04<00:03, 25.50it/s]
63%|######3 | 126/200 [00:04<00:02, 25.53it/s]
64%|######4 | 129/200 [00:05<00:02, 25.53it/s]
66%|######6 | 132/200 [00:05<00:02, 25.54it/s]
68%|######7 | 135/200 [00:05<00:02, 25.52it/s]
69%|######9 | 138/200 [00:05<00:02, 25.53it/s]
70%|####### | 141/200 [00:05<00:02, 25.53it/s]
72%|#######2 | 144/200 [00:05<00:02, 25.52it/s]
74%|#######3 | 147/200 [00:05<00:02, 25.53it/s]
75%|#######5 | 150/200 [00:05<00:01, 25.50it/s]
76%|#######6 | 153/200 [00:05<00:01, 25.52it/s]
78%|#######8 | 156/200 [00:06<00:01, 25.54it/s]
80%|#######9 | 159/200 [00:06<00:01, 25.53it/s]
81%|########1 | 162/200 [00:06<00:01, 25.53it/s]
82%|########2 | 165/200 [00:06<00:01, 25.51it/s]
84%|########4 | 168/200 [00:06<00:01, 25.53it/s]
86%|########5 | 171/200 [00:06<00:01, 25.55it/s]
87%|########7 | 174/200 [00:06<00:01, 25.51it/s]
88%|########8 | 177/200 [00:06<00:00, 25.52it/s]
90%|######### | 180/200 [00:07<00:00, 25.51it/s]
92%|#########1| 183/200 [00:07<00:00, 25.53it/s]
93%|#########3| 186/200 [00:07<00:00, 25.55it/s]
94%|#########4| 189/200 [00:07<00:00, 25.54it/s]
96%|#########6| 192/200 [00:07<00:00, 25.56it/s]
98%|#########7| 195/200 [00:07<00:00, 25.57it/s]
99%|#########9| 198/200 [00:07<00:00, 25.55it/s]
100%|##########| 200/200 [00:07<00:00, 25.52it/s]
[10356.855, 3629.3547, 5468.5664, 56.93733, 21.106413, 9.469525, 6.4444423, 5.2016644, 3.6126451, 2.1732643, 1.5829742, 1.860848, 1.2751704, 0.9995298, 0.7592125, 0.9017307, 0.9579516, 0.69556296, 0.76768607, 0.4751102, 0.6556558, 0.48176563, 0.47170743, 0.5104978, 0.38221976, 0.38100347, 0.37095225, 0.2985677, 0.3834814, 0.3384063, 0.3193733, 0.32169548, 0.35431206, 0.32055867, 0.28405303, 0.31619704, 0.2606541, 0.24231105, 0.31852734, 0.26802945, 0.28784284, 0.23087664, 0.24087748, 0.241963, 0.21733445, 0.2596913, 0.20460276, 0.21412516, 0.1899021, 0.23373967, 0.24461012, 0.18258958, 0.18951698, 0.21923462, 0.19709414, 0.19156687, 0.16901016, 0.19155061, 0.12159678, 0.13814253, 0.18099941, 0.15530445, 0.17474842, 0.1900974, 0.15872309, 0.14015315, 0.11668899, 0.16468897, 0.14514636, 0.13027368, 0.14565146, 0.13048214, 0.15211047, 0.13929005, 0.14334789, 0.10838359, 0.12183977, 0.1149121, 0.13145874, 0.12681946, 0.143691, 0.09989229, 0.14133218, 0.13635536, 0.12463614, 0.10697017, 0.14476025, 0.11179676, 0.12284869, 0.10495844, 0.12956223, 0.11531146, 0.09658861, 0.11387021, 0.10891845, 0.11061918, 0.106412075, 0.10080794, 0.08622, 0.12126827, 0.112020016, 0.11092112, 0.09028332, 0.0996965, 0.10365843, 0.11601756, 0.112796746, 0.08982027, 0.10317604, 0.1039547, 0.09923191, 0.099138364, 0.09271118, 0.10959093, 0.09520128, 0.09023816, 0.10819502, 0.0705075, 0.078772366, 0.11278635, 0.10365344, 0.10405728, 0.0912991, 0.08606227, 0.074225515, 0.09489141, 0.09398032, 0.0870906, 0.07308353, 0.087707445, 0.079344146, 0.07076826, 0.08865332, 0.07360243, 0.084816806, 0.08144194, 0.11355602, 0.07774, 0.08522823, 0.07332131, 0.06729105, 0.07457762, 0.074100055, 0.08964393, 0.06868683, 0.07093787, 0.066420816, 0.07660875, 0.06951344, 0.09136959, 0.095407724, 0.097968504, 0.089812316, 0.06884828, 0.08935883, 0.070964575, 0.076458275, 0.082713984, 0.08508434, 0.083531894, 0.076436475, 0.066789396, 0.06575546, 0.08968302, 0.05860862, 0.06622231, 0.057139497, 0.07177789, 0.08635917, 0.071850315, 0.07057329, 0.068253726, 0.06798706, 0.058654524, 0.05781555, 0.055289954, 0.08243201, 0.06664098, 0.06659025, 0.06493263, 0.05997903, 0.08052919, 0.06449012, 0.08076675, 0.05476511, 0.060389828, 0.060871974, 0.058537036, 0.058930896, 0.05889268, 0.07461551, 0.0679916, 0.060866784, 0.06353437, 0.04434182, 0.05416678, 0.06558044, 0.053351857, 0.066473596, 0.06475205]
<AxesSubplot:title={'center':'Train loss against iterations'}>
Total running time of the script: ( 0 minutes 26.421 seconds)