openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
f43533b0dad1fc1013dbc615dd4bc192d484b94e

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_chatkit_server.py

1944lines · modecode

1import asyncio
2import sqlite3
3from contextlib import contextmanager
4from datetime import datetime
5from typing import Any, AsyncIterator, Callable, cast
6
7import pytest
8from helpers.mock_store import SQLiteStore
9from pydantic import AnyUrl, TypeAdapter
10
11from chatkit.actions import Action
12from chatkit.errors import ErrorCode
13from chatkit.server import (
14 ChatKitServer,
15 NonStreamingResult,
16 StreamingResult,
17 stream_widget,
18)
19from chatkit.store import (
20 AttachmentStore,
21 NotFoundError,
22 StoreItemType,
23 default_generate_id,
24)
25from chatkit.types import (
26 AssistantMessageContent,
27 AssistantMessageContentPartTextDelta,
28 AssistantMessageItem,
29 Attachment,
30 AttachmentCreateParams,
31 AttachmentDeleteParams,
32 AttachmentsCreateReq,
33 AttachmentsDeleteReq,
34 ClientToolCallItem,
35 FeedbackKind,
36 FileAttachment,
37 ImageAttachment,
38 InferenceOptions,
39 ItemFeedbackParams,
40 ItemsFeedbackReq,
41 ItemsListParams,
42 ItemsListReq,
43 LockedStatus,
44 Page,
45 ProgressUpdateEvent,
46 Thread,
47 ThreadAddClientToolOutputParams,
48 ThreadAddUserMessageParams,
49 ThreadCreatedEvent,
50 ThreadCreateParams,
51 ThreadCustomActionParams,
52 ThreadDeleteParams,
53 ThreadGetByIdParams,
54 ThreadItem,
55 ThreadItemAddedEvent,
56 ThreadItemDoneEvent,
57 ThreadItemRemovedEvent,
58 ThreadItemReplacedEvent,
59 ThreadItemUpdatedEvent,
60 ThreadListParams,
61 ThreadMetadata,
62 ThreadRetryAfterItemParams,
63 ThreadsAddClientToolOutputReq,
64 ThreadsAddUserMessageReq,
65 ThreadsCreateReq,
66 ThreadsCustomActionReq,
67 ThreadsDeleteReq,
68 ThreadsGetByIdReq,
69 ThreadsListReq,
70 ThreadsRetryAfterItemReq,
71 ThreadStreamEvent,
72 ThreadsUpdateReq,
73 ThreadUpdatedEvent,
74 ThreadUpdateParams,
75 ToolChoice,
76 UserMessageInput,
77 UserMessageItem,
78 UserMessageTextContent,
79 WidgetItem,
80 WidgetRootUpdated,
81)
82from chatkit.widgets import Card, Text
83from tests._types import RequestContext
84from tests.test_store import make_thread_items
85
86server_id = 0
87
88
89DEFAULT_CONTEXT = RequestContext(user_id="test_user")
90
91
92class InMemoryFileStore(AttachmentStore):
93 def __init__(self):
94 self.files = {}
95
96 async def delete_attachment(self, attachment_id: str, context: Any) -> None:
97 del self.files[attachment_id]
98
99 async def create_attachment(
100 self, input: AttachmentCreateParams, context: RequestContext
101 ) -> Attachment:
102 # Simple in-memory attachment creation
103 id = f"atc_{len(self.files) + 1}"
104 if input.mime_type.startswith("image/"):
105 attachment = ImageAttachment(
106 id=id,
107 mime_type=input.mime_type,
108 name=input.name,
109 preview_url=AnyUrl(f"https://example.com/{id}/preview"),
110 upload_url=AnyUrl(f"https://example.com/{id}/upload"),
111 )
112 else:
113 attachment = FileAttachment(
114 id=id,
115 mime_type=input.mime_type,
116 name=input.name,
117 upload_url=AnyUrl(f"https://example.com/{id}/upload"),
118 )
119 self.files[attachment.id] = attachment
120 return attachment
121
122
123def decode_event(event: bytes) -> ThreadStreamEvent:
124 return TypeAdapter(ThreadStreamEvent).validate_json(event.split(b"data: ")[1])
125
126
127async def decode_streaming_result(
128 streaming_result: StreamingResult,
129) -> list[ThreadStreamEvent]:
130 return [decode_event(event) async for event in streaming_result.json_events]
131
132
133@contextmanager
134def make_server(
135 responder: Callable[
136 [ThreadMetadata, UserMessageItem | None, Any], AsyncIterator[ThreadStreamEvent]
137 ]
138 | None = None,
139 handle_feedback: Callable[[str, list[str], FeedbackKind, Any], None] | None = None,
140 action_callback: Callable[
141 [ThreadMetadata, Action[str, Any], WidgetItem | None, Any],
142 AsyncIterator[ThreadStreamEvent],
143 ]
144 | None = None,
145 file_store: AttachmentStore | None = None,
146):
147 global server_id
148 db_path = f"file:{server_id}?mode=memory&cache=shared"
149 server_id += 1
150 # Keep the shared in-memory database from being deleted when the last connection is closed
151 db = sqlite3.connect(db_path, uri=True)
152
153 class TestChatKitServer(ChatKitServer):
154 def __init__(self):
155 super().__init__(SQLiteStore(db_path), file_store)
156
157 def action(
158 self,
159 thread: ThreadMetadata,
160 action: Action[str, Any],
161 sender: WidgetItem | None,
162 context: Any,
163 ) -> AsyncIterator[ThreadStreamEvent]:
164 if action_callback is None:
165 raise ValueError("action_callback not wired up")
166 return action_callback(thread, action, sender, context)
167
168 def respond(
169 self,
170 thread: ThreadMetadata,
171 input_user_message: UserMessageItem | None,
172 context: Any,
173 ) -> AsyncIterator[ThreadStreamEvent]:
174 if responder is None:
175 return self._empty_responder()
176 return responder(thread, input_user_message, context)
177
178 async def _empty_responder(self) -> AsyncIterator[ThreadStreamEvent]:
179 return
180 yield
181
182 async def add_feedback(
183 self,
184 thread_id: str,
185 item_ids: list[str],
186 feedback: FeedbackKind,
187 context: Any,
188 ) -> None:
189 if handle_feedback is None:
190 return
191 handle_feedback(thread_id, item_ids, feedback, context)
192
193 async def process_streaming(
194 self, request_obj, context: Any | None = None
195 ) -> list[ThreadStreamEvent]:
196 result = await self.process(
197 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
198 )
199 assert isinstance(result, StreamingResult)
200 return await decode_streaming_result(result)
201
202 async def process_non_streaming(self, request_obj, context: Any | None = None):
203 result = await self.process(
204 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
205 )
206 assert isinstance(result, NonStreamingResult)
207 return result
208
209 try:
210 yield TestChatKitServer()
211 finally:
212 db.close()
213
214
215async def test_stream_cancellation_persists_pending_assistant_message_and_hidden_context():
216 async def responder(
217 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
218 ) -> AsyncIterator[ThreadStreamEvent]:
219 yield ThreadItemAddedEvent(
220 item=AssistantMessageItem(
221 id="assistant-message-pending",
222 created_at=datetime.now(),
223 content=[AssistantMessageContent(text="Hello, ")],
224 thread_id=thread.id,
225 )
226 )
227 yield ThreadItemUpdatedEvent(
228 item_id="assistant-message-pending",
229 update=AssistantMessageContentPartTextDelta(
230 content_index=0,
231 delta="World!",
232 ),
233 )
234 raise asyncio.CancelledError()
235
236 with make_server(responder) as server:
237 # Allow hidden_context id generation in this test store
238 original_generate_item_id = server.store.generate_item_id
239
240 def generate_item_id(
241 item_type: StoreItemType, thread: ThreadMetadata, context: Any
242 ):
243 if item_type == "sdk_hidden_context":
244 return default_generate_id("sdk_hidden_context")
245 return original_generate_item_id(item_type, thread, context)
246
247 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
248
249 stream = await server.process(
250 ThreadsCreateReq(
251 params=ThreadCreateParams(
252 input=UserMessageInput(
253 content=[UserMessageTextContent(text="Hello")],
254 attachments=[],
255 inference_options=InferenceOptions(),
256 )
257 )
258 ).model_dump_json(),
259 DEFAULT_CONTEXT,
260 )
261 assert isinstance(stream, StreamingResult)
262
263 events: list[ThreadStreamEvent] = []
264 with pytest.raises(asyncio.CancelledError): # noqa: PT012
265 async for raw in stream.json_events:
266 events.append(decode_event(raw))
267
268 thread = next(e.thread for e in events if e.type == "thread.created")
269 items = await server.store.load_thread_items(
270 thread.id, None, 1, "desc", DEFAULT_CONTEXT
271 )
272 hidden_context_item = items.data[-1]
273 assert hidden_context_item.type == "sdk_hidden_context"
274 assert (
275 hidden_context_item.content
276 == "The user cancelled the stream. Stop responding to the prior request."
277 )
278
279 assistant_message_item = await server.store.load_item(
280 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
281 )
282 assert assistant_message_item.type == "assistant_message"
283 assert assistant_message_item.content[0].text == "Hello, World!"
284
285
286async def test_stream_cancellation_does_not_persist_pending_empty_assistant_message():
287 async def responder(
288 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
289 ) -> AsyncIterator[ThreadStreamEvent]:
290 yield ThreadItemAddedEvent(
291 item=AssistantMessageItem(
292 id="assistant-message-pending",
293 created_at=datetime.now(),
294 content=[],
295 thread_id=thread.id,
296 )
297 )
298 raise asyncio.CancelledError()
299
300 with make_server(responder) as server:
301 original_generate_item_id = server.store.generate_item_id
302
303 def generate_item_id(
304 item_type: StoreItemType, thread: ThreadMetadata, context: Any
305 ):
306 if item_type == "sdk_hidden_context":
307 return default_generate_id("sdk_hidden_context")
308 return original_generate_item_id(item_type, thread, context)
309
310 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
311
312 stream = await server.process(
313 ThreadsCreateReq(
314 params=ThreadCreateParams(
315 input=UserMessageInput(
316 content=[UserMessageTextContent(text="Hello")],
317 attachments=[],
318 inference_options=InferenceOptions(),
319 )
320 )
321 ).model_dump_json(),
322 DEFAULT_CONTEXT,
323 )
324 assert isinstance(stream, StreamingResult)
325
326 events: list[ThreadStreamEvent] = []
327 with pytest.raises(asyncio.CancelledError): # noqa: PT012
328 async for raw in stream.json_events:
329 events.append(decode_event(raw))
330
331 thread = next(e.thread for e in events if e.type == "thread.created")
332 items = await server.store.load_thread_items(
333 thread.id, None, 1, "desc", DEFAULT_CONTEXT
334 )
335 hidden_context_item = items.data[-1]
336 assert hidden_context_item.type == "sdk_hidden_context"
337 assert (
338 hidden_context_item.content
339 == "The user cancelled the stream. Stop responding to the prior request."
340 )
341
342 with pytest.raises(NotFoundError):
343 await server.store.load_item(
344 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
345 )
346
347
348async def test_flows_context_to_responder():
349 responder_context = None
350 add_feedback_context = None
351
352 async def responder(
353 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
354 ) -> AsyncIterator[ThreadStreamEvent]:
355 nonlocal responder_context
356 responder_context = context
357 yield ThreadItemDoneEvent(
358 item=AssistantMessageItem(
359 id="msg_1",
360 created_at=datetime.now(),
361 content=[AssistantMessageContent(text="Hello, world!")],
362 thread_id=thread.id,
363 ),
364 )
365
366 def add_feedback(
367 thread_id: str, item_ids: list[str], feedback: FeedbackKind, context: Any
368 ) -> None:
369 nonlocal add_feedback_context
370 add_feedback_context = context
371
372 context = RequestContext(user_id="text_ctx")
373 with make_server(responder, add_feedback) as server:
374 events = await server.process_streaming(
375 ThreadsCreateReq(
376 params=ThreadCreateParams(
377 input=UserMessageInput(
378 content=[UserMessageTextContent(text="Hello, world!")],
379 attachments=[],
380 inference_options=InferenceOptions(),
381 )
382 )
383 ),
384 context,
385 )
386 assert responder_context == context
387 thread = next(
388 event.thread for event in events if event.type == "thread.created"
389 )
390 item = next(
391 event.item
392 for event in events
393 if event.type == "thread.item.done"
394 and event.item.type == "assistant_message"
395 )
396 await server.process_non_streaming(
397 ItemsFeedbackReq(
398 params=ItemFeedbackParams(
399 thread_id=thread.id,
400 item_ids=[item.id],
401 kind="positive",
402 )
403 ),
404 context,
405 )
406 assert add_feedback_context == context
407
408
409async def test_creates_thread():
410 async def responder(
411 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
412 ) -> AsyncIterator[ThreadStreamEvent]:
413 return
414 yield
415
416 with make_server(responder) as server:
417 events = await server.process_streaming(
418 ThreadsCreateReq(
419 params=ThreadCreateParams(
420 input=UserMessageInput(
421 content=[UserMessageTextContent(text="Hello, world!")],
422 attachments=[],
423 inference_options=InferenceOptions(),
424 quoted_text="Quoted text",
425 )
426 )
427 )
428 )
429
430 thread_created_event = next(
431 event for event in events if event.type == "thread.created"
432 )
433 assert thread_created_event.thread.title is None
434
435 item_done_event = next(
436 event.item for event in events if event.type == "thread.item.done"
437 )
438
439 assert item_done_event.type == "user_message"
440 assert item_done_event.content[0].text == "Hello, world!"
441 assert item_done_event.quoted_text == "Quoted text"
442
443
444async def test_saves_thread_title():
445 async def responder(
446 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
447 ) -> AsyncIterator[ThreadStreamEvent]:
448 thread.title = "Updated title"
449 return
450 yield
451
452 with make_server(responder) as server:
453 events = await server.process_streaming(
454 ThreadsCreateReq(
455 params=ThreadCreateParams(
456 input=UserMessageInput(
457 content=[UserMessageTextContent(text="Hello, world!")],
458 attachments=[],
459 inference_options=InferenceOptions(),
460 )
461 )
462 )
463 )
464 thread = next(
465 event.thread for event in events if event.type == "thread.created"
466 )
467 assert (
468 await server.store.load_thread(
469 thread.id, RequestContext(user_id="test_user")
470 )
471 ).title == "Updated title"
472 assert events[-1].type == "thread.updated"
473 assert events[-1].thread.title == "Updated title"
474
475
476async def test_saves_thread_metadata():
477 async def responder(
478 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
479 ) -> AsyncIterator[ThreadStreamEvent]:
480 thread.metadata["test_key"] = "test_value"
481 return
482 yield
483
484 with make_server(responder) as server:
485 events = await server.process_streaming(
486 ThreadsCreateReq(
487 params=ThreadCreateParams(
488 input=UserMessageInput(
489 content=[UserMessageTextContent(text="Hello, world!")],
490 attachments=[],
491 inference_options=InferenceOptions(),
492 )
493 )
494 )
495 )
496 thread = next(
497 event.thread for event in events if event.type == "thread.created"
498 )
499 assert (await server.store.load_thread(thread.id, DEFAULT_CONTEXT)).metadata[
500 "test_key"
501 ] == "test_value"
502 assert events[-1].type == "thread.updated"
503
504
505async def test_saves_thread_locked_fields():
506 async def responder(
507 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
508 ) -> AsyncIterator[ThreadStreamEvent]:
509 thread.status = LockedStatus(reason="Because")
510 return
511 yield
512
513 with make_server(responder) as server:
514 events = await server.process_streaming(
515 ThreadsCreateReq(
516 params=ThreadCreateParams(
517 input=UserMessageInput(
518 content=[UserMessageTextContent(text="Hello, world!")],
519 attachments=[],
520 inference_options=InferenceOptions(),
521 )
522 )
523 )
524 )
525 thread = next(
526 event.thread for event in events if event.type == "thread.created"
527 )
528 loaded = await server.store.load_thread(
529 thread.id, RequestContext(user_id="test_user")
530 )
531 assert loaded.status == LockedStatus(reason="Because")
532 assert events[-1].type == "thread.updated"
533 assert events[-1].thread.status == LockedStatus(reason="Because")
534
535
536async def test_emits_thread_updated_mid_stream_and_persists():
537 async def responder(
538 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
539 ) -> AsyncIterator[ThreadStreamEvent]:
540 # First assistant message
541 yield ThreadItemDoneEvent(
542 item=AssistantMessageItem(
543 id="assistant_a",
544 content=[AssistantMessageContent(text="A")],
545 created_at=datetime.now(),
546 thread_id=thread.id,
547 ),
548 )
549
550 # Update the thread mid-stream
551 thread.title = "Mid-stream title"
552
553 # Second assistant message, after which thread.updated should be emitted
554 yield ThreadItemDoneEvent(
555 item=AssistantMessageItem(
556 id="assistant_b",
557 content=[AssistantMessageContent(text="B")],
558 created_at=datetime.now(),
559 thread_id=thread.id,
560 ),
561 )
562
563 # Continue streaming after the update event
564 yield ThreadItemDoneEvent(
565 item=AssistantMessageItem(
566 id="assistant_c",
567 content=[AssistantMessageContent(text="C")],
568 created_at=datetime.now(),
569 thread_id=thread.id,
570 ),
571 )
572
573 with make_server(responder) as server:
574 events = await server.process_streaming(
575 ThreadsCreateReq(
576 params=ThreadCreateParams(
577 input=UserMessageInput(
578 content=[UserMessageTextContent(text="Hello")],
579 attachments=[],
580 inference_options=InferenceOptions(),
581 )
582 )
583 )
584 )
585
586 # Find the created thread
587 thread = next(e.thread for e in events if isinstance(e, ThreadCreatedEvent))
588
589 # Ensure a thread.updated event occurs right after assistant_b
590 idx_b = next(
591 i
592 for i, e in enumerate(events)
593 if isinstance(e, ThreadItemDoneEvent)
594 and isinstance(e.item, AssistantMessageItem)
595 and e.item.id == "assistant_b"
596 )
597 assert isinstance(events[idx_b + 1], ThreadUpdatedEvent)
598 updated_event = cast(ThreadUpdatedEvent, events[idx_b + 1])
599 assert updated_event.thread.title == "Mid-stream title"
600
601 # Streaming continues after the update event
602 assert isinstance(events[idx_b + 2], ThreadItemDoneEvent)
603 done_event = cast(ThreadItemDoneEvent, events[idx_b + 2])
604 assert isinstance(done_event.item, AssistantMessageItem)
605 assert done_event.item.id == "assistant_c"
606
607 # Persisted in the store
608 saved = await server.store.load_thread(thread.id, DEFAULT_CONTEXT)
609 assert saved.title == "Mid-stream title"
610
611
612async def test_respond_with_tool_call():
613 async def responder(
614 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
615 ) -> AsyncIterator[ThreadStreamEvent]:
616 if isinstance(input, UserMessageItem):
617 yield ThreadItemDoneEvent(
618 item=ClientToolCallItem(
619 id="msg_1",
620 created_at=datetime.now(),
621 name="tool_call_1",
622 arguments={"arg1": "val1", "arg2": False},
623 call_id="tool_call_1",
624 thread_id=thread.id,
625 ),
626 )
627 elif input is None:
628 yield ThreadItemDoneEvent(
629 item=AssistantMessageItem(
630 id="msg_2",
631 content=[
632 AssistantMessageContent(text="Glad the tool call succeeded!")
633 ],
634 created_at=datetime.now(),
635 thread_id=thread.id,
636 ),
637 )
638
639 with make_server(responder) as server:
640 events = await server.process_streaming(
641 ThreadsCreateReq(
642 params=ThreadCreateParams(
643 input=UserMessageInput(
644 content=[UserMessageTextContent(text="Hello, world!")],
645 attachments=[],
646 inference_options=InferenceOptions(),
647 )
648 )
649 )
650 )
651
652 assert len(events) == 4
653 assert events[0].type == "thread.created"
654 thread = events[0].thread
655
656 assert events[1].type == "thread.item.done"
657 assert events[1].item.type == "user_message"
658
659 assert events[2].type == "stream_options"
660
661 assert events[3].type == "thread.item.done"
662 assert events[3].item.type == "client_tool_call"
663 assert events[3].item.id == "msg_1"
664 assert events[3].item.name == "tool_call_1"
665 assert events[3].item.arguments == {"arg1": "val1", "arg2": False}
666 assert events[3].item.call_id == "tool_call_1"
667
668 events = await server.process_streaming(
669 ThreadsAddClientToolOutputReq(
670 params=ThreadAddClientToolOutputParams(
671 thread_id=thread.id,
672 result={"text": "Wow!"},
673 )
674 )
675 )
676
677 assert len(events) == 2
678 assert events[0].type == "stream_options"
679 assert events[1].type == "thread.item.done"
680 assert events[1].item.type == "assistant_message"
681
682
683async def test_removes_tool_call_if_no_output_provided():
684 async def responder(
685 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
686 ) -> AsyncIterator[ThreadStreamEvent]:
687 assert isinstance(input, UserMessageItem)
688 assert input.content[0].type == "input_text"
689 if input.content[0].text == "Message 1":
690 yield ThreadItemDoneEvent(
691 item=ClientToolCallItem(
692 id="msg_1",
693 created_at=datetime.now(),
694 name="tool_call_1",
695 arguments={"arg1": "val1", "arg2": False},
696 call_id="tool_call_1",
697 thread_id=thread.id,
698 ),
699 )
700 else:
701 yield ThreadItemDoneEvent(
702 item=AssistantMessageItem(
703 id="msg_2",
704 content=[AssistantMessageContent(text="All done!")],
705 created_at=datetime.now(),
706 thread_id=thread.id,
707 ),
708 )
709
710 with make_server(responder) as server:
711 events = await server.process_streaming(
712 ThreadsCreateReq(
713 params=ThreadCreateParams(
714 input=UserMessageInput(
715 content=[UserMessageTextContent(text="Message 1")],
716 attachments=[],
717 inference_options=InferenceOptions(),
718 )
719 )
720 )
721 )
722 thread = next(
723 event.thread for event in events if event.type == "thread.created"
724 )
725
726 await server.process_streaming(
727 ThreadsAddUserMessageReq(
728 params=ThreadAddUserMessageParams(
729 thread_id=thread.id,
730 input=UserMessageInput(
731 content=[UserMessageTextContent(text="Message 2")],
732 attachments=[],
733 inference_options=InferenceOptions(),
734 ),
735 )
736 )
737 )
738
739 items_result = await server.process_non_streaming(
740 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
741 )
742 items = TypeAdapter(Page[ThreadItem]).validate_json(items_result.json)
743 assert len(items.data) == 3
744 assert items.data[0].type == "assistant_message"
745 assert items.data[1].type == "user_message"
746 assert items.data[2].type == "user_message"
747
748
749async def test_respond_with_tool_status():
750 async def responder(
751 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
752 ) -> AsyncIterator[ThreadStreamEvent]:
753 yield ProgressUpdateEvent(text="Tool status")
754 yield ThreadItemDoneEvent(
755 item=AssistantMessageItem(
756 id="msg_4",
757 content=[AssistantMessageContent(text="Hello, world!")],
758 created_at=datetime.now(),
759 thread_id=thread.id,
760 ),
761 )
762
763 with make_server(responder) as server:
764 events = await server.process_streaming(
765 ThreadsCreateReq(
766 params=ThreadCreateParams(
767 input=UserMessageInput(
768 content=[UserMessageTextContent(text="Hello, world!")],
769 attachments=[],
770 inference_options=InferenceOptions(),
771 )
772 )
773 )
774 )
775
776 assert len(events) == 5
777 assert events[0].type == "thread.created"
778 assert events[1].type == "thread.item.done"
779 assert events[2].type == "stream_options"
780 assert events[3].type == "progress_update"
781 assert events[4].type == "thread.item.done"
782
783
784async def test_list_threads_response():
785 with make_server() as server:
786 for i in range(25):
787 await server.process_streaming(
788 ThreadsCreateReq(
789 params=ThreadCreateParams(
790 input=UserMessageInput(
791 content=[UserMessageTextContent(text=f"Thread {i}")],
792 attachments=[],
793 inference_options=InferenceOptions(),
794 )
795 )
796 )
797 )
798 list_result = await server.process_non_streaming(
799 ThreadsListReq(params=ThreadListParams(limit=20))
800 )
801 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
802 assert len(threads.data) == 20
803
804 # Newest first
805 assert threads.data[0].created_at > threads.data[1].created_at
806 assert threads.has_more is True
807 assert threads.after is not None
808
809 more_threads = await server.process_non_streaming(
810 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
811 )
812 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(more_threads.json)
813 assert len(threads.data) == 5
814 assert threads.has_more is False
815 assert not threads.after
816
817
818async def test_list_items_response():
819 async def responder(
820 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
821 ) -> AsyncIterator[ThreadStreamEvent]:
822 for i in range(25):
823 yield ThreadItemDoneEvent(
824 item=UserMessageItem(
825 id=f"msg_{i}",
826 content=[UserMessageTextContent(text=f"Message {i}")],
827 attachments=[],
828 inference_options=InferenceOptions(),
829 thread_id=thread.id,
830 created_at=datetime.now(),
831 ),
832 )
833
834 with make_server(responder) as server:
835 events = await server.process_streaming(
836 ThreadsCreateReq(
837 params=ThreadCreateParams(
838 input=UserMessageInput(
839 content=[UserMessageTextContent(text="Thread 1")],
840 attachments=[],
841 inference_options=InferenceOptions(),
842 )
843 )
844 )
845 )
846 thread = next(
847 event.thread for event in events if event.type == "thread.created"
848 )
849 list_result = await server.process_non_streaming(
850 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
851 )
852 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
853 assert len(items.data) == 20
854 # Newest first
855 assert items.data[0].created_at > items.data[1].created_at
856 assert items.has_more is True
857 assert items.after is not None
858
859 list_result = await server.process_non_streaming(
860 ItemsListReq(
861 params=ItemsListParams(thread_id=thread.id, after=items.data[-1].id)
862 )
863 )
864 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
865 assert len(items.data) == 6 # +1 for user message
866 assert items.has_more is False
867
868
869async def test_calls_action():
870 actions = []
871
872 async def responder(
873 thread: ThreadMetadata,
874 input: UserMessageItem | None,
875 context: Any,
876 ) -> AsyncIterator[ThreadStreamEvent]:
877 yield ThreadItemDoneEvent(
878 item=WidgetItem(
879 id="widget_1",
880 type="widget",
881 created_at=datetime.now(),
882 thread_id=thread.id,
883 widget=Card(
884 children=[
885 Text(
886 key="text_1",
887 value="test",
888 )
889 ],
890 ),
891 ),
892 )
893
894 async def action(
895 thread: ThreadMetadata,
896 action: Action[str, Any],
897 sender: WidgetItem | None,
898 context: Any,
899 ) -> AsyncIterator[ThreadStreamEvent]:
900 actions.append((action, sender))
901 assert sender
902
903 yield ThreadItemUpdatedEvent(
904 item_id=sender.id,
905 update=WidgetRootUpdated(
906 widget=Card(
907 children=[
908 Text(value="Email sent!"),
909 ]
910 ),
911 ),
912 )
913
914 with make_server(responder, action_callback=action) as server:
915 events = await server.process_streaming(
916 ThreadsCreateReq(
917 params=ThreadCreateParams(
918 input=UserMessageInput(
919 content=[UserMessageTextContent(text="Show widget")],
920 attachments=[],
921 inference_options=InferenceOptions(),
922 )
923 )
924 )
925 )
926 thread = next(
927 event.thread for event in events if event.type == "thread.created"
928 )
929 widget_item = next(
930 event.item
931 for event in events
932 if isinstance(event, ThreadItemDoneEvent)
933 and isinstance(event.item, WidgetItem)
934 )
935
936 events = await server.process_streaming(
937 ThreadsCustomActionReq(
938 params=ThreadCustomActionParams(
939 thread_id=thread.id,
940 item_id=widget_item.id,
941 action=Action(type="create_user", payload={"user_id": "123"}),
942 )
943 )
944 )
945
946 assert actions
947 assert actions[0] == (
948 Action(type="create_user", payload={"user_id": "123"}),
949 widget_item,
950 )
951
952 assert len(events) == 2
953 assert events[0].type == "stream_options"
954 assert events[1].type == "thread.item.updated"
955 assert isinstance(events[1], ThreadItemUpdatedEvent)
956 assert events[1].update.type == "widget.root.updated"
957 assert events[1].update.widget == Card(children=[Text(value="Email sent!")])
958
959
960async def test_add_feedback():
961 called = []
962
963 def handle_feedback(thread_id, item_ids, feedback, context):
964 called.append((thread_id, item_ids, feedback, context))
965
966 async def responder(
967 thread: ThreadMetadata,
968 input: UserMessageItem | None,
969 context: Any,
970 ) -> AsyncIterator[ThreadStreamEvent]:
971 yield ThreadItemDoneEvent(
972 item=AssistantMessageItem(
973 id="assistant_msg_1",
974 content=[
975 AssistantMessageContent(text="Feedback received!"),
976 ],
977 created_at=datetime.now(),
978 thread_id=thread.id,
979 ),
980 )
981 yield ThreadItemDoneEvent(
982 item=WidgetItem(
983 id="widget_1",
984 type="widget",
985 created_at=datetime.now(),
986 widget=Card(children=[Text(key="text", value="Widget content")]),
987 thread_id=thread.id,
988 ),
989 )
990
991 with make_server(responder, handle_feedback=handle_feedback) as server:
992 events = await server.process_streaming(
993 ThreadsCreateReq(
994 params=ThreadCreateParams(
995 input=UserMessageInput(
996 content=[UserMessageTextContent(text="Test feedback")],
997 attachments=[],
998 inference_options=InferenceOptions(),
999 )
1000 )
1001 )
1002 )
1003 thread = next(
1004 event.thread for event in events if event.type == "thread.created"
1005 )
1006
1007 await server.process_non_streaming(
1008 ItemsFeedbackReq(
1009 params=ItemFeedbackParams(
1010 thread_id=thread.id,
1011 item_ids=["assistant_msg_1"],
1012 kind="positive",
1013 )
1014 )
1015 )
1016 await server.process_non_streaming(
1017 ItemsFeedbackReq(
1018 params=ItemFeedbackParams(
1019 thread_id=thread.id,
1020 item_ids=["widget_1"],
1021 kind="negative",
1022 )
1023 )
1024 )
1025
1026 assert len(called) == 2
1027 assert called[0][0] == thread.id
1028 assert called[0][1] == ["assistant_msg_1"]
1029 assert called[0][2] == "positive"
1030
1031 assert called[1][0] == thread.id
1032 assert called[1][1] == ["widget_1"]
1033 assert called[1][2] == "negative"
1034
1035
1036async def test_get_thread_by_id():
1037 make_thread_items()
1038
1039 with make_server() as server:
1040 # Create a thread by sending a ThreadCreateRequest
1041 events = await server.process_streaming(
1042 ThreadsCreateReq(
1043 params=ThreadCreateParams(
1044 input=UserMessageInput(
1045 content=[UserMessageTextContent(text="Test thread")],
1046 attachments=[],
1047 inference_options=InferenceOptions(),
1048 )
1049 )
1050 )
1051 )
1052 thread = next(
1053 event.thread for event in events if event.type == "thread.created"
1054 )
1055
1056 # Retrieve the thread by id
1057 result = await server.process_non_streaming(
1058 ThreadsGetByIdReq(
1059 params=ThreadGetByIdParams(thread_id=thread.id),
1060 )
1061 )
1062 loaded_thread = TypeAdapter(Thread).validate_json(result.json)
1063 assert loaded_thread.id == thread.id
1064 assert loaded_thread.title == thread.title
1065 assert loaded_thread.items.data[0].type == "user_message"
1066 assert loaded_thread.items.data[0].content[0].text == "Test thread"
1067
1068
1069async def test_create_file():
1070 store = InMemoryFileStore()
1071 file_name = "test-file-name"
1072 file_content_type = "text/plain"
1073
1074 with make_server(file_store=store) as server:
1075 result = await server.process_non_streaming(
1076 AttachmentsCreateReq(
1077 params=AttachmentCreateParams(
1078 name=file_name, size=1024, mime_type=file_content_type
1079 )
1080 )
1081 )
1082 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1083
1084 # Verify response includes file
1085 assert attachment.id is not None
1086 assert attachment.mime_type == file_content_type
1087 assert attachment.name == file_name
1088 assert attachment.type == "file"
1089 assert attachment.upload_url == AnyUrl(
1090 f"https://example.com/{attachment.id}/upload"
1091 )
1092 assert attachment.id in store.files
1093
1094
1095async def test_create_image_file():
1096 store = InMemoryFileStore()
1097 file_name = "test-file-name"
1098 file_content_type = "image/png"
1099
1100 with make_server(file_store=store) as server:
1101 result = await server.process_non_streaming(
1102 AttachmentsCreateReq(
1103 params=AttachmentCreateParams(
1104 name=file_name,
1105 size=1024,
1106 mime_type=file_content_type,
1107 )
1108 )
1109 )
1110 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1111
1112 assert attachment.id is not None
1113 assert attachment.mime_type == file_content_type
1114 assert attachment.name == file_name
1115 assert attachment.type == "image"
1116 assert attachment.preview_url == AnyUrl(
1117 f"https://example.com/{attachment.id}/preview"
1118 )
1119 assert attachment.upload_url == AnyUrl(
1120 f"https://example.com/{attachment.id}/upload"
1121 )
1122
1123 assert attachment.id in store.files
1124
1125
1126async def test_create_file_without_filestore():
1127 with pytest.raises(RuntimeError):
1128 with make_server() as server:
1129 await server.process_non_streaming(
1130 AttachmentsCreateReq(
1131 params=AttachmentCreateParams(
1132 name="test-file-name",
1133 size=1024,
1134 mime_type="text/plain",
1135 )
1136 )
1137 )
1138
1139
1140async def test_delete_file():
1141 store = InMemoryFileStore()
1142 store.files["test-file-id"] = {"id": "test-file-id"}
1143 with make_server(file_store=store) as server:
1144 await server.process_non_streaming(
1145 AttachmentsDeleteReq(
1146 params=AttachmentDeleteParams(attachment_id="test-file-id")
1147 )
1148 )
1149 assert store.files == {}
1150
1151
1152async def test_delete_file_without_filestore():
1153 with pytest.raises(RuntimeError):
1154 with make_server() as server:
1155 await server.process_non_streaming(
1156 AttachmentsDeleteReq(
1157 params=AttachmentDeleteParams(attachment_id="test-file-id")
1158 )
1159 )
1160
1161
1162async def test_list_threads_paginated_response():
1163 with make_server() as server:
1164 for i in range(25):
1165 await server.process_streaming(
1166 ThreadsCreateReq(
1167 params=ThreadCreateParams(
1168 input=UserMessageInput(
1169 content=[UserMessageTextContent(text=f"Thread {i}")],
1170 attachments=[],
1171 inference_options=InferenceOptions(),
1172 )
1173 )
1174 )
1175 )
1176 list_result = await server.process_non_streaming(
1177 ThreadsListReq(params=ThreadListParams(limit=20))
1178 )
1179 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1180 assert len(threads.data) == 20
1181 assert threads.has_more is True
1182 assert threads.after is not None
1183
1184 list_result = await server.process_non_streaming(
1185 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
1186 )
1187 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1188 assert len(threads.data) == 5
1189 assert threads.has_more is False
1190
1191
1192async def test_threads_update_request():
1193 with make_server() as server:
1194 events = await server.process_streaming(
1195 ThreadsCreateReq(
1196 params=ThreadCreateParams(
1197 input=UserMessageInput(
1198 content=[UserMessageTextContent(text="Hello, world!")],
1199 attachments=[],
1200 inference_options=InferenceOptions(),
1201 )
1202 )
1203 )
1204 )
1205 thread = next(
1206 event.thread for event in events if event.type == "thread.created"
1207 )
1208 updated_thread = await server.process_non_streaming(
1209 ThreadsUpdateReq(
1210 params=ThreadUpdateParams(thread_id=thread.id, title="New Title")
1211 )
1212 )
1213 updated_thread = TypeAdapter[Thread](Thread).validate_json(updated_thread.json)
1214 assert updated_thread.title == "New Title"
1215
1216
1217async def test_returns_widget_item_generator():
1218 thread = ThreadMetadata(
1219 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1220 )
1221
1222 def render_widget(i: int) -> Card:
1223 return Card(children=[Text(id="text", value="Hello, world"[:i])])
1224
1225 async def widget_generator():
1226 yield render_widget(0)
1227 yield render_widget(3)
1228 yield render_widget(6)
1229 yield render_widget(12)
1230
1231 events = [event async for event in stream_widget(thread, widget_generator())]
1232
1233 assert len(events) == 5
1234 assert isinstance(events[0], ThreadItemAddedEvent)
1235 assert isinstance(events[0].item, WidgetItem)
1236 assert events[0].item.widget == Card(children=[Text(id="text", value="")])
1237
1238 assert isinstance(events[1], ThreadItemUpdatedEvent)
1239 assert events[1].update.type == "widget.streaming_text.value_delta"
1240 assert events[1].update.component_id == "text"
1241 assert events[1].update.delta == "Hel"
1242
1243 assert isinstance(events[2], ThreadItemUpdatedEvent)
1244 assert events[2].update.type == "widget.streaming_text.value_delta"
1245 assert events[2].update.component_id == "text"
1246 assert events[2].update.delta == "lo,"
1247
1248 assert isinstance(events[3], ThreadItemUpdatedEvent)
1249 assert events[3].update.type == "widget.streaming_text.value_delta"
1250 assert events[3].update.component_id == "text"
1251 assert events[3].update.delta == " world"
1252
1253 assert isinstance(events[4], ThreadItemDoneEvent)
1254 assert isinstance(events[4].item, WidgetItem)
1255 assert events[4].item.widget == Card(
1256 children=[Text(id="text", value="Hello, world")]
1257 )
1258
1259
1260async def test_returns_widget_item_generator_full_replace():
1261 thread = ThreadMetadata(
1262 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1263 )
1264
1265 async def widget_generator():
1266 yield Card(children=[Text(id="text", value="Hello")])
1267 yield Card(children=[Text(id="text", value="World")])
1268
1269 events = [event async for event in stream_widget(thread, widget_generator())]
1270
1271 assert len(events) == 3
1272 assert isinstance(events[0], ThreadItemAddedEvent)
1273 assert isinstance(events[0].item, WidgetItem)
1274 assert events[0].item.widget == Card(children=[Text(id="text", value="Hello")])
1275
1276 assert isinstance(events[1], ThreadItemUpdatedEvent)
1277 assert events[1].update.type == "widget.root.updated"
1278 assert events[1].update.widget == Card(children=[Text(id="text", value="World")])
1279
1280 assert isinstance(events[2], ThreadItemDoneEvent)
1281 assert isinstance(events[2].item, WidgetItem)
1282 assert events[2].item.widget == Card(children=[Text(id="text", value="World")])
1283
1284
1285async def test_delete_thread():
1286 with make_server() as server:
1287 events = await server.process_streaming(
1288 ThreadsCreateReq(
1289 params=ThreadCreateParams(
1290 input=UserMessageInput(
1291 content=[UserMessageTextContent(text="Thread to delete")],
1292 attachments=[],
1293 inference_options=InferenceOptions(),
1294 )
1295 )
1296 )
1297 )
1298 thread = next(
1299 event.thread for event in events if event.type == "thread.created"
1300 )
1301
1302 response = await server.process_non_streaming(
1303 ThreadsDeleteReq(params=ThreadDeleteParams(thread_id=thread.id))
1304 )
1305 assert response.json == b"{}"
1306
1307 try:
1308 await server.process_non_streaming(
1309 ThreadsGetByIdReq(
1310 params=ThreadGetByIdParams(thread_id=thread.id),
1311 )
1312 )
1313 assert False, "Expected ValueError for deleted thread"
1314 except NotFoundError as e:
1315 assert f"Thread {thread.id} not found" in str(e)
1316
1317
1318async def test_thread_item_removed_event_removes_item():
1319 removed_id = "msg_to_remove"
1320
1321 async def responder(
1322 thread: ThreadMetadata,
1323 input: UserMessageItem | None,
1324 context: Any,
1325 ) -> AsyncIterator[ThreadStreamEvent]:
1326 yield ThreadItemDoneEvent(
1327 item=UserMessageItem(
1328 id=removed_id,
1329 content=[UserMessageTextContent(text="To be removed")],
1330 attachments=[],
1331 inference_options=InferenceOptions(),
1332 thread_id=thread.id,
1333 created_at=datetime.now(),
1334 ),
1335 )
1336 yield ThreadItemRemovedEvent(item_id=removed_id)
1337
1338 with make_server(responder) as server:
1339 events = await server.process_streaming(
1340 ThreadsCreateReq(
1341 params=ThreadCreateParams(
1342 input=UserMessageInput(
1343 content=[UserMessageTextContent(text="Trigger removal")],
1344 attachments=[],
1345 inference_options=InferenceOptions(),
1346 )
1347 )
1348 )
1349 )
1350 thread = next(
1351 event.thread for event in events if event.type == "thread.created"
1352 )
1353 assert any(
1354 e.type == "thread.item.removed" and e.item_id == removed_id for e in events
1355 )
1356 # The item should not be present in the store
1357 list_result = await server.process_non_streaming(
1358 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1359 )
1360 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1361 assert all(i.id != removed_id for i in items)
1362
1363
1364async def test_raising_in_responder_yields_error_event():
1365 async def responder(
1366 thread: ThreadMetadata,
1367 input: UserMessageItem | None,
1368 context: Any,
1369 ) -> AsyncIterator[ThreadStreamEvent]:
1370 raise ValueError("Test error")
1371 yield
1372
1373 with make_server(responder) as server:
1374 result = await server.process(
1375 ThreadsCreateReq(
1376 params=ThreadCreateParams(
1377 input=UserMessageInput(
1378 content=[UserMessageTextContent(text="Test error")],
1379 attachments=[],
1380 inference_options=InferenceOptions(),
1381 )
1382 )
1383 ).model_dump_json(),
1384 RequestContext(user_id="test_user"),
1385 )
1386 assert isinstance(result, StreamingResult)
1387 iter = result.__aiter__()
1388
1389 # Yield an errro
1390 async for e in iter:
1391 e = decode_event(e)
1392 if e.type == "error":
1393 assert e.code == ErrorCode.STREAM_ERROR
1394 break
1395 else:
1396 assert False, "Expected error event"
1397
1398 # Then finish
1399 try:
1400 await iter.__anext__()
1401 except StopAsyncIteration:
1402 pass
1403 else:
1404 assert False, "Expected StopAsyncIteration"
1405
1406
1407async def test_thread_item_replaced_event_replaces_item():
1408 original_id = "msg_to_replace"
1409
1410 async def responder(
1411 thread: ThreadMetadata,
1412 input: UserMessageItem | None,
1413 context: Any,
1414 ) -> AsyncIterator[ThreadStreamEvent]:
1415 # First add an item
1416 yield ThreadItemDoneEvent(
1417 item=UserMessageItem(
1418 id=original_id,
1419 content=[UserMessageTextContent(text="Original content")],
1420 attachments=[],
1421 inference_options=InferenceOptions(),
1422 thread_id=thread.id,
1423 created_at=datetime.now(),
1424 ),
1425 )
1426 # Then replace it
1427 replacement_item = AssistantMessageItem(
1428 id=original_id,
1429 content=[AssistantMessageContent(text="This is the replacement content")],
1430 created_at=datetime.now(),
1431 thread_id=thread.id,
1432 )
1433 yield ThreadItemReplacedEvent(item=replacement_item)
1434
1435 with make_server(responder) as server:
1436 events = await server.process_streaming(
1437 ThreadsCreateReq(
1438 params=ThreadCreateParams(
1439 input=UserMessageInput(
1440 content=[UserMessageTextContent(text="Trigger replacement")],
1441 attachments=[],
1442 inference_options=InferenceOptions(),
1443 )
1444 )
1445 )
1446 )
1447 thread = next(
1448 event.thread for event in events if event.type == "thread.created"
1449 )
1450
1451 # Verify the replacement event was generated
1452 assert any(
1453 e.type == "thread.item.replaced" and e.item.id == original_id
1454 for e in events
1455 )
1456
1457 # Verify the item was replaced in the store
1458 list_result = await server.process_non_streaming(
1459 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1460 )
1461 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1462
1463 # Find the replaced item
1464 replaced_item = next((i for i in items if i.id == original_id), None)
1465 assert replaced_item is not None
1466 assert isinstance(replaced_item, AssistantMessageItem)
1467 assert isinstance(replaced_item.content[0], AssistantMessageContent)
1468 assert replaced_item.content[0].text == "This is the replacement content"
1469
1470
1471async def test_thread_item_replaced_event_with_widget():
1472 widget_id = "widget_to_replace"
1473 original_widget = Card(children=[Text(value="Original widget content")])
1474 replacement_widget = Card(children=[Text(value="Replaced widget content")])
1475
1476 async def responder(
1477 thread: ThreadMetadata,
1478 input: UserMessageItem | None,
1479 context: Any,
1480 ) -> AsyncIterator[ThreadStreamEvent]:
1481 # First add a widget item
1482 yield ThreadItemDoneEvent(
1483 item=WidgetItem(
1484 id=widget_id,
1485 widget=original_widget,
1486 created_at=datetime.now(),
1487 thread_id=thread.id,
1488 ),
1489 )
1490 # Then replace it
1491 yield ThreadItemReplacedEvent(
1492 item=WidgetItem(
1493 id=widget_id,
1494 widget=replacement_widget,
1495 created_at=datetime.now(),
1496 thread_id=thread.id,
1497 ),
1498 )
1499
1500 with make_server(responder) as server:
1501 events = await server.process_streaming(
1502 ThreadsCreateReq(
1503 params=ThreadCreateParams(
1504 input=UserMessageInput(
1505 content=[
1506 UserMessageTextContent(text="Trigger widget replacement")
1507 ],
1508 attachments=[],
1509 inference_options=InferenceOptions(),
1510 )
1511 )
1512 )
1513 )
1514 thread = next(
1515 event.thread for event in events if event.type == "thread.created"
1516 )
1517
1518 # Verify the replacement event was generated
1519 replacement_event = next(
1520 (
1521 e
1522 for e in events
1523 if e.type == "thread.item.replaced" and e.item.id == widget_id
1524 ),
1525 None,
1526 )
1527 assert replacement_event is not None
1528 assert isinstance(replacement_event.item, WidgetItem)
1529 assert replacement_event.item.widget == replacement_widget
1530
1531 # Verify the item was replaced in the store
1532 list_result = await server.process_non_streaming(
1533 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1534 )
1535 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1536
1537 # Find the replaced widget item
1538 replaced_item = next((i for i in items if i.id == widget_id), None)
1539 assert replaced_item is not None
1540 assert isinstance(replaced_item, WidgetItem)
1541 assert isinstance(replaced_item.widget, Card)
1542 assert isinstance(replaced_item.widget.children[0], Text)
1543 assert replaced_item.widget.children[0].value == "Replaced widget content"
1544
1545
1546async def test_retry_after_item():
1547 responder_calls = []
1548
1549 async def responder(
1550 thread: ThreadMetadata,
1551 input: UserMessageItem | None,
1552 context: Any,
1553 ) -> AsyncIterator[ThreadStreamEvent]:
1554 responder_calls.append(input)
1555
1556 if len(responder_calls) == 1:
1557 # First response - return an assistant message
1558 yield ThreadItemDoneEvent(
1559 item=AssistantMessageItem(
1560 id="assistant_msg_1",
1561 content=[AssistantMessageContent(text="First response")],
1562 created_at=datetime.now(),
1563 thread_id=thread.id,
1564 ),
1565 )
1566 else:
1567 # Retry response - return different content
1568 yield ThreadItemDoneEvent(
1569 item=AssistantMessageItem(
1570 id="assistant_msg_2",
1571 content=[AssistantMessageContent(text="Retried response")],
1572 created_at=datetime.now(),
1573 thread_id=thread.id,
1574 ),
1575 )
1576
1577 with make_server(responder) as server:
1578 events = await server.process_streaming(
1579 ThreadsCreateReq(
1580 params=ThreadCreateParams(
1581 input=UserMessageInput(
1582 content=[UserMessageTextContent(text="Hello, world!")],
1583 attachments=[],
1584 inference_options=InferenceOptions(),
1585 )
1586 )
1587 )
1588 )
1589 thread = next(
1590 event.thread for event in events if event.type == "thread.created"
1591 )
1592
1593 list_result = await server.process_non_streaming(
1594 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1595 )
1596 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1597 assert len(items_before.data) == 2
1598 assert items_before.data[0].type == "assistant_message"
1599 assert items_before.data[0].content[0].type == "output_text"
1600 assert items_before.data[0].content[0].text == "First response"
1601 assert items_before.data[1].type == "user_message"
1602 assert items_before.data[1].content[0].text == "Hello, world!"
1603
1604 # Get the user message ID for retrying after it
1605 user_message_id = items_before.data[1].id
1606
1607 # Retry after the user message
1608 retry_events = await server.process_streaming(
1609 ThreadsRetryAfterItemReq(
1610 params=ThreadRetryAfterItemParams(
1611 thread_id=thread.id, item_id=user_message_id
1612 )
1613 )
1614 )
1615
1616 # Verify the retry generated new response
1617 assert len(retry_events) == 2
1618 assert retry_events[0].type == "stream_options"
1619 assert retry_events[1].type == "thread.item.done"
1620 assert retry_events[1].item.type == "assistant_message"
1621 assert retry_events[1].item.content[0].type == "output_text"
1622 assert retry_events[1].item.content[0].text == "Retried response"
1623
1624 # Verify the responder was called twice with the same user message
1625 assert len(responder_calls) == 2
1626 assert responder_calls[0] == responder_calls[1]
1627 assert responder_calls[0].type == "user_message"
1628 assert responder_calls[0].content[0].text == "Hello, world!"
1629
1630 # Verify the assistant response was replaced in storage
1631 list_result_after = await server.process_non_streaming(
1632 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1633 )
1634 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1635 list_result_after.json
1636 )
1637 assert len(items_after.data) == 2 # Still user message + assistant message
1638 assert items_after.data[0].type == "assistant_message"
1639 assert items_after.data[0].content[0].type == "output_text"
1640 assert items_after.data[0].content[0].text == "Retried response" # New response
1641 assert items_after.data[1].type == "user_message"
1642 assert items_after.data[1].content[0].text == "Hello, world!"
1643
1644
1645async def test_retry_after_item_invalid_item():
1646 async def responder(
1647 thread: ThreadMetadata,
1648 input: UserMessageItem | None,
1649 context: Any,
1650 ) -> AsyncIterator[ThreadStreamEvent]:
1651 yield ThreadItemDoneEvent(
1652 item=AssistantMessageItem(
1653 id="assistant_msg_1",
1654 content=[AssistantMessageContent(text="Response")],
1655 created_at=datetime.now(),
1656 thread_id=thread.id,
1657 ),
1658 )
1659
1660 with make_server(responder) as server:
1661 events = await server.process_streaming(
1662 ThreadsCreateReq(
1663 params=ThreadCreateParams(
1664 input=UserMessageInput(
1665 content=[UserMessageTextContent(text="Hello")],
1666 attachments=[],
1667 inference_options=InferenceOptions(),
1668 )
1669 )
1670 )
1671 )
1672 thread = next(
1673 event.thread for event in events if event.type == "thread.created"
1674 )
1675
1676 items_result = await server.process_non_streaming(
1677 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1678 )
1679 items = TypeAdapter(Page[ThreadItem]).validate_json(items_result.json)
1680 assistant_item_id = items.data[0].id
1681
1682 # Try to retry after an assistant message (should fail)
1683 with pytest.raises(ValueError, match="is not a user message"):
1684 await server.process_streaming(
1685 ThreadsRetryAfterItemReq(
1686 params=ThreadRetryAfterItemParams(
1687 thread_id=thread.id, item_id=assistant_item_id
1688 )
1689 )
1690 )
1691
1692
1693async def test_retry_after_item_with_multiple_messages():
1694 responder_calls = []
1695
1696 async def responder(
1697 thread: ThreadMetadata,
1698 input: UserMessageItem | None,
1699 context: Any,
1700 ) -> AsyncIterator[ThreadStreamEvent]:
1701 responder_calls.append(input)
1702 assert input is not None
1703 assert input.type == "user_message"
1704 yield ThreadItemDoneEvent(
1705 item=AssistantMessageItem(
1706 id=f"assistant_msg_{len(responder_calls)}",
1707 content=[
1708 AssistantMessageContent(
1709 text=f"Response to: {input.content[0].text}"
1710 )
1711 ],
1712 created_at=datetime.now(),
1713 thread_id=thread.id,
1714 ),
1715 )
1716
1717 with make_server(responder) as server:
1718 events = await server.process_streaming(
1719 ThreadsCreateReq(
1720 params=ThreadCreateParams(
1721 input=UserMessageInput(
1722 content=[UserMessageTextContent(text="First message")],
1723 attachments=[],
1724 inference_options=InferenceOptions(),
1725 )
1726 )
1727 )
1728 )
1729 thread = next(
1730 event.thread for event in events if event.type == "thread.created"
1731 )
1732
1733 await server.process_streaming(
1734 ThreadsAddUserMessageReq(
1735 params=ThreadAddUserMessageParams(
1736 thread_id=thread.id,
1737 input=UserMessageInput(
1738 content=[UserMessageTextContent(text="Second message")],
1739 attachments=[],
1740 inference_options=InferenceOptions(),
1741 ),
1742 )
1743 )
1744 )
1745
1746 # Verify we have 4 items: user1, assistant1, user2, assistant2
1747 list_result = await server.process_non_streaming(
1748 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1749 )
1750 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1751 assert len(items_before.data) == 4
1752
1753 # Get the second user message ID to retry after it
1754 second_user_message_id = items_before.data[1].id # Second user message
1755
1756 # Retry after the second user message
1757 retry_events = await server.process_streaming(
1758 ThreadsRetryAfterItemReq(
1759 params=ThreadRetryAfterItemParams(
1760 thread_id=thread.id, item_id=second_user_message_id
1761 )
1762 )
1763 )
1764
1765 assert len(retry_events) == 2
1766 assert retry_events[0].type == "stream_options"
1767 assert retry_events[1].type == "thread.item.done"
1768
1769 # Verify retry used the second user message
1770 assert len(responder_calls) == 3 # Original 2 + 1 retry
1771 assert responder_calls[2].type == "user_message"
1772 assert responder_calls[2].content[0].text == "Second message"
1773
1774 # Verify only the last assistant response was removed and regenerated
1775 list_result_after = await server.process_non_streaming(
1776 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1777 )
1778 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1779 list_result_after.json
1780 )
1781 assert len(items_after.data) == 4 # Still 4 items
1782 # First user message and response should remain unchanged
1783 assert items_after.data[3].type == "user_message"
1784 assert items_after.data[3].content[0].text == "First message"
1785 assert items_after.data[3].content[0].type == "input_text"
1786 assert items_after.data[3].content[0].text == "First message"
1787 assert items_after.data[2].type == "assistant_message"
1788 assert items_after.data[2].content[0].type == "output_text"
1789 assert items_after.data[2].content[0].text == "Response to: First message"
1790 # Second user message should remain, but assistant response should be new
1791 assert items_after.data[1].type == "user_message"
1792 assert items_after.data[1].content[0].text == "Second message"
1793 assert items_after.data[0].type == "assistant_message"
1794 assert items_after.data[0].content[0].type == "output_text"
1795 assert items_after.data[0].content[0].text == "Response to: Second message"
1796 assert items_after.data[0].id != items_before.data[0].id
1797
1798
1799async def test_threads_create_passes_tools_to_responder():
1800 tool_choice = ToolChoice(id="web_search")
1801
1802 async def responder(
1803 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1804 ) -> AsyncIterator[ThreadStreamEvent]:
1805 assert input is not None
1806 assert input.inference_options.tool_choice is not None
1807 assert input.inference_options.tool_choice.id == tool_choice.id
1808
1809 yield ThreadItemDoneEvent(
1810 item=AssistantMessageItem(
1811 id="assistant_msg_tools_create",
1812 content=[AssistantMessageContent(text="ok")],
1813 created_at=datetime.now(),
1814 thread_id=thread.id,
1815 ),
1816 )
1817
1818 with make_server(responder) as server:
1819 events = await server.process_streaming(
1820 ThreadsCreateReq(
1821 params=ThreadCreateParams(
1822 input=UserMessageInput(
1823 content=[UserMessageTextContent(text="Hello with tools")],
1824 attachments=[],
1825 inference_options=InferenceOptions(tool_choice=tool_choice),
1826 )
1827 )
1828 )
1829 )
1830 # Sanity check that the request flowed through.
1831 assert any(e.type == "thread.item.done" for e in events)
1832
1833
1834async def test_retry_after_item_passes_tools_to_responder():
1835 pass
1836
1837
1838async def test_threads_add_message_passes_tools_to_responder():
1839 tool_choice = ToolChoice(id="retrieval")
1840
1841 async def responder(
1842 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1843 ) -> AsyncIterator[ThreadStreamEvent]:
1844 assert input is not None
1845 assert input.inference_options.tool_choice is not None
1846 assert input.inference_options.tool_choice.id == tool_choice.id
1847
1848 yield ThreadItemDoneEvent(
1849 item=AssistantMessageItem(
1850 id="assistant_msg_tools_add",
1851 content=[AssistantMessageContent(text="ok")],
1852 created_at=datetime.now(),
1853 thread_id=thread.id,
1854 ),
1855 )
1856
1857 with make_server(responder) as server:
1858 # Create a thread first.
1859 events = await server.process_streaming(
1860 ThreadsCreateReq(
1861 params=ThreadCreateParams(
1862 input=UserMessageInput(
1863 content=[UserMessageTextContent(text="Start")],
1864 attachments=[],
1865 inference_options=InferenceOptions(tool_choice=tool_choice),
1866 )
1867 )
1868 )
1869 )
1870 thread = next(e.thread for e in events if e.type == "thread.created")
1871
1872 # Add a message with tools and verify they reach the responder.
1873 events = await server.process_streaming(
1874 ThreadsAddUserMessageReq(
1875 params=ThreadAddUserMessageParams(
1876 thread_id=thread.id,
1877 input=UserMessageInput(
1878 content=[UserMessageTextContent(text="Second")],
1879 attachments=[],
1880 inference_options=InferenceOptions(tool_choice=tool_choice),
1881 ),
1882 )
1883 )
1884 )
1885 # Sanity check that the request flowed through.
1886 assert any(e.type == "thread.item.done" for e in events)
1887
1888
1889async def test_custom_id_generators_on_server():
1890 class UniqueIdStore(SQLiteStore):
1891 def __init__(self, path):
1892 super().__init__(path)
1893 self.counter = 0
1894
1895 def generate_thread_id(self, context) -> str:
1896 self.counter += 1
1897 return f"thr_custom_{self.counter}"
1898
1899 def generate_item_id(
1900 self,
1901 item_type: str,
1902 thread: ThreadMetadata,
1903 context,
1904 ) -> str:
1905 self.counter += 1
1906 return f"{item_type}_custom_{self.counter}_{thread.id}"
1907
1908 async def responder(
1909 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1910 ) -> AsyncIterator[ThreadStreamEvent]:
1911 # Emit a tool call to later verify tool_call_output id generation
1912 return
1913 yield
1914
1915 with make_server(responder) as server:
1916 # Swap in our custom ID-generating store, reusing the same DB path.
1917 db_path = cast(SQLiteStore, server.store).db_path
1918 server.store = UniqueIdStore(db_path)
1919
1920 # Create thread and verify thread id and user message id
1921 result = await server.process(
1922 ThreadsCreateReq(
1923 params=ThreadCreateParams(
1924 input=UserMessageInput(
1925 content=[UserMessageTextContent(text="Hello")],
1926 attachments=[],
1927 inference_options=InferenceOptions(),
1928 )
1929 )
1930 ).model_dump_json(),
1931 DEFAULT_CONTEXT,
1932 )
1933 assert isinstance(result, StreamingResult)
1934 events = await decode_streaming_result(result)
1935
1936 thread_created = next(e for e in events if e.type == "thread.created")
1937 assert thread_created.thread.id == "thr_custom_1"
1938
1939 user_message = next(
1940 e.item
1941 for e in events
1942 if e.type == "thread.item.done" and e.item.type == "user_message"
1943 )
1944 assert user_message.id == "message_custom_2_thr_custom_1"