Train a scikit-learn neural network with onnxruntime-training on GPU

This example leverages example Train a linear regression with onnxruntime-training on GPU in details to train a neural network from scikit-learn on GPU. However, the code is using classes implemented in this module, following the pattern introduced in exemple Train a linear regression with onnxruntime-training.

A neural network with scikit-learn

import warnings
from pprint import pprint
import numpy
from pandas import DataFrame
from onnxruntime import get_device, InferenceSession
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPRegressor
from sklearn.metrics import mean_squared_error
from onnxcustom.plotting.plotting_onnx import plot_onnxs
from mlprodict.onnx_conv import to_onnx
from onnxcustom.utils.orttraining_helper import (
    add_loss_output, get_train_initializer)
from onnxcustom.training.optimizers import OrtGradientOptimizer


X, y = make_regression(1000, n_features=10, bias=2)
X = X.astype(numpy.float32)
y = y.astype(numpy.float32)
X_train, X_test, y_train, y_test = train_test_split(X, y)

nn = MLPRegressor(hidden_layer_sizes=(10, 10), max_iter=200,
                  solver='sgd', learning_rate_init=1e-4, alpha=0,
                  n_iter_no_change=1000, batch_size=10,
                  momentum=0, nesterovs_momentum=False)

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    nn.fit(X_train, y_train)

Score:

print("mean_squared_error=%r" % mean_squared_error(y_test, nn.predict(X_test)))

Out:

mean_squared_error=0.6655399

Conversion to ONNX

onx = to_onnx(nn, X_train[:1].astype(numpy.float32), target_opset=15)
plot_onnxs(onx)
plot orttraining nn gpu

Out:

<AxesSubplot:>

Training graph

The loss function is the square function. We use function add_loss_output. It does something what is implemented in example Train a linear regression with onnxruntime-training in details.

onx_train = add_loss_output(onx)
plot_onnxs(onx_train)
plot orttraining nn gpu

Out:

<AxesSubplot:>

Let’s check inference is working.

sess = InferenceSession(onx_train.SerializeToString(),
                        providers=['CPUExecutionProvider'])
res = sess.run(None, {'X': X_test, 'label': y_test.reshape((-1, 1))})
print("onnx loss=%r" % (res[0][0, 0] / X_test.shape[0]))

Out:

onnx loss=0.6655394897460938

Let’s retrieve the constant, the weight to optimize. We remove initializer which cannot be optimized.

inits = get_train_initializer(onx)
weights = {k: v for k, v in inits.items() if k != "shape_tensor"}
pprint(list((k, v[0].shape) for k, v in weights.items()))

Out:

[('coefficient', (10, 10)),
 ('intercepts', (1, 10)),
 ('coefficient1', (10, 10)),
 ('intercepts1', (1, 10)),
 ('coefficient2', (10, 1)),
 ('intercepts2', (1, 1))]

Training

The training session. If GPU is available, it chooses CUDA otherwise it falls back to CPU.

device = "cuda" if get_device().upper() == 'GPU' else 'cpu'

print("device=%r get_device()=%r" % (device, get_device()))

Out:

device='cpu' get_device()='CPU'

The training session.

train_session = OrtGradientOptimizer(
    onx_train, list(weights), device=device, verbose=1,
    learning_rate=5e-4, warm_start=False, max_iter=200, batch_size=10)

train_session.fit(X, y)
state_tensors = train_session.get_state()

print(train_session.train_losses_)

df = DataFrame({'ort losses': train_session.train_losses_,
                'skl losses:': nn.loss_curve_})
df.plot(title="Train loss against iterations", logy=True)

# import matplotlib.pyplot as plt
# plt.show()
Train loss against iterations

