microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
onnxruntime_extensions/onnxprocess/_builder.py
53lines · modecode
| 1 | import json |
| 2 | import pathlib |
| 3 | from ._onnx_ops import make_model_ex |
| 4 | from .._cuops import SingleOpGraph, GPT2Tokenizer, VectorToString |
| 5 | from .._ortapi2 import default_opset_domain |
| 6 | |
| 7 | |
| 8 | def is_path(name_or_buffer): |
| 9 | return isinstance(name_or_buffer, str) or isinstance(name_or_buffer, pathlib.Path) |
| 10 | |
| 11 | |
| 12 | class _GPT2Tokenizer(GPT2Tokenizer): |
| 13 | @classmethod |
| 14 | def serialize_attr(cls, kwargs): |
| 15 | assert 'model' in kwargs, "Need model parameter to build the tokenizer" |
| 16 | hf_gpt2_tokenizer = kwargs['model'] |
| 17 | attrs = {'vocab': json.dumps(hf_gpt2_tokenizer.encoder, separators=(',', ':'))} |
| 18 | sorted_merges = {v_: k_ for k_, v_ in hf_gpt2_tokenizer.bpe_ranks.items()} |
| 19 | attrs['merges'] = '\n'.join("{} {}".format(*sorted_merges[n_]) for n_ in range(len(sorted_merges))) |
| 20 | return attrs |
| 21 | |
| 22 | |
| 23 | class _VectorToString(VectorToString): |
| 24 | @classmethod |
| 25 | def serialize_attr(cls, kwargs): |
| 26 | assert 'decoder' in kwargs, "Need decoder parameter to build the tokenizer" |
| 27 | decoder = kwargs['decoder'] |
| 28 | remapped = {v: [k] for k, v in decoder.items()} |
| 29 | attrs = dict(map=remapped, unk='<unknown>') |
| 30 | return super().serialize_attr(attrs) |
| 31 | |
| 32 | |
| 33 | customop_mbuilder = { |
| 34 | c_.op_type(): c_ for c_ in ( |
| 35 | _GPT2Tokenizer, |
| 36 | _VectorToString |
| 37 | ) |
| 38 | } |
| 39 | |
| 40 | |
| 41 | def build_customop_model(op_type, f, opset_version=11, **attrs): |
| 42 | op_class = SingleOpGraph.get_op_class(op_type) |
| 43 | if op_type in customop_mbuilder: |
| 44 | op_class = customop_mbuilder[op_type] |
| 45 | |
| 46 | graph = SingleOpGraph.build_my_graph(op_class, **attrs) |
| 47 | m = make_model_ex(graph, [(default_opset_domain(), 1)], opset_version) |
| 48 | if is_path(f): |
| 49 | with open(f, 'wb') as f_: |
| 50 | f_.write(m.SerializeToString()) |
| 51 | else: |
| 52 | f.write(m.SerializeToString()) |
| 53 | f.flush() |
| 54 | |