microsoft/TypeAgent

Public

mirrored from https://github.com/microsoft/TypeAgentAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
c00d4d0b224106d7aedb86e5161cb4ff31875909

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

python/ta/test/fixtures.py

101lines · modecode

1# Copyright (c) Microsoft Corporation.
2# Licensed under the MIT License.
3
4from collections.abc import AsyncGenerator, Iterator
5import os
6import tempfile
7from typing import assert_never
8
9import pytest
10import pytest_asyncio
11
12from typeagent.aitools import utils
13from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME
14from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings
15from typeagent.knowpro.interfaces import IStorageProvider
16from typeagent.knowpro.messageindex import MessageTextIndexSettings
17from typeagent.knowpro.reltermsindex import RelatedTermIndexSettings
18from typeagent.storage.memorystore import MemoryStorageProvider
19from typeagent.storage.sqlitestore import SqliteStorageProvider
20
21
22@pytest.fixture(scope="session")
23def needs_auth() -> None:
24 utils.load_dotenv()
25
26
27@pytest.fixture(scope="session")
28def embedding_model() -> AsyncEmbeddingModel:
29 """Fixture to create a test embedding model with small embedding size for faster tests."""
30 return AsyncEmbeddingModel(model_name=TEST_MODEL_NAME)
31
32
33@pytest.fixture
34def temp_dir() -> Iterator[str]:
35 with tempfile.TemporaryDirectory() as dir:
36 yield dir
37
38
39@pytest.fixture
40def temp_db_path() -> Iterator[str]:
41 """Create a temporary SQLite database file for testing."""
42 fd, path = tempfile.mkstemp(suffix=".sqlite")
43 os.close(fd)
44 yield path
45 if os.path.exists(path):
46 os.remove(path)
47
48
49@pytest_asyncio.fixture
50async def memory_storage(embedding_model) -> MemoryStorageProvider:
51 """Create a MemoryStorageProvider for testing."""
52 embedding_settings = TextEmbeddingIndexSettings(embedding_model)
53 message_text_settings = MessageTextIndexSettings(embedding_settings)
54 related_terms_settings = RelatedTermIndexSettings(embedding_settings)
55
56 return await MemoryStorageProvider.create(
57 message_text_settings=message_text_settings,
58 related_terms_settings=related_terms_settings,
59 )
60
61
62@pytest_asyncio.fixture
63async def sqlite_storage(
64 temp_db_path, embedding_model
65) -> AsyncGenerator[SqliteStorageProvider, None]:
66 """Create a SqliteStorageProvider for testing."""
67 embedding_settings = TextEmbeddingIndexSettings(embedding_model)
68 message_text_settings = MessageTextIndexSettings(embedding_settings)
69 related_terms_settings = RelatedTermIndexSettings(embedding_settings)
70
71 provider = await SqliteStorageProvider.create(
72 message_text_settings, related_terms_settings, temp_db_path
73 )
74 yield provider
75 await provider.close()
76
77
78@pytest_asyncio.fixture(params=["memory", "sqlite"])
79async def storage_provider_type(
80 request, embedding_model, temp_db_path
81) -> AsyncGenerator[tuple[IStorageProvider, str], None]:
82 """Parameterized fixture that provides both memory and sqlite storage providers."""
83 embedding_settings = TextEmbeddingIndexSettings(embedding_model)
84 message_text_settings = MessageTextIndexSettings(embedding_settings)
85 related_terms_settings = RelatedTermIndexSettings(embedding_settings)
86
87 match request.param:
88 case "memory":
89 provider = await MemoryStorageProvider.create(
90 message_text_settings=message_text_settings,
91 related_terms_settings=related_terms_settings,
92 )
93 yield provider, request.param
94 case "sqlite":
95 provider = await SqliteStorageProvider.create(
96 message_text_settings, related_terms_settings, temp_db_path
97 )
98 yield provider, request.param
99 await provider.close()
100 case _:
101 assert_never(request.param)
102