Coverage for mlprodict/onnx_tools/exports/tf2onnx_helper.py: 95%

309 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Helpers to run examples created with function 

4@see fn export2tf2onnx. 

5""" 

6import collections 

7import inspect 

8import numpy 

9from onnx.numpy_helper import from_array 

10from onnx.helper import ( 

11 make_node, make_graph, make_model, set_model_props, make_tensor) 

12from onnx import AttributeProto 

13from ..onnx2py_helper import guess_dtype, guess_proto_dtype 

14from ..onnx_tools import ensure_topological_order 

15 

16 

17_make_name_id = 0 

18 

19 

20def make_tf2onnx_code(opset, name=None, op_type=None, domain='', 

21 inputs=None, outputs=None, attributes=None, 

22 used=None, context=None, mark_inits=None, indent=8, 

23 **unused): 

24 """ 

25 Converts an ONNX operators into :epkg:`tf2onnx` code. 

26 

27 :param opset: target opset for the conversion (usually unused) 

28 :param name: node name 

29 :param op_type: operator type 

30 :param domain: domain 

31 :param inputs: inputs 

32 :param outputs: outputs 

33 :param attributes: attributes 

34 :param used: dictionary `{k: v}`, 

35 list of nodes taking *k* as input 

36 :param context: whole context 

37 :param mark_inits: marks initializer as replaced 

38 :param indent: number of spaces to add on the second 

39 and following rows 

40 :return: code as str 

41 """ 

42 def simplify(name, kind, force=False): 

43 value = None 

44 if (used is not None and name in used and 

45 len(used[name]) == 1 and context is not None): 

46 inits = context['initializers_dict'] 

47 if name in inits: 

48 v = inits[name] 

49 if v.dtype == numpy.int64 and v.size < 10: 

50 value = v 

51 if name not in mark_inits: 

52 mark_inits[name] = [] 

53 mark_inits[name].append(v) 

54 

55 if value is None and force: 

56 inits = context['initializers_dict'] 

57 if name not in inits: 

58 raise RuntimeError( # pragma: no cover 

59 "Unable to find init %r in %r value=%r." % ( 

60 name, list(sorted(inits)), value)) 

61 value = inits[name] 

62 if kind == 'list': # pragma: no cover 

63 if value is None: 

64 return name 

65 if len(value.shape) == 0: 

66 return str(value) 

67 return str(list(value)) 

68 if kind == 'list_var': 

69 if value is None: 

70 return f"varx[{name!r}]" 

71 if len(value.shape) == 0: 

72 return str(value) 

73 return str(list(value)) 

74 raise NotImplementedError( # pragma: no cover 

75 f"Unknown scenario to simplify ({kind!r}).") 

76 

77 rows = [] 

78 if op_type == 'Unsqueeze': 

79 if len(inputs) == 2: 

80 rows.append( 

81 "node = GraphBuilder(ctx).make_unsqueeze(" 

82 "{'data': varx[%r], 'axes': %s}, return_node=True)" 

83 "" % (inputs[0], simplify(inputs[1], 'list_var'))) 

84 else: 

85 raise NotImplementedError( # pragma: no cover 

86 f"Unable to create code for operator {op_type!r} (opset <= 12).") 

87 elif op_type == 'Squeeze': 

88 if len(inputs) == 1: 

89 rows.append( # pragma: no cover 

90 "node = GraphBuilder(ctx).make_squeeze(" 

91 "{'data': varx[%r]}, return_node=True)" 

92 "" % (inputs[0], )) 

93 elif len(inputs) == 2: 

94 rows.append( 

95 "node = GraphBuilder(ctx).make_squeeze(" 

96 "{'data': varx[%r], 'axes': %s}, return_node=True)" 

97 "" % (inputs[0], simplify(inputs[1], 'list_var'))) 

98 else: 

99 raise NotImplementedError( # pragma: no cover 

100 f"Unable to create code for operator {op_type!r} (opset <= 12).") 

101 elif op_type == 'Slice': 

102 atts = dict(zip(['starts', 'ends', 'axes', 'steps'], 

103 inputs[1:])) 

104 text = ", ".join(f"'{k}': {simplify(v, 'list_var')}" 

105 for k, v in atts.items()) 

106 if len(inputs) in (3, 4, 5): 

107 rows.append( 

108 "node = GraphBuilder(ctx).make_slice(" 

109 "{'data': varx[%r], %s}, return_node=True)" 

110 "" % (inputs[0], text)) 

111 else: 

112 raise NotImplementedError( # pragma: no cover 

113 f"Unable to create code for operator {op_type!r} (opset <= 12).") 

114 else: 

115 if len(attributes) > 0: 

116 attributes_str = ", ".join(f"{k}={v}" for k, v in attributes) 

117 attr = f", attr=dict({attributes_str})" 

118 else: 

119 attr = "" 

120 rows.append( 

121 f"inputs = [{', '.join('varx[%r]' % n for n in inputs)}]") 

122 sdomain = '' if domain == '' else (f"domain={domain!r}, ") 

123 rows.append( 

124 "node = ctx.make_node(%r, inputs=inputs%s, %s" 

125 "name=make_name(%r))" % ( 

126 op_type, attr, sdomain, name)) 

127 for i, n in enumerate(outputs): 

128 rows.append("varx[%r] = node.output[%d]" % (n, i)) 

129 if indent > 0: 

130 sind = " " * indent 

131 for i in range(1, len(rows)): 

132 rows[i] = sind + rows[i] 

133 return "\n".join(rows) 

134 

135 

136def make_name(name): 

137 "Creates a unique name." 

138 global _make_name_id # pylint: disable=W0603 

139 name = "%s_%d" % (name, _make_name_id) 

140 _make_name_id += 1 

141 return name 

142 

143 

144def get_max_value(np_dtype): 

145 "Returns the maximum value for a specific type." 

146 return numpy.iinfo(np_dtype).max 

147 

148 

149def make_sure(cond, msg, *args): 

150 "Raises an exception if cond is not verified." 

151 if not cond: 

152 raise RuntimeError(msg % tuple(args)) # pragma: no cover 

153 

154 

155def map_onnx_to_numpy_type(onnx_dtype): 

156 "Converts ONNX type into numpy type." 

157 if onnx_dtype is None: 

158 return numpy.float32 

159 return guess_dtype(onnx_dtype) 

160 

161 

162class tf_op: 

163 """ 

