microsoft/onnxruntime-extensions

Public

mirrored from https://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v0.4.2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/text/op_ragged_tensor.cc

215lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3#include "string_utils.h"
4#include "string_tensor.h"
5#include "op_ragged_tensor.hpp"
6
7KernelRaggedTensorToSparse::KernelRaggedTensorToSparse(OrtApi api) : BaseKernel(api) {
8}
9
10void KernelRaggedTensorToSparse::Compute(OrtKernelContext* context) {
11 const OrtValue* n_elements = ort_.KernelContext_GetInput(context, 0);
12 const int64_t* p_n_elements = ort_.GetTensorData<int64_t>(n_elements);
13
14 OrtTensorDimensions d_length(ort_, n_elements);
15
16 if (d_length.size() != 1)
17 ORT_CXX_API_THROW(MakeString(
18 "First input must have one dimension not ", d_length.size(), "."), ORT_INVALID_ARGUMENT);
19 int64_t n_els = d_length[0] - 1;
20 int64_t n_values = p_n_elements[n_els];
21 std::vector<int64_t> shape{n_values, 2};
22 std::vector<int64_t> shape2{2};
23
24 OrtValue* v0 = ort_.KernelContext_GetOutput(context, 0, shape.data(), shape.size());
25 int64_t* out0 = ort_.GetTensorMutableData<int64_t>(v0);
26 OrtValue* v1 = ort_.KernelContext_GetOutput(context, 1, shape2.data(), shape2.size());
27 int64_t* out1 = ort_.GetTensorMutableData<int64_t>(v1);
28 out1[0] = n_els;
29 out1[1] = 0;
30 int64_t row = 0;
31 int64_t i, j, length;
32 for (i = 1; i < d_length[0]; ++i) {
33 length = p_n_elements[i] - p_n_elements[i - 1];
34 if (length > out1[1])
35 out1[1] = length;
36 for (j = 0; j < length; ++j) {
37 *out0++ = row;
38 *out0++ = j;
39 }
40 ++row;
41 }
42}
43
44size_t CustomOpRaggedTensorToSparse::GetInputTypeCount() const {
45 return 1;
46};
47
48size_t CustomOpRaggedTensorToSparse::GetOutputTypeCount() const {
49 return 2;
50};
51
52ONNXTensorElementDataType CustomOpRaggedTensorToSparse::GetOutputType(size_t /*index*/) const {
53 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
54};
55
56void* CustomOpRaggedTensorToSparse::CreateKernel(OrtApi api, const OrtKernelInfo* /* info */) const {
57 return new KernelRaggedTensorToSparse(api);
58};
59
60const char* CustomOpRaggedTensorToSparse::GetName() const {
61 return "RaggedTensorToSparse";
62};
63
64ONNXTensorElementDataType CustomOpRaggedTensorToSparse::GetInputType(size_t /*index*/) const {
65 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
66};
67
68CommonRaggedTensorToDense::CommonRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
69}
70
71void CommonRaggedTensorToDense::GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims) {
72 for (int i = 0; i < 4; ++i) {
73 inputs[i] = ort_.KernelContext_GetInput(context, i);
74 dims[i] = OrtTensorDimensions(ort_, inputs[i]);
75 }
76}
77
78int64_t CommonRaggedTensorToDense::GetMaxCol(int64_t n, const int64_t* p_indices) {
79 int64_t size = n;
80 int64_t max_col = 0;
81 for (int64_t i = 1; i < size; ++i) {
82 max_col = std::max(max_col, p_indices[i] - p_indices[i - 1]);
83 }
84 return max_col;
85}
86
87KernelRaggedTensorToDense::KernelRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) {
88 missing_value_ = HasAttribute("missing_value") ? ort_.KernelInfoGetAttribute<int64_t>(info, "missing_value") : -1;
89}
90
91void KernelRaggedTensorToDense::Compute(OrtKernelContext* context) {
92 const OrtValue* inputs[4];
93 OrtTensorDimensions dims[4];
94 GetInputDims(context, inputs, dims);
95
96 const int64_t* p_values = ort_.GetTensorData<int64_t>(inputs[1]);
97 const int64_t* p_missing = ort_.GetTensorData<int64_t>(inputs[2]);
98 const int64_t* p_indices = ort_.GetTensorData<int64_t>(inputs[3]);
99
100 int64_t size = dims[3].Size();
101 int64_t max_col = GetMaxCol(size, p_indices);
102
103 std::vector<int64_t> shape_out{size - 1, max_col};
104 OrtValue* output = ort_.KernelContext_GetOutput(context, 0, shape_out.data(), shape_out.size());
105 int64_t* dense = ort_.GetTensorMutableData<int64_t>(output);
106
107 int64_t pos = 0;
108 int64_t j, pos_end;
109 int64_t shape_out_size = shape_out[0] * shape_out[1];
110 for (int64_t i = 0; i < size - 1; ++i) {
111 pos_end = pos + max_col;
112 if (pos_end > shape_out_size)
113 ORT_CXX_API_THROW(MakeString(
114 "Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
115 " - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT);
116 for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
117 dense[pos] = p_values[j];
118 }
119 for (; pos < pos_end; ++pos) {
120 dense[pos] = p_missing[0];
121 }
122 }
123}
124
125size_t CustomOpRaggedTensorToDense::GetInputTypeCount() const {
126 return 4;
127};
128
129size_t CustomOpRaggedTensorToDense::GetOutputTypeCount() const {
130 return 1;
131};
132
133ONNXTensorElementDataType CustomOpRaggedTensorToDense::GetOutputType(size_t /*index*/) const {
134 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
135};
136
137void* CustomOpRaggedTensorToDense::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
138 return new KernelRaggedTensorToDense(api, info);
139};
140
141const char* CustomOpRaggedTensorToDense::GetName() const {
142 return "RaggedTensorToDense";
143};
144
145ONNXTensorElementDataType CustomOpRaggedTensorToDense::GetInputType(size_t /*index*/) const {
146 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
147};
148
149KernelStringRaggedTensorToDense::KernelStringRaggedTensorToDense(OrtApi api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) {
150}
151
152void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
153 const OrtValue* inputs[4];
154 OrtTensorDimensions dims[4];
155 GetInputDims(context, inputs, dims);
156
157 std::vector<std::string> input;
158 GetTensorMutableDataString(api_, ort_, context, inputs[1], input);
159 const int64_t* p_indices = ort_.GetTensorData<int64_t>(inputs[3]);
160 int64_t size = dims[3].Size();
161 int64_t max_col = GetMaxCol(size, p_indices);
162 std::vector<int64_t> shape_out{size - 1, max_col};
163
164 int64_t shape_out_size = shape_out[0] * shape_out[1];
165 std::vector<std::string> dense(max_col * (size - 1));
166 int64_t pos = 0;
167 int64_t j, pos_end;
168 for (int64_t i = 0; i < size - 1; ++i) {
169 pos_end = pos + max_col;
170 if (pos_end > shape_out_size)
171 ORT_CXX_API_THROW(MakeString(
172 "Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
173 " - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT);
174 for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
175 dense[pos] = input[j];
176 }
177 pos = pos_end;
178 }
179
180 OrtValue* output = ort_.KernelContext_GetOutput(context, 0, shape_out.data(), shape_out.size());
181 FillTensorDataString(api_, ort_, context, dense, output);
182}
183
184size_t CustomOpStringRaggedTensorToDense::GetInputTypeCount() const {
185 return 4;
186};
187
188size_t CustomOpStringRaggedTensorToDense::GetOutputTypeCount() const {
189 return 1;
190};
191
192ONNXTensorElementDataType CustomOpStringRaggedTensorToDense::GetOutputType(size_t /*index*/) const {
193 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
194};
195
196void* CustomOpStringRaggedTensorToDense::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
197 return new KernelStringRaggedTensorToDense(api, info);
198};
199
200const char* CustomOpStringRaggedTensorToDense::GetName() const {
201 return "StringRaggedTensorToDense";
202};
203
204ONNXTensorElementDataType CustomOpStringRaggedTensorToDense::GetInputType(size_t index) const {
205 switch (index) {
206 case 1:
207 case 2:
208 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
209 case 0:
210 case 3:
211 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
212 default:
213 ORT_CXX_API_THROW(MakeString("[StringRaggedTensorToDense] Unexpected output index ", index, "."), ORT_INVALID_ARGUMENT);
214 }
215};