Coverage for mlprodict/onnxrt/ops_shape/shape_result.py: 93%
174 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Class ShapeResult
4"""
5from enum import Enum
6import numpy
7from .shape_excs import (
8 ShapeInferenceException, NotImplementedShapeInferenceError,
9 ShapeInferenceDimensionError)
12class OnnxKind(Enum):
13 """
14 Describes a result type.
15 """
16 Tensor = 0
17 Sequence = 1
18 Map = 2
21class ShapeConstraint:
22 """
23 One constraint.
25 :param name: variable name
26 :param values: set of possible values
27 """
29 def __init__(self, name, values):
30 if name == '?':
31 raise ValueError( # pragma: no cover
32 "Name cannot be '?'.")
33 if not isinstance(values, set):
34 raise TypeError( # pragma: no cover
35 f"values must be a set not {type(values)!r}.")
36 self.name = name
37 self.values = values
39 def __eq__(self, other):
40 "usual"
41 if self.name != other.name:
42 return False
43 if self.values != other.values:
44 return False
45 return True
47 def __repr__(self):
48 "usual"
49 return f"{self.__class__.__name__}({self.name!r}, {self.values!r})"
51 def merge(self, cst):
52 """
53 Merges this constraint with *cst* into this one.
54 """
55 if isinstance(cst, list):
56 for c in cst:
57 self.merge(c)
58 return
59 self.values = self.values.intersection(cst.values)
61 def copy(self, deep=False):
62 """
63 Makes a copy of the object.
64 """
65 return ShapeConstraint(self.name, self.values.copy())
68class ShapeConstraintList:
69 """
70 A list of ShapeConstraint.
71 """
73 def __init__(self):
74 self.csts = []
76 def __contains__(self, cst):
77 for a in self.csts:
78 if cst == a:
79 return True
80 return False
82 def append(self, cst):
83 "Appends a new constraint to the list."
84 self.csts.append(cst)
86 def __repr__(self):
87 return f"ShapeConstraintList({self.csts!r})"
89 def __iter__(self):
90 for c in self.csts:
91 yield c
93 def __len__(self):
94 return len(self.csts)
96 def copy(self, deep=False):
97 """
98 Copies the object.
99 """
100 cp = ShapeConstraintList()
101 if deep:
102 cp.csts = [v.copy(deep=deep) for v in self]
103 else:
104 cp.csts = self.csts.copy()
105 return cp
108class ShapeResult:
109 """
110 Contains information about shape and type of a result
111 in an onnx graph.
113 :param name: result name
114 :param shape: shape if the result is a tensor
115 :param dtype: element type if the result is a tensor
116 :param sparse: is the tensor sparse
117 :param mtype: kind of the result (see class @see cl OnnxKind)
118 :param constraints: list of constraints applying on variables
119 """
121 def __init__(self, name, shape=None, dtype=None, sparse=False,
122 mtype=OnnxKind.Tensor, constraints=None):
123 if not isinstance(name, str):
124 raise TypeError( # pragma: no cover
125 f"name must be a string not {type(name)!r}.")
126 if not isinstance(sparse, bool):
127 raise TypeError( # pragma: no cover
128 f"sparse must be a boolean not {sparse!r}.")
129 if not isinstance(mtype, OnnxKind):
130 raise TypeError( # pragma: no cover
131 f"mtype must be of type OnnxKind not {type(mtype)!r}.")
132 self.shape = list(shape)
133 for i in range(0, len(self.shape)): # pylint: disable=C0200
134 if shape[i] in ('', None, '?'):
135 raise ValueError( # pragma: no cover
136 f"All dimensions must an int or a variable name, {shape} is not.")
137 self.name = name
138 self.mtype = mtype
139 self.dtype = dtype
140 self.sparse = sparse
141 if constraints is None:
142 self.constraints = ShapeConstraintList()
143 elif isinstance(constraints, ShapeConstraintList):
144 self.constraints = constraints
145 else:
146 raise TypeError( # pragma: no cover
147 "constraints must be of type(ShapeConstraintList).")
149 def is_compatible(self, shape):
150 """
151 Tells if this shape is compatible with the given tuple.
153 :param shape: tuple
154 :return: boolean
155 """
156 if isinstance(shape, numpy.ndarray):
157 shape = shape.shape
158 if all(map(lambda x: isinstance(x, int), self.shape)):
159 return tuple(self.shape) == tuple(shape)
160 raise NotImplementedError(f"{self!r} ? {shape!r}")
162 def copy(self, deep=False):
163 """
164 Returns a copy for the result.
165 """
166 return ShapeResult(self.name, self.shape, self.dtype, self.sparse,
167 self.mtype, self.constraints.copy(deep=deep))
169 def __repr__(self):
170 """
171 Usual
172 """
173 if len(self.constraints) > 0:
174 return "%s(%r, %r, %r, sparse=%r, mtype=%r, constraints=%r)" % (
175 self.__class__.__name__, self.name, self.shape, self.dtype,
176 self.sparse, self.mtype, self.constraints)
177 if self.mtype != OnnxKind.Tensor:
178 return "%s(%r, %r, %r, sparse=%r, mtype=%r)" % (
179 self.__class__.__name__, self.name, self.shape, self.dtype,
180 self.sparse, self.mtype)
181 if self.sparse:
182 return "%s(%r, %r, %r,sparse=%r)" % (
183 self.__class__.__name__, self.name, self.shape, self.dtype,
184 self.sparse)
185 return "%s(%r, %r, %r)" % (
186 self.__class__.__name__, self.name, self.shape, self.dtype)
188 def __eq__(self, shape):
189 """
190 Tells if two shapes are identical.
191 """
192 return (self.mtype == shape.mtype and self.shape == shape.shape and
193 self.dtype == shape.dtype and self.sparse == shape.sparse)
195 def n_dims(self):
196 """
197 Returns the number of dimensions if it is a tensor.
198 Raises an exception otherwise.
199 """
200 if self.mtype != OnnxKind.Tensor:
201 raise ShapeInferenceException( # pragma: no cover
202 f"This shape is not a tensor {self!r}.")
203 return len(self.shape)
205 def merge(self, other_result):
206 """
207 Merges constraints from *other_results* into *self*.
208 """
209 if self.mtype != other_result.mtype:
210 raise RuntimeError( # pragma: no cover
211 f"Unable to merge {self!r} and {other_result!r}.")
212 if (len(self.shape) != 0 and len(other_result.shape) != 0 and
213 len(self.shape) != len(other_result.shape)):
214 raise ShapeInferenceDimensionError( # pragma: no cover
215 f"Length mismatch, unable to merge {self!r} and {other_result!r}.")
216 updated = False
217 if other_result.constraints is not None:
218 for c in other_result.constraints:
219 if c not in self.constraints:
220 self.constraints.append(c)
221 updated = True
223 if len(self.shape) == 0 and len(other_result.shape) > 0:
224 # Then self.shape is unknown and the other one is.
225 self.shape = other_result.shape.copy()
226 return True
228 for a, b in zip(self.shape, other_result.shape):
229 if a == b:
230 continue
231 if isinstance(a, int) and isinstance(b, int):
232 raise RuntimeError(
233 f"Inconsistancy between {self!r} and {other_result!r}.")
234 elif isinstance(a, str):
235 c = ShapeConstraint(a, {b})
236 if c not in self.constraints:
237 updated = True
238 self.constraints.append(c)
239 elif isinstance(b, str):
240 c = ShapeConstraint(b, {a})
241 if c not in self.constraints:
242 updated = True
243 self.constraints.append(c)
244 else:
245 raise NotImplementedError( # pragma: no cover
246 f"Merge not implemented between {self!r} and {other_result!r}.")
247 return updated
249 def resolve(self, variables):
250 """
251 Results variables in a shape using values stored
252 in *variables*. It does not copy any constraints.
254 :param variables: dictionary `{ name: values }`
255 :return: new ShapeResult
256 """
257 res = ShapeResult(self.name, shape=self.shape, dtype=self.dtype,
258 sparse=self.sparse, mtype=self.mtype)
259 for i in range(len(res.shape)): # pylint: disable=C0200
260 v = res.shape[i]
261 if isinstance(v, str):
262 if v in variables:
263 vals = variables[v]
264 if vals is None:
265 # size unknown
266 continue
267 if len(vals) == 1:
268 res.shape[i] = list(vals)[0]
269 else:
270 res.shape[i] = set(vals)
271 else:
272 raise RuntimeError( # pragma: no cover
273 f"Unable to resolve shape {self!r} due to missing {v!r}.")
274 return res
276 @staticmethod
277 def broadcast(sh1, sh2, name=None, dtype=None, same_type=True):
278 """
279 Broadcasts dimensions for an element wise operator.
281 :param sh1: ShapeResult
282 :param sh2: ShapeResult
283 :param name: name of the output ShapeResult
284 :param dtype: type of the result or the same as the first
285 element if None
286 :param same_type: check the type are the same
287 :return: ShapeResult
288 """
289 if not isinstance(sh1, ShapeResult):
290 raise TypeError( # pragma: no cover
291 f"Unexpected type for sh1 {type(sh1)!r}.")
292 if not isinstance(sh2, ShapeResult):
293 raise TypeError( # pragma: no cover
294 f"Unexpected type for sh2 {type(sh2)!r}.")
295 if sh1.mtype != OnnxKind.Tensor:
296 raise TypeError( # pragma: no cover
297 f"sh1 must be a tensor not {sh1.mtype!r}.")
298 if sh2.mtype != OnnxKind.Tensor:
299 raise TypeError( # pragma: no cover
300 f"sh2 must be a tensor not {sh2.mtype!r}.")
301 if same_type and sh1.dtype != sh2.dtype:
302 if sh1.dtype is not None and sh2.dtype is not None:
303 raise ShapeInferenceException( # pragma: no cover
304 f"Cannot broadcast shapes {sh1!r} and {sh2!r} (dtypes).")
306 # Specific cases.
307 if sh1.n_dims() != sh2.n_dims():
308 if sh1.n_dims() == 1 and sh1.shape[0] == 1:
309 return ShapeResult(
310 name, sh2.shape, dtype or sh2.dtype, sh2.sparse, sh2.mtype)
311 if sh2.n_dims() == 1 and sh2.shape[0] == 1:
312 return ShapeResult(
313 name, sh1.shape, dtype or sh1.dtype, sh1.sparse, sh1.mtype)
314 if sh2.n_dims() < sh1.n_dims() and sh1.shape[-sh2.n_dims():] == sh2.shape:
315 return ShapeResult(
316 name, sh1.shape, dtype or sh1.dtype, sh1.sparse, sh1.mtype)
317 raise NotImplementedShapeInferenceError( # pragma: no cover
318 "Broadcasting is only implemented for shape of the same "
319 "size, shapes are %r and %r." % (sh1, sh2))
321 # Other cases.
322 constraints = ShapeConstraintList()
323 shape = []
324 for a, b in zip(sh1.shape, sh2.shape):
325 if isinstance(a, int) and isinstance(b, int):
326 if a != b:
327 if min(a, b) == 1:
328 d = max(a, b)
329 else:
330 raise ShapeInferenceException( # pragma: no cover
331 "Cannot broadcast shapes %r and %r (dimensions)."
332 "" % (sh1, sh2))
333 else:
334 d = a
335 elif isinstance(a, int):
336 if a != 1:
337 d = a
338 constraints.append(ShapeConstraint(b, {1, a}))
339 else:
340 d = b
341 elif isinstance(b, int):
342 if b != 1:
343 d = b
344 constraints.append(ShapeConstraint(a, {1, b}))
345 else:
346 d = a
347 elif a == b:
348 d = a
349 elif isinstance(a, str) and isinstance(b, str):
350 if a != b:
351 # Both dimensions are variables.
352 constraints.append(ShapeConstraint(a, {1, b}))
353 constraints.append(ShapeConstraint(b, {1, a}))
354 d = a
355 else:
356 raise ShapeInferenceException( # pragma: no cover
357 f"Cannot broadcast shapes {sh1!r} and {sh2!r}.")
358 shape.append(d)
359 if name in (None, ''):
360 raise ValueError( # pragma: no cover
361 "name cannot be empty.")
362 res = ShapeResult(name, shape, dtype or sh1.dtype, sh1.sparse or sh2.sparse,
363 sh1.mtype, constraints)
364 return res