microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
0f45fef2d9301cc5b479b4df4bd8ecbaac93e1e6

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/_ortapi2.py

185lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License. See License.txt in the project root for
3# license information.
4###############################################################################
5
6import numpy as np
7from ._ocos import default_opset_domain, get_library_path # noqa
8from ._cuops import onnx, onnx_proto, CustomOpConverter, SingleOpGraph
9
10_ort_check_passed = False
11try:
12 from packaging import version as _ver
13 import onnxruntime as _ort
14 if _ver.parse(_ort.__version__) >= _ver.parse("1.10.0"):
15 _ort_check_passed = True
16except ImportError:
17 pass
18
19if not _ort_check_passed:
20 raise RuntimeError("please install ONNXRuntime/ONNXRuntime-GPU >= 1.10.0")
21
22
23def get_opset_version_from_ort():
24 _ORT_OPSET_SUPPORT_TABLE = {
25 "1.5": 11,
26 "1.6": 12,
27 "1.7": 13,
28 "1.8": 14,
29 "1.9": 15,
30 "1.10": 15,
31 "1.11": 16,
32 "1.12": 17,
33 "1.13": 17,
34 "1.14": 18,
35 }
36
37 ort_ver_string = '.'.join(_ort.__version__.split('.')[0:2])
38 max_ver = max(_ORT_OPSET_SUPPORT_TABLE, key=_ORT_OPSET_SUPPORT_TABLE.get)
39 if ort_ver_string > max_ver:
40 ort_ver_string = max_ver
41 return _ORT_OPSET_SUPPORT_TABLE.get(ort_ver_string, 11)
42
43
44def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(), extra_opset_version=1):
45 if opset_version == 0:
46 opset_version = get_opset_version_from_ort()
47 fn_mm = onnx.helper.make_model_gen_version if hasattr(onnx.helper, 'make_model_gen_version'
48 ) else onnx.helper.make_model
49 model = fn_mm(graph, opset_imports=[
50 onnx.helper.make_operatorsetid('ai.onnx', opset_version)])
51 model.opset_import.extend(
52 [onnx.helper.make_operatorsetid(extra_domain, extra_opset_version)])
53 return model
54
55
56class OrtPyFunction:
57
58 @classmethod
59 def get_ort_session_options(cls):
60 # ONNXRuntime has an issue to support reusing the SessionOptions object.
61 # Create a new one every time here
62 so = _ort.SessionOptions()
63 so.register_custom_ops_library(get_library_path())
64 return so
65
66 def __init__(self, cpu_only=None):
67 self._onnx_model = None
68 self.ort_session = None
69 self.default_inputs = {}
70 self.execution_providers = ['CPUExecutionProvider']
71 if not cpu_only:
72 if _ort.get_device() == 'GPU':
73 self.execution_providers = ['CUDAExecutionProvider']
74
75 def create_from_customop(self, op_type, *args, **kwargs):
76 cvt = kwargs.get('cvt', None)
77 if cvt is None:
78 cvt = args[0] if len(args) > 0 and isinstance(
79 args[0], CustomOpConverter) else None
80 args = args[1:]
81 else:
82 del kwargs['cvt']
83
84 new_kwargs = kwargs if cvt is None else cvt(**kwargs)
85 graph = SingleOpGraph.build_my_graph(op_type, *args, **new_kwargs)
86 self._bind(make_onnx_model(graph))
87 return self
88
89 def add_default_input(self, **kwargs):
90 inputs = {
91 ky_: val_ if isinstance(val_, (np.ndarray, np.generic)) else
92 np.asarray(list(val_), dtype=np.uint8) for ky_, val_ in kwargs.items()
93 }
94
95 self.default_inputs.update(inputs)
96
97 @property
98 def onnx_model(self):
99 assert self._oxml is not None, "No onnx model attached yet."
100 return self._oxml
101
102 @property
103 def input_names(self):
104 return [vi_.name for vi_ in self.onnx_model.graph.input]
105
106 @property
107 def output_names(self):
108 return [vi_.name for vi_ in self.onnx_model.graph.output]
109
110 def _bind(self, oxml, model_path=None):
111 self.inputs = list(oxml.graph.input)
112 self.outputs = list(oxml.graph.output)
113 self._oxml = oxml
114 if model_path is not None:
115 self.ort_session = _ort.InferenceSession(
116 model_path, self.get_ort_session_options(),
117 self.execution_providers)
118 return self
119
120 def _ensure_ort_session(self):
121 if self.ort_session is None:
122 sess = _ort.InferenceSession(
123 self.onnx_model.SerializeToString(), self.get_ort_session_options(),
124 self.execution_providers)
125 self.ort_session = sess
126
127 return self.ort_session
128
129 @staticmethod
130 def _get_kwarg_device(kwargs):
131 cpuonly = kwargs.get('cpu_only', None)
132 if cpuonly is not None:
133 del kwargs['cpu_only']
134 return cpuonly
135
136 @classmethod
137 def from_customop(cls, op_type, *args, **kwargs):
138 return cls(cls._get_kwarg_device(kwargs)).create_from_customop(op_type, *args, **kwargs)
139
140 @classmethod
141 def from_model(cls, path_or_model, *args, **kwargs):
142 mpath = None
143 if isinstance(path_or_model, str):
144 oxml = onnx.load_model(path_or_model)
145 mpath = path_or_model
146 else:
147 oxml = path_or_model
148 return cls(cls._get_kwarg_device(kwargs))._bind(oxml, mpath)
149
150 def _argument_map(self, *args, **kwargs):
151 idx = 0
152 feed = {}
153 for i_ in self.inputs:
154 if i_.name in self.default_inputs:
155 feed[i_.name] = self.default_inputs[i_.name]
156 continue
157
158 x = args[idx]
159 ts_x = np.array(x) if isinstance(x, (int, float, bool)) else x
160 # an annoying bug is numpy by default is int32, while pytorch is int64.
161 # so cast the input here automatically.
162 feed[i_.name] = \
163 ts_x.astype(
164 np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x
165 idx += 1
166
167 # feed.update(kwargs)
168 return feed
169
170 def __call__(self, *args, **kwargs):
171 self._ensure_ort_session()
172 outputs = self.ort_session.run(
173 None, self._argument_map(*args, **kwargs))
174 return outputs[0] if len(outputs) == 1 else tuple(outputs)
175
176
177def optimize_model(model_or_file, output_file):
178 sess_options = OrtPyFunction.get_ort_session_options()
179 sess_options.graph_optimization_level = _ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
180 sess_options.optimized_model_filepath = output_file
181 _ort.InferenceSession(model_or_file if isinstance(model_or_file, str)
182 else model_or_file.SerializeToString(), sess_options)
183
184
185ONNXRuntimeError = _ort.capi.onnxruntime_pybind11_state.Fail