Coverage for mlprodict/npy/xop_variable.py: 94%

192 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Xop API to build onnx graphs. Inspired from :epkg:`sklearn-onnx`. 

4 

5.. versionadded:: 0.9 

6""" 

7import numpy 

8from onnx import ValueInfoProto 

9from onnx.helper import make_tensor_type_proto 

10from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE 

11from onnx.defs import onnx_opset_version 

12from .. import __max_supported_opset__ 

13 

14 

15def max_supported_opset(): 

16 """ 

17 Returns the latest supported opset for the main domain. 

18 

19 .. runpython:: 

20 :showcode: 

21 

22 from mlprodict.npy.xop_variable import max_supported_opset 

23 print("max_supported_opset() returns", max_supported_opset()) 

24 """ 

25 return min(__max_supported_opset__, onnx_opset_version()) 

26 

27 

28def is_numpy_dtype(dtype): 

29 """ 

30 Tells if a dtype is a numpy dtype. 

31 

32 :param dtype: anything 

33 :return: boolean 

34 """ 

35 if isinstance(dtype, (list, dict, Variable)): 

36 return False 

37 if dtype in NP_TYPE_TO_TENSOR_TYPE: 

38 return True 

39 dt = numpy.dtype(dtype) 

40 if dt in NP_TYPE_TO_TENSOR_TYPE: 

41 return True 

42 return False 

43 

44 

45def numpy_type_prototype(dtype): 

46 """ 

47 Converts a numpy dtyp into a TensorProto dtype. 

48 

49 :param dtype: dtype 

50 :return: proto dtype 

51 """ 

52 if dtype in NP_TYPE_TO_TENSOR_TYPE: 

53 return NP_TYPE_TO_TENSOR_TYPE[dtype] 

54 dt = numpy.dtype(dtype) 

55 if dt in NP_TYPE_TO_TENSOR_TYPE: 

56 return NP_TYPE_TO_TENSOR_TYPE[dt] 

57 raise ValueError( # pragma: no cover 

58 f"Unable to convert dtype {dtype!r} into ProtoType.") 

59 

60 

61def guess_numpy_type(data_type): 

62 """ 

63 Guesses the corresponding numpy type based on data_type. 

64 """ 

65 if data_type in (numpy.float64, numpy.float32, numpy.int8, numpy.uint8, 

66 numpy.str_, numpy.bool_, numpy.int32, numpy.int64): 

67 return data_type 

68 if data_type == str: 

69 return numpy.str_ 

70 if data_type == bool: 

71 return numpy.bool_ 

72 name2numpy = { 

73 'FloatTensorType': numpy.float32, 

74 'DoubleTensorType': numpy.float64, 

75 'Int32TensorType': numpy.int32, 

76 'Int64TensorType': numpy.int64, 

77 'StringTensorType': numpy.str_, 

78 'BooleanTensorType': numpy.bool_, 

79 'Complex64TensorType': numpy.complex64, 

80 'Complex128TensorType': numpy.complex128, 

81 } 

82 cl_name = data_type.__class__.__name__ 

83 if cl_name in name2numpy: 

84 return name2numpy[cl_name] 

85 if hasattr(data_type, 'type'): 

86 return guess_numpy_type(data_type.type) 

87 raise NotImplementedError( # pragma: no cover 

88 f"Unsupported data_type '{data_type}'.") 

89 

90 

91class ExistingVariable: 

92 """ 

93 Temporary name. 

94 

95 :param name: variable name 

96 :param op: operator it comes from 

97 """ 

98 

99 def __init__(self, name, op): 

100 self.name = name 

101 self.op = op 

102 

103 def __repr__(self): 

104 "usual" 

105 return f"{self.__class__.__name__}({self.name!r})" 

106 

107 @property 

108 def dtype(self): 

109 "Unknown type, returns None." 

110 return None 

111 

112 @property 

113 def added_dtype(self): 

114 "Unknown type, returns None." 

115 return None 

116 

117 

118class Variable: 

119 """ 

120 An input or output to an ONNX graph. 

121 

122 :param name: name 

123 :param dtype: :epkg:`numpy` dtype (can be None) 

124 :param shape: shape (can be None) 

125 :param added_dtype: :epkg:`numpy` dtype specified at conversion type 

126 (can be None) 

127 :param added_shape: :epkg:`numpy` shape specified at conversion type 

128 (can be None) 

