Coverage for mlprodict/npy/onnx_sklearn_wrapper.py: 98%

255 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Helpers to use numpy API to easily write converters 

4for :epkg:`scikit-learn` classes for :epkg:`onnx`. 

5 

6.. versionadded:: 0.6 

7""" 

8import logging 

9import numpy 

10from sklearn.base import ( 

11 ClassifierMixin, ClusterMixin, 

12 RegressorMixin, TransformerMixin) 

13from .onnx_numpy_wrapper import _created_classes_inst, wrapper_onnxnumpy_np 

14from .onnx_numpy_annotation import NDArraySameType, NDArrayType 

15from .xop import OnnxOperatorTuple 

16from .xop_variable import Variable 

17from .xop import loadop 

18from ..plotting.text_plot import onnx_simple_text_plot 

19 

20 

21logger = logging.getLogger('xop') 

22 

23 

24def _skl2onnx_add_to_container(onx, scope, container, outputs): 

25 """ 

26 Adds ONNX graph to :epkg:`skl2onnx` container and scope. 

27 

28 :param onx: onnx graph 

29 :param scope: scope 

30 :param container: container 

31 """ 

32 logger.debug("_skl2onnx_add_to_container:onx=%r outputs=%r", 

33 type(onx), outputs) 

34 mapped_names = {x.name: x.name for x in onx.graph.input} 

35 opsets = {} 

36 for op in onx.opset_import: 

37 opsets[op.domain] = op.version 

38 

39 # adding initializers 

40 for init in onx.graph.initializer: 

41 new_name = scope.get_unique_variable_name(init.name) 

42 mapped_names[init.name] = new_name 

43 container.add_initializer(new_name, None, None, init) 

44 

45 # adding nodes 

46 for node in list(onx.graph.node): 

47 new_inputs = [] 

48 for i in node.input: 

49 if i not in mapped_names: 

50 raise RuntimeError( # pragma: no cover 

51 f"Unable to find input {i!r} in {mapped_names!r}.") 

52 new_inputs.append(mapped_names[i]) 

53 new_outputs = [] 

54 for o in node.output: 

55 new_name = scope.get_unique_variable_name(o) 

56 mapped_names[o] = new_name 

57 new_outputs.append(new_name) 

58 

59 atts = {} 

60 for att in node.attribute: 

61 if att.type == 1: # .f 

62 value = att.f 

63 elif att.type == 2: # .i 

64 value = att.i 

65 elif att.type == 3: # .s 

66 value = att.s 

67 elif att.type == 4: # .t 

68 value = att.t 

69 elif att.type == 6: # .floats 

70 value = list(att.floats) 

71 elif att.type == 7: # .ints 

72 value = list(att.ints) 

73 elif att.type == 8: # .strings 

74 value = list(att.strings) 

75 else: 

76 raise NotImplementedError( # pragma: no cover 

77 f"Unable to copy attribute type {att.type!r} ({att!r}).") 

78 atts[att.name] = value 

79 

80 container.add_node( 

81 node.op_type, 

82 name=scope.get_unique_operator_name('_sub_' + node.name), 

83 inputs=new_inputs, outputs=new_outputs, op_domain=node.domain, 

84 op_version=opsets.get(node.domain, None), **atts) 

85 

86 # linking outputs 

87 if len(onx.graph.output) != len(outputs): 

88 raise RuntimeError( # pragma: no cover 

89 "Output size mismatch %r != %r.\n--ONNX--\n%s" % ( 

90 len(onx.graph.output), len(outputs), 

91 onnx_simple_text_plot(onx))) 

92 for out, var in zip(onx.graph.output, outputs): 

93 container.add_node( 

94 'Identity', name=scope.get_unique_operator_name( 

95 '_sub_' + out.name), 

96 inputs=[mapped_names[out.name]], outputs=[var.onnx_name]) 

97 

98 

99def _common_shape_calculator_t(operator): 

100 if not hasattr(operator, 'onnx_numpy_fct_'): 

101 raise AttributeError( 

102 "operator must have attribute 'onnx_numpy_fct_'.") 

103 X = operator.inputs 

104 if len(X) != 1: 

105 raise RuntimeError( 

106 f"This function only supports one input not {len(X)!r}.") 

107 if len(operator.outputs) != 1: 

108 raise RuntimeError( 

109 f"This function only supports one output not {len(operator.outputs)!r}.") 

110 op = operator.raw_operator 

111 cl = X[0].type.__class__ 

112 dim = [X[0].type.shape[0], getattr(op, 'n_outputs_', None)] 

113 operator.outputs[0].type = cl(dim) 

114 

115 

116def _shape_calculator_transformer(operator): 

117 """ 

