Coverage for mlprodict/onnxrt/ops_cpu/op_loop.py: 83%

63 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6 

7.. versionadded:: 0.7 

8""" 

9import numpy 

10from ._op import OpRun 

11 

12 

13class Loop(OpRun): 

14 

15 atts = {'body': None} 

16 

17 def __init__(self, onnx_node, desc=None, **options): 

18 OpRun.__init__(self, onnx_node, desc=desc, 

19 expected_attributes=Loop.atts, 

20 **options) 

21 if not hasattr(self.body, 'run'): 

22 raise RuntimeError( # pragma: no cover 

23 f"Parameter 'body' must have a method 'run', type {type(self.body)}.") 

24 

25 self._run_meth = (self.body.run_in_scan 

26 if hasattr(self.body, 'run_in_scan') 

27 else self.body.run) 

28 self.additional_inputs = self.body.static_inputs 

29 

30 def need_context(self): 

31 """ 

32 The operator Loop needs to know all results produced 

33 so far as the loop may silently access one of them. 

34 Some information are not always referred in the list of inputs 

35 (kind of static variables). 

36 """ 

37 return len(self.additional_inputs) > 0 

38 

39 def _run(self, M, cond, # pylint: disable=W0221 

40 *args, callback=None, context=None, # pylint: disable=W0221 

41 attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

42 if len(args) > 0: 

43 v_initial = args[0] 

44 args = args[1:] 

45 else: 

46 v_initial = None 

47 loop_inputs = self.body.input_names 

48 inputs = {name: None for name in loop_inputs} 

49 if v_initial is not None: 

50 inputs[loop_inputs[2]] = v_initial 

51 cond_name = self.body.output_names[0] 

52 if len(args) > 0: 

53 begin = len(loop_inputs) - len(args) 

54 all_inputs = loop_inputs[begin:] 

55 for name, val in zip(all_inputs, args): 

56 inputs[name] = val 

57 if len(self.additional_inputs) > 0: 

58 if context is None: 

59 raise RuntimeError( 

60 "Additional inputs %r are missing and context is None." 

61 "" % (self.additional_inputs, )) 

62 for a in self.additional_inputs: 

63 if a in context: 

64 inputs[a] = context[a] 

65 else: 

66 raise RuntimeError( 

67 "Additional inputs %r not found in context\n%s." % ( 

68 a, "\n".join(sorted(map(str, context))))) 

69 

70 it = 0 

71 while cond and it < M: 

72 if verbose > 1: 

73 fLOG(f'-- Loop-Begin-{it}<{M}') 

74 if len(self.body.input_names) > 0 and self.body.input_names[0] is not None: 

75 inputs[self.body.input_names[0]] = numpy.array( 

76 it, dtype=M.dtype) 

77 if len(self.body.input_names) > 1 and self.body.input_names[1] is not None: 

78 inputs[self.body.input_names[1]] = cond 

79 outputs = self._run_meth( 

80 inputs, verbose=max(verbose - 1, 0), fLOG=fLOG) 

81 cond = outputs[cond_name] 

82 if cond is None: 

83 raise RuntimeError( 

84 f"Condition {cond_name!r} returned by the " 

85 f"subgraph cannot be None.") 

86 for i, o in zip(self.body.input_names[2:], 

87 self.body.output_names[1:]): 

88 inputs[i] = outputs[o] 

89 if callback is not None: 

90 callback(inputs, context=context) 

91 if verbose > 1: 

92 fLOG(f'-- Loop-End-{it}<{M}') 

93 it += 1 

94 

95 if it == 0: 

96 outputs = {self.body.output_names[1]: cond} 

97 for i, o in zip(self.body.input_names[2:], 

98 self.body.output_names[1:]): 

99 outputs[o] = inputs[i] 

100 for o in self.body.output_names: 

101 if o not in outputs: 

102 outputs[o] = numpy.empty(shape=tuple()) 

103 res = tuple([outputs[name] for name in self.body.output_names[1:]]) 

104 if any(r is None for r in res): 

105 raise TypeError( # pragma: no cover 

106 "Operator Loop produces a None value.") 

107 return res