microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
rel-0.7

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/tokenizer/basic_tokenizer.cc

127lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "basic_tokenizer.hpp"
5#include "string_utils.h"
6#include "string_tensor.h"
7#include <vector>
8#include <locale>
9#include <algorithm>
10
11BasicTokenizer::BasicTokenizer(bool do_lower_case, bool tokenize_chinese_chars, bool strip_accents,
12 bool tokenize_punctuation, bool remove_control_chars)
13 : do_lower_case_(do_lower_case),
14 strip_accents_(strip_accents),
15 tokenize_chinese_chars_(tokenize_chinese_chars),
16 tokenize_punctuation_(tokenize_punctuation),
17 remove_control_chars_(remove_control_chars) {}
18
19std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
20 std::vector<ustring> result;
21 ustring token;
22 auto push_current_token_and_clear = [&result, &token]() {
23 if (!token.empty()) {
24 result.push_back(token);
25 token.clear();
26 }
27 };
28
29 auto push_single_char_and_clear = [&result, &token](char32_t c) {
30 token.push_back(c);
31 result.push_back(token);
32 token.clear();
33 };
34
35 // strip accent first
36 if (strip_accents_) {
37 for (auto& c : text) {
38 c = StripAccent(c);
39 }
40 }
41
42 if (do_lower_case_) {
43 for (auto& c : text) {
44 c = ToLower(c);
45 }
46 }
47
48 for (auto c : text) {
49 if (tokenize_chinese_chars_ && IsCJK(c)) {
50 push_current_token_and_clear();
51 push_single_char_and_clear(c);
52 continue;
53 }
54
55 if (strip_accents_ && IsAccent(c)) {
56 continue;
57 }
58
59 // 0x2019 unicode is not punctuation in some Linux platform,
60 // to be consistent, take it as punctuation.
61 if (tokenize_punctuation_ && IsPunct(c)) {
62 push_current_token_and_clear();
63 push_single_char_and_clear(c);
64 continue;
65 }
66
67 // split by space
68 if (IsSpace(c)) {
69 push_current_token_and_clear();
70 continue;
71 }
72
73 if (remove_control_chars_ && IsControl(c)) {
74 continue;
75 }
76
77 token.push_back(c);
78 }
79
80 push_current_token_and_clear();
81 return result;
82}
83
84KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo& info) : BaseKernel(api, info) {
85 bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
86 bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
87 bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
88 bool tokenize_punctuation = TryToGetAttributeWithDefault("tokenize_punctuation", false);
89 bool remove_control_chars = TryToGetAttributeWithDefault("remove_control_chars", true);
90
91 tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents, tokenize_punctuation, remove_control_chars);
92}
93
94void KernelBasicTokenizer::Compute(OrtKernelContext* context) {
95 // Setup inputs
96 const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
97 std::vector<std::string> input_data;
98 GetTensorMutableDataString(api_, ort_, context, input, input_data);
99
100 OrtTensorDimensions dimensions(ort_, input);
101 if (dimensions.size() != 1 && dimensions[0] != 1) {
102 ORTX_CXX_API_THROW("[BasicTokenizer]: only support string scalar.", ORT_INVALID_GRAPH);
103 }
104
105 OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
106 std::vector<ustring> result = tokenizer_->Tokenize(ustring(input_data[0]));
107
108 FillTensorDataString(api_, ort_, context, result, output);
109}
110
111const char* CustomOpBasicTokenizer::GetName() const { return "BasicTokenizer"; };
112
113size_t CustomOpBasicTokenizer::GetInputTypeCount() const {
114 return 1;
115};
116
117ONNXTensorElementDataType CustomOpBasicTokenizer::GetInputType(size_t /*index*/) const {
118 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
119};
120
121size_t CustomOpBasicTokenizer::GetOutputTypeCount() const {
122 return 1;
123};
124
125ONNXTensorElementDataType CustomOpBasicTokenizer::GetOutputType(size_t /*index*/) const {
126 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
127};
128