Coverage for mlprodict/npy/numpy_onnx_impl_body.py: 74%

70 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Design to implement graph as parameter. 

4 

5.. versionadded:: 0.8 

6""" 

7import logging 

8import numpy 

9from .onnx_variable import OnnxVar 

10from .xop import loadop 

11 

12 

13logger = logging.getLogger('xop') 

14 

15 

16class AttributeGraph: 

17 """ 

18 Class wrapping a function to make it simple as 

19 a parameter. 

20 

21 :param fct: function taking the list of inputs defined 

22 as @see cl OnnxVar, the function returns an @see cl OnnxVar 

23 :param inputs: list of input as @see cl OnnxVar 

24 

25 .. versionadded:: 0.8 

26 """ 

27 

28 def __init__(self, fct, *inputs): 

29 logger.debug('AttributeGraph(%r, %d in)', type(fct), len(inputs)) 

30 if isinstance(fct, numpy.ndarray) and len(inputs) == 0: 

31 self.cst = fct 

32 fct = None 

33 else: 

34 self.cst = None 

35 self.fct = fct 

36 self.inputs = inputs 

37 self.alg_ = None 

38 

39 def __repr__(self): 

40 "usual" 

41 return f"{self.__class__.__name__}(...)" 

42 

43 def _graph_guess_dtype(self, i, var): 

44 """ 

45 Guesses the graph inputs. 

46 

47 :param i: attribute index (integer) 

48 :param var: the input (@see cl OnnxVar) 

49 :return: input type 

50 """ 

51 dtype = var._guess_dtype(None) 

52 if dtype is None: 

53 dtype = numpy.float32 

54 

55 input_name = 'graph_%d_%d' % (id(self), i) 

56 return OnnxVar(input_name, dtype=dtype) 

57 

58 def to_algebra(self, op_version=None): 

59 """ 

60 Converts the variable into an operator. 

61 """ 

62 if self.alg_ is not None: 

63 return self.alg_ 

64 

65 logger.debug('AttributeGraph.to_algebra(op_version=%r)', 

66 op_version) 

67 if self.cst is not None: 

68 OnnxIdentity = loadop('Identity') 

69 self.alg_ = OnnxIdentity(self.cst, op_version=op_version) 

70 self.alg_inputs_ = None 

71 logger.debug('AttributeGraph.to_algebra:end:1:%r', type(self.alg_)) 

72 return self.alg_ 

73 

74 new_inputs = [self._graph_guess_dtype(i, inp) 

75 for i, inp in enumerate(self.inputs)] 

76 self.alg_inputs_ = new_inputs 

77 vars = [v[1] for v in new_inputs] 

78 var = self.fct(*vars) 

79 if not isinstance(var, OnnxVar): 

80 raise RuntimeError( # pragma: no cover 

81 f"var is not from type OnnxVar but {type(var)!r}.") 

82 

83 self.alg_ = var.to_algebra(op_version=op_version) 

84 logger.debug('AttributeGraph.to_algebra:end:2:%r', type(self.alg_)) 

85 return self.alg_ 

86 

87 

88class OnnxVarGraph(OnnxVar): 

89 """ 

90 Overloads @see cl OnnxVar to handle graph attribute. 

91 

92 :param inputs: variable name or object 

93 :param op: :epkg:`ONNX` operator 

94 :param select_output: if multiple output are returned by 

95 ONNX operator *op*, it takes only one specifed by this 

96 argument 

97 :param dtype: specifies the type of the variable 

98 held by this class (*op* is None) in that case 

99 :param fields: list of attributes with the graph type 

100 :param kwargs: addition argument to give operator *op* 

101 

102 .. versionadded:: 0.8 

103 """ 

104 

105 def __init__(self, *inputs, op=None, select_output=None, 

106 dtype=None, **kwargs): 

107 OnnxVar.__init__( 

108 self, *inputs, op=op, select_output=select_output, 

109 dtype=dtype, **kwargs) 

110 

111 def to_algebra(self, op_version=None): 

112 """ 

113 Converts the variable into an operator. 

114 """ 

115 if self.alg_ is not None: 

116 return self.alg_ 

117 

118 logger.debug('OnnxVarGraph.to_algebra(op_version=%r)', 

119 op_version) 

120 # Conversion of graph attributes from InputGraph 

121 # ONNX graph. 

122 updates = dict() 

123 self.alg_hidden_var_ = {} 

124 self.alg_hidden_var_inputs = {} 

125 for att, var in self.onnx_op_kwargs.items(): 

126 if not isinstance(var, AttributeGraph): 

127 continue 

128 alg = var.to_algebra(op_version=op_version) 

129 if var.alg_inputs_ is None: 

130 onnx_inputs = [] 

131 else: 

132 onnx_inputs = [i[0] for i in var.alg_inputs_] 

133 onx = alg.to_onnx(onnx_inputs, target_opset=op_version) 

134 updates[att] = onx.graph 

135 self.alg_hidden_var_[id(var)] = var 

136 self.alg_hidden_var_inputs[id(var)] = onnx_inputs 

137 self.onnx_op_kwargs_before = { 

138 k: self.onnx_op_kwargs[k] for k in updates} 

139 self.onnx_op_kwargs.update(updates) 

140 self.alg_ = OnnxVar.to_algebra(self, op_version=op_version) 

141 logger.debug('OnnxVarGraph.to_algebra:end:%r', type(self.alg_)) 

142 return self.alg_ 

143 

144 

145class if_then_else(AttributeGraph): 

146 """ 

147 Overloads class @see cl OnnxVarGraph. 

148 """ 

149 

150 def __init__(self, fct, *inputs): 

151 AttributeGraph.__init__(self, fct, *inputs)