microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
ba200b4a0e391b45c0df4f9b1a506f0a9f574dd4

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

shared/ortcustomops.cc

189lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include <set>
5
6#include "string_utils.h"
7
8#include "text/op_equal.hpp"
9#include "text/op_segment_sum.hpp"
10#include "text/op_ragged_tensor.hpp"
11#include "text/string_hash.hpp"
12#include "text/string_join.hpp"
13#include "text/string_lower.hpp"
14#include "text/string_regex_replace.hpp"
15#include "text/string_regex_split.hpp"
16#include "text/string_split.hpp"
17#include "text/string_to_vector.hpp"
18#include "text/string_upper.hpp"
19#include "text/vector_to_string.hpp"
20#include "text/string_length.hpp"
21#include "text/string_concat.hpp"
22
23
24#ifdef ENABLE_SPM_TOKENIZER
25#include "sentencepiece_tokenizer.hpp"
26#endif
27
28#ifdef ENABLE_BERT_TOKENIZER
29#include "wordpiece_tokenizer.hpp"
30#endif
31
32#ifdef ENABLE_SPM_TOKENIZER
33CustomOpSentencepieceTokenizer c_CustomOpSentencepieceTokenizer;
34#endif
35
36#ifdef ENABLE_BERT_TOKENIZER
37CustomOpWordpieceTokenizer c_CustomOpWordpieceTokenizer;
38#endif
39
40#ifdef ENABLE_TF_STRING
41CustomOpSegmentSum c_CustomOpSegmentSum;
42CustomOpRaggedTensorToDense c_CustomOpRaggedTensorToDense;
43CustomOpRaggedTensorToSparse c_CustomOpRaggedTensorToSparse;
44CustomOpStringEqual c_CustomOpStringEqual;
45CustomOpStringHash c_CustomOpStringHash;
46CustomOpStringHashFast c_CustomOpStringHashFast;
47CustomOpStringJoin c_CustomOpStringJoin;
48CustomOpStringLower c_CustomOpStringLower;
49CustomOpStringRaggedTensorToDense c_CustomOpStringRaggedTensorToDense;
50CustomOpStringRegexReplace c_CustomOpStringRegexReplace;
51CustomOpStringRegexSplitWithOffsets c_CustomOpStringRegexSplitWithOffsets;
52CustomOpStringSplit c_CustomOpStringSplit;
53CustomOpStringToVector c_CustomOpStringToVector;
54CustomOpStringUpper c_CustomOpStringUpper;
55CustomOpVectorToString c_CustomOpVectorToString;
56CustomOpStringLength c_CustomOpStringLength;
57CustomOpStringConcat c_CustomOpStringConcat;
58#endif
59
60OrtCustomOp* operator_lists[] = {
61#ifdef ENABLE_SPM_TOKENIZER
62 &c_CustomOpSentencepieceTokenizer,
63#endif
64
65#ifdef ENABLE_BERT_TOKENIZER
66 &c_CustomOpWordpieceTokenizer,
67#endif
68
69#ifdef ENABLE_TF_STRING
70 &c_CustomOpRaggedTensorToDense,
71 &c_CustomOpRaggedTensorToSparse,
72 &c_CustomOpSegmentSum,
73 &c_CustomOpStringEqual,
74 &c_CustomOpStringHash,
75 &c_CustomOpStringHashFast,
76 &c_CustomOpStringJoin,
77 &c_CustomOpStringLower,
78 &c_CustomOpStringRaggedTensorToDense,
79 &c_CustomOpStringRegexReplace,
80 &c_CustomOpStringRegexSplitWithOffsets,
81 &c_CustomOpStringSplit,
82 &c_CustomOpStringToVector,
83 &c_CustomOpStringUpper,
84 &c_CustomOpVectorToString,
85 &c_CustomOpStringLength,
86 &c_CustomOpStringConcat,
87#endif
88 nullptr};
89
90#if ENABLE_MATH
91extern FxLoadCustomOpFactory LoadCustomOpClasses_Math;
92#endif //ENABLE_MATH
93
94class ExternalCustomOps {
95 public:
96 ExternalCustomOps() {
97 }
98
99 static ExternalCustomOps& instance() {
100 static ExternalCustomOps g_instance;
101 return g_instance;
102 }
103
104 void Add(const OrtCustomOp* c_op) {
105 op_array_.push_back(c_op);
106 }
107
108 const OrtCustomOp* GetNextOp(size_t& idx) {
109 if (idx >= op_array_.size()) {
110 return nullptr;
111 }
112
113 return op_array_[idx++];
114 }
115
116 ExternalCustomOps(ExternalCustomOps const&) = delete;
117 void operator=(ExternalCustomOps const&) = delete;
118
119 private:
120 std::vector<const OrtCustomOp*> op_array_;
121};
122
123extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op) {
124 ExternalCustomOps::instance().Add(c_op);
125 return true;
126}
127
128
129
130extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) {
131 OrtCustomOpDomain* domain = nullptr;
132 const OrtApi* ortApi = api->GetApi(ORT_API_VERSION);
133 std::set<std::string> pyop_nameset;
134
135 if (auto status = ortApi->CreateCustomOpDomain(c_OpDomain, &domain)) {
136 return status;
137 }
138
139#if defined(PYTHON_OP_SUPPORT)
140 size_t count = 0;
141 const OrtCustomOp* c_ops = FetchPyCustomOps(count);
142 while (c_ops != nullptr) {
143 if (auto status = ortApi->CustomOpDomain_Add(domain, c_ops)) {
144 return status;
145 } else {
146 pyop_nameset.emplace(c_ops->GetName(c_ops));
147 }
148 ++count;
149 c_ops = FetchPyCustomOps(count);
150 }
151#endif
152
153 static std::vector<FxLoadCustomOpFactory> c_factories = {
154 []() { return const_cast<const OrtCustomOp**>(operator_lists); }
155#if defined(ENABLE_MATH)
156 ,
157 LoadCustomOpClasses_Math
158#endif
159#if defined(ENABLE_GPT2_TOKENIZER)
160 ,
161 LoadTokenizerSchemaList
162#endif
163 };
164
165 for (auto fx : c_factories) {
166 auto ops = fx();
167 while (*ops != nullptr) {
168 if (pyop_nameset.find((*ops)->GetName(*ops)) == pyop_nameset.end()) {
169 if (auto status = ortApi->CustomOpDomain_Add(domain, *ops)) {
170 return status;
171 }
172 }
173 ++ops;
174 }
175 }
176
177 size_t idx = 0;
178 const OrtCustomOp* e_ops = ExternalCustomOps::instance().GetNextOp(idx);
179 while (e_ops != nullptr) {
180 if (pyop_nameset.find(e_ops->GetName(e_ops)) == pyop_nameset.end()) {
181 if (auto status = ortApi->CustomOpDomain_Add(domain, e_ops)) {
182 return status;
183 }
184 e_ops = ExternalCustomOps::instance().GetNextOp(idx);
185 }
186 }
187
188 return ortApi->AddCustomOpDomain(options, domain);
189}
190