Coverage for mlprodict/onnxrt/ops_shape/_element_wise.py: 100%
48 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Computes shape inference for element wise operators.
4"""
5import numpy
6from .shape_excs import ShapeInferenceException
7from .shape_result import ShapeResult, OnnxKind
10def _element_wise(known_shapes, node, return_bool=False, same_type=True,
11 one_input=False):
12 """
13 Infers shape for an element wise operator.
14 The function returns but updates *known_shapes*.
16 :param known_shapes: known shapes
17 :param node: Onnx node
18 :param return_bool: return boolean
19 :param same_type: check the type are the same
20 :param one_input: allow one input
21 :return: updated or not
22 """
23 if one_input:
24 if len(node.input) == 1:
25 x = known_shapes[node.input[0]]
26 return known_shapes.update(node.output[0], x.copy())
27 elif len(node.input) != 2:
28 raise ShapeInferenceException( # pragma: no cover
29 f"Node {node.name!r} must have two inputs not {len(node.input)}.")
30 x = known_shapes[node.input[0]]
31 y = known_shapes[node.input[1]]
32 if x.mtype != OnnxKind.Tensor:
33 raise ShapeInferenceException( # pragma: no cover
34 f"Result {x!r} must be a tensor.")
35 if y.mtype != OnnxKind.Tensor:
36 raise ShapeInferenceException( # pragma: no cover
37 f"Result {y!r} must be a tensor.")
38 if return_bool:
39 return known_shapes.update(
40 node.output[0],
41 ShapeResult.broadcast(
42 x, y, name=node.output[0], dtype=numpy.bool_,
43 same_type=same_type))
44 return known_shapes.update(
45 node.output[0],
46 ShapeResult.broadcast(
47 x, y, name=node.output[0], same_type=same_type))
50def shape_add(known_shapes, node):
51 "Infers shape for operator Add."
52 return _element_wise(known_shapes, node)
55def shape_and(known_shapes, node):
56 "Infers shape for operator And."
57 return _element_wise(known_shapes, node)
60def shape_div(known_shapes, node):
61 "Infers shape for operator Div."
62 return _element_wise(known_shapes, node)
65def shape_equal(known_shapes, node):
66 "Infers shape for operator Equal."
67 return _element_wise(known_shapes, node, return_bool=True)
70def shape_greater(known_shapes, node):
71 "Infers shape for operator Greater."
72 return _element_wise(known_shapes, node, return_bool=True)
75def shape_greaterorequal(known_shapes, node):
76 "Infers shape for operator GreaterOrEqual."
77 return _element_wise(known_shapes, node, return_bool=True)
80def shape_less(known_shapes, node):
81 "Infers shape for operator Less."
82 return _element_wise(known_shapes, node, return_bool=True)
85def shape_lessorequal(known_shapes, node):
86 "Infers shape for operator LessOrEqual."
87 return _element_wise(known_shapes, node, return_bool=True)
90def shape_max(known_shapes, node):
91 "Infers shape for operator Max."
92 return _element_wise(known_shapes, node, one_input=True)
95def shape_min(known_shapes, node):
96 "Infers shape for operator Min."
97 return _element_wise(known_shapes, node, one_input=True)
100def shape_mod(known_shapes, node):
101 "Infers shape for operator Mod."
102 return _element_wise(known_shapes, node)
105def shape_mul(known_shapes, node):
106 "Infers shape for operator Mul."
107 return _element_wise(known_shapes, node)
110def shape_or(known_shapes, node):
111 "Infers shape for operator Or."
112 return _element_wise(known_shapes, node)
115def shape_pow(known_shapes, node):
116 "Infers shape for operator Pow."
117 return _element_wise(known_shapes, node, same_type=False)
120def shape_sub(known_shapes, node):
121 "Infers shape for operator Sub."
122 return _element_wise(known_shapes, node)
125def shape_xor(known_shapes, node):
126 "Infers shape for operator Xor."
127 return _element_wise(known_shapes, node)