module onnxtorch.torchort#

Inheritance diagram of deeponnxcustom.onnxtorch.torchort

Short summary#

module deeponnxcustom.onnxtorch.torchort

Experimental.

source on GitHub

Classes#

class

truncated documentation

TorchOrtFactory

A class which dynamically another class which implements a custom function (see autograd functions). Use …

TorchOrtFunction

Ancestor to all classes created by TorchOrtFactory. It implements simple functions to move the ownership of …

Functions#

function

truncated documentation

ort_backward

Implements backward function. See autograd functions.

ort_forward

Implements forward function. See autograd functions.

Static Methods#

staticmethod

truncated documentation

_provider_name_to_device_type

_repr_helper_

used to improve logging messages

from_ort_to_torch

Converts a OrtValueVector into a tuple of pytorch tensors.

from_torch_to_ort

Converts a list of pytorch tensors into an OrtValueVector.

Methods#

method

truncated documentation

__init__

__repr__

usual

create_class

Creates a class which inherits from torch.autograd.Function() and implements forward, backward methods …

Documentation#

Experimental.

source on GitHub

class deeponnxcustom.onnxtorch.torchort.TorchOrtFactory(onnx_model, weights_to_train, input_names=None, output_names=None, class_name=None, sess_options=None, providers=None, provider_options=None, run_options=None, graph_builder_config=None, device_index=0)#

Bases: object

A class which dynamically another class which implements a custom function (see autograd functions). Use ONNX inside a torch function. Only initializers can be trained, no parameters.

Parameters
  • onnx_model – onnx model

  • weights_to_train – names of the weights to train

  • input_names – input names or None for all

  • output_names – output names or None for all

  • class_name – class name

  • sess_options – see SessionOptions

  • providers – see InferenceSession

  • provider_options – see InferenceSession

  • run_options – see RunOptions

  • graph_builder_config – see OrtModuleGraphBuilderConfiguration

  • device_index – used for cuda (0 for cuda:0, cuda:1, …), 0 by default

Note

The current implementation of onnxruntime forces the weights to train to appear in the alphabetical order. The constructor checks that condition is verified.

Warning

This class does not consider subgraphs.

source on GitHub

__init__(onnx_model, weights_to_train, input_names=None, output_names=None, class_name=None, sess_options=None, providers=None, provider_options=None, run_options=None, graph_builder_config=None, device_index=0)#
__repr__()#

usual

static _provider_name_to_device_type(provider_name)#
static _repr_helper_(obj, indent=0)#

used to improve logging messages

create_class(enable_logging=False, keep_models=False, debug=False)#

Creates a class which inherits from torch.autograd.Function() and implements forward, backward methods using ONNX. The function dynamically creates a new class and pushes every needed objects as static attributes of the new class.

Parameters
  • enable_logging – used to debug, logs every building step, at info level, logs information while processing forward and backward at debug level

  • keep_models – stores additional information as static attributes

  • debug – display information

Returns

a new class

The pattern follows the documentation described in autograd functions. Methods forward and backward are replaced by onnx implementations, runtime is onnxruntime-training.

class CustomClass(torch.autograd.Function):

    @staticmethod
    def forward(ctx, *input):
        ctx.save_for_backward(*input)
        return ...

    @staticmethod
    def backward(ctx, *grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

The new class has the following attributes:

  • __doc__: doc string

  • __module__: module name (this file)

  • _run_options: see RunOptions

  • _sess: InferenceSession with the original graph

  • _sess_eval: InferenceSession on the graph

    with weights as inputs

  • _training_agent: :epkg:`TrainingAgent`

  • _cache: :epkg:`OrtValueCache`

  • _update_cache: update the cache or not

  • _states: a list

  • _logger: logger

  • _input_names: input names

  • _debug: use debug mode

  • _grad_input_names: gradient input names

  • _output_names: output names

  • _weights_to_train: names of the weights to train

Torch API:

  • forward: forward static method

  • backward: forward static method

Training attributes

  • _bw_fetches_names: bw_fetches_names,

  • _fw_outputs_device_info: fw_outputs_device_info,

  • _bw_outputs_device_info: bw_outputs_device_info,

  • _fw_no_grad_output_device_info: fw_no_grad_output_device_info,

  • _graph_info: graph_info}

Additional attributes added if keep_model is True:

  • _trained_onnx: ONNX graph for the gradient

  • _optimized_pre_grad_model: evaluation ONNX graph taking

    weights as inputs

  • _graph_builder: :epkg:`OrtModuleGraphBuilder`

source on GitHub

class deeponnxcustom.onnxtorch.torchort.TorchOrtFunction(*args, **kwargs)#

Bases: torch.autograd.function.Function

Ancestor to all classes created by TorchOrtFactory. It implements simple functions to move the ownership of tensors from onnxruntime to pytorch (or the other way around) through DLPack structures. This class requires :epkg:`onnxruntime_training`.

Differences between onnxruntime and onnxruntime-training

onnxruntime-training is an extension of onnxruntime that supports training. Version 1.10 is obtained by compiling onnxruntime from the sources with different flags. One example:

python ./tools/ci_build/build.py --build_dir ./build/debian \
       --config Release --build_wheel --numpy_version= \
       --skip_tests --build_shared_lib --enable_training \
       --enable_training_ops --enable_training_torch_interop \
       --parallel

source on GitHub

_backward_cls#

alias of torch.autograd.function.TorchOrtFunctionBackward

static from_ort_to_torch(ort_values)#

Converts a OrtValueVector into a tuple of pytorch tensors.

static from_torch_to_ort(tensors)#

Converts a list of pytorch tensors into an OrtValueVector.

deeponnxcustom.onnxtorch.torchort.ort_backward(ctx, *grad_outputs)#

Implements backward function. See autograd functions.

source on GitHub

deeponnxcustom.onnxtorch.torchort.ort_forward(ctx, *inputs)#

Implements forward function. See autograd functions.

source on GitHub