microsoft/onnxruntime-extensions

Public

mirrored from https://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
nodeps

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

include/ort_c_to_cpp.h

404lines · modeblame

a0c26255Wenbing Li2 years ago1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5#include <vector>
22340011cao lei2 years ago6#include "exceptions.h"
a0c26255Wenbing Li2 years ago7
64646279Wenbing Li2 years ago8// OrtW: ONNX Runtime C ABI Wrapper
a0c26255Wenbing Li2 years ago9namespace OrtW {
10struct CustomOpApi {
11CustomOpApi(const OrtApi& api) : api_(api) {}
12
13template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
14T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const;
15
16OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value) const;
17size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
18ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const;
19size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
20void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values,
21size_t dim_values_length) const;
22void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const;
23
24template <typename T>
25T* GetTensorMutableData(_Inout_ OrtValue* value) const;
26template <typename T>
27const T* GetTensorData(_Inout_ const OrtValue* value) const;
28
64646279Wenbing Li2 years ago29void* GetTensorMutableRawData(_Inout_ OrtValue* value) const;
30const void* GetTensorRawData(_Inout_ const OrtValue* value) const;
31
a0c26255Wenbing Li2 years ago32std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const;
33void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const;
34size_t KernelContext_GetInputCount(const OrtKernelContext* context) const;
35const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const;
36size_t KernelContext_GetOutputCount(const OrtKernelContext* context) const;
37OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values,
38size_t dim_count) const;
39
40void ThrowOnError(OrtStatus* status) const {
41OrtW::ThrowOnError(api_, status);
42}
43
44const OrtApi& GetOrtApi() const { return api_; }
45
46private:
47const OrtApi& api_;
48};
49
64646279Wenbing Li2 years ago50class API {
51// To use ONNX C ABI in a way like OrtW::API::CreateStatus.
52public:
53static API& instance(const OrtApi* ort_api = nullptr) noexcept {
54static API self(ort_api);
55return self;
56}
57
58static OrtStatusPtr CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept {
59return instance()->CreateStatus(code, msg);
60}
61
62static void ReleaseStatus(OrtStatusPtr ptr) noexcept {
63instance()->ReleaseStatus(ptr);
64}
65
db33bc33Chester Liu1 years ago66static OrtStatusPtr GetOpAttributeString(const OrtApi& api,
67const OrtKernelInfo& info, const char* name, std::string& value) noexcept {
68size_t size = 0;
69OrtStatus* status = api.KernelInfoGetAttribute_string(&info, name, nullptr, &size);
70if (status == nullptr) {
71value.resize(size);
72status = api.KernelInfoGetAttribute_string(&info, name, &value[0], &size);
73value.resize(size - 1); // remove the terminating character '\0'
74if (status != nullptr) {
75return status; // some unexpected error
76}
77} else {
78// ignore the error, as the attribute is optional
79api.ReleaseStatus(status);
80}
81
82return nullptr;
83}
84
64646279Wenbing Li2 years ago85template <typename T>
86static OrtStatusPtr KernelInfoGetAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept;
87
88static void ThrowOnError(OrtStatusPtr ptr) {
89OrtW::ThrowOnError(instance().api_, ptr);
90}
91
92// Caller is responsible for releasing OrtMemoryInfo object
93static OrtStatusPtr CreateOrtMemoryInfo(const char* name, enum OrtAllocatorType type, int id, enum OrtMemType mem_type, OrtMemoryInfo** out) noexcept {
94return instance()->CreateMemoryInfo(name, type, id, mem_type, out);
95}
96#if ORT_API_VERSION >= 15
97// Caller is responsible for releasing OrtAllocator object: delete static_cast<onnxruntime::OrtAllocatorImpl*> (allocator)
98static OrtStatusPtr GetOrtAllocator(const OrtKernelContext* context, const OrtMemoryInfo* mem_info, OrtAllocator** out) {
99return instance()->KernelContext_GetAllocator(context, mem_info, out);
100}
101#endif
102private:
103const OrtApi* operator->() const {
104return &api_;
105}
106
107API(const OrtApi* api) : api_(*api) {
108if (api == nullptr) {
109ORTX_CXX_API_THROW("ort-extensions internal error: ORT-APIs used before RegisterCustomOps", ORT_RUNTIME_EXCEPTION);
110}
111}
112
113const OrtApi& api_;
114};
115
116
a0c26255Wenbing Li2 years ago117//
118// Custom OP API Inlines
119//
120
121template <>
122inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
123float out;
124ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
125return out;
126}
127
128template <>
129inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
130int64_t out;
131ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
132return out;
133}
134
135template <>
136inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
137size_t size = 0;
138std::string out;
139
140// Feed nullptr for the data buffer to query the true size of the string attribute
141OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
142
143if (status == nullptr) {
144out.resize(size);
145ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
146out.resize(size - 1); // remove the terminating character '\0'
147} else {
148ThrowOnError(status);
149}
150return out;
151}
152
153template <>
154inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
155size_t size = 0;
156std::vector<float> out;
157
158// Feed nullptr for the data buffer to query the true size of the attribute
159OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
160
161if (status == nullptr) {
162out.resize(size);
163ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
164} else {
165ThrowOnError(status);
166}
167return out;
168}
169
170template <>
171inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
172size_t size = 0;
173std::vector<int64_t> out;
174
175// Feed nullptr for the data buffer to query the true size of the attribute
176OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
177
178if (status == nullptr) {
179out.resize(size);
180ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
181} else {
182ThrowOnError(status);
183}
184return out;
185}
186
187inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) const {
188OrtTensorTypeAndShapeInfo* out;
189ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
190return out;
191}
192
193inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
194size_t out;
195ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
196return out;
197}
198
199inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const {
200ONNXTensorElementDataType out;
201ThrowOnError(api_.GetTensorElementType(info, &out));
202return out;
203}
204
205inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
206size_t out;
207ThrowOnError(api_.GetDimensionsCount(info, &out));
208return out;
209}
210
211inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) const {
212ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
213}
214
215inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const {
216ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
217}
218
219template <typename T>
220inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) const {
221T* data = nullptr;
222ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
223return data;
224}
225
226template <typename T>
227inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) const {
228return GetTensorMutableData<T>(const_cast<OrtValue*>(value));
229}
230
64646279Wenbing Li2 years ago231inline void* CustomOpApi::GetTensorMutableRawData(_Inout_ OrtValue* value) const {
232void* data = nullptr;
233ThrowOnError(api_.GetTensorMutableData(value, &data));
234return data;
235}
236
237inline const void* CustomOpApi::GetTensorRawData(_Inout_ const OrtValue* value) const {
238return GetTensorMutableRawData(const_cast<OrtValue*>(value));
239}
240
a0c26255Wenbing Li2 years ago241inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const {
242std::vector<int64_t> output(GetDimensionsCount(info));
243GetDimensions(info, output.data(), output.size());
244return output;
245}
246
247inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const {
248api_.ReleaseTensorTypeAndShapeInfo(input);
249}
250
251inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) const {
252size_t out;
253ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
254return out;
255}
256
257inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const {
258const OrtValue* out;
259ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
260return out;
261}
262
263inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) const {
264size_t out;
265ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
266return out;
267}
268
269inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
270_In_ const int64_t* dim_values, size_t dim_count) const {
271OrtValue* out;
272ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
273return out;
274}
275
64646279Wenbing Li2 years ago276template <>
277inline OrtStatusPtr API::KernelInfoGetAttribute<int64_t>(const OrtKernelInfo& info, const char* name, int64_t& value) noexcept {
278return instance()->KernelInfoGetAttribute_int64(&info, name, &value);
279}
280
281template <>
282inline OrtStatusPtr API::KernelInfoGetAttribute<float>(const OrtKernelInfo& info, const char* name, float& value) noexcept {
283return instance()->KernelInfoGetAttribute_float(&info, name, &value);
284}
285
286template <>
287inline OrtStatusPtr API::KernelInfoGetAttribute<std::string>(const OrtKernelInfo& info, const char* name, std::string& value) noexcept {
288size_t size = 0;
289std::string out;
290// Feed nullptr for the data buffer to query the true size of the string attribute
291OrtStatus* status = instance()->KernelInfoGetAttribute_string(&info, name, nullptr, &size);
292if (status == nullptr) {
293out.resize(size);
294status = instance()->KernelInfoGetAttribute_string(&info, name, &out[0], &size);
295out.resize(size - 1); // remove the terminating character '\0'
296}
297
298if (status == nullptr) {
299value = std::move(out);
300}
301
302return status;
303}
304
305template <class T>
306inline OrtStatusPtr GetOpAttribute(const OrtKernelInfo& info, const char* name, T& value) noexcept {
307if (auto status = API::KernelInfoGetAttribute(info, name, value); status) {
308// Ideally, we should know which kind of error code can be ignored, but it is not available now.
309// Just ignore all of them.
310API::ReleaseStatus(status);
311}
312
313return nullptr;
314}
315
316template <class T>
317inline T GetOpAttributeOrDefault(const OrtKernelInfo& info, const char* name, const T& default_value) noexcept {
318T ret;
319if (API::KernelInfoGetAttribute(info, name, ret)) {
320ret = default_value;
321}
322return ret;
323}
324
325inline OrtStatusPtr CreateStatus(const char* msg, OrtErrorCode code) {
326return API::CreateStatus(code, msg);
327}
328
329inline OrtStatusPtr CreateStatus(const std::string& msg, OrtErrorCode code) {
330return API::CreateStatus(code, msg.c_str());
331}
332
333inline void ReleaseStatus(OrtStatusPtr& status) {
334API::ReleaseStatus(status);
335status = nullptr;
336}
337
a0c26255Wenbing Li2 years ago338} // namespace of OrtW
339
340
64646279Wenbing Li2 years ago341// Deprecated: No needs to create a new class derived from BaseKernel.
a0c26255Wenbing Li2 years ago342struct BaseKernel {
343BaseKernel(const OrtApi& api, const OrtKernelInfo& info) noexcept
344: api_(api), info_(info), ort_(api_) {
345}
346
347template <class T>
348bool TryToGetAttribute(const char* name, T& value) const noexcept;
349
350template <class T>
351T TryToGetAttributeWithDefault(const char* name, const T& default_value) const noexcept {
352T result = default_value;
353TryToGetAttribute(name, result);
354return result;
355}
356
357void SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim,
358const std::vector<int64_t>& data);
359
360protected:
361OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept;
362
363const OrtApi& api_;
364const OrtKernelInfo& info_;
58b55238Chester Liu2 years ago365OrtW::CustomOpApi ort_;
a0c26255Wenbing Li2 years ago366};
367
64646279Wenbing Li2 years ago368// Deprecated: Use OrtW::CustomOpApi::KernelInfoGetAttribute instead
a0c26255Wenbing Li2 years ago369struct OrtTensorDimensions : std::vector<int64_t> {
370OrtTensorDimensions() = default;
371OrtTensorDimensions(const OrtW::CustomOpApi& ort, const OrtValue* value) {
372OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
373std::vector<int64_t>::operator=(ort.GetTensorShape(info));
374ort.ReleaseTensorTypeAndShapeInfo(info);
375}
376
377int64_t Size() const {
378int64_t s = 1;
379for (auto it = begin(); it != end(); ++it)
380s *= *it;
381return s;
382}
383
384bool IsScalar() const {
385return empty();
386}
387
388bool IsVector() const {
389return size() == 1;
390}
391};
176c1d01Wenbing Li1 years ago392
393inline bool IsScalarOr1ElementVector(size_t num_dimensions, int64_t shape_size) {
394if (num_dimensions == 0 || (num_dimensions == 1 && shape_size == 1)) return true;
395return false;
396}
397
1a21d451Wenbing Li1 years ago398#define ORTW_RETURN_IF_ERROR(expr) \
176c1d01Wenbing Li1 years ago399do { \
400auto _status = (expr); \
401if (_status != nullptr) { \
402return _status; \
403} \
404} while (0)