microsoft/onnxruntime-extensions

Public

mirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
352b2003bc6c4604f4285d64133d8bdf11549253

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tutorials/gpt2bs.py

138lines · modecode

1import os
2from transformers import AutoConfig
3from onnxruntime_extensions.onnxprocess import trace_for_onnx, pyfunc_from_model, build_customop_model
4from onnxruntime_extensions.onnxprocess import torch_wrapper as torch
5
6# Create a cache directory to store pretrained model.
7cache_dir = os.path.join(".", "cache_models")
8if not os.path.exists(cache_dir):
9 os.makedirs(cache_dir)
10
11
12model_name_or_path = "gpt2"
13device = "cpu"
14beam_width = 4
15config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)
16num_attention_heads = config.n_head
17hidden_size = config.n_embd
18num_layer = config.n_layer
19gpt2_full_model_path = "./gpt2_full.onnx"
20
21# this model was generated by this script
22# https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/python/tools/transformers/notebooks/Inference_GPT2-OneStepSearch_OnnxRuntime_CPU.ipynb
23onnx_model_path = "gpt2_one_step_search.onnx"
24func_one_step = pyfunc_from_model(onnx_model_path)
25
26
27def get_tokenizer(model_name_or_path, cache_dir):
28 from transformers import GPT2Tokenizer # noqa
29 tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
30
31 tokenizer.padding_side = "left"
32 tokenizer.pad_token = tokenizer.eos_token
33 gpt2_encoder_model_path = './gpt2_tok.onnx'
34 build_customop_model('GPT2Tokenizer', gpt2_encoder_model_path, model=tokenizer)
35 return tokenizer, pyfunc_from_model(gpt2_encoder_model_path)
36
37
38def inference_and_dump_full_model(tokenizer, func_tokenizer, input_text, num_tokens_to_produce = 30):
39
40 with trace_for_onnx(input_text, num_tokens_to_produce, names=func_tokenizer.input_names) as tc_sess:
41 inputs, num_tokens = tc_sess.get_inputs()
42 input_ids, attention_mask = func_tokenizer(inputs, padding=True, padding_side='left')
43 attention_mask = attention_mask.type(torch.float)
44 position_ids = (attention_mask.long().cumsum(-1) - 1)
45 # position_ids.masked_fill_(position_ids < 0, 0)
46
47 # Empty Past State for generating first word
48 # batch_size = input_ids.size()[0]
49 batch_size = 1
50 past_shape = [2, batch_size, num_attention_heads, 0, hidden_size // num_attention_heads]
51 empty_past = []
52 for _ in range(num_layer):
53 empty_past.append(torch.empty(*past_shape).type(torch.float32).to(device))
54
55 beam_select_idx = torch.zeros([1, batch_size]).long()
56 input_log_probs = torch.zeros([batch_size, 1])
57 input_unfinished_sents = torch.ones([batch_size, 1], dtype=torch.bool)
58 prev_step_scores = torch.zeros([batch_size, 1])
59 beam_size = beam_width
60 prev_step_results = input_ids.clone().detach().to(device)
61
62 cfg = torch.control_flow()
63 for states in cfg.loop(num_tokens, torch.tensor(True), input_ids, position_ids,
64 attention_mask, beam_select_idx, input_log_probs,
65 input_unfinished_sents, prev_step_results, prev_step_scores, *empty_past):
66 step = states[0]
67 states[1].symbolic_shape = ['batch_size', 'seq_len']
68 states[2].symbolic_shape = ['batch_size', 'seq_len']
69 states[3].symbolic_shape = ['batch_size', 'all_seq_len']
70 states[4].symbolic_shape = [1, 'batch_size']
71
72 # prev_step_results
73 states[7].symbolic_shape = ['batch_size', 'total_seq_len']
74
75 for st_ in states[-num_layer:]:
76 st_.symbolic_shape = [2, 'batch_size', num_attention_heads, 'past_seq_len', hidden_size // num_attention_heads]
77
78 prev_attention_mask = states[3]
79 outputs = func_one_step(*states[1:])
80 last_state = outputs[0].clone().detach().cpu()
81 input_ids = last_state.reshape([batch_size * beam_size, -1]).to(device)
82
83 input_unfinished_sents_id = -3
84 prev_step_results = outputs[-2].clone().detach().to(device)
85 # position_ids = (torch.tensor([context_length + step - 1
86 # ]).unsqueeze(0).repeat(batch_size * beam_size, 1).to(device))
87 position_ids = torch.zeros([batch_size * beam_size, 1], dtype=torch.int64) + attention_mask.size()[-1]
88 factor = (~step.type(torch.bool)).type(torch.int64)
89 prev_attention_mask = prev_attention_mask.repeat(factor * (batch_size * beam_size - 1) + 1, 1).to(device)
90 attention_mask = torch.cat(
91 [
92 prev_attention_mask,
93 torch.ones([batch_size * beam_size, 1], dtype=torch.float),
94 ],
95 1,
96 ).to(device)
97
98 beam_select_idx = outputs[input_unfinished_sents_id - 2].clone().detach().to(device)
99 input_log_probs = outputs[input_unfinished_sents_id - 1].clone().detach().to(device)
100 input_unfinished_sents = outputs[input_unfinished_sents_id].clone().detach().to(device)
101 prev_step_scores = outputs[-1].clone().detach().to(device)
102
103 past = []
104 for i in range(num_layer):
105 past_i = outputs[i + 1].clone().detach()
106 past.append(past_i.to(device))
107
108
109 any_unfinished = input_unfinished_sents.any()
110 input_ids.symbolic_shape = ['total_batch_size', 'seq_len']
111 position_ids.symbolic_shape = ['total_batch_size', 'seq_len']
112 attention_mask.symbolic_shape = ['total_batch_size', 'all_seq_len']
113 prev_step_results.symbolic_shape = ['total_batch_size', 'step_seq_len']
114 for st_ in past:
115 st_.symbolic_shape = [2, 'total_batch_size', num_attention_heads, 'all_seq_len', hidden_size // num_attention_heads]
116 cfg.flow_output(any_unfinished, input_ids,
117 position_ids, attention_mask, beam_select_idx,
118 input_log_probs, input_unfinished_sents, prev_step_results, prev_step_scores, *past)
119
120 result_id = 6
121 all_token_ids = cfg.finalize()[result_id]
122 tc_sess.save_as_onnx(gpt2_full_model_path, all_token_ids)
123
124 print(tokenizer.decode(all_token_ids.t[0], skip_special_tokens=True))
125
126
127def verify_bsfull_model(input_text):
128 from onnxruntime_extensions import PyOrtFunction
129 gpt2_all = PyOrtFunction.from_model(gpt2_full_model_path)
130 outputs = gpt2_all(input_text, 30)
131 print(tokenizer.decode(outputs[0], skip_special_tokens=True))
132
133
134if __name__ == "__main__":
135 tokenizer, func_tokenizer = get_tokenizer(model_name_or_path, cache_dir)
136 input_text = ['best hotel in bay area.']
137 inference_and_dump_full_model(tokenizer, func_tokenizer, input_text)
138 verify_bsfull_model(input_text)