Coverage for mlprodict/npy/onnx_variable.py: 94%

417 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Intermediate class between :epkg:`numpy` and :epkg:`onnx`. 

4 

5.. versionadded:: 0.6 

6""" 

7import logging 

8import numpy 

9from onnx.helper import make_tensor 

10from ..onnx_tools.onnx2py_helper import guess_proto_dtype 

11from .xop_variable import Variable 

12from .xop import loadop, OnnxOperatorItem, OnnxOperatorTuple 

13from .xop_variable import guess_numpy_type 

14 

15logger = logging.getLogger('xop') 

16 

17 

18try: 

19 numpy_bool = numpy.bool_ 

20except AttributeError: # pragma: no cover 

21 numpy_bool = bool 

22try: 

23 numpy_str = numpy.str_ 

24except AttributeError: # pragma: no cover 

25 numpy_str = str 

26 

27 

28class OnnxVar: 

29 """ 

30 Variables used into :epkg:`onnx` computation. 

31 

32 :param inputs: variable name or object 

33 :param op: :epkg:`ONNX` operator 

34 :param select_output: if multiple output are returned by 

35 ONNX operator *op*, it takes only one specifed by this 

36 argument 

37 :param dtype: specifies the type of the variable 

38 held by this class (*op* is None) in that case 

39 :param kwargs: addition argument to give operator *op* 

40 

41 .. versionadded:: 0.6 

42 """ 

43 __array_ufunc__ = None 

44 

45 def __init__(self, *inputs, op=None, select_output=None, 

46 dtype=None, **kwargs): 

47 logger.debug('OnnxVar(%d in, dtype=%r, op=%r, select_output=%r)', 

48 len(inputs), dtype, op, select_output) 

49 self.inputs = inputs 

50 self.select_output = select_output 

51 self.onnx_op = op 

52 self.alg_ = None 

53 self.onnx_op_kwargs = kwargs 

54 if dtype is not None and (op is not None or len(inputs) != 1): 

55 raise RuntimeError( # pragma: no cover 

56 "dtype can only be used if op is None or len(inputs) == 1.") 

57 for i, inp in enumerate(self.inputs): 

58 if isinstance(inp, type): 

59 raise TypeError( # pragma: no cover 

60 "Unexpected type for input %d - %r." % (i, inp)) 

61 if not isinstance(inp, numpy.ndarray): 

62 continue 

63 if (inp.size > 0 and 

64 isinstance(inp.ravel()[0], (numpy.ndarray, OnnxVar))): 

65 raise TypeError( # pragma: no cover 

66 "Unexpected type for input %d: %r, %r, " 

67 "op=%r" % (i, type(inp), inp.ravel()[0], op)) 

68 self.dtype = self._guess_dtype(dtype, from_init=True) 

69 

70 def _guess_dtype(self, dtype, from_init=False): 

71 "Guesses dtype when not specified." 

72 if dtype is not None: 

73 return dtype 

74 dtypes = [] 

75 for i, inp in enumerate(self.inputs): 

76 if isinstance(inp, str): 

77 return None 

78 if isinstance(inp, numpy.ndarray): 

79 dtypes.append(inp.dtype) 

80 elif isinstance(inp, Variable): 

81 dtypes.append(inp.dtype) 

82 elif isinstance(inp, OnnxVar): 

83 dtypes.append(inp.dtype) 

84 elif isinstance(inp, MultiOnnxVar): 

85 dtypes.append(inp._guess_dtype(dtype)) 

86 elif isinstance(inp, (numpy.float32, numpy.float64, 

87 numpy.int32, numpy.int64)): 

88 dtypes.append(inp.dtype) 

89 elif isinstance(inp, numpy_str): 

90 dtypes.append(numpy_str) 

91 elif isinstance(inp, numpy_bool): 

92 dtypes.append(numpy_bool) 

93 elif isinstance(inp, int): 

94 dtypes.append(numpy.int64) # pragma: no cover 

95 elif isinstance(inp, float): 

96 dtypes.append(numpy.float64) 

97 elif hasattr(inp, 'fit'): 

98 # scikit-learn model 

99 continue 

100 elif hasattr(inp, '_guess_dtype'): 

101 dtypes.append(inp._guess_dtype(dtype)) 

102 else: 

103 try: 

104 dtype = guess_numpy_type(inp) 

105 except NotImplementedError as e: # pragma: no cover 

106 raise TypeError( 

107 "Unexpected type for input %i type=%r." % ( 

108 i, type(inp))) from e 

109 dtypes.append(dtype) 

110 dtypes = [_ for _ in dtypes if _ is not None] 

111 unique = set(dtypes) 

112 if len(unique) != 1: 

113 return None 

114 return dtypes[0] 

115 

116 def __repr__(self): 

117 "usual" 

118 args = [] 

119 for inp in self.inputs: 

120 args.append(repr(inp)) 

121 if self.onnx_op is not None: 

122 if isinstance(self.onnx_op, str): 

123 args.append(f"op={self.onnx_op!r}") 

124 else: 

125 args.append(f"op={self.onnx_op.__name__}") 

126 if self.select_output is not None: 

127 args.append(f"select_output={self.select_output!r}") 

128 if self.dtype is not None and self.dtype != self._guess_dtype(None): 

129 args.append(f"dtype={self.dtype!r}") 

130 for k, v in sorted(self.onnx_op_kwargs.items()): 

131 args.append(f"{k}={v!r}") 

132 res = f"{self.__class__.__name__}({', '.join(args)})" 

133 return res 

134 

135 def set_onnx_name(self, name_type): 

136 """ 

