Coverage for mlprodict/onnxrt/onnx_shape_inference.py: 99%

91 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Runtime to infer shapes. 

4 

5.. versionadded:: 0.9 

6""" 

7import numpy 

8from onnx import FunctionProto, ModelProto 

9from onnx.numpy_helper import to_array 

10from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE 

11from .ops_shape.shape_result import ShapeResult 

12from .ops_shape.shape_container import ShapeContainer 

13from .ops_shape import shape_dispatch 

14 

15 

16class OnnxShapeInference: 

17 """ 

18 Implements a micro runtime for ONNX graphs. 

19 It does not implements all the operator types. 

20 

21 :param model_onnx: ONNX model 

22 

23 Other attributes: 

24 

25 * `known_shapes_`: shapes which can be inferred without any input 

26 * `cache_`: keeps track of the function used to infer 

27 the shapes 

28 * `is_isfunction`: tells if the graph is a function or a model 

29 

30 .. runpython:: 

31 :showcode: 

32 

33 import pprint 

34 import numpy 

35 from mlprodict.onnxrt.onnx_shape_inference import OnnxShapeInference 

36 from mlprodict.npy.xop_variable import Variable 

37 from mlprodict.npy.xop import loadop 

38 

39 opset = 15 

40 OnnxAdd = loadop('Add') 

41 dtype = numpy.float32 

42 

43 cop = OnnxAdd('X', numpy.array( 

44 [[1]], dtype=dtype), op_version=opset) 

45 cop4 = OnnxAdd(cop, numpy.array([[2]], dtype=dtype), 

46 output_names=['Y']) 

47 vari = Variable('X', numpy.float32, [None, 3]) 

48 model_def = cop4.to_onnx([vari], run_shape=False) 

49 rt = OnnxShapeInference(model_def) 

50 out = rt.run() 

51 pprint.pprint(out.get()) 

52 """ 

53 

54 def __init__(self, model_onnx): 

55 if not isinstance(model_onnx, (FunctionProto, ModelProto)): 

56 raise TypeError( # pragma: no cover 

57 "model_onnx is not from FunctionProto or ModelProto " 

58 "%r." % type(model_onnx)) 

59 self.is_function = isinstance(model_onnx, FunctionProto) 

60 self.model_onnx = model_onnx 

61 self.cache_ = {} 

62 self.known_shapes_ = self._run_empty() 

63 

64 @property 

65 def input_names(self): 

66 "Returns input names." 

67 if self.is_function: 

68 return list(self.model_onnx.input) 

69 return [i.name for i in self.model_onnx.graph.input] 

70 

71 @property 

72 def output_names(self): 

73 "Returns output names." 

74 if self.is_function: 

75 return list(self.model_onnx.output) 

76 return [i.name for i in self.model_onnx.graph.output] 

77 

78 def __repr__(self): 

79 "Usual" 

80 return f"{self.__class__.__name__}(...)" 

81 

82 @staticmethod 

83 def _get_shape(obj, known_shapes=None, result_name=None): 

84 if obj is None: 

85 return [], None, False 

86 dtype = TENSOR_TYPE_TO_NP_TYPE.get( 

87 obj.type.tensor_type.elem_type, None) 

88 shape = [] 

89 for dimi, d in enumerate(obj.type.tensor_type.shape.dim): 

90 v = d.dim_value if d.dim_value > 0 else d.dim_param 

91 if v in ('', None): 

92 if known_shapes is None or result_name is None: 

93 raise RuntimeError( # pragma: no cover 

94 "known_shapes must be specified if " 

95 "a dimension is not.") 

96 v = known_shapes.get_new_name(v, result_name, dimi) 

97 shape.append(v) 

98 return shape, dtype, False 

99 

100 def _run_empty(self): 

101 """ 

102 Computes shape and types of all results. 

103 

104 :return: all intermediates results and output as a dictionary 

105 """ 

106 def get_obj(name, inputs): 

107 if self.is_function: 

108 return None 

109 if inputs: 

110 for o in self.model_onnx.graph.input: 

111 if o.name == name: 

112 return o 

113 else: 

114 for o in self.model_onnx.graph.output: 

115 if o.name == name: 

116 return o 

117 return None 

118 

119 known_shapes = ShapeContainer() 

120 if not self.is_function: 

121 for init in self.model_onnx.graph.initializer: 

122 mat = to_array(init) 

123 known_shapes.update(init.name, ShapeResult( 

124 init.name, mat.shape, mat.dtype, sparse=False)) 

125 

126 for name in self.input_names: 

127 if name in known_shapes: 

128 raise NotImplementedError( 

129 f"Optional inputs are not implemented yet. (name={name!r})") 

130 shape, dtype, sparse = self._get_shape( 

131 get_obj(name, True), known_shapes, result_name=name) 

132 known_shapes.update(name, ShapeResult( 

133 name, shape, dtype, sparse=sparse)) 

134 

135 for name in self.output_names: 

136 if name in known_shapes: 

137 raise NameError( # pragma: no cover 

138 f"Output {name!r} is already present. Use Identity node.") 

139 shape, dtype, sparse = self._get_shape( 

140 get_obj(name, False), known_shapes, result_name=name) 

141 if dtype is None: 

142 # The onnx graph was created with named outputs 

143 # but with no type or shape. 

144 continue 

145 known_shapes.update(name, ShapeResult( 

146 name, shape, dtype, sparse=sparse)) 

147 

148 nodes = ( 

149 self.model_onnx.node if self.is_function 

150 else self.model_onnx.graph.node) 

151 cont = True 

152 while cont: 

153 cont = False 

154 for node in nodes: 

155 cont = cont or shape_dispatch( 

156 self.cache_, known_shapes, node, rt_class=self.__class__) 

157 return known_shapes 

158 

159 def run(self, inputs=None): 

160 """ 

161 Runs shape inference and type given known inputs. 

162 

163 :param inputs: inputs 

164 :return: all results 

165 """ 

166 known_shapes = self.known_shapes_.copy(deep=True) 

167 if inputs is None: 

168 known_shapes.resolve() 

169 return known_shapes 

170 

171 cont = False 

172 for name, obj in inputs.items(): 

173 shape, dtype, sparse = ( 

174 obj.shape, obj.dtype, not isinstance(obj, numpy.ndarray)) 

175 cont = cont or known_shapes.update( 

176 name, ShapeResult(name, shape, dtype, sparse=sparse)) 

177 

178 nodes = ( 

179 self.model_onnx.node if self.is_function 

180 else self.model_onnx.graph.node) 

181 while cont: 

182 cont = False 

183 for node in nodes: 

184 updated = shape_dispatch( 

185 self.cache_, known_shapes, node, rt_class=self.__class__) 

186 cont = cont or updated 

187 known_shapes.resolve() 

188 return known_shapes