Coverage for mlprodict/onnxrt/onnx_shape_inference.py: 99%
91 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Runtime to infer shapes.
5.. versionadded:: 0.9
6"""
7import numpy
8from onnx import FunctionProto, ModelProto
9from onnx.numpy_helper import to_array
10from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
11from .ops_shape.shape_result import ShapeResult
12from .ops_shape.shape_container import ShapeContainer
13from .ops_shape import shape_dispatch
16class OnnxShapeInference:
17 """
18 Implements a micro runtime for ONNX graphs.
19 It does not implements all the operator types.
21 :param model_onnx: ONNX model
23 Other attributes:
25 * `known_shapes_`: shapes which can be inferred without any input
26 * `cache_`: keeps track of the function used to infer
27 the shapes
28 * `is_isfunction`: tells if the graph is a function or a model
30 .. runpython::
31 :showcode:
33 import pprint
34 import numpy
35 from mlprodict.onnxrt.onnx_shape_inference import OnnxShapeInference
36 from mlprodict.npy.xop_variable import Variable
37 from mlprodict.npy.xop import loadop
39 opset = 15
40 OnnxAdd = loadop('Add')
41 dtype = numpy.float32
43 cop = OnnxAdd('X', numpy.array(
44 [[1]], dtype=dtype), op_version=opset)
45 cop4 = OnnxAdd(cop, numpy.array([[2]], dtype=dtype),
46 output_names=['Y'])
47 vari = Variable('X', numpy.float32, [None, 3])
48 model_def = cop4.to_onnx([vari], run_shape=False)
49 rt = OnnxShapeInference(model_def)
50 out = rt.run()
51 pprint.pprint(out.get())
52 """
54 def __init__(self, model_onnx):
55 if not isinstance(model_onnx, (FunctionProto, ModelProto)):
56 raise TypeError( # pragma: no cover
57 "model_onnx is not from FunctionProto or ModelProto "
58 "%r." % type(model_onnx))
59 self.is_function = isinstance(model_onnx, FunctionProto)
60 self.model_onnx = model_onnx
61 self.cache_ = {}
62 self.known_shapes_ = self._run_empty()
64 @property
65 def input_names(self):
66 "Returns input names."
67 if self.is_function:
68 return list(self.model_onnx.input)
69 return [i.name for i in self.model_onnx.graph.input]
71 @property
72 def output_names(self):
73 "Returns output names."
74 if self.is_function:
75 return list(self.model_onnx.output)
76 return [i.name for i in self.model_onnx.graph.output]
78 def __repr__(self):
79 "Usual"
80 return f"{self.__class__.__name__}(...)"
82 @staticmethod
83 def _get_shape(obj, known_shapes=None, result_name=None):
84 if obj is None:
85 return [], None, False
86 dtype = TENSOR_TYPE_TO_NP_TYPE.get(
87 obj.type.tensor_type.elem_type, None)
88 shape = []
89 for dimi, d in enumerate(obj.type.tensor_type.shape.dim):
90 v = d.dim_value if d.dim_value > 0 else d.dim_param
91 if v in ('', None):
92 if known_shapes is None or result_name is None:
93 raise RuntimeError( # pragma: no cover
94 "known_shapes must be specified if "
95 "a dimension is not.")
96 v = known_shapes.get_new_name(v, result_name, dimi)
97 shape.append(v)
98 return shape, dtype, False
100 def _run_empty(self):
101 """
102 Computes shape and types of all results.
104 :return: all intermediates results and output as a dictionary
105 """
106 def get_obj(name, inputs):
107 if self.is_function:
108 return None
109 if inputs:
110 for o in self.model_onnx.graph.input:
111 if o.name == name:
112 return o
113 else:
114 for o in self.model_onnx.graph.output:
115 if o.name == name:
116 return o
117 return None
119 known_shapes = ShapeContainer()
120 if not self.is_function:
121 for init in self.model_onnx.graph.initializer:
122 mat = to_array(init)
123 known_shapes.update(init.name, ShapeResult(
124 init.name, mat.shape, mat.dtype, sparse=False))
126 for name in self.input_names:
127 if name in known_shapes:
128 raise NotImplementedError(
129 f"Optional inputs are not implemented yet. (name={name!r})")
130 shape, dtype, sparse = self._get_shape(
131 get_obj(name, True), known_shapes, result_name=name)
132 known_shapes.update(name, ShapeResult(
133 name, shape, dtype, sparse=sparse))
135 for name in self.output_names:
136 if name in known_shapes:
137 raise NameError( # pragma: no cover
138 f"Output {name!r} is already present. Use Identity node.")
139 shape, dtype, sparse = self._get_shape(
140 get_obj(name, False), known_shapes, result_name=name)
141 if dtype is None:
142 # The onnx graph was created with named outputs
143 # but with no type or shape.
144 continue
145 known_shapes.update(name, ShapeResult(
146 name, shape, dtype, sparse=sparse))
148 nodes = (
149 self.model_onnx.node if self.is_function
150 else self.model_onnx.graph.node)
151 cont = True
152 while cont:
153 cont = False
154 for node in nodes:
155 cont = cont or shape_dispatch(
156 self.cache_, known_shapes, node, rt_class=self.__class__)
157 return known_shapes
159 def run(self, inputs=None):
160 """
161 Runs shape inference and type given known inputs.
163 :param inputs: inputs
164 :return: all results
165 """
166 known_shapes = self.known_shapes_.copy(deep=True)
167 if inputs is None:
168 known_shapes.resolve()
169 return known_shapes
171 cont = False
172 for name, obj in inputs.items():
173 shape, dtype, sparse = (
174 obj.shape, obj.dtype, not isinstance(obj, numpy.ndarray))
175 cont = cont or known_shapes.update(
176 name, ShapeResult(name, shape, dtype, sparse=sparse))
178 nodes = (
179 self.model_onnx.node if self.is_function
180 else self.model_onnx.graph.node)
181 while cont:
182 cont = False
183 for node in nodes:
184 updated = shape_dispatch(
185 self.cache_, known_shapes, node, rt_class=self.__class__)
186 cont = cont or updated
187 known_shapes.resolve()
188 return known_shapes