137 Forces this variable to get this name during 

138 

139 :param name_type: a tuple *(name, type)* 

140 """ 

141 self.onnx_input_type_ = name_type 

142 

143 def to_algebra(self, op_version=None): 

144 """ 

145 Converts the variable into an operator. 

146 """ 

147 if self.alg_ is not None: 

148 return self.alg_ 

149 

150 if self.onnx_op is None: 

151 logger.debug('OnnxVar.to_algebra:1(op_version=%r)', op_version) 

152 if len(self.inputs) != 1: 

153 raise RuntimeError( # pragma: no cover 

154 "Unexpected number of inputs, 1 expected, " 

155 "got {} instead.".format(self.inputs)) 

156 if self.dtype is None or hasattr(self.inputs[0], 'onnx_name'): 

157 self.alg_ = Variable.from_skl2onnx(self.inputs[0]) 

158 elif isinstance(self.inputs[0], Variable): 

159 self.alg_ = self.inputs[0] 

160 else: 

161 self.alg_ = Variable(self.inputs[0], self.dtype) 

162 else: 

163 logger.debug('OnnxVar.to_algebra:2(op_version=%r) - onnx_op=%r', 

164 op_version, self.onnx_op) 

165 if isinstance(self.onnx_op, str): 

166 var = self._custom_op(*self.inputs, op_version=op_version, 

167 **self.onnx_op_kwargs) 

168 alg = var.to_algebra(op_version=op_version) 

169 if not hasattr(self, 'alg_'): 

170 raise RuntimeError( # pragma: no cover 

171 "Missing attribute 'alg_'.") 

172 self.alg_ = alg 

173 return alg 

174 

175 new_inputs = [] 

176 for inp in self.inputs: 

177 if hasattr(inp, 'fit'): 

178 # scikit-learn model 

179 new_inputs.append(inp) 

180 elif isinstance(inp, ( 

181 int, float, str, numpy.ndarray, numpy.int32, 

182 numpy.int64, numpy.float32, numpy.float64, 

183 numpy_bool, numpy_str, numpy.int8, numpy.uint8, 

184 numpy.int16, numpy.uint16, numpy.uint32, 

185 numpy.uint64)): 

186 if (inp.size > 0 and 

187 isinstance( 

188 inp.ravel()[0], # pylint: disable=E1101 

189 (numpy.ndarray, OnnxVar))): 

190 raise TypeError( # pragma: no cover 

191 "Unexpected type for an input %r, %r." 

192 "" % (type(inp), inp.ravel()[0])) # pylint: disable=E1101 

193 new_inputs.append(inp) 

194 else: 

195 new_inputs.append( 

196 inp.to_algebra(op_version=op_version)) 

197 

198 res = self.onnx_op(*new_inputs, op_version=op_version, 

199 **self.onnx_op_kwargs) 

200 if self.select_output is None: 

201 self.alg_ = res 

202 else: 

203 self.alg_ = res[self.select_output] 

204 return self.alg_ 

205 

206 def _custom_op(self, *args, op_version=None, runtime=None, **kwargs): 

207 """ 

