microsoft/onnxruntime-extensions

Public

mirrored from https://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v0.7.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/text/re2_strings/string_regex_replace.cc

84lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "string_regex_replace.hpp"
5#include <vector>
6#include <cmath>
7#include <algorithm>
8#include "re2/re2.h"
9#include "string_tensor.h"
10
11KernelStringRegexReplace::KernelStringRegexReplace(const OrtApi& api, const OrtKernelInfo& info)
12 : BaseKernel(api, info) {
13 global_replace_ = HasAttribute("global_replace") ? ort_.KernelInfoGetAttribute<int64_t>(&info_, "global_replace") : 1;
14}
15
16void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
17 // Setup inputs
18 const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
19 const OrtValue* pattern = ort_.KernelContext_GetInput(context, 1);
20 const OrtValue* rewrite = ort_.KernelContext_GetInput(context, 2);
21
22 std::vector<std::string> str_input, str_pattern, str_rewrite;
23 GetTensorMutableDataString(api_, ort_, context, input, str_input);
24 GetTensorMutableDataString(api_, ort_, context, pattern, str_pattern);
25 GetTensorMutableDataString(api_, ort_, context, rewrite, str_rewrite);
26
27 // Verifications
28 OrtTensorDimensions pattern_dimensions(ort_, pattern);
29 OrtTensorDimensions rewrite_dimensions(ort_, rewrite);
30 if (pattern_dimensions.size() != 1 || pattern_dimensions[0] != 1)
31 ORTX_CXX_API_THROW(MakeString(
32 "pattern (second input) must contain only one element. It has ",
33 pattern_dimensions.size(), " dimensions."),
34 ORT_INVALID_ARGUMENT);
35 if (rewrite_dimensions.size() != 1 || rewrite_dimensions[0] != 1)
36 ORTX_CXX_API_THROW(MakeString(
37 "rewrite (third input) must contain only one element. It has ",
38 rewrite_dimensions.size(), " dimensions."),
39 ORT_INVALID_ARGUMENT);
40 if (str_pattern[0].empty())
41 ORTX_CXX_API_THROW("pattern (second input) cannot be empty.", ORT_INVALID_ARGUMENT);
42
43 // Setup output
44 OrtTensorDimensions dimensions(ort_, input);
45 OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
46
47 OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
48 size_t size = ort_.GetTensorShapeElementCount(output_info);
49 ort_.ReleaseTensorTypeAndShapeInfo(output_info);
50
51 re2::StringPiece piece(str_rewrite[0]);
52 re2::RE2 reg(str_pattern[0]);
53
54 // Do computation
55 if (global_replace_) {
56 for (size_t i = 0; i < size; i++) {
57 re2::RE2::GlobalReplace(&(str_input[i]), reg, piece);
58 }
59 } else {
60 for (size_t i = 0; i < size; i++) {
61 re2::RE2::Replace(&(str_input[i]), reg, piece);
62 }
63 }
64
65 FillTensorDataString(api_, ort_, context, str_input, output);
66}
67
68const char* CustomOpStringRegexReplace::GetName() const { return "StringRegexReplace"; };
69
70size_t CustomOpStringRegexReplace::GetInputTypeCount() const {
71 return 3;
72};
73
74ONNXTensorElementDataType CustomOpStringRegexReplace::GetInputType(size_t /*index*/) const {
75 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
76};
77
78size_t CustomOpStringRegexReplace::GetOutputTypeCount() const {
79 return 1;
80};
81
82ONNXTensorElementDataType CustomOpStringRegexReplace::GetOutputType(size_t /*index*/) const {
83 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
84};
85