microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
shared/ortcustomops.cc
155lines · modecode
| 1 | // Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | // Licensed under the MIT License. |
| 3 | |
| 4 | #include <set> |
| 5 | |
| 6 | #include "kernels/op_equal.hpp" |
| 7 | #include "kernels/op_segment_sum.hpp" |
| 8 | #include "kernels/op_ragged_tensor.hpp" |
| 9 | #include "kernels/string_hash.hpp" |
| 10 | #include "kernels/string_join.hpp" |
| 11 | #include "kernels/string_regex_replace.hpp" |
| 12 | #include "kernels/string_split.hpp" |
| 13 | #include "kernels/string_upper.hpp" |
| 14 | #include "kernels/negpos.hpp" |
| 15 | #include "kernels/vector_to_string.hpp" |
| 16 | #include "utils/string_utils.h" |
| 17 | |
| 18 | #ifdef ENABLE_SPM_TOKENIZER |
| 19 | #include "sentencepiece_tokenizer.hpp" |
| 20 | #endif |
| 21 | |
| 22 | #ifdef ENABLE_SPM_TOKENIZER |
| 23 | CustomOpSentencepieceTokenizer c_CustomOpSentencepieceTokenizer; |
| 24 | #endif |
| 25 | |
| 26 | #ifdef ENABLE_TF_STRING |
| 27 | CustomOpNegPos c_CustomOpNegPos; |
| 28 | CustomOpSegmentSum c_CustomOpSegmentSum; |
| 29 | CustomOpRaggedTensorToSparse c_CustomOpRaggedTensorToSparse; |
| 30 | CustomOpStringEqual c_CustomOpStringEqual; |
| 31 | CustomOpStringHash c_CustomOpStringHash; |
| 32 | CustomOpStringHashFast c_CustomOpStringHashFast; |
| 33 | CustomOpStringJoin c_CustomOpStringJoin; |
| 34 | CustomOpStringRegexReplace c_CustomOpStringRegexReplace; |
| 35 | CustomOpStringSplit c_CustomOpStringSplit; |
| 36 | CustomOpStringUpper c_CustomOpStringUpper; |
| 37 | CustomOpVectorToString c_CustomOpVectorToString; |
| 38 | #endif |
| 39 | |
| 40 | OrtCustomOp* operator_lists[] = { |
| 41 | #ifdef ENABLE_SPM_TOKENIZER |
| 42 | &c_CustomOpSentencepieceTokenizer, |
| 43 | #endif |
| 44 | |
| 45 | #ifdef ENABLE_TF_STRING |
| 46 | &c_CustomOpNegPos, |
| 47 | &c_CustomOpRaggedTensorToSparse, |
| 48 | &c_CustomOpSegmentSum, |
| 49 | &c_CustomOpStringEqual, |
| 50 | &c_CustomOpStringHash, |
| 51 | &c_CustomOpStringHashFast, |
| 52 | &c_CustomOpStringJoin, |
| 53 | &c_CustomOpStringRegexReplace, |
| 54 | &c_CustomOpStringSplit, |
| 55 | &c_CustomOpStringUpper, |
| 56 | &c_CustomOpVectorToString, |
| 57 | #endif |
| 58 | nullptr}; |
| 59 | |
| 60 | |
| 61 | class ExternalCustomOps |
| 62 | { |
| 63 | public: |
| 64 | ExternalCustomOps(){ |
| 65 | } |
| 66 | |
| 67 | static ExternalCustomOps& instance() { |
| 68 | static ExternalCustomOps g_instance; |
| 69 | return g_instance; |
| 70 | } |
| 71 | |
| 72 | void Add(const OrtCustomOp* c_op) { |
| 73 | op_array_.push_back(c_op); |
| 74 | } |
| 75 | |
| 76 | const OrtCustomOp* GetNextOp(size_t& idx) { |
| 77 | if (idx >= op_array_.size()) { |
| 78 | return nullptr; |
| 79 | } |
| 80 | |
| 81 | return op_array_[idx ++]; |
| 82 | } |
| 83 | |
| 84 | ExternalCustomOps(ExternalCustomOps const&) = delete; |
| 85 | void operator=(ExternalCustomOps const&) = delete; |
| 86 | |
| 87 | private: |
| 88 | std::vector<const OrtCustomOp*> op_array_; |
| 89 | }; |
| 90 | |
| 91 | |
| 92 | extern "C" bool AddExternalCustomOp(const OrtCustomOp* c_op) { |
| 93 | ExternalCustomOps::instance().Add(c_op); |
| 94 | return true; |
| 95 | } |
| 96 | |
| 97 | |
| 98 | extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) { |
| 99 | OrtCustomOpDomain* domain = nullptr; |
| 100 | const OrtApi* ortApi = api->GetApi(ORT_API_VERSION); |
| 101 | std::set<std::string> pyop_nameset; |
| 102 | |
| 103 | if (auto status = ortApi->CreateCustomOpDomain(c_OpDomain, &domain)) { |
| 104 | return status; |
| 105 | } |
| 106 | |
| 107 | #if defined(PYTHON_OP_SUPPORT) |
| 108 | size_t count = 0; |
| 109 | const OrtCustomOp* c_ops = FetchPyCustomOps(count); |
| 110 | while (c_ops != nullptr) { |
| 111 | if (auto status = ortApi->CustomOpDomain_Add(domain, c_ops)) { |
| 112 | return status; |
| 113 | } else { |
| 114 | pyop_nameset.emplace(c_ops->GetName(c_ops)); |
| 115 | } |
| 116 | ++count; |
| 117 | c_ops = FetchPyCustomOps(count); |
| 118 | } |
| 119 | #endif |
| 120 | |
| 121 | OrtCustomOp** ops = operator_lists; |
| 122 | while (*ops != nullptr) { |
| 123 | if (pyop_nameset.find((*ops)->GetName(*ops)) == pyop_nameset.end()) { |
| 124 | if (auto status = ortApi->CustomOpDomain_Add(domain, *ops)) { |
| 125 | return status; |
| 126 | } |
| 127 | } |
| 128 | ++ops; |
| 129 | } |
| 130 | |
| 131 | #if defined(ENABLE_GPT2_TOKENIZER) |
| 132 | const OrtCustomOp** t_ops = LoadTokenizerSchemaList(); |
| 133 | while (*t_ops != nullptr) { |
| 134 | if (pyop_nameset.find((*t_ops)->GetName(*t_ops)) == pyop_nameset.end()) { |
| 135 | if (auto status = ortApi->CustomOpDomain_Add(domain, *t_ops)) { |
| 136 | return status; |
| 137 | } |
| 138 | } |
| 139 | t_ops++; |
| 140 | } |
| 141 | #endif |
| 142 | |
| 143 | size_t idx = 0; |
| 144 | const OrtCustomOp* e_ops = ExternalCustomOps::instance().GetNextOp(idx); |
| 145 | while (e_ops != nullptr) { |
| 146 | if (pyop_nameset.find(e_ops->GetName(e_ops)) == pyop_nameset.end()) { |
| 147 | if (auto status = ortApi->CustomOpDomain_Add(domain, e_ops)){ |
| 148 | return status; |
| 149 | } |
| 150 | e_ops = ExternalCustomOps::instance().GetNextOp(idx); |
| 151 | } |
| 152 | } |
| 153 | |
| 154 | return ortApi->AddCustomOpDomain(options, domain); |
| 155 | } |
| 156 | |