208 This could be handled before a call to this method 

209 but this method can change the conversion of an non-existing 

210 operator depending on the given opset. 

211 """ 

212 if self.onnx_op == 'filter': 

213 return self._custom_op_filter(*args, op_version=op_version, 

214 runtime=runtime, **kwargs) 

215 raise NotImplementedError( # pragma: no cover 

216 f"Unexpected custom operator {self.onnx_op!r}.") 

217 

218 def _custom_op_filter(self, *args, op_version=None, runtime=None, **kwargs): 

219 """ 

220 This could be handled before a call to this method 

221 but this method can change the conversion of an non-existing 

222 operator depending on the given opset. 

223 """ 

224 OnnxSqueeze, OnnxTopK, OnnxGather, OnnxReduceSum = loadop( 

225 'Squeeze', 'TopK', 'Gather', 'ReduceSum') 

226 if len(args) != 2: 

227 raise RuntimeError( # pragma: no cover 

228 f"Custom op 'filter' expects two inputs not {len(args)!r}.") 

229 if len(kwargs) != 0: 

230 raise RuntimeError( # pragma: no cover 

231 f"Custom op 'filter' expects no arguments but got {kwargs!r}.") 

232 mat, index = args 

233 cast = OnnxVar(index.astype(numpy.int64), op=OnnxSqueeze) 

234 n1 = OnnxVar(cast, op=OnnxReduceSum, keepdims=1) 

235 indices = OnnxVar(cast, n1, op=OnnxTopK, select_output=1) 

236 return OnnxVar(mat, indices, op=OnnxGather) 

237 

238 @property 

239 def T(self): 

240 "Transpose." 

241 OnnxTranspose = loadop('Transpose') 

242 return OnnxVar(self, op=OnnxTranspose) 

243 

244 def astype(self, dtype): 

245 "Cast" 

246 OnnxCast = loadop('Cast') 

247 return OnnxVar(self, op=OnnxCast, to=guess_proto_dtype(dtype)) 

248 

249 @property 

250 def shape(self): 

251 "Shape" 

252 OnnxShape = loadop('Shape') 

253 return OnnxVar(self, op=OnnxShape) 

254 

255 @property 

256 def size(self): 

257 "Size" 

258 OnnxSize = loadop('Size') 

259 return OnnxVar(self, op=OnnxSize) 

260 

261 def reshape(self, shape): 

262 "Reshape" 

263 OnnxReshape = loadop('Reshape') 

264 if isinstance(shape, (tuple, list)): 

265 shape = numpy.array(shape, dtype=numpy.int64) 

266 return OnnxVar(self, shape, op=OnnxReshape) 

267 

268 def _make_array(self, y): 

269 """Converts *y* into an array if not.""" 

270 if isinstance(y, (numpy.ndarray, OnnxVar)): 

271 return y 

272 if hasattr(y, 'dtype'): 

273 return numpy.full((1, ), y, dtype=y.dtype) 

274 if isinstance(y, str): 

275 return numpy.array([y]) 

276 if isinstance(y, float): 

277 return numpy.array([y], dtype=numpy.float32) 

278 if isinstance(y, int): 

279 return numpy.array([y], dtype=numpy.int64) 

280 return y 

281 

282 def __add__(self, y): 

283 "Addition." 

284 y = self._make_array(y) 

285 OnnxAdd = loadop('Add') 

286 return OnnxVar(self, y, op=OnnxAdd) 

287 

288 def __radd__(self, y): 

289 "Right Addition." 

290 y = self._make_array(y) 

291 OnnxIdentity, OnnxAdd = loadop('Identity', 'Add') 

292 return OnnxVar(OnnxVar(y, op=OnnxIdentity), self, op=OnnxAdd) 

293 

294 def __sub__(self, y): 

295 "Subtraction." 

296 y = self._make_array(y) 

297 OnnxSub = loadop('Sub') 

298 return OnnxVar(self, y, op=OnnxSub) 

299 

300 def __rsub__(self, y): 

301 "Right subtraction." 

302 y = self._make_array(y) 

303 OnnxIdentity, OnnxSub = loadop('Identity', 'Sub') 

304 return OnnxVar(OnnxVar(y, op=OnnxIdentity), self, op=OnnxSub) 

305 

306 def __mul__(self, y): 

307 "Multiplication." 

308 y = self._make_array(y) 

309 OnnxMul = loadop('Mul') 

310 return OnnxVar(self, y, op=OnnxMul) 

311 

312 def __rmul__(self, y): 

313 "Right multiplication." 

314 y = self._make_array(y) 

315 OnnxIdentity = loadop('Identity') 

316 return OnnxVar(y, op=OnnxIdentity) * self 

317 

318 def __pow__(self, y): 

319 "Power." 

320 y = self._make_array(y) 

321 OnnxPow = loadop('Pow') 

322 return OnnxVar(self, y, op=OnnxPow) 

323 

324 def __mod__(self, y): 

325 "Modulo." 

326 y = self._make_array(y) 

327 OnnxMod = loadop('Mod') 

328 return OnnxVar(self, y, op=OnnxMod) 

329 

330 def __matmul__(self, y): 

331 "Matrix multiplication." 

332 y = self._make_array(y) 

333 OnnxMatMul = loadop('MatMul') 

334 return OnnxVar(self, y, op=OnnxMatMul) 

335 

336 def __truediv__(self, y): 

337 "Division, no difference between `/` and `//`." 

338 y = self._make_array(y) 

339 OnnxDiv = loadop('Div') 

340 return OnnxVar(self, y, op=OnnxDiv) 

341 

342 def __rtruediv__(self, y): 

343 "Division, no difference between `/` and `//`." 

344 y = self._make_array(y) 

345 OnnxIdentity, OnnxDiv = loadop('Identity', 'Div') 

346 return OnnxVar(OnnxVar(y, op=OnnxIdentity), self, op=OnnxDiv) 

347 

348 def __floordiv__(self, y): 

349 "Division, no difference between `/` and `//`." 

350 y = self._make_array(y) 

351 OnnxDiv = loadop('Div') 

352 return OnnxVar(self, y, op=OnnxDiv) 

353 

354 def __eq__(self, y): 

355 "Equality." 

356 y = self._make_array(y) 

357 OnnxEqual = loadop('Equal') 

358 return OnnxVar(self, y, op=OnnxEqual) 

359 

360 def __ne__(self, y): 

361 "Difference." 

362 y = self._make_array(y) 

363 OnnxEqual, OnnxNot = loadop('Equal', 'Not') 

364 return OnnxVar(OnnxVar(self, y, op=OnnxEqual), op=OnnxNot) 

365 

366 def __ge__(self, y): 

367 "Greater or Equal." 

368 y = self._make_array(y) 

369 OnnxGreaterOrEqual = loadop('GreaterOrEqual') 

370 return OnnxVar(self, y, op=OnnxGreaterOrEqual) 

371 

372 def __gt__(self, y): 

373 "Greater." 

374 y = self._make_array(y) 

375 OnnxGreater = loadop('Greater') 

376 return OnnxVar(self, y, op=OnnxGreater) 

377 

378 def __invert__(self): 

379 "not." 

380 OnnxNot = loadop('Not') 

381 return OnnxVar(self, op=OnnxNot) 

382 

383 def __le__(self, y): 

384 "Less or Equal." 

385 y = self._make_array(y) 

386 OnnxLessOrEqual = loadop('LessOrEqual') 

387 return OnnxVar(self, y, op=OnnxLessOrEqual) 

388 

389 def __lt__(self, y): 

390 "Less." 

391 y = self._make_array(y) 

392 OnnxLess = loadop('Less') 

393 return OnnxVar(self, y, op=OnnxLess) 

394 

395 def __and__(self, y): 

396 "And." 

397 y = self._make_array(y) 

398 OnnxAnd = loadop('And') 

399 return OnnxVar(self, y, op=OnnxAnd) 

400 

401 def __or__(self, y): 

402 "And." 

403 y = self._make_array(y) 

404 OnnxOr = loadop('Or') 

405 return OnnxVar(self, y, op=OnnxOr) 

406 

407 def not_(self): 

408 "Not." 

409 OnnxNot = loadop('Not') 

410 return OnnxVar(self, op=OnnxNot) 

411 

412 def __neg__(self): 

413 "Neg." 

414 OnnxNeg = loadop('Neg') 

415 return OnnxVar(self, op=OnnxNeg) 

416 

417 def __getitem__(self, index): 

418 """ 

