Coverage for mlprodict/onnxrt/ops_cpu/_op_helper.py: 100%

40 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Runtime operator. 

4""" 

5import numpy 

6 

7 

8def _get_typed_class_attribute(self, k, atts): 

9 """ 

10 Converts an attribute into a C++ value. 

11 """ 

12 ty = atts[k] 

13 if isinstance(ty, numpy.ndarray): 

14 v = getattr(self, k) 

15 return v if v.dtype == ty.dtype else v.astype(ty.dtype) 

16 if isinstance(ty, bytes): 

17 return getattr(self, k).decode() 

18 if isinstance(ty, list): 

19 v = getattr(self, k) 

20 if isinstance(v, numpy.ndarray): 

21 return v 

22 return [_.decode() for _ in getattr(self, k)] 

23 if isinstance(ty, int): 

24 return getattr(self, k) 

25 raise NotImplementedError( # pragma: no cover 

26 f"Unable to convert '{k}' ({getattr(self, k)}).") 

27 

28 

29def proto2dtype(proto_type): 

30 """ 

31 Converts a proto type into a :epkg:`numpy` type. 

32 

33 :param proto_type: example ``onnx.TensorProto.FLOAT`` 

34 :return: :epkg:`numpy` dtype 

35 """ 

36 from ...onnx_tools.onnx2py_helper import guess_dtype 

37 return guess_dtype(proto_type) 

38 

39 

40def dtype_name(dtype): 

41 """ 

42 Returns the name of a numpy dtype. 

43 """ 

44 if dtype == numpy.float32: 

45 return "float32" 

46 if dtype == numpy.float64: 

47 return "float64" 

48 if dtype == numpy.float16: 

49 return "float16" 

50 if dtype == numpy.int32: 

51 return "int32" 

52 if dtype == numpy.uint32: 

53 return "uint32" 

54 if dtype == numpy.int64: 

55 return "int64" 

56 if dtype == numpy.int8: 

57 return "int8" 

58 if dtype == numpy.uint8: 

59 return "uint8" 

60 if dtype == numpy.str_: 

61 return "str" 

62 if dtype == numpy.bool_: 

63 return "bool" 

64 raise ValueError( 

65 f"Unexpected dtype {dtype}.")