Forward backward on a neural network on GPU (Nesterov) and penalty#

This example does the same as Forward backward on a neural network on GPU but updates the weights using Nesterov momentum.

A neural network with scikit-learn#

import warnings
import numpy
import onnx
from pandas import DataFrame
from onnxruntime import get_device
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPRegressor
from sklearn.metrics import mean_squared_error
from onnxcustom.plotting.plotting_onnx import plot_onnxs
from mlprodict.onnx_conv import to_onnx
from mlprodict.plotting.text_plot import onnx_simple_text_plot
from onnxcustom.utils.orttraining_helper import get_train_initializer
from onnxcustom.utils.onnx_helper import onnx_rename_weights
from onnxcustom.training.optimizers_partial import (
    OrtGradientForwardBackwardOptimizer)
from onnxcustom.training.sgd_learning_rate import LearningRateSGDNesterov
from onnxcustom.training.sgd_learning_penalty import ElasticLearningPenalty


X, y = make_regression(1000, n_features=10, bias=2)
X = X.astype(numpy.float32)
y = y.astype(numpy.float32)
X_train, X_test, y_train, y_test = train_test_split(X, y)

nn = MLPRegressor(hidden_layer_sizes=(10, 10), max_iter=100,
                  solver='sgd', learning_rate_init=5e-5,
                  n_iter_no_change=1000, batch_size=10, alpha=0,
                  momentum=0.9, nesterovs_momentum=True)

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    nn.fit(X_train, y_train)

print(nn.loss_curve_)
[10054.449475097656, 189.1372885163625, 3.7245753371715544, 1.860364637374878, 1.274503894249598, 0.8726946311195691, 0.6833320026596387, 0.542678053577741, 0.49827574759721754, 0.424820536673069, 0.386580089405179, 0.33915482232968014, 0.31883926351865133, 0.2761942940453688, 0.27035212762653826, 0.24684627493222555, 0.21701112889995178, 0.21617389185974994, 0.20163319108386835, 0.18282731930414836, 0.16717151386042436, 0.1578208313261469, 0.15087749325359862, 0.14043892733752728, 0.13359561673055093, 0.12180486931776006, 0.1177771899725, 0.11096388620634874, 0.10386681098490953, 0.09780385530243317, 0.09634242203086614, 0.09435003049050768, 0.08390011178950468, 0.08096387526020407, 0.07830714888870716, 0.0765130594248573, 0.07293080120968322, 0.06970383134980997, 0.06543640260274211, 0.06315434299409389, 0.06127051665758093, 0.06381957611069083, 0.06015455489978194, 0.057736615799367424, 0.05688619713609417, 0.05612969385460019, 0.05420623283833265, 0.053199999170998734, 0.052347233413408197, 0.05181567606826623, 0.04814249501253168, 0.050154063176984585, 0.0475235358128945, 0.047307331937675674, 0.04616504316839079, 0.04499405236293872, 0.046677144082883995, 0.04502580287711074, 0.043434124618458254, 0.042241301809748014, 0.042568021453917027, 0.04173387736702959, 0.04200659551347295, 0.041385810108234486, 0.04010269008887311, 0.03797309905601044, 0.03775240391182403, 0.03843328082623581, 0.03881608373485505, 0.03571364842976133, 0.03581802526178459, 0.03502363634606202, 0.03564320217818022, 0.034188425450896225, 0.03460605141396324, 0.03415382140936951, 0.033409435441717504, 0.033176217232830825, 0.033438047903279464, 0.03274326724465936, 0.033322533718310295, 0.031592415530855456, 0.0336309671929727, 0.03201467059552669, 0.030856537763029337, 0.0300901587602372, 0.029325366423775753, 0.030038521621997157, 0.030455003576353192, 0.031351924588282905, 0.030277252153803905, 0.029705787518372138, 0.029387272868771106, 0.029001247140889367, 0.028990144620959956, 0.027676615475987393, 0.02735211929306388, 0.02829690771177411, 0.028539744125058254, 0.02718373096858462]

Score:

print(f"mean_squared_error={mean_squared_error(y_test, nn.predict(X_test))!r}")
mean_squared_error=0.21088222

Conversion to ONNX#

onx = to_onnx(nn, X_train[:1].astype(numpy.float32), target_opset=15)
plot_onnxs(onx)

weights = list(sorted(get_train_initializer(onx)))
print(weights)
plot orttraining nn gpu fwbw nesterov
['coefficient', 'coefficient1', 'coefficient2', 'intercepts', 'intercepts1', 'intercepts2']

Training graph with forward backward#

device = "cuda" if get_device().upper() == 'GPU' else 'cpu'

print(f"device={device!r} get_device()={get_device()!r}")

onx = onnx_rename_weights(onx)
train_session = OrtGradientForwardBackwardOptimizer(
    onx, device=device, verbose=1,
    learning_rate=LearningRateSGDNesterov(1e-4, nesterov=True, momentum=0.9),
    warm_start=False, max_iter=100, batch_size=10)
