Coverage for mlprodict/npy/onnx_numpy_annotation.py: 97%

174 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief :epkg:`numpy` annotations. 

4 

5.. versionadded:: 0.6 

6""" 

7import inspect 

8from collections import OrderedDict 

9from typing import TypeVar, Generic 

10import numpy 

11from .onnx_version import FctVersion 

12 

13try: 

14 numpy_bool = numpy.bool_ 

15except AttributeError: # pragma: no cover 

16 numpy_bool = bool 

17 

18try: 

19 numpy_str = numpy.str_ 

20except AttributeError: # pragma: no cover 

21 numpy_str = str 

22 

23Shape = TypeVar("Shape") 

24DType = TypeVar("DType") 

25 

26 

27all_dtypes = (numpy.float32, numpy.float64, 

28 numpy.int32, numpy.int64, 

29 numpy.uint32, numpy.uint64) 

30 

31 

32def get_args_kwargs(fct, n_optional): 

33 """ 

34 Extracts arguments and optional parameters of a function. 

35 

36 :param fct: function 

37 :param n_optional: number of arguments to consider as 

38 optional arguments and not parameters, this parameter skips 

39 the first *n_optional* paramerters 

40 :return: arguments, OrderedDict 

41 

42 Any optional argument ending with '_' is ignored. 

43 """ 

44 params = inspect.signature(fct).parameters 

45 if n_optional == 0: 

46 items = list(params.items()) 

47 args = [name for name, p in params.items() 

48 if p.default == inspect.Parameter.empty] 

49 else: 

50 items = [] 

51 args = [] 

52 for name, p in params.items(): 

53 if p.default == inspect.Parameter.empty: 

54 args.append(name) 

55 else: 

56 if n_optional > 0: 

57 args.append(name) 

58 n_optional -= 1 

59 else: 

60 items.append((name, p)) 

61 

62 kwargs = OrderedDict((name, p.default) for name, p in items 

63 if (p.default != inspect.Parameter.empty and 

64 name != 'op_version')) 

65 if args[0] == 'self': 

66 args = args[1:] 

67 kwargs['op_'] = None 

68 return args, kwargs 

69 

70 

71class NDArray(numpy.ndarray, Generic[Shape, DType]): 

72 """ 

73 Used to annotation ONNX numpy functions. 

74 

75 .. versionadded:: 0.6 

76 """ 

77 class ShapeType: 

78 "Stores shape information." 

79 

80 def __init__(self, params): 

81 self.__args__ = params 

82 

83 def __class_getitem__(cls, params): # pylint: disable=W0221,W0237 

84 "Overwrites this method." 

85 if not isinstance(params, tuple): 

86 params = (params,) # pragma: no cover 

87 return NDArray.ShapeType(params) 

88 

89 

90class _NDArrayAlias: 

91 """ 

92 Ancestor to custom signature. 

93 

94 :param dtypes: input dtypes 

95 :param dtypes_out: output dtypes 

96 :param n_optional: number of optional parameters, 0 by default 

97 :param nvars: True if the function allows an infinite number of inputs, 

98 this is incompatible with parameter *n_optional*. 

99 

100 *dtypes*, *dtypes_out* by default are a tuple of tuple: 

101 

102 * first dimension: type of every input 

103 * second dimension: list of types for one input 

104 

105 .. versionadded:: 0.6 

