openai/chatkit-python
Publicmirrored fromhttps://github.com/openai/chatkit-pythonAvailable
tests/test_store.py
457lines · modecode
| 1 | import sqlite3 |
| 2 | from abc import ABC, abstractmethod |
| 3 | from datetime import datetime, timedelta |
| 4 | |
| 5 | import pytest |
| 6 | from helpers.mock_store import SQLiteStore |
| 7 | from pydantic import AnyUrl |
| 8 | |
| 9 | from chatkit.store import NotFoundError, Store |
| 10 | from chatkit.types import ( |
| 11 | AssistantMessageContent, |
| 12 | AssistantMessageItem, |
| 13 | FileAttachment, |
| 14 | ImageAttachment, |
| 15 | InferenceOptions, |
| 16 | ThreadItem, |
| 17 | ThreadMetadata, |
| 18 | UserMessageItem, |
| 19 | UserMessageTextContent, |
| 20 | WidgetItem, |
| 21 | ) |
| 22 | from chatkit.widgets import Card, Col, Text |
| 23 | from tests._types import RequestContext |
| 24 | |
| 25 | |
| 26 | def make_thread(thread_id="test_thread", created_at=None): |
| 27 | if created_at is None: |
| 28 | created_at = datetime.now() |
| 29 | return ThreadMetadata( |
| 30 | id=thread_id, |
| 31 | title="Test Thread", |
| 32 | created_at=created_at, |
| 33 | metadata={"test": "test"}, |
| 34 | ) |
| 35 | |
| 36 | |
| 37 | def make_thread_items() -> list[ThreadItem]: |
| 38 | now = datetime.now() |
| 39 | user_msg = UserMessageItem( |
| 40 | id="msg_100000", |
| 41 | content=[UserMessageTextContent(text="Hello!")], |
| 42 | attachments=[ |
| 43 | FileAttachment( |
| 44 | id="file_1", |
| 45 | type="file", |
| 46 | mime_type="text/plain", |
| 47 | name="test.txt", |
| 48 | metadata={"source": "test"}, |
| 49 | ) |
| 50 | ], |
| 51 | inference_options=InferenceOptions(), |
| 52 | thread_id="test_thread", |
| 53 | created_at=now, |
| 54 | ) |
| 55 | assistant_msg = AssistantMessageItem( |
| 56 | id="msg_000001", |
| 57 | content=[AssistantMessageContent(text="Hi there!")], |
| 58 | thread_id="test_thread", |
| 59 | created_at=now + timedelta(seconds=1), |
| 60 | ) |
| 61 | widget = WidgetItem( |
| 62 | id="widget_1", |
| 63 | type="widget", |
| 64 | thread_id="test_thread", |
| 65 | created_at=now + timedelta(seconds=2), |
| 66 | widget=Card( |
| 67 | children=[ |
| 68 | Col( |
| 69 | padding={"x": 4, "y": 3}, |
| 70 | border={"bottom": 1}, |
| 71 | children=[ |
| 72 | Text( |
| 73 | value="Title", |
| 74 | weight="medium", |
| 75 | size="sm", |
| 76 | color="secondary", |
| 77 | ), |
| 78 | Text( |
| 79 | value="test", |
| 80 | ), |
| 81 | ], |
| 82 | ), |
| 83 | ] |
| 84 | ), |
| 85 | ) |
| 86 | |
| 87 | return [user_msg, assistant_msg, widget] |
| 88 | |
| 89 | |
| 90 | DEFAULT_CONTEXT = RequestContext(user_id="test_user") |
| 91 | ALTERNATIVE_CONTEXT = RequestContext(user_id="alternative_user") |
| 92 | |
| 93 | |
| 94 | class TestStore(ABC): |
| 95 | store: Store |
| 96 | |
| 97 | @abstractmethod |
| 98 | def setup_method(self, method): |
| 99 | pass |
| 100 | |
| 101 | @abstractmethod |
| 102 | def teardown_method(self, method): |
| 103 | pass |
| 104 | |
| 105 | @pytest.mark.asyncio |
| 106 | async def test_save_and_load_thread_metadata(self): |
| 107 | thread = make_thread() |
| 108 | await self.store.save_thread(thread, DEFAULT_CONTEXT) |
| 109 | loaded_meta = await self.store.load_thread(thread.id, DEFAULT_CONTEXT) |
| 110 | assert loaded_meta.id == thread.id |
| 111 | assert loaded_meta.title == thread.title |
| 112 | assert loaded_meta.metadata == thread.metadata |
| 113 | |
| 114 | @pytest.mark.asyncio |
| 115 | async def test_save_and_load_thread_metadata_null_title(self): |
| 116 | thread = ThreadMetadata( |
| 117 | id="test_thread", |
| 118 | title=None, |
| 119 | created_at=datetime.now(), |
| 120 | metadata={"test": "test"}, |
| 121 | ) |
| 122 | await self.store.save_thread(thread, DEFAULT_CONTEXT) |
| 123 | loaded_meta = await self.store.load_thread(thread.id, DEFAULT_CONTEXT) |
| 124 | assert loaded_meta.id == thread.id |
| 125 | assert loaded_meta.title is None |
| 126 | assert loaded_meta.metadata == thread.metadata |
| 127 | |
| 128 | @pytest.mark.asyncio |
| 129 | async def test_update_thread_metadata(self): |
| 130 | thread = make_thread() |
| 131 | await self.store.save_thread(thread, DEFAULT_CONTEXT) |
| 132 | thread.title = "Updated Title" |
| 133 | await self.store.save_thread(thread, DEFAULT_CONTEXT) |
| 134 | loaded_meta = await self.store.load_thread(thread.id, DEFAULT_CONTEXT) |
| 135 | assert loaded_meta.title == "Updated Title" |
| 136 | |
| 137 | @pytest.mark.asyncio |
| 138 | async def test_save_and_load_thread_items(self): |
| 139 | thread = make_thread() |
| 140 | items = make_thread_items() |
| 141 | await self.store.save_thread(thread, DEFAULT_CONTEXT) |
| 142 | for item in items: |
| 143 | await self.store.add_thread_item(thread.id, item, DEFAULT_CONTEXT) |
| 144 | loaded_items = ( |
| 145 | await self.store.load_thread_items( |
| 146 | thread.id, None, 10, "asc", DEFAULT_CONTEXT |
| 147 | ) |
| 148 | ).data |
| 149 | assert loaded_items == items |
| 150 | |
| 151 | @pytest.mark.asyncio |
| 152 | async def test_overwrite_thread_metadata(self): |
| 153 | thread = make_thread() |
| 154 | await self.store.save_thread(thread, DEFAULT_CONTEXT) |
| 155 | thread.title = "Updated Title" |
| 156 | await self.store.save_thread(thread, DEFAULT_CONTEXT) |
| 157 | loaded_meta = await self.store.load_thread(thread.id, DEFAULT_CONTEXT) |
| 158 | assert loaded_meta.title == "Updated Title" |
| 159 | |
| 160 | @pytest.mark.asyncio |
| 161 | async def test_save_and_load_file(self): |
| 162 | file = FileAttachment( |
| 163 | id="file_1", |
| 164 | type="file", |
| 165 | mime_type="text/plain", |
| 166 | name="test.txt", |
| 167 | ) |
| 168 | await self.store.save_attachment(file, DEFAULT_CONTEXT) |
| 169 | loaded = await self.store.load_attachment(file.id, DEFAULT_CONTEXT) |
| 170 | assert loaded == file |
| 171 | |
| 172 | @pytest.mark.asyncio |
| 173 | async def test_save_and_load_image(self): |
| 174 | image = ImageAttachment( |
| 175 | id="image_1", |
| 176 | type="image", |
| 177 | mime_type="image/png", |
| 178 | name="test.png", |
| 179 | preview_url=AnyUrl("https://example.com/test.png"), |
| 180 | metadata={"source": "test"}, |
| 181 | ) |
| 182 | await self.store.save_attachment(image, DEFAULT_CONTEXT) |
| 183 | loaded = await self.store.load_attachment(image.id, DEFAULT_CONTEXT) |
| 184 | assert loaded == image |
| 185 | |
| 186 | @pytest.mark.asyncio |
| 187 | async def test_load_threads(self): |
| 188 | now = datetime.now() |
| 189 | thread1 = make_thread(thread_id="thread1", created_at=now) |
| 190 | thread2 = make_thread( |
| 191 | thread_id="thread2", created_at=now + timedelta(seconds=1) |
| 192 | ) |
| 193 | thread3 = make_thread( |
| 194 | thread_id="thread3", created_at=now + timedelta(seconds=2) |
| 195 | ) |
| 196 | await self.store.save_thread(thread1, DEFAULT_CONTEXT) |
| 197 | await self.store.save_thread(thread2, DEFAULT_CONTEXT) |
| 198 | await self.store.save_thread(thread3, DEFAULT_CONTEXT) |
| 199 | |
| 200 | page1 = await self.store.load_threads( |
| 201 | limit=2, after=None, order="asc", context=DEFAULT_CONTEXT |
| 202 | ) |
| 203 | assert [t.id for t in page1.data] == ["thread1", "thread2"] |
| 204 | assert page1.has_more is True |
| 205 | assert page1.after == "thread2" |
| 206 | |
| 207 | page2 = await self.store.load_threads( |
| 208 | limit=2, after=page1.data[-1].id, order="asc", context=DEFAULT_CONTEXT |
| 209 | ) |
| 210 | assert [t.id for t in page2.data] == ["thread3"] |
| 211 | assert page2.has_more is False |
| 212 | assert not page2.after |
| 213 | |
| 214 | @pytest.mark.asyncio |
| 215 | async def test_thread_items_ordering(self): |
| 216 | thread = make_thread() |
| 217 | now = datetime.now() |
| 218 | items = [ |
| 219 | UserMessageItem( |
| 220 | id="msg1", |
| 221 | content=[UserMessageTextContent(text="A")], |
| 222 | attachments=[], |
| 223 | inference_options=InferenceOptions(), |
| 224 | thread_id=thread.id, |
| 225 | created_at=now, |
| 226 | ), |
| 227 | UserMessageItem( |
| 228 | id="msg2", |
| 229 | content=[UserMessageTextContent(text="B")], |
| 230 | attachments=[], |
| 231 | inference_options=InferenceOptions(), |
| 232 | thread_id=thread.id, |
| 233 | created_at=now + timedelta(seconds=1), |
| 234 | ), |
| 235 | UserMessageItem( |
| 236 | id="msg3", |
| 237 | content=[UserMessageTextContent(text="C")], |
| 238 | attachments=[], |
| 239 | inference_options=InferenceOptions(), |
| 240 | thread_id=thread.id, |
| 241 | created_at=now + timedelta(seconds=2), |
| 242 | ), |
| 243 | ] |
| 244 | await self.store.save_thread(thread, DEFAULT_CONTEXT) |
| 245 | for item in items: |
| 246 | await self.store.add_thread_item(thread.id, item, DEFAULT_CONTEXT) |
| 247 | |
| 248 | asc = await self.store.load_thread_items( |
| 249 | thread.id, after=None, limit=3, order="asc", context=DEFAULT_CONTEXT |
| 250 | ) |
| 251 | desc = await self.store.load_thread_items( |
| 252 | thread.id, after=None, limit=3, order="desc", context=DEFAULT_CONTEXT |
| 253 | ) |
| 254 | assert [i.id for i in asc.data] == ["msg1", "msg2", "msg3"] |
| 255 | assert [i.id for i in desc.data] == ["msg3", "msg2", "msg1"] |
| 256 | |
| 257 | @pytest.mark.asyncio |
| 258 | async def test_thread_items_offset(self): |
| 259 | thread = make_thread() |
| 260 | now = datetime.now() |
| 261 | items = [ |
| 262 | UserMessageItem( |
| 263 | id=f"msg{i}", |
| 264 | content=[UserMessageTextContent(text=str(i))], |
| 265 | attachments=[], |
| 266 | inference_options=InferenceOptions(), |
| 267 | thread_id=thread.id, |
| 268 | created_at=now + timedelta(seconds=i), |
| 269 | ) |
| 270 | for i in range(5) |
| 271 | ] |
| 272 | await self.store.save_thread(thread, DEFAULT_CONTEXT) |
| 273 | for item in items: |
| 274 | await self.store.add_thread_item(thread.id, item, DEFAULT_CONTEXT) |
| 275 | |
| 276 | after = await self.store.load_thread_items( |
| 277 | thread.id, after="msg1", limit=3, order="asc", context=DEFAULT_CONTEXT |
| 278 | ) |
| 279 | assert [i.id for i in after.data] == ["msg2", "msg3", "msg4"] |
| 280 | assert after.has_more is False |
| 281 | assert not after.after |
| 282 | |
| 283 | after_limit = await self.store.load_thread_items( |
| 284 | thread.id, after=None, limit=2, order="asc", context=DEFAULT_CONTEXT |
| 285 | ) |
| 286 | assert [i.id for i in after_limit.data] == ["msg0", "msg1"] |
| 287 | assert after_limit.has_more is True |
| 288 | assert after_limit.after == "msg1" |
| 289 | |
| 290 | @pytest.mark.asyncio |
| 291 | async def test_save_and_load_item(self): |
| 292 | thread = make_thread() |
| 293 | now = datetime.now() |
| 294 | assistant_msg = AssistantMessageItem( |
| 295 | id="msg_000001", |
| 296 | content=[], |
| 297 | thread_id=thread.id, |
| 298 | created_at=now, |
| 299 | ) |
| 300 | widget = WidgetItem( |
| 301 | id="widget_1", |
| 302 | type="widget", |
| 303 | thread_id=thread.id, |
| 304 | created_at=now + timedelta(seconds=1), |
| 305 | widget=Card(children=[Text(value="Test")]), |
| 306 | ) |
| 307 | await self.store.save_thread(thread, DEFAULT_CONTEXT) |
| 308 | await self.store.add_thread_item(thread.id, assistant_msg, DEFAULT_CONTEXT) |
| 309 | await self.store.add_thread_item(thread.id, widget, DEFAULT_CONTEXT) |
| 310 | |
| 311 | assistant_msg.content = [ |
| 312 | AssistantMessageContent(text="This is an assistant message.") |
| 313 | ] |
| 314 | await self.store.save_item(thread.id, assistant_msg, DEFAULT_CONTEXT) |
| 315 | loaded_assistant = await self.store.load_item( |
| 316 | thread.id, assistant_msg.id, DEFAULT_CONTEXT |
| 317 | ) |
| 318 | assert isinstance(loaded_assistant, AssistantMessageItem) |
| 319 | assert loaded_assistant.id == assistant_msg.id |
| 320 | assert loaded_assistant.content[0] == AssistantMessageContent( |
| 321 | text="This is an assistant message." |
| 322 | ) |
| 323 | |
| 324 | await self.store.save_item(thread.id, widget, DEFAULT_CONTEXT) |
| 325 | loaded_widget = await self.store.load_item( |
| 326 | thread.id, widget.id, DEFAULT_CONTEXT |
| 327 | ) |
| 328 | assert isinstance(loaded_widget, WidgetItem) |
| 329 | assert loaded_widget.id == widget.id |
| 330 | |
| 331 | @pytest.mark.asyncio |
| 332 | async def test_load_nonexistent_item(self): |
| 333 | thread = make_thread() |
| 334 | await self.store.save_thread(thread, DEFAULT_CONTEXT) |
| 335 | with pytest.raises(NotFoundError): |
| 336 | await self.store.load_item(thread.id, "does_not_exist", DEFAULT_CONTEXT) |
| 337 | |
| 338 | @pytest.mark.asyncio |
| 339 | async def test_thread_isolation_by_user(self): |
| 340 | # Create a thread for DEFAULT_CONTEXT |
| 341 | thread = make_thread(thread_id="user1_thread") |
| 342 | await self.store.save_thread(thread, DEFAULT_CONTEXT) |
| 343 | |
| 344 | # Should be accessible by the same user |
| 345 | loaded_default = await self.store.load_thread(thread.id, DEFAULT_CONTEXT) |
| 346 | assert loaded_default.title == thread.title |
| 347 | |
| 348 | # Should not be accessible by another user |
| 349 | with pytest.raises(NotFoundError): |
| 350 | await self.store.load_thread(thread.id, ALTERNATIVE_CONTEXT) |
| 351 | |
| 352 | @pytest.mark.asyncio |
| 353 | async def test_thread_items_isolation_by_user(self): |
| 354 | thread = make_thread(thread_id="shared_thread") |
| 355 | await self.store.save_thread(thread, DEFAULT_CONTEXT) |
| 356 | |
| 357 | # Add items for DEFAULT_CONTEXT |
| 358 | items_default = make_thread_items() |
| 359 | for item in items_default: |
| 360 | await self.store.add_thread_item(thread.id, item, DEFAULT_CONTEXT) |
| 361 | |
| 362 | # Should be accessible by the same user |
| 363 | loaded_items_default = ( |
| 364 | await self.store.load_thread_items( |
| 365 | thread.id, None, 10, "asc", DEFAULT_CONTEXT |
| 366 | ) |
| 367 | ).data |
| 368 | assert loaded_items_default == items_default |
| 369 | |
| 370 | # Should not be accessible by another user (should return empty list) |
| 371 | loaded_items_alternative = ( |
| 372 | await self.store.load_thread_items( |
| 373 | thread.id, None, 10, "asc", ALTERNATIVE_CONTEXT |
| 374 | ) |
| 375 | ).data |
| 376 | assert loaded_items_alternative == [] |
| 377 | |
| 378 | # Try to load a specific item with another user's context (should raise ValueError) |
| 379 | with pytest.raises(NotFoundError): |
| 380 | await self.store.load_item( |
| 381 | thread.id, items_default[0].id, ALTERNATIVE_CONTEXT |
| 382 | ) |
| 383 | |
| 384 | |
| 385 | class TestSqliteStore(TestStore): |
| 386 | def setup_method(self, method): |
| 387 | db_path = f"file:{method.__name__}?mode=memory&cache=shared" |
| 388 | # Keep the shared in-memory database from being deleted when the last connection is closed |
| 389 | self.db = sqlite3.connect(db_path, uri=True) |
| 390 | self.store = SQLiteStore(db_path) |
| 391 | |
| 392 | def teardown_method(self, method): |
| 393 | self.db.close() |
| 394 | |
| 395 | @pytest.mark.asyncio |
| 396 | async def test_default_id_generators_use_default_prefixes(self): |
| 397 | ctx = DEFAULT_CONTEXT |
| 398 | thread_id = self.store.generate_thread_id(ctx) |
| 399 | assert thread_id.startswith("thr_") |
| 400 | |
| 401 | thread = make_thread(thread_id=thread_id) |
| 402 | assert self.store.generate_item_id("message", thread, ctx).startswith("msg_") |
| 403 | assert self.store.generate_item_id("tool_call", thread, ctx).startswith("tc_") |
| 404 | assert self.store.generate_item_id("task", thread, ctx).startswith("tsk_") |
| 405 | assert self.store.generate_item_id("workflow", thread, ctx).startswith("wf_") |
| 406 | |
| 407 | |
| 408 | class TestSqliteStoreCustomIds(TestStore): |
| 409 | def setup_method(self, method): |
| 410 | db_path = f"file:{method.__name__}_custom?mode=memory&cache=shared" |
| 411 | # Keep the shared in-memory database from being deleted when the last connection is closed |
| 412 | self.db = sqlite3.connect(db_path, uri=True) |
| 413 | |
| 414 | class CustomSQLiteStore(SQLiteStore): |
| 415 | def __init__(self, path: str): |
| 416 | super().__init__(path) |
| 417 | self.counter = 0 |
| 418 | |
| 419 | def generate_thread_id(self, context) -> str: |
| 420 | self.counter += 1 |
| 421 | return f"thr_custom_{self.counter}" |
| 422 | |
| 423 | def generate_item_id( |
| 424 | self, item_type: str, thread: ThreadMetadata, context |
| 425 | ) -> str: |
| 426 | self.counter += 1 |
| 427 | return f"{item_type}_custom_{self.counter}_{thread.id}" |
| 428 | |
| 429 | self.store = CustomSQLiteStore(db_path) |
| 430 | |
| 431 | def teardown_method(self, method): |
| 432 | self.db.close() |
| 433 | |
| 434 | @pytest.mark.asyncio |
| 435 | async def test_overridden_id_generators_are_used(self): |
| 436 | ctx = DEFAULT_CONTEXT |
| 437 | thread_id = self.store.generate_thread_id(ctx) |
| 438 | assert thread_id == "thr_custom_1" |
| 439 | |
| 440 | thread = make_thread(thread_id=thread_id) |
| 441 | msg_id = self.store.generate_item_id("message", thread, ctx) |
| 442 | tool_call_id = self.store.generate_item_id("tool_call", thread, ctx) |
| 443 | task_id = self.store.generate_item_id("task", thread, ctx) |
| 444 | |
| 445 | assert msg_id == "message_custom_2_thr_custom_1" |
| 446 | assert tool_call_id == "tool_call_custom_3_thr_custom_1" |
| 447 | assert task_id == "task_custom_4_thr_custom_1" |
| 448 | |
| 449 | thread2 = make_thread(thread_id=self.store.generate_thread_id(ctx)) |
| 450 | assert thread2.id == "thr_custom_5" |
| 451 | msg_id2 = self.store.generate_item_id("message", thread2, ctx) |
| 452 | tool_call_id2 = self.store.generate_item_id("tool_call", thread2, ctx) |
| 453 | task_id2 = self.store.generate_item_id("task", thread2, ctx) |
| 454 | |
| 455 | assert msg_id2 == "message_custom_6_thr_custom_5" |
| 456 | assert tool_call_id2 == "tool_call_custom_7_thr_custom_5" |
| 457 | assert task_id2 == "task_custom_8_thr_custom_5" |
| 458 | |