Train a scikit-learn neural network with onnxruntime-training on GPU#

This example leverages example Train a linear regression with onnxruntime-training on GPU in details to train a neural network from scikit-learn on GPU. However, the code is using classes implemented in this module, following the pattern introduced in exemple Train a linear regression with onnxruntime-training.

A neural network with scikit-learn#

import warnings
from pprint import pprint
import numpy
from pandas import DataFrame
from onnxruntime import get_device, InferenceSession
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPRegressor
from sklearn.metrics import mean_squared_error
from onnxcustom.plotting.plotting_onnx import plot_onnxs
from mlprodict.onnx_conv import to_onnx
from onnxcustom.utils.orttraining_helper import (
    add_loss_output, get_train_initializer)
from import OrtGradientOptimizer

X, y = make_regression(1000, n_features=10, bias=2)
X = X.astype(numpy.float32)
y = y.astype(numpy.float32)
X_train, X_test, y_train, y_test = train_test_split(X, y)

nn = MLPRegressor(hidden_layer_sizes=(10, 10), max_iter=200,
                  solver='sgd', learning_rate_init=1e-4, alpha=0,
                  n_iter_no_change=1000, batch_size=10,
                  momentum=0, nesterovs_momentum=False)

with warnings.catch_warnings():
    warnings.simplefilter('ignore'), y_train)


print(f"mean_squared_error={mean_squared_error(y_test, nn.predict(X_test))!r}")

Conversion to ONNX#

onx = to_onnx(nn, X_train[:1].astype(numpy.float32), target_opset=15)
plot orttraining nn gpu
<AxesSubplot: >

Training graph#

The loss function is the square function. We use function add_loss_output. It does something what is implemented in example Train a linear regression with onnxruntime-training in details.

onx_train = add_loss_output(onx)
plot orttraining nn gpu
<AxesSubplot: >

Let’s check inference is working.

sess = InferenceSession(onx_train.SerializeToString(),
res =, {'X': X_test, 'label': y_test.reshape((-1, 1))})
print(f"onnx loss={res[0][0, 0] / X_test.shape[0]!r}")
onnx loss=0.0939002456665039

Let’s retrieve the constant, the weight to optimize. We remove initializer which cannot be optimized.

inits = get_train_initializer(onx)
weights = {k: v for k, v in inits.items() if k != "shape_tensor"}
pprint(list((k, v[0].shape) for k, v in weights.items()))
[('coefficient', (10, 10)),
 ('intercepts', (1, 10)),
 ('coefficient1', (10, 10)),
 ('intercepts1', (1, 10)),
 ('coefficient2', (10, 1)),
 ('intercepts2', (1, 1))]


The training session. If GPU is available, it chooses CUDA otherwise it falls back to CPU.

device = "cuda" if get_device().upper() == 'GPU' else 'cpu'

print(f"device={device!r} get_device()={get_device()!r}")
device='cpu' get_device()='CPU'

The training session.

train_session = OrtGradientOptimizer(
    onx_train, list(weights), device=device, verbose=1,
    learning_rate=5e-4, warm_start=False, max_iter=200, batch_size=10), y)
state_tensors = train_session.get_state()


df = DataFrame({'ort losses': train_session.train_losses_,
                'skl losses:': nn.loss_curve_})
df.plot(title="Train loss against iterations", logy=True)

# import matplotlib.pyplot as plt
Train loss against iterations
<AxesSubplot: title={'center': 'Train loss against iterations'}>