Out:

  0%|          | 0/200 [00:00<?, ?it/s]
  1%|1         | 2/200 [00:00<00:10, 19.43it/s]
  2%|2         | 4/200 [00:00<00:10, 19.39it/s]
  3%|3         | 6/200 [00:00<00:10, 19.36it/s]
  4%|4         | 8/200 [00:00<00:09, 19.37it/s]
  5%|5         | 10/200 [00:00<00:09, 19.35it/s]
  6%|6         | 12/200 [00:00<00:09, 19.35it/s]
  7%|7         | 14/200 [00:00<00:09, 19.36it/s]
  8%|8         | 16/200 [00:00<00:09, 19.37it/s]
  9%|9         | 18/200 [00:00<00:09, 19.37it/s]
 10%|#         | 20/200 [00:01<00:09, 19.35it/s]
 11%|#1        | 22/200 [00:01<00:09, 19.36it/s]
 12%|#2        | 24/200 [00:01<00:09, 19.37it/s]
 13%|#3        | 26/200 [00:01<00:08, 19.35it/s]
 14%|#4        | 28/200 [00:01<00:08, 19.36it/s]
 15%|#5        | 30/200 [00:01<00:08, 19.36it/s]
 16%|#6        | 32/200 [00:01<00:08, 19.34it/s]
 17%|#7        | 34/200 [00:01<00:08, 19.35it/s]
 18%|#8        | 36/200 [00:01<00:08, 19.37it/s]
 19%|#9        | 38/200 [00:01<00:08, 19.36it/s]
 20%|##        | 40/200 [00:02<00:08, 19.35it/s]
 21%|##1       | 42/200 [00:02<00:08, 19.35it/s]
 22%|##2       | 44/200 [00:02<00:08, 19.35it/s]
 23%|##3       | 46/200 [00:02<00:07, 19.34it/s]
 24%|##4       | 48/200 [00:02<00:07, 19.36it/s]
 25%|##5       | 50/200 [00:02<00:07, 19.36it/s]
 26%|##6       | 52/200 [00:02<00:07, 19.35it/s]
 27%|##7       | 54/200 [00:02<00:07, 19.36it/s]
 28%|##8       | 56/200 [00:02<00:07, 19.36it/s]
 29%|##9       | 58/200 [00:02<00:07, 19.38it/s]
 30%|###       | 60/200 [00:03<00:07, 19.38it/s]
 31%|###1      | 62/200 [00:03<00:07, 19.35it/s]
 32%|###2      | 64/200 [00:03<00:07, 19.36it/s]
 33%|###3      | 66/200 [00:03<00:06, 19.31it/s]
 34%|###4      | 68/200 [00:03<00:06, 19.34it/s]
 35%|###5      | 70/200 [00:03<00:06, 19.35it/s]
 36%|###6      | 72/200 [00:03<00:06, 19.36it/s]
 37%|###7      | 74/200 [00:03<00:06, 19.38it/s]
 38%|###8      | 76/200 [00:03<00:06, 19.39it/s]
 39%|###9      | 78/200 [00:04<00:06, 19.40it/s]
 40%|####      | 80/200 [00:04<00:06, 19.22it/s]
 41%|####1     | 82/200 [00:04<00:06, 19.28it/s]
 42%|####2     | 84/200 [00:04<00:06, 19.31it/s]
 43%|####3     | 86/200 [00:04<00:05, 19.33it/s]
 44%|####4     | 88/200 [00:04<00:05, 19.36it/s]
 45%|####5     | 90/200 [00:04<00:05, 19.38it/s]
 46%|####6     | 92/200 [00:04<00:05, 19.38it/s]
 47%|####6     | 94/200 [00:04<00:05, 19.37it/s]
 48%|####8     | 96/200 [00:04<00:05, 19.39it/s]
 49%|####9     | 98/200 [00:05<00:05, 19.38it/s]
 50%|#####     | 100/200 [00:05<00:05, 19.38it/s]
 51%|#####1    | 102/200 [00:05<00:05, 19.37it/s]
 52%|#####2    | 104/200 [00:05<00:04, 19.36it/s]
 53%|#####3    | 106/200 [00:05<00:04, 19.37it/s]
 54%|#####4    | 108/200 [00:05<00:04, 19.36it/s]
 55%|#####5    | 110/200 [00:05<00:04, 19.36it/s]
 56%|#####6    | 112/200 [00:05<00:04, 19.37it/s]
 57%|#####6    | 114/200 [00:05<00:04, 19.38it/s]
 58%|#####8    | 116/200 [00:05<00:04, 19.39it/s]
 59%|#####8    | 118/200 [00:06<00:04, 19.39it/s]
 60%|######    | 120/200 [00:06<00:04, 19.40it/s]
 61%|######1   | 122/200 [00:06<00:04, 19.39it/s]
 62%|######2   | 124/200 [00:06<00:03, 19.39it/s]
 63%|######3   | 126/200 [00:06<00:03, 19.36it/s]
 64%|######4   | 128/200 [00:06<00:03, 19.34it/s]
 65%|######5   | 130/200 [00:06<00:03, 19.36it/s]
 66%|######6   | 132/200 [00:06<00:03, 19.36it/s]
 67%|######7   | 134/200 [00:06<00:03, 19.36it/s]
 68%|######8   | 136/200 [00:07<00:03, 19.36it/s]
 69%|######9   | 138/200 [00:07<00:03, 19.38it/s]
 70%|#######   | 140/200 [00:07<00:03, 19.37it/s]
 71%|#######1  | 142/200 [00:07<00:02, 19.37it/s]
 72%|#######2  | 144/200 [00:07<00:02, 19.38it/s]
 73%|#######3  | 146/200 [00:07<00:02, 19.38it/s]
 74%|#######4  | 148/200 [00:07<00:02, 19.39it/s]
 75%|#######5  | 150/200 [00:07<00:02, 19.38it/s]
 76%|#######6  | 152/200 [00:07<00:02, 19.37it/s]
 77%|#######7  | 154/200 [00:07<00:02, 19.36it/s]
 78%|#######8  | 156/200 [00:08<00:02, 19.37it/s]
 79%|#######9  | 158/200 [00:08<00:02, 19.39it/s]
 80%|########  | 160/200 [00:08<00:02, 19.40it/s]
 81%|########1 | 162/200 [00:08<00:01, 19.41it/s]
 82%|########2 | 164/200 [00:08<00:01, 19.40it/s]
 83%|########2 | 166/200 [00:08<00:01, 19.41it/s]
 84%|########4 | 168/200 [00:08<00:01, 19.40it/s]
 85%|########5 | 170/200 [00:08<00:01, 19.39it/s]
 86%|########6 | 172/200 [00:08<00:01, 19.40it/s]
 87%|########7 | 174/200 [00:08<00:01, 19.41it/s]
 88%|########8 | 176/200 [00:09<00:01, 19.43it/s]
 89%|########9 | 178/200 [00:09<00:01, 19.44it/s]
 90%|######### | 180/200 [00:09<00:01, 19.45it/s]
 91%|#########1| 182/200 [00:09<00:00, 19.45it/s]
 92%|#########2| 184/200 [00:09<00:00, 19.43it/s]
 93%|#########3| 186/200 [00:09<00:00, 19.47it/s]
 94%|#########3| 188/200 [00:09<00:00, 19.48it/s]
 95%|#########5| 190/200 [00:09<00:00, 19.46it/s]
 96%|#########6| 192/200 [00:09<00:00, 19.45it/s]
 97%|#########7| 194/200 [00:10<00:00, 19.44it/s]
 98%|#########8| 196/200 [00:10<00:00, 19.42it/s]
 99%|#########9| 198/200 [00:10<00:00, 19.39it/s]
