Coverage for mlprodict/onnxrt/doc/doc_helper.py: 96%

153 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Documentation helper. 

4""" 

5import keyword 

6import textwrap 

7import re 

8from onnx.defs import OpSchema 

9 

10 

11def type_mapping(name): 

12 """ 

13 Mapping between types name and type integer value. 

14 

15 .. runpython:: 

16 :showcode: 

17 :warningout: DeprecationWarning 

18 

19 from mlprodict.onnxrt.doc.doc_helper import type_mapping 

20 import pprint 

21 pprint.pprint(type_mapping(None)) 

22 print(type_mapping("INT")) 

23 print(type_mapping(2)) 

24 """ 

25 di = dict(FLOAT=1, FLOATS=6, GRAPH=5, GRAPHS=10, INT=2, 

26 INTS=7, STRING=3, STRINGS=8, TENSOR=4, 

27 TENSORS=9, UNDEFINED=0, SPARSE_TENSOR=11) 

28 if name is None: 

29 return di 

30 if isinstance(name, str): 

31 return di[name] 

32 rev = {v: k for k, v in di.items()} 

33 return rev[name] 

34 

35 

36def _get_doc_template(): 

37 

38 from jinja2 import Template # delayed import 

39 return Template(textwrap.dedent(""" 

40 {% for sch in schemas %} 

41 

42 {{format_name_with_domain(sch)}} 

43 {{'=' * len(format_name_with_domain(sch))}} 

44 

45 {{process_documentation(sch.doc)}} 

46 

47 {% if sch.attributes %} 

48 **Attributes** 

49 

50 {% for _, attr in sorted(sch.attributes.items()) %}* *{{attr.name}}*{% 

51 if attr.required %} (required){% endif %}: {{ 

52 process_attribute_doc(attr.description)}} {% 

53 if attr.default_value %} {{ 

54 process_default_value(attr.default_value) 

55 }} ({{type_mapping(attr.type)}}){% endif %} 

56 {% endfor %} 

57 {% endif %} 

58 

59 {% if sch.inputs %} 

60 **Inputs** 

61 

62 {% if sch.min_input != sch.max_input %}Between {{sch.min_input 

63 }} and {{sch.max_input}} inputs. 

64 {% endif %} 

65 {% for ii, inp in enumerate(sch.inputs) %} 

66 * *{{getname(inp, ii)}}*{{format_option(inp)}}{{inp.typeStr}}: {{ 

67 inp.description}}{% endfor %} 

68 {% endif %} 

69 

70 {% if sch.outputs %} 

71 **Outputs** 

72 

73 {% if sch.min_output != sch.max_output %}Between {{sch.min_output 

74 }} and {{sch.max_output}} outputs. 

75 {% endif %} 

76 {% for ii, out in enumerate(sch.outputs) %} 

77 * *{{getname(out, ii)}}*{{format_option(out)}}{{out.typeStr}}: {{ 

78 out.description}}{% endfor %} 

79 {% endif %} 

80 

81 {% if sch.type_constraints %} 

82 **Type Constraints** 

83 

84 {% for ii, type_constraint in enumerate(sch.type_constraints) 

85 %}* {{getconstraint(type_constraint, ii)}}: {{ 

86 type_constraint.description}} 

87 {% endfor %} 

88 {% endif %} 

89 

90 **Version** 

91 

92 *Onnx name:* `{{sch.name}} <{{build_doc_url(sch)}}{{sch.name}}>`_ 

93 

94 {% if sch.support_level == OpSchema.SupportType.EXPERIMENTAL %} 

95 No versioning maintained for experimental ops. 

96 {% else %} 

97 This version of the operator has been {% if 

98 sch.deprecated %}deprecated{% else %}available{% endif %} since 

99 version {{sch.since_version}}{% if 

100 sch.domain %} of domain {{sch.domain}}{% endif %}. 

101 {% if len(sch.versions) > 1 %} 

102 Other versions of this operator: 

103 {% for v in sch.version[:-1] %} {{v}} {% endfor %} 

104 {% endif %} 

105 {% endif %} 

106 

107 **Runtime implementation:** 

108 :class:`{{sch.name}} 

109 <mlprodict.onnxrt.ops_cpu.op_{{change_style(sch.name)}}.{{sch.name}}>` 

110 

111 {% endfor %} 

112 """)) 

113 

114 

115_template_operator = _get_doc_template() 

116 

117 

118class NewOperatorSchema: 

119 """ 

