microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
8e4d15793609897fcd2cbe30ffa1d16e8ea4ee81

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

shared/ortcustomops.cc

137lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include <mutex>
5#include <set>
6
7#include "onnxruntime_extensions.h"
8#include "ocos.h"
9
10struct OrtCustomOpDomainDeleter {
11 explicit OrtCustomOpDomainDeleter(const OrtApi* ort_api) {
12 ort_api_ = ort_api;
13 }
14 void operator()(OrtCustomOpDomain* domain) const {
15 ort_api_->ReleaseCustomOpDomain(domain);
16 }
17
18 const OrtApi* ort_api_;
19};
20
21using OrtCustomOpDomainUniquePtr = std::unique_ptr<OrtCustomOpDomain, OrtCustomOpDomainDeleter>;
22static std::vector<OrtCustomOpDomainUniquePtr> ort_custom_op_domain_container;
23static std::mutex ort_custom_op_domain_mutex;
24
25static void AddOrtCustomOpDomainToContainer(OrtCustomOpDomain* domain, const OrtApi* ort_api) {
26 std::lock_guard<std::mutex> lock(ort_custom_op_domain_mutex);
27 auto ptr = std::unique_ptr<OrtCustomOpDomain, OrtCustomOpDomainDeleter>(domain, OrtCustomOpDomainDeleter(ort_api));
28 ort_custom_op_domain_container.push_back(std::move(ptr));
29}
30
31class ExternalCustomOps {
32 public:
33 ExternalCustomOps() {
34 }
35
36 static ExternalCustomOps& instance() {
37 static ExternalCustomOps g_instance;
38 return g_instance;
39 }
40
41 void Add(const OrtCustomOp* c_op) {
42 op_array_.push_back(c_op);
43 }
44
45 const OrtCustomOp* GetNextOp(size_t& idx) {
46 if (idx >= op_array_.size()) {
47 return nullptr;
48 }
49
50 return op_array_[idx++];
51 }
52
53 ExternalCustomOps(ExternalCustomOps const&) = delete;
54 void operator=(ExternalCustomOps const&) = delete;
55
56 private:
57 std::vector<const OrtCustomOp*> op_array_;
58};
59
60extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op) {
61 ExternalCustomOps::instance().Add(c_op);
62 return true;
63}
64
65extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) {
66 OrtCustomOpDomain* domain = nullptr;
67 const OrtApi* ortApi = api->GetApi(ORT_API_VERSION);
68 std::set<std::string> pyop_nameset;
69 OrtStatus* status = nullptr;
70
71#if defined(PYTHON_OP_SUPPORT)
72 if (status = RegisterPythonDomainAndOps(options, ortApi)){
73 return status;
74 }
75#endif // PYTHON_OP_SUPPORT
76
77 if (status = ortApi->CreateCustomOpDomain(c_OpDomain, &domain)) {
78 return status;
79 }
80
81 AddOrtCustomOpDomainToContainer(domain, ortApi);
82
83#if defined(PYTHON_OP_SUPPORT)
84 size_t count = 0;
85 const OrtCustomOp* c_ops = FetchPyCustomOps(count);
86 while (c_ops != nullptr) {
87 if (status = ortApi->CustomOpDomain_Add(domain, c_ops)) {
88 return status;
89 } else {
90 pyop_nameset.emplace(c_ops->GetName(c_ops));
91 }
92 ++count;
93 c_ops = FetchPyCustomOps(count);
94 }
95#endif
96
97 static std::vector<FxLoadCustomOpFactory> c_factories = {
98 LoadCustomOpClasses<CustomOpClassBegin>
99#if defined(ENABLE_TF_STRING)
100 , LoadCustomOpClasses_Text
101#endif // ENABLE_TF_STRING
102#if defined(ENABLE_MATH)
103 , LoadCustomOpClasses_Math
104#endif
105#if defined(ENABLE_TOKENIZER)
106 , LoadCustomOpClasses_Tokenizer
107#endif
108#if defined(ENABLE_OPENCV)
109 , LoadCustomOpClasses_OpenCV
110#endif
111 };
112
113 for (auto fx : c_factories) {
114 auto ops = fx();
115 while (*ops != nullptr) {
116 if (pyop_nameset.find((*ops)->GetName(*ops)) == pyop_nameset.end()) {
117 if (status = ortApi->CustomOpDomain_Add(domain, *ops)) {
118 return status;
119 }
120 }
121 ++ops;
122 }
123 }
124
125 size_t idx = 0;
126 const OrtCustomOp* e_ops = ExternalCustomOps::instance().GetNextOp(idx);
127 while (e_ops != nullptr) {
128 if (pyop_nameset.find(e_ops->GetName(e_ops)) == pyop_nameset.end()) {
129 if (status = ortApi->CustomOpDomain_Add(domain, e_ops)) {
130 return status;
131 }
132 e_ops = ExternalCustomOps::instance().GetNextOp(idx);
133 }
134 }
135
136 return ortApi->AddCustomOpDomain(options, domain);
137}