module onnxtorch.torchort
#
Short summary#
module deeponnxcustom.onnxtorch.torchort
Experimental.
Classes#
class |
truncated documentation |
---|---|
A class which dynamically another class which implements a custom function (see autograd functions). Use … |
|
Ancestor to all classes created by |
Functions#
function |
truncated documentation |
---|---|
Implements backward function. See autograd functions. |
|
Implements forward function. See autograd functions. |
Static Methods#
staticmethod |
truncated documentation |
---|---|
used to improve logging messages |
|
Converts a OrtValueVector into a tuple of pytorch tensors. |
|
Converts a list of pytorch tensors into an OrtValueVector. |
Methods#
method |
truncated documentation |
---|---|
usual |
|
Creates a class which inherits from |
Documentation#
Experimental.
- class deeponnxcustom.onnxtorch.torchort.TorchOrtFactory(onnx_model, weights_to_train, input_names=None, output_names=None, class_name=None, sess_options=None, providers=None, provider_options=None, run_options=None, graph_builder_config=None, device_index=0)#
Bases:
object
A class which dynamically another class which implements a custom function (see autograd functions). Use ONNX inside a torch function. Only initializers can be trained, no parameters.
- Parameters
onnx_model – onnx model
weights_to_train – names of the weights to train
input_names – input names or None for all
output_names – output names or None for all
class_name – class name
sess_options – see SessionOptions
providers – see InferenceSession
provider_options – see InferenceSession
run_options – see RunOptions
graph_builder_config – see OrtModuleGraphBuilderConfiguration
device_index – used for cuda (0 for cuda:0, cuda:1, …), 0 by default
Note
The current implementation of onnxruntime forces the weights to train to appear in the alphabetical order. The constructor checks that condition is verified.
Warning
This class does not consider subgraphs.
- __init__(onnx_model, weights_to_train, input_names=None, output_names=None, class_name=None, sess_options=None, providers=None, provider_options=None, run_options=None, graph_builder_config=None, device_index=0)#
- __repr__()#
usual
- static _provider_name_to_device_type(provider_name)#
- static _repr_helper_(obj, indent=0)#
used to improve logging messages
- create_class(enable_logging=False, keep_models=False, debug=False)#
Creates a class which inherits from
torch.autograd.Function()
and implements forward, backward methods using ONNX. The function dynamically creates a new class and pushes every needed objects as static attributes of the new class.- Parameters
enable_logging – used to debug, logs every building step, at info level, logs information while processing forward and backward at debug level
keep_models – stores additional information as static attributes
debug – display information
- Returns
a new class
The pattern follows the documentation described in autograd functions. Methods forward and backward are replaced by onnx implementations, runtime is onnxruntime-training.
class CustomClass(torch.autograd.Function): @staticmethod def forward(ctx, *input): ctx.save_for_backward(*input) return ... @staticmethod def backward(ctx, *grad_output): input, = ctx.saved_tensors grad_input = grad_output.clone() grad_input[input < 0] = 0 return grad_input
The new class has the following attributes:
__doc__: doc string
__module__: module name (this file)
_run_options: see RunOptions
_sess: InferenceSession with the original graph
- _sess_eval: InferenceSession on the graph
with weights as inputs
_training_agent: :epkg:`TrainingAgent`
_cache: :epkg:`OrtValueCache`
_update_cache: update the cache or not
_states: a list
_logger: logger
_input_names: input names
_debug: use debug mode
_grad_input_names: gradient input names
_output_names: output names
_weights_to_train: names of the weights to train
Torch API:
forward: forward static method
backward: forward static method
Training attributes
_bw_fetches_names: bw_fetches_names,
_fw_outputs_device_info: fw_outputs_device_info,
_bw_outputs_device_info: bw_outputs_device_info,
_fw_no_grad_output_device_info: fw_no_grad_output_device_info,
_graph_info: graph_info}
Additional attributes added if keep_model is True:
_trained_onnx: ONNX graph for the gradient
- _optimized_pre_grad_model: evaluation ONNX graph taking
weights as inputs
_graph_builder: :epkg:`OrtModuleGraphBuilder`
- class deeponnxcustom.onnxtorch.torchort.TorchOrtFunction(*args, **kwargs)#
Bases:
torch.autograd.function.Function
Ancestor to all classes created by
TorchOrtFactory
. It implements simple functions to move the ownership of tensors from onnxruntime to pytorch (or the other way around) through DLPack structures. This class requires :epkg:`onnxruntime_training`.Differences between onnxruntime and onnxruntime-training
onnxruntime-training is an extension of onnxruntime that supports training. Version 1.10 is obtained by compiling onnxruntime from the sources with different flags. One example:
python ./tools/ci_build/build.py --build_dir ./build/debian \ --config Release --build_wheel --numpy_version= \ --skip_tests --build_shared_lib --enable_training \ --enable_training_ops --enable_training_torch_interop \ --parallel
- _backward_cls#
alias of
torch.autograd.function.TorchOrtFunctionBackward
- static from_ort_to_torch(ort_values)#
Converts a OrtValueVector into a tuple of pytorch tensors.
- static from_torch_to_ort(tensors)#
Converts a list of pytorch tensors into an OrtValueVector.
- deeponnxcustom.onnxtorch.torchort.ort_backward(ctx, *grad_outputs)#
Implements backward function. See autograd functions.
- deeponnxcustom.onnxtorch.torchort.ort_forward(ctx, *inputs)#
Implements forward function. See autograd functions.