Coverage for mlprodict/sklapi/onnx_transformer.py: 97%

160 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# coding: utf-8 

2""" 

3@file 

4@brief Wraps runtime into a :epkg:`scikit-learn` transformer. 

5""" 

6from io import BytesIO 

7import numpy 

8import pandas 

9import onnx 

10from sklearn.base import BaseEstimator, TransformerMixin 

11from skl2onnx.algebra.onnx_operator_mixin import OnnxOperatorMixin 

12from mlprodict.onnx_tools.onnx_manipulations import ( 

13 select_model_inputs_outputs, enumerate_model_node_outputs) 

14from ..onnx_tools.onnx2py_helper import _var_as_dict, onnx_model_opsets 

15from ..onnx_tools.exports.skl2onnx_helper import add_onnx_graph 

16from ..onnxrt import OnnxInference 

17 

18 

19class OnnxTransformer(BaseEstimator, TransformerMixin, OnnxOperatorMixin): 

20 """ 

21 Calls :epkg:`onnxruntime` or the runtime implemented 

22 in this package to transform input based on a ONNX graph. 

23 It follows :epkg:`scikit-learn` API 

24 so that it can be included in a :epkg:`scikit-learn` pipeline. 

25 See notebook :ref:`transferlearningrst` for an example. 

26 

27 :param onnx_bytes: bytes 

28 :param output_name: string 

29 requested output name or None to request all and 

30 have method *transform* to store all of them in a dataframe 

31 :param enforce_float32: boolean 

32 :epkg:`onnxruntime` only supports *float32*, 

33 :epkg:`scikit-learn` usually uses double floats, this parameter 

34 ensures that every array of double floats is converted into 

35 single floats 

36 :param runtime: string, defined the runtime to use 

37 as described in @see cl OnnxInference. 

38 :param change_batch_size: some models are converted for 

39 a specific batch size, this parameter changes it, 

40 None to avoid changing it, 0 to fix an undefined 

41 first dimension 

42 :param reshape: reshape the output to get 

43 a matrix and not a multidimensional array 

44 """ 

45 

46 def __init__(self, onnx_bytes, output_name=None, enforce_float32=True, 

47 runtime='python', change_batch_size=None, reshape=False): 

48 BaseEstimator.__init__(self) 

49 TransformerMixin.__init__(self) 

50 self.onnx_bytes = (onnx_bytes 

51 if not hasattr(onnx_bytes, 'SerializeToString') 

52 else onnx_bytes.SerializeToString()) 

53 self.output_name = output_name 

54 self.enforce_float32 = enforce_float32 

55 self.runtime = runtime 

56 self.change_batch_size = change_batch_size 

57 self.reshape = reshape 

58 

59 def __repr__(self): # pylint: disable=W0222 

60 """ 

61 usual 

62 """ 

63 ob = self.onnx_bytes 

64 if len(ob) > 20: 

65 ob = ob[:10] + b"..." + ob[-10:] 

66 return ("{0}(onnx_bytes={1}, output_name={2}, enforce_float32={3}, " 

67 "runtime='{4}')".format( 

68 self.__class__.__name__, ob, self.output_name, 

69 self.enforce_float32, self.runtime)) 

70 

71 def fit(self, X=None, y=None, **fit_params): 

72 """ 

73 Loads the :epkg:`ONNX` model. 

74 

75 :param X: unused 

76 :param y: unused 

77 :param fit_params: additional parameter (unused) 

78 :return: self 

79 """ 

80 from ..onnx_tools.optim.onnx_helper import change_input_first_dimension 

81 onx = onnx.load(BytesIO(self.onnx_bytes)) 

82 self.op_version = onnx_model_opsets(onx) 

83 

84 output_names = set( 

85 o.name for o in onx.graph.output) # pylint: disable=E1101 

86 updated = False 

87 if (self.output_name is not None and 

88 self.output_name not in output_names): 

89 # The model refers to intermediate outputs. 

90 onx = select_model_inputs_outputs( 

91 onx, outputs=[self.output_name]) 

92 updated = True 

93 

94 if self.change_batch_size is not None: 

95 onx = change_input_first_dimension( 

96 onx, self.change_batch_size) 

97 updated = True 

98 

99 onnx_bytes = ( 

100 onx.SerializeToString() if updated else self.onnx_bytes) 

101 self.onnxrt_ = OnnxInference( 

102 onnx_bytes, runtime=self.runtime, 

103 runtime_options=dict(log_severity_level=3)) 

104 self.inputs_ = self.onnxrt_.input_names 

105 self.inputs_shape_types_ = self.onnxrt_.input_names_shapes_types 

106 return self 

107 

108 def _check_arrays(self, inputs): 

109 """ 

110 Ensures that double floats are converted into single floats 

111 if *enforce_float32* is True or raises an exception. 

112 """ 

113 has = hasattr(self, "onnxrt_") 

114 sht = self.inputs_shape_types_ if has else None 

