Coverage for mlprodict/onnx_tools/onnx2py_helper.py: 92%

475 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Functions which converts :epkg:`ONNX` object into 

4readable :epkg:`python` objects. 

5""" 

6import pprint 

7import warnings 

8import numpy 

9from scipy.sparse import coo_matrix 

10from onnx.defs import get_schema, get_function_ops, onnx_opset_version 

11from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE, TENSOR_TYPE_TO_NP_TYPE 

12from onnx import TensorProto, ValueInfoProto, TypeProto, TensorShapeProto 

13from onnx.helper import make_tensor_type_proto 

14from onnx.numpy_helper import to_array, from_array as onnx_from_array 

15 

16 

17def get_tensor_shape(obj): 

18 """ 

19 Returns the shape if that makes sense for this object. 

20 """ 

21 if isinstance(obj, ValueInfoProto): 

22 return get_tensor_shape(obj.type) 

23 elif not isinstance(obj, TypeProto): 

24 raise TypeError( # pragma: no cover 

25 f"Unexpected type {type(obj)!r}.") 

26 if not obj.tensor_type.HasField('shape'): 

27 return None 

28 shape = [] 

29 for d in obj.tensor_type.shape.dim: 

30 v = d.dim_value if d.dim_value > 0 else d.dim_param 

31 shape.append(v) 

32 if len(shape) == 0: 

33 return shape 

34 return list(None if s in (0, '') else s for s in shape) 

35 

36 

37def get_tensor_elem_type(obj): 

38 """ 

39 Returns the element type if that makes sense for this object. 

40 """ 

41 if isinstance(obj, ValueInfoProto): 

42 return get_tensor_elem_type(obj.type) 

43 elif not isinstance(obj, TypeProto): 

44 raise TypeError( # pragma: no cover 

45 f"Unexpected type {type(obj)!r}.") 

46 if obj.tensor_type.ByteSize() == 0: 

47 raise TypeError( # pragma: no cover 

48 f"Unable to guess element type for {obj!r}.") 

49 return obj.tensor_type.elem_type 

50 

51 

52def to_bytes(val): 

53 """ 

54 Converts an array into protobuf and then into bytes. 

55 

56 :param val: array 

57 :return: bytes 

58 

59 .. exref:: 

60 :title: Converts an array into bytes (serialization) 

61 

62 Useful to serialize. 

63 

64 .. runpython:: 

65 :showcode: 

66 :warningout: DeprecationWarning 

67 

68 import numpy 

69 from mlprodict.onnx_tools.onnx2py_helper import to_bytes 

70 

71 data = numpy.array([[0, 1], [2, 3], [4, 5]], dtype=numpy.float32) 

72 pb = to_bytes(data) 

73 print(len(pb), data.size * data.itemsize, pb[:10]) 

74 """ 

75 if isinstance(val, numpy.ndarray): 

76 pb = from_array(val) 

77 else: 

78 pb = val # pragma: no cover 

79 return pb.SerializeToString() 

80 

81 

82def from_array(value, name=None): 

83 """ 

84 Converts an array into an ONNX tensor. 

85 

86 :param value: numpy array 

87 :return: ONNX tensor 

88 """ 

89 if isinstance(value, numpy.ndarray): 

90 try: 

91 pb = onnx_from_array(value, name=name) 

92 except NotImplementedError as e: # pragma: no cover 

93 if value.dtype == numpy.dtype('O'): 

94 pb = TensorProto() 

95 pb.data_type = TensorProto.STRING # pylint: disable=E1101 

96 if name is not None: 

97 pb.name = name 

98 pb.dims.extend(value.shape) # pylint: disable=E1101 

99 pb.string_data.extend( # pylint: disable=E1101 

100 list(map(lambda o: str(o).encode('utf-8'), value.ravel()))) 

101 else: 

102 raise NotImplementedError( 

103 "Unable to convert type %r (dtype=%r) into an ONNX tensor " 

104 "due to %r." % (type(value), value.dtype, e)) from e 

105 return pb 

106 if isinstance(value, TensorProto): # pragma: no cover 

107 return value 

108 raise NotImplementedError( # pragma: no cover 

109 f"Unable to convert type {type(value)!r} into an ONNX tensor.") 

110 

111 

112def from_bytes(b): 

113 """ 

114 Retrieves an array from bytes then protobuf. 

115 

116 :param b: bytes 

117 :return: array 

118 

119 .. exref:: 

120 :title: Converts bytes into an array (serialization) 

121 

122 Useful to deserialize. 

123 

124 .. runpython:: 

125 :showcode: 

126 :warningout: DeprecationWarning 

127 

128 import numpy 

