Train a scikit-learn neural network with onnxruntime-training on GPU#

This example leverages example Train a linear regression with onnxruntime-training on GPU in details to train a neural network from scikit-learn on GPU. However, the code is using classes implemented in this module, following the pattern introduced in exemple Train a linear regression with onnxruntime-training.

A neural network with scikit-learn#

import warnings
from pprint import pprint
import numpy
from pandas import DataFrame
from onnxruntime import get_device, InferenceSession
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPRegressor
from sklearn.metrics import mean_squared_error
from onnxcustom.plotting.plotting_onnx import plot_onnxs
from mlprodict.onnx_conv import to_onnx
from onnxcustom.utils.orttraining_helper import (
    add_loss_output, get_train_initializer)
from onnxcustom.training.optimizers import OrtGradientOptimizer


X, y = make_regression(1000, n_features=10, bias=2)
X = X.astype(numpy.float32)
y = y.astype(numpy.float32)
X_train, X_test, y_train, y_test = train_test_split(X, y)

nn = MLPRegressor(hidden_layer_sizes=(10, 10), max_iter=200,
                  solver='sgd', learning_rate_init=1e-4, alpha=0,
                  n_iter_no_change=1000, batch_size=10,
                  momentum=0, nesterovs_momentum=False)

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    nn.fit(X_train, y_train)

Score:

print(f"mean_squared_error={mean_squared_error(y_test, nn.predict(X_test))!r}")
mean_squared_error=0.093900256

Conversion to ONNX#

onx = to_onnx(nn, X_train[:1].astype(numpy.float32), target_opset=15)
plot_onnxs(onx)
plot orttraining nn gpu
<AxesSubplot: >

Training graph#

The loss function is the square function. We use function add_loss_output. It does something what is implemented in example Train a linear regression with onnxruntime-training in details.

onx_train = add_loss_output(onx)
plot_onnxs(onx_train)
plot orttraining nn gpu
<AxesSubplot: >

Let’s check inference is working.

sess = InferenceSession(onx_train.SerializeToString(),
                        providers=['CPUExecutionProvider'])
res = sess.run(None, {'X': X_test, 'label': y_test.reshape((-1, 1))})
print(f"onnx loss={res[0][0, 0] / X_test.shape[0]!r}")
onnx loss=0.0939002456665039

Let’s retrieve the constant, the weight to optimize. We remove initializer which cannot be optimized.

inits = get_train_initializer(onx)
weights = {k: v for k, v in inits.items() if k != "shape_tensor"}
pprint(list((k, v[0].shape) for k, v in weights.items()))
[('coefficient', (10, 10)),
 ('intercepts', (1, 10)),
 ('coefficient1', (10, 10)),
 ('intercepts1', (1, 10)),
 ('coefficient2', (10, 1)),
 ('intercepts2', (1, 1))]

Training#

The training session. If GPU is available, it chooses CUDA otherwise it falls back to CPU.

device = "cuda" if get_device().upper() == 'GPU' else 'cpu'

print(f"device={device!r} get_device()={get_device()!r}")
device='cpu' get_device()='CPU'

The training session.

train_session = OrtGradientOptimizer(
    onx_train, list(weights), device=device, verbose=1,
    learning_rate=5e-4, warm_start=False, max_iter=200, batch_size=10)

train_session.fit(X, y)
state_tensors = train_session.get_state()

print(train_session.train_losses_)

df = DataFrame({'ort losses': train_session.train_losses_,
                'skl losses:': nn.loss_curve_})
df.plot(title="Train loss against iterations", logy=True)

