Coverage for mlprodict/onnxrt/ops_cpu/op_reshape.py: 97%

29 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from onnx.defs import onnx_opset_version 

9from ._op import OpRun 

10 

11 

12def reshape_reference_implementation(data, shape): 

13 new_shape = numpy.copy(shape) 

14 zeros_index = numpy.where(shape == 0) 

15 if len(data.shape) == 1 and data.shape[0] == 0: 

16 reshaped = numpy.reshape(data, shape) 

17 else: 

18 try: 

19 new_shape[zeros_index] = numpy.array(data.shape)[zeros_index] 

20 except IndexError as e: # pragma: no cover 

21 raise RuntimeError( 

22 "Unable to reshape from shape %r to shape %r (or %r)." 

23 "" % (data.shape, shape, new_shape)) from e 

24 reshaped = numpy.reshape(data, new_shape) 

25 return reshaped 

26 

27 

28class CommonReshape(OpRun): 

29 

30 def __init__(self, onnx_node, desc=None, expected_attributes=None, **options): 

31 OpRun.__init__( 

32 self, onnx_node, desc=desc, 

33 expected_attributes=expected_attributes, **options) 

34 

35 def _run(self, data, shape, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

36 return (reshape_reference_implementation(data, shape), ) 

37 

38 

39class Reshape_5(CommonReshape): 

40 

41 def __init__(self, onnx_node, desc=None, expected_attributes=None, **options): 

42 CommonReshape.__init__(self, onnx_node, desc=desc, **options) 

43 

44 

45class Reshape_13(Reshape_5): 

46 pass 

47 

48 

49class Reshape_14(CommonReshape): 

50 

51 atts = {'allowzero': 0} 

52 

53 def __init__(self, onnx_node, desc=None, **options): 

54 CommonReshape.__init__( 

55 self, onnx_node, desc=desc, 

56 expected_attributes=Reshape_14.atts, **options) 

57 

58 

59if onnx_opset_version() >= 14: 

60 Reshape = Reshape_14 

61else: 

62 Reshape = Reshape_5 # pragma: no cover