microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
include/custom_op/custom_op_lite.h
1035lines · modecode
| 1 | // Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | // Licensed under the MIT License. |
| 3 | |
| 4 | #pragma once |
| 5 | |
| 6 | #include <optional> |
| 7 | #include <numeric> |
| 8 | #include <string> |
| 9 | #include <string_view> |
| 10 | #include "tensor_api.h" |
| 11 | #include "ort_c_to_cpp.h" |
| 12 | #include "onnxruntime_f16.h" |
| 13 | |
| 14 | namespace Ort { |
| 15 | namespace Custom { |
| 16 | |
| 17 | class OrtKernelContextStorage : public ITensorStorage { |
| 18 | public: |
| 19 | OrtKernelContextStorage(const OrtW::CustomOpApi& custom_op_api, |
| 20 | OrtKernelContext& ctx, |
| 21 | size_t indice, |
| 22 | bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice) { |
| 23 | if (is_input) { |
| 24 | auto input_count = api_.KernelContext_GetInputCount(&ctx); |
| 25 | if (indice >= input_count) { |
| 26 | ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION); |
| 27 | } |
| 28 | const_value_ = api_.KernelContext_GetInput(&ctx, indice); |
| 29 | auto* info = api_.GetTensorTypeAndShape(const_value_); |
| 30 | shape_ = api_.GetTensorShape(info); |
| 31 | api_.ReleaseTensorTypeAndShapeInfo(info); |
| 32 | } |
| 33 | } |
| 34 | |
| 35 | const std::vector<int64_t>& Shape() const override { |
| 36 | if (!IsInitialized()) |
| 37 | ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION); |
| 38 | return *shape_; |
| 39 | } |
| 40 | |
| 41 | virtual bool IsInitialized() const override { |
| 42 | return shape_.has_value(); |
| 43 | } |
| 44 | |
| 45 | const void* DataRaw() const override { |
| 46 | return api_.GetTensorRawData(const_value_); |
| 47 | } |
| 48 | |
| 49 | void* Initialize(const std::vector<int64_t>& shape, size_t element_size) override { |
| 50 | if (!const_value_) { |
| 51 | const_value_ = api_.KernelContext_GetOutput(&ctx_, indice_, shape.data(), shape.size()); |
| 52 | shape_ = shape; |
| 53 | } |
| 54 | return api_.GetTensorMutableRawData(const_cast<OrtValue*>(const_value_)); |
| 55 | } |
| 56 | |
| 57 | void* Release() override { |
| 58 | ORTX_CXX_API_THROW("Can't release the tensor buffer with ORT graph mode.", ORT_RUNTIME_EXCEPTION); |
| 59 | } |
| 60 | |
| 61 | private: |
| 62 | const OrtW::CustomOpApi& api_; |
| 63 | OrtKernelContext& ctx_; |
| 64 | size_t indice_; |
| 65 | const OrtValue* const_value_{}; // for input |
| 66 | std::optional<std::vector<int64_t>> shape_; |
| 67 | }; |
| 68 | |
| 69 | static std::string get_mem_type(const OrtW::CustomOpApi& custom_op_api, |
| 70 | OrtKernelContext& ctx, |
| 71 | size_t indice, |
| 72 | bool is_input) { |
| 73 | std::string output = "Cpu"; |
| 74 | if (is_input) { |
| 75 | const OrtValue* const_value = custom_op_api.KernelContext_GetInput(&ctx, indice); |
| 76 | const OrtMemoryInfo* mem_info = {}; |
| 77 | custom_op_api.ThrowOnError(custom_op_api.GetOrtApi().GetTensorMemoryInfo(const_value, &mem_info)); |
| 78 | if (mem_info) { |
| 79 | const char* mem_type = nullptr; |
| 80 | custom_op_api.ThrowOnError(custom_op_api.GetOrtApi().MemoryInfoGetName(mem_info, &mem_type)); |
| 81 | if (mem_type) { |
| 82 | output = mem_type; |
| 83 | } |
| 84 | } |
| 85 | } |
| 86 | return output; |
| 87 | } |
| 88 | |
| 89 | template <typename T> |
| 90 | class OrtTensor : public Tensor<T> { |
| 91 | public: |
| 92 | OrtTensor(const OrtW::CustomOpApi& custom_op_api, |
| 93 | OrtKernelContext& ctx, |
| 94 | size_t indice, |
| 95 | bool is_input) : Tensor<T>(std::make_unique<OrtKernelContextStorage>(custom_op_api, ctx, indice, is_input)), |
| 96 | mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) { |
| 97 | } |
| 98 | |
| 99 | bool IsCpuTensor() const { |
| 100 | return mem_type_ == "Cpu"; |
| 101 | } |
| 102 | |
| 103 | private: |
| 104 | std::string mem_type_ = "Cpu"; |
| 105 | }; |
| 106 | |
| 107 | class OrtStringTensorStorage : public IStringTensorStorage<std::string> { |
| 108 | public: |
| 109 | using strings = std::vector<std::string>; |
| 110 | OrtStringTensorStorage(const OrtW::CustomOpApi& custom_op_api, |
| 111 | OrtKernelContext& ctx, |
| 112 | size_t indice, |
| 113 | bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice) { |
| 114 | if (is_input) { |
| 115 | auto input_count = api_.KernelContext_GetInputCount(&ctx_); |
| 116 | if (indice >= input_count) { |
| 117 | ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION); |
| 118 | } |
| 119 | |
| 120 | auto* const_value = api_.KernelContext_GetInput(&ctx_, indice); |
| 121 | auto* info = api_.GetTensorTypeAndShape(const_value); |
| 122 | shape_ = api_.GetTensorShape(info); |
| 123 | api_.ReleaseTensorTypeAndShapeInfo(info); |
| 124 | |
| 125 | size_t num_chars; |
| 126 | OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars)); |
| 127 | std::vector<char> chars(num_chars + 1, '\0'); |
| 128 | // assert((*shape_).size() == 1 || ((*shape_).size() == 2 && (*shape_)[0] == 1)); |
| 129 | |
| 130 | int64_t num_strings = 1; // string scalar |
| 131 | if ((*shape_).size() > 0) { |
| 132 | num_strings = (*shape_).front(); |
| 133 | for (auto iter = (*shape_).begin() + 1; iter != (*shape_).end(); ++iter) { |
| 134 | num_strings *= *iter; |
| 135 | } |
| 136 | } |
| 137 | std::vector<size_t> offsets(num_strings); |
| 138 | OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value, |
| 139 | (void*)chars.data(), |
| 140 | num_chars, |
| 141 | offsets.data(), |
| 142 | offsets.size())); |
| 143 | auto upper_bound = static_cast<int64_t>(num_strings) - 1; |
| 144 | input_strings_.resize(num_strings); |
| 145 | for (int64_t i = upper_bound; i >= 0; --i) { |
| 146 | if (i < upper_bound) { |
| 147 | chars[offsets[i + 1]] = '\0'; |
| 148 | } |
| 149 | input_strings_[i] = chars.data() + offsets[i]; |
| 150 | } |
| 151 | } |
| 152 | } |
| 153 | |
| 154 | const std::vector<int64_t>& Shape() const override { |
| 155 | if (!IsInitialized()) |
| 156 | ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION); |
| 157 | return *shape_; |
| 158 | } |
| 159 | |
| 160 | virtual const void* DataRaw() const override { |
| 161 | if (input_strings_.size() != 1) { |
| 162 | ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); |
| 163 | } |
| 164 | return reinterpret_cast<const void*>(input_strings_[0].c_str()); |
| 165 | } |
| 166 | |
| 167 | virtual bool IsInitialized() const override { |
| 168 | return shape_.has_value(); |
| 169 | } |
| 170 | |
| 171 | virtual void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) override { |
| 172 | std::vector<const char*> raw; |
| 173 | for (const auto& s : ss) { |
| 174 | raw.push_back(s.data()); |
| 175 | } |
| 176 | auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size()); |
| 177 | OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, raw.data(), raw.size())); |
| 178 | } |
| 179 | |
| 180 | virtual void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) override { |
| 181 | auto* output = api_.KernelContext_GetOutput(&ctx_, indice_, dims.data(), dims.size()); |
| 182 | OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().FillStringTensor(output, ss.data(), ss.size())); |
| 183 | } |
| 184 | |
| 185 | const strings& Data() const override { |
| 186 | return input_strings_; |
| 187 | } |
| 188 | |
| 189 | private: |
| 190 | const OrtW::CustomOpApi& api_; |
| 191 | OrtKernelContext& ctx_; |
| 192 | size_t indice_; |
| 193 | std::vector<std::string> input_strings_; |
| 194 | std::optional<std::vector<int64_t>> shape_; |
| 195 | }; |
| 196 | |
| 197 | class OrtStringViewTensorStorage : public IStringTensorStorage<std::string_view> { |
| 198 | public: |
| 199 | using strings = std::vector<std::string_view>; |
| 200 | OrtStringViewTensorStorage(const OrtW::CustomOpApi& custom_op_api, |
| 201 | OrtKernelContext& ctx, |
| 202 | size_t indice, |
| 203 | bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice) { |
| 204 | if (is_input) { |
| 205 | auto input_count = api_.KernelContext_GetInputCount(&ctx_); |
| 206 | if (indice >= input_count) { |
| 207 | ORTX_CXX_API_THROW("invalid indice", ORT_RUNTIME_EXCEPTION); |
| 208 | } |
| 209 | auto* const_value = api_.KernelContext_GetInput(&ctx_, indice); |
| 210 | auto* info = api_.GetTensorTypeAndShape(const_value); |
| 211 | shape_ = api_.GetTensorShape(info); |
| 212 | api_.ReleaseTensorTypeAndShapeInfo(info); |
| 213 | |
| 214 | size_t num_chars; |
| 215 | OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorDataLength(const_value, &num_chars)); |
| 216 | chars_.resize(num_chars + 1, '\0'); |
| 217 | |
| 218 | size_t num_strings = 1; |
| 219 | if ((*shape_).size() > 0) { |
| 220 | num_strings = static_cast<size_t>((*shape_)[0]); |
| 221 | } |
| 222 | |
| 223 | if (num_strings) { |
| 224 | std::vector<size_t> offsets(num_strings); |
| 225 | OrtW::ThrowOnError(api_.GetOrtApi(), api_.GetOrtApi().GetStringTensorContent(const_value, |
| 226 | (void*)chars_.data(), |
| 227 | num_chars, |
| 228 | offsets.data(), |
| 229 | offsets.size())); |
| 230 | offsets.push_back(num_chars); |
| 231 | for (size_t i = 0; i < num_strings; ++i) { |
| 232 | input_string_views_.emplace_back(chars_.data() + offsets[i], offsets[i + 1] - offsets[i]); |
| 233 | } |
| 234 | } |
| 235 | } |
| 236 | } |
| 237 | |
| 238 | const std::vector<int64_t>& Shape() const override { |
| 239 | if (!IsInitialized()) |
| 240 | ORTX_CXX_API_THROW("Tensor not initialized", ORT_RUNTIME_EXCEPTION); |
| 241 | return *shape_; |
| 242 | } |
| 243 | |
| 244 | virtual const void* DataRaw() const override { |
| 245 | if (input_string_views_.size() != 1) { |
| 246 | ORTX_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); |
| 247 | } |
| 248 | return reinterpret_cast<const void*>(input_string_views_[0].data()); |
| 249 | } |
| 250 | |
| 251 | virtual bool IsInitialized() const override { |
| 252 | return shape_.has_value(); |
| 253 | } |
| 254 | |
| 255 | virtual void SetStringOutput(const strings& ss, const std::vector<int64_t>& dims) override { |
| 256 | ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION); |
| 257 | } |
| 258 | |
| 259 | virtual void SetStringOutput(const std::vector<const char*>& ss, const std::vector<int64_t>& dims) override { |
| 260 | ORTX_CXX_API_THROW("Set output for string view tensor is not supported", ORT_RUNTIME_EXCEPTION); |
| 261 | } |
| 262 | |
| 263 | const strings& Data() const override { |
| 264 | return input_string_views_; |
| 265 | } |
| 266 | |
| 267 | private: |
| 268 | const OrtW::CustomOpApi& api_; |
| 269 | OrtKernelContext& ctx_; |
| 270 | size_t indice_; |
| 271 | std::vector<char> chars_; // for input |
| 272 | std::vector<std::string_view> input_string_views_; // for input |
| 273 | std::optional<std::vector<int64_t>> shape_; |
| 274 | }; |
| 275 | |
| 276 | // to make the metaprogramming magic happy. |
| 277 | template <> |
| 278 | class OrtTensor<std::string> : public Tensor<std::string> { |
| 279 | public: |
| 280 | OrtTensor(const OrtW::CustomOpApi& custom_op_api, |
| 281 | OrtKernelContext& ctx, |
| 282 | size_t indice, |
| 283 | bool is_input) : Tensor<std::string>(std::make_unique<OrtStringTensorStorage>(custom_op_api, ctx, indice, is_input)), |
| 284 | mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {} |
| 285 | |
| 286 | bool IsCpuTensor() const { |
| 287 | return mem_type_ == "Cpu"; |
| 288 | } |
| 289 | |
| 290 | private: |
| 291 | std::string mem_type_ = "Cpu"; |
| 292 | }; |
| 293 | |
| 294 | template <> |
| 295 | class OrtTensor<std::string_view> : public Tensor<std::string_view> { |
| 296 | public: |
| 297 | OrtTensor(const OrtW::CustomOpApi& custom_op_api, |
| 298 | OrtKernelContext& ctx, |
| 299 | size_t indice, |
| 300 | bool is_input) : Tensor<std::string_view>(std::make_unique<OrtStringViewTensorStorage>(custom_op_api, ctx, indice, is_input)), |
| 301 | mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) {} |
| 302 | |
| 303 | bool IsCpuTensor() const { |
| 304 | return mem_type_ == "Cpu"; |
| 305 | } |
| 306 | |
| 307 | private: |
| 308 | std::string mem_type_ = "Cpu"; |
| 309 | }; |
| 310 | |
| 311 | using TensorPtr = std::unique_ptr<Custom::Arg>; |
| 312 | using TensorPtrs = std::vector<TensorPtr>; |
| 313 | |
| 314 | using TensorBasePtr = std::unique_ptr<Custom::TensorBase>; |
| 315 | using TensorBasePtrs = std::vector<TensorBasePtr>; |
| 316 | |
| 317 | // Represent variadic input or output |
| 318 | struct Variadic : public Arg { |
| 319 | Variadic(const OrtW::CustomOpApi& custom_op_api, |
| 320 | OrtKernelContext& ctx, |
| 321 | size_t indice, |
| 322 | bool is_input) : api_(custom_op_api), ctx_(ctx), indice_(indice), mem_type_(get_mem_type(custom_op_api, ctx, indice, is_input)) { |
| 323 | #if ORT_API_VERSION < 14 |
| 324 | ORTX_CXX_API_THROW("Variadic input or output only supported after onnxruntime 1.14", ORT_RUNTIME_EXCEPTION); |
| 325 | #endif |
| 326 | if (is_input) { |
| 327 | auto input_count = api_.KernelContext_GetInputCount(&ctx_); |
| 328 | for (size_t ith_input = 0; ith_input < input_count; ++ith_input) { |
| 329 | auto* const_value = api_.KernelContext_GetInput(&ctx_, ith_input); |
| 330 | auto* info = api_.GetTensorTypeAndShape(const_value); |
| 331 | auto type = api_.GetTensorElementType(info); |
| 332 | api_.ReleaseTensorTypeAndShapeInfo(info); |
| 333 | TensorBasePtr tensor; |
| 334 | switch (type) { |
| 335 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: |
| 336 | tensor = std::make_unique<Custom::OrtTensor<bool>>(api_, ctx, ith_input, true); |
| 337 | break; |
| 338 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: |
| 339 | tensor = std::make_unique<Custom::OrtTensor<float>>(api_, ctx, ith_input, true); |
| 340 | break; |
| 341 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: |
| 342 | tensor = std::make_unique<Custom::OrtTensor<double>>(api_, ctx, ith_input, true); |
| 343 | break; |
| 344 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: |
| 345 | tensor = std::make_unique<Custom::OrtTensor<uint8_t>>(api_, ctx, ith_input, true); |
| 346 | break; |
| 347 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: |
| 348 | tensor = std::make_unique<Custom::OrtTensor<int8_t>>(api_, ctx, ith_input, true); |
| 349 | break; |
| 350 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: |
| 351 | tensor = std::make_unique<Custom::OrtTensor<uint16_t>>(api_, ctx, ith_input, true); |
| 352 | break; |
| 353 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: |
| 354 | tensor = std::make_unique<Custom::OrtTensor<int16_t>>(api_, ctx, ith_input, true); |
| 355 | break; |
| 356 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: |
| 357 | tensor = std::make_unique<Custom::OrtTensor<uint32_t>>(api_, ctx, ith_input, true); |
| 358 | break; |
| 359 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: |
| 360 | tensor = std::make_unique<Custom::OrtTensor<int32_t>>(api_, ctx, ith_input, true); |
| 361 | break; |
| 362 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: |
| 363 | tensor = std::make_unique<Custom::OrtTensor<uint64_t>>(api_, ctx, ith_input, true); |
| 364 | break; |
| 365 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: |
| 366 | tensor = std::make_unique<Custom::OrtTensor<int64_t>>(api_, ctx, ith_input, true); |
| 367 | break; |
| 368 | case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: |
| 369 | tensor = std::make_unique<Custom::OrtTensor<std::string>>(api_, ctx, ith_input, true); |
| 370 | break; |
| 371 | default: |
| 372 | ORTX_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION); |
| 373 | break; |
| 374 | } |
| 375 | tensors_.emplace_back(tensor.release()); |
| 376 | } // for |
| 377 | } else { |
| 378 | // a Variadic used for output is populated by the Compute so leave tensors_ empty here |
| 379 | } |
| 380 | } |
| 381 | template <typename T> |
| 382 | T* AllocateOutput(size_t ith_output, const std::vector<int64_t>& shape) { |
| 383 | auto tensor = std::make_unique<OrtTensor<T>>(api_, ctx_, ith_output, false); |
| 384 | auto raw_output = tensor.get()->Allocate(shape); |
| 385 | tensors_.emplace_back(tensor.release()); |
| 386 | return raw_output; |
| 387 | } |
| 388 | Tensor<std::string>& AllocateStringTensor(size_t ith_output) { |
| 389 | auto tensor = std::make_unique<OrtTensor<std::string>>(api_, ctx_, ith_output, false); |
| 390 | Tensor<std::string>& output = *tensor; |
| 391 | tensors_.emplace_back(tensor.release()); |
| 392 | return output; |
| 393 | } |
| 394 | size_t Size() const { |
| 395 | return tensors_.size(); |
| 396 | } |
| 397 | |
| 398 | const TensorBasePtr& operator[](size_t indice) const { |
| 399 | return tensors_.at(indice); |
| 400 | } |
| 401 | |
| 402 | private: |
| 403 | const OrtW::CustomOpApi& api_; |
| 404 | OrtKernelContext& ctx_; |
| 405 | size_t indice_; |
| 406 | std::string mem_type_ = "Cpu"; |
| 407 | TensorBasePtrs tensors_; |
| 408 | }; |
| 409 | |
| 410 | #if ORT_API_VERSION >= 17 |
| 411 | |
| 412 | class OrtGraphKernelContext : public KernelContext { |
| 413 | public: |
| 414 | OrtGraphKernelContext(const OrtApi& ort_api, const OrtKernelContext& ctx) : api_(ort_api) { |
| 415 | OrtMemoryInfo* info; |
| 416 | OrtW::ThrowOnError(api_, api_.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info)); |
| 417 | OrtW::ThrowOnError(api_, api_.KernelContext_GetAllocator(&ctx, info, &allocator_)); |
| 418 | api_.ReleaseMemoryInfo(info); |
| 419 | } |
| 420 | |
| 421 | virtual ~OrtGraphKernelContext() { |
| 422 | if (allocator_) { |
| 423 | api_.ReleaseAllocator(allocator_); |
| 424 | } |
| 425 | } |
| 426 | |
| 427 | void* AllocScratchBuffer(size_t size) override { |
| 428 | return allocator_->Alloc(allocator_, size); |
| 429 | } |
| 430 | |
| 431 | void FreeScratchBuffer(void* p) override { |
| 432 | if (p) { |
| 433 | allocator_->Free(allocator_, p); |
| 434 | } |
| 435 | } |
| 436 | |
| 437 | private: |
| 438 | const OrtApi& api_; |
| 439 | OrtAllocator* allocator_; |
| 440 | }; |
| 441 | |
| 442 | #endif |
| 443 | |
| 444 | #ifdef USE_CUDA |
| 445 | |
| 446 | enum CudaResource { |
| 447 | cuda_handle_t = 10000, |
| 448 | cudnn_handle_t, |
| 449 | cublas_handle_t, |
| 450 | deferred_cpu_allocator_t, |
| 451 | // below are cuda ep options |
| 452 | device_id_t, |
| 453 | }; |
| 454 | |
| 455 | #if ORT_API_VERSION >= 17 |
| 456 | class OrtGraphCudaKernelContext : public CUDAKernelContext { |
| 457 | public: |
| 458 | static const int cuda_resource_ver = 1; |
| 459 | |
| 460 | OrtGraphCudaKernelContext(const OrtApi& ort_api, const OrtKernelContext& ctx) : api_(ort_api) { |
| 461 | OrtStatusPtr result = api_.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cuda_handle_t, &cuda_stream_); |
| 462 | if (result || !cuda_stream_) { |
| 463 | ORTX_CXX_API_THROW("Failed to fetch cuda stream from context", ORT_RUNTIME_EXCEPTION); |
| 464 | } |
| 465 | result = api_.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::cublas_handle_t, &cublas_); |
| 466 | if (result || !cublas_) { |
| 467 | ORTX_CXX_API_THROW("Failed to fetch cublas handle from context", ORT_RUNTIME_EXCEPTION); |
| 468 | } |
| 469 | void* resource = nullptr; |
| 470 | result = api_.KernelContext_GetResource(&ctx, cuda_resource_ver, CudaResource::device_id_t, &resource); |
| 471 | if (result) { |
| 472 | ORTX_CXX_API_THROW("Failed to fetch device id from context", ORT_RUNTIME_EXCEPTION); |
| 473 | } |
| 474 | memcpy(&device_id_, &resource, sizeof(int)); |
| 475 | |
| 476 | OrtMemoryInfo* info; |
| 477 | OrtW::ThrowOnError(api_, api_.CreateCpuMemoryInfo(OrtAllocatorType::OrtArenaAllocator, OrtMemType::OrtMemTypeDefault, &info)); |
| 478 | OrtW::ThrowOnError(api_, api_.KernelContext_GetAllocator(&ctx, info, &cpu_allocator_)); |
| 479 | api_.ReleaseMemoryInfo(info); |
| 480 | |
| 481 | OrtMemoryInfo* cuda_mem_info; |
| 482 | OrtW::ThrowOnError(api_, api_.CreateMemoryInfo("Cuda", OrtAllocatorType::OrtArenaAllocator, device_id_, OrtMemType::OrtMemTypeDefault, &cuda_mem_info)); |
| 483 | OrtW::ThrowOnError(api_, api_.KernelContext_GetAllocator(&ctx, cuda_mem_info, &cuda_allocator_)); |
| 484 | api_.ReleaseMemoryInfo(cuda_mem_info); |
| 485 | } |
| 486 | |
| 487 | virtual ~OrtGraphCudaKernelContext() { |
| 488 | if (cpu_allocator_) { |
| 489 | api_.ReleaseAllocator(cpu_allocator_); |
| 490 | } |
| 491 | if (cuda_allocator_) { |
| 492 | api_.ReleaseAllocator(cuda_allocator_); |
| 493 | } |
| 494 | } |
| 495 | |
| 496 | void* AllocScratchBuffer(size_t size) override { |
| 497 | return cpu_allocator_->Alloc(cpu_allocator_, size); |
| 498 | } |
| 499 | |
| 500 | void FreeScratchBuffer(void* p) override { |
| 501 | if (p) { |
| 502 | cpu_allocator_->Free(cpu_allocator_, p); |
| 503 | } |
| 504 | } |
| 505 | |
| 506 | void* AllocCudaScratchBuffer(size_t size) override { |
| 507 | return cuda_allocator_->Alloc(cuda_allocator_, size); |
| 508 | } |
| 509 | |
| 510 | void FreeCudaScratchBuffer(void* p) override { |
| 511 | if (p) { |
| 512 | cuda_allocator_->Free(cuda_allocator_, p); |
| 513 | } |
| 514 | } |
| 515 | |
| 516 | void* GetCudaStream() const override { |
| 517 | return cuda_stream_; |
| 518 | } |
| 519 | |
| 520 | void* GetCublasHandle() const override { |
| 521 | return cublas_; |
| 522 | } |
| 523 | |
| 524 | int GetCudaDeviceId() const override { |
| 525 | return device_id_; |
| 526 | } |
| 527 | |
| 528 | private: |
| 529 | const OrtApi& api_; |
| 530 | OrtAllocator* cpu_allocator_; |
| 531 | OrtAllocator* cuda_allocator_; |
| 532 | void* cuda_stream_ = {}; |
| 533 | void* cublas_ = {}; |
| 534 | int device_id_ = 0; |
| 535 | }; |
| 536 | |
| 537 | #endif |
| 538 | #endif |
| 539 | |
| 540 | // using mf16_t = uint16_t; |
| 541 | |
| 542 | struct OrtLiteCustomOp : public OrtCustomOp { |
| 543 | // CreateTuple |
| 544 | template <size_t ith_input, size_t ith_output, typename... Ts> |
| 545 | static typename std::enable_if<sizeof...(Ts) == 0, std::tuple<>>::type |
| 546 | CreateTuple(const OrtW::CustomOpApi*, OrtKernelContext*, std::vector<TensorPtr>&, size_t, size_t, const std::string&) { |
| 547 | return std::make_tuple(); |
| 548 | } |
| 549 | |
| 550 | template <size_t ith_input, size_t ith_output, typename T, typename... Ts> |
| 551 | static typename std::enable_if<std::is_same<T, OrtKernelContext*>::value, std::tuple<T, Ts...>>::type |
| 552 | CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { |
| 553 | std::tuple<T> current = std::tuple<OrtKernelContext*>{context}; |
| 554 | auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); |
| 555 | return std::tuple_cat(current, next); |
| 556 | } |
| 557 | |
| 558 | #if ORT_API_VERSION >= 17 |
| 559 | template <size_t ith_input, size_t ith_output, typename T, typename... Ts> |
| 560 | static typename std::enable_if<std::is_same<T, KernelContext*>::value, std::tuple<T, Ts...>>::type |
| 561 | CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { |
| 562 | tensors.push_back(std::make_unique<OrtGraphKernelContext>(api->GetOrtApi(), *context)); |
| 563 | std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; |
| 564 | auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); |
| 565 | return std::tuple_cat(current, next); |
| 566 | } |
| 567 | |
| 568 | #ifdef USE_CUDA |
| 569 | |
| 570 | template <size_t ith_input, size_t ith_output, typename T, typename... Ts> |
| 571 | static typename std::enable_if<std::is_same<T, CUDAKernelContext*>::value, std::tuple<T, Ts...>>::type |
| 572 | CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { |
| 573 | tensors.push_back(std::make_unique<OrtGraphCudaKernelContext>(api->GetOrtApi(), *context)); |
| 574 | std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; |
| 575 | auto next = CreateTuple<ith_input, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); |
| 576 | return std::tuple_cat(current, next); |
| 577 | } |
| 578 | |
| 579 | #endif |
| 580 | |
| 581 | #endif |
| 582 | |
| 583 | #if ORT_API_VERSION >= 14 |
| 584 | template <size_t ith_input, size_t ith_output, typename T, typename... Ts> |
| 585 | static typename std::enable_if<std::is_same<T, const Variadic*>::value, std::tuple<T, Ts...>>::type |
| 586 | CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { |
| 587 | tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true)); |
| 588 | std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; |
| 589 | auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); |
| 590 | return std::tuple_cat(current, next); |
| 591 | } |
| 592 | |
| 593 | template <size_t ith_input, size_t ith_output, typename T, typename... Ts> |
| 594 | static typename std::enable_if<std::is_same<T, const Variadic&>::value, std::tuple<T, Ts...>>::type |
| 595 | CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { |
| 596 | tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_input, true)); |
| 597 | std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; |
| 598 | auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep); |
| 599 | return std::tuple_cat(current, next); |
| 600 | } |
| 601 | |
| 602 | template <size_t ith_input, size_t ith_output, typename T, typename... Ts> |
| 603 | static typename std::enable_if<std::is_same<T, Variadic*>::value, std::tuple<T, Ts...>>::type |
| 604 | CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { |
| 605 | tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false)); |
| 606 | std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())}; |
| 607 | auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); |
| 608 | return std::tuple_cat(current, next); |
| 609 | } |
| 610 | |
| 611 | template <size_t ith_input, size_t ith_output, typename T, typename... Ts> |
| 612 | static typename std::enable_if<std::is_same<T, Variadic&>::value, std::tuple<T, Ts...>>::type |
| 613 | CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) { |
| 614 | tensors.push_back(std::make_unique<Variadic>(*api, *context, ith_output, false)); |
| 615 | std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())}; |
| 616 | auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep); |
| 617 | return std::tuple_cat(current, next); |
| 618 | } |
| 619 | #endif |
| 620 | |
| 621 | #undef data_type_def |
| 622 | #define data_type_def bool |
| 623 | #include "tensor_tuple.inc" |
| 624 | |
| 625 | #if ORT_API_VERSION >= 16 |
| 626 | #undef data_type_def |
| 627 | #define data_type_def BFloat16 |
| 628 | #include "tensor_tuple.inc" |
| 629 | |
| 630 | #undef data_type_def |
| 631 | #define data_type_def MFloat16 |
| 632 | #include "tensor_tuple.inc" |
| 633 | #endif |
| 634 | |
| 635 | #undef data_type_def |
| 636 | #define data_type_def float |
| 637 | #include "tensor_tuple.inc" |
| 638 | |
| 639 | #undef data_type_def |
| 640 | #define data_type_def double |
| 641 | #include "tensor_tuple.inc" |
| 642 | |
| 643 | #undef data_type_def |
| 644 | #define data_type_def int8_t |
| 645 | #include "tensor_tuple.inc" |
| 646 | |
| 647 | #undef data_type_def |
| 648 | #define data_type_def int16_t |
| 649 | #include "tensor_tuple.inc" |
| 650 | |
| 651 | #undef data_type_def |
| 652 | #define data_type_def int32_t |
| 653 | #include "tensor_tuple.inc" |
| 654 | |
| 655 | #undef data_type_def |
| 656 | #define data_type_def int64_t |
| 657 | #include "tensor_tuple.inc" |
| 658 | |
| 659 | #undef data_type_def |
| 660 | #define data_type_def uint8_t |
| 661 | #include "tensor_tuple.inc" |
| 662 | |
| 663 | #undef data_type_def |
| 664 | #define data_type_def uint16_t |
| 665 | #include "tensor_tuple.inc" |
| 666 | |
| 667 | #undef data_type_def |
| 668 | #define data_type_def uint32_t |
| 669 | #include "tensor_tuple.inc" |
| 670 | |
| 671 | #undef data_type_def |
| 672 | #define data_type_def uint64_t |
| 673 | #include "tensor_tuple.inc" |
| 674 | |
| 675 | #undef data_type_def |
| 676 | #define data_type_def std::string |
| 677 | #include "tensor_tuple.inc" |
| 678 | |
| 679 | #undef data_type_def |
| 680 | #define data_type_def std::string_view |
| 681 | #include "tensor_tuple.inc" |
| 682 | |
| 683 | #undef data_type_def |
| 684 | |
| 685 | // ParseArgs ... |
| 686 | template <typename... Ts> |
| 687 | static typename std::enable_if<0 == sizeof...(Ts)>::type |
| 688 | ParseArgs(std::vector<ONNXTensorElementDataType>&, std::vector<ONNXTensorElementDataType>&) { |
| 689 | } |
| 690 | |
| 691 | template <typename T, typename... Ts> |
| 692 | static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, OrtKernelContext*>::value>::type |
| 693 | ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { |
| 694 | ParseArgs<Ts...>(input_types, output_types); |
| 695 | } |
| 696 | |
| 697 | #if ORT_API_VERSION >= 17 |
| 698 | template <typename T, typename... Ts> |
| 699 | static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, KernelContext*>::value>::type |
| 700 | ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { |
| 701 | ParseArgs<Ts...>(input_types, output_types); |
| 702 | } |
| 703 | |
| 704 | #ifdef USE_CUDA |
| 705 | template <typename T, typename... Ts> |
| 706 | static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, CUDAKernelContext*>::value>::type |
| 707 | ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { |
| 708 | ParseArgs<Ts...>(input_types, output_types); |
| 709 | } |
| 710 | #endif |
| 711 | #endif |
| 712 | |
| 713 | #if ORT_API_VERSION >= 14 |
| 714 | template <typename T, typename... Ts> |
| 715 | static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic&>::value>::type |
| 716 | ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { |
| 717 | if (!input_types.empty()) { |
| 718 | ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION); |
| 719 | } |
| 720 | input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); |
| 721 | ParseArgs<Ts...>(input_types, output_types); |
| 722 | } |
| 723 | |
| 724 | template <typename T, typename... Ts> |
| 725 | static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, const Variadic*>::value>::type |
| 726 | ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { |
| 727 | if (!input_types.empty()) { |
| 728 | ORTX_CXX_API_THROW("for op has variadic input, only one input is allowed", ORT_RUNTIME_EXCEPTION); |
| 729 | } |
| 730 | input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); |
| 731 | ParseArgs<Ts...>(input_types, output_types); |
| 732 | } |
| 733 | |
| 734 | template <typename T, typename... Ts> |
| 735 | static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic&>::value>::type |
| 736 | ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { |
| 737 | if (!output_types.empty()) { |
| 738 | ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION); |
| 739 | } |
| 740 | output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); |
| 741 | ParseArgs<Ts...>(input_types, output_types); |
| 742 | } |
| 743 | |
| 744 | template <typename T, typename... Ts> |
| 745 | static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Variadic*>::value>::type |
| 746 | ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { |
| 747 | if (!output_types.empty()) { |
| 748 | ORTX_CXX_API_THROW("for op has variadic output, only one output is allowed", ORT_RUNTIME_EXCEPTION); |
| 749 | } |
| 750 | output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); |
| 751 | ParseArgs<Ts...>(input_types, output_types); |
| 752 | } |
| 753 | #endif |
| 754 | |
| 755 | #define PARSE_INPUT_BASE(pack_type, onnx_type) \ |
| 756 | template <typename T, typename... Ts> \ |
| 757 | static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, pack_type>::value>::type \ |
| 758 | ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \ |
| 759 | input_types.push_back(onnx_type); \ |
| 760 | ParseArgs<Ts...>(input_types, output_types); \ |
| 761 | } \ |
| 762 | template <typename T, typename... Ts> \ |
| 763 | static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<pack_type>>::value>::type \ |
| 764 | ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \ |
| 765 | input_types.push_back(onnx_type); \ |
| 766 | ParseArgs<Ts...>(input_types, output_types); \ |
| 767 | } |
| 768 | |
| 769 | #define PARSE_INPUT(data_type, onnx_type) \ |
| 770 | PARSE_INPUT_BASE(const Custom::Tensor<data_type>*, onnx_type) \ |
| 771 | PARSE_INPUT_BASE(const Custom::Tensor<data_type>&, onnx_type) \ |
| 772 | PARSE_INPUT_BASE(const Custom::Span<data_type>*, onnx_type) \ |
| 773 | PARSE_INPUT_BASE(const Custom::Span<data_type>&, onnx_type) \ |
| 774 | PARSE_INPUT_BASE(data_type, onnx_type) |
| 775 | |
| 776 | #define PARSE_OUTPUT(data_type, onnx_type) \ |
| 777 | template <typename T, typename... Ts> \ |
| 778 | static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>*>::value>::type \ |
| 779 | ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \ |
| 780 | output_types.push_back(onnx_type); \ |
| 781 | ParseArgs<Ts...>(input_types, output_types); \ |
| 782 | } \ |
| 783 | template <typename T, typename... Ts> \ |
| 784 | static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, Custom::Tensor<data_type>&>::value>::type \ |
| 785 | ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \ |
| 786 | output_types.push_back(onnx_type); \ |
| 787 | ParseArgs<Ts...>(input_types, output_types); \ |
| 788 | } \ |
| 789 | template <typename T, typename... Ts> \ |
| 790 | static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same<T, std::optional<Custom::Tensor<data_type>*>>::value>::type \ |
| 791 | ParseArgs(std::vector<ONNXTensorElementDataType>& input_types, std::vector<ONNXTensorElementDataType>& output_types) { \ |
| 792 | output_types.push_back(onnx_type); \ |
| 793 | ParseArgs<Ts...>(input_types, output_types); \ |
| 794 | } |
| 795 | |
| 796 | #define PARSE_ARGS(data_type, onnx_type) \ |
| 797 | PARSE_INPUT(data_type, onnx_type) \ |
| 798 | PARSE_OUTPUT(data_type, onnx_type) |
| 799 | |
| 800 | PARSE_ARGS(bool, ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL) |
| 801 | #if ORT_API_VERSION >= 16 |
| 802 | PARSE_ARGS(MFloat16, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16) |
| 803 | PARSE_ARGS(BFloat16, ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16) |
| 804 | #endif |
| 805 | PARSE_ARGS(float, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT) |
| 806 | PARSE_ARGS(double, ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) |
| 807 | PARSE_ARGS(int8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8) |
| 808 | PARSE_ARGS(int16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16) |
| 809 | PARSE_ARGS(int32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32) |
| 810 | PARSE_ARGS(int64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) |
| 811 | PARSE_ARGS(uint8_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8) |
| 812 | PARSE_ARGS(uint16_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16) |
| 813 | PARSE_ARGS(uint32_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32) |
| 814 | PARSE_ARGS(uint64_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64) |
| 815 | PARSE_ARGS(std::string, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) |
| 816 | PARSE_ARGS(std::string_view, ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) // todo - remove string_view output |
| 817 | |
| 818 | OrtLiteCustomOp(const char* op_name, |
| 819 | const char* execution_provider) : op_name_(op_name), |
| 820 | execution_provider_(execution_provider) { |
| 821 | // Zero out OrtCustomOp so that any added func pointers are nullptr for forwards compatibility |
| 822 | memset(&this->version, 0, sizeof(OrtCustomOp)); |
| 823 | |
| 824 | int act_ver = GetActiveOrtAPIVersion(); |
| 825 | OrtCustomOp::version = act_ver < ORT_API_VERSION ? act_ver : ORT_API_VERSION; |
| 826 | |
| 827 | OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); }; |
| 828 | OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* op) { return ((OrtLiteCustomOp*)op)->execution_provider_.c_str(); }; |
| 829 | |
| 830 | OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* op) { |
| 831 | auto self = reinterpret_cast<const OrtLiteCustomOp*>(op); |
| 832 | return self->input_types_.size(); |
| 833 | }; |
| 834 | |
| 835 | OrtCustomOp::GetInputType = [](const OrtCustomOp* op, size_t indice) { |
| 836 | auto self = reinterpret_cast<const OrtLiteCustomOp*>(op); |
| 837 | return self->input_types_[indice]; |
| 838 | }; |
| 839 | |
| 840 | OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* op) { |
| 841 | auto self = reinterpret_cast<const OrtLiteCustomOp*>(op); |
| 842 | return self->output_types_.size(); |
| 843 | }; |
| 844 | |
| 845 | OrtCustomOp::GetOutputType = [](const OrtCustomOp* op, size_t indice) { |
| 846 | auto self = reinterpret_cast<const OrtLiteCustomOp*>(op); |
| 847 | return self->output_types_[indice]; |
| 848 | }; |
| 849 | |
| 850 | #if ORT_API_VERSION >= 14 |
| 851 | OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t) { |
| 852 | auto self = reinterpret_cast<const OrtLiteCustomOp*>(op); |
| 853 | return (self->input_types_.empty() || self->input_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC; |
| 854 | }; |
| 855 | |
| 856 | OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) { |
| 857 | auto self = reinterpret_cast<const OrtLiteCustomOp*>(op); |
| 858 | return (self->output_types_.empty() || self->output_types_[0] != ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) ? INPUT_OUTPUT_OPTIONAL : INPUT_OUTPUT_VARIADIC; |
| 859 | }; |
| 860 | |
| 861 | OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { |
| 862 | return 1; |
| 863 | }; |
| 864 | |
| 865 | OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { |
| 866 | return 0; |
| 867 | }; |
| 868 | |
| 869 | OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { |
| 870 | return 1; |
| 871 | }; |
| 872 | |
| 873 | OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { |
| 874 | return 0; |
| 875 | }; |
| 876 | |
| 877 | OrtCustomOp::GetInputMemoryType = [](const OrtCustomOp*, size_t) { |
| 878 | return OrtMemTypeDefault; |
| 879 | }; |
| 880 | #else |
| 881 | OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) { |
| 882 | return INPUT_OUTPUT_OPTIONAL; |
| 883 | }; |
| 884 | |
| 885 | OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t) { |
| 886 | return INPUT_OUTPUT_OPTIONAL; |
| 887 | }; |
| 888 | #endif |
| 889 | |
| 890 | #if ORT_API_VERSION >= 18 |
| 891 | OrtCustomOp::GetMayInplace = [](int**, int**) -> size_t { |
| 892 | return 0; |
| 893 | }; |
| 894 | OrtCustomOp::ReleaseMayInplace = [](int*, int*) -> void {}; |
| 895 | #endif |
| 896 | } |
| 897 | |
| 898 | const std::string op_name_; |
| 899 | const std::string execution_provider_; |
| 900 | |
| 901 | std::vector<ONNXTensorElementDataType> input_types_; |
| 902 | std::vector<ONNXTensorElementDataType> output_types_; |
| 903 | }; |
| 904 | |
| 905 | template <typename... Args> |
| 906 | struct OrtLiteCustomFunc : public OrtLiteCustomOp { |
| 907 | using ComputeFn = void (*)(Args...); |
| 908 | using MyType = OrtLiteCustomFunc<Args...>; |
| 909 | |
| 910 | struct Kernel { |
| 911 | ComputeFn compute_fn_{}; |
| 912 | std::string ep_{}; |
| 913 | std::unique_ptr<OrtW::CustomOpApi> api_; |
| 914 | }; |
| 915 | |
| 916 | OrtLiteCustomFunc(const char* op_name, |
| 917 | const char* execution_provider, |
| 918 | ComputeFn compute_fn) : OrtLiteCustomOp(op_name, execution_provider), |
| 919 | compute_fn_(compute_fn) { |
| 920 | ParseArgs<Args...>(input_types_, output_types_); |
| 921 | |
| 922 | OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { |
| 923 | auto kernel = reinterpret_cast<Kernel*>(op_kernel); |
| 924 | std::vector<TensorPtr> tensors; |
| 925 | auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(), |
| 926 | context, |
| 927 | tensors, |
| 928 | kernel->api_->KernelContext_GetInputCount(context), |
| 929 | kernel->api_->KernelContext_GetOutputCount(context), |
| 930 | kernel->ep_); |
| 931 | std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t); |
| 932 | }; |
| 933 | |
| 934 | OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { |
| 935 | auto kernel = std::make_unique<Kernel>(); |
| 936 | auto self = static_cast<const OrtLiteCustomFunc*>(this_); |
| 937 | kernel->compute_fn_ = self->compute_fn_; |
| 938 | kernel->ep_ = self->execution_provider_; |
| 939 | kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api); |
| 940 | return reinterpret_cast<void*>(kernel.release()); |
| 941 | }; |
| 942 | |
| 943 | OrtCustomOp::KernelDestroy = [](void* op_kernel) { |
| 944 | delete reinterpret_cast<Kernel*>(op_kernel); |
| 945 | }; |
| 946 | } |
| 947 | |
| 948 | ComputeFn compute_fn_; |
| 949 | }; |
| 950 | |
| 951 | class OrtAttributeReader { |
| 952 | public: |
| 953 | OrtAttributeReader(const OrtApi& ort_api, const OrtKernelInfo& info) : base_kernel_(ort_api, info) { |
| 954 | } |
| 955 | |
| 956 | template <class T> |
| 957 | T TryToGetAttributeWithDefault(const char* name, const T& default_value) const noexcept { |
| 958 | return base_kernel_.TryToGetAttributeWithDefault(name, default_value); |
| 959 | } |
| 960 | |
| 961 | private: |
| 962 | BaseKernel base_kernel_; |
| 963 | }; |
| 964 | |
| 965 | template <typename CustomOp> |
| 966 | struct OrtLiteCustomStruct : public OrtLiteCustomOp { |
| 967 | template <typename... Args> |
| 968 | using CustomComputeFn = void (CustomOp::*)(Args...) const; |
| 969 | using MyType = OrtLiteCustomStruct<CustomOp>; |
| 970 | |
| 971 | struct Kernel { |
| 972 | std::unique_ptr<CustomOp> custom_op_; |
| 973 | std::string ep_{}; |
| 974 | std::unique_ptr<OrtW::CustomOpApi> api_; |
| 975 | }; |
| 976 | |
| 977 | OrtLiteCustomStruct(const char* op_name, |
| 978 | const char* execution_provider) : OrtLiteCustomOp(op_name, |
| 979 | execution_provider) { |
| 980 | init(&CustomOp::Compute); |
| 981 | } |
| 982 | |
| 983 | template <typename... Args> |
| 984 | void init(CustomComputeFn<Args...>) { |
| 985 | ParseArgs<Args...>(input_types_, output_types_); |
| 986 | |
| 987 | OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { |
| 988 | auto kernel = reinterpret_cast<Kernel*>(op_kernel); |
| 989 | std::vector<TensorPtr> tensors; |
| 990 | auto t = CreateTuple<0, 0, Args...>(kernel->api_.get(), |
| 991 | context, |
| 992 | tensors, |
| 993 | kernel->api_->KernelContext_GetInputCount(context), |
| 994 | kernel->api_->KernelContext_GetOutputCount(context), |
| 995 | kernel->ep_); |
| 996 | std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t); |
| 997 | }; |
| 998 | |
| 999 | OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { |
| 1000 | auto kernel = std::make_unique<Kernel>(); |
| 1001 | |
| 1002 | if constexpr (std::is_constructible<CustomOp, const OrtApi&, const OrtKernelInfo&>::value) { |
| 1003 | kernel->custom_op_ = std::make_unique<CustomOp>(*ort_api, *info); |
| 1004 | } else { |
| 1005 | kernel->custom_op_ = std::make_unique<CustomOp>(OrtAttributeReader(*ort_api, *info)); |
| 1006 | } |
| 1007 | auto self = static_cast<const MyType*>(this_); |
| 1008 | kernel->ep_ = self->execution_provider_; |
| 1009 | kernel->api_ = std::make_unique<OrtW::CustomOpApi>(*ort_api); |
| 1010 | return reinterpret_cast<void*>(kernel.release()); |
| 1011 | }; |
| 1012 | |
| 1013 | OrtCustomOp::KernelDestroy = [](void* op_kernel) { |
| 1014 | delete reinterpret_cast<Kernel*>(op_kernel); |
| 1015 | }; |
| 1016 | } |
| 1017 | }; |
| 1018 | |
| 1019 | template <typename... Args> |
| 1020 | OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, |
| 1021 | const char* execution_provider, |
| 1022 | void (*custom_compute_fn)(Args...)) { |
| 1023 | using LiteOp = OrtLiteCustomFunc<Args...>; |
| 1024 | return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn).release(); |
| 1025 | } |
| 1026 | |
| 1027 | template <typename CustomOp> |
| 1028 | OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, |
| 1029 | const char* execution_provider) { |
| 1030 | using LiteOp = OrtLiteCustomStruct<CustomOp>; |
| 1031 | return std::make_unique<LiteOp>(op_name, execution_provider).release(); |
| 1032 | } |
| 1033 | |
| 1034 | } // namespace Custom |
| 1035 | } // namespace Ort |
| 1036 | |