Coverage for mlprodict/onnxrt/ops_cpu/op_array_feature_extractor.py: 100%

41 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9from ._op_onnx_numpy import ( # pylint: disable=E0611,E0401 

10 array_feature_extractor_double, 

11 array_feature_extractor_int64, 

12 array_feature_extractor_float) 

13 

14 

15def _array_feature_extrator(data, indices): 

16 """ 

17 Implementation of operator *ArrayFeatureExtractor* 

18 with :epkg:`numpy`. 

19 """ 

20 if len(indices.shape) == 2 and indices.shape[0] == 1: 

21 index = indices.ravel().tolist() 

22 add = len(index) 

23 elif len(indices.shape) == 1: 

24 index = indices.tolist() 

25 add = len(index) 

26 else: 

27 add = 1 

28 for s in indices.shape: 

29 add *= s 

30 index = indices.ravel().tolist() 

31 if len(data.shape) == 1: 

32 new_shape = (1, add) 

33 else: 

34 new_shape = list(data.shape[:-1]) + [add] 

35 tem = data[..., index] 

36 res = tem.reshape(new_shape) 

37 return res 

38 

39 

40def sizeof_dtype(dty): 

41 if dty == numpy.float64: 

42 return 8 

43 if dty == numpy.float32: 

44 return 4 

45 if dty == numpy.int64: 

46 return 8 

47 raise ValueError( 

48 f"Unable to get bytes size for type {numpy.dtype}.") 

49 

50 

51class ArrayFeatureExtractor(OpRun): 

52 

53 def __init__(self, onnx_node, desc=None, **options): 

54 OpRun.__init__(self, onnx_node, desc=desc, 

55 **options) 

56 

57 def _run(self, data, indices, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

58 """ 

59 Runtime for operator *ArrayFeatureExtractor*. 

60 

61 .. warning:: 

62 ONNX specifications may be imprecise in some cases. 

63 When the input data is a vector (one dimension), 

64 the output has still two like a matrix with one row. 

65 The implementation follows what :epkg:`onnxruntime` does in 

66 `array_feature_extractor.cc 

67 <https://github.com/microsoft/onnxruntime/blob/master/ 

68 onnxruntime/core/providers/cpu/ml/array_feature_extractor.cc#L84>`_. 

69 """ 

70 if data.dtype == numpy.float64: 

71 res = array_feature_extractor_double(data, indices) 

72 elif data.dtype == numpy.float32: 

73 res = array_feature_extractor_float(data, indices) 

74 elif data.dtype == numpy.int64: 

75 res = array_feature_extractor_int64(data, indices) 

76 else: 

77 # for strings, still not C++ 

78 res = _array_feature_extrator(data, indices) 

79 return (res, )