microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
1ae69c0f7aeaab9911cf8ebf86ee92b34dadd26e

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/onnxruntime/onnxruntime_cxx_api.h

632lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4// Summary: The Ort C++ API is a header only wrapper around the Ort C API.
5//
6// The C++ API simplifies usage by returning values directly instead of error codes, throwing exceptions on errors
7// and automatically releasing resources in the destructors.
8//
9// Each of the C++ wrapper classes holds only a pointer to the C internal object. Treat them like smart pointers.
10// To create an empty object, pass 'nullptr' to the constructor (for example, Env e{nullptr};).
11//
12// Only move assignment between objects is allowed, there are no copy constructors. Some objects have explicit 'Clone'
13// methods for this purpose.
14
15#pragma once
16#include "onnxruntime_c_api.h"
17#include <cstddef>
18#include <array>
19#include <memory>
20#include <stdexcept>
21#include <string>
22#include <vector>
23#include <utility>
24#include <type_traits>
25
26#ifdef ORT_NO_EXCEPTIONS
27#include <iostream>
28#endif
29
30namespace Ort {
31
32// All C++ methods that can fail will throw an exception of this type
33struct Exception : std::exception {
34 Exception(std::string&& string, OrtErrorCode code) : message_{std::move(string)}, code_{code} {}
35
36 OrtErrorCode GetOrtErrorCode() const { return code_; }
37 const char* what() const noexcept override { return message_.c_str(); }
38
39 private:
40 std::string message_;
41 OrtErrorCode code_;
42};
43
44#ifdef ORT_NO_EXCEPTIONS
45#define ORT_CXX_API_THROW(string, code) \
46 do { \
47 std::cerr << Ort::Exception(string, code) \
48 .what() \
49 << std::endl; \
50 abort(); \
51 } while (false)
52#else
53#define ORT_CXX_API_THROW(string, code) \
54 throw Ort::Exception(string, code)
55#endif
56
57// This is used internally by the C++ API. This class holds the global variable that points to the OrtApi, it's in a template so that we can define a global variable in a header and make
58// it transparent to the users of the API.
59template <typename T>
60struct Global {
61 static const OrtApi* api_;
62};
63
64// If macro ORT_API_MANUAL_INIT is defined, no static initialization will be performed. Instead, user must call InitApi() before using it.
65
66template <typename T>
67#ifdef ORT_API_MANUAL_INIT
68const OrtApi* Global<T>::api_{};
69inline void InitApi() { Global<void>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION); }
70#else
71const OrtApi* Global<T>::api_ = OrtGetApiBase()->GetApi(ORT_API_VERSION);
72#endif
73
74// This returns a reference to the OrtApi interface in use, in case someone wants to use the C API functions
75inline const OrtApi& GetApi() { return *Global<void>::api_; }
76
77// This is a C++ wrapper for GetAvailableProviders() C API and returns
78// a vector of strings representing the available execution providers.
79std::vector<std::string> GetAvailableProviders();
80
81// This is used internally by the C++ API. This macro is to make it easy to generate overloaded methods for all of the various OrtRelease* functions for every Ort* type
82// This can't be done in the C API since C doesn't have function overloading.
83#define ORT_DEFINE_RELEASE(NAME) \
84 inline void OrtRelease(Ort##NAME* ptr) { GetApi().Release##NAME(ptr); }
85
86ORT_DEFINE_RELEASE(Allocator);
87ORT_DEFINE_RELEASE(MemoryInfo);
88ORT_DEFINE_RELEASE(CustomOpDomain);
89ORT_DEFINE_RELEASE(Env);
90ORT_DEFINE_RELEASE(RunOptions);
91ORT_DEFINE_RELEASE(Session);
92ORT_DEFINE_RELEASE(SessionOptions);
93ORT_DEFINE_RELEASE(TensorTypeAndShapeInfo);
94ORT_DEFINE_RELEASE(SequenceTypeInfo);
95ORT_DEFINE_RELEASE(MapTypeInfo);
96ORT_DEFINE_RELEASE(TypeInfo);
97ORT_DEFINE_RELEASE(Value);
98ORT_DEFINE_RELEASE(ModelMetadata);
99ORT_DEFINE_RELEASE(ThreadingOptions);
100ORT_DEFINE_RELEASE(IoBinding);
101ORT_DEFINE_RELEASE(ArenaCfg);
102
103/*! \class Ort::Float16_t
104 * \brief it is a structure that represents float16 data.
105 * \details It is necessary for type dispatching to make use of C++ API
106 * The type is implicitly convertible to/from uint16_t.
107 * The size of the structure should align with uint16_t and one can freely cast
108 * uint16_t buffers to/from Ort::Float16_t to feed and retrieve data.
109 *
110 * Generally, you can feed any of your types as float16/blfoat16 data to create a tensor
111 * on top of it, providing it can form a continuous buffer with 16-bit elements with no padding.
112 * And you can also feed a array of uint16_t elements directly. For example,
113 *
114 * \code{.unparsed}
115 * uint16_t values[] = { 15360, 16384, 16896, 17408, 17664};
116 * constexpr size_t values_length = sizeof(values) / sizeof(values[0]);
117 * std::vector<int64_t> dims = {values_length}; // one dimensional example
118 * Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
119 * // Note we are passing bytes count in this api, not number of elements -> sizeof(values)
120 * auto float16_tensor = Ort::Value::CreateTensor(info, values, sizeof(values),
121 * dims.data(), dims.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16);
122 * \endcode
123 *
124 * Here is another example, a little bit more elaborate. Let's assume that you use your own float16 type and you want to use
125 * a templated version of the API above so the type is automatically set based on your type. You will need to supply an extra
126 * template specialization.
127 *
128 * \code{.unparsed}
129 * namespace yours { struct half {}; } // assume this is your type, define this:
130 * namespace Ort {
131 * template<>
132 * struct TypeToTensorType<yours::half> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16; };
133 * } //namespace Ort
134 *
135 * std::vector<yours::half> values;
136 * std::vector<int64_t> dims = {values.size()}; // one dimensional example
137 * Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
138 * // Here we are passing element count -> values.size()
139 * auto float16_tensor = Ort::Value::CreateTensor<yours::half>(info, values.data(), values.size(), dims.data(), dims.size());
140 *
141 * \endcode
142 */
143struct Float16_t {
144 uint16_t value;
145 constexpr Float16_t() noexcept : value(0) {}
146 constexpr Float16_t(uint16_t v) noexcept : value(v) {}
147 constexpr operator uint16_t() const noexcept { return value; }
148 constexpr bool operator==(const Float16_t& rhs) const noexcept { return value == rhs.value; };
149 constexpr bool operator!=(const Float16_t& rhs) const noexcept { return value != rhs.value; };
150};
151
152static_assert(sizeof(Float16_t) == sizeof(uint16_t), "Sizes must match");
153
154/*! \class Ort::BFloat16_t
155 * \brief is a structure that represents bfloat16 data.
156 * \details It is necessary for type dispatching to make use of C++ API
157 * The type is implicitly convertible to/from uint16_t.
158 * The size of the structure should align with uint16_t and one can freely cast
159 * uint16_t buffers to/from Ort::BFloat16_t to feed and retrieve data.
160 *
161 * See also code examples for Float16_t above.
162 */
163struct BFloat16_t {
164 uint16_t value;
165 constexpr BFloat16_t() noexcept : value(0) {}
166 constexpr BFloat16_t(uint16_t v) noexcept : value(v) {}
167 constexpr operator uint16_t() const noexcept { return value; }
168 constexpr bool operator==(const BFloat16_t& rhs) const noexcept { return value == rhs.value; };
169 constexpr bool operator!=(const BFloat16_t& rhs) const noexcept { return value != rhs.value; };
170};
171
172static_assert(sizeof(BFloat16_t) == sizeof(uint16_t), "Sizes must match");
173
174// This is used internally by the C++ API. This is the common base class used by the wrapper objects.
175template <typename T>
176struct Base {
177 using contained_type = T;
178
179 Base() = default;
180 Base(T* p) : p_{p} {
181 if (!p)
182 ORT_CXX_API_THROW("Allocation failure", ORT_FAIL);
183 }
184 ~Base() { OrtRelease(p_); }
185
186 operator T*() { return p_; }
187 operator const T*() const { return p_; }
188
189 T* release() {
190 T* p = p_;
191 p_ = nullptr;
192 return p;
193 }
194
195 protected:
196 Base(const Base&) = delete;
197 Base& operator=(const Base&) = delete;
198 Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
199 void operator=(Base&& v) noexcept {
200 OrtRelease(p_);
201 p_ = v.p_;
202 v.p_ = nullptr;
203 }
204
205 T* p_{};
206
207 template <typename>
208 friend struct Unowned; // This friend line is needed to keep the centos C++ compiler from giving an error
209};
210
211template <typename T>
212struct Base<const T> {
213 using contained_type = const T;
214
215 Base() = default;
216 Base(const T* p) : p_{p} {
217 if (!p)
218 ORT_CXX_API_THROW("Invalid instance ptr", ORT_INVALID_ARGUMENT);
219 }
220 ~Base() = default;
221
222 operator const T*() const { return p_; }
223
224 protected:
225 Base(const Base&) = delete;
226 Base& operator=(const Base&) = delete;
227 Base(Base&& v) noexcept : p_{v.p_} { v.p_ = nullptr; }
228 void operator=(Base&& v) noexcept {
229 p_ = v.p_;
230 v.p_ = nullptr;
231 }
232
233 const T* p_{};
234};
235
236template <typename T>
237struct Unowned : T {
238 Unowned(decltype(T::p_) p) : T{p} {}
239 Unowned(Unowned&& v) : T{v.p_} {}
240 ~Unowned() { this->release(); }
241};
242
243struct AllocatorWithDefaultOptions;
244struct MemoryInfo;
245struct Env;
246struct TypeInfo;
247struct Value;
248struct ModelMetadata;
249
250struct Env : Base<OrtEnv> {
251 Env(std::nullptr_t) {}
252 Env(OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
253 Env(const OrtThreadingOptions* tp_options, OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
254 Env(OrtLoggingLevel logging_level, const char* logid, OrtLoggingFunction logging_function, void* logger_param);
255 Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
256 OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
257 explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}
258
259 Env& EnableTelemetryEvents();
260 Env& DisableTelemetryEvents();
261
262 Env& CreateAndRegisterAllocator(const OrtMemoryInfo* mem_info, const OrtArenaCfg* arena_cfg);
263
264 static const OrtApi* s_api;
265};
266
267struct CustomOpDomain : Base<OrtCustomOpDomain> {
268 explicit CustomOpDomain(std::nullptr_t) {}
269 explicit CustomOpDomain(const char* domain);
270
271 void Add(OrtCustomOp* op);
272};
273
274struct RunOptions : Base<OrtRunOptions> {
275 RunOptions(std::nullptr_t) {}
276 RunOptions();
277
278 RunOptions& SetRunLogVerbosityLevel(int);
279 int GetRunLogVerbosityLevel() const;
280
281 RunOptions& SetRunLogSeverityLevel(int);
282 int GetRunLogSeverityLevel() const;
283
284 RunOptions& SetRunTag(const char* run_tag);
285 const char* GetRunTag() const;
286
287 // terminate ALL currently executing Session::Run calls that were made using this RunOptions instance
288 RunOptions& SetTerminate();
289 // unset the terminate flag so this RunOptions instance can be used in a new Session::Run call
290 RunOptions& UnsetTerminate();
291};
292
293struct SessionOptions : Base<OrtSessionOptions> {
294 explicit SessionOptions(std::nullptr_t) {}
295 SessionOptions();
296 explicit SessionOptions(OrtSessionOptions* p) : Base<OrtSessionOptions>{p} {}
297
298 SessionOptions Clone() const;
299
300 SessionOptions& SetIntraOpNumThreads(int intra_op_num_threads);
301 SessionOptions& SetInterOpNumThreads(int inter_op_num_threads);
302 SessionOptions& SetGraphOptimizationLevel(GraphOptimizationLevel graph_optimization_level);
303
304 SessionOptions& EnableCpuMemArena();
305 SessionOptions& DisableCpuMemArena();
306
307 SessionOptions& SetOptimizedModelFilePath(const ORTCHAR_T* optimized_model_file);
308
309 SessionOptions& EnableProfiling(const ORTCHAR_T* profile_file_prefix);
310 SessionOptions& DisableProfiling();
311
312 SessionOptions& EnableMemPattern();
313 SessionOptions& DisableMemPattern();
314
315 SessionOptions& SetExecutionMode(ExecutionMode execution_mode);
316
317 SessionOptions& SetLogId(const char* logid);
318 SessionOptions& SetLogSeverityLevel(int level);
319
320 SessionOptions& Add(OrtCustomOpDomain* custom_op_domain);
321
322 SessionOptions& DisablePerSessionThreads();
323
324 SessionOptions& AddConfigEntry(const char* config_key, const char* config_value);
325 SessionOptions& AddInitializer(const char* name, const OrtValue* ort_val);
326
327 SessionOptions& AppendExecutionProvider_CUDA(const OrtCUDAProviderOptions& provider_options);
328 SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options);
329};
330
331struct ModelMetadata : Base<OrtModelMetadata> {
332 explicit ModelMetadata(std::nullptr_t) {}
333 explicit ModelMetadata(OrtModelMetadata* p) : Base<OrtModelMetadata>{p} {}
334
335 char* GetProducerName(OrtAllocator* allocator) const;
336 char* GetGraphName(OrtAllocator* allocator) const;
337 char* GetDomain(OrtAllocator* allocator) const;
338 char* GetDescription(OrtAllocator* allocator) const;
339 char** GetCustomMetadataMapKeys(OrtAllocator* allocator, _Out_ int64_t& num_keys) const;
340 char* LookupCustomMetadataMap(const char* key, OrtAllocator* allocator) const;
341 int64_t GetVersion() const;
342};
343
344struct Session : Base<OrtSession> {
345 explicit Session(std::nullptr_t) {}
346 Session(Env& env, const ORTCHAR_T* model_path, const SessionOptions& options);
347 Session(Env& env, const void* model_data, size_t model_data_length, const SessionOptions& options);
348
349 // Run that will allocate the output values
350 std::vector<Value> Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
351 const char* const* output_names, size_t output_count);
352 // Run for when there is a list of preallocated outputs
353 void Run(const RunOptions& run_options, const char* const* input_names, const Value* input_values, size_t input_count,
354 const char* const* output_names, Value* output_values, size_t output_count);
355
356 void Run(const RunOptions& run_options, const struct IoBinding&);
357
358 size_t GetInputCount() const;
359 size_t GetOutputCount() const;
360 size_t GetOverridableInitializerCount() const;
361
362 char* GetInputName(size_t index, OrtAllocator* allocator) const;
363 char* GetOutputName(size_t index, OrtAllocator* allocator) const;
364 char* GetOverridableInitializerName(size_t index, OrtAllocator* allocator) const;
365 char* EndProfiling(OrtAllocator* allocator) const;
366 uint64_t GetProfilingStartTimeNs() const;
367 ModelMetadata GetModelMetadata() const;
368
369 TypeInfo GetInputTypeInfo(size_t index) const;
370 TypeInfo GetOutputTypeInfo(size_t index) const;
371 TypeInfo GetOverridableInitializerTypeInfo(size_t index) const;
372};
373
374struct TensorTypeAndShapeInfo : Base<OrtTensorTypeAndShapeInfo> {
375 explicit TensorTypeAndShapeInfo(std::nullptr_t) {}
376 explicit TensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* p) : Base<OrtTensorTypeAndShapeInfo>{p} {}
377
378 ONNXTensorElementDataType GetElementType() const;
379 size_t GetElementCount() const;
380
381 size_t GetDimensionsCount() const;
382 void GetDimensions(int64_t* values, size_t values_count) const;
383 void GetSymbolicDimensions(const char** values, size_t values_count) const;
384
385 std::vector<int64_t> GetShape() const;
386};
387
388struct SequenceTypeInfo : Base<OrtSequenceTypeInfo> {
389 explicit SequenceTypeInfo(std::nullptr_t) {}
390 explicit SequenceTypeInfo(OrtSequenceTypeInfo* p) : Base<OrtSequenceTypeInfo>{p} {}
391
392 TypeInfo GetSequenceElementType() const;
393};
394
395struct MapTypeInfo : Base<OrtMapTypeInfo> {
396 explicit MapTypeInfo(std::nullptr_t) {}
397 explicit MapTypeInfo(OrtMapTypeInfo* p) : Base<OrtMapTypeInfo>{p} {}
398
399 ONNXTensorElementDataType GetMapKeyType() const;
400 TypeInfo GetMapValueType() const;
401};
402
403struct TypeInfo : Base<OrtTypeInfo> {
404 explicit TypeInfo(std::nullptr_t) {}
405 explicit TypeInfo(OrtTypeInfo* p) : Base<OrtTypeInfo>{p} {}
406
407 Unowned<TensorTypeAndShapeInfo> GetTensorTypeAndShapeInfo() const;
408 Unowned<SequenceTypeInfo> GetSequenceTypeInfo() const;
409 Unowned<MapTypeInfo> GetMapTypeInfo() const;
410
411 ONNXType GetONNXType() const;
412};
413
414struct Value : Base<OrtValue> {
415 template <typename T>
416 static Value CreateTensor(const OrtMemoryInfo* info, T* p_data, size_t p_data_element_count, const int64_t* shape, size_t shape_len);
417 static Value CreateTensor(const OrtMemoryInfo* info, void* p_data, size_t p_data_byte_count, const int64_t* shape, size_t shape_len,
418 ONNXTensorElementDataType type);
419 template <typename T>
420 static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len);
421 static Value CreateTensor(OrtAllocator* allocator, const int64_t* shape, size_t shape_len, ONNXTensorElementDataType type);
422
423 static Value CreateMap(Value& keys, Value& values);
424 static Value CreateSequence(std::vector<Value>& values);
425
426 template <typename T>
427 static Value CreateOpaque(const char* domain, const char* type_name, const T&);
428
429 template <typename T>
430 void GetOpaqueData(const char* domain, const char* type_name, T&) const;
431
432 explicit Value(std::nullptr_t) {}
433 explicit Value(OrtValue* p) : Base<OrtValue>{p} {}
434 Value(Value&&) = default;
435 Value& operator=(Value&&) = default;
436
437 bool IsTensor() const;
438 size_t GetCount() const; // If a non tensor, returns 2 for map and N for sequence, where N is the number of elements
439 Value GetValue(int index, OrtAllocator* allocator) const;
440
441 size_t GetStringTensorDataLength() const;
442 void GetStringTensorContent(void* buffer, size_t buffer_length, size_t* offsets, size_t offsets_count) const;
443
444 template <typename T>
445 T* GetTensorMutableData();
446
447 template <typename T>
448 const T* GetTensorData() const;
449
450 template <typename T>
451 T& At(const std::vector<int64_t>& location);
452
453 TypeInfo GetTypeInfo() const;
454 TensorTypeAndShapeInfo GetTensorTypeAndShapeInfo() const;
455
456 size_t GetStringTensorElementLength(size_t element_index) const;
457 void GetStringTensorElement(size_t buffer_length, size_t element_index, void* buffer) const;
458
459 void FillStringTensor(const char* const* s, size_t s_len);
460 void FillStringTensorElement(const char* s, size_t index);
461};
462
463// Represents native memory allocation
464struct MemoryAllocation {
465 MemoryAllocation(OrtAllocator* allocator, void* p, size_t size);
466 ~MemoryAllocation();
467 MemoryAllocation(const MemoryAllocation&) = delete;
468 MemoryAllocation& operator=(const MemoryAllocation&) = delete;
469 MemoryAllocation(MemoryAllocation&&);
470 MemoryAllocation& operator=(MemoryAllocation&&);
471
472 void* get() { return p_; }
473 size_t size() const { return size_; }
474
475 private:
476 OrtAllocator* allocator_;
477 void* p_;
478 size_t size_;
479};
480
481struct AllocatorWithDefaultOptions {
482 AllocatorWithDefaultOptions();
483
484 operator OrtAllocator*() { return p_; }
485 operator const OrtAllocator*() const { return p_; }
486
487 void* Alloc(size_t size);
488 // The return value will own the allocation
489 MemoryAllocation GetAllocation(size_t size);
490 void Free(void* p);
491
492 const OrtMemoryInfo* GetInfo() const;
493
494 private:
495 OrtAllocator* p_{};
496};
497
498template <typename B>
499struct BaseMemoryInfo : B {
500 BaseMemoryInfo() = default;
501 explicit BaseMemoryInfo(typename B::contained_type* p) : B(p) {}
502 ~BaseMemoryInfo() = default;
503 BaseMemoryInfo(BaseMemoryInfo&&) = default;
504 BaseMemoryInfo& operator=(BaseMemoryInfo&&) = default;
505
506 std::string GetAllocatorName() const;
507 OrtAllocatorType GetAllocatorType() const;
508 int GetDeviceId() const;
509 OrtMemType GetMemoryType() const;
510 template <typename U>
511 bool operator==(const BaseMemoryInfo<U>& o) const;
512};
513
514struct UnownedMemoryInfo : BaseMemoryInfo<Base<const OrtMemoryInfo> > {
515 explicit UnownedMemoryInfo(std::nullptr_t) {}
516 explicit UnownedMemoryInfo(const OrtMemoryInfo* p) : BaseMemoryInfo(p) {}
517};
518
519struct MemoryInfo : BaseMemoryInfo<Base<OrtMemoryInfo> > {
520 static MemoryInfo CreateCpu(OrtAllocatorType type, OrtMemType mem_type1);
521
522 explicit MemoryInfo(std::nullptr_t) {}
523 explicit MemoryInfo(OrtMemoryInfo* p) : BaseMemoryInfo(p) {}
524 MemoryInfo(const char* name, OrtAllocatorType type, int id, OrtMemType mem_type);
525};
526
527struct Allocator : public Base<OrtAllocator> {
528 Allocator(const Session& session, const MemoryInfo&);
529
530 void* Alloc(size_t size) const;
531 // The return value will own the allocation
532 MemoryAllocation GetAllocation(size_t size);
533 void Free(void* p) const;
534 UnownedMemoryInfo GetInfo() const;
535};
536
537struct IoBinding : public Base<OrtIoBinding> {
538 private:
539 std::vector<std::string> GetOutputNamesHelper(OrtAllocator*) const;
540 std::vector<Value> GetOutputValuesHelper(OrtAllocator*) const;
541
542 public:
543 explicit IoBinding(Session& session);
544 void BindInput(const char* name, const Value&);
545 void BindOutput(const char* name, const Value&);
546 void BindOutput(const char* name, const MemoryInfo&);
547 std::vector<std::string> GetOutputNames() const;
548 std::vector<std::string> GetOutputNames(Allocator&) const;
549 std::vector<Value> GetOutputValues() const;
550 std::vector<Value> GetOutputValues(Allocator&) const;
551 void ClearBoundInputs();
552 void ClearBoundOutputs();
553};
554
555/*! \struct Ort::ArenaCfg
556 * \brief it is a structure that represents the configuration of an arena based allocator
557 * \details Please see docs/C_API.md for details
558 */
559struct ArenaCfg : Base<OrtArenaCfg> {
560 explicit ArenaCfg(std::nullptr_t) {}
561 /**
562 * \param max_mem - use 0 to allow ORT to choose the default
563 * \param arena_extend_strategy - use -1 to allow ORT to choose the default, 0 = kNextPowerOfTwo, 1 = kSameAsRequested
564 * \param initial_chunk_size_bytes - use -1 to allow ORT to choose the default
565 * \param max_dead_bytes_per_chunk - use -1 to allow ORT to choose the default
566 * \return an instance of ArenaCfg
567 * See docs/C_API.md for details on what the following parameters mean and how to choose these values
568 */
569 ArenaCfg(size_t max_mem, int arena_extend_strategy, int initial_chunk_size_bytes, int max_dead_bytes_per_chunk);
570};
571
572//
573// Custom OPs (only needed to implement custom OPs)
574//
575
576struct CustomOpApi {
577 CustomOpApi(const OrtApi& api) : api_(api) {}
578
579 template <typename T> // T is only implemented for float, int64_t, and string
580 T KernelInfoGetAttribute(_In_ const OrtKernelInfo* info, _In_ const char* name);
581
582 OrtTensorTypeAndShapeInfo* GetTensorTypeAndShape(_In_ const OrtValue* value);
583 size_t GetTensorShapeElementCount(_In_ const OrtTensorTypeAndShapeInfo* info);
584 ONNXTensorElementDataType GetTensorElementType(const OrtTensorTypeAndShapeInfo* info);
585 size_t GetDimensionsCount(_In_ const OrtTensorTypeAndShapeInfo* info);
586 void GetDimensions(_In_ const OrtTensorTypeAndShapeInfo* info, _Out_ int64_t* dim_values, size_t dim_values_length);
587 void SetDimensions(OrtTensorTypeAndShapeInfo* info, _In_ const int64_t* dim_values, size_t dim_count);
588
589 template <typename T>
590 T* GetTensorMutableData(_Inout_ OrtValue* value);
591 template <typename T>
592 const T* GetTensorData(_Inout_ const OrtValue* value);
593
594 std::vector<int64_t> GetTensorShape(const OrtTensorTypeAndShapeInfo* info);
595 void ReleaseTensorTypeAndShapeInfo(OrtTensorTypeAndShapeInfo* input);
596 size_t KernelContext_GetInputCount(const OrtKernelContext* context);
597 const OrtValue* KernelContext_GetInput(const OrtKernelContext* context, _In_ size_t index);
598 size_t KernelContext_GetOutputCount(const OrtKernelContext* context);
599 OrtValue* KernelContext_GetOutput(OrtKernelContext* context, _In_ size_t index, _In_ const int64_t* dim_values, size_t dim_count);
600
601 void ThrowOnError(OrtStatus* result);
602
603 private:
604 const OrtApi& api_;
605};
606
607template <typename TOp, typename TKernel>
608struct CustomOpBase : OrtCustomOp {
609 CustomOpBase() {
610 OrtCustomOp::version = ORT_API_VERSION;
611 OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* api, const OrtKernelInfo* info) { return static_cast<const TOp*>(this_)->CreateKernel(*api, info); };
612 OrtCustomOp::GetName = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetName(); };
613
614 OrtCustomOp::GetExecutionProviderType = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetExecutionProviderType(); };
615
616 OrtCustomOp::GetInputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetInputTypeCount(); };
617 OrtCustomOp::GetInputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetInputType(index); };
618
619 OrtCustomOp::GetOutputTypeCount = [](const OrtCustomOp* this_) { return static_cast<const TOp*>(this_)->GetOutputTypeCount(); };
620 OrtCustomOp::GetOutputType = [](const OrtCustomOp* this_, size_t index) { return static_cast<const TOp*>(this_)->GetOutputType(index); };
621
622 OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { static_cast<TKernel*>(op_kernel)->Compute(context); };
623 OrtCustomOp::KernelDestroy = [](void* op_kernel) { delete static_cast<TKernel*>(op_kernel); };
624 }
625
626 // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
627 const char* GetExecutionProviderType() const { return nullptr; }
628};
629
630} // namespace Ort
631
632#include "onnxruntime_cxx_inline.h"
633