Coverage for mlprodict/onnxrt/ops_shape/shape_container.py: 93%

156 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Class ShapeContainer 

4""" 

5import pprint 

6from .shape_result import ShapeResult 

7 

8 

9class ShapeContainer: 

10 """ 

11 Stores all infered shapes as @see cl ShapeResult. 

12 

13 Attributes: 

14 

15 * `shapes`: dictionary `{ result name: ShapeResult }` 

16 * `names`: some dimensions are unknown and represented as 

17 variables, this dictionary keeps track of them 

18 * `names_rev`: reverse dictionary of `names` 

19 """ 

20 

21 def __init__(self): 

22 self.shapes = dict() 

23 self.names = dict() 

24 self.names_rev = dict() 

25 

26 def __repr__(self): 

27 "usual" 

28 return f"{self.__class__.__name__}()" 

29 

30 def __len__(self): 

31 "usual" 

32 return len(self.shapes) 

33 

34 def __getitem__(self, key): 

35 "Retrieves one shape from its name." 

36 return self.shapes[key] 

37 

38 def copy(self, deep=False): 

39 "Makes a copy." 

40 cont = ShapeContainer() 

41 cont.shapes = {k: v.copy(deep=deep) for k, v in self.shapes.items()} 

42 cont.names = self.names.copy() 

43 cont.names_rev = {k: v.copy() for k, v in self.names_rev.items()} 

44 return cont 

45 

46 def update(self, key, value): 

47 """ 

48 Updates one shape. Returns True if the shape was different. 

49 """ 

50 if not isinstance(key, str): 

51 raise TypeError( # pragma: no cover 

52 f"key must be a string not {type(key)!r}.") 

53 if not isinstance(value, ShapeResult): 

54 raise TypeError( # pragma: no cover 

55 f"value must be a ShapeResult not {type(key)!r}.") 

56 if key not in self.shapes: 

57 self.shapes[key] = value 

58 return True 

59 r = self.shapes[key].merge(value) 

60 return r 

61 

62 def __contains__(self, key): 

63 "Operator in." 

64 return key in self.shapes 

65 

66 def __str__(self): 

67 """ 

68 Displays. 

69 """ 

70 rows = ["ShapeContainer({"] 

71 for k, v in self.shapes.items(): 

72 rows.append(f" {k!r}: {v!r}") 

73 rows.append("}, names={") 

74 for k, v in self.names.items(): 

75 rows.append(f" {k!r}: {v!r}") 

76 cst = self.get_all_constraints() 

77 if len(cst) > 0: 

78 rows.append("}, constraint={") 

79 for c, v in cst.items(): 

80 rows.append(f" {c!r}: {v!r}") 

81 rows.append("})") 

82 else: 

83 rows.append("})") 

84 

85 return "\n".join(rows) 

86 

87 def get_new_name(self, name, result_name, dim): 

88 """ 

89 Returns a variable name when a dimension is not 

90 specified. 

91 """ 

92 if name is not None and not isinstance(name, str): 

93 raise TypeError( # pragma: no cover 

94 f"name must be string not {name!r}.") 

95 if name is None: 

96 name = '' 

97 if name == '' or name not in self.names: 

98 i = 0 

99 new_name = "%s_%d" % (name, i) 

100 while new_name in self.names: 

101 i += 1 

102 new_name = "%s_%d" % (name, i) 

103 self.names[new_name] = (name, result_name, dim) 

104 if name not in self.names_rev: 

105 self.names_rev[name] = [] 

106 self.names_rev[name].append(new_name) 

107 return new_name 

108 val = self.names_rev[name] 

109 if len(val) != 1: 

110 raise RuntimeError( # pragma: no cover 

111 f"Name {name!r} has more than one correspondance ({val!r}).") 

112 return val[0] 

113 

114 def get_all_constraints(self): 

115 """ 

116 Gathers all constraints. 

117 """ 

118 cons = {} 

119 for _, v in self.shapes.items(): 

120 if v.constraints is not None: 

121 for c in v.constraints: 

122 if c.name not in cons: 

123 cons[c.name] = [] 

124 cons[c.name].append(c) 

125 for _, v in cons.items(): 

126 if len(v) > 1: 

127 v[0].merge(v[1:]) 

128 del v[1:] 

129 return cons 

130 

131 def get(self): 

132 """ 

133 Returns the value of attribute `resolved_` 