129 """ 

130 

131 def __init__(self, name, dtype=None, shape=None, added_dtype=None, 

132 added_shape=None): 

133 if (dtype is not None and isinstance( 

134 dtype, (int, Variable, tuple, numpy.ndarray))): 

135 raise TypeError( 

136 f"Unexpected type {type(dtype)!r} for dtype.") 

137 if (added_dtype is not None and isinstance( 

138 added_dtype, (int, Variable, tuple, numpy.ndarray))): 

139 raise TypeError( 

140 f"Unexpected type {type(added_dtype)!r} for added_dtype.") 

141 if shape is not None and not isinstance(shape, (tuple, list)): 

142 raise TypeError( 

143 f"Unexpected type {type(shape)!r} for shape.") 

144 if (added_shape is not None and not isinstance( 

145 added_shape, (tuple, list))): 

146 raise TypeError( 

147 f"Unexpected type {type(added_shape)!r} for added_shape.") 

148 

149 if isinstance(name, Variable): 

150 if (dtype is not None or shape is not None or 

151 added_dtype is not None or added_shape is not None): 

152 raise ValueError( # pragma: no cover 

153 "If name is a Variable, then all others attributes " 

154 "should be None.") 

155 

156 self.name_ = name.name_ 

157 self.dtype_ = name.dtype_ 

158 self.added_dtype_ = name.added_dtype_ 

159 self.shape_ = name.shape_ 

160 self.added_shape_ = name.added_shape_ 

161 else: 

162 if not isinstance(name, str): 

163 raise TypeError( # pragma: no cover 

164 f"name must be a string not {type(name)!r}.") 

165 

166 self.name_ = name 

167 self.dtype_ = dtype 

168 self.added_dtype_ = added_dtype 

169 self.shape_ = shape 

170 self.added_shape_ = added_shape 

171 

172 def to_skl2onnx(self, scope=None): 

173 """ 

174 Converts this instance into an instance of *Variable* 

175 from :epkg:`sklearn-onnx`. 

176 """ 

177 from skl2onnx.common._topology import Variable as skl2onnxVariable # delayed 

178 from skl2onnx.common.data_types import _guess_numpy_type # delayed 

179 inst = _guess_numpy_type(self.dtype, self.shape) 

180 var = skl2onnxVariable(self.name, self.name, type=inst, scope=scope) 

181 return var 

182 

183 @staticmethod 

184 def from_skl2onnx(var): 

185 """ 

186 Converts variable from :epkg:`sklearn-onnx` into this class. 

187 """ 

188 return Variable(var.onnx_name, guess_numpy_type(var.type), 

189 shape=var.type.shape) 

190 

191 @staticmethod 

192 def from_skl2onnx_tuple(var): 

193 """ 

194 Converts variable from :epkg:`sklearn-onnx` into this class 

195 defined as a tuple. 

196 """ 

197 return Variable(var[0], guess_numpy_type(var[1]), 

198 shape=var[1].shape) 

199 

200 @property 

201 def name(self): 

202 "Returns the variable name (`self.name_`)." 

203 return self.name_ 

204 

205 @property 

206 def dtype(self): 

207 "Returns `self.dtype_`." 

208 return self.dtype_ 

209 

210 @property 

211 def added_dtype(self): 

212 "Returns `self.added_dtype_`." 

213 return self.added_dtype_ 

214 

215 @property 

216 def shape(self): 

217 "Returns `self.shape_`." 

218 return self.shape_ 

219 

220 @property 

221 def proto_type(self): 

222 "Returns the proto type for `self.dtype_`." 

223 if self.dtype_ is None: 

224 return 0 

225 return numpy_type_prototype(self.dtype_) 

226 

227 @property 

228 def proto_added_type(self): 

229 "Returns the proto type for `self.added_dtype_` or `self.dtype_`." 

230 dt = self.added_dtype_ or self.dtype_ 

231 if dt is None: 

232 return 0 

233 return numpy_type_prototype(dt) 

234 

235 @property 

236 def proto_added_shape(self): 

237 "Returns the shape for `self.added_shape_` or `self.shape`." 

238 dt = self.added_shape_ or self.shape_ 

239 if dt is None: 

240 return None 

241 return list(dt) 

242 

243 def __repr__(self): 

244 "usual" 

245 kwargs = dict(dtype=self.dtype_, shape=self.shape_, 

246 added_dtype=self.added_dtype_, 

247 added_shape=self.added_shape_) 

248 kwargs = {k: v for k, v in kwargs.items() if v is not None} 

249 if len(kwargs) > 0: 

250 msg = ", " + ", ".join(f"{k}={v!r}" for k, v in kwargs.items()) 

251 else: 

252 msg = '' 

253 return f"{self.__class__.__name__}({self.name_!r}{msg})" 

254 

255 def is_named(self, name): 

256 "Tells the variable is named like that." 

257 if not isinstance(name, str): 

258 raise TypeError( # pragma: no cover 

259 f"name is expected to be a string not {type(name)!r}.") 

260 return self.name == name 

261 

262 def copy_add(self, dtype): 

263 """ 

264 Returns a copy of this variable with a new dtype. 

265 

266 :param dtype: added type 

267 :return: @see cl Variable 