118 Default shape calculator for a transformer with one input 

119 and one output of the same type. 

120 

121 .. versionadded:: 0.6 

122 """ 

123 _common_shape_calculator_t(operator) 

124 

125 

126def _shape_calculator_regressor(operator): 

127 """ 

128 Default shape calculator for a regressor with one input 

129 and one output of the same type. 

130 

131 .. versionadded:: 0.6 

132 """ 

133 _common_shape_calculator_t(operator) 

134 

135 

136def _common_shape_calculator_int_t(operator): 

137 if not hasattr(operator, 'onnx_numpy_fct_'): 

138 raise AttributeError( 

139 "operator must have attribute 'onnx_numpy_fct_'.") 

140 X = operator.inputs 

141 if len(X) != 1: 

142 raise RuntimeError( 

143 f"This function only supports one input not {len(X)!r}.") 

144 if len(operator.outputs) != 2: 

145 raise RuntimeError( 

146 f"This function only supports two outputs not {len(operator.outputs)!r}.") 

147 from skl2onnx.common.data_types import Int64TensorType # delayed 

148 op = operator.raw_operator 

149 cl = X[0].type.__class__ 

150 dim = [X[0].type.shape[0], getattr(op, 'n_outputs_', None)] 

151 operator.outputs[0].type = Int64TensorType(dim[:1]) 

152 operator.outputs[1].type = cl(dim) 

153 

154 

155def _shape_calculator_classifier(operator): 

156 """ 

157 Default shape calculator for a classifier with one input 

158 and two outputs, label (int64) and probabilites of the same type. 

159 

160 .. versionadded:: 0.6 

161 """ 

162 _common_shape_calculator_int_t(operator) 

163 

164 

165def _shape_calculator_cluster(operator): 

166 """ 

167 Default shape calculator for a clustering with one input 

168 and two outputs, label (int64) and distances of the same type. 

169 

170 .. versionadded:: 0.6 

171 """ 

172 _common_shape_calculator_int_t(operator) 

173 

174 

175def _common_converter_begin(scope, operator, container, n_outputs): 

176 if not hasattr(operator, 'onnx_numpy_fct_'): 

177 raise AttributeError( 

178 "operator must have attribute 'onnx_numpy_fct_'.") 

179 X = operator.inputs 

180 if len(X) != 1: 

181 raise RuntimeError( 

182 f"This function only supports one input not {len(X)!r}.") 

183 if len(operator.outputs) != n_outputs: 

184 raise RuntimeError( 

185 "This function only supports %d output not %r." % ( 

186 n_outputs, len(operator.outputs))) 

187 

188 # First conversion of the model to onnx 

189 # Then addition of the onnx graph to the main graph. 

190 from .onnx_variable import OnnxVar 

191 new_var = Variable.from_skl2onnx(X[0]) 

192 xvar = OnnxVar(new_var) 

193 fct_cl = operator.onnx_numpy_fct_ 

194 

195 opv = container.target_opset 

196 logger.debug("_common_converter_begin:xvar=%r op=%s", 

197 xvar, type(operator.raw_operator)) 

198 inst = fct_cl.fct(xvar, op_=operator.raw_operator) 

199 logger.debug("_common_converter_begin:inst=%r opv=%r fct_cl.fct=%r", 

200 type(inst), opv, fct_cl.fct) 

201 onx = inst.to_algebra(op_version=opv) 

202 logger.debug("_common_converter_begin:end:onx=%r", type(onx)) 

203 return new_var, onx 

204 

205 

206def _common_converter_t(scope, operator, container): 

207 logger.debug("_common_converter_t:op=%r -> %r", 

208 operator.inputs, operator.outputs) 

209 OnnxIdentity = loadop('Identity') 

210 opv = container.target_opset 

211 new_var, onx = _common_converter_begin(scope, operator, container, 1) 

212 final = OnnxIdentity(onx, op_version=opv, 

213 output_names=[operator.outputs[0].full_name]) 

214 onx_model = final.to_onnx( 

215 [new_var], [Variable.from_skl2onnx(o) for o in operator.outputs], 

216 target_opset=opv) 

217 _skl2onnx_add_to_container(onx_model, scope, container, operator.outputs) 

218 logger.debug("_common_converter_t:end") 

219 

220 

221def _converter_transformer(scope, operator, container): 

222 """ 

