microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
rel-0.7

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/tokenizer/bert_tokenizer.hpp

120lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5
6#include "ocos.h"
7#include "ustring.h"
8#include "string_utils.h"
9#include "string_tensor.h"
10#include "basic_tokenizer.hpp"
11
12#include <unordered_map>
13
14
15class BertTokenizerVocab final {
16 public:
17 explicit BertTokenizerVocab(std::string_view vocab);
18 bool FindToken(const ustring& token);
19 bool FindTokenId(const ustring& token, int32_t& token_id);
20 int32_t FindTokenId(const ustring& token);
21
22 private:
23 std::string raw_vocab_;
24 std::unordered_map<std::string_view, int32_t> vocab_;
25};
26
27class TruncateStrategy final {
28 public:
29 explicit TruncateStrategy(std::string_view strategy_name);
30 void Truncate(std::vector<int64_t>& ids, int32_t max_len);
31 void Truncate(std::vector<int64_t>& ids1, std::vector<int64_t>& ids2, int32_t max_len);
32
33 private:
34 enum TruncateStrategyType {
35 LONGEST_FIRST,
36 ONLY_FIRST,
37 ONLY_SECOND,
38 LONGEST_FROM_BACK
39 } strategy_;
40};
41
42// TODO: merge with the implementation of word piece tokenizer
43class WordpieceTokenizer final {
44 public:
45 WordpieceTokenizer(
46 std::shared_ptr<BertTokenizerVocab> vocab, ustring unk_token,
47 ustring suffix_indicator, int max_input_chars_per_word = 100);
48 std::vector<ustring> Tokenize(const ustring& text);
49 std::vector<ustring> Tokenize(const std::vector<ustring>& tokens);
50 std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
51
52 private:
53 int64_t max_input_chars_per_word_;
54 ustring suffix_indicator_;
55 ustring unk_token_;
56 int32_t unk_token_id_;
57 std::shared_ptr<BertTokenizerVocab> vocab_;
58
59 void GreedySearch(const ustring& token, std::vector<ustring>& tokenized_result);
60};
61
62class BertTokenizer final {
63 public:
64 BertTokenizer(const std::string& vocab, bool do_lower_case, bool do_basic_tokenize,
65 ustring unk_token, ustring sep_token, ustring pad_token, ustring cls_token,
66 ustring mask_token, bool tokenize_chinese_chars, bool strip_accents,
67 ustring suffix_indicator, int32_t max_len, const std::string& truncation_strategy);
68 std::vector<ustring> Tokenize(const ustring& text);
69 std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
70
71 void Truncate(std::vector<int64_t>& ids);
72 void Truncate(std::vector<int64_t>& ids1, std::vector<int64_t>& ids2);
73
74 std::vector<int64_t> AddSpecialToken(const std::vector<int64_t>& ids);
75 std::vector<int64_t> AddSpecialToken(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2);
76 std::vector<int64_t> GenerateTypeId(const std::vector<int64_t>& ids);
77 std::vector<int64_t> GenerateTypeId(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2);
78
79 private:
80 int32_t unk_token_id_ = 0;
81 int32_t sep_token_id_ = 0;
82 int32_t pad_token_id_ = 0;
83 int32_t cls_token_id_ = 0;
84 int32_t mask_token_id_ = 0;
85 int32_t max_length_ = 0;
86 bool do_basic_tokenize_ = false;
87 std::unique_ptr<TruncateStrategy> truncate_;
88 std::shared_ptr<BertTokenizerVocab> vocab_;
89 std::unique_ptr<BasicTokenizer> basic_tokenizer_;
90 std::unique_ptr<WordpieceTokenizer> wordpiece_tokenizer_;
91};
92
93struct KernelBertTokenizer : BaseKernel {
94 KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo& info);
95 void Compute(OrtKernelContext* context);
96
97 protected:
98 std::unique_ptr<BertTokenizer> tokenizer_;
99};
100
101struct CustomOpBertTokenizer : OrtW::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {
102 const char* GetName() const;
103 size_t GetInputTypeCount() const;
104 ONNXTensorElementDataType GetInputType(size_t index) const;
105 size_t GetOutputTypeCount() const;
106 ONNXTensorElementDataType GetOutputType(size_t index) const;
107};
108
109struct KernelHfBertTokenizer : KernelBertTokenizer {
110 KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo& info);
111 void Compute(OrtKernelContext* context);
112};
113
114struct CustomOpHfBertTokenizer : OrtW::CustomOpBase<CustomOpHfBertTokenizer, KernelHfBertTokenizer> {
115 const char* GetName() const;
116 size_t GetInputTypeCount() const;
117 ONNXTensorElementDataType GetInputType(size_t index) const;
118 size_t GetOutputTypeCount() const;
119 ONNXTensorElementDataType GetOutputType(size_t index) const;
120};
121