microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
a0c26255112b7a56e48abc62046477e0c7ea6152

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/custom_op_lite.h

923lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5#include "onnxruntime_customop.hpp"
6#include <optional>
7#include <numeric>
8
9namespace Ort {
10namespace Custom {
11
12class TensorBase {
13 public:
14 TensorBase(const OrtW::CustomOpApi& api,
15 OrtKernelContext& ctx,
16 size_t indice,
17 bool is_input) : api_(api),
18 ctx_(ctx),
19 indice_(indice),
20 is_input_(is_input) {}
21
22 virtual ~TensorBase() = default;
23 operator bool() const {
24 return shape_.has_value();
25 }
26 const std::vector<int64_t>& Shape() const {
27 if (shape_.has_value()) {
28 return *shape_;
29 } else {
30 ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION);
31 }
32 }
33 ONNXTensorElementDataType Type() const {
34 return type_;
35 }
36 int64_t NumberOfElement() const {
37 if (shape_.has_value()) {
38 return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies<int64_t>());
39 } else {
40 ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION);
41 }
42 }
43 std::string Shape2Str() const {
44 if (shape_.has_value()) {
45 std::string shape_str;
46 for (const auto& dim : *shape_) {
47 shape_str.append(std::to_string(dim));
48 shape_str.append(", ");
49 }
50 return shape_str;
51 } else {
52 return "empty";
53 }
54 }
55 bool IsCpuTensor() const {
56 return strcmp("Cpu", mem_type_) == 0;
57 }
58 virtual const void* DataRaw() const = 0;
59 virtual size_t SizeInBytes() const = 0;
60
61 protected:
62 const OrtW::CustomOpApi& api_;
63 OrtKernelContext& ctx_;
64 size_t indice_;
65 bool is_input_;
66 std::optional<std::vector<int64_t>> shape_;
67 ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
68 const char* mem_type_ = "Cpu";
69};
70
71template <typename T>
72struct Span {
73 const T* data_ = {};
74 size_t size_ = {};
75 void Assign(const T* data, size_t size) {
76 data_ = data;
77 size_ = size;
78 }
79 size_t size() const { return size_; }
80 T operator[](size_t indice) const {
81 return data_[indice];
82 }
83 const T* data() const { return data_; }
84};
85
86template <typename T>
87class Tensor : public TensorBase {
88 public:
89 using TT = typename std::remove_reference<T>::type;
90 Tensor(const OrtW::CustomOpApi& api,
91 OrtKernelContext& ctx,
92 size_t indice,
93 bool is_input) : TensorBase(api,
94 ctx,
95 indice,
96 is_input) {
97 if (is_input) {
98 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
99 if (indice >= input_count) {
100 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
101 }
102 const_value_ = api_.KernelContext_GetInput(&ctx_, indice);
103 auto* info = api_.GetTensorTypeAndShape(const_value_);
104 shape_ = api_.GetTensorShape(info);
105 type_ = api_.GetTensorElementType(info);
106 api_.ReleaseTensorTypeAndShapeInfo(info);
107 const OrtMemoryInfo* mem_info = {};
108 api_.ThrowOnError(api_.GetOrtApi().GetTensorMemoryInfo(const_value_, &mem_info));
109 if (mem_info) {
110 api_.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_));
111 }
112 }
113 }
114 const TT* Data() const {
115 return api_.GetTensorData<TT>(const_value_);
116 }
117
118 const void* DataRaw() const override {
119 return reinterpret_cast<const void*>(Data());
120 }
121
122 size_t SizeInBytes() const override {
123 return NumberOfElement() * sizeof(TT);
124 }
125
126 TT* Allocate(const std::vector<int64_t>& shape) {
127 if (!data_) {
128 OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size());
129 shape_ = shape;
130 data_ = api_.GetTensorMutableData<TT>(out);
131 }
132 return data_;
133 }
134 const Span<T>& AsSpan() {
135 if (!shape_.has_value() || shape_->size() != 1) {
136 ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
137 }
138 span_.Assign(Data(), (*shape_)[0]);
139 return span_;
140 }
141 const T& AsScalar() {
142 if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
143 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
144 }
145 return *Data();
146 }
147
148 private:
149 const OrtValue* const_value_{}; // for input
150 TT* data_{}; // for output
151 Span<T> span_;
152};
153
154template <>
155class Tensor<std::string> : public TensorBase {
156 public:
157 using strings = std::vector<std::string>;
158
159 Tensor(const OrtW::CustomOpApi& api,
160 OrtKernelContext& ctx,
161 size_t indice,
162 bool is_input) : TensorBase(api,
163 ctx,
164 indice,
165 is_input) {
166 if (is_input) {
167 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
168 if (indice >= input_count) {
169 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
170 }
171
172 auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
173 auto* info = api_.GetTensorTypeAndShape(const_value);
174 shape_ = api_.GetTensorShape(info);
175 type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
176 api_.ReleaseTensorTypeAndShapeInfo(info);
177
178 size_t num_chars;
179 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
180 std::vector<char> chars(num_chars + 1, '\0');
181 auto num_strings = NumberOfElement();
182 std::vector<size_t> offsets(NumberOfElement());
183 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
184 (void*)chars.data(),
185 num_chars,
186 offsets.data(),
187 offsets.size()));
188 auto upper_bound = static_cast<int64_t>(num_strings) - 1;
189 input_strings_.resize(num_strings);
190 for (int64_t i = upper_bound; i >= 0; --i) {
191 if (i < upper_bound) {
192 chars[offsets[i + 1]] = '\0';
193 }
194 input_strings_[i] = chars.data() + offsets[i];
195 }
196 }
197 }
198 const strings& Data() const {
199 return input_strings_;
200 }
201 const void* DataRaw() const override {
202 if (input_strings_.size() != 1) {
203 ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
204 }
205 return reinterpret_cast<const void*>(input_strings_[0].c_str());
206 }
207 size_t SizeInBytes() const override {
208 if (input_strings_.size() != 1) {
209 ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
210 }
211 return input_strings_[0].size();
212 }
213 void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
214 std::vector<const char*> raw;
215 for (const auto& s : ss) {
216 raw.push_back(s.data());
217 }
218 auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
219 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, raw.data(), raw.size()));
220 }
221 void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
222 auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
223 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, ss.data(), ss.size()));
224 }
225 const Span<std::string>& AsSpan() {
226 ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
227 }
228 const std::string& AsScalar() {
229 if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
230 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
231 }
232 return input_strings_[0];
233 }
234
235 private:
236 std::vector<std::string> input_strings_; // for input
237};
238
239template <>
240class Tensor<std::string_view> : public TensorBase {
241 public:
242 using strings = std::vector<std::string>;
243 using string_views = std::vector<std::string_view>;
244
245 Tensor(const OrtW::CustomOpApi& api,
246 OrtKernelContext& ctx,
247 size_t indice,
248 bool is_input) : TensorBase(api,
249 ctx,
250 indice,
251 is_input) {
252 if (is_input_) {
253 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
254 if (indice >= input_count) {
255 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
256 }
257 auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
258 auto* info = api_.GetTensorTypeAndShape(const_value);
259 shape_ = api_.GetTensorShape(info);
260 type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
261 api_.ReleaseTensorTypeAndShapeInfo(info);
262
263 size_t num_chars;
264 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
265 chars_.resize(num_chars + 1, '\0');
266
267 auto num_strings = static_cast<size_t>(NumberOfElement());
268 if (num_strings) {
269 std::vector<size_t> offsets(num_strings);
270 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
271 (void*)chars_.data(),
272 num_chars,
273 offsets.data(),
274 offsets.size()));
275 offsets.push_back(num_chars);
276 for (size_t i = 0; i < num_strings; ++i) {
277 input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
278 }
279 }
280 }
281 }
282 int64_t NumberOfElement() const {
283 if (shape_.has_value()) {
284 return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies<int64_t>());
285 } else {
286 return 0;
287 }
288 }
289 const string_views& Data() const {
290 return input_string_views_;
291 }
292 const void* DataRaw() const override {
293 if (input_string_views_.size() != 1) {
294 ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
295 }
296 return reinterpret_cast<const void*>(input_string_views_[0].data());
297 }
298 size_t SizeInBytes() const override {
299 if (input_string_views_.size() != 1) {
300 ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
301 }
302 return input_string_views_[0].size();
303 }
304 const Span<std::string_view>& AsSpan() {
305 ORTX_CXX_API_THROW("span for TensorT of string view not implemented", ORT_RUNTIME_EXCEPTION);
306 }
307 std::string_view AsScalar() {
308 if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
309 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
310 }
311 return input_string_views_[0];
312 }
313
314 private:
315 std::vector<char> chars_; // for input
316 std::vector<std::string_view> input_string_views_; // for input
317};
318
319using TensorPtr = std::unique_ptr<Custom::TensorBase>;
320using TensorPtrs = std::vector<TensorPtr>;
321
322// Represent variadic input or output
323struct Variadic : public TensorBase {
324 Variadic(const OrtW::CustomOpApi& api,
325 OrtKernelContext& ctx,
326 size_t indice,
327 bool is_input) : TensorBase(api,
328 ctx,
329 indice,
330 is_input) {
331#if ORT_API_VERSION < 14
332 ORTX_CXX_API_THROW("Variadic input or output only supported after onnxruntime 1.14", ORT_RUNTIME_EXCEPTION);
333#endif
334 if (is_input) {
335 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
336 for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
337 auto* const_value = api_.KernelContext_GetInput(&ctx_, ith_input);
338 auto* info = api_.GetTensorTypeAndShape(const_value);
339 auto type = api_.GetTensorElementType(info);
340 api_.ReleaseTensorTypeAndShapeInfo(info);
341 TensorPtr tensor;
342 switch (type) {
343 case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
344 tensor = std::make_unique<Custom::Tensor<bool>>(api, ctx, ith_input, true);
345 break;
346 case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
347 tensor = std::make_unique<Custom::Tensor<float>>(api, ctx, ith_input, true);
348 break;
349 case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
350 tensor = std::make_unique<Custom::Tensor<double>>(api, ctx, ith_input, true);
351 break;
352 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
353 tensor = std::make_unique<Custom::Tensor<uint8_t>>(api, ctx, ith_input, true);
354 break;
355 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
356 tensor = std::make_unique<Custom::Tensor<int8_t>>(api, ctx, ith_input, true);
357 break;
358 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
359 tensor = std::make_unique<Custom::Tensor<uint16_t>>(api, ctx, ith_input, true);
360 break;
361 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
362 tensor = std::make_unique<Custom::Tensor<int16_t>>(api, ctx, ith_input, true);
363 break;
364 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
365 tensor = std::make_unique<Custom::Tensor<uint32_t>>(api, ctx, ith_input, true);
366 break;
367 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
368 tensor = std::make_unique<Custom::Tensor<int32_t>>(api, ctx, ith_input, true);
369 break;
370 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
371 tensor = std::make_unique<Custom::Tensor<uint64_t>>(api, ctx, ith_input, true);
372 break;
373 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
374 tensor = std::make_unique<Custom::Tensor<int64_t>>(api, ctx, ith_input, true);
375 break;
376 case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
377 tensor = std::make_unique<Custom::Tensor<std::string>>(api, ctx, ith_input, true);
378 break;
379 default:
380 ORTX_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION);
381 break;
382 }
383 tensors_.emplace_back(tensor.release());
384 } // for
385 } else {
386 // a Variadic used for output is populated by the Compute so leave tensors_ empty here
387 }
388 }
389 template <typename T>
390 T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) {
391 auto tensor = std::make_unique<Tensor<T>>(api_, ctx_, ith_output, false);
392 auto raw_output = tensor.get()->Allocate(shape);
393 tensors_.emplace_back(tensor.release());
394 return raw_output;
395 }
396 Tensor<std::string>& AllocateStringTensor(size_t ith_output) {
397 auto tensor = std::make_unique<Tensor<std::string>>(api_, ctx_, ith_output, false);
398 Tensor<std::string>& output = *tensor;
399 tensors_.emplace_back(tensor.release());
400 return output;
401 }
402 const void* DataRaw() const override {
403 ORTX_CXX_API_THROW("DataRaw() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION);
404 return nullptr;
405 }
406 size_t SizeInBytes() const override {
407 ORTX_CXX_API_THROW("SizeInBytes() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION);
408 return 0;
409 }
410 size_t Size() const {
411 return tensors_.size();
412 }
413 const TensorPtr& operator[](size_t indice) const {
414 return tensors_.at(indice);
415 }
416
417 private:
418 TensorPtrs tensors_;
419};
420
421struct OrtLiteCustomOp : public OrtCustomOp {
422 // CreateTuple
423 template <size_t ith_input, size_t ith_output, typename... Ts>
424 static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
425 CreateTuple(const OrtW::CustomOpApi*, OrtKernelContext*, std::vector<TensorPtr>&, size_t, size_t, const std::string&) {
426 return std::make_tuple();
427 }
428
429 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
430 static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type
431 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
432 std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
433 auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
434 return std::tuple_cat(current, next);
435 }
436
437#if ORT_API_VERSION >= 14
438 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
439 static typename std::enable_if<std::is_same<T, const Variadic*>::value, std::tuple<T, Ts...>>::type
440 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
441 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
442 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
443 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
444 return std::tuple_cat(current, next);
445 }
446
447 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
448 static typename std::enable_if<std::is_same<T, const Variadic&>::value, std::tuple<T, Ts...>>::type
449 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
450 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
451 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
452 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
453 return std::tuple_cat(current, next);
454 }
455
456 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
457 static typename std::enable_if<std::is_same<T, Variadic*>::value, std::tuple<T, Ts...>>::type
458 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
459 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
460 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
461 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
462 return std::tuple_cat(current, next);
463 }
464
465 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
466 static typename std::enable_if<std::is_same<T, Variadic&>::value, std::tuple<T, Ts...>>::type
467 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
468 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
469 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
470 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
471 return std::tuple_cat(current, next);
472 }
473#endif
474
475#define CREATE_TUPLE_INPUT(data_type) \
476 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
477 static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
478 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
479 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
480 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; \
481 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
482 return std::tuple_cat(current, next); \
483 } \
484 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
485 static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
486 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
487 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
488 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; \
489 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
490 return std::tuple_cat(current, next); \
491 } \
492 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
493 static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
494 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
495 if (ith_input < num_input) { \
496 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
497 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())}; \
498 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
499 return std::tuple_cat(current, next); \
500 } else { \
501 std::tuple<T> current = std::tuple<T>{}; \
502 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
503 return std::tuple_cat(current, next); \
504 } \
505 } \
506 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
507 static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>*>::value, std::tuple<T, Ts...>>::type \
508 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
509 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
510 if (!tensors.back()->IsCpuTensor()) { \
511 ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \
512 } \
513 std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
514 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
515 return std::tuple_cat(current, next); \
516 } \
517 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
518 static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>&>::value, std::tuple<T, Ts...>>::type \
519 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
520 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
521 if (!tensors.back()->IsCpuTensor()) { \
522 ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \
523 } \
524 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
525 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
526 return std::tuple_cat(current, next); \
527 } \
528 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
529 static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type>*>>::value, std::tuple<T, Ts...>>::type \
530 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
531 if (ith_input < num_input) { \
532 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
533 if (!tensors.back()->IsCpuTensor()) { \
534 ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \
535 } \
536 std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
537 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
538 return std::tuple_cat(current, next); \
539 } else { \
540 std::tuple<T> current = std::tuple<T>{}; \
541 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
542 return std::tuple_cat(current, next); \
543 } \
544 } \
545 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
546 static typename std::enable_if<std::is_same<T, data_type>::value, std::tuple<T, Ts...>>::type \
547 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
548 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
549 if (!tensors.back()->IsCpuTensor()) { \
550 ORTX_CXX_API_THROW("scalar input could only be applied to CPU tensor", ORT_FAIL); \
551 } \
552 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsScalar()}; \
553 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
554 return std::tuple_cat(current, next); \
555 } \
556 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
557 static typename std::enable_if<std::is_same<T, std::optional<data_type>>::value, std::tuple<T, Ts...>>::type \
558 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
559 if (ith_input < num_input) { \
560 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
561 if (!tensors.back()->IsCpuTensor()) { \
562 ORTX_CXX_API_THROW("scalar input could only be applied to CPU tensor", ORT_FAIL); \
563 } \
564 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsScalar()}; \
565 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
566 return std::tuple_cat(current, next); \
567 } else { \
568 std::tuple<T> current = std::tuple<T>{}; \
569 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
570 return std::tuple_cat(current, next); \
571 } \
572 }
573#define CREATE_TUPLE_OUTPUT(data_type) \
574 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
575 static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
576 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
577 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
578 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; \
579 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
580 return std::tuple_cat(current, next); \
581 } \
582 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
583 static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
584 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
585 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
586 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; \
587 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
588 return std::tuple_cat(current, next); \
589 } \
590 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
591 static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
592 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
593 if (ith_output < num_output) { \
594 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
595 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())}; \
596 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
597 return std::tuple_cat(current, next); \
598 } else { \
599 std::tuple<T> current = std::tuple<T>{}; \
600 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
601 return std::tuple_cat(current, next); \
602 } \
603 }
604#define CREATE_TUPLE(data_type) \
605 CREATE_TUPLE_INPUT(data_type) \
606 CREATE_TUPLE_OUTPUT(data_type)
607
608 CREATE_TUPLE(bool)
609 CREATE_TUPLE(float)
610 CREATE_TUPLE(double)
611 CREATE_TUPLE(int8_t)
612 CREATE_TUPLE(int16_t)
613 CREATE_TUPLE(int32_t)
614 CREATE_TUPLE(int64_t)
615 CREATE_TUPLE(uint8_t)
616 CREATE_TUPLE(uint16_t)
617 CREATE_TUPLE(uint32_t)
618 CREATE_TUPLE(uint64_t)
619 CREATE_TUPLE(std::string)
620 CREATE_TUPLE_INPUT(std::string_view)
621
622 // ParseArgs ...
623 template <typename... Ts>
624 static typename std::enable_if<0 == sizeof...(Ts)>::type
625 ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
626 }
627
628 template <typename T, typename... Ts>
629 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type
630 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
631 ParseArgs<Ts...>(input_types, output_types);
632 }
633
634#if ORT_API_VERSION >= 14
635 template <typename T, typename... Ts>
636 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic&>::value>::type
637 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
638 if (!input_types.empty()) {
639 ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
640 }
641 input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
642 ParseArgs<Ts...>(input_types, output_types);
643 }
644
645 template <typename T, typename... Ts>
646 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic*>::value>::type
647 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
648 if (!input_types.empty()) {
649 ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
650 }
651 input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
652 ParseArgs<Ts...>(input_types, output_types);
653 }
654
655 template <typename T, typename... Ts>
656 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic&>::value>::type
657 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
658 if (!output_types.empty()) {
659 ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
660 }
661 output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
662 ParseArgs<Ts...>(input_types, output_types);
663 }
664
665 template <typename T, typename... Ts>
666 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic*>::value>::type
667 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
668 if (!output_types.empty()) {
669 ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
670 }
671 output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
672 ParseArgs<Ts...>(input_types, output_types);
673 }
674#endif
675
676#define PARSE_INPUT_BASE(pack_type, onnx_type) \
677 template <typename T, typename... Ts> \
678 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
679 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
680 input_types.push_back(onnx_type); \
681 ParseArgs<Ts...>(input_types, output_types); \
682 } \
683 template <typename T, typename... Ts> \
684 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \
685 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
686 input_types.push_back(onnx_type); \
687 ParseArgs<Ts...>(input_types, output_types); \
688 }
689
690#define PARSE_INPUT(data_type, onnx_type) \
691 PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
692 PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
693 PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \
694 PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \
695 PARSE_INPUT_BASE(data_type, onnx_type)
696
697#define PARSE_OUTPUT(data_type, onnx_type) \
698 template <typename T, typename... Ts> \
699 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \
700 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
701 output_types.push_back(onnx_type); \
702 ParseArgs<Ts...>(input_types, output_types); \
703 } \
704 template <typename T, typename... Ts> \
705 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \
706 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
707 output_types.push_back(onnx_type); \
708 ParseArgs<Ts...>(input_types, output_types); \
709 } \
710 template <typename T, typename... Ts> \
711 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
712 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
713 output_types.push_back(onnx_type); \
714 ParseArgs<Ts...>(input_types, output_types); \
715 }
716
717#define PARSE_ARGS(data_type, onnx_type) \
718 PARSE_INPUT(data_type, onnx_type) \
719 PARSE_OUTPUT(data_type, onnx_type)
720
721 PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
722 PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
723 PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
724 PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
725 PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
726 PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
727 PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
728 PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
729 PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
730 PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
731 PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
732 PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
733 PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output
734
735 OrtLiteCustomOp(const char* op_name,
736 const char* execution_provider) : op_name_(op_name),
737 execution_provider_(execution_provider) {
738 int act_ver = GetActiveOrtAPIVersion();
739 OrtCustomOp::version = act_ver < ORT_API_VERSION ? act_ver : ORT_API_VERSION;
740
741 OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
742 OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
743
744 OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) {
745 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
746 return self->input_types_.size();
747 };
748
749 OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) {
750 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
751 return self->input_types_[indice];
752 };
753
754 OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) {
755 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
756 return self->output_types_.size();
757 };
758
759 OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) {
760 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
761 return self->output_types_[indice];
762 };
763
764#if ORT_API_VERSION >= 14
765 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t) {
766 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
767 return (self->input_types_.empty() || self->input_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
768 };
769
770 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
771 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
772 return (self->output_types_.empty() || self->output_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
773 };
774
775 OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) {
776 return 1;
777 };
778
779 OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) {
780 return 0;
781 };
782
783 OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) {
784 return 1;
785 };
786
787 OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) {
788 return 0;
789 };
790
791 OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) {
792 return OrtMemTypeDefault;
793 };
794#else
795 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) {
796 return INPUT_OUTPUT_OPTIONAL;
797 };
798
799 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
800 return INPUT_OUTPUT_OPTIONAL;
801 };
802#endif
803 }
804
805 const std::string op_name_;
806 const std::string execution_provider_;
807
808 std::vector<ONNXTensorElementDataType> input_types_;
809 std::vector<ONNXTensorElementDataType> output_types_;
810};
811
812template <typename... Args>
813struct OrtLiteCustomFunc : public OrtLiteCustomOp {
814 using ComputeFn = void (*)(Args...);
815 using MyType = OrtLiteCustomFunc<Args...>;
816
817 struct Kernel {
818 ComputeFn compute_fn_{};
819 std::string ep_{};
820 std::unique_ptr<OrtW::CustomOpApi> api_;
821 };
822
823 OrtLiteCustomFunc(const char* op_name,
824 const char* execution_provider,
825 ComputeFn compute_fn) : OrtLiteCustomOp(op_name, execution_provider),
826 compute_fn_(compute_fn) {
827 ParseArgs<Args...>(input_types_, output_types_);
828
829 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
830 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
831 std::vector<TensorPtr> tensors;
832 auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
833 context,
834 tensors,
835 kernel->api_->KernelContext_GetInputCount(context),
836 kernel->api_->KernelContext_GetOutputCount(context),
837 kernel->ep_);
838 std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t);
839 };
840
841 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
842 auto kernel = std::make_unique<Kernel>();
843 auto self = static_cast<const OrtLiteCustomFunc*>(this_);
844 kernel->compute_fn_ = self->compute_fn_;
845 kernel->ep_ = self->execution_provider_;
846 kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
847 return reinterpret_cast<void*>(kernel.release());
848 };
849
850 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
851 delete reinterpret_cast<Kernel*>(op_kernel);
852 };
853 }
854
855 ComputeFn compute_fn_;
856};
857
858template <typename CustomOp>
859struct OrtLiteCustomStruct : public OrtLiteCustomOp {
860 template <typename... Args>
861 using CustomComputeFn = void (CustomOp::*)(Args...) const;
862 using MyType = OrtLiteCustomStruct<CustomOp>;
863
864 struct Kernel {
865 std::unique_ptr<CustomOp> custom_op_;
866 std::string ep_{};
867 std::unique_ptr<OrtW::CustomOpApi> api_;
868 };
869
870 OrtLiteCustomStruct(const char* op_name,
871 const char* execution_provider) : OrtLiteCustomOp(op_name,
872 execution_provider) {
873 init(&CustomOp::Compute);
874 }
875
876 template <typename... Args>
877 void init(CustomComputeFn<Args...>) {
878 ParseArgs<Args...>(input_types_, output_types_);
879
880 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
881 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
882 std::vector<TensorPtr> tensors;
883 auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
884 context,
885 tensors,
886 kernel->api_->KernelContext_GetInputCount(context),
887 kernel->api_->KernelContext_GetOutputCount(context),
888 kernel->ep_);
889 std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t);
890 };
891
892 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
893 auto kernel = std::make_unique<Kernel>();
894 kernel->custom_op_ = std::make_unique<CustomOp>(*ort_api, *info);
895 auto self = static_cast<const MyType*>(this_);
896 kernel->ep_ = self->execution_provider_;
897 kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
898 return reinterpret_cast<void*>(kernel.release());
899 };
900
901 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
902 delete reinterpret_cast<Kernel*>(op_kernel);
903 };
904 }
905};
906
907template <typename... Args>
908OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
909 const char* execution_provider,
910 void (*custom_compute_fn)(Args...)) {
911 using LiteOp = OrtLiteCustomFunc<Args...>;
912 return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn).release();
913}
914
915template <typename CustomOp>
916OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
917 const char* execution_provider) {
918 using LiteOp = OrtLiteCustomStruct<CustomOp>;
919 return std::make_unique<LiteOp>(op_name, execution_provider).release();
920}
921
922} // namespace Custom
923} // namespace Ort
924