microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
includes/custom_op_lite.h
923lines · modecode
| 1 | // Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | // Licensed under the MIT License. |
| 3 | |
| 4 | #pragma once |
| 5 | #include "onnxruntime_customop.hpp" |
| 6 | #include <optional> |
| 7 | #include <numeric> |
| 8 | |
| 9 | namespace Ort { |
| 10 | namespace Custom { |
| 11 | |
| 12 | class TensorBase { |
| 13 | public: |
| 14 | TensorBase(const OrtW::CustomOpApi& api, |
| 15 | OrtKernelContext& ctx, |
| 16 | size_t indice, |
| 17 | bool is_input) : api_(api), |
| 18 | ctx_(ctx), |
| 19 | indice_(indice), |
| 20 | is_input_(is_input) {} |
| 21 | |
| 22 | virtual ~TensorBase() = default; |
| 23 | operator bool() const { |
| 24 | return shape_.has_value(); |
| 25 | } |
| 26 | const std::vector<int64_t>& Shape() const { |
| 27 | if (shape_.has_value()) { |
| 28 | return *shape_; |
| 29 | } else { |
| 30 | ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION); |
| 31 | } |
| 32 | } |
| 33 | ONNXTensorElementDataType Type() const { |
| 34 | return type_; |
| 35 | } |
| 36 | int64_t NumberOfElement() const { |
| 37 | if (shape_.has_value()) { |
| 38 | return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies<int64_t>()); |
| 39 | } else { |
| 40 | ORTX_CXX_API_THROW("tensor shape is not yet initialized", ORT_RUNTIME_EXCEPTION); |
| 41 | } |
| 42 | } |
| 43 | std::string Shape2Str() const { |
| 44 | if (shape_.has_value()) { |
| 45 | std::string shape_str; |
| 46 | for (const auto& dim : *shape_) { |
| 47 | shape_str.append(std::to_string(dim)); |
| 48 | shape_str.append(", "); |
| 49 | } |
| 50 | return shape_str; |
| 51 | } else { |
| 52 | return "empty"; |
| 53 | } |
| 54 | } |
| 55 | bool IsCpuTensor() const { |
| 56 | return strcmp("Cpu", mem_type_) == 0; |
| 57 | } |
| 58 | virtual const void* DataRaw() const = 0; |
| 59 | virtual size_t SizeInBytes() const = 0; |
| 60 | |
| 61 | protected: |
| 62 | const OrtW::CustomOpApi& api_; |
| 63 | OrtKernelContext& ctx_; |
| 64 | size_t indice_; |
| 65 | bool is_input_; |
| 66 | std::optional<std::vector<int64_t>> shape_; |
| 67 | ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; |
| 68 | const char* mem_type_ = "Cpu"; |
| 69 | }; |
| 70 | |
| 71 | template <typename T> |
| 72 | struct Span { |
| 73 | const T* data_ = {}; |
| 74 | size_t size_ = {}; |
| 75 | void Assign(const T* data, size_t size) { |
| 76 | data_ = data; |
| 77 | size_ = size; |
| 78 | } |
| 79 | size_t size() const { return size_; } |
| 80 | T operator[](size_t indice) const { |
| 81 | return data_[indice]; |
| 82 | } |
| 83 | const T* data() const { return data_; } |
| 84 | }; |
| 85 | |
| 86 | template <typename T> |
| 87 | class Tensor : public TensorBase { |
| 88 | public: |
| 89 | using TT = typename std::remove_reference<T>::type; |
| 90 | Tensor(const OrtW::CustomOpApi& api, |
| 91 | OrtKernelContext& ctx, |
| 92 | size_t indice, |
| 93 | bool is_input) : TensorBase(api, |
| 94 | ctx, |
| 95 | indice, |
| 96 | is_input) { |
| 97 | if (is_input) { |
| 98 | auto input_count = api_.KernelContext_GetInputCount(&ctx_); |
| 99 | if (indice >= input_count) { |
| 100 | ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION); |
| 101 | } |
| 102 | const_value_ = api_.KernelContext_GetInput(&ctx_, indice); |
| 103 | auto* info = api_.GetTensorTypeAndShape(const_value_); |
| 104 | shape_ = api_.GetTensorShape(info); |
| 105 | type_ = api_.GetTensorElementType(info); |
| 106 | api_.ReleaseTensorTypeAndShapeInfo(info); |
| 107 | const OrtMemoryInfo* mem_info = {}; |
| 108 | api_.ThrowOnError(api_.GetOrtApi().GetTensorMemoryInfo(const_value_, &mem_info)); |
| 109 | if (mem_info) { |
| 110 | api_.ThrowOnError(api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type_)); |
| 111 | } |
| 112 | } |
| 113 | } |
| 114 | const TT* Data() const { |
| 115 | return api_.GetTensorData<TT>(const_value_); |
| 116 | } |
| 117 | |
| 118 | const void* DataRaw() const override { |
| 119 | return reinterpret_cast<const void*>(Data()); |
| 120 | } |
| 121 | |
| 122 | size_t SizeInBytes() const override { |
| 123 | return NumberOfElement() * sizeof(TT); |
| 124 | } |
| 125 | |
| 126 | TT* Allocate(const std::vector<int64_t>& shape) { |
| 127 | if (!data_) { |
| 128 | OrtValue* out = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size()); |
| 129 | shape_ = shape; |
| 130 | data_ = api_.GetTensorMutableData<TT>(out); |
| 131 | } |
| 132 | return data_; |
| 133 | } |
| 134 | const Span<T>& AsSpan() { |
| 135 | if (!shape_.has_value() || shape_->size() != 1) { |
| 136 | ORTX_CXX_API_THROW("to get a span, shape must be 1-D, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION); |
| 137 | } |
| 138 | span_.Assign(Data(), (*shape_)[0]); |
| 139 | return span_; |
| 140 | } |
| 141 | const T& AsScalar() { |
| 142 | if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) { |
| 143 | ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION); |
| 144 | } |
| 145 | return *Data(); |
| 146 | } |
| 147 | |
| 148 | private: |
| 149 | const OrtValue* const_value_{}; // for input |
| 150 | TT* data_{}; // for output |
| 151 | Span<T> span_; |
| 152 | }; |
| 153 | |
| 154 | template <> |
| 155 | class Tensor<std::string> : public TensorBase { |
| 156 | public: |
| 157 | using strings = std::vector<std::string>; |
| 158 | |
| 159 | Tensor(const OrtW::CustomOpApi& api, |
| 160 | OrtKernelContext& ctx, |
| 161 | size_t indice, |
| 162 | bool is_input) : TensorBase(api, |
| 163 | ctx, |
| 164 | indice, |
| 165 | is_input) { |
| 166 | if (is_input) { |
| 167 | auto input_count = api_.KernelContext_GetInputCount(&ctx_); |
| 168 | if (indice >= input_count) { |
| 169 | ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION); |
| 170 | } |
| 171 | |
| 172 | auto* const_value = api_.KernelContext_GetInput(&ctx_, indice); |
| 173 | auto* info = api_.GetTensorTypeAndShape(const_value); |
| 174 | shape_ = api_.GetTensorShape(info); |
| 175 | type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; |
| 176 | api_.ReleaseTensorTypeAndShapeInfo(info); |
| 177 | |
| 178 | size_t num_chars; |
| 179 | OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars)); |
| 180 | std::vector<char> chars(num_chars + 1, '\0'); |
| 181 | auto num_strings = NumberOfElement(); |
| 182 | std::vector<size_t> offsets(NumberOfElement()); |
| 183 | OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value, |
| 184 | (void*)chars.data(), |
| 185 | num_chars, |
| 186 | offsets.data(), |
| 187 | offsets.size())); |
| 188 | auto upper_bound = static_cast<int64_t>(num_strings) - 1; |
| 189 | input_strings_.resize(num_strings); |
| 190 | for (int64_t i = upper_bound; i >= 0; --i) { |
| 191 | if (i < upper_bound) { |
| 192 | chars[offsets[i + 1]] = '\0'; |
| 193 | } |
| 194 | input_strings_[i] = chars.data() + offsets[i]; |
| 195 | } |
| 196 | } |
| 197 | } |
| 198 | const strings& Data() const { |
| 199 | return input_strings_; |
| 200 | } |
| 201 | const void* DataRaw() const override { |
| 202 | if (input_strings_.size() != 1) { |
| 203 | ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); |
| 204 | } |
| 205 | return reinterpret_cast<const void*>(input_strings_[0].c_str()); |
| 206 | } |
| 207 | size_t SizeInBytes() const override { |
| 208 | if (input_strings_.size() != 1) { |
| 209 | ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); |
| 210 | } |
| 211 | return input_strings_[0].size(); |
| 212 | } |
| 213 | void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) { |
| 214 | std::vector<const char*> raw; |
| 215 | for (const auto& s : ss) { |
| 216 | raw.push_back(s.data()); |
| 217 | } |
| 218 | auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size()); |
| 219 | OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, raw.data(), raw.size())); |
| 220 | } |
| 221 | void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) { |
| 222 | auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size()); |
| 223 | OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, ss.data(), ss.size())); |
| 224 | } |
| 225 | const Span<std::string>& AsSpan() { |
| 226 | ORTX_CXX_API_THROW("span for TensorT of string not implemented", ORT_RUNTIME_EXCEPTION); |
| 227 | } |
| 228 | const std::string& AsScalar() { |
| 229 | if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) { |
| 230 | ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION); |
| 231 | } |
| 232 | return input_strings_[0]; |
| 233 | } |
| 234 | |
| 235 | private: |
| 236 | std::vector<std::string> input_strings_; // for input |
| 237 | }; |
| 238 | |
| 239 | template <> |
| 240 | class Tensor<std::string_view> : public TensorBase { |
| 241 | public: |
| 242 | using strings = std::vector<std::string>; |
| 243 | using string_views = std::vector<std::string_view>; |
| 244 | |
| 245 | Tensor(const OrtW::CustomOpApi& api, |
| 246 | OrtKernelContext& ctx, |
| 247 | size_t indice, |
| 248 | bool is_input) : TensorBase(api, |
| 249 | ctx, |
| 250 | indice, |
| 251 | is_input) { |
| 252 | if (is_input_) { |
| 253 | auto input_count = api_.KernelContext_GetInputCount(&ctx_); |
| 254 | if (indice >= input_count) { |
| 255 | ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION); |
| 256 | } |
| 257 | auto* const_value = api_.KernelContext_GetInput(&ctx_, indice); |
| 258 | auto* info = api_.GetTensorTypeAndShape(const_value); |
| 259 | shape_ = api_.GetTensorShape(info); |
| 260 | type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; |
| 261 | api_.ReleaseTensorTypeAndShapeInfo(info); |
| 262 | |
| 263 | size_t num_chars; |
| 264 | OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars)); |
| 265 | chars_.resize(num_chars + 1, '\0'); |
| 266 | |
| 267 | auto num_strings = static_cast<size_t>(NumberOfElement()); |
| 268 | if (num_strings) { |
| 269 | std::vector<size_t> offsets(num_strings); |
| 270 | OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value, |
| 271 | (void*)chars_.data(), |
| 272 | num_chars, |
| 273 | offsets.data(), |
| 274 | offsets.size())); |
| 275 | offsets.push_back(num_chars); |
| 276 | for (size_t i = 0; i < num_strings; ++i) { |
| 277 | input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]); |
| 278 | } |
| 279 | } |
| 280 | } |
| 281 | } |
| 282 | int64_t NumberOfElement() const { |
| 283 | if (shape_.has_value()) { |
| 284 | return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies<int64_t>()); |
| 285 | } else { |
| 286 | return 0; |
| 287 | } |
| 288 | } |
| 289 | const string_views& Data() const { |
| 290 | return input_string_views_; |
| 291 | } |
| 292 | const void* DataRaw() const override { |
| 293 | if (input_string_views_.size() != 1) { |
| 294 | ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); |
| 295 | } |
| 296 | return reinterpret_cast<const void*>(input_string_views_[0].data()); |
| 297 | } |
| 298 | size_t SizeInBytes() const override { |
| 299 | if (input_string_views_.size() != 1) { |
| 300 | ORTX_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); |
| 301 | } |
| 302 | return input_string_views_[0].size(); |
| 303 | } |
| 304 | const Span<std::string_view>& AsSpan() { |
| 305 | ORTX_CXX_API_THROW("span for TensorT of string view not implemented", ORT_RUNTIME_EXCEPTION); |
| 306 | } |
| 307 | std::string_view AsScalar() { |
| 308 | if (!shape_.has_value() || (shape_->size() == 1 && (*shape_)[0] != 1) || shape_->size() > 1) { |
| 309 | ORTX_CXX_API_THROW("to get a scalar, shape must be {1}, actual shape: " + Shape2Str(), ORT_RUNTIME_EXCEPTION); |
| 310 | } |
| 311 | return input_string_views_[0]; |
| 312 | } |
| 313 | |
| 314 | private: |
| 315 | std::vector<char> chars_; // for input |
| 316 | std::vector<std::string_view> input_string_views_; // for input |
| 317 | }; |
| 318 | |
| 319 | using TensorPtr = std::unique_ptr<Custom::TensorBase>; |
| 320 | using TensorPtrs = std::vector<TensorPtr>; |
| 321 | |
| 322 | // Represent variadic input or output |
| 323 | struct Variadic : public TensorBase { |
| 324 | Variadic(const OrtW::CustomOpApi& api, |
| 325 | OrtKernelContext& ctx, |
| 326 | size_t indice, |
| 327 | bool is_input) : TensorBase(api, |
| 328 | ctx, |
| 329 | indice, |
| 330 | is_input) { |
| 331 | #if ORT_API_VERSION < 14 |
| 332 | ORTX_CXX_API_THROW("Variadic input or output only supported after onnxruntime 1.14", ORT_RUNTIME_EXCEPTION); |
| 333 | #endif |
| 334 | if (is_input) { |
| 335 | auto input_count = api_.KernelContext_GetInputCount(&ctx_); |
| 336 | for (size_t ith_input = 0; ith_input < input_count; ++ith_input) { |
| 337 | auto* const_value = api_.KernelContext_GetInput(&ctx_, ith_input); |
| 338 | auto* info = api_.GetTensorTypeAndShape(const_value); |
| 339 | auto type = api_.GetTensorElementType(info); |
| 340 | api_.ReleaseTensorTypeAndShapeInfo(info); |
| 341 | TensorPtr tensor; |
| 342 | switch (type) { |
| 343 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: |
| 344 | tensor = std::make_unique<Custom::Tensor<bool>>(api, ctx, ith_input, true); |
| 345 | break; |
| 346 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: |
| 347 | tensor = std::make_unique<Custom::Tensor<float>>(api, ctx, ith_input, true); |
| 348 | break; |
| 349 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: |
| 350 | tensor = std::make_unique<Custom::Tensor<double>>(api, ctx, ith_input, true); |
| 351 | break; |
| 352 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: |
| 353 | tensor = std::make_unique<Custom::Tensor<uint8_t>>(api, ctx, ith_input, true); |
| 354 | break; |
| 355 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: |
| 356 | tensor = std::make_unique<Custom::Tensor<int8_t>>(api, ctx, ith_input, true); |
| 357 | break; |
| 358 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: |
| 359 | tensor = std::make_unique<Custom::Tensor<uint16_t>>(api, ctx, ith_input, true); |
| 360 | break; |
| 361 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: |
| 362 | tensor = std::make_unique<Custom::Tensor<int16_t>>(api, ctx, ith_input, true); |
| 363 | break; |
| 364 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: |
| 365 | tensor = std::make_unique<Custom::Tensor<uint32_t>>(api, ctx, ith_input, true); |
| 366 | break; |
| 367 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: |
| 368 | tensor = std::make_unique<Custom::Tensor<int32_t>>(api, ctx, ith_input, true); |
| 369 | break; |
| 370 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: |
| 371 | tensor = std::make_unique<Custom::Tensor<uint64_t>>(api, ctx, ith_input, true); |
| 372 | break; |
| 373 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: |
| 374 | tensor = std::make_unique<Custom::Tensor<int64_t>>(api, ctx, ith_input, true); |
| 375 | break; |
| 376 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: |
| 377 | tensor = std::make_unique<Custom::Tensor<std::string>>(api, ctx, ith_input, true); |
| 378 | break; |
| 379 | default: |
| 380 | ORTX_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION); |
| 381 | break; |
| 382 | } |
| 383 | tensors_.emplace_back(tensor.release()); |
| 384 | } // for |
| 385 | } else { |
| 386 | // a Variadic used for output is populated by the Compute so leave tensors_ empty here |
| 387 | } |
| 388 | } |
| 389 | template <typename T> |
| 390 | T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) { |
| 391 | auto tensor = std::make_unique<Tensor<T>>(api_, ctx_, ith_output, false); |
| 392 | auto raw_output = tensor.get()->Allocate(shape); |
| 393 | tensors_.emplace_back(tensor.release()); |
| 394 | return raw_output; |
| 395 | } |
| 396 | Tensor<std::string>& AllocateStringTensor(size_t ith_output) { |
| 397 | auto tensor = std::make_unique<Tensor<std::string>>(api_, ctx_, ith_output, false); |
| 398 | Tensor<std::string>& output = *tensor; |
| 399 | tensors_.emplace_back(tensor.release()); |
| 400 | return output; |
| 401 | } |
| 402 | const void* DataRaw() const override { |
| 403 | ORTX_CXX_API_THROW("DataRaw() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION); |
| 404 | return nullptr; |
| 405 | } |
| 406 | size_t SizeInBytes() const override { |
| 407 | ORTX_CXX_API_THROW("SizeInBytes() cannot be applied to Variadic", ORT_RUNTIME_EXCEPTION); |
| 408 | return 0; |
| 409 | } |
| 410 | size_t Size() const { |
| 411 | return tensors_.size(); |
| 412 | } |
| 413 | const TensorPtr& operator[](size_t indice) const { |
| 414 | return tensors_.at(indice); |
| 415 | } |
| 416 | |
| 417 | private: |
| 418 | TensorPtrs tensors_; |
| 419 | }; |
| 420 | |
| 421 | struct OrtLiteCustomOp : public OrtCustomOp { |
| 422 | // CreateTuple |
| 423 | template <size_t ith_input, size_t ith_output, typename... Ts> |
| 424 | static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type |
| 425 | CreateTuple(const OrtW::CustomOpApi*, OrtKernelContext*, std::vector<TensorPtr>&, size_t, size_t, const std::string&) { |
| 426 | return std::make_tuple(); |
| 427 | } |
| 428 | |
| 429 | template <size_t ith_input, size_t ith_output, typename T, typename... Ts> |
| 430 | static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type |
| 431 | CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { |
| 432 | std::tuple<T> current = std::tuple<OrtKernelContext*>{context}; |
| 433 | auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); |
| 434 | return std::tuple_cat(current, next); |
| 435 | } |
| 436 | |
| 437 | #if ORT_API_VERSION >= 14 |
| 438 | template <size_t ith_input, size_t ith_output, typename T, typename... Ts> |
| 439 | static typename std::enable_if<std::is_same<T, const Variadic*>::value, std::tuple<T, Ts...>>::type |
| 440 | CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { |
| 441 | tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true)); |
| 442 | std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; |
| 443 | auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); |
| 444 | return std::tuple_cat(current, next); |
| 445 | } |
| 446 | |
| 447 | template <size_t ith_input, size_t ith_output, typename T, typename... Ts> |
| 448 | static typename std::enable_if<std::is_same<T, const Variadic&>::value, std::tuple<T, Ts...>>::type |
| 449 | CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { |
| 450 | tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true)); |
| 451 | std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; |
| 452 | auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); |
| 453 | return std::tuple_cat(current, next); |
| 454 | } |
| 455 | |
| 456 | template <size_t ith_input, size_t ith_output, typename T, typename... Ts> |
| 457 | static typename std::enable_if<std::is_same<T, Variadic*>::value, std::tuple<T, Ts...>>::type |
| 458 | CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { |
| 459 | tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false)); |
| 460 | std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; |
| 461 | auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); |
| 462 | return std::tuple_cat(current, next); |
| 463 | } |
| 464 | |
| 465 | template <size_t ith_input, size_t ith_output, typename T, typename... Ts> |
| 466 | static typename std::enable_if<std::is_same<T, Variadic&>::value, std::tuple<T, Ts...>>::type |
| 467 | CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { |
| 468 | tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false)); |
| 469 | std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; |
| 470 | auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); |
| 471 | return std::tuple_cat(current, next); |
| 472 | } |
| 473 | #endif |
| 474 | |
| 475 | #define CREATE_TUPLE_INPUT(data_type) \ |
| 476 | template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \ |
| 477 | static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \ |
| 478 | CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ |
| 479 | tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \ |
| 480 | std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; \ |
| 481 | auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \ |
| 482 | return std::tuple_cat(current, next); \ |
| 483 | } \ |
| 484 | template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \ |
| 485 | static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \ |
| 486 | CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ |
| 487 | tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \ |
| 488 | std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; \ |
| 489 | auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \ |
| 490 | return std::tuple_cat(current, next); \ |
| 491 | } \ |
| 492 | template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \ |
| 493 | static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \ |
| 494 | CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ |
| 495 | if (ith_input < num_input) { \ |
| 496 | tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \ |
| 497 | std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())}; \ |
| 498 | auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \ |
| 499 | return std::tuple_cat(current, next); \ |
| 500 | } else { \ |
| 501 | std::tuple<T> current = std::tuple<T>{}; \ |
| 502 | auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \ |
| 503 | return std::tuple_cat(current, next); \ |
| 504 | } \ |
| 505 | } \ |
| 506 | template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \ |
| 507 | static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>*>::value, std::tuple<T, Ts...>>::type \ |
| 508 | CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ |
| 509 | tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \ |
| 510 | if (!tensors.back()->IsCpuTensor()) { \ |
| 511 | ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \ |
| 512 | } \ |
| 513 | std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \ |
| 514 | auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \ |
| 515 | return std::tuple_cat(current, next); \ |
| 516 | } \ |
| 517 | template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \ |
| 518 | static typename std::enable_if<std::is_same<T, const Custom::Span<data_type>&>::value, std::tuple<T, Ts...>>::type \ |
| 519 | CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ |
| 520 | tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \ |
| 521 | if (!tensors.back()->IsCpuTensor()) { \ |
| 522 | ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \ |
| 523 | } \ |
| 524 | std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \ |
| 525 | auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \ |
| 526 | return std::tuple_cat(current, next); \ |
| 527 | } \ |
| 528 | template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \ |
| 529 | static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type>*>>::value, std::tuple<T, Ts...>>::type \ |
| 530 | CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ |
| 531 | if (ith_input < num_input) { \ |
| 532 | tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \ |
| 533 | if (!tensors.back()->IsCpuTensor()) { \ |
| 534 | ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL); \ |
| 535 | } \ |
| 536 | std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsSpan()}; \ |
| 537 | auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \ |
| 538 | return std::tuple_cat(current, next); \ |
| 539 | } else { \ |
| 540 | std::tuple<T> current = std::tuple<T>{}; \ |
| 541 | auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \ |
| 542 | return std::tuple_cat(current, next); \ |
| 543 | } \ |
| 544 | } \ |
| 545 | template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \ |
| 546 | static typename std::enable_if<std::is_same<T, data_type>::value, std::tuple<T, Ts...>>::type \ |
| 547 | CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ |
| 548 | tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \ |
| 549 | if (!tensors.back()->IsCpuTensor()) { \ |
| 550 | ORTX_CXX_API_THROW("scalar input could only be applied to CPU tensor", ORT_FAIL); \ |
| 551 | } \ |
| 552 | std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsScalar()}; \ |
| 553 | auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \ |
| 554 | return std::tuple_cat(current, next); \ |
| 555 | } \ |
| 556 | template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \ |
| 557 | static typename std::enable_if<std::is_same<T, std::optional<data_type>>::value, std::tuple<T, Ts...>>::type \ |
| 558 | CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ |
| 559 | if (ith_input < num_input) { \ |
| 560 | tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_input, true)); \ |
| 561 | if (!tensors.back()->IsCpuTensor()) { \ |
| 562 | ORTX_CXX_API_THROW("scalar input could only be applied to CPU tensor", ORT_FAIL); \ |
| 563 | } \ |
| 564 | std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())->AsScalar()}; \ |
| 565 | auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \ |
| 566 | return std::tuple_cat(current, next); \ |
| 567 | } else { \ |
| 568 | std::tuple<T> current = std::tuple<T>{}; \ |
| 569 | auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); \ |
| 570 | return std::tuple_cat(current, next); \ |
| 571 | } \ |
| 572 | } |
| 573 | #define CREATE_TUPLE_OUTPUT(data_type) \ |
| 574 | template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \ |
| 575 | static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>*>::value, std::tuple<T, Ts...>>::type \ |
| 576 | CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ |
| 577 | tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \ |
| 578 | std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; \ |
| 579 | auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \ |
| 580 | return std::tuple_cat(current, next); \ |
| 581 | } \ |
| 582 | template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \ |
| 583 | static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type>&>::value, std::tuple<T, Ts...>>::type \ |
| 584 | CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ |
| 585 | tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \ |
| 586 | std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; \ |
| 587 | auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \ |
| 588 | return std::tuple_cat(current, next); \ |
| 589 | } \ |
| 590 | template <size_t ith_input, size_t ith_output, typename T, typename... Ts> \ |
| 591 | static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value, std::tuple<T, Ts...>>::type \ |
| 592 | CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ |
| 593 | if (ith_output < num_output) { \ |
| 594 | tensors.push_back(std::make_unique<Custom::Tensor<data_type>>(*api, *context, ith_output, false)); \ |
| 595 | std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type>*>(tensors.back().get())}; \ |
| 596 | auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \ |
| 597 | return std::tuple_cat(current, next); \ |
| 598 | } else { \ |
| 599 | std::tuple<T> current = std::tuple<T>{}; \ |
| 600 | auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); \ |
| 601 | return std::tuple_cat(current, next); \ |
| 602 | } \ |
| 603 | } |
| 604 | #define CREATE_TUPLE(data_type) \ |
| 605 | CREATE_TUPLE_INPUT(data_type) \ |
| 606 | CREATE_TUPLE_OUTPUT(data_type) |
| 607 | |
| 608 | CREATE_TUPLE(bool) |
| 609 | CREATE_TUPLE(float) |
| 610 | CREATE_TUPLE(double) |
| 611 | CREATE_TUPLE(int8_t) |
| 612 | CREATE_TUPLE(int16_t) |
| 613 | CREATE_TUPLE(int32_t) |
| 614 | CREATE_TUPLE(int64_t) |
| 615 | CREATE_TUPLE(uint8_t) |
| 616 | CREATE_TUPLE(uint16_t) |
| 617 | CREATE_TUPLE(uint32_t) |
| 618 | CREATE_TUPLE(uint64_t) |
| 619 | CREATE_TUPLE(std::string) |
| 620 | CREATE_TUPLE_INPUT(std::string_view) |
| 621 | |
| 622 | // ParseArgs ... |
| 623 | template <typename... Ts> |
| 624 | static typename std::enable_if<0 == sizeof...(Ts)>::type |
| 625 | ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) { |
| 626 | } |
| 627 | |
| 628 | template <typename T, typename... Ts> |
| 629 | static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type |
| 630 | ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { |
| 631 | ParseArgs<Ts...>(input_types, output_types); |
| 632 | } |
| 633 | |
| 634 | #if ORT_API_VERSION >= 14 |
| 635 | template <typename T, typename... Ts> |
| 636 | static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic&>::value>::type |
| 637 | ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { |
| 638 | if (!input_types.empty()) { |
| 639 | ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION); |
| 640 | } |
| 641 | input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); |
| 642 | ParseArgs<Ts...>(input_types, output_types); |
| 643 | } |
| 644 | |
| 645 | template <typename T, typename... Ts> |
| 646 | static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic*>::value>::type |
| 647 | ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { |
| 648 | if (!input_types.empty()) { |
| 649 | ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION); |
| 650 | } |
| 651 | input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); |
| 652 | ParseArgs<Ts...>(input_types, output_types); |
| 653 | } |
| 654 | |
| 655 | template <typename T, typename... Ts> |
| 656 | static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic&>::value>::type |
| 657 | ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { |
| 658 | if (!output_types.empty()) { |
| 659 | ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION); |
| 660 | } |
| 661 | output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); |
| 662 | ParseArgs<Ts...>(input_types, output_types); |
| 663 | } |
| 664 | |
| 665 | template <typename T, typename... Ts> |
| 666 | static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic*>::value>::type |
| 667 | ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { |
| 668 | if (!output_types.empty()) { |
| 669 | ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION); |
| 670 | } |
| 671 | output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); |
| 672 | ParseArgs<Ts...>(input_types, output_types); |
| 673 | } |
| 674 | #endif |
| 675 | |
| 676 | #define PARSE_INPUT_BASE(pack_type, onnx_type) \ |
| 677 | template <typename T, typename... Ts> \ |
| 678 | static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \ |
| 679 | ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \ |
| 680 | input_types.push_back(onnx_type); \ |
| 681 | ParseArgs<Ts...>(input_types, output_types); \ |
| 682 | } \ |
| 683 | template <typename T, typename... Ts> \ |
| 684 | static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \ |
| 685 | ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \ |
| 686 | input_types.push_back(onnx_type); \ |
| 687 | ParseArgs<Ts...>(input_types, output_types); \ |
| 688 | } |
| 689 | |
| 690 | #define PARSE_INPUT(data_type, onnx_type) \ |
| 691 | PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \ |
| 692 | PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \ |
| 693 | PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \ |
| 694 | PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \ |
| 695 | PARSE_INPUT_BASE(data_type, onnx_type) |
| 696 | |
| 697 | #define PARSE_OUTPUT(data_type, onnx_type) \ |
| 698 | template <typename T, typename... Ts> \ |
| 699 | static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \ |
| 700 | ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \ |
| 701 | output_types.push_back(onnx_type); \ |
| 702 | ParseArgs<Ts...>(input_types, output_types); \ |
| 703 | } \ |
| 704 | template <typename T, typename... Ts> \ |
| 705 | static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \ |
| 706 | ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \ |
| 707 | output_types.push_back(onnx_type); \ |
| 708 | ParseArgs<Ts...>(input_types, output_types); \ |
| 709 | } \ |
| 710 | template <typename T, typename... Ts> \ |
| 711 | static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \ |
| 712 | ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \ |
| 713 | output_types.push_back(onnx_type); \ |
| 714 | ParseArgs<Ts...>(input_types, output_types); \ |
| 715 | } |
| 716 | |
| 717 | #define PARSE_ARGS(data_type, onnx_type) \ |
| 718 | PARSE_INPUT(data_type, onnx_type) \ |
| 719 | PARSE_OUTPUT(data_type, onnx_type) |
| 720 | |
| 721 | PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL) |
| 722 | PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) |
| 723 | PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) |
| 724 | PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) |
| 725 | PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16) |
| 726 | PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) |
| 727 | PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) |
| 728 | PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) |
| 729 | PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16) |
| 730 | PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32) |
| 731 | PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64) |
| 732 | PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) |
| 733 | PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output |
| 734 | |
| 735 | OrtLiteCustomOp(const char* op_name, |
| 736 | const char* execution_provider) : op_name_(op_name), |
| 737 | execution_provider_(execution_provider) { |
| 738 | int act_ver = GetActiveOrtAPIVersion(); |
| 739 | OrtCustomOp::version = act_ver < ORT_API_VERSION ? act_ver : ORT_API_VERSION; |
| 740 | |
| 741 | OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); }; |
| 742 | OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); }; |
| 743 | |
| 744 | OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) { |
| 745 | auto self = reinterpret_cast<const OrtLiteCustomOp*>(op); |
| 746 | return self->input_types_.size(); |
| 747 | }; |
| 748 | |
| 749 | OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) { |
| 750 | auto self = reinterpret_cast<const OrtLiteCustomOp*>(op); |
| 751 | return self->input_types_[indice]; |
| 752 | }; |
| 753 | |
| 754 | OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) { |
| 755 | auto self = reinterpret_cast<const OrtLiteCustomOp*>(op); |
| 756 | return self->output_types_.size(); |
| 757 | }; |
| 758 | |
| 759 | OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) { |
| 760 | auto self = reinterpret_cast<const OrtLiteCustomOp*>(op); |
| 761 | return self->output_types_[indice]; |
| 762 | }; |
| 763 | |
| 764 | #if ORT_API_VERSION >= 14 |
| 765 | OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t) { |
| 766 | auto self = reinterpret_cast<const OrtLiteCustomOp*>(op); |
| 767 | return (self->input_types_.empty() || self->input_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC; |
| 768 | }; |
| 769 | |
| 770 | OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) { |
| 771 | auto self = reinterpret_cast<const OrtLiteCustomOp*>(op); |
| 772 | return (self->output_types_.empty() || self->output_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC; |
| 773 | }; |
| 774 | |
| 775 | OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { |
| 776 | return 1; |
| 777 | }; |
| 778 | |
| 779 | OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { |
| 780 | return 0; |
| 781 | }; |
| 782 | |
| 783 | OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { |
| 784 | return 1; |
| 785 | }; |
| 786 | |
| 787 | OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { |
| 788 | return 0; |
| 789 | }; |
| 790 | |
| 791 | OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) { |
| 792 | return OrtMemTypeDefault; |
| 793 | }; |
| 794 | #else |
| 795 | OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) { |
| 796 | return INPUT_OUTPUT_OPTIONAL; |
| 797 | }; |
| 798 | |
| 799 | OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) { |
| 800 | return INPUT_OUTPUT_OPTIONAL; |
| 801 | }; |
| 802 | #endif |
| 803 | } |
| 804 | |
| 805 | const std::string op_name_; |
| 806 | const std::string execution_provider_; |
| 807 | |
| 808 | std::vector<ONNXTensorElementDataType> input_types_; |
| 809 | std::vector<ONNXTensorElementDataType> output_types_; |
| 810 | }; |
| 811 | |
| 812 | template <typename... Args> |
| 813 | struct OrtLiteCustomFunc : public OrtLiteCustomOp { |
| 814 | using ComputeFn = void (*)(Args...); |
| 815 | using MyType = OrtLiteCustomFunc<Args...>; |
| 816 | |
| 817 | struct Kernel { |
| 818 | ComputeFn compute_fn_{}; |
| 819 | std::string ep_{}; |
| 820 | std::unique_ptr<OrtW::CustomOpApi> api_; |
| 821 | }; |
| 822 | |
| 823 | OrtLiteCustomFunc(const char* op_name, |
| 824 | const char* execution_provider, |
| 825 | ComputeFn compute_fn) : OrtLiteCustomOp(op_name, execution_provider), |
| 826 | compute_fn_(compute_fn) { |
| 827 | ParseArgs<Args...>(input_types_, output_types_); |
| 828 | |
| 829 | OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { |
| 830 | auto kernel = reinterpret_cast<Kernel*>(op_kernel); |
| 831 | std::vector<TensorPtr> tensors; |
| 832 | auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(), |
| 833 | context, |
| 834 | tensors, |
| 835 | kernel->api_->KernelContext_GetInputCount(context), |
| 836 | kernel->api_->KernelContext_GetOutputCount(context), |
| 837 | kernel->ep_); |
| 838 | std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t); |
| 839 | }; |
| 840 | |
| 841 | OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { |
| 842 | auto kernel = std::make_unique<Kernel>(); |
| 843 | auto self = static_cast<const OrtLiteCustomFunc*>(this_); |
| 844 | kernel->compute_fn_ = self->compute_fn_; |
| 845 | kernel->ep_ = self->execution_provider_; |
| 846 | kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api); |
| 847 | return reinterpret_cast<void*>(kernel.release()); |
| 848 | }; |
| 849 | |
| 850 | OrtCustomOp::KernelDestroy = [](void* op_kernel) { |
| 851 | delete reinterpret_cast<Kernel*>(op_kernel); |
| 852 | }; |
| 853 | } |
| 854 | |
| 855 | ComputeFn compute_fn_; |
| 856 | }; |
| 857 | |
| 858 | template <typename CustomOp> |
| 859 | struct OrtLiteCustomStruct : public OrtLiteCustomOp { |
| 860 | template <typename... Args> |
| 861 | using CustomComputeFn = void (CustomOp::*)(Args...) const; |
| 862 | using MyType = OrtLiteCustomStruct<CustomOp>; |
| 863 | |
| 864 | struct Kernel { |
| 865 | std::unique_ptr<CustomOp> custom_op_; |
| 866 | std::string ep_{}; |
| 867 | std::unique_ptr<OrtW::CustomOpApi> api_; |
| 868 | }; |
| 869 | |
| 870 | OrtLiteCustomStruct(const char* op_name, |
| 871 | const char* execution_provider) : OrtLiteCustomOp(op_name, |
| 872 | execution_provider) { |
| 873 | init(&CustomOp::Compute); |
| 874 | } |
| 875 | |
| 876 | template <typename... Args> |
| 877 | void init(CustomComputeFn<Args...>) { |
| 878 | ParseArgs<Args...>(input_types_, output_types_); |
| 879 | |
| 880 | OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { |
| 881 | auto kernel = reinterpret_cast<Kernel*>(op_kernel); |
| 882 | std::vector<TensorPtr> tensors; |
| 883 | auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(), |
| 884 | context, |
| 885 | tensors, |
| 886 | kernel->api_->KernelContext_GetInputCount(context), |
| 887 | kernel->api_->KernelContext_GetOutputCount(context), |
| 888 | kernel->ep_); |
| 889 | std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t); |
| 890 | }; |
| 891 | |
| 892 | OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { |
| 893 | auto kernel = std::make_unique<Kernel>(); |
| 894 | kernel->custom_op_ = std::make_unique<CustomOp>(*ort_api, *info); |
| 895 | auto self = static_cast<const MyType*>(this_); |
| 896 | kernel->ep_ = self->execution_provider_; |
| 897 | kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api); |
| 898 | return reinterpret_cast<void*>(kernel.release()); |
| 899 | }; |
| 900 | |
| 901 | OrtCustomOp::KernelDestroy = [](void* op_kernel) { |
| 902 | delete reinterpret_cast<Kernel*>(op_kernel); |
| 903 | }; |
| 904 | } |
| 905 | }; |
| 906 | |
| 907 | template <typename... Args> |
| 908 | OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, |
| 909 | const char* execution_provider, |
| 910 | void (*custom_compute_fn)(Args...)) { |
| 911 | using LiteOp = OrtLiteCustomFunc<Args...>; |
| 912 | return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn).release(); |
| 913 | } |
| 914 | |
| 915 | template <typename CustomOp> |
| 916 | OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, |
| 917 | const char* execution_provider) { |
| 918 | using LiteOp = OrtLiteCustomStruct<CustomOp>; |
| 919 | return std::make_unique<LiteOp>(op_name, execution_provider).release(); |
| 920 | } |
| 921 | |
| 922 | } // namespace Custom |
| 923 | } // namespace Ort |