Partial Training#
OrtValueCache#
- class onnxruntime.capi._pybind_state.OrtValueCache(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache)#
- __init__(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache) None #
- count(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache, arg0: str) int #
- insert(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache, arg0: str, arg1: onnxruntime.capi.onnxruntime_pybind11_state.OrtValue) None #
- remove(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache, arg0: str) None #
TrainingAgent#
- class onnxruntime.capi._pybind_state.TrainingAgent(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingAgent, arg0: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg1: List[str], arg2: List[onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice], arg3: List[str], arg4: List[onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice], arg5: int)#
This is the main class used to run a ORTModule model.
- __init__(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingAgent, arg0: onnxruntime.capi.onnxruntime_pybind11_state.InferenceSession, arg1: List[str], arg2: List[onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice], arg3: List[str], arg4: List[onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice], arg5: int) None #
- run_backward(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingAgent, arg0: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector, arg1: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector, arg2: onnxruntime.capi.onnxruntime_pybind11_state.PartialGraphExecutionState) None #
- run_forward(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingAgent, arg0: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector, arg1: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueVector, arg2: onnxruntime.capi.onnxruntime_pybind11_state.PartialGraphExecutionState, arg3: onnxruntime.capi.onnxruntime_pybind11_state.OrtValueCache) None #
PartialGraphExecutionState#
- class onnxruntime.capi._pybind_state.PartialGraphExecutionState(self: onnxruntime.capi.onnxruntime_pybind11_state.PartialGraphExecutionState)#
OrtModuleGraphBuilder#
- class onnxruntime.capi._pybind_state.OrtModuleGraphBuilder(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder)#
- __init__(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) None #
- build(*args, **kwargs)#
Overloaded function.
build(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) -> None
build(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder, arg0: List[List[int]]) -> None
- get_forward_model(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) bytes #
- get_gradient_model(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilder) bytes #
OrtModuleGraphBuilderConfiguration#
- class onnxruntime.capi._pybind_state.OrtModuleGraphBuilderConfiguration(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilderConfiguration)#
Configuration information for module graph builder.
- __init__(self: onnxruntime.capi.onnxruntime_pybind11_state.OrtModuleGraphBuilderConfiguration) None #
- property build_gradient_graph#
- property enable_caching#
- property graph_transformer_config#
- property initializer_names#
- property initializer_names_to_train#
- property input_names_require_grad#
- property loglevel#
- property use_memory_efficient_gradient#
TrainingGraphTransformerConfiguration#
- class onnxruntime.capi._pybind_state.TrainingGraphTransformerConfiguration(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingGraphTransformerConfiguration)#
Training Graph transformer configuration.
- __init__(self: onnxruntime.capi.onnxruntime_pybind11_state.TrainingGraphTransformerConfiguration) None #
- property attn_dropout_recompute#
- property enable_gelu_approximation#
- property gelu_recompute#
- property number_recompute_layers#
- property propagate_cast_ops_config#
- property transformer_layer_recompute#