Coverage for mlprodict/onnxrt/ops_cpu/_op_helper.py: 100%
40 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Runtime operator.
4"""
5import numpy
8def _get_typed_class_attribute(self, k, atts):
9 """
10 Converts an attribute into a C++ value.
11 """
12 ty = atts[k]
13 if isinstance(ty, numpy.ndarray):
14 v = getattr(self, k)
15 return v if v.dtype == ty.dtype else v.astype(ty.dtype)
16 if isinstance(ty, bytes):
17 return getattr(self, k).decode()
18 if isinstance(ty, list):
19 v = getattr(self, k)
20 if isinstance(v, numpy.ndarray):
21 return v
22 return [_.decode() for _ in getattr(self, k)]
23 if isinstance(ty, int):
24 return getattr(self, k)
25 raise NotImplementedError( # pragma: no cover
26 f"Unable to convert '{k}' ({getattr(self, k)}).")
29def proto2dtype(proto_type):
30 """
31 Converts a proto type into a :epkg:`numpy` type.
33 :param proto_type: example ``onnx.TensorProto.FLOAT``
34 :return: :epkg:`numpy` dtype
35 """
36 from ...onnx_tools.onnx2py_helper import guess_dtype
37 return guess_dtype(proto_type)
40def dtype_name(dtype):
41 """
42 Returns the name of a numpy dtype.
43 """
44 if dtype == numpy.float32:
45 return "float32"
46 if dtype == numpy.float64:
47 return "float64"
48 if dtype == numpy.float16:
49 return "float16"
50 if dtype == numpy.int32:
51 return "int32"
52 if dtype == numpy.uint32:
53 return "uint32"
54 if dtype == numpy.int64:
55 return "int64"
56 if dtype == numpy.int8:
57 return "int8"
58 if dtype == numpy.uint8:
59 return "uint8"
60 if dtype == numpy.str_:
61 return "str"
62 if dtype == numpy.bool_:
63 return "bool"
64 raise ValueError(
65 f"Unexpected dtype {dtype}.")