Coverage for mlprodict/onnxrt/onnx_inference_node.py: 87%

271 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief OnnxInferenceNode definition. 

4""" 

5import sys 

6import numpy 

7from onnx import GraphProto, onnx_pb as onnx_proto 

8from onnx.onnx_cpp2py_export.defs import SchemaError # pylint: disable=E0401,E0611 

9from ..onnx_tools.onnx2py_helper import get_onnx_schema 

10from .excs import MissingOperatorError 

11from .ops import load_op 

12 

13 

14class OnnxInferenceNode: 

15 """ 

16 A node to execute. 

17 

18 :param onnx_node: onnx_node 

19 :param desc: internal description 

20 :param global_index: it is a function which returns a unique index 

21 for the output this operator generates 

22 """ 

23 class OnnxInferenceWrapper: 

24 """ 

25 Wraps @see cl OnnxInference in a wrapper and exposes 

26 the necessary function. 

27 

28 :param oinf: instance of @see cl OnnxInference 

29 """ 

30 

31 def __init__(self, oinf): 

32 if oinf is None: 

33 raise ValueError( # pragma: no cover 

34 "oinf cannot be None.") 

35 self.oinf = oinf 

36 

37 @property 

38 def args_default(self): 

39 "Returns the list of default arguments." 

40 return [] 

41 

42 @property 

43 def args_default_modified(self): 

44 "Returns the list of modified arguments." 

45 return [] 

46 

47 @property 

48 def args_mandatory(self): 

49 "Returns the list of mandatory arguments." 

50 return self.oinf.input_names 

51 

52 @property 

53 def args_optional(self): 

54 "Returns the list of optional arguments." 

55 return [] 

56 

57 @property 

58 def obj(self): 

59 "Returns the ONNX graph." 

60 return self.oinf.obj 

61 

62 def run(self, *args, **kwargs): 

63 "Calls run." 

64 return self.oinf.run(*args, **kwargs) 

65 

66 def to_python(self, inputs, *args, **kwargs): 

67 "Calls to_python." 

68 res = self.oinf.to_python(*args, **kwargs) 

69 if len(res) != 1: 

70 raise NotImplementedError( # pragma: no cover 

71 "Not implemented if the code has multiple files.") 

72 keys = list(res) 

73 value = res[keys[0]] 

74 lines = value.split('\n') 

75 last = 0 

76 for i, line in enumerate(lines): 

77 if line.startswith('def '): 

78 last = i - 1 

79 break 

80 imports = '\n'.join( 

81 line for line in lines[:last] if 'import ' in line) 

82 lines.append('') 

83 lines.append( 

84 f"return OnnxPythonInference().run({', '.join(inputs)})") 

85 code = '\n'.join(lines[last:]) 

86 return imports, code 

87 

88 def need_context(self): 

89 "Needs context?" 

90 return False 

91 

92 def enable_inplace_compute(self, index): 

93 "Not implemented." 

94 pass 

95 

96 def __init__(self, onnx_node, desc, global_index): 

97 if desc is None: 

98 raise ValueError("desc should not be None.") # pragma: no cover 

99 self.desc = desc 

100 self.onnx_node = onnx_node 

101 self._init(global_index) 

102 

103 @property 

104 def name(self): 

105 "Returns the ONNX name." 

106 return "_".join( 

107 [self.desc['domain'], self.onnx_node.op_type]).replace( 

108 ".", "_").replace('__', '_').strip('_') 

109 

110 def _init(self, global_index): 

111 """ 

112 Prepares the node. 

113 """ 

114 self.op_type = self.onnx_node.op_type 

115 self.order = -1 

116 self.variable_to_clean = [] 

117 self.inputs = list(self.onnx_node.input) 

118 self.outputs = list(self.onnx_node.output) 

119 self.inplaces = [] 

120 self.inputs_indices = [global_index(name) for name in self.inputs] 

121 self.outputs_indices = [global_index(name) for name in self.outputs] 

122 self._global_index = global_index 

123 

124 def set_order(self, order): 

125 """ 

126 Defines the order of execution. 

127 """ 

128 self.order = order 

129 

130 def add_variable_to_clean(self, name): 

131 """ 

