microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
rel-0.4

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

operators/tokenizer/bert_tokenizer.hpp

101lines · modecode

1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5
6#include <unordered_map>
7#include <vector>
8#include "ocos.h"
9#include "ustring.h"
10#include "string_utils.h"
11#include "string_tensor.h"
12#include "basic_tokenizer.hpp"
13
14class BertTokenizerVocab {
15 public:
16 explicit BertTokenizerVocab(std::string vocab);
17 bool FindToken(const ustring& token);
18 bool FindTokenId(const ustring& token, int32_t& token_id);
19 int32_t FindTokenId(const ustring& token);
20
21 private:
22 std::string raw_vocab_;
23 std::unordered_map<std::string_view, int32_t> vocab_;
24};
25
26class TruncateStrategy {
27 public:
28 explicit TruncateStrategy(std::string strategy_name);
29 void Truncate(std::vector<int64_t>& ids, int64_t max_len);
30 void Truncate(std::vector<int64_t>& input1, std::vector<int64_t>& input2, int64_t max_len);
31
32 private:
33 enum TruncateStrategyType {
34 LONGEST_FIRST,
35 ONLY_FIRST,
36 ONLY_SECOND,
37 LONGEST_FROM_BACK
38 } strategy_;
39};
40
41// TODO: merge with the implementation of word piece tokenizer
42class WordpieceTokenizer {
43 public:
44 WordpieceTokenizer(std::shared_ptr<BertTokenizerVocab> vocab, ustring unk_token, ustring suffix_indicator, int max_input_chars_per_word = 100);
45 std::vector<ustring> Tokenize(const ustring& text);
46 std::vector<ustring> Tokenize(const std::vector<ustring>& tokens);
47 std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
48
49 private:
50 int64_t max_input_chars_per_word_;
51 ustring suffix_indicator_;
52 ustring unk_token_;
53 int32_t unk_token_id_;
54 std::shared_ptr<BertTokenizerVocab> vocab_;
55
56 void GreedySearch(const ustring& token, std::vector<ustring>& tokenized_result);
57};
58
59class BertTokenizer {
60 public:
61 BertTokenizer(std::string vocab, bool do_lower_case, bool do_basic_tokenize,
62 ustring unk_token, ustring sep_token, ustring pad_token, ustring cls_token,
63 ustring mask_token, bool tokenize_chinese_chars, bool strip_accents,
64 ustring suffix_indicator);
65 std::vector<ustring> Tokenize(const ustring& text);
66 std::vector<int64_t> Encode(const std::vector<ustring>& tokens);
67 std::vector<int64_t> AddSpecialToken(const std::vector<int64_t>& ids);
68 std::vector<int64_t> AddSpecialToken(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2);
69 std::vector<int64_t> GenerateTypeId(const std::vector<int64_t>& ids);
70 std::vector<int64_t> GenerateTypeId(const std::vector<int64_t>& ids1, const std::vector<int64_t>& ids2);
71
72 private:
73 int32_t unk_token_id_;
74 int32_t sep_token_id_;
75 int32_t pad_token_id_;
76 int32_t cls_token_id_;
77 int32_t mask_token_id_;
78 bool do_basic_tokenize_;
79 std::shared_ptr<BertTokenizerVocab> vocab_;
80 std::shared_ptr<BasicTokenizer> basic_tokenizer_;
81 std::shared_ptr<WordpieceTokenizer> wordpiece_tokenizer_;
82};
83
84struct KernelBertTokenizer : BaseKernel {
85 KernelBertTokenizer(OrtApi api, const OrtKernelInfo* info);
86 void Compute(OrtKernelContext* context);
87
88 private:
89 std::shared_ptr<BertTokenizer> tokenizer_;
90 std::shared_ptr<TruncateStrategy> truncate_;
91 int max_length_;
92};
93
94struct CustomOpBertTokenizer : Ort::CustomOpBase<CustomOpBertTokenizer, KernelBertTokenizer> {
95 void* CreateKernel(OrtApi api, const OrtKernelInfo* info) const;
96 const char* GetName() const;
97 size_t GetInputTypeCount() const;
98 ONNXTensorElementDataType GetInputType(size_t index) const;
99 size_t GetOutputTypeCount() const;
100 ONNXTensorElementDataType GetOutputType(size_t index) const;
101};