train_session.fit(X, y)
device='cpu' get_device()='CPU'

  0%|          | 0/100 [00:00<?, ?it/s]
  1%|1         | 1/100 [00:00<00:21,  4.64it/s]
  2%|2         | 2/100 [00:00<00:21,  4.64it/s]
  3%|3         | 3/100 [00:00<00:20,  4.64it/s]
  4%|4         | 4/100 [00:00<00:20,  4.64it/s]
  5%|5         | 5/100 [00:01<00:20,  4.64it/s]
  6%|6         | 6/100 [00:01<00:20,  4.64it/s]
  7%|7         | 7/100 [00:01<00:20,  4.64it/s]
  8%|8         | 8/100 [00:01<00:19,  4.64it/s]
  9%|9         | 9/100 [00:01<00:19,  4.64it/s]
 10%|#         | 10/100 [00:02<00:19,  4.64it/s]
 11%|#1        | 11/100 [00:02<00:19,  4.64it/s]
 12%|#2        | 12/100 [00:02<00:18,  4.63it/s]
 13%|#3        | 13/100 [00:02<00:18,  4.63it/s]
 14%|#4        | 14/100 [00:03<00:18,  4.63it/s]
 15%|#5        | 15/100 [00:03<00:18,  4.63it/s]
 16%|#6        | 16/100 [00:03<00:18,  4.63it/s]
 17%|#7        | 17/100 [00:03<00:17,  4.63it/s]
 18%|#8        | 18/100 [00:03<00:17,  4.63it/s]
 19%|#9        | 19/100 [00:04<00:17,  4.63it/s]
 20%|##        | 20/100 [00:04<00:17,  4.63it/s]
 21%|##1       | 21/100 [00:04<00:17,  4.62it/s]
 22%|##2       | 22/100 [00:04<00:16,  4.62it/s]
 23%|##3       | 23/100 [00:04<00:16,  4.63it/s]
 24%|##4       | 24/100 [00:05<00:16,  4.63it/s]
 25%|##5       | 25/100 [00:05<00:16,  4.61it/s]
 26%|##6       | 26/100 [00:05<00:16,  4.61it/s]
 27%|##7       | 27/100 [00:05<00:15,  4.62it/s]
 28%|##8       | 28/100 [00:06<00:15,  4.62it/s]
 29%|##9       | 29/100 [00:06<00:15,  4.62it/s]
 30%|###       | 30/100 [00:06<00:15,  4.61it/s]
 31%|###1      | 31/100 [00:06<00:15,  4.60it/s]
 32%|###2      | 32/100 [00:06<00:14,  4.60it/s]
 33%|###3      | 33/100 [00:07<00:14,  4.59it/s]
 34%|###4      | 34/100 [00:07<00:14,  4.59it/s]
 35%|###5      | 35/100 [00:07<00:14,  4.59it/s]
 36%|###6      | 36/100 [00:07<00:13,  4.58it/s]
 37%|###7      | 37/100 [00:08<00:13,  4.59it/s]
 38%|###8      | 38/100 [00:08<00:13,  4.60it/s]
 39%|###9      | 39/100 [00:08<00:13,  4.60it/s]
 40%|####      | 40/100 [00:08<00:13,  4.61it/s]
 41%|####1     | 41/100 [00:08<00:12,  4.61it/s]
 42%|####2     | 42/100 [00:09<00:12,  4.61it/s]
 43%|####3     | 43/100 [00:09<00:12,  4.61it/s]
 44%|####4     | 44/100 [00:09<00:12,  4.61it/s]
 45%|####5     | 45/100 [00:09<00:11,  4.61it/s]
 46%|####6     | 46/100 [00:09<00:11,  4.61it/s]
 47%|####6     | 47/100 [00:10<00:11,  4.61it/s]
 48%|####8     | 48/100 [00:10<00:11,  4.61it/s]
 49%|####9     | 49/100 [00:10<00:11,  4.60it/s]
 50%|#####     | 50/100 [00:10<00:10,  4.61it/s]
 51%|#####1    | 51/100 [00:11<00:10,  4.61it/s]
 52%|#####2    | 52/100 [00:11<00:10,  4.61it/s]
 53%|#####3    | 53/100 [00:11<00:10,  4.61it/s]
 54%|#####4    | 54/100 [00:11<00:09,  4.61it/s]
 55%|#####5    | 55/100 [00:11<00:09,  4.61it/s]
 56%|#####6    | 56/100 [00:12<00:09,  4.62it/s]
 57%|#####6    | 57/100 [00:12<00:09,  4.62it/s]
 58%|#####8    | 58/100 [00:12<00:09,  4.62it/s]
 59%|#####8    | 59/100 [00:12<00:08,  4.61it/s]
 60%|######    | 60/100 [00:12<00:08,  4.61it/s]
 61%|######1   | 61/100 [00:13<00:08,  4.61it/s]
 62%|######2   | 62/100 [00:13<00:08,  4.61it/s]
 63%|######3   | 63/100 [00:13<00:08,  4.61it/s]
 64%|######4   | 64/100 [00:13<00:07,  4.61it/s]
 65%|######5   | 65/100 [00:14<00:07,  4.61it/s]
 66%|######6   | 66/100 [00:14<00:07,  4.61it/s]
 67%|######7   | 67/100 [00:14<00:07,  4.61it/s]
 68%|######8   | 68/100 [00:14<00:06,  4.62it/s]
 69%|######9   | 69/100 [00:14<00:06,  4.62it/s]
 70%|#######   | 70/100 [00:15<00:06,  4.63it/s]
 71%|#######1  | 71/100 [00:15<00:06,  4.63it/s]
 72%|#######2  | 72/100 [00:15<00:06,  4.63it/s]
 73%|#######3  | 73/100 [00:15<00:05,  4.63it/s]
 74%|#######4  | 74/100 [00:16<00:05,  4.63it/s]
 75%|#######5  | 75/100 [00:16<00:05,  4.63it/s]
 76%|#######6  | 76/100 [00:16<00:05,  4.63it/s]
 77%|#######7  | 77/100 [00:16<00:04,  4.63it/s]
 78%|#######8  | 78/100 [00:16<00:04,  4.63it/s]
 79%|#######9  | 79/100 [00:17<00:04,  4.63it/s]
 80%|########  | 80/100 [00:17<00:04,  4.62it/s]
 81%|########1 | 81/100 [00:17<00:04,  4.61it/s]
 82%|########2 | 82/100 [00:17<00:03,  4.61it/s]
 83%|########2 | 83/100 [00:17<00:03,  4.61it/s]
 84%|########4 | 84/100 [00:18<00:03,  4.61it/s]
 85%|########5 | 85/100 [00:18<00:03,  4.62it/s]
 86%|########6 | 86/100 [00:18<00:03,  4.62it/s]
 87%|########7 | 87/100 [00:18<00:02,  4.62it/s]
 88%|########8 | 88/100 [00:19<00:02,  4.62it/s]
 89%|########9 | 89/100 [00:19<00:02,  4.62it/s]
 90%|######### | 90/100 [00:19<00:02,  4.63it/s]
 91%|#########1| 91/100 [00:19<00:01,  4.62it/s]
 92%|#########2| 92/100 [00:19<00:01,  4.62it/s]
 93%|#########3| 93/100 [00:20<00:01,  4.61it/s]
 94%|#########3| 94/100 [00:20<00:01,  4.61it/s]
 95%|#########5| 95/100 [00:20<00:01,  4.62it/s]
 96%|#########6| 96/100 [00:20<00:00,  4.62it/s]
 97%|#########7| 97/100 [00:21<00:00,  4.62it/s]
 98%|#########8| 98/100 [00:21<00:00,  4.61it/s]
 99%|#########9| 99/100 [00:21<00:00,  4.61it/s]
100%|##########| 100/100 [00:21<00:00,  4.61it/s]
100%|##########| 100/100 [00:21<00:00,  4.62it/s]

