Coverage for mlprodict/cli/optimize.py: 100%
41 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Command line about model optimisation.
4"""
5import os
6import onnx
9def onnx_stats(name, optim=False, kind=None):
10 """
11 Computes statistics on an ONNX model.
13 :param name: filename
14 :param optim: computes statistics before an after optimisation was done
15 :param kind: kind of statistics, if left unknown,
16 prints out the metadata, possible values:
17 * `io`: prints input and output name, type, shapes
18 * `node`: prints the distribution of node types
19 * `text`: printts a text summary
21 .. cmdref::
22 :title: Computes statistics on an ONNX graph
23 :cmd: -m mlprodict onnx_stats --help
24 :lid: l-cmd-onnx_stats
26 The command computes statistics on an ONNX model.
27 """
28 if not os.path.exists(name):
29 raise FileNotFoundError( # pragma: no cover
30 f"Unable to find file '{name}'.")
31 with open(name, 'rb') as f:
32 model = onnx.load(f)
33 if kind in (None, ""):
34 from ..onnx_tools.optim import onnx_statistics
35 return onnx_statistics(model, optim=optim)
36 if kind == 'text':
37 from ..plotting.plotting import onnx_simple_text_plot
38 return onnx_simple_text_plot(model)
39 if kind == 'io':
40 from ..plotting.plotting import onnx_text_plot_io
41 return onnx_text_plot_io(model)
42 if kind == 'node':
43 from ..onnx_tools.optim import onnx_statistics
44 return onnx_statistics(model, optim=optim, node_type=True)
45 raise ValueError( # pragma: no cover
46 f"Unexpected kind={kind!r}.")
49def onnx_optim(name, outfile=None, recursive=True, options=None, verbose=0, fLOG=None):
50 """
51 Optimizes an ONNX model.
53 :param name: filename
54 :param outfile: output filename
55 :param recursive: processes the main graph and the subgraphs
56 :param options: options, kind of optimize to do
57 :param verbose: display statistics before and after the optimisation
58 :param fLOG: logging function
60 .. cmdref::
61 :title: Optimizes an ONNX graph
62 :cmd: -m mlprodict onnx_optim --help
63 :lid: l-cmd-onnx_optim
65 The command optimizes an ONNX model.
66 """
67 from ..onnx_tools.optim import onnx_statistics, onnx_optimisations
68 if not os.path.exists(name):
69 raise FileNotFoundError( # pragma: no cover
70 f"Unable to find file '{name}'.")
71 if outfile == "":
72 outfile = None # pragma: no cover
73 if options == "":
74 options = None # pragma: no cover
75 if verbose >= 1 and fLOG is not None:
76 fLOG(f"[onnx_optim] read file '{name}'.")
77 with open(name, 'rb') as f:
78 model = onnx.load(f)
79 if verbose >= 1 and fLOG is not None:
80 stats = onnx_statistics(model, optim=False)
81 for k, v in sorted(stats.items()):
82 fLOG(f' before.{k}={v}')
83 new_model = onnx_optimisations(model, recursive=recursive)
84 if verbose >= 1 and fLOG is not None:
85 stats = onnx_statistics(model, optim=False)
86 for k, v in sorted(stats.items()):
87 fLOG(f' after.{k}={v}')
88 if outfile is not None:
89 fLOG(f"[onnx_optim] write '{outfile}'.")
90 with open(outfile, 'wb') as f:
91 onnx.save(new_model, f)
92 return new_model