microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
13d9e27ccd8a0de9a1225756fbf6860a1931484f

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/text/re2_strings/string_regex_split.cc

112lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "string_regex_split.hpp"
5#include "string_regex_split_re.hpp"
6#include "string_tensor.h"
7#include <vector>
8#include <cmath>
9
10KernelStringRegexSplitWithOffsets::KernelStringRegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
11}
12
13void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) {
14 // Setup inputs
15 const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
16 const OrtValue* pattern = ort_.KernelContext_GetInput(context, 1);
17 const OrtValue* keep_pattern = ort_.KernelContext_GetInput(context, 2);
18
19 std::vector<std::string> str_input, str_pattern, str_keep_pattern;
20 GetTensorMutableDataString(api_, ort_, context, input, str_input);
21 GetTensorMutableDataString(api_, ort_, context, pattern, str_pattern);
22 GetTensorMutableDataString(api_, ort_, context, keep_pattern, str_keep_pattern);
23
24 // Verifications
25 OrtTensorDimensions keep_pattern_dimensions(ort_, keep_pattern);
26 if (str_pattern.size() != 1)
27 ORTX_CXX_API_THROW(MakeString(
28 "pattern (second input) must contain only one element. It has ",
29 str_pattern.size(), " values."), ORT_INVALID_ARGUMENT);
30 if (str_keep_pattern.size() > 1)
31 ORTX_CXX_API_THROW(MakeString(
32 "Third input must contain only one element. It has ",
33 str_keep_pattern.size(), " values."), ORT_INVALID_ARGUMENT);
34 if (str_pattern[0].empty())
35 ORTX_CXX_API_THROW("Splitting pattern cannot be empty.", ORT_INVALID_ARGUMENT);
36
37 OrtTensorDimensions dimensions(ort_, input);
38 bool include_delimiter = (str_keep_pattern.size() == 1) && (!str_keep_pattern[0].empty());
39
40 re2::RE2 reg(str_pattern[0]);
41 re2::RE2 keep_reg(include_delimiter ? str_keep_pattern[0] : "");
42
43 std::vector<std::string> all_tokens;
44 std::vector<int64_t> all_begin_offsets, all_end_offsets;
45 std::vector<int64_t> row_offsets;
46
47 for (int64_t i = 0; i < dimensions[0]; i++) {
48 row_offsets.push_back(all_begin_offsets.size());
49 std::vector<std::string_view> tokens;
50 std::vector<int64_t> begin_offsets;
51 std::vector<int64_t> end_offsets;
52 RegexSplitImpl(str_input[static_cast<size_t>(i)], reg,
53 include_delimiter, keep_reg,
54 tokens, begin_offsets, end_offsets);
55 all_tokens.insert(all_tokens.end(), tokens.begin(), tokens.end());
56 for (size_t j = 0; j < begin_offsets.size(); ++j) {
57 all_begin_offsets.push_back(begin_offsets[j]);
58 all_end_offsets.push_back(end_offsets[j]);
59 }
60 }
61 row_offsets.push_back(all_begin_offsets.size());
62
63 // Setup output
64 std::vector<int64_t> dim_out{(int64_t)all_tokens.size()};
65 OrtValue* output_text = ort_.KernelContext_GetOutput(context, 0, dim_out.data(), dim_out.size());
66 FillTensorDataString(api_, ort_, context, all_tokens, output_text);
67
68 OrtValue* output = ort_.KernelContext_GetOutput(context, 1, dim_out.data(), dim_out.size());
69 int64_t* p_output = ort_.GetTensorMutableData<int64_t>(output);
70 memcpy(p_output, all_begin_offsets.data(), all_begin_offsets.size() * sizeof(int64_t));
71
72 output = ort_.KernelContext_GetOutput(context, 2, dim_out.data(), dim_out.size());
73 p_output = ort_.GetTensorMutableData<int64_t>(output);
74 memcpy(p_output, all_end_offsets.data(), all_end_offsets.size() * sizeof(int64_t));
75
76 std::vector<int64_t> dim_out_row{(int64_t)row_offsets.size()};
77 output = ort_.KernelContext_GetOutput(context, 3, dim_out_row.data(), dim_out_row.size());
78 p_output = ort_.GetTensorMutableData<int64_t>(output);
79 memcpy(p_output, row_offsets.data(), row_offsets.size() * sizeof(int64_t));
80}
81
82void* CustomOpStringRegexSplitWithOffsets::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
83 return CreateKernelImpl(api, info);
84};
85
86const char* CustomOpStringRegexSplitWithOffsets::GetName() const { return "StringRegexSplitWithOffsets"; };
87
88size_t CustomOpStringRegexSplitWithOffsets::GetInputTypeCount() const {
89 return 3;
90};
91
92ONNXTensorElementDataType CustomOpStringRegexSplitWithOffsets::GetInputType(size_t /*index*/) const {
93 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
94};
95
96size_t CustomOpStringRegexSplitWithOffsets::GetOutputTypeCount() const {
97 return 4;
98};
99
100ONNXTensorElementDataType CustomOpStringRegexSplitWithOffsets::GetOutputType(size_t index) const {
101 switch (index) {
102 case 0:
103 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
104 case 1:
105 case 2:
106 case 3:
107 return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64;
108 default:
109 ORTX_CXX_API_THROW(MakeString(
110 "StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."), ORT_INVALID_ARGUMENT);
111 }
112};
113