OrtGradientForwardBackwardOptimizer(model_onnx='ir_version...', weights_to_train="['I0_coeff...", loss_output_name='loss', max_iter=100, training_optimizer_name='SGDOptimizer', batch_size=10, learning_rate=LearningRateSGDNesterov(eta0=0.0001, alpha=0.0001, power_t=0.25, learning_rate='invscaling', momentum=0.9, nesterov=True), value=3.1622776601683795e-05, device='cpu', warm_start=False, verbose=1, validation_every=10, learning_loss=SquareLearningLoss(), enable_logging=False, weight_name=None, learning_penalty=NoLearningPenalty(), exc=True)

Let’s see the weights.

state_tensors = train_session.get_state()

And the loss.

print(train_session.train_losses_)

df = DataFrame({'ort losses': train_session.train_losses_,
                'skl losses:': nn.loss_curve_})
df.plot(title="Train loss against iterations (Nesterov)", logy=True)
Train loss against iterations (Nesterov)
[2482.5952, 16.686783, 7.5055327, 4.5816164, 3.202926, 2.6682677, 1.6791431, 1.4285043, 1.3416656, 1.1326325, 0.8957688, 0.82788086, 0.72842926, 0.6954739, 0.510436, 0.45630062, 0.51432884, 0.4509862, 0.38555565, 0.35307777, 0.42497143, 0.38387772, 0.30492845, 0.32352805, 0.25249144, 0.2725671, 0.26502246, 0.23859224, 0.201042, 0.19988051, 0.18336533, 0.19642735, 0.21465172, 0.18288924, 0.2016055, 0.2164054, 0.18049853, 0.2003468, 0.17537096, 0.1347837, 0.14113492, 0.16640556, 0.13889928, 0.1176382, 0.14449322, 0.13111524, 0.13592185, 0.14986442, 0.11279977, 0.13506591, 0.1107577, 0.119454585, 0.13107957, 0.1243559, 0.11652054, 0.10295862, 0.11147666, 0.10620057, 0.107936494, 0.10766308, 0.09265014, 0.13104309, 0.098379545, 0.09291313, 0.104170606, 0.1126431, 0.07876549, 0.08391076, 0.097605065, 0.10825506, 0.08321762, 0.08030946, 0.08336864, 0.09467383, 0.09617934, 0.07945354, 0.071173996, 0.06519376, 0.070295304, 0.07086661, 0.080556214, 0.077396005, 0.06331115, 0.067252904, 0.08449259, 0.07706063, 0.06558259, 0.06927825, 0.08264201, 0.074191496, 0.05676177, 0.08522473, 0.053270973, 0.0675921, 0.05769639, 0.058698535, 0.06474799, 0.06399178, 0.060638495, 0.06651075]

<AxesSubplot: title={'center': 'Train loss against iterations (Nesterov)'}>

The convergence rate is different but both classes do not update the learning exactly the same way.

Regularization#

Default parameters for MLPRegressor suggest to penalize weights during training: alpha=1e-4.

nn = MLPRegressor(hidden_layer_sizes=(10, 10), max_iter=100,
                  solver='sgd', learning_rate_init=5e-5,
                  n_iter_no_change=1000, batch_size=10, alpha=1e-4,
                  momentum=0.9, nesterovs_momentum=True)

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    nn.fit(X_train, y_train)

print(nn.loss_curve_)
[8788.052516152955, 125.04397798087665, 16.87310730919902, 6.11450787386373, 3.458880722833761, 2.36566677540706, 1.8096329200389225, 1.306794376330185, 1.0745067866942082, 0.9214900698321022, 0.8421491919091537, 0.7012372042871158, 0.655693690352726, 0.5857730045863309, 0.5574902763084094, 0.49682948769280144, 0.44812331218053486, 0.44328559742104223, 0.4179707387905757, 0.3797781364775023, 0.35639786191590633, 0.347460354052957, 0.3402620181665817, 0.3148887930714526, 0.2916719015270869, 0.2921010421114604, 0.26369987149974505, 0.255980564050746, 0.2525165998914719, 0.2388780372351408, 0.2432799556310097, 0.2276666573224625, 0.2273612847690424, 0.21398382606608085, 0.20447203167285916, 0.19683600255259678, 0.20020803300402168, 0.18580255492033962, 0.17874163727478193, 0.1793333719394843, 0.17181670444157912, 0.16127423663226764, 0.1655236408702453, 0.15610014390267532, 0.14962380448606816, 0.16484000797971488, 0.14847695694236757, 0.13644790766873358, 0.13814148018380007, 0.1353792354401112, 0.13489122693561714, 0.130907506252416, 0.13452651561396514, 0.1262878094845136, 0.12272184734819729, 0.11663721134821571, 0.11422777272015816, 0.12307965929588084, 0.10886620407215757, 0.11307516354598995, 0.10874837014035679, 0.10788932543419008, 0.10724228933554888, 0.10371358414481877, 0.10454944889689088, 0.10111476764650344, 0.10043875837034386, 0.10101110488689342, 0.0993876853888532, 0.09902955709837477, 0.09152927426936626, 0.09476917640967368, 0.08917339748057126, 0.09118896747116052, 0.08638342410014074, 0.09065931073526538, 0.08807614384918014, 0.08323922289647462, 0.0863466943374236, 0.0825813520977636, 0.08168856354325413, 0.07915624106135168, 0.0780415243612071, 0.07572586084063844, 0.0767130674306671, 0.07370180915985504, 0.07431589064689277, 0.0729574887900392, 0.07454802576505742, 0.07312724698729713, 0.07044385365910931, 0.07036599419771035, 0.06590895634582142, 0.06710517400419512, 0.0714021099541227, 0.06640130982647337, 0.06465087263889911, 0.0643265079466601, 0.06849931089646419, 0.06270189993597171]

Let’s do the same with onnxruntime.

train_session = OrtGradientForwardBackwardOptimizer(
    onx, device=device, verbose=1,
    learning_rate=LearningRateSGDNesterov(1e-4, nesterov=True, momentum=0.9),
    learning_penalty=ElasticLearningPenalty(l1=0, l2=1e-4),
    warm_start=False, max_iter=100, batch_size=10)
