microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
leca/pagedAttention

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

include/ort_c_to_cpp.h

376lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5#include <vector>
6#include "exceptions.h"
7
8// OrtW: ONNX Runtime C ABI Wrapper
9namespace OrtW {
10
11struct CustomOpApi {
12 CustomOpApi(const OrtApi& api) : api_(api) {}
13
14 template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
15 T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const;
16
17 OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value) const;
18 size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
19 ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const;
20 size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
21 void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values,
22 size_t dim_values_length) const;
23 void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const;
24
25 template <typename T>
26 T* GetTensorMutableData(_Inout_ OrtValue* value) const;
27 template <typename T>
28 const T* GetTensorData(_Inout_ const OrtValue* value) const;
29
30 void* GetTensorMutableRawData(_Inout_ OrtValue* value) const;
31 const void* GetTensorRawData(_Inout_ const OrtValue* value) const;
32
33 std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const;
34 void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const;
35 size_t KernelContext_GetInputCount(const OrtKernelContext* context) const;
36 const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const;
37 size_t KernelContext_GetOutputCount(const OrtKernelContext* context) const;
38 OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values,
39 size_t dim_count) const;
40
41 void ThrowOnError(OrtStatus* status) const {
42 OrtW::ThrowOnError(api_, status);
43 }
44
45 const OrtApi& GetOrtApi() const { return api_; }
46
47 private:
48 const OrtApi& api_;
49};
50
51class API {
52 // To use ONNX C ABI in a way like OrtW::API::CreateStatus.
53 public:
54 static API& instance(const OrtApi* ort_api = nullptr) noexcept {
55 static API self(ort_api);
56 return self;
57 }
58
59 static OrtStatusPtr CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept {
60 return instance()->CreateStatus(code, msg);
61 }
62
63 static void ReleaseStatus(OrtStatusPtr ptr) noexcept {
64 instance()->ReleaseStatus(ptr);
65 }
66
67 template <typename T>
68 static OrtStatusPtr KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept;
69
70 static void ThrowOnError(OrtStatusPtr ptr) {
71 OrtW::ThrowOnError(instance().api_, ptr);
72 }
73
74 // Caller is responsible for releasing OrtMemoryInfo object
75 static OrtStatusPtr CreateOrtMemoryInfo(const char* name, enum OrtAllocatorType type, int id, enum OrtMemType mem_type, OrtMemoryInfo** out) noexcept {
76 return instance()->CreateMemoryInfo(name, type, id, mem_type, out);
77 }
78#if ORT_API_VERSION >= 15
79 // Caller is responsible for releasing OrtAllocator object: delete static_cast<onnxruntime::OrtAllocatorImpl*> (allocator)
80 static OrtStatusPtr GetOrtAllocator(const OrtKernelContext* context, const OrtMemoryInfo* mem_info, OrtAllocator** out) {
81 return instance()->KernelContext_GetAllocator(context, mem_info, out);
82 }
83#endif
84 static void ReleaseMemoryInfo(OrtMemoryInfo* mem_info) {
85 return instance()->ReleaseMemoryInfo(mem_info);
86 }
87 private:
88 const OrtApi* operator->() const {
89 return &api_;
90 }
91
92 API(const OrtApi* api) : api_(*api) {
93 if (api == nullptr) {
94 ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION);
95 }
96 }
97
98 const OrtApi& api_;
99};
100
101
102//
103// Custom OP API Inlines
104//
105
106template <>
107inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
108 float out;
109 ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
110 return out;
111}
112
113template <>
114inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
115 int64_t out;
116 ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
117 return out;
118}
119
120template <>
121inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
122 size_t size = 0;
123 std::string out;
124
125 // Feed nullptr for the data buffer to query the true size of the string attribute
126 OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
127
128 if (status == nullptr) {
129 out.resize(size);
130 ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
131 out.resize(size - 1); // remove the terminating character '\0'
132 } else {
133 ThrowOnError(status);
134 }
135 return out;
136}
137
138template <>
139inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
140 size_t size = 0;
141 std::vector<float> out;
142
143 // Feed nullptr for the data buffer to query the true size of the attribute
144 OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
145
146 if (status == nullptr) {
147 out.resize(size);
148 ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
149 } else {
150 ThrowOnError(status);
151 }
152 return out;
153}
154
155template <>
156inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
157 size_t size = 0;
158 std::vector<int64_t> out;
159
160 // Feed nullptr for the data buffer to query the true size of the attribute
161 OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
162
163 if (status == nullptr) {
164 out.resize(size);
165 ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
166 } else {
167 ThrowOnError(status);
168 }
169 return out;
170}
171
172inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) const {
173 OrtTensorTypeAndShapeInfo* out;
174 ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
175 return out;
176}
177
178inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
179 size_t out;
180 ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
181 return out;
182}
183
184inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const {
185 ONNXTensorElementDataType out;
186 ThrowOnError(api_.GetTensorElementType(info, &out));
187 return out;
188}
189
190inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
191 size_t out;
192 ThrowOnError(api_.GetDimensionsCount(info, &out));
193 return out;
194}
195
196inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) const {
197 ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
198}
199
200inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const {
201 ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
202}
203
204template <typename T>
205inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) const {
206 T* data = nullptr;
207 ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
208 return data;
209}
210
211template <typename T>
212inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) const {
213 return GetTensorMutableData<T>(const_cast<OrtValue*>(value));
214}
215
216inline void* CustomOpApi::GetTensorMutableRawData(_Inout_ OrtValue* value) const {
217 void* data = nullptr;
218 ThrowOnError(api_.GetTensorMutableData(value, &data));
219 return data;
220}
221
222inline const void* CustomOpApi::GetTensorRawData(_Inout_ const OrtValue* value) const {
223 return GetTensorMutableRawData(const_cast<OrtValue*>(value));
224}
225
226inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const {
227 std::vector<int64_t> output(GetDimensionsCount(info));
228 GetDimensions(info, output.data(), output.size());
229 return output;
230}
231
232inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const {
233 api_.ReleaseTensorTypeAndShapeInfo(input);
234}
235
236inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) const {
237 size_t out;
238 ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
239 return out;
240}
241
242inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const {
243 const OrtValue* out;
244 ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
245 return out;
246}
247
248inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) const {
249 size_t out;
250 ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
251 return out;
252}
253
254inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
255 _In_ const int64_t* dim_values, size_t dim_count) const {
256 OrtValue* out;
257 ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
258 return out;
259}
260
261template <>
262inline OrtStatusPtr API::KernelInfoGetAttribute<int64_t>(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept {
263 return instance()->KernelInfoGetAttribute_int64(&info, name, &value);
264}
265
266template <>
267inline OrtStatusPtr API::KernelInfoGetAttribute<float>(const OrtKernelInfo& info, const char* name, float& value) noexcept {
268 return instance()->KernelInfoGetAttribute_float(&info, name, &value);
269}
270
271template <>
272inline OrtStatusPtr API::KernelInfoGetAttribute<std::string>(const OrtKernelInfo& info, const char* name, std::string& value) noexcept {
273 size_t size = 0;
274 std::string out;
275 // Feed nullptr for the data buffer to query the true size of the string attribute
276 OrtStatus* status = instance()->KernelInfoGetAttribute_string(&info, name, nullptr, &size);
277 if (status == nullptr) {
278 out.resize(size);
279 status = instance()->KernelInfoGetAttribute_string(&info, name, &out[0], &size);
280 out.resize(size - 1); // remove the terminating character '\0'
281 }
282
283 if (status == nullptr) {
284 value = std::move(out);
285 }
286
287 return status;
288}
289
290template <class T>
291inline OrtStatusPtr GetOpAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept {
292 if (auto status = API::KernelInfoGetAttribute(info, name, value); status) {
293 // Ideally, we should know which kind of error code can be ignored, but it is not available now.
294 // Just ignore all of them.
295 API::ReleaseStatus(status);
296 }
297
298 return nullptr;
299}
300
301template <class T>
302inline T GetOpAttributeOrDefault(const OrtKernelInfo& info, const char* name, const T& default_value) noexcept {
303 T ret;
304 if (API::KernelInfoGetAttribute(info, name, ret)) {
305 ret = default_value;
306 }
307 return ret;
308}
309
310inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) {
311 return API::CreateStatus(code, msg);
312}
313
314inline OrtStatusPtr CreateStatus(const std::string& msg, OrtErrorCode code) {
315 return API::CreateStatus(code, msg.c_str());
316}
317
318inline void ReleaseStatus(OrtStatusPtr& status) {
319 API::ReleaseStatus(status);
320 status = nullptr;
321}
322
323} // namespace of OrtW
324
325
326// Deprecated: No needs to create a new class derived from BaseKernel.
327struct BaseKernel {
328 BaseKernel(const OrtApi& api, const OrtKernelInfo& info) noexcept
329 : api_(api), info_(info), ort_(api_) {
330 }
331
332 template <class T>
333 bool TryToGetAttribute(const char* name, T& value) const noexcept;
334
335 template <class T>
336 T TryToGetAttributeWithDefault(const char* name, const T& default_value) const noexcept {
337 T result = default_value;
338 TryToGetAttribute(name, result);
339 return result;
340 }
341
342 void SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim,
343 const std::vector<int64_t>& data);
344
345 protected:
346 OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept;
347
348 const OrtApi& api_;
349 OrtW::CustomOpApi ort_;
350 const OrtKernelInfo& info_;
351};
352
353// Deprecated: Use OrtW::CustomOpApi::KernelInfoGetAttribute instead
354struct OrtTensorDimensions : std::vector<int64_t> {
355 OrtTensorDimensions() = default;
356 OrtTensorDimensions(const OrtW::CustomOpApi& ort, const OrtValue* value) {
357 OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
358 std::vector<int64_t>::operator=(ort.GetTensorShape(info));
359 ort.ReleaseTensorTypeAndShapeInfo(info);
360 }
361
362 int64_t Size() const {
363 int64_t s = 1;
364 for (auto it = begin(); it != end(); ++it)
365 s *= *it;
366 return s;
367 }
368
369 bool IsScalar() const {
370 return empty();
371 }
372
373 bool IsVector() const {
374 return size() == 1;
375 }
376};
377