microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
1ac33abd689f90b362f127039232175913c60d08

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/onnxruntime_customop.hpp

306lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4//.A very thin wrapper of ONNXRuntime Custom Operator Callback ABI, which
5// is only used in the custom-op kernels. For the general ORT C++ invocation, like end-to-end
6// testing, the ONNXRuntime public C++ APIs should be used since there is no binary compatible requirement.
7
8#pragma once
9#include <cstddef>
10#include <array>
11#include <memory>
12#include <string>
13#include <vector>
14#include <utility>
15#include <type_traits>
16
17#include "onnxruntime_c_api.h"
18#include "exceptions.h"
19
20#define MIN_ORT_VERSION_SUPPORTED 10
21
22extern "C" int ORT_API_CALL GetActiveOrtAPIVersion();
23
24namespace OrtW {
25
26//
27// Custom OPs (only needed to implement custom OPs)
28//
29struct CustomOpApi {
30 CustomOpApi(const OrtApi& api) : api_(api) {}
31
32 template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
33 T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const;
34
35 OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value) const;
36 size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
37 ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const;
38 size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
39 void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values,
40 size_t dim_values_length) const;
41 void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const;
42
43 template <typename T>
44 T* GetTensorMutableData(_Inout_ OrtValue* value) const;
45 template <typename T>
46 const T* GetTensorData(_Inout_ const OrtValue* value) const;
47
48 std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const;
49 void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const;
50 size_t KernelContext_GetInputCount(const OrtKernelContext* context) const;
51 const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const;
52 size_t KernelContext_GetOutputCount(const OrtKernelContext* context) const;
53 OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values,
54 size_t dim_count) const;
55
56 void ThrowOnError(OrtStatus* status) const {
57 OrtW::ThrowOnError(api_, status);
58 }
59
60 const OrtApi& GetOrtApi() const { return api_; }
61
62 private:
63 const OrtApi& api_;
64};
65
66template <typename TOp, typename TKernel>
67struct CustomOpBase : OrtCustomOp {
68 CustomOpBase() {
69 OrtCustomOp::version = MIN_ORT_VERSION_SUPPORTED; // The minimum ORT version supported
70 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) {
71 void* result = nullptr;
72 OCOS_API_IMPL_BEGIN
73 result = static_cast<const TOp*>(this_)->CreateKernel(*api, *info);
74 OCOS_API_IMPL_END
75 return result;
76 };
77
78 OrtCustomOp::GetName = [](const OrtCustomOp* this_) noexcept {
79 return static_cast<const TOp*>(this_)->GetName();
80 };
81
82 OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) noexcept {
83 return static_cast<const TOp*>(this_)->GetExecutionProviderType();
84 };
85
86 OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) noexcept {
87 return static_cast<const TOp*>(this_)->GetInputTypeCount();
88 };
89
90 OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) noexcept {
91 return static_cast<const TOp*>(this_)->GetInputType(index);
92 };
93
94 OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) noexcept {
95 return static_cast<const TOp*>(this_)->GetOutputTypeCount();
96 };
97
98 OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) noexcept {
99 return static_cast<const TOp*>(this_)->GetOutputType(index);
100 };
101
102 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
103 OCOS_API_IMPL_BEGIN
104 static_cast<TKernel*>(op_kernel)->Compute(context);
105 OCOS_API_IMPL_END
106 };
107
108#if defined(_MSC_VER) && !defined(__clang__)
109#pragma warning(push)
110#pragma warning(disable : 26409)
111#endif
112 OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
113#if defined(_MSC_VER) && !defined(__clang__)
114#pragma warning(pop)
115#endif
116 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) noexcept {
117 return static_cast<const TOp*>(this_)->GetInputCharacteristic(index);
118 };
119
120 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) noexcept {
121 return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index);
122 };
123 }
124
125 // default implementation. we can't use a virtual function as the layout of this struct has to be aligned with
126 // OrtCustomOp, but a derived class can override by creating a function with the same name and signature,
127 // calling this base class implementation as needed. e.g. see CustomOpThree in the unit test code
128 void* CreateKernel(const OrtApi& api, const OrtKernelInfo& info) const {
129#if defined(_MSC_VER) && !defined(__clang__)
130#pragma warning(push)
131#pragma warning(disable : 26409)
132#endif
133 return new TKernel(api, info);
134#if defined(_MSC_VER) && !defined(__clang__)
135#pragma warning(pop)
136#endif
137 }
138
139 // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
140 const char* GetExecutionProviderType() const { return nullptr; }
141
142 // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
143 // (inputs and outputs are required by default)
144 OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const {
145 return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
146 }
147
148 OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
149 return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
150 }
151};
152
153//
154// Custom OP API Inlines
155//
156
157template <>
158inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
159 float out;
160 ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
161 return out;
162}
163
164template <>
165inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
166 int64_t out;
167 ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
168 return out;
169}
170
171template <>
172inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
173 size_t size = 0;
174 std::string out;
175
176 // Feed nullptr for the data buffer to query the true size of the string attribute
177 OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
178
179 if (status == nullptr) {
180 out.resize(size);
181 ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
182 out.resize(size - 1); // remove the terminating character '\0'
183 } else {
184 ThrowOnError(status);
185 }
186 return out;
187}
188
189template <>
190inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
191 size_t size = 0;
192 std::vector<float> out;
193
194 // Feed nullptr for the data buffer to query the true size of the attribute
195 OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
196
197 if (status == nullptr) {
198 out.resize(size);
199 ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
200 } else {
201 ThrowOnError(status);
202 }
203 return out;
204}
205
206template <>
207inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
208 size_t size = 0;
209 std::vector<int64_t> out;
210
211 // Feed nullptr for the data buffer to query the true size of the attribute
212 OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
213
214 if (status == nullptr) {
215 out.resize(size);
216 ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
217 } else {
218 ThrowOnError(status);
219 }
220 return out;
221}
222
223inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) const {
224 OrtTensorTypeAndShapeInfo* out;
225 ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
226 return out;
227}
228
229inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
230 size_t out;
231 ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
232 return out;
233}
234
235inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const {
236 ONNXTensorElementDataType out;
237 ThrowOnError(api_.GetTensorElementType(info, &out));
238 return out;
239}
240
241inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
242 size_t out;
243 ThrowOnError(api_.GetDimensionsCount(info, &out));
244 return out;
245}
246
247inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) const {
248 ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
249}
250
251inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const {
252 ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
253}
254
255template <typename T>
256inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) const {
257 T* data = nullptr;
258 ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
259 return data;
260}
261
262template <typename T>
263inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) const {
264 return GetTensorMutableData<T>(const_cast<OrtValue*>(value));
265}
266
267inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const {
268 std::vector<int64_t> output(GetDimensionsCount(info));
269 GetDimensions(info, output.data(), output.size());
270 return output;
271}
272
273inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const {
274 api_.ReleaseTensorTypeAndShapeInfo(input);
275}
276
277inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) const {
278 size_t out;
279 ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
280 return out;
281}
282
283inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const {
284 const OrtValue* out;
285 ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
286 return out;
287}
288
289inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) const {
290 size_t out;
291 ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
292 return out;
293}
294
295inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
296 _In_ const int64_t* dim_values, size_t dim_count) const {
297 OrtValue* out;
298 ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
299 return out;
300}
301
302} // namespace OrtW
303
304// !! TODO: only do it for legecy ort build
305#include "custom_op_lite.h"
306namespace ortc = Ort::Custom;
307