Coverage for mlprodict/onnxrt/ops_onnxruntime/_op.py: 96%

142 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2""" 

3@file 

4@brief Shortcut to *ops_onnxruntime*. 

5""" 

6import numpy 

7import onnx.defs 

8from onnx.helper import make_tensor 

9from onnx.onnx_cpp2py_export.shape_inference import InferenceError # pylint: disable=E0401,E0611 

10from ...tools.ort_wrapper import InferenceSession 

11from ...onnx_tools.onnx2py_helper import guess_proto_dtype 

12from ...onnx_tools.optim.graph_schema_helper import ( 

13 get_defined_inputs, get_defined_outputs, proto2vars) 

14 

15 

16_schemas = { 

17 schema.name: schema for schema in onnx.defs.get_all_schemas_with_history()} 

18 

19 

20class OpRunOnnxRuntime: 

21 """ 

22 Unique operator which calls :epkg:`onnxruntime` 

23 to compute predictions for one operator. 

24 """ 

25 

26 def __init__(self, onnx_node, desc=None, variables=None, 

27 dtype=None, runtime=None, **options): 

28 """ 

29 :param onnx_node: :epkg:`onnx` node 

30 :param desc: internal representation 

31 :param variables: registered variables created by previous operators 

32 :param dtype: float computation type 

33 :param options: runtime options 

34 :param runtime: `onnxruntime1`, `onnxruntime1-cuda`, ... 

35 """ 

36 self._provider = 'onnxruntime' 

37 self.onnx_node = onnx_node 

38 self.desc = desc 

39 self.runtime = runtime 

40 self._schema = _schemas.get(onnx_node.op_type, None) 

41 if desc is not None: 

42 if 'atts' in desc: 

43 for a, b in desc['atts'].items(): 

44 if not isinstance(b, dict) or 'value' not in b: 

45 raise ValueError( # pragma: no cover 

46 f"Unexpected value {b}.") 

47 options[a] = b['value'] 

48 

49 self.options = options 

50 self.dtype = dtype 

51 self._init(variables) 

52 

53 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

54 InvalidArgument as OrtInvalidArgument) 

55 self.OrtInvalidArgument = OrtInvalidArgument 

56 

57 def _name_mapping(self, inputs): 

58 mapping = {} 

59 new_inputs = [] 

60 for name in inputs: 

61 if name in mapping: 

62 i = 0 

63 new_name = f"{name}_{i}" 

64 while new_name in mapping: 

65 i += 1 # pragma: no cover 

66 new_name = f"{name}_{i}" # pragma: no cover 

67 mapping[new_name] = name 

68 new_inputs.append(new_name) 

69 else: 

70 new_inputs.append(name) 

71 mapping[name] = name 

72 return mapping, new_inputs 

73 

74 def _guess_proto_type(self, dtype): 

75 return guess_proto_dtype(dtype) 

76 

77 def _init(self, variables=None): 

78 """ 

79 Initializes the node. 

80 

81 :param variables: registered variables created by previous operators 

82 

83 The current implementation for operator *Scan* 

84 only works for matrices. 

85 """ 

86 custom_nodes = self.options.get('nodes', None) 

87 if (custom_nodes is not None and 

88 self.onnx_node.op_type in custom_nodes): 

89 self.alg_class = custom_nodes[self.onnx_node.op_type] 

90 else: 

91 try: 

92 import mlprodict.onnx_conv.onnx_ops as alg0 

93 self.alg_class = getattr(alg0, 'Onnx' + self.onnx_node.op_type) 

94 except AttributeError: 

95 import skl2onnx.algebra.custom_ops as alg2 # delayed 

96 try: 

97 self.alg_class = getattr( 

98 alg2, 'Onnx' + self.onnx_node.op_type) 

99 except AttributeError: 

100 import skl2onnx.algebra.onnx_ops as alg # delayed 

101 self.alg_class = getattr( 

102 alg, 'Onnx' + self.onnx_node.op_type) 

103 

104 inputs = list(self.onnx_node.input) 

105 self.mapping, self.inputs = self._name_mapping(inputs) 

106 self.outputs = list(self.onnx_node.output) 

107 

108 options = self.options.copy() 

109 options.pop('nodes', None) 

110 target_opset = options.pop('target_opset', None) 

111 domain = options.pop('domain', None) 

112 disable_optimisation = options.pop('disable_optimisation', False) 

113 session_options = options.pop('session_options', False) 

114 ir_version = options.pop('ir_version', None) 

115 

116 if domain == '' and target_opset < 9: 

117 # target_opset should be >= 9 not {} for main domain. 

118 # We assume it was the case when the graph was created. 

119 pass 

120 

121 if self.onnx_node.op_type == 'ZipMap': 

