openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
d14ce4602e2002ef97ccff685ec14fb27c49c94c

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_store.py

457lines · modecode

1import sqlite3
2from abc import ABC, abstractmethod
3from datetime import datetime, timedelta
4
5import pytest
6from helpers.mock_store import SQLiteStore
7from pydantic import AnyUrl
8
9from chatkit.store import NotFoundError, Store
10from chatkit.types import (
11 AssistantMessageContent,
12 AssistantMessageItem,
13 FileAttachment,
14 ImageAttachment,
15 InferenceOptions,
16 ThreadItem,
17 ThreadMetadata,
18 UserMessageItem,
19 UserMessageTextContent,
20 WidgetItem,
21)
22from chatkit.widgets import Card, Col, Text
23from tests._types import RequestContext
24
25
26def make_thread(thread_id="test_thread", created_at=None):
27 if created_at is None:
28 created_at = datetime.now()
29 return ThreadMetadata(
30 id=thread_id,
31 title="Test Thread",
32 created_at=created_at,
33 metadata={"test": "test"},
34 )
35
36
37def make_thread_items() -> list[ThreadItem]:
38 now = datetime.now()
39 user_msg = UserMessageItem(
40 id="msg_100000",
41 content=[UserMessageTextContent(text="Hello!")],
42 attachments=[
43 FileAttachment(
44 id="file_1",
45 type="file",
46 mime_type="text/plain",
47 name="test.txt",
48 metadata={"source": "test"},
49 )
50 ],
51 inference_options=InferenceOptions(),
52 thread_id="test_thread",
53 created_at=now,
54 )
55 assistant_msg = AssistantMessageItem(
56 id="msg_000001",
57 content=[AssistantMessageContent(text="Hi there!")],
58 thread_id="test_thread",
59 created_at=now + timedelta(seconds=1),
60 )
61 widget = WidgetItem(
62 id="widget_1",
63 type="widget",
64 thread_id="test_thread",
65 created_at=now + timedelta(seconds=2),
66 widget=Card(
67 children=[
68 Col(
69 padding={"x": 4, "y": 3},
70 border={"bottom": 1},
71 children=[
72 Text(
73 value="Title",
74 weight="medium",
75 size="sm",
76 color="secondary",
77 ),
78 Text(
79 value="test",
80 ),
81 ],
82 ),
83 ]
84 ),
85 )
86
87 return [user_msg, assistant_msg, widget]
88
89
90DEFAULT_CONTEXT = RequestContext(user_id="test_user")
91ALTERNATIVE_CONTEXT = RequestContext(user_id="alternative_user")
92
93
94class TestStore(ABC):
95 store: Store
96
97 @abstractmethod
98 def setup_method(self, method):
99 pass
100
101 @abstractmethod
102 def teardown_method(self, method):
103 pass
104
105 @pytest.mark.asyncio
106 async def test_save_and_load_thread_metadata(self):
107 thread = make_thread()
108 await self.store.save_thread(thread, DEFAULT_CONTEXT)
109 loaded_meta = await self.store.load_thread(thread.id, DEFAULT_CONTEXT)
110 assert loaded_meta.id == thread.id
111 assert loaded_meta.title == thread.title
112 assert loaded_meta.metadata == thread.metadata
113
114 @pytest.mark.asyncio
115 async def test_save_and_load_thread_metadata_null_title(self):
116 thread = ThreadMetadata(
117 id="test_thread",
118 title=None,
119 created_at=datetime.now(),
120 metadata={"test": "test"},
121 )
122 await self.store.save_thread(thread, DEFAULT_CONTEXT)
123 loaded_meta = await self.store.load_thread(thread.id, DEFAULT_CONTEXT)
124 assert loaded_meta.id == thread.id
125 assert loaded_meta.title is None
126 assert loaded_meta.metadata == thread.metadata
127
128 @pytest.mark.asyncio
129 async def test_update_thread_metadata(self):
130 thread = make_thread()
131 await self.store.save_thread(thread, DEFAULT_CONTEXT)
132 thread.title = "Updated Title"
133 await self.store.save_thread(thread, DEFAULT_CONTEXT)
134 loaded_meta = await self.store.load_thread(thread.id, DEFAULT_CONTEXT)
135 assert loaded_meta.title == "Updated Title"
136
137 @pytest.mark.asyncio
138 async def test_save_and_load_thread_items(self):
139 thread = make_thread()
140 items = make_thread_items()
141 await self.store.save_thread(thread, DEFAULT_CONTEXT)
142 for item in items:
143 await self.store.add_thread_item(thread.id, item, DEFAULT_CONTEXT)
144 loaded_items = (
145 await self.store.load_thread_items(
146 thread.id, None, 10, "asc", DEFAULT_CONTEXT
147 )
148 ).data
149 assert loaded_items == items
150
151 @pytest.mark.asyncio
152 async def test_overwrite_thread_metadata(self):
153 thread = make_thread()
154 await self.store.save_thread(thread, DEFAULT_CONTEXT)
155 thread.title = "Updated Title"
156 await self.store.save_thread(thread, DEFAULT_CONTEXT)
157 loaded_meta = await self.store.load_thread(thread.id, DEFAULT_CONTEXT)
158 assert loaded_meta.title == "Updated Title"
159
160 @pytest.mark.asyncio
161 async def test_save_and_load_file(self):
162 file = FileAttachment(
163 id="file_1",
164 type="file",
165 mime_type="text/plain",
166 name="test.txt",
167 )
168 await self.store.save_attachment(file, DEFAULT_CONTEXT)
169 loaded = await self.store.load_attachment(file.id, DEFAULT_CONTEXT)
170 assert loaded == file
171
172 @pytest.mark.asyncio
173 async def test_save_and_load_image(self):
174 image = ImageAttachment(
175 id="image_1",
176 type="image",
177 mime_type="image/png",
178 name="test.png",
179 preview_url=AnyUrl("https://example.com/test.png"),
180 metadata={"source": "test"},
181 )
182 await self.store.save_attachment(image, DEFAULT_CONTEXT)
183 loaded = await self.store.load_attachment(image.id, DEFAULT_CONTEXT)
184 assert loaded == image
185
186 @pytest.mark.asyncio
187 async def test_load_threads(self):
188 now = datetime.now()
189 thread1 = make_thread(thread_id="thread1", created_at=now)
190 thread2 = make_thread(
191 thread_id="thread2", created_at=now + timedelta(seconds=1)
192 )
193 thread3 = make_thread(
194 thread_id="thread3", created_at=now + timedelta(seconds=2)
195 )
196 await self.store.save_thread(thread1, DEFAULT_CONTEXT)
197 await self.store.save_thread(thread2, DEFAULT_CONTEXT)
198 await self.store.save_thread(thread3, DEFAULT_CONTEXT)
199
200 page1 = await self.store.load_threads(
201 limit=2, after=None, order="asc", context=DEFAULT_CONTEXT
202 )
203 assert [t.id for t in page1.data] == ["thread1", "thread2"]
204 assert page1.has_more is True
205 assert page1.after == "thread2"
206
207 page2 = await self.store.load_threads(
208 limit=2, after=page1.data[-1].id, order="asc", context=DEFAULT_CONTEXT
209 )
210 assert [t.id for t in page2.data] == ["thread3"]
211 assert page2.has_more is False
212 assert not page2.after
213
214 @pytest.mark.asyncio
215 async def test_thread_items_ordering(self):
216 thread = make_thread()
217 now = datetime.now()
218 items = [
219 UserMessageItem(
220 id="msg1",
221 content=[UserMessageTextContent(text="A")],
222 attachments=[],
223 inference_options=InferenceOptions(),
224 thread_id=thread.id,
225 created_at=now,
226 ),
227 UserMessageItem(
228 id="msg2",
229 content=[UserMessageTextContent(text="B")],
230 attachments=[],
231 inference_options=InferenceOptions(),
232 thread_id=thread.id,
233 created_at=now + timedelta(seconds=1),
234 ),
235 UserMessageItem(
236 id="msg3",
237 content=[UserMessageTextContent(text="C")],
238 attachments=[],
239 inference_options=InferenceOptions(),
240 thread_id=thread.id,
241 created_at=now + timedelta(seconds=2),
242 ),
243 ]
244 await self.store.save_thread(thread, DEFAULT_CONTEXT)
245 for item in items:
246 await self.store.add_thread_item(thread.id, item, DEFAULT_CONTEXT)
247
248 asc = await self.store.load_thread_items(
249 thread.id, after=None, limit=3, order="asc", context=DEFAULT_CONTEXT
250 )
251 desc = await self.store.load_thread_items(
252 thread.id, after=None, limit=3, order="desc", context=DEFAULT_CONTEXT
253 )
254 assert [i.id for i in asc.data] == ["msg1", "msg2", "msg3"]
255 assert [i.id for i in desc.data] == ["msg3", "msg2", "msg1"]
256
257 @pytest.mark.asyncio
258 async def test_thread_items_offset(self):
259 thread = make_thread()
260 now = datetime.now()
261 items = [
262 UserMessageItem(
263 id=f"msg{i}",
264 content=[UserMessageTextContent(text=str(i))],
265 attachments=[],
266 inference_options=InferenceOptions(),
267 thread_id=thread.id,
268 created_at=now + timedelta(seconds=i),
269 )
270 for i in range(5)
271 ]
272 await self.store.save_thread(thread, DEFAULT_CONTEXT)
273 for item in items:
274 await self.store.add_thread_item(thread.id, item, DEFAULT_CONTEXT)
275
276 after = await self.store.load_thread_items(
277 thread.id, after="msg1", limit=3, order="asc", context=DEFAULT_CONTEXT
278 )
279 assert [i.id for i in after.data] == ["msg2", "msg3", "msg4"]
280 assert after.has_more is False
281 assert not after.after
282
283 after_limit = await self.store.load_thread_items(
284 thread.id, after=None, limit=2, order="asc", context=DEFAULT_CONTEXT
285 )
286 assert [i.id for i in after_limit.data] == ["msg0", "msg1"]
287 assert after_limit.has_more is True
288 assert after_limit.after == "msg1"
289
290 @pytest.mark.asyncio
291 async def test_save_and_load_item(self):
292 thread = make_thread()
293 now = datetime.now()
294 assistant_msg = AssistantMessageItem(
295 id="msg_000001",
296 content=[],
297 thread_id=thread.id,
298 created_at=now,
299 )
300 widget = WidgetItem(
301 id="widget_1",
302 type="widget",
303 thread_id=thread.id,
304 created_at=now + timedelta(seconds=1),
305 widget=Card(children=[Text(value="Test")]),
306 )
307 await self.store.save_thread(thread, DEFAULT_CONTEXT)
308 await self.store.add_thread_item(thread.id, assistant_msg, DEFAULT_CONTEXT)
309 await self.store.add_thread_item(thread.id, widget, DEFAULT_CONTEXT)
310
311 assistant_msg.content = [
312 AssistantMessageContent(text="This is an assistant message.")
313 ]
314 await self.store.save_item(thread.id, assistant_msg, DEFAULT_CONTEXT)
315 loaded_assistant = await self.store.load_item(
316 thread.id, assistant_msg.id, DEFAULT_CONTEXT
317 )
318 assert isinstance(loaded_assistant, AssistantMessageItem)
319 assert loaded_assistant.id == assistant_msg.id
320 assert loaded_assistant.content[0] == AssistantMessageContent(
321 text="This is an assistant message."
322 )
323
324 await self.store.save_item(thread.id, widget, DEFAULT_CONTEXT)
325 loaded_widget = await self.store.load_item(
326 thread.id, widget.id, DEFAULT_CONTEXT
327 )
328 assert isinstance(loaded_widget, WidgetItem)
329 assert loaded_widget.id == widget.id
330
331 @pytest.mark.asyncio
332 async def test_load_nonexistent_item(self):
333 thread = make_thread()
334 await self.store.save_thread(thread, DEFAULT_CONTEXT)
335 with pytest.raises(NotFoundError):
336 await self.store.load_item(thread.id, "does_not_exist", DEFAULT_CONTEXT)
337
338 @pytest.mark.asyncio
339 async def test_thread_isolation_by_user(self):
340 # Create a thread for DEFAULT_CONTEXT
341 thread = make_thread(thread_id="user1_thread")
342 await self.store.save_thread(thread, DEFAULT_CONTEXT)
343
344 # Should be accessible by the same user
345 loaded_default = await self.store.load_thread(thread.id, DEFAULT_CONTEXT)
346 assert loaded_default.title == thread.title
347
348 # Should not be accessible by another user
349 with pytest.raises(NotFoundError):
350 await self.store.load_thread(thread.id, ALTERNATIVE_CONTEXT)
351
352 @pytest.mark.asyncio
353 async def test_thread_items_isolation_by_user(self):
354 thread = make_thread(thread_id="shared_thread")
355 await self.store.save_thread(thread, DEFAULT_CONTEXT)
356
357 # Add items for DEFAULT_CONTEXT
358 items_default = make_thread_items()
359 for item in items_default:
360 await self.store.add_thread_item(thread.id, item, DEFAULT_CONTEXT)
361
362 # Should be accessible by the same user
363 loaded_items_default = (
364 await self.store.load_thread_items(
365 thread.id, None, 10, "asc", DEFAULT_CONTEXT
366 )
367 ).data
368 assert loaded_items_default == items_default
369
370 # Should not be accessible by another user (should return empty list)
371 loaded_items_alternative = (
372 await self.store.load_thread_items(
373 thread.id, None, 10, "asc", ALTERNATIVE_CONTEXT
374 )
375 ).data
376 assert loaded_items_alternative == []
377
378 # Try to load a specific item with another user's context (should raise ValueError)
379 with pytest.raises(NotFoundError):
380 await self.store.load_item(
381 thread.id, items_default[0].id, ALTERNATIVE_CONTEXT
382 )
383
384
385class TestSqliteStore(TestStore):
386 def setup_method(self, method):
387 db_path = f"file:{method.__name__}?mode=memory&cache=shared"
388 # Keep the shared in-memory database from being deleted when the last connection is closed
389 self.db = sqlite3.connect(db_path, uri=True)
390 self.store = SQLiteStore(db_path)
391
392 def teardown_method(self, method):
393 self.db.close()
394
395 @pytest.mark.asyncio
396 async def test_default_id_generators_use_default_prefixes(self):
397 ctx = DEFAULT_CONTEXT
398 thread_id = self.store.generate_thread_id(ctx)
399 assert thread_id.startswith("thr_")
400
401 thread = make_thread(thread_id=thread_id)
402 assert self.store.generate_item_id("message", thread, ctx).startswith("msg_")
403 assert self.store.generate_item_id("tool_call", thread, ctx).startswith("tc_")
404 assert self.store.generate_item_id("task", thread, ctx).startswith("tsk_")
405 assert self.store.generate_item_id("workflow", thread, ctx).startswith("wf_")
406
407
408class TestSqliteStoreCustomIds(TestStore):
409 def setup_method(self, method):
410 db_path = f"file:{method.__name__}_custom?mode=memory&cache=shared"
411 # Keep the shared in-memory database from being deleted when the last connection is closed
412 self.db = sqlite3.connect(db_path, uri=True)
413
414 class CustomSQLiteStore(SQLiteStore):
415 def __init__(self, path: str):
416 super().__init__(path)
417 self.counter = 0
418
419 def generate_thread_id(self, context) -> str:
420 self.counter += 1
421 return f"thr_custom_{self.counter}"
422
423 def generate_item_id(
424 self, item_type: str, thread: ThreadMetadata, context
425 ) -> str:
426 self.counter += 1
427 return f"{item_type}_custom_{self.counter}_{thread.id}"
428
429 self.store = CustomSQLiteStore(db_path)
430
431 def teardown_method(self, method):
432 self.db.close()
433
434 @pytest.mark.asyncio
435 async def test_overridden_id_generators_are_used(self):
436 ctx = DEFAULT_CONTEXT
437 thread_id = self.store.generate_thread_id(ctx)
438 assert thread_id == "thr_custom_1"
439
440 thread = make_thread(thread_id=thread_id)
441 msg_id = self.store.generate_item_id("message", thread, ctx)
442 tool_call_id = self.store.generate_item_id("tool_call", thread, ctx)
443 task_id = self.store.generate_item_id("task", thread, ctx)
444
445 assert msg_id == "message_custom_2_thr_custom_1"
446 assert tool_call_id == "tool_call_custom_3_thr_custom_1"
447 assert task_id == "task_custom_4_thr_custom_1"
448
449 thread2 = make_thread(thread_id=self.store.generate_thread_id(ctx))
450 assert thread2.id == "thr_custom_5"
451 msg_id2 = self.store.generate_item_id("message", thread2, ctx)
452 tool_call_id2 = self.store.generate_item_id("tool_call", thread2, ctx)
453 task_id2 = self.store.generate_item_id("task", thread2, ctx)
454
455 assert msg_id2 == "message_custom_6_thr_custom_5"
456 assert tool_call_id2 == "tool_call_custom_7_thr_custom_5"
457 assert task_id2 == "task_custom_8_thr_custom_5"
458