132 Adds a variable which can be cleaned after the node 

133 execution. 

134 """ 

135 self.variable_to_clean.append(name) 

136 

137 def __str__(self): 

138 "usual" 

139 return "Onnx-{}({}) -> {}{}".format( 

140 self.op_type, ", ".join(self.inputs), ", ".join(self.outputs), 

141 " (name=%r)" % self.onnx_node.name 

142 if self.onnx_node.name else "") 

143 

144 def __repr__(self): 

145 "usual" 

146 return self.__str__() 

147 

148 def setup_runtime(self, runtime=None, variables=None, rt_class=None, 

149 target_opset=None, dtype=None, domain=None, 

150 ir_version=None, runtime_options=None, 

151 build_inference_node_function=None, 

152 existing_functions=None): 

153 """ 

154 Loads runtime. 

155 

156 :param runtime: runtime options 

157 :param variables: registered variables created by previous operators 

158 :param rt_class: runtime class used to compute 

159 prediction of subgraphs 

160 :param target_opset: use a specific target opset 

161 :param dtype: float computational type 

162 :param domain: node domain 

163 :param ir_version: if not None, changes the default value 

164 given by :epkg:`ONNX` 

165 :param runtime_options: runtime options 

166 :param build_inference_node_function: function creating an inference 

167 runtime from an ONNX graph 

168 :param existing_functions: existing function as a dictionary 

169 `{ (domain, name): fct }` 

170 

171 .. versionchanged:: 0.9 

172 Parameters *build_inference_node_function* and *existing_functions* 

173 were added. 

174 """ 

175 if self.desc is None: 

176 raise AttributeError( 

177 "desc should not be None.") # pragma: no cover 

178 if rt_class is None: 

179 # path used when this operator is a function. 

180 self.function_ = OnnxInferenceNode.OnnxInferenceWrapper(runtime) 

181 self.ops_ = None 

182 return 

183 

184 self.function_ = None 

185 self.preprocess_parameters( 

186 runtime, rt_class, ir_version=ir_version, 

187 target_opset=target_opset, existing_functions=existing_functions) 

188 options = {'provider': runtime} if runtime else {} 

189 if domain is not None: 

190 options['domain'] = domain 

191 if target_opset is not None: 

192 options['target_opset'] = target_opset 

193 if ir_version is not None: 

194 options['ir_version'] = ir_version 

195 if runtime_options is not None: 

196 options.update({ 

197 k: v for k, v in runtime_options.items() 

198 if k not in {'log_severity_level'}}) 

199 

200 # existing functions? 

201 key = (self.onnx_node.domain, self.onnx_node.name) 

202 if existing_functions is not None and key in existing_functions: 

203 self.ops_ = existing_functions[key] 

204 return 

205 

206 # regular node 

207 try: 

208 if runtime is not None and runtime.startswith('onnxruntime2'): 

209 self.ops_ = load_op(self.onnx_node, desc=self.desc, 

210 options=options if options else None, 

211 variables=variables, dtype=dtype, 

212 runtime=runtime) 

213 elif runtime in ('python_compiled', 'python_compiled_debug'): 

214 options['provider'] = 'python' 

215 self.ops_ = load_op(self.onnx_node, desc=self.desc, 

216 options=options if options else None, 

217 variables=variables, dtype=dtype, 

218 runtime=runtime) 

219 else: 

220 self.ops_ = load_op(self.onnx_node, desc=self.desc, 

221 options=options if options else None, 

222 variables=variables, dtype=dtype, 

223 runtime=runtime) 

224 except MissingOperatorError as e: 

225 try: 

226 onnx_schema = get_onnx_schema( 

227 self.onnx_node.op_type, self.onnx_node.domain, 

228 opset=target_opset) 

229 except SchemaError: 

230 fct_names = ( 

231 list(existing_functions.keys()) if existing_functions 

232 else []) 

233 raise MissingOperatorError( 

234 "Unable to find runtime for node (%r, %r), " 

235 "available functions=%r." % ( 

236 self.onnx_node.domain, self.onnx_node.op_type, 

237 fct_names)) from e 

238 if onnx_schema is None or not onnx_schema.has_function: 

239 raise e 

240 self.function_ = OnnxInferenceNode.OnnxInferenceWrapper( 

241 build_inference_node_function(onnx_schema.function_body)) 

242 self.ops_ = None 

243 

244 @staticmethod 

245 def _find_static_inputs(body): 

246 """ 

