microsoft/onnxruntime-extensions

Public

mirrored from https://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
wenbingl-patch-2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

include/custom_op/tensor_api.h

588lines · modeblame

97ee9eb5Wenbing Li2 years ago1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
64646279Wenbing Li2 years ago4#pragma once
5#include <optional>
6#include <numeric>
7#include <type_traits>
311dd354Wenbing Li2 years ago8#include <assert.h>
9
64646279Wenbing Li2 years ago10#include "onnxruntime_f16.h"
11#include "kernel_context.h"
12
13namespace Ort {
14namespace Custom {
15
16template <typename T>
17struct Span {
18const T* data_ = {};
19size_t size_ = {};
20void Assign(const T* data, size_t size) {
21data_ = data;
22size_ = size;
23}
24size_t size() const { return size_; }
25T operator[](size_t indice) const {
26return data_[indice];
27}
28const T* data() const { return data_; }
29};
30
31
32#if ORT_API_VERSION >= 16
33
34template <>
35struct Span<MFloat16> {
36const MFloat16* data_ = {};
37size_t size_ = {};
38void Assign(const MFloat16* data, size_t size) {
39data_ = data;
40size_ = size;
41}
42size_t size() const { return size_; }
43MFloat16 operator[](size_t indice) const {
44return data_[indice];
45}
46const MFloat16* data() const { return data_; }
47};
48
49template <>
50struct Span<BFloat16> {
51const BFloat16* data_ = {};
52size_t size_ = {};
53void Assign(const BFloat16* data, size_t size) {
54data_ = data;
55size_ = size;
56}
57size_t size() const { return size_; }
58BFloat16 operator[](size_t indice) const {
59return data_[indice];
60}
61const BFloat16* data() const { return data_; }
62};
63
64#endif
65
66class ITensorStorage{
67public:
68virtual const std::vector<int64_t>& Shape() const = 0;
69virtual const void* DataRaw() const = 0;
70virtual bool IsInitialized() const = 0;
71virtual void* Initialize(const std::vector<int64_t>& shape, size_t element_size) = 0;
f0ef40d0Tang, Cheng2 years ago72virtual void* Release() = 0;
beb9fbbaScott McKay2 years ago73virtual ~ITensorStorage() = default;
64646279Wenbing Li2 years ago74};
75
76
77class IAllocator {
78public:
79virtual void* Alloc(size_t size) = 0;
80virtual void Free(void* p) = 0;
81};
82
83
84class OrtEagerTensorStorage : public ITensorStorage {
85public:
86OrtEagerTensorStorage(const std::vector<int64_t>& shape,
87void* buffer) : buffer_(buffer), shape_(shape){
88
89}
90
91OrtEagerTensorStorage(IAllocator* allocator) : allocator_(allocator){
92}
93
beb9fbbaScott McKay2 years ago94~OrtEagerTensorStorage() override{
64646279Wenbing Li2 years ago95if (allocator_ && buffer_)
96allocator_->Free(buffer_);
97}
98
99const std::vector<int64_t>& Shape() const override {
100if (!IsInitialized())
101ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
102return *shape_;
103}
104
beb9fbbaScott McKay2 years ago105bool IsInitialized() const override {
64646279Wenbing Li2 years ago106return shape_.has_value();
107}
108
109const void* DataRaw() const override {
110return buffer_;
111}
112
113void* Initialize(const std::vector<int64_t>& shape, size_t element_size) override {
114if (IsInitialized())
115return buffer_;
116assert(allocator_);
117shape_ = shape;
118int64_t n_elem = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
119auto buffer_size = n_elem * element_size;
120buffer_ = allocator_->Alloc(buffer_size);
121return buffer_;
122}
123
f0ef40d0Tang, Cheng2 years ago124void* Release() override {
125void* tmp = buffer_;
126buffer_ = 0;
127shape_ = std::nullopt;
128return tmp;
129}
130
64646279Wenbing Li2 years ago131private:
132void* buffer_ {};
133std::optional<std::vector<int64_t>> shape_;
134// caller need to make sure the allocator is alive
f9290e8bWenbing Li2 years ago135IAllocator* allocator_{};
64646279Wenbing Li2 years ago136};
137
138template <typename TT>
139ONNXTensorElementDataType GetOrtDType(){
140if constexpr (std::is_same<TT, bool>::value)
141return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
142else if constexpr (std::is_same<TT, float>::value)
143return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
144else if constexpr (std::is_same<TT, double>::value)
145return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
146else if constexpr (std::is_same<TT, uint8_t>::value)
147return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
148else if constexpr (std::is_same<TT, int8_t>::value)
149return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
150else if constexpr (std::is_same<TT, uint16_t>::value)
151return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
152else if constexpr (std::is_same<TT, int16_t>::value)
153return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
154else if constexpr (std::is_same<TT, uint32_t>::value)
155return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
156else if constexpr (std::is_same<TT, int32_t>::value)
157return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
158else if constexpr (std::is_same<TT, uint64_t>::value)
159return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
160else if constexpr (std::is_same<TT, int64_t>::value)
161return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
162else if constexpr (std::is_same<TT, std::string>::value)
163return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
164ORTX_CXX_API_THROW("Unexpected type", ORT_RUNTIME_EXCEPTION);
165return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
166}
167
168class TensorBase : public Arg {
169public:
beb9fbbaScott McKay2 years ago170virtual ~TensorBase() = default;
64646279Wenbing Li2 years ago171
172virtual ONNXTensorElementDataType Type() const = 0;
173virtual const std::vector<int64_t>& Shape() const = 0;
174virtual int64_t NumberOfElement() const = 0;
175virtual const void* DataRaw() const = 0;
176virtual size_t SizeInBytes() const = 0;
177};
178
179template <typename T>
180class Tensor : public TensorBase {
181public:
182using TT = typename std::remove_reference<T>::type;
183Tensor(std::unique_ptr<ITensorStorage> tensor_storage) : storage_(std::move(tensor_storage)){
184}
185
186Tensor(const std::vector<int64_t>& shape, void* buffer) : Tensor(std::make_unique<OrtEagerTensorStorage>(shape, buffer)) {}
187
188Tensor(IAllocator* allocator) : storage_(std::make_unique<OrtEagerTensorStorage>(allocator)){}
189
f0ef40d0Tang, Cheng2 years ago190Tensor(const Tensor& src) = delete;
191
192Tensor& operator=(Tensor src) = delete;
193
194Tensor(Tensor&& other) : storage_(std::move(other.storage_)) {
195other.storage_ = nullptr;
196other.span_ = {};
197}
198
199Tensor& operator=(Tensor&& other)
200{
201storage_ = std::move(other.storage_);
202other.span_ = {};
203return *this;
204}
205
64646279Wenbing Li2 years ago206operator bool() const {
f0ef40d0Tang, Cheng2 years ago207return storage_ && storage_->IsInitialized();
64646279Wenbing Li2 years ago208}
209
210ONNXTensorElementDataType Type() const override {
211return GetOrtDType<T>();
212}
213
214const std::vector<int64_t>& Shape() const override {
f0ef40d0Tang, Cheng2 years ago215if (!storage_)
216ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
64646279Wenbing Li2 years ago217return storage_->Shape();
218}
219
220int64_t NumberOfElement() const override {
f0ef40d0Tang, Cheng2 years ago221if (!storage_)
222ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
64646279Wenbing Li2 years ago223auto& shape = storage_->Shape();
224return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
225}
226
227std::string Shape2Str() const {
f0ef40d0Tang, Cheng2 years ago228if (!storage_)
229ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
230if (storage_&& storage_->IsInitialized()) {
64646279Wenbing Li2 years ago231std::string shape_str;
232auto& shape = storage_->Shape();
233for (const auto& dim : shape) {
234shape_str.append(std::to_string(dim));
235shape_str.append(", ");
236}
237return shape_str;
238} else {
239return "empty";
240}
241}
f0ef40d0Tang, Cheng2 years ago242
243void* Release() {
244if (!storage_)
245ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
246span_ = {};
247return storage_->Release();
248}
64646279Wenbing Li2 years ago249
250const TT* Data() const {
f0ef40d0Tang, Cheng2 years ago251if (!storage_)
252ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
64646279Wenbing Li2 years ago253#if ORT_API_VERSION >= 16
254if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value)
255return reinterpret_cast<const TT*>(storage_->DataRaw());
256else
257#endif
258return static_cast<const TT*>(storage_->DataRaw());
259}
260
261const void* DataRaw() const override {
f0ef40d0Tang, Cheng2 years ago262if (!storage_)
263ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
64646279Wenbing Li2 years ago264return storage_->DataRaw();
265}
266
267size_t SizeInBytes() const override {
f0ef40d0Tang, Cheng2 years ago268if (!storage_)
269ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
64646279Wenbing Li2 years ago270return NumberOfElement() * sizeof(TT);
271}
272
273TT* Allocate(const std::vector<int64_t>& shape) {
f0ef40d0Tang, Cheng2 years ago274if (!storage_)
275ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
64646279Wenbing Li2 years ago276// it should be OK to allocate multiple times
277void* buffer = storage_->Initialize(shape, sizeof(TT));
278#if ORT_API_VERSION >= 16
279if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value)
280return reinterpret_cast<TT*>(buffer);
281else
282#endif
283return static_cast<TT*>(buffer);
284}
285
286const Span<T>& AsSpan() {
f0ef40d0Tang, Cheng2 years ago287if (!storage_)
288ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
64646279Wenbing Li2 years ago289#if ORT_API_VERSION >= 16
290if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value) {
291ORTX_CXX_API_THROW("AsSpan for MFloat16 / BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
292}
293else{
294#endif
295auto& shape = storage_->Shape();
296if (shape.size() != 1) {
297ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
298}
299span_.Assign(Data(), shape[0]);
300return span_;
301#if ORT_API_VERSION >= 16
302}
303#endif
304}
305
306const T& AsScalar() {
f0ef40d0Tang, Cheng2 years ago307if (!storage_)
308ORTX_CXX_API_THROW("tensor not initialized.", ORT_RUNTIME_EXCEPTION);
64646279Wenbing Li2 years ago309#if ORT_API_VERSION >= 16
310if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value) {
311ORTX_CXX_API_THROW("AsScalar for MFloat16 / BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
312}
313else{
314#endif
315auto& shape = storage_->Shape();
316if ((shape.size() == 1 && shape[0] != 1) || shape.size() > 1) {
317ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
318}
319return *Data();
320#if ORT_API_VERSION >= 16
321}
322#endif
323}
324
325private:
326std::unique_ptr<ITensorStorage> storage_;
327Span<T> span_;
328};
329
330template<typename T>
331class IStringTensorStorage{
332public:
333using strings = std::vector<T>;
334virtual const std::vector<int64_t>& Shape() const = 0;
335virtual const void* DataRaw() const = 0;
336virtual const strings& Data() const = 0;
337virtual bool IsInitialized() const = 0;
338virtual void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) = 0;
339virtual void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) = 0;
beb9fbbaScott McKay2 years ago340virtual ~IStringTensorStorage() = default;
64646279Wenbing Li2 years ago341};
342
343template<typename T>
344class EagerStringTensorStorage : public IStringTensorStorage<T>{
345public:
346using strings = std::vector<T>;
347EagerStringTensorStorage(const strings& ss) : input_strings_(ss), shape_(std::vector<int64_t>{static_cast<int64_t>(ss.size())}){}
348
349EagerStringTensorStorage() {}
350
351const std::vector<int64_t>& Shape() const override {
352if (!IsInitialized())
353ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
354return *shape_;
355}
356
beb9fbbaScott McKay2 years ago357const void* DataRaw() const override {
64646279Wenbing Li2 years ago358if (input_strings_.size() != 1) {
359ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
360}
361if constexpr (std::is_same<std::string_view, T>::value)
362return reinterpret_cast<const void*>(input_strings_[0].data());
363else
364return reinterpret_cast<const void*>(input_strings_[0].c_str());
365}
366
beb9fbbaScott McKay2 years ago367bool IsInitialized() const override {
64646279Wenbing Li2 years ago368return shape_.has_value();
369}
370
beb9fbbaScott McKay2 years ago371void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) override {
64646279Wenbing Li2 years ago372if constexpr (std::is_same<std::string_view, T>::value)
373ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
374input_strings_.assign(ss.begin(), ss.end());
375shape_ = dims;
376}
377
378const strings& Data() const override {
379return input_strings_;
380}
381
beb9fbbaScott McKay2 years ago382void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) override {
64646279Wenbing Li2 years ago383if constexpr (std::is_same<std::string_view, T>::value)
384ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
385
386for (const char* s : ss){
387input_strings_.push_back(s);
388}
389shape_ = dims;
390}
391
392private:
393std::vector<T> input_strings_;
394std::optional<std::vector<int64_t>> shape_;
395};
396
397template <>
398class Tensor<std::string> : public TensorBase {
399public:
400using strings = std::vector<std::string>;
401
402Tensor(std::unique_ptr<IStringTensorStorage<std::string>> storage) : storage_(std::move(storage)) {}
403
404Tensor(const strings& ss) : storage_(std::make_unique<EagerStringTensorStorage<std::string>>(ss)) {}
405
406Tensor() : storage_(std::make_unique<EagerStringTensorStorage<std::string>>()) {}
407
408ONNXTensorElementDataType Type() const override {
409return GetOrtDType<std::string>();
410}
411
412const strings& Data() const {
413return storage_->Data();
414}
415
416const std::vector<int64_t>& Shape() const override {
417return storage_->Shape();
418}
419
420int64_t NumberOfElement() const override {
421auto& shape = storage_->Shape();
422return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
423}
424
425std::string Shape2Str() const {
426if (storage_->IsInitialized()) {
427std::string shape_str;
428auto& shape = storage_->Shape();
429for (const auto& dim : shape) {
430shape_str.append(std::to_string(dim));
431shape_str.append(", ");
432}
433return shape_str;
434} else {
435return "empty";
436}
437}
438
439const void* DataRaw() const override {
440return storage_->DataRaw();
441}
442
443size_t SizeInBytes() const override {
444auto& ss = storage_->Data();
445if (ss.size() != 1) {
446ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
447}
448return ss[0].size();
449}
450
451void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
452storage_->SetStringOutput(ss, dims);
453}
454void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
455storage_->SetStringOutput(ss, dims);
456}
457const Span<std::string>& AsSpan() {
458ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
459}
460const std::string& AsScalar() {
461auto& ss = storage_->Data();
462if (ss.size() != 1) {
463ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
464}
465return ss[0];
466}
467
468private:
469std::unique_ptr<IStringTensorStorage<std::string>> storage_;
470};
471
472
473template <>
474class Tensor<std::string_view> : public TensorBase {
475public:
476using strings = std::vector<std::string_view>;
477
478Tensor(std::unique_ptr<IStringTensorStorage<std::string_view>> storage) : storage_(std::move(storage)) {}
479
480Tensor(const strings& ss) : storage_(std::make_unique<EagerStringTensorStorage<std::string_view>>(ss)) {}
481
482ONNXTensorElementDataType Type() const override {
483return GetOrtDType<std::string_view>();
484}
485
486const strings& Data() const {
487return storage_->Data();
488}
489
490const std::vector<int64_t>& Shape() const override {
491return storage_->Shape();
492}
493
494int64_t NumberOfElement() const override {
495auto& shape = storage_->Shape();
496return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
497}
498
499std::string Shape2Str() const {
500if (storage_->IsInitialized()) {
501std::string shape_str;
502auto& shape = storage_->Shape();
503for (const auto& dim : shape) {
504shape_str.append(std::to_string(dim));
505shape_str.append(", ");
506}
507return shape_str;
508} else {
509return "empty";
510}
511}
512
513const void* DataRaw() const override {
514return storage_->DataRaw();
515}
516
517size_t SizeInBytes() const override {
518auto& ss = storage_->Data();
519if (ss.size() != 1) {
520ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
521}
522return ss[0].size();
523}
524
525void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
526storage_->SetStringOutput(ss, dims);
527}
528void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
529storage_->SetStringOutput(ss, dims);
530}
531const Span<std::string_view>& AsSpan() {
532ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
533}
534const std::string_view& AsScalar() {
535auto& ss = storage_->Data();
536if (ss.size() != 1) {
537ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
538}
539return ss[0];
540}
541
542private:
543std::unique_ptr<IStringTensorStorage<std::string_view>> storage_;
544};
545
546
547template<typename ...Args>
548class NamedArgumentDict{
549public:
550using ValueTuple = std::tuple<Args...>;
551
552NamedArgumentDict(const std::vector<const char*>& keys, const std::tuple<Args...>& args) : entries_(args) {
553for (const char* key : keys){
554names_.push_back(key);
555}
556}
557
558template<typename T>
559T TryToGetAttributeWithDefault(const char* name, const T& default_value) const {
560return TryToGetAttributeWithDefaultInternal<0>(name, default_value);
561}
562
563private:
564template<size_t I, typename T>
565typename std::enable_if<I == sizeof...(Args), T>::type
566TryToGetAttributeWithDefaultInternal(const char* name, const T& default_value) const {
567return default_value;
568}
569
570template<size_t I, typename T>
571typename std::enable_if<I < sizeof...(Args), T>::type
572TryToGetAttributeWithDefaultInternal(const char* name, const T& default_value) const {
573if (names_[I] == name){
574if constexpr (std::is_same<std::tuple_element_t<I, ValueTuple>, T>::value)
575return std::get<I>(entries_);
576else
577throw std::runtime_error("name matched but type is not");
578}
579return TryToGetAttributeWithDefaultInternal<I+1>(name, default_value);
580}
581
582std::vector<std::string> names_;
583std::tuple<Args...> entries_;
584
585};
586
587}
588}