Coverage for mlprodict/npy/xop_opset.py: 91%

97 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# pylint: disable=E0602 

2""" 

3@file 

4@brief Xop API to build onnx graphs. Inspired from :epkg:`sklearn-onnx`. 

5 

6.. versionadded:: 0.9 

7""" 

8import numpy 

9from .xop import loadop 

10 

11 

12def OnnxReduceSumApi11(*x, axes=None, keepdims=1, op_version=None, 

13 output_names=None): 

14 """ 

15 Adds operator ReduceSum with opset>=13 following API from opset 12. 

16 """ 

17 if op_version is None: 

18 raise RuntimeError( # pragma: no cover 

19 "op_version must be specified.") 

20 if op_version is None or op_version >= 13: 

21 OnnxReduceSum = loadop('ReduceSum') 

22 if axes is None: 

23 return OnnxReduceSum( 

24 *x, keepdims=keepdims, op_version=op_version, 

25 output_names=output_names) 

26 return OnnxReduceSum( 

27 *x, numpy.array(axes, dtype=numpy.int64), 

28 keepdims=keepdims, op_version=op_version, 

29 output_names=output_names) 

30 if op_version >= 11: 

31 OnnxReduceSum_11 = loadop('ReduceSum_11') 

32 if axes is None: 

33 return OnnxReduceSum_11( 

34 *x, keepdims=keepdims, 

35 op_version=op_version, output_names=output_names) 

36 return OnnxReduceSum_11( 

37 *x, axes=axes, keepdims=keepdims, 

38 op_version=op_version, output_names=output_names) 

39 OnnxReduceSum_1 = loadop('ReduceSum_1') 

40 if axes is None: 

41 return OnnxReduceSum_1(*x, keepdims=keepdims, 

42 op_version=op_version, 

43 output_names=output_names) 

44 return OnnxReduceSum_1(*x, axes=axes, keepdims=keepdims, 

45 op_version=op_version, output_names=output_names) 

46 

47 

48def OnnxSplitApi11(*x, axis=0, split=None, op_version=None, 

49 output_names=None): 

50 """ 

51 Adds operator Split with opset>=13 following API from opset 11. 

52 """ 

53 if op_version is None: 

54 raise RuntimeError( # pragma: no cover 

55 "op_version must be specified.") 

56 if op_version is None or op_version >= 13: 

57 OnnxSplit = loadop('Split') 

58 if split is None: 

59 return OnnxSplit( 

60 *x, axis=axis, op_version=op_version, 

61 output_names=output_names) 

62 return OnnxSplit( 

63 *x, numpy.array(split, dtype=numpy.int64), axis=axis, 

64 op_version=op_version, output_names=output_names) 

65 if op_version >= 11: 

66 OnnxSplit_11 = loadop('Split_11') 

67 if split is None: 

68 return OnnxSplit_11( 

69 *x, axis=axis, op_version=op_version, 

70 output_names=output_names) 

71 return OnnxSplit_11( 

72 *x, split=split, axis=axis, op_version=op_version, 

73 output_names=output_names) 

74 OnnxSplit_2 = loadop('Split_2') 

75 if split is None: 

76 return OnnxSplit_2( 

77 *x, axis=axis, op_version=op_version, output_names=output_names) 

78 return OnnxSplit_2(*x, split=split, axis=axis, 

79 op_version=op_version, output_names=output_names) 

80 

81 

82def OnnxSqueezeApi11(*x, axes=None, op_version=None, 

83 output_names=None): 

84 """ 

85 Adds operator Squeeze with opset>=13 following API from opset 11. 

86 """ 

87 if op_version is None: 

88 raise RuntimeError( # pragma: no cover 

89 "op_version must be specified.") 

90 if op_version is None or op_version >= 13: 

91 OnnxSqueeze = loadop('Squeeze') 

92 return OnnxSqueeze( 

93 *x, numpy.array(axes, dtype=numpy.int64), 

94 op_version=op_version, output_names=output_names) 

95 if op_version >= 11: 

96 OnnxSqueeze_11 = loadop('Squeeze_11') 

97 return OnnxSqueeze_11( 

98 *x, axes=axes, op_version=op_version, 

99 output_names=output_names) 

100 OnnxSqueeze_1 = loadop('Squeeze_1') 

101 return OnnxSqueeze_1(*x, axes=axes, 

102 op_version=op_version, output_names=output_names) 

103 

104 

105def OnnxUnsqueezeApi11(*x, axes=None, op_version=None, 

106 output_names=None): 

107 """ 

108 Adds operator Unsqueeze with opset>=13 following API from opset 11. 

109 """ 

110 if op_version is None: 

111 raise RuntimeError( # pragma: no cover 

112 "op_version must be specified.") 

113 if op_version is None or op_version >= 13: 

114 OnnxUnsqueeze = loadop('Unsqueeze') 

115 return OnnxUnsqueeze( 

116 *x, numpy.array(axes, dtype=numpy.int64), 

117 op_version=op_version, output_names=output_names) 

118 if op_version >= 11: 

119 OnnxUnsqueeze_11 = loadop('Unsqueeze_11') 

120 return OnnxUnsqueeze_11( 

121 *x, axes=axes, op_version=op_version, 

122 output_names=output_names) 

