Coverage for mlprodict/onnxrt/onnx_inference_exports.py: 99%

325 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Extensions to class @see cl OnnxInference. 

4""" 

5import os 

6import json 

7import re 

8from io import BytesIO 

9import pickle 

10import textwrap 

11from onnx import numpy_helper 

12from ..onnx_tools.onnx2py_helper import _var_as_dict, _type_to_string 

13from ..tools.graphs import onnx2bigraph 

14from ..plotting.text_plot import onnx_simple_text_plot 

15 

16 

17class OnnxInferenceExport: 

18 """ 

19 Implements methods to export a instance of 

20 @see cl OnnxInference into :epkg:`json`, :epkg:`dot`, 

21 *text*, *python*. 

22 """ 

23 

24 def __init__(self, oinf): 

25 """ 

26 @param oinf @see cl OnnxInference 

27 """ 

28 self.oinf = oinf 

29 

30 def to_dot(self, recursive=False, prefix='', # pylint: disable=R0914 

31 add_rt_shapes=False, use_onnx=False, 

32 add_functions=True, **params): 

33 """ 

34 Produces a :epkg:`DOT` language string for the graph. 

35 

36 :param params: additional params to draw the graph 

37 :param recursive: also show subgraphs inside operator like 

38 @see cl Scan 

39 :param prefix: prefix for every node name 

40 :param add_rt_shapes: adds shapes infered from the python runtime 

41 :param use_onnx: use :epkg:`onnx` dot format instead of this one 

42 :param add_functions: add functions to the graph 

43 :return: string 

44 

45 Default options for the graph are: 

46 

47 :: 

48 

49 options = { 

50 'orientation': 'portrait', 

51 'ranksep': '0.25', 

52 'nodesep': '0.05', 

53 'width': '0.5', 

54 'height': '0.1', 

55 'size': '7', 

56 } 

57 

58 One example: 

59 

60 .. exref:: 

61 :title: Convert ONNX into DOT 

62 

63 An example on how to convert an :epkg:`ONNX` 

64 graph into :epkg:`DOT`. 

65 

66 .. runpython:: 

67 :showcode: 

68 :warningout: DeprecationWarning 

69 

70 import numpy 

71 from mlprodict.npy.xop import loadop 

72 from mlprodict.onnxrt import OnnxInference 

73 

74 OnnxAiOnnxMlLinearRegressor = loadop( 

75 ('ai.onnx.ml', 'LinearRegressor')) 

76 

77 pars = dict(coefficients=numpy.array([1., 2.]), 

78 intercepts=numpy.array([1.]), 

79 post_transform='NONE') 

80 onx = OnnxAiOnnxMlLinearRegressor( 

81 'X', output_names=['Y'], **pars) 

82 model_def = onx.to_onnx( 

83 {'X': pars['coefficients'].astype(numpy.float32)}, 

84 outputs={'Y': numpy.float32}, 

85 target_opset=12) 

86 oinf = OnnxInference(model_def) 

87 print(oinf.to_dot()) 

88 

89 See an example of representation in notebook 

90 :ref:`onnxvisualizationrst`. 

