microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
352b2003bc6c4604f4285d64133d8bdf11549253

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

pyop/pykernel.h

118lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5
6#include <vector>
7#include <map>
8#include "ocos.h"
9
10struct PyCustomOpDef {
11 std::string op_type;
12 uint64_t obj_id;
13 std::vector<int> input_types;
14 std::vector<int> output_types;
15 std::vector<std::string> attrs;
16
17 static void AddOp(const PyCustomOpDef* cod);
18
19 // no initializer here to avoid gcc whole-archive
20 static const int undefined;
21 static const int dt_float;
22 static const int dt_uint8;
23 static const int dt_int8;
24 static const int dt_uint16;
25 static const int dt_int16;
26 static const int dt_int32;
27 static const int dt_int64;
28 static const int dt_string;
29 static const int dt_bool;
30 static const int dt_float16;
31 static const int dt_double;
32 static const int dt_uint32;
33 static const int dt_uint64;
34 static const int dt_complex64;
35 static const int dt_complex128;
36 static const int dt_bfloat16;
37};
38
39struct PyCustomOpKernel {
40 PyCustomOpKernel(OrtApi api, const OrtKernelInfo* info, uint64_t id, const std::vector<std::string>& attrs)
41 : api_(api),
42 ort_(api_),
43 obj_id_(id) {
44 size_t size;
45 for (std::vector<std::string>::const_iterator it = attrs.begin(); it != attrs.end(); ++it) {
46 size = 0;
47 OrtStatus* status = api_.KernelInfoGetAttribute_string(info, it->c_str(), nullptr, &size);
48 if (api_.GetErrorCode(status) != ORT_INVALID_ARGUMENT) {
49 std::string error_message(api_.GetErrorMessage(status));
50 api_.ReleaseStatus(status);
51 throw std::runtime_error(MakeString(
52 "Unable to find attribute '", *it, "' due to '",
53 error_message, "'."));
54 }
55 api_.ReleaseStatus(status);
56 attrs_values_[*it] = "";
57 attrs_values_[*it].resize(size);
58 status = api_.KernelInfoGetAttribute_string(info, it->c_str(), &(attrs_values_[*it][0]), &size);
59 if ((status != nullptr) && (api_.GetErrorCode(status) != ORT_OK)) {
60 api_.ReleaseStatus(status);
61 throw std::runtime_error(MakeString(
62 "Unable to retrieve attribute '", *it, "' due to '",
63 api_.GetErrorMessage(status), "'."));
64 }
65 attrs_values_[*it].resize(size - 1);
66 api_.ReleaseStatus(status);
67 }
68 }
69
70 void Compute(OrtKernelContext* context);
71
72 private:
73 OrtApi api_; // keep a copy of the struct, whose ref is used in the ort_
74 Ort::CustomOpApi ort_;
75 uint64_t obj_id_;
76 std::map<std::string, std::string> attrs_values_;
77};
78
79struct PyCustomOpFactory : Ort::CustomOpBase<PyCustomOpFactory, PyCustomOpKernel> {
80 PyCustomOpFactory(const PyCustomOpDef* opdef) {
81 if (opdef == nullptr)
82 throw std::runtime_error("Python definition is empty.");
83 opdef_ = opdef;
84 }
85
86 void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
87 return new PyCustomOpKernel(api, info, opdef_->obj_id, opdef_->attrs);
88 };
89
90 const char* GetName() const {
91 return opdef_->op_type.c_str();
92 };
93
94 size_t GetInputTypeCount() const {
95 return opdef_->input_types.size();
96 };
97
98 ONNXTensorElementDataType GetInputType(size_t idx) const {
99 return static_cast<ONNXTensorElementDataType>(opdef_->input_types[idx]);
100 };
101
102 const std::vector<std::string>& GetAttributesNames() const {
103 return opdef_->attrs;
104 }
105
106 size_t GetOutputTypeCount() const {
107 return opdef_->output_types.size();
108 };
109
110 ONNXTensorElementDataType GetOutputType(size_t idx) const {
111 return static_cast<ONNXTensorElementDataType>(opdef_->output_types[idx]);
112 }
113
114 const PyCustomOpDef* opdef_;
115};
116
117std::vector<PyCustomOpFactory>& PyCustomOpDef_python_operator_list();
118const PyCustomOpFactory* PyCustomOpDef_FetchPyCustomOps(size_t count);
119