Coverage for mlprodict/npy/xop_auto.py: 97%

326 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Automates the generation of operators for the 

4documentation for the Xop API. 

5 

6.. versionadded:: 0.9 

7""" 

8import os 

9import textwrap 

10import importlib 

11import inspect 

12import re 

13import keyword 

14import onnx 

15import onnx.defs 

16from onnx.backend.test.case.base import _Exporter 

17from onnx.onnx_cpp2py_export.defs import SchemaError # pylint: disable=E1101,E0611,E0401 

18from onnx.defs import OpSchema 

19 

20 

21def _get_doc_template(): 

22 try: 

23 from jinja2 import Template 

24 except ImportError: # pragma no cover 

25 class Template: 

26 "Docstring template" 

27 

28 def __init__(self, *args): 

29 pass 

30 

31 def render(self, **context): 

32 "render" 

33 schemas = context['schemas'] 

34 rows = [] 

35 for sch in schemas: 

36 doc = sch.doc or '' 

37 name = sch.name 

38 if name is None: 

39 raise RuntimeError("An operator must have a name.") 

40 rows.extend([name, "=" * len(name), 

41 "", doc, ""]) 

42 return "\n".join(rows) 

43 

44 return Template(textwrap.dedent(""" 

45 {% for sch in schemas %} 

46 

47 .. tag-diff-insert. 

48 

49 .. _l-onnx-op{{sch.domain.lower().replace(".", "-")}}-{{sch.name.lower()}}-{{str(sch.since_version)}}: 

50 

51 {{format_name_with_domain(sch)}} 

52 {{'=' * len(format_name_with_domain(sch))}} 

53 

54 **Version** 

55 

56 * **name**: `{{sch.name}} (GitHub) <{{build_doc_url(sch)}}{{sch.name}}>`_ 

57 * **domain**: **{% if sch.domain == '' %}main{% else %}{{sch.domain}}{% endif %}** 

58 * **since_version**: **{{sch.since_version}}** 

59 * **function**: {{sch.has_function}} 

60 * **support_level**: {{sch.support_level}} 

61 * **shape inference**: {{sch.has_type_and_shape_inference_function}} 

62 

63 {% if sch.support_level == OpSchema.SupportType.EXPERIMENTAL %} 

64 No versioning maintained for experimental ops. 

65 {% else %} 

66 This version of the operator has been {% if 

67 sch.deprecated %}deprecated{% else %}available{% endif %} 

68 **since version {{sch.since_version}}{% if 

69 sch.domain %} of domain {{sch.domain}}{% endif %}**. 

70 {% if len(sch.versions) > 1 %} 

71 Other versions of this operator: 

72 {% for v in sch.version[:-1] %} {{v}} {% endfor %} 

73 {% endif %} 

74 {% endif %} 

75 

76 **Summary** 

77 

78 {{process_documentation(sch.doc)}} 

79 

80 {% if sch.attributes %} 

81 **Attributes** 

82 

83 {% for _, attr in sorted(sch.attributes.items()) %}* **{{attr.name}}**{% 

84 if attr.required %} (required){% endif %}: 

85 {{text_wrap(attr.description, 2)}} {% 

86 if attr.default_value %}{{clean_default_value(attr.default_value)}}{% 

87 endif %} 

88 {% endfor %} 

89 {% endif %} 

90 

91 {% if sch.inputs %} 

92 **Inputs** 

93 

94 {% if sch.min_input != sch.max_input %}Between {{sch.min_input 

95 }} and {{sch.max_input}} inputs. 

96 {% endif %} 

97 {% for ii, inp in enumerate(sch.inputs) %} 

98 * **{{getname(inp, ii)}}**{{format_option(inp)}} - **{{inp.typeStr}}**: 

99 {{text_wrap(inp.description, 2)}}{% endfor %} 

100 {% endif %} 

101 

102 {% if sch.outputs %} 

103 **Outputs** 

104 

105 {% if sch.min_output != sch.max_output %}Between {{sch.min_output 

106 }} and {{sch.max_output}} outputs. 

