Coverage for mlprodict/onnx_tools/onnx_grammar/node_visitor_translator.py: 95%

171 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief One class which visits a syntax tree. 

4""" 

5 

6import ast 

7from .onnx_translator import OnnxTranslator 

8 

9 

10class CodeNodeVisitor(ast.NodeVisitor): 

11 

12 """ 

13 Defines a visitor which walks though the syntax tree of the code. 

14 

15 .. exref:: 

16 :title: Get the tree of a simple function 

17 

18 The following code uses Python syntax but follows a SQL logic. 

19 

20 .. runpython:: 

21 :showcode: 

22 :warningout: DeprecationWarning 

23 :process: 

24 :store_in_file: fct2onnx1.py 

25 

26 import ast 

27 import inspect 

28 from textwrap import dedent 

29 from mlprodict.onnx_tools.onnx_grammar import CodeNodeVisitor 

30 

31 def norm2(x, y): 

32 delta = x - y 

33 n = delta ** 2 

34 return n 

35 

36 code = dedent(inspect.getsource(norm2)) 

37 node = ast.parse(code) 

38 v = CodeNodeVisitor() 

39 v.visit(node) 

40 for r in v.Rows : 

41 print("{0}{1}: {2}".format(" " * r["indent"], r["type"], r["str"])) 

42 """ 

43 

44 def __init__(self, translator=None): 

45 """ 

46 @param translator @see cl CodeTranslator 

47 

48 By default the translator is @see cl OnnxTranslator. 

49 """ 

50 ast.NodeVisitor.__init__(self) 

51 self._rows = [] 

52 self._indent = 0 

53 self._stack = [] 

54 self._translator = OnnxTranslator( 

55 self) if translator is None else translator 

56 

57 def push(self, row): 

58 """ 

59 Pushes an element into a list. 

60 """ 

61 self._rows.append(row) 

62 

63 def generic_visit(self, node): 

64 """ 

65 Overrides ``generic_visit`` to check it is not used. 

66 """ 

67 raise AttributeError( # pragma: no cover 

68 "generic_visit_args should be used.") 

69 

70 def generic_visit_args(self, node, row): 

71 """ 

72 Overrides ``generic_visit`` to keep track of the indentation 

73 and the node parent. The function will add field 

74 ``row["children"] = visited`` nodes from here. 

75 

76 @param node node which needs to be visited 

77 @param row row (a dictionary) 

78 @return See ``ast.NodeVisitor.generic_visit`` 

79 """ 

80 if hasattr(node, 'lineno'): 

81 row['lineno'] = node.lineno 

82 if hasattr(node, 'col_offset'): 

83 row['col_offset'] = node.col_offset 

84 self._indent += 1 

85 last = len(self._rows) 

86 self._translator.visit(node, row) 

87 res = ast.NodeVisitor.generic_visit( # pylint: disable=E1111 

88 self, node) # pylint: disable=E1111 

89 row["children"] = [ 

90 _ for _ in self._rows[ 

91 last:] if _["indent"] == self._indent] 

92 self._indent -= 1 

93 self._translator.depart(node, row) 

94 return res 

95 

96 def make_msg(self, node): 

97 """ 

98 Displays line and column information into a string. 

99 """ 

100 return "line {}, col {}".format( # pragma: no cover 

101 getattr(node, 'lineno', '?'), getattr(node, 'col_offset', '?')) 

102 

103 def visit(self, node): 

104 """ 

105 Visits a node, a method must exist for every object class. 

106 """ 

107 method = 'visit_' + node.__class__.__name__ 

108 visitor = getattr(self, method, None) 

109 if visitor is None: 

110 raise TypeError( # pragma: no cover 

111 f"Unable to find a method '{method}' at {self.make_msg(node)}.") 

112 res = visitor(node) 

113 # print(method, CodeNodeVisitor.print_node(node)) 

114 return res 

115 

116 def visit_(self, node): 

117 """ 

118 If an element is not found... 

119 """ 

120 raise NotImplementedError( # pragma: no cover 

121 "Node '{}' ({}) not recognized at {}\nNode\n{}\n--" 

122 "Status--\n{}".format( 

123 node, type(node), self.make_msg(node), 

124 self.print_node(node), self.print_tree())) 

125 

126 @staticmethod 

127 def print_node(node): 

128 """ 

129 Debugging purpose. 

130 """ 

131 r = [] 

132 for att in sorted(set(["s", "name", "str", "id", "body", "n", 

133 "arg", "targets", "attr", "returns", "ctx", 

134 'col_offset', 'lineno', 

135 'value'] + list(getattr(node, '_attributes', [])))): 

136 v = getattr(node, att, None) 

137 if v is not None or att in getattr(node, '_fields', []): 

138 r.append(f"{att}={v}") 

139 return " ".join(r) 

140 

141 def print_tree(self): 

142 """ 