223 Default converter for a transformer with one input 

224 and one output of the same type. It assumes instance *operator* 

225 has an attribute *onnx_numpy_fct_* from a function 

226 wrapped with decorator :func:`onnxsklearn_transformer 

227 <mlprodict.npy.onnx_sklearn_wrapper.onnxsklearn_transformer>`. 

228 

229 .. versionadded:: 0.6 

230 """ 

231 _common_converter_t(scope, operator, container) 

232 

233 

234def _converter_regressor(scope, operator, container): 

235 """ 

236 Default converter for a regressor with one input 

237 and one output of the same type. It assumes instance *operator* 

238 has an attribute *onnx_numpy_fct_* from a function 

239 wrapped with decorator :func:`onnxsklearn_regressor 

240 <mlprodict.npy.onnx_sklearn_wrapper.onnxsklearn_regressor>`. 

241 

242 .. versionadded:: 0.6 

243 """ 

244 _common_converter_t(scope, operator, container) 

245 

246 

247def _common_converter_int_t(scope, operator, container): 

248 logger.debug("_common_converter_int_t:op=%r -> %r", 

249 operator.inputs, operator.outputs) 

250 OnnxIdentity = loadop('Identity') 

251 opv = container.target_opset 

252 new_var, onx = _common_converter_begin(scope, operator, container, 2) 

253 

254 if isinstance(onx, OnnxOperatorTuple): 

255 if len(operator.outputs) != len(onx): 

256 raise RuntimeError( # pragma: no cover 

257 "Mismatched number of outputs expected %d, got %d." % ( 

258 len(operator.outputs), len(onx))) 

259 first_output = None 

260 other_outputs = [] 

261 for out, ox in zip(operator.outputs, onx): 

262 if not hasattr(ox, 'add_to'): 

263 raise TypeError( # pragma: no cover 

264 "Unexpected type for onnx graph %r, inst=%r." % ( 

265 type(ox), type(operator.raw_operator))) 

266 final = OnnxIdentity(ox, op_version=opv, 

267 output_names=[out.full_name]) 

268 if first_output is None: 

269 first_output = final 

270 else: 

271 other_outputs.append(final) 

272 

273 onx_model = first_output.to_onnx( 

274 [new_var], 

275 [Variable.from_skl2onnx(o) for o in operator.outputs], 

276 target_opset=opv, other_outputs=other_outputs) 

277 _skl2onnx_add_to_container( 

278 onx_model, scope, container, operator.outputs) 

279 logger.debug("_common_converter_int_t:1:end") 

280 else: 

281 final = OnnxIdentity(onx, op_version=opv, 

282 output_names=[operator.outputs[0].full_name]) 

283 onx_model = final.to_onnx( 

284 [new_var], 

285 [Variable.from_skl2onnx(o) for o in operator.outputs], 

286 target_opset=opv) 

287 _skl2onnx_add_to_container( 

288 onx_model, scope, container, operator.outputs) 

289 logger.debug("_common_converter_int_t:2:end") 

290 

291 

292def _converter_classifier(scope, operator, container): 

293 """ 

