microsoft/onnxruntime-extensions

Public

mirrored from https://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
wenbingl-patch-2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

include/custom_op/custom_op_lite.h

1029lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5
6#include <optional>
7#include <numeric>
8#include <string>
9#include <string_view>
10#include "tensor_api.h"
11#include "ort_c_to_cpp.h"
12#include "onnxruntime_f16.h"
13
14namespace Ort {
15namespace Custom {
16
17class OrtKernelContextStorage : public ITensorStorage {
18 public:
19 OrtKernelContextStorage(const OrtW::CustomOpApi& api,
20 OrtKernelContext& ctx,
21 size_t indice,
22 bool is_input) : api_(api), ctx_(ctx), indice_(indice) {
23 if (is_input) {
24 auto input_count = api.KernelContext_GetInputCount(&ctx);
25 if (indice >= input_count) {
26 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
27 }
28 const_value_ = api.KernelContext_GetInput(&ctx, indice);
29 auto* info = api.GetTensorTypeAndShape(const_value_);
30 shape_ = api.GetTensorShape(info);
31 api.ReleaseTensorTypeAndShapeInfo(info);
32 }
33 }
34
35 const std::vector<int64_t>& Shape() const override {
36 if (!IsInitialized())
37 ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
38 return *shape_;
39 }
40
41 virtual bool IsInitialized() const override {
42 return shape_.has_value();
43 }
44
45 const void* DataRaw() const override {
46 return api_.GetTensorRawData(const_value_);
47 }
48
49 void* Initialize(const std::vector<int64_t>& shape, size_t element_size) override {
50 if (!const_value_) {
51 const_value_ = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size());
52 shape_ = shape;
53 }
54 return api_.GetTensorMutableRawData(const_cast<OrtValue*>(const_value_));
55 }
56
57 void* Release() override {
58 ORTX_CXX_API_THROW("Can't release the tensor buffer with ORT graph mode.", ORT_RUNTIME_EXCEPTION);
59 }
60
61 private:
62 const OrtW::CustomOpApi& api_;
63 OrtKernelContext& ctx_;
64 size_t indice_;
65 const OrtValue* const_value_{}; // for input
66 std::optional<std::vector<int64_t>> shape_;
67};
68
69static std::string get_mem_type(const OrtW::CustomOpApi& api,
70 OrtKernelContext& ctx,
71 size_t indice,
72 bool is_input){
73 std::string output = "Cpu";
74 if (is_input) {
75 const OrtValue* const_value = api.KernelContext_GetInput(&ctx, indice);
76 const OrtMemoryInfo* mem_info = {};
77 api.ThrowOnError(api.GetOrtApi().GetTensorMemoryInfo(const_value, &mem_info));
78 if (mem_info) {
79 const char* mem_type = nullptr;
80 api.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type));
81 if (mem_type) {
82 output = mem_type;
83 }
84 }
85 }
86 return output;
87}
88
89template <typename T>
90class OrtTensor : public Tensor<T> {
91public:
92 OrtTensor(const OrtW::CustomOpApi& api,
93 OrtKernelContext& ctx,
94 size_t indice,
95 bool is_input) : Tensor<T>(std::make_unique<OrtKernelContextStorage>(api, ctx, indice, is_input)),
96 mem_type_(get_mem_type(api, ctx, indice, is_input)) {
97 }
98
99 bool IsCpuTensor() const {
100 return mem_type_ == "Cpu";
101 }
102
103private:
104 std::string mem_type_ = "Cpu";
105};
106
107class OrtStringTensorStorage : public IStringTensorStorage<std::string> {
108 public:
109 using strings = std::vector<std::string>;
110 OrtStringTensorStorage(const OrtW::CustomOpApi& api,
111 OrtKernelContext& ctx,
112 size_t indice,
113 bool is_input) : api_(api), ctx_(ctx), indice_(indice) {
114 if (is_input) {
115 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
116 if (indice >= input_count) {
117 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
118 }
119
120 auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
121 auto* info = api_.GetTensorTypeAndShape(const_value);
122 shape_ = api_.GetTensorShape(info);
123 api_.ReleaseTensorTypeAndShapeInfo(info);
124
125 size_t num_chars;
126 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
127 std::vector<char> chars(num_chars + 1, '\0');
128 // assert((*shape_).size() == 1 || ((*shape_).size() == 2 && (*shape_)[0] == 1));
129
130 int64_t num_strings = 1; // string scalar
131 if ((*shape_).size() > 0) {
132 num_strings = (*shape_).front();
133 for (auto iter = (*shape_).begin() + 1; iter != (*shape_).end(); ++iter) {
134 num_strings *= *iter;
135 }
136 }
137 std::vector<size_t> offsets(num_strings);
138 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
139 (void*)chars.data(),
140 num_chars,
141 offsets.data(),
142 offsets.size()));
143 auto upper_bound = static_cast<int64_t>(num_strings) - 1;
144 input_strings_.resize(num_strings);
145 for (int64_t i = upper_bound; i >= 0; --i) {
146 if (i < upper_bound) {
147 chars[offsets[i + 1]] = '\0';
148 }
149 input_strings_[i] = chars.data() + offsets[i];
150 }
151 }
152 }
153
154 const std::vector<int64_t>& Shape() const override {
155 if (!IsInitialized())
156 ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
157 return *shape_;
158 }
159
160 virtual const void* DataRaw() const override {
161 if (input_strings_.size() != 1) {
162 ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
163 }
164 return reinterpret_cast<const void*>(input_strings_[0].c_str());
165 }
166
167 virtual bool IsInitialized() const override {
168 return shape_.has_value();
169 }
170
171 virtual void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) override {
172 std::vector<const char*> raw;
173 for (const auto& s : ss) {
174 raw.push_back(s.data());
175 }
176 auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
177 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, raw.data(), raw.size()));
178 }
179
180 virtual void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) override {
181 auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
182 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, ss.data(), ss.size()));
183 }
184
185 const strings& Data() const override {
186 return input_strings_;
187 }
188
189 private:
190 const OrtW::CustomOpApi& api_;
191 OrtKernelContext& ctx_;
192 size_t indice_;
193 std::vector<std::string> input_strings_;
194 std::optional<std::vector<int64_t>> shape_;
195};
196
197class OrtStringViewTensorStorage : public IStringTensorStorage<std::string_view> {
198 public:
199 using strings = std::vector<std::string_view>;
200 OrtStringViewTensorStorage(const OrtW::CustomOpApi& api,
201 OrtKernelContext& ctx,
202 size_t indice,
203 bool is_input) : api_(api), ctx_(ctx), indice_(indice) {
204 if (is_input) {
205 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
206 if (indice >= input_count) {
207 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
208 }
209 auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
210 auto* info = api_.GetTensorTypeAndShape(const_value);
211 shape_ = api_.GetTensorShape(info);
212 api_.ReleaseTensorTypeAndShapeInfo(info);
213
214 size_t num_chars;
215 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
216 chars_.resize(num_chars + 1, '\0');
217
218 size_t num_strings = 1;
219 if ((*shape_).size() > 0) {
220 num_strings = static_cast<size_t>((*shape_)[0]);
221 }
222
223 if (num_strings) {
224 std::vector<size_t> offsets(num_strings);
225 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
226 (void*)chars_.data(),
227 num_chars,
228 offsets.data(),
229 offsets.size()));
230 offsets.push_back(num_chars);
231 for (size_t i = 0; i < num_strings; ++i) {
232 input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
233 }
234 }
235 }
236 }
237
238 const std::vector<int64_t>& Shape() const override {
239 if (!IsInitialized())
240 ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
241 return *shape_;
242 }
243
244 virtual const void* DataRaw() const override {
245 if (input_string_views_.size() != 1) {
246 ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
247 }
248 return reinterpret_cast<const void*>(input_string_views_[0].data());
249 }
250
251 virtual bool IsInitialized() const override {
252 return shape_.has_value();
253 }
254
255 virtual void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) override {
256 ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
257 }
258
259 virtual void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) override {
260 ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
261 }
262
263 const strings& Data() const override {
264 return input_string_views_;
265 }
266
267 private:
268 const OrtW::CustomOpApi& api_;
269 OrtKernelContext& ctx_;
270 size_t indice_;
271 std::vector<char> chars_; // for input
272 std::vector<std::string_view> input_string_views_; // for input
273 std::optional<std::vector<int64_t>> shape_;
274};
275
276// to make the metaprogramming magic happy.
277template <>
278class OrtTensor<std::string> : public Tensor<std::string>{
279public:
280 OrtTensor(const OrtW::CustomOpApi& api,
281 OrtKernelContext& ctx,
282 size_t indice,
283 bool is_input) : Tensor<std::string>(std::make_unique<OrtStringTensorStorage>(api, ctx, indice, is_input)),
284 mem_type_(get_mem_type(api, ctx, indice, is_input)) {}
285
286 bool IsCpuTensor() const {
287 return mem_type_ == "Cpu";
288 }
289
290private:
291 std::string mem_type_ = "Cpu";
292};
293
294template <>
295class OrtTensor<std::string_view> : public Tensor<std::string_view>{
296public:
297 OrtTensor(const OrtW::CustomOpApi& api,
298 OrtKernelContext& ctx,
299 size_t indice,
300 bool is_input) : Tensor<std::string_view>(std::make_unique<OrtStringViewTensorStorage>(api, ctx, indice, is_input)),
301 mem_type_(get_mem_type(api, ctx, indice, is_input)) {}
302
303 bool IsCpuTensor() const {
304 return mem_type_ == "Cpu";
305 }
306
307private:
308 std::string mem_type_ = "Cpu";
309};
310
311using TensorPtr = std::unique_ptr<Custom::Arg>;
312using TensorPtrs = std::vector<TensorPtr>;
313
314
315using TensorBasePtr = std::unique_ptr<Custom::TensorBase>;
316using TensorBasePtrs = std::vector<TensorBasePtr>;
317
318// Represent variadic input or output
319struct Variadic : public Arg {
320 Variadic(const OrtW::CustomOpApi& api,
321 OrtKernelContext& ctx,
322 size_t indice,
323 bool is_input) : api_(api), ctx_(ctx), indice_(indice), mem_type_(get_mem_type(api, ctx, indice, is_input)) {
324#if ORT_API_VERSION < 14
325 ORTX_CXX_API_THROW("Variadic input or output only supported after onnxruntime 1.14", ORT_RUNTIME_EXCEPTION);
326#endif
327 if (is_input) {
328 auto input_count = api.KernelContext_GetInputCount(&ctx_);
329 for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
330 auto* const_value = api_.KernelContext_GetInput(&ctx_, ith_input);
331 auto* info = api_.GetTensorTypeAndShape(const_value);
332 auto type = api_.GetTensorElementType(info);
333 api_.ReleaseTensorTypeAndShapeInfo(info);
334 TensorBasePtr tensor;
335 switch (type) {
336 case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
337 tensor = std::make_unique<Custom::OrtTensor<bool>>(api, ctx, ith_input, true);
338 break;
339 case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
340 tensor = std::make_unique<Custom::OrtTensor<float>>(api, ctx, ith_input, true);
341 break;
342 case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
343 tensor = std::make_unique<Custom::OrtTensor<double>>(api, ctx, ith_input, true);
344 break;
345 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
346 tensor = std::make_unique<Custom::OrtTensor<uint8_t>>(api, ctx, ith_input, true);
347 break;
348 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
349 tensor = std::make_unique<Custom::OrtTensor<int8_t>>(api, ctx, ith_input, true);
350 break;
351 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
352 tensor = std::make_unique<Custom::OrtTensor<uint16_t>>(api, ctx, ith_input, true);
353 break;
354 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
355 tensor = std::make_unique<Custom::OrtTensor<int16_t>>(api, ctx, ith_input, true);
356 break;
357 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
358 tensor = std::make_unique<Custom::OrtTensor<uint32_t>>(api, ctx, ith_input, true);
359 break;
360 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
361 tensor = std::make_unique<Custom::OrtTensor<int32_t>>(api, ctx, ith_input, true);
362 break;
363 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
364 tensor = std::make_unique<Custom::OrtTensor<uint64_t>>(api, ctx, ith_input, true);
365 break;
366 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
367 tensor = std::make_unique<Custom::OrtTensor<int64_t>>(api, ctx, ith_input, true);
368 break;
369 case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
370 tensor = std::make_unique<Custom::OrtTensor<std::string>>(api, ctx, ith_input, true);
371 break;
372 default:
373 ORTX_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION);
374 break;
375 }
376 tensors_.emplace_back(tensor.release());
377 } // for
378 } else {
379 // a Variadic used for output is populated by the Compute so leave tensors_ empty here
380 }
381 }
382 template <typename T>
383 T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) {
384 auto tensor = std::make_unique<OrtTensor<T>>(api_, ctx_, ith_output, false);
385 auto raw_output = tensor.get()->Allocate(shape);
386 tensors_.emplace_back(tensor.release());
387 return raw_output;
388 }
389 Tensor<std::string>& AllocateStringTensor(size_t ith_output) {
390 auto tensor = std::make_unique<OrtTensor<std::string>>(api_, ctx_, ith_output, false);
391 Tensor<std::string>& output = *tensor;
392 tensors_.emplace_back(tensor.release());
393 return output;
394 }
395 size_t Size() const {
396 return tensors_.size();
397 }
398
399 const TensorBasePtr& operator[](size_t indice) const {
400 return tensors_.at(indice);
401 }
402
403 private:
404 const OrtW::CustomOpApi& api_;
405 OrtKernelContext& ctx_;
406 size_t indice_;
407 std::string mem_type_ = "Cpu";
408 TensorBasePtrs tensors_;
409};
410
411#if ORT_API_VERSION >= 17
412
413class OrtGraphKernelContext : public KernelContext {
414 public:
415 OrtGraphKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api) {
416 OrtMemoryInfo* info;
417 OrtW::ThrowOnError(api, api.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
418 OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, info, &allocator_));
419 api.ReleaseMemoryInfo(info);
420 }
421
422 virtual ~OrtGraphKernelContext() {
423 if (allocator_) {
424 api_.ReleaseAllocator(allocator_);
425 }
426 }
427
428 void* AllocScratchBuffer(size_t size) override {
429 return allocator_->Alloc(allocator_, size);
430 }
431
432 void FreeScratchBuffer(void* p) override {
433 if (p) {
434 allocator_->Free(allocator_, p);
435 }
436 }
437
438 private:
439 const OrtApi& api_;
440 OrtAllocator* allocator_;
441};
442
443#endif
444
445#ifdef USE_CUDA
446
447enum CudaResource {
448 cuda_handle_t = 10000,
449 cudnn_handle_t,
450 cublas_handle_t,
451 deferred_cpu_allocator_t,
452 // below are cuda ep options
453 device_id_t,
454};
455
456#if ORT_API_VERSION >= 17
457class OrtGraphCudaKernelContext : public CUDAKernelContext {
458 public:
459 static const int cuda_resource_ver = 1;
460
461 OrtGraphCudaKernelContext(const OrtApi& api, const OrtKernelContext& ctx) : api_(api) {
462 api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream_);
463 if (!cuda_stream_) {
464 ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION);
465 }
466 api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cublas_handle_t, &cublas_);
467 if (!cublas_) {
468 ORTX_CXX_API_THROW("Failed to fetch cublas handle from context", ORT_RUNTIME_EXCEPTION);
469 }
470 void* resource = nullptr;
471 OrtStatusPtr result = api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::device_id_t, &resource);
472 if (result) {
473 ORTX_CXX_API_THROW("Failed to fetch device id from context", ORT_RUNTIME_EXCEPTION);
474 }
475 memcpy(&device_id_, &resource, sizeof(int));
476
477 OrtMemoryInfo* info;
478 OrtW::ThrowOnError(api, api.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
479 OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, info, &cpu_allocator_));
480 api.ReleaseMemoryInfo(info);
481
482 OrtMemoryInfo* cuda_mem_info;
483 OrtW::ThrowOnError(api, api.CreateMemoryInfo("Cuda", OrtAllocatorType::OrtArenaAllocator, device_id_, OrtMemType::OrtMemTypeDefault, &cuda_mem_info));
484 OrtW::ThrowOnError(api, api.KernelContext_GetAllocator(&ctx, cuda_mem_info, &cuda_allocator_));
485 api.ReleaseMemoryInfo(cuda_mem_info);
486 }
487
488 virtual ~OrtGraphCudaKernelContext() {
489 if (cpu_allocator_) {
490 api_.ReleaseAllocator(cpu_allocator_);
491 }
492 if (cuda_allocator_) {
493 api_.ReleaseAllocator(cuda_allocator_);
494 }
495 }
496
497 void* AllocScratchBuffer(size_t size) override {
498 return cpu_allocator_->Alloc(cpu_allocator_, size);
499 }
500
501 void FreeScratchBuffer(void* p) override {
502 if (p) {
503 cpu_allocator_->Free(cpu_allocator_, p);
504 }
505 }
506
507 void* AllocCudaScratchBuffer(size_t size) override {
508 return cuda_allocator_->Alloc(cuda_allocator_, size);
509 }
510
511 void FreeCudaScratchBuffer(void* p) override {
512 if (p) {
513 cuda_allocator_->Free(cuda_allocator_, p);
514 }
515 }
516
517 void* GetCudaStream() const override {
518 return cuda_stream_;
519 }
520
521 void* GetCublasHandle() const override {
522 return cublas_;
523 }
524
525 int GetCudaDeviceId() const override {
526 return device_id_;
527 }
528
529 private:
530 const OrtApi& api_;
531 OrtAllocator* cpu_allocator_;
532 OrtAllocator* cuda_allocator_;
533 void* cuda_stream_ = {};
534 void* cublas_ = {};
535 int device_id_ = 0;
536};
537
538#endif
539#endif
540
541// using mf16_t = uint16_t;
542
543struct OrtLiteCustomOp : public OrtCustomOp {
544 // CreateTuple
545 template <size_t ith_input, size_t ith_output, typename... Ts>
546 static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
547 CreateTuple(const OrtW::CustomOpApi*, OrtKernelContext*, std::vector<TensorPtr>&, size_t, size_t, const std::string&) {
548 return std::make_tuple();
549 }
550
551 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
552 static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type
553 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
554 std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
555 auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
556 return std::tuple_cat(current, next);
557 }
558
559#if ORT_API_VERSION >= 17
560 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
561 static typename std::enable_if<std::is_same<T, KernelContext*>::value, std::tuple<T, Ts...>>::type
562 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
563 tensors.push_back(std::make_unique<OrtGraphKernelContext>(api->GetOrtApi(), *context));
564 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
565 auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
566 return std::tuple_cat(current, next);
567 }
568
569#ifdef USE_CUDA
570
571 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
572 static typename std::enable_if<std::is_same<T, CUDAKernelContext*>::value, std::tuple<T, Ts...>>::type
573 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
574 tensors.push_back(std::make_unique<OrtGraphCudaKernelContext>(api->GetOrtApi(), *context));
575 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
576 auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
577 return std::tuple_cat(current, next);
578 }
579
580#endif
581
582#endif
583
584#if ORT_API_VERSION >= 14
585 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
586 static typename std::enable_if<std::is_same<T, const Variadic*>::value, std::tuple<T, Ts...>>::type
587 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
588 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
589 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
590 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
591 return std::tuple_cat(current, next);
592 }
593
594 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
595 static typename std::enable_if<std::is_same<T, const Variadic&>::value, std::tuple<T, Ts...>>::type
596 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
597 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
598 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
599 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
600 return std::tuple_cat(current, next);
601 }
602
603 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
604 static typename std::enable_if<std::is_same<T, Variadic*>::value, std::tuple<T, Ts...>>::type
605 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
606 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
607 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
608 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
609 return std::tuple_cat(current, next);
610 }
611
612 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
613 static typename std::enable_if<std::is_same<T, Variadic&>::value, std::tuple<T, Ts...>>::type
614 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
615 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
616 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
617 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
618 return std::tuple_cat(current, next);
619 }
620#endif
621
622#undef data_type_def
623#define data_type_def bool
624#include "tensor_tuple.inc"
625
626#if ORT_API_VERSION >= 16
627#undef data_type_def
628#define data_type_def BFloat16
629#include "tensor_tuple.inc"
630
631#undef data_type_def
632#define data_type_def MFloat16
633#include "tensor_tuple.inc"
634#endif
635
636#undef data_type_def
637#define data_type_def float
638#include "tensor_tuple.inc"
639
640#undef data_type_def
641#define data_type_def double
642#include "tensor_tuple.inc"
643
644#undef data_type_def
645#define data_type_def int8_t
646#include "tensor_tuple.inc"
647
648#undef data_type_def
649#define data_type_def int16_t
650#include "tensor_tuple.inc"
651
652#undef data_type_def
653#define data_type_def int32_t
654#include "tensor_tuple.inc"
655
656#undef data_type_def
657#define data_type_def int64_t
658#include "tensor_tuple.inc"
659
660#undef data_type_def
661#define data_type_def uint8_t
662#include "tensor_tuple.inc"
663
664#undef data_type_def
665#define data_type_def uint16_t
666#include "tensor_tuple.inc"
667
668#undef data_type_def
669#define data_type_def uint32_t
670#include "tensor_tuple.inc"
671
672#undef data_type_def
673#define data_type_def uint64_t
674#include "tensor_tuple.inc"
675
676#undef data_type_def
677#define data_type_def std::string
678#include "tensor_tuple.inc"
679
680#undef data_type_def
681#define data_type_def std::string_view
682#include "tensor_tuple.inc"
683
684#undef data_type_def
685
686 // ParseArgs ...
687 template <typename... Ts>
688 static typename std::enable_if<0 == sizeof...(Ts)>::type
689 ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
690 }
691
692 template <typename T, typename... Ts>
693 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type
694 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
695 ParseArgs<Ts...>(input_types, output_types);
696 }
697
698#if ORT_API_VERSION >= 17
699 template <typename T, typename... Ts>
700 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, KernelContext*>::value>::type
701 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
702 ParseArgs<Ts...>(input_types, output_types);
703 }
704
705#ifdef USE_CUDA
706 template <typename T, typename... Ts>
707 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, CUDAKernelContext*>::value>::type
708 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
709 ParseArgs<Ts...>(input_types, output_types);
710 }
711#endif
712#endif
713
714#if ORT_API_VERSION >= 14
715 template <typename T, typename... Ts>
716 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic&>::value>::type
717 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
718 if (!input_types.empty()) {
719 ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
720 }
721 input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
722 ParseArgs<Ts...>(input_types, output_types);
723 }
724
725 template <typename T, typename... Ts>
726 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic*>::value>::type
727 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
728 if (!input_types.empty()) {
729 ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
730 }
731 input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
732 ParseArgs<Ts...>(input_types, output_types);
733 }
734
735 template <typename T, typename... Ts>
736 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic&>::value>::type
737 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
738 if (!output_types.empty()) {
739 ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
740 }
741 output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
742 ParseArgs<Ts...>(input_types, output_types);
743 }
744
745 template <typename T, typename... Ts>
746 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic*>::value>::type
747 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
748 if (!output_types.empty()) {
749 ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
750 }
751 output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
752 ParseArgs<Ts...>(input_types, output_types);
753 }
754#endif
755
756#define PARSE_INPUT_BASE(pack_type, onnx_type) \
757 template <typename T, typename... Ts> \
758 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
759 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
760 input_types.push_back(onnx_type); \
761 ParseArgs<Ts...>(input_types, output_types); \
762 } \
763 template <typename T, typename... Ts> \
764 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \
765 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
766 input_types.push_back(onnx_type); \
767 ParseArgs<Ts...>(input_types, output_types); \
768 }
769
770#define PARSE_INPUT(data_type, onnx_type) \
771 PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
772 PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
773 PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \
774 PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \
775 PARSE_INPUT_BASE(data_type, onnx_type)
776
777#define PARSE_OUTPUT(data_type, onnx_type) \
778 template <typename T, typename... Ts> \
779 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \
780 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
781 output_types.push_back(onnx_type); \
782 ParseArgs<Ts...>(input_types, output_types); \
783 } \
784 template <typename T, typename... Ts> \
785 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \
786 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
787 output_types.push_back(onnx_type); \
788 ParseArgs<Ts...>(input_types, output_types); \
789 } \
790 template <typename T, typename... Ts> \
791 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
792 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
793 output_types.push_back(onnx_type); \
794 ParseArgs<Ts...>(input_types, output_types); \
795 }
796
797#define PARSE_ARGS(data_type, onnx_type) \
798 PARSE_INPUT(data_type, onnx_type) \
799 PARSE_OUTPUT(data_type, onnx_type)
800
801 PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
802#if ORT_API_VERSION >= 16
803 PARSE_ARGS(MFloat16, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)
804 PARSE_ARGS(BFloat16, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16)
805#endif
806 PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
807 PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
808 PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
809 PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
810 PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
811 PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
812 PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
813 PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
814 PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
815 PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
816 PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
817 PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output
818
819 OrtLiteCustomOp(const char* op_name,
820 const char* execution_provider) : op_name_(op_name),
821 execution_provider_(execution_provider) {
822 // Zero out OrtCustomOp so that any added func pointers are nullptr for forwards compatibility
823 memset(&this->version, 0, sizeof(OrtCustomOp));
824
825 int act_ver = GetActiveOrtAPIVersion();
826 OrtCustomOp::version = act_ver < ORT_API_VERSION ? act_ver : ORT_API_VERSION;
827
828 OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
829 OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
830
831 OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) {
832 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
833 return self->input_types_.size();
834 };
835
836 OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) {
837 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
838 return self->input_types_[indice];
839 };
840
841 OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) {
842 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
843 return self->output_types_.size();
844 };
845
846 OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) {
847 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
848 return self->output_types_[indice];
849 };
850
851#if ORT_API_VERSION >= 14
852 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t) {
853 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
854 return (self->input_types_.empty() || self->input_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
855 };
856
857 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
858 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
859 return (self->output_types_.empty() || self->output_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
860 };
861
862 OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) {
863 return 1;
864 };
865
866 OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) {
867 return 0;
868 };
869
870 OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) {
871 return 1;
872 };
873
874 OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) {
875 return 0;
876 };
877
878 OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) {
879 return OrtMemTypeDefault;
880 };
881#else
882 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) {
883 return INPUT_OUTPUT_OPTIONAL;
884 };
885
886 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
887 return INPUT_OUTPUT_OPTIONAL;
888 };
889#endif
890 }
891
892 const std::string op_name_;
893 const std::string execution_provider_;
894
895 std::vector<ONNXTensorElementDataType> input_types_;
896 std::vector<ONNXTensorElementDataType> output_types_;
897};
898
899template <typename... Args>
900struct OrtLiteCustomFunc : public OrtLiteCustomOp {
901 using ComputeFn = void (*)(Args...);
902 using MyType = OrtLiteCustomFunc<Args...>;
903
904 struct Kernel {
905 ComputeFn compute_fn_{};
906 std::string ep_{};
907 std::unique_ptr<OrtW::CustomOpApi> api_;
908 };
909
910 OrtLiteCustomFunc(const char* op_name,
911 const char* execution_provider,
912 ComputeFn compute_fn) : OrtLiteCustomOp(op_name, execution_provider),
913 compute_fn_(compute_fn) {
914 ParseArgs<Args...>(input_types_, output_types_);
915
916 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
917 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
918 std::vector<TensorPtr> tensors;
919 auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
920 context,
921 tensors,
922 kernel->api_->KernelContext_GetInputCount(context),
923 kernel->api_->KernelContext_GetOutputCount(context),
924 kernel->ep_);
925 std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t);
926 };
927
928 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
929 auto kernel = std::make_unique<Kernel>();
930 auto self = static_cast<const OrtLiteCustomFunc*>(this_);
931 kernel->compute_fn_ = self->compute_fn_;
932 kernel->ep_ = self->execution_provider_;
933 kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
934 return reinterpret_cast<void*>(kernel.release());
935 };
936
937 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
938 delete reinterpret_cast<Kernel*>(op_kernel);
939 };
940 }
941
942 ComputeFn compute_fn_;
943};
944
945class OrtAttributeReader {
946 public:
947 OrtAttributeReader(const OrtApi& api, const OrtKernelInfo& info) : base_kernel_(api, info) {
948 }
949
950 template <class T>
951 T TryToGetAttributeWithDefault(const char* name, const T& default_value) const noexcept {
952 return base_kernel_.TryToGetAttributeWithDefault(name, default_value);
953 }
954
955 private:
956 BaseKernel base_kernel_;
957};
958
959template <typename CustomOp>
960struct OrtLiteCustomStruct : public OrtLiteCustomOp {
961 template <typename... Args>
962 using CustomComputeFn = void (CustomOp::*)(Args...) const;
963 using MyType = OrtLiteCustomStruct<CustomOp>;
964
965 struct Kernel {
966 std::unique_ptr<CustomOp> custom_op_;
967 std::string ep_{};
968 std::unique_ptr<OrtW::CustomOpApi> api_;
969 };
970
971 OrtLiteCustomStruct(const char* op_name,
972 const char* execution_provider) : OrtLiteCustomOp(op_name,
973 execution_provider) {
974 init(&CustomOp::Compute);
975 }
976
977 template <typename... Args>
978 void init(CustomComputeFn<Args...>) {
979 ParseArgs<Args...>(input_types_, output_types_);
980
981 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
982 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
983 std::vector<TensorPtr> tensors;
984 auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
985 context,
986 tensors,
987 kernel->api_->KernelContext_GetInputCount(context),
988 kernel->api_->KernelContext_GetOutputCount(context),
989 kernel->ep_);
990 std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t);
991 };
992
993 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
994 auto kernel = std::make_unique<Kernel>();
995
996 if constexpr (std::is_constructible<CustomOp, const OrtApi&, const OrtKernelInfo&>::value) {
997 kernel->custom_op_ = std::make_unique<CustomOp>(*ort_api, *info);
998 } else {
999 kernel->custom_op_ = std::make_unique<CustomOp>(OrtAttributeReader(*ort_api, *info));
1000 }
1001 auto self = static_cast<const MyType*>(this_);
1002 kernel->ep_ = self->execution_provider_;
1003 kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
1004 return reinterpret_cast<void*>(kernel.release());
1005 };
1006
1007 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
1008 delete reinterpret_cast<Kernel*>(op_kernel);
1009 };
1010 }
1011};
1012
1013template <typename... Args>
1014OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
1015 const char* execution_provider,
1016 void (*custom_compute_fn)(Args...)) {
1017 using LiteOp = OrtLiteCustomFunc<Args...>;
1018 return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn).release();
1019}
1020
1021template <typename CustomOp>
1022OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
1023 const char* execution_provider) {
1024 using LiteOp = OrtLiteCustomStruct<CustomOp>;
1025 return std::make_unique<LiteOp>(op_name, execution_provider).release();
1026}
1027
1028} // namespace Custom
1029} // namespace Ort
1030