Coverage for mlprodict/onnxrt/ops_shape/_op_shape_op.py: 90%

20 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Computes shape inference for onnx operators. 

4""" 

5from .shape_excs import ShapeInferenceException, ShapeInferenceDimensionError 

6from .shape_result import ( 

7 ShapeResult, OnnxKind, ShapeConstraintList, ShapeConstraint) 

8 

9 

10def shape_det(known_shapes, node): 

11 "Infers shape for operator Abs." 

12 x = known_shapes[node.input[0]] 

13 if x.mtype != OnnxKind.Tensor: 

14 raise ShapeInferenceException( # pragma: no cover 

15 f"Result {x!r} must be a tensor.") 

16 if x.n_dims() < 2: 

17 if x.n_dims() > 0: 

18 raise ShapeInferenceException( # pragma: no cover 

19 f"Operator Det requires at least two dimensions not {x.n_dims()!r}.") 

20 raise ShapeInferenceDimensionError( # pragma: no cover 

21 f"Operator Det requires at least two dimensions not {x.n_dims()!r}.") 

22 name = node.output[0] 

23 

24 constraints = ShapeConstraintList() 

25 a, b = x.shape[-2:] 

26 if isinstance(a, int) and isinstance(b, int): 

27 if a != b: 

28 raise ShapeInferenceException( # pragma: no cover 

29 f"Operator Det only applies on square matrices not {x.n_dims()!r}.") 

30 elif isinstance(a, str): 

31 constraints.append(ShapeConstraint(a, {b})) 

32 elif isinstance(b, str): 

33 constraints.append(ShapeConstraint(b, {a})) 

34 else: 

35 raise ShapeInferenceException( # pragma: no cover 

36 f"Unexpected case for operator Det ({x!r}).") 

37 if x.n_dims() == 2: 

38 r = ShapeResult(name, [], x.dtype, False, 

39 x.mtype, constraints) 

40 else: 

41 r = ShapeResult(name, x.shape[:-2], x.dtype, False, 

42 x.mtype, constraints) 

43 return known_shapes.update(name, r)