Training with onnxruntime#

onnxruntime offers the possibility to compute a gradient. Then with some extra lines, it is possible to implement a gradient descent.

Training capabilities are part of the same repository but released under a different package onnxruntime-training. It is not an extension, it replaces onnxruntime and has the same import name. It can be built with different compilation settings or downloaded from pypi. There are two versions to keep a low size for the version which only does inference.

Two training API are available. The first one assumes the loss is part of the graph to train. It can be trained as a whole. The second API assumes the graph is only a piece or a layer in a model trained by another framework or at least a logic which updates the weights. This mechanism is convenient when a model is trained with pytorch.

First API: TrainingSession#

TrainingSession is used by class OrtGradientOptimizer in example Train a linear regression with onnxruntime-training to show how it could be wrappped to train a model. Example Train a linear regression with onnxruntime-training in details digs into the details of the implementation. It goes through the following steps:

That’s what method OrtGradientOptimizer._create_training_session does. It does not implement a training algorithm, only an iteration - forward + backward - with the expected label, the learning rate and the features as inputs. The class updates its weights. When the training ends, the user must collect the updated weights and create a new ONNX file with the optimized weights.

onnxruntime-training does not implement loss functions. That must be done independently. That’s what function onnxcustom.utils.orttraining_helper.add_loss_output() does. It implements a couple of usual losses in ONNX. Another function onnxcustom.utils.orttraining_helper.get_train_initializer() guesses all the coefficients of an ONNX graph if the user does not specify any. Another common use not implemented in onnxruntime-training.

GPU is no different. It changes the syntax because data has to be moved on this device first. Example Train a linear regression with onnxruntime-training on GPU in details adapts previous example to this configuration. Finally, a last example compares this approach against scikit-learn in the same conditions.

Second API: TrainingAgent#

TrainingAgent is used by class OrtGradientForwardBackwardOptimizer to train the same model. The training is split into the forward step, the backward step (gradient computation), the weight updating step. TrainingAgent implement forward and backward. Everything else must be explicitely implemented outside of this class or be taken care of by an existing framework such as this one or pytorch. First, forward, backward with TrainingAgent. To build it, the following steps are needed:

That’s what method OrtGradientForwardBackward._create_onnx_graphs does. Forward and backward steps must be called separately. It is not trivial to guess how to call them (a forward step can be called to predict or to train if followed by a backward step). Class OrtGradientForwardBackwardFunction implements those two steps with the proper API. Next lines gives an idea on how it can be done. First the forward step.

def forward(self, inputs, training=False):
    forward_inputs = cls.input_to_ort(
        inputs, cls._devices, cls._debug)

    if training:
        forward_outputs = OrtValueVector()
        state = PartialGraphExecutionState()
            forward_inputs, forward_outputs, state, cls._cache)
        return forward_outputs
        iobinding = SessionIOBinding(cls._sess_eval._sess)
        for name, inp in zip(
                cls._grad_input_names, forward_inputs):
            iobinding.bind_ortvalue_input(name, inp)

        for name, dev in zip(
                cls._output_names, cls._fw_no_grad_output_device_info):
            iobinding.bind_output(name, dev)

            iobinding, cls._run_options)
        return iobinding.get_outputs()

Then the backward step.

def backward(self, grad_outputs):
    cls = self.__class__
    inputs = self.saved_tensors
    state = self.states_.pop()
    backward_inputs = cls.input_to_ort(
        grad_outputs, cls._bw_outputs_device_info, cls._debug)

    backward_outputs = OrtValueVector()
        backward_inputs, backward_outputs, state)
    return backward_outputs

The API implemented by class TrainingAgent does not use named inputs, only a list of inputs, the features followed by the current weights. Initializers must be be given names in alphabetical order to avoid any confusion with that API.

Train a linear regression with forward backward changes the previous example to use class OrtGradientForwardBackwardOptimizer explains the details of the implementation. This example is the best place to continue if using the raw API of onnxruntime-training is the goal. Then the same example is changed to use GPU: Forward backward on a neural network on GPU. And finally a benchmark to compare this approach with scikit-learn: Benchmark, comparison scikit-learn - forward-backward. Beside forward and backard, the training needs three elements to be complete.

  • a loss: a square loss for example SquareLearningLoss but it could be ElasticLearningPenalty.

  • a way to update the weight: a simple learning rate for a stockastic gradient descent LearningRateSGD or something more complex such as LearningRateSGDNesterov.

  • a regularization applied to the weight, it could be seen as an extension of the loss but this design seemed more simple as it does not mix the gradient applied to the output and the gradient due to the regularization, the most simple regularization is no regularization with NoLearningPenalty, but it could be L1 or L2 penalty as well with ElasticLearningPenalty.

These parts can easily be replaced by the same pieces implemented in pytorch. That’s what wrapper class ORTModule offers except it starts from a pytorch model then converted into ONNX. That’s what shows example Benchmark, comparison torch - forward-backward. Class OrtGradientForwardBackwardOptimizer directly starts with the ONNX graph and adds the pieces not implemented in onnxruntime-training.