Coverage for mlprodict/onnx_tools/onnx_grammar/node_visitor_translator.py: 95%
171 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief One class which visits a syntax tree.
4"""
6import ast
7from .onnx_translator import OnnxTranslator
10class CodeNodeVisitor(ast.NodeVisitor):
12 """
13 Defines a visitor which walks though the syntax tree of the code.
15 .. exref::
16 :title: Get the tree of a simple function
18 The following code uses Python syntax but follows a SQL logic.
20 .. runpython::
21 :showcode:
22 :warningout: DeprecationWarning
23 :process:
24 :store_in_file: fct2onnx1.py
26 import ast
27 import inspect
28 from textwrap import dedent
29 from mlprodict.onnx_tools.onnx_grammar import CodeNodeVisitor
31 def norm2(x, y):
32 delta = x - y
33 n = delta ** 2
34 return n
36 code = dedent(inspect.getsource(norm2))
37 node = ast.parse(code)
38 v = CodeNodeVisitor()
39 v.visit(node)
40 for r in v.Rows :
41 print("{0}{1}: {2}".format(" " * r["indent"], r["type"], r["str"]))
42 """
44 def __init__(self, translator=None):
45 """
46 @param translator @see cl CodeTranslator
48 By default the translator is @see cl OnnxTranslator.
49 """
50 ast.NodeVisitor.__init__(self)
51 self._rows = []
52 self._indent = 0
53 self._stack = []
54 self._translator = OnnxTranslator(
55 self) if translator is None else translator
57 def push(self, row):
58 """
59 Pushes an element into a list.
60 """
61 self._rows.append(row)
63 def generic_visit(self, node):
64 """
65 Overrides ``generic_visit`` to check it is not used.
66 """
67 raise AttributeError( # pragma: no cover
68 "generic_visit_args should be used.")
70 def generic_visit_args(self, node, row):
71 """
72 Overrides ``generic_visit`` to keep track of the indentation
73 and the node parent. The function will add field
74 ``row["children"] = visited`` nodes from here.
76 @param node node which needs to be visited
77 @param row row (a dictionary)
78 @return See ``ast.NodeVisitor.generic_visit``
79 """
80 if hasattr(node, 'lineno'):
81 row['lineno'] = node.lineno
82 if hasattr(node, 'col_offset'):
83 row['col_offset'] = node.col_offset
84 self._indent += 1
85 last = len(self._rows)
86 self._translator.visit(node, row)
87 res = ast.NodeVisitor.generic_visit( # pylint: disable=E1111
88 self, node) # pylint: disable=E1111
89 row["children"] = [
90 _ for _ in self._rows[
91 last:] if _["indent"] == self._indent]
92 self._indent -= 1
93 self._translator.depart(node, row)
94 return res
96 def make_msg(self, node):
97 """
98 Displays line and column information into a string.
99 """
100 return "line {}, col {}".format( # pragma: no cover
101 getattr(node, 'lineno', '?'), getattr(node, 'col_offset', '?'))
103 def visit(self, node):
104 """
105 Visits a node, a method must exist for every object class.
106 """
107 method = 'visit_' + node.__class__.__name__
108 visitor = getattr(self, method, None)
109 if visitor is None:
110 raise TypeError( # pragma: no cover
111 f"Unable to find a method '{method}' at {self.make_msg(node)}.")
112 res = visitor(node)
113 # print(method, CodeNodeVisitor.print_node(node))
114 return res
116 def visit_(self, node):
117 """
118 If an element is not found...
119 """
120 raise NotImplementedError( # pragma: no cover
121 "Node '{}' ({}) not recognized at {}\nNode\n{}\n--"
122 "Status--\n{}".format(
123 node, type(node), self.make_msg(node),
124 self.print_node(node), self.print_tree()))
126 @staticmethod
127 def print_node(node):
128 """
129 Debugging purpose.
130 """
131 r = []
132 for att in sorted(set(["s", "name", "str", "id", "body", "n",
133 "arg", "targets", "attr", "returns", "ctx",
134 'col_offset', 'lineno',
135 'value'] + list(getattr(node, '_attributes', [])))):
136 v = getattr(node, att, None)
137 if v is not None or att in getattr(node, '_fields', []):
138 r.append(f"{att}={v}")
139 return " ".join(r)
141 def print_tree(self):
142 """
143 Displays the tree of instructions.
145 @return string
146 """
147 rows = []
148 for r in self.Rows:
149 rows.append(
150 f"{' ' * r['indent']}{r['type']}: {r['str']}")
151 return "\n".join(rows)
153 @property
154 def Rows(self):
155 """
156 returns a list of dictionaries with all the elements of the code
157 """
158 return [_ for _ in self._rows if not _.get("remove", False)]
160 def export(self, context=None, **kwargs):
161 """
162 Calls method *export* from the translator class.
164 @param context known :epkg:`python` needed to run
165 the translated function
166 @param kwargs whatever the method *export* from
167 the translator class ingests
168 @return whatever the method *export* from
169 the translator class returns
170 """
171 return self._translator.export(context=context, **kwargs)
173 ###########
174 # Methods for python code elements
175 ###########
177 def visit_Str(self, node): # pylint: disable=C0111
178 cont = {
179 "indent": self._indent,
180 "type": "Str",
181 "str": node.s,
182 "node": node,
183 "value": node.s}
184 self.push(cont)
185 return self.generic_visit_args(node, cont)
187 def visit_Name(self, node): # pylint: disable=C0111
188 cont = {
189 "indent": self._indent,
190 "type": "Name",
191 "str": node.id,
192 "node": node,
193 "id": node.id,
194 "ctx": node.ctx}
195 self.push(cont)
196 return self.generic_visit_args(node, cont)
198 def visit_Module(self, node): # pylint: disable=C0111
199 cont = {
200 "indent": self._indent,
201 "type": "Module",
202 "str": "",
203 "body": node.body,
204 "node": node}
205 self.push(cont)
206 return self.generic_visit_args(node, cont)
208 def visit_FunctionDef(self, node): # pylint: disable=C0111
209 cont = {"indent": self._indent, "type": "FunctionDef", "str": node.name, "name": node.name, "body": node.body,
210 "node": node, "returns": node.returns}
211 self.push(cont)
212 return self.generic_visit_args(node, cont)
214 def visit_List(self, node): # pylint: disable=C0111
215 cont = {"indent": self._indent, "type": "List",
216 "str": "", "elts": node.elts,
217 "node": node}
218 self.push(cont)
219 return self.generic_visit_args(node, cont)
221 def visit_arguments(self, node): # pylint: disable=C0111
222 cont = {"indent": self._indent, "type": "arguments", "str": "",
223 "node": node, "args": node.args}
224 self.push(cont)
225 return self.generic_visit_args(node, cont)
227 def visit_arg(self, node): # pylint: disable=C0111
228 cont = {"indent": self._indent, "type": "arg", "str": node.arg,
229 "node": node,
230 "arg": node.arg, "annotation": node.annotation}
231 self.push(cont)
232 return self.generic_visit_args(node, cont)
234 def visit_Assign(self, node): # pylint: disable=C0111
235 cont = {"indent": self._indent, "type": "Assign", "str": "", "node": node,
236 "targets": node.targets, "value": node.value}
237 self.push(cont)
238 return self.generic_visit_args(node, cont)
240 def visit_Store(self, node): # pylint: disable=C0111
241 #cont = { "indent":self._indent, "type": "Store", "str": "" }
242 # self.push(cont)
243 cont = {}
244 return self.generic_visit_args(node, cont)
246 def visit_Call(self, node): # pylint: disable=C0111
247 if "attr" in node.func.__dict__:
248 cont = {"indent": self._indent, "type": "Call", "str": node.func.attr,
249 "node": node, "func": node.func}
250 else:
251 cont = {"indent": self._indent, "type": "Call", "str": node.func.id,
252 "node": node, "func": node.func}
253 self.push(cont)
254 return self.generic_visit_args(node, cont)
256 def visit_Attribute(self, node): # pylint: disable=C0111
257 cont = {"indent": self._indent, "type": "Attribute", "str": node.attr,
258 "node": node, "value": node.value, "ctx": node.ctx, "attr": node.attr}
259 self.push(cont)
260 # last = len(self._rows)
261 res = self.generic_visit_args(node, cont)
263 if len(cont["children"]) > 0:
264 fir = cont["children"][0]
265 if fir["type"] == "Name":
266 parent = fir["node"].id
267 cont["str"] = f"{parent}.{cont['str']}"
268 cont["children"][0]["remove"] = True
269 return res
271 def visit_Load(self, node): # pylint: disable=C0111
272 cont = {}
273 return self.generic_visit_args(node, cont)
275 def visit_keyword(self, node): # pylint: disable=C0111
276 cont = {"indent": self._indent, "type": "keyword", "str": f"{node.arg}",
277 "node": node, "arg": node.arg, "value": node.value}
278 self.push(cont)
279 return self.generic_visit_args(node, cont)
281 def visit_BinOp(self, node): # pylint: disable=C0111
282 cont = {"indent": self._indent, "type": "BinOp",
283 "str": "", "node": node}
284 self.push(cont)
285 return self.generic_visit_args(node, cont)
287 def visit_Div(self, node): # pylint: disable=C0111
288 cont = {"indent": self._indent, "type": "Div",
289 "str": "", "node": node}
290 self.push(cont)
291 return self.generic_visit_args(node, cont)
293 def visit_Sub(self, node): # pylint: disable=C0111
294 cont = {"indent": self._indent, "type": "Sub",
295 "str": "", "node": node}
296 self.push(cont)
297 return self.generic_visit_args(node, cont)
299 def visit_USub(self, node): # pylint: disable=C0111
300 cont = {"indent": self._indent, "type": "Sub",
301 "str": "", "node": node}
302 self.push(cont)
303 return self.generic_visit_args(node, cont)
305 def visit_Add(self, node): # pylint: disable=C0111
306 cont = {"indent": self._indent, "type": "Add",
307 "str": "", "node": node}
308 self.push(cont)
309 return self.generic_visit_args(node, cont)
311 def visit_Pow(self, node): # pylint: disable=C0111
312 cont = {"indent": self._indent, "type": "Pow",
313 "str": "", "node": node}
314 self.push(cont)
315 return self.generic_visit_args(node, cont)
317 def visit_Mult(self, node): # pylint: disable=C0111
318 cont = {"indent": self._indent, "type": "Mult",
319 "str": "", "node": node}
320 self.push(cont)
321 return self.generic_visit_args(node, cont)
323 def visit_MatMult(self, node): # pylint: disable=C0111
324 cont = {"indent": self._indent, "type": "MatMult",
325 "str": "", "node": node}
326 self.push(cont)
327 return self.generic_visit_args(node, cont)
329 def visit_Compare(self, node): # pylint: disable=C0111
330 cont = {"indent": self._indent, "type": "Compare",
331 "str": "", "node": node}
332 self.push(cont)
333 return self.generic_visit_args(node, cont)
335 def visit_Gt(self, node): # pylint: disable=C0111
336 cont = {"indent": self._indent, "type": "Gt", "str": "", "node": node}
337 self.push(cont)
338 return self.generic_visit_args(node, cont)
340 def visit_Lt(self, node): # pylint: disable=C0111
341 cont = {"indent": self._indent, "type": "Lt", "str": "", "node": node}
342 self.push(cont)
343 return self.generic_visit_args(node, cont)
345 def visit_UnaryOp(self, node): # pylint: disable=C0111
346 cont = {"indent": self._indent,
347 "type": "UnaryOp", "str": "", "node": node}
348 self.push(cont)
349 return self.generic_visit_args(node, cont)
351 def visit_Num(self, node): # pylint: disable=C0111
352 cont = {"indent": self._indent, "type": "Num",
353 "node": node, "str": f"{node.n}",
354 'n': node.n}
355 self.push(cont)
356 return self.generic_visit_args(node, cont)
358 def visit_Return(self, node): # pylint: disable=C0111
359 cont = {"indent": self._indent, "type": "Return", "node": node, "str": "",
360 'value': node.value}
361 self.push(cont)
362 return self.generic_visit_args(node, cont)
364 def visit_NameConstant(self, node):
365 """
366 A name.
367 """
368 if node.value is None:
369 cont = {"indent": self._indent, "type": "Cst",
370 "node": node, "str": "None",
371 'n': None}
372 self.push(cont)
373 return self.generic_visit_args(node, cont)
374 return self.visit_(node) # pragma: no cover