microsoft/TypeAgent
Publicmirrored fromhttps://github.com/microsoft/TypeAgentAvailable
python/fineTuning/unsloth/batchInfer.py
76lines · modecode
| 1 | # Copyright (c) Microsoft Corporation and Henry Lucco. |
| 2 | # Licensed under the MIT License. |
| 3 | |
| 4 | from unsloth import FastLanguageModel |
| 5 | import sys |
| 6 | import os |
| 7 | from knowledgePrompt import get_knowledge_prompt |
| 8 | from argparse import ArgumentParser |
| 9 | parser = ArgumentParser(description="Run batch inference with a language model.") |
| 10 | parser.add_argument("--model_path", type=str, default = "/data/phi-4-lora-3200",help="Path to the pre-trained model.") |
| 11 | parser.add_argument("--dataset_path", type=str, default = '/data/npr_chunks_no_embedding_seed127_samples5000_test.json', help="Path to the dataset file.") |
| 12 | parser.add_argument("--output_file", type=str, default = 'batchOutput.txt', help="Path to the output file.") |
| 13 | args = parser.parse_args(sys.argv[1:]) |
| 14 | model_path = args.model_path |
| 15 | dataset_path = args.dataset_path |
| 16 | output_file = args.output_file |
| 17 | model, tokenizer = FastLanguageModel.from_pretrained( |
| 18 | model_name = model_path, |
| 19 | max_seq_length = 8192, |
| 20 | dtype = None, |
| 21 | load_in_4bit = True, |
| 22 | ) |
| 23 | |
| 24 | FastLanguageModel.for_inference(model) # Enable native 2x faster inference |
| 25 | |
| 26 | |
| 27 | # load an array of JSON objects with properties knowledge and message |
| 28 | import json |
| 29 | from datasets import Dataset |
| 30 | import pandas as pd |
| 31 | |
| 32 | with open(dataset_path) as f: |
| 33 | rawData = json.load(f) |
| 34 | chats = [] |
| 35 | for i in range(len(rawData)): |
| 36 | message = rawData[i]['speaker']+": "+rawData[i]['content'] |
| 37 | chats.append([{"role": "user", "content": get_knowledge_prompt(message)}]) |
| 38 | |
| 39 | def print_responses(responses, time_taken, file): |
| 40 | for response in responses: |
| 41 | # find string 'Response:' and print the text after it |
| 42 | start = response.find('Response:') |
| 43 | if start != -1: |
| 44 | text = response[start + len('Response:'):].strip() |
| 45 | if text: |
| 46 | file.write(text + "\n") |
| 47 | token_count = len(tokenizer.encode(text)) |
| 48 | # print token count and tokens per second |
| 49 | file.write(f"Token count: {token_count} at {token_count / time_taken:.2f} tokens/sec\n") |
| 50 | |
| 51 | import time |
| 52 | |
| 53 | count = 0 |
| 54 | with open(output_file, 'w', encoding='utf-8') as f: |
| 55 | for message in chats: |
| 56 | # record start time |
| 57 | start = time.time() |
| 58 | model_inputs = tokenizer.apply_chat_template( |
| 59 | [message], |
| 60 | add_generation_prompt = True, |
| 61 | return_tensors = "pt", |
| 62 | ).to("cuda") |
| 63 | |
| 64 | output = model.generate(model_inputs, max_new_tokens = 1024, pad_token_id = tokenizer.eos_token_id) |
| 65 | output_text = tokenizer.decode(output[0], skip_special_tokens = True) |
| 66 | # record end time |
| 67 | end = time.time() |
| 68 | f.write(rawData[count]['speaker']+": "+rawData[count]['content'] + "\n") |
| 69 | count += 1 |
| 70 | print_responses([output_text], end-start, f) |
| 71 | # print end time with two decimal places |
| 72 | f.write("Time taken: {:.2f}sec\n".format(end - start)) |
| 73 | f.write("\n" + "="*80 + "\n\n") |
| 74 | |
| 75 | print(f"Batch inference complete. Results written to {output_file}") |
| 76 | print(f"Processed {count} messages") |
| 77 | |