Coverage for mlprodict/onnxrt/ops_shape/_element_unary.py: 99%
109 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Computes shape inference for element wise operators with one input.
4"""
5import numpy
6from .shape_excs import ShapeInferenceException
7from .shape_result import OnnxKind
10def _element_unary(known_shapes, node, dtype=None, one_input=True):
11 """
12 Infers shape for an element wise operator.
13 The function returns but updates *known_shapes*.
15 :param known_shapes: known shapes
16 :param node: Onnx node
17 :param dtype: None to keep the same type as input,
18 not None to change it
19 :param one_input: check there is only one input
20 :return: updated or not
21 """
22 if one_input and len(node.input) != 1:
23 raise ShapeInferenceException( # pragma: no cover
24 f"Node {node.name!r} must have one input not {len(node.input)}.")
25 x = known_shapes[node.input[0]]
26 if x.mtype != OnnxKind.Tensor:
27 raise ShapeInferenceException( # pragma: no cover
28 f"Result {x!r} must be a tensor.")
29 if dtype is None:
30 return known_shapes.update(node.output[0], x.copy())
31 cp = x.copy()
32 cp.dtype = dtype
33 return known_shapes.update(node.output[0], cp)
36def shape_abs(known_shapes, node):
37 "Infers shape for operator Abs."
38 return _element_unary(known_shapes, node)
41def shape_acos(known_shapes, node):
42 "Infers shape for operator Acos."
43 return _element_unary(known_shapes, node)
46def shape_acosh(known_shapes, node):
47 "Infers shape for operator Acosh."
48 return _element_unary(known_shapes, node)
51def shape_asin(known_shapes, node):
52 "Infers shape for operator Asin."
53 return _element_unary(known_shapes, node)
56def shape_asinh(known_shapes, node):
57 "Infers shape for operator Asinh."
58 return _element_unary(known_shapes, node)
61def shape_atan(known_shapes, node):
62 "Infers shape for operator Atan."
63 return _element_unary(known_shapes, node)
66def shape_atanh(known_shapes, node):
67 "Infers shape for operator Atanh."
68 return _element_unary(known_shapes, node)
71def shape_castlike(known_shapes, node):
72 "Infers shape for operator CastLike."
73 x = known_shapes[node.input[0]]
74 if x.mtype != OnnxKind.Tensor:
75 raise ShapeInferenceException( # pragma: no cover
76 f"Result {x!r} must be a tensor.")
77 y = known_shapes[node.input[1]]
78 if y.mtype != OnnxKind.Tensor:
79 raise ShapeInferenceException( # pragma: no cover
80 f"Result {y!r} must be a tensor.")
81 cp = x.copy()
82 cp.dtype = y.dtype
83 return known_shapes.update(node.output[0], cp)
86def shape_ceil(known_shapes, node):
87 "Infers shape for operator Ceil."
88 return _element_unary(known_shapes, node)
91def shape_celu(known_shapes, node):
92 "Infers shape for operator Celu."
93 return _element_unary(known_shapes, node)
96def shape_clip(known_shapes, node):
97 "Infers shape for operator Clip."
98 return _element_unary(known_shapes, node, one_input=False)
101def shape_cos(known_shapes, node):
102 "Infers shape for operator Cos."
103 return _element_unary(known_shapes, node)
106def shape_cosh(known_shapes, node):
107 "Infers shape for operator Cosh."
108 return _element_unary(known_shapes, node)
111def shape_elu(known_shapes, node):
112 "Infers shape for operator Elu."
113 return _element_unary(known_shapes, node)
116def shape_erf(known_shapes, node):
117 "Infers shape for operator Erf."
118 return _element_unary(known_shapes, node)
121def shape_exp(known_shapes, node):
122 "Infers shape for operator Exp."
123 return _element_unary(known_shapes, node)
126def shape_floor(known_shapes, node):
127 "Infers shape for operator Floor."
128 return _element_unary(known_shapes, node)
131def shape_hardmax(known_shapes, node):
132 "Infers shape for operator Hardmax."
133 return _element_unary(known_shapes, node)
136def shape_hardsigmoid(known_shapes, node):
137 "Infers shape for operator HardSigmoid."
138 return _element_unary(known_shapes, node)
141def shape_identity(known_shapes, node):
142 "Infers shape for operator Identity."
143 return _element_unary(known_shapes, node)
146def shape_isnan(known_shapes, node):
147 "Infers shape for operator IsNan."
148 return _element_unary(known_shapes, node, numpy.bool_)
151def shape_isinf(known_shapes, node):
152 "Infers shape for operator IsInf."
153 return _element_unary(known_shapes, node, numpy.bool_)
156def shape_leakyrelu(known_shapes, node):
157 "Infers shape for operator LeakyRelu."
158 return _element_unary(known_shapes, node)
161def shape_log(known_shapes, node):
162 "Infers shape for operator Log."
163 return _element_unary(known_shapes, node)
166def shape_logsoftmax(known_shapes, node):
167 "Infers shape for operator LogSoftmax."
168 return shape_softmax(known_shapes, node)
171def shape_neg(known_shapes, node):
172 "Infers shape for operator Neg."
173 return _element_unary(known_shapes, node)
176def shape_not(known_shapes, node):
177 "Infers shape for operator Not."
178 x = known_shapes[node.input[0]]
179 if x.dtype != numpy.bool_:
180 raise ShapeInferenceException(
181 f"Unexpected input type for operator Not {x.dtype!r} (must be bool).")
182 return _element_unary(known_shapes, node)
185def shape_reciprocal(known_shapes, node):
186 "Infers shape for operator Reciprocal."
187 return _element_unary(known_shapes, node)
190def shape_relu(known_shapes, node):
191 "Infers shape for operator Relu."
192 return _element_unary(known_shapes, node)
195def shape_round(known_shapes, node):
196 "Infers shape for operator Round."
197 return _element_unary(known_shapes, node)
200def shape_selu(known_shapes, node):
201 "Infers shape for operator Selu."
202 return _element_unary(known_shapes, node)
205def shape_shrink(known_shapes, node):
206 "Infers shape for operator Shrink."
207 return _element_unary(known_shapes, node)
210def shape_sigmoid(known_shapes, node):
211 "Infers shape for operator Sigmoid."
212 return _element_unary(known_shapes, node)
215def shape_sign(known_shapes, node):
216 "Infers shape for operator Sigmoid."
217 return _element_unary(known_shapes, node)
220def shape_sin(known_shapes, node):
221 "Infers shape for operator Sin."
222 return _element_unary(known_shapes, node)
225def shape_sinh(known_shapes, node):
226 "Infers shape for operator Sinh."
227 return _element_unary(known_shapes, node)
230def shape_softmax(known_shapes, node):
231 "Infers shape for operator Softmax."
232 return _element_unary(known_shapes, node)
235def shape_softplus(known_shapes, node):
236 "Infers shape for operator Softplus."
237 return _element_unary(known_shapes, node)
240def shape_softsign(known_shapes, node):
241 "Infers shape for operator Softsign."
242 return _element_unary(known_shapes, node)
245def shape_sqrt(known_shapes, node):
246 "Infers shape for operator Sqrt."
247 return _element_unary(known_shapes, node)
250def shape_tan(known_shapes, node):
251 "Infers shape for operator Tan."
252 return _element_unary(known_shapes, node)
255def shape_tanh(known_shapes, node):
256 "Infers shape for operator Tanh."
257 return _element_unary(known_shapes, node)
260def shape_thresholdedrelu(known_shapes, node):
261 "Infers shape for operator ThresholdedRelu."
262 return _element_unary(known_shapes, node)
265def shape_trilu(known_shapes, node):
266 "Infers shape for operator Trilu."
267 return _element_unary(known_shapes, node, one_input=False)