Coverage for mlprodict/onnx_tools/onnx_grammar/onnx_translator.py: 97%

384 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief One class which visits a syntax tree. 

4""" 

5import pprint 

6import numpy 

7 

8 

9class CodeTranslator: 

10 """ 

11 Class which converts a Python function into 

12 something else. It must implements 

13 methods *visit* and *depart*. 

14 """ 

15 

16 def __init__(self, visitor): 

17 """ 

18 :param visitor: :class:`CodeNodeVisitor 

19 <mlprodict.onnx_tools.onnx_grammar.node_visitor_translator>` 

20 """ 

21 self._visitor = visitor 

22 

23 def export(self, context=None, **kwargs): 

24 """ 

25 Exports the parsed :epkg:`python` code 

26 into something. 

27 """ 

28 raise NotImplementedError( # pragma: no cover 

29 "This function should be overwritten.") 

30 

31 def visit(self, node, info): 

32 """ 

33 Visits a node. 

34 

35 @param node visited node 

36 @param info info extracted by the visitor 

37 """ 

38 raise NotImplementedError( # pragma: no cover 

39 "This function should be overwritten.") 

40 

41 def depart(self, node, info): 

42 """ 

43 Leaves a node. 

44 

45 @param node visited node 

46 @param info info extracted by the visitor 

47 """ 

48 raise NotImplementedError( # pragma: no cover 

49 "This function should be overwritten.") 

50 

51 

52class OnnxTranslator(CodeTranslator): 

53 """ 

54 Class which converts a Python function into 

55 an :epkg:`ONNX` function. It must implements 

56 methods *visit* and *depart*. 

57 """ 

58 _binary_operators = { 

59 'Add': 'Add', 'Div': 'Div', 

60 'Mult': 'Mul', 'Sub': 'Sub', 

61 'Pow': 'Pow', 'MatMult': 'MatMul', 

62 } 

63 

64 _unary_operators = { 

65 'Sub': 'Neg', 

66 } 

67 

68 _numpy2onnx_op = { 

69 'absolute': 'Abs', 

70 'cos': 'Cos', 

71 'exp': 'Exp', 

72 'power': 'Pow', 

73 'transpose': 'Transpose', 

74 'sin': 'Sin', 

75 # complex function 

76 'inner': 'inner', 

77 } 

78 

79 _parameter_mapping = { 

80 'Transpose': {'axes': 'perm'} 

81 } 

82 

83 class Parameter: 

84 """ 

85 Holds parameter information. 

86 """ 

87 

88 def __init__(self, name, value=('#NODEFAULT#', ), annotation=None): 

89 """ 

90 @param name parameter name 

91 @param value parameter value 

92 """ 

93 self.name = name 

94 self.value = value 

95 self.annotation = annotation 

96 

97 @staticmethod 

98 def format_value(value): 

99 """ 

100 Returns a formatted value in python code. 

101 """ 

102 if isinstance(value, str): 

103 return '"{}"'.format(value.replace('"', '\\"').replace('\\', '\\\\')) 

104 if isinstance(value, list): 

105 return f"[{', '.join(map(OnnxTranslator.Parameter.format_value, value))}]" 

106 if isinstance(value, tuple): 

107 if value == ('#NODEFAULT#', ): 

108 return None 

109 return f"({', '.join(map(OnnxTranslator.Parameter.format_value, value))})" 

110 return str(value) 

111 

112 @property 

113 def formatted_value(self): 

114 """ 

115 Returns a formatted value in python code. 

116 """ 

117 return OnnxTranslator.Parameter.format_value(self.value) 

118 

119 def __str__(self): 

120 """ 

121 Into python syntax. 

122 """ 

123 rows = [self.name] 

124 if self.value != ('#NODEFAULT#', ): 

125 rows.append('=') 

126 rows.append(self.formatted_value) 

127 return ''.join(rows) 

128 

129 def __init__(self, visitor): 

130 """ 

131 :param visitor: :class:`CodeNodeVisitor 

132 <mlprodict.onnx_tools.onnx_grammar.node_visitor_translator>` 