train_session.fit(X, y)
  0%|          | 0/100 [00:00<?, ?it/s]
  1%|1         | 1/100 [00:00<00:29,  3.39it/s]
  2%|2         | 2/100 [00:00<00:28,  3.39it/s]
  3%|3         | 3/100 [00:00<00:28,  3.38it/s]
  4%|4         | 4/100 [00:01<00:28,  3.37it/s]
  5%|5         | 5/100 [00:01<00:28,  3.38it/s]
  6%|6         | 6/100 [00:01<00:27,  3.37it/s]
  7%|7         | 7/100 [00:02<00:27,  3.37it/s]
  8%|8         | 8/100 [00:02<00:27,  3.38it/s]
  9%|9         | 9/100 [00:02<00:26,  3.37it/s]
 10%|#         | 10/100 [00:02<00:26,  3.37it/s]
 11%|#1        | 11/100 [00:03<00:26,  3.37it/s]
 12%|#2        | 12/100 [00:03<00:26,  3.37it/s]
 13%|#3        | 13/100 [00:03<00:25,  3.37it/s]
 14%|#4        | 14/100 [00:04<00:25,  3.37it/s]
 15%|#5        | 15/100 [00:04<00:25,  3.36it/s]
 16%|#6        | 16/100 [00:04<00:24,  3.36it/s]
 17%|#7        | 17/100 [00:05<00:24,  3.36it/s]
 18%|#8        | 18/100 [00:05<00:24,  3.36it/s]
 19%|#9        | 19/100 [00:05<00:24,  3.36it/s]
 20%|##        | 20/100 [00:05<00:23,  3.36it/s]
 21%|##1       | 21/100 [00:06<00:23,  3.36it/s]
 22%|##2       | 22/100 [00:06<00:23,  3.36it/s]
 23%|##3       | 23/100 [00:06<00:22,  3.36it/s]
 24%|##4       | 24/100 [00:07<00:22,  3.36it/s]
 25%|##5       | 25/100 [00:07<00:22,  3.35it/s]
 26%|##6       | 26/100 [00:07<00:22,  3.35it/s]
 27%|##7       | 27/100 [00:08<00:21,  3.35it/s]
 28%|##8       | 28/100 [00:08<00:21,  3.35it/s]
 29%|##9       | 29/100 [00:08<00:21,  3.36it/s]
 30%|###       | 30/100 [00:08<00:20,  3.35it/s]
 31%|###1      | 31/100 [00:09<00:20,  3.35it/s]
 32%|###2      | 32/100 [00:09<00:20,  3.35it/s]
 33%|###3      | 33/100 [00:09<00:19,  3.35it/s]
 34%|###4      | 34/100 [00:10<00:19,  3.34it/s]
 35%|###5      | 35/100 [00:10<00:19,  3.33it/s]
 36%|###6      | 36/100 [00:10<00:19,  3.33it/s]
 37%|###7      | 37/100 [00:11<00:18,  3.32it/s]
 38%|###8      | 38/100 [00:11<00:18,  3.33it/s]
 39%|###9      | 39/100 [00:11<00:18,  3.33it/s]
 40%|####      | 40/100 [00:11<00:18,  3.33it/s]
 41%|####1     | 41/100 [00:12<00:17,  3.33it/s]
 42%|####2     | 42/100 [00:12<00:17,  3.34it/s]
 43%|####3     | 43/100 [00:12<00:17,  3.35it/s]
 44%|####4     | 44/100 [00:13<00:16,  3.35it/s]
 45%|####5     | 45/100 [00:13<00:16,  3.36it/s]
 46%|####6     | 46/100 [00:13<00:16,  3.36it/s]
 47%|####6     | 47/100 [00:14<00:15,  3.36it/s]
 48%|####8     | 48/100 [00:14<00:15,  3.36it/s]
 49%|####9     | 49/100 [00:14<00:15,  3.36it/s]
 50%|#####     | 50/100 [00:14<00:14,  3.36it/s]
 51%|#####1    | 51/100 [00:15<00:14,  3.36it/s]
 52%|#####2    | 52/100 [00:15<00:14,  3.35it/s]
 53%|#####3    | 53/100 [00:15<00:14,  3.35it/s]
 54%|#####4    | 54/100 [00:16<00:13,  3.35it/s]
 55%|#####5    | 55/100 [00:16<00:13,  3.35it/s]
 56%|#####6    | 56/100 [00:16<00:13,  3.36it/s]
 57%|#####6    | 57/100 [00:16<00:12,  3.36it/s]
 58%|#####8    | 58/100 [00:17<00:12,  3.35it/s]
 59%|#####8    | 59/100 [00:17<00:12,  3.35it/s]
 60%|######    | 60/100 [00:17<00:11,  3.35it/s]
 61%|######1   | 61/100 [00:18<00:11,  3.35it/s]
 62%|######2   | 62/100 [00:18<00:11,  3.35it/s]
 63%|######3   | 63/100 [00:18<00:11,  3.35it/s]
 64%|######4   | 64/100 [00:19<00:10,  3.35it/s]
 65%|######5   | 65/100 [00:19<00:10,  3.35it/s]
 66%|######6   | 66/100 [00:19<00:10,  3.35it/s]
 67%|######7   | 67/100 [00:19<00:09,  3.35it/s]
 68%|######8   | 68/100 [00:20<00:09,  3.35it/s]
 69%|######9   | 69/100 [00:20<00:09,  3.34it/s]
 70%|#######   | 70/100 [00:20<00:08,  3.34it/s]
 71%|#######1  | 71/100 [00:21<00:08,  3.34it/s]
 72%|#######2  | 72/100 [00:21<00:08,  3.34it/s]
 73%|#######3  | 73/100 [00:21<00:08,  3.34it/s]
 74%|#######4  | 74/100 [00:22<00:07,  3.33it/s]
 75%|#######5  | 75/100 [00:22<00:07,  3.33it/s]
 76%|#######6  | 76/100 [00:22<00:07,  3.33it/s]
 77%|#######7  | 77/100 [00:22<00:06,  3.33it/s]
 78%|#######8  | 78/100 [00:23<00:06,  3.33it/s]
 79%|#######9  | 79/100 [00:23<00:06,  3.33it/s]
 80%|########  | 80/100 [00:23<00:06,  3.33it/s]
 81%|########1 | 81/100 [00:24<00:05,  3.33it/s]
 82%|########2 | 82/100 [00:24<00:05,  3.33it/s]
 83%|########2 | 83/100 [00:24<00:05,  3.33it/s]
 84%|########4 | 84/100 [00:25<00:04,  3.33it/s]
 85%|########5 | 85/100 [00:25<00:04,  3.33it/s]
 86%|########6 | 86/100 [00:25<00:04,  3.33it/s]
 87%|########7 | 87/100 [00:25<00:03,  3.34it/s]
 88%|########8 | 88/100 [00:26<00:03,  3.34it/s]
 89%|########9 | 89/100 [00:26<00:03,  3.35it/s]
 90%|######### | 90/100 [00:26<00:02,  3.35it/s]
 91%|#########1| 91/100 [00:27<00:02,  3.35it/s]
 92%|#########2| 92/100 [00:27<00:02,  3.35it/s]
 93%|#########3| 93/100 [00:27<00:02,  3.35it/s]
 94%|#########3| 94/100 [00:28<00:01,  3.35it/s]
 95%|#########5| 95/100 [00:28<00:01,  3.35it/s]
 96%|#########6| 96/100 [00:28<00:01,  3.35it/s]
 97%|#########7| 97/100 [00:28<00:00,  3.34it/s]
 98%|#########8| 98/100 [00:29<00:00,  3.35it/s]
 99%|#########9| 99/100 [00:29<00:00,  3.34it/s]
