Coverage for mlprodict/onnxrt/validate/side_by_side.py: 100%

101 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Helpers to compare executions. 

4""" 

5import copy 

6import numpy 

7from .validate_difference import measure_relative_difference 

8 

9 

10def _side_by_side_by_values_inputs(sess, inputs, i): 

11 if isinstance(sess, tuple) and inputs is None: 

12 new_sess, new_inputs = sess 

13 elif isinstance(inputs, list): 

14 new_sess = sess 

15 new_inputs = inputs[i] 

16 else: 

17 new_sess = sess 

18 new_inputs = copy.deepcopy(inputs) 

19 return new_sess, new_inputs 

20 

21 

22def side_by_side_by_values(sessions, *args, inputs=None, 

23 return_results=False, **kwargs): 

24 """ 

25 Compares the execution of two sessions. 

26 It calls method :meth:`OnnxInference.run 

27 <mlprodict.onnxrt.onnx_inference.OnnxInference.run>` 

28 with value ``intermediate=True`` and compares the results. 

29 

30 :param sessions: list of class @see cl OnnxInference 

31 :param inputs: inputs 

32 :param args: additional parameters for 

33 :meth:`OnnxInference.run 

34 <mlprodict.onnxrt.onnx_inference.OnnxInference.run` 

35 :param return_results: if True, returns the results as well. 

36 :param kwargs: additional parameters for 

37 :meth:`OnnxInference.run 

38 <mlprodict.onnxrt.onnx_inference.OnnxInference.run` 

39 :return: list of dictionaries 

40 

41 The first session is considered as the baseline. 

42 See notebook :ref:`onnxsbsrst` for an example. 

43 If *inputs* is None, the function assumes 

44 *sessions* is a list of *tuple(sessions, inputs)* 

45 because sometimes inputs must be different. 

46 

47 .. versionchanged:: 0.7 

48 Parameter *return_results* was added. The function 

49 returns the execution order when available. 

50 """ 

51 if not kwargs.get('intermediate', True): 

52 raise ValueError( # pragma: no cover 

53 "kwargs must not set intermediate to True") 

54 kwargs['intermediate'] = True 

55 verbose = kwargs.get('verbose', 0) 

56 fLOG = kwargs.get('fLOG', None) 

57 

58 # run 

59 results = [] 

60 orders = [] 

61 for i, sess in enumerate(sessions): 

62 if (hasattr(sess, 'runtime') and hasattr(sess, 'inplace') and 

63 sess.runtime in (None, 'python') and sess.inplace): 

64 raise ValueError( 

65 "You must disable the inplace mechanism in order to get " 

66 "true results. See OnnxInference constructor.") 

67 new_sess, new_inputs = _side_by_side_by_values_inputs(sess, inputs, i) 

68 if verbose > 0 and fLOG: 

69 fLOG( # pragma: no cover 

70 f'[side_by_side_by_values] run session {i + 1}/{len(sessions)}') 

71 res = new_sess.run(new_inputs, *args, **kwargs) 

72 order = new_sess.get_execution_order() 

73 results.append([(k, v) for k, v in res.items()]) 

74 orders.append(order) 

75 

76 # same number of results? 

77 rows = [] 

78 row = {"metric": "nb_results", 'step': -1} 

79 for i, res in enumerate(results): 

80 row["v[%d]" % i] = len(res) 

81 mnd = min(map(len, results)) 

82 mxd = max(map(len, results)) 

83 row['cmp'] = 'OK' if mnd == mxd else '!=' 

84 rows.append(row) 

85 

86 merged = merge_results(results) 

87 

88 # analysis 

89 for i in range(len(merged)): # pylint: disable=C0200 

90 for metric in ('rel-diff', 'abs-diff'): 

91 row = {'step': i} 

92 name, res_row = merged[i] 

93 row['name'] = name 

94 row['metric'] = metric 

95 

96 vals = [] 

97 for j, r in enumerate(res_row): 

98 order = orders[j] 

99 if order is not None: 

100 row['order[%d]' % j] = order.get( 

101 ('res', name), (numpy.nan, ))[0] 

102 row['value[%d]' % j] = r 

103 if hasattr(r, 'shape'): 

104 row['shape[%d]' % j] = r.shape 

105 

106 if j == 0: 

107 row['v[%d]' % j] = 0 

108 elif res_row[0] is not None and r is not None: 

109 v = measure_relative_difference( 

110 res_row[0], r, abs_diff=metric == 'abs-diff') 

111 row['v[%d]' % j] = v 

112 vals.append(v) 

113 

114 if len(vals) > 0: 

115 diff = max(vals) 

116 if diff < 1e-5: 

117 row['cmp'] = 'OK' 

118 elif diff < 0.0001: # pragma: no cover 

119 row['cmp'] = 'e<0.0001' 

120 elif diff < 0.001: # pragma: no cover 

121 row['cmp'] = 'e<0.001' 

122 elif diff < 0.01: # pragma: no cover 

123 row['cmp'] = 'e<0.01' 

124 elif diff < 0.1: # pragma: no cover 

125 row['cmp'] = 'e<0.1' 

126 else: # pragma: no cover 

127 row['cmp'] = f"ERROR->={diff:1.1f}" 

128 

129 rows.append(row) 

130 if return_results: 

131 return rows, results 

132 return rows 

133 

134 

135def merge_results(results): 

136 """ 

137 Merges results by name. The first ones 

138 are used to keep the order. 

139 

140 :param results: results of intermediate variables 

141 :return: list of tuple 

142 """ 

143 # matrix of names 

144 rows = [(k, []) for k, _ in results[0]] 

145 positions = {k[0]: i for i, k in enumerate(rows)} 

146 todos = [] 

147 for result in results: 

148 todo = [] 

149 for row in rows: 

150 row[1].append(None) 

151 for i, (k, v) in enumerate(result): 

152 pos = positions.get(k, None) 

153 if pos is None: 

154 todo.append((i, k, v)) 

155 else: 

156 rows[pos][1][-1] = (v, i) 

157 todos.append(todo) 

158 

159 # left over 

160 if len(todos) > 0: 

161 for i, todo in enumerate(todos): 

162 if len(todo) == 0: 

163 continue 

164 for pos, name, val in todo: 

165 pos1 = pos + 1 

166 found = -1 

167 for ik, row in enumerate(rows): 

168 if row[1][i] is not None and row[1][i][1] == pos1: 

169 found = ik 

170 break 

171 vv = [None] * len(results) 

172 if found == -1: 

173 vv[i] = (val, len(rows)) 

174 rows.append((name, vv)) 

175 else: 

176 vv[i] = (val, pos) 

177 rows.insert(found, (name, vv)) 

178 

179 # final 

180 final = [] 

181 for row in rows: 

182 nrow = (row[0], [_ if _ is None else _[0] for _ in row[1]]) 

183 final.append(nrow) 

184 return final