Training utilities#


onnxcustom.utils.onnx_helper.add_initializer (model, name, value)

Adds an initializer to graph.

onnxcustom.utils.onnx_helper.dtype_to_var_type (dtype)

Converts a numpy dtype into a var type.

onnxcustom.utils.onnx_helper.get_onnx_opset (onx, domain = ‘’)

Returns the opset associated to an opset.

onnxcustom.utils.orttraining_helper.get_train_initializer (onx)

Returns the list of initializers to train.

onnxcustom.utils.onnx_helper.proto_type_to_dtype (proto_type)

Converts a ONNX TensorProto type into numpy type.

onnxcustom.utils.onnx_helper.onnx_rename_weights (onx)

Renames ONNX initializers to make sure their name follows the alphabetical order. The model is modified inplace. This function calls onnx_rename_names.

onnxcustom.utils.onnx_rewriter.onnx_rewrite_operator (onx, op_type, sub_onx, recursive = True, debug_info = None)

Replaces one operator by an onnx graph.

onnxcustom.utils.onnx_helper.replace_initializers_into_onnx (model, results)

Replaces initializers by other initializers, usually trained ones.


onnxcustom.utils.onnxruntime_helper.device_to_providers (device)

Returns the corresponding providers for a specific device.

onnxcustom.utils.onnxruntime_helper.numpy_to_ort_value (arr, device = None)

Converts a numpy array to C_OrtValue.

onnxcustom.utils.onnxruntime_helper.get_ort_device (device)

Converts device into C_OrtDevice.

onnxcustom.utils.onnxruntime_helper.get_ort_device_type (device)

Converts device into device type.

onnxcustom.utils.onnxruntime_helper.ort_device_to_string (device)

Returns a string representing the device. Opposite of function get_ort_device.

onnxcustom.utils.onnxruntime_helper.provider_to_device (provider_name)

Converts provider into a device.


onnxcustom.utils.orttraining_helper.add_loss_output (onx, score_name = ‘squared_error’, loss_name = ‘loss’, label_name = ‘label’, weight_name = None, penalty = None, output_index = None, kwargs)

Modifies an ONNX graph to add operators to score and allow training.

onnxcustom.utils.onnx_function.get_supported_functions ()

Returns the list of supported function by function_onnx_graph.

onnxcustom.utils.onnx_function.function_onnx_graph (name, target_opset = None, dtype = <class ‘numpy.float32’>, weight_name = None, kwargs)

Returns the ONNX graph corresponding to a function.

onnxcustom.utils.orttraining_helper.penalty_loss_onnx (name, dtype, l1 = None, l2 = None, existing_names = None)

Returns onnx nodes to compute |w| \alpha + w^2 \beta where \alpha=l1 and \beta=l2.

gradient# (onx, weights = None, inputs = None, options = DerivativeOptions.Zero, loss = None, label = None, path_name = None)

Builds the gradient for an onnx graph.