Note
Click here to download the full example code
Train a scikit-learn neural network with onnxruntime-training on GPU#
This example leverages example Train a linear regression with onnxruntime-training on GPU in details to train a neural network from scikit-learn on GPU. However, the code is using classes implemented in this module, following the pattern introduced in exemple Train a linear regression with onnxruntime-training.
A neural network with scikit-learn#
import warnings
from pprint import pprint
import numpy
from pandas import DataFrame
from onnxruntime import get_device, InferenceSession
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPRegressor
from sklearn.metrics import mean_squared_error
from onnxcustom.plotting.plotting_onnx import plot_onnxs
from mlprodict.onnx_conv import to_onnx
from onnxcustom.utils.orttraining_helper import (
add_loss_output, get_train_initializer)
from onnxcustom.training.optimizers import OrtGradientOptimizer
X, y = make_regression(1000, n_features=10, bias=2)
X = X.astype(numpy.float32)
y = y.astype(numpy.float32)
X_train, X_test, y_train, y_test = train_test_split(X, y)
nn = MLPRegressor(hidden_layer_sizes=(10, 10), max_iter=200,
solver='sgd', learning_rate_init=1e-4, alpha=0,
n_iter_no_change=1000, batch_size=10,
momentum=0, nesterovs_momentum=False)
with warnings.catch_warnings():
warnings.simplefilter('ignore')
nn.fit(X_train, y_train)
Score:
print(f"mean_squared_error={mean_squared_error(y_test, nn.predict(X_test))!r}")
mean_squared_error=0.093900256
Conversion to ONNX#
onx = to_onnx(nn, X_train[:1].astype(numpy.float32), target_opset=15)
plot_onnxs(onx)

<AxesSubplot: >
Training graph#
The loss function is the square function. We use function
add_loss_output
.
It does something what is implemented in example
Train a linear regression with onnxruntime-training in details.
onx_train = add_loss_output(onx)
plot_onnxs(onx_train)

<AxesSubplot: >
Let’s check inference is working.
sess = InferenceSession(onx_train.SerializeToString(),
providers=['CPUExecutionProvider'])
res = sess.run(None, {'X': X_test, 'label': y_test.reshape((-1, 1))})
print(f"onnx loss={res[0][0, 0] / X_test.shape[0]!r}")
onnx loss=0.0939002456665039
Let’s retrieve the constant, the weight to optimize. We remove initializer which cannot be optimized.
inits = get_train_initializer(onx)
weights = {k: v for k, v in inits.items() if k != "shape_tensor"}
pprint(list((k, v[0].shape) for k, v in weights.items()))
[('coefficient', (10, 10)),
('intercepts', (1, 10)),
('coefficient1', (10, 10)),
('intercepts1', (1, 10)),
('coefficient2', (10, 1)),
('intercepts2', (1, 1))]
Training#
The training session. If GPU is available, it chooses CUDA otherwise it falls back to CPU.
device = "cuda" if get_device().upper() == 'GPU' else 'cpu'
print(f"device={device!r} get_device()={get_device()!r}")
device='cpu' get_device()='CPU'
The training session.
train_session = OrtGradientOptimizer(
onx_train, list(weights), device=device, verbose=1,
learning_rate=5e-4, warm_start=False, max_iter=200, batch_size=10)
train_session.fit(X, y)
state_tensors = train_session.get_state()
print(train_session.train_losses_)
df = DataFrame({'ort losses': train_session.train_losses_,
'skl losses:': nn.loss_curve_})
df.plot(title="Train loss against iterations", logy=True)
# import matplotlib.pyplot as plt
# plt.show()