106 """ 

107 

108 def __init__(self, dtypes=None, dtypes_out=None, n_optional=None, 

109 nvars=False): 

110 "constructor" 

111 if dtypes is None: 

112 raise ValueError("dtypes cannot be None.") # pragma: no cover 

113 if isinstance(dtypes, tuple) and len(dtypes) == 0: 

114 raise TypeError("dtypes must not be empty.") # pragma: no cover 

115 if isinstance(dtypes, tuple) and not isinstance(dtypes[0], tuple): 

116 dtypes = tuple(t if isinstance(t, str) else (t,) for t in dtypes) 

117 if isinstance(dtypes, str) and '_' in dtypes: 

118 dtypes, dtypes_out = dtypes.split('_') 

119 if not isinstance(dtypes, (tuple, list)): 

120 dtypes = (dtypes, ) 

121 

122 self.mapped_types = {} 

123 self.dtypes = _NDArrayAlias._process_type( 

124 dtypes, self.mapped_types, 0) 

125 if dtypes_out is None: 

126 self.dtypes_out = (self.dtypes[0], ) 

127 elif isinstance(dtypes_out, int): 

128 self.dtypes_out = (self.dtypes[dtypes_out], ) 

129 else: 

130 if not isinstance(dtypes_out, (tuple, list)): 

131 dtypes_out = (dtypes_out, ) 

132 self.dtypes_out = _NDArrayAlias._process_type( 

133 dtypes_out, self.mapped_types, 0) 

134 self.n_optional = 0 if n_optional is None else n_optional 

135 self.n_variables = nvars 

136 

137 if not isinstance(self.dtypes, tuple): 

138 raise TypeError( # pragma: no cover 

139 f"self.dtypes must be a tuple not {self.dtypes}.") 

140 if (len(self.dtypes) == 0 or 

141 not isinstance(self.dtypes[0], tuple)): 

142 raise TypeError( # pragma: no cover 

143 f"Type mismatch in self.dtypes: {self.dtypes}.") 

144 if (len(self.dtypes[0]) == 0 or 

145 isinstance(self.dtypes[0][0], tuple)): 

146 raise TypeError( # pragma: no cover 

147 f"Type mismatch in self.dtypes: {self.dtypes}.") 

148 

149 if not isinstance(self.dtypes_out, tuple): 

150 raise TypeError( # pragma: no cover 

151 f"self.dtypes_out must be a tuple not {self.dtypes_out}.") 

152 if (len(self.dtypes_out) == 0 or 

153 not isinstance(self.dtypes_out[0], tuple)): 

154 raise TypeError( # pragma: no cover 

155 "Type mismatch in self.dtypes_out={}, " 

156 "self.dtypes={}.".format(self.dtypes_out, self.dtypes)) 

157 if (len(self.dtypes_out[0]) == 0 or 

158 isinstance(self.dtypes_out[0][0], tuple)): 

159 raise TypeError( # pragma: no cover 

160 f"Type mismatch in self.dtypes_out: {self.dtypes_out}.") 

161 

162 if self.n_variables and self.n_optional > 0: 

163 raise RuntimeError( # pragma: no cover 

164 "n_variables and n_optional cannot be positive at " 

165 "the same type.") 

166 

167 @staticmethod 

168 def _process_type(dtypes, mapped_types, index): 

169 """ 

170 Nicknames such as `floats`, `int`, `ints`, `all` 

171 can be used to describe multiple inputs for 

172 a signature. This function intreprets that. 

173 

174 .. runpython:: 

175 :showcode: 

176 

177 from mlprodict.npy.onnx_numpy_annotation import _NDArrayAlias 

178 for name in ['all', 'int', 'ints', 'floats', 'T']: 

179 print(name, _NDArrayAlias._process_type(name, {'T': 0}, 0)) 

180 """ 

181 if isinstance(dtypes, str): 

182 if ":" in dtypes: 

183 name, dtypes = dtypes.split(':') 

184 if name in mapped_types and dtypes != mapped_types[name]: 

185 raise RuntimeError( # pragma: no cover 

186 "Type name mismatch for '%s:%s' in %r." % ( 

187 name, dtypes, list(sorted(mapped_types)))) 

188 mapped_types[name] = (dtypes, index) 

189 if dtypes == "all": 

190 dtypes = all_dtypes 

191 elif dtypes in ("int", "int64"): 

192 dtypes = (numpy.int64, ) 

193 elif dtypes == "bool": 

194 dtypes = (numpy_bool, ) 

195 elif dtypes == "floats": 

196 dtypes = (numpy.float32, numpy.float64) 

197 elif dtypes == "ints": 

198 dtypes = (numpy.int32, numpy.int64) 

199 elif dtypes == "int64": 

200 dtypes = (numpy.int64, ) 

201 elif dtypes == "float32": 

202 dtypes = (numpy.float32, ) 

203 elif dtypes == "float64": 

204 dtypes = (numpy.float64, ) 

205 elif dtypes not in mapped_types: 

206 raise ValueError( # pragma: no cover 

207 f"Unexpected shortcut for dtype {dtypes!r}.") 

208 elif not isinstance(dtypes, tuple): 

209 dtypes = (dtypes, ) 

210 return dtypes 

211 

212 if isinstance(dtypes, (tuple, list)): 

213 insig = [_NDArrayAlias._process_type(dt, mapped_types, index + d) 

214 for d, dt in enumerate(dtypes)] 

215 return tuple(insig) 

216 

217 if dtypes in all_dtypes: 

218 return dtypes 

219 

220 raise NotImplementedError( # pragma: no cover 

221 f"Unexpected input dtype {dtypes!r}.") 

222 

223 def __repr__(self): 

224 "usual" 

225 return "%s(%r, %r, %r)" % ( 

226 self.__class__.__name__, self.dtypes, self.dtypes_out, 

227 self.n_optional) 

228 

229 def _get_output_types(self, key): 

230 """ 

