microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
62d8598b6b9fa462a440ade891017eaafd4bfaee

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/custom_op_lite.h

921lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5#include "onnxruntime_customop.hpp"
6#include <optional>
7#include <numeric>
8
9namespace Ort {
10namespace Custom {
11
12class TensorBase {
13 public:
14 TensorBase(const OrtW::CustomOpApi& api,
15 OrtKernelContext& ctx,
16 size_t indice,
17 bool is_input) : api_(api),
18 ctx_(ctx),
19 indice_(indice),
20 is_input_(is_input) {}
21
22 virtual ~TensorBase() = default;
23 operator bool() const {
24 return shape_.has_value();
25 }
26 const std::vector<int64_t>& Shape() const {
27 if (shape_.has_value()) {
28 return *shape_;
29 } else {
30 ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION);
31 }
32 }
33 ONNXTensorElementDataType Type() const {
34 return type_;
35 }
36 int64_t NumberOfElement() const {
37 if (shape_.has_value()) {
38 return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies<int64_t>());
39 } else {
40 ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION);
41 }
42 }
43 std::string Shape2Str() const {
44 if (shape_.has_value()) {
45 std::string shape_str;
46 for (const auto& dim : *shape_) {
47 shape_str.append(std::to_string(dim));
48 shape_str.append(", ");
49 }
50 return shape_str;
51 } else {
52 return "empty";
53 }
54 }
55 bool IsCpuTensor() const {
56 return strcmp("Cpu", mem_type_) == 0;
57 }
58 virtual const void* DataRaw() const = 0;
59 virtual size_t SizeInBytes() const = 0;
60
61 protected:
62 const OrtW::CustomOpApi& api_;
63 OrtKernelContext& ctx_;
64 size_t indice_;
65 bool is_input_;
66 std::optional<std::vector<int64_t>> shape_;
67 ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
68 const char* mem_type_ = "Cpu";
69};
70
71template <typename T>
72struct Span {
73 const T* data_ = {};
74 size_t size_ = {};
75 void Assign(const T* data, size_t size) {
76 data_ = data;
77 size_ = size;
78 }
79 size_t size() const { return size_; }
80 T operator[](size_t indice) const {
81 return data_[indice];
82 }
83 const T* data() const { return data_; }
84};
85
86template <typename T>
87class Tensor : public TensorBase {
88 public:
89 using TT = typename std::remove_reference<T>::type;
90 Tensor(const OrtW::CustomOpApi& api,
91 OrtKernelContext& ctx,
92 size_t indice,
93 bool is_input) : TensorBase(api,
94 ctx,
95 indice,
96 is_input) {
97 if (is_input) {
98 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
99 if (indice >= input_count) {
100 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
101 }
102 const_value_ = api_.KernelContext_GetInput(&ctx_, indice);
103 auto* info = api_.GetTensorTypeAndShape(const_value_);
104 shape_ = api_.GetTensorShape(info);
105 type_ = api_.GetTensorElementType(info);
106 api_.ReleaseTensorTypeAndShapeInfo(info);
107 const OrtMemoryInfo* mem_info = {};
108 api_.GetOrtApi().GetTensorMemoryInfo(const_value_, &mem_info);
109 if (mem_info) {
110 api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_);
111 }
112 }
113 }
114 const TT* Data() const {
115 return api_.GetTensorData<TT>(const_value_);
116 }
117
118 const void* DataRaw() const override {
119 return reinterpret_cast<const void*>(Data());
120 }
121
122 size_t SizeInBytes() const override {
123 return NumberOfElement() * sizeof(TT);
124 }
125
126 TT* Allocate(const std::vector<int64_t>& shape) {
127 if (!data_) {
128 OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size());
129 shape_ = shape;
130 data_ = api_.GetTensorMutableData<TT>(out);
131 }
132 return data_;
133 }
134 const Span<T>& AsSpan() {
135 if (!shape_.has_value() || shape_->size() != 1) {
136 ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
137 }
138 span_.Assign(Data(), (*shape_)[0]);
139 return span_;
140 }
141 const T& AsScalar() {
142 if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
143 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
144 }
145 return *Data();
146 }
147
148 private:
149 const OrtValue* const_value_{}; // for input
150 TT* data_{}; // for output
151 Span<T> span_;
152};
153
154template <>
155class Tensor<std::string> : public TensorBase {
156 public:
157 using strings = std::vector<std::string>;
158
159 Tensor(const OrtW::CustomOpApi& api,
160 OrtKernelContext& ctx,
161 size_t indice,
162 bool is_input) : TensorBase(api,
163 ctx,
164 indice,
165 is_input) {
166 if (is_input) {
167 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
168 if (indice >= input_count) {
169 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
170 }
171
172 auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
173 auto* info = api_.GetTensorTypeAndShape(const_value);
174 shape_ = api_.GetTensorShape(info);
175 type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
176 api_.ReleaseTensorTypeAndShapeInfo(info);
177
178 size_t num_chars;
179 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
180 std::vector<char> chars(num_chars + 1, '\0');
181 auto num_strings = NumberOfElement();
182 std::vector<size_t> offsets(NumberOfElement());
183 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
184 (void*)chars.data(),
185 num_chars,
186 offsets.data(),
187 offsets.size()));
188 auto upper_bound = static_cast<int64_t>(num_strings) - 1;
189 input_strings_.resize(num_strings);
190 for (int64_t i = upper_bound; i >= 0; --i) {
191 if (i < upper_bound) {
192 chars[offsets[i + 1]] = '\0';
193 }
194 input_strings_[i] = chars.data() + offsets[i];
195 }
196 }
197 }
198 const strings& Data() const {
199 return input_strings_;
200 }
201 const void* DataRaw() const override {
202 if (input_strings_.size() != 1) {
203 ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
204 }
205 return reinterpret_cast<const void*>(input_strings_[0].c_str());
206 }
207 size_t SizeInBytes() const override {
208 if (input_strings_.size() != 1) {
209 ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
210 }
211 return input_strings_[0].size();
212 }
213 void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
214 std::vector<const char*> raw;
215 for (const auto& s : ss) {
216 raw.push_back(s.data());
217 }
218 auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
219 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, raw.data(), raw.size()));
220 }
221 void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
222 auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
223 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, ss.data(), ss.size()));
224 }
225 const Span<std::string>& AsSpan() {
226 ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
227 }
228 const std::string& AsScalar() {
229 if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
230 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
231 }
232 return input_strings_[0];
233 }
234
235 private:
236 std::vector<std::string> input_strings_; // for input
237};
238
239template <>
240class Tensor<std::string_view> : public TensorBase {
241 public:
242 using strings = std::vector<std::string>;
243 using string_views = std::vector<std::string_view>;
244
245 Tensor(const OrtW::CustomOpApi& api,
246 OrtKernelContext& ctx,
247 size_t indice,
248 bool is_input) : TensorBase(api,
249 ctx,
250 indice,
251 is_input) {
252 if (is_input_) {
253 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
254 if (indice >= input_count) {
255 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
256 }
257 auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
258 auto* info = api_.GetTensorTypeAndShape(const_value);
259 shape_ = api_.GetTensorShape(info);
260 type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
261 api_.ReleaseTensorTypeAndShapeInfo(info);
262
263 size_t num_chars;
264 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
265 chars_.resize(num_chars + 1, '\0');
266
267 auto num_strings = static_cast<size_t>(NumberOfElement());
268 if (num_strings) {
269 std::vector<size_t> offsets(num_strings);
270 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
271 (void*)chars_.data(),
272 num_chars,
273 offsets.data(),
274 offsets.size()));
275 offsets.push_back(num_chars);
276 for (size_t i = 0; i < num_strings; ++i) {
277 input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
278 }
279 }
280 }
281 }
282 int64_t NumberOfElement() const {
283 if (shape_.has_value()) {
284 return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies<int64_t>());
285 } else {
286 return 0;
287 }
288 }
289 const string_views& Data() const {
290 return input_string_views_;
291 }
292 const void* DataRaw() const override {
293 if (input_string_views_.size() != 1) {
294 ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
295 }
296 return reinterpret_cast<const void*>(input_string_views_[0].data());
297 }
298 size_t SizeInBytes() const override {
299 if (input_string_views_.size() != 1) {
300 ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION);
301 }
302 return input_string_views_[0].size();
303 }
304 const Span<std::string_view>& AsSpan() {
305 ORTX_CXX_API_THROW("span for TensorT of string view not implemented", ORT_RUNTIME_EXCEPTION);
306 }
307 std::string_view AsScalar() {
308 if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
309 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
310 }
311 return input_string_views_[0];
312 }
313
314 private:
315 std::vector<char> chars_; // for input
316 std::vector<std::string_view> input_string_views_; // for input
317};
318
319using TensorPtr = std::unique_ptr<Custom::TensorBase>;
320using TensorPtrs = std::vector<TensorPtr>;
321
322// Represent variadic input or output
323struct Variadic : public TensorBase {
324 Variadic(const OrtW::CustomOpApi& api,
325 OrtKernelContext& ctx,
326 size_t indice,
327 bool is_input) : TensorBase(api,
328 ctx,
329 indice,
330 is_input) {
331#if ORT_API_VERSION < 14
332 ORTX_CXX_API_THROW("Variadic input or output only supported after onnxruntime 1.14", ORT_RUNTIME_EXCEPTION);
333#endif
334 if (is_input) {
335 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
336 for (size_t ith_input = 0; ith_input < input_count; ++ith_input) {
337 auto* const_value = api_.KernelContext_GetInput(&ctx_, ith_input);
338 auto* info = api_.GetTensorTypeAndShape(const_value);
339 auto type = api_.GetTensorElementType(info);
340 api_.ReleaseTensorTypeAndShapeInfo(info);
341 TensorPtr tensor;
342 switch (type) {
343 case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL:
344 tensor = std::make_unique<Custom::Tensor<bool>>(api, ctx, ith_input, true);
345 break;
346 case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
347 tensor = std::make_unique<Custom::Tensor<float>>(api, ctx, ith_input, true);
348 break;
349 case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE:
350 tensor = std::make_unique<Custom::Tensor<double>>(api, ctx, ith_input, true);
351 break;
352 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8:
353 tensor = std::make_unique<Custom::Tensor<uint8_t>>(api, ctx, ith_input, true);
354 break;
355 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8:
356 tensor = std::make_unique<Custom::Tensor<int8_t>>(api, ctx, ith_input, true);
357 break;
358 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16:
359 tensor = std::make_unique<Custom::Tensor<uint16_t>>(api, ctx, ith_input, true);
360 break;
361 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16:
362 tensor = std::make_unique<Custom::Tensor<int16_t>>(api, ctx, ith_input, true);
363 break;
364 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32:
365 tensor = std::make_unique<Custom::Tensor<uint32_t>>(api, ctx, ith_input, true);
366 break;
367 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
368 tensor = std::make_unique<Custom::Tensor<int32_t>>(api, ctx, ith_input, true);
369 break;
370 case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64:
371 tensor = std::make_unique<Custom::Tensor<uint64_t>>(api, ctx, ith_input, true);
372 break;
373 case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
374 tensor = std::make_unique<Custom::Tensor<int64_t>>(api, ctx, ith_input, true);
375 break;
376 case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING:
377 tensor = std::make_unique<Custom::Tensor<std::string_view>>(api, ctx, ith_input, true);
378 break;
379 default:
380 ORTX_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION);
381 break;
382 }
383 tensors_.emplace_back(tensor.release());
384 } // for
385 }
386 }
387 template <typename T>
388 T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) {
389 auto tensor = std::make_unique<Tensor<T>>(api_, ctx_, ith_output, false);
390 auto raw_output = tensor.get()->Allocate(shape);
391 tensors_.emplace_back(tensor.release());
392 return raw_output;
393 }
394 Tensor<std::string>& AllocateStringTensor(size_t ith_output) {
395 auto tensor = std::make_unique<Tensor<std::string>>(api_, ctx_, ith_output, false);
396 Tensor<std::string>& output = *tensor;
397 tensors_.emplace_back(tensor.release());
398 return output;
399 }
400 const void* DataRaw() const override {
401 ORTX_CXX_API_THROW("DataRaw() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION);
402 return nullptr;
403 }
404 size_t SizeInBytes() const override {
405 ORTX_CXX_API_THROW("SizeInBytes() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION);
406 return 0;
407 }
408 size_t Size() const {
409 return tensors_.size();
410 }
411 const TensorPtr& operator[](size_t indice) const {
412 return tensors_.at(indice);
413 }
414
415 private:
416 TensorPtrs tensors_;
417};
418
419struct OrtLiteCustomOp : public OrtCustomOp {
420 // CreateTuple
421 template <size_t ith_input, size_t ith_output, typename... Ts>
422 static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
423 CreateTuple(const OrtW::CustomOpApi*, OrtKernelContext*, std::vector<TensorPtr>&, size_t, size_t, const std::string&) {
424 return std::make_tuple();
425 }
426
427 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
428 static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type
429 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
430 std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
431 auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
432 return std::tuple_cat(current, next);
433 }
434
435#if ORT_API_VERSION >= 14
436 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
437 static typename std::enable_if<std::is_same<T, const Variadic*>::value, std::tuple<T, Ts...>>::type
438 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
439 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
440 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
441 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
442 return std::tuple_cat(current, next);
443 }
444
445 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
446 static typename std::enable_if<std::is_same<T, const Variadic&>::value, std::tuple<T, Ts...>>::type
447 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
448 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true));
449 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
450 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
451 return std::tuple_cat(current, next);
452 }
453
454 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
455 static typename std::enable_if<std::is_same<T, Variadic*>::value, std::tuple<T, Ts...>>::type
456 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
457 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
458 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
459 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
460 return std::tuple_cat(current, next);
461 }
462
463 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
464 static typename std::enable_if<std::is_same<T, Variadic&>::value, std::tuple<T, Ts...>>::type
465 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
466 tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false));
467 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
468 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
469 return std::tuple_cat(current, next);
470 }
471#endif
472
473#define CREATE_TUPLE_INPUT(data_type) \
474 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
475 static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
476 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
477 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
478 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; \
479 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
480 return std::tuple_cat(current, next); \
481 } \
482 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
483 static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
484 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
485 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
486 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; \
487 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
488 return std::tuple_cat(current, next); \
489 } \
490 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
491 static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
492 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
493 if (ith_input < num_input) { \
494 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
495 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())}; \
496 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
497 return std::tuple_cat(current, next); \
498 } else { \
499 std::tuple<T> current = std::tuple<T>{}; \
500 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
501 return std::tuple_cat(current, next); \
502 } \
503 } \
504 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
505 static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>*>::value, std::tuple<T, Ts...>>::type \
506 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
507 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
508 if (!tensors.back()->IsCpuTensor()) { \
509 ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \
510 } \
511 std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
512 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
513 return std::tuple_cat(current, next); \
514 } \
515 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
516 static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>&>::value, std::tuple<T, Ts...>>::type \
517 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
518 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
519 if (!tensors.back()->IsCpuTensor()) { \
520 ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \
521 } \
522 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
523 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
524 return std::tuple_cat(current, next); \
525 } \
526 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
527 static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type>*>>::value, std::tuple<T, Ts...>>::type \
528 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
529 if (ith_input < num_input) { \
530 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
531 if (!tensors.back()->IsCpuTensor()) { \
532 ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \
533 } \
534 std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
535 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
536 return std::tuple_cat(current, next); \
537 } else { \
538 std::tuple<T> current = std::tuple<T>{}; \
539 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
540 return std::tuple_cat(current, next); \
541 } \
542 } \
543 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
544 static typename std::enable_if<std::is_same<T, data_type>::value, std::tuple<T, Ts...>>::type \
545 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
546 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
547 if (!tensors.back()->IsCpuTensor()) { \
548 ORTX_CXX_API_THROW("scalar input could only be applied to CPU tensor", ORT_FAIL); \
549 } \
550 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsScalar()}; \
551 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
552 return std::tuple_cat(current, next); \
553 } \
554 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
555 static typename std::enable_if<std::is_same<T, std::optional<data_type>>::value, std::tuple<T, Ts...>>::type \
556 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
557 if (ith_input < num_input) { \
558 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
559 if (!tensors.back()->IsCpuTensor()) { \
560 ORTX_CXX_API_THROW("scalar input could only be applied to CPU tensor", ORT_FAIL); \
561 } \
562 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsScalar()}; \
563 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
564 return std::tuple_cat(current, next); \
565 } else { \
566 std::tuple<T> current = std::tuple<T>{}; \
567 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
568 return std::tuple_cat(current, next); \
569 } \
570 }
571#define CREATE_TUPLE_OUTPUT(data_type) \
572 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
573 static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
574 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
575 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
576 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; \
577 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
578 return std::tuple_cat(current, next); \
579 } \
580 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
581 static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
582 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
583 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
584 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; \
585 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
586 return std::tuple_cat(current, next); \
587 } \
588 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
589 static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
590 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
591 if (ith_output < num_output) { \
592 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
593 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())}; \
594 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
595 return std::tuple_cat(current, next); \
596 } else { \
597 std::tuple<T> current = std::tuple<T>{}; \
598 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
599 return std::tuple_cat(current, next); \
600 } \
601 }
602#define CREATE_TUPLE(data_type) \
603 CREATE_TUPLE_INPUT(data_type) \
604 CREATE_TUPLE_OUTPUT(data_type)
605
606 CREATE_TUPLE(bool)
607 CREATE_TUPLE(float)
608 CREATE_TUPLE(double)
609 CREATE_TUPLE(int8_t)
610 CREATE_TUPLE(int16_t)
611 CREATE_TUPLE(int32_t)
612 CREATE_TUPLE(int64_t)
613 CREATE_TUPLE(uint8_t)
614 CREATE_TUPLE(uint16_t)
615 CREATE_TUPLE(uint32_t)
616 CREATE_TUPLE(uint64_t)
617 CREATE_TUPLE(std::string)
618 CREATE_TUPLE_INPUT(std::string_view)
619
620 // ParseArgs ...
621 template <typename... Ts>
622 static typename std::enable_if<0 == sizeof...(Ts)>::type
623 ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
624 }
625
626 template <typename T, typename... Ts>
627 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type
628 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
629 ParseArgs<Ts...>(input_types, output_types);
630 }
631
632#if ORT_API_VERSION >= 14
633 template <typename T, typename... Ts>
634 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic&>::value>::type
635 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
636 if (!input_types.empty()) {
637 ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
638 }
639 input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
640 ParseArgs<Ts...>(input_types, output_types);
641 }
642
643 template <typename T, typename... Ts>
644 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic*>::value>::type
645 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
646 if (!input_types.empty()) {
647 ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION);
648 }
649 input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
650 ParseArgs<Ts...>(input_types, output_types);
651 }
652
653 template <typename T, typename... Ts>
654 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic&>::value>::type
655 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
656 if (!output_types.empty()) {
657 ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
658 }
659 output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
660 ParseArgs<Ts...>(input_types, output_types);
661 }
662
663 template <typename T, typename... Ts>
664 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic*>::value>::type
665 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
666 if (!output_types.empty()) {
667 ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION);
668 }
669 output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED);
670 ParseArgs<Ts...>(input_types, output_types);
671 }
672#endif
673
674#define PARSE_INPUT_BASE(pack_type, onnx_type) \
675 template <typename T, typename... Ts> \
676 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
677 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
678 input_types.push_back(onnx_type); \
679 ParseArgs<Ts...>(input_types, output_types); \
680 } \
681 template <typename T, typename... Ts> \
682 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \
683 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
684 input_types.push_back(onnx_type); \
685 ParseArgs<Ts...>(input_types, output_types); \
686 }
687
688#define PARSE_INPUT(data_type, onnx_type) \
689 PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
690 PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
691 PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \
692 PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \
693 PARSE_INPUT_BASE(data_type, onnx_type)
694
695#define PARSE_OUTPUT(data_type, onnx_type) \
696 template <typename T, typename... Ts> \
697 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \
698 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
699 output_types.push_back(onnx_type); \
700 ParseArgs<Ts...>(input_types, output_types); \
701 } \
702 template <typename T, typename... Ts> \
703 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \
704 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
705 output_types.push_back(onnx_type); \
706 ParseArgs<Ts...>(input_types, output_types); \
707 } \
708 template <typename T, typename... Ts> \
709 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
710 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
711 output_types.push_back(onnx_type); \
712 ParseArgs<Ts...>(input_types, output_types); \
713 }
714
715#define PARSE_ARGS(data_type, onnx_type) \
716 PARSE_INPUT(data_type, onnx_type) \
717 PARSE_OUTPUT(data_type, onnx_type)
718
719 PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
720 PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
721 PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
722 PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
723 PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
724 PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
725 PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
726 PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
727 PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
728 PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
729 PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
730 PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
731 PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output
732
733 OrtLiteCustomOp(const char* op_name,
734 const char* execution_provider) : op_name_(op_name),
735 execution_provider_(execution_provider) {
736 int act_ver = GetActiveOrtAPIVersion();
737 OrtCustomOp::version = act_ver < ORT_API_VERSION ? act_ver : ORT_API_VERSION;
738
739 OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
740 OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
741
742 OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) {
743 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
744 return self->input_types_.size();
745 };
746
747 OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) {
748 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
749 return self->input_types_[indice];
750 };
751
752 OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) {
753 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
754 return self->output_types_.size();
755 };
756
757 OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) {
758 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
759 return self->output_types_[indice];
760 };
761
762#if ORT_API_VERSION >= 14
763 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t) {
764 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
765 return (self->input_types_.empty() || self->input_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
766 };
767
768 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
769 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
770 return (self->output_types_.empty() || self->output_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC;
771 };
772
773 OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) {
774 return 1;
775 };
776
777 OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) {
778 return 0;
779 };
780
781 OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) {
782 return 1;
783 };
784
785 OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) {
786 return 0;
787 };
788
789 OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) {
790 return OrtMemTypeDefault;
791 };
792#else
793 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) {
794 return INPUT_OUTPUT_OPTIONAL;
795 };
796
797 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) {
798 return INPUT_OUTPUT_OPTIONAL;
799 };
800#endif
801 }
802
803 const std::string op_name_;
804 const std::string execution_provider_;
805
806 std::vector<ONNXTensorElementDataType> input_types_;
807 std::vector<ONNXTensorElementDataType> output_types_;
808};
809
810template <typename... Args>
811struct OrtLiteCustomFunc : public OrtLiteCustomOp {
812 using ComputeFn = void (*)(Args...);
813 using MyType = OrtLiteCustomFunc<Args...>;
814
815 struct Kernel {
816 ComputeFn compute_fn_{};
817 std::string ep_{};
818 std::unique_ptr<OrtW::CustomOpApi> api_;
819 };
820
821 OrtLiteCustomFunc(const char* op_name,
822 const char* execution_provider,
823 ComputeFn compute_fn) : OrtLiteCustomOp(op_name, execution_provider),
824 compute_fn_(compute_fn) {
825 ParseArgs<Args...>(input_types_, output_types_);
826
827 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
828 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
829 std::vector<TensorPtr> tensors;
830 auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
831 context,
832 tensors,
833 kernel->api_->KernelContext_GetInputCount(context),
834 kernel->api_->KernelContext_GetOutputCount(context),
835 kernel->ep_);
836 std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t);
837 };
838
839 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
840 auto kernel = std::make_unique<Kernel>();
841 auto self = static_cast<const OrtLiteCustomFunc*>(this_);
842 kernel->compute_fn_ = self->compute_fn_;
843 kernel->ep_ = self->execution_provider_;
844 kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
845 return reinterpret_cast<void*>(kernel.release());
846 };
847
848 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
849 delete reinterpret_cast<Kernel*>(op_kernel);
850 };
851 }
852
853 ComputeFn compute_fn_;
854};
855
856template <typename CustomOp>
857struct OrtLiteCustomStruct : public OrtLiteCustomOp {
858 template <typename... Args>
859 using CustomComputeFn = void (CustomOp::*)(Args...);
860 using MyType = OrtLiteCustomStruct<CustomOp>;
861
862 struct Kernel {
863 std::unique_ptr<CustomOp> custom_op_;
864 std::string ep_{};
865 std::unique_ptr<OrtW::CustomOpApi> api_;
866 };
867
868 OrtLiteCustomStruct(const char* op_name,
869 const char* execution_provider) : OrtLiteCustomOp(op_name,
870 execution_provider) {
871 init(&CustomOp::Compute);
872 }
873
874 template <typename... Args>
875 void init(CustomComputeFn<Args...>) {
876 ParseArgs<Args...>(input_types_, output_types_);
877
878 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
879 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
880 std::vector<TensorPtr> tensors;
881 auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
882 context,
883 tensors,
884 kernel->api_->KernelContext_GetInputCount(context),
885 kernel->api_->KernelContext_GetOutputCount(context),
886 kernel->ep_);
887 std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t);
888 };
889
890 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
891 auto kernel = std::make_unique<Kernel>();
892 kernel->custom_op_ = std::make_unique<CustomOp>(*ort_api, *info);
893 auto self = static_cast<const MyType*>(this_);
894 kernel->ep_ = self->execution_provider_;
895 kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
896 return reinterpret_cast<void*>(kernel.release());
897 };
898
899 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
900 delete reinterpret_cast<Kernel*>(op_kernel);
901 };
902 }
903};
904
905template <typename... Args>
906OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
907 const char* execution_provider,
908 void (*custom_compute_fn)(Args...)) {
909 using LiteOp = OrtLiteCustomFunc<Args...>;
910 return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn).release();
911}
912
913template <typename CustomOp>
914OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
915 const char* execution_provider) {
916 using LiteOp = OrtLiteCustomStruct<CustomOp>;
917 return std::make_unique<LiteOp>(op_name, execution_provider).release();
918}
919
920} // namespace Custom
921} // namespace Ort
922