Coverage for mlprodict/onnxrt/ops_cpu/_op.py: 90%

299 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2""" 

3@file 

4@brief Shortcut to *ops_cpu*. 

5""" 

6import pprint 

7import numpy 

8import onnx 

9import onnx.defs 

10from onnx import GraphProto 

11from ._new_ops import OperatorSchema 

12 

13 

14def _build_schemas(): 

15 res = {} 

16 for schema in onnx.defs.get_all_schemas_with_history(): 

17 # Multiple version can coexist. The last one is kept. 

18 if schema.name in res: 

19 if schema.since_version > res[schema.name].since_version: 

20 # We keep the most recent one. 

21 res[schema.name] = schema 

22 else: 

23 res[schema.name] = schema 

24 res[schema.name + '_' + str(schema.since_version)] = schema 

25 return res 

26 

27 

28_schemas = _build_schemas() 

29_at_least_one = {'Constant'} 

30 

31 

32class RuntimeTypeError(RuntimeError): 

33 """ 

34 Raised when a type of a variable is unexpected. 

35 """ 

36 pass 

37 

38 

39class DefaultNone: 

40 """ 

41 Default value for parameters when the parameter is not set 

42 but the operator has a default behaviour for it. 

43 """ 

44 pass 

45 

46 

47class RefAttrName: 

48 """ 

49 Implements a link between a parameter of a function 

50 and an attribute in node. 

51 

52 :param name: name of the input 

53 """ 

54 

55 def __init__(self, name): 

56 self.name = name 

57 

58 def __repr__(self): 

59 "usual" 

60 return f"{self.__class__.__name__}({self.name!r})" 

61 

62 

63class OpRun: 

64 """ 

65 Ancestor to all operators in this subfolder. 

66 The runtime for every node can checked into 

67 `ONNX unit tests 

68 <https://github.com/onnx/onnx/tree/master/onnx/backend/test/case/node>`_. 

69 

70 :param onnx_node: :epkg:`onnx` node 

71 :param desc: internal representation 

72 :param expected_attributes: expected attributes for this node 

73 :param options: runtime options 

74 """ 

75 

76 def __init__(self, onnx_node, desc=None, expected_attributes=None, 

77 **options): 

78 self._provider = 'python' 

79 self.onnx_node = onnx_node 

80 self.desc = desc 

81 self.inplaces = {} 

82 

83 if onnx_node.op_type in _schemas: 

84 self._schema = _schemas[onnx_node.op_type] 

85 else: 

86 self._schema = self._find_custom_operator_schema(onnx_node.op_type) 

87 if self._schema is None: 

88 raise RuntimeError( # pragma: no cover 

89 "Unable to find class name '{}' in available schemas:" 

90 "(onnx.__version__='{}')\n{}".format( 

91 self.__class__.__name__, 

92 onnx.__version__, 

93 "\n".join(sorted(_schemas)))) 

94 

95 if desc is not None: 

96 if 'atts' in desc: 

97 for a, b in desc['atts'].items(): 

98 if not isinstance(b, dict) or ( 

99 'value' not in b and 'ref_attr_name' not in b): 

100 raise ValueError( # pragma: no cover 

101 f"Unexpected value {b}.") 

102 if 'ref_attr_name' in b: 

103 options[a] = RefAttrName(b['ref_attr_name']) 

104 else: 

105 options[a] = (b['value_rt'] if 'value_rt' in b 

106 else b['value']) 

107 if expected_attributes is not None: 

108 if onnx_node.op_type in _at_least_one: 

109 done = 0 

110 for a, b in expected_attributes.items(): 

111 if a in options: 

112 setattr(self, a, b) 

113 done += 1 

114 if done == 0: 

115 raise RuntimeError( # pragma: no cover 

116 "All parameters '{}' are missing from operator '{}', " 

117 "given {}.".format( 

118 a, onnx_node.op_type, list(sorted(options)))) 

119 else: 

120 for a, b in expected_attributes.items(): 

