microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
92f6b51106c9e9143c452e537cb5e41d2dcaa266

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

includes/ocos.h

51lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5
6#define ORT_API_MANUAL_INIT
7#include "onnxruntime_cxx_api.h"
8#undef ORT_API_MANUAL_INIT
9
10#if defined(ENABLE_GPT2_TOKENIZER)
11const OrtCustomOp** LoadTokenizerSchemaList();
12#endif // ENABLE_GPT2_TOKENIZER
13
14#if defined(PYTHON_OP_SUPPORT)
15const OrtCustomOp* FetchPyCustomOps(size_t& count);
16bool EnablePyCustomOps(bool enable = true);
17#endif
18
19// A helper API to support test kernels.
20// Must be invoked before RegisterCustomOps.
21extern "C" bool AddExternalCustomOp(const OrtCustomOp* c_op);
22
23const char c_OpDomain[] = "ai.onnx.contrib";
24
25struct BaseKernel {
26 BaseKernel(OrtApi api) : api_(api), info_(nullptr), ort_(api_) {}
27 BaseKernel(OrtApi api, const OrtKernelInfo* info) : api_(api), info_(info), ort_(api_) {}
28
29 bool HasAttribute(const char* name) const;
30
31 protected:
32 OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
33 Ort::CustomOpApi ort_;
34 const OrtKernelInfo* info_;
35};
36
37struct OrtTensorDimensions : std::vector<int64_t> {
38 OrtTensorDimensions() = default;
39 OrtTensorDimensions(Ort::CustomOpApi& ort, const OrtValue* value) {
40 OrtTensorTypeAndShapeInfo* info = ort.GetTensorTypeAndShape(value);
41 std::vector<int64_t>::operator=(ort.GetTensorShape(info));
42 ort.ReleaseTensorTypeAndShapeInfo(info);
43 }
44 const std::vector<int64_t>& GetDims() const { return *this; }
45 int64_t Size() const {
46 int64_t s = 1.;
47 for (auto it = begin(); it != end(); ++it)
48 s *= *it;
49 return s;
50 }
51};
52