Coverage for mlprodict/onnxrt/ops_onnx/_op.py: 95%

22 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Additional methods for the extension of 

4:epkg:`ReferenceEvaluator`. 

5""" 

6from io import BytesIO 

7import pickle 

8from typing import Any, Dict 

9from onnx import NodeProto 

10from onnx.reference.op_run import OpRun 

11 

12 

13class OpRunExtended(OpRun): 

14 """ 

15 Base class to cache C++ implementation based on inputs. 

16 """ 

17 

18 def __init__(self, onnx_node: NodeProto, run_params: Dict[str, Any]): 

19 OpRun.__init__(self, onnx_node, run_params) 

20 self._cache = {} 

21 

22 def get_cache_key(self, **kwargs): 

23 """ 

24 Returns a key mapped to the corresponding C++ implementation. 

25 """ 

26 b = BytesIO() 

27 pickle.dump(kwargs, b) 

28 return b.getvalue() 

29 

30 def has_cache_key(self, key): 

31 """ 

32 Tells if a key belongs to the cache. 

33 """ 

34 return key in self._cache 

35 

36 def get_cache_impl(self, key): 

37 """ 

38 Returns the cached implementation for key *key*. 

39 """ 

40 return self._cache[key] 

41 

42 def cache_impl(self, key, rt): 

43 """ 

44 Caches an implementation. 

45 """ 

46 if key in self._cache: 

47 raise RuntimeError(f"Key {key!r} is already cached.") 

48 self._cache[key] = rt 

49 return rt