121 if a not in options: 

122 if b is DefaultNone: 

123 setattr(self, a, None) 

124 elif b is None: 

125 raise RuntimeError( # pragma: no cover 

126 "Parameter '{}' is missing from operator '{}' " 

127 "(class='{}'), given {}.".format( 

128 a, onnx_node.op_type, 

129 self.__class__.__name__, 

130 list(sorted(options)))) 

131 else: 

132 setattr(self, a, b) 

133 for k, v in options.items(): 

134 setattr(self, k, v) 

135 

136 if onnx_node.op_type not in _at_least_one: 

137 for k, v in self._schema.attributes.items(): 

138 if not hasattr(self, k) and getattr(v, 'required', True): 

139 raise RuntimeError( # pragma: no cover 

140 "Attribute '{}' is expected based on ONNX specifications " 

141 "for node '{}' and options {}.".format( 

142 k, onnx_node.op_type, pprint.pformat(options))) 

143 

144 @staticmethod 

145 def local_inputs(graph): 

146 """ 

147 Returns all varibles not registered as inputs and not produced by 

148 an node inside the graph. This inputs are part of the context 

149 existing in the graph calling this one. 

150 """ 

151 if not isinstance(graph, GraphProto): 

152 raise TypeError( 

153 f"Unexpected type {type(graph)!r}.") 

154 local = set() 

155 known = set() 

156 for init in graph.initializer: 

157 known.add(init.name) 

158 for init in graph.input: 

159 known.add(init.name) 

160 for node in graph.node: 

161 for o in node.output: 

162 known.add(o) 

163 for i in node.input: 

164 if i not in known: 

165 local.add(i) 

166 return list(local) 

167 

168 def need_context(self): 

169 """ 

170 Tells the runtime if this node needs the context 

171 (all the results produced so far) as it may silently access 

172 one of them (operator Loop). 

173 The default answer is `False`. 

174 """ 

175 return False 

176 

177 def _find_custom_operator_schema(self, op_name): 

178 raise NotImplementedError( # pragma: no cover 

179 f"This method should be overwritten for operator '{op_name}'.") 

180 

181 def __str__(self): 

182 """ 

183 usual 

184 """ 

185 atts = [self.__class__.__name__ + '(', 

186 f" op_type={self.onnx_node.op_type}"] 

187 for k, v in sorted(self.__dict__.items()): 

188 if k in {'desc', 'onnx_node'}: 

189 continue 

190 if 'a' <= k[0] <= 'z' and k[-1] != '_': 

191 atts.append(f' {k}={v},') 

192 atts.append(')') 

193 return "\n".join(atts) 

194 

195 def _run(self, *args, **kwargs): 

196 """ 

197 Should be overwritten. 

198 """ 

199 raise NotImplementedError( # pragma: no cover 

200 "Method '_run' or 'to_python' should be overwritten for operator %s." 

201 "" % self.__class__.__name__) 

202 

203 def run(self, *args, **kwargs): # pylint: disable=E0202 

204 """ 

205 Calls method ``_run``. 

206 """ 

207 try: 

208 res = self._run(*args, **kwargs) 

209 except TypeError as e: 

210 raise TypeError( # pragma: no cover 

211 "Issues with types {} (operator {}).".format( 

212 ", ".join(str(type(_)) for _ in args), 

213 self.__class__.__name__)) from e 

214 except AttributeError as e: 

215 raise AttributeError( # pragma: no cover 

216 "Issues with types {} (operator {}).".format( 

217 ", ".join(str(type(_)) for _ in args), 

218 self.__class__.__name__)) from e 

219 return res 

220 

221 def switch_initializers_dtype(self, dtype_in=numpy.float32, 

222 dtype_out=numpy.float64): 

223 """ 

224 Switches all initializers to ``numpy.float64``. If *model* 

225 is None, a simple cast is done. 

226 

227 @param dtype_in previous type 

228 @param dtype_out next type 

229 @return done operations 

230 """ 

