Coverage for mlprodict/onnxrt/ops_cpu/op_max_pool.py: 91%
47 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import itertools
8import numpy
9from ._op import OpRun
10from .op_max_pool_ import MaxPoolFloat, MaxPoolDouble # pylint: disable=E0611,E0401
13def _pool_get_output_shape(auto_pad, input_spatial_shape, kernel_spatial_shape,
14 strides_spatial):
15 out_shape = [0] * len(input_spatial_shape)
16 if auto_pad in (b'SAME_UPPER', b'SAME_LOWER'):
17 for i in range(len(input_spatial_shape)): # pylint: disable=C0200
18 out_shape[i] = int(
19 numpy.ceil(
20 float(input_spatial_shape[i]) / float(strides_spatial[i])))
21 elif auto_pad == b'VALID':
22 for i in range(len(input_spatial_shape)): # pylint: disable=C0200
23 out_shape[i] = int(
24 numpy.ceil(float(input_spatial_shape[i] - (kernel_spatial_shape[i] - 1)) /
25 float(strides_spatial[i])))
26 return out_shape
29def _pool_impl(padded, x_shape, kernel_shape, strides_shape,
30 out_shape, pad_shape, pooling_type,
31 count_include_pad=0):
32 spatial_size = len(x_shape) - 2
33 y = numpy.zeros([x_shape[0], x_shape[1]] + list(out_shape))
35 for shape in itertools.product(
36 range(x_shape[0]), range(x_shape[1]),
37 *[range(int((x_shape[i + 2] + pad_shape[i] - kernel_shape[i]) /
38 strides_shape[i] + 1))
39 for i in range(spatial_size)]):
40 window = padded[shape[0], shape[1]]
41 window_vals = numpy.array(
42 [window[i] for i in list(
43 itertools.product(
44 *[range(strides_shape[i] * shape[i + 2],
45 strides_shape[i] * shape[i + 2] + kernel_shape[i])
46 for i in range(spatial_size)]))])
47 if pooling_type == b'AVG':
48 f = numpy.average
49 elif pooling_type == b'MAX':
50 f = numpy.max
51 else:
52 raise NotImplementedError( # pragma: no cover
53 f"Pooling type '{pooling_type}' does not support. Should be AVG, MAX.")
55 if count_include_pad == 1 and pooling_type == b'AVG':
56 y[shape] = f(window_vals)
57 else:
58 y[shape] = f(window_vals[numpy.where(~numpy.isnan(window_vals))])
59 return y.astype(numpy.float32)
62class MaxPool(OpRun):
64 atts = {'auto_pad': b'NOTSET', 'ceil_mode': 0, 'dilations': [],
65 'kernel_shape': [], 'pads': [], 'storage_order': 0,
66 'strides': []}
68 def __init__(self, onnx_node, desc=None, **options):
69 OpRun.__init__(self, onnx_node, desc=desc,
70 expected_attributes=MaxPool.atts,
71 **options)
72 self.auto_pad_ = self.auto_pad.decode('ascii')
73 self.nb_outputs = len(onnx_node.output)
74 self._init()
76 def _init(self):
77 self.rt32_ = MaxPoolFloat()
78 self.rt64_ = MaxPoolDouble()
79 for rt in [self.rt32_, self.rt64_]:
80 rt.init(self.auto_pad,
81 numpy.array(self.dilations, dtype=numpy.int64),
82 self.ceil_mode,
83 self.storage_order,
84 numpy.array(self.kernel_shape, dtype=numpy.int64),
85 numpy.array(self.pads, dtype=numpy.int64),
86 numpy.array(self.strides, dtype=numpy.int64))
88 def _run(self, X, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
89 if X.dtype == numpy.float32:
90 res = self.rt32_.compute(X)
91 else:
92 res = self.rt64_.compute(X)
93 if self.nb_outputs == 1:
94 return res[:1]
95 return res