microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
82a59c634565e72b210831c109bf971b90030efb

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

include/custom_op/custom_op_lite.h

1035lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5
6#include <optional>
7#include <numeric>
8#include <string>
9#include <string_view>
10#include "tensor_api.h"
11#include "ort_c_to_cpp.h"
12#include "onnxruntime_f16.h"
13
14namespace Ort {
15namespace Custom {
16
17class OrtKernelContextStorage : public ITensorStorage {
18 public:
19 OrtKernelContextStorage(const OrtW::CustomOpApi& custom_op_api,
20 OrtKernelContext& ctx,
21 size_t indice,
22 bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice) {
23 if (is_input) {
24 auto input_count = api_.KernelContext_GetInputCount(&ctx);
25 if (indice >= input_count) {
26 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
27 }
28 const_value_ = api_.KernelContext_GetInput(&ctx, indice);
29 auto* info = api_.GetTensorTypeAndShape(const_value_);
30 shape_ = api_.GetTensorShape(info);
31 api_.ReleaseTensorTypeAndShapeInfo(info);
32 }
33 }
34
35 const std::vector<int64_t>& Shape() const override {
36 if (!IsInitialized())
37 ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
38 return *shape_;
39 }
40
41 virtual bool IsInitialized() const override {
42 return shape_.has_value();
43 }
44
45 const void* DataRaw() const override {
46 return api_.GetTensorRawData(const_value_);
47 }
48
49 void* Initialize(const std::vector<int64_t>& shape, size_t element_size) override {
50 if (!const_value_) {
51 const_value_ = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size());
52 shape_ = shape;
53 }
54 return api_.GetTensorMutableRawData(const_cast<OrtValue*>(const_value_));
55 }
56
57 void* Release() override {
58 ORTX_CXX_API_THROW("Can't release the tensor buffer with ORT graph mode.", ORT_RUNTIME_EXCEPTION);
59 }
60
61 private:
62 const OrtW::CustomOpApi& api_;
63 OrtKernelContext& ctx_;
64 size_t indice_;
65 const OrtValue* const_value_{}; // for input
66 std::optional<std::vector<int64_t>> shape_;
67};
68
69static std::string get_mem_type(const OrtW::CustomOpApi& custom_op_api,
70 OrtKernelContext& ctx,
71 size_t indice,
72 bool is_input) {
73 std::string output = "Cpu";
74 if (is_input) {
75 const OrtValue* const_value = custom_op_api.KernelContext_GetInput(&ctx, indice);
76 const OrtMemoryInfo* mem_info = {};
77 custom_op_api.ThrowOnError(custom_op_api.GetOrtApi().GetTensorMemoryInfo(const_value, &mem_info));
78 if (mem_info) {
79 const char* mem_type = nullptr;
80 custom_op_api.ThrowOnError(custom_op_api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type));
81 if (mem_type) {
82 output = mem_type;
83 }
84 }
85 }
86 return output;
87}
88
89template <typename T>
90class OrtTensor : public Tensor<T> {
91 public:
92 OrtTensor(const OrtW::CustomOpApi& custom_op_api,
93 OrtKernelContext& ctx,
94 size_t indice,
95 bool is_input) : Tensor<T>(std::make_unique<OrtKernelContextStorage>(custom_op_api, ctx, indice, is_input)),
96 mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {
97 }
98
99 bool IsCpuTensor() const {
100 return mem_type_ == "Cpu";
101 }
102
103 private:
104 std::string mem_type_ = "Cpu";
105};
106
107class OrtStringTensorStorage : public IStringTensorStorage<std::string> {
108 public:
109 using strings = std::vector<std::string>;
110 OrtStringTensorStorage(const OrtW::CustomOpApi& custom_op_api,
111 OrtKernelContext& ctx,
112 size_t indice,
113 bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice) {
114 if (is_input) {
115 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
116 if (indice >= input_count) {
117 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
118 }
119
120 auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
121 auto* info = api_.GetTensorTypeAndShape(const_value);
122 shape_ = api_.GetTensorShape(info);
123 api_.ReleaseTensorTypeAndShapeInfo(info);
124
125 size_t num_chars;
126 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
127 std::vector<char> chars(num_chars + 1, '\0');
128 // assert((*shape_).size() == 1 || ((*shape_).size() == 2 && (*shape_)[0] == 1));
129
130 int64_t num_strings = 1; // string scalar
131 if ((*shape_).size() > 0) {
132 num_strings = (*shape_).front();
133 for (auto iter = (*shape_).begin() + 1; iter != (*shape_).end(); ++iter) {
134 num_strings *= *iter;
135 }
136 }
137 std::vector<size_t> offsets(num_strings);
138 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
139 (void*)chars.data(),
140 num_chars,
141 offsets.data(),
142 offsets.size()));
143 auto upper_bound = static_cast<int64_t>(num_strings) - 1;
144 input_strings_.resize(num_strings);
145 for (int64_t i = upper_bound; i >= 0; --i) {
146 if (i < upper_bound) {
147 chars[offsets[i + 1]] = '\0';
148 }
149 input_strings_[i] = chars.data() + offsets[i];
150 }
151 }
152 }
153
154 const std::vector<int64_t>& Shape() const override {
155 if (!IsInitialized())
156 ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
157 return *shape_;
158 }
159
160 virtual const void* DataRaw() const override {
161 if (input_strings_.size() != 1) {
162 ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
163 }
164 return reinterpret_cast<const void*>(input_strings_[0].c_str());
165 }
166
167 virtual bool IsInitialized() const override {
168 return shape_.has_value();
169 }
170
171 virtual void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) override {
172 std::vector<const char*> raw;
173 for (const auto& s : ss) {
174 raw.push_back(s.data());
175 }
176 auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
177 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, raw.data(), raw.size()));
178 }
179
180 virtual void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) override {
181 auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
182 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, ss.data(), ss.size()));
183 }
184
185 const strings& Data() const override {
186 return input_strings_;
187 }
188
189 private:
190 const OrtW::CustomOpApi& api_;
191 OrtKernelContext& ctx_;
192 size_t indice_;
193 std::vector<std::string> input_strings_;
194 std::optional<std::vector<int64_t>> shape_;
195};
196
197class OrtStringViewTensorStorage : public IStringTensorStorage<std::string_view> {
198 public:
199 using strings = std::vector<std::string_view>;
200 OrtStringViewTensorStorage(const OrtW::CustomOpApi& custom_op_api,
201 OrtKernelContext& ctx,
202 size_t indice,
203 bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice) {
204 if (is_input) {
205 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
206 if (indice >= input_count) {
207 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
208 }
209 auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
210 auto* info = api_.GetTensorTypeAndShape(const_value);
211 shape_ = api_.GetTensorShape(info);
212 api_.ReleaseTensorTypeAndShapeInfo(info);
213
214 size_t num_chars;
215 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
216 chars_.resize(num_chars + 1, '\0');
217
218 size_t num_strings = 1;
219 if ((*shape_).size() > 0) {
220 num_strings = static_cast<size_t>((*shape_)[0]);
221 }
222
223 if (num_strings) {
224 std::vector<size_t> offsets(num_strings);
225 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
226 (void*)chars_.data(),
227 num_chars,
228 offsets.data(),
229 offsets.size()));
230 offsets.push_back(num_chars);
231 for (size_t i = 0; i < num_strings; ++i) {
232 input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
233 }
234 }
235 }
236 }
237
238 const std::vector<int64_t>& Shape() const override {
239 if (!IsInitialized())
240 ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION);
241 return *shape_;
242 }
243
244 virtual const void* DataRaw() const override {
245 if (input_string_views_.size() != 1) {
246 ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
247 }
248 return reinterpret_cast<const void*>(input_string_views_[0].data());
249 }
250
251 virtual bool IsInitialized() const override {
252 return shape_.has_value();
253 }
254
255 virtual void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) override {
256 ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
257 }
258
259 virtual void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) override {
260 ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION);
261 }
262
263 const strings& Data() const override {
264 return input_string_views_;
265 }
266
267 private:
268 const OrtW::CustomOpApi& api_;
269 OrtKernelContext& ctx_;
270 size_t indice_;
271 std::vector<char> chars_; // for input
272 std::vector<std::string_view> input_string_views_; // for input
273 std::optional<std::vector<int64_t>> shape_;
274};
275
276// to make the metaprogramming magic happy.
277template <>
278class OrtTensor<std::string> : public Tensor<std::string> {
279 public:
280 OrtTensor(const OrtW::CustomOpApi& custom_op_api,
281 OrtKernelContext& ctx,
282 size_t indice,
283 bool is_input) : Tensor<std::string>(std::make_unique<OrtStringTensorStorage>(custom_op_api, ctx, indice, is_input)),
284 mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {}
285
286 bool IsCpuTensor() const {
287 return mem_type_ == "Cpu";
288 }
289
290 private:
291 std::string mem_type_ = "Cpu";
292};
293
294template <>
295class OrtTensor<std::string_view> : public Tensor<std::string_view> {
296 public:
297 OrtTensor(const OrtW::CustomOpApi& custom_op_api,
298 OrtKernelContext& ctx,
299 size_t indice,
300 bool is_input) : Tensor<std::string_view>(std::make_unique<OrtStringViewTensorStorage>(custom_op_api, ctx, indice, is_input)),
301 mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {}
302
303 bool IsCpuTensor() const {
304 return mem_type_ == "Cpu";
305 }
306
307 private:
308 std::string mem_type_ = "Cpu";
309};
310
311using TensorPtr = std::unique_ptr<Custom::Arg>;
312using TensorPtrs = std::vector<TensorPtr>;
313
314using TensorBasePtr = std::unique_ptr<Custom::TensorBase>;
315using TensorBasePtrs = std::vector<TensorBasePtr>;
316
317// Represent variadic input or output
318struct Variadic : public Arg {
319 Variadic(const OrtW::CustomOpApi& custom_op_api,
320 OrtKernelContext& ctx,
321 size_t indice,
322 bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice), mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {
323#if ORT_API_VERSION < 14
324 ORTX_CXX_API_THROW("Variadic input or output only supported after onnxruntime 1.14", ORT_RUNTIME_EXCEPTION);
325#endif
326 if (is_input) {
327 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
328 for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
329 auto* const_value = api_.KernelContext_GetInput(&ctx_, ith_input);
330 auto* info = api_.GetTensorTypeAndShape(const_value);
331 auto type = api_.GetTensorElementType(info);
332 api_.ReleaseTensorTypeAndShapeInfo(info);
333 TensorBasePtr tensor;
334 switch (type) {
335 case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
336 tensor = std::make_unique<Custom::OrtTensor<bool>>(api_, ctx, ith_input, true);
337 break;
338 case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
339 tensor = std::make_unique<Custom::OrtTensor<float>>(api_, ctx, ith_input, true);
340 break;
341 case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
342 tensor = std::make_unique<Custom::OrtTensor<double>>(api_, ctx, ith_input, true);
343 break;
344 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
345 tensor = std::make_unique<Custom::OrtTensor<uint8_t>>(api_, ctx, ith_input, true);
346 break;
347 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
348 tensor = std::make_unique<Custom::OrtTensor<int8_t>>(api_, ctx, ith_input, true);
349 break;
350 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
351 tensor = std::make_unique<Custom::OrtTensor<uint16_t>>(api_, ctx, ith_input, true);
352 break;
353 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
354 tensor = std::make_unique<Custom::OrtTensor<int16_t>>(api_, ctx, ith_input, true);
355 break;
356 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
357 tensor = std::make_unique<Custom::OrtTensor<uint32_t>>(api_, ctx, ith_input, true);
358 break;
359 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
360 tensor = std::make_unique<Custom::OrtTensor<int32_t>>(api_, ctx, ith_input, true);
361 break;
362 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
363 tensor = std::make_unique<Custom::OrtTensor<uint64_t>>(api_, ctx, ith_input, true);
364 break;
365 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
366 tensor = std::make_unique<Custom::OrtTensor<int64_t>>(api_, ctx, ith_input, true);
367 break;
368 case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
369 tensor = std::make_unique<Custom::OrtTensor<std::string>>(api_, ctx, ith_input, true);
370 break;
371 default:
372 ORTX_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION);
373 break;
374 }
375 tensors_.emplace_back(tensor.release());
376 } // for
377 } else {
378 // a Variadic used for output is populated by the Compute so leave tensors_ empty here
379 }
380 }
381 template <typename T>
382 T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) {
383 auto tensor = std::make_unique<OrtTensor<T>>(api_, ctx_, ith_output, false);
384 auto raw_output = tensor.get()->Allocate(shape);
385 tensors_.emplace_back(tensor.release());
386 return raw_output;
387 }
388 Tensor<std::string>& AllocateStringTensor(size_t ith_output) {
389 auto tensor = std::make_unique<OrtTensor<std::string>>(api_, ctx_, ith_output, false);
390 Tensor<std::string>& output = *tensor;
391 tensors_.emplace_back(tensor.release());
392 return output;
393 }
394 size_t Size() const {
395 return tensors_.size();
396 }
397
398 const TensorBasePtr& operator[](size_t indice) const {
399 return tensors_.at(indice);
400 }
401
402 private:
403 const OrtW::CustomOpApi& api_;
404 OrtKernelContext& ctx_;
405 size_t indice_;
406 std::string mem_type_ = "Cpu";
407 TensorBasePtrs tensors_;
408};
409
410#if ORT_API_VERSION >= 17
411
412class OrtGraphKernelContext : public KernelContext {
413 public:
414 OrtGraphKernelContext(const OrtApi& ort_api, const OrtKernelContext& ctx) : api_(ort_api) {
415 OrtMemoryInfo* info;
416 OrtW::ThrowOnError(api_, api_.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
417 OrtW::ThrowOnError(api_, api_.KernelContext_GetAllocator(&ctx, info, &allocator_));
418 api_.ReleaseMemoryInfo(info);
419 }
420
421 virtual ~OrtGraphKernelContext() {
422 if (allocator_) {
423 api_.ReleaseAllocator(allocator_);
424 }
425 }
426
427 void* AllocScratchBuffer(size_t size) override {
428 return allocator_->Alloc(allocator_, size);
429 }
430
431 void FreeScratchBuffer(void* p) override {
432 if (p) {
433 allocator_->Free(allocator_, p);
434 }
435 }
436
437 private:
438 const OrtApi& api_;
439 OrtAllocator* allocator_;
440};
441
442#endif
443
444#ifdef USE_CUDA
445
446enum CudaResource {
447 cuda_handle_t = 10000,
448 cudnn_handle_t,
449 cublas_handle_t,
450 deferred_cpu_allocator_t,
451 // below are cuda ep options
452 device_id_t,
453};
454
455#if ORT_API_VERSION >= 17
456class OrtGraphCudaKernelContext : public CUDAKernelContext {
457 public:
458 static const int cuda_resource_ver = 1;
459
460 OrtGraphCudaKernelContext(const OrtApi& ort_api, const OrtKernelContext& ctx) : api_(ort_api) {
461 OrtStatusPtr result = api_.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream_);
462 if (result || !cuda_stream_) {
463 ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION);
464 }
465 result = api_.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cublas_handle_t, &cublas_);
466 if (result || !cublas_) {
467 ORTX_CXX_API_THROW("Failed to fetch cublas handle from context", ORT_RUNTIME_EXCEPTION);
468 }
469 void* resource = nullptr;
470 result = api_.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::device_id_t, &resource);
471 if (result) {
472 ORTX_CXX_API_THROW("Failed to fetch device id from context", ORT_RUNTIME_EXCEPTION);
473 }
474 memcpy(&device_id_, &resource, sizeof(int));
475
476 OrtMemoryInfo* info;
477 OrtW::ThrowOnError(api_, api_.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info));
478 OrtW::ThrowOnError(api_, api_.KernelContext_GetAllocator(&ctx, info, &cpu_allocator_));
479 api_.ReleaseMemoryInfo(info);
480
481 OrtMemoryInfo* cuda_mem_info;
482 OrtW::ThrowOnError(api_, api_.CreateMemoryInfo("Cuda", OrtAllocatorType::OrtArenaAllocator, device_id_, OrtMemType::OrtMemTypeDefault, &cuda_mem_info));
483 OrtW::ThrowOnError(api_, api_.KernelContext_GetAllocator(&ctx, cuda_mem_info, &cuda_allocator_));
484 api_.ReleaseMemoryInfo(cuda_mem_info);
485 }
486
487 virtual ~OrtGraphCudaKernelContext() {
488 if (cpu_allocator_) {
489 api_.ReleaseAllocator(cpu_allocator_);
490 }
491 if (cuda_allocator_) {
492 api_.ReleaseAllocator(cuda_allocator_);
493 }
494 }
495
496 void* AllocScratchBuffer(size_t size) override {
497 return cpu_allocator_->Alloc(cpu_allocator_, size);
498 }
499
500 void FreeScratchBuffer(void* p) override {
501 if (p) {
502 cpu_allocator_->Free(cpu_allocator_, p);
503 }
504 }
505
506 void* AllocCudaScratchBuffer(size_t size) override {
507 return cuda_allocator_->Alloc(cuda_allocator_, size);
508 }
509
510 void FreeCudaScratchBuffer(void* p) override {
511 if (p) {
512 cuda_allocator_->Free(cuda_allocator_, p);
513 }
514 }
515
516 void* GetCudaStream() const override {
517 return cuda_stream_;
518 }
519
520 void* GetCublasHandle() const override {
521 return cublas_;
522 }
523
524 int GetCudaDeviceId() const override {
525 return device_id_;
526 }
527
528 private:
529 const OrtApi& api_;
530 OrtAllocator* cpu_allocator_;
531 OrtAllocator* cuda_allocator_;
532 void* cuda_stream_ = {};
533 void* cublas_ = {};
534 int device_id_ = 0;
535};
536
537#endif
538#endif
539
540// using mf16_t = uint16_t;
541
542struct OrtLiteCustomOp : public OrtCustomOp {
543 // CreateTuple
544 template <size_t ith_input, size_t ith_output, typename... Ts>
545 static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
546 CreateTuple(const OrtW::CustomOpApi*, OrtKernelContext*, std::vector<TensorPtr>&, size_t, size_t, const std::string&) {
547 return std::make_tuple();
548 }
549
550 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
551 static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type
552 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
553 std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
554 auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
555 return std::tuple_cat(current, next);
556 }
557
558#if ORT_API_VERSION >= 17
559 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
560 static typename std::enable_if<std::is_same<T, KernelContext*>::value, std::tuple<T, Ts...>>::type
561 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
562 tensors.push_back(std::make_unique<OrtGraphKernelContext>(api->GetOrtApi(), *context));
563 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
564 auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
565 return std::tuple_cat(current, next);
566 }
567
568#ifdef USE_CUDA
569
570 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
571 static typename std::enable_if<std::is_same<T, CUDAKernelContext*>::value, std::tuple<T, Ts...>>::type
572 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
573 tensors.push_back(std::make_unique<OrtGraphCudaKernelContext>(api->GetOrtApi(), *context));
574 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
575 auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
576 return std::tuple_cat(current, next);
577 }
578
579#endif
580
581#endif
582
583#if ORT_API_VERSION >= 14
584 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
585 static typename std::enable_if<std::is_same<T, const Variadic*>::value, std::tuple<T, Ts...>>::type
586 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
587 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
588 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
589 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
590 return std::tuple_cat(current, next);
591 }
592
593 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
594 static typename std::enable_if<std::is_same<T, const Variadic&>::value, std::tuple<T, Ts...>>::type
595 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
596 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
597 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
598 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
599 return std::tuple_cat(current, next);
600 }
601
602 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
603 static typename std::enable_if<std::is_same<T, Variadic*>::value, std::tuple<T, Ts...>>::type
604 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
605 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
606 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
607 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
608 return std::tuple_cat(current, next);
609 }
610
611 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
612 static typename std::enable_if<std::is_same<T, Variadic&>::value, std::tuple<T, Ts...>>::type
613 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
614 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
615 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
616 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
617 return std::tuple_cat(current, next);
618 }
619#endif
620
621#undef data_type_def
622#define data_type_def bool
623#include "tensor_tuple.inc"
624
625#if ORT_API_VERSION >= 16
626#undef data_type_def
627#define data_type_def BFloat16
628#include "tensor_tuple.inc"
629
630#undef data_type_def
631#define data_type_def MFloat16
632#include "tensor_tuple.inc"
633#endif
634
635#undef data_type_def
636#define data_type_def float
637#include "tensor_tuple.inc"
638
639#undef data_type_def
640#define data_type_def double
641#include "tensor_tuple.inc"
642
643#undef data_type_def
644#define data_type_def int8_t
645#include "tensor_tuple.inc"
646
647#undef data_type_def
648#define data_type_def int16_t
649#include "tensor_tuple.inc"
650
651#undef data_type_def
652#define data_type_def int32_t
653#include "tensor_tuple.inc"
654
655#undef data_type_def
656#define data_type_def int64_t
657#include "tensor_tuple.inc"
658
659#undef data_type_def
660#define data_type_def uint8_t
661#include "tensor_tuple.inc"
662
663#undef data_type_def
664#define data_type_def uint16_t
665#include "tensor_tuple.inc"
666
667#undef data_type_def
668#define data_type_def uint32_t
669#include "tensor_tuple.inc"
670
671#undef data_type_def
672#define data_type_def uint64_t
673#include "tensor_tuple.inc"
674
675#undef data_type_def
676#define data_type_def std::string
677#include "tensor_tuple.inc"
678
679#undef data_type_def
680#define data_type_def std::string_view
681#include "tensor_tuple.inc"
682
683#undef data_type_def
684
685 // ParseArgs ...
686 template <typename... Ts>
687 static typename std::enable_if<0 == sizeof...(Ts)>::type
688 ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
689 }
690
691 template <typename T, typename... Ts>
692 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type
693 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
694 ParseArgs<Ts...>(input_types, output_types);
695 }
696
697#if ORT_API_VERSION >= 17
698 template <typename T, typename... Ts>
699 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, KernelContext*>::value>::type
700 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
701 ParseArgs<Ts...>(input_types, output_types);
702 }
703
704#ifdef USE_CUDA
705 template <typename T, typename... Ts>
706 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, CUDAKernelContext*>::value>::type
707 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
708 ParseArgs<Ts...>(input_types, output_types);
709 }
710#endif
711#endif
712
713#if ORT_API_VERSION >= 14
714 template <typename T, typename... Ts>
715 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic&>::value>::type
716 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
717 if (!input_types.empty()) {
718 ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
719 }
720 input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
721 ParseArgs<Ts...>(input_types, output_types);
722 }
723
724 template <typename T, typename... Ts>
725 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic*>::value>::type
726 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
727 if (!input_types.empty()) {
728 ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
729 }
730 input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
731 ParseArgs<Ts...>(input_types, output_types);
732 }
733
734 template <typename T, typename... Ts>
735 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic&>::value>::type
736 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
737 if (!output_types.empty()) {
738 ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
739 }
740 output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
741 ParseArgs<Ts...>(input_types, output_types);
742 }
743
744 template <typename T, typename... Ts>
745 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic*>::value>::type
746 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
747 if (!output_types.empty()) {
748 ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
749 }
750 output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
751 ParseArgs<Ts...>(input_types, output_types);
752 }
753#endif
754
755#define PARSE_INPUT_BASE(pack_type, onnx_type) \
756 template <typename T, typename... Ts> \
757 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
758 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
759 input_types.push_back(onnx_type); \
760 ParseArgs<Ts...>(input_types, output_types); \
761 } \
762 template <typename T, typename... Ts> \
763 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \
764 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
765 input_types.push_back(onnx_type); \
766 ParseArgs<Ts...>(input_types, output_types); \
767 }
768
769#define PARSE_INPUT(data_type, onnx_type) \
770 PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
771 PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
772 PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \
773 PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \
774 PARSE_INPUT_BASE(data_type, onnx_type)
775
776#define PARSE_OUTPUT(data_type, onnx_type) \
777 template <typename T, typename... Ts> \
778 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \
779 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
780 output_types.push_back(onnx_type); \
781 ParseArgs<Ts...>(input_types, output_types); \
782 } \
783 template <typename T, typename... Ts> \
784 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \
785 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
786 output_types.push_back(onnx_type); \
787 ParseArgs<Ts...>(input_types, output_types); \
788 } \
789 template <typename T, typename... Ts> \
790 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
791 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
792 output_types.push_back(onnx_type); \
793 ParseArgs<Ts...>(input_types, output_types); \
794 }
795
796#define PARSE_ARGS(data_type, onnx_type) \
797 PARSE_INPUT(data_type, onnx_type) \
798 PARSE_OUTPUT(data_type, onnx_type)
799
800 PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
801#if ORT_API_VERSION >= 16
802 PARSE_ARGS(MFloat16, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16)
803 PARSE_ARGS(BFloat16, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16)
804#endif
805 PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
806 PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
807 PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
808 PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
809 PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
810 PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
811 PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
812 PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
813 PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
814 PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
815 PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
816 PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output
817
818 OrtLiteCustomOp(const char* op_name,
819 const char* execution_provider) : op_name_(op_name),
820 execution_provider_(execution_provider) {
821 // Zero out OrtCustomOp so that any added func pointers are nullptr for forwards compatibility
822 memset(&this->version, 0, sizeof(OrtCustomOp));
823
824 int act_ver = GetActiveOrtAPIVersion();
825 OrtCustomOp::version = act_ver < ORT_API_VERSION ? act_ver : ORT_API_VERSION;
826
827 OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
828 OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
829
830 OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) {
831 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
832 return self->input_types_.size();
833 };
834
835 OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) {
836 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
837 return self->input_types_[indice];
838 };
839
840 OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) {
841 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
842 return self->output_types_.size();
843 };
844
845 OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) {
846 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
847 return self->output_types_[indice];
848 };
849
850#if ORT_API_VERSION >= 14
851 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t) {
852 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
853 return (self->input_types_.empty() || self->input_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
854 };
855
856 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
857 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
858 return (self->output_types_.empty() || self->output_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
859 };
860
861 OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) {
862 return 1;
863 };
864
865 OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) {
866 return 0;
867 };
868
869 OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) {
870 return 1;
871 };
872
873 OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) {
874 return 0;
875 };
876
877 OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) {
878 return OrtMemTypeDefault;
879 };
880#else
881 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) {
882 return INPUT_OUTPUT_OPTIONAL;
883 };
884
885 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
886 return INPUT_OUTPUT_OPTIONAL;
887 };
888#endif
889
890#if ORT_API_VERSION >= 18
891 OrtCustomOp::GetMayInplace = [](int**, int**) -> size_t {
892 return 0;
893 };
894 OrtCustomOp::ReleaseMayInplace = [](int*, int*) -> void {};
895#endif
896 }
897
898 const std::string op_name_;
899 const std::string execution_provider_;
900
901 std::vector<ONNXTensorElementDataType> input_types_;
902 std::vector<ONNXTensorElementDataType> output_types_;
903};
904
905template <typename... Args>
906struct OrtLiteCustomFunc : public OrtLiteCustomOp {
907 using ComputeFn = void (*)(Args...);
908 using MyType = OrtLiteCustomFunc<Args...>;
909
910 struct Kernel {
911 ComputeFn compute_fn_{};
912 std::string ep_{};
913 std::unique_ptr<OrtW::CustomOpApi> api_;
914 };
915
916 OrtLiteCustomFunc(const char* op_name,
917 const char* execution_provider,
918 ComputeFn compute_fn) : OrtLiteCustomOp(op_name, execution_provider),
919 compute_fn_(compute_fn) {
920 ParseArgs<Args...>(input_types_, output_types_);
921
922 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
923 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
924 std::vector<TensorPtr> tensors;
925 auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
926 context,
927 tensors,
928 kernel->api_->KernelContext_GetInputCount(context),
929 kernel->api_->KernelContext_GetOutputCount(context),
930 kernel->ep_);
931 std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t);
932 };
933
934 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
935 auto kernel = std::make_unique<Kernel>();
936 auto self = static_cast<const OrtLiteCustomFunc*>(this_);
937 kernel->compute_fn_ = self->compute_fn_;
938 kernel->ep_ = self->execution_provider_;
939 kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
940 return reinterpret_cast<void*>(kernel.release());
941 };
942
943 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
944 delete reinterpret_cast<Kernel*>(op_kernel);
945 };
946 }
947
948 ComputeFn compute_fn_;
949};
950
951class OrtAttributeReader {
952 public:
953 OrtAttributeReader(const OrtApi& ort_api, const OrtKernelInfo& info) : base_kernel_(ort_api, info) {
954 }
955
956 template <class T>
957 T TryToGetAttributeWithDefault(const char* name, const T& default_value) const noexcept {
958 return base_kernel_.TryToGetAttributeWithDefault(name, default_value);
959 }
960
961 private:
962 BaseKernel base_kernel_;
963};
964
965template <typename CustomOp>
966struct OrtLiteCustomStruct : public OrtLiteCustomOp {
967 template <typename... Args>
968 using CustomComputeFn = void (CustomOp::*)(Args...) const;
969 using MyType = OrtLiteCustomStruct<CustomOp>;
970
971 struct Kernel {
972 std::unique_ptr<CustomOp> custom_op_;
973 std::string ep_{};
974 std::unique_ptr<OrtW::CustomOpApi> api_;
975 };
976
977 OrtLiteCustomStruct(const char* op_name,
978 const char* execution_provider) : OrtLiteCustomOp(op_name,
979 execution_provider) {
980 init(&CustomOp::Compute);
981 }
982
983 template <typename... Args>
984 void init(CustomComputeFn<Args...>) {
985 ParseArgs<Args...>(input_types_, output_types_);
986
987 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
988 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
989 std::vector<TensorPtr> tensors;
990 auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
991 context,
992 tensors,
993 kernel->api_->KernelContext_GetInputCount(context),
994 kernel->api_->KernelContext_GetOutputCount(context),
995 kernel->ep_);
996 std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t);
997 };
998
999 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
1000 auto kernel = std::make_unique<Kernel>();
1001
1002 if constexpr (std::is_constructible<CustomOp, const OrtApi&, const OrtKernelInfo&>::value) {
1003 kernel->custom_op_ = std::make_unique<CustomOp>(*ort_api, *info);
1004 } else {
1005 kernel->custom_op_ = std::make_unique<CustomOp>(OrtAttributeReader(*ort_api, *info));
1006 }
1007 auto self = static_cast<const MyType*>(this_);
1008 kernel->ep_ = self->execution_provider_;
1009 kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
1010 return reinterpret_cast<void*>(kernel.release());
1011 };
1012
1013 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
1014 delete reinterpret_cast<Kernel*>(op_kernel);
1015 };
1016 }
1017};
1018
1019template <typename... Args>
1020OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
1021 const char* execution_provider,
1022 void (*custom_compute_fn)(Args...)) {
1023 using LiteOp = OrtLiteCustomFunc<Args...>;
1024 return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn).release();
1025}
1026
1027template <typename CustomOp>
1028OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
1029 const char* execution_provider) {
1030 using LiteOp = OrtLiteCustomStruct<CustomOp>;
1031 return std::make_unique<LiteOp>(op_name, execution_provider).release();
1032}
1033
1034} // namespace Custom
1035} // namespace Ort
1036