Coverage for mlprodict/onnx_tools/exports/skl2onnx_helper.py: 98%
63 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Helpers to run examples created with :epkg:`sklearn-onnx`.
4"""
5from onnx import helper, TensorProto, ValueInfoProto, TypeProto
8def get_tensor_shape(obj):
9 """
10 Returns the shape if that makes sense for this object.
11 """
12 if isinstance(obj, ValueInfoProto):
13 return get_tensor_shape(obj.type)
14 elif not isinstance(obj, TypeProto):
15 raise TypeError( # pragma: no cover
16 f"Unexpected type {type(obj)!r}.")
17 shape = []
18 for d in obj.tensor_type.shape.dim:
19 v = d.dim_value if d.dim_value > 0 else d.dim_param
20 shape.append(v)
21 if len(shape) == 0:
22 shape = None
23 else:
24 shape = list(None if s == 0 else s for s in shape)
25 return shape
28def get_tensor_elem_type(obj):
29 """
30 Returns the element type if that makes sense for this object.
31 """
32 if isinstance(obj, ValueInfoProto):
33 return get_tensor_elem_type(obj.type)
34 elif not isinstance(obj, TypeProto):
35 raise TypeError( # pragma: no cover
36 f"Unexpected type {type(obj)!r}.")
37 return obj.tensor_type.elem_type
40def _copy_inout(inout, scope, new_name):
41 shape = get_tensor_shape(inout)
42 elem_type = get_tensor_elem_type(inout)
43 value_info = helper.make_tensor_value_info(
44 new_name, elem_type, shape)
45 return value_info
48def _clean_variable_name(name, scope):
49 return scope.get_unique_variable_name(name)
52def _clean_operator_name(name, scope):
53 return scope.get_unique_operator_name(name)
56def _clean_initializer_name(name, scope):
57 return scope.get_unique_variable_name(name)
60def add_onnx_graph(scope, operator, container, onx):
61 """
62 Adds a whole ONNX graph to an existing one following
63 :epkg:`skl2onnx` API assuming this ONNX graph implements
64 an `operator <http://onnx.ai/sklearn-onnx/api_summary.html?
65 highlight=operator#skl2onnx.common._topology.Operator>`_.
67 :param scope: scope (to get unique names)
68 :param operator: operator
69 :param container: container
70 :param onx: ONNX graph
71 """
72 graph = onx.graph
73 name_mapping = {}
74 node_mapping = {}
75 for node in graph.node:
76 name = node.name
77 if name is not None:
78 node_mapping[node.name] = _clean_initializer_name(
79 node.name, scope)
80 for o in node.input:
81 name_mapping[o] = _clean_variable_name(o, scope)
82 for o in node.output:
83 name_mapping[o] = _clean_variable_name(o, scope)
84 for o in graph.initializer:
85 name_mapping[o.name] = _clean_operator_name(o.name, scope)
87 inputs = [_copy_inout(o, scope, name_mapping[o.name])
88 for o in graph.input]
89 outputs = [_copy_inout(o, scope, name_mapping[o.name])
90 for o in graph.output]
92 for inp, to in zip(operator.inputs, inputs):
93 n = helper.make_node('Identity', [inp.onnx_name], [to.name],
94 name=_clean_operator_name('Identity', scope))
95 container.nodes.append(n)
97 for inp, to in zip(outputs, operator.outputs):
98 n = helper.make_node('Identity', [inp.name], [to.onnx_name],
99 name=_clean_operator_name('Identity', scope))
100 container.nodes.append(n)
102 for node in graph.node:
103 n = helper.make_node(
104 node.op_type,
105 [name_mapping[o] for o in node.input],
106 [name_mapping[o] for o in node.output],
107 name=node_mapping[node.name] if node.name else None,
108 domain=node.domain if node.domain else None)
109 n.attribute.extend(node.attribute) # pylint: disable=E1101
110 container.nodes.append(n)
112 for o in graph.initializer:
113 as_str = o.SerializeToString()
114 tensor = TensorProto()
115 tensor.ParseFromString(as_str)
116 tensor.name = name_mapping[o.name]
117 container.initializers.append(tensor)
119 # opset
120 for oimp in onx.opset_import:
121 container.node_domain_version_pair_sets.add(
122 (oimp.domain, oimp.version))