Coverage for mlprodict/onnxrt/ops_empty/_op.py: 71%

86 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2""" 

3@file 

4@brief Shortcut to *ops_onnxruntime*. 

5""" 

6import numpy 

7import onnx.defs 

8from onnx.helper import make_tensor 

9import skl2onnx.algebra.onnx_ops as alg 

10try: 

11 import skl2onnx.algebra.custom_ops as alg2 

12except ImportError: # pragma: no cover 

13 # older version of skl2onnx 

14 alg2 = alg 

15from ...onnx_tools.onnx2py_helper import guess_proto_dtype 

16from ...onnx_tools.optim.graph_schema_helper import ( 

17 get_defined_inputs, get_defined_outputs, proto2vars) 

18 

19 

20_schemas = { 

21 schema.name: schema for schema in onnx.defs.get_all_schemas_with_history()} 

22 

23 

24class OpRunOnnxEmpty: 

25 """ 

26 Unique operator for an empty runtime. 

27 """ 

28 

29 def __init__(self, onnx_node, desc=None, variables=None, 

30 dtype=None, **options): 

31 """ 

32 :param onnx_node: :epkg:`onnx` node 

33 :param desc: internal representation 

34 :param variables: registered variables created by previous operators 

35 :param dtype: float computation type 

36 :param options: runtime options 

37 """ 

38 self._provider = 'empty' 

39 self.onnx_node = onnx_node 

40 self.desc = desc 

41 self._schema = _schemas.get(onnx_node.op_type, None) 

42 if desc is not None: 

43 if 'atts' in desc: 

44 for a, b in desc['atts'].items(): 

45 if not isinstance(b, dict) or 'value' not in b: 

46 raise ValueError( # pragma: no cover 

47 f"Unexpected value {b}.") 

48 options[a] = b['value'] 

49 

50 self.options = options 

51 self.dtype = dtype 

52 self._init(variables) 

53 

54 def _name_mapping(self, inputs): 

55 mapping = {} 

56 new_inputs = [] 

57 for name in inputs: 

58 if name in mapping: 

59 i = 0 

60 new_name = f"{name}_{i}" 

61 while new_name in mapping: 

62 i += 1 # pragma: no cover 

63 new_name = f"{name}_{i}" # pragma: no cover 

64 mapping[new_name] = name 

65 new_inputs.append(new_name) 

66 else: 

67 new_inputs.append(name) 

68 mapping[name] = name 

69 return mapping, new_inputs 

70 

71 def _guess_proto_type(self, dtype): 

72 return guess_proto_dtype(dtype) 

73 

74 def _init(self, variables=None): 

75 """ 

76 Initializes the node. 

77 

78 @param variables registered variables created by previous operators 

79 

80 The current implementation for operator *Scan* 

81 only works for matrices. 

82 """ 

83 try: 

84 self.alg_class = getattr(alg2, 'Onnx' + self.onnx_node.op_type) 

85 except AttributeError: 

86 try: 

87 self.alg_class = getattr(alg, 'Onnx' + self.onnx_node.op_type) 

88 except AttributeError: 

89 self.alg_class = None 

90 inputs = list(self.onnx_node.input) 

91 self.mapping, self.inputs = self._name_mapping(inputs) 

92 self.outputs = list(self.onnx_node.output) 

93 

94 options = self.options.copy() 

95 target_opset = options.pop('target_opset', None) 

96 domain = options.pop('domain', None) 

97 # disable_optimisation = options.pop('disable_optimisation', False) 

98 # ir_version = options.pop('ir_version', None) 

99 

100 if self.alg_class is None: 

101 self.onnx_ = self.onnx_node 

102 elif self.onnx_node.op_type == 'ConstantOfShape': 

103 for k in options: # pylint: disable=C0206 

104 v = options[k] 

105 if isinstance(v, numpy.ndarray): 

106 options[k] = make_tensor( 

107 k, self._guess_proto_type(v.dtype), 

108 v.shape, v.tolist()) 

109 

110 self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs, 

111 op_version=target_opset, **options) 

112 inputs = get_defined_inputs( 

113 self.inputs, variables, dtype=self.dtype) 

114 try: 

