Coverage for mlprodict/testing/script_testing.py: 100%
66 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Utilies to test script from :epkg:`scikit-learn` documentation.
4"""
5import os
6from io import StringIO
7from contextlib import redirect_stdout, redirect_stderr
8import pprint
9import numpy
10from sklearn.base import BaseEstimator
11from .verify_code import verify_code
14class MissingVariableError(RuntimeError):
15 """
16 Raised when a variable is missing.
17 """
18 pass
21def _clean_script(content):
22 """
23 Comments out all lines containing ``.show()``.
24 """
25 new_lines = []
26 for line in content.split('\n'):
27 if '.show()' in line or 'sys.exit' in line:
28 new_lines.append("# " + line)
29 else:
30 new_lines.append(line)
31 return "\n".join(new_lines)
34def _enumerate_fit_info(fits):
35 """
36 Extracts the name of the fitted models and the data
37 used to train it.
38 """
39 for fit in fits:
40 chs = fit['children']
41 if len(chs) < 2:
42 # unable to extract the needed information
43 continue # pragma: no cover
44 model = chs[0]['str']
45 if model.endswith('.fit'):
46 model = model[:-4]
47 args = [ch['str'] for ch in chs[1:]]
48 yield model, args
51def _try_onnx(loc, model_name, args_name, **options):
52 """
53 Tries onnx conversion.
55 @param loc available variables
56 @param model_name model name among these variables
57 @param args_name arguments name among these variables
58 @param options additional options for the conversion
59 @return onnx model
60 """
61 from ..onnx_conv import to_onnx
62 if model_name not in loc:
63 raise MissingVariableError( # pragma: no cover
64 f"Unable to find model '{model_name}' in {', '.join(sorted(loc))}")
65 if args_name[0] not in loc:
66 raise MissingVariableError( # pragma: no cover
67 f"Unable to find data '{args_name[0]}' in {', '.join(sorted(loc))}")
68 model = loc[model_name]
69 X = loc[args_name[0]]
70 dtype = options.get('dtype', numpy.float32)
71 Xt = X.astype(dtype)
72 onx = to_onnx(model, Xt, **options)
73 args = dict(onx=onx, model=model, X=Xt)
74 return onx, args
77def verify_script(file_or_name, try_onnx=True, existing_loc=None,
78 **options):
79 """
80 Checks that models fitted in an example from :epkg:`scikit-learn`
81 documentation can be converted into :epkg:`ONNX`.
83 @param file_or_name file or string
84 @param try_onnx try the onnx conversion
85 @param existing_loc existing local variables
86 @param options conversion options
87 @return list of converted models
88 """
89 if '\n' not in file_or_name and os.path.exists(file_or_name):
90 filename = file_or_name
91 with open(file_or_name, 'r', encoding='utf-8') as f:
92 content = f.read()
93 else: # pragma: no cover
94 content = file_or_name
95 filename = "<string>"
97 # comments out .show()
98 content = _clean_script(content)
100 # look for fit or predict expressions
101 _, node = verify_code(content, exc=False)
102 fits = node._fits
103 models_args = list(_enumerate_fit_info(fits))
105 # execution
106 obj = compile(content, filename, 'exec')
107 glo = globals().copy()
108 loc = {}
109 if existing_loc is not None:
110 loc.update(existing_loc) # pragma: no cover
111 glo.update(existing_loc) # pragma: no cover
112 out = StringIO()
113 err = StringIO()
115 with redirect_stdout(out):
116 with redirect_stderr(err):
117 exec(obj, glo, loc) # pylint: disable=W0122
119 # filter out values
120 cls = (BaseEstimator, numpy.ndarray)
121 loc_fil = {k: v for k, v in loc.items() if isinstance(v, cls)}
122 glo_fil = {k: v for k, v in glo.items() if k not in {'__builtins__'}}
123 onx_info = []
125 # onnx
126 if try_onnx:
127 if len(models_args) == 0:
128 raise MissingVariableError( # pragma: no cover
129 "No detected trained model in '{}'\n{}\n--LOCALS--\n{}".format(
130 filename, content, pprint.pformat(loc_fil)))
131 for model_args in models_args:
132 try:
133 onx, args = _try_onnx(loc_fil, *model_args, **options)
134 except MissingVariableError as e: # pragma: no cover
135 raise MissingVariableError("Unable to find variable in '{}'\n{}".format(
136 filename, pprint.pformat(fits))) from e
137 loc_fil[model_args[0] + "_onnx"] = onx
138 onx_info.append(args)
140 # final results
141 return dict(locals=loc_fil, globals=glo_fil,
142 stdout=out.getvalue(),
143 stderr=err.getvalue(),
144 onx_info=onx_info)