microsoft/TypeAgent

Public

mirrored fromhttps://github.com/microsoft/TypeAgentAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
ca5f385b23f2caef43ded65cb3d1fed21117e6cb

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

python/ta/test/fixtures.py

276lines · modecode

1# Copyright (c) Microsoft Corporation.
2# Licensed under the MIT License.
3
4from collections.abc import AsyncGenerator, Iterator
5import os
6import tempfile
7from typing import assert_never
8
9import pytest
10import pytest_asyncio
11
12from typeagent.aitools import utils
13from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME
14from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings
15from typeagent.knowpro.collections import (
16 MemoryMessageCollection,
17 MemorySemanticRefCollection,
18)
19from typeagent.knowpro.convsettings import ConversationSettings
20from typeagent.knowpro.interfaces import (
21 DeletionInfo,
22 IConversation,
23 IConversationSecondaryIndexes,
24 IMessage,
25 IMessageCollection,
26 ISemanticRefCollection,
27 IStorageProvider,
28 ITermToSemanticRefIndex,
29 SemanticRef,
30 ScoredSemanticRefOrdinal,
31 TextLocation,
32)
33from typeagent.knowpro.kplib import KnowledgeResponse
34from typeagent.knowpro.messageindex import MessageTextIndexSettings
35from typeagent.knowpro.reltermsindex import RelatedTermIndexSettings
36from typeagent.knowpro.secindex import ConversationSecondaryIndexes
37from typeagent.storage.memorystore import MemoryStorageProvider
38from typeagent.storage.sqlitestore import SqliteStorageProvider
39
40
41@pytest.fixture(scope="session")
42def needs_auth() -> None:
43 utils.load_dotenv()
44
45
46@pytest.fixture(scope="session")
47def embedding_model() -> AsyncEmbeddingModel:
48 """Fixture to create a test embedding model with small embedding size for faster tests."""
49 return AsyncEmbeddingModel(model_name=TEST_MODEL_NAME)
50
51
52@pytest.fixture
53def temp_dir() -> Iterator[str]:
54 with tempfile.TemporaryDirectory() as dir:
55 yield dir
56
57
58@pytest.fixture
59def temp_db_path() -> Iterator[str]:
60 """Create a temporary SQLite database file for testing."""
61 fd, path = tempfile.mkstemp(suffix=".sqlite")
62 os.close(fd)
63 yield path
64 if os.path.exists(path):
65 os.remove(path)
66
67
68@pytest_asyncio.fixture
69async def memory_storage(embedding_model: AsyncEmbeddingModel) -> MemoryStorageProvider:
70 """Create a MemoryStorageProvider for testing."""
71 embedding_settings = TextEmbeddingIndexSettings(embedding_model)
72 message_text_settings = MessageTextIndexSettings(embedding_settings)
73 related_terms_settings = RelatedTermIndexSettings(embedding_settings)
74
75 return MemoryStorageProvider(
76 message_text_settings=message_text_settings,
77 related_terms_settings=related_terms_settings,
78 )
79
80
81# Unified fake message and conversation classes for testing
82
83
84class FakeMessage(IMessage):
85 """Unified message implementation for testing purposes."""
86
87 def __init__(
88 self, text_chunks: list[str] | str, message_ordinal: int | None = None
89 ):
90 if isinstance(text_chunks, str):
91 self.text_chunks = [text_chunks]
92 else:
93 self.text_chunks = text_chunks
94
95 # Handle timestamp - for compatibility with mock message pattern
96 if message_ordinal is not None:
97 self.ordinal = message_ordinal
98 self.timestamp = f"2020-01-01T{message_ordinal:02d}:00:00"
99 else:
100 self.timestamp = None
101
102 self.tags: list[str] = []
103 self.deletion_info: DeletionInfo | None = None
104 self.text_location = TextLocation(0, 0)
105
106 def get_knowledge(self) -> KnowledgeResponse:
107 return KnowledgeResponse(
108 entities=[],
109 actions=[],
110 inverse_actions=[],
111 topics=[],
112 )
113
114 def get_text(self) -> str:
115 return " ".join(self.text_chunks)
116
117 def get_text_location(self) -> TextLocation:
118 return self.text_location
119
120
121@pytest_asyncio.fixture
122async def sqlite_storage(
123 temp_db_path: str, embedding_model: AsyncEmbeddingModel
124) -> AsyncGenerator[SqliteStorageProvider[FakeMessage], None]:
125 """Create a SqliteStorageProvider for testing."""
126 embedding_settings = TextEmbeddingIndexSettings(embedding_model)
127 message_text_settings = MessageTextIndexSettings(embedding_settings)
128 related_terms_settings = RelatedTermIndexSettings(embedding_settings)
129
130 provider = await SqliteStorageProvider.create(
131 message_text_settings, related_terms_settings, temp_db_path, FakeMessage
132 )
133 yield provider
134 await provider.close()
135
136
137class FakeMessageCollection(MemoryMessageCollection[FakeMessage]):
138 """Message collection for testing."""
139
140 pass
141
142
143class FakeTermIndex(ITermToSemanticRefIndex):
144 """Simple term index for testing."""
145
146 def __init__(
147 self, term_to_refs: dict[str, list[ScoredSemanticRefOrdinal]] | None = None
148 ):
149 self.term_to_refs = term_to_refs or {}
150
151 async def size(self) -> int:
152 return len(self.term_to_refs)
153
154 async def get_terms(self) -> list[str]:
155 return list(self.term_to_refs.keys())
156
157 async def add_term(
158 self,
159 term: str,
160 semantic_ref_ordinal: int | ScoredSemanticRefOrdinal,
161 ) -> str:
162 if term not in self.term_to_refs:
163 self.term_to_refs[term] = []
164 if isinstance(semantic_ref_ordinal, int):
165 scored_ref = ScoredSemanticRefOrdinal(semantic_ref_ordinal, 1.0)
166 else:
167 scored_ref = semantic_ref_ordinal
168 self.term_to_refs[term].append(scored_ref)
169 return term
170
171 async def remove_term(self, term: str, semantic_ref_ordinal: int) -> None:
172 if term in self.term_to_refs:
173 self.term_to_refs[term] = [
174 ref
175 for ref in self.term_to_refs[term]
176 if ref.semantic_ref_ordinal != semantic_ref_ordinal
177 ]
178 if not self.term_to_refs[term]:
179 del self.term_to_refs[term]
180
181 async def clear(self) -> None:
182 """Clear all terms from the index."""
183 self.term_to_refs.clear()
184
185 async def lookup_term(self, term: str) -> list[ScoredSemanticRefOrdinal] | None:
186 return self.term_to_refs.get(term)
187
188
189class FakeConversation(IConversation[FakeMessage, FakeTermIndex]):
190 """Unified conversation implementation for testing purposes."""
191
192 def __init__(
193 self,
194 name_tag: str = "FakeConversation",
195 messages: list[FakeMessage] | None = None,
196 semantic_refs: list[SemanticRef] | None = None,
197 storage_provider: IStorageProvider | None = None,
198 has_secondary_indexes: bool = True,
199 ):
200 self.name_tag = name_tag
201 self.tags: list[str] = []
202
203 # Set up messages
204 if messages is None:
205 messages = [FakeMessage("Hello world")]
206 self.messages: IMessageCollection[FakeMessage] = FakeMessageCollection(messages)
207
208 # Set up semantic refs
209 self.semantic_refs: ISemanticRefCollection = MemorySemanticRefCollection(
210 semantic_refs or []
211 )
212
213 # Set up term index
214 self.semantic_ref_index: FakeTermIndex | None = FakeTermIndex()
215
216 # Store settings with storage provider for access via conversation.settings.storage_provider
217 if storage_provider is None:
218 # Default storage provider will be created lazily in async context
219 self._needs_async_init = True
220 self.secondary_indexes = None
221 self._storage_provider = None
222 self._has_secondary_indexes = has_secondary_indexes
223 else:
224 # Create test model for settings
225 test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME)
226 self.settings = ConversationSettings(test_model, storage_provider)
227 self._needs_async_init = False
228 self._storage_provider = storage_provider
229
230 if has_secondary_indexes:
231 # Set up secondary indexes
232 embedding_settings = TextEmbeddingIndexSettings(test_model)
233 related_terms_settings = RelatedTermIndexSettings(embedding_settings)
234 self.secondary_indexes: (
235 IConversationSecondaryIndexes[FakeMessage] | None
236 ) = ConversationSecondaryIndexes(
237 storage_provider, related_terms_settings
238 )
239 else:
240 self.secondary_indexes = None
241
242 async def ensure_initialized(self):
243 """Ensure async initialization is complete."""
244 if self._needs_async_init:
245 test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME)
246 self.settings = ConversationSettings(test_model)
247 storage_provider = await self.settings.get_storage_provider()
248 self._storage_provider = storage_provider
249 if self.semantic_ref_index is None:
250 self.semantic_ref_index = await storage_provider.get_semantic_ref_index() # type: ignore
251
252 if self._has_secondary_indexes:
253 # Set up secondary indexes
254 embedding_settings = TextEmbeddingIndexSettings(test_model)
255 related_terms_settings = RelatedTermIndexSettings(embedding_settings)
256 self.secondary_indexes = ConversationSecondaryIndexes(
257 storage_provider, related_terms_settings
258 )
259 else:
260 self.secondary_indexes = None
261
262 self._needs_async_init = False
263
264
265@pytest.fixture
266def fake_conversation() -> FakeConversation:
267 """Fixture to create a FakeConversation instance."""
268 return FakeConversation()
269
270
271@pytest.fixture
272async def fake_conversation_with_storage(
273 memory_storage: MemoryStorageProvider,
274) -> FakeConversation:
275 """Fixture to create a FakeConversation instance with storage provider."""
276 return FakeConversation(storage_provider=memory_storage)