107 {% endif %} 

108 {% for ii, out in enumerate(sch.outputs) %} 

109 * **{{getname(out, ii)}}**{{format_option(out)}} - **{{out.typeStr}}**: 

110 {{text_wrap(out.description, 2)}}{% endfor %} 

111 {% endif %} 

112 

113 {% if sch.type_constraints %} 

114 **Type Constraints** 

115 

116 {% for ii, type_constraint in enumerate(sch.type_constraints) 

117 %}* {{get_constraint(type_constraint, ii)}}: 

118 {{text_wrap(type_constraint.description, 2)}} 

119 {% endfor %} 

120 {% endif %} 

121 

122 {% if get_onnx_example and is_last_schema(sch): %} 

123 **Examples** 

124 

125 {% for example, code in get_onnx_example(sch.name).items(): %} 

126 **{{ example }}** 

127 

128 :: 

129 

130 {{ format_example(code) }} 

131 

132 {% endfor %} 

133 {% endif %} 

134 

135 {% endfor %} 

136 """)) 

137 

138 

139_template_operator = _get_doc_template() 

140__get_all_schemas_with_history = None 

141 

142 

143def _populate__get_all_schemas_with_history(): 

144 res = {} 

145 for schema in onnx.defs.get_all_schemas_with_history(): 

146 domain = schema.domain 

147 version = schema.since_version 

148 name = schema.name 

149 if domain not in res: 

150 res[domain] = {} 

151 if name not in res[domain]: 

152 res[domain][name] = {} 

153 res[domain][name][version] = schema 

154 

155 try: 

156 import onnxruntime.capi.onnxruntime_pybind11_state as rtpy 

157 except ImportError: # pragma: no cover 

158 rtpy = None 

159 

160 if rtpy is not None: 

161 # If onnxruntime is available, it is being populated with these operators as well. 

162 from .xop import _CustomSchema 

163 try: 

164 get_schemas = rtpy.get_all_operator_schema 

165 except AttributeError: 

166 # onnxruntime must be compiled with flag --gen_doc. 

167 # a local copy is retrieved. 

168 from .xop import _get_all_operator_schema 

169 get_schemas = _get_all_operator_schema 

170 for op in get_schemas(): 

171 sch = _CustomSchema(op) 

172 domain, name = sch.domain, sch.name 

173 if domain in res and name in res[domain]: 

174 # already handled 

175 continue 

176 version = sch.since_version 

177 if domain not in res: 

178 res[domain] = {} 

179 if name not in res[domain]: 

180 res[domain][name] = {} 

181 res[domain][name][version] = sch 

182 

183 return res 

184 

185 

186def _get_all_schemas_with_history(): 

187 global __get_all_schemas_with_history # pylint: disable=W0603 

188 if __get_all_schemas_with_history is None: 

189 __get_all_schemas_with_history = _populate__get_all_schemas_with_history() 

190 return __get_all_schemas_with_history 

191 

192 

193def get_domain_list(): 

194 """ 

195 Returns the list of available domains. 

196 """ 

197 return list(sorted(set(map(lambda s: s.domain, 

198 onnx.defs.get_all_schemas_with_history())))) 

199 

200 

201def get_operator_schemas(op_name, version=None, domain=None): 

202 """ 

203 Returns all schemas mapped to an operator name. 

204 

205 :param op_name: name of the operator 

206 :param version: version 

207 :param domain: domain 

208 :return: list of schemas 