294 Default converter for a classifier with one input 

295 and two outputs, label and probabilities of the same input type. 

296 It assumes instance *operator* 

297 has an attribute *onnx_numpy_fct_* from a function 

298 wrapped with decorator :func:`onnxsklearn_classifier 

299 <mlprodict.npy.onnx_sklearn_wrapper.onnxsklearn_classifier>`. 

300 

301 .. versionadded:: 0.6 

302 """ 

303 _common_converter_int_t(scope, operator, container) 

304 

305 

306def _converter_cluster(scope, operator, container): 

307 """ 

308 Default converter for a clustering with one input 

309 and two outputs, label and distances of the same input type. 

310 It assumes instance *operator* 

311 has an attribute *onnx_numpy_fct_* from a function 

312 wrapped with decorator :func:`onnxsklearn_cluster 

313 <mlprodict.npy.onnx_sklearn_wrapper.onnxsklearn_cluster>`. 

314 

315 .. versionadded:: 0.6 

316 """ 

317 _common_converter_int_t(scope, operator, container) 

318 

319 

320_default_cvt = { 

321 ClassifierMixin: (_shape_calculator_classifier, _converter_classifier), 

322 ClusterMixin: (_shape_calculator_cluster, _converter_cluster), 

323 RegressorMixin: (_shape_calculator_regressor, _converter_regressor), 

324 TransformerMixin: (_shape_calculator_transformer, _converter_transformer), 

325} 

326 

327 

328def update_registered_converter_npy( 

329 model, alias, convert_fct, shape_fct=None, overwrite=True, 

330 parser=None, options=None): 

331 """ 

332 Registers or updates a converter for a new model so that 

333 it can be converted when inserted in a *scikit-learn* pipeline. 

334 This function assumes the converter is written as a function 

335 decoarated with :func:`onnxsklearn_transformer 

336 <mlprodict.npy.onnx_sklearn_wrapper.onnxsklearn_transformer>`. 

337 

338 :param model: model class 

339 :param alias: alias used to register the model 

340 :param shape_fct: function which checks or modifies the expected 

341 outputs, this function should be fast so that the whole graph 

342 can be computed followed by the conversion of each model, 

343 parallelized or not 

344 :param convert_fct: function which converts a model 

345 :param overwrite: False to raise exception if a converter 

346 already exists 

347 :param parser: overwrites the parser as well if not empty 

348 :param options: registered options for this converter 

349 

350 The alias is usually the library name followed by the model name. 

351 

352 .. versionadded:: 0.6 

