microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
176c1d013864044bcc0747b908bdd32048669401

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

include/ort_c_to_cpp.h

385lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5#include <vector>
6#include "exceptions.h"
7
8// OrtW: ONNX Runtime C ABI Wrapper
9namespace OrtW {
10struct CustomOpApi {
11 CustomOpApi(const OrtApi& api) : api_(api) {}
12
13 template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
14 T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const;
15
16 OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value) const;
17 size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
18 ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const;
19 size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
20 void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values,
21 size_t dim_values_length) const;
22 void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const;
23
24 template <typename T>
25 T* GetTensorMutableData(_Inout_ OrtValue* value) const;
26 template <typename T>
27 const T* GetTensorData(_Inout_ const OrtValue* value) const;
28
29 void* GetTensorMutableRawData(_Inout_ OrtValue* value) const;
30 const void* GetTensorRawData(_Inout_ const OrtValue* value) const;
31
32 std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const;
33 void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const;
34 size_t KernelContext_GetInputCount(const OrtKernelContext* context) const;
35 const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const;
36 size_t KernelContext_GetOutputCount(const OrtKernelContext* context) const;
37 OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values,
38 size_t dim_count) const;
39
40 void ThrowOnError(OrtStatus* status) const {
41 OrtW::ThrowOnError(api_, status);
42 }
43
44 const OrtApi& GetOrtApi() const { return api_; }
45
46 private:
47 const OrtApi& api_;
48};
49
50class API {
51 // To use ONNX C ABI in a way like OrtW::API::CreateStatus.
52 public:
53 static API& instance(const OrtApi* ort_api = nullptr) noexcept {
54 static API self(ort_api);
55 return self;
56 }
57
58 static OrtStatusPtr CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept {
59 return instance()->CreateStatus(code, msg);
60 }
61
62 static void ReleaseStatus(OrtStatusPtr ptr) noexcept {
63 instance()->ReleaseStatus(ptr);
64 }
65
66 template <typename T>
67 static OrtStatusPtr KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept;
68
69 static void ThrowOnError(OrtStatusPtr ptr) {
70 OrtW::ThrowOnError(instance().api_, ptr);
71 }
72
73 // Caller is responsible for releasing OrtMemoryInfo object
74 static OrtStatusPtr CreateOrtMemoryInfo(const char* name, enum OrtAllocatorType type, int id, enum OrtMemType mem_type, OrtMemoryInfo** out) noexcept {
75 return instance()->CreateMemoryInfo(name, type, id, mem_type, out);
76 }
77#if ORT_API_VERSION >= 15
78 // Caller is responsible for releasing OrtAllocator object: delete static_cast<onnxruntime::OrtAllocatorImpl*> (allocator)
79 static OrtStatusPtr GetOrtAllocator(const OrtKernelContext* context, const OrtMemoryInfo* mem_info, OrtAllocator** out) {
80 return instance()->KernelContext_GetAllocator(context, mem_info, out);
81 }
82#endif
83 private:
84 const OrtApi* operator->() const {
85 return &api_;
86 }
87
88 API(const OrtApi* api) : api_(*api) {
89 if (api == nullptr) {
90 ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION);
91 }
92 }
93
94 const OrtApi& api_;
95};
96
97
98//
99// Custom OP API Inlines
100//
101
102template <>
103inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
104 float out;
105 ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
106 return out;
107}
108
109template <>
110inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
111 int64_t out;
112 ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
113 return out;
114}
115
116template <>
117inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
118 size_t size = 0;
119 std::string out;
120
121 // Feed nullptr for the data buffer to query the true size of the string attribute
122 OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
123
124 if (status == nullptr) {
125 out.resize(size);
126 ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
127 out.resize(size - 1); // remove the terminating character '\0'
128 } else {
129 ThrowOnError(status);
130 }
131 return out;
132}
133
134template <>
135inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
136 size_t size = 0;
137 std::vector<float> out;
138
139 // Feed nullptr for the data buffer to query the true size of the attribute
140 OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
141
142 if (status == nullptr) {
143 out.resize(size);
144 ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
145 } else {
146 ThrowOnError(status);
147 }
148 return out;
149}
150
151template <>
152inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
153 size_t size = 0;
154 std::vector<int64_t> out;
155
156 // Feed nullptr for the data buffer to query the true size of the attribute
157 OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
158
159 if (status == nullptr) {
160 out.resize(size);
161 ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
162 } else {
163 ThrowOnError(status);
164 }
165 return out;
166}
167
168inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) const {
169 OrtTensorTypeAndShapeInfo* out;
170 ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
171 return out;
172}
173
174inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
175 size_t out;
176 ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
177 return out;
178}
179
180inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const {
181 ONNXTensorElementDataType out;
182 ThrowOnError(api_.GetTensorElementType(info, &out));
183 return out;
184}
185
186inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
187 size_t out;
188 ThrowOnError(api_.GetDimensionsCount(info, &out));
189 return out;
190}
191
192inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) const {
193 ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
194}
195
196inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const {
197 ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
198}
199
200template <typename T>
201inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) const {
202 T* data = nullptr;
203 ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
204 return data;
205}
206
207template <typename T>
208inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) const {
209 return GetTensorMutableData<T>(const_cast<OrtValue*>(value));
210}
211
212inline void* CustomOpApi::GetTensorMutableRawData(_Inout_ OrtValue* value) const {
213 void* data = nullptr;
214 ThrowOnError(api_.GetTensorMutableData(value, &data));
215 return data;
216}
217
218inline const void* CustomOpApi::GetTensorRawData(_Inout_ const OrtValue* value) const {
219 return GetTensorMutableRawData(const_cast<OrtValue*>(value));
220}
221
222inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const {
223 std::vector<int64_t> output(GetDimensionsCount(info));
224 GetDimensions(info, output.data(), output.size());
225 return output;
226}
227
228inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const {
229 api_.ReleaseTensorTypeAndShapeInfo(input);
230}
231
232inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) const {
233 size_t out;
234 ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
235 return out;
236}
237
238inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const {
239 const OrtValue* out;
240 ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
241 return out;
242}
243
244inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) const {
245 size_t out;
246 ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
247 return out;
248}
249
250inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
251 _In_ const int64_t* dim_values, size_t dim_count) const {
252 OrtValue* out;
253 ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
254 return out;
255}
256
257template <>
258inline OrtStatusPtr API::KernelInfoGetAttribute<int64_t>(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept {
259 return instance()->KernelInfoGetAttribute_int64(&info, name, &value);
260}
261
262template <>
263inline OrtStatusPtr API::KernelInfoGetAttribute<float>(const OrtKernelInfo& info, const char* name, float& value) noexcept {
264 return instance()->KernelInfoGetAttribute_float(&info, name, &value);
265}
266
267template <>
268inline OrtStatusPtr API::KernelInfoGetAttribute<std::string>(const OrtKernelInfo& info, const char* name, std::string& value) noexcept {
269 size_t size = 0;
270 std::string out;
271 // Feed nullptr for the data buffer to query the true size of the string attribute
272 OrtStatus* status = instance()->KernelInfoGetAttribute_string(&info, name, nullptr, &size);
273 if (status == nullptr) {
274 out.resize(size);
275 status = instance()->KernelInfoGetAttribute_string(&info, name, &out[0], &size);
276 out.resize(size - 1); // remove the terminating character '\0'
277 }
278
279 if (status == nullptr) {
280 value = std::move(out);
281 }
282
283 return status;
284}
285
286template <class T>
287inline OrtStatusPtr GetOpAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept {
288 if (auto status = API::KernelInfoGetAttribute(info, name, value); status) {
289 // Ideally, we should know which kind of error code can be ignored, but it is not available now.
290 // Just ignore all of them.
291 API::ReleaseStatus(status);
292 }
293
294 return nullptr;
295}
296
297template <class T>
298inline T GetOpAttributeOrDefault(const OrtKernelInfo& info, const char* name, const T& default_value) noexcept {
299 T ret;
300 if (API::KernelInfoGetAttribute(info, name, ret)) {
301 ret = default_value;
302 }
303 return ret;
304}
305
306inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) {
307 return API::CreateStatus(code, msg);
308}
309
310inline OrtStatusPtr CreateStatus(const std::string& msg, OrtErrorCode code) {
311 return API::CreateStatus(code, msg.c_str());
312}
313
314inline void ReleaseStatus(OrtStatusPtr& status) {
315 API::ReleaseStatus(status);
316 status = nullptr;
317}
318
319} // namespace of OrtW
320
321
322// Deprecated: No needs to create a new class derived from BaseKernel.
323struct BaseKernel {
324 BaseKernel(const OrtApi& api, const OrtKernelInfo& info) noexcept
325 : api_(api), info_(info), ort_(api_) {
326 }
327
328 template <class T>
329 bool TryToGetAttribute(const char* name, T& value) const noexcept;
330
331 template <class T>
332 T TryToGetAttributeWithDefault(const char* name, const T& default_value) const noexcept {
333 T result = default_value;
334 TryToGetAttribute(name, result);
335 return result;
336 }
337
338 void SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim,
339 const std::vector<int64_t>& data);
340
341 protected:
342 OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept;
343
344 const OrtApi& api_;
345 const OrtKernelInfo& info_;
346 OrtW::CustomOpApi ort_;
347};
348
349// Deprecated: Use OrtW::CustomOpApi::KernelInfoGetAttribute instead
350struct OrtTensorDimensions : std::vector<int64_t> {
351 OrtTensorDimensions() = default;
352 OrtTensorDimensions(const OrtW::CustomOpApi& ort, const OrtValue* value) {
353 OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
354 std::vector<int64_t>::operator=(ort.GetTensorShape(info));
355 ort.ReleaseTensorTypeAndShapeInfo(info);
356 }
357
358 int64_t Size() const {
359 int64_t s = 1;
360 for (auto it = begin(); it != end(); ++it)
361 s *= *it;
362 return s;
363 }
364
365 bool IsScalar() const {
366 return empty();
367 }
368
369 bool IsVector() const {
370 return size() == 1;
371 }
372};
373
374inline bool IsScalarOr1ElementVector(size_t num_dimensions, int64_t shape_size) {
375 if (num_dimensions == 0 || (num_dimensions == 1 && shape_size == 1)) return true;
376 return false;
377}
378
379#define ORTX_RETURN_IF_ERROR(expr) \
380 do { \
381 auto _status = (expr); \
382 if (_status != nullptr) { \
383 return _status; \
384 } \
385 } while (0)
386