209 """ 

210 if version == 'last' and op_name is not None: 

211 if domain is not None: 

212 return [onnx.defs.get_schema(op_name, domain=domain)] 

213 all_schemas = _get_all_schemas_with_history() 

214 if domain is None: 

215 domains = [] 

216 for dom, ops in all_schemas.items(): 

217 if op_name is None or op_name in ops: 

218 domains.append(dom) 

219 else: 

220 domains = [domain] 

221 

222 # schemas 

223 sch = [] 

224 for dom in domains: 

225 ops = all_schemas[dom] 

226 if op_name is None: 

227 for op, v in ops.items(): 

228 if version is None: 

229 sch.extend(v.values()) 

230 elif version == 'last' and (dom == '' or 'onnx' in dom): 

231 try: 

232 sch.append(onnx.defs.get_schema(op, domain=dom)) 

233 except SchemaError: # pragma: no cover 

234 sch.append(v[max(v)]) 

235 elif version == 'last': 

236 sch.append(v[max(v)]) 

237 else: 

238 sch.append(v[version]) 

239 elif op_name in ops: 

240 if version is None: 

241 sch.extend(ops[op_name].values()) 

242 elif version in ops[op_name]: 

243 sch.append(ops[op_name][version]) 

244 

245 # sort 

246 vals = [(s.domain, s.name, -s.since_version, s) for s in sch] 

247 vals.sort() 

248 return [v[-1] for v in vals] 

249 

250 

251def get_rst_doc(op_name=None, domain=None, version='last', clean=True, 

252 diff=False, example=False): 

253 """ 

254 Returns a documentation in RST format 

255 for all :class:`OnnxOperator`. 

256 

257 :param op_name: operator name of None for all 

258 :param domain: domain 

259 :param version: version, None for all, `'last'` for the most recent one 

260 :param clean: clean empty lines 

261 :param diff: highlights differences between two versions 

262 :param example: add example to the documentation 

263 :return: string 

264 

265 The function relies on module :epkg:`jinja2` or replaces it 

266 with a simple rendering if not present. 

