Coverage for mlprodict/npy/onnx_numpy_compiler.py: 97%

204 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Implements :epkg:`numpy` functions with onnx and a runtime. 

4 

5.. versionadded:: 0.6 

6""" 

7import inspect 

8import logging 

9from typing import Any 

10import numpy 

11from ..onnx_tools.optim._main_onnx_optim import onnx_optimisations 

12from .onnx_version import FctVersion 

13from .onnx_numpy_annotation import get_args_kwargs 

14from .xop_variable import Variable 

15from .xop import OnnxOperator, OnnxOperatorTuple 

16 

17 

18logger = logging.getLogger('xop') 

19 

20 

21class OnnxNumpyFunction: 

22 """ 

23 Class wrapping a function build with 

24 @see cl OnnxNumpyCompiler. 

25 

26 .. versionadded:: 0.6 

27 """ 

28 

29 def __init__(self, compiler, rt, inputs, outputs, 

30 n_optional, n_variables): 

31 if any(map(lambda n: not isinstance(n, Variable), inputs)): 

32 raise TypeError( # pragma: no cover 

33 f"All inputs must be of type Variable: {inputs!r}.") 

34 if any(map(lambda n: not isinstance(n, Variable), outputs)): 

35 raise TypeError( # pragma: no cover 

36 f"All outputs must be of type Variable: {outputs!r}.") 

37 self.compiler = compiler 

38 self.inputs = inputs 

39 self.outputs = outputs 

40 self.rt = rt 

41 self.n_optional = n_optional 

42 self.n_variables = n_variables 

43 if n_optional < 0: 

44 raise RuntimeError( # pragma: no cover 

45 f"Wrong configuration, n_optional {n_optional!r} must be >= 0.") 

46 if n_optional >= len(inputs): 

47 raise RuntimeError( # pragma: no cover 

48 "Wrong configuration, n_optional %r must be >= %r " 

49 "the number of inputs." % (n_optional, len(inputs))) 

50 

51 def _check_(self, *args, **kwargs): 

52 if self.n_variables > 0: 

53 return 

54 if (len(args) < len(self.inputs) - self.n_optional or 

55 len(args) > len(self.inputs)): 

56 raise RuntimeError( # pragma: no cover 

57 "Unexpected number of inputs %d. It should be in " 

58 "[%r, %r] len(args)=%d n_optional=%d n_variables=%d" 

59 "\nargs=%s\nkwargs=%s\ninputs=%s" % ( 

60 len(args), len(self.inputs) - self.n_optional, 

61 len(args), self.n_optional, self.n_variables, 

62 len(self.inputs), args, kwargs, self.inputs)) 

63 

64 

65class OnnxNumpyFunctionOnnxInference(OnnxNumpyFunction): 

66 """ 

67 Overwrites @see cl OnnxNumpyFunction to run an instance of 

68 @see cl OnnxInference. 

69 

70 .. versionadded:: 0.6 

71 """ 

72 

73 def __call__(self, *args, **kwargs): 

74 self._check_(*args, **kwargs) 

75 inp = {k.name: a for k, a in zip(self.inputs, args)} 

76 out = self.rt.run(inp, **kwargs) 

77 if len(out) != len(self.outputs): 

78 raise RuntimeError( # pragma: no cover 

79 "Unexpected number of outputs %d instead of %d." % ( 

80 len(out), len(self.outputs))) 

81 return tuple([out[o.name] for o in self.outputs]) 

82 

83 

84class OnnxNumpyFunctionInferenceSession(OnnxNumpyFunction): 

85 """ 

86 Overwrites @see cl OnnxNumpyFunction to run an instance of 

87 `InferenceSession` from :epkg:`onnxruntime`. 

88 

89 .. versionadded:: 0.6 

90 """ 

91 

92 def __call__(self, *args, **kwargs): 

93 self._check_(*args, **kwargs) 

94 if len(kwargs) > 0: 

95 raise RuntimeError( # pragma: no cover 

96 f"kwargs is not used but it is not empty: {kwargs!r}.") 

97 inp = {k.name: a for k, a in zip(self.inputs, args)} 

98 out = self.rt.run(None, inp) 

99 

100 if len(out) != len(self.outputs): 

101 raise RuntimeError( # pragma: no cover 

102 "Unexpected number of outputs %d instead of %d." % ( 

103 len(out), len(self.outputs))) 

104 return tuple(out) 

105 

106 

107class OnnxNumpyCompiler: 

108 """ 

