microsoft/TypeAgent

Public

mirrored fromhttps://github.com/microsoft/TypeAgentAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
c765421e8f1db0fb5df3e6aa28254a56e6e33e0a

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

python/fineTuning/eval.py

100lines · modecode

1# Copyright (c) Microsoft Corporation and Henry Lucco.
2# Licensed under the MIT License.
3
4from chaparral.util.datareader import DataReader
5from chaparral.train.hf_model import HFModel
6from chaparral.train.hf_params import HFParams
7import argparse
8import torch
9from unsloth import FastLanguageModel
10from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
11
12
13def parse_args():
14 parser = argparse.ArgumentParser(
15 description="Fine-tune a model with given dataset.")
16 parser.add_argument("--dataset_file", help="Path to the dataset file.")
17 parser.add_argument("--model_name", help="Name of the model to fine-tune.")
18 parser.add_argument("--params", help="Path to params file")
19 return parser.parse_args()
20
21
22if __name__ == "__main__":
23 args = parse_args()
24 dataset_file = args.dataset_file
25 params_file = args.params
26
27 # load params
28 params = HFParams.from_file(params_file)
29
30 # load dataset
31 dataset = DataReader().load_text_file(dataset_file)
32
33 # format data into train and eval sets
34 train_set, eval_set = dataset.create_train_eval_sets()
35
36
37 # model = HFModel(params)
38 # Load the model
39 """
40 model = AutoModelForCausalLM.from_pretrained(
41 # "./hf_output_llama_1b_half_epoch/checkpoint-10",
42 # "./test_output"
43 # "meta-llama/Llama-3.2-1B-Instruct"
44 "./llama_peft_1/checkpoint-60"
45 )
46 """
47
48 model = AutoModelForCausalLM.from_pretrained(
49 "./llama_peft_1/checkpoint-20",
50 load_in_4bit = True,
51 )
52 tokenizer = AutoTokenizer.from_pretrained(
53 # "google/gemma-2-2b",
54 # "meta-llama/Llama-3.2-1B-Instruct"
55 "./llama_peft_1/checkpoint-20"
56 # "./test_output"
57 )
58
59 messages = [ # Change below!
60 {"role": "user", "content": dataset.get_filled_prompt("The quick brown fox jumps over the lazy dog")},
61 ]
62 input_ids = tokenizer.apply_chat_template(
63 messages,
64 add_generation_prompt = True,
65 return_tensors = "pt",
66 ).to("cuda")
67
68 text_streamer = TextStreamer(tokenizer, skip_prompt = True)
69 _ = model.generate(input_ids, streamer = text_streamer, max_new_tokens = 128, pad_token_id = tokenizer.eos_token_id)
70 exit()
71
72 print("model loaded")
73
74 # Load the tokenizer
75 tokenizer = AutoTokenizer.from_pretrained(
76 # "google/gemma-2-2b",
77 # "meta-llama/Llama-3.2-1B-Instruct"
78 "./llama_peft_1/checkpoint-60"
79 # "./test_output"
80 )
81
82 print("Model loaded")
83
84 # Prepare the input text
85 # input_text = "[Game Minecraft 1.19.2] Player Asked: How do I craft a bed? Answer:"
86 input_text = dataset.get_filled_prompt("The quick brown fox jumps over the lazy dog")
87
88 # Tokenize the input text
89 inputs = tokenizer(input_text, return_tensors="pt")
90
91 text_streamer = TextStreamer(tokenizer, skip_prompt = True)
92
93 # Generate output (you can adjust max_length and other parameters as needed)
94 _ = model.generate(inputs["input_ids"], streamer=text_streamer, pad_token_id = tokenizer.eos_token_id, max_length=5000)
95 exit()
96
97 # Decode the generated tokens to get the output text
98 output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
99
100 print("Output Text:", output_text)