openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
b79fb0495b3e0fffb0cc647d4be8bb031b7c584e

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_chatkit_server.py

2042lines · modecode

1import asyncio
2import base64
3import sqlite3
4from contextlib import contextmanager
5from datetime import datetime
6from typing import Any, AsyncIterator, Callable, cast
7
8import pytest
9from helpers.mock_store import SQLiteStore
10from pydantic import AnyUrl, TypeAdapter
11
12from chatkit.actions import Action
13from chatkit.errors import ErrorCode
14from chatkit.server import (
15 ChatKitServer,
16 NonStreamingResult,
17 StreamingResult,
18 stream_widget,
19)
20from chatkit.store import (
21 AttachmentStore,
22 NotFoundError,
23 StoreItemType,
24 default_generate_id,
25)
26from chatkit.types import (
27 AssistantMessageContent,
28 AssistantMessageContentPartTextDelta,
29 AssistantMessageItem,
30 Attachment,
31 AttachmentCreateParams,
32 AttachmentDeleteParams,
33 AttachmentsCreateReq,
34 AttachmentsDeleteReq,
35 AttachmentUploadDescriptor,
36 ClientToolCallItem,
37 CustomTask,
38 FeedbackKind,
39 FileAttachment,
40 ImageAttachment,
41 InferenceOptions,
42 InputTranscribeParams,
43 InputTranscribeReq,
44 ItemFeedbackParams,
45 ItemsFeedbackReq,
46 ItemsListParams,
47 ItemsListReq,
48 LockedStatus,
49 Page,
50 ProgressUpdateEvent,
51 Thread,
52 ThreadAddClientToolOutputParams,
53 ThreadAddUserMessageParams,
54 ThreadCreatedEvent,
55 ThreadCreateParams,
56 ThreadCustomActionParams,
57 ThreadDeleteParams,
58 ThreadGetByIdParams,
59 ThreadItem,
60 ThreadItemAddedEvent,
61 ThreadItemDoneEvent,
62 ThreadItemRemovedEvent,
63 ThreadItemReplacedEvent,
64 ThreadItemUpdatedEvent,
65 ThreadListParams,
66 ThreadMetadata,
67 ThreadRetryAfterItemParams,
68 ThreadsAddClientToolOutputReq,
69 ThreadsAddUserMessageReq,
70 ThreadsCreateReq,
71 ThreadsCustomActionReq,
72 ThreadsDeleteReq,
73 ThreadsGetByIdReq,
74 ThreadsListReq,
75 ThreadsRetryAfterItemReq,
76 ThreadStreamEvent,
77 ThreadsUpdateReq,
78 ThreadUpdatedEvent,
79 ThreadUpdateParams,
80 ToolChoice,
81 TranscriptionResult,
82 UserMessageInput,
83 UserMessageItem,
84 UserMessageTextContent,
85 WidgetItem,
86 WidgetRootUpdated,
87 Workflow,
88 WorkflowItem,
89 WorkflowTaskAdded,
90)
91from chatkit.widgets import Card, Text
92from tests._types import RequestContext
93from tests.test_store import make_thread_items
94
95server_id = 0
96
97
98DEFAULT_CONTEXT = RequestContext(user_id="test_user")
99
100
101class InMemoryFileStore(AttachmentStore):
102 def __init__(self):
103 self.files = {}
104
105 async def delete_attachment(self, attachment_id: str, context: Any) -> None:
106 del self.files[attachment_id]
107
108 async def create_attachment(
109 self, input: AttachmentCreateParams, context: RequestContext
110 ) -> Attachment:
111 # Simple in-memory attachment creation
112 id = f"atc_{len(self.files) + 1}"
113 metadata = {"source": "test"}
114 if input.mime_type.startswith("image/"):
115 attachment = ImageAttachment(
116 id=id,
117 mime_type=input.mime_type,
118 name=input.name,
119 preview_url=AnyUrl(f"https://example.com/{id}/preview"),
120 upload_descriptor=AttachmentUploadDescriptor(
121 url=AnyUrl(f"https://example.com/{id}/upload"),
122 method="PUT",
123 headers={"X-My-Header": "my-value"},
124 ),
125 metadata=metadata,
126 )
127 else:
128 attachment = FileAttachment(
129 id=id,
130 mime_type=input.mime_type,
131 name=input.name,
132 upload_descriptor=AttachmentUploadDescriptor(
133 url=AnyUrl(f"https://example.com/{id}/upload"),
134 method="PUT",
135 headers={"X-My-Header": "my-value"},
136 ),
137 metadata=metadata,
138 )
139 self.files[attachment.id] = attachment
140 return attachment
141
142
143def decode_event(event: bytes) -> ThreadStreamEvent:
144 return TypeAdapter(ThreadStreamEvent).validate_json(event.split(b"data: ")[1])
145
146
147async def decode_streaming_result(
148 streaming_result: StreamingResult,
149) -> list[ThreadStreamEvent]:
150 return [decode_event(event) async for event in streaming_result.json_events]
151
152
153@contextmanager
154def make_server(
155 responder: Callable[
156 [ThreadMetadata, UserMessageItem | None, Any], AsyncIterator[ThreadStreamEvent]
157 ]
158 | None = None,
159 handle_feedback: Callable[[str, list[str], FeedbackKind, Any], None] | None = None,
160 action_callback: Callable[
161 [ThreadMetadata, Action[str, Any], WidgetItem | None, Any],
162 AsyncIterator[ThreadStreamEvent],
163 ]
164 | None = None,
165 file_store: AttachmentStore | None = None,
166 transcribe_callback: Callable[[bytes, str, Any], TranscriptionResult] | None = None,
167):
168 global server_id
169 db_path = f"file:{server_id}?mode=memory&cache=shared"
170 server_id += 1
171 # Keep the shared in-memory database from being deleted when the last connection is closed
172 db = sqlite3.connect(db_path, uri=True)
173
174 class TestChatKitServer(ChatKitServer):
175 def __init__(self):
176 super().__init__(SQLiteStore(db_path), file_store)
177
178 def action(
179 self,
180 thread: ThreadMetadata,
181 action: Action[str, Any],
182 sender: WidgetItem | None,
183 context: Any,
184 ) -> AsyncIterator[ThreadStreamEvent]:
185 if action_callback is None:
186 raise ValueError("action_callback not wired up")
187 return action_callback(thread, action, sender, context)
188
189 def respond(
190 self,
191 thread: ThreadMetadata,
192 input_user_message: UserMessageItem | None,
193 context: Any,
194 ) -> AsyncIterator[ThreadStreamEvent]:
195 if responder is None:
196 return self._empty_responder()
197 return responder(thread, input_user_message, context)
198
199 async def _empty_responder(self) -> AsyncIterator[ThreadStreamEvent]:
200 return
201 yield
202
203 async def add_feedback(
204 self,
205 thread_id: str,
206 item_ids: list[str],
207 feedback: FeedbackKind,
208 context: Any,
209 ) -> None:
210 if handle_feedback is None:
211 return
212 handle_feedback(thread_id, item_ids, feedback, context)
213
214 async def transcribe(
215 self, audio_bytes: bytes, mime_type: str, context: Any
216 ) -> TranscriptionResult:
217 if transcribe_callback is None:
218 return await super().transcribe(audio_bytes, mime_type, context)
219 return transcribe_callback(audio_bytes, mime_type, context)
220
221 async def process_streaming(
222 self, request_obj, context: Any | None = None
223 ) -> list[ThreadStreamEvent]:
224 result = await self.process(
225 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
226 )
227 assert isinstance(result, StreamingResult)
228 return await decode_streaming_result(result)
229
230 async def process_non_streaming(self, request_obj, context: Any | None = None):
231 result = await self.process(
232 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
233 )
234 assert isinstance(result, NonStreamingResult)
235 return result
236
237 try:
238 yield TestChatKitServer()
239 finally:
240 db.close()
241
242
243async def test_stream_cancellation_persists_pending_assistant_message_and_hidden_context():
244 async def responder(
245 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
246 ) -> AsyncIterator[ThreadStreamEvent]:
247 yield ThreadItemAddedEvent(
248 item=AssistantMessageItem(
249 id="assistant-message-pending",
250 created_at=datetime.now(),
251 content=[AssistantMessageContent(text="Hello, ")],
252 thread_id=thread.id,
253 )
254 )
255 yield ThreadItemUpdatedEvent(
256 item_id="assistant-message-pending",
257 update=AssistantMessageContentPartTextDelta(
258 content_index=0,
259 delta="World!",
260 ),
261 )
262 raise asyncio.CancelledError()
263
264 with make_server(responder) as server:
265 # Allow hidden_context id generation in this test store
266 original_generate_item_id = server.store.generate_item_id
267
268 def generate_item_id(
269 item_type: StoreItemType, thread: ThreadMetadata, context: Any
270 ):
271 if item_type == "sdk_hidden_context":
272 return default_generate_id("sdk_hidden_context")
273 return original_generate_item_id(item_type, thread, context)
274
275 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
276
277 stream = await server.process(
278 ThreadsCreateReq(
279 params=ThreadCreateParams(
280 input=UserMessageInput(
281 content=[UserMessageTextContent(text="Hello")],
282 attachments=[],
283 inference_options=InferenceOptions(),
284 )
285 )
286 ).model_dump_json(),
287 DEFAULT_CONTEXT,
288 )
289 assert isinstance(stream, StreamingResult)
290
291 events: list[ThreadStreamEvent] = []
292 with pytest.raises(asyncio.CancelledError): # noqa: PT012
293 async for raw in stream.json_events:
294 events.append(decode_event(raw))
295
296 thread = next(e.thread for e in events if e.type == "thread.created")
297 items = await server.store.load_thread_items(
298 thread.id, None, 1, "desc", DEFAULT_CONTEXT
299 )
300 hidden_context_item = items.data[-1]
301 assert hidden_context_item.type == "sdk_hidden_context"
302 assert (
303 hidden_context_item.content
304 == "The user cancelled the stream. Stop responding to the prior request."
305 )
306
307 assistant_message_item = await server.store.load_item(
308 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
309 )
310 assert assistant_message_item.type == "assistant_message"
311 assert assistant_message_item.content[0].text == "Hello, World!"
312
313
314async def test_stream_cancellation_does_not_persist_pending_empty_assistant_message():
315 async def responder(
316 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
317 ) -> AsyncIterator[ThreadStreamEvent]:
318 yield ThreadItemAddedEvent(
319 item=AssistantMessageItem(
320 id="assistant-message-pending",
321 created_at=datetime.now(),
322 content=[],
323 thread_id=thread.id,
324 )
325 )
326 raise asyncio.CancelledError()
327
328 with make_server(responder) as server:
329 original_generate_item_id = server.store.generate_item_id
330
331 def generate_item_id(
332 item_type: StoreItemType, thread: ThreadMetadata, context: Any
333 ):
334 if item_type == "sdk_hidden_context":
335 return default_generate_id("sdk_hidden_context")
336 return original_generate_item_id(item_type, thread, context)
337
338 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
339
340 stream = await server.process(
341 ThreadsCreateReq(
342 params=ThreadCreateParams(
343 input=UserMessageInput(
344 content=[UserMessageTextContent(text="Hello")],
345 attachments=[],
346 inference_options=InferenceOptions(),
347 )
348 )
349 ).model_dump_json(),
350 DEFAULT_CONTEXT,
351 )
352 assert isinstance(stream, StreamingResult)
353
354 events: list[ThreadStreamEvent] = []
355 with pytest.raises(asyncio.CancelledError): # noqa: PT012
356 async for raw in stream.json_events:
357 events.append(decode_event(raw))
358
359 thread = next(e.thread for e in events if e.type == "thread.created")
360 items = await server.store.load_thread_items(
361 thread.id, None, 1, "desc", DEFAULT_CONTEXT
362 )
363 hidden_context_item = items.data[-1]
364 assert hidden_context_item.type == "sdk_hidden_context"
365 assert (
366 hidden_context_item.content
367 == "The user cancelled the stream. Stop responding to the prior request."
368 )
369
370 with pytest.raises(NotFoundError):
371 await server.store.load_item(
372 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
373 )
374
375
376async def test_workflow_task_not_duplicated_on_done_event():
377 """
378 Regression test to make sure pending item updates do not modify the
379 origin item that was streamed.
380 """
381
382 async def responder(
383 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
384 ) -> AsyncIterator[ThreadStreamEvent]:
385 workflow_item = WorkflowItem(
386 id="workflow-item",
387 created_at=datetime.now(),
388 thread_id=thread.id,
389 workflow=Workflow(type="custom", tasks=[]),
390 )
391
392 yield ThreadItemAddedEvent(item=workflow_item)
393
394 task = CustomTask(title="foo", content="bar")
395 yield ThreadItemUpdatedEvent(
396 item_id=workflow_item.id,
397 update=WorkflowTaskAdded(task=task, task_index=0),
398 )
399
400 workflow_item.workflow.tasks.append(task)
401 yield ThreadItemDoneEvent(item=workflow_item)
402
403 with make_server(responder) as server:
404 events = await server.process_streaming(
405 ThreadsCreateReq(
406 params=ThreadCreateParams(
407 input=UserMessageInput(
408 content=[UserMessageTextContent(text="Hello")],
409 attachments=[],
410 inference_options=InferenceOptions(),
411 )
412 )
413 )
414 )
415
416 thread = next(e.thread for e in events if e.type == "thread.created")
417 workflow_done_event = next(
418 e
419 for e in events
420 if isinstance(e, ThreadItemDoneEvent) and isinstance(e.item, WorkflowItem)
421 )
422 workflow_done_item = cast(WorkflowItem, workflow_done_event.item)
423 assert len(workflow_done_item.workflow.tasks) == 1
424 assert workflow_done_item.workflow.tasks[0].title == "foo"
425
426 stored = await server.store.load_item(
427 thread.id, workflow_done_item.id, DEFAULT_CONTEXT
428 )
429 assert isinstance(stored, WorkflowItem)
430 stored_workflow = cast(WorkflowItem, stored)
431 assert len(stored_workflow.workflow.tasks) == 1
432 assert stored_workflow.workflow.tasks[0].title == "foo"
433
434
435async def test_flows_context_to_responder():
436 responder_context = None
437 add_feedback_context = None
438
439 async def responder(
440 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
441 ) -> AsyncIterator[ThreadStreamEvent]:
442 nonlocal responder_context
443 responder_context = context
444 yield ThreadItemDoneEvent(
445 item=AssistantMessageItem(
446 id="msg_1",
447 created_at=datetime.now(),
448 content=[AssistantMessageContent(text="Hello, world!")],
449 thread_id=thread.id,
450 ),
451 )
452
453 def add_feedback(
454 thread_id: str, item_ids: list[str], feedback: FeedbackKind, context: Any
455 ) -> None:
456 nonlocal add_feedback_context
457 add_feedback_context = context
458
459 context = RequestContext(user_id="text_ctx")
460 with make_server(responder, add_feedback) as server:
461 events = await server.process_streaming(
462 ThreadsCreateReq(
463 params=ThreadCreateParams(
464 input=UserMessageInput(
465 content=[UserMessageTextContent(text="Hello, world!")],
466 attachments=[],
467 inference_options=InferenceOptions(),
468 )
469 )
470 ),
471 context,
472 )
473 assert responder_context == context
474 thread = next(
475 event.thread for event in events if event.type == "thread.created"
476 )
477 item = next(
478 event.item
479 for event in events
480 if event.type == "thread.item.done"
481 and event.item.type == "assistant_message"
482 )
483 await server.process_non_streaming(
484 ItemsFeedbackReq(
485 params=ItemFeedbackParams(
486 thread_id=thread.id,
487 item_ids=[item.id],
488 kind="positive",
489 )
490 ),
491 context,
492 )
493 assert add_feedback_context == context
494
495
496async def test_creates_thread():
497 async def responder(
498 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
499 ) -> AsyncIterator[ThreadStreamEvent]:
500 return
501 yield
502
503 with make_server(responder) as server:
504 events = await server.process_streaming(
505 ThreadsCreateReq(
506 params=ThreadCreateParams(
507 input=UserMessageInput(
508 content=[UserMessageTextContent(text="Hello, world!")],
509 attachments=[],
510 inference_options=InferenceOptions(),
511 quoted_text="Quoted text",
512 )
513 )
514 )
515 )
516
517 thread_created_event = next(
518 event for event in events if event.type == "thread.created"
519 )
520 assert thread_created_event.thread.title is None
521
522 item_done_event = next(
523 event.item for event in events if event.type == "thread.item.done"
524 )
525
526 assert item_done_event.type == "user_message"
527 assert item_done_event.content[0].text == "Hello, world!"
528 assert item_done_event.quoted_text == "Quoted text"
529
530
531async def test_saves_thread_title():
532 async def responder(
533 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
534 ) -> AsyncIterator[ThreadStreamEvent]:
535 thread.title = "Updated title"
536 return
537 yield
538
539 with make_server(responder) as server:
540 events = await server.process_streaming(
541 ThreadsCreateReq(
542 params=ThreadCreateParams(
543 input=UserMessageInput(
544 content=[UserMessageTextContent(text="Hello, world!")],
545 attachments=[],
546 inference_options=InferenceOptions(),
547 )
548 )
549 )
550 )
551 thread = next(
552 event.thread for event in events if event.type == "thread.created"
553 )
554 assert (
555 await server.store.load_thread(
556 thread.id, RequestContext(user_id="test_user")
557 )
558 ).title == "Updated title"
559 assert events[-1].type == "thread.updated"
560 assert events[-1].thread.title == "Updated title"
561
562
563async def test_saves_thread_metadata():
564 async def responder(
565 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
566 ) -> AsyncIterator[ThreadStreamEvent]:
567 thread.metadata["test_key"] = "test_value"
568 return
569 yield
570
571 with make_server(responder) as server:
572 events = await server.process_streaming(
573 ThreadsCreateReq(
574 params=ThreadCreateParams(
575 input=UserMessageInput(
576 content=[UserMessageTextContent(text="Hello, world!")],
577 attachments=[],
578 inference_options=InferenceOptions(),
579 )
580 )
581 )
582 )
583 thread = next(
584 event.thread for event in events if event.type == "thread.created"
585 )
586 assert (await server.store.load_thread(thread.id, DEFAULT_CONTEXT)).metadata[
587 "test_key"
588 ] == "test_value"
589 assert events[-1].type == "thread.updated"
590
591
592async def test_saves_thread_locked_fields():
593 async def responder(
594 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
595 ) -> AsyncIterator[ThreadStreamEvent]:
596 thread.status = LockedStatus(reason="Because")
597 return
598 yield
599
600 with make_server(responder) as server:
601 events = await server.process_streaming(
602 ThreadsCreateReq(
603 params=ThreadCreateParams(
604 input=UserMessageInput(
605 content=[UserMessageTextContent(text="Hello, world!")],
606 attachments=[],
607 inference_options=InferenceOptions(),
608 )
609 )
610 )
611 )
612 thread = next(
613 event.thread for event in events if event.type == "thread.created"
614 )
615 loaded = await server.store.load_thread(
616 thread.id, RequestContext(user_id="test_user")
617 )
618 assert loaded.status == LockedStatus(reason="Because")
619 assert events[-1].type == "thread.updated"
620 assert events[-1].thread.status == LockedStatus(reason="Because")
621
622
623async def test_emits_thread_updated_mid_stream_and_persists():
624 async def responder(
625 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
626 ) -> AsyncIterator[ThreadStreamEvent]:
627 # First assistant message
628 yield ThreadItemDoneEvent(
629 item=AssistantMessageItem(
630 id="assistant_a",
631 content=[AssistantMessageContent(text="A")],
632 created_at=datetime.now(),
633 thread_id=thread.id,
634 ),
635 )
636
637 # Update the thread mid-stream
638 thread.title = "Mid-stream title"
639
640 # Second assistant message, after which thread.updated should be emitted
641 yield ThreadItemDoneEvent(
642 item=AssistantMessageItem(
643 id="assistant_b",
644 content=[AssistantMessageContent(text="B")],
645 created_at=datetime.now(),
646 thread_id=thread.id,
647 ),
648 )
649
650 # Continue streaming after the update event
651 yield ThreadItemDoneEvent(
652 item=AssistantMessageItem(
653 id="assistant_c",
654 content=[AssistantMessageContent(text="C")],
655 created_at=datetime.now(),
656 thread_id=thread.id,
657 ),
658 )
659
660 with make_server(responder) as server:
661 events = await server.process_streaming(
662 ThreadsCreateReq(
663 params=ThreadCreateParams(
664 input=UserMessageInput(
665 content=[UserMessageTextContent(text="Hello")],
666 attachments=[],
667 inference_options=InferenceOptions(),
668 )
669 )
670 )
671 )
672
673 # Find the created thread
674 thread = next(e.thread for e in events if isinstance(e, ThreadCreatedEvent))
675
676 # Ensure a thread.updated event occurs right after assistant_b
677 idx_b = next(
678 i
679 for i, e in enumerate(events)
680 if isinstance(e, ThreadItemDoneEvent)
681 and isinstance(e.item, AssistantMessageItem)
682 and e.item.id == "assistant_b"
683 )
684 assert isinstance(events[idx_b + 1], ThreadUpdatedEvent)
685 updated_event = cast(ThreadUpdatedEvent, events[idx_b + 1])
686 assert updated_event.thread.title == "Mid-stream title"
687
688 # Streaming continues after the update event
689 assert isinstance(events[idx_b + 2], ThreadItemDoneEvent)
690 done_event = cast(ThreadItemDoneEvent, events[idx_b + 2])
691 assert isinstance(done_event.item, AssistantMessageItem)
692 assert done_event.item.id == "assistant_c"
693
694 # Persisted in the store
695 saved = await server.store.load_thread(thread.id, DEFAULT_CONTEXT)
696 assert saved.title == "Mid-stream title"
697
698
699async def test_respond_with_tool_call():
700 async def responder(
701 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
702 ) -> AsyncIterator[ThreadStreamEvent]:
703 if isinstance(input, UserMessageItem):
704 yield ThreadItemDoneEvent(
705 item=ClientToolCallItem(
706 id="msg_1",
707 created_at=datetime.now(),
708 name="tool_call_1",
709 arguments={"arg1": "val1", "arg2": False},
710 call_id="tool_call_1",
711 thread_id=thread.id,
712 ),
713 )
714 elif input is None:
715 yield ThreadItemDoneEvent(
716 item=AssistantMessageItem(
717 id="msg_2",
718 content=[
719 AssistantMessageContent(text="Glad the tool call succeeded!")
720 ],
721 created_at=datetime.now(),
722 thread_id=thread.id,
723 ),
724 )
725
726 with make_server(responder) as server:
727 events = await server.process_streaming(
728 ThreadsCreateReq(
729 params=ThreadCreateParams(
730 input=UserMessageInput(
731 content=[UserMessageTextContent(text="Hello, world!")],
732 attachments=[],
733 inference_options=InferenceOptions(),
734 )
735 )
736 )
737 )
738
739 assert len(events) == 4
740 assert events[0].type == "thread.created"
741 thread = events[0].thread
742
743 assert events[1].type == "thread.item.done"
744 assert events[1].item.type == "user_message"
745
746 assert events[2].type == "stream_options"
747
748 assert events[3].type == "thread.item.done"
749 assert events[3].item.type == "client_tool_call"
750 assert events[3].item.id == "msg_1"
751 assert events[3].item.name == "tool_call_1"
752 assert events[3].item.arguments == {"arg1": "val1", "arg2": False}
753 assert events[3].item.call_id == "tool_call_1"
754
755 events = await server.process_streaming(
756 ThreadsAddClientToolOutputReq(
757 params=ThreadAddClientToolOutputParams(
758 thread_id=thread.id,
759 result={"text": "Wow!"},
760 )
761 )
762 )
763
764 assert len(events) == 2
765 assert events[0].type == "stream_options"
766 assert events[1].type == "thread.item.done"
767 assert events[1].item.type == "assistant_message"
768
769
770async def test_respond_with_tool_status():
771 async def responder(
772 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
773 ) -> AsyncIterator[ThreadStreamEvent]:
774 yield ProgressUpdateEvent(text="Tool status")
775 yield ThreadItemDoneEvent(
776 item=AssistantMessageItem(
777 id="msg_4",
778 content=[AssistantMessageContent(text="Hello, world!")],
779 created_at=datetime.now(),
780 thread_id=thread.id,
781 ),
782 )
783
784 with make_server(responder) as server:
785 events = await server.process_streaming(
786 ThreadsCreateReq(
787 params=ThreadCreateParams(
788 input=UserMessageInput(
789 content=[UserMessageTextContent(text="Hello, world!")],
790 attachments=[],
791 inference_options=InferenceOptions(),
792 )
793 )
794 )
795 )
796
797 assert len(events) == 5
798 assert events[0].type == "thread.created"
799 assert events[1].type == "thread.item.done"
800 assert events[2].type == "stream_options"
801 assert events[3].type == "progress_update"
802 assert events[4].type == "thread.item.done"
803
804
805async def test_list_threads_response():
806 with make_server() as server:
807 for i in range(25):
808 await server.process_streaming(
809 ThreadsCreateReq(
810 params=ThreadCreateParams(
811 input=UserMessageInput(
812 content=[UserMessageTextContent(text=f"Thread {i}")],
813 attachments=[],
814 inference_options=InferenceOptions(),
815 )
816 )
817 )
818 )
819 list_result = await server.process_non_streaming(
820 ThreadsListReq(params=ThreadListParams(limit=20))
821 )
822 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
823 assert len(threads.data) == 20
824
825 # Newest first
826 assert threads.data[0].created_at > threads.data[1].created_at
827 assert threads.has_more is True
828 assert threads.after is not None
829
830 more_threads = await server.process_non_streaming(
831 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
832 )
833 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(more_threads.json)
834 assert len(threads.data) == 5
835 assert threads.has_more is False
836 assert not threads.after
837
838
839async def test_list_items_response():
840 async def responder(
841 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
842 ) -> AsyncIterator[ThreadStreamEvent]:
843 for i in range(25):
844 yield ThreadItemDoneEvent(
845 item=UserMessageItem(
846 id=f"msg_{i}",
847 content=[UserMessageTextContent(text=f"Message {i}")],
848 attachments=[],
849 inference_options=InferenceOptions(),
850 thread_id=thread.id,
851 created_at=datetime.now(),
852 ),
853 )
854
855 with make_server(responder) as server:
856 events = await server.process_streaming(
857 ThreadsCreateReq(
858 params=ThreadCreateParams(
859 input=UserMessageInput(
860 content=[UserMessageTextContent(text="Thread 1")],
861 attachments=[],
862 inference_options=InferenceOptions(),
863 )
864 )
865 )
866 )
867 thread = next(
868 event.thread for event in events if event.type == "thread.created"
869 )
870 list_result = await server.process_non_streaming(
871 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
872 )
873 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
874 assert len(items.data) == 20
875 # Newest first
876 assert items.data[0].created_at > items.data[1].created_at
877 assert items.has_more is True
878 assert items.after is not None
879
880 list_result = await server.process_non_streaming(
881 ItemsListReq(
882 params=ItemsListParams(thread_id=thread.id, after=items.data[-1].id)
883 )
884 )
885 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
886 assert len(items.data) == 6 # +1 for user message
887 assert items.has_more is False
888
889
890async def test_calls_action():
891 actions = []
892
893 async def responder(
894 thread: ThreadMetadata,
895 input: UserMessageItem | None,
896 context: Any,
897 ) -> AsyncIterator[ThreadStreamEvent]:
898 yield ThreadItemDoneEvent(
899 item=WidgetItem(
900 id="widget_1",
901 type="widget",
902 created_at=datetime.now(),
903 thread_id=thread.id,
904 widget=Card(
905 children=[
906 Text(
907 key="text_1",
908 value="test",
909 )
910 ],
911 ),
912 ),
913 )
914
915 async def action(
916 thread: ThreadMetadata,
917 action: Action[str, Any],
918 sender: WidgetItem | None,
919 context: Any,
920 ) -> AsyncIterator[ThreadStreamEvent]:
921 actions.append((action, sender))
922 assert sender
923
924 yield ThreadItemUpdatedEvent(
925 item_id=sender.id,
926 update=WidgetRootUpdated(
927 widget=Card(
928 children=[
929 Text(value="Email sent!"),
930 ]
931 ),
932 ),
933 )
934
935 with make_server(responder, action_callback=action) as server:
936 events = await server.process_streaming(
937 ThreadsCreateReq(
938 params=ThreadCreateParams(
939 input=UserMessageInput(
940 content=[UserMessageTextContent(text="Show widget")],
941 attachments=[],
942 inference_options=InferenceOptions(),
943 )
944 )
945 )
946 )
947 thread = next(
948 event.thread for event in events if event.type == "thread.created"
949 )
950 widget_item = next(
951 event.item
952 for event in events
953 if isinstance(event, ThreadItemDoneEvent)
954 and isinstance(event.item, WidgetItem)
955 )
956
957 events = await server.process_streaming(
958 ThreadsCustomActionReq(
959 params=ThreadCustomActionParams(
960 thread_id=thread.id,
961 item_id=widget_item.id,
962 action=Action(type="create_user", payload={"user_id": "123"}),
963 )
964 )
965 )
966
967 assert actions
968 assert actions[0] == (
969 Action(type="create_user", payload={"user_id": "123"}),
970 widget_item,
971 )
972
973 assert len(events) == 2
974 assert events[0].type == "stream_options"
975 assert events[1].type == "thread.item.updated"
976 assert isinstance(events[1], ThreadItemUpdatedEvent)
977 assert events[1].update.type == "widget.root.updated"
978 assert events[1].update.widget == Card(children=[Text(value="Email sent!")])
979
980
981async def test_add_feedback():
982 called = []
983
984 def handle_feedback(thread_id, item_ids, feedback, context):
985 called.append((thread_id, item_ids, feedback, context))
986
987 async def responder(
988 thread: ThreadMetadata,
989 input: UserMessageItem | None,
990 context: Any,
991 ) -> AsyncIterator[ThreadStreamEvent]:
992 yield ThreadItemDoneEvent(
993 item=AssistantMessageItem(
994 id="assistant_msg_1",
995 content=[
996 AssistantMessageContent(text="Feedback received!"),
997 ],
998 created_at=datetime.now(),
999 thread_id=thread.id,
1000 ),
1001 )
1002 yield ThreadItemDoneEvent(
1003 item=WidgetItem(
1004 id="widget_1",
1005 type="widget",
1006 created_at=datetime.now(),
1007 widget=Card(children=[Text(key="text", value="Widget content")]),
1008 thread_id=thread.id,
1009 ),
1010 )
1011
1012 with make_server(responder, handle_feedback=handle_feedback) as server:
1013 events = await server.process_streaming(
1014 ThreadsCreateReq(
1015 params=ThreadCreateParams(
1016 input=UserMessageInput(
1017 content=[UserMessageTextContent(text="Test feedback")],
1018 attachments=[],
1019 inference_options=InferenceOptions(),
1020 )
1021 )
1022 )
1023 )
1024 thread = next(
1025 event.thread for event in events if event.type == "thread.created"
1026 )
1027
1028 await server.process_non_streaming(
1029 ItemsFeedbackReq(
1030 params=ItemFeedbackParams(
1031 thread_id=thread.id,
1032 item_ids=["assistant_msg_1"],
1033 kind="positive",
1034 )
1035 )
1036 )
1037 await server.process_non_streaming(
1038 ItemsFeedbackReq(
1039 params=ItemFeedbackParams(
1040 thread_id=thread.id,
1041 item_ids=["widget_1"],
1042 kind="negative",
1043 )
1044 )
1045 )
1046
1047 assert len(called) == 2
1048 assert called[0][0] == thread.id
1049 assert called[0][1] == ["assistant_msg_1"]
1050 assert called[0][2] == "positive"
1051
1052 assert called[1][0] == thread.id
1053 assert called[1][1] == ["widget_1"]
1054 assert called[1][2] == "negative"
1055
1056
1057async def test_get_thread_by_id():
1058 make_thread_items()
1059
1060 with make_server() as server:
1061 # Create a thread by sending a ThreadCreateRequest
1062 events = await server.process_streaming(
1063 ThreadsCreateReq(
1064 params=ThreadCreateParams(
1065 input=UserMessageInput(
1066 content=[UserMessageTextContent(text="Test thread")],
1067 attachments=[],
1068 inference_options=InferenceOptions(),
1069 )
1070 )
1071 )
1072 )
1073 thread = next(
1074 event.thread for event in events if event.type == "thread.created"
1075 )
1076
1077 # Retrieve the thread by id
1078 result = await server.process_non_streaming(
1079 ThreadsGetByIdReq(
1080 params=ThreadGetByIdParams(thread_id=thread.id),
1081 )
1082 )
1083 loaded_thread = TypeAdapter(Thread).validate_json(result.json)
1084 assert loaded_thread.id == thread.id
1085 assert loaded_thread.title == thread.title
1086 assert loaded_thread.items.data[0].type == "user_message"
1087 assert loaded_thread.items.data[0].content[0].text == "Test thread"
1088
1089
1090async def test_create_file():
1091 store = InMemoryFileStore()
1092 file_name = "test-file-name"
1093 file_content_type = "text/plain"
1094
1095 with make_server(file_store=store) as server:
1096 result = await server.process_non_streaming(
1097 AttachmentsCreateReq(
1098 params=AttachmentCreateParams(
1099 name=file_name, size=1024, mime_type=file_content_type
1100 )
1101 )
1102 )
1103 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1104
1105 # Verify response includes file
1106 assert attachment.id is not None
1107 assert attachment.mime_type == file_content_type
1108 assert attachment.name == file_name
1109 assert attachment.type == "file"
1110 assert attachment.upload_descriptor is not None
1111 assert attachment.upload_descriptor.url == AnyUrl(
1112 f"https://example.com/{attachment.id}/upload"
1113 )
1114 assert attachment.upload_descriptor.method == "PUT"
1115 assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"}
1116 assert attachment.id in store.files
1117
1118 # The `metadata` field is excluded from serialization by default but still stored server-side
1119 assert attachment.metadata is None
1120 assert store.files[attachment.id].metadata == {"source": "test"}
1121
1122
1123async def test_create_image_file():
1124 store = InMemoryFileStore()
1125 file_name = "test-file-name"
1126 file_content_type = "image/png"
1127
1128 with make_server(file_store=store) as server:
1129 result = await server.process_non_streaming(
1130 AttachmentsCreateReq(
1131 params=AttachmentCreateParams(
1132 name=file_name,
1133 size=1024,
1134 mime_type=file_content_type,
1135 )
1136 )
1137 )
1138 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1139
1140 assert attachment.id is not None
1141 assert attachment.mime_type == file_content_type
1142 assert attachment.name == file_name
1143 assert attachment.type == "image"
1144 assert attachment.preview_url == AnyUrl(
1145 f"https://example.com/{attachment.id}/preview"
1146 )
1147 assert attachment.upload_descriptor is not None
1148 assert attachment.upload_descriptor.url == AnyUrl(
1149 f"https://example.com/{attachment.id}/upload"
1150 )
1151 assert attachment.upload_descriptor.method == "PUT"
1152 assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"}
1153
1154 assert attachment.id in store.files
1155
1156
1157async def test_user_message_attachment_metadata_not_serialized():
1158 attachment = FileAttachment(
1159 id="file_with_metadata",
1160 type="file",
1161 mime_type="text/plain",
1162 name="test.txt",
1163 metadata={"source": "test"},
1164 )
1165
1166 with make_server() as server:
1167 await server.store.save_attachment(attachment, DEFAULT_CONTEXT)
1168 events = await server.process_streaming(
1169 ThreadsCreateReq(
1170 params=ThreadCreateParams(
1171 input=UserMessageInput(
1172 content=[UserMessageTextContent(text="Message with file")],
1173 attachments=[attachment.id],
1174 inference_options=InferenceOptions(),
1175 )
1176 )
1177 )
1178 )
1179 thread = next(
1180 event.thread for event in events if event.type == "thread.created"
1181 )
1182
1183 list_result = await server.process_non_streaming(
1184 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1185 )
1186 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1187 user_items = [item for item in items.data if item.type == "user_message"]
1188 assert user_items
1189 attachments = user_items[0].attachments
1190 assert attachments
1191 assert all(att.metadata is None for att in attachments)
1192
1193
1194async def test_create_file_without_filestore():
1195 with pytest.raises(RuntimeError):
1196 with make_server() as server:
1197 await server.process_non_streaming(
1198 AttachmentsCreateReq(
1199 params=AttachmentCreateParams(
1200 name="test-file-name",
1201 size=1024,
1202 mime_type="text/plain",
1203 )
1204 )
1205 )
1206
1207
1208async def test_delete_file():
1209 store = InMemoryFileStore()
1210 store.files["test-file-id"] = {"id": "test-file-id"}
1211 with make_server(file_store=store) as server:
1212 await server.process_non_streaming(
1213 AttachmentsDeleteReq(
1214 params=AttachmentDeleteParams(attachment_id="test-file-id")
1215 )
1216 )
1217 assert store.files == {}
1218
1219
1220async def test_delete_file_without_filestore():
1221 with pytest.raises(RuntimeError):
1222 with make_server() as server:
1223 await server.process_non_streaming(
1224 AttachmentsDeleteReq(
1225 params=AttachmentDeleteParams(attachment_id="test-file-id")
1226 )
1227 )
1228
1229
1230async def test_list_threads_paginated_response():
1231 with make_server() as server:
1232 for i in range(25):
1233 await server.process_streaming(
1234 ThreadsCreateReq(
1235 params=ThreadCreateParams(
1236 input=UserMessageInput(
1237 content=[UserMessageTextContent(text=f"Thread {i}")],
1238 attachments=[],
1239 inference_options=InferenceOptions(),
1240 )
1241 )
1242 )
1243 )
1244 list_result = await server.process_non_streaming(
1245 ThreadsListReq(params=ThreadListParams(limit=20))
1246 )
1247 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1248 assert len(threads.data) == 20
1249 assert threads.has_more is True
1250 assert threads.after is not None
1251
1252 list_result = await server.process_non_streaming(
1253 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
1254 )
1255 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1256 assert len(threads.data) == 5
1257 assert threads.has_more is False
1258
1259
1260async def test_threads_update_request():
1261 with make_server() as server:
1262 events = await server.process_streaming(
1263 ThreadsCreateReq(
1264 params=ThreadCreateParams(
1265 input=UserMessageInput(
1266 content=[UserMessageTextContent(text="Hello, world!")],
1267 attachments=[],
1268 inference_options=InferenceOptions(),
1269 )
1270 )
1271 )
1272 )
1273 thread = next(
1274 event.thread for event in events if event.type == "thread.created"
1275 )
1276 updated_thread = await server.process_non_streaming(
1277 ThreadsUpdateReq(
1278 params=ThreadUpdateParams(thread_id=thread.id, title="New Title")
1279 )
1280 )
1281 updated_thread = TypeAdapter[Thread](Thread).validate_json(updated_thread.json)
1282 assert updated_thread.title == "New Title"
1283
1284
1285async def test_returns_widget_item_generator():
1286 thread = ThreadMetadata(
1287 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1288 )
1289
1290 def render_widget(i: int) -> Card:
1291 return Card(children=[Text(id="text", value="Hello, world"[:i])])
1292
1293 async def widget_generator():
1294 yield render_widget(0)
1295 yield render_widget(3)
1296 yield render_widget(6)
1297 yield render_widget(12)
1298
1299 events = [event async for event in stream_widget(thread, widget_generator())]
1300
1301 assert len(events) == 5
1302 assert isinstance(events[0], ThreadItemAddedEvent)
1303 assert isinstance(events[0].item, WidgetItem)
1304 assert events[0].item.widget == Card(children=[Text(id="text", value="")])
1305
1306 assert isinstance(events[1], ThreadItemUpdatedEvent)
1307 assert events[1].update.type == "widget.streaming_text.value_delta"
1308 assert events[1].update.component_id == "text"
1309 assert events[1].update.delta == "Hel"
1310
1311 assert isinstance(events[2], ThreadItemUpdatedEvent)
1312 assert events[2].update.type == "widget.streaming_text.value_delta"
1313 assert events[2].update.component_id == "text"
1314 assert events[2].update.delta == "lo,"
1315
1316 assert isinstance(events[3], ThreadItemUpdatedEvent)
1317 assert events[3].update.type == "widget.streaming_text.value_delta"
1318 assert events[3].update.component_id == "text"
1319 assert events[3].update.delta == " world"
1320
1321 assert isinstance(events[4], ThreadItemDoneEvent)
1322 assert isinstance(events[4].item, WidgetItem)
1323 assert events[4].item.widget == Card(
1324 children=[Text(id="text", value="Hello, world")]
1325 )
1326
1327
1328async def test_returns_widget_item_generator_full_replace():
1329 thread = ThreadMetadata(
1330 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1331 )
1332
1333 async def widget_generator():
1334 yield Card(children=[Text(id="text", value="Hello")])
1335 yield Card(children=[Text(id="text", value="World")])
1336
1337 events = [event async for event in stream_widget(thread, widget_generator())]
1338
1339 assert len(events) == 3
1340 assert isinstance(events[0], ThreadItemAddedEvent)
1341 assert isinstance(events[0].item, WidgetItem)
1342 assert events[0].item.widget == Card(children=[Text(id="text", value="Hello")])
1343
1344 assert isinstance(events[1], ThreadItemUpdatedEvent)
1345 assert events[1].update.type == "widget.root.updated"
1346 assert events[1].update.widget == Card(children=[Text(id="text", value="World")])
1347
1348 assert isinstance(events[2], ThreadItemDoneEvent)
1349 assert isinstance(events[2].item, WidgetItem)
1350 assert events[2].item.widget == Card(children=[Text(id="text", value="World")])
1351
1352
1353async def test_delete_thread():
1354 with make_server() as server:
1355 events = await server.process_streaming(
1356 ThreadsCreateReq(
1357 params=ThreadCreateParams(
1358 input=UserMessageInput(
1359 content=[UserMessageTextContent(text="Thread to delete")],
1360 attachments=[],
1361 inference_options=InferenceOptions(),
1362 )
1363 )
1364 )
1365 )
1366 thread = next(
1367 event.thread for event in events if event.type == "thread.created"
1368 )
1369
1370 response = await server.process_non_streaming(
1371 ThreadsDeleteReq(params=ThreadDeleteParams(thread_id=thread.id))
1372 )
1373 assert response.json == b"{}"
1374
1375 try:
1376 await server.process_non_streaming(
1377 ThreadsGetByIdReq(
1378 params=ThreadGetByIdParams(thread_id=thread.id),
1379 )
1380 )
1381 assert False, "Expected ValueError for deleted thread"
1382 except NotFoundError as e:
1383 assert f"Thread {thread.id} not found" in str(e)
1384
1385
1386async def test_thread_item_removed_event_removes_item():
1387 removed_id = "msg_to_remove"
1388
1389 async def responder(
1390 thread: ThreadMetadata,
1391 input: UserMessageItem | None,
1392 context: Any,
1393 ) -> AsyncIterator[ThreadStreamEvent]:
1394 yield ThreadItemDoneEvent(
1395 item=UserMessageItem(
1396 id=removed_id,
1397 content=[UserMessageTextContent(text="To be removed")],
1398 attachments=[],
1399 inference_options=InferenceOptions(),
1400 thread_id=thread.id,
1401 created_at=datetime.now(),
1402 ),
1403 )
1404 yield ThreadItemRemovedEvent(item_id=removed_id)
1405
1406 with make_server(responder) as server:
1407 events = await server.process_streaming(
1408 ThreadsCreateReq(
1409 params=ThreadCreateParams(
1410 input=UserMessageInput(
1411 content=[UserMessageTextContent(text="Trigger removal")],
1412 attachments=[],
1413 inference_options=InferenceOptions(),
1414 )
1415 )
1416 )
1417 )
1418 thread = next(
1419 event.thread for event in events if event.type == "thread.created"
1420 )
1421 assert any(
1422 e.type == "thread.item.removed" and e.item_id == removed_id for e in events
1423 )
1424 # The item should not be present in the store
1425 list_result = await server.process_non_streaming(
1426 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1427 )
1428 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1429 assert all(i.id != removed_id for i in items)
1430
1431
1432async def test_raising_in_responder_yields_error_event():
1433 async def responder(
1434 thread: ThreadMetadata,
1435 input: UserMessageItem | None,
1436 context: Any,
1437 ) -> AsyncIterator[ThreadStreamEvent]:
1438 raise ValueError("Test error")
1439 yield
1440
1441 with make_server(responder) as server:
1442 result = await server.process(
1443 ThreadsCreateReq(
1444 params=ThreadCreateParams(
1445 input=UserMessageInput(
1446 content=[UserMessageTextContent(text="Test error")],
1447 attachments=[],
1448 inference_options=InferenceOptions(),
1449 )
1450 )
1451 ).model_dump_json(),
1452 RequestContext(user_id="test_user"),
1453 )
1454 assert isinstance(result, StreamingResult)
1455 iter = result.__aiter__()
1456
1457 # Yield an errro
1458 async for e in iter:
1459 e = decode_event(e)
1460 if e.type == "error":
1461 assert e.code == ErrorCode.STREAM_ERROR
1462 break
1463 else:
1464 assert False, "Expected error event"
1465
1466 # Then finish
1467 try:
1468 await iter.__anext__()
1469 except StopAsyncIteration:
1470 pass
1471 else:
1472 assert False, "Expected StopAsyncIteration"
1473
1474
1475async def test_thread_item_replaced_event_replaces_item():
1476 original_id = "msg_to_replace"
1477
1478 async def responder(
1479 thread: ThreadMetadata,
1480 input: UserMessageItem | None,
1481 context: Any,
1482 ) -> AsyncIterator[ThreadStreamEvent]:
1483 # First add an item
1484 yield ThreadItemDoneEvent(
1485 item=UserMessageItem(
1486 id=original_id,
1487 content=[UserMessageTextContent(text="Original content")],
1488 attachments=[],
1489 inference_options=InferenceOptions(),
1490 thread_id=thread.id,
1491 created_at=datetime.now(),
1492 ),
1493 )
1494 # Then replace it
1495 replacement_item = AssistantMessageItem(
1496 id=original_id,
1497 content=[AssistantMessageContent(text="This is the replacement content")],
1498 created_at=datetime.now(),
1499 thread_id=thread.id,
1500 )
1501 yield ThreadItemReplacedEvent(item=replacement_item)
1502
1503 with make_server(responder) as server:
1504 events = await server.process_streaming(
1505 ThreadsCreateReq(
1506 params=ThreadCreateParams(
1507 input=UserMessageInput(
1508 content=[UserMessageTextContent(text="Trigger replacement")],
1509 attachments=[],
1510 inference_options=InferenceOptions(),
1511 )
1512 )
1513 )
1514 )
1515 thread = next(
1516 event.thread for event in events if event.type == "thread.created"
1517 )
1518
1519 # Verify the replacement event was generated
1520 assert any(
1521 e.type == "thread.item.replaced" and e.item.id == original_id
1522 for e in events
1523 )
1524
1525 # Verify the item was replaced in the store
1526 list_result = await server.process_non_streaming(
1527 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1528 )
1529 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1530
1531 # Find the replaced item
1532 replaced_item = next((i for i in items if i.id == original_id), None)
1533 assert replaced_item is not None
1534 assert isinstance(replaced_item, AssistantMessageItem)
1535 assert isinstance(replaced_item.content[0], AssistantMessageContent)
1536 assert replaced_item.content[0].text == "This is the replacement content"
1537
1538
1539async def test_thread_item_replaced_event_with_widget():
1540 widget_id = "widget_to_replace"
1541 original_widget = Card(children=[Text(value="Original widget content")])
1542 replacement_widget = Card(children=[Text(value="Replaced widget content")])
1543
1544 async def responder(
1545 thread: ThreadMetadata,
1546 input: UserMessageItem | None,
1547 context: Any,
1548 ) -> AsyncIterator[ThreadStreamEvent]:
1549 # First add a widget item
1550 yield ThreadItemDoneEvent(
1551 item=WidgetItem(
1552 id=widget_id,
1553 widget=original_widget,
1554 created_at=datetime.now(),
1555 thread_id=thread.id,
1556 ),
1557 )
1558 # Then replace it
1559 yield ThreadItemReplacedEvent(
1560 item=WidgetItem(
1561 id=widget_id,
1562 widget=replacement_widget,
1563 created_at=datetime.now(),
1564 thread_id=thread.id,
1565 ),
1566 )
1567
1568 with make_server(responder) as server:
1569 events = await server.process_streaming(
1570 ThreadsCreateReq(
1571 params=ThreadCreateParams(
1572 input=UserMessageInput(
1573 content=[
1574 UserMessageTextContent(text="Trigger widget replacement")
1575 ],
1576 attachments=[],
1577 inference_options=InferenceOptions(),
1578 )
1579 )
1580 )
1581 )
1582 thread = next(
1583 event.thread for event in events if event.type == "thread.created"
1584 )
1585
1586 # Verify the replacement event was generated
1587 replacement_event = next(
1588 (
1589 e
1590 for e in events
1591 if e.type == "thread.item.replaced" and e.item.id == widget_id
1592 ),
1593 None,
1594 )
1595 assert replacement_event is not None
1596 assert isinstance(replacement_event.item, WidgetItem)
1597 assert replacement_event.item.widget == replacement_widget
1598
1599 # Verify the item was replaced in the store
1600 list_result = await server.process_non_streaming(
1601 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1602 )
1603 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1604
1605 # Find the replaced widget item
1606 replaced_item = next((i for i in items if i.id == widget_id), None)
1607 assert replaced_item is not None
1608 assert isinstance(replaced_item, WidgetItem)
1609 assert isinstance(replaced_item.widget, Card)
1610 assert isinstance(replaced_item.widget.children[0], Text)
1611 assert replaced_item.widget.children[0].value == "Replaced widget content"
1612
1613
1614async def test_retry_after_item():
1615 responder_calls = []
1616
1617 async def responder(
1618 thread: ThreadMetadata,
1619 input: UserMessageItem | None,
1620 context: Any,
1621 ) -> AsyncIterator[ThreadStreamEvent]:
1622 responder_calls.append(input)
1623
1624 if len(responder_calls) == 1:
1625 # First response - return an assistant message
1626 yield ThreadItemDoneEvent(
1627 item=AssistantMessageItem(
1628 id="assistant_msg_1",
1629 content=[AssistantMessageContent(text="First response")],
1630 created_at=datetime.now(),
1631 thread_id=thread.id,
1632 ),
1633 )
1634 else:
1635 # Retry response - return different content
1636 yield ThreadItemDoneEvent(
1637 item=AssistantMessageItem(
1638 id="assistant_msg_2",
1639 content=[AssistantMessageContent(text="Retried response")],
1640 created_at=datetime.now(),
1641 thread_id=thread.id,
1642 ),
1643 )
1644
1645 with make_server(responder) as server:
1646 events = await server.process_streaming(
1647 ThreadsCreateReq(
1648 params=ThreadCreateParams(
1649 input=UserMessageInput(
1650 content=[UserMessageTextContent(text="Hello, world!")],
1651 attachments=[],
1652 inference_options=InferenceOptions(),
1653 )
1654 )
1655 )
1656 )
1657 thread = next(
1658 event.thread for event in events if event.type == "thread.created"
1659 )
1660
1661 list_result = await server.process_non_streaming(
1662 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1663 )
1664 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1665 assert len(items_before.data) == 2
1666 assert items_before.data[0].type == "assistant_message"
1667 assert items_before.data[0].content[0].type == "output_text"
1668 assert items_before.data[0].content[0].text == "First response"
1669 assert items_before.data[1].type == "user_message"
1670 assert items_before.data[1].content[0].text == "Hello, world!"
1671
1672 # Get the user message ID for retrying after it
1673 user_message_id = items_before.data[1].id
1674
1675 # Retry after the user message
1676 retry_events = await server.process_streaming(
1677 ThreadsRetryAfterItemReq(
1678 params=ThreadRetryAfterItemParams(
1679 thread_id=thread.id, item_id=user_message_id
1680 )
1681 )
1682 )
1683
1684 # Verify the retry generated new response
1685 assert len(retry_events) == 2
1686 assert retry_events[0].type == "stream_options"
1687 assert retry_events[1].type == "thread.item.done"
1688 assert retry_events[1].item.type == "assistant_message"
1689 assert retry_events[1].item.content[0].type == "output_text"
1690 assert retry_events[1].item.content[0].text == "Retried response"
1691
1692 # Verify the responder was called twice with the same user message
1693 assert len(responder_calls) == 2
1694 assert responder_calls[0] == responder_calls[1]
1695 assert responder_calls[0].type == "user_message"
1696 assert responder_calls[0].content[0].text == "Hello, world!"
1697
1698 # Verify the assistant response was replaced in storage
1699 list_result_after = await server.process_non_streaming(
1700 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1701 )
1702 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1703 list_result_after.json
1704 )
1705 assert len(items_after.data) == 2 # Still user message + assistant message
1706 assert items_after.data[0].type == "assistant_message"
1707 assert items_after.data[0].content[0].type == "output_text"
1708 assert items_after.data[0].content[0].text == "Retried response" # New response
1709 assert items_after.data[1].type == "user_message"
1710 assert items_after.data[1].content[0].text == "Hello, world!"
1711
1712
1713async def test_retry_after_item_invalid_item():
1714 async def responder(
1715 thread: ThreadMetadata,
1716 input: UserMessageItem | None,
1717 context: Any,
1718 ) -> AsyncIterator[ThreadStreamEvent]:
1719 yield ThreadItemDoneEvent(
1720 item=AssistantMessageItem(
1721 id="assistant_msg_1",
1722 content=[AssistantMessageContent(text="Response")],
1723 created_at=datetime.now(),
1724 thread_id=thread.id,
1725 ),
1726 )
1727
1728 with make_server(responder) as server:
1729 events = await server.process_streaming(
1730 ThreadsCreateReq(
1731 params=ThreadCreateParams(
1732 input=UserMessageInput(
1733 content=[UserMessageTextContent(text="Hello")],
1734 attachments=[],
1735 inference_options=InferenceOptions(),
1736 )
1737 )
1738 )
1739 )
1740 thread = next(
1741 event.thread for event in events if event.type == "thread.created"
1742 )
1743
1744 items_result = await server.process_non_streaming(
1745 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1746 )
1747 items = TypeAdapter(Page[ThreadItem]).validate_json(items_result.json)
1748 assistant_item_id = items.data[0].id
1749
1750 # Try to retry after an assistant message (should fail)
1751 with pytest.raises(ValueError, match="is not a user message"):
1752 await server.process_streaming(
1753 ThreadsRetryAfterItemReq(
1754 params=ThreadRetryAfterItemParams(
1755 thread_id=thread.id, item_id=assistant_item_id
1756 )
1757 )
1758 )
1759
1760
1761async def test_retry_after_item_with_multiple_messages():
1762 responder_calls = []
1763
1764 async def responder(
1765 thread: ThreadMetadata,
1766 input: UserMessageItem | None,
1767 context: Any,
1768 ) -> AsyncIterator[ThreadStreamEvent]:
1769 responder_calls.append(input)
1770 assert input is not None
1771 assert input.type == "user_message"
1772 yield ThreadItemDoneEvent(
1773 item=AssistantMessageItem(
1774 id=f"assistant_msg_{len(responder_calls)}",
1775 content=[
1776 AssistantMessageContent(
1777 text=f"Response to: {input.content[0].text}"
1778 )
1779 ],
1780 created_at=datetime.now(),
1781 thread_id=thread.id,
1782 ),
1783 )
1784
1785 with make_server(responder) as server:
1786 events = await server.process_streaming(
1787 ThreadsCreateReq(
1788 params=ThreadCreateParams(
1789 input=UserMessageInput(
1790 content=[UserMessageTextContent(text="First message")],
1791 attachments=[],
1792 inference_options=InferenceOptions(),
1793 )
1794 )
1795 )
1796 )
1797 thread = next(
1798 event.thread for event in events if event.type == "thread.created"
1799 )
1800
1801 await server.process_streaming(
1802 ThreadsAddUserMessageReq(
1803 params=ThreadAddUserMessageParams(
1804 thread_id=thread.id,
1805 input=UserMessageInput(
1806 content=[UserMessageTextContent(text="Second message")],
1807 attachments=[],
1808 inference_options=InferenceOptions(),
1809 ),
1810 )
1811 )
1812 )
1813
1814 # Verify we have 4 items: user1, assistant1, user2, assistant2
1815 list_result = await server.process_non_streaming(
1816 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1817 )
1818 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1819 assert len(items_before.data) == 4
1820
1821 # Get the second user message ID to retry after it
1822 second_user_message_id = items_before.data[1].id # Second user message
1823
1824 # Retry after the second user message
1825 retry_events = await server.process_streaming(
1826 ThreadsRetryAfterItemReq(
1827 params=ThreadRetryAfterItemParams(
1828 thread_id=thread.id, item_id=second_user_message_id
1829 )
1830 )
1831 )
1832
1833 assert len(retry_events) == 2
1834 assert retry_events[0].type == "stream_options"
1835 assert retry_events[1].type == "thread.item.done"
1836
1837 # Verify retry used the second user message
1838 assert len(responder_calls) == 3 # Original 2 + 1 retry
1839 assert responder_calls[2].type == "user_message"
1840 assert responder_calls[2].content[0].text == "Second message"
1841
1842 # Verify only the last assistant response was removed and regenerated
1843 list_result_after = await server.process_non_streaming(
1844 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1845 )
1846 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1847 list_result_after.json
1848 )
1849 assert len(items_after.data) == 4 # Still 4 items
1850 # First user message and response should remain unchanged
1851 assert items_after.data[3].type == "user_message"
1852 assert items_after.data[3].content[0].text == "First message"
1853 assert items_after.data[3].content[0].type == "input_text"
1854 assert items_after.data[3].content[0].text == "First message"
1855 assert items_after.data[2].type == "assistant_message"
1856 assert items_after.data[2].content[0].type == "output_text"
1857 assert items_after.data[2].content[0].text == "Response to: First message"
1858 # Second user message should remain, but assistant response should be new
1859 assert items_after.data[1].type == "user_message"
1860 assert items_after.data[1].content[0].text == "Second message"
1861 assert items_after.data[0].type == "assistant_message"
1862 assert items_after.data[0].content[0].type == "output_text"
1863 assert items_after.data[0].content[0].text == "Response to: Second message"
1864 assert items_after.data[0].id != items_before.data[0].id
1865
1866
1867async def test_threads_create_passes_tools_to_responder():
1868 tool_choice = ToolChoice(id="web_search")
1869
1870 async def responder(
1871 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1872 ) -> AsyncIterator[ThreadStreamEvent]:
1873 assert input is not None
1874 assert input.inference_options.tool_choice is not None
1875 assert input.inference_options.tool_choice.id == tool_choice.id
1876
1877 yield ThreadItemDoneEvent(
1878 item=AssistantMessageItem(
1879 id="assistant_msg_tools_create",
1880 content=[AssistantMessageContent(text="ok")],
1881 created_at=datetime.now(),
1882 thread_id=thread.id,
1883 ),
1884 )
1885
1886 with make_server(responder) as server:
1887 events = await server.process_streaming(
1888 ThreadsCreateReq(
1889 params=ThreadCreateParams(
1890 input=UserMessageInput(
1891 content=[UserMessageTextContent(text="Hello with tools")],
1892 attachments=[],
1893 inference_options=InferenceOptions(tool_choice=tool_choice),
1894 )
1895 )
1896 )
1897 )
1898 # Sanity check that the request flowed through.
1899 assert any(e.type == "thread.item.done" for e in events)
1900
1901
1902async def test_input_transcribe_decodes_base64_and_passes_mime_type():
1903 audio_bytes = b"hello audio"
1904 audio_b64 = base64.b64encode(audio_bytes).decode("ascii")
1905 seen: dict[str, Any] = {}
1906
1907 def transcribe_callback(
1908 audio: bytes, mime: str, context: Any
1909 ) -> TranscriptionResult:
1910 seen["audio"] = audio
1911 seen["mime"] = mime
1912 seen["context"] = context
1913 return TranscriptionResult(text="ok")
1914
1915 with make_server(transcribe_callback=transcribe_callback) as server:
1916 result = await server.process_non_streaming(
1917 InputTranscribeReq(
1918 params=InputTranscribeParams(
1919 audio_base64=audio_b64,
1920 mime_type="audio/wav",
1921 )
1922 )
1923 )
1924 parsed = TypeAdapter(TranscriptionResult).validate_json(result.json)
1925 assert parsed.text == "ok"
1926
1927 assert seen["audio"] == audio_bytes
1928 assert seen["mime"] == "audio/wav"
1929 assert seen["context"] == DEFAULT_CONTEXT
1930
1931
1932async def test_retry_after_item_passes_tools_to_responder():
1933 pass
1934
1935
1936async def test_threads_add_message_passes_tools_to_responder():
1937 tool_choice = ToolChoice(id="retrieval")
1938
1939 async def responder(
1940 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1941 ) -> AsyncIterator[ThreadStreamEvent]:
1942 assert input is not None
1943 assert input.inference_options.tool_choice is not None
1944 assert input.inference_options.tool_choice.id == tool_choice.id
1945
1946 yield ThreadItemDoneEvent(
1947 item=AssistantMessageItem(
1948 id="assistant_msg_tools_add",
1949 content=[AssistantMessageContent(text="ok")],
1950 created_at=datetime.now(),
1951 thread_id=thread.id,
1952 ),
1953 )
1954
1955 with make_server(responder) as server:
1956 # Create a thread first.
1957 events = await server.process_streaming(
1958 ThreadsCreateReq(
1959 params=ThreadCreateParams(
1960 input=UserMessageInput(
1961 content=[UserMessageTextContent(text="Start")],
1962 attachments=[],
1963 inference_options=InferenceOptions(tool_choice=tool_choice),
1964 )
1965 )
1966 )
1967 )
1968 thread = next(e.thread for e in events if e.type == "thread.created")
1969
1970 # Add a message with tools and verify they reach the responder.
1971 events = await server.process_streaming(
1972 ThreadsAddUserMessageReq(
1973 params=ThreadAddUserMessageParams(
1974 thread_id=thread.id,
1975 input=UserMessageInput(
1976 content=[UserMessageTextContent(text="Second")],
1977 attachments=[],
1978 inference_options=InferenceOptions(tool_choice=tool_choice),
1979 ),
1980 )
1981 )
1982 )
1983 # Sanity check that the request flowed through.
1984 assert any(e.type == "thread.item.done" for e in events)
1985
1986
1987async def test_custom_id_generators_on_server():
1988 class UniqueIdStore(SQLiteStore):
1989 def __init__(self, path):
1990 super().__init__(path)
1991 self.counter = 0
1992
1993 def generate_thread_id(self, context) -> str:
1994 self.counter += 1
1995 return f"thr_custom_{self.counter}"
1996
1997 def generate_item_id(
1998 self,
1999 item_type: str,
2000 thread: ThreadMetadata,
2001 context,
2002 ) -> str:
2003 self.counter += 1
2004 return f"{item_type}_custom_{self.counter}_{thread.id}"
2005
2006 async def responder(
2007 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
2008 ) -> AsyncIterator[ThreadStreamEvent]:
2009 # Emit a tool call to later verify tool_call_output id generation
2010 return
2011 yield
2012
2013 with make_server(responder) as server:
2014 # Swap in our custom ID-generating store, reusing the same DB path.
2015 db_path = cast(SQLiteStore, server.store).db_path
2016 server.store = UniqueIdStore(db_path)
2017
2018 # Create thread and verify thread id and user message id
2019 result = await server.process(
2020 ThreadsCreateReq(
2021 params=ThreadCreateParams(
2022 input=UserMessageInput(
2023 content=[UserMessageTextContent(text="Hello")],
2024 attachments=[],
2025 inference_options=InferenceOptions(),
2026 )
2027 )
2028 ).model_dump_json(),
2029 DEFAULT_CONTEXT,
2030 )
2031 assert isinstance(result, StreamingResult)
2032 events = await decode_streaming_result(result)
2033
2034 thread_created = next(e for e in events if e.type == "thread.created")
2035 assert thread_created.thread.id == "thr_custom_1"
2036
2037 user_message = next(
2038 e.item
2039 for e in events
2040 if e.type == "thread.item.done" and e.item.type == "user_message"
2041 )
2042 assert user_message.id == "message_custom_2_thr_custom_1"
2043