267 """ 

268 from ..onnx_tools.onnx2py_helper import _var_as_dict 

269 schemas = get_operator_schemas(op_name, domain=domain, version=version) 

270 

271 # from onnx.backend.sample.ops import collect_sample_implementations 

272 # from onnx.backend.test.case import collect_snippets 

273 # SNIPPETS = collect_snippets() 

274 # SAMPLE_IMPLEMENTATIONS = collect_sample_implementations() 

275 def format_name_with_domain(sch): 

276 if version == 'last': 

277 if sch.domain: 

278 return f'{sch.name} ({sch.domain})' 

279 return sch.name 

280 if sch.domain: 

281 return f'{sch.name} - {sch.since_version} ({sch.domain})' 

282 return '%s - %d' % (sch.name, sch.since_version) 

283 

284 def format_option(obj): 

285 opts = [] 

286 if OpSchema.FormalParameterOption.Optional == obj.option: 

287 opts.append('optional') 

288 elif OpSchema.FormalParameterOption.Variadic == obj.option: 

289 opts.append('variadic') 

290 if getattr(obj, 'isHomogeneous', False): 

291 opts.append('heterogeneous') 

292 if opts: 

293 return f" ({', '.join(opts)})" 

294 return "" 

295 

296 def format_example(code): 

297 code = textwrap.indent(code, ' ') 

298 return code 

299 

300 def get_constraint(const, ii): 

301 if const.type_param_str: 

302 name = const.type_param_str 

303 else: 

304 name = str(ii) 

305 name = f"**{name}** in (" 

306 if const.allowed_type_strs: 

307 text = ",\n ".join(sorted(const.allowed_type_strs)) 

308 name += "\n " + text + "\n )" 

309 return name 

310 

311 def getname(obj, i): 

312 name = obj.name 

313 if len(name) == 0: 

314 return str(i) 

315 return name 

316 

317 def process_documentation(doc): 

318 if doc is None: 

319 doc = '' 

320 if not isinstance(doc, str): 

321 raise TypeError( # pragma: no cover 

322 f"doc must be a string not {type(doc)!r} - {doc + 42!r}.") 

323 doc = textwrap.dedent(doc) 

324 main_docs_url = "https://github.com/onnx/onnx/blob/master/" 

325 rep = { 

326 '[the doc](IR.md)': '`ONNX <{0}docs/IR.md>`_', 

327 '[the doc](Broadcasting.md)': 

328 '`Broadcasting in ONNX <{0}docs/Broadcasting.md>`_', 

329 '<dl>': '', 

330 '</dl>': '', 

331 '<dt>': '* ', 

332 '<dd>': ' ', 

333 '</dt>': '', 

334 '</dd>': '', 

335 '<tt>': '``', 

336 '</tt>': '``', 

337 '<br>': '\n', 

338 } 

339 for k, v in rep.items(): 

340 doc = doc.replace(k, v.format(main_docs_url)) 

341 move = 0 

342 lines = [] 

343 for line in doc.split('\n'): 

344 if line.startswith("```"): 

345 if move > 0: 

346 move -= 4 

347 lines.append("\n") 

348 else: 

349 lines.append("::\n") 

350 move += 4 

351 elif move > 0: 

352 lines.append(" " * move + line) 

353 else: 

354 lines.append(line) 

355 return "\n".join(lines) 

356 

357 def build_doc_url(sch): 

358 doc_url = "https://github.com/onnx/onnx/blob/main/docs/Operators" 

359 if "ml" in sch.domain: 

360 doc_url += "-ml" 

361 doc_url += ".md" 

362 doc_url += "#" 

363 if sch.domain not in (None, '', 'ai.onnx'): 

364 doc_url += sch.domain + "." 

365 return doc_url 

366 

367 def clean_default_value(value): 

368 dvar = _var_as_dict(value) 

369 if 'value' in dvar: 

370 v = dvar['value'] 

371 if isinstance(v, bytes): 

372 return f"Default value is ``'{v.decode('ascii')}'``." 

373 return f"Default value is ``{v}``." 

374 else: 

375 res = str(value).replace('\n', ' ').strip() 

376 if len(res) > 0: 

377 return f"Default value is ``{res}``." 

378 return "" 

379 

380 def text_wrap(text, indent): 

381 s = ' ' * indent 

382 lines = textwrap.wrap(text, initial_indent=s, subsequent_indent=s) 

383 return '\n'.join(lines) 

384 

385 fnwd = format_name_with_domain 

386 tmpl = _template_operator 

387 docs = tmpl.render(schemas=schemas, OpSchema=OpSchema, 

388 len=len, getattr=getattr, sorted=sorted, 

389 format_option=format_option, 

390 get_constraint=get_constraint, 

391 getname=getname, enumerate=enumerate, 

392 format_name_with_domain=fnwd, 

393 process_documentation=process_documentation, 

394 build_doc_url=build_doc_url, text_wrap=text_wrap, 

395 str=str, clean_default_value=clean_default_value, 

396 get_onnx_example=get_onnx_example if example else None, 

397 format_example=format_example, 

398 is_last_schema=is_last_schema) 

399 if diff: 

400 lines = docs.split('\n') 

401 new_lines = [''] 

402 for line in lines: 

403 line = line.rstrip('\r\t ') 

404 if len(line) == 0 and len(new_lines[-1]) == 0: 

405 continue 

406 new_lines.append(line) 

407 docs = '\n'.join(new_lines) 

408 docs = _insert_diff(docs, '.. tag-diff-insert.') 

409 

410 if clean: 

411 lines = docs.split('\n') 

412 new_lines = [''] 

413 for line in lines: 

414 line = line.rstrip('\r\t ') 

415 if len(line) == 0 and len(new_lines[-1]) == 0: 

416 continue 

417 new_lines.append(line) 

418 docs = '\n'.join(new_lines) 

419 

420 return docs 

421 

422 

423def _insert_diff(docs, split='.. tag-diff-insert.'): 

424 """ 

425 Splits a using `split`, insert HTML differences between pieces. 

426 The function relies on package :epkg:`pyquickhelper`. 

