Coverage for mlprodict/npy/numpy_onnx_impl.py: 98%

255 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief :epkg:`numpy` functions implemented with :epkg:`onnx`. 

4 

5.. versionadded:: 0.6 

6 

7.. versionchanged:: 0.7 

8""" 

9import warnings 

10import numpy 

11from onnx import onnx_pb as onnx_proto # pylint: disable=E1101 

12from onnx.helper import make_tensor 

13from .onnx_variable import OnnxVar, MultiOnnxVar as xtuple 

14from .xop import loadop 

15from .numpy_onnx_impl_body import if_then_else, OnnxVarGraph 

16 

17 

18def abs(x): 

19 "See :func:`numpy.abs`." 

20 OnnxAbs = loadop('Abs') 

21 return OnnxVar(x, op=OnnxAbs) 

22 

23 

24def acos(x): 

25 "See :func:`numpy.acos`." 

26 OnnxAcos = loadop('Acos') 

27 return OnnxVar(x, op=OnnxAcos) 

28 

29 

30def acosh(x): 

31 "See :func:`numpy.acosh`." 

32 OnnxAcosh = loadop('Acosh') 

33 return OnnxVar(x, op=OnnxAcosh) 

34 

35 

36def amax(x, axis=None, keepdims=0): 

37 "See :func:`numpy.amax`." 

38 OnnxReduceMax = loadop('ReduceMax') 

39 if axis is None: 

40 return OnnxVar(x, op=OnnxReduceMax, keepdims=keepdims) 

41 if not isinstance(axis, list): 

42 axis = [axis] 

43 return OnnxVar(x, op=OnnxReduceMax, keepdims=keepdims, axes=axis) 

44 

45 

46def amin(x, axis=None, keepdims=0): 

47 "See :func:`numpy.amin`." 

48 OnnxReduceMin = loadop('ReduceMin') 

49 if axis is None: 

50 return OnnxVar(x, op=OnnxReduceMin, keepdims=keepdims) 

51 if not isinstance(axis, list): 

52 axis = [axis] 

53 return OnnxVar(x, op=OnnxReduceMin, keepdims=keepdims, axes=axis) 

54 

55 

56def arange(start, stop, step=1): 

57 "See :func:`numpy.arange`, *start*, *stop* must be specified." 

58 if not isinstance(step, (int, numpy.int64)): 

59 raise TypeError( # pragma: no cover 

60 f"step must be an integer not {type(step)!r}.") 

61 if isinstance(start, (int, numpy.int64, numpy.int32)): 

62 start = numpy.array([start], dtype=numpy.int64) 

63 zero = start == 0 

64 else: 

65 zero = False 

66 if isinstance(stop, (int, numpy.int64, numpy.int32)): 

67 stop = numpy.array([stop], dtype=numpy.int64) 

68 value = make_tensor( 

69 "value", onnx_proto.TensorProto.INT64, (1, ), [step]) # pylint: disable=E1101 

70 

71 OnnxAdd, OnnxCumSum, OnnxConstantOfShape, OnnxSub = loadop( 

72 'Add', 'CumSum', 'ConstantOfShape', 'Sub') 

73 if isinstance(step, (int, numpy.int64, numpy.int32)) and step == 1: 

74 if zero: 

75 shape = stop 

76 else: 

77 shape = stop - start 

78 if isinstance(shape, OnnxVar): 

79 shape = shape.reshape(numpy.array([-1], dtype=numpy.int64)) 

80 _cst = OnnxVar(shape, op=OnnxConstantOfShape, value=value) 

81 cs = OnnxVar(_cst, numpy.array([0], dtype=numpy.int64), 

82 op=OnnxCumSum) 

83 diff = start - numpy.array([step], dtype=numpy.int64) 

84 return OnnxVar(cs, diff, op=OnnxAdd) 

85 

86 if isinstance(step, (int, numpy.int64, numpy.int32)): 

87 step = numpy.array([step], dtype=numpy.int64) 

88 if zero: 

89 shape = stop // step 

90 else: 

91 shape = (stop - start) // step 

92 if isinstance(shape, OnnxVar): 

93 shape = shape.reshape(numpy.array([-1], dtype=numpy.int64)) 

94 _cst = OnnxVar(shape, op=OnnxConstantOfShape, value=value) 

95 else: 

96 # csm = OnnxVar(_cst, step, op=OnnxMul) 

97 raise NotImplementedError( # pragma: no cover 

98 "Not yet implemented.") 

99 

100 cs = OnnxVar(_cst, numpy.array([0], dtype=numpy.int64), 

101 op=OnnxCumSum) 

102 add = OnnxVar(cs, start, op=OnnxAdd) 

103 return OnnxVar(add, step, op=OnnxSub) 

104 

105 

106def argmax(x, axis=0, keepdims=0): 

107 """ 

