Coverage for mlprodict/onnxrt/ops_cpu/op_scan.py: 100%
52 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7import numpy
8from ._op import OpRun
11class Scan(OpRun):
13 atts = {
14 'body': None,
15 'num_scan_inputs': None,
16 'scan_input_axes': [],
17 'scan_input_directions': [],
18 'scan_output_axes': [],
19 'scan_output_directions': []
20 }
22 def __init__(self, onnx_node, desc=None, **options):
23 OpRun.__init__(self, onnx_node, desc=desc,
24 expected_attributes=Scan.atts,
25 **options)
26 if not hasattr(self.body, 'run'):
27 raise RuntimeError( # pragma: no cover
28 f"Parameter 'body' must have a method 'run', type {type(self.body)}.")
29 self.input_directions_ = [0 if i >= len(self.scan_input_directions) else self.scan_input_directions[i]
30 for i in range(self.num_scan_inputs)]
31 max_dir_in = max(self.input_directions_)
32 if max_dir_in != 0:
33 raise RuntimeError( # pragma: no cover
34 "Scan is not implemented for other output input_direction than 0.")
35 self.input_axes_ = [0 if i >= len(self.scan_input_axes) else self.scan_input_axes[i]
36 for i in range(self.num_scan_inputs)]
37 max_axe_in = max(self.input_axes_)
38 if max_axe_in != 0:
39 raise RuntimeError( # pragma: no cover
40 "Scan is not implemented for other input axes than 0.")
41 self.input_names = self.body.input_names
42 self.output_names = self.body.output_names
43 self._run_meth = (self.body.run_in_scan
44 if hasattr(self.body, 'run_in_scan')
45 else self.body.run)
47 def _common_run_shape(self, *args):
48 num_loop_state_vars = len(args) - self.num_scan_inputs
49 num_scan_outputs = len(args) - num_loop_state_vars
51 output_directions = [0 if i >= len(self.scan_output_directions) else self.scan_output_directions[i]
52 for i in range(num_scan_outputs)]
53 max_dir_out = max(output_directions)
54 if max_dir_out != 0:
55 raise RuntimeError( # pragma: no cover
56 "Scan is not implemented for other output output_direction than 0.")
57 output_axes = [0 if i >= len(self.scan_output_axes) else self.scan_output_axes[i]
58 for i in range(num_scan_outputs)]
59 max_axe_out = max(output_axes)
60 if max_axe_out != 0:
61 raise RuntimeError( # pragma: no cover
62 "Scan is not implemented for other output axes than 0.")
64 state_names_in = self.input_names[:self.num_scan_inputs]
65 state_names_out = self.output_names[:len(state_names_in)]
66 scan_names_in = self.input_names[num_loop_state_vars:]
67 scan_names_out = self.output_names[num_loop_state_vars:]
68 scan_values = args[num_loop_state_vars:]
70 states = args[:num_loop_state_vars]
72 return (num_loop_state_vars, num_scan_outputs, output_directions,
73 max_dir_out, output_axes, max_axe_out, state_names_in,
74 state_names_out, scan_names_in, scan_names_out,
75 scan_values, states)
77 def _run(self, *args, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
78 (num_loop_state_vars, num_scan_outputs, output_directions, # pylint: disable=W0612
79 max_dir_out, output_axes, max_axe_out, state_names_in, # pylint: disable=W0612
80 state_names_out, scan_names_in, scan_names_out, # pylint: disable=W0612
81 scan_values, states) = self._common_run_shape(*args) # pylint: disable=W0612
83 max_iter = args[num_loop_state_vars].shape[self.input_axes_[0]]
84 results = [[] for _ in scan_names_out]
86 for iter in range(max_iter):
87 inputs = {}
88 for name, value in zip(state_names_in, states):
89 inputs[name] = value
90 for name, value in zip(scan_names_in, scan_values):
91 inputs[name] = value[iter]
93 try:
94 outputs = self._run_meth(inputs)
95 except TypeError as e: # pragma: no cover
96 raise TypeError(
97 f"Unable to call 'run' for type '{type(self.body)}'.") from e
99 states = [outputs[name] for name in state_names_out]
100 for i, name in enumerate(scan_names_out):
101 results[i].append(numpy.expand_dims(outputs[name], axis=0))
103 for res in results:
104 conc = numpy.vstack(res)
105 states.append(conc)
106 return tuple(states)