Coverage for mlprodict/onnxrt/ops_cpu/op_global_average_pool.py: 100%

27 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11def _global_average_pool(x): 

12 spatial_shape = numpy.ndim(x) - 2 

13 y = numpy.average( 

14 x, axis=tuple(range(spatial_shape, spatial_shape + 2))) 

15 for _ in range(spatial_shape): 

16 y = numpy.expand_dims(y, -1) 

17 return y 

18 

19 

20def _global_max_pool(x): 

21 spatial_shape = numpy.ndim(x) - 2 

22 y = x.max(axis=tuple(range(spatial_shape, spatial_shape + 2))) 

23 for _ in range(spatial_shape): 

24 y = numpy.expand_dims(y, -1) 

25 return y 

26 

27 

28class GlobalAveragePool(OpRun): 

29 

30 def __init__(self, onnx_node, desc=None, **options): 

31 OpRun.__init__(self, onnx_node, desc=desc, 

32 **options) 

33 

34 def _run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

35 res = _global_average_pool(x) 

36 return (res, ) 

37 

38 

39class GlobalMaxPool(OpRun): 

40 

41 def __init__(self, onnx_node, desc=None, **options): 

42 OpRun.__init__(self, onnx_node, desc=desc, 

43 **options) 

44 

45 def _run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

46 res = _global_max_pool(x) 

47 return (res, )