129 from mlprodict.onnx_tools.onnx2py_helper import to_bytes, from_bytes 

130 

131 data = numpy.array([[0, 1], [2, 3], [4, 5]], dtype=numpy.float32) 

132 pb = to_bytes(data) 

133 data2 = from_bytes(pb) 

134 print(data2) 

135 """ 

136 if isinstance(b, bytes): 

137 pb = TensorProto() 

138 pb.ParseFromString(b) 

139 else: 

140 pb = b # pragma: no cover 

141 return to_array(pb) 

142 

143 

144def _numpy_array(data, dtype=None, copy=True): 

145 """ 

146 Single function to create an array. 

147 

148 @param data data 

149 @param dtype dtype 

150 @param copy copy 

151 @return numpy array 

152 """ 

153 if isinstance(data, numpy.ndarray): 

154 res = data 

155 else: 

156 res = numpy.array(data, dtype=dtype, copy=copy) 

157 return res 

158 

159 

160def _sparse_array(shape, data, indices, dtype=None, copy=True): 

161 """ 

162 Single function to create an sparse array 

163 (:epkg:`coo_matrix`). 

164 

165 @param shape shape 

166 @param data data 

167 @param indices indices 

168 @param dtype dtype 

169 @param copy copy 

170 @return :epkg:`coo_matrix` 

171 """ 

172 if len(shape) != 2: 

173 raise ValueError( # pragma: no cover 

174 f"Only matrices are allowed or sparse matrices but shape is {shape}.") 

175 rows = numpy.array([i // shape[1] for i in indices]) 

176 cols = numpy.array([i % shape[1] for i in indices]) 

177 if isinstance(data, numpy.ndarray): 

178 res = coo_matrix((data, (rows, cols)), dtype=dtype) 

179 else: 

180 res = coo_matrix( # pragma: no cover 

181 (numpy.array(data, dtype=dtype, copy=copy), 

182 (rows, cols)), dtype=dtype) 

183 return res 

184 

185 

186def guess_numpy_type_from_string(name): 

187 """ 

188 Converts a string (such as `'float'`) into a 

189 numpy dtype. 

190 """ 

191 if name in ('float', 'float32'): 

192 return numpy.float32 

193 if name in ('double', 'float64'): 

194 return numpy.float64 

195 if name == 'float16': 

196 return numpy.float16 

197 if name == 'int64': 

198 return numpy.int64 

199 if name == 'int8': 

200 return numpy.int8 

201 if name == 'uint8': 

202 return numpy.uint8 

203 if name == 'int32': 

204 return numpy.int32 

205 if name == 'int16': 

206 return numpy.int16 

207 if name == 'bool': 

208 return numpy.bool_ 

209 if name == 'str': 

210 return numpy.str_ 

211 raise ValueError( # pragma: no cover 

212 f"Unable to guess numpy dtype from {name!r}.") 

213 

214 

215def guess_numpy_type_from_dtype(dt): 

216 """ 

217 Converts a string (such as `'dtype(float32)'`) into a 

218 numpy dtype. 

