microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
b661d5f22f396e757eb1de6e1ab28f2a50f0e81b

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

include/custom_op/tensor_api.h

601lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5#include <optional>
6#include <numeric>
7#include <type_traits>
8#include <assert.h>
9
10#include "onnxruntime_f16.h"
11#include "kernel_context.h"
12
13namespace Ort {
14namespace Custom {
15
16template <typename T>
17struct Span {
18 const T* data_ = {};
19 size_t size_ = {};
20 void Assign(const T* data, size_t size) {
21 data_ = data;
22 size_ = size;
23 }
24 size_t size() const { return size_; }
25 T operator[](size_t indice) const {
26 return data_[indice];
27 }
28 const T* data() const { return data_; }
29};
30
31
32#if ORT_API_VERSION >= 16
33
34template <>
35struct Span<MFloat16> {
36 const MFloat16* data_ = {};
37 size_t size_ = {};
38 void Assign(const MFloat16* data, size_t size) {
39 data_ = data;
40 size_ = size;
41 }
42 size_t size() const { return size_; }
43 MFloat16 operator[](size_t indice) const {
44 return data_[indice];
45 }
46 const MFloat16* data() const { return data_; }
47};
48
49template <>
50struct Span<BFloat16> {
51 const BFloat16* data_ = {};
52 size_t size_ = {};
53 void Assign(const BFloat16* data, size_t size) {
54 data_ = data;
55 size_ = size;
56 }
57 size_t size() const { return size_; }
58 BFloat16 operator[](size_t indice) const {
59 return data_[indice];
60 }
61 const BFloat16* data() const { return data_; }
62};
63
64#endif
65
66class ITensorStorage{
67public:
68 virtual const std::vector<int64_t>& Shape() const = 0;
69 virtual const void* DataRaw() const = 0;
70 virtual bool IsInitialized() const = 0;
71 virtual void* Initialize(const std::vector<int64_t>& shape, size_t element_size) = 0;
72 virtual void* Release() = 0;
73 virtual ~ITensorStorage() = default;
74};
75
76
77class IAllocator {
78public:
79 virtual void* Alloc(size_t size) = 0;
80 virtual void Free(void* p) = 0;
81};
82
83
84class OrtEagerTensorStorage : public ITensorStorage {
85public:
86 OrtEagerTensorStorage(const std::vector<int64_t>& shape,
87 void* buffer) : buffer_(buffer), shape_(shape){
88
89 }
90
91 OrtEagerTensorStorage(IAllocator* allocator) : allocator_(allocator){
92 }
93
94 ~OrtEagerTensorStorage() override{
95 if (allocator_ && buffer_)
96 allocator_->Free(buffer_);
97 }
98
99 const std::vector<int64_t>& Shape() const override {
100 if (!IsInitialized())
101 ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
102 return *shape_;
103 }
104
105 bool IsInitialized() const override {
106 return shape_.has_value();
107 }
108
109 const void* DataRaw() const override {
110 return buffer_;
111 }
112
113 void* Initialize(const std::vector<int64_t>& shape, size_t element_size) override {
114 if (IsInitialized())
115 return buffer_;
116 assert(allocator_);
117 shape_ = shape;
118 int64_t n_elem = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
119 auto buffer_size = n_elem * element_size;
120 buffer_ = allocator_->Alloc(buffer_size);
121 return buffer_;
122 }
123
124 void* Release() override {
125 void* tmp = buffer_;
126 buffer_ = 0;
127 shape_ = std::nullopt;
128 return tmp;
129 }
130
131private:
132 void* buffer_ {};
133 std::optional<std::vector<int64_t>> shape_;
134 // caller need to make sure the allocator is alive
135 IAllocator* allocator_{};
136};
137
138template <typename TT>
139ONNXTensorElementDataType GetOrtDType(){
140 if constexpr (std::is_same<TT, bool>::value)
141 return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
142 else if constexpr (std::is_same<TT, float>::value)
143 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
144 else if constexpr (std::is_same<TT, double>::value)
145 return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
146 else if constexpr (std::is_same<TT, uint8_t>::value)
147 return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
148 else if constexpr (std::is_same<TT, int8_t>::value)
149 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
150 else if constexpr (std::is_same<TT, uint16_t>::value)
151 return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
152 else if constexpr (std::is_same<TT, int16_t>::value)
153 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
154 else if constexpr (std::is_same<TT, uint32_t>::value)
155 return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
156 else if constexpr (std::is_same<TT, int32_t>::value)
157 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
158 else if constexpr (std::is_same<TT, uint64_t>::value)
159 return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
160 else if constexpr (std::is_same<TT, int64_t>::value)
161 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
162 else if constexpr (std::is_same<TT, std::string>::value)
163 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
164 ORTX_CXX_API_THROW("Unexpected type", ORT_RUNTIME_EXCEPTION);
165 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
166}
167
168class TensorBase : public Arg {
169public:
170 virtual ~TensorBase() = default;
171
172 virtual ONNXTensorElementDataType Type() const = 0;
173 virtual const std::vector<int64_t>& Shape() const = 0;
174 virtual int64_t NumberOfElement() const = 0;
175 virtual const void* DataRaw() const = 0;
176 virtual size_t SizeInBytes() const = 0;
177
178 virtual std::byte* AllocateRaw(const std::vector<int64_t>& shape) = 0;
179};
180
181template <typename T>
182class Tensor : public TensorBase {
183 public:
184 using TT = typename std::remove_reference<T>::type;
185 Tensor(std::unique_ptr<ITensorStorage> tensor_storage) : storage_(std::move(tensor_storage)){
186 }
187
188 Tensor(const std::vector<int64_t>& shape, void* buffer) : Tensor(std::make_unique<OrtEagerTensorStorage>(shape, buffer)) {}
189
190 Tensor(IAllocator* allocator) : storage_(std::make_unique<OrtEagerTensorStorage>(allocator)){}
191
192 Tensor(const Tensor& src) = delete;
193
194 Tensor& operator=(Tensor src) = delete;
195
196 Tensor(Tensor&& other) : storage_(std::move(other.storage_)) {
197 other.storage_ = nullptr;
198 other.span_ = {};
199 }
200
201 Tensor& operator=(Tensor&& other)
202 {
203 storage_ = std::move(other.storage_);
204 other.span_ = {};
205 return *this;
206 }
207
208 operator bool() const {
209 return storage_ && storage_->IsInitialized();
210 }
211
212 ONNXTensorElementDataType Type() const override {
213 return GetOrtDType<T>();
214 }
215
216 const std::vector<int64_t>& Shape() const override {
217 if (!storage_)
218 ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
219 return storage_->Shape();
220 }
221
222 int64_t NumberOfElement() const override {
223 if (!storage_)
224 ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
225 auto& shape = storage_->Shape();
226 return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
227 }
228
229 std::string Shape2Str() const {
230 if (!storage_)
231 ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
232 if (storage_&& storage_->IsInitialized()) {
233 std::string shape_str;
234 auto& shape = storage_->Shape();
235 for (const auto& dim : shape) {
236 shape_str.append(std::to_string(dim));
237 shape_str.append(", ");
238 }
239 return shape_str;
240 } else {
241 return "empty";
242 }
243 }
244
245 void* Release() {
246 if (!storage_)
247 ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
248 span_ = {};
249 return storage_->Release();
250 }
251
252 const TT* Data() const {
253 if (!storage_)
254 ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
255#if ORT_API_VERSION >= 16
256 if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value)
257 return reinterpret_cast<const TT*>(storage_->DataRaw());
258 else
259#endif
260 return static_cast<const TT*>(storage_->DataRaw());
261 }
262
263 const void* DataRaw() const override {
264 if (!storage_)
265 ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
266 return storage_->DataRaw();
267 }
268
269 size_t SizeInBytes() const override {
270 if (!storage_)
271 ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
272 return NumberOfElement() * sizeof(TT);
273 }
274
275 TT* Allocate(const std::vector<int64_t>& shape) {
276 if (!storage_)
277 ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
278 // it should be OK to allocate multiple times
279 void* buffer = storage_->Initialize(shape, sizeof(TT));
280#if ORT_API_VERSION >= 16
281 if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value)
282 return reinterpret_cast<TT*>(buffer);
283 else
284#endif
285 return static_cast<TT*>(buffer);
286 }
287
288 std::byte* AllocateRaw(const std::vector<int64_t>& shape) override {
289 return reinterpret_cast<std::byte*>(Allocate(shape));
290 }
291
292 const Span<T>& AsSpan() {
293 if (!storage_)
294 ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
295#if ORT_API_VERSION >= 16
296 if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value) {
297 ORTX_CXX_API_THROW("AsSpan for MFloat16 / BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
298 }
299 else{
300#endif
301 auto& shape = storage_->Shape();
302 if (shape.size() != 1) {
303 ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
304 }
305 span_.Assign(Data(), shape[0]);
306 return span_;
307#if ORT_API_VERSION >= 16
308 }
309#endif
310 }
311
312 const T& AsScalar() {
313 if (!storage_)
314 ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
315#if ORT_API_VERSION >= 16
316 if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value) {
317 ORTX_CXX_API_THROW("AsScalar for MFloat16 / BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
318 }
319 else{
320#endif
321 auto& shape = storage_->Shape();
322 if ((shape.size() == 1 && shape[0] != 1) || shape.size() > 1) {
323 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
324 }
325 return *Data();
326#if ORT_API_VERSION >= 16
327 }
328#endif
329 }
330
331 private:
332 std::unique_ptr<ITensorStorage> storage_;
333 Span<T> span_;
334};
335
336template<typename T>
337class IStringTensorStorage{
338public:
339 using strings = std::vector<T>;
340 virtual const std::vector<int64_t>& Shape() const = 0;
341 virtual const void* DataRaw() const = 0;
342 virtual const strings& Data() const = 0;
343 virtual bool IsInitialized() const = 0;
344 virtual void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) = 0;
345 virtual void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) = 0;
346 virtual ~IStringTensorStorage() = default;
347};
348
349template<typename T>
350class EagerStringTensorStorage : public IStringTensorStorage<T>{
351public:
352 using strings = std::vector<T>;
353 EagerStringTensorStorage(const strings& ss) : input_strings_(ss), shape_(std::vector<int64_t>{static_cast<int64_t>(ss.size())}){}
354
355 EagerStringTensorStorage() {}
356
357 const std::vector<int64_t>& Shape() const override {
358 if (!IsInitialized())
359 ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
360 return *shape_;
361 }
362
363 const void* DataRaw() const override {
364 if (input_strings_.size() != 1) {
365 ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
366 }
367 if constexpr (std::is_same<std::string_view, T>::value)
368 return reinterpret_cast<const void*>(input_strings_[0].data());
369 else
370 return reinterpret_cast<const void*>(input_strings_[0].c_str());
371 }
372
373 bool IsInitialized() const override {
374 return shape_.has_value();
375 }
376
377 void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) override {
378 if constexpr (std::is_same<std::string_view, T>::value)
379 ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
380 input_strings_.assign(ss.begin(), ss.end());
381 shape_ = dims;
382 }
383
384 const strings& Data() const override {
385 return input_strings_;
386 }
387
388 void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) override {
389 if constexpr (std::is_same<std::string_view, T>::value)
390 ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
391
392 for (const char* s : ss){
393 input_strings_.push_back(s);
394 }
395 shape_ = dims;
396 }
397
398private:
399 std::vector<T> input_strings_;
400 std::optional<std::vector<int64_t>> shape_;
401};
402
403template <>
404class Tensor<std::string> : public TensorBase {
405 public:
406 using strings = std::vector<std::string>;
407
408 Tensor(std::unique_ptr<IStringTensorStorage<std::string>> storage) : storage_(std::move(storage)) {}
409
410 Tensor(const strings& ss) : storage_(std::make_unique<EagerStringTensorStorage<std::string>>(ss)) {}
411
412 Tensor() : storage_(std::make_unique<EagerStringTensorStorage<std::string>>()) {}
413
414 ONNXTensorElementDataType Type() const override {
415 return GetOrtDType<std::string>();
416 }
417
418 const strings& Data() const {
419 return storage_->Data();
420 }
421
422 const std::vector<int64_t>& Shape() const override {
423 return storage_->Shape();
424 }
425
426 int64_t NumberOfElement() const override {
427 auto& shape = storage_->Shape();
428 return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
429 }
430
431 std::string Shape2Str() const {
432 if (storage_->IsInitialized()) {
433 std::string shape_str;
434 auto& shape = storage_->Shape();
435 for (const auto& dim : shape) {
436 shape_str.append(std::to_string(dim));
437 shape_str.append(", ");
438 }
439 return shape_str;
440 } else {
441 return "empty";
442 }
443 }
444
445 const void* DataRaw() const override {
446 return storage_->DataRaw();
447 }
448
449 size_t SizeInBytes() const override {
450 auto& ss = storage_->Data();
451 if (ss.size() != 1) {
452 ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
453 }
454 return ss[0].size();
455 }
456
457 std::byte* AllocateRaw(const std::vector<int64_t>& shape) override {
458 ORTX_CXX_API_THROW("AllocateRaw() not supported for string tensor", ORT_RUNTIME_EXCEPTION);
459 }
460
461 void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
462 storage_->SetStringOutput(ss, dims);
463 }
464 void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
465 storage_->SetStringOutput(ss, dims);
466 }
467 const Span<std::string>& AsSpan() const {
468 ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
469 }
470 const std::string& AsScalar() const {
471 auto& ss = storage_->Data();
472 if (ss.size() != 1) {
473 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
474 }
475 return ss[0];
476 }
477
478 private:
479 std::unique_ptr<IStringTensorStorage<std::string>> storage_;
480};
481
482
483template <>
484class Tensor<std::string_view> : public TensorBase {
485 public:
486 using strings = std::vector<std::string_view>;
487
488 Tensor(std::unique_ptr<IStringTensorStorage<std::string_view>> storage) : storage_(std::move(storage)) {}
489
490 Tensor(const strings& ss) : storage_(std::make_unique<EagerStringTensorStorage<std::string_view>>(ss)) {}
491
492 ONNXTensorElementDataType Type() const override {
493 return GetOrtDType<std::string_view>();
494 }
495
496 const strings& Data() const {
497 return storage_->Data();
498 }
499
500 const std::vector<int64_t>& Shape() const override {
501 return storage_->Shape();
502 }
503
504 int64_t NumberOfElement() const override {
505 auto& shape = storage_->Shape();
506 return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
507 }
508
509 std::string Shape2Str() const {
510 if (storage_->IsInitialized()) {
511 std::string shape_str;
512 auto& shape = storage_->Shape();
513 for (const auto& dim : shape) {
514 shape_str.append(std::to_string(dim));
515 shape_str.append(", ");
516 }
517 return shape_str;
518 } else {
519 return "empty";
520 }
521 }
522
523 const void* DataRaw() const override {
524 return storage_->DataRaw();
525 }
526
527 size_t SizeInBytes() const override {
528 auto& ss = storage_->Data();
529 if (ss.size() != 1) {
530 ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
531 }
532 return ss[0].size();
533 }
534
535 std::byte* AllocateRaw(const std::vector<int64_t>& shape) override {
536 ORTX_CXX_API_THROW("AllocateRaw() not supported for string tensor", ORT_RUNTIME_EXCEPTION);
537 }
538
539 void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
540 storage_->SetStringOutput(ss, dims);
541 }
542 void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
543 storage_->SetStringOutput(ss, dims);
544 }
545 const Span<std::string_view>& AsSpan() const {
546 ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
547 }
548 const std::string_view& AsScalar() const {
549 auto& ss = storage_->Data();
550 if (ss.size() != 1) {
551 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
552 }
553 return ss[0];
554 }
555
556 private:
557 std::unique_ptr<IStringTensorStorage<std::string_view>> storage_;
558};
559
560template<typename ...Args>
561class NamedArgumentDict{
562public:
563 using ValueTuple = std::tuple<Args...>;
564
565 NamedArgumentDict(const std::vector<const char*>& keys, const std::tuple<Args...>& args) : entries_(args) {
566 for (const char* key : keys){
567 names_.push_back(key);
568 }
569 }
570
571 template<typename T>
572 T TryToGetAttributeWithDefault(const char* name, const T& default_value) const {
573 return TryToGetAttributeWithDefaultInternal<0>(name, default_value);
574 }
575
576private:
577 template<size_t I, typename T>
578 typename std::enable_if<I == sizeof...(Args), T>::type
579 TryToGetAttributeWithDefaultInternal(const char* name, const T& default_value) const {
580 return default_value;
581 }
582
583 template<size_t I, typename T>
584 typename std::enable_if<I < sizeof...(Args), T>::type
585 TryToGetAttributeWithDefaultInternal(const char* name, const T& default_value) const {
586 if (names_[I] == name){
587 if constexpr (std::is_same<std::tuple_element_t<I, ValueTuple>, T>::value)
588 return std::get<I>(entries_);
589 else
590 throw std::runtime_error("name matched but type is not");
591 }
592 return TryToGetAttributeWithDefaultInternal<I+1>(name, default_value);
593 }
594
595 std::vector<std::string> names_;
596 std::tuple<Args...> entries_;
597
598};
599
600}
601}
602