123 OnnxUnsqueeze_1 = loadop('Unsqueeze_1') 

124 return OnnxUnsqueeze_1(*x, axes=axes, 

125 op_version=op_version, output_names=output_names) 

126 

127 

128def OnnxReduceL2_typed(dtype, x, axes=None, keepdims=1, op_version=None, 

129 output_names=None): 

130 """ 

131 Adds operator ReduceL2 for float or double. 

132 """ 

133 OnnxMul, OnnxSqrt = loadop('Mul', 'Sqrt') 

134 if dtype == numpy.float32: 

135 OnnxReduceL2 = loadop('ReduceL2') 

136 return OnnxReduceL2( 

137 x, axes=axes, keepdims=keepdims, 

138 op_version=op_version, output_names=output_names) 

139 x2 = OnnxMul(x, x, op_version=op_version) 

140 red = OnnxReduceSumApi11( 

141 x2, axes=[1], keepdims=1, op_version=op_version) 

142 return OnnxSqrt( 

143 red, op_version=op_version, output_names=output_names) 

144 

145 

146def OnnxReshapeApi13(*x, allowzero=0, op_version=None, 

147 output_names=None): 

148 """ 

149 Adds operator Reshape with opset>=14 following API from opset 13. 

150 """ 

151 if op_version is None: 

152 raise RuntimeError( # pragma: no cover 

153 "op_version must be specified.") 

154 if op_version is None or op_version >= 14: 

155 OnnxReshape = loadop('Reshape') 

156 return OnnxReshape( 

157 *x, allowzero=allowzero, 

158 op_version=op_version, output_names=output_names) 

159 if op_version >= 13: 

160 OnnxReshape_13 = loadop('Reshape_13') 

161 return OnnxReshape_13( 

162 *x, op_version=op_version, output_names=output_names) 

163 OnnxReshape_5 = loadop('Reshape_5') 

164 return OnnxReshape_5( 

165 *x, op_version=op_version, output_names=output_names) 

166 

167 

168def OnnxReduceAnyApi18(cl18, cl13, cl11, cl1, *x, axes=None, keepdims=1, 

169 op_version=None, output_names=None): 

170 """ 

171 Adds operator Reduce* with opset>=18 following API from opset 17. 

172 """ 

173 if op_version is None or op_version >= 18: 

174 if axes is None: 

175 return cl18( 

176 *x, keepdims=keepdims, op_version=op_version, 

177 output_names=output_names) 

178 return cl18( 

179 *x, numpy.array(axes, dtype=numpy.int64), 

180 keepdims=keepdims, op_version=op_version, 

181 output_names=output_names) 

182 if op_version >= 13: 

183 if axes is None: 

184 return cl13(*x, keepdims=keepdims, 

185 op_version=op_version, 

186 output_names=output_names) 

187 return cl13(*x, axes=axes, keepdims=keepdims, 

188 op_version=op_version, output_names=output_names) 

189 if op_version >= 11: 

190 if axes is None: 

191 return cl11(*x, keepdims=keepdims, 

192 op_version=op_version, 

193 output_names=output_names) 

194 return cl11(*x, axes=axes, keepdims=keepdims, 

195 op_version=op_version, output_names=output_names) 

196 if axes is None: 

197 return cl1(*x, keepdims=keepdims, 

198 op_version=op_version, 

199 output_names=output_names) 

200 return cl1(*x, axes=axes, keepdims=keepdims, 

201 op_version=op_version, output_names=output_names) 

202 

203 

204def OnnxReduceSumSquareApi18(*x, axes=None, keepdims=1, op_version=None, 

205 output_names=None): 

206 """ 

207 Adds operator ReduceSumSquare with opset>=18 following API from opset 17. 

208 """ 

209 OnnxReduceSumSquare = loadop('ReduceSumSquare') 

210 (OnnxReduceSumSquare_13, OnnxReduceSumSquare_11, 

211 OnnxReduceSumSquare_1) = loadop( 

212 'ReduceSumSquare_13', 'ReduceSumSquare_11', 'ReduceSumSquare_1') 

213 return OnnxReduceAnyApi18( 

214 OnnxReduceSumSquare, OnnxReduceSumSquare_13, 

215 OnnxReduceSumSquare_11, OnnxReduceSumSquare_1, 

216 *x, axes=axes, keepdims=keepdims, op_version=op_version, 

217 output_names=output_names) 

218 

219 

220def OnnxReduceMeanApi18(*x, axes=None, keepdims=1, op_version=None, 

221 output_names=None): 

222 """ 

223 Adds operator ReduceMean with opset>=18 following API from opset 17. 

224 """ 

225 OnnxReduceMean = loadop('ReduceMean') 

226 (OnnxReduceMean_13, OnnxReduceMean_11, OnnxReduceMean_1) = loadop( 

227 'ReduceMean_13', 'ReduceMean_11', 'ReduceMean_1') 

228 return OnnxReduceAnyApi18( 

229 OnnxReduceMean, OnnxReduceMean_13, 

230 OnnxReduceMean_11, OnnxReduceMean_1, 

231 *x, axes=axes, keepdims=keepdims, op_version=op_version, 

232 output_names=output_names)