microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v0.4.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/ocos.h

117lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5
6#include <vector>
7#include <functional>
8
9#define ORT_API_MANUAL_INIT
10#include "onnxruntime_cxx_api.h"
11#undef ORT_API_MANUAL_INIT
12
13
14// A helper API to support test kernels.
15// Must be invoked before RegisterCustomOps.
16extern "C" bool ORT_API_CALL AddExternalCustomOp(const OrtCustomOp* c_op);
17
18const char c_OpDomain[] = "ai.onnx.contrib";
19
20struct BaseKernel {
21 BaseKernel(OrtApi api) : api_(api), info_(nullptr), ort_(api_) {}
22 BaseKernel(OrtApi api, const OrtKernelInfo* info) : api_(api), info_(info), ort_(api_) {}
23
24 bool HasAttribute(const char* name) const;
25
26 template <class T>
27 bool TryToGetAttribute(const char* name, T& value);
28
29 template <class T>
30 T TryToGetAttributeWithDefault(const char* name, T default_value) {
31 T& result = default_value;
32 TryToGetAttribute(name, result);
33 return result;
34 }
35
36 void SetOutput(OrtKernelContext* ctx, size_t output_idx, const std::vector<int64_t>& dim, const std::vector<int64_t>& data);
37
38 protected:
39 OrtErrorCode GetErrorCodeAndRelease(OrtStatusPtr status);
40 OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
41 Ort::CustomOpApi ort_;
42 const OrtKernelInfo* info_;
43};
44
45struct OrtTensorDimensions : std::vector<int64_t> {
46 OrtTensorDimensions() = default;
47 OrtTensorDimensions(Ort::CustomOpApi& ort, const OrtValue* value) {
48 OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
49 std::vector<int64_t>::operator=(ort.GetTensorShape(info));
50 ort.ReleaseTensorTypeAndShapeInfo(info);
51 }
52 const std::vector<int64_t>& GetDims() const { return *this; }
53 int64_t Size() const {
54 if (empty()) {
55 return 0;
56 }
57
58 int64_t s = 1.;
59 for (auto it = begin(); it != end(); ++it)
60 s *= *it;
61 return s;
62 }
63};
64
65
66template <typename... Args>
67class CuopContainer {
68 public:
69 CuopContainer() : ocos_list_({[]() { return new Args; }()...}) {
70 ocos_list_.push_back(nullptr);
71 }
72
73 ~CuopContainer() {
74 // skip the last null pointer.
75 for (auto i = 0; i < ocos_list_.size() - 1; i++) {
76 delete ocos_list_[i];
77 }
78
79 ocos_list_.clear();
80 }
81
82 const OrtCustomOp** GetList() {
83 return &const_cast<const OrtCustomOp*&>(ocos_list_.front());
84 }
85
86 private:
87 std::vector<OrtCustomOp*> ocos_list_;
88};
89
90struct CustomOpClassBegin{
91};
92
93typedef std::function<const OrtCustomOp**()> FxLoadCustomOpFactory;
94
95template <typename _Begin_place_holder, typename... Args>
96const OrtCustomOp** LoadCustomOpClasses() {
97 static CuopContainer<Args...> ctr; // Let C++ runtime take cares of the MP initializing.
98 return ctr.GetList();
99}
100
101#if defined(PYTHON_OP_SUPPORT)
102const OrtCustomOp* FetchPyCustomOps(size_t& count);
103OrtStatusPtr RegisterPythonDomainAndOps(OrtSessionOptions*, const OrtApi*);
104bool EnablePyCustomOps(bool enable = true);
105#endif
106
107#ifdef ENABLE_MATH
108extern FxLoadCustomOpFactory LoadCustomOpClasses_Math;
109#endif // ENABLE_MATH
110
111#ifdef ENABLE_TOKENIZER
112extern FxLoadCustomOpFactory LoadCustomOpClasses_Tokenizer;
113#endif // ENABLE_TOKENIZER
114
115#ifdef ENABLE_TF_STRING
116extern FxLoadCustomOpFactory LoadCustomOpClasses_Text;
117#endif // ENABLE_TF_STRING
118