120 Defines a schema for operators added in this package 

121 such as @see cl TreeEnsembleRegressorDouble. 

122 """ 

123 

124 def __init__(self, name): 

125 self.name = name 

126 self.domain = 'mlprodict' 

127 

128 

129def change_style(name): 

130 """ 

131 Switches from *AaBb* into *aa_bb*. 

132 

133 @param name name to convert 

134 @return converted name 

135 """ 

136 s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) 

137 s2 = re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() 

138 return s2 if not keyword.iskeyword(s2) else s2 + "_" 

139 

140 

141def get_rst_doc(op_name): 

142 """ 

143 Returns a documentation in RST format 

144 for all :epkg:`OnnxOperator`. 

145 

146 :param op_name: operator name of None for all 

147 :return: string 

148 

149 The function relies on module :epkg:`jinja2` or replaces it 

150 with a simple rendering if not present. 

151 """ 

152 from jinja2.runtime import Undefined 

153 from ..ops_cpu._op import _schemas 

154 schemas = [_schemas.get(op_name, NewOperatorSchema(op_name))] 

155 

156 def format_name_with_domain(sch): 

157 if sch.domain: 

158 return f'{sch.name} ({sch.domain})' 

159 return sch.name 

160 

161 def format_option(obj): 

162 opts = [] 

163 if OpSchema.FormalParameterOption.Optional == obj.option: 

164 opts.append('optional') 

165 elif OpSchema.FormalParameterOption.Variadic == obj.option: 

166 opts.append('variadic') 

167 if getattr(obj, 'isHomogeneous', False): 

168 opts.append('heterogeneous') 

169 if opts: 

170 return f" ({', '.join(opts)})" 

171 return "" # pragma: no cover 

172 

173 def getconstraint(const, ii): 

174 if const.type_param_str: 

175 name = const.type_param_str 

176 else: 

177 name = str(ii) # pragma: no cover 

178 if const.allowed_type_strs: 

179 name += " " + ", ".join(const.allowed_type_strs) 

180 return name 

181 

182 def getname(obj, i): 

183 name = obj.name 

184 if len(name) == 0: 

185 return str(i) # pragma: no cover 

186 return name 

187 

188 def process_documentation(doc): 

189 if doc is None: 

190 doc = '' # pragma: no cover 

191 if isinstance(doc, Undefined): 

192 doc = '' # pragma: no cover 

193 if not isinstance(doc, str): 

194 raise TypeError( # pragma: no cover 

195 f"Unexpected type {type(doc)} for {doc}") 

196 doc = textwrap.dedent(doc) 

197 main_docs_url = "https://github.com/onnx/onnx/blob/master/" 

198 rep = { 

199 '[the doc](IR.md)': '`ONNX <{0}docs/IR.md>`_', 

200 '[the doc](Broadcasting.md)': 

201 '`Broadcasting in ONNX <{0}docs/Broadcasting.md>`_', 

202 '<dl>': '', 

203 '</dl>': '', 

204 '<dt>': '* ', 

205 '<dd>': ' ', 

206 '</dt>': '', 

207 '</dd>': '', 

208 '<tt>': '``', 

209 '</tt>': '``', 

210 '<br>': '\n', 

211 '```': '``', 

212 } 

213 for k, v in rep.items(): 

214 doc = doc.replace(k, v.format(main_docs_url)) 

215 move = 0 

216 lines = [] 

217 for line in doc.split('\n'): 

218 if line.startswith("```"): 

219 if move > 0: 

220 move -= 4 

221 lines.append("\n") 

222 else: 

223 lines.append("::\n") 

224 move += 4 

225 elif move > 0: 

226 lines.append(" " * move + line) 

227 else: 

228 lines.append(line) 

229 return "\n".join(lines) 

230 

231 def process_attribute_doc(doc): 

232 return doc.replace("<br>", " ") 

233 

234 def build_doc_url(sch): 

235 doc_url = "https://github.com/onnx/onnx/blob/master/docs/Operators" 

236 if "ml" in sch.domain: 

237 doc_url += "-ml" 

238 doc_url += ".md" 

239 doc_url += "#" 

240 if sch.domain not in (None, '', 'ai.onnx'): 

241 doc_url += sch.domain + "." 

242 return doc_url 

243 

244 def process_default_value(value): 

245 if value is None: 

246 return '' # pragma: no cover 

247 res = [] 

248 for c in str(value): 

249 if ((c >= 'A' and c <= 'Z') or (c >= 'a' and c <= 'z') or 

250 (c >= '0' and c <= '9')): 

251 res.append(c) 

252 continue 

253 if c in '[]-+(),.?': 

254 res.append(c) 

255 continue 

256 if len(res) == 0: 

257 return "*default value cannot be automatically retrieved*" 

258 return "Default value is ``" + ''.join(res) + "``" 

259 

260 fnwd = format_name_with_domain 

261 tmpl = _template_operator 

262 

263 docs = tmpl.render(schemas=schemas, OpSchema=OpSchema, 

264 len=len, getattr=getattr, sorted=sorted, 

265 format_option=format_option, 

266 getconstraint=getconstraint, 

267 getname=getname, enumerate=enumerate, 

268 format_name_with_domain=fnwd, 

269 process_documentation=process_documentation, 

270 build_doc_url=build_doc_url, str=str, 

271 type_mapping=type_mapping, 

272 process_attribute_doc=process_attribute_doc, 

273 process_default_value=process_default_value, 

274 change_style=change_style) 

275 return docs.replace(" Default value is ````", "") 

276 

277 

278def debug_onnx_object(obj, depth=3): 

279 """ 

