microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
68b9d1dc47663a9017c55d136c804417c8efec7d

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/ocos.h

185lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5
6#include <algorithm>
7#include <cassert>
8#include <functional>
9#include <iterator>
10#include <string>
11#include <vector>
12
13#include "onnxruntime_customop.hpp"
14
15// A helper API to support test kernels.
16// Must be invoked before RegisterCustomOps.
17extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op);
18
19constexpr const char* c_OpDomain = "ai.onnx.contrib";
20constexpr const char* c_ComMsExtOpDomain = "com.microsoft.extensions";
21
22struct BaseKernel {
23 BaseKernel(const OrtApi& api, const OrtKernelInfo& info) noexcept
24 : api_(api), info_(info), ort_(api_) {
25 }
26
27 template <class T>
28 bool TryToGetAttribute(const char* name, T& value) const noexcept;
29
30 template <class T>
31 T TryToGetAttributeWithDefault(const char* name, const T& default_value) const noexcept {
32 T result = default_value;
33 TryToGetAttribute(name, result);
34 return result;
35 }
36
37 void SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim,
38 const std::vector<int64_t>& data);
39
40 protected:
41 OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept;
42
43 const OrtApi& api_;
44 OrtW::CustomOpApi ort_;
45 const OrtKernelInfo& info_;
46};
47
48struct OrtTensorDimensions : std::vector<int64_t> {
49 OrtTensorDimensions() = default;
50 OrtTensorDimensions(const OrtW::CustomOpApi& ort, const OrtValue* value) {
51 OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
52 std::vector<int64_t>::operator=(ort.GetTensorShape(info));
53 ort.ReleaseTensorTypeAndShapeInfo(info);
54 }
55
56 int64_t Size() const {
57 int64_t s = 1;
58 for (auto it = begin(); it != end(); ++it)
59 s *= *it;
60 return s;
61 }
62
63 bool IsScalar() const {
64 return empty();
65 }
66
67 bool IsVector() const {
68 return size() == 1;
69 }
70};
71
72template <typename... Args>
73class CuopContainer {
74 public:
75 CuopContainer() : op_instances_({[]() { return std::make_shared<Args>(); }()...}) {
76 ocos_list_.reserve(op_instances_.size());
77 std::transform(op_instances_.begin(), op_instances_.end(), std::back_inserter(ocos_list_),
78 [](const std::shared_ptr<OrtCustomOp>& custom_op) { return custom_op.get(); });
79 }
80
81 const std::vector<const OrtCustomOp*>& GetCustomOps() const {
82 return ocos_list_;
83 }
84
85 private:
86 std::vector<const OrtCustomOp*> ocos_list_;
87 std::vector<std::shared_ptr<OrtCustomOp>> op_instances_; // use shared_ptr to capture type specific deleter
88};
89
90#define CustomCpuFunc(name, f) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp(name, "CPUExecutionProvider", f)); }
91#define CustomCpuStruct(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp<s>(name, "CPUExecutionProvider")); }
92#define CustomAzureStruct(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp<s>(name, "AzureExecutionProvider")); }
93
94#define CustomCpuFuncV2(name, f) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOpV2(name, "CPUExecutionProvider", f)); }
95#define CustomCpuStructV2(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOpV2<s>(name, "CPUExecutionProvider")); }
96
97
98template <typename F>
99void AppendCustomOp(std::vector<std::shared_ptr<OrtCustomOp>>& ops,
100 F arg) {
101 ops.emplace_back(std::move(arg()));
102}
103
104template <typename T, typename... Args>
105void AppendCustomOp(std::vector<std::shared_ptr<OrtCustomOp>>& ops,
106 T arg, Args... args) {
107 AppendCustomOp(ops, arg);
108 AppendCustomOp(ops, args...);
109}
110
111class OrtOpLoader {
112 public:
113 template <typename... Args>
114 OrtOpLoader(Args... args) {
115 LoadOps(args...);
116 for (auto& ptr : op_instances_) {
117 if (ptr)
118 ocos_list_.push_back(ptr.get());
119 }
120 }
121
122 const std::vector<const OrtCustomOp*>& GetCustomOps() const {
123 return ocos_list_;
124 }
125
126 private:
127 template <typename T>
128 void LoadOps(T fn) {
129 AppendCustomOp(op_instances_, fn);
130 }
131
132 template <typename T, typename... Args>
133 void LoadOps(T fn, Args... args) {
134 AppendCustomOp(op_instances_, fn);
135 AppendCustomOp(op_instances_, args...);
136 }
137
138 std::vector<const OrtCustomOp*> ocos_list_;
139 std::vector<std::shared_ptr<OrtCustomOp>> op_instances_;
140};
141
142struct CustomOpClassNull {
143};
144
145template <typename _Begin_place_holder = CustomOpClassNull, typename... Args>
146const std::vector<const OrtCustomOp*>& LoadCustomOpClasses() {
147 static CuopContainer<Args...> ctr; // Let C++ runtime take cares of the MP initializing.
148 return ctr.GetCustomOps();
149}
150
151using CustomOpArray = const std::vector<const OrtCustomOp*>;
152using FxLoadCustomOpFactory = std::function<CustomOpArray&()>;
153
154#if defined(PYTHON_OP_SUPPORT)
155const OrtCustomOp* FetchPyCustomOps(size_t& count);
156OrtStatusPtr RegisterPythonDomainAndOps(OrtSessionOptions*, const OrtApi*);
157#endif
158
159#ifdef ENABLE_MATH
160extern FxLoadCustomOpFactory LoadCustomOpClasses_Math;
161#endif // ENABLE_MATH
162
163#ifdef ENABLE_TOKENIZER
164extern FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer;
165#endif // ENABLE_TOKENIZER
166
167#ifdef ENABLE_TF_STRING
168extern FxLoadCustomOpFactory LoadCustomOpClasses_Text;
169#endif // ENABLE_TF_STRING
170
171#ifdef ENABLE_CV2
172extern FxLoadCustomOpFactory LoadCustomOpClasses_CV2;
173#endif // ENABLE_OPENCV
174
175#ifdef ENABLE_VISION
176extern FxLoadCustomOpFactory LoadCustomOpClasses_Vision;
177#endif
178
179#ifdef ENABLE_DR_LIBS
180extern FxLoadCustomOpFactory LoadCustomOpClasses_Audio;
181#endif
182
183#if ENABLE_AZURE
184extern FxLoadCustomOpFactory LoadCustomOpClasses_Azure;
185#endif
186