419 Deals with multiple scenarios. 

420 

421 * *index* is an integer or a slice, a tuple of integers and slices, 

422 example: `[0, 1]`, `[:5, :6]`, `[::2]` (**scenario 1**) 

423 * *index* is an *ONNX* object (more precisely an instance of 

424 @see cl OnnxVar), then the method assumes it is an array of 

425 boolean to select a subset of the tensor along the first axis, 

426 example: `mat[mat == 0]` (**scenario 2**) 

427 """ 

428 if isinstance(index, OnnxVar): 

429 # scenario 2 

430 return OnnxVar(self, index, op='filter') 

431 

432 if isinstance(index, int): 

433 # Use Gather instead. 

434 OnnxGather = loadop('Gather') 

435 return OnnxVar( 

436 self, numpy.array(index, dtype=numpy.int64), 

437 axis=0, op=OnnxGather) 

438 

439 if not isinstance(index, tuple): 

440 index = (index, ) 

441 

442 # only one integer? 

443 ni = None 

444 ax = None 

445 for i, a in enumerate(index): 

446 if isinstance(a, int): 

447 if ni is None: 

448 ni = i 

449 ax = a 

450 else: 

451 ax = None 

452 ni = None 

453 break 

454 if (isinstance(a, slice) and a.start is None and 

455 a.stop is None and a.step is None): 

456 continue 

457 ax = None 

458 ni = None 

459 break 

460 if ni is not None and ax is not None: 

461 # Use Gather instead. 

462 OnnxGather = loadop('Gather') 

463 return OnnxVar( 

464 self, numpy.array(ni, dtype=numpy.int64), 

465 axis=ax, op=OnnxGather) 

466 

467 # scenario 1 

468 starts = [] 

469 ends = [] 

470 axes = [] 

471 steps = [] 

472 axis_squeeze = [] 

473 needs_shape = [] 

474 for i, ind in enumerate(index): 

475 if isinstance(ind, int): 

476 starts.append(ind) 

477 ends.append(ind + 1) 

478 axes.append(i) 

479 steps.append(1) 

480 axis_squeeze.append(i) 

481 continue 

482 if isinstance(ind, slice): 

483 if ind.start is None and ind.stop is None and ind.step is None: 

484 continue 

485 start = 0 if ind.start is None else ind.start 

486 end = (None, i) if ind.stop is None else ind.stop 

487 step = 1 if ind.step is None else ind.step 

488 starts.append(start) 

489 ends.append(end) 

490 axes.append(i) 

491 steps.append(step) 

492 if isinstance(end, tuple): 

493 needs_shape.append(len(ends) - 1) 

494 elif isinstance(end, OnnxVar): 

495 needs_shape.append(end) 

496 continue 

497 raise NotImplementedError( # pragma: no cover 

498 f"Not implemented for type {type(ind)!r}.") 

499 

500 if max(steps) == min(steps) == 1: 

501 steps = None 

502 else: 

503 steps = numpy.array(steps, dtype=numpy.int64) 

504 

505 starts = numpy.array(starts, dtype=numpy.int64) 

506 axes = numpy.array(axes, dtype=numpy.int64) 

507 

508 OnnxGather, OnnxSlice, OnnxSqueeze, OnnxConcat = loadop( 

509 'Gather', 'Slice', 'Squeeze', 'Concat') 

510 if len(needs_shape) > 0: 

511 shape = self.shape 

512 conc = [] 

513 for e in ends: 

514 if isinstance(e, tuple): 

515 conc.append( 

516 OnnxVar(shape, numpy.array([e[1]], numpy.int64), 

517 op=OnnxGather)) 

518 elif isinstance(e, OnnxVar): 

519 conc.append( 

520 e.reshape(numpy.array([-1], dtype=numpy.int64))) 

521 else: 

522 conc.append(numpy.array([e], dtype=numpy.int64)) 

523 if len(conc) > 1: 

524 ends = OnnxVar(*conc, op=OnnxConcat, axis=0) 

525 else: 

526 ends = conc[0] 

527 else: 

528 ends = numpy.array(ends, dtype=numpy.int64) 

529 

530 if steps is None: 

531 sliced = OnnxVar(self, starts, ends, axes, op=OnnxSlice) 

532 else: 

533 sliced = OnnxVar(self, starts, ends, axes, steps, op=OnnxSlice) 

534 if len(axis_squeeze) > 0: 

535 return OnnxVar( 

536 sliced, numpy.array(axis_squeeze, dtype=numpy.int64), 

537 op=OnnxSqueeze) 

538 return sliced 

539 

540 def __setitem__(self, index, value): 

541 """ 

