Coverage for mlprodict/onnxrt/ops_onnxruntime/_op.py: 96%
142 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2"""
3@file
4@brief Shortcut to *ops_onnxruntime*.
5"""
6import numpy
7import onnx.defs
8from onnx.helper import make_tensor
9from onnx.onnx_cpp2py_export.shape_inference import InferenceError # pylint: disable=E0401,E0611
10from ...tools.ort_wrapper import InferenceSession
11from ...onnx_tools.onnx2py_helper import guess_proto_dtype
12from ...onnx_tools.optim.graph_schema_helper import (
13 get_defined_inputs, get_defined_outputs, proto2vars)
16_schemas = {
17 schema.name: schema for schema in onnx.defs.get_all_schemas_with_history()}
20class OpRunOnnxRuntime:
21 """
22 Unique operator which calls :epkg:`onnxruntime`
23 to compute predictions for one operator.
24 """
26 def __init__(self, onnx_node, desc=None, variables=None,
27 dtype=None, runtime=None, **options):
28 """
29 :param onnx_node: :epkg:`onnx` node
30 :param desc: internal representation
31 :param variables: registered variables created by previous operators
32 :param dtype: float computation type
33 :param options: runtime options
34 :param runtime: `onnxruntime1`, `onnxruntime1-cuda`, ...
35 """
36 self._provider = 'onnxruntime'
37 self.onnx_node = onnx_node
38 self.desc = desc
39 self.runtime = runtime
40 self._schema = _schemas.get(onnx_node.op_type, None)
41 if desc is not None:
42 if 'atts' in desc:
43 for a, b in desc['atts'].items():
44 if not isinstance(b, dict) or 'value' not in b:
45 raise ValueError( # pragma: no cover
46 f"Unexpected value {b}.")
47 options[a] = b['value']
49 self.options = options
50 self.dtype = dtype
51 self._init(variables)
53 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
54 InvalidArgument as OrtInvalidArgument)
55 self.OrtInvalidArgument = OrtInvalidArgument
57 def _name_mapping(self, inputs):
58 mapping = {}
59 new_inputs = []
60 for name in inputs:
61 if name in mapping:
62 i = 0
63 new_name = f"{name}_{i}"
64 while new_name in mapping:
65 i += 1 # pragma: no cover
66 new_name = f"{name}_{i}" # pragma: no cover
67 mapping[new_name] = name
68 new_inputs.append(new_name)
69 else:
70 new_inputs.append(name)
71 mapping[name] = name
72 return mapping, new_inputs
74 def _guess_proto_type(self, dtype):
75 return guess_proto_dtype(dtype)
77 def _init(self, variables=None):
78 """
79 Initializes the node.
81 :param variables: registered variables created by previous operators
83 The current implementation for operator *Scan*
84 only works for matrices.
85 """
86 custom_nodes = self.options.get('nodes', None)
87 if (custom_nodes is not None and
88 self.onnx_node.op_type in custom_nodes):
89 self.alg_class = custom_nodes[self.onnx_node.op_type]
90 else:
91 try:
92 import mlprodict.onnx_conv.onnx_ops as alg0
93 self.alg_class = getattr(alg0, 'Onnx' + self.onnx_node.op_type)
94 except AttributeError:
95 import skl2onnx.algebra.custom_ops as alg2 # delayed
96 try:
97 self.alg_class = getattr(
98 alg2, 'Onnx' + self.onnx_node.op_type)
99 except AttributeError:
100 import skl2onnx.algebra.onnx_ops as alg # delayed
101 self.alg_class = getattr(
102 alg, 'Onnx' + self.onnx_node.op_type)
104 inputs = list(self.onnx_node.input)
105 self.mapping, self.inputs = self._name_mapping(inputs)
106 self.outputs = list(self.onnx_node.output)
108 options = self.options.copy()
109 options.pop('nodes', None)
110 target_opset = options.pop('target_opset', None)
111 domain = options.pop('domain', None)
112 disable_optimisation = options.pop('disable_optimisation', False)
113 session_options = options.pop('session_options', False)
114 ir_version = options.pop('ir_version', None)
116 if domain == '' and target_opset < 9:
117 # target_opset should be >= 9 not {} for main domain.
118 # We assume it was the case when the graph was created.
119 pass
121 if self.onnx_node.op_type == 'ZipMap':
122 from skl2onnx.common.data_types import ( # delayed
123 DictionaryType, FloatTensorType, Int64TensorType, StringTensorType)
124 self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs,
125 op_version=target_opset, **options)
126 inputs = get_defined_inputs(
127 self.inputs, variables, dtype=self.dtype)
128 name = (self.outputs[0] if len(self.outputs) == 1
129 else self.inst_.expected_outputs[0][0])
130 otype = (Int64TensorType if 'classlabels_int64s' in options
131 else StringTensorType)
132 outvar = [(name, DictionaryType(otype([1]), FloatTensorType([1])))]
133 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outvar)
134 forced = True
135 elif self.onnx_node.op_type == 'ArrayFeatureExtractor':
136 self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs,
137 op_version=target_opset, **options)
138 inputs = get_defined_inputs(
139 self.inputs, variables, dtype=self.dtype)
140 name = (self.outputs[0] if len(self.outputs) == 1
141 else self.inst_.expected_outputs[0][0])
142 otype = inputs[0][1].__class__
143 outvar = [(name, otype())]
144 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outvar)
145 forced = True
146 elif self.onnx_node.op_type == 'ConstantOfShape':
147 for k in options: # pylint: disable=C0206
148 v = options[k]
149 if isinstance(v, numpy.ndarray):
150 options[k] = make_tensor(
151 k, self._guess_proto_type(v.dtype),
152 v.shape, v.tolist())
154 self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs,
155 op_version=target_opset, **options)
156 inputs = get_defined_inputs(
157 self.inputs, variables, dtype=self.dtype)
158 try:
159 self.onnx_ = self.inst_.to_onnx(inputs, target_opset=target_opset,
160 domain=domain)
161 if "dim_value: 0" in str(self.onnx_):
162 raise RuntimeError( # pragma: no cover
163 f"Probable issue as one dimension is null.\n--\n{self.onnx_}")
164 except AttributeError as e: # pragma: no cover
165 # older version of skl2onnx
166 self.onnx_ = self.inst_.to_onnx(inputs)
167 if "dim_value: 0" in str(self.onnx_):
168 raise RuntimeError(
169 "Probable issue as one dimension is null.\n--\n{}".format(
170 self.onnx_)) from e
171 forced = False
172 elif self.onnx_node.op_type == 'Scan':
173 self.inst_ = self.alg_class(
174 *self.inputs, output_names=self.outputs,
175 op_version=target_opset, **options)
176 inputs = get_defined_inputs(
177 self.inputs, variables, dtype=self.dtype)
178 outputs = get_defined_outputs(
179 self.outputs, self.onnx_node, inputs, variables,
180 dtype=self.dtype)
181 inputs = [(name, cl.__class__([None, None]))
182 for (name, cl) in inputs]
183 outputs = [(name, cl.__class__([None, None]))
184 for (name, cl) in outputs]
185 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outputs,
186 target_opset=target_opset,
187 domain=domain)
188 if "dim_value: 0" in str(self.onnx_):
189 raise RuntimeError( # pragma: no cover
190 f"Probable issue as one dimension is null.\n--\n{self.onnx_}")
191 forced = True
192 else:
193 self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs,
194 op_version=target_opset, domain=domain,
195 **options)
196 inputs = get_defined_inputs(
197 self.inputs, variables, dtype=self.dtype,
198 schema=self.alg_class.expected_inputs)
200 try:
201 self.onnx_ = self.inst_.to_onnx(
202 inputs, target_opset=target_opset, domain=domain)
203 if "dim_value: 0" in str(self.onnx_):
204 raise RuntimeError( # pragma: no cover
205 "Probable issue as one dimension is null.\n--\n{}\n---\n{}".format(
206 self.onnx_, inputs))
207 forced = False
208 except (RuntimeError, ValueError, InferenceError) as eo:
209 # Let's try again by forcing output types.
210 forced = True
211 outputs = get_defined_outputs(
212 self.outputs, self.onnx_node, inputs, variables,
213 dtype=self.dtype, schema=self.alg_class.expected_outputs,
214 schema_inputs=self.alg_class.expected_inputs)
215 try:
216 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outputs,
217 target_opset=target_opset,
218 domain=domain)
219 except NotImplementedError as e: # pragma: no cover
220 raise NotImplementedError(
221 "Unable to instantiate node {} inputs={} "
222 "self.inputs={} outputs={} variables={} "
223 "dtype={} e={} eo={}".format(
224 self.alg_class, inputs, self.inputs,
225 outputs, variables, self.dtype, e, eo)) from e
226 if "dim_value: 0" in str(self.onnx_):
227 raise RuntimeError( # pragma: no cover
228 "Probable issue as one dimension is null.\n--\n{}".format(
229 self.onnx_)) from e
231 if len(self.onnx_.graph.output) > len(self.outputs): # pragma: no cover
232 # Something is wrong, falls back to default plan.
233 forced = True
234 outputs = get_defined_outputs(
235 self.outputs, self.onnx_node, inputs, variables,
236 dtype=self.dtype, schema=self.alg_class.expected_outputs)
237 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outputs,
238 target_opset=target_opset,
239 domain=domain)
240 if "dim_value: 0" in str(self.onnx_):
241 raise RuntimeError( # pragma: no cover
242 f"Probable issue as one dimension is null.\n--\n{self.onnx_}")
243 else:
244 lo = list(self.onnx_.graph.output)
245 outputs = proto2vars(lo)
247 from onnxruntime import ( # pylint: disable=E0611
248 SessionOptions, RunOptions, GraphOptimizationLevel)
249 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
250 Fail as OrtFail, InvalidGraph as OrtInvalidGraph,
251 NotImplemented as OrtNotImplemented)
253 sess_options = session_options or SessionOptions()
254 self.run_options = RunOptions()
256 if session_options is None:
257 try:
258 sess_options.session_log_severity_level = 3
259 # sess_options.sessions_log_verbosity_level = 0
260 except AttributeError: # pragma: no cover
261 # onnxruntime not recent enough.
262 pass
263 try:
264 self.run_options.run_log_severity_level = 3
265 # self.run_options.run_log_verbosity_level = 0
266 except AttributeError: # pragma: no cover
267 # onnxruntime not recent enough.
268 pass
269 if disable_optimisation:
270 sess_options.graph_optimization_level = ( # pragma: no cover
271 GraphOptimizationLevel.ORT_DISABLE_ALL)
272 elif disable_optimisation:
273 raise RuntimeError( # pragma: no cover
274 "session_options and disable_optimisation cannot be defined "
275 "at the same time.")
277 if ir_version is not None:
278 self.onnx_.ir_version = ir_version
279 try:
280 self.sess_ = InferenceSession(
281 self.onnx_.SerializeToString(), sess_options=sess_options,
282 runtime=self.runtime)
283 except (RuntimeError, OrtNotImplemented, OrtInvalidGraph, OrtFail) as e:
284 raise RuntimeError(
285 "Unable to load node '{}' (output type was {}) inputs={} "
286 "self.inputs={} self.onnx_node.input={} "
287 "variables={} mapping={} "
288 "expected_inputs={}\n{}".format(
289 self.onnx_node.op_type,
290 "guessed" if forced else "inferred",
291 inputs, self.inputs, self.onnx_node.input,
292 variables, self.mapping,
293 self.alg_class.expected_inputs,
294 self.onnx_)) from e
295 self.typed_outputs_ = outputs
297 def run(self, *args, **kwargs):
298 """
299 Should be overwritten.
300 """
301 inputs = {name: val for name, val in zip(self.inputs, args)}
303 try:
304 res = self.sess_.run(None, inputs, self.run_options)
305 except (RuntimeError, self.OrtInvalidArgument) as e: # pragma: no cover
306 dtypes = {k: v.dtype for k, v in inputs.items()}
307 shapes = {k: v.shape for k, v in inputs.items()}
308 exp = [_.name for _ in self.sess_.get_inputs()]
309 exp_types = [_.type for _ in self.sess_.get_inputs()]
310 raise RuntimeError(
311 "Predictions failed. List of inputs: {}, class={}"
312 "\ndtypes={}\nshapes={}\nexpected={}\nexpected={}\n"
313 "exception={}\n--ONNX--\n{}".format(
314 list(sorted(inputs)), self.alg_class, dtypes,
315 shapes, exp, exp_types, e, self.onnx_)) from e
316 return tuple(res)
318 def need_context(self):
319 """
320 Tells the runtime if this node needs the context
321 (all the results produced so far) as it may silently access
322 one of them (operator Loop).
323 The default answer is `False`.
324 """
325 return False