108 See :func:`numpy.argmax`. 

109 

110 .. warning:: 

111 ONNX does not implement default value axis=None. 

112 """ 

113 if axis is None: 

114 raise NotImplementedError( # pragma: no cover 

115 "ONNX does not allow axis=None.") 

116 OnnxArgMax = loadop('ArgMax') 

117 return OnnxVar(x, op=OnnxArgMax, axis=axis, keepdims=keepdims) 

118 

119 

120def argmin(x, axis=0, keepdims=0): 

121 """ 

122 See :func:`numpy.argmin`. 

123 

124 .. warning:: 

125 ONNX does not implement default value axis=None. 

126 """ 

127 if axis is None: 

128 raise NotImplementedError( # pragma: no cover 

129 "ONNX does not allow axis=None.") 

130 OnnxArgMin = loadop('ArgMin') 

131 return OnnxVar(x, op=OnnxArgMin, axis=axis, keepdims=keepdims) 

132 

133 

134def asin(x): 

135 "See :func:`numpy.asin`." 

136 OnnxAsin = loadop('Asin') 

137 return OnnxVar(x, op=OnnxAsin) 

138 

139 

140def asinh(x): 

141 "See :func:`numpy.asinh`." 

142 OnnxAsinh = loadop('Asinh') 

143 return OnnxVar(x, op=OnnxAsinh) 

144 

145 

146def atan(x): 

147 "See :func:`numpy.atan`." 

148 OnnxAtan = loadop('Atan') 

149 return OnnxVar(x, op=OnnxAtan) 

150 

151 

152def atanh(x): 

153 "See :func:`numpy.atanh`." 

154 OnnxAtanh = loadop('Atanh') 

155 return OnnxVar(x, op=OnnxAtanh) 

156 

157 

158def ceil(x): 

159 "See :func:`numpy.ceil`." 

160 OnnxCeil = loadop('Ceil') 

161 return OnnxVar(x, op=OnnxCeil) 

162 

163 

164def clip(x, a_min=None, a_max=None): 

165 "See :func:`numpy.clip`." 

166 args = [x] 

167 if a_min is not None: 

168 args.append(a_min) 

169 if a_max is not None: 

170 args.append(a_max) 

171 OnnxClip = loadop('Clip') 

172 return OnnxVar(*args, op=OnnxClip) 

173 

174 

175def compress(condition, x, axis=None): 

176 """ 

177 See :func:`numpy.compress`. 

178 `numpy.compress(condition, x)` or `npnx.compress(x, condition)`. 

179 """ 

180 OnnxCompress = loadop('Compress') 

181 if axis is None: 

182 return OnnxVar(x, condition, op=OnnxCompress) 

183 return OnnxVar(x, condition, op=OnnxCompress, axis=axis) 

184 

185 

186def cos(x): 

187 "See :func:`numpy.cos`." 

188 OnnxCos = loadop('Cos') 

189 return OnnxVar(x, op=OnnxCos) 

190 

191 

192def cosh(x): 

193 "See :func:`numpy.cosh`." 

194 OnnxCosh = loadop('Cosh') 

195 return OnnxVar(x, op=OnnxCosh) 

196 

197 

198def concat(*x, axis=0): 

199 """ 

200 Operator concat, handle :func:`numpy.vstack` and 

201 :func:`numpy.hstack`. 