134 (method `resolve()` must have been called first). 

135 """ 

136 if not hasattr(self, 'resolved_') or self.resolved_ is None: 

137 raise AttributeError( # pragma: no cover 

138 "Attribute 'resolved_' is missing. You must run " 

139 "method 'resolve()'.") 

140 return self.resolved_ 

141 

142 def resolve(self): 

143 """ 

144 Resolves all constraints. It adds the attribute 

145 `resolved_`. 

146 """ 

147 def vars_in_values(values): 

148 i_vals, s_vals = [], [] 

149 for v in values: 

150 if isinstance(v, str): 

151 s_vals.append(v) 

152 else: 

153 i_vals.append(v) 

154 return set(i_vals), s_vals 

155 

156 variables = {} 

157 for _, v in self.shapes.items(): 

158 for sh in v.shape: 

159 if isinstance(sh, str): 

160 variables[sh] = None 

161 

162 # first step: resolves all constraint with integer 

163 dcsts = self.get_all_constraints() 

164 csts = [] 

165 for li in dcsts.values(): 

166 csts.extend(li) 

167 new_csts = [] 

168 for cst in csts: 

169 if cst.name in variables and variables[cst.name] is None: 

170 if all(map(lambda n: isinstance(n, int), cst.values)): 

171 variables[cst.name] = cst.values.copy() 

172 else: 

173 new_csts.append(cst) 

174 else: 

175 raise RuntimeError( # pragma: no cover 

176 "Unable to find any correspondance for variable %r " 

177 "in %r." % (cst.name, ", ".join(sorted(variables)))) 

178 

179 # second step: everything else, like a logic algorithm 

180 dim_names = set() 

181 csts = new_csts 

182 updates = 1 

183 while updates > 0 and len(new_csts) > 0: 

184 updates = 0 

185 new_csts = [] 

186 for cst in csts: 

187 rvalues = variables[cst.name] 

188 ivalues, lvars = vars_in_values(cst.values) 

189 

190 if len(lvars) > 0: 

191 miss = 0 

192 for lv in lvars: 

193 if lv in variables and variables[lv] is not None: 

194 ivalues |= variables[lv] 

195 else: 

196 miss += 1 

197 

198 if miss == 0: 

199 # simple case: only integers 

200 if rvalues is None: 

201 inter = ivalues 

202 else: 

203 inter = rvalues.intersection(ivalues) 

204 if len(inter) == 0: 

205 raise RuntimeError( # pragma: no cover 

206 "Resolution failed for variable %r, " 

207 "current possibilities %r does not match " 

208 "constraint %r." % (cst.name, rvalues, cst)) 

209 if rvalues is None or len(inter) < len(rvalues): 

210 variables[cst.name] = inter 

211 updates += 1 

212 else: 

213 continue 

214 elif len(dim_names) > 0: 

215 # more complex case: variables 

216 if len(cst.values) == 1 and len(lvars) == 1: 

217 # exact mapping between cst.name and lvars[0] 

218 a, b = cst.name, lvars[0] 

219 if variables[a] is None and variables[b] is not None: 

220 if variables[b].intersection(dim_names): 

221 variables[a] = variables[b] 

222 updates += 1 

223 continue 

224 elif variables[b] is None and variables[a] is not None: 

225 if variables[a].intersection(dim_names): 

226 variables[b] = variables[a] 

227 updates += 1 

228 continue 

229 

230 new_csts.append(cst) 

231 csts = new_csts 

232 

233 if len(new_csts) > 0 and updates == 0: 

234 # It means that a dimension needs to be left unknown. 

235 found = None 

236 for k, v in variables.items(): 

237 if v is None: 

238 found = k 

239 if found is not None: 

240 name = f"d{len(dim_names)}" 

241 dim_names.add(name) 

242 variables[found] = {name} 

243 updates += 1 

244 else: 

245 raise RuntimeError( # pragma: no cover 

246 f"Inconsistency in {self!r} with\n{variables!r}") 

247 

248 # final 

249 results = {} 

250 for k, v in self.shapes.items(): 

251 try: 

252 results[k] = v.resolve(variables) 

253 except RuntimeError as e: # pragma: no cover 

254 raise RuntimeError( 

255 "Unable to resolve shapes and constraints:\n%s" 

256 "" % pprint.pformat(self.shapes)) from e 

257 self.resolved_ = results 

258 return self.resolved_