109 Implements a class which runs onnx graph. 

110 

111 :param fct: a function with annotations which returns an ONNX graph, 

112 it can also be an ONNX graph. 

113 :param op_version: :epkg:`ONNX` opset to use, None 

114 for the latest one 

115 :param runtime: runtime to choose to execute the onnx graph, 

116 `python`, `onnxruntime`, `onnxruntime1` 

117 :param signature: used when the function is not annotated 

118 :param version: the same function can be instantiated with 

119 different type, this parameter is None or a numpy type 

120 if the signature allows multiple types, it must an instance 

121 of type @see cl FctVersion 

122 :param fctsig: function used to overwrite the fct signature 

123 in case this one is using `*args, **kwargs` 

124 

125 .. versionadded:: 0.6 

126 """ 

127 

128 def __init__(self, fct, op_version=None, runtime=None, signature=None, 

129 version=None, fctsig=None): 

130 if version is not None and not isinstance(version, FctVersion): 

131 raise TypeError( # pragma: no cover 

132 "version must be of Type 'FctVersion' not %s - %s" 

133 "." % (type(version), version)) 

134 self.fctsig = fctsig 

135 if op_version is None: 

136 from .. import __max_supported_opset__ 

137 op_version = __max_supported_opset__ 

138 if hasattr(fct, 'SerializeToString'): 

139 self.fct_ = None 

140 self.onnx_ = fct 

141 else: 

142 self.fct_ = fct 

143 if not inspect.isfunction(fct): 

144 raise TypeError( # pragma: no cover 

145 f"Unexpected type for fct={type(fct)!r}, it must be a function.") 

146 self.onnx_ = None 

147 self.onnx_ = self._to_onnx( 

148 op_version=op_version, signature=signature, 

149 version=version) 

150 self.runtime_ = self._build_runtime( 

151 op_version=op_version, runtime=runtime, 

152 signature=signature, version=version) 

153 ann = self._parse_annotation(signature=signature, version=version) 

154 inputs, outputs, kwargs, n_optional, n_variables = ann 

155 n_opt = 0 if signature is None else signature.n_optional 

156 args, kwargs2 = get_args_kwargs(self.fctsig or self.fct_, n_opt) 

157 self.meta_ = dict(op_version=op_version, runtime=runtime, 

158 signature=signature, version=version, 

159 inputs=inputs, outputs=outputs, 

160 kwargs=kwargs, n_optional=n_optional, 

161 n_variables=n_variables, 

162 args=args, kwargs2=kwargs2, 

163 annotations=self.fct_.__annotations__) 

164 

165 def __getstate__(self): 

166 """ 

167 Serializes everything but function `fct_`. 

168 Function `fct_` is used to build the onnx graph 

169 and is not needed anymore. 

170 """ 

171 return dict(onnx_=self.onnx_, meta_=self.meta_) 

172 

173 def __setstate__(self, state): 

174 """ 

175 Restores serialized data. 

176 """ 

177 for k, v in state.items(): 

178 setattr(self, k, v) 

179 self.runtime_ = self._build_runtime( 

180 op_version=self.meta_['op_version'], 

181 runtime=self.meta_['runtime'], 

182 signature=self.meta_['signature'], 

183 version=self.meta_['version']) 

184 

185 def __repr__(self): 

186 "usual" 

187 if self.fct_ is not None: 

188 return f"{self.__class__.__name__}({repr(self.fct_)})" 

189 if self.onnx_ is not None: 

190 return f"{self.__class__.__name__}({'... ONNX ... '})" 

191 raise NotImplementedError( # pragma: no cover 

192 "fct_ and onnx_ are empty.") 

193 

194 def _to_onnx_shape(self, shape): 

195 if shape is Any or shape is Ellipsis: 

196 shape = None 

197 elif isinstance(shape, tuple): 

198 shape = [None if s is Any or s is Ellipsis else s 

199 for s in shape] 

200 else: 

201 raise RuntimeError( # pragma: no cover 

202 f"Unexpected annotated shape {shape!r}.") 

203 return shape 

204 

205 def _parse_annotation(self, signature, version): 

206 """ 

