Coverage for mlprodict/onnxrt/ops_shape/shape_result.py: 93%

174 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Class ShapeResult 

4""" 

5from enum import Enum 

6import numpy 

7from .shape_excs import ( 

8 ShapeInferenceException, NotImplementedShapeInferenceError, 

9 ShapeInferenceDimensionError) 

10 

11 

12class OnnxKind(Enum): 

13 """ 

14 Describes a result type. 

15 """ 

16 Tensor = 0 

17 Sequence = 1 

18 Map = 2 

19 

20 

21class ShapeConstraint: 

22 """ 

23 One constraint. 

24 

25 :param name: variable name 

26 :param values: set of possible values 

27 """ 

28 

29 def __init__(self, name, values): 

30 if name == '?': 

31 raise ValueError( # pragma: no cover 

32 "Name cannot be '?'.") 

33 if not isinstance(values, set): 

34 raise TypeError( # pragma: no cover 

35 f"values must be a set not {type(values)!r}.") 

36 self.name = name 

37 self.values = values 

38 

39 def __eq__(self, other): 

40 "usual" 

41 if self.name != other.name: 

42 return False 

43 if self.values != other.values: 

44 return False 

45 return True 

46 

47 def __repr__(self): 

48 "usual" 

49 return f"{self.__class__.__name__}({self.name!r}, {self.values!r})" 

50 

51 def merge(self, cst): 

52 """ 

53 Merges this constraint with *cst* into this one. 

54 """ 

55 if isinstance(cst, list): 

56 for c in cst: 

57 self.merge(c) 

58 return 

59 self.values = self.values.intersection(cst.values) 

60 

61 def copy(self, deep=False): 

62 """ 

63 Makes a copy of the object. 

64 """ 

65 return ShapeConstraint(self.name, self.values.copy()) 

66 

67 

68class ShapeConstraintList: 

69 """ 

70 A list of ShapeConstraint. 

71 """ 

72 

73 def __init__(self): 

74 self.csts = [] 

75 

76 def __contains__(self, cst): 

77 for a in self.csts: 

78 if cst == a: 

79 return True 

80 return False 

81 

82 def append(self, cst): 

83 "Appends a new constraint to the list." 

84 self.csts.append(cst) 

85 

86 def __repr__(self): 

87 return f"ShapeConstraintList({self.csts!r})" 

88 

89 def __iter__(self): 

90 for c in self.csts: 

91 yield c 

92 

93 def __len__(self): 

94 return len(self.csts) 

95 

96 def copy(self, deep=False): 

97 """ 

98 Copies the object. 

99 """ 

100 cp = ShapeConstraintList() 

101 if deep: 

102 cp.csts = [v.copy(deep=deep) for v in self] 

103 else: 

104 cp.csts = self.csts.copy() 

105 return cp 

106 

107 

108class ShapeResult: 

109 """ 

110 Contains information about shape and type of a result 

111 in an onnx graph. 

112 

113 :param name: result name 

114 :param shape: shape if the result is a tensor 

115 :param dtype: element type if the result is a tensor 

116 :param sparse: is the tensor sparse 

117 :param mtype: kind of the result (see class @see cl OnnxKind) 

118 :param constraints: list of constraints applying on variables 

119 """ 

120 

121 def __init__(self, name, shape=None, dtype=None, sparse=False, 

122 mtype=OnnxKind.Tensor, constraints=None): 

123 if not isinstance(name, str): 

124 raise TypeError( # pragma: no cover 

125 f"name must be a string not {type(name)!r}.") 

126 if not isinstance(sparse, bool): 

127 raise TypeError( # pragma: no cover 

128 f"sparse must be a boolean not {sparse!r}.") 

129 if not isinstance(mtype, OnnxKind): 

130 raise TypeError( # pragma: no cover 

131 f"mtype must be of type OnnxKind not {type(mtype)!r}.") 

132 self.shape = list(shape) 

133 for i in range(0, len(self.shape)): # pylint: disable=C0200 

134 if shape[i] in ('', None, '?'): 

135 raise ValueError( # pragma: no cover 

136 f"All dimensions must an int or a variable name, {shape} is not.") 

137 self.name = name 

138 self.mtype = mtype 

139 self.dtype = dtype 

140 self.sparse = sparse 

141 if constraints is None: 

142 self.constraints = ShapeConstraintList() 

143 elif isinstance(constraints, ShapeConstraintList): 

144 self.constraints = constraints 

145 else: 

146 raise TypeError( # pragma: no cover 

147 "constraints must be of type(ShapeConstraintList).") 

148 

149 def is_compatible(self, shape): 

150 """ 

