microsoft/onnxruntime-extensions
Publicmirrored fromhttps://github.com/microsoft/onnxruntime-extensionsAvailable
tutorials/gpt2_e2e.py
125lines · modecode
| 1 | import os |
| 2 | import numpy |
| 3 | from transformers import AutoConfig |
| 4 | from onnxruntime_extensions.onnxprocess import torch_wrapper as torch |
| 5 | from onnxruntime_extensions.onnxprocess import trace_for_onnx, pyfunc_from_model, build_customop_model |
| 6 | |
| 7 | |
| 8 | device = 'cpu' |
| 9 | model_name_or_path = 'gpt2' |
| 10 | gpt2_core_model_path = './gpt2.onnx' |
| 11 | gpt2_encoder_model_path = './gpt2_tok.onnx' |
| 12 | gpt2_decoder_model_path = './gpt2_dec.onnx' |
| 13 | gpt2_full_model_path = './gpt2_full.onnx' |
| 14 | |
| 15 | |
| 16 | input_text = ['best hotel in bay area', 'here is an example of gpt2 model'] |
| 17 | num_tokens_to_produce = 10 |
| 18 | |
| 19 | |
| 20 | def get_cache_directory(): |
| 21 | cache_dir = os.path.join("../..", "cache_models") |
| 22 | if not os.path.exists(cache_dir): |
| 23 | os.makedirs(cache_dir) |
| 24 | return cache_dir |
| 25 | |
| 26 | |
| 27 | def convert_models(): |
| 28 | from transformers import GPT2Tokenizer # noqa |
| 29 | from onnxruntime.transformers.gpt2_helper import Gpt2Helper, MyGPT2LMHeadModel # noqa |
| 30 | cache_dir = get_cache_directory() |
| 31 | |
| 32 | tokenizer = GPT2Tokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir) |
| 33 | build_customop_model('GPT2Tokenizer', gpt2_encoder_model_path, model=tokenizer) |
| 34 | build_customop_model('VectorToString', gpt2_decoder_model_path, decoder=tokenizer.decoder) |
| 35 | |
| 36 | config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir) |
| 37 | # remove the dependency from onnxruntime-tools. |
| 38 | model = MyGPT2LMHeadModel.from_pretrained(model_name_or_path, config=config, cache_dir=cache_dir) |
| 39 | Gpt2Helper.export_onnx(model, device, gpt2_core_model_path) |
| 40 | |
| 41 | |
| 42 | def inference_and_dump_full_model(inputs): |
| 43 | config = AutoConfig.from_pretrained(model_name_or_path, cache_dir=get_cache_directory()) |
| 44 | core_model = pyfunc_from_model(gpt2_core_model_path) |
| 45 | gpt2_tokenizer = pyfunc_from_model(gpt2_encoder_model_path) |
| 46 | |
| 47 | with trace_for_onnx(inputs, num_tokens_to_produce, names=gpt2_tokenizer.input_names) as tc_sess: |
| 48 | |
| 49 | num_attention_heads = config.n_head |
| 50 | hidden_size = config.n_embd |
| 51 | num_layer = config.n_layer |
| 52 | eos_token_id = config.eos_token_id |
| 53 | |
| 54 | inputs, num_tokens = tc_sess.get_inputs() |
| 55 | input_ids, attention_mask = gpt2_tokenizer(inputs, padding=True, padding_side='left') |
| 56 | attention_mask = attention_mask.type(torch.float) |
| 57 | position_ids = (attention_mask.long().cumsum(-1) - 1) |
| 58 | # position_ids.masked_fill_(position_ids < 0, 0) |
| 59 | |
| 60 | # Empty Past State for generating first word |
| 61 | batch_size = input_ids.size()[0] |
| 62 | past_shape = [2, batch_size, num_attention_heads, 0, hidden_size // num_attention_heads] |
| 63 | empty_past = [] |
| 64 | for _ in range(num_layer): |
| 65 | empty_past.append(torch.empty(*past_shape).type(torch.float32).to(device)) |
| 66 | |
| 67 | has_eos = torch.zeros(batch_size, dtype=torch.bool) |
| 68 | |
| 69 | all_eos = torch.tensor(False, dtype=torch.bool) |
| 70 | past = empty_past |
| 71 | cfg = torch.control_flow() |
| 72 | for states in cfg.loop(num_tokens, ~all_eos, has_eos, |
| 73 | input_ids, position_ids, attention_mask.type(torch.float), *past): |
| 74 | _, has_eos, input_ids, position_ids, attention_mask, *past = states |
| 75 | # for _ in [1]: |
| 76 | outputs = core_model(input_ids, position_ids, attention_mask, *past) |
| 77 | next_token_logits = outputs[0][:, -1, :] |
| 78 | next_tokens = torch.argmax(next_token_logits, dim=-1) |
| 79 | |
| 80 | has_eos = has_eos | (next_tokens == eos_token_id) |
| 81 | tokens_to_add = next_tokens.masked_fill(has_eos, eos_token_id) |
| 82 | |
| 83 | # Update input_ids, attention_mask, position_ids and past |
| 84 | batch_size = 2 |
| 85 | input_ids = tokens_to_add.clone().detach().reshape([batch_size, 1]).to(device) |
| 86 | position_ids = (position_ids[:, -1] + 1).reshape([batch_size, 1]) |
| 87 | attention_mask = torch.cat([attention_mask, torch.ones(batch_size, 1).type(torch.float)], 1).to(device) |
| 88 | |
| 89 | all_eos = torch.all(has_eos) |
| 90 | |
| 91 | past = [] |
| 92 | for i in range(num_layer): |
| 93 | past_i = torch.from_numpy(outputs[i + 1]) if \ |
| 94 | isinstance(outputs[i + 1], numpy.ndarray) else outputs[i + 1].clone().detach() |
| 95 | past.append(past_i.to(device)) |
| 96 | |
| 97 | cfg.flow_output(~all_eos, has_eos, input_ids, position_ids, attention_mask, *past, tokens_to_add) |
| 98 | |
| 99 | *_, all_token_ids = cfg.finalize() |
| 100 | gpt2_decoder = torch.op_from_model(gpt2_decoder_model_path) |
| 101 | # text_out = gpt2_decoder(tokens_to_add.unsqueeze(-1)) |
| 102 | text_out = gpt2_decoder(all_token_ids.transpose(0, 1)[0].unsqueeze(-1)) |
| 103 | tc_sess.save_as_onnx(gpt2_full_model_path, text_out) |
| 104 | return text_out |
| 105 | |
| 106 | |
| 107 | # 1. Check if the gpt2 models is already exported, otherwise, they are converted. |
| 108 | if not os.path.exists(gpt2_core_model_path) or \ |
| 109 | not os.path.exists(gpt2_decoder_model_path) or \ |
| 110 | not os.path.exists(gpt2_encoder_model_path): |
| 111 | convert_models() |
| 112 | |
| 113 | |
| 114 | # 2. Run the inference with the pre and post process, trace the computation graph and build the all-in-one ONNX model |
| 115 | output_ms = inference_and_dump_full_model(input_text) |
| 116 | |
| 117 | # 3. Inference on the all-in-one model |
| 118 | from onnxruntime_extensions import PyOrtFunction |
| 119 | full_model = PyOrtFunction.from_model(gpt2_full_model_path) |
| 120 | output_text = full_model(input_text, num_tokens_to_produce) |
| 121 | |
| 122 | # 4. Test the result |
| 123 | if not numpy.array_equal(output_ms.numpy(), output_text): |
| 124 | import warnings # noqa |
| 125 | warnings.warn("{}: The all-in-one model output is not the same\t{}".format(output_ms.numpy(), output_text)) |
| 126 | |