openai/chatkit-python

Public

mirrored from https://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.3.1

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_store.py

448lines · modeblame

f688d870victor-openai8 months ago1import sqlite3
2from abc import ABC, abstractmethod
3from datetime import datetime, timedelta
4
5import pytest
6from helpers.mock_store import SQLiteStore
7from pydantic import AnyUrl
8
9from chatkit.store import NotFoundError, Store
10from chatkit.types import (
11AssistantMessageContent,
12AssistantMessageItem,
13FileAttachment,
14ImageAttachment,
15InferenceOptions,
16ThreadItem,
17ThreadMetadata,
18UserMessageItem,
19UserMessageTextContent,
20WidgetItem,
21)
22from chatkit.widgets import Card, Col, Text
23from tests._types import RequestContext
24
25
26def make_thread(thread_id="test_thread", created_at=None):
27if created_at is None:
28created_at = datetime.now()
29return ThreadMetadata(
30id=thread_id,
31title="Test Thread",
32created_at=created_at,
33metadata={"test": "test"},
34)
35
36
37def make_thread_items() -> list[ThreadItem]:
38now = datetime.now()
39user_msg = UserMessageItem(
40id="msg_100000",
41content=[UserMessageTextContent(text="Hello!")],
42attachments=[],
43inference_options=InferenceOptions(),
44thread_id="test_thread",
45created_at=now,
46)
47assistant_msg = AssistantMessageItem(
48id="msg_000001",
49content=[AssistantMessageContent(text="Hi there!")],
50thread_id="test_thread",
51created_at=now + timedelta(seconds=1),
52)
53widget = WidgetItem(
54id="widget_1",
55type="widget",
56thread_id="test_thread",
57created_at=now + timedelta(seconds=2),
58widget=Card(
59children=[
60Col(
61padding={"x": 4, "y": 3},
62border={"bottom": 1},
63children=[
64Text(
65value="Title",
66weight="medium",
67size="sm",
68color="secondary",
69),
70Text(
71value="test",
72),
73],
74),
75]
76),
77)
78
79return [user_msg, assistant_msg, widget]
80
81
82DEFAULT_CONTEXT = RequestContext(user_id="test_user")
83ALTERNATIVE_CONTEXT = RequestContext(user_id="alternative_user")
84
85
86class TestStore(ABC):
87store: Store
88
89@abstractmethod
90def setup_method(self, method):
91pass
92
93@abstractmethod
94def teardown_method(self, method):
95pass
96
97@pytest.mark.asyncio
98async def test_save_and_load_thread_metadata(self):
99thread = make_thread()
100await self.store.save_thread(thread, DEFAULT_CONTEXT)
101loaded_meta = await self.store.load_thread(thread.id, DEFAULT_CONTEXT)
102assert loaded_meta.id == thread.id
103assert loaded_meta.title == thread.title
104assert loaded_meta.metadata == thread.metadata
105
106@pytest.mark.asyncio
107async def test_save_and_load_thread_metadata_null_title(self):
108thread = ThreadMetadata(
109id="test_thread",
110title=None,
111created_at=datetime.now(),
112metadata={"test": "test"},
113)
114await self.store.save_thread(thread, DEFAULT_CONTEXT)
115loaded_meta = await self.store.load_thread(thread.id, DEFAULT_CONTEXT)
116assert loaded_meta.id == thread.id
117assert loaded_meta.title is None
118assert loaded_meta.metadata == thread.metadata
119
120@pytest.mark.asyncio
121async def test_update_thread_metadata(self):
122thread = make_thread()
123await self.store.save_thread(thread, DEFAULT_CONTEXT)
124thread.title = "Updated Title"
125await self.store.save_thread(thread, DEFAULT_CONTEXT)
126loaded_meta = await self.store.load_thread(thread.id, DEFAULT_CONTEXT)
127assert loaded_meta.title == "Updated Title"
128
129@pytest.mark.asyncio
130async def test_save_and_load_thread_items(self):
131thread = make_thread()
132items = make_thread_items()
133await self.store.save_thread(thread, DEFAULT_CONTEXT)
134for item in items:
135await self.store.add_thread_item(thread.id, item, DEFAULT_CONTEXT)
136loaded_items = (
137await self.store.load_thread_items(
138thread.id, None, 10, "asc", DEFAULT_CONTEXT
139)
140).data
141assert loaded_items == items
142
143@pytest.mark.asyncio
144async def test_overwrite_thread_metadata(self):
145thread = make_thread()
146await self.store.save_thread(thread, DEFAULT_CONTEXT)
147thread.title = "Updated Title"
148await self.store.save_thread(thread, DEFAULT_CONTEXT)
149loaded_meta = await self.store.load_thread(thread.id, DEFAULT_CONTEXT)
150assert loaded_meta.title == "Updated Title"
151
152@pytest.mark.asyncio
153async def test_save_and_load_file(self):
154file = FileAttachment(
155id="file_1",
156type="file",
157mime_type="text/plain",
158name="test.txt",
159)
160await self.store.save_attachment(file, DEFAULT_CONTEXT)
161loaded = await self.store.load_attachment(file.id, DEFAULT_CONTEXT)
162assert loaded == file
163
164@pytest.mark.asyncio
165async def test_save_and_load_image(self):
166image = ImageAttachment(
167id="image_1",
168type="image",
169mime_type="image/png",
170name="test.png",
171preview_url=AnyUrl("https://example.com/test.png"),
172)
173await self.store.save_attachment(image, DEFAULT_CONTEXT)
174loaded = await self.store.load_attachment(image.id, DEFAULT_CONTEXT)
175assert loaded == image
176
177@pytest.mark.asyncio
178async def test_load_threads(self):
179now = datetime.now()
180thread1 = make_thread(thread_id="thread1", created_at=now)
181thread2 = make_thread(
182thread_id="thread2", created_at=now + timedelta(seconds=1)
183)
184thread3 = make_thread(
185thread_id="thread3", created_at=now + timedelta(seconds=2)
186)
187await self.store.save_thread(thread1, DEFAULT_CONTEXT)
188await self.store.save_thread(thread2, DEFAULT_CONTEXT)
189await self.store.save_thread(thread3, DEFAULT_CONTEXT)
190
191page1 = await self.store.load_threads(
192limit=2, after=None, order="asc", context=DEFAULT_CONTEXT
193)
194assert [t.id for t in page1.data] == ["thread1", "thread2"]
195assert page1.has_more is True
196assert page1.after == "thread2"
197
198page2 = await self.store.load_threads(
199limit=2, after=page1.data[-1].id, order="asc", context=DEFAULT_CONTEXT
200)
201assert [t.id for t in page2.data] == ["thread3"]
202assert page2.has_more is False
203assert not page2.after
204
205@pytest.mark.asyncio
206async def test_thread_items_ordering(self):
207thread = make_thread()
208now = datetime.now()
209items = [
210UserMessageItem(
211id="msg1",
212content=[UserMessageTextContent(text="A")],
213attachments=[],
214inference_options=InferenceOptions(),
215thread_id=thread.id,
216created_at=now,
217),
218UserMessageItem(
219id="msg2",
220content=[UserMessageTextContent(text="B")],
221attachments=[],
222inference_options=InferenceOptions(),
223thread_id=thread.id,
224created_at=now + timedelta(seconds=1),
225),
226UserMessageItem(
227id="msg3",
228content=[UserMessageTextContent(text="C")],
229attachments=[],
230inference_options=InferenceOptions(),
231thread_id=thread.id,
232created_at=now + timedelta(seconds=2),
233),
234]
235await self.store.save_thread(thread, DEFAULT_CONTEXT)
236for item in items:
237await self.store.add_thread_item(thread.id, item, DEFAULT_CONTEXT)
238
239asc = await self.store.load_thread_items(
240thread.id, after=None, limit=3, order="asc", context=DEFAULT_CONTEXT
241)
242desc = await self.store.load_thread_items(
243thread.id, after=None, limit=3, order="desc", context=DEFAULT_CONTEXT
244)
245assert [i.id for i in asc.data] == ["msg1", "msg2", "msg3"]
246assert [i.id for i in desc.data] == ["msg3", "msg2", "msg1"]
247
248@pytest.mark.asyncio
249async def test_thread_items_offset(self):
250thread = make_thread()
251now = datetime.now()
252items = [
253UserMessageItem(
254id=f"msg{i}",
255content=[UserMessageTextContent(text=str(i))],
256attachments=[],
257inference_options=InferenceOptions(),
258thread_id=thread.id,
259created_at=now + timedelta(seconds=i),
260)
261for i in range(5)
262]
263await self.store.save_thread(thread, DEFAULT_CONTEXT)
264for item in items:
265await self.store.add_thread_item(thread.id, item, DEFAULT_CONTEXT)
266
267after = await self.store.load_thread_items(
268thread.id, after="msg1", limit=3, order="asc", context=DEFAULT_CONTEXT
269)
270assert [i.id for i in after.data] == ["msg2", "msg3", "msg4"]
271assert after.has_more is False
272assert not after.after
273
274after_limit = await self.store.load_thread_items(
275thread.id, after=None, limit=2, order="asc", context=DEFAULT_CONTEXT
276)
277assert [i.id for i in after_limit.data] == ["msg0", "msg1"]
278assert after_limit.has_more is True
279assert after_limit.after == "msg1"
280
281@pytest.mark.asyncio
282async def test_save_and_load_item(self):
283thread = make_thread()
284now = datetime.now()
285assistant_msg = AssistantMessageItem(
286id="msg_000001",
287content=[],
288thread_id=thread.id,
289created_at=now,
290)
291widget = WidgetItem(
292id="widget_1",
293type="widget",
294thread_id=thread.id,
295created_at=now + timedelta(seconds=1),
296widget=Card(children=[Text(value="Test")]),
297)
298await self.store.save_thread(thread, DEFAULT_CONTEXT)
299await self.store.add_thread_item(thread.id, assistant_msg, DEFAULT_CONTEXT)
300await self.store.add_thread_item(thread.id, widget, DEFAULT_CONTEXT)
301
302assistant_msg.content = [
303AssistantMessageContent(text="This is an assistant message.")
304]
305await self.store.save_item(thread.id, assistant_msg, DEFAULT_CONTEXT)
306loaded_assistant = await self.store.load_item(
307thread.id, assistant_msg.id, DEFAULT_CONTEXT
308)
309assert isinstance(loaded_assistant, AssistantMessageItem)
310assert loaded_assistant.id == assistant_msg.id
311assert loaded_assistant.content[0] == AssistantMessageContent(
312text="This is an assistant message."
313)
314
315await self.store.save_item(thread.id, widget, DEFAULT_CONTEXT)
316loaded_widget = await self.store.load_item(
317thread.id, widget.id, DEFAULT_CONTEXT
318)
319assert isinstance(loaded_widget, WidgetItem)
320assert loaded_widget.id == widget.id
321
322@pytest.mark.asyncio
323async def test_load_nonexistent_item(self):
324thread = make_thread()
325await self.store.save_thread(thread, DEFAULT_CONTEXT)
326with pytest.raises(NotFoundError):
327await self.store.load_item(thread.id, "does_not_exist", DEFAULT_CONTEXT)
328
329@pytest.mark.asyncio
330async def test_thread_isolation_by_user(self):
331# Create a thread for DEFAULT_CONTEXT
332thread = make_thread(thread_id="user1_thread")
333await self.store.save_thread(thread, DEFAULT_CONTEXT)
334
335# Should be accessible by the same user
336loaded_default = await self.store.load_thread(thread.id, DEFAULT_CONTEXT)
337assert loaded_default.title == thread.title
338
339# Should not be accessible by another user
340with pytest.raises(NotFoundError):
341await self.store.load_thread(thread.id, ALTERNATIVE_CONTEXT)
342
343@pytest.mark.asyncio
344async def test_thread_items_isolation_by_user(self):
345thread = make_thread(thread_id="shared_thread")
346await self.store.save_thread(thread, DEFAULT_CONTEXT)
347
348# Add items for DEFAULT_CONTEXT
349items_default = make_thread_items()
350for item in items_default:
351await self.store.add_thread_item(thread.id, item, DEFAULT_CONTEXT)
352
353# Should be accessible by the same user
354loaded_items_default = (
355await self.store.load_thread_items(
356thread.id, None, 10, "asc", DEFAULT_CONTEXT
357)
358).data
359assert loaded_items_default == items_default
360
361# Should not be accessible by another user (should return empty list)
362loaded_items_alternative = (
363await self.store.load_thread_items(
364thread.id, None, 10, "asc", ALTERNATIVE_CONTEXT
365)
366).data
367assert loaded_items_alternative == []
368
369# Try to load a specific item with another user's context (should raise ValueError)
370with pytest.raises(NotFoundError):
371await self.store.load_item(
372thread.id, items_default[0].id, ALTERNATIVE_CONTEXT
373)
374
375
376class TestSqliteStore(TestStore):
377def setup_method(self, method):
378db_path = f"file:{method.__name__}?mode=memory&cache=shared"
379# Keep the shared in-memory database from being deleted when the last connection is closed
380self.db = sqlite3.connect(db_path, uri=True)
381self.store = SQLiteStore(db_path)
382
383def teardown_method(self, method):
384self.db.close()
385
386@pytest.mark.asyncio
387async def test_default_id_generators_use_default_prefixes(self):
388ctx = DEFAULT_CONTEXT
389thread_id = self.store.generate_thread_id(ctx)
390assert thread_id.startswith("thr_")
391
392thread = make_thread(thread_id=thread_id)
393assert self.store.generate_item_id("message", thread, ctx).startswith("msg_")
394assert self.store.generate_item_id("tool_call", thread, ctx).startswith("tc_")
395assert self.store.generate_item_id("task", thread, ctx).startswith("tsk_")
396assert self.store.generate_item_id("workflow", thread, ctx).startswith("wf_")
397
398
399class TestSqliteStoreCustomIds(TestStore):
400def setup_method(self, method):
401db_path = f"file:{method.__name__}_custom?mode=memory&cache=shared"
402# Keep the shared in-memory database from being deleted when the last connection is closed
403self.db = sqlite3.connect(db_path, uri=True)
404
405class CustomSQLiteStore(SQLiteStore):
406def __init__(self, path: str):
407super().__init__(path)
408self.counter = 0
409
410def generate_thread_id(self, context) -> str:
411self.counter += 1
412return f"thr_custom_{self.counter}"
413
414def generate_item_id(
415self, item_type: str, thread: ThreadMetadata, context
416) -> str:
417self.counter += 1
418return f"{item_type}_custom_{self.counter}_{thread.id}"
419
420self.store = CustomSQLiteStore(db_path)
421
422def teardown_method(self, method):
423self.db.close()
424
425@pytest.mark.asyncio
426async def test_overridden_id_generators_are_used(self):
427ctx = DEFAULT_CONTEXT
428thread_id = self.store.generate_thread_id(ctx)
429assert thread_id == "thr_custom_1"
430
431thread = make_thread(thread_id=thread_id)
432msg_id = self.store.generate_item_id("message", thread, ctx)
433tool_call_id = self.store.generate_item_id("tool_call", thread, ctx)
434task_id = self.store.generate_item_id("task", thread, ctx)
435
436assert msg_id == "message_custom_2_thr_custom_1"
437assert tool_call_id == "tool_call_custom_3_thr_custom_1"
438assert task_id == "task_custom_4_thr_custom_1"
439
440thread2 = make_thread(thread_id=self.store.generate_thread_id(ctx))
441assert thread2.id == "thr_custom_5"
442msg_id2 = self.store.generate_item_id("message", thread2, ctx)
443tool_call_id2 = self.store.generate_item_id("tool_call", thread2, ctx)
444task_id2 = self.store.generate_item_id("task", thread2, ctx)
445
446assert msg_id2 == "message_custom_6_thr_custom_5"
447assert tool_call_id2 == "tool_call_custom_7_thr_custom_5"
448assert task_id2 == "task_custom_8_thr_custom_5"