151 Tells if this shape is compatible with the given tuple. 

152 

153 :param shape: tuple 

154 :return: boolean 

155 """ 

156 if isinstance(shape, numpy.ndarray): 

157 shape = shape.shape 

158 if all(map(lambda x: isinstance(x, int), self.shape)): 

159 return tuple(self.shape) == tuple(shape) 

160 raise NotImplementedError(f"{self!r} ? {shape!r}") 

161 

162 def copy(self, deep=False): 

163 """ 

164 Returns a copy for the result. 

165 """ 

166 return ShapeResult(self.name, self.shape, self.dtype, self.sparse, 

167 self.mtype, self.constraints.copy(deep=deep)) 

168 

169 def __repr__(self): 

170 """ 

171 Usual 

172 """ 

173 if len(self.constraints) > 0: 

174 return "%s(%r, %r, %r, sparse=%r, mtype=%r, constraints=%r)" % ( 

175 self.__class__.__name__, self.name, self.shape, self.dtype, 

176 self.sparse, self.mtype, self.constraints) 

177 if self.mtype != OnnxKind.Tensor: 

178 return "%s(%r, %r, %r, sparse=%r, mtype=%r)" % ( 

179 self.__class__.__name__, self.name, self.shape, self.dtype, 

180 self.sparse, self.mtype) 

181 if self.sparse: 

182 return "%s(%r, %r, %r,sparse=%r)" % ( 

183 self.__class__.__name__, self.name, self.shape, self.dtype, 

184 self.sparse) 

185 return "%s(%r, %r, %r)" % ( 

186 self.__class__.__name__, self.name, self.shape, self.dtype) 

187 

188 def __eq__(self, shape): 

189 """ 

190 Tells if two shapes are identical. 

191 """ 

192 return (self.mtype == shape.mtype and self.shape == shape.shape and 

193 self.dtype == shape.dtype and self.sparse == shape.sparse) 

194 

195 def n_dims(self): 

196 """ 

197 Returns the number of dimensions if it is a tensor. 

198 Raises an exception otherwise. 

199 """ 

200 if self.mtype != OnnxKind.Tensor: 

201 raise ShapeInferenceException( # pragma: no cover 

202 f"This shape is not a tensor {self!r}.") 

203 return len(self.shape) 

204 

205 def merge(self, other_result): 

206 """ 

207 Merges constraints from *other_results* into *self*. 

208 """ 

209 if self.mtype != other_result.mtype: 

210 raise RuntimeError( # pragma: no cover 

211 f"Unable to merge {self!r} and {other_result!r}.") 

212 if (len(self.shape) != 0 and len(other_result.shape) != 0 and 

213 len(self.shape) != len(other_result.shape)): 

214 raise ShapeInferenceDimensionError( # pragma: no cover 

215 f"Length mismatch, unable to merge {self!r} and {other_result!r}.") 

216 updated = False 

217 if other_result.constraints is not None: 

218 for c in other_result.constraints: 

219 if c not in self.constraints: 

220 self.constraints.append(c) 

221 updated = True 

222 

223 if len(self.shape) == 0 and len(other_result.shape) > 0: 

224 # Then self.shape is unknown and the other one is. 

225 self.shape = other_result.shape.copy() 

226 return True 

227 

228 for a, b in zip(self.shape, other_result.shape): 

229 if a == b: 

230 continue 

231 if isinstance(a, int) and isinstance(b, int): 

232 raise RuntimeError( 

233 f"Inconsistancy between {self!r} and {other_result!r}.") 

234 elif isinstance(a, str): 

235 c = ShapeConstraint(a, {b}) 

236 if c not in self.constraints: 

237 updated = True 

238 self.constraints.append(c) 

239 elif isinstance(b, str): 

240 c = ShapeConstraint(b, {a}) 

241 if c not in self.constraints: 

242 updated = True 

243 self.constraints.append(c) 

244 else: 

245 raise NotImplementedError( # pragma: no cover 

246 f"Merge not implemented between {self!r} and {other_result!r}.") 

247 return updated 

248 

249 def resolve(self, variables): 

250 """ 

251 Results variables in a shape using values stored 

252 in *variables*. It does not copy any constraints. 

253 

254 :param variables: dictionary `{ name: values }` 

255 :return: new ShapeResult 

