Coverage for mlprodict/npy/xop.py: 92%
1783 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# pylint: disable=E1101,C0302
2"""
3@file
4@brief Xop API to build onnx graphs. Inspired from :epkg:`sklearn-onnx`.
6.. versionadded:: 0.9
7"""
8import os
9import pprint
10import logging
11import hashlib
12import json
13from collections import OrderedDict
14import numpy
15from scipy.sparse.coo import coo_matrix
16import onnx
17from onnx import GraphProto, TensorProto, ValueInfoProto
18from onnx.helper import (
19 make_node, make_graph, make_model, make_value_info,
20 make_tensor_value_info, make_function, make_opsetid,
21 make_tensor_type_proto, make_operatorsetid)
22from onnx.numpy_helper import from_array, to_array
23from onnx.shape_inference import infer_shapes
24from ..onnx_tools.model_checker import check_onnx
25from ._cache import cache_folder
26from .xop_variable import (
27 Variable, is_numpy_dtype, numpy_type_prototype, max_supported_opset,
28 DetectedVariable, InputDetectedVariable, OutputDetectedVariable,
29 NodeResultName, guess_numpy_type, ExistingVariable)
30from .xop_auto import get_rst_doc
31from .xop_helper import _infer_node_output
34class _WrapperLogger:
35 """
36 Wrappers around class :class:`logging.Logger`
37 to take indentation into account.
38 """
40 def __init__(self, lg):
41 "constructor"
42 self._logger = lg
43 self._indent = 0
45 def debug(self, msg, *args):
46 "debug"
47 self._logger.debug("%s" + msg, " " * self._indent, *args)
49 def indent(self):
50 "indent"
51 self._indent += 1
53 def dedent(self):
54 "unindent"
55 self._indent -= 1
56 if self._indent < 0:
57 raise RuntimeError( # pragma: no cover
58 "Indentation cannot be negative.")
61class _WrapperPrint(_WrapperLogger):
62 """
63 Wrappers around print to help debugging.
64 """
66 def __init__(self):
67 "constructor"
68 _WrapperLogger.__init__(self, None)
70 def debug(self, msg, *args, indent=None):
71 "debug"
72 sign = ""
73 if indent is not None:
74 if not indent:
75 self.dedent()
76 sign = '< '
77 else:
78 sign = '> '
79 print(f"{' ' * self._indent}{sign}{msg} {' '.join(map(str, args))}")
80 if indent is not None:
81 if indent:
82 self.indent()
85logger = _WrapperLogger(logging.getLogger('xop'))
86local_print = _WrapperPrint().debug
89def _default_OPSET_TO_IR_VERSION():
90 """
91 Returns the default mapping between opset and ir_version.
93 .. runpython::
94 :showcode:
96 import pprint
97 from mlprodict.npy.xop import _default_OPSET_TO_IR_VERSION
98 pprint.pprint(_default_OPSET_TO_IR_VERSION())
99 """
100 return {
101 1: 3, 2: 3, 3: 3, 4: 3, 5: 3, 6: 3,
102 7: 3, 8: 4, 9: 4, 10: 5, 11: 6, 12: 7,
103 13: 7, 14: 7, 15: 8, 16: 8, 17: 8}
106def _domain_to_class_name(domain):
107 """
108 Converts domain into a name.
110 :param domain: domain name such as `ai.onnx.ml`
111 :return: string
113 .. runpython::
114 :showcode:
116 from mlprodict.npy.xop import _domain_to_class_name
117 print(_domain_to_class_name('ai.onnx.ml'))
118 """
119 if domain == 'ai.onnx':
120 return ''
121 dom = domain.split('.')
122 res = []
123 for d in dom:
124 if len(d) == 0:
125 res.append(d)
126 elif len(d) == 1:
127 res.append(d.upper())
128 else:
129 res.append(d[0].upper() + d[1:])
130 return "".join(res)
133class _CustomSchema:
134 """
135 For operators defined outside onnx.
136 """
138 class _empty:
139 "dummy class"
141 @staticmethod
142 def from_attribute(data):
143 "Creates an instance of `_CustomSchema._attribute`."
144 if not isinstance(data, dict):
145 raise TypeError( # pragma: no cover
146 f"Unexpected type {type(data)!r}.")
147 self = _CustomSchema._empty()
148 setattr(self, 'name', data['name'])
149 setattr(self, 'description', data['description'])
150 setattr(self, 'required', data['required'])
151 setattr(self, 'type', _CustomSchema._empty())
152 setattr(self.type, 'value', data['type'])
153 setattr(self, 'default_value', '?')
154 return self
156 @staticmethod
157 def from_io(data):
158 "Creates an instance of `_CustomSchema._io`."
159 if not isinstance(data, dict):
160 raise TypeError( # pragma: no cover
161 f"Unexpected type {type(data)!r}.")
162 self = _CustomSchema._empty()
163 setattr(self, 'name', data['name'])
164 setattr(self, 'typeStr', data['typeStr'])
165 setattr(self, 'description', data['description'])
166 setattr(self, 'option', _CustomSchema._empty())
167 setattr(self.option, 'value', data['option'])
168 setattr(self, 'isHomogeneous', data['isHomogeneous'])
169 return self
171 class _io:
172 "input, output"
174 def __init__(self, t):
175 self.name = t.name
176 self.typeStr = t.typeStr
177 if isinstance(t.option, int):
178 self.option = t.option
179 else:
180 self.option = t.option.value
181 self.description = t.description
182 self.isHomogeneous = t.isHomogeneous
184 def data(self):
185 "Returns all data in that class in a dictionary."
186 return {'name': self.name, 'typeStr': self.typeStr,
187 'description': self.description,
188 'isHomogeneous': self.isHomogeneous,
189 'option': self.option}
191 def __eq__(self, ot):
192 return self.name == ot.name and self.typeStr == ot.typeStr
194 class _attribute:
195 "attribute"
197 def __init__(self, att):
198 self.name = att.name
199 if isinstance(att.type, int):
200 self.type = att.type
201 else:
202 self.type = att.type.value
203 self.default_value = '?'
204 self.description = att.description
205 self.required = att.required
207 def data(self):
208 "Returns all data in that class in a dictionary."
209 return {'name': self.name, 'type': self.type,
210 'description': self.description,
211 'required': self.required}
213 def __eq__(self, ot):
214 return self.name == ot.name and self.type == ot.type
216 def __init__(self, schema):
217 self._schema = schema
218 self.domain = schema.domain
219 self.name = schema.name
220 self.since_version = schema.since_version
221 try:
222 self.inputs = [_CustomSchema._io(t) for t in schema.inputs]
223 except AttributeError as e: # pragma: no cover
224 raise AttributeError(
225 "Issue with operator=%r domain=%r since_version=%r, "
226 "type(schema)=%r" % (
227 schema.name, schema.domain, schema.since_version,
228 type(schema))) from e
229 try:
230 self.outputs = [_CustomSchema._io(t) for t in schema.outputs]
231 except AttributeError as e: # pragma: no cover
232 raise AttributeError(
233 "Issue with operator=%r domain=%r since_version=%r, "
234 "type(schema)=%r" % (
235 schema.name, schema.domain, schema.since_version,
236 type(schema))) from e
237 self.attributes = {a.name: _CustomSchema._attribute(a)
238 for a in schema.attributes.values()}
239 self.min_input = schema.min_input
240 self.max_input = schema.max_input
241 self.min_output = schema.min_output
242 self.max_output = schema.max_output
243 self.doc = schema.doc
245 _atts = ['domain', 'name', 'since_version', 'inputs', 'outputs',
246 'attributes', 'min_input', 'max_input',
247 'min_output', 'max_output', 'doc']
249 def __eq__(self, ot):
250 for k in _CustomSchema._atts:
251 if getattr(self, k) == getattr(ot, k):
252 continue
253 return False
254 return True
256 def data(self):
257 "Returns all data in that class in a dictionary."
258 def _(x):
259 if x is None:
260 return None
261 if isinstance(x, (str, int)):
262 return x
263 if isinstance(x, list):
264 return [_(e) for e in x]
265 if isinstance(x, dict):
266 return {k: _(v) for k, v in x.items()}
267 if hasattr(x, 'data'):
268 return x.data()
269 raise TypeError( # pragma: no cover
270 f"Unable to handle type {type(x)!r} - {x!r}.")
272 return {k: _(getattr(self, k)) for k in _CustomSchema._atts}
274 def SerializeToString(self):
275 "Serializes this class into json."
276 return json.dumps(self.data())
278 @staticmethod
279 def ParseFromString(s):
280 "Parses this class from a json string."
281 obj = json.loads(s)
282 e = _CustomSchema._empty()
283 for k in _CustomSchema._atts:
284 if k == 'attributes':
285 setattr(e, k, {a['name']: _CustomSchema._empty.from_attribute(a)
286 for a in obj[k].values()})
287 elif k in ('inputs', 'outputs'):
288 setattr(e, k, [_CustomSchema._empty.from_io(o)
289 for o in obj[k]])
290 else:
291 setattr(e, k, obj[k])
292 return _CustomSchema(e)
294 def __repr__(self):
295 return f"_CustomSchema(**{pprint.pformat(self.data())})"
298def _get_all_operator_schema():
299 data = os.path.join(os.path.dirname(__file__),
300 "ort_get_all_operator_schema.tmpl")
301 with open(data, 'r', encoding='utf-8') as f:
302 js = f.readlines()
303 return [_CustomSchema.ParseFromString(j) for j in js[1:]]
306def _populate_schemas():
307 """
308 Populates all schemas.
309 """
310 def _populate_schema(schema):
311 # Multiple version can coexist. The last one is kept.
312 key = schema.domain, schema.name
313 if key in res:
314 if schema.since_version > res[key].since_version:
315 # We keep the most recent one.
316 res[key] = schema
317 else:
318 res[key] = schema
319 full_name = schema.name + '_' + str(schema.since_version)
320 res[schema.domain, full_name] = schema
321 if key not in versions:
322 versions[key] = set()
323 if schema.name not in domains:
324 domains[schema.name] = set()
325 domains[schema.name].add(schema.domain)
326 versions[key].add(full_name)
328 res = {}
329 versions = {}
330 domains = {}
331 for schema in onnx.defs.get_all_schemas_with_history():
332 if schema.support_level == schema.SupportType.EXPERIMENTAL:
333 # Skips experimental operators.
334 continue
335 _populate_schema(schema)
337 try:
338 import onnxruntime.capi.onnxruntime_pybind11_state as rtpy
339 except ImportError: # pragma: no cover
340 rtpy = None
342 if rtpy is not None:
343 # If onnxruntime is available, it is being populated with these operators as well.
344 try:
345 get_schemas = rtpy.get_all_operator_schema
346 except AttributeError:
347 # onnxruntime must be compiled with flag --gen_doc.
348 # a local copy is retrieved.
349 get_schemas = _get_all_operator_schema
350 for op in get_schemas():
351 if (op.domain, op.name) in res:
352 # an existing onnx schema
353 continue
354 sch = _CustomSchema(op)
355 _populate_schema(sch)
357 return res, versions, domains
360def _find_operator_domain(name):
361 """
362 Determines the domain of an operator.
363 Raises an exception if not found or if there is an ambiguity.
365 :param name: operator name
366 :return: domain
367 """
368 if name not in _S.all_domains:
369 raise ValueError(
370 "Unable to guess domain for operator %r. "
371 "Not found in %r." % (name, list(_S.all_domains)))
372 domains = _S.all_domains[name]
373 if len(domains) == 1:
374 return list(domains)[0]
375 raise ValueError( # pragma: no cover
376 f"Unable to guess domain of operator {name!r}, found domains {domains!r}.")
379def _split_op_name(name):
380 spl = name.split('_')
381 try:
382 i = int(spl[-1])
383 except ValueError:
384 return name, None
385 return "_".join(spl[:-1]), i
388def ClassFactory(class_name, op_name, inputs, outputs,
389 input_range, output_range,
390 domain, attr_names, doc,
391 deprecated, since_version,
392 past_version):
393 """
394 Dynamically creates a class for a specific operator.
396 :param class_name: class name
397 :param op_name: operator type
398 :param inputs: expected inputs
399 :param outputs: expected outputs
400 :param input_range: input range
401 :param output_range: output_range
402 :param domain: domain
403 :param attr_names: attributes names
404 :param doc: docstring
405 :param deprecated: is the operator deprecated
406 :param since_version: available since version
407 :param past_version: list of versions
408 """
410 def __init__(self, *args, **kwargs):
412 op_version = kwargs.pop('op_version', None)
414 if op_version is None:
415 if len(args) == 0 and input_range[0] == input_range[1]:
416 args = [_[0] for _ in self.__class__.expected_inputs]
417 if not (input_range[0] <= len(args) <= input_range[1]):
418 raise RuntimeError( # pragma: no cover
419 "Unexpected number of inputs, "
420 "got {}, expecting {} for operator "
421 "'{}'.".format(
422 len(args), len(inputs), op_name))
424 attr_names = self.attr_names
425 _, op_version_class = _split_op_name(self.__class__.__name__)
426 if op_version_class is not None:
427 if op_version is None:
428 op_version = op_version_class
429 try:
430 op_version = min(op_version, op_version_class)
431 except TypeError: # pragma: no cover
432 raise TypeError( # pylint: disable=W0707
433 "Could not compare versions {} ? {} for "
434 "class '{}' since_version {}. Parameter 'op_version' "
435 "is probably missing when the class "
436 "is instantiated.".format(
437 op_version, op_version_class, class_name,
438 since_version))
439 else:
440 op_version_class = None
442 # By default, the op_version is None.
443 # None means the latest available.
444 if op_version is None:
445 op_version = since_version
447 found = None
448 if op_version is not None:
449 # attr_names refers to the most recent version of
450 # this operator. We may need an older one.
451 for op in range(op_version, 0, -1):
452 name = f'{self.__class__.__name__}_{op}'
453 if name in self.past_version:
454 found = (name, op)
455 attr_names = self.past_version[name].attr_names
456 if len(attr_names) > 0 and not isinstance(attr_names[0], str):
457 raise TypeError( # pragma: no cover
458 "attr_names must be a list of string not a list of %r for "
459 "operator %r and domain %r." % (
460 type(attr_names[0]), name, domain))
461 break
462 if (op_version_class is not None and found is not None and
463 found[-1] != op_version_class):
464 raise RuntimeError( # pragma: no cover
465 "op_version={} does not refer to the same opset as the class "
466 "name ('{}').".format(op_version, self.__class__.__name__))
467 for key in kwargs:
468 if key in {'output_names', 'op_version', 'domain', 'ir_version',
469 'global_context', 'clear_subgraph_inputs'}:
470 continue
471 if key not in attr_names:
472 raise TypeError( # pragma: no cover
473 "Argument '%s' not valid for '%s' domain=%r opset=%s "
474 "(should be in %r, type(self)=%r)." % (
475 key, op_name, domain, op_version, attr_names,
476 type(self)))
478 if op_version is not None:
479 kwargs['op_version'] = op_version
480 if 'domain' not in kwargs:
481 kwargs['domain'] = domain
482 # This class can only be created by a user. Let's check
483 # types are either a variable, an operator or an array.
484 for i, a in enumerate(args):
485 if isinstance(a, tuple):
486 if len(a) != 2:
487 raise TypeError( # pragma: no cover
488 "Input %r is a tuple or class %r, it must have two "
489 "elements (name, type) not %r." % (i, class_name, a))
490 if not isinstance(a[0], str):
491 raise TypeError( # pragma: no cover
492 "Input %r is a tuple or class %r, it must be a tuple "
493 "(name, type) not %r." % (i, class_name, a))
494 continue
495 if not isinstance(a, (
496 Variable, OnnxOperator, numpy.ndarray, str,
497 OnnxOperatorItem, coo_matrix)):
498 raise TypeError( # pragma: no cover
499 "Unexpected type %r for input %r of operator %r. "
500 "It must be an instance of Variable (or a string), "
501 "OnnxOperator, OnnxOperatorItem, numpy.ndarray, "
502 "coo_matrix)." % (
503 type(a), i, class_name))
504 OnnxOperator.__init__(self, *args, **kwargs)
506 newclass = type(class_name, (OnnxOperator,),
507 {"__init__": __init__, '__doc__': doc,
508 'expected_inputs': inputs,
509 'expected_outputs': outputs,
510 'operator_name': op_name,
511 'input_range': input_range,
512 'output_range': output_range,
513 'domain': domain,
514 'is_deprecated': deprecated,
515 'since_version': since_version,
516 'past_version': past_version,
517 'attr_names': attr_names,
518 'op_type': op_name,
519 '__module__': __name__})
520 return newclass
523def _dynamic_class_creation(operator_names=None, cache=False, include_past=False,
524 verbose=0, fLOG=print):
525 """
526 Automatically generates classes for each of the operators
527 module *onnx* defines and described at
528 `Operators
529 <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_
530 and `Operators
531 <https://github.com/onnx/onnx/blob/master/docs/
532 Operators-ml.md>`_.
534 :param operator_names: list of operators to request or None for all
535 :param cache: extract the documentation from onnx package and
536 saves it on disk it True
537 :param include_past: includes past versions if operator_names is None
538 :param verbose: display some progress
539 :param fLOG: logging function
540 :return: list of requested operators as a tuple
541 """
542 def _c(obj, label, i):
543 name = '%s%d' % (obj.name or label, i)
544 tys = obj.typeStr or ''
545 return (name, tys)
547 cache_dir = cache_folder()
548 if operator_names is None:
549 operator_names = list(_S.all_schemas_versions)
550 if include_past:
551 add = []
552 for domain, op in operator_names:
553 add.extend(
554 [(domain, k)
555 for k in _S.all_schemas_versions[domain, op]])
556 operator_names.extend(add)
557 operator_names.sort()
559 # type verification
560 ops = []
561 for name in operator_names:
562 if isinstance(name, str):
563 if name.startswith('Onnx'):
564 raise ValueError(
565 f"Operator name cannot start with Onnx: {name!r}.")
566 n_name, _ = _split_op_name(name)
567 domain = _find_operator_domain(n_name)
568 ops.append((domain, name))
569 elif isinstance(name, tuple) and len(name) == 2:
570 if name[1].startswith('Onnx'):
571 raise ValueError( # pragma: no cover
572 f"Operator name cannot starts with Onnx: {name!r}.")
573 ops.append(name)
574 else:
575 raise ValueError( # pragma: no cover
576 "Operator to fetch must be a string or a "
577 "`tuple(domain, name)` not %r." % (name))
578 operator_names = ops
580 # versions
581 res = _S.all_schemas
582 cls = {}
583 set_names = dict()
584 set_skip = set()
585 for pos, (op_domain, op_name) in enumerate(operator_names):
586 if op_domain == 'ai.onnx':
587 op_domain = ''
588 set_names[op_domain, op_name] = pos
589 n, v = _split_op_name(op_name)
590 if v is not None and not include_past:
591 set_skip.add((op_domain, n))
592 if n not in set_names:
593 set_names[op_domain, n] = -1
595 if verbose > 1 and fLOG is not None: # pragma: no cover
596 fLOG(f"[_dynamic_class_creation] set_names={set_names!r}")
597 fLOG(f"[_dynamic_class_creation] set_skip={set_skip!r}")
599 returned_classes = []
600 positions = {}
602 for (op_domain, op_name), position in set_names.items():
603 cl_name = 'Onnx' + _domain_to_class_name(op_domain) + op_name
604 if verbose > 3 and fLOG is not None:
605 fLOG( # pragma: no cover
606 '[_dynamic_class_creation] cl_name=%r op_domain=%r op_name=%r (in=%d) '
607 'position=%r' % (
608 cl_name, op_domain, op_name,
609 1 if cl_name in _S.all_classes else 0,
610 position))
611 if cl_name in _S.all_classes:
612 if cl_name not in set_skip:
613 if position >= 0:
614 returned_classes.append(
615 (position, _S.all_classes[cl_name]))
616 continue
618 # operator name without domain
619 n, v = _split_op_name(op_name)
620 if v is not None:
621 names = [op_name]
622 else:
623 try:
624 names = _S.all_schemas_versions[op_domain, op_name].copy()
625 except KeyError as e: # pragma: no cover
626 raise ValueError(
627 "Operator %r (domain=%r) does not exists." % (
628 op_name, op_domain)) from e
629 names.add(op_name)
631 if verbose > 0 and fLOG is not None:
632 fLOG( # pragma: no cover
633 "[_dynamic_class_creation] op_domain=%r op_name=%r, cl_name=%r names=%r"
634 "" % (op_domain, op_name, cl_name, names))
636 for name in names:
637 try:
638 schema = res[op_domain, name]
639 except KeyError as e:
640 raise ValueError(
641 "Operator (%r, %r) does not exists (available=%r)" % (
642 op_domain, name, pprint.pformat(list(res)))) from e
643 inputs = [_c(o, 'I', i) for i, o in enumerate(schema.inputs)]
644 outputs = [_c(o, 'O', i) for i, o in enumerate(schema.outputs)]
645 args = [p if isinstance(p, str) else p.name
646 for p in schema.attributes]
647 if len(args) > 0 and not isinstance(args[0], str):
648 raise TypeError( # pragma: no cover
649 "args must be a list of string not a list of %r for "
650 "operator %r and domain %r." % (
651 type(args[0]), name, op_domain))
653 n_name, v = _split_op_name(name)
655 if v is not None:
656 if op_domain == 'com.microsoft' and name in {
657 'SoftmaxGrad_13', 'LogSoftmaxGrad_13'}:
658 # exception
659 pass
660 elif v != schema.since_version:
661 raise ValueError( # pragma: no cover
662 "Inconsistent version number %d != %d for operator "
663 " %r, %r (%r)." % (
664 v, schema.since_version, schema.domain,
665 schema.name, name))
666 class_name = "Onnx" + _domain_to_class_name(op_domain) + name
667 else:
668 class_name = (
669 "Onnx" + _domain_to_class_name(op_domain) + schema.name)
671 if verbose > 0 and fLOG is not None:
672 fLOG( # pragma: no cover
673 "[_dynamic_class_creation] op_name=%r, cl_name=%r cache=%r v=%r"
674 "" % (op_name, class_name, cache, v))
676 filename = os.path.join(
677 cache_dir,
678 schema.name + '_' + str(schema.since_version) + ".rst")
679 if not cache and os.path.exists(filename):
680 with open(filename, "r", encoding="utf-8") as f: # pragma: no cover
681 doc = f.read()
682 else:
683 doc = get_rst_doc(schema.name, domain=schema.domain,
684 version=schema.since_version)
685 if cache: # pragma: no cover
686 with open(filename, 'w', encoding='utf-8') as f:
687 f.write(doc)
689 cl = ClassFactory(class_name, schema.name, inputs, outputs,
690 [schema.min_input, schema.max_input],
691 [schema.min_output, schema.max_output],
692 schema.domain, args,
693 "**Version**" + doc.split('**Version**')[-1],
694 getattr(schema, 'deprecated', False),
695 schema.since_version, {})
696 cls[class_name] = cl
697 if name == op_name:
698 positions[class_name] = position
700 # Retrieves past classes.
701 for name in cls: # pylint: disable=C0206
702 main, v = _split_op_name(name)
703 if v is None:
704 continue
705 if main in cls: # pylint: disable=R1715
706 last = cls[main]
707 else:
708 last = _S.all_classes[main]
709 last.past_version[name] = cls[name]
711 # final
712 _S.all_classes.update(cls)
713 for cl_name, v in cls.items():
714 if v not in set_skip and positions.get(cl_name, -1) >= 0:
715 returned_classes.append((positions[cl_name], v))
717 returned_classes.sort()
718 return tuple(e[1] for e in returned_classes)
721def loadop(*names, cache=False, verbose=0, fLOG=print):
722 """
723 Dynamically creates a class for a every operator type in
724 the given list.
725 """
726 res = _dynamic_class_creation(
727 names, cache=cache, verbose=verbose, fLOG=fLOG)
728 if len(res) == 1:
729 return res[0]
730 return res
733class OnnxLoadFactory:
734 """
735 Automatically creating all operators from onnx packages
736 takes time. That's why function @see cl loadop only creates
737 classes for the requested operators. This class does the same
738 when an attributes is requested.
740 ::
742 cl = OnnxLoadOperators()
743 x = cl.Add(...)
745 It is equivalent to:
747 ::
749 OnnxAdd = loadop('Add')
750 x = OnnxAdd(...)
751 """
753 def __init__(self):
754 self._loaded_classes = {}
756 def __getattr__(self, name):
757 """
758 Enables expressions such as:
760 ::
762 ops = OnnxLoadFactory()
763 op = ops.Abs('X')
764 """
765 if name == '_loaded_classes':
766 return self._loaded_classes
767 if name in self._loaded_classes:
768 return self._loaded_classes[name]
769 cl = loadop(name)
770 self._loaded_classes[name] = cl
771 self._loaded_classes[cl.__name__] = cl
772 return cl
775class OnnxOperatorBase:
776 """
777 Base class for @see cl OnnxOperator, @see cl OnnxOperatorItem,
778 @see cl OnnxOperatorTuple.
779 """
781 def __init__(self):
782 pass
784 def add_to(self, builder):
785 "This method should be overwritten."
786 raise NotImplementedError( # pragma: no cover
787 f"Not overwritten for class {type(self)!r}.")
789 @property
790 def output_names(self):
791 "This method should be overwritten."
792 raise NotImplementedError( # pragma: no cover
793 f"Not overwritten for class {type(self)!r}.")
795 def find_named_inputs(self):
796 """
797 Returns all inputs to the graph.
798 """
799 raise NotImplementedError( # pragma: no cover
800 f"Method 'find_named_inputs' must be overloaded for type {type(self)}.")
802 def f(self, *args, **kwargs):
803 """
804 Evaluates this node.
805 """
806 raise NotImplementedError( # pragma: no cover
807 f"Method 'f' must be overloaded for type {type(self)}.")
809 def _set_control_op(self, op, subgraph_inputs=None):
810 """
811 Tells this operator is part of a subgraph.
812 """
813 raise NotImplementedError( # pragma: no cover
814 f"Method '_set_control_op' must be overloaded for type {type(self)}.")
816 def add_external_input(self, op):
817 """
818 Tells a subgraph this node comes from the main graph.
819 It may be used only by the subgraph but it must be processed as well.
820 """
821 raise NotImplementedError( # pragma: no cover
822 f"Method '_set_control_op' must be overloaded for type {type(self)}.")
825class OnnxOperatorItem(OnnxOperatorBase):
826 """
827 Accessor to one of the output returned by a @see cl OnnxOperator.
829 :param onx_op: @see cl OnnxOperator
830 :param index: integer
831 :param op_version: defines the opset version
832 """
834 def __init__(self, onx_op, index, op_version=None):
835 OnnxOperatorBase.__init__(self)
836 if not isinstance(index, int):
837 raise TypeError( # pragma: no cover
838 f"index must be an integer not {type(index)!r}.")
839 logger.debug("op:%s-%d(%r, %d, op_version=%r)",
840 self.__class__.__name__, id(self), onx_op, index, op_version)
841 if not isinstance(onx_op, OnnxOperatorBase):
842 raise TypeError( # pragma: no cover
843 f"onx_op must be an OnnxOperator not {type(onx_op)!r}.")
844 self.onx_op = onx_op
845 self.index = index
846 self.op_version = op_version
848 @property
849 def output_names(self):
850 "Returns None."
851 return None
853 @property
854 def inputs(self):
855 "Returns the only inputs in a list."
856 return [NodeResultName(self.onx_op, self.index)]
858 def add_to(self, builder):
859 """
860 Adds to graph builder.
861 Does nothing because the original node is already added.
863 :param builder: instance of @see cl _GraphBuilder,
864 it must have a method `add_node`
865 """
866 pass
868 def __str__(self):
869 "usual"
870 return "%s[%d]" % (str(self.onx_op), self.index)
872 def __repr__(self):
873 "usual"
874 return "%s(%s[%d])" % (
875 self.__class__.__name__,
876 self.onx_op.__class__.__name__,
877 self.index)
879 def get_output_result(self, i=0):
880 """
881 Returns the output name at position *i*.
882 """
883 if i != 0:
884 raise IndexError( # pragma: no cover
885 "Can only return the first item.")
886 return self.onx_op.get_output_result(self.index)
888 def _to_onnx_attributes(self, inputs=None, target_opset=None,
889 optim=True, verbose=0, run_shape=True,
890 fLOG=print, processed=None):
891 """
892 Calls `self.onx_op._to_onnx_attributes`.
893 """
894 return self.onx_op._to_onnx_attributes(
895 inputs=inputs, target_opset=target_opset, optim=optim,
896 run_shape=run_shape, verbose=verbose, fLOG=fLOG,
897 processed=processed)
899 def find_named_inputs(self):
900 """
901 Returns all inputs to the graph.
902 """
903 return self.onx_op.find_named_inputs()
905 def f(self, *inputs, verbose=0, fLOG=None, # pylint: disable=W0221
906 clear_cache=False, runtime=None):
907 """
908 Computes the predictions for this node.
909 Similar to an eager evaluation.
911 :param inputs: inputs as dictionary or a list of inputs
912 (see below)
913 :param verbose: display information while predicting
914 :param fLOG: logging function if *verbose > 0*
915 :param clear_cache: onnx graph is created once unless
916 this parameter is True
917 :param runtime: runtime to use for the evaluation,
918 see @see cl OnnxInference
919 :return: outputs as a dictionary if the input were given as a
920 dictionary or a single result or a tuple otherwise
922 The inputs refer to the inputs of the graph.
923 The method walks through all inputs and finds inputs defined as
924 string. It replaces them by the value found in the dictionary.
925 If the inputs are specified in a list, the function retrieves the
926 list of inputs defined as a string and assigns them a value.
927 Logging function can be used to get more insight about it.
928 During the evaluation every node is independently converted
929 into ONNX. The ONNX graph is cached in the class itself.
930 """
931 res = self.onx_op.f(*inputs, verbose=verbose, fLOG=fLOG,
932 clear_cache=clear_cache, runtime=runtime)
933 if isinstance(res, dict):
934 names = self.onx_op.output_names
935 if names is None:
936 names = self.onx_op.expected_outputs
937 name = names[self.index][0]
938 else:
939 name = names[self.index]
940 return {name: res[name]}
941 return res[self.index]
944class OnnxOperatorTuple(OnnxOperatorBase):
945 """
946 Class used to return multiple @see cl OnnxVar
947 at the same time.
948 """
950 def __init__(self, first, *args):
951 OnnxOperatorBase.__init__(self)
952 logger.debug("op:%s-%d([%r], %d in)",
953 self.__class__.__name__, id(self), type(first),
954 len(args))
955 if isinstance(first, (list, tuple)):
956 raise TypeError( # pragma: no cover
957 f"Unexpected type for first {type(first)!r}.")
958 logger.debug('op:%s-%d(%d in)', self.__class__.__name__,
959 id(self), 1 + len(args))
960 if len(args) > 0:
961 self.values = (first,) + args
962 self.unique = None
963 else:
964 self.values = None
965 self.unique = first
966 if self.values is not None and self.unique is not None:
967 raise RuntimeError( # pragma: no cover
968 "Unexpected configuration. One member (values or unique) must be "
969 "null, unique=%r, values=%r" % (self.unique, self.values))
970 if self.values is None and self.unique is None:
971 raise RuntimeError( # pragma: no cover
972 "Unexpected configuration. One member (values or unique) must be "
973 "not null.")
975 def __repr__(self):
976 "usual"
977 if self.values is None:
978 return f"{self.__class__.__name__}({type(self.unique)!r})"
979 return "%s(%s)" % (self.__class__.__name__, ", ".join(
980 "%r" % type(v) for v in self.values))
982 @property
983 def inputs(self):
984 "Returns the only inputs in a list."
985 if self.values is None:
986 return [self.unique]
987 raise NotImplementedError( # pragma: no cover
988 "OnnxOperatorTuple.inputs is missing.")
990 @property
991 def external_inputs(self):
992 """
993 Returns the list of implicit inputs the subgraph
994 assumes to be existing even if they are not referenced as
995 explicit input for the graph.
996 """
997 if self.values is None:
998 return self.unique.external_inputs
999 res = []
1000 for op in self.values:
1001 res.extend(op.external_inputs)
1002 return res
1004 def add_to(self, builder):
1005 """
1006 Adds to graph builder.
1007 Does nothing because the original node is already added.
1009 :param builder: instance of @see cl _GraphBuilder,
1010 it must have a method `add_node`
1011 """
1012 pass
1014 def __len__(self):
1015 "usual"
1016 if self.values is None:
1017 raise NotImplementedError( # pragma: no cover
1018 "Not yet implemented in this case unique=%r, "
1019 "values=%r." % (self.unique, self.values))
1020 return len(self.values)
1022 def __iter__(self):
1023 "Iterates on the outputs."
1024 if self.values is None:
1025 raise NotImplementedError( # pragma: no cover
1026 "Not yet implemented in this case.")
1027 for v in self.values:
1028 yield v
1030 def __getitem__(self, i):
1031 "usual"
1032 if self.values is None:
1033 return self.unique[i]
1034 return self.values[i]
1036 @property
1037 def outputs(self):
1038 "Returns 'output_names' of attribute 'unique'."
1039 if self.values is None:
1040 if hasattr(self.unique, 'to_onnx'):
1041 return self.unique.outputs
1042 raise NotImplementedError( # pragma: no cover
1043 f"Not implemented yet unique={self.unique!r} values={self.values!r}.")
1045 @property
1046 def output_names(self):
1047 "Returns 'output_names' of attribute 'unique'."
1048 if self.values is None:
1049 if hasattr(self.unique, 'to_onnx'):
1050 return self.unique.output_names
1051 raise NotImplementedError( # pragma: no cover
1052 f"Not implemented yet unique={self.unique!r} values={self.values!r}.")
1054 @output_names.setter
1055 def output_names(self, value):
1056 """
1057 Updates 'output_names' of attribute 'unique'
1058 or every output name of attribute 'values'.
1059 """
1060 logger.debug("op:%s:output_names:set(%r)",
1061 self.__class__.__name__, value)
1062 OnnxIdentity = loadop('Identity') # pylint: disable=W0621
1063 if self.values is None:
1064 if (hasattr(self.unique, 'to_onnx') or
1065 hasattr(self.unique, 'add_to')):
1066 if len(value) > 1:
1067 self.values = tuple(
1068 OnnxIdentity(
1069 self.unique[i], output_names=value[i:i + 1],
1070 op_version=self.unique.op_version)
1071 for i in range(0, len(value)))
1072 self.unique = None
1073 return
1074 self.unique.output_names = [Variable(v) for v in value]
1075 return
1076 raise NotImplementedError( # pragma: no cover
1077 "Not implemented yet, value=%r, unique=%r values=%r." % (
1078 value, self.unique, self.values))
1079 if self.values is not None and len(self.values) == len(value):
1080 for name, v in zip(value, self.values):
1081 v.output_names = [Variable(name)]
1082 return
1083 raise NotImplementedError( # pragma: no cover
1084 "Not implemented yet, value=%r, unique=%r values=%r." % (
1085 value, self.unique, self.values))
1087 def _to_onnx_attributes(self, inputs=None, target_opset=None,
1088 optim=True, verbose=0, run_shape=True,
1089 fLOG=print, processed=None):
1090 """
1091 Calls `self.onx_op._to_onnx_attributes`.
1092 """
1093 if self.values is None:
1094 return self.unique._to_onnx_attributes(
1095 inputs=inputs, target_opset=target_opset, optim=optim,
1096 run_shape=run_shape, verbose=verbose, fLOG=fLOG,
1097 processed=processed)
1098 res = []
1099 for v in self.values:
1100 res.append(v._to_onnx_attributes(
1101 inputs=inputs, target_opset=target_opset, optim=optim,
1102 run_shape=run_shape, verbose=verbose, fLOG=fLOG,
1103 processed=processed))
1104 return res
1106 def to_onnx(self, inputs=None, outputs=None,
1107 other_outputs=None, target_opset=None,
1108 optim=True, verbose=0, run_shape=True,
1109 processed=None, check_model=True,
1110 return_builder=False, fLOG=None):
1111 """
1112 Converts this operator into an ONNX graph.
1113 It follows the same signature as :meth:`OnnxOperator.to_onnx
1114 <mlprodict.npy.xop.OnnxOperator.to_onnx>` and calls this
1115 method of the unique input object or the first one
1116 if there are several. In that case, other inputs in
1117 attribute `values` are moved into container
1118 `other_outputs`.
1120 (OnnxOperatorTuple)
1121 """
1122 logger.debug('op:%s-%d.to_onnx:%r:%r:%r',
1123 self.__class__.__name__, id(self),
1124 inputs, outputs, other_outputs)
1125 logger.indent()
1126 if self.values is None:
1127 res = self.unique.to_onnx(
1128 inputs=inputs, outputs=outputs, other_outputs=other_outputs,
1129 target_opset=target_opset, optim=optim, verbose=verbose,
1130 run_shape=run_shape, processed=processed, check_model=check_model,
1131 fLOG=fLOG, return_builder=return_builder)
1132 logger.dedent()
1133 return res
1134 new_other_outputs = self.values[1:]
1135 if other_outputs is not None:
1136 new_other_outputs.extend(other_outputs)
1137 res = self.values[0].to_onnx(
1138 inputs=inputs, outputs=outputs, other_outputs=new_other_outputs,
1139 target_opset=target_opset, optim=optim, verbose=verbose,
1140 run_shape=run_shape, processed=processed, check_model=check_model,
1141 fLOG=fLOG, return_builder=return_builder)
1142 logger.dedent()
1143 return res
1145 def find_named_inputs(self):
1146 """
1147 Returns all inputs to the graph.
1148 """
1149 if self.values is None:
1150 return self.unique.find_named_inputs()
1151 named = []
1152 for value in self.values:
1153 tmp = value.find_named_inputs()
1154 named.extend(tmp)
1155 return named
1157 def _set_control_op(self, op, subgraph_inputs=None):
1158 """
1159 Tells this operator is part of a subgraph.
1160 """
1161 logger.debug('op:%s-%d._set_control_op:%r',
1162 self.__class__.__name__, id(self), op)
1163 logger.indent()
1164 if self.values is None:
1165 raise NotImplementedError( # pragma: no cover
1166 "Not implemented yet.")
1167 for value in self.values:
1168 value._set_control_op(op, subgraph_inputs)
1169 logger.dedent()
1172class OnnxOperator(OnnxOperatorBase):
1173 """
1174 Ancestor to every *ONNX* operator exposed in
1175 :mod:`mlprodict.npy.xops` and :mod:`mlprodict.npy.xops_ml`.
1177 :param inputs: list of inputs expected by the operator
1178 :param op_version: to select a specific version of the operator
1179 :param output_names: used defined names for the outputs
1180 :param domain: to overwrite the default domain
1181 :param global_context: operator *If* executes one subgraph
1182 whose nodes may use one existing output in the current
1183 context. If not used in the main graph, these operators
1184 are not linked to the output and cannot be retrieved.
1185 *global_context* is a dictionary mapped the subgraph input
1186 names to these operators.
1187 :param kwargs: additional parameters of the operator
1189 .. versionadd:: 0.9
1190 """
1191 @classmethod
1192 def __class_getitem__(cls, opset):
1193 """
1194 Enables expression `cls[opset]`. It returns the appropriate class
1195 `cls_opset`. Parameter *op_version* should be specified.
1196 """
1197 if not isinstance(opset, int):
1198 raise ValueError(
1199 f"opset must an integer not {type(opset)!r}.")
1200 best = None
1201 for _, v in cls.past_version.items():
1202 if v.since_version == opset:
1203 return lambda *args, **kwargs: v(
1204 *args, op_version=opset, **kwargs)
1205 if v.since_version <= opset and (
1206 best is None or best.since_version < v.since_version):
1207 best = v
1208 if best is None:
1209 raise ValueError(
1210 "Unable to find a version of operator %r and opset %r." % (
1211 cls.__name__, opset))
1212 return lambda *args, **kwargs: best(
1213 *args, op_version=opset, **kwargs)
1215 def __init__(self, *inputs, op_version=None, output_names=None,
1216 domain=None, global_context=None, **kwargs):
1218 OnnxOperatorBase.__init__(self)
1219 logger.debug("op:%s-%d(%d in, op_version=%r, output_names=%r)",
1220 self.__class__.__name__, id(self),
1221 len(inputs), op_version,
1222 output_names)
1223 if (output_names is None and
1224 self.__class__.__name__.startswith("OnnxScan")):
1225 raise NotImplementedError( # pragma: no cover
1226 "The class cannot infer the number of variables "
1227 "for node '{}' yet. output_names must be specified"
1228 ".".format(self.__class__.__name__))
1229 if isinstance(output_names, (str, Variable)):
1230 output_names = [output_names]
1231 if isinstance(output_names[0], str):
1232 output_names[0] = Variable(output_names[0])
1233 elif isinstance(output_names, (list, OnnxOperator._InputContainer)):
1234 if len(output_names) == 0:
1235 raise ValueError( # pragma: no cover
1236 "output_names cannot be empty (operator %r)."
1237 "" % self.__class__.__name__)
1238 output_names = output_names.copy()
1239 for i in range(len(output_names)): # pylint: disable=C0200
1240 if isinstance(output_names[i], str):
1241 output_names[i] = Variable(output_names[i])
1242 elif output_names is not None:
1243 raise TypeError( # pragma: no cover
1244 f"output_names must be a string or a list not {type(output_names)!r}.")
1246 if op_version is None:
1247 if domain == '':
1248 self.op_version = max_supported_opset()
1249 else:
1250 self.op_version = None
1251 else:
1252 self.op_version = op_version
1253 self.since_version = self.__class__.since_version
1255 if (self.op_version is not None and
1256 self.op_version < self.since_version):
1257 schema = self.find_schema(self.op_version)
1258 self.since_version = schema.since_version
1259 self.expected_inputs = schema.expected_inputs.copy()
1260 self.expected_outputs = schema.expected_outputs.copy()
1261 self.input_range = schema.input_range
1262 self.output_range = schema.output_range
1263 else:
1264 self.expected_inputs = (
1265 None if self.__class__.expected_inputs is None
1266 else self.__class__.expected_inputs.copy())
1267 self.expected_outputs = (
1268 None if self.__class__.expected_outputs is None
1269 else self.__class__.expected_outputs.copy())
1270 self.input_range = self.__class__.input_range
1271 self.output_range = self.__class__.output_range
1272 if self.__class__.__name__ not in {
1273 'OnnxScan', 'OnnxLoop', 'OnnxIf'}:
1274 # The minimum opset depends on embedded graph
1275 # by default, it takes the given op_version but the
1276 # optimal value could be lower.
1277 self.op_version = self.since_version
1278 if self.op_version is None:
1279 self.op_version = self.since_version
1281 if (self.op_version is not None and
1282 self.op_version < self.since_version):
1283 raise RuntimeError( # pragma: no cover
1284 "Operator '{}': requested version {} < "
1285 "{} schema version.".format(
1286 self.__class__.__name__,
1287 self.op_version, self.since_version))
1289 self.state = None
1290 self.domain = domain
1291 self.kwargs = kwargs
1292 self.max_item_ = None
1294 # check inputs
1295 self.inputs = []
1296 if len(inputs) > 0:
1297 for inp in inputs:
1298 if isinstance(inp, str):
1299 self.inputs.append(Variable(inp))
1300 elif isinstance(inp, tuple):
1301 if len(inp) != 2:
1302 raise RuntimeError( # pragma: no cover
1303 f"Unexpected tuple {inp!r}.")
1304 self.inputs.append(
1305 Variable(inp[0], dtype=guess_numpy_type(inp[1]),
1306 shape=inp[1].shape))
1307 elif isinstance(inp, (OnnxOperatorBase, Variable)):
1308 self.inputs.append(inp)
1309 elif isinstance(inp, (numpy.ndarray, coo_matrix, TensorProto)):
1310 self.inputs.append(inp)
1311 elif isinstance(inp, ValueInfoProto):
1312 self.inputs.append(inp.type.tensor_type)
1313 else:
1314 raise TypeError( # pragma: no cover
1315 "Unable to interpret the input name for type {} in "
1316 "operator '{}' (value={}).".format(
1317 type(inp), self.__class__.__name__, inp))
1319 if (self.inputs is not None and
1320 (len(self.inputs) < self.input_range[0] or
1321 len(self.inputs) > self.input_range[1])):
1322 raise RuntimeError( # pragma: no cover
1323 "Operator '{}' expects a number of inputs in [{}, {}] not {} "
1324 "(expected opset={}, class opset={})".format(
1325 getattr(self, 'operator_name', '?'), *self.input_range,
1326 len(self.inputs), op_version, self.op_version))
1327 # global context
1328 if global_context is None:
1329 self.global_context = None
1330 else:
1331 if not isinstance(global_context, dict):
1332 raise TypeError( # pragma: no cover
1333 "global_context must be a dictionary not %r."
1334 "" % type(global_context))
1335 for k, v in global_context.items():
1336 if not isinstance(v, OnnxOperatorBase):
1337 raise TypeError( # pragma: no cover
1338 f"Value {k!r} in must be an OnnxOperatorBase not {type(v)!r}.")
1339 self.global_context = global_context
1341 # check output
1342 self.output_names_ = output_names
1343 self.output_variables = None
1345 if self.output_names is not None:
1346 if len(self.output_names) == 0:
1347 raise ValueError( # pragma: no cover
1348 "output_names can be None but cannot be empty for "
1349 "operator %r." % self)
1350 if self.output_variables is None:
1351 self.output_variables = [None for o in self.output_names]
1352 for i in range(len(self.output_names)): # pylint: disable=C0200
1353 name = self.output_names[i]
1354 if isinstance(name, Variable):
1355 self.output_variables[i] = name
1356 else:
1357 raise TypeError( # pragma: no cover
1358 "output_names must be a list of strings "
1359 "and element %r is %r (%r)" % (
1360 i, type(name), name))
1361 if all(map(lambda x: x is None, self.output_variables)):
1362 self.output_variables = None
1364 if (self.output_names is not None and (
1365 self.expected_outputs is None or
1366 len(self.output_names) > len(self.expected_outputs))):
1367 if self.expected_outputs is None:
1368 self.expected_outputs = []
1369 for i in range(len(self.expected_outputs),
1370 len(self.output_names)):
1371 self.expected_outputs.append((self.output_names[i], None))
1373 if (self.expected_inputs is None or
1374 len(self.inputs) > len(self.expected_inputs)):
1375 if self.expected_inputs is None:
1376 self.expected_inputs = []
1377 for i in range(len(self.expected_inputs),
1378 len(self.inputs)):
1379 inp = self.inputs[i]
1380 if isinstance(inp, str):
1381 inp = (inp, None)
1382 elif hasattr(inp, 'add_to'):
1383 # OnnxOperator
1384 existing = set(_[0] for _ in self.expected_inputs)
1385 i = 10
1386 name = "input%d" % (10 + i)
1387 while name in existing:
1388 i += 1
1389 name = "input%d" % (10 + i)
1390 inp = (name, None)
1391 self.expected_inputs.append(inp)
1393 self._post_process_attributes()
1394 self._check()
1395 self.external_inputs = []
1397 def add_external_input(self, op):
1398 """
1399 Tells a subgraph this node comes from a graph calling this one.
1400 """
1401 logger.debug("op:%s.add_external_input:%r",
1402 self.__class__.__name__, op)
1403 self.external_inputs.append(op)
1405 def do(self, body, subgraph_inputs=None):
1406 """
1407 Fills attribute *body*.
1409 :param branch: onnx graph or @see cl OnnxOperator
1410 :param subgraph_inputs: additional parameter to convert
1411 the subgraph into ONNX
1412 :return: self
1413 """
1414 if (isinstance(body, (onnx.GraphProto, onnx.ModelProto)) and
1415 subgraph_inputs is not None):
1416 raise RuntimeError( # pragma: no cover
1417 "inputs cannot be defined if body is a "
1418 "GraphProto or a ModelProto.")
1419 return self._add_subgraph(
1420 'body', body, subgraph_inputs=subgraph_inputs)
1422 def then_do(self, branch):
1423 """
1424 Fills attribute *then_branch*.
1426 :param branch: onnx graph or @see cl OnnxOperator
1427 :return: self
1428 """
1429 if isinstance(branch, onnx.GraphProto) and len(branch.input) > 0:
1430 raise RuntimeError( # pragma: no cover
1431 "then_branch subgraph cannot have any input.")
1432 return self._add_subgraph('then_branch', branch)
1434 def else_do(self, branch):
1435 """
1436 Fills attribute *else_branch*.
1438 :param branch: onnx graph or @see cl OnnxOperator
1439 :return: self
1440 """
1441 if isinstance(branch, onnx.GraphProto) and len(branch.input) > 0:
1442 raise RuntimeError( # pragma: no cover
1443 "else_branch subgraph cannot have any input.")
1444 return self._add_subgraph('else_branch', branch)
1446 def _add_subgraph(self, attribute, branch, subgraph_inputs=None):
1447 """
1448 Fills attribute *attribute*.
1450 :param attribute: attribute name
1451 :param branch: onnx graph or @see cl OnnxOperator
1452 :param subgraph_inputs: additional parameter to convert
1453 the subgraph into ONNX
1454 :return: self
1455 """
1456 if isinstance(branch, str):
1457 # branch is an input.
1458 OnnxIdentity = loadop('Identity')
1459 branch = OnnxIdentity(OnnxExisting(branch),
1460 op_version=self.op_version)
1461 logger.debug("op:%s:_add_subgraph:%s=type(branch)=%r",
1462 self.__class__.__name__, attribute, type(branch))
1463 if isinstance(branch, onnx.ModelProto):
1464 return self._add_subgraph(attribute, branch.graph)
1465 if isinstance(branch, onnx.GraphProto):
1466 self.kwargs[attribute] = branch
1467 return self
1468 if isinstance(branch, (OnnxOperator, OnnxOperatorTuple)):
1469 self.kwargs[attribute] = branch
1470 branch._set_control_op(self, subgraph_inputs=subgraph_inputs)
1471 return self
1472 raise TypeError( # pragma: no cover
1473 "Unexpected type %r for a subgraph, attribute %r "
1474 "and class %r." % (
1475 type(branch), attribute, self.__class__.__name__))
1477 def _set_control_op(self, op, subgraph_inputs=None):
1478 """
1479 Sets *control_op* for every instance of @see cl OnnxExisting node.
1481 :param op: operator calling the subgraph.
1482 :param inputs: additional parameters to convert
1483 into ONNX
1484 """
1485 if subgraph_inputs is not None:
1486 self.subgraph_inputs = subgraph_inputs
1488 for i, inp in enumerate(self.inputs):
1489 if isinstance(inp, OnnxOperatorBase):
1490 logger.debug("op:%s-%d:_set_control_op:propagate-into-input:%d:p:%d",
1491 self.__class__.__name__, id(self), i, id(op))
1492 logger.indent()
1493 inp._set_control_op(op)
1494 logger.dedent()
1495 if self.kwargs is None:
1496 return
1497 for k, v in self.kwargs.items():
1498 if isinstance(v, OnnxOperatorBase):
1499 logger.debug("op:%s-%d:_set_control_op:propagate-into-attribute:%s:p:%d",
1500 self.__class__.__name__, id(self), k, id(op))
1501 logger.indent()
1502 v._set_control_op(op)
1503 logger.dedent()
1505 @property
1506 def output_names(self):
1507 "Returns `self.output_names_`."
1508 return self.output_names_
1510 @output_names.setter
1511 def output_names(self, value):
1512 logger.debug("op:%s:output_names:set(%r)",
1513 self.__class__.__name__, value)
1514 if not isinstance(value, (list, OnnxOperator._InputContainer)):
1515 raise TypeError( # pragma: no cover
1516 f"Value must be a list not {type(value)!r}.")
1517 res = []
1518 for v in value:
1519 if isinstance(v, (Variable, ExistingVariable)):
1520 res.append(v)
1521 elif isinstance(v, str):
1522 res.append(Variable(v))
1523 else:
1524 raise TypeError( # pragma: no cover
1525 "Unexpected type %r for an output_names %r."
1526 "" % type(v))
1527 self.output_names_ = res
1529 def _check(self):
1530 input_types = (Variable, OnnxOperatorBase, numpy.ndarray,
1531 TensorProto)
1532 for o in self.inputs:
1533 if not isinstance(o, input_types):
1534 raise TypeError( # pragma: no cover
1535 f"Wrong type for inputs {self.inputs!r}.")
1536 if self.output_names is not None:
1537 for o in self.output_names:
1538 if not isinstance(o, Variable):
1539 raise TypeError( # pragma: no cover
1540 f"Wrong type for output_names {self.output_names!r}.")
1542 def _post_process_attributes(self):
1543 """
1544 Walks through attributes and replaces them by ONNX values.
1545 """
1546 # Looks into attributes if there is any tuple
1547 # (GraphProto, OnnxOperator). In that case, the function
1548 # replaces the tuple by the graph proto and keeps
1549 # in attributes graph_algebra the OnnxOperator
1550 # which is the source of it.
1551 updates = {}
1552 graph_algebra = {}
1553 for k, v in self.kwargs.items():
1554 if isinstance(v, tuple) and isinstance(v[0], GraphProto):
1555 updates[k] = v[0]
1556 graph_algebra[k] = v[1]
1558 if len(graph_algebra) > 0:
1559 self.kwargs.update(updates)
1560 self.graph_algebra = graph_algebra
1562 if self.__class__.__name__ == "OnnxConstantOfShape":
1563 if "value" in self.kwargs:
1564 value = self.kwargs['value']
1565 if isinstance(value, TensorProto):
1566 return
1567 if isinstance(value, numpy.ndarray):
1568 if value.shape == (1, ):
1569 val = value[0]
1570 elif len(value.shape) == 0:
1571 val = value
1572 else:
1573 raise RuntimeError( # pragma: no cover
1574 "Unexpected shape %r for value, it must be "
1575 "an array of one element." % value.shape)
1576 self.kwargs['value'] = from_array(
1577 numpy.array([val], dtype=value.dtype))
1578 return
1579 raise TypeError( # pragma: no cover
1580 "Unexpected type %r for value. It should be an array "
1581 "of one element." % type(value))
1582 return
1584 if self.__class__.__name__ == "OnnxCast":
1585 if "to" in self.kwargs:
1586 value = self.kwargs['to']
1587 if not isinstance(value, int):
1588 try:
1589 to = numpy_type_prototype(value)
1590 except ValueError as e: # pragma: no cover
1591 raise ValueError(
1592 "Unable to convert argument to in operator cast, "
1593 "type is %r, value is %r." % (type(value), value)) from e
1594 self.kwargs['to'] = to
1595 return
1597 def update_max_item(self, index):
1598 """
1599 Some operators return a undefined number of outputs.
1600 The method is called when require one of them (with `__getitem__`)
1601 and keeps the greater requested index assuming the node does
1602 not output any result beyond that index.
1604 :param index: requested index
1605 """
1606 if self.max_item_ is None:
1607 self.max_item_ = index
1608 else:
1609 self.max_item_ = max(self.max_item_, index)
1610 if self.expected_outputs is None:
1611 self.expected_outputs = []
1612 while len(self.expected_outputs) <= self.max_item_:
1613 self.expected_outputs.append(
1614 (("NEWOUTPUT", len(self.expected_outputs)), None))
1616 def find_schema(self, op_version):
1617 """
1618 Checks if there is an existing schema for a specific version.
1620 :param op_version: requested version
1621 :return: schema
1622 """
1623 if not hasattr(self.__class__, 'past_version'):
1624 raise RuntimeError( # pragma: no cover
1625 "Missing attribute 'past_version', there is "
1626 "no other available schema.")
1627 found = None
1628 for v in self.past_version.values():
1629 if v.since_version > op_version:
1630 continue
1631 if found is None or v.since_version > found.since_version:
1632 found = v
1633 if found is None:
1634 raise RuntimeError( # pragma: no cover
1635 "Operator '{}': requested version {} < "
1636 "{} schema version (past_version {}).".format(
1637 self.__class__.__name__,
1638 op_version, self.since_version,
1639 [v.since_version for v in self.past_version.values()]))
1640 return found
1642 def __repr__(self):
1643 """
1644 usual
1645 """
1646 return "{}({} in) -> {}".format(
1647 self.__class__.__name__,
1648 len(self.inputs) if self.inputs is not None else 0,
1649 [str(o) for o in self.output_names]
1650 if self.output_names is not None else "?")
1652 def get_output_result(self, i=0):
1653 """
1654 Returns the output name at position *i*.
1655 """
1656 return NodeResultName(self, i)
1658 def __getitem__(self, index):
1659 """
1660 Returns an accessor to one of the output
1661 of this node.
1662 """
1663 self.update_max_item(index)
1664 return OnnxOperatorItem(self, index, self.op_version)
1666 def __iter__(self):
1667 """
1668 Allows expressions such as ``a, b = OnnxTopK(...)``.
1669 """
1670 n = None
1671 if self.output_names is not None:
1672 n = len(self.output_names)
1673 else:
1674 rg = self.output_range
1675 if rg[0] == rg[1] and rg[0] > 0:
1676 n = rg[0]
1677 if n is None and self.max_item_ is not None:
1678 n = self.max_item_ + 1
1679 if n is None:
1680 raise RuntimeError( # pragma: no cover
1681 "Unable to guess the number of outputs of node type %r. "
1682 "Uses operator [] to select a specific output." %
1683 self.__class__.__name__)
1684 if self.max_item_ is not None:
1685 n = max(n, self.max_item_ + 1)
1686 for i in range(n):
1687 yield self[i]
1689 def add_to(self, builder):
1690 """
1691 Adds to graph builder.
1693 :param builder: instance of @see cl _GraphBuilder,
1694 it must have a method `add_node`
1695 """
1696 logger.debug("op:%s-%d.add_to(builder-%d):1",
1697 self.__class__.__name__, id(self), id(builder))
1698 inputs = builder.get_input_names(self, self.inputs)
1699 if self.output_names is not None:
1700 n_outputs = len(self.output_names)
1701 elif self.expected_outputs is not None:
1702 n_outputs = len(self.expected_outputs)
1703 else:
1704 n_outputs = self.output_range[0]
1705 outputs = [builder.get_unique_output_name(NodeResultName(self, i))
1706 for i in range(n_outputs)]
1707 logger.debug("op:%s-%d.add_to(builder-%d):2:%s:%r:%r",
1708 self.__class__.__name__, id(self), id(builder),
1709 self.operator_name, inputs, outputs)
1710 logger.indent()
1711 builder.add_node(
1712 self.operator_name,
1713 builder.get_unique_name(
1714 '_' + self.operator_name.lower(), reserved=False),
1715 inputs, outputs, domain=self.domain, opset=self.op_version,
1716 **self.kwargs)
1717 logger.dedent()
1718 logger.debug("op:%s-%d.add_to(builder-%d):3",
1719 self.__class__.__name__, id(self), id(builder))
1721 @staticmethod
1722 def _node_to_graph_preprocess_list(inputs):
1723 new_inputs = OrderedDict()
1724 for el in inputs:
1725 if isinstance(el, str):
1726 new_inputs[el] = Variable(el)
1727 elif isinstance(el, Variable):
1728 new_inputs[el.name] = el
1729 elif isinstance(el, tuple) and len(el) == 2:
1730 # sklearn-onnx
1731 new_inputs[el[0]] = Variable(
1732 el[0], guess_numpy_type(el[1]), el[1].shape)
1733 elif isinstance(el, ValueInfoProto):
1734 new_inputs[el.name] = el
1735 else:
1736 raise TypeError( # pragma: no cover
1737 f"Unable to handle input type {type(el)!r} ({el!r}).")
1738 return new_inputs
1740 @staticmethod
1741 def _node_to_graph_process_input(processed, inputs, set_inputs, node, inp,
1742 new_inputs, new_stack, inputs_dtype,
1743 as_function=False):
1744 if not as_function and inputs is None and inputs_dtype is None:
1745 raise RuntimeError( # pragma: no cover
1746 "Both inputs and inputs_dtype cannot be None at the same time "
1747 "for inp=%r." % (inp, ))
1749 if isinstance(inp, OnnxExisting):
1750 if inp.inputs[0].output_names is None:
1751 raise RuntimeError( # pragma: no cover
1752 "output_names cannot be None for OnnxExisting, "
1753 "subop is %r." % (inp.inputs[0], ))
1754 # We need to check that this input was not already added.
1755 oinp = inp.inputs[0].output_names[0]
1756 if not new_inputs.has_input(oinp) and id(inp.inputs[0]) not in processed:
1757 raise RuntimeError( # pragma: no cover
1758 "This node id=%d (%r) was not added yet in the subgraph "
1759 "but it must be from node %r." % (
1760 id(inp.inputs[0]), inp.inputs[0], node))
1761 elif isinstance(inp, OnnxOperator):
1762 new_stack.append(inp)
1763 logger.debug("op:static:SG-op:processed[%d]:%s",
1764 id(inp), inp.__class__.__name__)
1765 processed[id(inp)] = inp
1766 elif isinstance(inp, OnnxOperatorItem):
1767 new_stack.append(inp)
1768 logger.debug("op:static:SG-it:processed[%d]:%s",
1769 id(inp), inp.__class__.__name__)
1770 processed[id(inp)] = inp
1771 new_stack.append(inp.onx_op)
1772 logger.debug("op:static:SG-op:processed[%d]:%s",
1773 id(inp.onx_op), inp.onx_op.__class__.__name__)
1774 processed[id(inp.onx_op)] = inp.onx_op
1775 elif isinstance(inp, OnnxOperatorTuple):
1776 # new_stack.append(inp)
1777 # new_stack.append(inp.onx_op)
1778 raise NotImplementedError( # pragma: no cover
1779 "Unable to guess inputs when one input is OnnxOperatorTuple.")
1780 elif isinstance(inp, Variable):
1781 if inp.name in set_inputs:
1782 return
1783 if inp.name == '':
1784 return
1785 logger.debug("op:static:SG-var:processed[%d]:%s",
1786 id(inp), inp.__class__.__name__)
1787 processed[id(inp)] = inp
1788 set_inputs.add(inp.name)
1789 if inputs is None and inputs_dtype is None:
1790 new_inputs.append(InputDetectedVariable(node, inp))
1791 elif isinstance(inputs, dict):
1792 if inp.name in inputs:
1793 var = InputDetectedVariable(
1794 node, inp.copy_merge(inputs[inp.name]))
1795 new_inputs.append(var)
1796 else:
1797 external_inputs = {
1798 ei.name: ei for ei in node.external_inputs
1799 if isinstance(ei, Variable)}
1800 if inp.name not in external_inputs:
1801 # This happens when an input is used for the first time
1802 # inside a sub-sub-graph.
1803 var = InputDetectedVariable(node, Variable(inp.name))
1804 elif inp.name in set_inputs:
1805 var = InputDetectedVariable(
1806 node, inp.copy_merge(external_inputs[inp.name]))
1807 else:
1808 raise ValueError( # pragma: no cover
1809 f"Unable to find input {inp!r} in {inputs!r}, "
1810 f"new_inputs={new_inputs!r}, "
1811 f"type(node)={type(node)!r}, "
1812 f"node.external_inputs={node.external_inputs!r}, "
1813 f"node={node!r}.")
1814 new_inputs.append(var)
1815 elif inputs_dtype is not None:
1816 new_inputs.append(
1817 InputDetectedVariable(node, inp.copy_add(inputs_dtype)))
1818 elif isinstance(inputs, Variable):
1819 if inp.name == inputs.name:
1820 new_inputs.append(
1821 InputDetectedVariable(node, inp.copy_merge(inputs)))
1822 else:
1823 new_inputs.append(InputDetectedVariable(node, inp))
1824 else:
1825 raise RuntimeError( # pragma: no cover
1826 f"Unable to handle inputs={inputs!r}.")
1827 elif isinstance(inp, numpy.ndarray):
1828 pass
1829 else:
1830 raise TypeError( # pragma: no cover
1831 f"Unexpected input type {type(inp)!r} in node type {type(node)!r}.")
1833 @staticmethod
1834 def _node_to_graph_get_type(node, name=None, outputs=None,
1835 outputs_dtype=None):
1836 if outputs is None:
1837 return outputs_dtype, None
1838 if isinstance(outputs, Variable):
1839 if name is None:
1840 return (outputs.dtype or outputs_dtype, None)
1841 if isinstance(name, Variable):
1842 return (outputs.dtype or name.dtype or outputs_dtype,
1843 None)
1844 raise RuntimeError( # pragma: no cover
1845 f"Unable to handle outputs={outputs!r}.")
1846 if isinstance(outputs, dict):
1847 if name is None:
1848 return _infer_node_output(node, outputs)
1849 if isinstance(name, Variable):
1850 n = name.name
1851 else:
1852 n = name
1853 if n not in outputs:
1854 return None, None
1855 return outputs[n], None
1856 if isinstance(outputs, (list, OnnxOperator._InputContainer)):
1857 raise NotImplementedError( # pragma: no cover
1858 f"Unexpected type for name={name!r}, outputs={outputs!r}.")
1859 if is_numpy_dtype(outputs):
1860 return outputs, None
1861 raise RuntimeError( # pragma: no cover
1862 f"Unable to handle outputs={outputs!r}.")
1864 @staticmethod
1865 def _node_to_graph_reorder_by_name(new_inputs, inputs):
1866 memo = OrderedDict((n.name, n) for n in new_inputs)
1867 done = set()
1868 result = []
1869 for inp in inputs:
1870 if inp.name in memo:
1871 result.append(memo[inp.name])
1872 done.add(inp.name)
1873 for k, v in memo.items():
1874 if k in done:
1875 continue
1876 result.append(v)
1877 return result
1879 class _InputContainer:
1881 def __init__(self):
1882 self._c = []
1883 self._names = set()
1885 def has_input(self, inp):
1886 "Checks that input *inp* is part the list of names."
1887 if isinstance(inp, str):
1888 return inp in self._names
1889 if inp.name in self._names:
1890 return True
1891 return False
1893 def append(self, inp):
1894 "Append one element to the list."
1895 name = inp.var.name
1896 self._c.append(inp)
1897 self._names.add(name)
1899 def __len__(self):
1900 return len(self._c)
1902 def __repr__(self):
1903 return f"{'_InputContainer'}(\n {pprint.pformat(self._c)})"
1905 def __iter__(self):
1906 for inp in self._c:
1907 yield inp
1909 def _node_to_graph(self, other_outputs=None, inputs=None, outputs=None,
1910 as_function=False, processed=None):
1911 """
1912 Builds a graph as a list of nodes to walk through in that order.
1913 """
1914 if processed is None:
1915 raise RuntimeError( # pragma: no cover
1916 "processed cannot be None.")
1917 node_outputs = [self]
1918 if other_outputs is not None:
1919 node_outputs += other_outputs
1921 if inputs is not None:
1922 logger.debug("op:%s-%d._node_to_graph:1:inputs=%r",
1923 self.__class__.__name__, id(self), inputs)
1924 if outputs is not None:
1925 logger.debug("op:%s-%d._node_to_graph:1:outputs=%r",
1926 self.__class__.__name__, id(self), outputs)
1928 # preprocess inputs, outputs
1929 _keep_inputs = None
1930 inputs_dtype = None
1931 if isinstance(inputs, (list, OnnxOperator._InputContainer)):
1932 _keep_inputs = inputs
1933 inputs_dict = self._node_to_graph_preprocess_list(inputs)
1934 elif isinstance(inputs, dict):
1935 inputs_dict = inputs
1936 elif isinstance(inputs, Variable):
1937 inputs = [inputs]
1938 inputs_dict = self._node_to_graph_preprocess_list(inputs)
1939 elif is_numpy_dtype(inputs):
1940 inputs_dtype = inputs
1941 inputs_dict = None
1942 else:
1943 raise TypeError( # pragma: no cover
1944 f"Unexpected type {type(inputs)!r} for inputs.")
1946 _keep_outputs = None
1947 outputs_dtype = None
1948 if isinstance(outputs, (list, OnnxOperator._InputContainer)):
1949 _keep_outputs = outputs
1950 outputs_dict = self._node_to_graph_preprocess_list(outputs)
1951 elif isinstance(outputs, dict):
1952 outputs_dict = outputs
1953 elif isinstance(outputs, Variable):
1954 outputs = [outputs]
1955 outputs_dict = self._node_to_graph_preprocess_list(outputs)
1956 elif is_numpy_dtype(outputs):
1957 outputs_dtype = outputs
1958 outputs_dict = None
1959 else:
1960 raise TypeError( # pragma: no cover
1961 f"Unexpected type {type(outputs)!r} for outputs.")
1963 if inputs is not None:
1964 logger.debug("op:%s-%d._node_to_graph:2:inputs=%r",
1965 self.__class__.__name__, id(self), inputs)
1966 if outputs is not None:
1967 logger.debug("op:%s-%d._node_to_graph:2:outputs=%r",
1968 self.__class__.__name__, id(self), outputs)
1969 if inputs_dict is not None:
1970 logger.debug("op:%s-%d._node_to_graph:2:inputs_dict=%r",
1971 self.__class__.__name__, id(self), inputs_dict)
1972 if outputs_dict is not None:
1973 logger.debug("op:%s-%d._node_to_graph:2:outputs_dict=%r",
1974 self.__class__.__name__, id(self), outputs_dict)
1975 if inputs_dtype is not None:
1976 logger.debug("op:%s-%d._node_to_graph:2:inputs_dtype=%r",
1977 self.__class__.__name__, id(self), inputs_dtype)
1978 if outputs_dtype is not None:
1979 logger.debug("op:%s-%d._node_to_graph:2:outputs_dtype=%r",
1980 self.__class__.__name__, id(self), outputs_dtype)
1982 # walk through graph
1983 stack = list(node_outputs)
1984 new_inputs = self._InputContainer()
1985 set_inputs = set()
1986 memo = []
1987 while len(stack) > 0:
1988 logger.debug("op:%s-%d._node_to_graph:loop:len(memo)=%d",
1989 self.__class__.__name__, id(self), len(memo))
1990 memo.extend(stack)
1991 new_stack = []
1992 for obj in stack:
1993 logger.debug("op:%s-%d._node_to_graph:-node=%r:external_inputs=%r",
1994 self.__class__.__name__, id(self),
1995 obj.__class__.__name__,
1996 getattr(obj, 'external_inputs', "-"))
1997 if isinstance(obj, OnnxExisting):
1998 pass
1999 elif isinstance(obj, OnnxOperatorItem):
2000 # nothing to do, OnnxOperatorItem is created
2001 # by OnnxOperator.__getitem__.
2002 pass
2003 elif isinstance(obj, (OnnxOperator, OnnxOperatorTuple)):
2004 if len(obj.external_inputs) > 0:
2005 # external_inputs are inputs required by a subgraph
2006 # but not necessarily used in the main graph.
2007 # They need to be processed first.
2008 for inp in obj.external_inputs:
2009 self._node_to_graph_process_input(
2010 processed, inputs_dict, set_inputs, obj, inp, new_inputs,
2011 new_stack, inputs_dtype, as_function=as_function)
2012 for inp in obj.inputs:
2013 self._node_to_graph_process_input(
2014 processed, inputs_dict, set_inputs, obj, inp, new_inputs,
2015 new_stack, inputs_dtype, as_function=as_function)
2016 else:
2017 raise TypeError( # pragma: no cover
2018 f"Unexpected type {type(obj)!r}.")
2019 stack = new_stack
2021 # reorder new_inputs to follow inputs initial order
2022 if _keep_inputs is not None:
2023 new_inputs = self._node_to_graph_reorder_by_name(
2024 new_inputs, inputs)
2026 logger.debug("op:%s-%d._node_to_graph:new_inputs=%r",
2027 self.__class__.__name__, id(self), new_inputs)
2029 # eliminate duplicates
2030 done = set()
2031 nodes = []
2032 for node in reversed(memo):
2033 if id(node) in done:
2034 continue
2035 done.add(id(node))
2036 nodes.append(node)
2038 # outputs
2039 set_names = set()
2040 new_outputs = []
2041 run_shape = False
2042 for node in node_outputs:
2043 if node.output_names is None:
2044 n = self.output_range[0]
2045 for i in range(n):
2046 to, shape = self._node_to_graph_get_type(
2047 node, outputs=outputs_dict,
2048 outputs_dtype=outputs_dtype)
2049 if to is None:
2050 run_shape = True
2051 res = f'xop_{id(node)}_{i}'
2052 var = Variable(res, added_dtype=to, shape=shape)
2053 if var.name in set_names:
2054 raise RuntimeError( # pragma: no cover
2055 f"Duplicated output name var={var!r} in "
2056 f"{set_names!r}.")
2057 set_names.add(var.name)
2058 new_outputs.append(OutputDetectedVariable(node, var, i))
2059 else:
2060 for i, o in enumerate(node.output_names):
2061 if isinstance(o, str):
2062 raise TypeError( # pragma: no cover
2063 "Output %d - %r (%r) not allowed in node %r." % (
2064 i, o, node.output_names, node))
2065 to, shape = self._node_to_graph_get_type(
2066 node, o, outputs=outputs_dict,
2067 outputs_dtype=outputs_dtype)
2068 if to is None:
2069 run_shape = True
2070 res = (o, to)
2071 var = o.copy_merge(to, shape=shape)
2072 if var.name in set_names:
2073 raise RuntimeError( # pragma: no cover
2074 f"Duplicated output name o={o!r} var={var!r}.")
2075 set_names.add(var.name)
2076 new_outputs.append(OutputDetectedVariable(node, var, i))
2077 if len(new_outputs) == 0:
2078 raise RuntimeError( # pragma: no cover
2079 f"No detected outputs inputs={inputs_dict!r} outputs={outputs_dict!r}.")
2081 # reorder new_outputs to follow outputs initial order
2082 if _keep_outputs is not None:
2083 new_outputs = self._node_to_graph_reorder_by_name(
2084 new_outputs, outputs)
2086 logger.debug("op:%s-%d._node_to_graph:new_outputs=%r",
2087 self.__class__.__name__, id(self), new_outputs)
2089 return nodes, new_inputs, new_outputs, run_shape
2091 def to_onnx(self, inputs=None, outputs=None,
2092 other_outputs=None, target_opset=None,
2093 optim=True, verbose=0, run_shape=True,
2094 function_name=None, function_domain=None,
2095 fLOG=print, processed=None, check_model=True,
2096 return_builder=False):
2097 """
2098 Converts this operator into an ONNX graph.
2100 :param inputs: information about type, it should not be None
2101 :param outputs: information about types, if None, the function
2102 will use shape inference to guess the final output type
2103 and shape
2104 :param other_outputs: additional nodes to consider
2105 as graph outputs but not outputs of this particular
2106 node
2107 :param target_opset: dictionary with target opset per domain,
2108 None for the default one
2109 :param optim: optimize the model with function
2110 @see fn onnx_optimisations
2111 :param run_shape: in case output shapes are not specify,
2112 the function runs function :epkg:`infer_shapes`
2113 to guess them, False would disable that
2114 default behaviour
2115 :param verbose: prints information
2116 :param function_name: if not None, returns a :epkg:`FunctionProto`
2117 :param function_domain: in case of a function, declares the function
2118 as part of this domain
2119 :param fLOG: logging function
2120 :param processed: keeps track the of the processed nodes
2121 :param check_model: checks the output model
2122 :param return_builder: if True, returns the instance of @see cl GraphBuilder
2123 used to build the onnx graph.
2124 :return: ONNX stucture
2126 *inputs* and *outputs* parameters work the same way.
2127 Here is some possible walues:
2129 - `inputs=numpy.float32`: all inputs are dense tensors of
2130 unknown shapes sharing the same element type
2131 - `inputs={'X': numpy.float32`, 'Y': numpy.in64}`:
2132 input `X` is a dense tensor of float32,
2133 input `Y` is a dense tensor of int64,
2134 - `{'X': numpy.array(...)}}`: input `X` is a dense
2135 tensor with a precise shape
2136 - `inputs=[Variable('X', numpy.float32, [1, 2])]`:
2137 input `X` is a dense tensor of float32 with shape `[1, 2]`
2138 - `inputs=[Variable('X', numpy.float32, [None, 2])]`:
2139 input `X` is a dense tensor of float32 with a 2D tensor
2140 with an unknown dimension (first one)
2141 - see @see cl Variable
2143 (OnnxOperator)
2144 """
2145 # opsets
2146 logger.debug(
2147 "op:%s-%d.to_onnx(%r, %r, other_outputs=%r, target_opset=%r, as_function=%r)",
2148 self.__class__.__name__, id(self), inputs, outputs,
2149 other_outputs, target_opset, function_name)
2150 if isinstance(target_opset, dict):
2151 dom = self.domain or ''
2152 target_opset = target_opset.get(dom, None)
2153 elif isinstance(target_opset, int):
2154 if self.domain not in ('', None):
2155 # The target_opset is for the domain '' we ignore it.
2156 target_opset = None
2157 elif target_opset is not None:
2158 raise TypeError( # pragma: no cover
2159 "target_opset must be a dictionary {domain: "
2160 "target_opset} not %r for operator %r." % (
2161 target_opset, self.__class__.__name__))
2163 if self.domain in ('', None) and target_opset == 1:
2164 raise RuntimeError( # pragma: no cover
2165 "target_opset cannot be 1.")
2166 if (self.op_version is not None and target_opset is not None and
2167 self.op_version > target_opset):
2168 raise RuntimeError( # pragma: no cover
2169 "target_opset={} is lower than the version={} requested "
2170 "for this node '{}'.".format(
2171 target_opset, self.op_version, self.__class__.__name__))
2173 # get the graph
2174 if processed is None:
2175 processed = {}
2176 logger.debug("op:%s-%d:SG-self:processed[%d]:SELF",
2177 self.__class__.__name__, id(self), id(self))
2178 processed[id(self)] = self
2180 logger.indent()
2181 nodes, graph_inputs, graph_outputs, run_shape2 = self._node_to_graph(
2182 other_outputs, inputs, outputs, as_function=function_name is not None,
2183 processed=processed)
2184 if hasattr(self, 'subgraph_inputs'):
2185 if any(map(lambda o: not isinstance(o, Variable),
2186 self.subgraph_inputs)):
2187 raise TypeError( # pragma: no cover
2188 f"Unexpected type, all type should be Variable in "
2189 f"{self.subgraph_inputs!r}.")
2190 graph_inputs = [
2191 InputDetectedVariable(None, v) for v in self.subgraph_inputs
2192 ] + graph_inputs
2193 logger.dedent()
2195 logger.debug("op:%s.to_onnx:graph_inputs=%r",
2196 self.__class__.__name__, graph_inputs)
2197 logger.debug("op:%s.to_onnx:graph_outputs=%r",
2198 self.__class__.__name__, graph_outputs)
2200 if len(nodes) == 0:
2201 raise RuntimeError( # pragma: no cover
2202 "Node list is empty.")
2203 if verbose > 1:
2204 for i, n in enumerate(nodes): # pragma: no cover
2205 fLOG("nodes[%d]=%r" % (i, n))
2206 for i, n in enumerate(graph_inputs): # pragma: no cover
2207 fLOG("graph_inputs[%d]=%r" % (i, n))
2209 # creates a _GraphBuilder
2210 builder = _GraphBuilder()
2212 # reserve input names starting by the first one
2213 for node in reversed(nodes):
2214 for var in node.inputs:
2215 if isinstance(var, Variable):
2216 logger.debug("op:%s.to_onnx:_add_name(%r)",
2217 self.__class__.__name__, var.name)
2218 builder._add_name(var.name)
2220 # reserve output names starting by the last ones
2221 for node in reversed(nodes):
2222 builder.reserve_names(node, node.output_names)
2224 # adds every node to the builder
2225 for i, node in enumerate(nodes):
2226 logger.debug("op:%s-%d.to_onnx:node:%d/%d:%r",
2227 self.__class__.__name__, id(self), i, len(nodes), node)
2229 for node in nodes:
2230 if isinstance(node, OnnxExisting):
2231 continue
2232 logger.indent()
2233 hidden = node._to_onnx_attributes(
2234 inputs=graph_inputs, target_opset=target_opset,
2235 optim=optim, verbose=verbose, run_shape=run_shape, fLOG=fLOG,
2236 processed=processed)
2237 logger.dedent()
2239 if len(hidden) > 0:
2240 logger.debug(
2241 "op:%s-%d.to_onnx:to_onnx:%s-%d:hidden:%r",
2242 self.__class__.__name__, id(self),
2243 node.__class__.__name__, id(node), hidden)
2244 builder.get_input_names(node, hidden)
2245 node.add_to(builder)
2247 logger.debug(
2248 "op:%s-%d.to_onnx:to_onnx:a", self.__class__.__name__, id(self))
2249 logger.indent()
2251 # fix missing inputs
2252 if isinstance(inputs, dict):
2253 known = set()
2254 for gi in graph_inputs:
2255 known.add(gi.var.name)
2256 for name, dtype in inputs.items():
2257 if name not in known:
2258 logger.debug(
2259 "%s-%d.to_onnx:+:%s:%r",
2260 self.__class__.__name__, id(self), name, dtype)
2261 var = InputDetectedVariable(
2262 None, Variable(name, dtype=dtype))
2263 graph_inputs.append(var)
2264 builder.input_names[name] = var
2265 for v in graph_inputs:
2266 if v.var.name not in builder.input_names:
2267 builder.input_names[v.var.name] = v
2269 onx = builder.to_onnx(
2270 inputs=graph_inputs, outputs=graph_outputs,
2271 target_opset=target_opset, verbose=verbose,
2272 optim=optim, run_shape=run_shape and run_shape2,
2273 function_name=function_name, function_domain=function_domain,
2274 check_model=check_model)
2275 logger.dedent()
2277 logger.debug(
2278 "op:%s-%d.to_onnx:to_onnx:b:%s:%d-nodes",
2279 self.__class__.__name__, id(self), type(onx).__name__,
2280 len(onx.graph.node) if hasattr(onx, 'graph') else onx.node)
2281 if return_builder:
2282 return onx, builder
2283 return onx
2285 def _to_onnx_attributes(self, inputs=None, target_opset=None,
2286 optim=True, verbose=0, run_shape=True,
2287 fLOG=print, processed=None):
2288 """
2289 Converts attributes into ONNX.
2290 Returns the hidden inputs.
2291 """
2292 if processed is None:
2293 raise RuntimeError( # pragma: no cover
2294 "processed cannot be None.")
2295 converts = []
2296 for k, v in self.kwargs.items():
2297 if isinstance(v, OnnxOperatorBase):
2298 converts.append(k)
2299 hidden_inputs = []
2300 for name in converts:
2301 if verbose > 0:
2302 fLOG( # pragma: no cover
2303 '[OnnxOperator._to_onnx_attributes] process %r of type %r.'
2304 '' % (name, type(self.kwargs[name])))
2305 model, hidden = self._to_onnx_attribute(
2306 name, self.kwargs[name], inputs=inputs, target_opset=target_opset,
2307 optim=optim, verbose=verbose, run_shape=run_shape, fLOG=fLOG,
2308 processed=processed)
2310 hidden_inputs.extend(hidden)
2311 if len(model.graph.node) == 0:
2312 _, hidden = self._to_onnx_attribute(
2313 name, self.kwargs[name], inputs=inputs, target_opset=target_opset,
2314 optim=False, verbose=verbose, run_shape=run_shape, fLOG=fLOG,
2315 processed=processed)
2316 raise RuntimeError( # pragma: no cover
2317 "Conversion to graph of parameter %r from\nnode=%r "
2318 "and\ninputs=%r\nis empty:\n%s\nHIDDEN\n%r" % (
2319 name, self.kwargs[name], self.kwargs[name].inputs,
2320 model, hidden))
2321 if name in {'else_branch', 'then_branck'}:
2322 if len(model.graph.input) > 0:
2323 # else_branch, then_branch must not have any input.
2324 del model.graph.input[:]
2325 self.kwargs[name] = model.graph
2326 return hidden_inputs
2328 def _to_onnx_attribute(self, att_name, oxop, inputs=None, target_opset=None,
2329 optim=True, verbose=0, run_shape=True,
2330 fLOG=print, processed=None):
2331 """
2332 Converts one subgraph into ONNX.
2333 Returns the ONNX graph and the hidden inputs.
2334 """
2335 if processed is None:
2336 raise RuntimeError( # pragma: no cover
2337 "processed cannot be None.")
2338 if inputs is None:
2339 vars = None
2340 else:
2341 named_inputs = set(oxop.find_named_inputs())
2342 vars = []
2343 added = set()
2344 for inp in inputs:
2345 if inp.var.name in named_inputs and inp.var.name not in added:
2346 added.add(inp.var.name)
2347 vars.append(Variable(
2348 inp.var.name, inp.var.dtype or inp.var.added_dtype))
2349 if verbose > 0:
2350 fLOG( # pragma: no cover
2351 f'[OnnxOperator._to_onnx_attribute] inputs={vars!r}')
2352 logger.debug("op:%s._to_onnx_attribute:%s:inputs(%r)",
2353 self.__class__.__name__, att_name, vars)
2354 logger.indent()
2355 onx, att_builder = oxop.to_onnx(
2356 inputs=vars, target_opset=target_opset, run_shape=run_shape,
2357 verbose=verbose, fLOG=fLOG, processed=processed, optim=False,
2358 check_model=False, return_builder=True)
2359 logger.dedent()
2360 hidden_inputs = att_builder.hidden_input
2361 if len(hidden_inputs) > 0:
2362 if verbose > 0:
2363 fLOG( # pragma: no cover
2364 f'[OnnxOperator._to_onnx_attribute] inputs={vars!r}')
2365 logger.debug("op:%s._to_onnx_attribute:inputs:hidden:%r",
2366 self.__class__.__name__, att_builder.hidden_input)
2367 if len(onx.graph.node) == 0:
2368 raise RuntimeError( # pragma: no cover
2369 "Empty graph (class=%r, optim=%r) from\nnode=%r "
2370 "and\ninputs=%r\nis empty:\n%s" % (
2371 type(oxop), optim, oxop, vars, onx))
2372 shaped_onx = infer_shapes(onx)
2373 return shaped_onx, hidden_inputs
2375 def predecessors(self):
2376 """
2377 Returns the list of predecessors.
2379 :return: list of @see cl OnnxOperator
2380 """
2381 stack = [self]
2382 last = 0
2383 while True:
2384 end = len(stack)
2385 if end == last:
2386 break
2387 for i in range(last, end):
2388 node = stack[i]
2389 for inp in node.inputs:
2390 if isinstance(inp, OnnxOperatorBase):
2391 stack.append(inp)
2392 last = end
2393 return stack
2395 def __call__(self, *args, function_name=None, function_domain=None,
2396 **kwargs):
2397 """
2398 Creates an instance of class @see cl OnnxOperatorFunction.
2399 Equivalent to `OnnxOperatorFunction(proto, *args, **kwargs)`.
2401 :param args: see @see cl OnnxOperatorFunction
2402 :param function_name: name to be given to the function
2403 :param function_domain: function domain, if None,
2404 it is given a default value
2405 :param kwargs: see @see cl OnnxOperatorFunction
2406 :return: instance of type @see cl OnnxOperatorFunction
2407 """
2408 if function_name is None:
2409 def clean(name):
2410 if name.startswith("Onnx"):
2411 name = name[4:]
2412 return name
2414 pred = self.predecessors()
2415 cls = [clean(p.__class__.__name__) for p in pred]
2416 function_name = "".join(cls)
2417 onx = self.to_onnx(function_name=function_name,
2418 function_domain=function_domain)
2419 return OnnxOperatorFunction(onx, *args, **kwargs)
2421 def find_named_inputs(self):
2422 """
2423 Retrieves all named inputs in this graph.
2424 """
2425 unique = set()
2426 found = []
2427 for inp in self.inputs:
2428 if isinstance(inp, str):
2429 if inp not in unique:
2430 found.append(inp)
2431 unique.add(inp)
2432 elif isinstance(inp, Variable):
2433 if inp.name not in unique:
2434 found.append(inp.name)
2435 unique.add(inp.name)
2436 elif isinstance(inp, OnnxOperatorBase):
2437 f = inp.find_named_inputs()
2438 for n in f:
2439 if n not in unique:
2440 found.append(n)
2441 unique.add(n)
2442 elif isinstance(inp, numpy.ndarray):
2443 pass
2444 else:
2445 raise RuntimeError( # pragma: no cover
2446 f"Unexpected input type {type(inp)!r}.")
2447 return found
2449 def to_onnx_this(self, evaluated_inputs):
2450 """
2451 Returns a simple ONNX graph corresponding to this node.
2453 :param evaluated_inputs: inputs as a list
2454 :return: ONNX graph
2456 (OnnxOperator)
2457 """
2458 logger.debug('op:%s-%d.to_onnx_this:%r',
2459 self.__class__.__name__, id(self),
2460 evaluated_inputs)
2461 inputs_names = ['I%d' % i for i in range(len(evaluated_inputs))]
2462 if self.output_names is None:
2463 if self.expected_outputs is None:
2464 raise NotImplementedError( # pragma: no cover
2465 "expected_outputs and output_names are not defined.")
2466 output_names = [o[0] for o in self.expected_outputs]
2467 else:
2468 output_names = [o.name for o in self.output_names]
2469 node = make_node(self.op_type, inputs_names, output_names,
2470 domain=self.domain, name="f", **self.kwargs)
2471 onx_inputs = [Variable(name, a.dtype).make_value_info()
2472 for name, a in zip(inputs_names, evaluated_inputs)]
2473 onx_outputs = [make_value_info(name, make_tensor_type_proto(0, []))
2474 for name in output_names]
2475 graph = make_graph([node], 'f', onx_inputs, onx_outputs)
2476 model = make_model(
2477 graph, opset_imports=[make_operatorsetid(
2478 self.domain or '', self.since_version)])
2479 return model
2481 def run(self, *inputs, verbose=0, fLOG=None, clear_cache=False, runtime=None):
2482 """
2483 Other name for
2484 `OnnxInference.f <mlprodict.onnxrt.onnx_inference.OnnxInference.f>`_.
2485 """
2486 return self.f(*inputs, verbose=verbose, fLOG=fLOG,
2487 clear_cache=clear_cache, runtime=runtime)
2489 def f(self, *inputs, verbose=0, fLOG=None, # pylint: disable=W0221
2490 clear_cache=False, runtime=None):
2491 """
2492 Computes the predictions for this node.
2493 Similar to an eager evaluation.
2495 :param inputs: inputs as dictionary or a list of inputs
2496 (see below)
2497 :param verbose: display information while predicting
2498 :param fLOG: logging function if *verbose > 0*
2499 :param clear_cache: onnx graph is created once unless
2500 this parameter is True
2501 :param runtime: runtime to use for the evaluation,
2502 see @see cl OnnxInference
2503 :return: outputs as a dictionary if the input were given as a
2504 dictionary or a single result or a tuple otherwise
2506 The inputs refer to the inputs of the graph.
2507 The method walks through all inputs and finds inputs defined as
2508 string. It replaces them by the value found in the dictionary.
2509 If the inputs are specified in a list, the function retrieves the
2510 list of inputs defined as a string and assigns them a value.
2511 Logging function can be used to get more insight about it.
2512 During the evaluation every node is independently converted
2513 into ONNX. The ONNX graph is cached in the class itself.
2514 """
2515 # input evaluation
2516 if len(inputs) == 1 and isinstance(inputs[0], dict):
2517 dict_inputs = inputs[0]
2518 as_dict = True
2519 elif not isinstance(inputs, (tuple, list, OnnxOperator._InputContainer)):
2520 raise TypeError( # pragma: no cover
2521 f"inputs must be a list not {type(inputs)!r}.")
2522 elif len(inputs) > 0 and isinstance(inputs[0], OnnxOperator):
2523 raise TypeError( # pragma: no cover
2524 f"Unexpected type for inputs[0]: {type(inputs[0])!r}.")
2525 else:
2526 as_dict = False
2527 if verbose > 0:
2528 fLOG( # pragma: no cover
2529 "[OnnxOperator.f] retrieves named inputs")
2530 if hasattr(self, "feval_named_inputs_"):
2531 named_inputs = self.feval_named_inputs_ # pylint: disable=E0203
2532 else:
2533 named_inputs = self.find_named_inputs()
2534 self.feval_named_inputs_ = named_inputs
2535 if len(named_inputs) != len(inputs):
2536 raise RuntimeError(
2537 "Mismatch between the number of found inputs (%d) and "
2538 "the number of given inputs (%d) (found %r)."
2539 "" % (
2540 len(named_inputs), len(inputs), named_inputs))
2541 dict_inputs = {
2542 name: value for name, value in zip(named_inputs, inputs)}
2543 if verbose > 0:
2544 fLOG( # pragma: no cover
2545 f"[OnnxOperator.f] found inputs: {named_inputs!r}")
2547 # conversion
2548 evaluated_inputs = []
2549 for i, inp in enumerate(self.inputs):
2550 if isinstance(inp, str):
2551 evaluated_inputs.append(dict_inputs[inp])
2552 elif isinstance(inp, Variable):
2553 evaluated_inputs.append(dict_inputs[inp.name])
2554 elif isinstance(inp, OnnxOperatorBase):
2555 if verbose > 0:
2556 fLOG( # pragma: no cover
2557 "[OnnxOperator.f] evaluate input %d (op_type=%r)" % (
2558 i, self.__class__.op_type))
2559 out = inp.f(dict_inputs, verbose=verbose, fLOG=fLOG)
2560 if isinstance(out, dict):
2561 if len(out) == 1:
2562 evaluated_inputs.append(out.popitem()[1])
2563 else:
2564 raise NotImplementedError( # pragma: no cover
2565 "Not yet implemented in case when there are multiple "
2566 "outputs (%r)." % list(out))
2567 elif isinstance(out, (list, OnnxOperator._InputContainer)):
2568 evaluated_inputs.extend(out)
2569 else:
2570 evaluated_inputs.append(out)
2571 elif isinstance(inp, numpy.ndarray):
2572 evaluated_inputs.append(inp)
2573 else:
2574 raise RuntimeError( # pragma: no cover
2575 "Unexpected type %r for input %d." % (type(inp), i))
2577 # conversion to ONNX
2578 if not hasattr(self, 'feval_onnx_'):
2579 self.feval_onnx_ = {}
2580 key = tuple((m.dtype, m.shape) for m in evaluated_inputs)
2581 if key not in self.feval_onnx_ or clear_cache:
2582 if verbose > 0:
2583 fLOG(
2584 f"[OnnxOperator.f] creating node {self.op_type!r}, inputs={key!r}")
2585 from ..onnxrt import OnnxInference
2586 model = self.to_onnx_this(evaluated_inputs)
2587 oinf = OnnxInference(model, runtime=runtime)
2588 self.feval_onnx_[key] = oinf
2589 else:
2590 oinf = self.feval_onnx_[key]
2592 # execution
2593 if verbose > 0:
2594 fLOG(f"[OnnxOperator.f] execute node {self.op_type!r}")
2595 got = oinf.run({k: v for k, v in
2596 zip(oinf.input_names, evaluated_inputs)})
2597 if as_dict:
2598 return got
2599 if len(got) == 1:
2600 return got.popitem()[1]
2601 return [got[n] for n in oinf.output_names]
2603 @staticmethod
2604 def _merge_op_version(n1, n2, at_least=None):
2605 if isinstance(n2, OnnxOperator):
2606 if n1.op_version is None:
2607 opv = n2.op_version
2608 elif n2.op_version is None:
2609 opv = n1.op_version
2610 elif n1.op_version == n2.op_version:
2611 opv = n1.op_version
2612 else:
2613 opv = max(n1.op_version, n2.op_version)
2614 elif isinstance(n2, OnnxOperatorItem):
2615 opv = OnnxOperator._merge_op_version(n1, n2.onx_op)
2616 elif isinstance(n2, OnnxOperatorTuple):
2617 raise NotImplementedError( # pragma: no cover
2618 "_merge_op_version is not implemented when n2 "
2619 "is OnnxOperatorTuple.")
2620 else:
2621 opv = n1.op_version
2622 if at_least is not None and opv is not None and opv < at_least:
2623 opv = at_least
2624 return opv
2626 def __add__(self, ov):
2627 """
2628 Automatically adds operator `OnnxAdd` to the graph.
2630 :param ov: onnx node
2631 :return: `OnnxAdd(self, ov)`
2632 """
2633 OnnxAdd = loadop('Add')
2634 opv = self._merge_op_version(self, ov, at_least=15)
2635 if isinstance(ov, (int, float)):
2636 OnnxCastLike = loadop('CastLike')
2637 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv)
2638 return OnnxAdd(self, ov, op_version=opv)
2640 def __sub__(self, ov):
2641 """
2642 Automatically adds operator `OnnxSub` to the graph.
2644 :param ov: onnx node
2645 :return: `OnnxSub(self, ov)`
2646 """
2647 OnnxSub = loadop('Sub')
2648 opv = self._merge_op_version(self, ov, at_least=15)
2649 if isinstance(ov, (int, float)):
2650 OnnxCastLike = loadop('CastLike')
2651 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv)
2652 return OnnxSub(self, ov, op_version=opv)
2654 def __mul__(self, ov):
2655 """
2656 Automatically adds operator `OnnxMul` to the graph.
2658 :param ov: onnx node
2659 :return: `OnnxMul(self, ov)`
2660 """
2661 OnnxMul = loadop('Mul')
2662 opv = self._merge_op_version(self, ov, at_least=15)
2663 if isinstance(ov, (int, float)):
2664 OnnxCastLike = loadop('CastLike')
2665 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv)
2666 return OnnxMul(self, ov, op_version=opv)
2668 def __truediv__(self, ov):
2669 """
2670 Automatically adds operator `OnnxDiv` to the graph.
2672 :param ov: onnx node
2673 :return: `OnnxDiv(self, ov)`
2674 """
2675 OnnxDiv = loadop('Div')
2676 opv = self._merge_op_version(self, ov, at_least=15)
2677 if isinstance(ov, (int, float)):
2678 OnnxCastLike = loadop('CastLike')
2679 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv)
2680 return OnnxDiv(self, ov, op_version=opv)
2682 def __pow__(self, ov):
2683 """
2684 Automatically adds operator `OnnxPow` to the graph.
2686 :param ov: onnx node
2687 :return: `OnnPow(self, ov)`
2688 """
2689 OnnxPow = loadop('Pow')
2690 opv = self._merge_op_version(self, ov, at_least=15)
2691 if isinstance(ov, (int, float)):
2692 OnnxCastLike = loadop('CastLike')
2693 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv)
2694 return OnnxPow(self, ov, op_version=opv)
2696 def __mod__(self, ov):
2697 """
2698 Automatically adds operator `OnnxMod` to the graph.
2700 :param ov: onnx node
2701 :return: `OnnxMod(self, ov)`
2702 """
2703 OnnxMod = loadop('Mod')
2704 opv = self._merge_op_version(self, ov, at_least=15)
2705 if isinstance(ov, (int, float)):
2706 OnnxCastLike = loadop('CastLike')
2707 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv)
2708 return OnnxMod(self, ov, op_version=opv)
2710 def __matmul__(self, ov):
2711 """
2712 Automatically adds operator `OnnxMatMul` to the graph.
2714 :param ov: onnx node
2715 :return: `OnnMatMul(self, ov)`
2716 """
2717 OnnxMatMul = loadop('MatMul')
2718 opv = self._merge_op_version(self, ov)
2719 return OnnxMatMul(self, ov, op_version=opv)
2721 def __gt__(self, ov):
2722 """
2723 Automatically adds operator `OnnxGreater` to the graph.
2725 :param ov: onnx node
2726 :return: `OnnxGreater(self, ov)`
2727 """
2728 OnnxGreater = loadop('Greater')
2729 opv = self._merge_op_version(self, ov, at_least=15)
2730 if isinstance(ov, (int, float)):
2731 OnnxCastLike = loadop('CastLike')
2732 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv)
2733 return OnnxGreater(self, ov, op_version=opv)
2735 def __ge__(self, ov):
2736 """
2737 Automatically adds operator `OnnxGreaterOrEqual` to the graph.
2739 :param ov: onnx node
2740 :return: `OnnxGreater(self, ov)`
2741 """
2742 OnnxGreaterOrEqual = loadop('GreaterOrEqual')
2743 opv = self._merge_op_version(self, ov, at_least=15)
2744 if isinstance(ov, (int, float)):
2745 OnnxCastLike = loadop('CastLike')
2746 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv)
2747 return OnnxGreaterOrEqual(self, ov, op_version=opv)
2749 def __lt__(self, ov):
2750 """
2751 Automatically adds operator `OnnxLess` to the graph.
2753 :param ov: onnx node
2754 :return: `OnnxLess(self, ov)`
2755 """
2756 OnnxLess = loadop('Less')
2757 opv = self._merge_op_version(self, ov, at_least=15)
2758 if isinstance(ov, (int, float)):
2759 OnnxCastLike = loadop('CastLike')
2760 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv)
2761 return OnnxLess(self, ov, op_version=opv)
2763 def __le__(self, ov):
2764 """
2765 Automatically adds operator `OnnxLess` to the graph.
2767 :param ov: onnx node
2768 :return: `OnnxLess(self, ov)`
2769 """
2770 OnnxLessOrEqual = loadop('LessOrEqual')
2771 opv = self._merge_op_version(self, ov, at_least=15)
2772 if isinstance(ov, (int, float)):
2773 OnnxCastLike = loadop('CastLike')
2774 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv)
2775 return OnnxLessOrEqual(self, ov, op_version=opv)
2777 def __eq__(self, ov):
2778 """
2779 Automatically adds operator `OnnxEqual` to the graph.
2781 :param ov: onnx node
2782 :return: `OnnxEqual(self, ov)`
2783 """
2784 OnnxEqual = loadop('Equal')
2785 opv = self._merge_op_version(self, ov, at_least=15)
2786 if isinstance(ov, (int, float)):
2787 OnnxCastLike = loadop('CastLike')
2788 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv)
2789 return OnnxEqual(self, ov, op_version=opv)
2791 def and_(self, ov):
2792 """
2793 Automatically adds operator `OnnxAnd` to the graph.
2795 :param ov: onnx node
2796 :return: `OnnxAnd(self, ov)`
2797 """
2798 OnnxAnd = loadop('And')
2799 opv = self._merge_op_version(self, ov)
2800 return OnnxAnd(self, ov, op_version=opv)
2802 def or_(self, ov):
2803 """
2804 Automatically adds operator `OnnxOr` to the graph.
2806 :param ov: onnx node
2807 :return: `OnnxOr(self, ov)`
2808 """
2809 OnnxOr = loadop('Or')
2810 opv = self._merge_op_version(self, ov)
2811 return OnnxOr(self, ov, op_version=opv)
2813 def __ne__(self, ov):
2814 """
2815 Automatically adds operator `OnnxNot x OnnxEqual` to the graph.
2817 :param ov: onnx node
2818 :return: `OnnxNot(OnnxEqual(self, ov))`
2819 """
2820 OnnxNot, OnnxEqual = loadop('Not', 'Equal')
2821 opv = self._merge_op_version(self, ov, at_least=15)
2822 if isinstance(ov, (int, float)):
2823 OnnxCastLike = loadop('CastLike')
2824 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv)
2825 return OnnxNot(OnnxEqual(self, ov, op_version=opv), op_version=opv)
2827 def __abs__(self):
2828 """
2829 Automatically adds operator `OnnxAbs` to the graph.
2831 :param ov: onnx node
2832 :return: `OnnxAbs(self, ov)`
2833 """
2834 OnnxAbs = loadop('Abs')
2835 return OnnxAbs(self, op_version=self.op_version)
2837 def not_(self):
2838 """
2839 Automatically adds operator `OnnxNot` to the graph.
2841 :param ov: onnx node
2842 :return: `OnnxNot(self, ov)`
2843 """
2844 OnnxNot = loadop('Not')
2845 return OnnxNot(self, op_version=self.op_version)
2847 def astype(self, to):
2848 """
2849 Automatically adds operator `OnnxCast` to the graph.
2851 :param ov: onnx node
2852 :return: `OnnxCast(self, ov, to=to)`
2853 """
2854 OnnxCast = loadop('Cast')
2855 return OnnxCast(self, to=to, op_version=self.op_version)
2858class OnnxOperatorFunction(OnnxOperator):
2859 """
2860 This operator is used to insert existing ONNX function into
2861 the ONNX graph being built.
2863 :param function_proto: instance of type :epkg:`FunctionProto`
2864 :param inputs: inputs
2865 :param output_names: output names
2866 :param sub_functions: functions called by this one
2867 """
2869 domain = 'mlprodict'
2870 since_version = 1
2871 expected_inputs = None
2872 expected_outputs = None
2873 input_range = [1, 1e9]
2874 output_range = [1, 1e9]
2875 op_type = 'Function'
2876 domain = 'mlprodict.xop'
2878 @staticmethod
2879 def attribute_to_value(att):
2880 """
2881 Converts an attribute into a value using python structures.
2882 """
2883 if isinstance(att, onnx.AttributeProto):
2884 dtype = att.type
2885 else:
2886 raise NotImplementedError( # pragma: no cover
2887 f"Unable to copy attribute type {type(att)!r}.")
2888 if dtype == 1: # .f
2889 value = att.f
2890 elif dtype == 2: # .i
2891 value = att.i
2892 elif dtype == 3: # .s
2893 value = att.s
2894 elif dtype == 4: # .t
2895 value = att.t
2896 elif dtype == 6: # .floats
2897 value = list(att.floats)
2898 elif dtype == 7: # .ints
2899 value = list(att.ints)
2900 elif dtype == 8: # .strings
2901 value = list(att.strings)
2902 elif dtype == 11: # .double_data
2903 value = list(att.double_data)
2904 else:
2905 raise NotImplementedError( # pragma: no cover
2906 f"Unable to copy attribute type {dtype!r} ({att!r}).")
2907 return value
2909 def __init__(self, function_proto, *inputs, output_names=None,
2910 sub_functions=None):
2911 logger.debug("op:Function(ONNX, %d in, output_names=%r)",
2912 len(inputs), output_names)
2913 if function_proto is None:
2914 raise ValueError(
2915 "function_proto cannot be None.") # pragma: no cover
2916 if not isinstance(function_proto, onnx.FunctionProto):
2917 raise TypeError( # pragma: no cover
2918 "function_proto must be of type FunctionProto not %r." %
2919 type(function_proto))
2920 if len(inputs) > len(function_proto.input):
2921 raise RuntimeError( # pragma: no cover
2922 "Unexpected number of inputs %r > expected %r." % (
2923 len(inputs), len(function_proto.input)))
2924 if (output_names is not None and
2925 len(output_names) != len(function_proto.output)):
2926 raise RuntimeError( # pragma: no cover
2927 "Unexpected number of outputs %r != expected %r." % (
2928 len(output_names), len(function_proto.output)))
2929 OnnxOperator.__init__(self, *inputs, output_names=output_names)
2930 self.model = function_proto
2931 self.sub_functions = sub_functions
2933 def __repr__(self):
2934 "usual"
2935 atts = {}
2936 for att in ['output_names']:
2937 value = getattr(self, att, None)
2938 if value is not None:
2939 atts[att] = value
2940 atts.update(self.kwargs)
2941 if self.sub_functions is not None and len(self.sub_functions) > 0:
2942 atts["sub_functions"] = list(range(len(self.sub_functions)))
2943 msg = ", ".join(f"{k}={v!r}" for k, v in atts.items())
2944 if len(atts) > 0:
2945 msg = ", " + msg
2946 return f"{self.__class__.__name__}(...{msg})"
2948 def add_to(self, builder):
2949 """
2950 Adds to graph builder.
2952 :param builder: instance of @see cl _GraphBuilder,
2953 it must have a method `add_node`
2954 """
2955 logger.debug("op:Function.add_to(builder)")
2956 inputs = builder.get_input_names(self, self.inputs)
2957 n_outputs = len(self.model.output)
2958 outputs = [builder.get_unique_output_name(NodeResultName(self, i))
2959 for i in range(n_outputs)]
2961 # linking inputs
2962 logger.indent()
2963 if self.sub_functions is not None:
2964 for sub in self.sub_functions:
2965 builder.add_function(sub)
2966 builder.add_function(self.model)
2967 builder.add_node(
2968 self.model.name, builder.get_unique_name(
2969 '_fct_' + self.model.name, reserved=False),
2970 inputs, outputs, domain=self.model.domain)
2971 logger.dedent()
2974class _GraphBuilder:
2975 """
2976 Graph builder. It takes a graph structure made with
2977 instances of @see cl OnnxOperatorBase.
2978 The main method is `to_onnx`.
2980 * `initializer`: list of initializers to add to the ONNX graph
2981 * `node`: list of nodes to add to the ONNX graph
2982 * `input`: list of inputs to add to the ONNX graph
2983 * `output`: list of inputs to add to the ONNX graph
2984 * `opsets`: opsets of the ONNX graph
2985 * `input_names`: dictionary of input names
2986 `{name: InputDetectedVariable}`
2987 * `node_output_names`: memorizes a name for a node output
2988 when the user did not specify any
2989 `{(id(node), index): OutputDetectedVariable}`
2990 * `reserved_names`: dictionary `{ name : (node, index) }`,
2991 name which should remain unchanged in the ONNX graph
2992 * `names`: list of uniques names
2993 * `functions`: dictionary `{ domain, name: function_proto }`
2994 * `function_hashes`: dictionary `{ domain, name: hash of function_proto }`
2995 """
2997 def __init__(self):
2998 self.initializer = []
2999 self.node = []
3000 self.input = []
3001 self.output = []
3002 self.opsets = {}
3003 self.input_names = {}
3004 self.node_output_names = {}
3005 self.reserved_names = {}
3006 self.names = set()
3007 self.functions = {}
3008 self.function_hashes = {}
3009 logger.debug('_GraphBuilder-%d:new', id(self))
3011 def _add_domain(self, domain, version):
3012 if domain not in self.opsets:
3013 self.opsets[domain] = version
3014 else:
3015 self.opsets[domain] = max(version, self.opsets[domain])
3017 def _add_name(self, name):
3018 self.names.add(name)
3020 @staticmethod
3021 def number2alpha(index):
3022 """
3023 Converts a numbers into a string keeping the same
3024 alphabetical order.
3025 """
3026 dec = str(int(index))
3027 if len(dec) == 1:
3028 return dec
3029 return chr(96 + len(dec)) + dec
3031 def reserve_names(self, node, output_names):
3032 """
3033 Adds names to the list of reserved names.
3034 All must be unique.
3036 :param node: node or None for an input
3037 :param output_names: names of the output
3038 """
3039 if output_names is None:
3040 return
3041 for index, var in enumerate(output_names):
3042 if not isinstance(var, (Variable, ExistingVariable)):
3043 raise TypeError( # pragma: no cover
3044 f"Unexpected type {type(var)!r} for {var!r}.")
3045 self.reserve_name(node, var.name, index)
3047 def reserve_name(self, node, name, index):
3048 """
3049 Reserves a name so that it cannot be changed.
3051 :param node: node or None for an input
3052 :param name: name
3053 :param index: input index
3054 """
3055 if not isinstance(name, str):
3056 raise TypeError( # pragma: no cover
3057 f"Name {name!r} is not a string.")
3058 if name in self.reserved_names:
3059 raise RuntimeError( # pragma: no cover
3060 "Name %r is already reserved from node %r, index=%d." % (
3061 name, node, index))
3062 logger.debug("_GraphBuilder-%d.reserve_name([%s-%d], %r, %r)",
3063 id(self), node.__class__.__name__, id(node),
3064 name, index)
3065 self.reserved_names[name] = (node, index)
3066 self._add_name(name)
3068 def get_unique_output_name(self, result):
3069 """
3070 Returns a unique output_name for a NodeResultName.
3072 :param result: instance of @see cl NodeResultName
3073 """
3074 if not isinstance(result, NodeResultName):
3075 raise TypeError( # pragma: no cover
3076 "Result must be of type NodeResultName not %r (%r)." % (
3077 type(result), result))
3078 if result.node is None:
3079 key = None, result.index
3080 else:
3081 key = id(result.node), result.index
3082 if key in self.node_output_names:
3083 return self.node_output_names[key]
3084 name = result.get_name()
3085 if name in self.reserved_names:
3086 unique = name
3087 else:
3088 unique = self.get_unique_name(name)
3089 self.node_output_names[key] = unique
3090 return unique
3092 def get_unique_name(self, name, reserved=True):
3093 """
3094 Returns a unique name to name an output.
3096 :param name: name
3097 :param reserved: bypass if the name is a reserved one
3098 :return: unique name, may be the same if not taken already
3099 """
3100 if not isinstance(name, str):
3101 raise TypeError( # pragma: no cover
3102 f"name must be a string not {type(name)!r}.")
3103 if reserved and name in self.reserved_names:
3104 logger.debug( # pragma: no cover
3105 "_GraphBuilder-%d.get_unique_name(%r) 1-> %r",
3106 id(self), name, name)
3107 return name
3108 if name not in self.names:
3109 self._add_name(name)
3110 logger.debug("_GraphBuilder-%d.get_unique_name(%r) 2-> %r",
3111 id(self), name, name)
3112 return name
3113 i = 1
3114 new_name = f"{name}_{self.number2alpha(i)}"
3115 while new_name in self.names:
3116 i += 1
3117 new_name = f"{name}_{self.number2alpha(i)}"
3118 self._add_name(new_name)
3119 logger.debug("_GraphBuilder-%d.get_unique_name(%r) 3-> %r",
3120 id(self), name, new_name)
3121 return new_name
3123 def get_input_names(self, node, inputs):
3124 """
3125 Returns input names for node *node* and inputs *inputs*.
3127 :param node: node
3128 :param inputs: inputs
3129 :return: name
3130 """
3131 logger.debug(
3132 "_GraphBuilder-%d.get_input_names:1:%s-%d:%r",
3133 id(self), node.__class__.__name__, id(node), inputs)
3134 names = []
3135 for i in inputs:
3136 if isinstance(i, (Variable, ExistingVariable)):
3137 self._add_name(i.name)
3138 names.append(i.name)
3139 if i.name in self.input_names:
3140 if isinstance(i, Variable):
3141 self.input_names[i.name] = InputDetectedVariable(
3142 None, i)
3143 logger.debug(
3144 "_GraphBuilder-%d.get_input_names:2:a:%d:+input_names:%s",
3145 id(self), id(node), i.name)
3146 else:
3147 logger.debug( # pragma: no cover
3148 "_GraphBuilder-%d.get_input_names:2:a:%d:=input_names:%s",
3149 id(self), id(node), i.name)
3150 else:
3151 self.input_names[i.name] = InputDetectedVariable(None, i)
3152 logger.debug(
3153 "_GraphBuilder-%d.get_input_names:2:b:%d:+input_names:%s",
3154 id(self), id(node), i.name)
3155 elif isinstance(i, InputDetectedVariable):
3156 self._add_name(i.name)
3157 names.append(i.name)
3158 if i.name in self.input_names:
3159 logger.debug( # pragma: no cover
3160 "_GraphBuilder-%d.get_input_names:2:c:%d:=input_names:%s",
3161 id(self), id(node), i.name)
3162 else:
3163 self.input_names[i.name] = i
3164 logger.debug(
3165 "_GraphBuilder-%d.get_input_names:2:c:%d:+input_names:%s",
3166 id(self), id(node), i.name)
3167 elif isinstance(i, OnnxExisting):
3168 inp = i.inputs[0]
3169 n = inp.output_names[0]
3170 self._add_name(n.name)
3171 names.append(n.name)
3172 if n.name in self.input_names:
3173 if isinstance(inp, Variable):
3174 self.input_names[n.name] = InputDetectedVariable(
3175 None, n)
3176 logger.debug( # pragma: no cover
3177 "_GraphBuilder-%d.get_input_names:2:d:%d:+input_names:%s",
3178 id(self), id(node), n.name)
3179 else:
3180 logger.debug(
3181 "_GraphBuilder-%d.get_input_names:2:d:%d:=input_names:%s",
3182 id(self), id(node), n.name)
3183 else:
3184 self.input_names[n.name] = InputDetectedVariable(None, n)
3185 logger.debug(
3186 "_GraphBuilder-%d.get_input_names:2:d:%d:+input_names:%s",
3187 id(self), id(node), n.name)
3188 elif isinstance(i, OnnxOperator):
3189 key = id(i), 0
3190 try:
3191 name = self.node_output_names[key]
3192 except KeyError as e: # pragma: no cover
3193 raise RuntimeError(
3194 "Unable to find key %r for input "
3195 "(type(i) is %r, type(node) is %r) "
3196 "%r in node %r among %r." % (
3197 key, type(i), type(node), i, node,
3198 list(self.node_output_names))) from e
3199 names.append(name)
3200 elif isinstance(i, OnnxOperatorItem):
3201 if isinstance(i.onx_op, OnnxOperatorTuple):
3202 if i.onx_op.values is None:
3203 key = id(i.onx_op.unique), i.index
3204 else:
3205 key = id(i.onx_op[i.index]), 0
3206 elif isinstance(i.onx_op, OnnxOperator):
3207 key = id(i.onx_op), i.index
3208 else:
3209 raise TypeError( # pragma: no cover
3210 f"Unexpected type for OnnxOperatorItem: {type(i.onx_op)!r}.")
3211 try:
3212 name = self.node_output_names[key]
3213 except KeyError as e: # pragma: no cover
3214 raise RuntimeError(
3215 "Unable to find key %r for input %r in node %r." % (
3216 key, i, node)) from e
3217 names.append(name)
3218 elif isinstance(i, OnnxOperatorTuple):
3219 raise NotImplementedError() # pragma: no cover
3220 elif isinstance(i, numpy.ndarray):
3221 # Adding an initializer
3222 name = self.get_unique_name('init', reserved=False)
3223 init = from_array(i, name)
3224 self.initializer.append(init)
3225 names.append(name)
3226 else:
3227 raise TypeError( # pragma: no cover
3228 f"Unexpected type for an input {type(i)!r}.")
3229 logger.debug(
3230 "_GraphBuilder-%d.get_input_names:3:%r", id(self), names)
3231 return names
3233 def add_initializer(self, name, init):
3234 """
3235 Adds an initializer to the graph.
3237 :param name: initializer name
3238 :param init: initializer to copy
3239 :return: created intializer
3240 """
3241 if isinstance(init, onnx.TensorProto):
3242 tensor = to_array(init)
3243 val = from_array(tensor, name)
3244 logger.debug("_GraphBuilder.add_initializer:1(%r, %r, %r)",
3245 name, tensor.dtype, tensor.shape)
3246 elif isinstance(init, numpy.ndarray):
3247 value = to_array(init)
3248 val = from_array(value, name)
3249 logger.debug("_GraphBuilder.add_initializer:2(%r, %r, %r)",
3250 name, init.dtype, init.shape)
3251 else:
3252 raise NotImplementedError( # pragma: no cover
3253 f"Unsupported initializer type {type(init)!r}.")
3254 self.initializer.append(val)
3255 return val
3257 def add_function(self, function_proto,
3258 raise_if_exist=False, check_unique=True,
3259 opset=1):
3260 """
3261 Adds a function to the graph.
3263 :param function_proto: instance of type :epkg:`FunctionProto`
3264 :param raise_if_exist: raises an exception if a function of the
3265 same name was already added
3266 :param check_unique: checks if a function was added twice,
3267 it is the same
3268 :param opset: opset for the domain the function belongs to
3269 """
3270 def _hash(p):
3271 m = hashlib.sha256()
3272 m.update(p.SerializeToString())
3273 return m.hexdigest()[:64]
3275 key = function_proto.domain, function_proto.name
3276 if key in self.functions:
3277 if raise_if_exist:
3278 raise RuntimeError( # pragma: no cover
3279 f"Function {key!r} is added for the second time.")
3280 if check_unique:
3281 hs = _hash(function_proto)
3282 if hs != self.function_hashes[key]:
3283 raise RuntimeError( # pragma: no cover
3284 "Function %r is added for the second time "
3285 "and the content is not the same." % (key, ))
3286 return
3287 self.functions[key] = function_proto
3288 self.function_hashes[key] = _hash(function_proto)
3289 self._add_domain(function_proto.domain, opset)
3291 def add_node(self, op_type, name, inputs, outputs, domain='',
3292 opset=None, **attributes):
3293 """
3294 Adds a node to the graph.
3296 :param op_type: operator type
3297 :param name: node name
3298 :param inputs: inputs name list
3299 :param outputs: outputs name list
3300 :param domain: node domain
3301 :param opset: node opset
3302 :return: created node
3303 """
3304 logger.debug("_GraphBuilder-%d.add_node(%r, %r, "
3305 "inputs=%r, outputs=%r, domain=%r, opset=%r)",
3306 id(self), op_type, name, inputs, outputs, domain, opset)
3307 if not isinstance(inputs, (list, OnnxOperator._InputContainer)):
3308 raise TypeError( # pragma: no cover
3309 f"inputs must be a list not {type(inputs)!r}.")
3310 if not isinstance(outputs, (list, OnnxOperator._InputContainer)):
3311 raise TypeError( # pragma: no cover
3312 f"inputs must be a list not {type(outputs)!r}.")
3313 if any(map(lambda x: not isinstance(x, str), inputs)):
3314 raise TypeError( # pragma: no cover
3315 f"inputs must be all strings not {inputs!r}.")
3316 if any(map(lambda x: not isinstance(x, str), outputs)):
3317 raise TypeError( # pragma: no cover
3318 f"outputs must be all strings not {outputs!r}.")
3319 if opset is not None:
3320 self._add_domain(domain, opset)
3321 node = make_node(op_type, inputs, outputs, name=name,
3322 domain=domain, **attributes)
3323 self.node.append(node)
3324 return node
3326 def _process_io(self, inputs, input_names_):
3327 logger.debug("_GraphBuilder-%d._process_io:1:inputs=%r",
3328 id(self), inputs)
3329 logger.debug("_GraphBuilder-%d._process_io:1:input_names_=%r",
3330 id(self), input_names_)
3331 if input_names_ is None:
3332 input_names = None
3333 else:
3334 input_names = []
3335 for inp in input_names_:
3336 if inp.var.name == '':
3337 continue
3338 input_names.append(inp)
3340 if inputs is None:
3341 logger.debug( # pragma: no cover
3342 "_GraphBuilder-%d._process_io:return:%r",
3343 id(self), self.input_names)
3344 return [
3345 make_tensor_value_info(
3346 'X', TensorProto.FLOAT, None) # pylint: disable=E1101
3347 for name in self.input_names], None
3349 if not isinstance(inputs, (list, OnnxOperator._InputContainer)):
3350 if is_numpy_dtype(inputs):
3351 inputs = [inputs]
3353 logger.debug("_GraphBuilder-%d._process_io:2:input_names=%r",
3354 id(self), input_names)
3355 if input_names is None:
3356 # outputs
3357 set_names = set()
3358 input_names = []
3359 new_inputs = []
3360 for inp in inputs:
3361 if isinstance(inp, OutputDetectedVariable):
3362 if inp.name in set_names:
3363 raise ValueError( # pragma: no cover
3364 f"Names already taken {inp.name!r} in {inputs!r}.")
3365 set_names.add(inp.name)
3366 if isinstance(inp.node, OnnxExisting):
3367 raise NotImplementedError( # pragma: no cover
3368 f"Unexpected name {inp.name!r} type {type(inp.node)!r}.")
3369 # continue
3370 key = id(inp.node), inp.index
3371 if key in self.node_output_names:
3372 new_name = self.node_output_names[key]
3373 new_var = OutputDetectedVariable(
3374 inp.node, inp.var.copy_name(new_name), inp.index)
3375 input_names.append(new_var)
3376 new_inputs.append(new_var)
3377 else:
3378 raise RuntimeError( # pragma: no cover
3379 "Key %r is ambiguous or defined in "
3380 "two nodes %r, id(node)=%d, index=%d." % (
3381 key, inp, id(inp.node), inp.index))
3382 else:
3383 raise TypeError( # pragma: no cover
3384 "Unexpected type %r (it should be "
3385 "OutputDetectedVariable) in %r." % (inp, inputs))
3386 inputs = new_inputs
3387 if len(input_names) == 0:
3388 raise RuntimeError( # pragma: no cover
3389 "Unable to cross %r and %r or %r (set_names=%r)." % (
3390 inputs, self.output_names_rev,
3391 self.node_output_names_rev, set_names))
3392 elif not isinstance(input_names, (list, OnnxOperator._InputContainer)):
3393 raise RuntimeError( # pragma: no cover
3394 f"Unexpected type for input_names {type(input_names)!r}.")
3395 else:
3396 # inputs
3397 pass
3399 # common parts
3400 logger.debug("_GraphBuilder-%d._process_io:3:input_names:%r",
3401 id(self), input_names)
3402 logger.debug("_GraphBuilder-%d._process_io:3:inputs:%r",
3403 id(self), inputs)
3404 no_exists_names = [c for c in input_names if not isinstance(
3405 c.var, (ExistingVariable, OnnxExisting))]
3406 no_exists = [c for c in inputs if not isinstance(
3407 c.var, (ExistingVariable, OnnxExisting))]
3409 if isinstance(input_names, (list, OnnxOperator._InputContainer)):
3410 d_input_names = {}
3411 for inp in input_names:
3412 if inp.name in d_input_names:
3413 raise ValueError( # pragma: no cover
3414 f"Duplicated name {inp.name!r} in {input_names!r}.")
3415 d_input_names[inp.name] = inp
3416 elif isinstance(input_names, dict):
3417 d_input_names = input_names
3418 else:
3419 raise TypeError( # pragma: no cover
3420 "Unexpected type for input_names %r (%r)." % (
3421 type(input_names), input_names))
3423 logger.debug("_GraphBuilder-%d._process_io:4:no_exists_names:%r",
3424 id(self), no_exists_names)
3425 logger.debug("_GraphBuilder-%d._process_io:4:no_exists:%r",
3426 id(self), no_exists)
3428 # mapping
3429 res = []
3430 for inp in no_exists:
3431 if not isinstance(inp, DetectedVariable):
3432 raise TypeError( # pragma: no cover
3433 f"inp not DetectedVariable but {type(inp)!r} ({inp!r}).")
3434 if inp.name.startswith('???'):
3435 raise RuntimeError( # pragma: no cover
3436 f"Issue with variable {inp!r}.")
3437 var = d_input_names[inp.name]
3438 if not isinstance(var, DetectedVariable):
3439 raise TypeError( # pragma: no cover
3440 f"var not Variable but {type(var)!r} ({var!r}).")
3442 # inp: Variable
3443 # var: str
3444 if isinstance(var.var, ExistingVariable):
3445 # It may be an input referenced in a subgraph and not used in the
3446 # main graph.
3447 if inp.var.name != var.var.name:
3448 raise RuntimeError( # pragma: no cover
3449 f"Unexpected {inp!r} != {var!r}.")
3450 elif inp.var != var.var:
3451 if (inp.var.name != var.var.name or (
3452 inp.var.dtype is not None and
3453 var.var.dtype is not None)):
3454 raise RuntimeError( # pragma: no cover
3455 f"Unexpected {inp.var!r} != {var.var!r}.")
3457 if isinstance(inp.var, ExistingVariable):
3458 # The type of ExistingVariable must be known
3459 # to build the subgraph. Let's try unknown.
3460 res.append(make_tensor_value_info(inp.name, 0, None))
3461 else:
3462 res.append(make_tensor_value_info(
3463 inp.name, inp.var.proto_added_type,
3464 inp.var.proto_added_shape))
3466 hidden = [c for c in input_names if isinstance(
3467 c.var, (ExistingVariable, OnnxExisting))]
3468 logger.debug("_GraphBuilder-%d._process_io:4:return:res:%r",
3469 id(self), [n.name for n in res])
3470 logger.debug("_GraphBuilder-%d._process_io:4:return:hidden:%r",
3471 id(self), hidden)
3472 return res, hidden
3474 def to_onnx(self, inputs=None, outputs=None,
3475 target_opset=None, run_shape=False,
3476 optim=True, function_name=None,
3477 function_domain=None, verbose=0,
3478 check_model=True):
3479 """
3480 Converts this operator into an ONNX graph.
3482 :param inputs: specific inputs (as a dictionary) or
3483 default inputs if not specified
3484 :param outputs: specific outputs
3485 :param target_opset: dictionary with target opset per domain,
3486 None for the default one
3487 :param run_shape: run shape inference before returning the model
3488 :param optim: optimize the model with function
3489 @see fn onnx_optimisations
3490 :param function_name: if not None builds a :epkg:`FunctionProto`
3491 use this name
3492 :param function_domain: in case of a function, declares the function
3493 as part of this domain, `'mlprodict'` if None
3494 :param verbose: prints information
3495 :param check_model: checks the output model
3496 :return: onnx graph
3498 (_GraphBuilder)
3499 """
3500 logger.debug("_GraphBuilder-%d.to_onnx:#####:%s",
3501 id(self), str(function_name))
3502 logger.debug("_GraphBuilder-%d.to_onnx(%r, %r, target_opset=%r)",
3503 id(self), inputs, outputs, target_opset)
3504 # inputs and outputs
3505 if not all(map(lambda x: isinstance(x, InputDetectedVariable), inputs)):
3506 raise TypeError( # pragma: no cover
3507 "One of the input is not InputDetectedVariable.")
3508 if not all(map(lambda x: isinstance(x, OutputDetectedVariable), outputs)):
3509 raise TypeError( # pragma: no cover
3510 "One of the outputs is not OutputDetectedVariable.")
3511 logger.indent()
3512 self.input, self.hidden_input = self._process_io(
3513 inputs, list(self.input_names.values()))
3514 logger.dedent()
3515 logger.debug("_GraphBuilder-%d.to_onnx:hidden_input:%r",
3516 id(self), self.hidden_input)
3517 logger.indent()
3518 self.output, self.hidden_output = self._process_io(outputs, None)
3519 logger.dedent()
3520 if len(self.hidden_output) > 0:
3521 raise RuntimeError( # pragma: no cover
3522 f"Unexpected hidden output {self.hidden_output!r}.")
3523 logger.debug("_GraphBuilder-%d.to_onnx:self.input=%r",
3524 id(self), [i.name for i in self.input])
3525 if len(self.hidden_input) > 0:
3526 logger.debug("_GraphBuilder-%d.to_onnx:self.hidden_input=%r",
3527 id(self), [i.name for i in self.hidden_input])
3528 logger.debug("_GraphBuilder-%d.to_onnx:self.output=%r",
3529 id(self), [i.name for i in self.output])
3530 logger.debug("_GraphBuilder-%d.to_onnx:build:n_inputs=%r n_inits=%r n_nodes=%r "
3531 "n_outputs=%r",
3532 id(self), len(self.input), len(self.initializer),
3533 len(self.node), len(self.output))
3535 if function_name is not None:
3536 # function
3537 if function_domain is None:
3538 function_domain = 'mlprodict'
3539 if len(self.initializer) > 0:
3540 nodes = []
3541 for init in self.initializer:
3542 nodes.append(
3543 make_node('Constant', [], [init.name], value=init,
3544 name=f'_init_{init.name}'))
3545 nodes.extend(self.node)
3546 else:
3547 nodes = self.node
3548 fct = make_function(
3549 function_domain, function_name,
3550 [_.name for _ in self.input],
3551 [_.name for _ in self.output],
3552 nodes,
3553 [make_opsetid(k, v) for k, v in self.opsets.items()])
3554 if check_model:
3555 check_onnx(fct)
3556 if optim:
3557 from ..onnx_tools.optim import onnx_optimisations
3558 fct = onnx_optimisations(fct)
3559 if check_model:
3560 check_onnx(fct)
3561 logger.debug("_GraphBuilder-%d:fct:.to_onnx() -> done", id(self))
3562 logger.debug("_GraphBuilder-%d:fct:to_onnx:#####", id(self))
3563 return fct
3564 else:
3565 # graph
3566 graph = make_graph(
3567 self.node, 'XOP', self.input, self.output, self.initializer)
3568 onnx_model = make_model(
3569 graph, functions=list(self.functions.values()))
3570 opv = self.opsets.get('', max_supported_opset())
3571 opset2ir = _default_OPSET_TO_IR_VERSION()
3572 irv = opset2ir.get(opv, max(opset2ir.values()))
3573 onnx_model.ir_version = irv
3575 logger.debug("_GraphBuilder-%d.to_onnx:2onnx:n_inputs=%r n_inits=%r "
3576 "n_nodes=%r n_outputs=%r",
3577 id(self), len(onnx_model.graph.input),
3578 len(onnx_model.graph.initializer),
3579 len(onnx_model.graph.node),
3580 len(onnx_model.graph.output))
3582 del onnx_model.opset_import[:] # pylint: disable=E1101
3583 seen_opset = set()
3584 for k, v in self.opsets.items():
3585 if (k or '') in seen_opset:
3586 raise RuntimeError( # pragma: no cover
3587 f"Duplicated opset ({k!r}, {v!r}).")
3588 op_set = onnx_model.opset_import.add() # pylint: disable=E1101
3589 op_set.domain = k or ''
3590 op_set.version = v
3591 seen_opset.add(op_set.domain)
3593 # optimisation, remove redundant constant, unnecessary
3594 # identity nodes.
3595 if check_model:
3596 check_onnx(onnx_model)
3597 if optim:
3598 from ..onnx_tools.optim import onnx_optimisations
3599 onnx_model = onnx_optimisations(onnx_model)
3600 if check_model:
3601 logger.debug(
3602 "_GraphBuilder-%d.to_onnx:check_onnx", id(self))
3603 check_onnx(onnx_model)
3605 logger.debug("_GraphBuilder-%d.to_onnx:optim:n_inputs=%r n_inits=%r "
3606 "n_nodes=%r n_outputs=%r",
3607 id(self), len(onnx_model.graph.input),
3608 len(onnx_model.graph.initializer),
3609 len(onnx_model.graph.node),
3610 len(onnx_model.graph.output))
3612 if run_shape:
3613 logger.debug("_GraphBuilder-%d.to_onnx:infer_shapes", id(self))
3614 with_shape = infer_shapes(onnx_model)
3615 logger.debug("_GraphBuilder-%d.to_onnx:shape:n_inputs=%r "
3616 "n_inits=%r n_nodes=%r n_outputs=%r",
3617 id(self), len(with_shape.graph.input),
3618 len(with_shape.graph.initializer),
3619 len(with_shape.graph.node),
3620 len(with_shape.graph.output))
3621 return with_shape
3623 logger.debug("_GraphBuilder-%d.to_onnx:mod -> done", id(self))
3624 logger.debug("_GraphBuilder-%d.to_onnx:mod:#####", id(self))
3625 return onnx_model
3628class _StaticVariables:
3629 """
3630 Holds static variables.
3631 """
3633 def __init__(self):
3634 self._all_schemas_ = None
3635 self._all_schemas_versions_ = None
3636 self._all_domains_ = None
3637 self._all_classes_ = None
3639 @property
3640 def all_schemas(self):
3641 "Returns all schemas."
3642 self.populate()
3643 return self._all_schemas_
3645 @property
3646 def all_classes(self):
3647 "Returns all operators wrapped in classes."
3648 self.populate()
3649 return self._all_classes_
3651 @property
3652 def all_schemas_versions(self):
3653 "Returns all operators, domains, versions."
3654 self.populate()
3655 return self._all_schemas_versions_
3657 @property
3658 def all_domains(self):
3659 "Returns all domains."
3660 self.populate()
3661 return self._all_domains_
3663 def populate(self):
3664 "Populates static variables."
3665 if self._all_schemas_ is not None:
3666 return
3667 (self._all_schemas_, self._all_schemas_versions_,
3668 self._all_domains_) = _populate_schemas()
3669 self._all_classes_ = {}
3672class OnnxExisting(OnnxOperator):
3673 """
3674 Wrapper around OnnxIdentity to specify this operator is
3675 not part of the subgraph it is used in.
3676 """
3678 _unique_names = set()
3680 expected_inputs = ['X']
3681 expected_outputs = ['Y']
3682 operator_name = 'Existing'
3683 input_range = [1, 1]
3684 output_range = [1, 1]
3685 domain = ''
3686 is_deprecated = False
3687 since_version = 1
3688 past_version = []
3689 attr_names = []
3690 op_type = 'Existing'
3691 __module__ = __name__
3693 @staticmethod
3694 def get_unique_name(var):
3695 """
3696 Returns a unique variable name.
3698 :param var: an instance of OnnxOperator.
3699 :return: unique variable name
3700 """
3701 if isinstance(var, OnnxOperator):
3702 name = "%s_%s" % ((var.domain or "").lower().replace(".", ""),
3703 var.op_type.lower())
3704 else:
3705 raise TypeError( # pragma: no cover
3706 f"Unexpected type {type(var)!r} for var.")
3707 i = 0
3708 new_name = "_exist_%s_%d" % (name, i)
3709 while new_name in OnnxExisting._unique_names:
3710 i += 1
3711 new_name = "_exist_%s_%d" % (name, i)
3712 OnnxExisting._unique_names.add(new_name)
3713 return new_name
3715 def __init__(self, *args, **kwargs): # pylint: disable=W0231
3716 # OnnxIdentity.__init__(self, *args, **kwargs) # pylint: disable=W0233
3717 OnnxOperator.__init__(self, *args, **kwargs) # pylint: disable=W0233
3718 self.control_ops_ = None
3719 if len(self.inputs) != 1:
3720 raise RuntimeError( # pragma: no cover
3721 f"Unexpected number of inputs {len(self.inputs)}.")
3722 if isinstance(self.inputs[0], Variable):
3723 # It is one input
3724 new_names = [
3725 ExistingVariable(self.inputs[0].name, self.inputs[0])]
3726 logger.debug("op:OnnxExisting-%d.__init__:set-input:1:%r",
3727 id(self), new_names)
3728 self.inputs[0].output_names = new_names
3729 else:
3730 if not isinstance(self.inputs[0], OnnxOperatorBase):
3731 raise TypeError( # pragma: no cover
3732 f"Only input should a node not {type(self.inputs[0])!r}.")
3733 if self.inputs[0].output_names is None:
3734 new_names = [
3735 ExistingVariable(OnnxExisting.get_unique_name(self.inputs[0]),
3736 self.inputs[0])]
3737 logger.debug("op:OnnxExisting-%d.__init__:set-input:2:%r",
3738 id(self), new_names)
3739 self.inputs[0].output_names = new_names
3741 def __repr__(self):
3742 """
3743 usual
3744 """
3745 return "{}({}) -> {}".format(
3746 self.__class__.__name__,
3747 self.inputs[0].output_names,
3748 [str(o) for o in self.output_names]
3749 if self.output_names is not None else "?")
3751 def find_named_inputs(self):
3752 """
3753 Retrieves all named inputs in this graph.
3754 """
3755 res = []
3756 for i, inp in enumerate(self.inputs[0].output_names):
3757 if not isinstance(inp, (Variable, ExistingVariable)):
3758 raise TypeError( # pragma: no cover
3759 "Unexpected type %r for input %r in node type %r."
3760 "" % (type(inp), i, type(self)))
3761 res.append(inp.name)
3762 return res
3764 def f(self, *inputs, verbose=0, fLOG=None, # pylint: disable=W0221
3765 clear_cache=False, runtime=None):
3766 "For the eager mode."
3767 raise NotImplementedError() # pragma: no cover
3769 def _set_control_op(self, op, subgraph_inputs=None):
3770 if subgraph_inputs is not None:
3771 raise NotImplementedError( # pragma: no cover
3772 "Not implemented.")
3773 if op is None:
3774 raise RuntimeError( # pragma: no cover
3775 "op cannot be None in _set_control_op.")
3776 logger.debug("op:%s-%d:_set_control_op:found:p:%d:%r",
3777 self.__class__.__name__, id(self), id(op),
3778 self.inputs[0].output_names)
3779 if self.control_ops_ is None:
3780 self.control_ops_ = []
3781 self.control_ops_.append(op)
3782 op.add_external_input(self.inputs[0])
3785_S = _StaticVariables()
3786onnx_load_factory = Xop = OnnxLoadFactory()