Coverage for mlprodict/onnxrt/ops_cpu/op_squeeze.py: 84%

37 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from onnx.defs import onnx_opset_version 

9from ._op import OpRunUnaryNum, OpRun 

10 

11 

12class Squeeze_1(OpRunUnaryNum): 

13 

14 atts = {'axes': [], 'keepdims': 1} 

15 

16 def __init__(self, onnx_node, desc=None, **options): 

17 OpRunUnaryNum.__init__(self, onnx_node, desc=desc, 

18 expected_attributes=Squeeze_1.atts, 

19 **options) 

20 if isinstance(self.axes, numpy.ndarray): 

21 self.axes = tuple(self.axes) 

22 elif self.axes in [[], tuple()]: 

23 self.axes = None 

24 elif isinstance(self.axes, list): 

25 self.axes = tuple(self.axes) 

26 

27 def _run(self, data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

28 if isinstance(self.axes, (tuple, list)): 

29 sq = data 

30 for a in reversed(self.axes): 

31 sq = numpy.squeeze(sq, axis=a) 

32 else: 

33 sq = numpy.squeeze(data, axis=self.axes) 

34 return (sq, ) 

35 

36 

37class Squeeze_11(Squeeze_1): 

38 pass 

39 

40 

41class Squeeze_13(OpRun): 

42 

43 atts = {'keepdims': 1} 

44 

45 def __init__(self, onnx_node, desc=None, **options): 

46 OpRun.__init__(self, onnx_node, desc=desc, 

47 expected_attributes=Squeeze_13.atts, 

48 **options) 

49 self.axes = None 

50 

51 def _run(self, data, axes=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

52 if axes is not None: 

53 if hasattr(axes, '__iter__'): 

54 sq = numpy.squeeze(data, axis=tuple(axes)) 

55 else: 

56 sq = numpy.squeeze(data, axis=axes) 

57 else: 

58 sq = numpy.squeeze(data) 

59 return (sq, ) 

60 

61 

62if onnx_opset_version() >= 13: 

63 Squeeze = Squeeze_13 

64elif onnx_opset_version() >= 11: # pragma: no cover 

65 Squeeze = Squeeze_11 

66else: # pragma: no cover 

67 Squeeze = Squeeze_1