100%|##########| 100/100 [00:29<00:00,  3.33it/s]
100%|##########| 100/100 [00:29<00:00,  3.35it/s]

OrtGradientForwardBackwardOptimizer(model_onnx='ir_version...', weights_to_train="['I0_coeff...", loss_output_name='loss', max_iter=100, training_optimizer_name='SGDOptimizer', batch_size=10, learning_rate=LearningRateSGDNesterov(eta0=0.0001, alpha=0.0001, power_t=0.25, learning_rate='invscaling', momentum=0.9, nesterov=True), value=3.1622776601683795e-05, device='cpu', warm_start=False, verbose=1, validation_every=10, learning_loss=SquareLearningLoss(), enable_logging=False, weight_name=None, learning_penalty=ElasticLearningPenalty(l1=0, l2=0.0001), exc=True)

Let’s see the weights.

state_tensors = train_session.get_state()

And the loss.

print(train_session.train_losses_)

df = DataFrame({'ort losses': train_session.train_losses_,
                'skl losses:': nn.loss_curve_})
df.plot(title="Train loss against iterations (Nesterov + penalty)", logy=True)
Train loss against iterations (Nesterov + penalty)
[3464.9238, 21.151068, 8.625083, 6.2473383, 3.9837186, 3.2991576, 3.2089133, 2.6556857, 2.1656234, 1.738589, 1.5543369, 1.6887753, 1.2635362, 1.2149721, 0.9751648, 0.9839249, 0.9172316, 0.9104758, 0.84075683, 0.6801579, 0.6840201, 0.5834166, 0.543251, 0.52099663, 0.4822024, 0.55136734, 0.44080117, 0.42342654, 0.41430926, 0.38076505, 0.4324667, 0.38826057, 0.33802825, 0.3422785, 0.31227237, 0.31036714, 0.28148624, 0.282955, 0.2430402, 0.28611413, 0.29229707, 0.27799365, 0.23675726, 0.2337163, 0.23705846, 0.23123205, 0.25134268, 0.23188902, 0.22295786, 0.20838757, 0.20398639, 0.20057708, 0.18738486, 0.20114826, 0.1861803, 0.16981076, 0.1922723, 0.19230647, 0.172244, 0.15982531, 0.17538148, 0.1476835, 0.15446931, 0.15051419, 0.14535221, 0.14153738, 0.18138912, 0.13479964, 0.15239921, 0.12375267, 0.13673629, 0.12308879, 0.1449343, 0.119252235, 0.13855606, 0.120667726, 0.12693155, 0.11828371, 0.13156797, 0.11487253, 0.122170426, 0.11302893, 0.110918514, 0.10241692, 0.104556344, 0.097668216, 0.11328426, 0.11257473, 0.1051394, 0.11605338, 0.11633578, 0.10778884, 0.10735649, 0.10253565, 0.11987091, 0.10584743, 0.11028698, 0.11730781, 0.10583352, 0.111658856]

<AxesSubplot: title={'center': 'Train loss against iterations (Nesterov + penalty)'}>

All ONNX graphs#

Method Method save_onnx_graph can export all the ONNX graph used by the model on disk.

def print_graph(d):
    for k, v in sorted(d.items()):
        if isinstance(v, dict):
            print_graph(v)
        else:
            print("\n++++++", v.replace("\\", "/"), "\n")
            with open(v, "rb") as f:
                print(onnx_simple_text_plot(onnx.load(f)))


all_files = train_session.save_onnx_graph('.')
print_graph(all_files)


# import matplotlib.pyplot as plt
# plt.show()
++++++ ./SquareLLoss.learning_loss.loss_grad_onnx_.onnx

opset: domain='' version=14
input: name='X1' type=dtype('float32') shape=[None, None]
input: name='X2' type=dtype('float32') shape=[None, None]
init: name='Mu_Mulcst' type=dtype('float32') shape=(1,) -- array([0.5], dtype=float32)
init: name='Re_Reshapecst' type=dtype('int64') shape=(1,) -- array([-1])
init: name='Mu_Mulcst1' type=dtype('float32') shape=(1,) -- array([-1.], dtype=float32)
Sub(X1, X2) -> Su_C0
  Mul(Su_C0, Mu_Mulcst1) -> Y_grad
ReduceSumSquare(Su_C0) -> Re_reduced0
  Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0
    Reshape(Mu_C0, Re_Reshapecst) -> Y
output: name='Y' type=dtype('float32') shape=None
output: name='Y_grad' type=dtype('float32') shape=None

++++++ ./SquareLLoss.learning_loss.loss_score_onnx_.onnx

opset: domain='' version=14
input: name='X1' type=dtype('float32') shape=[None, None]
input: name='X2' type=dtype('float32') shape=[None, None]
Sub(X1, X2) -> Su_C0
  Mul(Su_C0, Su_C0) -> Y
output: name='Y' type=dtype('float32') shape=[None, 1]

++++++ ./ElasticLPenalty.learning_penalty.penalty_grad_onnx_.onnx

opset: domain='' version=14
input: name='X' type=dtype('float32') shape=None
init: name='Mu_Mulcst' type=dtype('float32') shape=(1,) -- array([0.9998], dtype=float32)
init: name='Mu_Mulcst1' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)
Mul(X, Mu_Mulcst) -> Mu_C0
Sign(X) -> Si_output0
  Mul(Si_output0, Mu_Mulcst1) -> Mu_C02
  Sub(Mu_C0, Mu_C02) -> Y
output: name='Y' type=dtype('float32') shape=None

++++++ ./ElasticLPenalty.learning_penalty.penalty_onnx_.onnx