219 """ 

220 if dt in {numpy.int8, numpy.uint8, numpy.float16, numpy.float32, 

221 numpy.float64, numpy.int32, numpy.int64, numpy.int16, 

222 numpy.uint16, numpy.uint32, numpy.bool_, numpy.str_, 

223 numpy.uint64, bool, str, }: 

224 return dt 

225 if dt == numpy.dtype('float32'): 

226 return numpy.float32 

227 if dt == numpy.dtype('float64'): 

228 return numpy.float64 

229 if dt == numpy.dtype('int64'): 

230 return numpy.int64 

231 if dt == numpy.dtype('int8'): 

232 return numpy.int8 

233 if dt == numpy.dtype('uint8'): 

234 return numpy.uint8 

235 raise ValueError( # pragma: no cover 

236 f"Unable to guess numpy dtype from {dt!r}.") 

237 

238 

239def _elem_type_as_str(elem_type): 

240 if elem_type == TensorProto.FLOAT: # pylint: disable=E1101 

241 return 'float' 

242 if elem_type == TensorProto.BOOL: # pylint: disable=E1101 

243 return 'bool' 

244 if elem_type == TensorProto.DOUBLE: # pylint: disable=E1101 

245 return 'double' 

246 if elem_type == TensorProto.STRING: # pylint: disable=E1101 

247 return 'str' 

248 if elem_type == TensorProto.INT64: # pylint: disable=E1101 

249 return 'int64' 

250 if elem_type == TensorProto.INT32: # pylint: disable=E1101 

251 return 'int32' 

252 if elem_type == TensorProto.UINT32: # pylint: disable=E1101 

253 return 'uint32' 

254 if elem_type == TensorProto.UINT64: # pylint: disable=E1101 

255 return 'uint64' 

256 if elem_type == TensorProto.INT16: # pylint: disable=E1101 

257 return 'int16' 

258 if elem_type == TensorProto.UINT16: # pylint: disable=E1101 

259 return 'uint16' 

260 if elem_type == TensorProto.UINT8: # pylint: disable=E1101 

261 return 'uint8' 

262 if elem_type == TensorProto.INT8: # pylint: disable=E1101 

263 return 'int8' 

264 if elem_type == TensorProto.FLOAT16: # pylint: disable=E1101 

265 return 'float16' 

266 if elem_type == TensorProto.COMPLEX64: # pylint: disable=E1101 

267 return 'complex64' 

268 if elem_type == TensorProto.COMPLEX128: # pylint: disable=E1101 

269 return 'complex128' 

270 if elem_type == 0: # pylint: disable=E1101 

271 return 'unk' 

272 

273 # The following code should be refactored. 

274 selem = str(elem_type) 

275 

276 if selem.startswith("tensor_type"): 

277 this = elem_type.tensor_type 

278 et = _elem_type_as_str(this.elem_type) 

279 shape = this.shape 

280 dim = shape.dim 

281 dims = [d.dim_value for d in dim] 

282 if len(dims) == 0: 

283 dims = '?' 

284 return {'kind': 'tensor', 'elem': et, 'shape': shape} 

285 

286 if selem.startswith("optional_type"): 

287 this = elem_type.optional_type 

288 et = _elem_type_as_str(this.elem_type) 

289 shape = this.shape 

290 dim = shape.dim 

291 dims = [d.dim_value for d in dim] 

292 if len(dims) == 0: 

293 dims = '?' 

294 return {'kind': 'tensor', 'elem': et, 'shape': shape, 

295 'optional_type': True} 

296 

297 if selem.startswith("map_type"): 

298 this = elem_type.map_type 

299 kt = _elem_type_as_str(this.key_type) 

300 vt = _elem_type_as_str(this.value_type) 

301 return {'kind': 'map', 'key': kt, 'value': vt} 

302 

303 raise NotImplementedError( # pragma: no cover 

304 "elem_type '{}' is unknown\nfields:\n{}\n-----\n{}.".format( 

305 elem_type, pprint.pformat(dir(elem_type)), type(elem_type))) 

306 

307 

308def _to_array(var): 

309 try: 

310 data = to_array(var) 

311 except ValueError as e: # pragma: no cover 

312 dims = [d for d in var.dims] 

313 if var.data_type == 1 and var.float_data is not None: 

314 try: 

315 data = _numpy_array(var.float_data, dtype=numpy.float32, 

316 copy=False).reshape(dims) 

317 except ValueError: 

318 data = _numpy_array(to_array(var)) 

319 elif var.data_type == 2 and var.uint8_data is not None: 

320 data = _numpy_array(var.uint8_data, dtype=numpy.uint8, 

321 copy=False).reshape(dims) 

322 elif var.data_type == 3 and var.int8_data is not None: 

323 data = _numpy_array(var.int8_data, dtype=numpy.int8, 

324 copy=False).reshape(dims) 

325 elif var.data_type == 4 and var.uint16_data is not None: 

326 data = _numpy_array(var.uint16_data, dtype=numpy.uint16, 

327 copy=False).reshape(dims) 

328 elif var.data_type == 5 and var.int16_data is not None: 

329 data = _numpy_array(var.int16_data, dtype=numpy.int16, 

330 copy=False).reshape(dims) 

331 elif var.data_type == 6 and var.int32_data is not None: 

332 data = _numpy_array(var.int32_data, dtype=numpy.int32, 

333 copy=False).reshape(dims) 

334 elif var.data_type == 7 and var.int64_data is not None: 

335 data = _numpy_array(var.int64_data, dtype=numpy.int64, 

336 copy=False).reshape(dims) 

337 elif var.data_type == 11 and var.double_data is not None: 

338 try: 

339 data = _numpy_array(var.double_data, dtype=numpy.float64, 

340 copy=False).reshape(dims) 

341 except ValueError: 

342 data = _numpy_array(to_array(var)) 

343 elif var.data_type == 16 and var.float16_data is not None: 

344 data = _numpy_array(var.float16_data, dtype=numpy.float16, 

345 copy=False).reshape(dims) 

346 else: 

347 raise NotImplementedError( 

348 f"Iniatilizer {var} cannot be converted into a dictionary.") from e 

349 return data 

350 

351 

352def _var_as_dict(var): # pylint: disable=R0912 

353 """ 

