Coverage for mlprodict/onnxrt/ops_cpu/op_zipmap.py: 96%

114 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2# pylint: disable=E0203,E1101,C0111 

3""" 

4@file 

5@brief Runtime operator. 

6""" 

7import numpy 

8from ._op import OpRun 

9 

10 

11class ZipMapDictionary(dict): 

12 """ 

13 Custom dictionary class much faster for this runtime, 

14 it implements a subset of the same methods. 

15 """ 

16 __slots__ = ['_rev_keys', '_values', '_mat'] 

17 

18 @staticmethod 

19 def build_rev_keys(keys): 

20 res = {} 

21 for i, k in enumerate(keys): 

22 res[k] = i 

23 return res 

24 

25 def __init__(self, rev_keys, values, mat=None): 

26 """ 

27 @param rev_keys returns by @see me build_rev_keys, 

28 *{keys: column index}* 

29 @param values values 

30 @param mat matrix if values is a row index, 

31 one or two dimensions 

32 """ 

33 if mat is not None: 

34 if not isinstance(mat, numpy.ndarray): 

35 raise TypeError( # pragma: no cover 

36 f'matrix is expected, got {type(mat)}.') 

37 if len(mat.shape) not in (2, 3): 

38 raise ValueError( # pragma: no cover 

39 f"matrix must have two or three dimensions but got {mat.shape}.") 

40 dict.__init__(self) 

41 self._rev_keys = rev_keys 

42 self._values = values 

43 self._mat = mat 

44 

45 def __getstate__(self): 

46 """ 

47 For pickle. 

48 """ 

49 return dict(_rev_keys=self._rev_keys, 

50 _values=self._values, 

51 _mat=self._mat) 

52 

53 def __setstate__(self, state): 

54 """ 

55 For pickle. 

56 """ 

57 if isinstance(state, tuple): 

58 state = state[1] 

59 self._rev_keys = state['_rev_keys'] 

60 self._values = state['_values'] 

61 self._mat = state['_mat'] 

62 

63 def __getitem__(self, key): 

64 """ 

65 Returns the item mapped to keys. 

66 """ 

67 if self._mat is None: 

68 return self._values[self._rev_keys[key]] 

69 return self._mat[self._values, self._rev_keys[key]] 

70 

71 def __setitem__(self, pos, value): 

72 "unused but used by pickle" 

73 pass 

74 

75 def __len__(self): 

76 """ 

77 Returns the number of items. 

78 """ 

79 return len(self._values) if self._mat is None else self._mat.shape[1] 

80 

81 def __iter__(self): 

82 for k in self._rev_keys: 

83 yield k 

84 

85 def __contains__(self, key): 

86 return key in self._rev_keys 

87 

88 def items(self): 

89 if self._mat is None: 

90 for k, v in self._rev_keys.items(): 

91 yield k, self._values[v] 

92 else: 

93 for k, v in self._rev_keys.items(): 

94 yield k, self._mat[self._values, v] 

95 

96 def keys(self): 

97 for k in self._rev_keys.keys(): 

98 yield k 

99 

100 def values(self): 

101 if self._mat is None: 

102 for v in self._values: 

103 yield v 

104 else: 

105 for v in self._mat[self._values]: 

106 yield v 

107 

108 def asdict(self): 

109 res = {} 

110 for k, v in self.items(): 

111 res[k] = v 

112 return res 

113 

114 def __str__(self): 

115 return f"ZipMap({str(self.asdict())!r})" 

116 

117 

118class ArrayZipMapDictionary(list): 

119 """ 

120 Mocks an array without changing the data it receives. 

121 Notebooks :ref:`onnxnodetimerst` illustrates the weaknesses 

122 and the strengths of this class compare to a list 

123 of dictionaries. 

124 

125 .. index:: ZipMap 

126 """ 

127 

128 def __init__(self, rev_keys, mat): 

129 """ 

130 @param rev_keys dictionary *{keys: column index}* 

131 @param mat matrix if values is a row index, 

132 one or two dimensions 

133 """ 

134 if mat is not None: 

135 if not isinstance(mat, numpy.ndarray): 

136 raise TypeError( # pragma: no cover 

137 f'matrix is expected, got {type(mat)}.') 

138 if len(mat.shape) not in (2, 3): 

139 raise ValueError( # pragma: no cover 

140 f"matrix must have two or three dimensions but got {mat.shape}.") 

141 list.__init__(self) 

142 self._rev_keys = rev_keys 

143 self._mat = mat 

144 

145 @property 

146 def dtype(self): 

147 return self._mat.dtype 

148 

149 def __len__(self): 

150 return self._mat.shape[0] 

151 

152 def __iter__(self): 

153 for i in range(len(self)): 

154 yield self[i] 

155 

156 def __getitem__(self, i): 

157 return ZipMapDictionary(self._rev_keys, i, self._mat) 

158 

159 def __setitem__(self, pos, value): 

160 raise RuntimeError( 

161 f"Changing an element is not supported (pos=[{pos}]).") 

162 

163 @property 

164 def values(self): 

165 """ 

166 Equivalent to ``DataFrame(self).values``. 

167 """ 

168 if len(self._mat.shape) == 3: 

169 return self._mat.reshape((self._mat.shape[1], -1)) 

170 return self._mat 

171 

172 @property 

173 def columns(self): 

174 """ 

175 Equivalent to ``DataFrame(self).columns``. 

176 """ 

177 res = [(v, k) for k, v in self._rev_keys.items()] 

178 if len(res) == 0: 

179 if len(self._mat.shape) == 2: 

180 res = [(i, 'c%d' % i) for i in range(self._mat.shape[1])] 

181 elif len(self._mat.shape) == 3: 

182 # multiclass 

183 res = [(i, 'c%d' % i) 

184 for i in range(self._mat.shape[0] * self._mat.shape[2])] 

185 else: 

186 raise RuntimeError( # pragma: no cover 

187 "Unable to guess the right number of columns for " 

188 "shapes: {}".format(self._mat.shape)) 

189 else: 

190 res.sort() 

191 return [_[1] for _ in res] 

192 

193 @property 

194 def is_zip_map(self): 

195 return True 

196 

197 def __str__(self): 

198 return f"ZipMaps[{', '.join(map(str, self))}]" 

199 

200 

201class ZipMap(OpRun): 

202 """ 

203 The class does not output a dictionary as 

204 specified in :epkg:`ONNX` specifications 

205 but a @see cl ArrayZipMapDictionary which 

206 is wrapper on the input so that it does not 

207 get copied. 

208 """ 

209 

210 atts = {'classlabels_int64s': [], 'classlabels_strings': []} 

211 

212 def __init__(self, onnx_node, desc=None, **options): 

213 OpRun.__init__(self, onnx_node, desc=desc, 

214 expected_attributes=ZipMap.atts, 

215 **options) 

216 if hasattr(self, 'classlabels_int64s') and len(self.classlabels_int64s) > 0: 

217 self.rev_keys_ = ZipMapDictionary.build_rev_keys( 

218 self.classlabels_int64s) 

219 elif hasattr(self, 'classlabels_strings') and len(self.classlabels_strings) > 0: 

220 self.rev_keys_ = ZipMapDictionary.build_rev_keys( 

221 self.classlabels_strings) 

222 else: 

223 self.rev_keys_ = {} 

224 

225 def _run(self, x, attributes=None, verbose=0, fLOG=None): # pylint: disable=W0221 

226 res = ArrayZipMapDictionary(self.rev_keys_, x) 

227 return (res, )