Coverage for mlprodict/npy/xop.py: 92%

1783 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# pylint: disable=E1101,C0302 

2""" 

3@file 

4@brief Xop API to build onnx graphs. Inspired from :epkg:`sklearn-onnx`. 

5 

6.. versionadded:: 0.9 

7""" 

8import os 

9import pprint 

10import logging 

11import hashlib 

12import json 

13from collections import OrderedDict 

14import numpy 

15from scipy.sparse.coo import coo_matrix 

16import onnx 

17from onnx import GraphProto, TensorProto, ValueInfoProto 

18from onnx.helper import ( 

19 make_node, make_graph, make_model, make_value_info, 

20 make_tensor_value_info, make_function, make_opsetid, 

21 make_tensor_type_proto, make_operatorsetid) 

22from onnx.numpy_helper import from_array, to_array 

23from onnx.shape_inference import infer_shapes 

24from ..onnx_tools.model_checker import check_onnx 

25from ._cache import cache_folder 

26from .xop_variable import ( 

27 Variable, is_numpy_dtype, numpy_type_prototype, max_supported_opset, 

28 DetectedVariable, InputDetectedVariable, OutputDetectedVariable, 

29 NodeResultName, guess_numpy_type, ExistingVariable) 

30from .xop_auto import get_rst_doc 

31from .xop_helper import _infer_node_output 

32 

33 

34class _WrapperLogger: 

35 """ 

36 Wrappers around class :class:`logging.Logger` 

37 to take indentation into account. 

38 """ 

39 

40 def __init__(self, lg): 

41 "constructor" 

42 self._logger = lg 

43 self._indent = 0 

44 

45 def debug(self, msg, *args): 

46 "debug" 

47 self._logger.debug("%s" + msg, " " * self._indent, *args) 

48 

49 def indent(self): 

50 "indent" 

51 self._indent += 1 

52 

53 def dedent(self): 

54 "unindent" 

55 self._indent -= 1 

56 if self._indent < 0: 

57 raise RuntimeError( # pragma: no cover 

58 "Indentation cannot be negative.") 

59 

60 

61class _WrapperPrint(_WrapperLogger): 

62 """ 

63 Wrappers around print to help debugging. 

64 """ 

65 

66 def __init__(self): 

67 "constructor" 

68 _WrapperLogger.__init__(self, None) 

69 

70 def debug(self, msg, *args, indent=None): 

71 "debug" 

72 sign = "" 

73 if indent is not None: 

74 if not indent: 

75 self.dedent() 

76 sign = '< ' 

77 else: 

78 sign = '> ' 

79 print(f"{' ' * self._indent}{sign}{msg} {' '.join(map(str, args))}") 

80 if indent is not None: 

81 if indent: 

82 self.indent() 

83 

84 

85logger = _WrapperLogger(logging.getLogger('xop')) 

86local_print = _WrapperPrint().debug 

87 

88 

89def _default_OPSET_TO_IR_VERSION(): 

90 """ 

91 Returns the default mapping between opset and ir_version. 

92 

93 .. runpython:: 

94 :showcode: 

95 

96 import pprint 

97 from mlprodict.npy.xop import _default_OPSET_TO_IR_VERSION 

98 pprint.pprint(_default_OPSET_TO_IR_VERSION()) 

99 """ 

100 return { 

101 1: 3, 2: 3, 3: 3, 4: 3, 5: 3, 6: 3, 

102 7: 3, 8: 4, 9: 4, 10: 5, 11: 6, 12: 7, 

103 13: 7, 14: 7, 15: 8, 16: 8, 17: 8} 

104 

105 

106def _domain_to_class_name(domain): 

107 """ 

108 Converts domain into a name. 

109 

110 :param domain: domain name such as `ai.onnx.ml` 

111 :return: string 

112 

113 .. runpython:: 

114 :showcode: 

115 

116 from mlprodict.npy.xop import _domain_to_class_name 

117 print(_domain_to_class_name('ai.onnx.ml')) 

118 """ 

119 if domain == 'ai.onnx': 

120 return '' 

121 dom = domain.split('.') 

122 res = [] 

123 for d in dom: 

124 if len(d) == 0: 

125 res.append(d) 

126 elif len(d) == 1: 

127 res.append(d.upper()) 

128 else: 

129 res.append(d[0].upper() + d[1:]) 

130 return "".join(res) 

131 

132 

133class _CustomSchema: 

134 """ 

135 For operators defined outside onnx. 

136 """ 

137 

138 class _empty: 

139 "dummy class" 

140 

141 @staticmethod 

142 def from_attribute(data): 

143 "Creates an instance of `_CustomSchema._attribute`." 

144 if not isinstance(data, dict): 

145 raise TypeError( # pragma: no cover 

146 f"Unexpected type {type(data)!r}.") 

147 self = _CustomSchema._empty() 

148 setattr(self, 'name', data['name']) 

149 setattr(self, 'description', data['description']) 

150 setattr(self, 'required', data['required']) 

151 setattr(self, 'type', _CustomSchema._empty()) 

152 setattr(self.type, 'value', data['type']) 

153 setattr(self, 'default_value', '?') 

154 return self 

155 

156 @staticmethod 

157 def from_io(data): 

158 "Creates an instance of `_CustomSchema._io`." 

159 if not isinstance(data, dict): 

160 raise TypeError( # pragma: no cover 

161 f"Unexpected type {type(data)!r}.") 

162 self = _CustomSchema._empty() 

163 setattr(self, 'name', data['name']) 

164 setattr(self, 'typeStr', data['typeStr']) 

165 setattr(self, 'description', data['description']) 

166 setattr(self, 'option', _CustomSchema._empty()) 

167 setattr(self.option, 'value', data['option']) 

168 setattr(self, 'isHomogeneous', data['isHomogeneous']) 

169 return self 

170 

171 class _io: 

172 "input, output" 

173 

174 def __init__(self, t): 

175 self.name = t.name 

176 self.typeStr = t.typeStr 

177 if isinstance(t.option, int): 

178 self.option = t.option 

179 else: 

180 self.option = t.option.value 

181 self.description = t.description 

182 self.isHomogeneous = t.isHomogeneous 

183 

184 def data(self): 

185 "Returns all data in that class in a dictionary." 

186 return {'name': self.name, 'typeStr': self.typeStr, 

187 'description': self.description, 

188 'isHomogeneous': self.isHomogeneous, 

189 'option': self.option} 

190 

191 def __eq__(self, ot): 

192 return self.name == ot.name and self.typeStr == ot.typeStr 

193 

194 class _attribute: 

195 "attribute" 

196 

197 def __init__(self, att): 

198 self.name = att.name 

199 if isinstance(att.type, int): 

200 self.type = att.type 

201 else: 

202 self.type = att.type.value 

203 self.default_value = '?' 

204 self.description = att.description 

205 self.required = att.required 

206 

207 def data(self): 

208 "Returns all data in that class in a dictionary." 

209 return {'name': self.name, 'type': self.type, 

210 'description': self.description, 

211 'required': self.required} 

212 

213 def __eq__(self, ot): 

214 return self.name == ot.name and self.type == ot.type 

215 

216 def __init__(self, schema): 

217 self._schema = schema 

218 self.domain = schema.domain 

219 self.name = schema.name 

220 self.since_version = schema.since_version 

221 try: 

222 self.inputs = [_CustomSchema._io(t) for t in schema.inputs] 

223 except AttributeError as e: # pragma: no cover 

224 raise AttributeError( 

225 "Issue with operator=%r domain=%r since_version=%r, " 

226 "type(schema)=%r" % ( 

227 schema.name, schema.domain, schema.since_version, 

228 type(schema))) from e 

229 try: 

230 self.outputs = [_CustomSchema._io(t) for t in schema.outputs] 

231 except AttributeError as e: # pragma: no cover 

232 raise AttributeError( 

233 "Issue with operator=%r domain=%r since_version=%r, " 

234 "type(schema)=%r" % ( 

235 schema.name, schema.domain, schema.since_version, 

236 type(schema))) from e 

237 self.attributes = {a.name: _CustomSchema._attribute(a) 

238 for a in schema.attributes.values()} 

239 self.min_input = schema.min_input 

240 self.max_input = schema.max_input 

241 self.min_output = schema.min_output 

242 self.max_output = schema.max_output 

243 self.doc = schema.doc 

244 

245 _atts = ['domain', 'name', 'since_version', 'inputs', 'outputs', 

246 'attributes', 'min_input', 'max_input', 

247 'min_output', 'max_output', 'doc'] 

248 

249 def __eq__(self, ot): 

250 for k in _CustomSchema._atts: 

251 if getattr(self, k) == getattr(ot, k): 

252 continue 

253 return False 

254 return True 

255 

256 def data(self): 

257 "Returns all data in that class in a dictionary." 

258 def _(x): 

259 if x is None: 

260 return None 

261 if isinstance(x, (str, int)): 

262 return x 

263 if isinstance(x, list): 

264 return [_(e) for e in x] 

265 if isinstance(x, dict): 

266 return {k: _(v) for k, v in x.items()} 

267 if hasattr(x, 'data'): 

268 return x.data() 

269 raise TypeError( # pragma: no cover 

270 f"Unable to handle type {type(x)!r} - {x!r}.") 

271 

272 return {k: _(getattr(self, k)) for k in _CustomSchema._atts} 

273 

274 def SerializeToString(self): 

275 "Serializes this class into json." 

276 return json.dumps(self.data()) 

277 

278 @staticmethod 

279 def ParseFromString(s): 

280 "Parses this class from a json string." 

281 obj = json.loads(s) 

282 e = _CustomSchema._empty() 

283 for k in _CustomSchema._atts: 

284 if k == 'attributes': 

285 setattr(e, k, {a['name']: _CustomSchema._empty.from_attribute(a) 

286 for a in obj[k].values()}) 

287 elif k in ('inputs', 'outputs'): 

288 setattr(e, k, [_CustomSchema._empty.from_io(o) 

289 for o in obj[k]]) 

290 else: 

291 setattr(e, k, obj[k]) 

292 return _CustomSchema(e) 

293 

294 def __repr__(self): 

295 return f"_CustomSchema(**{pprint.pformat(self.data())})" 

296 

297 

298def _get_all_operator_schema(): 

299 data = os.path.join(os.path.dirname(__file__), 

300 "ort_get_all_operator_schema.tmpl") 

301 with open(data, 'r', encoding='utf-8') as f: 

302 js = f.readlines() 

303 return [_CustomSchema.ParseFromString(j) for j in js[1:]] 

304 

305 

306def _populate_schemas(): 

307 """ 

308 Populates all schemas. 

309 """ 

310 def _populate_schema(schema): 

311 # Multiple version can coexist. The last one is kept. 

312 key = schema.domain, schema.name 

313 if key in res: 

314 if schema.since_version > res[key].since_version: 

315 # We keep the most recent one. 

316 res[key] = schema 

317 else: 

318 res[key] = schema 

319 full_name = schema.name + '_' + str(schema.since_version) 

320 res[schema.domain, full_name] = schema 

321 if key not in versions: 

322 versions[key] = set() 

323 if schema.name not in domains: 

324 domains[schema.name] = set() 

325 domains[schema.name].add(schema.domain) 

326 versions[key].add(full_name) 

327 

328 res = {} 

329 versions = {} 

330 domains = {} 

331 for schema in onnx.defs.get_all_schemas_with_history(): 

332 if schema.support_level == schema.SupportType.EXPERIMENTAL: 

333 # Skips experimental operators. 

334 continue 

335 _populate_schema(schema) 

336 

337 try: 

338 import onnxruntime.capi.onnxruntime_pybind11_state as rtpy 

339 except ImportError: # pragma: no cover 

340 rtpy = None 

341 

342 if rtpy is not None: 

343 # If onnxruntime is available, it is being populated with these operators as well. 

344 try: 

345 get_schemas = rtpy.get_all_operator_schema 

346 except AttributeError: 

347 # onnxruntime must be compiled with flag --gen_doc. 

348 # a local copy is retrieved. 

349 get_schemas = _get_all_operator_schema 

350 for op in get_schemas(): 

351 if (op.domain, op.name) in res: 

352 # an existing onnx schema 

353 continue 

354 sch = _CustomSchema(op) 

355 _populate_schema(sch) 

356 

357 return res, versions, domains 

358 

359 

360def _find_operator_domain(name): 

361 """ 

362 Determines the domain of an operator. 

363 Raises an exception if not found or if there is an ambiguity. 

364 

365 :param name: operator name 

366 :return: domain 

367 """ 

368 if name not in _S.all_domains: 

369 raise ValueError( 

370 "Unable to guess domain for operator %r. " 

371 "Not found in %r." % (name, list(_S.all_domains))) 

372 domains = _S.all_domains[name] 

373 if len(domains) == 1: 

374 return list(domains)[0] 

375 raise ValueError( # pragma: no cover 

376 f"Unable to guess domain of operator {name!r}, found domains {domains!r}.") 

377 

378 

379def _split_op_name(name): 

380 spl = name.split('_') 

381 try: 

382 i = int(spl[-1]) 

383 except ValueError: 

384 return name, None 

385 return "_".join(spl[:-1]), i 

386 

387 

388def ClassFactory(class_name, op_name, inputs, outputs, 

389 input_range, output_range, 

390 domain, attr_names, doc, 

391 deprecated, since_version, 

392 past_version): 

393 """ 

394 Dynamically creates a class for a specific operator. 

395 

396 :param class_name: class name 

397 :param op_name: operator type 

398 :param inputs: expected inputs 

399 :param outputs: expected outputs 

400 :param input_range: input range 

401 :param output_range: output_range 

402 :param domain: domain 

403 :param attr_names: attributes names 

404 :param doc: docstring 

405 :param deprecated: is the operator deprecated 

406 :param since_version: available since version 

407 :param past_version: list of versions 

408 """ 

409 

410 def __init__(self, *args, **kwargs): 

411 

412 op_version = kwargs.pop('op_version', None) 

413 

414 if op_version is None: 

415 if len(args) == 0 and input_range[0] == input_range[1]: 

416 args = [_[0] for _ in self.__class__.expected_inputs] 

417 if not (input_range[0] <= len(args) <= input_range[1]): 

418 raise RuntimeError( # pragma: no cover 

419 "Unexpected number of inputs, " 

420 "got {}, expecting {} for operator " 

421 "'{}'.".format( 

422 len(args), len(inputs), op_name)) 

423 

424 attr_names = self.attr_names 

425 _, op_version_class = _split_op_name(self.__class__.__name__) 

426 if op_version_class is not None: 

427 if op_version is None: 

428 op_version = op_version_class 

429 try: 

430 op_version = min(op_version, op_version_class) 

431 except TypeError: # pragma: no cover 

432 raise TypeError( # pylint: disable=W0707 

433 "Could not compare versions {} ? {} for " 

434 "class '{}' since_version {}. Parameter 'op_version' " 

435 "is probably missing when the class " 

436 "is instantiated.".format( 

437 op_version, op_version_class, class_name, 

438 since_version)) 

439 else: 

440 op_version_class = None 

441 

442 # By default, the op_version is None. 

443 # None means the latest available. 

444 if op_version is None: 

445 op_version = since_version 

446 

447 found = None 

448 if op_version is not None: 

449 # attr_names refers to the most recent version of 

450 # this operator. We may need an older one. 

451 for op in range(op_version, 0, -1): 

452 name = f'{self.__class__.__name__}_{op}' 

453 if name in self.past_version: 

454 found = (name, op) 

455 attr_names = self.past_version[name].attr_names 

456 if len(attr_names) > 0 and not isinstance(attr_names[0], str): 

457 raise TypeError( # pragma: no cover 

458 "attr_names must be a list of string not a list of %r for " 

459 "operator %r and domain %r." % ( 

460 type(attr_names[0]), name, domain)) 

461 break 

462 if (op_version_class is not None and found is not None and 

463 found[-1] != op_version_class): 

464 raise RuntimeError( # pragma: no cover 

465 "op_version={} does not refer to the same opset as the class " 

466 "name ('{}').".format(op_version, self.__class__.__name__)) 

467 for key in kwargs: 

468 if key in {'output_names', 'op_version', 'domain', 'ir_version', 

469 'global_context', 'clear_subgraph_inputs'}: 

470 continue 

471 if key not in attr_names: 

472 raise TypeError( # pragma: no cover 

473 "Argument '%s' not valid for '%s' domain=%r opset=%s " 

474 "(should be in %r, type(self)=%r)." % ( 

475 key, op_name, domain, op_version, attr_names, 

476 type(self))) 

477 

478 if op_version is not None: 

479 kwargs['op_version'] = op_version 

480 if 'domain' not in kwargs: 

481 kwargs['domain'] = domain 

482 # This class can only be created by a user. Let's check 

483 # types are either a variable, an operator or an array. 

484 for i, a in enumerate(args): 

485 if isinstance(a, tuple): 

486 if len(a) != 2: 

487 raise TypeError( # pragma: no cover 

488 "Input %r is a tuple or class %r, it must have two " 

489 "elements (name, type) not %r." % (i, class_name, a)) 

490 if not isinstance(a[0], str): 

491 raise TypeError( # pragma: no cover 

492 "Input %r is a tuple or class %r, it must be a tuple " 

493 "(name, type) not %r." % (i, class_name, a)) 

494 continue 

495 if not isinstance(a, ( 

496 Variable, OnnxOperator, numpy.ndarray, str, 

497 OnnxOperatorItem, coo_matrix)): 

498 raise TypeError( # pragma: no cover 

499 "Unexpected type %r for input %r of operator %r. " 

500 "It must be an instance of Variable (or a string), " 

501 "OnnxOperator, OnnxOperatorItem, numpy.ndarray, " 

502 "coo_matrix)." % ( 

503 type(a), i, class_name)) 

504 OnnxOperator.__init__(self, *args, **kwargs) 

505 

506 newclass = type(class_name, (OnnxOperator,), 

507 {"__init__": __init__, '__doc__': doc, 

508 'expected_inputs': inputs, 

509 'expected_outputs': outputs, 

510 'operator_name': op_name, 

511 'input_range': input_range, 

512 'output_range': output_range, 

513 'domain': domain, 

514 'is_deprecated': deprecated, 

515 'since_version': since_version, 

516 'past_version': past_version, 

517 'attr_names': attr_names, 

518 'op_type': op_name, 

519 '__module__': __name__}) 

520 return newclass 

521 

522 

523def _dynamic_class_creation(operator_names=None, cache=False, include_past=False, 

524 verbose=0, fLOG=print): 

525 """ 

526 Automatically generates classes for each of the operators 

527 module *onnx* defines and described at 

528 `Operators 

529 <https://github.com/onnx/onnx/blob/master/docs/Operators.md>`_ 

530 and `Operators 

531 <https://github.com/onnx/onnx/blob/master/docs/ 

532 Operators-ml.md>`_. 

533 

534 :param operator_names: list of operators to request or None for all 

535 :param cache: extract the documentation from onnx package and 

536 saves it on disk it True 

537 :param include_past: includes past versions if operator_names is None 

538 :param verbose: display some progress 

539 :param fLOG: logging function 

540 :return: list of requested operators as a tuple 

541 """ 

542 def _c(obj, label, i): 

543 name = '%s%d' % (obj.name or label, i) 

544 tys = obj.typeStr or '' 

545 return (name, tys) 

546 