# import matplotlib.pyplot as plt
# plt.show()
Train loss against iterations
  0%|          | 0/200 [00:00<?, ?it/s]
  2%|1         | 3/200 [00:00<00:07, 24.65it/s]
  3%|3         | 6/200 [00:00<00:07, 24.59it/s]
  4%|4         | 9/200 [00:00<00:07, 24.60it/s]
  6%|6         | 12/200 [00:00<00:07, 24.62it/s]
  8%|7         | 15/200 [00:00<00:07, 24.61it/s]
  9%|9         | 18/200 [00:00<00:07, 24.62it/s]
 10%|#         | 21/200 [00:00<00:07, 24.62it/s]
 12%|#2        | 24/200 [00:00<00:07, 24.63it/s]
 14%|#3        | 27/200 [00:01<00:07, 24.62it/s]
 15%|#5        | 30/200 [00:01<00:06, 24.61it/s]
 16%|#6        | 33/200 [00:01<00:06, 24.60it/s]
 18%|#8        | 36/200 [00:01<00:06, 24.61it/s]
 20%|#9        | 39/200 [00:01<00:06, 24.60it/s]
 21%|##1       | 42/200 [00:01<00:06, 24.62it/s]
 22%|##2       | 45/200 [00:01<00:06, 24.60it/s]
 24%|##4       | 48/200 [00:01<00:06, 24.62it/s]
 26%|##5       | 51/200 [00:02<00:06, 24.61it/s]
 27%|##7       | 54/200 [00:02<00:05, 24.61it/s]
 28%|##8       | 57/200 [00:02<00:05, 24.60it/s]
 30%|###       | 60/200 [00:02<00:05, 24.60it/s]
 32%|###1      | 63/200 [00:02<00:05, 24.59it/s]
 33%|###3      | 66/200 [00:02<00:05, 24.60it/s]
 34%|###4      | 69/200 [00:02<00:05, 24.57it/s]
 36%|###6      | 72/200 [00:02<00:05, 24.56it/s]
 38%|###7      | 75/200 [00:03<00:05, 24.55it/s]
 39%|###9      | 78/200 [00:03<00:04, 24.56it/s]
 40%|####      | 81/200 [00:03<00:04, 24.56it/s]
 42%|####2     | 84/200 [00:03<00:04, 24.54it/s]
 44%|####3     | 87/200 [00:03<00:04, 24.55it/s]
 45%|####5     | 90/200 [00:03<00:04, 24.56it/s]
 46%|####6     | 93/200 [00:03<00:04, 24.56it/s]
 48%|####8     | 96/200 [00:03<00:04, 24.56it/s]
 50%|####9     | 99/200 [00:04<00:04, 24.54it/s]
 51%|#####1    | 102/200 [00:04<00:03, 24.54it/s]
 52%|#####2    | 105/200 [00:04<00:03, 24.53it/s]
 54%|#####4    | 108/200 [00:04<00:03, 24.54it/s]
 56%|#####5    | 111/200 [00:04<00:03, 24.54it/s]
 57%|#####6    | 114/200 [00:04<00:03, 24.55it/s]
 58%|#####8    | 117/200 [00:04<00:03, 24.54it/s]
 60%|######    | 120/200 [00:04<00:03, 24.54it/s]
 62%|######1   | 123/200 [00:05<00:03, 24.54it/s]
 63%|######3   | 126/200 [00:05<00:03, 24.55it/s]
 64%|######4   | 129/200 [00:05<00:02, 24.55it/s]
 66%|######6   | 132/200 [00:05<00:02, 24.53it/s]
 68%|######7   | 135/200 [00:05<00:02, 24.47it/s]
 69%|######9   | 138/200 [00:05<00:02, 24.47it/s]
 70%|#######   | 141/200 [00:05<00:02, 24.47it/s]
 72%|#######2  | 144/200 [00:05<00:02, 24.49it/s]
 74%|#######3  | 147/200 [00:05<00:02, 24.52it/s]
 75%|#######5  | 150/200 [00:06<00:02, 24.54it/s]
 76%|#######6  | 153/200 [00:06<00:01, 24.53it/s]
 78%|#######8  | 156/200 [00:06<00:01, 24.52it/s]
 80%|#######9  | 159/200 [00:06<00:01, 24.52it/s]
 81%|########1 | 162/200 [00:06<00:01, 24.52it/s]
 82%|########2 | 165/200 [00:06<00:01, 24.50it/s]
 84%|########4 | 168/200 [00:06<00:01, 24.51it/s]
 86%|########5 | 171/200 [00:06<00:01, 24.51it/s]
 87%|########7 | 174/200 [00:07<00:01, 24.50it/s]
 88%|########8 | 177/200 [00:07<00:00, 24.49it/s]
 90%|######### | 180/200 [00:07<00:00, 24.48it/s]
 92%|#########1| 183/200 [00:07<00:00, 24.50it/s]
 93%|#########3| 186/200 [00:07<00:00, 24.50it/s]
 94%|#########4| 189/200 [00:07<00:00, 24.51it/s]
 96%|#########6| 192/200 [00:07<00:00, 24.51it/s]
 98%|#########7| 195/200 [00:07<00:00, 24.54it/s]
 99%|#########9| 198/200 [00:08<00:00, 24.49it/s]
