openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
cec87b20f1817daa284c63d63ca7cf3f69a38f07

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_chatkit_server.py

1956lines · modecode

1import asyncio
2import sqlite3
3from contextlib import contextmanager
4from datetime import datetime
5from typing import Any, AsyncIterator, Callable, cast
6
7import pytest
8from helpers.mock_store import SQLiteStore
9from pydantic import AnyUrl, TypeAdapter
10
11from chatkit.actions import Action
12from chatkit.errors import ErrorCode
13from chatkit.server import (
14 ChatKitServer,
15 NonStreamingResult,
16 StreamingResult,
17 stream_widget,
18)
19from chatkit.store import (
20 AttachmentStore,
21 NotFoundError,
22 StoreItemType,
23 default_generate_id,
24)
25from chatkit.types import (
26 AssistantMessageContent,
27 AssistantMessageContentPartTextDelta,
28 AssistantMessageItem,
29 Attachment,
30 AttachmentCreateParams,
31 AttachmentDeleteParams,
32 AttachmentsCreateReq,
33 AttachmentsDeleteReq,
34 AttachmentUploadDescriptor,
35 ClientToolCallItem,
36 CustomTask,
37 FeedbackKind,
38 FileAttachment,
39 ImageAttachment,
40 InferenceOptions,
41 ItemFeedbackParams,
42 ItemsFeedbackReq,
43 ItemsListParams,
44 ItemsListReq,
45 LockedStatus,
46 Page,
47 ProgressUpdateEvent,
48 Thread,
49 ThreadAddClientToolOutputParams,
50 ThreadAddUserMessageParams,
51 ThreadCreatedEvent,
52 ThreadCreateParams,
53 ThreadCustomActionParams,
54 ThreadDeleteParams,
55 ThreadGetByIdParams,
56 ThreadItem,
57 ThreadItemAddedEvent,
58 ThreadItemDoneEvent,
59 ThreadItemRemovedEvent,
60 ThreadItemReplacedEvent,
61 ThreadItemUpdatedEvent,
62 ThreadListParams,
63 ThreadMetadata,
64 ThreadRetryAfterItemParams,
65 ThreadsAddClientToolOutputReq,
66 ThreadsAddUserMessageReq,
67 ThreadsCreateReq,
68 ThreadsCustomActionReq,
69 ThreadsDeleteReq,
70 ThreadsGetByIdReq,
71 ThreadsListReq,
72 ThreadsRetryAfterItemReq,
73 ThreadStreamEvent,
74 ThreadsUpdateReq,
75 ThreadUpdatedEvent,
76 ThreadUpdateParams,
77 ToolChoice,
78 UserMessageInput,
79 UserMessageItem,
80 UserMessageTextContent,
81 WidgetItem,
82 WidgetRootUpdated,
83 Workflow,
84 WorkflowItem,
85 WorkflowTaskAdded,
86)
87from chatkit.widgets import Card, Text
88from tests._types import RequestContext
89from tests.test_store import make_thread_items
90
91server_id = 0
92
93
94DEFAULT_CONTEXT = RequestContext(user_id="test_user")
95
96
97class InMemoryFileStore(AttachmentStore):
98 def __init__(self):
99 self.files = {}
100
101 async def delete_attachment(self, attachment_id: str, context: Any) -> None:
102 del self.files[attachment_id]
103
104 async def create_attachment(
105 self, input: AttachmentCreateParams, context: RequestContext
106 ) -> Attachment:
107 # Simple in-memory attachment creation
108 id = f"atc_{len(self.files) + 1}"
109 if input.mime_type.startswith("image/"):
110 attachment = ImageAttachment(
111 id=id,
112 mime_type=input.mime_type,
113 name=input.name,
114 preview_url=AnyUrl(f"https://example.com/{id}/preview"),
115 upload_descriptor=AttachmentUploadDescriptor(
116 url=AnyUrl(f"https://example.com/{id}/upload"),
117 method="PUT",
118 headers={"X-My-Header": "my-value"},
119 ),
120 )
121 else:
122 attachment = FileAttachment(
123 id=id,
124 mime_type=input.mime_type,
125 name=input.name,
126 upload_descriptor=AttachmentUploadDescriptor(
127 url=AnyUrl(f"https://example.com/{id}/upload"),
128 method="PUT",
129 headers={"X-My-Header": "my-value"},
130 ),
131 )
132 self.files[attachment.id] = attachment
133 return attachment
134
135
136def decode_event(event: bytes) -> ThreadStreamEvent:
137 return TypeAdapter(ThreadStreamEvent).validate_json(event.split(b"data: ")[1])
138
139
140async def decode_streaming_result(
141 streaming_result: StreamingResult,
142) -> list[ThreadStreamEvent]:
143 return [decode_event(event) async for event in streaming_result.json_events]
144
145
146@contextmanager
147def make_server(
148 responder: Callable[
149 [ThreadMetadata, UserMessageItem | None, Any], AsyncIterator[ThreadStreamEvent]
150 ]
151 | None = None,
152 handle_feedback: Callable[[str, list[str], FeedbackKind, Any], None] | None = None,
153 action_callback: Callable[
154 [ThreadMetadata, Action[str, Any], WidgetItem | None, Any],
155 AsyncIterator[ThreadStreamEvent],
156 ]
157 | None = None,
158 file_store: AttachmentStore | None = None,
159):
160 global server_id
161 db_path = f"file:{server_id}?mode=memory&cache=shared"
162 server_id += 1
163 # Keep the shared in-memory database from being deleted when the last connection is closed
164 db = sqlite3.connect(db_path, uri=True)
165
166 class TestChatKitServer(ChatKitServer):
167 def __init__(self):
168 super().__init__(SQLiteStore(db_path), file_store)
169
170 def action(
171 self,
172 thread: ThreadMetadata,
173 action: Action[str, Any],
174 sender: WidgetItem | None,
175 context: Any,
176 ) -> AsyncIterator[ThreadStreamEvent]:
177 if action_callback is None:
178 raise ValueError("action_callback not wired up")
179 return action_callback(thread, action, sender, context)
180
181 def respond(
182 self,
183 thread: ThreadMetadata,
184 input_user_message: UserMessageItem | None,
185 context: Any,
186 ) -> AsyncIterator[ThreadStreamEvent]:
187 if responder is None:
188 return self._empty_responder()
189 return responder(thread, input_user_message, context)
190
191 async def _empty_responder(self) -> AsyncIterator[ThreadStreamEvent]:
192 return
193 yield
194
195 async def add_feedback(
196 self,
197 thread_id: str,
198 item_ids: list[str],
199 feedback: FeedbackKind,
200 context: Any,
201 ) -> None:
202 if handle_feedback is None:
203 return
204 handle_feedback(thread_id, item_ids, feedback, context)
205
206 async def process_streaming(
207 self, request_obj, context: Any | None = None
208 ) -> list[ThreadStreamEvent]:
209 result = await self.process(
210 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
211 )
212 assert isinstance(result, StreamingResult)
213 return await decode_streaming_result(result)
214
215 async def process_non_streaming(self, request_obj, context: Any | None = None):
216 result = await self.process(
217 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
218 )
219 assert isinstance(result, NonStreamingResult)
220 return result
221
222 try:
223 yield TestChatKitServer()
224 finally:
225 db.close()
226
227
228async def test_stream_cancellation_persists_pending_assistant_message_and_hidden_context():
229 async def responder(
230 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
231 ) -> AsyncIterator[ThreadStreamEvent]:
232 yield ThreadItemAddedEvent(
233 item=AssistantMessageItem(
234 id="assistant-message-pending",
235 created_at=datetime.now(),
236 content=[AssistantMessageContent(text="Hello, ")],
237 thread_id=thread.id,
238 )
239 )
240 yield ThreadItemUpdatedEvent(
241 item_id="assistant-message-pending",
242 update=AssistantMessageContentPartTextDelta(
243 content_index=0,
244 delta="World!",
245 ),
246 )
247 raise asyncio.CancelledError()
248
249 with make_server(responder) as server:
250 # Allow hidden_context id generation in this test store
251 original_generate_item_id = server.store.generate_item_id
252
253 def generate_item_id(
254 item_type: StoreItemType, thread: ThreadMetadata, context: Any
255 ):
256 if item_type == "sdk_hidden_context":
257 return default_generate_id("sdk_hidden_context")
258 return original_generate_item_id(item_type, thread, context)
259
260 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
261
262 stream = await server.process(
263 ThreadsCreateReq(
264 params=ThreadCreateParams(
265 input=UserMessageInput(
266 content=[UserMessageTextContent(text="Hello")],
267 attachments=[],
268 inference_options=InferenceOptions(),
269 )
270 )
271 ).model_dump_json(),
272 DEFAULT_CONTEXT,
273 )
274 assert isinstance(stream, StreamingResult)
275
276 events: list[ThreadStreamEvent] = []
277 with pytest.raises(asyncio.CancelledError): # noqa: PT012
278 async for raw in stream.json_events:
279 events.append(decode_event(raw))
280
281 thread = next(e.thread for e in events if e.type == "thread.created")
282 items = await server.store.load_thread_items(
283 thread.id, None, 1, "desc", DEFAULT_CONTEXT
284 )
285 hidden_context_item = items.data[-1]
286 assert hidden_context_item.type == "sdk_hidden_context"
287 assert (
288 hidden_context_item.content
289 == "The user cancelled the stream. Stop responding to the prior request."
290 )
291
292 assistant_message_item = await server.store.load_item(
293 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
294 )
295 assert assistant_message_item.type == "assistant_message"
296 assert assistant_message_item.content[0].text == "Hello, World!"
297
298
299async def test_stream_cancellation_does_not_persist_pending_empty_assistant_message():
300 async def responder(
301 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
302 ) -> AsyncIterator[ThreadStreamEvent]:
303 yield ThreadItemAddedEvent(
304 item=AssistantMessageItem(
305 id="assistant-message-pending",
306 created_at=datetime.now(),
307 content=[],
308 thread_id=thread.id,
309 )
310 )
311 raise asyncio.CancelledError()
312
313 with make_server(responder) as server:
314 original_generate_item_id = server.store.generate_item_id
315
316 def generate_item_id(
317 item_type: StoreItemType, thread: ThreadMetadata, context: Any
318 ):
319 if item_type == "sdk_hidden_context":
320 return default_generate_id("sdk_hidden_context")
321 return original_generate_item_id(item_type, thread, context)
322
323 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
324
325 stream = await server.process(
326 ThreadsCreateReq(
327 params=ThreadCreateParams(
328 input=UserMessageInput(
329 content=[UserMessageTextContent(text="Hello")],
330 attachments=[],
331 inference_options=InferenceOptions(),
332 )
333 )
334 ).model_dump_json(),
335 DEFAULT_CONTEXT,
336 )
337 assert isinstance(stream, StreamingResult)
338
339 events: list[ThreadStreamEvent] = []
340 with pytest.raises(asyncio.CancelledError): # noqa: PT012
341 async for raw in stream.json_events:
342 events.append(decode_event(raw))
343
344 thread = next(e.thread for e in events if e.type == "thread.created")
345 items = await server.store.load_thread_items(
346 thread.id, None, 1, "desc", DEFAULT_CONTEXT
347 )
348 hidden_context_item = items.data[-1]
349 assert hidden_context_item.type == "sdk_hidden_context"
350 assert (
351 hidden_context_item.content
352 == "The user cancelled the stream. Stop responding to the prior request."
353 )
354
355 with pytest.raises(NotFoundError):
356 await server.store.load_item(
357 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
358 )
359
360
361async def test_workflow_task_not_duplicated_on_done_event():
362 """
363 Regression test to make sure pending item updates do not modify the
364 origin item that was streamed.
365 """
366
367 async def responder(
368 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
369 ) -> AsyncIterator[ThreadStreamEvent]:
370 workflow_item = WorkflowItem(
371 id="workflow-item",
372 created_at=datetime.now(),
373 thread_id=thread.id,
374 workflow=Workflow(type="custom", tasks=[]),
375 )
376
377 yield ThreadItemAddedEvent(item=workflow_item)
378
379 task = CustomTask(title="foo", content="bar")
380 yield ThreadItemUpdatedEvent(
381 item_id=workflow_item.id,
382 update=WorkflowTaskAdded(task=task, task_index=0),
383 )
384
385 workflow_item.workflow.tasks.append(task)
386 yield ThreadItemDoneEvent(item=workflow_item)
387
388 with make_server(responder) as server:
389 events = await server.process_streaming(
390 ThreadsCreateReq(
391 params=ThreadCreateParams(
392 input=UserMessageInput(
393 content=[UserMessageTextContent(text="Hello")],
394 attachments=[],
395 inference_options=InferenceOptions(),
396 )
397 )
398 )
399 )
400
401 thread = next(e.thread for e in events if e.type == "thread.created")
402 workflow_done_event = next(
403 e
404 for e in events
405 if isinstance(e, ThreadItemDoneEvent) and isinstance(e.item, WorkflowItem)
406 )
407 workflow_done_item = cast(WorkflowItem, workflow_done_event.item)
408 assert len(workflow_done_item.workflow.tasks) == 1
409 assert workflow_done_item.workflow.tasks[0].title == "foo"
410
411 stored = await server.store.load_item(
412 thread.id, workflow_done_item.id, DEFAULT_CONTEXT
413 )
414 assert isinstance(stored, WorkflowItem)
415 stored_workflow = cast(WorkflowItem, stored)
416 assert len(stored_workflow.workflow.tasks) == 1
417 assert stored_workflow.workflow.tasks[0].title == "foo"
418
419
420async def test_flows_context_to_responder():
421 responder_context = None
422 add_feedback_context = None
423
424 async def responder(
425 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
426 ) -> AsyncIterator[ThreadStreamEvent]:
427 nonlocal responder_context
428 responder_context = context
429 yield ThreadItemDoneEvent(
430 item=AssistantMessageItem(
431 id="msg_1",
432 created_at=datetime.now(),
433 content=[AssistantMessageContent(text="Hello, world!")],
434 thread_id=thread.id,
435 ),
436 )
437
438 def add_feedback(
439 thread_id: str, item_ids: list[str], feedback: FeedbackKind, context: Any
440 ) -> None:
441 nonlocal add_feedback_context
442 add_feedback_context = context
443
444 context = RequestContext(user_id="text_ctx")
445 with make_server(responder, add_feedback) as server:
446 events = await server.process_streaming(
447 ThreadsCreateReq(
448 params=ThreadCreateParams(
449 input=UserMessageInput(
450 content=[UserMessageTextContent(text="Hello, world!")],
451 attachments=[],
452 inference_options=InferenceOptions(),
453 )
454 )
455 ),
456 context,
457 )
458 assert responder_context == context
459 thread = next(
460 event.thread for event in events if event.type == "thread.created"
461 )
462 item = next(
463 event.item
464 for event in events
465 if event.type == "thread.item.done"
466 and event.item.type == "assistant_message"
467 )
468 await server.process_non_streaming(
469 ItemsFeedbackReq(
470 params=ItemFeedbackParams(
471 thread_id=thread.id,
472 item_ids=[item.id],
473 kind="positive",
474 )
475 ),
476 context,
477 )
478 assert add_feedback_context == context
479
480
481async def test_creates_thread():
482 async def responder(
483 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
484 ) -> AsyncIterator[ThreadStreamEvent]:
485 return
486 yield
487
488 with make_server(responder) as server:
489 events = await server.process_streaming(
490 ThreadsCreateReq(
491 params=ThreadCreateParams(
492 input=UserMessageInput(
493 content=[UserMessageTextContent(text="Hello, world!")],
494 attachments=[],
495 inference_options=InferenceOptions(),
496 quoted_text="Quoted text",
497 )
498 )
499 )
500 )
501
502 thread_created_event = next(
503 event for event in events if event.type == "thread.created"
504 )
505 assert thread_created_event.thread.title is None
506
507 item_done_event = next(
508 event.item for event in events if event.type == "thread.item.done"
509 )
510
511 assert item_done_event.type == "user_message"
512 assert item_done_event.content[0].text == "Hello, world!"
513 assert item_done_event.quoted_text == "Quoted text"
514
515
516async def test_saves_thread_title():
517 async def responder(
518 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
519 ) -> AsyncIterator[ThreadStreamEvent]:
520 thread.title = "Updated title"
521 return
522 yield
523
524 with make_server(responder) as server:
525 events = await server.process_streaming(
526 ThreadsCreateReq(
527 params=ThreadCreateParams(
528 input=UserMessageInput(
529 content=[UserMessageTextContent(text="Hello, world!")],
530 attachments=[],
531 inference_options=InferenceOptions(),
532 )
533 )
534 )
535 )
536 thread = next(
537 event.thread for event in events if event.type == "thread.created"
538 )
539 assert (
540 await server.store.load_thread(
541 thread.id, RequestContext(user_id="test_user")
542 )
543 ).title == "Updated title"
544 assert events[-1].type == "thread.updated"
545 assert events[-1].thread.title == "Updated title"
546
547
548async def test_saves_thread_metadata():
549 async def responder(
550 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
551 ) -> AsyncIterator[ThreadStreamEvent]:
552 thread.metadata["test_key"] = "test_value"
553 return
554 yield
555
556 with make_server(responder) as server:
557 events = await server.process_streaming(
558 ThreadsCreateReq(
559 params=ThreadCreateParams(
560 input=UserMessageInput(
561 content=[UserMessageTextContent(text="Hello, world!")],
562 attachments=[],
563 inference_options=InferenceOptions(),
564 )
565 )
566 )
567 )
568 thread = next(
569 event.thread for event in events if event.type == "thread.created"
570 )
571 assert (await server.store.load_thread(thread.id, DEFAULT_CONTEXT)).metadata[
572 "test_key"
573 ] == "test_value"
574 assert events[-1].type == "thread.updated"
575
576
577async def test_saves_thread_locked_fields():
578 async def responder(
579 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
580 ) -> AsyncIterator[ThreadStreamEvent]:
581 thread.status = LockedStatus(reason="Because")
582 return
583 yield
584
585 with make_server(responder) as server:
586 events = await server.process_streaming(
587 ThreadsCreateReq(
588 params=ThreadCreateParams(
589 input=UserMessageInput(
590 content=[UserMessageTextContent(text="Hello, world!")],
591 attachments=[],
592 inference_options=InferenceOptions(),
593 )
594 )
595 )
596 )
597 thread = next(
598 event.thread for event in events if event.type == "thread.created"
599 )
600 loaded = await server.store.load_thread(
601 thread.id, RequestContext(user_id="test_user")
602 )
603 assert loaded.status == LockedStatus(reason="Because")
604 assert events[-1].type == "thread.updated"
605 assert events[-1].thread.status == LockedStatus(reason="Because")
606
607
608async def test_emits_thread_updated_mid_stream_and_persists():
609 async def responder(
610 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
611 ) -> AsyncIterator[ThreadStreamEvent]:
612 # First assistant message
613 yield ThreadItemDoneEvent(
614 item=AssistantMessageItem(
615 id="assistant_a",
616 content=[AssistantMessageContent(text="A")],
617 created_at=datetime.now(),
618 thread_id=thread.id,
619 ),
620 )
621
622 # Update the thread mid-stream
623 thread.title = "Mid-stream title"
624
625 # Second assistant message, after which thread.updated should be emitted
626 yield ThreadItemDoneEvent(
627 item=AssistantMessageItem(
628 id="assistant_b",
629 content=[AssistantMessageContent(text="B")],
630 created_at=datetime.now(),
631 thread_id=thread.id,
632 ),
633 )
634
635 # Continue streaming after the update event
636 yield ThreadItemDoneEvent(
637 item=AssistantMessageItem(
638 id="assistant_c",
639 content=[AssistantMessageContent(text="C")],
640 created_at=datetime.now(),
641 thread_id=thread.id,
642 ),
643 )
644
645 with make_server(responder) as server:
646 events = await server.process_streaming(
647 ThreadsCreateReq(
648 params=ThreadCreateParams(
649 input=UserMessageInput(
650 content=[UserMessageTextContent(text="Hello")],
651 attachments=[],
652 inference_options=InferenceOptions(),
653 )
654 )
655 )
656 )
657
658 # Find the created thread
659 thread = next(e.thread for e in events if isinstance(e, ThreadCreatedEvent))
660
661 # Ensure a thread.updated event occurs right after assistant_b
662 idx_b = next(
663 i
664 for i, e in enumerate(events)
665 if isinstance(e, ThreadItemDoneEvent)
666 and isinstance(e.item, AssistantMessageItem)
667 and e.item.id == "assistant_b"
668 )
669 assert isinstance(events[idx_b + 1], ThreadUpdatedEvent)
670 updated_event = cast(ThreadUpdatedEvent, events[idx_b + 1])
671 assert updated_event.thread.title == "Mid-stream title"
672
673 # Streaming continues after the update event
674 assert isinstance(events[idx_b + 2], ThreadItemDoneEvent)
675 done_event = cast(ThreadItemDoneEvent, events[idx_b + 2])
676 assert isinstance(done_event.item, AssistantMessageItem)
677 assert done_event.item.id == "assistant_c"
678
679 # Persisted in the store
680 saved = await server.store.load_thread(thread.id, DEFAULT_CONTEXT)
681 assert saved.title == "Mid-stream title"
682
683
684async def test_respond_with_tool_call():
685 async def responder(
686 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
687 ) -> AsyncIterator[ThreadStreamEvent]:
688 if isinstance(input, UserMessageItem):
689 yield ThreadItemDoneEvent(
690 item=ClientToolCallItem(
691 id="msg_1",
692 created_at=datetime.now(),
693 name="tool_call_1",
694 arguments={"arg1": "val1", "arg2": False},
695 call_id="tool_call_1",
696 thread_id=thread.id,
697 ),
698 )
699 elif input is None:
700 yield ThreadItemDoneEvent(
701 item=AssistantMessageItem(
702 id="msg_2",
703 content=[
704 AssistantMessageContent(text="Glad the tool call succeeded!")
705 ],
706 created_at=datetime.now(),
707 thread_id=thread.id,
708 ),
709 )
710
711 with make_server(responder) as server:
712 events = await server.process_streaming(
713 ThreadsCreateReq(
714 params=ThreadCreateParams(
715 input=UserMessageInput(
716 content=[UserMessageTextContent(text="Hello, world!")],
717 attachments=[],
718 inference_options=InferenceOptions(),
719 )
720 )
721 )
722 )
723
724 assert len(events) == 4
725 assert events[0].type == "thread.created"
726 thread = events[0].thread
727
728 assert events[1].type == "thread.item.done"
729 assert events[1].item.type == "user_message"
730
731 assert events[2].type == "stream_options"
732
733 assert events[3].type == "thread.item.done"
734 assert events[3].item.type == "client_tool_call"
735 assert events[3].item.id == "msg_1"
736 assert events[3].item.name == "tool_call_1"
737 assert events[3].item.arguments == {"arg1": "val1", "arg2": False}
738 assert events[3].item.call_id == "tool_call_1"
739
740 events = await server.process_streaming(
741 ThreadsAddClientToolOutputReq(
742 params=ThreadAddClientToolOutputParams(
743 thread_id=thread.id,
744 result={"text": "Wow!"},
745 )
746 )
747 )
748
749 assert len(events) == 2
750 assert events[0].type == "stream_options"
751 assert events[1].type == "thread.item.done"
752 assert events[1].item.type == "assistant_message"
753
754
755async def test_respond_with_tool_status():
756 async def responder(
757 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
758 ) -> AsyncIterator[ThreadStreamEvent]:
759 yield ProgressUpdateEvent(text="Tool status")
760 yield ThreadItemDoneEvent(
761 item=AssistantMessageItem(
762 id="msg_4",
763 content=[AssistantMessageContent(text="Hello, world!")],
764 created_at=datetime.now(),
765 thread_id=thread.id,
766 ),
767 )
768
769 with make_server(responder) as server:
770 events = await server.process_streaming(
771 ThreadsCreateReq(
772 params=ThreadCreateParams(
773 input=UserMessageInput(
774 content=[UserMessageTextContent(text="Hello, world!")],
775 attachments=[],
776 inference_options=InferenceOptions(),
777 )
778 )
779 )
780 )
781
782 assert len(events) == 5
783 assert events[0].type == "thread.created"
784 assert events[1].type == "thread.item.done"
785 assert events[2].type == "stream_options"
786 assert events[3].type == "progress_update"
787 assert events[4].type == "thread.item.done"
788
789
790async def test_list_threads_response():
791 with make_server() as server:
792 for i in range(25):
793 await server.process_streaming(
794 ThreadsCreateReq(
795 params=ThreadCreateParams(
796 input=UserMessageInput(
797 content=[UserMessageTextContent(text=f"Thread {i}")],
798 attachments=[],
799 inference_options=InferenceOptions(),
800 )
801 )
802 )
803 )
804 list_result = await server.process_non_streaming(
805 ThreadsListReq(params=ThreadListParams(limit=20))
806 )
807 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
808 assert len(threads.data) == 20
809
810 # Newest first
811 assert threads.data[0].created_at > threads.data[1].created_at
812 assert threads.has_more is True
813 assert threads.after is not None
814
815 more_threads = await server.process_non_streaming(
816 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
817 )
818 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(more_threads.json)
819 assert len(threads.data) == 5
820 assert threads.has_more is False
821 assert not threads.after
822
823
824async def test_list_items_response():
825 async def responder(
826 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
827 ) -> AsyncIterator[ThreadStreamEvent]:
828 for i in range(25):
829 yield ThreadItemDoneEvent(
830 item=UserMessageItem(
831 id=f"msg_{i}",
832 content=[UserMessageTextContent(text=f"Message {i}")],
833 attachments=[],
834 inference_options=InferenceOptions(),
835 thread_id=thread.id,
836 created_at=datetime.now(),
837 ),
838 )
839
840 with make_server(responder) as server:
841 events = await server.process_streaming(
842 ThreadsCreateReq(
843 params=ThreadCreateParams(
844 input=UserMessageInput(
845 content=[UserMessageTextContent(text="Thread 1")],
846 attachments=[],
847 inference_options=InferenceOptions(),
848 )
849 )
850 )
851 )
852 thread = next(
853 event.thread for event in events if event.type == "thread.created"
854 )
855 list_result = await server.process_non_streaming(
856 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
857 )
858 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
859 assert len(items.data) == 20
860 # Newest first
861 assert items.data[0].created_at > items.data[1].created_at
862 assert items.has_more is True
863 assert items.after is not None
864
865 list_result = await server.process_non_streaming(
866 ItemsListReq(
867 params=ItemsListParams(thread_id=thread.id, after=items.data[-1].id)
868 )
869 )
870 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
871 assert len(items.data) == 6 # +1 for user message
872 assert items.has_more is False
873
874
875async def test_calls_action():
876 actions = []
877
878 async def responder(
879 thread: ThreadMetadata,
880 input: UserMessageItem | None,
881 context: Any,
882 ) -> AsyncIterator[ThreadStreamEvent]:
883 yield ThreadItemDoneEvent(
884 item=WidgetItem(
885 id="widget_1",
886 type="widget",
887 created_at=datetime.now(),
888 thread_id=thread.id,
889 widget=Card(
890 children=[
891 Text(
892 key="text_1",
893 value="test",
894 )
895 ],
896 ),
897 ),
898 )
899
900 async def action(
901 thread: ThreadMetadata,
902 action: Action[str, Any],
903 sender: WidgetItem | None,
904 context: Any,
905 ) -> AsyncIterator[ThreadStreamEvent]:
906 actions.append((action, sender))
907 assert sender
908
909 yield ThreadItemUpdatedEvent(
910 item_id=sender.id,
911 update=WidgetRootUpdated(
912 widget=Card(
913 children=[
914 Text(value="Email sent!"),
915 ]
916 ),
917 ),
918 )
919
920 with make_server(responder, action_callback=action) as server:
921 events = await server.process_streaming(
922 ThreadsCreateReq(
923 params=ThreadCreateParams(
924 input=UserMessageInput(
925 content=[UserMessageTextContent(text="Show widget")],
926 attachments=[],
927 inference_options=InferenceOptions(),
928 )
929 )
930 )
931 )
932 thread = next(
933 event.thread for event in events if event.type == "thread.created"
934 )
935 widget_item = next(
936 event.item
937 for event in events
938 if isinstance(event, ThreadItemDoneEvent)
939 and isinstance(event.item, WidgetItem)
940 )
941
942 events = await server.process_streaming(
943 ThreadsCustomActionReq(
944 params=ThreadCustomActionParams(
945 thread_id=thread.id,
946 item_id=widget_item.id,
947 action=Action(type="create_user", payload={"user_id": "123"}),
948 )
949 )
950 )
951
952 assert actions
953 assert actions[0] == (
954 Action(type="create_user", payload={"user_id": "123"}),
955 widget_item,
956 )
957
958 assert len(events) == 2
959 assert events[0].type == "stream_options"
960 assert events[1].type == "thread.item.updated"
961 assert isinstance(events[1], ThreadItemUpdatedEvent)
962 assert events[1].update.type == "widget.root.updated"
963 assert events[1].update.widget == Card(children=[Text(value="Email sent!")])
964
965
966async def test_add_feedback():
967 called = []
968
969 def handle_feedback(thread_id, item_ids, feedback, context):
970 called.append((thread_id, item_ids, feedback, context))
971
972 async def responder(
973 thread: ThreadMetadata,
974 input: UserMessageItem | None,
975 context: Any,
976 ) -> AsyncIterator[ThreadStreamEvent]:
977 yield ThreadItemDoneEvent(
978 item=AssistantMessageItem(
979 id="assistant_msg_1",
980 content=[
981 AssistantMessageContent(text="Feedback received!"),
982 ],
983 created_at=datetime.now(),
984 thread_id=thread.id,
985 ),
986 )
987 yield ThreadItemDoneEvent(
988 item=WidgetItem(
989 id="widget_1",
990 type="widget",
991 created_at=datetime.now(),
992 widget=Card(children=[Text(key="text", value="Widget content")]),
993 thread_id=thread.id,
994 ),
995 )
996
997 with make_server(responder, handle_feedback=handle_feedback) as server:
998 events = await server.process_streaming(
999 ThreadsCreateReq(
1000 params=ThreadCreateParams(
1001 input=UserMessageInput(
1002 content=[UserMessageTextContent(text="Test feedback")],
1003 attachments=[],
1004 inference_options=InferenceOptions(),
1005 )
1006 )
1007 )
1008 )
1009 thread = next(
1010 event.thread for event in events if event.type == "thread.created"
1011 )
1012
1013 await server.process_non_streaming(
1014 ItemsFeedbackReq(
1015 params=ItemFeedbackParams(
1016 thread_id=thread.id,
1017 item_ids=["assistant_msg_1"],
1018 kind="positive",
1019 )
1020 )
1021 )
1022 await server.process_non_streaming(
1023 ItemsFeedbackReq(
1024 params=ItemFeedbackParams(
1025 thread_id=thread.id,
1026 item_ids=["widget_1"],
1027 kind="negative",
1028 )
1029 )
1030 )
1031
1032 assert len(called) == 2
1033 assert called[0][0] == thread.id
1034 assert called[0][1] == ["assistant_msg_1"]
1035 assert called[0][2] == "positive"
1036
1037 assert called[1][0] == thread.id
1038 assert called[1][1] == ["widget_1"]
1039 assert called[1][2] == "negative"
1040
1041
1042async def test_get_thread_by_id():
1043 make_thread_items()
1044
1045 with make_server() as server:
1046 # Create a thread by sending a ThreadCreateRequest
1047 events = await server.process_streaming(
1048 ThreadsCreateReq(
1049 params=ThreadCreateParams(
1050 input=UserMessageInput(
1051 content=[UserMessageTextContent(text="Test thread")],
1052 attachments=[],
1053 inference_options=InferenceOptions(),
1054 )
1055 )
1056 )
1057 )
1058 thread = next(
1059 event.thread for event in events if event.type == "thread.created"
1060 )
1061
1062 # Retrieve the thread by id
1063 result = await server.process_non_streaming(
1064 ThreadsGetByIdReq(
1065 params=ThreadGetByIdParams(thread_id=thread.id),
1066 )
1067 )
1068 loaded_thread = TypeAdapter(Thread).validate_json(result.json)
1069 assert loaded_thread.id == thread.id
1070 assert loaded_thread.title == thread.title
1071 assert loaded_thread.items.data[0].type == "user_message"
1072 assert loaded_thread.items.data[0].content[0].text == "Test thread"
1073
1074
1075async def test_create_file():
1076 store = InMemoryFileStore()
1077 file_name = "test-file-name"
1078 file_content_type = "text/plain"
1079
1080 with make_server(file_store=store) as server:
1081 result = await server.process_non_streaming(
1082 AttachmentsCreateReq(
1083 params=AttachmentCreateParams(
1084 name=file_name, size=1024, mime_type=file_content_type
1085 )
1086 )
1087 )
1088 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1089
1090 # Verify response includes file
1091 assert attachment.id is not None
1092 assert attachment.mime_type == file_content_type
1093 assert attachment.name == file_name
1094 assert attachment.type == "file"
1095 assert attachment.upload_descriptor is not None
1096 assert attachment.upload_descriptor.url == AnyUrl(
1097 f"https://example.com/{attachment.id}/upload"
1098 )
1099 assert attachment.upload_descriptor.method == "PUT"
1100 assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"}
1101 assert attachment.id in store.files
1102
1103
1104async def test_create_image_file():
1105 store = InMemoryFileStore()
1106 file_name = "test-file-name"
1107 file_content_type = "image/png"
1108
1109 with make_server(file_store=store) as server:
1110 result = await server.process_non_streaming(
1111 AttachmentsCreateReq(
1112 params=AttachmentCreateParams(
1113 name=file_name,
1114 size=1024,
1115 mime_type=file_content_type,
1116 )
1117 )
1118 )
1119 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1120
1121 assert attachment.id is not None
1122 assert attachment.mime_type == file_content_type
1123 assert attachment.name == file_name
1124 assert attachment.type == "image"
1125 assert attachment.preview_url == AnyUrl(
1126 f"https://example.com/{attachment.id}/preview"
1127 )
1128 assert attachment.upload_descriptor is not None
1129 assert attachment.upload_descriptor.url == AnyUrl(
1130 f"https://example.com/{attachment.id}/upload"
1131 )
1132 assert attachment.upload_descriptor.method == "PUT"
1133 assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"}
1134
1135 assert attachment.id in store.files
1136
1137
1138async def test_create_file_without_filestore():
1139 with pytest.raises(RuntimeError):
1140 with make_server() as server:
1141 await server.process_non_streaming(
1142 AttachmentsCreateReq(
1143 params=AttachmentCreateParams(
1144 name="test-file-name",
1145 size=1024,
1146 mime_type="text/plain",
1147 )
1148 )
1149 )
1150
1151
1152async def test_delete_file():
1153 store = InMemoryFileStore()
1154 store.files["test-file-id"] = {"id": "test-file-id"}
1155 with make_server(file_store=store) as server:
1156 await server.process_non_streaming(
1157 AttachmentsDeleteReq(
1158 params=AttachmentDeleteParams(attachment_id="test-file-id")
1159 )
1160 )
1161 assert store.files == {}
1162
1163
1164async def test_delete_file_without_filestore():
1165 with pytest.raises(RuntimeError):
1166 with make_server() as server:
1167 await server.process_non_streaming(
1168 AttachmentsDeleteReq(
1169 params=AttachmentDeleteParams(attachment_id="test-file-id")
1170 )
1171 )
1172
1173
1174async def test_list_threads_paginated_response():
1175 with make_server() as server:
1176 for i in range(25):
1177 await server.process_streaming(
1178 ThreadsCreateReq(
1179 params=ThreadCreateParams(
1180 input=UserMessageInput(
1181 content=[UserMessageTextContent(text=f"Thread {i}")],
1182 attachments=[],
1183 inference_options=InferenceOptions(),
1184 )
1185 )
1186 )
1187 )
1188 list_result = await server.process_non_streaming(
1189 ThreadsListReq(params=ThreadListParams(limit=20))
1190 )
1191 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1192 assert len(threads.data) == 20
1193 assert threads.has_more is True
1194 assert threads.after is not None
1195
1196 list_result = await server.process_non_streaming(
1197 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
1198 )
1199 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1200 assert len(threads.data) == 5
1201 assert threads.has_more is False
1202
1203
1204async def test_threads_update_request():
1205 with make_server() as server:
1206 events = await server.process_streaming(
1207 ThreadsCreateReq(
1208 params=ThreadCreateParams(
1209 input=UserMessageInput(
1210 content=[UserMessageTextContent(text="Hello, world!")],
1211 attachments=[],
1212 inference_options=InferenceOptions(),
1213 )
1214 )
1215 )
1216 )
1217 thread = next(
1218 event.thread for event in events if event.type == "thread.created"
1219 )
1220 updated_thread = await server.process_non_streaming(
1221 ThreadsUpdateReq(
1222 params=ThreadUpdateParams(thread_id=thread.id, title="New Title")
1223 )
1224 )
1225 updated_thread = TypeAdapter[Thread](Thread).validate_json(updated_thread.json)
1226 assert updated_thread.title == "New Title"
1227
1228
1229async def test_returns_widget_item_generator():
1230 thread = ThreadMetadata(
1231 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1232 )
1233
1234 def render_widget(i: int) -> Card:
1235 return Card(children=[Text(id="text", value="Hello, world"[:i])])
1236
1237 async def widget_generator():
1238 yield render_widget(0)
1239 yield render_widget(3)
1240 yield render_widget(6)
1241 yield render_widget(12)
1242
1243 events = [event async for event in stream_widget(thread, widget_generator())]
1244
1245 assert len(events) == 5
1246 assert isinstance(events[0], ThreadItemAddedEvent)
1247 assert isinstance(events[0].item, WidgetItem)
1248 assert events[0].item.widget == Card(children=[Text(id="text", value="")])
1249
1250 assert isinstance(events[1], ThreadItemUpdatedEvent)
1251 assert events[1].update.type == "widget.streaming_text.value_delta"
1252 assert events[1].update.component_id == "text"
1253 assert events[1].update.delta == "Hel"
1254
1255 assert isinstance(events[2], ThreadItemUpdatedEvent)
1256 assert events[2].update.type == "widget.streaming_text.value_delta"
1257 assert events[2].update.component_id == "text"
1258 assert events[2].update.delta == "lo,"
1259
1260 assert isinstance(events[3], ThreadItemUpdatedEvent)
1261 assert events[3].update.type == "widget.streaming_text.value_delta"
1262 assert events[3].update.component_id == "text"
1263 assert events[3].update.delta == " world"
1264
1265 assert isinstance(events[4], ThreadItemDoneEvent)
1266 assert isinstance(events[4].item, WidgetItem)
1267 assert events[4].item.widget == Card(
1268 children=[Text(id="text", value="Hello, world")]
1269 )
1270
1271
1272async def test_returns_widget_item_generator_full_replace():
1273 thread = ThreadMetadata(
1274 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1275 )
1276
1277 async def widget_generator():
1278 yield Card(children=[Text(id="text", value="Hello")])
1279 yield Card(children=[Text(id="text", value="World")])
1280
1281 events = [event async for event in stream_widget(thread, widget_generator())]
1282
1283 assert len(events) == 3
1284 assert isinstance(events[0], ThreadItemAddedEvent)
1285 assert isinstance(events[0].item, WidgetItem)
1286 assert events[0].item.widget == Card(children=[Text(id="text", value="Hello")])
1287
1288 assert isinstance(events[1], ThreadItemUpdatedEvent)
1289 assert events[1].update.type == "widget.root.updated"
1290 assert events[1].update.widget == Card(children=[Text(id="text", value="World")])
1291
1292 assert isinstance(events[2], ThreadItemDoneEvent)
1293 assert isinstance(events[2].item, WidgetItem)
1294 assert events[2].item.widget == Card(children=[Text(id="text", value="World")])
1295
1296
1297async def test_delete_thread():
1298 with make_server() as server:
1299 events = await server.process_streaming(
1300 ThreadsCreateReq(
1301 params=ThreadCreateParams(
1302 input=UserMessageInput(
1303 content=[UserMessageTextContent(text="Thread to delete")],
1304 attachments=[],
1305 inference_options=InferenceOptions(),
1306 )
1307 )
1308 )
1309 )
1310 thread = next(
1311 event.thread for event in events if event.type == "thread.created"
1312 )
1313
1314 response = await server.process_non_streaming(
1315 ThreadsDeleteReq(params=ThreadDeleteParams(thread_id=thread.id))
1316 )
1317 assert response.json == b"{}"
1318
1319 try:
1320 await server.process_non_streaming(
1321 ThreadsGetByIdReq(
1322 params=ThreadGetByIdParams(thread_id=thread.id),
1323 )
1324 )
1325 assert False, "Expected ValueError for deleted thread"
1326 except NotFoundError as e:
1327 assert f"Thread {thread.id} not found" in str(e)
1328
1329
1330async def test_thread_item_removed_event_removes_item():
1331 removed_id = "msg_to_remove"
1332
1333 async def responder(
1334 thread: ThreadMetadata,
1335 input: UserMessageItem | None,
1336 context: Any,
1337 ) -> AsyncIterator[ThreadStreamEvent]:
1338 yield ThreadItemDoneEvent(
1339 item=UserMessageItem(
1340 id=removed_id,
1341 content=[UserMessageTextContent(text="To be removed")],
1342 attachments=[],
1343 inference_options=InferenceOptions(),
1344 thread_id=thread.id,
1345 created_at=datetime.now(),
1346 ),
1347 )
1348 yield ThreadItemRemovedEvent(item_id=removed_id)
1349
1350 with make_server(responder) as server:
1351 events = await server.process_streaming(
1352 ThreadsCreateReq(
1353 params=ThreadCreateParams(
1354 input=UserMessageInput(
1355 content=[UserMessageTextContent(text="Trigger removal")],
1356 attachments=[],
1357 inference_options=InferenceOptions(),
1358 )
1359 )
1360 )
1361 )
1362 thread = next(
1363 event.thread for event in events if event.type == "thread.created"
1364 )
1365 assert any(
1366 e.type == "thread.item.removed" and e.item_id == removed_id for e in events
1367 )
1368 # The item should not be present in the store
1369 list_result = await server.process_non_streaming(
1370 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1371 )
1372 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1373 assert all(i.id != removed_id for i in items)
1374
1375
1376async def test_raising_in_responder_yields_error_event():
1377 async def responder(
1378 thread: ThreadMetadata,
1379 input: UserMessageItem | None,
1380 context: Any,
1381 ) -> AsyncIterator[ThreadStreamEvent]:
1382 raise ValueError("Test error")
1383 yield
1384
1385 with make_server(responder) as server:
1386 result = await server.process(
1387 ThreadsCreateReq(
1388 params=ThreadCreateParams(
1389 input=UserMessageInput(
1390 content=[UserMessageTextContent(text="Test error")],
1391 attachments=[],
1392 inference_options=InferenceOptions(),
1393 )
1394 )
1395 ).model_dump_json(),
1396 RequestContext(user_id="test_user"),
1397 )
1398 assert isinstance(result, StreamingResult)
1399 iter = result.__aiter__()
1400
1401 # Yield an errro
1402 async for e in iter:
1403 e = decode_event(e)
1404 if e.type == "error":
1405 assert e.code == ErrorCode.STREAM_ERROR
1406 break
1407 else:
1408 assert False, "Expected error event"
1409
1410 # Then finish
1411 try:
1412 await iter.__anext__()
1413 except StopAsyncIteration:
1414 pass
1415 else:
1416 assert False, "Expected StopAsyncIteration"
1417
1418
1419async def test_thread_item_replaced_event_replaces_item():
1420 original_id = "msg_to_replace"
1421
1422 async def responder(
1423 thread: ThreadMetadata,
1424 input: UserMessageItem | None,
1425 context: Any,
1426 ) -> AsyncIterator[ThreadStreamEvent]:
1427 # First add an item
1428 yield ThreadItemDoneEvent(
1429 item=UserMessageItem(
1430 id=original_id,
1431 content=[UserMessageTextContent(text="Original content")],
1432 attachments=[],
1433 inference_options=InferenceOptions(),
1434 thread_id=thread.id,
1435 created_at=datetime.now(),
1436 ),
1437 )
1438 # Then replace it
1439 replacement_item = AssistantMessageItem(
1440 id=original_id,
1441 content=[AssistantMessageContent(text="This is the replacement content")],
1442 created_at=datetime.now(),
1443 thread_id=thread.id,
1444 )
1445 yield ThreadItemReplacedEvent(item=replacement_item)
1446
1447 with make_server(responder) as server:
1448 events = await server.process_streaming(
1449 ThreadsCreateReq(
1450 params=ThreadCreateParams(
1451 input=UserMessageInput(
1452 content=[UserMessageTextContent(text="Trigger replacement")],
1453 attachments=[],
1454 inference_options=InferenceOptions(),
1455 )
1456 )
1457 )
1458 )
1459 thread = next(
1460 event.thread for event in events if event.type == "thread.created"
1461 )
1462
1463 # Verify the replacement event was generated
1464 assert any(
1465 e.type == "thread.item.replaced" and e.item.id == original_id
1466 for e in events
1467 )
1468
1469 # Verify the item was replaced in the store
1470 list_result = await server.process_non_streaming(
1471 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1472 )
1473 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1474
1475 # Find the replaced item
1476 replaced_item = next((i for i in items if i.id == original_id), None)
1477 assert replaced_item is not None
1478 assert isinstance(replaced_item, AssistantMessageItem)
1479 assert isinstance(replaced_item.content[0], AssistantMessageContent)
1480 assert replaced_item.content[0].text == "This is the replacement content"
1481
1482
1483async def test_thread_item_replaced_event_with_widget():
1484 widget_id = "widget_to_replace"
1485 original_widget = Card(children=[Text(value="Original widget content")])
1486 replacement_widget = Card(children=[Text(value="Replaced widget content")])
1487
1488 async def responder(
1489 thread: ThreadMetadata,
1490 input: UserMessageItem | None,
1491 context: Any,
1492 ) -> AsyncIterator[ThreadStreamEvent]:
1493 # First add a widget item
1494 yield ThreadItemDoneEvent(
1495 item=WidgetItem(
1496 id=widget_id,
1497 widget=original_widget,
1498 created_at=datetime.now(),
1499 thread_id=thread.id,
1500 ),
1501 )
1502 # Then replace it
1503 yield ThreadItemReplacedEvent(
1504 item=WidgetItem(
1505 id=widget_id,
1506 widget=replacement_widget,
1507 created_at=datetime.now(),
1508 thread_id=thread.id,
1509 ),
1510 )
1511
1512 with make_server(responder) as server:
1513 events = await server.process_streaming(
1514 ThreadsCreateReq(
1515 params=ThreadCreateParams(
1516 input=UserMessageInput(
1517 content=[
1518 UserMessageTextContent(text="Trigger widget replacement")
1519 ],
1520 attachments=[],
1521 inference_options=InferenceOptions(),
1522 )
1523 )
1524 )
1525 )
1526 thread = next(
1527 event.thread for event in events if event.type == "thread.created"
1528 )
1529
1530 # Verify the replacement event was generated
1531 replacement_event = next(
1532 (
1533 e
1534 for e in events
1535 if e.type == "thread.item.replaced" and e.item.id == widget_id
1536 ),
1537 None,
1538 )
1539 assert replacement_event is not None
1540 assert isinstance(replacement_event.item, WidgetItem)
1541 assert replacement_event.item.widget == replacement_widget
1542
1543 # Verify the item was replaced in the store
1544 list_result = await server.process_non_streaming(
1545 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1546 )
1547 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1548
1549 # Find the replaced widget item
1550 replaced_item = next((i for i in items if i.id == widget_id), None)
1551 assert replaced_item is not None
1552 assert isinstance(replaced_item, WidgetItem)
1553 assert isinstance(replaced_item.widget, Card)
1554 assert isinstance(replaced_item.widget.children[0], Text)
1555 assert replaced_item.widget.children[0].value == "Replaced widget content"
1556
1557
1558async def test_retry_after_item():
1559 responder_calls = []
1560
1561 async def responder(
1562 thread: ThreadMetadata,
1563 input: UserMessageItem | None,
1564 context: Any,
1565 ) -> AsyncIterator[ThreadStreamEvent]:
1566 responder_calls.append(input)
1567
1568 if len(responder_calls) == 1:
1569 # First response - return an assistant message
1570 yield ThreadItemDoneEvent(
1571 item=AssistantMessageItem(
1572 id="assistant_msg_1",
1573 content=[AssistantMessageContent(text="First response")],
1574 created_at=datetime.now(),
1575 thread_id=thread.id,
1576 ),
1577 )
1578 else:
1579 # Retry response - return different content
1580 yield ThreadItemDoneEvent(
1581 item=AssistantMessageItem(
1582 id="assistant_msg_2",
1583 content=[AssistantMessageContent(text="Retried response")],
1584 created_at=datetime.now(),
1585 thread_id=thread.id,
1586 ),
1587 )
1588
1589 with make_server(responder) as server:
1590 events = await server.process_streaming(
1591 ThreadsCreateReq(
1592 params=ThreadCreateParams(
1593 input=UserMessageInput(
1594 content=[UserMessageTextContent(text="Hello, world!")],
1595 attachments=[],
1596 inference_options=InferenceOptions(),
1597 )
1598 )
1599 )
1600 )
1601 thread = next(
1602 event.thread for event in events if event.type == "thread.created"
1603 )
1604
1605 list_result = await server.process_non_streaming(
1606 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1607 )
1608 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1609 assert len(items_before.data) == 2
1610 assert items_before.data[0].type == "assistant_message"
1611 assert items_before.data[0].content[0].type == "output_text"
1612 assert items_before.data[0].content[0].text == "First response"
1613 assert items_before.data[1].type == "user_message"
1614 assert items_before.data[1].content[0].text == "Hello, world!"
1615
1616 # Get the user message ID for retrying after it
1617 user_message_id = items_before.data[1].id
1618
1619 # Retry after the user message
1620 retry_events = await server.process_streaming(
1621 ThreadsRetryAfterItemReq(
1622 params=ThreadRetryAfterItemParams(
1623 thread_id=thread.id, item_id=user_message_id
1624 )
1625 )
1626 )
1627
1628 # Verify the retry generated new response
1629 assert len(retry_events) == 2
1630 assert retry_events[0].type == "stream_options"
1631 assert retry_events[1].type == "thread.item.done"
1632 assert retry_events[1].item.type == "assistant_message"
1633 assert retry_events[1].item.content[0].type == "output_text"
1634 assert retry_events[1].item.content[0].text == "Retried response"
1635
1636 # Verify the responder was called twice with the same user message
1637 assert len(responder_calls) == 2
1638 assert responder_calls[0] == responder_calls[1]
1639 assert responder_calls[0].type == "user_message"
1640 assert responder_calls[0].content[0].text == "Hello, world!"
1641
1642 # Verify the assistant response was replaced in storage
1643 list_result_after = await server.process_non_streaming(
1644 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1645 )
1646 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1647 list_result_after.json
1648 )
1649 assert len(items_after.data) == 2 # Still user message + assistant message
1650 assert items_after.data[0].type == "assistant_message"
1651 assert items_after.data[0].content[0].type == "output_text"
1652 assert items_after.data[0].content[0].text == "Retried response" # New response
1653 assert items_after.data[1].type == "user_message"
1654 assert items_after.data[1].content[0].text == "Hello, world!"
1655
1656
1657async def test_retry_after_item_invalid_item():
1658 async def responder(
1659 thread: ThreadMetadata,
1660 input: UserMessageItem | None,
1661 context: Any,
1662 ) -> AsyncIterator[ThreadStreamEvent]:
1663 yield ThreadItemDoneEvent(
1664 item=AssistantMessageItem(
1665 id="assistant_msg_1",
1666 content=[AssistantMessageContent(text="Response")],
1667 created_at=datetime.now(),
1668 thread_id=thread.id,
1669 ),
1670 )
1671
1672 with make_server(responder) as server:
1673 events = await server.process_streaming(
1674 ThreadsCreateReq(
1675 params=ThreadCreateParams(
1676 input=UserMessageInput(
1677 content=[UserMessageTextContent(text="Hello")],
1678 attachments=[],
1679 inference_options=InferenceOptions(),
1680 )
1681 )
1682 )
1683 )
1684 thread = next(
1685 event.thread for event in events if event.type == "thread.created"
1686 )
1687
1688 items_result = await server.process_non_streaming(
1689 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1690 )
1691 items = TypeAdapter(Page[ThreadItem]).validate_json(items_result.json)
1692 assistant_item_id = items.data[0].id
1693
1694 # Try to retry after an assistant message (should fail)
1695 with pytest.raises(ValueError, match="is not a user message"):
1696 await server.process_streaming(
1697 ThreadsRetryAfterItemReq(
1698 params=ThreadRetryAfterItemParams(
1699 thread_id=thread.id, item_id=assistant_item_id
1700 )
1701 )
1702 )
1703
1704
1705async def test_retry_after_item_with_multiple_messages():
1706 responder_calls = []
1707
1708 async def responder(
1709 thread: ThreadMetadata,
1710 input: UserMessageItem | None,
1711 context: Any,
1712 ) -> AsyncIterator[ThreadStreamEvent]:
1713 responder_calls.append(input)
1714 assert input is not None
1715 assert input.type == "user_message"
1716 yield ThreadItemDoneEvent(
1717 item=AssistantMessageItem(
1718 id=f"assistant_msg_{len(responder_calls)}",
1719 content=[
1720 AssistantMessageContent(
1721 text=f"Response to: {input.content[0].text}"
1722 )
1723 ],
1724 created_at=datetime.now(),
1725 thread_id=thread.id,
1726 ),
1727 )
1728
1729 with make_server(responder) as server:
1730 events = await server.process_streaming(
1731 ThreadsCreateReq(
1732 params=ThreadCreateParams(
1733 input=UserMessageInput(
1734 content=[UserMessageTextContent(text="First message")],
1735 attachments=[],
1736 inference_options=InferenceOptions(),
1737 )
1738 )
1739 )
1740 )
1741 thread = next(
1742 event.thread for event in events if event.type == "thread.created"
1743 )
1744
1745 await server.process_streaming(
1746 ThreadsAddUserMessageReq(
1747 params=ThreadAddUserMessageParams(
1748 thread_id=thread.id,
1749 input=UserMessageInput(
1750 content=[UserMessageTextContent(text="Second message")],
1751 attachments=[],
1752 inference_options=InferenceOptions(),
1753 ),
1754 )
1755 )
1756 )
1757
1758 # Verify we have 4 items: user1, assistant1, user2, assistant2
1759 list_result = await server.process_non_streaming(
1760 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1761 )
1762 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1763 assert len(items_before.data) == 4
1764
1765 # Get the second user message ID to retry after it
1766 second_user_message_id = items_before.data[1].id # Second user message
1767
1768 # Retry after the second user message
1769 retry_events = await server.process_streaming(
1770 ThreadsRetryAfterItemReq(
1771 params=ThreadRetryAfterItemParams(
1772 thread_id=thread.id, item_id=second_user_message_id
1773 )
1774 )
1775 )
1776
1777 assert len(retry_events) == 2
1778 assert retry_events[0].type == "stream_options"
1779 assert retry_events[1].type == "thread.item.done"
1780
1781 # Verify retry used the second user message
1782 assert len(responder_calls) == 3 # Original 2 + 1 retry
1783 assert responder_calls[2].type == "user_message"
1784 assert responder_calls[2].content[0].text == "Second message"
1785
1786 # Verify only the last assistant response was removed and regenerated
1787 list_result_after = await server.process_non_streaming(
1788 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1789 )
1790 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1791 list_result_after.json
1792 )
1793 assert len(items_after.data) == 4 # Still 4 items
1794 # First user message and response should remain unchanged
1795 assert items_after.data[3].type == "user_message"
1796 assert items_after.data[3].content[0].text == "First message"
1797 assert items_after.data[3].content[0].type == "input_text"
1798 assert items_after.data[3].content[0].text == "First message"
1799 assert items_after.data[2].type == "assistant_message"
1800 assert items_after.data[2].content[0].type == "output_text"
1801 assert items_after.data[2].content[0].text == "Response to: First message"
1802 # Second user message should remain, but assistant response should be new
1803 assert items_after.data[1].type == "user_message"
1804 assert items_after.data[1].content[0].text == "Second message"
1805 assert items_after.data[0].type == "assistant_message"
1806 assert items_after.data[0].content[0].type == "output_text"
1807 assert items_after.data[0].content[0].text == "Response to: Second message"
1808 assert items_after.data[0].id != items_before.data[0].id
1809
1810
1811async def test_threads_create_passes_tools_to_responder():
1812 tool_choice = ToolChoice(id="web_search")
1813
1814 async def responder(
1815 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1816 ) -> AsyncIterator[ThreadStreamEvent]:
1817 assert input is not None
1818 assert input.inference_options.tool_choice is not None
1819 assert input.inference_options.tool_choice.id == tool_choice.id
1820
1821 yield ThreadItemDoneEvent(
1822 item=AssistantMessageItem(
1823 id="assistant_msg_tools_create",
1824 content=[AssistantMessageContent(text="ok")],
1825 created_at=datetime.now(),
1826 thread_id=thread.id,
1827 ),
1828 )
1829
1830 with make_server(responder) as server:
1831 events = await server.process_streaming(
1832 ThreadsCreateReq(
1833 params=ThreadCreateParams(
1834 input=UserMessageInput(
1835 content=[UserMessageTextContent(text="Hello with tools")],
1836 attachments=[],
1837 inference_options=InferenceOptions(tool_choice=tool_choice),
1838 )
1839 )
1840 )
1841 )
1842 # Sanity check that the request flowed through.
1843 assert any(e.type == "thread.item.done" for e in events)
1844
1845
1846async def test_retry_after_item_passes_tools_to_responder():
1847 pass
1848
1849
1850async def test_threads_add_message_passes_tools_to_responder():
1851 tool_choice = ToolChoice(id="retrieval")
1852
1853 async def responder(
1854 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1855 ) -> AsyncIterator[ThreadStreamEvent]:
1856 assert input is not None
1857 assert input.inference_options.tool_choice is not None
1858 assert input.inference_options.tool_choice.id == tool_choice.id
1859
1860 yield ThreadItemDoneEvent(
1861 item=AssistantMessageItem(
1862 id="assistant_msg_tools_add",
1863 content=[AssistantMessageContent(text="ok")],
1864 created_at=datetime.now(),
1865 thread_id=thread.id,
1866 ),
1867 )
1868
1869 with make_server(responder) as server:
1870 # Create a thread first.
1871 events = await server.process_streaming(
1872 ThreadsCreateReq(
1873 params=ThreadCreateParams(
1874 input=UserMessageInput(
1875 content=[UserMessageTextContent(text="Start")],
1876 attachments=[],
1877 inference_options=InferenceOptions(tool_choice=tool_choice),
1878 )
1879 )
1880 )
1881 )
1882 thread = next(e.thread for e in events if e.type == "thread.created")
1883
1884 # Add a message with tools and verify they reach the responder.
1885 events = await server.process_streaming(
1886 ThreadsAddUserMessageReq(
1887 params=ThreadAddUserMessageParams(
1888 thread_id=thread.id,
1889 input=UserMessageInput(
1890 content=[UserMessageTextContent(text="Second")],
1891 attachments=[],
1892 inference_options=InferenceOptions(tool_choice=tool_choice),
1893 ),
1894 )
1895 )
1896 )
1897 # Sanity check that the request flowed through.
1898 assert any(e.type == "thread.item.done" for e in events)
1899
1900
1901async def test_custom_id_generators_on_server():
1902 class UniqueIdStore(SQLiteStore):
1903 def __init__(self, path):
1904 super().__init__(path)
1905 self.counter = 0
1906
1907 def generate_thread_id(self, context) -> str:
1908 self.counter += 1
1909 return f"thr_custom_{self.counter}"
1910
1911 def generate_item_id(
1912 self,
1913 item_type: str,
1914 thread: ThreadMetadata,
1915 context,
1916 ) -> str:
1917 self.counter += 1
1918 return f"{item_type}_custom_{self.counter}_{thread.id}"
1919
1920 async def responder(
1921 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1922 ) -> AsyncIterator[ThreadStreamEvent]:
1923 # Emit a tool call to later verify tool_call_output id generation
1924 return
1925 yield
1926
1927 with make_server(responder) as server:
1928 # Swap in our custom ID-generating store, reusing the same DB path.
1929 db_path = cast(SQLiteStore, server.store).db_path
1930 server.store = UniqueIdStore(db_path)
1931
1932 # Create thread and verify thread id and user message id
1933 result = await server.process(
1934 ThreadsCreateReq(
1935 params=ThreadCreateParams(
1936 input=UserMessageInput(
1937 content=[UserMessageTextContent(text="Hello")],
1938 attachments=[],
1939 inference_options=InferenceOptions(),
1940 )
1941 )
1942 ).model_dump_json(),
1943 DEFAULT_CONTEXT,
1944 )
1945 assert isinstance(result, StreamingResult)
1946 events = await decode_streaming_result(result)
1947
1948 thread_created = next(e for e in events if e.type == "thread.created")
1949 assert thread_created.thread.id == "thr_custom_1"
1950
1951 user_message = next(
1952 e.item
1953 for e in events
1954 if e.type == "thread.item.done" and e.item.type == "user_message"
1955 )
1956 assert user_message.id == "message_custom_2_thr_custom_1"
1957