microsoft/TypeAgent

Public

mirrored fromhttps://github.com/microsoft/TypeAgentAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
8ba3ebd84dd1bb6343ebae028996313cabd70764

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

python/fineTuning/tune.py

44lines · modecode

1# Copyright (c) Microsoft Corporation and Henry Lucco.
2# Licensed under the MIT License.
3
4from chaparral.util.datareader import DataReader
5from chaparral.train.hf_model import HFModel
6from chaparral.train.hf_params import HFParams
7import argparse
8def parse_args():
9 parser = argparse.ArgumentParser(description="Fine-tune a model with given dataset.")
10 parser.add_argument("--dataset_file", help="Path to the dataset file.")
11 parser.add_argument("--model_name", help="Name of the model to fine-tune.")
12 parser.add_argument("--params", help="Path to params file")
13 return parser.parse_args()
14
15if __name__ == "__main__":
16 args = parse_args()
17 dataset_file = args.dataset_file
18 params_file = args.params
19
20 # load params
21 params = HFParams.from_file(params_file)
22
23 # load dataset
24 dataset = DataReader().load_text_file(dataset_file)
25
26 # format data into train and eval sets
27 train_set, eval_set = dataset.create_train_eval_sets()
28
29 model = HFModel(params)
30
31
32 model.load_model()
33
34 model.init_data()
35
36 print("Model loaded")
37
38 # model.train()
39 model.sft_train()
40
41 model.save_model("./test_output")
42
43 # model.load_local_model("./test_output")
44 # print(model.evaluate(eval_set))