opset: domain='' version=14
input: name='loss' type=dtype('float32') shape=None
input: name='W0' type=dtype('float32') shape=None
input: name='W1' type=dtype('float32') shape=None
input: name='W2' type=dtype('float32') shape=None
input: name='W3' type=dtype('float32') shape=None
input: name='W4' type=dtype('float32') shape=None
input: name='W5' type=dtype('float32') shape=None
init: name='Mu_Mulcst' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)
init: name='Mu_Mulcst1' type=dtype('float32') shape=(1,) -- array([1.e-04], dtype=float32)
init: name='Re_Reshapecst' type=dtype('int64') shape=(1,) -- array([-1])
Abs(W0) -> Ab_Y0
  ReduceSum(Ab_Y0) -> Re_reduced0
    Mul(Re_reduced0, Mu_Mulcst) -> Mu_C0
Abs(W1) -> Ab_Y02
  ReduceSum(Ab_Y02) -> Re_reduced03
Abs(W2) -> Ab_Y03
  ReduceSum(Ab_Y03) -> Re_reduced05
Identity(Mu_Mulcst) -> Mu_Mulcst2
  Mul(Re_reduced03, Mu_Mulcst2) -> Mu_C03
ReduceSumSquare(W1) -> Re_reduced04
Identity(Mu_Mulcst1) -> Mu_Mulcst3
  Mul(Re_reduced04, Mu_Mulcst3) -> Mu_C04
    Add(Mu_C03, Mu_C04) -> Ad_C07
ReduceSumSquare(W0) -> Re_reduced02
  Mul(Re_reduced02, Mu_Mulcst1) -> Mu_C02
    Add(Mu_C0, Mu_C02) -> Ad_C06
      Add(Ad_C06, Ad_C07) -> Ad_C05
ReduceSumSquare(W4) -> Re_reduced010
ReduceSumSquare(W5) -> Re_reduced012
ReduceSumSquare(W2) -> Re_reduced06
Identity(Mu_Mulcst1) -> Mu_Mulcst11
  Mul(Re_reduced012, Mu_Mulcst11) -> Mu_C012
Identity(Mu_Mulcst) -> Mu_Mulcst4
  Mul(Re_reduced05, Mu_Mulcst4) -> Mu_C05
ReduceSumSquare(W3) -> Re_reduced08
Identity(Mu_Mulcst1) -> Mu_Mulcst5
  Mul(Re_reduced06, Mu_Mulcst5) -> Mu_C06
    Add(Mu_C05, Mu_C06) -> Ad_C08
      Add(Ad_C05, Ad_C08) -> Ad_C04
Abs(W3) -> Ab_Y04
  ReduceSum(Ab_Y04) -> Re_reduced07
Identity(Mu_Mulcst) -> Mu_Mulcst6
  Mul(Re_reduced07, Mu_Mulcst6) -> Mu_C07
Identity(Mu_Mulcst1) -> Mu_Mulcst7
  Mul(Re_reduced08, Mu_Mulcst7) -> Mu_C08
    Add(Mu_C07, Mu_C08) -> Ad_C09
      Add(Ad_C04, Ad_C09) -> Ad_C03
Abs(W4) -> Ab_Y05
  ReduceSum(Ab_Y05) -> Re_reduced09
Identity(Mu_Mulcst) -> Mu_Mulcst8
  Mul(Re_reduced09, Mu_Mulcst8) -> Mu_C09
Identity(Mu_Mulcst1) -> Mu_Mulcst9
  Mul(Re_reduced010, Mu_Mulcst9) -> Mu_C010
    Add(Mu_C09, Mu_C010) -> Ad_C010
      Add(Ad_C03, Ad_C010) -> Ad_C02
Abs(W5) -> Ab_Y06
  ReduceSum(Ab_Y06) -> Re_reduced011
Identity(Mu_Mulcst) -> Mu_Mulcst10
  Mul(Re_reduced011, Mu_Mulcst10) -> Mu_C011
    Add(Mu_C011, Mu_C012) -> Ad_C011
      Add(Ad_C02, Ad_C011) -> Ad_C01
        Add(loss, Ad_C01) -> Ad_C0
          Reshape(Ad_C0, Re_Reshapecst) -> Y
output: name='Y' type=dtype('float32') shape=[None]

++++++ ./LRateSGDNesterov.learning_rate.axpyw_onnx_.onnx

opset: domain='' version=14
input: name='X1' type=dtype('float32') shape=None
input: name='X2' type=dtype('float32') shape=None
input: name='G' type=dtype('float32') shape=None
input: name='alpha' type=dtype('float32') shape=[1]
input: name='beta' type=dtype('float32') shape=[1]
Mul(X1, alpha) -> Mu_C0
Mul(G, beta) -> Mu_C03
  Add(Mu_C0, Mu_C03) -> Z
    Mul(Z, beta) -> Mu_C02
  Add(Mu_C0, Mu_C02) -> Ad_C0
    Add(Ad_C0, X2) -> Y
output: name='Y' type=dtype('float32') shape=None
output: name='Z' type=dtype('float32') shape=None

++++++ ./GradFBOptimizer.model_onnx.onnx

opset: domain='' version=14
input: name='X' type=dtype('float32') shape=[None, 10]
init: name='I0_coefficient' type=dtype('float32') shape=(100,)
init: name='I1_intercepts' type=dtype('float32') shape=(10,)
init: name='I2_coefficient1' type=dtype('float32') shape=(100,)
init: name='I3_intercepts1' type=dtype('float32') shape=(10,)
init: name='I4_coefficient2' type=dtype('float32') shape=(10,)
init: name='I5_intercepts2' type=dtype('float32') shape=(1,) -- array([-1.0380448], dtype=float32)
init: name='I6_shape_tensor' type=dtype('int64') shape=(2,) -- array([-1,  1])
Cast(X, to=1) -> r0
  MatMul(r0, I0_coefficient) -> r1
    Add(r1, I1_intercepts) -> r2
      Relu(r2) -> r3
        MatMul(r3, I2_coefficient1) -> r4
          Add(r4, I3_intercepts1) -> r5
            Relu(r5) -> r6
              MatMul(r6, I4_coefficient2) -> r7
                Add(r7, I5_intercepts2) -> r8
                  Reshape(r8, I6_shape_tensor) -> variable
output: name='variable' type=dtype('float32') shape=[None, 1]

++++++ ./OrtGradientForwardBackwardFunction_140079623617408.train_function_._optimized_pre_grad_model.onnx

