microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
onnxruntime_extensions/_ortapi2.py
185lines · modecode
| 1 | # Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | # Licensed under the MIT License. See License.txt in the project root for |
| 3 | # license information. |
| 4 | ############################################################################### |
| 5 | |
| 6 | import numpy as np |
| 7 | from ._ocos import default_opset_domain, get_library_path # noqa |
| 8 | from ._cuops import onnx, onnx_proto, CustomOpConverter, SingleOpGraph |
| 9 | |
| 10 | _ort_check_passed = False |
| 11 | try: |
| 12 | from packaging import version as _ver |
| 13 | import onnxruntime as _ort |
| 14 | if _ver.parse(_ort.__version__) >= _ver.parse("1.10.0"): |
| 15 | _ort_check_passed = True |
| 16 | except ImportError: |
| 17 | pass |
| 18 | |
| 19 | if not _ort_check_passed: |
| 20 | raise RuntimeError("please install ONNXRuntime/ONNXRuntime-GPU >= 1.10.0") |
| 21 | |
| 22 | |
| 23 | def get_opset_version_from_ort(): |
| 24 | _ORT_OPSET_SUPPORT_TABLE = { |
| 25 | "1.5": 11, |
| 26 | "1.6": 12, |
| 27 | "1.7": 13, |
| 28 | "1.8": 14, |
| 29 | "1.9": 15, |
| 30 | "1.10": 15, |
| 31 | "1.11": 16, |
| 32 | "1.12": 17, |
| 33 | "1.13": 17, |
| 34 | "1.14": 18, |
| 35 | } |
| 36 | |
| 37 | ort_ver_string = '.'.join(_ort.__version__.split('.')[0:2]) |
| 38 | max_ver = max(_ORT_OPSET_SUPPORT_TABLE, key=_ORT_OPSET_SUPPORT_TABLE.get) |
| 39 | if ort_ver_string > max_ver: |
| 40 | ort_ver_string = max_ver |
| 41 | return _ORT_OPSET_SUPPORT_TABLE.get(ort_ver_string, 11) |
| 42 | |
| 43 | |
| 44 | def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(), extra_opset_version=1): |
| 45 | if opset_version == 0: |
| 46 | opset_version = get_opset_version_from_ort() |
| 47 | fn_mm = onnx.helper.make_model_gen_version if hasattr(onnx.helper, 'make_model_gen_version' |
| 48 | ) else onnx.helper.make_model |
| 49 | model = fn_mm(graph, opset_imports=[ |
| 50 | onnx.helper.make_operatorsetid('ai.onnx', opset_version)]) |
| 51 | model.opset_import.extend( |
| 52 | [onnx.helper.make_operatorsetid(extra_domain, extra_opset_version)]) |
| 53 | return model |
| 54 | |
| 55 | |
| 56 | class OrtPyFunction: |
| 57 | |
| 58 | @classmethod |
| 59 | def get_ort_session_options(cls): |
| 60 | # ONNXRuntime has an issue to support reusing the SessionOptions object. |
| 61 | # Create a new one every time here |
| 62 | so = _ort.SessionOptions() |
| 63 | so.register_custom_ops_library(get_library_path()) |
| 64 | return so |
| 65 | |
| 66 | def __init__(self, cpu_only=None): |
| 67 | self._onnx_model = None |
| 68 | self.ort_session = None |
| 69 | self.default_inputs = {} |
| 70 | self.execution_providers = ['CPUExecutionProvider'] |
| 71 | if not cpu_only: |
| 72 | if _ort.get_device() == 'GPU': |
| 73 | self.execution_providers = ['CUDAExecutionProvider'] |
| 74 | |
| 75 | def create_from_customop(self, op_type, *args, **kwargs): |
| 76 | cvt = kwargs.get('cvt', None) |
| 77 | if cvt is None: |
| 78 | cvt = args[0] if len(args) > 0 and isinstance( |
| 79 | args[0], CustomOpConverter) else None |
| 80 | args = args[1:] |
| 81 | else: |
| 82 | del kwargs['cvt'] |
| 83 | |
| 84 | new_kwargs = kwargs if cvt is None else cvt(**kwargs) |
| 85 | graph = SingleOpGraph.build_my_graph(op_type, *args, **new_kwargs) |
| 86 | self._bind(make_onnx_model(graph)) |
| 87 | return self |
| 88 | |
| 89 | def add_default_input(self, **kwargs): |
| 90 | inputs = { |
| 91 | ky_: val_ if isinstance(val_, (np.ndarray, np.generic)) else |
| 92 | np.asarray(list(val_), dtype=np.uint8) for ky_, val_ in kwargs.items() |
| 93 | } |
| 94 | |
| 95 | self.default_inputs.update(inputs) |
| 96 | |
| 97 | @property |
| 98 | def onnx_model(self): |
| 99 | assert self._oxml is not None, "No onnx model attached yet." |
| 100 | return self._oxml |
| 101 | |
| 102 | @property |
| 103 | def input_names(self): |
| 104 | return [vi_.name for vi_ in self.onnx_model.graph.input] |
| 105 | |
| 106 | @property |
| 107 | def output_names(self): |
| 108 | return [vi_.name for vi_ in self.onnx_model.graph.output] |
| 109 | |
| 110 | def _bind(self, oxml, model_path=None): |
| 111 | self.inputs = list(oxml.graph.input) |
| 112 | self.outputs = list(oxml.graph.output) |
| 113 | self._oxml = oxml |
| 114 | if model_path is not None: |
| 115 | self.ort_session = _ort.InferenceSession( |
| 116 | model_path, self.get_ort_session_options(), |
| 117 | self.execution_providers) |
| 118 | return self |
| 119 | |
| 120 | def _ensure_ort_session(self): |
| 121 | if self.ort_session is None: |
| 122 | sess = _ort.InferenceSession( |
| 123 | self.onnx_model.SerializeToString(), self.get_ort_session_options(), |
| 124 | self.execution_providers) |
| 125 | self.ort_session = sess |
| 126 | |
| 127 | return self.ort_session |
| 128 | |
| 129 | @staticmethod |
| 130 | def _get_kwarg_device(kwargs): |
| 131 | cpuonly = kwargs.get('cpu_only', None) |
| 132 | if cpuonly is not None: |
| 133 | del kwargs['cpu_only'] |
| 134 | return cpuonly |
| 135 | |
| 136 | @classmethod |
| 137 | def from_customop(cls, op_type, *args, **kwargs): |
| 138 | return cls(cls._get_kwarg_device(kwargs)).create_from_customop(op_type, *args, **kwargs) |
| 139 | |
| 140 | @classmethod |
| 141 | def from_model(cls, path_or_model, *args, **kwargs): |
| 142 | mpath = None |
| 143 | if isinstance(path_or_model, str): |
| 144 | oxml = onnx.load_model(path_or_model) |
| 145 | mpath = path_or_model |
| 146 | else: |
| 147 | oxml = path_or_model |
| 148 | return cls(cls._get_kwarg_device(kwargs))._bind(oxml, mpath) |
| 149 | |
| 150 | def _argument_map(self, *args, **kwargs): |
| 151 | idx = 0 |
| 152 | feed = {} |
| 153 | for i_ in self.inputs: |
| 154 | if i_.name in self.default_inputs: |
| 155 | feed[i_.name] = self.default_inputs[i_.name] |
| 156 | continue |
| 157 | |
| 158 | x = args[idx] |
| 159 | ts_x = np.array(x) if isinstance(x, (int, float, bool)) else x |
| 160 | # an annoying bug is numpy by default is int32, while pytorch is int64. |
| 161 | # so cast the input here automatically. |
| 162 | feed[i_.name] = \ |
| 163 | ts_x.astype( |
| 164 | np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x |
| 165 | idx += 1 |
| 166 | |
| 167 | # feed.update(kwargs) |
| 168 | return feed |
| 169 | |
| 170 | def __call__(self, *args, **kwargs): |
| 171 | self._ensure_ort_session() |
| 172 | outputs = self.ort_session.run( |
| 173 | None, self._argument_map(*args, **kwargs)) |
| 174 | return outputs[0] if len(outputs) == 1 else tuple(outputs) |
| 175 | |
| 176 | |
| 177 | def optimize_model(model_or_file, output_file): |
| 178 | sess_options = OrtPyFunction.get_ort_session_options() |
| 179 | sess_options.graph_optimization_level = _ort.GraphOptimizationLevel.ORT_ENABLE_BASIC |
| 180 | sess_options.optimized_model_filepath = output_file |
| 181 | _ort.InferenceSession(model_or_file if isinstance(model_or_file, str) |
| 182 | else model_or_file.SerializeToString(), sess_options) |
| 183 | |
| 184 | |
| 185 | ONNXRuntimeError = _ort.capi.onnxruntime_pybind11_state.Fail |