microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
natke-patch-1

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/onnxruntime_customop.hpp

297lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4//.A very thin wrapper of ONNXRuntime Custom Operator Callback ABI, which
5// is only used in the custom-op kernels. For the general ORT C++ invocation, like end-to-end
6// testing, the ONNXRuntime public C++ APIs should be used since there is no binary compatible requirement.
7
8#pragma once
9#include <cstddef>
10#include <array>
11#include <memory>
12#include <string>
13#include <vector>
14#include <utility>
15#include <type_traits>
16
17#include "onnxruntime_c_api.h"
18
19#include "exceptions.h"
20
21namespace OrtW {
22
23//
24// Custom OPs (only needed to implement custom OPs)
25//
26struct CustomOpApi {
27 CustomOpApi(const OrtApi& api) : api_(api) {}
28
29 template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
30 T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const;
31
32 OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value) const;
33 size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
34 ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const;
35 size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
36 void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values,
37 size_t dim_values_length) const;
38 void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const;
39
40 template <typename T>
41 T* GetTensorMutableData(_Inout_ OrtValue* value) const;
42 template <typename T>
43 const T* GetTensorData(_Inout_ const OrtValue* value) const;
44
45 std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const;
46 void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const;
47 size_t KernelContext_GetInputCount(const OrtKernelContext* context) const;
48 const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const;
49 size_t KernelContext_GetOutputCount(const OrtKernelContext* context) const;
50 OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values,
51 size_t dim_count) const;
52
53 void ThrowOnError(OrtStatus* status) const {
54 OrtW::ThrowOnError(api_, status);
55 }
56
57 private:
58 const OrtApi& api_;
59};
60
61template <typename TOp, typename TKernel>
62struct CustomOpBase : OrtCustomOp {
63 CustomOpBase() {
64 OrtCustomOp::version = 10; // The minimum ORT version supported
65 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) {
66 void* result = nullptr;
67 OCOS_API_IMPL_BEGIN
68 result = static_cast<const TOp*>(this_)->CreateKernel(*api, *info);
69 OCOS_API_IMPL_END
70 return result;
71 };
72
73 OrtCustomOp::GetName = [](const OrtCustomOp* this_) noexcept {
74 return static_cast<const TOp*>(this_)->GetName();
75 };
76
77 OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) noexcept {
78 return static_cast<const TOp*>(this_)->GetExecutionProviderType();
79 };
80
81 OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) noexcept {
82 return static_cast<const TOp*>(this_)->GetInputTypeCount();
83 };
84
85 OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) noexcept {
86 return static_cast<const TOp*>(this_)->GetInputType(index);
87 };
88
89 OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) noexcept {
90 return static_cast<const TOp*>(this_)->GetOutputTypeCount();
91 };
92
93 OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) noexcept {
94 return static_cast<const TOp*>(this_)->GetOutputType(index);
95 };
96
97 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
98 OCOS_API_IMPL_BEGIN
99 static_cast<TKernel*>(op_kernel)->Compute(context);
100 OCOS_API_IMPL_END
101 };
102
103#if defined(_MSC_VER) && !defined(__clang__)
104#pragma warning(push)
105#pragma warning(disable : 26409)
106#endif
107 OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
108#if defined(_MSC_VER) && !defined(__clang__)
109#pragma warning(pop)
110#endif
111 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) noexcept {
112 return static_cast<const TOp*>(this_)->GetInputCharacteristic(index);
113 };
114
115 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) noexcept {
116 return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index);
117 };
118 }
119
120 // default implementation. we can't use a virtual function as the layout of this struct has to be aligned with
121 // OrtCustomOp, but a derived class can override by creating a function with the same name and signature,
122 // calling this base class implementation as needed. e.g. see CustomOpThree in the unit test code
123 void* CreateKernel(const OrtApi& api, const OrtKernelInfo& info) const {
124#if defined(_MSC_VER) && !defined(__clang__)
125#pragma warning(push)
126#pragma warning(disable : 26409)
127#endif
128 return new TKernel(api, info);
129#if defined(_MSC_VER) && !defined(__clang__)
130#pragma warning(pop)
131#endif
132 }
133
134 // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
135 const char* GetExecutionProviderType() const { return nullptr; }
136
137 // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
138 // (inputs and outputs are required by default)
139 OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const {
140 return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
141 }
142
143 OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
144 return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
145 }
146};
147
148//
149// Custom OP API Inlines
150//
151
152template <>
153inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
154 float out;
155 ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
156 return out;
157}
158
159template <>
160inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
161 int64_t out;
162 ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
163 return out;
164}
165
166template <>
167inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
168 size_t size = 0;
169 std::string out;
170
171 // Feed nullptr for the data buffer to query the true size of the string attribute
172 OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
173
174 if (status == nullptr) {
175 out.resize(size);
176 ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
177 out.resize(size - 1); // remove the terminating character '\0'
178 } else {
179 ThrowOnError(status);
180 }
181 return out;
182}
183
184template <>
185inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
186 size_t size = 0;
187 std::vector<float> out;
188
189 // Feed nullptr for the data buffer to query the true size of the attribute
190 OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
191
192 if (status == nullptr) {
193 out.resize(size);
194 ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
195 } else {
196 ThrowOnError(status);
197 }
198 return out;
199}
200
201template <>
202inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
203 size_t size = 0;
204 std::vector<int64_t> out;
205
206 // Feed nullptr for the data buffer to query the true size of the attribute
207 OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
208
209 if (status == nullptr) {
210 out.resize(size);
211 ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
212 } else {
213 ThrowOnError(status);
214 }
215 return out;
216}
217
218inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) const {
219 OrtTensorTypeAndShapeInfo* out;
220 ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
221 return out;
222}
223
224inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
225 size_t out;
226 ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
227 return out;
228}
229
230inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const {
231 ONNXTensorElementDataType out;
232 ThrowOnError(api_.GetTensorElementType(info, &out));
233 return out;
234}
235
236inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
237 size_t out;
238 ThrowOnError(api_.GetDimensionsCount(info, &out));
239 return out;
240}
241
242inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) const {
243 ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
244}
245
246inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const {
247 ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
248}
249
250template <typename T>
251inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) const {
252 T* data = nullptr;
253 ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
254 return data;
255}
256
257template <typename T>
258inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) const {
259 return GetTensorMutableData<T>(const_cast<OrtValue*>(value));
260}
261
262inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const {
263 std::vector<int64_t> output(GetDimensionsCount(info));
264 GetDimensions(info, output.data(), output.size());
265 return output;
266}
267
268inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const {
269 api_.ReleaseTensorTypeAndShapeInfo(input);
270}
271
272inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) const {
273 size_t out;
274 ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
275 return out;
276}
277
278inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const {
279 const OrtValue* out;
280 ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
281 return out;
282}
283
284inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) const {
285 size_t out;
286 ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
287 return out;
288}
289
290inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
291 _In_ const int64_t* dim_values, size_t dim_count) const {
292 OrtValue* out;
293 ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
294 return out;
295}
296
297} // namespace OrtW
298