Coverage for mlprodict/onnx_tools/optim/graph_schema_helper.py: 73%

143 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Functions to help guessing the final graph structure. 

4""" 

5import numpy 

6from onnx import TensorProto 

7 

8 

9def _guess_type(var): 

10 from skl2onnx.algebra.type_helper import _guess_type as skl2onnx__guess_type # delayed 

11 if isinstance(var, dict) and 'value' in var: 

12 return skl2onnx__guess_type(var['value']) # pragma: no cover 

13 return skl2onnx__guess_type(var) 

14 

15 

16def get_defined_inputs(input_names, variables=None, dtype=None, 

17 schema=None): 

18 """ 

19 Retrieves defined inputs in already declared variables 

20 bsed on their names. 

21 

22 @param input_names input names 

23 @param variables registered variables created 

24 by previous operators 

25 @param dtype float computational type 

26 @param schema defined inputs by schema (*expected_inputs*) 

27 @return typed inputs as ``tuple(name, type)`` 

28 """ 

29 from skl2onnx.common.data_types import ( # delayed 

30 DataType, FloatTensorType, DoubleTensorType) 

31 

32 def guess_type_variable(name, schema): 

33 if variables is None: 

34 if (schema is None or 

35 not isinstance(schema, (DataType, tuple))): 

36 return ( # pragma: no cover 

37 DoubleTensorType() if dtype == numpy.float64 else FloatTensorType()) 

38 return schema if isinstance(schema, DataType) else schema[1] 

39 if name in variables: 

40 ty = variables[name] 

41 if isinstance(ty, DataType): 

42 shape = ty.shape 

43 if 0 in shape: 

44 raise RuntimeError( # pragma: no cover 

45 f"Shape cannot be empty: name='{name}', var={ty}") 

46 return variables[name] 

47 if isinstance(ty, dict) and 'value' in ty: 

48 # constant 

49 arr = ty['value'] 

50 try: 

51 return _guess_type(arr) 

52 except RuntimeError as e: # pragma: no cover 

53 raise RuntimeError( 

54 f"Unable to guess type of variable '{name}' - {arr}.") from e 

55 raise NotImplementedError( # pragma: no cover 

56 f"Unable to guess type for '{name}' form '{variables[name]}'.") 

57 if isinstance(schema, (DataType, tuple)): 

58 sch = schema if isinstance(schema, DataType) else schema[1] 

59 if not isinstance(sch, str): 

60 return sch 

61 # Inputs. Let's assume it is a vector of floats. 

62 return DoubleTensorType() if dtype == numpy.float64 else FloatTensorType() 

63 

64 if schema is None or len(schema) < len(input_names): 

65 inputs = [(name, guess_type_variable(name, None)) 

66 for name in input_names] 

67 else: 

68 inputs = [(name, guess_type_variable(name, schema=sch)) 

69 for name, sch in zip(input_names, schema)] 

70 return inputs 

71 

72 

73def get_defined_outputs(outputs, onnx_node, typed_inputs=None, variables=None, 

74 dtype=None, schema=None, schema_inputs=None): 

75 """ 

76 Gets types of predefined outputs when they cannot be inferred. 

77 Some part of it should be automated based 

78 on type constraints. 

79 

80 :param outputs: requested outputs 

81 :param onnx_node: :epkg:`ONNX` node definition 

82 :param typed_inputs: known typed inputs of the node as `tuple(name, type)` 

83 :param variables: registered variables created by previous operators 

84 :param dtype: float computational type 

85 :param schema: defined outputs by schema (*expected_outputs*) 

86 :param schema_inputs: defined inputs by schema (*expected_inputs*) 

87 :return: typed outputs as ``tuple(name, type)`` 