164 Decorator to register any new converter. 

165 :param name: type of the operator to rewrite 

166 :param domain: domain 

167 """ 

168 _OPSETS = collections.OrderedDict() 

169 

170 def __init__(self, name, domain='', **kwargs): 

171 if not isinstance(name, list): 

172 name = [name] 

173 self.names = name 

174 self.domain = domain 

175 self.kwargs = kwargs 

176 

177 def __call__(self, func): 

178 for ke, va in inspect.getmembers(func, inspect.ismethod): 

179 if ke.startswith("version_"): 

180 version = int(ke.replace("version_", "")) 

181 self._register_handler( 

182 va, version, self.names, self.domain, self.kwargs) 

183 return func 

184 

185 def _register_handler(self, func, version, names, domain, kwargs): 

186 opset = tf_op._OPSETS.get(domain) 

187 if not opset: 

188 opset = [] 

189 tf_op._OPSETS[domain] = opset 

190 while version >= len(opset): 

191 opset.append({}) 

192 opset_dict = opset[version] 

193 for name in names: 

194 opset_dict[name] = (func, kwargs) 

195 

196 

197class Tf2OnnxConvert: 

198 """ 

199 Applies the converter on an ONNX graph. 

200 

201 :param onnx_model: ONNX graph 

202 :param tf_op: class which register 

203 :param verbose: verbosity 

204 :param target_opset: targetted opsets 

