microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
wenbingl-patch-2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

include/custom_op/tensor_api.h

588lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5#include <optional>
6#include <numeric>
7#include <type_traits>
8#include <assert.h>
9
10#include "onnxruntime_f16.h"
11#include "kernel_context.h"
12
13namespace Ort {
14namespace Custom {
15
16template <typename T>
17struct Span {
18 const T* data_ = {};
19 size_t size_ = {};
20 void Assign(const T* data, size_t size) {
21 data_ = data;
22 size_ = size;
23 }
24 size_t size() const { return size_; }
25 T operator[](size_t indice) const {
26 return data_[indice];
27 }
28 const T* data() const { return data_; }
29};
30
31
32#if ORT_API_VERSION >= 16
33
34template <>
35struct Span<MFloat16> {
36 const MFloat16* data_ = {};
37 size_t size_ = {};
38 void Assign(const MFloat16* data, size_t size) {
39 data_ = data;
40 size_ = size;
41 }
42 size_t size() const { return size_; }
43 MFloat16 operator[](size_t indice) const {
44 return data_[indice];
45 }
46 const MFloat16* data() const { return data_; }
47};
48
49template <>
50struct Span<BFloat16> {
51 const BFloat16* data_ = {};
52 size_t size_ = {};
53 void Assign(const BFloat16* data, size_t size) {
54 data_ = data;
55 size_ = size;
56 }
57 size_t size() const { return size_; }
58 BFloat16 operator[](size_t indice) const {
59 return data_[indice];
60 }
61 const BFloat16* data() const { return data_; }
62};
63
64#endif
65
66class ITensorStorage{
67public:
68 virtual const std::vector<int64_t>& Shape() const = 0;
69 virtual const void* DataRaw() const = 0;
70 virtual bool IsInitialized() const = 0;
71 virtual void* Initialize(const std::vector<int64_t>& shape, size_t element_size) = 0;
72 virtual void* Release() = 0;
73 virtual ~ITensorStorage() = default;
74};
75
76
77class IAllocator {
78public:
79 virtual void* Alloc(size_t size) = 0;
80 virtual void Free(void* p) = 0;
81};
82
83
84class OrtEagerTensorStorage : public ITensorStorage {
85public:
86 OrtEagerTensorStorage(const std::vector<int64_t>& shape,
87 void* buffer) : buffer_(buffer), shape_(shape){
88
89 }
90
91 OrtEagerTensorStorage(IAllocator* allocator) : allocator_(allocator){
92 }
93
94 ~OrtEagerTensorStorage() override{
95 if (allocator_ && buffer_)
96 allocator_->Free(buffer_);
97 }
98
99 const std::vector<int64_t>& Shape() const override {
100 if (!IsInitialized())
101 ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
102 return *shape_;
103 }
104
105 bool IsInitialized() const override {
106 return shape_.has_value();
107 }
108
109 const void* DataRaw() const override {
110 return buffer_;
111 }
112
113 void* Initialize(const std::vector<int64_t>& shape, size_t element_size) override {
114 if (IsInitialized())
115 return buffer_;
116 assert(allocator_);
117 shape_ = shape;
118 int64_t n_elem = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
119 auto buffer_size = n_elem * element_size;
120 buffer_ = allocator_->Alloc(buffer_size);
121 return buffer_;
122 }
123
124 void* Release() override {
125 void* tmp = buffer_;
126 buffer_ = 0;
127 shape_ = std::nullopt;
128 return tmp;
129 }
130
131private:
132 void* buffer_ {};
133 std::optional<std::vector<int64_t>> shape_;
134 // caller need to make sure the allocator is alive
135 IAllocator* allocator_{};
136};
137
138template <typename TT>
139ONNXTensorElementDataType GetOrtDType(){
140 if constexpr (std::is_same<TT, bool>::value)
141 return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
142 else if constexpr (std::is_same<TT, float>::value)
143 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
144 else if constexpr (std::is_same<TT, double>::value)
145 return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
146 else if constexpr (std::is_same<TT, uint8_t>::value)
147 return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
148 else if constexpr (std::is_same<TT, int8_t>::value)
149 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
150 else if constexpr (std::is_same<TT, uint16_t>::value)
151 return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
152 else if constexpr (std::is_same<TT, int16_t>::value)
153 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
154 else if constexpr (std::is_same<TT, uint32_t>::value)
155 return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
156 else if constexpr (std::is_same<TT, int32_t>::value)
157 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
158 else if constexpr (std::is_same<TT, uint64_t>::value)
159 return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
160 else if constexpr (std::is_same<TT, int64_t>::value)
161 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
162 else if constexpr (std::is_same<TT, std::string>::value)
163 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
164 ORTX_CXX_API_THROW("Unexpected type", ORT_RUNTIME_EXCEPTION);
165 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
166}
167
168class TensorBase : public Arg {
169public:
170 virtual ~TensorBase() = default;
171
172 virtual ONNXTensorElementDataType Type() const = 0;
173 virtual const std::vector<int64_t>& Shape() const = 0;
174 virtual int64_t NumberOfElement() const = 0;
175 virtual const void* DataRaw() const = 0;
176 virtual size_t SizeInBytes() const = 0;
177};
178
179template <typename T>
180class Tensor : public TensorBase {
181 public:
182 using TT = typename std::remove_reference<T>::type;
183 Tensor(std::unique_ptr<ITensorStorage> tensor_storage) : storage_(std::move(tensor_storage)){
184 }
185
186 Tensor(const std::vector<int64_t>& shape, void* buffer) : Tensor(std::make_unique<OrtEagerTensorStorage>(shape, buffer)) {}
187
188 Tensor(IAllocator* allocator) : storage_(std::make_unique<OrtEagerTensorStorage>(allocator)){}
189
190 Tensor(const Tensor& src) = delete;
191
192 Tensor& operator=(Tensor src) = delete;
193
194 Tensor(Tensor&& other) : storage_(std::move(other.storage_)) {
195 other.storage_ = nullptr;
196 other.span_ = {};
197 }
198
199 Tensor& operator=(Tensor&& other)
200 {
201 storage_ = std::move(other.storage_);
202 other.span_ = {};
203 return *this;
204 }
205
206 operator bool() const {
207 return storage_ && storage_->IsInitialized();
208 }
209
210 ONNXTensorElementDataType Type() const override {
211 return GetOrtDType<T>();
212 }
213
214 const std::vector<int64_t>& Shape() const override {
215 if (!storage_)
216 ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
217 return storage_->Shape();
218 }
219
220 int64_t NumberOfElement() const override {
221 if (!storage_)
222 ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
223 auto& shape = storage_->Shape();
224 return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
225 }
226
227 std::string Shape2Str() const {
228 if (!storage_)
229 ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
230 if (storage_&& storage_->IsInitialized()) {
231 std::string shape_str;
232 auto& shape = storage_->Shape();
233 for (const auto& dim : shape) {
234 shape_str.append(std::to_string(dim));
235 shape_str.append(", ");
236 }
237 return shape_str;
238 } else {
239 return "empty";
240 }
241 }
242
243 void* Release() {
244 if (!storage_)
245 ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
246 span_ = {};
247 return storage_->Release();
248 }
249
250 const TT* Data() const {
251 if (!storage_)
252 ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
253#if ORT_API_VERSION >= 16
254 if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value)
255 return reinterpret_cast<const TT*>(storage_->DataRaw());
256 else
257#endif
258 return static_cast<const TT*>(storage_->DataRaw());
259 }
260
261 const void* DataRaw() const override {
262 if (!storage_)
263 ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
264 return storage_->DataRaw();
265 }
266
267 size_t SizeInBytes() const override {
268 if (!storage_)
269 ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
270 return NumberOfElement() * sizeof(TT);
271 }
272
273 TT* Allocate(const std::vector<int64_t>& shape) {
274 if (!storage_)
275 ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
276 // it should be OK to allocate multiple times
277 void* buffer = storage_->Initialize(shape, sizeof(TT));
278#if ORT_API_VERSION >= 16
279 if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value)
280 return reinterpret_cast<TT*>(buffer);
281 else
282#endif
283 return static_cast<TT*>(buffer);
284 }
285
286 const Span<T>& AsSpan() {
287 if (!storage_)
288 ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
289#if ORT_API_VERSION >= 16
290 if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value) {
291 ORTX_CXX_API_THROW("AsSpan for MFloat16 / BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
292 }
293 else{
294#endif
295 auto& shape = storage_->Shape();
296 if (shape.size() != 1) {
297 ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
298 }
299 span_.Assign(Data(), shape[0]);
300 return span_;
301#if ORT_API_VERSION >= 16
302 }
303#endif
304 }
305
306 const T& AsScalar() {
307 if (!storage_)
308 ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
309#if ORT_API_VERSION >= 16
310 if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value) {
311 ORTX_CXX_API_THROW("AsScalar for MFloat16 / BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
312 }
313 else{
314#endif
315 auto& shape = storage_->Shape();
316 if ((shape.size() == 1 && shape[0] != 1) || shape.size() > 1) {
317 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
318 }
319 return *Data();
320#if ORT_API_VERSION >= 16
321 }
322#endif
323 }
324
325 private:
326 std::unique_ptr<ITensorStorage> storage_;
327 Span<T> span_;
328};
329
330template<typename T>
331class IStringTensorStorage{
332public:
333 using strings = std::vector<T>;
334 virtual const std::vector<int64_t>& Shape() const = 0;
335 virtual const void* DataRaw() const = 0;
336 virtual const strings& Data() const = 0;
337 virtual bool IsInitialized() const = 0;
338 virtual void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) = 0;
339 virtual void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) = 0;
340 virtual ~IStringTensorStorage() = default;
341};
342
343template<typename T>
344class EagerStringTensorStorage : public IStringTensorStorage<T>{
345public:
346 using strings = std::vector<T>;
347 EagerStringTensorStorage(const strings& ss) : input_strings_(ss), shape_(std::vector<int64_t>{static_cast<int64_t>(ss.size())}){}
348
349 EagerStringTensorStorage() {}
350
351 const std::vector<int64_t>& Shape() const override {
352 if (!IsInitialized())
353 ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
354 return *shape_;
355 }
356
357 const void* DataRaw() const override {
358 if (input_strings_.size() != 1) {
359 ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
360 }
361 if constexpr (std::is_same<std::string_view, T>::value)
362 return reinterpret_cast<const void*>(input_strings_[0].data());
363 else
364 return reinterpret_cast<const void*>(input_strings_[0].c_str());
365 }
366
367 bool IsInitialized() const override {
368 return shape_.has_value();
369 }
370
371 void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) override {
372 if constexpr (std::is_same<std::string_view, T>::value)
373 ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
374 input_strings_.assign(ss.begin(), ss.end());
375 shape_ = dims;
376 }
377
378 const strings& Data() const override {
379 return input_strings_;
380 }
381
382 void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) override {
383 if constexpr (std::is_same<std::string_view, T>::value)
384 ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
385
386 for (const char* s : ss){
387 input_strings_.push_back(s);
388 }
389 shape_ = dims;
390 }
391
392private:
393 std::vector<T> input_strings_;
394 std::optional<std::vector<int64_t>> shape_;
395};
396
397template <>
398class Tensor<std::string> : public TensorBase {
399 public:
400 using strings = std::vector<std::string>;
401
402 Tensor(std::unique_ptr<IStringTensorStorage<std::string>> storage) : storage_(std::move(storage)) {}
403
404 Tensor(const strings& ss) : storage_(std::make_unique<EagerStringTensorStorage<std::string>>(ss)) {}
405
406 Tensor() : storage_(std::make_unique<EagerStringTensorStorage<std::string>>()) {}
407
408 ONNXTensorElementDataType Type() const override {
409 return GetOrtDType<std::string>();
410 }
411
412 const strings& Data() const {
413 return storage_->Data();
414 }
415
416 const std::vector<int64_t>& Shape() const override {
417 return storage_->Shape();
418 }
419
420 int64_t NumberOfElement() const override {
421 auto& shape = storage_->Shape();
422 return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
423 }
424
425 std::string Shape2Str() const {
426 if (storage_->IsInitialized()) {
427 std::string shape_str;
428 auto& shape = storage_->Shape();
429 for (const auto& dim : shape) {
430 shape_str.append(std::to_string(dim));
431 shape_str.append(", ");
432 }
433 return shape_str;
434 } else {
435 return "empty";
436 }
437 }
438
439 const void* DataRaw() const override {
440 return storage_->DataRaw();
441 }
442
443 size_t SizeInBytes() const override {
444 auto& ss = storage_->Data();
445 if (ss.size() != 1) {
446 ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
447 }
448 return ss[0].size();
449 }
450
451 void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
452 storage_->SetStringOutput(ss, dims);
453 }
454 void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
455 storage_->SetStringOutput(ss, dims);
456 }
457 const Span<std::string>& AsSpan() {
458 ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
459 }
460 const std::string& AsScalar() {
461 auto& ss = storage_->Data();
462 if (ss.size() != 1) {
463 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
464 }
465 return ss[0];
466 }
467
468 private:
469 std::unique_ptr<IStringTensorStorage<std::string>> storage_;
470};
471
472
473template <>
474class Tensor<std::string_view> : public TensorBase {
475 public:
476 using strings = std::vector<std::string_view>;
477
478 Tensor(std::unique_ptr<IStringTensorStorage<std::string_view>> storage) : storage_(std::move(storage)) {}
479
480 Tensor(const strings& ss) : storage_(std::make_unique<EagerStringTensorStorage<std::string_view>>(ss)) {}
481
482 ONNXTensorElementDataType Type() const override {
483 return GetOrtDType<std::string_view>();
484 }
485
486 const strings& Data() const {
487 return storage_->Data();
488 }
489
490 const std::vector<int64_t>& Shape() const override {
491 return storage_->Shape();
492 }
493
494 int64_t NumberOfElement() const override {
495 auto& shape = storage_->Shape();
496 return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
497 }
498
499 std::string Shape2Str() const {
500 if (storage_->IsInitialized()) {
501 std::string shape_str;
502 auto& shape = storage_->Shape();
503 for (const auto& dim : shape) {
504 shape_str.append(std::to_string(dim));
505 shape_str.append(", ");
506 }
507 return shape_str;
508 } else {
509 return "empty";
510 }
511 }
512
513 const void* DataRaw() const override {
514 return storage_->DataRaw();
515 }
516
517 size_t SizeInBytes() const override {
518 auto& ss = storage_->Data();
519 if (ss.size() != 1) {
520 ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
521 }
522 return ss[0].size();
523 }
524
525 void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
526 storage_->SetStringOutput(ss, dims);
527 }
528 void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
529 storage_->SetStringOutput(ss, dims);
530 }
531 const Span<std::string_view>& AsSpan() {
532 ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
533 }
534 const std::string_view& AsScalar() {
535 auto& ss = storage_->Data();
536 if (ss.size() != 1) {
537 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
538 }
539 return ss[0];
540 }
541
542 private:
543 std::unique_ptr<IStringTensorStorage<std::string_view>> storage_;
544};
545
546
547template<typename ...Args>
548class NamedArgumentDict{
549public:
550 using ValueTuple = std::tuple<Args...>;
551
552 NamedArgumentDict(const std::vector<const char*>& keys, const std::tuple<Args...>& args) : entries_(args) {
553 for (const char* key : keys){
554 names_.push_back(key);
555 }
556 }
557
558 template<typename T>
559 T TryToGetAttributeWithDefault(const char* name, const T& default_value) const {
560 return TryToGetAttributeWithDefaultInternal<0>(name, default_value);
561 }
562
563private:
564 template<size_t I, typename T>
565 typename std::enable_if<I == sizeof...(Args), T>::type
566 TryToGetAttributeWithDefaultInternal(const char* name, const T& default_value) const {
567 return default_value;
568 }
569
570 template<size_t I, typename T>
571 typename std::enable_if<I < sizeof...(Args), T>::type
572 TryToGetAttributeWithDefaultInternal(const char* name, const T& default_value) const {
573 if (names_[I] == name){
574 if constexpr (std::is_same<std::tuple_element_t<I, ValueTuple>, T>::value)
575 return std::get<I>(entries_);
576 else
577 throw std::runtime_error("name matched but type is not");
578 }
579 return TryToGetAttributeWithDefaultInternal<I+1>(name, default_value);
580 }
581
582 std::vector<std::string> names_;
583 std::tuple<Args...> entries_;
584
585};
586
587}
588}
589