354 Converts a protobuf object into something readable. 

355 The current implementation relies on :epkg:`json`. 

356 That's not the most efficient way. 

357 """ 

358 if hasattr(var, 'type') and str(var.type) != '': 

359 # variable 

360 if var.type is not None: 

361 if hasattr(var, 'sparse_tensor') and var.type == 11: 

362 # sparse tensor 

363 t = var.sparse_tensor 

364 values = _var_as_dict(t.values) 

365 dims = list(t.dims) 

366 dtype = dict(kind='sparse_tensor', shape=tuple(dims), elem=1) 

367 elif (hasattr(var.type, 'tensor_type') and 

368 var.type.tensor_type.elem_type > 0): 

369 t = var.type.tensor_type 

370 elem_type = _elem_type_as_str(t.elem_type) 

371 shape = t.shape 

372 dim = shape.dim 

373 dims = [d.dim_value for d in dim] 

374 if len(dims) == 0: 

375 dims = '?' 

376 dtype = dict(kind='tensor', elem=elem_type, 

377 shape=tuple(dims)) 

378 elif (hasattr(var.type, 'optional_type') and 

379 var.type.tensor_type.elem_type > 0): 

380 t = var.type.optional_type 

381 elem_type = _elem_type_as_str(t.elem_type) 

382 shape = t.shape 

383 dim = shape.dim 

384 dims = [d.dim_value for d in dim] 

385 if len(dims) == 0: 

386 dims = '?' 

387 dtype = dict(kind='tensor', elem=elem_type, 

388 shape=tuple(dims), optional_type=True) 

389 elif (hasattr(var.type, 'real') and var.type.real == 5 and 

390 hasattr(var, 'g')): 

391 dtype = dict(kind='graph', elem=var.type.real) 

392 elif (hasattr(var.type, 'real') and var.type.real == 4 and 

393 hasattr(var, 't')): 

394 dtype = dict(kind='tensor', elem=var.type.real) 

395 elif hasattr(var.type, 'real'): 

396 dtype = dict(kind='real', elem=var.type.real) 

397 elif (hasattr(var.type, "sequence_type") and 

398 var.type.sequence_type is not None and 

399 str(var.type.sequence_type.elem_type) != ''): 

400 t = var.type.sequence_type 

401 elem_type = _elem_type_as_str(t.elem_type) 

402 dtype = dict(kind='sequence', elem=elem_type) 

403 elif (hasattr(var.type, "map_type") and 

404 var.type.map_type is not None and 

405 str(var.type.map_type.key_type) != '' and 

406 str(var.type.map_type.value_type) != ''): 

407 t = var.type.map_type 

408 key_type = _elem_type_as_str(t.key_type) 

409 value_type = _elem_type_as_str(t.value_type) 

410 dtype = dict(kind='map', key=key_type, value=value_type) 

411 elif (hasattr(var.type, 'tensor_type') and 

412 var.type.tensor_type.elem_type == 0): 

413 if hasattr(var.type, 'optional_type'): 

414 optional = var.type.optional_type 

415 else: 

416 optional = None 

417 t = var.type.tensor_type 

418 elem_type = _elem_type_as_str(t.elem_type) 

419 shape = t.shape 

420 dim = shape.dim 

421 dims = [d.dim_value for d in dim] 

422 if len(dims) == 0: 

423 dims = '?' 

424 dtype = dict(kind='tensor', elem=elem_type, 

425 shape=tuple(dims)) 

426 if optional is not None: 

427 dtype['optional'] = _var_as_dict(optional) 

428 else: 

429 raise NotImplementedError( # pragma: no cover 

430 "Unable to convert a type into a dictionary for '{}'. " 

431 "Available fields: {}.".format( 

432 var.type, pprint.pformat(dir(var.type)))) 

433 else: 

434 raise NotImplementedError( # pragma: no cover 

435 "Unable to convert variable into a dictionary for '{}'. " 

436 "Available fields: {}.".format( 

437 var, pprint.pformat(dir(var.type)))) 

438 

439 res = dict(name=var.name, type=dtype) 

440 

441 if (hasattr(var, 'sparse_tensor') and dtype.get('elem', None) == 1 and 

442 dtype['kind'] == 'sparse_tensor'): 

443 # sparse matrix 

444 t = var.sparse_tensor 

445 try: 

446 values = _var_as_dict(t.values) 

447 except NotImplementedError as e: # pragma: no cover 

448 raise NotImplementedError( 

449 f"Issue with\n{var}\n---") from e 

450 indices = _var_as_dict(t.indices) 

451 res['value'] = _sparse_array( 

452 dtype['shape'], values['value'], indices['value'], dtype=numpy.float32) 

453 elif hasattr(var, 'floats') and dtype.get('elem', None) == 6: 

454 res['value'] = _numpy_array(var.floats, dtype=numpy.float32) 

455 elif hasattr(var, 'strings') and dtype.get('elem', None) == 8: 

456 res['value'] = _numpy_array(var.strings) 

457 elif hasattr(var, 'ints') and dtype.get('elem', None) == 7: 

458 res['value'] = _numpy_array(var.ints) 

459 elif hasattr(var, 'f') and dtype.get('elem', None) == 1: 

460 res['value'] = var.f 

461 elif hasattr(var, 's') and dtype.get('elem', None) == 3: 

462 res['value'] = var.s 

463 elif hasattr(var, 'i') and dtype.get('elem', None) == 2: 

464 res['value'] = var.i 

465 elif hasattr(var, 'g') and dtype.get('elem', None) == 5: 

466 res['value'] = var.g 

467 elif hasattr(var, 't') and dtype.get('elem', None) == 4: 

468 if hasattr(var, 'ref_attr_name') and var.ref_attr_name: 

469 res['ref_attr_name'] = var.ref_attr_name 

470 else: 

471 ts = _var_as_dict(var.t) 

472 res['value'] = ts['value'] 

473 elif hasattr(var, 'sparse_tensor') and dtype.get('elem', None) == 11: 

474 ts = _var_as_dict(var.sparse_tensor) 

475 if hasattr(var, 'ref_attr_name') and var.ref_attr_name: 

476 res['ref_attr_name'] = var.ref_attr_name 

477 else: 

478 ts = _var_as_dict(var.t) 

479 res['value'] = ts['value'] 

480 elif "'value'" in str(var): 

481 warnings.warn("No value: {} -- {}".format( # pragma: no cover 

482 dtype, str(var).replace("\n", "").replace(" ", ""))) 

483 return res 

484 

485 if hasattr(var, 'op_type'): 

486 if hasattr(var, 'attribute'): 

487 atts = {} 

488 for att in var.attribute: 

489 atts[att.name] = _var_as_dict(att) 

490 return dict(name=var.name, op_type=var.op_type, 

491 domain=var.domain, atts=atts) 

492 if hasattr(var, 'dims') and len(var.dims) > 0: 

493 # initializer 

494 data = _to_array(var) 

495 return dict(name=var.name, value=data) 

496 if hasattr(var, 'data_type') and var.data_type > 0: 

497 data = _to_array(var) 

498 return dict(name=var.name, value=data) 

499 if isinstance(var, str): 

500 return dict(name=var) 

501 if str(var) == '': 

502 return None 

503 if isinstance(var, ValueInfoProto): 

504 return dict(name=var.name, 

505 type=dict(elem='unk', kind='tensor', shape=('?', ))) 

506 if isinstance(var, TensorShapeProto): 

507 ds = [] 

508 for dim in var.dim: 

509 d = {} 

510 if dim.dim_value: 

511 d['dim_value'] = dim.dim_value 

512 if dim.dim_param: 

513 d['dim_param'] = dim.dim_param 

514 ds.append(d) 

515 return dict(dim=ds) 

516 if isinstance(var, TypeProto): 

517 d = dict(denotation=var.denotation) 

518 for n in dir(var): 

519 if n.endswith('_type'): 

520 at = getattr(var, n) 

521 d[n] = _var_as_dict(at) 

522 return d 

523 if var.__class__.__name__ == "Tensor": 

524 return dict(elem_type=var.elem_type, shape=_var_as_dict(var.shape)) 

525 if var.__class__.__name__ == "Optional": 

526 return dict(optional=True, elem_type=_var_as_dict(var.elem_type)) 

527 

528 raise NotImplementedError( # pragma: no cover 

529 "Unable to guess which object it is type is %r value is %r " 

530 "(hasattr(var,'type')=%r, var.type=%s\n%s" 

531 "" % (type(var), str(var), hasattr(var, 'type'), 

532 str(getattr(var, 'type', None)), 

533 '\n'.join(dir(var)))) 

534 

535 

536def get_dtype_shape(obj): 

537 """ 