205 """ 

206 

207 def __init__(self, onnx_model, _tf_op=None, verbose=None, 

208 target_opset=None, max_iter=5): 

209 self._onnx_model = onnx_model 

210 self._tf_op = _tf_op or tf_op 

211 self.verbose = verbose 

212 self.max_iter = max_iter 

213 if isinstance(target_opset, int): 

214 self.target_opsets = {'': target_opset} # pragma: no cover 

215 elif isinstance(target_opset, dict): 

216 self.target_opsets = target_opset 

217 elif target_opset is None: # pragma: no cover 

218 opsets = {} 

219 for oimp in onnx_model.opset_import: 

220 if oimp.domain == '': 

221 opsets[oimp.domain] = oimp.version 

222 opset = oimp.version 

223 else: 

224 opsets[oimp.domain] = opset 

225 self.target_opsets = opsets 

226 else: 

227 raise ValueError( # pragma: no cover 

228 f"Unexepected value for target_opset={target_opset!r}.") 

229 self._names = {} 

230 for node in onnx_model.graph.node: 

231 self._names[node.name] = node 

232 for init in onnx_model.graph.initializer: 

233 self._names[init.name] = init 

234 # _forbidden_new_names contains current names and deleted names. 

235 self._forbidden_new_names = set(self._names) 

236 if '' in self.target_opsets: 

237 self.opset = self.target_opsets[''] 

238 if not hasattr(self, 'opset'): 

239 raise RuntimeError( # pragma: no cover 

240 f"Attribute opset is missing, target_opset={target_opset!r}.") 

241 

242 def get_node_by_name(self, name): # pragma: no cover 

243 """ 

244 Retrieves a node by its name. 

245 

246 :param name: node name 

247 :return: node name 

248 """ 

249 if name not in self._names: 

250 raise RuntimeError( 

251 "Unable to find node name %r among %r." % ( 

252 name, ", ".join(sorted(self._names)))) 

253 return self._names[name] 

254 

255 def _add_node_name(self, obj): 

256 """ 

257 Registers an object in in the graph by its name. 

258 :param name: node or initializer 

259 """ 

260 if obj.name in self._forbidden_new_names: 

261 raise RuntimeError( # pragma: no cover 

262 f"Name {obj.name!r} is already registered.") 

263 self._names[obj.name] = obj 

264 self._forbidden_new_names.add(obj.name) 

265 

266 def make_node(self, op_type, inputs, attr=None, outputs=None, 

267 name=None, domain='', output_count=1, 

268 shapes=None, dtypes=None): 

269 """ 

270 Adds a node to the list of nodes. 

271 

272 :param op_type: operator type 

273 :param inputs: list of strings 

274 :param attr: dictionary of attributes 

275 :param outputs: None or list of strings 

276 :param output_count: used if outputs is None to guess 

277 the number of outputs of this node 

278 :param name: name of the node 

279 :param domain: domain 

280 :param shapes: unused 

281 :param dtypes: unused 

282 :return: created node 

283 """ 

284 if self.verbose: 

285 print( # pragma: no cover 

286 f"[Tf2OnnxConvert.make_node] op_type={op_type!r} inputs={inputs!r}") 

287 

288 if attr is None: 

289 attr = {} 

290 if name is None: 

291 name = make_name(op_type) 

292 if name in self._names: 

293 raise RuntimeError( # pragma: no cover 

294 "Node name %r already exists in %r." % ( 

295 name, ", ".join(sorted(self._names)))) 

296 

297 if outputs is None: 

298 outputs = [(name + ":" + str(i)) for i in range(output_count)] 

299 

300 output_count = len(outputs) 

301 raw_attr = {} 

302 onnx_attrs = [] 

303 for a, v in attr.items(): 

304 if isinstance(v, AttributeProto): 

305 onnx_attrs.append(v) # pragma: no cover 

306 else: 

307 raw_attr[a] = v 

308 

309 onnx_node = make_node( 

310 op_type, inputs, outputs, name=name, domain=domain, **raw_attr) 

311 

312 self._add_node_name(onnx_node) 

313 return onnx_node 

314 

315 def make_const(self, name, np_val, skip_conversion=False, raw=True): 

316 """ 

317 Make a new constants in the graph. 

318 :param name: const node name, must be unique. 

319 :param np_val: value of type numpy ndarray. 

320 :param skip_conversion: 