143 Displays the tree of instructions. 

144 

145 @return string 

146 """ 

147 rows = [] 

148 for r in self.Rows: 

149 rows.append( 

150 f"{' ' * r['indent']}{r['type']}: {r['str']}") 

151 return "\n".join(rows) 

152 

153 @property 

154 def Rows(self): 

155 """ 

156 returns a list of dictionaries with all the elements of the code 

157 """ 

158 return [_ for _ in self._rows if not _.get("remove", False)] 

159 

160 def export(self, context=None, **kwargs): 

161 """ 

162 Calls method *export* from the translator class. 

163 

164 @param context known :epkg:`python` needed to run 

165 the translated function 

166 @param kwargs whatever the method *export* from 

167 the translator class ingests 

168 @return whatever the method *export* from 

169 the translator class returns 

170 """ 

171 return self._translator.export(context=context, **kwargs) 

172 

173 ########### 

174 # Methods for python code elements 

175 ########### 

176 

177 def visit_Str(self, node): # pylint: disable=C0111 

178 cont = { 

179 "indent": self._indent, 

180 "type": "Str", 

181 "str": node.s, 

182 "node": node, 

183 "value": node.s} 

184 self.push(cont) 

185 return self.generic_visit_args(node, cont) 

186 

187 def visit_Name(self, node): # pylint: disable=C0111 

188 cont = { 

189 "indent": self._indent, 

190 "type": "Name", 

191 "str": node.id, 

192 "node": node, 

193 "id": node.id, 

194 "ctx": node.ctx} 

195 self.push(cont) 

196 return self.generic_visit_args(node, cont) 

197 

198 def visit_Module(self, node): # pylint: disable=C0111 

199 cont = { 

200 "indent": self._indent, 

201 "type": "Module", 

202 "str": "", 

203 "body": node.body, 

204 "node": node} 

205 self.push(cont) 

206 return self.generic_visit_args(node, cont) 

207 

208 def visit_FunctionDef(self, node): # pylint: disable=C0111 

209 cont = {"indent": self._indent, "type": "FunctionDef", "str": node.name, "name": node.name, "body": node.body, 

210 "node": node, "returns": node.returns} 

211 self.push(cont) 

212 return self.generic_visit_args(node, cont) 

213 

214 def visit_List(self, node): # pylint: disable=C0111 

215 cont = {"indent": self._indent, "type": "List", 

216 "str": "", "elts": node.elts, 

217 "node": node} 

218 self.push(cont) 

219 return self.generic_visit_args(node, cont) 

220 

221 def visit_arguments(self, node): # pylint: disable=C0111 

222 cont = {"indent": self._indent, "type": "arguments", "str": "", 

223 "node": node, "args": node.args} 

224 self.push(cont) 

225 return self.generic_visit_args(node, cont) 

226 

227 def visit_arg(self, node): # pylint: disable=C0111 

228 cont = {"indent": self._indent, "type": "arg", "str": node.arg, 

229 "node": node, 

230 "arg": node.arg, "annotation": node.annotation} 

231 self.push(cont) 

232 return self.generic_visit_args(node, cont) 

233 

234 def visit_Assign(self, node): # pylint: disable=C0111 

235 cont = {"indent": self._indent, "type": "Assign", "str": "", "node": node, 

236 "targets": node.targets, "value": node.value} 

237 self.push(cont) 

238 return self.generic_visit_args(node, cont) 

239 

240 def visit_Store(self, node): # pylint: disable=C0111 

241 #cont = { "indent":self._indent, "type": "Store", "str": "" } 

242 # self.push(cont) 

243 cont = {} 

244 return self.generic_visit_args(node, cont) 

245 

246 def visit_Call(self, node): # pylint: disable=C0111 

247 if "attr" in node.func.__dict__: 

248 cont = {"indent": self._indent, "type": "Call", "str": node.func.attr, 

249 "node": node, "func": node.func} 

250 else: 

251 cont = {"indent": self._indent, "type": "Call", "str": node.func.id, 

252 "node": node, "func": node.func} 

253 self.push(cont) 

254 return self.generic_visit_args(node, cont) 

255 

256 def visit_Attribute(self, node): # pylint: disable=C0111 

257 cont = {"indent": self._indent, "type": "Attribute", "str": node.attr, 

258 "node": node, "value": node.value, "ctx": node.ctx, "attr": node.attr} 

259 self.push(cont) 

260 # last = len(self._rows) 

261 res = self.generic_visit_args(node, cont) 

262 

263 if len(cont["children"]) > 0: 

264 fir = cont["children"][0] 

265 if fir["type"] == "Name": 

266 parent = fir["node"].id 

267 cont["str"] = f"{parent}.{cont['str']}" 

268 cont["children"][0]["remove"] = True 

269 return res 

270 

271 def visit_Load(self, node): # pylint: disable=C0111 

272 cont = {} 

273 return self.generic_visit_args(node, cont) 

274 

275 def visit_keyword(self, node): # pylint: disable=C0111 

276 cont = {"indent": self._indent, "type": "keyword", "str": f"{node.arg}", 

277 "node": node, "arg": node.arg, "value": node.value} 

278 self.push(cont) 

279 return self.generic_visit_args(node, cont) 

280 

281 def visit_BinOp(self, node): # pylint: disable=C0111 

282 cont = {"indent": self._indent, "type": "BinOp", 

283 "str": "", "node": node} 

284 self.push(cont) 

285 return self.generic_visit_args(node, cont) 

286 

287 def visit_Div(self, node): # pylint: disable=C0111 

288 cont = {"indent": self._indent, "type": "Div", 

289 "str": "", "node": node} 

290 self.push(cont) 

291 return self.generic_visit_args(node, cont) 

292 

293 def visit_Sub(self, node): # pylint: disable=C0111 

294 cont = {"indent": self._indent, "type": "Sub", 

295 "str": "", "node": node} 

296 self.push(cont) 

297 return self.generic_visit_args(node, cont) 

298 

299 def visit_USub(self, node): # pylint: disable=C0111 

300 cont = {"indent": self._indent, "type": "Sub", 

301 "str": "", "node": node} 

302 self.push(cont) 

303 return self.generic_visit_args(node, cont) 

304 

305 def visit_Add(self, node): # pylint: disable=C0111 

306 cont = {"indent": self._indent, "type": "Add", 

307 "str": "", "node": node} 

308 self.push(cont) 

309 return self.generic_visit_args(node, cont) 

310 

311 def visit_Pow(self, node): # pylint: disable=C0111 

312 cont = {"indent": self._indent, "type": "Pow", 

313 "str": "", "node": node} 

314 self.push(cont) 

315 return self.generic_visit_args(node, cont) 

316 

317 def visit_Mult(self, node): # pylint: disable=C0111 

318 cont = {"indent": self._indent, "type": "Mult", 

319 "str": "", "node": node} 

320 self.push(cont) 

321 return self.generic_visit_args(node, cont) 

322 

323 def visit_MatMult(self, node): # pylint: disable=C0111 

324 cont = {"indent": self._indent, "type": "MatMult", 

325 "str": "", "node": node} 

326 self.push(cont) 

327 return self.generic_visit_args(node, cont) 

328 

329 def visit_Compare(self, node): # pylint: disable=C0111 

330 cont = {"indent": self._indent, "type": "Compare", 

331 "str": "", "node": node} 

332 self.push(cont) 

333 return self.generic_visit_args(node, cont) 

334 

335 def visit_Gt(self, node): # pylint: disable=C0111 

336 cont = {"indent": self._indent, "type": "Gt", "str": "", "node": node} 

337 self.push(cont) 

338 return self.generic_visit_args(node, cont) 

339 

340 def visit_Lt(self, node): # pylint: disable=C0111 

341 cont = {"indent": self._indent, "type": "Lt", "str": "", "node": node} 

342 self.push(cont) 

343 return self.generic_visit_args(node, cont) 

344 

345 def visit_UnaryOp(self, node): # pylint: disable=C0111 

346 cont = {"indent": self._indent, 

347 "type": "UnaryOp", "str": "", "node": node} 

348 self.push(cont) 

349 return self.generic_visit_args(node, cont) 

350 

351 def visit_Num(self, node): # pylint: disable=C0111 

352 cont = {"indent": self._indent, "type": "Num", 

353 "node": node, "str": f"{node.n}", 

354 'n': node.n} 

355 self.push(cont) 

356 return self.generic_visit_args(node, cont) 

357 

358 def visit_Return(self, node): # pylint: disable=C0111 

359 cont = {"indent": self._indent, "type": "Return", "node": node, "str": "", 

360 'value': node.value} 

361 self.push(cont) 

362 return self.generic_visit_args(node, cont) 

363 

364 def visit_NameConstant(self, node): 

365 """ 

366 A name. 

367 """ 

368 if node.value is None: 

369 cont = {"indent": self._indent, "type": "Cst", 

370 "node": node, "str": "None", 

371 'n': None} 

372 self.push(cont) 

373 return self.generic_visit_args(node, cont) 

374 return self.visit_(node) # pragma: no cover