microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
367f59c6fad2e820b9b1bb5807065c5d6c3886d0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

onnxruntime_extensions/onnxprocess/_builder.py

53lines · modecode

1import json
2import pathlib
3from ._onnx_ops import make_model_ex
4from .._cuops import SingleOpGraph, GPT2Tokenizer, VectorToString
5from .._ortapi2 import default_opset_domain
6
7
8def is_path(name_or_buffer):
9 return isinstance(name_or_buffer, str) or isinstance(name_or_buffer, pathlib.Path)
10
11
12class _GPT2Tokenizer(GPT2Tokenizer):
13 @classmethod
14 def serialize_attr(cls, kwargs):
15 assert 'model' in kwargs, "Need model parameter to build the tokenizer"
16 hf_gpt2_tokenizer = kwargs['model']
17 attrs = {'vocab': json.dumps(hf_gpt2_tokenizer.encoder, separators=(',', ':'))}
18 sorted_merges = {v_: k_ for k_, v_ in hf_gpt2_tokenizer.bpe_ranks.items()}
19 attrs['merges'] = '\n'.join("{} {}".format(*sorted_merges[n_]) for n_ in range(len(sorted_merges)))
20 return attrs
21
22
23class _VectorToString(VectorToString):
24 @classmethod
25 def serialize_attr(cls, kwargs):
26 assert 'decoder' in kwargs, "Need decoder parameter to build the tokenizer"
27 decoder = kwargs['decoder']
28 remapped = {v: [k] for k, v in decoder.items()}
29 attrs = dict(map=remapped, unk='<unknown>')
30 return super().serialize_attr(attrs)
31
32
33customop_mbuilder = {
34 c_.op_type(): c_ for c_ in (
35 _GPT2Tokenizer,
36 _VectorToString
37 )
38}
39
40
41def build_customop_model(op_type, f, opset_version=11, **attrs):
42 op_class = SingleOpGraph.get_op_class(op_type)
43 if op_type in customop_mbuilder:
44 op_class = customop_mbuilder[op_type]
45
46 graph = SingleOpGraph.build_my_graph(op_class, **attrs)
47 m = make_model_ex(graph, [(default_opset_domain(), 1)], opset_version)
48 if is_path(f):
49 with open(f, 'wb') as f_:
50 f_.write(m.SerializeToString())
51 else:
52 f.write(m.SerializeToString())
53 f.flush()
54