122 from skl2onnx.common.data_types import ( # delayed 

123 DictionaryType, FloatTensorType, Int64TensorType, StringTensorType) 

124 self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs, 

125 op_version=target_opset, **options) 

126 inputs = get_defined_inputs( 

127 self.inputs, variables, dtype=self.dtype) 

128 name = (self.outputs[0] if len(self.outputs) == 1 

129 else self.inst_.expected_outputs[0][0]) 

130 otype = (Int64TensorType if 'classlabels_int64s' in options 

131 else StringTensorType) 

132 outvar = [(name, DictionaryType(otype([1]), FloatTensorType([1])))] 

133 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outvar) 

134 forced = True 

135 elif self.onnx_node.op_type == 'ArrayFeatureExtractor': 

136 self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs, 

137 op_version=target_opset, **options) 

138 inputs = get_defined_inputs( 

139 self.inputs, variables, dtype=self.dtype) 

140 name = (self.outputs[0] if len(self.outputs) == 1 

141 else self.inst_.expected_outputs[0][0]) 

142 otype = inputs[0][1].__class__ 

143 outvar = [(name, otype())] 

144 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outvar) 

145 forced = True 

146 elif self.onnx_node.op_type == 'ConstantOfShape': 

147 for k in options: # pylint: disable=C0206 

148 v = options[k] 

149 if isinstance(v, numpy.ndarray): 

150 options[k] = make_tensor( 

151 k, self._guess_proto_type(v.dtype), 

152 v.shape, v.tolist()) 

153 

154 self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs, 

155 op_version=target_opset, **options) 

156 inputs = get_defined_inputs( 

157 self.inputs, variables, dtype=self.dtype) 

158 try: 

159 self.onnx_ = self.inst_.to_onnx(inputs, target_opset=target_opset, 

160 domain=domain) 

161 if "dim_value: 0" in str(self.onnx_): 

162 raise RuntimeError( # pragma: no cover 

163 f"Probable issue as one dimension is null.\n--\n{self.onnx_}") 

164 except AttributeError as e: # pragma: no cover 

165 # older version of skl2onnx 

166 self.onnx_ = self.inst_.to_onnx(inputs) 

167 if "dim_value: 0" in str(self.onnx_): 

168 raise RuntimeError( 

169 "Probable issue as one dimension is null.\n--\n{}".format( 

170 self.onnx_)) from e 

171 forced = False 

172 elif self.onnx_node.op_type == 'Scan': 

173 self.inst_ = self.alg_class( 

174 *self.inputs, output_names=self.outputs, 

175 op_version=target_opset, **options) 

176 inputs = get_defined_inputs( 

177 self.inputs, variables, dtype=self.dtype) 

178 outputs = get_defined_outputs( 

179 self.outputs, self.onnx_node, inputs, variables, 

180 dtype=self.dtype) 

181 inputs = [(name, cl.__class__([None, None])) 

182 for (name, cl) in inputs] 

183 outputs = [(name, cl.__class__([None, None])) 

184 for (name, cl) in outputs] 

185 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outputs, 

186 target_opset=target_opset, 

187 domain=domain) 

188 if "dim_value: 0" in str(self.onnx_): 

189 raise RuntimeError( # pragma: no cover 

190 f"Probable issue as one dimension is null.\n--\n{self.onnx_}") 

191 forced = True 

192 else: 

193 self.inst_ = self.alg_class(*self.inputs, output_names=self.outputs, 

194 op_version=target_opset, domain=domain, 

195 **options) 

196 inputs = get_defined_inputs( 

197 self.inputs, variables, dtype=self.dtype, 

198 schema=self.alg_class.expected_inputs) 

199 

200 try: 

201 self.onnx_ = self.inst_.to_onnx( 

202 inputs, target_opset=target_opset, domain=domain) 

203 if "dim_value: 0" in str(self.onnx_): 

204 raise RuntimeError( # pragma: no cover 

205 "Probable issue as one dimension is null.\n--\n{}\n---\n{}".format( 

206 self.onnx_, inputs)) 

207 forced = False 

208 except (RuntimeError, ValueError, InferenceError) as eo: 

209 # Let's try again by forcing output types. 

210 forced = True 

211 outputs = get_defined_outputs( 

212 self.outputs, self.onnx_node, inputs, variables, 

213 dtype=self.dtype, schema=self.alg_class.expected_outputs, 

214 schema_inputs=self.alg_class.expected_inputs) 

215 try: 

216 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outputs, 

217 target_opset=target_opset, 

218 domain=domain) 

219 except NotImplementedError as e: # pragma: no cover 

