microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
62d8598b6b9fa462a440ade891017eaafd4bfaee

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/onnxruntime_customop.hpp

307lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4//.A very thin wrapper of ONNXRuntime Custom Operator Callback ABI, which
5// is only used in the custom-op kernels. For the general ORT C++ invocation, like end-to-end
6// testing, the ONNXRuntime public C++ APIs should be used since there is no binary compatible requirement.
7
8#pragma once
9#include <cstddef>
10#include <array>
11#include <memory>
12#include <string>
13#include <vector>
14#include <utility>
15#include <type_traits>
16
17#include "onnxruntime_c_api.h"
18#include "exceptions.h"
19
20#define MIN_ORT_VERSION_SUPPORTED 10
21
22extern "C" int ORT_API_CALL GetActiveOrtAPIVersion();
23
24
25namespace OrtW {
26
27//
28// Custom OPs (only needed to implement custom OPs)
29//
30struct CustomOpApi {
31 CustomOpApi(const OrtApi& api) : api_(api) {}
32
33 template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
34 T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const;
35
36 OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value) const;
37 size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
38 ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const;
39 size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
40 void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values,
41 size_t dim_values_length) const;
42 void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const;
43
44 template <typename T>
45 T* GetTensorMutableData(_Inout_ OrtValue* value) const;
46 template <typename T>
47 const T* GetTensorData(_Inout_ const OrtValue* value) const;
48
49 std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const;
50 void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const;
51 size_t KernelContext_GetInputCount(const OrtKernelContext* context) const;
52 const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const;
53 size_t KernelContext_GetOutputCount(const OrtKernelContext* context) const;
54 OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values,
55 size_t dim_count) const;
56
57 void ThrowOnError(OrtStatus* status) const {
58 OrtW::ThrowOnError(api_, status);
59 }
60
61 const OrtApi& GetOrtApi() const { return api_; }
62
63 private:
64 const OrtApi& api_;
65};
66
67template <typename TOp, typename TKernel>
68struct CustomOpBase : OrtCustomOp {
69 CustomOpBase() {
70 OrtCustomOp::version = MIN_ORT_VERSION_SUPPORTED; // The minimum ORT version supported
71 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) {
72 void* result = nullptr;
73 OCOS_API_IMPL_BEGIN
74 result = static_cast<const TOp*>(this_)->CreateKernel(*api, *info);
75 OCOS_API_IMPL_END
76 return result;
77 };
78
79 OrtCustomOp::GetName = [](const OrtCustomOp* this_) noexcept {
80 return static_cast<const TOp*>(this_)->GetName();
81 };
82
83 OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) noexcept {
84 return static_cast<const TOp*>(this_)->GetExecutionProviderType();
85 };
86
87 OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) noexcept {
88 return static_cast<const TOp*>(this_)->GetInputTypeCount();
89 };
90
91 OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) noexcept {
92 return static_cast<const TOp*>(this_)->GetInputType(index);
93 };
94
95 OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) noexcept {
96 return static_cast<const TOp*>(this_)->GetOutputTypeCount();
97 };
98
99 OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) noexcept {
100 return static_cast<const TOp*>(this_)->GetOutputType(index);
101 };
102
103 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
104 OCOS_API_IMPL_BEGIN
105 static_cast<TKernel*>(op_kernel)->Compute(context);
106 OCOS_API_IMPL_END
107 };
108
109#if defined(_MSC_VER) && !defined(__clang__)
110#pragma warning(push)
111#pragma warning(disable : 26409)
112#endif
113 OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
114#if defined(_MSC_VER) && !defined(__clang__)
115#pragma warning(pop)
116#endif
117 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* this_, size_t index) noexcept {
118 return static_cast<const TOp*>(this_)->GetInputCharacteristic(index);
119 };
120
121 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* this_, size_t index) noexcept {
122 return static_cast<const TOp*>(this_)->GetOutputCharacteristic(index);
123 };
124 }
125
126 // default implementation. we can't use a virtual function as the layout of this struct has to be aligned with
127 // OrtCustomOp, but a derived class can override by creating a function with the same name and signature,
128 // calling this base class implementation as needed. e.g. see CustomOpThree in the unit test code
129 void* CreateKernel(const OrtApi& api, const OrtKernelInfo& info) const {
130#if defined(_MSC_VER) && !defined(__clang__)
131#pragma warning(push)
132#pragma warning(disable : 26409)
133#endif
134 return new TKernel(api, info);
135#if defined(_MSC_VER) && !defined(__clang__)
136#pragma warning(pop)
137#endif
138 }
139
140 // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
141 const char* GetExecutionProviderType() const { return nullptr; }
142
143 // Default implementations of GetInputCharacteristic() and GetOutputCharacteristic() below
144 // (inputs and outputs are required by default)
145 OrtCustomOpInputOutputCharacteristic GetInputCharacteristic(size_t /*index*/) const {
146 return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
147 }
148
149 OrtCustomOpInputOutputCharacteristic GetOutputCharacteristic(size_t /*index*/) const {
150 return OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED;
151 }
152};
153
154//
155// Custom OP API Inlines
156//
157
158template <>
159inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
160 float out;
161 ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
162 return out;
163}
164
165template <>
166inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
167 int64_t out;
168 ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
169 return out;
170}
171
172template <>
173inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
174 size_t size = 0;
175 std::string out;
176
177 // Feed nullptr for the data buffer to query the true size of the string attribute
178 OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
179
180 if (status == nullptr) {
181 out.resize(size);
182 ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
183 out.resize(size - 1); // remove the terminating character '\0'
184 } else {
185 ThrowOnError(status);
186 }
187 return out;
188}
189
190template <>
191inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
192 size_t size = 0;
193 std::vector<float> out;
194
195 // Feed nullptr for the data buffer to query the true size of the attribute
196 OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
197
198 if (status == nullptr) {
199 out.resize(size);
200 ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
201 } else {
202 ThrowOnError(status);
203 }
204 return out;
205}
206
207template <>
208inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
209 size_t size = 0;
210 std::vector<int64_t> out;
211
212 // Feed nullptr for the data buffer to query the true size of the attribute
213 OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
214
215 if (status == nullptr) {
216 out.resize(size);
217 ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
218 } else {
219 ThrowOnError(status);
220 }
221 return out;
222}
223
224inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) const {
225 OrtTensorTypeAndShapeInfo* out;
226 ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
227 return out;
228}
229
230inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
231 size_t out;
232 ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
233 return out;
234}
235
236inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const {
237 ONNXTensorElementDataType out;
238 ThrowOnError(api_.GetTensorElementType(info, &out));
239 return out;
240}
241
242inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
243 size_t out;
244 ThrowOnError(api_.GetDimensionsCount(info, &out));
245 return out;
246}
247
248inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) const {
249 ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
250}
251
252inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const {
253 ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
254}
255
256template <typename T>
257inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) const {
258 T* data = nullptr;
259 ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
260 return data;
261}
262
263template <typename T>
264inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) const {
265 return GetTensorMutableData<T>(const_cast<OrtValue*>(value));
266}
267
268inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const {
269 std::vector<int64_t> output(GetDimensionsCount(info));
270 GetDimensions(info, output.data(), output.size());
271 return output;
272}
273
274inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const {
275 api_.ReleaseTensorTypeAndShapeInfo(input);
276}
277
278inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) const {
279 size_t out;
280 ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
281 return out;
282}
283
284inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const {
285 const OrtValue* out;
286 ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
287 return out;
288}
289
290inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) const {
291 size_t out;
292 ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
293 return out;
294}
295
296inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
297 _In_ const int64_t* dim_values, size_t dim_count) const {
298 OrtValue* out;
299 ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
300 return out;
301}
302
303} // namespace OrtW
304
305// !! TODO: only do it for legecy ort build
306#include "custom_op_lite.h"
307namespace ortc = Ort::Custom;
308