openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2d5dc36d5cfc388541118dffa6451d9b37db52ef

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_chatkit_server.py

2075lines · modecode

1import asyncio
2import sqlite3
3from contextlib import contextmanager
4from datetime import datetime
5from typing import Any, AsyncIterator, Callable, cast
6
7import pytest
8from helpers.mock_store import SQLiteStore
9from pydantic import AnyUrl, TypeAdapter
10
11from chatkit.actions import Action
12from chatkit.errors import ErrorCode
13from chatkit.server import (
14 ChatKitServer,
15 NonStreamingResult,
16 StreamingResult,
17 stream_widget,
18)
19from chatkit.store import (
20 AttachmentStore,
21 NotFoundError,
22 StoreItemType,
23 default_generate_id,
24)
25from chatkit.types import (
26 AssistantMessageContent,
27 AssistantMessageContentPartTextDelta,
28 AssistantMessageItem,
29 Attachment,
30 AttachmentCreateParams,
31 AttachmentDeleteParams,
32 AttachmentsCreateReq,
33 AttachmentsDeleteReq,
34 AttachmentUploadDescriptor,
35 ClientToolCallItem,
36 CustomTask,
37 FeedbackKind,
38 FileAttachment,
39 ImageAttachment,
40 InferenceOptions,
41 ItemFeedbackParams,
42 ItemsFeedbackReq,
43 ItemsListParams,
44 ItemsListReq,
45 LockedStatus,
46 Page,
47 ProgressUpdateEvent,
48 Thread,
49 ThreadAddClientToolOutputParams,
50 ThreadAddUserMessageParams,
51 ThreadCreatedEvent,
52 ThreadCreateParams,
53 ThreadCustomActionParams,
54 ThreadDeleteParams,
55 ThreadGetByIdParams,
56 ThreadItem,
57 ThreadItemAddedEvent,
58 ThreadItemDoneEvent,
59 ThreadItemRemovedEvent,
60 ThreadItemReplacedEvent,
61 ThreadItemUpdatedEvent,
62 ThreadListParams,
63 ThreadMetadata,
64 ThreadRetryAfterItemParams,
65 ThreadsAddClientToolOutputReq,
66 ThreadsAddUserMessageReq,
67 ThreadsCreateReq,
68 ThreadsCustomActionReq,
69 ThreadsDeleteReq,
70 ThreadsGetByIdReq,
71 ThreadsListReq,
72 ThreadsRetryAfterItemReq,
73 ThreadStreamEvent,
74 ThreadsUpdateReq,
75 ThreadUpdatedEvent,
76 ThreadUpdateParams,
77 ToolChoice,
78 UserMessageInput,
79 UserMessageItem,
80 UserMessageTextContent,
81 WidgetItem,
82 WidgetRootUpdated,
83 Workflow,
84 WorkflowItem,
85 WorkflowTaskAdded,
86)
87from chatkit.widgets import Card, Text
88from tests._types import RequestContext
89from tests.test_store import make_thread_items
90
91server_id = 0
92
93
94DEFAULT_CONTEXT = RequestContext(user_id="test_user")
95
96
97class InMemoryFileStore(AttachmentStore):
98 def __init__(self):
99 self.files = {}
100
101 async def delete_attachment(self, attachment_id: str, context: Any) -> None:
102 del self.files[attachment_id]
103
104 async def create_attachment(
105 self, input: AttachmentCreateParams, context: RequestContext
106 ) -> Attachment:
107 # Simple in-memory attachment creation
108 id = f"atc_{len(self.files) + 1}"
109 metadata = {"source": "test"}
110 if input.mime_type.startswith("image/"):
111 attachment = ImageAttachment(
112 id=id,
113 mime_type=input.mime_type,
114 name=input.name,
115 preview_url=AnyUrl(f"https://example.com/{id}/preview"),
116 upload_descriptor=AttachmentUploadDescriptor(
117 url=AnyUrl(f"https://example.com/{id}/upload"),
118 method="PUT",
119 headers={"X-My-Header": "my-value"},
120 ),
121 metadata=metadata,
122 )
123 else:
124 attachment = FileAttachment(
125 id=id,
126 mime_type=input.mime_type,
127 name=input.name,
128 upload_descriptor=AttachmentUploadDescriptor(
129 url=AnyUrl(f"https://example.com/{id}/upload"),
130 method="PUT",
131 headers={"X-My-Header": "my-value"},
132 ),
133 metadata=metadata,
134 )
135 self.files[attachment.id] = attachment
136 return attachment
137
138
139def decode_event(event: bytes) -> ThreadStreamEvent:
140 return TypeAdapter(ThreadStreamEvent).validate_json(event.split(b"data: ")[1])
141
142
143async def decode_streaming_result(
144 streaming_result: StreamingResult,
145) -> list[ThreadStreamEvent]:
146 return [decode_event(event) async for event in streaming_result.json_events]
147
148
149@contextmanager
150def make_server(
151 responder: Callable[
152 [ThreadMetadata, UserMessageItem | None, Any], AsyncIterator[ThreadStreamEvent]
153 ]
154 | None = None,
155 handle_feedback: Callable[[str, list[str], FeedbackKind, Any], None] | None = None,
156 action_callback: Callable[
157 [ThreadMetadata, Action[str, Any], WidgetItem | None, Any],
158 AsyncIterator[ThreadStreamEvent],
159 ]
160 | None = None,
161 file_store: AttachmentStore | None = None,
162):
163 global server_id
164 db_path = f"file:{server_id}?mode=memory&cache=shared"
165 server_id += 1
166 # Keep the shared in-memory database from being deleted when the last connection is closed
167 db = sqlite3.connect(db_path, uri=True)
168
169 class TestChatKitServer(ChatKitServer):
170 def __init__(self):
171 super().__init__(SQLiteStore(db_path), file_store)
172
173 def action(
174 self,
175 thread: ThreadMetadata,
176 action: Action[str, Any],
177 sender: WidgetItem | None,
178 context: Any,
179 ) -> AsyncIterator[ThreadStreamEvent]:
180 if action_callback is None:
181 raise ValueError("action_callback not wired up")
182 return action_callback(thread, action, sender, context)
183
184 def respond(
185 self,
186 thread: ThreadMetadata,
187 input_user_message: UserMessageItem | None,
188 context: Any,
189 ) -> AsyncIterator[ThreadStreamEvent]:
190 if responder is None:
191 return self._empty_responder()
192 return responder(thread, input_user_message, context)
193
194 async def _empty_responder(self) -> AsyncIterator[ThreadStreamEvent]:
195 return
196 yield
197
198 async def add_feedback(
199 self,
200 thread_id: str,
201 item_ids: list[str],
202 feedback: FeedbackKind,
203 context: Any,
204 ) -> None:
205 if handle_feedback is None:
206 return
207 handle_feedback(thread_id, item_ids, feedback, context)
208
209 async def process_streaming(
210 self, request_obj, context: Any | None = None
211 ) -> list[ThreadStreamEvent]:
212 result = await self.process(
213 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
214 )
215 assert isinstance(result, StreamingResult)
216 return await decode_streaming_result(result)
217
218 async def process_non_streaming(self, request_obj, context: Any | None = None):
219 result = await self.process(
220 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
221 )
222 assert isinstance(result, NonStreamingResult)
223 return result
224
225 try:
226 yield TestChatKitServer()
227 finally:
228 db.close()
229
230
231async def test_stream_cancellation_persists_pending_assistant_message_and_hidden_context():
232 async def responder(
233 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
234 ) -> AsyncIterator[ThreadStreamEvent]:
235 yield ThreadItemAddedEvent(
236 item=AssistantMessageItem(
237 id="assistant-message-pending",
238 created_at=datetime.now(),
239 content=[AssistantMessageContent(text="Hello, ")],
240 thread_id=thread.id,
241 )
242 )
243 yield ThreadItemUpdatedEvent(
244 item_id="assistant-message-pending",
245 update=AssistantMessageContentPartTextDelta(
246 content_index=0,
247 delta="World!",
248 ),
249 )
250 raise asyncio.CancelledError()
251
252 with make_server(responder) as server:
253 # Allow hidden_context id generation in this test store
254 original_generate_item_id = server.store.generate_item_id
255
256 def generate_item_id(
257 item_type: StoreItemType, thread: ThreadMetadata, context: Any
258 ):
259 if item_type == "sdk_hidden_context":
260 return default_generate_id("sdk_hidden_context")
261 return original_generate_item_id(item_type, thread, context)
262
263 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
264
265 stream = await server.process(
266 ThreadsCreateReq(
267 params=ThreadCreateParams(
268 input=UserMessageInput(
269 content=[UserMessageTextContent(text="Hello")],
270 attachments=[],
271 inference_options=InferenceOptions(),
272 )
273 )
274 ).model_dump_json(),
275 DEFAULT_CONTEXT,
276 )
277 assert isinstance(stream, StreamingResult)
278
279 events: list[ThreadStreamEvent] = []
280 with pytest.raises(asyncio.CancelledError): # noqa: PT012
281 async for raw in stream.json_events:
282 events.append(decode_event(raw))
283
284 thread = next(e.thread for e in events if e.type == "thread.created")
285 items = await server.store.load_thread_items(
286 thread.id, None, 1, "desc", DEFAULT_CONTEXT
287 )
288 hidden_context_item = items.data[-1]
289 assert hidden_context_item.type == "sdk_hidden_context"
290 assert (
291 hidden_context_item.content
292 == "The user cancelled the stream. Stop responding to the prior request."
293 )
294
295 assistant_message_item = await server.store.load_item(
296 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
297 )
298 assert assistant_message_item.type == "assistant_message"
299 assert assistant_message_item.content[0].text == "Hello, World!"
300
301
302async def test_stream_cancellation_does_not_persist_pending_empty_assistant_message():
303 async def responder(
304 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
305 ) -> AsyncIterator[ThreadStreamEvent]:
306 yield ThreadItemAddedEvent(
307 item=AssistantMessageItem(
308 id="assistant-message-pending",
309 created_at=datetime.now(),
310 content=[],
311 thread_id=thread.id,
312 )
313 )
314 raise asyncio.CancelledError()
315
316 with make_server(responder) as server:
317 original_generate_item_id = server.store.generate_item_id
318
319 def generate_item_id(
320 item_type: StoreItemType, thread: ThreadMetadata, context: Any
321 ):
322 if item_type == "sdk_hidden_context":
323 return default_generate_id("sdk_hidden_context")
324 return original_generate_item_id(item_type, thread, context)
325
326 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
327
328 stream = await server.process(
329 ThreadsCreateReq(
330 params=ThreadCreateParams(
331 input=UserMessageInput(
332 content=[UserMessageTextContent(text="Hello")],
333 attachments=[],
334 inference_options=InferenceOptions(),
335 )
336 )
337 ).model_dump_json(),
338 DEFAULT_CONTEXT,
339 )
340 assert isinstance(stream, StreamingResult)
341
342 events: list[ThreadStreamEvent] = []
343 with pytest.raises(asyncio.CancelledError): # noqa: PT012
344 async for raw in stream.json_events:
345 events.append(decode_event(raw))
346
347 thread = next(e.thread for e in events if e.type == "thread.created")
348 items = await server.store.load_thread_items(
349 thread.id, None, 1, "desc", DEFAULT_CONTEXT
350 )
351 hidden_context_item = items.data[-1]
352 assert hidden_context_item.type == "sdk_hidden_context"
353 assert (
354 hidden_context_item.content
355 == "The user cancelled the stream. Stop responding to the prior request."
356 )
357
358 with pytest.raises(NotFoundError):
359 await server.store.load_item(
360 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
361 )
362
363
364async def test_workflow_task_not_duplicated_on_done_event():
365 """
366 Regression test to make sure pending item updates do not modify the
367 origin item that was streamed.
368 """
369
370 async def responder(
371 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
372 ) -> AsyncIterator[ThreadStreamEvent]:
373 workflow_item = WorkflowItem(
374 id="workflow-item",
375 created_at=datetime.now(),
376 thread_id=thread.id,
377 workflow=Workflow(type="custom", tasks=[]),
378 )
379
380 yield ThreadItemAddedEvent(item=workflow_item)
381
382 task = CustomTask(title="foo", content="bar")
383 yield ThreadItemUpdatedEvent(
384 item_id=workflow_item.id,
385 update=WorkflowTaskAdded(task=task, task_index=0),
386 )
387
388 workflow_item.workflow.tasks.append(task)
389 yield ThreadItemDoneEvent(item=workflow_item)
390
391 with make_server(responder) as server:
392 events = await server.process_streaming(
393 ThreadsCreateReq(
394 params=ThreadCreateParams(
395 input=UserMessageInput(
396 content=[UserMessageTextContent(text="Hello")],
397 attachments=[],
398 inference_options=InferenceOptions(),
399 )
400 )
401 )
402 )
403
404 thread = next(e.thread for e in events if e.type == "thread.created")
405 workflow_done_event = next(
406 e
407 for e in events
408 if isinstance(e, ThreadItemDoneEvent) and isinstance(e.item, WorkflowItem)
409 )
410 workflow_done_item = cast(WorkflowItem, workflow_done_event.item)
411 assert len(workflow_done_item.workflow.tasks) == 1
412 assert workflow_done_item.workflow.tasks[0].title == "foo"
413
414 stored = await server.store.load_item(
415 thread.id, workflow_done_item.id, DEFAULT_CONTEXT
416 )
417 assert isinstance(stored, WorkflowItem)
418 stored_workflow = cast(WorkflowItem, stored)
419 assert len(stored_workflow.workflow.tasks) == 1
420 assert stored_workflow.workflow.tasks[0].title == "foo"
421
422
423async def test_flows_context_to_responder():
424 responder_context = None
425 add_feedback_context = None
426
427 async def responder(
428 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
429 ) -> AsyncIterator[ThreadStreamEvent]:
430 nonlocal responder_context
431 responder_context = context
432 yield ThreadItemDoneEvent(
433 item=AssistantMessageItem(
434 id="msg_1",
435 created_at=datetime.now(),
436 content=[AssistantMessageContent(text="Hello, world!")],
437 thread_id=thread.id,
438 ),
439 )
440
441 def add_feedback(
442 thread_id: str, item_ids: list[str], feedback: FeedbackKind, context: Any
443 ) -> None:
444 nonlocal add_feedback_context
445 add_feedback_context = context
446
447 context = RequestContext(user_id="text_ctx")
448 with make_server(responder, add_feedback) as server:
449 events = await server.process_streaming(
450 ThreadsCreateReq(
451 params=ThreadCreateParams(
452 input=UserMessageInput(
453 content=[UserMessageTextContent(text="Hello, world!")],
454 attachments=[],
455 inference_options=InferenceOptions(),
456 )
457 )
458 ),
459 context,
460 )
461 assert responder_context == context
462 thread = next(
463 event.thread for event in events if event.type == "thread.created"
464 )
465 item = next(
466 event.item
467 for event in events
468 if event.type == "thread.item.done"
469 and event.item.type == "assistant_message"
470 )
471 await server.process_non_streaming(
472 ItemsFeedbackReq(
473 params=ItemFeedbackParams(
474 thread_id=thread.id,
475 item_ids=[item.id],
476 kind="positive",
477 )
478 ),
479 context,
480 )
481 assert add_feedback_context == context
482
483
484async def test_creates_thread():
485 async def responder(
486 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
487 ) -> AsyncIterator[ThreadStreamEvent]:
488 return
489 yield
490
491 with make_server(responder) as server:
492 events = await server.process_streaming(
493 ThreadsCreateReq(
494 params=ThreadCreateParams(
495 input=UserMessageInput(
496 content=[UserMessageTextContent(text="Hello, world!")],
497 attachments=[],
498 inference_options=InferenceOptions(),
499 quoted_text="Quoted text",
500 )
501 )
502 )
503 )
504
505 thread_created_event = next(
506 event for event in events if event.type == "thread.created"
507 )
508 assert thread_created_event.thread.title is None
509
510 item_done_event = next(
511 event.item for event in events if event.type == "thread.item.done"
512 )
513
514 assert item_done_event.type == "user_message"
515 assert item_done_event.content[0].text == "Hello, world!"
516 assert item_done_event.quoted_text == "Quoted text"
517
518
519async def test_saves_thread_title():
520 async def responder(
521 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
522 ) -> AsyncIterator[ThreadStreamEvent]:
523 thread.title = "Updated title"
524 return
525 yield
526
527 with make_server(responder) as server:
528 events = await server.process_streaming(
529 ThreadsCreateReq(
530 params=ThreadCreateParams(
531 input=UserMessageInput(
532 content=[UserMessageTextContent(text="Hello, world!")],
533 attachments=[],
534 inference_options=InferenceOptions(),
535 )
536 )
537 )
538 )
539 thread = next(
540 event.thread for event in events if event.type == "thread.created"
541 )
542 assert (
543 await server.store.load_thread(
544 thread.id, RequestContext(user_id="test_user")
545 )
546 ).title == "Updated title"
547 assert events[-1].type == "thread.updated"
548 assert events[-1].thread.title == "Updated title"
549
550
551async def test_saves_thread_metadata():
552 async def responder(
553 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
554 ) -> AsyncIterator[ThreadStreamEvent]:
555 thread.metadata["test_key"] = "test_value"
556 return
557 yield
558
559 with make_server(responder) as server:
560 events = await server.process_streaming(
561 ThreadsCreateReq(
562 params=ThreadCreateParams(
563 input=UserMessageInput(
564 content=[UserMessageTextContent(text="Hello, world!")],
565 attachments=[],
566 inference_options=InferenceOptions(),
567 )
568 )
569 )
570 )
571 thread = next(
572 event.thread for event in events if event.type == "thread.created"
573 )
574 assert (await server.store.load_thread(thread.id, DEFAULT_CONTEXT)).metadata[
575 "test_key"
576 ] == "test_value"
577 assert events[-1].type == "thread.updated"
578
579
580async def test_saves_thread_locked_fields():
581 async def responder(
582 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
583 ) -> AsyncIterator[ThreadStreamEvent]:
584 thread.status = LockedStatus(reason="Because")
585 return
586 yield
587
588 with make_server(responder) as server:
589 events = await server.process_streaming(
590 ThreadsCreateReq(
591 params=ThreadCreateParams(
592 input=UserMessageInput(
593 content=[UserMessageTextContent(text="Hello, world!")],
594 attachments=[],
595 inference_options=InferenceOptions(),
596 )
597 )
598 )
599 )
600 thread = next(
601 event.thread for event in events if event.type == "thread.created"
602 )
603 loaded = await server.store.load_thread(
604 thread.id, RequestContext(user_id="test_user")
605 )
606 assert loaded.status == LockedStatus(reason="Because")
607 assert events[-1].type == "thread.updated"
608 assert events[-1].thread.status == LockedStatus(reason="Because")
609
610
611async def test_emits_thread_updated_mid_stream_and_persists():
612 async def responder(
613 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
614 ) -> AsyncIterator[ThreadStreamEvent]:
615 # First assistant message
616 yield ThreadItemDoneEvent(
617 item=AssistantMessageItem(
618 id="assistant_a",
619 content=[AssistantMessageContent(text="A")],
620 created_at=datetime.now(),
621 thread_id=thread.id,
622 ),
623 )
624
625 # Update the thread mid-stream
626 thread.title = "Mid-stream title"
627
628 # Second assistant message, after which thread.updated should be emitted
629 yield ThreadItemDoneEvent(
630 item=AssistantMessageItem(
631 id="assistant_b",
632 content=[AssistantMessageContent(text="B")],
633 created_at=datetime.now(),
634 thread_id=thread.id,
635 ),
636 )
637
638 # Continue streaming after the update event
639 yield ThreadItemDoneEvent(
640 item=AssistantMessageItem(
641 id="assistant_c",
642 content=[AssistantMessageContent(text="C")],
643 created_at=datetime.now(),
644 thread_id=thread.id,
645 ),
646 )
647
648 with make_server(responder) as server:
649 events = await server.process_streaming(
650 ThreadsCreateReq(
651 params=ThreadCreateParams(
652 input=UserMessageInput(
653 content=[UserMessageTextContent(text="Hello")],
654 attachments=[],
655 inference_options=InferenceOptions(),
656 )
657 )
658 )
659 )
660
661 # Find the created thread
662 thread = next(e.thread for e in events if isinstance(e, ThreadCreatedEvent))
663
664 # Ensure a thread.updated event occurs right after assistant_b
665 idx_b = next(
666 i
667 for i, e in enumerate(events)
668 if isinstance(e, ThreadItemDoneEvent)
669 and isinstance(e.item, AssistantMessageItem)
670 and e.item.id == "assistant_b"
671 )
672 assert isinstance(events[idx_b + 1], ThreadUpdatedEvent)
673 updated_event = cast(ThreadUpdatedEvent, events[idx_b + 1])
674 assert updated_event.thread.title == "Mid-stream title"
675
676 # Streaming continues after the update event
677 assert isinstance(events[idx_b + 2], ThreadItemDoneEvent)
678 done_event = cast(ThreadItemDoneEvent, events[idx_b + 2])
679 assert isinstance(done_event.item, AssistantMessageItem)
680 assert done_event.item.id == "assistant_c"
681
682 # Persisted in the store
683 saved = await server.store.load_thread(thread.id, DEFAULT_CONTEXT)
684 assert saved.title == "Mid-stream title"
685
686
687async def test_respond_with_tool_call():
688 async def responder(
689 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
690 ) -> AsyncIterator[ThreadStreamEvent]:
691 if isinstance(input, UserMessageItem):
692 yield ThreadItemDoneEvent(
693 item=ClientToolCallItem(
694 id="msg_1",
695 created_at=datetime.now(),
696 name="tool_call_1",
697 arguments={"arg1": "val1", "arg2": False},
698 call_id="tool_call_1",
699 thread_id=thread.id,
700 ),
701 )
702 elif input is None:
703 yield ThreadItemDoneEvent(
704 item=AssistantMessageItem(
705 id="msg_2",
706 content=[
707 AssistantMessageContent(text="Glad the tool call succeeded!")
708 ],
709 created_at=datetime.now(),
710 thread_id=thread.id,
711 ),
712 )
713
714 with make_server(responder) as server:
715 events = await server.process_streaming(
716 ThreadsCreateReq(
717 params=ThreadCreateParams(
718 input=UserMessageInput(
719 content=[UserMessageTextContent(text="Hello, world!")],
720 attachments=[],
721 inference_options=InferenceOptions(),
722 )
723 )
724 )
725 )
726
727 assert len(events) == 4
728 assert events[0].type == "thread.created"
729 thread = events[0].thread
730
731 assert events[1].type == "thread.item.done"
732 assert events[1].item.type == "user_message"
733
734 assert events[2].type == "stream_options"
735
736 assert events[3].type == "thread.item.done"
737 assert events[3].item.type == "client_tool_call"
738 assert events[3].item.id == "msg_1"
739 assert events[3].item.name == "tool_call_1"
740 assert events[3].item.arguments == {"arg1": "val1", "arg2": False}
741 assert events[3].item.call_id == "tool_call_1"
742
743 events = await server.process_streaming(
744 ThreadsAddClientToolOutputReq(
745 params=ThreadAddClientToolOutputParams(
746 thread_id=thread.id,
747 result={"text": "Wow!"},
748 )
749 )
750 )
751
752 assert len(events) == 2
753 assert events[0].type == "stream_options"
754 assert events[1].type == "thread.item.done"
755 assert events[1].item.type == "assistant_message"
756
757
758async def test_respond_with_tool_status():
759 async def responder(
760 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
761 ) -> AsyncIterator[ThreadStreamEvent]:
762 yield ProgressUpdateEvent(text="Tool status")
763 yield ThreadItemDoneEvent(
764 item=AssistantMessageItem(
765 id="msg_4",
766 content=[AssistantMessageContent(text="Hello, world!")],
767 created_at=datetime.now(),
768 thread_id=thread.id,
769 ),
770 )
771
772 with make_server(responder) as server:
773 events = await server.process_streaming(
774 ThreadsCreateReq(
775 params=ThreadCreateParams(
776 input=UserMessageInput(
777 content=[UserMessageTextContent(text="Hello, world!")],
778 attachments=[],
779 inference_options=InferenceOptions(),
780 )
781 )
782 )
783 )
784
785 assert len(events) == 5
786 assert events[0].type == "thread.created"
787 assert events[1].type == "thread.item.done"
788 assert events[2].type == "stream_options"
789 assert events[3].type == "progress_update"
790 assert events[4].type == "thread.item.done"
791
792
793async def test_list_threads_response():
794 with make_server() as server:
795 for i in range(25):
796 await server.process_streaming(
797 ThreadsCreateReq(
798 params=ThreadCreateParams(
799 input=UserMessageInput(
800 content=[UserMessageTextContent(text=f"Thread {i}")],
801 attachments=[],
802 inference_options=InferenceOptions(),
803 )
804 )
805 )
806 )
807 list_result = await server.process_non_streaming(
808 ThreadsListReq(params=ThreadListParams(limit=20))
809 )
810 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
811 assert len(threads.data) == 20
812
813 # Newest first
814 assert threads.data[0].created_at > threads.data[1].created_at
815 assert threads.has_more is True
816 assert threads.after is not None
817
818 more_threads = await server.process_non_streaming(
819 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
820 )
821 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(more_threads.json)
822 assert len(threads.data) == 5
823 assert threads.has_more is False
824 assert not threads.after
825
826
827async def test_list_items_response():
828 async def responder(
829 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
830 ) -> AsyncIterator[ThreadStreamEvent]:
831 for i in range(25):
832 yield ThreadItemDoneEvent(
833 item=UserMessageItem(
834 id=f"msg_{i}",
835 content=[UserMessageTextContent(text=f"Message {i}")],
836 attachments=[],
837 inference_options=InferenceOptions(),
838 thread_id=thread.id,
839 created_at=datetime.now(),
840 ),
841 )
842
843 with make_server(responder) as server:
844 events = await server.process_streaming(
845 ThreadsCreateReq(
846 params=ThreadCreateParams(
847 input=UserMessageInput(
848 content=[UserMessageTextContent(text="Thread 1")],
849 attachments=[],
850 inference_options=InferenceOptions(),
851 )
852 )
853 )
854 )
855 thread = next(
856 event.thread for event in events if event.type == "thread.created"
857 )
858 list_result = await server.process_non_streaming(
859 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
860 )
861 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
862 assert len(items.data) == 20
863 # Newest first
864 assert items.data[0].created_at > items.data[1].created_at
865 assert items.has_more is True
866 assert items.after is not None
867
868 list_result = await server.process_non_streaming(
869 ItemsListReq(
870 params=ItemsListParams(thread_id=thread.id, after=items.data[-1].id)
871 )
872 )
873 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
874 assert len(items.data) == 6 # +1 for user message
875 assert items.has_more is False
876
877
878async def test_calls_action():
879 actions = []
880
881 async def responder(
882 thread: ThreadMetadata,
883 input: UserMessageItem | None,
884 context: Any,
885 ) -> AsyncIterator[ThreadStreamEvent]:
886 yield ThreadItemDoneEvent(
887 item=WidgetItem(
888 id="widget_1",
889 type="widget",
890 created_at=datetime.now(),
891 thread_id=thread.id,
892 widget=Card(
893 children=[
894 Text(
895 key="text_1",
896 value="test",
897 )
898 ],
899 ),
900 ),
901 )
902
903 async def action(
904 thread: ThreadMetadata,
905 action: Action[str, Any],
906 sender: WidgetItem | None,
907 context: Any,
908 ) -> AsyncIterator[ThreadStreamEvent]:
909 actions.append((action, sender))
910 assert sender
911
912 yield ThreadItemUpdatedEvent(
913 item_id=sender.id,
914 update=WidgetRootUpdated(
915 widget=Card(
916 children=[
917 Text(value="Email sent!"),
918 ]
919 ),
920 ),
921 )
922
923 with make_server(responder, action_callback=action) as server:
924 events = await server.process_streaming(
925 ThreadsCreateReq(
926 params=ThreadCreateParams(
927 input=UserMessageInput(
928 content=[UserMessageTextContent(text="Show widget")],
929 attachments=[],
930 inference_options=InferenceOptions(),
931 )
932 )
933 )
934 )
935 thread = next(
936 event.thread for event in events if event.type == "thread.created"
937 )
938 widget_item = next(
939 event.item
940 for event in events
941 if isinstance(event, ThreadItemDoneEvent)
942 and isinstance(event.item, WidgetItem)
943 )
944
945 events = await server.process_streaming(
946 ThreadsCustomActionReq(
947 params=ThreadCustomActionParams(
948 thread_id=thread.id,
949 item_id=widget_item.id,
950 action=Action(type="create_user", payload={"user_id": "123"}),
951 )
952 )
953 )
954
955 assert actions
956 assert actions[0] == (
957 Action(type="create_user", payload={"user_id": "123"}),
958 widget_item,
959 )
960
961 assert len(events) == 2
962 assert events[0].type == "stream_options"
963 assert events[1].type == "thread.item.updated"
964 assert isinstance(events[1], ThreadItemUpdatedEvent)
965 assert events[1].update.type == "widget.root.updated"
966 assert events[1].update.widget == Card(children=[Text(value="Email sent!")])
967
968
969async def test_add_feedback():
970 called = []
971
972 def handle_feedback(thread_id, item_ids, feedback, context):
973 called.append((thread_id, item_ids, feedback, context))
974
975 async def responder(
976 thread: ThreadMetadata,
977 input: UserMessageItem | None,
978 context: Any,
979 ) -> AsyncIterator[ThreadStreamEvent]:
980 yield ThreadItemDoneEvent(
981 item=AssistantMessageItem(
982 id="assistant_msg_1",
983 content=[
984 AssistantMessageContent(text="Feedback received!"),
985 ],
986 created_at=datetime.now(),
987 thread_id=thread.id,
988 ),
989 )
990 yield ThreadItemDoneEvent(
991 item=WidgetItem(
992 id="widget_1",
993 type="widget",
994 created_at=datetime.now(),
995 widget=Card(children=[Text(key="text", value="Widget content")]),
996 thread_id=thread.id,
997 ),
998 )
999
1000 with make_server(responder, handle_feedback=handle_feedback) as server:
1001 events = await server.process_streaming(
1002 ThreadsCreateReq(
1003 params=ThreadCreateParams(
1004 input=UserMessageInput(
1005 content=[UserMessageTextContent(text="Test feedback")],
1006 attachments=[],
1007 inference_options=InferenceOptions(),
1008 )
1009 )
1010 )
1011 )
1012 thread = next(
1013 event.thread for event in events if event.type == "thread.created"
1014 )
1015
1016 await server.process_non_streaming(
1017 ItemsFeedbackReq(
1018 params=ItemFeedbackParams(
1019 thread_id=thread.id,
1020 item_ids=["assistant_msg_1"],
1021 kind="positive",
1022 )
1023 )
1024 )
1025 await server.process_non_streaming(
1026 ItemsFeedbackReq(
1027 params=ItemFeedbackParams(
1028 thread_id=thread.id,
1029 item_ids=["widget_1"],
1030 kind="negative",
1031 )
1032 )
1033 )
1034
1035 assert len(called) == 2
1036 assert called[0][0] == thread.id
1037 assert called[0][1] == ["assistant_msg_1"]
1038 assert called[0][2] == "positive"
1039
1040 assert called[1][0] == thread.id
1041 assert called[1][1] == ["widget_1"]
1042 assert called[1][2] == "negative"
1043
1044
1045async def test_get_thread_by_id():
1046 make_thread_items()
1047
1048 with make_server() as server:
1049 # Create a thread by sending a ThreadCreateRequest
1050 events = await server.process_streaming(
1051 ThreadsCreateReq(
1052 params=ThreadCreateParams(
1053 input=UserMessageInput(
1054 content=[UserMessageTextContent(text="Test thread")],
1055 attachments=[],
1056 inference_options=InferenceOptions(),
1057 )
1058 )
1059 )
1060 )
1061 thread = next(
1062 event.thread for event in events if event.type == "thread.created"
1063 )
1064
1065 # Retrieve the thread by id
1066 result = await server.process_non_streaming(
1067 ThreadsGetByIdReq(
1068 params=ThreadGetByIdParams(thread_id=thread.id),
1069 )
1070 )
1071 loaded_thread = TypeAdapter(Thread).validate_json(result.json)
1072 assert loaded_thread.id == thread.id
1073 assert loaded_thread.title == thread.title
1074 assert loaded_thread.items.data[0].type == "user_message"
1075 assert loaded_thread.items.data[0].content[0].text == "Test thread"
1076
1077
1078async def test_create_file():
1079 store = InMemoryFileStore()
1080 file_name = "test-file-name"
1081 file_content_type = "text/plain"
1082
1083 with make_server(file_store=store) as server:
1084 result = await server.process_non_streaming(
1085 AttachmentsCreateReq(
1086 params=AttachmentCreateParams(
1087 name=file_name, size=1024, mime_type=file_content_type
1088 )
1089 )
1090 )
1091 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1092
1093 # Verify response includes file
1094 assert attachment.id is not None
1095 assert attachment.mime_type == file_content_type
1096 assert attachment.name == file_name
1097 assert attachment.type == "file"
1098 assert attachment.upload_descriptor is not None
1099 assert attachment.upload_descriptor.url == AnyUrl(
1100 f"https://example.com/{attachment.id}/upload"
1101 )
1102 assert attachment.upload_descriptor.method == "PUT"
1103 assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"}
1104 assert attachment.id in store.files
1105
1106 # The `metadata` field is excluded from serialization by default but still stored server-side
1107 assert attachment.metadata is None
1108 assert store.files[attachment.id].metadata == {"source": "test"}
1109
1110
1111async def test_create_image_file():
1112 store = InMemoryFileStore()
1113 file_name = "test-file-name"
1114 file_content_type = "image/png"
1115
1116 with make_server(file_store=store) as server:
1117 result = await server.process_non_streaming(
1118 AttachmentsCreateReq(
1119 params=AttachmentCreateParams(
1120 name=file_name,
1121 size=1024,
1122 mime_type=file_content_type,
1123 )
1124 )
1125 )
1126 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1127
1128 assert attachment.id is not None
1129 assert attachment.mime_type == file_content_type
1130 assert attachment.name == file_name
1131 assert attachment.type == "image"
1132 assert attachment.preview_url == AnyUrl(
1133 f"https://example.com/{attachment.id}/preview"
1134 )
1135 assert attachment.upload_descriptor is not None
1136 assert attachment.upload_descriptor.url == AnyUrl(
1137 f"https://example.com/{attachment.id}/upload"
1138 )
1139 assert attachment.upload_descriptor.method == "PUT"
1140 assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"}
1141
1142 assert attachment.id in store.files
1143
1144
1145async def test_user_message_attachment_metadata_not_serialized():
1146 attachment = FileAttachment(
1147 id="file_with_metadata",
1148 type="file",
1149 mime_type="text/plain",
1150 name="test.txt",
1151 metadata={"source": "test"},
1152 )
1153
1154 with make_server() as server:
1155 await server.store.save_attachment(attachment, DEFAULT_CONTEXT)
1156 events = await server.process_streaming(
1157 ThreadsCreateReq(
1158 params=ThreadCreateParams(
1159 input=UserMessageInput(
1160 content=[UserMessageTextContent(text="Message with file")],
1161 attachments=[attachment.id],
1162 inference_options=InferenceOptions(),
1163 )
1164 )
1165 )
1166 )
1167 thread = next(
1168 event.thread for event in events if event.type == "thread.created"
1169 )
1170
1171 list_result = await server.process_non_streaming(
1172 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1173 )
1174 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1175 user_items = [item for item in items.data if item.type == "user_message"]
1176 assert user_items
1177 attachments = user_items[0].attachments
1178 assert attachments
1179 assert all(att.metadata is None for att in attachments)
1180
1181
1182async def test_user_message_attachment_thread_id_persisted_on_thread_create():
1183 attachment = FileAttachment(
1184 id="file_thread_id_on_create",
1185 type="file",
1186 mime_type="text/plain",
1187 name="test.txt",
1188 thread_id=None,
1189 )
1190
1191 with make_server() as server:
1192 # Attachment exists before a thread exists, so it may not have a thread_id yet.
1193 await server.store.save_attachment(attachment, DEFAULT_CONTEXT)
1194
1195 events = await server.process_streaming(
1196 ThreadsCreateReq(
1197 params=ThreadCreateParams(
1198 input=UserMessageInput(
1199 content=[UserMessageTextContent(text="Message with file")],
1200 attachments=[attachment.id],
1201 inference_options=InferenceOptions(),
1202 )
1203 )
1204 )
1205 )
1206 thread = next(e.thread for e in events if e.type == "thread.created")
1207
1208 # After the user message is saved, the attachment record should be updated
1209 # to reflect its thread association.
1210 stored = await server.store.load_attachment(attachment.id, DEFAULT_CONTEXT)
1211 assert stored.thread_id == thread.id
1212
1213
1214async def test_user_message_attachment_thread_id_persisted_on_add_user_message():
1215 with make_server() as server:
1216 # Create a thread first.
1217 events = await server.process_streaming(
1218 ThreadsCreateReq(
1219 params=ThreadCreateParams(
1220 input=UserMessageInput(
1221 content=[UserMessageTextContent(text="Start")],
1222 attachments=[],
1223 inference_options=InferenceOptions(),
1224 )
1225 )
1226 )
1227 )
1228 thread = next(e.thread for e in events if e.type == "thread.created")
1229
1230 attachment = FileAttachment(
1231 id="file_thread_id_on_add",
1232 type="file",
1233 mime_type="text/plain",
1234 name="test.txt",
1235 thread_id=None,
1236 )
1237 await server.store.save_attachment(attachment, DEFAULT_CONTEXT)
1238
1239 # Attach the pre-created attachment to a message in an existing thread.
1240 await server.process_streaming(
1241 ThreadsAddUserMessageReq(
1242 params=ThreadAddUserMessageParams(
1243 thread_id=thread.id,
1244 input=UserMessageInput(
1245 content=[UserMessageTextContent(text="Second")],
1246 attachments=[attachment.id],
1247 inference_options=InferenceOptions(),
1248 ),
1249 )
1250 )
1251 )
1252
1253 stored = await server.store.load_attachment(attachment.id, DEFAULT_CONTEXT)
1254 assert stored.thread_id == thread.id
1255
1256
1257async def test_create_file_without_filestore():
1258 with pytest.raises(RuntimeError):
1259 with make_server() as server:
1260 await server.process_non_streaming(
1261 AttachmentsCreateReq(
1262 params=AttachmentCreateParams(
1263 name="test-file-name",
1264 size=1024,
1265 mime_type="text/plain",
1266 )
1267 )
1268 )
1269
1270
1271async def test_delete_file():
1272 store = InMemoryFileStore()
1273 store.files["test-file-id"] = {"id": "test-file-id"}
1274 with make_server(file_store=store) as server:
1275 await server.process_non_streaming(
1276 AttachmentsDeleteReq(
1277 params=AttachmentDeleteParams(attachment_id="test-file-id")
1278 )
1279 )
1280 assert store.files == {}
1281
1282
1283async def test_delete_file_without_filestore():
1284 with pytest.raises(RuntimeError):
1285 with make_server() as server:
1286 await server.process_non_streaming(
1287 AttachmentsDeleteReq(
1288 params=AttachmentDeleteParams(attachment_id="test-file-id")
1289 )
1290 )
1291
1292
1293async def test_list_threads_paginated_response():
1294 with make_server() as server:
1295 for i in range(25):
1296 await server.process_streaming(
1297 ThreadsCreateReq(
1298 params=ThreadCreateParams(
1299 input=UserMessageInput(
1300 content=[UserMessageTextContent(text=f"Thread {i}")],
1301 attachments=[],
1302 inference_options=InferenceOptions(),
1303 )
1304 )
1305 )
1306 )
1307 list_result = await server.process_non_streaming(
1308 ThreadsListReq(params=ThreadListParams(limit=20))
1309 )
1310 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1311 assert len(threads.data) == 20
1312 assert threads.has_more is True
1313 assert threads.after is not None
1314
1315 list_result = await server.process_non_streaming(
1316 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
1317 )
1318 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1319 assert len(threads.data) == 5
1320 assert threads.has_more is False
1321
1322
1323async def test_threads_update_request():
1324 with make_server() as server:
1325 events = await server.process_streaming(
1326 ThreadsCreateReq(
1327 params=ThreadCreateParams(
1328 input=UserMessageInput(
1329 content=[UserMessageTextContent(text="Hello, world!")],
1330 attachments=[],
1331 inference_options=InferenceOptions(),
1332 )
1333 )
1334 )
1335 )
1336 thread = next(
1337 event.thread for event in events if event.type == "thread.created"
1338 )
1339 updated_thread = await server.process_non_streaming(
1340 ThreadsUpdateReq(
1341 params=ThreadUpdateParams(thread_id=thread.id, title="New Title")
1342 )
1343 )
1344 updated_thread = TypeAdapter[Thread](Thread).validate_json(updated_thread.json)
1345 assert updated_thread.title == "New Title"
1346
1347
1348async def test_returns_widget_item_generator():
1349 thread = ThreadMetadata(
1350 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1351 )
1352
1353 def render_widget(i: int) -> Card:
1354 return Card(children=[Text(id="text", value="Hello, world"[:i])])
1355
1356 async def widget_generator():
1357 yield render_widget(0)
1358 yield render_widget(3)
1359 yield render_widget(6)
1360 yield render_widget(12)
1361
1362 events = [event async for event in stream_widget(thread, widget_generator())]
1363
1364 assert len(events) == 5
1365 assert isinstance(events[0], ThreadItemAddedEvent)
1366 assert isinstance(events[0].item, WidgetItem)
1367 assert events[0].item.widget == Card(children=[Text(id="text", value="")])
1368
1369 assert isinstance(events[1], ThreadItemUpdatedEvent)
1370 assert events[1].update.type == "widget.streaming_text.value_delta"
1371 assert events[1].update.component_id == "text"
1372 assert events[1].update.delta == "Hel"
1373
1374 assert isinstance(events[2], ThreadItemUpdatedEvent)
1375 assert events[2].update.type == "widget.streaming_text.value_delta"
1376 assert events[2].update.component_id == "text"
1377 assert events[2].update.delta == "lo,"
1378
1379 assert isinstance(events[3], ThreadItemUpdatedEvent)
1380 assert events[3].update.type == "widget.streaming_text.value_delta"
1381 assert events[3].update.component_id == "text"
1382 assert events[3].update.delta == " world"
1383
1384 assert isinstance(events[4], ThreadItemDoneEvent)
1385 assert isinstance(events[4].item, WidgetItem)
1386 assert events[4].item.widget == Card(
1387 children=[Text(id="text", value="Hello, world")]
1388 )
1389
1390
1391async def test_returns_widget_item_generator_full_replace():
1392 thread = ThreadMetadata(
1393 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1394 )
1395
1396 async def widget_generator():
1397 yield Card(children=[Text(id="text", value="Hello")])
1398 yield Card(children=[Text(id="text", value="World")])
1399
1400 events = [event async for event in stream_widget(thread, widget_generator())]
1401
1402 assert len(events) == 3
1403 assert isinstance(events[0], ThreadItemAddedEvent)
1404 assert isinstance(events[0].item, WidgetItem)
1405 assert events[0].item.widget == Card(children=[Text(id="text", value="Hello")])
1406
1407 assert isinstance(events[1], ThreadItemUpdatedEvent)
1408 assert events[1].update.type == "widget.root.updated"
1409 assert events[1].update.widget == Card(children=[Text(id="text", value="World")])
1410
1411 assert isinstance(events[2], ThreadItemDoneEvent)
1412 assert isinstance(events[2].item, WidgetItem)
1413 assert events[2].item.widget == Card(children=[Text(id="text", value="World")])
1414
1415
1416async def test_delete_thread():
1417 with make_server() as server:
1418 events = await server.process_streaming(
1419 ThreadsCreateReq(
1420 params=ThreadCreateParams(
1421 input=UserMessageInput(
1422 content=[UserMessageTextContent(text="Thread to delete")],
1423 attachments=[],
1424 inference_options=InferenceOptions(),
1425 )
1426 )
1427 )
1428 )
1429 thread = next(
1430 event.thread for event in events if event.type == "thread.created"
1431 )
1432
1433 response = await server.process_non_streaming(
1434 ThreadsDeleteReq(params=ThreadDeleteParams(thread_id=thread.id))
1435 )
1436 assert response.json == b"{}"
1437
1438 try:
1439 await server.process_non_streaming(
1440 ThreadsGetByIdReq(
1441 params=ThreadGetByIdParams(thread_id=thread.id),
1442 )
1443 )
1444 assert False, "Expected ValueError for deleted thread"
1445 except NotFoundError as e:
1446 assert f"Thread {thread.id} not found" in str(e)
1447
1448
1449async def test_thread_item_removed_event_removes_item():
1450 removed_id = "msg_to_remove"
1451
1452 async def responder(
1453 thread: ThreadMetadata,
1454 input: UserMessageItem | None,
1455 context: Any,
1456 ) -> AsyncIterator[ThreadStreamEvent]:
1457 yield ThreadItemDoneEvent(
1458 item=UserMessageItem(
1459 id=removed_id,
1460 content=[UserMessageTextContent(text="To be removed")],
1461 attachments=[],
1462 inference_options=InferenceOptions(),
1463 thread_id=thread.id,
1464 created_at=datetime.now(),
1465 ),
1466 )
1467 yield ThreadItemRemovedEvent(item_id=removed_id)
1468
1469 with make_server(responder) as server:
1470 events = await server.process_streaming(
1471 ThreadsCreateReq(
1472 params=ThreadCreateParams(
1473 input=UserMessageInput(
1474 content=[UserMessageTextContent(text="Trigger removal")],
1475 attachments=[],
1476 inference_options=InferenceOptions(),
1477 )
1478 )
1479 )
1480 )
1481 thread = next(
1482 event.thread for event in events if event.type == "thread.created"
1483 )
1484 assert any(
1485 e.type == "thread.item.removed" and e.item_id == removed_id for e in events
1486 )
1487 # The item should not be present in the store
1488 list_result = await server.process_non_streaming(
1489 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1490 )
1491 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1492 assert all(i.id != removed_id for i in items)
1493
1494
1495async def test_raising_in_responder_yields_error_event():
1496 async def responder(
1497 thread: ThreadMetadata,
1498 input: UserMessageItem | None,
1499 context: Any,
1500 ) -> AsyncIterator[ThreadStreamEvent]:
1501 raise ValueError("Test error")
1502 yield
1503
1504 with make_server(responder) as server:
1505 result = await server.process(
1506 ThreadsCreateReq(
1507 params=ThreadCreateParams(
1508 input=UserMessageInput(
1509 content=[UserMessageTextContent(text="Test error")],
1510 attachments=[],
1511 inference_options=InferenceOptions(),
1512 )
1513 )
1514 ).model_dump_json(),
1515 RequestContext(user_id="test_user"),
1516 )
1517 assert isinstance(result, StreamingResult)
1518 iter = result.__aiter__()
1519
1520 # Yield an errro
1521 async for e in iter:
1522 e = decode_event(e)
1523 if e.type == "error":
1524 assert e.code == ErrorCode.STREAM_ERROR
1525 break
1526 else:
1527 assert False, "Expected error event"
1528
1529 # Then finish
1530 try:
1531 await iter.__anext__()
1532 except StopAsyncIteration:
1533 pass
1534 else:
1535 assert False, "Expected StopAsyncIteration"
1536
1537
1538async def test_thread_item_replaced_event_replaces_item():
1539 original_id = "msg_to_replace"
1540
1541 async def responder(
1542 thread: ThreadMetadata,
1543 input: UserMessageItem | None,
1544 context: Any,
1545 ) -> AsyncIterator[ThreadStreamEvent]:
1546 # First add an item
1547 yield ThreadItemDoneEvent(
1548 item=UserMessageItem(
1549 id=original_id,
1550 content=[UserMessageTextContent(text="Original content")],
1551 attachments=[],
1552 inference_options=InferenceOptions(),
1553 thread_id=thread.id,
1554 created_at=datetime.now(),
1555 ),
1556 )
1557 # Then replace it
1558 replacement_item = AssistantMessageItem(
1559 id=original_id,
1560 content=[AssistantMessageContent(text="This is the replacement content")],
1561 created_at=datetime.now(),
1562 thread_id=thread.id,
1563 )
1564 yield ThreadItemReplacedEvent(item=replacement_item)
1565
1566 with make_server(responder) as server:
1567 events = await server.process_streaming(
1568 ThreadsCreateReq(
1569 params=ThreadCreateParams(
1570 input=UserMessageInput(
1571 content=[UserMessageTextContent(text="Trigger replacement")],
1572 attachments=[],
1573 inference_options=InferenceOptions(),
1574 )
1575 )
1576 )
1577 )
1578 thread = next(
1579 event.thread for event in events if event.type == "thread.created"
1580 )
1581
1582 # Verify the replacement event was generated
1583 assert any(
1584 e.type == "thread.item.replaced" and e.item.id == original_id
1585 for e in events
1586 )
1587
1588 # Verify the item was replaced in the store
1589 list_result = await server.process_non_streaming(
1590 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1591 )
1592 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1593
1594 # Find the replaced item
1595 replaced_item = next((i for i in items if i.id == original_id), None)
1596 assert replaced_item is not None
1597 assert isinstance(replaced_item, AssistantMessageItem)
1598 assert isinstance(replaced_item.content[0], AssistantMessageContent)
1599 assert replaced_item.content[0].text == "This is the replacement content"
1600
1601
1602async def test_thread_item_replaced_event_with_widget():
1603 widget_id = "widget_to_replace"
1604 original_widget = Card(children=[Text(value="Original widget content")])
1605 replacement_widget = Card(children=[Text(value="Replaced widget content")])
1606
1607 async def responder(
1608 thread: ThreadMetadata,
1609 input: UserMessageItem | None,
1610 context: Any,
1611 ) -> AsyncIterator[ThreadStreamEvent]:
1612 # First add a widget item
1613 yield ThreadItemDoneEvent(
1614 item=WidgetItem(
1615 id=widget_id,
1616 widget=original_widget,
1617 created_at=datetime.now(),
1618 thread_id=thread.id,
1619 ),
1620 )
1621 # Then replace it
1622 yield ThreadItemReplacedEvent(
1623 item=WidgetItem(
1624 id=widget_id,
1625 widget=replacement_widget,
1626 created_at=datetime.now(),
1627 thread_id=thread.id,
1628 ),
1629 )
1630
1631 with make_server(responder) as server:
1632 events = await server.process_streaming(
1633 ThreadsCreateReq(
1634 params=ThreadCreateParams(
1635 input=UserMessageInput(
1636 content=[
1637 UserMessageTextContent(text="Trigger widget replacement")
1638 ],
1639 attachments=[],
1640 inference_options=InferenceOptions(),
1641 )
1642 )
1643 )
1644 )
1645 thread = next(
1646 event.thread for event in events if event.type == "thread.created"
1647 )
1648
1649 # Verify the replacement event was generated
1650 replacement_event = next(
1651 (
1652 e
1653 for e in events
1654 if e.type == "thread.item.replaced" and e.item.id == widget_id
1655 ),
1656 None,
1657 )
1658 assert replacement_event is not None
1659 assert isinstance(replacement_event.item, WidgetItem)
1660 assert replacement_event.item.widget == replacement_widget
1661
1662 # Verify the item was replaced in the store
1663 list_result = await server.process_non_streaming(
1664 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1665 )
1666 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1667
1668 # Find the replaced widget item
1669 replaced_item = next((i for i in items if i.id == widget_id), None)
1670 assert replaced_item is not None
1671 assert isinstance(replaced_item, WidgetItem)
1672 assert isinstance(replaced_item.widget, Card)
1673 assert isinstance(replaced_item.widget.children[0], Text)
1674 assert replaced_item.widget.children[0].value == "Replaced widget content"
1675
1676
1677async def test_retry_after_item():
1678 responder_calls = []
1679
1680 async def responder(
1681 thread: ThreadMetadata,
1682 input: UserMessageItem | None,
1683 context: Any,
1684 ) -> AsyncIterator[ThreadStreamEvent]:
1685 responder_calls.append(input)
1686
1687 if len(responder_calls) == 1:
1688 # First response - return an assistant message
1689 yield ThreadItemDoneEvent(
1690 item=AssistantMessageItem(
1691 id="assistant_msg_1",
1692 content=[AssistantMessageContent(text="First response")],
1693 created_at=datetime.now(),
1694 thread_id=thread.id,
1695 ),
1696 )
1697 else:
1698 # Retry response - return different content
1699 yield ThreadItemDoneEvent(
1700 item=AssistantMessageItem(
1701 id="assistant_msg_2",
1702 content=[AssistantMessageContent(text="Retried response")],
1703 created_at=datetime.now(),
1704 thread_id=thread.id,
1705 ),
1706 )
1707
1708 with make_server(responder) as server:
1709 events = await server.process_streaming(
1710 ThreadsCreateReq(
1711 params=ThreadCreateParams(
1712 input=UserMessageInput(
1713 content=[UserMessageTextContent(text="Hello, world!")],
1714 attachments=[],
1715 inference_options=InferenceOptions(),
1716 )
1717 )
1718 )
1719 )
1720 thread = next(
1721 event.thread for event in events if event.type == "thread.created"
1722 )
1723
1724 list_result = await server.process_non_streaming(
1725 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1726 )
1727 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1728 assert len(items_before.data) == 2
1729 assert items_before.data[0].type == "assistant_message"
1730 assert items_before.data[0].content[0].type == "output_text"
1731 assert items_before.data[0].content[0].text == "First response"
1732 assert items_before.data[1].type == "user_message"
1733 assert items_before.data[1].content[0].text == "Hello, world!"
1734
1735 # Get the user message ID for retrying after it
1736 user_message_id = items_before.data[1].id
1737
1738 # Retry after the user message
1739 retry_events = await server.process_streaming(
1740 ThreadsRetryAfterItemReq(
1741 params=ThreadRetryAfterItemParams(
1742 thread_id=thread.id, item_id=user_message_id
1743 )
1744 )
1745 )
1746
1747 # Verify the retry generated new response
1748 assert len(retry_events) == 2
1749 assert retry_events[0].type == "stream_options"
1750 assert retry_events[1].type == "thread.item.done"
1751 assert retry_events[1].item.type == "assistant_message"
1752 assert retry_events[1].item.content[0].type == "output_text"
1753 assert retry_events[1].item.content[0].text == "Retried response"
1754
1755 # Verify the responder was called twice with the same user message
1756 assert len(responder_calls) == 2
1757 assert responder_calls[0] == responder_calls[1]
1758 assert responder_calls[0].type == "user_message"
1759 assert responder_calls[0].content[0].text == "Hello, world!"
1760
1761 # Verify the assistant response was replaced in storage
1762 list_result_after = await server.process_non_streaming(
1763 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1764 )
1765 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1766 list_result_after.json
1767 )
1768 assert len(items_after.data) == 2 # Still user message + assistant message
1769 assert items_after.data[0].type == "assistant_message"
1770 assert items_after.data[0].content[0].type == "output_text"
1771 assert items_after.data[0].content[0].text == "Retried response" # New response
1772 assert items_after.data[1].type == "user_message"
1773 assert items_after.data[1].content[0].text == "Hello, world!"
1774
1775
1776async def test_retry_after_item_invalid_item():
1777 async def responder(
1778 thread: ThreadMetadata,
1779 input: UserMessageItem | None,
1780 context: Any,
1781 ) -> AsyncIterator[ThreadStreamEvent]:
1782 yield ThreadItemDoneEvent(
1783 item=AssistantMessageItem(
1784 id="assistant_msg_1",
1785 content=[AssistantMessageContent(text="Response")],
1786 created_at=datetime.now(),
1787 thread_id=thread.id,
1788 ),
1789 )
1790
1791 with make_server(responder) as server:
1792 events = await server.process_streaming(
1793 ThreadsCreateReq(
1794 params=ThreadCreateParams(
1795 input=UserMessageInput(
1796 content=[UserMessageTextContent(text="Hello")],
1797 attachments=[],
1798 inference_options=InferenceOptions(),
1799 )
1800 )
1801 )
1802 )
1803 thread = next(
1804 event.thread for event in events if event.type == "thread.created"
1805 )
1806
1807 items_result = await server.process_non_streaming(
1808 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1809 )
1810 items = TypeAdapter(Page[ThreadItem]).validate_json(items_result.json)
1811 assistant_item_id = items.data[0].id
1812
1813 # Try to retry after an assistant message (should fail)
1814 with pytest.raises(ValueError, match="is not a user message"):
1815 await server.process_streaming(
1816 ThreadsRetryAfterItemReq(
1817 params=ThreadRetryAfterItemParams(
1818 thread_id=thread.id, item_id=assistant_item_id
1819 )
1820 )
1821 )
1822
1823
1824async def test_retry_after_item_with_multiple_messages():
1825 responder_calls = []
1826
1827 async def responder(
1828 thread: ThreadMetadata,
1829 input: UserMessageItem | None,
1830 context: Any,
1831 ) -> AsyncIterator[ThreadStreamEvent]:
1832 responder_calls.append(input)
1833 assert input is not None
1834 assert input.type == "user_message"
1835 yield ThreadItemDoneEvent(
1836 item=AssistantMessageItem(
1837 id=f"assistant_msg_{len(responder_calls)}",
1838 content=[
1839 AssistantMessageContent(
1840 text=f"Response to: {input.content[0].text}"
1841 )
1842 ],
1843 created_at=datetime.now(),
1844 thread_id=thread.id,
1845 ),
1846 )
1847
1848 with make_server(responder) as server:
1849 events = await server.process_streaming(
1850 ThreadsCreateReq(
1851 params=ThreadCreateParams(
1852 input=UserMessageInput(
1853 content=[UserMessageTextContent(text="First message")],
1854 attachments=[],
1855 inference_options=InferenceOptions(),
1856 )
1857 )
1858 )
1859 )
1860 thread = next(
1861 event.thread for event in events if event.type == "thread.created"
1862 )
1863
1864 await server.process_streaming(
1865 ThreadsAddUserMessageReq(
1866 params=ThreadAddUserMessageParams(
1867 thread_id=thread.id,
1868 input=UserMessageInput(
1869 content=[UserMessageTextContent(text="Second message")],
1870 attachments=[],
1871 inference_options=InferenceOptions(),
1872 ),
1873 )
1874 )
1875 )
1876
1877 # Verify we have 4 items: user1, assistant1, user2, assistant2
1878 list_result = await server.process_non_streaming(
1879 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1880 )
1881 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1882 assert len(items_before.data) == 4
1883
1884 # Get the second user message ID to retry after it
1885 second_user_message_id = items_before.data[1].id # Second user message
1886
1887 # Retry after the second user message
1888 retry_events = await server.process_streaming(
1889 ThreadsRetryAfterItemReq(
1890 params=ThreadRetryAfterItemParams(
1891 thread_id=thread.id, item_id=second_user_message_id
1892 )
1893 )
1894 )
1895
1896 assert len(retry_events) == 2
1897 assert retry_events[0].type == "stream_options"
1898 assert retry_events[1].type == "thread.item.done"
1899
1900 # Verify retry used the second user message
1901 assert len(responder_calls) == 3 # Original 2 + 1 retry
1902 assert responder_calls[2].type == "user_message"
1903 assert responder_calls[2].content[0].text == "Second message"
1904
1905 # Verify only the last assistant response was removed and regenerated
1906 list_result_after = await server.process_non_streaming(
1907 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1908 )
1909 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1910 list_result_after.json
1911 )
1912 assert len(items_after.data) == 4 # Still 4 items
1913 # First user message and response should remain unchanged
1914 assert items_after.data[3].type == "user_message"
1915 assert items_after.data[3].content[0].text == "First message"
1916 assert items_after.data[3].content[0].type == "input_text"
1917 assert items_after.data[3].content[0].text == "First message"
1918 assert items_after.data[2].type == "assistant_message"
1919 assert items_after.data[2].content[0].type == "output_text"
1920 assert items_after.data[2].content[0].text == "Response to: First message"
1921 # Second user message should remain, but assistant response should be new
1922 assert items_after.data[1].type == "user_message"
1923 assert items_after.data[1].content[0].text == "Second message"
1924 assert items_after.data[0].type == "assistant_message"
1925 assert items_after.data[0].content[0].type == "output_text"
1926 assert items_after.data[0].content[0].text == "Response to: Second message"
1927 assert items_after.data[0].id != items_before.data[0].id
1928
1929
1930async def test_threads_create_passes_tools_to_responder():
1931 tool_choice = ToolChoice(id="web_search")
1932
1933 async def responder(
1934 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1935 ) -> AsyncIterator[ThreadStreamEvent]:
1936 assert input is not None
1937 assert input.inference_options.tool_choice is not None
1938 assert input.inference_options.tool_choice.id == tool_choice.id
1939
1940 yield ThreadItemDoneEvent(
1941 item=AssistantMessageItem(
1942 id="assistant_msg_tools_create",
1943 content=[AssistantMessageContent(text="ok")],
1944 created_at=datetime.now(),
1945 thread_id=thread.id,
1946 ),
1947 )
1948
1949 with make_server(responder) as server:
1950 events = await server.process_streaming(
1951 ThreadsCreateReq(
1952 params=ThreadCreateParams(
1953 input=UserMessageInput(
1954 content=[UserMessageTextContent(text="Hello with tools")],
1955 attachments=[],
1956 inference_options=InferenceOptions(tool_choice=tool_choice),
1957 )
1958 )
1959 )
1960 )
1961 # Sanity check that the request flowed through.
1962 assert any(e.type == "thread.item.done" for e in events)
1963
1964
1965async def test_retry_after_item_passes_tools_to_responder():
1966 pass
1967
1968
1969async def test_threads_add_message_passes_tools_to_responder():
1970 tool_choice = ToolChoice(id="retrieval")
1971
1972 async def responder(
1973 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1974 ) -> AsyncIterator[ThreadStreamEvent]:
1975 assert input is not None
1976 assert input.inference_options.tool_choice is not None
1977 assert input.inference_options.tool_choice.id == tool_choice.id
1978
1979 yield ThreadItemDoneEvent(
1980 item=AssistantMessageItem(
1981 id="assistant_msg_tools_add",
1982 content=[AssistantMessageContent(text="ok")],
1983 created_at=datetime.now(),
1984 thread_id=thread.id,
1985 ),
1986 )
1987
1988 with make_server(responder) as server:
1989 # Create a thread first.
1990 events = await server.process_streaming(
1991 ThreadsCreateReq(
1992 params=ThreadCreateParams(
1993 input=UserMessageInput(
1994 content=[UserMessageTextContent(text="Start")],
1995 attachments=[],
1996 inference_options=InferenceOptions(tool_choice=tool_choice),
1997 )
1998 )
1999 )
2000 )
2001 thread = next(e.thread for e in events if e.type == "thread.created")
2002
2003 # Add a message with tools and verify they reach the responder.
2004 events = await server.process_streaming(
2005 ThreadsAddUserMessageReq(
2006 params=ThreadAddUserMessageParams(
2007 thread_id=thread.id,
2008 input=UserMessageInput(
2009 content=[UserMessageTextContent(text="Second")],
2010 attachments=[],
2011 inference_options=InferenceOptions(tool_choice=tool_choice),
2012 ),
2013 )
2014 )
2015 )
2016 # Sanity check that the request flowed through.
2017 assert any(e.type == "thread.item.done" for e in events)
2018
2019
2020async def test_custom_id_generators_on_server():
2021 class UniqueIdStore(SQLiteStore):
2022 def __init__(self, path):
2023 super().__init__(path)
2024 self.counter = 0
2025
2026 def generate_thread_id(self, context) -> str:
2027 self.counter += 1
2028 return f"thr_custom_{self.counter}"
2029
2030 def generate_item_id(
2031 self,
2032 item_type: str,
2033 thread: ThreadMetadata,
2034 context,
2035 ) -> str:
2036 self.counter += 1
2037 return f"{item_type}_custom_{self.counter}_{thread.id}"
2038
2039 async def responder(
2040 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
2041 ) -> AsyncIterator[ThreadStreamEvent]:
2042 # Emit a tool call to later verify tool_call_output id generation
2043 return
2044 yield
2045
2046 with make_server(responder) as server:
2047 # Swap in our custom ID-generating store, reusing the same DB path.
2048 db_path = cast(SQLiteStore, server.store).db_path
2049 server.store = UniqueIdStore(db_path)
2050
2051 # Create thread and verify thread id and user message id
2052 result = await server.process(
2053 ThreadsCreateReq(
2054 params=ThreadCreateParams(
2055 input=UserMessageInput(
2056 content=[UserMessageTextContent(text="Hello")],
2057 attachments=[],
2058 inference_options=InferenceOptions(),
2059 )
2060 )
2061 ).model_dump_json(),
2062 DEFAULT_CONTEXT,
2063 )
2064 assert isinstance(result, StreamingResult)
2065 events = await decode_streaming_result(result)
2066
2067 thread_created = next(e for e in events if e.type == "thread.created")
2068 assert thread_created.thread.id == "thr_custom_1"
2069
2070 user_message = next(
2071 e.item
2072 for e in events
2073 if e.type == "thread.item.done" and e.item.type == "user_message"
2074 )
2075 assert user_message.id == "message_custom_2_thr_custom_1"
2076