202 """ 

203 OnnxConcat = loadop('Concat') 

204 if len(x) <= 1: 

205 raise RuntimeError( # pragma: no cover 

206 f"N={len(x)}<=1 elements to concatenate.") 

207 return OnnxVar(*x, op=OnnxConcat, axis=axis) 

208 

209 

210def cumsum(x, axis): 

211 "See :func:`numpy.cumsum`." 

212 OnnxCumSum = loadop('CumSum') 

213 return OnnxVar(x, axis, op=OnnxCumSum) 

214 

215 

216def cst(x, dtype=None): 

217 """ 

218 Creates a constant. `log(x) + numpy.float32(1)` works 

219 but `numpy.float32(32) + log(x)` fails because Python 

220 calls `numpy.float32.__add__` instead of 

221 `OnnxVar.__add__`. With this function, expression 

222 `cst(1.) + log(x)` is valid. Parameter `dtype` is 

223 used to overwrite the default dtype (`numpy.float32` 

224 for floats and `numpy.int64` for ints. 

225 """ 

226 OnnxIdentity = loadop('Identity') 

227 if isinstance(x, float): 

228 return OnnxVar(numpy.array([x], dtype=dtype or numpy.float32), 

229 op=OnnxIdentity) 

230 if isinstance(x, int): 

231 return OnnxVar(numpy.array([x], dtype=dtype or numpy.int64), 

232 op=OnnxIdentity) 

233 if isinstance(x, numpy.ndarray): 

234 return OnnxVar(x, op=OnnxIdentity) 

235 if hasattr(x, 'dtype'): 

236 if dtype is not None: 

237 raise RuntimeError( # pragma: no cover 

238 f"dtype is not used because x is of type {type(x)!r}.") 

239 return OnnxVar(numpy.array([x], dtype=x.dtype), 

240 op=OnnxIdentity) 

241 raise NotImplementedError( # pragma: no cover 

242 f"Unable to convert type {type(x)!r} into a constant.") 

243 

244 

245def det(x): 

246 "See :func:`numpy.linalg:det`." 

247 OnnxDet = loadop('Det') 

248 return OnnxVar(x, op=OnnxDet) 

249 

250 

251def dot(a, b): 

252 "See :func:`numpy.dot`" 

253 warnings.warn( 

254 "npnx.dot is equivalent to npnx.matmul == numpy.matmul " 

255 "!= numpy.dot with arrays with more than 3D dimensions.") 

256 OnnxMatMul = loadop('MatMul') 

257 return OnnxVar(a, b, op=OnnxMatMul) 

258 

259 

260def matmul(a, b): 

261 "See :func:`numpy.matmul`." 

262 OnnxMatMul = loadop('MatMul') 

263 return OnnxVar(a, b, op=OnnxMatMul) 

264 

265 

266def einsum(*x, equation=None): 

267 "See :func:`numpy.einsum`." 

268 OnnxEinsum = loadop('Einsum') 

269 return OnnxVar(*x, op=OnnxEinsum, equation=equation) 

270 

271 

272def erf(x): 

273 "See :epkg:`scipy:special:erf`." 

274 OnnxErf = loadop('Erf') 

275 return OnnxVar(x, op=OnnxErf) 

276 

277 

278def exp(x): 

279 "See :func:`numpy.exp`." 

280 OnnxExp = loadop('Exp') 

281 return OnnxVar(x, op=OnnxExp) 

282 

283 

284def expand_dims(x, axis): 

285 "See :func:`numpy.expand_dims`." 

286 if not isinstance(axis, int): 

287 raise NotImplementedError( # pragma: no cover 

288 f"This function only allows integer for axis not {type(axis)!r}.") 

289 OnnxUnsqueeze = loadop('Unsqueeze') 

290 return OnnxVar(x, numpy.array([axis], dtype=numpy.int64), 

291 op=OnnxUnsqueeze) 

292 

293 

294def expit(x): 

295 "See :epkg:`scipy:special:expit`." 

296 OnnxSigmoid = loadop('Sigmoid') 

297 return OnnxVar(x, op=OnnxSigmoid) 

298 

299 

300def floor(x): 

301 "See :func:`numpy.floor`." 

302 OnnxFloor = loadop('Floor') 

303 return OnnxVar(x, op=OnnxFloor) 

304 

305 

306def hstack(*x): 

307 "See :func:`numpy.hstack`." 

308 if len(x) <= 1: 

309 raise RuntimeError( # pragma: no cover 

310 f"N={len(x)}<=1 elements to concatenate.") 

311 OnnxConcat = loadop('Concat') 

312 return OnnxVar(*x, op=OnnxConcat, axis=-1) 

313 

314 

315def isnan(x): 

316 "See :func:`numpy.isnan`." 

317 OnnxIsNaN = loadop('IsNaN') 

318 return OnnxVar(x, op=OnnxIsNaN) 

319 

320 

321def identity(x): 

322 "Identity." 

323 OnnxIdentity = loadop('Identity') 

324 return OnnxVar(x, op=OnnxIdentity) 

325 

326 

327def log(x): 

328 "See :func:`numpy.log`." 

329 OnnxLog = loadop('Log') 

330 return OnnxVar(x, op=OnnxLog) 

331 

332 

333def log1p(x): 

334 "See :func:`numpy.log1p`." 

335 OnnxLog, OnnxAdd = loadop('Log', 'Add') 

336 x1 = OnnxVar(x, numpy.array([1], dtype=x.dtype), 

337 op=OnnxAdd) 

338 return OnnxVar(x1, op=OnnxLog) 

339 

340 

341def mean(x, axis=None, keepdims=0): 

342 "See :func:`numpy.mean`." 

343 OnnxReduceMean = loadop('ReduceMean') 

344 if axis is None: 

345 return OnnxVar(x, op=OnnxReduceMean, keepdims=keepdims) 

346 if not isinstance(axis, list): 

347 axis = [axis] 

348 return OnnxVar(x, op=OnnxReduceMean, keepdims=keepdims, axes=axis) 

349 

350 

351def onnx_if(condition, then_branch, else_branch): 

352 """ 

353 Implements a test with onnx syntax. 

354 

355 :param condition: condition (@see cl OnnxVar) 

356 :param then_branch: then branch, of type @see cl if_then_else 

357 :param else_branch: else branch, of type @see cl if_then_else 

358 :return: result (@see cl OnnxVar) 

359 """ 

360 OnnxIf = loadop('If') 

361 if isinstance(then_branch, numpy.ndarray): 

362 then_branch = if_then_else(then_branch) 

363 if not isinstance(then_branch, if_then_else): 

364 raise TypeError( 

365 "Parameter then_branch is not of type " 

366 "'if_then_else' but %r." % type(then_branch)) 

367 if isinstance(else_branch, numpy.ndarray): 

368 else_branch = if_then_else(else_branch) 

369 if not isinstance(else_branch, if_then_else): 

370 raise TypeError( 

371 "Parameter then_branch is not of type " 

372 "'if_then_else' but %r." % type(else_branch)) 

373 return OnnxVarGraph( 

374 condition, then_branch=then_branch, 

375 else_branch=else_branch, op=OnnxIf) 

376 

377 

378def pad(x, pads, constant_value=None, mode='constant'): 

379 """ 

380 It does not implement :func:`numpy.pad` but the ONNX version 

381 :func:`onnx_pad <mlprodict.onnxrt.ops_cpu.op_pad.onnx_pad>`. 

382 """ 

383 OnnxPad = loadop(('', 'Pad')) 

384 if constant_value is None: 

385 return OnnxVar(x, pads, op=OnnxPad, mode=mode) 

386 return OnnxVar(x, pads, constant_value, op=OnnxPad, mode=mode) 

387 

388 

389def prod(x, axis=None, keepdims=0): 

390 "See :func:`numpy.prod`." 

391 OnnxReduceProd = loadop('ReduceProd') 

392 if axis is None: 

393 return OnnxVar(x, op=OnnxReduceProd, keepdims=keepdims) 

394 if not isinstance(axis, list): 

395 axis = [axis] 

396 return OnnxVar(x, op=OnnxReduceProd, keepdims=keepdims, axes=axis) 

397 

398 

399def relu(x): 

400 "relu" 

401 OnnxRelu = loadop('Relu') 

402 return OnnxVar(x, op=OnnxRelu) 

403 

404 

405def reciprocal(x): 

406 "See :func:`numpy.reciprocal`." 

407 OnnxReciprocal = loadop('Reciprocal') 

408 return OnnxVar(x, op=OnnxReciprocal) 

409 

410 

411def round(x): 

412 "See :func:`numpy.round`." 

413 OnnxRound = loadop('Round') 

414 return OnnxVar(x, op=OnnxRound) 

415 

416 

417def sigmoid(x): 

418 "See :epkg:`scipy:special:expit`." 

419 OnnxSigmoid = loadop('Sigmoid') 

420 return OnnxVar(x, op=OnnxSigmoid) 

421 

422 

423def sign(x): 

424 "See :func:`numpy.sign`." 

425 OnnxSign = loadop('Sign') 

426 return OnnxVar(x, op=OnnxSign) 

427 

428 

429def sin(x): 

430 "See :func:`numpy.sin`." 

431 OnnxSin = loadop('Sin') 

432 return OnnxVar(x, op=OnnxSin) 

433 

434 

435def sinh(x): 

436 "See :func:`numpy.sinh`." 

437 OnnxSinh = loadop('Sinh') 

438 return OnnxVar(x, op=OnnxSinh) 

439 

440 

441def sqrt(x): 

442 "See :func:`numpy.sqrt`." 

443 OnnxSqrt = loadop('Sqrt') 

444 return OnnxVar(x, op=OnnxSqrt) 

445 

446 

447def squeeze(x, axis=None): 

448 "See :func:`numpy.squeeze`." 

449 OnnxSqueeze = loadop('Squeeze') 

450 if axis is None: 

451 raise NotImplementedError( # pragma: no cover 

452 "The case where all empty dimensions are removed is not " 

453 "implemented.") 

454 if isinstance(axis, int): 

455 raise RuntimeError( # pragma: no cover 

456 "axis must be a tensor.") 

457 return OnnxVar(x, axis, op=OnnxSqueeze) 

458 

459 

460def sum(x, axis=None, keepdims=0): 

461 "See :func:`numpy.sum`." 

462 OnnxReduceSum = loadop('ReduceSum') 

463 if axis is None: 

464 return OnnxVar(x, op=OnnxReduceSum, keepdims=keepdims) 

465 return OnnxVar(x, numpy.array([axis], dtype=numpy.int64), 

466 op=OnnxReduceSum, keepdims=keepdims) 

467 

468 

469def tan(x): 

470 "See :func:`numpy.tan`." 

471 OnnxTan = loadop('Tan') 

472 return OnnxVar(x, op=OnnxTan) 

473 

474 

475def tanh(x): 

476 "See :func:`numpy.tanh`." 

477 OnnxTanh = loadop('Tanh') 

478 return OnnxVar(x, op=OnnxTanh) 

479 

480 

481def topk(x, k, axis=-1, largest=1, sorted=1): 

482 "See :func:`numpy.argsort`." 

483 OnnxTopK = loadop('TopK') 

484 return xtuple(x, k, op=OnnxTopK, axis=axis, largest=largest, 

485 sorted=sorted) 

486 

487 

488def transpose(x, perm=(1, 0)): 

489 "See :func:`numpy.transpose`." 

490 OnnxTranspose = loadop('Transpose') 

491 return OnnxVar(x, op=OnnxTranspose, perm=list(perm)) 

492 

493 

494def unsqueeze(x, axes): 

495 "See :func:`numpy.expand_dims`." 

496 OnnxUnsqueeze = loadop('Unsqueeze') 

497 if isinstance(axes, int): 

498 axes = numpy.array([axes], dtype=numpy.int64) 

499 return OnnxVar(x, axes, op=OnnxUnsqueeze) 

500 

501 

502def vstack(*x): 

503 "See :func:`numpy.vstack`." 

504 OnnxConcat = loadop('Concat') 

505 if len(x) <= 1: 

506 raise RuntimeError( # pragma: no cover 

507 f"N={len(x)}<=1 elements to concatenate.") 

508 return OnnxVar(*x, op=OnnxConcat, axis=0) 

509 

510 

511def where(cond, x, y): 

512 "See :func:`numpy.where`." 

513 OnnxWhere = loadop('Where') 

514 return OnnxVar(cond, x, y, op=OnnxWhere)