Coverage for mlprodict/testing/script_testing.py: 100%

66 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Utilies to test script from :epkg:`scikit-learn` documentation. 

4""" 

5import os 

6from io import StringIO 

7from contextlib import redirect_stdout, redirect_stderr 

8import pprint 

9import numpy 

10from sklearn.base import BaseEstimator 

11from .verify_code import verify_code 

12 

13 

14class MissingVariableError(RuntimeError): 

15 """ 

16 Raised when a variable is missing. 

17 """ 

18 pass 

19 

20 

21def _clean_script(content): 

22 """ 

23 Comments out all lines containing ``.show()``. 

24 """ 

25 new_lines = [] 

26 for line in content.split('\n'): 

27 if '.show()' in line or 'sys.exit' in line: 

28 new_lines.append("# " + line) 

29 else: 

30 new_lines.append(line) 

31 return "\n".join(new_lines) 

32 

33 

34def _enumerate_fit_info(fits): 

35 """ 

36 Extracts the name of the fitted models and the data 

37 used to train it. 

38 """ 

39 for fit in fits: 

40 chs = fit['children'] 

41 if len(chs) < 2: 

42 # unable to extract the needed information 

43 continue # pragma: no cover 

44 model = chs[0]['str'] 

45 if model.endswith('.fit'): 

46 model = model[:-4] 

47 args = [ch['str'] for ch in chs[1:]] 

48 yield model, args 

49 

50 

51def _try_onnx(loc, model_name, args_name, **options): 

52 """ 

53 Tries onnx conversion. 

54 

55 @param loc available variables 

56 @param model_name model name among these variables 

57 @param args_name arguments name among these variables 

58 @param options additional options for the conversion 

59 @return onnx model 

60 """ 

61 from ..onnx_conv import to_onnx 

62 if model_name not in loc: 

63 raise MissingVariableError( # pragma: no cover 

64 f"Unable to find model '{model_name}' in {', '.join(sorted(loc))}") 

65 if args_name[0] not in loc: 

66 raise MissingVariableError( # pragma: no cover 

67 f"Unable to find data '{args_name[0]}' in {', '.join(sorted(loc))}") 

68 model = loc[model_name] 

69 X = loc[args_name[0]] 

70 dtype = options.get('dtype', numpy.float32) 

71 Xt = X.astype(dtype) 

72 onx = to_onnx(model, Xt, **options) 

73 args = dict(onx=onx, model=model, X=Xt) 

74 return onx, args 

75 

76 

77def verify_script(file_or_name, try_onnx=True, existing_loc=None, 

78 **options): 

79 """ 

80 Checks that models fitted in an example from :epkg:`scikit-learn` 

81 documentation can be converted into :epkg:`ONNX`. 

82 

83 @param file_or_name file or string 

84 @param try_onnx try the onnx conversion 

85 @param existing_loc existing local variables 

86 @param options conversion options 

87 @return list of converted models 

88 """ 

89 if '\n' not in file_or_name and os.path.exists(file_or_name): 

90 filename = file_or_name 

91 with open(file_or_name, 'r', encoding='utf-8') as f: 

92 content = f.read() 

93 else: # pragma: no cover 

94 content = file_or_name 

95 filename = "<string>" 

96 

97 # comments out .show() 

98 content = _clean_script(content) 

99 

100 # look for fit or predict expressions 

101 _, node = verify_code(content, exc=False) 

102 fits = node._fits 

103 models_args = list(_enumerate_fit_info(fits)) 

104 

105 # execution 

106 obj = compile(content, filename, 'exec') 

107 glo = globals().copy() 

108 loc = {} 

109 if existing_loc is not None: 

110 loc.update(existing_loc) # pragma: no cover 

111 glo.update(existing_loc) # pragma: no cover 

112 out = StringIO() 

113 err = StringIO() 

114 

115 with redirect_stdout(out): 

116 with redirect_stderr(err): 

117 exec(obj, glo, loc) # pylint: disable=W0122 

118 

119 # filter out values 

120 cls = (BaseEstimator, numpy.ndarray) 

121 loc_fil = {k: v for k, v in loc.items() if isinstance(v, cls)} 

122 glo_fil = {k: v for k, v in glo.items() if k not in {'__builtins__'}} 

123 onx_info = [] 

124 

125 # onnx 

126 if try_onnx: 

127 if len(models_args) == 0: 

128 raise MissingVariableError( # pragma: no cover 

129 "No detected trained model in '{}'\n{}\n--LOCALS--\n{}".format( 

130 filename, content, pprint.pformat(loc_fil))) 

131 for model_args in models_args: 

132 try: 

133 onx, args = _try_onnx(loc_fil, *model_args, **options) 

134 except MissingVariableError as e: # pragma: no cover 

135 raise MissingVariableError("Unable to find variable in '{}'\n{}".format( 

136 filename, pprint.pformat(fits))) from e 

137 loc_fil[model_args[0] + "_onnx"] = onx 

138 onx_info.append(args) 

139 

140 # final results 

141 return dict(locals=loc_fil, globals=glo_fil, 

142 stdout=out.getvalue(), 

143 stderr=err.getvalue(), 

144 onx_info=onx_info)