microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
leca/gqa2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

include/custom_op/custom_op_lite.h

1032lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5
6#include <optional>
7#include <numeric>
8#include <string>
9#include <string_view>
10#include "tensor_api.h"
11#include "ort_c_to_cpp.h"
12#include "onnxruntime_f16.h"
13
14namespace Ort {
15namespace Custom {
16
17class OrtKernelContextStorage : public ITensorStorage {
18 public:
19 OrtKernelContextStorage(const OrtW::CustomOpApi& api,
20 OrtKernelContext& ctx,
21 size_t indice,
22 bool is_input) : api_(api), ctx_(ctx), indice_(indice) {
23 if (is_input) {
24 auto input_count = api.KernelContext_GetInputCount(&ctx);
25 if (indice >= input_count) {
26 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
27 }
28 const_value_ = api.KernelContext_GetInput(&ctx, indice);
29 auto* info = api.GetTensorTypeAndShape(const_value_);
30 shape_ = api.GetTensorShape(info);
31 api.ReleaseTensorTypeAndShapeInfo(info);
32 }
33 }
34
35 const std::vector<int64_t>& Shape() const override {
36 if (!IsInitialized())
37 ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
38 return *shape_;
39 }
40
41 virtual bool IsInitialized() const override {
42 return shape_.has_value();
43 }
44
45 const void* DataRaw() const override {
46 return api_.GetTensorRawData(const_value_);
47 }
48
49 void* Initialize(const std::vector<int64_t>& shape, size_t element_size) override {
50 if (!const_value_) {
51 const_value_ = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size());
52 shape_ = shape;
53 }
54 return api_.GetTensorMutableRawData(const_cast<OrtValue*>(const_value_));
55 }
56
57 private:
58 const OrtW::CustomOpApi& api_;
59 OrtKernelContext& ctx_;
60 size_t indice_;
61 const OrtValue* const_value_{}; // for input
62 std::optional<std::vector<int64_t>> shape_;
63};
64
65static std::string get_mem_type(const OrtW::CustomOpApi& api,
66 OrtKernelContext& ctx,
67 size_t indice,
68 bool is_input){
69 std::string output = "Cpu";
70 if (is_input) {
71 const OrtValue* const_value = api.KernelContext_GetInput(&ctx, indice);
72 const OrtMemoryInfo* mem_info = {};
73 api.ThrowOnError(api.GetOrtApi().GetTensorMemoryInfo(const_value, &mem_info));
74 if (mem_info) {
75 const char* mem_type = nullptr;
76 api.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type));
77 if (mem_type) {
78 output = mem_type;
79 }
80 }
81 }
82 return output;
83}
84
85template <typename T>
86class OrtTensor : public Tensor<T> {
87public:
88 OrtTensor(const OrtW::CustomOpApi& api,
89 OrtKernelContext& ctx,
90 size_t indice,
91 bool is_input) : Tensor<T>(std::make_unique<OrtKernelContextStorage>(api, ctx, indice, is_input)),
92 mem_type_(get_mem_type(api, ctx, indice, is_input)) {
93 }
94
95 bool IsCpuTensor() const {
96 return mem_type_ == "Cpu";
97 }
98
99private:
100 std::string mem_type_ = "Cpu";
101};
102
103class OrtStringTensorStorage : public IStringTensorStorage<std::string> {
104 public:
105 using strings = std::vector<std::string>;
106 OrtStringTensorStorage(const OrtW::CustomOpApi& api,
107 OrtKernelContext& ctx,
108 size_t indice,
109 bool is_input) : api_(api), ctx_(ctx), indice_(indice) {
110 if (is_input) {
111 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
112 if (indice >= input_count) {
113 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
114 }
115
116 auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
117 auto* info = api_.GetTensorTypeAndShape(const_value);
118 shape_ = api_.GetTensorShape(info);
119 api_.ReleaseTensorTypeAndShapeInfo(info);
120
121 size_t num_chars;
122 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
123 std::vector<char> chars(num_chars + 1, '\0');
124 // assert((*shape_).size() == 1 || ((*shape_).size() == 2 && (*shape_)[0] == 1));
125
126 int64_t num_strings = 1; // string scalar
127 if ((*shape_).size() > 0) {
128 num_strings = (*shape_).front();
129 for (auto iter = (*shape_).begin() + 1; iter != (*shape_).end(); ++iter) {
130 num_strings *= *iter;
131 }
132 }
133 std::vector<size_t> offsets(num_strings);
134 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
135 (void*)chars.data(),
136 num_chars,
137 offsets.data(),
138 offsets.size()));
139 auto upper_bound = static_cast<int64_t>(num_strings) - 1;
140 input_strings_.resize(num_strings);
141 for (int64_t i = upper_bound; i >= 0; --i) {
142 if (i < upper_bound) {
143 chars[offsets[i + 1]] = '\0';
144 }
145 input_strings_[i] = chars.data() + offsets[i];
146 }
147 }
148 }
149
150 const std::vector<int64_t>& Shape() const override {
151 if (!IsInitialized())
152 ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
153 return *shape_;
154 }
155
156 virtual const void* DataRaw() const override {
157 if (input_strings_.size() != 1) {
158 ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
159 }
160 return reinterpret_cast<const void*>(input_strings_[0].c_str());
161 }
162
163 virtual bool IsInitialized() const override {
164 return shape_.has_value();
165 }
166
167 virtual void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) override {
168 std::vector<const char*> raw;
169 for (const auto& s : ss) {
170 raw.push_back(s.data());
171 }
172 auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
173 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, raw.data(), raw.size()));
174 }
175
176 virtual void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) override {
177 auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
178 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, ss.data(), ss.size()));
179 }
180
181 const strings& Data() const override {
182 return input_strings_;
183 }
184
185 private:
186 const OrtW::CustomOpApi& api_;
187 OrtKernelContext& ctx_;
188 size_t indice_;
189 std::vector<std::string> input_strings_;
190 std::optional<std::vector<int64_t>> shape_;
191};
192
193class OrtStringViewTensorStorage : public IStringTensorStorage<std::string_view> {
194 public:
195 using strings = std::vector<std::string_view>;
196 OrtStringViewTensorStorage(const OrtW::CustomOpApi& api,
197 OrtKernelContext& ctx,
198 size_t indice,
199 bool is_input) : api_(api), ctx_(ctx), indice_(indice) {
200 if (is_input) {
201 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
202 if (indice >= input_count) {
203 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
204 }
205 auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
206 auto* info = api_.GetTensorTypeAndShape(const_value);
207 shape_ = api_.GetTensorShape(info);
208 api_.ReleaseTensorTypeAndShapeInfo(info);
209
210 size_t num_chars;
211 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
212 chars_.resize(num_chars + 1, '\0');
213
214 size_t num_strings = 1;
215 if ((*shape_).size() > 0) {
216 num_strings = static_cast<size_t>((*shape_)[0]);
217 }
218
219 if (num_strings) {
220 std::vector<size_t> offsets(num_strings);
221 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
222 (void*)chars_.data(),
223 num_chars,
224 offsets.data(),
225 offsets.size()));
226 offsets.push_back(num_chars);
227 for (size_t i = 0; i < num_strings; ++i) {
228 input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
229 }
230 }
231 }
232 }
233
234 const std::vector<int64_t>& Shape() const override {
235 if (!IsInitialized())
236 ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
237 return *shape_;
238 }
239
240 virtual const void* DataRaw() const override {
241 if (input_string_views_.size() != 1) {
242 ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
243 }
244 return reinterpret_cast<const void*>(input_string_views_[0].data());
245 }
246
247 virtual bool IsInitialized() const override {
248 return shape_.has_value();
249 }
250
251 virtual void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) override {
252 ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
253 }
254
255 virtual void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) override {
256 ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
257 }
258
259 const strings& Data() const override {
260 return input_string_views_;
261 }
262
263 private:
264 const OrtW::CustomOpApi& api_;
265 OrtKernelContext& ctx_;
266 size_t indice_;
267 std::vector<char> chars_; // for input
268 std::vector<std::string_view> input_string_views_; // for input
269 std::optional<std::vector<int64_t>> shape_;
270};
271
272// to make the metaprogramming magic happy.
273template <>
274class OrtTensor<std::string> : public Tensor<std::string>{
275public:
276 OrtTensor(const OrtW::CustomOpApi& api,
277 OrtKernelContext& ctx,
278 size_t indice,
279 bool is_input) : Tensor<std::string>(std::make_unique<OrtStringTensorStorage>(api, ctx, indice, is_input)),
280 mem_type_(get_mem_type(api, ctx, indice, is_input)) {}
281
282 bool IsCpuTensor() const {
283 return mem_type_ == "Cpu";
284 }
285
286private:
287 std::string mem_type_ = "Cpu";
288};
289
290template <>
291class OrtTensor<std::string_view> : public Tensor<std::string_view>{
292public:
293 OrtTensor(const OrtW::CustomOpApi& api,
294 OrtKernelContext& ctx,
295 size_t indice,
296 bool is_input) : Tensor<std::string_view>(std::make_unique<OrtStringViewTensorStorage>(api, ctx, indice, is_input)),
297 mem_type_(get_mem_type(api, ctx, indice, is_input)) {}
298
299 bool IsCpuTensor() const {
300 return mem_type_ == "Cpu";
301 }
302
303private:
304 std::string mem_type_ = "Cpu";
305};
306
307using TensorPtr = std::unique_ptr<Custom::Arg>;
308using TensorPtrs = std::vector<TensorPtr>;
309
310
311using TensorBasePtr = std::unique_ptr<Custom::TensorBase>;
312using TensorBasePtrs = std::vector<TensorBasePtr>;
313
314// Represent variadic input or output
315struct Variadic : public Arg {
316 Variadic(const OrtW::CustomOpApi& api,
317 OrtKernelContext& ctx,
318 size_t indice,
319 bool is_input) : api_(api), ctx_(ctx), indice_(indice), mem_type_(get_mem_type(api, ctx, indice, is_input)) {
320#if ORT_API_VERSION < 14
321 ORTX_CXX_API_THROW("Variadic input or output only supported after onnxruntime 1.14", ORT_RUNTIME_EXCEPTION);
322#endif
323 if (is_input) {
324 auto input_count = api.KernelContext_GetInputCount(&ctx_);
325 for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
326 auto* const_value = api_.KernelContext_GetInput(&ctx_, ith_input);
327 auto* info = api_.GetTensorTypeAndShape(const_value);
328 auto type = api_.GetTensorElementType(info);
329 api_.ReleaseTensorTypeAndShapeInfo(info);
330 TensorBasePtr tensor;
331 switch (type) {
332 case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
333 tensor = std::make_unique<Custom::OrtTensor<bool>>(api, ctx, ith_input, true);
334 break;
335 case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
336 tensor = std::make_unique<Custom::OrtTensor<float>>(api, ctx, ith_input, true);
337 break;
338 case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
339 tensor = std::make_unique<Custom::OrtTensor<double>>(api, ctx, ith_input, true);
340 break;
341 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
342 tensor = std::make_unique<Custom::OrtTensor<uint8_t>>(api, ctx, ith_input, true);
343 break;
344 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
345 tensor = std::make_unique<Custom::OrtTensor<int8_t>>(api, ctx, ith_input, true);
346 break;
347 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
348 tensor = std::make_unique<Custom::OrtTensor<uint16_t>>(api, ctx, ith_input, true);
349 break;
350 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
351 tensor = std::make_unique<Custom::OrtTensor<int16_t>>(api, ctx, ith_input, true);
352 break;
353 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
354 tensor = std::make_unique<Custom::OrtTensor<uint32_t>>(api, ctx, ith_input, true);
355 break;
356 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
357 tensor = std::make_unique<Custom::OrtTensor<int32_t>>(api, ctx, ith_input, true);
358 break;
359 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
360 tensor = std::make_unique<Custom::OrtTensor<uint64_t>>(api, ctx, ith_input, true);
361 break;
362 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
363 tensor = std::make_unique<Custom::OrtTensor<int64_t>>(api, ctx, ith_input, true);
364 break;
365 case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
366 tensor = std::make_unique<Custom::OrtTensor<std::string>>(api, ctx, ith_input, true);
367 break;
368 default:
369 ORTX_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION);
370 break;
371 }
372 tensors_.emplace_back(tensor.release());
373 } // for
374 } else {
375 // a Variadic used for output is populated by the Compute so leave tensors_ empty here
376 }
377 }
378 template <typename T>
379 T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) {
380 auto tensor = std::make_unique<OrtTensor<T>>(api_, ctx_, ith_output, false);
381 auto raw_output = tensor.get()->Allocate(shape);
382 tensors_.emplace_back(tensor.release());
383 return raw_output;
384 }
385 Tensor<std::string>& AllocateStringTensor(size_t ith_output) {
386 auto tensor = std::make_unique<OrtTensor<std::string>>(api_, ctx_, ith_output, false);
387 Tensor<std::string>& output = *tensor;
388 tensors_.emplace_back(tensor.release());
389 return output;
390 }
391 size_t Size() const {
392 return tensors_.size();
393 }
394
395 const TensorBasePtr& operator[](size_t indice) const {
396 return tensors_.at(indice);
397 }
398
399 private:
400 const OrtW::CustomOpApi& api_;
401 OrtKernelContext& ctx_;
402 size_t indice_;
403 std::string mem_type_ = "Cpu";
404 TensorBasePtrs tensors_;
405};
406
407#if ORT_API_VERSION >= 17
408
409class OrtGraphKernelContext : public KernelContext {
410 public:
411 OrtGraphKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api) {
412 OrtMemoryInfo* info;
413 OrtW::ThrowOnError(api, api.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
414 OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, info, &allocator_));
415 api.ReleaseMemoryInfo(info);
416 }
417
418 virtual ~OrtGraphKernelContext() {
419 if (allocator_) {
420 api_.ReleaseAllocator(allocator_);
421 }
422 }
423
424 void* AllocScratchBuffer(size_t size) override {
425 return allocator_->Alloc(allocator_, size);
426 }
427
428 void FreeScratchBuffer(void* p) override {
429 if (p) {
430 allocator_->Free(allocator_, p);
431 }
432 }
433
434 private:
435 const OrtApi& api_;
436 OrtAllocator* allocator_;
437};
438
439#endif
440
441#ifdef USE_CUDA
442
443enum CudaResource {
444 cuda_handle_t = 10000,
445 cudnn_handle_t,
446 cublas_handle_t,
447 deferred_cpu_allocator_t,
448 // below are cuda ep options
449 device_id_t,
450};
451
452#if ORT_API_VERSION >= 17
453class OrtGraphCudaKernelContext : public CUDAKernelContext {
454 public:
455 static const int cuda_resource_ver = 1;
456
457 OrtGraphCudaKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api) {
458 api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream_);
459 if (!cuda_stream_) {
460 ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION);
461 }
462 api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cublas_handle_t, &cublas_);
463 if (!cublas_) {
464 ORTX_CXX_API_THROW("Failed to fetch cublas handle from context", ORT_RUNTIME_EXCEPTION);
465 }
466 void* resource = nullptr;
467 OrtStatusPtr result = api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::device_id_t, &resource);
468 if (result) {
469 ORTX_CXX_API_THROW("Failed to fetch device id from context", ORT_RUNTIME_EXCEPTION);
470 }
471 memcpy(&device_id_, &resource, sizeof(int));
472
473 OrtMemoryInfo* info;
474 OrtW::ThrowOnError(api, api.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
475 OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, info, &cpu_allocator_));
476 api.ReleaseMemoryInfo(info);
477
478 OrtMemoryInfo* cuda_mem_info;
479 OrtW::ThrowOnError(api, api.CreateMemoryInfo("Cuda", OrtAllocatorType::OrtArenaAllocator, device_id_, OrtMemType::OrtMemTypeDefault, &cuda_mem_info));
480 OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, cuda_mem_info, &cuda_allocator_));
481 api.ReleaseMemoryInfo(cuda_mem_info);
482 }
483
484 virtual ~OrtGraphCudaKernelContext() {
485 if (cpu_allocator_) {
486 api_.ReleaseAllocator(cpu_allocator_);
487 }
488 if (cuda_allocator_) {
489 api_.ReleaseAllocator(cuda_allocator_);
490 }
491 }
492
493 void* AllocScratchBuffer(size_t size) override {
494 return cpu_allocator_->Alloc(cpu_allocator_, size);
495 }
496
497 void FreeScratchBuffer(void* p) override {
498 if (p) {
499 cpu_allocator_->Free(cpu_allocator_, p);
500 }
501 }
502
503 void* AllocCudaScratchBuffer(size_t size) override {
504 return cuda_allocator_->Alloc(cuda_allocator_, size);
505 }
506
507 void FreeCudaScratchBuffer(void* p) override {
508 if (p) {
509 cuda_allocator_->Free(cuda_allocator_, p);
510 }
511 }
512
513 void* GetCudaStream() const override {
514 return cuda_stream_;
515 }
516
517 void* GetCublasHandle() const override {
518 return cublas_;
519 }
520
521 int GetCudaDeviceId() const override {
522 return device_id_;
523 }
524
525 private:
526 const OrtApi& api_;
527 OrtAllocator* cpu_allocator_;
528 OrtAllocator* cuda_allocator_;
529 void* cuda_stream_ = {};
530 void* cublas_ = {};
531 int device_id_ = 0;
532};
533
534#endif
535#endif
536
537// using mf16_t = uint16_t;
538
539struct OrtLiteCustomOp : public OrtCustomOp {
540 // CreateTuple
541 template <size_t ith_input, size_t ith_output, typename... Ts>
542 static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
543 CreateTuple(const OrtW::CustomOpApi*, OrtKernelContext*, std::vector<TensorPtr>&, size_t, size_t, const std::string&) {
544 return std::make_tuple();
545 }
546
547 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
548 static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type
549 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
550 std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
551 auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
552 return std::tuple_cat(current, next);
553 }
554
555#if ORT_API_VERSION >= 17
556 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
557 static typename std::enable_if<std::is_same<T, KernelContext*>::value, std::tuple<T, Ts...>>::type
558 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
559 tensors.push_back(std::make_unique<OrtGraphKernelContext>(api->GetOrtApi(), *context));
560 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
561 auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
562 return std::tuple_cat(current, next);
563 }
564
565#ifdef USE_CUDA
566
567 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
568 static typename std::enable_if<std::is_same<T, CUDAKernelContext*>::value, std::tuple<T, Ts...>>::type
569 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
570 tensors.push_back(std::make_unique<OrtGraphCudaKernelContext>(api->GetOrtApi(), *context));
571 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
572 auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
573 return std::tuple_cat(current, next);
574 }
575
576#endif
577
578#endif
579
580#if ORT_API_VERSION >= 14
581 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
582 static typename std::enable_if<std::is_same<T, const Variadic*>::value, std::tuple<T, Ts...>>::type
583 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
584 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
585 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
586 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
587 return std::tuple_cat(current, next);
588 }
589
590 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
591 static typename std::enable_if<std::is_same<T, const Variadic&>::value, std::tuple<T, Ts...>>::type
592 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
593 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
594 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
595 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
596 return std::tuple_cat(current, next);
597 }
598
599 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
600 static typename std::enable_if<std::is_same<T, Variadic*>::value, std::tuple<T, Ts...>>::type
601 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
602 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
603 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
604 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
605 return std::tuple_cat(current, next);
606 }
607
608 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
609 static typename std::enable_if<std::is_same<T, Variadic&>::value, std::tuple<T, Ts...>>::type
610 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
611 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
612 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
613 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
614 return std::tuple_cat(current, next);
615 }
616#endif
617
618#undef data_type_def
619#define data_type_def bool
620#include "tensor_tuple.inc"
621
622#if ORT_API_VERSION >= 16
623#undef data_type_def
624#define data_type_def BFloat16
625#include "tensor_tuple.inc"
626
627#undef data_type_def
628#define data_type_def MFloat16
629#include "tensor_tuple.inc"
630#endif
631
632#undef data_type_def
633#define data_type_def float
634#include "tensor_tuple.inc"
635
636#undef data_type_def
637#define data_type_def double
638#include "tensor_tuple.inc"
639
640#undef data_type_def
641#define data_type_def int8_t
642#include "tensor_tuple.inc"
643
644#undef data_type_def
645#define data_type_def int16_t
646#include "tensor_tuple.inc"
647
648#undef data_type_def
649#define data_type_def int32_t
650#include "tensor_tuple.inc"
651
652#undef data_type_def
653#define data_type_def int64_t
654#include "tensor_tuple.inc"
655
656#undef data_type_def
657#define data_type_def uint8_t
658#include "tensor_tuple.inc"
659
660#undef data_type_def
661#define data_type_def uint16_t
662#include "tensor_tuple.inc"
663
664#undef data_type_def
665#define data_type_def uint32_t
666#include "tensor_tuple.inc"
667
668#undef data_type_def
669#define data_type_def uint64_t
670#include "tensor_tuple.inc"
671
672#undef data_type_def
673#define data_type_def std::string
674#include "tensor_tuple.inc"
675
676#undef data_type_def
677#define data_type_def std::string_view
678#include "tensor_tuple.inc"
679
680#undef data_type_def
681
682 // ParseArgs ...
683 template <typename... Ts>
684 static typename std::enable_if<0 == sizeof...(Ts)>::type
685 ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
686 }
687
688 template <typename T, typename... Ts>
689 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type
690 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
691 ParseArgs<Ts...>(input_types, output_types);
692 }
693
694#if ORT_API_VERSION >= 17
695 template <typename T, typename... Ts>
696 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, KernelContext*>::value>::type
697 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
698 ParseArgs<Ts...>(input_types, output_types);
699 }
700
701#ifdef USE_CUDA
702 template <typename T, typename... Ts>
703 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, CUDAKernelContext*>::value>::type
704 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
705 ParseArgs<Ts...>(input_types, output_types);
706 }
707#endif
708#endif
709
710#if ORT_API_VERSION >= 14
711 template <typename T, typename... Ts>
712 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic&>::value>::type
713 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
714 if (!input_types.empty()) {
715 ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
716 }
717 input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
718 ParseArgs<Ts...>(input_types, output_types);
719 }
720
721 template <typename T, typename... Ts>
722 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic*>::value>::type
723 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
724 if (!input_types.empty()) {
725 ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
726 }
727 input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
728 ParseArgs<Ts...>(input_types, output_types);
729 }
730
731 template <typename T, typename... Ts>
732 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic&>::value>::type
733 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
734 if (!output_types.empty()) {
735 ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
736 }
737 output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
738 ParseArgs<Ts...>(input_types, output_types);
739 }
740
741 template <typename T, typename... Ts>
742 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic*>::value>::type
743 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
744 if (!output_types.empty()) {
745 ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
746 }
747 output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
748 ParseArgs<Ts...>(input_types, output_types);
749 }
750#endif
751
752#define PARSE_INPUT_BASE(pack_type, onnx_type) \
753 template <typename T, typename... Ts> \
754 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
755 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
756 input_types.push_back(onnx_type); \
757 ParseArgs<Ts...>(input_types, output_types); \
758 } \
759 template <typename T, typename... Ts> \
760 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \
761 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
762 input_types.push_back(onnx_type); \
763 ParseArgs<Ts...>(input_types, output_types); \
764 }
765
766#define PARSE_INPUT(data_type, onnx_type) \
767 PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
768 PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
769 PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \
770 PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \
771 PARSE_INPUT_BASE(data_type, onnx_type)
772
773#define PARSE_OUTPUT(data_type, onnx_type) \
774 template <typename T, typename... Ts> \
775 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \
776 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
777 output_types.push_back(onnx_type); \
778 ParseArgs<Ts...>(input_types, output_types); \
779 } \
780 template <typename T, typename... Ts> \
781 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \
782 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
783 output_types.push_back(onnx_type); \
784 ParseArgs<Ts...>(input_types, output_types); \
785 } \
786 template <typename T, typename... Ts> \
787 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
788 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
789 output_types.push_back(onnx_type); \
790 ParseArgs<Ts...>(input_types, output_types); \
791 }
792
793#define PARSE_ARGS(data_type, onnx_type) \
794 PARSE_INPUT(data_type, onnx_type) \
795 PARSE_OUTPUT(data_type, onnx_type)
796
797 PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
798#if ORT_API_VERSION >= 16
799 PARSE_ARGS(MFloat16, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)
800 PARSE_ARGS(BFloat16, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16)
801#endif
802 PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
803 PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
804 PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
805 PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
806 PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
807 PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
808 PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
809 PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
810 PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
811 PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
812 PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
813 PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output
814
815 OrtLiteCustomOp(const char* op_name,
816 const char* execution_provider) : op_name_(op_name),
817 execution_provider_(execution_provider) {
818 // Zero out OrtCustomOp so that any added func pointers are nullptr for forwards compatibility
819 memset(&this->version, 0, sizeof(OrtCustomOp));
820
821 int act_ver = GetActiveOrtAPIVersion();
822 OrtCustomOp::version = act_ver < ORT_API_VERSION ? act_ver : ORT_API_VERSION;
823
824 OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
825 OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
826
827 OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) {
828 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
829 return self->input_types_.size();
830 };
831
832 OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) {
833 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
834 return self->input_types_[indice];
835 };
836
837 OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) {
838 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
839 return self->output_types_.size();
840 };
841
842 OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) {
843 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
844 return self->output_types_[indice];
845 };
846
847#if ORT_API_VERSION >= 14
848 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t) {
849 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
850 return (self->input_types_.empty() || self->input_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
851 };
852
853 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
854 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
855 return (self->output_types_.empty() || self->output_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
856 };
857
858 OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) {
859 return 1;
860 };
861
862 OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) {
863 return 0;
864 };
865
866 OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) {
867 return 1;
868 };
869
870 OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) {
871 return 0;
872 };
873
874 OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) {
875 return OrtMemTypeDefault;
876 };
877#else
878 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) {
879 return INPUT_OUTPUT_OPTIONAL;
880 };
881
882 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
883 return INPUT_OUTPUT_OPTIONAL;
884 };
885#endif
886
887#if ORT_API_VERSION >= 18
888 OrtCustomOp::GetMayInplace = [](int**, int**) -> size_t {
889 return 0;
890 };
891 OrtCustomOp::ReleaseMayInplace = [](int*, int*) -> void {};
892#endif
893 }
894
895 const std::string op_name_;
896 const std::string execution_provider_;
897
898 std::vector<ONNXTensorElementDataType> input_types_;
899 std::vector<ONNXTensorElementDataType> output_types_;
900};
901
902template <typename... Args>
903struct OrtLiteCustomFunc : public OrtLiteCustomOp {
904 using ComputeFn = void (*)(Args...);
905 using MyType = OrtLiteCustomFunc<Args...>;
906
907 struct Kernel {
908 ComputeFn compute_fn_{};
909 std::string ep_{};
910 std::unique_ptr<OrtW::CustomOpApi> api_;
911 };
912
913 OrtLiteCustomFunc(const char* op_name,
914 const char* execution_provider,
915 ComputeFn compute_fn) : OrtLiteCustomOp(op_name, execution_provider),
916 compute_fn_(compute_fn) {
917 ParseArgs<Args...>(input_types_, output_types_);
918
919 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
920 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
921 std::vector<TensorPtr> tensors;
922 auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
923 context,
924 tensors,
925 kernel->api_->KernelContext_GetInputCount(context),
926 kernel->api_->KernelContext_GetOutputCount(context),
927 kernel->ep_);
928 std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t);
929 };
930
931 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
932 auto kernel = std::make_unique<Kernel>();
933 auto self = static_cast<const OrtLiteCustomFunc*>(this_);
934 kernel->compute_fn_ = self->compute_fn_;
935 kernel->ep_ = self->execution_provider_;
936 kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
937 return reinterpret_cast<void*>(kernel.release());
938 };
939
940 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
941 delete reinterpret_cast<Kernel*>(op_kernel);
942 };
943 }
944
945 ComputeFn compute_fn_;
946};
947
948class OrtAttributeReader {
949 public:
950 OrtAttributeReader(const OrtApi& api, const OrtKernelInfo& info) : base_kernel_(api, info) {
951 }
952
953 template <class T>
954 T TryToGetAttributeWithDefault(const char* name, const T& default_value) const noexcept {
955 return base_kernel_.TryToGetAttributeWithDefault(name, default_value);
956 }
957
958 private:
959 BaseKernel base_kernel_;
960};
961
962template <typename CustomOp>
963struct OrtLiteCustomStruct : public OrtLiteCustomOp {
964 template <typename... Args>
965 using CustomComputeFn = void (CustomOp::*)(Args...) const;
966 using MyType = OrtLiteCustomStruct<CustomOp>;
967
968 struct Kernel {
969 std::unique_ptr<CustomOp> custom_op_;
970 std::string ep_{};
971 std::unique_ptr<OrtW::CustomOpApi> api_;
972 };
973
974 OrtLiteCustomStruct(const char* op_name,
975 const char* execution_provider) : OrtLiteCustomOp(op_name,
976 execution_provider) {
977 init(&CustomOp::Compute);
978 }
979
980 template <typename... Args>
981 void init(CustomComputeFn<Args...>) {
982 ParseArgs<Args...>(input_types_, output_types_);
983
984 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
985 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
986 std::vector<TensorPtr> tensors;
987 auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
988 context,
989 tensors,
990 kernel->api_->KernelContext_GetInputCount(context),
991 kernel->api_->KernelContext_GetOutputCount(context),
992 kernel->ep_);
993 std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t);
994 };
995
996 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
997 auto kernel = std::make_unique<Kernel>();
998
999 if constexpr (std::is_constructible<CustomOp, const OrtApi&, const OrtKernelInfo&>::value) {
1000 kernel->custom_op_ = std::make_unique<CustomOp>(*ort_api, *info);
1001 } else {
1002 kernel->custom_op_ = std::make_unique<CustomOp>(OrtAttributeReader(*ort_api, *info));
1003 }
1004 auto self = static_cast<const MyType*>(this_);
1005 kernel->ep_ = self->execution_provider_;
1006 kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
1007 return reinterpret_cast<void*>(kernel.release());
1008 };
1009
1010 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
1011 delete reinterpret_cast<Kernel*>(op_kernel);
1012 };
1013 }
1014};
1015
1016template <typename... Args>
1017OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
1018 const char* execution_provider,
1019 void (*custom_compute_fn)(Args...)) {
1020 using LiteOp = OrtLiteCustomFunc<Args...>;
1021 return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn).release();
1022}
1023
1024template <typename CustomOp>
1025OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
1026 const char* execution_provider) {
1027 using LiteOp = OrtLiteCustomStruct<CustomOp>;
1028 return std::make_unique<LiteOp>(op_name, execution_provider).release();
1029}
1030
1031} // namespace Custom
1032} // namespace Ort
1033