547 cache_dir = cache_folder() 

548 if operator_names is None: 

549 operator_names = list(_S.all_schemas_versions) 

550 if include_past: 

551 add = [] 

552 for domain, op in operator_names: 

553 add.extend( 

554 [(domain, k) 

555 for k in _S.all_schemas_versions[domain, op]]) 

556 operator_names.extend(add) 

557 operator_names.sort() 

558 

559 # type verification 

560 ops = [] 

561 for name in operator_names: 

562 if isinstance(name, str): 

563 if name.startswith('Onnx'): 

564 raise ValueError( 

565 f"Operator name cannot start with Onnx: {name!r}.") 

566 n_name, _ = _split_op_name(name) 

567 domain = _find_operator_domain(n_name) 

568 ops.append((domain, name)) 

569 elif isinstance(name, tuple) and len(name) == 2: 

570 if name[1].startswith('Onnx'): 

571 raise ValueError( # pragma: no cover 

572 f"Operator name cannot starts with Onnx: {name!r}.") 

573 ops.append(name) 

574 else: 

575 raise ValueError( # pragma: no cover 

576 "Operator to fetch must be a string or a " 

577 "`tuple(domain, name)` not %r." % (name)) 

578 operator_names = ops 

579 

580 # versions 

581 res = _S.all_schemas 

582 cls = {} 

583 set_names = dict() 

584 set_skip = set() 

585 for pos, (op_domain, op_name) in enumerate(operator_names): 

586 if op_domain == 'ai.onnx': 

587 op_domain = '' 

588 set_names[op_domain, op_name] = pos 

589 n, v = _split_op_name(op_name) 

590 if v is not None and not include_past: 

591 set_skip.add((op_domain, n)) 

592 if n not in set_names: 

593 set_names[op_domain, n] = -1 

594 

595 if verbose > 1 and fLOG is not None: # pragma: no cover 

596 fLOG(f"[_dynamic_class_creation] set_names={set_names!r}") 

597 fLOG(f"[_dynamic_class_creation] set_skip={set_skip!r}") 

598 

599 returned_classes = [] 

600 positions = {} 

601 

602 for (op_domain, op_name), position in set_names.items(): 

603 cl_name = 'Onnx' + _domain_to_class_name(op_domain) + op_name 

604 if verbose > 3 and fLOG is not None: 

605 fLOG( # pragma: no cover 

606 '[_dynamic_class_creation] cl_name=%r op_domain=%r op_name=%r (in=%d) ' 

607 'position=%r' % ( 

608 cl_name, op_domain, op_name, 

609 1 if cl_name in _S.all_classes else 0, 

610 position)) 

611 if cl_name in _S.all_classes: 

612 if cl_name not in set_skip: 

613 if position >= 0: 

614 returned_classes.append( 

615 (position, _S.all_classes[cl_name])) 

616 continue 

617 

618 # operator name without domain 

619 n, v = _split_op_name(op_name) 

620 if v is not None: 

621 names = [op_name] 

622 else: 

623 try: 

624 names = _S.all_schemas_versions[op_domain, op_name].copy() 

625 except KeyError as e: # pragma: no cover 

626 raise ValueError( 

627 "Operator %r (domain=%r) does not exists." % ( 

628 op_name, op_domain)) from e 

629 names.add(op_name) 

630 

631 if verbose > 0 and fLOG is not None: 

632 fLOG( # pragma: no cover 

633 "[_dynamic_class_creation] op_domain=%r op_name=%r, cl_name=%r names=%r" 

634 "" % (op_domain, op_name, cl_name, names)) 

635 

636 for name in names: 

637 try: 

638 schema = res[op_domain, name] 

639 except KeyError as e: 

640 raise ValueError( 

641 "Operator (%r, %r) does not exists (available=%r)" % ( 

642 op_domain, name, pprint.pformat(list(res)))) from e 

643 inputs = [_c(o, 'I', i) for i, o in enumerate(schema.inputs)] 

644 outputs = [_c(o, 'O', i) for i, o in enumerate(schema.outputs)] 

645 args = [p if isinstance(p, str) else p.name 

646 for p in schema.attributes] 

647 if len(args) > 0 and not isinstance(args[0], str): 

648 raise TypeError( # pragma: no cover 

649 "args must be a list of string not a list of %r for " 

650 "operator %r and domain %r." % ( 

651 type(args[0]), name, op_domain)) 

652 

653 n_name, v = _split_op_name(name) 

654 

655 if v is not None: 

656 if op_domain == 'com.microsoft' and name in { 

657 'SoftmaxGrad_13', 'LogSoftmaxGrad_13'}: 

658 # exception 

659 pass 

660 elif v != schema.since_version: 

661 raise ValueError( # pragma: no cover 

662 "Inconsistent version number %d != %d for operator " 

663 " %r, %r (%r)." % ( 

664 v, schema.since_version, schema.domain, 

665 schema.name, name)) 

666 class_name = "Onnx" + _domain_to_class_name(op_domain) + name 

667 else: 

668 class_name = ( 

669 "Onnx" + _domain_to_class_name(op_domain) + schema.name) 

670 

671 if verbose > 0 and fLOG is not None: 

672 fLOG( # pragma: no cover 

673 "[_dynamic_class_creation] op_name=%r, cl_name=%r cache=%r v=%r" 

674 "" % (op_name, class_name, cache, v)) 

675 

676 filename = os.path.join( 

677 cache_dir, 

678 schema.name + '_' + str(schema.since_version) + ".rst") 

679 if not cache and os.path.exists(filename): 

680 with open(filename, "r", encoding="utf-8") as f: # pragma: no cover 

681 doc = f.read() 

682 else: 

683 doc = get_rst_doc(schema.name, domain=schema.domain, 

684 version=schema.since_version) 

685 if cache: # pragma: no cover 

686 with open(filename, 'w', encoding='utf-8') as f: 

687 f.write(doc) 

688 

689 cl = ClassFactory(class_name, schema.name, inputs, outputs, 

690 [schema.min_input, schema.max_input], 

691 [schema.min_output, schema.max_output], 

692 schema.domain, args, 

693 "**Version**" + doc.split('**Version**')[-1], 

694 getattr(schema, 'deprecated', False), 

695 schema.since_version, {}) 

696 cls[class_name] = cl 

697 if name == op_name: 

698 positions[class_name] = position 

699 

700 # Retrieves past classes. 

701 for name in cls: # pylint: disable=C0206 

702 main, v = _split_op_name(name) 

703 if v is None: 

704 continue 

705 if main in cls: # pylint: disable=R1715 

706 last = cls[main] 

707 else: 

708 last = _S.all_classes[main] 

709 last.past_version[name] = cls[name] 

710 

711 # final 

712 _S.all_classes.update(cls) 

713 for cl_name, v in cls.items(): 

714 if v not in set_skip and positions.get(cl_name, -1) >= 0: 

715 returned_classes.append((positions[cl_name], v)) 

716 

717 returned_classes.sort() 

718 return tuple(e[1] for e in returned_classes) 

719 

720 

721def loadop(*names, cache=False, verbose=0, fLOG=print): 

722 """ 

723 Dynamically creates a class for a every operator type in 

724 the given list. 

725 """ 

726 res = _dynamic_class_creation( 

727 names, cache=cache, verbose=verbose, fLOG=fLOG) 

728 if len(res) == 1: 

729 return res[0] 

730 return res 

731 

732 

733class OnnxLoadFactory: 

734 """ 

735 Automatically creating all operators from onnx packages 

736 takes time. That's why function @see cl loadop only creates 

737 classes for the requested operators. This class does the same 

738 when an attributes is requested. 

739 

740 :: 

741 

742 cl = OnnxLoadOperators() 

743 x = cl.Add(...) 

744 

745 It is equivalent to: 

746 

747 :: 

748 

749 OnnxAdd = loadop('Add') 

750 x = OnnxAdd(...) 

751 """ 

752 

753 def __init__(self): 

754 self._loaded_classes = {} 

755 

756 def __getattr__(self, name): 

757 """ 

758 Enables expressions such as: 

759 

760 :: 

761 

762 ops = OnnxLoadFactory() 

763 op = ops.Abs('X') 

764 """ 

765 if name == '_loaded_classes': 

766 return self._loaded_classes 

767 if name in self._loaded_classes: 

768 return self._loaded_classes[name] 

769 cl = loadop(name) 

770 self._loaded_classes[name] = cl 

771 self._loaded_classes[cl.__name__] = cl 

772 return cl 

773 

774 

775class OnnxOperatorBase: 

776 """ 

777 Base class for @see cl OnnxOperator, @see cl OnnxOperatorItem, 

778 @see cl OnnxOperatorTuple. 

779 """ 

780 

781 def __init__(self): 

782 pass 

783 

784 def add_to(self, builder): 

785 "This method should be overwritten." 

786 raise NotImplementedError( # pragma: no cover 

787 f"Not overwritten for class {type(self)!r}.") 

788 

789 @property 

790 def output_names(self): 

791 "This method should be overwritten." 

792 raise NotImplementedError( # pragma: no cover 

793 f"Not overwritten for class {type(self)!r}.") 

794 

795 def find_named_inputs(self): 

796 """ 

797 Returns all inputs to the graph. 

798 """ 

799 raise NotImplementedError( # pragma: no cover 

800 f"Method 'find_named_inputs' must be overloaded for type {type(self)}.") 

801 

802 def f(self, *args, **kwargs): 

803 """ 

804 Evaluates this node. 

805 """ 

806 raise NotImplementedError( # pragma: no cover 

807 f"Method 'f' must be overloaded for type {type(self)}.") 

808 

809 def _set_control_op(self, op, subgraph_inputs=None): 

810 """ 

811 Tells this operator is part of a subgraph. 

812 """ 

813 raise NotImplementedError( # pragma: no cover 

814 f"Method '_set_control_op' must be overloaded for type {type(self)}.") 

815 

816 def add_external_input(self, op): 

817 """ 

818 Tells a subgraph this node comes from the main graph. 

819 It may be used only by the subgraph but it must be processed as well. 

820 """ 

821 raise NotImplementedError( # pragma: no cover 

822 f"Method '_set_control_op' must be overloaded for type {type(self)}.") 

823 

824 

825class OnnxOperatorItem(OnnxOperatorBase): 

826 """ 

827 Accessor to one of the output returned by a @see cl OnnxOperator. 

828 

829 :param onx_op: @see cl OnnxOperator 

830 :param index: integer 

831 :param op_version: defines the opset version 

832 """ 

833 

834 def __init__(self, onx_op, index, op_version=None): 

835 OnnxOperatorBase.__init__(self) 

836 if not isinstance(index, int): 

837 raise TypeError( # pragma: no cover 

838 f"index must be an integer not {type(index)!r}.") 

839 logger.debug("op:%s-%d(%r, %d, op_version=%r)", 

840 self.__class__.__name__, id(self), onx_op, index, op_version) 

841 if not isinstance(onx_op, OnnxOperatorBase): 

842 raise TypeError( # pragma: no cover 

843 f"onx_op must be an OnnxOperator not {type(onx_op)!r}.") 

844 self.onx_op = onx_op 

845 self.index = index 

846 self.op_version = op_version 

847 

848 @property 

849 def output_names(self): 

850 "Returns None." 

851 return None 

852 

853 @property 

854 def inputs(self): 

855 "Returns the only inputs in a list." 

856 return [NodeResultName(self.onx_op, self.index)] 

857 

858 def add_to(self, builder): 

859 """ 

860 Adds to graph builder. 

861 Does nothing because the original node is already added. 

862 

863 :param builder: instance of @see cl _GraphBuilder, 

864 it must have a method `add_node` 

865 """ 

866 pass 

867 

868 def __str__(self): 

869 "usual" 

870 return "%s[%d]" % (str(self.onx_op), self.index) 

871 

872 def __repr__(self): 

873 "usual" 

874 return "%s(%s[%d])" % ( 

875 self.__class__.__name__, 

876 self.onx_op.__class__.__name__, 

877 self.index) 

878 

879 def get_output_result(self, i=0): 

880 """ 

881 Returns the output name at position *i*. 

882 """ 

883 if i != 0: 

884 raise IndexError( # pragma: no cover 

885 "Can only return the first item.") 

886 return self.onx_op.get_output_result(self.index) 

887 

888 def _to_onnx_attributes(self, inputs=None, target_opset=None, 

889 optim=True, verbose=0, run_shape=True, 

890 fLOG=print, processed=None): 

891 """ 

892 Calls `self.onx_op._to_onnx_attributes`. 

893 """ 

894 return self.onx_op._to_onnx_attributes( 

895 inputs=inputs, target_opset=target_opset, optim=optim, 

896 run_shape=run_shape, verbose=verbose, fLOG=fLOG, 

897 processed=processed) 

898 

899 def find_named_inputs(self): 

900 """ 

901 Returns all inputs to the graph. 

902 """ 

903 return self.onx_op.find_named_inputs() 

904 

905 def f(self, *inputs, verbose=0, fLOG=None, # pylint: disable=W0221 

906 clear_cache=False, runtime=None): 

907 """ 

908 Computes the predictions for this node. 

909 Similar to an eager evaluation. 

910 

911 :param inputs: inputs as dictionary or a list of inputs 

912 (see below) 

913 :param verbose: display information while predicting 

914 :param fLOG: logging function if *verbose > 0* 

915 :param clear_cache: onnx graph is created once unless 

916 this parameter is True 

917 :param runtime: runtime to use for the evaluation, 

918 see @see cl OnnxInference 

919 :return: outputs as a dictionary if the input were given as a 

920 dictionary or a single result or a tuple otherwise 

921 

922 The inputs refer to the inputs of the graph. 

923 The method walks through all inputs and finds inputs defined as 

924 string. It replaces them by the value found in the dictionary. 

925 If the inputs are specified in a list, the function retrieves the 

926 list of inputs defined as a string and assigns them a value. 

927 Logging function can be used to get more insight about it. 

928 During the evaluation every node is independently converted 

929 into ONNX. The ONNX graph is cached in the class itself. 

930 """ 

931 res = self.onx_op.f(*inputs, verbose=verbose, fLOG=fLOG, 

932 clear_cache=clear_cache, runtime=runtime) 

933 if isinstance(res, dict): 

934 names = self.onx_op.output_names 

935 if names is None: 

936 names = self.onx_op.expected_outputs 

937 name = names[self.index][0] 

938 else: 

939 name = names[self.index] 

940 return {name: res[name]} 

941 return res[self.index] 

942 

943 

944class OnnxOperatorTuple(OnnxOperatorBase): 

945 """ 

946 Class used to return multiple @see cl OnnxVar 

947 at the same time. 

948 """ 

949 

950 def __init__(self, first, *args): 

951 OnnxOperatorBase.__init__(self) 

952 logger.debug("op:%s-%d([%r], %d in)", 

953 self.__class__.__name__, id(self), type(first), 

954 len(args)) 

955 if isinstance(first, (list, tuple)): 

956 raise TypeError( # pragma: no cover 

957 f"Unexpected type for first {type(first)!r}.") 

958 logger.debug('op:%s-%d(%d in)', self.__class__.__name__, 

959 id(self), 1 + len(args)) 

960 if len(args) > 0: 

961 self.values = (first,) + args 

962 self.unique = None 

963 else: 

964 self.values = None 

965 self.unique = first 

966 if self.values is not None and self.unique is not None: 

967 raise RuntimeError( # pragma: no cover 

968 "Unexpected configuration. One member (values or unique) must be " 

969 "null, unique=%r, values=%r" % (self.unique, self.values)) 

970 if self.values is None and self.unique is None: 

971 raise RuntimeError( # pragma: no cover 

972 "Unexpected configuration. One member (values or unique) must be " 

973 "not null.") 

974 

975 def __repr__(self): 

976 "usual" 

977 if self.values is None: 

978 return f"{self.__class__.__name__}({type(self.unique)!r})" 

979 return "%s(%s)" % (self.__class__.__name__, ", ".join( 

980 "%r" % type(v) for v in self.values)) 

981 

982 @property 

983 def inputs(self): 

984 "Returns the only inputs in a list." 

985 if self.values is None: 

986 return [self.unique] 

987 raise NotImplementedError( # pragma: no cover 

988 "OnnxOperatorTuple.inputs is missing.") 

989 

990 @property 

991 def external_inputs(self): 

992 """ 

993 Returns the list of implicit inputs the subgraph 

994 assumes to be existing even if they are not referenced as 

995 explicit input for the graph. 

996 """ 

997 if self.values is None: 

998 return self.unique.external_inputs 

999 res = [] 

1000 for op in self.values: 

1001 res.extend(op.external_inputs) 

1002 return res 

1003 

1004 def add_to(self, builder): 

1005 """ 

1006 Adds to graph builder. 

1007 Does nothing because the original node is already added. 

1008 

1009 :param builder: instance of @see cl _GraphBuilder, 

1010 it must have a method `add_node` 

1011 """ 

1012 pass 

1013 

1014 def __len__(self): 

1015 "usual" 

1016 if self.values is None: 

1017 raise NotImplementedError( # pragma: no cover 

1018 "Not yet implemented in this case unique=%r, " 

1019 "values=%r." % (self.unique, self.values)) 

1020 return len(self.values) 

1021 

1022 def __iter__(self): 

1023 "Iterates on the outputs." 

1024 if self.values is None: 

1025 raise NotImplementedError( # pragma: no cover 

1026 "Not yet implemented in this case.") 

1027 for v in self.values: 

1028 yield v 

1029 

1030 def __getitem__(self, i): 

1031 "usual" 

1032 if self.values is None: 

1033 return self.unique[i] 

1034 return self.values[i] 

1035 

1036 @property 

1037 def outputs(self): 

1038 "Returns 'output_names' of attribute 'unique'." 

1039 if self.values is None: 

1040 if hasattr(self.unique, 'to_onnx'): 

1041 return self.unique.outputs 

1042 raise NotImplementedError( # pragma: no cover 

1043 f"Not implemented yet unique={self.unique!r} values={self.values!r}.") 

1044 

1045 @property 

1046 def output_names(self): 

1047 "Returns 'output_names' of attribute 'unique'." 

1048 if self.values is None: 

1049 if hasattr(self.unique, 'to_onnx'): 

1050 return self.unique.output_names 

1051 raise NotImplementedError( # pragma: no cover 

1052 f"Not implemented yet unique={self.unique!r} values={self.values!r}.") 

1053 

1054 @output_names.setter 

1055 def output_names(self, value): 

1056 """ 

1057 Updates 'output_names' of attribute 'unique' 

1058 or every output name of attribute 'values'. 

1059 """ 

1060 logger.debug("op:%s:output_names:set(%r)", 

1061 self.__class__.__name__, value) 

1062 OnnxIdentity = loadop('Identity') # pylint: disable=W0621 

1063 if self.values is None: 

1064 if (hasattr(self.unique, 'to_onnx') or 

1065 hasattr(self.unique, 'add_to')): 

1066 if len(value) > 1: 

1067 self.values = tuple( 

1068 OnnxIdentity( 

1069 self.unique[i], output_names=value[i:i + 1], 

1070 op_version=self.unique.op_version) 

1071 for i in range(0, len(value))) 

1072 self.unique = None 

1073 return 

1074 self.unique.output_names = [Variable(v) for v in value] 

1075 return 