231 done = [] 

232 for k, v in sorted(self.__dict__.items()): 

233 if k in {'desc', 'onnx_node'}: 

234 continue 

235 if isinstance(v, numpy.ndarray): 

236 if v.dtype == dtype_in: 

237 v = v.astype(dtype_out) 

238 setattr(self, k, v) 

239 done.append(("+", "att", k, getattr(self, k))) 

240 else: 

241 done.append(("-", "att", k, getattr(self, k))) 

242 if hasattr(self, '_run_no_checks_') and hasattr(self, 'run'): 

243 self.run = self._run_no_checks_ # pylint: disable=E0202,E1101 

244 return done 

245 

246 def enable_inplace_compute(self, index): 

247 """ 

248 Tells the node that one input can be overwritten. 

249 

250 @param index input index 

251 """ 

252 self.inplaces[index] = True 

253 

254 @property 

255 def args_default(self): 

256 """ 

257 Returns the list of arguments as well as 

258 the list of parameters with the default values 

259 (close to the signature). 

260 """ 

261 inps = [] 

262 if hasattr(self, 'atts'): 

263 for k, v in self.atts.items(): # pylint: disable=E1101 

264 if isinstance(v, (list, tuple, dict)) and len(v) == 0: 

265 v = None 

266 inps.append(f'{k}={v!r}') 

267 return inps 

268 

269 @property 

270 def args_default_modified(self): 

271 """ 

272 Returns the list of modified parameters. 

273 """ 

274 if not hasattr(self, 'atts'): 

275 return None 

276 

277 inps = [] 

278 for k, v in self.atts.items(): # pylint: disable=E1101 

279 val = getattr(self, k, None) 

280 if isinstance(val, numpy.ndarray) and isinstance(v, list): 

281 val = list(val) 

282 try: 

283 if val != v: 

284 inps.append(f'{k}={val!r}') 

285 except ValueError as e: # pragma: no cover 

286 raise ValueError( 

287 f"Unexpected value for v={v!r} and val={val!r}.") from e 

288 return inps 

289 

290 @property 

291 def args_optional(self): 

292 """ 

293 Returns the list of optional arguments. 

294 """ 

295 inps = [] 

296 if hasattr(self, 'optional_inputs'): 

297 for k, v in self.optional_inputs.items(): # pylint: disable=E1101 

298 inps.append(f'{k}={v!r}') 

299 return inps 

300 

301 @property 

302 def args_mandatory(self): 

303 """ 

304 Returns the list of optional arguments. 

305 """ 

306 if hasattr(self, 'mandatory_inputs'): 

307 return self.mandatory_inputs # pylint: disable=E1101 

308 return None 

309 

310 def to_python(self, inputs): 

311 """ 

312 Returns a python code equivalent to this operator. 

313 

314 @param inputs inputs name 

315 @return imports, python code, both as strings 

316 """ 

317 raise NotImplementedError( 

318 f"Operator '{self.__class__.__name__}' has no equivalent python code.") # pragma: no cover 

319 

320 def _to_python_numpy(self, inputs, numpy_name): 

321 return ("import numpy", 

322 f"return numpy.{numpy_name}({', '.join(inputs)})") 

323 

324 @property 

325 def atts_value(self): 

326 "Returns all parameters in a dictionary." 

327 if hasattr(self, 'atts'): 

328 return {k: getattr(self, k) 

329 for k in self.atts} # pylint: disable=E1101 

330 return None 

331 

332 

333class OpRunUnary(OpRun): 

334 """ 

335 Ancestor to all unary operators in this subfolder. 

336 Checks that inputs type are the same. 

337 """ 

338 

339 def __init__(self, onnx_node, desc=None, expected_attributes=None, **options): 

340 OpRun.__init__(self, onnx_node, desc=desc, 

341 expected_attributes=expected_attributes, 

342 **options) 

343 

344 def run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=E0202,W0221 

345 """ 

346 Calls method ``_run``. 

347 """ 