91 """ 

92 clean_label_reg1 = re.compile("\\\\x\\{[0-9A-F]{1,6}\\}") 

93 clean_label_reg2 = re.compile("\\\\p\\{[0-9P]{1,6}\\}") 

94 

95 def dot_name(text): 

96 return text.replace("/", "_").replace( 

97 ":", "__").replace(".", "_") 

98 

99 def dot_label(text): 

100 for reg in [clean_label_reg1, clean_label_reg2]: 

101 fall = reg.findall(text) 

102 for f in fall: 

103 text = text.replace(f, "_") # pragma: no cover 

104 return text 

105 

106 options = { 

107 'orientation': 'portrait', 

108 'ranksep': '0.25', 

109 'nodesep': '0.05', 

110 'width': '0.5', 

111 'height': '0.1', 

112 'size': '7', 

113 } 

114 options.update({k: v for k, v in params.items() if v is not None}) 

115 

116 if use_onnx: 

117 from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer 

118 

119 pydot_graph = GetPydotGraph( 

120 self.oinf.obj.graph, name=self.oinf.obj.graph.name, 

121 rankdir=params.get('rankdir', "TB"), 

122 node_producer=GetOpNodeProducer( 

123 "docstring", fillcolor="orange", style="filled", 

124 shape="box")) 

125 return pydot_graph.to_string() 

126 

127 inter_vars = {} 

128 exp = ["digraph{"] 

129 for opt in {'orientation', 'pad', 'nodesep', 'ranksep', 'size'}: 

130 if opt in options: 

131 exp.append(f" {opt}={options[opt]};") 

132 fontsize = 10 

133 

134 shapes = {} 

135 if add_rt_shapes: 

136 if not hasattr(self.oinf, 'shapes_'): 

137 raise RuntimeError( # pragma: no cover 

138 "No information on shapes, check the runtime '{}'." 

139 "".format(self.oinf.runtime)) 

140 for name, shape in self.oinf.shapes_.items(): 

141 va = str(shape.shape) 

142 shapes[name] = va 

143 if name in self.oinf.inplaces_: 

144 shapes[name] += "\\ninplace" 

145 

146 # inputs 

147 exp.append("") 

148 graph = ( 

149 self.oinf.obj.graph if hasattr(self.oinf.obj, 'graph') 

150 else self.oinf.obj) 

151 for obj in graph.input: 

152 if isinstance(obj, str): 

153 exp.append( 

154 ' {2}{0} [shape=box color=red label="{0}" fontsize={1}];' 

155 ''.format(obj, fontsize, prefix)) 

156 inter_vars[obj] = obj 

157 else: 

158 dobj = _var_as_dict(obj) 

159 sh = shapes.get(dobj['name'], '') 

160 if sh: 

161 sh = f"\\nshape={sh}" 

162 exp.append( 

163 ' {3}{0} [shape=box color=red label="{0}\\n{1}{4}" fontsize={2}];' 

164 ''.format( 

165 dot_name(dobj['name']), _type_to_string(dobj['type']), 

166 fontsize, prefix, dot_label(sh))) 

167 inter_vars[obj.name] = obj 

168 

169 # outputs 

170 exp.append("") 

171 for obj in graph.output: 

172 if isinstance(obj, str): 

173 exp.append( 

174 ' {2}{0} [shape=box color=green label="{0}" fontsize={1}];'.format( 

175 obj, fontsize, prefix)) 

176 inter_vars[obj] = obj 

177 else: 

178 dobj = _var_as_dict(obj) 

179 sh = shapes.get(dobj['name'], '') 

180 if sh: 

181 sh = f"\\nshape={sh}" 

182 exp.append( 

183 ' {3}{0} [shape=box color=green label="{0}\\n{1}{4}" fontsize={2}];' 

184 ''.format( 

185 dot_name(dobj['name']), _type_to_string(dobj['type']), 

186 fontsize, prefix, dot_label(sh))) 

187 inter_vars[obj.name] = obj 

188 

189 # initializer 

190 exp.append("") 

191 if hasattr(self.oinf.obj, 'graph'): 

192 inits = ( 

193 list(self.oinf.obj.graph.initializer) + 

194 list(self.oinf.obj.graph.sparse_initializer)) 

195 for obj in inits: 

196 dobj = _var_as_dict(obj) 

197 val = dobj['value'] 

198 flat = val.flatten() 

199 if flat.shape[0] < 9: 

200 st = str(val) 

201 else: 

202 st = str(val) 

203 if len(st) > 50: 

204 st = st[:50] + '...' 

205 st = st.replace('\n', '\\n') 

206 kind = "" 

207 exp.append( 

208 ' {6}{0} [shape=box label="{0}\\n{4}{1}({2})\\n{3}" fontsize={5}];' 

209 ''.format( 

210 dot_name(dobj['name']), dobj['value'].dtype, 

211 dobj['value'].shape, dot_label(st), kind, fontsize, prefix)) 

212 inter_vars[obj.name] = obj 

213 

214 # nodes 

215 fill_names = {} 

216 if hasattr(self.oinf.obj, 'graph'): 

217 static_inputs = [n.name for n in self.oinf.obj.graph.input] 

218 static_inputs.extend( 

219 n.name for n in self.oinf.obj.graph.initializer) 

220 static_inputs.extend( 

221 n.name for n in self.oinf.obj.graph.sparse_initializer) 

222 nodes = list(self.oinf.obj.graph.node) 

223 else: 

224 static_inputs = list(self.oinf.obj.input) 

225 nodes = self.oinf.obj.node 

226 for node in nodes: 

227 exp.append("") 

228 for out in node.output: 

229 if len(out) > 0 and out not in inter_vars: 

230 inter_vars[out] = out 

231 sh = shapes.get(out, '') 

232 if sh: 

233 sh = f"\\nshape={sh}" 

234 exp.append( 

235 ' {2}{0} [shape=box label="{0}{3}" fontsize={1}];'.format( 

236 dot_name(out), fontsize, dot_name(prefix), 

237 dot_label(sh))) 

238 static_inputs.append(out) 

239 

240 dobj = _var_as_dict(node) 

241 if dobj['name'].strip() == '': # pragma: no cover 

242 name = node.op_type 

243 iname = 1 

244 while name in fill_names: 

245 name = "%s%d" % (name, iname) 

246 iname += 1 

247 dobj['name'] = name 

248 node.name = name 

249 fill_names[name] = node 

250 

251 atts = [] 

252 if 'atts' in dobj: 

253 for k, v in sorted(dobj['atts'].items()): 

254 val = None 

255 if 'value' in v: 

256 val = str(v['value']).replace( 

257 "\n", "\\n").replace('"', "'") 

258 sl = max(30 - len(k), 10) 

259 if len(val) > sl: 

260 val = val[:sl] + "..." 

261 if val is not None: 

262 atts.append(f'{k}={val}') 

263 satts = "" if len(atts) == 0 else ("\\n" + "\\n".join(atts)) 

264 

265 connects = [] 

266 if recursive and node.op_type in {'Scan', 'Loop', 'If'}: 

267 fields = (['then_branch', 'else_branch'] 

268 if node.op_type == 'If' else ['body']) 

269 for field in fields: 

270 if field not in dobj['atts']: 

271 continue # pragma: no cover 

272 

273 # creates the subgraph 

274 body = dobj['atts'][field]['value'] 

275 oinf = self.oinf.__class__( 

276 body, runtime=self.oinf.runtime, 

277 skip_run=self.oinf.skip_run, 

278 static_inputs=static_inputs) 

279 subprefix = prefix + "B_" 

280 subdot = oinf.to_dot(recursive=recursive, prefix=subprefix, 

281 add_rt_shapes=add_rt_shapes) 

282 lines = subdot.split("\n") 

283 start = 0 

284 for i, line in enumerate(lines): 

285 if '[' in line: 

286 start = i 

287 break 

288 subgraph = "\n".join(lines[start:]) 

289 

290 # connecting the subgraph 

291 cluster = f"cluster_{node.op_type}{id(node)}_{id(field)}" 

292 exp.append(f" subgraph {cluster} {{") 

293 exp.append(' label="{0}\\n({1}){2}";'.format( 

294 dobj['op_type'], dot_name(dobj['name']), satts)) 

295 exp.append(f' fontsize={fontsize};') 

296 exp.append(' color=black;') 

297 exp.append( 

298 '\n'.join(map(lambda s: ' ' + s, subgraph.split('\n')))) 

299 

300 node0 = body.node[0] 

301 connects.append(( 

302 f"{dot_name(subprefix)}{dot_name(node0.name)}", 

303 cluster)) 

304 

305 for inp1, inp2 in zip(node.input, body.input): 

306 exp.append( 

307 " {0}{1} -> {2}{3};".format( 

308 dot_name(prefix), dot_name(inp1), 

309 dot_name(subprefix), dot_name(inp2.name))) 

310 for out1, out2 in zip(body.output, node.output): 

311 if len(out2) == 0: 

312 # Empty output, it cannot be used. 

313 continue 

314 exp.append( 

315 " {0}{1} -> {2}{3};".format( 

316 dot_name(subprefix), dot_name(out1.name), 

317 dot_name(prefix), dot_name(out2))) 

318 else: 

319 exp.append(' {4}{1} [shape=box style="filled,rounded" color=orange ' 

320 'label="{0}\\n({1}){2}" fontsize={3}];'.format( 

321 dobj['op_type'], dot_name( 

322 dobj['name']), satts, fontsize, 

323 dot_name(prefix))) 

324 

325 if connects is not None and len(connects) > 0: 

326 for name, cluster in connects: 

327 exp.append( 

328 " {0}{1} -> {2} [lhead={3}];".format( 

329 dot_name(prefix), dot_name(node.name), 

330 name, cluster)) 

331 

332 for inp in node.input: 

333 exp.append( 

334 " {0}{1} -> {0}{2};".format( 

335 dot_name(prefix), dot_name(inp), dot_name(node.name))) 

336 for out in node.output: 

337 if len(out) == 0: 

338 # Empty output, it cannot be used. 

339 continue 

340 exp.append( 

341 " {0}{1} -> {0}{2};".format( 

342 dot_name(prefix), dot_name(node.name), dot_name(out))) 

343 

344 if add_functions and len(self.oinf.functions_) > 0: 

345 for i, (k, v) in enumerate(self.oinf.functions_.items()): 

346 dot = v.to_dot(recursive=recursive, prefix=prefix + v.obj.name, 

347 add_rt_shapes=add_rt_shapes, 

348 use_onnx=use_onnx, add_functions=False, 

349 **params) 

350 spl = dot.split('\n')[1:] 

351 exp.append('') 

352 exp.append(' subgraph cluster_%d {' % i) 

353 exp.append(f' label="{v.obj.name}";') 

354 exp.append(' color=blue;') 

355 #exp.append(' style=filled;') 

356 exp.extend((' ' + line) for line in spl) 

357 

358 exp.append('}') 

359 return "\n".join(exp) 

360 

361 def to_json(self, indent=2): 

362 """ 