133 """ 

134 CodeTranslator.__init__(self, visitor) 

135 self._stack = [] 

136 self._code_fct = None 

137 

138 def _is_stacked(self, name): 

139 for line in self._stack: 

140 if line[0] == name: 

141 return True 

142 return False 

143 

144 def _get_last(self, name, info=None): 

145 if len(self._stack) == 0: 

146 raise RuntimeError("Stack is empty.") # pragma: no cover 

147 last = self._stack[-1] 

148 if ((isinstance(name, str) and last[0] != name) or 

149 (isinstance(name, tuple) and last[0] not in name)): 

150 raise RuntimeError( # pragma: no cover 

151 "Last item is not '{}'\n{}\n---\n{}".format( 

152 name, pprint.pformat(self._stack), 

153 pprint.pformat(info) if info else "")) 

154 return last 

155 

156 def make_msg(self, info): 

157 """ 

158 Make a message with line and column information. 

159 """ 

160 lineno = '?' 

161 col_offset = '?' 

162 if isinstance(info, dict): 

163 if 'node' in info: 

164 node = info['node'] 

165 lineno = node.lineno 

166 col_offset = node.col_offset 

167 else: 

168 if 'lineno' in info: 

169 lineno = info['lineno'] 

170 if 'col_offset' in info: 

171 col_offset = info['col_offset'] 

172 else: 

173 if hasattr(info, 'lineno'): 

174 lineno = info.lineno 

175 if hasattr(info, 'col_offset'): 

176 col_offset = info.col_offset 

177 

178 return f"line {lineno}, col {col_offset}" 

179 

180 def export(self, context=None, format='code', # pylint: disable=W0221 

181 output_names=None): 

182 """ 

183 Returns an :epkg:`ONNX` graph or a piece 

184 of code which could generate the graph. 

185 

186 @param context function used in the function code 

187 @param format ``'code'`` 

188 @param output_names add code in the final function 

189 to overwrite the names of the 

190 outputs in the :epkg:`ONNX` graph 

191 @return string or :epkg:`onnx` graph 

192 

193 This method is used in function @see fn translate_fct2onnx. 

194 An example of code can be found there. 

