microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
rel-0.4

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/text/string_ecmaregex_replace.cc

90lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "string_ecmaregex_replace.hpp"
5#include <vector>
6#include <algorithm>
7#include <regex>
8#include "string_tensor.h"
9
10KernelStringECMARegexReplace::KernelStringECMARegexReplace(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
11 global_replace_ = TryToGetAttributeWithDefault("global_replace", true);
12 ignore_case_ = TryToGetAttributeWithDefault("ignore_case", false);
13
14}
15
16void KernelStringECMARegexReplace::Compute(OrtKernelContext* context) {
17 // Setup inputs
18 const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
19 const OrtValue* pattern = ort_.KernelContext_GetInput(context, 1);
20 const OrtValue* rewrite = ort_.KernelContext_GetInput(context, 2);
21
22 std::vector<std::string> str_input, str_pattern, str_rewrite;
23 GetTensorMutableDataString(api_, ort_, context, input, str_input);
24 GetTensorMutableDataString(api_, ort_, context, pattern, str_pattern);
25 GetTensorMutableDataString(api_, ort_, context, rewrite, str_rewrite);
26
27
28 // Verifications
29 OrtTensorDimensions pattern_dimensions(ort_, pattern);
30 OrtTensorDimensions rewrite_dimensions(ort_, rewrite);
31 if (pattern_dimensions.size() != 1 || pattern_dimensions[0] != 1)
32 ORT_CXX_API_THROW(MakeString(
33 "pattern (second input) must contain only one element. It has ",
34 pattern_dimensions.size(), " dimensions."), ORT_INVALID_GRAPH);
35 if (rewrite_dimensions.size() != 1 || rewrite_dimensions[0] != 1)
36 ORT_CXX_API_THROW(MakeString(
37 "rewrite (third input) must contain only one element. It has ",
38 rewrite_dimensions.size(), " dimensions."), ORT_INVALID_GRAPH);
39 if (str_pattern[0].empty())
40 ORT_CXX_API_THROW("pattern (second input) cannot be empty.", ORT_INVALID_GRAPH);
41
42 // Setup output
43 OrtTensorDimensions dimensions(ort_, input);
44 OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
45
46 OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
47 int64_t size = ort_.GetTensorShapeElementCount(output_info);
48 ort_.ReleaseTensorTypeAndShapeInfo(output_info);
49
50 auto regex_flag = std::regex_constants::optimize | std::regex_constants::ECMAScript;
51 if (ignore_case_) {
52 regex_flag |= std::regex_constants::icase;
53 }
54
55 std::regex reg(str_pattern[0], regex_flag);
56
57 if (global_replace_) {
58 for (int64_t i = 0; i < size; i++) {
59 str_input[i] = std::regex_replace(str_input[i], reg, str_rewrite[0]);
60 }
61 } else {
62 for (int64_t i = 0; i < size; i++) {
63 str_input[i] = std::regex_replace(str_input[i], reg, str_rewrite[0], std::regex_constants::format_first_only);
64 }
65 }
66
67 FillTensorDataString(api_, ort_, context, str_input, output);
68}
69
70void* CustomOpStringECMARegexReplace::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
71 return new KernelStringECMARegexReplace(api, info);
72};
73
74const char* CustomOpStringECMARegexReplace::GetName() const { return "StringECMARegexReplace"; };
75
76size_t CustomOpStringECMARegexReplace::GetInputTypeCount() const {
77 return 3;
78};
79
80ONNXTensorElementDataType CustomOpStringECMARegexReplace::GetInputType(size_t /*index*/) const {
81 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
82};
83
84size_t CustomOpStringECMARegexReplace::GetOutputTypeCount() const {
85 return 1;
86};
87
88ONNXTensorElementDataType CustomOpStringECMARegexReplace::GetOutputType(size_t /*index*/) const {
89 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
90};
91