microsoft/onnxruntime-extensions

Public

mirrored from https://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
cmake44

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

include/op_def_struct.h

306lines · modeblame

c599b00dWenbing Li3 years ago1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4//.A very thin wrapper of ONNXRuntime Custom Operator Callback ABI, which
5// is only used in the custom-op kernels. For the general ORT C++ invocation, like end-to-end
6// testing, the ONNXRuntime public C++ APIs should be used since there is no binary compatible requirement.
7
8#pragma once
01e3a636Wenbing Li2 years ago9#include <cstdint>
c599b00dWenbing Li3 years ago10#include <cstddef>
11#include <array>
12#include <memory>
13#include <string>
14#include <vector>
15#include <utility>
16#include <type_traits>
914509d5Wenbing Li2 years ago17#include <optional>
22340011cao lei2 years ago18#include <functional>
c599b00dWenbing Li3 years ago19
5e44a7c3Scott McKay3 years ago20#include "exceptions.h"
a0c26255Wenbing Li2 years ago21#include "onnxruntime_extensions.h"
64646279Wenbing Li2 years ago22#include "custom_op/custom_op_lite.h"
914509d5Wenbing Li2 years ago23
24#define MIN_ORT_VERSION_SUPPORTED 11
8f36cf32Tang, Cheng3 years ago25
914509d5Wenbing Li2 years ago26namespace Ort {
27namespace Custom {
28
f9290e8bWenbing Li2 years ago29template <typename T>
30inline OrtStatusPtr ToApiStatus(const T& status) {
97ee9eb5Wenbing Li2 years ago31return (OrtStatus*)status;
f9290e8bWenbing Li2 years ago32}
33
34template <>
35inline OrtStatusPtr ToApiStatus(const OrtStatusPtr& status) {
36return status;
37}
38
39template <typename RType, typename... Args>
914509d5Wenbing Li2 years ago40struct FunctionKernel {
f9290e8bWenbing Li2 years ago41using ComputeFn = std::function<RType(Args...)>;
914509d5Wenbing Li2 years ago42
f9290e8bWenbing Li2 years ago43RType Compute(Args... args) const {
914509d5Wenbing Li2 years ago44return compute_fn_(args...);
45}
46
47ComputeFn compute_fn_;
48};
49
50// primary template handles types that have no nested ::type member:
51template <class, class = void>
52struct IsFunctionKernel {
53typedef std::false_type type;
54};
55
56// specialization recognizes types that do have a nested ::type member:
57template <class T>
a0c26255Wenbing Li2 years ago58struct IsFunctionKernel<T, std::void_t<typename T::ComputeFn>> {
914509d5Wenbing Li2 years ago59typedef std::true_type type;
60};
61
62// Helper type
63template <typename T>
64struct ComputeArgsList;
65
66// Specialization for member function
f9290e8bWenbing Li2 years ago67template <typename RType, typename C, typename... Args>
68struct ComputeArgsList<RType (C::*)(Args...) const> {
69using FunctionType = RType (*)(Args...);
70using MemberFunctionType = RType (C::*)(Args...) const;
71using ResultType = RType;
914509d5Wenbing Li2 years ago72};
73
3b889fc4Tang, Cheng2 years ago74template<typename, typename T>
75struct HasOnModelAttach {
76static_assert(
77std::integral_constant<T, false>::value,
78"Second template parameter needs to be of function type.");
79};
80
81// specialization that does the checking
82
83template<typename C, typename Ret, typename... Args>
84struct HasOnModelAttach<C, Ret(Args...)> {
85private:
86template<typename T>
87static constexpr auto check(T*)
88-> typename
89std::is_same<
90decltype( std::declval<T>().OnModelAttach( std::declval<Args>()... ) ),
91Ret
92>::type; // attempt to call it and see if the return type is correct
93
94template<typename>
95static constexpr std::false_type check(...);
96
97typedef decltype(check<C>(0)) type;
98
99public:
100static constexpr bool value = type::value;
101};
102
22340011cao lei2 years ago103template <typename T, typename = void>
104struct CustomOp_defined_getInputMemoryType : std::false_type {};
105
106template <typename T>
107struct CustomOp_defined_getInputMemoryType<T, std::void_t<decltype(&T::GetInputMemoryType)>> : std::true_type {};
108
95d65e4ecao lei1 years ago109template <typename T, typename = void>
110struct CustomOp_defined_getMayInplace : std::false_type {};
111
112template <typename T>
113struct CustomOp_defined_getMayInplace<T, std::void_t<decltype(&T::GetMayInplace)>> : std::true_type {};
114
115template <typename T, typename = void>
116struct CustomOp_defined_releaseMayInplace : std::false_type {};
117
118template <typename T>
119struct CustomOp_defined_releaseMayInplace<T, std::void_t<decltype(&T::ReleaseMayInplace)>> : std::true_type {};
120
914509d5Wenbing Li2 years ago121template <typename CustomOpKernel>
122struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
123using ComputeFunction = decltype(&CustomOpKernel::Compute);
124using RegularComputeType = typename ComputeArgsList<ComputeFunction>::FunctionType;
f9290e8bWenbing Li2 years ago125using RType = typename ComputeArgsList<ComputeFunction>::ResultType;
914509d5Wenbing Li2 years ago126
127template <typename... Args>
f9290e8bWenbing Li2 years ago128using MemberComputeType = RType (CustomOpKernel::*)(Args...) const;
914509d5Wenbing Li2 years ago129
130struct KernelEx : public CustomOpKernel {
131struct {
132std::string ep_{};
133std::unique_ptr<OrtW::CustomOpApi> api_;
134} extra_;
135};
136
137template <typename T>
138static OrtStatusPtr InitKernel(KernelEx& kernel,
a0c26255Wenbing Li2 years ago139const OrtApi& api, const OrtKernelInfo& info, RegularComputeType fn, T t) {
3b889fc4Tang, Cheng2 years ago140if constexpr (HasOnModelAttach<KernelEx, OrtStatusPtr(const OrtApi&, const OrtKernelInfo&)>::value){
141auto status = kernel.OnModelAttach(api, info);
142return ToApiStatus(status);
143}
144else {
145auto status = kernel.OnModelAttach(OrtAttributeReader(api, info));
146return ToApiStatus(status);
147}
914509d5Wenbing Li2 years ago148}
149
150static OrtStatusPtr InitKernel(
a0c26255Wenbing Li2 years ago151KernelEx& kernel,
152const OrtApi& api, const OrtKernelInfo& info, RegularComputeType fn, std::true_type) {
914509d5Wenbing Li2 years ago153kernel.compute_fn_ = fn;
154return nullptr;
155}
156
157template <typename... Args>
158void ParseArgs(MemberComputeType<Args...> fn) {
159OrtLiteCustomOp::ParseArgs<Args...>(OrtLiteCustomOp::input_types_, OrtLiteCustomOp::output_types_);
160}
161
162// TODO: consider to disable these legacy functions for mobile build to save binary size
163template <typename... Args>
164void DefineCallbackFunctionsLegacy(MemberComputeType<Args...> fn, RegularComputeType regular_fn) {
165OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
166auto self = static_cast<const OrtLiteCustomStructV2<CustomOpKernel>*>(this_);
167auto kernel = std::make_unique<KernelEx>();
168typedef typename IsFunctionKernel<CustomOpKernel>::type type_flag;
f9290e8bWenbing Li2 years ago169auto status = InitKernel(*kernel, *ort_api, *info, self->regular_fn_, type_flag());
914509d5Wenbing Li2 years ago170OrtW::ThrowOnError(*ort_api, status);
171
172kernel->extra_.ep_ = self->execution_provider_;
173kernel->extra_.api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
174return reinterpret_cast<void*>(kernel.release());
175};
176
177OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
178auto kernel = reinterpret_cast<KernelEx*>(op_kernel);
179std::vector<TensorPtr> tensors;
180auto t = CreateTuple<0, 0, Args...>(kernel->extra_.api_.get(),
181context,
182tensors,
183kernel->extra_.api_->KernelContext_GetInputCount(context),
184kernel->extra_.api_->KernelContext_GetOutputCount(context),
185kernel->extra_.ep_);
186std::apply([kernel](Args const&... t_args) {
f9290e8bWenbing Li2 years ago187auto status = kernel->Compute(t_args...); OrtW::API::ThrowOnError(ToApiStatus(status)); }, t);
914509d5Wenbing Li2 years ago188};
189
190OrtCustomOp::KernelDestroy = [](void* op_kernel) {
191std::unique_ptr<KernelEx>(reinterpret_cast<KernelEx*>(op_kernel)).reset();
192};
193}
194
cd7d6459Wenbing Li2 years ago195#if ORT_API_VERSION >= 16
914509d5Wenbing Li2 years ago196template <typename... Args>
197void DefineCallbackFunctions(MemberComputeType<Args...> fn, RegularComputeType regular_fn) {
198OrtCustomOp::CreateKernel = nullptr;
199OrtCustomOp::KernelCompute = nullptr;
200
22340011cao lei2 years ago201if constexpr (CustomOp_defined_getInputMemoryType<CustomOpKernel>::value) {
202OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* /*this_*/, size_t index) -> OrtMemType {
203return CustomOpKernel::GetInputMemoryType(index);
204};
205}
206
95d65e4ecao lei1 years ago207#if ORT_API_VERSION >= 18
208if constexpr (CustomOp_defined_getMayInplace<CustomOpKernel>::value) {
209OrtCustomOp::GetMayInplace = [](int** input_index, int** output_index) -> size_t {
210return CustomOpKernel::GetMayInplace(input_index, output_index);
211};
212}
213if constexpr (CustomOp_defined_releaseMayInplace<CustomOpKernel>::value) {
214OrtCustomOp::ReleaseMayInplace = [](int* input_index, int* output_index) -> void {
215CustomOpKernel::ReleaseMayInplace(input_index, output_index);
216};
217}
218#endif
219
914509d5Wenbing Li2 years ago220OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_,
221const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
222if (api == nullptr) {
223assert(false && "Got a null pointer for ORT api on calling CreateKernelV2");
224// should never happened, what we can do?
225return nullptr;
226}
227
228if (this_ == nullptr || info == nullptr || op_kernel == nullptr) {
229return api->CreateStatus(ORT_INVALID_ARGUMENT, "OrtCustomOp::CreateKernelV2: received a null pointer");
230}
231
232auto self = static_cast<const OrtLiteCustomStructV2<CustomOpKernel>*>(this_);
233auto kernel = std::make_unique<KernelEx>();
234if (kernel == nullptr) {
235return api->CreateStatus(ORT_FAIL, "OrtCustomOp::CreateKernelV2: failed to new a kernel, OOM?");
236}
237
238typedef typename IsFunctionKernel<CustomOpKernel>::type flag_type;
f9290e8bWenbing Li2 years ago239auto status = InitKernel(*kernel, *api, *info, self->regular_fn_, flag_type());
914509d5Wenbing Li2 years ago240if (status == nullptr) {
241kernel->extra_.ep_ = self->execution_provider_;
242kernel->extra_.api_ = std::make_unique<OrtW::CustomOpApi>(*api);
243*op_kernel = reinterpret_cast<void*>(kernel.release());
244}
245
246return status;
247};
248
249OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
a0c26255Wenbing Li2 years ago250auto kernel = reinterpret_cast<KernelEx*>(op_kernel);
914509d5Wenbing Li2 years ago251std::vector<TensorPtr> tensors;
252auto t = CreateTuple<0, 0, Args...>(kernel->extra_.api_.get(),
253context,
254tensors,
255kernel->extra_.api_->KernelContext_GetInputCount(context),
256kernel->extra_.api_->KernelContext_GetOutputCount(context),
257kernel->extra_.ep_);
f9290e8bWenbing Li2 years ago258return std::apply([kernel](Args const&... t_args) {
259auto status = kernel->Compute(t_args...);
260return ToApiStatus(status); }, t);
914509d5Wenbing Li2 years ago261};
262
263OrtCustomOp::KernelDestroy = [](void* op_kernel) {
264std::unique_ptr<KernelEx>(reinterpret_cast<KernelEx*>(op_kernel)).reset();
265};
266}
cd7d6459Wenbing Li2 years ago267#endif // ORT_API_VERSION >= 16
914509d5Wenbing Li2 years ago268
269OrtLiteCustomStructV2(const char* op_name,
270const char* execution_provider,
271RegularComputeType fn_compute = nullptr)
272: OrtLiteCustomOp(op_name, execution_provider), regular_fn_(fn_compute) {
273ParseArgs(&CustomOpKernel::Compute);
274
cd7d6459Wenbing Li2 years ago275#if ORT_API_VERSION >= 16
f9290e8bWenbing Li2 years ago276if (OrtCustomOp::version >= 16) {
914509d5Wenbing Li2 years ago277DefineCallbackFunctions(&CustomOpKernel::Compute, fn_compute);
278} else
cd7d6459Wenbing Li2 years ago279#endif // ORT_API_VERSION >= 16
914509d5Wenbing Li2 years ago280{
281DefineCallbackFunctionsLegacy(&CustomOpKernel::Compute, fn_compute);
282}
283}
284
285RegularComputeType regular_fn_{};
286};
287
f9290e8bWenbing Li2 years ago288template <typename RType, typename... Args>
30609b74yunmengxie1 years ago289std::shared_ptr<OrtLiteCustomOp> CreateLiteCustomOpV2(const char* op_name,
914509d5Wenbing Li2 years ago290const char* execution_provider,
f9290e8bWenbing Li2 years ago291RType (*custom_compute_fn)(Args...)) {
292using LiteOp = OrtLiteCustomStructV2<FunctionKernel<RType, Args...>>;
30609b74yunmengxie1 years ago293return std::make_shared<LiteOp>(op_name, execution_provider, custom_compute_fn);
914509d5Wenbing Li2 years ago294}
295
296template <typename OpKernel>
30609b74yunmengxie1 years ago297std::shared_ptr<OrtLiteCustomOp> CreateLiteCustomOpV2(const char* op_name,
914509d5Wenbing Li2 years ago298const char* execution_provider) {
299using LiteOp = OrtLiteCustomStructV2<OpKernel>;
30609b74yunmengxie1 years ago300return std::make_shared<LiteOp>(op_name, execution_provider);
914509d5Wenbing Li2 years ago301}
302
303} // namespace Custom
304} // namespace Ort
305
8f36cf32Tang, Cheng3 years ago306namespace ortc = Ort::Custom;