microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v0.4.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/text/re2_strings/string_regex_replace.cc

85lines · modepreview

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "string_regex_replace.hpp"
#include <vector>
#include <cmath>
#include <algorithm>
#include "re2/re2.h"
#include "string_tensor.h"

KernelStringRegexReplace::KernelStringRegexReplace(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
  global_replace_ = HasAttribute("global_replace") ? ort_.KernelInfoGetAttribute<int64_t>(info_, "global_replace") : 1;
}

void KernelStringRegexReplace::Compute(OrtKernelContext* context) {
  // Setup inputs
  const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
  const OrtValue* pattern = ort_.KernelContext_GetInput(context, 1);
  const OrtValue* rewrite = ort_.KernelContext_GetInput(context, 2);

  std::vector<std::string> str_input, str_pattern, str_rewrite;
  GetTensorMutableDataString(api_, ort_, context, input, str_input);
  GetTensorMutableDataString(api_, ort_, context, pattern, str_pattern);
  GetTensorMutableDataString(api_, ort_, context, rewrite, str_rewrite);

  // Verifications
  OrtTensorDimensions pattern_dimensions(ort_, pattern);
  OrtTensorDimensions rewrite_dimensions(ort_, rewrite);
  if (pattern_dimensions.size() != 1 || pattern_dimensions[0] != 1)
    ORT_CXX_API_THROW(MakeString(
        "pattern (second input) must contain only one element. It has ",
        pattern_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
  if (rewrite_dimensions.size() != 1 || rewrite_dimensions[0] != 1)
    ORT_CXX_API_THROW(MakeString(
        "rewrite (third input) must contain only one element. It has ",
        rewrite_dimensions.size(), " dimensions."), ORT_INVALID_ARGUMENT);
  if (str_pattern[0].empty())
    ORT_CXX_API_THROW("pattern (second input) cannot be empty.", ORT_INVALID_ARGUMENT);

  // Setup output
  OrtTensorDimensions dimensions(ort_, input);
  OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());

  OrtTensorTypeAndShapeInfo* output_info = ort_.GetTensorTypeAndShape(output);
  int64_t size = ort_.GetTensorShapeElementCount(output_info);
  ort_.ReleaseTensorTypeAndShapeInfo(output_info);

  re2::StringPiece piece(str_rewrite[0]);
  re2::RE2 reg(str_pattern[0]);

  // Do computation
  if (global_replace_) {
    for (int64_t i = 0; i < size; i++) {
      re2::RE2::GlobalReplace(&(str_input[i]), reg, piece);
    }
  } else {
    for (int64_t i = 0; i < size; i++) {
      re2::RE2::Replace(&(str_input[i]), reg, piece);
    }
  }

  FillTensorDataString(api_, ort_, context, str_input, output);
}

void* CustomOpStringRegexReplace::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
  return new KernelStringRegexReplace(api, info);
};

const char* CustomOpStringRegexReplace::GetName() const { return "StringRegexReplace"; };

size_t CustomOpStringRegexReplace::GetInputTypeCount() const {
  return 3;
};

ONNXTensorElementDataType CustomOpStringRegexReplace::GetInputType(size_t /*index*/) const {
  return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};

size_t CustomOpStringRegexReplace::GetOutputTypeCount() const {
  return 1;
};

ONNXTensorElementDataType CustomOpStringRegexReplace::GetOutputType(size_t /*index*/) const {
  return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};