microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
3e82549bcb53bc93d04c25d64f16a52d495725c1

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/eager_op.py

111lines · modecode

1# Copyright (c) Microsoft Corporation. All rights reserved.
2# Licensed under the MIT License. See License.txt in the project root for
3# license information.
4###############################################################################
5
6import numpy as np
7import onnxruntime as _ort
8from ._ocos import default_opset_domain, get_library_path # noqa
9from ._cuops import * # noqa
10
11
12def _get_opset_version_from_ort():
13 _ORT_OPSET_SUPPORT_TABLE = {
14 "1.5": 11,
15 "1.6": 12,
16 "1.7": 13,
17 "1.8": 14
18 }
19
20 ort_ver_string = '.'.join(_ort.__version__.split('.')[0:2])
21 return _ORT_OPSET_SUPPORT_TABLE.get(ort_ver_string, 11)
22
23
24class EagerOp:
25
26 @classmethod
27 def get_ort_session_options(cls):
28 # ONNXRuntime has an issue to support reusing the SessionOptions object.
29 # Create a new one every time here
30 so = _ort.SessionOptions()
31 so.register_custom_ops_library(get_library_path())
32 return so
33
34 def __init__(self):
35 self._onnx_model = None
36 self.ort_session = None
37 self.default_inputs = {}
38
39 def create_from_customop(self, op_type, *args, **kwargs):
40 graph = SingleOpGraph.build_my_graph(op_type, *args, **kwargs)
41 model = onnx.helper.make_model(graph, opset_imports=[
42 onnx.helper.make_operatorsetid('ai.onnx', _get_opset_version_from_ort()),
43 onnx.helper.make_operatorsetid(default_opset_domain(), 1)])
44 self._bind(model)
45 return self
46
47 def add_default_input(self, **kwargs):
48 inputs = {
49 ky_: val_ if isinstance(val_, (np.ndarray, np.generic)) else \
50 np.asarray(list(val_), dtype=np.uint8) for ky_, val_ in kwargs.items()
51 }
52
53 self.default_inputs.update(inputs)
54
55 @property
56 def onnx_model(self):
57 assert self._oxml is not None, "No onnx model attached yet."
58 return self._oxml
59
60 @property
61 def input_names(self):
62 return [vi_.name for vi_ in self.onnx_model.graph.input]
63
64 @property
65 def output_names(self):
66 return [vi_.name for vi_ in self.onnx_model.graph.output]
67
68 def _bind(self, oxml):
69 self.inputs = list(oxml.graph.input)
70 self.output = list(oxml.graph.output)
71 self._oxml = oxml
72 return self
73
74 def _ensure_ort_session(self):
75 if self.ort_session is None:
76 sess = _ort.InferenceSession(self.onnx_model.SerializeToString(), self.get_ort_session_options())
77 self.ort_session = sess
78
79 return self.ort_session
80
81 @classmethod
82 def from_customop(cls, op_type, *args, **kwargs):
83 return cls().create_from_customop(op_type, *args, **kwargs)
84
85 @classmethod
86 def from_model(cls, path_or_model, *args, **kwargs):
87 return cls()._bind(onnx.load_model(path_or_model) if isinstance(path_or_model, str) else path_or_model)
88
89 def _argument_map(self, *args, **kwargs):
90 idx = 0
91 feed = {}
92 for i_ in self.inputs:
93 if i_.name in self.default_inputs:
94 feed[i_.name] = self.default_inputs[i_.name]
95 continue
96
97 x = args[idx]
98 ts_x = np.array(x) if isinstance(x, (int, float, bool)) else x
99 # an annoying bug is numpy by default is int32, while pytorch is int64.
100 # so cast the input here automatically.
101 feed[i_.name] = \
102 ts_x.astype(np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x
103 idx += 1
104
105 # feed.update(kwargs)
106 return feed
107
108 def __call__(self, *args, **kwargs):
109 self._ensure_ort_session()
110 outputs = self.ort_session.run(None, self._argument_map(*args, **kwargs))
111 return outputs[0] if len(outputs) == 1 else outputs
112