Coverage for mlprodict/onnx_tools/optim/graph_schema_helper.py: 73%
143 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Functions to help guessing the final graph structure.
4"""
5import numpy
6from onnx import TensorProto
9def _guess_type(var):
10 from skl2onnx.algebra.type_helper import _guess_type as skl2onnx__guess_type # delayed
11 if isinstance(var, dict) and 'value' in var:
12 return skl2onnx__guess_type(var['value']) # pragma: no cover
13 return skl2onnx__guess_type(var)
16def get_defined_inputs(input_names, variables=None, dtype=None,
17 schema=None):
18 """
19 Retrieves defined inputs in already declared variables
20 bsed on their names.
22 @param input_names input names
23 @param variables registered variables created
24 by previous operators
25 @param dtype float computational type
26 @param schema defined inputs by schema (*expected_inputs*)
27 @return typed inputs as ``tuple(name, type)``
28 """
29 from skl2onnx.common.data_types import ( # delayed
30 DataType, FloatTensorType, DoubleTensorType)
32 def guess_type_variable(name, schema):
33 if variables is None:
34 if (schema is None or
35 not isinstance(schema, (DataType, tuple))):
36 return ( # pragma: no cover
37 DoubleTensorType() if dtype == numpy.float64 else FloatTensorType())
38 return schema if isinstance(schema, DataType) else schema[1]
39 if name in variables:
40 ty = variables[name]
41 if isinstance(ty, DataType):
42 shape = ty.shape
43 if 0 in shape:
44 raise RuntimeError( # pragma: no cover
45 f"Shape cannot be empty: name='{name}', var={ty}")
46 return variables[name]
47 if isinstance(ty, dict) and 'value' in ty:
48 # constant
49 arr = ty['value']
50 try:
51 return _guess_type(arr)
52 except RuntimeError as e: # pragma: no cover
53 raise RuntimeError(
54 f"Unable to guess type of variable '{name}' - {arr}.") from e
55 raise NotImplementedError( # pragma: no cover
56 f"Unable to guess type for '{name}' form '{variables[name]}'.")
57 if isinstance(schema, (DataType, tuple)):
58 sch = schema if isinstance(schema, DataType) else schema[1]
59 if not isinstance(sch, str):
60 return sch
61 # Inputs. Let's assume it is a vector of floats.
62 return DoubleTensorType() if dtype == numpy.float64 else FloatTensorType()
64 if schema is None or len(schema) < len(input_names):
65 inputs = [(name, guess_type_variable(name, None))
66 for name in input_names]
67 else:
68 inputs = [(name, guess_type_variable(name, schema=sch))
69 for name, sch in zip(input_names, schema)]
70 return inputs
73def get_defined_outputs(outputs, onnx_node, typed_inputs=None, variables=None,
74 dtype=None, schema=None, schema_inputs=None):
75 """
76 Gets types of predefined outputs when they cannot be inferred.
77 Some part of it should be automated based
78 on type constraints.
80 :param outputs: requested outputs
81 :param onnx_node: :epkg:`ONNX` node definition
82 :param typed_inputs: known typed inputs of the node as `tuple(name, type)`
83 :param variables: registered variables created by previous operators
84 :param dtype: float computational type
85 :param schema: defined outputs by schema (*expected_outputs*)
86 :param schema_inputs: defined inputs by schema (*expected_inputs*)
87 :return: typed outputs as ``tuple(name, type)``
88 """
89 from skl2onnx.common.data_types import ( # delayed
90 DataType,
91 FloatTensorType, SequenceType, DictionaryType,
92 Int64Type, Int64TensorType, BooleanTensorType,
93 DoubleTensorType, _guess_type_proto, _guess_type_proto_str)
95 if schema is None:
96 ft = DoubleTensorType if dtype == numpy.float64 else FloatTensorType
97 elif len(schema) != 1:
98 raise ValueError( # pragma: no cover
99 f"Operator {onnx_node.op_type!r}, "
100 f"schema should only contain one output not {schema}.")
101 else:
102 if isinstance(schema, DataType):
103 ft = schema[0].__class__
104 else:
105 ft = schema[0][1].__class__
107 if onnx_node.op_type in {'ZipMap', 'ArgMin', 'ArgMax', 'Shape',
108 'Greater', 'Less', 'Equal', 'TopK',
109 'Cast', 'ArrayFeatureExtractor',
110 'Reshape', 'Transpose', 'Scan',
111 'ConstantOfShape'}:
112 if onnx_node.op_type == "ZipMap":
113 # ZipMap
114 otype = SequenceType(DictionaryType(
115 Int64Type(), ft()))
116 outputs = [(name, otype) for name in outputs]
117 elif (onnx_node.op_type in ("ArgMin", "ArgMax", 'Shape') and
118 len(outputs) == 1):
119 # ArgMin, ArgMax, Shape
120 outputs = [(outputs[0], Int64TensorType())]
121 elif (onnx_node.op_type in ("Greater", "Less", 'Equal') and
122 len(outputs) == 1):
123 # Greater, Less, Equal
124 outputs = [(outputs[0], BooleanTensorType())]
125 elif onnx_node.op_type == "TopK" and len(outputs) == 2:
126 # TopK
127 if len(typed_inputs) != 2:
128 raise RuntimeError( # pragma: no cover
129 f"Wrong typed_inputs, got {typed_inputs}.")
130 outputs = [(outputs[0], typed_inputs[0][1]),
131 (outputs[1], Int64TensorType())]
132 elif onnx_node.op_type == "Cast" and len(outputs) == 1:
133 # Cast
134 ttyp = _guess_type_proto(onnx_node.attribute[0].i, dims=None)
135 outputs = [(outputs[0], ttyp)]
136 elif onnx_node.op_type == "ArrayFeatureExtractor":
137 # ArrayFeatureExtractor
138 if len(typed_inputs) != 2:
139 raise RuntimeError( # pragma: no cover
140 f"Wrong typed_inputs, got {typed_inputs}.")
141 outputs = [(outputs[0], typed_inputs[0][1])]
142 elif onnx_node.op_type in ('Reshape', 'Transpose'):
143 # Reshape
144 outputs = [(outputs[0], typed_inputs[0][1].__class__())]
145 elif onnx_node.op_type == 'Scan':
146 # Scan
147 if len(outputs) != len(typed_inputs):
148 raise RuntimeError( # pragma: no cover
149 "Dimension mismatch, operator Scan should have "
150 "the same number of inputs and outputs {} != {}"
151 ".".format(len(outputs), len(typed_inputs)))
152 outputs = [(o, t[1].__class__())
153 for o, t in zip(outputs, typed_inputs)]
154 elif onnx_node.op_type == "ConstantOfShape":
155 # ConstantOfShape
156 outputs = [(outputs[0], ft())]
157 elif 'Classifier' in onnx_node.op_type:
158 # Good chance that's a classifier.
159 outputs = [(outputs[0], Int64TensorType()),
160 (outputs[1], ft())]
161 else:
162 if schema_inputs is not None and schema is not None:
163 dt = {}
164 for got, exp in zip(typed_inputs, schema_inputs):
165 if isinstance(exp[1], str):
166 dt[exp[1]] = got
167 out = []
168 for i in range(len(outputs)): # pylint: disable=C0200
169 o = outputs[i]
170 if isinstance(o, str):
171 exp = schema[i]
172 if exp[1] in dt:
173 out.append((o, dt[exp[1]][1].__class__()))
174 else:
175 nt = _guess_type_proto_str(exp[1], None)
176 out.append((o, nt))
177 elif (isinstance(o, tuple) and
178 (isinstance(o[1], str) or o[1] is None)):
179 exp = schema[i]
180 if exp[1] in dt:
181 out.append((o[0], dt[exp[1]][1].__class__()))
182 else:
183 nt = _guess_type_proto_str(exp[1], None)
184 out.append((o[0], nt))
185 else:
186 out.append(o)
187 outputs = out
188 elif len(typed_inputs) == 1 and len(outputs) == 1:
189 # Default case
190 # Assuming the only output is the same as the only input.
191 outputs = [(outputs[0], typed_inputs[0][1])]
192 else:
193 # Default
194 outputs = [(name, ft()) for name in outputs]
196 for name, typ in outputs:
197 if typ in ('T', None, '', 'I'):
198 raise NotImplementedError( # pragma: no cover
199 "Undefined output type: %r (outputs=%r, typed_inputs=%r, "
200 "dtype=%r, schema=%r, schema_inputs=%r, onnx_node=%r, "
201 "variables=%r)." % (
202 typ, outputs, typed_inputs, dtype,
203 schema, schema_inputs, onnx_node, variables))
204 if not isinstance(name, str):
205 raise NotImplementedError( # pragma: no cover
206 "Undefined output type: %r (outputs=%r, typed_inputs=%r, "
207 "dtype=%r, schema=%r, schema_inputs=%r, onnx_node=%r, "
208 "variables=%r)." % (
209 typ, outputs, typed_inputs, dtype,
210 schema, schema_inputs, onnx_node, variables))
211 return outputs
214def proto2vars(values):
215 """
216 Converts proto values to Variables.
217 """
218 from skl2onnx.common.data_types import ( # delayed
219 FloatTensorType, SequenceType, DictionaryType,
220 Int64Type, Int64TensorType, BooleanTensorType,
221 Int32TensorType, DoubleTensorType, FloatType,
222 StringTensorType, Float16TensorType)
223 from ..onnx2py_helper import (
224 get_tensor_elem_type, get_tensor_shape)
226 def ptype2vttype(it, shape):
227 if it == TensorProto.FLOAT: # pylint: disable=E1101
228 return FloatTensorType(shape)
229 if it == TensorProto.DOUBLE: # pylint: disable=E1101
230 return DoubleTensorType(shape)
231 if it == TensorProto.INT64: # pylint: disable=E1101
232 return Int64TensorType(shape)
233 if it == TensorProto.INT32: # pylint: disable=E1101
234 return Int32TensorType(shape)
235 if it == TensorProto.BOOL: # pylint: disable=E1101
236 return BooleanTensorType(shape)
237 if it == TensorProto.STRING: # pylint: disable=E1101
238 return StringTensorType(shape)
239 if Float16TensorType is None:
240 if it == TensorProto.FLOAT16: # pylint: disable=E1101
241 return Float16TensorType(shape)
242 raise NotImplementedError( # pragma: no cover
243 f"Unrecognized proto type {it} with shape {shape}")
245 def ptype2vtype(it):
246 if it == TensorProto.FLOAT: # pylint: disable=E1101
247 return FloatType()
248 if it == TensorProto.INT64: # pylint: disable=E1101
249 return Int64Type()
250 raise NotImplementedError( # pragma: no cover
251 f"Unrecognized proto type {it}")
253 res = []
254 for v_ in values:
255 v = v_
256 name = v.name if hasattr(v, 'name') else None
257 if hasattr(v, 'type') and str(v.type) != '':
258 t = v.type
259 v = proto2vars([t])[0][1]
260 elif hasattr(v, 'sequence_type') and str(v.sequence_type) != '':
261 subtype = proto2vars([v.sequence_type.elem_type])[0][1]
262 v = SequenceType(subtype)
263 elif hasattr(v, 'tensor_type') and str(v.tensor_type) != '':
264 v = ptype2vttype(get_tensor_elem_type(v), get_tensor_shape(v))
265 elif hasattr(v, 'map_type') and str(v.map_type) != '':
266 mt = v.map_type
267 keyt = ptype2vtype(mt.key_type)
268 valt = proto2vars([mt.value_type])[0][1]
269 v = DictionaryType(keyt, valt)
270 else:
271 raise RuntimeError( # pragma: no cover
272 f"Unable to build a variable from {v}.")
273 if v.shape is not None and 0 in v.shape:
274 # Replaces 0 by None
275 new_shape = tuple(None if d == 0 else d for d in v.shape)
276 if new_shape in ((None, ), None):
277 v = v.__class__()
278 else:
279 v = v.__class__(new_shape)
280 if v.shape is not None and 0 in v.shape:
281 raise RuntimeError( # pragma: no cover
282 f"Shape cannot be empty: '{name}': {v_}.")
283 res.append((name, v))
284 return res