openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
d14ce4602e2002ef97ccff685ec14fb27c49c94c

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_chatkit_server.py

2227lines · modecode

1import asyncio
2import base64
3import sqlite3
4from contextlib import contextmanager
5from datetime import datetime
6from typing import Any, AsyncIterator, Callable, cast
7
8import pytest
9from helpers.mock_store import SQLiteStore
10from pydantic import AnyUrl, TypeAdapter
11
12from chatkit.actions import Action
13from chatkit.errors import ErrorCode
14from chatkit.server import (
15 ChatKitServer,
16 NonStreamingResult,
17 StreamingResult,
18 stream_widget,
19)
20from chatkit.store import (
21 AttachmentStore,
22 NotFoundError,
23 StoreItemType,
24 default_generate_id,
25)
26from chatkit.types import (
27 AssistantMessageContent,
28 AssistantMessageContentPartTextDelta,
29 AssistantMessageItem,
30 Attachment,
31 AttachmentCreateParams,
32 AttachmentDeleteParams,
33 AttachmentsCreateReq,
34 AttachmentsDeleteReq,
35 AttachmentUploadDescriptor,
36 AudioInput,
37 ClientToolCallItem,
38 CustomTask,
39 FeedbackKind,
40 FileAttachment,
41 ImageAttachment,
42 InferenceOptions,
43 InputTranscribeParams,
44 InputTranscribeReq,
45 ItemFeedbackParams,
46 ItemsFeedbackReq,
47 ItemsListParams,
48 ItemsListReq,
49 LockedStatus,
50 Page,
51 ProgressUpdateEvent,
52 SyncCustomActionResponse,
53 Thread,
54 ThreadAddClientToolOutputParams,
55 ThreadAddUserMessageParams,
56 ThreadCreatedEvent,
57 ThreadCreateParams,
58 ThreadCustomActionParams,
59 ThreadDeleteParams,
60 ThreadGetByIdParams,
61 ThreadItem,
62 ThreadItemAddedEvent,
63 ThreadItemDoneEvent,
64 ThreadItemRemovedEvent,
65 ThreadItemReplacedEvent,
66 ThreadItemUpdatedEvent,
67 ThreadListParams,
68 ThreadMetadata,
69 ThreadRetryAfterItemParams,
70 ThreadsAddClientToolOutputReq,
71 ThreadsAddUserMessageReq,
72 ThreadsCreateReq,
73 ThreadsCustomActionReq,
74 ThreadsDeleteReq,
75 ThreadsGetByIdReq,
76 ThreadsListReq,
77 ThreadsRetryAfterItemReq,
78 ThreadsSyncCustomActionReq,
79 ThreadStreamEvent,
80 ThreadsUpdateReq,
81 ThreadUpdatedEvent,
82 ThreadUpdateParams,
83 ToolChoice,
84 TranscriptionResult,
85 UserMessageInput,
86 UserMessageItem,
87 UserMessageTextContent,
88 WidgetItem,
89 WidgetRootUpdated,
90 Workflow,
91 WorkflowItem,
92 WorkflowTaskAdded,
93)
94from chatkit.widgets import Card, Text
95from tests._types import RequestContext
96from tests.test_store import make_thread_items
97
98server_id = 0
99
100
101DEFAULT_CONTEXT = RequestContext(user_id="test_user")
102
103
104class InMemoryFileStore(AttachmentStore):
105 def __init__(self):
106 self.files = {}
107
108 async def delete_attachment(self, attachment_id: str, context: Any) -> None:
109 del self.files[attachment_id]
110
111 async def create_attachment(
112 self, input: AttachmentCreateParams, context: RequestContext
113 ) -> Attachment:
114 # Simple in-memory attachment creation
115 id = f"atc_{len(self.files) + 1}"
116 metadata = {"source": "test"}
117 if input.mime_type.startswith("image/"):
118 attachment = ImageAttachment(
119 id=id,
120 mime_type=input.mime_type,
121 name=input.name,
122 preview_url=AnyUrl(f"https://example.com/{id}/preview"),
123 upload_descriptor=AttachmentUploadDescriptor(
124 url=AnyUrl(f"https://example.com/{id}/upload"),
125 method="PUT",
126 headers={"X-My-Header": "my-value"},
127 ),
128 metadata=metadata,
129 )
130 else:
131 attachment = FileAttachment(
132 id=id,
133 mime_type=input.mime_type,
134 name=input.name,
135 upload_descriptor=AttachmentUploadDescriptor(
136 url=AnyUrl(f"https://example.com/{id}/upload"),
137 method="PUT",
138 headers={"X-My-Header": "my-value"},
139 ),
140 metadata=metadata,
141 )
142 self.files[attachment.id] = attachment
143 return attachment
144
145
146def decode_event(event: bytes) -> ThreadStreamEvent:
147 return TypeAdapter(ThreadStreamEvent).validate_json(event.split(b"data: ")[1])
148
149
150async def decode_streaming_result(
151 streaming_result: StreamingResult,
152) -> list[ThreadStreamEvent]:
153 return [decode_event(event) async for event in streaming_result.json_events]
154
155
156@contextmanager
157def make_server(
158 responder: Callable[
159 [ThreadMetadata, UserMessageItem | None, Any], AsyncIterator[ThreadStreamEvent]
160 ]
161 | None = None,
162 handle_feedback: Callable[[str, list[str], FeedbackKind, Any], None] | None = None,
163 action_callback: Callable[
164 [ThreadMetadata, Action[str, Any], WidgetItem | None, Any],
165 AsyncIterator[ThreadStreamEvent],
166 ]
167 | None = None,
168 sync_action_callback: Callable[
169 [ThreadMetadata, Action[str, Any], WidgetItem | None, Any],
170 SyncCustomActionResponse,
171 ]
172 | None = None,
173 file_store: AttachmentStore | None = None,
174 transcribe_callback: Callable[[AudioInput, Any], TranscriptionResult] | None = None,
175):
176 global server_id
177 db_path = f"file:{server_id}?mode=memory&cache=shared"
178 server_id += 1
179 # Keep the shared in-memory database from being deleted when the last connection is closed
180 db = sqlite3.connect(db_path, uri=True)
181
182 class TestChatKitServer(ChatKitServer):
183 def __init__(self):
184 super().__init__(SQLiteStore(db_path), file_store)
185
186 def action(
187 self,
188 thread: ThreadMetadata,
189 action: Action[str, Any],
190 sender: WidgetItem | None,
191 context: Any,
192 ) -> AsyncIterator[ThreadStreamEvent]:
193 if action_callback is None:
194 raise ValueError("action_callback not wired up")
195 return action_callback(thread, action, sender, context)
196
197 async def sync_action(
198 self,
199 thread: ThreadMetadata,
200 action: Action[str, Any],
201 sender: WidgetItem | None,
202 context: Any,
203 ) -> SyncCustomActionResponse:
204 if sync_action_callback is None:
205 raise ValueError("sync_action_callback not wired up")
206 return sync_action_callback(thread, action, sender, context)
207
208 def respond(
209 self,
210 thread: ThreadMetadata,
211 input_user_message: UserMessageItem | None,
212 context: Any,
213 ) -> AsyncIterator[ThreadStreamEvent]:
214 if responder is None:
215 return self._empty_responder()
216 return responder(thread, input_user_message, context)
217
218 async def _empty_responder(self) -> AsyncIterator[ThreadStreamEvent]:
219 return
220 yield
221
222 async def add_feedback(
223 self,
224 thread_id: str,
225 item_ids: list[str],
226 feedback: FeedbackKind,
227 context: Any,
228 ) -> None:
229 if handle_feedback is None:
230 return
231 handle_feedback(thread_id, item_ids, feedback, context)
232
233 async def transcribe(
234 self, audio_input: AudioInput, context: Any
235 ) -> TranscriptionResult:
236 if transcribe_callback is None:
237 return await super().transcribe(audio_input, context)
238 return transcribe_callback(audio_input, context)
239
240 async def process_streaming(
241 self, request_obj, context: Any | None = None
242 ) -> list[ThreadStreamEvent]:
243 result = await self.process(
244 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
245 )
246 assert isinstance(result, StreamingResult)
247 return await decode_streaming_result(result)
248
249 async def process_non_streaming(self, request_obj, context: Any | None = None):
250 result = await self.process(
251 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
252 )
253 assert isinstance(result, NonStreamingResult)
254 return result
255
256 try:
257 yield TestChatKitServer()
258 finally:
259 db.close()
260
261
262async def test_stream_cancellation_persists_pending_assistant_message_and_hidden_context():
263 async def responder(
264 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
265 ) -> AsyncIterator[ThreadStreamEvent]:
266 yield ThreadItemAddedEvent(
267 item=AssistantMessageItem(
268 id="assistant-message-pending",
269 created_at=datetime.now(),
270 content=[AssistantMessageContent(text="Hello, ")],
271 thread_id=thread.id,
272 )
273 )
274 yield ThreadItemUpdatedEvent(
275 item_id="assistant-message-pending",
276 update=AssistantMessageContentPartTextDelta(
277 content_index=0,
278 delta="World!",
279 ),
280 )
281 raise asyncio.CancelledError()
282
283 with make_server(responder) as server:
284 # Allow hidden_context id generation in this test store
285 original_generate_item_id = server.store.generate_item_id
286
287 def generate_item_id(
288 item_type: StoreItemType, thread: ThreadMetadata, context: Any
289 ):
290 if item_type == "sdk_hidden_context":
291 return default_generate_id("sdk_hidden_context")
292 return original_generate_item_id(item_type, thread, context)
293
294 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
295
296 stream = await server.process(
297 ThreadsCreateReq(
298 params=ThreadCreateParams(
299 input=UserMessageInput(
300 content=[UserMessageTextContent(text="Hello")],
301 attachments=[],
302 inference_options=InferenceOptions(),
303 )
304 )
305 ).model_dump_json(),
306 DEFAULT_CONTEXT,
307 )
308 assert isinstance(stream, StreamingResult)
309
310 events: list[ThreadStreamEvent] = []
311 with pytest.raises(asyncio.CancelledError): # noqa: PT012
312 async for raw in stream.json_events:
313 events.append(decode_event(raw))
314
315 thread = next(e.thread for e in events if e.type == "thread.created")
316 items = await server.store.load_thread_items(
317 thread.id, None, 1, "desc", DEFAULT_CONTEXT
318 )
319 hidden_context_item = items.data[-1]
320 assert hidden_context_item.type == "sdk_hidden_context"
321 assert (
322 hidden_context_item.content
323 == "The user cancelled the stream. Stop responding to the prior request."
324 )
325
326 assistant_message_item = await server.store.load_item(
327 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
328 )
329 assert assistant_message_item.type == "assistant_message"
330 assert assistant_message_item.content[0].text == "Hello, World!"
331
332
333async def test_stream_cancellation_does_not_persist_pending_empty_assistant_message():
334 async def responder(
335 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
336 ) -> AsyncIterator[ThreadStreamEvent]:
337 yield ThreadItemAddedEvent(
338 item=AssistantMessageItem(
339 id="assistant-message-pending",
340 created_at=datetime.now(),
341 content=[],
342 thread_id=thread.id,
343 )
344 )
345 raise asyncio.CancelledError()
346
347 with make_server(responder) as server:
348 original_generate_item_id = server.store.generate_item_id
349
350 def generate_item_id(
351 item_type: StoreItemType, thread: ThreadMetadata, context: Any
352 ):
353 if item_type == "sdk_hidden_context":
354 return default_generate_id("sdk_hidden_context")
355 return original_generate_item_id(item_type, thread, context)
356
357 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
358
359 stream = await server.process(
360 ThreadsCreateReq(
361 params=ThreadCreateParams(
362 input=UserMessageInput(
363 content=[UserMessageTextContent(text="Hello")],
364 attachments=[],
365 inference_options=InferenceOptions(),
366 )
367 )
368 ).model_dump_json(),
369 DEFAULT_CONTEXT,
370 )
371 assert isinstance(stream, StreamingResult)
372
373 events: list[ThreadStreamEvent] = []
374 with pytest.raises(asyncio.CancelledError): # noqa: PT012
375 async for raw in stream.json_events:
376 events.append(decode_event(raw))
377
378 thread = next(e.thread for e in events if e.type == "thread.created")
379 items = await server.store.load_thread_items(
380 thread.id, None, 1, "desc", DEFAULT_CONTEXT
381 )
382 hidden_context_item = items.data[-1]
383 assert hidden_context_item.type == "sdk_hidden_context"
384 assert (
385 hidden_context_item.content
386 == "The user cancelled the stream. Stop responding to the prior request."
387 )
388
389 with pytest.raises(NotFoundError):
390 await server.store.load_item(
391 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
392 )
393
394
395async def test_workflow_task_not_duplicated_on_done_event():
396 """
397 Regression test to make sure pending item updates do not modify the
398 origin item that was streamed.
399 """
400
401 async def responder(
402 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
403 ) -> AsyncIterator[ThreadStreamEvent]:
404 workflow_item = WorkflowItem(
405 id="workflow-item",
406 created_at=datetime.now(),
407 thread_id=thread.id,
408 workflow=Workflow(type="custom", tasks=[]),
409 )
410
411 yield ThreadItemAddedEvent(item=workflow_item)
412
413 task = CustomTask(title="foo", content="bar")
414 yield ThreadItemUpdatedEvent(
415 item_id=workflow_item.id,
416 update=WorkflowTaskAdded(task=task, task_index=0),
417 )
418
419 workflow_item.workflow.tasks.append(task)
420 yield ThreadItemDoneEvent(item=workflow_item)
421
422 with make_server(responder) as server:
423 events = await server.process_streaming(
424 ThreadsCreateReq(
425 params=ThreadCreateParams(
426 input=UserMessageInput(
427 content=[UserMessageTextContent(text="Hello")],
428 attachments=[],
429 inference_options=InferenceOptions(),
430 )
431 )
432 )
433 )
434
435 thread = next(e.thread for e in events if e.type == "thread.created")
436 workflow_done_event = next(
437 e
438 for e in events
439 if isinstance(e, ThreadItemDoneEvent) and isinstance(e.item, WorkflowItem)
440 )
441 workflow_done_item = cast(WorkflowItem, workflow_done_event.item)
442 assert len(workflow_done_item.workflow.tasks) == 1
443 assert workflow_done_item.workflow.tasks[0].title == "foo"
444
445 stored = await server.store.load_item(
446 thread.id, workflow_done_item.id, DEFAULT_CONTEXT
447 )
448 assert isinstance(stored, WorkflowItem)
449 stored_workflow = cast(WorkflowItem, stored)
450 assert len(stored_workflow.workflow.tasks) == 1
451 assert stored_workflow.workflow.tasks[0].title == "foo"
452
453
454async def test_flows_context_to_responder():
455 responder_context = None
456 add_feedback_context = None
457
458 async def responder(
459 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
460 ) -> AsyncIterator[ThreadStreamEvent]:
461 nonlocal responder_context
462 responder_context = context
463 yield ThreadItemDoneEvent(
464 item=AssistantMessageItem(
465 id="msg_1",
466 created_at=datetime.now(),
467 content=[AssistantMessageContent(text="Hello, world!")],
468 thread_id=thread.id,
469 ),
470 )
471
472 def add_feedback(
473 thread_id: str, item_ids: list[str], feedback: FeedbackKind, context: Any
474 ) -> None:
475 nonlocal add_feedback_context
476 add_feedback_context = context
477
478 context = RequestContext(user_id="text_ctx")
479 with make_server(responder, add_feedback) as server:
480 events = await server.process_streaming(
481 ThreadsCreateReq(
482 params=ThreadCreateParams(
483 input=UserMessageInput(
484 content=[UserMessageTextContent(text="Hello, world!")],
485 attachments=[],
486 inference_options=InferenceOptions(),
487 )
488 )
489 ),
490 context,
491 )
492 assert responder_context == context
493 thread = next(
494 event.thread for event in events if event.type == "thread.created"
495 )
496 item = next(
497 event.item
498 for event in events
499 if event.type == "thread.item.done"
500 and event.item.type == "assistant_message"
501 )
502 await server.process_non_streaming(
503 ItemsFeedbackReq(
504 params=ItemFeedbackParams(
505 thread_id=thread.id,
506 item_ids=[item.id],
507 kind="positive",
508 )
509 ),
510 context,
511 )
512 assert add_feedback_context == context
513
514
515async def test_creates_thread():
516 async def responder(
517 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
518 ) -> AsyncIterator[ThreadStreamEvent]:
519 return
520 yield
521
522 with make_server(responder) as server:
523 events = await server.process_streaming(
524 ThreadsCreateReq(
525 params=ThreadCreateParams(
526 input=UserMessageInput(
527 content=[UserMessageTextContent(text="Hello, world!")],
528 attachments=[],
529 inference_options=InferenceOptions(),
530 quoted_text="Quoted text",
531 )
532 )
533 )
534 )
535
536 thread_created_event = next(
537 event for event in events if event.type == "thread.created"
538 )
539 assert thread_created_event.thread.title is None
540
541 item_done_event = next(
542 event.item for event in events if event.type == "thread.item.done"
543 )
544
545 assert item_done_event.type == "user_message"
546 assert item_done_event.content[0].text == "Hello, world!"
547 assert item_done_event.quoted_text == "Quoted text"
548
549
550async def test_saves_thread_title():
551 async def responder(
552 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
553 ) -> AsyncIterator[ThreadStreamEvent]:
554 thread.title = "Updated title"
555 return
556 yield
557
558 with make_server(responder) as server:
559 events = await server.process_streaming(
560 ThreadsCreateReq(
561 params=ThreadCreateParams(
562 input=UserMessageInput(
563 content=[UserMessageTextContent(text="Hello, world!")],
564 attachments=[],
565 inference_options=InferenceOptions(),
566 )
567 )
568 )
569 )
570 thread = next(
571 event.thread for event in events if event.type == "thread.created"
572 )
573 assert (
574 await server.store.load_thread(
575 thread.id, RequestContext(user_id="test_user")
576 )
577 ).title == "Updated title"
578 assert events[-1].type == "thread.updated"
579 assert events[-1].thread.title == "Updated title"
580
581
582async def test_saves_thread_metadata():
583 async def responder(
584 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
585 ) -> AsyncIterator[ThreadStreamEvent]:
586 thread.metadata["test_key"] = "test_value"
587 return
588 yield
589
590 with make_server(responder) as server:
591 events = await server.process_streaming(
592 ThreadsCreateReq(
593 params=ThreadCreateParams(
594 input=UserMessageInput(
595 content=[UserMessageTextContent(text="Hello, world!")],
596 attachments=[],
597 inference_options=InferenceOptions(),
598 )
599 )
600 )
601 )
602 thread = next(
603 event.thread for event in events if event.type == "thread.created"
604 )
605 assert (await server.store.load_thread(thread.id, DEFAULT_CONTEXT)).metadata[
606 "test_key"
607 ] == "test_value"
608 assert events[-1].type == "thread.updated"
609
610
611async def test_saves_thread_locked_fields():
612 async def responder(
613 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
614 ) -> AsyncIterator[ThreadStreamEvent]:
615 thread.status = LockedStatus(reason="Because")
616 return
617 yield
618
619 with make_server(responder) as server:
620 events = await server.process_streaming(
621 ThreadsCreateReq(
622 params=ThreadCreateParams(
623 input=UserMessageInput(
624 content=[UserMessageTextContent(text="Hello, world!")],
625 attachments=[],
626 inference_options=InferenceOptions(),
627 )
628 )
629 )
630 )
631 thread = next(
632 event.thread for event in events if event.type == "thread.created"
633 )
634 loaded = await server.store.load_thread(
635 thread.id, RequestContext(user_id="test_user")
636 )
637 assert loaded.status == LockedStatus(reason="Because")
638 assert events[-1].type == "thread.updated"
639 assert events[-1].thread.status == LockedStatus(reason="Because")
640
641
642async def test_emits_thread_updated_mid_stream_and_persists():
643 async def responder(
644 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
645 ) -> AsyncIterator[ThreadStreamEvent]:
646 # First assistant message
647 yield ThreadItemDoneEvent(
648 item=AssistantMessageItem(
649 id="assistant_a",
650 content=[AssistantMessageContent(text="A")],
651 created_at=datetime.now(),
652 thread_id=thread.id,
653 ),
654 )
655
656 # Update the thread mid-stream
657 thread.title = "Mid-stream title"
658
659 # Second assistant message, after which thread.updated should be emitted
660 yield ThreadItemDoneEvent(
661 item=AssistantMessageItem(
662 id="assistant_b",
663 content=[AssistantMessageContent(text="B")],
664 created_at=datetime.now(),
665 thread_id=thread.id,
666 ),
667 )
668
669 # Continue streaming after the update event
670 yield ThreadItemDoneEvent(
671 item=AssistantMessageItem(
672 id="assistant_c",
673 content=[AssistantMessageContent(text="C")],
674 created_at=datetime.now(),
675 thread_id=thread.id,
676 ),
677 )
678
679 with make_server(responder) as server:
680 events = await server.process_streaming(
681 ThreadsCreateReq(
682 params=ThreadCreateParams(
683 input=UserMessageInput(
684 content=[UserMessageTextContent(text="Hello")],
685 attachments=[],
686 inference_options=InferenceOptions(),
687 )
688 )
689 )
690 )
691
692 # Find the created thread
693 thread = next(e.thread for e in events if isinstance(e, ThreadCreatedEvent))
694
695 # Ensure a thread.updated event occurs right after assistant_b
696 idx_b = next(
697 i
698 for i, e in enumerate(events)
699 if isinstance(e, ThreadItemDoneEvent)
700 and isinstance(e.item, AssistantMessageItem)
701 and e.item.id == "assistant_b"
702 )
703 assert isinstance(events[idx_b + 1], ThreadUpdatedEvent)
704 updated_event = cast(ThreadUpdatedEvent, events[idx_b + 1])
705 assert updated_event.thread.title == "Mid-stream title"
706
707 # Streaming continues after the update event
708 assert isinstance(events[idx_b + 2], ThreadItemDoneEvent)
709 done_event = cast(ThreadItemDoneEvent, events[idx_b + 2])
710 assert isinstance(done_event.item, AssistantMessageItem)
711 assert done_event.item.id == "assistant_c"
712
713 # Persisted in the store
714 saved = await server.store.load_thread(thread.id, DEFAULT_CONTEXT)
715 assert saved.title == "Mid-stream title"
716
717
718async def test_respond_with_tool_call():
719 async def responder(
720 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
721 ) -> AsyncIterator[ThreadStreamEvent]:
722 if isinstance(input, UserMessageItem):
723 yield ThreadItemDoneEvent(
724 item=ClientToolCallItem(
725 id="msg_1",
726 created_at=datetime.now(),
727 name="tool_call_1",
728 arguments={"arg1": "val1", "arg2": False},
729 call_id="tool_call_1",
730 thread_id=thread.id,
731 ),
732 )
733 elif input is None:
734 yield ThreadItemDoneEvent(
735 item=AssistantMessageItem(
736 id="msg_2",
737 content=[
738 AssistantMessageContent(text="Glad the tool call succeeded!")
739 ],
740 created_at=datetime.now(),
741 thread_id=thread.id,
742 ),
743 )
744
745 with make_server(responder) as server:
746 events = await server.process_streaming(
747 ThreadsCreateReq(
748 params=ThreadCreateParams(
749 input=UserMessageInput(
750 content=[UserMessageTextContent(text="Hello, world!")],
751 attachments=[],
752 inference_options=InferenceOptions(),
753 )
754 )
755 )
756 )
757
758 assert len(events) == 4
759 assert events[0].type == "thread.created"
760 thread = events[0].thread
761
762 assert events[1].type == "thread.item.done"
763 assert events[1].item.type == "user_message"
764
765 assert events[2].type == "stream_options"
766
767 assert events[3].type == "thread.item.done"
768 assert events[3].item.type == "client_tool_call"
769 assert events[3].item.id == "msg_1"
770 assert events[3].item.name == "tool_call_1"
771 assert events[3].item.arguments == {"arg1": "val1", "arg2": False}
772 assert events[3].item.call_id == "tool_call_1"
773
774 events = await server.process_streaming(
775 ThreadsAddClientToolOutputReq(
776 params=ThreadAddClientToolOutputParams(
777 thread_id=thread.id,
778 result={"text": "Wow!"},
779 )
780 )
781 )
782
783 assert len(events) == 2
784 assert events[0].type == "stream_options"
785 assert events[1].type == "thread.item.done"
786 assert events[1].item.type == "assistant_message"
787
788
789async def test_respond_with_tool_status():
790 async def responder(
791 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
792 ) -> AsyncIterator[ThreadStreamEvent]:
793 yield ProgressUpdateEvent(text="Tool status")
794 yield ThreadItemDoneEvent(
795 item=AssistantMessageItem(
796 id="msg_4",
797 content=[AssistantMessageContent(text="Hello, world!")],
798 created_at=datetime.now(),
799 thread_id=thread.id,
800 ),
801 )
802
803 with make_server(responder) as server:
804 events = await server.process_streaming(
805 ThreadsCreateReq(
806 params=ThreadCreateParams(
807 input=UserMessageInput(
808 content=[UserMessageTextContent(text="Hello, world!")],
809 attachments=[],
810 inference_options=InferenceOptions(),
811 )
812 )
813 )
814 )
815
816 assert len(events) == 5
817 assert events[0].type == "thread.created"
818 assert events[1].type == "thread.item.done"
819 assert events[2].type == "stream_options"
820 assert events[3].type == "progress_update"
821 assert events[4].type == "thread.item.done"
822
823
824async def test_list_threads_response():
825 with make_server() as server:
826 for i in range(25):
827 await server.process_streaming(
828 ThreadsCreateReq(
829 params=ThreadCreateParams(
830 input=UserMessageInput(
831 content=[UserMessageTextContent(text=f"Thread {i}")],
832 attachments=[],
833 inference_options=InferenceOptions(),
834 )
835 )
836 )
837 )
838 list_result = await server.process_non_streaming(
839 ThreadsListReq(params=ThreadListParams(limit=20))
840 )
841 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
842 assert len(threads.data) == 20
843
844 # Newest first
845 assert threads.data[0].created_at > threads.data[1].created_at
846 assert threads.has_more is True
847 assert threads.after is not None
848
849 more_threads = await server.process_non_streaming(
850 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
851 )
852 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(more_threads.json)
853 assert len(threads.data) == 5
854 assert threads.has_more is False
855 assert not threads.after
856
857
858async def test_list_items_response():
859 async def responder(
860 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
861 ) -> AsyncIterator[ThreadStreamEvent]:
862 for i in range(25):
863 yield ThreadItemDoneEvent(
864 item=UserMessageItem(
865 id=f"msg_{i}",
866 content=[UserMessageTextContent(text=f"Message {i}")],
867 attachments=[],
868 inference_options=InferenceOptions(),
869 thread_id=thread.id,
870 created_at=datetime.now(),
871 ),
872 )
873
874 with make_server(responder) as server:
875 events = await server.process_streaming(
876 ThreadsCreateReq(
877 params=ThreadCreateParams(
878 input=UserMessageInput(
879 content=[UserMessageTextContent(text="Thread 1")],
880 attachments=[],
881 inference_options=InferenceOptions(),
882 )
883 )
884 )
885 )
886 thread = next(
887 event.thread for event in events if event.type == "thread.created"
888 )
889 list_result = await server.process_non_streaming(
890 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
891 )
892 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
893 assert len(items.data) == 20
894 # Newest first
895 assert items.data[0].created_at > items.data[1].created_at
896 assert items.has_more is True
897 assert items.after is not None
898
899 list_result = await server.process_non_streaming(
900 ItemsListReq(
901 params=ItemsListParams(thread_id=thread.id, after=items.data[-1].id)
902 )
903 )
904 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
905 assert len(items.data) == 6 # +1 for user message
906 assert items.has_more is False
907
908
909async def test_calls_action():
910 actions = []
911
912 async def responder(
913 thread: ThreadMetadata,
914 input: UserMessageItem | None,
915 context: Any,
916 ) -> AsyncIterator[ThreadStreamEvent]:
917 yield ThreadItemDoneEvent(
918 item=WidgetItem(
919 id="widget_1",
920 type="widget",
921 created_at=datetime.now(),
922 thread_id=thread.id,
923 widget=Card(
924 children=[
925 Text(
926 key="text_1",
927 value="test",
928 )
929 ],
930 ),
931 ),
932 )
933
934 async def action(
935 thread: ThreadMetadata,
936 action: Action[str, Any],
937 sender: WidgetItem | None,
938 context: Any,
939 ) -> AsyncIterator[ThreadStreamEvent]:
940 actions.append((action, sender))
941 assert sender
942
943 yield ThreadItemUpdatedEvent(
944 item_id=sender.id,
945 update=WidgetRootUpdated(
946 widget=Card(
947 children=[
948 Text(value="Email sent!"),
949 ]
950 ),
951 ),
952 )
953
954 with make_server(responder, action_callback=action) as server:
955 events = await server.process_streaming(
956 ThreadsCreateReq(
957 params=ThreadCreateParams(
958 input=UserMessageInput(
959 content=[UserMessageTextContent(text="Show widget")],
960 attachments=[],
961 inference_options=InferenceOptions(),
962 )
963 )
964 )
965 )
966 thread = next(
967 event.thread for event in events if event.type == "thread.created"
968 )
969 widget_item = next(
970 event.item
971 for event in events
972 if isinstance(event, ThreadItemDoneEvent)
973 and isinstance(event.item, WidgetItem)
974 )
975
976 events = await server.process_streaming(
977 ThreadsCustomActionReq(
978 params=ThreadCustomActionParams(
979 thread_id=thread.id,
980 item_id=widget_item.id,
981 action=Action(type="create_user", payload={"user_id": "123"}),
982 )
983 )
984 )
985
986 assert actions
987 assert actions[0] == (
988 Action(type="create_user", payload={"user_id": "123"}),
989 widget_item,
990 )
991
992 assert len(events) == 2
993 assert events[0].type == "stream_options"
994 assert events[1].type == "thread.item.updated"
995 assert isinstance(events[1], ThreadItemUpdatedEvent)
996 assert events[1].update.type == "widget.root.updated"
997 assert events[1].update.widget == Card(children=[Text(value="Email sent!")])
998
999
1000async def test_calls_action_sync():
1001 sync_actions = []
1002
1003 async def responder(
1004 thread: ThreadMetadata,
1005 input: UserMessageItem | None,
1006 context: Any,
1007 ) -> AsyncIterator[ThreadStreamEvent]:
1008 yield ThreadItemDoneEvent(
1009 item=WidgetItem(
1010 id="widget_1",
1011 type="widget",
1012 created_at=datetime.now(),
1013 thread_id=thread.id,
1014 widget=Card(
1015 children=[
1016 Text(
1017 key="text_1",
1018 value="test",
1019 )
1020 ],
1021 ),
1022 ),
1023 )
1024
1025 def sync_action(
1026 thread: ThreadMetadata,
1027 action: Action[str, Any],
1028 sender: WidgetItem | None,
1029 context: Any,
1030 ) -> SyncCustomActionResponse:
1031 sync_actions.append((action, sender))
1032 assert sender
1033
1034 return SyncCustomActionResponse(
1035 updated_item=sender.model_copy(
1036 update={
1037 "widget": Card(
1038 children=[
1039 Text(value="Email sent!"),
1040 ]
1041 )
1042 }
1043 )
1044 )
1045
1046 with make_server(responder, sync_action_callback=sync_action) as server:
1047 events = await server.process_streaming(
1048 ThreadsCreateReq(
1049 params=ThreadCreateParams(
1050 input=UserMessageInput(
1051 content=[UserMessageTextContent(text="Show widget")],
1052 attachments=[],
1053 inference_options=InferenceOptions(),
1054 )
1055 )
1056 )
1057 )
1058 thread = next(
1059 event.thread for event in events if event.type == "thread.created"
1060 )
1061 widget_item = next(
1062 event.item
1063 for event in events
1064 if isinstance(event, ThreadItemDoneEvent)
1065 and isinstance(event.item, WidgetItem)
1066 )
1067
1068 result = await server.process_non_streaming(
1069 ThreadsSyncCustomActionReq(
1070 params=ThreadCustomActionParams(
1071 thread_id=thread.id,
1072 item_id=widget_item.id,
1073 action=Action(type="create_user", payload={"user_id": "123"}),
1074 )
1075 )
1076 )
1077 response = TypeAdapter(SyncCustomActionResponse).validate_json(result.json)
1078
1079 assert sync_actions
1080 assert sync_actions[0] == (
1081 Action(type="create_user", payload={"user_id": "123"}),
1082 widget_item,
1083 )
1084 assert response.updated_item == widget_item.model_copy(
1085 update={"widget": Card(children=[Text(value="Email sent!")])}
1086 )
1087
1088
1089async def test_add_feedback():
1090 called = []
1091
1092 def handle_feedback(thread_id, item_ids, feedback, context):
1093 called.append((thread_id, item_ids, feedback, context))
1094
1095 async def responder(
1096 thread: ThreadMetadata,
1097 input: UserMessageItem | None,
1098 context: Any,
1099 ) -> AsyncIterator[ThreadStreamEvent]:
1100 yield ThreadItemDoneEvent(
1101 item=AssistantMessageItem(
1102 id="assistant_msg_1",
1103 content=[
1104 AssistantMessageContent(text="Feedback received!"),
1105 ],
1106 created_at=datetime.now(),
1107 thread_id=thread.id,
1108 ),
1109 )
1110 yield ThreadItemDoneEvent(
1111 item=WidgetItem(
1112 id="widget_1",
1113 type="widget",
1114 created_at=datetime.now(),
1115 widget=Card(children=[Text(key="text", value="Widget content")]),
1116 thread_id=thread.id,
1117 ),
1118 )
1119
1120 with make_server(responder, handle_feedback=handle_feedback) as server:
1121 events = await server.process_streaming(
1122 ThreadsCreateReq(
1123 params=ThreadCreateParams(
1124 input=UserMessageInput(
1125 content=[UserMessageTextContent(text="Test feedback")],
1126 attachments=[],
1127 inference_options=InferenceOptions(),
1128 )
1129 )
1130 )
1131 )
1132 thread = next(
1133 event.thread for event in events if event.type == "thread.created"
1134 )
1135
1136 await server.process_non_streaming(
1137 ItemsFeedbackReq(
1138 params=ItemFeedbackParams(
1139 thread_id=thread.id,
1140 item_ids=["assistant_msg_1"],
1141 kind="positive",
1142 )
1143 )
1144 )
1145 await server.process_non_streaming(
1146 ItemsFeedbackReq(
1147 params=ItemFeedbackParams(
1148 thread_id=thread.id,
1149 item_ids=["widget_1"],
1150 kind="negative",
1151 )
1152 )
1153 )
1154
1155 assert len(called) == 2
1156 assert called[0][0] == thread.id
1157 assert called[0][1] == ["assistant_msg_1"]
1158 assert called[0][2] == "positive"
1159
1160 assert called[1][0] == thread.id
1161 assert called[1][1] == ["widget_1"]
1162 assert called[1][2] == "negative"
1163
1164
1165async def test_get_thread_by_id():
1166 make_thread_items()
1167
1168 with make_server() as server:
1169 # Create a thread by sending a ThreadCreateRequest
1170 events = await server.process_streaming(
1171 ThreadsCreateReq(
1172 params=ThreadCreateParams(
1173 input=UserMessageInput(
1174 content=[UserMessageTextContent(text="Test thread")],
1175 attachments=[],
1176 inference_options=InferenceOptions(),
1177 )
1178 )
1179 )
1180 )
1181 thread = next(
1182 event.thread for event in events if event.type == "thread.created"
1183 )
1184
1185 # Retrieve the thread by id
1186 result = await server.process_non_streaming(
1187 ThreadsGetByIdReq(
1188 params=ThreadGetByIdParams(thread_id=thread.id),
1189 )
1190 )
1191 loaded_thread = TypeAdapter(Thread).validate_json(result.json)
1192 assert loaded_thread.id == thread.id
1193 assert loaded_thread.title == thread.title
1194 assert loaded_thread.items.data[0].type == "user_message"
1195 assert loaded_thread.items.data[0].content[0].text == "Test thread"
1196
1197
1198async def test_create_file():
1199 store = InMemoryFileStore()
1200 file_name = "test-file-name"
1201 file_content_type = "text/plain"
1202
1203 with make_server(file_store=store) as server:
1204 result = await server.process_non_streaming(
1205 AttachmentsCreateReq(
1206 params=AttachmentCreateParams(
1207 name=file_name, size=1024, mime_type=file_content_type
1208 )
1209 )
1210 )
1211 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1212
1213 # Verify response includes file
1214 assert attachment.id is not None
1215 assert attachment.mime_type == file_content_type
1216 assert attachment.name == file_name
1217 assert attachment.type == "file"
1218 assert attachment.upload_descriptor is not None
1219 assert attachment.upload_descriptor.url == AnyUrl(
1220 f"https://example.com/{attachment.id}/upload"
1221 )
1222 assert attachment.upload_descriptor.method == "PUT"
1223 assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"}
1224 assert attachment.id in store.files
1225
1226 # The `metadata` field is excluded from serialization by default but still stored server-side
1227 assert attachment.metadata is None
1228 assert store.files[attachment.id].metadata == {"source": "test"}
1229
1230
1231async def test_create_image_file():
1232 store = InMemoryFileStore()
1233 file_name = "test-file-name"
1234 file_content_type = "image/png"
1235
1236 with make_server(file_store=store) as server:
1237 result = await server.process_non_streaming(
1238 AttachmentsCreateReq(
1239 params=AttachmentCreateParams(
1240 name=file_name,
1241 size=1024,
1242 mime_type=file_content_type,
1243 )
1244 )
1245 )
1246 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1247
1248 assert attachment.id is not None
1249 assert attachment.mime_type == file_content_type
1250 assert attachment.name == file_name
1251 assert attachment.type == "image"
1252 assert attachment.preview_url == AnyUrl(
1253 f"https://example.com/{attachment.id}/preview"
1254 )
1255 assert attachment.upload_descriptor is not None
1256 assert attachment.upload_descriptor.url == AnyUrl(
1257 f"https://example.com/{attachment.id}/upload"
1258 )
1259 assert attachment.upload_descriptor.method == "PUT"
1260 assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"}
1261
1262 assert attachment.id in store.files
1263
1264
1265async def test_user_message_attachment_metadata_not_serialized():
1266 attachment = FileAttachment(
1267 id="file_with_metadata",
1268 type="file",
1269 mime_type="text/plain",
1270 name="test.txt",
1271 metadata={"source": "test"},
1272 )
1273
1274 with make_server() as server:
1275 await server.store.save_attachment(attachment, DEFAULT_CONTEXT)
1276 events = await server.process_streaming(
1277 ThreadsCreateReq(
1278 params=ThreadCreateParams(
1279 input=UserMessageInput(
1280 content=[UserMessageTextContent(text="Message with file")],
1281 attachments=[attachment.id],
1282 inference_options=InferenceOptions(),
1283 )
1284 )
1285 )
1286 )
1287 thread = next(
1288 event.thread for event in events if event.type == "thread.created"
1289 )
1290
1291 list_result = await server.process_non_streaming(
1292 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1293 )
1294 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1295 user_items = [item for item in items.data if item.type == "user_message"]
1296 assert user_items
1297 attachments = user_items[0].attachments
1298 assert attachments
1299 assert all(att.metadata is None for att in attachments)
1300
1301
1302async def test_user_message_attachment_thread_id_persisted_on_thread_create():
1303 attachment = FileAttachment(
1304 id="file_thread_id_on_create",
1305 type="file",
1306 mime_type="text/plain",
1307 name="test.txt",
1308 thread_id=None,
1309 )
1310
1311 with make_server() as server:
1312 # Attachment exists before a thread exists, so it may not have a thread_id yet.
1313 await server.store.save_attachment(attachment, DEFAULT_CONTEXT)
1314
1315 events = await server.process_streaming(
1316 ThreadsCreateReq(
1317 params=ThreadCreateParams(
1318 input=UserMessageInput(
1319 content=[UserMessageTextContent(text="Message with file")],
1320 attachments=[attachment.id],
1321 inference_options=InferenceOptions(),
1322 )
1323 )
1324 )
1325 )
1326 thread = next(e.thread for e in events if e.type == "thread.created")
1327
1328 # After the user message is saved, the attachment record should be updated
1329 # to reflect its thread association.
1330 stored = await server.store.load_attachment(attachment.id, DEFAULT_CONTEXT)
1331 assert stored.thread_id == thread.id
1332
1333
1334async def test_user_message_attachment_thread_id_persisted_on_add_user_message():
1335 with make_server() as server:
1336 # Create a thread first.
1337 events = await server.process_streaming(
1338 ThreadsCreateReq(
1339 params=ThreadCreateParams(
1340 input=UserMessageInput(
1341 content=[UserMessageTextContent(text="Start")],
1342 attachments=[],
1343 inference_options=InferenceOptions(),
1344 )
1345 )
1346 )
1347 )
1348 thread = next(e.thread for e in events if e.type == "thread.created")
1349
1350 attachment = FileAttachment(
1351 id="file_thread_id_on_add",
1352 type="file",
1353 mime_type="text/plain",
1354 name="test.txt",
1355 thread_id=None,
1356 )
1357 await server.store.save_attachment(attachment, DEFAULT_CONTEXT)
1358
1359 # Attach the pre-created attachment to a message in an existing thread.
1360 await server.process_streaming(
1361 ThreadsAddUserMessageReq(
1362 params=ThreadAddUserMessageParams(
1363 thread_id=thread.id,
1364 input=UserMessageInput(
1365 content=[UserMessageTextContent(text="Second")],
1366 attachments=[attachment.id],
1367 inference_options=InferenceOptions(),
1368 ),
1369 )
1370 )
1371 )
1372
1373 stored = await server.store.load_attachment(attachment.id, DEFAULT_CONTEXT)
1374 assert stored.thread_id == thread.id
1375
1376
1377async def test_create_file_without_filestore():
1378 with pytest.raises(RuntimeError):
1379 with make_server() as server:
1380 await server.process_non_streaming(
1381 AttachmentsCreateReq(
1382 params=AttachmentCreateParams(
1383 name="test-file-name",
1384 size=1024,
1385 mime_type="text/plain",
1386 )
1387 )
1388 )
1389
1390
1391async def test_delete_file():
1392 store = InMemoryFileStore()
1393 store.files["test-file-id"] = {"id": "test-file-id"}
1394 with make_server(file_store=store) as server:
1395 await server.process_non_streaming(
1396 AttachmentsDeleteReq(
1397 params=AttachmentDeleteParams(attachment_id="test-file-id")
1398 )
1399 )
1400 assert store.files == {}
1401
1402
1403async def test_delete_file_without_filestore():
1404 with pytest.raises(RuntimeError):
1405 with make_server() as server:
1406 await server.process_non_streaming(
1407 AttachmentsDeleteReq(
1408 params=AttachmentDeleteParams(attachment_id="test-file-id")
1409 )
1410 )
1411
1412
1413async def test_list_threads_paginated_response():
1414 with make_server() as server:
1415 for i in range(25):
1416 await server.process_streaming(
1417 ThreadsCreateReq(
1418 params=ThreadCreateParams(
1419 input=UserMessageInput(
1420 content=[UserMessageTextContent(text=f"Thread {i}")],
1421 attachments=[],
1422 inference_options=InferenceOptions(),
1423 )
1424 )
1425 )
1426 )
1427 list_result = await server.process_non_streaming(
1428 ThreadsListReq(params=ThreadListParams(limit=20))
1429 )
1430 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1431 assert len(threads.data) == 20
1432 assert threads.has_more is True
1433 assert threads.after is not None
1434
1435 list_result = await server.process_non_streaming(
1436 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
1437 )
1438 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1439 assert len(threads.data) == 5
1440 assert threads.has_more is False
1441
1442
1443async def test_threads_update_request():
1444 with make_server() as server:
1445 events = await server.process_streaming(
1446 ThreadsCreateReq(
1447 params=ThreadCreateParams(
1448 input=UserMessageInput(
1449 content=[UserMessageTextContent(text="Hello, world!")],
1450 attachments=[],
1451 inference_options=InferenceOptions(),
1452 )
1453 )
1454 )
1455 )
1456 thread = next(
1457 event.thread for event in events if event.type == "thread.created"
1458 )
1459 updated_thread = await server.process_non_streaming(
1460 ThreadsUpdateReq(
1461 params=ThreadUpdateParams(thread_id=thread.id, title="New Title")
1462 )
1463 )
1464 updated_thread = TypeAdapter[Thread](Thread).validate_json(updated_thread.json)
1465 assert updated_thread.title == "New Title"
1466
1467
1468async def test_returns_widget_item_generator():
1469 thread = ThreadMetadata(
1470 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1471 )
1472
1473 def render_widget(i: int) -> Card:
1474 return Card(children=[Text(id="text", value="Hello, world"[:i])])
1475
1476 async def widget_generator():
1477 yield render_widget(0)
1478 yield render_widget(3)
1479 yield render_widget(6)
1480 yield render_widget(12)
1481
1482 events = [event async for event in stream_widget(thread, widget_generator())]
1483
1484 assert len(events) == 5
1485 assert isinstance(events[0], ThreadItemAddedEvent)
1486 assert isinstance(events[0].item, WidgetItem)
1487 assert events[0].item.widget == Card(children=[Text(id="text", value="")])
1488
1489 assert isinstance(events[1], ThreadItemUpdatedEvent)
1490 assert events[1].update.type == "widget.streaming_text.value_delta"
1491 assert events[1].update.component_id == "text"
1492 assert events[1].update.delta == "Hel"
1493
1494 assert isinstance(events[2], ThreadItemUpdatedEvent)
1495 assert events[2].update.type == "widget.streaming_text.value_delta"
1496 assert events[2].update.component_id == "text"
1497 assert events[2].update.delta == "lo,"
1498
1499 assert isinstance(events[3], ThreadItemUpdatedEvent)
1500 assert events[3].update.type == "widget.streaming_text.value_delta"
1501 assert events[3].update.component_id == "text"
1502 assert events[3].update.delta == " world"
1503
1504 assert isinstance(events[4], ThreadItemDoneEvent)
1505 assert isinstance(events[4].item, WidgetItem)
1506 assert events[4].item.widget == Card(
1507 children=[Text(id="text", value="Hello, world")]
1508 )
1509
1510
1511async def test_returns_widget_item_generator_full_replace():
1512 thread = ThreadMetadata(
1513 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1514 )
1515
1516 async def widget_generator():
1517 yield Card(children=[Text(id="text", value="Hello")])
1518 yield Card(children=[Text(id="text", value="World")])
1519
1520 events = [event async for event in stream_widget(thread, widget_generator())]
1521
1522 assert len(events) == 3
1523 assert isinstance(events[0], ThreadItemAddedEvent)
1524 assert isinstance(events[0].item, WidgetItem)
1525 assert events[0].item.widget == Card(children=[Text(id="text", value="Hello")])
1526
1527 assert isinstance(events[1], ThreadItemUpdatedEvent)
1528 assert events[1].update.type == "widget.root.updated"
1529 assert events[1].update.widget == Card(children=[Text(id="text", value="World")])
1530
1531 assert isinstance(events[2], ThreadItemDoneEvent)
1532 assert isinstance(events[2].item, WidgetItem)
1533 assert events[2].item.widget == Card(children=[Text(id="text", value="World")])
1534
1535
1536async def test_delete_thread():
1537 with make_server() as server:
1538 events = await server.process_streaming(
1539 ThreadsCreateReq(
1540 params=ThreadCreateParams(
1541 input=UserMessageInput(
1542 content=[UserMessageTextContent(text="Thread to delete")],
1543 attachments=[],
1544 inference_options=InferenceOptions(),
1545 )
1546 )
1547 )
1548 )
1549 thread = next(
1550 event.thread for event in events if event.type == "thread.created"
1551 )
1552
1553 response = await server.process_non_streaming(
1554 ThreadsDeleteReq(params=ThreadDeleteParams(thread_id=thread.id))
1555 )
1556 assert response.json == b"{}"
1557
1558 try:
1559 await server.process_non_streaming(
1560 ThreadsGetByIdReq(
1561 params=ThreadGetByIdParams(thread_id=thread.id),
1562 )
1563 )
1564 assert False, "Expected ValueError for deleted thread"
1565 except NotFoundError as e:
1566 assert f"Thread {thread.id} not found" in str(e)
1567
1568
1569async def test_thread_item_removed_event_removes_item():
1570 removed_id = "msg_to_remove"
1571
1572 async def responder(
1573 thread: ThreadMetadata,
1574 input: UserMessageItem | None,
1575 context: Any,
1576 ) -> AsyncIterator[ThreadStreamEvent]:
1577 yield ThreadItemDoneEvent(
1578 item=UserMessageItem(
1579 id=removed_id,
1580 content=[UserMessageTextContent(text="To be removed")],
1581 attachments=[],
1582 inference_options=InferenceOptions(),
1583 thread_id=thread.id,
1584 created_at=datetime.now(),
1585 ),
1586 )
1587 yield ThreadItemRemovedEvent(item_id=removed_id)
1588
1589 with make_server(responder) as server:
1590 events = await server.process_streaming(
1591 ThreadsCreateReq(
1592 params=ThreadCreateParams(
1593 input=UserMessageInput(
1594 content=[UserMessageTextContent(text="Trigger removal")],
1595 attachments=[],
1596 inference_options=InferenceOptions(),
1597 )
1598 )
1599 )
1600 )
1601 thread = next(
1602 event.thread for event in events if event.type == "thread.created"
1603 )
1604 assert any(
1605 e.type == "thread.item.removed" and e.item_id == removed_id for e in events
1606 )
1607 # The item should not be present in the store
1608 list_result = await server.process_non_streaming(
1609 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1610 )
1611 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1612 assert all(i.id != removed_id for i in items)
1613
1614
1615async def test_raising_in_responder_yields_error_event():
1616 async def responder(
1617 thread: ThreadMetadata,
1618 input: UserMessageItem | None,
1619 context: Any,
1620 ) -> AsyncIterator[ThreadStreamEvent]:
1621 raise ValueError("Test error")
1622 yield
1623
1624 with make_server(responder) as server:
1625 result = await server.process(
1626 ThreadsCreateReq(
1627 params=ThreadCreateParams(
1628 input=UserMessageInput(
1629 content=[UserMessageTextContent(text="Test error")],
1630 attachments=[],
1631 inference_options=InferenceOptions(),
1632 )
1633 )
1634 ).model_dump_json(),
1635 RequestContext(user_id="test_user"),
1636 )
1637 assert isinstance(result, StreamingResult)
1638 iter = result.__aiter__()
1639
1640 # Yield an errro
1641 async for e in iter:
1642 e = decode_event(e)
1643 if e.type == "error":
1644 assert e.code == ErrorCode.STREAM_ERROR
1645 break
1646 else:
1647 assert False, "Expected error event"
1648
1649 # Then finish
1650 try:
1651 await iter.__anext__()
1652 except StopAsyncIteration:
1653 pass
1654 else:
1655 assert False, "Expected StopAsyncIteration"
1656
1657
1658async def test_thread_item_replaced_event_replaces_item():
1659 original_id = "msg_to_replace"
1660
1661 async def responder(
1662 thread: ThreadMetadata,
1663 input: UserMessageItem | None,
1664 context: Any,
1665 ) -> AsyncIterator[ThreadStreamEvent]:
1666 # First add an item
1667 yield ThreadItemDoneEvent(
1668 item=UserMessageItem(
1669 id=original_id,
1670 content=[UserMessageTextContent(text="Original content")],
1671 attachments=[],
1672 inference_options=InferenceOptions(),
1673 thread_id=thread.id,
1674 created_at=datetime.now(),
1675 ),
1676 )
1677 # Then replace it
1678 replacement_item = AssistantMessageItem(
1679 id=original_id,
1680 content=[AssistantMessageContent(text="This is the replacement content")],
1681 created_at=datetime.now(),
1682 thread_id=thread.id,
1683 )
1684 yield ThreadItemReplacedEvent(item=replacement_item)
1685
1686 with make_server(responder) as server:
1687 events = await server.process_streaming(
1688 ThreadsCreateReq(
1689 params=ThreadCreateParams(
1690 input=UserMessageInput(
1691 content=[UserMessageTextContent(text="Trigger replacement")],
1692 attachments=[],
1693 inference_options=InferenceOptions(),
1694 )
1695 )
1696 )
1697 )
1698 thread = next(
1699 event.thread for event in events if event.type == "thread.created"
1700 )
1701
1702 # Verify the replacement event was generated
1703 assert any(
1704 e.type == "thread.item.replaced" and e.item.id == original_id
1705 for e in events
1706 )
1707
1708 # Verify the item was replaced in the store
1709 list_result = await server.process_non_streaming(
1710 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1711 )
1712 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1713
1714 # Find the replaced item
1715 replaced_item = next((i for i in items if i.id == original_id), None)
1716 assert replaced_item is not None
1717 assert isinstance(replaced_item, AssistantMessageItem)
1718 assert isinstance(replaced_item.content[0], AssistantMessageContent)
1719 assert replaced_item.content[0].text == "This is the replacement content"
1720
1721
1722async def test_thread_item_replaced_event_with_widget():
1723 widget_id = "widget_to_replace"
1724 original_widget = Card(children=[Text(value="Original widget content")])
1725 replacement_widget = Card(children=[Text(value="Replaced widget content")])
1726
1727 async def responder(
1728 thread: ThreadMetadata,
1729 input: UserMessageItem | None,
1730 context: Any,
1731 ) -> AsyncIterator[ThreadStreamEvent]:
1732 # First add a widget item
1733 yield ThreadItemDoneEvent(
1734 item=WidgetItem(
1735 id=widget_id,
1736 widget=original_widget,
1737 created_at=datetime.now(),
1738 thread_id=thread.id,
1739 ),
1740 )
1741 # Then replace it
1742 yield ThreadItemReplacedEvent(
1743 item=WidgetItem(
1744 id=widget_id,
1745 widget=replacement_widget,
1746 created_at=datetime.now(),
1747 thread_id=thread.id,
1748 ),
1749 )
1750
1751 with make_server(responder) as server:
1752 events = await server.process_streaming(
1753 ThreadsCreateReq(
1754 params=ThreadCreateParams(
1755 input=UserMessageInput(
1756 content=[
1757 UserMessageTextContent(text="Trigger widget replacement")
1758 ],
1759 attachments=[],
1760 inference_options=InferenceOptions(),
1761 )
1762 )
1763 )
1764 )
1765 thread = next(
1766 event.thread for event in events if event.type == "thread.created"
1767 )
1768
1769 # Verify the replacement event was generated
1770 replacement_event = next(
1771 (
1772 e
1773 for e in events
1774 if e.type == "thread.item.replaced" and e.item.id == widget_id
1775 ),
1776 None,
1777 )
1778 assert replacement_event is not None
1779 assert isinstance(replacement_event.item, WidgetItem)
1780 assert replacement_event.item.widget == replacement_widget
1781
1782 # Verify the item was replaced in the store
1783 list_result = await server.process_non_streaming(
1784 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1785 )
1786 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1787
1788 # Find the replaced widget item
1789 replaced_item = next((i for i in items if i.id == widget_id), None)
1790 assert replaced_item is not None
1791 assert isinstance(replaced_item, WidgetItem)
1792 assert isinstance(replaced_item.widget, Card)
1793 assert isinstance(replaced_item.widget.children[0], Text)
1794 assert replaced_item.widget.children[0].value == "Replaced widget content"
1795
1796
1797async def test_retry_after_item():
1798 responder_calls = []
1799
1800 async def responder(
1801 thread: ThreadMetadata,
1802 input: UserMessageItem | None,
1803 context: Any,
1804 ) -> AsyncIterator[ThreadStreamEvent]:
1805 responder_calls.append(input)
1806
1807 if len(responder_calls) == 1:
1808 # First response - return an assistant message
1809 yield ThreadItemDoneEvent(
1810 item=AssistantMessageItem(
1811 id="assistant_msg_1",
1812 content=[AssistantMessageContent(text="First response")],
1813 created_at=datetime.now(),
1814 thread_id=thread.id,
1815 ),
1816 )
1817 else:
1818 # Retry response - return different content
1819 yield ThreadItemDoneEvent(
1820 item=AssistantMessageItem(
1821 id="assistant_msg_2",
1822 content=[AssistantMessageContent(text="Retried response")],
1823 created_at=datetime.now(),
1824 thread_id=thread.id,
1825 ),
1826 )
1827
1828 with make_server(responder) as server:
1829 events = await server.process_streaming(
1830 ThreadsCreateReq(
1831 params=ThreadCreateParams(
1832 input=UserMessageInput(
1833 content=[UserMessageTextContent(text="Hello, world!")],
1834 attachments=[],
1835 inference_options=InferenceOptions(),
1836 )
1837 )
1838 )
1839 )
1840 thread = next(
1841 event.thread for event in events if event.type == "thread.created"
1842 )
1843
1844 list_result = await server.process_non_streaming(
1845 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1846 )
1847 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1848 assert len(items_before.data) == 2
1849 assert items_before.data[0].type == "assistant_message"
1850 assert items_before.data[0].content[0].type == "output_text"
1851 assert items_before.data[0].content[0].text == "First response"
1852 assert items_before.data[1].type == "user_message"
1853 assert items_before.data[1].content[0].text == "Hello, world!"
1854
1855 # Get the user message ID for retrying after it
1856 user_message_id = items_before.data[1].id
1857
1858 # Retry after the user message
1859 retry_events = await server.process_streaming(
1860 ThreadsRetryAfterItemReq(
1861 params=ThreadRetryAfterItemParams(
1862 thread_id=thread.id, item_id=user_message_id
1863 )
1864 )
1865 )
1866
1867 # Verify the retry generated new response
1868 assert len(retry_events) == 2
1869 assert retry_events[0].type == "stream_options"
1870 assert retry_events[1].type == "thread.item.done"
1871 assert retry_events[1].item.type == "assistant_message"
1872 assert retry_events[1].item.content[0].type == "output_text"
1873 assert retry_events[1].item.content[0].text == "Retried response"
1874
1875 # Verify the responder was called twice with the same user message
1876 assert len(responder_calls) == 2
1877 assert responder_calls[0] == responder_calls[1]
1878 assert responder_calls[0].type == "user_message"
1879 assert responder_calls[0].content[0].text == "Hello, world!"
1880
1881 # Verify the assistant response was replaced in storage
1882 list_result_after = await server.process_non_streaming(
1883 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1884 )
1885 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1886 list_result_after.json
1887 )
1888 assert len(items_after.data) == 2 # Still user message + assistant message
1889 assert items_after.data[0].type == "assistant_message"
1890 assert items_after.data[0].content[0].type == "output_text"
1891 assert items_after.data[0].content[0].text == "Retried response" # New response
1892 assert items_after.data[1].type == "user_message"
1893 assert items_after.data[1].content[0].text == "Hello, world!"
1894
1895
1896async def test_retry_after_item_invalid_item():
1897 async def responder(
1898 thread: ThreadMetadata,
1899 input: UserMessageItem | None,
1900 context: Any,
1901 ) -> AsyncIterator[ThreadStreamEvent]:
1902 yield ThreadItemDoneEvent(
1903 item=AssistantMessageItem(
1904 id="assistant_msg_1",
1905 content=[AssistantMessageContent(text="Response")],
1906 created_at=datetime.now(),
1907 thread_id=thread.id,
1908 ),
1909 )
1910
1911 with make_server(responder) as server:
1912 events = await server.process_streaming(
1913 ThreadsCreateReq(
1914 params=ThreadCreateParams(
1915 input=UserMessageInput(
1916 content=[UserMessageTextContent(text="Hello")],
1917 attachments=[],
1918 inference_options=InferenceOptions(),
1919 )
1920 )
1921 )
1922 )
1923 thread = next(
1924 event.thread for event in events if event.type == "thread.created"
1925 )
1926
1927 items_result = await server.process_non_streaming(
1928 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1929 )
1930 items = TypeAdapter(Page[ThreadItem]).validate_json(items_result.json)
1931 assistant_item_id = items.data[0].id
1932
1933 # Try to retry after an assistant message (should fail)
1934 with pytest.raises(ValueError, match="is not a user message"):
1935 await server.process_streaming(
1936 ThreadsRetryAfterItemReq(
1937 params=ThreadRetryAfterItemParams(
1938 thread_id=thread.id, item_id=assistant_item_id
1939 )
1940 )
1941 )
1942
1943
1944async def test_retry_after_item_with_multiple_messages():
1945 responder_calls = []
1946
1947 async def responder(
1948 thread: ThreadMetadata,
1949 input: UserMessageItem | None,
1950 context: Any,
1951 ) -> AsyncIterator[ThreadStreamEvent]:
1952 responder_calls.append(input)
1953 assert input is not None
1954 assert input.type == "user_message"
1955 yield ThreadItemDoneEvent(
1956 item=AssistantMessageItem(
1957 id=f"assistant_msg_{len(responder_calls)}",
1958 content=[
1959 AssistantMessageContent(
1960 text=f"Response to: {input.content[0].text}"
1961 )
1962 ],
1963 created_at=datetime.now(),
1964 thread_id=thread.id,
1965 ),
1966 )
1967
1968 with make_server(responder) as server:
1969 events = await server.process_streaming(
1970 ThreadsCreateReq(
1971 params=ThreadCreateParams(
1972 input=UserMessageInput(
1973 content=[UserMessageTextContent(text="First message")],
1974 attachments=[],
1975 inference_options=InferenceOptions(),
1976 )
1977 )
1978 )
1979 )
1980 thread = next(
1981 event.thread for event in events if event.type == "thread.created"
1982 )
1983
1984 await server.process_streaming(
1985 ThreadsAddUserMessageReq(
1986 params=ThreadAddUserMessageParams(
1987 thread_id=thread.id,
1988 input=UserMessageInput(
1989 content=[UserMessageTextContent(text="Second message")],
1990 attachments=[],
1991 inference_options=InferenceOptions(),
1992 ),
1993 )
1994 )
1995 )
1996
1997 # Verify we have 4 items: user1, assistant1, user2, assistant2
1998 list_result = await server.process_non_streaming(
1999 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
2000 )
2001 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
2002 assert len(items_before.data) == 4
2003
2004 # Get the second user message ID to retry after it
2005 second_user_message_id = items_before.data[1].id # Second user message
2006
2007 # Retry after the second user message
2008 retry_events = await server.process_streaming(
2009 ThreadsRetryAfterItemReq(
2010 params=ThreadRetryAfterItemParams(
2011 thread_id=thread.id, item_id=second_user_message_id
2012 )
2013 )
2014 )
2015
2016 assert len(retry_events) == 2
2017 assert retry_events[0].type == "stream_options"
2018 assert retry_events[1].type == "thread.item.done"
2019
2020 # Verify retry used the second user message
2021 assert len(responder_calls) == 3 # Original 2 + 1 retry
2022 assert responder_calls[2].type == "user_message"
2023 assert responder_calls[2].content[0].text == "Second message"
2024
2025 # Verify only the last assistant response was removed and regenerated
2026 list_result_after = await server.process_non_streaming(
2027 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
2028 )
2029 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
2030 list_result_after.json
2031 )
2032 assert len(items_after.data) == 4 # Still 4 items
2033 # First user message and response should remain unchanged
2034 assert items_after.data[3].type == "user_message"
2035 assert items_after.data[3].content[0].text == "First message"
2036 assert items_after.data[3].content[0].type == "input_text"
2037 assert items_after.data[3].content[0].text == "First message"
2038 assert items_after.data[2].type == "assistant_message"
2039 assert items_after.data[2].content[0].type == "output_text"
2040 assert items_after.data[2].content[0].text == "Response to: First message"
2041 # Second user message should remain, but assistant response should be new
2042 assert items_after.data[1].type == "user_message"
2043 assert items_after.data[1].content[0].text == "Second message"
2044 assert items_after.data[0].type == "assistant_message"
2045 assert items_after.data[0].content[0].type == "output_text"
2046 assert items_after.data[0].content[0].text == "Response to: Second message"
2047 assert items_after.data[0].id != items_before.data[0].id
2048
2049
2050async def test_threads_create_passes_tools_to_responder():
2051 tool_choice = ToolChoice(id="web_search")
2052
2053 async def responder(
2054 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
2055 ) -> AsyncIterator[ThreadStreamEvent]:
2056 assert input is not None
2057 assert input.inference_options.tool_choice is not None
2058 assert input.inference_options.tool_choice.id == tool_choice.id
2059
2060 yield ThreadItemDoneEvent(
2061 item=AssistantMessageItem(
2062 id="assistant_msg_tools_create",
2063 content=[AssistantMessageContent(text="ok")],
2064 created_at=datetime.now(),
2065 thread_id=thread.id,
2066 ),
2067 )
2068
2069 with make_server(responder) as server:
2070 events = await server.process_streaming(
2071 ThreadsCreateReq(
2072 params=ThreadCreateParams(
2073 input=UserMessageInput(
2074 content=[UserMessageTextContent(text="Hello with tools")],
2075 attachments=[],
2076 inference_options=InferenceOptions(tool_choice=tool_choice),
2077 )
2078 )
2079 )
2080 )
2081 # Sanity check that the request flowed through.
2082 assert any(e.type == "thread.item.done" for e in events)
2083
2084
2085async def test_input_transcribe_decodes_base64_and_normalizes_mime_type():
2086 audio_bytes = b"hello audio"
2087 audio_b64 = base64.b64encode(audio_bytes).decode("ascii")
2088 seen: dict[str, Any] = {}
2089
2090 def transcribe_callback(
2091 audio_input: AudioInput, context: Any
2092 ) -> TranscriptionResult:
2093 seen["audio"] = audio_input.data
2094 seen["mime"] = audio_input.mime_type
2095 seen["media_type"] = audio_input.media_type
2096 seen["context"] = context
2097 return TranscriptionResult(text="ok")
2098
2099 with make_server(transcribe_callback=transcribe_callback) as server:
2100 result = await server.process_non_streaming(
2101 InputTranscribeReq(
2102 params=InputTranscribeParams(
2103 audio_base64=audio_b64,
2104 mime_type="audio/webm;codecs=opus",
2105 )
2106 )
2107 )
2108 parsed = TypeAdapter(TranscriptionResult).validate_json(result.json)
2109 assert parsed.text == "ok"
2110
2111 assert seen["audio"] == audio_bytes
2112 assert seen["mime"] == "audio/webm;codecs=opus"
2113 assert seen["media_type"] == "audio/webm"
2114 assert seen["context"] == DEFAULT_CONTEXT
2115
2116
2117async def test_retry_after_item_passes_tools_to_responder():
2118 pass
2119
2120
2121async def test_threads_add_message_passes_tools_to_responder():
2122 tool_choice = ToolChoice(id="retrieval")
2123
2124 async def responder(
2125 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
2126 ) -> AsyncIterator[ThreadStreamEvent]:
2127 assert input is not None
2128 assert input.inference_options.tool_choice is not None
2129 assert input.inference_options.tool_choice.id == tool_choice.id
2130
2131 yield ThreadItemDoneEvent(
2132 item=AssistantMessageItem(
2133 id="assistant_msg_tools_add",
2134 content=[AssistantMessageContent(text="ok")],
2135 created_at=datetime.now(),
2136 thread_id=thread.id,
2137 ),
2138 )
2139
2140 with make_server(responder) as server:
2141 # Create a thread first.
2142 events = await server.process_streaming(
2143 ThreadsCreateReq(
2144 params=ThreadCreateParams(
2145 input=UserMessageInput(
2146 content=[UserMessageTextContent(text="Start")],
2147 attachments=[],
2148 inference_options=InferenceOptions(tool_choice=tool_choice),
2149 )
2150 )
2151 )
2152 )
2153 thread = next(e.thread for e in events if e.type == "thread.created")
2154
2155 # Add a message with tools and verify they reach the responder.
2156 events = await server.process_streaming(
2157 ThreadsAddUserMessageReq(
2158 params=ThreadAddUserMessageParams(
2159 thread_id=thread.id,
2160 input=UserMessageInput(
2161 content=[UserMessageTextContent(text="Second")],
2162 attachments=[],
2163 inference_options=InferenceOptions(tool_choice=tool_choice),
2164 ),
2165 )
2166 )
2167 )
2168 # Sanity check that the request flowed through.
2169 assert any(e.type == "thread.item.done" for e in events)
2170
2171
2172async def test_custom_id_generators_on_server():
2173 class UniqueIdStore(SQLiteStore):
2174 def __init__(self, path):
2175 super().__init__(path)
2176 self.counter = 0
2177
2178 def generate_thread_id(self, context) -> str:
2179 self.counter += 1
2180 return f"thr_custom_{self.counter}"
2181
2182 def generate_item_id(
2183 self,
2184 item_type: str,
2185 thread: ThreadMetadata,
2186 context,
2187 ) -> str:
2188 self.counter += 1
2189 return f"{item_type}_custom_{self.counter}_{thread.id}"
2190
2191 async def responder(
2192 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
2193 ) -> AsyncIterator[ThreadStreamEvent]:
2194 # Emit a tool call to later verify tool_call_output id generation
2195 return
2196 yield
2197
2198 with make_server(responder) as server:
2199 # Swap in our custom ID-generating store, reusing the same DB path.
2200 db_path = cast(SQLiteStore, server.store).db_path
2201 server.store = UniqueIdStore(db_path)
2202
2203 # Create thread and verify thread id and user message id
2204 result = await server.process(
2205 ThreadsCreateReq(
2206 params=ThreadCreateParams(
2207 input=UserMessageInput(
2208 content=[UserMessageTextContent(text="Hello")],
2209 attachments=[],
2210 inference_options=InferenceOptions(),
2211 )
2212 )
2213 ).model_dump_json(),
2214 DEFAULT_CONTEXT,
2215 )
2216 assert isinstance(result, StreamingResult)
2217 events = await decode_streaming_result(result)
2218
2219 thread_created = next(e for e in events if e.type == "thread.created")
2220 assert thread_created.thread.id == "thr_custom_1"
2221
2222 user_message = next(
2223 e.item
2224 for e in events
2225 if e.type == "thread.item.done" and e.item.type == "user_message"
2226 )
2227 assert user_message.id == "message_custom_2_thr_custom_1"
2228