Coverage for mlprodict/testing/verify_code.py: 95%

337 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Looks into the code and detects error 

4before finalizing the benchmark. 

5""" 

6import ast 

7import collections 

8import inspect 

9import numpy 

10 

11 

12class ImperfectPythonCode(RuntimeError): 

13 """ 

14 Raised if the code shows errors. 

15 """ 

16 pass 

17 

18 

19def verify_code(source, exc=True): 

20 """ 

21 Verifies :epkg:`python` code. 

22 

23 :param source: source to look into 

24 :param exc: raise an exception or return the list of 

25 missing identifiers 

26 :return: tuple(missing identifiers, :class:`CodeNodeVisitor 

27 <mlprodict.onnx_tools.onnx_grammar.node_visitor_translator>`) 

28 """ 

29 node = ast.parse(source) 

30 v = CodeNodeVisitor() 

31 v.visit(node) 

32 assign = v._assign 

33 imports = v._imports 

34 names = v._names 

35 args = v._args 

36 known = {'super': None, 'ImportError': None, 'print': print, 

37 'classmethod': classmethod, 'numpy': numpy, 

38 'dict': dict, 'list': list, 'sorted': sorted, 'len': len, 

39 'collections': collections, 'inspect': inspect, 'range': range, 

40 'int': int, 'str': str, 'isinstance': isinstance} 

41 for kn in imports: 

42 known[kn[0]] = kn 

43 for kn in assign: 

44 known[kn[0]] = kn 

45 for kn in args: 

46 known[kn[0]] = kn 

47 issues = set() 

48 for name in names: 

49 if name[0] not in known: 

50 issues.add(name[0]) 

51 if exc and len(issues) > 0: 

52 raise ImperfectPythonCode( 

53 f"Unknown identifiers: '{issues}' in source\n{source}") 

54 return issues, v 

55 

56 

57class CodeNodeVisitor(ast.NodeVisitor): 

58 """ 

59 Visits the code, implements verification rules. 

60 """ 

61 

62 def __init__(self): 

63 ast.NodeVisitor.__init__(self) 

64 self._rows = [] 

65 self._indent = 0 

66 self._stack = [] 

67 self._imports = [] 

68 self._names = [] 

69 self._alias = [] 

70 self._assign = [] 

71 self._args = [] 

72 self._fits = [] 

73 

74 def push(self, row): 

75 """ 

76 Pushes an element into a list. 

77 """ 

78 self._rows.append(row) 

79 

80 def generic_visit(self, node): 

81 """ 

82 Overrides ``generic_visit`` to check it is not used. 

83 """ 

84 raise AttributeError( # pragma: no cover 

85 "generic_visit_args should not be used for node " 

86 "type %r and node=%r." % (type(node), node)) 

87 

88 def generic_visit_args(self, node, row): 

89 """ 

90 Overrides ``generic_visit`` to keep track of the indentation 

91 and the node parent. The function will add field 

92 ``row["children"] = visited`` nodes from here. 

93 

94 @param node node which needs to be visited 

95 @param row row (a dictionary) 

96 @return See ``ast.NodeVisitor.generic_visit`` 

97 """ 

98 self._indent += 1 

99 last = len(self._rows) 

100 res = ast.NodeVisitor.generic_visit( # pylint: disable=E1111 

101 self, node) # pylint: disable=E1111 

102 row["children"] = [ 

103 _ for _ in self._rows[ 

104 last:] if _["indent"] == self._indent] 

105 self._indent -= 1 

106 return res 

107 

108 def visit(self, node): 

109 """ 

110 Visits a node, a method must exist for every object class. 

111 """ 

112 method = 'visit_' + node.__class__.__name__ 

113 visitor = getattr(self, method, None) 

114 if visitor is None: 

115 if method.startswith("visit_"): 

116 cont = { 

117 "indent": self._indent, 

118 "str": method[6:], 

119 "node": node} 

120 self.push(cont) 

121 return self.generic_visit_args(node, cont) 

122 raise TypeError("unable to find a method: " + 

123 method) # pragma: no cover 

124 res = visitor(node) 

125 # print(method, CodeNodeVisitor.print_node(node)) 

126 return res 

127 

128 @staticmethod 

129 def print_node(node): 

130 """ 