1076 raise NotImplementedError( # pragma: no cover 

1077 "Not implemented yet, value=%r, unique=%r values=%r." % ( 

1078 value, self.unique, self.values)) 

1079 if self.values is not None and len(self.values) == len(value): 

1080 for name, v in zip(value, self.values): 

1081 v.output_names = [Variable(name)] 

1082 return 

1083 raise NotImplementedError( # pragma: no cover 

1084 "Not implemented yet, value=%r, unique=%r values=%r." % ( 

1085 value, self.unique, self.values)) 

1086 

1087 def _to_onnx_attributes(self, inputs=None, target_opset=None, 

1088 optim=True, verbose=0, run_shape=True, 

1089 fLOG=print, processed=None): 

1090 """ 

1091 Calls `self.onx_op._to_onnx_attributes`. 

1092 """ 

1093 if self.values is None: 

1094 return self.unique._to_onnx_attributes( 

1095 inputs=inputs, target_opset=target_opset, optim=optim, 

1096 run_shape=run_shape, verbose=verbose, fLOG=fLOG, 

1097 processed=processed) 

1098 res = [] 

1099 for v in self.values: 

1100 res.append(v._to_onnx_attributes( 

1101 inputs=inputs, target_opset=target_opset, optim=optim, 

1102 run_shape=run_shape, verbose=verbose, fLOG=fLOG, 

1103 processed=processed)) 

1104 return res 

1105 

1106 def to_onnx(self, inputs=None, outputs=None, 

1107 other_outputs=None, target_opset=None, 

1108 optim=True, verbose=0, run_shape=True, 

1109 processed=None, check_model=True, 

1110 return_builder=False, fLOG=None): 

1111 """ 

1112 Converts this operator into an ONNX graph. 

1113 It follows the same signature as :meth:`OnnxOperator.to_onnx 

1114 <mlprodict.npy.xop.OnnxOperator.to_onnx>` and calls this 

1115 method of the unique input object or the first one 

1116 if there are several. In that case, other inputs in 

1117 attribute `values` are moved into container 

1118 `other_outputs`. 

1119 

1120 (OnnxOperatorTuple) 

1121 """ 

1122 logger.debug('op:%s-%d.to_onnx:%r:%r:%r', 

1123 self.__class__.__name__, id(self), 

1124 inputs, outputs, other_outputs) 

1125 logger.indent() 

1126 if self.values is None: 

1127 res = self.unique.to_onnx( 

1128 inputs=inputs, outputs=outputs, other_outputs=other_outputs, 

1129 target_opset=target_opset, optim=optim, verbose=verbose, 

1130 run_shape=run_shape, processed=processed, check_model=check_model, 

1131 fLOG=fLOG, return_builder=return_builder) 

1132 logger.dedent() 

1133 return res 

1134 new_other_outputs = self.values[1:] 

1135 if other_outputs is not None: 

1136 new_other_outputs.extend(other_outputs) 

1137 res = self.values[0].to_onnx( 

1138 inputs=inputs, outputs=outputs, other_outputs=new_other_outputs, 

1139 target_opset=target_opset, optim=optim, verbose=verbose, 

1140 run_shape=run_shape, processed=processed, check_model=check_model, 

1141 fLOG=fLOG, return_builder=return_builder) 

1142 logger.dedent() 

1143 return res 

1144 

1145 def find_named_inputs(self): 

1146 """ 

1147 Returns all inputs to the graph. 

1148 """ 

1149 if self.values is None: 

1150 return self.unique.find_named_inputs() 

1151 named = [] 

1152 for value in self.values: 

1153 tmp = value.find_named_inputs() 

1154 named.extend(tmp) 

1155 return named 

1156 

1157 def _set_control_op(self, op, subgraph_inputs=None): 

1158 """ 

1159 Tells this operator is part of a subgraph. 

1160 """ 

1161 logger.debug('op:%s-%d._set_control_op:%r', 

1162 self.__class__.__name__, id(self), op) 

1163 logger.indent() 

1164 if self.values is None: 

1165 raise NotImplementedError( # pragma: no cover 

1166 "Not implemented yet.") 

1167 for value in self.values: 

1168 value._set_control_op(op, subgraph_inputs) 

1169 logger.dedent() 

1170 

1171 

1172class OnnxOperator(OnnxOperatorBase): 

1173 """ 

1174 Ancestor to every *ONNX* operator exposed in 

1175 :mod:`mlprodict.npy.xops` and :mod:`mlprodict.npy.xops_ml`. 

1176 

1177 :param inputs: list of inputs expected by the operator 

1178 :param op_version: to select a specific version of the operator 

1179 :param output_names: used defined names for the outputs 

1180 :param domain: to overwrite the default domain 

1181 :param global_context: operator *If* executes one subgraph 

1182 whose nodes may use one existing output in the current 

1183 context. If not used in the main graph, these operators 

1184 are not linked to the output and cannot be retrieved. 

1185 *global_context* is a dictionary mapped the subgraph input 

1186 names to these operators. 

1187 :param kwargs: additional parameters of the operator 

1188 

1189 .. versionadd:: 0.9 

1190 """ 

1191 @classmethod 

1192 def __class_getitem__(cls, opset): 

1193 """ 

1194 Enables expression `cls[opset]`. It returns the appropriate class 

1195 `cls_opset`. Parameter *op_version* should be specified. 

1196 """ 

1197 if not isinstance(opset, int): 

1198 raise ValueError( 

1199 f"opset must an integer not {type(opset)!r}.") 

1200 best = None 

1201 for _, v in cls.past_version.items(): 

1202 if v.since_version == opset: 

1203 return lambda *args, **kwargs: v( 

1204 *args, op_version=opset, **kwargs) 

1205 if v.since_version <= opset and ( 

1206 best is None or best.since_version < v.since_version): 

1207 best = v 

1208 if best is None: 

1209 raise ValueError( 

1210 "Unable to find a version of operator %r and opset %r." % ( 

1211 cls.__name__, opset)) 

1212 return lambda *args, **kwargs: best( 

1213 *args, op_version=opset, **kwargs) 

1214 

1215 def __init__(self, *inputs, op_version=None, output_names=None, 

1216 domain=None, global_context=None, **kwargs): 

1217 

1218 OnnxOperatorBase.__init__(self) 

1219 logger.debug("op:%s-%d(%d in, op_version=%r, output_names=%r)", 

1220 self.__class__.__name__, id(self), 

1221 len(inputs), op_version, 

1222 output_names) 

1223 if (output_names is None and 

1224 self.__class__.__name__.startswith("OnnxScan")): 

1225 raise NotImplementedError( # pragma: no cover 

1226 "The class cannot infer the number of variables " 

1227 "for node '{}' yet. output_names must be specified" 

1228 ".".format(self.__class__.__name__)) 

1229 if isinstance(output_names, (str, Variable)): 

1230 output_names = [output_names] 

1231 if isinstance(output_names[0], str): 

1232 output_names[0] = Variable(output_names[0]) 

1233 elif isinstance(output_names, (list, OnnxOperator._InputContainer)): 

1234 if len(output_names) == 0: 

1235 raise ValueError( # pragma: no cover 

1236 "output_names cannot be empty (operator %r)." 

1237 "" % self.__class__.__name__) 

1238 output_names = output_names.copy() 

1239 for i in range(len(output_names)): # pylint: disable=C0200 

1240 if isinstance(output_names[i], str): 

1241 output_names[i] = Variable(output_names[i]) 

1242 elif output_names is not None: 

1243 raise TypeError( # pragma: no cover 

1244 f"output_names must be a string or a list not {type(output_names)!r}.") 

1245 

1246 if op_version is None: 

1247 if domain == '': 

1248 self.op_version = max_supported_opset() 

1249 else: 

1250 self.op_version = None 

1251 else: 

1252 self.op_version = op_version 

1253 self.since_version = self.__class__.since_version 

1254 

1255 if (self.op_version is not None and 

1256 self.op_version < self.since_version): 

1257 schema = self.find_schema(self.op_version) 

1258 self.since_version = schema.since_version 

1259 self.expected_inputs = schema.expected_inputs.copy() 

1260 self.expected_outputs = schema.expected_outputs.copy() 

1261 self.input_range = schema.input_range 

1262 self.output_range = schema.output_range 

1263 else: 

1264 self.expected_inputs = ( 

1265 None if self.__class__.expected_inputs is None 

1266 else self.__class__.expected_inputs.copy()) 

1267 self.expected_outputs = ( 

1268 None if self.__class__.expected_outputs is None 

1269 else self.__class__.expected_outputs.copy()) 

1270 self.input_range = self.__class__.input_range 

1271 self.output_range = self.__class__.output_range 

1272 if self.__class__.__name__ not in { 

1273 'OnnxScan', 'OnnxLoop', 'OnnxIf'}: 

1274 # The minimum opset depends on embedded graph 

1275 # by default, it takes the given op_version but the 

1276 # optimal value could be lower. 

1277 self.op_version = self.since_version 

1278 if self.op_version is None: 

1279 self.op_version = self.since_version 

1280 

1281 if (self.op_version is not None and 

1282 self.op_version < self.since_version): 

1283 raise RuntimeError( # pragma: no cover 

1284 "Operator '{}': requested version {} < " 

1285 "{} schema version.".format( 

1286 self.__class__.__name__, 

1287 self.op_version, self.since_version)) 

1288 

1289 self.state = None 

1290 self.domain = domain 

1291 self.kwargs = kwargs 

1292 self.max_item_ = None 

1293 

1294 # check inputs 

1295 self.inputs = [] 

1296 if len(inputs) > 0: 

1297 for inp in inputs: 

1298 if isinstance(inp, str): 

1299 self.inputs.append(Variable(inp)) 

1300 elif isinstance(inp, tuple): 

1301 if len(inp) != 2: 

1302 raise RuntimeError( # pragma: no cover 

1303 f"Unexpected tuple {inp!r}.") 

1304 self.inputs.append( 

1305 Variable(inp[0], dtype=guess_numpy_type(inp[1]), 

1306 shape=inp[1].shape)) 

1307 elif isinstance(inp, (OnnxOperatorBase, Variable)): 

1308 self.inputs.append(inp) 

1309 elif isinstance(inp, (numpy.ndarray, coo_matrix, TensorProto)): 

1310 self.inputs.append(inp) 

1311 elif isinstance(inp, ValueInfoProto): 

1312 self.inputs.append(inp.type.tensor_type) 

1313 else: 

1314 raise TypeError( # pragma: no cover 

1315 "Unable to interpret the input name for type {} in " 

1316 "operator '{}' (value={}).".format( 

1317 type(inp), self.__class__.__name__, inp)) 

1318 

1319 if (self.inputs is not None and 

1320 (len(self.inputs) < self.input_range[0] or 

1321 len(self.inputs) > self.input_range[1])): 

1322 raise RuntimeError( # pragma: no cover 

1323 "Operator '{}' expects a number of inputs in [{}, {}] not {} " 

1324 "(expected opset={}, class opset={})".format( 

1325 getattr(self, 'operator_name', '?'), *self.input_range, 

1326 len(self.inputs), op_version, self.op_version)) 

1327 # global context 

1328 if global_context is None: 

1329 self.global_context = None 

1330 else: 

1331 if not isinstance(global_context, dict): 

1332 raise TypeError( # pragma: no cover 

1333 "global_context must be a dictionary not %r." 

1334 "" % type(global_context)) 

1335 for k, v in global_context.items(): 

1336 if not isinstance(v, OnnxOperatorBase): 

1337 raise TypeError( # pragma: no cover 

1338 f"Value {k!r} in must be an OnnxOperatorBase not {type(v)!r}.") 

1339 self.global_context = global_context 

1340 

1341 # check output 

1342 self.output_names_ = output_names 

1343 self.output_variables = None 

1344 

1345 if self.output_names is not None: 

1346 if len(self.output_names) == 0: 

1347 raise ValueError( # pragma: no cover 

1348 "output_names can be None but cannot be empty for " 

1349 "operator %r." % self) 

1350 if self.output_variables is None: 

1351 self.output_variables = [None for o in self.output_names] 

1352 for i in range(len(self.output_names)): # pylint: disable=C0200 

1353 name = self.output_names[i] 

1354 if isinstance(name, Variable): 

1355 self.output_variables[i] = name 

1356 else: 

1357 raise TypeError( # pragma: no cover 

1358 "output_names must be a list of strings " 

1359 "and element %r is %r (%r)" % ( 

1360 i, type(name), name)) 

1361 if all(map(lambda x: x is None, self.output_variables)): 

1362 self.output_variables = None 

1363 

1364 if (self.output_names is not None and ( 

1365 self.expected_outputs is None or 

1366 len(self.output_names) > len(self.expected_outputs))): 

1367 if self.expected_outputs is None: 

1368 self.expected_outputs = [] 

1369 for i in range(len(self.expected_outputs), 

1370 len(self.output_names)): 

1371 self.expected_outputs.append((self.output_names[i], None)) 

1372 

1373 if (self.expected_inputs is None or 

1374 len(self.inputs) > len(self.expected_inputs)): 

1375 if self.expected_inputs is None: 

1376 self.expected_inputs = [] 

1377 for i in range(len(self.expected_inputs), 

1378 len(self.inputs)): 

1379 inp = self.inputs[i] 

1380 if isinstance(inp, str): 

1381 inp = (inp, None) 

1382 elif hasattr(inp, 'add_to'): 

1383 # OnnxOperator 

1384 existing = set(_[0] for _ in self.expected_inputs) 

1385 i = 10 

1386 name = "input%d" % (10 + i) 

1387 while name in existing: 

1388 i += 1 

1389 name = "input%d" % (10 + i) 

1390 inp = (name, None) 

1391 self.expected_inputs.append(inp) 

1392 

1393 self._post_process_attributes() 

1394 self._check() 

1395 self.external_inputs = [] 

1396 

1397 def add_external_input(self, op): 

1398 """ 

1399 Tells a subgraph this node comes from a graph calling this one. 

1400 """ 

1401 logger.debug("op:%s.add_external_input:%r", 

1402 self.__class__.__name__, op) 

1403 self.external_inputs.append(op) 

1404 

1405 def do(self, body, subgraph_inputs=None): 

1406 """ 

1407 Fills attribute *body*. 

1408 

1409 :param branch: onnx graph or @see cl OnnxOperator 

1410 :param subgraph_inputs: additional parameter to convert 

1411 the subgraph into ONNX 

1412 :return: self 

1413 """ 

1414 if (isinstance(body, (onnx.GraphProto, onnx.ModelProto)) and 

1415 subgraph_inputs is not None): 

1416 raise RuntimeError( # pragma: no cover 

1417 "inputs cannot be defined if body is a " 

1418 "GraphProto or a ModelProto.") 

1419 return self._add_subgraph( 

1420 'body', body, subgraph_inputs=subgraph_inputs) 

1421 

1422 def then_do(self, branch): 

1423 """ 

1424 Fills attribute *then_branch*. 

1425 

1426 :param branch: onnx graph or @see cl OnnxOperator 

1427 :return: self 

1428 """ 

1429 if isinstance(branch, onnx.GraphProto) and len(branch.input) > 0: 

1430 raise RuntimeError( # pragma: no cover 

1431 "then_branch subgraph cannot have any input.") 

1432 return self._add_subgraph('then_branch', branch) 

1433 

1434 def else_do(self, branch): 

1435 """ 

1436 Fills attribute *else_branch*. 

1437 

1438 :param branch: onnx graph or @see cl OnnxOperator 

1439 :return: self 

1440 """ 

1441 if isinstance(branch, onnx.GraphProto) and len(branch.input) > 0: 

1442 raise RuntimeError( # pragma: no cover 

1443 "else_branch subgraph cannot have any input.") 

1444 return self._add_subgraph('else_branch', branch) 

1445 

1446 def _add_subgraph(self, attribute, branch, subgraph_inputs=None): 

1447 """ 

1448 Fills attribute *attribute*. 

1449 

1450 :param attribute: attribute name 

1451 :param branch: onnx graph or @see cl OnnxOperator 

1452 :param subgraph_inputs: additional parameter to convert 

1453 the subgraph into ONNX 

1454 :return: self 

1455 """ 

1456 if isinstance(branch, str): 

1457 # branch is an input. 

1458 OnnxIdentity = loadop('Identity') 

1459 branch = OnnxIdentity(OnnxExisting(branch), 

1460 op_version=self.op_version) 

1461 logger.debug("op:%s:_add_subgraph:%s=type(branch)=%r", 

1462 self.__class__.__name__, attribute, type(branch)) 

1463 if isinstance(branch, onnx.ModelProto): 

1464 return self._add_subgraph(attribute, branch.graph) 

1465 if isinstance(branch, onnx.GraphProto): 

1466 self.kwargs[attribute] = branch 

1467 return self 

1468 if isinstance(branch, (OnnxOperator, OnnxOperatorTuple)): 

1469 self.kwargs[attribute] = branch 

1470 branch._set_control_op(self, subgraph_inputs=subgraph_inputs) 

1471 return self 

1472 raise TypeError( # pragma: no cover 

1473 "Unexpected type %r for a subgraph, attribute %r " 

1474 "and class %r." % ( 

1475 type(branch), attribute, self.__class__.__name__)) 

1476 

1477 def _set_control_op(self, op, subgraph_inputs=None): 

1478 """ 

1479 Sets *control_op* for every instance of @see cl OnnxExisting node. 

1480 

1481 :param op: operator calling the subgraph. 

1482 :param inputs: additional parameters to convert 

1483 into ONNX 

1484 """ 

1485 if subgraph_inputs is not None: 

1486 self.subgraph_inputs = subgraph_inputs 

1487 

1488 for i, inp in enumerate(self.inputs): 

1489 if isinstance(inp, OnnxOperatorBase): 

1490 logger.debug("op:%s-%d:_set_control_op:propagate-into-input:%d:p:%d", 

1491 self.__class__.__name__, id(self), i, id(op)) 

1492 logger.indent() 

1493 inp._set_control_op(op) 

1494 logger.dedent() 

1495 if self.kwargs is None: 

1496 return 

1497 for k, v in self.kwargs.items(): 

1498 if isinstance(v, OnnxOperatorBase): 

1499 logger.debug("op:%s-%d:_set_control_op:propagate-into-attribute:%s:p:%d", 

1500 self.__class__.__name__, id(self), k, id(op)) 

1501 logger.indent() 

1502 v._set_control_op(op) 

1503 logger.dedent() 

1504 

1505 @property 

1506 def output_names(self): 

1507 "Returns `self.output_names_`." 

1508 return self.output_names_ 

1509 

1510 @output_names.setter 

1511 def output_names(self, value): 

1512 logger.debug("op:%s:output_names:set(%r)", 

1513 self.__class__.__name__, value) 

1514 if not isinstance(value, (list, OnnxOperator._InputContainer)): 

1515 raise TypeError( # pragma: no cover 

1516 f"Value must be a list not {type(value)!r}.") 

1517 res = [] 

1518 for v in value: 

1519 if isinstance(v, (Variable, ExistingVariable)): 

1520 res.append(v) 

1521 elif isinstance(v, str): 

1522 res.append(Variable(v)) 

1523 else: 

1524 raise TypeError( # pragma: no cover 

1525 "Unexpected type %r for an output_names %r." 

1526 "" % type(v)) 

1527 self.output_names_ = res 

1528 

1529 def _check(self): 

