microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
e0d48e255f28e5465f63e7fc141df1e1d533cc40

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/onnxprocess/_builder.py

52lines · modecode

1import json
2import pathlib
3from ._onnx_ops import make_model_ex
4from .._ortapi2 import SingleOpGraph, default_opset_domain, GPT2Tokenizer, VectorToString
5
6
7def is_path(name_or_buffer):
8 return isinstance(name_or_buffer, str) or isinstance(name_or_buffer, pathlib.Path)
9
10
11class _GPT2Tokenizer(GPT2Tokenizer):
12 @classmethod
13 def serialize_attr(cls, kwargs):
14 assert 'model' in kwargs, "Need model parameter to build the tokenizer"
15 hf_gpt2_tokenizer = kwargs['model']
16 attrs = {'vocab': json.dumps(hf_gpt2_tokenizer.encoder, separators=(',', ':'))}
17 sorted_merges = {v_: k_ for k_, v_ in hf_gpt2_tokenizer.bpe_ranks.items()}
18 attrs['merges'] = '\n'.join("{} {}".format(*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
19 return attrs
20
21
22class _VectorToString(VectorToString):
23 @classmethod
24 def serialize_attr(cls, kwargs):
25 assert 'decoder' in kwargs, "Need decoder parameter to build the tokenizer"
26 decoder = kwargs['decoder']
27 remapped = {v: [k] for k, v in decoder.items()}
28 attrs = dict(map=remapped, unk='<unknown>')
29 return super().serialize_attr(attrs)
30
31
32customop_mbuilder = {
33 c_.op_type(): c_ for c_ in (
34 _GPT2Tokenizer,
35 _VectorToString
36 )
37}
38
39
40def build_customop_model(op_type, f, opset_version=11, **attrs):
41 op_class = SingleOpGraph.get_op_class(op_type)
42 if op_type in customop_mbuilder:
43 op_class = customop_mbuilder[op_type]
44
45 graph = SingleOpGraph.build_my_graph(op_class, **attrs)
46 m = make_model_ex(graph, [(default_opset_domain(), 1)], opset_version)
47 if is_path(f):
48 with open(f, 'wb') as f_:
49 f_.write(m.SerializeToString())
50 else:
51 f.write(m.SerializeToString())
52 f.flush()
53