microsoft/onnxruntime-extensions

Public

mirrored from https://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
sayanshaw/genai-tutorial

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/custom_op_lite.h

966lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5#include "onnxruntime_customop.hpp"
6#include <optional>
7#include <numeric>
8
9namespace Ort {
10namespace Custom {
11
12class TensorBase {
13 public:
14 TensorBase(const OrtW::CustomOpApi& api,
15 OrtKernelContext& ctx,
16 size_t indice,
17 bool is_input) : api_(api),
18 ctx_(ctx),
19 indice_(indice),
20 is_input_(is_input) {}
21
22 virtual ~TensorBase() = default;
23 operator bool() const {
24 return shape_.has_value();
25 }
26 const std::vector<int64_t>& Shape() const {
27 if (shape_.has_value()) {
28 return *shape_;
29 } else {
30 ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION);
31 }
32 }
33 ONNXTensorElementDataType Type() const {
34 return type_;
35 }
36 int64_t NumberOfElement() const {
37 if (shape_.has_value()) {
38 return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies<int64_t>());
39 } else {
40 ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION);
41 }
42 }
43 std::string Shape2Str() const {
44 if (shape_.has_value()) {
45 std::string shape_str;
46 for (const auto& dim : *shape_) {
47 shape_str.append(std::to_string(dim));
48 shape_str.append(", ");
49 }
50 return shape_str;
51 } else {
52 return "empty";
53 }
54 }
55 bool IsCpuTensor() const {
56 return strcmp("Cpu", mem_type_) == 0;
57 }
58 virtual const void* DataRaw() const = 0;
59 virtual size_t SizeInBytes() const = 0;
60
61 protected:
62 const OrtW::CustomOpApi& api_;
63 OrtKernelContext& ctx_;
64 size_t indice_;
65 bool is_input_;
66 std::optional<std::vector<int64_t>> shape_;
67 ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
68 const char* mem_type_ = "Cpu";
69};
70
71template <typename T>
72struct Span {
73 const T* data_ = {};
74 size_t size_ = {};
75 void Assign(const T* data, size_t size) {
76 data_ = data;
77 size_ = size;
78 }
79 size_t size() const { return size_; }
80 T operator[](size_t indice) const {
81 return data_[indice];
82 }
83 const T* data() const { return data_; }
84};
85
86template <typename T>
87class Tensor : public TensorBase {
88 public:
89 using TT = typename std::remove_reference<T>::type;
90 Tensor(const OrtW::CustomOpApi& api,
91 OrtKernelContext& ctx,
92 size_t indice,
93 bool is_input) : TensorBase(api,
94 ctx,
95 indice,
96 is_input) {
97 if (is_input) {
98 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
99 if (indice >= input_count) {
100 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
101 }
102 const_value_ = api_.KernelContext_GetInput(&ctx_, indice);
103 auto* info = api_.GetTensorTypeAndShape(const_value_);
104 shape_ = api_.GetTensorShape(info);
105 type_ = api_.GetTensorElementType(info);
106 api_.ReleaseTensorTypeAndShapeInfo(info);
107 const OrtMemoryInfo* mem_info = {};
108 api_.ThrowOnError(api_.GetOrtApi().GetTensorMemoryInfo(const_value_, &mem_info));
109 if (mem_info) {
110 api_.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_));
111 }
112 }
113 }
114 const TT* Data() const {
115 return api_.GetTensorData<TT>(const_value_);
116 }
117
118 const void* DataRaw() const override {
119 return reinterpret_cast<const void*>(Data());
120 }
121
122 size_t SizeInBytes() const override {
123 return NumberOfElement() * sizeof(TT);
124 }
125
126 TT* Allocate(const std::vector<int64_t>& shape) {
127 if (!data_) {
128 OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size());
129 shape_ = shape;
130 data_ = api_.GetTensorMutableData<TT>(out);
131 }
132 return data_;
133 }
134 const Span<T>& AsSpan() {
135 if (!shape_.has_value() || shape_->size() != 1) {
136 ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
137 }
138 span_.Assign(Data(), (*shape_)[0]);
139 return span_;
140 }
141 const T& AsScalar() {
142 if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
143 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
144 }
145 return *Data();
146 }
147
148 private:
149 const OrtValue* const_value_{}; // for input
150 TT* data_{}; // for output
151 Span<T> span_;
152};
153
154template <>
155class Tensor<std::string> : public TensorBase {
156 public:
157 using strings = std::vector<std::string>;
158
159 Tensor(const OrtW::CustomOpApi& api,
160 OrtKernelContext& ctx,
161 size_t indice,
162 bool is_input) : TensorBase(api,
163 ctx,
164 indice,
165 is_input) {
166 if (is_input) {
167 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
168 if (indice >= input_count) {
169 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
170 }
171
172 auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
173 auto* info = api_.GetTensorTypeAndShape(const_value);
174 shape_ = api_.GetTensorShape(info);
175 type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
176 api_.ReleaseTensorTypeAndShapeInfo(info);
177
178 size_t num_chars;
179 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
180 std::vector<char> chars(num_chars + 1, '\0');
181 auto num_strings = NumberOfElement();
182 std::vector<size_t> offsets(NumberOfElement());
183 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
184 (void*)chars.data(),
185 num_chars,
186 offsets.data(),
187 offsets.size()));
188 auto upper_bound = static_cast<int64_t>(num_strings) - 1;
189 input_strings_.resize(num_strings);
190 for (int64_t i = upper_bound; i >= 0; --i) {
191 if (i < upper_bound) {
192 chars[offsets[i + 1]] = '\0';
193 }
194 input_strings_[i] = chars.data() + offsets[i];
195 }
196 }
197 }
198 const strings& Data() const {
199 return input_strings_;
200 }
201 const void* DataRaw() const override {
202 if (input_strings_.size() != 1) {
203 ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
204 }
205 return reinterpret_cast<const void*>(input_strings_[0].c_str());
206 }
207 size_t SizeInBytes() const override {
208 if (input_strings_.size() != 1) {
209 ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
210 }
211 return input_strings_[0].size();
212 }
213 void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
214 std::vector<const char*> raw;
215 for (const auto& s : ss) {
216 raw.push_back(s.data());
217 }
218 auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
219 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, raw.data(), raw.size()));
220 }
221 void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
222 auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
223 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, ss.data(), ss.size()));
224 }
225 const Span<std::string>& AsSpan() {
226 ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
227 }
228 const std::string& AsScalar() {
229 if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
230 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
231 }
232 return input_strings_[0];
233 }
234
235 private:
236 std::vector<std::string> input_strings_; // for input
237};
238
239template <>
240class Tensor<std::string_view> : public TensorBase {
241 public:
242 using strings = std::vector<std::string>;
243 using string_views = std::vector<std::string_view>;
244
245 Tensor(const OrtW::CustomOpApi& api,
246 OrtKernelContext& ctx,
247 size_t indice,
248 bool is_input) : TensorBase(api,
249 ctx,
250 indice,
251 is_input) {
252 if (is_input_) {
253 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
254 if (indice >= input_count) {
255 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
256 }
257 auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
258 auto* info = api_.GetTensorTypeAndShape(const_value);
259 shape_ = api_.GetTensorShape(info);
260 type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
261 api_.ReleaseTensorTypeAndShapeInfo(info);
262
263 size_t num_chars;
264 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
265 chars_.resize(num_chars + 1, '\0');
266
267 auto num_strings = static_cast<size_t>(NumberOfElement());
268 if (num_strings) {
269 std::vector<size_t> offsets(num_strings);
270 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
271 (void*)chars_.data(),
272 num_chars,
273 offsets.data(),
274 offsets.size()));
275 offsets.push_back(num_chars);
276 for (size_t i = 0; i < num_strings; ++i) {
277 input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
278 }
279 }
280 }
281 }
282 int64_t NumberOfElement() const {
283 if (shape_.has_value()) {
284 return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies<int64_t>());
285 } else {
286 return 0;
287 }
288 }
289 const string_views& Data() const {
290 return input_string_views_;
291 }
292 const void* DataRaw() const override {
293 if (input_string_views_.size() != 1) {
294 ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
295 }
296 return reinterpret_cast<const void*>(input_string_views_[0].data());
297 }
298 size_t SizeInBytes() const override {
299 if (input_string_views_.size() != 1) {
300 ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
301 }
302 return input_string_views_[0].size();
303 }
304 const Span<std::string_view>& AsSpan() {
305 ORTX_CXX_API_THROW("span for TensorT of string view not implemented", ORT_RUNTIME_EXCEPTION);
306 }
307 std::string_view AsScalar() {
308 if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
309 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
310 }
311 return input_string_views_[0];
312 }
313
314 private:
315 std::vector<char> chars_; // for input
316 std::vector<std::string_view> input_string_views_; // for input
317};
318
319using TensorPtr = std::unique_ptr<Custom::TensorBase>;
320using TensorPtrs = std::vector<TensorPtr>;
321
322// Represent variadic input or output
323struct Variadic : public TensorBase {
324 Variadic(const OrtW::CustomOpApi& api,
325 OrtKernelContext& ctx,
326 size_t indice,
327 bool is_input) : TensorBase(api,
328 ctx,
329 indice,
330 is_input) {
331#if ORT_API_VERSION < 14
332 ORTX_CXX_API_THROW("Variadic input or output only supported after onnxruntime 1.14", ORT_RUNTIME_EXCEPTION);
333#endif
334 if (is_input) {
335 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
336 for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
337 auto* const_value = api_.KernelContext_GetInput(&ctx_, ith_input);
338 auto* info = api_.GetTensorTypeAndShape(const_value);
339 auto type = api_.GetTensorElementType(info);
340 api_.ReleaseTensorTypeAndShapeInfo(info);
341 TensorPtr tensor;
342 switch (type) {
343 case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
344 tensor = std::make_unique<Custom::Tensor<bool>>(api, ctx, ith_input, true);
345 break;
346 case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
347 tensor = std::make_unique<Custom::Tensor<float>>(api, ctx, ith_input, true);
348 break;
349 case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
350 tensor = std::make_unique<Custom::Tensor<double>>(api, ctx, ith_input, true);
351 break;
352 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
353 tensor = std::make_unique<Custom::Tensor<uint8_t>>(api, ctx, ith_input, true);
354 break;
355 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
356 tensor = std::make_unique<Custom::Tensor<int8_t>>(api, ctx, ith_input, true);
357 break;
358 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
359 tensor = std::make_unique<Custom::Tensor<uint16_t>>(api, ctx, ith_input, true);
360 break;
361 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
362 tensor = std::make_unique<Custom::Tensor<int16_t>>(api, ctx, ith_input, true);
363 break;
364 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
365 tensor = std::make_unique<Custom::Tensor<uint32_t>>(api, ctx, ith_input, true);
366 break;
367 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
368 tensor = std::make_unique<Custom::Tensor<int32_t>>(api, ctx, ith_input, true);
369 break;
370 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
371 tensor = std::make_unique<Custom::Tensor<uint64_t>>(api, ctx, ith_input, true);
372 break;
373 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
374 tensor = std::make_unique<Custom::Tensor<int64_t>>(api, ctx, ith_input, true);
375 break;
376 case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
377 tensor = std::make_unique<Custom::Tensor<std::string>>(api, ctx, ith_input, true);
378 break;
379 default:
380 ORTX_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION);
381 break;
382 }
383 tensors_.emplace_back(tensor.release());
384 } // for
385 } else {
386 // a Variadic used for output is populated by the Compute so leave tensors_ empty here
387 }
388 }
389 template <typename T>
390 T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) {
391 auto tensor = std::make_unique<Tensor<T>>(api_, ctx_, ith_output, false);
392 auto raw_output = tensor.get()->Allocate(shape);
393 tensors_.emplace_back(tensor.release());
394 return raw_output;
395 }
396 Tensor<std::string>& AllocateStringTensor(size_t ith_output) {
397 auto tensor = std::make_unique<Tensor<std::string>>(api_, ctx_, ith_output, false);
398 Tensor<std::string>& output = *tensor;
399 tensors_.emplace_back(tensor.release());
400 return output;
401 }
402 const void* DataRaw() const override {
403 ORTX_CXX_API_THROW("DataRaw() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION);
404 return nullptr;
405 }
406 size_t SizeInBytes() const override {
407 ORTX_CXX_API_THROW("SizeInBytes() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION);
408 return 0;
409 }
410 size_t Size() const {
411 return tensors_.size();
412 }
413 const TensorPtr& operator[](size_t indice) const {
414 return tensors_.at(indice);
415 }
416
417 private:
418 TensorPtrs tensors_;
419};
420
421#ifdef USE_CUDA
422
423enum CudaResource {
424 cuda_handle_t = 10000,
425};
426
427struct CudaContext {
428 static const int cuda_resource_ver = 1;
429 void Init(const OrtW::CustomOpApi& api, const OrtKernelContext& ctx) {
430 const auto& ort_api = api.GetOrtApi();
431 ort_api.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream);
432 if (!cuda_stream) {
433 ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION);
434 }
435 }
436 void* cuda_stream = {};
437};
438
439#endif
440
441struct OrtLiteCustomOp : public OrtCustomOp {
442 // CreateTuple
443 template <size_t ith_input, size_t ith_output, typename... Ts>
444 static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
445 CreateTuple(const OrtW::CustomOpApi*, OrtKernelContext*, std::vector<TensorPtr>&, size_t, size_t, const std::string&) {
446 return std::make_tuple();
447 }
448
449 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
450 static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type
451 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
452 std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
453 auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
454 return std::tuple_cat(current, next);
455 }
456
457#ifdef USE_CUDA
458 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
459 static typename std::enable_if<std::is_same<T, const CudaContext&>::value, std::tuple<T, Ts...>>::type
460 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
461 thread_local CudaContext cuda_context;
462 cuda_context.Init(*api, *context);
463 std::tuple<T> current = std::tuple<const CudaContext&>{cuda_context};
464 auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
465 return std::tuple_cat(current, next);
466 }
467#endif
468
469#if ORT_API_VERSION >= 14
470 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
471 static typename std::enable_if<std::is_same<T, const Variadic*>::value, std::tuple<T, Ts...>>::type
472 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
473 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
474 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
475 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
476 return std::tuple_cat(current, next);
477 }
478
479 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
480 static typename std::enable_if<std::is_same<T, const Variadic&>::value, std::tuple<T, Ts...>>::type
481 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
482 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
483 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
484 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
485 return std::tuple_cat(current, next);
486 }
487
488 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
489 static typename std::enable_if<std::is_same<T, Variadic*>::value, std::tuple<T, Ts...>>::type
490 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
491 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
492 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
493 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
494 return std::tuple_cat(current, next);
495 }
496
497 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
498 static typename std::enable_if<std::is_same<T, Variadic&>::value, std::tuple<T, Ts...>>::type
499 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
500 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
501 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
502 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
503 return std::tuple_cat(current, next);
504 }
505#endif
506
507#define CREATE_TUPLE_INPUT(data_type) \
508 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
509 static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
510 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
511 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
512 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; \
513 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
514 return std::tuple_cat(current, next); \
515 } \
516 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
517 static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
518 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
519 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
520 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; \
521 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
522 return std::tuple_cat(current, next); \
523 } \
524 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
525 static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
526 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
527 if (ith_input < num_input) { \
528 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
529 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())}; \
530 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
531 return std::tuple_cat(current, next); \
532 } else { \
533 std::tuple<T> current = std::tuple<T>{}; \
534 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
535 return std::tuple_cat(current, next); \
536 } \
537 } \
538 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
539 static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>*>::value, std::tuple<T, Ts...>>::type \
540 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
541 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
542 if (!tensors.back()->IsCpuTensor()) { \
543 ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \
544 } \
545 std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
546 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
547 return std::tuple_cat(current, next); \
548 } \
549 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
550 static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>&>::value, std::tuple<T, Ts...>>::type \
551 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
552 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
553 if (!tensors.back()->IsCpuTensor()) { \
554 ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \
555 } \
556 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
557 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
558 return std::tuple_cat(current, next); \
559 } \
560 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
561 static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type>*>>::value, std::tuple<T, Ts...>>::type \
562 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
563 if (ith_input < num_input) { \
564 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
565 if (!tensors.back()->IsCpuTensor()) { \
566 ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \
567 } \
568 std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
569 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
570 return std::tuple_cat(current, next); \
571 } else { \
572 std::tuple<T> current = std::tuple<T>{}; \
573 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
574 return std::tuple_cat(current, next); \
575 } \
576 } \
577 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
578 static typename std::enable_if<std::is_same<T, data_type>::value, std::tuple<T, Ts...>>::type \
579 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
580 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
581 if (!tensors.back()->IsCpuTensor()) { \
582 ORTX_CXX_API_THROW("scalar input could only be applied to CPU tensor", ORT_FAIL); \
583 } \
584 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsScalar()}; \
585 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
586 return std::tuple_cat(current, next); \
587 } \
588 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
589 static typename std::enable_if<std::is_same<T, std::optional<data_type>>::value, std::tuple<T, Ts...>>::type \
590 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
591 if (ith_input < num_input) { \
592 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
593 if (!tensors.back()->IsCpuTensor()) { \
594 ORTX_CXX_API_THROW("scalar input could only be applied to CPU tensor", ORT_FAIL); \
595 } \
596 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsScalar()}; \
597 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
598 return std::tuple_cat(current, next); \
599 } else { \
600 std::tuple<T> current = std::tuple<T>{}; \
601 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
602 return std::tuple_cat(current, next); \
603 } \
604 }
605#define CREATE_TUPLE_OUTPUT(data_type) \
606 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
607 static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
608 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
609 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
610 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; \
611 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
612 return std::tuple_cat(current, next); \
613 } \
614 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
615 static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
616 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
617 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
618 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; \
619 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
620 return std::tuple_cat(current, next); \
621 } \
622 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
623 static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
624 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
625 if (ith_output < num_output) { \
626 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
627 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())}; \
628 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
629 return std::tuple_cat(current, next); \
630 } else { \
631 std::tuple<T> current = std::tuple<T>{}; \
632 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
633 return std::tuple_cat(current, next); \
634 } \
635 }
636#define CREATE_TUPLE(data_type) \
637 CREATE_TUPLE_INPUT(data_type) \
638 CREATE_TUPLE_OUTPUT(data_type)
639
640 CREATE_TUPLE(bool)
641 CREATE_TUPLE(float)
642 CREATE_TUPLE(double)
643 CREATE_TUPLE(int8_t)
644 CREATE_TUPLE(int16_t)
645 CREATE_TUPLE(int32_t)
646 CREATE_TUPLE(int64_t)
647 CREATE_TUPLE(uint8_t)
648 CREATE_TUPLE(uint16_t)
649 CREATE_TUPLE(uint32_t)
650 CREATE_TUPLE(uint64_t)
651 CREATE_TUPLE(std::string)
652 CREATE_TUPLE_INPUT(std::string_view)
653
654 // ParseArgs ...
655 template <typename... Ts>
656 static typename std::enable_if<0 == sizeof...(Ts)>::type
657 ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
658 }
659
660 template <typename T, typename... Ts>
661 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type
662 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
663 ParseArgs<Ts...>(input_types, output_types);
664 }
665
666#ifdef USE_CUDA
667 template <typename T, typename... Ts>
668 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const CudaContext&>::value>::type
669 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
670 ParseArgs<Ts...>(input_types, output_types);
671 }
672#endif
673
674#if ORT_API_VERSION >= 14
675 template <typename T, typename... Ts>
676 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic&>::value>::type
677 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
678 if (!input_types.empty()) {
679 ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
680 }
681 input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
682 ParseArgs<Ts...>(input_types, output_types);
683 }
684
685 template <typename T, typename... Ts>
686 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic*>::value>::type
687 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
688 if (!input_types.empty()) {
689 ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
690 }
691 input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
692 ParseArgs<Ts...>(input_types, output_types);
693 }
694
695 template <typename T, typename... Ts>
696 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic&>::value>::type
697 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
698 if (!output_types.empty()) {
699 ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
700 }
701 output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
702 ParseArgs<Ts...>(input_types, output_types);
703 }
704
705 template <typename T, typename... Ts>
706 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic*>::value>::type
707 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
708 if (!output_types.empty()) {
709 ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
710 }
711 output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
712 ParseArgs<Ts...>(input_types, output_types);
713 }
714#endif
715
716#define PARSE_INPUT_BASE(pack_type, onnx_type) \
717 template <typename T, typename... Ts> \
718 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
719 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
720 input_types.push_back(onnx_type); \
721 ParseArgs<Ts...>(input_types, output_types); \
722 } \
723 template <typename T, typename... Ts> \
724 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \
725 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
726 input_types.push_back(onnx_type); \
727 ParseArgs<Ts...>(input_types, output_types); \
728 }
729
730#define PARSE_INPUT(data_type, onnx_type) \
731 PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
732 PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
733 PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \
734 PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \
735 PARSE_INPUT_BASE(data_type, onnx_type)
736
737#define PARSE_OUTPUT(data_type, onnx_type) \
738 template <typename T, typename... Ts> \
739 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \
740 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
741 output_types.push_back(onnx_type); \
742 ParseArgs<Ts...>(input_types, output_types); \
743 } \
744 template <typename T, typename... Ts> \
745 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \
746 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
747 output_types.push_back(onnx_type); \
748 ParseArgs<Ts...>(input_types, output_types); \
749 } \
750 template <typename T, typename... Ts> \
751 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
752 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
753 output_types.push_back(onnx_type); \
754 ParseArgs<Ts...>(input_types, output_types); \
755 }
756
757#define PARSE_ARGS(data_type, onnx_type) \
758 PARSE_INPUT(data_type, onnx_type) \
759 PARSE_OUTPUT(data_type, onnx_type)
760
761 PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
762 PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
763 PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
764 PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
765 PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
766 PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
767 PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
768 PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
769 PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
770 PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
771 PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
772 PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
773 PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output
774
775 OrtLiteCustomOp(const char* op_name,
776 const char* execution_provider) : op_name_(op_name),
777 execution_provider_(execution_provider) {
778 // Zero out OrtCustomOp so that any added func pointers are nullptr for forwards compatibility
779 memset(&this->version, 0, sizeof(OrtCustomOp));
780
781 int act_ver = GetActiveOrtAPIVersion();
782 OrtCustomOp::version = act_ver < ORT_API_VERSION ? act_ver : ORT_API_VERSION;
783
784 OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
785 OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
786
787 OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) {
788 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
789 return self->input_types_.size();
790 };
791
792 OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) {
793 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
794 return self->input_types_[indice];
795 };
796
797 OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) {
798 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
799 return self->output_types_.size();
800 };
801
802 OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) {
803 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
804 return self->output_types_[indice];
805 };
806
807#if ORT_API_VERSION >= 14
808 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t) {
809 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
810 return (self->input_types_.empty() || self->input_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
811 };
812
813 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
814 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
815 return (self->output_types_.empty() || self->output_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
816 };
817
818 OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) {
819 return 1;
820 };
821
822 OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) {
823 return 0;
824 };
825
826 OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) {
827 return 1;
828 };
829
830 OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) {
831 return 0;
832 };
833
834 OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) {
835 return OrtMemTypeDefault;
836 };
837#else
838 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) {
839 return INPUT_OUTPUT_OPTIONAL;
840 };
841
842 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
843 return INPUT_OUTPUT_OPTIONAL;
844 };
845#endif
846 }
847
848 const std::string op_name_;
849 const std::string execution_provider_;
850
851 std::vector<ONNXTensorElementDataType> input_types_;
852 std::vector<ONNXTensorElementDataType> output_types_;
853};
854
855template <typename... Args>
856struct OrtLiteCustomFunc : public OrtLiteCustomOp {
857 using ComputeFn = void (*)(Args...);
858 using MyType = OrtLiteCustomFunc<Args...>;
859
860 struct Kernel {
861 ComputeFn compute_fn_{};
862 std::string ep_{};
863 std::unique_ptr<OrtW::CustomOpApi> api_;
864 };
865
866 OrtLiteCustomFunc(const char* op_name,
867 const char* execution_provider,
868 ComputeFn compute_fn) : OrtLiteCustomOp(op_name, execution_provider),
869 compute_fn_(compute_fn) {
870 ParseArgs<Args...>(input_types_, output_types_);
871
872 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
873 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
874 std::vector<TensorPtr> tensors;
875 auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
876 context,
877 tensors,
878 kernel->api_->KernelContext_GetInputCount(context),
879 kernel->api_->KernelContext_GetOutputCount(context),
880 kernel->ep_);
881 std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t);
882 };
883
884 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
885 auto kernel = std::make_unique<Kernel>();
886 auto self = static_cast<const OrtLiteCustomFunc*>(this_);
887 kernel->compute_fn_ = self->compute_fn_;
888 kernel->ep_ = self->execution_provider_;
889 kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
890 return reinterpret_cast<void*>(kernel.release());
891 };
892
893 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
894 delete reinterpret_cast<Kernel*>(op_kernel);
895 };
896 }
897
898 ComputeFn compute_fn_;
899};
900
901template <typename CustomOp>
902struct OrtLiteCustomStruct : public OrtLiteCustomOp {
903 template <typename... Args>
904 using CustomComputeFn = void (CustomOp::*)(Args...) const;
905 using MyType = OrtLiteCustomStruct<CustomOp>;
906
907 struct Kernel {
908 std::unique_ptr<CustomOp> custom_op_;
909 std::string ep_{};
910 std::unique_ptr<OrtW::CustomOpApi> api_;
911 };
912
913 OrtLiteCustomStruct(const char* op_name,
914 const char* execution_provider) : OrtLiteCustomOp(op_name,
915 execution_provider) {
916 init(&CustomOp::Compute);
917 }
918
919 template <typename... Args>
920 void init(CustomComputeFn<Args...>) {
921 ParseArgs<Args...>(input_types_, output_types_);
922
923 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
924 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
925 std::vector<TensorPtr> tensors;
926 auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
927 context,
928 tensors,
929 kernel->api_->KernelContext_GetInputCount(context),
930 kernel->api_->KernelContext_GetOutputCount(context),
931 kernel->ep_);
932 std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t);
933 };
934
935 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
936 auto kernel = std::make_unique<Kernel>();
937 kernel->custom_op_ = std::make_unique<CustomOp>(*ort_api, *info);
938 auto self = static_cast<const MyType*>(this_);
939 kernel->ep_ = self->execution_provider_;
940 kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
941 return reinterpret_cast<void*>(kernel.release());
942 };
943
944 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
945 delete reinterpret_cast<Kernel*>(op_kernel);
946 };
947 }
948};
949
950template <typename... Args>
951OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
952 const char* execution_provider,
953 void (*custom_compute_fn)(Args...)) {
954 using LiteOp = OrtLiteCustomFunc<Args...>;
955 return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn).release();
956}
957
958template <typename CustomOp>
959OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
960 const char* execution_provider) {
961 using LiteOp = OrtLiteCustomStruct<CustomOp>;
962 return std::make_unique<LiteOp>(op_name, execution_provider).release();
963}
964
965} // namespace Custom
966} // namespace Ort
967