1530 input_types = (Variable, OnnxOperatorBase, numpy.ndarray, 

1531 TensorProto) 

1532 for o in self.inputs: 

1533 if not isinstance(o, input_types): 

1534 raise TypeError( # pragma: no cover 

1535 f"Wrong type for inputs {self.inputs!r}.") 

1536 if self.output_names is not None: 

1537 for o in self.output_names: 

1538 if not isinstance(o, Variable): 

1539 raise TypeError( # pragma: no cover 

1540 f"Wrong type for output_names {self.output_names!r}.") 

1541 

1542 def _post_process_attributes(self): 

1543 """ 

1544 Walks through attributes and replaces them by ONNX values. 

1545 """ 

1546 # Looks into attributes if there is any tuple 

1547 # (GraphProto, OnnxOperator). In that case, the function 

1548 # replaces the tuple by the graph proto and keeps 

1549 # in attributes graph_algebra the OnnxOperator 

1550 # which is the source of it. 

1551 updates = {} 

1552 graph_algebra = {} 

1553 for k, v in self.kwargs.items(): 

1554 if isinstance(v, tuple) and isinstance(v[0], GraphProto): 

1555 updates[k] = v[0] 

1556 graph_algebra[k] = v[1] 

1557 

1558 if len(graph_algebra) > 0: 

1559 self.kwargs.update(updates) 

1560 self.graph_algebra = graph_algebra 

1561 

1562 if self.__class__.__name__ == "OnnxConstantOfShape": 

1563 if "value" in self.kwargs: 

1564 value = self.kwargs['value'] 

1565 if isinstance(value, TensorProto): 

1566 return 

1567 if isinstance(value, numpy.ndarray): 

1568 if value.shape == (1, ): 

1569 val = value[0] 

1570 elif len(value.shape) == 0: 

1571 val = value 

1572 else: 

1573 raise RuntimeError( # pragma: no cover 

1574 "Unexpected shape %r for value, it must be " 

1575 "an array of one element." % value.shape) 

1576 self.kwargs['value'] = from_array( 

1577 numpy.array([val], dtype=value.dtype)) 

1578 return 

1579 raise TypeError( # pragma: no cover 

1580 "Unexpected type %r for value. It should be an array " 

1581 "of one element." % type(value)) 

1582 return 

1583 

1584 if self.__class__.__name__ == "OnnxCast": 

1585 if "to" in self.kwargs: 

1586 value = self.kwargs['to'] 

1587 if not isinstance(value, int): 

1588 try: 

1589 to = numpy_type_prototype(value) 

1590 except ValueError as e: # pragma: no cover 

1591 raise ValueError( 

1592 "Unable to convert argument to in operator cast, " 

1593 "type is %r, value is %r." % (type(value), value)) from e 

1594 self.kwargs['to'] = to 

1595 return 

1596 

1597 def update_max_item(self, index): 

1598 """ 

1599 Some operators return a undefined number of outputs. 

1600 The method is called when require one of them (with `__getitem__`) 

1601 and keeps the greater requested index assuming the node does 

1602 not output any result beyond that index. 

1603 

1604 :param index: requested index 

1605 """ 

1606 if self.max_item_ is None: 

1607 self.max_item_ = index 

1608 else: 

1609 self.max_item_ = max(self.max_item_, index) 

1610 if self.expected_outputs is None: 

1611 self.expected_outputs = [] 

1612 while len(self.expected_outputs) <= self.max_item_: 

1613 self.expected_outputs.append( 

1614 (("NEWOUTPUT", len(self.expected_outputs)), None)) 

1615 

1616 def find_schema(self, op_version): 

1617 """ 

1618 Checks if there is an existing schema for a specific version. 

1619 

1620 :param op_version: requested version 

1621 :return: schema 

1622 """ 

1623 if not hasattr(self.__class__, 'past_version'): 

1624 raise RuntimeError( # pragma: no cover 

1625 "Missing attribute 'past_version', there is " 

1626 "no other available schema.") 

1627 found = None 

1628 for v in self.past_version.values(): 

1629 if v.since_version > op_version: 

1630 continue 

1631 if found is None or v.since_version > found.since_version: 

1632 found = v 

1633 if found is None: 

1634 raise RuntimeError( # pragma: no cover 

1635 "Operator '{}': requested version {} < " 

1636 "{} schema version (past_version {}).".format( 

1637 self.__class__.__name__, 

1638 op_version, self.since_version, 

1639 [v.since_version for v in self.past_version.values()])) 

1640 return found 

1641 

1642 def __repr__(self): 

1643 """ 

1644 usual 

1645 """ 

1646 return "{}({} in) -> {}".format( 

1647 self.__class__.__name__, 

1648 len(self.inputs) if self.inputs is not None else 0, 

1649 [str(o) for o in self.output_names] 

1650 if self.output_names is not None else "?") 

1651 

1652 def get_output_result(self, i=0): 

1653 """ 

1654 Returns the output name at position *i*. 

1655 """ 

1656 return NodeResultName(self, i) 

1657 

1658 def __getitem__(self, index): 

1659 """ 

1660 Returns an accessor to one of the output 

1661 of this node. 

1662 """ 

1663 self.update_max_item(index) 

1664 return OnnxOperatorItem(self, index, self.op_version) 

1665 

1666 def __iter__(self): 

1667 """ 

1668 Allows expressions such as ``a, b = OnnxTopK(...)``. 

1669 """ 

1670 n = None 

1671 if self.output_names is not None: 

1672 n = len(self.output_names) 

1673 else: 

1674 rg = self.output_range 

1675 if rg[0] == rg[1] and rg[0] > 0: 

1676 n = rg[0] 

1677 if n is None and self.max_item_ is not None: 

1678 n = self.max_item_ + 1 

1679 if n is None: 

1680 raise RuntimeError( # pragma: no cover 

1681 "Unable to guess the number of outputs of node type %r. " 

1682 "Uses operator [] to select a specific output." % 

1683 self.__class__.__name__) 

1684 if self.max_item_ is not None: 

1685 n = max(n, self.max_item_ + 1) 

1686 for i in range(n): 

1687 yield self[i] 

1688 

1689 def add_to(self, builder): 

1690 """ 

1691 Adds to graph builder. 

1692 

1693 :param builder: instance of @see cl _GraphBuilder, 

1694 it must have a method `add_node` 

1695 """ 

1696 logger.debug("op:%s-%d.add_to(builder-%d):1", 

1697 self.__class__.__name__, id(self), id(builder)) 

1698 inputs = builder.get_input_names(self, self.inputs) 

1699 if self.output_names is not None: 

1700 n_outputs = len(self.output_names) 

1701 elif self.expected_outputs is not None: 

1702 n_outputs = len(self.expected_outputs) 

1703 else: 

1704 n_outputs = self.output_range[0] 

1705 outputs = [builder.get_unique_output_name(NodeResultName(self, i)) 

1706 for i in range(n_outputs)] 

1707 logger.debug("op:%s-%d.add_to(builder-%d):2:%s:%r:%r", 

1708 self.__class__.__name__, id(self), id(builder), 

1709 self.operator_name, inputs, outputs) 

1710 logger.indent() 

1711 builder.add_node( 

1712 self.operator_name, 

1713 builder.get_unique_name( 

1714 '_' + self.operator_name.lower(), reserved=False), 

1715 inputs, outputs, domain=self.domain, opset=self.op_version, 

1716 **self.kwargs) 

1717 logger.dedent() 

1718 logger.debug("op:%s-%d.add_to(builder-%d):3", 

1719 self.__class__.__name__, id(self), id(builder)) 

1720 

1721 @staticmethod 

1722 def _node_to_graph_preprocess_list(inputs): 

1723 new_inputs = OrderedDict() 

1724 for el in inputs: 

1725 if isinstance(el, str): 

1726 new_inputs[el] = Variable(el) 

1727 elif isinstance(el, Variable): 

1728 new_inputs[el.name] = el 

1729 elif isinstance(el, tuple) and len(el) == 2: 

1730 # sklearn-onnx 

1731 new_inputs[el[0]] = Variable( 

1732 el[0], guess_numpy_type(el[1]), el[1].shape) 

1733 elif isinstance(el, ValueInfoProto): 

1734 new_inputs[el.name] = el 

1735 else: 

1736 raise TypeError( # pragma: no cover 

1737 f"Unable to handle input type {type(el)!r} ({el!r}).") 

1738 return new_inputs 

1739 

1740 @staticmethod 

1741 def _node_to_graph_process_input(processed, inputs, set_inputs, node, inp, 

1742 new_inputs, new_stack, inputs_dtype, 

1743 as_function=False): 

1744 if not as_function and inputs is None and inputs_dtype is None: 

1745 raise RuntimeError( # pragma: no cover 

1746 "Both inputs and inputs_dtype cannot be None at the same time " 

1747 "for inp=%r." % (inp, )) 

1748 

1749 if isinstance(inp, OnnxExisting): 

1750 if inp.inputs[0].output_names is None: 

1751 raise RuntimeError( # pragma: no cover 

1752 "output_names cannot be None for OnnxExisting, " 

1753 "subop is %r." % (inp.inputs[0], )) 

1754 # We need to check that this input was not already added. 

1755 oinp = inp.inputs[0].output_names[0] 

1756 if not new_inputs.has_input(oinp) and id(inp.inputs[0]) not in processed: 

1757 raise RuntimeError( # pragma: no cover 

1758 "This node id=%d (%r) was not added yet in the subgraph " 

1759 "but it must be from node %r." % ( 

1760 id(inp.inputs[0]), inp.inputs[0], node)) 

1761 elif isinstance(inp, OnnxOperator): 

1762 new_stack.append(inp) 

1763 logger.debug("op:static:SG-op:processed[%d]:%s", 

1764 id(inp), inp.__class__.__name__) 

1765 processed[id(inp)] = inp 

1766 elif isinstance(inp, OnnxOperatorItem): 

1767 new_stack.append(inp) 

1768 logger.debug("op:static:SG-it:processed[%d]:%s", 

1769 id(inp), inp.__class__.__name__) 

1770 processed[id(inp)] = inp 

1771 new_stack.append(inp.onx_op) 

1772 logger.debug("op:static:SG-op:processed[%d]:%s", 

1773 id(inp.onx_op), inp.onx_op.__class__.__name__) 

1774 processed[id(inp.onx_op)] = inp.onx_op 

1775 elif isinstance(inp, OnnxOperatorTuple): 

1776 # new_stack.append(inp) 

1777 # new_stack.append(inp.onx_op) 

1778 raise NotImplementedError( # pragma: no cover 

1779 "Unable to guess inputs when one input is OnnxOperatorTuple.") 

1780 elif isinstance(inp, Variable): 

1781 if inp.name in set_inputs: 

1782 return 

1783 if inp.name == '': 

1784 return 

1785 logger.debug("op:static:SG-var:processed[%d]:%s", 

1786 id(inp), inp.__class__.__name__) 

1787 processed[id(inp)] = inp 

1788 set_inputs.add(inp.name) 

1789 if inputs is None and inputs_dtype is None: 

1790 new_inputs.append(InputDetectedVariable(node, inp)) 

1791 elif isinstance(inputs, dict): 

1792 if inp.name in inputs: 

1793 var = InputDetectedVariable( 

1794 node, inp.copy_merge(inputs[inp.name])) 

1795 new_inputs.append(var) 

1796 else: 

1797 external_inputs = { 

1798 ei.name: ei for ei in node.external_inputs 

1799 if isinstance(ei, Variable)} 

1800 if inp.name not in external_inputs: 

1801 # This happens when an input is used for the first time 

1802 # inside a sub-sub-graph. 

1803 var = InputDetectedVariable(node, Variable(inp.name)) 

1804 elif inp.name in set_inputs: 

1805 var = InputDetectedVariable( 

1806 node, inp.copy_merge(external_inputs[inp.name])) 

1807 else: 

1808 raise ValueError( # pragma: no cover 

1809 f"Unable to find input {inp!r} in {inputs!r}, " 

1810 f"new_inputs={new_inputs!r}, " 

1811 f"type(node)={type(node)!r}, " 

1812 f"node.external_inputs={node.external_inputs!r}, " 

1813 f"node={node!r}.") 

1814 new_inputs.append(var) 

1815 elif inputs_dtype is not None: 

1816 new_inputs.append( 

1817 InputDetectedVariable(node, inp.copy_add(inputs_dtype))) 

1818 elif isinstance(inputs, Variable): 

1819 if inp.name == inputs.name: 

1820 new_inputs.append( 

1821 InputDetectedVariable(node, inp.copy_merge(inputs))) 

1822 else: 

1823 new_inputs.append(InputDetectedVariable(node, inp)) 

1824 else: 

1825 raise RuntimeError( # pragma: no cover 

1826 f"Unable to handle inputs={inputs!r}.") 

1827 elif isinstance(inp, numpy.ndarray): 

1828 pass 

1829 else: 

1830 raise TypeError( # pragma: no cover 

1831 f"Unexpected input type {type(inp)!r} in node type {type(node)!r}.") 

1832 

1833 @staticmethod 

1834 def _node_to_graph_get_type(node, name=None, outputs=None, 

1835 outputs_dtype=None): 

1836 if outputs is None: 

1837 return outputs_dtype, None 

1838 if isinstance(outputs, Variable): 

1839 if name is None: 

1840 return (outputs.dtype or outputs_dtype, None) 

1841 if isinstance(name, Variable): 

1842 return (outputs.dtype or name.dtype or outputs_dtype, 

1843 None) 

1844 raise RuntimeError( # pragma: no cover 

1845 f"Unable to handle outputs={outputs!r}.") 

1846 if isinstance(outputs, dict): 

1847 if name is None: 

1848 return _infer_node_output(node, outputs) 

1849 if isinstance(name, Variable): 

1850 n = name.name 

1851 else: 

1852 n = name 

1853 if n not in outputs: 

1854 return None, None 

1855 return outputs[n], None 

1856 if isinstance(outputs, (list, OnnxOperator._InputContainer)): 

1857 raise NotImplementedError( # pragma: no cover 

1858 f"Unexpected type for name={name!r}, outputs={outputs!r}.") 

1859 if is_numpy_dtype(outputs): 

1860 return outputs, None 

1861 raise RuntimeError( # pragma: no cover 

1862 f"Unable to handle outputs={outputs!r}.") 

1863 

1864 @staticmethod 

1865 def _node_to_graph_reorder_by_name(new_inputs, inputs): 

1866 memo = OrderedDict((n.name, n) for n in new_inputs) 

1867 done = set() 

1868 result = [] 

1869 for inp in inputs: 

1870 if inp.name in memo: 

1871 result.append(memo[inp.name]) 

1872 done.add(inp.name) 

1873 for k, v in memo.items(): 

1874 if k in done: 

1875 continue 

1876 result.append(v) 

1877 return result 

1878 

1879 class _InputContainer: 

1880 

1881 def __init__(self): 

1882 self._c = [] 

1883 self._names = set() 

1884 

1885 def has_input(self, inp): 

1886 "Checks that input *inp* is part the list of names." 

1887 if isinstance(inp, str): 

1888 return inp in self._names 

1889 if inp.name in self._names: 

1890 return True 

1891 return False 

1892 

1893 def append(self, inp): 

1894 "Append one element to the list." 

1895 name = inp.var.name 

1896 self._c.append(inp) 

1897 self._names.add(name) 

1898 

1899 def __len__(self): 

1900 return len(self._c) 

1901 

1902 def __repr__(self): 

1903 return f"{'_InputContainer'}(\n {pprint.pformat(self._c)})" 

1904 

1905 def __iter__(self): 

1906 for inp in self._c: 

1907 yield inp 

1908 

1909 def _node_to_graph(self, other_outputs=None, inputs=None, outputs=None, 

1910 as_function=False, processed=None): 

1911 """ 

1912 Builds a graph as a list of nodes to walk through in that order. 

1913 """ 

1914 if processed is None: 

1915 raise RuntimeError( # pragma: no cover 

1916 "processed cannot be None.") 

1917 node_outputs = [self] 

1918 if other_outputs is not None: 

1919 node_outputs += other_outputs 

1920 

1921 if inputs is not None: 

1922 logger.debug("op:%s-%d._node_to_graph:1:inputs=%r", 

1923 self.__class__.__name__, id(self), inputs) 

1924 if outputs is not None: 

1925 logger.debug("op:%s-%d._node_to_graph:1:outputs=%r", 

1926 self.__class__.__name__, id(self), outputs) 

1927 

1928 # preprocess inputs, outputs 

1929 _keep_inputs = None 

1930 inputs_dtype = None 

1931 if isinstance(inputs, (list, OnnxOperator._InputContainer)): 

1932 _keep_inputs = inputs 

1933 inputs_dict = self._node_to_graph_preprocess_list(inputs) 

1934 elif isinstance(inputs, dict): 

1935 inputs_dict = inputs 

1936 elif isinstance(inputs, Variable): 

1937 inputs = [inputs] 

1938 inputs_dict = self._node_to_graph_preprocess_list(inputs) 

1939 elif is_numpy_dtype(inputs): 

1940 inputs_dtype = inputs 

1941 inputs_dict = None 

1942 else: 

1943 raise TypeError( # pragma: no cover 

1944 f"Unexpected type {type(inputs)!r} for inputs.") 

1945 

1946 _keep_outputs = None 

1947 outputs_dtype = None 

1948 if isinstance(outputs, (list, OnnxOperator._InputContainer)): 

1949 _keep_outputs = outputs 

1950 outputs_dict = self._node_to_graph_preprocess_list(outputs) 

1951 elif isinstance(outputs, dict): 

1952 outputs_dict = outputs 

1953 elif isinstance(outputs, Variable): 

1954 outputs = [outputs] 

1955 outputs_dict = self._node_to_graph_preprocess_list(outputs) 

1956 elif is_numpy_dtype(outputs): 

1957 outputs_dtype = outputs 

1958 outputs_dict = None 

1959 else: 

1960 raise TypeError( # pragma: no cover 

1961 f"Unexpected type {type(outputs)!r} for outputs.") 

1962 

1963 if inputs is not None: 

1964 logger.debug("op:%s-%d._node_to_graph:2:inputs=%r", 

1965 self.__class__.__name__, id(self), inputs) 

1966 if outputs is not None: 

1967 logger.debug("op:%s-%d._node_to_graph:2:outputs=%r", 

1968 self.__class__.__name__, id(self), outputs) 

1969 if inputs_dict is not None: 

1970 logger.debug("op:%s-%d._node_to_graph:2:inputs_dict=%r", 

1971 self.__class__.__name__, id(self), inputs_dict) 

1972 if outputs_dict is not None: 

1973 logger.debug("op:%s-%d._node_to_graph:2:outputs_dict=%r", 

1974 self.__class__.__name__, id(self), outputs_dict) 

1975 if inputs_dtype is not None: 

1976 logger.debug("op:%s-%d._node_to_graph:2:inputs_dtype=%r", 

1977 self.__class__.__name__, id(self), inputs_dtype) 

1978 if outputs_dtype is not None: 

1979 logger.debug("op:%s-%d._node_to_graph:2:outputs_dtype=%r", 

1980 self.__class__.__name__, id(self), outputs_dtype) 

1981 

1982 # walk through graph 

1983 stack = list(node_outputs) 

1984 new_inputs = self._InputContainer() 

1985 set_inputs = set() 

1986 memo = [] 

