microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
edgchen1/update_xcode_fix

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/custom_op_lite.h

663lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5#include "onnxruntime_customop.hpp"
6#include <optional>
7#include <numeric>
8// uplevel the version when supported ort version migrates to newer ones
9#define SUPPORT_ORT_API_VERSION_TO 13
10
11namespace Ort {
12namespace Custom {
13
14class TensorBase {
15 public:
16 TensorBase(const OrtW::CustomOpApi& api,
17 OrtKernelContext& ctx,
18 size_t indice,
19 bool is_input) : api_(api),
20 ctx_(ctx),
21 indice_(indice),
22 is_input_(is_input) {}
23
24 virtual ~TensorBase() = default;
25 operator bool() const {
26 return shape_.has_value();
27 }
28 const std::vector<int64_t>& Shape() const {
29 if (shape_.has_value()) {
30 return *shape_;
31 } else {
32 ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION);
33 }
34 }
35 int64_t NumberOfElement() const {
36 if (shape_.has_value()) {
37 return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies<int64_t>());
38 } else {
39 ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION);
40 }
41 }
42 std::string Shape2Str() const {
43 if (shape_.has_value()) {
44 std::string shape_str;
45 for (const auto& dim: *shape_) {
46 shape_str.append(std::to_string(dim));
47 shape_str.append(", ");
48 }
49 return shape_str;
50 } else {
51 return "empty";
52 }
53 }
54 protected:
55 const OrtW::CustomOpApi& api_;
56 OrtKernelContext& ctx_;
57 size_t indice_;
58 bool is_input_;
59 std::optional<std::vector<int64_t>> shape_;
60};
61
62template <typename T>
63struct Span {
64 const T* data_ = {};
65 size_t size_ = {};
66 void Assign(const T* data, size_t size) {
67 data_ = data;
68 size_ = size;
69 }
70 size_t size() const { return size_; }
71 T operator[](size_t indice) const {
72 return data_[indice];
73 }
74 const T* data() const { return data_; }
75};
76
77template <typename T>
78class Tensor : public TensorBase {
79 public:
80 using TT = typename std::remove_reference<T>::type;
81 Tensor(const OrtW::CustomOpApi& api,
82 OrtKernelContext& ctx,
83 size_t indice,
84 bool is_input) : TensorBase(api,
85 ctx,
86 indice,
87 is_input) {
88 if (is_input) {
89 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
90 if (indice >= input_count) {
91 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
92 }
93 const_value_ = api_.KernelContext_GetInput(&ctx_, indice);
94 auto* info = api_.GetTensorTypeAndShape(const_value_);
95 shape_ = api_.GetTensorShape(info);
96 api_.ReleaseTensorTypeAndShapeInfo(info);
97 }
98 }
99 const TT* Data() const {
100 return api_.GetTensorData<TT>(const_value_);
101 }
102 TT* Allocate(const std::vector<int64_t>& shape) {
103 if (!data_) {
104 OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size());
105 shape_ = shape;
106 data_ = api_.GetTensorMutableData<TT>(out);
107 }
108 return data_;
109 }
110 const Span<T>& AsSpan() {
111 if (!shape_.has_value() || shape_->size() != 1) {
112 ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
113 }
114 span_.Assign(Data(), (*shape_)[0]);
115 return span_;
116 }
117 const T& AsScalar() {
118 if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
119 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
120 }
121 return *Data();
122 }
123
124 private:
125 const OrtValue* const_value_{}; // for input
126 TT* data_{}; // for output
127 Span<T> span_;
128};
129
130template <>
131class Tensor<std::string> : public TensorBase {
132 public:
133 using strings = std::vector<std::string>;
134
135 Tensor(const OrtW::CustomOpApi& api,
136 OrtKernelContext& ctx,
137 size_t indice,
138 bool is_input) : TensorBase(api,
139 ctx,
140 indice,
141 is_input) {
142 if (is_input) {
143 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
144 if (indice >= input_count) {
145 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
146 }
147
148 auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
149 auto* info = api_.GetTensorTypeAndShape(const_value);
150 shape_ = api_.GetTensorShape(info);
151 api_.ReleaseTensorTypeAndShapeInfo(info);
152
153 size_t num_chars;
154 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
155 std::vector<char> chars(num_chars + 1, '\0');
156 auto num_strings = NumberOfElement();
157 std::vector<size_t> offsets(NumberOfElement());
158 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
159 (void*)chars.data(),
160 num_chars,
161 offsets.data(),
162 offsets.size()));
163 auto upper_bound = static_cast<int64_t>(num_strings) - 1;
164 input_strings_.resize(num_strings);
165 for (int64_t i = upper_bound; i >= 0; --i) {
166 if (i < upper_bound) {
167 chars[offsets[i + 1]] = '\0';
168 }
169 input_strings_[i] = chars.data() + offsets[i];
170 }
171 }
172 }
173 const strings& Data() const {
174 return input_strings_;
175 }
176 void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) {
177 std::vector<const char*> raw;
178 for (const auto& s : ss) {
179 raw.push_back(s.data());
180 }
181 auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
182 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, raw.data(), raw.size()));
183 }
184 void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) {
185 auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size());
186 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, ss.data(), ss.size()));
187 }
188 const Span<std::string>& AsSpan() {
189 ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION);
190 }
191 const std::string& AsScalar() {
192 if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
193 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
194 }
195 return input_strings_[0];
196 }
197
198 private:
199 std::vector<std::string> input_strings_; // for input
200};
201
202template <>
203class Tensor<std::string_view> : public TensorBase {
204 public:
205 using strings = std::vector<std::string>;
206 using string_views = std::vector<std::string_view>;
207
208 Tensor(const OrtW::CustomOpApi& api,
209 OrtKernelContext& ctx,
210 size_t indice,
211 bool is_input) : TensorBase(api,
212 ctx,
213 indice,
214 is_input) {
215 if (is_input_) {
216 auto input_count = api_.KernelContext_GetInputCount(&ctx_);
217 if (indice >= input_count) {
218 ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION);
219 }
220 auto* const_value = api_.KernelContext_GetInput(&ctx_, indice);
221 auto* info = api_.GetTensorTypeAndShape(const_value);
222 shape_ = api_.GetTensorShape(info);
223 api_.ReleaseTensorTypeAndShapeInfo(info);
224
225 size_t num_chars;
226 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars));
227 chars_.resize(num_chars + 1, '\0');
228
229 auto num_strings = static_cast<size_t>(NumberOfElement());
230 if (num_strings) {
231 std::vector<size_t> offsets(num_strings);
232 OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value,
233 (void*)chars_.data(),
234 num_chars,
235 offsets.data(),
236 offsets.size()));
237 offsets.push_back(num_chars);
238 for (size_t i = 0; i < num_strings; ++i) {
239 input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]);
240 }
241 }
242 }
243 }
244 int64_t NumberOfElement() const {
245 if (shape_.has_value()) {
246 return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies<int64_t>());
247 } else {
248 return 0;
249 }
250 }
251 const string_views& Data() const {
252 return input_string_views_;
253 }
254 const Span<std::string_view>& AsSpan() {
255 ORTX_CXX_API_THROW("span for TensorT of string view not implemented", ORT_RUNTIME_EXCEPTION);
256 }
257 std::string_view AsScalar() {
258 if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) {
259 ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION);
260 }
261 return input_string_views_[0];
262 }
263
264 private:
265 std::vector<char> chars_; // for input
266 std::vector<std::string_view> input_string_views_; // for input
267};
268
269using TensorPtr = std::unique_ptr<Custom::TensorBase>;
270
271struct OrtLiteCustomOp : public OrtCustomOp {
272 using ConstOptionalFloatTensor = std::optional<const Custom::Tensor<float>&>;
273 using OptionalFloatTensor = std::optional<Custom::Tensor<float>>;
274
275 // CreateTuple
276 template <size_t ith_input, size_t ith_output, typename... Ts>
277 static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type
278 CreateTuple(const OrtW::CustomOpApi*, OrtKernelContext*, std::vector<TensorPtr>&, size_t, size_t, const std::string&) {
279 return std::make_tuple();
280 }
281
282 template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
283 static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type
284 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
285 std::tuple<T> current = std::tuple<OrtKernelContext*>{context};
286 auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
287 return std::tuple_cat(current, next);
288 }
289
290#define CREATE_TUPLE_INPUT(data_type) \
291 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
292 static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
293 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
294 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
295 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; \
296 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
297 return std::tuple_cat(current, next); \
298 } \
299 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
300 static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
301 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
302 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
303 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; \
304 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
305 return std::tuple_cat(current, next); \
306 } \
307 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
308 static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
309 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
310 if (ith_input < num_input) { \
311 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
312 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())}; \
313 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
314 return std::tuple_cat(current, next); \
315 } else { \
316 std::tuple<T> current = std::tuple<T>{}; \
317 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
318 return std::tuple_cat(current, next); \
319 } \
320 } \
321 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
322 static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>*>::value, std::tuple<T, Ts...>>::type \
323 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
324 if ("CPUExecutionProvider" != ep) { \
325 ORTX_CXX_API_THROW("span input could only be applied to CPU EP", ORT_FAIL); \
326 } \
327 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
328 std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
329 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
330 return std::tuple_cat(current, next); \
331 } \
332 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
333 static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>&>::value, std::tuple<T, Ts...>>::type \
334 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
335 if ("CPUExecutionProvider" != ep) { \
336 ORTX_CXX_API_THROW("span input could only be applied to CPU EP", ORT_FAIL); \
337 } \
338 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
339 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
340 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
341 return std::tuple_cat(current, next); \
342 } \
343 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
344 static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type>*>>::value, std::tuple<T, Ts...>>::type \
345 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
346 if (ith_input < num_input) { \
347 if ("CPUExecutionProvider" != ep) { \
348 ORTX_CXX_API_THROW("span input could only be applied to CPU EP", ORT_FAIL); \
349 } \
350 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
351 std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \
352 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
353 return std::tuple_cat(current, next); \
354 } else { \
355 std::tuple<T> current = std::tuple<T>{}; \
356 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
357 return std::tuple_cat(current, next); \
358 } \
359 } \
360 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
361 static typename std::enable_if<std::is_same<T, data_type>::value, std::tuple<T, Ts...>>::type \
362 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
363 if ("CPUExecutionProvider" != ep) { \
364 ORTX_CXX_API_THROW("scalar input could only be applied to CPU EP", ORT_FAIL); \
365 } \
366 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
367 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsScalar()}; \
368 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
369 return std::tuple_cat(current, next); \
370 } \
371 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
372 static typename std::enable_if<std::is_same<T, std::optional<data_type>>::value, std::tuple<T, Ts...>>::type \
373 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
374 if (ith_input < num_input) { \
375 if ("CPUExecutionProvider" != ep) { \
376 ORTX_CXX_API_THROW("scalar input could only be applied to CPU EP", ORT_FAIL); \
377 } \
378 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \
379 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsScalar()}; \
380 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
381 return std::tuple_cat(current, next); \
382 } else { \
383 std::tuple<T> current = std::tuple<T>{}; \
384 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \
385 return std::tuple_cat(current, next); \
386 } \
387 }
388#define CREATE_TUPLE_OUTPUT(data_type) \
389 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
390 static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \
391 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
392 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
393 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; \
394 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
395 return std::tuple_cat(current, next); \
396 } \
397 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
398 static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \
399 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
400 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
401 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; \
402 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
403 return std::tuple_cat(current, next); \
404 } \
405 template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \
406 static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \
407 CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \
408 if (ith_output < num_output) { \
409 tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \
410 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())}; \
411 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
412 return std::tuple_cat(current, next); \
413 } else { \
414 std::tuple<T> current = std::tuple<T>{}; \
415 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \
416 return std::tuple_cat(current, next); \
417 } \
418 }
419#define CREATE_TUPLE(data_type) \
420 CREATE_TUPLE_INPUT(data_type) \
421 CREATE_TUPLE_OUTPUT(data_type)
422
423 CREATE_TUPLE(bool)
424 CREATE_TUPLE(float)
425 CREATE_TUPLE(double)
426 CREATE_TUPLE(int8_t)
427 CREATE_TUPLE(int16_t)
428 CREATE_TUPLE(int32_t)
429 CREATE_TUPLE(int64_t)
430 CREATE_TUPLE(uint8_t)
431 CREATE_TUPLE(uint16_t)
432 CREATE_TUPLE(uint32_t)
433 CREATE_TUPLE(uint64_t)
434 CREATE_TUPLE(std::string)
435 CREATE_TUPLE_INPUT(std::string_view)
436
437 // ParseArgs ...
438 template <typename... Ts>
439 static typename std::enable_if<0 == sizeof...(Ts)>::type
440 ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) {
441 }
442
443 template <typename T, typename... Ts>
444 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type
445 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) {
446 ParseArgs<Ts...>(input_types, output_types);
447 }
448
449#define PARSE_INPUT_BASE(pack_type, onnx_type) \
450 template <typename T, typename... Ts> \
451 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \
452 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
453 input_types.push_back(onnx_type); \
454 ParseArgs<Ts...>(input_types, output_types); \
455 } \
456 template <typename T, typename... Ts> \
457 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \
458 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
459 input_types.push_back(onnx_type); \
460 ParseArgs<Ts...>(input_types, output_types); \
461 }
462
463#define PARSE_INPUT(data_type, onnx_type) \
464 PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \
465 PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \
466 PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \
467 PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \
468 PARSE_INPUT_BASE(data_type, onnx_type)
469
470#define PARSE_OUTPUT(data_type, onnx_type) \
471 template <typename T, typename... Ts> \
472 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \
473 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
474 output_types.push_back(onnx_type); \
475 ParseArgs<Ts...>(input_types, output_types); \
476 } \
477 template <typename T, typename... Ts> \
478 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \
479 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
480 output_types.push_back(onnx_type); \
481 ParseArgs<Ts...>(input_types, output_types); \
482 } \
483 template <typename T, typename... Ts> \
484 static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \
485 ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \
486 output_types.push_back(onnx_type); \
487 ParseArgs<Ts...>(input_types, output_types); \
488 }
489
490#define PARSE_ARGS(data_type, onnx_type) \
491 PARSE_INPUT(data_type, onnx_type) \
492 PARSE_OUTPUT(data_type, onnx_type)
493
494 PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL)
495 PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT)
496 PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE)
497 PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8)
498 PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16)
499 PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32)
500 PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64)
501 PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8)
502 PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16)
503 PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32)
504 PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64)
505 PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING)
506 PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output
507
508 OrtLiteCustomOp(const char* op_name,
509 const char* execution_provider) : op_name_(op_name),
510 execution_provider_(execution_provider) {
511 OrtCustomOp::version = MIN_ORT_VERSION_SUPPORTED;
512
513 OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
514 OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); };
515
516 OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) {
517 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
518 return self->input_types_.size();
519 };
520
521 OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) {
522 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
523 return self->input_types_[indice];
524 };
525
526 OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) {
527 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
528 return self->output_types_.size();
529 };
530
531 OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) {
532 auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
533 return self->output_types_[indice];
534 };
535
536 OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) {
537 return INPUT_OUTPUT_OPTIONAL;
538 };
539
540 OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp*, size_t) {
541 return INPUT_OUTPUT_OPTIONAL;
542 };
543 }
544
545 const std::string op_name_;
546 const std::string execution_provider_;
547
548 std::vector<ONNXTensorElementDataType> input_types_;
549 std::vector<ONNXTensorElementDataType> output_types_;
550};
551
552template <typename... Args>
553struct OrtLiteCustomFunc : public OrtLiteCustomOp {
554 using ComputeFn = void (*)(Args...);
555 using MyType = OrtLiteCustomFunc<Args...>;
556
557 struct Kernel {
558 ComputeFn compute_fn_{};
559 std::string ep_{};
560 std::unique_ptr<OrtW::CustomOpApi> api_;
561 };
562
563 OrtLiteCustomFunc(const char* op_name,
564 const char* execution_provider,
565 ComputeFn compute_fn) : OrtLiteCustomOp(op_name, execution_provider),
566 compute_fn_(compute_fn) {
567 ParseArgs<Args...>(input_types_, output_types_);
568
569 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
570 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
571 std::vector<TensorPtr> tensors;
572 auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
573 context,
574 tensors,
575 kernel->api_->KernelContext_GetInputCount(context),
576 kernel->api_->KernelContext_GetOutputCount(context),
577 kernel->ep_);
578 std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t);
579 };
580
581 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
582 auto kernel = std::make_unique<Kernel>();
583 auto self = static_cast<const OrtLiteCustomFunc*>(this_);
584 kernel->compute_fn_ = self->compute_fn_;
585 kernel->ep_ = self->execution_provider_;
586 kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
587 return reinterpret_cast<void*>(kernel.release());
588 };
589
590 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
591 delete reinterpret_cast<Kernel*>(op_kernel);
592 };
593 }
594
595 ComputeFn compute_fn_;
596};
597
598template <typename CustomOp>
599struct OrtLiteCustomStruct : public OrtLiteCustomOp {
600 template <typename... Args>
601 using CustomComputeFn = void (CustomOp::*)(Args...);
602 using MyType = OrtLiteCustomStruct<CustomOp>;
603
604 struct Kernel {
605 std::unique_ptr<CustomOp> custom_op_;
606 std::string ep_{};
607 std::unique_ptr<OrtW::CustomOpApi> api_;
608 };
609
610 OrtLiteCustomStruct(const char* op_name,
611 const char* execution_provider) : OrtLiteCustomOp(op_name,
612 execution_provider) {
613 init(&CustomOp::Compute);
614 }
615
616 template <typename... Args>
617 void init(CustomComputeFn<Args...>) {
618 ParseArgs<Args...>(input_types_, output_types_);
619
620 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
621 auto kernel = reinterpret_cast<Kernel*>(op_kernel);
622 std::vector<TensorPtr> tensors;
623 auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(),
624 context,
625 tensors,
626 kernel->api_->KernelContext_GetInputCount(context),
627 kernel->api_->KernelContext_GetOutputCount(context),
628 kernel->ep_);
629 std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t);
630 };
631
632 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
633 auto kernel = std::make_unique<Kernel>();
634 kernel->custom_op_ = std::make_unique<CustomOp>(*ort_api, *info);
635 auto self = static_cast<const MyType*>(this_);
636 kernel->ep_ = self->execution_provider_;
637 kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api);
638 return reinterpret_cast<void*>(kernel.release());
639 };
640
641 OrtCustomOp::KernelDestroy = [](void* op_kernel) {
642 delete reinterpret_cast<Kernel*>(op_kernel);
643 };
644 }
645};
646
647template <typename... Args>
648OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
649 const char* execution_provider,
650 void (*custom_compute_fn)(Args...)) {
651 using LiteOp = OrtLiteCustomFunc<Args...>;
652 return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn).release();
653}
654
655template <typename CustomOp>
656OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
657 const char* execution_provider) {
658 using LiteOp = OrtLiteCustomStruct<CustomOp>;
659 return std::make_unique<LiteOp>(op_name, execution_provider).release();
660}
661
662} // namespace Custom
663} // namespace Ort