microsoft/onnxruntime-extensions

Public

mirrored from https://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
edgchen1/retry_android_emulator_startup

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/onnxruntime_customop.hpp

329lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4//.A very thin wrapper of ONNXRuntime Custom Operator Callback ABI, which
5// is only used in the custom-op kernels. For the general ORT C++ invocation, like end-to-end
6// testing, the ONNXRuntime public C++ APIs should be used since there is no binary compatible requirement.
7
8#pragma once
9#include <cstddef>
10#include <array>
11#include <memory>
12#include <string>
13#include <vector>
14#include <utility>
15#include <type_traits>
16#include <optional>
17
18#include "onnxruntime_c_api.h"
19#include "exceptions.h"
20#include "onnxruntime_cpp_api_legacy.hpp"
21#include "onnxruntime_extensions.h"
22#include "custom_op_lite.h"
23
24#define MIN_ORT_VERSION_SUPPORTED 11
25
26// namespace of ORT ABI Wrapper
27namespace OrtW {
28
29class API {
30 // To use ONNX C ABI in a way like OrtW::API::CreateStatus.
31 public:
32 static API& instance(const OrtApi* ort_api = nullptr) noexcept {
33 static API self(ort_api);
34 return self;
35 }
36
37 static OrtStatusPtr CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept {
38 return instance()->CreateStatus(code, msg);
39 }
40
41 static void ReleaseStatus(OrtStatusPtr ptr) noexcept {
42 instance()->ReleaseStatus(ptr);
43 }
44
45 template <typename T>
46 static OrtStatusPtr KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept;
47
48 static void ThrowOnError(OrtStatusPtr ptr) {
49 OrtW::ThrowOnError(instance().api_, ptr);
50 }
51
52 private:
53 const OrtApi* operator->() const {
54 return &api_;
55 }
56
57 API(const OrtApi* api) : api_(*api) {
58 if (api == nullptr) {
59 ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION);
60 }
61 }
62
63 const OrtApi& api_;
64};
65
66template <>
67inline OrtStatusPtr API::KernelInfoGetAttribute<int64_t>(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept {
68 return instance()->KernelInfoGetAttribute_int64(&info, name, &value);
69}
70
71template <>
72inline OrtStatusPtr API::KernelInfoGetAttribute<float>(const OrtKernelInfo& info, const char* name, float& value) noexcept {
73 return instance()->KernelInfoGetAttribute_float(&info, name, &value);
74}
75
76template <>
77inline OrtStatusPtr API::KernelInfoGetAttribute<std::string>(const OrtKernelInfo& info, const char* name, std::string& value) noexcept {
78 size_t size = 0;
79 std::string out;
80 // Feed nullptr for the data buffer to query the true size of the string attribute
81 OrtStatus* status = instance()->KernelInfoGetAttribute_string(&info, name, nullptr, &size);
82 if (status == nullptr) {
83 out.resize(size);
84 status = instance()->KernelInfoGetAttribute_string(&info, name, &out[0], &size);
85 out.resize(size - 1); // remove the terminating character '\0'
86 }
87
88 if (status == nullptr) {
89 value = std::move(out);
90 }
91
92 return status;
93}
94
95template <class T>
96inline OrtStatusPtr GetOpAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept {
97 if (auto status = API::KernelInfoGetAttribute(info, name, value); status) {
98 // Ideally, we should know which kind of error code can be ignored, but it is not available now.
99 // Just ignore all of them.
100 API::ReleaseStatus(status);
101 }
102
103 return nullptr;
104}
105
106inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) {
107 return API::CreateStatus(code, msg);
108}
109
110inline OrtStatusPtr CreateStatus(const std::string& msg, OrtErrorCode code) {
111 return API::CreateStatus(code, msg.c_str());
112}
113
114inline void ReleaseStatus(OrtStatusPtr& status) {
115 API::ReleaseStatus(status);
116 status = nullptr;
117}
118
119} // namespace OrtW
120
121#define ORTX_RETURN_IF_ERROR(expr) \
122 do { \
123 auto _status = (expr); \
124 if (_status != nullptr) { \
125 return _status; \
126 } \
127 } while (0)
128
129namespace Ort {
130namespace Custom {
131
132#ifdef USE_CUDA
133///////////////////////////////////////////////////////////////////////////
134// TODO: include the definition from the header file in ONNXRuntime
135struct CudaContext {};
136
137#endif // USE_CUDA
138
139template <typename... Args>
140struct FunctionKernel {
141 using ComputeFn = std::function<OrtStatusPtr(Args...)>;
142
143 OrtStatusPtr Compute(Args... args) const {
144 return compute_fn_(args...);
145 }
146
147 ComputeFn compute_fn_;
148};
149
150// primary template handles types that have no nested ::type member:
151template <class, class = void>
152struct IsFunctionKernel {
153 typedef std::false_type type;
154};
155
156// specialization recognizes types that do have a nested ::type member:
157template <class T>
158struct IsFunctionKernel<T, std::void_t<typename T::ComputeFn>> {
159 typedef std::true_type type;
160};
161
162// Helper type
163template <typename T>
164struct ComputeArgsList;
165
166// Specialization for member function
167template <typename C, typename... Args>
168struct ComputeArgsList<OrtStatusPtr (C::*)(Args...) const> {
169 using FunctionType = OrtStatusPtr (*)(Args...);
170 using MemberFunctionType = OrtStatusPtr (C::*)(Args...) const;
171};
172
173template <typename CustomOpKernel>
174struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
175 using ComputeFunction = decltype(&CustomOpKernel::Compute);
176 using RegularComputeType = typename ComputeArgsList<ComputeFunction>::FunctionType;
177
178 template <typename... Args>
179 using MemberComputeType = OrtStatusPtr (CustomOpKernel::*)(Args...) const;
180
181 struct KernelEx : public CustomOpKernel {
182 struct {
183 std::string ep_{};
184 std::unique_ptr<OrtW::CustomOpApi> api_;
185 } extra_;
186 };
187
188 template <typename T>
189 static OrtStatusPtr InitKernel(KernelEx& kernel,
190 const OrtApi& api, const OrtKernelInfo& info, RegularComputeType fn, T t) {
191 return kernel.OnModelAttach(api, info);
192 }
193
194 static OrtStatusPtr InitKernel(
195 KernelEx& kernel,
196 const OrtApi& api, const OrtKernelInfo& info, RegularComputeType fn, std::true_type) {
197 kernel.compute_fn_ = fn;
198 return nullptr;
199 }
200
201 template <typename... Args>
202 void ParseArgs(MemberComputeType<Args...> fn) {
203 OrtLiteCustomOp::ParseArgs<Args...>(OrtLiteCustomOp::input_types_, OrtLiteCustomOp::output_types_);
204 }
205
206 // TODO: consider to disable these legacy functions for mobile build to save binary size
207 template <typename... Args>
208 void DefineCallbackFunctionsLegacy(MemberComputeType<Args...> fn, RegularComputeType regular_fn) {
209 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
210 auto self = static_cast<const OrtLiteCustomStructV2<CustomOpKernel>*>(this_);
211 auto kernel = std::make_unique<KernelEx>();
212 typedef typename IsFunctionKernel<CustomOpKernel>::type type_flag;
213 OrtStatusPtr status = InitKernel(*kernel, *ort_api, *info, self->regular_fn_, type_flag());
214 OrtW::ThrowOnError(*ort_api, status);
215
216 kernel->extra_.ep_ = self->execution_provider_;
217 kernel->extra_.api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
218 return reinterpret_cast<void*>(kernel.release());
219 };
220
221 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
222 auto kernel = reinterpret_cast<KernelEx*>(op_kernel);
223 std::vector<TensorPtr> tensors;
224 auto t = CreateTuple<0, 0, Args...>(kernel->extra_.api_.get(),
225 context,
226 tensors,
227 kernel->extra_.api_->KernelContext_GetInputCount(context),
228 kernel->extra_.api_->KernelContext_GetOutputCount(context),
229 kernel->extra_.ep_);
230 std::apply([kernel](Args const&... t_args) {
231 auto status = kernel->Compute(t_args...); OrtW::API::ThrowOnError(status); }, t);
232 };
233
234 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
235 std::unique_ptr<KernelEx>(reinterpret_cast<KernelEx*>(op_kernel)).reset();
236 };
237 }
238
239#if ORT_API_VERSION > 15
240 template <typename... Args>
241 void DefineCallbackFunctions(MemberComputeType<Args...> fn, RegularComputeType regular_fn) {
242 OrtCustomOp::CreateKernel = nullptr;
243 OrtCustomOp::KernelCompute = nullptr;
244
245 OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_,
246 const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
247 if (api == nullptr) {
248 assert(false && "Got a null pointer for ORT api on calling CreateKernelV2");
249 // should never happened, what we can do?
250 return nullptr;
251 }
252
253 if (this_ == nullptr || info == nullptr || op_kernel == nullptr) {
254 return api->CreateStatus(ORT_INVALID_ARGUMENT, "OrtCustomOp::CreateKernelV2: received a null pointer");
255 }
256
257 auto self = static_cast<const OrtLiteCustomStructV2<CustomOpKernel>*>(this_);
258 auto kernel = std::make_unique<KernelEx>();
259 if (kernel == nullptr) {
260 return api->CreateStatus(ORT_FAIL, "OrtCustomOp::CreateKernelV2: failed to new a kernel, OOM?");
261 }
262
263 typedef typename IsFunctionKernel<CustomOpKernel>::type flag_type;
264 OrtStatusPtr status = InitKernel(*kernel, *api, *info, self->regular_fn_, flag_type());
265 if (status == nullptr) {
266 kernel->extra_.ep_ = self->execution_provider_;
267 kernel->extra_.api_ = std::make_unique<OrtW::CustomOpApi>(*api);
268 *op_kernel = reinterpret_cast<void*>(kernel.release());
269 }
270
271 return status;
272 };
273
274 OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
275 auto kernel = reinterpret_cast<KernelEx*>(op_kernel);
276 std::vector<TensorPtr> tensors;
277 auto t = CreateTuple<0, 0, Args...>(kernel->extra_.api_.get(),
278 context,
279 tensors,
280 kernel->extra_.api_->KernelContext_GetInputCount(context),
281 kernel->extra_.api_->KernelContext_GetOutputCount(context),
282 kernel->extra_.ep_);
283 return std::apply([kernel](Args const&... t_args) { return kernel->Compute(t_args...); }, t);
284 };
285
286 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
287 std::unique_ptr<KernelEx>(reinterpret_cast<KernelEx*>(op_kernel)).reset();
288 };
289 }
290#endif // ORT_API_VERSION > 15
291
292 OrtLiteCustomStructV2(const char* op_name,
293 const char* execution_provider,
294 RegularComputeType fn_compute = nullptr)
295 : OrtLiteCustomOp(op_name, execution_provider), regular_fn_(fn_compute) {
296 ParseArgs(&CustomOpKernel::Compute);
297
298#if ORT_API_VERSION > 15
299 if (OrtCustomOp::version > 15) {
300 DefineCallbackFunctions(&CustomOpKernel::Compute, fn_compute);
301 } else
302#endif // ORT_API_VERSION > 15
303 {
304 DefineCallbackFunctionsLegacy(&CustomOpKernel::Compute, fn_compute);
305 }
306 }
307
308 RegularComputeType regular_fn_{};
309};
310
311template <typename... Args>
312OrtLiteCustomOp* CreateLiteCustomOpV2(const char* op_name,
313 const char* execution_provider,
314 OrtStatusPtr (*custom_compute_fn)(Args...)) {
315 using LiteOp = OrtLiteCustomStructV2<FunctionKernel<Args...>>;
316 return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn).release();
317}
318
319template <typename OpKernel>
320OrtLiteCustomOp* CreateLiteCustomOpV2(const char* op_name,
321 const char* execution_provider) {
322 using LiteOp = OrtLiteCustomStructV2<OpKernel>;
323 return std::make_unique<LiteOp>(op_name, execution_provider).release();
324}
325
326} // namespace Custom
327} // namespace Ort
328
329namespace ortc = Ort::Custom;
330