microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
13d9e27ccd8a0de9a1225756fbf6860a1931484f

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/tokenizer/blingfire_sentencebreaker.cc

101lines · modepreview

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "blingfire_sentencebreaker.hpp"
#include "string_tensor.h"
#include <vector>
#include <locale>
#include <codecvt>
#include <algorithm>
#include <memory>

KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info), max_sentence(-1) {
  model_data_ = ort_.KernelInfoGetAttribute<std::string>(info, "model");
  if (model_data_.empty()) {
    ORTX_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
  }

  void* model_ptr = SetModel(reinterpret_cast<const unsigned char*>(model_data_.data()), static_cast<int>(model_data_.size()));

  if (model_ptr == nullptr) {
    ORTX_CXX_API_THROW("Invalid model", ORT_INVALID_ARGUMENT);
  }

  model_ = std::shared_ptr<void>(model_ptr, FreeModel);

  if (HasAttribute("max_sentence")) {
    max_sentence = static_cast<int>(ort_.KernelInfoGetAttribute<int64_t>(info, "max_sentence"));
  }
}

void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
  // Setup inputs
  const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
  OrtTensorDimensions dimensions(ort_, input);

  // TODO: fix this scalar check.
  if (dimensions.Size() != 1 && dimensions[0] != 1) {
    ORTX_CXX_API_THROW("We only support string scalar.", ORT_INVALID_ARGUMENT);
  }

  std::vector<std::string> input_data;
  GetTensorMutableDataString(api_, ort_, context, input, input_data);

  std::string& input_string = input_data[0];
  int max_length = static_cast<int>(2 * input_string.size() + 1);
  std::unique_ptr<char[]> output_str = std::make_unique<char[]>(max_length);

  int output_length = TextToSentencesWithOffsetsWithModel(input_string.data(), static_cast<int>(input_string.size()), output_str.get(), nullptr, nullptr, max_length, model_.get());
  if (output_length < 0) {
    ORTX_CXX_API_THROW(MakeString("splitting input:\"", input_string, "\"  failed"), ORT_INVALID_ARGUMENT);
  }

  // inline split output_str by newline '\n'
  std::vector<const char*> output_sentences;

  if (output_length == 0) {
    // put one empty string if output_length is 0
    output_sentences.push_back("");
  } else {
    bool head_flag = true;
    for (int i = 0; i < output_length; i++) {
      if (head_flag) {
        output_sentences.push_back(&output_str[i]);
        head_flag = false;
      }

      if (output_str[i] == '\n') {
        head_flag = true;
        output_str[i] = '\0';
      }
    }
  }

  std::vector<int64_t> output_dimensions(1);
  output_dimensions[0] = output_sentences.size();

  OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dimensions.data(), output_dimensions.size());
  OrtW::ThrowOnError(api_, api_.FillStringTensor(output, output_sentences.data(), output_sentences.size()));
}

void* CustomOpBlingFireSentenceBreaker::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
  return CreateKernelImpl(api, info);
};

const char* CustomOpBlingFireSentenceBreaker::GetName() const { return "BlingFireSentenceBreaker"; };

size_t CustomOpBlingFireSentenceBreaker::GetInputTypeCount() const {
  return 1;
};

ONNXTensorElementDataType CustomOpBlingFireSentenceBreaker::GetInputType(size_t /*index*/) const {
  return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};

size_t CustomOpBlingFireSentenceBreaker::GetOutputTypeCount() const {
  return 1;
};

ONNXTensorElementDataType CustomOpBlingFireSentenceBreaker::GetOutputType(size_t /*index*/) const {
  return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
};