1987 while len(stack) > 0: 

1988 logger.debug("op:%s-%d._node_to_graph:loop:len(memo)=%d", 

1989 self.__class__.__name__, id(self), len(memo)) 

1990 memo.extend(stack) 

1991 new_stack = [] 

1992 for obj in stack: 

1993 logger.debug("op:%s-%d._node_to_graph:-node=%r:external_inputs=%r", 

1994 self.__class__.__name__, id(self), 

1995 obj.__class__.__name__, 

1996 getattr(obj, 'external_inputs', "-")) 

1997 if isinstance(obj, OnnxExisting): 

1998 pass 

1999 elif isinstance(obj, OnnxOperatorItem): 

2000 # nothing to do, OnnxOperatorItem is created 

2001 # by OnnxOperator.__getitem__. 

2002 pass 

2003 elif isinstance(obj, (OnnxOperator, OnnxOperatorTuple)): 

2004 if len(obj.external_inputs) > 0: 

2005 # external_inputs are inputs required by a subgraph 

2006 # but not necessarily used in the main graph. 

2007 # They need to be processed first. 

2008 for inp in obj.external_inputs: 

2009 self._node_to_graph_process_input( 

2010 processed, inputs_dict, set_inputs, obj, inp, new_inputs, 

2011 new_stack, inputs_dtype, as_function=as_function) 

2012 for inp in obj.inputs: 

2013 self._node_to_graph_process_input( 

2014 processed, inputs_dict, set_inputs, obj, inp, new_inputs, 

2015 new_stack, inputs_dtype, as_function=as_function) 

2016 else: 

2017 raise TypeError( # pragma: no cover 

2018 f"Unexpected type {type(obj)!r}.") 

2019 stack = new_stack 

2020 

2021 # reorder new_inputs to follow inputs initial order 

2022 if _keep_inputs is not None: 

2023 new_inputs = self._node_to_graph_reorder_by_name( 

2024 new_inputs, inputs) 

2025 

2026 logger.debug("op:%s-%d._node_to_graph:new_inputs=%r", 

2027 self.__class__.__name__, id(self), new_inputs) 

2028 

2029 # eliminate duplicates 

2030 done = set() 

2031 nodes = [] 

2032 for node in reversed(memo): 

2033 if id(node) in done: 

2034 continue 

2035 done.add(id(node)) 

2036 nodes.append(node) 

2037 

2038 # outputs 

2039 set_names = set() 

2040 new_outputs = [] 

2041 run_shape = False 

2042 for node in node_outputs: 

2043 if node.output_names is None: 

2044 n = self.output_range[0] 

2045 for i in range(n): 

2046 to, shape = self._node_to_graph_get_type( 

2047 node, outputs=outputs_dict, 

2048 outputs_dtype=outputs_dtype) 

2049 if to is None: 

2050 run_shape = True 

2051 res = f'xop_{id(node)}_{i}' 

2052 var = Variable(res, added_dtype=to, shape=shape) 

2053 if var.name in set_names: 

2054 raise RuntimeError( # pragma: no cover 

2055 f"Duplicated output name var={var!r} in " 

2056 f"{set_names!r}.") 

2057 set_names.add(var.name) 

2058 new_outputs.append(OutputDetectedVariable(node, var, i)) 

2059 else: 

2060 for i, o in enumerate(node.output_names): 

2061 if isinstance(o, str): 

2062 raise TypeError( # pragma: no cover 

2063 "Output %d - %r (%r) not allowed in node %r." % ( 

2064 i, o, node.output_names, node)) 

2065 to, shape = self._node_to_graph_get_type( 

2066 node, o, outputs=outputs_dict, 

2067 outputs_dtype=outputs_dtype) 

2068 if to is None: 

2069 run_shape = True 

2070 res = (o, to) 

2071 var = o.copy_merge(to, shape=shape) 

2072 if var.name in set_names: 

2073 raise RuntimeError( # pragma: no cover 

2074 f"Duplicated output name o={o!r} var={var!r}.") 

2075 set_names.add(var.name) 

2076 new_outputs.append(OutputDetectedVariable(node, var, i)) 

2077 if len(new_outputs) == 0: 

2078 raise RuntimeError( # pragma: no cover 

2079 f"No detected outputs inputs={inputs_dict!r} outputs={outputs_dict!r}.") 

2080 

2081 # reorder new_outputs to follow outputs initial order 

2082 if _keep_outputs is not None: 

2083 new_outputs = self._node_to_graph_reorder_by_name( 

2084 new_outputs, outputs) 

2085 

2086 logger.debug("op:%s-%d._node_to_graph:new_outputs=%r", 

2087 self.__class__.__name__, id(self), new_outputs) 

2088 

2089 return nodes, new_inputs, new_outputs, run_shape 

2090 

2091 def to_onnx(self, inputs=None, outputs=None, 

2092 other_outputs=None, target_opset=None, 

2093 optim=True, verbose=0, run_shape=True, 

2094 function_name=None, function_domain=None, 

2095 fLOG=print, processed=None, check_model=True, 

2096 return_builder=False): 

2097 """ 

2098 Converts this operator into an ONNX graph. 

2099 

2100 :param inputs: information about type, it should not be None 

2101 :param outputs: information about types, if None, the function 

2102 will use shape inference to guess the final output type 

2103 and shape 

2104 :param other_outputs: additional nodes to consider 

2105 as graph outputs but not outputs of this particular 

2106 node 

2107 :param target_opset: dictionary with target opset per domain, 

2108 None for the default one 

2109 :param optim: optimize the model with function 

2110 @see fn onnx_optimisations 

2111 :param run_shape: in case output shapes are not specify, 

2112 the function runs function :epkg:`infer_shapes` 

2113 to guess them, False would disable that 

2114 default behaviour 

2115 :param verbose: prints information 

2116 :param function_name: if not None, returns a :epkg:`FunctionProto` 

2117 :param function_domain: in case of a function, declares the function 

2118 as part of this domain 

2119 :param fLOG: logging function 

2120 :param processed: keeps track the of the processed nodes 

2121 :param check_model: checks the output model 

2122 :param return_builder: if True, returns the instance of @see cl GraphBuilder 

2123 used to build the onnx graph. 

2124 :return: ONNX stucture 

2125 

2126 *inputs* and *outputs* parameters work the same way. 

2127 Here is some possible walues: 

2128 

2129 - `inputs=numpy.float32`: all inputs are dense tensors of 

2130 unknown shapes sharing the same element type 

2131 - `inputs={'X': numpy.float32`, 'Y': numpy.in64}`: 

2132 input `X` is a dense tensor of float32, 

2133 input `Y` is a dense tensor of int64, 

2134 - `{'X': numpy.array(...)}}`: input `X` is a dense 

2135 tensor with a precise shape 

2136 - `inputs=[Variable('X', numpy.float32, [1, 2])]`: 

2137 input `X` is a dense tensor of float32 with shape `[1, 2]` 

2138 - `inputs=[Variable('X', numpy.float32, [None, 2])]`: 

2139 input `X` is a dense tensor of float32 with a 2D tensor 

2140 with an unknown dimension (first one) 

2141 - see @see cl Variable 

2142 

2143 (OnnxOperator) 

2144 """ 

2145 # opsets 

2146 logger.debug( 

2147 "op:%s-%d.to_onnx(%r, %r, other_outputs=%r, target_opset=%r, as_function=%r)", 

2148 self.__class__.__name__, id(self), inputs, outputs, 

2149 other_outputs, target_opset, function_name) 

2150 if isinstance(target_opset, dict): 

2151 dom = self.domain or '' 

2152 target_opset = target_opset.get(dom, None) 

2153 elif isinstance(target_opset, int): 

2154 if self.domain not in ('', None): 

2155 # The target_opset is for the domain '' we ignore it. 

2156 target_opset = None 

2157 elif target_opset is not None: 

2158 raise TypeError( # pragma: no cover 

2159 "target_opset must be a dictionary {domain: " 

2160 "target_opset} not %r for operator %r." % ( 

2161 target_opset, self.__class__.__name__)) 

2162 

2163 if self.domain in ('', None) and target_opset == 1: 

2164 raise RuntimeError( # pragma: no cover 

2165 "target_opset cannot be 1.") 

2166 if (self.op_version is not None and target_opset is not None and 

2167 self.op_version > target_opset): 

2168 raise RuntimeError( # pragma: no cover 

2169 "target_opset={} is lower than the version={} requested " 

2170 "for this node '{}'.".format( 

2171 target_opset, self.op_version, self.__class__.__name__)) 

2172 

2173 # get the graph 

2174 if processed is None: 

2175 processed = {} 

2176 logger.debug("op:%s-%d:SG-self:processed[%d]:SELF", 

2177 self.__class__.__name__, id(self), id(self)) 

2178 processed[id(self)] = self 

2179 

2180 logger.indent() 

2181 nodes, graph_inputs, graph_outputs, run_shape2 = self._node_to_graph( 

2182 other_outputs, inputs, outputs, as_function=function_name is not None, 

2183 processed=processed) 

2184 if hasattr(self, 'subgraph_inputs'): 

2185 if any(map(lambda o: not isinstance(o, Variable), 

2186 self.subgraph_inputs)): 

2187 raise TypeError( # pragma: no cover 

2188 f"Unexpected type, all type should be Variable in " 

2189 f"{self.subgraph_inputs!r}.") 

2190 graph_inputs = [ 

2191 InputDetectedVariable(None, v) for v in self.subgraph_inputs 

2192 ] + graph_inputs 

2193 logger.dedent() 

2194 

2195 logger.debug("op:%s.to_onnx:graph_inputs=%r", 

2196 self.__class__.__name__, graph_inputs) 

2197 logger.debug("op:%s.to_onnx:graph_outputs=%r", 

2198 self.__class__.__name__, graph_outputs) 

2199 

2200 if len(nodes) == 0: 

2201 raise RuntimeError( # pragma: no cover 

2202 "Node list is empty.") 

2203 if verbose > 1: 

2204 for i, n in enumerate(nodes): # pragma: no cover 

2205 fLOG("nodes[%d]=%r" % (i, n)) 

2206 for i, n in enumerate(graph_inputs): # pragma: no cover 

2207 fLOG("graph_inputs[%d]=%r" % (i, n)) 

2208 

2209 # creates a _GraphBuilder 

2210 builder = _GraphBuilder() 

2211 

2212 # reserve input names starting by the first one 

2213 for node in reversed(nodes): 

2214 for var in node.inputs: 

2215 if isinstance(var, Variable): 

2216 logger.debug("op:%s.to_onnx:_add_name(%r)", 

2217 self.__class__.__name__, var.name) 

2218 builder._add_name(var.name) 

2219 

2220 # reserve output names starting by the last ones 

2221 for node in reversed(nodes): 

2222 builder.reserve_names(node, node.output_names) 

2223 

2224 # adds every node to the builder 

2225 for i, node in enumerate(nodes): 

2226 logger.debug("op:%s-%d.to_onnx:node:%d/%d:%r", 

2227 self.__class__.__name__, id(self), i, len(nodes), node) 

2228 

2229 for node in nodes: 

2230 if isinstance(node, OnnxExisting): 

2231 continue 

2232 logger.indent() 

2233 hidden = node._to_onnx_attributes( 

2234 inputs=graph_inputs, target_opset=target_opset, 

2235 optim=optim, verbose=verbose, run_shape=run_shape, fLOG=fLOG, 

2236 processed=processed) 

2237 logger.dedent() 

2238 

2239 if len(hidden) > 0: 

2240 logger.debug( 

2241 "op:%s-%d.to_onnx:to_onnx:%s-%d:hidden:%r", 

2242 self.__class__.__name__, id(self), 

2243 node.__class__.__name__, id(node), hidden) 

2244 builder.get_input_names(node, hidden) 

2245 node.add_to(builder) 

2246 

2247 logger.debug( 

2248 "op:%s-%d.to_onnx:to_onnx:a", self.__class__.__name__, id(self)) 

2249 logger.indent() 

2250 

2251 # fix missing inputs 

2252 if isinstance(inputs, dict): 

2253 known = set() 

2254 for gi in graph_inputs: 

2255 known.add(gi.var.name) 

2256 for name, dtype in inputs.items(): 

2257 if name not in known: 

2258 logger.debug( 

2259 "%s-%d.to_onnx:+:%s:%r", 

2260 self.__class__.__name__, id(self), name, dtype) 

2261 var = InputDetectedVariable( 

2262 None, Variable(name, dtype=dtype)) 

2263 graph_inputs.append(var) 

2264 builder.input_names[name] = var 

2265 for v in graph_inputs: 

2266 if v.var.name not in builder.input_names: 

2267 builder.input_names[v.var.name] = v 

2268 

2269 onx = builder.to_onnx( 

2270 inputs=graph_inputs, outputs=graph_outputs, 

2271 target_opset=target_opset, verbose=verbose, 

2272 optim=optim, run_shape=run_shape and run_shape2, 

2273 function_name=function_name, function_domain=function_domain, 

2274 check_model=check_model) 

2275 logger.dedent() 

2276 

2277 logger.debug( 

2278 "op:%s-%d.to_onnx:to_onnx:b:%s:%d-nodes", 

2279 self.__class__.__name__, id(self), type(onx).__name__, 

2280 len(onx.graph.node) if hasattr(onx, 'graph') else onx.node) 

2281 if return_builder: 

2282 return onx, builder 

2283 return onx 

2284 

2285 def _to_onnx_attributes(self, inputs=None, target_opset=None, 

2286 optim=True, verbose=0, run_shape=True, 

2287 fLOG=print, processed=None): 

2288 """ 

2289 Converts attributes into ONNX. 

2290 Returns the hidden inputs. 

2291 """ 

2292 if processed is None: 

2293 raise RuntimeError( # pragma: no cover 

2294 "processed cannot be None.") 

2295 converts = [] 

2296 for k, v in self.kwargs.items(): 

2297 if isinstance(v, OnnxOperatorBase): 

2298 converts.append(k) 

2299 hidden_inputs = [] 

2300 for name in converts: 

2301 if verbose > 0: 

2302 fLOG( # pragma: no cover 

2303 '[OnnxOperator._to_onnx_attributes] process %r of type %r.' 

2304 '' % (name, type(self.kwargs[name]))) 

2305 model, hidden = self._to_onnx_attribute( 

2306 name, self.kwargs[name], inputs=inputs, target_opset=target_opset, 

2307 optim=optim, verbose=verbose, run_shape=run_shape, fLOG=fLOG, 

2308 processed=processed) 

2309 

2310 hidden_inputs.extend(hidden) 

2311 if len(model.graph.node) == 0: 

2312 _, hidden = self._to_onnx_attribute( 

2313 name, self.kwargs[name], inputs=inputs, target_opset=target_opset, 

2314 optim=False, verbose=verbose, run_shape=run_shape, fLOG=fLOG, 

2315 processed=processed) 

2316 raise RuntimeError( # pragma: no cover 

2317 "Conversion to graph of parameter %r from\nnode=%r " 

2318 "and\ninputs=%r\nis empty:\n%s\nHIDDEN\n%r" % ( 

2319 name, self.kwargs[name], self.kwargs[name].inputs, 

2320 model, hidden)) 

2321 if name in {'else_branch', 'then_branck'}: 

2322 if len(model.graph.input) > 0: 

2323 # else_branch, then_branch must not have any input. 

2324 del model.graph.input[:] 

2325 self.kwargs[name] = model.graph 

2326 return hidden_inputs 

2327 

2328 def _to_onnx_attribute(self, att_name, oxop, inputs=None, target_opset=None, 

2329 optim=True, verbose=0, run_shape=True, 

2330 fLOG=print, processed=None): 

2331 """ 

2332 Converts one subgraph into ONNX. 

2333 Returns the ONNX graph and the hidden inputs. 

2334 """ 

2335 if processed is None: 

2336 raise RuntimeError( # pragma: no cover 

2337 "processed cannot be None.") 

2338 if inputs is None: 

2339 vars = None 

2340 else: 

2341 named_inputs = set(oxop.find_named_inputs()) 

2342 vars = [] 

2343 added = set() 

2344 for inp in inputs: 

2345 if inp.var.name in named_inputs and inp.var.name not in added: 

2346 added.add(inp.var.name) 

2347 vars.append(Variable( 

2348 inp.var.name, inp.var.dtype or inp.var.added_dtype)) 

2349 if verbose > 0: 

2350 fLOG( # pragma: no cover 

2351 f'[OnnxOperator._to_onnx_attribute] inputs={vars!r}') 

2352 logger.debug("op:%s._to_onnx_attribute:%s:inputs(%r)", 

2353 self.__class__.__name__, att_name, vars) 

2354 logger.indent() 

2355 onx, att_builder = oxop.to_onnx( 

2356 inputs=vars, target_opset=target_opset, run_shape=run_shape, 

2357 verbose=verbose, fLOG=fLOG, processed=processed, optim=False, 

2358 check_model=False, return_builder=True) 

2359 logger.dedent() 

2360 hidden_inputs = att_builder.hidden_input 

2361 if len(hidden_inputs) > 0: 

2362 if verbose > 0: 

2363 fLOG( # pragma: no cover 

2364 f'[OnnxOperator._to_onnx_attribute] inputs={vars!r}') 

2365 logger.debug("op:%s._to_onnx_attribute:inputs:hidden:%r", 

2366 self.__class__.__name__, att_builder.hidden_input) 

2367 if len(onx.graph.node) == 0: 

2368 raise RuntimeError( # pragma: no cover 

2369 "Empty graph (class=%r, optim=%r) from\nnode=%r " 

2370 "and\ninputs=%r\nis empty:\n%s" % ( 

2371 type(oxop), optim, oxop, vars, onx)) 

2372 shaped_onx = infer_shapes(onx) 

2373 return shaped_onx, hidden_inputs 

2374 

2375 def predecessors(self): 

2376 """ 

2377 Returns the list of predecessors. 

2378 

2379 :return: list of @see cl OnnxOperator 

2380 """ 

2381 stack = [self] 

2382 last = 0 

2383 while True: 

2384 end = len(stack) 

2385 if end == last: 

2386 break 

2387 for i in range(last, end): 

2388 node = stack[i] 

2389 for inp in node.inputs: 

2390 if isinstance(inp, OnnxOperatorBase): 

2391 stack.append(inp) 

2392 last = end 

2393 return stack 

2394 

2395 def __call__(self, *args, function_name=None, function_domain=None, 

2396 **kwargs): 

2397 """ 

2398 Creates an instance of class @see cl OnnxOperatorFunction. 

2399 Equivalent to `OnnxOperatorFunction(proto, *args, **kwargs)`. 

2400 

2401 :param args: see @see cl OnnxOperatorFunction 

2402 :param function_name: name to be given to the function 

2403 :param function_domain: function domain, if None, 

2404 it is given a default value 

2405 :param kwargs: see @see cl OnnxOperatorFunction 

2406 :return: instance of type @see cl OnnxOperatorFunction 

2407 """ 

2408 if function_name is None: 

2409 def clean(name): 

2410 if name.startswith("Onnx"): 

2411 name = name[4:] 

2412 return name 

2413 

2414 pred = self.predecessors() 

2415 cls = [clean(p.__class__.__name__) for p in pred] 

2416 function_name = "".join(cls) 

2417 onx = self.to_onnx(function_name=function_name, 

2418 function_domain=function_domain) 

