microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
onnxruntime_extensions/cvt.py
65lines · modecode
| 1 | import json |
| 2 | from ._cuops import CustomOpConverter |
| 3 | |
| 4 | |
| 5 | class HFTokenizerConverter(CustomOpConverter): |
| 6 | def __init__(self, tokenizer): |
| 7 | self.tokenizer = tokenizer |
| 8 | |
| 9 | def bpe_tokenizer(self, **kwargs): |
| 10 | hf_gpt2_tokenizer = self.tokenizer |
| 11 | attrs = {'vocab': json.dumps( |
| 12 | hf_gpt2_tokenizer.encoder, separators=(',', ':'))} |
| 13 | sorted_merges = {v_: k_ for k_, |
| 14 | v_ in hf_gpt2_tokenizer.bpe_ranks.items()} |
| 15 | attrs['merges'] = '\n'.join("{} {}".format( |
| 16 | *sorted_merges[n_]) for n_ in range(len(sorted_merges))) |
| 17 | attrs.update(**kwargs) |
| 18 | return attrs |
| 19 | |
| 20 | def bpe_decoder(self, **kwargs): |
| 21 | decoder = self.tokenizer.decoder |
| 22 | id_vocab = "\n".join([decoder[_idx] for _idx in sorted(decoder)]) |
| 23 | # with open("id_vocab.txt", "w", encoding="utf-8") as f: |
| 24 | # f.write(id_vocab) |
| 25 | byte_decoder = self.tokenizer.byte_decoder |
| 26 | str_byte_decoder = "\n".join(["{}\t{}".format( |
| 27 | ord(_c), str(byte_decoder[_c])) for _c in byte_decoder]) |
| 28 | # with open("byte_decoder.txt", "w", encoding="utf-8") as f: |
| 29 | # f.write(str_byte_decoder) |
| 30 | all_special_ids = self.tokenizer.all_special_ids |
| 31 | added_tokens = self.tokenizer.added_tokens_decoder |
| 32 | str_all_special_ids = "\n".join([str(_id) for _id in all_special_ids]) |
| 33 | str_added_tokens = "\n".join( |
| 34 | ["{}\t{}".format(str(_id), added_tokens[_id]) for _id in added_tokens]) |
| 35 | kwargs.update({ |
| 36 | "id_vocab": id_vocab, |
| 37 | "byte_decoder": str_byte_decoder, |
| 38 | "added_tokens": str_added_tokens, |
| 39 | "all_special_ids": str_all_special_ids, |
| 40 | "skip_special_tokens": kwargs.get("skip_special_tokens", False) |
| 41 | }) |
| 42 | |
| 43 | return kwargs |
| 44 | |
| 45 | def clip_tokenizer(self, **kwargs): |
| 46 | hf_clip_tokenizer = self.tokenizer |
| 47 | attrs = {'vocab': json.dumps( |
| 48 | hf_clip_tokenizer.encoder, separators=(',', ':'))} |
| 49 | sorted_merges = {v_: k_ for k_, |
| 50 | v_ in hf_clip_tokenizer.bpe_ranks.items()} |
| 51 | attrs['merges'] = '\n'.join("{} {}".format( |
| 52 | *sorted_merges[n_]) for n_ in range(len(sorted_merges))) |
| 53 | attrs.update(**kwargs) |
| 54 | return attrs |
| 55 | |
| 56 | def roberta_tokenizer(self, **kwargs): |
| 57 | hf_roberta_tokenizer = self.tokenizer |
| 58 | attrs = {'vocab': json.dumps( |
| 59 | hf_roberta_tokenizer.encoder, separators=(',', ':'))} |
| 60 | sorted_merges = {v_: k_ for k_, |
| 61 | v_ in hf_roberta_tokenizer.bpe_ranks.items()} |
| 62 | attrs['merges'] = '\n'.join("{} {}".format( |
| 63 | *sorted_merges[n_]) for n_ in range(len(sorted_merges))) |
| 64 | attrs.update(**kwargs) |
| 65 | return attrs |
| 66 | |