microsoft/TypeAgent
Publicmirrored fromhttps://github.com/microsoft/TypeAgentAvailable
python/fineTuning/eval.py
100lines · modecode
| 1 | # Copyright (c) Microsoft Corporation and Henry Lucco. |
| 2 | # Licensed under the MIT License. |
| 3 | |
| 4 | from chaparral.util.datareader import DataReader |
| 5 | from chaparral.train.hf_model import HFModel |
| 6 | from chaparral.train.hf_params import HFParams |
| 7 | import argparse |
| 8 | import torch |
| 9 | from unsloth import FastLanguageModel |
| 10 | from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer |
| 11 | |
| 12 | |
| 13 | def parse_args(): |
| 14 | parser = argparse.ArgumentParser( |
| 15 | description="Fine-tune a model with given dataset.") |
| 16 | parser.add_argument("--dataset_file", help="Path to the dataset file.") |
| 17 | parser.add_argument("--model_name", help="Name of the model to fine-tune.") |
| 18 | parser.add_argument("--params", help="Path to params file") |
| 19 | return parser.parse_args() |
| 20 | |
| 21 | |
| 22 | if __name__ == "__main__": |
| 23 | args = parse_args() |
| 24 | dataset_file = args.dataset_file |
| 25 | params_file = args.params |
| 26 | |
| 27 | # load params |
| 28 | params = HFParams.from_file(params_file) |
| 29 | |
| 30 | # load dataset |
| 31 | dataset = DataReader().load_text_file(dataset_file) |
| 32 | |
| 33 | # format data into train and eval sets |
| 34 | train_set, eval_set = dataset.create_train_eval_sets() |
| 35 | |
| 36 | |
| 37 | # model = HFModel(params) |
| 38 | # Load the model |
| 39 | """ |
| 40 | model = AutoModelForCausalLM.from_pretrained( |
| 41 | # "./hf_output_llama_1b_half_epoch/checkpoint-10", |
| 42 | # "./test_output" |
| 43 | # "meta-llama/Llama-3.2-1B-Instruct" |
| 44 | "./llama_peft_1/checkpoint-60" |
| 45 | ) |
| 46 | """ |
| 47 | |
| 48 | model = AutoModelForCausalLM.from_pretrained( |
| 49 | "./llama_peft_1/checkpoint-20", |
| 50 | load_in_4bit = True, |
| 51 | ) |
| 52 | tokenizer = AutoTokenizer.from_pretrained( |
| 53 | # "google/gemma-2-2b", |
| 54 | # "meta-llama/Llama-3.2-1B-Instruct" |
| 55 | "./llama_peft_1/checkpoint-20" |
| 56 | # "./test_output" |
| 57 | ) |
| 58 | |
| 59 | messages = [ # Change below! |
| 60 | {"role": "user", "content": dataset.get_filled_prompt("The quick brown fox jumps over the lazy dog")}, |
| 61 | ] |
| 62 | input_ids = tokenizer.apply_chat_template( |
| 63 | messages, |
| 64 | add_generation_prompt = True, |
| 65 | return_tensors = "pt", |
| 66 | ).to("cuda") |
| 67 | |
| 68 | text_streamer = TextStreamer(tokenizer, skip_prompt = True) |
| 69 | _ = model.generate(input_ids, streamer = text_streamer, max_new_tokens = 128, pad_token_id = tokenizer.eos_token_id) |
| 70 | exit() |
| 71 | |
| 72 | print("model loaded") |
| 73 | |
| 74 | # Load the tokenizer |
| 75 | tokenizer = AutoTokenizer.from_pretrained( |
| 76 | # "google/gemma-2-2b", |
| 77 | # "meta-llama/Llama-3.2-1B-Instruct" |
| 78 | "./llama_peft_1/checkpoint-60" |
| 79 | # "./test_output" |
| 80 | ) |
| 81 | |
| 82 | print("Model loaded") |
| 83 | |
| 84 | # Prepare the input text |
| 85 | # input_text = "[Game Minecraft 1.19.2] Player Asked: How do I craft a bed? Answer:" |
| 86 | input_text = dataset.get_filled_prompt("The quick brown fox jumps over the lazy dog") |
| 87 | |
| 88 | # Tokenize the input text |
| 89 | inputs = tokenizer(input_text, return_tensors="pt") |
| 90 | |
| 91 | text_streamer = TextStreamer(tokenizer, skip_prompt = True) |
| 92 | |
| 93 | # Generate output (you can adjust max_length and other parameters as needed) |
| 94 | _ = model.generate(inputs["input_ids"], streamer=text_streamer, pad_token_id = tokenizer.eos_token_id, max_length=5000) |
| 95 | exit() |
| 96 | |
| 97 | # Decode the generated tokens to get the output text |
| 98 | output_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| 99 | |
| 100 | print("Output Text:", output_text) |