microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
includes/onnxruntime/onnxruntime_cxx_api.h
632lines · modecode
| 1 | // Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | // Licensed under the MIT License. |
| 3 | |
| 4 | // Summary: The Ort C++ API is a header only wrapper around the Ort C API. |
| 5 | // |
| 6 | // The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors |
| 7 | // and automatically releasing resources in the destructors. |
| 8 | // |
| 9 | // Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers. |
| 10 | // To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};). |
| 11 | // |
| 12 | // Only move assignment between objects is allowed, there are no copy constructors. Some objects have explicit 'Clone' |
| 13 | // methods for this purpose. |
| 14 | |
| 15 | #pragma once |
| 16 | #include "onnxruntime_c_api.h" |
| 17 | #include <cstddef> |
| 18 | #include <array> |
| 19 | #include <memory> |
| 20 | #include <stdexcept> |
| 21 | #include <string> |
| 22 | #include <vector> |
| 23 | #include <utility> |
| 24 | #include <type_traits> |
| 25 | |
| 26 | #ifdef ORT_NO_EXCEPTIONS |
| 27 | #include <iostream> |
| 28 | #endif |
| 29 | |
| 30 | namespace Ort { |
| 31 | |
| 32 | // All C++ methods that can fail will throw an exception of this type |
| 33 | struct Exception : std::exception { |
| 34 | Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {} |
| 35 | |
| 36 | OrtErrorCode GetOrtErrorCode() const { return code_; } |
| 37 | const char* what() const noexcept override { return message_.c_str(); } |
| 38 | |
| 39 | private: |
| 40 | std::string message_; |
| 41 | OrtErrorCode code_; |
| 42 | }; |
| 43 | |
| 44 | #ifdef ORT_NO_EXCEPTIONS |
| 45 | #define ORT_CXX_API_THROW(string, code) \ |
| 46 | do { \ |
| 47 | std::cerr << Ort::Exception(string, code) \ |
| 48 | .what() \ |
| 49 | << std::endl; \ |
| 50 | abort(); \ |
| 51 | } while (false) |
| 52 | #else |
| 53 | #define ORT_CXX_API_THROW(string, code) \ |
| 54 | throw Ort::Exception(string, code) |
| 55 | #endif |
| 56 | |
| 57 | // This is used internally by the C++ API. This class holds the global variable that points to the OrtApi, it's in a template so that we can define a global variable in a header and make |
| 58 | // it transparent to the users of the API. |
| 59 | template <typename T> |
| 60 | struct Global { |
| 61 | static const OrtApi* api_; |
| 62 | }; |
| 63 | |
| 64 | // If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it. |
| 65 | |
| 66 | template <typename T> |
| 67 | #ifdef ORT_API_MANUAL_INIT |
| 68 | const OrtApi* Global<T>::api_{}; |
| 69 | inline void InitApi() { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); } |
| 70 | #else |
| 71 | const OrtApi* Global<T>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); |
| 72 | #endif |
| 73 | |
| 74 | // This returns a reference to the OrtApi interface in use, in case someone wants to use the C API functions |
| 75 | inline const OrtApi& GetApi() { return *Global<void>::api_; } |
| 76 | |
| 77 | // This is a C++ wrapper for GetAvailableProviders() C API and returns |
| 78 | // a vector of strings representing the available execution providers. |
| 79 | std::vector<std::string> GetAvailableProviders(); |
| 80 | |
| 81 | // This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type |
| 82 | // This can't be done in the C API since C doesn't have function overloading. |
| 83 | #define ORT_DEFINE_RELEASE(NAME) \ |
| 84 | inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); } |
| 85 | |
| 86 | ORT_DEFINE_RELEASE(Allocator); |
| 87 | ORT_DEFINE_RELEASE(MemoryInfo); |
| 88 | ORT_DEFINE_RELEASE(CustomOpDomain); |
| 89 | ORT_DEFINE_RELEASE(Env); |
| 90 | ORT_DEFINE_RELEASE(RunOptions); |
| 91 | ORT_DEFINE_RELEASE(Session); |
| 92 | ORT_DEFINE_RELEASE(SessionOptions); |
| 93 | ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo); |
| 94 | ORT_DEFINE_RELEASE(SequenceTypeInfo); |
| 95 | ORT_DEFINE_RELEASE(MapTypeInfo); |
| 96 | ORT_DEFINE_RELEASE(TypeInfo); |
| 97 | ORT_DEFINE_RELEASE(Value); |
| 98 | ORT_DEFINE_RELEASE(ModelMetadata); |
| 99 | ORT_DEFINE_RELEASE(ThreadingOptions); |
| 100 | ORT_DEFINE_RELEASE(IoBinding); |
| 101 | ORT_DEFINE_RELEASE(ArenaCfg); |
| 102 | |
| 103 | /*! \class Ort::Float16_t |
| 104 | * \brief it is a structure that represents float16 data. |
| 105 | * \details It is necessary for type dispatching to make use of C++ API |
| 106 | * The type is implicitly convertible to/from uint16_t. |
| 107 | * The size of the structure should align with uint16_t and one can freely cast |
| 108 | * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data. |
| 109 | * |
| 110 | * Generally, you can feed any of your types as float16/blfoat16 data to create a tensor |
| 111 | * on top of it, providing it can form a continuous buffer with 16-bit elements with no padding. |
| 112 | * And you can also feed a array of uint16_t elements directly. For example, |
| 113 | * |
| 114 | * \code{.unparsed} |
| 115 | * uint16_t values[] = { 15360, 16384, 16896, 17408, 17664}; |
| 116 | * constexpr size_t values_length = sizeof(values) / sizeof(values[0]); |
| 117 | * std::vector<int64_t> dims = {values_length}; // one dimensional example |
| 118 | * Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); |
| 119 | * // Note we are passing bytes count in this api, not number of elements -> sizeof(values) |
| 120 | * auto float16_tensor = Ort::Value::CreateTensor(info, values, sizeof(values), |
| 121 | * dims.data(), dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16); |
| 122 | * \endcode |
| 123 | * |
| 124 | * Here is another example, a little bit more elaborate. Let's assume that you use your own float16 type and you want to use |
| 125 | * a templated version of the API above so the type is automatically set based on your type. You will need to supply an extra |
| 126 | * template specialization. |
| 127 | * |
| 128 | * \code{.unparsed} |
| 129 | * namespace yours { struct half {}; } // assume this is your type, define this: |
| 130 | * namespace Ort { |
| 131 | * template<> |
| 132 | * struct TypeToTensorType<yours::half> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; }; |
| 133 | * } //namespace Ort |
| 134 | * |
| 135 | * std::vector<yours::half> values; |
| 136 | * std::vector<int64_t> dims = {values.size()}; // one dimensional example |
| 137 | * Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault); |
| 138 | * // Here we are passing element count -> values.size() |
| 139 | * auto float16_tensor = Ort::Value::CreateTensor<yours::half>(info, values.data(), values.size(), dims.data(), dims.size()); |
| 140 | * |
| 141 | * \endcode |
| 142 | */ |
| 143 | struct Float16_t { |
| 144 | uint16_t value; |
| 145 | constexpr Float16_t() noexcept : value(0) {} |
| 146 | constexpr Float16_t(uint16_t v) noexcept : value(v) {} |
| 147 | constexpr operator uint16_t() const noexcept { return value; } |
| 148 | constexpr bool operator==(const Float16_t& rhs) const noexcept { return value == rhs.value; }; |
| 149 | constexpr bool operator!=(const Float16_t& rhs) const noexcept { return value != rhs.value; }; |
| 150 | }; |
| 151 | |
| 152 | static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match"); |
| 153 | |
| 154 | /*! \class Ort::BFloat16_t |
| 155 | * \brief is a structure that represents bfloat16 data. |
| 156 | * \details It is necessary for type dispatching to make use of C++ API |
| 157 | * The type is implicitly convertible to/from uint16_t. |
| 158 | * The size of the structure should align with uint16_t and one can freely cast |
| 159 | * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data. |
| 160 | * |
| 161 | * See also code examples for Float16_t above. |
| 162 | */ |
| 163 | struct BFloat16_t { |
| 164 | uint16_t value; |
| 165 | constexpr BFloat16_t() noexcept : value(0) {} |
| 166 | constexpr BFloat16_t(uint16_t v) noexcept : value(v) {} |
| 167 | constexpr operator uint16_t() const noexcept { return value; } |
| 168 | constexpr bool operator==(const BFloat16_t& rhs) const noexcept { return value == rhs.value; }; |
| 169 | constexpr bool operator!=(const BFloat16_t& rhs) const noexcept { return value != rhs.value; }; |
| 170 | }; |
| 171 | |
| 172 | static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match"); |
| 173 | |
| 174 | // This is used internally by the C++ API. This is the common base class used by the wrapper objects. |
| 175 | template <typename T> |
| 176 | struct Base { |
| 177 | using contained_type = T; |
| 178 | |
| 179 | Base() = default; |
| 180 | Base(T* p) : p_{p} { |
| 181 | if (!p) |
| 182 | ORT_CXX_API_THROW("Allocation failure", ORT_FAIL); |
| 183 | } |
| 184 | ~Base() { OrtRelease(p_); } |
| 185 | |
| 186 | operator T*() { return p_; } |
| 187 | operator const T*() const { return p_; } |
| 188 | |
| 189 | T* release() { |
| 190 | T* p = p_; |
| 191 | p_ = nullptr; |
| 192 | return p; |
| 193 | } |
| 194 | |
| 195 | protected: |
| 196 | Base(const Base&) = delete; |
| 197 | Base& operator=(const Base&) = delete; |
| 198 | Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } |
| 199 | void operator=(Base&& v) noexcept { |
| 200 | OrtRelease(p_); |
| 201 | p_ = v.p_; |
| 202 | v.p_ = nullptr; |
| 203 | } |
| 204 | |
| 205 | T* p_{}; |
| 206 | |
| 207 | template <typename> |
| 208 | friend struct Unowned; // This friend line is needed to keep the centos C++ compiler from giving an error |
| 209 | }; |
| 210 | |
| 211 | template <typename T> |
| 212 | struct Base<const T> { |
| 213 | using contained_type = const T; |
| 214 | |
| 215 | Base() = default; |
| 216 | Base(const T* p) : p_{p} { |
| 217 | if (!p) |
| 218 | ORT_CXX_API_THROW("Invalid instance ptr", ORT_INVALID_ARGUMENT); |
| 219 | } |
| 220 | ~Base() = default; |
| 221 | |
| 222 | operator const T*() const { return p_; } |
| 223 | |
| 224 | protected: |
| 225 | Base(const Base&) = delete; |
| 226 | Base& operator=(const Base&) = delete; |
| 227 | Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; } |
| 228 | void operator=(Base&& v) noexcept { |
| 229 | p_ = v.p_; |
| 230 | v.p_ = nullptr; |
| 231 | } |
| 232 | |
| 233 | const T* p_{}; |
| 234 | }; |
| 235 | |
| 236 | template <typename T> |
| 237 | struct Unowned : T { |
| 238 | Unowned(decltype(T::p_) p) : T{p} {} |
| 239 | Unowned(Unowned&& v) : T{v.p_} {} |
| 240 | ~Unowned() { this->release(); } |
| 241 | }; |
| 242 | |
| 243 | struct AllocatorWithDefaultOptions; |
| 244 | struct MemoryInfo; |
| 245 | struct Env; |
| 246 | struct TypeInfo; |
| 247 | struct Value; |
| 248 | struct ModelMetadata; |
| 249 | |
| 250 | struct Env : Base<OrtEnv> { |
| 251 | Env(std::nullptr_t) {} |
| 252 | Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); |
| 253 | Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); |
| 254 | Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param); |
| 255 | Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param, |
| 256 | OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = ""); |
| 257 | explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {} |
| 258 | |
| 259 | Env& EnableTelemetryEvents(); |
| 260 | Env& DisableTelemetryEvents(); |
| 261 | |
| 262 | Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg); |
| 263 | |
| 264 | static const OrtApi* s_api; |
| 265 | }; |
| 266 | |
| 267 | struct CustomOpDomain : Base<OrtCustomOpDomain> { |
| 268 | explicit CustomOpDomain(std::nullptr_t) {} |
| 269 | explicit CustomOpDomain(const char* domain); |
| 270 | |
| 271 | void Add(OrtCustomOp* op); |
| 272 | }; |
| 273 | |
| 274 | struct RunOptions : Base<OrtRunOptions> { |
| 275 | RunOptions(std::nullptr_t) {} |
| 276 | RunOptions(); |
| 277 | |
| 278 | RunOptions& SetRunLogVerbosityLevel(int); |
| 279 | int GetRunLogVerbosityLevel() const; |
| 280 | |
| 281 | RunOptions& SetRunLogSeverityLevel(int); |
| 282 | int GetRunLogSeverityLevel() const; |
| 283 | |
| 284 | RunOptions& SetRunTag(const char* run_tag); |
| 285 | const char* GetRunTag() const; |
| 286 | |
| 287 | // terminate ALL currently executing Session::Run calls that were made using this RunOptions instance |
| 288 | RunOptions& SetTerminate(); |
| 289 | // unset the terminate flag so this RunOptions instance can be used in a new Session::Run call |
| 290 | RunOptions& UnsetTerminate(); |
| 291 | }; |
| 292 | |
| 293 | struct SessionOptions : Base<OrtSessionOptions> { |
| 294 | explicit SessionOptions(std::nullptr_t) {} |
| 295 | SessionOptions(); |
| 296 | explicit SessionOptions(OrtSessionOptions* p) : Base<OrtSessionOptions>{p} {} |
| 297 | |
| 298 | SessionOptions Clone() const; |
| 299 | |
| 300 | SessionOptions& SetIntraOpNumThreads(int intra_op_num_threads); |
| 301 | SessionOptions& SetInterOpNumThreads(int inter_op_num_threads); |
| 302 | SessionOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level); |
| 303 | |
| 304 | SessionOptions& EnableCpuMemArena(); |
| 305 | SessionOptions& DisableCpuMemArena(); |
| 306 | |
| 307 | SessionOptions& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file); |
| 308 | |
| 309 | SessionOptions& EnableProfiling(const ORTCHAR_T* profile_file_prefix); |
| 310 | SessionOptions& DisableProfiling(); |
| 311 | |
| 312 | SessionOptions& EnableMemPattern(); |
| 313 | SessionOptions& DisableMemPattern(); |
| 314 | |
| 315 | SessionOptions& SetExecutionMode(ExecutionMode execution_mode); |
| 316 | |
| 317 | SessionOptions& SetLogId(const char* logid); |
| 318 | SessionOptions& SetLogSeverityLevel(int level); |
| 319 | |
| 320 | SessionOptions& Add(OrtCustomOpDomain* custom_op_domain); |
| 321 | |
| 322 | SessionOptions& DisablePerSessionThreads(); |
| 323 | |
| 324 | SessionOptions& AddConfigEntry(const char* config_key, const char* config_value); |
| 325 | SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val); |
| 326 | |
| 327 | SessionOptions& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options); |
| 328 | SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); |
| 329 | }; |
| 330 | |
| 331 | struct ModelMetadata : Base<OrtModelMetadata> { |
| 332 | explicit ModelMetadata(std::nullptr_t) {} |
| 333 | explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {} |
| 334 | |
| 335 | char* GetProducerName(OrtAllocator* allocator) const; |
| 336 | char* GetGraphName(OrtAllocator* allocator) const; |
| 337 | char* GetDomain(OrtAllocator* allocator) const; |
| 338 | char* GetDescription(OrtAllocator* allocator) const; |
| 339 | char** GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const; |
| 340 | char* LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const; |
| 341 | int64_t GetVersion() const; |
| 342 | }; |
| 343 | |
| 344 | struct Session : Base<OrtSession> { |
| 345 | explicit Session(std::nullptr_t) {} |
| 346 | Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options); |
| 347 | Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options); |
| 348 | |
| 349 | // Run that will allocate the output values |
| 350 | std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, |
| 351 | const char* const* output_names, size_t output_count); |
| 352 | // Run for when there is a list of preallocated outputs |
| 353 | void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count, |
| 354 | const char* const* output_names, Value* output_values, size_t output_count); |
| 355 | |
| 356 | void Run(const RunOptions& run_options, const struct IoBinding&); |
| 357 | |
| 358 | size_t GetInputCount() const; |
| 359 | size_t GetOutputCount() const; |
| 360 | size_t GetOverridableInitializerCount() const; |
| 361 | |
| 362 | char* GetInputName(size_t index, OrtAllocator* allocator) const; |
| 363 | char* GetOutputName(size_t index, OrtAllocator* allocator) const; |
| 364 | char* GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const; |
| 365 | char* EndProfiling(OrtAllocator* allocator) const; |
| 366 | uint64_t GetProfilingStartTimeNs() const; |
| 367 | ModelMetadata GetModelMetadata() const; |
| 368 | |
| 369 | TypeInfo GetInputTypeInfo(size_t index) const; |
| 370 | TypeInfo GetOutputTypeInfo(size_t index) const; |
| 371 | TypeInfo GetOverridableInitializerTypeInfo(size_t index) const; |
| 372 | }; |
| 373 | |
| 374 | struct TensorTypeAndShapeInfo : Base<OrtTensorTypeAndShapeInfo> { |
| 375 | explicit TensorTypeAndShapeInfo(std::nullptr_t) {} |
| 376 | explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : Base<OrtTensorTypeAndShapeInfo>{p} {} |
| 377 | |
| 378 | ONNXTensorElementDataType GetElementType() const; |
| 379 | size_t GetElementCount() const; |
| 380 | |
| 381 | size_t GetDimensionsCount() const; |
| 382 | void GetDimensions(int64_t* values, size_t values_count) const; |
| 383 | void GetSymbolicDimensions(const char** values, size_t values_count) const; |
| 384 | |
| 385 | std::vector<int64_t> GetShape() const; |
| 386 | }; |
| 387 | |
| 388 | struct SequenceTypeInfo : Base<OrtSequenceTypeInfo> { |
| 389 | explicit SequenceTypeInfo(std::nullptr_t) {} |
| 390 | explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : Base<OrtSequenceTypeInfo>{p} {} |
| 391 | |
| 392 | TypeInfo GetSequenceElementType() const; |
| 393 | }; |
| 394 | |
| 395 | struct MapTypeInfo : Base<OrtMapTypeInfo> { |
| 396 | explicit MapTypeInfo(std::nullptr_t) {} |
| 397 | explicit MapTypeInfo(OrtMapTypeInfo* p) : Base<OrtMapTypeInfo>{p} {} |
| 398 | |
| 399 | ONNXTensorElementDataType GetMapKeyType() const; |
| 400 | TypeInfo GetMapValueType() const; |
| 401 | }; |
| 402 | |
| 403 | struct TypeInfo : Base<OrtTypeInfo> { |
| 404 | explicit TypeInfo(std::nullptr_t) {} |
| 405 | explicit TypeInfo(OrtTypeInfo* p) : Base<OrtTypeInfo>{p} {} |
| 406 | |
| 407 | Unowned<TensorTypeAndShapeInfo> GetTensorTypeAndShapeInfo() const; |
| 408 | Unowned<SequenceTypeInfo> GetSequenceTypeInfo() const; |
| 409 | Unowned<MapTypeInfo> GetMapTypeInfo() const; |
| 410 | |
| 411 | ONNXType GetONNXType() const; |
| 412 | }; |
| 413 | |
| 414 | struct Value : Base<OrtValue> { |
| 415 | template <typename T> |
| 416 | static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len); |
| 417 | static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len, |
| 418 | ONNXTensorElementDataType type); |
| 419 | template <typename T> |
| 420 | static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len); |
| 421 | static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type); |
| 422 | |
| 423 | static Value CreateMap(Value& keys, Value& values); |
| 424 | static Value CreateSequence(std::vector<Value>& values); |
| 425 | |
| 426 | template <typename T> |
| 427 | static Value CreateOpaque(const char* domain, const char* type_name, const T&); |
| 428 | |
| 429 | template <typename T> |
| 430 | void GetOpaqueData(const char* domain, const char* type_name, T&) const; |
| 431 | |
| 432 | explicit Value(std::nullptr_t) {} |
| 433 | explicit Value(OrtValue* p) : Base<OrtValue>{p} {} |
| 434 | Value(Value&&) = default; |
| 435 | Value& operator=(Value&&) = default; |
| 436 | |
| 437 | bool IsTensor() const; |
| 438 | size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements |
| 439 | Value GetValue(int index, OrtAllocator* allocator) const; |
| 440 | |
| 441 | size_t GetStringTensorDataLength() const; |
| 442 | void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const; |
| 443 | |
| 444 | template <typename T> |
| 445 | T* GetTensorMutableData(); |
| 446 | |
| 447 | template <typename T> |
| 448 | const T* GetTensorData() const; |
| 449 | |
| 450 | template <typename T> |
| 451 | T& At(const std::vector<int64_t>& location); |
| 452 | |
| 453 | TypeInfo GetTypeInfo() const; |
| 454 | TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const; |
| 455 | |
| 456 | size_t GetStringTensorElementLength(size_t element_index) const; |
| 457 | void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const; |
| 458 | |
| 459 | void FillStringTensor(const char* const* s, size_t s_len); |
| 460 | void FillStringTensorElement(const char* s, size_t index); |
| 461 | }; |
| 462 | |
| 463 | // Represents native memory allocation |
| 464 | struct MemoryAllocation { |
| 465 | MemoryAllocation(OrtAllocator* allocator, void* p, size_t size); |
| 466 | ~MemoryAllocation(); |
| 467 | MemoryAllocation(const MemoryAllocation&) = delete; |
| 468 | MemoryAllocation& operator=(const MemoryAllocation&) = delete; |
| 469 | MemoryAllocation(MemoryAllocation&&); |
| 470 | MemoryAllocation& operator=(MemoryAllocation&&); |
| 471 | |
| 472 | void* get() { return p_; } |
| 473 | size_t size() const { return size_; } |
| 474 | |
| 475 | private: |
| 476 | OrtAllocator* allocator_; |
| 477 | void* p_; |
| 478 | size_t size_; |
| 479 | }; |
| 480 | |
| 481 | struct AllocatorWithDefaultOptions { |
| 482 | AllocatorWithDefaultOptions(); |
| 483 | |
| 484 | operator OrtAllocator*() { return p_; } |
| 485 | operator const OrtAllocator*() const { return p_; } |
| 486 | |
| 487 | void* Alloc(size_t size); |
| 488 | // The return value will own the allocation |
| 489 | MemoryAllocation GetAllocation(size_t size); |
| 490 | void Free(void* p); |
| 491 | |
| 492 | const OrtMemoryInfo* GetInfo() const; |
| 493 | |
| 494 | private: |
| 495 | OrtAllocator* p_{}; |
| 496 | }; |
| 497 | |
| 498 | template <typename B> |
| 499 | struct BaseMemoryInfo : B { |
| 500 | BaseMemoryInfo() = default; |
| 501 | explicit BaseMemoryInfo(typename B::contained_type* p) : B(p) {} |
| 502 | ~BaseMemoryInfo() = default; |
| 503 | BaseMemoryInfo(BaseMemoryInfo&&) = default; |
| 504 | BaseMemoryInfo& operator=(BaseMemoryInfo&&) = default; |
| 505 | |
| 506 | std::string GetAllocatorName() const; |
| 507 | OrtAllocatorType GetAllocatorType() const; |
| 508 | int GetDeviceId() const; |
| 509 | OrtMemType GetMemoryType() const; |
| 510 | template <typename U> |
| 511 | bool operator==(const BaseMemoryInfo<U>& o) const; |
| 512 | }; |
| 513 | |
| 514 | struct UnownedMemoryInfo : BaseMemoryInfo<Base<const OrtMemoryInfo> > { |
| 515 | explicit UnownedMemoryInfo(std::nullptr_t) {} |
| 516 | explicit UnownedMemoryInfo(const OrtMemoryInfo* p) : BaseMemoryInfo(p) {} |
| 517 | }; |
| 518 | |
| 519 | struct MemoryInfo : BaseMemoryInfo<Base<OrtMemoryInfo> > { |
| 520 | static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1); |
| 521 | |
| 522 | explicit MemoryInfo(std::nullptr_t) {} |
| 523 | explicit MemoryInfo(OrtMemoryInfo* p) : BaseMemoryInfo(p) {} |
| 524 | MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type); |
| 525 | }; |
| 526 | |
| 527 | struct Allocator : public Base<OrtAllocator> { |
| 528 | Allocator(const Session& session, const MemoryInfo&); |
| 529 | |
| 530 | void* Alloc(size_t size) const; |
| 531 | // The return value will own the allocation |
| 532 | MemoryAllocation GetAllocation(size_t size); |
| 533 | void Free(void* p) const; |
| 534 | UnownedMemoryInfo GetInfo() const; |
| 535 | }; |
| 536 | |
| 537 | struct IoBinding : public Base<OrtIoBinding> { |
| 538 | private: |
| 539 | std::vector<std::string> GetOutputNamesHelper(OrtAllocator*) const; |
| 540 | std::vector<Value> GetOutputValuesHelper(OrtAllocator*) const; |
| 541 | |
| 542 | public: |
| 543 | explicit IoBinding(Session& session); |
| 544 | void BindInput(const char* name, const Value&); |
| 545 | void BindOutput(const char* name, const Value&); |
| 546 | void BindOutput(const char* name, const MemoryInfo&); |
| 547 | std::vector<std::string> GetOutputNames() const; |
| 548 | std::vector<std::string> GetOutputNames(Allocator&) const; |
| 549 | std::vector<Value> GetOutputValues() const; |
| 550 | std::vector<Value> GetOutputValues(Allocator&) const; |
| 551 | void ClearBoundInputs(); |
| 552 | void ClearBoundOutputs(); |
| 553 | }; |
| 554 | |
| 555 | /*! \struct Ort::ArenaCfg |
| 556 | * \brief it is a structure that represents the configuration of an arena based allocator |
| 557 | * \details Please see docs/C_API.md for details |
| 558 | */ |
| 559 | struct ArenaCfg : Base<OrtArenaCfg> { |
| 560 | explicit ArenaCfg(std::nullptr_t) {} |
| 561 | /** |
| 562 | * \param max_mem - use 0 to allow ORT to choose the default |
| 563 | * \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested |
| 564 | * \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default |
| 565 | * \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default |
| 566 | * \return an instance of ArenaCfg |
| 567 | * See docs/C_API.md for details on what the following parameters mean and how to choose these values |
| 568 | */ |
| 569 | ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk); |
| 570 | }; |
| 571 | |
| 572 | // |
| 573 | // Custom OPs (only needed to implement custom OPs) |
| 574 | // |
| 575 | |
| 576 | struct CustomOpApi { |
| 577 | CustomOpApi(const OrtApi& api) : api_(api) {} |
| 578 | |
| 579 | template <typename T> // T is only implemented for float, int64_t, and string |
| 580 | T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name); |
| 581 | |
| 582 | OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value); |
| 583 | size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info); |
| 584 | ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info); |
| 585 | size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info); |
| 586 | void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length); |
| 587 | void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count); |
| 588 | |
| 589 | template <typename T> |
| 590 | T* GetTensorMutableData(_Inout_ OrtValue* value); |
| 591 | template <typename T> |
| 592 | const T* GetTensorData(_Inout_ const OrtValue* value); |
| 593 | |
| 594 | std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info); |
| 595 | void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input); |
| 596 | size_t KernelContext_GetInputCount(const OrtKernelContext* context); |
| 597 | const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index); |
| 598 | size_t KernelContext_GetOutputCount(const OrtKernelContext* context); |
| 599 | OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count); |
| 600 | |
| 601 | void ThrowOnError(OrtStatus* result); |
| 602 | |
| 603 | private: |
| 604 | const OrtApi& api_; |
| 605 | }; |
| 606 | |
| 607 | template <typename TOp, typename TKernel> |
| 608 | struct CustomOpBase : OrtCustomOp { |
| 609 | CustomOpBase() { |
| 610 | OrtCustomOp::version = ORT_API_VERSION; |
| 611 | OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); }; |
| 612 | OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); }; |
| 613 | |
| 614 | OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); }; |
| 615 | |
| 616 | OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); }; |
| 617 | OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); }; |
| 618 | |
| 619 | OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); }; |
| 620 | OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); }; |
| 621 | |
| 622 | OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); }; |
| 623 | OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); }; |
| 624 | } |
| 625 | |
| 626 | // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider |
| 627 | const char* GetExecutionProviderType() const { return nullptr; } |
| 628 | }; |
| 629 | |
| 630 | } // namespace Ort |
| 631 | |
| 632 | #include "onnxruntime_cxx_inline.h" |
| 633 | |