247 Determines the loop inputs. It is any defined inputs 

248 by the subgraphs + any result used as a constant 

249 in the subgraphs. 

250 """ 

251 inputs_set = set(i.name for i in body.input) 

252 for init in body.initializer: 

253 inputs_set.add(init.name) 

254 for node in body.node: 

255 for i in node.output: 

256 inputs_set.add(i) 

257 add_inputs = [] 

258 for node in body.node: 

259 for i in node.input: 

260 if i not in inputs_set: 

261 # no graph input or output node matches 

262 # it must be a constant from the below graph 

263 add_inputs.append(i) 

264 inputs_set.add(i) 

265 for att in node.attribute: 

266 if (att.type == onnx_proto.AttributeProto.GRAPH and # pylint: disable=E1101 

267 hasattr(att, 'g') and att.g is not None): 

268 inside = OnnxInferenceNode._find_static_inputs(att.g) 

269 for i in inside: 

270 if i not in inputs_set: 

271 add_inputs.append(i) 

272 inputs_set.add(i) 

273 # If there is no node, we add the outputs as well. 

274 if len(body.node) == 0: 

275 for o in body.output: 

276 i = o.name 

277 if i not in inputs_set: 

278 add_inputs.append(i) 

279 inputs_set.add(i) 

280 return add_inputs 

281 

282 @staticmethod 

283 def _find_local_inputs(graph): 

284 """ 

285 Determines the local inputs. It is any defined input 

286 used by the subgraph and defined in the parent graph. 

287 """ 

288 if not isinstance(graph, GraphProto): 

289 raise TypeError( 

290 f"Unexpected type {type(graph)!r}.") 

291 local = set() 

292 known = set() 

293 for init in graph.initializer: 

294 known.add(init.name) 

295 for init in graph.input: 

296 known.add(init.name) 

297 for node in graph.node: 

298 for o in node.output: 

299 known.add(o) 

300 for i in node.input: 

301 if i not in known: 

302 local.add(i) 

303 return list(local) 

304 

305 def get_local_inputs(self): 

306 """ 

307 Returns any local input used by this node in a subgraph 

308 defined as an attribute and not declared as an input of this subgraph. 

309 """ 

310 req = set() 

311 for att in self.onnx_node.attribute: 

312 if hasattr(att, 'g') and att.g is not None: 

313 req |= set(self._find_local_inputs(att.g)) 

314 return req 

315 

316 def preprocess_parameters(self, runtime, rt_class, ir_version=None, 

317 target_opset=None, existing_functions=None): 

318 """ 

319 Preprocesses the parameters, loads *GraphProto* 

320 (equivalent to :epkg:`ONNX` graph with less metadata). 

321 

322 :param runtime: runtime options 

323 :param rt_class: runtime class used to compute 

324 prediction of subgraphs 

325 :param ir_version: if not None, overwrites the default value 

326 :param target_opset: use a specific target opset 

327 :param existing_functions: existing functions 

