microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
b5ba84a185c5c85c436e744c40ac9a8cbd9d3f1f

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

shared/ortcustomops.cc

155lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include <set>
5
6#include "kernels/op_equal.hpp"
7#include "kernels/op_segment_sum.hpp"
8#include "kernels/op_ragged_tensor.hpp"
9#include "kernels/string_hash.hpp"
10#include "kernels/string_join.hpp"
11#include "kernels/string_regex_replace.hpp"
12#include "kernels/string_split.hpp"
13#include "kernels/string_upper.hpp"
14#include "kernels/negpos.hpp"
15#include "kernels/vector_to_string.hpp"
16#include "utils/string_utils.h"
17
18#ifdef ENABLE_SPM_TOKENIZER
19#include "sentencepiece_tokenizer.hpp"
20#endif
21
22#ifdef ENABLE_SPM_TOKENIZER
23CustomOpSentencepieceTokenizer c_CustomOpSentencepieceTokenizer;
24#endif
25
26#ifdef ENABLE_TF_STRING
27CustomOpNegPos c_CustomOpNegPos;
28CustomOpSegmentSum c_CustomOpSegmentSum;
29CustomOpRaggedTensorToSparse c_CustomOpRaggedTensorToSparse;
30CustomOpStringEqual c_CustomOpStringEqual;
31CustomOpStringHash c_CustomOpStringHash;
32CustomOpStringHashFast c_CustomOpStringHashFast;
33CustomOpStringJoin c_CustomOpStringJoin;
34CustomOpStringRegexReplace c_CustomOpStringRegexReplace;
35CustomOpStringSplit c_CustomOpStringSplit;
36CustomOpStringUpper c_CustomOpStringUpper;
37CustomOpVectorToString c_CustomOpVectorToString;
38#endif
39
40OrtCustomOp* operator_lists[] = {
41#ifdef ENABLE_SPM_TOKENIZER
42 &c_CustomOpSentencepieceTokenizer,
43#endif
44
45#ifdef ENABLE_TF_STRING
46 &c_CustomOpNegPos,
47 &c_CustomOpRaggedTensorToSparse,
48 &c_CustomOpSegmentSum,
49 &c_CustomOpStringEqual,
50 &c_CustomOpStringHash,
51 &c_CustomOpStringHashFast,
52 &c_CustomOpStringJoin,
53 &c_CustomOpStringRegexReplace,
54 &c_CustomOpStringSplit,
55 &c_CustomOpStringUpper,
56 &c_CustomOpVectorToString,
57#endif
58 nullptr};
59
60
61class ExternalCustomOps
62{
63 public:
64 ExternalCustomOps(){
65 }
66
67 static ExternalCustomOps& instance() {
68 static ExternalCustomOps g_instance;
69 return g_instance;
70 }
71
72 void Add(const OrtCustomOp* c_op) {
73 op_array_.push_back(c_op);
74 }
75
76 const OrtCustomOp* GetNextOp(size_t& idx) {
77 if (idx >= op_array_.size()) {
78 return nullptr;
79 }
80
81 return op_array_[idx ++];
82 }
83
84 ExternalCustomOps(ExternalCustomOps const&) = delete;
85 void operator=(ExternalCustomOps const&) = delete;
86
87 private:
88 std::vector<const OrtCustomOp*> op_array_;
89};
90
91
92extern "C" bool AddExternalCustomOp(const OrtCustomOp* c_op) {
93 ExternalCustomOps::instance().Add(c_op);
94 return true;
95}
96
97
98extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) {
99 OrtCustomOpDomain* domain = nullptr;
100 const OrtApi* ortApi = api->GetApi(ORT_API_VERSION);
101 std::set<std::string> pyop_nameset;
102
103 if (auto status = ortApi->CreateCustomOpDomain(c_OpDomain, &domain)) {
104 return status;
105 }
106
107#if defined(PYTHON_OP_SUPPORT)
108 size_t count = 0;
109 const OrtCustomOp* c_ops = FetchPyCustomOps(count);
110 while (c_ops != nullptr) {
111 if (auto status = ortApi->CustomOpDomain_Add(domain, c_ops)) {
112 return status;
113 } else {
114 pyop_nameset.emplace(c_ops->GetName(c_ops));
115 }
116 ++count;
117 c_ops = FetchPyCustomOps(count);
118 }
119#endif
120
121 OrtCustomOp** ops = operator_lists;
122 while (*ops != nullptr) {
123 if (pyop_nameset.find((*ops)->GetName(*ops)) == pyop_nameset.end()) {
124 if (auto status = ortApi->CustomOpDomain_Add(domain, *ops)) {
125 return status;
126 }
127 }
128 ++ops;
129 }
130
131#if defined(ENABLE_GPT2_TOKENIZER)
132 const OrtCustomOp** t_ops = LoadTokenizerSchemaList();
133 while (*t_ops != nullptr) {
134 if (pyop_nameset.find((*t_ops)->GetName(*t_ops)) == pyop_nameset.end()) {
135 if (auto status = ortApi->CustomOpDomain_Add(domain, *t_ops)) {
136 return status;
137 }
138 }
139 t_ops++;
140 }
141#endif
142
143 size_t idx = 0;
144 const OrtCustomOp* e_ops = ExternalCustomOps::instance().GetNextOp(idx);
145 while (e_ops != nullptr) {
146 if (pyop_nameset.find(e_ops->GetName(e_ops)) == pyop_nameset.end()) {
147 if (auto status = ortApi->CustomOpDomain_Add(domain, e_ops)){
148 return status;
149 }
150 e_ops = ExternalCustomOps::instance().GetNextOp(idx);
151 }
152 }
153
154 return ortApi->AddCustomOpDomain(options, domain);
155}
156