Coverage for mlprodict/onnxrt/ops_shape/_element_unary.py: 99%

109 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Computes shape inference for element wise operators with one input. 

4""" 

5import numpy 

6from .shape_excs import ShapeInferenceException 

7from .shape_result import OnnxKind 

8 

9 

10def _element_unary(known_shapes, node, dtype=None, one_input=True): 

11 """ 

12 Infers shape for an element wise operator. 

13 The function returns but updates *known_shapes*. 

14 

15 :param known_shapes: known shapes 

16 :param node: Onnx node 

17 :param dtype: None to keep the same type as input, 

18 not None to change it 

19 :param one_input: check there is only one input 

20 :return: updated or not 

21 """ 

22 if one_input and len(node.input) != 1: 

23 raise ShapeInferenceException( # pragma: no cover 

24 f"Node {node.name!r} must have one input not {len(node.input)}.") 

25 x = known_shapes[node.input[0]] 

26 if x.mtype != OnnxKind.Tensor: 

27 raise ShapeInferenceException( # pragma: no cover 

28 f"Result {x!r} must be a tensor.") 

29 if dtype is None: 

30 return known_shapes.update(node.output[0], x.copy()) 

31 cp = x.copy() 

32 cp.dtype = dtype 

33 return known_shapes.update(node.output[0], cp) 

34 

35 

36def shape_abs(known_shapes, node): 

37 "Infers shape for operator Abs." 

38 return _element_unary(known_shapes, node) 

39 

40 

41def shape_acos(known_shapes, node): 

42 "Infers shape for operator Acos." 

43 return _element_unary(known_shapes, node) 

44 

45 

46def shape_acosh(known_shapes, node): 

47 "Infers shape for operator Acosh." 

48 return _element_unary(known_shapes, node) 

49 

50 

51def shape_asin(known_shapes, node): 

52 "Infers shape for operator Asin." 

53 return _element_unary(known_shapes, node) 

54 

55 

56def shape_asinh(known_shapes, node): 

57 "Infers shape for operator Asinh." 

58 return _element_unary(known_shapes, node) 

59 

60 

61def shape_atan(known_shapes, node): 

62 "Infers shape for operator Atan." 

63 return _element_unary(known_shapes, node) 

64 

65 

66def shape_atanh(known_shapes, node): 

67 "Infers shape for operator Atanh." 

68 return _element_unary(known_shapes, node) 

69 

70 

71def shape_castlike(known_shapes, node): 

72 "Infers shape for operator CastLike." 

73 x = known_shapes[node.input[0]] 

74 if x.mtype != OnnxKind.Tensor: 

75 raise ShapeInferenceException( # pragma: no cover 

76 f"Result {x!r} must be a tensor.") 

77 y = known_shapes[node.input[1]] 

78 if y.mtype != OnnxKind.Tensor: 

79 raise ShapeInferenceException( # pragma: no cover 

80 f"Result {y!r} must be a tensor.") 

81 cp = x.copy() 

82 cp.dtype = y.dtype 

83 return known_shapes.update(node.output[0], cp) 

84 

85 

86def shape_ceil(known_shapes, node): 

87 "Infers shape for operator Ceil." 

88 return _element_unary(known_shapes, node) 

89 

90 

91def shape_celu(known_shapes, node): 

92 "Infers shape for operator Celu." 

93 return _element_unary(known_shapes, node) 

94 

95 

96def shape_clip(known_shapes, node): 

97 "Infers shape for operator Clip." 

98 return _element_unary(known_shapes, node, one_input=False) 

99 

100 

101def shape_cos(known_shapes, node): 

102 "Infers shape for operator Cos." 

103 return _element_unary(known_shapes, node) 

104 

105 

106def shape_cosh(known_shapes, node): 

107 "Infers shape for operator Cosh." 

108 return _element_unary(known_shapes, node) 

109 

110 

111def shape_elu(known_shapes, node): 

112 "Infers shape for operator Elu." 

113 return _element_unary(known_shapes, node) 

114 

115 

116def shape_erf(known_shapes, node): 

117 "Infers shape for operator Erf." 

118 return _element_unary(known_shapes, node) 

119 

120 

121def shape_exp(known_shapes, node): 

122 "Infers shape for operator Exp." 

123 return _element_unary(known_shapes, node) 

124 

125 

126def shape_floor(known_shapes, node): 

127 "Infers shape for operator Floor." 

128 return _element_unary(known_shapes, node) 

129 

130 

131def shape_hardmax(known_shapes, node): 

132 "Infers shape for operator Hardmax." 

133 return _element_unary(known_shapes, node) 

134 

135 

136def shape_hardsigmoid(known_shapes, node): 

137 "Infers shape for operator HardSigmoid." 

138 return _element_unary(known_shapes, node) 

139 

140 

141def shape_identity(known_shapes, node): 

142 "Infers shape for operator Identity." 

143 return _element_unary(known_shapes, node) 

144 

145 

146def shape_isnan(known_shapes, node): 

147 "Infers shape for operator IsNan." 

148 return _element_unary(known_shapes, node, numpy.bool_) 

149 

150 

151def shape_isinf(known_shapes, node): 

152 "Infers shape for operator IsInf." 

153 return _element_unary(known_shapes, node, numpy.bool_) 

154 

155 

156def shape_leakyrelu(known_shapes, node): 

157 "Infers shape for operator LeakyRelu." 

158 return _element_unary(known_shapes, node) 

159 

160 

161def shape_log(known_shapes, node): 

162 "Infers shape for operator Log." 

163 return _element_unary(known_shapes, node) 

164 

165 

166def shape_logsoftmax(known_shapes, node): 

167 "Infers shape for operator LogSoftmax." 

168 return shape_softmax(known_shapes, node) 

169 

170 

171def shape_neg(known_shapes, node): 

172 "Infers shape for operator Neg." 

173 return _element_unary(known_shapes, node) 

174 

175 

176def shape_not(known_shapes, node): 

177 "Infers shape for operator Not." 

178 x = known_shapes[node.input[0]] 

179 if x.dtype != numpy.bool_: 

180 raise ShapeInferenceException( 

181 f"Unexpected input type for operator Not {x.dtype!r} (must be bool).") 

182 return _element_unary(known_shapes, node) 

183 

184 

185def shape_reciprocal(known_shapes, node): 

186 "Infers shape for operator Reciprocal." 

187 return _element_unary(known_shapes, node) 

188 

189 

190def shape_relu(known_shapes, node): 

191 "Infers shape for operator Relu." 

192 return _element_unary(known_shapes, node) 

193 

194 

195def shape_round(known_shapes, node): 

196 "Infers shape for operator Round." 

197 return _element_unary(known_shapes, node) 

198 

199 

200def shape_selu(known_shapes, node): 

201 "Infers shape for operator Selu." 

202 return _element_unary(known_shapes, node) 

203 

204 

205def shape_shrink(known_shapes, node): 

206 "Infers shape for operator Shrink." 

207 return _element_unary(known_shapes, node) 

208 

209 

210def shape_sigmoid(known_shapes, node): 

211 "Infers shape for operator Sigmoid." 

212 return _element_unary(known_shapes, node) 

213 

214 

215def shape_sign(known_shapes, node): 

216 "Infers shape for operator Sigmoid." 

217 return _element_unary(known_shapes, node) 

218 

219 

220def shape_sin(known_shapes, node): 

221 "Infers shape for operator Sin." 

222 return _element_unary(known_shapes, node) 

223 

224 

225def shape_sinh(known_shapes, node): 

226 "Infers shape for operator Sinh." 

227 return _element_unary(known_shapes, node) 

228 

229 

230def shape_softmax(known_shapes, node): 

231 "Infers shape for operator Softmax." 

232 return _element_unary(known_shapes, node) 

233 

234 

235def shape_softplus(known_shapes, node): 

236 "Infers shape for operator Softplus." 

237 return _element_unary(known_shapes, node) 

238 

239 

240def shape_softsign(known_shapes, node): 

241 "Infers shape for operator Softsign." 

242 return _element_unary(known_shapes, node) 

243 

244 

245def shape_sqrt(known_shapes, node): 

246 "Infers shape for operator Sqrt." 

247 return _element_unary(known_shapes, node) 

248 

249 

250def shape_tan(known_shapes, node): 

251 "Infers shape for operator Tan." 

252 return _element_unary(known_shapes, node) 

253 

254 

255def shape_tanh(known_shapes, node): 

256 "Infers shape for operator Tanh." 

257 return _element_unary(known_shapes, node) 

258 

259 

260def shape_thresholdedrelu(known_shapes, node): 

261 "Infers shape for operator ThresholdedRelu." 

262 return _element_unary(known_shapes, node) 

263 

264 

265def shape_trilu(known_shapes, node): 

266 "Infers shape for operator Trilu." 

267 return _element_unary(known_shapes, node, one_input=False)