207 Returns the annotations for function `fct_`. 

208 

209 :param signature: needed if the annotation is missing, 

210 then version might be needed to specify which type 

211 to use if the signature allows many 

212 :param version: version inside the many signatures possible 

213 :return: *tuple(inputs, outputs, kwargs)*, each of them 

214 is a list of tuple with the name and the dtype, 

215 *kwargs* is the list of additional parameters 

216 """ 

217 n_opt = 0 if signature is None else signature.n_optional 

218 if hasattr(self, 'meta_'): 

219 args, kwargs = self.meta_['args'], self.meta_['kwargs2'] 

220 else: 

221 args, kwargs = get_args_kwargs(self.fctsig or self.fct_, n_opt) 

222 if version is not None: 

223 nv = len(version) - len(args) - n_opt 

224 if (signature is not None and not 

225 signature.n_variables and nv > len(kwargs)): 

226 raise RuntimeError( # pragma: no cover 

227 "Mismatch (%d - %d - %d ? %d) between version=%r and kwargs=%r for " 

228 "function %r, optional argument is %d, " 

229 "signature=%r." % ( 

230 len(version), len(args), n_opt, len(kwargs), 

231 version, kwargs, self.fct_, 

232 signature.n_variables, signature)) 

233 vvers = {} if version.kwargs is None else version.kwargs 

234 up = {} 

235 for k, v in zip(kwargs, vvers): 

236 up[k] = v 

237 kwargs = kwargs.copy() 

238 kwargs.update(up) 

239 

240 for k, v in kwargs.items(): 

241 if isinstance(v, (type, numpy.dtype)): 

242 raise RuntimeError( # pragma: no cover 

243 f"Unexpected value for argument {k!r}: {v!r} from {kwargs!r}.") 

244 

245 if signature is not None: 

246 inputs, kwargs, outputs, n_optional, n_variables = ( 

247 signature.get_inputs_outputs(args, kwargs, version)) 

248 inputs = [Variable(i[0], i[1]) for i in inputs] 

249 outputs = [Variable(i[0], i[1]) for i in outputs] 

250 return inputs, outputs, kwargs, n_optional, n_variables 

251 

252 def _possible_names(): 

253 yield 'y' 

254 yield 'z' # pragma: no cover 

255 yield 'o' # pragma: no cover 

256 for i in range(0, 10000): # pragma: no cover 

257 yield 'o%d' % i 

258 

259 if hasattr(self, 'meta_'): 

260 annotations = self.meta_['annotations'] 

261 else: 

262 annotations = self.fct_.__annotations__ 

263 inputs = [] 

264 outputs = [] 

265 for a in args: 

266 if a == "op_version": 

267 continue 

268 if a not in annotations: 

269 raise RuntimeError( # pragma: no cover 

270 "Unable to find annotation for argument %r. " 

271 "You should annotate the arguments and the results " 

272 "or specify a signature." % a) 

273 ann = annotations[a] 

274 shape, dtype = ann.__args__ 

275 shape = self._to_onnx_shape(shape) 

276 inputs.append(Variable(a, dtype, shape=shape)) 

277 

278 ret = annotations['return'] 

279 names_in = set(inp.name for inp in inputs) 

280 

281 if isinstance(ret, tuple): 

282 # multiple outputs 

283 names_none = set() 

284 for shape_dtype in ret: 

285 shape, dtype = shape_dtype.__args__ 

286 shape = self._to_onnx_shape(shape) 

287 name_out = None 

288 for name in _possible_names(): 

289 if name not in names_in and name not in names_none: 

290 name_out = name 

291 break 

292 outputs.append(Variable(name_out, dtype, shape=shape)) 

293 names_none.add(name_out) 

294 return (inputs, outputs, kwargs, 0, 

295 signature.n_variables if signature is not None else False) 

296 

297 # single outputs 

298 shape, dtype = ret.__args__ 

299 shape = self._to_onnx_shape(shape) 

300 name_out = None 

301 for name in _possible_names(): 

302 if name not in names_in: 

303 name_out = name 

304 break 

305 outputs.append(Variable(name_out, dtype, shape=shape)) 

306 return (inputs, outputs, kwargs, 0, 

307 signature.n_variables if signature is not None else False) 

308 

309 def _find_hidden_algebras(self, onx_var, onx_algebra): 

310 """ 

