Coverage for mlprodict/onnxrt/ops_cpu/op_if.py: 81%

63 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7from ...onnx_tools.onnx2py_helper import guess_dtype 

8from ._op import OpRun 

9 

10 

11class If(OpRun): 

12 

13 atts = { 

14 'then_branch': None, 

15 'else_branch': None, 

16 } 

17 

18 def __init__(self, onnx_node, desc=None, **options): 

19 OpRun.__init__(self, onnx_node, desc=desc, 

20 expected_attributes=If.atts, 

21 **options) 

22 if not hasattr(self.then_branch, 'run'): 

23 raise RuntimeError( # pragma: no cover 

24 "Parameter 'then_branch' must have a method 'run', " 

25 "type {}.".format(type(self.then_branch))) 

26 if not hasattr(self.else_branch, 'run'): 

27 raise RuntimeError( # pragma: no cover 

28 "Parameter 'else_branch' must have a method 'run', " 

29 "type {}.".format(type(self.else_branch))) 

30 

31 self._run_meth_then = (self.then_branch.run_in_scan 

32 if hasattr(self.then_branch, 'run_in_scan') 

33 else self.then_branch.run) 

34 self._run_meth_else = (self.else_branch.run_in_scan 

35 if hasattr(self.else_branch, 'run_in_scan') 

36 else self.else_branch.run) 

37 self.additional_inputs = list( 

38 set(self.then_branch.static_inputs) | 

39 set(self.else_branch.static_inputs)) 

40 

41 def need_context(self): 

42 """ 

43 Tells the runtime if this node needs the context 

44 (all the results produced so far) as it may silently access 

45 one of them (operator Loop). 

46 The default answer is `False`. 

47 """ 

48 return True 

49 

50 def _run(self, cond, named_inputs=None, context=None, # pylint: disable=W0221 

51 attributes=None, verbose=0, fLOG=None): 

52 if cond is None: 

53 raise RuntimeError( # pragma: no cover 

54 "cond cannot be None") 

55 if named_inputs is None: 

56 named_inputs = {} 

57 if len(self.then_branch.input_names) > 0: 

58 if len(context) == 0: 

59 raise RuntimeError( # pragma: no cover 

60 "named_inputs is empty but the graph needs {}, " 

61 "sub-graphs for node If must not have any inputs.".format( 

62 self.then_branch.input_names)) 

63 for k in self.then_branch.input_names: 

64 if k not in context: 

65 raise RuntimeError( # pragma: no cover 

66 "Unable to find named input '{}' in\n{}.".format( 

67 k, "\n".join(sorted(context)))) 

68 if len(self.else_branch.input_names) > 0: 

69 if len(context) == 0: 

70 raise RuntimeError( # pragma: no cover 

71 "context is empty but the graph needs {}.".format( 

72 self.then_branch.input_names)) 

73 for k in self.else_branch.input_names: 

74 if k not in context: 

75 raise RuntimeError( # pragma: no cover 

76 "Unable to find named input '{}' in\n{}.".format( 

77 k, "\n".join(sorted(context)))) 

78 

79 # then_local_inputs = set(self.local_inputs(self.then_branch.obj.graph)) 

80 # else_local_inputs = set(self.local_inputs(self.else_branch.obj.graph)) 

81 # self.additional_inputs = list( 

82 # set(self.additional_inputs).union(then_local_inputs.union(else_local_inputs))) 

83 # for n in self.additional_inputs: 

84 # self.then_branch.global_index(n) 

85 # self.else_branch.global_index(n) 

86 

87 if len(cond.shape) > 0: 

88 if all(cond): 

89 if verbose > 0 and fLOG is not None: 

90 fLOG( # pragma: no cover 

91 f' -- then> {list(context)!r}') 

92 outputs = self._run_meth_then(named_inputs, context=context, 

93 attributes=attributes, 

94 verbose=verbose, fLOG=fLOG) 

95 if verbose > 0 and fLOG is not None: 

96 fLOG(' -- then<') # pragma: no cover 

97 final = tuple([outputs[name] 

98 for name in self.then_branch.output_names]) 

99 branch = 'then' 

100 else: 

101 if verbose > 0 and fLOG is not None: 

102 fLOG( # pragma: no cover 

103 f' -- else> {list(context)!r}') 

104 outputs = self._run_meth_else(named_inputs, context=context, 

105 attributes=attributes, 

106 verbose=verbose, fLOG=fLOG) 

107 if verbose > 0 and fLOG is not None: 

108 fLOG(' -- else<') # pragma: no cover 

109 final = tuple([outputs[name] 

110 for name in self.else_branch.output_names]) 

111 branch = 'else' 

112 elif cond: 

113 if verbose > 0 and fLOG is not None: 

114 fLOG( # pragma: no cover 

115 f' -- then> {list(context)!r}') 

116 outputs = self._run_meth_then(named_inputs, context=context, 

117 attributes=attributes, 

118 verbose=verbose, fLOG=fLOG) 

119 if verbose > 0 and fLOG is not None: 

120 fLOG(' -- then<') # pragma: no cover 

121 final = tuple([outputs[name] 

122 for name in self.then_branch.output_names]) 

123 branch = 'then' 

124 else: 

125 if verbose > 0 and fLOG is not None: 

126 fLOG( # pragma: no cover 

127 f' -- else> {list(context)!r}') 

128 outputs = self._run_meth_else(named_inputs, context=context, 

129 attributes=attributes, 

130 verbose=verbose, fLOG=fLOG) 

131 if verbose > 0 and fLOG is not None: 

132 fLOG(' -- else<') # pragma: no cover 

133 final = tuple([outputs[name] 

134 for name in self.else_branch.output_names]) 

135 branch = 'else' 

136 

137 if len(final) == 0: 

138 raise RuntimeError( # pragma: no cover 

139 f"Operator If ({self.onnx_node.name!r}) does not have any output.") 

140 for i, f in enumerate(final): 

141 if f is None: 

142 ni = named_inputs if named_inputs else [] # pragma: no cover 

143 br = self.then_branch if branch == 'then' else self.else_branch 

144 names = br.output_names 

145 inits = [i.name for i in br.obj.graph.initializer] 

146 raise RuntimeError( # pragma: no cover 

147 "Output %d (branch=%r, name=%r) is None, available inputs=%r, " 

148 "initializers=%r." % ( 

149 i, branch, names[i], list(sorted(ni)), inits)) 

150 return final 

151 

152 def _pick_type(self, res, name): 

153 if name in res: 

154 return res[name] 

155 out = {o.name: o for o in self.then_branch.obj.graph.output} 

156 if name not in out: 

157 raise ValueError( # pragma: no cover 

158 "Unable to find name=%r in %r or %r." % ( 

159 name, list(sorted(res)), list(sorted(out)))) 

160 dt = out[name].type.tensor_type.elem_type 

161 return guess_dtype(dt)