100%|##########| 200/200 [00:10<00:00, 19.37it/s]
100%|##########| 200/200 [00:10<00:00, 19.37it/s]
[16436.879, 14902.491, 5674.8354, 566.09204, 120.756935, 81.74709, 19.599201, 13.427179, 7.0914106, 6.1223063, 7.606578, 5.297221, 3.3950295, 3.3763013, 3.3155093, 2.0030444, 2.1310534, 2.019372, 1.5568511, 2.20432, 2.3076622, 1.8339978, 1.839129, 1.319651, 1.5677245, 1.3874806, 1.3038124, 1.0870141, 1.085475, 1.101761, 0.91161865, 0.9858739, 0.8887112, 1.2029111, 1.0190339, 0.91901106, 0.8482574, 0.5112161, 0.73553103, 0.8175488, 0.7595738, 0.73705554, 0.5749286, 0.7384508, 0.7043623, 0.5548471, 0.5371983, 0.6098751, 0.60137606, 0.52753675, 0.55970067, 0.45928496, 0.5683602, 0.60115665, 0.5185862, 0.45665458, 0.38131404, 0.60609925, 0.47240117, 0.35278502, 0.3959964, 0.4827702, 0.4593915, 0.3595308, 0.49254858, 0.373629, 0.42077777, 0.33036146, 0.48065296, 0.3712532, 0.3855768, 0.39852604, 0.33634844, 0.31005704, 0.28867978, 0.28298303, 0.3355928, 0.40739077, 0.3506148, 0.34875366, 0.32470128, 0.34540513, 0.26682815, 0.39616358, 0.33194813, 0.2949814, 0.25135845, 0.331502, 0.2743682, 0.3686386, 0.35658973, 0.2277636, 0.2253767, 0.3170266, 0.3710035, 0.28745717, 0.25517413, 0.25017792, 0.26485375, 0.27284276, 0.2749549, 0.2521683, 0.2084716, 0.29643622, 0.28641498, 0.1661274, 0.18573241, 0.21098083, 0.1957551, 0.28017148, 0.21550007, 0.2528343, 0.17376713, 0.2503166, 0.21103568, 0.2989259, 0.21069498, 0.18429331, 0.16732803, 0.17649384, 0.110683925, 0.21496479, 0.29996985, 0.15340587, 0.23661381, 0.17454056, 0.21417572, 0.16218118, 0.1924737, 0.14655374, 0.1977415, 0.17899685, 0.10682064, 0.16559431, 0.18972574, 0.14772667, 0.17117491, 0.20505154, 0.23113133, 0.13904168, 0.10518817, 0.11671767, 0.16133738, 0.2646562, 0.1289357, 0.1862377, 0.16447243, 0.24826303, 0.16002838, 0.18587168, 0.16572273, 0.15251066, 0.10744582, 0.15665536, 0.1349173, 0.19306754, 0.18153165, 0.24369375, 0.20521504, 0.1061354, 0.16147436, 0.11600265, 0.13725594, 0.116161585, 0.13569254, 0.1354539, 0.18902661, 0.14189872, 0.24732019, 0.0925035, 0.19683427, 0.12531371, 0.12991042, 0.1160886, 0.089676335, 0.11825717, 0.09971924, 0.13621303, 0.06729801, 0.12358891, 0.08002738, 0.11493575, 0.08486778, 0.10554214, 0.13393642, 0.1026533, 0.097287275, 0.097009726, 0.13138635, 0.18428284, 0.118060455, 0.15535262, 0.10996106, 0.06857138, 0.077335455, 0.08316788, 0.1604995, 0.13469936, 0.15550925, 0.1174179]

<AxesSubplot:title={'center':'Train loss against iterations'}>

Total running time of the script: ( 0 minutes 39.500 seconds)

Gallery generated by Sphinx-Gallery