microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
leca/gqa2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

include/ort_c_to_cpp.h

377lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5#include <vector>
6#include "exceptions.h"
7
8// OrtW: ONNX Runtime C ABI Wrapper
9namespace OrtW {
10
11struct CustomOpApi {
12 CustomOpApi(const OrtApi& api) : api_(api) {}
13
14 template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
15 T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const;
16
17 OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value) const;
18 size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
19 ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const;
20 size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
21 void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values,
22 size_t dim_values_length) const;
23 void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const;
24
25 template <typename T>
26 T* GetTensorMutableData(_Inout_ OrtValue* value) const;
27 template <typename T>
28 const T* GetTensorData(_Inout_ const OrtValue* value) const;
29
30 void* GetTensorMutableRawData(_Inout_ OrtValue* value) const;
31 const void* GetTensorRawData(_Inout_ const OrtValue* value) const;
32
33 std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const;
34 void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const;
35 size_t KernelContext_GetInputCount(const OrtKernelContext* context) const;
36 const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const;
37 size_t KernelContext_GetOutputCount(const OrtKernelContext* context) const;
38 OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values,
39 size_t dim_count) const;
40
41 void ThrowOnError(OrtStatus* status) const {
42 OrtW::ThrowOnError(api_, status);
43 }
44
45 const OrtApi& GetOrtApi() const { return api_; }
46
47 private:
48 const OrtApi& api_;
49};
50
51class API {
52 // To use ONNX C ABI in a way like OrtW::API::CreateStatus.
53 public:
54 static API& instance(const OrtApi* ort_api = nullptr) noexcept {
55 static API self(ort_api);
56 return self;
57 }
58
59 static OrtStatusPtr CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept {
60 return instance()->CreateStatus(code, msg);
61 }
62
63 static void ReleaseStatus(OrtStatusPtr ptr) noexcept {
64 instance()->ReleaseStatus(ptr);
65 }
66
67 template <typename T>
68 static OrtStatusPtr KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept;
69
70 static void ThrowOnError(OrtStatusPtr ptr) {
71 OrtW::ThrowOnError(instance().api_, ptr);
72 }
73
74 // Caller is responsible for releasing OrtMemoryInfo object
75 static OrtStatusPtr CreateOrtMemoryInfo(const char* name, enum OrtAllocatorType type, int id, enum OrtMemType mem_type, OrtMemoryInfo** out) noexcept {
76 return instance()->CreateMemoryInfo(name, type, id, mem_type, out);
77 }
78
79 static void ReleaseMemoryInfo(OrtMemoryInfo* mem_info) {
80 return instance()->ReleaseMemoryInfo(mem_info);
81 }
82
83#if ORT_API_VERSION >= 18
84 static OrtStatusPtr KernelContextGetScratchBuffer(const OrtKernelContext* context, const OrtMemoryInfo* mem_info, size_t count_or_bytes, void** out) {
85 return instance()->KernelContext_GetScratchBuffer(context, mem_info, count_or_bytes, out);
86 }
87#endif
88 private:
89 const OrtApi* operator->() const {
90 return &api_;
91 }
92
93 API(const OrtApi* api) : api_(*api) {
94 if (api == nullptr) {
95 ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION);
96 }
97 }
98
99 const OrtApi& api_;
100};
101
102
103//
104// Custom OP API Inlines
105//
106
107template <>
108inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
109 float out;
110 ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
111 return out;
112}
113
114template <>
115inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
116 int64_t out;
117 ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
118 return out;
119}
120
121template <>
122inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
123 size_t size = 0;
124 std::string out;
125
126 // Feed nullptr for the data buffer to query the true size of the string attribute
127 OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
128
129 if (status == nullptr) {
130 out.resize(size);
131 ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
132 out.resize(size - 1); // remove the terminating character '\0'
133 } else {
134 ThrowOnError(status);
135 }
136 return out;
137}
138
139template <>
140inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
141 size_t size = 0;
142 std::vector<float> out;
143
144 // Feed nullptr for the data buffer to query the true size of the attribute
145 OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
146
147 if (status == nullptr) {
148 out.resize(size);
149 ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
150 } else {
151 ThrowOnError(status);
152 }
153 return out;
154}
155
156template <>
157inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
158 size_t size = 0;
159 std::vector<int64_t> out;
160
161 // Feed nullptr for the data buffer to query the true size of the attribute
162 OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
163
164 if (status == nullptr) {
165 out.resize(size);
166 ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
167 } else {
168 ThrowOnError(status);
169 }
170 return out;
171}
172
173inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) const {
174 OrtTensorTypeAndShapeInfo* out;
175 ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
176 return out;
177}
178
179inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
180 size_t out;
181 ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
182 return out;
183}
184
185inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const {
186 ONNXTensorElementDataType out;
187 ThrowOnError(api_.GetTensorElementType(info, &out));
188 return out;
189}
190
191inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
192 size_t out;
193 ThrowOnError(api_.GetDimensionsCount(info, &out));
194 return out;
195}
196
197inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) const {
198 ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
199}
200
201inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const {
202 ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
203}
204
205template <typename T>
206inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) const {
207 T* data = nullptr;
208 ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
209 return data;
210}
211
212template <typename T>
213inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) const {
214 return GetTensorMutableData<T>(const_cast<OrtValue*>(value));
215}
216
217inline void* CustomOpApi::GetTensorMutableRawData(_Inout_ OrtValue* value) const {
218 void* data = nullptr;
219 ThrowOnError(api_.GetTensorMutableData(value, &data));
220 return data;
221}
222
223inline const void* CustomOpApi::GetTensorRawData(_Inout_ const OrtValue* value) const {
224 return GetTensorMutableRawData(const_cast<OrtValue*>(value));
225}
226
227inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const {
228 std::vector<int64_t> output(GetDimensionsCount(info));
229 GetDimensions(info, output.data(), output.size());
230 return output;
231}
232
233inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const {
234 api_.ReleaseTensorTypeAndShapeInfo(input);
235}
236
237inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) const {
238 size_t out;
239 ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
240 return out;
241}
242
243inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const {
244 const OrtValue* out;
245 ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
246 return out;
247}
248
249inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) const {
250 size_t out;
251 ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
252 return out;
253}
254
255inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
256 _In_ const int64_t* dim_values, size_t dim_count) const {
257 OrtValue* out;
258 ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
259 return out;
260}
261
262template <>
263inline OrtStatusPtr API::KernelInfoGetAttribute<int64_t>(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept {
264 return instance()->KernelInfoGetAttribute_int64(&info, name, &value);
265}
266
267template <>
268inline OrtStatusPtr API::KernelInfoGetAttribute<float>(const OrtKernelInfo& info, const char* name, float& value) noexcept {
269 return instance()->KernelInfoGetAttribute_float(&info, name, &value);
270}
271
272template <>
273inline OrtStatusPtr API::KernelInfoGetAttribute<std::string>(const OrtKernelInfo& info, const char* name, std::string& value) noexcept {
274 size_t size = 0;
275 std::string out;
276 // Feed nullptr for the data buffer to query the true size of the string attribute
277 OrtStatus* status = instance()->KernelInfoGetAttribute_string(&info, name, nullptr, &size);
278 if (status == nullptr) {
279 out.resize(size);
280 status = instance()->KernelInfoGetAttribute_string(&info, name, &out[0], &size);
281 out.resize(size - 1); // remove the terminating character '\0'
282 }
283
284 if (status == nullptr) {
285 value = std::move(out);
286 }
287
288 return status;
289}
290
291template <class T>
292inline OrtStatusPtr GetOpAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept {
293 if (auto status = API::KernelInfoGetAttribute(info, name, value); status) {
294 // Ideally, we should know which kind of error code can be ignored, but it is not available now.
295 // Just ignore all of them.
296 API::ReleaseStatus(status);
297 }
298
299 return nullptr;
300}
301
302template <class T>
303inline T GetOpAttributeOrDefault(const OrtKernelInfo& info, const char* name, const T& default_value) noexcept {
304 T ret;
305 if (API::KernelInfoGetAttribute(info, name, ret)) {
306 ret = default_value;
307 }
308 return ret;
309}
310
311inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) {
312 return API::CreateStatus(code, msg);
313}
314
315inline OrtStatusPtr CreateStatus(const std::string& msg, OrtErrorCode code) {
316 return API::CreateStatus(code, msg.c_str());
317}
318
319inline void ReleaseStatus(OrtStatusPtr& status) {
320 API::ReleaseStatus(status);
321 status = nullptr;
322}
323
324} // namespace of OrtW
325
326
327// Deprecated: No needs to create a new class derived from BaseKernel.
328struct BaseKernel {
329 BaseKernel(const OrtApi& api, const OrtKernelInfo& info) noexcept
330 : api_(api), info_(info), ort_(api_) {
331 }
332
333 template <class T>
334 bool TryToGetAttribute(const char* name, T& value) const noexcept;
335
336 template <class T>
337 T TryToGetAttributeWithDefault(const char* name, const T& default_value) const noexcept {
338 T result = default_value;
339 TryToGetAttribute(name, result);
340 return result;
341 }
342
343 void SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim,
344 const std::vector<int64_t>& data);
345
346 protected:
347 OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept;
348
349 const OrtApi& api_;
350 OrtW::CustomOpApi ort_;
351 const OrtKernelInfo& info_;
352};
353
354// Deprecated: Use OrtW::CustomOpApi::KernelInfoGetAttribute instead
355struct OrtTensorDimensions : std::vector<int64_t> {
356 OrtTensorDimensions() = default;
357 OrtTensorDimensions(const OrtW::CustomOpApi& ort, const OrtValue* value) {
358 OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
359 std::vector<int64_t>::operator=(ort.GetTensorShape(info));
360 ort.ReleaseTensorTypeAndShapeInfo(info);
361 }
362
363 int64_t Size() const {
364 int64_t s = 1;
365 for (auto it = begin(); it != end(); ++it)
366 s *= *it;
367 return s;
368 }
369
370 bool IsScalar() const {
371 return empty();
372 }
373
374 bool IsVector() const {
375 return size() == 1;
376 }
377};
378