311 Subgraph are using inputs not linked to the others nodes. 

312 This function retrieves them as they are stored in 

313 attributes `alg_hidden_var_`. The function looks into every 

314 node linked to the inputs and their predecessors. 

315 

316 :param onx_var: @see cl OnnxVar 

317 :param onx_algebra: OnnxOperator 

318 :return: tuple(dictionary `{id(obj): (var, obj)}`, 

319 all instance of @see cl OnnxVarGraph) 

320 """ 

321 keep_hidden = {} 

322 var_graphs = [] 

323 stack = [onx_var] 

324 while len(stack) > 0: 

325 var = stack.pop() 

326 hidden = getattr(var, 'alg_hidden_var_', None) 

327 if hidden is not None: 

328 if any(map(lambda x: len(x) > 0, 

329 var.alg_hidden_var_inputs.values())): 

330 keep_hidden.update(hidden) 

331 var_graphs.append(var) 

332 if hasattr(var, 'inputs'): 

333 for inp in var.inputs: 

334 stack.append(inp) 

335 return keep_hidden, var_graphs 

336 

337 def _to_onnx(self, op_version=None, signature=None, version=None): 

338 """ 

339 Returns the onnx graph produced by function `fct_`. 

340 """ 

341 if self.onnx_ is None and self.fct_ is not None: 

342 from .onnx_variable import OnnxVar 

343 logger.debug('OnnxNumpyCompiler._to_onnx(op_version=%r, ' 

344 'signature=%r, version=%r)', 

345 op_version, signature, version) 

346 inputs, outputs, kwargs, n_optional, n_variables = ( # pylint: disable=W0612 

347 self._parse_annotation( 

348 signature=signature, version=version)) 

349 if ((signature is None or not signature.n_variables) and 

350 isinstance(version, tuple) and 

351 len(inputs) > len(version)): 

352 raise NotImplementedError( # pragma: no cover 

353 "Mismatch between additional parameters %r " 

354 "(n_optional=%r) and version %r for function %r from %r." 

355 "" % (kwargs, n_optional, version, self.fct_, 

356 getattr(self.fct_, '__module__', None))) 

357 names_in = [oi.name for oi in inputs] 

358 names_out = [oi.name for oi in outputs] 

359 names_var = [OnnxVar(n, dtype=dt.dtype) 

360 for n, dt in zip(names_in, inputs)] 

361 

362 logger.debug('OnnxNumpyCompiler._to_onnx:names_in=%r', names_in) 

363 logger.debug('OnnxNumpyCompiler._to_onnx:names_out=%r', names_out) 

364 

365 if 'op_version' in self.fct_.__code__.co_varnames: 

366 onx_var = None 

367 onx_algebra = self.fct_( 

368 *names_in, op_version=op_version, **kwargs) 

369 else: 

370 onx_var = self.fct_(*names_var, **kwargs) 

371 if not hasattr(onx_var, 'to_algebra'): 

372 raise TypeError( # pragma: no cover 

373 "The function %r to convert must return an instance of " 

374 "OnnxVar but returns type %r." % (self.fct_, type(onx_var))) 

375 onx_algebra = onx_var.to_algebra(op_version=op_version) 

376 

377 logger.debug('OnnxNumpyCompiler._to_onnx:onx_var=%r', 

378 type(onx_var)) 

379 logger.debug('OnnxNumpyCompiler._to_onnx:onx_algebra=%r', 

380 type(onx_algebra)) 

381 

382 if not isinstance(onx_algebra, (OnnxOperator, OnnxOperatorTuple)): 

383 raise TypeError( # pragma: no cover 

384 "Unexpected type for onx_algebra %r " 

385 "(It should be OnnxOperator or OnnxOperatorItem), " 

386 "function is %r." % (type(onx_algebra), self.fct_)) 

387 hidden_algebras, var_graphs = self._find_hidden_algebras( 

388 onx_var, onx_algebra) 

389 if len(hidden_algebras) > 0: 

390 logger.debug( # pragma: no cover 

391 'OnnxNumpyCompiler._to_onnx:len(hidden_algebras)=%r', 

392 len(hidden_algebras)) 

393 # print('----1', len(var_graphs)) 

394 # for gr in var_graphs: 

395 # print(type(gr), dir(gr)) 

396 # print('----2', len(hidden_algebras)) 

397 # for k, v in hidden_algebras.items(): 

398 # print("*", type(v.alg_), dir(v.alg_)) 

399 # #import pprint 

400 # #pprint.pprint(dir(v.alg_)) 

401 raise NotImplementedError( # pragma: no cover 

402 "Subgraphs only support constants (operator If, Loop, " 

403 "Scan). hidden_algebras=%r var_graphs=%r" % ( 

404 hidden_algebras, var_graphs)) 

405 

406 if isinstance(onx_algebra, str): 

407 raise RuntimeError( # pragma: no cover 

408 f"Unexpected str type {onx_algebra!r}.") 

409 if isinstance(onx_algebra, tuple): 

410 raise NotImplementedError( # pragma: no cover 

411 "Not implemented when the function returns multiple results.") 

412 if hasattr(onx_algebra, 'to_onnx'): 

413 onx_algebra.output_names = [Variable(n) for n in names_out] 

414 onx = onx_algebra.to_onnx( 

415 inputs=inputs, target_opset=op_version, outputs=outputs) 

416 # optimisation 

417 onx_optimized = onnx_optimisations(onx) 

418 self.onnx_ = onx_optimized 

419 

420 if self.onnx_ is None: 

421 raise RuntimeError( # pragma: no cover 

422 "Unable to get the ONNX graph (class %r, fct_=%r)" % ( 

423 type(self), self.fct_)) 

424 return self.onnx_ 

425 

426 def to_onnx(self, **kwargs): 

427 """ 

