Coverage for mlprodict/onnxrt/ops.py: 100%

17 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Loads runtime operator. 

4""" 

5 

6 

7def load_op(onnx_node, desc=None, options=None, variables=None, dtype=None, runtime=None): 

8 """ 

9 Sets up a class for a specific ONNX operator. 

10 

11 :param onnx_node: :epkg:`onnx` node 

12 :param desc: internal representation 

13 :param options: runtime options 

14 :param variables: registered variables created by previous operators 

15 :param dtype: float computational type 

16 :param runtime: runtime 

17 :return: runtime class 

18 """ 

19 if desc is None: 

20 raise ValueError( # pragma: no cover 

21 "desc should not be None.") 

22 if options is None: 

23 provider = 'python' # pragma: no cover 

24 else: 

25 provider = options.get('provider', 'python') 

26 if 'provider' in options: 

27 options = options.copy() 

28 del options['provider'] 

29 if provider == 'python': 

30 from .ops_cpu import load_op as lo 

31 return lo(onnx_node, desc=desc, options=options) 

32 if provider == 'empty': 

33 from .ops_empty import load_op as lo 

34 return lo(onnx_node, desc=desc, options=options) 

35 if provider in ('onnxruntime2', 'onnxruntime2-cuda'): 

36 from .ops_onnxruntime import load_op as lo 

37 return lo(onnx_node, desc=desc, options=options, # pylint: disable=E1123 

38 variables=variables, dtype=dtype, runtime=runtime) 

39 raise ValueError(f"Unable to handle provider '{provider}'.")