Coverage for mlprodict/onnxrt/validate/side_by_side.py: 100%
101 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1"""
2@file
3@brief Helpers to compare executions.
4"""
5import copy
6import numpy
7from .validate_difference import measure_relative_difference
10def _side_by_side_by_values_inputs(sess, inputs, i):
11 if isinstance(sess, tuple) and inputs is None:
12 new_sess, new_inputs = sess
13 elif isinstance(inputs, list):
14 new_sess = sess
15 new_inputs = inputs[i]
16 else:
17 new_sess = sess
18 new_inputs = copy.deepcopy(inputs)
19 return new_sess, new_inputs
22def side_by_side_by_values(sessions, *args, inputs=None,
23 return_results=False, **kwargs):
24 """
25 Compares the execution of two sessions.
26 It calls method :meth:`OnnxInference.run
27 <mlprodict.onnxrt.onnx_inference.OnnxInference.run>`
28 with value ``intermediate=True`` and compares the results.
30 :param sessions: list of class @see cl OnnxInference
31 :param inputs: inputs
32 :param args: additional parameters for
33 :meth:`OnnxInference.run
34 <mlprodict.onnxrt.onnx_inference.OnnxInference.run`
35 :param return_results: if True, returns the results as well.
36 :param kwargs: additional parameters for
37 :meth:`OnnxInference.run
38 <mlprodict.onnxrt.onnx_inference.OnnxInference.run`
39 :return: list of dictionaries
41 The first session is considered as the baseline.
42 See notebook :ref:`onnxsbsrst` for an example.
43 If *inputs* is None, the function assumes
44 *sessions* is a list of *tuple(sessions, inputs)*
45 because sometimes inputs must be different.
47 .. versionchanged:: 0.7
48 Parameter *return_results* was added. The function
49 returns the execution order when available.
50 """
51 if not kwargs.get('intermediate', True):
52 raise ValueError( # pragma: no cover
53 "kwargs must not set intermediate to True")
54 kwargs['intermediate'] = True
55 verbose = kwargs.get('verbose', 0)
56 fLOG = kwargs.get('fLOG', None)
58 # run
59 results = []
60 orders = []
61 for i, sess in enumerate(sessions):
62 if (hasattr(sess, 'runtime') and hasattr(sess, 'inplace') and
63 sess.runtime in (None, 'python') and sess.inplace):
64 raise ValueError(
65 "You must disable the inplace mechanism in order to get "
66 "true results. See OnnxInference constructor.")
67 new_sess, new_inputs = _side_by_side_by_values_inputs(sess, inputs, i)
68 if verbose > 0 and fLOG:
69 fLOG( # pragma: no cover
70 f'[side_by_side_by_values] run session {i + 1}/{len(sessions)}')
71 res = new_sess.run(new_inputs, *args, **kwargs)
72 order = new_sess.get_execution_order()
73 results.append([(k, v) for k, v in res.items()])
74 orders.append(order)
76 # same number of results?
77 rows = []
78 row = {"metric": "nb_results", 'step': -1}
79 for i, res in enumerate(results):
80 row["v[%d]" % i] = len(res)
81 mnd = min(map(len, results))
82 mxd = max(map(len, results))
83 row['cmp'] = 'OK' if mnd == mxd else '!='
84 rows.append(row)
86 merged = merge_results(results)
88 # analysis
89 for i in range(len(merged)): # pylint: disable=C0200
90 for metric in ('rel-diff', 'abs-diff'):
91 row = {'step': i}
92 name, res_row = merged[i]
93 row['name'] = name
94 row['metric'] = metric
96 vals = []
97 for j, r in enumerate(res_row):
98 order = orders[j]
99 if order is not None:
100 row['order[%d]' % j] = order.get(
101 ('res', name), (numpy.nan, ))[0]
102 row['value[%d]' % j] = r
103 if hasattr(r, 'shape'):
104 row['shape[%d]' % j] = r.shape
106 if j == 0:
107 row['v[%d]' % j] = 0
108 elif res_row[0] is not None and r is not None:
109 v = measure_relative_difference(
110 res_row[0], r, abs_diff=metric == 'abs-diff')
111 row['v[%d]' % j] = v
112 vals.append(v)
114 if len(vals) > 0:
115 diff = max(vals)
116 if diff < 1e-5:
117 row['cmp'] = 'OK'
118 elif diff < 0.0001: # pragma: no cover
119 row['cmp'] = 'e<0.0001'
120 elif diff < 0.001: # pragma: no cover
121 row['cmp'] = 'e<0.001'
122 elif diff < 0.01: # pragma: no cover
123 row['cmp'] = 'e<0.01'
124 elif diff < 0.1: # pragma: no cover
125 row['cmp'] = 'e<0.1'
126 else: # pragma: no cover
127 row['cmp'] = f"ERROR->={diff:1.1f}"
129 rows.append(row)
130 if return_results:
131 return rows, results
132 return rows
135def merge_results(results):
136 """
137 Merges results by name. The first ones
138 are used to keep the order.
140 :param results: results of intermediate variables
141 :return: list of tuple
142 """
143 # matrix of names
144 rows = [(k, []) for k, _ in results[0]]
145 positions = {k[0]: i for i, k in enumerate(rows)}
146 todos = []
147 for result in results:
148 todo = []
149 for row in rows:
150 row[1].append(None)
151 for i, (k, v) in enumerate(result):
152 pos = positions.get(k, None)
153 if pos is None:
154 todo.append((i, k, v))
155 else:
156 rows[pos][1][-1] = (v, i)
157 todos.append(todo)
159 # left over
160 if len(todos) > 0:
161 for i, todo in enumerate(todos):
162 if len(todo) == 0:
163 continue
164 for pos, name, val in todo:
165 pos1 = pos + 1
166 found = -1
167 for ik, row in enumerate(rows):
168 if row[1][i] is not None and row[1][i][1] == pos1:
169 found = ik
170 break
171 vv = [None] * len(results)
172 if found == -1:
173 vv[i] = (val, len(rows))
174 rows.append((name, vv))
175 else:
176 vv[i] = (val, pos)
177 rows.insert(found, (name, vv))
179 # final
180 final = []
181 for row in rows:
182 nrow = (row[0], [_ if _ is None else _[0] for _ in row[1]])
183 final.append(nrow)
184 return final