542 Only supports vectors (1D tensor). 

543 

544 * *index* is an integer or a slice, a tuple of integers and slices, 

545 example: `[0]`, `[:5]`, `[::2]` (**scenario 1**) 

546 * *index* is an *ONNX* object (more precisely an instance of 

547 @see cl OnnxVar), then the method assumes it is an array of 

548 boolean to select a subset of the tensor along the first axis, 

549 example: `mat[mat == 0]` (**scenario 2**) 

550 This processing is applied before the operator it contains. 

551 A copy should be made (Identity node or copy method). 

552 """ 

553 OnnxIdentity = loadop('Identity') 

554 if self.onnx_op is not None and self.onnx_op is not OnnxIdentity: 

555 raise RuntimeError( # pragma: no cover 

556 "A copy should be made before setting new values on a matrix. " 

557 "Method copy() would do that.") 

558 

559 if isinstance(index, OnnxVar): 

560 # scenario 2, example: cp[x < 0] = -1 

561 return self._setitem2i_(index, value) 

562 elif not isinstance(index, tuple): 

563 index = (index, ) 

564 

565 for i in index: 

566 if isinstance(i, OnnxVar): 

567 raise NotImplementedError( # pragma: no cover 

568 "Unable to handle case such as cp[0, x < 0] = -1.") 

569 

570 # scenario 1 

571 if len(index) == 1: 

572 return self._setitem1i_(index[0], value) 

573 raise NotImplementedError( # pragma: no cover 

574 f"Indices in {len(index)} dimensions are not implemented yet.") 

575 

576 def _setitem1i_(self, index, value): 

577 sl = None 

578 if isinstance(index, slice): 

579 start = 0 if index.start is None else index.start 

580 stop = index.stop 

581 step = index.step 

582 sl = [start, stop, step] 

583 elif isinstance(index, int): 

584 sl = [index, index + 1, 1] 

585 else: 

586 raise NotImplementedError( # pragma: no cover 

587 f"Unable to assign new values due to unexpected type {type(index)!r}.") 

588 

589 if sl[1] is None and isinstance(value, numpy.ndarray): 

590 sl[1] = sl[0] + value.size 

591 OnnxConstantOfShape, OnnxScatterElements = loadop( 

592 'ConstantOfShape', 'ScatterElements') 

593 if sl[1] is None: 

594 if sl[2] is not None and sl[2] != 1: 

595 raise NotImplementedError( # pragma: no cover 

596 "If the length is not known, step must be 1 not %d." % sl[2]) 

597 value = make_tensor( 

598 "value", guess_proto_dtype(value.dtype), (1, ), [value]) # pylint: disable=E1101 

599 inp = self.inputs[0] 

600 if not isinstance(inp, OnnxVar): 

601 raise RuntimeError( # pragma: no cover 

602 f"Input must be an instance of OnnxVar not {type(inp)!r}.") 

603 cst = OnnxVar(inp.shape, op=OnnxConstantOfShape, value=value) 

604 ext = inp[:sl[0]] 

605 indices = numpy.arange(0, sl[0]).astype(numpy.int64) 

606 add_step = OnnxVar(cst, indices, ext, 

607 op=OnnxScatterElements, axis=0) 

608 else: 

609 indices = numpy.arange(sl[0], sl[1], sl[2]).astype(numpy.int64) 

610 if isinstance(value, numpy.ndarray): 

611 values = value 

612 else: 

613 values = numpy.full(indices.shape, value) 

614 add_step = OnnxVar(self.inputs[0], indices, values, 

615 op=OnnxScatterElements, axis=0) 

616 

617 self.inputs = [add_step] 

618 return self 

619 

620 def _setitem2i_(self, index, value): 

621 OnnxWhere = loadop('Where') 

622 add_step = OnnxVar(index, value, self.inputs[0], op=OnnxWhere) 

623 self.inputs = [add_step] 

624 return self 

625 

626 def copy(self): 

627 """ 

