microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
tutorials/gpt2bs.py
138lines · modecode
| 1 | import os |
| 2 | from transformers import AutoConfig |
| 3 | from onnxruntime_extensions.onnxprocess import trace_for_onnx, pyfunc_from_model, build_customop_model |
| 4 | from onnxruntime_extensions.onnxprocess import torch_wrapper as torch |
| 5 | |
| 6 | # Create a cache directory to store pretrained model. |
| 7 | cache_dir = os.path.join(".", "cache_models") |
| 8 | if not os.path.exists(cache_dir): |
| 9 | os.makedirs(cache_dir) |
| 10 | |
| 11 | |
| 12 | model_name_or_path = "gpt2" |
| 13 | device = "cpu" |
| 14 | beam_width = 4 |
| 15 | config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir) |
| 16 | num_attention_heads = config.n_head |
| 17 | hidden_size = config.n_embd |
| 18 | num_layer = config.n_layer |
| 19 | gpt2_full_model_path = "./gpt2_full.onnx" |
| 20 | |
| 21 | # this model was generated by this script |
| 22 | # https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/python/tools/transformers/notebooks/Inference_GPT2-OneStepSearch_OnnxRuntime_CPU.ipynb |
| 23 | onnx_model_path = "gpt2_one_step_search.onnx" |
| 24 | func_one_step = pyfunc_from_model(onnx_model_path) |
| 25 | |
| 26 | |
| 27 | def get_tokenizer(model_name_or_path, cache_dir): |
| 28 | from transformers import GPT2Tokenizer # noqa |
| 29 | tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir) |
| 30 | |
| 31 | tokenizer.padding_side = "left" |
| 32 | tokenizer.pad_token = tokenizer.eos_token |
| 33 | gpt2_encoder_model_path = './gpt2_tok.onnx' |
| 34 | build_customop_model('GPT2Tokenizer', gpt2_encoder_model_path, model=tokenizer) |
| 35 | return tokenizer, pyfunc_from_model(gpt2_encoder_model_path) |
| 36 | |
| 37 | |
| 38 | def inference_and_dump_full_model(tokenizer, func_tokenizer, input_text, num_tokens_to_produce = 30): |
| 39 | |
| 40 | with trace_for_onnx(input_text, num_tokens_to_produce, names=func_tokenizer.input_names) as tc_sess: |
| 41 | inputs, num_tokens = tc_sess.get_inputs() |
| 42 | input_ids, attention_mask = func_tokenizer(inputs, padding=True, padding_side='left') |
| 43 | attention_mask = attention_mask.type(torch.float) |
| 44 | position_ids = (attention_mask.long().cumsum(-1) - 1) |
| 45 | # position_ids.masked_fill_(position_ids < 0, 0) |
| 46 | |
| 47 | # Empty Past State for generating first word |
| 48 | # batch_size = input_ids.size()[0] |
| 49 | batch_size = 1 |
| 50 | past_shape = [2, batch_size, num_attention_heads, 0, hidden_size // num_attention_heads] |
| 51 | empty_past = [] |
| 52 | for _ in range(num_layer): |
| 53 | empty_past.append(torch.empty(*past_shape).type(torch.float32).to(device)) |
| 54 | |
| 55 | beam_select_idx = torch.zeros([1, batch_size]).long() |
| 56 | input_log_probs = torch.zeros([batch_size, 1]) |
| 57 | input_unfinished_sents = torch.ones([batch_size, 1], dtype=torch.bool) |
| 58 | prev_step_scores = torch.zeros([batch_size, 1]) |
| 59 | beam_size = beam_width |
| 60 | prev_step_results = input_ids.clone().detach().to(device) |
| 61 | |
| 62 | cfg = torch.control_flow() |
| 63 | for states in cfg.loop(num_tokens, torch.tensor(True), input_ids, position_ids, |
| 64 | attention_mask, beam_select_idx, input_log_probs, |
| 65 | input_unfinished_sents, prev_step_results, prev_step_scores, *empty_past): |
| 66 | step = states[0] |
| 67 | states[1].symbolic_shape = ['batch_size', 'seq_len'] |
| 68 | states[2].symbolic_shape = ['batch_size', 'seq_len'] |
| 69 | states[3].symbolic_shape = ['batch_size', 'all_seq_len'] |
| 70 | states[4].symbolic_shape = [1, 'batch_size'] |
| 71 | |
| 72 | # prev_step_results |
| 73 | states[7].symbolic_shape = ['batch_size', 'total_seq_len'] |
| 74 | |
| 75 | for st_ in states[-num_layer:]: |
| 76 | st_.symbolic_shape = [2, 'batch_size', num_attention_heads, 'past_seq_len', hidden_size // num_attention_heads] |
| 77 | |
| 78 | prev_attention_mask = states[3] |
| 79 | outputs = func_one_step(*states[1:]) |
| 80 | last_state = outputs[0].clone().detach().cpu() |
| 81 | input_ids = last_state.reshape([batch_size * beam_size, -1]).to(device) |
| 82 | |
| 83 | input_unfinished_sents_id = -3 |
| 84 | prev_step_results = outputs[-2].clone().detach().to(device) |
| 85 | # position_ids = (torch.tensor([context_length + step - 1 |
| 86 | # ]).unsqueeze(0).repeat(batch_size * beam_size, 1).to(device)) |
| 87 | position_ids = torch.zeros([batch_size * beam_size, 1], dtype=torch.int64) + attention_mask.size()[-1] |
| 88 | factor = (~step.type(torch.bool)).type(torch.int64) |
| 89 | prev_attention_mask = prev_attention_mask.repeat(factor * (batch_size * beam_size - 1) + 1, 1).to(device) |
| 90 | attention_mask = torch.cat( |
| 91 | [ |
| 92 | prev_attention_mask, |
| 93 | torch.ones([batch_size * beam_size, 1], dtype=torch.float), |
| 94 | ], |
| 95 | 1, |
| 96 | ).to(device) |
| 97 | |
| 98 | beam_select_idx = outputs[input_unfinished_sents_id - 2].clone().detach().to(device) |
| 99 | input_log_probs = outputs[input_unfinished_sents_id - 1].clone().detach().to(device) |
| 100 | input_unfinished_sents = outputs[input_unfinished_sents_id].clone().detach().to(device) |
| 101 | prev_step_scores = outputs[-1].clone().detach().to(device) |
| 102 | |
| 103 | past = [] |
| 104 | for i in range(num_layer): |
| 105 | past_i = outputs[i + 1].clone().detach() |
| 106 | past.append(past_i.to(device)) |
| 107 | |
| 108 | |
| 109 | any_unfinished = input_unfinished_sents.any() |
| 110 | input_ids.symbolic_shape = ['total_batch_size', 'seq_len'] |
| 111 | position_ids.symbolic_shape = ['total_batch_size', 'seq_len'] |
| 112 | attention_mask.symbolic_shape = ['total_batch_size', 'all_seq_len'] |
| 113 | prev_step_results.symbolic_shape = ['total_batch_size', 'step_seq_len'] |
| 114 | for st_ in past: |
| 115 | st_.symbolic_shape = [2, 'total_batch_size', num_attention_heads, 'all_seq_len', hidden_size // num_attention_heads] |
| 116 | cfg.flow_output(any_unfinished, input_ids, |
| 117 | position_ids, attention_mask, beam_select_idx, |
| 118 | input_log_probs, input_unfinished_sents, prev_step_results, prev_step_scores, *past) |
| 119 | |
| 120 | result_id = 6 |
| 121 | all_token_ids = cfg.finalize()[result_id] |
| 122 | tc_sess.save_as_onnx(gpt2_full_model_path, all_token_ids) |
| 123 | |
| 124 | print(tokenizer.decode(all_token_ids.t[0], skip_special_tokens=True)) |
| 125 | |
| 126 | |
| 127 | def verify_bsfull_model(input_text): |
| 128 | from onnxruntime_extensions import PyOrtFunction |
| 129 | gpt2_all = PyOrtFunction.from_model(gpt2_full_model_path) |
| 130 | outputs = gpt2_all(input_text, 30) |
| 131 | print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |
| 132 | |
| 133 | |
| 134 | if __name__ == "__main__": |
| 135 | tokenizer, func_tokenizer = get_tokenizer(model_name_or_path, cache_dir) |
| 136 | input_text = ['best hotel in bay area.'] |
| 137 | inference_and_dump_full_model(tokenizer, func_tokenizer, input_text) |
| 138 | verify_bsfull_model(input_text) |