microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
b661d5f22f396e757eb1de6e1ab28f2a50f0e81b

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

include/custom_op/tensor_tuple.inc

137lines · modecode

1template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
2static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type_def>*>::value, std::tuple<T, Ts...>>::type
3CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
4 tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
5 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
6 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
7 return std::tuple_cat(current, next);
8}
9
10template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
11static typename std::enable_if<std::is_same<T, const Custom::Tensor<data_type_def>&>::value, std::tuple<T, Ts...>>::type
12CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
13 tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
14 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
15 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
16 return std::tuple_cat(current, next);
17}
18
19template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
20static typename std::enable_if<std::is_same<T, std::optional<const Custom::Tensor<data_type_def>*>>::value, std::tuple<T, Ts...>>::type
21CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
22 if (ith_input < num_input) {
23 tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
24 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type_def>*>(tensors.back().get())};
25 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
26 return std::tuple_cat(current, next);
27 } else {
28 std::tuple<T> current = std::tuple<T>{};
29 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
30 return std::tuple_cat(current, next);
31 }
32}
33
34template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
35static typename std::enable_if<std::is_same<T, const Custom::Span<data_type_def>*>::value, std::tuple<T, Ts...>>::type
36CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
37 tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
38 if (!reinterpret_cast<Custom::OrtTensor<data_type_def>*>(tensors.back().get())->IsCpuTensor()) {
39 ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL);
40 }
41 std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type_def>*>(tensors.back().get())->AsSpan()};
42 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
43 return std::tuple_cat(current, next);
44}
45
46template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
47static typename std::enable_if<std::is_same<T, const Custom::Span<data_type_def>&>::value, std::tuple<T, Ts...>>::type
48CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
49 tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
50 if (!reinterpret_cast<Custom::OrtTensor<data_type_def>*>(tensors.back().get())->IsCpuTensor()) {
51 ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL);
52 }
53 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type_def>*>(tensors.back().get())->AsSpan()};
54 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
55 return std::tuple_cat(current, next);
56}
57
58template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
59static typename std::enable_if<std::is_same<T, std::optional<const Custom::Span<data_type_def>*>>::value, std::tuple<T, Ts...>>::type
60CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
61 if (ith_input < num_input) {
62 tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
63 if (!reinterpret_cast<Custom::OrtTensor<data_type_def>*>(tensors.back().get())->IsCpuTensor()) {
64 ORTX_CXX_API_THROW("span input could only be applied to CPU tensor", ORT_FAIL);
65 }
66 std::tuple<T> current = std::tuple<T>{&reinterpret_cast<Custom::Tensor<data_type_def>*>(tensors.back().get())->AsSpan()};
67 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
68 return std::tuple_cat(current, next);
69 } else {
70 std::tuple<T> current = std::tuple<T>{};
71 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
72 return std::tuple_cat(current, next);
73 }
74}
75
76template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
77static typename std::enable_if<std::is_same<T, data_type_def>::value, std::tuple<T, Ts...>>::type
78CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
79 tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
80 if (!reinterpret_cast<Custom::OrtTensor<data_type_def>*>(tensors.back().get())->IsCpuTensor()) {
81 ORTX_CXX_API_THROW("scalar input could only be applied to CPU tensor", ORT_FAIL);
82 }
83 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type_def>*>(tensors.back().get())->AsScalar()};
84 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
85 return std::tuple_cat(current, next);
86}
87
88template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
89static typename std::enable_if<std::is_same<T, std::optional<data_type_def>>::value, std::tuple<T, Ts...>>::type
90CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
91 if (ith_input < num_input) {
92 tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_input, true));
93 if (!reinterpret_cast<Custom::OrtTensor<data_type_def>*>(tensors.back().get())->IsCpuTensor()) {
94 ORTX_CXX_API_THROW("scalar input could only be applied to CPU tensor", ORT_FAIL);
95 }
96 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type_def>*>(tensors.back().get())->AsScalar()};
97 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
98 return std::tuple_cat(current, next);
99 } else {
100 std::tuple<T> current = std::tuple<T>{};
101 auto next = CreateTuple<ith_input + 1, ith_output, Ts...>(api, context, tensors, num_input, num_output, ep);
102 return std::tuple_cat(current, next);
103 }
104}
105
106template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
107static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type_def>*>::value, std::tuple<T, Ts...>>::type
108CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
109 tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_output, false));
110 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(tensors.back().get())};
111 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
112 return std::tuple_cat(current, next);
113}
114
115template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
116static typename std::enable_if<std::is_same<T, Custom::Tensor<data_type_def>&>::value, std::tuple<T, Ts...>>::type
117CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
118 tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_output, false));
119 std::tuple<T> current = std::tuple<T>{reinterpret_cast<T>(*tensors.back().get())};
120 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
121 return std::tuple_cat(current, next);
122}
123
124template <size_t ith_input, size_t ith_output, typename T, typename... Ts>
125static typename std::enable_if<std::is_same<T, std::optional<Custom::Tensor<data_type_def>*>>::value, std::tuple<T, Ts...>>::type
126CreateTuple(const OrtW::CustomOpApi* api, OrtKernelContext* context, std::vector<TensorPtr>& tensors, size_t num_input, size_t num_output, const std::string& ep) {
127 if (ith_output < num_output) {
128 tensors.push_back(std::make_unique<Custom::OrtTensor<data_type_def>>(*api, *context, ith_output, false));
129 std::tuple<T> current = std::tuple<T>{reinterpret_cast<Custom::Tensor<data_type_def>*>(tensors.back().get())};
130 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
131 return std::tuple_cat(current, next);
132 } else {
133 std::tuple<T> current = std::tuple<T>{};
134 auto next = CreateTuple<ith_input, ith_output + 1, Ts...>(api, context, tensors, num_input, num_output, ep);
135 return std::tuple_cat(current, next);
136 }
137}
138