microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
00b54739c555efeeeb29bcad687d5fab5caf0bc9

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

pyop/pykernel.h

97lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5
6#include <vector>
7#include <map>
8#include "ocos.h"
9
10struct PyCustomOpDef {
11 std::string op_type;
12 uint64_t obj_id = 0;
13 std::vector<int> input_types;
14 std::vector<int> output_types;
15 std::vector<std::string> attrs;
16
17 static void AddOp(const PyCustomOpDef* cod);
18
19 // no initializer here to avoid gcc whole-archive
20 static const int undefined;
21 static const int dt_float;
22 static const int dt_uint8;
23 static const int dt_int8;
24 static const int dt_uint16;
25 static const int dt_int16;
26 static const int dt_int32;
27 static const int dt_int64;
28 static const int dt_string;
29 static const int dt_bool;
30 static const int dt_float16;
31 static const int dt_double;
32 static const int dt_uint32;
33 static const int dt_uint64;
34 static const int dt_complex64;
35 static const int dt_complex128;
36 static const int dt_bfloat16;
37};
38
39struct PyCustomOpKernel {
40 PyCustomOpKernel(const OrtApi& api, const OrtKernelInfo* info, uint64_t id, const std::vector<std::string>& attrs);
41 void Compute(OrtKernelContext* context);
42
43 private:
44 const OrtApi& api_;
45 OrtW::CustomOpApi ort_;
46 uint64_t obj_id_;
47 std::map<std::string, std::string> attrs_values_;
48};
49
50struct PyCustomOpFactory : OrtW::CustomOpBase<PyCustomOpFactory, PyCustomOpKernel> {
51 PyCustomOpFactory() {
52 // STL vector needs it.
53 }
54
55 PyCustomOpFactory(const PyCustomOpDef* opdef, const std::string& domain, const std::string& op) {
56 if (opdef == nullptr)
57 throw std::runtime_error("Python definition is empty.");
58 opdef_ = opdef;
59 op_domain_ = domain;
60 op_type_ = op;
61 }
62
63 void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
64 return CreateKernelImpl(api, info, opdef_->obj_id, opdef_->attrs);
65 };
66
67 const char* GetName() const {
68 return op_type_.c_str();
69 };
70
71 size_t GetInputTypeCount() const {
72 return opdef_->input_types.size();
73 };
74
75 ONNXTensorElementDataType GetInputType(size_t idx) const {
76 return static_cast<ONNXTensorElementDataType>(opdef_->input_types[idx]);
77 };
78
79 const std::vector<std::string>& GetAttributesNames() const {
80 return opdef_->attrs;
81 }
82
83 size_t GetOutputTypeCount() const {
84 return opdef_->output_types.size();
85 };
86
87 ONNXTensorElementDataType GetOutputType(size_t idx) const {
88 return static_cast<ONNXTensorElementDataType>(opdef_->output_types[idx]);
89 }
90
91 const PyCustomOpDef* opdef_ = nullptr;
92 std::string op_type_;
93 std::string op_domain_;
94};
95
96
97bool EnablePyCustomOps(bool enable = true);