100%|##########| 200/200 [00:08<00:00, 24.55it/s]
[18954.137, 31050.977, 19407.963, 15871.107, 13330.832, 14443.731, 13159.874, 12349.546, 12676.277, 11961.014, 11906.464, 10416.354, 11611.566, 10417.972, 10804.054, 10527.281, 9124.444, 11262.38, 9301.433, 8569.867, 8347.148, 9405.403, 7794.7485, 8109.105, 7257.8, 8218.562, 7245.007, 7782.9004, 7871.074, 7626.823, 8530.531, 7548.5454, 7778.264, 8224.826, 6902.9785, 7205.895, 7423.84, 6401.8545, 7116.817, 6488.8345, 6521.5986, 5753.2183, 5130.2563, 6996.4136, 6577.326, 5748.6694, 5768.578, 6988.4805, 5768.353, 5787.759, 6372.6465, 5348.3955, 5723.533, 4274.1353, 5269.3623, 5723.1226, 5444.287, 4706.5713, 5622.877, 5366.4336, 4514.6543, 4708.142, 5282.5, 6357.183, 4880.45, 5901.449, 4857.69, 5836.1177, 5845.1675, 5479.284, 4007.0168, 4709.371, 4036.6123, 5027.9785, 4353.3804, 4181.4995, 5830.0483, 4586.8594, 4254.7236, 4470.1934, 4015.5544, 4500.4434, 4568.95, 3513.9038, 3589.9968, 5498.219, 3294.775, 4083.7854, 3769.8765, 3219.3784, 3342.8894, 4108.8115, 3915.1362, 3150.9622, 3695.3713, 4342.0454, 3436.6143, 3273.1628, 4461.0205, 3909.701, 3074.854, 2944.664, 2993.8105, 4340.0195, 4000.6184, 3482.1833, 3112.2393, 3339.338, 3909.8728, 2852.2725, 3751.4966, 2770.0007, 2824.8738, 3748.8206, 3219.9255, 4720.1367, 3590.0186, 2802.2827, 3265.167, 3195.3096, 3490.1147, 2791.9478, 3799.0225, 3131.083, 3419.3914, 2652.4563, 3068.9573, 3152.8445, 4149.6323, 2918.5454, 4073.206, 3721.1946, 3552.317, 3330.7195, 2765.9038, 2887.549, 2938.086, 3276.3047, 3061.6812, 2664.5476, 3259.952, 2773.6226, 3141.1484, 2949.0503, 2356.173, 2734.7847, 2682.4731, 2455.6746, 2534.1172, 2317.5605, 3253.5571, 2783.5588, 2728.8015, 2713.7693, 2384.4868, 3035.3994, 2864.5825, 2632.3708, 2155.6738, 3013.587, 2784.9998, 3078.027, 2101.6091, 1709.1598, 2793.0637, 2505.3882, 2687.3157, 2117.1045, 1873.9286, 3022.3792, 2428.8342, 2512.3032, 2486.1597, 2481.2144, 2090.9, 2438.1907, 2082.797, 2191.189, 2421.3044, 2187.6257, 1895.4113, 2646.2888, 1953.2219, 2136.0935, 2026.9938, 2546.2915, 3144.115, 2129.9836, 3022.9688, 2555.0815, 2281.376, 2198.972, 1996.7545, 2399.6152, 2222.954, 2709.7866, 2269.5564, 2658.1304, 1772.8275, 2272.3315]

<AxesSubplot: title={'center': 'Train loss against iterations'}>

Total running time of the script: ( 0 minutes 27.891 seconds)

Gallery generated by Sphinx-Gallery