195 """ 

196 if self._code_fct is None: 

197 raise RuntimeError( # pragma: no cover 

198 "No python code was parsed.") 

199 if context is None: 

200 context = {} 

201 

202 def find_onnx_correspondance(fct, info): 

203 if isinstance(fct, numpy.ufunc): 

204 name = fct.__name__ 

205 elif callable(fct) and getattr(fct, '__module__', '') in ( 

206 'numpy', 'numpy.core.fromnumeric'): 

207 name = fct.__name__ 

208 elif callable(fct) and fct.__name__.startswith("py_"): 

209 return fct 

210 else: 

211 name = None 

212 if name is not None and name not in OnnxTranslator._numpy2onnx_op: 

213 raise RuntimeError( # pragma: no cover 

214 "Unable to find a correspondance to '{}' at {} in \n{}".format( 

215 name, self.make_msg(info), 

216 "\n".join(sorted(OnnxTranslator._numpy2onnx_op)))) 

217 if name is not None: 

218 return OnnxTranslator._numpy2onnx_op[name] 

219 if isinstance(fct, str): 

220 return fct 

221 raise RuntimeError( # pragma: no cover 

222 "Unable to find a correspondance for function name '{}' in module '{}', " 

223 "'{}' (type {}) at {}.".format( 

224 name, getattr(fct, '__module__', ''), 

225 fct, type(fct), self.make_msg(info))) 

226 

227 def write_expression(stack_fct_used, expr, indent, parameter_mapping=None): 

228 if isinstance(expr, str): 

229 # an argument 

230 return [f"{' ' * indent * 4}{expr}"] 

231 if isinstance(expr, (int, float)): 

232 # an argument 

233 return [f"{' ' * indent * 4}{expr}"] 

234 if isinstance(expr, OnnxTranslator.Parameter): 

235 if parameter_mapping is None: 

236 name = expr.name 

237 else: 

238 name = parameter_mapping.get(expr.name, expr.name) 

239 return [f"{' ' * indent * 4}{name}={expr.formatted_value}"] 

240 rows = [] 

241 if isinstance(expr, tuple): 

242 expr = [expr] 

243 for op, args in expr: 

244 if op == 'BinOp': 

245 opname = args["op"] 

246 opon = args["args"] 

247 onnx_name = OnnxTranslator._binary_operators[opname] 

248 rows.append( 

249 f"{' ' * indent * 4}Onnx{onnx_name}(") 

250 for expr2 in opon: 

251 sexpr2 = write_expression( 

252 stack_fct_used, expr2, indent + 1) 

253 if any(filter(lambda s: 'op_version="op_version"' in s, sexpr2)): 

254 continue # pragma: no cover 

255 rows.extend(sexpr2) 

256 rows[-1] += "," 

257 rows.append( 

258 f"{' ' * (indent + 1) * 4}op_version=op_version") 

259 rows.append(f"{' ' * indent * 4})") 

260 elif op == 'UnaryOp': 

261 opname = args["op"] 

262 opon = args["args"] 

263 onnx_name = OnnxTranslator._unary_operators[opname] 

264 rows.append( 

265 f"{' ' * indent * 4}Onnx{onnx_name}(") 

266 for expr2 in opon: 

267 sexpr2 = write_expression( 

268 stack_fct_used, expr2, indent + 1) 

269 if any(filter(lambda s: 'op_version="op_version"' in s, sexpr2)): 

270 continue 

271 rows.extend(sexpr2) 

272 rows[-1] += "," 

273 rows.append( 

274 f"{' ' * (indent + 1) * 4}op_version=op_version") 

275 rows.append(f"{' ' * indent * 4})") 

276 elif op == 'Call': 

277 name = args['name'] 

278 if name.startswith("onnx_"): 

279 raise RuntimeError("The code must not use a function prefixed by 'onnx_' (%s). " 

280 "It indicates that function manipulate ONNX node and " 

281 "the fonction to convert must only deal with arrays." % name) 

282 if name not in context: 

283 raise RuntimeError( 

284 "Unable to find function '{}' at {} in context\n{}\n--\n{}".format( 

285 name, self.make_msg(args), 

286 '\n'.join(sorted(context)), 

287 pprint.pformat(args))) 

288 op_conv = find_onnx_correspondance(context[name], args) 

289 if callable(op_conv) and op_conv.__name__.startswith('py_'): 

290 rows.append( 

291 f"{' ' * indent * 4}{op_conv.__name__}(") 

292 elif callable(op_conv) and op_conv.__name__.startswith('onnx_'): 

293 stack_fct_used.append(op_conv.__name__) 

294 rows.append( 

295 f"{' ' * indent * 4}{op_conv}(") 

296 else: 

297 prefix = "onnx_" if 'a' <= op_conv[0] <= 'z' else 'Onnx' 

298 if prefix == "onnx_": 

299 stack_fct_used.append( 

300 f"{prefix}{op_conv}") 

301 prefix = '_' + prefix 

302 rows.append( 

303 f"{' ' * indent * 4}{prefix}{op_conv}(") 

304 

305 opon = args["args"] 

306 opon = opon[1:] 

307 for expr2 in opon: 

308 sexpr2 = write_expression( 

309 stack_fct_used, expr2, indent + 1, 

310 OnnxTranslator._parameter_mapping.get(op_conv, None)) 

311 if any(filter(lambda s: 'op_version="op_version"' in s, sexpr2)): 

312 continue 

313 rows.extend(sexpr2) 

314 rows[-1] += "," 

315 rows.append( 

316 f"{' ' * (indent + 1) * 4}op_version=op_version") 

317 rows.append(f"{' ' * indent * 4})") 

318 else: 

319 raise RuntimeError( # pragma: no cover 

320 f"Unable to interpret '{expr}'.") 

321 return rows 

322 

323 def write_function(stack_fct_used, to_replaces, node): 

324 rows = [] 

325 name, args = node 

326 if name != 'FunctionDef': 

327 raise RuntimeError( # pragma: no cover 

328 "The code being translated should be a single function not " 

329 "'{}' at {}.".format(name, self.make_msg(args))) 

330 list_args = list(map(str, args['args'])) 

331 if all(map(lambda s: 'dtype=' not in s, list_args)): 

332 list_args.append("dtype=numpy.float32") 

333 if all(map(lambda s: 'op_version=' not in s, list_args)): 

334 list_args.append("op_version=None") 

335 fct_name = args['name'] 

336 rows.append(f"def {fct_name}({', '.join(list_args)}):") 

337 indent = 1 

338 

339 to_replace = f"# __HEADER__{id(node)}" 

340 to_replaces.append(to_replace) 

341 rows.append(f"{' ' * (indent * 4)}{to_replace}") 

342 

343 code = args['code'] 

344 for op, args in code: 

345 if op == "Assign": 

346 name = args['name'] 

347 args = args["args"] 

348 rows.append(f"{' ' * (indent * 4)}{name} = (") 

349 rows.extend(write_expression( 

350 stack_fct_used, args, indent + 1)) 

351 rows.append(f"{' ' * (indent * 4)})") 

352 elif op == "Return": 

353 args = args["code"] 

354 if output_names is None: 

355 rows.append(f"{' ' * (indent * 4)}return (") 

356 rows.extend(write_expression( 

357 stack_fct_used, args, indent + 1)) 

358 rows.append(f"{' ' * (indent * 4)})") 

359 else: 

360 rows.append( 

361 f"{' ' * (indent * 4)}return OnnxIdentity(") 

362 subrows = write_expression( 

363 stack_fct_used, args, indent + 1) 

364 subrows[-1] += "," 

365 rows.extend(subrows) 

366 rows.append("{}output_names={},".format( 

367 " " * ((indent + 1) * 4), str(output_names))) 

368 rows.append( 

369 f"{' ' * ((indent + 1) * 4)}op_version=op_version") 

370 rows.append(f"{' ' * (indent * 4)})") 

371 else: 

372 raise RuntimeError( # pragma: no cover 

373 "Unable to process operator '{}' at {}. " 

374 "Make sure it is either an affectation, " 

375 "either a return.".format(op, self.make_msg(args))) 

376 return rows 

377 

378 stack_fct_used = [] 

379 to_replaces = [] 

380 rows = write_function(stack_fct_used, to_replaces, self._code_fct) 

381 

382 # handling dtype parameter 

383 if len(to_replaces) != 1: 

384 raise RuntimeError( # pragma: no cover 

385 "The following code misses a placeholder:\n{}".format( 

386 "\n".join(rows))) 

387 index = -1 

388 for i, row in enumerate(rows): 

389 if to_replaces[0] in row: 

390 index = i 

391 break 

392 

393 header = [] 

394 for fct in stack_fct_used: 

395 header.append( 

396 " _{0} = lambda *args, op_version=op_version, **kwargs: {0}(*args, dtype=dtype, " 

397 "op_version=op_version, **kwargs)".format(fct)) 

398 if len(header) > 0: 

399 header.append('') 

400 rows[index:index + 1] = header 

401 

402 return "\n".join(rows) 

403 

404 def visit(self, node, info): 

405 """ 

