Coverage for mlprodict/npy/xop_variable.py: 94%
192 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Xop API to build onnx graphs. Inspired from :epkg:`sklearn-onnx`.
5.. versionadded:: 0.9
6"""
7import numpy
8from onnx import ValueInfoProto
9from onnx.helper import make_tensor_type_proto
10from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE
11from onnx.defs import onnx_opset_version
12from .. import __max_supported_opset__
15def max_supported_opset():
16 """
17 Returns the latest supported opset for the main domain.
19 .. runpython::
20 :showcode:
22 from mlprodict.npy.xop_variable import max_supported_opset
23 print("max_supported_opset() returns", max_supported_opset())
24 """
25 return min(__max_supported_opset__, onnx_opset_version())
28def is_numpy_dtype(dtype):
29 """
30 Tells if a dtype is a numpy dtype.
32 :param dtype: anything
33 :return: boolean
34 """
35 if isinstance(dtype, (list, dict, Variable)):
36 return False
37 if dtype in NP_TYPE_TO_TENSOR_TYPE:
38 return True
39 dt = numpy.dtype(dtype)
40 if dt in NP_TYPE_TO_TENSOR_TYPE:
41 return True
42 return False
45def numpy_type_prototype(dtype):
46 """
47 Converts a numpy dtyp into a TensorProto dtype.
49 :param dtype: dtype
50 :return: proto dtype
51 """
52 if dtype in NP_TYPE_TO_TENSOR_TYPE:
53 return NP_TYPE_TO_TENSOR_TYPE[dtype]
54 dt = numpy.dtype(dtype)
55 if dt in NP_TYPE_TO_TENSOR_TYPE:
56 return NP_TYPE_TO_TENSOR_TYPE[dt]
57 raise ValueError( # pragma: no cover
58 f"Unable to convert dtype {dtype!r} into ProtoType.")
61def guess_numpy_type(data_type):
62 """
63 Guesses the corresponding numpy type based on data_type.
64 """
65 if data_type in (numpy.float64, numpy.float32, numpy.int8, numpy.uint8,
66 numpy.str_, numpy.bool_, numpy.int32, numpy.int64):
67 return data_type
68 if data_type == str:
69 return numpy.str_
70 if data_type == bool:
71 return numpy.bool_
72 name2numpy = {
73 'FloatTensorType': numpy.float32,
74 'DoubleTensorType': numpy.float64,
75 'Int32TensorType': numpy.int32,
76 'Int64TensorType': numpy.int64,
77 'StringTensorType': numpy.str_,
78 'BooleanTensorType': numpy.bool_,
79 'Complex64TensorType': numpy.complex64,
80 'Complex128TensorType': numpy.complex128,
81 }
82 cl_name = data_type.__class__.__name__
83 if cl_name in name2numpy:
84 return name2numpy[cl_name]
85 if hasattr(data_type, 'type'):
86 return guess_numpy_type(data_type.type)
87 raise NotImplementedError( # pragma: no cover
88 f"Unsupported data_type '{data_type}'.")
91class ExistingVariable:
92 """
93 Temporary name.
95 :param name: variable name
96 :param op: operator it comes from
97 """
99 def __init__(self, name, op):
100 self.name = name
101 self.op = op
103 def __repr__(self):
104 "usual"
105 return f"{self.__class__.__name__}({self.name!r})"
107 @property
108 def dtype(self):
109 "Unknown type, returns None."
110 return None
112 @property
113 def added_dtype(self):
114 "Unknown type, returns None."
115 return None
118class Variable:
119 """
120 An input or output to an ONNX graph.
122 :param name: name
123 :param dtype: :epkg:`numpy` dtype (can be None)
124 :param shape: shape (can be None)
125 :param added_dtype: :epkg:`numpy` dtype specified at conversion type
126 (can be None)
127 :param added_shape: :epkg:`numpy` shape specified at conversion type
128 (can be None)
129 """
131 def __init__(self, name, dtype=None, shape=None, added_dtype=None,
132 added_shape=None):
133 if (dtype is not None and isinstance(
134 dtype, (int, Variable, tuple, numpy.ndarray))):
135 raise TypeError(
136 f"Unexpected type {type(dtype)!r} for dtype.")
137 if (added_dtype is not None and isinstance(
138 added_dtype, (int, Variable, tuple, numpy.ndarray))):
139 raise TypeError(
140 f"Unexpected type {type(added_dtype)!r} for added_dtype.")
141 if shape is not None and not isinstance(shape, (tuple, list)):
142 raise TypeError(
143 f"Unexpected type {type(shape)!r} for shape.")
144 if (added_shape is not None and not isinstance(
145 added_shape, (tuple, list))):
146 raise TypeError(
147 f"Unexpected type {type(added_shape)!r} for added_shape.")
149 if isinstance(name, Variable):
150 if (dtype is not None or shape is not None or
151 added_dtype is not None or added_shape is not None):
152 raise ValueError( # pragma: no cover
153 "If name is a Variable, then all others attributes "
154 "should be None.")
156 self.name_ = name.name_
157 self.dtype_ = name.dtype_
158 self.added_dtype_ = name.added_dtype_
159 self.shape_ = name.shape_
160 self.added_shape_ = name.added_shape_
161 else:
162 if not isinstance(name, str):
163 raise TypeError( # pragma: no cover
164 f"name must be a string not {type(name)!r}.")
166 self.name_ = name
167 self.dtype_ = dtype
168 self.added_dtype_ = added_dtype
169 self.shape_ = shape
170 self.added_shape_ = added_shape
172 def to_skl2onnx(self, scope=None):
173 """
174 Converts this instance into an instance of *Variable*
175 from :epkg:`sklearn-onnx`.
176 """
177 from skl2onnx.common._topology import Variable as skl2onnxVariable # delayed
178 from skl2onnx.common.data_types import _guess_numpy_type # delayed
179 inst = _guess_numpy_type(self.dtype, self.shape)
180 var = skl2onnxVariable(self.name, self.name, type=inst, scope=scope)
181 return var
183 @staticmethod
184 def from_skl2onnx(var):
185 """
186 Converts variable from :epkg:`sklearn-onnx` into this class.
187 """
188 return Variable(var.onnx_name, guess_numpy_type(var.type),
189 shape=var.type.shape)
191 @staticmethod
192 def from_skl2onnx_tuple(var):
193 """
194 Converts variable from :epkg:`sklearn-onnx` into this class
195 defined as a tuple.
196 """
197 return Variable(var[0], guess_numpy_type(var[1]),
198 shape=var[1].shape)
200 @property
201 def name(self):
202 "Returns the variable name (`self.name_`)."
203 return self.name_
205 @property
206 def dtype(self):
207 "Returns `self.dtype_`."
208 return self.dtype_
210 @property
211 def added_dtype(self):
212 "Returns `self.added_dtype_`."
213 return self.added_dtype_
215 @property
216 def shape(self):
217 "Returns `self.shape_`."
218 return self.shape_
220 @property
221 def proto_type(self):
222 "Returns the proto type for `self.dtype_`."
223 if self.dtype_ is None:
224 return 0
225 return numpy_type_prototype(self.dtype_)
227 @property
228 def proto_added_type(self):
229 "Returns the proto type for `self.added_dtype_` or `self.dtype_`."
230 dt = self.added_dtype_ or self.dtype_
231 if dt is None:
232 return 0
233 return numpy_type_prototype(dt)
235 @property
236 def proto_added_shape(self):
237 "Returns the shape for `self.added_shape_` or `self.shape`."
238 dt = self.added_shape_ or self.shape_
239 if dt is None:
240 return None
241 return list(dt)
243 def __repr__(self):
244 "usual"
245 kwargs = dict(dtype=self.dtype_, shape=self.shape_,
246 added_dtype=self.added_dtype_,
247 added_shape=self.added_shape_)
248 kwargs = {k: v for k, v in kwargs.items() if v is not None}
249 if len(kwargs) > 0:
250 msg = ", " + ", ".join(f"{k}={v!r}" for k, v in kwargs.items())
251 else:
252 msg = ''
253 return f"{self.__class__.__name__}({self.name_!r}{msg})"
255 def is_named(self, name):
256 "Tells the variable is named like that."
257 if not isinstance(name, str):
258 raise TypeError( # pragma: no cover
259 f"name is expected to be a string not {type(name)!r}.")
260 return self.name == name
262 def copy_add(self, dtype):
263 """
264 Returns a copy of this variable with a new dtype.
266 :param dtype: added type
267 :return: @see cl Variable
268 """
269 if self.added_dtype_ is not None:
270 raise RuntimeError( # pragma: no cover
271 "Cannot copy as added_dtype is not None.")
272 if isinstance(dtype, numpy.ndarray):
273 dtype, shape = dtype.dtype, dtype.shape
274 else:
275 shape = None
276 return Variable(self.name_, self.dtype_, self.shape_, dtype, shape)
278 def copy_merge(self, var, shape=None):
279 """
280 Merges information from both Variable.
281 """
282 if not isinstance(var, Variable):
283 if shape is not None:
284 raise RuntimeError( # pragma: no cover
285 "shape must be None if var is a Variable.")
286 return self.copy_add(var)
287 res = Variable(self.name_, self.dtype_,
288 shape or self.shape_, self.added_dtype_,
289 self.added_shape_)
290 if self.added_dtype_ is None and var.dtype_ is not None:
291 res.added_dtype_ = var.dtype_
292 if self.added_shape_ is None and var.shape_ is not None:
293 res.added_shape_ = var.shape_
294 return res
296 def copy_name(self, name):
297 """
298 Returns a copy with a new name.
299 """
300 return Variable(
301 name or self.name_, self.dtype_,
302 self.shape_, self.added_dtype_,
303 self.added_shape_)
305 def __eq__(self, other):
306 """
307 Compares every attributes.
308 """
309 if not isinstance(other, Variable):
310 raise TypeError(
311 f"Unexpected type {type(other)!r}.")
312 if self.name != other.name:
313 return False
314 if self.shape_ != other.shape_:
315 return False
316 if self.dtype_ != other.dtype_:
317 return False
318 return True
320 def make_value_info(self):
321 """
322 Converts the variable into `onnx.ValueInfoProto`.
324 :return: instance of `onnx.ValueInfoProto`
325 """
326 value_info = ValueInfoProto()
327 value_info.name = self.name
328 tensor_type_proto = make_tensor_type_proto(self.proto_type, self.shape)
329 value_info.type.CopyFrom(tensor_type_proto) # pylint: disable=E1101
330 return value_info
332 @staticmethod
333 def from_pb(obj):
334 """
335 Creates a Variable from a protobuf object.
337 :param obj: initializer, tensor
338 :return: @see cl Variable
339 """
340 from ..onnx_tools.onnx2py_helper import from_pb
341 name, ty, shape = from_pb(obj)
342 return Variable(name, ty, shape=shape)
345class NodeResultName:
346 """
347 Defines a result name for a node.
349 :param node: node it comes from
350 :param index: index of the output
351 """
353 def __init__(self, node, index):
354 self.node = node
355 self.index = index
357 def __repr__(self):
358 "Usual"
359 return f"{self.__class__.__name__}({self.node!r}, {self.index!r})"
361 def get_name(self):
362 """
363 Returns a name from output_names or a suggestion for a name.
364 """
365 if self.node is None:
366 raise RuntimeError( # pragma: no cover
367 "node must not be None.")
368 if self.node.output_names is not None:
369 return self.node.output_names[self.index].name
370 cl = self.node.op_type.lower()[:3]
371 return "out_%s_%d" % (cl, self.index)
374class DetectedVariable:
375 """
376 Wrapper around a @see cl Variable to detect inputs
377 and outputs of a graph.
379 :param node: node where the variable was detected
380 :param var: instance of @see cl Variable
381 :param index: index, only used if it is an output
382 """
384 def __init__(self, node, var, index):
385 if not isinstance(var, (Variable, ExistingVariable)):
386 raise TypeError( # pragma: no cover
387 f"Unexpected type {type(var)!r}, it should be a Variable.")
388 self.node = node
389 self.var = var
390 self.index = index
392 @property
393 def name(self):
394 "Returns variable name."
395 return self.var.name
397 def __repr__(self):
398 "usual"
399 sindex = f", {self.index}" if self.index >= 0 else ""
400 if self.node is None:
401 return f"{self.__class__.__name__}(None, {self.var!r}{sindex})"
402 return "%s(%s-%d, %r%s)" % (
403 self.__class__.__name__, self.node.__class__.__name__,
404 id(self.node), self.var, sindex)
407class InputDetectedVariable(DetectedVariable):
408 """
409 Instance of @see cl DetectedVariable.
410 Only for inputs.
411 """
413 def __init__(self, node, var):
414 DetectedVariable.__init__(self, node, var, -1)
417class OutputDetectedVariable(DetectedVariable):
418 """
419 Instance of @see cl DetectedVariable.
420 Only for outputs.
421 """
422 pass