microsoft/TypeAgent
Publicmirrored fromhttps://github.com/microsoft/TypeAgentAvailable
python/fineTuning/tune.py
44lines · modecode
| 1 | # Copyright (c) Microsoft Corporation and Henry Lucco. |
| 2 | # Licensed under the MIT License. |
| 3 | |
| 4 | from chaparral.util.datareader import DataReader |
| 5 | from chaparral.train.hf_model import HFModel |
| 6 | from chaparral.train.hf_params import HFParams |
| 7 | import argparse |
| 8 | def parse_args(): |
| 9 | parser = argparse.ArgumentParser(description="Fine-tune a model with given dataset.") |
| 10 | parser.add_argument("--dataset_file", help="Path to the dataset file.") |
| 11 | parser.add_argument("--model_name", help="Name of the model to fine-tune.") |
| 12 | parser.add_argument("--params", help="Path to params file") |
| 13 | return parser.parse_args() |
| 14 | |
| 15 | if __name__ == "__main__": |
| 16 | args = parse_args() |
| 17 | dataset_file = args.dataset_file |
| 18 | params_file = args.params |
| 19 | |
| 20 | # load params |
| 21 | params = HFParams.from_file(params_file) |
| 22 | |
| 23 | # load dataset |
| 24 | dataset = DataReader().load_text_file(dataset_file) |
| 25 | |
| 26 | # format data into train and eval sets |
| 27 | train_set, eval_set = dataset.create_train_eval_sets() |
| 28 | |
| 29 | model = HFModel(params) |
| 30 | |
| 31 | |
| 32 | model.load_model() |
| 33 | |
| 34 | model.init_data() |
| 35 | |
| 36 | print("Model loaded") |
| 37 | |
| 38 | # model.train() |
| 39 | model.sft_train() |
| 40 | |
| 41 | model.save_model("./test_output") |
| 42 | |
| 43 | # model.load_local_model("./test_output") |
| 44 | # print(model.evaluate(eval_set)) |