268 """ 

269 if self.added_dtype_ is not None: 

270 raise RuntimeError( # pragma: no cover 

271 "Cannot copy as added_dtype is not None.") 

272 if isinstance(dtype, numpy.ndarray): 

273 dtype, shape = dtype.dtype, dtype.shape 

274 else: 

275 shape = None 

276 return Variable(self.name_, self.dtype_, self.shape_, dtype, shape) 

277 

278 def copy_merge(self, var, shape=None): 

279 """ 

280 Merges information from both Variable. 

281 """ 

282 if not isinstance(var, Variable): 

283 if shape is not None: 

284 raise RuntimeError( # pragma: no cover 

285 "shape must be None if var is a Variable.") 

286 return self.copy_add(var) 

287 res = Variable(self.name_, self.dtype_, 

288 shape or self.shape_, self.added_dtype_, 

289 self.added_shape_) 

290 if self.added_dtype_ is None and var.dtype_ is not None: 

291 res.added_dtype_ = var.dtype_ 

292 if self.added_shape_ is None and var.shape_ is not None: 

293 res.added_shape_ = var.shape_ 

294 return res 

295 

296 def copy_name(self, name): 

297 """ 

298 Returns a copy with a new name. 

299 """ 

300 return Variable( 

301 name or self.name_, self.dtype_, 

302 self.shape_, self.added_dtype_, 

303 self.added_shape_) 

304 

305 def __eq__(self, other): 

306 """ 

307 Compares every attributes. 

308 """ 

309 if not isinstance(other, Variable): 

310 raise TypeError( 

311 f"Unexpected type {type(other)!r}.") 

312 if self.name != other.name: 

313 return False 

314 if self.shape_ != other.shape_: 

315 return False 

316 if self.dtype_ != other.dtype_: 

317 return False 

318 return True 

319 

320 def make_value_info(self): 

321 """ 

322 Converts the variable into `onnx.ValueInfoProto`. 

323 

324 :return: instance of `onnx.ValueInfoProto` 

325 """ 

326 value_info = ValueInfoProto() 

327 value_info.name = self.name 

328 tensor_type_proto = make_tensor_type_proto(self.proto_type, self.shape) 

329 value_info.type.CopyFrom(tensor_type_proto) # pylint: disable=E1101 

330 return value_info 

331 

332 @staticmethod 

333 def from_pb(obj): 

334 """ 

335 Creates a Variable from a protobuf object. 

336 

337 :param obj: initializer, tensor 

338 :return: @see cl Variable 

339 """ 

340 from ..onnx_tools.onnx2py_helper import from_pb 

341 name, ty, shape = from_pb(obj) 

342 return Variable(name, ty, shape=shape) 

343 

344 

345class NodeResultName: 

346 """ 

347 Defines a result name for a node. 

348 

349 :param node: node it comes from 

350 :param index: index of the output 

351 """ 

352 

353 def __init__(self, node, index): 

354 self.node = node 

355 self.index = index 

356 

357 def __repr__(self): 

358 "Usual" 

359 return f"{self.__class__.__name__}({self.node!r}, {self.index!r})" 

360 

361 def get_name(self): 

362 """ 

363 Returns a name from output_names or a suggestion for a name. 

364 """ 

365 if self.node is None: 

366 raise RuntimeError( # pragma: no cover 

367 "node must not be None.") 

368 if self.node.output_names is not None: 

369 return self.node.output_names[self.index].name 

370 cl = self.node.op_type.lower()[:3] 

371 return "out_%s_%d" % (cl, self.index) 

372 

373 

374class DetectedVariable: 

375 """ 

376 Wrapper around a @see cl Variable to detect inputs 

377 and outputs of a graph. 

378 

379 :param node: node where the variable was detected 

380 :param var: instance of @see cl Variable 

381 :param index: index, only used if it is an output 

382 """ 

383 

384 def __init__(self, node, var, index): 

385 if not isinstance(var, (Variable, ExistingVariable)): 

386 raise TypeError( # pragma: no cover 

387 f"Unexpected type {type(var)!r}, it should be a Variable.") 

388 self.node = node 

389 self.var = var 

390 self.index = index 

391 

392 @property 

393 def name(self): 

394 "Returns variable name." 

395 return self.var.name 

396 

397 def __repr__(self): 

398 "usual" 

399 sindex = f", {self.index}" if self.index >= 0 else "" 

400 if self.node is None: 

401 return f"{self.__class__.__name__}(None, {self.var!r}{sindex})" 

402 return "%s(%s-%d, %r%s)" % ( 

403 self.__class__.__name__, self.node.__class__.__name__, 

404 id(self.node), self.var, sindex) 

405 

406 

407class InputDetectedVariable(DetectedVariable): 

408 """ 

409 Instance of @see cl DetectedVariable. 

410 Only for inputs. 

411 """ 

412 

413 def __init__(self, node, var): 

414 DetectedVariable.__init__(self, node, var, -1) 

415 

416 

417class OutputDetectedVariable(DetectedVariable): 

418 """ 

419 Instance of @see cl DetectedVariable. 

420 Only for outputs. 

421 """ 

422 pass