348 try: 

349 res = self._run(x, attributes=attributes, 

350 verbose=verbose, fLOG=fLOG) 

351 except TypeError as e: 

352 raise TypeError( # pragma: no cover 

353 "Issues with types {} (binary operator {}).".format( 

354 ", ".join(str(type(_)) for _ in [x]), 

355 self.__class__.__name__)) from e 

356 return res 

357 

358 

359class OpRunArg(OpRunUnary): 

360 """ 

361 Ancestor to all unary operators in this subfolder 

362 and which produces position of extremas (ArgMax, ...). 

363 Checks that inputs type are the same. 

364 The class must have attributes *axis*, *keepdim*. 

365 """ 

366 

367 def __init__(self, onnx_node, desc=None, expected_attributes=None, 

368 **options): 

369 OpRunUnary.__init__(self, onnx_node, desc=desc, 

370 expected_attributes=expected_attributes, 

371 **options) 

372 if not hasattr(self, 'keepdims'): 

373 raise AttributeError( # pragma: no cover 

374 "Attribute 'keepdims' is missing.") 

375 if not hasattr(self, 'axis'): 

376 raise AttributeError( # pragma: no cover 

377 "Attribute 'axis' is missing.") 

378 

379 def run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=E0202 

380 """ 

381 Calls method ``_run``. 

382 """ 

383 res = OpRunUnary.run(self, x, attributes=attributes, 

384 verbose=verbose, fLOG=fLOG) 

385 if res[0].dtype != numpy.int64: 

386 raise RuntimeTypeError( # pragma: no cover 

387 "Output type mismatch: should be '{}' != output '{}' " 

388 "(operator '{}')".format( 

389 numpy.int64, res[0].dtype, self.__class__.__name__)) 

390 return res 

391 

