microsoft/TypeAgent

Public

mirrored fromhttps://github.com/microsoft/TypeAgentAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
ecec6489db2ad85ccd1be99c0860dc3b057af2b1

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

python/fineTuning/unsloth/batchInfer.py

76lines · modecode

1# Copyright (c) Microsoft Corporation and Henry Lucco.
2# Licensed under the MIT License.
3
4from unsloth import FastLanguageModel
5import sys
6import os
7from knowledgePrompt import get_knowledge_prompt
8from argparse import ArgumentParser
9parser = ArgumentParser(description="Run batch inference with a language model.")
10parser.add_argument("--model_path", type=str, default = "/data/phi-4-lora-3200",help="Path to the pre-trained model.")
11parser.add_argument("--dataset_path", type=str, default = '/data/npr_chunks_no_embedding_seed127_samples5000_test.json', help="Path to the dataset file.")
12parser.add_argument("--output_file", type=str, default = 'batchOutput.txt', help="Path to the output file.")
13args = parser.parse_args(sys.argv[1:])
14model_path = args.model_path
15dataset_path = args.dataset_path
16output_file = args.output_file
17model, tokenizer = FastLanguageModel.from_pretrained(
18 model_name = model_path,
19 max_seq_length = 8192,
20 dtype = None,
21 load_in_4bit = True,
22)
23
24FastLanguageModel.for_inference(model) # Enable native 2x faster inference
25
26
27# load an array of JSON objects with properties knowledge and message
28import json
29from datasets import Dataset
30import pandas as pd
31
32with open(dataset_path) as f:
33 rawData = json.load(f)
34chats = []
35for i in range(len(rawData)):
36 message = rawData[i]['speaker']+": "+rawData[i]['content']
37 chats.append([{"role": "user", "content": get_knowledge_prompt(message)}])
38
39def print_responses(responses, time_taken, file):
40 for response in responses:
41 # find string 'Response:' and print the text after it
42 start = response.find('Response:')
43 if start != -1:
44 text = response[start + len('Response:'):].strip()
45 if text:
46 file.write(text + "\n")
47 token_count = len(tokenizer.encode(text))
48 # print token count and tokens per second
49 file.write(f"Token count: {token_count} at {token_count / time_taken:.2f} tokens/sec\n")
50
51import time
52
53count = 0
54with open(output_file, 'w', encoding='utf-8') as f:
55 for message in chats:
56 # record start time
57 start = time.time()
58 model_inputs = tokenizer.apply_chat_template(
59 [message],
60 add_generation_prompt = True,
61 return_tensors = "pt",
62 ).to("cuda")
63
64 output = model.generate(model_inputs, max_new_tokens = 1024, pad_token_id = tokenizer.eos_token_id)
65 output_text = tokenizer.decode(output[0], skip_special_tokens = True)
66 # record end time
67 end = time.time()
68 f.write(rawData[count]['speaker']+": "+rawData[count]['content'] + "\n")
69 count += 1
70 print_responses([output_text], end-start, f)
71 # print end time with two decimal places
72 f.write("Time taken: {:.2f}sec\n".format(end - start))
73 f.write("\n" + "="*80 + "\n\n")
74
75print(f"Batch inference complete. Results written to {output_file}")
76print(f"Processed {count} messages")
77