538 Returns the shape of a tensor. 

539 

540 :param obj: onnx object 

541 :return: `(dtype, shape)` or `(None, None)` if not applicable 

542 """ 

543 if not hasattr(obj, 'type'): 

544 return None 

545 t = obj.type 

546 if not hasattr(t, 'tensor_type'): 

547 return None 

548 t = t.tensor_type 

549 dtype = t.elem_type 

550 if not hasattr(t, 'shape'): 

551 return dtype, None 

552 shape = t.shape 

553 ds = [] 

554 for dim in shape.dim: 

555 d = dim.dim_value 

556 s = dim.dim_param 

557 if d == 0: 

558 if s == '': 

559 ds.append(None) 

560 else: 

561 ds.append(s) 

562 else: 

563 ds.append(d) 

564 return dtype, tuple(ds) 

565 

566 

567def onnx_model_opsets(onnx_model): 

568 """ 

569 Extracts opsets in a dictionary. 

570 

571 :param onnx_model: ONNX graph 

572 :return: dictionary `{domain: version}` 

573 """ 

574 res = {} 

575 for oimp in onnx_model.opset_import: 

576 res[oimp.domain] = oimp.version 

577 return res 

578 

579 

580def _type_to_string(dtype): 

581 """ 

582 Converts a type into a readable string. 

