openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
cec87b20f1817daa284c63d63ca7cf3f69a38f07

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_store.py

448lines · modecode

1import sqlite3
2from abc import ABC, abstractmethod
3from datetime import datetime, timedelta
4
5import pytest
6from helpers.mock_store import SQLiteStore
7from pydantic import AnyUrl
8
9from chatkit.store import NotFoundError, Store
10from chatkit.types import (
11 AssistantMessageContent,
12 AssistantMessageItem,
13 FileAttachment,
14 ImageAttachment,
15 InferenceOptions,
16 ThreadItem,
17 ThreadMetadata,
18 UserMessageItem,
19 UserMessageTextContent,
20 WidgetItem,
21)
22from chatkit.widgets import Card, Col, Text
23from tests._types import RequestContext
24
25
26def make_thread(thread_id="test_thread", created_at=None):
27 if created_at is None:
28 created_at = datetime.now()
29 return ThreadMetadata(
30 id=thread_id,
31 title="Test Thread",
32 created_at=created_at,
33 metadata={"test": "test"},
34 )
35
36
37def make_thread_items() -> list[ThreadItem]:
38 now = datetime.now()
39 user_msg = UserMessageItem(
40 id="msg_100000",
41 content=[UserMessageTextContent(text="Hello!")],
42 attachments=[],
43 inference_options=InferenceOptions(),
44 thread_id="test_thread",
45 created_at=now,
46 )
47 assistant_msg = AssistantMessageItem(
48 id="msg_000001",
49 content=[AssistantMessageContent(text="Hi there!")],
50 thread_id="test_thread",
51 created_at=now + timedelta(seconds=1),
52 )
53 widget = WidgetItem(
54 id="widget_1",
55 type="widget",
56 thread_id="test_thread",
57 created_at=now + timedelta(seconds=2),
58 widget=Card(
59 children=[
60 Col(
61 padding={"x": 4, "y": 3},
62 border={"bottom": 1},
63 children=[
64 Text(
65 value="Title",
66 weight="medium",
67 size="sm",
68 color="secondary",
69 ),
70 Text(
71 value="test",
72 ),
73 ],
74 ),
75 ]
76 ),
77 )
78
79 return [user_msg, assistant_msg, widget]
80
81
82DEFAULT_CONTEXT = RequestContext(user_id="test_user")
83ALTERNATIVE_CONTEXT = RequestContext(user_id="alternative_user")
84
85
86class TestStore(ABC):
87 store: Store
88
89 @abstractmethod
90 def setup_method(self, method):
91 pass
92
93 @abstractmethod
94 def teardown_method(self, method):
95 pass
96
97 @pytest.mark.asyncio
98 async def test_save_and_load_thread_metadata(self):
99 thread = make_thread()
100 await self.store.save_thread(thread, DEFAULT_CONTEXT)
101 loaded_meta = await self.store.load_thread(thread.id, DEFAULT_CONTEXT)
102 assert loaded_meta.id == thread.id
103 assert loaded_meta.title == thread.title
104 assert loaded_meta.metadata == thread.metadata
105
106 @pytest.mark.asyncio
107 async def test_save_and_load_thread_metadata_null_title(self):
108 thread = ThreadMetadata(
109 id="test_thread",
110 title=None,
111 created_at=datetime.now(),
112 metadata={"test": "test"},
113 )
114 await self.store.save_thread(thread, DEFAULT_CONTEXT)
115 loaded_meta = await self.store.load_thread(thread.id, DEFAULT_CONTEXT)
116 assert loaded_meta.id == thread.id
117 assert loaded_meta.title is None
118 assert loaded_meta.metadata == thread.metadata
119
120 @pytest.mark.asyncio
121 async def test_update_thread_metadata(self):
122 thread = make_thread()
123 await self.store.save_thread(thread, DEFAULT_CONTEXT)
124 thread.title = "Updated Title"
125 await self.store.save_thread(thread, DEFAULT_CONTEXT)
126 loaded_meta = await self.store.load_thread(thread.id, DEFAULT_CONTEXT)
127 assert loaded_meta.title == "Updated Title"
128
129 @pytest.mark.asyncio
130 async def test_save_and_load_thread_items(self):
131 thread = make_thread()
132 items = make_thread_items()
133 await self.store.save_thread(thread, DEFAULT_CONTEXT)
134 for item in items:
135 await self.store.add_thread_item(thread.id, item, DEFAULT_CONTEXT)
136 loaded_items = (
137 await self.store.load_thread_items(
138 thread.id, None, 10, "asc", DEFAULT_CONTEXT
139 )
140 ).data
141 assert loaded_items == items
142
143 @pytest.mark.asyncio
144 async def test_overwrite_thread_metadata(self):
145 thread = make_thread()
146 await self.store.save_thread(thread, DEFAULT_CONTEXT)
147 thread.title = "Updated Title"
148 await self.store.save_thread(thread, DEFAULT_CONTEXT)
149 loaded_meta = await self.store.load_thread(thread.id, DEFAULT_CONTEXT)
150 assert loaded_meta.title == "Updated Title"
151
152 @pytest.mark.asyncio
153 async def test_save_and_load_file(self):
154 file = FileAttachment(
155 id="file_1",
156 type="file",
157 mime_type="text/plain",
158 name="test.txt",
159 )
160 await self.store.save_attachment(file, DEFAULT_CONTEXT)
161 loaded = await self.store.load_attachment(file.id, DEFAULT_CONTEXT)
162 assert loaded == file
163
164 @pytest.mark.asyncio
165 async def test_save_and_load_image(self):
166 image = ImageAttachment(
167 id="image_1",
168 type="image",
169 mime_type="image/png",
170 name="test.png",
171 preview_url=AnyUrl("https://example.com/test.png"),
172 )
173 await self.store.save_attachment(image, DEFAULT_CONTEXT)
174 loaded = await self.store.load_attachment(image.id, DEFAULT_CONTEXT)
175 assert loaded == image
176
177 @pytest.mark.asyncio
178 async def test_load_threads(self):
179 now = datetime.now()
180 thread1 = make_thread(thread_id="thread1", created_at=now)
181 thread2 = make_thread(
182 thread_id="thread2", created_at=now + timedelta(seconds=1)
183 )
184 thread3 = make_thread(
185 thread_id="thread3", created_at=now + timedelta(seconds=2)
186 )
187 await self.store.save_thread(thread1, DEFAULT_CONTEXT)
188 await self.store.save_thread(thread2, DEFAULT_CONTEXT)
189 await self.store.save_thread(thread3, DEFAULT_CONTEXT)
190
191 page1 = await self.store.load_threads(
192 limit=2, after=None, order="asc", context=DEFAULT_CONTEXT
193 )
194 assert [t.id for t in page1.data] == ["thread1", "thread2"]
195 assert page1.has_more is True
196 assert page1.after == "thread2"
197
198 page2 = await self.store.load_threads(
199 limit=2, after=page1.data[-1].id, order="asc", context=DEFAULT_CONTEXT
200 )
201 assert [t.id for t in page2.data] == ["thread3"]
202 assert page2.has_more is False
203 assert not page2.after
204
205 @pytest.mark.asyncio
206 async def test_thread_items_ordering(self):
207 thread = make_thread()
208 now = datetime.now()
209 items = [
210 UserMessageItem(
211 id="msg1",
212 content=[UserMessageTextContent(text="A")],
213 attachments=[],
214 inference_options=InferenceOptions(),
215 thread_id=thread.id,
216 created_at=now,
217 ),
218 UserMessageItem(
219 id="msg2",
220 content=[UserMessageTextContent(text="B")],
221 attachments=[],
222 inference_options=InferenceOptions(),
223 thread_id=thread.id,
224 created_at=now + timedelta(seconds=1),
225 ),
226 UserMessageItem(
227 id="msg3",
228 content=[UserMessageTextContent(text="C")],
229 attachments=[],
230 inference_options=InferenceOptions(),
231 thread_id=thread.id,
232 created_at=now + timedelta(seconds=2),
233 ),
234 ]
235 await self.store.save_thread(thread, DEFAULT_CONTEXT)
236 for item in items:
237 await self.store.add_thread_item(thread.id, item, DEFAULT_CONTEXT)
238
239 asc = await self.store.load_thread_items(
240 thread.id, after=None, limit=3, order="asc", context=DEFAULT_CONTEXT
241 )
242 desc = await self.store.load_thread_items(
243 thread.id, after=None, limit=3, order="desc", context=DEFAULT_CONTEXT
244 )
245 assert [i.id for i in asc.data] == ["msg1", "msg2", "msg3"]
246 assert [i.id for i in desc.data] == ["msg3", "msg2", "msg1"]
247
248 @pytest.mark.asyncio
249 async def test_thread_items_offset(self):
250 thread = make_thread()
251 now = datetime.now()
252 items = [
253 UserMessageItem(
254 id=f"msg{i}",
255 content=[UserMessageTextContent(text=str(i))],
256 attachments=[],
257 inference_options=InferenceOptions(),
258 thread_id=thread.id,
259 created_at=now + timedelta(seconds=i),
260 )
261 for i in range(5)
262 ]
263 await self.store.save_thread(thread, DEFAULT_CONTEXT)
264 for item in items:
265 await self.store.add_thread_item(thread.id, item, DEFAULT_CONTEXT)
266
267 after = await self.store.load_thread_items(
268 thread.id, after="msg1", limit=3, order="asc", context=DEFAULT_CONTEXT
269 )
270 assert [i.id for i in after.data] == ["msg2", "msg3", "msg4"]
271 assert after.has_more is False
272 assert not after.after
273
274 after_limit = await self.store.load_thread_items(
275 thread.id, after=None, limit=2, order="asc", context=DEFAULT_CONTEXT
276 )
277 assert [i.id for i in after_limit.data] == ["msg0", "msg1"]
278 assert after_limit.has_more is True
279 assert after_limit.after == "msg1"
280
281 @pytest.mark.asyncio
282 async def test_save_and_load_item(self):
283 thread = make_thread()
284 now = datetime.now()
285 assistant_msg = AssistantMessageItem(
286 id="msg_000001",
287 content=[],
288 thread_id=thread.id,
289 created_at=now,
290 )
291 widget = WidgetItem(
292 id="widget_1",
293 type="widget",
294 thread_id=thread.id,
295 created_at=now + timedelta(seconds=1),
296 widget=Card(children=[Text(value="Test")]),
297 )
298 await self.store.save_thread(thread, DEFAULT_CONTEXT)
299 await self.store.add_thread_item(thread.id, assistant_msg, DEFAULT_CONTEXT)
300 await self.store.add_thread_item(thread.id, widget, DEFAULT_CONTEXT)
301
302 assistant_msg.content = [
303 AssistantMessageContent(text="This is an assistant message.")
304 ]
305 await self.store.save_item(thread.id, assistant_msg, DEFAULT_CONTEXT)
306 loaded_assistant = await self.store.load_item(
307 thread.id, assistant_msg.id, DEFAULT_CONTEXT
308 )
309 assert isinstance(loaded_assistant, AssistantMessageItem)
310 assert loaded_assistant.id == assistant_msg.id
311 assert loaded_assistant.content[0] == AssistantMessageContent(
312 text="This is an assistant message."
313 )
314
315 await self.store.save_item(thread.id, widget, DEFAULT_CONTEXT)
316 loaded_widget = await self.store.load_item(
317 thread.id, widget.id, DEFAULT_CONTEXT
318 )
319 assert isinstance(loaded_widget, WidgetItem)
320 assert loaded_widget.id == widget.id
321
322 @pytest.mark.asyncio
323 async def test_load_nonexistent_item(self):
324 thread = make_thread()
325 await self.store.save_thread(thread, DEFAULT_CONTEXT)
326 with pytest.raises(NotFoundError):
327 await self.store.load_item(thread.id, "does_not_exist", DEFAULT_CONTEXT)
328
329 @pytest.mark.asyncio
330 async def test_thread_isolation_by_user(self):
331 # Create a thread for DEFAULT_CONTEXT
332 thread = make_thread(thread_id="user1_thread")
333 await self.store.save_thread(thread, DEFAULT_CONTEXT)
334
335 # Should be accessible by the same user
336 loaded_default = await self.store.load_thread(thread.id, DEFAULT_CONTEXT)
337 assert loaded_default.title == thread.title
338
339 # Should not be accessible by another user
340 with pytest.raises(NotFoundError):
341 await self.store.load_thread(thread.id, ALTERNATIVE_CONTEXT)
342
343 @pytest.mark.asyncio
344 async def test_thread_items_isolation_by_user(self):
345 thread = make_thread(thread_id="shared_thread")
346 await self.store.save_thread(thread, DEFAULT_CONTEXT)
347
348 # Add items for DEFAULT_CONTEXT
349 items_default = make_thread_items()
350 for item in items_default:
351 await self.store.add_thread_item(thread.id, item, DEFAULT_CONTEXT)
352
353 # Should be accessible by the same user
354 loaded_items_default = (
355 await self.store.load_thread_items(
356 thread.id, None, 10, "asc", DEFAULT_CONTEXT
357 )
358 ).data
359 assert loaded_items_default == items_default
360
361 # Should not be accessible by another user (should return empty list)
362 loaded_items_alternative = (
363 await self.store.load_thread_items(
364 thread.id, None, 10, "asc", ALTERNATIVE_CONTEXT
365 )
366 ).data
367 assert loaded_items_alternative == []
368
369 # Try to load a specific item with another user's context (should raise ValueError)
370 with pytest.raises(NotFoundError):
371 await self.store.load_item(
372 thread.id, items_default[0].id, ALTERNATIVE_CONTEXT
373 )
374
375
376class TestSqliteStore(TestStore):
377 def setup_method(self, method):
378 db_path = f"file:{method.__name__}?mode=memory&cache=shared"
379 # Keep the shared in-memory database from being deleted when the last connection is closed
380 self.db = sqlite3.connect(db_path, uri=True)
381 self.store = SQLiteStore(db_path)
382
383 def teardown_method(self, method):
384 self.db.close()
385
386 @pytest.mark.asyncio
387 async def test_default_id_generators_use_default_prefixes(self):
388 ctx = DEFAULT_CONTEXT
389 thread_id = self.store.generate_thread_id(ctx)
390 assert thread_id.startswith("thr_")
391
392 thread = make_thread(thread_id=thread_id)
393 assert self.store.generate_item_id("message", thread, ctx).startswith("msg_")
394 assert self.store.generate_item_id("tool_call", thread, ctx).startswith("tc_")
395 assert self.store.generate_item_id("task", thread, ctx).startswith("tsk_")
396 assert self.store.generate_item_id("workflow", thread, ctx).startswith("wf_")
397
398
399class TestSqliteStoreCustomIds(TestStore):
400 def setup_method(self, method):
401 db_path = f"file:{method.__name__}_custom?mode=memory&cache=shared"
402 # Keep the shared in-memory database from being deleted when the last connection is closed
403 self.db = sqlite3.connect(db_path, uri=True)
404
405 class CustomSQLiteStore(SQLiteStore):
406 def __init__(self, path: str):
407 super().__init__(path)
408 self.counter = 0
409
410 def generate_thread_id(self, context) -> str:
411 self.counter += 1
412 return f"thr_custom_{self.counter}"
413
414 def generate_item_id(
415 self, item_type: str, thread: ThreadMetadata, context
416 ) -> str:
417 self.counter += 1
418 return f"{item_type}_custom_{self.counter}_{thread.id}"
419
420 self.store = CustomSQLiteStore(db_path)
421
422 def teardown_method(self, method):
423 self.db.close()
424
425 @pytest.mark.asyncio
426 async def test_overridden_id_generators_are_used(self):
427 ctx = DEFAULT_CONTEXT
428 thread_id = self.store.generate_thread_id(ctx)
429 assert thread_id == "thr_custom_1"
430
431 thread = make_thread(thread_id=thread_id)
432 msg_id = self.store.generate_item_id("message", thread, ctx)
433 tool_call_id = self.store.generate_item_id("tool_call", thread, ctx)
434 task_id = self.store.generate_item_id("task", thread, ctx)
435
436 assert msg_id == "message_custom_2_thr_custom_1"
437 assert tool_call_id == "tool_call_custom_3_thr_custom_1"
438 assert task_id == "task_custom_4_thr_custom_1"
439
440 thread2 = make_thread(thread_id=self.store.generate_thread_id(ctx))
441 assert thread2.id == "thr_custom_5"
442 msg_id2 = self.store.generate_item_id("message", thread2, ctx)
443 tool_call_id2 = self.store.generate_item_id("tool_call", thread2, ctx)
444 task_id2 = self.store.generate_item_id("task", thread2, ctx)
445
446 assert msg_id2 == "message_custom_6_thr_custom_5"
447 assert tool_call_id2 == "tool_call_custom_7_thr_custom_5"
448 assert task_id2 == "task_custom_8_thr_custom_5"
449