microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
e0d48e255f28e5465f63e7fc141df1e1d533cc40

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/ocos.h

121lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5
6#include <algorithm>
7#include <functional>
8#include <iterator>
9#include <vector>
10
11#include "onnxruntime_customop.hpp"
12
13// A helper API to support test kernels.
14// Must be invoked before RegisterCustomOps.
15extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op);
16
17constexpr const char* c_OpDomain = "ai.onnx.contrib";
18constexpr const char* c_ComMsExtOpDomain = "com.microsoft.extensions";
19
20struct BaseKernel {
21 BaseKernel(const OrtApi& api) : api_(api), info_(nullptr), ort_(api_) {}
22 BaseKernel(const OrtApi& api, const OrtKernelInfo* info) : api_(api), info_(info), ort_(api_) {}
23
24 bool HasAttribute(const char* name) const;
25
26 template <class T>
27 bool TryToGetAttribute(const char* name, T& value);
28
29 template <class T>
30 T TryToGetAttributeWithDefault(const char* name, T default_value) {
31 T& result = default_value;
32 TryToGetAttribute(name, result);
33 return result;
34 }
35
36 void SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim, const std::vector<int64_t>& data);
37
38 protected:
39 OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status);
40 const OrtApi& api_;
41 OrtW::CustomOpApi ort_;
42 const OrtKernelInfo* info_;
43};
44
45struct OrtTensorDimensions : std::vector<int64_t> {
46 OrtTensorDimensions() = default;
47 OrtTensorDimensions(const OrtW::CustomOpApi& ort, const OrtValue* value) {
48 OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
49 std::vector<int64_t>::operator=(ort.GetTensorShape(info));
50 ort.ReleaseTensorTypeAndShapeInfo(info);
51 }
52
53 int64_t Size() const {
54 int64_t s = 1;
55 for (auto it = begin(); it != end(); ++it)
56 s *= *it;
57 return s;
58 }
59
60 bool IsScalar() const {
61 return empty();
62 }
63
64 bool IsVector() const {
65 return size() == 1;
66 }
67};
68
69template <typename... Args>
70class CuopContainer {
71 public:
72 CuopContainer() : op_instances_({[]() { return std::make_shared<Args>(); }()...}) {
73 ocos_list_.reserve(op_instances_.size());
74 std::transform(op_instances_.begin(), op_instances_.end(), std::back_inserter(ocos_list_),
75 [](const std::shared_ptr<OrtCustomOp>& custom_op) { return custom_op.get(); });
76 }
77
78 const std::vector<const OrtCustomOp*>& GetCustomOps() const {
79 return ocos_list_;
80 }
81
82 private:
83 std::vector<const OrtCustomOp*> ocos_list_;
84 std::vector<std::shared_ptr<OrtCustomOp>> op_instances_; // use shared_ptr to capture type specific deleter
85};
86
87struct CustomOpClassBegin {
88};
89
90using FxLoadCustomOpFactory = std::function<const std::vector<const OrtCustomOp*>&()>;
91
92template <typename _Begin_place_holder, typename... Args>
93const std::vector<const OrtCustomOp*>& LoadCustomOpClasses() {
94 static CuopContainer<Args...> ctr; // Let C++ runtime take cares of the MP initializing.
95 return ctr.GetCustomOps();
96}
97
98#if defined(PYTHON_OP_SUPPORT)
99const OrtCustomOp* FetchPyCustomOps(size_t& count);
100OrtStatusPtr RegisterPythonDomainAndOps(OrtSessionOptions*, const OrtApi*);
101#endif
102
103#ifdef ENABLE_MATH
104extern FxLoadCustomOpFactory LoadCustomOpClasses_Math;
105#endif // ENABLE_MATH
106
107#ifdef ENABLE_TOKENIZER
108extern FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer;
109#endif // ENABLE_TOKENIZER
110
111#ifdef ENABLE_TF_STRING
112extern FxLoadCustomOpFactory LoadCustomOpClasses_Text;
113#endif // ENABLE_TF_STRING
114
115#ifdef ENABLE_CV2
116extern FxLoadCustomOpFactory LoadCustomOpClasses_CV2;
117#endif // ENABLE_OPENCV
118
119#ifdef ENABLE_VISION
120extern FxLoadCustomOpFactory LoadCustomOpClasses_Vision;
121#endif