2419 return OnnxOperatorFunction(onx, *args, **kwargs) 

2420 

2421 def find_named_inputs(self): 

2422 """ 

2423 Retrieves all named inputs in this graph. 

2424 """ 

2425 unique = set() 

2426 found = [] 

2427 for inp in self.inputs: 

2428 if isinstance(inp, str): 

2429 if inp not in unique: 

2430 found.append(inp) 

2431 unique.add(inp) 

2432 elif isinstance(inp, Variable): 

2433 if inp.name not in unique: 

2434 found.append(inp.name) 

2435 unique.add(inp.name) 

2436 elif isinstance(inp, OnnxOperatorBase): 

2437 f = inp.find_named_inputs() 

2438 for n in f: 

2439 if n not in unique: 

2440 found.append(n) 

2441 unique.add(n) 

2442 elif isinstance(inp, numpy.ndarray): 

2443 pass 

2444 else: 

2445 raise RuntimeError( # pragma: no cover 

2446 f"Unexpected input type {type(inp)!r}.") 

2447 return found 

2448 

2449 def to_onnx_this(self, evaluated_inputs): 

2450 """ 

2451 Returns a simple ONNX graph corresponding to this node. 

2452 

2453 :param evaluated_inputs: inputs as a list 

2454 :return: ONNX graph 

2455 

2456 (OnnxOperator) 

2457 """ 

2458 logger.debug('op:%s-%d.to_onnx_this:%r', 

2459 self.__class__.__name__, id(self), 

2460 evaluated_inputs) 

2461 inputs_names = ['I%d' % i for i in range(len(evaluated_inputs))] 

2462 if self.output_names is None: 

2463 if self.expected_outputs is None: 

2464 raise NotImplementedError( # pragma: no cover 

2465 "expected_outputs and output_names are not defined.") 

2466 output_names = [o[0] for o in self.expected_outputs] 

2467 else: 

2468 output_names = [o.name for o in self.output_names] 

2469 node = make_node(self.op_type, inputs_names, output_names, 

2470 domain=self.domain, name="f", **self.kwargs) 

2471 onx_inputs = [Variable(name, a.dtype).make_value_info() 

2472 for name, a in zip(inputs_names, evaluated_inputs)] 

2473 onx_outputs = [make_value_info(name, make_tensor_type_proto(0, [])) 

2474 for name in output_names] 

2475 graph = make_graph([node], 'f', onx_inputs, onx_outputs) 

2476 model = make_model( 

2477 graph, opset_imports=[make_operatorsetid( 

2478 self.domain or '', self.since_version)]) 

2479 return model 

2480 

2481 def run(self, *inputs, verbose=0, fLOG=None, clear_cache=False, runtime=None): 

2482 """ 

2483 Other name for 

2484 `OnnxInference.f <mlprodict.onnxrt.onnx_inference.OnnxInference.f>`_. 

2485 """ 

2486 return self.f(*inputs, verbose=verbose, fLOG=fLOG, 

2487 clear_cache=clear_cache, runtime=runtime) 

2488 

2489 def f(self, *inputs, verbose=0, fLOG=None, # pylint: disable=W0221 

2490 clear_cache=False, runtime=None): 

2491 """ 

2492 Computes the predictions for this node. 

2493 Similar to an eager evaluation. 

2494 

2495 :param inputs: inputs as dictionary or a list of inputs 

2496 (see below) 

2497 :param verbose: display information while predicting 

2498 :param fLOG: logging function if *verbose > 0* 

2499 :param clear_cache: onnx graph is created once unless 

2500 this parameter is True 

2501 :param runtime: runtime to use for the evaluation, 

2502 see @see cl OnnxInference 

2503 :return: outputs as a dictionary if the input were given as a 

2504 dictionary or a single result or a tuple otherwise 

2505 

2506 The inputs refer to the inputs of the graph. 

2507 The method walks through all inputs and finds inputs defined as 

2508 string. It replaces them by the value found in the dictionary. 

2509 If the inputs are specified in a list, the function retrieves the 

2510 list of inputs defined as a string and assigns them a value. 

2511 Logging function can be used to get more insight about it. 

2512 During the evaluation every node is independently converted 

2513 into ONNX. The ONNX graph is cached in the class itself. 

2514 """ 

2515 # input evaluation 

2516 if len(inputs) == 1 and isinstance(inputs[0], dict): 

2517 dict_inputs = inputs[0] 

2518 as_dict = True 

2519 elif not isinstance(inputs, (tuple, list, OnnxOperator._InputContainer)): 

2520 raise TypeError( # pragma: no cover 

2521 f"inputs must be a list not {type(inputs)!r}.") 

2522 elif len(inputs) > 0 and isinstance(inputs[0], OnnxOperator): 

2523 raise TypeError( # pragma: no cover 

2524 f"Unexpected type for inputs[0]: {type(inputs[0])!r}.") 

2525 else: 

2526 as_dict = False 

2527 if verbose > 0: 

2528 fLOG( # pragma: no cover 

2529 "[OnnxOperator.f] retrieves named inputs") 

2530 if hasattr(self, "feval_named_inputs_"): 

2531 named_inputs = self.feval_named_inputs_ # pylint: disable=E0203 

2532 else: 

2533 named_inputs = self.find_named_inputs() 

2534 self.feval_named_inputs_ = named_inputs 

2535 if len(named_inputs) != len(inputs): 

2536 raise RuntimeError( 

2537 "Mismatch between the number of found inputs (%d) and " 

2538 "the number of given inputs (%d) (found %r)." 

2539 "" % ( 

2540 len(named_inputs), len(inputs), named_inputs)) 

2541 dict_inputs = { 

2542 name: value for name, value in zip(named_inputs, inputs)} 

2543 if verbose > 0: 

2544 fLOG( # pragma: no cover 

2545 f"[OnnxOperator.f] found inputs: {named_inputs!r}") 

2546 

2547 # conversion 

2548 evaluated_inputs = [] 

2549 for i, inp in enumerate(self.inputs): 

2550 if isinstance(inp, str): 

2551 evaluated_inputs.append(dict_inputs[inp]) 

2552 elif isinstance(inp, Variable): 

2553 evaluated_inputs.append(dict_inputs[inp.name]) 

2554 elif isinstance(inp, OnnxOperatorBase): 

2555 if verbose > 0: 

2556 fLOG( # pragma: no cover 

2557 "[OnnxOperator.f] evaluate input %d (op_type=%r)" % ( 

2558 i, self.__class__.op_type)) 

2559 out = inp.f(dict_inputs, verbose=verbose, fLOG=fLOG) 

2560 if isinstance(out, dict): 

2561 if len(out) == 1: 

2562 evaluated_inputs.append(out.popitem()[1]) 

2563 else: 

2564 raise NotImplementedError( # pragma: no cover 

2565 "Not yet implemented in case when there are multiple " 

2566 "outputs (%r)." % list(out)) 

2567 elif isinstance(out, (list, OnnxOperator._InputContainer)): 

2568 evaluated_inputs.extend(out) 

2569 else: 

2570 evaluated_inputs.append(out) 

2571 elif isinstance(inp, numpy.ndarray): 

2572 evaluated_inputs.append(inp) 

2573 else: 

2574 raise RuntimeError( # pragma: no cover 

2575 "Unexpected type %r for input %d." % (type(inp), i)) 

2576 

2577 # conversion to ONNX 

2578 if not hasattr(self, 'feval_onnx_'): 

2579 self.feval_onnx_ = {} 

2580 key = tuple((m.dtype, m.shape) for m in evaluated_inputs) 

2581 if key not in self.feval_onnx_ or clear_cache: 

2582 if verbose > 0: 

2583 fLOG( 

2584 f"[OnnxOperator.f] creating node {self.op_type!r}, inputs={key!r}") 

2585 from ..onnxrt import OnnxInference 

2586 model = self.to_onnx_this(evaluated_inputs) 

2587 oinf = OnnxInference(model, runtime=runtime) 

2588 self.feval_onnx_[key] = oinf 

2589 else: 

2590 oinf = self.feval_onnx_[key] 

2591 

2592 # execution 

2593 if verbose > 0: 

2594 fLOG(f"[OnnxOperator.f] execute node {self.op_type!r}") 

2595 got = oinf.run({k: v for k, v in 

2596 zip(oinf.input_names, evaluated_inputs)}) 

2597 if as_dict: 

2598 return got 

2599 if len(got) == 1: 

2600 return got.popitem()[1] 

2601 return [got[n] for n in oinf.output_names] 

2602 

2603 @staticmethod 

2604 def _merge_op_version(n1, n2, at_least=None): 

2605 if isinstance(n2, OnnxOperator): 

2606 if n1.op_version is None: 

2607 opv = n2.op_version 

2608 elif n2.op_version is None: 

2609 opv = n1.op_version 

2610 elif n1.op_version == n2.op_version: 

2611 opv = n1.op_version 

2612 else: 

2613 opv = max(n1.op_version, n2.op_version) 

2614 elif isinstance(n2, OnnxOperatorItem): 

2615 opv = OnnxOperator._merge_op_version(n1, n2.onx_op) 

2616 elif isinstance(n2, OnnxOperatorTuple): 

2617 raise NotImplementedError( # pragma: no cover 

2618 "_merge_op_version is not implemented when n2 " 

2619 "is OnnxOperatorTuple.") 

2620 else: 

2621 opv = n1.op_version 

2622 if at_least is not None and opv is not None and opv < at_least: 

2623 opv = at_least 

2624 return opv 

2625 

2626 def __add__(self, ov): 

2627 """ 

2628 Automatically adds operator `OnnxAdd` to the graph. 

2629 

2630 :param ov: onnx node 

2631 :return: `OnnxAdd(self, ov)` 

2632 """ 

2633 OnnxAdd = loadop('Add') 

2634 opv = self._merge_op_version(self, ov, at_least=15) 

2635 if isinstance(ov, (int, float)): 

2636 OnnxCastLike = loadop('CastLike') 

2637 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv) 

2638 return OnnxAdd(self, ov, op_version=opv) 

2639 

2640 def __sub__(self, ov): 

2641 """ 

2642 Automatically adds operator `OnnxSub` to the graph. 

2643 

2644 :param ov: onnx node 

2645 :return: `OnnxSub(self, ov)` 

2646 """ 

2647 OnnxSub = loadop('Sub') 

2648 opv = self._merge_op_version(self, ov, at_least=15) 

2649 if isinstance(ov, (int, float)): 

2650 OnnxCastLike = loadop('CastLike') 

2651 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv) 

2652 return OnnxSub(self, ov, op_version=opv) 

2653 

2654 def __mul__(self, ov): 

2655 """ 

2656 Automatically adds operator `OnnxMul` to the graph. 

2657 

2658 :param ov: onnx node 

2659 :return: `OnnxMul(self, ov)` 

2660 """ 

2661 OnnxMul = loadop('Mul') 

2662 opv = self._merge_op_version(self, ov, at_least=15) 

2663 if isinstance(ov, (int, float)): 

2664 OnnxCastLike = loadop('CastLike') 

2665 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv) 

2666 return OnnxMul(self, ov, op_version=opv) 

2667 

2668 def __truediv__(self, ov): 

2669 """ 

2670 Automatically adds operator `OnnxDiv` to the graph. 

2671 

2672 :param ov: onnx node 

2673 :return: `OnnxDiv(self, ov)` 

2674 """ 

2675 OnnxDiv = loadop('Div') 

2676 opv = self._merge_op_version(self, ov, at_least=15) 

2677 if isinstance(ov, (int, float)): 

2678 OnnxCastLike = loadop('CastLike') 

2679 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv) 

2680 return OnnxDiv(self, ov, op_version=opv) 

2681 

2682 def __pow__(self, ov): 

2683 """ 

2684 Automatically adds operator `OnnxPow` to the graph. 

2685 

2686 :param ov: onnx node 

2687 :return: `OnnPow(self, ov)` 

2688 """ 

2689 OnnxPow = loadop('Pow') 

2690 opv = self._merge_op_version(self, ov, at_least=15) 

2691 if isinstance(ov, (int, float)): 

2692 OnnxCastLike = loadop('CastLike') 

2693 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv) 

2694 return OnnxPow(self, ov, op_version=opv) 

2695 

2696 def __mod__(self, ov): 

2697 """ 

2698 Automatically adds operator `OnnxMod` to the graph. 

2699 

2700 :param ov: onnx node 

2701 :return: `OnnxMod(self, ov)` 

2702 """ 

2703 OnnxMod = loadop('Mod') 

2704 opv = self._merge_op_version(self, ov, at_least=15) 

2705 if isinstance(ov, (int, float)): 

2706 OnnxCastLike = loadop('CastLike') 

2707 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv) 

2708 return OnnxMod(self, ov, op_version=opv) 

2709 

2710 def __matmul__(self, ov): 

2711 """ 

2712 Automatically adds operator `OnnxMatMul` to the graph. 

2713 

2714 :param ov: onnx node 

2715 :return: `OnnMatMul(self, ov)` 

2716 """ 

2717 OnnxMatMul = loadop('MatMul') 

2718 opv = self._merge_op_version(self, ov) 

2719 return OnnxMatMul(self, ov, op_version=opv) 

2720 

2721 def __gt__(self, ov): 

2722 """ 

2723 Automatically adds operator `OnnxGreater` to the graph. 

2724 

2725 :param ov: onnx node 

2726 :return: `OnnxGreater(self, ov)` 

2727 """ 

2728 OnnxGreater = loadop('Greater') 

2729 opv = self._merge_op_version(self, ov, at_least=15) 

2730 if isinstance(ov, (int, float)): 

2731 OnnxCastLike = loadop('CastLike') 

2732 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv) 

2733 return OnnxGreater(self, ov, op_version=opv) 

2734 

2735 def __ge__(self, ov): 

2736 """ 

2737 Automatically adds operator `OnnxGreaterOrEqual` to the graph. 

2738 

2739 :param ov: onnx node 

2740 :return: `OnnxGreater(self, ov)` 

2741 """ 

2742 OnnxGreaterOrEqual = loadop('GreaterOrEqual') 

2743 opv = self._merge_op_version(self, ov, at_least=15) 

2744 if isinstance(ov, (int, float)): 

2745 OnnxCastLike = loadop('CastLike') 

2746 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv) 

2747 return OnnxGreaterOrEqual(self, ov, op_version=opv) 

2748 

2749 def __lt__(self, ov): 

2750 """ 

2751 Automatically adds operator `OnnxLess` to the graph. 

2752 

2753 :param ov: onnx node 

2754 :return: `OnnxLess(self, ov)` 

2755 """ 

2756 OnnxLess = loadop('Less') 

2757 opv = self._merge_op_version(self, ov, at_least=15) 

2758 if isinstance(ov, (int, float)): 

2759 OnnxCastLike = loadop('CastLike') 

2760 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv) 

2761 return OnnxLess(self, ov, op_version=opv) 

2762 

2763 def __le__(self, ov): 

2764 """ 

2765 Automatically adds operator `OnnxLess` to the graph. 

2766 

2767 :param ov: onnx node 

2768 :return: `OnnxLess(self, ov)` 

2769 """ 

2770 OnnxLessOrEqual = loadop('LessOrEqual') 

2771 opv = self._merge_op_version(self, ov, at_least=15) 

2772 if isinstance(ov, (int, float)): 

2773 OnnxCastLike = loadop('CastLike') 

2774 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv) 

2775 return OnnxLessOrEqual(self, ov, op_version=opv) 

2776 

2777 def __eq__(self, ov): 

2778 """ 

2779 Automatically adds operator `OnnxEqual` to the graph. 

2780 

2781 :param ov: onnx node 

2782 :return: `OnnxEqual(self, ov)` 

2783 """ 

2784 OnnxEqual = loadop('Equal') 

2785 opv = self._merge_op_version(self, ov, at_least=15) 

2786 if isinstance(ov, (int, float)): 

2787 OnnxCastLike = loadop('CastLike') 

2788 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv) 

2789 return OnnxEqual(self, ov, op_version=opv) 

2790 

2791 def and_(self, ov): 

2792 """ 

2793 Automatically adds operator `OnnxAnd` to the graph. 

2794 

2795 :param ov: onnx node 

2796 :return: `OnnxAnd(self, ov)` 

2797 """ 

2798 OnnxAnd = loadop('And') 

2799 opv = self._merge_op_version(self, ov) 

2800 return OnnxAnd(self, ov, op_version=opv) 

2801 

2802 def or_(self, ov): 

2803 """ 

2804 Automatically adds operator `OnnxOr` to the graph. 

2805 

2806 :param ov: onnx node 

2807 :return: `OnnxOr(self, ov)` 

2808 """ 

2809 OnnxOr = loadop('Or') 

2810 opv = self._merge_op_version(self, ov) 

2811 return OnnxOr(self, ov, op_version=opv) 

2812 

2813 def __ne__(self, ov): 

2814 """ 

2815 Automatically adds operator `OnnxNot x OnnxEqual` to the graph. 

2816 

2817 :param ov: onnx node 

2818 :return: `OnnxNot(OnnxEqual(self, ov))` 

2819 """ 

2820 OnnxNot, OnnxEqual = loadop('Not', 'Equal') 

2821 opv = self._merge_op_version(self, ov, at_least=15) 

2822 if isinstance(ov, (int, float)): 

2823 OnnxCastLike = loadop('CastLike') 

2824 ov = OnnxCastLike(numpy.array([ov]), self, op_version=opv) 

2825 return OnnxNot(OnnxEqual(self, ov, op_version=opv), op_version=opv) 

2826 

2827 def __abs__(self): 

2828 """ 

2829 Automatically adds operator `OnnxAbs` to the graph. 

2830 

2831 :param ov: onnx node 

2832 :return: `OnnxAbs(self, ov)` 

2833 """ 

2834 OnnxAbs = loadop('Abs') 

2835 return OnnxAbs(self, op_version=self.op_version) 

2836 

2837 def not_(self): 

2838 """ 

2839 Automatically adds operator `OnnxNot` to the graph. 

2840 

2841 :param ov: onnx node 

2842 :return: `OnnxNot(self, ov)` 

2843 """ 

2844 OnnxNot = loadop('Not') 

2845 return OnnxNot(self, op_version=self.op_version) 

2846 

2847 def astype(self, to): 

2848 """ 

2849 Automatically adds operator `OnnxCast` to the graph. 

2850 

2851 :param ov: onnx node 

2852 :return: `OnnxCast(self, ov, to=to)` 

2853 """ 

2854 OnnxCast = loadop('Cast') 

2855 return OnnxCast(self, to=to, op_version=self.op_version) 

2856 

2857 

2858class OnnxOperatorFunction(OnnxOperator): 

2859 """ 

2860 This operator is used to insert existing ONNX function into 

2861 the ONNX graph being built. 

2862 

2863 :param function_proto: instance of type :epkg:`FunctionProto` 

2864 :param inputs: inputs 

2865 :param output_names: output names 

2866 :param sub_functions: functions called by this one 

2867 """ 

2868 

2869 domain = 'mlprodict' 

2870 since_version = 1 

2871 expected_inputs = None 

2872 expected_outputs = None 

2873 input_range = [1, 1e9] 

2874 output_range = [1, 1e9] 

2875 op_type = 'Function' 

2876 domain = 'mlprodict.xop' 

2877 

2878 @staticmethod 

2879 def attribute_to_value(att): 

