Coverage for mlprodict/onnxrt/ops_onnx/_op.py: 95%
22 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Additional methods for the extension of
4:epkg:`ReferenceEvaluator`.
5"""
6from io import BytesIO
7import pickle
8from typing import Any, Dict
9from onnx import NodeProto
10from onnx.reference.op_run import OpRun
13class OpRunExtended(OpRun):
14 """
15 Base class to cache C++ implementation based on inputs.
16 """
18 def __init__(self, onnx_node: NodeProto, run_params: Dict[str, Any]):
19 OpRun.__init__(self, onnx_node, run_params)
20 self._cache = {}
22 def get_cache_key(self, **kwargs):
23 """
24 Returns a key mapped to the corresponding C++ implementation.
25 """
26 b = BytesIO()
27 pickle.dump(kwargs, b)
28 return b.getvalue()
30 def has_cache_key(self, key):
31 """
32 Tells if a key belongs to the cache.
33 """
34 return key in self._cache
36 def get_cache_impl(self, key):
37 """
38 Returns the cached implementation for key *key*.
39 """
40 return self._cache[key]
42 def cache_impl(self, key, rt):
43 """
44 Caches an implementation.
45 """
46 if key in self._cache:
47 raise RuntimeError(f"Key {key!r} is already cached.")
48 self._cache[key] = rt
49 return rt