microsoft/TypeAgent

Public

mirrored from https://github.com/microsoft/TypeAgentAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
5e1641e77d4742e44fcbe8e8b59e303392ecc023

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

python/ta/test/fixtures.py

303lines · modecode

1# Copyright (c) Microsoft Corporation.
2# Licensed under the MIT License.
3
4from collections.abc import AsyncGenerator, Iterator
5import os
6import tempfile
7from typing import assert_never
8
9import pytest
10import pytest_asyncio
11
12from typeagent.aitools import utils
13from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME
14from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings
15from typeagent.knowpro import kplib
16from typeagent.knowpro.collections import (
17 MemoryMessageCollection,
18 MemorySemanticRefCollection,
19)
20from typeagent.knowpro.convsettings import ConversationSettings
21from typeagent.knowpro.interfaces import (
22 DeletionInfo,
23 ICollection,
24 IConversation,
25 IConversationSecondaryIndexes,
26 IMessage,
27 IMessageCollection,
28 ISemanticRefCollection,
29 IStorageProvider,
30 ITermToSemanticRefIndex,
31 SemanticRef,
32 ScoredSemanticRefOrdinal,
33 TextLocation,
34 TextRange,
35)
36from typeagent.knowpro.kplib import KnowledgeResponse
37from typeagent.knowpro.messageindex import MessageTextIndexSettings
38from typeagent.knowpro.reltermsindex import RelatedTermIndexSettings
39from typeagent.knowpro.secindex import ConversationSecondaryIndexes
40from typeagent.storage.memorystore import MemoryStorageProvider
41from typeagent.storage.sqlitestore import SqliteStorageProvider
42
43
44@pytest.fixture(scope="session")
45def needs_auth() -> None:
46 utils.load_dotenv()
47
48
49@pytest.fixture(scope="session")
50def embedding_model() -> AsyncEmbeddingModel:
51 """Fixture to create a test embedding model with small embedding size for faster tests."""
52 return AsyncEmbeddingModel(model_name=TEST_MODEL_NAME)
53
54
55@pytest.fixture
56def temp_dir() -> Iterator[str]:
57 with tempfile.TemporaryDirectory() as dir:
58 yield dir
59
60
61@pytest.fixture
62def temp_db_path() -> Iterator[str]:
63 """Create a temporary SQLite database file for testing."""
64 fd, path = tempfile.mkstemp(suffix=".sqlite")
65 os.close(fd)
66 yield path
67 if os.path.exists(path):
68 os.remove(path)
69
70
71@pytest_asyncio.fixture
72async def memory_storage(embedding_model: AsyncEmbeddingModel) -> MemoryStorageProvider:
73 """Create a MemoryStorageProvider for testing."""
74 embedding_settings = TextEmbeddingIndexSettings(embedding_model)
75 message_text_settings = MessageTextIndexSettings(embedding_settings)
76 related_terms_settings = RelatedTermIndexSettings(embedding_settings)
77
78 return await MemoryStorageProvider.create(
79 message_text_settings=message_text_settings,
80 related_terms_settings=related_terms_settings,
81 )
82
83
84@pytest_asyncio.fixture
85async def sqlite_storage(
86 temp_db_path: str, embedding_model: AsyncEmbeddingModel
87) -> AsyncGenerator[SqliteStorageProvider, None]:
88 """Create a SqliteStorageProvider for testing."""
89 embedding_settings = TextEmbeddingIndexSettings(embedding_model)
90 message_text_settings = MessageTextIndexSettings(embedding_settings)
91 related_terms_settings = RelatedTermIndexSettings(embedding_settings)
92
93 provider = await SqliteStorageProvider.create(
94 message_text_settings, related_terms_settings, temp_db_path
95 )
96 yield provider
97 await provider.close()
98
99
100@pytest_asyncio.fixture(params=["memory", "sqlite"])
101async def storage_provider_type(
102 request: pytest.FixtureRequest,
103 embedding_model: AsyncEmbeddingModel,
104 temp_db_path: str,
105) -> AsyncGenerator[tuple[IStorageProvider, str], None]:
106 """Parameterized fixture that provides both memory and sqlite storage providers."""
107 embedding_settings = TextEmbeddingIndexSettings(embedding_model)
108 message_text_settings = MessageTextIndexSettings(embedding_settings)
109 related_terms_settings = RelatedTermIndexSettings(embedding_settings)
110
111 match request.param:
112 case "memory":
113 provider = await MemoryStorageProvider.create(
114 message_text_settings=message_text_settings,
115 related_terms_settings=related_terms_settings,
116 )
117 yield provider, request.param
118 case "sqlite":
119 provider = await SqliteStorageProvider.create(
120 message_text_settings, related_terms_settings, temp_db_path
121 )
122 yield provider, request.param
123 await provider.close()
124 case _:
125 assert_never(request.param)
126
127
128# Unified fake message and conversation classes for testing
129
130
131class FakeMessage(IMessage):
132 """Unified message implementation for testing purposes."""
133
134 def __init__(
135 self, text_chunks: list[str] | str, message_ordinal: int | None = None
136 ):
137 if isinstance(text_chunks, str):
138 self.text_chunks = [text_chunks]
139 else:
140 self.text_chunks = text_chunks
141
142 # Handle timestamp - for compatibility with mock message pattern
143 if message_ordinal is not None:
144 self.ordinal = message_ordinal
145 self.timestamp = f"2020-01-01T{message_ordinal:02d}:00:00"
146 else:
147 self.timestamp = None
148
149 self.tags: list[str] = []
150 self.deletion_info: DeletionInfo | None = None
151 self.text_location = TextLocation(0, 0)
152
153 def get_knowledge(self) -> KnowledgeResponse:
154 return KnowledgeResponse(
155 entities=[],
156 actions=[],
157 inverse_actions=[],
158 topics=[],
159 )
160
161 def get_text(self) -> str:
162 return " ".join(self.text_chunks)
163
164 def get_text_location(self) -> TextLocation:
165 return self.text_location
166
167
168class FakeMessageCollection(MemoryMessageCollection[FakeMessage]):
169 """Message collection for testing."""
170
171 pass
172
173
174class FakeTermIndex(ITermToSemanticRefIndex):
175 """Simple term index for testing."""
176
177 def __init__(
178 self, term_to_refs: dict[str, list[ScoredSemanticRefOrdinal]] | None = None
179 ):
180 self.term_to_refs = term_to_refs or {}
181
182 async def size(self) -> int:
183 return len(self.term_to_refs)
184
185 async def get_terms(self) -> list[str]:
186 return list(self.term_to_refs.keys())
187
188 async def add_term(
189 self,
190 term: str,
191 semantic_ref_ordinal: int | ScoredSemanticRefOrdinal,
192 ) -> str:
193 if term not in self.term_to_refs:
194 self.term_to_refs[term] = []
195 if isinstance(semantic_ref_ordinal, int):
196 scored_ref = ScoredSemanticRefOrdinal(semantic_ref_ordinal, 1.0)
197 else:
198 scored_ref = semantic_ref_ordinal
199 self.term_to_refs[term].append(scored_ref)
200 return term
201
202 async def remove_term(self, term: str, semantic_ref_ordinal: int) -> None:
203 if term in self.term_to_refs:
204 self.term_to_refs[term] = [
205 ref
206 for ref in self.term_to_refs[term]
207 if ref.semantic_ref_ordinal != semantic_ref_ordinal
208 ]
209 if not self.term_to_refs[term]:
210 del self.term_to_refs[term]
211
212 async def lookup_term(self, term: str) -> list[ScoredSemanticRefOrdinal] | None:
213 return self.term_to_refs.get(term)
214
215
216class FakeConversation(IConversation[FakeMessage, FakeTermIndex]):
217 """Unified conversation implementation for testing purposes."""
218
219 def __init__(
220 self,
221 name_tag: str = "FakeConversation",
222 messages: list[FakeMessage] | None = None,
223 semantic_refs: list[SemanticRef] | None = None,
224 storage_provider: IStorageProvider | None = None,
225 has_secondary_indexes: bool = True,
226 ):
227 self.name_tag = name_tag
228 self.tags: list[str] = []
229
230 # Set up messages
231 if messages is None:
232 messages = [FakeMessage("Hello world")]
233 self.messages: IMessageCollection[FakeMessage] = FakeMessageCollection(messages)
234
235 # Set up semantic refs
236 self.semantic_refs: ISemanticRefCollection = MemorySemanticRefCollection(
237 semantic_refs or []
238 )
239
240 # Set up term index
241 self.semantic_ref_index: FakeTermIndex | None = FakeTermIndex()
242
243 # Store settings with storage provider for access via conversation.settings.storage_provider
244 if storage_provider is None:
245 # Default storage provider will be created lazily in async context
246 self._needs_async_init = True
247 self.secondary_indexes = None
248 self._storage_provider = None
249 self._has_secondary_indexes = has_secondary_indexes
250 else:
251 # Create test model for settings
252 test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME)
253 self.settings = ConversationSettings(test_model, storage_provider)
254 self._needs_async_init = False
255 self._storage_provider = storage_provider
256
257 if has_secondary_indexes:
258 # Set up secondary indexes
259 embedding_settings = TextEmbeddingIndexSettings(test_model)
260 related_terms_settings = RelatedTermIndexSettings(embedding_settings)
261 self.secondary_indexes: (
262 IConversationSecondaryIndexes[FakeMessage] | None
263 ) = ConversationSecondaryIndexes(
264 storage_provider, related_terms_settings
265 )
266 else:
267 self.secondary_indexes = None
268
269 async def ensure_initialized(self):
270 """Ensure async initialization is complete."""
271 if self._needs_async_init:
272 test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME)
273 self.settings = ConversationSettings(test_model)
274 storage_provider = await self.settings.get_storage_provider()
275 self._storage_provider = storage_provider
276 if self.semantic_ref_index is None:
277 self.semantic_ref_index = await storage_provider.get_semantic_ref_index() # type: ignore
278
279 if self._has_secondary_indexes:
280 # Set up secondary indexes
281 embedding_settings = TextEmbeddingIndexSettings(test_model)
282 related_terms_settings = RelatedTermIndexSettings(embedding_settings)
283 self.secondary_indexes = ConversationSecondaryIndexes(
284 storage_provider, related_terms_settings
285 )
286 else:
287 self.secondary_indexes = None
288
289 self._needs_async_init = False
290
291
292@pytest.fixture
293def fake_conversation() -> FakeConversation:
294 """Fixture to create a FakeConversation instance."""
295 return FakeConversation()
296
297
298@pytest.fixture
299async def fake_conversation_with_storage(
300 memory_storage: MemoryStorageProvider,
301) -> FakeConversation:
302 """Fixture to create a FakeConversation instance with storage provider."""
303 return FakeConversation(storage_provider=memory_storage)
304