microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
d1c657486d908aaf4c494d9a02871cae7131401e

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

ocos/ortcustomops.cc

89lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include <set>
5
6#include "kernels/op_equal.hpp"
7#include "kernels/op_segment_sum.hpp"
8#include "kernels/string_hash.hpp"
9#include "kernels/string_join.hpp"
10#include "kernels/string_regex_replace.hpp"
11#include "kernels/string_split.hpp"
12#include "kernels/string_upper.hpp"
13#include "kernels/test_output.hpp"
14#include "utils.h"
15
16CustomOpNegPos c_CustomOpNegPos;
17CustomOpSegmentSum c_CustomOpSegmentSum;
18CustomOpStringEqual c_CustomOpStringEqual;
19CustomOpStringHash c_CustomOpStringHash;
20CustomOpStringHashFast c_CustomOpStringHashFast;
21CustomOpStringJoin c_CustomOpStringJoin;
22CustomOpStringRegexReplace c_CustomOpStringRegexReplace;
23CustomOpStringSplit c_CustomOpStringSplit;
24CustomOpStringUpper c_CustomOpStringUpper;
25CustomOpOne c_CustomOpOne;
26CustomOpTwo c_CustomOpTwo;
27
28OrtCustomOp* operator_lists[] = {
29 &c_CustomOpNegPos,
30 &c_CustomOpSegmentSum,
31 &c_CustomOpStringEqual,
32 &c_CustomOpStringHash,
33 &c_CustomOpStringHashFast,
34 &c_CustomOpStringJoin,
35 &c_CustomOpStringRegexReplace,
36 &c_CustomOpStringSplit,
37 &c_CustomOpStringUpper,
38 &c_CustomOpOne,
39 &c_CustomOpTwo,
40 nullptr};
41
42extern "C" OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api) {
43 OrtCustomOpDomain* domain = nullptr;
44 const OrtApi* ortApi = api->GetApi(ORT_API_VERSION);
45 std::set<std::string> pyop_nameset;
46
47 if (auto status = ortApi->CreateCustomOpDomain(c_OpDomain, &domain)) {
48 return status;
49 }
50
51#if defined(PYTHON_OP_SUPPORT)
52 size_t count = 0;
53 const OrtCustomOp* c_ops = FetchPyCustomOps(count);
54 while (c_ops != nullptr) {
55 if (auto status = ortApi->CustomOpDomain_Add(domain, c_ops)) {
56 return status;
57 }
58 else {
59 pyop_nameset.emplace(c_ops->GetName(c_ops));
60 }
61 ++count;
62 c_ops = FetchPyCustomOps(count);
63 }
64#endif
65
66 OrtCustomOp** ops = operator_lists;
67 while (*ops != nullptr) {
68 if (pyop_nameset.find((*ops)->GetName(*ops)) == pyop_nameset.end()) {
69 if (auto status = ortApi->CustomOpDomain_Add(domain, *ops)) {
70 return status;
71 }
72 }
73 ++ops;
74 }
75
76#if defined(ENABLE_TOKENIZER)
77 const OrtCustomOp** t_ops = LoadTokenizerSchemaList();
78 while (*t_ops != nullptr) {
79 if (pyop_nameset.find((*t_ops)->GetName(*t_ops)) == pyop_nameset.end()) {
80 if (auto status = ortApi->CustomOpDomain_Add(domain, *t_ops)){
81 return status;
82 }
83 }
84 t_ops++;
85 }
86#endif
87
88 return ortApi->AddCustomOpDomain(options, domain);
89}
90