Coverage for mlprodict/onnx_tools/optim/onnx_optimisation_redundant.py: 99%

110 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Optimisation of :epkg:`ONNX` graphs. 

4""" 

5import copy 

6import hashlib 

7import logging 

8from onnx import FunctionProto 

9from onnx.helper import make_graph, make_function 

10from ._onnx_optimisation_common import ( # pylint: disable=E0611 

11 _rename_node_input, 

12 _rename_node_output, 

13 _apply_optimisation_on_graph, 

14 _apply_remove_node_fct_node) 

15 

16 

17logger = logging.getLogger('onnx:optim') 

18 

19 

20def _hash_obj_content(obj, max_size=1000): 

21 """ 

22 Hash the content of an object. 

23 """ 

24 m = hashlib.sha256() 

25 if hasattr(obj, 'op_type'): 

26 # An operator. 

27 m.update(obj.op_type.encode('ascii')) 

28 m.update(len(obj.output).to_bytes(8, byteorder='big')) 

29 for i in obj.input: 

30 m.update(i.encode('ascii')) 

31 if hasattr(obj, 'attribute'): 

32 for att in obj.attribute: 

33 m.update(att.name.encode('ascii')) 

34 m.update(_hash_obj_content(att)) 

35 else: 

36 # An initializer. 

37 obj = copy.deepcopy(obj) 

38 obj.name = "" 

39 obj.doc_string = "" 

40 m.update(obj.SerializeToString()) 

41 

42 content = m.digest() 

43 if len(content) > max_size: 

44 content = content[:max_size] 

45 return content 

46 

47 

48def onnx_remove_node_redundant(onnx_model, recursive=True, debug_info=None, 

49 max_hash_size=1000, **options): 

50 """ 

51 Removes redundant part of the graph. A redundant part is 

52 a set of nodes which takes the same inputs and produces 

53 the same outputs. It first starts by looking into duplicated 

54 initializers, then looks into nodes taking the same inputs 

55 and sharing the same type and parameters. 

56 

57 @param onnx_model onnx model 

58 @param recursive looks into subgraphs 

59 @param debug_info debug information (private) 

60 @param max_hash_size limit the size of a hash used to detect 

61 identical subgraphs 

62 @param options additional options (unused) 

63 @return new onnx _model 

64 """ 

65 if debug_info is None: 

66 debug_info = [str(type(onnx_model)).rsplit( 

67 '.', maxsplit=1)[-1].strip("'>")] 

68 else: 

69 debug_info = (debug_info + 

70 [str(type(onnx_model)).rsplit('.', maxsplit=1)[-1].strip("'>")]) 

71 

72 if hasattr(onnx_model, 'graph'): 

73 return _apply_optimisation_on_graph( 

74 onnx_remove_node_redundant, onnx_model, 

75 recursive=recursive, debug_info=debug_info, 

76 max_hash_size=max_hash_size, **options) 

77 

78 def _enumerate_rename_list_nodes_inputs(nodes, rename): 

79 for i, node in enumerate(nodes): 

80 if node is None: 

81 yield False, i, None 

82 continue 

83 if any(set(node.input) & set(rename)): 

84 yield True, i, _rename_node_input(node, rename) 

85 continue 

86 yield False, i, node 

87 

88 graph = onnx_model 

89 logger.debug("onnx_remove_node_redundant:begin with %d nodes.", 

90 len(graph.node)) 

91 is_function = isinstance(graph, FunctionProto) 

92 

93 # Detects duplicated initializers. 

94 hashes = {} 

95 names = [] 

96 rename = {} 

97 if is_function: 

98 new_inits = [] 

99 else: 

100 for init in graph.initializer: 

101 hs = _hash_obj_content(init, max_size=max_hash_size) 

102 if hs in hashes: 

103 # Already seen. 

104 rename[init.name] = hashes[hs] # pragma: no cover 

105 else: 

106 # New. 

107 hashes[hs] = init.name 

108 names.append(init.name) 

109 new_inits = [init for init in graph.initializer 

110 if init.name in set(names)] 

111 

112 # Renames node inputs. 

113 new_nodes = [] 

114 new_nodes = list(graph.node) 

115 new_nodes = list( 

116 _[2] for _ in _enumerate_rename_list_nodes_inputs(new_nodes, rename)) 

117 

118 # Detects duplicated operators. 

119 if is_function: 

120 graph_outputs = set(graph.output) 

121 else: 

122 graph_outputs = set(o.name for o in graph.output) 

123 node_hashes = {} 

124 changed = 1 

125 replace = {} 

126 while changed > 0: 

127 changed = 0 

128 nnodes = len(new_nodes) 

129 for i in range(nnodes): 

130 if i in replace: 

131 # Already removed. 

132 continue 

133 node = new_nodes[i] 

134 hash = _hash_obj_content(node, max_size=max_hash_size) 

135 if hash in node_hashes: 

136 ni = node_hashes[hash] 

137 if ni == i: 

138 continue 

139 replace[i] = ni 

140 changed += 1 

141 

142 # Specifies what to rename. 

143 # One exception: the output is one of the graph output. 

144 rep = new_nodes[ni] 

145 for old, nn in zip(node.output, rep.output): 

146 if old in graph_outputs: 

147 rename[nn] = old 

148 new_nodes[ni] = _rename_node_output( 

149 new_nodes[ni], nn, old) 

150 else: 

151 rename[old] = nn 

152 

153 # Renames inputs. 

154 new_new_nodes = [] 

155 renew_index = set() 

156 for changed, ci, node in _enumerate_rename_list_nodes_inputs(new_nodes, rename): 

157 if changed: 

158 renew_index.add(ci) 

159 new_new_nodes.append(node) 

160 new_nodes = new_new_nodes 

161 

162 # Renews hashes. 

163 renew_hash = set( 

164 k for k, v in node_hashes.items() if v in renew_index) 

165 for hs in renew_hash: 

166 del node_hashes[hs] 

167 new_nodes[i] = None 

168 else: 

169 node_hashes[hash] = i 

170 

171 if recursive: 

172 # Handles subgraphs. 

173 for i in range(len(new_nodes)): # pylint: disable=C0200 

174 node = new_nodes[i] 

175 if node is None or not (node.attribute): # pylint: disable=C0325 

176 continue 

177 new_nodes[i] = _apply_remove_node_fct_node( 

178 onnx_remove_node_redundant, 

179 node, recursive=True, debug_info=debug_info + [node.name]) 

180 

181 # Finally create the new graph. 

182 nodes = list(filter(lambda n: n is not None, new_nodes)) 

183 if is_function: 

184 logger.debug("onnx_remove_node_redundant:end function with %d nodes.", 

185 len(nodes)) 

186 return make_function( 

187 onnx_model.domain, onnx_model.name, 

188 onnx_model.input, onnx_model.output, nodes, 

189 opset_imports=onnx_model.opset_import, 

190 attributes=onnx_model.attribute, 

191 doc_string=onnx_model.doc_string) 

192 

193 graph = make_graph(nodes, onnx_model.name, 

194 onnx_model.input, onnx_model.output, 

195 new_inits) 

196 

197 graph.value_info.extend(onnx_model.value_info) # pylint: disable=E1101 

198 logger.debug("onnx_remove_node_redundant:end graph with %d nodes.", 

199 len(nodes)) 

200 return graph