microsoft/onnxruntime-extensions

Public

mirrored from https://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
cf7d14bc9c2a1a51de5c02a657d00a1943fbef55

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/tensor_api.h

492lines · modecode

1#pragma once
2#include <optional>
3#include <numeric>
4#include <type_traits>
5#include "onnxruntime_customop.hpp"
6#include "onnxruntime_f16.h"
7#include "kernel_context.h"
8
9namespace Ort {
10namespace Custom {
11
12template <typename T>
13struct Span {
14 const T* data_ = {};
15 size_t size_ = {};
16 void Assign(const T* data, size_t size) {
17 data_ = data;
18 size_ = size;
19 }
20 size_t size() const { return size_; }
21 T operator[](size_t indice) const {
22 return data_[indice];
23 }
24 const T* data() const { return data_; }
25};
26
27
28#if ORT_API_VERSION >= 16
29
30template <>
31struct Span<MFloat16> {
32 const MFloat16* data_ = {};
33 size_t size_ = {};
34 void Assign(const MFloat16* data, size_t size) {
35 data_ = data;
36 size_ = size;
37 }
38 size_t size() const { return size_; }
39 MFloat16 operator[](size_t indice) const {
40 return data_[indice];
41 }
42 const MFloat16* data() const { return data_; }
43};
44
45template <>
46struct Span<BFloat16> {
47 const BFloat16* data_ = {};
48 size_t size_ = {};
49 void Assign(const BFloat16* data, size_t size) {
50 data_ = data;
51 size_ = size;
52 }
53 size_t size() const { return size_; }
54 BFloat16 operator[](size_t indice) const {
55 return data_[indice];
56 }
57 const BFloat16* data() const { return data_; }
58};
59
60#endif
61
62class ITensorStorage{
63public:
64 virtual const std::vector<int64_t>& Shape() const = 0;
65 virtual const void* DataRaw() const = 0;
66 virtual bool IsInitialized() const = 0;
67 virtual void* Initialize(const std::vector<int64_t>& shape, size_t element_size) = 0;
68};
69
70
71class IAllocator {
72public:
73 virtual void* Alloc(size_t size) = 0;
74 virtual void Free(void* p) = 0;
75};
76
77// TODO: remove this
78class TestAllocator : public IAllocator {
79public:
80 void* Alloc(size_t size) override {
81 return malloc(size);
82 }
83
84 void Free(void* p) override {
85 if (p){
86 free(p);
87 }
88 }
89};
90
91class OrtEagerTensorStorage : public ITensorStorage {
92public:
93 OrtEagerTensorStorage(const std::vector<int64_t>& shape,
94 void* buffer) : buffer_(buffer), shape_(shape){
95
96 }
97
98 OrtEagerTensorStorage(IAllocator* allocator) : allocator_(allocator){
99 }
100
101 virtual ~OrtEagerTensorStorage(){
102 if (allocator_ && buffer_)
103 allocator_->Free(buffer_);
104 }
105
106 const std::vector<int64_t>& Shape() const override {
107 if (!IsInitialized())
108 ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
109 return *shape_;
110 }
111
112 virtual bool IsInitialized() const override {
113 return shape_.has_value();
114 }
115
116 const void* DataRaw() const override {
117 return buffer_;
118 }
119
120 void* Initialize(const std::vector<int64_t>& shape, size_t element_size) override {
121 if (IsInitialized())
122 return buffer_;
123 assert(allocator_);
124 shape_ = shape;
125 int64_t n_elem = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
126 auto buffer_size = n_elem * element_size;
127 buffer_ = allocator_->Alloc(buffer_size);
128 return buffer_;
129 }
130
131private:
132 void* buffer_ {};
133 std::optional<std::vector<int64_t>> shape_;
134 // caller need to make sure the allocator is alive
135 IAllocator* allocator_;
136};
137
138template <typename T>
139class Tensor : public Arg {
140 public:
141 using TT = typename std::remove_reference<T>::type;
142 Tensor(std::unique_ptr<ITensorStorage> tensor_storage) : storage_(std::move(tensor_storage)){
143 }
144
145 Tensor(const std::vector<int64_t>& shape, void* buffer) : Tensor(std::make_unique<OrtEagerTensorStorage>(shape, buffer)) {}
146
147 Tensor(IAllocator* allocator) : storage_(std::make_unique<OrtEagerTensorStorage>(allocator)){}
148
149 virtual ~Tensor() = default;
150
151 operator bool() const {
152 return storage_->IsInitialized();
153 }
154
155 const std::vector<int64_t>& Shape() const {
156 return storage_->Shape();
157 }
158
159 int64_t NumberOfElement() const {
160 auto& shape = storage_->Shape();
161 return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
162 }
163
164 std::string Shape2Str() const {
165 if (storage_->IsInitialized()) {
166 std::string shape_str;
167 auto& shape = storage_->Shape();
168 for (const auto& dim : shape) {
169 shape_str.append(std::to_string(dim));
170 shape_str.append(", ");
171 }
172 return shape_str;
173 } else {
174 return "empty";
175 }
176 }
177
178 const TT* Data() const {
179#if ORT_API_VERSION >= 16
180 if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value)
181 return reinterpret_cast<const TT*>(storage_->DataRaw());
182 else
183#endif
184 return static_cast<const TT*>(storage_->DataRaw());
185 }
186
187 const void* DataRaw() const {
188 return storage_->DataRaw();
189 }
190
191 size_t SizeInBytes() const {
192 return NumberOfElement() * sizeof(TT);
193 }
194
195 TT* Allocate(const std::vector<int64_t>& shape) {
196 // it should be OK to allocate multiple times
197 void* buffer = storage_->Initialize(shape, sizeof(TT));
198#if ORT_API_VERSION >= 16
199 if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value)
200 return reinterpret_cast<TT*>(buffer);
201 else
202#endif
203 return static_cast<TT*>(buffer);
204 }
205
206 const Span<T>& AsSpan() {
207#if ORT_API_VERSION >= 16
208 if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value) {
209 ORTX_CXX_API_THROW("AsSpan for MFloat16 / BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
210 }
211 else{
212#endif
213 auto& shape = storage_->Shape();
214 if (shape.size() != 1) {
215 ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
216 }
217 span_.Assign(Data(), shape[0]);
218 return span_;
219#if ORT_API_VERSION >= 16
220 }
221#endif
222 }
223
224 const T& AsScalar() {
225#if ORT_API_VERSION >= 16
226 if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value) {
227 ORTX_CXX_API_THROW("AsScalar for MFloat16 / BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
228 }
229 else{
230#endif
231 auto& shape = storage_->Shape();
232 if ((shape.size() == 1 && shape[0] != 1) || shape.size() > 1) {
233 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
234 }
235 return *Data();
236#if ORT_API_VERSION >= 16
237 }
238#endif
239 }
240
241 private:
242 std::unique_ptr<ITensorStorage> storage_;
243 Span<T> span_;
244};
245
246template<typename T>
247class IStringTensorStorage{
248public:
249 using strings = std::vector<T>;
250 virtual const std::vector<int64_t>& Shape() const = 0;
251 virtual const void* DataRaw() const = 0;
252 virtual const strings& Data() const = 0;
253 virtual bool IsInitialized() const = 0;
254 virtual void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) = 0;
255 virtual void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) = 0;
256};
257
258template<typename T>
259class EagerStringTensorStorage : public IStringTensorStorage<T>{
260public:
261 using strings = std::vector<T>;
262 EagerStringTensorStorage(const strings& ss) : input_strings_(ss), shape_(std::vector<int64_t>{ss.size()}){}
263
264 EagerStringTensorStorage() {}
265
266 const std::vector<int64_t>& Shape() const override {
267 if (!IsInitialized())
268 ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
269 return *shape_;
270 }
271
272 virtual const void* DataRaw() const override {
273 if (input_strings_.size() != 1) {
274 ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
275 }
276 if constexpr (std::is_same<std::string_view, T>::value)
277 return reinterpret_cast<const void*>(input_strings_[0].data());
278 else
279 return reinterpret_cast<const void*>(input_strings_[0].c_str());
280 }
281
282 virtual bool IsInitialized() const override {
283 return shape_.has_value();
284 }
285
286 virtual void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) override {
287 if constexpr (std::is_same<std::string_view, T>::value)
288 ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
289 input_strings_.assign(ss.begin(), ss.end());
290 shape_ = dims;
291 }
292
293 const strings& Data() const override {
294 return input_strings_;
295 }
296
297 virtual void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) override {
298 if constexpr (std::is_same<std::string_view, T>::value)
299 ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
300
301 for (const char* s : ss){
302 input_strings_.push_back(s);
303 }
304 shape_ = dims;
305 }
306
307private:
308 std::vector<T> input_strings_;
309 std::optional<std::vector<int64_t>> shape_;
310};
311
312template <>
313class Tensor<std::string> : public Arg {
314 public:
315 using strings = std::vector<std::string>;
316
317 Tensor(std::unique_ptr<IStringTensorStorage<std::string>> storage) : storage_(std::move(storage)) {}
318
319 Tensor(const strings& ss) : storage_(std::make_unique<EagerStringTensorStorage<std::string>>(ss)) {}
320
321 Tensor() : storage_(std::make_unique<EagerStringTensorStorage<std::string>>()) {}
322
323 const strings& Data() const {
324 return storage_->Data();
325 }
326
327 const std::vector<int64_t>& Shape() const {
328 return storage_->Shape();
329 }
330
331 int64_t NumberOfElement() const {
332 auto& shape = storage_->Shape();
333 return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
334 }
335
336 std::string Shape2Str() const {
337 if (storage_->IsInitialized()) {
338 std::string shape_str;
339 auto& shape = storage_->Shape();
340 for (const auto& dim : shape) {
341 shape_str.append(std::to_string(dim));
342 shape_str.append(", ");
343 }
344 return shape_str;
345 } else {
346 return "empty";
347 }
348 }
349
350 const void* DataRaw() const {
351 return storage_->DataRaw();
352 }
353
354 size_t SizeInBytes() const {
355 auto& ss = storage_->Data();
356 if (ss.size() != 1) {
357 ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
358 }
359 return ss[0].size();
360 }
361
362 void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
363 storage_->SetStringOutput(ss, dims);
364 }
365 void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
366 storage_->SetStringOutput(ss, dims);
367 }
368 const Span<std::string>& AsSpan() {
369 ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
370 }
371 const std::string& AsScalar() {
372 auto& ss = storage_->Data();
373 if (ss.size() != 1) {
374 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
375 }
376 return ss[0];
377 }
378
379 private:
380 std::unique_ptr<IStringTensorStorage<std::string>> storage_;
381};
382
383
384template <>
385class Tensor<std::string_view> : public Arg {
386 public:
387 using strings = std::vector<std::string_view>;
388
389 Tensor(std::unique_ptr<IStringTensorStorage<std::string_view>> storage) : storage_(std::move(storage)) {}
390
391 Tensor(const strings& ss) : storage_(std::make_unique<EagerStringTensorStorage<std::string_view>>(ss)) {}
392
393 const strings& Data() const {
394 return storage_->Data();
395 }
396
397 const std::vector<int64_t>& Shape() const {
398 return storage_->Shape();
399 }
400
401 int64_t NumberOfElement() const {
402 auto& shape = storage_->Shape();
403 return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
404 }
405
406 std::string Shape2Str() const {
407 if (storage_->IsInitialized()) {
408 std::string shape_str;
409 auto& shape = storage_->Shape();
410 for (const auto& dim : shape) {
411 shape_str.append(std::to_string(dim));
412 shape_str.append(", ");
413 }
414 return shape_str;
415 } else {
416 return "empty";
417 }
418 }
419
420 const void* DataRaw() const {
421 return storage_->DataRaw();
422 }
423
424 size_t SizeInBytes() const {
425 auto& ss = storage_->Data();
426 if (ss.size() != 1) {
427 ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
428 }
429 return ss[0].size();
430 }
431
432 void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
433 storage_->SetStringOutput(ss, dims);
434 }
435 void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
436 storage_->SetStringOutput(ss, dims);
437 }
438 const Span<std::string_view>& AsSpan() {
439 ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
440 }
441 const std::string_view& AsScalar() {
442 auto& ss = storage_->Data();
443 if (ss.size() != 1) {
444 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
445 }
446 return ss[0];
447 }
448
449 private:
450 std::unique_ptr<IStringTensorStorage<std::string_view>> storage_;
451};
452
453
454template<typename ...Args>
455class NamedArgumentDict{
456public:
457 using ValueTuple = std::tuple<Args...>;
458
459 NamedArgumentDict(const std::vector<const char*>& keys, const std::tuple<Args...>& args) : names_(keys), entries_(args) {
460 }
461
462 template<typename T>
463 T TryToGetAttributeWithDefault(const char* name, const T& default_value) const {
464 return TryToGetAttributeWithDefaultInternal<0>(name, default_value);
465 }
466
467private:
468 template<size_t I, typename T>
469 typename std::enable_if<I == sizeof...(Args), T>::type
470 TryToGetAttributeWithDefaultInternal(const char* name, const T& default_value) const {
471 return default_value;
472 }
473
474 template<size_t I, typename T>
475 typename std::enable_if<I < sizeof...(Args), T>::type
476 TryToGetAttributeWithDefaultInternal(const char* name, const T& default_value) const {
477 if (names_[I] == name){
478 if constexpr (std::is_same<std::tuple_element_t<I, ValueTuple>, T>::value)
479 return std::get<I>(entries_);
480 else
481 throw std::runtime_error("name matched but type is not");
482 }
483 return TryToGetAttributeWithDefaultInternal<I+1>(name, default_value);
484 }
485
486 std::vector<const char*> names_;
487 std::tuple<Args...> entries_;
488
489};
490
491}
492}
493