363 Converts an :epkg:`ONNX` model into :epkg:`JSON`. 

364 

365 @param indent indentation 

366 @return string 

367 

368 .. exref:: 

369 :title: Convert ONNX into JSON 

370 

371 An example on how to convert an :epkg:`ONNX` 

372 graph into :epkg:`JSON`. 

373 

374 .. runpython:: 

375 :showcode: 

376 :warningout: DeprecationWarning 

377 

378 import numpy 

379 from mlprodict.npy.xop import loadop 

380 from mlprodict.onnxrt import OnnxInference 

381 

382 OnnxAiOnnxMlLinearRegressor = loadop( 

383 ('ai.onnx.ml', 'LinearRegressor')) 

384 

385 pars = dict(coefficients=numpy.array([1., 2.]), 

386 intercepts=numpy.array([1.]), 

387 post_transform='NONE') 

388 onx = OnnxAiOnnxMlLinearRegressor( 

389 'X', output_names=['Y'], **pars) 

390 model_def = onx.to_onnx( 

391 {'X': pars['coefficients'].astype(numpy.float32)}, 

392 outputs={'Y': numpy.float32}, 

393 target_opset=12) 

394 oinf = OnnxInference(model_def) 

395 print(oinf.to_json()) 

396 """ 

397 

398 def _to_json(obj): 

399 s = str(obj) 

400 rows = ['{'] 

401 leave = None 

402 for line in s.split('\n'): 

403 if line.endswith("{"): 

404 rows.append('"%s": {' % line.strip('{ ')) 

405 elif ':' in line: 

406 spl = line.strip().split(':') 

407 if len(spl) != 2: 

408 raise RuntimeError( # pragma: no cover 

409 f"Unable to interpret line '{line}'.") 

410 

411 if spl[0].strip() in ('type', ): 

412 st = spl[1].strip() 

413 if st in {'INT', 'INTS', 'FLOAT', 'FLOATS', 

414 'STRING', 'STRINGS', 'TENSOR'}: 

415 spl[1] = f'"{st}"' 

416 

417 if spl[0] in ('floats', 'ints'): 

418 if leave: 

419 rows.append(f"{spl[1]},") 

420 else: 

421 rows.append(f'"{spl[0]}": [{spl[1].strip()},') 

422 leave = spl[0] 

423 elif leave: 

424 rows[-1] = rows[-1].strip(',') 

425 rows.append('],') 

426 rows.append(f'"{spl[0].strip()}": {spl[1].strip()},') 

427 leave = None 

428 else: 

429 rows.append(f'"{spl[0].strip()}": {spl[1].strip()},') 

430 elif line.strip() == "}": 

431 rows[-1] = rows[-1].rstrip(",") 

432 rows.append(line + ",") 

433 elif line: 

434 raise RuntimeError( # pragma: no cover 

435 f"Unable to interpret line '{line}'.") 

436 rows[-1] = rows[-1].rstrip(',') 

437 rows.append("}") 

438 js = "\n".join(rows) 

439 

440 try: 

441 content = json.loads(js) 

442 except json.decoder.JSONDecodeError as e: # pragma: no cover 

443 js2 = "\n".join("%04d %s" % (i + 1, line) 

444 for i, line in enumerate(js.split("\n"))) 

445 raise RuntimeError( 

446 f"Unable to parse JSON\n{js2}") from e 

447 return content 

448 

449 # meta data 

450 final_obj = {} 

451 for k in {'ir_version', 'producer_name', 'producer_version', 

452 'domain', 'model_version', 'doc_string'}: 

453 if hasattr(self.oinf.obj, k): 

454 final_obj[k] = getattr(self.oinf.obj, k) 

455 

456 # inputs 

457 inputs = [] 

458 for obj in self.oinf.obj.graph.input: 

459 st = _to_json(obj) 

460 inputs.append(st) 

461 final_obj['inputs'] = inputs 

462 

463 # outputs 

464 outputs = [] 

465 for obj in self.oinf.obj.graph.output: 

466 st = _to_json(obj) 

467 outputs.append(st) 

468 final_obj['outputs'] = outputs 

469 

470 # init 

471 inits = {} 

472 for obj in self.oinf.obj.graph.initializer: 

473 value = numpy_helper.to_array(obj).tolist() 

474 inits[obj.name] = value 

475 final_obj['initializers'] = inits 

476 

477 # nodes 

478 nodes = [] 

479 for obj in list(self.oinf.obj.graph.node): 

480 node = dict(name=obj.name, op_type=obj.op_type, domain=obj.domain, 

481 inputs=[str(_) for _ in obj.input], 

482 outputs=[str(_) for _ in obj.output], 

483 attributes={}) 

484 for att in obj.attribute: 

485 st = _to_json(att) 

486 node['attributes'][st['name']] = st 

487 del st['name'] 

488 nodes.append(node) 

489 final_obj['nodes'] = nodes 

490 

491 return json.dumps(final_obj, indent=indent) 

492 

493 def to_python(self, prefix="onnx_pyrt_", dest=None, inline=True): 

494 """ 

