Gradient#
C++ API#
- class onnxruntime.capi._pybind_state.GradientGraphBuilder(self: onnxruntime.capi.onnxruntime_pybind11_state.GradientGraphBuilder, arg0: bytes, arg1: Set[str], arg2: Set[str], arg3: str)#
A utility for making a gradient graph that can be used to help train a model.
- class onnxruntime.capi._pybind_state.GradientNodeAttributeDefinition(self: onnxruntime.capi.onnxruntime_pybind11_state.GradientNodeAttributeDefinition)#
Attribute definition for gradient graph nodes.
- class onnxruntime.capi._pybind_state.GradientNodeDefinition(self: onnxruntime.capi.onnxruntime_pybind11_state.GradientNodeDefinition)#
Definition for gradient graph nodes.
- onnxruntime.capi._pybind_state.register_gradient_definition(arg0: str, arg1: List[onnxruntime.capi.onnxruntime_pybind11_state.GradientNodeDefinition]) None #
Python API#
- onnxruntime.training.experimental.gradient_graph._gradient_graph_tools.export_gradient_graph(model: Module, loss_fn: Callable[[Any, Any], Any], example_input: Tensor, example_labels: Tensor, gradient_graph_path: Union[Path, str], opset_version=12) None #
Build a gradient graph for model so that you can output gradients in an inference session when given specific input and corresponding labels.
- Parameters:
model (torch.nn.Module) – A gradient graph will be built for this model.
loss_fn (Callable[[Any, Any], Any]) – A function to compute the loss given the model’s output and the example_labels.
Web (Predefined loss functions such as torch.nn.CrossEntropyLoss() will work but you might not be able to load the graph in other environments such as an InferenceSession in ONNX Runtime) –
instead –
method. (use a custom Python) –
example_input (torch.Tensor) – Example input that you would give your model for inference/prediction.
example_labels (torch.Tensor) – The expected labels for example_input.
different (This could be the output of your model when given example_input but it might be different if your loss function expects labels to be) –
gradient_graph_path (Union[Path, str]) – The path to where you would like to save the gradient graph.
opset_version (int) – See torch.onnx.export.