Coverage for mlprodict/onnxrt/ops_cpu/__init__.py: 81%

115 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2""" 

3@file 

4@brief Shortcut to *ops_cpu*. 

5""" 

6import inspect 

7import textwrap 

8from onnx import FunctionProto 

9from onnx.reference.ops import load_op as onnx_load_op 

10from onnx.defs import get_schema 

11from ..excs import MissingOperatorError 

12from ._op import OpRunCustom, OpFunction 

13from ._op_list import __dict__ as d_op_list 

14 

15 

16_additional_ops = {} 

17 

18 

19def register_operator(cls, name=None, overwrite=True): 

20 """ 

21 Registers a new runtime operator. 

22 

23 @param cls class 

24 @param name by default ``cls.__name__``, 

25 or *name* if defined 

26 @param overwrite overwrite or raise an exception 

27 """ 

28 if name is None: 

29 name = cls.__name__ 

30 if name not in _additional_ops: 

31 _additional_ops[name] = cls 

32 elif not overwrite: 

33 raise RuntimeError( # pragma: no cover 

34 "Unable to overwrite existing operator '{}': {} " 

35 "by {}".format(name, _additional_ops[name], cls)) 

36 

37 

38def load_op(onnx_node, desc=None, options=None, runtime=None): 

39 """ 

40 Gets the operator related to the *onnx* node. 

41 

42 :param onnx_node: :epkg:`onnx` node 

43 :param desc: internal representation 

44 :param options: runtime options 

45 :param runtime: runtime 

46 :param existing_functions: existing functions 

47 :return: runtime class 

48 """ 

49 from ... import __max_supported_opset__ 

50 if desc is None: 

51 raise ValueError("desc should not be None.") # pragma no cover 

52 name = onnx_node.op_type 

53 opset = options.get('target_opset', None) if options is not None else None 

54 current_opset = __max_supported_opset__ 

55 chosen_opset = opset or current_opset 

56 if opset is not None: 

57 if not isinstance(opset, int): 

58 raise TypeError( # pragma no cover 

59 f"opset must be an integer not {type(opset)}") 

60 name_opset = name + "_" + str(opset) 

61 for op in range(opset, 0, -1): 

62 nop = name + "_" + str(op) 

63 if nop in d_op_list: 

64 name_opset = nop 

65 chosen_opset = op 

66 break 

67 else: 

68 name_opset = name 

69 

70 onnx_op = False 

71 if name_opset in _additional_ops: 

72 cl = _additional_ops[name_opset] 

73 elif name in _additional_ops: 

74 cl = _additional_ops[name] 

75 elif name_opset in d_op_list: 

76 cl = d_op_list[name_opset] 

77 elif name in d_op_list: 

78 cl = d_op_list[name] 

79 else: 

80 # finish 

81 try: 

82 cl = onnx_load_op(options.get('domain', ''), 

83 name, opset) 

84 except ValueError as e: 

85 raise MissingOperatorError( 

86 f"Unable to load class for operator name={name}, " 

87 f"opset={opset}, options={options}, " 

88 f"_additional_ops={_additional_ops}.") from e 

89 onnx_op = True 

90 if cl is None: 

91 raise MissingOperatorError( # pragma no cover 

92 "Operator '{}' from domain '{}' has no runtime yet. " 

93 "Available list:\n" 

94 "{} - {}".format( 

95 name, onnx_node.domain, 

96 "\n".join(sorted(_additional_ops)), 

97 "\n".join(textwrap.wrap( 

98 " ".join( 

99 _ for _ in sorted(d_op_list) 

100 if "_" not in _ and _ not in { 

101 'cl', 'clo', 'name'}))))) 

102 

103 class _Wrapper: 

104 

105 def _log(self, *args, **kwargs): 

106 pass 

107 

108 @property 

109 def base_class(self): 

110 "Returns the parent class." 

111 return self.__class__.__bases__[0] 

112 

113 def _onnx_run(self, *args, **kwargs): 

114 cl = self.base_class 

115 new_kws = {} 

116 for k, v in kwargs.items(): 

117 if k not in {'attributes', 'verbose', 'fLOG'}: 

118 new_kws[k] = v 

119 attributes = kwargs.get('attributes', None) 

120 if attributes is not None and len(attributes) > 0: 

121 raise NotImplementedError( 

122 f"attributes is not empty but not implemented yet, " 

123 f"attribures={attributes}.") 

124 return cl.run(self, *args, **new_kws) # pylint: disable=E1101 

125 

126 def _onnx__run(self, *args, attributes=None, **kwargs): 