280 ``__dict__`` is not in most of :epkg:`onnx` objects. 

281 This function uses function *dir* to explore this object. 

282 """ 

283 def iterable(o): 

284 try: 

285 iter(o) 

286 return True 

287 except TypeError: 

288 return False 

289 

290 if depth <= 0: 

291 return None 

292 

293 rows = [str(type(obj))] 

294 if not isinstance(obj, (int, str, float, bool)): 

295 

296 for k in sorted(dir(obj)): 

297 try: 

298 val = getattr(obj, k) 

299 sval = str(val).replace("\n", " ") 

300 except (AttributeError, ValueError) as e: # pragma: no cover 

301 sval = "ERRROR-" + str(e) 

302 val = None 

303 

304 if 'method-wrapper' in sval or "built-in method" in sval: 

305 continue 

306 

307 rows.append(f"- {k}: {sval}") 

308 if k.startswith('__') and k.endswith('__'): 

309 continue 

310 if val is None: 

311 continue 

312 

313 if isinstance(val, dict): 

314 try: 

315 sorted_list = list(sorted(val.items())) 

316 except TypeError: # pragma: no cover 

317 sorted_list = list(val.items()) 

318 for kk, vv in sorted_list: 

319 rows.append(f" - [{str(kk)}]: {str(vv)}") 

320 res = debug_onnx_object(vv, depth - 1) 

321 if res is None: 

322 continue 

323 for line in res.split("\n"): 

324 rows.append(" " + line) 

325 elif iterable(val): 

326 if all(map(lambda o: isinstance(o, (str, bytes)) and len(o) == 1, val)): 

327 continue 

328 for i, vv in enumerate(val): 

329 rows.append(" - [%d]: %s" % (i, str(vv))) 

330 res = debug_onnx_object(vv, depth - 1) 

331 if res is None: 

332 continue 

333 for line in res.split("\n"): 

334 rows.append(" " + line) 

335 elif not callable(val): 

336 res = debug_onnx_object(val, depth - 1) 

337 if res is None: 

338 continue 

339 for line in res.split("\n"): 

340 rows.append(" " + line) 

341 

342 return "\n".join(rows) 

343 

344 

345def visual_rst_template(): 

346 """ 

347 Returns a :epkg:`jinja2` template to display DOT graph for each 

348 converter from :epkg:`sklearn-onnx`. 

349 

350 .. runpython:: 

351 :showcode: 

352 :warningout: DeprecationWarning 

353 

354 from mlprodict.onnxrt.doc.doc_helper import visual_rst_template 

355 print(visual_rst_template()) 

356 """ 

357 return textwrap.dedent(""" 

358 

359 .. _l-{{link}}: 

360 

361 {{ title }} 

362 {{ '=' * len(title) }} 

363 

364 Fitted on a problem type *{{ kind }}* 

365 (see :func:`find_suitable_problem 

366 <mlprodict.onnxrt.validate.validate_problems.find_suitable_problem>`), 

367 method `{{ method }}` matches output {{ output_index }}. 

368 {{ optim_param }} 

369 

370 :: 

371 

372 {{ indent(model, " ") }} 

373 

374 {{ table }} 

375 

376 .. gdot:: 

377 

378 {{ indent(dot, " ") }} 

379 """)