231 Tries to infer output types. 

232 """ 

233 res = [] 

234 for i, o in enumerate(self.dtypes_out): 

235 if not isinstance(o, tuple): 

236 raise TypeError( # pragma: no cover 

237 "All outputs must be tuple, output %d is %r." 

238 "" % (i, o)) 

239 if (len(o) == 1 and (o[0] in all_dtypes or 

240 o[0] in (bool, numpy_bool, str, numpy_str))): 

241 res.append(o[0]) 

242 elif len(o) == 1 and o[0] in self.mapped_types: 

243 info = self.mapped_types[o[0]] 

244 res.append(key[info[1]]) 

245 elif key[0] in o: 

246 res.append(key[0]) 

247 else: 

248 raise RuntimeError( # pragma: no cover 

249 "Unable to guess output type for output %d, " 

250 "input types are %r, expected output is %r." 

251 "" % (i, key, o)) 

252 return tuple(res) 

253 

254 def get_inputs_outputs(self, args, kwargs, version): 

255 """ 

256 Returns the list of inputs, outputs. 

257 

258 :param args: list of arguments 

259 :param kwargs: list of optional arguments 

260 :param version: required version 

261 :return: *tuple(inputs, kwargs, outputs, optional)*, 

262 inputs and outputs are tuple, kwargs are the arguments, 

263 *optional* is the number of optional arguments 

264 """ 

265 if not isinstance(version, FctVersion): 

266 raise TypeError("Version must be of type 'FctVersion' not " 

267 "%s, version=%s." % (type(version), version)) 

268 if args == ['args', 'kwargs']: 

269 raise RuntimeError( # pragma: no cover 

270 f"Issue with signature {args!r}.") 

271 for k, v in kwargs.items(): 

272 if isinstance(v, type): 

273 raise RuntimeError( # pragma: no cover 

274 f"Default value for argument {k!r} must not be of type {v!r}.") 

275 if (not self.n_variables and 

276 len(args) > len(self.dtypes)): 

277 raise RuntimeError( 

278 "Unexpected number of inputs version=%s.\n" 

279 "Given: args=%s dtypes=%s." % ( 

280 version, args, self.dtypes)) 

281 

282 def _possible_names(): 

283 yield 'y' 

284 yield 'z' # pragma: no cover 

285 yield 'o' # pragma: no cover 

286 for i in range(0, 10000): # pragma: no cover 

287 yield 'o%d' % i 

288 

289 new_kwargs = OrderedDict( 

290 (k, v) for k, v in zip(kwargs, version.kwargs or tuple())) 

291 if self.n_variables: 

292 # undefined number of inputs 

293 optional = 0 

294 else: 

295 optional = len(self.dtypes) - len(version.args) 

296 if optional > self.n_optional: 

297 raise RuntimeError( # pragma: no cover 

298 "Unexpected number of optional parameters %d, at most " 

299 "%d are expected, version=%s, args=%s, dtypes=%s." % ( 

300 optional, self.n_optional, version, args, self.dtypes)) 

301 optional = self.n_optional - optional 

302 

303 onnx_types = [k for k in version.args] 

304 inputs = list(zip(args[:len(version.args)], onnx_types)) 

305 if self.n_variables and len(inputs) < len(version.args): 

306 # Complete the list of inputs 

307 last_name = inputs[-1][0] 

308 while len(inputs) < len(onnx_types): 

309 inputs.append((f'{last_name}{len(inputs)}', 

310 onnx_types[len(inputs)])) 

311 

312 key_out = self._get_output_types(version.args) 

313 onnx_types_out = key_out 

314 

315 names_out = [] 

316 names_in = set(inp[0] for inp in inputs) 

317 for _ in key_out: 

318 for name in _possible_names(): 

319 if name not in names_in and name not in names_out: 

320 name_out = name 

321 break 

322 names_out.append(name_out) 

323 names_in.add(name_out) 

324 

325 outputs = list(zip(names_out, onnx_types_out)) 

326 if optional < 0: 

327 raise RuntimeError( # pragma: no cover 

328 "optional cannot be negative %r (self.n_optional=%r, " 

329 "len(self.dtypes)=%r, len(inputs)=%r) " 

330 "names_in=%r, names_out=%r." % ( 

331 optional, self.n_optional, len(self.dtypes), 

332 len(inputs), names_in, names_out)) 

333 

334 if (not self.n_variables and 

335 len(inputs) + len(new_kwargs) > len(version)): 

336 raise RuntimeError( # pragma: no cover 

337 "Mismatch number of inputs and arguments for version=%s.\n" 

338 "Given: args=%s kwargs=%s.\n" 

339 "Returned: inputs=%s new_kwargs=%s.\n" % ( 

340 version, args, kwargs, inputs, new_kwargs)) 

341 if not self.n_variables and len(inputs) > len(self.dtypes): 

342 raise RuntimeError( # pragma: no cover 

343 "Mismatch number of inputs for version=%s.\n" 

344 "Given: args=%s.\n" 

345 "Expected: dtypes=%s\n" 

346 "Returned: inputs=%s.\n" % ( 

347 version, args, self.dtypes, inputs)) 

348 

349 return inputs, kwargs, outputs, optional, self.n_variables 

350 

351 def shape_calculator(self, dims): 

352 """ 

