microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
shared/ortcustomops.cc
189lines · modecode
| 1 | // Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | // Licensed under the MIT License. |
| 3 | |
| 4 | #include <set> |
| 5 | |
| 6 | #include "string_utils.h" |
| 7 | |
| 8 | #include "text/op_equal.hpp" |
| 9 | #include "text/op_segment_sum.hpp" |
| 10 | #include "text/op_ragged_tensor.hpp" |
| 11 | #include "text/string_hash.hpp" |
| 12 | #include "text/string_join.hpp" |
| 13 | #include "text/string_lower.hpp" |
| 14 | #include "text/string_regex_replace.hpp" |
| 15 | #include "text/string_regex_split.hpp" |
| 16 | #include "text/string_split.hpp" |
| 17 | #include "text/string_to_vector.hpp" |
| 18 | #include "text/string_upper.hpp" |
| 19 | #include "text/vector_to_string.hpp" |
| 20 | #include "text/string_length.hpp" |
| 21 | #include "text/string_concat.hpp" |
| 22 | |
| 23 | |
| 24 | #ifdef ENABLE_SPM_TOKENIZER |
| 25 | #include "sentencepiece_tokenizer.hpp" |
| 26 | #endif |
| 27 | |
| 28 | #ifdef ENABLE_BERT_TOKENIZER |
| 29 | #include "wordpiece_tokenizer.hpp" |
| 30 | #endif |
| 31 | |
| 32 | #ifdef ENABLE_SPM_TOKENIZER |
| 33 | CustomOpSentencepieceTokenizer c_CustomOpSentencepieceTokenizer; |
| 34 | #endif |
| 35 | |
| 36 | #ifdef ENABLE_BERT_TOKENIZER |
| 37 | CustomOpWordpieceTokenizer c_CustomOpWordpieceTokenizer; |
| 38 | #endif |
| 39 | |
| 40 | #ifdef ENABLE_TF_STRING |
| 41 | CustomOpSegmentSum c_CustomOpSegmentSum; |
| 42 | CustomOpRaggedTensorToDense c_CustomOpRaggedTensorToDense; |
| 43 | CustomOpRaggedTensorToSparse c_CustomOpRaggedTensorToSparse; |
| 44 | CustomOpStringEqual c_CustomOpStringEqual; |
| 45 | CustomOpStringHash c_CustomOpStringHash; |
| 46 | CustomOpStringHashFast c_CustomOpStringHashFast; |
| 47 | CustomOpStringJoin c_CustomOpStringJoin; |
| 48 | CustomOpStringLower c_CustomOpStringLower; |
| 49 | CustomOpStringRaggedTensorToDense c_CustomOpStringRaggedTensorToDense; |
| 50 | CustomOpStringRegexReplace c_CustomOpStringRegexReplace; |
| 51 | CustomOpStringRegexSplitWithOffsets c_CustomOpStringRegexSplitWithOffsets; |
| 52 | CustomOpStringSplit c_CustomOpStringSplit; |
| 53 | CustomOpStringToVector c_CustomOpStringToVector; |
| 54 | CustomOpStringUpper c_CustomOpStringUpper; |
| 55 | CustomOpVectorToString c_CustomOpVectorToString; |
| 56 | CustomOpStringLength c_CustomOpStringLength; |
| 57 | CustomOpStringConcat c_CustomOpStringConcat; |
| 58 | #endif |
| 59 | |
| 60 | OrtCustomOp* operator_lists[] = { |
| 61 | #ifdef ENABLE_SPM_TOKENIZER |
| 62 | &c_CustomOpSentencepieceTokenizer, |
| 63 | #endif |
| 64 | |
| 65 | #ifdef ENABLE_BERT_TOKENIZER |
| 66 | &c_CustomOpWordpieceTokenizer, |
| 67 | #endif |
| 68 | |
| 69 | #ifdef ENABLE_TF_STRING |
| 70 | &c_CustomOpRaggedTensorToDense, |
| 71 | &c_CustomOpRaggedTensorToSparse, |
| 72 | &c_CustomOpSegmentSum, |
| 73 | &c_CustomOpStringEqual, |
| 74 | &c_CustomOpStringHash, |
| 75 | &c_CustomOpStringHashFast, |
| 76 | &c_CustomOpStringJoin, |
| 77 | &c_CustomOpStringLower, |
| 78 | &c_CustomOpStringRaggedTensorToDense, |
| 79 | &c_CustomOpStringRegexReplace, |
| 80 | &c_CustomOpStringRegexSplitWithOffsets, |
| 81 | &c_CustomOpStringSplit, |
| 82 | &c_CustomOpStringToVector, |
| 83 | &c_CustomOpStringUpper, |
| 84 | &c_CustomOpVectorToString, |
| 85 | &c_CustomOpStringLength, |
| 86 | &c_CustomOpStringConcat, |
| 87 | #endif |
| 88 | nullptr}; |
| 89 | |
| 90 | #if ENABLE_MATH |
| 91 | extern FxLoadCustomOpFactory LoadCustomOpClasses_Math; |
| 92 | #endif //ENABLE_MATH |
| 93 | |
| 94 | class ExternalCustomOps { |
| 95 | public: |
| 96 | ExternalCustomOps() { |
| 97 | } |
| 98 | |
| 99 | static ExternalCustomOps& instance() { |
| 100 | static ExternalCustomOps g_instance; |
| 101 | return g_instance; |
| 102 | } |
| 103 | |
| 104 | void Add(const OrtCustomOp* c_op) { |
| 105 | op_array_.push_back(c_op); |
| 106 | } |
| 107 | |
| 108 | const OrtCustomOp* GetNextOp(size_t& idx) { |
| 109 | if (idx >= op_array_.size()) { |
| 110 | return nullptr; |
| 111 | } |
| 112 | |
| 113 | return op_array_[idx++]; |
| 114 | } |
| 115 | |
| 116 | ExternalCustomOps(ExternalCustomOps const&) = delete; |
| 117 | void operator=(ExternalCustomOps const&) = delete; |
| 118 | |
| 119 | private: |
| 120 | std::vector<const OrtCustomOp*> op_array_; |
| 121 | }; |
| 122 | |
| 123 | extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op) { |
| 124 | ExternalCustomOps::instance().Add(c_op); |
| 125 | return true; |
| 126 | } |
| 127 | |
| 128 | |
| 129 | |
| 130 | extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) { |
| 131 | OrtCustomOpDomain* domain = nullptr; |
| 132 | const OrtApi* ortApi = api->GetApi(ORT_API_VERSION); |
| 133 | std::set<std::string> pyop_nameset; |
| 134 | |
| 135 | if (auto status = ortApi->CreateCustomOpDomain(c_OpDomain, &domain)) { |
| 136 | return status; |
| 137 | } |
| 138 | |
| 139 | #if defined(PYTHON_OP_SUPPORT) |
| 140 | size_t count = 0; |
| 141 | const OrtCustomOp* c_ops = FetchPyCustomOps(count); |
| 142 | while (c_ops != nullptr) { |
| 143 | if (auto status = ortApi->CustomOpDomain_Add(domain, c_ops)) { |
| 144 | return status; |
| 145 | } else { |
| 146 | pyop_nameset.emplace(c_ops->GetName(c_ops)); |
| 147 | } |
| 148 | ++count; |
| 149 | c_ops = FetchPyCustomOps(count); |
| 150 | } |
| 151 | #endif |
| 152 | |
| 153 | static std::vector<FxLoadCustomOpFactory> c_factories = { |
| 154 | []() { return const_cast<const OrtCustomOp**>(operator_lists); } |
| 155 | #if defined(ENABLE_MATH) |
| 156 | , |
| 157 | LoadCustomOpClasses_Math |
| 158 | #endif |
| 159 | #if defined(ENABLE_GPT2_TOKENIZER) |
| 160 | , |
| 161 | LoadTokenizerSchemaList |
| 162 | #endif |
| 163 | }; |
| 164 | |
| 165 | for (auto fx : c_factories) { |
| 166 | auto ops = fx(); |
| 167 | while (*ops != nullptr) { |
| 168 | if (pyop_nameset.find((*ops)->GetName(*ops)) == pyop_nameset.end()) { |
| 169 | if (auto status = ortApi->CustomOpDomain_Add(domain, *ops)) { |
| 170 | return status; |
| 171 | } |
| 172 | } |
| 173 | ++ops; |
| 174 | } |
| 175 | } |
| 176 | |
| 177 | size_t idx = 0; |
| 178 | const OrtCustomOp* e_ops = ExternalCustomOps::instance().GetNextOp(idx); |
| 179 | while (e_ops != nullptr) { |
| 180 | if (pyop_nameset.find(e_ops->GetName(e_ops)) == pyop_nameset.end()) { |
| 181 | if (auto status = ortApi->CustomOpDomain_Add(domain, e_ops)) { |
| 182 | return status; |
| 183 | } |
| 184 | e_ops = ExternalCustomOps::instance().GetNextOp(idx); |
| 185 | } |
| 186 | } |
| 187 | |
| 188 | return ortApi->AddCustomOpDomain(options, domain); |
| 189 | } |
| 190 | |