Coverage for onnxcustom/utils/onnx_helper.py: 100%
75 statements
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 01:42 +0100
« prev ^ index » next coverage.py v7.0.5, created at 2023-01-17 01:42 +0100
1# pylint: disable=C0415,E0611,E1101
2"""
3@file
4@brief Onnx implementation of common functions used to train a model.
5"""
6import math
7import numpy
8from onnx import TensorProto, numpy_helper, helper
9from onnxruntime import OrtValue
10from onnxruntime.capi._pybind_state import OrtValue as C_OrtValue
13def onnx_rename_weights(onx):
14 """
15 Renames ONNX initializers to make sure their name
16 follows the alphabetical order. The model is
17 modified inplace. This function calls
18 :func:`onnx_rename_names
19 <mlprodict.onnx_tools.onnx_manipulations.onnx_rename_names>`.
21 :param onx: ONNX model
22 :return: same model
24 .. note::
25 The function does not go into subgraphs.
26 """
27 from mlprodict.onnx_tools.onnx_manipulations import ( # pylint: disable=C0415
28 onnx_rename_names)
30 init = [init.name for init in onx.graph.initializer]
31 ninit = max(1, int(math.log(len(init)) / math.log(10) + 1))
32 fmt = f"I%0{ninit}d_%s"
33 new_names = [fmt % (i, name) for i, name in enumerate(init)]
34 repl = dict(zip(init, new_names))
35 return onnx_rename_names(onx, recursive=False, replace=repl)
38def get_onnx_opset(onx, domain=''):
39 """
40 Returns the opset associated to an opset.
42 :param onx: onx graph
43 :param domain: domain
44 :return: value
45 """
46 for opset in onx.opset_import:
47 if opset.domain == domain:
48 return opset.version
49 raise ValueError(
50 f"Unable to find opset for domain={domain!r}.")
53def proto_type_to_dtype(proto_type):
54 """
55 Converts a ONNX TensorProto type into numpy type.
57 :param proto_type: integer
58 :return: proto type
59 """
60 if proto_type == TensorProto.FLOAT:
61 return numpy.float32
62 if proto_type == TensorProto.DOUBLE:
63 return numpy.float64
64 # Not efficient.
65 if proto_type == 'tensor(float)':
66 return numpy.float32
67 if proto_type == 'tensor(double)':
68 return numpy.float64
69 raise ValueError(
70 f"Unexpected value proto_type={proto_type!r} (type={type(proto_type)!r}).")
73def dtype_to_var_type(dtype):
74 """
75 Converts a numpy dtype into a var type.
76 """
77 from skl2onnx.common.data_types import (
78 FloatTensorType, DoubleTensorType,
79 Int32TensorType, Int64TensorType)
80 if dtype == numpy.float32:
81 return FloatTensorType
82 if dtype == numpy.float64:
83 return DoubleTensorType
84 if dtype == numpy.int64:
85 return Int64TensorType
86 if dtype == numpy.int32:
87 return Int32TensorType
88 raise ValueError(
89 f"Unexpected value dtype={dtype!r}.")
92def _finalize_new_onnx(graph, onx):
93 onnx_model = helper.make_model(graph)
94 onnx_model.ir_version = onx.ir_version
95 onnx_model.producer_name = onx.producer_name
96 onnx_model.producer_version = onx.producer_version
97 onnx_model.domain = onx.domain
98 onnx_model.model_version = onx.model_version
99 onnx_model.doc_string = onx.doc_string
100 if len(onx.metadata_props) > 0: # pragma: no cover
101 values = {p.key: p.value for p in onx.metadata_props}
102 helper.set_model_props(onnx_model, values)
104 del onnx_model.opset_import[:] # pylint: disable=E1101
105 for oimp in onx.opset_import:
106 op_set = onnx_model.opset_import.add() # pylint: disable=E1101
107 op_set.domain = oimp.domain
108 op_set.version = oimp.version
109 return onnx_model
112def add_initializer(model, name, value):
113 """
114 Adds an initializer to graph.
116 :param model: onnx model
117 :param name: initializer name
118 :param value: value
119 :return: new ONNX graph
120 """
121 inits = set(i.name for i in model.graph.initializer)
122 if name in inits:
123 raise ValueError( # pragma: no cover
124 f"Name {name!r} is already taken among {inits!r}.")
125 list_inits = list(model.graph.initializer)
126 list_inits.append(
127 numpy_helper.from_array(value, name=name))
128 graph_def = helper.make_graph(
129 model.graph.node, model.graph.name,
130 model.graph.input, model.graph.output,
131 list_inits)
132 return _finalize_new_onnx(graph_def, model)
135def replace_initializers_into_onnx(model, results):
136 """
137 Replaces initializers by other initializers,
138 usually trained ones.
140 :param model: onnx graph
141 :param results: results to be added in a dictionary
142 :return: new onnx graph
143 """
144 inputs = list(model.graph.input)
145 outputs = list(model.graph.output)
146 inits = list(model.graph.initializer)
148 inits_dict = {init.name: i for i, init in enumerate(inits)}
149 for k, v in results.items():
150 if k in inits_dict:
151 if isinstance(v, numpy.ndarray):
152 v = numpy_helper.from_array(v, k)
153 elif isinstance(v, (C_OrtValue, OrtValue)):
154 v = numpy_helper.from_array(v.numpy(), k)
155 inits[inits_dict[k]] = v
156 else:
157 raise RuntimeError( # pragma: no cover
158 f"Unable to find initializer {k!r} in {inits_dict!r}.")
160 graph = helper.make_graph(
161 list(model.graph.node), model.graph.name, inputs,
162 outputs, inits)
163 return _finalize_new_onnx(graph, model)