microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
352b2003bc6c4604f4285d64133d8bdf11549253

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tutorials/gpt2_e2e.py

125lines · modecode

1import os
2import numpy
3from transformers import AutoConfig
4from onnxruntime_extensions.onnxprocess import torch_wrapper as torch
5from onnxruntime_extensions.onnxprocess import trace_for_onnx, pyfunc_from_model, build_customop_model
6
7
8device = 'cpu'
9model_name_or_path = 'gpt2'
10gpt2_core_model_path = './gpt2.onnx'
11gpt2_encoder_model_path = './gpt2_tok.onnx'
12gpt2_decoder_model_path = './gpt2_dec.onnx'
13gpt2_full_model_path = './gpt2_full.onnx'
14
15
16input_text = ['best hotel in bay area', 'here is an example of gpt2 model']
17num_tokens_to_produce = 10
18
19
20def get_cache_directory():
21 cache_dir = os.path.join("../..", "cache_models")
22 if not os.path.exists(cache_dir):
23 os.makedirs(cache_dir)
24 return cache_dir
25
26
27def convert_models():
28 from transformers import GPT2Tokenizer # noqa
29 from onnxruntime.transformers.gpt2_helper import Gpt2Helper, MyGPT2LMHeadModel # noqa
30 cache_dir = get_cache_directory()
31
32 tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
33 build_customop_model('GPT2Tokenizer', gpt2_encoder_model_path, model=tokenizer)
34 build_customop_model('VectorToString', gpt2_decoder_model_path, decoder=tokenizer.decoder)
35
36 config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)
37 # remove the dependency from onnxruntime-tools.
38 model = MyGPT2LMHeadModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir)
39 Gpt2Helper.export_onnx(model, device, gpt2_core_model_path)
40
41
42def inference_and_dump_full_model(inputs):
43 config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=get_cache_directory())
44 core_model = pyfunc_from_model(gpt2_core_model_path)
45 gpt2_tokenizer = pyfunc_from_model(gpt2_encoder_model_path)
46
47 with trace_for_onnx(inputs, num_tokens_to_produce, names=gpt2_tokenizer.input_names) as tc_sess:
48
49 num_attention_heads = config.n_head
50 hidden_size = config.n_embd
51 num_layer = config.n_layer
52 eos_token_id = config.eos_token_id
53
54 inputs, num_tokens = tc_sess.get_inputs()
55 input_ids, attention_mask = gpt2_tokenizer(inputs, padding=True, padding_side='left')
56 attention_mask = attention_mask.type(torch.float)
57 position_ids = (attention_mask.long().cumsum(-1) - 1)
58 # position_ids.masked_fill_(position_ids < 0, 0)
59
60 # Empty Past State for generating first word
61 batch_size = input_ids.size()[0]
62 past_shape = [2, batch_size, num_attention_heads, 0, hidden_size // num_attention_heads]
63 empty_past = []
64 for _ in range(num_layer):
65 empty_past.append(torch.empty(*past_shape).type(torch.float32).to(device))
66
67 has_eos = torch.zeros(batch_size, dtype=torch.bool)
68
69 all_eos = torch.tensor(False, dtype=torch.bool)
70 past = empty_past
71 cfg = torch.control_flow()
72 for states in cfg.loop(num_tokens, ~all_eos, has_eos,
73 input_ids, position_ids, attention_mask.type(torch.float), *past):
74 _, has_eos, input_ids, position_ids, attention_mask, *past = states
75 # for _ in [1]:
76 outputs = core_model(input_ids, position_ids, attention_mask, *past)
77 next_token_logits = outputs[0][:, -1, :]
78 next_tokens = torch.argmax(next_token_logits, dim=-1)
79
80 has_eos = has_eos | (next_tokens == eos_token_id)
81 tokens_to_add = next_tokens.masked_fill(has_eos, eos_token_id)
82
83 # Update input_ids, attention_mask, position_ids and past
84 batch_size = 2
85 input_ids = tokens_to_add.clone().detach().reshape([batch_size, 1]).to(device)
86 position_ids = (position_ids[:, -1] + 1).reshape([batch_size, 1])
87 attention_mask = torch.cat([attention_mask, torch.ones(batch_size, 1).type(torch.float)], 1).to(device)
88
89 all_eos = torch.all(has_eos)
90
91 past = []
92 for i in range(num_layer):
93 past_i = torch.from_numpy(outputs[i + 1]) if \
94 isinstance(outputs[i + 1], numpy.ndarray) else outputs[i + 1].clone().detach()
95 past.append(past_i.to(device))
96
97 cfg.flow_output(~all_eos, has_eos, input_ids, position_ids, attention_mask, *past, tokens_to_add)
98
99 *_, all_token_ids = cfg.finalize()
100 gpt2_decoder = torch.op_from_model(gpt2_decoder_model_path)
101 # text_out = gpt2_decoder(tokens_to_add.unsqueeze(-1))
102 text_out = gpt2_decoder(all_token_ids.transpose(0, 1)[0].unsqueeze(-1))
103 tc_sess.save_as_onnx(gpt2_full_model_path, text_out)
104 return text_out
105
106
107# 1. Check if the gpt2 models is already exported, otherwise, they are converted.
108if not os.path.exists(gpt2_core_model_path) or \
109 not os.path.exists(gpt2_decoder_model_path) or \
110 not os.path.exists(gpt2_encoder_model_path):
111 convert_models()
112
113
114# 2. Run the inference with the pre and post process, trace the computation graph and build the all-in-one ONNX model
115output_ms = inference_and_dump_full_model(input_text)
116
117# 3. Inference on the all-in-one model
118from onnxruntime_extensions import PyOrtFunction
119full_model = PyOrtFunction.from_model(gpt2_full_model_path)
120output_text = full_model(input_text, num_tokens_to_produce)
121
122# 4. Test the result
123if not numpy.array_equal(output_ms.numpy(), output_text):
124 import warnings # noqa
125 warnings.warn("{}: The all-in-one model output is not the same\t{}".format(output_ms.numpy(), output_text))
126