2880 """ 

2881 Converts an attribute into a value using python structures. 

2882 """ 

2883 if isinstance(att, onnx.AttributeProto): 

2884 dtype = att.type 

2885 else: 

2886 raise NotImplementedError( # pragma: no cover 

2887 f"Unable to copy attribute type {type(att)!r}.") 

2888 if dtype == 1: # .f 

2889 value = att.f 

2890 elif dtype == 2: # .i 

2891 value = att.i 

2892 elif dtype == 3: # .s 

2893 value = att.s 

2894 elif dtype == 4: # .t 

2895 value = att.t 

2896 elif dtype == 6: # .floats 

2897 value = list(att.floats) 

2898 elif dtype == 7: # .ints 

2899 value = list(att.ints) 

2900 elif dtype == 8: # .strings 

2901 value = list(att.strings) 

2902 elif dtype == 11: # .double_data 

2903 value = list(att.double_data) 

2904 else: 

2905 raise NotImplementedError( # pragma: no cover 

2906 f"Unable to copy attribute type {dtype!r} ({att!r}).") 

2907 return value 

2908 

2909 def __init__(self, function_proto, *inputs, output_names=None, 

2910 sub_functions=None): 

2911 logger.debug("op:Function(ONNX, %d in, output_names=%r)", 

2912 len(inputs), output_names) 

2913 if function_proto is None: 

2914 raise ValueError( 

2915 "function_proto cannot be None.") # pragma: no cover 

2916 if not isinstance(function_proto, onnx.FunctionProto): 

2917 raise TypeError( # pragma: no cover 

2918 "function_proto must be of type FunctionProto not %r." % 

2919 type(function_proto)) 

2920 if len(inputs) > len(function_proto.input): 

2921 raise RuntimeError( # pragma: no cover 

2922 "Unexpected number of inputs %r > expected %r." % ( 

2923 len(inputs), len(function_proto.input))) 

2924 if (output_names is not None and 

2925 len(output_names) != len(function_proto.output)): 

2926 raise RuntimeError( # pragma: no cover 

2927 "Unexpected number of outputs %r != expected %r." % ( 

2928 len(output_names), len(function_proto.output))) 

2929 OnnxOperator.__init__(self, *inputs, output_names=output_names) 

2930 self.model = function_proto 

2931 self.sub_functions = sub_functions 

2932 

2933 def __repr__(self): 

2934 "usual" 

2935 atts = {} 

2936 for att in ['output_names']: 

2937 value = getattr(self, att, None) 

2938 if value is not None: 

2939 atts[att] = value 

2940 atts.update(self.kwargs) 

2941 if self.sub_functions is not None and len(self.sub_functions) > 0: 

2942 atts["sub_functions"] = list(range(len(self.sub_functions))) 

2943 msg = ", ".join(f"{k}={v!r}" for k, v in atts.items()) 

2944 if len(atts) > 0: 

2945 msg = ", " + msg 

2946 return f"{self.__class__.__name__}(...{msg})" 

2947 

2948 def add_to(self, builder): 

2949 """ 

2950 Adds to graph builder. 

2951 

2952 :param builder: instance of @see cl _GraphBuilder, 

2953 it must have a method `add_node` 

2954 """ 

2955 logger.debug("op:Function.add_to(builder)") 

2956 inputs = builder.get_input_names(self, self.inputs) 

2957 n_outputs = len(self.model.output) 

2958 outputs = [builder.get_unique_output_name(NodeResultName(self, i)) 

2959 for i in range(n_outputs)] 

2960 

2961 # linking inputs 

2962 logger.indent() 

2963 if self.sub_functions is not None: 

2964 for sub in self.sub_functions: 

2965 builder.add_function(sub) 

2966 builder.add_function(self.model) 

2967 builder.add_node( 

2968 self.model.name, builder.get_unique_name( 

2969 '_fct_' + self.model.name, reserved=False), 

2970 inputs, outputs, domain=self.model.domain) 

2971 logger.dedent() 

2972 

2973 

2974class _GraphBuilder: 

2975 """ 

2976 Graph builder. It takes a graph structure made with 

2977 instances of @see cl OnnxOperatorBase. 

2978 The main method is `to_onnx`. 

2979 

2980 * `initializer`: list of initializers to add to the ONNX graph 

2981 * `node`: list of nodes to add to the ONNX graph 

2982 * `input`: list of inputs to add to the ONNX graph 

2983 * `output`: list of inputs to add to the ONNX graph 

2984 * `opsets`: opsets of the ONNX graph 

2985 * `input_names`: dictionary of input names 

2986 `{name: InputDetectedVariable}` 

2987 * `node_output_names`: memorizes a name for a node output 

2988 when the user did not specify any 

2989 `{(id(node), index): OutputDetectedVariable}` 

2990 * `reserved_names`: dictionary `{ name : (node, index) }`, 

2991 name which should remain unchanged in the ONNX graph 

2992 * `names`: list of uniques names 

2993 * `functions`: dictionary `{ domain, name: function_proto }` 

2994 * `function_hashes`: dictionary `{ domain, name: hash of function_proto }` 

2995 """ 

2996 

2997 def __init__(self): 

2998 self.initializer = [] 

2999 self.node = [] 

3000 self.input = [] 

3001 self.output = [] 

3002 self.opsets = {} 

3003 self.input_names = {} 

3004 self.node_output_names = {} 

3005 self.reserved_names = {} 

3006 self.names = set() 

3007 self.functions = {} 

3008 self.function_hashes = {} 

3009 logger.debug('_GraphBuilder-%d:new', id(self)) 

3010 

3011 def _add_domain(self, domain, version): 

3012 if domain not in self.opsets: 

3013 self.opsets[domain] = version 

3014 else: 

3015 self.opsets[domain] = max(version, self.opsets[domain]) 

3016 

3017 def _add_name(self, name): 

3018 self.names.add(name) 

3019 

3020 @staticmethod 

3021 def number2alpha(index): 

3022 """ 

3023 Converts a numbers into a string keeping the same 

3024 alphabetical order. 

3025 """ 

3026 dec = str(int(index)) 

3027 if len(dec) == 1: 

3028 return dec 

3029 return chr(96 + len(dec)) + dec 

3030 

3031 def reserve_names(self, node, output_names): 

3032 """ 

3033 Adds names to the list of reserved names. 

3034 All must be unique. 

3035 

3036 :param node: node or None for an input 

3037 :param output_names: names of the output 

3038 """ 

3039 if output_names is None: 

3040 return 

3041 for index, var in enumerate(output_names): 

3042 if not isinstance(var, (Variable, ExistingVariable)): 

3043 raise TypeError( # pragma: no cover 

3044 f"Unexpected type {type(var)!r} for {var!r}.") 

3045 self.reserve_name(node, var.name, index) 

3046 

3047 def reserve_name(self, node, name, index): 

3048 """ 

3049 Reserves a name so that it cannot be changed. 

3050 

3051 :param node: node or None for an input 

3052 :param name: name 

3053 :param index: input index 

3054 """ 

3055 if not isinstance(name, str): 

3056 raise TypeError( # pragma: no cover 

3057 f"Name {name!r} is not a string.") 

3058 if name in self.reserved_names: 

3059 raise RuntimeError( # pragma: no cover 

3060 "Name %r is already reserved from node %r, index=%d." % ( 

3061 name, node, index)) 

3062 logger.debug("_GraphBuilder-%d.reserve_name([%s-%d], %r, %r)", 

3063 id(self), node.__class__.__name__, id(node), 

3064 name, index) 

3065 self.reserved_names[name] = (node, index) 

3066 self._add_name(name) 

3067 

3068 def get_unique_output_name(self, result): 

3069 """ 

3070 Returns a unique output_name for a NodeResultName. 

3071 

3072 :param result: instance of @see cl NodeResultName 

3073 """ 

3074 if not isinstance(result, NodeResultName): 

3075 raise TypeError( # pragma: no cover 

3076 "Result must be of type NodeResultName not %r (%r)." % ( 

3077 type(result), result)) 

3078 if result.node is None: 

3079 key = None, result.index 

3080 else: 

3081 key = id(result.node), result.index 

3082 if key in self.node_output_names: 

3083 return self.node_output_names[key] 

3084 name = result.get_name() 

3085 if name in self.reserved_names: 

3086 unique = name 

3087 else: 

3088 unique = self.get_unique_name(name) 

3089 self.node_output_names[key] = unique 

3090 return unique 

3091 

3092 def get_unique_name(self, name, reserved=True): 

3093 """ 

3094 Returns a unique name to name an output. 

3095 

3096 :param name: name 

3097 :param reserved: bypass if the name is a reserved one 

3098 :return: unique name, may be the same if not taken already 

3099 """ 

3100 if not isinstance(name, str): 

3101 raise TypeError( # pragma: no cover 

3102 f"name must be a string not {type(name)!r}.") 

3103 if reserved and name in self.reserved_names: 

3104 logger.debug( # pragma: no cover 

3105 "_GraphBuilder-%d.get_unique_name(%r) 1-> %r", 

3106 id(self), name, name) 

3107 return name 

3108 if name not in self.names: 

3109 self._add_name(name) 

3110 logger.debug("_GraphBuilder-%d.get_unique_name(%r) 2-> %r", 

3111 id(self), name, name) 

3112 return name 

3113 i = 1 

3114 new_name = f"{name}_{self.number2alpha(i)}" 

3115 while new_name in self.names: 

3116 i += 1 

3117 new_name = f"{name}_{self.number2alpha(i)}" 

3118 self._add_name(new_name) 

3119 logger.debug("_GraphBuilder-%d.get_unique_name(%r) 3-> %r", 

3120 id(self), name, new_name) 

3121 return new_name 

3122 

3123 def get_input_names(self, node, inputs): 

3124 """ 

3125 Returns input names for node *node* and inputs *inputs*. 

3126 

3127 :param node: node 

3128 :param inputs: inputs 

3129 :return: name 

3130 """ 

3131 logger.debug( 

3132 "_GraphBuilder-%d.get_input_names:1:%s-%d:%r", 

3133 id(self), node.__class__.__name__, id(node), inputs) 

3134 names = [] 

3135 for i in inputs: 

3136 if isinstance(i, (Variable, ExistingVariable)): 

3137 self._add_name(i.name) 

3138 names.append(i.name) 

3139 if i.name in self.input_names: 

3140 if isinstance(i, Variable): 

3141 self.input_names[i.name] = InputDetectedVariable( 

3142 None, i) 

3143 logger.debug( 

3144 "_GraphBuilder-%d.get_input_names:2:a:%d:+input_names:%s", 

3145 id(self), id(node), i.name) 

3146 else: 

3147 logger.debug( # pragma: no cover 

3148 "_GraphBuilder-%d.get_input_names:2:a:%d:=input_names:%s", 

3149 id(self), id(node), i.name) 

3150 else: 

3151 self.input_names[i.name] = InputDetectedVariable(None, i) 

3152 logger.debug( 

3153 "_GraphBuilder-%d.get_input_names:2:b:%d:+input_names:%s", 

3154 id(self), id(node), i.name) 

3155 elif isinstance(i, InputDetectedVariable): 

3156 self._add_name(i.name) 

3157 names.append(i.name) 

3158 if i.name in self.input_names: 

3159 logger.debug( # pragma: no cover 

3160 "_GraphBuilder-%d.get_input_names:2:c:%d:=input_names:%s", 

3161 id(self), id(node), i.name) 

3162 else: 

3163 self.input_names[i.name] = i 

3164 logger.debug( 

3165 "_GraphBuilder-%d.get_input_names:2:c:%d:+input_names:%s", 

3166 id(self), id(node), i.name) 

3167 elif isinstance(i, OnnxExisting): 

3168 inp = i.inputs[0] 

3169 n = inp.output_names[0] 

3170 self._add_name(n.name) 

3171 names.append(n.name) 

3172 if n.name in self.input_names: 

3173 if isinstance(inp, Variable): 

3174 self.input_names[n.name] = InputDetectedVariable( 

3175 None, n) 

3176 logger.debug( # pragma: no cover 

3177 "_GraphBuilder-%d.get_input_names:2:d:%d:+input_names:%s", 

3178 id(self), id(node), n.name) 

3179 else: 

3180 logger.debug( 

3181 "_GraphBuilder-%d.get_input_names:2:d:%d:=input_names:%s", 

3182 id(self), id(node), n.name) 

3183 else: 

3184 self.input_names[n.name] = InputDetectedVariable(None, n) 

3185 logger.debug( 

3186 "_GraphBuilder-%d.get_input_names:2:d:%d:+input_names:%s", 

3187 id(self), id(node), n.name) 

3188 elif isinstance(i, OnnxOperator): 

3189 key = id(i), 0 

3190 try: 

3191 name = self.node_output_names[key] 

3192 except KeyError as e: # pragma: no cover 

3193 raise RuntimeError( 

3194 "Unable to find key %r for input " 

3195 "(type(i) is %r, type(node) is %r) " 

3196 "%r in node %r among %r." % ( 

3197 key, type(i), type(node), i, node, 

3198 list(self.node_output_names))) from e 

3199 names.append(name) 

3200 elif isinstance(i, OnnxOperatorItem): 

3201 if isinstance(i.onx_op, OnnxOperatorTuple): 

3202 if i.onx_op.values is None: 

3203 key = id(i.onx_op.unique), i.index 

3204 else: 

3205 key = id(i.onx_op[i.index]), 0 

3206 elif isinstance(i.onx_op, OnnxOperator): 

3207 key = id(i.onx_op), i.index 

3208 else: 

3209 raise TypeError( # pragma: no cover 

3210 f"Unexpected type for OnnxOperatorItem: {type(i.onx_op)!r}.") 

3211 try: 

3212 name = self.node_output_names[key] 

3213 except KeyError as e: # pragma: no cover 

3214 raise RuntimeError( 

3215 "Unable to find key %r for input %r in node %r." % ( 

3216 key, i, node)) from e 

3217 names.append(name) 

3218 elif isinstance(i, OnnxOperatorTuple): 

3219 raise NotImplementedError() # pragma: no cover 

3220 elif isinstance(i, numpy.ndarray): 

3221 # Adding an initializer 

3222 name = self.get_unique_name('init', reserved=False) 

3223 init = from_array(i, name) 

3224 self.initializer.append(init) 

3225 names.append(name) 

3226 else: 

3227 raise TypeError( # pragma: no cover 

3228 f"Unexpected type for an input {type(i)!r}.") 

3229 logger.debug( 

3230 "_GraphBuilder-%d.get_input_names:3:%r", id(self), names) 

3231 return names 

3232 

3233 def add_initializer(self, name, init): 

3234 """ 

3235 Adds an initializer to the graph. 

3236 

3237 :param name: initializer name 

3238 :param init: initializer to copy 

3239 :return: created intializer 

3240 """ 

3241 if isinstance(init, onnx.TensorProto): 

3242 tensor = to_array(init) 

3243 val = from_array(tensor, name) 

3244 logger.debug("_GraphBuilder.add_initializer:1(%r, %r, %r)", 

3245 name, tensor.dtype, tensor.shape) 

3246 elif isinstance(init, numpy.ndarray): 

3247 value = to_array(init) 

3248 val = from_array(value, name) 

3249 logger.debug("_GraphBuilder.add_initializer:2(%r, %r, %r)", 

3250 name, init.dtype, init.shape) 

3251 else: 

3252 raise NotImplementedError( # pragma: no cover 

3253 f"Unsupported initializer type {type(init)!r}.") 

3254 self.initializer.append(val) 

3255 return val 

3256 

3257 def add_function(self, function_proto, 

3258 raise_if_exist=False, check_unique=True, 

3259 opset=1): 

3260 """ 

3261 Adds a function to the graph. 

3262 

3263 :param function_proto: instance of type :epkg:`FunctionProto` 

3264 :param raise_if_exist: raises an exception if a function of the 

3265 same name was already added 

3266 :param check_unique: checks if a function was added twice, 

3267 it is the same 

3268 :param opset: opset for the domain the function belongs to 

3269 """ 

3270 def _hash(p): 

3271 m = hashlib.sha256() 

3272 m.update(p.SerializeToString()) 

3273 return m.hexdigest()[:64] 

3274 

3275 key = function_proto.domain, function_proto.name 

3276 if key in self.functions: 

3277 if raise_if_exist: 

3278 raise RuntimeError( # pragma: no cover 

3279 f"Function {key!r} is added for the second time.") 

3280 if check_unique: 

3281 hs = _hash(function_proto) 

3282 if hs != self.function_hashes[key]: 

3283 raise RuntimeError( # pragma: no cover 

3284 "Function %r is added for the second time " 

3285 "and the content is not the same." % (key, )) 

3286 return 

3287 self.functions[key] = function_proto 

3288 self.function_hashes[key] = _hash(function_proto) 

3289 self._add_domain(function_proto.domain, opset) 

3290 

3291 def add_node(self, op_type, name, inputs, outputs, domain='', 

3292 opset=None, **attributes): 

3293 """ 

3294 Adds a node to the graph. 

3295 

3296 :param op_type: operator type 

3297 :param name: node name 

3298 :param inputs: inputs name list 

3299 :param outputs: outputs name list 

3300 :param domain: node domain 

3301 :param opset: node opset 

3302 :return: created node 

3303 """ 

3304 logger.debug("_GraphBuilder-%d.add_node(%r, %r, " 

3305 "inputs=%r, outputs=%r, domain=%r, opset=%r)", 

3306 id(self), op_type, name, inputs, outputs, domain, opset) 

3307 if not isinstance(inputs, (list, OnnxOperator._InputContainer)): 

3308 raise TypeError( # pragma: no cover 

3309 f"inputs must be a list not {type(inputs)!r}.") 

3310 if not isinstance(outputs, (list, OnnxOperator._InputContainer)): 

3311 raise TypeError( # pragma: no cover 

3312 f"inputs must be a list not {type(outputs)!r}.") 

3313 if any(map(lambda x: not isinstance(x, str), inputs)): 

3314 raise TypeError( # pragma: no cover 

3315 f"inputs must be all strings not {inputs!r}.") 

3316 if any(map(lambda x: not isinstance(x, str), outputs)): 

3317 raise TypeError( # pragma: no cover 

3318 f"outputs must be all strings not {outputs!r}.") 

3319 if opset is not None: 

3320 self._add_domain(domain, opset) 

3321 node = make_node(op_type, inputs, outputs, name=name, 

3322 domain=domain, **attributes) 

3323 self.node.append(node) 

3324 return node 

3325 

3326 def _process_io(self, inputs, input_names_): 

3327 logger.debug("_GraphBuilder-%d._process_io:1:inputs=%r", 

3328 id(self), inputs) 

3329 logger.debug("_GraphBuilder-%d._process_io:1:input_names_=%r", 

3330 id(self), input_names_) 

3331 if input_names_ is None: 

3332 input_names = None 

3333 else: 

3334 input_names = [] 

3335 for inp in input_names_: 

3336 if inp.var.name == '': 

3337 continue 

3338 input_names.append(inp) 

3339 

3340 if inputs is None: 

3341 logger.debug( # pragma: no cover 

3342 "_GraphBuilder-%d._process_io:return:%r", 

3343 id(self), self.input_names) 

3344 return [ 

3345 make_tensor_value_info( 

3346 'X', TensorProto.FLOAT, None) # pylint: disable=E1101 

3347 for name in self.input_names], None 