220 raise NotImplementedError( 

221 "Unable to instantiate node {} inputs={} " 

222 "self.inputs={} outputs={} variables={} " 

223 "dtype={} e={} eo={}".format( 

224 self.alg_class, inputs, self.inputs, 

225 outputs, variables, self.dtype, e, eo)) from e 

226 if "dim_value: 0" in str(self.onnx_): 

227 raise RuntimeError( # pragma: no cover 

228 "Probable issue as one dimension is null.\n--\n{}".format( 

229 self.onnx_)) from e 

230 

231 if len(self.onnx_.graph.output) > len(self.outputs): # pragma: no cover 

232 # Something is wrong, falls back to default plan. 

233 forced = True 

234 outputs = get_defined_outputs( 

235 self.outputs, self.onnx_node, inputs, variables, 

236 dtype=self.dtype, schema=self.alg_class.expected_outputs) 

237 self.onnx_ = self.inst_.to_onnx(inputs, outputs=outputs, 

238 target_opset=target_opset, 

239 domain=domain) 

240 if "dim_value: 0" in str(self.onnx_): 

241 raise RuntimeError( # pragma: no cover 

242 f"Probable issue as one dimension is null.\n--\n{self.onnx_}") 

243 else: 

244 lo = list(self.onnx_.graph.output) 

245 outputs = proto2vars(lo) 

246 

247 from onnxruntime import ( # pylint: disable=E0611 

248 SessionOptions, RunOptions, GraphOptimizationLevel) 

249 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

250 Fail as OrtFail, InvalidGraph as OrtInvalidGraph, 

251 NotImplemented as OrtNotImplemented) 

252 

253 sess_options = session_options or SessionOptions() 

254 self.run_options = RunOptions() 

255 

256 if session_options is None: 

257 try: 

258 sess_options.session_log_severity_level = 3 

259 # sess_options.sessions_log_verbosity_level = 0 

260 except AttributeError: # pragma: no cover 

261 # onnxruntime not recent enough. 

262 pass 

263 try: 

264 self.run_options.run_log_severity_level = 3 

265 # self.run_options.run_log_verbosity_level = 0 

266 except AttributeError: # pragma: no cover 

267 # onnxruntime not recent enough. 

268 pass 

269 if disable_optimisation: 

270 sess_options.graph_optimization_level = ( # pragma: no cover 

271 GraphOptimizationLevel.ORT_DISABLE_ALL) 

272 elif disable_optimisation: 

273 raise RuntimeError( # pragma: no cover 

274 "session_options and disable_optimisation cannot be defined " 

275 "at the same time.") 

276 

277 if ir_version is not None: 

278 self.onnx_.ir_version = ir_version 

279 try: 

280 self.sess_ = InferenceSession( 

281 self.onnx_.SerializeToString(), sess_options=sess_options, 

282 runtime=self.runtime) 

283 except (RuntimeError, OrtNotImplemented, OrtInvalidGraph, OrtFail) as e: 

284 raise RuntimeError( 

285 "Unable to load node '{}' (output type was {}) inputs={} " 

286 "self.inputs={} self.onnx_node.input={} " 

287 "variables={} mapping={} " 

288 "expected_inputs={}\n{}".format( 

289 self.onnx_node.op_type, 

290 "guessed" if forced else "inferred", 

291 inputs, self.inputs, self.onnx_node.input, 

292 variables, self.mapping, 

293 self.alg_class.expected_inputs, 

294 self.onnx_)) from e 

295 self.typed_outputs_ = outputs 

296 

297 def run(self, *args, **kwargs): 

298 """ 

299 Should be overwritten. 

300 """ 

301 inputs = {name: val for name, val in zip(self.inputs, args)} 

302 

303 try: 

304 res = self.sess_.run(None, inputs, self.run_options) 

305 except (RuntimeError, self.OrtInvalidArgument) as e: # pragma: no cover 

306 dtypes = {k: v.dtype for k, v in inputs.items()} 

307 shapes = {k: v.shape for k, v in inputs.items()} 

308 exp = [_.name for _ in self.sess_.get_inputs()] 

309 exp_types = [_.type for _ in self.sess_.get_inputs()] 

310 raise RuntimeError( 

311 "Predictions failed. List of inputs: {}, class={}" 

312 "\ndtypes={}\nshapes={}\nexpected={}\nexpected={}\n" 

313 "exception={}\n--ONNX--\n{}".format( 

314 list(sorted(inputs)), self.alg_class, dtypes, 

315 shapes, exp, exp_types, e, self.onnx_)) from e 

316 return tuple(res) 

317 

318 def need_context(self): 

319 """ 

320 Tells the runtime if this node needs the context 

321 (all the results produced so far) as it may silently access 

322 one of them (operator Loop). 

323 The default answer is `False`. 

324 """ 

325 return False