Coverage for mlprodict/onnxrt/ops_cpu/op_max_pool.py: 91%

47 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import itertools 

8import numpy 

9from ._op import OpRun 

10from .op_max_pool_ import MaxPoolFloat, MaxPoolDouble # pylint: disable=E0611,E0401 

11 

12 

13def _pool_get_output_shape(auto_pad, input_spatial_shape, kernel_spatial_shape, 

14 strides_spatial): 

15 out_shape = [0] * len(input_spatial_shape) 

16 if auto_pad in (b'SAME_UPPER', b'SAME_LOWER'): 

17 for i in range(len(input_spatial_shape)): # pylint: disable=C0200 

18 out_shape[i] = int( 

19 numpy.ceil( 

20 float(input_spatial_shape[i]) / float(strides_spatial[i]))) 

21 elif auto_pad == b'VALID': 

22 for i in range(len(input_spatial_shape)): # pylint: disable=C0200 

23 out_shape[i] = int( 

24 numpy.ceil(float(input_spatial_shape[i] - (kernel_spatial_shape[i] - 1)) / 

25 float(strides_spatial[i]))) 

26 return out_shape 

27 

28 

29def _pool_impl(padded, x_shape, kernel_shape, strides_shape, 

30 out_shape, pad_shape, pooling_type, 

31 count_include_pad=0): 

32 spatial_size = len(x_shape) - 2 

33 y = numpy.zeros([x_shape[0], x_shape[1]] + list(out_shape)) 

34 

35 for shape in itertools.product( 

36 range(x_shape[0]), range(x_shape[1]), 

37 *[range(int((x_shape[i + 2] + pad_shape[i] - kernel_shape[i]) / 

38 strides_shape[i] + 1)) 

39 for i in range(spatial_size)]): 

40 window = padded[shape[0], shape[1]] 

41 window_vals = numpy.array( 

42 [window[i] for i in list( 

43 itertools.product( 

44 *[range(strides_shape[i] * shape[i + 2], 

45 strides_shape[i] * shape[i + 2] + kernel_shape[i]) 

46 for i in range(spatial_size)]))]) 

47 if pooling_type == b'AVG': 

48 f = numpy.average 

49 elif pooling_type == b'MAX': 

50 f = numpy.max 

51 else: 

52 raise NotImplementedError( # pragma: no cover 

53 f"Pooling type '{pooling_type}' does not support. Should be AVG, MAX.") 

54 

55 if count_include_pad == 1 and pooling_type == b'AVG': 

56 y[shape] = f(window_vals) 

57 else: 

58 y[shape] = f(window_vals[numpy.where(~numpy.isnan(window_vals))]) 

59 return y.astype(numpy.float32) 

60 

61 

62class MaxPool(OpRun): 

63 

64 atts = {'auto_pad': b'NOTSET', 'ceil_mode': 0, 'dilations': [], 

65 'kernel_shape': [], 'pads': [], 'storage_order': 0, 

66 'strides': []} 

67 

68 def __init__(self, onnx_node, desc=None, **options): 

69 OpRun.__init__(self, onnx_node, desc=desc, 

70 expected_attributes=MaxPool.atts, 

71 **options) 

72 self.auto_pad_ = self.auto_pad.decode('ascii') 

73 self.nb_outputs = len(onnx_node.output) 

74 self._init() 

75 

76 def _init(self): 

77 self.rt32_ = MaxPoolFloat() 

78 self.rt64_ = MaxPoolDouble() 

79 for rt in [self.rt32_, self.rt64_]: 

80 rt.init(self.auto_pad, 

81 numpy.array(self.dilations, dtype=numpy.int64), 

82 self.ceil_mode, 

83 self.storage_order, 

84 numpy.array(self.kernel_shape, dtype=numpy.int64), 

85 numpy.array(self.pads, dtype=numpy.int64), 

86 numpy.array(self.strides, dtype=numpy.int64)) 

87 

88 def _run(self, X, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

89 if X.dtype == numpy.float32: 

90 res = self.rt32_.compute(X) 

91 else: 

92 res = self.rt64_.compute(X) 

93 if self.nb_outputs == 1: 

94 return res[:1] 

95 return res