microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
skottmckay/BuildInfra_AndTestImageLibs

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/tokenizer/basic_tokenizer.cc

132lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#include "string_utils.h"
5#include "basic_tokenizer.hpp"
6#include "string_tensor.h"
7#include <vector>
8#include <locale>
9#include <codecvt>
10#include <algorithm>
11
12BasicTokenizer::BasicTokenizer(bool do_lower_case, bool tokenize_chinese_chars, bool strip_accents, bool tokenize_punctuation, bool remove_control_chars):
13 do_lower_case_(do_lower_case),
14 strip_accents_(strip_accents),
15 tokenize_chinese_chars_(tokenize_chinese_chars),
16 tokenize_punctuation_(tokenize_punctuation),
17 remove_control_chars_(remove_control_chars)
18{}
19
20std::vector<ustring> BasicTokenizer::Tokenize(ustring text) {
21 std::vector<ustring> result;
22 ustring token;
23 auto push_current_token_and_clear = [&result, &token]() {
24 if (!token.empty()) {
25 result.push_back(token);
26 token.clear();
27 }
28 };
29
30 auto push_single_char_and_clear = [&result, &token](char32_t c) {
31 token.push_back(c);
32 result.push_back(token);
33 token.clear();
34 };
35
36 // strip accent first
37 if (strip_accents_) {
38 for (auto& c : text) {
39 c = StripAccent(c);
40 }
41 }
42
43 if (do_lower_case_) {
44 for (auto& c : text) {
45 c = ToLower(c);
46 }
47 }
48
49 for (auto c : text) {
50 if (tokenize_chinese_chars_ && IsCJK(c)) {
51 push_current_token_and_clear();
52 push_single_char_and_clear(c);
53 continue;
54 }
55
56 if (strip_accents_ && IsAccent(c)) {
57 continue;
58 }
59
60 // 0x2019 unicode is not punctuation in some Linux platform,
61 // to be consistent, take it as punctuation.
62 if (tokenize_punctuation_ && IsPunct(c)) {
63 push_current_token_and_clear();
64 push_single_char_and_clear(c);
65 continue;
66 }
67
68 // split by space
69 if (IsSpace(c)) {
70 push_current_token_and_clear();
71 continue;
72 }
73
74 if (remove_control_chars_ && IsControl(c)) {
75 continue;
76 }
77
78 token.push_back(c);
79 }
80
81 push_current_token_and_clear();
82 return result;
83}
84
85KernelBasicTokenizer::KernelBasicTokenizer(const OrtApi& api, const OrtKernelInfo* info) : BaseKernel(api, info) {
86 bool do_lower_case = TryToGetAttributeWithDefault("do_lower_case", true);
87 bool tokenize_chinese_chars = TryToGetAttributeWithDefault("tokenize_chinese_chars", true);
88 bool strip_accents = TryToGetAttributeWithDefault("strip_accents", false);
89 bool tokenize_punctuation = TryToGetAttributeWithDefault("tokenize_punctuation", false);
90 bool remove_control_chars = TryToGetAttributeWithDefault("remove_control_chars", true);
91
92 tokenizer_ = std::make_shared<BasicTokenizer>(do_lower_case, tokenize_chinese_chars, strip_accents, tokenize_punctuation, remove_control_chars);
93}
94
95void KernelBasicTokenizer::Compute(OrtKernelContext* context) {
96 // Setup inputs
97 const OrtValue* input = ort_.KernelContext_GetInput(context, 0);
98 std::vector<std::string> input_data;
99 GetTensorMutableDataString(api_, ort_, context, input, input_data);
100
101 OrtTensorDimensions dimensions(ort_, input);
102 if (dimensions.size() != 1 && dimensions[0] != 1) {
103 ORT_CXX_API_THROW("[BasicTokenizer]: only support string scalar.", ORT_INVALID_GRAPH);
104 }
105
106 OrtValue* output = ort_.KernelContext_GetOutput(context, 0, dimensions.data(), dimensions.size());
107 std::vector<ustring> result = tokenizer_->Tokenize(ustring(input_data[0]));
108
109 FillTensorDataString(api_, ort_, context, result, output);
110}
111
112void* CustomOpBasicTokenizer::CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const {
113 return CreateKernelImpl(api, info);
114};
115
116const char* CustomOpBasicTokenizer::GetName() const { return "BasicTokenizer"; };
117
118size_t CustomOpBasicTokenizer::GetInputTypeCount() const {
119 return 1;
120};
121
122ONNXTensorElementDataType CustomOpBasicTokenizer::GetInputType(size_t /*index*/) const {
123 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
124};
125
126size_t CustomOpBasicTokenizer::GetOutputTypeCount() const {
127 return 1;
128};
129
130ONNXTensorElementDataType CustomOpBasicTokenizer::GetOutputType(size_t /*index*/) const {
131 return ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING;
132};
133