microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
62d8598b6b9fa462a440ade891017eaafd4bfaee

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/ocos.h

179lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5
6#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
7#include <string>
8#include <algorithm>
9#include <functional>
10#include <iterator>
11#include <vector>
12
13#include "onnxruntime_customop.hpp"
14
15// A helper API to support test kernels.
16// Must be invoked before RegisterCustomOps.
17extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op);
18
19constexpr const char* c_OpDomain = "ai.onnx.contrib";
20constexpr const char* c_ComMsExtOpDomain = "com.microsoft.extensions";
21
22struct BaseKernel {
23 BaseKernel(const OrtApi& api, const OrtKernelInfo& info) noexcept : api_(api), info_(info), ort_(api_) {
24 }
25
26 template <class T>
27 bool TryToGetAttribute(const char* name, T& value) const noexcept;
28
29 template <class T>
30 T TryToGetAttributeWithDefault(const char* name, const T& default_value) const noexcept {
31 T result = default_value;
32 TryToGetAttribute(name, result);
33 return result;
34 }
35
36 void SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim,
37 const std::vector<int64_t>& data);
38
39 protected:
40 OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status) const noexcept;
41 const OrtApi& api_;
42 OrtW::CustomOpApi ort_;
43 const OrtKernelInfo& info_;
44};
45
46struct OrtTensorDimensions : std::vector<int64_t> {
47 OrtTensorDimensions() = default;
48 OrtTensorDimensions(const OrtW::CustomOpApi& ort, const OrtValue* value) {
49 OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
50 std::vector<int64_t>::operator=(ort.GetTensorShape(info));
51 ort.ReleaseTensorTypeAndShapeInfo(info);
52 }
53
54 int64_t Size() const {
55 int64_t s = 1;
56 for (auto it = begin(); it != end(); ++it)
57 s *= *it;
58 return s;
59 }
60
61 bool IsScalar() const {
62 return empty();
63 }
64
65 bool IsVector() const {
66 return size() == 1;
67 }
68};
69
70template <typename... Args>
71class CuopContainer {
72 public:
73 CuopContainer() : op_instances_({[]() { return std::make_shared<Args>(); }()...}) {
74 ocos_list_.reserve(op_instances_.size());
75 std::transform(op_instances_.begin(), op_instances_.end(), std::back_inserter(ocos_list_),
76 [](const std::shared_ptr<OrtCustomOp>& custom_op) { return custom_op.get(); });
77 }
78
79 const std::vector<const OrtCustomOp*>& GetCustomOps() const {
80 return ocos_list_;
81 }
82
83 private:
84 std::vector<const OrtCustomOp*> ocos_list_;
85 std::vector<std::shared_ptr<OrtCustomOp>> op_instances_; // use shared_ptr to capture type specific deleter
86};
87
88#define CustomCpuFunc(name, f) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp(name, "CPUExecutionProvider", f)); }
89#define CustomCpuStruct(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp<s>(name, "CPUExecutionProvider")); }
90#define CustomAzureStruct(name, s) []() { return std::shared_ptr<ortc::OrtLiteCustomOp>(ortc::CreateLiteCustomOp<s>(name, "AzureExecutionProvider")); }
91
92template <typename F>
93void AppendCustomOp(std::vector<std::shared_ptr<OrtCustomOp>>& ops,
94 F arg) {
95 ops.emplace_back(std::move(arg()));
96}
97
98template <typename T, typename... Args>
99void AppendCustomOp(std::vector<std::shared_ptr<OrtCustomOp>>& ops,
100 T arg, Args... args) {
101 AppendCustomOp(ops, arg);
102 AppendCustomOp(ops, args...);
103}
104
105class OrtOpLoader {
106 public:
107 template <typename... Args>
108 OrtOpLoader(Args... args) {
109 LoadOps(args...);
110 for (auto& ptr : op_instances_) {
111 if (ptr)
112 ocos_list_.push_back(ptr.get());
113 }
114 }
115
116 const std::vector<const OrtCustomOp*>& GetCustomOps() const {
117 return ocos_list_;
118 }
119
120 private:
121 template <typename T>
122 void LoadOps(T fn) {
123 AppendCustomOp(op_instances_, fn);
124 }
125
126 template <typename T, typename... Args>
127 void LoadOps(T fn, Args... args) {
128 AppendCustomOp(op_instances_, fn);
129 AppendCustomOp(op_instances_, args...);
130 }
131
132 std::vector<const OrtCustomOp*> ocos_list_;
133 std::vector<std::shared_ptr<OrtCustomOp>> op_instances_;
134};
135
136struct CustomOpClassNull {
137};
138
139template <typename _Begin_place_holder = CustomOpClassNull, typename... Args>
140const std::vector<const OrtCustomOp*>& LoadCustomOpClasses() {
141 static CuopContainer<Args...> ctr; // Let C++ runtime take cares of the MP initializing.
142 return ctr.GetCustomOps();
143}
144
145using CustomOpArray = const std::vector<const OrtCustomOp*>;
146using FxLoadCustomOpFactory = std::function<CustomOpArray&()>;
147
148#if defined(PYTHON_OP_SUPPORT)
149const OrtCustomOp* FetchPyCustomOps(size_t& count);
150OrtStatusPtr RegisterPythonDomainAndOps(OrtSessionOptions*, const OrtApi*);
151#endif
152
153#ifdef ENABLE_MATH
154extern FxLoadCustomOpFactory LoadCustomOpClasses_Math;
155#endif // ENABLE_MATH
156
157#ifdef ENABLE_TOKENIZER
158extern FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer;
159#endif // ENABLE_TOKENIZER
160
161#ifdef ENABLE_TF_STRING
162extern FxLoadCustomOpFactory LoadCustomOpClasses_Text;
163#endif // ENABLE_TF_STRING
164
165#ifdef ENABLE_CV2
166extern FxLoadCustomOpFactory LoadCustomOpClasses_CV2;
167#endif // ENABLE_OPENCV
168
169#ifdef ENABLE_VISION
170extern FxLoadCustomOpFactory LoadCustomOpClasses_Vision;
171#endif
172
173#ifdef ENABLE_DR_LIBS
174extern FxLoadCustomOpFactory LoadCustomOpClasses_Audio;
175#endif
176
177#if ENABLE_AZURE
178extern FxLoadCustomOpFactory LoadCustomOpClasses_Azure;
179#endif
180