openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
4e1156dd81f931ed2f455197737892dd59437ab2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_chatkit_server.py

2045lines · modecode

1import asyncio
2import base64
3import sqlite3
4from contextlib import contextmanager
5from datetime import datetime
6from typing import Any, AsyncIterator, Callable, cast
7
8import pytest
9from helpers.mock_store import SQLiteStore
10from pydantic import AnyUrl, TypeAdapter
11
12from chatkit.actions import Action
13from chatkit.errors import ErrorCode
14from chatkit.server import (
15 ChatKitServer,
16 NonStreamingResult,
17 StreamingResult,
18 stream_widget,
19)
20from chatkit.store import (
21 AttachmentStore,
22 NotFoundError,
23 StoreItemType,
24 default_generate_id,
25)
26from chatkit.types import (
27 AssistantMessageContent,
28 AssistantMessageContentPartTextDelta,
29 AssistantMessageItem,
30 Attachment,
31 AttachmentCreateParams,
32 AttachmentDeleteParams,
33 AttachmentsCreateReq,
34 AttachmentsDeleteReq,
35 AttachmentUploadDescriptor,
36 AudioInput,
37 ClientToolCallItem,
38 CustomTask,
39 FeedbackKind,
40 FileAttachment,
41 ImageAttachment,
42 InferenceOptions,
43 InputTranscribeParams,
44 InputTranscribeReq,
45 ItemFeedbackParams,
46 ItemsFeedbackReq,
47 ItemsListParams,
48 ItemsListReq,
49 LockedStatus,
50 Page,
51 ProgressUpdateEvent,
52 Thread,
53 ThreadAddClientToolOutputParams,
54 ThreadAddUserMessageParams,
55 ThreadCreatedEvent,
56 ThreadCreateParams,
57 ThreadCustomActionParams,
58 ThreadDeleteParams,
59 ThreadGetByIdParams,
60 ThreadItem,
61 ThreadItemAddedEvent,
62 ThreadItemDoneEvent,
63 ThreadItemRemovedEvent,
64 ThreadItemReplacedEvent,
65 ThreadItemUpdatedEvent,
66 ThreadListParams,
67 ThreadMetadata,
68 ThreadRetryAfterItemParams,
69 ThreadsAddClientToolOutputReq,
70 ThreadsAddUserMessageReq,
71 ThreadsCreateReq,
72 ThreadsCustomActionReq,
73 ThreadsDeleteReq,
74 ThreadsGetByIdReq,
75 ThreadsListReq,
76 ThreadsRetryAfterItemReq,
77 ThreadStreamEvent,
78 ThreadsUpdateReq,
79 ThreadUpdatedEvent,
80 ThreadUpdateParams,
81 ToolChoice,
82 TranscriptionResult,
83 UserMessageInput,
84 UserMessageItem,
85 UserMessageTextContent,
86 WidgetItem,
87 WidgetRootUpdated,
88 Workflow,
89 WorkflowItem,
90 WorkflowTaskAdded,
91)
92from chatkit.widgets import Card, Text
93from tests._types import RequestContext
94from tests.test_store import make_thread_items
95
96server_id = 0
97
98
99DEFAULT_CONTEXT = RequestContext(user_id="test_user")
100
101
102class InMemoryFileStore(AttachmentStore):
103 def __init__(self):
104 self.files = {}
105
106 async def delete_attachment(self, attachment_id: str, context: Any) -> None:
107 del self.files[attachment_id]
108
109 async def create_attachment(
110 self, input: AttachmentCreateParams, context: RequestContext
111 ) -> Attachment:
112 # Simple in-memory attachment creation
113 id = f"atc_{len(self.files) + 1}"
114 metadata = {"source": "test"}
115 if input.mime_type.startswith("image/"):
116 attachment = ImageAttachment(
117 id=id,
118 mime_type=input.mime_type,
119 name=input.name,
120 preview_url=AnyUrl(f"https://example.com/{id}/preview"),
121 upload_descriptor=AttachmentUploadDescriptor(
122 url=AnyUrl(f"https://example.com/{id}/upload"),
123 method="PUT",
124 headers={"X-My-Header": "my-value"},
125 ),
126 metadata=metadata,
127 )
128 else:
129 attachment = FileAttachment(
130 id=id,
131 mime_type=input.mime_type,
132 name=input.name,
133 upload_descriptor=AttachmentUploadDescriptor(
134 url=AnyUrl(f"https://example.com/{id}/upload"),
135 method="PUT",
136 headers={"X-My-Header": "my-value"},
137 ),
138 metadata=metadata,
139 )
140 self.files[attachment.id] = attachment
141 return attachment
142
143
144def decode_event(event: bytes) -> ThreadStreamEvent:
145 return TypeAdapter(ThreadStreamEvent).validate_json(event.split(b"data: ")[1])
146
147
148async def decode_streaming_result(
149 streaming_result: StreamingResult,
150) -> list[ThreadStreamEvent]:
151 return [decode_event(event) async for event in streaming_result.json_events]
152
153
154@contextmanager
155def make_server(
156 responder: Callable[
157 [ThreadMetadata, UserMessageItem | None, Any], AsyncIterator[ThreadStreamEvent]
158 ]
159 | None = None,
160 handle_feedback: Callable[[str, list[str], FeedbackKind, Any], None] | None = None,
161 action_callback: Callable[
162 [ThreadMetadata, Action[str, Any], WidgetItem | None, Any],
163 AsyncIterator[ThreadStreamEvent],
164 ]
165 | None = None,
166 file_store: AttachmentStore | None = None,
167 transcribe_callback: Callable[[AudioInput, Any], TranscriptionResult] | None = None,
168):
169 global server_id
170 db_path = f"file:{server_id}?mode=memory&cache=shared"
171 server_id += 1
172 # Keep the shared in-memory database from being deleted when the last connection is closed
173 db = sqlite3.connect(db_path, uri=True)
174
175 class TestChatKitServer(ChatKitServer):
176 def __init__(self):
177 super().__init__(SQLiteStore(db_path), file_store)
178
179 def action(
180 self,
181 thread: ThreadMetadata,
182 action: Action[str, Any],
183 sender: WidgetItem | None,
184 context: Any,
185 ) -> AsyncIterator[ThreadStreamEvent]:
186 if action_callback is None:
187 raise ValueError("action_callback not wired up")
188 return action_callback(thread, action, sender, context)
189
190 def respond(
191 self,
192 thread: ThreadMetadata,
193 input_user_message: UserMessageItem | None,
194 context: Any,
195 ) -> AsyncIterator[ThreadStreamEvent]:
196 if responder is None:
197 return self._empty_responder()
198 return responder(thread, input_user_message, context)
199
200 async def _empty_responder(self) -> AsyncIterator[ThreadStreamEvent]:
201 return
202 yield
203
204 async def add_feedback(
205 self,
206 thread_id: str,
207 item_ids: list[str],
208 feedback: FeedbackKind,
209 context: Any,
210 ) -> None:
211 if handle_feedback is None:
212 return
213 handle_feedback(thread_id, item_ids, feedback, context)
214
215 async def transcribe(
216 self, audio_input: AudioInput, context: Any
217 ) -> TranscriptionResult:
218 if transcribe_callback is None:
219 return await super().transcribe(audio_input, context)
220 return transcribe_callback(audio_input, context)
221
222 async def process_streaming(
223 self, request_obj, context: Any | None = None
224 ) -> list[ThreadStreamEvent]:
225 result = await self.process(
226 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
227 )
228 assert isinstance(result, StreamingResult)
229 return await decode_streaming_result(result)
230
231 async def process_non_streaming(self, request_obj, context: Any | None = None):
232 result = await self.process(
233 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
234 )
235 assert isinstance(result, NonStreamingResult)
236 return result
237
238 try:
239 yield TestChatKitServer()
240 finally:
241 db.close()
242
243
244async def test_stream_cancellation_persists_pending_assistant_message_and_hidden_context():
245 async def responder(
246 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
247 ) -> AsyncIterator[ThreadStreamEvent]:
248 yield ThreadItemAddedEvent(
249 item=AssistantMessageItem(
250 id="assistant-message-pending",
251 created_at=datetime.now(),
252 content=[AssistantMessageContent(text="Hello, ")],
253 thread_id=thread.id,
254 )
255 )
256 yield ThreadItemUpdatedEvent(
257 item_id="assistant-message-pending",
258 update=AssistantMessageContentPartTextDelta(
259 content_index=0,
260 delta="World!",
261 ),
262 )
263 raise asyncio.CancelledError()
264
265 with make_server(responder) as server:
266 # Allow hidden_context id generation in this test store
267 original_generate_item_id = server.store.generate_item_id
268
269 def generate_item_id(
270 item_type: StoreItemType, thread: ThreadMetadata, context: Any
271 ):
272 if item_type == "sdk_hidden_context":
273 return default_generate_id("sdk_hidden_context")
274 return original_generate_item_id(item_type, thread, context)
275
276 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
277
278 stream = await server.process(
279 ThreadsCreateReq(
280 params=ThreadCreateParams(
281 input=UserMessageInput(
282 content=[UserMessageTextContent(text="Hello")],
283 attachments=[],
284 inference_options=InferenceOptions(),
285 )
286 )
287 ).model_dump_json(),
288 DEFAULT_CONTEXT,
289 )
290 assert isinstance(stream, StreamingResult)
291
292 events: list[ThreadStreamEvent] = []
293 with pytest.raises(asyncio.CancelledError): # noqa: PT012
294 async for raw in stream.json_events:
295 events.append(decode_event(raw))
296
297 thread = next(e.thread for e in events if e.type == "thread.created")
298 items = await server.store.load_thread_items(
299 thread.id, None, 1, "desc", DEFAULT_CONTEXT
300 )
301 hidden_context_item = items.data[-1]
302 assert hidden_context_item.type == "sdk_hidden_context"
303 assert (
304 hidden_context_item.content
305 == "The user cancelled the stream. Stop responding to the prior request."
306 )
307
308 assistant_message_item = await server.store.load_item(
309 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
310 )
311 assert assistant_message_item.type == "assistant_message"
312 assert assistant_message_item.content[0].text == "Hello, World!"
313
314
315async def test_stream_cancellation_does_not_persist_pending_empty_assistant_message():
316 async def responder(
317 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
318 ) -> AsyncIterator[ThreadStreamEvent]:
319 yield ThreadItemAddedEvent(
320 item=AssistantMessageItem(
321 id="assistant-message-pending",
322 created_at=datetime.now(),
323 content=[],
324 thread_id=thread.id,
325 )
326 )
327 raise asyncio.CancelledError()
328
329 with make_server(responder) as server:
330 original_generate_item_id = server.store.generate_item_id
331
332 def generate_item_id(
333 item_type: StoreItemType, thread: ThreadMetadata, context: Any
334 ):
335 if item_type == "sdk_hidden_context":
336 return default_generate_id("sdk_hidden_context")
337 return original_generate_item_id(item_type, thread, context)
338
339 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
340
341 stream = await server.process(
342 ThreadsCreateReq(
343 params=ThreadCreateParams(
344 input=UserMessageInput(
345 content=[UserMessageTextContent(text="Hello")],
346 attachments=[],
347 inference_options=InferenceOptions(),
348 )
349 )
350 ).model_dump_json(),
351 DEFAULT_CONTEXT,
352 )
353 assert isinstance(stream, StreamingResult)
354
355 events: list[ThreadStreamEvent] = []
356 with pytest.raises(asyncio.CancelledError): # noqa: PT012
357 async for raw in stream.json_events:
358 events.append(decode_event(raw))
359
360 thread = next(e.thread for e in events if e.type == "thread.created")
361 items = await server.store.load_thread_items(
362 thread.id, None, 1, "desc", DEFAULT_CONTEXT
363 )
364 hidden_context_item = items.data[-1]
365 assert hidden_context_item.type == "sdk_hidden_context"
366 assert (
367 hidden_context_item.content
368 == "The user cancelled the stream. Stop responding to the prior request."
369 )
370
371 with pytest.raises(NotFoundError):
372 await server.store.load_item(
373 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
374 )
375
376
377async def test_workflow_task_not_duplicated_on_done_event():
378 """
379 Regression test to make sure pending item updates do not modify the
380 origin item that was streamed.
381 """
382
383 async def responder(
384 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
385 ) -> AsyncIterator[ThreadStreamEvent]:
386 workflow_item = WorkflowItem(
387 id="workflow-item",
388 created_at=datetime.now(),
389 thread_id=thread.id,
390 workflow=Workflow(type="custom", tasks=[]),
391 )
392
393 yield ThreadItemAddedEvent(item=workflow_item)
394
395 task = CustomTask(title="foo", content="bar")
396 yield ThreadItemUpdatedEvent(
397 item_id=workflow_item.id,
398 update=WorkflowTaskAdded(task=task, task_index=0),
399 )
400
401 workflow_item.workflow.tasks.append(task)
402 yield ThreadItemDoneEvent(item=workflow_item)
403
404 with make_server(responder) as server:
405 events = await server.process_streaming(
406 ThreadsCreateReq(
407 params=ThreadCreateParams(
408 input=UserMessageInput(
409 content=[UserMessageTextContent(text="Hello")],
410 attachments=[],
411 inference_options=InferenceOptions(),
412 )
413 )
414 )
415 )
416
417 thread = next(e.thread for e in events if e.type == "thread.created")
418 workflow_done_event = next(
419 e
420 for e in events
421 if isinstance(e, ThreadItemDoneEvent) and isinstance(e.item, WorkflowItem)
422 )
423 workflow_done_item = cast(WorkflowItem, workflow_done_event.item)
424 assert len(workflow_done_item.workflow.tasks) == 1
425 assert workflow_done_item.workflow.tasks[0].title == "foo"
426
427 stored = await server.store.load_item(
428 thread.id, workflow_done_item.id, DEFAULT_CONTEXT
429 )
430 assert isinstance(stored, WorkflowItem)
431 stored_workflow = cast(WorkflowItem, stored)
432 assert len(stored_workflow.workflow.tasks) == 1
433 assert stored_workflow.workflow.tasks[0].title == "foo"
434
435
436async def test_flows_context_to_responder():
437 responder_context = None
438 add_feedback_context = None
439
440 async def responder(
441 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
442 ) -> AsyncIterator[ThreadStreamEvent]:
443 nonlocal responder_context
444 responder_context = context
445 yield ThreadItemDoneEvent(
446 item=AssistantMessageItem(
447 id="msg_1",
448 created_at=datetime.now(),
449 content=[AssistantMessageContent(text="Hello, world!")],
450 thread_id=thread.id,
451 ),
452 )
453
454 def add_feedback(
455 thread_id: str, item_ids: list[str], feedback: FeedbackKind, context: Any
456 ) -> None:
457 nonlocal add_feedback_context
458 add_feedback_context = context
459
460 context = RequestContext(user_id="text_ctx")
461 with make_server(responder, add_feedback) as server:
462 events = await server.process_streaming(
463 ThreadsCreateReq(
464 params=ThreadCreateParams(
465 input=UserMessageInput(
466 content=[UserMessageTextContent(text="Hello, world!")],
467 attachments=[],
468 inference_options=InferenceOptions(),
469 )
470 )
471 ),
472 context,
473 )
474 assert responder_context == context
475 thread = next(
476 event.thread for event in events if event.type == "thread.created"
477 )
478 item = next(
479 event.item
480 for event in events
481 if event.type == "thread.item.done"
482 and event.item.type == "assistant_message"
483 )
484 await server.process_non_streaming(
485 ItemsFeedbackReq(
486 params=ItemFeedbackParams(
487 thread_id=thread.id,
488 item_ids=[item.id],
489 kind="positive",
490 )
491 ),
492 context,
493 )
494 assert add_feedback_context == context
495
496
497async def test_creates_thread():
498 async def responder(
499 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
500 ) -> AsyncIterator[ThreadStreamEvent]:
501 return
502 yield
503
504 with make_server(responder) as server:
505 events = await server.process_streaming(
506 ThreadsCreateReq(
507 params=ThreadCreateParams(
508 input=UserMessageInput(
509 content=[UserMessageTextContent(text="Hello, world!")],
510 attachments=[],
511 inference_options=InferenceOptions(),
512 quoted_text="Quoted text",
513 )
514 )
515 )
516 )
517
518 thread_created_event = next(
519 event for event in events if event.type == "thread.created"
520 )
521 assert thread_created_event.thread.title is None
522
523 item_done_event = next(
524 event.item for event in events if event.type == "thread.item.done"
525 )
526
527 assert item_done_event.type == "user_message"
528 assert item_done_event.content[0].text == "Hello, world!"
529 assert item_done_event.quoted_text == "Quoted text"
530
531
532async def test_saves_thread_title():
533 async def responder(
534 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
535 ) -> AsyncIterator[ThreadStreamEvent]:
536 thread.title = "Updated title"
537 return
538 yield
539
540 with make_server(responder) as server:
541 events = await server.process_streaming(
542 ThreadsCreateReq(
543 params=ThreadCreateParams(
544 input=UserMessageInput(
545 content=[UserMessageTextContent(text="Hello, world!")],
546 attachments=[],
547 inference_options=InferenceOptions(),
548 )
549 )
550 )
551 )
552 thread = next(
553 event.thread for event in events if event.type == "thread.created"
554 )
555 assert (
556 await server.store.load_thread(
557 thread.id, RequestContext(user_id="test_user")
558 )
559 ).title == "Updated title"
560 assert events[-1].type == "thread.updated"
561 assert events[-1].thread.title == "Updated title"
562
563
564async def test_saves_thread_metadata():
565 async def responder(
566 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
567 ) -> AsyncIterator[ThreadStreamEvent]:
568 thread.metadata["test_key"] = "test_value"
569 return
570 yield
571
572 with make_server(responder) as server:
573 events = await server.process_streaming(
574 ThreadsCreateReq(
575 params=ThreadCreateParams(
576 input=UserMessageInput(
577 content=[UserMessageTextContent(text="Hello, world!")],
578 attachments=[],
579 inference_options=InferenceOptions(),
580 )
581 )
582 )
583 )
584 thread = next(
585 event.thread for event in events if event.type == "thread.created"
586 )
587 assert (await server.store.load_thread(thread.id, DEFAULT_CONTEXT)).metadata[
588 "test_key"
589 ] == "test_value"
590 assert events[-1].type == "thread.updated"
591
592
593async def test_saves_thread_locked_fields():
594 async def responder(
595 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
596 ) -> AsyncIterator[ThreadStreamEvent]:
597 thread.status = LockedStatus(reason="Because")
598 return
599 yield
600
601 with make_server(responder) as server:
602 events = await server.process_streaming(
603 ThreadsCreateReq(
604 params=ThreadCreateParams(
605 input=UserMessageInput(
606 content=[UserMessageTextContent(text="Hello, world!")],
607 attachments=[],
608 inference_options=InferenceOptions(),
609 )
610 )
611 )
612 )
613 thread = next(
614 event.thread for event in events if event.type == "thread.created"
615 )
616 loaded = await server.store.load_thread(
617 thread.id, RequestContext(user_id="test_user")
618 )
619 assert loaded.status == LockedStatus(reason="Because")
620 assert events[-1].type == "thread.updated"
621 assert events[-1].thread.status == LockedStatus(reason="Because")
622
623
624async def test_emits_thread_updated_mid_stream_and_persists():
625 async def responder(
626 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
627 ) -> AsyncIterator[ThreadStreamEvent]:
628 # First assistant message
629 yield ThreadItemDoneEvent(
630 item=AssistantMessageItem(
631 id="assistant_a",
632 content=[AssistantMessageContent(text="A")],
633 created_at=datetime.now(),
634 thread_id=thread.id,
635 ),
636 )
637
638 # Update the thread mid-stream
639 thread.title = "Mid-stream title"
640
641 # Second assistant message, after which thread.updated should be emitted
642 yield ThreadItemDoneEvent(
643 item=AssistantMessageItem(
644 id="assistant_b",
645 content=[AssistantMessageContent(text="B")],
646 created_at=datetime.now(),
647 thread_id=thread.id,
648 ),
649 )
650
651 # Continue streaming after the update event
652 yield ThreadItemDoneEvent(
653 item=AssistantMessageItem(
654 id="assistant_c",
655 content=[AssistantMessageContent(text="C")],
656 created_at=datetime.now(),
657 thread_id=thread.id,
658 ),
659 )
660
661 with make_server(responder) as server:
662 events = await server.process_streaming(
663 ThreadsCreateReq(
664 params=ThreadCreateParams(
665 input=UserMessageInput(
666 content=[UserMessageTextContent(text="Hello")],
667 attachments=[],
668 inference_options=InferenceOptions(),
669 )
670 )
671 )
672 )
673
674 # Find the created thread
675 thread = next(e.thread for e in events if isinstance(e, ThreadCreatedEvent))
676
677 # Ensure a thread.updated event occurs right after assistant_b
678 idx_b = next(
679 i
680 for i, e in enumerate(events)
681 if isinstance(e, ThreadItemDoneEvent)
682 and isinstance(e.item, AssistantMessageItem)
683 and e.item.id == "assistant_b"
684 )
685 assert isinstance(events[idx_b + 1], ThreadUpdatedEvent)
686 updated_event = cast(ThreadUpdatedEvent, events[idx_b + 1])
687 assert updated_event.thread.title == "Mid-stream title"
688
689 # Streaming continues after the update event
690 assert isinstance(events[idx_b + 2], ThreadItemDoneEvent)
691 done_event = cast(ThreadItemDoneEvent, events[idx_b + 2])
692 assert isinstance(done_event.item, AssistantMessageItem)
693 assert done_event.item.id == "assistant_c"
694
695 # Persisted in the store
696 saved = await server.store.load_thread(thread.id, DEFAULT_CONTEXT)
697 assert saved.title == "Mid-stream title"
698
699
700async def test_respond_with_tool_call():
701 async def responder(
702 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
703 ) -> AsyncIterator[ThreadStreamEvent]:
704 if isinstance(input, UserMessageItem):
705 yield ThreadItemDoneEvent(
706 item=ClientToolCallItem(
707 id="msg_1",
708 created_at=datetime.now(),
709 name="tool_call_1",
710 arguments={"arg1": "val1", "arg2": False},
711 call_id="tool_call_1",
712 thread_id=thread.id,
713 ),
714 )
715 elif input is None:
716 yield ThreadItemDoneEvent(
717 item=AssistantMessageItem(
718 id="msg_2",
719 content=[
720 AssistantMessageContent(text="Glad the tool call succeeded!")
721 ],
722 created_at=datetime.now(),
723 thread_id=thread.id,
724 ),
725 )
726
727 with make_server(responder) as server:
728 events = await server.process_streaming(
729 ThreadsCreateReq(
730 params=ThreadCreateParams(
731 input=UserMessageInput(
732 content=[UserMessageTextContent(text="Hello, world!")],
733 attachments=[],
734 inference_options=InferenceOptions(),
735 )
736 )
737 )
738 )
739
740 assert len(events) == 4
741 assert events[0].type == "thread.created"
742 thread = events[0].thread
743
744 assert events[1].type == "thread.item.done"
745 assert events[1].item.type == "user_message"
746
747 assert events[2].type == "stream_options"
748
749 assert events[3].type == "thread.item.done"
750 assert events[3].item.type == "client_tool_call"
751 assert events[3].item.id == "msg_1"
752 assert events[3].item.name == "tool_call_1"
753 assert events[3].item.arguments == {"arg1": "val1", "arg2": False}
754 assert events[3].item.call_id == "tool_call_1"
755
756 events = await server.process_streaming(
757 ThreadsAddClientToolOutputReq(
758 params=ThreadAddClientToolOutputParams(
759 thread_id=thread.id,
760 result={"text": "Wow!"},
761 )
762 )
763 )
764
765 assert len(events) == 2
766 assert events[0].type == "stream_options"
767 assert events[1].type == "thread.item.done"
768 assert events[1].item.type == "assistant_message"
769
770
771async def test_respond_with_tool_status():
772 async def responder(
773 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
774 ) -> AsyncIterator[ThreadStreamEvent]:
775 yield ProgressUpdateEvent(text="Tool status")
776 yield ThreadItemDoneEvent(
777 item=AssistantMessageItem(
778 id="msg_4",
779 content=[AssistantMessageContent(text="Hello, world!")],
780 created_at=datetime.now(),
781 thread_id=thread.id,
782 ),
783 )
784
785 with make_server(responder) as server:
786 events = await server.process_streaming(
787 ThreadsCreateReq(
788 params=ThreadCreateParams(
789 input=UserMessageInput(
790 content=[UserMessageTextContent(text="Hello, world!")],
791 attachments=[],
792 inference_options=InferenceOptions(),
793 )
794 )
795 )
796 )
797
798 assert len(events) == 5
799 assert events[0].type == "thread.created"
800 assert events[1].type == "thread.item.done"
801 assert events[2].type == "stream_options"
802 assert events[3].type == "progress_update"
803 assert events[4].type == "thread.item.done"
804
805
806async def test_list_threads_response():
807 with make_server() as server:
808 for i in range(25):
809 await server.process_streaming(
810 ThreadsCreateReq(
811 params=ThreadCreateParams(
812 input=UserMessageInput(
813 content=[UserMessageTextContent(text=f"Thread {i}")],
814 attachments=[],
815 inference_options=InferenceOptions(),
816 )
817 )
818 )
819 )
820 list_result = await server.process_non_streaming(
821 ThreadsListReq(params=ThreadListParams(limit=20))
822 )
823 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
824 assert len(threads.data) == 20
825
826 # Newest first
827 assert threads.data[0].created_at > threads.data[1].created_at
828 assert threads.has_more is True
829 assert threads.after is not None
830
831 more_threads = await server.process_non_streaming(
832 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
833 )
834 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(more_threads.json)
835 assert len(threads.data) == 5
836 assert threads.has_more is False
837 assert not threads.after
838
839
840async def test_list_items_response():
841 async def responder(
842 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
843 ) -> AsyncIterator[ThreadStreamEvent]:
844 for i in range(25):
845 yield ThreadItemDoneEvent(
846 item=UserMessageItem(
847 id=f"msg_{i}",
848 content=[UserMessageTextContent(text=f"Message {i}")],
849 attachments=[],
850 inference_options=InferenceOptions(),
851 thread_id=thread.id,
852 created_at=datetime.now(),
853 ),
854 )
855
856 with make_server(responder) as server:
857 events = await server.process_streaming(
858 ThreadsCreateReq(
859 params=ThreadCreateParams(
860 input=UserMessageInput(
861 content=[UserMessageTextContent(text="Thread 1")],
862 attachments=[],
863 inference_options=InferenceOptions(),
864 )
865 )
866 )
867 )
868 thread = next(
869 event.thread for event in events if event.type == "thread.created"
870 )
871 list_result = await server.process_non_streaming(
872 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
873 )
874 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
875 assert len(items.data) == 20
876 # Newest first
877 assert items.data[0].created_at > items.data[1].created_at
878 assert items.has_more is True
879 assert items.after is not None
880
881 list_result = await server.process_non_streaming(
882 ItemsListReq(
883 params=ItemsListParams(thread_id=thread.id, after=items.data[-1].id)
884 )
885 )
886 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
887 assert len(items.data) == 6 # +1 for user message
888 assert items.has_more is False
889
890
891async def test_calls_action():
892 actions = []
893
894 async def responder(
895 thread: ThreadMetadata,
896 input: UserMessageItem | None,
897 context: Any,
898 ) -> AsyncIterator[ThreadStreamEvent]:
899 yield ThreadItemDoneEvent(
900 item=WidgetItem(
901 id="widget_1",
902 type="widget",
903 created_at=datetime.now(),
904 thread_id=thread.id,
905 widget=Card(
906 children=[
907 Text(
908 key="text_1",
909 value="test",
910 )
911 ],
912 ),
913 ),
914 )
915
916 async def action(
917 thread: ThreadMetadata,
918 action: Action[str, Any],
919 sender: WidgetItem | None,
920 context: Any,
921 ) -> AsyncIterator[ThreadStreamEvent]:
922 actions.append((action, sender))
923 assert sender
924
925 yield ThreadItemUpdatedEvent(
926 item_id=sender.id,
927 update=WidgetRootUpdated(
928 widget=Card(
929 children=[
930 Text(value="Email sent!"),
931 ]
932 ),
933 ),
934 )
935
936 with make_server(responder, action_callback=action) as server:
937 events = await server.process_streaming(
938 ThreadsCreateReq(
939 params=ThreadCreateParams(
940 input=UserMessageInput(
941 content=[UserMessageTextContent(text="Show widget")],
942 attachments=[],
943 inference_options=InferenceOptions(),
944 )
945 )
946 )
947 )
948 thread = next(
949 event.thread for event in events if event.type == "thread.created"
950 )
951 widget_item = next(
952 event.item
953 for event in events
954 if isinstance(event, ThreadItemDoneEvent)
955 and isinstance(event.item, WidgetItem)
956 )
957
958 events = await server.process_streaming(
959 ThreadsCustomActionReq(
960 params=ThreadCustomActionParams(
961 thread_id=thread.id,
962 item_id=widget_item.id,
963 action=Action(type="create_user", payload={"user_id": "123"}),
964 )
965 )
966 )
967
968 assert actions
969 assert actions[0] == (
970 Action(type="create_user", payload={"user_id": "123"}),
971 widget_item,
972 )
973
974 assert len(events) == 2
975 assert events[0].type == "stream_options"
976 assert events[1].type == "thread.item.updated"
977 assert isinstance(events[1], ThreadItemUpdatedEvent)
978 assert events[1].update.type == "widget.root.updated"
979 assert events[1].update.widget == Card(children=[Text(value="Email sent!")])
980
981
982async def test_add_feedback():
983 called = []
984
985 def handle_feedback(thread_id, item_ids, feedback, context):
986 called.append((thread_id, item_ids, feedback, context))
987
988 async def responder(
989 thread: ThreadMetadata,
990 input: UserMessageItem | None,
991 context: Any,
992 ) -> AsyncIterator[ThreadStreamEvent]:
993 yield ThreadItemDoneEvent(
994 item=AssistantMessageItem(
995 id="assistant_msg_1",
996 content=[
997 AssistantMessageContent(text="Feedback received!"),
998 ],
999 created_at=datetime.now(),
1000 thread_id=thread.id,
1001 ),
1002 )
1003 yield ThreadItemDoneEvent(
1004 item=WidgetItem(
1005 id="widget_1",
1006 type="widget",
1007 created_at=datetime.now(),
1008 widget=Card(children=[Text(key="text", value="Widget content")]),
1009 thread_id=thread.id,
1010 ),
1011 )
1012
1013 with make_server(responder, handle_feedback=handle_feedback) as server:
1014 events = await server.process_streaming(
1015 ThreadsCreateReq(
1016 params=ThreadCreateParams(
1017 input=UserMessageInput(
1018 content=[UserMessageTextContent(text="Test feedback")],
1019 attachments=[],
1020 inference_options=InferenceOptions(),
1021 )
1022 )
1023 )
1024 )
1025 thread = next(
1026 event.thread for event in events if event.type == "thread.created"
1027 )
1028
1029 await server.process_non_streaming(
1030 ItemsFeedbackReq(
1031 params=ItemFeedbackParams(
1032 thread_id=thread.id,
1033 item_ids=["assistant_msg_1"],
1034 kind="positive",
1035 )
1036 )
1037 )
1038 await server.process_non_streaming(
1039 ItemsFeedbackReq(
1040 params=ItemFeedbackParams(
1041 thread_id=thread.id,
1042 item_ids=["widget_1"],
1043 kind="negative",
1044 )
1045 )
1046 )
1047
1048 assert len(called) == 2
1049 assert called[0][0] == thread.id
1050 assert called[0][1] == ["assistant_msg_1"]
1051 assert called[0][2] == "positive"
1052
1053 assert called[1][0] == thread.id
1054 assert called[1][1] == ["widget_1"]
1055 assert called[1][2] == "negative"
1056
1057
1058async def test_get_thread_by_id():
1059 make_thread_items()
1060
1061 with make_server() as server:
1062 # Create a thread by sending a ThreadCreateRequest
1063 events = await server.process_streaming(
1064 ThreadsCreateReq(
1065 params=ThreadCreateParams(
1066 input=UserMessageInput(
1067 content=[UserMessageTextContent(text="Test thread")],
1068 attachments=[],
1069 inference_options=InferenceOptions(),
1070 )
1071 )
1072 )
1073 )
1074 thread = next(
1075 event.thread for event in events if event.type == "thread.created"
1076 )
1077
1078 # Retrieve the thread by id
1079 result = await server.process_non_streaming(
1080 ThreadsGetByIdReq(
1081 params=ThreadGetByIdParams(thread_id=thread.id),
1082 )
1083 )
1084 loaded_thread = TypeAdapter(Thread).validate_json(result.json)
1085 assert loaded_thread.id == thread.id
1086 assert loaded_thread.title == thread.title
1087 assert loaded_thread.items.data[0].type == "user_message"
1088 assert loaded_thread.items.data[0].content[0].text == "Test thread"
1089
1090
1091async def test_create_file():
1092 store = InMemoryFileStore()
1093 file_name = "test-file-name"
1094 file_content_type = "text/plain"
1095
1096 with make_server(file_store=store) as server:
1097 result = await server.process_non_streaming(
1098 AttachmentsCreateReq(
1099 params=AttachmentCreateParams(
1100 name=file_name, size=1024, mime_type=file_content_type
1101 )
1102 )
1103 )
1104 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1105
1106 # Verify response includes file
1107 assert attachment.id is not None
1108 assert attachment.mime_type == file_content_type
1109 assert attachment.name == file_name
1110 assert attachment.type == "file"
1111 assert attachment.upload_descriptor is not None
1112 assert attachment.upload_descriptor.url == AnyUrl(
1113 f"https://example.com/{attachment.id}/upload"
1114 )
1115 assert attachment.upload_descriptor.method == "PUT"
1116 assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"}
1117 assert attachment.id in store.files
1118
1119 # The `metadata` field is excluded from serialization by default but still stored server-side
1120 assert attachment.metadata is None
1121 assert store.files[attachment.id].metadata == {"source": "test"}
1122
1123
1124async def test_create_image_file():
1125 store = InMemoryFileStore()
1126 file_name = "test-file-name"
1127 file_content_type = "image/png"
1128
1129 with make_server(file_store=store) as server:
1130 result = await server.process_non_streaming(
1131 AttachmentsCreateReq(
1132 params=AttachmentCreateParams(
1133 name=file_name,
1134 size=1024,
1135 mime_type=file_content_type,
1136 )
1137 )
1138 )
1139 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1140
1141 assert attachment.id is not None
1142 assert attachment.mime_type == file_content_type
1143 assert attachment.name == file_name
1144 assert attachment.type == "image"
1145 assert attachment.preview_url == AnyUrl(
1146 f"https://example.com/{attachment.id}/preview"
1147 )
1148 assert attachment.upload_descriptor is not None
1149 assert attachment.upload_descriptor.url == AnyUrl(
1150 f"https://example.com/{attachment.id}/upload"
1151 )
1152 assert attachment.upload_descriptor.method == "PUT"
1153 assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"}
1154
1155 assert attachment.id in store.files
1156
1157
1158async def test_user_message_attachment_metadata_not_serialized():
1159 attachment = FileAttachment(
1160 id="file_with_metadata",
1161 type="file",
1162 mime_type="text/plain",
1163 name="test.txt",
1164 metadata={"source": "test"},
1165 )
1166
1167 with make_server() as server:
1168 await server.store.save_attachment(attachment, DEFAULT_CONTEXT)
1169 events = await server.process_streaming(
1170 ThreadsCreateReq(
1171 params=ThreadCreateParams(
1172 input=UserMessageInput(
1173 content=[UserMessageTextContent(text="Message with file")],
1174 attachments=[attachment.id],
1175 inference_options=InferenceOptions(),
1176 )
1177 )
1178 )
1179 )
1180 thread = next(
1181 event.thread for event in events if event.type == "thread.created"
1182 )
1183
1184 list_result = await server.process_non_streaming(
1185 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1186 )
1187 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1188 user_items = [item for item in items.data if item.type == "user_message"]
1189 assert user_items
1190 attachments = user_items[0].attachments
1191 assert attachments
1192 assert all(att.metadata is None for att in attachments)
1193
1194
1195async def test_create_file_without_filestore():
1196 with pytest.raises(RuntimeError):
1197 with make_server() as server:
1198 await server.process_non_streaming(
1199 AttachmentsCreateReq(
1200 params=AttachmentCreateParams(
1201 name="test-file-name",
1202 size=1024,
1203 mime_type="text/plain",
1204 )
1205 )
1206 )
1207
1208
1209async def test_delete_file():
1210 store = InMemoryFileStore()
1211 store.files["test-file-id"] = {"id": "test-file-id"}
1212 with make_server(file_store=store) as server:
1213 await server.process_non_streaming(
1214 AttachmentsDeleteReq(
1215 params=AttachmentDeleteParams(attachment_id="test-file-id")
1216 )
1217 )
1218 assert store.files == {}
1219
1220
1221async def test_delete_file_without_filestore():
1222 with pytest.raises(RuntimeError):
1223 with make_server() as server:
1224 await server.process_non_streaming(
1225 AttachmentsDeleteReq(
1226 params=AttachmentDeleteParams(attachment_id="test-file-id")
1227 )
1228 )
1229
1230
1231async def test_list_threads_paginated_response():
1232 with make_server() as server:
1233 for i in range(25):
1234 await server.process_streaming(
1235 ThreadsCreateReq(
1236 params=ThreadCreateParams(
1237 input=UserMessageInput(
1238 content=[UserMessageTextContent(text=f"Thread {i}")],
1239 attachments=[],
1240 inference_options=InferenceOptions(),
1241 )
1242 )
1243 )
1244 )
1245 list_result = await server.process_non_streaming(
1246 ThreadsListReq(params=ThreadListParams(limit=20))
1247 )
1248 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1249 assert len(threads.data) == 20
1250 assert threads.has_more is True
1251 assert threads.after is not None
1252
1253 list_result = await server.process_non_streaming(
1254 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
1255 )
1256 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1257 assert len(threads.data) == 5
1258 assert threads.has_more is False
1259
1260
1261async def test_threads_update_request():
1262 with make_server() as server:
1263 events = await server.process_streaming(
1264 ThreadsCreateReq(
1265 params=ThreadCreateParams(
1266 input=UserMessageInput(
1267 content=[UserMessageTextContent(text="Hello, world!")],
1268 attachments=[],
1269 inference_options=InferenceOptions(),
1270 )
1271 )
1272 )
1273 )
1274 thread = next(
1275 event.thread for event in events if event.type == "thread.created"
1276 )
1277 updated_thread = await server.process_non_streaming(
1278 ThreadsUpdateReq(
1279 params=ThreadUpdateParams(thread_id=thread.id, title="New Title")
1280 )
1281 )
1282 updated_thread = TypeAdapter[Thread](Thread).validate_json(updated_thread.json)
1283 assert updated_thread.title == "New Title"
1284
1285
1286async def test_returns_widget_item_generator():
1287 thread = ThreadMetadata(
1288 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1289 )
1290
1291 def render_widget(i: int) -> Card:
1292 return Card(children=[Text(id="text", value="Hello, world"[:i])])
1293
1294 async def widget_generator():
1295 yield render_widget(0)
1296 yield render_widget(3)
1297 yield render_widget(6)
1298 yield render_widget(12)
1299
1300 events = [event async for event in stream_widget(thread, widget_generator())]
1301
1302 assert len(events) == 5
1303 assert isinstance(events[0], ThreadItemAddedEvent)
1304 assert isinstance(events[0].item, WidgetItem)
1305 assert events[0].item.widget == Card(children=[Text(id="text", value="")])
1306
1307 assert isinstance(events[1], ThreadItemUpdatedEvent)
1308 assert events[1].update.type == "widget.streaming_text.value_delta"
1309 assert events[1].update.component_id == "text"
1310 assert events[1].update.delta == "Hel"
1311
1312 assert isinstance(events[2], ThreadItemUpdatedEvent)
1313 assert events[2].update.type == "widget.streaming_text.value_delta"
1314 assert events[2].update.component_id == "text"
1315 assert events[2].update.delta == "lo,"
1316
1317 assert isinstance(events[3], ThreadItemUpdatedEvent)
1318 assert events[3].update.type == "widget.streaming_text.value_delta"
1319 assert events[3].update.component_id == "text"
1320 assert events[3].update.delta == " world"
1321
1322 assert isinstance(events[4], ThreadItemDoneEvent)
1323 assert isinstance(events[4].item, WidgetItem)
1324 assert events[4].item.widget == Card(
1325 children=[Text(id="text", value="Hello, world")]
1326 )
1327
1328
1329async def test_returns_widget_item_generator_full_replace():
1330 thread = ThreadMetadata(
1331 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1332 )
1333
1334 async def widget_generator():
1335 yield Card(children=[Text(id="text", value="Hello")])
1336 yield Card(children=[Text(id="text", value="World")])
1337
1338 events = [event async for event in stream_widget(thread, widget_generator())]
1339
1340 assert len(events) == 3
1341 assert isinstance(events[0], ThreadItemAddedEvent)
1342 assert isinstance(events[0].item, WidgetItem)
1343 assert events[0].item.widget == Card(children=[Text(id="text", value="Hello")])
1344
1345 assert isinstance(events[1], ThreadItemUpdatedEvent)
1346 assert events[1].update.type == "widget.root.updated"
1347 assert events[1].update.widget == Card(children=[Text(id="text", value="World")])
1348
1349 assert isinstance(events[2], ThreadItemDoneEvent)
1350 assert isinstance(events[2].item, WidgetItem)
1351 assert events[2].item.widget == Card(children=[Text(id="text", value="World")])
1352
1353
1354async def test_delete_thread():
1355 with make_server() as server:
1356 events = await server.process_streaming(
1357 ThreadsCreateReq(
1358 params=ThreadCreateParams(
1359 input=UserMessageInput(
1360 content=[UserMessageTextContent(text="Thread to delete")],
1361 attachments=[],
1362 inference_options=InferenceOptions(),
1363 )
1364 )
1365 )
1366 )
1367 thread = next(
1368 event.thread for event in events if event.type == "thread.created"
1369 )
1370
1371 response = await server.process_non_streaming(
1372 ThreadsDeleteReq(params=ThreadDeleteParams(thread_id=thread.id))
1373 )
1374 assert response.json == b"{}"
1375
1376 try:
1377 await server.process_non_streaming(
1378 ThreadsGetByIdReq(
1379 params=ThreadGetByIdParams(thread_id=thread.id),
1380 )
1381 )
1382 assert False, "Expected ValueError for deleted thread"
1383 except NotFoundError as e:
1384 assert f"Thread {thread.id} not found" in str(e)
1385
1386
1387async def test_thread_item_removed_event_removes_item():
1388 removed_id = "msg_to_remove"
1389
1390 async def responder(
1391 thread: ThreadMetadata,
1392 input: UserMessageItem | None,
1393 context: Any,
1394 ) -> AsyncIterator[ThreadStreamEvent]:
1395 yield ThreadItemDoneEvent(
1396 item=UserMessageItem(
1397 id=removed_id,
1398 content=[UserMessageTextContent(text="To be removed")],
1399 attachments=[],
1400 inference_options=InferenceOptions(),
1401 thread_id=thread.id,
1402 created_at=datetime.now(),
1403 ),
1404 )
1405 yield ThreadItemRemovedEvent(item_id=removed_id)
1406
1407 with make_server(responder) as server:
1408 events = await server.process_streaming(
1409 ThreadsCreateReq(
1410 params=ThreadCreateParams(
1411 input=UserMessageInput(
1412 content=[UserMessageTextContent(text="Trigger removal")],
1413 attachments=[],
1414 inference_options=InferenceOptions(),
1415 )
1416 )
1417 )
1418 )
1419 thread = next(
1420 event.thread for event in events if event.type == "thread.created"
1421 )
1422 assert any(
1423 e.type == "thread.item.removed" and e.item_id == removed_id for e in events
1424 )
1425 # The item should not be present in the store
1426 list_result = await server.process_non_streaming(
1427 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1428 )
1429 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1430 assert all(i.id != removed_id for i in items)
1431
1432
1433async def test_raising_in_responder_yields_error_event():
1434 async def responder(
1435 thread: ThreadMetadata,
1436 input: UserMessageItem | None,
1437 context: Any,
1438 ) -> AsyncIterator[ThreadStreamEvent]:
1439 raise ValueError("Test error")
1440 yield
1441
1442 with make_server(responder) as server:
1443 result = await server.process(
1444 ThreadsCreateReq(
1445 params=ThreadCreateParams(
1446 input=UserMessageInput(
1447 content=[UserMessageTextContent(text="Test error")],
1448 attachments=[],
1449 inference_options=InferenceOptions(),
1450 )
1451 )
1452 ).model_dump_json(),
1453 RequestContext(user_id="test_user"),
1454 )
1455 assert isinstance(result, StreamingResult)
1456 iter = result.__aiter__()
1457
1458 # Yield an errro
1459 async for e in iter:
1460 e = decode_event(e)
1461 if e.type == "error":
1462 assert e.code == ErrorCode.STREAM_ERROR
1463 break
1464 else:
1465 assert False, "Expected error event"
1466
1467 # Then finish
1468 try:
1469 await iter.__anext__()
1470 except StopAsyncIteration:
1471 pass
1472 else:
1473 assert False, "Expected StopAsyncIteration"
1474
1475
1476async def test_thread_item_replaced_event_replaces_item():
1477 original_id = "msg_to_replace"
1478
1479 async def responder(
1480 thread: ThreadMetadata,
1481 input: UserMessageItem | None,
1482 context: Any,
1483 ) -> AsyncIterator[ThreadStreamEvent]:
1484 # First add an item
1485 yield ThreadItemDoneEvent(
1486 item=UserMessageItem(
1487 id=original_id,
1488 content=[UserMessageTextContent(text="Original content")],
1489 attachments=[],
1490 inference_options=InferenceOptions(),
1491 thread_id=thread.id,
1492 created_at=datetime.now(),
1493 ),
1494 )
1495 # Then replace it
1496 replacement_item = AssistantMessageItem(
1497 id=original_id,
1498 content=[AssistantMessageContent(text="This is the replacement content")],
1499 created_at=datetime.now(),
1500 thread_id=thread.id,
1501 )
1502 yield ThreadItemReplacedEvent(item=replacement_item)
1503
1504 with make_server(responder) as server:
1505 events = await server.process_streaming(
1506 ThreadsCreateReq(
1507 params=ThreadCreateParams(
1508 input=UserMessageInput(
1509 content=[UserMessageTextContent(text="Trigger replacement")],
1510 attachments=[],
1511 inference_options=InferenceOptions(),
1512 )
1513 )
1514 )
1515 )
1516 thread = next(
1517 event.thread for event in events if event.type == "thread.created"
1518 )
1519
1520 # Verify the replacement event was generated
1521 assert any(
1522 e.type == "thread.item.replaced" and e.item.id == original_id
1523 for e in events
1524 )
1525
1526 # Verify the item was replaced in the store
1527 list_result = await server.process_non_streaming(
1528 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1529 )
1530 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1531
1532 # Find the replaced item
1533 replaced_item = next((i for i in items if i.id == original_id), None)
1534 assert replaced_item is not None
1535 assert isinstance(replaced_item, AssistantMessageItem)
1536 assert isinstance(replaced_item.content[0], AssistantMessageContent)
1537 assert replaced_item.content[0].text == "This is the replacement content"
1538
1539
1540async def test_thread_item_replaced_event_with_widget():
1541 widget_id = "widget_to_replace"
1542 original_widget = Card(children=[Text(value="Original widget content")])
1543 replacement_widget = Card(children=[Text(value="Replaced widget content")])
1544
1545 async def responder(
1546 thread: ThreadMetadata,
1547 input: UserMessageItem | None,
1548 context: Any,
1549 ) -> AsyncIterator[ThreadStreamEvent]:
1550 # First add a widget item
1551 yield ThreadItemDoneEvent(
1552 item=WidgetItem(
1553 id=widget_id,
1554 widget=original_widget,
1555 created_at=datetime.now(),
1556 thread_id=thread.id,
1557 ),
1558 )
1559 # Then replace it
1560 yield ThreadItemReplacedEvent(
1561 item=WidgetItem(
1562 id=widget_id,
1563 widget=replacement_widget,
1564 created_at=datetime.now(),
1565 thread_id=thread.id,
1566 ),
1567 )
1568
1569 with make_server(responder) as server:
1570 events = await server.process_streaming(
1571 ThreadsCreateReq(
1572 params=ThreadCreateParams(
1573 input=UserMessageInput(
1574 content=[
1575 UserMessageTextContent(text="Trigger widget replacement")
1576 ],
1577 attachments=[],
1578 inference_options=InferenceOptions(),
1579 )
1580 )
1581 )
1582 )
1583 thread = next(
1584 event.thread for event in events if event.type == "thread.created"
1585 )
1586
1587 # Verify the replacement event was generated
1588 replacement_event = next(
1589 (
1590 e
1591 for e in events
1592 if e.type == "thread.item.replaced" and e.item.id == widget_id
1593 ),
1594 None,
1595 )
1596 assert replacement_event is not None
1597 assert isinstance(replacement_event.item, WidgetItem)
1598 assert replacement_event.item.widget == replacement_widget
1599
1600 # Verify the item was replaced in the store
1601 list_result = await server.process_non_streaming(
1602 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1603 )
1604 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1605
1606 # Find the replaced widget item
1607 replaced_item = next((i for i in items if i.id == widget_id), None)
1608 assert replaced_item is not None
1609 assert isinstance(replaced_item, WidgetItem)
1610 assert isinstance(replaced_item.widget, Card)
1611 assert isinstance(replaced_item.widget.children[0], Text)
1612 assert replaced_item.widget.children[0].value == "Replaced widget content"
1613
1614
1615async def test_retry_after_item():
1616 responder_calls = []
1617
1618 async def responder(
1619 thread: ThreadMetadata,
1620 input: UserMessageItem | None,
1621 context: Any,
1622 ) -> AsyncIterator[ThreadStreamEvent]:
1623 responder_calls.append(input)
1624
1625 if len(responder_calls) == 1:
1626 # First response - return an assistant message
1627 yield ThreadItemDoneEvent(
1628 item=AssistantMessageItem(
1629 id="assistant_msg_1",
1630 content=[AssistantMessageContent(text="First response")],
1631 created_at=datetime.now(),
1632 thread_id=thread.id,
1633 ),
1634 )
1635 else:
1636 # Retry response - return different content
1637 yield ThreadItemDoneEvent(
1638 item=AssistantMessageItem(
1639 id="assistant_msg_2",
1640 content=[AssistantMessageContent(text="Retried response")],
1641 created_at=datetime.now(),
1642 thread_id=thread.id,
1643 ),
1644 )
1645
1646 with make_server(responder) as server:
1647 events = await server.process_streaming(
1648 ThreadsCreateReq(
1649 params=ThreadCreateParams(
1650 input=UserMessageInput(
1651 content=[UserMessageTextContent(text="Hello, world!")],
1652 attachments=[],
1653 inference_options=InferenceOptions(),
1654 )
1655 )
1656 )
1657 )
1658 thread = next(
1659 event.thread for event in events if event.type == "thread.created"
1660 )
1661
1662 list_result = await server.process_non_streaming(
1663 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1664 )
1665 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1666 assert len(items_before.data) == 2
1667 assert items_before.data[0].type == "assistant_message"
1668 assert items_before.data[0].content[0].type == "output_text"
1669 assert items_before.data[0].content[0].text == "First response"
1670 assert items_before.data[1].type == "user_message"
1671 assert items_before.data[1].content[0].text == "Hello, world!"
1672
1673 # Get the user message ID for retrying after it
1674 user_message_id = items_before.data[1].id
1675
1676 # Retry after the user message
1677 retry_events = await server.process_streaming(
1678 ThreadsRetryAfterItemReq(
1679 params=ThreadRetryAfterItemParams(
1680 thread_id=thread.id, item_id=user_message_id
1681 )
1682 )
1683 )
1684
1685 # Verify the retry generated new response
1686 assert len(retry_events) == 2
1687 assert retry_events[0].type == "stream_options"
1688 assert retry_events[1].type == "thread.item.done"
1689 assert retry_events[1].item.type == "assistant_message"
1690 assert retry_events[1].item.content[0].type == "output_text"
1691 assert retry_events[1].item.content[0].text == "Retried response"
1692
1693 # Verify the responder was called twice with the same user message
1694 assert len(responder_calls) == 2
1695 assert responder_calls[0] == responder_calls[1]
1696 assert responder_calls[0].type == "user_message"
1697 assert responder_calls[0].content[0].text == "Hello, world!"
1698
1699 # Verify the assistant response was replaced in storage
1700 list_result_after = await server.process_non_streaming(
1701 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1702 )
1703 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1704 list_result_after.json
1705 )
1706 assert len(items_after.data) == 2 # Still user message + assistant message
1707 assert items_after.data[0].type == "assistant_message"
1708 assert items_after.data[0].content[0].type == "output_text"
1709 assert items_after.data[0].content[0].text == "Retried response" # New response
1710 assert items_after.data[1].type == "user_message"
1711 assert items_after.data[1].content[0].text == "Hello, world!"
1712
1713
1714async def test_retry_after_item_invalid_item():
1715 async def responder(
1716 thread: ThreadMetadata,
1717 input: UserMessageItem | None,
1718 context: Any,
1719 ) -> AsyncIterator[ThreadStreamEvent]:
1720 yield ThreadItemDoneEvent(
1721 item=AssistantMessageItem(
1722 id="assistant_msg_1",
1723 content=[AssistantMessageContent(text="Response")],
1724 created_at=datetime.now(),
1725 thread_id=thread.id,
1726 ),
1727 )
1728
1729 with make_server(responder) as server:
1730 events = await server.process_streaming(
1731 ThreadsCreateReq(
1732 params=ThreadCreateParams(
1733 input=UserMessageInput(
1734 content=[UserMessageTextContent(text="Hello")],
1735 attachments=[],
1736 inference_options=InferenceOptions(),
1737 )
1738 )
1739 )
1740 )
1741 thread = next(
1742 event.thread for event in events if event.type == "thread.created"
1743 )
1744
1745 items_result = await server.process_non_streaming(
1746 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1747 )
1748 items = TypeAdapter(Page[ThreadItem]).validate_json(items_result.json)
1749 assistant_item_id = items.data[0].id
1750
1751 # Try to retry after an assistant message (should fail)
1752 with pytest.raises(ValueError, match="is not a user message"):
1753 await server.process_streaming(
1754 ThreadsRetryAfterItemReq(
1755 params=ThreadRetryAfterItemParams(
1756 thread_id=thread.id, item_id=assistant_item_id
1757 )
1758 )
1759 )
1760
1761
1762async def test_retry_after_item_with_multiple_messages():
1763 responder_calls = []
1764
1765 async def responder(
1766 thread: ThreadMetadata,
1767 input: UserMessageItem | None,
1768 context: Any,
1769 ) -> AsyncIterator[ThreadStreamEvent]:
1770 responder_calls.append(input)
1771 assert input is not None
1772 assert input.type == "user_message"
1773 yield ThreadItemDoneEvent(
1774 item=AssistantMessageItem(
1775 id=f"assistant_msg_{len(responder_calls)}",
1776 content=[
1777 AssistantMessageContent(
1778 text=f"Response to: {input.content[0].text}"
1779 )
1780 ],
1781 created_at=datetime.now(),
1782 thread_id=thread.id,
1783 ),
1784 )
1785
1786 with make_server(responder) as server:
1787 events = await server.process_streaming(
1788 ThreadsCreateReq(
1789 params=ThreadCreateParams(
1790 input=UserMessageInput(
1791 content=[UserMessageTextContent(text="First message")],
1792 attachments=[],
1793 inference_options=InferenceOptions(),
1794 )
1795 )
1796 )
1797 )
1798 thread = next(
1799 event.thread for event in events if event.type == "thread.created"
1800 )
1801
1802 await server.process_streaming(
1803 ThreadsAddUserMessageReq(
1804 params=ThreadAddUserMessageParams(
1805 thread_id=thread.id,
1806 input=UserMessageInput(
1807 content=[UserMessageTextContent(text="Second message")],
1808 attachments=[],
1809 inference_options=InferenceOptions(),
1810 ),
1811 )
1812 )
1813 )
1814
1815 # Verify we have 4 items: user1, assistant1, user2, assistant2
1816 list_result = await server.process_non_streaming(
1817 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1818 )
1819 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1820 assert len(items_before.data) == 4
1821
1822 # Get the second user message ID to retry after it
1823 second_user_message_id = items_before.data[1].id # Second user message
1824
1825 # Retry after the second user message
1826 retry_events = await server.process_streaming(
1827 ThreadsRetryAfterItemReq(
1828 params=ThreadRetryAfterItemParams(
1829 thread_id=thread.id, item_id=second_user_message_id
1830 )
1831 )
1832 )
1833
1834 assert len(retry_events) == 2
1835 assert retry_events[0].type == "stream_options"
1836 assert retry_events[1].type == "thread.item.done"
1837
1838 # Verify retry used the second user message
1839 assert len(responder_calls) == 3 # Original 2 + 1 retry
1840 assert responder_calls[2].type == "user_message"
1841 assert responder_calls[2].content[0].text == "Second message"
1842
1843 # Verify only the last assistant response was removed and regenerated
1844 list_result_after = await server.process_non_streaming(
1845 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1846 )
1847 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1848 list_result_after.json
1849 )
1850 assert len(items_after.data) == 4 # Still 4 items
1851 # First user message and response should remain unchanged
1852 assert items_after.data[3].type == "user_message"
1853 assert items_after.data[3].content[0].text == "First message"
1854 assert items_after.data[3].content[0].type == "input_text"
1855 assert items_after.data[3].content[0].text == "First message"
1856 assert items_after.data[2].type == "assistant_message"
1857 assert items_after.data[2].content[0].type == "output_text"
1858 assert items_after.data[2].content[0].text == "Response to: First message"
1859 # Second user message should remain, but assistant response should be new
1860 assert items_after.data[1].type == "user_message"
1861 assert items_after.data[1].content[0].text == "Second message"
1862 assert items_after.data[0].type == "assistant_message"
1863 assert items_after.data[0].content[0].type == "output_text"
1864 assert items_after.data[0].content[0].text == "Response to: Second message"
1865 assert items_after.data[0].id != items_before.data[0].id
1866
1867
1868async def test_threads_create_passes_tools_to_responder():
1869 tool_choice = ToolChoice(id="web_search")
1870
1871 async def responder(
1872 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1873 ) -> AsyncIterator[ThreadStreamEvent]:
1874 assert input is not None
1875 assert input.inference_options.tool_choice is not None
1876 assert input.inference_options.tool_choice.id == tool_choice.id
1877
1878 yield ThreadItemDoneEvent(
1879 item=AssistantMessageItem(
1880 id="assistant_msg_tools_create",
1881 content=[AssistantMessageContent(text="ok")],
1882 created_at=datetime.now(),
1883 thread_id=thread.id,
1884 ),
1885 )
1886
1887 with make_server(responder) as server:
1888 events = await server.process_streaming(
1889 ThreadsCreateReq(
1890 params=ThreadCreateParams(
1891 input=UserMessageInput(
1892 content=[UserMessageTextContent(text="Hello with tools")],
1893 attachments=[],
1894 inference_options=InferenceOptions(tool_choice=tool_choice),
1895 )
1896 )
1897 )
1898 )
1899 # Sanity check that the request flowed through.
1900 assert any(e.type == "thread.item.done" for e in events)
1901
1902
1903async def test_input_transcribe_decodes_base64_and_normalizes_mime_type():
1904 audio_bytes = b"hello audio"
1905 audio_b64 = base64.b64encode(audio_bytes).decode("ascii")
1906 seen: dict[str, Any] = {}
1907
1908 def transcribe_callback(
1909 audio_input: AudioInput, context: Any
1910 ) -> TranscriptionResult:
1911 seen["audio"] = audio_input.data
1912 seen["mime"] = audio_input.mime_type
1913 seen["media_type"] = audio_input.media_type
1914 seen["context"] = context
1915 return TranscriptionResult(text="ok")
1916
1917 with make_server(transcribe_callback=transcribe_callback) as server:
1918 result = await server.process_non_streaming(
1919 InputTranscribeReq(
1920 params=InputTranscribeParams(
1921 audio_base64=audio_b64,
1922 mime_type="audio/webm;codecs=opus",
1923 )
1924 )
1925 )
1926 parsed = TypeAdapter(TranscriptionResult).validate_json(result.json)
1927 assert parsed.text == "ok"
1928
1929 assert seen["audio"] == audio_bytes
1930 assert seen["mime"] == "audio/webm;codecs=opus"
1931 assert seen["media_type"] == "audio/webm"
1932 assert seen["context"] == DEFAULT_CONTEXT
1933
1934
1935async def test_retry_after_item_passes_tools_to_responder():
1936 pass
1937
1938
1939async def test_threads_add_message_passes_tools_to_responder():
1940 tool_choice = ToolChoice(id="retrieval")
1941
1942 async def responder(
1943 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1944 ) -> AsyncIterator[ThreadStreamEvent]:
1945 assert input is not None
1946 assert input.inference_options.tool_choice is not None
1947 assert input.inference_options.tool_choice.id == tool_choice.id
1948
1949 yield ThreadItemDoneEvent(
1950 item=AssistantMessageItem(
1951 id="assistant_msg_tools_add",
1952 content=[AssistantMessageContent(text="ok")],
1953 created_at=datetime.now(),
1954 thread_id=thread.id,
1955 ),
1956 )
1957
1958 with make_server(responder) as server:
1959 # Create a thread first.
1960 events = await server.process_streaming(
1961 ThreadsCreateReq(
1962 params=ThreadCreateParams(
1963 input=UserMessageInput(
1964 content=[UserMessageTextContent(text="Start")],
1965 attachments=[],
1966 inference_options=InferenceOptions(tool_choice=tool_choice),
1967 )
1968 )
1969 )
1970 )
1971 thread = next(e.thread for e in events if e.type == "thread.created")
1972
1973 # Add a message with tools and verify they reach the responder.
1974 events = await server.process_streaming(
1975 ThreadsAddUserMessageReq(
1976 params=ThreadAddUserMessageParams(
1977 thread_id=thread.id,
1978 input=UserMessageInput(
1979 content=[UserMessageTextContent(text="Second")],
1980 attachments=[],
1981 inference_options=InferenceOptions(tool_choice=tool_choice),
1982 ),
1983 )
1984 )
1985 )
1986 # Sanity check that the request flowed through.
1987 assert any(e.type == "thread.item.done" for e in events)
1988
1989
1990async def test_custom_id_generators_on_server():
1991 class UniqueIdStore(SQLiteStore):
1992 def __init__(self, path):
1993 super().__init__(path)
1994 self.counter = 0
1995
1996 def generate_thread_id(self, context) -> str:
1997 self.counter += 1
1998 return f"thr_custom_{self.counter}"
1999
2000 def generate_item_id(
2001 self,
2002 item_type: str,
2003 thread: ThreadMetadata,
2004 context,
2005 ) -> str:
2006 self.counter += 1
2007 return f"{item_type}_custom_{self.counter}_{thread.id}"
2008
2009 async def responder(
2010 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
2011 ) -> AsyncIterator[ThreadStreamEvent]:
2012 # Emit a tool call to later verify tool_call_output id generation
2013 return
2014 yield
2015
2016 with make_server(responder) as server:
2017 # Swap in our custom ID-generating store, reusing the same DB path.
2018 db_path = cast(SQLiteStore, server.store).db_path
2019 server.store = UniqueIdStore(db_path)
2020
2021 # Create thread and verify thread id and user message id
2022 result = await server.process(
2023 ThreadsCreateReq(
2024 params=ThreadCreateParams(
2025 input=UserMessageInput(
2026 content=[UserMessageTextContent(text="Hello")],
2027 attachments=[],
2028 inference_options=InferenceOptions(),
2029 )
2030 )
2031 ).model_dump_json(),
2032 DEFAULT_CONTEXT,
2033 )
2034 assert isinstance(result, StreamingResult)
2035 events = await decode_streaming_result(result)
2036
2037 thread_created = next(e for e in events if e.type == "thread.created")
2038 assert thread_created.thread.id == "thr_custom_1"
2039
2040 user_message = next(
2041 e.item
2042 for e in events
2043 if e.type == "thread.item.done" and e.item.type == "user_message"
2044 )
2045 assert user_message.id == "message_custom_2_thr_custom_1"
2046