microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
55e9c4965e1dcb0102960c101f7ff2c1b2384c31

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

ocos/pyfunc/pykernel.h

88lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5
6#include <vector>
7#include <map>
8#include "ocos.h"
9
10struct PyCustomOpDef {
11 std::string op_type;
12 uint64_t obj_id;
13 std::vector<int> input_types;
14 std::vector<int> output_types;
15
16 static void AddOp(const PyCustomOpDef* cod);
17
18 // no initializer here to avoid gcc whole-archive
19 static const int undefined;
20 static const int dt_float;
21 static const int dt_uint8;
22 static const int dt_int8;
23 static const int dt_uint16;
24 static const int dt_int16;
25 static const int dt_int32;
26 static const int dt_int64;
27 static const int dt_string;
28 static const int dt_bool;
29 static const int dt_float16;
30 static const int dt_double;
31 static const int dt_uint32;
32 static const int dt_uint64;
33 static const int dt_complex64;
34 static const int dt_complex128;
35 static const int dt_bfloat16;
36};
37
38struct PyCustomOpKernel {
39 PyCustomOpKernel(OrtApi api, uint64_t id)
40 : api_(api),
41 ort_(api_),
42 obj_id_(id) {
43 }
44
45 void Compute(OrtKernelContext* context);
46
47 private:
48 OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
49 Ort::CustomOpApi ort_;
50 uint64_t obj_id_;
51};
52
53struct PyCustomOpFactory : Ort::CustomOpBase<PyCustomOpFactory, PyCustomOpKernel> {
54 PyCustomOpFactory(const PyCustomOpDef* opdef) {
55 if (opdef == nullptr)
56 throw std::runtime_error("Python definition is empty.");
57 opdef_ = opdef;
58 }
59
60 void* CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
61 return new PyCustomOpKernel(api, opdef_->obj_id);
62 };
63
64 const char* GetName() const {
65 return opdef_->op_type.c_str();
66 };
67
68 size_t GetInputTypeCount() const {
69 return opdef_->input_types.size();
70 };
71
72 ONNXTensorElementDataType GetInputType(size_t idx) const {
73 return static_cast<ONNXTensorElementDataType>(opdef_->input_types[idx]);
74 };
75
76 size_t GetOutputTypeCount() const {
77 return opdef_->output_types.size();
78 };
79
80 ONNXTensorElementDataType GetOutputType(size_t idx) const {
81 return static_cast<ONNXTensorElementDataType>(opdef_->output_types[idx]);
82 }
83
84 const PyCustomOpDef* opdef_;
85};
86
87std::vector<PyCustomOpFactory>& PyCustomOpDef_python_operator_list();
88const PyCustomOpFactory* PyCustomOpDef_FetchPyCustomOps(size_t count);
89