microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
e3d9198de801fe80ec3896063e016f9db8cf2be2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/custom_op_lite.h

913lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5#include "onnxruntime_customop.hpp"
6#include <optional>
7#include <numeric>
8
9namespace Ort {
10namespace Custom {
11
12class TensorBase {
13 public:
14 TensorBase(const OrtW::CustomOpApi& api,
15 OrtKernelContext& ctx,
16 size_t indice,
17 bool is_input) : api_(api),
18 ctx_(ctx),
19 indice_(indice),
20 is_input_(is_input) {}
21
22 virtual ~TensorBase() = default;
23 operator bool() const {
24 return shape_.has_value();
25 }
26 const std::vector<int64_t>& Shape() const {
27 if (shape_.has_value()) {
28 return *shape_;
29 } else {
30 ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION);
31 }
32 }
33 ONNXTensorElementDataType Type() const {
34 return type_;
35 }
36 int64_t NumberOfElement() const {
37 if (shape_.has_value()) {
38 return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies<int64_t>());
39 } else {
40 ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION);
41 }
42 }
43 std::string Shape2Str() const {
44 if (shape_.has_value()) {
45 std::string shape_str;
46 for (const auto& dim: *shape_) {
47 shape_str.append(std::to_string(dim));
48 shape_str.append(", ");
49 }
50 return shape_str;
51 } else {
52 return "empty";
53 }
54 }
55 virtual const void* DataRaw() const = 0;
56 virtual size_t SizeInBytes() const = 0;
57
58 protected:
59 const OrtW::CustomOpApi& api_;
60 OrtKernelContext& ctx_;
61 size_t indice_;
62 bool is_input_;
63 std::optional<std::vector<int64_t>> shape_;
64 ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
65};
66
67template <typename T>
68struct Span {
69 const T* data_ = {};
70 size_t size_ = {};
71 void Assign(const T* data, size_t size) {
72 data_ = data;
73 size_ = size;
74 }
75 size_t size() const { return size_; }
76 T operator[](size_t indice) const {
77 return data_[indice];
78 }
79 const T* data() const { return data_; }
80};
81
82template <typename T>
83class Tensor : public TensorBase {
84 public:
85 using TT = typename std::remove_reference<T>::type;
86 Tensor(const OrtW::CustomOpApi& api,
87 OrtKernelContext& ctx,
88 size_t indice,
89 bool is_input) : TensorBase(api,
90 ctx,
91 indice,
92 is_input) {
93 if (is_input) {
94 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
95 if (indice >= input_count) {
96 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
97 }
98 const_value_ = api_.KernelContext_GetInput(&ctx_, indice);
99 auto* info = api_.GetTensorTypeAndShape(const_value_);
100 shape_ = api_.GetTensorShape(info);
101 type_ = api_.GetTensorElementType(info);
102 api_.ReleaseTensorTypeAndShapeInfo(info);
103 }
104 }
105 const TT* Data() const {
106 return api_.GetTensorData<TT>(const_value_);
107 }
108
109 const void* DataRaw() const override {
110 return reinterpret_cast<const void*>(Data());
111 }
112
113 size_t SizeInBytes() const override {
114 return NumberOfElement() * sizeof(TT);
115 }
116
117 TT* Allocate(const std::vector<int64_t>& shape) {
118 if (!data_) {
119 OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size());
120 shape_ = shape;
121 data_ = api_.GetTensorMutableData<TT>(out);
122 }
123 return data_;
124 }
125 const Span<T>& AsSpan() {
126 if (!shape_.has_value() || shape_->size() != 1) {
127 ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
128 }
129 span_.Assign(Data(), (*shape_)[0]);
130 return span_;
131 }
132 const T& AsScalar() {
133 if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
134 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
135 }
136 return *Data();
137 }
138
139 private:
140 const OrtValue* const_value_{}; // for input
141 TT* data_{}; // for output
142 Span<T> span_;
143};
144
145template <>
146class Tensor<std::string> : public TensorBase {
147 public:
148 using strings = std::vector<std::string>;
149
150 Tensor(const OrtW::CustomOpApi& api,
151 OrtKernelContext& ctx,
152 size_t indice,
153 bool is_input) : TensorBase(api,
154 ctx,
155 indice,
156 is_input) {
157 if (is_input) {
158 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
159 if (indice >= input_count) {
160 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
161 }
162
163 auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
164 auto* info = api_.GetTensorTypeAndShape(const_value);
165 shape_ = api_.GetTensorShape(info);
166 type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
167 api_.ReleaseTensorTypeAndShapeInfo(info);
168
169 size_t num_chars;
170 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
171 std::vector<char> chars(num_chars + 1, '\0');
172 auto num_strings = NumberOfElement();
173 std::vector<size_t> offsets(NumberOfElement());
174 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
175 (void*)chars.data(),
176 num_chars,
177 offsets.data(),
178 offsets.size()));
179 auto upper_bound = static_cast<int64_t>(num_strings) - 1;
180 input_strings_.resize(num_strings);
181 for (int64_t i = upper_bound; i >= 0; --i) {
182 if (i < upper_bound) {
183 chars[offsets[i + 1]] = '\0';
184 }
185 input_strings_[i] = chars.data() + offsets[i];
186 }
187 }
188 }
189 const strings& Data() const {
190 return input_strings_;
191 }
192 const void* DataRaw() const override {
193 if (input_strings_.size() != 1) {
194 ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
195 }
196 return reinterpret_cast<const void*>(input_strings_[0].c_str());
197 }
198 size_t SizeInBytes() const override {
199 if (input_strings_.size() != 1) {
200 ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
201 }
202 return input_strings_[0].size();
203 }
204 void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
205 std::vector<const char*> raw;
206 for (const auto& s : ss) {
207 raw.push_back(s.data());
208 }
209 auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
210 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, raw.data(), raw.size()));
211 }
212 void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
213 auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
214 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, ss.data(), ss.size()));
215 }
216 const Span<std::string>& AsSpan() {
217 ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
218 }
219 const std::string& AsScalar() {
220 if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
221 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
222 }
223 return input_strings_[0];
224 }
225
226 private:
227 std::vector<std::string> input_strings_; // for input
228};
229
230template <>
231class Tensor<std::string_view> : public TensorBase {
232 public:
233 using strings = std::vector<std::string>;
234 using string_views = std::vector<std::string_view>;
235
236 Tensor(const OrtW::CustomOpApi& api,
237 OrtKernelContext& ctx,
238 size_t indice,
239 bool is_input) : TensorBase(api,
240 ctx,
241 indice,
242 is_input) {
243 if (is_input_) {
244 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
245 if (indice >= input_count) {
246 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
247 }
248 auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
249 auto* info = api_.GetTensorTypeAndShape(const_value);
250 shape_ = api_.GetTensorShape(info);
251 type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
252 api_.ReleaseTensorTypeAndShapeInfo(info);
253
254 size_t num_chars;
255 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
256 chars_.resize(num_chars + 1, '\0');
257
258 auto num_strings = static_cast<size_t>(NumberOfElement());
259 if (num_strings) {
260 std::vector<size_t> offsets(num_strings);
261 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
262 (void*)chars_.data(),
263 num_chars,
264 offsets.data(),
265 offsets.size()));
266 offsets.push_back(num_chars);
267 for (size_t i = 0; i < num_strings; ++i) {
268 input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
269 }
270 }
271 }
272 }
273 int64_t NumberOfElement() const {
274 if (shape_.has_value()) {
275 return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies<int64_t>());
276 } else {
277 return 0;
278 }
279 }
280 const string_views& Data() const {
281 return input_string_views_;
282 }
283 const void* DataRaw() const override {
284 if (input_string_views_.size() != 1) {
285 ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
286 }
287 return reinterpret_cast<const void*>(input_string_views_[0].data());
288 }
289 size_t SizeInBytes() const override {
290 if (input_string_views_.size() != 1) {
291 ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
292 }
293 return input_string_views_[0].size();
294 }
295 const Span<std::string_view>& AsSpan() {
296 ORTX_CXX_API_THROW("span for TensorT of string view not implemented", ORT_RUNTIME_EXCEPTION);
297 }
298 std::string_view AsScalar() {
299 if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
300 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
301 }
302 return input_string_views_[0];
303 }
304
305 private:
306 std::vector<char> chars_; // for input
307 std::vector<std::string_view> input_string_views_; // for input
308};
309
310using TensorPtr = std::unique_ptr<Custom::TensorBase>;
311using TensorPtrs = std::vector<TensorPtr>;
312
313// Represent variadic input or output
314struct Variadic : public TensorBase {
315 Variadic(const OrtW::CustomOpApi& api,
316 OrtKernelContext& ctx,
317 size_t indice,
318 bool is_input) : TensorBase(api,
319 ctx,
320 indice,
321 is_input) {
322#if ORT_API_VERSION < 14
323 ORTX_CXX_API_THROW("Variadic input or output only supported after onnxruntime 1.14", ORT_RUNTIME_EXCEPTION);
324#endif
325 if (is_input) {
326 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
327 for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
328 auto* const_value = api_.KernelContext_GetInput(&ctx_, ith_input);
329 auto* info = api_.GetTensorTypeAndShape(const_value);
330 auto type = api_.GetTensorElementType(info);
331 api_.ReleaseTensorTypeAndShapeInfo(info);
332 TensorPtr tensor;
333 switch (type) {
334 case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
335 tensor = std::make_unique<Custom::Tensor<bool>>(api, ctx, ith_input, true);
336 break;
337 case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
338 tensor = std::make_unique<Custom::Tensor<float>>(api, ctx, ith_input, true);
339 break;
340 case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
341 tensor = std::make_unique<Custom::Tensor<double>>(api, ctx, ith_input, true);
342 break;
343 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
344 tensor = std::make_unique<Custom::Tensor<uint8_t>>(api, ctx, ith_input, true);
345 break;
346 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
347 tensor = std::make_unique<Custom::Tensor<int8_t>>(api, ctx, ith_input, true);
348 break;
349 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
350 tensor = std::make_unique<Custom::Tensor<uint16_t>>(api, ctx, ith_input, true);
351 break;
352 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
353 tensor = std::make_unique<Custom::Tensor<int16_t>>(api, ctx, ith_input, true);
354 break;
355 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
356 tensor = std::make_unique<Custom::Tensor<uint32_t>>(api, ctx, ith_input, true);
357 break;
358 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
359 tensor = std::make_unique<Custom::Tensor<int32_t>>(api, ctx, ith_input, true);
360 break;
361 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
362 tensor = std::make_unique<Custom::Tensor<uint64_t>>(api, ctx, ith_input, true);
363 break;
364 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
365 tensor = std::make_unique<Custom::Tensor<int64_t>>(api, ctx, ith_input, true);
366 break;
367 case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
368 tensor = std::make_unique<Custom::Tensor<std::string_view>>(api, ctx, ith_input, true);
369 break;
370 default:
371 ORTX_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION);
372 break;
373 }
374 tensors_.emplace_back(tensor.release());
375 } // for
376 }
377 }
378 template<typename T>
379 T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) {
380 auto tensor = std::make_unique<Tensor<T>>(api_, ctx_, ith_output, false);
381 auto raw_output = tensor.get()->Allocate(shape);
382 tensors_.emplace_back(tensor.release());
383 return raw_output;
384 }
385 Tensor<std::string>& AllocateStringTensor(size_t ith_output) {
386 auto tensor = std::make_unique<Tensor<std::string>>(api_, ctx_, ith_output, false);
387 Tensor<std::string>& output = *tensor;
388 tensors_.emplace_back(tensor.release());
389 return output;
390 }
391 const void* DataRaw() const override {
392 ORTX_CXX_API_THROW("DataRaw() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION);
393 return nullptr;
394 }
395 size_t SizeInBytes() const override {
396 ORTX_CXX_API_THROW("SizeInBytes() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION);
397 return 0;
398 }
399 size_t Size() const {
400 return tensors_.size();
401 }
402 const TensorPtr& operator[](size_t indice) const {
403 return tensors_.at(indice);
404 }
405 private:
406 TensorPtrs tensors_;
407};
408
409struct OrtLiteCustomOp : public OrtCustomOp {
410 using ConstOptionalFloatTensor = std::optional<const Custom::Tensor<float>&>;
411 using OptionalFloatTensor = std::optional<Custom::Tensor<float>>;
412
413 // CreateTuple
414 template <size_t ith_input, size_t ith_output, typename... Ts>
415 static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
416 CreateTuple(const OrtW::CustomOpApi*, OrtKernelContext*, std::vector<TensorPtr>&, size_t, size_t, const std::string&) {
417 return std::make_tuple();
418 }
419
420 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
421 static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type
422 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
423 std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
424 auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
425 return std::tuple_cat(current, next);
426 }
427
428#if ORT_API_VERSION >= 14
429 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
430 static typename std::enable_if<std::is_same<T, const Variadic*>::value, std::tuple<T, Ts...>>::type
431 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
432 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
433 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
434 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
435 return std::tuple_cat(current, next);
436 }
437
438 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
439 static typename std::enable_if<std::is_same<T, const Variadic&>::value, std::tuple<T, Ts...>>::type
440 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
441 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
442 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
443 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
444 return std::tuple_cat(current, next);
445 }
446
447 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
448 static typename std::enable_if<std::is_same<T, Variadic*>::value, std::tuple<T, Ts...>>::type
449 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
450 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
451 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
452 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
453 return std::tuple_cat(current, next);
454 }
455
456 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
457 static typename std::enable_if<std::is_same<T, Variadic&>::value, std::tuple<T, Ts...>>::type
458 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
459 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
460 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
461 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
462 return std::tuple_cat(current, next);
463 }
464#endif
465
466#define CREATE_TUPLE_INPUT(data_type) \
467 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
468 static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
469 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
470 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
471 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; \
472 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
473 return std::tuple_cat(current, next); \
474 } \
475 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
476 static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
477 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
478 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
479 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; \
480 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
481 return std::tuple_cat(current, next); \
482 } \
483 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
484 static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
485 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
486 if (ith_input < num_input) { \
487 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
488 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())}; \
489 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
490 return std::tuple_cat(current, next); \
491 } else { \
492 std::tuple<T> current = std::tuple<T>{}; \
493 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
494 return std::tuple_cat(current, next); \
495 } \
496 } \
497 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
498 static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>*>::value, std::tuple<T, Ts...>>::type \
499 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
500 if ("CPUExecutionProvider" != ep) { \
501 ORTX_CXX_API_THROW("span input could only be applied to CPU EP", ORT_FAIL); \
502 } \
503 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
504 std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
505 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
506 return std::tuple_cat(current, next); \
507 } \
508 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
509 static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>&>::value, std::tuple<T, Ts...>>::type \
510 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
511 if ("CPUExecutionProvider" != ep) { \
512 ORTX_CXX_API_THROW("span input could only be applied to CPU EP", ORT_FAIL); \
513 } \
514 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
515 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
516 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
517 return std::tuple_cat(current, next); \
518 } \
519 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
520 static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type>*>>::value, std::tuple<T, Ts...>>::type \
521 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
522 if (ith_input < num_input) { \
523 if ("CPUExecutionProvider" != ep) { \
524 ORTX_CXX_API_THROW("span input could only be applied to CPU EP", ORT_FAIL); \
525 } \
526 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
527 std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
528 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
529 return std::tuple_cat(current, next); \
530 } else { \
531 std::tuple<T> current = std::tuple<T>{}; \
532 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
533 return std::tuple_cat(current, next); \
534 } \
535 } \
536 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
537 static typename std::enable_if<std::is_same<T, data_type>::value, std::tuple<T, Ts...>>::type \
538 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
539 if ("CPUExecutionProvider" != ep) { \
540 ORTX_CXX_API_THROW("scalar input could only be applied to CPU EP", ORT_FAIL); \
541 } \
542 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
543 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsScalar()}; \
544 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
545 return std::tuple_cat(current, next); \
546 } \
547 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
548 static typename std::enable_if<std::is_same<T, std::optional<data_type>>::value, std::tuple<T, Ts...>>::type \
549 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
550 if (ith_input < num_input) { \
551 if ("CPUExecutionProvider" != ep) { \
552 ORTX_CXX_API_THROW("scalar input could only be applied to CPU EP", ORT_FAIL); \
553 } \
554 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
555 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsScalar()}; \
556 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
557 return std::tuple_cat(current, next); \
558 } else { \
559 std::tuple<T> current = std::tuple<T>{}; \
560 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
561 return std::tuple_cat(current, next); \
562 } \
563 }
564#define CREATE_TUPLE_OUTPUT(data_type) \
565 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
566 static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
567 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
568 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
569 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; \
570 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
571 return std::tuple_cat(current, next); \
572 } \
573 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
574 static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
575 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
576 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
577 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; \
578 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
579 return std::tuple_cat(current, next); \
580 } \
581 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
582 static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
583 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
584 if (ith_output < num_output) { \
585 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
586 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())}; \
587 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
588 return std::tuple_cat(current, next); \
589 } else { \
590 std::tuple<T> current = std::tuple<T>{}; \
591 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
592 return std::tuple_cat(current, next); \
593 } \
594 }
595#define CREATE_TUPLE(data_type) \
596 CREATE_TUPLE_INPUT(data_type) \
597 CREATE_TUPLE_OUTPUT(data_type)
598
599 CREATE_TUPLE(bool)
600 CREATE_TUPLE(float)
601 CREATE_TUPLE(double)
602 CREATE_TUPLE(int8_t)
603 CREATE_TUPLE(int16_t)
604 CREATE_TUPLE(int32_t)
605 CREATE_TUPLE(int64_t)
606 CREATE_TUPLE(uint8_t)
607 CREATE_TUPLE(uint16_t)
608 CREATE_TUPLE(uint32_t)
609 CREATE_TUPLE(uint64_t)
610 CREATE_TUPLE(std::string)
611 CREATE_TUPLE_INPUT(std::string_view)
612
613 // ParseArgs ...
614 template <typename... Ts>
615 static typename std::enable_if<0 == sizeof...(Ts)>::type
616 ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
617 }
618
619 template <typename T, typename... Ts>
620 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type
621 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
622 ParseArgs<Ts...>(input_types, output_types);
623 }
624
625#if ORT_API_VERSION >= 14
626 template <typename T, typename... Ts>
627 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic&>::value>::type
628 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
629 if (!input_types.empty()) {
630 ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
631 }
632 input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
633 ParseArgs<Ts...>(input_types, output_types);
634 }
635
636 template <typename T, typename... Ts>
637 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic*>::value>::type
638 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
639 if (!input_types.empty()) {
640 ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
641 }
642 input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
643 ParseArgs<Ts...>(input_types, output_types);
644 }
645
646 template <typename T, typename... Ts>
647 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic&>::value>::type
648 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
649 if (!output_types.empty()) {
650 ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
651 }
652 output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
653 ParseArgs<Ts...>(input_types, output_types);
654 }
655
656 template <typename T, typename... Ts>
657 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic*>::value>::type
658 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
659 if (!output_types.empty()) {
660 ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
661 }
662 output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
663 ParseArgs<Ts...>(input_types, output_types);
664 }
665#endif
666
667#define PARSE_INPUT_BASE(pack_type, onnx_type) \
668 template <typename T, typename... Ts> \
669 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
670 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
671 input_types.push_back(onnx_type); \
672 ParseArgs<Ts...>(input_types, output_types); \
673 } \
674 template <typename T, typename... Ts> \
675 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \
676 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
677 input_types.push_back(onnx_type); \
678 ParseArgs<Ts...>(input_types, output_types); \
679 }
680
681#define PARSE_INPUT(data_type, onnx_type) \
682 PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
683 PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
684 PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \
685 PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \
686 PARSE_INPUT_BASE(data_type, onnx_type)
687
688#define PARSE_OUTPUT(data_type, onnx_type) \
689 template <typename T, typename... Ts> \
690 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \
691 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
692 output_types.push_back(onnx_type); \
693 ParseArgs<Ts...>(input_types, output_types); \
694 } \
695 template <typename T, typename... Ts> \
696 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \
697 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
698 output_types.push_back(onnx_type); \
699 ParseArgs<Ts...>(input_types, output_types); \
700 } \
701 template <typename T, typename... Ts> \
702 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
703 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
704 output_types.push_back(onnx_type); \
705 ParseArgs<Ts...>(input_types, output_types); \
706 }
707
708#define PARSE_ARGS(data_type, onnx_type) \
709 PARSE_INPUT(data_type, onnx_type) \
710 PARSE_OUTPUT(data_type, onnx_type)
711
712 PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
713 PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
714 PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
715 PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
716 PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
717 PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
718 PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
719 PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
720 PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
721 PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
722 PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
723 PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
724 PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output
725
726 OrtLiteCustomOp(const char* op_name,
727 const char* execution_provider) : op_name_(op_name),
728 execution_provider_(execution_provider) {
729 OrtCustomOp::version = GetActiveOrtAPIVersion();
730
731 OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
732 OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
733
734 OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) {
735 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
736 return self->input_types_.size();
737 };
738
739 OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) {
740 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
741 return self->input_types_[indice];
742 };
743
744 OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) {
745 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
746 return self->output_types_.size();
747 };
748
749 OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) {
750 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
751 return self->output_types_[indice];
752 };
753
754#if ORT_API_VERSION >= 14
755 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t) {
756 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
757 return (self->input_types_.empty() || self->input_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
758 };
759
760 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
761 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
762 return (self->output_types_.empty() || self->output_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
763 };
764
765 OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) {
766 return 1;
767 };
768
769 OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) {
770 return 0;
771 };
772
773 OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) {
774 return 1;
775 };
776
777 OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) {
778 return 0;
779 };
780
781 OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) {
782 return OrtMemTypeDefault;
783 };
784#else
785 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) {
786 return INPUT_OUTPUT_OPTIONAL;
787 };
788
789 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
790 return INPUT_OUTPUT_OPTIONAL;
791 };
792#endif
793 }
794
795 const std::string op_name_;
796 const std::string execution_provider_;
797
798 std::vector<ONNXTensorElementDataType> input_types_;
799 std::vector<ONNXTensorElementDataType> output_types_;
800};
801
802template <typename... Args>
803struct OrtLiteCustomFunc : public OrtLiteCustomOp {
804 using ComputeFn = void (*)(Args...);
805 using MyType = OrtLiteCustomFunc<Args...>;
806
807 struct Kernel {
808 ComputeFn compute_fn_{};
809 std::string ep_{};
810 std::unique_ptr<OrtW::CustomOpApi> api_;
811 };
812
813 OrtLiteCustomFunc(const char* op_name,
814 const char* execution_provider,
815 ComputeFn compute_fn) : OrtLiteCustomOp(op_name, execution_provider),
816 compute_fn_(compute_fn) {
817 ParseArgs<Args...>(input_types_, output_types_);
818
819 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
820 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
821 std::vector<TensorPtr> tensors;
822 auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
823 context,
824 tensors,
825 kernel->api_->KernelContext_GetInputCount(context),
826 kernel->api_->KernelContext_GetOutputCount(context),
827 kernel->ep_);
828 std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t);
829 };
830
831 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
832 auto kernel = std::make_unique<Kernel>();
833 auto self = static_cast<const OrtLiteCustomFunc*>(this_);
834 kernel->compute_fn_ = self->compute_fn_;
835 kernel->ep_ = self->execution_provider_;
836 kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
837 return reinterpret_cast<void*>(kernel.release());
838 };
839
840 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
841 delete reinterpret_cast<Kernel*>(op_kernel);
842 };
843 }
844
845 ComputeFn compute_fn_;
846};
847
848template <typename CustomOp>
849struct OrtLiteCustomStruct : public OrtLiteCustomOp {
850 template <typename... Args>
851 using CustomComputeFn = void (CustomOp::*)(Args...);
852 using MyType = OrtLiteCustomStruct<CustomOp>;
853
854 struct Kernel {
855 std::unique_ptr<CustomOp> custom_op_;
856 std::string ep_{};
857 std::unique_ptr<OrtW::CustomOpApi> api_;
858 };
859
860 OrtLiteCustomStruct(const char* op_name,
861 const char* execution_provider) : OrtLiteCustomOp(op_name,
862 execution_provider) {
863 init(&CustomOp::Compute);
864 }
865
866 template <typename... Args>
867 void init(CustomComputeFn<Args...>) {
868 ParseArgs<Args...>(input_types_, output_types_);
869
870 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
871 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
872 std::vector<TensorPtr> tensors;
873 auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
874 context,
875 tensors,
876 kernel->api_->KernelContext_GetInputCount(context),
877 kernel->api_->KernelContext_GetOutputCount(context),
878 kernel->ep_);
879 std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t);
880 };
881
882 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
883 auto kernel = std::make_unique<Kernel>();
884 kernel->custom_op_ = std::make_unique<CustomOp>(*ort_api, *info);
885 auto self = static_cast<const MyType*>(this_);
886 kernel->ep_ = self->execution_provider_;
887 kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
888 return reinterpret_cast<void*>(kernel.release());
889 };
890
891 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
892 delete reinterpret_cast<Kernel*>(op_kernel);
893 };
894 }
895};
896
897template <typename... Args>
898OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
899 const char* execution_provider,
900 void (*custom_compute_fn)(Args...)) {
901 using LiteOp = OrtLiteCustomFunc<Args...>;
902 return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn).release();
903}
904
905template <typename CustomOp>
906OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
907 const char* execution_provider) {
908 using LiteOp = OrtLiteCustomStruct<CustomOp>;
909 return std::make_unique<LiteOp>(op_name, execution_provider).release();
910}
911
912} // namespace Custom
913} // namespace Ort
914