Gradient#

C++ API#

class onnxruntime.capi._pybind_state.GradientGraphBuilder(self: onnxruntime.capi.onnxruntime_pybind11_state.GradientGraphBuilder, arg0: bytes, arg1: Set[str], arg2: Set[str], arg3: str)#

A utility for making a gradient graph that can be used to help train a model.

class onnxruntime.capi._pybind_state.GradientNodeAttributeDefinition(self: onnxruntime.capi.onnxruntime_pybind11_state.GradientNodeAttributeDefinition)#

Attribute definition for gradient graph nodes.

class onnxruntime.capi._pybind_state.GradientNodeDefinition(self: onnxruntime.capi.onnxruntime_pybind11_state.GradientNodeDefinition)#

Definition for gradient graph nodes.

onnxruntime.capi._pybind_state.register_gradient_definition(arg0: str, arg1: List[onnxruntime.capi.onnxruntime_pybind11_state.GradientNodeDefinition]) None#
onnxruntime.capi._pybind_state.register_aten_op_executor(arg0: str, arg1: str) None#
onnxruntime.capi._pybind_state.register_backward_runner(arg0: object) None#
onnxruntime.capi._pybind_state.register_forward_runner(arg0: object) None#

Python API#

onnxruntime.training.experimental.gradient_graph._gradient_graph_tools.export_gradient_graph(model: Module, loss_fn: Callable[[Any, Any], Any], example_input: Tensor, example_labels: Tensor, gradient_graph_path: Union[Path, str], opset_version=12) None#

Build a gradient graph for model so that you can output gradients in an inference session when given specific input and corresponding labels.

Parameters:
  • model (torch.nn.Module) – A gradient graph will be built for this model.

  • loss_fn (Callable[[Any, Any], Any]) – A function to compute the loss given the model’s output and the example_labels.

  • Web (Predefined loss functions such as torch.nn.CrossEntropyLoss() will work but you might not be able to load the graph in other environments such as an InferenceSession in ONNX Runtime) –

  • instead

  • method. (use a custom Python) –

  • example_input (torch.Tensor) – Example input that you would give your model for inference/prediction.

  • example_labels (torch.Tensor) – The expected labels for example_input.

  • different (This could be the output of your model when given example_input but it might be different if your loss function expects labels to be) –

  • gradient_graph_path (Union[Path, str]) – The path to where you would like to save the gradient graph.

  • opset_version (int) – See torch.onnx.export.