321 bool, indicate whether this created node would be mapped 

322 during conversion 

323 :param raw: whether to store data at field of raw_data or the 

324 specific field according to its dtype 

325 :return: create initializer 

326 """ 

327 if name in self._names: 

328 raise RuntimeError( # pragma: no cover 

329 "Initializer name %r already exists in %r." % ( 

330 name, ", ".join(sorted(self._names)))) 

331 np_val_flat = np_val.flatten() 

332 is_bytes = (np_val.dtype == numpy.object and len(np_val_flat) > 0 and 

333 isinstance(np_val_flat[0], bytes)) 

334 if raw and not is_bytes: 

335 onnx_tensor = from_array(np_val, name) 

336 else: # pragma: no cover 

337 onnx_tensor = make_tensor( 

338 name, guess_proto_dtype(np_val.dtype), 

339 np_val.shape, np_val_flat, raw=False) 

340 

341 self._add_node_name(onnx_tensor) 

342 return onnx_tensor 

343 

344 def get_dtype(self, input_name): 

345 """ 

346 Returns the type of one node or None if unknown. 

347 :param input_name: result name 

348 :return: numpy dtype 

349 """ 

350 inputs = self._onnx_model.graph.input 

351 names = [_.name for _ in inputs] 

352 if input_name not in names: 

353 return None # pragma: no cover 

354 ind = names.index(input_name) 

355 return inputs[ind].type.tensor_type.elem_type 

356 

357 def replace_all_inputs(self, old_name, new_name): 

358 """ 

359 Every taking *old_name* as inputs will take *new_name* instead. 

360 Looks in the output as well but in that case, it creates an identity 

361 node to avoid changing an output name. 

362 :param old_name: name to replace 

363 :param new_name: new name 

364 :return: list of impacted nodes 

365 """ 

366 if self.verbose: 

367 print( # pragma: no cover 

368 "[Tf2OnnxConvert.replace_all_inputs] replace %r by %r" % ( 

369 old_name, new_name)) 

370 res = [] 

371 for node in self._names.values(): 

372 if not hasattr(node, 'input'): 

373 continue 

374 if old_name not in node.input: 

375 continue 

376 new_inputs = [ # pragma: no cover 

377 new_name if i == old_name else i for i in node.input] 

378 node.input[:] = new_inputs[:] # pragma: no cover 

379 res.append(node) # pragma: no cover 

380 if self.verbose: # pragma: no cover 

381 print( 

382 "[Tf2OnnxConvert.replace_all_inputs] replace %r by %r in node %r" % ( 

383 old_name, new_name, node.name)) 

384 for o in self._onnx_model.graph.output: 

385 if o.name != old_name: 

386 continue # pragma: no cover 

387 n = self.make_node("Identity", [new_name], outputs=[old_name], 

388 name=make_name("IdOutputReplaced")) 

389 res.append(n) 

390 if self.verbose: 

391 print( # pragma: no cover 

392 "[Tf2OnnxConvert.replace_all_inputs] add id node from %r to %r " 

393 "with node %r." % ( 

394 old_name, new_name, n.name)) # pylint: disable=E1101 

395 if self.verbose: 

396 print( # pragma: no cover 

397 "[Tf2OnnxConvert.replace_all_inputs] end") 

398 return res 

399 

400 def remove_node(self, name): 

401 """ 

402 Removes a node name from the list. 

403 """ 

404 if name not in self._names: 

405 raise RuntimeError( # pragma: no cover 

406 f"Unable to delete name {name!r} because it does not exists.") 

407 del self._names[name] 

408 if self.verbose: 

409 print( # pragma: no cover 

410 f"[Tf2OnnxConvert.remove_node] delete name {name!r}") 

411 

412 def get_shape(self, input_name): 

413 """ 

414 Returns the type of one node or None if unknown. 

415 :param input_name: result name 

416 :return: numpy dtype 

417 """ 

418 inputs = self._onnx_model.graph.input 

419 names = [_.name for _ in inputs] 

420 if input_name not in names: 

421 return None # pragma: no cover 

422 ind = names.index(input_name) 

423 dims = inputs[ind].type.tensor_type.shape.dim 

424 return tuple(dims) 

425 

426 def run(self): 

427 """ 

