Coverage for mlprodict/npy/onnx_numpy_wrapper.py: 97%

109 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Wraps :epkg:`numpy` functions into :epkg:`onnx`. 

4 

5.. versionadded:: 0.6 

6""" 

7import warnings 

8from .onnx_version import FctVersion 

9from .onnx_numpy_annotation import get_args_kwargs 

10from .onnx_numpy_compiler import OnnxNumpyCompiler 

11 

12 

13class _created_classes: 

14 """ 

15 Class to store all dynamic classes created by wrappers. 

16 """ 

17 

18 def __init__(self): 

19 self.stored = {} 

20 

21 def append(self, name, cl): 

22 """ 

23 Adds a class into `globals()` to enable pickling on dynamic 

24 classes. 

25 """ 

26 if name in self.stored: 

27 warnings.warn( # pragma: no cover 

28 "Class %r overwritten in\n%r\n---\n%r" % ( 

29 name, ", ".join(sorted(self.stored)), cl), 

30 RuntimeWarning) 

31 self.stored[name] = cl 

32 globals()[name] = cl 

33 

34 

35_created_classes_inst = _created_classes() 

36 

37 

38class wrapper_onnxnumpy: 

39 """ 

40 Intermediate wrapper to store a pointer 

41 on the compiler (type: @see cl OnnxNumpyCompiler). 

42 

43 :param compiled: instance of @see cl OnnxNumpyCompiler 

44 

45 .. versionadded:: 0.6 

46 """ 

47 

48 def __init__(self, compiled): 

49 self.compiled = compiled 

50 

51 def __call__(self, *args, **kwargs): 

52 """ 

53 Calls the compiled function with arguments `args`. 

54 """ 

55 from .onnx_variable import OnnxVar 

56 try: 

57 return self.compiled(*args, **kwargs) 

58 except (TypeError, RuntimeError, ValueError) as e: 

59 if any(map(lambda a: isinstance(a, OnnxVar), args)): 

60 return self.__class__.__fct__( # pylint: disable=E1101 

61 *args, **kwargs) 

62 raise RuntimeError( 

63 "Unable to call the compiled version, args is %r. " 

64 "kwargs=%r." % ([type(a) for a in args], kwargs)) from e 

65 

66 def __getstate__(self): 

67 """ 

68 Serializes everything but the function which generates 

69 the ONNX graph, not needed anymore. 

70 """ 

71 return dict(compiled=self.compiled) 

72 

73 def __setstate__(self, state): 

74 """ 

75 Serializes everything but the function which generates 

76 the ONNX graph, not needed anymore. 

77 """ 

78 self.compiled = state['compiled'] 

79 

80 def to_onnx(self, **kwargs): 

81 """ 

82 Returns the ONNX graph for the wrapped function. 

83 It takes additional arguments to distinguish between multiple graphs. 

84 This happens when a function needs to support multiple type. 

85 

86 :return: ONNX graph 

87 """ 

88 return self.compiled.to_onnx(**kwargs) 

89 

90 

91def onnxnumpy(op_version=None, runtime=None, signature=None): 

92 """ 

93 Decorator to declare a function implemented using 

94 :epkg:`numpy` syntax but executed with :epkg:`ONNX` 

95 operators. 

96 

97 :param op_version: :epkg:`ONNX` opset version 

98 :param runtime: `'onnxruntime'` or one implemented by 

99 @see cl OnnxInference 

100 :param signature: it should be used when the function 

101 is not annoatated. 

102 

103 Equivalent to `onnxnumpy(arg)(foo)`. 

104 

105 .. versionadded:: 0.6 

106 """ 

107 def decorator_fct(fct): 

108 compiled = OnnxNumpyCompiler( 

109 fct, op_version=op_version, runtime=runtime, 

110 signature=signature) 

111 name = f"onnxnumpy_{fct.__name__}_{str(op_version)}_{runtime}" 

112 newclass = type( 

113 name, (wrapper_onnxnumpy,), 

114 {'__doc__': fct.__doc__, '__name__': name, '__fct__': fct}) 

115 _created_classes_inst.append(name, newclass) 

116 return newclass(compiled) 

117 return decorator_fct 

118 

119 

120def onnxnumpy_default(fct): 

121 """ 

122 Decorator with options to declare a function implemented 

123 using :epkg:`numpy` syntax but executed with :epkg:`ONNX` 

124 operators. 

125 

126 :param fct: function to wrap 

127 

128 .. versionadded:: 0.6 

129 """ 

130 return onnxnumpy()(fct) 

131 

132 

133class wrapper_onnxnumpy_np: 

134 """ 

135 Intermediate wrapper to store a pointer 

136 on the compiler (type: @see cl OnnxNumpyCompiler) 

137 supporting multiple signatures. 

138 

139 .. versionadded:: 0.6 

140 """ 

141 

142 def __init__(self, **kwargs): 

143 self.fct = kwargs['fct'] 

144 self.signature = kwargs['signature'] 

145 self.fctsig = kwargs.get('fctsig', None) 

146 self.args, self.kwargs = get_args_kwargs( 

147 self.fct, 

148 0 if self.signature is None else self.signature.n_optional) 

149 self.data = kwargs 

150 self.signed_compiled = {} 

151 

152 def __getstate__(self): 

153 """ 

154 Serializes everything but the function which generates 

155 the ONNX graph, not needed anymore. 

156 """ 

157 data_copy = {k: v for k, v in self.data.items() if k != 'fct'} 

158 return dict(signature=self.signature, args=self.args, 

159 kwargs=self.kwargs, data=data_copy, 

160 signed_compiled=self.signed_compiled) 

161 

162 def __setstate__(self, state): 

163 """ 

164 Restores serialized data. 

