Coverage for mlprodict/onnxrt/ops_cpu/op_loop.py: 83%
63 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
7.. versionadded:: 0.7
8"""
9import numpy
10from ._op import OpRun
13class Loop(OpRun):
15 atts = {'body': None}
17 def __init__(self, onnx_node, desc=None, **options):
18 OpRun.__init__(self, onnx_node, desc=desc,
19 expected_attributes=Loop.atts,
20 **options)
21 if not hasattr(self.body, 'run'):
22 raise RuntimeError( # pragma: no cover
23 f"Parameter 'body' must have a method 'run', type {type(self.body)}.")
25 self._run_meth = (self.body.run_in_scan
26 if hasattr(self.body, 'run_in_scan')
27 else self.body.run)
28 self.additional_inputs = self.body.static_inputs
30 def need_context(self):
31 """
32 The operator Loop needs to know all results produced
33 so far as the loop may silently access one of them.
34 Some information are not always referred in the list of inputs
35 (kind of static variables).
36 """
37 return len(self.additional_inputs) > 0
39 def _run(self, M, cond, # pylint: disable=W0221
40 *args, callback=None, context=None, # pylint: disable=W0221
41 attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221
42 if len(args) > 0:
43 v_initial = args[0]
44 args = args[1:]
45 else:
46 v_initial = None
47 loop_inputs = self.body.input_names
48 inputs = {name: None for name in loop_inputs}
49 if v_initial is not None:
50 inputs[loop_inputs[2]] = v_initial
51 cond_name = self.body.output_names[0]
52 if len(args) > 0:
53 begin = len(loop_inputs) - len(args)
54 all_inputs = loop_inputs[begin:]
55 for name, val in zip(all_inputs, args):
56 inputs[name] = val
57 if len(self.additional_inputs) > 0:
58 if context is None:
59 raise RuntimeError(
60 "Additional inputs %r are missing and context is None."
61 "" % (self.additional_inputs, ))
62 for a in self.additional_inputs:
63 if a in context:
64 inputs[a] = context[a]
65 else:
66 raise RuntimeError(
67 "Additional inputs %r not found in context\n%s." % (
68 a, "\n".join(sorted(map(str, context)))))
70 it = 0
71 while cond and it < M:
72 if verbose > 1:
73 fLOG(f'-- Loop-Begin-{it}<{M}')
74 if len(self.body.input_names) > 0 and self.body.input_names[0] is not None:
75 inputs[self.body.input_names[0]] = numpy.array(
76 it, dtype=M.dtype)
77 if len(self.body.input_names) > 1 and self.body.input_names[1] is not None:
78 inputs[self.body.input_names[1]] = cond
79 outputs = self._run_meth(
80 inputs, verbose=max(verbose - 1, 0), fLOG=fLOG)
81 cond = outputs[cond_name]
82 if cond is None:
83 raise RuntimeError(
84 f"Condition {cond_name!r} returned by the "
85 f"subgraph cannot be None.")
86 for i, o in zip(self.body.input_names[2:],
87 self.body.output_names[1:]):
88 inputs[i] = outputs[o]
89 if callback is not None:
90 callback(inputs, context=context)
91 if verbose > 1:
92 fLOG(f'-- Loop-End-{it}<{M}')
93 it += 1
95 if it == 0:
96 outputs = {self.body.output_names[1]: cond}
97 for i, o in zip(self.body.input_names[2:],
98 self.body.output_names[1:]):
99 outputs[o] = inputs[i]
100 for o in self.body.output_names:
101 if o not in outputs:
102 outputs[o] = numpy.empty(shape=tuple())
103 res = tuple([outputs[name] for name in self.body.output_names[1:]])
104 if any(r is None for r in res):
105 raise TypeError( # pragma: no cover
106 "Operator Loop produces a None value.")
107 return res