3348 

3349 if not isinstance(inputs, (list, OnnxOperator._InputContainer)): 

3350 if is_numpy_dtype(inputs): 

3351 inputs = [inputs] 

3352 

3353 logger.debug("_GraphBuilder-%d._process_io:2:input_names=%r", 

3354 id(self), input_names) 

3355 if input_names is None: 

3356 # outputs 

3357 set_names = set() 

3358 input_names = [] 

3359 new_inputs = [] 

3360 for inp in inputs: 

3361 if isinstance(inp, OutputDetectedVariable): 

3362 if inp.name in set_names: 

3363 raise ValueError( # pragma: no cover 

3364 f"Names already taken {inp.name!r} in {inputs!r}.") 

3365 set_names.add(inp.name) 

3366 if isinstance(inp.node, OnnxExisting): 

3367 raise NotImplementedError( # pragma: no cover 

3368 f"Unexpected name {inp.name!r} type {type(inp.node)!r}.") 

3369 # continue 

3370 key = id(inp.node), inp.index 

3371 if key in self.node_output_names: 

3372 new_name = self.node_output_names[key] 

3373 new_var = OutputDetectedVariable( 

3374 inp.node, inp.var.copy_name(new_name), inp.index) 

3375 input_names.append(new_var) 

3376 new_inputs.append(new_var) 

3377 else: 

3378 raise RuntimeError( # pragma: no cover 

3379 "Key %r is ambiguous or defined in " 

3380 "two nodes %r, id(node)=%d, index=%d." % ( 

3381 key, inp, id(inp.node), inp.index)) 

3382 else: 

3383 raise TypeError( # pragma: no cover 

3384 "Unexpected type %r (it should be " 

3385 "OutputDetectedVariable) in %r." % (inp, inputs)) 

3386 inputs = new_inputs 

3387 if len(input_names) == 0: 

3388 raise RuntimeError( # pragma: no cover 

3389 "Unable to cross %r and %r or %r (set_names=%r)." % ( 

3390 inputs, self.output_names_rev, 

3391 self.node_output_names_rev, set_names)) 

3392 elif not isinstance(input_names, (list, OnnxOperator._InputContainer)): 

3393 raise RuntimeError( # pragma: no cover 

3394 f"Unexpected type for input_names {type(input_names)!r}.") 

3395 else: 

3396 # inputs 

3397 pass 

3398 

3399 # common parts 

3400 logger.debug("_GraphBuilder-%d._process_io:3:input_names:%r", 

3401 id(self), input_names) 

3402 logger.debug("_GraphBuilder-%d._process_io:3:inputs:%r", 

3403 id(self), inputs) 

3404 no_exists_names = [c for c in input_names if not isinstance( 

3405 c.var, (ExistingVariable, OnnxExisting))] 

3406 no_exists = [c for c in inputs if not isinstance( 

3407 c.var, (ExistingVariable, OnnxExisting))] 

3408 

3409 if isinstance(input_names, (list, OnnxOperator._InputContainer)): 

3410 d_input_names = {} 

3411 for inp in input_names: 

3412 if inp.name in d_input_names: 

3413 raise ValueError( # pragma: no cover 

3414 f"Duplicated name {inp.name!r} in {input_names!r}.") 

3415 d_input_names[inp.name] = inp 

3416 elif isinstance(input_names, dict): 

3417 d_input_names = input_names 

3418 else: 

3419 raise TypeError( # pragma: no cover 

3420 "Unexpected type for input_names %r (%r)." % ( 

3421 type(input_names), input_names)) 

3422 

3423 logger.debug("_GraphBuilder-%d._process_io:4:no_exists_names:%r", 

3424 id(self), no_exists_names) 

3425 logger.debug("_GraphBuilder-%d._process_io:4:no_exists:%r", 

3426 id(self), no_exists) 

3427 

3428 # mapping 

3429 res = [] 

3430 for inp in no_exists: 

3431 if not isinstance(inp, DetectedVariable): 

3432 raise TypeError( # pragma: no cover 

3433 f"inp not DetectedVariable but {type(inp)!r} ({inp!r}).") 

3434 if inp.name.startswith('???'): 

3435 raise RuntimeError( # pragma: no cover 

3436 f"Issue with variable {inp!r}.") 

3437 var = d_input_names[inp.name] 

3438 if not isinstance(var, DetectedVariable): 

3439 raise TypeError( # pragma: no cover 

3440 f"var not Variable but {type(var)!r} ({var!r}).") 

3441 

3442 # inp: Variable 

3443 # var: str 

3444 if isinstance(var.var, ExistingVariable): 

3445 # It may be an input referenced in a subgraph and not used in the 

3446 # main graph. 

3447 if inp.var.name != var.var.name: 

3448 raise RuntimeError( # pragma: no cover 

3449 f"Unexpected {inp!r} != {var!r}.") 

3450 elif inp.var != var.var: 

3451 if (inp.var.name != var.var.name or ( 

3452 inp.var.dtype is not None and 

3453 var.var.dtype is not None)): 

3454 raise RuntimeError( # pragma: no cover 

3455 f"Unexpected {inp.var!r} != {var.var!r}.") 

3456 

3457 if isinstance(inp.var, ExistingVariable): 

3458 # The type of ExistingVariable must be known 

3459 # to build the subgraph. Let's try unknown. 

3460 res.append(make_tensor_value_info(inp.name, 0, None)) 

3461 else: 

3462 res.append(make_tensor_value_info( 

3463 inp.name, inp.var.proto_added_type, 

3464 inp.var.proto_added_shape)) 

3465 

3466 hidden = [c for c in input_names if isinstance( 

3467 c.var, (ExistingVariable, OnnxExisting))] 

3468 logger.debug("_GraphBuilder-%d._process_io:4:return:res:%r", 

3469 id(self), [n.name for n in res]) 

3470 logger.debug("_GraphBuilder-%d._process_io:4:return:hidden:%r", 

3471 id(self), hidden) 

3472 return res, hidden 

3473 

3474 def to_onnx(self, inputs=None, outputs=None, 

3475 target_opset=None, run_shape=False, 

3476 optim=True, function_name=None, 

3477 function_domain=None, verbose=0, 

3478 check_model=True): 

3479 """ 

3480 Converts this operator into an ONNX graph. 

3481 

3482 :param inputs: specific inputs (as a dictionary) or 

3483 default inputs if not specified 

3484 :param outputs: specific outputs 

3485 :param target_opset: dictionary with target opset per domain, 

3486 None for the default one 

3487 :param run_shape: run shape inference before returning the model 

3488 :param optim: optimize the model with function 

3489 @see fn onnx_optimisations 

3490 :param function_name: if not None builds a :epkg:`FunctionProto` 

3491 use this name 

3492 :param function_domain: in case of a function, declares the function 

3493 as part of this domain, `'mlprodict'` if None 

3494 :param verbose: prints information 

3495 :param check_model: checks the output model 

3496 :return: onnx graph 

3497 

3498 (_GraphBuilder) 

3499 """ 

3500 logger.debug("_GraphBuilder-%d.to_onnx:#####:%s", 

3501 id(self), str(function_name)) 

3502 logger.debug("_GraphBuilder-%d.to_onnx(%r, %r, target_opset=%r)", 

3503 id(self), inputs, outputs, target_opset) 

3504 # inputs and outputs 

3505 if not all(map(lambda x: isinstance(x, InputDetectedVariable), inputs)): 

3506 raise TypeError( # pragma: no cover 

3507 "One of the input is not InputDetectedVariable.") 

3508 if not all(map(lambda x: isinstance(x, OutputDetectedVariable), outputs)): 

3509 raise TypeError( # pragma: no cover 

3510 "One of the outputs is not OutputDetectedVariable.") 

3511 logger.indent() 

3512 self.input, self.hidden_input = self._process_io( 

3513 inputs, list(self.input_names.values())) 

3514 logger.dedent() 

3515 logger.debug("_GraphBuilder-%d.to_onnx:hidden_input:%r", 

3516 id(self), self.hidden_input) 

3517 logger.indent() 

3518 self.output, self.hidden_output = self._process_io(outputs, None) 

3519 logger.dedent() 

3520 if len(self.hidden_output) > 0: 

3521 raise RuntimeError( # pragma: no cover 

3522 f"Unexpected hidden output {self.hidden_output!r}.") 

3523 logger.debug("_GraphBuilder-%d.to_onnx:self.input=%r", 

3524 id(self), [i.name for i in self.input]) 

3525 if len(self.hidden_input) > 0: 

3526 logger.debug("_GraphBuilder-%d.to_onnx:self.hidden_input=%r", 

3527 id(self), [i.name for i in self.hidden_input]) 

3528 logger.debug("_GraphBuilder-%d.to_onnx:self.output=%r", 

3529 id(self), [i.name for i in self.output]) 

3530 logger.debug("_GraphBuilder-%d.to_onnx:build:n_inputs=%r n_inits=%r n_nodes=%r " 

3531 "n_outputs=%r", 

3532 id(self), len(self.input), len(self.initializer), 

3533 len(self.node), len(self.output)) 

3534 

3535 if function_name is not None: 

3536 # function 

3537 if function_domain is None: 

3538 function_domain = 'mlprodict' 

3539 if len(self.initializer) > 0: 

3540 nodes = [] 

3541 for init in self.initializer: 

3542 nodes.append( 

3543 make_node('Constant', [], [init.name], value=init, 

3544 name=f'_init_{init.name}')) 

3545 nodes.extend(self.node) 

3546 else: 

3547 nodes = self.node 

3548 fct = make_function( 

3549 function_domain, function_name, 

3550 [_.name for _ in self.input], 

3551 [_.name for _ in self.output], 

3552 nodes, 

3553 [make_opsetid(k, v) for k, v in self.opsets.items()]) 

3554 if check_model: 

3555 check_onnx(fct) 

3556 if optim: 

3557 from ..onnx_tools.optim import onnx_optimisations 

3558 fct = onnx_optimisations(fct) 

3559 if check_model: 

3560 check_onnx(fct) 

3561 logger.debug("_GraphBuilder-%d:fct:.to_onnx() -> done", id(self)) 

3562 logger.debug("_GraphBuilder-%d:fct:to_onnx:#####", id(self)) 

3563 return fct 

3564 else: 

3565 # graph 

3566 graph = make_graph( 

3567 self.node, 'XOP', self.input, self.output, self.initializer) 

3568 onnx_model = make_model( 

3569 graph, functions=list(self.functions.values())) 

3570 opv = self.opsets.get('', max_supported_opset()) 

3571 opset2ir = _default_OPSET_TO_IR_VERSION() 

3572 irv = opset2ir.get(opv, max(opset2ir.values())) 

3573 onnx_model.ir_version = irv 

3574 

3575 logger.debug("_GraphBuilder-%d.to_onnx:2onnx:n_inputs=%r n_inits=%r " 

3576 "n_nodes=%r n_outputs=%r", 

3577 id(self), len(onnx_model.graph.input), 

3578 len(onnx_model.graph.initializer), 

3579 len(onnx_model.graph.node), 

3580 len(onnx_model.graph.output)) 

3581 

3582 del onnx_model.opset_import[:] # pylint: disable=E1101 

3583 seen_opset = set() 

3584 for k, v in self.opsets.items(): 

3585 if (k or '') in seen_opset: 

3586 raise RuntimeError( # pragma: no cover 

3587 f"Duplicated opset ({k!r}, {v!r}).") 

3588 op_set = onnx_model.opset_import.add() # pylint: disable=E1101 

3589 op_set.domain = k or '' 

3590 op_set.version = v 

3591 seen_opset.add(op_set.domain) 

3592 

3593 # optimisation, remove redundant constant, unnecessary 

3594 # identity nodes. 

3595 if check_model: 

3596 check_onnx(onnx_model) 

3597 if optim: 

3598 from ..onnx_tools.optim import onnx_optimisations 

3599 onnx_model = onnx_optimisations(onnx_model) 

3600 if check_model: 

3601 logger.debug( 

3602 "_GraphBuilder-%d.to_onnx:check_onnx", id(self)) 

3603 check_onnx(onnx_model) 

3604 

3605 logger.debug("_GraphBuilder-%d.to_onnx:optim:n_inputs=%r n_inits=%r " 

3606 "n_nodes=%r n_outputs=%r", 

3607 id(self), len(onnx_model.graph.input), 

3608 len(onnx_model.graph.initializer), 

3609 len(onnx_model.graph.node), 

3610 len(onnx_model.graph.output)) 

3611 

3612 if run_shape: 

3613 logger.debug("_GraphBuilder-%d.to_onnx:infer_shapes", id(self)) 

3614 with_shape = infer_shapes(onnx_model) 

3615 logger.debug("_GraphBuilder-%d.to_onnx:shape:n_inputs=%r " 

3616 "n_inits=%r n_nodes=%r n_outputs=%r", 

3617 id(self), len(with_shape.graph.input), 

3618 len(with_shape.graph.initializer), 

3619 len(with_shape.graph.node), 

3620 len(with_shape.graph.output)) 

3621 return with_shape 

3622 

3623 logger.debug("_GraphBuilder-%d.to_onnx:mod -> done", id(self)) 

3624 logger.debug("_GraphBuilder-%d.to_onnx:mod:#####", id(self)) 

3625 return onnx_model 

3626 

3627 

3628class _StaticVariables: 

3629 """ 

3630 Holds static variables. 

3631 """ 

3632 

3633 def __init__(self): 

3634 self._all_schemas_ = None 

3635 self._all_schemas_versions_ = None 

3636 self._all_domains_ = None 

3637 self._all_classes_ = None 

3638 

3639 @property 

3640 def all_schemas(self): 

3641 "Returns all schemas." 

3642 self.populate() 

3643 return self._all_schemas_ 

3644 

3645 @property 

3646 def all_classes(self): 

3647 "Returns all operators wrapped in classes." 

3648 self.populate() 

3649 return self._all_classes_ 

3650 

3651 @property 

3652 def all_schemas_versions(self): 

3653 "Returns all operators, domains, versions." 

3654 self.populate() 

3655 return self._all_schemas_versions_ 

3656 

3657 @property 

3658 def all_domains(self): 

3659 "Returns all domains." 

3660 self.populate() 

3661 return self._all_domains_ 

3662 

3663 def populate(self): 

3664 "Populates static variables." 

3665 if self._all_schemas_ is not None: 

3666 return 

3667 (self._all_schemas_, self._all_schemas_versions_, 

3668 self._all_domains_) = _populate_schemas() 

3669 self._all_classes_ = {} 

3670 

3671 

3672class OnnxExisting(OnnxOperator): 

3673 """ 

3674 Wrapper around OnnxIdentity to specify this operator is 

3675 not part of the subgraph it is used in. 

3676 """ 

3677 

3678 _unique_names = set() 

3679 

3680 expected_inputs = ['X'] 

3681 expected_outputs = ['Y'] 

3682 operator_name = 'Existing' 

3683 input_range = [1, 1] 

3684 output_range = [1, 1] 

3685 domain = '' 

3686 is_deprecated = False 

3687 since_version = 1 

3688 past_version = [] 

3689 attr_names = [] 

3690 op_type = 'Existing' 

3691 __module__ = __name__ 

3692 

3693 @staticmethod 

3694 def get_unique_name(var): 

3695 """ 

3696 Returns a unique variable name. 

3697 

3698 :param var: an instance of OnnxOperator. 

3699 :return: unique variable name 

3700 """ 

3701 if isinstance(var, OnnxOperator): 

3702 name = "%s_%s" % ((var.domain or "").lower().replace(".", ""), 

3703 var.op_type.lower()) 

3704 else: 

3705 raise TypeError( # pragma: no cover 

3706 f"Unexpected type {type(var)!r} for var.") 

3707 i = 0 

3708 new_name = "_exist_%s_%d" % (name, i) 

3709 while new_name in OnnxExisting._unique_names: 

3710 i += 1 

3711 new_name = "_exist_%s_%d" % (name, i) 

3712 OnnxExisting._unique_names.add(new_name) 

3713 return new_name 

3714 

3715 def __init__(self, *args, **kwargs): # pylint: disable=W0231 

3716 # OnnxIdentity.__init__(self, *args, **kwargs) # pylint: disable=W0233 

3717 OnnxOperator.__init__(self, *args, **kwargs) # pylint: disable=W0233 

3718 self.control_ops_ = None 

3719 if len(self.inputs) != 1: 

3720 raise RuntimeError( # pragma: no cover 

3721 f"Unexpected number of inputs {len(self.inputs)}.") 

3722 if isinstance(self.inputs[0], Variable): 

3723 # It is one input 

3724 new_names = [ 

3725 ExistingVariable(self.inputs[0].name, self.inputs[0])] 

3726 logger.debug("op:OnnxExisting-%d.__init__:set-input:1:%r", 

3727 id(self), new_names) 

3728 self.inputs[0].output_names = new_names 

3729 else: 

3730 if not isinstance(self.inputs[0], OnnxOperatorBase): 

3731 raise TypeError( # pragma: no cover 

3732 f"Only input should a node not {type(self.inputs[0])!r}.") 

3733 if self.inputs[0].output_names is None: 

3734 new_names = [ 

3735 ExistingVariable(OnnxExisting.get_unique_name(self.inputs[0]), 

3736 self.inputs[0])] 

3737 logger.debug("op:OnnxExisting-%d.__init__:set-input:2:%r", 

3738 id(self), new_names) 

3739 self.inputs[0].output_names = new_names 

3740 

3741 def __repr__(self): 

3742 """ 

3743 usual 

3744 """ 

3745 return "{}({}) -> {}".format( 

3746 self.__class__.__name__, 

3747 self.inputs[0].output_names, 

3748 [str(o) for o in self.output_names] 

3749 if self.output_names is not None else "?") 

3750 

3751 def find_named_inputs(self): 

3752 """ 

3753 Retrieves all named inputs in this graph. 

3754 """ 

3755 res = [] 

3756 for i, inp in enumerate(self.inputs[0].output_names): 

3757 if not isinstance(inp, (Variable, ExistingVariable)): 

3758 raise TypeError( # pragma: no cover 

3759 "Unexpected type %r for input %r in node type %r." 

3760 "" % (type(inp), i, type(self))) 

3761 res.append(inp.name) 

3762 return res 

3763 

3764 def f(self, *inputs, verbose=0, fLOG=None, # pylint: disable=W0221 

3765 clear_cache=False, runtime=None): 

3766 "For the eager mode." 

3767 raise NotImplementedError() # pragma: no cover 

3768 

3769 def _set_control_op(self, op, subgraph_inputs=None): 

3770 if subgraph_inputs is not None: 

3771 raise NotImplementedError( # pragma: no cover 

3772 "Not implemented.") 

3773 if op is None: 

3774 raise RuntimeError( # pragma: no cover 

3775 "op cannot be None in _set_control_op.") 

3776 logger.debug("op:%s-%d:_set_control_op:found:p:%d:%r", 

3777 self.__class__.__name__, id(self), id(op), 

3778 self.inputs[0].output_names) 

3779 if self.control_ops_ is None: 

3780 self.control_ops_ = [] 

3781 self.control_ops_.append(op) 

3782 op.add_external_input(self.inputs[0]) 

3783 

3784 

3785_S = _StaticVariables() 

3786onnx_load_factory = Xop = OnnxLoadFactory()