Hide keyboard shortcuts

Hot-keys on this page

r m x p   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

1""" 

2@file 

3@brief Shorten code in notebook :ref:`onnxsklearnconsortiumrst`. 

4""" 

5import os 

6import sys 

7from collections import OrderedDict 

8import warnings 

9from pyquickhelper.pycode.profiling import profile 

10from pyquickhelper.helpgen.rst_converters import docstring2html 

11from pyensae.graphhelper import draw_diagram 

12from jyquickhelper import RenderJsDot 

13import sklearn 

14from skl2onnx.proto import TensorProto 

15from onnx import helper 

16 

17 

18def graph_persistence_pickle(): 

19 """ 

20 See :ref:`onnxsklearnconsortiumrst`. 

21 """ 

22 return draw_diagram(""" 

23 blockdiag { 

24 default_fontsize = 20; node_width = 200; node_height = 100; 

25 model[label="trained model\\nscikit-learn"]; 

26 pkl[label="pickled model"]; 

27 rest[label="restored model\\nscikit-learn", textcolor="#00AAAA"]; 

28 pkl -> rest; 

29 pred[label="predictions"]; 

30 train[label="training"]; 

31 group { 

32 label = "machine 1"; 

33 color = "#FFAAAA"; 

34 model -> pkl; pkl; 

35 } 

36 group { 

37 label = "machine 2"; 

38 color = "#AAFFAA"; 

39 rest -> pred; rest -> train; 

40 } 

41 }""") 

42 

43 

44def graph_persistence_pickle_issues(): 

45 """ 

46 See :ref:`onnxsklearnconsortiumrst`. 

47 """ 

48 return draw_diagram(""" 

49 blockdiag { 

50 default_fontsize = 20; node_width = 200; node_height = 100; 

51 pkl[label="pickled model"]; 

52 rest[label="restored model\\nscikit-learn\\nUNSTABLE", textcolor="#00AAAA"]; 

53 pkl -> rest; 

54 pred[label="predictions\\nSLOW"]; 

55 train[label="training"]; 

56 group { 

57 label = "machine 1"; 

58 color = "#FFAAAA"; pkl; 

59 } 

60 group { 

61 label = "machine 2"; 

62 color = "#AAFFAA"; 

63 rest -> pred; rest -> train; 

64 } 

65 }""") 

66 

67 

68def graph_persistence_onnx(): 

69 """ 

70 See :ref:`onnxsklearnconsortiumrst`. 

71 """ 

72 return draw_diagram(""" 

73 blockdiag { 

74 default_fontsize = 20; node_width = 200; node_height = 100; 

75 model[label="trained model\\nscikit-learn"]; 

76 onnx[label="ONNX model"]; 

77 rest[label="ONNX runtime", textcolor="#00AAAA"]; 

78 onnx -> rest; 

79 pred[label="predictions"]; 

80 notrain[label="cannot train", color="#FF0000"]; 

81 group { 

82 label = "machine 1"; 

83 color = "#FFAAAA"; 

84 model -> onnx[label="conversion"]; 

85 onnx; 

86 } 

87 group { 

88 label = "machine 2"; 

89 color = "#AAFFAA"; 

90 rest ; 

91 pred; 

92 rest -> pred; 

93 rest -> notrain[folded]; 

94 } 

95 }""") 

96 

97 

98def graph_three_components(): 

99 """ 

100 See :ref:`onnxsklearnconsortiumrst`. 

101 """ 

102 return draw_diagram(""" 

103 blockdiag { 

104 default_fontsize = 20; node_width = 200; node_height = 100; 

105 onnx[label="ONNX\\n\\nset of mathematical functions", color="#FFFF00"]; 

106 conv[label="converter\\n\\nsklearn-onnx", color="#FFFF00"]; 

107 run[label="runtime\\n\\nonnxruntime\\nonnx.js\\n...", color="#FFFF00"]; 

108 onnx -> conv -> run ; 

109 }""") 

110 

111 

