microsoft/onnxruntime-extensions

Public

mirrored from https://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v0.4.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/tokenizer/basic_tokenizer.cc

130lines · modeblame

aef5ef1eMojimi4 years ago1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "string_utils.h"
5#include "basic_tokenizer.hpp"
6#include "string_tensor.h"
7#include <vector>
8#include <locale>
9#include <codecvt>
10#include <algorithm>
11
12BasicTokenizer::BasicTokenizer(bool do_lower_case, bool tokenize_chinese_chars, bool strip_accents, bool tokenize_punctuation, bool remove_control_chars):
13do_lower_case_(do_lower_case), tokenize_chinese_chars_(tokenize_chinese_chars), strip_accents_(strip_accents), tokenize_punctuation_(tokenize_punctuation),
14remove_control_chars_(remove_control_chars){}
15
16std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
17std::vector<ustring> result;
18ustring token;
19auto push_current_token_and_clear = [&result, &token]() {
20if (!token.empty()) {
21result.push_back(token);
22token.clear();
23}
24};
25
26auto push_single_char_and_clear = [&result, &token](char32_t c) {
27token.push_back(c);
28result.push_back(token);
29token.clear();
30};
31
32// strip accent first
33if (strip_accents_) {
34for (auto& c : text) {
35c = StripAccent(c);
36}
37}
38
39if (do_lower_case_) {
40for (auto& c : text) {
41c = ::tolower(c);
42}
43}
44
45for (auto c : text) {
46if (tokenize_chinese_chars_ && IsCJK(c)) {
47push_current_token_and_clear();
48push_single_char_and_clear(c);
49continue;
50}
51
52if (strip_accents_ && IsAccent(c)) {
53continue;
54}
55
9f3abe20Wenbing Li4 years ago56// 0x2019 unicode is not punctuation in some Linux platform,
57// to be consistent, take it as punctatuation always.
58if (tokenize_punctuation_ && (::iswpunct(c) || c == wint_t(0x2019))) {
aef5ef1eMojimi4 years ago59push_current_token_and_clear();
60push_single_char_and_clear(c);
61continue;
62}
63
64// split by space
cce66310Mojimi4 years ago65if (::iswspace(c)) {
aef5ef1eMojimi4 years ago66push_current_token_and_clear();
67continue;
68}
69
70// iscntrl will judge \t\f\n\r as control char
71// but it has been filter by isspace(c)
cce66310Mojimi4 years ago72if (remove_control_chars_ && ::iswcntrl(c)) {
aef5ef1eMojimi4 years ago73continue;
74}
75
76token.push_back(c);
77}
78
79push_current_token_and_clear();
80return result;
81}
82
83KernelBasicTokenizer::KernelBasicTokenizer(OrtApi api, const OrtKernelInfo* info) : BaseKernel(api, info) {
84bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
85bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
86bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
87bool tokenize_punctuation = TryToGetAttributeWithDefault("tokenize_punctuation", false);
88bool remove_control_chars = TryToGetAttributeWithDefault("strip_accents", true);
89
90tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents, tokenize_punctuation, remove_control_chars);
91}
92
93void KernelBasicTokenizer::Compute(OrtKernelContext* context) {
94// Setup inputs
95const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
96std::vector<std::string> input_data;
97GetTensorMutableDataString(api_, ort_, context, input, input_data);
98
99OrtTensorDimensions dimensions(ort_, input);
100if (dimensions.size() != 1 && dimensions[0] != 1) {
101ORT_CXX_API_THROW("[BasicTokenizer]: only support string scalar.", ORT_INVALID_GRAPH);
102}
103
104OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
105std::vector<ustring> result = tokenizer_->Tokenize(ustring(input_data[0]));
106
107FillTensorDataString(api_, ort_, context, result, output);
108}
109
110void* CustomOpBasicTokenizer::CreateKernel(OrtApi api, const OrtKernelInfo* info) const {
111return new KernelBasicTokenizer(api, info);
112};
113
114const char* CustomOpBasicTokenizer::GetName() const { return "BasicTokenizer"; };
115
116size_t CustomOpBasicTokenizer::GetInputTypeCount() const {
117return 1;
118};
119
120ONNXTensorElementDataType CustomOpBasicTokenizer::GetInputType(size_t /*index*/) const {
121return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
122};
123
124size_t CustomOpBasicTokenizer::GetOutputTypeCount() const {
125return 1;
126};
127
128ONNXTensorElementDataType CustomOpBasicTokenizer::GetOutputType(size_t /*index*/) const {
129return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
130};