openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.6.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_chatkit_server.py

2120lines · modecode

1import asyncio
2import base64
3import sqlite3
4from contextlib import contextmanager
5from datetime import datetime
6from typing import Any, AsyncIterator, Callable, cast
7
8import pytest
9from helpers.mock_store import SQLiteStore
10from pydantic import AnyUrl, TypeAdapter
11
12from chatkit.actions import Action
13from chatkit.errors import ErrorCode
14from chatkit.server import (
15 ChatKitServer,
16 NonStreamingResult,
17 StreamingResult,
18 stream_widget,
19)
20from chatkit.store import (
21 AttachmentStore,
22 NotFoundError,
23 StoreItemType,
24 default_generate_id,
25)
26from chatkit.types import (
27 AssistantMessageContent,
28 AssistantMessageContentPartTextDelta,
29 AssistantMessageItem,
30 Attachment,
31 AttachmentCreateParams,
32 AttachmentDeleteParams,
33 AttachmentsCreateReq,
34 AttachmentsDeleteReq,
35 AttachmentUploadDescriptor,
36 AudioInput,
37 ClientToolCallItem,
38 CustomTask,
39 FeedbackKind,
40 FileAttachment,
41 ImageAttachment,
42 InferenceOptions,
43 InputTranscribeParams,
44 InputTranscribeReq,
45 ItemFeedbackParams,
46 ItemsFeedbackReq,
47 ItemsListParams,
48 ItemsListReq,
49 LockedStatus,
50 Page,
51 ProgressUpdateEvent,
52 Thread,
53 ThreadAddClientToolOutputParams,
54 ThreadAddUserMessageParams,
55 ThreadCreatedEvent,
56 ThreadCreateParams,
57 ThreadCustomActionParams,
58 ThreadDeleteParams,
59 ThreadGetByIdParams,
60 ThreadItem,
61 ThreadItemAddedEvent,
62 ThreadItemDoneEvent,
63 ThreadItemRemovedEvent,
64 ThreadItemReplacedEvent,
65 ThreadItemUpdatedEvent,
66 ThreadListParams,
67 ThreadMetadata,
68 ThreadRetryAfterItemParams,
69 ThreadsAddClientToolOutputReq,
70 ThreadsAddUserMessageReq,
71 ThreadsCreateReq,
72 ThreadsCustomActionReq,
73 ThreadsDeleteReq,
74 ThreadsGetByIdReq,
75 ThreadsListReq,
76 ThreadsRetryAfterItemReq,
77 ThreadStreamEvent,
78 ThreadsUpdateReq,
79 ThreadUpdatedEvent,
80 ThreadUpdateParams,
81 ToolChoice,
82 TranscriptionResult,
83 UserMessageInput,
84 UserMessageItem,
85 UserMessageTextContent,
86 WidgetItem,
87 WidgetRootUpdated,
88 Workflow,
89 WorkflowItem,
90 WorkflowTaskAdded,
91)
92from chatkit.widgets import Card, Text
93from tests._types import RequestContext
94from tests.test_store import make_thread_items
95
96server_id = 0
97
98
99DEFAULT_CONTEXT = RequestContext(user_id="test_user")
100
101
102class InMemoryFileStore(AttachmentStore):
103 def __init__(self):
104 self.files = {}
105
106 async def delete_attachment(self, attachment_id: str, context: Any) -> None:
107 del self.files[attachment_id]
108
109 async def create_attachment(
110 self, input: AttachmentCreateParams, context: RequestContext
111 ) -> Attachment:
112 # Simple in-memory attachment creation
113 id = f"atc_{len(self.files) + 1}"
114 metadata = {"source": "test"}
115 if input.mime_type.startswith("image/"):
116 attachment = ImageAttachment(
117 id=id,
118 mime_type=input.mime_type,
119 name=input.name,
120 preview_url=AnyUrl(f"https://example.com/{id}/preview"),
121 upload_descriptor=AttachmentUploadDescriptor(
122 url=AnyUrl(f"https://example.com/{id}/upload"),
123 method="PUT",
124 headers={"X-My-Header": "my-value"},
125 ),
126 metadata=metadata,
127 )
128 else:
129 attachment = FileAttachment(
130 id=id,
131 mime_type=input.mime_type,
132 name=input.name,
133 upload_descriptor=AttachmentUploadDescriptor(
134 url=AnyUrl(f"https://example.com/{id}/upload"),
135 method="PUT",
136 headers={"X-My-Header": "my-value"},
137 ),
138 metadata=metadata,
139 )
140 self.files[attachment.id] = attachment
141 return attachment
142
143
144def decode_event(event: bytes) -> ThreadStreamEvent:
145 return TypeAdapter(ThreadStreamEvent).validate_json(event.split(b"data: ")[1])
146
147
148async def decode_streaming_result(
149 streaming_result: StreamingResult,
150) -> list[ThreadStreamEvent]:
151 return [decode_event(event) async for event in streaming_result.json_events]
152
153
154@contextmanager
155def make_server(
156 responder: Callable[
157 [ThreadMetadata, UserMessageItem | None, Any], AsyncIterator[ThreadStreamEvent]
158 ]
159 | None = None,
160 handle_feedback: Callable[[str, list[str], FeedbackKind, Any], None] | None = None,
161 action_callback: Callable[
162 [ThreadMetadata, Action[str, Any], WidgetItem | None, Any],
163 AsyncIterator[ThreadStreamEvent],
164 ]
165 | None = None,
166 file_store: AttachmentStore | None = None,
167 transcribe_callback: Callable[[AudioInput, Any], TranscriptionResult] | None = None,
168):
169 global server_id
170 db_path = f"file:{server_id}?mode=memory&cache=shared"
171 server_id += 1
172 # Keep the shared in-memory database from being deleted when the last connection is closed
173 db = sqlite3.connect(db_path, uri=True)
174
175 class TestChatKitServer(ChatKitServer):
176 def __init__(self):
177 super().__init__(SQLiteStore(db_path), file_store)
178
179 def action(
180 self,
181 thread: ThreadMetadata,
182 action: Action[str, Any],
183 sender: WidgetItem | None,
184 context: Any,
185 ) -> AsyncIterator[ThreadStreamEvent]:
186 if action_callback is None:
187 raise ValueError("action_callback not wired up")
188 return action_callback(thread, action, sender, context)
189
190 def respond(
191 self,
192 thread: ThreadMetadata,
193 input_user_message: UserMessageItem | None,
194 context: Any,
195 ) -> AsyncIterator[ThreadStreamEvent]:
196 if responder is None:
197 return self._empty_responder()
198 return responder(thread, input_user_message, context)
199
200 async def _empty_responder(self) -> AsyncIterator[ThreadStreamEvent]:
201 return
202 yield
203
204 async def add_feedback(
205 self,
206 thread_id: str,
207 item_ids: list[str],
208 feedback: FeedbackKind,
209 context: Any,
210 ) -> None:
211 if handle_feedback is None:
212 return
213 handle_feedback(thread_id, item_ids, feedback, context)
214
215 async def transcribe(
216 self, audio_input: AudioInput, context: Any
217 ) -> TranscriptionResult:
218 if transcribe_callback is None:
219 return await super().transcribe(audio_input, context)
220 return transcribe_callback(audio_input, context)
221
222 async def process_streaming(
223 self, request_obj, context: Any | None = None
224 ) -> list[ThreadStreamEvent]:
225 result = await self.process(
226 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
227 )
228 assert isinstance(result, StreamingResult)
229 return await decode_streaming_result(result)
230
231 async def process_non_streaming(self, request_obj, context: Any | None = None):
232 result = await self.process(
233 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
234 )
235 assert isinstance(result, NonStreamingResult)
236 return result
237
238 try:
239 yield TestChatKitServer()
240 finally:
241 db.close()
242
243
244async def test_stream_cancellation_persists_pending_assistant_message_and_hidden_context():
245 async def responder(
246 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
247 ) -> AsyncIterator[ThreadStreamEvent]:
248 yield ThreadItemAddedEvent(
249 item=AssistantMessageItem(
250 id="assistant-message-pending",
251 created_at=datetime.now(),
252 content=[AssistantMessageContent(text="Hello, ")],
253 thread_id=thread.id,
254 )
255 )
256 yield ThreadItemUpdatedEvent(
257 item_id="assistant-message-pending",
258 update=AssistantMessageContentPartTextDelta(
259 content_index=0,
260 delta="World!",
261 ),
262 )
263 raise asyncio.CancelledError()
264
265 with make_server(responder) as server:
266 # Allow hidden_context id generation in this test store
267 original_generate_item_id = server.store.generate_item_id
268
269 def generate_item_id(
270 item_type: StoreItemType, thread: ThreadMetadata, context: Any
271 ):
272 if item_type == "sdk_hidden_context":
273 return default_generate_id("sdk_hidden_context")
274 return original_generate_item_id(item_type, thread, context)
275
276 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
277
278 stream = await server.process(
279 ThreadsCreateReq(
280 params=ThreadCreateParams(
281 input=UserMessageInput(
282 content=[UserMessageTextContent(text="Hello")],
283 attachments=[],
284 inference_options=InferenceOptions(),
285 )
286 )
287 ).model_dump_json(),
288 DEFAULT_CONTEXT,
289 )
290 assert isinstance(stream, StreamingResult)
291
292 events: list[ThreadStreamEvent] = []
293 with pytest.raises(asyncio.CancelledError): # noqa: PT012
294 async for raw in stream.json_events:
295 events.append(decode_event(raw))
296
297 thread = next(e.thread for e in events if e.type == "thread.created")
298 items = await server.store.load_thread_items(
299 thread.id, None, 1, "desc", DEFAULT_CONTEXT
300 )
301 hidden_context_item = items.data[-1]
302 assert hidden_context_item.type == "sdk_hidden_context"
303 assert (
304 hidden_context_item.content
305 == "The user cancelled the stream. Stop responding to the prior request."
306 )
307
308 assistant_message_item = await server.store.load_item(
309 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
310 )
311 assert assistant_message_item.type == "assistant_message"
312 assert assistant_message_item.content[0].text == "Hello, World!"
313
314
315async def test_stream_cancellation_does_not_persist_pending_empty_assistant_message():
316 async def responder(
317 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
318 ) -> AsyncIterator[ThreadStreamEvent]:
319 yield ThreadItemAddedEvent(
320 item=AssistantMessageItem(
321 id="assistant-message-pending",
322 created_at=datetime.now(),
323 content=[],
324 thread_id=thread.id,
325 )
326 )
327 raise asyncio.CancelledError()
328
329 with make_server(responder) as server:
330 original_generate_item_id = server.store.generate_item_id
331
332 def generate_item_id(
333 item_type: StoreItemType, thread: ThreadMetadata, context: Any
334 ):
335 if item_type == "sdk_hidden_context":
336 return default_generate_id("sdk_hidden_context")
337 return original_generate_item_id(item_type, thread, context)
338
339 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
340
341 stream = await server.process(
342 ThreadsCreateReq(
343 params=ThreadCreateParams(
344 input=UserMessageInput(
345 content=[UserMessageTextContent(text="Hello")],
346 attachments=[],
347 inference_options=InferenceOptions(),
348 )
349 )
350 ).model_dump_json(),
351 DEFAULT_CONTEXT,
352 )
353 assert isinstance(stream, StreamingResult)
354
355 events: list[ThreadStreamEvent] = []
356 with pytest.raises(asyncio.CancelledError): # noqa: PT012
357 async for raw in stream.json_events:
358 events.append(decode_event(raw))
359
360 thread = next(e.thread for e in events if e.type == "thread.created")
361 items = await server.store.load_thread_items(
362 thread.id, None, 1, "desc", DEFAULT_CONTEXT
363 )
364 hidden_context_item = items.data[-1]
365 assert hidden_context_item.type == "sdk_hidden_context"
366 assert (
367 hidden_context_item.content
368 == "The user cancelled the stream. Stop responding to the prior request."
369 )
370
371 with pytest.raises(NotFoundError):
372 await server.store.load_item(
373 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
374 )
375
376
377async def test_workflow_task_not_duplicated_on_done_event():
378 """
379 Regression test to make sure pending item updates do not modify the
380 origin item that was streamed.
381 """
382
383 async def responder(
384 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
385 ) -> AsyncIterator[ThreadStreamEvent]:
386 workflow_item = WorkflowItem(
387 id="workflow-item",
388 created_at=datetime.now(),
389 thread_id=thread.id,
390 workflow=Workflow(type="custom", tasks=[]),
391 )
392
393 yield ThreadItemAddedEvent(item=workflow_item)
394
395 task = CustomTask(title="foo", content="bar")
396 yield ThreadItemUpdatedEvent(
397 item_id=workflow_item.id,
398 update=WorkflowTaskAdded(task=task, task_index=0),
399 )
400
401 workflow_item.workflow.tasks.append(task)
402 yield ThreadItemDoneEvent(item=workflow_item)
403
404 with make_server(responder) as server:
405 events = await server.process_streaming(
406 ThreadsCreateReq(
407 params=ThreadCreateParams(
408 input=UserMessageInput(
409 content=[UserMessageTextContent(text="Hello")],
410 attachments=[],
411 inference_options=InferenceOptions(),
412 )
413 )
414 )
415 )
416
417 thread = next(e.thread for e in events if e.type == "thread.created")
418 workflow_done_event = next(
419 e
420 for e in events
421 if isinstance(e, ThreadItemDoneEvent) and isinstance(e.item, WorkflowItem)
422 )
423 workflow_done_item = cast(WorkflowItem, workflow_done_event.item)
424 assert len(workflow_done_item.workflow.tasks) == 1
425 assert workflow_done_item.workflow.tasks[0].title == "foo"
426
427 stored = await server.store.load_item(
428 thread.id, workflow_done_item.id, DEFAULT_CONTEXT
429 )
430 assert isinstance(stored, WorkflowItem)
431 stored_workflow = cast(WorkflowItem, stored)
432 assert len(stored_workflow.workflow.tasks) == 1
433 assert stored_workflow.workflow.tasks[0].title == "foo"
434
435
436async def test_flows_context_to_responder():
437 responder_context = None
438 add_feedback_context = None
439
440 async def responder(
441 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
442 ) -> AsyncIterator[ThreadStreamEvent]:
443 nonlocal responder_context
444 responder_context = context
445 yield ThreadItemDoneEvent(
446 item=AssistantMessageItem(
447 id="msg_1",
448 created_at=datetime.now(),
449 content=[AssistantMessageContent(text="Hello, world!")],
450 thread_id=thread.id,
451 ),
452 )
453
454 def add_feedback(
455 thread_id: str, item_ids: list[str], feedback: FeedbackKind, context: Any
456 ) -> None:
457 nonlocal add_feedback_context
458 add_feedback_context = context
459
460 context = RequestContext(user_id="text_ctx")
461 with make_server(responder, add_feedback) as server:
462 events = await server.process_streaming(
463 ThreadsCreateReq(
464 params=ThreadCreateParams(
465 input=UserMessageInput(
466 content=[UserMessageTextContent(text="Hello, world!")],
467 attachments=[],
468 inference_options=InferenceOptions(),
469 )
470 )
471 ),
472 context,
473 )
474 assert responder_context == context
475 thread = next(
476 event.thread for event in events if event.type == "thread.created"
477 )
478 item = next(
479 event.item
480 for event in events
481 if event.type == "thread.item.done"
482 and event.item.type == "assistant_message"
483 )
484 await server.process_non_streaming(
485 ItemsFeedbackReq(
486 params=ItemFeedbackParams(
487 thread_id=thread.id,
488 item_ids=[item.id],
489 kind="positive",
490 )
491 ),
492 context,
493 )
494 assert add_feedback_context == context
495
496
497async def test_creates_thread():
498 async def responder(
499 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
500 ) -> AsyncIterator[ThreadStreamEvent]:
501 return
502 yield
503
504 with make_server(responder) as server:
505 events = await server.process_streaming(
506 ThreadsCreateReq(
507 params=ThreadCreateParams(
508 input=UserMessageInput(
509 content=[UserMessageTextContent(text="Hello, world!")],
510 attachments=[],
511 inference_options=InferenceOptions(),
512 quoted_text="Quoted text",
513 )
514 )
515 )
516 )
517
518 thread_created_event = next(
519 event for event in events if event.type == "thread.created"
520 )
521 assert thread_created_event.thread.title is None
522
523 item_done_event = next(
524 event.item for event in events if event.type == "thread.item.done"
525 )
526
527 assert item_done_event.type == "user_message"
528 assert item_done_event.content[0].text == "Hello, world!"
529 assert item_done_event.quoted_text == "Quoted text"
530
531
532async def test_saves_thread_title():
533 async def responder(
534 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
535 ) -> AsyncIterator[ThreadStreamEvent]:
536 thread.title = "Updated title"
537 return
538 yield
539
540 with make_server(responder) as server:
541 events = await server.process_streaming(
542 ThreadsCreateReq(
543 params=ThreadCreateParams(
544 input=UserMessageInput(
545 content=[UserMessageTextContent(text="Hello, world!")],
546 attachments=[],
547 inference_options=InferenceOptions(),
548 )
549 )
550 )
551 )
552 thread = next(
553 event.thread for event in events if event.type == "thread.created"
554 )
555 assert (
556 await server.store.load_thread(
557 thread.id, RequestContext(user_id="test_user")
558 )
559 ).title == "Updated title"
560 assert events[-1].type == "thread.updated"
561 assert events[-1].thread.title == "Updated title"
562
563
564async def test_saves_thread_metadata():
565 async def responder(
566 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
567 ) -> AsyncIterator[ThreadStreamEvent]:
568 thread.metadata["test_key"] = "test_value"
569 return
570 yield
571
572 with make_server(responder) as server:
573 events = await server.process_streaming(
574 ThreadsCreateReq(
575 params=ThreadCreateParams(
576 input=UserMessageInput(
577 content=[UserMessageTextContent(text="Hello, world!")],
578 attachments=[],
579 inference_options=InferenceOptions(),
580 )
581 )
582 )
583 )
584 thread = next(
585 event.thread for event in events if event.type == "thread.created"
586 )
587 assert (await server.store.load_thread(thread.id, DEFAULT_CONTEXT)).metadata[
588 "test_key"
589 ] == "test_value"
590 assert events[-1].type == "thread.updated"
591
592
593async def test_saves_thread_locked_fields():
594 async def responder(
595 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
596 ) -> AsyncIterator[ThreadStreamEvent]:
597 thread.status = LockedStatus(reason="Because")
598 return
599 yield
600
601 with make_server(responder) as server:
602 events = await server.process_streaming(
603 ThreadsCreateReq(
604 params=ThreadCreateParams(
605 input=UserMessageInput(
606 content=[UserMessageTextContent(text="Hello, world!")],
607 attachments=[],
608 inference_options=InferenceOptions(),
609 )
610 )
611 )
612 )
613 thread = next(
614 event.thread for event in events if event.type == "thread.created"
615 )
616 loaded = await server.store.load_thread(
617 thread.id, RequestContext(user_id="test_user")
618 )
619 assert loaded.status == LockedStatus(reason="Because")
620 assert events[-1].type == "thread.updated"
621 assert events[-1].thread.status == LockedStatus(reason="Because")
622
623
624async def test_emits_thread_updated_mid_stream_and_persists():
625 async def responder(
626 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
627 ) -> AsyncIterator[ThreadStreamEvent]:
628 # First assistant message
629 yield ThreadItemDoneEvent(
630 item=AssistantMessageItem(
631 id="assistant_a",
632 content=[AssistantMessageContent(text="A")],
633 created_at=datetime.now(),
634 thread_id=thread.id,
635 ),
636 )
637
638 # Update the thread mid-stream
639 thread.title = "Mid-stream title"
640
641 # Second assistant message, after which thread.updated should be emitted
642 yield ThreadItemDoneEvent(
643 item=AssistantMessageItem(
644 id="assistant_b",
645 content=[AssistantMessageContent(text="B")],
646 created_at=datetime.now(),
647 thread_id=thread.id,
648 ),
649 )
650
651 # Continue streaming after the update event
652 yield ThreadItemDoneEvent(
653 item=AssistantMessageItem(
654 id="assistant_c",
655 content=[AssistantMessageContent(text="C")],
656 created_at=datetime.now(),
657 thread_id=thread.id,
658 ),
659 )
660
661 with make_server(responder) as server:
662 events = await server.process_streaming(
663 ThreadsCreateReq(
664 params=ThreadCreateParams(
665 input=UserMessageInput(
666 content=[UserMessageTextContent(text="Hello")],
667 attachments=[],
668 inference_options=InferenceOptions(),
669 )
670 )
671 )
672 )
673
674 # Find the created thread
675 thread = next(e.thread for e in events if isinstance(e, ThreadCreatedEvent))
676
677 # Ensure a thread.updated event occurs right after assistant_b
678 idx_b = next(
679 i
680 for i, e in enumerate(events)
681 if isinstance(e, ThreadItemDoneEvent)
682 and isinstance(e.item, AssistantMessageItem)
683 and e.item.id == "assistant_b"
684 )
685 assert isinstance(events[idx_b + 1], ThreadUpdatedEvent)
686 updated_event = cast(ThreadUpdatedEvent, events[idx_b + 1])
687 assert updated_event.thread.title == "Mid-stream title"
688
689 # Streaming continues after the update event
690 assert isinstance(events[idx_b + 2], ThreadItemDoneEvent)
691 done_event = cast(ThreadItemDoneEvent, events[idx_b + 2])
692 assert isinstance(done_event.item, AssistantMessageItem)
693 assert done_event.item.id == "assistant_c"
694
695 # Persisted in the store
696 saved = await server.store.load_thread(thread.id, DEFAULT_CONTEXT)
697 assert saved.title == "Mid-stream title"
698
699
700async def test_respond_with_tool_call():
701 async def responder(
702 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
703 ) -> AsyncIterator[ThreadStreamEvent]:
704 if isinstance(input, UserMessageItem):
705 yield ThreadItemDoneEvent(
706 item=ClientToolCallItem(
707 id="msg_1",
708 created_at=datetime.now(),
709 name="tool_call_1",
710 arguments={"arg1": "val1", "arg2": False},
711 call_id="tool_call_1",
712 thread_id=thread.id,
713 ),
714 )
715 elif input is None:
716 yield ThreadItemDoneEvent(
717 item=AssistantMessageItem(
718 id="msg_2",
719 content=[
720 AssistantMessageContent(text="Glad the tool call succeeded!")
721 ],
722 created_at=datetime.now(),
723 thread_id=thread.id,
724 ),
725 )
726
727 with make_server(responder) as server:
728 events = await server.process_streaming(
729 ThreadsCreateReq(
730 params=ThreadCreateParams(
731 input=UserMessageInput(
732 content=[UserMessageTextContent(text="Hello, world!")],
733 attachments=[],
734 inference_options=InferenceOptions(),
735 )
736 )
737 )
738 )
739
740 assert len(events) == 4
741 assert events[0].type == "thread.created"
742 thread = events[0].thread
743
744 assert events[1].type == "thread.item.done"
745 assert events[1].item.type == "user_message"
746
747 assert events[2].type == "stream_options"
748
749 assert events[3].type == "thread.item.done"
750 assert events[3].item.type == "client_tool_call"
751 assert events[3].item.id == "msg_1"
752 assert events[3].item.name == "tool_call_1"
753 assert events[3].item.arguments == {"arg1": "val1", "arg2": False}
754 assert events[3].item.call_id == "tool_call_1"
755
756 events = await server.process_streaming(
757 ThreadsAddClientToolOutputReq(
758 params=ThreadAddClientToolOutputParams(
759 thread_id=thread.id,
760 result={"text": "Wow!"},
761 )
762 )
763 )
764
765 assert len(events) == 2
766 assert events[0].type == "stream_options"
767 assert events[1].type == "thread.item.done"
768 assert events[1].item.type == "assistant_message"
769
770
771async def test_respond_with_tool_status():
772 async def responder(
773 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
774 ) -> AsyncIterator[ThreadStreamEvent]:
775 yield ProgressUpdateEvent(text="Tool status")
776 yield ThreadItemDoneEvent(
777 item=AssistantMessageItem(
778 id="msg_4",
779 content=[AssistantMessageContent(text="Hello, world!")],
780 created_at=datetime.now(),
781 thread_id=thread.id,
782 ),
783 )
784
785 with make_server(responder) as server:
786 events = await server.process_streaming(
787 ThreadsCreateReq(
788 params=ThreadCreateParams(
789 input=UserMessageInput(
790 content=[UserMessageTextContent(text="Hello, world!")],
791 attachments=[],
792 inference_options=InferenceOptions(),
793 )
794 )
795 )
796 )
797
798 assert len(events) == 5
799 assert events[0].type == "thread.created"
800 assert events[1].type == "thread.item.done"
801 assert events[2].type == "stream_options"
802 assert events[3].type == "progress_update"
803 assert events[4].type == "thread.item.done"
804
805
806async def test_list_threads_response():
807 with make_server() as server:
808 for i in range(25):
809 await server.process_streaming(
810 ThreadsCreateReq(
811 params=ThreadCreateParams(
812 input=UserMessageInput(
813 content=[UserMessageTextContent(text=f"Thread {i}")],
814 attachments=[],
815 inference_options=InferenceOptions(),
816 )
817 )
818 )
819 )
820 list_result = await server.process_non_streaming(
821 ThreadsListReq(params=ThreadListParams(limit=20))
822 )
823 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
824 assert len(threads.data) == 20
825
826 # Newest first
827 assert threads.data[0].created_at > threads.data[1].created_at
828 assert threads.has_more is True
829 assert threads.after is not None
830
831 more_threads = await server.process_non_streaming(
832 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
833 )
834 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(more_threads.json)
835 assert len(threads.data) == 5
836 assert threads.has_more is False
837 assert not threads.after
838
839
840async def test_list_items_response():
841 async def responder(
842 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
843 ) -> AsyncIterator[ThreadStreamEvent]:
844 for i in range(25):
845 yield ThreadItemDoneEvent(
846 item=UserMessageItem(
847 id=f"msg_{i}",
848 content=[UserMessageTextContent(text=f"Message {i}")],
849 attachments=[],
850 inference_options=InferenceOptions(),
851 thread_id=thread.id,
852 created_at=datetime.now(),
853 ),
854 )
855
856 with make_server(responder) as server:
857 events = await server.process_streaming(
858 ThreadsCreateReq(
859 params=ThreadCreateParams(
860 input=UserMessageInput(
861 content=[UserMessageTextContent(text="Thread 1")],
862 attachments=[],
863 inference_options=InferenceOptions(),
864 )
865 )
866 )
867 )
868 thread = next(
869 event.thread for event in events if event.type == "thread.created"
870 )
871 list_result = await server.process_non_streaming(
872 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
873 )
874 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
875 assert len(items.data) == 20
876 # Newest first
877 assert items.data[0].created_at > items.data[1].created_at
878 assert items.has_more is True
879 assert items.after is not None
880
881 list_result = await server.process_non_streaming(
882 ItemsListReq(
883 params=ItemsListParams(thread_id=thread.id, after=items.data[-1].id)
884 )
885 )
886 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
887 assert len(items.data) == 6 # +1 for user message
888 assert items.has_more is False
889
890
891async def test_calls_action():
892 actions = []
893
894 async def responder(
895 thread: ThreadMetadata,
896 input: UserMessageItem | None,
897 context: Any,
898 ) -> AsyncIterator[ThreadStreamEvent]:
899 yield ThreadItemDoneEvent(
900 item=WidgetItem(
901 id="widget_1",
902 type="widget",
903 created_at=datetime.now(),
904 thread_id=thread.id,
905 widget=Card(
906 children=[
907 Text(
908 key="text_1",
909 value="test",
910 )
911 ],
912 ),
913 ),
914 )
915
916 async def action(
917 thread: ThreadMetadata,
918 action: Action[str, Any],
919 sender: WidgetItem | None,
920 context: Any,
921 ) -> AsyncIterator[ThreadStreamEvent]:
922 actions.append((action, sender))
923 assert sender
924
925 yield ThreadItemUpdatedEvent(
926 item_id=sender.id,
927 update=WidgetRootUpdated(
928 widget=Card(
929 children=[
930 Text(value="Email sent!"),
931 ]
932 ),
933 ),
934 )
935
936 with make_server(responder, action_callback=action) as server:
937 events = await server.process_streaming(
938 ThreadsCreateReq(
939 params=ThreadCreateParams(
940 input=UserMessageInput(
941 content=[UserMessageTextContent(text="Show widget")],
942 attachments=[],
943 inference_options=InferenceOptions(),
944 )
945 )
946 )
947 )
948 thread = next(
949 event.thread for event in events if event.type == "thread.created"
950 )
951 widget_item = next(
952 event.item
953 for event in events
954 if isinstance(event, ThreadItemDoneEvent)
955 and isinstance(event.item, WidgetItem)
956 )
957
958 events = await server.process_streaming(
959 ThreadsCustomActionReq(
960 params=ThreadCustomActionParams(
961 thread_id=thread.id,
962 item_id=widget_item.id,
963 action=Action(type="create_user", payload={"user_id": "123"}),
964 )
965 )
966 )
967
968 assert actions
969 assert actions[0] == (
970 Action(type="create_user", payload={"user_id": "123"}),
971 widget_item,
972 )
973
974 assert len(events) == 2
975 assert events[0].type == "stream_options"
976 assert events[1].type == "thread.item.updated"
977 assert isinstance(events[1], ThreadItemUpdatedEvent)
978 assert events[1].update.type == "widget.root.updated"
979 assert events[1].update.widget == Card(children=[Text(value="Email sent!")])
980
981
982async def test_add_feedback():
983 called = []
984
985 def handle_feedback(thread_id, item_ids, feedback, context):
986 called.append((thread_id, item_ids, feedback, context))
987
988 async def responder(
989 thread: ThreadMetadata,
990 input: UserMessageItem | None,
991 context: Any,
992 ) -> AsyncIterator[ThreadStreamEvent]:
993 yield ThreadItemDoneEvent(
994 item=AssistantMessageItem(
995 id="assistant_msg_1",
996 content=[
997 AssistantMessageContent(text="Feedback received!"),
998 ],
999 created_at=datetime.now(),
1000 thread_id=thread.id,
1001 ),
1002 )
1003 yield ThreadItemDoneEvent(
1004 item=WidgetItem(
1005 id="widget_1",
1006 type="widget",
1007 created_at=datetime.now(),
1008 widget=Card(children=[Text(key="text", value="Widget content")]),
1009 thread_id=thread.id,
1010 ),
1011 )
1012
1013 with make_server(responder, handle_feedback=handle_feedback) as server:
1014 events = await server.process_streaming(
1015 ThreadsCreateReq(
1016 params=ThreadCreateParams(
1017 input=UserMessageInput(
1018 content=[UserMessageTextContent(text="Test feedback")],
1019 attachments=[],
1020 inference_options=InferenceOptions(),
1021 )
1022 )
1023 )
1024 )
1025 thread = next(
1026 event.thread for event in events if event.type == "thread.created"
1027 )
1028
1029 await server.process_non_streaming(
1030 ItemsFeedbackReq(
1031 params=ItemFeedbackParams(
1032 thread_id=thread.id,
1033 item_ids=["assistant_msg_1"],
1034 kind="positive",
1035 )
1036 )
1037 )
1038 await server.process_non_streaming(
1039 ItemsFeedbackReq(
1040 params=ItemFeedbackParams(
1041 thread_id=thread.id,
1042 item_ids=["widget_1"],
1043 kind="negative",
1044 )
1045 )
1046 )
1047
1048 assert len(called) == 2
1049 assert called[0][0] == thread.id
1050 assert called[0][1] == ["assistant_msg_1"]
1051 assert called[0][2] == "positive"
1052
1053 assert called[1][0] == thread.id
1054 assert called[1][1] == ["widget_1"]
1055 assert called[1][2] == "negative"
1056
1057
1058async def test_get_thread_by_id():
1059 make_thread_items()
1060
1061 with make_server() as server:
1062 # Create a thread by sending a ThreadCreateRequest
1063 events = await server.process_streaming(
1064 ThreadsCreateReq(
1065 params=ThreadCreateParams(
1066 input=UserMessageInput(
1067 content=[UserMessageTextContent(text="Test thread")],
1068 attachments=[],
1069 inference_options=InferenceOptions(),
1070 )
1071 )
1072 )
1073 )
1074 thread = next(
1075 event.thread for event in events if event.type == "thread.created"
1076 )
1077
1078 # Retrieve the thread by id
1079 result = await server.process_non_streaming(
1080 ThreadsGetByIdReq(
1081 params=ThreadGetByIdParams(thread_id=thread.id),
1082 )
1083 )
1084 loaded_thread = TypeAdapter(Thread).validate_json(result.json)
1085 assert loaded_thread.id == thread.id
1086 assert loaded_thread.title == thread.title
1087 assert loaded_thread.items.data[0].type == "user_message"
1088 assert loaded_thread.items.data[0].content[0].text == "Test thread"
1089
1090
1091async def test_create_file():
1092 store = InMemoryFileStore()
1093 file_name = "test-file-name"
1094 file_content_type = "text/plain"
1095
1096 with make_server(file_store=store) as server:
1097 result = await server.process_non_streaming(
1098 AttachmentsCreateReq(
1099 params=AttachmentCreateParams(
1100 name=file_name, size=1024, mime_type=file_content_type
1101 )
1102 )
1103 )
1104 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1105
1106 # Verify response includes file
1107 assert attachment.id is not None
1108 assert attachment.mime_type == file_content_type
1109 assert attachment.name == file_name
1110 assert attachment.type == "file"
1111 assert attachment.upload_descriptor is not None
1112 assert attachment.upload_descriptor.url == AnyUrl(
1113 f"https://example.com/{attachment.id}/upload"
1114 )
1115 assert attachment.upload_descriptor.method == "PUT"
1116 assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"}
1117 assert attachment.id in store.files
1118
1119 # The `metadata` field is excluded from serialization by default but still stored server-side
1120 assert attachment.metadata is None
1121 assert store.files[attachment.id].metadata == {"source": "test"}
1122
1123
1124async def test_create_image_file():
1125 store = InMemoryFileStore()
1126 file_name = "test-file-name"
1127 file_content_type = "image/png"
1128
1129 with make_server(file_store=store) as server:
1130 result = await server.process_non_streaming(
1131 AttachmentsCreateReq(
1132 params=AttachmentCreateParams(
1133 name=file_name,
1134 size=1024,
1135 mime_type=file_content_type,
1136 )
1137 )
1138 )
1139 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1140
1141 assert attachment.id is not None
1142 assert attachment.mime_type == file_content_type
1143 assert attachment.name == file_name
1144 assert attachment.type == "image"
1145 assert attachment.preview_url == AnyUrl(
1146 f"https://example.com/{attachment.id}/preview"
1147 )
1148 assert attachment.upload_descriptor is not None
1149 assert attachment.upload_descriptor.url == AnyUrl(
1150 f"https://example.com/{attachment.id}/upload"
1151 )
1152 assert attachment.upload_descriptor.method == "PUT"
1153 assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"}
1154
1155 assert attachment.id in store.files
1156
1157
1158async def test_user_message_attachment_metadata_not_serialized():
1159 attachment = FileAttachment(
1160 id="file_with_metadata",
1161 type="file",
1162 mime_type="text/plain",
1163 name="test.txt",
1164 metadata={"source": "test"},
1165 )
1166
1167 with make_server() as server:
1168 await server.store.save_attachment(attachment, DEFAULT_CONTEXT)
1169 events = await server.process_streaming(
1170 ThreadsCreateReq(
1171 params=ThreadCreateParams(
1172 input=UserMessageInput(
1173 content=[UserMessageTextContent(text="Message with file")],
1174 attachments=[attachment.id],
1175 inference_options=InferenceOptions(),
1176 )
1177 )
1178 )
1179 )
1180 thread = next(
1181 event.thread for event in events if event.type == "thread.created"
1182 )
1183
1184 list_result = await server.process_non_streaming(
1185 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1186 )
1187 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1188 user_items = [item for item in items.data if item.type == "user_message"]
1189 assert user_items
1190 attachments = user_items[0].attachments
1191 assert attachments
1192 assert all(att.metadata is None for att in attachments)
1193
1194
1195async def test_user_message_attachment_thread_id_persisted_on_thread_create():
1196 attachment = FileAttachment(
1197 id="file_thread_id_on_create",
1198 type="file",
1199 mime_type="text/plain",
1200 name="test.txt",
1201 thread_id=None,
1202 )
1203
1204 with make_server() as server:
1205 # Attachment exists before a thread exists, so it may not have a thread_id yet.
1206 await server.store.save_attachment(attachment, DEFAULT_CONTEXT)
1207
1208 events = await server.process_streaming(
1209 ThreadsCreateReq(
1210 params=ThreadCreateParams(
1211 input=UserMessageInput(
1212 content=[UserMessageTextContent(text="Message with file")],
1213 attachments=[attachment.id],
1214 inference_options=InferenceOptions(),
1215 )
1216 )
1217 )
1218 )
1219 thread = next(e.thread for e in events if e.type == "thread.created")
1220
1221 # After the user message is saved, the attachment record should be updated
1222 # to reflect its thread association.
1223 stored = await server.store.load_attachment(attachment.id, DEFAULT_CONTEXT)
1224 assert stored.thread_id == thread.id
1225
1226
1227async def test_user_message_attachment_thread_id_persisted_on_add_user_message():
1228 with make_server() as server:
1229 # Create a thread first.
1230 events = await server.process_streaming(
1231 ThreadsCreateReq(
1232 params=ThreadCreateParams(
1233 input=UserMessageInput(
1234 content=[UserMessageTextContent(text="Start")],
1235 attachments=[],
1236 inference_options=InferenceOptions(),
1237 )
1238 )
1239 )
1240 )
1241 thread = next(e.thread for e in events if e.type == "thread.created")
1242
1243 attachment = FileAttachment(
1244 id="file_thread_id_on_add",
1245 type="file",
1246 mime_type="text/plain",
1247 name="test.txt",
1248 thread_id=None,
1249 )
1250 await server.store.save_attachment(attachment, DEFAULT_CONTEXT)
1251
1252 # Attach the pre-created attachment to a message in an existing thread.
1253 await server.process_streaming(
1254 ThreadsAddUserMessageReq(
1255 params=ThreadAddUserMessageParams(
1256 thread_id=thread.id,
1257 input=UserMessageInput(
1258 content=[UserMessageTextContent(text="Second")],
1259 attachments=[attachment.id],
1260 inference_options=InferenceOptions(),
1261 ),
1262 )
1263 )
1264 )
1265
1266 stored = await server.store.load_attachment(attachment.id, DEFAULT_CONTEXT)
1267 assert stored.thread_id == thread.id
1268
1269
1270async def test_create_file_without_filestore():
1271 with pytest.raises(RuntimeError):
1272 with make_server() as server:
1273 await server.process_non_streaming(
1274 AttachmentsCreateReq(
1275 params=AttachmentCreateParams(
1276 name="test-file-name",
1277 size=1024,
1278 mime_type="text/plain",
1279 )
1280 )
1281 )
1282
1283
1284async def test_delete_file():
1285 store = InMemoryFileStore()
1286 store.files["test-file-id"] = {"id": "test-file-id"}
1287 with make_server(file_store=store) as server:
1288 await server.process_non_streaming(
1289 AttachmentsDeleteReq(
1290 params=AttachmentDeleteParams(attachment_id="test-file-id")
1291 )
1292 )
1293 assert store.files == {}
1294
1295
1296async def test_delete_file_without_filestore():
1297 with pytest.raises(RuntimeError):
1298 with make_server() as server:
1299 await server.process_non_streaming(
1300 AttachmentsDeleteReq(
1301 params=AttachmentDeleteParams(attachment_id="test-file-id")
1302 )
1303 )
1304
1305
1306async def test_list_threads_paginated_response():
1307 with make_server() as server:
1308 for i in range(25):
1309 await server.process_streaming(
1310 ThreadsCreateReq(
1311 params=ThreadCreateParams(
1312 input=UserMessageInput(
1313 content=[UserMessageTextContent(text=f"Thread {i}")],
1314 attachments=[],
1315 inference_options=InferenceOptions(),
1316 )
1317 )
1318 )
1319 )
1320 list_result = await server.process_non_streaming(
1321 ThreadsListReq(params=ThreadListParams(limit=20))
1322 )
1323 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1324 assert len(threads.data) == 20
1325 assert threads.has_more is True
1326 assert threads.after is not None
1327
1328 list_result = await server.process_non_streaming(
1329 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
1330 )
1331 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1332 assert len(threads.data) == 5
1333 assert threads.has_more is False
1334
1335
1336async def test_threads_update_request():
1337 with make_server() as server:
1338 events = await server.process_streaming(
1339 ThreadsCreateReq(
1340 params=ThreadCreateParams(
1341 input=UserMessageInput(
1342 content=[UserMessageTextContent(text="Hello, world!")],
1343 attachments=[],
1344 inference_options=InferenceOptions(),
1345 )
1346 )
1347 )
1348 )
1349 thread = next(
1350 event.thread for event in events if event.type == "thread.created"
1351 )
1352 updated_thread = await server.process_non_streaming(
1353 ThreadsUpdateReq(
1354 params=ThreadUpdateParams(thread_id=thread.id, title="New Title")
1355 )
1356 )
1357 updated_thread = TypeAdapter[Thread](Thread).validate_json(updated_thread.json)
1358 assert updated_thread.title == "New Title"
1359
1360
1361async def test_returns_widget_item_generator():
1362 thread = ThreadMetadata(
1363 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1364 )
1365
1366 def render_widget(i: int) -> Card:
1367 return Card(children=[Text(id="text", value="Hello, world"[:i])])
1368
1369 async def widget_generator():
1370 yield render_widget(0)
1371 yield render_widget(3)
1372 yield render_widget(6)
1373 yield render_widget(12)
1374
1375 events = [event async for event in stream_widget(thread, widget_generator())]
1376
1377 assert len(events) == 5
1378 assert isinstance(events[0], ThreadItemAddedEvent)
1379 assert isinstance(events[0].item, WidgetItem)
1380 assert events[0].item.widget == Card(children=[Text(id="text", value="")])
1381
1382 assert isinstance(events[1], ThreadItemUpdatedEvent)
1383 assert events[1].update.type == "widget.streaming_text.value_delta"
1384 assert events[1].update.component_id == "text"
1385 assert events[1].update.delta == "Hel"
1386
1387 assert isinstance(events[2], ThreadItemUpdatedEvent)
1388 assert events[2].update.type == "widget.streaming_text.value_delta"
1389 assert events[2].update.component_id == "text"
1390 assert events[2].update.delta == "lo,"
1391
1392 assert isinstance(events[3], ThreadItemUpdatedEvent)
1393 assert events[3].update.type == "widget.streaming_text.value_delta"
1394 assert events[3].update.component_id == "text"
1395 assert events[3].update.delta == " world"
1396
1397 assert isinstance(events[4], ThreadItemDoneEvent)
1398 assert isinstance(events[4].item, WidgetItem)
1399 assert events[4].item.widget == Card(
1400 children=[Text(id="text", value="Hello, world")]
1401 )
1402
1403
1404async def test_returns_widget_item_generator_full_replace():
1405 thread = ThreadMetadata(
1406 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1407 )
1408
1409 async def widget_generator():
1410 yield Card(children=[Text(id="text", value="Hello")])
1411 yield Card(children=[Text(id="text", value="World")])
1412
1413 events = [event async for event in stream_widget(thread, widget_generator())]
1414
1415 assert len(events) == 3
1416 assert isinstance(events[0], ThreadItemAddedEvent)
1417 assert isinstance(events[0].item, WidgetItem)
1418 assert events[0].item.widget == Card(children=[Text(id="text", value="Hello")])
1419
1420 assert isinstance(events[1], ThreadItemUpdatedEvent)
1421 assert events[1].update.type == "widget.root.updated"
1422 assert events[1].update.widget == Card(children=[Text(id="text", value="World")])
1423
1424 assert isinstance(events[2], ThreadItemDoneEvent)
1425 assert isinstance(events[2].item, WidgetItem)
1426 assert events[2].item.widget == Card(children=[Text(id="text", value="World")])
1427
1428
1429async def test_delete_thread():
1430 with make_server() as server:
1431 events = await server.process_streaming(
1432 ThreadsCreateReq(
1433 params=ThreadCreateParams(
1434 input=UserMessageInput(
1435 content=[UserMessageTextContent(text="Thread to delete")],
1436 attachments=[],
1437 inference_options=InferenceOptions(),
1438 )
1439 )
1440 )
1441 )
1442 thread = next(
1443 event.thread for event in events if event.type == "thread.created"
1444 )
1445
1446 response = await server.process_non_streaming(
1447 ThreadsDeleteReq(params=ThreadDeleteParams(thread_id=thread.id))
1448 )
1449 assert response.json == b"{}"
1450
1451 try:
1452 await server.process_non_streaming(
1453 ThreadsGetByIdReq(
1454 params=ThreadGetByIdParams(thread_id=thread.id),
1455 )
1456 )
1457 assert False, "Expected ValueError for deleted thread"
1458 except NotFoundError as e:
1459 assert f"Thread {thread.id} not found" in str(e)
1460
1461
1462async def test_thread_item_removed_event_removes_item():
1463 removed_id = "msg_to_remove"
1464
1465 async def responder(
1466 thread: ThreadMetadata,
1467 input: UserMessageItem | None,
1468 context: Any,
1469 ) -> AsyncIterator[ThreadStreamEvent]:
1470 yield ThreadItemDoneEvent(
1471 item=UserMessageItem(
1472 id=removed_id,
1473 content=[UserMessageTextContent(text="To be removed")],
1474 attachments=[],
1475 inference_options=InferenceOptions(),
1476 thread_id=thread.id,
1477 created_at=datetime.now(),
1478 ),
1479 )
1480 yield ThreadItemRemovedEvent(item_id=removed_id)
1481
1482 with make_server(responder) as server:
1483 events = await server.process_streaming(
1484 ThreadsCreateReq(
1485 params=ThreadCreateParams(
1486 input=UserMessageInput(
1487 content=[UserMessageTextContent(text="Trigger removal")],
1488 attachments=[],
1489 inference_options=InferenceOptions(),
1490 )
1491 )
1492 )
1493 )
1494 thread = next(
1495 event.thread for event in events if event.type == "thread.created"
1496 )
1497 assert any(
1498 e.type == "thread.item.removed" and e.item_id == removed_id for e in events
1499 )
1500 # The item should not be present in the store
1501 list_result = await server.process_non_streaming(
1502 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1503 )
1504 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1505 assert all(i.id != removed_id for i in items)
1506
1507
1508async def test_raising_in_responder_yields_error_event():
1509 async def responder(
1510 thread: ThreadMetadata,
1511 input: UserMessageItem | None,
1512 context: Any,
1513 ) -> AsyncIterator[ThreadStreamEvent]:
1514 raise ValueError("Test error")
1515 yield
1516
1517 with make_server(responder) as server:
1518 result = await server.process(
1519 ThreadsCreateReq(
1520 params=ThreadCreateParams(
1521 input=UserMessageInput(
1522 content=[UserMessageTextContent(text="Test error")],
1523 attachments=[],
1524 inference_options=InferenceOptions(),
1525 )
1526 )
1527 ).model_dump_json(),
1528 RequestContext(user_id="test_user"),
1529 )
1530 assert isinstance(result, StreamingResult)
1531 iter = result.__aiter__()
1532
1533 # Yield an errro
1534 async for e in iter:
1535 e = decode_event(e)
1536 if e.type == "error":
1537 assert e.code == ErrorCode.STREAM_ERROR
1538 break
1539 else:
1540 assert False, "Expected error event"
1541
1542 # Then finish
1543 try:
1544 await iter.__anext__()
1545 except StopAsyncIteration:
1546 pass
1547 else:
1548 assert False, "Expected StopAsyncIteration"
1549
1550
1551async def test_thread_item_replaced_event_replaces_item():
1552 original_id = "msg_to_replace"
1553
1554 async def responder(
1555 thread: ThreadMetadata,
1556 input: UserMessageItem | None,
1557 context: Any,
1558 ) -> AsyncIterator[ThreadStreamEvent]:
1559 # First add an item
1560 yield ThreadItemDoneEvent(
1561 item=UserMessageItem(
1562 id=original_id,
1563 content=[UserMessageTextContent(text="Original content")],
1564 attachments=[],
1565 inference_options=InferenceOptions(),
1566 thread_id=thread.id,
1567 created_at=datetime.now(),
1568 ),
1569 )
1570 # Then replace it
1571 replacement_item = AssistantMessageItem(
1572 id=original_id,
1573 content=[AssistantMessageContent(text="This is the replacement content")],
1574 created_at=datetime.now(),
1575 thread_id=thread.id,
1576 )
1577 yield ThreadItemReplacedEvent(item=replacement_item)
1578
1579 with make_server(responder) as server:
1580 events = await server.process_streaming(
1581 ThreadsCreateReq(
1582 params=ThreadCreateParams(
1583 input=UserMessageInput(
1584 content=[UserMessageTextContent(text="Trigger replacement")],
1585 attachments=[],
1586 inference_options=InferenceOptions(),
1587 )
1588 )
1589 )
1590 )
1591 thread = next(
1592 event.thread for event in events if event.type == "thread.created"
1593 )
1594
1595 # Verify the replacement event was generated
1596 assert any(
1597 e.type == "thread.item.replaced" and e.item.id == original_id
1598 for e in events
1599 )
1600
1601 # Verify the item was replaced in the store
1602 list_result = await server.process_non_streaming(
1603 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1604 )
1605 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1606
1607 # Find the replaced item
1608 replaced_item = next((i for i in items if i.id == original_id), None)
1609 assert replaced_item is not None
1610 assert isinstance(replaced_item, AssistantMessageItem)
1611 assert isinstance(replaced_item.content[0], AssistantMessageContent)
1612 assert replaced_item.content[0].text == "This is the replacement content"
1613
1614
1615async def test_thread_item_replaced_event_with_widget():
1616 widget_id = "widget_to_replace"
1617 original_widget = Card(children=[Text(value="Original widget content")])
1618 replacement_widget = Card(children=[Text(value="Replaced widget content")])
1619
1620 async def responder(
1621 thread: ThreadMetadata,
1622 input: UserMessageItem | None,
1623 context: Any,
1624 ) -> AsyncIterator[ThreadStreamEvent]:
1625 # First add a widget item
1626 yield ThreadItemDoneEvent(
1627 item=WidgetItem(
1628 id=widget_id,
1629 widget=original_widget,
1630 created_at=datetime.now(),
1631 thread_id=thread.id,
1632 ),
1633 )
1634 # Then replace it
1635 yield ThreadItemReplacedEvent(
1636 item=WidgetItem(
1637 id=widget_id,
1638 widget=replacement_widget,
1639 created_at=datetime.now(),
1640 thread_id=thread.id,
1641 ),
1642 )
1643
1644 with make_server(responder) as server:
1645 events = await server.process_streaming(
1646 ThreadsCreateReq(
1647 params=ThreadCreateParams(
1648 input=UserMessageInput(
1649 content=[
1650 UserMessageTextContent(text="Trigger widget replacement")
1651 ],
1652 attachments=[],
1653 inference_options=InferenceOptions(),
1654 )
1655 )
1656 )
1657 )
1658 thread = next(
1659 event.thread for event in events if event.type == "thread.created"
1660 )
1661
1662 # Verify the replacement event was generated
1663 replacement_event = next(
1664 (
1665 e
1666 for e in events
1667 if e.type == "thread.item.replaced" and e.item.id == widget_id
1668 ),
1669 None,
1670 )
1671 assert replacement_event is not None
1672 assert isinstance(replacement_event.item, WidgetItem)
1673 assert replacement_event.item.widget == replacement_widget
1674
1675 # Verify the item was replaced in the store
1676 list_result = await server.process_non_streaming(
1677 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1678 )
1679 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1680
1681 # Find the replaced widget item
1682 replaced_item = next((i for i in items if i.id == widget_id), None)
1683 assert replaced_item is not None
1684 assert isinstance(replaced_item, WidgetItem)
1685 assert isinstance(replaced_item.widget, Card)
1686 assert isinstance(replaced_item.widget.children[0], Text)
1687 assert replaced_item.widget.children[0].value == "Replaced widget content"
1688
1689
1690async def test_retry_after_item():
1691 responder_calls = []
1692
1693 async def responder(
1694 thread: ThreadMetadata,
1695 input: UserMessageItem | None,
1696 context: Any,
1697 ) -> AsyncIterator[ThreadStreamEvent]:
1698 responder_calls.append(input)
1699
1700 if len(responder_calls) == 1:
1701 # First response - return an assistant message
1702 yield ThreadItemDoneEvent(
1703 item=AssistantMessageItem(
1704 id="assistant_msg_1",
1705 content=[AssistantMessageContent(text="First response")],
1706 created_at=datetime.now(),
1707 thread_id=thread.id,
1708 ),
1709 )
1710 else:
1711 # Retry response - return different content
1712 yield ThreadItemDoneEvent(
1713 item=AssistantMessageItem(
1714 id="assistant_msg_2",
1715 content=[AssistantMessageContent(text="Retried response")],
1716 created_at=datetime.now(),
1717 thread_id=thread.id,
1718 ),
1719 )
1720
1721 with make_server(responder) as server:
1722 events = await server.process_streaming(
1723 ThreadsCreateReq(
1724 params=ThreadCreateParams(
1725 input=UserMessageInput(
1726 content=[UserMessageTextContent(text="Hello, world!")],
1727 attachments=[],
1728 inference_options=InferenceOptions(),
1729 )
1730 )
1731 )
1732 )
1733 thread = next(
1734 event.thread for event in events if event.type == "thread.created"
1735 )
1736
1737 list_result = await server.process_non_streaming(
1738 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1739 )
1740 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1741 assert len(items_before.data) == 2
1742 assert items_before.data[0].type == "assistant_message"
1743 assert items_before.data[0].content[0].type == "output_text"
1744 assert items_before.data[0].content[0].text == "First response"
1745 assert items_before.data[1].type == "user_message"
1746 assert items_before.data[1].content[0].text == "Hello, world!"
1747
1748 # Get the user message ID for retrying after it
1749 user_message_id = items_before.data[1].id
1750
1751 # Retry after the user message
1752 retry_events = await server.process_streaming(
1753 ThreadsRetryAfterItemReq(
1754 params=ThreadRetryAfterItemParams(
1755 thread_id=thread.id, item_id=user_message_id
1756 )
1757 )
1758 )
1759
1760 # Verify the retry generated new response
1761 assert len(retry_events) == 2
1762 assert retry_events[0].type == "stream_options"
1763 assert retry_events[1].type == "thread.item.done"
1764 assert retry_events[1].item.type == "assistant_message"
1765 assert retry_events[1].item.content[0].type == "output_text"
1766 assert retry_events[1].item.content[0].text == "Retried response"
1767
1768 # Verify the responder was called twice with the same user message
1769 assert len(responder_calls) == 2
1770 assert responder_calls[0] == responder_calls[1]
1771 assert responder_calls[0].type == "user_message"
1772 assert responder_calls[0].content[0].text == "Hello, world!"
1773
1774 # Verify the assistant response was replaced in storage
1775 list_result_after = await server.process_non_streaming(
1776 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1777 )
1778 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1779 list_result_after.json
1780 )
1781 assert len(items_after.data) == 2 # Still user message + assistant message
1782 assert items_after.data[0].type == "assistant_message"
1783 assert items_after.data[0].content[0].type == "output_text"
1784 assert items_after.data[0].content[0].text == "Retried response" # New response
1785 assert items_after.data[1].type == "user_message"
1786 assert items_after.data[1].content[0].text == "Hello, world!"
1787
1788
1789async def test_retry_after_item_invalid_item():
1790 async def responder(
1791 thread: ThreadMetadata,
1792 input: UserMessageItem | None,
1793 context: Any,
1794 ) -> AsyncIterator[ThreadStreamEvent]:
1795 yield ThreadItemDoneEvent(
1796 item=AssistantMessageItem(
1797 id="assistant_msg_1",
1798 content=[AssistantMessageContent(text="Response")],
1799 created_at=datetime.now(),
1800 thread_id=thread.id,
1801 ),
1802 )
1803
1804 with make_server(responder) as server:
1805 events = await server.process_streaming(
1806 ThreadsCreateReq(
1807 params=ThreadCreateParams(
1808 input=UserMessageInput(
1809 content=[UserMessageTextContent(text="Hello")],
1810 attachments=[],
1811 inference_options=InferenceOptions(),
1812 )
1813 )
1814 )
1815 )
1816 thread = next(
1817 event.thread for event in events if event.type == "thread.created"
1818 )
1819
1820 items_result = await server.process_non_streaming(
1821 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1822 )
1823 items = TypeAdapter(Page[ThreadItem]).validate_json(items_result.json)
1824 assistant_item_id = items.data[0].id
1825
1826 # Try to retry after an assistant message (should fail)
1827 with pytest.raises(ValueError, match="is not a user message"):
1828 await server.process_streaming(
1829 ThreadsRetryAfterItemReq(
1830 params=ThreadRetryAfterItemParams(
1831 thread_id=thread.id, item_id=assistant_item_id
1832 )
1833 )
1834 )
1835
1836
1837async def test_retry_after_item_with_multiple_messages():
1838 responder_calls = []
1839
1840 async def responder(
1841 thread: ThreadMetadata,
1842 input: UserMessageItem | None,
1843 context: Any,
1844 ) -> AsyncIterator[ThreadStreamEvent]:
1845 responder_calls.append(input)
1846 assert input is not None
1847 assert input.type == "user_message"
1848 yield ThreadItemDoneEvent(
1849 item=AssistantMessageItem(
1850 id=f"assistant_msg_{len(responder_calls)}",
1851 content=[
1852 AssistantMessageContent(
1853 text=f"Response to: {input.content[0].text}"
1854 )
1855 ],
1856 created_at=datetime.now(),
1857 thread_id=thread.id,
1858 ),
1859 )
1860
1861 with make_server(responder) as server:
1862 events = await server.process_streaming(
1863 ThreadsCreateReq(
1864 params=ThreadCreateParams(
1865 input=UserMessageInput(
1866 content=[UserMessageTextContent(text="First message")],
1867 attachments=[],
1868 inference_options=InferenceOptions(),
1869 )
1870 )
1871 )
1872 )
1873 thread = next(
1874 event.thread for event in events if event.type == "thread.created"
1875 )
1876
1877 await server.process_streaming(
1878 ThreadsAddUserMessageReq(
1879 params=ThreadAddUserMessageParams(
1880 thread_id=thread.id,
1881 input=UserMessageInput(
1882 content=[UserMessageTextContent(text="Second message")],
1883 attachments=[],
1884 inference_options=InferenceOptions(),
1885 ),
1886 )
1887 )
1888 )
1889
1890 # Verify we have 4 items: user1, assistant1, user2, assistant2
1891 list_result = await server.process_non_streaming(
1892 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1893 )
1894 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1895 assert len(items_before.data) == 4
1896
1897 # Get the second user message ID to retry after it
1898 second_user_message_id = items_before.data[1].id # Second user message
1899
1900 # Retry after the second user message
1901 retry_events = await server.process_streaming(
1902 ThreadsRetryAfterItemReq(
1903 params=ThreadRetryAfterItemParams(
1904 thread_id=thread.id, item_id=second_user_message_id
1905 )
1906 )
1907 )
1908
1909 assert len(retry_events) == 2
1910 assert retry_events[0].type == "stream_options"
1911 assert retry_events[1].type == "thread.item.done"
1912
1913 # Verify retry used the second user message
1914 assert len(responder_calls) == 3 # Original 2 + 1 retry
1915 assert responder_calls[2].type == "user_message"
1916 assert responder_calls[2].content[0].text == "Second message"
1917
1918 # Verify only the last assistant response was removed and regenerated
1919 list_result_after = await server.process_non_streaming(
1920 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1921 )
1922 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1923 list_result_after.json
1924 )
1925 assert len(items_after.data) == 4 # Still 4 items
1926 # First user message and response should remain unchanged
1927 assert items_after.data[3].type == "user_message"
1928 assert items_after.data[3].content[0].text == "First message"
1929 assert items_after.data[3].content[0].type == "input_text"
1930 assert items_after.data[3].content[0].text == "First message"
1931 assert items_after.data[2].type == "assistant_message"
1932 assert items_after.data[2].content[0].type == "output_text"
1933 assert items_after.data[2].content[0].text == "Response to: First message"
1934 # Second user message should remain, but assistant response should be new
1935 assert items_after.data[1].type == "user_message"
1936 assert items_after.data[1].content[0].text == "Second message"
1937 assert items_after.data[0].type == "assistant_message"
1938 assert items_after.data[0].content[0].type == "output_text"
1939 assert items_after.data[0].content[0].text == "Response to: Second message"
1940 assert items_after.data[0].id != items_before.data[0].id
1941
1942
1943async def test_threads_create_passes_tools_to_responder():
1944 tool_choice = ToolChoice(id="web_search")
1945
1946 async def responder(
1947 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1948 ) -> AsyncIterator[ThreadStreamEvent]:
1949 assert input is not None
1950 assert input.inference_options.tool_choice is not None
1951 assert input.inference_options.tool_choice.id == tool_choice.id
1952
1953 yield ThreadItemDoneEvent(
1954 item=AssistantMessageItem(
1955 id="assistant_msg_tools_create",
1956 content=[AssistantMessageContent(text="ok")],
1957 created_at=datetime.now(),
1958 thread_id=thread.id,
1959 ),
1960 )
1961
1962 with make_server(responder) as server:
1963 events = await server.process_streaming(
1964 ThreadsCreateReq(
1965 params=ThreadCreateParams(
1966 input=UserMessageInput(
1967 content=[UserMessageTextContent(text="Hello with tools")],
1968 attachments=[],
1969 inference_options=InferenceOptions(tool_choice=tool_choice),
1970 )
1971 )
1972 )
1973 )
1974 # Sanity check that the request flowed through.
1975 assert any(e.type == "thread.item.done" for e in events)
1976
1977
1978async def test_input_transcribe_decodes_base64_and_normalizes_mime_type():
1979 audio_bytes = b"hello audio"
1980 audio_b64 = base64.b64encode(audio_bytes).decode("ascii")
1981 seen: dict[str, Any] = {}
1982
1983 def transcribe_callback(
1984 audio_input: AudioInput, context: Any
1985 ) -> TranscriptionResult:
1986 seen["audio"] = audio_input.data
1987 seen["mime"] = audio_input.mime_type
1988 seen["media_type"] = audio_input.media_type
1989 seen["context"] = context
1990 return TranscriptionResult(text="ok")
1991
1992 with make_server(transcribe_callback=transcribe_callback) as server:
1993 result = await server.process_non_streaming(
1994 InputTranscribeReq(
1995 params=InputTranscribeParams(
1996 audio_base64=audio_b64,
1997 mime_type="audio/webm;codecs=opus",
1998 )
1999 )
2000 )
2001 parsed = TypeAdapter(TranscriptionResult).validate_json(result.json)
2002 assert parsed.text == "ok"
2003
2004 assert seen["audio"] == audio_bytes
2005 assert seen["mime"] == "audio/webm;codecs=opus"
2006 assert seen["media_type"] == "audio/webm"
2007 assert seen["context"] == DEFAULT_CONTEXT
2008
2009
2010async def test_retry_after_item_passes_tools_to_responder():
2011 pass
2012
2013
2014async def test_threads_add_message_passes_tools_to_responder():
2015 tool_choice = ToolChoice(id="retrieval")
2016
2017 async def responder(
2018 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
2019 ) -> AsyncIterator[ThreadStreamEvent]:
2020 assert input is not None
2021 assert input.inference_options.tool_choice is not None
2022 assert input.inference_options.tool_choice.id == tool_choice.id
2023
2024 yield ThreadItemDoneEvent(
2025 item=AssistantMessageItem(
2026 id="assistant_msg_tools_add",
2027 content=[AssistantMessageContent(text="ok")],
2028 created_at=datetime.now(),
2029 thread_id=thread.id,
2030 ),
2031 )
2032
2033 with make_server(responder) as server:
2034 # Create a thread first.
2035 events = await server.process_streaming(
2036 ThreadsCreateReq(
2037 params=ThreadCreateParams(
2038 input=UserMessageInput(
2039 content=[UserMessageTextContent(text="Start")],
2040 attachments=[],
2041 inference_options=InferenceOptions(tool_choice=tool_choice),
2042 )
2043 )
2044 )
2045 )
2046 thread = next(e.thread for e in events if e.type == "thread.created")
2047
2048 # Add a message with tools and verify they reach the responder.
2049 events = await server.process_streaming(
2050 ThreadsAddUserMessageReq(
2051 params=ThreadAddUserMessageParams(
2052 thread_id=thread.id,
2053 input=UserMessageInput(
2054 content=[UserMessageTextContent(text="Second")],
2055 attachments=[],
2056 inference_options=InferenceOptions(tool_choice=tool_choice),
2057 ),
2058 )
2059 )
2060 )
2061 # Sanity check that the request flowed through.
2062 assert any(e.type == "thread.item.done" for e in events)
2063
2064
2065async def test_custom_id_generators_on_server():
2066 class UniqueIdStore(SQLiteStore):
2067 def __init__(self, path):
2068 super().__init__(path)
2069 self.counter = 0
2070
2071 def generate_thread_id(self, context) -> str:
2072 self.counter += 1
2073 return f"thr_custom_{self.counter}"
2074
2075 def generate_item_id(
2076 self,
2077 item_type: str,
2078 thread: ThreadMetadata,
2079 context,
2080 ) -> str:
2081 self.counter += 1
2082 return f"{item_type}_custom_{self.counter}_{thread.id}"
2083
2084 async def responder(
2085 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
2086 ) -> AsyncIterator[ThreadStreamEvent]:
2087 # Emit a tool call to later verify tool_call_output id generation
2088 return
2089 yield
2090
2091 with make_server(responder) as server:
2092 # Swap in our custom ID-generating store, reusing the same DB path.
2093 db_path = cast(SQLiteStore, server.store).db_path
2094 server.store = UniqueIdStore(db_path)
2095
2096 # Create thread and verify thread id and user message id
2097 result = await server.process(
2098 ThreadsCreateReq(
2099 params=ThreadCreateParams(
2100 input=UserMessageInput(
2101 content=[UserMessageTextContent(text="Hello")],
2102 attachments=[],
2103 inference_options=InferenceOptions(),
2104 )
2105 )
2106 ).model_dump_json(),
2107 DEFAULT_CONTEXT,
2108 )
2109 assert isinstance(result, StreamingResult)
2110 events = await decode_streaming_result(result)
2111
2112 thread_created = next(e for e in events if e.type == "thread.created")
2113 assert thread_created.thread.id == "thr_custom_1"
2114
2115 user_message = next(
2116 e.item
2117 for e in events
2118 if e.type == "thread.item.done" and e.item.type == "user_message"
2119 )
2120 assert user_message.id == "message_custom_2_thr_custom_1"
2121