Coverage for mlprodict/testing/test_utils/utils_backend_python.py: 100%
36 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Inspired from sklearn-onnx, handles two backends.
4"""
5from ...onnxrt import OnnxInference
6from .utils_backend_common_compare import compare_runtime_session
9class MockVariableName:
10 "A string."
12 def __init__(self, name):
13 self.name = name
15 @property
16 def shape(self):
17 "returns shape"
18 raise NotImplementedError( # pragma: no cover
19 f"No shape for '{self.name}'.")
21 @property
22 def type(self):
23 "returns type"
24 raise NotImplementedError( # pragma: no cover
25 f"No type for '{self.name}'.")
28class MockVariableNameShape(MockVariableName):
29 "A string and a shape."
31 def __init__(self, name, sh):
32 MockVariableName.__init__(self, name)
33 self._shape = sh
35 @property
36 def shape(self):
37 "returns shape"
38 return self._shape
41class MockVariableNameShapeType(MockVariableNameShape):
42 "A string and a shape and a type."
44 def __init__(self, name, sh, stype):
45 MockVariableNameShape.__init__(self, name, sh)
46 self._stype = stype
48 @property
49 def type(self):
50 "returns type"
51 return self._stype
54class OnnxInference2(OnnxInference):
55 "onnxruntime API"
57 def run(self, name, inputs, *args, **kwargs): # pylint: disable=W0221
58 "onnxruntime API"
59 res = OnnxInference.run(self, inputs, **kwargs)
60 if name is None:
61 return [res[n] for n in self.output_names]
62 if name in res: # pragma: no cover
63 return res[name]
64 raise RuntimeError( # pragma: no cover
65 f"Unable to find output '{name}'.")
67 def get_inputs(self):
68 "onnxruntime API"
69 return [MockVariableNameShapeType(*n) for n in self.input_names_shapes_types]
71 def get_outputs(self):
72 "onnxruntime API"
73 return [MockVariableNameShape(*n) for n in self.output_names_shapes]
75 def run_in_scan(self, inputs, attributes=None, verbose=0, fLOG=None):
76 "Instance to run in operator scan."
77 return OnnxInference.run(
78 self, inputs, attributes=attributes, verbose=verbose, fLOG=fLOG)
81def compare_runtime(test, decimal=5, options=None,
82 verbose=False, context=None, comparable_outputs=None,
83 intermediate_steps=False, classes=None,
84 disable_optimisation=False):
85 """
86 The function compares the expected output (computed with
87 the model before being converted to ONNX) and the ONNX output
88 produced with module :epkg:`onnxruntime` or :epkg:`mlprodict`.
90 :param test: dictionary with the following keys:
91 - *onnx*: onnx model (filename or object)
92 - *expected*: expected output (filename pkl or object)
93 - *data*: input data (filename pkl or object)
94 :param decimal: precision of the comparison
95 :param options: comparison options
96 :param context: specifies custom operators
97 :param verbose: in case of error, the function may print
98 more information on the standard output
99 :param comparable_outputs: compare only these outputs
100 :param intermediate_steps: displays intermediate steps
101 in case of an error
102 :param classes: classes names (if option 'nocl' is used)
103 :param disable_optimisation: disable optimisation the runtime may do
104 :return: tuple (outut, lambda function to run the predictions)
106 The function does not return anything but raises an error
107 if the comparison failed.
108 """
109 return compare_runtime_session(
110 OnnxInference2, test, decimal=decimal, options=options,
111 verbose=verbose, context=context,
112 comparable_outputs=comparable_outputs,
113 intermediate_steps=intermediate_steps,
114 classes=classes, disable_optimisation=disable_optimisation)