microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
f9290e8bac2758dba8279d7ebc10e7027ffe0503

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

include/custom_op/tensor_api.h

534lines · modecode

1#pragma once
2#include <optional>
3#include <numeric>
4#include <type_traits>
5#include "onnxruntime_f16.h"
6#include "kernel_context.h"
7
8namespace Ort {
9namespace Custom {
10
11template <typename T>
12struct Span {
13 const T* data_ = {};
14 size_t size_ = {};
15 void Assign(const T* data, size_t size) {
16 data_ = data;
17 size_ = size;
18 }
19 size_t size() const { return size_; }
20 T operator[](size_t indice) const {
21 return data_[indice];
22 }
23 const T* data() const { return data_; }
24};
25
26
27#if ORT_API_VERSION >= 16
28
29template <>
30struct Span<MFloat16> {
31 const MFloat16* data_ = {};
32 size_t size_ = {};
33 void Assign(const MFloat16* data, size_t size) {
34 data_ = data;
35 size_ = size;
36 }
37 size_t size() const { return size_; }
38 MFloat16 operator[](size_t indice) const {
39 return data_[indice];
40 }
41 const MFloat16* data() const { return data_; }
42};
43
44template <>
45struct Span<BFloat16> {
46 const BFloat16* data_ = {};
47 size_t size_ = {};
48 void Assign(const BFloat16* data, size_t size) {
49 data_ = data;
50 size_ = size;
51 }
52 size_t size() const { return size_; }
53 BFloat16 operator[](size_t indice) const {
54 return data_[indice];
55 }
56 const BFloat16* data() const { return data_; }
57};
58
59#endif
60
61class ITensorStorage{
62public:
63 virtual const std::vector<int64_t>& Shape() const = 0;
64 virtual const void* DataRaw() const = 0;
65 virtual bool IsInitialized() const = 0;
66 virtual void* Initialize(const std::vector<int64_t>& shape, size_t element_size) = 0;
67};
68
69
70class IAllocator {
71public:
72 virtual void* Alloc(size_t size) = 0;
73 virtual void Free(void* p) = 0;
74};
75
76
77class OrtEagerTensorStorage : public ITensorStorage {
78public:
79 OrtEagerTensorStorage(const std::vector<int64_t>& shape,
80 void* buffer) : buffer_(buffer), shape_(shape){
81
82 }
83
84 OrtEagerTensorStorage(IAllocator* allocator) : allocator_(allocator){
85 }
86
87 virtual ~OrtEagerTensorStorage(){
88 if (allocator_ && buffer_)
89 allocator_->Free(buffer_);
90 }
91
92 const std::vector<int64_t>& Shape() const override {
93 if (!IsInitialized())
94 ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
95 return *shape_;
96 }
97
98 virtual bool IsInitialized() const override {
99 return shape_.has_value();
100 }
101
102 const void* DataRaw() const override {
103 return buffer_;
104 }
105
106 void* Initialize(const std::vector<int64_t>& shape, size_t element_size) override {
107 if (IsInitialized())
108 return buffer_;
109 assert(allocator_);
110 shape_ = shape;
111 int64_t n_elem = std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
112 auto buffer_size = n_elem * element_size;
113 buffer_ = allocator_->Alloc(buffer_size);
114 return buffer_;
115 }
116
117private:
118 void* buffer_ {};
119 std::optional<std::vector<int64_t>> shape_;
120 // caller need to make sure the allocator is alive
121 IAllocator* allocator_{};
122};
123
124template <typename TT>
125ONNXTensorElementDataType GetOrtDType(){
126 if constexpr (std::is_same<TT, bool>::value)
127 return ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL;
128 else if constexpr (std::is_same<TT, float>::value)
129 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
130 else if constexpr (std::is_same<TT, double>::value)
131 return ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE;
132 else if constexpr (std::is_same<TT, uint8_t>::value)
133 return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8;
134 else if constexpr (std::is_same<TT, int8_t>::value)
135 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8;
136 else if constexpr (std::is_same<TT, uint16_t>::value)
137 return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16;
138 else if constexpr (std::is_same<TT, int16_t>::value)
139 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16;
140 else if constexpr (std::is_same<TT, uint32_t>::value)
141 return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32;
142 else if constexpr (std::is_same<TT, int32_t>::value)
143 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32;
144 else if constexpr (std::is_same<TT, uint64_t>::value)
145 return ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64;
146 else if constexpr (std::is_same<TT, int64_t>::value)
147 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
148 else if constexpr (std::is_same<TT, std::string>::value)
149 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
150 ORTX_CXX_API_THROW("Unexpected type", ORT_RUNTIME_EXCEPTION);
151 return ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT;
152}
153
154class TensorBase : public Arg {
155public:
156 virtual ~TensorBase() {}
157
158 virtual ONNXTensorElementDataType Type() const = 0;
159 virtual const std::vector<int64_t>& Shape() const = 0;
160 virtual int64_t NumberOfElement() const = 0;
161 virtual const void* DataRaw() const = 0;
162 virtual size_t SizeInBytes() const = 0;
163};
164
165template <typename T>
166class Tensor : public TensorBase {
167 public:
168 using TT = typename std::remove_reference<T>::type;
169 Tensor(std::unique_ptr<ITensorStorage> tensor_storage) : storage_(std::move(tensor_storage)){
170 }
171
172 Tensor(const std::vector<int64_t>& shape, void* buffer) : Tensor(std::make_unique<OrtEagerTensorStorage>(shape, buffer)) {}
173
174 Tensor(IAllocator* allocator) : storage_(std::make_unique<OrtEagerTensorStorage>(allocator)){}
175
176 virtual ~Tensor() = default;
177
178 operator bool() const {
179 return storage_->IsInitialized();
180 }
181
182 ONNXTensorElementDataType Type() const override {
183 return GetOrtDType<T>();
184 }
185
186 const std::vector<int64_t>& Shape() const override {
187 return storage_->Shape();
188 }
189
190 int64_t NumberOfElement() const override {
191 auto& shape = storage_->Shape();
192 return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
193 }
194
195 std::string Shape2Str() const {
196 if (storage_->IsInitialized()) {
197 std::string shape_str;
198 auto& shape = storage_->Shape();
199 for (const auto& dim : shape) {
200 shape_str.append(std::to_string(dim));
201 shape_str.append(", ");
202 }
203 return shape_str;
204 } else {
205 return "empty";
206 }
207 }
208
209 const TT* Data() const {
210#if ORT_API_VERSION >= 16
211 if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value)
212 return reinterpret_cast<const TT*>(storage_->DataRaw());
213 else
214#endif
215 return static_cast<const TT*>(storage_->DataRaw());
216 }
217
218 const void* DataRaw() const override {
219 return storage_->DataRaw();
220 }
221
222 size_t SizeInBytes() const override {
223 return NumberOfElement() * sizeof(TT);
224 }
225
226 TT* Allocate(const std::vector<int64_t>& shape) {
227 // it should be OK to allocate multiple times
228 void* buffer = storage_->Initialize(shape, sizeof(TT));
229#if ORT_API_VERSION >= 16
230 if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value)
231 return reinterpret_cast<TT*>(buffer);
232 else
233#endif
234 return static_cast<TT*>(buffer);
235 }
236
237 const Span<T>& AsSpan() {
238#if ORT_API_VERSION >= 16
239 if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value) {
240 ORTX_CXX_API_THROW("AsSpan for MFloat16 / BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
241 }
242 else{
243#endif
244 auto& shape = storage_->Shape();
245 if (shape.size() != 1) {
246 ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
247 }
248 span_.Assign(Data(), shape[0]);
249 return span_;
250#if ORT_API_VERSION >= 16
251 }
252#endif
253 }
254
255 const T& AsScalar() {
256#if ORT_API_VERSION >= 16
257 if constexpr (std::is_same<TT, MFloat16>::value || std::is_same<TT, BFloat16>::value) {
258 ORTX_CXX_API_THROW("AsScalar for MFloat16 / BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
259 }
260 else{
261#endif
262 auto& shape = storage_->Shape();
263 if ((shape.size() == 1 && shape[0] != 1) || shape.size() > 1) {
264 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
265 }
266 return *Data();
267#if ORT_API_VERSION >= 16
268 }
269#endif
270 }
271
272 private:
273 std::unique_ptr<ITensorStorage> storage_;
274 Span<T> span_;
275};
276
277template<typename T>
278class IStringTensorStorage{
279public:
280 using strings = std::vector<T>;
281 virtual const std::vector<int64_t>& Shape() const = 0;
282 virtual const void* DataRaw() const = 0;
283 virtual const strings& Data() const = 0;
284 virtual bool IsInitialized() const = 0;
285 virtual void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) = 0;
286 virtual void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) = 0;
287};
288
289template<typename T>
290class EagerStringTensorStorage : public IStringTensorStorage<T>{
291public:
292 using strings = std::vector<T>;
293 EagerStringTensorStorage(const strings& ss) : input_strings_(ss), shape_(std::vector<int64_t>{static_cast<int64_t>(ss.size())}){}
294
295 EagerStringTensorStorage() {}
296
297 const std::vector<int64_t>& Shape() const override {
298 if (!IsInitialized())
299 ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
300 return *shape_;
301 }
302
303 virtual const void* DataRaw() const override {
304 if (input_strings_.size() != 1) {
305 ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
306 }
307 if constexpr (std::is_same<std::string_view, T>::value)
308 return reinterpret_cast<const void*>(input_strings_[0].data());
309 else
310 return reinterpret_cast<const void*>(input_strings_[0].c_str());
311 }
312
313 virtual bool IsInitialized() const override {
314 return shape_.has_value();
315 }
316
317 virtual void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) override {
318 if constexpr (std::is_same<std::string_view, T>::value)
319 ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
320 input_strings_.assign(ss.begin(), ss.end());
321 shape_ = dims;
322 }
323
324 const strings& Data() const override {
325 return input_strings_;
326 }
327
328 virtual void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) override {
329 if constexpr (std::is_same<std::string_view, T>::value)
330 ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
331
332 for (const char* s : ss){
333 input_strings_.push_back(s);
334 }
335 shape_ = dims;
336 }
337
338private:
339 std::vector<T> input_strings_;
340 std::optional<std::vector<int64_t>> shape_;
341};
342
343template <>
344class Tensor<std::string> : public TensorBase {
345 public:
346 using strings = std::vector<std::string>;
347
348 Tensor(std::unique_ptr<IStringTensorStorage<std::string>> storage) : storage_(std::move(storage)) {}
349
350 Tensor(const strings& ss) : storage_(std::make_unique<EagerStringTensorStorage<std::string>>(ss)) {}
351
352 Tensor() : storage_(std::make_unique<EagerStringTensorStorage<std::string>>()) {}
353
354 ONNXTensorElementDataType Type() const override {
355 return GetOrtDType<std::string>();
356 }
357
358 const strings& Data() const {
359 return storage_->Data();
360 }
361
362 const std::vector<int64_t>& Shape() const override {
363 return storage_->Shape();
364 }
365
366 int64_t NumberOfElement() const override {
367 auto& shape = storage_->Shape();
368 return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
369 }
370
371 std::string Shape2Str() const {
372 if (storage_->IsInitialized()) {
373 std::string shape_str;
374 auto& shape = storage_->Shape();
375 for (const auto& dim : shape) {
376 shape_str.append(std::to_string(dim));
377 shape_str.append(", ");
378 }
379 return shape_str;
380 } else {
381 return "empty";
382 }
383 }
384
385 const void* DataRaw() const override {
386 return storage_->DataRaw();
387 }
388
389 size_t SizeInBytes() const override {
390 auto& ss = storage_->Data();
391 if (ss.size() != 1) {
392 ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
393 }
394 return ss[0].size();
395 }
396
397 void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
398 storage_->SetStringOutput(ss, dims);
399 }
400 void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
401 storage_->SetStringOutput(ss, dims);
402 }
403 const Span<std::string>& AsSpan() {
404 ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
405 }
406 const std::string& AsScalar() {
407 auto& ss = storage_->Data();
408 if (ss.size() != 1) {
409 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
410 }
411 return ss[0];
412 }
413
414 private:
415 std::unique_ptr<IStringTensorStorage<std::string>> storage_;
416};
417
418
419template <>
420class Tensor<std::string_view> : public TensorBase {
421 public:
422 using strings = std::vector<std::string_view>;
423
424 Tensor(std::unique_ptr<IStringTensorStorage<std::string_view>> storage) : storage_(std::move(storage)) {}
425
426 Tensor(const strings& ss) : storage_(std::make_unique<EagerStringTensorStorage<std::string_view>>(ss)) {}
427
428 ONNXTensorElementDataType Type() const override {
429 return GetOrtDType<std::string_view>();
430 }
431
432 const strings& Data() const {
433 return storage_->Data();
434 }
435
436 const std::vector<int64_t>& Shape() const override {
437 return storage_->Shape();
438 }
439
440 int64_t NumberOfElement() const override {
441 auto& shape = storage_->Shape();
442 return std::accumulate(shape.begin(), shape.end(), 1LL, std::multiplies<int64_t>());
443 }
444
445 std::string Shape2Str() const {
446 if (storage_->IsInitialized()) {
447 std::string shape_str;
448 auto& shape = storage_->Shape();
449 for (const auto& dim : shape) {
450 shape_str.append(std::to_string(dim));
451 shape_str.append(", ");
452 }
453 return shape_str;
454 } else {
455 return "empty";
456 }
457 }
458
459 const void* DataRaw() const override {
460 return storage_->DataRaw();
461 }
462
463 size_t SizeInBytes() const override {
464 auto& ss = storage_->Data();
465 if (ss.size() != 1) {
466 ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
467 }
468 return ss[0].size();
469 }
470
471 void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
472 storage_->SetStringOutput(ss, dims);
473 }
474 void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
475 storage_->SetStringOutput(ss, dims);
476 }
477 const Span<std::string_view>& AsSpan() {
478 ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
479 }
480 const std::string_view& AsScalar() {
481 auto& ss = storage_->Data();
482 if (ss.size() != 1) {
483 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
484 }
485 return ss[0];
486 }
487
488 private:
489 std::unique_ptr<IStringTensorStorage<std::string_view>> storage_;
490};
491
492
493template<typename ...Args>
494class NamedArgumentDict{
495public:
496 using ValueTuple = std::tuple<Args...>;
497
498 NamedArgumentDict(const std::vector<const char*>& keys, const std::tuple<Args...>& args) : entries_(args) {
499 for (const char* key : keys){
500 names_.push_back(key);
501 }
502 }
503
504 template<typename T>
505 T TryToGetAttributeWithDefault(const char* name, const T& default_value) const {
506 return TryToGetAttributeWithDefaultInternal<0>(name, default_value);
507 }
508
509private:
510 template<size_t I, typename T>
511 typename std::enable_if<I == sizeof...(Args), T>::type
512 TryToGetAttributeWithDefaultInternal(const char* name, const T& default_value) const {
513 return default_value;
514 }
515
516 template<size_t I, typename T>
517 typename std::enable_if<I < sizeof...(Args), T>::type
518 TryToGetAttributeWithDefaultInternal(const char* name, const T& default_value) const {
519 if (names_[I] == name){
520 if constexpr (std::is_same<std::tuple_element_t<I, ValueTuple>, T>::value)
521 return std::get<I>(entries_);
522 else
523 throw std::runtime_error("name matched but type is not");
524 }
525 return TryToGetAttributeWithDefaultInternal<I+1>(name, default_value);
526 }
527
528 std::vector<std::string> names_;
529 std::tuple<Args...> entries_;
530
531};
532
533}
534}