Coverage for mlprodict/onnxrt/ops_cpu/op_unsqueeze.py: 89%

35 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from onnx.defs import onnx_opset_version 

9from ._op import OpRunUnaryNum, OpRun 

10 

11 

12class Unsqueeze_1(OpRunUnaryNum): 

13 

14 atts = {'axes': [], 'keepdims': 1} 

15 

16 def __init__(self, onnx_node, desc=None, **options): 

17 OpRunUnaryNum.__init__(self, onnx_node, desc=desc, 

18 expected_attributes=Unsqueeze_1.atts, 

19 **options) 

20 if isinstance(self.axes, numpy.ndarray): 

21 self.axes = tuple(self.axes) 

22 elif self.axes in [[], tuple()]: 

23 self.axes = None 

24 elif isinstance(self.axes, list): 

25 self.axes = tuple(self.axes) 

26 

27 def _run(self, data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

28 if isinstance(self.axes, (tuple, list)): 

29 sq = data 

30 for a in self.axes: 

31 sq = numpy.expand_dims(sq, axis=a) 

32 else: 

33 raise RuntimeError( # pragma: no cover 

34 "axes cannot be None for operator Unsqueeze (Unsqueeze_1).") 

35 return (sq, ) 

36 

37 

38class Unsqueeze_11(Unsqueeze_1): 

39 pass 

40 

41 

42class Unsqueeze_13(OpRun): 

43 

44 atts = {'keepdims': 1} 

45 

46 def __init__(self, onnx_node, desc=None, **options): 

47 OpRun.__init__(self, onnx_node, desc=desc, 

48 expected_attributes=Unsqueeze_13.atts, 

49 **options) 

50 self.axes = None 

51 

52 def _run(self, data, axes=None, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

53 if axes is not None: 

54 if hasattr(axes, '__iter__') and len(axes.shape) > 0: 

55 sq = numpy.expand_dims(data, axis=tuple(axes)) 

56 else: 

57 sq = numpy.expand_dims(data, axis=axes) 

58 else: 

59 raise RuntimeError( # pragma: no cover 

60 "axes cannot be None for operator Unsqueeze (Unsqueeze_13).") 

61 return (sq, ) 

62 

63 

64if onnx_opset_version() >= 13: 

65 Unsqueeze = Unsqueeze_13 

66elif onnx_opset_version() >= 11: # pragma: no cover 

67 Unsqueeze = Unsqueeze_11 

68else: # pragma: no cover 

69 Unsqueeze = Unsqueeze_1