Coverage for mlprodict/onnxrt/ops_shape/_element_wise.py: 100%

48 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Computes shape inference for element wise operators. 

4""" 

5import numpy 

6from .shape_excs import ShapeInferenceException 

7from .shape_result import ShapeResult, OnnxKind 

8 

9 

10def _element_wise(known_shapes, node, return_bool=False, same_type=True, 

11 one_input=False): 

12 """ 

13 Infers shape for an element wise operator. 

14 The function returns but updates *known_shapes*. 

15 

16 :param known_shapes: known shapes 

17 :param node: Onnx node 

18 :param return_bool: return boolean 

19 :param same_type: check the type are the same 

20 :param one_input: allow one input 

21 :return: updated or not 

22 """ 

23 if one_input: 

24 if len(node.input) == 1: 

25 x = known_shapes[node.input[0]] 

26 return known_shapes.update(node.output[0], x.copy()) 

27 elif len(node.input) != 2: 

28 raise ShapeInferenceException( # pragma: no cover 

29 f"Node {node.name!r} must have two inputs not {len(node.input)}.") 

30 x = known_shapes[node.input[0]] 

31 y = known_shapes[node.input[1]] 

32 if x.mtype != OnnxKind.Tensor: 

33 raise ShapeInferenceException( # pragma: no cover 

34 f"Result {x!r} must be a tensor.") 

35 if y.mtype != OnnxKind.Tensor: 

36 raise ShapeInferenceException( # pragma: no cover 

37 f"Result {y!r} must be a tensor.") 

38 if return_bool: 

39 return known_shapes.update( 

40 node.output[0], 

41 ShapeResult.broadcast( 

42 x, y, name=node.output[0], dtype=numpy.bool_, 

43 same_type=same_type)) 

44 return known_shapes.update( 

45 node.output[0], 

46 ShapeResult.broadcast( 

47 x, y, name=node.output[0], same_type=same_type)) 

48 

49 

50def shape_add(known_shapes, node): 

51 "Infers shape for operator Add." 

52 return _element_wise(known_shapes, node) 

53 

54 

55def shape_and(known_shapes, node): 

56 "Infers shape for operator And." 

57 return _element_wise(known_shapes, node) 

58 

59 

60def shape_div(known_shapes, node): 

61 "Infers shape for operator Div." 

62 return _element_wise(known_shapes, node) 

63 

64 

65def shape_equal(known_shapes, node): 

66 "Infers shape for operator Equal." 

67 return _element_wise(known_shapes, node, return_bool=True) 

68 

69 

70def shape_greater(known_shapes, node): 

71 "Infers shape for operator Greater." 

72 return _element_wise(known_shapes, node, return_bool=True) 

73 

74 

75def shape_greaterorequal(known_shapes, node): 

76 "Infers shape for operator GreaterOrEqual." 

77 return _element_wise(known_shapes, node, return_bool=True) 

78 

79 

80def shape_less(known_shapes, node): 

81 "Infers shape for operator Less." 

82 return _element_wise(known_shapes, node, return_bool=True) 

83 

84 

85def shape_lessorequal(known_shapes, node): 

86 "Infers shape for operator LessOrEqual." 

87 return _element_wise(known_shapes, node, return_bool=True) 

88 

89 

90def shape_max(known_shapes, node): 

91 "Infers shape for operator Max." 

92 return _element_wise(known_shapes, node, one_input=True) 

93 

94 

95def shape_min(known_shapes, node): 

96 "Infers shape for operator Min." 

97 return _element_wise(known_shapes, node, one_input=True) 

98 

99 

100def shape_mod(known_shapes, node): 

101 "Infers shape for operator Mod." 

102 return _element_wise(known_shapes, node) 

103 

104 

105def shape_mul(known_shapes, node): 

106 "Infers shape for operator Mul." 

107 return _element_wise(known_shapes, node) 

108 

109 

110def shape_or(known_shapes, node): 

111 "Infers shape for operator Or." 

112 return _element_wise(known_shapes, node) 

113 

114 

115def shape_pow(known_shapes, node): 

116 "Infers shape for operator Pow." 

117 return _element_wise(known_shapes, node, same_type=False) 

118 

119 

120def shape_sub(known_shapes, node): 

121 "Infers shape for operator Sub." 

122 return _element_wise(known_shapes, node) 

123 

124 

125def shape_xor(known_shapes, node): 

126 "Infers shape for operator Xor." 

127 return _element_wise(known_shapes, node)