microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
9eef22cb81d762f6c093a4740c992582267a783f

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/tokenizer/blingfire_sentencebreaker.cc

94lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "blingfire_sentencebreaker.hpp"
5#include "string_tensor.h"
6#include <vector>
7#include <locale>
8#include <codecvt>
9#include <algorithm>
10
11KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info), max_sentence(-1) {
12 model_data_ = ort_.KernelInfoGetAttribute<std::string>(info, "model");
13 if (model_data_.empty()) {
14 ORT_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
15 }
16
17 void* model_ptr = SetModel(reinterpret_cast<unsigned char*>(model_data_.data()), model_data_.size());
18
19 if (model_ptr == nullptr) {
20 ORT_CXX_API_THROW("Invalid model", ORT_INVALID_ARGUMENT);
21 }
22
23 model_ = std::shared_ptr<void>(model_ptr, FreeModel);
24
25 if (HasAttribute("max_sentence")) {
26 max_sentence = ort_.KernelInfoGetAttribute<int64_t>(info, "max_sentence");
27 }
28}
29
30void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
31 // Setup inputs
32 const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
33 OrtTensorDimensions dimensions(ort_, input);
34
35 if (dimensions.Size() != 1 && dimensions[0] != 1) {
36 ORT_CXX_API_THROW("We only support string scalar.", ORT_INVALID_ARGUMENT);
37 }
38
39 std::vector<std::string> input_data;
40 GetTensorMutableDataString(api_, ort_, context, input, input_data);
41
42 std::string& input_string = input_data[0];
43 int max_length = 2 * input_string.size() + 1;
44 std::string output_str;
45 output_str.reserve(max_length);
46
47 int output_length = TextToSentencesWithOffsetsWithModel(input_string.data(), input_string.size(), output_str.data(), nullptr, nullptr, max_length, model_.get());
48 if (output_length < 0) {
49 ORT_CXX_API_THROW(MakeString("splitting input:\"", input_string, "\" failed"), ORT_INVALID_ARGUMENT);
50 }
51
52 // inline split output_str by newline '\n'
53 std::vector<char*> output_sentences;
54 bool head_flag = true;
55 for (int i = 0; i < output_length; i++) {
56 if (head_flag) {
57 output_sentences.push_back(&output_str[i]);
58 head_flag = false;
59 }
60
61 if (output_str[i] == '\n') {
62 head_flag = true;
63 output_str[i] = '\0';
64 }
65 }
66
67 std::vector<int64_t> output_dimensions(1);
68 output_dimensions[0] = output_sentences.size();
69
70 OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dimensions.data(), output_dimensions.size());
71 Ort::ThrowOnError(api_, api_.FillStringTensor(output, output_sentences.data(), output_sentences.size()));
72}
73
74void* CustomOpBlingFireSentenceBreaker::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
75 return new KernelBlingFireSentenceBreaker(api, info);
76};
77
78const char* CustomOpBlingFireSentenceBreaker::GetName() const { return "BlingFireSentenceBreaker"; };
79
80size_t CustomOpBlingFireSentenceBreaker::GetInputTypeCount() const {
81 return 1;
82};
83
84ONNXTensorElementDataType CustomOpBlingFireSentenceBreaker::GetInputType(size_t /*index*/) const {
85 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
86};
87
88size_t CustomOpBlingFireSentenceBreaker::GetOutputTypeCount() const {
89 return 1;
90};
91
92ONNXTensorElementDataType CustomOpBlingFireSentenceBreaker::GetOutputType(size_t /*index*/) const {
93 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
94};