Coverage for mlprodict/onnxrt/ops_cpu/op_if.py: 81%
63 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2# pylint: disable=E0203,E1101,C0111
3"""
4@file
5@brief Runtime operator.
6"""
7from ...onnx_tools.onnx2py_helper import guess_dtype
8from ._op import OpRun
11class If(OpRun):
13 atts = {
14 'then_branch': None,
15 'else_branch': None,
16 }
18 def __init__(self, onnx_node, desc=None, **options):
19 OpRun.__init__(self, onnx_node, desc=desc,
20 expected_attributes=If.atts,
21 **options)
22 if not hasattr(self.then_branch, 'run'):
23 raise RuntimeError( # pragma: no cover
24 "Parameter 'then_branch' must have a method 'run', "
25 "type {}.".format(type(self.then_branch)))
26 if not hasattr(self.else_branch, 'run'):
27 raise RuntimeError( # pragma: no cover
28 "Parameter 'else_branch' must have a method 'run', "
29 "type {}.".format(type(self.else_branch)))
31 self._run_meth_then = (self.then_branch.run_in_scan
32 if hasattr(self.then_branch, 'run_in_scan')
33 else self.then_branch.run)
34 self._run_meth_else = (self.else_branch.run_in_scan
35 if hasattr(self.else_branch, 'run_in_scan')
36 else self.else_branch.run)
37 self.additional_inputs = list(
38 set(self.then_branch.static_inputs) |
39 set(self.else_branch.static_inputs))
41 def need_context(self):
42 """
43 Tells the runtime if this node needs the context
44 (all the results produced so far) as it may silently access
45 one of them (operator Loop).
46 The default answer is `False`.
47 """
48 return True
50 def _run(self, cond, named_inputs=None, context=None, # pylint: disable=W0221
51 attributes=None, verbose=0, fLOG=None):
52 if cond is None:
53 raise RuntimeError( # pragma: no cover
54 "cond cannot be None")
55 if named_inputs is None:
56 named_inputs = {}
57 if len(self.then_branch.input_names) > 0:
58 if len(context) == 0:
59 raise RuntimeError( # pragma: no cover
60 "named_inputs is empty but the graph needs {}, "
61 "sub-graphs for node If must not have any inputs.".format(
62 self.then_branch.input_names))
63 for k in self.then_branch.input_names:
64 if k not in context:
65 raise RuntimeError( # pragma: no cover
66 "Unable to find named input '{}' in\n{}.".format(
67 k, "\n".join(sorted(context))))
68 if len(self.else_branch.input_names) > 0:
69 if len(context) == 0:
70 raise RuntimeError( # pragma: no cover
71 "context is empty but the graph needs {}.".format(
72 self.then_branch.input_names))
73 for k in self.else_branch.input_names:
74 if k not in context:
75 raise RuntimeError( # pragma: no cover
76 "Unable to find named input '{}' in\n{}.".format(
77 k, "\n".join(sorted(context))))
79 # then_local_inputs = set(self.local_inputs(self.then_branch.obj.graph))
80 # else_local_inputs = set(self.local_inputs(self.else_branch.obj.graph))
81 # self.additional_inputs = list(
82 # set(self.additional_inputs).union(then_local_inputs.union(else_local_inputs)))
83 # for n in self.additional_inputs:
84 # self.then_branch.global_index(n)
85 # self.else_branch.global_index(n)
87 if len(cond.shape) > 0:
88 if all(cond):
89 if verbose > 0 and fLOG is not None:
90 fLOG( # pragma: no cover
91 f' -- then> {list(context)!r}')
92 outputs = self._run_meth_then(named_inputs, context=context,
93 attributes=attributes,
94 verbose=verbose, fLOG=fLOG)
95 if verbose > 0 and fLOG is not None:
96 fLOG(' -- then<') # pragma: no cover
97 final = tuple([outputs[name]
98 for name in self.then_branch.output_names])
99 branch = 'then'
100 else:
101 if verbose > 0 and fLOG is not None:
102 fLOG( # pragma: no cover
103 f' -- else> {list(context)!r}')
104 outputs = self._run_meth_else(named_inputs, context=context,
105 attributes=attributes,
106 verbose=verbose, fLOG=fLOG)
107 if verbose > 0 and fLOG is not None:
108 fLOG(' -- else<') # pragma: no cover
109 final = tuple([outputs[name]
110 for name in self.else_branch.output_names])
111 branch = 'else'
112 elif cond:
113 if verbose > 0 and fLOG is not None:
114 fLOG( # pragma: no cover
115 f' -- then> {list(context)!r}')
116 outputs = self._run_meth_then(named_inputs, context=context,
117 attributes=attributes,
118 verbose=verbose, fLOG=fLOG)
119 if verbose > 0 and fLOG is not None:
120 fLOG(' -- then<') # pragma: no cover
121 final = tuple([outputs[name]
122 for name in self.then_branch.output_names])
123 branch = 'then'
124 else:
125 if verbose > 0 and fLOG is not None:
126 fLOG( # pragma: no cover
127 f' -- else> {list(context)!r}')
128 outputs = self._run_meth_else(named_inputs, context=context,
129 attributes=attributes,
130 verbose=verbose, fLOG=fLOG)
131 if verbose > 0 and fLOG is not None:
132 fLOG(' -- else<') # pragma: no cover
133 final = tuple([outputs[name]
134 for name in self.else_branch.output_names])
135 branch = 'else'
137 if len(final) == 0:
138 raise RuntimeError( # pragma: no cover
139 f"Operator If ({self.onnx_node.name!r}) does not have any output.")
140 for i, f in enumerate(final):
141 if f is None:
142 ni = named_inputs if named_inputs else [] # pragma: no cover
143 br = self.then_branch if branch == 'then' else self.else_branch
144 names = br.output_names
145 inits = [i.name for i in br.obj.graph.initializer]
146 raise RuntimeError( # pragma: no cover
147 "Output %d (branch=%r, name=%r) is None, available inputs=%r, "
148 "initializers=%r." % (
149 i, branch, names[i], list(sorted(ni)), inits))
150 return final
152 def _pick_type(self, res, name):
153 if name in res:
154 return res[name]
155 out = {o.name: o for o in self.then_branch.obj.graph.output}
156 if name not in out:
157 raise ValueError( # pragma: no cover
158 "Unable to find name=%r in %r or %r." % (
159 name, list(sorted(res)), list(sorted(out))))
160 dt = out[name].type.tensor_type.elem_type
161 return guess_dtype(dt)