microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
0d5d19f67b28024de0b88d4a61bcc4157dc06248

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

include/ocos.h

143lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5
6#include <algorithm>
7#include <cassert>
8#include <functional>
9#include <iterator>
10#include <string>
11#include <vector>
12
13#include "op_def_struct.h"
14#include "ext_status.h"
15
16// A helper API to support test kernels.
17// Must be invoked before RegisterCustomOps.
18extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op);
19
20constexpr const char* c_OpDomain = "ai.onnx.contrib";
21constexpr const char* c_ComMsExtOpDomain = "com.microsoft.extensions";
22
23
24template <typename... Args>
25class CuopContainer {
26 public:
27 CuopContainer() : op_instances_({[]() { return std::make_shared<Args>(); }()...}) {
28 ocos_list_.reserve(op_instances_.size());
29 std::transform(op_instances_.begin(), op_instances_.end(), std::back_inserter(ocos_list_),
30 [](const std::shared_ptr<OrtCustomOp>& custom_op) { return custom_op.get(); });
31 }
32
33 const std::vector<const OrtCustomOp*>& GetCustomOps() const {
34 return ocos_list_;
35 }
36
37 private:
38 std::vector<const OrtCustomOp*> ocos_list_;
39 std::vector<std::shared_ptr<OrtCustomOp>> op_instances_; // use shared_ptr to capture type specific deleter
40};
41
42#define CustomCpuFunc(name, f) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp(name, "CPUExecutionProvider", f)); }
43#define CustomCpuStruct(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp<s>(name, "CPUExecutionProvider")); }
44#define CustomAzureStruct(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp<s>(name, "AzureExecutionProvider")); }
45
46#define CustomCpuFuncV2(name, f) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOpV2(name, "CPUExecutionProvider", f)); }
47#define CustomCpuStructV2(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOpV2<s>(name, "CPUExecutionProvider")); }
48
49#define CustomCudaFuncV2(name, f) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOpV2(name, "CUDAExecutionProvider", f)); }
50#define CustomCudaStructV2(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOpV2<s>(name, "CUDAExecutionProvider")); }
51
52template <typename F>
53void AppendCustomOp(std::vector<std::shared_ptr<OrtCustomOp>>& ops,
54 F arg) {
55 ops.emplace_back(std::move(arg()));
56}
57
58template <typename T, typename... Args>
59void AppendCustomOp(std::vector<std::shared_ptr<OrtCustomOp>>& ops,
60 T arg, Args... args) {
61 AppendCustomOp(ops, arg);
62 AppendCustomOp(ops, args...);
63}
64
65class OrtOpLoader {
66 public:
67 template <typename... Args>
68 OrtOpLoader(Args... args) {
69 LoadOps(args...);
70 for (auto& ptr : op_instances_) {
71 if (ptr)
72 ocos_list_.push_back(ptr.get());
73 }
74 }
75
76 const std::vector<const OrtCustomOp*>& GetCustomOps() const {
77 return ocos_list_;
78 }
79
80 private:
81 template <typename T>
82 void LoadOps(T fn) {
83 AppendCustomOp(op_instances_, fn);
84 }
85
86 template <typename T, typename... Args>
87 void LoadOps(T fn, Args... args) {
88 AppendCustomOp(op_instances_, fn);
89 AppendCustomOp(op_instances_, args...);
90 }
91
92 std::vector<const OrtCustomOp*> ocos_list_;
93 std::vector<std::shared_ptr<OrtCustomOp>> op_instances_;
94};
95
96struct CustomOpClassNull {
97};
98
99template <typename _Begin_place_holder = CustomOpClassNull, typename... Args>
100const std::vector<const OrtCustomOp*>& LoadCustomOpClasses() {
101 static CuopContainer<Args...> ctr; // Let C++ runtime take cares of the MP initializing.
102 return ctr.GetCustomOps();
103}
104
105using CustomOpArray = const std::vector<const OrtCustomOp*>;
106using FxLoadCustomOpFactory = std::function<CustomOpArray&()>;
107
108#if defined(PYTHON_OP_SUPPORT)
109const OrtCustomOp* FetchPyCustomOps(size_t& count);
110OrtStatusPtr RegisterPythonDomainAndOps(OrtSessionOptions*, const OrtApi*);
111#endif
112
113#ifdef ENABLE_MATH
114extern FxLoadCustomOpFactory LoadCustomOpClasses_Math;
115#endif // ENABLE_MATH
116
117#ifdef ENABLE_TOKENIZER
118extern FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer;
119#endif // ENABLE_TOKENIZER
120
121#ifdef ENABLE_TF_STRING
122extern FxLoadCustomOpFactory LoadCustomOpClasses_Text;
123#endif // ENABLE_TF_STRING
124
125#ifdef ENABLE_CV2
126extern FxLoadCustomOpFactory LoadCustomOpClasses_CV2;
127#endif // ENABLE_CV2
128
129#ifdef ENABLE_VISION
130extern FxLoadCustomOpFactory LoadCustomOpClasses_Vision;
131#endif
132
133#ifdef ENABLE_DR_LIBS
134extern FxLoadCustomOpFactory LoadCustomOpClasses_Audio;
135#endif
136
137#if ENABLE_AZURE
138extern FxLoadCustomOpFactory LoadCustomOpClasses_Azure;
139#endif
140
141#if USE_CUDA
142extern FxLoadCustomOpFactory LoadCustomOpClasses_Contrib;
143#endif // USE_CUDA
144