406 Visits a node. 

407 

408 @param node visited node 

409 @param info info extracted by the visitor 

410 """ 

411 if 'type' not in info: 

412 return 

413 

414 kind = info['type'] 

415 if kind == "Module": 

416 return 

417 if kind == "FunctionDef": 

418 if self._is_stacked('FunctionDef'): 

419 raise RuntimeError("Nested functions are not allowed at {}.".format( 

420 self.make_msg(node))) 

421 self._stack.append( 

422 ('FunctionDef', {'args': [], 'code': [], 'name': info['name'], 'default': [], 

423 'lineno': node.lineno, 'col_offset': node.col_offset})) 

424 return 

425 if kind == "arguments": 

426 _, buf = self._get_last('FunctionDef') 

427 return 

428 if kind == "arg": 

429 return 

430 if kind == "Assign": 

431 self._stack.append( 

432 ('Assign', {'args': [], 'lineno': node.lineno, 'col_offset': node.col_offset})) 

433 return 

434 if kind in ('Name', 'Cst'): 

435 self._get_last( 

436 ('Assign', 'BinOp', 'Call', 'Return', 'FunctionDef', 'keyword', 'UnaryOp')) 

437 return 

438 if kind == 'BinOp': 

439 self._stack.append( 

440 ('BinOp', {'args': [], 'lineno': node.lineno, 'col_offset': node.col_offset})) 

441 return 

442 if kind == 'UnaryOp': 

443 self._stack.append( 

444 ('UnaryOp', {'args': [], 'lineno': node.lineno, 'col_offset': node.col_offset})) 

445 return 

446 if kind in OnnxTranslator._binary_operators: 

447 _, buf = self._get_last(('BinOp', 'UnaryOp')) 

448 buf['op'] = kind 

449 return 

450 if kind == 'Call': 

451 self._stack.append( 

452 ('Call', {'name': info['str'], 'args': [], 'lineno': node.lineno, 

453 'col_offset': node.col_offset})) 

454 return 

455 if kind == 'Return': 

456 self._get_last('FunctionDef') 

457 self._stack.append( 

458 ('Return', {'code': [], 'lineno': node.lineno, 'col_offset': node.col_offset})) 

459 return 

460 if kind == "Attribute": 

461 if info.get('str', '') == 'T': 

462 raise NotImplementedError( # pragma: no cover 

463 "Transpose should be done with numpy.transpose not with .T'{}' " 

464 "at {}\n{}\n---\n{}".format( 

465 info.get('type', '?'), self.make_msg(node), 

466 pprint.pformat(info), pprint.pformat(self._stack))) 

467 self._get_last('Call') 

468 return 

469 if kind == 'keyword': 

470 self._get_last('Call') 

471 self._stack.append( 

472 ('keyword', {'name': f"{node.arg}", 

473 'lineno': getattr(node, 'lineno', '?'), 

474 'col_offset': getattr(node, 'col_offset', '?')})) 

475 return 

476 if kind == 'List': 

477 self._get_last('keyword') 

478 self._stack.append( 

479 ('List', {'elts': [], 'lineno': getattr(node, 'lineno', '?'), 

480 'col_offset': getattr(node, 'col_offset', '?')})) 

481 return 

482 if kind == 'Num': 

483 self._get_last(('List', 'UnaryOp', 'BinOp', 'FunctionDef', 'Call')) 

484 return 

485 if kind == 'Str': 

486 self._get_last('keyword') 

487 return 

488 

489 raise NotImplementedError( # pragma: no cover 

490 "Unable to interpret kind '{}' at {}\n{}\n---\n{}".format( 

491 info.get('type', '?'), self.make_msg( 

492 node), pprint.pformat(info), 

493 pprint.pformat(self._stack))) 

494 

495 def _fix_default_values(self, code_fct): 

496 """ 

