microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2d02a687beb1ba10319dc381b3907c91ab370995

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

include/ort_c_to_cpp.h

373lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5#include <vector>
6#include "exceptions.h"
7
8// OrtW: ONNX Runtime C ABI Wrapper
9namespace OrtW {
10
11struct CustomOpApi {
12 CustomOpApi(const OrtApi& api) : api_(api) {}
13
14 template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
15 T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const;
16
17 OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value) const;
18 size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
19 ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const;
20 size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
21 void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values,
22 size_t dim_values_length) const;
23 void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const;
24
25 template <typename T>
26 T* GetTensorMutableData(_Inout_ OrtValue* value) const;
27 template <typename T>
28 const T* GetTensorData(_Inout_ const OrtValue* value) const;
29
30 void* GetTensorMutableRawData(_Inout_ OrtValue* value) const;
31 const void* GetTensorRawData(_Inout_ const OrtValue* value) const;
32
33 std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const;
34 void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const;
35 size_t KernelContext_GetInputCount(const OrtKernelContext* context) const;
36 const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const;
37 size_t KernelContext_GetOutputCount(const OrtKernelContext* context) const;
38 OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values,
39 size_t dim_count) const;
40
41 void ThrowOnError(OrtStatus* status) const {
42 OrtW::ThrowOnError(api_, status);
43 }
44
45 const OrtApi& GetOrtApi() const { return api_; }
46
47 private:
48 const OrtApi& api_;
49};
50
51class API {
52 // To use ONNX C ABI in a way like OrtW::API::CreateStatus.
53 public:
54 static API& instance(const OrtApi* ort_api = nullptr) noexcept {
55 static API self(ort_api);
56 return self;
57 }
58
59 static OrtStatusPtr CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept {
60 return instance()->CreateStatus(code, msg);
61 }
62
63 static void ReleaseStatus(OrtStatusPtr ptr) noexcept {
64 instance()->ReleaseStatus(ptr);
65 }
66
67 template <typename T>
68 static OrtStatusPtr KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept;
69
70 static void ThrowOnError(OrtStatusPtr ptr) {
71 OrtW::ThrowOnError(instance().api_, ptr);
72 }
73
74 // Caller is responsible for releasing OrtMemoryInfo object
75 static OrtStatusPtr CreateOrtMemoryInfo(const char* name, enum OrtAllocatorType type, int id, enum OrtMemType mem_type, OrtMemoryInfo** out) noexcept {
76 return instance()->CreateMemoryInfo(name, type, id, mem_type, out);
77 }
78#if ORT_API_VERSION >= 15
79 // Caller is responsible for releasing OrtAllocator object: delete static_cast<onnxruntime::OrtAllocatorImpl*> (allocator)
80 static OrtStatusPtr GetOrtAllocator(const OrtKernelContext* context, const OrtMemoryInfo* mem_info, OrtAllocator** out) {
81 return instance()->KernelContext_GetAllocator(context, mem_info, out);
82 }
83#endif
84 private:
85 const OrtApi* operator->() const {
86 return &api_;
87 }
88
89 API(const OrtApi* api) : api_(*api) {
90 if (api == nullptr) {
91 ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION);
92 }
93 }
94
95 const OrtApi& api_;
96};
97
98
99//
100// Custom OP API Inlines
101//
102
103template <>
104inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
105 float out;
106 ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
107 return out;
108}
109
110template <>
111inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
112 int64_t out;
113 ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
114 return out;
115}
116
117template <>
118inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
119 size_t size = 0;
120 std::string out;
121
122 // Feed nullptr for the data buffer to query the true size of the string attribute
123 OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
124
125 if (status == nullptr) {
126 out.resize(size);
127 ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
128 out.resize(size - 1); // remove the terminating character '\0'
129 } else {
130 ThrowOnError(status);
131 }
132 return out;
133}
134
135template <>
136inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
137 size_t size = 0;
138 std::vector<float> out;
139
140 // Feed nullptr for the data buffer to query the true size of the attribute
141 OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
142
143 if (status == nullptr) {
144 out.resize(size);
145 ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
146 } else {
147 ThrowOnError(status);
148 }
149 return out;
150}
151
152template <>
153inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
154 size_t size = 0;
155 std::vector<int64_t> out;
156
157 // Feed nullptr for the data buffer to query the true size of the attribute
158 OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
159
160 if (status == nullptr) {
161 out.resize(size);
162 ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
163 } else {
164 ThrowOnError(status);
165 }
166 return out;
167}
168
169inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) const {
170 OrtTensorTypeAndShapeInfo* out;
171 ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
172 return out;
173}
174
175inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
176 size_t out;
177 ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
178 return out;
179}
180
181inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const {
182 ONNXTensorElementDataType out;
183 ThrowOnError(api_.GetTensorElementType(info, &out));
184 return out;
185}
186
187inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
188 size_t out;
189 ThrowOnError(api_.GetDimensionsCount(info, &out));
190 return out;
191}
192
193inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) const {
194 ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
195}
196
197inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const {
198 ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
199}
200
201template <typename T>
202inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) const {
203 T* data = nullptr;
204 ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
205 return data;
206}
207
208template <typename T>
209inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) const {
210 return GetTensorMutableData<T>(const_cast<OrtValue*>(value));
211}
212
213inline void* CustomOpApi::GetTensorMutableRawData(_Inout_ OrtValue* value) const {
214 void* data = nullptr;
215 ThrowOnError(api_.GetTensorMutableData(value, &data));
216 return data;
217}
218
219inline const void* CustomOpApi::GetTensorRawData(_Inout_ const OrtValue* value) const {
220 return GetTensorMutableRawData(const_cast<OrtValue*>(value));
221}
222
223inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const {
224 std::vector<int64_t> output(GetDimensionsCount(info));
225 GetDimensions(info, output.data(), output.size());
226 return output;
227}
228
229inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const {
230 api_.ReleaseTensorTypeAndShapeInfo(input);
231}
232
233inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) const {
234 size_t out;
235 ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
236 return out;
237}
238
239inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const {
240 const OrtValue* out;
241 ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
242 return out;
243}
244
245inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) const {
246 size_t out;
247 ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
248 return out;
249}
250
251inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
252 _In_ const int64_t* dim_values, size_t dim_count) const {
253 OrtValue* out;
254 ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
255 return out;
256}
257
258template <>
259inline OrtStatusPtr API::KernelInfoGetAttribute<int64_t>(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept {
260 return instance()->KernelInfoGetAttribute_int64(&info, name, &value);
261}
262
263template <>
264inline OrtStatusPtr API::KernelInfoGetAttribute<float>(const OrtKernelInfo& info, const char* name, float& value) noexcept {
265 return instance()->KernelInfoGetAttribute_float(&info, name, &value);
266}
267
268template <>
269inline OrtStatusPtr API::KernelInfoGetAttribute<std::string>(const OrtKernelInfo& info, const char* name, std::string& value) noexcept {
270 size_t size = 0;
271 std::string out;
272 // Feed nullptr for the data buffer to query the true size of the string attribute
273 OrtStatus* status = instance()->KernelInfoGetAttribute_string(&info, name, nullptr, &size);
274 if (status == nullptr) {
275 out.resize(size);
276 status = instance()->KernelInfoGetAttribute_string(&info, name, &out[0], &size);
277 out.resize(size - 1); // remove the terminating character '\0'
278 }
279
280 if (status == nullptr) {
281 value = std::move(out);
282 }
283
284 return status;
285}
286
287template <class T>
288inline OrtStatusPtr GetOpAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept {
289 if (auto status = API::KernelInfoGetAttribute(info, name, value); status) {
290 // Ideally, we should know which kind of error code can be ignored, but it is not available now.
291 // Just ignore all of them.
292 API::ReleaseStatus(status);
293 }
294
295 return nullptr;
296}
297
298template <class T>
299inline T GetOpAttributeOrDefault(const OrtKernelInfo& info, const char* name, const T& default_value) noexcept {
300 T ret;
301 if (API::KernelInfoGetAttribute(info, name, ret)) {
302 ret = default_value;
303 }
304 return ret;
305}
306
307inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) {
308 return API::CreateStatus(code, msg);
309}
310
311inline OrtStatusPtr CreateStatus(const std::string& msg, OrtErrorCode code) {
312 return API::CreateStatus(code, msg.c_str());
313}
314
315inline void ReleaseStatus(OrtStatusPtr& status) {
316 API::ReleaseStatus(status);
317 status = nullptr;
318}
319
320} // namespace of OrtW
321
322
323// Deprecated: No needs to create a new class derived from BaseKernel.
324struct BaseKernel {
325 BaseKernel(const OrtApi& api, const OrtKernelInfo& info) noexcept
326 : api_(api), info_(info), ort_(api_) {
327 }
328
329 template <class T>
330 bool TryToGetAttribute(const char* name, T& value) const noexcept;
331
332 template <class T>
333 T TryToGetAttributeWithDefault(const char* name, const T& default_value) const noexcept {
334 T result = default_value;
335 TryToGetAttribute(name, result);
336 return result;
337 }
338
339 void SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim,
340 const std::vector<int64_t>& data);
341
342 protected:
343 OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept;
344
345 const OrtApi& api_;
346 const OrtKernelInfo& info_;
347 OrtW::CustomOpApi ort_;
348};
349
350// Deprecated: Use OrtW::CustomOpApi::KernelInfoGetAttribute instead
351struct OrtTensorDimensions : std::vector<int64_t> {
352 OrtTensorDimensions() = default;
353 OrtTensorDimensions(const OrtW::CustomOpApi& ort, const OrtValue* value) {
354 OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
355 std::vector<int64_t>::operator=(ort.GetTensorShape(info));
356 ort.ReleaseTensorTypeAndShapeInfo(info);
357 }
358
359 int64_t Size() const {
360 int64_t s = 1;
361 for (auto it = begin(); it != end(); ++it)
362 s *= *it;
363 return s;
364 }
365
366 bool IsScalar() const {
367 return empty();
368 }
369
370 bool IsVector() const {
371 return size() == 1;
372 }
373};
374