microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
a761fc59b7ace6392d8782fbb2a74888353ecf65

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/onnxruntime_cpp_api_legacy.hpp

251lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5#include <vector>
6#include "onnxruntime_c_api.h"
7
8//
9// DEPRECATED: All new custom OPs should not use any class/struct/functions from this file.
10// TODO: Remove this file once all custom OPs are migrated to the new API
11//
12namespace OrtW {
13
14struct CustomOpApi {
15 CustomOpApi(const OrtApi& api) : api_(api) {}
16
17 template <typename T> // T is only implemented for std::vector<float>, std::vector<int64_t>, float, int64_t, and string
18 T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const;
19
20 OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value) const;
21 size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
22 ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const;
23 size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const;
24 void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values,
25 size_t dim_values_length) const;
26 void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const;
27
28 template <typename T>
29 T* GetTensorMutableData(_Inout_ OrtValue* value) const;
30 template <typename T>
31 const T* GetTensorData(_Inout_ const OrtValue* value) const;
32
33 std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const;
34 void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const;
35 size_t KernelContext_GetInputCount(const OrtKernelContext* context) const;
36 const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const;
37 size_t KernelContext_GetOutputCount(const OrtKernelContext* context) const;
38 OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values,
39 size_t dim_count) const;
40
41 void ThrowOnError(OrtStatus* status) const {
42 OrtW::ThrowOnError(api_, status);
43 }
44
45 const OrtApi& GetOrtApi() const { return api_; }
46
47 private:
48 const OrtApi& api_;
49};
50
51//
52// Custom OP API Inlines
53//
54
55template <>
56inline float CustomOpApi::KernelInfoGetAttribute<float>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
57 float out;
58 ThrowOnError(api_.KernelInfoGetAttribute_float(info, name, &out));
59 return out;
60}
61
62template <>
63inline int64_t CustomOpApi::KernelInfoGetAttribute<int64_t>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
64 int64_t out;
65 ThrowOnError(api_.KernelInfoGetAttribute_int64(info, name, &out));
66 return out;
67}
68
69template <>
70inline std::string CustomOpApi::KernelInfoGetAttribute<std::string>(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
71 size_t size = 0;
72 std::string out;
73
74 // Feed nullptr for the data buffer to query the true size of the string attribute
75 OrtStatus* status = api_.KernelInfoGetAttribute_string(info, name, nullptr, &size);
76
77 if (status == nullptr) {
78 out.resize(size);
79 ThrowOnError(api_.KernelInfoGetAttribute_string(info, name, &out[0], &size));
80 out.resize(size - 1); // remove the terminating character '\0'
81 } else {
82 ThrowOnError(status);
83 }
84 return out;
85}
86
87template <>
88inline std::vector<float> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
89 size_t size = 0;
90 std::vector<float> out;
91
92 // Feed nullptr for the data buffer to query the true size of the attribute
93 OrtStatus* status = api_.KernelInfoGetAttributeArray_float(info, name, nullptr, &size);
94
95 if (status == nullptr) {
96 out.resize(size);
97 ThrowOnError(api_.KernelInfoGetAttributeArray_float(info, name, out.data(), &size));
98 } else {
99 ThrowOnError(status);
100 }
101 return out;
102}
103
104template <>
105inline std::vector<int64_t> CustomOpApi::KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name) const {
106 size_t size = 0;
107 std::vector<int64_t> out;
108
109 // Feed nullptr for the data buffer to query the true size of the attribute
110 OrtStatus* status = api_.KernelInfoGetAttributeArray_int64(info, name, nullptr, &size);
111
112 if (status == nullptr) {
113 out.resize(size);
114 ThrowOnError(api_.KernelInfoGetAttributeArray_int64(info, name, out.data(), &size));
115 } else {
116 ThrowOnError(status);
117 }
118 return out;
119}
120
121inline OrtTensorTypeAndShapeInfo* CustomOpApi::GetTensorTypeAndShape(_In_ const OrtValue* value) const {
122 OrtTensorTypeAndShapeInfo* out;
123 ThrowOnError(api_.GetTensorTypeAndShape(value, &out));
124 return out;
125}
126
127inline size_t CustomOpApi::GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
128 size_t out;
129 ThrowOnError(api_.GetTensorShapeElementCount(info, &out));
130 return out;
131}
132
133inline ONNXTensorElementDataType CustomOpApi::GetTensorElementType(const OrtTensorTypeAndShapeInfo* info) const {
134 ONNXTensorElementDataType out;
135 ThrowOnError(api_.GetTensorElementType(info, &out));
136 return out;
137}
138
139inline size_t CustomOpApi::GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info) const {
140 size_t out;
141 ThrowOnError(api_.GetDimensionsCount(info, &out));
142 return out;
143}
144
145inline void CustomOpApi::GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length) const {
146 ThrowOnError(api_.GetDimensions(info, dim_values, dim_values_length));
147}
148
149inline void CustomOpApi::SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count) const {
150 ThrowOnError(api_.SetDimensions(info, dim_values, dim_count));
151}
152
153template <typename T>
154inline T* CustomOpApi::GetTensorMutableData(_Inout_ OrtValue* value) const {
155 T* data = nullptr;
156 ThrowOnError(api_.GetTensorMutableData(value, reinterpret_cast<void**>(&data)));
157 return data;
158}
159
160template <typename T>
161inline const T* CustomOpApi::GetTensorData(_Inout_ const OrtValue* value) const {
162 return GetTensorMutableData<T>(const_cast<OrtValue*>(value));
163}
164
165inline std::vector<int64_t> CustomOpApi::GetTensorShape(const OrtTensorTypeAndShapeInfo* info) const {
166 std::vector<int64_t> output(GetDimensionsCount(info));
167 GetDimensions(info, output.data(), output.size());
168 return output;
169}
170
171inline void CustomOpApi::ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input) const {
172 api_.ReleaseTensorTypeAndShapeInfo(input);
173}
174
175inline size_t CustomOpApi::KernelContext_GetInputCount(const OrtKernelContext* context) const {
176 size_t out;
177 ThrowOnError(api_.KernelContext_GetInputCount(context, &out));
178 return out;
179}
180
181inline const OrtValue* CustomOpApi::KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index) const {
182 const OrtValue* out;
183 ThrowOnError(api_.KernelContext_GetInput(context, index, &out));
184 return out;
185}
186
187inline size_t CustomOpApi::KernelContext_GetOutputCount(const OrtKernelContext* context) const {
188 size_t out;
189 ThrowOnError(api_.KernelContext_GetOutputCount(context, &out));
190 return out;
191}
192
193inline OrtValue* CustomOpApi::KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index,
194 _In_ const int64_t* dim_values, size_t dim_count) const {
195 OrtValue* out;
196 ThrowOnError(api_.KernelContext_GetOutput(context, index, dim_values, dim_count, &out));
197 return out;
198}
199
200} // namespace of OrtW
201
202
203struct BaseKernel {
204 BaseKernel(const OrtApi& api, const OrtKernelInfo& info) noexcept
205 : api_(api), info_(info), ort_(api_) {
206 }
207
208 template <class T>
209 bool TryToGetAttribute(const char* name, T& value) const noexcept;
210
211 template <class T>
212 T TryToGetAttributeWithDefault(const char* name, const T& default_value) const noexcept {
213 T result = default_value;
214 TryToGetAttribute(name, result);
215 return result;
216 }
217
218 void SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim,
219 const std::vector<int64_t>& data);
220
221 protected:
222 OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept;
223
224 const OrtApi& api_;
225 OrtW::CustomOpApi ort_;
226 const OrtKernelInfo& info_;
227};
228
229struct OrtTensorDimensions : std::vector<int64_t> {
230 OrtTensorDimensions() = default;
231 OrtTensorDimensions(const OrtW::CustomOpApi& ort, const OrtValue* value) {
232 OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
233 std::vector<int64_t>::operator=(ort.GetTensorShape(info));
234 ort.ReleaseTensorTypeAndShapeInfo(info);
235 }
236
237 int64_t Size() const {
238 int64_t s = 1;
239 for (auto it = begin(); it != end(); ++it)
240 s *= *it;
241 return s;
242 }
243
244 bool IsScalar() const {
245 return empty();
246 }
247
248 bool IsVector() const {
249 return size() == 1;
250 }
251};
252