353 """ 

354 if (hasattr(convert_fct, "compiled") or 

355 hasattr(convert_fct, 'signed_compiled')): 

356 # type is wrapper_onnxnumpy or wrapper_onnxnumpy_np 

357 obj = convert_fct 

358 else: 

359 raise AttributeError( # pragma: no cover 

360 "Class %r must have attribute 'compiled' or 'signed_compiled' " 

361 "(object=%r)." % (type(convert_fct), convert_fct)) 

362 

363 def addattr(operator, obj): 

364 operator.onnx_numpy_fct_ = obj 

365 return operator 

366 

367 if issubclass(model, TransformerMixin): 

368 defcl = TransformerMixin 

369 elif issubclass(model, RegressorMixin): 

370 defcl = RegressorMixin 

371 elif issubclass(model, ClassifierMixin): 

372 defcl = ClassifierMixin 

373 elif issubclass(model, ClusterMixin): 

374 defcl = ClusterMixin 

375 else: 

376 defcl = None 

377 

378 if shape_fct is not None: 

379 raise NotImplementedError( # pragma: no cover 

380 "Custom shape calculator are not implemented yet.") 

381 

382 shc = _default_cvt[defcl][0] 

383 local_shape_fct = ( 

384 lambda operator: shc(addattr(operator, obj))) 

385 

386 cvtc = _default_cvt[defcl][1] 

387 local_convert_fct = ( 

388 lambda scope, operator, container: 

389 cvtc(scope, addattr(operator, obj), container)) 

390 

391 from skl2onnx import update_registered_converter # delayed 

392 update_registered_converter( 

393 model, alias, convert_fct=local_convert_fct, 

394 shape_fct=local_shape_fct, overwrite=overwrite, 

395 parser=parser, options=options) 

396 

397 

398def _internal_decorator(fct, op_version=None, runtime=None, signature=None, 

399 register_class=None, overwrite=True, options=None): 

400 name = f"onnxsklearn_parser_{fct.__name__}_{str(op_version)}_{runtime}" 

401 newclass = type( 

402 name, (wrapper_onnxnumpy_np,), { 

403 '__doc__': fct.__doc__, 

404 '__name__': name, 

405 '__getstate__': wrapper_onnxnumpy_np.__getstate__, 

406 '__setstate__': wrapper_onnxnumpy_np.__setstate__}) 

407 _created_classes_inst.append(name, newclass) 

408 res = newclass( 

409 fct=fct, op_version=op_version, runtime=runtime, 

410 signature=signature) 

411 if register_class is not None: 

412 update_registered_converter_npy( 

413 register_class, f"Sklearn{getattr(register_class, '__name__', 'noname')}", 

414 res, shape_fct=None, overwrite=overwrite, options=options) 

415 return res 

416 

417 

418def onnxsklearn_transformer(op_version=None, runtime=None, signature=None, 

419 register_class=None, overwrite=True): 

420 """ 

421 Decorator to declare a converter for a transformer implemented using 

422 :epkg:`numpy` syntax but executed with :epkg:`ONNX` 

423 operators. 

424 

425 :param op_version: :epkg:`ONNX` opset version 

426 :param runtime: `'onnxruntime'` or one implemented by @see cl OnnxInference 

427 :param signature: if None, the signature is replaced by a standard signature 

428 for transformer ``NDArraySameType("all")`` 

429 :param register_class: automatically register this converter 

430 for this class to :epkg:`sklearn-onnx` 

431 :param overwrite: overwrite existing registered function if any 

432 

433 .. versionadded:: 0.6 

434 """ 

435 if signature is None: 

436 signature = NDArraySameType("all") 

437 

438 def decorator_fct(fct): 

439 return _internal_decorator(fct, signature=signature, 

440 op_version=op_version, 

441 runtime=runtime, 

442 register_class=register_class, 

443 overwrite=overwrite) 

444 return decorator_fct 

445 

446 

447def onnxsklearn_regressor(op_version=None, runtime=None, signature=None, 

448 register_class=None, overwrite=True): 

449 """ 

450 Decorator to declare a converter for a regressor implemented using 

451 :epkg:`numpy` syntax but executed with :epkg:`ONNX` 

452 operators. 

453 

454 :param op_version: :epkg:`ONNX` opset version 

455 :param runtime: `'onnxruntime'` or one implemented by @see cl OnnxInference 

456 :param signature: if None, the signature is replaced by a standard signature 

457 for transformer ``NDArraySameType("all")`` 

458 :param register_class: automatically register this converter 

459 for this class to :epkg:`sklearn-onnx` 

460 :param overwrite: overwrite existing registered function if any 

461 

462 .. versionadded:: 0.6 

463 """ 

464 if signature is None: 

465 signature = NDArraySameType("all") 

466 

467 def decorator_fct(fct): 

468 return _internal_decorator(fct, signature=signature, 

469 op_version=op_version, 

470 runtime=runtime, 

471 register_class=register_class, 

472 overwrite=overwrite) 

473 return decorator_fct 

474 

475 

476def onnxsklearn_classifier(op_version=None, runtime=None, signature=None, 

477 register_class=None, overwrite=True): 

478 """ 

479 Decorator to declare a converter for a classifier implemented using 

