microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
d8581da434a333e573cdbc51b9558142203c9c8c

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/tokenizer/blingfire_sentencebreaker.cc

101lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "blingfire_sentencebreaker.hpp"
5#include "string_tensor.h"
6#include <vector>
7#include <locale>
8#include <codecvt>
9#include <algorithm>
10#include <memory>
11
12KernelBlingFireSentenceBreaker::KernelBlingFireSentenceBreaker(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info), max_sentence(-1) {
13 model_data_ = ort_.KernelInfoGetAttribute<std::string>(info, "model");
14 if (model_data_.empty()) {
15 ORTX_CXX_API_THROW("vocabulary shouldn't be empty.", ORT_INVALID_ARGUMENT);
16 }
17
18 void* model_ptr = SetModel(reinterpret_cast<const unsigned char*>(model_data_.data()), static_cast<int>(model_data_.size()));
19
20 if (model_ptr == nullptr) {
21 ORTX_CXX_API_THROW("Invalid model", ORT_INVALID_ARGUMENT);
22 }
23
24 model_ = std::shared_ptr<void>(model_ptr, FreeModel);
25
26 if (HasAttribute("max_sentence")) {
27 max_sentence = static_cast<int>(ort_.KernelInfoGetAttribute<int64_t>(info, "max_sentence"));
28 }
29}
30
31void KernelBlingFireSentenceBreaker::Compute(OrtKernelContext* context) {
32 // Setup inputs
33 const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
34 OrtTensorDimensions dimensions(ort_, input);
35
36 // TODO: fix this scalar check.
37 if (dimensions.Size() != 1 && dimensions[0] != 1) {
38 ORTX_CXX_API_THROW("We only support string scalar.", ORT_INVALID_ARGUMENT);
39 }
40
41 std::vector<std::string> input_data;
42 GetTensorMutableDataString(api_, ort_, context, input, input_data);
43
44 std::string& input_string = input_data[0];
45 int max_length = static_cast<int>(2 * input_string.size() + 1);
46 std::unique_ptr<char[]> output_str = std::make_unique<char[]>(max_length);
47
48 int output_length = TextToSentencesWithOffsetsWithModel(input_string.data(), static_cast<int>(input_string.size()), output_str.get(), nullptr, nullptr, max_length, model_.get());
49 if (output_length < 0) {
50 ORTX_CXX_API_THROW(MakeString("splitting input:\"", input_string, "\" failed"), ORT_INVALID_ARGUMENT);
51 }
52
53 // inline split output_str by newline '\n'
54 std::vector<const char*> output_sentences;
55
56 if (output_length == 0) {
57 // put one empty string if output_length is 0
58 output_sentences.push_back("");
59 } else {
60 bool head_flag = true;
61 for (int i = 0; i < output_length; i++) {
62 if (head_flag) {
63 output_sentences.push_back(&output_str[i]);
64 head_flag = false;
65 }
66
67 if (output_str[i] == '\n') {
68 head_flag = true;
69 output_str[i] = '\0';
70 }
71 }
72 }
73
74 std::vector<int64_t> output_dimensions(1);
75 output_dimensions[0] = output_sentences.size();
76
77 OrtValue* output = ort_.KernelContext_GetOutput(context, 0, output_dimensions.data(), output_dimensions.size());
78 OrtW::ThrowOnError(api_, api_.FillStringTensor(output, output_sentences.data(), output_sentences.size()));
79}
80
81void* CustomOpBlingFireSentenceBreaker::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
82 return CreateKernelImpl(api, info);
83};
84
85const char* CustomOpBlingFireSentenceBreaker::GetName() const { return "BlingFireSentenceBreaker"; };
86
87size_t CustomOpBlingFireSentenceBreaker::GetInputTypeCount() const {
88 return 1;
89};
90
91ONNXTensorElementDataType CustomOpBlingFireSentenceBreaker::GetInputType(size_t /*index*/) const {
92 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
93};
94
95size_t CustomOpBlingFireSentenceBreaker::GetOutputTypeCount() const {
96 return 1;
97};
98
99ONNXTensorElementDataType CustomOpBlingFireSentenceBreaker::GetOutputType(size_t /*index*/) const {
100 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
101};