microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
leca/gqa

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/custom_op_lite.h

1163lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5#include "onnxruntime_customop.hpp"
6#include "onnxruntime_f16.h"
7#include <optional>
8#include <numeric>
9
10namespace Ort {
11namespace Custom {
12
13class TensorBase {
14 public:
15 TensorBase(const OrtW::CustomOpApi& api,
16 OrtKernelContext& ctx,
17 size_t indice,
18 bool is_input) : api_(api),
19 ctx_(ctx),
20 indice_(indice),
21 is_input_(is_input) {}
22
23 virtual ~TensorBase() = default;
24 operator bool() const {
25 return shape_.has_value();
26 }
27 const std::vector<int64_t>& Shape() const {
28 if (shape_.has_value()) {
29 return *shape_;
30 } else {
31 ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION);
32 }
33 }
34 ONNXTensorElementDataType Type() const {
35 return type_;
36 }
37 int64_t NumberOfElement() const {
38 if (shape_.has_value()) {
39 return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies<int64_t>());
40 } else {
41 ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION);
42 }
43 }
44 std::string Shape2Str() const {
45 if (shape_.has_value()) {
46 std::string shape_str;
47 for (const auto& dim : *shape_) {
48 shape_str.append(std::to_string(dim));
49 shape_str.append(", ");
50 }
51 return shape_str;
52 } else {
53 return "empty";
54 }
55 }
56 bool IsCpuTensor() const {
57 return strcmp("Cpu", mem_type_) == 0;
58 }
59 virtual const void* DataRaw() const = 0;
60 virtual size_t SizeInBytes() const = 0;
61
62 protected:
63 const OrtW::CustomOpApi& api_;
64 OrtKernelContext& ctx_;
65 size_t indice_;
66 bool is_input_;
67 std::optional<std::vector<int64_t>> shape_;
68 ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
69 const char* mem_type_ = "Cpu";
70};
71
72template <typename T>
73struct Span {
74 const T* data_ = {};
75 size_t size_ = {};
76 void Assign(const T* data, size_t size) {
77 data_ = data;
78 size_ = size;
79 }
80 size_t size() const { return size_; }
81 T operator[](size_t indice) const {
82 return data_[indice];
83 }
84 const T* data() const { return data_; }
85};
86
87#if ORT_API_VERSION >= 16
88
89template <>
90struct Span<MFloat16> {
91 const MFloat16* data_ = {};
92 size_t size_ = {};
93 void Assign(const MFloat16* data, size_t size) {
94 data_ = data;
95 size_ = size;
96 }
97 size_t size() const { return size_; }
98 MFloat16 operator[](size_t indice) const {
99 return data_[indice];
100 }
101 const MFloat16* data() const { return data_; }
102};
103
104template <>
105struct Span<BFloat16> {
106 const BFloat16* data_ = {};
107 size_t size_ = {};
108 void Assign(const BFloat16* data, size_t size) {
109 data_ = data;
110 size_ = size;
111 }
112 size_t size() const { return size_; }
113 BFloat16 operator[](size_t indice) const {
114 return data_[indice];
115 }
116 const BFloat16* data() const { return data_; }
117};
118
119#endif
120
121template <typename T>
122class Tensor : public TensorBase {
123 public:
124 using TT = typename std::remove_reference<T>::type;
125 Tensor(const OrtW::CustomOpApi& api,
126 OrtKernelContext& ctx,
127 size_t indice,
128 bool is_input) : TensorBase(api,
129 ctx,
130 indice,
131 is_input) {
132 if (is_input) {
133 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
134 if (indice >= input_count) {
135 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
136 }
137 const_value_ = api_.KernelContext_GetInput(&ctx_, indice);
138 auto* info = api_.GetTensorTypeAndShape(const_value_);
139 shape_ = api_.GetTensorShape(info);
140 type_ = api_.GetTensorElementType(info);
141 api_.ReleaseTensorTypeAndShapeInfo(info);
142 const OrtMemoryInfo* mem_info = {};
143 api_.ThrowOnError(api_.GetOrtApi().GetTensorMemoryInfo(const_value_, &mem_info));
144 if (mem_info) {
145 api_.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_));
146 }
147 }
148 }
149 const TT* Data() const {
150 return api_.GetTensorData<TT>(const_value_);
151 }
152
153 const void* DataRaw() const override {
154 return reinterpret_cast<const void*>(Data());
155 }
156
157 size_t SizeInBytes() const override {
158 return NumberOfElement() * sizeof(TT);
159 }
160
161 TT* Allocate(const std::vector<int64_t>& shape) {
162 if (!data_) {
163 OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size());
164 shape_ = shape;
165 data_ = api_.GetTensorMutableData<TT>(out);
166 }
167 return data_;
168 }
169 const Span<T>& AsSpan() {
170 if (!shape_.has_value() || shape_->size() != 1) {
171 ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
172 }
173 span_.Assign(Data(), (*shape_)[0]);
174 return span_;
175 }
176 const T& AsScalar() {
177 if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
178 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
179 }
180 return *Data();
181 }
182
183 private:
184 const OrtValue* const_value_{}; // for input
185 TT* data_{}; // for output
186 Span<T> span_;
187};
188
189template <>
190class Tensor<std::string> : public TensorBase {
191 public:
192 using strings = std::vector<std::string>;
193
194 Tensor(const OrtW::CustomOpApi& api,
195 OrtKernelContext& ctx,
196 size_t indice,
197 bool is_input) : TensorBase(api,
198 ctx,
199 indice,
200 is_input) {
201 if (is_input) {
202 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
203 if (indice >= input_count) {
204 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
205 }
206
207 auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
208 auto* info = api_.GetTensorTypeAndShape(const_value);
209 shape_ = api_.GetTensorShape(info);
210 type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
211 api_.ReleaseTensorTypeAndShapeInfo(info);
212
213 size_t num_chars;
214 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
215 std::vector<char> chars(num_chars + 1, '\0');
216 auto num_strings = NumberOfElement();
217 std::vector<size_t> offsets(NumberOfElement());
218 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
219 (void*)chars.data(),
220 num_chars,
221 offsets.data(),
222 offsets.size()));
223 auto upper_bound = static_cast<int64_t>(num_strings) - 1;
224 input_strings_.resize(num_strings);
225 for (int64_t i = upper_bound; i >= 0; --i) {
226 if (i < upper_bound) {
227 chars[offsets[i + 1]] = '\0';
228 }
229 input_strings_[i] = chars.data() + offsets[i];
230 }
231 }
232 }
233 const strings& Data() const {
234 return input_strings_;
235 }
236 const void* DataRaw() const override {
237 if (input_strings_.size() != 1) {
238 ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
239 }
240 return reinterpret_cast<const void*>(input_strings_[0].c_str());
241 }
242 size_t SizeInBytes() const override {
243 if (input_strings_.size() != 1) {
244 ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
245 }
246 return input_strings_[0].size();
247 }
248 void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
249 std::vector<const char*> raw;
250 for (const auto& s : ss) {
251 raw.push_back(s.data());
252 }
253 auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
254 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, raw.data(), raw.size()));
255 }
256 void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
257 auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
258 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, ss.data(), ss.size()));
259 }
260 const Span<std::string>& AsSpan() {
261 ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
262 }
263 const std::string& AsScalar() {
264 if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
265 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
266 }
267 return input_strings_[0];
268 }
269
270 private:
271 std::vector<std::string> input_strings_; // for input
272};
273
274template <>
275class Tensor<std::string_view> : public TensorBase {
276 public:
277 using strings = std::vector<std::string>;
278 using string_views = std::vector<std::string_view>;
279
280 Tensor(const OrtW::CustomOpApi& api,
281 OrtKernelContext& ctx,
282 size_t indice,
283 bool is_input) : TensorBase(api,
284 ctx,
285 indice,
286 is_input) {
287 if (is_input_) {
288 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
289 if (indice >= input_count) {
290 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
291 }
292 auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
293 auto* info = api_.GetTensorTypeAndShape(const_value);
294 shape_ = api_.GetTensorShape(info);
295 type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
296 api_.ReleaseTensorTypeAndShapeInfo(info);
297
298 size_t num_chars;
299 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
300 chars_.resize(num_chars + 1, '\0');
301
302 auto num_strings = static_cast<size_t>(NumberOfElement());
303 if (num_strings) {
304 std::vector<size_t> offsets(num_strings);
305 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
306 (void*)chars_.data(),
307 num_chars,
308 offsets.data(),
309 offsets.size()));
310 offsets.push_back(num_chars);
311 for (size_t i = 0; i < num_strings; ++i) {
312 input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
313 }
314 }
315 }
316 }
317 int64_t NumberOfElement() const {
318 if (shape_.has_value()) {
319 return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies<int64_t>());
320 } else {
321 return 0;
322 }
323 }
324 const string_views& Data() const {
325 return input_string_views_;
326 }
327 const void* DataRaw() const override {
328 if (input_string_views_.size() != 1) {
329 ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
330 }
331 return reinterpret_cast<const void*>(input_string_views_[0].data());
332 }
333 size_t SizeInBytes() const override {
334 if (input_string_views_.size() != 1) {
335 ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
336 }
337 return input_string_views_[0].size();
338 }
339 const Span<std::string_view>& AsSpan() {
340 ORTX_CXX_API_THROW("span for TensorT of string view not implemented", ORT_RUNTIME_EXCEPTION);
341 }
342 std::string_view AsScalar() {
343 if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
344 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
345 }
346 return input_string_views_[0];
347 }
348
349 private:
350 std::vector<char> chars_; // for input
351 std::vector<std::string_view> input_string_views_; // for input
352};
353
354#if ORT_API_VERSION >= 16
355
356template <>
357struct Tensor<MFloat16> : public TensorBase {
358 Tensor(const OrtW::CustomOpApi& api,
359 OrtKernelContext& ctx,
360 size_t indice,
361 bool is_input) : TensorBase(api,
362 ctx,
363 indice,
364 is_input) {
365 type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
366 if (is_input_) {
367 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
368 if (indice >= input_count) {
369 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
370 }
371 const_value_ = api_.KernelContext_GetInput(&ctx_, indice);
372 auto* info = api_.GetTensorTypeAndShape(const_value_);
373 shape_ = api_.GetTensorShape(info);
374 type_ = api_.GetTensorElementType(info);
375 api_.ReleaseTensorTypeAndShapeInfo(info);
376 const OrtMemoryInfo* mem_info = {};
377 api_.ThrowOnError(api_.GetOrtApi().GetTensorMemoryInfo(const_value_, &mem_info));
378 if (mem_info) {
379 api_.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_));
380 }
381 }
382 }
383
384 const MFloat16* Data() const {
385 return reinterpret_cast<const MFloat16*>(api_.GetTensorData<uint16_t>(const_value_));
386 }
387
388 MFloat16* Allocate(const std::vector<int64_t>& shape) {
389 if (!data_) {
390 OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size());
391 shape_ = shape;
392 data_ = reinterpret_cast<MFloat16*>(api_.GetTensorMutableData<uint16_t>(out));
393 }
394 return data_;
395 }
396
397 const Span<MFloat16>& AsSpan() {
398 ORTX_CXX_API_THROW("AsSpan for MFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
399 }
400
401 const MFloat16& AsScalar() {
402 ORTX_CXX_API_THROW("AsScalar for MFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
403 }
404
405 const void* DataRaw() const override {
406 return reinterpret_cast<const void*>(Data());
407 }
408
409 virtual size_t SizeInBytes() const override {
410 return NumberOfElement() * sizeof(uint16_t);
411 }
412
413 private:
414 const OrtValue* const_value_{}; // for input
415 MFloat16* data_{}; // for output
416};
417
418template <>
419struct Tensor<BFloat16> : public TensorBase {
420 Tensor(const OrtW::CustomOpApi& api,
421 OrtKernelContext& ctx,
422 size_t indice,
423 bool is_input) : TensorBase(api,
424 ctx,
425 indice,
426 is_input) {
427 type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16;
428 if (is_input_) {
429 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
430 if (indice >= input_count) {
431 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
432 }
433 const_value_ = api_.KernelContext_GetInput(&ctx_, indice);
434 auto* info = api_.GetTensorTypeAndShape(const_value_);
435 shape_ = api_.GetTensorShape(info);
436 type_ = api_.GetTensorElementType(info);
437 api_.ReleaseTensorTypeAndShapeInfo(info);
438 const OrtMemoryInfo* mem_info = {};
439 api_.ThrowOnError(api_.GetOrtApi().GetTensorMemoryInfo(const_value_, &mem_info));
440 if (mem_info) {
441 api_.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_));
442 }
443 }
444 }
445
446 const BFloat16* Data() const {
447 return reinterpret_cast<const BFloat16*>(api_.GetTensorData<uint16_t>(const_value_));
448 }
449
450 BFloat16* Allocate(const std::vector<int64_t>& shape) {
451 if (!data_) {
452 OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size());
453 shape_ = shape;
454 data_ = reinterpret_cast<BFloat16*>(api_.GetTensorMutableData<uint16_t>(out));
455 }
456 return data_;
457 }
458
459 const Span<BFloat16>& AsSpan() {
460 ORTX_CXX_API_THROW("AsSpan for BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
461 }
462
463 const BFloat16& AsScalar() {
464 ORTX_CXX_API_THROW("AsScalar for BFloat16 not implemented", ORT_RUNTIME_EXCEPTION);
465 }
466
467 const void* DataRaw() const override {
468 return reinterpret_cast<const void*>(Data());
469 }
470
471 virtual size_t SizeInBytes() const override {
472 return NumberOfElement() * sizeof(uint16_t);
473 }
474
475 private:
476 const OrtValue* const_value_{}; // for input
477 BFloat16* data_{}; // for output
478};
479
480#endif
481
482using TensorPtr = std::unique_ptr<Custom::TensorBase>;
483using TensorPtrs = std::vector<TensorPtr>;
484
485// Represent variadic input or output
486struct Variadic : public TensorBase {
487 Variadic(const OrtW::CustomOpApi& api,
488 OrtKernelContext& ctx,
489 size_t indice,
490 bool is_input) : TensorBase(api,
491 ctx,
492 indice,
493 is_input) {
494#if ORT_API_VERSION < 14
495 ORTX_CXX_API_THROW("Variadic input or output only supported after onnxruntime 1.14", ORT_RUNTIME_EXCEPTION);
496#endif
497 if (is_input) {
498 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
499 for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
500 auto* const_value = api_.KernelContext_GetInput(&ctx_, ith_input);
501 auto* info = api_.GetTensorTypeAndShape(const_value);
502 auto type = api_.GetTensorElementType(info);
503 api_.ReleaseTensorTypeAndShapeInfo(info);
504 TensorPtr tensor;
505 switch (type) {
506 case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
507 tensor = std::make_unique<Custom::Tensor<bool>>(api, ctx, ith_input, true);
508 break;
509 case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
510 tensor = std::make_unique<Custom::Tensor<float>>(api, ctx, ith_input, true);
511 break;
512 case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
513 tensor = std::make_unique<Custom::Tensor<double>>(api, ctx, ith_input, true);
514 break;
515 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
516 tensor = std::make_unique<Custom::Tensor<uint8_t>>(api, ctx, ith_input, true);
517 break;
518 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
519 tensor = std::make_unique<Custom::Tensor<int8_t>>(api, ctx, ith_input, true);
520 break;
521 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
522 tensor = std::make_unique<Custom::Tensor<uint16_t>>(api, ctx, ith_input, true);
523 break;
524 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
525 tensor = std::make_unique<Custom::Tensor<int16_t>>(api, ctx, ith_input, true);
526 break;
527 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
528 tensor = std::make_unique<Custom::Tensor<uint32_t>>(api, ctx, ith_input, true);
529 break;
530 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
531 tensor = std::make_unique<Custom::Tensor<int32_t>>(api, ctx, ith_input, true);
532 break;
533 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
534 tensor = std::make_unique<Custom::Tensor<uint64_t>>(api, ctx, ith_input, true);
535 break;
536 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
537 tensor = std::make_unique<Custom::Tensor<int64_t>>(api, ctx, ith_input, true);
538 break;
539 case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
540 tensor = std::make_unique<Custom::Tensor<std::string>>(api, ctx, ith_input, true);
541 break;
542 default:
543 ORTX_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION);
544 break;
545 }
546 tensors_.emplace_back(tensor.release());
547 } // for
548 } else {
549 // a Variadic used for output is populated by the Compute so leave tensors_ empty here
550 }
551 }
552 template <typename T>
553 T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) {
554 auto tensor = std::make_unique<Tensor<T>>(api_, ctx_, ith_output, false);
555 auto raw_output = tensor.get()->Allocate(shape);
556 tensors_.emplace_back(tensor.release());
557 return raw_output;
558 }
559 Tensor<std::string>& AllocateStringTensor(size_t ith_output) {
560 auto tensor = std::make_unique<Tensor<std::string>>(api_, ctx_, ith_output, false);
561 Tensor<std::string>& output = *tensor;
562 tensors_.emplace_back(tensor.release());
563 return output;
564 }
565 const void* DataRaw() const override {
566 ORTX_CXX_API_THROW("DataRaw() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION);
567 return nullptr;
568 }
569 size_t SizeInBytes() const override {
570 ORTX_CXX_API_THROW("SizeInBytes() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION);
571 return 0;
572 }
573 size_t Size() const {
574 return tensors_.size();
575 }
576 const TensorPtr& operator[](size_t indice) const {
577 return tensors_.at(indice);
578 }
579
580 private:
581 TensorPtrs tensors_;
582};
583
584#ifdef USE_CUDA
585
586enum CudaResource {
587 cuda_handle_t = 10000,
588 cudnn_handle_t,
589 cublas_handle_t,
590 deferred_cpu_allocator_t,
591 // below are cuda ep options
592 device_id_t,
593};
594
595struct CudaContext {
596 static const int cuda_resource_ver = 1;
597 void Init(const OrtW::CustomOpApi& api, const OrtKernelContext& ctx) {
598 const auto& ort_api = api.GetOrtApi();
599 ort_api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream);
600 if (!cuda_stream) {
601 ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION);
602 }
603 ort_api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cublas_handle_t, &cublas);
604 if (!cublas) {
605 ORTX_CXX_API_THROW("Failed to fetch cublas handle from context", ORT_RUNTIME_EXCEPTION);
606 }
607 void* resource = nullptr;
608 OrtStatusPtr result = ort_api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::device_id_t, &resource);
609 if (result) {
610 ORTX_CXX_API_THROW("Failed to fetch device id from context", ORT_RUNTIME_EXCEPTION);
611 }
612 memcpy(&device_id, &resource, sizeof(int));
613 }
614 void* cuda_stream = {};
615 void* cublas = {};
616 int device_id = 0;
617};
618
619#endif
620
621// using mf16_t = uint16_t;
622
623struct OrtLiteCustomOp : public OrtCustomOp {
624 // CreateTuple
625 template <size_t ith_input, size_t ith_output, typename... Ts>
626 static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
627 CreateTuple(const OrtW::CustomOpApi*, OrtKernelContext*, std::vector<TensorPtr>&, size_t, size_t, const std::string&) {
628 return std::make_tuple();
629 }
630
631 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
632 static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type
633 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
634 std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
635 auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
636 return std::tuple_cat(current, next);
637 }
638
639#ifdef USE_CUDA
640 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
641 static typename std::enable_if<std::is_same<T, const CudaContext&>::value, std::tuple<T, Ts...>>::type
642 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
643 thread_local CudaContext cuda_context;
644 cuda_context.Init(*api, *context);
645 std::tuple<T> current = std::tuple<const CudaContext&>{cuda_context};
646 auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
647 return std::tuple_cat(current, next);
648 }
649#endif
650
651#if ORT_API_VERSION >= 14
652 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
653 static typename std::enable_if<std::is_same<T, const Variadic*>::value, std::tuple<T, Ts...>>::type
654 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
655 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
656 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
657 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
658 return std::tuple_cat(current, next);
659 }
660
661 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
662 static typename std::enable_if<std::is_same<T, const Variadic&>::value, std::tuple<T, Ts...>>::type
663 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
664 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
665 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
666 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
667 return std::tuple_cat(current, next);
668 }
669
670 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
671 static typename std::enable_if<std::is_same<T, Variadic*>::value, std::tuple<T, Ts...>>::type
672 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
673 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
674 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
675 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
676 return std::tuple_cat(current, next);
677 }
678
679 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
680 static typename std::enable_if<std::is_same<T, Variadic&>::value, std::tuple<T, Ts...>>::type
681 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
682 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
683 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
684 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
685 return std::tuple_cat(current, next);
686 }
687#endif
688
689#define CREATE_TUPLE_INPUT(data_type) \
690 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
691 static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
692 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
693 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
694 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; \
695 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
696 return std::tuple_cat(current, next); \
697 } \
698 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
699 static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
700 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
701 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
702 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; \
703 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
704 return std::tuple_cat(current, next); \
705 } \
706 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
707 static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
708 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
709 if (ith_input < num_input) { \
710 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
711 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())}; \
712 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
713 return std::tuple_cat(current, next); \
714 } else { \
715 std::tuple<T> current = std::tuple<T>{}; \
716 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
717 return std::tuple_cat(current, next); \
718 } \
719 } \
720 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
721 static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>*>::value, std::tuple<T, Ts...>>::type \
722 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
723 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
724 if (!tensors.back()->IsCpuTensor()) { \
725 ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \
726 } \
727 std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
728 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
729 return std::tuple_cat(current, next); \
730 } \
731 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
732 static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>&>::value, std::tuple<T, Ts...>>::type \
733 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
734 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
735 if (!tensors.back()->IsCpuTensor()) { \
736 ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \
737 } \
738 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
739 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
740 return std::tuple_cat(current, next); \
741 } \
742 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
743 static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type>*>>::value, std::tuple<T, Ts...>>::type \
744 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
745 if (ith_input < num_input) { \
746 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
747 if (!tensors.back()->IsCpuTensor()) { \
748 ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \
749 } \
750 std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
751 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
752 return std::tuple_cat(current, next); \
753 } else { \
754 std::tuple<T> current = std::tuple<T>{}; \
755 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
756 return std::tuple_cat(current, next); \
757 } \
758 } \
759 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
760 static typename std::enable_if<std::is_same<T, data_type>::value, std::tuple<T, Ts...>>::type \
761 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
762 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
763 if (!tensors.back()->IsCpuTensor()) { \
764 ORTX_CXX_API_THROW("scalar input could only be applied to CPU tensor", ORT_FAIL); \
765 } \
766 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsScalar()}; \
767 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
768 return std::tuple_cat(current, next); \
769 } \
770 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
771 static typename std::enable_if<std::is_same<T, std::optional<data_type>>::value, std::tuple<T, Ts...>>::type \
772 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
773 if (ith_input < num_input) { \
774 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
775 if (!tensors.back()->IsCpuTensor()) { \
776 ORTX_CXX_API_THROW("scalar input could only be applied to CPU tensor", ORT_FAIL); \
777 } \
778 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsScalar()}; \
779 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
780 return std::tuple_cat(current, next); \
781 } else { \
782 std::tuple<T> current = std::tuple<T>{}; \
783 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
784 return std::tuple_cat(current, next); \
785 } \
786 }
787#define CREATE_TUPLE_OUTPUT(data_type) \
788 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
789 static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
790 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
791 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
792 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; \
793 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
794 return std::tuple_cat(current, next); \
795 } \
796 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
797 static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
798 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
799 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
800 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; \
801 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
802 return std::tuple_cat(current, next); \
803 } \
804 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
805 static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
806 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
807 if (ith_output < num_output) { \
808 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
809 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())}; \
810 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
811 return std::tuple_cat(current, next); \
812 } else { \
813 std::tuple<T> current = std::tuple<T>{}; \
814 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
815 return std::tuple_cat(current, next); \
816 } \
817 }
818#define CREATE_TUPLE(data_type) \
819 CREATE_TUPLE_INPUT(data_type) \
820 CREATE_TUPLE_OUTPUT(data_type)
821
822 CREATE_TUPLE(bool)
823#if ORT_API_VERSION >= 16
824 CREATE_TUPLE(MFloat16)
825 CREATE_TUPLE(BFloat16)
826#endif
827 CREATE_TUPLE(float)
828 CREATE_TUPLE(double)
829 CREATE_TUPLE(int8_t)
830 CREATE_TUPLE(int16_t)
831 CREATE_TUPLE(int32_t)
832 CREATE_TUPLE(int64_t)
833 CREATE_TUPLE(uint8_t)
834 CREATE_TUPLE(uint16_t)
835 CREATE_TUPLE(uint32_t)
836 CREATE_TUPLE(uint64_t)
837 CREATE_TUPLE(std::string)
838 CREATE_TUPLE_INPUT(std::string_view)
839
840 // ParseArgs ...
841 template <typename... Ts>
842 static typename std::enable_if<0 == sizeof...(Ts)>::type
843 ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
844 }
845
846 template <typename T, typename... Ts>
847 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type
848 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
849 ParseArgs<Ts...>(input_types, output_types);
850 }
851
852#ifdef USE_CUDA
853 template <typename T, typename... Ts>
854 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const CudaContext&>::value>::type
855 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
856 ParseArgs<Ts...>(input_types, output_types);
857 }
858#endif
859
860#if ORT_API_VERSION >= 14
861 template <typename T, typename... Ts>
862 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic&>::value>::type
863 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
864 if (!input_types.empty()) {
865 ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
866 }
867 input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
868 ParseArgs<Ts...>(input_types, output_types);
869 }
870
871 template <typename T, typename... Ts>
872 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic*>::value>::type
873 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
874 if (!input_types.empty()) {
875 ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
876 }
877 input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
878 ParseArgs<Ts...>(input_types, output_types);
879 }
880
881 template <typename T, typename... Ts>
882 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic&>::value>::type
883 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
884 if (!output_types.empty()) {
885 ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
886 }
887 output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
888 ParseArgs<Ts...>(input_types, output_types);
889 }
890
891 template <typename T, typename... Ts>
892 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic*>::value>::type
893 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
894 if (!output_types.empty()) {
895 ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
896 }
897 output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
898 ParseArgs<Ts...>(input_types, output_types);
899 }
900#endif
901
902#define PARSE_INPUT_BASE(pack_type, onnx_type) \
903 template <typename T, typename... Ts> \
904 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
905 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
906 input_types.push_back(onnx_type); \
907 ParseArgs<Ts...>(input_types, output_types); \
908 } \
909 template <typename T, typename... Ts> \
910 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \
911 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
912 input_types.push_back(onnx_type); \
913 ParseArgs<Ts...>(input_types, output_types); \
914 }
915
916#define PARSE_INPUT(data_type, onnx_type) \
917 PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
918 PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
919 PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \
920 PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \
921 PARSE_INPUT_BASE(data_type, onnx_type)
922
923#define PARSE_OUTPUT(data_type, onnx_type) \
924 template <typename T, typename... Ts> \
925 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \
926 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
927 output_types.push_back(onnx_type); \
928 ParseArgs<Ts...>(input_types, output_types); \
929 } \
930 template <typename T, typename... Ts> \
931 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \
932 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
933 output_types.push_back(onnx_type); \
934 ParseArgs<Ts...>(input_types, output_types); \
935 } \
936 template <typename T, typename... Ts> \
937 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
938 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
939 output_types.push_back(onnx_type); \
940 ParseArgs<Ts...>(input_types, output_types); \
941 }
942
943#define PARSE_ARGS(data_type, onnx_type) \
944 PARSE_INPUT(data_type, onnx_type) \
945 PARSE_OUTPUT(data_type, onnx_type)
946
947 PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
948#if ORT_API_VERSION >= 16
949 PARSE_ARGS(MFloat16, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)
950 PARSE_ARGS(BFloat16, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16)
951#endif
952 PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
953 PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
954 PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
955 PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
956 PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
957 PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
958 PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
959 PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
960 PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
961 PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
962 PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
963 PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output
964
965 OrtLiteCustomOp(const char* op_name,
966 const char* execution_provider) : op_name_(op_name),
967 execution_provider_(execution_provider) {
968 // Zero out OrtCustomOp so that any added func pointers are nullptr for forwards compatibility
969 memset(&this->version, 0, sizeof(OrtCustomOp));
970
971 int act_ver = GetActiveOrtAPIVersion();
972 OrtCustomOp::version = act_ver < ORT_API_VERSION ? act_ver : ORT_API_VERSION;
973
974 OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
975 OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
976
977 OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) {
978 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
979 return self->input_types_.size();
980 };
981
982 OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) {
983 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
984 return self->input_types_[indice];
985 };
986
987 OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) {
988 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
989 return self->output_types_.size();
990 };
991
992 OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) {
993 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
994 return self->output_types_[indice];
995 };
996
997#if ORT_API_VERSION >= 14
998 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t) {
999 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
1000 return (self->input_types_.empty() || self->input_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
1001 };
1002
1003 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
1004 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
1005 return (self->output_types_.empty() || self->output_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
1006 };
1007
1008 OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) {
1009 return 1;
1010 };
1011
1012 OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) {
1013 return 0;
1014 };
1015
1016 OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) {
1017 return 1;
1018 };
1019
1020 OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) {
1021 return 0;
1022 };
1023
1024 OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) {
1025 return OrtMemTypeDefault;
1026 };
1027#else
1028 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) {
1029 return INPUT_OUTPUT_OPTIONAL;
1030 };
1031
1032 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
1033 return INPUT_OUTPUT_OPTIONAL;
1034 };
1035#endif
1036
1037#if ORT_API_VERSION >= 18
1038 OrtCustomOp::GetMayInplace = [](int**, int**) -> size_t {
1039 return 0;
1040 };
1041 OrtCustomOp::ReleaseMayInplace = [](int*, int*) -> void {};
1042#endif
1043 }
1044
1045 const std::string op_name_;
1046 const std::string execution_provider_;
1047
1048 std::vector<ONNXTensorElementDataType> input_types_;
1049 std::vector<ONNXTensorElementDataType> output_types_;
1050};
1051
1052template <typename... Args>
1053struct OrtLiteCustomFunc : public OrtLiteCustomOp {
1054 using ComputeFn = void (*)(Args...);
1055 using MyType = OrtLiteCustomFunc<Args...>;
1056
1057 struct Kernel {
1058 ComputeFn compute_fn_{};
1059 std::string ep_{};
1060 std::unique_ptr<OrtW::CustomOpApi> api_;
1061 };
1062
1063 OrtLiteCustomFunc(const char* op_name,
1064 const char* execution_provider,
1065 ComputeFn compute_fn) : OrtLiteCustomOp(op_name, execution_provider),
1066 compute_fn_(compute_fn) {
1067 ParseArgs<Args...>(input_types_, output_types_);
1068
1069 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
1070 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
1071 std::vector<TensorPtr> tensors;
1072 auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
1073 context,
1074 tensors,
1075 kernel->api_->KernelContext_GetInputCount(context),
1076 kernel->api_->KernelContext_GetOutputCount(context),
1077 kernel->ep_);
1078 std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t);
1079 };
1080
1081 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
1082 auto kernel = std::make_unique<Kernel>();
1083 auto self = static_cast<const OrtLiteCustomFunc*>(this_);
1084 kernel->compute_fn_ = self->compute_fn_;
1085 kernel->ep_ = self->execution_provider_;
1086 kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
1087 return reinterpret_cast<void*>(kernel.release());
1088 };
1089
1090 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
1091 delete reinterpret_cast<Kernel*>(op_kernel);
1092 };
1093 }
1094
1095 ComputeFn compute_fn_;
1096};
1097
1098template <typename CustomOp>
1099struct OrtLiteCustomStruct : public OrtLiteCustomOp {
1100 template <typename... Args>
1101 using CustomComputeFn = void (CustomOp::*)(Args...) const;
1102 using MyType = OrtLiteCustomStruct<CustomOp>;
1103
1104 struct Kernel {
1105 std::unique_ptr<CustomOp> custom_op_;
1106 std::string ep_{};
1107 std::unique_ptr<OrtW::CustomOpApi> api_;
1108 };
1109
1110 OrtLiteCustomStruct(const char* op_name,
1111 const char* execution_provider) : OrtLiteCustomOp(op_name,
1112 execution_provider) {
1113 init(&CustomOp::Compute);
1114 }
1115
1116 template <typename... Args>
1117 void init(CustomComputeFn<Args...>) {
1118 ParseArgs<Args...>(input_types_, output_types_);
1119
1120 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
1121 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
1122 std::vector<TensorPtr> tensors;
1123 auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
1124 context,
1125 tensors,
1126 kernel->api_->KernelContext_GetInputCount(context),
1127 kernel->api_->KernelContext_GetOutputCount(context),
1128 kernel->ep_);
1129 std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t);
1130 };
1131
1132 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
1133 auto kernel = std::make_unique<Kernel>();
1134 kernel->custom_op_ = std::make_unique<CustomOp>(*ort_api, *info);
1135 auto self = static_cast<const MyType*>(this_);
1136 kernel->ep_ = self->execution_provider_;
1137 kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
1138 return reinterpret_cast<void*>(kernel.release());
1139 };
1140
1141 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
1142 delete reinterpret_cast<Kernel*>(op_kernel);
1143 };
1144 }
1145};
1146
1147template <typename... Args>
1148OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
1149 const char* execution_provider,
1150 void (*custom_compute_fn)(Args...)) {
1151 using LiteOp = OrtLiteCustomFunc<Args...>;
1152 return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn).release();
1153}
1154
1155template <typename CustomOp>
1156OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
1157 const char* execution_provider) {
1158 using LiteOp = OrtLiteCustomStruct<CustomOp>;
1159 return std::make_unique<LiteOp>(op_name, execution_provider).release();
1160}
1161
1162} // namespace Custom
1163} // namespace Ort
1164