88 """ 

89 from skl2onnx.common.data_types import ( # delayed 

90 DataType, 

91 FloatTensorType, SequenceType, DictionaryType, 

92 Int64Type, Int64TensorType, BooleanTensorType, 

93 DoubleTensorType, _guess_type_proto, _guess_type_proto_str) 

94 

95 if schema is None: 

96 ft = DoubleTensorType if dtype == numpy.float64 else FloatTensorType 

97 elif len(schema) != 1: 

98 raise ValueError( # pragma: no cover 

99 f"Operator {onnx_node.op_type!r}, " 

100 f"schema should only contain one output not {schema}.") 

101 else: 

102 if isinstance(schema, DataType): 

103 ft = schema[0].__class__ 

104 else: 

105 ft = schema[0][1].__class__ 

106 

107 if onnx_node.op_type in {'ZipMap', 'ArgMin', 'ArgMax', 'Shape', 

108 'Greater', 'Less', 'Equal', 'TopK', 

109 'Cast', 'ArrayFeatureExtractor', 

110 'Reshape', 'Transpose', 'Scan', 

111 'ConstantOfShape'}: 

112 if onnx_node.op_type == "ZipMap": 

113 # ZipMap 

114 otype = SequenceType(DictionaryType( 

115 Int64Type(), ft())) 

116 outputs = [(name, otype) for name in outputs] 

117 elif (onnx_node.op_type in ("ArgMin", "ArgMax", 'Shape') and 

118 len(outputs) == 1): 

119 # ArgMin, ArgMax, Shape 

120 outputs = [(outputs[0], Int64TensorType())] 

121 elif (onnx_node.op_type in ("Greater", "Less", 'Equal') and 

122 len(outputs) == 1): 

123 # Greater, Less, Equal 

124 outputs = [(outputs[0], BooleanTensorType())] 

125 elif onnx_node.op_type == "TopK" and len(outputs) == 2: 

126 # TopK 

127 if len(typed_inputs) != 2: 

128 raise RuntimeError( # pragma: no cover 

129 f"Wrong typed_inputs, got {typed_inputs}.") 

130 outputs = [(outputs[0], typed_inputs[0][1]), 

131 (outputs[1], Int64TensorType())] 

132 elif onnx_node.op_type == "Cast" and len(outputs) == 1: 

133 # Cast 

134 ttyp = _guess_type_proto(onnx_node.attribute[0].i, dims=None) 

135 outputs = [(outputs[0], ttyp)] 

136 elif onnx_node.op_type == "ArrayFeatureExtractor": 

137 # ArrayFeatureExtractor 

138 if len(typed_inputs) != 2: 

139 raise RuntimeError( # pragma: no cover 

140 f"Wrong typed_inputs, got {typed_inputs}.") 

141 outputs = [(outputs[0], typed_inputs[0][1])] 

142 elif onnx_node.op_type in ('Reshape', 'Transpose'): 

143 # Reshape 

144 outputs = [(outputs[0], typed_inputs[0][1].__class__())] 

145 elif onnx_node.op_type == 'Scan': 

146 # Scan 

147 if len(outputs) != len(typed_inputs): 

148 raise RuntimeError( # pragma: no cover 

149 "Dimension mismatch, operator Scan should have " 

150 "the same number of inputs and outputs {} != {}" 

151 ".".format(len(outputs), len(typed_inputs))) 

152 outputs = [(o, t[1].__class__()) 

153 for o, t in zip(outputs, typed_inputs)] 

154 elif onnx_node.op_type == "ConstantOfShape": 

155 # ConstantOfShape 

156 outputs = [(outputs[0], ft())] 

157 elif 'Classifier' in onnx_node.op_type: 

158 # Good chance that's a classifier. 

159 outputs = [(outputs[0], Int64TensorType()), 

160 (outputs[1], ft())] 

161 else: 

162 if schema_inputs is not None and schema is not None: 

163 dt = {} 

164 for got, exp in zip(typed_inputs, schema_inputs): 

165 if isinstance(exp[1], str): 

166 dt[exp[1]] = got 

167 out = [] 

168 for i in range(len(outputs)): # pylint: disable=C0200 

169 o = outputs[i] 

170 if isinstance(o, str): 

171 exp = schema[i] 

172 if exp[1] in dt: 

173 out.append((o, dt[exp[1]][1].__class__())) 

174 else: 

175 nt = _guess_type_proto_str(exp[1], None) 

176 out.append((o, nt)) 

177 elif (isinstance(o, tuple) and 

178 (isinstance(o[1], str) or o[1] is None)): 

179 exp = schema[i] 

180 if exp[1] in dt: 

181 out.append((o[0], dt[exp[1]][1].__class__())) 

182 else: 

183 nt = _guess_type_proto_str(exp[1], None) 

184 out.append((o[0], nt)) 

185 else: 

186 out.append(o) 

187 outputs = out 

188 elif len(typed_inputs) == 1 and len(outputs) == 1: 

189 # Default case 

190 # Assuming the only output is the same as the only input. 

191 outputs = [(outputs[0], typed_inputs[0][1])] 

192 else: 

193 # Default 

194 outputs = [(name, ft()) for name in outputs] 

195 

196 for name, typ in outputs: 

197 if typ in ('T', None, '', 'I'): 

198 raise NotImplementedError( # pragma: no cover 

199 "Undefined output type: %r (outputs=%r, typed_inputs=%r, " 

200 "dtype=%r, schema=%r, schema_inputs=%r, onnx_node=%r, " 

201 "variables=%r)." % ( 

202 typ, outputs, typed_inputs, dtype, 

203 schema, schema_inputs, onnx_node, variables)) 

204 if not isinstance(name, str): 

205 raise NotImplementedError( # pragma: no cover 

206 "Undefined output type: %r (outputs=%r, typed_inputs=%r, " 

207 "dtype=%r, schema=%r, schema_inputs=%r, onnx_node=%r, " 

208 "variables=%r)." % ( 

209 typ, outputs, typed_inputs, dtype, 

210 schema, schema_inputs, onnx_node, variables)) 

211 return outputs 

212 

213 

214def proto2vars(values): 

215 """ 