583 """ 

584 if not isinstance(dtype, dict): 

585 dtype_ = _var_as_dict(dtype) # pragma: no cover 

586 else: 

587 dtype_ = dtype 

588 if dtype_["kind"] == 'tensor': 

589 return f"{dtype_['elem']}({dtype_['shape']})" 

590 if dtype_['kind'] == 'sequence': 

591 return f"[{_type_to_string(dtype_['elem'])}]" 

592 if dtype_["kind"] == 'map': 

593 return f"{{{dtype_['key']}, {dtype_['value']}}}" 

594 raise NotImplementedError( # pragma: no cover 

595 f"Unable to convert into string {dtype} or {dtype_}.") 

596 

597 

598def numpy_min(x): 

599 """ 

600 Returns the minimum of an array. 

601 Deals with text as well. 

602 """ 

603 try: 

604 if hasattr(x, 'todense'): 

605 x = x.todense() 

606 if x.dtype.kind not in 'cUC': 

607 return x.min() 

608 try: # pragma: no cover 

609 x = x.ravel() 

610 except AttributeError: # pragma: no cover 

611 pass 

612 keep = list(filter(lambda s: isinstance(s, str), x)) 

613 if len(keep) == 0: # pragma: no cover 

614 return numpy.nan 

615 keep.sort() 

616 val = keep[0] 

617 if len(val) > 10: # pragma: no cover 

618 val = val[:10] + '...' 

619 return f"{val!r}" 

620 except (ValueError, TypeError): # pragma: no cover 

621 return '?' 

622 

623 

624def numpy_max(x): 

625 """ 

626 Returns the maximum of an array. 

627 Deals with text as well. 

628 """ 

629 try: 

630 if hasattr(x, 'todense'): 

631 x = x.todense() 

632 if x.dtype.kind not in 'cUC': 

633 return x.max() 

634 try: # pragma: no cover 

635 x = x.ravel() 

636 except AttributeError: # pragma: no cover 

637 pass 

638 keep = list(filter(lambda s: isinstance(s, str), x)) 

639 if len(keep) == 0: # pragma: no cover 

640 return numpy.nan 

641 keep.sort() 

642 val = keep[-1] 

643 if len(val) > 10: # pragma: no cover 

644 val = val[:10] + '...' 

645 return f"{val!r}" 

646 except (ValueError, TypeError): # pragma: no cover 

647 return '?' 

648 

649 

650def guess_proto_dtype(dtype): 

651 """ 

652 Guesses the ONNX dtype given a numpy dtype. 

653 

654 :param dtype: numpy dtype 

655 :return: proto type 

656 """ 

657 if dtype == numpy.float32: 

658 return TensorProto.FLOAT # pylint: disable=E1101 

659 if dtype == numpy.float64: 

660 return TensorProto.DOUBLE # pylint: disable=E1101 

661 if dtype == numpy.int64: 

662 return TensorProto.INT64 # pylint: disable=E1101 

663 if dtype == numpy.int32: 

664 return TensorProto.INT32 # pylint: disable=E1101 

665 if dtype == numpy.int16: 

666 return TensorProto.INT16 # pylint: disable=E1101 

667 if dtype == numpy.int8: 

668 return TensorProto.INT8 # pylint: disable=E1101 

669 if dtype == numpy.uint64: 

670 return TensorProto.UINT64 # pylint: disable=E1101 

671 if dtype == numpy.uint32: 

672 return TensorProto.UINT32 # pylint: disable=E1101 

673 if dtype == numpy.uint16: 

674 return TensorProto.UINT16 # pylint: disable=E1101 

675 if dtype == numpy.uint8: 

676 return TensorProto.UINT8 # pylint: disable=E1101 

677 if dtype == numpy.float16: 

678 return TensorProto.FLOAT16 # pylint: disable=E1101 

679 if dtype in (bool, numpy.bool_): 

680 return TensorProto.BOOL # pylint: disable=E1101 

681 if dtype in (str, numpy.str_): 

682 return TensorProto.STRING # pylint: disable=E1101 

683 raise RuntimeError( 

684 f"Unable to guess type for dtype={dtype}.") # pragma: no cover 

685 

686 

687def guess_proto_dtype_name(onnx_dtype): 

688 """ 