427 """ 

428 spl = docs.split(split) 

429 if len(spl) <= 1: 

430 return docs 

431 

432 from pyquickhelper.texthelper.edit_text_diff import ( 

433 edit_distance_text, diff2html) 

434 

435 pieces = [spl[0]] 

436 for i in range(1, len(spl)): 

437 spl1 = spl[i - 1].strip('\n ') 

438 spl2 = spl[i].strip('\n ') 

439 spl1 = spl1.split('**Examples**')[0].replace('`', '') 

440 spl2 = spl2.split('**Examples**')[0].replace('`', '') 

441 spl1 = spl1.split('**Summary**')[-1].strip('\n ') 

442 spl2 = spl2.split('**Summary**')[-1].strip('\n ') 

443 if len(spl1) < 5 or len(spl2) < 5: 

444 pieces.append(spl[i]) 

445 continue 

446 

447 _, aligned, final = edit_distance_text( # pylint: disable=W0632 

448 spl2, spl1, threshold=0.5) 

449 ht = diff2html(spl2, spl1, aligned, final, two_columns=True) 

450 ht = ht.replace(">``<", "><") 

451 ht = ' ' + '\n '.join(ht.split('\n')) 

452 pieces.extend(['', '**Differences**', '', '.. raw:: html', 

453 '', ht, '', spl[i]]) 

454 

455 return '\n'.join(pieces) 

456 

457 

458def change_style(name): 

459 """ 

460 Switches from *AaBb* into *aa_bb*. 

461 

462 :param name: name to convert 

463 :return: converted name 

464 

465 Example: 

466 

467 .. runpython:: 

468 :showcode: 

469 

470 from mlprodict.npy.xop_auto import change_style 

471 

472 print("changeStyle --> {0}".format(change_style('change_style'))) 

473 """ 

474 s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) 

475 s2 = re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() 

476 return s2 if not keyword.iskeyword(s2) else s2 + "_" 

477 

478 

479def get_onnx_example(op_name): 

480 """ 

481 Retrieves examples associated to one operator 

482 stored in onnx packages. 

483 

484 :param op_name: operator name 

485 :param fmt: rendering format 

486 :return: dictionary 

487 """ 

488 modules = [ 

489 f'onnx.backend.test.case.node.{op_name.lower()}', 

490 f'onnx.backend.test.case.node.{change_style(op_name).lower()}', 

491 ] 

492 module = None 

493 for m in modules: 

494 try: 

495 mod = importlib.import_module(m) 

496 module = m 

497 except (AttributeError, ImportError): 

498 continue 

499 if module is None: 

500 # Unable to find an example for 'op_name'. 

501 return {} 

502 results = {} 

503 for v in mod.__dict__.values(): 

504 if not isinstance(v, _Exporter): 

505 continue 

506 code_cls = inspect.getsource(v) 

507 codes = code_cls.split('@staticmethod') 

508 for me in v.__dict__: 

509 if not me.startswith('export'): 

510 continue 

511 sub = f' {me}()' 

512 found = None 

513 for code in codes: 

514 if sub in code: 

515 found = code 

516 if found is None: 

517 raise RuntimeError( # pragma: no cover 

518 f"Unable to find {sub!r} in\n{code_cls}") 

519 found = textwrap.dedent(found) 

520 lines = found.split('\n') 

521 first = 0 

522 for i in range(len(lines)): # pylint: disable=C0200 

523 if lines[i].startswith('def '): 

524 first = i + 1 

525 found = textwrap.dedent('\n'.join(lines[first:])) 

526 key = me[len('export'):] 

527 if key == '': 

528 key = 'default' 

529 if key in results: 

530 key = f'example {len(results) + 1}' 

531 results[key] = found 

532 return results 

533 

534 

535def is_last_schema(sch): 

536 """ 

537 Tells if this is the most recent schema for this operator. 

538 

539 :param sch: schema 

540 :return: True 

541 """ 

542 try: 

543 last = onnx.defs.get_schema(sch.name, domain=sch.domain) 

544 except SchemaError: # pragma: no cover 

545 # raise RuntimeError( 

546 # "Unable to find schema for operator %r and domain %r." 

547 # "" % (sch.name, sch.domain)) 

548 return True 

549 return last.since_version == sch.since_version 

550 

551 

552def onnx_documentation_folder(folder, ops=None, title='ONNX operators', 

553 fLOG=None): 

554 """ 

555 Creates documentation in a folder for all known 

