microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
operators/text/re2_strings/string_regex_split.cc
112lines · modecode
| 1 | // Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | // Licensed under the MIT License. |
| 3 | |
| 4 | #include "string_regex_split.hpp" |
| 5 | #include "string_regex_split_re.hpp" |
| 6 | #include "string_tensor.h" |
| 7 | #include <vector> |
| 8 | #include <cmath> |
| 9 | |
| 10 | KernelStringRegexSplitWithOffsets::KernelStringRegexSplitWithOffsets(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) { |
| 11 | } |
| 12 | |
| 13 | void KernelStringRegexSplitWithOffsets::Compute(OrtKernelContext* context) { |
| 14 | // Setup inputs |
| 15 | const OrtValue* input = ort_.KernelContext_GetInput(context, 0); |
| 16 | const OrtValue* pattern = ort_.KernelContext_GetInput(context, 1); |
| 17 | const OrtValue* keep_pattern = ort_.KernelContext_GetInput(context, 2); |
| 18 | |
| 19 | std::vector<std::string> str_input, str_pattern, str_keep_pattern; |
| 20 | GetTensorMutableDataString(api_, ort_, context, input, str_input); |
| 21 | GetTensorMutableDataString(api_, ort_, context, pattern, str_pattern); |
| 22 | GetTensorMutableDataString(api_, ort_, context, keep_pattern, str_keep_pattern); |
| 23 | |
| 24 | // Verifications |
| 25 | OrtTensorDimensions keep_pattern_dimensions(ort_, keep_pattern); |
| 26 | if (str_pattern.size() != 1) |
| 27 | ORTX_CXX_API_THROW(MakeString( |
| 28 | "pattern (second input) must contain only one element. It has ", |
| 29 | str_pattern.size(), " values."), ORT_INVALID_ARGUMENT); |
| 30 | if (str_keep_pattern.size() > 1) |
| 31 | ORTX_CXX_API_THROW(MakeString( |
| 32 | "Third input must contain only one element. It has ", |
| 33 | str_keep_pattern.size(), " values."), ORT_INVALID_ARGUMENT); |
| 34 | if (str_pattern[0].empty()) |
| 35 | ORTX_CXX_API_THROW("Splitting pattern cannot be empty.", ORT_INVALID_ARGUMENT); |
| 36 | |
| 37 | OrtTensorDimensions dimensions(ort_, input); |
| 38 | bool include_delimiter = (str_keep_pattern.size() == 1) && (!str_keep_pattern[0].empty()); |
| 39 | |
| 40 | re2::RE2 reg(str_pattern[0]); |
| 41 | re2::RE2 keep_reg(include_delimiter ? str_keep_pattern[0] : ""); |
| 42 | |
| 43 | std::vector<std::string> all_tokens; |
| 44 | std::vector<int64_t> all_begin_offsets, all_end_offsets; |
| 45 | std::vector<int64_t> row_offsets; |
| 46 | |
| 47 | for (int64_t i = 0; i < dimensions[0]; i++) { |
| 48 | row_offsets.push_back(all_begin_offsets.size()); |
| 49 | std::vector<std::string_view> tokens; |
| 50 | std::vector<int64_t> begin_offsets; |
| 51 | std::vector<int64_t> end_offsets; |
| 52 | RegexSplitImpl(str_input[static_cast<size_t>(i)], reg, |
| 53 | include_delimiter, keep_reg, |
| 54 | tokens, begin_offsets, end_offsets); |
| 55 | all_tokens.insert(all_tokens.end(), tokens.begin(), tokens.end()); |
| 56 | for (size_t j = 0; j < begin_offsets.size(); ++j) { |
| 57 | all_begin_offsets.push_back(begin_offsets[j]); |
| 58 | all_end_offsets.push_back(end_offsets[j]); |
| 59 | } |
| 60 | } |
| 61 | row_offsets.push_back(all_begin_offsets.size()); |
| 62 | |
| 63 | // Setup output |
| 64 | std::vector<int64_t> dim_out{(int64_t)all_tokens.size()}; |
| 65 | OrtValue* output_text = ort_.KernelContext_GetOutput(context, 0, dim_out.data(), dim_out.size()); |
| 66 | FillTensorDataString(api_, ort_, context, all_tokens, output_text); |
| 67 | |
| 68 | OrtValue* output = ort_.KernelContext_GetOutput(context, 1, dim_out.data(), dim_out.size()); |
| 69 | int64_t* p_output = ort_.GetTensorMutableData<int64_t>(output); |
| 70 | memcpy(p_output, all_begin_offsets.data(), all_begin_offsets.size() * sizeof(int64_t)); |
| 71 | |
| 72 | output = ort_.KernelContext_GetOutput(context, 2, dim_out.data(), dim_out.size()); |
| 73 | p_output = ort_.GetTensorMutableData<int64_t>(output); |
| 74 | memcpy(p_output, all_end_offsets.data(), all_end_offsets.size() * sizeof(int64_t)); |
| 75 | |
| 76 | std::vector<int64_t> dim_out_row{(int64_t)row_offsets.size()}; |
| 77 | output = ort_.KernelContext_GetOutput(context, 3, dim_out_row.data(), dim_out_row.size()); |
| 78 | p_output = ort_.GetTensorMutableData<int64_t>(output); |
| 79 | memcpy(p_output, row_offsets.data(), row_offsets.size() * sizeof(int64_t)); |
| 80 | } |
| 81 | |
| 82 | void* CustomOpStringRegexSplitWithOffsets::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const { |
| 83 | return CreateKernelImpl(api, info); |
| 84 | }; |
| 85 | |
| 86 | const char* CustomOpStringRegexSplitWithOffsets::GetName() const { return "StringRegexSplitWithOffsets"; }; |
| 87 | |
| 88 | size_t CustomOpStringRegexSplitWithOffsets::GetInputTypeCount() const { |
| 89 | return 3; |
| 90 | }; |
| 91 | |
| 92 | ONNXTensorElementDataType CustomOpStringRegexSplitWithOffsets::GetInputType(size_t /*index*/) const { |
| 93 | return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; |
| 94 | }; |
| 95 | |
| 96 | size_t CustomOpStringRegexSplitWithOffsets::GetOutputTypeCount() const { |
| 97 | return 4; |
| 98 | }; |
| 99 | |
| 100 | ONNXTensorElementDataType CustomOpStringRegexSplitWithOffsets::GetOutputType(size_t index) const { |
| 101 | switch (index) { |
| 102 | case 0: |
| 103 | return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING; |
| 104 | case 1: |
| 105 | case 2: |
| 106 | case 3: |
| 107 | return ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64; |
| 108 | default: |
| 109 | ORTX_CXX_API_THROW(MakeString( |
| 110 | "StringRegexSplitWithOffsets has 4 outputs but index is ", index, "."), ORT_INVALID_ARGUMENT); |
| 111 | } |
| 112 | }; |
| 113 | |