Coverage for mlprodict/onnxrt/ops_cpu/__init__.py: 81%
115 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2"""
3@file
4@brief Shortcut to *ops_cpu*.
5"""
6import inspect
7import textwrap
8from onnx import FunctionProto
9from onnx.reference.ops import load_op as onnx_load_op
10from onnx.defs import get_schema
11from ..excs import MissingOperatorError
12from ._op import OpRunCustom, OpFunction
13from ._op_list import __dict__ as d_op_list
16_additional_ops = {}
19def register_operator(cls, name=None, overwrite=True):
20 """
21 Registers a new runtime operator.
23 @param cls class
24 @param name by default ``cls.__name__``,
25 or *name* if defined
26 @param overwrite overwrite or raise an exception
27 """
28 if name is None:
29 name = cls.__name__
30 if name not in _additional_ops:
31 _additional_ops[name] = cls
32 elif not overwrite:
33 raise RuntimeError( # pragma: no cover
34 "Unable to overwrite existing operator '{}': {} "
35 "by {}".format(name, _additional_ops[name], cls))
38def load_op(onnx_node, desc=None, options=None, runtime=None):
39 """
40 Gets the operator related to the *onnx* node.
42 :param onnx_node: :epkg:`onnx` node
43 :param desc: internal representation
44 :param options: runtime options
45 :param runtime: runtime
46 :param existing_functions: existing functions
47 :return: runtime class
48 """
49 from ... import __max_supported_opset__
50 if desc is None:
51 raise ValueError("desc should not be None.") # pragma no cover
52 name = onnx_node.op_type
53 opset = options.get('target_opset', None) if options is not None else None
54 current_opset = __max_supported_opset__
55 chosen_opset = opset or current_opset
56 if opset is not None:
57 if not isinstance(opset, int):
58 raise TypeError( # pragma no cover
59 f"opset must be an integer not {type(opset)}")
60 name_opset = name + "_" + str(opset)
61 for op in range(opset, 0, -1):
62 nop = name + "_" + str(op)
63 if nop in d_op_list:
64 name_opset = nop
65 chosen_opset = op
66 break
67 else:
68 name_opset = name
70 onnx_op = False
71 if name_opset in _additional_ops:
72 cl = _additional_ops[name_opset]
73 elif name in _additional_ops:
74 cl = _additional_ops[name]
75 elif name_opset in d_op_list:
76 cl = d_op_list[name_opset]
77 elif name in d_op_list:
78 cl = d_op_list[name]
79 else:
80 # finish
81 try:
82 cl = onnx_load_op(options.get('domain', ''),
83 name, opset)
84 except ValueError as e:
85 raise MissingOperatorError(
86 f"Unable to load class for operator name={name}, "
87 f"opset={opset}, options={options}, "
88 f"_additional_ops={_additional_ops}.") from e
89 onnx_op = True
90 if cl is None:
91 raise MissingOperatorError( # pragma no cover
92 "Operator '{}' from domain '{}' has no runtime yet. "
93 "Available list:\n"
94 "{} - {}".format(
95 name, onnx_node.domain,
96 "\n".join(sorted(_additional_ops)),
97 "\n".join(textwrap.wrap(
98 " ".join(
99 _ for _ in sorted(d_op_list)
100 if "_" not in _ and _ not in {
101 'cl', 'clo', 'name'})))))
103 class _Wrapper:
105 def _log(self, *args, **kwargs):
106 pass
108 @property
109 def base_class(self):
110 "Returns the parent class."
111 return self.__class__.__bases__[0]
113 def _onnx_run(self, *args, **kwargs):
114 cl = self.base_class
115 new_kws = {}
116 for k, v in kwargs.items():
117 if k not in {'attributes', 'verbose', 'fLOG'}:
118 new_kws[k] = v
119 attributes = kwargs.get('attributes', None)
120 if attributes is not None and len(attributes) > 0:
121 raise NotImplementedError(
122 f"attributes is not empty but not implemented yet, "
123 f"attribures={attributes}.")
124 return cl.run(self, *args, **new_kws) # pylint: disable=E1101
126 def _onnx__run(self, *args, attributes=None, **kwargs):
127 """
128 Wraps ONNX call to OpRun._run.
129 """
130 cl = self.base_class
131 if attributes is not None and len(attributes) > 0:
132 raise NotImplementedError( # pragma: no cover
133 f"Linked attributes are not yet implemented for class "
134 f"{self.__class__!r}.")
135 return cl._run(self, *args, **kwargs) # pylint: disable=E1101
137 def _onnx_need_context(self):
138 cl = self.base_class
139 return cl.need_context(self) # pylint: disable=E1101
141 def __init__(self, onnx_node, desc=None, **options):
142 cl = self.__class__.__bases__[0]
143 run_params = {'log': _Wrapper._log,
144 'opsets': {'': opset},
145 'new_ops': None}
146 cl.__init__(self, onnx_node, run_params)
148 # wrapping the original class
149 if inspect.isfunction(cl):
150 domain = options.get('domain', '')
151 if domain != '':
152 raise TypeError(
153 f"Unable to create a class for operator {name!r} and "
154 f"opset {opset} based on {cl} of type={type(cl)}.")
155 schema = get_schema(name, opset, domain)
156 if schema.has_function:
157 from mlprodict.onnxrt import OnnxInference
158 body = schema.function_body
159 sess = OnnxInference(body)
160 new_cls = lambda *args, sess=sess: OpFunction(
161 args[0], impl=sess)
162 elif schema.has_context_dependent_function:
163 input_types = options.get('input_types', '')
164 if onnx_node is None or input_types is None:
165 raise RuntimeError(
166 f"No registered implementation for operator {onnx_node.op_type!r} "
167 f"and domain {domain!r}, the operator has a context dependent function. "
168 f"but argument node or input_types is not defined.")
169 from mlprodict.onnxrt import OnnxInference
170 body = schema.get_context_dependent_function(
171 onnx_node.SerializeToString(),
172 [it.SerializeToString() for it in input_types])
173 proto = FunctionProto()
174 proto.ParseFromString(body)
175 sess = OnnxInference(proto)
176 new_cls = lambda *args, sess=sess: OpFunction(
177 args[0], impl=sess)
178 else:
179 raise TypeError(
180 f"Unable to create a class for operator {name!r} and "
181 f"opset {opset} based on {cl} of type={type(cl)}.")
182 else:
183 try:
184 new_cls = type(f"{name}_{opset}", (cl, ),
185 {'__init__': _Wrapper.__init__,
186 '_run': _Wrapper._onnx__run,
187 'base_class': _Wrapper.base_class,
188 'run': _Wrapper._onnx_run,
189 'need_context': _Wrapper._onnx_need_context})
190 except TypeError as e:
191 raise TypeError(
192 f"Unable to create a class for operator {name!r} and "
193 f"opset {opset} based on {cl} of type={type(cl)}.") from e
194 cl = new_cls
196 if hasattr(cl, 'version_higher_than'):
197 opv = min(current_opset, chosen_opset)
198 if cl.version_higher_than > opv:
199 # The chosen implementation does not support
200 # the opset version, we need to downgrade it.
201 if ('target_opset' in options and
202 options['target_opset'] is not None): # pragma: no cover
203 raise RuntimeError(
204 "Supported version {} > {} (opset={}) required version, "
205 "unable to find an implementation version {} found "
206 "'{}'\n--ONNX--\n{}\n--AVAILABLE--\n{}".format(
207 cl.version_higher_than, opv, opset,
208 options['target_opset'], cl.__name__, onnx_node,
209 "\n".join(
210 _ for _ in sorted(d_op_list)
211 if "_" not in _ and _ not in {'cl', 'clo', 'name'})))
212 options = options.copy()
213 options['target_opset'] = current_opset
214 return load_op(onnx_node, desc=desc, options=options)
216 if options is None:
217 options = {} # pragma: no cover
218 if onnx_op:
219 try:
220 return cl(onnx_node, {'log': None})
221 except TypeError as e:
222 raise TypeError( # pragma: no cover
223 f"Unexpected issue with class {cl}.") from e
224 try:
225 return cl(onnx_node, desc=desc, runtime=runtime, **options)
226 except TypeError as e:
227 raise TypeError( # pragma: no cover
228 f"Unexpected issue with class {cl}.") from e