microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
rel-0.4

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/text/re2_strings/string_regex_replace.cc

85lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "string_regex_replace.hpp"
5#include <vector>
6#include <cmath>
7#include <algorithm>
8#include "re2/re2.h"
9#include "string_tensor.h"
10
11KernelStringRegexReplace::KernelStringRegexReplace(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
12 global_replace_ = HasAttribute("global_replace") ? ort_.KernelInfoGetAttribute<int64_t>(info_, "global_replace") : 1;
13}
14
15void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
16 // Setup inputs
17 const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
18 const OrtValue* pattern = ort_.KernelContext_GetInput(context, 1);
19 const OrtValue* rewrite = ort_.KernelContext_GetInput(context, 2);
20
21 std::vector<std::string> str_input, str_pattern, str_rewrite;
22 GetTensorMutableDataString(api_, ort_, context, input, str_input);
23 GetTensorMutableDataString(api_, ort_, context, pattern, str_pattern);
24 GetTensorMutableDataString(api_, ort_, context, rewrite, str_rewrite);
25
26 // Verifications
27 OrtTensorDimensions pattern_dimensions(ort_, pattern);
28 OrtTensorDimensions rewrite_dimensions(ort_, rewrite);
29 if (pattern_dimensions.size() != 1 || pattern_dimensions[0] != 1)
30 ORT_CXX_API_THROW(MakeString(
31 "pattern (second input) must contain only one element. It has ",
32 pattern_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
33 if (rewrite_dimensions.size() != 1 || rewrite_dimensions[0] != 1)
34 ORT_CXX_API_THROW(MakeString(
35 "rewrite (third input) must contain only one element. It has ",
36 rewrite_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
37 if (str_pattern[0].empty())
38 ORT_CXX_API_THROW("pattern (second input) cannot be empty.", ORT_INVALID_ARGUMENT);
39
40 // Setup output
41 OrtTensorDimensions dimensions(ort_, input);
42 OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
43
44 OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
45 int64_t size = ort_.GetTensorShapeElementCount(output_info);
46 ort_.ReleaseTensorTypeAndShapeInfo(output_info);
47
48 re2::StringPiece piece(str_rewrite[0]);
49 re2::RE2 reg(str_pattern[0]);
50
51 // Do computation
52 if (global_replace_) {
53 for (int64_t i = 0; i < size; i++) {
54 re2::RE2::GlobalReplace(&(str_input[i]), reg, piece);
55 }
56 } else {
57 for (int64_t i = 0; i < size; i++) {
58 re2::RE2::Replace(&(str_input[i]), reg, piece);
59 }
60 }
61
62 FillTensorDataString(api_, ort_, context, str_input, output);
63}
64
65void* CustomOpStringRegexReplace::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
66 return new KernelStringRegexReplace(api, info);
67};
68
69const char* CustomOpStringRegexReplace::GetName() const { return "StringRegexReplace"; };
70
71size_t CustomOpStringRegexReplace::GetInputTypeCount() const {
72 return 3;
73};
74
75ONNXTensorElementDataType CustomOpStringRegexReplace::GetInputType(size_t /*index*/) const {
76 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
77};
78
79size_t CustomOpStringRegexReplace::GetOutputTypeCount() const {
80 return 1;
81};
82
83ONNXTensorElementDataType CustomOpStringRegexReplace::GetOutputType(size_t /*index*/) const {
84 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
85};
86