microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
0d5d19f67b28024de0b88d4a61bcc4157dc06248

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

include/ortx_cpp_helper.h

97lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5
6#include "ortx_utils.h"
7
8namespace ort_extensions {
9
10template <typename T>
11class OrtxDeleter {
12 public:
13 void operator()(T* p) const {
14 if (p) {
15 OrtxDisposeOnly(p);
16 }
17 }
18};
19
20/**
21 * @brief A smart pointer class that manages the lifetime of an OrtxObject.
22 *
23 * This class is derived from std::unique_ptr and provides additional functionality
24 * specific to OrtxObject. It automatically calls the OrtxDeleter to release the
25 * owned object when it goes out of scope.
26 *
27 * @tparam T The type of the object being managed.
28 */
29template <typename T>
30class OrtxObjectPtr : public std::unique_ptr<T, OrtxDeleter<T>> {
31 public:
32 /**
33 * @brief Default constructor.
34 *
35 * Constructs an OrtxObjectPtr with a null pointer.
36 */
37 OrtxObjectPtr() : std::unique_ptr<T, OrtxDeleter<T>>(nullptr) {}
38
39 /**
40 * @brief Constructor that creates an OrtxObjectPtr from a function call.
41 *
42 * This constructor calls the specified function with the given arguments to
43 * create an OrtxObject. If the function call succeeds, the created object is
44 * owned by the OrtxObjectPtr.
45 *
46 * @tparam TFn The type of the function pointer or function object.
47 * @tparam Args The types of the arguments to be passed to the function.
48 * @param fn The function pointer or function object used to create the OrtxObject.
49 * @param args The arguments to be passed to the function.
50 */
51 template <typename TFn, typename... Args>
52 OrtxObjectPtr(TFn fn, Args&&... args) {
53 OrtxObject* proc = nullptr;
54 err_ = fn(&proc, std::forward<Args>(args)...);
55 if (err_ == kOrtxOK) {
56 this->reset(static_cast<T*>(proc));
57 }
58 }
59
60 /**
61 * @brief Get the error code associated with the creation of the OrtxObject.
62 *
63 * @return The error code.
64 */
65 extError_t Code() const { return err_; }
66
67 private:
68 extError_t err_ = kOrtxOK; /**< The error code associated with the creation of the OrtxObject. */
69};
70
71template <typename T>
72struct PointerAssigner {
73 OrtxObject* obj_{};
74 OrtxObjectPtr<T>& ptr_;
75 PointerAssigner(OrtxObjectPtr<T>& ptr) : ptr_(ptr){};
76
77 ~PointerAssigner() { ptr_.reset(static_cast<T*>(obj_)); };
78
79 operator T**() { return reinterpret_cast<T**>(&obj_); };
80};
81
82/**
83 * @brief A wrapper function for OrtxObjectPtr that can be used as a function parameter on creation.
84 *
85 * This function creates a PointerAssigner object for the given OrtxObjectPtr. The PointerAssigner
86 * object can be used to assign a pointer value to the OrtxObjectPtr.
87 *
88 * @tparam T The type of the object pointed to by the OrtxObjectPtr.
89 * @param ptr The OrtxObjectPtr to create the PointerAssigner for.
90 * @return A PointerAssigner object for the given OrtxObjectPtr.
91 */
92template <typename T>
93PointerAssigner<T> ptr(OrtxObjectPtr<T>& ptr) {
94 return PointerAssigner<T>{ptr};
95};
96
97} // namespace ort_extensions
98