Coverage for mlprodict/tools/code_helper.py: 100%

89 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief A couple of tools unrelated to what the package does. 

4""" 

5import pickle 

6import re 

7import types 

8import numpy 

9 

10 

11def numpy_min_max(x, fct, minmax=False): 

12 """ 

13 Returns the minimum of an array. 

14 Deals with text as well. 

15 """ 

16 try: 

17 if hasattr(x, 'todense'): 

18 x = x.todense() 

19 if (x.dtype.kind[0] not in 'Uc' or 

20 x.dtype in {numpy.uint8}): 

21 return fct(x) 

22 try: # pragma: no cover 

23 x = x.ravel() 

24 except AttributeError: # pragma: no cover 

25 pass 

26 keep = list(filter(lambda s: isinstance(s, str), x)) 

27 if len(keep) == 0: # pragma: no cover 

28 return numpy.nan 

29 keep.sort(reverse=minmax) 

30 val = keep[0] 

31 if len(val) > 10: # pragma: no cover 

32 val = val[:10] + '...' 

33 return f"{val!r}" 

34 except (ValueError, TypeError, AttributeError): 

35 return '?' 

36 

37 

38def numpy_min(x): 

39 """ 

40 Returns the maximum of an array. 

41 Deals with text as well. 

42 """ 

43 return numpy_min_max(x, lambda x: x.min(), minmax=False) 

44 

45 

46def numpy_max(x): 

47 """ 

48 Returns the maximum of an array. 

49 Deals with text as well. 

50 """ 

51 return numpy_min_max(x, lambda x: x.max(), minmax=True) 

52 

53 

54def debug_dump(clname, obj, folder=None, ops=None): 

55 """ 

56 Dumps an object for debug purpose. 

57 

58 @param clname class name 

59 @param obj object 

60 @param folder folder 

61 @param ops operator to dump 

62 @return filename 

63 """ 

64 def debug_print_(obj, prefix=''): 

65 name = clname 

66 if isinstance(obj, dict): 

67 if 'in' in obj and 'out' in obj: 

68 nan_in = any(map(lambda o: any(map(numpy.isnan, o.ravel())), 

69 obj['in'])) 

70 nan_out = any(map(lambda o: any(map(numpy.isnan, o.ravel())), 

71 obj['out'])) 

72 if not nan_in and nan_out: 

73 print("NAN-notin-out ", name, prefix, 

74 {k: getattr(ops, k, '?') for k in getattr(ops, 'atts', {})}) 

75 return True 

76 return False # pragma: no cover 

77 for k, v in obj.items(): # pragma: no cover 

78 debug_print_([v], k) 

79 return None # pragma: no cover 

80 if isinstance(obj, list): 

81 for i, o in enumerate(obj): 

82 if o is None: 

83 continue 

84 if any(map(numpy.isnan, o.ravel())): 

85 print("NAN", prefix, i, name, o.shape) 

86 return None 

87 raise NotImplementedError( # pragma: no cover 

88 f"Unable to debug object of type {type(obj)}.") 

89 

90 dump = debug_print_(obj) 

91 if dump: 

92 name = f'cpu-{clname}-{id(obj)}-{id(ops)}.pkl' 

93 if folder is not None: 

94 name = "/".join([folder, name]) 

95 with open(name, 'wb') as f: 

96 pickle.dump(obj, f) 

97 return name 

98 return None 

99 

100 

101def debug_print(k, obj, printed): 

102 """ 

103 Displays informations on an object. 

104 

105 @param k name 

106 @param obj object 

107 @param printed memorizes already printed object 

108 """ 

109 if k not in printed: 

110 printed[k] = obj 

111 if hasattr(obj, 'shape'): 

112 print("-='{}' shape={} dtype={} min={} max={}{}".format( 

113 k, obj.shape, obj.dtype, numpy_min(obj), 

114 numpy_max(obj), 

115 ' (sparse)' if 'coo_matrix' in str(type(obj)) else '')) 

116 elif (isinstance(obj, list) and len(obj) > 0 and 

117 not isinstance(obj[0], dict)): # pragma: no cover 

118 print(f"-='{k}' list len={len(obj)} min={min(obj)} max={max(obj)}") 

119 else: # pragma: no cover 

120 print(f"-='{k}' type={type(obj)}") 

121 

122 

123def make_callable(fct, obj, code, gl, debug): 

124 """ 

125 Creates a callable function able to 

126 cope with default values as the combination 

127 of functions *compile* and *exec* does not seem 

128 able to take them into account. 

129 

130 @param fct function name 

131 @param obj output of function *compile* 

132 @param code code including the signature 

133 @param gl context (local and global) 

134 @param debug add debug function 

135 @return callable functions 

136 """ 

137 cst = "def " + fct + "(" 

138 sig = None 

139 for line in code.split('\n'): 

140 if line.startswith(cst): 

141 sig = line 

142 break 

143 if sig is None: # pragma: no cover 

144 raise ValueError( 

145 f"Unable to find function '{fct}' in\n{code}") 

146 reg = re.compile( 

147 "([a-z][A-Za-z_0-9]*)=((None)|(False)|(True)|([0-9.e+-]+))") 

148 fall = reg.findall(sig) 

149 defs = [] 

150 for name_value in fall: 

151 name = name_value[0] 

152 value = name_value[1] 

153 if value == 'None': 

154 defs.append((name, None)) 

155 continue 

156 if value == 'True': 

157 defs.append((name, True)) 

158 continue 

159 if value == 'False': 

160 defs.append((name, False)) 

161 continue 

162 f = float(value) 

163 if int(f) == f: 

164 f = int(f) 

165 defs.append((name, f)) 

166 

167 # debug 

168 if debug: 

169 gl = gl.copy() 

170 gl['debug_print'] = debug_print 

171 gl['print'] = print 

172 # specific 

173 if "value=array([0.], dtype=float32)" in sig: 

174 defs.append(('value', numpy.array([0.], dtype=numpy.float32))) 

175 res = types.FunctionType(obj, gl, fct, tuple(_[1] for _ in defs)) 

176 if res.__defaults__ != tuple(_[1] for _ in defs): # pylint: disable=E1101 

177 # See https://docs.python.org/3/library/inspect.html 

178 # See https://stackoverflow.com/questions/11291242/python-dynamically-create-function-at-runtime 

179 lines = [str(sig)] # pragma: no cover 

180 for name in ['co_argcount', 'co_cellvars', 'co_code', 'co_consts', 'co_filename', 

181 'co_firstlineno', 'co_flags', 'co_freevars', 'co_kwonlyargcount', 

182 'co_lnotab', 'co_name', 'co_names', 'co_nlocals', 'co_stacksize', 

183 'co_varnames']: # pragma: no cover 

184 v = getattr(res.__code__, name, None) # pylint: disable=E1101 

185 if v is not None: 

186 lines.append(f'{name}={v!r}') 

187 raise RuntimeError( # pragma: no cover 

188 "Defaults values of function '{}' (defaults={}) are missing.\nDefault: " 

189 "{}\n{}\n----\n{}".format( 

190 fct, res.__defaults__, defs, "\n".join(lines), code)) # pylint: disable=E1101 

191 return res 

192 

193 

194def print_code(code, begin=1): 

195 """ 

196 Returns the code with line number. 

197 """ 

198 rows = code.split("\n") 

199 return "\n".join("%03d %s" % (i + begin, s) 

200 for i, s in enumerate(rows))