Coverage for mlprodict/onnxrt/ops_cpu/op_scan.py: 100%

52 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11class Scan(OpRun): 

12 

13 atts = { 

14 'body': None, 

15 'num_scan_inputs': None, 

16 'scan_input_axes': [], 

17 'scan_input_directions': [], 

18 'scan_output_axes': [], 

19 'scan_output_directions': [] 

20 } 

21 

22 def __init__(self, onnx_node, desc=None, **options): 

23 OpRun.__init__(self, onnx_node, desc=desc, 

24 expected_attributes=Scan.atts, 

25 **options) 

26 if not hasattr(self.body, 'run'): 

27 raise RuntimeError( # pragma: no cover 

28 f"Parameter 'body' must have a method 'run', type {type(self.body)}.") 

29 self.input_directions_ = [0 if i >= len(self.scan_input_directions) else self.scan_input_directions[i] 

30 for i in range(self.num_scan_inputs)] 

31 max_dir_in = max(self.input_directions_) 

32 if max_dir_in != 0: 

33 raise RuntimeError( # pragma: no cover 

34 "Scan is not implemented for other output input_direction than 0.") 

35 self.input_axes_ = [0 if i >= len(self.scan_input_axes) else self.scan_input_axes[i] 

36 for i in range(self.num_scan_inputs)] 

37 max_axe_in = max(self.input_axes_) 

38 if max_axe_in != 0: 

39 raise RuntimeError( # pragma: no cover 

40 "Scan is not implemented for other input axes than 0.") 

41 self.input_names = self.body.input_names 

42 self.output_names = self.body.output_names 

43 self._run_meth = (self.body.run_in_scan 

44 if hasattr(self.body, 'run_in_scan') 

45 else self.body.run) 

46 

47 def _common_run_shape(self, *args): 

48 num_loop_state_vars = len(args) - self.num_scan_inputs 

49 num_scan_outputs = len(args) - num_loop_state_vars 

50 

51 output_directions = [0 if i >= len(self.scan_output_directions) else self.scan_output_directions[i] 

52 for i in range(num_scan_outputs)] 

53 max_dir_out = max(output_directions) 

54 if max_dir_out != 0: 

55 raise RuntimeError( # pragma: no cover 

56 "Scan is not implemented for other output output_direction than 0.") 

57 output_axes = [0 if i >= len(self.scan_output_axes) else self.scan_output_axes[i] 

58 for i in range(num_scan_outputs)] 

59 max_axe_out = max(output_axes) 

60 if max_axe_out != 0: 

61 raise RuntimeError( # pragma: no cover 

62 "Scan is not implemented for other output axes than 0.") 

63 

64 state_names_in = self.input_names[:self.num_scan_inputs] 

65 state_names_out = self.output_names[:len(state_names_in)] 

66 scan_names_in = self.input_names[num_loop_state_vars:] 

67 scan_names_out = self.output_names[num_loop_state_vars:] 

68 scan_values = args[num_loop_state_vars:] 

69 

70 states = args[:num_loop_state_vars] 

71 

72 return (num_loop_state_vars, num_scan_outputs, output_directions, 

73 max_dir_out, output_axes, max_axe_out, state_names_in, 

74 state_names_out, scan_names_in, scan_names_out, 

75 scan_values, states) 

76 

77 def _run(self, *args, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

78 (num_loop_state_vars, num_scan_outputs, output_directions, # pylint: disable=W0612 

79 max_dir_out, output_axes, max_axe_out, state_names_in, # pylint: disable=W0612 

80 state_names_out, scan_names_in, scan_names_out, # pylint: disable=W0612 

81 scan_values, states) = self._common_run_shape(*args) # pylint: disable=W0612 

82 

83 max_iter = args[num_loop_state_vars].shape[self.input_axes_[0]] 

84 results = [[] for _ in scan_names_out] 

85 

86 for iter in range(max_iter): 

87 inputs = {} 

88 for name, value in zip(state_names_in, states): 

89 inputs[name] = value 

90 for name, value in zip(scan_names_in, scan_values): 

91 inputs[name] = value[iter] 

92 

93 try: 

94 outputs = self._run_meth(inputs) 

95 except TypeError as e: # pragma: no cover 

96 raise TypeError( 

97 f"Unable to call 'run' for type '{type(self.body)}'.") from e 

98 

99 states = [outputs[name] for name in state_names_out] 

100 for i, name in enumerate(scan_names_out): 

101 results[i].append(numpy.expand_dims(outputs[name], axis=0)) 

102 

103 for res in results: 

104 conc = numpy.vstack(res) 

105 states.append(conc) 

106 return tuple(states)