microsoft/onnxruntime-extensions

Public

mirrored from https://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
b341d66ebb794e190613545b11bfa2be0e9c4440

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/onnxruntime_cpp_api_legacy.hpp

264lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5#include <vector>
6#include "exceptions.h"
7
8//
9// DEPRECATED: All new custom OPs should not use any class/struct/functions from this file.
10// TODO: Remove this file once all custom OPs are migrated to the new API
11//
12namespace OrtW {
13
14struct CustomOpApi {
15 CustomOpApi(const OrtApi& api) : api_(api) {}
16
17 template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
18 T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const;
19
20 OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value) const;
21 size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
22 ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const;
23 size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
24 void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values,
25 size_t dim_values_length) const;
26 void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const;
27
28 template <typename T>
29 T* GetTensorMutableData(_Inout_ OrtValue* value) const;
30 template <typename T>
31 const T* GetTensorData(_Inout_ const OrtValue* value) const;
32
33 void* GetTensorMutableRawData(_Inout_ OrtValue* value) const;
34 const void* GetTensorRawData(_Inout_ const OrtValue* value) const;
35
36 std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const;
37 void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const;
38 size_t KernelContext_GetInputCount(const OrtKernelContext* context) const;
39 const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const;
40 size_t KernelContext_GetOutputCount(const OrtKernelContext* context) const;
41 OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values,
42 size_t dim_count) const;
43
44 void ThrowOnError(OrtStatus* status) const {
45 OrtW::ThrowOnError(api_, status);
46 }
47
48 const OrtApi& GetOrtApi() const { return api_; }
49
50 private:
51 const OrtApi& api_;
52};
53
54//
55// Custom OP API Inlines
56//
57
58template <>
59inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
60 float out;
61 ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
62 return out;
63}
64
65template <>
66inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
67 int64_t out;
68 ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
69 return out;
70}
71
72template <>
73inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
74 size_t size = 0;
75 std::string out;
76
77 // Feed nullptr for the data buffer to query the true size of the string attribute
78 OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
79
80 if (status == nullptr) {
81 out.resize(size);
82 ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
83 out.resize(size - 1); // remove the terminating character '\0'
84 } else {
85 ThrowOnError(status);
86 }
87 return out;
88}
89
90template <>
91inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
92 size_t size = 0;
93 std::vector<float> out;
94
95 // Feed nullptr for the data buffer to query the true size of the attribute
96 OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
97
98 if (status == nullptr) {
99 out.resize(size);
100 ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
101 } else {
102 ThrowOnError(status);
103 }
104 return out;
105}
106
107template <>
108inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
109 size_t size = 0;
110 std::vector<int64_t> out;
111
112 // Feed nullptr for the data buffer to query the true size of the attribute
113 OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
114
115 if (status == nullptr) {
116 out.resize(size);
117 ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
118 } else {
119 ThrowOnError(status);
120 }
121 return out;
122}
123
124inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) const {
125 OrtTensorTypeAndShapeInfo* out;
126 ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
127 return out;
128}
129
130inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
131 size_t out;
132 ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
133 return out;
134}
135
136inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const {
137 ONNXTensorElementDataType out;
138 ThrowOnError(api_.GetTensorElementType(info, &out));
139 return out;
140}
141
142inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
143 size_t out;
144 ThrowOnError(api_.GetDimensionsCount(info, &out));
145 return out;
146}
147
148inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) const {
149 ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
150}
151
152inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const {
153 ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
154}
155
156template <typename T>
157inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) const {
158 T* data = nullptr;
159 ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
160 return data;
161}
162
163template <typename T>
164inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) const {
165 return GetTensorMutableData<T>(const_cast<OrtValue*>(value));
166}
167
168inline void* CustomOpApi::GetTensorMutableRawData(_Inout_ OrtValue* value) const {
169 void* data = nullptr;
170 ThrowOnError(api_.GetTensorMutableData(value, &data));
171 return data;
172}
173
174inline const void* CustomOpApi::GetTensorRawData(_Inout_ const OrtValue* value) const {
175 return GetTensorMutableRawData(const_cast<OrtValue*>(value));
176}
177
178inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const {
179 std::vector<int64_t> output(GetDimensionsCount(info));
180 GetDimensions(info, output.data(), output.size());
181 return output;
182}
183
184inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const {
185 api_.ReleaseTensorTypeAndShapeInfo(input);
186}
187
188inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) const {
189 size_t out;
190 ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
191 return out;
192}
193
194inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const {
195 const OrtValue* out;
196 ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
197 return out;
198}
199
200inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) const {
201 size_t out;
202 ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
203 return out;
204}
205
206inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
207 _In_ const int64_t* dim_values, size_t dim_count) const {
208 OrtValue* out;
209 ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
210 return out;
211}
212
213} // namespace of OrtW
214
215
216struct BaseKernel {
217 BaseKernel(const OrtApi& api, const OrtKernelInfo& info) noexcept
218 : api_(api), info_(info), ort_(api_) {
219 }
220
221 template <class T>
222 bool TryToGetAttribute(const char* name, T& value) const noexcept;
223
224 template <class T>
225 T TryToGetAttributeWithDefault(const char* name, const T& default_value) const noexcept {
226 T result = default_value;
227 TryToGetAttribute(name, result);
228 return result;
229 }
230
231 void SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim,
232 const std::vector<int64_t>& data);
233
234 protected:
235 OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept;
236
237 const OrtApi& api_;
238 OrtW::CustomOpApi ort_;
239 const OrtKernelInfo& info_;
240};
241
242struct OrtTensorDimensions : std::vector<int64_t> {
243 OrtTensorDimensions() = default;
244 OrtTensorDimensions(const OrtW::CustomOpApi& ort, const OrtValue* value) {
245 OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
246 std::vector<int64_t>::operator=(ort.GetTensorShape(info));
247 ort.ReleaseTensorTypeAndShapeInfo(info);
248 }
249
250 int64_t Size() const {
251 int64_t s = 1;
252 for (auto it = begin(); it != end(); ++it)
253 s *= *it;
254 return s;
255 }
256
257 bool IsScalar() const {
258 return empty();
259 }
260
261 bool IsVector() const {
262 return size() == 1;
263 }
264};