microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
include/op_def_struct.h
306lines · modecode
| 1 | // Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | // Licensed under the MIT License. |
| 3 | |
| 4 | //.A very thin wrapper of ONNXRuntime Custom Operator Callback ABI, which |
| 5 | // is only used in the custom-op kernels. For the general ORT C++ invocation, like end-to-end |
| 6 | // testing, the ONNXRuntime public C++ APIs should be used since there is no binary compatible requirement. |
| 7 | |
| 8 | #pragma once |
| 9 | #include <cstdint> |
| 10 | #include <cstddef> |
| 11 | #include <array> |
| 12 | #include <memory> |
| 13 | #include <string> |
| 14 | #include <vector> |
| 15 | #include <utility> |
| 16 | #include <type_traits> |
| 17 | #include <optional> |
| 18 | #include <functional> |
| 19 | |
| 20 | #include "exceptions.h" |
| 21 | #include "onnxruntime_extensions.h" |
| 22 | #include "custom_op/custom_op_lite.h" |
| 23 | |
| 24 | #define MIN_ORT_VERSION_SUPPORTED 11 |
| 25 | |
| 26 | namespace Ort { |
| 27 | namespace Custom { |
| 28 | |
| 29 | template <typename T> |
| 30 | inline OrtStatusPtr ToApiStatus(const T& status) { |
| 31 | return (OrtStatus*)status; |
| 32 | } |
| 33 | |
| 34 | template <> |
| 35 | inline OrtStatusPtr ToApiStatus(const OrtStatusPtr& status) { |
| 36 | return status; |
| 37 | } |
| 38 | |
| 39 | template <typename RType, typename... Args> |
| 40 | struct FunctionKernel { |
| 41 | using ComputeFn = std::function<RType(Args...)>; |
| 42 | |
| 43 | RType Compute(Args... args) const { |
| 44 | return compute_fn_(args...); |
| 45 | } |
| 46 | |
| 47 | ComputeFn compute_fn_; |
| 48 | }; |
| 49 | |
| 50 | // primary template handles types that have no nested ::type member: |
| 51 | template <class, class = void> |
| 52 | struct IsFunctionKernel { |
| 53 | typedef std::false_type type; |
| 54 | }; |
| 55 | |
| 56 | // specialization recognizes types that do have a nested ::type member: |
| 57 | template <class T> |
| 58 | struct IsFunctionKernel<T, std::void_t<typename T::ComputeFn>> { |
| 59 | typedef std::true_type type; |
| 60 | }; |
| 61 | |
| 62 | // Helper type |
| 63 | template <typename T> |
| 64 | struct ComputeArgsList; |
| 65 | |
| 66 | // Specialization for member function |
| 67 | template <typename RType, typename C, typename... Args> |
| 68 | struct ComputeArgsList<RType (C::*)(Args...) const> { |
| 69 | using FunctionType = RType (*)(Args...); |
| 70 | using MemberFunctionType = RType (C::*)(Args...) const; |
| 71 | using ResultType = RType; |
| 72 | }; |
| 73 | |
| 74 | template<typename, typename T> |
| 75 | struct HasOnModelAttach { |
| 76 | static_assert( |
| 77 | std::integral_constant<T, false>::value, |
| 78 | "Second template parameter needs to be of function type."); |
| 79 | }; |
| 80 | |
| 81 | // specialization that does the checking |
| 82 | |
| 83 | template<typename C, typename Ret, typename... Args> |
| 84 | struct HasOnModelAttach<C, Ret(Args...)> { |
| 85 | private: |
| 86 | template<typename T> |
| 87 | static constexpr auto check(T*) |
| 88 | -> typename |
| 89 | std::is_same< |
| 90 | decltype( std::declval<T>().OnModelAttach( std::declval<Args>()... ) ), |
| 91 | Ret |
| 92 | >::type; // attempt to call it and see if the return type is correct |
| 93 | |
| 94 | template<typename> |
| 95 | static constexpr std::false_type check(...); |
| 96 | |
| 97 | typedef decltype(check<C>(0)) type; |
| 98 | |
| 99 | public: |
| 100 | static constexpr bool value = type::value; |
| 101 | }; |
| 102 | |
| 103 | template <typename T, typename = void> |
| 104 | struct CustomOp_defined_getInputMemoryType : std::false_type {}; |
| 105 | |
| 106 | template <typename T> |
| 107 | struct CustomOp_defined_getInputMemoryType<T, std::void_t<decltype(&T::GetInputMemoryType)>> : std::true_type {}; |
| 108 | |
| 109 | template <typename T, typename = void> |
| 110 | struct CustomOp_defined_getMayInplace : std::false_type {}; |
| 111 | |
| 112 | template <typename T> |
| 113 | struct CustomOp_defined_getMayInplace<T, std::void_t<decltype(&T::GetMayInplace)>> : std::true_type {}; |
| 114 | |
| 115 | template <typename T, typename = void> |
| 116 | struct CustomOp_defined_releaseMayInplace : std::false_type {}; |
| 117 | |
| 118 | template <typename T> |
| 119 | struct CustomOp_defined_releaseMayInplace<T, std::void_t<decltype(&T::ReleaseMayInplace)>> : std::true_type {}; |
| 120 | |
| 121 | template <typename CustomOpKernel> |
| 122 | struct OrtLiteCustomStructV2 : public OrtLiteCustomOp { |
| 123 | using ComputeFunction = decltype(&CustomOpKernel::Compute); |
| 124 | using RegularComputeType = typename ComputeArgsList<ComputeFunction>::FunctionType; |
| 125 | using RType = typename ComputeArgsList<ComputeFunction>::ResultType; |
| 126 | |
| 127 | template <typename... Args> |
| 128 | using MemberComputeType = RType (CustomOpKernel::*)(Args...) const; |
| 129 | |
| 130 | struct KernelEx : public CustomOpKernel { |
| 131 | struct { |
| 132 | std::string ep_{}; |
| 133 | std::unique_ptr<OrtW::CustomOpApi> api_; |
| 134 | } extra_; |
| 135 | }; |
| 136 | |
| 137 | template <typename T> |
| 138 | static OrtStatusPtr InitKernel(KernelEx& kernel, |
| 139 | const OrtApi& api, const OrtKernelInfo& info, RegularComputeType fn, T t) { |
| 140 | if constexpr (HasOnModelAttach<KernelEx, OrtStatusPtr(const OrtApi&, const OrtKernelInfo&)>::value){ |
| 141 | auto status = kernel.OnModelAttach(api, info); |
| 142 | return ToApiStatus(status); |
| 143 | } |
| 144 | else { |
| 145 | auto status = kernel.OnModelAttach(OrtAttributeReader(api, info)); |
| 146 | return ToApiStatus(status); |
| 147 | } |
| 148 | } |
| 149 | |
| 150 | static OrtStatusPtr InitKernel( |
| 151 | KernelEx& kernel, |
| 152 | const OrtApi& api, const OrtKernelInfo& info, RegularComputeType fn, std::true_type) { |
| 153 | kernel.compute_fn_ = fn; |
| 154 | return nullptr; |
| 155 | } |
| 156 | |
| 157 | template <typename... Args> |
| 158 | void ParseArgs(MemberComputeType<Args...> fn) { |
| 159 | OrtLiteCustomOp::ParseArgs<Args...>(OrtLiteCustomOp::input_types_, OrtLiteCustomOp::output_types_); |
| 160 | } |
| 161 | |
| 162 | // TODO: consider to disable these legacy functions for mobile build to save binary size |
| 163 | template <typename... Args> |
| 164 | void DefineCallbackFunctionsLegacy(MemberComputeType<Args...> fn, RegularComputeType regular_fn) { |
| 165 | OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { |
| 166 | auto self = static_cast<const OrtLiteCustomStructV2<CustomOpKernel>*>(this_); |
| 167 | auto kernel = std::make_unique<KernelEx>(); |
| 168 | typedef typename IsFunctionKernel<CustomOpKernel>::type type_flag; |
| 169 | auto status = InitKernel(*kernel, *ort_api, *info, self->regular_fn_, type_flag()); |
| 170 | OrtW::ThrowOnError(*ort_api, status); |
| 171 | |
| 172 | kernel->extra_.ep_ = self->execution_provider_; |
| 173 | kernel->extra_.api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api); |
| 174 | return reinterpret_cast<void*>(kernel.release()); |
| 175 | }; |
| 176 | |
| 177 | OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { |
| 178 | auto kernel = reinterpret_cast<KernelEx*>(op_kernel); |
| 179 | std::vector<TensorPtr> tensors; |
| 180 | auto t = CreateTuple<0, 0, Args...>(kernel->extra_.api_.get(), |
| 181 | context, |
| 182 | tensors, |
| 183 | kernel->extra_.api_->KernelContext_GetInputCount(context), |
| 184 | kernel->extra_.api_->KernelContext_GetOutputCount(context), |
| 185 | kernel->extra_.ep_); |
| 186 | std::apply([kernel](Args const&... t_args) { |
| 187 | auto status = kernel->Compute(t_args...); OrtW::API::ThrowOnError(ToApiStatus(status)); }, t); |
| 188 | }; |
| 189 | |
| 190 | OrtCustomOp::KernelDestroy = [](void* op_kernel) { |
| 191 | std::unique_ptr<KernelEx>(reinterpret_cast<KernelEx*>(op_kernel)).reset(); |
| 192 | }; |
| 193 | } |
| 194 | |
| 195 | #if ORT_API_VERSION >= 16 |
| 196 | template <typename... Args> |
| 197 | void DefineCallbackFunctions(MemberComputeType<Args...> fn, RegularComputeType regular_fn) { |
| 198 | OrtCustomOp::CreateKernel = nullptr; |
| 199 | OrtCustomOp::KernelCompute = nullptr; |
| 200 | |
| 201 | if constexpr (CustomOp_defined_getInputMemoryType<CustomOpKernel>::value) { |
| 202 | OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp* /*this_*/, size_t index) -> OrtMemType { |
| 203 | return CustomOpKernel::GetInputMemoryType(index); |
| 204 | }; |
| 205 | } |
| 206 | |
| 207 | #if ORT_API_VERSION >= 18 |
| 208 | if constexpr (CustomOp_defined_getMayInplace<CustomOpKernel>::value) { |
| 209 | OrtCustomOp::GetMayInplace = [](int** input_index, int** output_index) -> size_t { |
| 210 | return CustomOpKernel::GetMayInplace(input_index, output_index); |
| 211 | }; |
| 212 | } |
| 213 | if constexpr (CustomOp_defined_releaseMayInplace<CustomOpKernel>::value) { |
| 214 | OrtCustomOp::ReleaseMayInplace = [](int* input_index, int* output_index) -> void { |
| 215 | CustomOpKernel::ReleaseMayInplace(input_index, output_index); |
| 216 | }; |
| 217 | } |
| 218 | #endif |
| 219 | |
| 220 | OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_, |
| 221 | const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr { |
| 222 | if (api == nullptr) { |
| 223 | assert(false && "Got a null pointer for ORT api on calling CreateKernelV2"); |
| 224 | // should never happened, what we can do? |
| 225 | return nullptr; |
| 226 | } |
| 227 | |
| 228 | if (this_ == nullptr || info == nullptr || op_kernel == nullptr) { |
| 229 | return api->CreateStatus(ORT_INVALID_ARGUMENT, "OrtCustomOp::CreateKernelV2: received a null pointer"); |
| 230 | } |
| 231 | |
| 232 | auto self = static_cast<const OrtLiteCustomStructV2<CustomOpKernel>*>(this_); |
| 233 | auto kernel = std::make_unique<KernelEx>(); |
| 234 | if (kernel == nullptr) { |
| 235 | return api->CreateStatus(ORT_FAIL, "OrtCustomOp::CreateKernelV2: failed to new a kernel, OOM?"); |
| 236 | } |
| 237 | |
| 238 | typedef typename IsFunctionKernel<CustomOpKernel>::type flag_type; |
| 239 | auto status = InitKernel(*kernel, *api, *info, self->regular_fn_, flag_type()); |
| 240 | if (status == nullptr) { |
| 241 | kernel->extra_.ep_ = self->execution_provider_; |
| 242 | kernel->extra_.api_ = std::make_unique<OrtW::CustomOpApi>(*api); |
| 243 | *op_kernel = reinterpret_cast<void*>(kernel.release()); |
| 244 | } |
| 245 | |
| 246 | return status; |
| 247 | }; |
| 248 | |
| 249 | OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { |
| 250 | auto kernel = reinterpret_cast<KernelEx*>(op_kernel); |
| 251 | std::vector<TensorPtr> tensors; |
| 252 | auto t = CreateTuple<0, 0, Args...>(kernel->extra_.api_.get(), |
| 253 | context, |
| 254 | tensors, |
| 255 | kernel->extra_.api_->KernelContext_GetInputCount(context), |
| 256 | kernel->extra_.api_->KernelContext_GetOutputCount(context), |
| 257 | kernel->extra_.ep_); |
| 258 | return std::apply([kernel](Args const&... t_args) { |
| 259 | auto status = kernel->Compute(t_args...); |
| 260 | return ToApiStatus(status); }, t); |
| 261 | }; |
| 262 | |
| 263 | OrtCustomOp::KernelDestroy = [](void* op_kernel) { |
| 264 | std::unique_ptr<KernelEx>(reinterpret_cast<KernelEx*>(op_kernel)).reset(); |
| 265 | }; |
| 266 | } |
| 267 | #endif // ORT_API_VERSION >= 16 |
| 268 | |
| 269 | OrtLiteCustomStructV2(const char* op_name, |
| 270 | const char* execution_provider, |
| 271 | RegularComputeType fn_compute = nullptr) |
| 272 | : OrtLiteCustomOp(op_name, execution_provider), regular_fn_(fn_compute) { |
| 273 | ParseArgs(&CustomOpKernel::Compute); |
| 274 | |
| 275 | #if ORT_API_VERSION >= 16 |
| 276 | if (OrtCustomOp::version >= 16) { |
| 277 | DefineCallbackFunctions(&CustomOpKernel::Compute, fn_compute); |
| 278 | } else |
| 279 | #endif // ORT_API_VERSION >= 16 |
| 280 | { |
| 281 | DefineCallbackFunctionsLegacy(&CustomOpKernel::Compute, fn_compute); |
| 282 | } |
| 283 | } |
| 284 | |
| 285 | RegularComputeType regular_fn_{}; |
| 286 | }; |
| 287 | |
| 288 | template <typename RType, typename... Args> |
| 289 | std::shared_ptr<OrtLiteCustomOp> CreateLiteCustomOpV2(const char* op_name, |
| 290 | const char* execution_provider, |
| 291 | RType (*custom_compute_fn)(Args...)) { |
| 292 | using LiteOp = OrtLiteCustomStructV2<FunctionKernel<RType, Args...>>; |
| 293 | return std::make_shared<LiteOp>(op_name, execution_provider, custom_compute_fn); |
| 294 | } |
| 295 | |
| 296 | template <typename OpKernel> |
| 297 | std::shared_ptr<OrtLiteCustomOp> CreateLiteCustomOpV2(const char* op_name, |
| 298 | const char* execution_provider) { |
| 299 | using LiteOp = OrtLiteCustomStructV2<OpKernel>; |
| 300 | return std::make_shared<LiteOp>(op_name, execution_provider); |
| 301 | } |
| 302 | |
| 303 | } // namespace Custom |
| 304 | } // namespace Ort |
| 305 | |
| 306 | namespace ortc = Ort::Custom; |
| 307 | |