428 Calls the registered converters on the graph 

429 held by this instance. Returns the new onnx graph. 

430 

431 :return: ONNX graph 

432 """ 

433 if len(self._tf_op._OPSETS) == 0: 

434 raise RuntimeError( # pragma: no cover 

435 "No converter was registered.") 

436 if self.verbose: 

437 print("[Tf2OnnxConvert.run]") # pragma: no cover 

438 

439 done = {} 

440 modif = 1 

441 turn = 0 

442 while modif > 0 and turn < self.max_iter: 

443 modif = 0 

444 turn += 1 

445 # The converter may alter the current list of nodes, we freeze it. 

446 current_values = list(self._names.values()) 

447 for node in current_values: 

448 if not hasattr(node, 'domain'): 

449 # initializer 

450 continue 

451 if done.get(node.name, False): 

452 continue # pragma: no cover 

453 domain = node.domain 

454 if domain not in self._tf_op._OPSETS: 

455 continue # pragma: no cover 

456 

457 # look for a converter 

458 rews = self._tf_op._OPSETS[domain] 

459 target = min(self.target_opsets[domain], len(rews)) 

460 conv = None 

461 for i in range(len(rews) - 1, -1, -1): 

462 if node.op_type in rews[i]: 

463 conv = rews[i][node.op_type] 

464 break 

465 if conv is None: 

466 continue 

467 

468 # applies the converter 

469 if self.verbose: 

470 print( # pragma: no cover 

471 "[Tf2OnnxConvert.run] convert node type=%r opset=%r name=%r" 

472 "" % (node.op_type, target, node.name)) 

473 fct, kwargs = conv 

474 fct(self, node, target_opset=target, **kwargs) 

475 modif += 1 

476 

477 if turn >= self.max_iter: 

478 raise RuntimeError( # pragma: no cover 

479 "Too many iterations and no stable ONNX was reached, " 

480 "iter=%d\n%s" % (turn, str(self.make_model()))) 

481 return self.make_model() 

482 

483 def make_model(self): 

484 """ 

485 Produces the new ONNX graph with the updated sets of nodes. 

486 """ 

487 inputs = self._onnx_model.graph.input 

488 outputs = self._onnx_model.graph.output 

489 inits = [init[1] for init in sorted(self._names.items()) 

490 if not hasattr(init[1], 'domain')] 

491 nodes = [node[1] for node in sorted(self._names.items()) 

492 if hasattr(node[1], 'domain')] 

493 nodes = ensure_topological_order(inputs, inits, nodes) 

494 

495 if self.verbose: 

496 print( # pragma: no cover 

497 "[Tf2OnnxConvert.make_node] %d nodes %d inputs %d " 

498 "outputs %d initializers" 

499 "" % (len(nodes), len(inputs), len(outputs), len(inits))) 

500 graph = make_graph(nodes, self._onnx_model.graph.name, 

501 inputs, outputs, inits) 

502 onnx_model = make_model(graph, functions=self._onnx_model.functions) 

503 onnx_model.ir_version = self._onnx_model.ir_version 

504 onnx_model.producer_name = self._onnx_model.producer_name + "-mlprodict" 

505 onnx_model.producer_version = self._onnx_model.producer_version 

506 onnx_model.domain = self._onnx_model.domain 

507 onnx_model.model_version = self._onnx_model.model_version 

508 onnx_model.doc_string = self._onnx_model.doc_string 

509 metadata = {p.key: p.value for p in self._onnx_model.metadata_props} 

510 set_model_props(onnx_model, metadata) 

511 

512 # opsets 

513 del onnx_model.opset_import[:] # pylint: disable=E1101 

514 for dom, value in self.target_opsets.items(): 

515 op_set = onnx_model.opset_import.add() # pylint: disable=E1101 

516 op_set.domain = dom 

517 op_set.version = value 

518 return onnx_model 

519 

520 

521class GraphBuilder: 

522 """ 

523 Helpers to build graph. 