689 Returns a string equivalent to `onnx_dtype`. 

690 

691 :param dtype: onnx dtype 

692 :return: proto type 

693 """ 

694 if onnx_dtype == TensorProto.FLOAT: # pylint: disable=E1101 

695 return "TensorProto.FLOAT" 

696 if onnx_dtype == TensorProto.DOUBLE: # pylint: disable=E1101 

697 return "TensorProto.DOUBLE" 

698 if onnx_dtype == TensorProto.INT64: # pylint: disable=E1101 

699 return "TensorProto.INT64" 

700 if onnx_dtype == TensorProto.INT32: # pylint: disable=E1101 

701 return "TensorProto.INT32" 

702 if onnx_dtype == TensorProto.INT16: # pylint: disable=E1101 

703 return "TensorProto.INT16" 

704 if onnx_dtype == TensorProto.UINT8: # pylint: disable=E1101 

705 return "TensorProto.UINT8" 

706 if onnx_dtype == TensorProto.FLOAT16: # pylint: disable=E1101 

707 return "TensorProto.FLOAT16" 

708 if onnx_dtype == TensorProto.BFLOAT16: # pylint: disable=E1101 

709 return "TensorProto.BFLOAT16" 

710 if onnx_dtype == TensorProto.BOOL: # pylint: disable=E1101 

711 return "TensorProto.BOOL" 

712 if onnx_dtype == TensorProto.STRING: # pylint: disable=E1101 

713 return "TensorProto.STRING" 

714 raise RuntimeError( # pragma: no cover 

715 f"Unable to guess type for dtype={onnx_dtype}.") 

716 

717 

718def guess_dtype(proto_type): 

719 """ 

720 Converts a proto type into a :epkg:`numpy` type. 

721 

722 :param proto_type: example ``onnx.TensorProto.FLOAT`` 

723 :return: :epkg:`numpy` dtype 

724 """ 

725 if proto_type == TensorProto.FLOAT: # pylint: disable=E1101 

726 return numpy.float32 

727 if proto_type == TensorProto.BOOL: # pylint: disable=E1101 

728 return numpy.bool_ 

729 if proto_type == TensorProto.DOUBLE: # pylint: disable=E1101 

730 return numpy.float64 

731 if proto_type == TensorProto.STRING: # pylint: disable=E1101 

732 return numpy.str_ 

733 if proto_type == TensorProto.INT64: # pylint: disable=E1101 

734 return numpy.int64 

735 if proto_type == TensorProto.INT32: # pylint: disable=E1101 

736 return numpy.int32 

737 if proto_type == TensorProto.INT8: # pylint: disable=E1101 

738 return numpy.int8 

739 if proto_type == TensorProto.INT16: # pylint: disable=E1101 

740 return numpy.int16 

741 if proto_type == TensorProto.UINT64: # pylint: disable=E1101 

742 return numpy.uint64 

743 if proto_type == TensorProto.UINT32: # pylint: disable=E1101 

744 return numpy.uint32 

745 if proto_type == TensorProto.UINT8: # pylint: disable=E1101 

746 return numpy.uint8 

747 if proto_type == TensorProto.UINT16: # pylint: disable=E1101 

748 return numpy.uint16 

749 if proto_type == TensorProto.FLOAT16: # pylint: disable=E1101 

750 return numpy.float16 

751 raise ValueError( 

752 f"Unable to convert proto_type {proto_type} to numpy type.") 

753 

754 

755def to_skl2onnx_type(name, elem_type, shape): 

756 """ 

757 Converts *name*, *elem_type*, *shape* into a 

758 :epkg:`sklearn-onnx` type. 

759 

760 :param name: string 

761 :param elem_type: tensor of elements of this type 

762 :param shape: expected shape 

763 :return: data type 

764 """ 

765 from skl2onnx.common.data_types import _guess_numpy_type # delayed 

766 elem = guess_numpy_type_from_string(elem_type) 

767 shape = list(None if d == 0 else d for d in shape) 

768 return (name, _guess_numpy_type(elem, shape)) 

769 

770 

771def from_pb(obj): 

772 """ 

773 Extracts tensor description from a protobuf. 

774 

775 :param obj: initializer, tensor 

776 :return: (name, type, shape) 

