microsoft/TypeAgent

Public

mirrored fromhttps://github.com/microsoft/TypeAgentAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
d8982a75d68ff326c7355970158bdf91d6590eac

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

python/nprData/btt_chunk.py

115lines · modecode

1# Copyright (c) Microsoft Corporation and Henry Lucco.
2# Licensed under the MIT License.
3
4from structs import Episode
5from dotenv import load_dotenv
6from concurrent.futures import ThreadPoolExecutor, as_completed
7from structs import Chunk
8from qdrant_client import QdrantClient
9from qdrant_client.models import VectorParams, Distance
10from embedding import Embedding
11from generate_chunks import process_turn
12import json
13import os
14from tqdm import tqdm
15
16# TODO move these to argparser
17EPISODE_PATH = "btt_podcast.txt"
18COLLECTION_NAME = "btt_llm_generic"
19CHUNK_PATH = "btt_chunks_llm_generic.json"
20USE_LLM = True
21
22if __name__ == "__main__":
23 load_dotenv("./env_vars")
24
25 episode_data = Episode.from_text_file(EPISODE_PATH)
26 use_llm = USE_LLM
27
28 if not os.path.exists(CHUNK_PATH):
29 chunks = []
30 for turn in tqdm(episode_data.sections[0].transcript):
31 chunk = process_turn(episode_data.id, episode_data.sections[0], turn, use_llm)
32 chunks.append(chunk)
33
34 # Removed concurrent processing due to rate limiting from LLM API
35 # if wanting to generate chunks where use_llm = False, uncomment the following block
36 """
37 with ThreadPoolExecutor() as executor:
38 futures = []
39 for section in episode_data.sections:
40 print(f"Section: {section.title}")
41 for turn in section.transcript:
42 futures.append(executor.submit(process_turn, episode_data.id, section, turn, use_llm))
43
44 for future in tqdm(as_completed(futures), total=len(futures)):
45 chunks.append(future.result())
46 """
47
48 with open(CHUNK_PATH, "w") as f:
49 json.dump([chunk.to_dict() for chunk in chunks], f, indent=4)
50
51 uri = os.environ.get("VECTOR_DB_URI")
52 if not uri:
53 raise ValueError("VECTOR_DB_URI environment variable is not set")
54
55 client = QdrantClient(uri)
56
57 # check if the collection already exists
58 chunks = []
59 if not client.collection_exists(COLLECTION_NAME):
60 print("Loading chunks...")
61 with open(CHUNK_PATH, "r") as f:
62 chunks = [Chunk.from_dict(x) for x in json.load(f)]
63
64 print(f"{len(chunks)} Chunks loaded")
65
66 client.create_collection(
67 COLLECTION_NAME,
68 vectors_config=VectorParams(
69 size=chunks[0].embedding.dimension,
70 distance=Distance.COSINE
71 ),
72 )
73
74 points = [
75 {
76 "id": i,
77 "vector": chunk.embedding.values,
78 "payload" : {
79 "speaker": chunk.speaker,
80 "content": chunk.content,
81 "episode_id": chunk.episode_id,
82 "section_id": chunk.section_id,
83 "section_title": chunk.section_title,
84 "speaker_role": chunk.speaker_role
85 }
86 } for i, chunk in enumerate(chunks)
87 ]
88
89 for point in tqdm(points):
90 operation_info = client.upsert(
91 collection_name=COLLECTION_NAME,
92 wait=True,
93 points=[point]
94 )
95
96 print(f"Upserted {len(points)} points")
97
98
99 print("Collection created")
100 collection_info = client.get_collection(COLLECTION_NAME)
101 print(collection_info)
102
103 while True:
104 query = input("> ")
105 if query == "exit" or query == "q" or query == "quit":
106 break
107
108 query_vector = Embedding.from_text(query).values
109 results = client.search(COLLECTION_NAME, query_vector)
110
111 terminal_size = os.get_terminal_size().columns
112 print("="*terminal_size)
113 for i, result in enumerate(results):
114 print(f"{i + 1}. {result.id} {result.payload.get('speaker').title()} : {result.payload.get('content')} [{result.payload.get('section_title')}]")
115 print("="*terminal_size)