497 Maps default values with parameter names. 

498 """ 

499 nbdef = len(code_fct[1]['default']) 

500 nbpar = len(code_fct[1]['args']) 

501 args = [] 

502 for i in range(nbpar): 

503 name, annotation = code_fct[1]['args'][i] 

504 j = nbdef - (nbpar - i) 

505 if j >= 0: 

506 default = code_fct[1]['default'][j] 

507 p = OnnxTranslator.Parameter( 

508 name, annotation=annotation, value=default) 

509 else: 

510 p = OnnxTranslator.Parameter(name, annotation=annotation) 

511 args.append(p) 

512 code_fct[1]['args'] = args 

513 

514 def _post_process(self, op, node): 

515 """ 

516 Simplifies some operator such as ``OnnxNeg(2)``. 

517 """ 

518 if op is None and 'args' in node: 

519 for i in range(len(node['args'])): 

520 if not isinstance(node['args'][i], tuple): 

521 continue 

522 o, v = node['args'][i] 

523 if (o == 'UnaryOp' and len(v['args']) == 1 and 

524 isinstance(v['args'][0], (int, float, numpy.int64, 

525 numpy.float32, numpy.float64))): 

526 if v['op'] == 'Sub': 

527 node['args'][i] = -v['args'][0] 

528 

529 def depart(self, node, info): 

530 """ 

531 Visits a node. 

532 

533 @param node visited node 

534 @param info info extracted by the visitor 