opset: domain='' version=14
opset: domain='com.microsoft.experimental' version=1
opset: domain='ai.onnx.preview.training' version=1
opset: domain='ai.onnx.training' version=1
opset: domain='com.ms.internal.nhwc' version=17
opset: domain='org.pytorch.aten' version=1
opset: domain='com.microsoft.nchwc' version=1
opset: domain='ai.onnx.ml' version=3
opset: domain='com.microsoft' version=1
input: name='X' type=dtype('float32') shape=[None, 10]
input: name='I0_coefficient' type=dtype('float32') shape=[10, 10]
input: name='I1_intercepts' type=dtype('float32') shape=[1, 10]
input: name='I2_coefficient1' type=dtype('float32') shape=[10, 10]
input: name='I3_intercepts1' type=dtype('float32') shape=[1, 10]
input: name='I4_coefficient2' type=dtype('float32') shape=[10, 1]
input: name='I5_intercepts2' type=dtype('float32') shape=[1, 1]
init: name='I6_shape_tensor' type=dtype('int64') shape=(2,) -- array([-1,  1])
MatMul(X, I0_coefficient) -> r1
  Add(r1, I1_intercepts) -> r2
    Relu(r2) -> r3
      MatMul(r3, I2_coefficient1) -> r4
        Add(r4, I3_intercepts1) -> r5
          Relu(r5) -> r6
            MatMul(r6, I4_coefficient2) -> r7
              Add(r7, I5_intercepts2) -> r8
                Reshape(r8, I6_shape_tensor, allowzero=0) -> variable
output: name='variable' type=dtype('float32') shape=[None, 1]

++++++ ./OrtGradientForwardBackwardFunction_140079623617408.train_function_._trained_onnx.onnx

opset: domain='' version=14
opset: domain='com.microsoft.experimental' version=1
opset: domain='ai.onnx.preview.training' version=1
opset: domain='ai.onnx.training' version=1
opset: domain='com.ms.internal.nhwc' version=17
opset: domain='org.pytorch.aten' version=1
opset: domain='com.microsoft.nchwc' version=1
opset: domain='ai.onnx.ml' version=3
opset: domain='com.microsoft' version=1
input: name='X' type=dtype('float32') shape=[None, 10]
input: name='I0_coefficient' type=dtype('float32') shape=[10, 10]
input: name='I1_intercepts' type=dtype('float32') shape=[1, 10]
input: name='I2_coefficient1' type=dtype('float32') shape=[10, 10]
input: name='I3_intercepts1' type=dtype('float32') shape=[1, 10]
input: name='I4_coefficient2' type=dtype('float32') shape=[10, 1]
input: name='I5_intercepts2' type=dtype('float32') shape=[1, 1]
init: name='I6_shape_tensor' type=dtype('int64') shape=(2,) -- array([-1,  1])
init: name='n1_Grad/A_target_shape' type=dtype('int64') shape=(2,) -- array([-1, 10])
init: name='n1_Grad/dY_target_shape' type=dtype('int64') shape=(2,) -- array([-1, 10])
init: name='n4_Grad/A_target_shape' type=dtype('int64') shape=(2,) -- array([-1, 10])
init: name='n4_Grad/dY_target_shape' type=dtype('int64') shape=(2,) -- array([-1, 10])
init: name='n7_Grad/A_target_shape' type=dtype('int64') shape=(2,) -- array([-1, 10])
init: name='n7_Grad/dY_target_shape' type=dtype('int64') shape=(2,) -- array([-1,  1])
MatMul(X, I0_coefficient) -> r1
  Add(r1, I1_intercepts) -> r2
    Relu(r2) -> r3
      MatMul(r3, I2_coefficient1) -> r4
        Add(r4, I3_intercepts1) -> r5
          Relu(r5) -> r6
            MatMul(r6, I4_coefficient2) -> r7
              Add(r7, I5_intercepts2) -> r8
                Reshape(r8, I6_shape_tensor, allowzero=0) -> variable
                  YieldOp[com.microsoft](variable, full_shape_outputs=[0]) -> variable_grad
                Shape(r8) -> n9_Grad/x_shape
                  Reshape(variable_grad, n9_Grad/x_shape, allowzero=0) -> r8_grad
Shape(I5_intercepts2) -> n8_Grad/Shape_I5_intercepts2
Shape(r7) -> n8_Grad/Shape_r7
  BroadcastGradientArgs[com.microsoft](n8_Grad/Shape_r7, n8_Grad/Shape_I5_intercepts2) -> n8_Grad/ReduceAxes_r7, n8_Grad/ReduceAxes_I5_intercepts2
    ReduceSum(r8_grad, n8_Grad/ReduceAxes_I5_intercepts2, noop_with_empty_axes=1, keepdims=1) -> n8_Grad/ReduceSum_r8_grad_for_I5_intercepts2
  Reshape(n8_Grad/ReduceSum_r8_grad_for_I5_intercepts2, n8_Grad/Shape_I5_intercepts2, allowzero=0) -> I5_intercepts2_grad
ReduceSum(r8_grad, n8_Grad/ReduceAxes_r7, noop_with_empty_axes=1, keepdims=1) -> n8_Grad/ReduceSum_r8_grad_for_r7
  Reshape(n8_Grad/ReduceSum_r8_grad_for_r7, n8_Grad/Shape_r7, allowzero=0) -> r7_grad
    Reshape(r7_grad, n7_Grad/dY_target_shape, allowzero=0) -> n7_Grad/dY_reshape_2d
Reshape(r6, n7_Grad/A_target_shape, allowzero=0) -> n7_Grad/A_reshape_2d
  Gemm(n7_Grad/A_reshape_2d, n7_Grad/dY_reshape_2d, beta=1.00, transB=0, transA=1, alpha=1.00) -> I4_coefficient2_grad
FusedMatMul[com.microsoft](r7_grad, I4_coefficient2, transBatchB=0, transB=1, alpha=1.00, transA=0, transBatchA=0) -> n7_Grad/PreReduceGrad0
  Shape(n7_Grad/PreReduceGrad0) -> n7_Grad/Shape_n7_Grad/PreReduceGrad0
Shape(r6) -> n7_Grad/Shape_r6
  BroadcastGradientArgs[com.microsoft](n7_Grad/Shape_r6, n7_Grad/Shape_n7_Grad/PreReduceGrad0) -> n7_Grad/ReduceAxes_r6_for_r6,
  ReduceSum(n7_Grad/PreReduceGrad0, n7_Grad/ReduceAxes_r6_for_r6, noop_with_empty_axes=1, keepdims=1) -> n7_Grad/ReduceSum_n7_Grad/PreReduceGrad0_for_r6
  Reshape(n7_Grad/ReduceSum_n7_Grad/PreReduceGrad0_for_r6, n7_Grad/Shape_r6, allowzero=0) -> r6_grad
    ReluGrad[com.microsoft](r6_grad, r6) -> r5_grad