495 Converts the ONNX runtime into independant python code. 

496 The function creates multiple files starting with 

497 *prefix* and saved to folder *dest*. 

498 

499 @param prefix file prefix 

500 @param dest destination folder 

501 @param inline constant matrices are put in the python file itself 

502 as byte arrays 

503 @return file dictionary 

504 

505 The function does not work if the chosen runtime 

506 is not *python*. 

507 

508 .. runpython:: 

509 :showcode: 

510 :warningout: DeprecationWarning 

511 

512 import numpy 

513 from mlprodict.npy.xop import loadop 

514 from mlprodict.onnxrt import OnnxInference 

515 

516 OnnxAdd = loadop('Add') 

517 

518 idi = numpy.identity(2).astype(numpy.float32) 

519 onx = OnnxAdd('X', idi, output_names=['Y'], 

520 op_version=12) 

521 model_def = onx.to_onnx({'X': idi}, 

522 target_opset=12) 

523 X = numpy.array([[1, 2], [3, 4]], dtype=numpy.float32) 

524 oinf = OnnxInference(model_def, runtime='python') 

525 res = oinf.to_python() 

526 print(res['onnx_pyrt_main.py']) 

527 """ 

528 if not isinstance(prefix, str): 

529 raise TypeError( # pragma: no cover 

530 f"prefix must be a string not {type(prefix)!r}.") 

531 

532 def clean_args(args): 

533 new_args = [] 

534 for v in args: 

535 # remove python keywords 

536 if v.startswith('min='): 

537 av = 'min_=' + v[4:] 

538 elif v.startswith('max='): 

539 av = 'max_=' + v[4:] 

540 else: 

541 av = v 

542 new_args.append(av) 

543 return new_args 

544 

545 if self.oinf.runtime not in ('python', None): 

546 raise ValueError( 

547 f"The runtime must be 'python' not '{self.oinf.runtime}'.") 

548 

549 # metadata 

550 obj = {} 

551 for k in {'ir_version', 'producer_name', 'producer_version', 

552 'domain', 'model_version', 'doc_string'}: 

553 if hasattr(self.oinf.obj, k): 

554 obj[k] = getattr(self.oinf.obj, k) 

555 code_begin = ["# coding: utf-8", 

556 "'''", 

557 "Python code equivalent to an ONNX graph.", 

558 "It was was generated by module *mlprodict*.", 

559 "'''"] 

560 code_imports = ["from io import BytesIO", 

561 "import pickle", 

562 "from numpy import array, float32, ndarray"] 

563 code_lines = ["class OnnxPythonInference:", "", 

564 " def __init__(self):", 

565 " self._load_inits()", "", 

566 " @property", 

567 " def metadata(self):", 

568 f" return {obj!r}", ""] 

569 

570 # inputs 

571 if hasattr(self.oinf.obj, 'graph'): 

572 inputs = [obj.name for obj in self.oinf.obj.graph.input] 

573 outputs = [obj.name for obj in self.oinf.obj.graph.output] 

574 else: 

575 inputs = list(self.oinf.obj.input) 

576 outputs = list(self.oinf.obj.output) 

577 

578 code_lines.extend([ 

579 " @property", " def inputs(self):", 

580 f" return {inputs!r}", 

581 "" 

582 ]) 

583 

584 # outputs 

585 code_lines.extend([ 

586 " @property", " def outputs(self):", 

587 f" return {outputs!r}", 

588 "" 

589 ]) 

590 

591 # init 

592 code_lines.extend([" def _load_inits(self):", 

593 " self._inits = {}"]) 

594 file_data = {} 

595 if hasattr(self.oinf.obj, 'graph'): 

596 for obj in self.oinf.obj.graph.initializer: 

597 value = numpy_helper.to_array(obj) 

598 bt = BytesIO() 

599 pickle.dump(value, bt) 

600 name = f'{prefix}{obj.name}.pkl' 

601 if inline: 

602 code_lines.extend([ 

603 f" iocst = {bt.getvalue()!r}", 

604 f" self._inits['{obj.name}'] = pickle.loads(iocst)" 

605 ]) 

606 else: 

607 file_data[name] = bt.getvalue() 

608 code_lines.append( 

609 f" self._inits['{obj.name}'] = pickle.loads('{name}')") 

610 code_lines.append('') 

611 

612 # inputs, outputs 

613 inputs = self.oinf.input_names 

614 

615 # nodes 

616 code_lines.extend([f" def run(self, {', '.join(inputs)}):"]) 

617 ops = {} 

618 if hasattr(self.oinf.obj, 'graph'): 

619 code_lines.append(' # constant') 

620 for obj in self.oinf.obj.graph.initializer: 

621 code_lines.append( 

622 " {0} = self._inits['{0}']".format(obj.name)) 

623 code_lines.append('') 

624 code_lines.append(' # graph code') 

625 for node in self.oinf.sequence_: 

626 fct = 'pyrt_' + node.name 

627 if fct not in ops: 

628 ops[fct] = node 

629 args = [] 

630 args.extend(node.inputs) 

631 margs = node.modified_args 

632 if margs is not None: 

633 args.extend(clean_args(margs)) 

634 code_lines.append(" {0} = {1}({2})".format( 

635 ', '.join(node.outputs), fct, ', '.join(args))) 

636 code_lines.append('') 

637 code_lines.append(' # return') 

638 code_lines.append(f" return {', '.join(outputs)}") 

639 code_lines.append('') 

640 

641 # operator code 

642 code_nodes = [] 

643 for name, op in ops.items(): 

644 inputs_args = clean_args(op.inputs_args) 

645 

646 code_nodes.append(f"def {name}({', '.join(inputs_args)}):") 

647 imps, code = op.to_python(op.python_inputs) 

648 if imps is not None: 

649 if not isinstance(imps, list): 

650 imps = [imps] 

651 code_imports.extend(imps) 

652 code_nodes.append(textwrap.indent(code, ' ')) 

653 code_nodes.extend(['', '']) 

654 

655 # end 

656 code_imports = list(sorted(set(code_imports))) 

657 code_imports.extend(['', '']) 

658 file_data[prefix + 'main.py'] = "\n".join( 

659 code_begin + code_imports + code_nodes + code_lines) 

660 

661 # saves as files 

662 if dest is not None: 

663 for k, v in file_data.items(): 

664 ext = os.path.splitext(k)[-1] 

665 kf = os.path.join(dest, k) 

666 if ext == '.py': 

667 with open(kf, "w", encoding="utf-8") as f: 

668 f.write(v) 

669 elif ext == '.pkl': # pragma: no cover 

670 with open(kf, "wb") as f: 

671 f.write(v) 

672 else: 

673 raise NotImplementedError( # pragma: no cover 

674 f"Unknown extension for file '{k}'.") 

675 return file_data 

676 

677 def to_text(self, recursive=False, grid=5, distance=5, kind='bi'): 

678 """ 

679 It calls function @see fn onnx2bigraph to return 

680 the ONNX graph as text. 

681 

682 :param recursive: dig into subgraphs too 

683 :param grid: align text to this grid 

684 :param distance: distance to the text 

685 :param kind: see below 

686 :return: text 

687 

688 Possible values for format: 

689 * `'bi'`: use @see fn onnx2bigraph 

690 * `'seq'`: use @see fn onnx_simple_text_plot 

691 """ 

692 if kind == 'bi': 

693 bigraph = onnx2bigraph(self.oinf.obj, recursive=recursive) 

694 graph = bigraph.display_structure(grid=grid, distance=distance) 

695 return graph.to_text() 

696 if kind == 'seq': 

697 return onnx_simple_text_plot(self.oinf.obj) 

698 raise ValueError( # pragma: no cover 

699 f"Unexpected value for format={format!r}.") 

700 

701 def to_onnx_code(self): 

702 """ 

703 Exports the ONNX graph into an :epkg:`onnx` code 

704 which replicates it. 

705 

706 :return: string 

707 """ 

708 # Lazy import as it is not a common use. 

709 from ..onnx_tools.onnx_export import export2onnx 

710 return export2onnx(self.oinf.obj)