535 """ 

536 if 'type' not in info: 

537 return 

538 

539 kind = info['type'] 

540 if kind == "arg": 

541 return 

542 if kind == "arguments": 

543 _, buf = self._get_last('FunctionDef') 

544 for child in info['children']: 

545 if child['type'] == 'Str': 

546 buf['default'].append(child['str']) 

547 elif child['type'] in ('Num', 'Cst'): 

548 buf['default'].append(child['n']) 

549 elif child['type'] == 'arg': 

550 buf['args'].append( 

551 (child['str'], child.get('annotation', None))) 

552 else: 

553 raise RuntimeError( # pragma: no cover 

554 "Unable to interpret type '{}' in function definition." 

555 "\n{}".format( 

556 child['type'], pprint.pformat(info))) 

557 return 

558 

559 if kind == "Name": 

560 op, buf = self._get_last( 

561 ('Assign', 'BinOp', 'Call', 'Return', 'FunctionDef', 'keyword', 

562 'UnaryOp'), 

563 info) 

564 if op == 'Assign': 

565 buf['name'] = info['str'] 

566 return 

567 elif op in ('BinOp', 'Call'): 

568 buf['args'].append(info['str']) 

569 return 

570 elif op == 'Return': 

571 buf['code'] = info['str'] 

572 return 

573 elif op == 'keyword': 

574 buf['value'] = info['str'] 

575 return 

576 elif op == 'UnaryOp': 

577 buf['args'].append(info['str']) 

578 return 

579 elif op == 'FunctionDef': 

580 raise RuntimeError("Default value must be constant, variable '{}' was " 

581 "detected.".format(info['str'])) 

582 

583 if kind in OnnxTranslator._binary_operators: 

584 _, buf = self._get_last(('BinOp', 'UnaryOp')) 

585 return 

586 if kind in ('Call', 'BinOp', 'Assign', 'Return', 'UnaryOp'): 

587 op, buf = self._get_last( 

588 ('Call', 'BinOp', 'Assign', 'Return', 'UnaryOp')) 

589 self._post_process(op, buf) 

590 self._stack.pop() 

591 opp, parent = self._get_last( 

592 ('Call', 'BinOp', 'Assign', 'FunctionDef', 'Return', 'UnaryOp')) 

593 if opp in ('FunctionDef', 'Return'): 

594 parent['code'].append((op, buf)) 

595 else: 

596 parent['args'].append((op, buf)) 

597 self._post_process(None, parent) 

598 return 

599 if kind == 'FunctionDef': 

600 if len(self._stack) == 1: 

601 self._code_fct = self._stack[-1] 

602 self._fix_default_values(self._code_fct) 

603 self._stack = [] 

604 return 

605 if kind == 'Module': 

606 return 

607 if kind == 'Attribute': 

608 op, buf = self._get_last(('Call', 'BinOp')) 

609 

610 if len(info["children"]) > 0: 

611 fir = info["children"][0] 

612 if fir["type"] == "Name": 

613 parent = fir["node"].id 

614 info["str"] = f"{parent}.{info['str']}" 

615 info["children"][0]["remove"] = True 

616 

617 buf['name'] = info["str"] 

618 buf['args'][0] = info["str"] 

619 return 

620 if kind in ('Num', 'Cst'): 

621 op, buf = self._get_last( 

622 ('List', 'BinOp', 'UnaryOp', 'FunctionDef', 'Call')) 

623 if op == 'FunctionDef': 

624 return 

625 if op == 'List': 

626 buf['elts'].append(info['n']) 

627 else: 

628 buf['args'].append(info['n']) 

629 return 

630 if kind == 'Str': 

631 _, buf = self._get_last('keyword') 

632 buf['value'] = info['str'] 

633 return 

634 if kind == 'List': 

635 op, buf = self._get_last('List') 

636 value = buf['elts'] 

637 self._post_process(op, buf) 

638 self._stack.pop() 

639 opp, parent = self._get_last('keyword') 

640 parent['value'] = value 

641 self._post_process(None, parent) 

642 return 

643 if kind == 'keyword': 

644 op, buf = self._get_last('keyword') 

645 name = buf["name"] 

646 if 'value' not in buf: 

647 raise RuntimeError(str(buf)) # pragma: no cover 

648 value = buf['value'] 

649 self._post_process(op, buf) 

650 self._stack.pop() 

651 opp, parent = self._get_last('Call') 

652 parent['args'].append(OnnxTranslator.Parameter(name, value)) 

653 self._post_process(None, parent) 

654 return 

655 

656 raise NotImplementedError( # pragma: no cover 

657 "Unable to interpret kind '{}' at {}\n{}\n---\n{}".format( 

658 info.get('type', '?'), self.make_msg( 

659 node), pprint.pformat(info), 

660 pprint.pformat(self._stack)))