microsoft/onnxruntime-extensions

Public

mirrored from https://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
cmake44

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

include/custom_op/tensor_api.h

601lines · modeblame

97ee9eb5Wenbing Li2 years ago1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
64646279Wenbing Li2 years ago4#pragma once
5#include <optional>
6#include <numeric>
7#include <type_traits>
311dd354Wenbing Li2 years ago8#include <assert.h>
9
64646279Wenbing Li2 years ago10#include "onnxruntime_f16.h"
11#include "kernel_context.h"
12
13namespace Ort {
14namespace Custom {
15
16template <typename T>
17struct Span {
18const T* data_ = {};
19size_t size_ = {};
20void Assign(const T* data, size_t size) {
21data_ = data;
22size_ = size;
23}
24size_t size() const { return size_; }
25T operator[](size_t indice) const {
26return data_[indice];
27}
28const T* data() const { return data_; }
29};
30
31
32#if ORT_API_VERSION >= 16
33
34template <>
35struct Span<MFloat16> {
36const MFloat16* data_ = {};
37size_t size_ = {};
38void Assign(const MFloat16* data, size_t size) {
39data_ = data;
40size_ = size;
41}
42size_t size() const { return size_; }
43MFloat16 operator[](size_t indice) const {
44return data_[indice];
45}
46const MFloat16* data() const { return data_; }
47};
48
49template <>
50struct Span<BFloat16> {
51const BFloat16* data_ = {};
52size_t size_ = {};
53void Assign(const BFloat16* data, size_t size) {
54data_ = data;
55size_ = size;
56}
57size_t size() const { return size_; }
58BFloat16 operator[](size_t indice) const {
59return data_[indice];
60}
61const BFloat16* data() const { return data_; }
62};
63
64#endif
65
66class ITensorStorage{
67public:
68virtual const std::vector<int64_t>& Shape() const = 0;
69virtual const void* DataRaw() const = 0;
70virtual bool IsInitialized() const = 0;
71virtual void* Initialize(const std::vector<int64_t>& shape, size_t element_size) = 0;
f0ef40d0Tang, Cheng2 years ago72virtual void* Release() = 0;
beb9fbbaScott McKay2 years ago73virtual ~ITensorStorage() = default;
64646279Wenbing Li2 years ago74};
75
76
77class IAllocator {
78public:
79virtual void* Alloc(size_t size) = 0;
80virtual void Free(void* p) = 0;
81};
82
83
84class OrtEagerTensorStorage : public ITensorStorage {
85public:
86OrtEagerTensorStorage(const std::vector<int64_t>& shape,
87void* buffer) : buffer_(buffer), shape_(shape){
88
89}
90
91OrtEagerTensorStorage(IAllocator* allocator) : allocator_(allocator){
92}
93
beb9fbbaScott McKay2 years ago94~OrtEagerTensorStorage() override{
64646279Wenbing Li2 years ago95if (allocator_ && buffer_)
96allocator_->Free(buffer_);
97}
98
99const std::vector<int64_t>& Shape() const override {
100if (!IsInitialized())
101ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
102return *shape_;
103}
104
beb9fbbaScott McKay2 years ago105bool IsInitialized() const override {
64646279Wenbing Li2 years ago106return shape_.has_value();
107}
108
109const void* DataRaw() const override {
110return buffer_;
111}
112
113void* Initialize(const std::vector<int64_t>& shape, size_t element_size) override {
114if (IsInitialized())
115return buffer_;
116assert(allocator_);
117shape_ = shape;
118int64_t n_elem = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
119auto buffer_size = n_elem * element_size;
120buffer_ = allocator_->Alloc(buffer_size);
121return buffer_;
122}
123
f0ef40d0Tang, Cheng2 years ago124void* Release() override {
125void* tmp = buffer_;
126buffer_ = 0;
127shape_ = std::nullopt;
128return tmp;
129}
130
64646279Wenbing Li2 years ago131private:
132void* buffer_ {};
133std::optional<std::vector<int64_t>> shape_;
134// caller need to make sure the allocator is alive
f9290e8bWenbing Li2 years ago135IAllocator* allocator_{};
64646279Wenbing Li2 years ago136};
137
138template <typename TT>
139ONNXTensorElementDataType GetOrtDType(){
140if constexpr (std::is_same<TT, bool>::value)
141return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
142else if constexpr (std::is_same<TT, float>::value)
143return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
144else if constexpr (std::is_same<TT, double>::value)
145return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
146else if constexpr (std::is_same<TT, uint8_t>::value)
147return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
148else if constexpr (std::is_same<TT, int8_t>::value)
149return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
150else if constexpr (std::is_same<TT, uint16_t>::value)
151return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
152else if constexpr (std::is_same<TT, int16_t>::value)
153return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
154else if constexpr (std::is_same<TT, uint32_t>::value)
155return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
156else if constexpr (std::is_same<TT, int32_t>::value)
157return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
158else if constexpr (std::is_same<TT, uint64_t>::value)
159return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
160else if constexpr (std::is_same<TT, int64_t>::value)
161return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
162else if constexpr (std::is_same<TT, std::string>::value)
163return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
164ORTX_CXX_API_THROW("Unexpected type", ORT_RUNTIME_EXCEPTION);
165return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
166}
167
168class TensorBase : public Arg {
169public:
beb9fbbaScott McKay2 years ago170virtual ~TensorBase() = default;
64646279Wenbing Li2 years ago171
172virtual ONNXTensorElementDataType Type() const = 0;
173virtual const std::vector<int64_t>& Shape() const = 0;
174virtual int64_t NumberOfElement() const = 0;
175virtual const void* DataRaw() const = 0;
176virtual size_t SizeInBytes() const = 0;
cbed8fd5Wenbing Li2 years ago177
178virtual std::byte* AllocateRaw(const std::vector<int64_t>& shape) = 0;
64646279Wenbing Li2 years ago179};
180
181template <typename T>
182class Tensor : public TensorBase {
183public:
184using TT = typename std::remove_reference<T>::type;
185Tensor(std::unique_ptr<ITensorStorage> tensor_storage) : storage_(std::move(tensor_storage)){
186}
187
188Tensor(const std::vector<int64_t>& shape, void* buffer) : Tensor(std::make_unique<OrtEagerTensorStorage>(shape, buffer)) {}
189
190Tensor(IAllocator* allocator) : storage_(std::make_unique<OrtEagerTensorStorage>(allocator)){}
191
f0ef40d0Tang, Cheng2 years ago192Tensor(const Tensor& src) = delete;
193
194Tensor& operator=(Tensor src) = delete;
195
196Tensor(Tensor&& other) : storage_(std::move(other.storage_)) {
197other.storage_ = nullptr;
198other.span_ = {};
199}
200
201Tensor& operator=(Tensor&& other)
202{
203storage_ = std::move(other.storage_);
204other.span_ = {};
205return *this;
206}
207
64646279Wenbing Li2 years ago208operator bool() const {
f0ef40d0Tang, Cheng2 years ago209return storage_ && storage_->IsInitialized();
64646279Wenbing Li2 years ago210}
211
212ONNXTensorElementDataType Type() const override {
213return GetOrtDType<T>();
214}
215
216const std::vector<int64_t>& Shape() const override {
f0ef40d0Tang, Cheng2 years ago217if (!storage_)
218ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
64646279Wenbing Li2 years ago219return storage_->Shape();
220}
221
222int64_t NumberOfElement() const override {
f0ef40d0Tang, Cheng2 years ago223if (!storage_)
224ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
64646279Wenbing Li2 years ago225auto& shape = storage_->Shape();
226return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
227}
228
229std::string Shape2Str() const {
f0ef40d0Tang, Cheng2 years ago230if (!storage_)
231ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
232if (storage_&& storage_->IsInitialized()) {
64646279Wenbing Li2 years ago233std::string shape_str;
234auto& shape = storage_->Shape();
235for (const auto& dim : shape) {
236shape_str.append(std::to_string(dim));
237shape_str.append(", ");
238}
239return shape_str;
240} else {
241return "empty";
242}
243}
f0ef40d0Tang, Cheng2 years ago244
245void* Release() {
246if (!storage_)
247ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
248span_ = {};
249return storage_->Release();
250}
64646279Wenbing Li2 years ago251
252const TT* Data() const {
f0ef40d0Tang, Cheng2 years ago253if (!storage_)
254ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
64646279Wenbing Li2 years ago255#if ORT_API_VERSION >= 16
256if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value)
257return reinterpret_cast<const TT*>(storage_->DataRaw());
258else
259#endif
260return static_cast<const TT*>(storage_->DataRaw());
261}
262
263const void* DataRaw() const override {
f0ef40d0Tang, Cheng2 years ago264if (!storage_)
265ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
64646279Wenbing Li2 years ago266return storage_->DataRaw();
267}
268
269size_t SizeInBytes() const override {
f0ef40d0Tang, Cheng2 years ago270if (!storage_)
271ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
64646279Wenbing Li2 years ago272return NumberOfElement() * sizeof(TT);
273}
274
275TT* Allocate(const std::vector<int64_t>& shape) {
f0ef40d0Tang, Cheng2 years ago276if (!storage_)
277ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
64646279Wenbing Li2 years ago278// it should be OK to allocate multiple times
279void* buffer = storage_->Initialize(shape, sizeof(TT));
280#if ORT_API_VERSION >= 16
281if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value)
282return reinterpret_cast<TT*>(buffer);
283else
284#endif
285return static_cast<TT*>(buffer);
286}
287
cbed8fd5Wenbing Li2 years ago288std::byte* AllocateRaw(const std::vector<int64_t>& shape) override {
289return reinterpret_cast<std::byte*>(Allocate(shape));
290}
291
64646279Wenbing Li2 years ago292const Span<T>& AsSpan() {
f0ef40d0Tang, Cheng2 years ago293if (!storage_)
294ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
64646279Wenbing Li2 years ago295#if ORT_API_VERSION >= 16
296if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value) {
297ORTX_CXX_API_THROW("AsSpan for MFloat16 / BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
298}
299else{
300#endif
301auto& shape = storage_->Shape();
302if (shape.size() != 1) {
303ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
304}
305span_.Assign(Data(), shape[0]);
306return span_;
307#if ORT_API_VERSION >= 16
308}
309#endif
310}
311
312const T& AsScalar() {
f0ef40d0Tang, Cheng2 years ago313if (!storage_)
314ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
64646279Wenbing Li2 years ago315#if ORT_API_VERSION >= 16
316if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value) {
317ORTX_CXX_API_THROW("AsScalar for MFloat16 / BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
318}
319else{
320#endif
321auto& shape = storage_->Shape();
322if ((shape.size() == 1 && shape[0] != 1) || shape.size() > 1) {
323ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
324}
325return *Data();
326#if ORT_API_VERSION >= 16
327}
328#endif
329}
330
331private:
332std::unique_ptr<ITensorStorage> storage_;
333Span<T> span_;
334};
335
336template<typename T>
337class IStringTensorStorage{
338public:
339using strings = std::vector<T>;
340virtual const std::vector<int64_t>& Shape() const = 0;
341virtual const void* DataRaw() const = 0;
342virtual const strings& Data() const = 0;
343virtual bool IsInitialized() const = 0;
344virtual void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) = 0;
345virtual void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) = 0;
beb9fbbaScott McKay2 years ago346virtual ~IStringTensorStorage() = default;
64646279Wenbing Li2 years ago347};
348
349template<typename T>
350class EagerStringTensorStorage : public IStringTensorStorage<T>{
351public:
352using strings = std::vector<T>;
353EagerStringTensorStorage(const strings& ss) : input_strings_(ss), shape_(std::vector<int64_t>{static_cast<int64_t>(ss.size())}){}
354
355EagerStringTensorStorage() {}
356
357const std::vector<int64_t>& Shape() const override {
358if (!IsInitialized())
359ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
360return *shape_;
361}
362
beb9fbbaScott McKay2 years ago363const void* DataRaw() const override {
64646279Wenbing Li2 years ago364if (input_strings_.size() != 1) {
365ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
366}
367if constexpr (std::is_same<std::string_view, T>::value)
368return reinterpret_cast<const void*>(input_strings_[0].data());
369else
370return reinterpret_cast<const void*>(input_strings_[0].c_str());
371}
372
beb9fbbaScott McKay2 years ago373bool IsInitialized() const override {
64646279Wenbing Li2 years ago374return shape_.has_value();
375}
376
beb9fbbaScott McKay2 years ago377void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) override {
64646279Wenbing Li2 years ago378if constexpr (std::is_same<std::string_view, T>::value)
379ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
380input_strings_.assign(ss.begin(), ss.end());
381shape_ = dims;
382}
383
384const strings& Data() const override {
385return input_strings_;
386}
387
beb9fbbaScott McKay2 years ago388void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) override {
64646279Wenbing Li2 years ago389if constexpr (std::is_same<std::string_view, T>::value)
390ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
391
392for (const char* s : ss){
393input_strings_.push_back(s);
394}
395shape_ = dims;
396}
397
398private:
399std::vector<T> input_strings_;
400std::optional<std::vector<int64_t>> shape_;
401};
402
403template <>
404class Tensor<std::string> : public TensorBase {
405public:
406using strings = std::vector<std::string>;
407
408Tensor(std::unique_ptr<IStringTensorStorage<std::string>> storage) : storage_(std::move(storage)) {}
409
410Tensor(const strings& ss) : storage_(std::make_unique<EagerStringTensorStorage<std::string>>(ss)) {}
411
412Tensor() : storage_(std::make_unique<EagerStringTensorStorage<std::string>>()) {}
413
414ONNXTensorElementDataType Type() const override {
415return GetOrtDType<std::string>();
416}
417
418const strings& Data() const {
419return storage_->Data();
420}
421
422const std::vector<int64_t>& Shape() const override {
423return storage_->Shape();
424}
425
426int64_t NumberOfElement() const override {
427auto& shape = storage_->Shape();
428return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
429}
430
431std::string Shape2Str() const {
432if (storage_->IsInitialized()) {
433std::string shape_str;
434auto& shape = storage_->Shape();
435for (const auto& dim : shape) {
436shape_str.append(std::to_string(dim));
437shape_str.append(", ");
438}
439return shape_str;
440} else {
441return "empty";
442}
443}
444
445const void* DataRaw() const override {
446return storage_->DataRaw();
447}
448
449size_t SizeInBytes() const override {
450auto& ss = storage_->Data();
451if (ss.size() != 1) {
452ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
453}
454return ss[0].size();
455}
456
cbed8fd5Wenbing Li2 years ago457std::byte* AllocateRaw(const std::vector<int64_t>& shape) override {
458ORTX_CXX_API_THROW("AllocateRaw() not supported for string tensor", ORT_RUNTIME_EXCEPTION);
459}
460
64646279Wenbing Li2 years ago461void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
462storage_->SetStringOutput(ss, dims);
463}
464void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
465storage_->SetStringOutput(ss, dims);
466}
176c1d01Wenbing Li1 years ago467const Span<std::string>& AsSpan() const {
64646279Wenbing Li2 years ago468ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
469}
176c1d01Wenbing Li1 years ago470const std::string& AsScalar() const {
64646279Wenbing Li2 years ago471auto& ss = storage_->Data();
472if (ss.size() != 1) {
473ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
474}
475return ss[0];
476}
477
478private:
479std::unique_ptr<IStringTensorStorage<std::string>> storage_;
480};
481
482
483template <>
484class Tensor<std::string_view> : public TensorBase {
485public:
486using strings = std::vector<std::string_view>;
487
488Tensor(std::unique_ptr<IStringTensorStorage<std::string_view>> storage) : storage_(std::move(storage)) {}
489
490Tensor(const strings& ss) : storage_(std::make_unique<EagerStringTensorStorage<std::string_view>>(ss)) {}
491
492ONNXTensorElementDataType Type() const override {
493return GetOrtDType<std::string_view>();
494}
495
496const strings& Data() const {
497return storage_->Data();
498}
499
500const std::vector<int64_t>& Shape() const override {
501return storage_->Shape();
502}
503
504int64_t NumberOfElement() const override {
505auto& shape = storage_->Shape();
506return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
507}
508
509std::string Shape2Str() const {
510if (storage_->IsInitialized()) {
511std::string shape_str;
512auto& shape = storage_->Shape();
513for (const auto& dim : shape) {
514shape_str.append(std::to_string(dim));
515shape_str.append(", ");
516}
517return shape_str;
518} else {
519return "empty";
520}
521}
522
523const void* DataRaw() const override {
524return storage_->DataRaw();
525}
526
527size_t SizeInBytes() const override {
528auto& ss = storage_->Data();
529if (ss.size() != 1) {
530ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
531}
532return ss[0].size();
533}
534
cbed8fd5Wenbing Li2 years ago535std::byte* AllocateRaw(const std::vector<int64_t>& shape) override {
536ORTX_CXX_API_THROW("AllocateRaw() not supported for string tensor", ORT_RUNTIME_EXCEPTION);
537}
538
64646279Wenbing Li2 years ago539void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
540storage_->SetStringOutput(ss, dims);
541}
542void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
543storage_->SetStringOutput(ss, dims);
544}
176c1d01Wenbing Li1 years ago545const Span<std::string_view>& AsSpan() const {
64646279Wenbing Li2 years ago546ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
547}
176c1d01Wenbing Li1 years ago548const std::string_view& AsScalar() const {
64646279Wenbing Li2 years ago549auto& ss = storage_->Data();
550if (ss.size() != 1) {
551ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
552}
553return ss[0];
554}
555
556private:
557std::unique_ptr<IStringTensorStorage<std::string_view>> storage_;
558};
559
560template<typename ...Args>
561class NamedArgumentDict{
562public:
563using ValueTuple = std::tuple<Args...>;
564
565NamedArgumentDict(const std::vector<const char*>& keys, const std::tuple<Args...>& args) : entries_(args) {
566for (const char* key : keys){
567names_.push_back(key);
568}
569}
570
571template<typename T>
572T TryToGetAttributeWithDefault(const char* name, const T& default_value) const {
573return TryToGetAttributeWithDefaultInternal<0>(name, default_value);
574}
575
576private:
577template<size_t I, typename T>
578typename std::enable_if<I == sizeof...(Args), T>::type
579TryToGetAttributeWithDefaultInternal(const char* name, const T& default_value) const {
580return default_value;
581}
582
583template<size_t I, typename T>
584typename std::enable_if<I < sizeof...(Args), T>::type
585TryToGetAttributeWithDefaultInternal(const char* name, const T& default_value) const {
586if (names_[I] == name){
587if constexpr (std::is_same<std::tuple_element_t<I, ValueTuple>, T>::value)
588return std::get<I>(entries_);
589else
590throw std::runtime_error("name matched but type is not");
591}
592return TryToGetAttributeWithDefaultInternal<I+1>(name, default_value);
593}
594
595std::vector<std::string> names_;
596std::tuple<Args...> entries_;
597
598};
599
600}
601}