Partial Training#

OrtValueCache#

class onnxruntime.capi._pybind_state.OrtValueCache(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache)#
__init__(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache) None#
clear(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache) None#
count(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache, arg0: str) int#
insert(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache, arg0: str, arg1: onnxruntime.capi.onnxruntime_pybind11_state.OrtValue) None#
keys(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache) list#
remove(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache, arg0: str) None#

TrainingAgent#

class onnxruntime.capi._pybind_state.TrainingAgent(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingAgent, arg0: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg1: List[str], arg2: List[onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice], arg3: List[str], arg4: List[onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice], arg5: int)#

This is the main class used to run a ORTModule model.

__init__(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingAgent, arg0: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg1: List[str], arg2: List[onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice], arg3: List[str], arg4: List[onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice], arg5: int) None#
run_backward(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingAgent, arg0: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector, arg1: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector, arg2: onnxruntime.capi.onnxruntime_pybind11_state.PartialGraphExecutionState) None#
run_forward(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingAgent, arg0: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector, arg1: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector, arg2: onnxruntime.capi.onnxruntime_pybind11_state.PartialGraphExecutionState, arg3: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache) None#

PartialGraphExecutionState#

class onnxruntime.capi._pybind_state.PartialGraphExecutionState(self: onnxruntime.capi.onnxruntime_pybind11_state.PartialGraphExecutionState)#
__init__(self: onnxruntime.capi.onnxruntime_pybind11_state.PartialGraphExecutionState) None#

OrtModuleGraphBuilder#

class onnxruntime.capi._pybind_state.OrtModuleGraphBuilder(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder)#
__init__(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) None#
build(*args, **kwargs)#

Overloaded function.

  1. build(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) -> None

  2. build(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder, arg0: List[List[int]]) -> None

get_forward_model(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) bytes#
get_gradient_model(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) bytes#
get_graph_info(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) onnxruntime.capi.onnxruntime_pybind11_state.GraphInfo#
initialize(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder, arg0: bytes, arg1: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilderConfiguration) None#

OrtModuleGraphBuilderConfiguration#

class onnxruntime.capi._pybind_state.OrtModuleGraphBuilderConfiguration(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilderConfiguration)#

Configuration information for module graph builder.

__init__(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilderConfiguration) None#
property build_gradient_graph#
property enable_caching#
property graph_transformer_config#
property initializer_names#
property initializer_names_to_train#
property input_names_require_grad#
property loglevel#
property use_memory_efficient_gradient#

TrainingGraphTransformerConfiguration#

class onnxruntime.capi._pybind_state.TrainingGraphTransformerConfiguration(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingGraphTransformerConfiguration)#

Training Graph transformer configuration.

__init__(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingGraphTransformerConfiguration) None#
property attn_dropout_recompute#
property enable_gelu_approximation#
property gelu_recompute#
property number_recompute_layers#
property propagate_cast_ops_config#
property transformer_layer_recompute#