127 """ 

128 Wraps ONNX call to OpRun._run. 

129 """ 

130 cl = self.base_class 

131 if attributes is not None and len(attributes) > 0: 

132 raise NotImplementedError( # pragma: no cover 

133 f"Linked attributes are not yet implemented for class " 

134 f"{self.__class__!r}.") 

135 return cl._run(self, *args, **kwargs) # pylint: disable=E1101 

136 

137 def _onnx_need_context(self): 

138 cl = self.base_class 

139 return cl.need_context(self) # pylint: disable=E1101 

140 

141 def __init__(self, onnx_node, desc=None, **options): 

142 cl = self.__class__.__bases__[0] 

143 run_params = {'log': _Wrapper._log, 

144 'opsets': {'': opset}, 

145 'new_ops': None} 

146 cl.__init__(self, onnx_node, run_params) 

147 

148 # wrapping the original class 

149 if inspect.isfunction(cl): 

150 domain = options.get('domain', '') 

151 if domain != '': 

152 raise TypeError( 

153 f"Unable to create a class for operator {name!r} and " 

154 f"opset {opset} based on {cl} of type={type(cl)}.") 

155 schema = get_schema(name, opset, domain) 

156 if schema.has_function: 

157 from mlprodict.onnxrt import OnnxInference 

158 body = schema.function_body 

159 sess = OnnxInference(body) 

160 new_cls = lambda *args, sess=sess: OpFunction( 

161 args[0], impl=sess) 

162 elif schema.has_context_dependent_function: 

163 input_types = options.get('input_types', '') 

164 if onnx_node is None or input_types is None: 

165 raise RuntimeError( 

166 f"No registered implementation for operator {onnx_node.op_type!r} " 

167 f"and domain {domain!r}, the operator has a context dependent function. " 

168 f"but argument node or input_types is not defined.") 

169 from mlprodict.onnxrt import OnnxInference 

170 body = schema.get_context_dependent_function( 

171 onnx_node.SerializeToString(), 

172 [it.SerializeToString() for it in input_types]) 

173 proto = FunctionProto() 

174 proto.ParseFromString(body) 

175 sess = OnnxInference(proto) 

176 new_cls = lambda *args, sess=sess: OpFunction( 

177 args[0], impl=sess) 

178 else: 

179 raise TypeError( 

180 f"Unable to create a class for operator {name!r} and " 

181 f"opset {opset} based on {cl} of type={type(cl)}.") 

182 else: 

183 try: 

184 new_cls = type(f"{name}_{opset}", (cl, ), 

185 {'__init__': _Wrapper.__init__, 

186 '_run': _Wrapper._onnx__run, 

187 'base_class': _Wrapper.base_class, 

188 'run': _Wrapper._onnx_run, 

189 'need_context': _Wrapper._onnx_need_context}) 

190 except TypeError as e: 

191 raise TypeError( 

192 f"Unable to create a class for operator {name!r} and " 

193 f"opset {opset} based on {cl} of type={type(cl)}.") from e 

194 cl = new_cls 

195 

196 if hasattr(cl, 'version_higher_than'): 

197 opv = min(current_opset, chosen_opset) 

198 if cl.version_higher_than > opv: 

199 # The chosen implementation does not support 

200 # the opset version, we need to downgrade it. 

201 if ('target_opset' in options and 

202 options['target_opset'] is not None): # pragma: no cover 

203 raise RuntimeError( 

204 "Supported version {} > {} (opset={}) required version, " 

205 "unable to find an implementation version {} found " 

206 "'{}'\n--ONNX--\n{}\n--AVAILABLE--\n{}".format( 

207 cl.version_higher_than, opv, opset, 

208 options['target_opset'], cl.__name__, onnx_node, 

209 "\n".join( 

210 _ for _ in sorted(d_op_list) 

211 if "_" not in _ and _ not in {'cl', 'clo', 'name'}))) 

212 options = options.copy() 

213 options['target_opset'] = current_opset 

214 return load_op(onnx_node, desc=desc, options=options) 

215 

216 if options is None: 

217 options = {} # pragma: no cover 

218 if onnx_op: 

219 try: 

220 return cl(onnx_node, {'log': None}) 

221 except TypeError as e: 

222 raise TypeError( # pragma: no cover 

223 f"Unexpected issue with class {cl}.") from e 

224 try: 

225 return cl(onnx_node, desc=desc, runtime=runtime, **options) 

226 except TypeError as e: 

227 raise TypeError( # pragma: no cover 

228 f"Unexpected issue with class {cl}.") from e