Coverage for mlprodict/onnxrt/ops_shape/shape_container.py: 93%
156 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Class ShapeContainer
4"""
5import pprint
6from .shape_result import ShapeResult
9class ShapeContainer:
10 """
11 Stores all infered shapes as @see cl ShapeResult.
13 Attributes:
15 * `shapes`: dictionary `{ result name: ShapeResult }`
16 * `names`: some dimensions are unknown and represented as
17 variables, this dictionary keeps track of them
18 * `names_rev`: reverse dictionary of `names`
19 """
21 def __init__(self):
22 self.shapes = dict()
23 self.names = dict()
24 self.names_rev = dict()
26 def __repr__(self):
27 "usual"
28 return f"{self.__class__.__name__}()"
30 def __len__(self):
31 "usual"
32 return len(self.shapes)
34 def __getitem__(self, key):
35 "Retrieves one shape from its name."
36 return self.shapes[key]
38 def copy(self, deep=False):
39 "Makes a copy."
40 cont = ShapeContainer()
41 cont.shapes = {k: v.copy(deep=deep) for k, v in self.shapes.items()}
42 cont.names = self.names.copy()
43 cont.names_rev = {k: v.copy() for k, v in self.names_rev.items()}
44 return cont
46 def update(self, key, value):
47 """
48 Updates one shape. Returns True if the shape was different.
49 """
50 if not isinstance(key, str):
51 raise TypeError( # pragma: no cover
52 f"key must be a string not {type(key)!r}.")
53 if not isinstance(value, ShapeResult):
54 raise TypeError( # pragma: no cover
55 f"value must be a ShapeResult not {type(key)!r}.")
56 if key not in self.shapes:
57 self.shapes[key] = value
58 return True
59 r = self.shapes[key].merge(value)
60 return r
62 def __contains__(self, key):
63 "Operator in."
64 return key in self.shapes
66 def __str__(self):
67 """
68 Displays.
69 """
70 rows = ["ShapeContainer({"]
71 for k, v in self.shapes.items():
72 rows.append(f" {k!r}: {v!r}")
73 rows.append("}, names={")
74 for k, v in self.names.items():
75 rows.append(f" {k!r}: {v!r}")
76 cst = self.get_all_constraints()
77 if len(cst) > 0:
78 rows.append("}, constraint={")
79 for c, v in cst.items():
80 rows.append(f" {c!r}: {v!r}")
81 rows.append("})")
82 else:
83 rows.append("})")
85 return "\n".join(rows)
87 def get_new_name(self, name, result_name, dim):
88 """
89 Returns a variable name when a dimension is not
90 specified.
91 """
92 if name is not None and not isinstance(name, str):
93 raise TypeError( # pragma: no cover
94 f"name must be string not {name!r}.")
95 if name is None:
96 name = ''
97 if name == '' or name not in self.names:
98 i = 0
99 new_name = "%s_%d" % (name, i)
100 while new_name in self.names:
101 i += 1
102 new_name = "%s_%d" % (name, i)
103 self.names[new_name] = (name, result_name, dim)
104 if name not in self.names_rev:
105 self.names_rev[name] = []
106 self.names_rev[name].append(new_name)
107 return new_name
108 val = self.names_rev[name]
109 if len(val) != 1:
110 raise RuntimeError( # pragma: no cover
111 f"Name {name!r} has more than one correspondance ({val!r}).")
112 return val[0]
114 def get_all_constraints(self):
115 """
116 Gathers all constraints.
117 """
118 cons = {}
119 for _, v in self.shapes.items():
120 if v.constraints is not None:
121 for c in v.constraints:
122 if c.name not in cons:
123 cons[c.name] = []
124 cons[c.name].append(c)
125 for _, v in cons.items():
126 if len(v) > 1:
127 v[0].merge(v[1:])
128 del v[1:]
129 return cons
131 def get(self):
132 """
133 Returns the value of attribute `resolved_`
134 (method `resolve()` must have been called first).
135 """
136 if not hasattr(self, 'resolved_') or self.resolved_ is None:
137 raise AttributeError( # pragma: no cover
138 "Attribute 'resolved_' is missing. You must run "
139 "method 'resolve()'.")
140 return self.resolved_
142 def resolve(self):
143 """
144 Resolves all constraints. It adds the attribute
145 `resolved_`.
146 """
147 def vars_in_values(values):
148 i_vals, s_vals = [], []
149 for v in values:
150 if isinstance(v, str):
151 s_vals.append(v)
152 else:
153 i_vals.append(v)
154 return set(i_vals), s_vals
156 variables = {}
157 for _, v in self.shapes.items():
158 for sh in v.shape:
159 if isinstance(sh, str):
160 variables[sh] = None
162 # first step: resolves all constraint with integer
163 dcsts = self.get_all_constraints()
164 csts = []
165 for li in dcsts.values():
166 csts.extend(li)
167 new_csts = []
168 for cst in csts:
169 if cst.name in variables and variables[cst.name] is None:
170 if all(map(lambda n: isinstance(n, int), cst.values)):
171 variables[cst.name] = cst.values.copy()
172 else:
173 new_csts.append(cst)
174 else:
175 raise RuntimeError( # pragma: no cover
176 "Unable to find any correspondance for variable %r "
177 "in %r." % (cst.name, ", ".join(sorted(variables))))
179 # second step: everything else, like a logic algorithm
180 dim_names = set()
181 csts = new_csts
182 updates = 1
183 while updates > 0 and len(new_csts) > 0:
184 updates = 0
185 new_csts = []
186 for cst in csts:
187 rvalues = variables[cst.name]
188 ivalues, lvars = vars_in_values(cst.values)
190 if len(lvars) > 0:
191 miss = 0
192 for lv in lvars:
193 if lv in variables and variables[lv] is not None:
194 ivalues |= variables[lv]
195 else:
196 miss += 1
198 if miss == 0:
199 # simple case: only integers
200 if rvalues is None:
201 inter = ivalues
202 else:
203 inter = rvalues.intersection(ivalues)
204 if len(inter) == 0:
205 raise RuntimeError( # pragma: no cover
206 "Resolution failed for variable %r, "
207 "current possibilities %r does not match "
208 "constraint %r." % (cst.name, rvalues, cst))
209 if rvalues is None or len(inter) < len(rvalues):
210 variables[cst.name] = inter
211 updates += 1
212 else:
213 continue
214 elif len(dim_names) > 0:
215 # more complex case: variables
216 if len(cst.values) == 1 and len(lvars) == 1:
217 # exact mapping between cst.name and lvars[0]
218 a, b = cst.name, lvars[0]
219 if variables[a] is None and variables[b] is not None:
220 if variables[b].intersection(dim_names):
221 variables[a] = variables[b]
222 updates += 1
223 continue
224 elif variables[b] is None and variables[a] is not None:
225 if variables[a].intersection(dim_names):
226 variables[b] = variables[a]
227 updates += 1
228 continue
230 new_csts.append(cst)
231 csts = new_csts
233 if len(new_csts) > 0 and updates == 0:
234 # It means that a dimension needs to be left unknown.
235 found = None
236 for k, v in variables.items():
237 if v is None:
238 found = k
239 if found is not None:
240 name = f"d{len(dim_names)}"
241 dim_names.add(name)
242 variables[found] = {name}
243 updates += 1
244 else:
245 raise RuntimeError( # pragma: no cover
246 f"Inconsistency in {self!r} with\n{variables!r}")
248 # final
249 results = {}
250 for k, v in self.shapes.items():
251 try:
252 results[k] = v.resolve(variables)
253 except RuntimeError as e: # pragma: no cover
254 raise RuntimeError(
255 "Unable to resolve shapes and constraints:\n%s"
256 "" % pprint.pformat(self.shapes)) from e
257 self.resolved_ = results
258 return self.resolved_