Coverage for mlprodict/onnx_tools/onnx_grammar/onnx_translator.py: 97%
384 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief One class which visits a syntax tree.
4"""
5import pprint
6import numpy
9class CodeTranslator:
10 """
11 Class which converts a Python function into
12 something else. It must implements
13 methods *visit* and *depart*.
14 """
16 def __init__(self, visitor):
17 """
18 :param visitor: :class:`CodeNodeVisitor
19 <mlprodict.onnx_tools.onnx_grammar.node_visitor_translator>`
20 """
21 self._visitor = visitor
23 def export(self, context=None, **kwargs):
24 """
25 Exports the parsed :epkg:`python` code
26 into something.
27 """
28 raise NotImplementedError( # pragma: no cover
29 "This function should be overwritten.")
31 def visit(self, node, info):
32 """
33 Visits a node.
35 @param node visited node
36 @param info info extracted by the visitor
37 """
38 raise NotImplementedError( # pragma: no cover
39 "This function should be overwritten.")
41 def depart(self, node, info):
42 """
43 Leaves a node.
45 @param node visited node
46 @param info info extracted by the visitor
47 """
48 raise NotImplementedError( # pragma: no cover
49 "This function should be overwritten.")
52class OnnxTranslator(CodeTranslator):
53 """
54 Class which converts a Python function into
55 an :epkg:`ONNX` function. It must implements
56 methods *visit* and *depart*.
57 """
58 _binary_operators = {
59 'Add': 'Add', 'Div': 'Div',
60 'Mult': 'Mul', 'Sub': 'Sub',
61 'Pow': 'Pow', 'MatMult': 'MatMul',
62 }
64 _unary_operators = {
65 'Sub': 'Neg',
66 }
68 _numpy2onnx_op = {
69 'absolute': 'Abs',
70 'cos': 'Cos',
71 'exp': 'Exp',
72 'power': 'Pow',
73 'transpose': 'Transpose',
74 'sin': 'Sin',
75 # complex function
76 'inner': 'inner',
77 }
79 _parameter_mapping = {
80 'Transpose': {'axes': 'perm'}
81 }
83 class Parameter:
84 """
85 Holds parameter information.
86 """
88 def __init__(self, name, value=('#NODEFAULT#', ), annotation=None):
89 """
90 @param name parameter name
91 @param value parameter value
92 """
93 self.name = name
94 self.value = value
95 self.annotation = annotation
97 @staticmethod
98 def format_value(value):
99 """
100 Returns a formatted value in python code.
101 """
102 if isinstance(value, str):
103 return '"{}"'.format(value.replace('"', '\\"').replace('\\', '\\\\'))
104 if isinstance(value, list):
105 return f"[{', '.join(map(OnnxTranslator.Parameter.format_value, value))}]"
106 if isinstance(value, tuple):
107 if value == ('#NODEFAULT#', ):
108 return None
109 return f"({', '.join(map(OnnxTranslator.Parameter.format_value, value))})"
110 return str(value)
112 @property
113 def formatted_value(self):
114 """
115 Returns a formatted value in python code.
116 """
117 return OnnxTranslator.Parameter.format_value(self.value)
119 def __str__(self):
120 """
121 Into python syntax.
122 """
123 rows = [self.name]
124 if self.value != ('#NODEFAULT#', ):
125 rows.append('=')
126 rows.append(self.formatted_value)
127 return ''.join(rows)
129 def __init__(self, visitor):
130 """
131 :param visitor: :class:`CodeNodeVisitor
132 <mlprodict.onnx_tools.onnx_grammar.node_visitor_translator>`
133 """
134 CodeTranslator.__init__(self, visitor)
135 self._stack = []
136 self._code_fct = None
138 def _is_stacked(self, name):
139 for line in self._stack:
140 if line[0] == name:
141 return True
142 return False
144 def _get_last(self, name, info=None):
145 if len(self._stack) == 0:
146 raise RuntimeError("Stack is empty.") # pragma: no cover
147 last = self._stack[-1]
148 if ((isinstance(name, str) and last[0] != name) or
149 (isinstance(name, tuple) and last[0] not in name)):
150 raise RuntimeError( # pragma: no cover
151 "Last item is not '{}'\n{}\n---\n{}".format(
152 name, pprint.pformat(self._stack),
153 pprint.pformat(info) if info else ""))
154 return last
156 def make_msg(self, info):
157 """
158 Make a message with line and column information.
159 """
160 lineno = '?'
161 col_offset = '?'
162 if isinstance(info, dict):
163 if 'node' in info:
164 node = info['node']
165 lineno = node.lineno
166 col_offset = node.col_offset
167 else:
168 if 'lineno' in info:
169 lineno = info['lineno']
170 if 'col_offset' in info:
171 col_offset = info['col_offset']
172 else:
173 if hasattr(info, 'lineno'):
174 lineno = info.lineno
175 if hasattr(info, 'col_offset'):
176 col_offset = info.col_offset
178 return f"line {lineno}, col {col_offset}"
180 def export(self, context=None, format='code', # pylint: disable=W0221
181 output_names=None):
182 """
183 Returns an :epkg:`ONNX` graph or a piece
184 of code which could generate the graph.
186 @param context function used in the function code
187 @param format ``'code'``
188 @param output_names add code in the final function
189 to overwrite the names of the
190 outputs in the :epkg:`ONNX` graph
191 @return string or :epkg:`onnx` graph
193 This method is used in function @see fn translate_fct2onnx.
194 An example of code can be found there.
195 """
196 if self._code_fct is None:
197 raise RuntimeError( # pragma: no cover
198 "No python code was parsed.")
199 if context is None:
200 context = {}
202 def find_onnx_correspondance(fct, info):
203 if isinstance(fct, numpy.ufunc):
204 name = fct.__name__
205 elif callable(fct) and getattr(fct, '__module__', '') in (
206 'numpy', 'numpy.core.fromnumeric'):
207 name = fct.__name__
208 elif callable(fct) and fct.__name__.startswith("py_"):
209 return fct
210 else:
211 name = None
212 if name is not None and name not in OnnxTranslator._numpy2onnx_op:
213 raise RuntimeError( # pragma: no cover
214 "Unable to find a correspondance to '{}' at {} in \n{}".format(
215 name, self.make_msg(info),
216 "\n".join(sorted(OnnxTranslator._numpy2onnx_op))))
217 if name is not None:
218 return OnnxTranslator._numpy2onnx_op[name]
219 if isinstance(fct, str):
220 return fct
221 raise RuntimeError( # pragma: no cover
222 "Unable to find a correspondance for function name '{}' in module '{}', "
223 "'{}' (type {}) at {}.".format(
224 name, getattr(fct, '__module__', ''),
225 fct, type(fct), self.make_msg(info)))
227 def write_expression(stack_fct_used, expr, indent, parameter_mapping=None):
228 if isinstance(expr, str):
229 # an argument
230 return [f"{' ' * indent * 4}{expr}"]
231 if isinstance(expr, (int, float)):
232 # an argument
233 return [f"{' ' * indent * 4}{expr}"]
234 if isinstance(expr, OnnxTranslator.Parameter):
235 if parameter_mapping is None:
236 name = expr.name
237 else:
238 name = parameter_mapping.get(expr.name, expr.name)
239 return [f"{' ' * indent * 4}{name}={expr.formatted_value}"]
240 rows = []
241 if isinstance(expr, tuple):
242 expr = [expr]
243 for op, args in expr:
244 if op == 'BinOp':
245 opname = args["op"]
246 opon = args["args"]
247 onnx_name = OnnxTranslator._binary_operators[opname]
248 rows.append(
249 f"{' ' * indent * 4}Onnx{onnx_name}(")
250 for expr2 in opon:
251 sexpr2 = write_expression(
252 stack_fct_used, expr2, indent + 1)
253 if any(filter(lambda s: 'op_version="op_version"' in s, sexpr2)):
254 continue # pragma: no cover
255 rows.extend(sexpr2)
256 rows[-1] += ","
257 rows.append(
258 f"{' ' * (indent + 1) * 4}op_version=op_version")
259 rows.append(f"{' ' * indent * 4})")
260 elif op == 'UnaryOp':
261 opname = args["op"]
262 opon = args["args"]
263 onnx_name = OnnxTranslator._unary_operators[opname]
264 rows.append(
265 f"{' ' * indent * 4}Onnx{onnx_name}(")
266 for expr2 in opon:
267 sexpr2 = write_expression(
268 stack_fct_used, expr2, indent + 1)
269 if any(filter(lambda s: 'op_version="op_version"' in s, sexpr2)):
270 continue
271 rows.extend(sexpr2)
272 rows[-1] += ","
273 rows.append(
274 f"{' ' * (indent + 1) * 4}op_version=op_version")
275 rows.append(f"{' ' * indent * 4})")
276 elif op == 'Call':
277 name = args['name']
278 if name.startswith("onnx_"):
279 raise RuntimeError("The code must not use a function prefixed by 'onnx_' (%s). "
280 "It indicates that function manipulate ONNX node and "
281 "the fonction to convert must only deal with arrays." % name)
282 if name not in context:
283 raise RuntimeError(
284 "Unable to find function '{}' at {} in context\n{}\n--\n{}".format(
285 name, self.make_msg(args),
286 '\n'.join(sorted(context)),
287 pprint.pformat(args)))
288 op_conv = find_onnx_correspondance(context[name], args)
289 if callable(op_conv) and op_conv.__name__.startswith('py_'):
290 rows.append(
291 f"{' ' * indent * 4}{op_conv.__name__}(")
292 elif callable(op_conv) and op_conv.__name__.startswith('onnx_'):
293 stack_fct_used.append(op_conv.__name__)
294 rows.append(
295 f"{' ' * indent * 4}{op_conv}(")
296 else:
297 prefix = "onnx_" if 'a' <= op_conv[0] <= 'z' else 'Onnx'
298 if prefix == "onnx_":
299 stack_fct_used.append(
300 f"{prefix}{op_conv}")
301 prefix = '_' + prefix
302 rows.append(
303 f"{' ' * indent * 4}{prefix}{op_conv}(")
305 opon = args["args"]
306 opon = opon[1:]
307 for expr2 in opon:
308 sexpr2 = write_expression(
309 stack_fct_used, expr2, indent + 1,
310 OnnxTranslator._parameter_mapping.get(op_conv, None))
311 if any(filter(lambda s: 'op_version="op_version"' in s, sexpr2)):
312 continue
313 rows.extend(sexpr2)
314 rows[-1] += ","
315 rows.append(
316 f"{' ' * (indent + 1) * 4}op_version=op_version")
317 rows.append(f"{' ' * indent * 4})")
318 else:
319 raise RuntimeError( # pragma: no cover
320 f"Unable to interpret '{expr}'.")
321 return rows
323 def write_function(stack_fct_used, to_replaces, node):
324 rows = []
325 name, args = node
326 if name != 'FunctionDef':
327 raise RuntimeError( # pragma: no cover
328 "The code being translated should be a single function not "
329 "'{}' at {}.".format(name, self.make_msg(args)))
330 list_args = list(map(str, args['args']))
331 if all(map(lambda s: 'dtype=' not in s, list_args)):
332 list_args.append("dtype=numpy.float32")
333 if all(map(lambda s: 'op_version=' not in s, list_args)):
334 list_args.append("op_version=None")
335 fct_name = args['name']
336 rows.append(f"def {fct_name}({', '.join(list_args)}):")
337 indent = 1
339 to_replace = f"# __HEADER__{id(node)}"
340 to_replaces.append(to_replace)
341 rows.append(f"{' ' * (indent * 4)}{to_replace}")
343 code = args['code']
344 for op, args in code:
345 if op == "Assign":
346 name = args['name']
347 args = args["args"]
348 rows.append(f"{' ' * (indent * 4)}{name} = (")
349 rows.extend(write_expression(
350 stack_fct_used, args, indent + 1))
351 rows.append(f"{' ' * (indent * 4)})")
352 elif op == "Return":
353 args = args["code"]
354 if output_names is None:
355 rows.append(f"{' ' * (indent * 4)}return (")
356 rows.extend(write_expression(
357 stack_fct_used, args, indent + 1))
358 rows.append(f"{' ' * (indent * 4)})")
359 else:
360 rows.append(
361 f"{' ' * (indent * 4)}return OnnxIdentity(")
362 subrows = write_expression(
363 stack_fct_used, args, indent + 1)
364 subrows[-1] += ","
365 rows.extend(subrows)
366 rows.append("{}output_names={},".format(
367 " " * ((indent + 1) * 4), str(output_names)))
368 rows.append(
369 f"{' ' * ((indent + 1) * 4)}op_version=op_version")
370 rows.append(f"{' ' * (indent * 4)})")
371 else:
372 raise RuntimeError( # pragma: no cover
373 "Unable to process operator '{}' at {}. "
374 "Make sure it is either an affectation, "
375 "either a return.".format(op, self.make_msg(args)))
376 return rows
378 stack_fct_used = []
379 to_replaces = []
380 rows = write_function(stack_fct_used, to_replaces, self._code_fct)
382 # handling dtype parameter
383 if len(to_replaces) != 1:
384 raise RuntimeError( # pragma: no cover
385 "The following code misses a placeholder:\n{}".format(
386 "\n".join(rows)))
387 index = -1
388 for i, row in enumerate(rows):
389 if to_replaces[0] in row:
390 index = i
391 break
393 header = []
394 for fct in stack_fct_used:
395 header.append(
396 " _{0} = lambda *args, op_version=op_version, **kwargs: {0}(*args, dtype=dtype, "
397 "op_version=op_version, **kwargs)".format(fct))
398 if len(header) > 0:
399 header.append('')
400 rows[index:index + 1] = header
402 return "\n".join(rows)
404 def visit(self, node, info):
405 """
406 Visits a node.
408 @param node visited node
409 @param info info extracted by the visitor
410 """
411 if 'type' not in info:
412 return
414 kind = info['type']
415 if kind == "Module":
416 return
417 if kind == "FunctionDef":
418 if self._is_stacked('FunctionDef'):
419 raise RuntimeError("Nested functions are not allowed at {}.".format(
420 self.make_msg(node)))
421 self._stack.append(
422 ('FunctionDef', {'args': [], 'code': [], 'name': info['name'], 'default': [],
423 'lineno': node.lineno, 'col_offset': node.col_offset}))
424 return
425 if kind == "arguments":
426 _, buf = self._get_last('FunctionDef')
427 return
428 if kind == "arg":
429 return
430 if kind == "Assign":
431 self._stack.append(
432 ('Assign', {'args': [], 'lineno': node.lineno, 'col_offset': node.col_offset}))
433 return
434 if kind in ('Name', 'Cst'):
435 self._get_last(
436 ('Assign', 'BinOp', 'Call', 'Return', 'FunctionDef', 'keyword', 'UnaryOp'))
437 return
438 if kind == 'BinOp':
439 self._stack.append(
440 ('BinOp', {'args': [], 'lineno': node.lineno, 'col_offset': node.col_offset}))
441 return
442 if kind == 'UnaryOp':
443 self._stack.append(
444 ('UnaryOp', {'args': [], 'lineno': node.lineno, 'col_offset': node.col_offset}))
445 return
446 if kind in OnnxTranslator._binary_operators:
447 _, buf = self._get_last(('BinOp', 'UnaryOp'))
448 buf['op'] = kind
449 return
450 if kind == 'Call':
451 self._stack.append(
452 ('Call', {'name': info['str'], 'args': [], 'lineno': node.lineno,
453 'col_offset': node.col_offset}))
454 return
455 if kind == 'Return':
456 self._get_last('FunctionDef')
457 self._stack.append(
458 ('Return', {'code': [], 'lineno': node.lineno, 'col_offset': node.col_offset}))
459 return
460 if kind == "Attribute":
461 if info.get('str', '') == 'T':
462 raise NotImplementedError( # pragma: no cover
463 "Transpose should be done with numpy.transpose not with .T'{}' "
464 "at {}\n{}\n---\n{}".format(
465 info.get('type', '?'), self.make_msg(node),
466 pprint.pformat(info), pprint.pformat(self._stack)))
467 self._get_last('Call')
468 return
469 if kind == 'keyword':
470 self._get_last('Call')
471 self._stack.append(
472 ('keyword', {'name': f"{node.arg}",
473 'lineno': getattr(node, 'lineno', '?'),
474 'col_offset': getattr(node, 'col_offset', '?')}))
475 return
476 if kind == 'List':
477 self._get_last('keyword')
478 self._stack.append(
479 ('List', {'elts': [], 'lineno': getattr(node, 'lineno', '?'),
480 'col_offset': getattr(node, 'col_offset', '?')}))
481 return
482 if kind == 'Num':
483 self._get_last(('List', 'UnaryOp', 'BinOp', 'FunctionDef', 'Call'))
484 return
485 if kind == 'Str':
486 self._get_last('keyword')
487 return
489 raise NotImplementedError( # pragma: no cover
490 "Unable to interpret kind '{}' at {}\n{}\n---\n{}".format(
491 info.get('type', '?'), self.make_msg(
492 node), pprint.pformat(info),
493 pprint.pformat(self._stack)))
495 def _fix_default_values(self, code_fct):
496 """
497 Maps default values with parameter names.
498 """
499 nbdef = len(code_fct[1]['default'])
500 nbpar = len(code_fct[1]['args'])
501 args = []
502 for i in range(nbpar):
503 name, annotation = code_fct[1]['args'][i]
504 j = nbdef - (nbpar - i)
505 if j >= 0:
506 default = code_fct[1]['default'][j]
507 p = OnnxTranslator.Parameter(
508 name, annotation=annotation, value=default)
509 else:
510 p = OnnxTranslator.Parameter(name, annotation=annotation)
511 args.append(p)
512 code_fct[1]['args'] = args
514 def _post_process(self, op, node):
515 """
516 Simplifies some operator such as ``OnnxNeg(2)``.
517 """
518 if op is None and 'args' in node:
519 for i in range(len(node['args'])):
520 if not isinstance(node['args'][i], tuple):
521 continue
522 o, v = node['args'][i]
523 if (o == 'UnaryOp' and len(v['args']) == 1 and
524 isinstance(v['args'][0], (int, float, numpy.int64,
525 numpy.float32, numpy.float64))):
526 if v['op'] == 'Sub':
527 node['args'][i] = -v['args'][0]
529 def depart(self, node, info):
530 """
531 Visits a node.
533 @param node visited node
534 @param info info extracted by the visitor
535 """
536 if 'type' not in info:
537 return
539 kind = info['type']
540 if kind == "arg":
541 return
542 if kind == "arguments":
543 _, buf = self._get_last('FunctionDef')
544 for child in info['children']:
545 if child['type'] == 'Str':
546 buf['default'].append(child['str'])
547 elif child['type'] in ('Num', 'Cst'):
548 buf['default'].append(child['n'])
549 elif child['type'] == 'arg':
550 buf['args'].append(
551 (child['str'], child.get('annotation', None)))
552 else:
553 raise RuntimeError( # pragma: no cover
554 "Unable to interpret type '{}' in function definition."
555 "\n{}".format(
556 child['type'], pprint.pformat(info)))
557 return
559 if kind == "Name":
560 op, buf = self._get_last(
561 ('Assign', 'BinOp', 'Call', 'Return', 'FunctionDef', 'keyword',
562 'UnaryOp'),
563 info)
564 if op == 'Assign':
565 buf['name'] = info['str']
566 return
567 elif op in ('BinOp', 'Call'):
568 buf['args'].append(info['str'])
569 return
570 elif op == 'Return':
571 buf['code'] = info['str']
572 return
573 elif op == 'keyword':
574 buf['value'] = info['str']
575 return
576 elif op == 'UnaryOp':
577 buf['args'].append(info['str'])
578 return
579 elif op == 'FunctionDef':
580 raise RuntimeError("Default value must be constant, variable '{}' was "
581 "detected.".format(info['str']))
583 if kind in OnnxTranslator._binary_operators:
584 _, buf = self._get_last(('BinOp', 'UnaryOp'))
585 return
586 if kind in ('Call', 'BinOp', 'Assign', 'Return', 'UnaryOp'):
587 op, buf = self._get_last(
588 ('Call', 'BinOp', 'Assign', 'Return', 'UnaryOp'))
589 self._post_process(op, buf)
590 self._stack.pop()
591 opp, parent = self._get_last(
592 ('Call', 'BinOp', 'Assign', 'FunctionDef', 'Return', 'UnaryOp'))
593 if opp in ('FunctionDef', 'Return'):
594 parent['code'].append((op, buf))
595 else:
596 parent['args'].append((op, buf))
597 self._post_process(None, parent)
598 return
599 if kind == 'FunctionDef':
600 if len(self._stack) == 1:
601 self._code_fct = self._stack[-1]
602 self._fix_default_values(self._code_fct)
603 self._stack = []
604 return
605 if kind == 'Module':
606 return
607 if kind == 'Attribute':
608 op, buf = self._get_last(('Call', 'BinOp'))
610 if len(info["children"]) > 0:
611 fir = info["children"][0]
612 if fir["type"] == "Name":
613 parent = fir["node"].id
614 info["str"] = f"{parent}.{info['str']}"
615 info["children"][0]["remove"] = True
617 buf['name'] = info["str"]
618 buf['args'][0] = info["str"]
619 return
620 if kind in ('Num', 'Cst'):
621 op, buf = self._get_last(
622 ('List', 'BinOp', 'UnaryOp', 'FunctionDef', 'Call'))
623 if op == 'FunctionDef':
624 return
625 if op == 'List':
626 buf['elts'].append(info['n'])
627 else:
628 buf['args'].append(info['n'])
629 return
630 if kind == 'Str':
631 _, buf = self._get_last('keyword')
632 buf['value'] = info['str']
633 return
634 if kind == 'List':
635 op, buf = self._get_last('List')
636 value = buf['elts']
637 self._post_process(op, buf)
638 self._stack.pop()
639 opp, parent = self._get_last('keyword')
640 parent['value'] = value
641 self._post_process(None, parent)
642 return
643 if kind == 'keyword':
644 op, buf = self._get_last('keyword')
645 name = buf["name"]
646 if 'value' not in buf:
647 raise RuntimeError(str(buf)) # pragma: no cover
648 value = buf['value']
649 self._post_process(op, buf)
650 self._stack.pop()
651 opp, parent = self._get_last('Call')
652 parent['args'].append(OnnxTranslator.Parameter(name, value))
653 self._post_process(None, parent)
654 return
656 raise NotImplementedError( # pragma: no cover
657 "Unable to interpret kind '{}' at {}\n{}\n---\n{}".format(
658 info.get('type', '?'), self.make_msg(
659 node), pprint.pformat(info),
660 pprint.pformat(self._stack)))