Coverage for mlprodict/cli/onnx_code.py: 96%

51 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Command line to check einsum scenarios. 

4""" 

5 

6 

7def onnx_code(filename, format="onnx", output=None, verbose=0, name=None, 

8 opset=None, fLOG=print): 

9 """ 

10 Exports an ONNX graph into a python code creating 

11 the same graph. 

12 

13 :param filename: onnx file 

14 :param format: format to export too (`onnx`, `tf2onnx`, `numpy`) 

15 :param output: output file to produce or None to print it on stdout 

16 :param verbose: verbosity level 

17 :param name: rewrite the graph name 

18 :param opset: overwrite the opset (may not works depending on the format) 

19 :param fLOG: logging function 

20 

21 .. cmdref:: 

22 :title: Exports an ONNX graph into a python code creating the same graph. 

23 :cmd: -m mlprodict onnx_code --help 

24 :lid: l-cmd-onnx_code 

25 

26 The command converts an ONNX graph into a python code generating 

27 the same graph. The python code may use onnx syntax, numpy syntax 

28 or tf2onnx syntax. 

29 

30 Example:: 

31 

32 python -m mlprodict onnx_code --filename="something.onnx" --format=onnx 

33 """ 

34 from ..onnx_tools.onnx_export import ( # pylint: disable=E0402 

35 export2onnx, export2tf2onnx, export2numpy) 

36 

37 if name == '': 

38 name = None # pragma: no cover 

39 if opset == '': 

40 opset = None # pragma: no cover 

41 try: 

42 v = int(opset) 

43 opset = v 

44 except (ValueError, TypeError): 

45 opset = None 

46 

47 if format == 'onnx': 

48 code = export2onnx(filename, verbose=verbose, name=name, opset=opset) 

49 elif format == 'tf2onnx': 

50 code = export2tf2onnx(filename, verbose=verbose, 

51 name=name, opset=opset) 

52 elif format == 'numpy': 

53 code = export2numpy(filename, verbose=verbose, 

54 name=name, opset=opset) 

55 else: 

56 raise ValueError( # pragma: no cover 

57 f"Unknown format {format!r}.") 

58 

59 if output not in ('', None): 

60 with open(output, "w", encoding="utf-8") as f: 

61 f.write(code) 

62 else: 

63 fLOG(code) # pragma: no cover 

64 

65 

66def dynamic_doc(verbose=0, fLOG=print): 

67 """ 

68 Generates the documentation for ONNX operators. 

69 

70 :param verbose: displays the list of operator 

71 :param fLOG: logging function 

72 """ 

73 from ..npy.xop import _dynamic_class_creation 

74 _dynamic_class_creation(cache=True, verbose=verbose, fLOG=fLOG) 

75 

76 

77def plot_onnx(filename, format="onnx", verbose=0, output=None, fLOG=print): 

78 """ 

79 Plots an ONNX graph on the standard output. 

80 

81 :param filename: onnx file 

82 :param format: format to export too (`simple`, `tree`, `dot`, 

83 `io`, `mat`, `raw`) 

84 :param output: output file to produce or None to print it on stdout 

85 :param verbose: verbosity level 

86 :param fLOG: logging function 

87 

88 .. cmdref:: 

89 :title: Plots an ONNX graph as text 

90 :cmd: -m mlprodict plot_onnx --help 

91 :lid: l-cmd-plot_onnx 

92 

93 The command shows the ONNX graphs as a text on the standard output. 

94 

95 Example:: 

96 

97 python -m mlprodict plot_onnx --filename="something.onnx" --format=simple 

98 """ 

99 if isinstance(filename, str): 

100 from onnx import load 

101 content = load(filename) 

102 else: 

103 content = filename 

104 if format == 'dot': 

105 from ..onnxrt import OnnxInference 

106 code = OnnxInference(filename).to_dot() 

107 elif format == 'simple': 

108 from mlprodict.plotting.text_plot import onnx_simple_text_plot 

109 code = onnx_simple_text_plot(content) 

110 elif format == 'io': 

111 from mlprodict.plotting.text_plot import onnx_text_plot_io 

112 code = onnx_text_plot_io(content) 

113 elif format == 'mat': 

114 from mlprodict.plotting.text_plot import onnx_text_plot 

115 code = onnx_text_plot(content) 

116 elif format == 'raw': 

117 code = str(content) 

118 elif format == 'tree': 

119 from mlprodict.plotting.plotting import onnx_text_plot_tree 

120 rows = [] 

121 for node in list(content.graph.node): 

122 if node.op_type.startswith("TreeEnsemble"): 

123 rows.append(f'Node type={node.op_type!r} name={node.name!r}') 

124 rows.append(onnx_text_plot_tree(node)) 

125 code = "\n".join(rows) 

126 else: 

127 raise ValueError( # pragma: no cover 

128 f"Unknown format {format!r}.") 

129 

130 if output not in ('', None): 

131 with open(output, "w", encoding="utf-8") as f: 

132 f.write(code) 

133 else: 

134 fLOG(code) # pragma: no cover