115 if sht is not None and len(sht) < len(inputs): 

116 raise RuntimeError( # pragma: no cover 

117 f"Unexpected number of inputs {len(inputs)} > {len(sht)} (expected).") 

118 for i, k in enumerate(inputs): 

119 v = inputs[k] 

120 if isinstance(v, numpy.ndarray): 

121 if v.dtype == numpy.float64 and self.enforce_float32: 

122 inputs[k] = v.astype(numpy.float32) 

123 continue 

124 if not has: 

125 continue 

126 exp = sht[i] 

127 if exp[1] != ('?', ) and exp[1][1:] != v.shape[1:]: 

128 raise RuntimeError( # pragma: no cover 

129 "Unexpected shape for input '{}': {} != {} " 

130 "(expected).".format( 

131 k, v.shape, exp[1])) 

132 if ((v.dtype == numpy.float32 and exp[2] != 'tensor(float)') or 

133 (v.dtype == numpy.float64 and exp[2] != 'tensor(double)')): 

134 raise TypeError( # pragma: no cover 

135 "Unexpected dtype for input '{}': {} != {} " 

136 "(expected).".format( 

137 k, v.dtype, exp[2])) 

138 

139 def transform(self, X, y=None, **inputs): 

140 """ 

141 Runs the predictions. If *X* is a dataframe, 

142 the function assumes every columns is a separate input, 

143 otherwise, *X* is considered as a first input and *inputs* 

144 can be used to specify extra inputs. 

145 

146 :param X: iterable, data to process 

147 (or first input if several expected) 

148 :param y: unused 

149 :param inputs: :epkg:`ONNX` graph support multiple inputs, 

150 each column of a dataframe is converted into as many inputs if 

151 *X* is a dataframe, otherwise, *X* is considered as the first input 

152 and *inputs* can be used to specify the other ones 

153 :return: :epkg:`DataFrame` 

154 """ 

155 if not hasattr(self, "onnxrt_"): 

156 raise AttributeError( # pragma: no cover 

157 "Transform OnnxTransformer must be fit first.") 

158 rt_inputs = {} 

159 if isinstance(X, numpy.ndarray): 

160 rt_inputs[self.inputs_[0]] = X 

161 elif isinstance(X, pandas.DataFrame): 

162 for c in X.columns: 

163 rt_inputs[c] = X[c] 

164 elif isinstance(X, dict) and len(inputs) == 0: 

165 for k, v in X.items(): 

166 rt_inputs[k] = v 

167 elif isinstance(X, list): 

168 if len(self.inputs_) == 1: 

169 rt_inputs[self.inputs_[0]] = numpy.array(X) 

170 else: 

171 for i in range(len(self.inputs_)): # pylint: disable=C0200 

172 rt_inputs[self.inputs_[i]] = [row[i] for row in X] 

173 

174 for k, v in inputs.items(): 

175 rt_inputs[k] = v 

176 

177 names = ([self.output_name] 

178 if self.output_name else self.onnxrt_.output_names) 

179 self._check_arrays(rt_inputs) 

180 doutputs = self.onnxrt_.run(rt_inputs) 

181 outputs = [doutputs[n] for n in names] 

182 

183 if self.reshape: 

184 n = outputs[0].shape[0] 

185 outputs = [o.reshape((n, -1)) for o in outputs] 

186 

187 if self.output_name or len(outputs) == 1: 

188 if isinstance(outputs[0], list): 

189 return pandas.DataFrame(outputs[0]) 

190 return outputs[0] 

191 

192 names = self.output_name if self.output_name else [ 

193 o for o in self.onnxrt_.output_names] 

194 concat = [] 

195 colnames = [] 

196 for k, v in zip(names, outputs): 

197 if isinstance(v, numpy.ndarray): 

198 if len(v.shape) == 1: 

199 v = v.reshape((-1, 1)) 

200 colnames.append(k) 

201 elif len(v.shape) == 2: 

202 colnames.extend("%s%d" % (k, i) for i in range(v.shape[1])) 

203 else: 

204 raise RuntimeError( # pragma: no cover 

205 f"Unexpected shape for results {k!r}: {v.shape!r}.") 

206 if isinstance(v, list): 

207 if len(v) == 0: 

208 raise RuntimeError( # pragma: no cover 

209 f"Output {k!r} is empty.") 

210 if not isinstance(v[0], dict): 

211 raise RuntimeError( # pragma: no cover 

212 f"Unexpected type for output {k!r} - value={v[0]!r}.") 

213 df = pandas.DataFrame(v) 

214 cols = list(sorted(df.columns)) 

215 v = df[cols].copy().values 

216 colnames.extend("%s%d" % (k, i) for i in range(v.shape[1])) 

217 concat.append(v) 

218 res = numpy.hstack(concat) 

219 return pandas.DataFrame(res, columns=colnames) 

220 

221 def fit_transform(self, X, y=None, **inputs): 