115 self.onnx_ = self.inst_.to_onnx(inputs, target_opset=target_opset, 

116 domain=domain) 

117 if "dim_value: 0" in str(self.onnx_): 

118 raise RuntimeError( # pragma: no cover 

119 f"Probable issue as one dimension is null.\n--\n{self.onnx_}") 

120 except AttributeError as e: # pragma: no cover 

121 # older version of skl2onnx 

122 self.onnx_ = self.inst_.to_onnx(inputs) 

123 if "dim_value: 0" in str(self.onnx_): 

124 raise RuntimeError( 

125 "Probable issue as one dimension is null.\n--\n{}".format( 

126 self.onnx_)) from e 

127 elif self.onnx_node.op_type == 'Scan': 

128 self.inst_ = self.alg_class( 

129 *self.inputs, output_names=self.outputs, 

130 op_version=target_opset, **options) 

131 inputs = get_defined_inputs( 

132 self.inputs, variables, dtype=self.dtype) 

133 outputs = get_defined_outputs( 

134 self.outputs, self.onnx_node, inputs, variables, 

135 dtype=self.dtype) 

136 inputs = [(name, cl.__class__([None, None])) 

137 for (name, cl) in inputs] 

138 outputs = [(name, cl.__class__([None, None])) 

139 for (name, cl) in outputs] 

140 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outputs, 

141 target_opset=target_opset, 

142 domain=domain) 

143 if "dim_value: 0" in str(self.onnx_): 

144 raise RuntimeError( # pragma: no cover 

145 f"Probable issue as one dimension is null.\n--\n{self.onnx_}") 

146 else: 

147 self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs, 

148 op_version=target_opset, domain=domain, 

149 **options) 

150 inputs = get_defined_inputs( 

151 self.inputs, variables, dtype=self.dtype) 

152 

153 try: 

154 self.onnx_ = self.inst_.to_onnx( 

155 inputs, target_opset=target_opset, domain=domain) 

156 if "dim_value: 0" in str(self.onnx_): 

157 raise RuntimeError( # pragma: no cover 

158 "Probable issue as one dimension is null.\n--\n{}\n---\n{}".format( 

159 self.onnx_, inputs)) 

160 except (RuntimeError, ValueError): # pragma: no cover 

161 # Let's try again by forcing output types. 

162 outputs = get_defined_outputs( 

163 self.outputs, self.onnx_node, inputs, variables, 

164 dtype=self.dtype) 

165 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outputs, 

166 target_opset=target_opset, 

167 domain=domain) 

168 if "dim_value: 0" in str(self.onnx_): 

169 raise RuntimeError( # pragma: no cover 

170 "Probable issue as one dimension is null.\n--\n{}".format( 

171 self.onnx_)) from e 

172 

173 if hasattr(self.onnx_, 'graph'): 

174 if len(self.onnx_.graph.output) != len(self.outputs): # pragma: no cover 

175 # Something is wrong, falls back to default plan. 

176 outputs = get_defined_outputs( 

177 self.outputs, self.onnx_node, inputs, variables, 

178 dtype=self.dtype) 

179 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outputs, 

180 target_opset=target_opset, 

181 domain=domain) 

182 if "dim_value: 0" in str(self.onnx_): 

183 raise RuntimeError( # pragma: no cover 

184 f"Probable issue as one dimension is null.\n--\n{self.onnx_}") 

185 else: 

186 lo = list(self.onnx_.graph.output) 

187 outputs = proto2vars(lo) 

188 else: 

189 outputs = [(o, None) for o in self.onnx_.output] 

190 

191 self.typed_outputs_ = outputs 

192 

193 def run(self, *args, **kwargs): 

194 """ 

195 Should be overwritten. 

196 """ 

197 # inputs = {name: val for name, val in zip(self.inputs, args)} 

198 raise RuntimeError( # pragma: no cover 

199 "This runtime does nothing. Running it is useless.") 

200 

201 def need_context(self): 

202 """ 

203 Tells the runtime if this node needs the context 

204 (all the results produced so far) as it may silently access 

205 one of them (operator Loop). 

206 The default answer is `False`. 

207 """ 

208 return False