microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
292a0297b4388e49a0afcccbdee24f0f0bf68de9

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/_ortapi2.py

135lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License. See License.txt in the project root for
3# license information.
4###############################################################################
5
6import numpy as np
7import onnxruntime as _ort
8from ._ocos import default_opset_domain, get_library_path # noqa
9from ._cuops import * # noqa
10
11
12def get_opset_version_from_ort():
13 _ORT_OPSET_SUPPORT_TABLE = {
14 "1.5": 11,
15 "1.6": 12,
16 "1.7": 13,
17 "1.8": 14,
18 "1.9": 15,
19 "1.10": 15,
20 "1.11": 16
21 }
22
23 ort_ver_string = '.'.join(_ort.__version__.split('.')[0:2])
24 return _ORT_OPSET_SUPPORT_TABLE.get(ort_ver_string, 11)
25
26
27def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(), extra_opset_version=1):
28 if opset_version == 0:
29 opset_version = get_opset_version_from_ort()
30 fn_mm = onnx.helper.make_model_gen_version if hasattr(onnx.helper, 'make_model_gen_version'
31 ) else onnx.helper.make_model
32 model = fn_mm(graph, opset_imports=[
33 onnx.helper.make_operatorsetid('ai.onnx', opset_version)])
34 model.opset_import.extend([onnx.helper.make_operatorsetid(extra_domain, extra_opset_version)])
35 return model
36
37
38class OrtPyFunction:
39
40 __name__ = 'OrtPyFunction'
41
42 @classmethod
43 def get_ort_session_options(cls):
44 # ONNXRuntime has an issue to support reusing the SessionOptions object.
45 # Create a new one every time here
46 so = _ort.SessionOptions()
47 so.register_custom_ops_library(get_library_path())
48 return so
49
50 def __init__(self):
51 self._onnx_model = None
52 self.ort_session = None
53 self.default_inputs = {}
54
55 def create_from_customop(self, op_type, *args, **kwargs):
56 graph = SingleOpGraph.build_my_graph(op_type, *args, **kwargs)
57 self._bind(make_onnx_model(graph))
58 return self
59
60 def add_default_input(self, **kwargs):
61 inputs = {
62 ky_: val_ if isinstance(val_, (np.ndarray, np.generic)) else \
63 np.asarray(list(val_), dtype=np.uint8) for ky_, val_ in kwargs.items()
64 }
65
66 self.default_inputs.update(inputs)
67
68 @property
69 def onnx_model(self):
70 assert self._oxml is not None, "No onnx model attached yet."
71 return self._oxml
72
73 @property
74 def input_names(self):
75 return [vi_.name for vi_ in self.onnx_model.graph.input]
76
77 @property
78 def output_names(self):
79 return [vi_.name for vi_ in self.onnx_model.graph.output]
80
81 def _bind(self, oxml):
82 self.inputs = list(oxml.graph.input)
83 self.outputs = list(oxml.graph.output)
84 self._oxml = oxml
85 return self
86
87 def _ensure_ort_session(self):
88 if self.ort_session is None:
89 sess = _ort.InferenceSession(self.onnx_model.SerializeToString(), self.get_ort_session_options())
90 self.ort_session = sess
91
92 return self.ort_session
93
94 @classmethod
95 def from_customop(cls, op_type, *args, **kwargs):
96 return cls().create_from_customop(op_type, *args, **kwargs)
97
98 @classmethod
99 def from_model(cls, path_or_model, *args, **kwargs):
100 return cls()._bind(onnx.load_model(path_or_model) if isinstance(path_or_model, str) else path_or_model)
101
102 def _argument_map(self, *args, **kwargs):
103 idx = 0
104 feed = {}
105 for i_ in self.inputs:
106 if i_.name in self.default_inputs:
107 feed[i_.name] = self.default_inputs[i_.name]
108 continue
109
110 x = args[idx]
111 ts_x = np.array(x) if isinstance(x, (int, float, bool)) else x
112 # an annoying bug is numpy by default is int32, while pytorch is int64.
113 # so cast the input here automatically.
114 feed[i_.name] = \
115 ts_x.astype(np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x
116 idx += 1
117
118 # feed.update(kwargs)
119 return feed
120
121 def __call__(self, *args, **kwargs):
122 self._ensure_ort_session()
123 outputs = self.ort_session.run(None, self._argument_map(*args, **kwargs))
124 return outputs[0] if len(outputs) == 1 else tuple(outputs)
125
126
127def optimize_model(model_or_file, output_file):
128 sess_options = OrtPyFunction.get_ort_session_options()
129 sess_options.graph_optimization_level = _ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
130 sess_options.optimized_model_filepath = output_file
131 _ort.InferenceSession(model_or_file if isinstance(model_or_file, str)
132 else model_or_file.SerializeToString(), sess_options)
133
134
135ONNXRuntimeError = _ort.capi.onnxruntime_pybind11_state.Fail
136