112def profile_fct_graph(fct, title, highlights=None, nb=20, figsize=(10, 3)): 

113 """ 

114 Returns a graph which profiles the execution of function *fct*. 

115 See :ref:`onnxsklearnconsortiumrst`. 

116 """ 

117 paths = [os.path.dirname(sklearn.__file__), 

118 "site-packages", 

119 os.path.join(sys.prefix, "lib")] 

120 _, df = profile(fct, as_df=True, rootrem=paths) # pylint: disable=W0632 

121 colname = 'namefct' if 'namefct' in df.columns else 'fct' 

122 sdf = df[[colname, 'cum_tall']].head(n=nb).set_index(colname) 

123 index_list = list(sdf.index) 

124 ax = sdf.plot(kind='bar', figsize=figsize, rot=30) 

125 ax.set_title(title) 

126 for la in ax.get_xticklabels(): 

127 la.set_horizontalalignment('right') 

128 if highlights: 

129 for lab in highlights: 

130 if lab not in index_list: 

131 new_labs = [ns for ns in index_list if isinstance( 

132 ns, str) and lab in ns] 

133 if len(new_labs) == 0: 

134 raise ValueError("Unable to find '{}' in '{}'?".format( 

135 lab, ", ".join(sorted(map(str, index_list))))) 

136 labs = new_labs 

137 else: 

138 labs = [lab] 

139 for la in labs: 

140 pos = sdf.index.get_loc(la) 

141 h = 0.15 

142 ax.plot([pos - 0.35, pos - 0.35], [0, h], 'r--') 

143 ax.plot([pos + 0.3, pos + 0.3], [0, h], 'r--') 

144 ax.plot([pos - 0.35, pos + 0.3], [h, h], 'r--') 

145 return ax 

146 

147 

148def onnx2str(model_onnx, nrows=15): 

149 """ 

150 Displays the beginning of an ONNX graph. 

151 See :ref:`onnxsklearnconsortiumrst`. 

152 """ 

153 lines = str(model_onnx).split('\n') 

154 if len(lines) > nrows: 

155 lines = lines[:nrows] + ['...'] 

156 return "\n".join(lines) 

157 

158 

159def onnx2dotnb(model_onnx, width="100%", orientation="LR"): 

160 """ 

161 Converts an ONNX graph into dot then into :epkg:`RenderJsDot`. 

162 See :ref:`onnxsklearnconsortiumrst`. 

163 """ 

164 from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer 

165 pydot_graph = GetPydotGraph( 

166 model_onnx.graph, name=model_onnx.graph.name, rankdir=orientation, 

167 node_producer=GetOpNodeProducer("docstring", color="yellow", 

168 fillcolor="yellow", style="filled")) 

169 dot = pydot_graph.to_string() 

170 return RenderJsDot(dot, width=width) 

171 

172 

173def onnx2graph(onnx_model): 

174 """ 

175 Converts an :epkg:`ONNX` model into a readable graph. 

176 

177 @param onnx_model onnx_model 

178 @return graph defined with :epkg:`OrderedDict` 

179 so that it can be processed by epkg:`asciitree` 

180 """ 

181 vars = {} 

182 

183 for node in onnx_model.graph.node: 

184 key = "%s[%s]" % (node.name, node.op_type) 

185 for inp in node.input: 

186 if inp not in vars: 

187 vars[inp] = [] 

188 if key not in vars[inp]: 

189 vars[inp].append(key) 

190 vars[key] = [] 

191 for out in node.output: 

192 if out not in vars[key]: 

193 vars[key].append(out) 

194 

195 return edges2asciitree(vars) 

196 

197 

198def edges2asciitree(edges): 