353 Returns expected dimensions given the input dimensions. 

354 """ 

355 if len(dims) == 0: 

356 return None 

357 res = [dims[0]] 

358 for _ in dims[1:]: 

359 res.append(None) 

360 return res 

361 

362 

363class NDArrayType(_NDArrayAlias): 

364 """ 

365 Shortcut to simplify signature description. 

366 

367 :param dtypes: input dtypes 

368 :param dtypes_out: output dtypes 

369 :param n_optional: number of optional parameters, 0 by default 

370 :param nvars: True if the function allows an infinite number of inputs, 

371 this is incompatible with parameter *n_optional*. 

372 

373 .. versionadded:: 0.6 

374 """ 

375 

376 def __init__(self, dtypes=None, dtypes_out=None, n_optional=None, nvars=False): 

377 _NDArrayAlias.__init__(self, dtypes=dtypes, dtypes_out=dtypes_out, 

378 n_optional=n_optional, nvars=nvars) 

379 

380 

381class NDArrayTypeSameShape(NDArrayType): 

382 """ 

383 Shortcut to simplify signature description. 

384 

385 :param dtypes: input dtypes 

386 :param dtypes_out: output dtypes 

387 :param n_optional: number of optional parameters, 0 by default 

388 :param nvars: True if the function allows an infinite number of inputs, 

389 this is incompatible with parameter *n_optional*. 

390 

391 .. versionadded:: 0.6 

392 """ 

393 

394 def __init__(self, dtypes=None, dtypes_out=None, n_optional=None, nvars=False): 

395 NDArrayType.__init__(self, dtypes=dtypes, dtypes_out=dtypes_out, 

396 n_optional=n_optional, nvars=nvars) 

397 

398 

399class NDArraySameType(NDArrayType): 

400 """ 

401 Shortcut to simplify signature description. 

402 

403 :param dtypes: input dtypes 

404 

405 .. versionadded:: 0.6 

406 """ 

407 

408 def __init__(self, dtypes=None): 

409 if dtypes is None: 

410 raise ValueError("dtypes cannot be None.") # pragma: no cover 

411 if isinstance(dtypes, str) and "_" in dtypes: 

412 raise ValueError( # pragma: no cover 

413 "dtypes cannot include '_' meaning two different types.") 

414 if isinstance(dtypes, tuple): 

415 raise ValueError( # pragma: no cover 

416 "dtypes must be a single type.") 

417 NDArrayType.__init__(self, dtypes=(dtypes, )) 

418 

419 def __repr__(self): 

420 "usual" 

421 return f"{self.__class__.__name__}({self.dtypes!r})" 

422 

423 

424class NDArraySameTypeSameShape(NDArraySameType): 

425 """ 

426 Shortcut to simplify signature description. 

427 

428 :param dtypes: input dtypes 

429 

430 .. versionadded:: 0.6 

431 """ 

432 

433 def __init__(self, dtypes=None): 

434 NDArraySameType.__init__(self, dtypes=dtypes)