Coverage for mlprodict/onnxrt/ops_cpu/op_shape.py: 97%

32 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from onnx.defs import onnx_opset_version 

9from ._op import OpRun 

10 

11 

12class Shape_1(OpRun): 

13 

14 def __init__(self, onnx_node, desc=None, **options): 

15 OpRun.__init__(self, onnx_node, desc=desc, **options) 

16 

17 def _run(self, data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

18 return (numpy.array(data.shape, dtype=numpy.int64), ) 

19 

20 

21class Shape_15(Shape_1): 

22 

23 atts = {'start': 0, 'end': numpy.nan} 

24 

25 def __init__(self, onnx_node, desc=None, **options): 

26 Shape_1.__init__(self, onnx_node, desc=desc, 

27 expected_attributes=Shape_15.atts, **options) 

28 

29 def _interval(self, n): 

30 if self.start == 0: 

31 if numpy.isnan(self.end): 

32 return None 

33 elif self.end < 0: 

34 return (0, n + self.end) 

35 return (0, self.end) 

36 if numpy.isnan(self.end): 

37 return (self.start, n) 

38 elif self.end < 0: 

39 return (self.start, n + self.end) 

40 return (self.start, self.end) 

41 

42 def _run(self, data, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

43 ab = self._interval(len(data.shape)) 

44 if ab is None: 

45 return (numpy.array(data.shape, dtype=numpy.int64), ) 

46 return (numpy.array(data.shape[ab[0]: ab[1]], dtype=numpy.int64), ) 

47 

48 

49if onnx_opset_version() >= 15: 

50 Shape = Shape_15 

51else: # pragma: no cover 

52 Shape = Shape_1