microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
a03eded71e8263ca579307e41f04fbdf6ad4abc5

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/custom_op_lite.h

1139lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5#include "onnxruntime_customop.hpp"
6#include "onnxruntime_f16.h"
7#include <optional>
8#include <numeric>
9
10namespace Ort {
11namespace Custom {
12
13class TensorBase {
14 public:
15 TensorBase(const OrtW::CustomOpApi& api,
16 OrtKernelContext& ctx,
17 size_t indice,
18 bool is_input) : api_(api),
19 ctx_(ctx),
20 indice_(indice),
21 is_input_(is_input) {}
22
23 virtual ~TensorBase() = default;
24 operator bool() const {
25 return shape_.has_value();
26 }
27 const std::vector<int64_t>& Shape() const {
28 if (shape_.has_value()) {
29 return *shape_;
30 } else {
31 ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION);
32 }
33 }
34 ONNXTensorElementDataType Type() const {
35 return type_;
36 }
37 int64_t NumberOfElement() const {
38 if (shape_.has_value()) {
39 return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies<int64_t>());
40 } else {
41 ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION);
42 }
43 }
44 std::string Shape2Str() const {
45 if (shape_.has_value()) {
46 std::string shape_str;
47 for (const auto& dim : *shape_) {
48 shape_str.append(std::to_string(dim));
49 shape_str.append(", ");
50 }
51 return shape_str;
52 } else {
53 return "empty";
54 }
55 }
56 bool IsCpuTensor() const {
57 return strcmp("Cpu", mem_type_) == 0;
58 }
59 virtual const void* DataRaw() const = 0;
60 virtual size_t SizeInBytes() const = 0;
61
62 protected:
63 const OrtW::CustomOpApi& api_;
64 OrtKernelContext& ctx_;
65 size_t indice_;
66 bool is_input_;
67 std::optional<std::vector<int64_t>> shape_;
68 ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
69 const char* mem_type_ = "Cpu";
70};
71
72template <typename T>
73struct Span {
74 const T* data_ = {};
75 size_t size_ = {};
76 void Assign(const T* data, size_t size) {
77 data_ = data;
78 size_ = size;
79 }
80 size_t size() const { return size_; }
81 T operator[](size_t indice) const {
82 return data_[indice];
83 }
84 const T* data() const { return data_; }
85};
86
87#if ORT_API_VERSION >= 16
88
89template <>
90struct Span<MFloat16> {
91 const MFloat16* data_ = {};
92 size_t size_ = {};
93 void Assign(const MFloat16* data, size_t size) {
94 data_ = data;
95 size_ = size;
96 }
97 size_t size() const { return size_; }
98 MFloat16 operator[](size_t indice) const {
99 return data_[indice];
100 }
101 const MFloat16* data() const { return data_; }
102};
103
104template <>
105struct Span<BFloat16> {
106 const BFloat16* data_ = {};
107 size_t size_ = {};
108 void Assign(const BFloat16* data, size_t size) {
109 data_ = data;
110 size_ = size;
111 }
112 size_t size() const { return size_; }
113 BFloat16 operator[](size_t indice) const {
114 return data_[indice];
115 }
116 const BFloat16* data() const { return data_; }
117};
118
119#endif
120
121template <typename T>
122class Tensor : public TensorBase {
123 public:
124 using TT = typename std::remove_reference<T>::type;
125 Tensor(const OrtW::CustomOpApi& api,
126 OrtKernelContext& ctx,
127 size_t indice,
128 bool is_input) : TensorBase(api,
129 ctx,
130 indice,
131 is_input) {
132 if (is_input) {
133 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
134 if (indice >= input_count) {
135 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
136 }
137 const_value_ = api_.KernelContext_GetInput(&ctx_, indice);
138 auto* info = api_.GetTensorTypeAndShape(const_value_);
139 shape_ = api_.GetTensorShape(info);
140 type_ = api_.GetTensorElementType(info);
141 api_.ReleaseTensorTypeAndShapeInfo(info);
142 const OrtMemoryInfo* mem_info = {};
143 api_.ThrowOnError(api_.GetOrtApi().GetTensorMemoryInfo(const_value_, &mem_info));
144 if (mem_info) {
145 api_.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_));
146 }
147 }
148 }
149 const TT* Data() const {
150 return api_.GetTensorData<TT>(const_value_);
151 }
152
153 const void* DataRaw() const override {
154 return reinterpret_cast<const void*>(Data());
155 }
156
157 size_t SizeInBytes() const override {
158 return NumberOfElement() * sizeof(TT);
159 }
160
161 TT* Allocate(const std::vector<int64_t>& shape) {
162 if (!data_) {
163 OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size());
164 shape_ = shape;
165 data_ = api_.GetTensorMutableData<TT>(out);
166 }
167 return data_;
168 }
169 const Span<T>& AsSpan() {
170 if (!shape_.has_value() || shape_->size() != 1) {
171 ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
172 }
173 span_.Assign(Data(), (*shape_)[0]);
174 return span_;
175 }
176 const T& AsScalar() {
177 if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
178 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
179 }
180 return *Data();
181 }
182
183 private:
184 const OrtValue* const_value_{}; // for input
185 TT* data_{}; // for output
186 Span<T> span_;
187};
188
189template <>
190class Tensor<std::string> : public TensorBase {
191 public:
192 using strings = std::vector<std::string>;
193
194 Tensor(const OrtW::CustomOpApi& api,
195 OrtKernelContext& ctx,
196 size_t indice,
197 bool is_input) : TensorBase(api,
198 ctx,
199 indice,
200 is_input) {
201 if (is_input) {
202 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
203 if (indice >= input_count) {
204 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
205 }
206
207 auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
208 auto* info = api_.GetTensorTypeAndShape(const_value);
209 shape_ = api_.GetTensorShape(info);
210 type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
211 api_.ReleaseTensorTypeAndShapeInfo(info);
212
213 size_t num_chars;
214 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
215 std::vector<char> chars(num_chars + 1, '\0');
216 auto num_strings = NumberOfElement();
217 std::vector<size_t> offsets(NumberOfElement());
218 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
219 (void*)chars.data(),
220 num_chars,
221 offsets.data(),
222 offsets.size()));
223 auto upper_bound = static_cast<int64_t>(num_strings) - 1;
224 input_strings_.resize(num_strings);
225 for (int64_t i = upper_bound; i >= 0; --i) {
226 if (i < upper_bound) {
227 chars[offsets[i + 1]] = '\0';
228 }
229 input_strings_[i] = chars.data() + offsets[i];
230 }
231 }
232 }
233 const strings& Data() const {
234 return input_strings_;
235 }
236 const void* DataRaw() const override {
237 if (input_strings_.size() != 1) {
238 ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
239 }
240 return reinterpret_cast<const void*>(input_strings_[0].c_str());
241 }
242 size_t SizeInBytes() const override {
243 if (input_strings_.size() != 1) {
244 ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
245 }
246 return input_strings_[0].size();
247 }
248 void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
249 std::vector<const char*> raw;
250 for (const auto& s : ss) {
251 raw.push_back(s.data());
252 }
253 auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
254 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, raw.data(), raw.size()));
255 }
256 void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
257 auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
258 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, ss.data(), ss.size()));
259 }
260 const Span<std::string>& AsSpan() {
261 ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
262 }
263 const std::string& AsScalar() {
264 if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
265 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
266 }
267 return input_strings_[0];
268 }
269
270 private:
271 std::vector<std::string> input_strings_; // for input
272};
273
274template <>
275class Tensor<std::string_view> : public TensorBase {
276 public:
277 using strings = std::vector<std::string>;
278 using string_views = std::vector<std::string_view>;
279
280 Tensor(const OrtW::CustomOpApi& api,
281 OrtKernelContext& ctx,
282 size_t indice,
283 bool is_input) : TensorBase(api,
284 ctx,
285 indice,
286 is_input) {
287 if (is_input_) {
288 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
289 if (indice >= input_count) {
290 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
291 }
292 auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
293 auto* info = api_.GetTensorTypeAndShape(const_value);
294 shape_ = api_.GetTensorShape(info);
295 type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
296 api_.ReleaseTensorTypeAndShapeInfo(info);
297
298 size_t num_chars;
299 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
300 chars_.resize(num_chars + 1, '\0');
301
302 auto num_strings = static_cast<size_t>(NumberOfElement());
303 if (num_strings) {
304 std::vector<size_t> offsets(num_strings);
305 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
306 (void*)chars_.data(),
307 num_chars,
308 offsets.data(),
309 offsets.size()));
310 offsets.push_back(num_chars);
311 for (size_t i = 0; i < num_strings; ++i) {
312 input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
313 }
314 }
315 }
316 }
317 int64_t NumberOfElement() const {
318 if (shape_.has_value()) {
319 return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies<int64_t>());
320 } else {
321 return 0;
322 }
323 }
324 const string_views& Data() const {
325 return input_string_views_;
326 }
327 const void* DataRaw() const override {
328 if (input_string_views_.size() != 1) {
329 ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
330 }
331 return reinterpret_cast<const void*>(input_string_views_[0].data());
332 }
333 size_t SizeInBytes() const override {
334 if (input_string_views_.size() != 1) {
335 ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
336 }
337 return input_string_views_[0].size();
338 }
339 const Span<std::string_view>& AsSpan() {
340 ORTX_CXX_API_THROW("span for TensorT of string view not implemented", ORT_RUNTIME_EXCEPTION);
341 }
342 std::string_view AsScalar() {
343 if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
344 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
345 }
346 return input_string_views_[0];
347 }
348
349 private:
350 std::vector<char> chars_; // for input
351 std::vector<std::string_view> input_string_views_; // for input
352};
353
354#if ORT_API_VERSION >= 16
355
356template <>
357struct Tensor<MFloat16> : public TensorBase {
358 Tensor(const OrtW::CustomOpApi& api,
359 OrtKernelContext& ctx,
360 size_t indice,
361 bool is_input) : TensorBase(api,
362 ctx,
363 indice,
364 is_input) {
365 type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
366 if (is_input_) {
367 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
368 if (indice >= input_count) {
369 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
370 }
371 const_value_ = api_.KernelContext_GetInput(&ctx_, indice);
372 auto* info = api_.GetTensorTypeAndShape(const_value_);
373 shape_ = api_.GetTensorShape(info);
374 type_ = api_.GetTensorElementType(info);
375 api_.ReleaseTensorTypeAndShapeInfo(info);
376 const OrtMemoryInfo* mem_info = {};
377 api_.ThrowOnError(api_.GetOrtApi().GetTensorMemoryInfo(const_value_, &mem_info));
378 if (mem_info) {
379 api_.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_));
380 }
381 }
382 }
383
384 const MFloat16* Data() const {
385 return reinterpret_cast<const MFloat16*>(api_.GetTensorData<uint16_t>(const_value_));
386 }
387
388 MFloat16* Allocate(const std::vector<int64_t>& shape) {
389 if (!data_) {
390 OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size());
391 shape_ = shape;
392 data_ = reinterpret_cast<MFloat16*>(api_.GetTensorMutableData<uint16_t>(out));
393 }
394 return data_;
395 }
396
397 const Span<MFloat16>& AsSpan() {
398 ORTX_CXX_API_THROW("AsSpan for MFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
399 }
400
401 const MFloat16& AsScalar() {
402 ORTX_CXX_API_THROW("AsScalar for MFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
403 }
404
405 const void* DataRaw() const override {
406 return reinterpret_cast<const void*>(Data());
407 }
408
409 virtual size_t SizeInBytes() const override {
410 return NumberOfElement() * sizeof(uint16_t);
411 }
412
413 private:
414 const OrtValue* const_value_{}; // for input
415 MFloat16* data_{}; // for output
416};
417
418template <>
419struct Tensor<BFloat16> : public TensorBase {
420 Tensor(const OrtW::CustomOpApi& api,
421 OrtKernelContext& ctx,
422 size_t indice,
423 bool is_input) : TensorBase(api,
424 ctx,
425 indice,
426 is_input) {
427 type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
428 if (is_input_) {
429 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
430 if (indice >= input_count) {
431 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
432 }
433 const_value_ = api_.KernelContext_GetInput(&ctx_, indice);
434 auto* info = api_.GetTensorTypeAndShape(const_value_);
435 shape_ = api_.GetTensorShape(info);
436 type_ = api_.GetTensorElementType(info);
437 api_.ReleaseTensorTypeAndShapeInfo(info);
438 const OrtMemoryInfo* mem_info = {};
439 api_.ThrowOnError(api_.GetOrtApi().GetTensorMemoryInfo(const_value_, &mem_info));
440 if (mem_info) {
441 api_.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_));
442 }
443 }
444 }
445
446 const BFloat16* Data() const {
447 return reinterpret_cast<const BFloat16*>(api_.GetTensorData<uint16_t>(const_value_));
448 }
449
450 BFloat16* Allocate(const std::vector<int64_t>& shape) {
451 if (!data_) {
452 OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size());
453 shape_ = shape;
454 data_ = reinterpret_cast<BFloat16*>(api_.GetTensorMutableData<uint16_t>(out));
455 }
456 return data_;
457 }
458
459 const Span<BFloat16>& AsSpan() {
460 ORTX_CXX_API_THROW("AsSpan for BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
461 }
462
463 const BFloat16& AsScalar() {
464 ORTX_CXX_API_THROW("AsScalar for BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
465 }
466
467 const void* DataRaw() const override {
468 return reinterpret_cast<const void*>(Data());
469 }
470
471 virtual size_t SizeInBytes() const override {
472 return NumberOfElement() * sizeof(uint16_t);
473 }
474
475 private:
476 const OrtValue* const_value_{}; // for input
477 BFloat16* data_{}; // for output
478};
479
480#endif
481
482using TensorPtr = std::unique_ptr<Custom::TensorBase>;
483using TensorPtrs = std::vector<TensorPtr>;
484
485// Represent variadic input or output
486struct Variadic : public TensorBase {
487 Variadic(const OrtW::CustomOpApi& api,
488 OrtKernelContext& ctx,
489 size_t indice,
490 bool is_input) : TensorBase(api,
491 ctx,
492 indice,
493 is_input) {
494#if ORT_API_VERSION < 14
495 ORTX_CXX_API_THROW("Variadic input or output only supported after onnxruntime 1.14", ORT_RUNTIME_EXCEPTION);
496#endif
497 if (is_input) {
498 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
499 for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
500 auto* const_value = api_.KernelContext_GetInput(&ctx_, ith_input);
501 auto* info = api_.GetTensorTypeAndShape(const_value);
502 auto type = api_.GetTensorElementType(info);
503 api_.ReleaseTensorTypeAndShapeInfo(info);
504 TensorPtr tensor;
505 switch (type) {
506 case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
507 tensor = std::make_unique<Custom::Tensor<bool>>(api, ctx, ith_input, true);
508 break;
509 case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
510 tensor = std::make_unique<Custom::Tensor<float>>(api, ctx, ith_input, true);
511 break;
512 case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
513 tensor = std::make_unique<Custom::Tensor<double>>(api, ctx, ith_input, true);
514 break;
515 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
516 tensor = std::make_unique<Custom::Tensor<uint8_t>>(api, ctx, ith_input, true);
517 break;
518 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
519 tensor = std::make_unique<Custom::Tensor<int8_t>>(api, ctx, ith_input, true);
520 break;
521 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
522 tensor = std::make_unique<Custom::Tensor<uint16_t>>(api, ctx, ith_input, true);
523 break;
524 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
525 tensor = std::make_unique<Custom::Tensor<int16_t>>(api, ctx, ith_input, true);
526 break;
527 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
528 tensor = std::make_unique<Custom::Tensor<uint32_t>>(api, ctx, ith_input, true);
529 break;
530 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
531 tensor = std::make_unique<Custom::Tensor<int32_t>>(api, ctx, ith_input, true);
532 break;
533 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
534 tensor = std::make_unique<Custom::Tensor<uint64_t>>(api, ctx, ith_input, true);
535 break;
536 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
537 tensor = std::make_unique<Custom::Tensor<int64_t>>(api, ctx, ith_input, true);
538 break;
539 case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
540 tensor = std::make_unique<Custom::Tensor<std::string>>(api, ctx, ith_input, true);
541 break;
542 default:
543 ORTX_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION);
544 break;
545 }
546 tensors_.emplace_back(tensor.release());
547 } // for
548 } else {
549 // a Variadic used for output is populated by the Compute so leave tensors_ empty here
550 }
551 }
552 template <typename T>
553 T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) {
554 auto tensor = std::make_unique<Tensor<T>>(api_, ctx_, ith_output, false);
555 auto raw_output = tensor.get()->Allocate(shape);
556 tensors_.emplace_back(tensor.release());
557 return raw_output;
558 }
559 Tensor<std::string>& AllocateStringTensor(size_t ith_output) {
560 auto tensor = std::make_unique<Tensor<std::string>>(api_, ctx_, ith_output, false);
561 Tensor<std::string>& output = *tensor;
562 tensors_.emplace_back(tensor.release());
563 return output;
564 }
565 const void* DataRaw() const override {
566 ORTX_CXX_API_THROW("DataRaw() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION);
567 return nullptr;
568 }
569 size_t SizeInBytes() const override {
570 ORTX_CXX_API_THROW("SizeInBytes() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION);
571 return 0;
572 }
573 size_t Size() const {
574 return tensors_.size();
575 }
576 const TensorPtr& operator[](size_t indice) const {
577 return tensors_.at(indice);
578 }
579
580 private:
581 TensorPtrs tensors_;
582};
583
584#ifdef USE_CUDA
585
586enum CudaResource {
587 cuda_handle_t = 10000,
588};
589
590struct CudaContext {
591 static const int cuda_resource_ver = 1;
592 void Init(const OrtW::CustomOpApi& api, const OrtKernelContext& ctx) {
593 const auto& ort_api = api.GetOrtApi();
594 ort_api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream);
595 if (!cuda_stream) {
596 ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION);
597 }
598 }
599 void* cuda_stream = {};
600};
601
602#endif
603
604// using mf16_t = uint16_t;
605
606struct OrtLiteCustomOp : public OrtCustomOp {
607 // CreateTuple
608 template <size_t ith_input, size_t ith_output, typename... Ts>
609 static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
610 CreateTuple(const OrtW::CustomOpApi*, OrtKernelContext*, std::vector<TensorPtr>&, size_t, size_t, const std::string&) {
611 return std::make_tuple();
612 }
613
614 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
615 static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type
616 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
617 std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
618 auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
619 return std::tuple_cat(current, next);
620 }
621
622#ifdef USE_CUDA
623 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
624 static typename std::enable_if<std::is_same<T, const CudaContext&>::value, std::tuple<T, Ts...>>::type
625 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
626 thread_local CudaContext cuda_context;
627 cuda_context.Init(*api, *context);
628 std::tuple<T> current = std::tuple<const CudaContext&>{cuda_context};
629 auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
630 return std::tuple_cat(current, next);
631 }
632#endif
633
634#if ORT_API_VERSION >= 14
635 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
636 static typename std::enable_if<std::is_same<T, const Variadic*>::value, std::tuple<T, Ts...>>::type
637 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
638 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
639 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
640 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
641 return std::tuple_cat(current, next);
642 }
643
644 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
645 static typename std::enable_if<std::is_same<T, const Variadic&>::value, std::tuple<T, Ts...>>::type
646 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
647 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
648 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
649 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
650 return std::tuple_cat(current, next);
651 }
652
653 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
654 static typename std::enable_if<std::is_same<T, Variadic*>::value, std::tuple<T, Ts...>>::type
655 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
656 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
657 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
658 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
659 return std::tuple_cat(current, next);
660 }
661
662 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
663 static typename std::enable_if<std::is_same<T, Variadic&>::value, std::tuple<T, Ts...>>::type
664 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
665 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
666 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
667 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
668 return std::tuple_cat(current, next);
669 }
670#endif
671
672#define CREATE_TUPLE_INPUT(data_type) \
673 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
674 static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
675 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
676 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
677 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; \
678 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
679 return std::tuple_cat(current, next); \
680 } \
681 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
682 static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
683 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
684 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
685 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; \
686 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
687 return std::tuple_cat(current, next); \
688 } \
689 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
690 static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
691 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
692 if (ith_input < num_input) { \
693 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
694 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())}; \
695 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
696 return std::tuple_cat(current, next); \
697 } else { \
698 std::tuple<T> current = std::tuple<T>{}; \
699 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
700 return std::tuple_cat(current, next); \
701 } \
702 } \
703 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
704 static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>*>::value, std::tuple<T, Ts...>>::type \
705 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
706 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
707 if (!tensors.back()->IsCpuTensor()) { \
708 ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \
709 } \
710 std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
711 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
712 return std::tuple_cat(current, next); \
713 } \
714 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
715 static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>&>::value, std::tuple<T, Ts...>>::type \
716 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
717 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
718 if (!tensors.back()->IsCpuTensor()) { \
719 ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \
720 } \
721 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
722 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
723 return std::tuple_cat(current, next); \
724 } \
725 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
726 static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type>*>>::value, std::tuple<T, Ts...>>::type \
727 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
728 if (ith_input < num_input) { \
729 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
730 if (!tensors.back()->IsCpuTensor()) { \
731 ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \
732 } \
733 std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
734 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
735 return std::tuple_cat(current, next); \
736 } else { \
737 std::tuple<T> current = std::tuple<T>{}; \
738 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
739 return std::tuple_cat(current, next); \
740 } \
741 } \
742 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
743 static typename std::enable_if<std::is_same<T, data_type>::value, std::tuple<T, Ts...>>::type \
744 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
745 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
746 if (!tensors.back()->IsCpuTensor()) { \
747 ORTX_CXX_API_THROW("scalar input could only be applied to CPU tensor", ORT_FAIL); \
748 } \
749 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsScalar()}; \
750 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
751 return std::tuple_cat(current, next); \
752 } \
753 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
754 static typename std::enable_if<std::is_same<T, std::optional<data_type>>::value, std::tuple<T, Ts...>>::type \
755 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
756 if (ith_input < num_input) { \
757 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
758 if (!tensors.back()->IsCpuTensor()) { \
759 ORTX_CXX_API_THROW("scalar input could only be applied to CPU tensor", ORT_FAIL); \
760 } \
761 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsScalar()}; \
762 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
763 return std::tuple_cat(current, next); \
764 } else { \
765 std::tuple<T> current = std::tuple<T>{}; \
766 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
767 return std::tuple_cat(current, next); \
768 } \
769 }
770#define CREATE_TUPLE_OUTPUT(data_type) \
771 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
772 static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
773 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
774 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
775 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; \
776 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
777 return std::tuple_cat(current, next); \
778 } \
779 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
780 static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
781 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
782 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
783 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; \
784 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
785 return std::tuple_cat(current, next); \
786 } \
787 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
788 static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
789 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
790 if (ith_output < num_output) { \
791 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
792 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())}; \
793 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
794 return std::tuple_cat(current, next); \
795 } else { \
796 std::tuple<T> current = std::tuple<T>{}; \
797 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
798 return std::tuple_cat(current, next); \
799 } \
800 }
801#define CREATE_TUPLE(data_type) \
802 CREATE_TUPLE_INPUT(data_type) \
803 CREATE_TUPLE_OUTPUT(data_type)
804
805 CREATE_TUPLE(bool)
806#if ORT_API_VERSION >= 16
807 CREATE_TUPLE(MFloat16)
808 CREATE_TUPLE(BFloat16)
809#endif
810 CREATE_TUPLE(float)
811 CREATE_TUPLE(double)
812 CREATE_TUPLE(int8_t)
813 CREATE_TUPLE(int16_t)
814 CREATE_TUPLE(int32_t)
815 CREATE_TUPLE(int64_t)
816 CREATE_TUPLE(uint8_t)
817 CREATE_TUPLE(uint16_t)
818 CREATE_TUPLE(uint32_t)
819 CREATE_TUPLE(uint64_t)
820 CREATE_TUPLE(std::string)
821 CREATE_TUPLE_INPUT(std::string_view)
822
823 // ParseArgs ...
824 template <typename... Ts>
825 static typename std::enable_if<0 == sizeof...(Ts)>::type
826 ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
827 }
828
829 template <typename T, typename... Ts>
830 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type
831 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
832 ParseArgs<Ts...>(input_types, output_types);
833 }
834
835#ifdef USE_CUDA
836 template <typename T, typename... Ts>
837 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const CudaContext&>::value>::type
838 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
839 ParseArgs<Ts...>(input_types, output_types);
840 }
841#endif
842
843#if ORT_API_VERSION >= 14
844 template <typename T, typename... Ts>
845 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic&>::value>::type
846 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
847 if (!input_types.empty()) {
848 ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
849 }
850 input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
851 ParseArgs<Ts...>(input_types, output_types);
852 }
853
854 template <typename T, typename... Ts>
855 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic*>::value>::type
856 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
857 if (!input_types.empty()) {
858 ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
859 }
860 input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
861 ParseArgs<Ts...>(input_types, output_types);
862 }
863
864 template <typename T, typename... Ts>
865 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic&>::value>::type
866 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
867 if (!output_types.empty()) {
868 ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
869 }
870 output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
871 ParseArgs<Ts...>(input_types, output_types);
872 }
873
874 template <typename T, typename... Ts>
875 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic*>::value>::type
876 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
877 if (!output_types.empty()) {
878 ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
879 }
880 output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
881 ParseArgs<Ts...>(input_types, output_types);
882 }
883#endif
884
885#define PARSE_INPUT_BASE(pack_type, onnx_type) \
886 template <typename T, typename... Ts> \
887 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
888 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
889 input_types.push_back(onnx_type); \
890 ParseArgs<Ts...>(input_types, output_types); \
891 } \
892 template <typename T, typename... Ts> \
893 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \
894 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
895 input_types.push_back(onnx_type); \
896 ParseArgs<Ts...>(input_types, output_types); \
897 }
898
899#define PARSE_INPUT(data_type, onnx_type) \
900 PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
901 PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
902 PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \
903 PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \
904 PARSE_INPUT_BASE(data_type, onnx_type)
905
906#define PARSE_OUTPUT(data_type, onnx_type) \
907 template <typename T, typename... Ts> \
908 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \
909 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
910 output_types.push_back(onnx_type); \
911 ParseArgs<Ts...>(input_types, output_types); \
912 } \
913 template <typename T, typename... Ts> \
914 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \
915 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
916 output_types.push_back(onnx_type); \
917 ParseArgs<Ts...>(input_types, output_types); \
918 } \
919 template <typename T, typename... Ts> \
920 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
921 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
922 output_types.push_back(onnx_type); \
923 ParseArgs<Ts...>(input_types, output_types); \
924 }
925
926#define PARSE_ARGS(data_type, onnx_type) \
927 PARSE_INPUT(data_type, onnx_type) \
928 PARSE_OUTPUT(data_type, onnx_type)
929
930 PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
931#if ORT_API_VERSION >= 16
932 PARSE_ARGS(MFloat16, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)
933 PARSE_ARGS(BFloat16, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16)
934#endif
935 PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
936 PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
937 PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
938 PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
939 PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
940 PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
941 PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
942 PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
943 PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
944 PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
945 PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
946 PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output
947
948 OrtLiteCustomOp(const char* op_name,
949 const char* execution_provider) : op_name_(op_name),
950 execution_provider_(execution_provider) {
951 // Zero out OrtCustomOp so that any added func pointers are nullptr for forwards compatibility
952 memset(&this->version, 0, sizeof(OrtCustomOp));
953
954 int act_ver = GetActiveOrtAPIVersion();
955 OrtCustomOp::version = act_ver < ORT_API_VERSION ? act_ver : ORT_API_VERSION;
956
957 OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
958 OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
959
960 OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) {
961 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
962 return self->input_types_.size();
963 };
964
965 OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) {
966 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
967 return self->input_types_[indice];
968 };
969
970 OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) {
971 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
972 return self->output_types_.size();
973 };
974
975 OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) {
976 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
977 return self->output_types_[indice];
978 };
979
980#if ORT_API_VERSION >= 14
981 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t) {
982 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
983 return (self->input_types_.empty() || self->input_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
984 };
985
986 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
987 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
988 return (self->output_types_.empty() || self->output_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
989 };
990
991 OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) {
992 return 1;
993 };
994
995 OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) {
996 return 0;
997 };
998
999 OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) {
1000 return 1;
1001 };
1002
1003 OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) {
1004 return 0;
1005 };
1006
1007 OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) {
1008 return OrtMemTypeDefault;
1009 };
1010#else
1011 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) {
1012 return INPUT_OUTPUT_OPTIONAL;
1013 };
1014
1015 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
1016 return INPUT_OUTPUT_OPTIONAL;
1017 };
1018#endif
1019 }
1020
1021 const std::string op_name_;
1022 const std::string execution_provider_;
1023
1024 std::vector<ONNXTensorElementDataType> input_types_;
1025 std::vector<ONNXTensorElementDataType> output_types_;
1026};
1027
1028template <typename... Args>
1029struct OrtLiteCustomFunc : public OrtLiteCustomOp {
1030 using ComputeFn = void (*)(Args...);
1031 using MyType = OrtLiteCustomFunc<Args...>;
1032
1033 struct Kernel {
1034 ComputeFn compute_fn_{};
1035 std::string ep_{};
1036 std::unique_ptr<OrtW::CustomOpApi> api_;
1037 };
1038
1039 OrtLiteCustomFunc(const char* op_name,
1040 const char* execution_provider,
1041 ComputeFn compute_fn) : OrtLiteCustomOp(op_name, execution_provider),
1042 compute_fn_(compute_fn) {
1043 ParseArgs<Args...>(input_types_, output_types_);
1044
1045 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
1046 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
1047 std::vector<TensorPtr> tensors;
1048 auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
1049 context,
1050 tensors,
1051 kernel->api_->KernelContext_GetInputCount(context),
1052 kernel->api_->KernelContext_GetOutputCount(context),
1053 kernel->ep_);
1054 std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t);
1055 };
1056
1057 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
1058 auto kernel = std::make_unique<Kernel>();
1059 auto self = static_cast<const OrtLiteCustomFunc*>(this_);
1060 kernel->compute_fn_ = self->compute_fn_;
1061 kernel->ep_ = self->execution_provider_;
1062 kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
1063 return reinterpret_cast<void*>(kernel.release());
1064 };
1065
1066 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
1067 delete reinterpret_cast<Kernel*>(op_kernel);
1068 };
1069 }
1070
1071 ComputeFn compute_fn_;
1072};
1073
1074template <typename CustomOp>
1075struct OrtLiteCustomStruct : public OrtLiteCustomOp {
1076 template <typename... Args>
1077 using CustomComputeFn = void (CustomOp::*)(Args...) const;
1078 using MyType = OrtLiteCustomStruct<CustomOp>;
1079
1080 struct Kernel {
1081 std::unique_ptr<CustomOp> custom_op_;
1082 std::string ep_{};
1083 std::unique_ptr<OrtW::CustomOpApi> api_;
1084 };
1085
1086 OrtLiteCustomStruct(const char* op_name,
1087 const char* execution_provider) : OrtLiteCustomOp(op_name,
1088 execution_provider) {
1089 init(&CustomOp::Compute);
1090 }
1091
1092 template <typename... Args>
1093 void init(CustomComputeFn<Args...>) {
1094 ParseArgs<Args...>(input_types_, output_types_);
1095
1096 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
1097 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
1098 std::vector<TensorPtr> tensors;
1099 auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
1100 context,
1101 tensors,
1102 kernel->api_->KernelContext_GetInputCount(context),
1103 kernel->api_->KernelContext_GetOutputCount(context),
1104 kernel->ep_);
1105 std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t);
1106 };
1107
1108 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
1109 auto kernel = std::make_unique<Kernel>();
1110 kernel->custom_op_ = std::make_unique<CustomOp>(*ort_api, *info);
1111 auto self = static_cast<const MyType*>(this_);
1112 kernel->ep_ = self->execution_provider_;
1113 kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
1114 return reinterpret_cast<void*>(kernel.release());
1115 };
1116
1117 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
1118 delete reinterpret_cast<Kernel*>(op_kernel);
1119 };
1120 }
1121};
1122
1123template <typename... Args>
1124OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
1125 const char* execution_provider,
1126 void (*custom_compute_fn)(Args...)) {
1127 using LiteOp = OrtLiteCustomFunc<Args...>;
1128 return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn).release();
1129}
1130
1131template <typename CustomOp>
1132OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
1133 const char* execution_provider) {
1134 using LiteOp = OrtLiteCustomStruct<CustomOp>;
1135 return std::make_unique<LiteOp>(op_name, execution_provider).release();
1136}
1137
1138} // namespace Custom
1139} // namespace Ort
1140