microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
223400118478b0d6512ae8491292723cf61c8fe8

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/ocos.h

138lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5
6#include <algorithm>
7#include <cassert>
8#include <functional>
9#include <iterator>
10#include <string>
11#include <vector>
12
13#include "onnxruntime_customop.hpp"
14
15// A helper API to support test kernels.
16// Must be invoked before RegisterCustomOps.
17extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op);
18
19constexpr const char* c_OpDomain = "ai.onnx.contrib";
20constexpr const char* c_ComMsExtOpDomain = "com.microsoft.extensions";
21template <typename... Args>
22class CuopContainer {
23 public:
24 CuopContainer() : op_instances_({[]() { return std::make_shared<Args>(); }()...}) {
25 ocos_list_.reserve(op_instances_.size());
26 std::transform(op_instances_.begin(), op_instances_.end(), std::back_inserter(ocos_list_),
27 [](const std::shared_ptr<OrtCustomOp>& custom_op) { return custom_op.get(); });
28 }
29
30 const std::vector<const OrtCustomOp*>& GetCustomOps() const {
31 return ocos_list_;
32 }
33
34 private:
35 std::vector<const OrtCustomOp*> ocos_list_;
36 std::vector<std::shared_ptr<OrtCustomOp>> op_instances_; // use shared_ptr to capture type specific deleter
37};
38
39#define CustomCpuFunc(name, f) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp(name, "CPUExecutionProvider", f)); }
40#define CustomCpuStruct(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp<s>(name, "CPUExecutionProvider")); }
41#define CustomAzureStruct(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp<s>(name, "AzureExecutionProvider")); }
42
43#define CustomCpuFuncV2(name, f) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOpV2(name, "CPUExecutionProvider", f)); }
44#define CustomCpuStructV2(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOpV2<s>(name, "CPUExecutionProvider")); }
45
46#define CustomCudaFuncV2(name, f) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOpV2(name, "CUDAExecutionProvider", f)); }
47#define CustomCudaStructV2(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOpV2<s>(name, "CUDAExecutionProvider")); }
48
49template <typename F>
50void AppendCustomOp(std::vector<std::shared_ptr<OrtCustomOp>>& ops,
51 F arg) {
52 ops.emplace_back(std::move(arg()));
53}
54
55template <typename T, typename... Args>
56void AppendCustomOp(std::vector<std::shared_ptr<OrtCustomOp>>& ops,
57 T arg, Args... args) {
58 AppendCustomOp(ops, arg);
59 AppendCustomOp(ops, args...);
60}
61
62class OrtOpLoader {
63 public:
64 template <typename... Args>
65 OrtOpLoader(Args... args) {
66 LoadOps(args...);
67 for (auto& ptr : op_instances_) {
68 if (ptr)
69 ocos_list_.push_back(ptr.get());
70 }
71 }
72
73 const std::vector<const OrtCustomOp*>& GetCustomOps() const {
74 return ocos_list_;
75 }
76
77 private:
78 template <typename T>
79 void LoadOps(T fn) {
80 AppendCustomOp(op_instances_, fn);
81 }
82
83 template <typename T, typename... Args>
84 void LoadOps(T fn, Args... args) {
85 AppendCustomOp(op_instances_, fn);
86 AppendCustomOp(op_instances_, args...);
87 }
88
89 std::vector<const OrtCustomOp*> ocos_list_;
90 std::vector<std::shared_ptr<OrtCustomOp>> op_instances_;
91};
92
93struct CustomOpClassNull {
94};
95
96template <typename _Begin_place_holder = CustomOpClassNull, typename... Args>
97const std::vector<const OrtCustomOp*>& LoadCustomOpClasses() {
98 static CuopContainer<Args...> ctr; // Let C++ runtime take cares of the MP initializing.
99 return ctr.GetCustomOps();
100}
101
102using CustomOpArray = const std::vector<const OrtCustomOp*>;
103using FxLoadCustomOpFactory = std::function<CustomOpArray&()>;
104
105#if defined(PYTHON_OP_SUPPORT)
106const OrtCustomOp* FetchPyCustomOps(size_t& count);
107OrtStatusPtr RegisterPythonDomainAndOps(OrtSessionOptions*, const OrtApi*);
108#endif
109
110#ifdef ENABLE_MATH
111extern FxLoadCustomOpFactory LoadCustomOpClasses_Math;
112#endif // ENABLE_MATH
113
114#ifdef ENABLE_TOKENIZER
115extern FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer;
116#endif // ENABLE_TOKENIZER
117
118#ifdef ENABLE_TF_STRING
119extern FxLoadCustomOpFactory LoadCustomOpClasses_Text;
120#endif // ENABLE_TF_STRING
121
122#ifdef ENABLE_CV2
123extern FxLoadCustomOpFactory LoadCustomOpClasses_CV2;
124#endif // ENABLE_OPENCV
125
126#ifdef ENABLE_VISION
127extern FxLoadCustomOpFactory LoadCustomOpClasses_Vision;
128#endif
129
130#ifdef ENABLE_DR_LIBS
131extern FxLoadCustomOpFactory LoadCustomOpClasses_Audio;
132#endif
133
134#if ENABLE_AZURE
135extern FxLoadCustomOpFactory LoadCustomOpClasses_Azure;
136#endif
137
138extern FxLoadCustomOpFactory LoadCustomOpClasses_Contrib;