microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
0169129b19715e12031e1f6378121bd671ea7ce3

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/text/op_ragged_tensor.cc

208lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3#include "string_utils.h"
4#include "string_tensor.h"
5#include "op_ragged_tensor.hpp"
6
7void KernelRaggedTensorToSparse::Compute(OrtKernelContext* context) {
8 const OrtValue* n_elements = ort_.KernelContext_GetInput(context, 0);
9 const int64_t* p_n_elements = ort_.GetTensorData<int64_t>(n_elements);
10
11 OrtTensorDimensions d_length(ort_, n_elements);
12
13 if (d_length.size() != 1)
14 ORTX_CXX_API_THROW(MakeString(
15 "First input must have one dimension not ", d_length.size(), "."), ORT_INVALID_ARGUMENT);
16 int64_t n_els = d_length[0] - 1;
17 int64_t n_values = p_n_elements[n_els];
18 std::vector<int64_t> shape{n_values, 2};
19 std::vector<int64_t> shape2{2};
20
21 OrtValue* v0 = ort_.KernelContext_GetOutput(context, 0, shape.data(), shape.size());
22 int64_t* out0 = ort_.GetTensorMutableData<int64_t>(v0);
23 OrtValue* v1 = ort_.KernelContext_GetOutput(context, 1, shape2.data(), shape2.size());
24 int64_t* out1 = ort_.GetTensorMutableData<int64_t>(v1);
25 out1[0] = n_els;
26 out1[1] = 0;
27 int64_t row = 0;
28 int64_t i, j, length;
29 for (i = 1; i < d_length[0]; ++i) {
30 length = p_n_elements[i] - p_n_elements[i - 1];
31 if (length > out1[1])
32 out1[1] = length;
33 for (j = 0; j < length; ++j) {
34 *out0++ = row;
35 *out0++ = j;
36 }
37 ++row;
38 }
39}
40
41size_t CustomOpRaggedTensorToSparse::GetInputTypeCount() const {
42 return 1;
43};
44
45size_t CustomOpRaggedTensorToSparse::GetOutputTypeCount() const {
46 return 2;
47};
48
49ONNXTensorElementDataType CustomOpRaggedTensorToSparse::GetOutputType(size_t /*index*/) const {
50 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
51};
52
53const char* CustomOpRaggedTensorToSparse::GetName() const {
54 return "RaggedTensorToSparse";
55};
56
57ONNXTensorElementDataType CustomOpRaggedTensorToSparse::GetInputType(size_t /*index*/) const {
58 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
59};
60
61CommonRaggedTensorToDense::CommonRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
62}
63
64void CommonRaggedTensorToDense::GetInputDims(OrtKernelContext* context, const OrtValue** inputs, OrtTensorDimensions* dims) {
65 for (int i = 0; i < 4; ++i) {
66 inputs[i] = ort_.KernelContext_GetInput(context, i);
67 dims[i] = OrtTensorDimensions(ort_, inputs[i]);
68 }
69}
70
71int64_t CommonRaggedTensorToDense::GetMaxCol(int64_t n, const int64_t* p_indices) {
72 int64_t size = n;
73 int64_t max_col = 0;
74 for (int64_t i = 1; i < size; ++i) {
75 max_col = std::max(max_col, p_indices[i] - p_indices[i - 1]);
76 }
77 return max_col;
78}
79
80KernelRaggedTensorToDense::KernelRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) {
81 missing_value_ = HasAttribute("missing_value") ? ort_.KernelInfoGetAttribute<int64_t>(info, "missing_value") : -1;
82}
83
84void KernelRaggedTensorToDense::Compute(OrtKernelContext* context) {
85 const OrtValue* inputs[4];
86 OrtTensorDimensions dims[4];
87 GetInputDims(context, inputs, dims);
88
89 const int64_t* p_values = ort_.GetTensorData<int64_t>(inputs[1]);
90 const int64_t* p_missing = ort_.GetTensorData<int64_t>(inputs[2]);
91 const int64_t* p_indices = ort_.GetTensorData<int64_t>(inputs[3]);
92
93 int64_t size = dims[3].Size();
94 int64_t max_col = GetMaxCol(size, p_indices);
95
96 std::vector<int64_t> shape_out{size - 1, max_col};
97 OrtValue* output = ort_.KernelContext_GetOutput(context, 0, shape_out.data(), shape_out.size());
98 int64_t* dense = ort_.GetTensorMutableData<int64_t>(output);
99
100 int64_t pos = 0;
101 int64_t j, pos_end;
102 int64_t shape_out_size = shape_out[0] * shape_out[1];
103 for (int64_t i = 0; i < size - 1; ++i) {
104 pos_end = pos + max_col;
105 if (pos_end > shape_out_size)
106 ORTX_CXX_API_THROW(MakeString(
107 "Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
108 " - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT);
109 for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
110 dense[pos] = p_values[j];
111 }
112 for (; pos < pos_end; ++pos) {
113 dense[pos] = p_missing[0];
114 }
115 }
116}
117
118size_t CustomOpRaggedTensorToDense::GetInputTypeCount() const {
119 return 4;
120};
121
122size_t CustomOpRaggedTensorToDense::GetOutputTypeCount() const {
123 return 1;
124};
125
126ONNXTensorElementDataType CustomOpRaggedTensorToDense::GetOutputType(size_t /*index*/) const {
127 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
128};
129
130void* CustomOpRaggedTensorToDense::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
131 return CreateKernelImpl(api, info);
132};
133
134const char* CustomOpRaggedTensorToDense::GetName() const {
135 return "RaggedTensorToDense";
136};
137
138ONNXTensorElementDataType CustomOpRaggedTensorToDense::GetInputType(size_t /*index*/) const {
139 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
140};
141
142KernelStringRaggedTensorToDense::KernelStringRaggedTensorToDense(const OrtApi& api, const OrtKernelInfo* info) : CommonRaggedTensorToDense(api, info) {
143}
144
145void KernelStringRaggedTensorToDense::Compute(OrtKernelContext* context) {
146 const OrtValue* inputs[4];
147 OrtTensorDimensions dims[4];
148 GetInputDims(context, inputs, dims);
149
150 std::vector<std::string> input;
151 GetTensorMutableDataString(api_, ort_, context, inputs[1], input);
152 const int64_t* p_indices = ort_.GetTensorData<int64_t>(inputs[3]);
153 int64_t size = dims[3].Size();
154 int64_t max_col = GetMaxCol(size, p_indices);
155 std::vector<int64_t> shape_out{size - 1, max_col};
156
157 int64_t shape_out_size = shape_out[0] * shape_out[1];
158 std::vector<std::string> dense(static_cast<size_t>(max_col * (size - 1)));
159 int64_t pos = 0;
160 int64_t j, pos_end;
161 for (int64_t i = 0; i < size - 1; ++i) {
162 pos_end = pos + max_col;
163 if (pos_end > shape_out_size)
164 ORTX_CXX_API_THROW(MakeString(
165 "Unexpected index ", pos_end, " greather than ", shape_out[0], "x", shape_out[1],
166 " - i=", i, " size=", size, "."), ORT_INVALID_ARGUMENT);
167 for (j = p_indices[i]; j < p_indices[i + 1]; ++j, ++pos) {
168 dense[static_cast<size_t>(pos)] = input[static_cast<size_t>(j)];
169 }
170 pos = pos_end;
171 }
172
173 OrtValue* output = ort_.KernelContext_GetOutput(context, 0, shape_out.data(), shape_out.size());
174 FillTensorDataString(api_, ort_, context, dense, output);
175}
176
177size_t CustomOpStringRaggedTensorToDense::GetInputTypeCount() const {
178 return 4;
179};
180
181size_t CustomOpStringRaggedTensorToDense::GetOutputTypeCount() const {
182 return 1;
183};
184
185ONNXTensorElementDataType CustomOpStringRaggedTensorToDense::GetOutputType(size_t /*index*/) const {
186 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
187};
188
189void* CustomOpStringRaggedTensorToDense::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
190 return CreateKernelImpl(api, info);
191};
192
193const char* CustomOpStringRaggedTensorToDense::GetName() const {
194 return "StringRaggedTensorToDense";
195};
196
197ONNXTensorElementDataType CustomOpStringRaggedTensorToDense::GetInputType(size_t index) const {
198 switch (index) {
199 case 1:
200 case 2:
201 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
202 case 0:
203 case 3:
204 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
205 default:
206 ORTX_CXX_API_THROW(MakeString("[StringRaggedTensorToDense] Unexpected output index ", index, "."), ORT_INVALID_ARGUMENT);
207 }
208};
209