328 """ 

329 if 'atts' not in self.desc: 

330 return # pragma: no cover 

331 inside_loop = self.onnx_node.op_type in {'Loop'} 

332 for _, v in self.desc['atts'].items(): 

333 if 'value' not in v: 

334 continue # pragma: no cover 

335 value = v['value'] 

336 if isinstance(value, onnx_proto.GraphProto): 

337 static_inputs = OnnxInferenceNode._find_static_inputs(value) 

338 if len(value.node) > 0: 

339 try: 

340 sess = rt_class(value, runtime=runtime, 

341 ir_version=ir_version, 

342 target_opset=target_opset, 

343 inside_loop=inside_loop, 

344 static_inputs=static_inputs, 

345 existing_functions=existing_functions) 

346 except RuntimeError as e: # pragma: no cover 

347 raise RuntimeError( 

348 "Unable to instantiate a node of type %r and name %r." 

349 "" % (self.onnx_node.op_type, self.onnx_node.name)) from e 

350 else: 

351 # outputs already exists, usually branch then of else for If node 

352 sess = rt_class(value, runtime=runtime, 

353 ir_version=ir_version, 

354 target_opset=target_opset, 

355 inside_loop=inside_loop, 

356 static_inputs=static_inputs, 

357 existing_functions=existing_functions) 

358 v['value_rt'] = sess 

359 

360 def _build_context(self, values, input_list): 

361 context = {} 

362 # input_list does not need to be sorted but when 

363 # an input is not found, the returned error is always 

364 # related to the same input. 

365 for n in sorted(input_list): 

366 try: 

367 v = values[self._global_index(n)] 

368 except IndexError as e: # pragma: no cover 

369 raise IndexError( 

370 f"Unable to find an index for result {n!r} in onnx object.") from e 

371 if v is None: 

372 raise ValueError( # pragma: no cover 

373 f"Input {n!r} is None.") 

374 context[n] = v 

375 return context 

376 

377 def run(self, values, attributes=None, verbose=0, fLOG=None): 

378 """ 

379 Runs the node. 

380 The function updates values with outputs. 

381 

382 :param values: list of existing values 

383 :param attributes: attributes known at function level 

384 :param verbose: verbosity 

385 :param fLOG: logging function 

386 """ 

387 # This code takes time if the graph contains many nodes. 

388 # Maybe a C++ container would help in that case (to skip GIL). 

389 if self.inputs_indices is None: 

390 args = list(values[k] for k in self.inputs) 

391 else: 

392 args = list(values[k] for k in self.inputs_indices) 

393 

394 if self.ops_ is None: 

395 # Then a function. 

396 if 'atts' in self.desc: 

397 # attributes of a function 

398 if attributes is None: 

399 attributes = {} 

400 else: 

401 attributes = attributes.copy() 

402 attributes.update(self.desc['atts']) 

403 

404 feeds = {} 

405 for name, val in zip(self.function_.obj.input, args): 

406 if val is None: 

407 raise ValueError( # pragma: no cover 

408 f"Input name {name!r} is None.") 

409 feeds[name] = val 

410 

411 if verbose == 0 or fLOG is None: 

412 outputs = self.function_.run(feeds, attributes=attributes) 

413 else: 

414 if verbose > 0: 

415 fLOG('-- >%s[%s](%s) -- len(feeds)=%d' % 

416 (self.function_.obj.name, self.function_.obj.domain, 

417 ", ".join(self.function_.obj.input), len(feeds))) 

418 outputs = self.function_.run( 

419 feeds, attributes=attributes, verbose=verbose, fLOG=fLOG) 

420 if verbose > 0: 

421 fLOG('-- <%s[%s][%s]' % 

422 (self.function_.obj.name, self.function_.obj.domain, 

423 ", ".join(self.function_.obj.output))) 

424 

425 res = [outputs[k] for k in self.function_.obj.output] 

426 else: 

427 # Or an operator. 

428 try: 

429 if self.ops_.need_context(): 

430 context = self._build_context(values, 

431 self.ops_.additional_inputs) 

432 res = self.ops_.run(*args, context=context, 

433 attributes=attributes, 

434 verbose=verbose, fLOG=fLOG) 

435 else: 

436 res = self.ops_.run( 

437 *args, attributes=attributes, 

438 verbose=verbose, fLOG=fLOG) 

439 except (ValueError, TypeError) as e: 

440 raise RuntimeError( # pragma: no cover 

441 "Unable to run operator %r, inputs=%r." 

442 "" % (type(self.ops_), self.inputs)) from e 

443 except OverflowError as e: 

444 raise RuntimeError( # pragma: no cover 

445 "Unable to run operator %r, inputs=%r." 

446 "" % (type(self.ops_), self.inputs)) from e 

447 

448 if not isinstance(res, tuple): 

449 raise RuntimeError( # pragma: no cover 

450 f"Results of operator {type(self.ops_)!r} should be a tuple.") 

451 

452 if len(self.outputs) < len(res): 

453 raise RuntimeError( # pragma: no cover 

454 f"Mismatch number of outputs got {len(res)} " 

455 f"for names {list(self.outputs)} " 

456 f"for class {self.name!r})." 

457 f"\n{self.desc}") 

458 

459 # This code takes times if the graph contains many nodes. 

460 # Maybe a C++ container would help in that case (to skip GIL). 

461 if self.outputs_indices is None: 

462 for name, value in zip(self.outputs, res): 

463 values[name] = value 

464 else: 

465 for i, r in enumerate(res): 

466 values[self.outputs_indices[i]] = r 

467 

468 def switch_initializers_dtype(self, dtype_in=numpy.float32, 

469 dtype_out=numpy.float64): 

470 """ 

