Coverage for mlprodict/onnxrt/onnx_inference_node.py: 87%
271 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief OnnxInferenceNode definition.
4"""
5import sys
6import numpy
7from onnx import GraphProto, onnx_pb as onnx_proto
8from onnx.onnx_cpp2py_export.defs import SchemaError # pylint: disable=E0401,E0611
9from ..onnx_tools.onnx2py_helper import get_onnx_schema
10from .excs import MissingOperatorError
11from .ops import load_op
14class OnnxInferenceNode:
15 """
16 A node to execute.
18 :param onnx_node: onnx_node
19 :param desc: internal description
20 :param global_index: it is a function which returns a unique index
21 for the output this operator generates
22 """
23 class OnnxInferenceWrapper:
24 """
25 Wraps @see cl OnnxInference in a wrapper and exposes
26 the necessary function.
28 :param oinf: instance of @see cl OnnxInference
29 """
31 def __init__(self, oinf):
32 if oinf is None:
33 raise ValueError( # pragma: no cover
34 "oinf cannot be None.")
35 self.oinf = oinf
37 @property
38 def args_default(self):
39 "Returns the list of default arguments."
40 return []
42 @property
43 def args_default_modified(self):
44 "Returns the list of modified arguments."
45 return []
47 @property
48 def args_mandatory(self):
49 "Returns the list of mandatory arguments."
50 return self.oinf.input_names
52 @property
53 def args_optional(self):
54 "Returns the list of optional arguments."
55 return []
57 @property
58 def obj(self):
59 "Returns the ONNX graph."
60 return self.oinf.obj
62 def run(self, *args, **kwargs):
63 "Calls run."
64 return self.oinf.run(*args, **kwargs)
66 def to_python(self, inputs, *args, **kwargs):
67 "Calls to_python."
68 res = self.oinf.to_python(*args, **kwargs)
69 if len(res) != 1:
70 raise NotImplementedError( # pragma: no cover
71 "Not implemented if the code has multiple files.")
72 keys = list(res)
73 value = res[keys[0]]
74 lines = value.split('\n')
75 last = 0
76 for i, line in enumerate(lines):
77 if line.startswith('def '):
78 last = i - 1
79 break
80 imports = '\n'.join(
81 line for line in lines[:last] if 'import ' in line)
82 lines.append('')
83 lines.append(
84 f"return OnnxPythonInference().run({', '.join(inputs)})")
85 code = '\n'.join(lines[last:])
86 return imports, code
88 def need_context(self):
89 "Needs context?"
90 return False
92 def enable_inplace_compute(self, index):
93 "Not implemented."
94 pass
96 def __init__(self, onnx_node, desc, global_index):
97 if desc is None:
98 raise ValueError("desc should not be None.") # pragma: no cover
99 self.desc = desc
100 self.onnx_node = onnx_node
101 self._init(global_index)
103 @property
104 def name(self):
105 "Returns the ONNX name."
106 return "_".join(
107 [self.desc['domain'], self.onnx_node.op_type]).replace(
108 ".", "_").replace('__', '_').strip('_')
110 def _init(self, global_index):
111 """
112 Prepares the node.
113 """
114 self.op_type = self.onnx_node.op_type
115 self.order = -1
116 self.variable_to_clean = []
117 self.inputs = list(self.onnx_node.input)
118 self.outputs = list(self.onnx_node.output)
119 self.inplaces = []
120 self.inputs_indices = [global_index(name) for name in self.inputs]
121 self.outputs_indices = [global_index(name) for name in self.outputs]
122 self._global_index = global_index
124 def set_order(self, order):
125 """
126 Defines the order of execution.
127 """
128 self.order = order
130 def add_variable_to_clean(self, name):
131 """
132 Adds a variable which can be cleaned after the node
133 execution.
134 """
135 self.variable_to_clean.append(name)
137 def __str__(self):
138 "usual"
139 return "Onnx-{}({}) -> {}{}".format(
140 self.op_type, ", ".join(self.inputs), ", ".join(self.outputs),
141 " (name=%r)" % self.onnx_node.name
142 if self.onnx_node.name else "")
144 def __repr__(self):
145 "usual"
146 return self.__str__()
148 def setup_runtime(self, runtime=None, variables=None, rt_class=None,
149 target_opset=None, dtype=None, domain=None,
150 ir_version=None, runtime_options=None,
151 build_inference_node_function=None,
152 existing_functions=None):
153 """
154 Loads runtime.
156 :param runtime: runtime options
157 :param variables: registered variables created by previous operators
158 :param rt_class: runtime class used to compute
159 prediction of subgraphs
160 :param target_opset: use a specific target opset
161 :param dtype: float computational type
162 :param domain: node domain
163 :param ir_version: if not None, changes the default value
164 given by :epkg:`ONNX`
165 :param runtime_options: runtime options
166 :param build_inference_node_function: function creating an inference
167 runtime from an ONNX graph
168 :param existing_functions: existing function as a dictionary
169 `{ (domain, name): fct }`
171 .. versionchanged:: 0.9
172 Parameters *build_inference_node_function* and *existing_functions*
173 were added.
174 """
175 if self.desc is None:
176 raise AttributeError(
177 "desc should not be None.") # pragma: no cover
178 if rt_class is None:
179 # path used when this operator is a function.
180 self.function_ = OnnxInferenceNode.OnnxInferenceWrapper(runtime)
181 self.ops_ = None
182 return
184 self.function_ = None
185 self.preprocess_parameters(
186 runtime, rt_class, ir_version=ir_version,
187 target_opset=target_opset, existing_functions=existing_functions)
188 options = {'provider': runtime} if runtime else {}
189 if domain is not None:
190 options['domain'] = domain
191 if target_opset is not None:
192 options['target_opset'] = target_opset
193 if ir_version is not None:
194 options['ir_version'] = ir_version
195 if runtime_options is not None:
196 options.update({
197 k: v for k, v in runtime_options.items()
198 if k not in {'log_severity_level'}})
200 # existing functions?
201 key = (self.onnx_node.domain, self.onnx_node.name)
202 if existing_functions is not None and key in existing_functions:
203 self.ops_ = existing_functions[key]
204 return
206 # regular node
207 try:
208 if runtime is not None and runtime.startswith('onnxruntime2'):
209 self.ops_ = load_op(self.onnx_node, desc=self.desc,
210 options=options if options else None,
211 variables=variables, dtype=dtype,
212 runtime=runtime)
213 elif runtime in ('python_compiled', 'python_compiled_debug'):
214 options['provider'] = 'python'
215 self.ops_ = load_op(self.onnx_node, desc=self.desc,
216 options=options if options else None,
217 variables=variables, dtype=dtype,
218 runtime=runtime)
219 else:
220 self.ops_ = load_op(self.onnx_node, desc=self.desc,
221 options=options if options else None,
222 variables=variables, dtype=dtype,
223 runtime=runtime)
224 except MissingOperatorError as e:
225 try:
226 onnx_schema = get_onnx_schema(
227 self.onnx_node.op_type, self.onnx_node.domain,
228 opset=target_opset)
229 except SchemaError:
230 fct_names = (
231 list(existing_functions.keys()) if existing_functions
232 else [])
233 raise MissingOperatorError(
234 "Unable to find runtime for node (%r, %r), "
235 "available functions=%r." % (
236 self.onnx_node.domain, self.onnx_node.op_type,
237 fct_names)) from e
238 if onnx_schema is None or not onnx_schema.has_function:
239 raise e
240 self.function_ = OnnxInferenceNode.OnnxInferenceWrapper(
241 build_inference_node_function(onnx_schema.function_body))
242 self.ops_ = None
244 @staticmethod
245 def _find_static_inputs(body):
246 """
247 Determines the loop inputs. It is any defined inputs
248 by the subgraphs + any result used as a constant
249 in the subgraphs.
250 """
251 inputs_set = set(i.name for i in body.input)
252 for init in body.initializer:
253 inputs_set.add(init.name)
254 for node in body.node:
255 for i in node.output:
256 inputs_set.add(i)
257 add_inputs = []
258 for node in body.node:
259 for i in node.input:
260 if i not in inputs_set:
261 # no graph input or output node matches
262 # it must be a constant from the below graph
263 add_inputs.append(i)
264 inputs_set.add(i)
265 for att in node.attribute:
266 if (att.type == onnx_proto.AttributeProto.GRAPH and # pylint: disable=E1101
267 hasattr(att, 'g') and att.g is not None):
268 inside = OnnxInferenceNode._find_static_inputs(att.g)
269 for i in inside:
270 if i not in inputs_set:
271 add_inputs.append(i)
272 inputs_set.add(i)
273 # If there is no node, we add the outputs as well.
274 if len(body.node) == 0:
275 for o in body.output:
276 i = o.name
277 if i not in inputs_set:
278 add_inputs.append(i)
279 inputs_set.add(i)
280 return add_inputs
282 @staticmethod
283 def _find_local_inputs(graph):
284 """
285 Determines the local inputs. It is any defined input
286 used by the subgraph and defined in the parent graph.
287 """
288 if not isinstance(graph, GraphProto):
289 raise TypeError(
290 f"Unexpected type {type(graph)!r}.")
291 local = set()
292 known = set()
293 for init in graph.initializer:
294 known.add(init.name)
295 for init in graph.input:
296 known.add(init.name)
297 for node in graph.node:
298 for o in node.output:
299 known.add(o)
300 for i in node.input:
301 if i not in known:
302 local.add(i)
303 return list(local)
305 def get_local_inputs(self):
306 """
307 Returns any local input used by this node in a subgraph
308 defined as an attribute and not declared as an input of this subgraph.
309 """
310 req = set()
311 for att in self.onnx_node.attribute:
312 if hasattr(att, 'g') and att.g is not None:
313 req |= set(self._find_local_inputs(att.g))
314 return req
316 def preprocess_parameters(self, runtime, rt_class, ir_version=None,
317 target_opset=None, existing_functions=None):
318 """
319 Preprocesses the parameters, loads *GraphProto*
320 (equivalent to :epkg:`ONNX` graph with less metadata).
322 :param runtime: runtime options
323 :param rt_class: runtime class used to compute
324 prediction of subgraphs
325 :param ir_version: if not None, overwrites the default value
326 :param target_opset: use a specific target opset
327 :param existing_functions: existing functions
328 """
329 if 'atts' not in self.desc:
330 return # pragma: no cover
331 inside_loop = self.onnx_node.op_type in {'Loop'}
332 for _, v in self.desc['atts'].items():
333 if 'value' not in v:
334 continue # pragma: no cover
335 value = v['value']
336 if isinstance(value, onnx_proto.GraphProto):
337 static_inputs = OnnxInferenceNode._find_static_inputs(value)
338 if len(value.node) > 0:
339 try:
340 sess = rt_class(value, runtime=runtime,
341 ir_version=ir_version,
342 target_opset=target_opset,
343 inside_loop=inside_loop,
344 static_inputs=static_inputs,
345 existing_functions=existing_functions)
346 except RuntimeError as e: # pragma: no cover
347 raise RuntimeError(
348 "Unable to instantiate a node of type %r and name %r."
349 "" % (self.onnx_node.op_type, self.onnx_node.name)) from e
350 else:
351 # outputs already exists, usually branch then of else for If node
352 sess = rt_class(value, runtime=runtime,
353 ir_version=ir_version,
354 target_opset=target_opset,
355 inside_loop=inside_loop,
356 static_inputs=static_inputs,
357 existing_functions=existing_functions)
358 v['value_rt'] = sess
360 def _build_context(self, values, input_list):
361 context = {}
362 # input_list does not need to be sorted but when
363 # an input is not found, the returned error is always
364 # related to the same input.
365 for n in sorted(input_list):
366 try:
367 v = values[self._global_index(n)]
368 except IndexError as e: # pragma: no cover
369 raise IndexError(
370 f"Unable to find an index for result {n!r} in onnx object.") from e
371 if v is None:
372 raise ValueError( # pragma: no cover
373 f"Input {n!r} is None.")
374 context[n] = v
375 return context
377 def run(self, values, attributes=None, verbose=0, fLOG=None):
378 """
379 Runs the node.
380 The function updates values with outputs.
382 :param values: list of existing values
383 :param attributes: attributes known at function level
384 :param verbose: verbosity
385 :param fLOG: logging function
386 """
387 # This code takes time if the graph contains many nodes.
388 # Maybe a C++ container would help in that case (to skip GIL).
389 if self.inputs_indices is None:
390 args = list(values[k] for k in self.inputs)
391 else:
392 args = list(values[k] for k in self.inputs_indices)
394 if self.ops_ is None:
395 # Then a function.
396 if 'atts' in self.desc:
397 # attributes of a function
398 if attributes is None:
399 attributes = {}
400 else:
401 attributes = attributes.copy()
402 attributes.update(self.desc['atts'])
404 feeds = {}
405 for name, val in zip(self.function_.obj.input, args):
406 if val is None:
407 raise ValueError( # pragma: no cover
408 f"Input name {name!r} is None.")
409 feeds[name] = val
411 if verbose == 0 or fLOG is None:
412 outputs = self.function_.run(feeds, attributes=attributes)
413 else:
414 if verbose > 0:
415 fLOG('-- >%s[%s](%s) -- len(feeds)=%d' %
416 (self.function_.obj.name, self.function_.obj.domain,
417 ", ".join(self.function_.obj.input), len(feeds)))
418 outputs = self.function_.run(
419 feeds, attributes=attributes, verbose=verbose, fLOG=fLOG)
420 if verbose > 0:
421 fLOG('-- <%s[%s][%s]' %
422 (self.function_.obj.name, self.function_.obj.domain,
423 ", ".join(self.function_.obj.output)))
425 res = [outputs[k] for k in self.function_.obj.output]
426 else:
427 # Or an operator.
428 try:
429 if self.ops_.need_context():
430 context = self._build_context(values,
431 self.ops_.additional_inputs)
432 res = self.ops_.run(*args, context=context,
433 attributes=attributes,
434 verbose=verbose, fLOG=fLOG)
435 else:
436 res = self.ops_.run(
437 *args, attributes=attributes,
438 verbose=verbose, fLOG=fLOG)
439 except (ValueError, TypeError) as e:
440 raise RuntimeError( # pragma: no cover
441 "Unable to run operator %r, inputs=%r."
442 "" % (type(self.ops_), self.inputs)) from e
443 except OverflowError as e:
444 raise RuntimeError( # pragma: no cover
445 "Unable to run operator %r, inputs=%r."
446 "" % (type(self.ops_), self.inputs)) from e
448 if not isinstance(res, tuple):
449 raise RuntimeError( # pragma: no cover
450 f"Results of operator {type(self.ops_)!r} should be a tuple.")
452 if len(self.outputs) < len(res):
453 raise RuntimeError( # pragma: no cover
454 f"Mismatch number of outputs got {len(res)} "
455 f"for names {list(self.outputs)} "
456 f"for class {self.name!r})."
457 f"\n{self.desc}")
459 # This code takes times if the graph contains many nodes.
460 # Maybe a C++ container would help in that case (to skip GIL).
461 if self.outputs_indices is None:
462 for name, value in zip(self.outputs, res):
463 values[name] = value
464 else:
465 for i, r in enumerate(res):
466 values[self.outputs_indices[i]] = r
468 def switch_initializers_dtype(self, dtype_in=numpy.float32,
469 dtype_out=numpy.float64):
470 """
471 Switches all initializers to ``numpy.float64``.
472 This only works if the runtime is ``'python'``.
474 @param dtype_in previous type
475 @param dtype_out next type
476 @return done operations
477 """
478 done = []
479 for k, v in self.desc['atts'].items():
480 if 'value_rt' not in v:
481 continue
482 if isinstance(v['value_rt'], numpy.ndarray):
483 if v['value_rt'].dtype == dtype_in:
484 v['value_rt'] = v['value_rt'].astype(dtype_out)
485 done.append(("+", "desc", k, v['value_rt']))
486 else:
487 done.append(("-", "desc", k, v['value_rt']))
488 if hasattr(self, 'ops_') and self.ops_ is not None:
489 res = self.ops_.switch_initializers_dtype(dtype_in, dtype_out)
490 for r in res:
491 done.append(("ops_", ) + r)
492 return done
494 def enable_inplace_compute(self, name):
495 """
496 Let the node know that one input can be overwritten.
498 @param name input name
499 """
500 self.inplaces.append(name)
501 (self.ops_ or self.function_).enable_inplace_compute(
502 self.inputs.index(name))
504 @property
505 def inputs_args(self):
506 """
507 Returns the list of arguments as well as
508 the list of parameters with the default values
509 (close to the signature).
510 """
511 if not hasattr(self, 'ops_'):
512 raise AttributeError(
513 "Attribute 'ops_' is missing.") # pragma: no cover
514 sigs = []
515 ops_or_function = self.function_ if self.ops_ is None else self.ops_
516 mand = ops_or_function.args_mandatory
517 if mand is None:
518 mand = self.python_inputs
519 sigs.extend(mand)
520 if len(ops_or_function.args_optional) > 0:
521 sigs.extend(ops_or_function.args_optional)
522 if sys.version_info[:2] >= (3, 8):
523 sigs.append('/')
524 sigs.extend(ops_or_function.args_default)
525 return sigs
527 @property
528 def python_inputs(self):
529 """
530 Returns the python arguments.
531 """
532 if not hasattr(self, 'ops_'):
533 raise AttributeError(
534 "Attribute 'ops_' is missing.") # pragma: no cover
535 if hasattr(self.ops_, 'python_inputs'):
536 return self.ops_.python_inputs
537 return self.inputs
539 @property
540 def modified_args(self):
541 """
542 Returns the list of modified parameters.
543 """
544 if not hasattr(self, 'ops_'):
545 raise AttributeError(
546 "Attribute 'ops_' is missing.") # pragma: no cover
547 if self.ops_ is None:
548 return self.function_.args_default_modified
549 return self.ops_.args_default_modified
551 def to_python(self, inputs):
552 """
553 Returns a python code for this operator.
555 @param inputs inputs name
556 @return imports, python code, both as strings
557 """
558 if not hasattr(self, 'ops_'):
559 raise AttributeError(
560 "Attribute 'ops_' is missing.") # pragma: no cover
561 if self.ops_ is None:
562 return self.function_.to_python(inputs)
563 return self.ops_.to_python(inputs)