microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
sayanshaw/triton-test

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

include/op_def_struct.h

246lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4//.A very thin wrapper of ONNXRuntime Custom Operator Callback ABI, which
5// is only used in the custom-op kernels. For the general ORT C++ invocation, like end-to-end
6// testing, the ONNXRuntime public C++ APIs should be used since there is no binary compatible requirement.
7
8#pragma once
9#include <cstdint>
10#include <cstddef>
11#include <array>
12#include <memory>
13#include <string>
14#include <vector>
15#include <utility>
16#include <type_traits>
17#include <optional>
18#include <functional>
19
20#include "exceptions.h"
21#include "onnxruntime_extensions.h"
22#include "custom_op/custom_op_lite.h"
23
24#define MIN_ORT_VERSION_SUPPORTED 11
25
26namespace Ort {
27namespace Custom {
28
29template <typename T>
30inline OrtStatusPtr ToApiStatus(const T& status) {
31 return status.CreateOrtStatus();
32}
33
34template <>
35inline OrtStatusPtr ToApiStatus(const OrtStatusPtr& status) {
36 return status;
37}
38
39template <typename RType, typename... Args>
40struct FunctionKernel {
41 using ComputeFn = std::function<RType(Args...)>;
42
43 RType Compute(Args... args) const {
44 return compute_fn_(args...);
45 }
46
47 ComputeFn compute_fn_;
48};
49
50// primary template handles types that have no nested ::type member:
51template <class, class = void>
52struct IsFunctionKernel {
53 typedef std::false_type type;
54};
55
56// specialization recognizes types that do have a nested ::type member:
57template <class T>
58struct IsFunctionKernel<T, std::void_t<typename T::ComputeFn>> {
59 typedef std::true_type type;
60};
61
62// Helper type
63template <typename T>
64struct ComputeArgsList;
65
66// Specialization for member function
67template <typename RType, typename C, typename... Args>
68struct ComputeArgsList<RType (C::*)(Args...) const> {
69 using FunctionType = RType (*)(Args...);
70 using MemberFunctionType = RType (C::*)(Args...) const;
71 using ResultType = RType;
72};
73
74template <typename T, typename = void>
75struct CustomOp_defined_getInputMemoryType : std::false_type {};
76
77template <typename T>
78struct CustomOp_defined_getInputMemoryType<T, std::void_t<decltype(&T::GetInputMemoryType)>> : std::true_type {};
79
80template <typename CustomOpKernel>
81struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
82 using ComputeFunction = decltype(&CustomOpKernel::Compute);
83 using RegularComputeType = typename ComputeArgsList<ComputeFunction>::FunctionType;
84 using RType = typename ComputeArgsList<ComputeFunction>::ResultType;
85
86 template <typename... Args>
87 using MemberComputeType = RType (CustomOpKernel::*)(Args...) const;
88
89 struct KernelEx : public CustomOpKernel {
90 struct {
91 std::string ep_{};
92 std::unique_ptr<OrtW::CustomOpApi> api_;
93 } extra_;
94 };
95
96 template <typename T>
97 static OrtStatusPtr InitKernel(KernelEx& kernel,
98 const OrtApi& api, const OrtKernelInfo& info, RegularComputeType fn, T t) {
99 auto status = kernel.OnModelAttach(api, info);
100 return ToApiStatus(status);
101 }
102
103 static OrtStatusPtr InitKernel(
104 KernelEx& kernel,
105 const OrtApi& api, const OrtKernelInfo& info, RegularComputeType fn, std::true_type) {
106 kernel.compute_fn_ = fn;
107 return nullptr;
108 }
109
110 template <typename... Args>
111 void ParseArgs(MemberComputeType<Args...> fn) {
112 OrtLiteCustomOp::ParseArgs<Args...>(OrtLiteCustomOp::input_types_, OrtLiteCustomOp::output_types_);
113 }
114
115 // TODO: consider to disable these legacy functions for mobile build to save binary size
116 template <typename... Args>
117 void DefineCallbackFunctionsLegacy(MemberComputeType<Args...> fn, RegularComputeType regular_fn) {
118 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
119 auto self = static_cast<const OrtLiteCustomStructV2<CustomOpKernel>*>(this_);
120 auto kernel = std::make_unique<KernelEx>();
121 typedef typename IsFunctionKernel<CustomOpKernel>::type type_flag;
122 auto status = InitKernel(*kernel, *ort_api, *info, self->regular_fn_, type_flag());
123 OrtW::ThrowOnError(*ort_api, status);
124
125 kernel->extra_.ep_ = self->execution_provider_;
126 kernel->extra_.api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
127 return reinterpret_cast<void*>(kernel.release());
128 };
129
130 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
131 auto kernel = reinterpret_cast<KernelEx*>(op_kernel);
132 std::vector<TensorPtr> tensors;
133 auto t = CreateTuple<0, 0, Args...>(kernel->extra_.api_.get(),
134 context,
135 tensors,
136 kernel->extra_.api_->KernelContext_GetInputCount(context),
137 kernel->extra_.api_->KernelContext_GetOutputCount(context),
138 kernel->extra_.ep_);
139 std::apply([kernel](Args const&... t_args) {
140 auto status = kernel->Compute(t_args...); OrtW::API::ThrowOnError(ToApiStatus(status)); }, t);
141 };
142
143 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
144 std::unique_ptr<KernelEx>(reinterpret_cast<KernelEx*>(op_kernel)).reset();
145 };
146 }
147
148#if ORT_API_VERSION >= 16
149 template <typename... Args>
150 void DefineCallbackFunctions(MemberComputeType<Args...> fn, RegularComputeType regular_fn) {
151 OrtCustomOp::CreateKernel = nullptr;
152 OrtCustomOp::KernelCompute = nullptr;
153
154 if constexpr (CustomOp_defined_getInputMemoryType<CustomOpKernel>::value) {
155 OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* /*this_*/, size_t index) -> OrtMemType {
156 return CustomOpKernel::GetInputMemoryType(index);
157 };
158 }
159
160 OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_,
161 const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
162 if (api == nullptr) {
163 assert(false && "Got a null pointer for ORT api on calling CreateKernelV2");
164 // should never happened, what we can do?
165 return nullptr;
166 }
167
168 if (this_ == nullptr || info == nullptr || op_kernel == nullptr) {
169 return api->CreateStatus(ORT_INVALID_ARGUMENT, "OrtCustomOp::CreateKernelV2: received a null pointer");
170 }
171
172 auto self = static_cast<const OrtLiteCustomStructV2<CustomOpKernel>*>(this_);
173 auto kernel = std::make_unique<KernelEx>();
174 if (kernel == nullptr) {
175 return api->CreateStatus(ORT_FAIL, "OrtCustomOp::CreateKernelV2: failed to new a kernel, OOM?");
176 }
177
178 typedef typename IsFunctionKernel<CustomOpKernel>::type flag_type;
179 auto status = InitKernel(*kernel, *api, *info, self->regular_fn_, flag_type());
180 if (status == nullptr) {
181 kernel->extra_.ep_ = self->execution_provider_;
182 kernel->extra_.api_ = std::make_unique<OrtW::CustomOpApi>(*api);
183 *op_kernel = reinterpret_cast<void*>(kernel.release());
184 }
185
186 return status;
187 };
188
189 OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
190 auto kernel = reinterpret_cast<KernelEx*>(op_kernel);
191 std::vector<TensorPtr> tensors;
192 auto t = CreateTuple<0, 0, Args...>(kernel->extra_.api_.get(),
193 context,
194 tensors,
195 kernel->extra_.api_->KernelContext_GetInputCount(context),
196 kernel->extra_.api_->KernelContext_GetOutputCount(context),
197 kernel->extra_.ep_);
198 return std::apply([kernel](Args const&... t_args) {
199 auto status = kernel->Compute(t_args...);
200 return ToApiStatus(status); }, t);
201 };
202
203 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
204 std::unique_ptr<KernelEx>(reinterpret_cast<KernelEx*>(op_kernel)).reset();
205 };
206 }
207#endif // ORT_API_VERSION >= 16
208
209 OrtLiteCustomStructV2(const char* op_name,
210 const char* execution_provider,
211 RegularComputeType fn_compute = nullptr)
212 : OrtLiteCustomOp(op_name, execution_provider), regular_fn_(fn_compute) {
213 ParseArgs(&CustomOpKernel::Compute);
214
215#if ORT_API_VERSION >= 16
216 if (OrtCustomOp::version >= 16) {
217 DefineCallbackFunctions(&CustomOpKernel::Compute, fn_compute);
218 } else
219#endif // ORT_API_VERSION >= 16
220 {
221 DefineCallbackFunctionsLegacy(&CustomOpKernel::Compute, fn_compute);
222 }
223 }
224
225 RegularComputeType regular_fn_{};
226};
227
228template <typename RType, typename... Args>
229OrtLiteCustomOp* CreateLiteCustomOpV2(const char* op_name,
230 const char* execution_provider,
231 RType (*custom_compute_fn)(Args...)) {
232 using LiteOp = OrtLiteCustomStructV2<FunctionKernel<RType, Args...>>;
233 return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn).release();
234}
235
236template <typename OpKernel>
237OrtLiteCustomOp* CreateLiteCustomOpV2(const char* op_name,
238 const char* execution_provider) {
239 using LiteOp = OrtLiteCustomStructV2<OpKernel>;
240 return std::make_unique<LiteOp>(op_name, execution_provider).release();
241}
242
243} // namespace Custom
244} // namespace Ort
245
246namespace ortc = Ort::Custom;
247