216 Converts proto values to Variables. 

217 """ 

218 from skl2onnx.common.data_types import ( # delayed 

219 FloatTensorType, SequenceType, DictionaryType, 

220 Int64Type, Int64TensorType, BooleanTensorType, 

221 Int32TensorType, DoubleTensorType, FloatType, 

222 StringTensorType, Float16TensorType) 

223 from ..onnx2py_helper import ( 

224 get_tensor_elem_type, get_tensor_shape) 

225 

226 def ptype2vttype(it, shape): 

227 if it == TensorProto.FLOAT: # pylint: disable=E1101 

228 return FloatTensorType(shape) 

229 if it == TensorProto.DOUBLE: # pylint: disable=E1101 

230 return DoubleTensorType(shape) 

231 if it == TensorProto.INT64: # pylint: disable=E1101 

232 return Int64TensorType(shape) 

233 if it == TensorProto.INT32: # pylint: disable=E1101 

234 return Int32TensorType(shape) 

235 if it == TensorProto.BOOL: # pylint: disable=E1101 

236 return BooleanTensorType(shape) 

237 if it == TensorProto.STRING: # pylint: disable=E1101 

238 return StringTensorType(shape) 

239 if Float16TensorType is None: 

240 if it == TensorProto.FLOAT16: # pylint: disable=E1101 

241 return Float16TensorType(shape) 

242 raise NotImplementedError( # pragma: no cover 

243 f"Unrecognized proto type {it} with shape {shape}") 

244 

245 def ptype2vtype(it): 

246 if it == TensorProto.FLOAT: # pylint: disable=E1101 

247 return FloatType() 

248 if it == TensorProto.INT64: # pylint: disable=E1101 

249 return Int64Type() 

250 raise NotImplementedError( # pragma: no cover 

251 f"Unrecognized proto type {it}") 

252 

253 res = [] 

254 for v_ in values: 

255 v = v_ 

256 name = v.name if hasattr(v, 'name') else None 

257 if hasattr(v, 'type') and str(v.type) != '': 

258 t = v.type 

259 v = proto2vars([t])[0][1] 

260 elif hasattr(v, 'sequence_type') and str(v.sequence_type) != '': 

261 subtype = proto2vars([v.sequence_type.elem_type])[0][1] 

262 v = SequenceType(subtype) 

263 elif hasattr(v, 'tensor_type') and str(v.tensor_type) != '': 

264 v = ptype2vttype(get_tensor_elem_type(v), get_tensor_shape(v)) 

265 elif hasattr(v, 'map_type') and str(v.map_type) != '': 

266 mt = v.map_type 

267 keyt = ptype2vtype(mt.key_type) 

268 valt = proto2vars([mt.value_type])[0][1] 

269 v = DictionaryType(keyt, valt) 

270 else: 

271 raise RuntimeError( # pragma: no cover 

272 f"Unable to build a variable from {v}.") 

273 if v.shape is not None and 0 in v.shape: 

274 # Replaces 0 by None 

275 new_shape = tuple(None if d == 0 else d for d in v.shape) 

276 if new_shape in ((None, ), None): 

277 v = v.__class__() 

278 else: 

279 v = v.__class__(new_shape) 

280 if v.shape is not None and 0 in v.shape: 

281 raise RuntimeError( # pragma: no cover 

282 f"Shape cannot be empty: '{name}': {v_}.") 

283 res.append((name, v)) 

284 return res