524 :param graph! 

525 """ 

526 

527 def __init__(self, graph): 

528 self._g = graph 

529 

530 @property 

531 def graph(self): 

532 "Returns the graph." 

533 return self._g 

534 

535 def make_slice(self, kwargs, name=None, shapes=None, dtypes=None, 

536 return_node=False): 

537 """ 

538 slice changes its schema at opset 10: it treats some 

539 attributes as dynamic input so this function has to process 

540 inputs according to graph's opset version 

541 to get "inputs" and "attr" to feed "make_node" 

542 kwargs: key could be `["data", "starts", "ends", 

543 "axes", "steps", "outputs"]`. 

544 """ 

545 outputs = kwargs.pop("outputs", None) 

546 

547 if self.graph.opset < 10: 

548 # "data" is string 

549 # "starts", "ends" and "axes" are attributes, 

550 # and "axes" is optional. 

551 data = kwargs.pop("data") # pragma: no cover 

552 starts = self._convert_to_attribute( # pragma: no cover 

553 kwargs.pop("starts")) 

554 ends = self._convert_to_attribute( # pragma: no cover 

555 kwargs.pop("ends")) 

556 axes = self._convert_to_attribute( # pragma: no cover 

557 kwargs.pop("axes", None), is_optional=True) 

558 attr = {"starts": starts, "ends": ends, 

559 "axes": axes} # pragma: no cover 

560 inputs = [data] # pragma: no cover 

561 else: 

562 # slice-10 has 3 required inputs "data", "starts", "ends"l 

563 # and 2 optional inputs "axes", "steps" 

564 # input sequence should be "data", "starts", "ends", 

565 # "axes", "steps" 

566 attr = {} 

567 data = kwargs.pop("data") 

568 starts = self._convert_to_input( 

569 kwargs.pop("starts"), "const_starts", dtype=numpy.int64) 

570 ends = self._convert_to_input( 

571 kwargs.pop("ends"), "const_ends", dtype=numpy.int64) 

572 axes = self._convert_to_input( 

573 kwargs.pop("axes", None), "const_axes", 

574 is_optional=True, dtype=numpy.int64) 

575 steps = self._convert_to_input( 

576 kwargs.pop("steps", None), "const_steps", 

577 is_optional=True, dtype=numpy.int64) 

578 inputs = [data, starts, ends, axes, steps] 

579 

580 # pro-process inputs and attr 

581 make_sure(not kwargs, "kwargs contains un-used key") 

582 

583 new_attr = {} 

584 for key, val in attr.items(): 

585 if val is not None: # pragma: no cover 

586 new_attr[key] = val 

587 attr = new_attr 

588 

589 for ind, val in enumerate(inputs): 

590 if val is None: 

591 inputs[ind] = "" # empty string means no connection in ONNX 

592 # remove tailing "" 

593 while inputs[-1] == "": 

594 inputs = inputs[:-1] 

595 

596 if self.graph.opset >= 10: 

597 dtype = self.graph.get_dtype(inputs[1]) 

598 for input_data in inputs[1:]: 

599 if input_data != "": 

600 make_sure(dtype == self.graph.get_dtype( 

601 input_data), "dtype should be same") 

602 

603 node = self.graph.make_node(op_type="Slice", inputs=inputs, attr=attr, 

604 name=name, outputs=outputs, shapes=shapes, 

605 dtypes=dtypes) 

606 if return_node: 

607 return node 

608 raise NotImplementedError( # pragma: no cover 

609 "return_node must be True") 

610 

611 def make_squeeze(self, kwargs, name=None, shapes=None, dtypes=None, 

612 return_node=False, op_name_scope=None): 

613 """ 

614 Squeeze changes its schema at opset 13: it treats axes as a dynamic input 

615 kwargs: key could be ["data", "axes"]. 