480 :epkg:`numpy` syntax but executed with :epkg:`ONNX` 

481 operators. 

482 

483 :param op_version: :epkg:`ONNX` opset version 

484 :param runtime: `'onnxruntime'` or one implemented by @see cl OnnxInference 

485 :param signature: if None, the signature is replaced by a standard signature 

486 for transformer ``NDArraySameType("all")`` 

487 :param register_class: automatically register this converter 

488 for this class to :epkg:`sklearn-onnx` 

489 :param overwrite: overwrite existing registered function if any 

490 

491 .. versionadded:: 0.6 

492 """ 

493 if signature is None: 

494 signature = NDArrayType(("T:all", ), dtypes_out=((numpy.int64, ), 'T')) 

495 

496 def decorator_fct(fct): 

497 return _internal_decorator(fct, signature=signature, 

498 op_version=op_version, 

499 runtime=runtime, 

500 register_class=register_class, 

501 overwrite=overwrite, 

502 options={'zipmap': [False, True, 'columns'], 

503 'nocl': [False, True]}) 

504 return decorator_fct 

505 

506 

507def onnxsklearn_cluster(op_version=None, runtime=None, signature=None, 

508 register_class=None, overwrite=True): 

509 """ 

510 Decorator to declare a converter for a cluster implemented using 

511 :epkg:`numpy` syntax but executed with :epkg:`ONNX` 

512 operators. 

513 

514 :param op_version: :epkg:`ONNX` opset version 

515 :param runtime: `'onnxruntime'` or one implemented by @see cl OnnxInference 

516 :param signature: if None, the signature is replaced by a standard signature 

517 for transformer ``NDArraySameType("all")`` 

518 :param register_class: automatically register this converter 

519 for this class to :epkg:`sklearn-onnx` 

520 :param overwrite: overwrite existing registered function if any 

521 

522 .. versionadded:: 0.6 