628 Returns a copy of self (use of Identity node). 

629 """ 

630 OnnxIdentity = loadop('Identity') 

631 return OnnxVar(self, op=OnnxIdentity) 

632 

633 def flatten(self, axis=0): 

634 """ 

635 Flattens a matrix (see :epkg:`numpy:ndarray:flatten`). 

636 

637 :param axis: only flatten from axis to the end. 

638 :return: @see cl OnnxVar. 

639 """ 

640 OnnxFlatten, OnnxSqueeze = loadop('Flatten', 'Squeeze') 

641 fl = OnnxVar(self, op=OnnxFlatten, axis=axis) 

642 if axis == 0: 

643 return OnnxVar(fl, numpy.array([0], dtype=numpy.int64), 

644 op=OnnxSqueeze) 

645 return fl 

646 

647 

648class MultiOnnxVar: 

649 """ 

650 Class used to return multiple @see cl OnnxVar 

651 at the same time. 

652 """ 

653 

654 def __init__(self, *inputs, op=None, dtype=None, **kwargs): 

655 "constructor" 

656 logger.debug('MultiOnnxVar(%d in, dtype=%r, op=%r)', 

657 len(inputs), dtype, op) 

658 self.onxvar = OnnxVar(*inputs, op=op, dtype=None, **kwargs) 

659 self.alg_ = None 

660 

661 def _guess_dtype(self, dtype): 

662 "Guesses dtype when not specified." 

663 return self.onxvar._guess_dtype(dtype) 

664 

665 @property 

666 def inputs(self): 

667 "Returns `self.onxvar.inputs`." 

668 return self.onxvar.inputs 

669 

670 @property 

671 def onnx_op(self): 

672 "Returns `self.onxvar.onnx_op`." 

673 return self.onxvar.onnx_op 

674 

675 @property 

676 def onnx_op_kwargs(self): 

677 "Returns `self.onxvar.onnx_op_kwargs`." 

678 return self.onxvar.onnx_op_kwargs 

679 

680 def to_algebra(self, op_version=None): 

681 """ 

