Coverage for mlprodict/onnxrt/ops_whole/session.py: 96%
69 statements
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
« prev ^ index » next coverage.py v7.1.0, created at 2023-02-04 02:28 +0100
1# -*- encoding: utf-8 -*-
2"""
3@file
4@brief Shortcut to *ops_whole*.
5"""
6import json
7import numpy
10class OnnxWholeSession:
11 """
12 Runs the prediction for a single :epkg:`ONNX`,
13 it lets the runtime handle the graph logic as well.
15 :param onnx_data: :epkg:`ONNX` model or data
16 :param runtime: runtime to be used, mostly :epkg:`onnxruntime`
17 :param runtime_options: runtime options
18 :param device: device, a string `cpu`, `cuda`, `cuda:0`...
20 .. versionchanged:: 0.8
21 Parameter *device* was added.
22 """
24 def __init__(self, onnx_data, runtime, runtime_options=None, device=None):
25 if runtime not in ('onnxruntime1', 'onnxruntime1-cuda'):
26 raise NotImplementedError( # pragma: no cover
27 f"runtime '{runtime}' is not implemented.")
29 from onnxruntime import ( # delayed
30 InferenceSession, SessionOptions, RunOptions,
31 GraphOptimizationLevel)
32 from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611
33 Fail as OrtFail, InvalidGraph as OrtInvalidGraph,
34 InvalidArgument as OrtInvalidArgument,
35 NotImplemented as OrtNotImplemented,
36 RuntimeException as OrtRuntimeException)
38 onnx_data0 = onnx_data
39 if hasattr(onnx_data, 'SerializeToString'):
40 onnx_data = onnx_data.SerializeToString()
41 if isinstance(runtime_options, SessionOptions):
42 sess_options = runtime_options
43 session_options = None
44 runtime_options = None
45 else:
46 session_options = (
47 None if runtime_options is None
48 else runtime_options.get('session_options', None))
49 self.runtime = runtime
50 sess_options = session_options or SessionOptions()
51 self.run_options = RunOptions()
52 self.run_options.log_severity_level = 3
53 self.run_options.log_verbosity_level = 1
55 if session_options is None:
56 if runtime_options is not None:
57 if runtime_options.get('disable_optimisation', False):
58 sess_options.graph_optimization_level = ( # pragma: no cover
59 GraphOptimizationLevel.ORT_ENABLE_ALL)
60 if runtime_options.get('enable_profiling', True):
61 sess_options.enable_profiling = True
62 if runtime_options.get('log_severity_level', 2) != 2:
63 v = runtime_options.get('log_severity_level', 2)
64 sess_options.log_severity_level = v
65 self.run_options.log_severity_level = v
66 elif runtime_options is not None and 'enable_profiling' in runtime_options:
67 raise RuntimeError( # pragma: no cover
68 "session_options and enable_profiling cannot be defined at the "
69 "same time.")
70 elif runtime_options is not None and 'disable_optimisation' in runtime_options:
71 raise RuntimeError( # pragma: no cover
72 "session_options and disable_optimisation cannot be defined at the "
73 "same time.")
74 elif runtime_options is not None and 'log_severity_level' in runtime_options:
75 raise RuntimeError( # pragma: no cover
76 "session_options and log_severity_level cannot be defined at the "
77 "same time.")
78 providers = ['CPUExecutionProvider']
79 if runtime == 'onnxruntime1-cuda':
80 providers = ['CUDAExecutionProvider'] + providers
81 try:
82 self.sess = InferenceSession(onnx_data, sess_options=sess_options,
83 device=device, providers=providers)
84 except (OrtFail, OrtNotImplemented, OrtInvalidGraph,
85 OrtInvalidArgument, OrtRuntimeException, RuntimeError) as e:
86 from ...plotting.text_plot import onnx_simple_text_plot
87 raise RuntimeError(
88 "Unable to create InferenceSession due to '{}'\n{}.".format(
89 e, onnx_simple_text_plot(onnx_data0, recursive=True))) from e
90 self.output_names = [_.name for _ in self.sess.get_outputs()]
92 def run(self, inputs):
93 """
94 Computes the predictions.
96 @param inputs dictionary *{variable, value}*
97 @return list of outputs
98 """
99 v = next(iter(inputs.values()))
100 if isinstance(v, (numpy.ndarray, dict)):
101 try:
102 return self.sess._sess.run(
103 self.output_names, inputs, self.run_options)
104 except ValueError as e:
105 raise ValueError(
106 "Issue running inference inputs=%r, expected inputs=%r."
107 "" % (
108 list(sorted(inputs)),
109 [i.name for i in self.sess.get_inputs()])) from e
110 try:
111 return self.sess._sess.run_with_ort_values(
112 inputs, self.output_names, self.run_options)
113 except RuntimeError:
114 return self.sess._sess.run_with_ort_values(
115 {k: v._get_c_value() for k, v in inputs.items()},
116 self.output_names, self.run_options)
118 @staticmethod
119 def process_profiling(js):
120 """
121 Flattens json returned by onnxruntime profiling.
123 :param js: json
124 :return: list of dictionaries
125 """
126 rows = []
127 for row in js:
128 if 'args' in row and isinstance(row['args'], dict):
129 for k, v in row['args'].items():
130 row[f'args_{k}'] = v
131 del row['args']
132 rows.append(row)
133 return rows
135 def get_profiling(self):
136 """
137 Returns the profiling informations.
138 """
139 prof = self.sess.end_profiling()
140 with open(prof, 'r') as f:
141 content = f.read()
142 js = json.loads(content)
143 return OnnxWholeSession.process_profiling(js)