556 ONNX operators or a subset. 

557 

558 :param folder: folder where to write the documentation 

559 :param ops: None for all operators or a subset of them 

560 :param title: index title 

561 :param fLOG: logging function 

562 :return: list of creates files 

563 """ 

564 all_schemas = _get_all_schemas_with_history() 

565 if not os.path.exists(folder): 

566 os.makedirs(folder) 

567 index = ['', title, '=' * len(title), '', '.. contents::', 

568 ' :local:', ''] 

569 pages = [] 

570 tables_domain_pages = [] 

571 

572 if ops is not None: 

573 ops = set(ops) 

574 for dom in sorted(all_schemas): 

575 sdom = 'main' if dom == '' else dom 

576 

577 index_dom = [sdom, '+' * len(sdom), '', '.. toctree::', 

578 ' :maxdepth: 1', ''] 

579 

580 table_dom = ["", f".. _l-table-operator-{sdom.replace('.', '-')}:", "", 

581 f"operator table for domain {sdom}"] 

582 table_dom.extend(["=" * len(table_dom[-1]), ""]) 

583 table_dom.extend([f".. list-table:: operators for domain {sdom}", 

584 " :widths: 10 10", 

585 " :header-rows: 1", 

586 "", 

587 " * - operator", 

588 " - versions"]) 

589 

590 sub = all_schemas[dom] 

591 do = [] 

592 if ops is None: 

593 do.extend(sub) 

594 else: 

595 inter = set(sub).intersection(ops) 

596 if len(inter) == 0: 

597 continue 

598 do.extend(sorted(inter)) 

599 if len(do) == 0: 

600 continue 

601 

602 for op in sorted(do): 

603 if fLOG is not None: 

604 fLOG( # pragma: no cover 

605 f'generate page for onnx {dom!r} - {op!r}') 

606 page_name = f"onnx_{dom.replace('.', '')}_{op}" 

607 index_dom.append(f' {page_name}') 

608 doc = get_rst_doc(op, domain=dom, version=None, example=True, 

609 diff=True) 

610 if dom == '': 

611 main = op 

612 else: 

613 main = f'{dom} - {op}' 

614 rows = ['', f'.. _l-onnx-doc{dom}-{op}:', '', 

615 '=' * len(main), main, '=' * len(main), '', 

616 '.. contents::', ' :local:', '', doc] 

617 

618 full = os.path.join(folder, page_name + '.rst') 

619 with open(full, 'w', encoding='utf-8') as f: 

620 f.write("\n".join(rows)) 

621 pages.append(full) 

622 

623 # table 

624 schemas = get_operator_schemas(op, domain=dom, version=None) 

625 links = [] 

626 for sch in schemas: 

627 link = ( 

628 ':ref:`{sver} <l-onnx-op{lname_}-{lname}-{sver}>`').format( 

629 sver=str(sch.since_version), lname=sch.name.lower(), 

630 lname_=sch.domain.lower().replace(".", "-")) 

631 links.append(link) 

632 table_dom.extend([f" * - {op}", 

633 f" - {', '.join(links)}"]) 

634 

635 sdom_clean = sdom.replace('.', '_') 

636 page_name = os.path.join(folder, f'table_{sdom_clean}.rst') 

637 tables_domain_pages.append(f'table_{sdom_clean}') 

638 pages.append(page_name) 

639 with open(page_name, "w", encoding="utf-8") as f: 

640 f.write("\n".join(table_dom)) 

641 

642 index.extend(index_dom) 

643 index.append('') 

644 

645 # adding pages 

646 index.extend(["", "Tables", "++++++", "", 

647 ".. toctree::", " :maxdepth: 1", ""]) 

648 for page in tables_domain_pages: 

649 index.append(f" {page}") 

650 index.append('') 

651 

652 # creating a big index 

653 page_name = os.path.join(folder, 'index.rst') 

654 with open(page_name, 'w', encoding='utf-8') as f: 

655 f.write('\n'.join(index)) 

656 pages.append(page_name) 

657 return pages