microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
b661d5f22f396e757eb1de6e1ab28f2a50f0e81b

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

include/ort_c_to_cpp.h

404lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5#include <vector>
6#include "exceptions.h"
7
8// OrtW: ONNX Runtime C ABI Wrapper
9namespace OrtW {
10struct CustomOpApi {
11 CustomOpApi(const OrtApi& api) : api_(api) {}
12
13 template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
14 T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const;
15
16 OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value) const;
17 size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
18 ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const;
19 size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
20 void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values,
21 size_t dim_values_length) const;
22 void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const;
23
24 template <typename T>
25 T* GetTensorMutableData(_Inout_ OrtValue* value) const;
26 template <typename T>
27 const T* GetTensorData(_Inout_ const OrtValue* value) const;
28
29 void* GetTensorMutableRawData(_Inout_ OrtValue* value) const;
30 const void* GetTensorRawData(_Inout_ const OrtValue* value) const;
31
32 std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const;
33 void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const;
34 size_t KernelContext_GetInputCount(const OrtKernelContext* context) const;
35 const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const;
36 size_t KernelContext_GetOutputCount(const OrtKernelContext* context) const;
37 OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values,
38 size_t dim_count) const;
39
40 void ThrowOnError(OrtStatus* status) const {
41 OrtW::ThrowOnError(api_, status);
42 }
43
44 const OrtApi& GetOrtApi() const { return api_; }
45
46 private:
47 const OrtApi& api_;
48};
49
50class API {
51 // To use ONNX C ABI in a way like OrtW::API::CreateStatus.
52 public:
53 static API& instance(const OrtApi* ort_api = nullptr) noexcept {
54 static API self(ort_api);
55 return self;
56 }
57
58 static OrtStatusPtr CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept {
59 return instance()->CreateStatus(code, msg);
60 }
61
62 static void ReleaseStatus(OrtStatusPtr ptr) noexcept {
63 instance()->ReleaseStatus(ptr);
64 }
65
66 static OrtStatusPtr GetOpAttributeString(const OrtApi& api,
67 const OrtKernelInfo& info, const char* name, std::string& value) noexcept {
68 size_t size = 0;
69 OrtStatus* status = api.KernelInfoGetAttribute_string(&info, name, nullptr, &size);
70 if (status == nullptr) {
71 value.resize(size);
72 status = api.KernelInfoGetAttribute_string(&info, name, &value[0], &size);
73 value.resize(size - 1); // remove the terminating character '\0'
74 if (status != nullptr) {
75 return status; // some unexpected error
76 }
77 } else {
78 // ignore the error, as the attribute is optional
79 api.ReleaseStatus(status);
80 }
81
82 return nullptr;
83 }
84
85 template <typename T>
86 static OrtStatusPtr KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept;
87
88 static void ThrowOnError(OrtStatusPtr ptr) {
89 OrtW::ThrowOnError(instance().api_, ptr);
90 }
91
92 // Caller is responsible for releasing OrtMemoryInfo object
93 static OrtStatusPtr CreateOrtMemoryInfo(const char* name, enum OrtAllocatorType type, int id, enum OrtMemType mem_type, OrtMemoryInfo** out) noexcept {
94 return instance()->CreateMemoryInfo(name, type, id, mem_type, out);
95 }
96#if ORT_API_VERSION >= 15
97 // Caller is responsible for releasing OrtAllocator object: delete static_cast<onnxruntime::OrtAllocatorImpl*> (allocator)
98 static OrtStatusPtr GetOrtAllocator(const OrtKernelContext* context, const OrtMemoryInfo* mem_info, OrtAllocator** out) {
99 return instance()->KernelContext_GetAllocator(context, mem_info, out);
100 }
101#endif
102 private:
103 const OrtApi* operator->() const {
104 return &api_;
105 }
106
107 API(const OrtApi* api) : api_(*api) {
108 if (api == nullptr) {
109 ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION);
110 }
111 }
112
113 const OrtApi& api_;
114};
115
116
117//
118// Custom OP API Inlines
119//
120
121template <>
122inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
123 float out;
124 ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
125 return out;
126}
127
128template <>
129inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
130 int64_t out;
131 ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
132 return out;
133}
134
135template <>
136inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
137 size_t size = 0;
138 std::string out;
139
140 // Feed nullptr for the data buffer to query the true size of the string attribute
141 OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
142
143 if (status == nullptr) {
144 out.resize(size);
145 ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
146 out.resize(size - 1); // remove the terminating character '\0'
147 } else {
148 ThrowOnError(status);
149 }
150 return out;
151}
152
153template <>
154inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
155 size_t size = 0;
156 std::vector<float> out;
157
158 // Feed nullptr for the data buffer to query the true size of the attribute
159 OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
160
161 if (status == nullptr) {
162 out.resize(size);
163 ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
164 } else {
165 ThrowOnError(status);
166 }
167 return out;
168}
169
170template <>
171inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
172 size_t size = 0;
173 std::vector<int64_t> out;
174
175 // Feed nullptr for the data buffer to query the true size of the attribute
176 OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
177
178 if (status == nullptr) {
179 out.resize(size);
180 ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
181 } else {
182 ThrowOnError(status);
183 }
184 return out;
185}
186
187inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) const {
188 OrtTensorTypeAndShapeInfo* out;
189 ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
190 return out;
191}
192
193inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
194 size_t out;
195 ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
196 return out;
197}
198
199inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const {
200 ONNXTensorElementDataType out;
201 ThrowOnError(api_.GetTensorElementType(info, &out));
202 return out;
203}
204
205inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
206 size_t out;
207 ThrowOnError(api_.GetDimensionsCount(info, &out));
208 return out;
209}
210
211inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) const {
212 ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
213}
214
215inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const {
216 ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
217}
218
219template <typename T>
220inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) const {
221 T* data = nullptr;
222 ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
223 return data;
224}
225
226template <typename T>
227inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) const {
228 return GetTensorMutableData<T>(const_cast<OrtValue*>(value));
229}
230
231inline void* CustomOpApi::GetTensorMutableRawData(_Inout_ OrtValue* value) const {
232 void* data = nullptr;
233 ThrowOnError(api_.GetTensorMutableData(value, &data));
234 return data;
235}
236
237inline const void* CustomOpApi::GetTensorRawData(_Inout_ const OrtValue* value) const {
238 return GetTensorMutableRawData(const_cast<OrtValue*>(value));
239}
240
241inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const {
242 std::vector<int64_t> output(GetDimensionsCount(info));
243 GetDimensions(info, output.data(), output.size());
244 return output;
245}
246
247inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const {
248 api_.ReleaseTensorTypeAndShapeInfo(input);
249}
250
251inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) const {
252 size_t out;
253 ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
254 return out;
255}
256
257inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const {
258 const OrtValue* out;
259 ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
260 return out;
261}
262
263inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) const {
264 size_t out;
265 ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
266 return out;
267}
268
269inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
270 _In_ const int64_t* dim_values, size_t dim_count) const {
271 OrtValue* out;
272 ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
273 return out;
274}
275
276template <>
277inline OrtStatusPtr API::KernelInfoGetAttribute<int64_t>(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept {
278 return instance()->KernelInfoGetAttribute_int64(&info, name, &value);
279}
280
281template <>
282inline OrtStatusPtr API::KernelInfoGetAttribute<float>(const OrtKernelInfo& info, const char* name, float& value) noexcept {
283 return instance()->KernelInfoGetAttribute_float(&info, name, &value);
284}
285
286template <>
287inline OrtStatusPtr API::KernelInfoGetAttribute<std::string>(const OrtKernelInfo& info, const char* name, std::string& value) noexcept {
288 size_t size = 0;
289 std::string out;
290 // Feed nullptr for the data buffer to query the true size of the string attribute
291 OrtStatus* status = instance()->KernelInfoGetAttribute_string(&info, name, nullptr, &size);
292 if (status == nullptr) {
293 out.resize(size);
294 status = instance()->KernelInfoGetAttribute_string(&info, name, &out[0], &size);
295 out.resize(size - 1); // remove the terminating character '\0'
296 }
297
298 if (status == nullptr) {
299 value = std::move(out);
300 }
301
302 return status;
303}
304
305template <class T>
306inline OrtStatusPtr GetOpAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept {
307 if (auto status = API::KernelInfoGetAttribute(info, name, value); status) {
308 // Ideally, we should know which kind of error code can be ignored, but it is not available now.
309 // Just ignore all of them.
310 API::ReleaseStatus(status);
311 }
312
313 return nullptr;
314}
315
316template <class T>
317inline T GetOpAttributeOrDefault(const OrtKernelInfo& info, const char* name, const T& default_value) noexcept {
318 T ret;
319 if (API::KernelInfoGetAttribute(info, name, ret)) {
320 ret = default_value;
321 }
322 return ret;
323}
324
325inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) {
326 return API::CreateStatus(code, msg);
327}
328
329inline OrtStatusPtr CreateStatus(const std::string& msg, OrtErrorCode code) {
330 return API::CreateStatus(code, msg.c_str());
331}
332
333inline void ReleaseStatus(OrtStatusPtr& status) {
334 API::ReleaseStatus(status);
335 status = nullptr;
336}
337
338} // namespace of OrtW
339
340
341// Deprecated: No needs to create a new class derived from BaseKernel.
342struct BaseKernel {
343 BaseKernel(const OrtApi& api, const OrtKernelInfo& info) noexcept
344 : api_(api), info_(info), ort_(api_) {
345 }
346
347 template <class T>
348 bool TryToGetAttribute(const char* name, T& value) const noexcept;
349
350 template <class T>
351 T TryToGetAttributeWithDefault(const char* name, const T& default_value) const noexcept {
352 T result = default_value;
353 TryToGetAttribute(name, result);
354 return result;
355 }
356
357 void SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim,
358 const std::vector<int64_t>& data);
359
360 protected:
361 OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept;
362
363 const OrtApi& api_;
364 const OrtKernelInfo& info_;
365 OrtW::CustomOpApi ort_;
366};
367
368// Deprecated: Use OrtW::CustomOpApi::KernelInfoGetAttribute instead
369struct OrtTensorDimensions : std::vector<int64_t> {
370 OrtTensorDimensions() = default;
371 OrtTensorDimensions(const OrtW::CustomOpApi& ort, const OrtValue* value) {
372 OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
373 std::vector<int64_t>::operator=(ort.GetTensorShape(info));
374 ort.ReleaseTensorTypeAndShapeInfo(info);
375 }
376
377 int64_t Size() const {
378 int64_t s = 1;
379 for (auto it = begin(); it != end(); ++it)
380 s *= *it;
381 return s;
382 }
383
384 bool IsScalar() const {
385 return empty();
386 }
387
388 bool IsVector() const {
389 return size() == 1;
390 }
391};
392
393inline bool IsScalarOr1ElementVector(size_t num_dimensions, int64_t shape_size) {
394 if (num_dimensions == 0 || (num_dimensions == 1 && shape_size == 1)) return true;
395 return false;
396}
397
398#define ORTW_RETURN_IF_ERROR(expr) \
399 do { \
400 auto _status = (expr); \
401 if (_status != nullptr) { \
402 return _status; \
403 } \
404 } while (0)
405