Coverage for mlprodict/tools/ort_wrapper.py: 88%

68 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1""" 

2@file 

3@brief Wrapper around :epkg:`onnxruntime`. 

4 

5.. versionadded:: 0.6 

6""" 

7import os 

8from onnx import numpy_helper 

9 

10 

11class InferenceSession: # pylint: disable=E0102 

12 """ 

13 Wrappers around InferenceSession from :epkg:`onnxruntime`. 

14 

15 :param onnx_bytes: onnx bytes 

16 :param session_options: session options 

17 :param log_severity_level: change the logging level 

18 :param runtime: runtime to use, `onnxruntime`, `onnxruntime-cuda`, ... 

19 :param providers: providers 

20 """ 

21 

22 def __init__(self, onnx_bytes, sess_options=None, log_severity_level=4, 

23 runtime='onnxruntime', providers=None): 

24 from onnxruntime import ( # pylint: disable=W0611 

25 SessionOptions, RunOptions, 

26 InferenceSession as OrtInferenceSession, 

27 set_default_logger_severity) 

28 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

29 OrtValue as C_OrtValue) 

30 

31 self.C_OrtValue = C_OrtValue 

32 

33 self.log_severity_level = log_severity_level 

34 if providers is not None: 

35 self.providers = providers 

36 elif runtime in (None, 'onnxruntime', 'onnxruntime1', 'onnxruntime2'): 

37 providers = ['CPUExecutionProvider'] 

38 elif runtime in ('onnxruntime-cuda', 'onnxruntime1-cuda', 'onnxruntime2-cuda'): 

39 providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] 

40 else: 

41 raise ValueError( 

42 f"Unexpected value {runtime!r} for onnxruntime.") 

43 self.providers = providers 

44 set_default_logger_severity(3) 

45 if sess_options is None: 

46 self.so = SessionOptions() 

47 self.so.log_severity_level = log_severity_level 

48 self.sess = OrtInferenceSession( 

49 onnx_bytes, sess_options=self.so, 

50 providers=self.providers) 

51 else: 

52 self.so = sess_options 

53 self.sess = OrtInferenceSession( 

54 onnx_bytes, sess_options=sess_options, 

55 providers=self.providers) 

56 self.ro = RunOptions() 

57 self.ro.log_severity_level = log_severity_level 

58 self.ro.log_verbosity_level = log_severity_level 

59 self.output_names = [o.name for o in self.get_outputs()] 

60 

61 def run(self, output_names, input_feed, run_options=None): 

62 """ 

63 Executes the ONNX graph. 

64 

65 :param output_names: None for all, a name for a specific output 

66 :param input_feed: dictionary of inputs 

67 :param run_options: None or RunOptions 

68 :return: array 

69 """ 

70 if any(map(lambda v: isinstance(v, self.C_OrtValue), 

71 input_feed.values())): 

72 return self.sess._sess.run_with_ort_values( 

73 input_feed, self.output_names, run_options or self.ro) 

74 return self.sess.run(output_names, input_feed, run_options or self.ro) 

75 

76 def get_inputs(self): 

77 "Returns input types." 

78 return self.sess.get_inputs() 

79 

80 def get_outputs(self): 

81 "Returns output types." 

82 return self.sess.get_outputs() 

83 

84 def end_profiling(self): 

85 "Ends profiling." 

86 return self.sess.end_profiling() 

87 

88 

89def prepare_c_profiling(model_onnx, inputs, dest=None): 

90 """ 

91 Prepares model and data to be profiled with tool `perftest 

92 <https://github.com/microsoft/onnxruntime/tree/ 

93 master/onnxruntime/test/perftest>`_ (onnxruntime) or 

94 `onnx_test_runner <https://github.com/microsoft/ 

95 onnxruntime/blob/master/docs/Model_Test.md>`_. 

96 It saves the model in folder 

97 *dest* and dumps the inputs in a subfolder. 

98 

99 :param model_onnx: onnx model 

100 :param inputs: inputs as a list of a dictionary 

101 :param dest: destination folder, None means the current folder 

102 :return: command line to use 

103 """ 

104 if dest is None: 

105 dest = "." 

106 if not os.path.exists(dest): 

107 os.makedirs(dest) # pragma: no cover 

108 dest = os.path.abspath(dest) 

109 name = "model.onnx" 

110 model_bytes = model_onnx.SerializeToString() 

111 with open(os.path.join(dest, name), "wb") as f: 

112 f.write(model_bytes) 

113 sess = InferenceSession(model_bytes, providers=['CPUExecutionProvider']) 

114 input_names = [_.name for _ in sess.get_inputs()] 

115 if isinstance(inputs, list): 

116 dict_inputs = dict(zip(input_names, inputs)) 

117 else: 

118 dict_inputs = inputs 

119 inputs = [dict_inputs[n] for n in input_names] 

120 outputs = sess.run(None, dict_inputs) 

121 sub = os.path.join(dest, "test_data_set_0") 

122 if not os.path.exists(sub): 

123 os.makedirs(sub) 

124 for i, v in enumerate(inputs): 

125 n = os.path.join(sub, "input_%d.pb" % i) 

126 pr = numpy_helper.from_array(v) 

127 with open(n, "wb") as f: 

128 f.write(pr.SerializeToString()) 

129 for i, v in enumerate(outputs): 

130 n = os.path.join(sub, "output_%d.pb" % i) 

131 pr = numpy_helper.from_array(v) 

132 with open(n, "wb") as f: 

133 f.write(pr.SerializeToString()) 

134 

135 cmd = f'onnx_test_runner -e cpu -r 100 -c 1 "{dest}"' 

136 return cmd