Coverage for mlprodict/testing/test_utils/utils_backend_onnxruntime.py: 100%

29 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Inspired from sklearn-onnx, handles two backends. 

4""" 

5from pyquickhelper.pycode import is_travis_or_appveyor 

6from .utils_backend_common_compare import compare_runtime_session 

7 

8 

9def _capture_output(fct, kind): 

10 if is_travis_or_appveyor(): 

11 return fct(), None, None # pragma: no cover 

12 try: 

13 from cpyquickhelper.io import capture_output 

14 except ImportError: # pragma: no cover 

15 # cpyquickhelper not available 

16 return fct(), None, None 

17 return capture_output(fct, kind) 

18 

19 

20class InferenceSession2: 

21 """ 

22 Overwrites class *InferenceSession* to capture 

23 the standard output and error. 

24 """ 

25 

26 def __init__(self, *args, **kwargs): 

27 "Overwrites the constructor." 

28 from onnxruntime import ( 

29 InferenceSession, GraphOptimizationLevel, SessionOptions) 

30 runtime_options = kwargs.pop('runtime_options', {}) 

31 disable_optimisation = runtime_options.pop( 

32 'disable_optimisation', False) 

33 if disable_optimisation: 

34 if 'sess_options' in kwargs: 

35 raise RuntimeError( # pragma: no cover 

36 "Incompatible options, 'disable_options' and 'sess_options' cannot " 

37 "be sepcified at the same time.") 

38 kwargs['sess_options'] = SessionOptions() 

39 kwargs['sess_options'].graph_optimization_level = ( 

40 GraphOptimizationLevel.ORT_DISABLE_ALL) 

41 if 'providers' not in kwargs: 

42 kwargs = kwargs.copy() 

43 kwargs['providers'] = ['CPUExecutionProvider'] 

44 self.sess, self.outi, self.erri = _capture_output( 

45 lambda: InferenceSession(*args, **kwargs), 'c') 

46 

47 def run(self, *args, **kwargs): 

48 "Overwrites method *run*." 

49 res, self.outr, self.errr = _capture_output( 

50 lambda: self.sess.run(*args, **kwargs), 'c') 

51 return res 

52 

53 def get_inputs(self, *args, **kwargs): 

54 "Overwrites method *get_inputs*." 

55 return self.sess.get_inputs(*args, **kwargs) 

56 

57 def get_outputs(self, *args, **kwargs): 

58 "Overwrites method *get_outputs*." 

59 return self.sess.get_outputs(*args, **kwargs) 

60 

61 

62def compare_runtime(test, decimal=5, options=None, 

63 verbose=False, context=None, comparable_outputs=None, 

64 intermediate_steps=False, classes=None, 

65 disable_optimisation=False): 

66 """ 

67 The function compares the expected output (computed with 

68 the model before being converted to ONNX) and the ONNX output 

69 produced with module :epkg:`onnxruntime` or :epkg:`mlprodict`. 

70 

71 :param test: dictionary with the following keys: 

72 - *onnx*: onnx model (filename or object) 

73 - *expected*: expected output (filename pkl or object) 

74 - *data*: input data (filename pkl or object) 

75 :param decimal: precision of the comparison 

76 :param options: comparison options 

77 :param context: specifies custom operators 

78 :param verbose: in case of error, the function may print 

79 more information on the standard output 

80 :param comparable_outputs: compare only these outputs 

81 :param intermediate_steps: displays intermediate steps 

82 in case of an error 

83 :param classes: classes names (if option 'nocl' is used) 

84 :param disable_optimisation: disable optimisation onnxruntime 

85 could do 

86 :return: tuple (outut, lambda function to run the predictions) 

87 

88 The function does not return anything but raises an error 

89 if the comparison failed. 

90 """ 

91 return compare_runtime_session( 

92 InferenceSession2, test, decimal=decimal, options=options, 

93 verbose=verbose, context=context, 

94 comparable_outputs=comparable_outputs, 

95 intermediate_steps=intermediate_steps, 

96 classes=classes, disable_optimisation=disable_optimisation)