microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
a6012b383e329c194ba2a3e21368a409800eb8ab

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/onnxruntime_customop.hpp

305lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4//.A very thin wrapper of ONNXRuntime Custom Operator Callback ABI, which
5// is only used in the custom-op kernels. For the general ORT C++ invocation, like end-to-end
6// testing, the ONNXRuntime public C++ APIs should be used since there is no binary compatible requirement.
7
8#pragma once
9#include <cstddef>
10#include <array>
11#include <memory>
12#include <string>
13#include <vector>
14#include <utility>
15#include <type_traits>
16
17#include "onnxruntime_c_api.h"
18
19#include "exceptions.h"
20
21#define MIN_ORT_VERSION_SUPPORTED 10
22
23namespace OrtW {
24
25//
26// Custom OPs (only needed to implement custom OPs)
27//
28struct CustomOpApi {
29 CustomOpApi(const OrtApi& api) : api_(api) {}
30
31 template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
32 T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const;
33
34 OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value) const;
35 size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
36 ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const;
37 size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
38 void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values,
39 size_t dim_values_length) const;
40 void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const;
41
42 template <typename T>
43 T* GetTensorMutableData(_Inout_ OrtValue* value) const;
44 template <typename T>
45 const T* GetTensorData(_Inout_ const OrtValue* value) const;
46
47 std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const;
48 void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const;
49 size_t KernelContext_GetInputCount(const OrtKernelContext* context) const;
50 const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const;
51 size_t KernelContext_GetOutputCount(const OrtKernelContext* context) const;
52 OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values,
53 size_t dim_count) const;
54
55 void ThrowOnError(OrtStatus* status) const {
56 OrtW::ThrowOnError(api_, status);
57 }
58
59 const OrtApi& GetOrtApi() const { return api_; }
60
61 private:
62 const OrtApi& api_;
63};
64
65template <typename TOp, typename TKernel>
66struct CustomOpBase : OrtCustomOp {
67 CustomOpBase() {
68 OrtCustomOp::version = MIN_ORT_VERSION_SUPPORTED; // The minimum ORT version supported
69 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) {
70 void* result = nullptr;
71 OCOS_API_IMPL_BEGIN
72 result = static_cast<const TOp*>(this_)->CreateKernel(*api, *info);
73 OCOS_API_IMPL_END
74 return result;
75 };
76
77 OrtCustomOp::GetName = [](const OrtCustomOp* this_) noexcept {
78 return static_cast<const TOp*>(this_)->GetName();
79 };
80
81 OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) noexcept {
82 return static_cast<const TOp*>(this_)->GetExecutionProviderType();
83 };
84
85 OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) noexcept {
86 return static_cast<const TOp*>(this_)->GetInputTypeCount();
87 };
88
89 OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) noexcept {
90 return static_cast<const TOp*>(this_)->GetInputType(index);
91 };
92
93 OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) noexcept {
94 return static_cast<const TOp*>(this_)->GetOutputTypeCount();
95 };
96
97 OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) noexcept {
98 return static_cast<const TOp*>(this_)->GetOutputType(index);
99 };
100
101 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
102 OCOS_API_IMPL_BEGIN
103 static_cast<TKernel*>(op_kernel)->Compute(context);
104 OCOS_API_IMPL_END
105 };
106
107#if defined(_MSC_VER) && !defined(__clang__)
108#pragma warning(push)
109#pragma warning(disable : 26409)
110#endif
111 OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
112#if defined(_MSC_VER) && !defined(__clang__)
113#pragma warning(pop)
114#endif
115 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) noexcept {
116 return static_cast<const TOp*>(this_)->GetInputCharacteristic(index);
117 };
118
119 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) noexcept {
120 return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index);
121 };
122 }
123
124 // default implementation. we can't use a virtual function as the layout of this struct has to be aligned with
125 // OrtCustomOp, but a derived class can override by creating a function with the same name and signature,
126 // calling this base class implementation as needed. e.g. see CustomOpThree in the unit test code
127 void* CreateKernel(const OrtApi& api, const OrtKernelInfo& info) const {
128#if defined(_MSC_VER) && !defined(__clang__)
129#pragma warning(push)
130#pragma warning(disable : 26409)
131#endif
132 return new TKernel(api, info);
133#if defined(_MSC_VER) && !defined(__clang__)
134#pragma warning(pop)
135#endif
136 }
137
138 // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
139 const char* GetExecutionProviderType() const { return nullptr; }
140
141 // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
142 // (inputs and outputs are required by default)
143 OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const {
144 return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
145 }
146
147 OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
148 return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
149 }
150};
151
152//
153// Custom OP API Inlines
154//
155
156template <>
157inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
158 float out;
159 ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
160 return out;
161}
162
163template <>
164inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
165 int64_t out;
166 ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
167 return out;
168}
169
170template <>
171inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
172 size_t size = 0;
173 std::string out;
174
175 // Feed nullptr for the data buffer to query the true size of the string attribute
176 OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
177
178 if (status == nullptr) {
179 out.resize(size);
180 ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
181 out.resize(size - 1); // remove the terminating character '\0'
182 } else {
183 ThrowOnError(status);
184 }
185 return out;
186}
187
188template <>
189inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
190 size_t size = 0;
191 std::vector<float> out;
192
193 // Feed nullptr for the data buffer to query the true size of the attribute
194 OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
195
196 if (status == nullptr) {
197 out.resize(size);
198 ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
199 } else {
200 ThrowOnError(status);
201 }
202 return out;
203}
204
205template <>
206inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
207 size_t size = 0;
208 std::vector<int64_t> out;
209
210 // Feed nullptr for the data buffer to query the true size of the attribute
211 OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
212
213 if (status == nullptr) {
214 out.resize(size);
215 ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
216 } else {
217 ThrowOnError(status);
218 }
219 return out;
220}
221
222inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) const {
223 OrtTensorTypeAndShapeInfo* out;
224 ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
225 return out;
226}
227
228inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
229 size_t out;
230 ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
231 return out;
232}
233
234inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const {
235 ONNXTensorElementDataType out;
236 ThrowOnError(api_.GetTensorElementType(info, &out));
237 return out;
238}
239
240inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
241 size_t out;
242 ThrowOnError(api_.GetDimensionsCount(info, &out));
243 return out;
244}
245
246inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) const {
247 ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
248}
249
250inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const {
251 ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
252}
253
254template <typename T>
255inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) const {
256 T* data = nullptr;
257 ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
258 return data;
259}
260
261template <typename T>
262inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) const {
263 return GetTensorMutableData<T>(const_cast<OrtValue*>(value));
264}
265
266inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const {
267 std::vector<int64_t> output(GetDimensionsCount(info));
268 GetDimensions(info, output.data(), output.size());
269 return output;
270}
271
272inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const {
273 api_.ReleaseTensorTypeAndShapeInfo(input);
274}
275
276inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) const {
277 size_t out;
278 ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
279 return out;
280}
281
282inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const {
283 const OrtValue* out;
284 ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
285 return out;
286}
287
288inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) const {
289 size_t out;
290 ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
291 return out;
292}
293
294inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
295 _In_ const int64_t* dim_values, size_t dim_count) const {
296 OrtValue* out;
297 ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
298 return out;
299}
300
301} // namespace OrtW
302
303// !! TODO: only do it for legecy ort build
304#include "custom_op_lite.h"
305namespace ortc = Ort::Custom;
306