428 Returns the ONNX graph for the wrapped function. 

429 It takes additional arguments to distinguish between multiple graphs. 

430 This happens when a function needs to support multiple type. 

431 

432 :return: ONNX graph 

433 """ 

434 if len(kwargs) > 0: 

435 raise NotImplementedError( # pragma: no cover 

436 "kwargs is not empty, this case is not implemented. " 

437 "kwargs=%r." % kwargs) 

438 if hasattr(self, 'onnx_'): 

439 return self.onnx_ 

440 raise NotImplementedError( # pragma: no cover 

441 "Attribute 'onnx_' is missing.") 

442 

443 def _build_runtime(self, op_version=None, runtime=None, 

444 signature=None, version=None): 

445 """ 

446 Creates the runtime for the :epkg:`ONNX` graph. 

447 

448 :param op_version: :epkg:`ONNX` opset to use, None 

449 for the latest one 

450 :param runtime: runtime to choose to execute the onnx graph, 

451 `python`, `onnxruntime`, `onnxruntime1` 

452 :param signature: used when the function is not annotated 

453 """ 

454 onx = self._to_onnx(op_version=op_version, signature=signature, 

455 version=version) 

456 inputs, outputs, _, n_optional, n_variables = self._parse_annotation( 

457 signature=signature, version=version) 

458 if runtime not in ('onnxruntime', 'onnxruntime-cuda'): 

459 from ..onnxrt import OnnxInference 

460 rt = OnnxInference(onx, runtime=runtime) 

461 self.rt_fct_ = OnnxNumpyFunctionOnnxInference( 

462 self, rt, inputs=inputs, outputs=outputs, 

463 n_optional=n_optional, n_variables=n_variables) 

464 else: 

465 from ..tools.ort_wrapper import InferenceSession 

466 rt = InferenceSession(onx.SerializeToString(), runtime=runtime) 

467 self.rt_fct_ = OnnxNumpyFunctionInferenceSession( 

468 self, rt, inputs=inputs, outputs=outputs, 

469 n_optional=n_optional, n_variables=n_variables) 

470 return self.rt_fct_ 

471 

472 def __call__(self, *args, **kwargs): 

473 """ 

474 Executes the function and returns the results. 

475 

476 :param args: arguments 

477 :return: results 

478 """ 

479 res = self.rt_fct_(*args, **kwargs) 

480 if len(res) == 1: 

481 return res[0] 

482 return res