Coverage for mlprodict/npy/onnx_numpy_annotation.py: 97%
174 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief :epkg:`numpy` annotations.
5.. versionadded:: 0.6
6"""
7import inspect
8from collections import OrderedDict
9from typing import TypeVar, Generic
10import numpy
11from .onnx_version import FctVersion
13try:
14 numpy_bool = numpy.bool_
15except AttributeError: # pragma: no cover
16 numpy_bool = bool
18try:
19 numpy_str = numpy.str_
20except AttributeError: # pragma: no cover
21 numpy_str = str
23Shape = TypeVar("Shape")
24DType = TypeVar("DType")
27all_dtypes = (numpy.float32, numpy.float64,
28 numpy.int32, numpy.int64,
29 numpy.uint32, numpy.uint64)
32def get_args_kwargs(fct, n_optional):
33 """
34 Extracts arguments and optional parameters of a function.
36 :param fct: function
37 :param n_optional: number of arguments to consider as
38 optional arguments and not parameters, this parameter skips
39 the first *n_optional* paramerters
40 :return: arguments, OrderedDict
42 Any optional argument ending with '_' is ignored.
43 """
44 params = inspect.signature(fct).parameters
45 if n_optional == 0:
46 items = list(params.items())
47 args = [name for name, p in params.items()
48 if p.default == inspect.Parameter.empty]
49 else:
50 items = []
51 args = []
52 for name, p in params.items():
53 if p.default == inspect.Parameter.empty:
54 args.append(name)
55 else:
56 if n_optional > 0:
57 args.append(name)
58 n_optional -= 1
59 else:
60 items.append((name, p))
62 kwargs = OrderedDict((name, p.default) for name, p in items
63 if (p.default != inspect.Parameter.empty and
64 name != 'op_version'))
65 if args[0] == 'self':
66 args = args[1:]
67 kwargs['op_'] = None
68 return args, kwargs
71class NDArray(numpy.ndarray, Generic[Shape, DType]):
72 """
73 Used to annotation ONNX numpy functions.
75 .. versionadded:: 0.6
76 """
77 class ShapeType:
78 "Stores shape information."
80 def __init__(self, params):
81 self.__args__ = params
83 def __class_getitem__(cls, params): # pylint: disable=W0221,W0237
84 "Overwrites this method."
85 if not isinstance(params, tuple):
86 params = (params,) # pragma: no cover
87 return NDArray.ShapeType(params)
90class _NDArrayAlias:
91 """
92 Ancestor to custom signature.
94 :param dtypes: input dtypes
95 :param dtypes_out: output dtypes
96 :param n_optional: number of optional parameters, 0 by default
97 :param nvars: True if the function allows an infinite number of inputs,
98 this is incompatible with parameter *n_optional*.
100 *dtypes*, *dtypes_out* by default are a tuple of tuple:
102 * first dimension: type of every input
103 * second dimension: list of types for one input
105 .. versionadded:: 0.6
106 """
108 def __init__(self, dtypes=None, dtypes_out=None, n_optional=None,
109 nvars=False):
110 "constructor"
111 if dtypes is None:
112 raise ValueError("dtypes cannot be None.") # pragma: no cover
113 if isinstance(dtypes, tuple) and len(dtypes) == 0:
114 raise TypeError("dtypes must not be empty.") # pragma: no cover
115 if isinstance(dtypes, tuple) and not isinstance(dtypes[0], tuple):
116 dtypes = tuple(t if isinstance(t, str) else (t,) for t in dtypes)
117 if isinstance(dtypes, str) and '_' in dtypes:
118 dtypes, dtypes_out = dtypes.split('_')
119 if not isinstance(dtypes, (tuple, list)):
120 dtypes = (dtypes, )
122 self.mapped_types = {}
123 self.dtypes = _NDArrayAlias._process_type(
124 dtypes, self.mapped_types, 0)
125 if dtypes_out is None:
126 self.dtypes_out = (self.dtypes[0], )
127 elif isinstance(dtypes_out, int):
128 self.dtypes_out = (self.dtypes[dtypes_out], )
129 else:
130 if not isinstance(dtypes_out, (tuple, list)):
131 dtypes_out = (dtypes_out, )
132 self.dtypes_out = _NDArrayAlias._process_type(
133 dtypes_out, self.mapped_types, 0)
134 self.n_optional = 0 if n_optional is None else n_optional
135 self.n_variables = nvars
137 if not isinstance(self.dtypes, tuple):
138 raise TypeError( # pragma: no cover
139 f"self.dtypes must be a tuple not {self.dtypes}.")
140 if (len(self.dtypes) == 0 or
141 not isinstance(self.dtypes[0], tuple)):
142 raise TypeError( # pragma: no cover
143 f"Type mismatch in self.dtypes: {self.dtypes}.")
144 if (len(self.dtypes[0]) == 0 or
145 isinstance(self.dtypes[0][0], tuple)):
146 raise TypeError( # pragma: no cover
147 f"Type mismatch in self.dtypes: {self.dtypes}.")
149 if not isinstance(self.dtypes_out, tuple):
150 raise TypeError( # pragma: no cover
151 f"self.dtypes_out must be a tuple not {self.dtypes_out}.")
152 if (len(self.dtypes_out) == 0 or
153 not isinstance(self.dtypes_out[0], tuple)):
154 raise TypeError( # pragma: no cover
155 "Type mismatch in self.dtypes_out={}, "
156 "self.dtypes={}.".format(self.dtypes_out, self.dtypes))
157 if (len(self.dtypes_out[0]) == 0 or
158 isinstance(self.dtypes_out[0][0], tuple)):
159 raise TypeError( # pragma: no cover
160 f"Type mismatch in self.dtypes_out: {self.dtypes_out}.")
162 if self.n_variables and self.n_optional > 0:
163 raise RuntimeError( # pragma: no cover
164 "n_variables and n_optional cannot be positive at "
165 "the same type.")
167 @staticmethod
168 def _process_type(dtypes, mapped_types, index):
169 """
170 Nicknames such as `floats`, `int`, `ints`, `all`
171 can be used to describe multiple inputs for
172 a signature. This function intreprets that.
174 .. runpython::
175 :showcode:
177 from mlprodict.npy.onnx_numpy_annotation import _NDArrayAlias
178 for name in ['all', 'int', 'ints', 'floats', 'T']:
179 print(name, _NDArrayAlias._process_type(name, {'T': 0}, 0))
180 """
181 if isinstance(dtypes, str):
182 if ":" in dtypes:
183 name, dtypes = dtypes.split(':')
184 if name in mapped_types and dtypes != mapped_types[name]:
185 raise RuntimeError( # pragma: no cover
186 "Type name mismatch for '%s:%s' in %r." % (
187 name, dtypes, list(sorted(mapped_types))))
188 mapped_types[name] = (dtypes, index)
189 if dtypes == "all":
190 dtypes = all_dtypes
191 elif dtypes in ("int", "int64"):
192 dtypes = (numpy.int64, )
193 elif dtypes == "bool":
194 dtypes = (numpy_bool, )
195 elif dtypes == "floats":
196 dtypes = (numpy.float32, numpy.float64)
197 elif dtypes == "ints":
198 dtypes = (numpy.int32, numpy.int64)
199 elif dtypes == "int64":
200 dtypes = (numpy.int64, )
201 elif dtypes == "float32":
202 dtypes = (numpy.float32, )
203 elif dtypes == "float64":
204 dtypes = (numpy.float64, )
205 elif dtypes not in mapped_types:
206 raise ValueError( # pragma: no cover
207 f"Unexpected shortcut for dtype {dtypes!r}.")
208 elif not isinstance(dtypes, tuple):
209 dtypes = (dtypes, )
210 return dtypes
212 if isinstance(dtypes, (tuple, list)):
213 insig = [_NDArrayAlias._process_type(dt, mapped_types, index + d)
214 for d, dt in enumerate(dtypes)]
215 return tuple(insig)
217 if dtypes in all_dtypes:
218 return dtypes
220 raise NotImplementedError( # pragma: no cover
221 f"Unexpected input dtype {dtypes!r}.")
223 def __repr__(self):
224 "usual"
225 return "%s(%r, %r, %r)" % (
226 self.__class__.__name__, self.dtypes, self.dtypes_out,
227 self.n_optional)
229 def _get_output_types(self, key):
230 """
231 Tries to infer output types.
232 """
233 res = []
234 for i, o in enumerate(self.dtypes_out):
235 if not isinstance(o, tuple):
236 raise TypeError( # pragma: no cover
237 "All outputs must be tuple, output %d is %r."
238 "" % (i, o))
239 if (len(o) == 1 and (o[0] in all_dtypes or
240 o[0] in (bool, numpy_bool, str, numpy_str))):
241 res.append(o[0])
242 elif len(o) == 1 and o[0] in self.mapped_types:
243 info = self.mapped_types[o[0]]
244 res.append(key[info[1]])
245 elif key[0] in o:
246 res.append(key[0])
247 else:
248 raise RuntimeError( # pragma: no cover
249 "Unable to guess output type for output %d, "
250 "input types are %r, expected output is %r."
251 "" % (i, key, o))
252 return tuple(res)
254 def get_inputs_outputs(self, args, kwargs, version):
255 """
256 Returns the list of inputs, outputs.
258 :param args: list of arguments
259 :param kwargs: list of optional arguments
260 :param version: required version
261 :return: *tuple(inputs, kwargs, outputs, optional)*,
262 inputs and outputs are tuple, kwargs are the arguments,
263 *optional* is the number of optional arguments
264 """
265 if not isinstance(version, FctVersion):
266 raise TypeError("Version must be of type 'FctVersion' not "
267 "%s, version=%s." % (type(version), version))
268 if args == ['args', 'kwargs']:
269 raise RuntimeError( # pragma: no cover
270 f"Issue with signature {args!r}.")
271 for k, v in kwargs.items():
272 if isinstance(v, type):
273 raise RuntimeError( # pragma: no cover
274 f"Default value for argument {k!r} must not be of type {v!r}.")
275 if (not self.n_variables and
276 len(args) > len(self.dtypes)):
277 raise RuntimeError(
278 "Unexpected number of inputs version=%s.\n"
279 "Given: args=%s dtypes=%s." % (
280 version, args, self.dtypes))
282 def _possible_names():
283 yield 'y'
284 yield 'z' # pragma: no cover
285 yield 'o' # pragma: no cover
286 for i in range(0, 10000): # pragma: no cover
287 yield 'o%d' % i
289 new_kwargs = OrderedDict(
290 (k, v) for k, v in zip(kwargs, version.kwargs or tuple()))
291 if self.n_variables:
292 # undefined number of inputs
293 optional = 0
294 else:
295 optional = len(self.dtypes) - len(version.args)
296 if optional > self.n_optional:
297 raise RuntimeError( # pragma: no cover
298 "Unexpected number of optional parameters %d, at most "
299 "%d are expected, version=%s, args=%s, dtypes=%s." % (
300 optional, self.n_optional, version, args, self.dtypes))
301 optional = self.n_optional - optional
303 onnx_types = [k for k in version.args]
304 inputs = list(zip(args[:len(version.args)], onnx_types))
305 if self.n_variables and len(inputs) < len(version.args):
306 # Complete the list of inputs
307 last_name = inputs[-1][0]
308 while len(inputs) < len(onnx_types):
309 inputs.append((f'{last_name}{len(inputs)}',
310 onnx_types[len(inputs)]))
312 key_out = self._get_output_types(version.args)
313 onnx_types_out = key_out
315 names_out = []
316 names_in = set(inp[0] for inp in inputs)
317 for _ in key_out:
318 for name in _possible_names():
319 if name not in names_in and name not in names_out:
320 name_out = name
321 break
322 names_out.append(name_out)
323 names_in.add(name_out)
325 outputs = list(zip(names_out, onnx_types_out))
326 if optional < 0:
327 raise RuntimeError( # pragma: no cover
328 "optional cannot be negative %r (self.n_optional=%r, "
329 "len(self.dtypes)=%r, len(inputs)=%r) "
330 "names_in=%r, names_out=%r." % (
331 optional, self.n_optional, len(self.dtypes),
332 len(inputs), names_in, names_out))
334 if (not self.n_variables and
335 len(inputs) + len(new_kwargs) > len(version)):
336 raise RuntimeError( # pragma: no cover
337 "Mismatch number of inputs and arguments for version=%s.\n"
338 "Given: args=%s kwargs=%s.\n"
339 "Returned: inputs=%s new_kwargs=%s.\n" % (
340 version, args, kwargs, inputs, new_kwargs))
341 if not self.n_variables and len(inputs) > len(self.dtypes):
342 raise RuntimeError( # pragma: no cover
343 "Mismatch number of inputs for version=%s.\n"
344 "Given: args=%s.\n"
345 "Expected: dtypes=%s\n"
346 "Returned: inputs=%s.\n" % (
347 version, args, self.dtypes, inputs))
349 return inputs, kwargs, outputs, optional, self.n_variables
351 def shape_calculator(self, dims):
352 """
353 Returns expected dimensions given the input dimensions.
354 """
355 if len(dims) == 0:
356 return None
357 res = [dims[0]]
358 for _ in dims[1:]:
359 res.append(None)
360 return res
363class NDArrayType(_NDArrayAlias):
364 """
365 Shortcut to simplify signature description.
367 :param dtypes: input dtypes
368 :param dtypes_out: output dtypes
369 :param n_optional: number of optional parameters, 0 by default
370 :param nvars: True if the function allows an infinite number of inputs,
371 this is incompatible with parameter *n_optional*.
373 .. versionadded:: 0.6
374 """
376 def __init__(self, dtypes=None, dtypes_out=None, n_optional=None, nvars=False):
377 _NDArrayAlias.__init__(self, dtypes=dtypes, dtypes_out=dtypes_out,
378 n_optional=n_optional, nvars=nvars)
381class NDArrayTypeSameShape(NDArrayType):
382 """
383 Shortcut to simplify signature description.
385 :param dtypes: input dtypes
386 :param dtypes_out: output dtypes
387 :param n_optional: number of optional parameters, 0 by default
388 :param nvars: True if the function allows an infinite number of inputs,
389 this is incompatible with parameter *n_optional*.
391 .. versionadded:: 0.6
392 """
394 def __init__(self, dtypes=None, dtypes_out=None, n_optional=None, nvars=False):
395 NDArrayType.__init__(self, dtypes=dtypes, dtypes_out=dtypes_out,
396 n_optional=n_optional, nvars=nvars)
399class NDArraySameType(NDArrayType):
400 """
401 Shortcut to simplify signature description.
403 :param dtypes: input dtypes
405 .. versionadded:: 0.6
406 """
408 def __init__(self, dtypes=None):
409 if dtypes is None:
410 raise ValueError("dtypes cannot be None.") # pragma: no cover
411 if isinstance(dtypes, str) and "_" in dtypes:
412 raise ValueError( # pragma: no cover
413 "dtypes cannot include '_' meaning two different types.")
414 if isinstance(dtypes, tuple):
415 raise ValueError( # pragma: no cover
416 "dtypes must be a single type.")
417 NDArrayType.__init__(self, dtypes=(dtypes, ))
419 def __repr__(self):
420 "usual"
421 return f"{self.__class__.__name__}({self.dtypes!r})"
424class NDArraySameTypeSameShape(NDArraySameType):
425 """
426 Shortcut to simplify signature description.
428 :param dtypes: input dtypes
430 .. versionadded:: 0.6
431 """
433 def __init__(self, dtypes=None):
434 NDArraySameType.__init__(self, dtypes=dtypes)