Coverage for mlprodict/cli/optimize.py: 100%

41 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Command line about model optimisation. 

4""" 

5import os 

6import onnx 

7 

8 

9def onnx_stats(name, optim=False, kind=None): 

10 """ 

11 Computes statistics on an ONNX model. 

12 

13 :param name: filename 

14 :param optim: computes statistics before an after optimisation was done 

15 :param kind: kind of statistics, if left unknown, 

16 prints out the metadata, possible values: 

17 * `io`: prints input and output name, type, shapes 

18 * `node`: prints the distribution of node types 

19 * `text`: printts a text summary 

20 

21 .. cmdref:: 

22 :title: Computes statistics on an ONNX graph 

23 :cmd: -m mlprodict onnx_stats --help 

24 :lid: l-cmd-onnx_stats 

25 

26 The command computes statistics on an ONNX model. 

27 """ 

28 if not os.path.exists(name): 

29 raise FileNotFoundError( # pragma: no cover 

30 f"Unable to find file '{name}'.") 

31 with open(name, 'rb') as f: 

32 model = onnx.load(f) 

33 if kind in (None, ""): 

34 from ..onnx_tools.optim import onnx_statistics 

35 return onnx_statistics(model, optim=optim) 

36 if kind == 'text': 

37 from ..plotting.plotting import onnx_simple_text_plot 

38 return onnx_simple_text_plot(model) 

39 if kind == 'io': 

40 from ..plotting.plotting import onnx_text_plot_io 

41 return onnx_text_plot_io(model) 

42 if kind == 'node': 

43 from ..onnx_tools.optim import onnx_statistics 

44 return onnx_statistics(model, optim=optim, node_type=True) 

45 raise ValueError( # pragma: no cover 

46 f"Unexpected kind={kind!r}.") 

47 

48 

49def onnx_optim(name, outfile=None, recursive=True, options=None, verbose=0, fLOG=None): 

50 """ 

51 Optimizes an ONNX model. 

52 

53 :param name: filename 

54 :param outfile: output filename 

55 :param recursive: processes the main graph and the subgraphs 

56 :param options: options, kind of optimize to do 

57 :param verbose: display statistics before and after the optimisation 

58 :param fLOG: logging function 

59 

60 .. cmdref:: 

61 :title: Optimizes an ONNX graph 

62 :cmd: -m mlprodict onnx_optim --help 

63 :lid: l-cmd-onnx_optim 

64 

65 The command optimizes an ONNX model. 

66 """ 

67 from ..onnx_tools.optim import onnx_statistics, onnx_optimisations 

68 if not os.path.exists(name): 

69 raise FileNotFoundError( # pragma: no cover 

70 f"Unable to find file '{name}'.") 

71 if outfile == "": 

72 outfile = None # pragma: no cover 

73 if options == "": 

74 options = None # pragma: no cover 

75 if verbose >= 1 and fLOG is not None: 

76 fLOG(f"[onnx_optim] read file '{name}'.") 

77 with open(name, 'rb') as f: 

78 model = onnx.load(f) 

79 if verbose >= 1 and fLOG is not None: 

80 stats = onnx_statistics(model, optim=False) 

81 for k, v in sorted(stats.items()): 

82 fLOG(f' before.{k}={v}') 

83 new_model = onnx_optimisations(model, recursive=recursive) 

84 if verbose >= 1 and fLOG is not None: 

85 stats = onnx_statistics(model, optim=False) 

86 for k, v in sorted(stats.items()): 

87 fLOG(f' after.{k}={v}') 

88 if outfile is not None: 

89 fLOG(f"[onnx_optim] write '{outfile}'.") 

90 with open(outfile, 'wb') as f: 

91 onnx.save(new_model, f) 

92 return new_model