523 """ 

524 if signature is None: 

525 signature = NDArrayType(("T:all", ), dtypes_out=((numpy.int64, ), 'T')) 

526 

527 def decorator_fct(fct): 

528 return _internal_decorator(fct, signature=signature, 

529 op_version=op_version, 

530 runtime=runtime, 

531 register_class=register_class, 

532 overwrite=overwrite) 

533 return decorator_fct 

534 

535 

536def _call_validate(self, X): 

537 if hasattr(self, "_validate_onnx_data"): 

538 return self._validate_onnx_data(X) 

539 return X 

540 

541 

542def _internal_method_decorator(register_class, method, op_version=None, 

543 runtime=None, signature=None, 

544 method_names=None, overwrite=True, 

545 options=None): 

546 if isinstance(method_names, str): 

547 method_names = (method_names, ) 

548 

549 if issubclass(register_class, TransformerMixin): 

550 if signature is None: 

551 signature = NDArraySameType("all") 

552 if method_names is None: 

553 method_names = ("transform", ) 

554 elif issubclass(register_class, RegressorMixin): 

555 if signature is None: 

556 signature = NDArraySameType("all") 

557 if method_names is None: 

558 method_names = ("predict", ) 

559 elif issubclass(register_class, ClassifierMixin): 

560 if signature is None: 

561 signature = NDArrayType( 

562 ("T:all", ), dtypes_out=((numpy.int64, ), 'T')) 

563 if method_names is None: 

564 method_names = ("predict", "predict_proba") 

565 if options is None: 

566 options = {'zipmap': [False, True, 'columns'], 

567 'nocl': [False, True]} 

568 elif issubclass(register_class, ClusterMixin): 

569 if signature is None: 

570 signature = NDArrayType( 

571 ("T:all", ), dtypes_out=((numpy.int64, ), 'T')) 

572 if method_names is None: 

573 method_names = ("predict", "transform") 

574 elif method_names is None: # pragma: no cover 

575 raise RuntimeError( 

576 "No obvious API was detected (one among %s), " 

577 "then 'method_names' must be specified and not left " 

578 "empty." % (", ".join(map(lambda s: s.__name__, _default_cvt)))) 

579 

580 if method_names is None: 

581 raise RuntimeError( # pragma: no cover 

582 "Methods to overwrite are not known for class %r and " 

583 "method %r." % (register_class, method)) 

584 if signature is None: 

585 raise RuntimeError( # pragma: no cover 

586 "Methods to overwrite are not known for class %r and " 

587 "method %r." % (register_class, method)) 

588 

589 name = f"onnxsklearn_parser_{register_class.__name__}_{str(op_version)}_{runtime}" 

590 newclass = type( 

591 name, (wrapper_onnxnumpy_np,), { 

592 '__doc__': method.__doc__, 

593 '__name__': name, 

594 '__getstate__': wrapper_onnxnumpy_np.__getstate__, 

595 '__setstate__': wrapper_onnxnumpy_np.__setstate__}) 

596 _created_classes_inst.append(name, newclass) 

597 

598 def _check_(op): 

599 if isinstance(op, str): 

600 raise TypeError( # pragma: no cover 

601 f"Unexpected type: {type(op)!r}: {op!r}.") 

602 return op 

603 

604 res = newclass( 

605 fct=lambda *args, op_=None, **kwargs: method( 

606 _check_(op_), *args, **kwargs), 

607 op_version=op_version, runtime=runtime, signature=signature, 

608 fctsig=method) 

609 

610 if len(method_names) == 1: 

611 name = method_names[0] 

612 if hasattr(register_class, name): 

613 raise RuntimeError( # pragma: no cover 

614 "Cannot overwrite method %r because it already exists in " 

615 "class %r." % (name, register_class)) 

616 m = lambda self, X: res(_call_validate(self, X), op_=self) 

617 setattr(register_class, name, m) 

618 elif len(method_names) == 0: 

619 raise RuntimeError("No available method.") # pragma: no cover 

620 else: 

621 m = lambda self, X: res(_call_validate(self, X), op_=self) 

622 setattr(register_class, method.__name__ + "_", m) 

623 for iname, name in enumerate(method_names): 

624 if hasattr(register_class, name): 

625 raise RuntimeError( # pragma: no cover 

626 "Cannot overwrite method %r because it already exists in " 

627 "class %r." % (name, register_class)) 

628 m = (lambda self, X, index_output=iname: 

629 res(_call_validate(self, X), op_=self)[index_output]) 

630 setattr(register_class, name, m) 

631 

632 update_registered_converter_npy( 

633 register_class, f"Sklearn{getattr(register_class, '__name__', 'noname')}", 

634 res, shape_fct=None, overwrite=overwrite, 

635 options=options) 

636 return res 

637 

638 

639def onnxsklearn_class(method_name, op_version=None, runtime=None, 

640 signature=None, method_names=None, 

641 overwrite=True): 

642 """ 

643 Decorator to declare a converter for a class derivated from 

644 :epkg:`scikit-learn`, implementing inference method 

645 and using :epkg:`numpy` syntax but executed with 

646 :epkg:`ONNX` operators. 

647 

648 :param method_name: name of the method implementing the 

649 inference method with :epkg:`numpy` API for ONNX 

650 :param op_version: :epkg:`ONNX` opset version 

651 :param runtime: `'onnxruntime'` or one implemented by @see cl OnnxInference 

652 :param signature: if None, the signature is replaced by a standard signature 

653 depending on the model kind, otherwise, it is the signature of the 

654 ONNX function 

655 :param method_names: if None, method names is guessed based on 

656 the class kind (transformer, regressor, classifier, clusterer) 

657 :param overwrite: overwrite existing registered function if any 

658 

659 .. versionadded:: 0.6 

660 """ 

661 def decorator_class(objclass): 

662 _internal_method_decorator( 

663 objclass, method=getattr(objclass, method_name), 

664 signature=signature, op_version=op_version, 

665 runtime=runtime, method_names=method_names, 

666 overwrite=overwrite) 

667 return objclass 

668 

669 return decorator_class