131 Debugging purpose. 

132 """ 

133 r = [] 

134 for att in ["s", "name", "str", "id", "body", "n", 

135 "arg", "targets", "attr", "returns", "ctx"]: 

136 if att in node.__dict__: 

137 r.append(f"{att}={str(node.__dict__[att])}") 

138 return " ".join(r) 

139 

140 def print_tree(self): # pylint: disable=C0116 

141 """ 

142 Displays the tree of instructions. 

143 

144 @return string 

145 """ 

146 rows = [] 

147 for r in self.Rows: 

148 rows.append( 

149 f"{' ' * r['indent']}{r.get('type', '')}: {r.get('str', '')}") 

150 return "\n".join(rows) 

151 

152 @property 

153 def Rows(self): 

154 """ 

155 returns a list of dictionaries with all the elements of the code 

156 """ 

157 return [_ for _ in self._rows if not _.get("remove", False)] 

158 

159 def visit_Str(self, node): # pylint: disable=C0116 

160 cont = { 

161 "indent": self._indent, 

162 "type": "Str", 

163 "str": node.s, 

164 "node": node, 

165 "value": node.s} 

166 self.push(cont) 

167 return self.generic_visit_args(node, cont) 

168 

169 def visit_Name(self, node): # pylint: disable=C0116 

170 cont = { 

171 "indent": self._indent, 

172 "type": "Name", 

173 "str": node.id, 

174 "node": node, 

175 "id": node.id, 

176 "ctx": node.ctx} 

177 self.push(cont) 

178 self._names.append((node.id, node)) 

179 return self.generic_visit_args(node, cont) 

180 

181 def visit_Constant(self, node): # pylint: disable=C0116 

182 cont = { 

183 "indent": self._indent, 

184 "type": "Constant", 

185 "str": str(node.value), 

186 "node": node, 

187 "id": node.value} 

188 self.push(cont) 

189 return self.generic_visit_args(node, cont) 

190 

191 def visit_Expr(self, node): # pylint: disable=C0116 

192 cont = { 

193 "indent": self._indent, 

194 "type": "Expr", 

195 "str": '', 

196 "node": node, 

197 "value": node.value} 

198 self.push(cont) 

199 return self.generic_visit_args(node, cont) 

200 

201 def visit_alias(self, node): # pylint: disable=C0116 

202 cont = { 

203 "indent": self._indent, 

204 "type": "alias", 

205 "str": "", 

206 "node": node, 

207 "name": node.name, 

208 "asname": node.asname} 

209 self.push(cont) 

210 self._alias.append((node.name, node.asname, node)) 

211 return self.generic_visit_args(node, cont) 

212 

213 def visit_Module(self, node): # pylint: disable=C0116 

214 cont = { 

215 "indent": self._indent, 

216 "type": "Module", 

217 "str": "", 

218 "body": node.body, 

219 "node": node} 

220 self.push(cont) 

221 return self.generic_visit_args(node, cont) 

222 

223 def visit_Import(self, node): # pylint: disable=C0116 

224 cont = { 

225 "indent": self._indent, 

226 "type": "Import", 

227 "str": "", 

228 "names": node.names, 

229 "node": node} 

230 self.push(cont) 

231 for name in node.names: 

232 self._imports.append((name.name, name.asname, node)) 

233 return self.generic_visit_args(node, cont) 

234 

235 def visit_ImportFrom(self, node): # pylint: disable=C0116 

236 cont = { 

237 "indent": self._indent, 

238 "type": "ImportFrom", 

239 "str": "", 

240 "module": node.module, 

241 "names": node.names, 

242 "node": node} 

243 self.push(cont) 

244 for name in node.names: 

245 self._imports.append((name.name, name.asname, node.module, node)) 

246 return self.generic_visit_args(node, cont) 

247 

248 def visit_ClassDef(self, node): # pylint: disable=C0116 

249 cont = { 

250 "indent": self._indent, 

251 "type": "ClassDef", 

252 "str": "", 

253 "name": node.name, 

254 "body": node.body, 

255 "node": node} 

256 self.push(cont) 

257 return self.generic_visit_args(node, cont) 

258 

259 def visit_FunctionDef(self, node): # pylint: disable=C0116 

260 cont = {"indent": self._indent, "type": "FunctionDef", "str": node.name, "name": node.name, "body": node.body, 

261 "node": node, "returns": node.returns} 

262 self.push(cont) 

263 return self.generic_visit_args(node, cont) 

264 

265 def visit_arguments(self, node): # pylint: disable=C0116 

266 cont = {"indent": self._indent, "type": "arguments", "str": "", 

267 "node": node, "args": node.args} 

268 self.push(cont) 

269 return self.generic_visit_args(node, cont) 

270 

271 def visit_arg(self, node): # pylint: disable=C0116 

272 cont = {"indent": self._indent, "type": "arg", "str": node.arg, 

273 "node": node, 

274 "arg": node.arg, "annotation": node.annotation} 

275 self.push(cont) 

276 self._args.append((node.arg, node)) 

277 return self.generic_visit_args(node, cont) 

278 

279 def visit_Assign(self, node): # pylint: disable=C0116 

280 cont = {"indent": self._indent, "type": "Assign", "str": "", "node": node, 

281 "targets": node.targets, "value": node.value} 

282 self.push(cont) 

283 for t in node.targets: 

284 if hasattr(t, 'id'): 

285 self._assign.append((t.id, node)) 

286 else: 

287 self._assign.append((id(t), node)) 

288 return self.generic_visit_args(node, cont) 

289 

290 def visit_Store(self, node): # pylint: disable=C0116 

291 #cont = { "indent":self._indent, "type": "Store", "str": "" } 

292 # self.push(cont) 

293 cont = {} 

294 return self.generic_visit_args(node, cont) 

295 

296 def visit_Call(self, node): # pylint: disable=C0116 

297 if "attr" in node.func.__dict__: 

298 cont = {"indent": self._indent, "type": "Call", "str": node.func.attr, 

299 "node": node, "func": node.func} 

300 elif "id" in node.func.__dict__: 

301 cont = {"indent": self._indent, "type": "Call", "str": node.func.id, 

302 "node": node, "func": node.func} 

303 else: 

304 cont = {"indent": self._indent, "type": "Call", "str": "", # pragma: no cover 

305 "node": node, "func": node.func} 

306 self.push(cont) 

307 if cont['str'] == 'fit': 

308 self._fits.append(cont) 

309 return self.generic_visit_args(node, cont) 

310 

311 def visit_Attribute(self, node): # pylint: disable=C0116 

312 cont = {"indent": self._indent, "type": "Attribute", "str": node.attr, 

313 "node": node, "value": node.value, "ctx": node.ctx, "attr": node.attr} 

314 self.push(cont) 

315 # last = len(self._rows) 

316 res = self.generic_visit_args(node, cont) 

317 

318 if len(cont["children"]) > 0: 

319 fir = cont["children"][0] 

320 if 'type' in fir and fir["type"] == "Name": 

321 parent = fir["node"].id 

322 cont["str"] = f"{parent}.{cont['str']}" 

323 cont["children"][0]["remove"] = True 

324 return res 

325 

326 def visit_Load(self, node): # pylint: disable=C0116 

327 cont = {} 

328 return self.generic_visit_args(node, cont) 

329 

330 def visit_keyword(self, node): # pylint: disable=C0116 

331 cont = {"indent": self._indent, "type": "keyword", "str": f"{node.arg}", 

332 "node": node, "arg": node.arg, "value": node.value} 

333 self.push(cont) 

334 return self.generic_visit_args(node, cont) 

335 

336 def visit_BinOp(self, node): # pylint: disable=C0116 

337 cont = { 

338 "indent": self._indent, 

339 "type": "BinOp", 

340 "str": "", 

341 "node": node} 

342 self.push(cont) 

343 return self.generic_visit_args(node, cont) 

344 

345 def visit_UnaryOp(self, node): # pylint: disable=C0116 

346 cont = { 

347 "indent": self._indent, 

348 "type": "UnaryOp", 

349 "str": "", 

350 "node": node} 

351 self.push(cont) 

352 return self.generic_visit_args(node, cont) 

353 

354 def visit_Not(self, node): # pylint: disable=C0116 

355 cont = { 

356 "indent": self._indent, 

357 "type": "Not", 

358 "str": "", 

359 "node": node} 

360 self.push(cont) 

361 return self.generic_visit_args(node, cont) 

362 

363 def visit_Invert(self, node): # pylint: disable=C0116 

364 cont = { 

365 "indent": self._indent, 

366 "type": "Invert", 

367 "str": "", 

368 "node": node} 

369 self.push(cont) 

370 return self.generic_visit_args(node, cont) 

371 

372 def visit_BoolOp(self, node): # pylint: disable=C0116 

373 cont = { 

374 "indent": self._indent, 

375 "type": "BoolOp", 

376 "str": "", 

377 "node": node} 

378 self.push(cont) 

379 return self.generic_visit_args(node, cont) 

380 

381 def visit_Mult(self, node): # pylint: disable=C0116 

382 cont = { 

383 "indent": self._indent, 

384 "type": "Mult", 

385 "str": "", 

386 "node": node} 

387 self.push(cont) 

388 return self.generic_visit_args(node, cont) 

389 

390 def visit_Div(self, node): # pylint: disable=C0116 

391 cont = { 

392 "indent": self._indent, 

393 "type": "Div", 

394 "str": "", 

395 "node": node} 

396 self.push(cont) 

397 return self.generic_visit_args(node, cont) 

398 

399 def visit_FloorDiv(self, node): # pylint: disable=C0116 

400 cont = { 

401 "indent": self._indent, 

402 "type": "FloorDiv", 

403 "str": "", 

404 "node": node} 

405 self.push(cont) 

406 return self.generic_visit_args(node, cont) 

407 

408 def visit_Add(self, node): # pylint: disable=C0116 

409 cont = { 

410 "indent": self._indent, 

411 "type": "Add", 

412 "str": "", 

413 "node": node} 

414 self.push(cont) 

415 return self.generic_visit_args(node, cont) 

416 

417 def visit_Pow(self, node): # pylint: disable=C0116 

418 cont = { 

419 "indent": self._indent, 

420 "type": "Pow", 

421 "str": "", 

422 "node": node} 

423 self.push(cont) 

424 return self.generic_visit_args(node, cont) 

425 

426 def visit_In(self, node): # pylint: disable=C0116 

427 cont = { 

428 "indent": self._indent, 

429 "type": "In", 

430 "str": "", 

431 "node": node} 

432 self.push(cont) 

433 return self.generic_visit_args(node, cont) 

434 

435 def visit_AugAssign(self, node): # pylint: disable=C0116 

436 cont = { 

437 "indent": self._indent, 

438 "type": "AugAssign", 

439 "str": "", 

440 "node": node} 

441 self.push(cont) 

442 return self.generic_visit_args(node, cont) 

443 

444 def visit_Eq(self, node): # pylint: disable=C0116 

445 cont = { 

446 "indent": self._indent, 

447 "type": "Eq", 

448 "str": "", 

449 "node": node} 

450 self.push(cont) 

451 return self.generic_visit_args(node, cont) 

452 

453 def visit_IsNot(self, node): # pylint: disable=C0116 

454 cont = { 

455 "indent": self._indent, 

456 "type": "IsNot", 

457 "str": "", 

458 "node": node} 

459 self.push(cont) 

460 return self.generic_visit_args(node, cont) 

461 

462 def visit_Is(self, node): # pylint: disable=C0116 

463 cont = { 

464 "indent": self._indent, 

465 "type": "Is", 

466 "str": "", 

467 "node": node} 

468 self.push(cont) 

469 return self.generic_visit_args(node, cont) 

470 

471 def visit_And(self, node): # pylint: disable=C0116 

472 cont = { 

473 "indent": self._indent, 

474 "type": "And", 

475 "str": "", 

476 "node": node} 

477 self.push(cont) 

478 return self.generic_visit_args(node, cont) 

479 

480 def visit_BitAnd(self, node): # pylint: disable=C0116 

481 cont = { 

482 "indent": self._indent, 

483 "type": "BitAnd", 

484 "str": "", 

485 "node": node} 

486 self.push(cont) 

487 return self.generic_visit_args(node, cont) 

488 

489 def visit_Or(self, node): # pylint: disable=C0116 

490 cont = { 

491 "indent": self._indent, 

492 "type": "Or", 

493 "str": "", 

494 "node": node} 

495 self.push(cont) 

496 return self.generic_visit_args(node, cont) 

497 

498 def visit_NotEq(self, node): # pylint: disable=C0116 

499 cont = { 

500 "indent": self._indent, 

501 "type": "NotEq", 

502 "str": "", 

503 "node": node} 

504 self.push(cont) 

505 return self.generic_visit_args(node, cont) 

506 

507 def visit_Mod(self, node): # pylint: disable=C0116 

508 cont = { 

509 "indent": self._indent, 

510 "type": "Mod", 

511 "str": "", 

512 "node": node} 

513 self.push(cont) 

514 return self.generic_visit_args(node, cont) 

515 

516 def visit_Sub(self, node): # pylint: disable=C0116 

517 cont = { 

518 "indent": self._indent, 

519 "type": "Sub", 

520 "str": "", 

521 "node": node} 

522 self.push(cont) 

523 return self.generic_visit_args(node, cont) 

524 

525 def visit_USub(self, node): # pylint: disable=C0116 

526 cont = { 

527 "indent": self._indent, 

528 "type": "USub", 

529 "str": "", 

530 "node": node} 

531 self.push(cont) 

532 return self.generic_visit_args(node, cont) 

533 

534 def visit_Compare(self, node): # pylint: disable=C0116 

535 cont = { 

536 "indent": self._indent, 

537 "type": "Compare", 

538 "str": "", 

539 "node": node} 

540 self.push(cont) 

541 return self.generic_visit_args(node, cont) 

542 

543 def visit_Gt(self, node): # pylint: disable=C0116 

544 cont = {"indent": self._indent, "type": "Gt", "str": "", "node": node} 

545 self.push(cont) 

546 return self.generic_visit_args(node, cont) 

547 

548 def visit_GtE(self, node): # pylint: disable=C0116 

549 cont = {"indent": self._indent, "type": "GtE", "str": "", "node": node} 

550 self.push(cont) 

551 return self.generic_visit_args(node, cont) 

552 

553 def visit_Lt(self, node): # pylint: disable=C0116 

554 cont = {"indent": self._indent, "type": "Lt", "str": "", "node": node} 

555 self.push(cont) 

556 return self.generic_visit_args(node, cont) 

557 

558 def visit_Num(self, node): # pylint: disable=C0116 

559 cont = { 

560 "indent": self._indent, 

561 "type": "Num", 

562 "node": node, 

563 "str": f"{node.n}", 

564 'n': node.n} 

565 self.push(cont) 

566 return self.generic_visit_args(node, cont) 

567 

568 def visit_Return(self, node): # pylint: disable=C0116 

569 cont = {"indent": self._indent, "type": "Return", "node": node, "str": "", 

570 'value': node.value} 

571 self.push(cont) 

572 return self.generic_visit_args(node, cont) 

573 

574 def visit_List(self, node): # pylint: disable=C0116 

575 cont = { 

576 "indent": self._indent, 

577 "type": "List", 

578 "node": node} 

579 self.push(cont) 

580 return self.generic_visit_args(node, cont) 

581 

582 def visit_ListComp(self, node): # pylint: disable=C0116 

583 cont = { 

584 "indent": self._indent, 

585 "type": "ListComp", 

586 "node": node} 

587 self.push(cont) 

588 return self.generic_visit_args(node, cont) 

589 

590 def visit_comprehension(self, node): # pylint: disable=C0116 

591 cont = { 

592 "indent": self._indent, 

593 "type": "comprehension", 

594 "node": node} 

595 self.push(cont) 

596 return self.generic_visit_args(node, cont) 

597 

598 def visit_Dict(self, node): # pylint: disable=C0116 

599 cont = { 

600 "indent": self._indent, 

601 "type": "Dict", 

602 "node": node} 

603 self.push(cont) 

604 return self.generic_visit_args(node, cont) 

605 

606 def visit_Tuple(self, node): # pylint: disable=C0116 

607 cont = { 

608 "indent": self._indent, 

609 "type": "Tuple", 

610 "node": node} 

611 self.push(cont) 

612 return self.generic_visit_args(node, cont) 

613 

614 def visit_NameConstant(self, node): # pylint: disable=C0116 

615 cont = { 

616 "indent": self._indent, 

617 "type": "NameConstant", 

618 "node": node} 

619 self.push(cont) 

620 return self.generic_visit_args(node, cont) 

621 

622 def visit_(self, node): # pylint: disable=C0116 

623 raise RuntimeError( # pragma: no cover 

624 f"This node is not handled: {node}") 

625 

626 def visit_Subscript(self, node): # pylint: disable=C0116 

627 cont = { 

628 "indent": self._indent, 

629 "str": "Subscript", 

630 "node": node} 

631 self.push(cont) 

632 return self.generic_visit_args(node, cont) 

633 

634 def visit_ExtSlice(self, node): # pylint: disable=C0116 

635 cont = { 

636 "indent": self._indent, 

637 "str": "ExtSlice", 

638 "node": node} 

639 self.push(cont) 

640 return self.generic_visit_args(node, cont) 

641 

642 def visit_Slice(self, node): # pylint: disable=C0116 

643 cont = { 

644 "indent": self._indent, 

645 "str": "Slice", 

646 "node": node} 

647 self.push(cont) 

648 return self.generic_visit_args(node, cont) 

649 

650 def visit_Index(self, node): # pylint: disable=C0116 

651 cont = { 

652 "indent": self._indent, 

653 "str": "Index", 

654 "node": node} 

655 self.push(cont) 

656 return self.generic_visit_args(node, cont) 

657 

658 def visit_If(self, node): # pylint: disable=C0116 

659 cont = { 

660 "indent": self._indent, 

661 "str": "If", 

662 "node": node} 

663 self.push(cont) 

664 return self.generic_visit_args(node, cont) 

665 

666 def visit_IfExp(self, node): # pylint: disable=C0116 

667 cont = { 

668 "indent": self._indent, 

669 "str": "IfExp", 

670 "node": node} 

671 self.push(cont) 

672 return self.generic_visit_args(node, cont) 

673 

674 def visit_Lambda(self, node): # pylint: disable=C0116 

675 cont = { 

676 "indent": self._indent, 

677 "str": "Lambda", 

678 "node": node} 

679 self.push(cont) 

680 return self.generic_visit_args(node, cont) 

681 

682 def visit_GeneratorExp(self, node): # pylint: disable=C0116 

683 cont = { 

684 "indent": self._indent, 

685 "str": "GeneratorExp", 

686 "node": node} 

687 self.push(cont) 

688 return self.generic_visit_args(node, cont)