Coverage for mlprodict/npy/onnx_numpy_wrapper.py: 97%
109 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Wraps :epkg:`numpy` functions into :epkg:`onnx`.
5.. versionadded:: 0.6
6"""
7import warnings
8from .onnx_version import FctVersion
9from .onnx_numpy_annotation import get_args_kwargs
10from .onnx_numpy_compiler import OnnxNumpyCompiler
13class _created_classes:
14 """
15 Class to store all dynamic classes created by wrappers.
16 """
18 def __init__(self):
19 self.stored = {}
21 def append(self, name, cl):
22 """
23 Adds a class into `globals()` to enable pickling on dynamic
24 classes.
25 """
26 if name in self.stored:
27 warnings.warn( # pragma: no cover
28 "Class %r overwritten in\n%r\n---\n%r" % (
29 name, ", ".join(sorted(self.stored)), cl),
30 RuntimeWarning)
31 self.stored[name] = cl
32 globals()[name] = cl
35_created_classes_inst = _created_classes()
38class wrapper_onnxnumpy:
39 """
40 Intermediate wrapper to store a pointer
41 on the compiler (type: @see cl OnnxNumpyCompiler).
43 :param compiled: instance of @see cl OnnxNumpyCompiler
45 .. versionadded:: 0.6
46 """
48 def __init__(self, compiled):
49 self.compiled = compiled
51 def __call__(self, *args, **kwargs):
52 """
53 Calls the compiled function with arguments `args`.
54 """
55 from .onnx_variable import OnnxVar
56 try:
57 return self.compiled(*args, **kwargs)
58 except (TypeError, RuntimeError, ValueError) as e:
59 if any(map(lambda a: isinstance(a, OnnxVar), args)):
60 return self.__class__.__fct__( # pylint: disable=E1101
61 *args, **kwargs)
62 raise RuntimeError(
63 "Unable to call the compiled version, args is %r. "
64 "kwargs=%r." % ([type(a) for a in args], kwargs)) from e
66 def __getstate__(self):
67 """
68 Serializes everything but the function which generates
69 the ONNX graph, not needed anymore.
70 """
71 return dict(compiled=self.compiled)
73 def __setstate__(self, state):
74 """
75 Serializes everything but the function which generates
76 the ONNX graph, not needed anymore.
77 """
78 self.compiled = state['compiled']
80 def to_onnx(self, **kwargs):
81 """
82 Returns the ONNX graph for the wrapped function.
83 It takes additional arguments to distinguish between multiple graphs.
84 This happens when a function needs to support multiple type.
86 :return: ONNX graph
87 """
88 return self.compiled.to_onnx(**kwargs)
91def onnxnumpy(op_version=None, runtime=None, signature=None):
92 """
93 Decorator to declare a function implemented using
94 :epkg:`numpy` syntax but executed with :epkg:`ONNX`
95 operators.
97 :param op_version: :epkg:`ONNX` opset version
98 :param runtime: `'onnxruntime'` or one implemented by
99 @see cl OnnxInference
100 :param signature: it should be used when the function
101 is not annoatated.
103 Equivalent to `onnxnumpy(arg)(foo)`.
105 .. versionadded:: 0.6
106 """
107 def decorator_fct(fct):
108 compiled = OnnxNumpyCompiler(
109 fct, op_version=op_version, runtime=runtime,
110 signature=signature)
111 name = f"onnxnumpy_{fct.__name__}_{str(op_version)}_{runtime}"
112 newclass = type(
113 name, (wrapper_onnxnumpy,),
114 {'__doc__': fct.__doc__, '__name__': name, '__fct__': fct})
115 _created_classes_inst.append(name, newclass)
116 return newclass(compiled)
117 return decorator_fct
120def onnxnumpy_default(fct):
121 """
122 Decorator with options to declare a function implemented
123 using :epkg:`numpy` syntax but executed with :epkg:`ONNX`
124 operators.
126 :param fct: function to wrap
128 .. versionadded:: 0.6
129 """
130 return onnxnumpy()(fct)
133class wrapper_onnxnumpy_np:
134 """
135 Intermediate wrapper to store a pointer
136 on the compiler (type: @see cl OnnxNumpyCompiler)
137 supporting multiple signatures.
139 .. versionadded:: 0.6
140 """
142 def __init__(self, **kwargs):
143 self.fct = kwargs['fct']
144 self.signature = kwargs['signature']
145 self.fctsig = kwargs.get('fctsig', None)
146 self.args, self.kwargs = get_args_kwargs(
147 self.fct,
148 0 if self.signature is None else self.signature.n_optional)
149 self.data = kwargs
150 self.signed_compiled = {}
152 def __getstate__(self):
153 """
154 Serializes everything but the function which generates
155 the ONNX graph, not needed anymore.
156 """
157 data_copy = {k: v for k, v in self.data.items() if k != 'fct'}
158 return dict(signature=self.signature, args=self.args,
159 kwargs=self.kwargs, data=data_copy,
160 signed_compiled=self.signed_compiled)
162 def __setstate__(self, state):
163 """
164 Restores serialized data.
165 """
166 for k, v in state.items():
167 setattr(self, k, v)
169 def __getitem__(self, dtype):
170 """
171 Returns the instance of @see cl wrapper_onnxnumpy
172 mapped to *dtype*.
174 :param dtype: numpy dtype corresponding to the input dtype
175 of the function
176 :return: instance of @see cl wrapper_onnxnumpy
177 """
178 if not isinstance(dtype, FctVersion):
179 raise TypeError( # pragma: no cover
180 f"dtype must be of type 'FctVersion' not {type(dtype)}: {dtype}.")
181 if dtype not in self.signed_compiled:
182 self._populate(dtype)
183 key = dtype
184 else:
185 key = dtype
186 return self.signed_compiled[key]
188 def __call__(self, *args, **kwargs):
189 """
190 Calls the compiled function assuming the type of the first
191 tensor in *args* defines the templated version of the function
192 to convert into *ONNX*.
193 """
194 from .onnx_variable import OnnxVar
195 if len(self.kwargs) == 0:
196 others = None
197 else:
198 others = tuple(kwargs.get(k, self.kwargs[k]) for k in self.kwargs)
199 try:
200 key = FctVersion( # pragma: no cover
201 tuple(a if (a is None or hasattr(a, 'fit'))
202 else a.dtype.type for a in args),
203 others)
204 return self[key](*args)
205 except AttributeError as e:
206 if any(map(lambda a: isinstance(a, OnnxVar), args)):
207 return self.__class__.__fct__( # pylint: disable=E1101
208 *args, **kwargs)
209 raise RuntimeError(
210 "Unable to call the compiled version, args is %r. "
211 "kwargs=%r." % ([type(a) for a in args], kwargs)) from e
213 def _populate(self, version):
214 """
215 Creates the appropriate runtime for function *fct*
216 """
217 compiled = OnnxNumpyCompiler(
218 fct=self.data["fct"], op_version=self.data["op_version"],
219 runtime=self.data["runtime"], signature=self.data["signature"],
220 version=version, fctsig=self.data.get('fctsig', None))
221 name = "onnxnumpy_np_%s_%s_%s_%s" % (
222 self.data["fct"].__name__, str(self.data["op_version"]),
223 self.data["runtime"], version.as_string())
224 newclass = type(
225 name, (wrapper_onnxnumpy,),
226 {'__doc__': self.data["fct"].__doc__, '__name__': name})
228 self.signed_compiled[version] = newclass(compiled)
230 def _validate_onnx_data(self, X):
231 return X
233 def to_onnx(self, **kwargs):
234 """
235 Returns the ONNX graph for the wrapped function.
236 It takes additional arguments to distinguish between multiple graphs.
237 This happens when a function needs to support multiple type.
239 :return: ONNX graph
240 """
241 if len(self.signed_compiled) == 0:
242 raise RuntimeError( # pragma: no cover
243 "No ONNX graph was compiled.")
244 if len(kwargs) == 0 and len(self.signed_compiled) == 1:
245 # We take the only one.
246 key = list(self.signed_compiled)[0]
247 cpl = self.signed_compiled[key]
248 return cpl.to_onnx()
249 if len(kwargs) == 0:
250 raise ValueError(
251 "There are multiple compiled ONNX graphs associated "
252 "with keys %r (add key=...)." % list(self.signed_compiled))
253 if list(kwargs) != ['key']:
254 raise ValueError(
255 f"kwargs should contain one parameter key=... but it is {kwargs!r}.")
256 key = kwargs['key']
257 if key in self.signed_compiled:
258 return self.signed_compiled[key].compiled.onnx_
259 found = []
260 for k, v in self.signed_compiled.items():
261 if k.args == key:
262 found.append((k, v))
263 elif isinstance(key, tuple) and k.args == key:
264 found.append((k, v))
265 elif k.args == (key, ) * len(k.args):
266 found.append((k, v))
267 if len(found) == 1:
268 return found[0][1].compiled.onnx_
269 raise ValueError(
270 "Unable to find signature with key=%r among %r found=%r." % (
271 key, list(self.signed_compiled), found))
274def onnxnumpy_np(op_version=None, runtime=None, signature=None):
275 """
276 Decorator to declare a function implemented using
277 :epkg:`numpy` syntax but executed with :epkg:`ONNX`
278 operators.
280 :param op_version: :epkg:`ONNX` opset version
281 :param runtime: `'onnxruntime'` or one implemented by @see cl OnnxInference
282 :param signature: it should be used when the function
283 is not annoatated.
285 Equivalent to `onnxnumpy(arg)(foo)`.
287 .. versionadded:: 0.6
288 """
289 def decorator_fct(fct):
290 name = f"onnxnumpy_nb_{fct.__name__}_{str(op_version)}_{runtime}"
291 newclass = type(
292 name, (wrapper_onnxnumpy_np,), {
293 '__doc__': fct.__doc__,
294 '__name__': name,
295 '__getstate__': wrapper_onnxnumpy_np.__getstate__,
296 '__setstate__': wrapper_onnxnumpy_np.__setstate__,
297 '__fct__': fct})
298 _created_classes_inst.append(name, newclass)
299 return newclass(
300 fct=fct, op_version=op_version, runtime=runtime,
301 signature=signature)
303 return decorator_fct