199 """ 

200 Converts a set of edges into a combination 

201 of :epkg:`OrderedDict` which can be understood 

202 by :epkg:`asciitree`. This does not work if one node 

203 has multiple inputs. 

204 

205 @param edges set of edges 

206 @return :epkg:`OrderedDict` 

207 

208 .. runpython:: 

209 :showcode: 

210 

211 data = {'X': ['LinearClassifier[LinearClassifier]'], 

212 'LinearClassifier[LinearClassifier]': 

213 ['label', 'probability_tensor'], 

214 'probability_tensor': ['Normalizer[Normalizer]'], 

215 'Normalizer[Normalizer]': ['probabilities'], 

216 'label': ['Cast[Cast]'], 

217 'Cast[Cast]': ['output_label'], 

218 'probabilities': ['ZipMap[ZipMap]'], 

219 'ZipMap[ZipMap]': ['output_probability']} 

220 

221 from jupytalk.talk_examples.sklearn2019 import edges2asciitree 

222 res = edges2asciitree(data) 

223 

224 import pprint 

225 pprint.pprint(res) 

226 

227 from asciitree import LeftAligned 

228 tr = LeftAligned() 

229 print(tr(res)) 

230 """ 

231 roots = [] 

232 values = [] 

233 for _, eds in edges.items(): 

234 values.extend(eds) 

235 vs = set(values) 

236 for key in edges: 

237 if key not in vs: 

238 roots.append(key) 

239 

240 if len(roots) > 1: 

241 edges = edges.copy() 

242 edges['root'] = roots 

243 roots = ['root'] 

244 

245 res = OrderedDict() 

246 find = {} 

247 for r in roots: 

248 res[r] = OrderedDict() 

249 find[r] = res[r] 

250 

251 modif = 1 

252 while modif > 0: 

253 modif = 0 

254 for k, eds in edges.items(): 

255 if k in find: 

256 ord = find[k] 

257 for edge in eds: 

258 if edge not in ord: 

259 ord[edge] = OrderedDict() 

260 find[edge] = ord[edge] 

261 modif += 1 

262 

263 return res 

264 

265 

266def onnxdocstring2html(doc, start="number of targets."): 

267 """ 

268 Converts the ONNX documentation into rst. 

269 """ 

270 if start is not None: 

271 doc = doc.split(start)[-1] 

272 with warnings.catch_warnings(): 

273 warnings.filterwarnings("ignore") 

274 return docstring2html(doc.replace("Default value is ````", "")) 

275 

276 

277def rename_input_output(model_onnx): 

278 """ 

279 Renames all input and output of an ONNX file. 

280 """ 

281 def clean_name(name): 

282 return name.replace("_", "") 

283 

284 def copy_inout(inout): 

285 shape = [s.dim_value for s in inout.type.tensor_type.shape.dim] 

286 value_info = helper.make_tensor_value_info( 

287 clean_name(inout.name), 

288 inout.type.tensor_type.elem_type, 

289 shape) 

290 return value_info 

291 

292 graph = model_onnx.graph 

293 inputs = [copy_inout(o) for o in graph.input] 

294 outputs = [copy_inout(o) for o in graph.output] 

295 nodes = [] 

296 for node in graph.node: 

297 n = helper.make_node(node.op_type, 

298 [clean_name(o) for o in node.input], 

299 [clean_name(o) for o in node.output]) 

300 n.attribute.extend(node.attribute) # pylint: disable=E1101 

301 nodes.append(n) 

302 

303 inits = [] 

304 for o in graph.initializer: 

305 tensor = TensorProto() 

306 tensor.data_type = o.data_type 

307 tensor.name = clean_name(o.name) 

308 tensor.raw_data = o.raw_data 

309 tensor.dims.extend(o.dims) # pylint: disable=E1101 

310 inits.append(tensor) 

311 

312 graph = helper.make_graph(nodes, graph.name, inputs, outputs, inits) 

313 onnx_model = helper.make_model(graph) 

314 onnx_model.ir_version = model_onnx.ir_version 

315 onnx_model.producer_name = model_onnx.producer_name 

316 onnx_model.producer_version = model_onnx.producer_version 

317 onnx_model.domain = model_onnx.domain 

318 onnx_model.model_version = model_onnx.model_version 

319 onnx_model.doc_string = model_onnx.doc_string 

320 return onnx_model