Coverage for mlprodict/onnx_tools/onnx2py_helper.py: 92%
475 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Functions which converts :epkg:`ONNX` object into
4readable :epkg:`python` objects.
5"""
6import pprint
7import warnings
8import numpy
9from scipy.sparse import coo_matrix
10from onnx.defs import get_schema, get_function_ops, onnx_opset_version
11from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE, TENSOR_TYPE_TO_NP_TYPE
12from onnx import TensorProto, ValueInfoProto, TypeProto, TensorShapeProto
13from onnx.helper import make_tensor_type_proto
14from onnx.numpy_helper import to_array, from_array as onnx_from_array
17def get_tensor_shape(obj):
18 """
19 Returns the shape if that makes sense for this object.
20 """
21 if isinstance(obj, ValueInfoProto):
22 return get_tensor_shape(obj.type)
23 elif not isinstance(obj, TypeProto):
24 raise TypeError( # pragma: no cover
25 f"Unexpected type {type(obj)!r}.")
26 if not obj.tensor_type.HasField('shape'):
27 return None
28 shape = []
29 for d in obj.tensor_type.shape.dim:
30 v = d.dim_value if d.dim_value > 0 else d.dim_param
31 shape.append(v)
32 if len(shape) == 0:
33 return shape
34 return list(None if s in (0, '') else s for s in shape)
37def get_tensor_elem_type(obj):
38 """
39 Returns the element type if that makes sense for this object.
40 """
41 if isinstance(obj, ValueInfoProto):
42 return get_tensor_elem_type(obj.type)
43 elif not isinstance(obj, TypeProto):
44 raise TypeError( # pragma: no cover
45 f"Unexpected type {type(obj)!r}.")
46 if obj.tensor_type.ByteSize() == 0:
47 raise TypeError( # pragma: no cover
48 f"Unable to guess element type for {obj!r}.")
49 return obj.tensor_type.elem_type
52def to_bytes(val):
53 """
54 Converts an array into protobuf and then into bytes.
56 :param val: array
57 :return: bytes
59 .. exref::
60 :title: Converts an array into bytes (serialization)
62 Useful to serialize.
64 .. runpython::
65 :showcode:
66 :warningout: DeprecationWarning
68 import numpy
69 from mlprodict.onnx_tools.onnx2py_helper import to_bytes
71 data = numpy.array([[0, 1], [2, 3], [4, 5]], dtype=numpy.float32)
72 pb = to_bytes(data)
73 print(len(pb), data.size * data.itemsize, pb[:10])
74 """
75 if isinstance(val, numpy.ndarray):
76 pb = from_array(val)
77 else:
78 pb = val # pragma: no cover
79 return pb.SerializeToString()
82def from_array(value, name=None):
83 """
84 Converts an array into an ONNX tensor.
86 :param value: numpy array
87 :return: ONNX tensor
88 """
89 if isinstance(value, numpy.ndarray):
90 try:
91 pb = onnx_from_array(value, name=name)
92 except NotImplementedError as e: # pragma: no cover
93 if value.dtype == numpy.dtype('O'):
94 pb = TensorProto()
95 pb.data_type = TensorProto.STRING # pylint: disable=E1101
96 if name is not None:
97 pb.name = name
98 pb.dims.extend(value.shape) # pylint: disable=E1101
99 pb.string_data.extend( # pylint: disable=E1101
100 list(map(lambda o: str(o).encode('utf-8'), value.ravel())))
101 else:
102 raise NotImplementedError(
103 "Unable to convert type %r (dtype=%r) into an ONNX tensor "
104 "due to %r." % (type(value), value.dtype, e)) from e
105 return pb
106 if isinstance(value, TensorProto): # pragma: no cover
107 return value
108 raise NotImplementedError( # pragma: no cover
109 f"Unable to convert type {type(value)!r} into an ONNX tensor.")
112def from_bytes(b):
113 """
114 Retrieves an array from bytes then protobuf.
116 :param b: bytes
117 :return: array
119 .. exref::
120 :title: Converts bytes into an array (serialization)
122 Useful to deserialize.
124 .. runpython::
125 :showcode:
126 :warningout: DeprecationWarning
128 import numpy
129 from mlprodict.onnx_tools.onnx2py_helper import to_bytes, from_bytes
131 data = numpy.array([[0, 1], [2, 3], [4, 5]], dtype=numpy.float32)
132 pb = to_bytes(data)
133 data2 = from_bytes(pb)
134 print(data2)
135 """
136 if isinstance(b, bytes):
137 pb = TensorProto()
138 pb.ParseFromString(b)
139 else:
140 pb = b # pragma: no cover
141 return to_array(pb)
144def _numpy_array(data, dtype=None, copy=True):
145 """
146 Single function to create an array.
148 @param data data
149 @param dtype dtype
150 @param copy copy
151 @return numpy array
152 """
153 if isinstance(data, numpy.ndarray):
154 res = data
155 else:
156 res = numpy.array(data, dtype=dtype, copy=copy)
157 return res
160def _sparse_array(shape, data, indices, dtype=None, copy=True):
161 """
162 Single function to create an sparse array
163 (:epkg:`coo_matrix`).
165 @param shape shape
166 @param data data
167 @param indices indices
168 @param dtype dtype
169 @param copy copy
170 @return :epkg:`coo_matrix`
171 """
172 if len(shape) != 2:
173 raise ValueError( # pragma: no cover
174 f"Only matrices are allowed or sparse matrices but shape is {shape}.")
175 rows = numpy.array([i // shape[1] for i in indices])
176 cols = numpy.array([i % shape[1] for i in indices])
177 if isinstance(data, numpy.ndarray):
178 res = coo_matrix((data, (rows, cols)), dtype=dtype)
179 else:
180 res = coo_matrix( # pragma: no cover
181 (numpy.array(data, dtype=dtype, copy=copy),
182 (rows, cols)), dtype=dtype)
183 return res
186def guess_numpy_type_from_string(name):
187 """
188 Converts a string (such as `'float'`) into a
189 numpy dtype.
190 """
191 if name in ('float', 'float32'):
192 return numpy.float32
193 if name in ('double', 'float64'):
194 return numpy.float64
195 if name == 'float16':
196 return numpy.float16
197 if name == 'int64':
198 return numpy.int64
199 if name == 'int8':
200 return numpy.int8
201 if name == 'uint8':
202 return numpy.uint8
203 if name == 'int32':
204 return numpy.int32
205 if name == 'int16':
206 return numpy.int16
207 if name == 'bool':
208 return numpy.bool_
209 if name == 'str':
210 return numpy.str_
211 raise ValueError( # pragma: no cover
212 f"Unable to guess numpy dtype from {name!r}.")
215def guess_numpy_type_from_dtype(dt):
216 """
217 Converts a string (such as `'dtype(float32)'`) into a
218 numpy dtype.
219 """
220 if dt in {numpy.int8, numpy.uint8, numpy.float16, numpy.float32,
221 numpy.float64, numpy.int32, numpy.int64, numpy.int16,
222 numpy.uint16, numpy.uint32, numpy.bool_, numpy.str_,
223 numpy.uint64, bool, str, }:
224 return dt
225 if dt == numpy.dtype('float32'):
226 return numpy.float32
227 if dt == numpy.dtype('float64'):
228 return numpy.float64
229 if dt == numpy.dtype('int64'):
230 return numpy.int64
231 if dt == numpy.dtype('int8'):
232 return numpy.int8
233 if dt == numpy.dtype('uint8'):
234 return numpy.uint8
235 raise ValueError( # pragma: no cover
236 f"Unable to guess numpy dtype from {dt!r}.")
239def _elem_type_as_str(elem_type):
240 if elem_type == TensorProto.FLOAT: # pylint: disable=E1101
241 return 'float'
242 if elem_type == TensorProto.BOOL: # pylint: disable=E1101
243 return 'bool'
244 if elem_type == TensorProto.DOUBLE: # pylint: disable=E1101
245 return 'double'
246 if elem_type == TensorProto.STRING: # pylint: disable=E1101
247 return 'str'
248 if elem_type == TensorProto.INT64: # pylint: disable=E1101
249 return 'int64'
250 if elem_type == TensorProto.INT32: # pylint: disable=E1101
251 return 'int32'
252 if elem_type == TensorProto.UINT32: # pylint: disable=E1101
253 return 'uint32'
254 if elem_type == TensorProto.UINT64: # pylint: disable=E1101
255 return 'uint64'
256 if elem_type == TensorProto.INT16: # pylint: disable=E1101
257 return 'int16'
258 if elem_type == TensorProto.UINT16: # pylint: disable=E1101
259 return 'uint16'
260 if elem_type == TensorProto.UINT8: # pylint: disable=E1101
261 return 'uint8'
262 if elem_type == TensorProto.INT8: # pylint: disable=E1101
263 return 'int8'
264 if elem_type == TensorProto.FLOAT16: # pylint: disable=E1101
265 return 'float16'
266 if elem_type == TensorProto.COMPLEX64: # pylint: disable=E1101
267 return 'complex64'
268 if elem_type == TensorProto.COMPLEX128: # pylint: disable=E1101
269 return 'complex128'
270 if elem_type == 0: # pylint: disable=E1101
271 return 'unk'
273 # The following code should be refactored.
274 selem = str(elem_type)
276 if selem.startswith("tensor_type"):
277 this = elem_type.tensor_type
278 et = _elem_type_as_str(this.elem_type)
279 shape = this.shape
280 dim = shape.dim
281 dims = [d.dim_value for d in dim]
282 if len(dims) == 0:
283 dims = '?'
284 return {'kind': 'tensor', 'elem': et, 'shape': shape}
286 if selem.startswith("optional_type"):
287 this = elem_type.optional_type
288 et = _elem_type_as_str(this.elem_type)
289 shape = this.shape
290 dim = shape.dim
291 dims = [d.dim_value for d in dim]
292 if len(dims) == 0:
293 dims = '?'
294 return {'kind': 'tensor', 'elem': et, 'shape': shape,
295 'optional_type': True}
297 if selem.startswith("map_type"):
298 this = elem_type.map_type
299 kt = _elem_type_as_str(this.key_type)
300 vt = _elem_type_as_str(this.value_type)
301 return {'kind': 'map', 'key': kt, 'value': vt}
303 raise NotImplementedError( # pragma: no cover
304 "elem_type '{}' is unknown\nfields:\n{}\n-----\n{}.".format(
305 elem_type, pprint.pformat(dir(elem_type)), type(elem_type)))
308def _to_array(var):
309 try:
310 data = to_array(var)
311 except ValueError as e: # pragma: no cover
312 dims = [d for d in var.dims]
313 if var.data_type == 1 and var.float_data is not None:
314 try:
315 data = _numpy_array(var.float_data, dtype=numpy.float32,
316 copy=False).reshape(dims)
317 except ValueError:
318 data = _numpy_array(to_array(var))
319 elif var.data_type == 2 and var.uint8_data is not None:
320 data = _numpy_array(var.uint8_data, dtype=numpy.uint8,
321 copy=False).reshape(dims)
322 elif var.data_type == 3 and var.int8_data is not None:
323 data = _numpy_array(var.int8_data, dtype=numpy.int8,
324 copy=False).reshape(dims)
325 elif var.data_type == 4 and var.uint16_data is not None:
326 data = _numpy_array(var.uint16_data, dtype=numpy.uint16,
327 copy=False).reshape(dims)
328 elif var.data_type == 5 and var.int16_data is not None:
329 data = _numpy_array(var.int16_data, dtype=numpy.int16,
330 copy=False).reshape(dims)
331 elif var.data_type == 6 and var.int32_data is not None:
332 data = _numpy_array(var.int32_data, dtype=numpy.int32,
333 copy=False).reshape(dims)
334 elif var.data_type == 7 and var.int64_data is not None:
335 data = _numpy_array(var.int64_data, dtype=numpy.int64,
336 copy=False).reshape(dims)
337 elif var.data_type == 11 and var.double_data is not None:
338 try:
339 data = _numpy_array(var.double_data, dtype=numpy.float64,
340 copy=False).reshape(dims)
341 except ValueError:
342 data = _numpy_array(to_array(var))
343 elif var.data_type == 16 and var.float16_data is not None:
344 data = _numpy_array(var.float16_data, dtype=numpy.float16,
345 copy=False).reshape(dims)
346 else:
347 raise NotImplementedError(
348 f"Iniatilizer {var} cannot be converted into a dictionary.") from e
349 return data
352def _var_as_dict(var): # pylint: disable=R0912
353 """
354 Converts a protobuf object into something readable.
355 The current implementation relies on :epkg:`json`.
356 That's not the most efficient way.
357 """
358 if hasattr(var, 'type') and str(var.type) != '':
359 # variable
360 if var.type is not None:
361 if hasattr(var, 'sparse_tensor') and var.type == 11:
362 # sparse tensor
363 t = var.sparse_tensor
364 values = _var_as_dict(t.values)
365 dims = list(t.dims)
366 dtype = dict(kind='sparse_tensor', shape=tuple(dims), elem=1)
367 elif (hasattr(var.type, 'tensor_type') and
368 var.type.tensor_type.elem_type > 0):
369 t = var.type.tensor_type
370 elem_type = _elem_type_as_str(t.elem_type)
371 shape = t.shape
372 dim = shape.dim
373 dims = [d.dim_value for d in dim]
374 if len(dims) == 0:
375 dims = '?'
376 dtype = dict(kind='tensor', elem=elem_type,
377 shape=tuple(dims))
378 elif (hasattr(var.type, 'optional_type') and
379 var.type.tensor_type.elem_type > 0):
380 t = var.type.optional_type
381 elem_type = _elem_type_as_str(t.elem_type)
382 shape = t.shape
383 dim = shape.dim
384 dims = [d.dim_value for d in dim]
385 if len(dims) == 0:
386 dims = '?'
387 dtype = dict(kind='tensor', elem=elem_type,
388 shape=tuple(dims), optional_type=True)
389 elif (hasattr(var.type, 'real') and var.type.real == 5 and
390 hasattr(var, 'g')):
391 dtype = dict(kind='graph', elem=var.type.real)
392 elif (hasattr(var.type, 'real') and var.type.real == 4 and
393 hasattr(var, 't')):
394 dtype = dict(kind='tensor', elem=var.type.real)
395 elif hasattr(var.type, 'real'):
396 dtype = dict(kind='real', elem=var.type.real)
397 elif (hasattr(var.type, "sequence_type") and
398 var.type.sequence_type is not None and
399 str(var.type.sequence_type.elem_type) != ''):
400 t = var.type.sequence_type
401 elem_type = _elem_type_as_str(t.elem_type)
402 dtype = dict(kind='sequence', elem=elem_type)
403 elif (hasattr(var.type, "map_type") and
404 var.type.map_type is not None and
405 str(var.type.map_type.key_type) != '' and
406 str(var.type.map_type.value_type) != ''):
407 t = var.type.map_type
408 key_type = _elem_type_as_str(t.key_type)
409 value_type = _elem_type_as_str(t.value_type)
410 dtype = dict(kind='map', key=key_type, value=value_type)
411 elif (hasattr(var.type, 'tensor_type') and
412 var.type.tensor_type.elem_type == 0):
413 if hasattr(var.type, 'optional_type'):
414 optional = var.type.optional_type
415 else:
416 optional = None
417 t = var.type.tensor_type
418 elem_type = _elem_type_as_str(t.elem_type)
419 shape = t.shape
420 dim = shape.dim
421 dims = [d.dim_value for d in dim]
422 if len(dims) == 0:
423 dims = '?'
424 dtype = dict(kind='tensor', elem=elem_type,
425 shape=tuple(dims))
426 if optional is not None:
427 dtype['optional'] = _var_as_dict(optional)
428 else:
429 raise NotImplementedError( # pragma: no cover
430 "Unable to convert a type into a dictionary for '{}'. "
431 "Available fields: {}.".format(
432 var.type, pprint.pformat(dir(var.type))))
433 else:
434 raise NotImplementedError( # pragma: no cover
435 "Unable to convert variable into a dictionary for '{}'. "
436 "Available fields: {}.".format(
437 var, pprint.pformat(dir(var.type))))
439 res = dict(name=var.name, type=dtype)
441 if (hasattr(var, 'sparse_tensor') and dtype.get('elem', None) == 1 and
442 dtype['kind'] == 'sparse_tensor'):
443 # sparse matrix
444 t = var.sparse_tensor
445 try:
446 values = _var_as_dict(t.values)
447 except NotImplementedError as e: # pragma: no cover
448 raise NotImplementedError(
449 f"Issue with\n{var}\n---") from e
450 indices = _var_as_dict(t.indices)
451 res['value'] = _sparse_array(
452 dtype['shape'], values['value'], indices['value'], dtype=numpy.float32)
453 elif hasattr(var, 'floats') and dtype.get('elem', None) == 6:
454 res['value'] = _numpy_array(var.floats, dtype=numpy.float32)
455 elif hasattr(var, 'strings') and dtype.get('elem', None) == 8:
456 res['value'] = _numpy_array(var.strings)
457 elif hasattr(var, 'ints') and dtype.get('elem', None) == 7:
458 res['value'] = _numpy_array(var.ints)
459 elif hasattr(var, 'f') and dtype.get('elem', None) == 1:
460 res['value'] = var.f
461 elif hasattr(var, 's') and dtype.get('elem', None) == 3:
462 res['value'] = var.s
463 elif hasattr(var, 'i') and dtype.get('elem', None) == 2:
464 res['value'] = var.i
465 elif hasattr(var, 'g') and dtype.get('elem', None) == 5:
466 res['value'] = var.g
467 elif hasattr(var, 't') and dtype.get('elem', None) == 4:
468 if hasattr(var, 'ref_attr_name') and var.ref_attr_name:
469 res['ref_attr_name'] = var.ref_attr_name
470 else:
471 ts = _var_as_dict(var.t)
472 res['value'] = ts['value']
473 elif hasattr(var, 'sparse_tensor') and dtype.get('elem', None) == 11:
474 ts = _var_as_dict(var.sparse_tensor)
475 if hasattr(var, 'ref_attr_name') and var.ref_attr_name:
476 res['ref_attr_name'] = var.ref_attr_name
477 else:
478 ts = _var_as_dict(var.t)
479 res['value'] = ts['value']
480 elif "'value'" in str(var):
481 warnings.warn("No value: {} -- {}".format( # pragma: no cover
482 dtype, str(var).replace("\n", "").replace(" ", "")))
483 return res
485 if hasattr(var, 'op_type'):
486 if hasattr(var, 'attribute'):
487 atts = {}
488 for att in var.attribute:
489 atts[att.name] = _var_as_dict(att)
490 return dict(name=var.name, op_type=var.op_type,
491 domain=var.domain, atts=atts)
492 if hasattr(var, 'dims') and len(var.dims) > 0:
493 # initializer
494 data = _to_array(var)
495 return dict(name=var.name, value=data)
496 if hasattr(var, 'data_type') and var.data_type > 0:
497 data = _to_array(var)
498 return dict(name=var.name, value=data)
499 if isinstance(var, str):
500 return dict(name=var)
501 if str(var) == '':
502 return None
503 if isinstance(var, ValueInfoProto):
504 return dict(name=var.name,
505 type=dict(elem='unk', kind='tensor', shape=('?', )))
506 if isinstance(var, TensorShapeProto):
507 ds = []
508 for dim in var.dim:
509 d = {}
510 if dim.dim_value:
511 d['dim_value'] = dim.dim_value
512 if dim.dim_param:
513 d['dim_param'] = dim.dim_param
514 ds.append(d)
515 return dict(dim=ds)
516 if isinstance(var, TypeProto):
517 d = dict(denotation=var.denotation)
518 for n in dir(var):
519 if n.endswith('_type'):
520 at = getattr(var, n)
521 d[n] = _var_as_dict(at)
522 return d
523 if var.__class__.__name__ == "Tensor":
524 return dict(elem_type=var.elem_type, shape=_var_as_dict(var.shape))
525 if var.__class__.__name__ == "Optional":
526 return dict(optional=True, elem_type=_var_as_dict(var.elem_type))
528 raise NotImplementedError( # pragma: no cover
529 "Unable to guess which object it is type is %r value is %r "
530 "(hasattr(var,'type')=%r, var.type=%s\n%s"
531 "" % (type(var), str(var), hasattr(var, 'type'),
532 str(getattr(var, 'type', None)),
533 '\n'.join(dir(var))))
536def get_dtype_shape(obj):
537 """
538 Returns the shape of a tensor.
540 :param obj: onnx object
541 :return: `(dtype, shape)` or `(None, None)` if not applicable
542 """
543 if not hasattr(obj, 'type'):
544 return None
545 t = obj.type
546 if not hasattr(t, 'tensor_type'):
547 return None
548 t = t.tensor_type
549 dtype = t.elem_type
550 if not hasattr(t, 'shape'):
551 return dtype, None
552 shape = t.shape
553 ds = []
554 for dim in shape.dim:
555 d = dim.dim_value
556 s = dim.dim_param
557 if d == 0:
558 if s == '':
559 ds.append(None)
560 else:
561 ds.append(s)
562 else:
563 ds.append(d)
564 return dtype, tuple(ds)
567def onnx_model_opsets(onnx_model):
568 """
569 Extracts opsets in a dictionary.
571 :param onnx_model: ONNX graph
572 :return: dictionary `{domain: version}`
573 """
574 res = {}
575 for oimp in onnx_model.opset_import:
576 res[oimp.domain] = oimp.version
577 return res
580def _type_to_string(dtype):
581 """
582 Converts a type into a readable string.
583 """
584 if not isinstance(dtype, dict):
585 dtype_ = _var_as_dict(dtype) # pragma: no cover
586 else:
587 dtype_ = dtype
588 if dtype_["kind"] == 'tensor':
589 return f"{dtype_['elem']}({dtype_['shape']})"
590 if dtype_['kind'] == 'sequence':
591 return f"[{_type_to_string(dtype_['elem'])}]"
592 if dtype_["kind"] == 'map':
593 return f"{{{dtype_['key']}, {dtype_['value']}}}"
594 raise NotImplementedError( # pragma: no cover
595 f"Unable to convert into string {dtype} or {dtype_}.")
598def numpy_min(x):
599 """
600 Returns the minimum of an array.
601 Deals with text as well.
602 """
603 try:
604 if hasattr(x, 'todense'):
605 x = x.todense()
606 if x.dtype.kind not in 'cUC':
607 return x.min()
608 try: # pragma: no cover
609 x = x.ravel()
610 except AttributeError: # pragma: no cover
611 pass
612 keep = list(filter(lambda s: isinstance(s, str), x))
613 if len(keep) == 0: # pragma: no cover
614 return numpy.nan
615 keep.sort()
616 val = keep[0]
617 if len(val) > 10: # pragma: no cover
618 val = val[:10] + '...'
619 return f"{val!r}"
620 except (ValueError, TypeError): # pragma: no cover
621 return '?'
624def numpy_max(x):
625 """
626 Returns the maximum of an array.
627 Deals with text as well.
628 """
629 try:
630 if hasattr(x, 'todense'):
631 x = x.todense()
632 if x.dtype.kind not in 'cUC':
633 return x.max()
634 try: # pragma: no cover
635 x = x.ravel()
636 except AttributeError: # pragma: no cover
637 pass
638 keep = list(filter(lambda s: isinstance(s, str), x))
639 if len(keep) == 0: # pragma: no cover
640 return numpy.nan
641 keep.sort()
642 val = keep[-1]
643 if len(val) > 10: # pragma: no cover
644 val = val[:10] + '...'
645 return f"{val!r}"
646 except (ValueError, TypeError): # pragma: no cover
647 return '?'
650def guess_proto_dtype(dtype):
651 """
652 Guesses the ONNX dtype given a numpy dtype.
654 :param dtype: numpy dtype
655 :return: proto type
656 """
657 if dtype == numpy.float32:
658 return TensorProto.FLOAT # pylint: disable=E1101
659 if dtype == numpy.float64:
660 return TensorProto.DOUBLE # pylint: disable=E1101
661 if dtype == numpy.int64:
662 return TensorProto.INT64 # pylint: disable=E1101
663 if dtype == numpy.int32:
664 return TensorProto.INT32 # pylint: disable=E1101
665 if dtype == numpy.int16:
666 return TensorProto.INT16 # pylint: disable=E1101
667 if dtype == numpy.int8:
668 return TensorProto.INT8 # pylint: disable=E1101
669 if dtype == numpy.uint64:
670 return TensorProto.UINT64 # pylint: disable=E1101
671 if dtype == numpy.uint32:
672 return TensorProto.UINT32 # pylint: disable=E1101
673 if dtype == numpy.uint16:
674 return TensorProto.UINT16 # pylint: disable=E1101
675 if dtype == numpy.uint8:
676 return TensorProto.UINT8 # pylint: disable=E1101
677 if dtype == numpy.float16:
678 return TensorProto.FLOAT16 # pylint: disable=E1101
679 if dtype in (bool, numpy.bool_):
680 return TensorProto.BOOL # pylint: disable=E1101
681 if dtype in (str, numpy.str_):
682 return TensorProto.STRING # pylint: disable=E1101
683 raise RuntimeError(
684 f"Unable to guess type for dtype={dtype}.") # pragma: no cover
687def guess_proto_dtype_name(onnx_dtype):
688 """
689 Returns a string equivalent to `onnx_dtype`.
691 :param dtype: onnx dtype
692 :return: proto type
693 """
694 if onnx_dtype == TensorProto.FLOAT: # pylint: disable=E1101
695 return "TensorProto.FLOAT"
696 if onnx_dtype == TensorProto.DOUBLE: # pylint: disable=E1101
697 return "TensorProto.DOUBLE"
698 if onnx_dtype == TensorProto.INT64: # pylint: disable=E1101
699 return "TensorProto.INT64"
700 if onnx_dtype == TensorProto.INT32: # pylint: disable=E1101
701 return "TensorProto.INT32"
702 if onnx_dtype == TensorProto.INT16: # pylint: disable=E1101
703 return "TensorProto.INT16"
704 if onnx_dtype == TensorProto.UINT8: # pylint: disable=E1101
705 return "TensorProto.UINT8"
706 if onnx_dtype == TensorProto.FLOAT16: # pylint: disable=E1101
707 return "TensorProto.FLOAT16"
708 if onnx_dtype == TensorProto.BFLOAT16: # pylint: disable=E1101
709 return "TensorProto.BFLOAT16"
710 if onnx_dtype == TensorProto.BOOL: # pylint: disable=E1101
711 return "TensorProto.BOOL"
712 if onnx_dtype == TensorProto.STRING: # pylint: disable=E1101
713 return "TensorProto.STRING"
714 raise RuntimeError( # pragma: no cover
715 f"Unable to guess type for dtype={onnx_dtype}.")
718def guess_dtype(proto_type):
719 """
720 Converts a proto type into a :epkg:`numpy` type.
722 :param proto_type: example ``onnx.TensorProto.FLOAT``
723 :return: :epkg:`numpy` dtype
724 """
725 if proto_type == TensorProto.FLOAT: # pylint: disable=E1101
726 return numpy.float32
727 if proto_type == TensorProto.BOOL: # pylint: disable=E1101
728 return numpy.bool_
729 if proto_type == TensorProto.DOUBLE: # pylint: disable=E1101
730 return numpy.float64
731 if proto_type == TensorProto.STRING: # pylint: disable=E1101
732 return numpy.str_
733 if proto_type == TensorProto.INT64: # pylint: disable=E1101
734 return numpy.int64
735 if proto_type == TensorProto.INT32: # pylint: disable=E1101
736 return numpy.int32
737 if proto_type == TensorProto.INT8: # pylint: disable=E1101
738 return numpy.int8
739 if proto_type == TensorProto.INT16: # pylint: disable=E1101
740 return numpy.int16
741 if proto_type == TensorProto.UINT64: # pylint: disable=E1101
742 return numpy.uint64
743 if proto_type == TensorProto.UINT32: # pylint: disable=E1101
744 return numpy.uint32
745 if proto_type == TensorProto.UINT8: # pylint: disable=E1101
746 return numpy.uint8
747 if proto_type == TensorProto.UINT16: # pylint: disable=E1101
748 return numpy.uint16
749 if proto_type == TensorProto.FLOAT16: # pylint: disable=E1101
750 return numpy.float16
751 raise ValueError(
752 f"Unable to convert proto_type {proto_type} to numpy type.")
755def to_skl2onnx_type(name, elem_type, shape):
756 """
757 Converts *name*, *elem_type*, *shape* into a
758 :epkg:`sklearn-onnx` type.
760 :param name: string
761 :param elem_type: tensor of elements of this type
762 :param shape: expected shape
763 :return: data type
764 """
765 from skl2onnx.common.data_types import _guess_numpy_type # delayed
766 elem = guess_numpy_type_from_string(elem_type)
767 shape = list(None if d == 0 else d for d in shape)
768 return (name, _guess_numpy_type(elem, shape))
771def from_pb(obj):
772 """
773 Extracts tensor description from a protobuf.
775 :param obj: initializer, tensor
776 :return: (name, type, shape)
777 """
778 def get_dim(d):
779 r = d.dim_value
780 if "dim_param" in str(d):
781 return None
782 if r == 0:
783 # dim_value is 0 when it is 0 or undefined
784 return 0 if "0" in str(d) else None
785 return r
787 def get_shape(tt):
788 return [get_dim(tt.shape.dim[i])
789 for i in range(len(tt.shape.dim))]
791 if hasattr(obj, 'extend'):
792 return [from_pb(o) for o in obj]
794 name = obj.name
795 if obj.type.tensor_type:
796 tt = obj.type.tensor_type
797 elem = tt.elem_type
798 shape = get_shape(tt)
799 if elem not in TENSOR_TYPE_TO_NP_TYPE:
800 raise NotImplementedError(
801 f"Unsupported type '{type(obj.type.tensor_type)}' (elem_type={elem}).")
802 ty = TENSOR_TYPE_TO_NP_TYPE[elem].type
803 else:
804 raise NotImplementedError( # pragma: no cover
805 f"Unsupported type '{type(obj)}' as a string ({obj}).")
807 return (name, ty, shape)
810def numpy_type_prototype(dtype):
811 """
812 Converts a numpy dtyp into a TensorProto dtype.
814 :param dtype: dtype
815 :return: proto dtype
816 """
817 if dtype in NP_TYPE_TO_TENSOR_TYPE:
818 return NP_TYPE_TO_TENSOR_TYPE[dtype]
819 dt = numpy.dtype(dtype)
820 if dt in NP_TYPE_TO_TENSOR_TYPE:
821 return NP_TYPE_TO_TENSOR_TYPE[dt]
822 raise ValueError( # pragma: no cover
823 f"Unable to convert dtype {dtype!r} into ProtoType.")
826def make_value_info(name, dtype, shape):
827 """
828 Converts a variable defined by its name, type and shape
829 into `onnx.ValueInfoProto`.
831 :param name: name
832 :param dtype: numpy element type
833 :param shape: shape
834 :return: instance of `onnx.ValueInfoProto`
835 """
836 value_info = ValueInfoProto()
837 value_info.name = name
838 tensor_type_proto = make_tensor_type_proto(
839 numpy_type_prototype(dtype), shape)
840 value_info.type.CopyFrom(tensor_type_proto) # pylint: disable=E1101
841 return value_info
844def copy_value_info(info, name=None):
845 """
846 Makes a copy of `onnx.ValueInfoProto`.
848 :param name: if defined, changed the name
849 :return: instance of `onnx.ValueInfoProto`
850 """
851 value_info = ValueInfoProto()
852 value_info.name = name or info.name
853 value_info.type.CopyFrom(info.type) # pylint: disable=E1101
854 return value_info
857_get_onnx_function_cache = None
860def _get_onnx_function():
861 """
862 Returns the list of functions defined in ONNX package.
863 """
864 global _get_onnx_function_cache # pylint: disable=W0603
865 if _get_onnx_function_cache is None:
866 _get_onnx_function_cache = {}
867 fcts = get_function_ops()
868 for fct in fcts:
869 key = fct.domain, fct.name
870 if key in _get_onnx_function_cache:
871 raise RuntimeError( # pragma: no cover
872 f"Function {key!r} is already registered.")
873 _get_onnx_function_cache[key] = fct
874 return _get_onnx_function_cache
877def get_onnx_schema(opname, domain='', opset=None, load_function=False):
878 """
879 Returns the operator schema for a specific operator.
881 :param domain: operator domain
882 :param opname: operator name
883 :param opset: opset or version, None for the latest
884 :param load_function: loads the function, if True, the function
885 looks into the list of function if one of them has the same name,
886 opset must be None in that case
887 :return: :epkg:`OpSchema`
888 """
889 if load_function:
890 if opset is not None:
891 raise ValueError(
892 "opset must be None if load_function is True for "
893 "operator (%r,%r)." % (domain, opname))
894 fcts = _get_onnx_function()
895 key = domain, opname
896 if key in fcts:
897 return fcts[key]
898 if opset is None:
899 opset = onnx_opset_version()
900 return get_schema(opname, opset, domain)
901 if opset is None:
902 opset = onnx_opset_version()
903 return get_schema(opname, opset, domain)