microsoft/TypeAgent

Public

mirrored fromhttps://github.com/microsoft/TypeAgentAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
fb22ace26e4ae2010db0861d7afaa20698e14528

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

python/ta/test/fixtures.py

292lines · modecode

1# Copyright (c) Microsoft Corporation.
2# Licensed under the MIT License.
3
4from collections.abc import AsyncGenerator, Iterator
5import os
6import tempfile
7from typing import Any, assert_never
8
9import pytest
10import pytest_asyncio
11
12from typeagent.aitools import utils
13from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME
14from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings
15from typeagent.storage.memory.collections import (
16 MemoryMessageCollection,
17 MemorySemanticRefCollection,
18)
19from typeagent.knowpro.convsettings import ConversationSettings
20from typeagent.knowpro.interfaces import (
21 DeletionInfo,
22 IConversation,
23 IConversationSecondaryIndexes,
24 IMessage,
25 IMessageCollection,
26 ISemanticRefCollection,
27 IStorageProvider,
28 ITermToSemanticRefIndex,
29 SemanticRef,
30 ScoredSemanticRefOrdinal,
31 TextLocation,
32)
33from typeagent.knowpro.kplib import KnowledgeResponse
34from typeagent.knowpro.convsettings import (
35 MessageTextIndexSettings,
36 RelatedTermIndexSettings,
37)
38from typeagent.knowpro.secindex import ConversationSecondaryIndexes
39from typeagent.storage.memory import MemoryStorageProvider
40from typeagent.storage import SqliteStorageProvider
41
42
43@pytest.fixture(scope="session")
44def needs_auth() -> None:
45 utils.load_dotenv()
46
47
48@pytest.fixture(scope="session")
49def embedding_model() -> AsyncEmbeddingModel:
50 """Fixture to create a test embedding model with small embedding size for faster tests."""
51 return AsyncEmbeddingModel(model_name=TEST_MODEL_NAME)
52
53
54@pytest.fixture
55def temp_dir() -> Iterator[str]:
56 with tempfile.TemporaryDirectory() as dir:
57 yield dir
58
59
60@pytest.fixture
61def temp_db_path() -> Iterator[str]:
62 """Create a temporary SQLite database file for testing."""
63 fd, path = tempfile.mkstemp(suffix=".sqlite")
64 os.close(fd)
65 yield path
66 if os.path.exists(path):
67 os.remove(path)
68
69
70@pytest.fixture
71def memory_storage(
72 embedding_model: AsyncEmbeddingModel,
73) -> MemoryStorageProvider:
74 """Create a memory storage provider with settings."""
75 embedding_settings = TextEmbeddingIndexSettings(embedding_model=embedding_model)
76 message_text_settings = MessageTextIndexSettings(
77 embedding_index_settings=embedding_settings
78 )
79 related_terms_settings = RelatedTermIndexSettings(
80 embedding_index_settings=embedding_settings
81 )
82 return MemoryStorageProvider(
83 message_text_settings=message_text_settings,
84 related_terms_settings=related_terms_settings,
85 )
86
87
88# Unified fake message and conversation classes for testing
89
90
91class FakeMessage(IMessage):
92 """Unified message implementation for testing purposes."""
93
94 def __init__(
95 self, text_chunks: list[str] | str, message_ordinal: int | None = None
96 ):
97 if isinstance(text_chunks, str):
98 self.text_chunks = [text_chunks]
99 else:
100 self.text_chunks = text_chunks
101
102 # Handle timestamp - for compatibility with mock message pattern
103 if message_ordinal is not None:
104 self.ordinal = message_ordinal
105 self.timestamp = f"2020-01-01T{message_ordinal:02d}:00:00"
106 else:
107 self.timestamp = None
108
109 self.tags: list[str] = []
110 self.deletion_info: DeletionInfo | None = None
111 self.text_location = TextLocation(0, 0)
112
113 def get_knowledge(self) -> KnowledgeResponse:
114 return KnowledgeResponse(
115 entities=[],
116 actions=[],
117 inverse_actions=[],
118 topics=[],
119 )
120
121 def get_text(self) -> str:
122 return " ".join(self.text_chunks)
123
124 def get_text_location(self) -> TextLocation:
125 return self.text_location
126
127
128@pytest_asyncio.fixture
129async def sqlite_storage(
130 temp_db_path: str, embedding_model: AsyncEmbeddingModel
131) -> AsyncGenerator[SqliteStorageProvider[FakeMessage], None]:
132 """Create a SqliteStorageProvider for testing."""
133 embedding_settings = TextEmbeddingIndexSettings(embedding_model)
134 message_text_settings = MessageTextIndexSettings(embedding_settings)
135 related_terms_settings = RelatedTermIndexSettings(embedding_settings)
136
137 provider = SqliteStorageProvider(
138 db_path=temp_db_path,
139 message_type=FakeMessage,
140 message_text_index_settings=message_text_settings,
141 related_term_index_settings=related_terms_settings,
142 )
143 yield provider
144 await provider.close()
145
146
147class FakeMessageCollection(MemoryMessageCollection[FakeMessage]):
148 """Message collection for testing."""
149
150 pass
151
152
153class FakeTermIndex(ITermToSemanticRefIndex):
154 """Simple term index for testing."""
155
156 def __init__(
157 self, term_to_refs: dict[str, list[ScoredSemanticRefOrdinal]] | None = None
158 ):
159 self.term_to_refs = term_to_refs or {}
160
161 async def size(self) -> int:
162 return len(self.term_to_refs)
163
164 async def get_terms(self) -> list[str]:
165 return list(self.term_to_refs.keys())
166
167 async def add_term(
168 self,
169 term: str,
170 semantic_ref_ordinal: int | ScoredSemanticRefOrdinal,
171 ) -> str:
172 if term not in self.term_to_refs:
173 self.term_to_refs[term] = []
174 if isinstance(semantic_ref_ordinal, int):
175 scored_ref = ScoredSemanticRefOrdinal(semantic_ref_ordinal, 1.0)
176 else:
177 scored_ref = semantic_ref_ordinal
178 self.term_to_refs[term].append(scored_ref)
179 return term
180
181 async def remove_term(self, term: str, semantic_ref_ordinal: int) -> None:
182 if term in self.term_to_refs:
183 self.term_to_refs[term] = [
184 ref
185 for ref in self.term_to_refs[term]
186 if ref.semantic_ref_ordinal != semantic_ref_ordinal
187 ]
188 if not self.term_to_refs[term]:
189 del self.term_to_refs[term]
190
191 async def clear(self) -> None:
192 """Clear all terms from the index."""
193 self.term_to_refs.clear()
194
195 async def lookup_term(self, term: str) -> list[ScoredSemanticRefOrdinal] | None:
196 return self.term_to_refs.get(term)
197
198 async def serialize(self) -> Any:
199 raise RuntimeError
200
201 async def deserialize(self, data: Any) -> None:
202 raise RuntimeError
203
204
205class FakeConversation(IConversation[FakeMessage, FakeTermIndex]):
206 """Unified conversation implementation for testing purposes."""
207
208 def __init__(
209 self,
210 name_tag: str = "FakeConversation",
211 messages: list[FakeMessage] | None = None,
212 semantic_refs: list[SemanticRef] | None = None,
213 storage_provider: IStorageProvider | None = None,
214 has_secondary_indexes: bool = True,
215 ):
216 self.name_tag = name_tag
217 self.tags: list[str] = []
218
219 # Set up messages
220 if messages is None:
221 messages = [FakeMessage("Hello world")]
222 self.messages: IMessageCollection[FakeMessage] = FakeMessageCollection(messages)
223
224 # Set up semantic refs
225 self.semantic_refs: ISemanticRefCollection = MemorySemanticRefCollection(
226 semantic_refs or []
227 )
228
229 # Set up term index
230 self.semantic_ref_index: FakeTermIndex | None = FakeTermIndex()
231
232 # Store settings with storage provider for access via conversation.settings.storage_provider
233 if storage_provider is None:
234 # Default storage provider will be created lazily in async context
235 self._needs_async_init = True
236 self.secondary_indexes = None
237 self._storage_provider = None
238 self._has_secondary_indexes = has_secondary_indexes
239 else:
240 # Create test model for settings
241 test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME)
242 self.settings = ConversationSettings(test_model, storage_provider)
243 self._needs_async_init = False
244 self._storage_provider = storage_provider
245
246 if has_secondary_indexes:
247 # Set up secondary indexes
248 embedding_settings = TextEmbeddingIndexSettings(test_model)
249 related_terms_settings = RelatedTermIndexSettings(embedding_settings)
250 self.secondary_indexes: (
251 IConversationSecondaryIndexes[FakeMessage] | None
252 ) = ConversationSecondaryIndexes(
253 storage_provider, related_terms_settings
254 )
255 else:
256 self.secondary_indexes = None
257
258 async def ensure_initialized(self):
259 """Ensure async initialization is complete."""
260 if self._needs_async_init:
261 test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME)
262 self.settings = ConversationSettings(test_model)
263 storage_provider = await self.settings.get_storage_provider()
264 self._storage_provider = storage_provider
265 if self.semantic_ref_index is None:
266 self.semantic_ref_index = await storage_provider.get_semantic_ref_index() # type: ignore
267
268 if self._has_secondary_indexes:
269 # Set up secondary indexes
270 embedding_settings = TextEmbeddingIndexSettings(test_model)
271 related_terms_settings = RelatedTermIndexSettings(embedding_settings)
272 self.secondary_indexes = ConversationSecondaryIndexes(
273 storage_provider, related_terms_settings
274 )
275 else:
276 self.secondary_indexes = None
277
278 self._needs_async_init = False
279
280
281@pytest.fixture
282def fake_conversation() -> FakeConversation:
283 """Fixture to create a FakeConversation instance."""
284 return FakeConversation()
285
286
287@pytest.fixture
288async def fake_conversation_with_storage(
289 memory_storage: MemoryStorageProvider,
290) -> FakeConversation:
291 """Fixture to create a FakeConversation instance with storage provider."""
292 return FakeConversation(storage_provider=memory_storage)