471 Switches all initializers to ``numpy.float64``. 

472 This only works if the runtime is ``'python'``. 

473 

474 @param dtype_in previous type 

475 @param dtype_out next type 

476 @return done operations 

477 """ 

478 done = [] 

479 for k, v in self.desc['atts'].items(): 

480 if 'value_rt' not in v: 

481 continue 

482 if isinstance(v['value_rt'], numpy.ndarray): 

483 if v['value_rt'].dtype == dtype_in: 

484 v['value_rt'] = v['value_rt'].astype(dtype_out) 

485 done.append(("+", "desc", k, v['value_rt'])) 

486 else: 

487 done.append(("-", "desc", k, v['value_rt'])) 

488 if hasattr(self, 'ops_') and self.ops_ is not None: 

489 res = self.ops_.switch_initializers_dtype(dtype_in, dtype_out) 

490 for r in res: 

491 done.append(("ops_", ) + r) 

492 return done 

493 

494 def enable_inplace_compute(self, name): 

495 """ 

496 Let the node know that one input can be overwritten. 

497 

498 @param name input name 

499 """ 

500 self.inplaces.append(name) 

501 (self.ops_ or self.function_).enable_inplace_compute( 

502 self.inputs.index(name)) 

503 

504 @property 

505 def inputs_args(self): 

506 """ 

507 Returns the list of arguments as well as 

508 the list of parameters with the default values 

509 (close to the signature). 

510 """ 

511 if not hasattr(self, 'ops_'): 

512 raise AttributeError( 

513 "Attribute 'ops_' is missing.") # pragma: no cover 

514 sigs = [] 

515 ops_or_function = self.function_ if self.ops_ is None else self.ops_ 

516 mand = ops_or_function.args_mandatory 

517 if mand is None: 

518 mand = self.python_inputs 

519 sigs.extend(mand) 

520 if len(ops_or_function.args_optional) > 0: 

521 sigs.extend(ops_or_function.args_optional) 

522 if sys.version_info[:2] >= (3, 8): 

523 sigs.append('/') 

524 sigs.extend(ops_or_function.args_default) 

525 return sigs 

526 

527 @property 

528 def python_inputs(self): 

529 """ 

530 Returns the python arguments. 

531 """ 

532 if not hasattr(self, 'ops_'): 

533 raise AttributeError( 

534 "Attribute 'ops_' is missing.") # pragma: no cover 

535 if hasattr(self.ops_, 'python_inputs'): 

536 return self.ops_.python_inputs 

537 return self.inputs 

538 

539 @property 

540 def modified_args(self): 

541 """ 

542 Returns the list of modified parameters. 

543 """ 

544 if not hasattr(self, 'ops_'): 

545 raise AttributeError( 

546 "Attribute 'ops_' is missing.") # pragma: no cover 

547 if self.ops_ is None: 

548 return self.function_.args_default_modified 

549 return self.ops_.args_default_modified 

550 

551 def to_python(self, inputs): 

552 """ 

553 Returns a python code for this operator. 

554 

555 @param inputs inputs name 

556 @return imports, python code, both as strings 

557 """ 

558 if not hasattr(self, 'ops_'): 

559 raise AttributeError( 

560 "Attribute 'ops_' is missing.") # pragma: no cover 

561 if self.ops_ is None: 

562 return self.function_.to_python(inputs) 

563 return self.ops_.to_python(inputs)