165 """ 

166 for k, v in state.items(): 

167 setattr(self, k, v) 

168 

169 def __getitem__(self, dtype): 

170 """ 

171 Returns the instance of @see cl wrapper_onnxnumpy 

172 mapped to *dtype*. 

173 

174 :param dtype: numpy dtype corresponding to the input dtype 

175 of the function 

176 :return: instance of @see cl wrapper_onnxnumpy 

177 """ 

178 if not isinstance(dtype, FctVersion): 

179 raise TypeError( # pragma: no cover 

180 f"dtype must be of type 'FctVersion' not {type(dtype)}: {dtype}.") 

181 if dtype not in self.signed_compiled: 

182 self._populate(dtype) 

183 key = dtype 

184 else: 

185 key = dtype 

186 return self.signed_compiled[key] 

187 

188 def __call__(self, *args, **kwargs): 

189 """ 

190 Calls the compiled function assuming the type of the first 

191 tensor in *args* defines the templated version of the function 

192 to convert into *ONNX*. 

193 """ 

194 from .onnx_variable import OnnxVar 

195 if len(self.kwargs) == 0: 

196 others = None 

197 else: 

198 others = tuple(kwargs.get(k, self.kwargs[k]) for k in self.kwargs) 

199 try: 

200 key = FctVersion( # pragma: no cover 

201 tuple(a if (a is None or hasattr(a, 'fit')) 

202 else a.dtype.type for a in args), 

203 others) 

204 return self[key](*args) 

205 except AttributeError as e: 

206 if any(map(lambda a: isinstance(a, OnnxVar), args)): 

207 return self.__class__.__fct__( # pylint: disable=E1101 

208 *args, **kwargs) 

209 raise RuntimeError( 

210 "Unable to call the compiled version, args is %r. " 

211 "kwargs=%r." % ([type(a) for a in args], kwargs)) from e 

212 

213 def _populate(self, version): 

214 """ 

215 Creates the appropriate runtime for function *fct* 

216 """ 

217 compiled = OnnxNumpyCompiler( 

218 fct=self.data["fct"], op_version=self.data["op_version"], 

219 runtime=self.data["runtime"], signature=self.data["signature"], 

220 version=version, fctsig=self.data.get('fctsig', None)) 

221 name = "onnxnumpy_np_%s_%s_%s_%s" % ( 

222 self.data["fct"].__name__, str(self.data["op_version"]), 

223 self.data["runtime"], version.as_string()) 

224 newclass = type( 

225 name, (wrapper_onnxnumpy,), 

226 {'__doc__': self.data["fct"].__doc__, '__name__': name}) 

227 

228 self.signed_compiled[version] = newclass(compiled) 

229 

230 def _validate_onnx_data(self, X): 

231 return X 

232 

233 def to_onnx(self, **kwargs): 

234 """ 

235 Returns the ONNX graph for the wrapped function. 

236 It takes additional arguments to distinguish between multiple graphs. 

237 This happens when a function needs to support multiple type. 

238 

239 :return: ONNX graph 

240 """ 

241 if len(self.signed_compiled) == 0: 

242 raise RuntimeError( # pragma: no cover 

243 "No ONNX graph was compiled.") 

244 if len(kwargs) == 0 and len(self.signed_compiled) == 1: 

245 # We take the only one. 

246 key = list(self.signed_compiled)[0] 

247 cpl = self.signed_compiled[key] 

248 return cpl.to_onnx() 

249 if len(kwargs) == 0: 

250 raise ValueError( 

251 "There are multiple compiled ONNX graphs associated " 

252 "with keys %r (add key=...)." % list(self.signed_compiled)) 

253 if list(kwargs) != ['key']: 

254 raise ValueError( 

255 f"kwargs should contain one parameter key=... but it is {kwargs!r}.") 

256 key = kwargs['key'] 

257 if key in self.signed_compiled: 

258 return self.signed_compiled[key].compiled.onnx_ 

259 found = [] 

260 for k, v in self.signed_compiled.items(): 

261 if k.args == key: 

262 found.append((k, v)) 

263 elif isinstance(key, tuple) and k.args == key: 

264 found.append((k, v)) 

265 elif k.args == (key, ) * len(k.args): 

266 found.append((k, v)) 

267 if len(found) == 1: 

268 return found[0][1].compiled.onnx_ 

269 raise ValueError( 

270 "Unable to find signature with key=%r among %r found=%r." % ( 

271 key, list(self.signed_compiled), found)) 

272 

273 

274def onnxnumpy_np(op_version=None, runtime=None, signature=None): 

275 """ 

276 Decorator to declare a function implemented using 

277 :epkg:`numpy` syntax but executed with :epkg:`ONNX` 

278 operators. 

279 

280 :param op_version: :epkg:`ONNX` opset version 

281 :param runtime: `'onnxruntime'` or one implemented by @see cl OnnxInference 

282 :param signature: it should be used when the function 

283 is not annoatated. 

284 

285 Equivalent to `onnxnumpy(arg)(foo)`. 

286 

287 .. versionadded:: 0.6 

288 """ 

289 def decorator_fct(fct): 

290 name = f"onnxnumpy_nb_{fct.__name__}_{str(op_version)}_{runtime}" 

291 newclass = type( 

292 name, (wrapper_onnxnumpy_np,), { 

293 '__doc__': fct.__doc__, 

294 '__name__': name, 

295 '__getstate__': wrapper_onnxnumpy_np.__getstate__, 

296 '__setstate__': wrapper_onnxnumpy_np.__setstate__, 

297 '__fct__': fct}) 

298 _created_classes_inst.append(name, newclass) 

299 return newclass( 

300 fct=fct, op_version=op_version, runtime=runtime, 

301 signature=signature) 

302 

303 return decorator_fct