microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
fb2a8c28419f255fcd7283ba14bcac61e721d4e4

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/onnxruntime_customop.hpp

322lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4//.A very thin wrapper of ONNXRuntime Custom Operator Callback ABI, which
5// is only used in the custom-op kernels. For the general ORT C++ invocation, like end-to-end
6// testing, the ONNXRuntime public C++ APIs should be used since there is no binary compatible requirement.
7
8#pragma once
9#include <cstddef>
10#include <array>
11#include <memory>
12#include <string>
13#include <vector>
14#include <utility>
15#include <type_traits>
16#include <optional>
17
18#include "onnxruntime_c_api.h"
19#include "exceptions.h"
20#include "onnxruntime_cpp_api_legacy.hpp"
21#include "onnxruntime_extensions.h"
22#include "custom_op_lite.h"
23
24#define MIN_ORT_VERSION_SUPPORTED 11
25
26// namespace of ORT ABI Wrapper
27namespace OrtW {
28
29class API {
30 // To use ONNX C ABI in a way like OrtW::API::CreateStatus.
31 public:
32 static API& instance(const OrtApi* ort_api = nullptr) noexcept {
33 static API self(ort_api);
34 return self;
35 }
36
37 static OrtStatusPtr CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept {
38 return instance()->CreateStatus(code, msg);
39 }
40
41 static void ReleaseStatus(OrtStatusPtr ptr) noexcept {
42 instance()->ReleaseStatus(ptr);
43 }
44
45 template <typename T>
46 static OrtStatusPtr KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept;
47
48 static void ThrowOnError(OrtStatusPtr ptr) {
49 OrtW::ThrowOnError(instance().api_, ptr);
50 }
51
52 private:
53 const OrtApi* operator->() const {
54 return &api_;
55 }
56
57 API(const OrtApi* api) : api_(*api) {
58 if (api == nullptr) {
59 ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION);
60 }
61 }
62
63 const OrtApi& api_;
64};
65
66template <>
67inline OrtStatusPtr API::KernelInfoGetAttribute<int64_t>(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept {
68 return instance()->KernelInfoGetAttribute_int64(&info, name, &value);
69}
70
71template <>
72inline OrtStatusPtr API::KernelInfoGetAttribute<float>(const OrtKernelInfo& info, const char* name, float& value) noexcept {
73 return instance()->KernelInfoGetAttribute_float(&info, name, &value);
74}
75
76template <>
77inline OrtStatusPtr API::KernelInfoGetAttribute<std::string>(const OrtKernelInfo& info, const char* name, std::string& value) noexcept {
78 size_t size = 0;
79 std::string out;
80 // Feed nullptr for the data buffer to query the true size of the string attribute
81 OrtStatus* status = instance()->KernelInfoGetAttribute_string(&info, name, nullptr, &size);
82 if (status == nullptr) {
83 out.resize(size);
84 status = instance()->KernelInfoGetAttribute_string(&info, name, &out[0], &size);
85 out.resize(size - 1); // remove the terminating character '\0'
86 }
87
88 if (status == nullptr) {
89 value = std::move(out);
90 }
91
92 return status;
93}
94
95template <class T>
96inline OrtStatusPtr GetOpAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept {
97 if (auto status = API::KernelInfoGetAttribute(info, name, value); status) {
98 // Ideally, we should know which kind of error code can be ignored, but it is not available now.
99 // Just ignore all of them.
100 API::ReleaseStatus(status);
101 }
102
103 return nullptr;
104}
105
106inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) {
107 return API::CreateStatus(code, msg);
108}
109
110inline OrtStatusPtr CreateStatus(const std::string& msg, OrtErrorCode code) {
111 return API::CreateStatus(code, msg.c_str());
112}
113
114inline void ReleaseStatus(OrtStatusPtr& status) {
115 API::ReleaseStatus(status);
116 status = nullptr;
117}
118
119} // namespace OrtW
120
121#define ORTX_RETURN_IF_ERROR(expr) \
122 do { \
123 auto _status = (expr); \
124 if (_status != nullptr) { \
125 return _status; \
126 } \
127 } while (0)
128
129namespace Ort {
130namespace Custom {
131
132template <typename... Args>
133struct FunctionKernel {
134 using ComputeFn = std::function<OrtStatusPtr(Args...)>;
135
136 OrtStatusPtr Compute(Args... args) const {
137 return compute_fn_(args...);
138 }
139
140 ComputeFn compute_fn_;
141};
142
143// primary template handles types that have no nested ::type member:
144template <class, class = void>
145struct IsFunctionKernel {
146 typedef std::false_type type;
147};
148
149// specialization recognizes types that do have a nested ::type member:
150template <class T>
151struct IsFunctionKernel<T, std::void_t<typename T::ComputeFn>> {
152 typedef std::true_type type;
153};
154
155// Helper type
156template <typename T>
157struct ComputeArgsList;
158
159// Specialization for member function
160template <typename C, typename... Args>
161struct ComputeArgsList<OrtStatusPtr (C::*)(Args...) const> {
162 using FunctionType = OrtStatusPtr (*)(Args...);
163 using MemberFunctionType = OrtStatusPtr (C::*)(Args...) const;
164};
165
166template <typename CustomOpKernel>
167struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
168 using ComputeFunction = decltype(&CustomOpKernel::Compute);
169 using RegularComputeType = typename ComputeArgsList<ComputeFunction>::FunctionType;
170
171 template <typename... Args>
172 using MemberComputeType = OrtStatusPtr (CustomOpKernel::*)(Args...) const;
173
174 struct KernelEx : public CustomOpKernel {
175 struct {
176 std::string ep_{};
177 std::unique_ptr<OrtW::CustomOpApi> api_;
178 } extra_;
179 };
180
181 template <typename T>
182 static OrtStatusPtr InitKernel(KernelEx& kernel,
183 const OrtApi& api, const OrtKernelInfo& info, RegularComputeType fn, T t) {
184 return kernel.OnModelAttach(api, info);
185 }
186
187 static OrtStatusPtr InitKernel(
188 KernelEx& kernel,
189 const OrtApi& api, const OrtKernelInfo& info, RegularComputeType fn, std::true_type) {
190 kernel.compute_fn_ = fn;
191 return nullptr;
192 }
193
194 template <typename... Args>
195 void ParseArgs(MemberComputeType<Args...> fn) {
196 OrtLiteCustomOp::ParseArgs<Args...>(OrtLiteCustomOp::input_types_, OrtLiteCustomOp::output_types_);
197 }
198
199 // TODO: consider to disable these legacy functions for mobile build to save binary size
200 template <typename... Args>
201 void DefineCallbackFunctionsLegacy(MemberComputeType<Args...> fn, RegularComputeType regular_fn) {
202 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
203 auto self = static_cast<const OrtLiteCustomStructV2<CustomOpKernel>*>(this_);
204 auto kernel = std::make_unique<KernelEx>();
205 typedef typename IsFunctionKernel<CustomOpKernel>::type type_flag;
206 OrtStatusPtr status = InitKernel(*kernel, *ort_api, *info, self->regular_fn_, type_flag());
207 OrtW::ThrowOnError(*ort_api, status);
208
209 kernel->extra_.ep_ = self->execution_provider_;
210 kernel->extra_.api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
211 return reinterpret_cast<void*>(kernel.release());
212 };
213
214 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
215 auto kernel = reinterpret_cast<KernelEx*>(op_kernel);
216 std::vector<TensorPtr> tensors;
217 auto t = CreateTuple<0, 0, Args...>(kernel->extra_.api_.get(),
218 context,
219 tensors,
220 kernel->extra_.api_->KernelContext_GetInputCount(context),
221 kernel->extra_.api_->KernelContext_GetOutputCount(context),
222 kernel->extra_.ep_);
223 std::apply([kernel](Args const&... t_args) {
224 auto status = kernel->Compute(t_args...); OrtW::API::ThrowOnError(status); }, t);
225 };
226
227 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
228 std::unique_ptr<KernelEx>(reinterpret_cast<KernelEx*>(op_kernel)).reset();
229 };
230 }
231
232#if ORT_API_VERSION >= 16
233 template <typename... Args>
234 void DefineCallbackFunctions(MemberComputeType<Args...> fn, RegularComputeType regular_fn) {
235 OrtCustomOp::CreateKernel = nullptr;
236 OrtCustomOp::KernelCompute = nullptr;
237
238 OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_,
239 const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
240 if (api == nullptr) {
241 assert(false && "Got a null pointer for ORT api on calling CreateKernelV2");
242 // should never happened, what we can do?
243 return nullptr;
244 }
245
246 if (this_ == nullptr || info == nullptr || op_kernel == nullptr) {
247 return api->CreateStatus(ORT_INVALID_ARGUMENT, "OrtCustomOp::CreateKernelV2: received a null pointer");
248 }
249
250 auto self = static_cast<const OrtLiteCustomStructV2<CustomOpKernel>*>(this_);
251 auto kernel = std::make_unique<KernelEx>();
252 if (kernel == nullptr) {
253 return api->CreateStatus(ORT_FAIL, "OrtCustomOp::CreateKernelV2: failed to new a kernel, OOM?");
254 }
255
256 typedef typename IsFunctionKernel<CustomOpKernel>::type flag_type;
257 OrtStatusPtr status = InitKernel(*kernel, *api, *info, self->regular_fn_, flag_type());
258 if (status == nullptr) {
259 kernel->extra_.ep_ = self->execution_provider_;
260 kernel->extra_.api_ = std::make_unique<OrtW::CustomOpApi>(*api);
261 *op_kernel = reinterpret_cast<void*>(kernel.release());
262 }
263
264 return status;
265 };
266
267 OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
268 auto kernel = reinterpret_cast<KernelEx*>(op_kernel);
269 std::vector<TensorPtr> tensors;
270 auto t = CreateTuple<0, 0, Args...>(kernel->extra_.api_.get(),
271 context,
272 tensors,
273 kernel->extra_.api_->KernelContext_GetInputCount(context),
274 kernel->extra_.api_->KernelContext_GetOutputCount(context),
275 kernel->extra_.ep_);
276 return std::apply([kernel](Args const&... t_args) { return kernel->Compute(t_args...); }, t);
277 };
278
279 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
280 std::unique_ptr<KernelEx>(reinterpret_cast<KernelEx*>(op_kernel)).reset();
281 };
282 }
283#endif // ORT_API_VERSION >= 16
284
285 OrtLiteCustomStructV2(const char* op_name,
286 const char* execution_provider,
287 RegularComputeType fn_compute = nullptr)
288 : OrtLiteCustomOp(op_name, execution_provider), regular_fn_(fn_compute) {
289 ParseArgs(&CustomOpKernel::Compute);
290
291#if ORT_API_VERSION >= 16
292 if (OrtCustomOp::version > 15) {
293 DefineCallbackFunctions(&CustomOpKernel::Compute, fn_compute);
294 } else
295#endif // ORT_API_VERSION >= 16
296 {
297 DefineCallbackFunctionsLegacy(&CustomOpKernel::Compute, fn_compute);
298 }
299 }
300
301 RegularComputeType regular_fn_{};
302};
303
304template <typename... Args>
305OrtLiteCustomOp* CreateLiteCustomOpV2(const char* op_name,
306 const char* execution_provider,
307 OrtStatusPtr (*custom_compute_fn)(Args...)) {
308 using LiteOp = OrtLiteCustomStructV2<FunctionKernel<Args...>>;
309 return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn).release();
310}
311
312template <typename OpKernel>
313OrtLiteCustomOp* CreateLiteCustomOpV2(const char* op_name,
314 const char* execution_provider) {
315 using LiteOp = OrtLiteCustomStructV2<OpKernel>;
316 return std::make_unique<LiteOp>(op_name, execution_provider).release();
317}
318
319} // namespace Custom
320} // namespace Ort
321
322namespace ortc = Ort::Custom;
323