777 """ 

778 def get_dim(d): 

779 r = d.dim_value 

780 if "dim_param" in str(d): 

781 return None 

782 if r == 0: 

783 # dim_value is 0 when it is 0 or undefined 

784 return 0 if "0" in str(d) else None 

785 return r 

786 

787 def get_shape(tt): 

788 return [get_dim(tt.shape.dim[i]) 

789 for i in range(len(tt.shape.dim))] 

790 

791 if hasattr(obj, 'extend'): 

792 return [from_pb(o) for o in obj] 

793 

794 name = obj.name 

795 if obj.type.tensor_type: 

796 tt = obj.type.tensor_type 

797 elem = tt.elem_type 

798 shape = get_shape(tt) 

799 if elem not in TENSOR_TYPE_TO_NP_TYPE: 

800 raise NotImplementedError( 

801 f"Unsupported type '{type(obj.type.tensor_type)}' (elem_type={elem}).") 

802 ty = TENSOR_TYPE_TO_NP_TYPE[elem].type 

803 else: 

804 raise NotImplementedError( # pragma: no cover 

805 f"Unsupported type '{type(obj)}' as a string ({obj}).") 

806 

807 return (name, ty, shape) 

808 

809 

810def numpy_type_prototype(dtype): 

811 """ 

812 Converts a numpy dtyp into a TensorProto dtype. 

813 

814 :param dtype: dtype 

815 :return: proto dtype 

816 """ 

817 if dtype in NP_TYPE_TO_TENSOR_TYPE: 

818 return NP_TYPE_TO_TENSOR_TYPE[dtype] 

819 dt = numpy.dtype(dtype) 

820 if dt in NP_TYPE_TO_TENSOR_TYPE: 

821 return NP_TYPE_TO_TENSOR_TYPE[dt] 

822 raise ValueError( # pragma: no cover 

823 f"Unable to convert dtype {dtype!r} into ProtoType.") 

824 

825 

826def make_value_info(name, dtype, shape): 

827 """ 

828 Converts a variable defined by its name, type and shape 

829 into `onnx.ValueInfoProto`. 

830 

831 :param name: name 

832 :param dtype: numpy element type 

833 :param shape: shape 

834 :return: instance of `onnx.ValueInfoProto` 

835 """ 

836 value_info = ValueInfoProto() 

837 value_info.name = name 

838 tensor_type_proto = make_tensor_type_proto( 

839 numpy_type_prototype(dtype), shape) 

840 value_info.type.CopyFrom(tensor_type_proto) # pylint: disable=E1101 

841 return value_info 

842 

843 

844def copy_value_info(info, name=None): 

845 """ 

846 Makes a copy of `onnx.ValueInfoProto`. 

847 

848 :param name: if defined, changed the name 

849 :return: instance of `onnx.ValueInfoProto` 

850 """ 

851 value_info = ValueInfoProto() 

852 value_info.name = name or info.name 

853 value_info.type.CopyFrom(info.type) # pylint: disable=E1101 

854 return value_info 

855 

856 

857_get_onnx_function_cache = None 

858 

859 

860def _get_onnx_function(): 

861 """ 

862 Returns the list of functions defined in ONNX package. 

863 """ 

864 global _get_onnx_function_cache # pylint: disable=W0603 

865 if _get_onnx_function_cache is None: 

866 _get_onnx_function_cache = {} 

867 fcts = get_function_ops() 

868 for fct in fcts: 

869 key = fct.domain, fct.name 

870 if key in _get_onnx_function_cache: 

871 raise RuntimeError( # pragma: no cover 

872 f"Function {key!r} is already registered.") 

873 _get_onnx_function_cache[key] = fct 

874 return _get_onnx_function_cache 

875 

876 

877def get_onnx_schema(opname, domain='', opset=None, load_function=False): 

878 """ 

879 Returns the operator schema for a specific operator. 

880 

881 :param domain: operator domain 

882 :param opname: operator name 

883 :param opset: opset or version, None for the latest 

884 :param load_function: loads the function, if True, the function 

885 looks into the list of function if one of them has the same name, 

886 opset must be None in that case 

887 :return: :epkg:`OpSchema` 

888 """ 

889 if load_function: 

890 if opset is not None: 

891 raise ValueError( 

892 "opset must be None if load_function is True for " 

893 "operator (%r,%r)." % (domain, opname)) 

894 fcts = _get_onnx_function() 

895 key = domain, opname 

896 if key in fcts: 

897 return fcts[key] 

898 if opset is None: 

899 opset = onnx_opset_version() 

900 return get_schema(opname, opset, domain) 

901 if opset is None: 

902 opset = onnx_opset_version() 

903 return get_schema(opname, opset, domain)