microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
leca/gqa

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/onnxruntime_customop.hpp

258lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4//.A very thin wrapper of ONNXRuntime Custom Operator Callback ABI, which
5// is only used in the custom-op kernels. For the general ORT C++ invocation, like end-to-end
6// testing, the ONNXRuntime public C++ APIs should be used since there is no binary compatible requirement.
7
8#pragma once
9#include <cstdint>
10#include <cstddef>
11#include <array>
12#include <memory>
13#include <string>
14#include <vector>
15#include <utility>
16#include <type_traits>
17#include <optional>
18#include <functional>
19
20#include "exceptions.h"
21#include "onnxruntime_no_customop.h"
22#include "onnxruntime_cpp_api_legacy.hpp"
23#include "onnxruntime_extensions.h"
24#include "custom_op_lite.h"
25
26#define MIN_ORT_VERSION_SUPPORTED 11
27
28namespace Ort {
29namespace Custom {
30
31template <typename... Args>
32struct FunctionKernel {
33 using ComputeFn = std::function<OrtStatusPtr(Args...)>;
34
35 OrtStatusPtr Compute(Args... args) const {
36 return compute_fn_(args...);
37 }
38
39 ComputeFn compute_fn_;
40};
41
42// primary template handles types that have no nested ::type member:
43template <class, class = void>
44struct IsFunctionKernel {
45 typedef std::false_type type;
46};
47
48// specialization recognizes types that do have a nested ::type member:
49template <class T>
50struct IsFunctionKernel<T, std::void_t<typename T::ComputeFn>> {
51 typedef std::true_type type;
52};
53
54// Helper type
55template <typename T>
56struct ComputeArgsList;
57
58// Specialization for member function
59template <typename C, typename... Args>
60struct ComputeArgsList<OrtStatusPtr (C::*)(Args...) const> {
61 using FunctionType = OrtStatusPtr (*)(Args...);
62 using MemberFunctionType = OrtStatusPtr (C::*)(Args...) const;
63};
64
65template <typename T, typename = void>
66struct CustomOp_defined_getInputMemoryType : std::false_type {};
67
68template <typename T>
69struct CustomOp_defined_getInputMemoryType<T, std::void_t<decltype(&T::GetInputMemoryType)>> : std::true_type {};
70
71template <typename T, typename = void>
72struct CustomOp_defined_getMayInplace : std::false_type {};
73
74template <typename T>
75struct CustomOp_defined_getMayInplace<T, std::void_t<decltype(&T::GetMayInplace)>> : std::true_type {};
76
77template <typename T, typename = void>
78struct CustomOp_defined_releaseMayInplace : std::false_type {};
79
80template <typename T>
81struct CustomOp_defined_releaseMayInplace<T, std::void_t<decltype(&T::ReleaseMayInplace)>> : std::true_type {};
82
83template <typename CustomOpKernel>
84struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
85 using ComputeFunction = decltype(&CustomOpKernel::Compute);
86 using RegularComputeType = typename ComputeArgsList<ComputeFunction>::FunctionType;
87
88 template <typename... Args>
89 using MemberComputeType = OrtStatusPtr (CustomOpKernel::*)(Args...) const;
90
91 struct KernelEx : public CustomOpKernel {
92 struct {
93 std::string ep_{};
94 std::unique_ptr<OrtW::CustomOpApi> api_;
95 } extra_;
96 };
97
98 template <typename T>
99 static OrtStatusPtr InitKernel(KernelEx& kernel,
100 const OrtApi& api, const OrtKernelInfo& info, RegularComputeType fn, T t) {
101 return kernel.OnModelAttach(api, info);
102 }
103
104 static OrtStatusPtr InitKernel(
105 KernelEx& kernel,
106 const OrtApi& api, const OrtKernelInfo& info, RegularComputeType fn, std::true_type) {
107 kernel.compute_fn_ = fn;
108 return nullptr;
109 }
110
111 template <typename... Args>
112 void ParseArgs(MemberComputeType<Args...> fn) {
113 OrtLiteCustomOp::ParseArgs<Args...>(OrtLiteCustomOp::input_types_, OrtLiteCustomOp::output_types_);
114 }
115
116 // TODO: consider to disable these legacy functions for mobile build to save binary size
117 template <typename... Args>
118 void DefineCallbackFunctionsLegacy(MemberComputeType<Args...> fn, RegularComputeType regular_fn) {
119 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
120 auto self = static_cast<const OrtLiteCustomStructV2<CustomOpKernel>*>(this_);
121 auto kernel = std::make_unique<KernelEx>();
122 typedef typename IsFunctionKernel<CustomOpKernel>::type type_flag;
123 OrtStatusPtr status = InitKernel(*kernel, *ort_api, *info, self->regular_fn_, type_flag());
124 OrtW::ThrowOnError(*ort_api, status);
125
126 kernel->extra_.ep_ = self->execution_provider_;
127 kernel->extra_.api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
128 return reinterpret_cast<void*>(kernel.release());
129 };
130
131 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
132 auto kernel = reinterpret_cast<KernelEx*>(op_kernel);
133 std::vector<TensorPtr> tensors;
134 auto t = CreateTuple<0, 0, Args...>(kernel->extra_.api_.get(),
135 context,
136 tensors,
137 kernel->extra_.api_->KernelContext_GetInputCount(context),
138 kernel->extra_.api_->KernelContext_GetOutputCount(context),
139 kernel->extra_.ep_);
140 std::apply([kernel](Args const&... t_args) {
141 auto status = kernel->Compute(t_args...); OrtW::API::ThrowOnError(status); }, t);
142 };
143
144 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
145 std::unique_ptr<KernelEx>(reinterpret_cast<KernelEx*>(op_kernel)).reset();
146 };
147 }
148
149#if ORT_API_VERSION >= 16
150 template <typename... Args>
151 void DefineCallbackFunctions(MemberComputeType<Args...> fn, RegularComputeType regular_fn) {
152 OrtCustomOp::CreateKernel = nullptr;
153 OrtCustomOp::KernelCompute = nullptr;
154
155 if constexpr (CustomOp_defined_getInputMemoryType<CustomOpKernel>::value) {
156 OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* /*this_*/, size_t index) -> OrtMemType {
157 return CustomOpKernel::GetInputMemoryType(index);
158 };
159 }
160
161#if ORT_API_VERSION >= 18
162 if constexpr (CustomOp_defined_getMayInplace<CustomOpKernel>::value) {
163 OrtCustomOp::GetMayInplace = [](int** input_index, int** output_index) -> size_t {
164 return CustomOpKernel::GetMayInplace(input_index, output_index);
165 };
166 }
167 if constexpr (CustomOp_defined_releaseMayInplace<CustomOpKernel>::value) {
168 OrtCustomOp::ReleaseMayInplace = [](int* input_index, int* output_index) -> void {
169 CustomOpKernel::ReleaseMayInplace(input_index, output_index);
170 };
171 }
172#endif
173
174 OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_,
175 const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
176 if (api == nullptr) {
177 assert(false && "Got a null pointer for ORT api on calling CreateKernelV2");
178 // should never happened, what we can do?
179 return nullptr;
180 }
181
182 if (this_ == nullptr || info == nullptr || op_kernel == nullptr) {
183 return api->CreateStatus(ORT_INVALID_ARGUMENT, "OrtCustomOp::CreateKernelV2: received a null pointer");
184 }
185
186 auto self = static_cast<const OrtLiteCustomStructV2<CustomOpKernel>*>(this_);
187 auto kernel = std::make_unique<KernelEx>();
188 if (kernel == nullptr) {
189 return api->CreateStatus(ORT_FAIL, "OrtCustomOp::CreateKernelV2: failed to new a kernel, OOM?");
190 }
191
192 typedef typename IsFunctionKernel<CustomOpKernel>::type flag_type;
193 OrtStatusPtr status = InitKernel(*kernel, *api, *info, self->regular_fn_, flag_type());
194 if (status == nullptr) {
195 kernel->extra_.ep_ = self->execution_provider_;
196 kernel->extra_.api_ = std::make_unique<OrtW::CustomOpApi>(*api);
197 *op_kernel = reinterpret_cast<void*>(kernel.release());
198 }
199
200 return status;
201 };
202
203 OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
204 auto kernel = reinterpret_cast<KernelEx*>(op_kernel);
205 std::vector<TensorPtr> tensors;
206 auto t = CreateTuple<0, 0, Args...>(kernel->extra_.api_.get(),
207 context,
208 tensors,
209 kernel->extra_.api_->KernelContext_GetInputCount(context),
210 kernel->extra_.api_->KernelContext_GetOutputCount(context),
211 kernel->extra_.ep_);
212 return std::apply([kernel](Args const&... t_args) { return kernel->Compute(t_args...); }, t);
213 };
214
215 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
216 std::unique_ptr<KernelEx>(reinterpret_cast<KernelEx*>(op_kernel)).reset();
217 };
218 }
219#endif // ORT_API_VERSION >= 16
220
221 OrtLiteCustomStructV2(const char* op_name,
222 const char* execution_provider,
223 RegularComputeType fn_compute = nullptr)
224 : OrtLiteCustomOp(op_name, execution_provider), regular_fn_(fn_compute) {
225 ParseArgs(&CustomOpKernel::Compute);
226
227#if ORT_API_VERSION >= 16
228 if (OrtCustomOp::version > 15) {
229 DefineCallbackFunctions(&CustomOpKernel::Compute, fn_compute);
230 } else
231#endif // ORT_API_VERSION >= 16
232 {
233 DefineCallbackFunctionsLegacy(&CustomOpKernel::Compute, fn_compute);
234 }
235 }
236
237 RegularComputeType regular_fn_{};
238};
239
240template <typename... Args>
241OrtLiteCustomOp* CreateLiteCustomOpV2(const char* op_name,
242 const char* execution_provider,
243 OrtStatusPtr (*custom_compute_fn)(Args...)) {
244 using LiteOp = OrtLiteCustomStructV2<FunctionKernel<Args...>>;
245 return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn).release();
246}
247
248template <typename OpKernel>
249OrtLiteCustomOp* CreateLiteCustomOpV2(const char* op_name,
250 const char* execution_provider) {
251 using LiteOp = OrtLiteCustomStructV2<OpKernel>;
252 return std::make_unique<LiteOp>(op_name, execution_provider).release();
253}
254
255} // namespace Custom
256} // namespace Ort
257
258namespace ortc = Ort::Custom;