microsoft/TypeAgent

Public

mirrored fromhttps://github.com/microsoft/TypeAgentAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
3d6af4fbf8c1941877d71e910be2718ab9fd3bf6

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

python/nprData/qdrant_handler.py

85lines · modecode

1# Copyright (c) Microsoft Corporation and Henry Lucco.
2# Licensed under the MIT License.
3
4from qdrant_client import QdrantClient
5from qdrant_client.models import VectorParams, Distance
6from structs import Chunk
7import json
8import os
9from dotenv import load_dotenv
10from embedding import Embedding
11from tqdm import tqdm
12
13if __name__ == "__main__":
14 load_dotenv("env_vars")
15 uri = os.environ.get("VECTOR_DB_URI")
16 if not uri:
17 raise ValueError("VECTOR_DB_URI environment variable is not set")
18
19 client = QdrantClient(uri)
20
21 # check if the collection already exists
22 chunks = []
23 if not client.collection_exists("npr"):
24 print("Loading chunks...")
25 with open("npr_chunks.json", "r") as f:
26 chunks = [Chunk.from_dict(x) for x in json.load(f)]
27
28 print(f"{len(chunks)} Chunks loaded")
29
30 client.create_collection(
31 "npr",
32 vectors_config=VectorParams(
33 size=chunks[0].embedding.dimension,
34 distance=Distance.COSINE
35 ),
36 )
37
38 points = [
39 {
40 "id": i,
41 "vector": chunk.embedding.values,
42 "payload" : {
43 "speaker": chunk.speaker,
44 "content": chunk.content,
45 "episode_id": chunk.episode_id,
46 "section_id": chunk.section_id,
47 "section_title": chunk.section_title,
48 "speaker_role": chunk.speaker_role
49 }
50 } for i, chunk in enumerate(chunks)
51 ]
52
53 for point in tqdm(points):
54 operation_info = client.upsert(
55 collection_name="npr",
56 wait=True,
57 points=[point]
58 )
59
60 print(f"Upserted {len(points)} points")
61
62
63 print("Collection created")
64 collection_info = client.get_collection("npr")
65 print(collection_info)
66
67 """
68 query_vector = Embedding.from_text("cheetah").values
69 print(query_vector)
70 exit()
71 """
72
73 while True:
74 query = input("> ")
75 if query == "exit" or query == "q" or query == "quit":
76 break
77
78 query_vector = Embedding.from_text(query).values
79 results = client.search("npr", query_vector)
80
81 terminal_size = os.get_terminal_size().columns
82 print("="*terminal_size)
83 for i, result in enumerate(results):
84 print(f"{i + 1}. {result.id} {result.payload.get('speaker').title()} : {result.payload.get('content')} [{result.payload.get('section_title')}]")
85 print("="*terminal_size)