392 def _run_no_checks_(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

393 return OpRunUnary.run(self, x, attributes=attributes, verbose=verbose, fLOG=fLOG) 

394 

395 

396class OpRunUnaryNum(OpRunUnary): 

397 """ 

398 Ancestor to all unary and numerical operators 

399 in this subfolder. Checks that inputs type 

400 are the same. 

401 """ 

402 

403 def __init__(self, onnx_node, desc=None, expected_attributes=None, **options): 

404 OpRunUnary.__init__(self, onnx_node, desc=desc, 

405 expected_attributes=expected_attributes, 

406 **options) 

407 

408 def run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=E0202 

409 """ 

410 Calls method ``_run``. 

411 """ 

412 res = OpRunUnary.run(self, x, attributes=attributes, 

413 verbose=verbose, fLOG=fLOG) 

414 if len(res) == 0 or res[0] is None: 

415 return res 

416 if not isinstance(res[0], list) and res[0].dtype != x.dtype: 

417 raise RuntimeTypeError( # pragma: no cover 

418 "Output type mismatch: input '{}' != output '{}' " 

419 "(operator '{}')".format( 

420 x.dtype, res[0].dtype, self.__class__.__name__)) 

421 return res 

422 

423 def _run_no_checks_(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

424 return OpRunUnary.run(self, x, attributes=attributes, verbose=verbose, fLOG=fLOG) 

425 

426 

427class OpRunClassifierProb(OpRunUnary): 

428 """ 

429 Ancestor to all binary operators in this subfolder. 

430 Checks that inputs type are the same. 

431 """ 

432 

433 def __init__(self, onnx_node, desc=None, expected_attributes=None, 

434 **options): 

435 OpRunUnary.__init__(self, onnx_node, desc=desc, 

436 expected_attributes=expected_attributes, 

437 **options) 

438 

439 def run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=E0202 

440 """ 

441 Calls method ``_run``. 

442 """ 

443 res = OpRunUnary.run(self, x, attributes=attributes, 

444 verbose=verbose, fLOG=fLOG) 

445 if x.dtype in (numpy.float32, numpy.float64) and res[1].dtype != x.dtype: 

446 raise RuntimeTypeError( # pragma: no cover 

447 "Output type mismatch: {} != {} (operator '{}')".format( 

448 x.dtype, res[1].dtype, self.__class__.__name__)) 

449 return res 

450 

451 @property 

452 def nb_classes(self): 

453 """ 

454 Returns the number of expected classes. 

455 """ 

456 return max(len(getattr(self, 'classlabels_ints', [])), 

457 len(getattr(self, 'classlabels_int64s', [])), 

458 len(self.classlabels_strings)) # pylint: disable=E1101 

459 

460 def _run_no_checks_(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

461 return OpRunUnary.run(self, x, attributes=attributes, verbose=verbose, fLOG=fLOG) 

462 

463 

464class OpRunBinary(OpRun): 

465 """ 

466 Ancestor to all binary operators in this subfolder. 

467 Checks that inputs type are the same. 

468 """ 

469 

470 def __init__(self, onnx_node, desc=None, expected_attributes=None, 

471 **options): 

472 OpRun.__init__(self, onnx_node, desc=desc, 

473 expected_attributes=expected_attributes, 

474 **options) 

475 

476 def run(self, x, y, attributes=None, verbose=0, fLOG=None): # pylint: disable=E0202,W0221 

477 """ 

478 Calls method ``_run``. 

479 """ 

480 if x is None or y is None: 

481 raise RuntimeError( # pragma: no cover 

482 f"x and y have different dtype: {type(x)} != {type(y)} ({type(self)})") 

483 if x.dtype != y.dtype: 

484 raise RuntimeTypeError( 

485 "Input type mismatch: {} != {} (operator '{}', shapes {}, {})".format( 

486 x.dtype, y.dtype, self.__class__.__name__, 

487 x.shape, y.shape)) 

488 try: 

489 res = self._run(x, y, attributes=attributes, 

490 verbose=verbose, fLOG=fLOG) 

491 except (TypeError, ValueError) as e: # pragma: no cover 

492 raise TypeError( 

493 "Issues with types {} (binary operator {}).".format( 

494 ", ".join(str(type(_)) for _ in [x, y]), 

495 self.__class__.__name__)) from e 

496 return res 

497 

498 def _run_no_checks_(self, x, y, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

499 """ 

500 Calls method ``_run``. 

501 """ 

502 try: 

503 res = self._run(x, y, attributes=attributes, 

504 verbose=verbose, fLOG=fLOG) 

505 except TypeError as e: # pragma: no cover 

506 raise TypeError( 

507 "Issues with types {} (binary operator {}).".format( 

508 ", ".join(str(type(_)) for _ in [x, y]), 

509 self.__class__.__name__)) from e 

510 return res 

511 

512 

513class OpRunBinaryComparison(OpRunBinary): 

514 """ 

515 Ancestor to all binary operators in this subfolder 

516 comparing tensors. 

517 """ 

518 

519 def __init__(self, onnx_node, desc=None, expected_attributes=None, 

520 **options): 

521 OpRunBinary.__init__(self, onnx_node, desc=desc, 

522 expected_attributes=expected_attributes, 

523 **options) 

524 

525 

526class OpRunBinaryNum(OpRunBinary): 

527 """ 

528 Ancestor to all binary operators in this subfolder. 

529 Checks that inputs type are the same. 

530 """ 

531 

532 def __init__(self, onnx_node, desc=None, expected_attributes=None, 

533 **options): 

534 OpRunBinary.__init__(self, onnx_node, desc=desc, 

535 expected_attributes=expected_attributes, 

536 **options) 

537 

538 def run(self, x, y, attributes=None, verbose=0, fLOG=None): # pylint: disable=E0202 

539 """ 

540 Calls method ``_run``. 

541 """ 

542 res = OpRunBinary.run( 

543 self, x, y, attributes=attributes, verbose=verbose, fLOG=fLOG) 

544 if res[0].dtype != x.dtype: 

545 raise RuntimeTypeError( 

546 "Output type mismatch: {} != {} or {} (operator '{}')" 

547 " type(x)={} type(y)={}".format( 

548 x.dtype, res[0].dtype, y.dtype, 

549 self.__class__.__name__, type(x), type(y))) 

550 return res 

551 

552 def _run_no_checks_(self, x, y, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

553 """ 

554 Calls method ``_run``. 

555 """ 

556 return OpRunBinary._run_no_checks_( 

557 self, x, y, attributes=attributes, verbose=verbose, fLOG=fLOG) 

558 

559 

560class OpRunBinaryNumpy(OpRunBinaryNum): 

561 """ 

562 Implements the inplaces logic. 

563 *numpy_fct* is a binary numpy function which 

564 takes two matrices and has a argument *out* 

565 for inplace operations. 

566 """ 

567 

568 def __init__(self, numpy_fct, onnx_node, desc=None, 

569 expected_attributes=None, **options): 

570 OpRunBinaryNum.__init__(self, onnx_node, desc=desc, 

571 expected_attributes=expected_attributes, 

572 **options) 

573 self.numpy_fct = numpy_fct 

574 self._cannot_inplace_int = self.numpy_fct in ( 

575 numpy.divide, numpy.true_divide) 

576 

577 def _run(self, a, b, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

578 if (self._cannot_inplace_int and 

579 numpy.issubdtype(a.dtype, numpy.integer)): 

580 return (self.numpy_fct(a, b), ) 

581 if self.inplaces.get(0, False) and a.flags['WRITEABLE'] and a.size >= b.size: 

582 if len(a.shape) == 1 and b.shape == (1, 1): 

583 a = a.reshape(1, a.shape[0]) 

584 try: 

585 self.numpy_fct(a, b, out=a) 

586 return (a, ) 

587 except (ValueError, TypeError): 

588 return (self.numpy_fct(a, b), ) 

589 if self.inplaces.get(1, False) and b.flags['WRITEABLE'] and a.size <= b.size: 

590 if len(b.shape) == 1 and a.shape == (1, 1): 

591 b = b.reshape(b.shape[0], 1) 

592 try: 

593 self.numpy_fct(a, b, out=b) 

594 return (b, ) 

595 except (ValueError, TypeError): 

596 return (self.numpy_fct(a, b), ) 

597 return (self.numpy_fct(a, b), ) 

598 

599 def to_python(self, inputs): 

600 """ 

601 Returns a python code equivalent to this operator. 

602 

603 @param inputs inputs name 

604 @return imports, python code, both as strings 

605 """ 

606 lines = [ 

607 "# inplaces not take into account {}-{}".format( 

608 self.inplaces.get(0, False), self.inplaces.get(1, False)), 

609 f"return numpy.{self.numpy_fct.__name__}({', '.join(inputs)})" 

610 ] 

611 return "import numpy", "\n".join(lines) 

612 

613 

614class OpRunReduceNumpy(OpRunUnaryNum): 

615 """ 

616 Implements the reduce logic. 

617 It must have a parameter *axes*. 

618 """ 

619 

620 def __init__(self, onnx_node, desc=None, 

621 expected_attributes=None, **options): 

622 if ('noop_with_empty_axes' not in expected_attributes and 

623 'axes' not in expected_attributes): 

624 raise RuntimeError( # pragma: no cover 

625 "Parameter 'axes' is expected but not found in {} " 

626 "from class {}".format(expected_attributes, type(self))) 

627 if (expected_attributes.get('noop_with_empty_axes', 0) and 

628 (expected_attributes['axes'] is None or 

629 len(expected_attributes['axes']) == 0)): 

630 raise RuntimeError( # pragma: no cover 

631 "Parameter 'axes' cannot be empty as {} (noop_with_empty_axes=1) " 

632 "from class {}".format(expected_attributes, type(self))) 

633 OpRunUnaryNum.__init__(self, onnx_node, desc=desc, 

634 expected_attributes=expected_attributes, 

635 **options) 

636 if isinstance(self.axes, numpy.ndarray): # pylint: disable=E0203 

637 if (len(self.axes.shape) == 0 or # pylint: disable=E0203,E1101 

638 self.axes.shape[0] == 0): # pylint: disable=E0203,E1101 

639 self.axes = None 

640 else: 

641 self.axes = tuple(self.axes) 

642 elif self.axes in [[], tuple()]: # pylint: disable=E0203 

643 self.axes = None 

644 elif isinstance(self.axes, list): # pylint: disable=E0203 

645 self.axes = tuple(self.axes) 

646 

647 

648class OpRunCustom(OpRun): 

649 """ 

650 Automates some methods for custom operators defined 

651 outside *mlprodict*. 

652 """ 

653 

654 class OpRunCustomSchema(OperatorSchema): 

655 """ 

656 Custom schema. 

657 """ 

658 

659 def __init__(self, cls): 

660 OperatorSchema.__init__(self, cls.__name__) 

661 self.attributes = cls.atts 

662 

663 def __init__(self, onnx_node, desc=None, 

664 expected_attributes=None, **options): 

665 OpRun.__init__(self, onnx_node, desc=desc, 

666 expected_attributes=expected_attributes, 

667 **options) 

668 

669 def _find_custom_operator_schema(self, op_name): 

670 """ 

671 Finds a custom operator defined by this runtime. 

672 """ 

673 if (op_name == self.__class__.__name__ or 

674 (hasattr(self.__class__, 'op_name') and 

675 self.__class__.op_name == op_name)): # pylint: disable=E1101 

676 return OpRunCustom.OpRunCustomSchema(self.__class__) 

677 raise RuntimeError( # pragma: no cover 

678 f"Unable to find a schema for operator '{op_name}'.") 

679 

680 

681class OpFunction(OpRun): 

682 """ 

683 Runs a custom function. 

684 """ 

685 

686 def __init__(self, onnx_node, impl): 

687 if impl is None: 

688 raise RuntimeError( 

689 f"impl cannot be None for node type {onnx_node.op_type!r} " 

690 f"from domain {onnx_node.domain!r}.") 

691 OpRun.__init__(self, onnx_node) 

692 self.impl_ = impl 

693 # The function implementation is the same whenever the function is called 

694 # but the attributes may be different at every call. 

695 self.attributes_ = { 

696 name: getattr(self, name) 

697 for name in self.impl_.attributes_} 

698 

699 def _run(self, *inputs, **kwargs): 

700 if len(self.impl_.input_names) != len(inputs): 

701 raise RuntimeError( 

702 f"Mismatch lengths between the number of inputs {len(inputs)} " 

703 f"and the expected number of inputs {len(self.impl_.inputs)} " 

704 f"for node {self.onnx_node.op_type!r} from domain " 

705 f"{self.onnx_node.domain!r}.") 

706 feeds = dict(zip(self.impl_.input_names, inputs)) 

707 attributes = self.attributes_.copy() 

708 attributes.update(kwargs) 

709 results = self.impl_.run(feeds, attributes=attributes) 

710 if len(self.impl_.output_names) != len(results): 

711 raise RuntimeError( 

712 f"Mismatch lengths between the number of outputs {len(results)} " 

713 f"and the expected number of outputs {len(self.impl_.output_names)} " 

714 f"for node {self.onnx_node.op_type!r} " 

715 f"from domain {self.onnx_node.domain!r}.") 

716 return tuple(results[n] for n in self.impl_.output_names) 

717 

718 def to_python(self, inputs): 

719 """ 

720 Returns a python code equivalent to this operator. 

721 

722 @param inputs inputs name 

723 @return imports, python code, both as strings 

724 """ 

725 res = self.impl_.to_python() 

726 sinp = ", ".join(inputs) 

727 code = [res[list(res.keys())[0]], "", "", 

728 "return OnnxPythonInference().run(" + sinp + ")"] 

729 return "", "\n".join(code)