microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
df7a9f337c69567dc9c58400d3ec8004bcafb794

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/_ortapi2.py

191lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License. See License.txt in the project root for
3# license information.
4###############################################################################
5
6"""
7_ortapi2.py: ONNXRuntime-Extensions Python API
8"""
9
10import numpy as np
11from ._ocos import default_opset_domain, get_library_path # noqa
12from ._cuops import onnx, onnx_proto, SingleOpGraph
13
14_ort_check_passed = False
15try:
16 from packaging import version as _ver
17 import onnxruntime as _ort
18
19 if _ver.parse(_ort.__version__) >= _ver.parse("1.10.0"):
20 _ort_check_passed = True
21except ImportError:
22 pass
23
24if not _ort_check_passed:
25 raise RuntimeError("please install ONNXRuntime/ONNXRuntime-GPU >= 1.10.0")
26
27
28def get_opset_version_from_ort():
29 _ORT_OPSET_SUPPORT_TABLE = {
30 "1.5": 11,
31 "1.6": 12,
32 "1.7": 13,
33 "1.8": 14,
34 "1.9": 15,
35 "1.10": 15,
36 "1.11": 16,
37 "1.12": 17,
38 "1.13": 17,
39 "1.14": 18,
40 "1.15": 18
41 }
42
43 ort_ver_string = '.'.join(_ort.__version__.split('.')[0:2])
44 max_ver = max(_ORT_OPSET_SUPPORT_TABLE, key=_ORT_OPSET_SUPPORT_TABLE.get)
45 if ort_ver_string > max_ver:
46 ort_ver_string = max_ver
47 return _ORT_OPSET_SUPPORT_TABLE.get(ort_ver_string, 11)
48
49
50def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(), extra_opset_version=1):
51 if opset_version == 0:
52 opset_version = get_opset_version_from_ort()
53 fn_mm = onnx.helper.make_model_gen_version if hasattr(onnx.helper, 'make_model_gen_version'
54 ) else onnx.helper.make_model
55 model = fn_mm(graph, opset_imports=[
56 onnx.helper.make_operatorsetid('ai.onnx', opset_version)])
57 model.opset_import.extend(
58 [onnx.helper.make_operatorsetid(extra_domain, extra_opset_version)])
59 return model
60
61
62class OrtPyFunction:
63 """
64 OrtPyFunction is a convenience class that serves as a wrapper around the ONNXRuntime InferenceSession,
65 equipped with registered onnxruntime-extensions. This allows execution of an ONNX model as if it were a
66 standard Python function. The order of the function arguments correlates directly with
67 the sequence of the input/output in the ONNX graph.
68 """
69
70 def get_ort_session_options(self):
71 so = _ort.SessionOptions()
72 for k, v in self.extra_session_options.items():
73 so.__setattr__(k, v)
74 so.register_custom_ops_library(get_library_path())
75 return so
76
77 def __init__(self, path_or_model=None, cpu_only=None):
78 self._onnx_model = None
79 self.ort_session = None
80 self.default_inputs = {}
81 self.execution_providers = ['CPUExecutionProvider']
82 if not cpu_only:
83 if _ort.get_device() == 'GPU':
84 self.execution_providers = ['CUDAExecutionProvider']
85 self.extra_session_options = {}
86 mpath = None
87 if isinstance(path_or_model, str):
88 oxml = onnx.load_model(path_or_model)
89 mpath = path_or_model
90 else:
91 oxml = path_or_model
92 if path_or_model is not None:
93 self._bind(oxml, mpath)
94
95 def create_from_customop(self, op_type, *args, **kwargs):
96 graph = SingleOpGraph.build_graph(op_type, *args, **kwargs)
97 self._bind(make_onnx_model(graph))
98 return self
99
100 def add_default_input(self, **kwargs):
101 inputs = {
102 ky_: val_ if isinstance(val_, (np.ndarray, np.generic)) else
103 np.asarray(list(val_), dtype=np.uint8) for ky_, val_ in kwargs.items()
104 }
105
106 self.default_inputs.update(inputs)
107
108 @property
109 def onnx_model(self):
110 assert self._oxml is not None, "No onnx model attached yet."
111 return self._oxml
112
113 @property
114 def input_names(self):
115 return [vi_.name for vi_ in self.onnx_model.graph.input]
116
117 @property
118 def output_names(self):
119 return [vi_.name for vi_ in self.onnx_model.graph.output]
120
121 def _bind(self, oxml, model_path=None):
122 self.inputs = list(oxml.graph.input)
123 self.outputs = list(oxml.graph.output)
124 self._oxml = oxml
125 if model_path is not None:
126 self.ort_session = _ort.InferenceSession(
127 model_path, self.get_ort_session_options(),
128 self.execution_providers)
129 return self
130
131 def _ensure_ort_session(self):
132 if self.ort_session is None:
133 sess = _ort.InferenceSession(
134 self.onnx_model.SerializeToString(), self.get_ort_session_options(),
135 self.execution_providers)
136 self.ort_session = sess
137
138 return self.ort_session
139
140 @staticmethod
141 def _get_kwarg_device(kwargs):
142 cpuonly = kwargs.get('cpu_only', None)
143 if cpuonly is not None:
144 del kwargs['cpu_only']
145 return cpuonly
146
147 @classmethod
148 def from_customop(cls, op_type, *args, **kwargs):
149 return (cls(cpu_only=cls._get_kwarg_device(kwargs))
150 .create_from_customop(op_type, *args, **kwargs))
151
152 @classmethod
153 def from_model(cls, path_or_model, *args, **kwargs):
154 fn = cls(path_or_model, cls._get_kwarg_device(kwargs))
155 return fn
156
157 def _argument_map(self, *args, **kwargs):
158 idx = 0
159 feed = {}
160 for i_ in self.inputs:
161 if i_.name in self.default_inputs:
162 feed[i_.name] = self.default_inputs[i_.name]
163 continue
164
165 x = args[idx]
166 ts_x = np.array(x) if isinstance(x, (int, float, bool)) else x
167 # numpy by default is int32 in some platforms, sometimes it is int64.
168 feed[i_.name] = \
169 ts_x.astype(
170 np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x
171 idx += 1
172
173 # feed.update(kwargs)
174 return feed
175
176 def __call__(self, *args, **kwargs):
177 self._ensure_ort_session()
178 outputs = self.ort_session.run(
179 None, self._argument_map(*args, **kwargs))
180 return outputs[0] if len(outputs) == 1 else tuple(outputs)
181
182
183def optimize_model(model_or_file, output_file):
184 sess_options = OrtPyFunction().get_ort_session_options()
185 sess_options.graph_optimization_level = _ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
186 sess_options.optimized_model_filepath = output_file
187 _ort.InferenceSession(model_or_file if isinstance(model_or_file, str)
188 else model_or_file.SerializeToString(), sess_options)
189
190
191ONNXRuntimeError = _ort.capi.onnxruntime_pybind11_state.Fail
192