microsoft/TypeAgent

Public

mirrored from https://github.com/microsoft/TypeAgentAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
899cedcc36e55e38e164b77ce6d9f37dc5245684

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

python/fineTuning/unsloth/batchInfer.py

73lines · modecode

1from unsloth import FastLanguageModel
2import sys
3import os
4from knowledgePrompt import get_knowledge_prompt
5from argparse import ArgumentParser
6parser = ArgumentParser(description="Run batch inference with a language model.")
7parser.add_argument("--model_path", type=str, default = "/data/phi-4-lora-3200",help="Path to the pre-trained model.")
8parser.add_argument("--dataset_path", type=str, default = '/data/npr_chunks_no_embedding_seed127_samples5000_test.json', help="Path to the dataset file.")
9parser.add_argument("--output_file", type=str, default = 'batchOutput.txt', help="Path to the output file.")
10args = parser.parse_args(sys.argv[1:])
11model_path = args.model_path
12dataset_path = args.dataset_path
13output_file = args.output_file
14model, tokenizer = FastLanguageModel.from_pretrained(
15 model_name = model_path,
16 max_seq_length = 8192,
17 dtype = None,
18 load_in_4bit = True,
19)
20
21FastLanguageModel.for_inference(model) # Enable native 2x faster inference
22
23
24# load an array of JSON objects with properties knowledge and message
25import json
26from datasets import Dataset
27import pandas as pd
28
29with open(dataset_path) as f:
30 rawData = json.load(f)
31chats = []
32for i in range(len(rawData)):
33 message = rawData[i]['speaker']+": "+rawData[i]['content']
34 chats.append([{"role": "user", "content": get_knowledge_prompt(message)}])
35
36def print_responses(responses, time_taken, file):
37 for response in responses:
38 # find string 'Response:' and print the text after it
39 start = response.find('Response:')
40 if start != -1:
41 text = response[start + len('Response:'):].strip()
42 if text:
43 file.write(text + "\n")
44 token_count = len(tokenizer.encode(text))
45 # print token count and tokens per second
46 file.write(f"Token count: {token_count} at {token_count / time_taken:.2f} tokens/sec\n")
47
48import time
49
50count = 0
51with open(output_file, 'w', encoding='utf-8') as f:
52 for message in chats:
53 # record start time
54 start = time.time()
55 model_inputs = tokenizer.apply_chat_template(
56 [message],
57 add_generation_prompt = True,
58 return_tensors = "pt",
59 ).to("cuda")
60
61 output = model.generate(model_inputs, max_new_tokens = 1024, pad_token_id = tokenizer.eos_token_id)
62 output_text = tokenizer.decode(output[0], skip_special_tokens = True)
63 # record end time
64 end = time.time()
65 f.write(rawData[count]['speaker']+": "+rawData[count]['content'] + "\n")
66 count += 1
67 print_responses([output_text], end-start, f)
68 # print end time with two decimal places
69 f.write("Time taken: {:.2f}sec\n".format(end - start))
70 f.write("\n" + "="*80 + "\n\n")
71
72print(f"Batch inference complete. Results written to {output_file}")
73print(f"Processed {count} messages")
74