256 """ 

257 res = ShapeResult(self.name, shape=self.shape, dtype=self.dtype, 

258 sparse=self.sparse, mtype=self.mtype) 

259 for i in range(len(res.shape)): # pylint: disable=C0200 

260 v = res.shape[i] 

261 if isinstance(v, str): 

262 if v in variables: 

263 vals = variables[v] 

264 if vals is None: 

265 # size unknown 

266 continue 

267 if len(vals) == 1: 

268 res.shape[i] = list(vals)[0] 

269 else: 

270 res.shape[i] = set(vals) 

271 else: 

272 raise RuntimeError( # pragma: no cover 

273 f"Unable to resolve shape {self!r} due to missing {v!r}.") 

274 return res 

275 

276 @staticmethod 

277 def broadcast(sh1, sh2, name=None, dtype=None, same_type=True): 

278 """ 

279 Broadcasts dimensions for an element wise operator. 

280 

281 :param sh1: ShapeResult 

282 :param sh2: ShapeResult 

283 :param name: name of the output ShapeResult 

284 :param dtype: type of the result or the same as the first 

285 element if None 

286 :param same_type: check the type are the same 

287 :return: ShapeResult 

288 """ 

289 if not isinstance(sh1, ShapeResult): 

290 raise TypeError( # pragma: no cover 

291 f"Unexpected type for sh1 {type(sh1)!r}.") 

292 if not isinstance(sh2, ShapeResult): 

293 raise TypeError( # pragma: no cover 

294 f"Unexpected type for sh2 {type(sh2)!r}.") 

295 if sh1.mtype != OnnxKind.Tensor: 

296 raise TypeError( # pragma: no cover 

297 f"sh1 must be a tensor not {sh1.mtype!r}.") 

298 if sh2.mtype != OnnxKind.Tensor: 

299 raise TypeError( # pragma: no cover 

300 f"sh2 must be a tensor not {sh2.mtype!r}.") 

301 if same_type and sh1.dtype != sh2.dtype: 

302 if sh1.dtype is not None and sh2.dtype is not None: 

303 raise ShapeInferenceException( # pragma: no cover 

304 f"Cannot broadcast shapes {sh1!r} and {sh2!r} (dtypes).") 

305 

306 # Specific cases. 

307 if sh1.n_dims() != sh2.n_dims(): 

308 if sh1.n_dims() == 1 and sh1.shape[0] == 1: 

309 return ShapeResult( 

310 name, sh2.shape, dtype or sh2.dtype, sh2.sparse, sh2.mtype) 

311 if sh2.n_dims() == 1 and sh2.shape[0] == 1: 

312 return ShapeResult( 

313 name, sh1.shape, dtype or sh1.dtype, sh1.sparse, sh1.mtype) 

314 if sh2.n_dims() < sh1.n_dims() and sh1.shape[-sh2.n_dims():] == sh2.shape: 

315 return ShapeResult( 

316 name, sh1.shape, dtype or sh1.dtype, sh1.sparse, sh1.mtype) 

317 raise NotImplementedShapeInferenceError( # pragma: no cover 

318 "Broadcasting is only implemented for shape of the same " 

319 "size, shapes are %r and %r." % (sh1, sh2)) 

320 

321 # Other cases. 

322 constraints = ShapeConstraintList() 

323 shape = [] 

324 for a, b in zip(sh1.shape, sh2.shape): 

325 if isinstance(a, int) and isinstance(b, int): 

326 if a != b: 

327 if min(a, b) == 1: 

328 d = max(a, b) 

329 else: 

330 raise ShapeInferenceException( # pragma: no cover 

331 "Cannot broadcast shapes %r and %r (dimensions)." 

332 "" % (sh1, sh2)) 

333 else: 

334 d = a 

335 elif isinstance(a, int): 

336 if a != 1: 

337 d = a 

338 constraints.append(ShapeConstraint(b, {1, a})) 

339 else: 

340 d = b 

341 elif isinstance(b, int): 

342 if b != 1: 

343 d = b 

344 constraints.append(ShapeConstraint(a, {1, b})) 

345 else: 

346 d = a 

347 elif a == b: 

348 d = a 

349 elif isinstance(a, str) and isinstance(b, str): 

350 if a != b: 

351 # Both dimensions are variables. 

352 constraints.append(ShapeConstraint(a, {1, b})) 

353 constraints.append(ShapeConstraint(b, {1, a})) 

354 d = a 

355 else: 

356 raise ShapeInferenceException( # pragma: no cover 

357 f"Cannot broadcast shapes {sh1!r} and {sh2!r}.") 

358 shape.append(d) 

359 if name in (None, ''): 

360 raise ValueError( # pragma: no cover 

361 "name cannot be empty.") 

362 res = ShapeResult(name, shape, dtype or sh1.dtype, sh1.sparse or sh2.sparse, 

363 sh1.mtype, constraints) 

364 return res