616 """ 

617 outputs = kwargs.pop("outputs", None) 

618 

619 if self.graph.opset < 13: # pragma: no cover 

620 data = kwargs.pop("data") 

621 axes = self._convert_to_attribute( 

622 kwargs.pop("axes", None), is_optional=True) 

623 attr = {"axes": axes} 

624 inputs = [data] 

625 else: 

626 data = kwargs.pop("data") 

627 axes = self._convert_to_input( 

628 kwargs.pop("axes", None), "const_axes", 

629 is_optional=True, dtype=numpy.int64) 

630 attr = {} 

631 inputs = [data, axes] 

632 

633 make_sure(not kwargs, "kwargs contains un-used key") 

634 

635 new_attr = {} 

636 for key, val in attr.items(): 

637 if val is not None: # pragma: no cover 

638 new_attr[key] = val 

639 attr = new_attr 

640 

641 for ind, val in enumerate(inputs): 

642 if val is None: # pragma: no cover 

643 inputs[ind] = "" # empty string means no connection in ONNX 

644 # remove tailing "" 

645 while inputs[-1] == "": 

646 inputs = inputs[:-1] # pragma: no cover 

647 

648 node = self.graph.make_node( 

649 op_type="Squeeze", inputs=inputs, attr=attr, name=name, 

650 outputs=outputs) 

651 if return_node: 

652 return node 

653 raise NotImplementedError( # pragma: no cover 

654 "return_node must be True") 

655 

656 def make_unsqueeze(self, kwargs, name=None, shapes=None, dtypes=None, 

657 return_node=False, op_name_scope=None): 

658 """ 

659 Unsqueeze changes its schema at opset 13: it treats axes as a dynamic input 

660 kwargs: key could be ["data", "axes"]. 

661 """ 

662 outputs = kwargs.pop("outputs", None) 

663 

664 if self.graph.opset < 13: 

665 data = kwargs.pop("data") # pragma: no cover 

666 axes = self._convert_to_attribute( # pragma: no cover 

667 kwargs.pop("axes", None), is_optional=True) 

668 attr = {"axes": axes} # pragma: no cover 

669 inputs = [data] # pragma: no cover 

670 else: 

671 data = kwargs.pop("data") 

672 axes = self._convert_to_input( 

673 kwargs.pop("axes", None), "const_axes", 

674 is_optional=True, dtype=numpy.int64) 

675 attr = {} 

676 inputs = [data, axes] 

677 

678 make_sure(not kwargs, "kwargs contains un-used key") 

679 

680 new_attr = {} 

681 for key, val in attr.items(): 

682 if val is not None: # pragma: no cover 

683 new_attr[key] = val 

684 attr = new_attr 

685 

686 for ind, val in enumerate(inputs): 

687 if val is None: # pragma: no cover 

688 inputs[ind] = "" # empty string means no connection in ONNX 

689 # remove tailing "" 

690 while inputs[-1] == "": 

691 inputs = inputs[:-1] # pragma: no cover 

692 

693 node = self.graph.make_node( 

694 op_type="Unsqueeze", inputs=inputs, attr=attr, name=name, 

695 outputs=outputs) 

696 if return_node: 

697 return node 

698 raise NotImplementedError( # pragma: no cover 

699 "return_node must be True") 

700 

701 def _convert_to_input(self, tensor, const_name, # pragma: no cover 

702 is_optional=False, dtype=None): 

703 """in ONNX, input shold come from node, so it must be a string""" 

704 if is_optional and tensor is None: 

705 return None 

706 

707 make_sure(tensor is not None, 

708 "input is required so it couldn't be None") 

709 

710 res = tensor 

711 if isinstance(tensor, list): 

712 res = self.graph.make_const( 

713 make_name(const_name), numpy.array(tensor, dtype)).name 

714 return res 

715 

716 def _convert_to_attribute(self, tensor, is_optional=False): 

717 if is_optional and tensor is None: 

718 return None 

719 

720 make_sure(tensor is not None, 

721 "input is required so it couldn't be None") 

722 

723 res = tensor 

724 if isinstance(tensor, str): 

725 const_node = self.graph.get_node_by_output(tensor) 

726 res = const_node.get_tensor_value(as_list=True) 

727 

728 make_sure(isinstance(res, list), 

729 "input is an attr, so a list is needed") 

730 

731 return res