Coverage for mlprodict/onnxrt/ops_whole/session.py: 96%

69 statements  

« prev     ^ index     » next       coverage.py v7.1.0, created at 2023-02-04 02:28 +0100

1# -*- encoding: utf-8 -*- 

2""" 

3@file 

4@brief Shortcut to *ops_whole*. 

5""" 

6import json 

7import numpy 

8 

9 

10class OnnxWholeSession: 

11 """ 

12 Runs the prediction for a single :epkg:`ONNX`, 

13 it lets the runtime handle the graph logic as well. 

14 

15 :param onnx_data: :epkg:`ONNX` model or data 

16 :param runtime: runtime to be used, mostly :epkg:`onnxruntime` 

17 :param runtime_options: runtime options 

18 :param device: device, a string `cpu`, `cuda`, `cuda:0`... 

19 

20 .. versionchanged:: 0.8 

21 Parameter *device* was added. 

22 """ 

23 

24 def __init__(self, onnx_data, runtime, runtime_options=None, device=None): 

25 if runtime not in ('onnxruntime1', 'onnxruntime1-cuda'): 

26 raise NotImplementedError( # pragma: no cover 

27 f"runtime '{runtime}' is not implemented.") 

28 

29 from onnxruntime import ( # delayed 

30 InferenceSession, SessionOptions, RunOptions, 

31 GraphOptimizationLevel) 

32 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 

33 Fail as OrtFail, InvalidGraph as OrtInvalidGraph, 

34 InvalidArgument as OrtInvalidArgument, 

35 NotImplemented as OrtNotImplemented, 

36 RuntimeException as OrtRuntimeException) 

37 

38 onnx_data0 = onnx_data 

39 if hasattr(onnx_data, 'SerializeToString'): 

40 onnx_data = onnx_data.SerializeToString() 

41 if isinstance(runtime_options, SessionOptions): 

42 sess_options = runtime_options 

43 session_options = None 

44 runtime_options = None 

45 else: 

46 session_options = ( 

47 None if runtime_options is None 

48 else runtime_options.get('session_options', None)) 

49 self.runtime = runtime 

50 sess_options = session_options or SessionOptions() 

51 self.run_options = RunOptions() 

52 self.run_options.log_severity_level = 3 

53 self.run_options.log_verbosity_level = 1 

54 

55 if session_options is None: 

56 if runtime_options is not None: 

57 if runtime_options.get('disable_optimisation', False): 

58 sess_options.graph_optimization_level = ( # pragma: no cover 

59 GraphOptimizationLevel.ORT_ENABLE_ALL) 

60 if runtime_options.get('enable_profiling', True): 

61 sess_options.enable_profiling = True 

62 if runtime_options.get('log_severity_level', 2) != 2: 

63 v = runtime_options.get('log_severity_level', 2) 

64 sess_options.log_severity_level = v 

65 self.run_options.log_severity_level = v 

66 elif runtime_options is not None and 'enable_profiling' in runtime_options: 

67 raise RuntimeError( # pragma: no cover 

68 "session_options and enable_profiling cannot be defined at the " 

69 "same time.") 

70 elif runtime_options is not None and 'disable_optimisation' in runtime_options: 

71 raise RuntimeError( # pragma: no cover 

72 "session_options and disable_optimisation cannot be defined at the " 

73 "same time.") 

74 elif runtime_options is not None and 'log_severity_level' in runtime_options: 

75 raise RuntimeError( # pragma: no cover 

76 "session_options and log_severity_level cannot be defined at the " 

77 "same time.") 

78 providers = ['CPUExecutionProvider'] 

79 if runtime == 'onnxruntime1-cuda': 

80 providers = ['CUDAExecutionProvider'] + providers 

81 try: 

82 self.sess = InferenceSession(onnx_data, sess_options=sess_options, 

83 device=device, providers=providers) 

84 except (OrtFail, OrtNotImplemented, OrtInvalidGraph, 

85 OrtInvalidArgument, OrtRuntimeException, RuntimeError) as e: 

86 from ...plotting.text_plot import onnx_simple_text_plot 

87 raise RuntimeError( 

88 "Unable to create InferenceSession due to '{}'\n{}.".format( 

89 e, onnx_simple_text_plot(onnx_data0, recursive=True))) from e 

90 self.output_names = [_.name for _ in self.sess.get_outputs()] 

91 

92 def run(self, inputs): 

93 """ 

94 Computes the predictions. 

95 

96 @param inputs dictionary *{variable, value}* 

97 @return list of outputs 

98 """ 

99 v = next(iter(inputs.values())) 

100 if isinstance(v, (numpy.ndarray, dict)): 

101 try: 

102 return self.sess._sess.run( 

103 self.output_names, inputs, self.run_options) 

104 except ValueError as e: 

105 raise ValueError( 

106 "Issue running inference inputs=%r, expected inputs=%r." 

107 "" % ( 

108 list(sorted(inputs)), 

109 [i.name for i in self.sess.get_inputs()])) from e 

110 try: 

111 return self.sess._sess.run_with_ort_values( 

112 inputs, self.output_names, self.run_options) 

113 except RuntimeError: 

114 return self.sess._sess.run_with_ort_values( 

115 {k: v._get_c_value() for k, v in inputs.items()}, 

116 self.output_names, self.run_options) 

117 

118 @staticmethod 

119 def process_profiling(js): 

120 """ 

121 Flattens json returned by onnxruntime profiling. 

122 

123 :param js: json 

124 :return: list of dictionaries 

125 """ 

126 rows = [] 

127 for row in js: 

128 if 'args' in row and isinstance(row['args'], dict): 

129 for k, v in row['args'].items(): 

130 row[f'args_{k}'] = v 

131 del row['args'] 

132 rows.append(row) 

133 return rows 

134 

135 def get_profiling(self): 

136 """ 

137 Returns the profiling informations. 

138 """ 

139 prof = self.sess.end_profiling() 

140 with open(prof, 'r') as f: 

141 content = f.read() 

142 js = json.loads(content) 

143 return OnnxWholeSession.process_profiling(js)