microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
skottmckay/BuildInfra_AndTestImageLibs

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/tokenizer/bert_tokenizer.hpp

121lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5
6#include <unordered_map>
7#include <vector>
8#include "ocos.h"
9#include "ustring.h"
10#include "string_utils.h"
11#include "string_tensor.h"
12#include "basic_tokenizer.hpp"
13
14class BertTokenizerVocab final {
15 public:
16 explicit BertTokenizerVocab(std::string_view vocab);
17 bool FindToken(const ustring& token);
18 bool FindTokenId(const ustring& token, int32_t& token_id);
19 int32_t FindTokenId(const ustring& token);
20
21 private:
22 std::string raw_vocab_;
23 std::unordered_map<std::string_view, int32_t> vocab_;
24};
25
26class TruncateStrategy final {
27 public:
28 explicit TruncateStrategy(std::string_view strategy_name);
29 void Truncate(std::vector<int64_t>& ids, int32_t max_len);
30 void Truncate(std::vector<int64_t>& ids1, std::vector<int64_t>& ids2, int32_t max_len);
31
32 private:
33 enum TruncateStrategyType {
34 LONGEST_FIRST,
35 ONLY_FIRST,
36 ONLY_SECOND,
37 LONGEST_FROM_BACK
38 } strategy_;
39};
40
41// TODO: merge with the implementation of word piece tokenizer
42class WordpieceTokenizer final {
43 public:
44 WordpieceTokenizer(
45 std::shared_ptr<BertTokenizerVocab> vocab, ustring unk_token,
46 ustring suffix_indicator, int max_input_chars_per_word = 100);
47 std::vector<ustring> Tokenize(const ustring& text);
48 std::vector<ustring> Tokenize(const std::vector<ustring>& tokens);
49 std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
50
51 private:
52 int64_t max_input_chars_per_word_;
53 ustring suffix_indicator_;
54 ustring unk_token_;
55 int32_t unk_token_id_;
56 std::shared_ptr<BertTokenizerVocab> vocab_;
57
58 void GreedySearch(const ustring& token, std::vector<ustring>& tokenized_result);
59};
60
61class BertTokenizer final {
62 public:
63 BertTokenizer(const std::string& vocab, bool do_lower_case, bool do_basic_tokenize,
64 ustring unk_token, ustring sep_token, ustring pad_token, ustring cls_token,
65 ustring mask_token, bool tokenize_chinese_chars, bool strip_accents,
66 ustring suffix_indicator, int32_t max_len, const std::string& truncation_strategy);
67 std::vector<ustring> Tokenize(const ustring& text);
68 std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
69
70 void Truncate(std::vector<int64_t>& ids);
71 void Truncate(std::vector<int64_t>& ids1, std::vector<int64_t>& ids2);
72
73 std::vector<int64_t> AddSpecialToken(const std::vector<int64_t>& ids);
74 std::vector<int64_t> AddSpecialToken(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2);
75 std::vector<int64_t> GenerateTypeId(const std::vector<int64_t>& ids);
76 std::vector<int64_t> GenerateTypeId(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2);
77
78 private:
79 int32_t unk_token_id_ = 0;
80 int32_t sep_token_id_ = 0;
81 int32_t pad_token_id_ = 0;
82 int32_t cls_token_id_ = 0;
83 int32_t mask_token_id_ = 0;
84 int32_t max_length_ = 0;
85 bool do_basic_tokenize_ = false;
86 std::unique_ptr<TruncateStrategy> truncate_;
87 std::shared_ptr<BertTokenizerVocab> vocab_;
88 std::unique_ptr<BasicTokenizer> basic_tokenizer_;
89 std::unique_ptr<WordpieceTokenizer> wordpiece_tokenizer_;
90};
91
92struct KernelBertTokenizer : BaseKernel {
93 KernelBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
94 void Compute(OrtKernelContext* context);
95
96 protected:
97 std::unique_ptr<BertTokenizer> tokenizer_;
98};
99
100struct CustomOpBertTokenizer : Ort::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {
101 void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
102 const char* GetName() const;
103 size_t GetInputTypeCount() const;
104 ONNXTensorElementDataType GetInputType(size_t index) const;
105 size_t GetOutputTypeCount() const;
106 ONNXTensorElementDataType GetOutputType(size_t index) const;
107};
108
109struct KernelHfBertTokenizer : KernelBertTokenizer {
110 KernelHfBertTokenizer(const OrtApi& api, const OrtKernelInfo* info);
111 void Compute(OrtKernelContext* context);
112};
113
114struct CustomOpHfBertTokenizer : Ort::CustomOpBase<CustomOpHfBertTokenizer, KernelHfBertTokenizer> {
115 void* CreateKernel(const OrtApi& api, const OrtKernelInfo* info) const;
116 const char* GetName() const;
117 size_t GetInputTypeCount() const;
118 ONNXTensorElementDataType GetInputType(size_t index) const;
119 size_t GetOutputTypeCount() const;
120 ONNXTensorElementDataType GetOutputType(size_t index) const;
121};
122