Coverage for mlinsights/mlbatch/cache_model.py: 97%

61 statements  

« prev     ^ index     » next       coverage.py v6.4.2, created at 2022-08-09 08:45 +0200

1""" 

2@file 

3@brief Caches to cache training. 

4""" 

5import numpy 

6 

7_caches = {} 

8 

9 

10class MLCache: 

11 """ 

12 Implements a cache to reduce the number of trainings 

13 a grid search has to do. 

14 """ 

15 

16 def __init__(self, name): 

17 """ 

18 @param name name of the cache 

19 """ 

20 self.name = name 

21 self.cached = {} 

22 self.count_ = {} 

23 

24 def cache(self, params, value): 

25 """ 

26 Caches one object. 

27 

28 @param params dictionary of parameters 

29 @param value value to cache 

30 """ 

31 key = MLCache.as_key(params) 

32 if key in self.cached: 

33 raise KeyError( # pragma: no cover 

34 f"Key {params} already exists") 

35 self.cached[key] = value 

36 self.count_[key] = 0 

37 

38 def get(self, params, default=None): 

39 """ 

40 Retrieves an element from the cache. 

41 

42 @param params dictionary of parameters 

43 @param default if not found 

44 @return value or None if it does not exists 

45 """ 

46 key = MLCache.as_key(params) 

47 res = self.cached.get(key, default) 

48 if res != default: 

49 self.count_[key] += 1 

50 return res 

51 

52 def count(self, params): 

53 """ 

54 Retrieves the number of times 

55 an elements was retrieved from the cache. 

56 

57 @param params dictionary of parameters 

58 @return int 

59 """ 

60 key = MLCache.as_key(params) 

61 return self.count_.get(key, 0) 

62 

63 @staticmethod 

64 def as_key(params): 

65 """ 

66 Converts a list of parameters into a key. 

67 

68 @param params dictionary 

69 @return key as a string 

70 """ 

71 if isinstance(params, str): 

72 return params 

73 els = [] 

74 for k, v in sorted(params.items()): 

75 if isinstance(v, (int, float, str)): 

76 sv = str(v) 

77 elif isinstance(v, tuple): 

78 if not all(map(lambda e: isinstance(e, (int, float, str)), v)): 

79 raise TypeError( # pragma: no cover 

80 f"Unable to create a key with value '{k}':{v}") 

81 return str(v) 

82 elif isinstance(v, numpy.ndarray): 

83 # id(v) may have been better but 

84 # it does not play well with joblib. 

85 sv = hash(v.tostring()) 

86 elif v is None: 

87 sv = "" 

88 else: 

89 raise TypeError( # pragma: no cover 

90 f"Unable to create a key with value '{k}':{v}") 

91 els.append((k, sv)) 

92 return str(els) 

93 

94 def __len__(self): 

95 """ 

96 Returns the number of cached items. 

97 """ 

98 return len(self.cached) 

99 

100 def items(self): 

101 """ 

102 Enumerates all cached items. 

103 """ 

104 for item in self.cached.items(): 

105 yield item 

106 

107 def keys(self): 

108 """ 

109 Enumerates all cached keys. 

110 """ 

111 for k in self.cached.keys(): # pylint: disable=C0201 

112 yield k 

113 

114 @staticmethod 

115 def create_cache(name): 

116 """ 

117 Creates a new cache. 

118 

119 @param name name 

120 @return created cache 

121 """ 

122 global _caches # pylint: disable=W0603,W0602 

123 if name in _caches: 

124 raise RuntimeError( # pragma: no cover 

125 f"cache '{name}' already exists.") 

126 

127 cache = MLCache(name) 

128 _caches[name] = cache 

129 return cache 

130 

131 @staticmethod 

132 def remove_cache(name): 

133 """ 

134 Removes a cache with a given name. 

135 

136 @param name name 

137 """ 

138 global _caches # pylint: disable=W0603,W0602 

139 del _caches[name] 

140 

141 @staticmethod 

142 def get_cache(name): 

143 """ 

144 Gets a cache with a given name. 

145 

146 @param name name 

147 @return created cache 

148 """ 

149 global _caches # pylint: disable=W0603,W0602 

150 return _caches[name] 

151 

152 @staticmethod 

153 def has_cache(name): 

154 """ 

155 Tells if cache *name* is present. 

156 

157 @param name name 

158 @return boolean 

159 """ 

160 global _caches # pylint: disable=W0603,W0602 

161 return name in _caches