microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
68b9d1dc47663a9017c55d136c804417c8efec7d

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/onnxruntime_customop.hpp

494lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4//.A very thin wrapper of ONNXRuntime Custom Operator Callback ABI, which
5// is only used in the custom-op kernels. For the general ORT C++ invocation, like end-to-end
6// testing, the ONNXRuntime public C++ APIs should be used since there is no binary compatible requirement.
7
8#pragma once
9#include <cstddef>
10#include <array>
11#include <memory>
12#include <string>
13#include <vector>
14#include <utility>
15#include <type_traits>
16#include <optional>
17
18#include "onnxruntime_c_api.h"
19#include "exceptions.h"
20
21
22#define MIN_ORT_VERSION_SUPPORTED 11
23
24extern "C" int ORT_API_CALL GetActiveOrtAPIVersion();
25
26// namespace of ORT ABI Wrapper
27namespace OrtW {
28
29class API {
30 // To use ONNX C ABI in a way like OrtW::API::CreateStatus.
31 public:
32 static API& instance(const OrtApi* ort_api = nullptr) noexcept {
33 static API self(*ort_api);
34 return self;
35 }
36
37 static OrtStatusPtr CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept {
38 return instance()->CreateStatus(code, msg);
39 }
40
41 static void ReleaseStatus(OrtStatusPtr ptr) noexcept {
42 instance()->ReleaseStatus(ptr);
43 }
44
45 template<typename T>
46 static OrtStatusPtr KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept;
47
48 static void ThrowOnError(OrtStatusPtr ptr) {
49 OrtW::ThrowOnError(instance().api_, ptr);
50 }
51
52private:
53 const OrtApi* operator->() const {
54 return &api_;
55 }
56
57 API(const OrtApi& api) : api_(api) {
58 if (&api == nullptr) {
59 ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION);
60 }
61 }
62 const OrtApi& api_;
63};
64
65//
66// DEPRECTED: Custom OPs (only needed to implement custom OPs)
67//
68struct CustomOpApi {
69 CustomOpApi(const OrtApi& api) : api_(api) {}
70
71 template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
72 T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const;
73
74 OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value) const;
75 size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
76 ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const;
77 size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
78 void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values,
79 size_t dim_values_length) const;
80 void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const;
81
82 template <typename T>
83 T* GetTensorMutableData(_Inout_ OrtValue* value) const;
84 template <typename T>
85 const T* GetTensorData(_Inout_ const OrtValue* value) const;
86
87 std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const;
88 void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const;
89 size_t KernelContext_GetInputCount(const OrtKernelContext* context) const;
90 const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const;
91 size_t KernelContext_GetOutputCount(const OrtKernelContext* context) const;
92 OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values,
93 size_t dim_count) const;
94
95 void ThrowOnError(OrtStatus* status) const {
96 OrtW::ThrowOnError(api_, status);
97 }
98
99 const OrtApi& GetOrtApi() const { return api_; }
100
101 private:
102 const OrtApi& api_;
103};
104
105//
106// Custom OP API Inlines
107//
108
109template <>
110inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
111 float out;
112 ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
113 return out;
114}
115
116template <>
117inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
118 int64_t out;
119 ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
120 return out;
121}
122
123template <>
124inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
125 size_t size = 0;
126 std::string out;
127
128 // Feed nullptr for the data buffer to query the true size of the string attribute
129 OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
130
131 if (status == nullptr) {
132 out.resize(size);
133 ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
134 out.resize(size - 1); // remove the terminating character '\0'
135 } else {
136 ThrowOnError(status);
137 }
138 return out;
139}
140
141template <>
142inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
143 size_t size = 0;
144 std::vector<float> out;
145
146 // Feed nullptr for the data buffer to query the true size of the attribute
147 OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
148
149 if (status == nullptr) {
150 out.resize(size);
151 ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
152 } else {
153 ThrowOnError(status);
154 }
155 return out;
156}
157
158template <>
159inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
160 size_t size = 0;
161 std::vector<int64_t> out;
162
163 // Feed nullptr for the data buffer to query the true size of the attribute
164 OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
165
166 if (status == nullptr) {
167 out.resize(size);
168 ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
169 } else {
170 ThrowOnError(status);
171 }
172 return out;
173}
174
175inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) const {
176 OrtTensorTypeAndShapeInfo* out;
177 ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
178 return out;
179}
180
181inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
182 size_t out;
183 ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
184 return out;
185}
186
187inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const {
188 ONNXTensorElementDataType out;
189 ThrowOnError(api_.GetTensorElementType(info, &out));
190 return out;
191}
192
193inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
194 size_t out;
195 ThrowOnError(api_.GetDimensionsCount(info, &out));
196 return out;
197}
198
199inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) const {
200 ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
201}
202
203inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const {
204 ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
205}
206
207template <typename T>
208inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) const {
209 T* data = nullptr;
210 ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
211 return data;
212}
213
214template <typename T>
215inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) const {
216 return GetTensorMutableData<T>(const_cast<OrtValue*>(value));
217}
218
219inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const {
220 std::vector<int64_t> output(GetDimensionsCount(info));
221 GetDimensions(info, output.data(), output.size());
222 return output;
223}
224
225inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const {
226 api_.ReleaseTensorTypeAndShapeInfo(input);
227}
228
229inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) const {
230 size_t out;
231 ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
232 return out;
233}
234
235inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const {
236 const OrtValue* out;
237 ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
238 return out;
239}
240
241inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) const {
242 size_t out;
243 ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
244 return out;
245}
246
247inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
248 _In_ const int64_t* dim_values, size_t dim_count) const {
249 OrtValue* out;
250 ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
251 return out;
252}
253
254template <>
255inline OrtStatusPtr API::KernelInfoGetAttribute<int64_t>(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept {
256 return instance()->KernelInfoGetAttribute_int64(&info, name, &value);
257}
258
259template <>
260inline OrtStatusPtr API::KernelInfoGetAttribute<float>(const OrtKernelInfo& info, const char* name, float& value) noexcept {
261 return instance()->KernelInfoGetAttribute_float(&info, name, &value);
262}
263
264template <class T>
265 static OrtStatusPtr GetOpAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept {
266 if (auto status = API::KernelInfoGetAttribute(info, name, value); status) {
267 // Ideally, we should know which kind of error code can be ignored, but it is not availabe now.
268 // Just ignore all of them.
269 API::ReleaseStatus(status);
270 }
271
272 return nullptr;
273 }
274
275
276inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) {
277 return API::CreateStatus(code, msg);
278}
279
280} // namespace OrtW
281
282
283#if ORT_API_VERSION < 15
284#include "custom_op_lite.h"
285
286#else
287// From onnxruntime 1.17, the custom op lite API header is used the one from onnxruntime package.
288// #include "onnxruntime_lite_custom_op.h"
289// The existing custom op lite API header has more features than the one from onnxruntime 1.16.
290#include "custom_op_lite.h"
291
292#endif // ORT_API_VERSION < 15
293
294
295
296namespace Ort {
297namespace Custom {
298
299
300template <typename... Args>
301struct FunctionKernel {
302 using ComputeFn = std::function<OrtStatusPtr(Args...)>;
303
304 OrtStatusPtr Compute(Args... args) const {
305 return compute_fn_(args...);
306 }
307
308 ComputeFn compute_fn_;
309};
310
311// primary template handles types that have no nested ::type member:
312template <class, class = void>
313struct IsFunctionKernel {
314 typedef std::false_type type;
315};
316
317// specialization recognizes types that do have a nested ::type member:
318template <class T>
319struct IsFunctionKernel<T, std::void_t<typename T::ComputeFn>>{
320 typedef std::true_type type;
321};
322
323// Helper type
324template <typename T>
325struct ComputeArgsList;
326
327// Specialization for member function
328template <typename C, typename... Args>
329struct ComputeArgsList<OrtStatusPtr (C::*)(Args...) const> {
330 using FunctionType = OrtStatusPtr (*)(Args...);
331 using MemberFunctionType = OrtStatusPtr (C::*)(Args...) const;
332};
333
334template <typename CustomOpKernel>
335struct OrtLiteCustomStructV2 : public OrtLiteCustomOp {
336 using ComputeFunction = decltype(&CustomOpKernel::Compute);
337 using RegularComputeType = typename ComputeArgsList<ComputeFunction>::FunctionType;
338
339 template <typename... Args>
340 using MemberComputeType = OrtStatusPtr (CustomOpKernel::*)(Args...) const;
341
342 struct KernelEx : public CustomOpKernel {
343 struct {
344 std::string ep_{};
345 std::unique_ptr<OrtW::CustomOpApi> api_;
346 } extra_;
347 };
348
349 template <typename T>
350 static OrtStatusPtr InitKernel(KernelEx& kernel,
351 const OrtApi& api, const OrtKernelInfo& info, RegularComputeType fn, T t) {
352 return kernel.OnModelAttach(api, info);
353 }
354
355 static OrtStatusPtr InitKernel(
356 KernelEx& kernel,
357 const OrtApi& api, const OrtKernelInfo& info, RegularComputeType fn, std::true_type) {
358 kernel.compute_fn_ = fn;
359 return nullptr;
360 }
361
362 template <typename... Args>
363 void ParseArgs(MemberComputeType<Args...> fn) {
364 OrtLiteCustomOp::ParseArgs<Args...>(OrtLiteCustomOp::input_types_, OrtLiteCustomOp::output_types_);
365 }
366
367 // TODO: consider to disable these legacy functions for mobile build to save binary size
368 template <typename... Args>
369 void DefineCallbackFunctionsLegacy(MemberComputeType<Args...> fn, RegularComputeType regular_fn) {
370
371 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
372 auto self = static_cast<const OrtLiteCustomStructV2<CustomOpKernel>*>(this_);
373 auto kernel = std::make_unique<KernelEx>();
374 typedef typename IsFunctionKernel<CustomOpKernel>::type type_flag;
375 OrtStatusPtr status = InitKernel(*kernel, *ort_api, *info, self->regular_fn_, type_flag());
376 OrtW::ThrowOnError(*ort_api, status);
377
378 kernel->extra_.ep_ = self->execution_provider_;
379 kernel->extra_.api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
380 return reinterpret_cast<void*>(kernel.release());
381 };
382
383 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
384 auto kernel = reinterpret_cast<KernelEx*>(op_kernel);
385 std::vector<TensorPtr> tensors;
386 auto t = CreateTuple<0, 0, Args...>(kernel->extra_.api_.get(),
387 context,
388 tensors,
389 kernel->extra_.api_->KernelContext_GetInputCount(context),
390 kernel->extra_.api_->KernelContext_GetOutputCount(context),
391 kernel->extra_.ep_);
392 std::apply([kernel](Args const&... t_args) {
393 auto status = kernel->Compute(t_args...); OrtW::API::ThrowOnError(status);}, t);
394 };
395
396 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
397 std::unique_ptr<KernelEx>(reinterpret_cast<KernelEx*>(op_kernel)).reset();
398 };
399 }
400
401#if ORT_API_VERSION > 15
402 template <typename... Args>
403 void DefineCallbackFunctions(MemberComputeType<Args...> fn, RegularComputeType regular_fn) {
404 OrtCustomOp::CreateKernel = nullptr;
405 OrtCustomOp::KernelCompute = nullptr;
406
407 OrtCustomOp::CreateKernelV2 = [](const OrtCustomOp* this_,
408 const OrtApi* api, const OrtKernelInfo* info, void** op_kernel) -> OrtStatusPtr {
409 if (api == nullptr) {
410 assert(false && "Got a null pointer for ORT api on calling CreateKernelV2");
411 // should never happened, what we can do?
412 return nullptr;
413 }
414
415 if (this_ == nullptr || info == nullptr || op_kernel == nullptr) {
416 return api->CreateStatus(ORT_INVALID_ARGUMENT, "OrtCustomOp::CreateKernelV2: received a null pointer");
417 }
418
419 auto self = static_cast<const OrtLiteCustomStructV2<CustomOpKernel>*>(this_);
420 auto kernel = std::make_unique<KernelEx>();
421 if (kernel == nullptr) {
422 return api->CreateStatus(ORT_FAIL, "OrtCustomOp::CreateKernelV2: failed to new a kernel, OOM?");
423 }
424
425 typedef typename IsFunctionKernel<CustomOpKernel>::type flag_type;
426 OrtStatusPtr status = InitKernel(*kernel, *api, *info, self->regular_fn_, flag_type());
427 if (status == nullptr) {
428 kernel->extra_.ep_ = self->execution_provider_;
429 kernel->extra_.api_ = std::make_unique<OrtW::CustomOpApi>(*api);
430 *op_kernel = reinterpret_cast<void*>(kernel.release());
431 }
432
433 return status;
434 };
435
436 OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
437 auto kernel = reinterpret_cast<KernelEx* >(op_kernel);
438 std::vector<TensorPtr> tensors;
439 auto t = CreateTuple<0, 0, Args...>(kernel->extra_.api_.get(),
440 context,
441 tensors,
442 kernel->extra_.api_->KernelContext_GetInputCount(context),
443 kernel->extra_.api_->KernelContext_GetOutputCount(context),
444 kernel->extra_.ep_);
445 return std::apply([kernel](Args const&... t_args) {
446 return kernel->Compute(t_args...); }, t);
447 };
448
449 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
450 std::unique_ptr<KernelEx>(reinterpret_cast<KernelEx*>(op_kernel)).reset();
451 };
452 }
453#endif // ORT_API_VERSION > 15
454
455 OrtLiteCustomStructV2(const char* op_name,
456 const char* execution_provider,
457 RegularComputeType fn_compute = nullptr)
458 : OrtLiteCustomOp(op_name, execution_provider), regular_fn_(fn_compute) {
459
460 ParseArgs(&CustomOpKernel::Compute);
461
462#if ORT_API_VERSION > 15
463 if (OrtCustomOp::version > 15) {
464 DefineCallbackFunctions(&CustomOpKernel::Compute, fn_compute);
465 } else
466#endif // ORT_API_VERSION > 15
467
468 {
469 DefineCallbackFunctionsLegacy(&CustomOpKernel::Compute, fn_compute);
470 }
471 }
472
473 RegularComputeType regular_fn_{};
474};
475
476template <typename... Args>
477OrtLiteCustomOp* CreateLiteCustomOpV2(const char* op_name,
478 const char* execution_provider,
479 OrtStatusPtr (*custom_compute_fn)(Args...)) {
480 using LiteOp = OrtLiteCustomStructV2<FunctionKernel<Args...>>;
481 return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn).release();
482}
483
484template <typename OpKernel>
485OrtLiteCustomOp* CreateLiteCustomOpV2(const char* op_name,
486 const char* execution_provider) {
487 using LiteOp = OrtLiteCustomStructV2<OpKernel>;
488 return std::make_unique<LiteOp>(op_name, execution_provider).release();
489}
490
491} // namespace Custom
492} // namespace Ort
493
494namespace ortc = Ort::Custom;
495