Shape(I3_intercepts1) -> n5_Grad/Shape_I3_intercepts1
Shape(r4) -> n5_Grad/Shape_r4
  BroadcastGradientArgs[com.microsoft](n5_Grad/Shape_r4, n5_Grad/Shape_I3_intercepts1) -> n5_Grad/ReduceAxes_r4, n5_Grad/ReduceAxes_I3_intercepts1
    ReduceSum(r5_grad, n5_Grad/ReduceAxes_I3_intercepts1, noop_with_empty_axes=1, keepdims=1) -> n5_Grad/ReduceSum_r5_grad_for_I3_intercepts1
  Reshape(n5_Grad/ReduceSum_r5_grad_for_I3_intercepts1, n5_Grad/Shape_I3_intercepts1, allowzero=0) -> I3_intercepts1_grad
ReduceSum(r5_grad, n5_Grad/ReduceAxes_r4, noop_with_empty_axes=1, keepdims=1) -> n5_Grad/ReduceSum_r5_grad_for_r4
  Reshape(n5_Grad/ReduceSum_r5_grad_for_r4, n5_Grad/Shape_r4, allowzero=0) -> r4_grad
    Reshape(r4_grad, n4_Grad/dY_target_shape, allowzero=0) -> n4_Grad/dY_reshape_2d
Reshape(r3, n4_Grad/A_target_shape, allowzero=0) -> n4_Grad/A_reshape_2d
  Gemm(n4_Grad/A_reshape_2d, n4_Grad/dY_reshape_2d, beta=1.00, transB=0, transA=1, alpha=1.00) -> I2_coefficient1_grad
FusedMatMul[com.microsoft](r4_grad, I2_coefficient1, transBatchB=0, transB=1, alpha=1.00, transA=0, transBatchA=0) -> n4_Grad/PreReduceGrad0
  Shape(n4_Grad/PreReduceGrad0) -> n4_Grad/Shape_n4_Grad/PreReduceGrad0
Shape(r3) -> n4_Grad/Shape_r3
  BroadcastGradientArgs[com.microsoft](n4_Grad/Shape_r3, n4_Grad/Shape_n4_Grad/PreReduceGrad0) -> n4_Grad/ReduceAxes_r3_for_r3,
  ReduceSum(n4_Grad/PreReduceGrad0, n4_Grad/ReduceAxes_r3_for_r3, noop_with_empty_axes=1, keepdims=1) -> n4_Grad/ReduceSum_n4_Grad/PreReduceGrad0_for_r3
  Reshape(n4_Grad/ReduceSum_n4_Grad/PreReduceGrad0_for_r3, n4_Grad/Shape_r3, allowzero=0) -> r3_grad
    ReluGrad[com.microsoft](r3_grad, r3) -> r2_grad
Shape(I1_intercepts) -> n2_Grad/Shape_I1_intercepts
Shape(r1) -> n2_Grad/Shape_r1
  BroadcastGradientArgs[com.microsoft](n2_Grad/Shape_r1, n2_Grad/Shape_I1_intercepts) -> n2_Grad/ReduceAxes_r1, n2_Grad/ReduceAxes_I1_intercepts
    ReduceSum(r2_grad, n2_Grad/ReduceAxes_I1_intercepts, noop_with_empty_axes=1, keepdims=1) -> n2_Grad/ReduceSum_r2_grad_for_I1_intercepts
  Reshape(n2_Grad/ReduceSum_r2_grad_for_I1_intercepts, n2_Grad/Shape_I1_intercepts, allowzero=0) -> I1_intercepts_grad
ReduceSum(r2_grad, n2_Grad/ReduceAxes_r1, noop_with_empty_axes=1, keepdims=1) -> n2_Grad/ReduceSum_r2_grad_for_r1
  Reshape(n2_Grad/ReduceSum_r2_grad_for_r1, n2_Grad/Shape_r1, allowzero=0) -> r1_grad
    Reshape(r1_grad, n1_Grad/dY_target_shape, allowzero=0) -> n1_Grad/dY_reshape_2d
Reshape(X, n1_Grad/A_target_shape, allowzero=0) -> n1_Grad/A_reshape_2d
  Gemm(n1_Grad/A_reshape_2d, n1_Grad/dY_reshape_2d, beta=1.00, transB=0, transA=1, alpha=1.00) -> I0_coefficient_grad
FusedMatMul[com.microsoft](r1_grad, I0_coefficient, transBatchB=0, transB=1, alpha=1.00, transA=0, transBatchA=0) -> n1_Grad/PreReduceGrad0
  Shape(n1_Grad/PreReduceGrad0) -> n1_Grad/Shape_n1_Grad/PreReduceGrad0
Shape(X) -> n1_Grad/Shape_X
  BroadcastGradientArgs[com.microsoft](n1_Grad/Shape_X, n1_Grad/Shape_n1_Grad/PreReduceGrad0) -> n1_Grad/ReduceAxes_X_for_X,
  ReduceSum(n1_Grad/PreReduceGrad0, n1_Grad/ReduceAxes_X_for_X, noop_with_empty_axes=1, keepdims=1) -> n1_Grad/ReduceSum_n1_Grad/PreReduceGrad0_for_X
  Reshape(n1_Grad/ReduceSum_n1_Grad/PreReduceGrad0_for_X, n1_Grad/Shape_X, allowzero=0) -> X_grad
output: name='X_grad' type=dtype('float32') shape=[None, 10]
output: name='I0_coefficient_grad' type=dtype('float32') shape=[10, 10]
output: name='I1_intercepts_grad' type=dtype('float32') shape=[1, 10]
output: name='I2_coefficient1_grad' type=dtype('float32') shape=[10, 10]
output: name='I3_intercepts1_grad' type=dtype('float32') shape=[1, 10]
output: name='I4_coefficient2_grad' type=dtype('float32') shape=[10, 1]
output: name='I5_intercepts2_grad' type=dtype('float32') shape=[1, 1]

++++++ ./GradFBOptimizer.zero_onnx_.onnx

opset: domain='' version=14
input: name='X' type=dtype('float32') shape=None
init: name='Mu_Mulcst' type=dtype('float32') shape=(1,) -- array([0.], dtype=float32)
Mul(X, Mu_Mulcst) -> Y
output: name='Y' type=dtype('float32') shape=None

Total running time of the script: ( 1 minutes 18.636 seconds)

Gallery generated by Sphinx-Gallery