Coverage for mlprodict/npy/xop_auto.py: 97%
326 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Automates the generation of operators for the
4documentation for the Xop API.
6.. versionadded:: 0.9
7"""
8import os
9import textwrap
10import importlib
11import inspect
12import re
13import keyword
14import onnx
15import onnx.defs
16from onnx.backend.test.case.base import _Exporter
17from onnx.onnx_cpp2py_export.defs import SchemaError # pylint: disable=E1101,E0611,E0401
18from onnx.defs import OpSchema
21def _get_doc_template():
22 try:
23 from jinja2 import Template
24 except ImportError: # pragma no cover
25 class Template:
26 "Docstring template"
28 def __init__(self, *args):
29 pass
31 def render(self, **context):
32 "render"
33 schemas = context['schemas']
34 rows = []
35 for sch in schemas:
36 doc = sch.doc or ''
37 name = sch.name
38 if name is None:
39 raise RuntimeError("An operator must have a name.")
40 rows.extend([name, "=" * len(name),
41 "", doc, ""])
42 return "\n".join(rows)
44 return Template(textwrap.dedent("""
45 {% for sch in schemas %}
47 .. tag-diff-insert.
49 .. _l-onnx-op{{sch.domain.lower().replace(".", "-")}}-{{sch.name.lower()}}-{{str(sch.since_version)}}:
51 {{format_name_with_domain(sch)}}
52 {{'=' * len(format_name_with_domain(sch))}}
54 **Version**
56 * **name**: `{{sch.name}} (GitHub) <{{build_doc_url(sch)}}{{sch.name}}>`_
57 * **domain**: **{% if sch.domain == '' %}main{% else %}{{sch.domain}}{% endif %}**
58 * **since_version**: **{{sch.since_version}}**
59 * **function**: {{sch.has_function}}
60 * **support_level**: {{sch.support_level}}
61 * **shape inference**: {{sch.has_type_and_shape_inference_function}}
63 {% if sch.support_level == OpSchema.SupportType.EXPERIMENTAL %}
64 No versioning maintained for experimental ops.
65 {% else %}
66 This version of the operator has been {% if
67 sch.deprecated %}deprecated{% else %}available{% endif %}
68 **since version {{sch.since_version}}{% if
69 sch.domain %} of domain {{sch.domain}}{% endif %}**.
70 {% if len(sch.versions) > 1 %}
71 Other versions of this operator:
72 {% for v in sch.version[:-1] %} {{v}} {% endfor %}
73 {% endif %}
74 {% endif %}
76 **Summary**
78 {{process_documentation(sch.doc)}}
80 {% if sch.attributes %}
81 **Attributes**
83 {% for _, attr in sorted(sch.attributes.items()) %}* **{{attr.name}}**{%
84 if attr.required %} (required){% endif %}:
85 {{text_wrap(attr.description, 2)}} {%
86 if attr.default_value %}{{clean_default_value(attr.default_value)}}{%
87 endif %}
88 {% endfor %}
89 {% endif %}
91 {% if sch.inputs %}
92 **Inputs**
94 {% if sch.min_input != sch.max_input %}Between {{sch.min_input
95 }} and {{sch.max_input}} inputs.
96 {% endif %}
97 {% for ii, inp in enumerate(sch.inputs) %}
98 * **{{getname(inp, ii)}}**{{format_option(inp)}} - **{{inp.typeStr}}**:
99 {{text_wrap(inp.description, 2)}}{% endfor %}
100 {% endif %}
102 {% if sch.outputs %}
103 **Outputs**
105 {% if sch.min_output != sch.max_output %}Between {{sch.min_output
106 }} and {{sch.max_output}} outputs.
107 {% endif %}
108 {% for ii, out in enumerate(sch.outputs) %}
109 * **{{getname(out, ii)}}**{{format_option(out)}} - **{{out.typeStr}}**:
110 {{text_wrap(out.description, 2)}}{% endfor %}
111 {% endif %}
113 {% if sch.type_constraints %}
114 **Type Constraints**
116 {% for ii, type_constraint in enumerate(sch.type_constraints)
117 %}* {{get_constraint(type_constraint, ii)}}:
118 {{text_wrap(type_constraint.description, 2)}}
119 {% endfor %}
120 {% endif %}
122 {% if get_onnx_example and is_last_schema(sch): %}
123 **Examples**
125 {% for example, code in get_onnx_example(sch.name).items(): %}
126 **{{ example }}**
128 ::
130 {{ format_example(code) }}
132 {% endfor %}
133 {% endif %}
135 {% endfor %}
136 """))
139_template_operator = _get_doc_template()
140__get_all_schemas_with_history = None
143def _populate__get_all_schemas_with_history():
144 res = {}
145 for schema in onnx.defs.get_all_schemas_with_history():
146 domain = schema.domain
147 version = schema.since_version
148 name = schema.name
149 if domain not in res:
150 res[domain] = {}
151 if name not in res[domain]:
152 res[domain][name] = {}
153 res[domain][name][version] = schema
155 try:
156 import onnxruntime.capi.onnxruntime_pybind11_state as rtpy
157 except ImportError: # pragma: no cover
158 rtpy = None
160 if rtpy is not None:
161 # If onnxruntime is available, it is being populated with these operators as well.
162 from .xop import _CustomSchema
163 try:
164 get_schemas = rtpy.get_all_operator_schema
165 except AttributeError:
166 # onnxruntime must be compiled with flag --gen_doc.
167 # a local copy is retrieved.
168 from .xop import _get_all_operator_schema
169 get_schemas = _get_all_operator_schema
170 for op in get_schemas():
171 sch = _CustomSchema(op)
172 domain, name = sch.domain, sch.name
173 if domain in res and name in res[domain]:
174 # already handled
175 continue
176 version = sch.since_version
177 if domain not in res:
178 res[domain] = {}
179 if name not in res[domain]:
180 res[domain][name] = {}
181 res[domain][name][version] = sch
183 return res
186def _get_all_schemas_with_history():
187 global __get_all_schemas_with_history # pylint: disable=W0603
188 if __get_all_schemas_with_history is None:
189 __get_all_schemas_with_history = _populate__get_all_schemas_with_history()
190 return __get_all_schemas_with_history
193def get_domain_list():
194 """
195 Returns the list of available domains.
196 """
197 return list(sorted(set(map(lambda s: s.domain,
198 onnx.defs.get_all_schemas_with_history()))))
201def get_operator_schemas(op_name, version=None, domain=None):
202 """
203 Returns all schemas mapped to an operator name.
205 :param op_name: name of the operator
206 :param version: version
207 :param domain: domain
208 :return: list of schemas
209 """
210 if version == 'last' and op_name is not None:
211 if domain is not None:
212 return [onnx.defs.get_schema(op_name, domain=domain)]
213 all_schemas = _get_all_schemas_with_history()
214 if domain is None:
215 domains = []
216 for dom, ops in all_schemas.items():
217 if op_name is None or op_name in ops:
218 domains.append(dom)
219 else:
220 domains = [domain]
222 # schemas
223 sch = []
224 for dom in domains:
225 ops = all_schemas[dom]
226 if op_name is None:
227 for op, v in ops.items():
228 if version is None:
229 sch.extend(v.values())
230 elif version == 'last' and (dom == '' or 'onnx' in dom):
231 try:
232 sch.append(onnx.defs.get_schema(op, domain=dom))
233 except SchemaError: # pragma: no cover
234 sch.append(v[max(v)])
235 elif version == 'last':
236 sch.append(v[max(v)])
237 else:
238 sch.append(v[version])
239 elif op_name in ops:
240 if version is None:
241 sch.extend(ops[op_name].values())
242 elif version in ops[op_name]:
243 sch.append(ops[op_name][version])
245 # sort
246 vals = [(s.domain, s.name, -s.since_version, s) for s in sch]
247 vals.sort()
248 return [v[-1] for v in vals]
251def get_rst_doc(op_name=None, domain=None, version='last', clean=True,
252 diff=False, example=False):
253 """
254 Returns a documentation in RST format
255 for all :class:`OnnxOperator`.
257 :param op_name: operator name of None for all
258 :param domain: domain
259 :param version: version, None for all, `'last'` for the most recent one
260 :param clean: clean empty lines
261 :param diff: highlights differences between two versions
262 :param example: add example to the documentation
263 :return: string
265 The function relies on module :epkg:`jinja2` or replaces it
266 with a simple rendering if not present.
267 """
268 from ..onnx_tools.onnx2py_helper import _var_as_dict
269 schemas = get_operator_schemas(op_name, domain=domain, version=version)
271 # from onnx.backend.sample.ops import collect_sample_implementations
272 # from onnx.backend.test.case import collect_snippets
273 # SNIPPETS = collect_snippets()
274 # SAMPLE_IMPLEMENTATIONS = collect_sample_implementations()
275 def format_name_with_domain(sch):
276 if version == 'last':
277 if sch.domain:
278 return f'{sch.name} ({sch.domain})'
279 return sch.name
280 if sch.domain:
281 return f'{sch.name} - {sch.since_version} ({sch.domain})'
282 return '%s - %d' % (sch.name, sch.since_version)
284 def format_option(obj):
285 opts = []
286 if OpSchema.FormalParameterOption.Optional == obj.option:
287 opts.append('optional')
288 elif OpSchema.FormalParameterOption.Variadic == obj.option:
289 opts.append('variadic')
290 if getattr(obj, 'isHomogeneous', False):
291 opts.append('heterogeneous')
292 if opts:
293 return f" ({', '.join(opts)})"
294 return ""
296 def format_example(code):
297 code = textwrap.indent(code, ' ')
298 return code
300 def get_constraint(const, ii):
301 if const.type_param_str:
302 name = const.type_param_str
303 else:
304 name = str(ii)
305 name = f"**{name}** in ("
306 if const.allowed_type_strs:
307 text = ",\n ".join(sorted(const.allowed_type_strs))
308 name += "\n " + text + "\n )"
309 return name
311 def getname(obj, i):
312 name = obj.name
313 if len(name) == 0:
314 return str(i)
315 return name
317 def process_documentation(doc):
318 if doc is None:
319 doc = ''
320 if not isinstance(doc, str):
321 raise TypeError( # pragma: no cover
322 f"doc must be a string not {type(doc)!r} - {doc + 42!r}.")
323 doc = textwrap.dedent(doc)
324 main_docs_url = "https://github.com/onnx/onnx/blob/master/"
325 rep = {
326 '[the doc](IR.md)': '`ONNX <{0}docs/IR.md>`_',
327 '[the doc](Broadcasting.md)':
328 '`Broadcasting in ONNX <{0}docs/Broadcasting.md>`_',
329 '<dl>': '',
330 '</dl>': '',
331 '<dt>': '* ',
332 '<dd>': ' ',
333 '</dt>': '',
334 '</dd>': '',
335 '<tt>': '``',
336 '</tt>': '``',
337 '<br>': '\n',
338 }
339 for k, v in rep.items():
340 doc = doc.replace(k, v.format(main_docs_url))
341 move = 0
342 lines = []
343 for line in doc.split('\n'):
344 if line.startswith("```"):
345 if move > 0:
346 move -= 4
347 lines.append("\n")
348 else:
349 lines.append("::\n")
350 move += 4
351 elif move > 0:
352 lines.append(" " * move + line)
353 else:
354 lines.append(line)
355 return "\n".join(lines)
357 def build_doc_url(sch):
358 doc_url = "https://github.com/onnx/onnx/blob/main/docs/Operators"
359 if "ml" in sch.domain:
360 doc_url += "-ml"
361 doc_url += ".md"
362 doc_url += "#"
363 if sch.domain not in (None, '', 'ai.onnx'):
364 doc_url += sch.domain + "."
365 return doc_url
367 def clean_default_value(value):
368 dvar = _var_as_dict(value)
369 if 'value' in dvar:
370 v = dvar['value']
371 if isinstance(v, bytes):
372 return f"Default value is ``'{v.decode('ascii')}'``."
373 return f"Default value is ``{v}``."
374 else:
375 res = str(value).replace('\n', ' ').strip()
376 if len(res) > 0:
377 return f"Default value is ``{res}``."
378 return ""
380 def text_wrap(text, indent):
381 s = ' ' * indent
382 lines = textwrap.wrap(text, initial_indent=s, subsequent_indent=s)
383 return '\n'.join(lines)
385 fnwd = format_name_with_domain
386 tmpl = _template_operator
387 docs = tmpl.render(schemas=schemas, OpSchema=OpSchema,
388 len=len, getattr=getattr, sorted=sorted,
389 format_option=format_option,
390 get_constraint=get_constraint,
391 getname=getname, enumerate=enumerate,
392 format_name_with_domain=fnwd,
393 process_documentation=process_documentation,
394 build_doc_url=build_doc_url, text_wrap=text_wrap,
395 str=str, clean_default_value=clean_default_value,
396 get_onnx_example=get_onnx_example if example else None,
397 format_example=format_example,
398 is_last_schema=is_last_schema)
399 if diff:
400 lines = docs.split('\n')
401 new_lines = ['']
402 for line in lines:
403 line = line.rstrip('\r\t ')
404 if len(line) == 0 and len(new_lines[-1]) == 0:
405 continue
406 new_lines.append(line)
407 docs = '\n'.join(new_lines)
408 docs = _insert_diff(docs, '.. tag-diff-insert.')
410 if clean:
411 lines = docs.split('\n')
412 new_lines = ['']
413 for line in lines:
414 line = line.rstrip('\r\t ')
415 if len(line) == 0 and len(new_lines[-1]) == 0:
416 continue
417 new_lines.append(line)
418 docs = '\n'.join(new_lines)
420 return docs
423def _insert_diff(docs, split='.. tag-diff-insert.'):
424 """
425 Splits a using `split`, insert HTML differences between pieces.
426 The function relies on package :epkg:`pyquickhelper`.
427 """
428 spl = docs.split(split)
429 if len(spl) <= 1:
430 return docs
432 from pyquickhelper.texthelper.edit_text_diff import (
433 edit_distance_text, diff2html)
435 pieces = [spl[0]]
436 for i in range(1, len(spl)):
437 spl1 = spl[i - 1].strip('\n ')
438 spl2 = spl[i].strip('\n ')
439 spl1 = spl1.split('**Examples**')[0].replace('`', '')
440 spl2 = spl2.split('**Examples**')[0].replace('`', '')
441 spl1 = spl1.split('**Summary**')[-1].strip('\n ')
442 spl2 = spl2.split('**Summary**')[-1].strip('\n ')
443 if len(spl1) < 5 or len(spl2) < 5:
444 pieces.append(spl[i])
445 continue
447 _, aligned, final = edit_distance_text( # pylint: disable=W0632
448 spl2, spl1, threshold=0.5)
449 ht = diff2html(spl2, spl1, aligned, final, two_columns=True)
450 ht = ht.replace(">``<", "><")
451 ht = ' ' + '\n '.join(ht.split('\n'))
452 pieces.extend(['', '**Differences**', '', '.. raw:: html',
453 '', ht, '', spl[i]])
455 return '\n'.join(pieces)
458def change_style(name):
459 """
460 Switches from *AaBb* into *aa_bb*.
462 :param name: name to convert
463 :return: converted name
465 Example:
467 .. runpython::
468 :showcode:
470 from mlprodict.npy.xop_auto import change_style
472 print("changeStyle --> {0}".format(change_style('change_style')))
473 """
474 s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
475 s2 = re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
476 return s2 if not keyword.iskeyword(s2) else s2 + "_"
479def get_onnx_example(op_name):
480 """
481 Retrieves examples associated to one operator
482 stored in onnx packages.
484 :param op_name: operator name
485 :param fmt: rendering format
486 :return: dictionary
487 """
488 modules = [
489 f'onnx.backend.test.case.node.{op_name.lower()}',
490 f'onnx.backend.test.case.node.{change_style(op_name).lower()}',
491 ]
492 module = None
493 for m in modules:
494 try:
495 mod = importlib.import_module(m)
496 module = m
497 except (AttributeError, ImportError):
498 continue
499 if module is None:
500 # Unable to find an example for 'op_name'.
501 return {}
502 results = {}
503 for v in mod.__dict__.values():
504 if not isinstance(v, _Exporter):
505 continue
506 code_cls = inspect.getsource(v)
507 codes = code_cls.split('@staticmethod')
508 for me in v.__dict__:
509 if not me.startswith('export'):
510 continue
511 sub = f' {me}()'
512 found = None
513 for code in codes:
514 if sub in code:
515 found = code
516 if found is None:
517 raise RuntimeError( # pragma: no cover
518 f"Unable to find {sub!r} in\n{code_cls}")
519 found = textwrap.dedent(found)
520 lines = found.split('\n')
521 first = 0
522 for i in range(len(lines)): # pylint: disable=C0200
523 if lines[i].startswith('def '):
524 first = i + 1
525 found = textwrap.dedent('\n'.join(lines[first:]))
526 key = me[len('export'):]
527 if key == '':
528 key = 'default'
529 if key in results:
530 key = f'example {len(results) + 1}'
531 results[key] = found
532 return results
535def is_last_schema(sch):
536 """
537 Tells if this is the most recent schema for this operator.
539 :param sch: schema
540 :return: True
541 """
542 try:
543 last = onnx.defs.get_schema(sch.name, domain=sch.domain)
544 except SchemaError: # pragma: no cover
545 # raise RuntimeError(
546 # "Unable to find schema for operator %r and domain %r."
547 # "" % (sch.name, sch.domain))
548 return True
549 return last.since_version == sch.since_version
552def onnx_documentation_folder(folder, ops=None, title='ONNX operators',
553 fLOG=None):
554 """
555 Creates documentation in a folder for all known
556 ONNX operators or a subset.
558 :param folder: folder where to write the documentation
559 :param ops: None for all operators or a subset of them
560 :param title: index title
561 :param fLOG: logging function
562 :return: list of creates files
563 """
564 all_schemas = _get_all_schemas_with_history()
565 if not os.path.exists(folder):
566 os.makedirs(folder)
567 index = ['', title, '=' * len(title), '', '.. contents::',
568 ' :local:', '']
569 pages = []
570 tables_domain_pages = []
572 if ops is not None:
573 ops = set(ops)
574 for dom in sorted(all_schemas):
575 sdom = 'main' if dom == '' else dom
577 index_dom = [sdom, '+' * len(sdom), '', '.. toctree::',
578 ' :maxdepth: 1', '']
580 table_dom = ["", f".. _l-table-operator-{sdom.replace('.', '-')}:", "",
581 f"operator table for domain {sdom}"]
582 table_dom.extend(["=" * len(table_dom[-1]), ""])
583 table_dom.extend([f".. list-table:: operators for domain {sdom}",
584 " :widths: 10 10",
585 " :header-rows: 1",
586 "",
587 " * - operator",
588 " - versions"])
590 sub = all_schemas[dom]
591 do = []
592 if ops is None:
593 do.extend(sub)
594 else:
595 inter = set(sub).intersection(ops)
596 if len(inter) == 0:
597 continue
598 do.extend(sorted(inter))
599 if len(do) == 0:
600 continue
602 for op in sorted(do):
603 if fLOG is not None:
604 fLOG( # pragma: no cover
605 f'generate page for onnx {dom!r} - {op!r}')
606 page_name = f"onnx_{dom.replace('.', '')}_{op}"
607 index_dom.append(f' {page_name}')
608 doc = get_rst_doc(op, domain=dom, version=None, example=True,
609 diff=True)
610 if dom == '':
611 main = op
612 else:
613 main = f'{dom} - {op}'
614 rows = ['', f'.. _l-onnx-doc{dom}-{op}:', '',
615 '=' * len(main), main, '=' * len(main), '',
616 '.. contents::', ' :local:', '', doc]
618 full = os.path.join(folder, page_name + '.rst')
619 with open(full, 'w', encoding='utf-8') as f:
620 f.write("\n".join(rows))
621 pages.append(full)
623 # table
624 schemas = get_operator_schemas(op, domain=dom, version=None)
625 links = []
626 for sch in schemas:
627 link = (
628 ':ref:`{sver} <l-onnx-op{lname_}-{lname}-{sver}>`').format(
629 sver=str(sch.since_version), lname=sch.name.lower(),
630 lname_=sch.domain.lower().replace(".", "-"))
631 links.append(link)
632 table_dom.extend([f" * - {op}",
633 f" - {', '.join(links)}"])
635 sdom_clean = sdom.replace('.', '_')
636 page_name = os.path.join(folder, f'table_{sdom_clean}.rst')
637 tables_domain_pages.append(f'table_{sdom_clean}')
638 pages.append(page_name)
639 with open(page_name, "w", encoding="utf-8") as f:
640 f.write("\n".join(table_dom))
642 index.extend(index_dom)
643 index.append('')
645 # adding pages
646 index.extend(["", "Tables", "++++++", "",
647 ".. toctree::", " :maxdepth: 1", ""])
648 for page in tables_domain_pages:
649 index.append(f" {page}")
650 index.append('')
652 # creating a big index
653 page_name = os.path.join(folder, 'index.rst')
654 with open(page_name, 'w', encoding='utf-8') as f:
655 f.write('\n'.join(index))
656 pages.append(page_name)
657 return pages