Coverage for mlprodict/onnxrt/onnx_inference_exports.py: 99%
325 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Extensions to class @see cl OnnxInference.
4"""
5import os
6import json
7import re
8from io import BytesIO
9import pickle
10import textwrap
11from onnx import numpy_helper
12from ..onnx_tools.onnx2py_helper import _var_as_dict, _type_to_string
13from ..tools.graphs import onnx2bigraph
14from ..plotting.text_plot import onnx_simple_text_plot
17class OnnxInferenceExport:
18 """
19 Implements methods to export a instance of
20 @see cl OnnxInference into :epkg:`json`, :epkg:`dot`,
21 *text*, *python*.
22 """
24 def __init__(self, oinf):
25 """
26 @param oinf @see cl OnnxInference
27 """
28 self.oinf = oinf
30 def to_dot(self, recursive=False, prefix='', # pylint: disable=R0914
31 add_rt_shapes=False, use_onnx=False,
32 add_functions=True, **params):
33 """
34 Produces a :epkg:`DOT` language string for the graph.
36 :param params: additional params to draw the graph
37 :param recursive: also show subgraphs inside operator like
38 @see cl Scan
39 :param prefix: prefix for every node name
40 :param add_rt_shapes: adds shapes infered from the python runtime
41 :param use_onnx: use :epkg:`onnx` dot format instead of this one
42 :param add_functions: add functions to the graph
43 :return: string
45 Default options for the graph are:
47 ::
49 options = {
50 'orientation': 'portrait',
51 'ranksep': '0.25',
52 'nodesep': '0.05',
53 'width': '0.5',
54 'height': '0.1',
55 'size': '7',
56 }
58 One example:
60 .. exref::
61 :title: Convert ONNX into DOT
63 An example on how to convert an :epkg:`ONNX`
64 graph into :epkg:`DOT`.
66 .. runpython::
67 :showcode:
68 :warningout: DeprecationWarning
70 import numpy
71 from mlprodict.npy.xop import loadop
72 from mlprodict.onnxrt import OnnxInference
74 OnnxAiOnnxMlLinearRegressor = loadop(
75 ('ai.onnx.ml', 'LinearRegressor'))
77 pars = dict(coefficients=numpy.array([1., 2.]),
78 intercepts=numpy.array([1.]),
79 post_transform='NONE')
80 onx = OnnxAiOnnxMlLinearRegressor(
81 'X', output_names=['Y'], **pars)
82 model_def = onx.to_onnx(
83 {'X': pars['coefficients'].astype(numpy.float32)},
84 outputs={'Y': numpy.float32},
85 target_opset=12)
86 oinf = OnnxInference(model_def)
87 print(oinf.to_dot())
89 See an example of representation in notebook
90 :ref:`onnxvisualizationrst`.
91 """
92 clean_label_reg1 = re.compile("\\\\x\\{[0-9A-F]{1,6}\\}")
93 clean_label_reg2 = re.compile("\\\\p\\{[0-9P]{1,6}\\}")
95 def dot_name(text):
96 return text.replace("/", "_").replace(
97 ":", "__").replace(".", "_")
99 def dot_label(text):
100 for reg in [clean_label_reg1, clean_label_reg2]:
101 fall = reg.findall(text)
102 for f in fall:
103 text = text.replace(f, "_") # pragma: no cover
104 return text
106 options = {
107 'orientation': 'portrait',
108 'ranksep': '0.25',
109 'nodesep': '0.05',
110 'width': '0.5',
111 'height': '0.1',
112 'size': '7',
113 }
114 options.update({k: v for k, v in params.items() if v is not None})
116 if use_onnx:
117 from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer
119 pydot_graph = GetPydotGraph(
120 self.oinf.obj.graph, name=self.oinf.obj.graph.name,
121 rankdir=params.get('rankdir', "TB"),
122 node_producer=GetOpNodeProducer(
123 "docstring", fillcolor="orange", style="filled",
124 shape="box"))
125 return pydot_graph.to_string()
127 inter_vars = {}
128 exp = ["digraph{"]
129 for opt in {'orientation', 'pad', 'nodesep', 'ranksep', 'size'}:
130 if opt in options:
131 exp.append(f" {opt}={options[opt]};")
132 fontsize = 10
134 shapes = {}
135 if add_rt_shapes:
136 if not hasattr(self.oinf, 'shapes_'):
137 raise RuntimeError( # pragma: no cover
138 "No information on shapes, check the runtime '{}'."
139 "".format(self.oinf.runtime))
140 for name, shape in self.oinf.shapes_.items():
141 va = str(shape.shape)
142 shapes[name] = va
143 if name in self.oinf.inplaces_:
144 shapes[name] += "\\ninplace"
146 # inputs
147 exp.append("")
148 graph = (
149 self.oinf.obj.graph if hasattr(self.oinf.obj, 'graph')
150 else self.oinf.obj)
151 for obj in graph.input:
152 if isinstance(obj, str):
153 exp.append(
154 ' {2}{0} [shape=box color=red label="{0}" fontsize={1}];'
155 ''.format(obj, fontsize, prefix))
156 inter_vars[obj] = obj
157 else:
158 dobj = _var_as_dict(obj)
159 sh = shapes.get(dobj['name'], '')
160 if sh:
161 sh = f"\\nshape={sh}"
162 exp.append(
163 ' {3}{0} [shape=box color=red label="{0}\\n{1}{4}" fontsize={2}];'
164 ''.format(
165 dot_name(dobj['name']), _type_to_string(dobj['type']),
166 fontsize, prefix, dot_label(sh)))
167 inter_vars[obj.name] = obj
169 # outputs
170 exp.append("")
171 for obj in graph.output:
172 if isinstance(obj, str):
173 exp.append(
174 ' {2}{0} [shape=box color=green label="{0}" fontsize={1}];'.format(
175 obj, fontsize, prefix))
176 inter_vars[obj] = obj
177 else:
178 dobj = _var_as_dict(obj)
179 sh = shapes.get(dobj['name'], '')
180 if sh:
181 sh = f"\\nshape={sh}"
182 exp.append(
183 ' {3}{0} [shape=box color=green label="{0}\\n{1}{4}" fontsize={2}];'
184 ''.format(
185 dot_name(dobj['name']), _type_to_string(dobj['type']),
186 fontsize, prefix, dot_label(sh)))
187 inter_vars[obj.name] = obj
189 # initializer
190 exp.append("")
191 if hasattr(self.oinf.obj, 'graph'):
192 inits = (
193 list(self.oinf.obj.graph.initializer) +
194 list(self.oinf.obj.graph.sparse_initializer))
195 for obj in inits:
196 dobj = _var_as_dict(obj)
197 val = dobj['value']
198 flat = val.flatten()
199 if flat.shape[0] < 9:
200 st = str(val)
201 else:
202 st = str(val)
203 if len(st) > 50:
204 st = st[:50] + '...'
205 st = st.replace('\n', '\\n')
206 kind = ""
207 exp.append(
208 ' {6}{0} [shape=box label="{0}\\n{4}{1}({2})\\n{3}" fontsize={5}];'
209 ''.format(
210 dot_name(dobj['name']), dobj['value'].dtype,
211 dobj['value'].shape, dot_label(st), kind, fontsize, prefix))
212 inter_vars[obj.name] = obj
214 # nodes
215 fill_names = {}
216 if hasattr(self.oinf.obj, 'graph'):
217 static_inputs = [n.name for n in self.oinf.obj.graph.input]
218 static_inputs.extend(
219 n.name for n in self.oinf.obj.graph.initializer)
220 static_inputs.extend(
221 n.name for n in self.oinf.obj.graph.sparse_initializer)
222 nodes = list(self.oinf.obj.graph.node)
223 else:
224 static_inputs = list(self.oinf.obj.input)
225 nodes = self.oinf.obj.node
226 for node in nodes:
227 exp.append("")
228 for out in node.output:
229 if len(out) > 0 and out not in inter_vars:
230 inter_vars[out] = out
231 sh = shapes.get(out, '')
232 if sh:
233 sh = f"\\nshape={sh}"
234 exp.append(
235 ' {2}{0} [shape=box label="{0}{3}" fontsize={1}];'.format(
236 dot_name(out), fontsize, dot_name(prefix),
237 dot_label(sh)))
238 static_inputs.append(out)
240 dobj = _var_as_dict(node)
241 if dobj['name'].strip() == '': # pragma: no cover
242 name = node.op_type
243 iname = 1
244 while name in fill_names:
245 name = "%s%d" % (name, iname)
246 iname += 1
247 dobj['name'] = name
248 node.name = name
249 fill_names[name] = node
251 atts = []
252 if 'atts' in dobj:
253 for k, v in sorted(dobj['atts'].items()):
254 val = None
255 if 'value' in v:
256 val = str(v['value']).replace(
257 "\n", "\\n").replace('"', "'")
258 sl = max(30 - len(k), 10)
259 if len(val) > sl:
260 val = val[:sl] + "..."
261 if val is not None:
262 atts.append(f'{k}={val}')
263 satts = "" if len(atts) == 0 else ("\\n" + "\\n".join(atts))
265 connects = []
266 if recursive and node.op_type in {'Scan', 'Loop', 'If'}:
267 fields = (['then_branch', 'else_branch']
268 if node.op_type == 'If' else ['body'])
269 for field in fields:
270 if field not in dobj['atts']:
271 continue # pragma: no cover
273 # creates the subgraph
274 body = dobj['atts'][field]['value']
275 oinf = self.oinf.__class__(
276 body, runtime=self.oinf.runtime,
277 skip_run=self.oinf.skip_run,
278 static_inputs=static_inputs)
279 subprefix = prefix + "B_"
280 subdot = oinf.to_dot(recursive=recursive, prefix=subprefix,
281 add_rt_shapes=add_rt_shapes)
282 lines = subdot.split("\n")
283 start = 0
284 for i, line in enumerate(lines):
285 if '[' in line:
286 start = i
287 break
288 subgraph = "\n".join(lines[start:])
290 # connecting the subgraph
291 cluster = f"cluster_{node.op_type}{id(node)}_{id(field)}"
292 exp.append(f" subgraph {cluster} {{")
293 exp.append(' label="{0}\\n({1}){2}";'.format(
294 dobj['op_type'], dot_name(dobj['name']), satts))
295 exp.append(f' fontsize={fontsize};')
296 exp.append(' color=black;')
297 exp.append(
298 '\n'.join(map(lambda s: ' ' + s, subgraph.split('\n'))))
300 node0 = body.node[0]
301 connects.append((
302 f"{dot_name(subprefix)}{dot_name(node0.name)}",
303 cluster))
305 for inp1, inp2 in zip(node.input, body.input):
306 exp.append(
307 " {0}{1} -> {2}{3};".format(
308 dot_name(prefix), dot_name(inp1),
309 dot_name(subprefix), dot_name(inp2.name)))
310 for out1, out2 in zip(body.output, node.output):
311 if len(out2) == 0:
312 # Empty output, it cannot be used.
313 continue
314 exp.append(
315 " {0}{1} -> {2}{3};".format(
316 dot_name(subprefix), dot_name(out1.name),
317 dot_name(prefix), dot_name(out2)))
318 else:
319 exp.append(' {4}{1} [shape=box style="filled,rounded" color=orange '
320 'label="{0}\\n({1}){2}" fontsize={3}];'.format(
321 dobj['op_type'], dot_name(
322 dobj['name']), satts, fontsize,
323 dot_name(prefix)))
325 if connects is not None and len(connects) > 0:
326 for name, cluster in connects:
327 exp.append(
328 " {0}{1} -> {2} [lhead={3}];".format(
329 dot_name(prefix), dot_name(node.name),
330 name, cluster))
332 for inp in node.input:
333 exp.append(
334 " {0}{1} -> {0}{2};".format(
335 dot_name(prefix), dot_name(inp), dot_name(node.name)))
336 for out in node.output:
337 if len(out) == 0:
338 # Empty output, it cannot be used.
339 continue
340 exp.append(
341 " {0}{1} -> {0}{2};".format(
342 dot_name(prefix), dot_name(node.name), dot_name(out)))
344 if add_functions and len(self.oinf.functions_) > 0:
345 for i, (k, v) in enumerate(self.oinf.functions_.items()):
346 dot = v.to_dot(recursive=recursive, prefix=prefix + v.obj.name,
347 add_rt_shapes=add_rt_shapes,
348 use_onnx=use_onnx, add_functions=False,
349 **params)
350 spl = dot.split('\n')[1:]
351 exp.append('')
352 exp.append(' subgraph cluster_%d {' % i)
353 exp.append(f' label="{v.obj.name}";')
354 exp.append(' color=blue;')
355 #exp.append(' style=filled;')
356 exp.extend((' ' + line) for line in spl)
358 exp.append('}')
359 return "\n".join(exp)
361 def to_json(self, indent=2):
362 """
363 Converts an :epkg:`ONNX` model into :epkg:`JSON`.
365 @param indent indentation
366 @return string
368 .. exref::
369 :title: Convert ONNX into JSON
371 An example on how to convert an :epkg:`ONNX`
372 graph into :epkg:`JSON`.
374 .. runpython::
375 :showcode:
376 :warningout: DeprecationWarning
378 import numpy
379 from mlprodict.npy.xop import loadop
380 from mlprodict.onnxrt import OnnxInference
382 OnnxAiOnnxMlLinearRegressor = loadop(
383 ('ai.onnx.ml', 'LinearRegressor'))
385 pars = dict(coefficients=numpy.array([1., 2.]),
386 intercepts=numpy.array([1.]),
387 post_transform='NONE')
388 onx = OnnxAiOnnxMlLinearRegressor(
389 'X', output_names=['Y'], **pars)
390 model_def = onx.to_onnx(
391 {'X': pars['coefficients'].astype(numpy.float32)},
392 outputs={'Y': numpy.float32},
393 target_opset=12)
394 oinf = OnnxInference(model_def)
395 print(oinf.to_json())
396 """
398 def _to_json(obj):
399 s = str(obj)
400 rows = ['{']
401 leave = None
402 for line in s.split('\n'):
403 if line.endswith("{"):
404 rows.append('"%s": {' % line.strip('{ '))
405 elif ':' in line:
406 spl = line.strip().split(':')
407 if len(spl) != 2:
408 raise RuntimeError( # pragma: no cover
409 f"Unable to interpret line '{line}'.")
411 if spl[0].strip() in ('type', ):
412 st = spl[1].strip()
413 if st in {'INT', 'INTS', 'FLOAT', 'FLOATS',
414 'STRING', 'STRINGS', 'TENSOR'}:
415 spl[1] = f'"{st}"'
417 if spl[0] in ('floats', 'ints'):
418 if leave:
419 rows.append(f"{spl[1]},")
420 else:
421 rows.append(f'"{spl[0]}": [{spl[1].strip()},')
422 leave = spl[0]
423 elif leave:
424 rows[-1] = rows[-1].strip(',')
425 rows.append('],')
426 rows.append(f'"{spl[0].strip()}": {spl[1].strip()},')
427 leave = None
428 else:
429 rows.append(f'"{spl[0].strip()}": {spl[1].strip()},')
430 elif line.strip() == "}":
431 rows[-1] = rows[-1].rstrip(",")
432 rows.append(line + ",")
433 elif line:
434 raise RuntimeError( # pragma: no cover
435 f"Unable to interpret line '{line}'.")
436 rows[-1] = rows[-1].rstrip(',')
437 rows.append("}")
438 js = "\n".join(rows)
440 try:
441 content = json.loads(js)
442 except json.decoder.JSONDecodeError as e: # pragma: no cover
443 js2 = "\n".join("%04d %s" % (i + 1, line)
444 for i, line in enumerate(js.split("\n")))
445 raise RuntimeError(
446 f"Unable to parse JSON\n{js2}") from e
447 return content
449 # meta data
450 final_obj = {}
451 for k in {'ir_version', 'producer_name', 'producer_version',
452 'domain', 'model_version', 'doc_string'}:
453 if hasattr(self.oinf.obj, k):
454 final_obj[k] = getattr(self.oinf.obj, k)
456 # inputs
457 inputs = []
458 for obj in self.oinf.obj.graph.input:
459 st = _to_json(obj)
460 inputs.append(st)
461 final_obj['inputs'] = inputs
463 # outputs
464 outputs = []
465 for obj in self.oinf.obj.graph.output:
466 st = _to_json(obj)
467 outputs.append(st)
468 final_obj['outputs'] = outputs
470 # init
471 inits = {}
472 for obj in self.oinf.obj.graph.initializer:
473 value = numpy_helper.to_array(obj).tolist()
474 inits[obj.name] = value
475 final_obj['initializers'] = inits
477 # nodes
478 nodes = []
479 for obj in list(self.oinf.obj.graph.node):
480 node = dict(name=obj.name, op_type=obj.op_type, domain=obj.domain,
481 inputs=[str(_) for _ in obj.input],
482 outputs=[str(_) for _ in obj.output],
483 attributes={})
484 for att in obj.attribute:
485 st = _to_json(att)
486 node['attributes'][st['name']] = st
487 del st['name']
488 nodes.append(node)
489 final_obj['nodes'] = nodes
491 return json.dumps(final_obj, indent=indent)
493 def to_python(self, prefix="onnx_pyrt_", dest=None, inline=True):
494 """
495 Converts the ONNX runtime into independant python code.
496 The function creates multiple files starting with
497 *prefix* and saved to folder *dest*.
499 @param prefix file prefix
500 @param dest destination folder
501 @param inline constant matrices are put in the python file itself
502 as byte arrays
503 @return file dictionary
505 The function does not work if the chosen runtime
506 is not *python*.
508 .. runpython::
509 :showcode:
510 :warningout: DeprecationWarning
512 import numpy
513 from mlprodict.npy.xop import loadop
514 from mlprodict.onnxrt import OnnxInference
516 OnnxAdd = loadop('Add')
518 idi = numpy.identity(2).astype(numpy.float32)
519 onx = OnnxAdd('X', idi, output_names=['Y'],
520 op_version=12)
521 model_def = onx.to_onnx({'X': idi},
522 target_opset=12)
523 X = numpy.array([[1, 2], [3, 4]], dtype=numpy.float32)
524 oinf = OnnxInference(model_def, runtime='python')
525 res = oinf.to_python()
526 print(res['onnx_pyrt_main.py'])
527 """
528 if not isinstance(prefix, str):
529 raise TypeError( # pragma: no cover
530 f"prefix must be a string not {type(prefix)!r}.")
532 def clean_args(args):
533 new_args = []
534 for v in args:
535 # remove python keywords
536 if v.startswith('min='):
537 av = 'min_=' + v[4:]
538 elif v.startswith('max='):
539 av = 'max_=' + v[4:]
540 else:
541 av = v
542 new_args.append(av)
543 return new_args
545 if self.oinf.runtime not in ('python', None):
546 raise ValueError(
547 f"The runtime must be 'python' not '{self.oinf.runtime}'.")
549 # metadata
550 obj = {}
551 for k in {'ir_version', 'producer_name', 'producer_version',
552 'domain', 'model_version', 'doc_string'}:
553 if hasattr(self.oinf.obj, k):
554 obj[k] = getattr(self.oinf.obj, k)
555 code_begin = ["# coding: utf-8",
556 "'''",
557 "Python code equivalent to an ONNX graph.",
558 "It was was generated by module *mlprodict*.",
559 "'''"]
560 code_imports = ["from io import BytesIO",
561 "import pickle",
562 "from numpy import array, float32, ndarray"]
563 code_lines = ["class OnnxPythonInference:", "",
564 " def __init__(self):",
565 " self._load_inits()", "",
566 " @property",
567 " def metadata(self):",
568 f" return {obj!r}", ""]
570 # inputs
571 if hasattr(self.oinf.obj, 'graph'):
572 inputs = [obj.name for obj in self.oinf.obj.graph.input]
573 outputs = [obj.name for obj in self.oinf.obj.graph.output]
574 else:
575 inputs = list(self.oinf.obj.input)
576 outputs = list(self.oinf.obj.output)
578 code_lines.extend([
579 " @property", " def inputs(self):",
580 f" return {inputs!r}",
581 ""
582 ])
584 # outputs
585 code_lines.extend([
586 " @property", " def outputs(self):",
587 f" return {outputs!r}",
588 ""
589 ])
591 # init
592 code_lines.extend([" def _load_inits(self):",
593 " self._inits = {}"])
594 file_data = {}
595 if hasattr(self.oinf.obj, 'graph'):
596 for obj in self.oinf.obj.graph.initializer:
597 value = numpy_helper.to_array(obj)
598 bt = BytesIO()
599 pickle.dump(value, bt)
600 name = f'{prefix}{obj.name}.pkl'
601 if inline:
602 code_lines.extend([
603 f" iocst = {bt.getvalue()!r}",
604 f" self._inits['{obj.name}'] = pickle.loads(iocst)"
605 ])
606 else:
607 file_data[name] = bt.getvalue()
608 code_lines.append(
609 f" self._inits['{obj.name}'] = pickle.loads('{name}')")
610 code_lines.append('')
612 # inputs, outputs
613 inputs = self.oinf.input_names
615 # nodes
616 code_lines.extend([f" def run(self, {', '.join(inputs)}):"])
617 ops = {}
618 if hasattr(self.oinf.obj, 'graph'):
619 code_lines.append(' # constant')
620 for obj in self.oinf.obj.graph.initializer:
621 code_lines.append(
622 " {0} = self._inits['{0}']".format(obj.name))
623 code_lines.append('')
624 code_lines.append(' # graph code')
625 for node in self.oinf.sequence_:
626 fct = 'pyrt_' + node.name
627 if fct not in ops:
628 ops[fct] = node
629 args = []
630 args.extend(node.inputs)
631 margs = node.modified_args
632 if margs is not None:
633 args.extend(clean_args(margs))
634 code_lines.append(" {0} = {1}({2})".format(
635 ', '.join(node.outputs), fct, ', '.join(args)))
636 code_lines.append('')
637 code_lines.append(' # return')
638 code_lines.append(f" return {', '.join(outputs)}")
639 code_lines.append('')
641 # operator code
642 code_nodes = []
643 for name, op in ops.items():
644 inputs_args = clean_args(op.inputs_args)
646 code_nodes.append(f"def {name}({', '.join(inputs_args)}):")
647 imps, code = op.to_python(op.python_inputs)
648 if imps is not None:
649 if not isinstance(imps, list):
650 imps = [imps]
651 code_imports.extend(imps)
652 code_nodes.append(textwrap.indent(code, ' '))
653 code_nodes.extend(['', ''])
655 # end
656 code_imports = list(sorted(set(code_imports)))
657 code_imports.extend(['', ''])
658 file_data[prefix + 'main.py'] = "\n".join(
659 code_begin + code_imports + code_nodes + code_lines)
661 # saves as files
662 if dest is not None:
663 for k, v in file_data.items():
664 ext = os.path.splitext(k)[-1]
665 kf = os.path.join(dest, k)
666 if ext == '.py':
667 with open(kf, "w", encoding="utf-8") as f:
668 f.write(v)
669 elif ext == '.pkl': # pragma: no cover
670 with open(kf, "wb") as f:
671 f.write(v)
672 else:
673 raise NotImplementedError( # pragma: no cover
674 f"Unknown extension for file '{k}'.")
675 return file_data
677 def to_text(self, recursive=False, grid=5, distance=5, kind='bi'):
678 """
679 It calls function @see fn onnx2bigraph to return
680 the ONNX graph as text.
682 :param recursive: dig into subgraphs too
683 :param grid: align text to this grid
684 :param distance: distance to the text
685 :param kind: see below
686 :return: text
688 Possible values for format:
689 * `'bi'`: use @see fn onnx2bigraph
690 * `'seq'`: use @see fn onnx_simple_text_plot
691 """
692 if kind == 'bi':
693 bigraph = onnx2bigraph(self.oinf.obj, recursive=recursive)
694 graph = bigraph.display_structure(grid=grid, distance=distance)
695 return graph.to_text()
696 if kind == 'seq':
697 return onnx_simple_text_plot(self.oinf.obj)
698 raise ValueError( # pragma: no cover
699 f"Unexpected value for format={format!r}.")
701 def to_onnx_code(self):
702 """
703 Exports the ONNX graph into an :epkg:`onnx` code
704 which replicates it.
706 :return: string
707 """
708 # Lazy import as it is not a common use.
709 from ..onnx_tools.onnx_export import export2onnx
710 return export2onnx(self.oinf.obj)