0%| | 0/200 [00:00<?, ?it/s]
2%|1 | 3/200 [00:00<00:07, 24.65it/s]
3%|3 | 6/200 [00:00<00:07, 24.59it/s]
4%|4 | 9/200 [00:00<00:07, 24.60it/s]
6%|6 | 12/200 [00:00<00:07, 24.62it/s]
8%|7 | 15/200 [00:00<00:07, 24.61it/s]
9%|9 | 18/200 [00:00<00:07, 24.62it/s]
10%|# | 21/200 [00:00<00:07, 24.62it/s]
12%|#2 | 24/200 [00:00<00:07, 24.63it/s]
14%|#3 | 27/200 [00:01<00:07, 24.62it/s]
15%|#5 | 30/200 [00:01<00:06, 24.61it/s]
16%|#6 | 33/200 [00:01<00:06, 24.60it/s]
18%|#8 | 36/200 [00:01<00:06, 24.61it/s]
20%|#9 | 39/200 [00:01<00:06, 24.60it/s]
21%|##1 | 42/200 [00:01<00:06, 24.62it/s]
22%|##2 | 45/200 [00:01<00:06, 24.60it/s]
24%|##4 | 48/200 [00:01<00:06, 24.62it/s]
26%|##5 | 51/200 [00:02<00:06, 24.61it/s]
27%|##7 | 54/200 [00:02<00:05, 24.61it/s]
28%|##8 | 57/200 [00:02<00:05, 24.60it/s]
30%|### | 60/200 [00:02<00:05, 24.60it/s]
32%|###1 | 63/200 [00:02<00:05, 24.59it/s]
33%|###3 | 66/200 [00:02<00:05, 24.60it/s]
34%|###4 | 69/200 [00:02<00:05, 24.57it/s]
36%|###6 | 72/200 [00:02<00:05, 24.56it/s]
38%|###7 | 75/200 [00:03<00:05, 24.55it/s]
39%|###9 | 78/200 [00:03<00:04, 24.56it/s]
40%|#### | 81/200 [00:03<00:04, 24.56it/s]
42%|####2 | 84/200 [00:03<00:04, 24.54it/s]
44%|####3 | 87/200 [00:03<00:04, 24.55it/s]
45%|####5 | 90/200 [00:03<00:04, 24.56it/s]
46%|####6 | 93/200 [00:03<00:04, 24.56it/s]
48%|####8 | 96/200 [00:03<00:04, 24.56it/s]
50%|####9 | 99/200 [00:04<00:04, 24.54it/s]
51%|#####1 | 102/200 [00:04<00:03, 24.54it/s]
52%|#####2 | 105/200 [00:04<00:03, 24.53it/s]
54%|#####4 | 108/200 [00:04<00:03, 24.54it/s]
56%|#####5 | 111/200 [00:04<00:03, 24.54it/s]
57%|#####6 | 114/200 [00:04<00:03, 24.55it/s]
58%|#####8 | 117/200 [00:04<00:03, 24.54it/s]
60%|###### | 120/200 [00:04<00:03, 24.54it/s]
62%|######1 | 123/200 [00:05<00:03, 24.54it/s]
63%|######3 | 126/200 [00:05<00:03, 24.55it/s]
64%|######4 | 129/200 [00:05<00:02, 24.55it/s]
66%|######6 | 132/200 [00:05<00:02, 24.53it/s]
68%|######7 | 135/200 [00:05<00:02, 24.47it/s]
69%|######9 | 138/200 [00:05<00:02, 24.47it/s]
70%|####### | 141/200 [00:05<00:02, 24.47it/s]
72%|#######2 | 144/200 [00:05<00:02, 24.49it/s]
74%|#######3 | 147/200 [00:05<00:02, 24.52it/s]
75%|#######5 | 150/200 [00:06<00:02, 24.54it/s]
76%|#######6 | 153/200 [00:06<00:01, 24.53it/s]
78%|#######8 | 156/200 [00:06<00:01, 24.52it/s]
80%|#######9 | 159/200 [00:06<00:01, 24.52it/s]
81%|########1 | 162/200 [00:06<00:01, 24.52it/s]
82%|########2 | 165/200 [00:06<00:01, 24.50it/s]
84%|########4 | 168/200 [00:06<00:01, 24.51it/s]
86%|########5 | 171/200 [00:06<00:01, 24.51it/s]
87%|########7 | 174/200 [00:07<00:01, 24.50it/s]
88%|########8 | 177/200 [00:07<00:00, 24.49it/s]
90%|######### | 180/200 [00:07<00:00, 24.48it/s]
92%|#########1| 183/200 [00:07<00:00, 24.50it/s]
93%|#########3| 186/200 [00:07<00:00, 24.50it/s]
94%|#########4| 189/200 [00:07<00:00, 24.51it/s]
96%|#########6| 192/200 [00:07<00:00, 24.51it/s]
98%|#########7| 195/200 [00:07<00:00, 24.54it/s]
99%|#########9| 198/200 [00:08<00:00, 24.49it/s]
100%|##########| 200/200 [00:08<00:00, 24.55it/s]
[18954.137, 31050.977, 19407.963, 15871.107, 13330.832, 14443.731, 13159.874, 12349.546, 12676.277, 11961.014, 11906.464, 10416.354, 11611.566, 10417.972, 10804.054, 10527.281, 9124.444, 11262.38, 9301.433, 8569.867, 8347.148, 9405.403, 7794.7485, 8109.105, 7257.8, 8218.562, 7245.007, 7782.9004, 7871.074, 7626.823, 8530.531, 7548.5454, 7778.264, 8224.826, 6902.9785, 7205.895, 7423.84, 6401.8545, 7116.817, 6488.8345, 6521.5986, 5753.2183, 5130.2563, 6996.4136, 6577.326, 5748.6694, 5768.578, 6988.4805, 5768.353, 5787.759, 6372.6465, 5348.3955, 5723.533, 4274.1353, 5269.3623, 5723.1226, 5444.287, 4706.5713, 5622.877, 5366.4336, 4514.6543, 4708.142, 5282.5, 6357.183, 4880.45, 5901.449, 4857.69, 5836.1177, 5845.1675, 5479.284, 4007.0168, 4709.371, 4036.6123, 5027.9785, 4353.3804, 4181.4995, 5830.0483, 4586.8594, 4254.7236, 4470.1934, 4015.5544, 4500.4434, 4568.95, 3513.9038, 3589.9968, 5498.219, 3294.775, 4083.7854, 3769.8765, 3219.3784, 3342.8894, 4108.8115, 3915.1362, 3150.9622, 3695.3713, 4342.0454, 3436.6143, 3273.1628, 4461.0205, 3909.701, 3074.854, 2944.664, 2993.8105, 4340.0195, 4000.6184, 3482.1833, 3112.2393, 3339.338, 3909.8728, 2852.2725, 3751.4966, 2770.0007, 2824.8738, 3748.8206, 3219.9255, 4720.1367, 3590.0186, 2802.2827, 3265.167, 3195.3096, 3490.1147, 2791.9478, 3799.0225, 3131.083, 3419.3914, 2652.4563, 3068.9573, 3152.8445, 4149.6323, 2918.5454, 4073.206, 3721.1946, 3552.317, 3330.7195, 2765.9038, 2887.549, 2938.086, 3276.3047, 3061.6812, 2664.5476, 3259.952, 2773.6226, 3141.1484, 2949.0503, 2356.173, 2734.7847, 2682.4731, 2455.6746, 2534.1172, 2317.5605, 3253.5571, 2783.5588, 2728.8015, 2713.7693, 2384.4868, 3035.3994, 2864.5825, 2632.3708, 2155.6738, 3013.587, 2784.9998, 3078.027, 2101.6091, 1709.1598, 2793.0637, 2505.3882, 2687.3157, 2117.1045, 1873.9286, 3022.3792, 2428.8342, 2512.3032, 2486.1597, 2481.2144, 2090.9, 2438.1907, 2082.797, 2191.189, 2421.3044, 2187.6257, 1895.4113, 2646.2888, 1953.2219, 2136.0935, 2026.9938, 2546.2915, 3144.115, 2129.9836, 3022.9688, 2555.0815, 2281.376, 2198.972, 1996.7545, 2399.6152, 2222.954, 2709.7866, 2269.5564, 2658.1304, 1772.8275, 2272.3315]
<AxesSubplot: title={'center': 'Train loss against iterations'}>
Total running time of the script: ( 0 minutes 27.891 seconds)