222 """ 

223 Loads the *ONNX* model and runs the predictions. 

224 

225 :param X: iterable, data to process 

226 (or first input if several expected) 

227 :param y: unused 

228 :param inputs: :epkg:`ONNX` graph support multiple inputs, 

229 each column of a dataframe is converted into as many inputs if 

230 *X* is a dataframe, otherwise, *X* is considered as the first input 

231 and *inputs* can be used to specify the other ones 

232 :return: :epkg:`DataFrame` 

233 """ 

234 return self.fit(X, y=y, **inputs).transform(X, y) 

235 

236 @staticmethod 

237 def enumerate_create(onnx_bytes, output_names=None, enforce_float32=True): 

238 """ 

239 Creates multiple *OnnxTransformer*, 

240 one for each requested intermediate node. 

241 

242 onnx_bytes : bytes 

243 output_names: string 

244 requested output names or None to request all and 

245 have method *transform* to store all of them in a dataframe 

246 enforce_float32 : boolean 

247 :epkg:`onnxruntime` only supports *float32*, 

248 :epkg:`scikit-learn` usually uses double floats, this parameter 

249 ensures that every array of double floats is converted into 

250 single floats 

251 :return: iterator on OnnxTransformer *('output name', OnnxTransformer)* 

252 """ 

253 selected = None if output_names is None else set(output_names) 

254 model = onnx.load(BytesIO(onnx_bytes)) 

255 for out in enumerate_model_node_outputs(model): 

256 m = select_model_inputs_outputs(model, out) 

257 if selected is None or out in selected: 

258 tr = OnnxTransformer(m.SerializeToString(), 

259 enforce_float32=enforce_float32) 

260 yield out, tr 

261 

262 def onnx_parser(self): 

263 """ 

264 Returns a parser for this model. 

265 """ 

266 def parser(scope=None, inputs=None): 

267 if scope is None: 

268 raise RuntimeError( # pragma: no cover 

269 f"scope cannot be None (parser of class {type(self)!r}).") 

270 if inputs is None: 

271 raise RuntimeError( # pragma: no cover 

272 f"inputs cannot be None (parser of class {type(self)!r}).") 

273 if (not hasattr(self, 'onnxrt_') or 

274 not hasattr(self.onnxrt_, 'output_names')): 

275 raise RuntimeError( # pragma: no cover 

276 'OnnxTransformer not fit.') 

277 if len(inputs) != len(self.inputs_): 

278 raise RuntimeError( # pragma: no cover 

279 "Mismatch between the number of inputs, expected %r, " 

280 "got %r." % (self.inputs_, inputs)) 

281 return self.onnxrt_.output_names 

282 return parser 

283 

284 def onnx_shape_calculator(self): 

285 def shape_calculator(operator): 

286 from skl2onnx.common.data_types import ( # delayed 

287 FloatTensorType, DoubleTensorType, Int64TensorType) 

288 cout = self.onnxrt_.output_names 

289 if len(operator.outputs) != len(cout): 

290 raise RuntimeError( # pragma: no cover 

291 "Mismatched number of outputs: {} != {}." 

292 "".format(len(operator.outputs), len(cout))) 

293 for out_op, out in zip(operator.outputs, self.onnxrt_.obj.graph.output): 

294 var = _var_as_dict(out) 

295 if var['type']['kind'] != 'tensor': 

296 raise NotImplementedError( # pragma: no cover 

297 f"Noy yet implemented for output:\n{out}") 

298 shape = var['type']['shape'] 

299 if shape[0] == 0: 

300 shape = (None,) + tuple(shape[1:]) 

301 elem = var['type']['elem'] 

302 if elem == 'float': 

303 out_op.type = FloatTensorType(shape=shape) 

304 elif elem == 'int64': 

305 out_op.type = Int64TensorType(shape=shape) 

306 elif elem == 'double': 

307 out_op.type = DoubleTensorType(shape=shape) 

308 else: 

309 raise NotImplementedError( # pragma: no cover 

310 f"Not yet implemented for elem_type: {elem!r}") 

311 return shape_calculator 

312 

313 def onnx_converter(self): 

314 """ 

315 Returns a converter for this model. 

316 If not overloaded, it fetches the converter 

317 mapped to the first *scikit-learn* parent 

318 it can find. 

319 """ 

320 def converter(scope, operator, container, onnx_model=None): 

321 op = operator.raw_operator 

322 onx = onnx_model or op.onnxrt_.obj 

323 add_onnx_graph(scope, operator, container, onx) 

324 

325 return converter 

326 

327 @property 

328 def opsets(self): 

329 """ 

330 Returns the opsets as dictionary ``{domain: opset}``. 

331 """ 

332 if hasattr(self, 'onnxrt_'): 

333 model = self.onnxrt_.obj 

334 else: 

335 model = onnx.load(BytesIO(self.onnx_bytes)) 

336 res = {} 

337 for oimp in model.opset_import: 

338 res[oimp.domain] = oimp.version 

339 return res