682 Converts the variable into an operator. 

683 """ 

684 if self.alg_ is None: 

685 logger.debug('MultiOnnxVar.to_algebra(op_version=%r)', 

686 op_version) 

687 new_inputs = [] 

688 for inp in self.inputs: 

689 if isinstance(inp, ( 

690 int, float, str, numpy.ndarray, numpy.int32, 

691 numpy.int64, numpy.float32, numpy.float64, 

692 numpy_bool, numpy_str, numpy.int8, numpy.uint8, 

693 numpy.int16, numpy.uint16, numpy.uint32, 

694 numpy.uint64)): 

695 new_inputs.append(inp) 

696 elif hasattr(inp, 'fit'): 

697 # scikit-learn models 

698 new_inputs.append(inp) 

699 else: 

700 new_inputs.append( 

701 inp.to_algebra(op_version=op_version)) 

702 

703 if self.onnx_op is None: 

704 if len(new_inputs) == 1: 

705 logger.debug('MultiOnnxVar.to_algebra:1:new_inputs[0]=%r', 

706 new_inputs[0]) 

707 self.alg_ = OnnxOperatorTuple(new_inputs[0]) 

708 else: 

709 logger.debug('MultiOnnxVar.to_algebra:2:new_inputs=%r', 

710 new_inputs) 

711 self.alg_ = OnnxOperatorTuple( 

712 new_inputs[0], *(new_inputs[1:])) 

713 else: 

714 logger.debug('MultiOnnxVar.to_algebra:%s:new_inputs=%r', 

715 self.onnx_op.__class__.__name__, new_inputs) 

716 res = self.onnx_op( # pylint: disable=E1102 

717 *new_inputs, op_version=op_version, **self.onnx_op_kwargs) 

718 self.alg_ = OnnxOperatorTuple(res) 

719 return self.alg_ 

720 

721 def __getitem__(self, index): 

722 """ 

723 Returns the ith elements. 

724 """ 

725 return OnnxVar(self, index=index, op=OnnxOperatorItem)