openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.5.3

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_chatkit_server.py

2000lines · modecode

1import asyncio
2import sqlite3
3from contextlib import contextmanager
4from datetime import datetime
5from typing import Any, AsyncIterator, Callable, cast
6
7import pytest
8from helpers.mock_store import SQLiteStore
9from pydantic import AnyUrl, TypeAdapter
10
11from chatkit.actions import Action
12from chatkit.errors import ErrorCode
13from chatkit.server import (
14 ChatKitServer,
15 NonStreamingResult,
16 StreamingResult,
17 stream_widget,
18)
19from chatkit.store import (
20 AttachmentStore,
21 NotFoundError,
22 StoreItemType,
23 default_generate_id,
24)
25from chatkit.types import (
26 AssistantMessageContent,
27 AssistantMessageContentPartTextDelta,
28 AssistantMessageItem,
29 Attachment,
30 AttachmentCreateParams,
31 AttachmentDeleteParams,
32 AttachmentsCreateReq,
33 AttachmentsDeleteReq,
34 AttachmentUploadDescriptor,
35 ClientToolCallItem,
36 CustomTask,
37 FeedbackKind,
38 FileAttachment,
39 ImageAttachment,
40 InferenceOptions,
41 ItemFeedbackParams,
42 ItemsFeedbackReq,
43 ItemsListParams,
44 ItemsListReq,
45 LockedStatus,
46 Page,
47 ProgressUpdateEvent,
48 Thread,
49 ThreadAddClientToolOutputParams,
50 ThreadAddUserMessageParams,
51 ThreadCreatedEvent,
52 ThreadCreateParams,
53 ThreadCustomActionParams,
54 ThreadDeleteParams,
55 ThreadGetByIdParams,
56 ThreadItem,
57 ThreadItemAddedEvent,
58 ThreadItemDoneEvent,
59 ThreadItemRemovedEvent,
60 ThreadItemReplacedEvent,
61 ThreadItemUpdatedEvent,
62 ThreadListParams,
63 ThreadMetadata,
64 ThreadRetryAfterItemParams,
65 ThreadsAddClientToolOutputReq,
66 ThreadsAddUserMessageReq,
67 ThreadsCreateReq,
68 ThreadsCustomActionReq,
69 ThreadsDeleteReq,
70 ThreadsGetByIdReq,
71 ThreadsListReq,
72 ThreadsRetryAfterItemReq,
73 ThreadStreamEvent,
74 ThreadsUpdateReq,
75 ThreadUpdatedEvent,
76 ThreadUpdateParams,
77 ToolChoice,
78 UserMessageInput,
79 UserMessageItem,
80 UserMessageTextContent,
81 WidgetItem,
82 WidgetRootUpdated,
83 Workflow,
84 WorkflowItem,
85 WorkflowTaskAdded,
86)
87from chatkit.widgets import Card, Text
88from tests._types import RequestContext
89from tests.test_store import make_thread_items
90
91server_id = 0
92
93
94DEFAULT_CONTEXT = RequestContext(user_id="test_user")
95
96
97class InMemoryFileStore(AttachmentStore):
98 def __init__(self):
99 self.files = {}
100
101 async def delete_attachment(self, attachment_id: str, context: Any) -> None:
102 del self.files[attachment_id]
103
104 async def create_attachment(
105 self, input: AttachmentCreateParams, context: RequestContext
106 ) -> Attachment:
107 # Simple in-memory attachment creation
108 id = f"atc_{len(self.files) + 1}"
109 metadata = {"source": "test"}
110 if input.mime_type.startswith("image/"):
111 attachment = ImageAttachment(
112 id=id,
113 mime_type=input.mime_type,
114 name=input.name,
115 preview_url=AnyUrl(f"https://example.com/{id}/preview"),
116 upload_descriptor=AttachmentUploadDescriptor(
117 url=AnyUrl(f"https://example.com/{id}/upload"),
118 method="PUT",
119 headers={"X-My-Header": "my-value"},
120 ),
121 metadata=metadata,
122 )
123 else:
124 attachment = FileAttachment(
125 id=id,
126 mime_type=input.mime_type,
127 name=input.name,
128 upload_descriptor=AttachmentUploadDescriptor(
129 url=AnyUrl(f"https://example.com/{id}/upload"),
130 method="PUT",
131 headers={"X-My-Header": "my-value"},
132 ),
133 metadata=metadata,
134 )
135 self.files[attachment.id] = attachment
136 return attachment
137
138
139def decode_event(event: bytes) -> ThreadStreamEvent:
140 return TypeAdapter(ThreadStreamEvent).validate_json(event.split(b"data: ")[1])
141
142
143async def decode_streaming_result(
144 streaming_result: StreamingResult,
145) -> list[ThreadStreamEvent]:
146 return [decode_event(event) async for event in streaming_result.json_events]
147
148
149@contextmanager
150def make_server(
151 responder: Callable[
152 [ThreadMetadata, UserMessageItem | None, Any], AsyncIterator[ThreadStreamEvent]
153 ]
154 | None = None,
155 handle_feedback: Callable[[str, list[str], FeedbackKind, Any], None] | None = None,
156 action_callback: Callable[
157 [ThreadMetadata, Action[str, Any], WidgetItem | None, Any],
158 AsyncIterator[ThreadStreamEvent],
159 ]
160 | None = None,
161 file_store: AttachmentStore | None = None,
162):
163 global server_id
164 db_path = f"file:{server_id}?mode=memory&cache=shared"
165 server_id += 1
166 # Keep the shared in-memory database from being deleted when the last connection is closed
167 db = sqlite3.connect(db_path, uri=True)
168
169 class TestChatKitServer(ChatKitServer):
170 def __init__(self):
171 super().__init__(SQLiteStore(db_path), file_store)
172
173 def action(
174 self,
175 thread: ThreadMetadata,
176 action: Action[str, Any],
177 sender: WidgetItem | None,
178 context: Any,
179 ) -> AsyncIterator[ThreadStreamEvent]:
180 if action_callback is None:
181 raise ValueError("action_callback not wired up")
182 return action_callback(thread, action, sender, context)
183
184 def respond(
185 self,
186 thread: ThreadMetadata,
187 input_user_message: UserMessageItem | None,
188 context: Any,
189 ) -> AsyncIterator[ThreadStreamEvent]:
190 if responder is None:
191 return self._empty_responder()
192 return responder(thread, input_user_message, context)
193
194 async def _empty_responder(self) -> AsyncIterator[ThreadStreamEvent]:
195 return
196 yield
197
198 async def add_feedback(
199 self,
200 thread_id: str,
201 item_ids: list[str],
202 feedback: FeedbackKind,
203 context: Any,
204 ) -> None:
205 if handle_feedback is None:
206 return
207 handle_feedback(thread_id, item_ids, feedback, context)
208
209 async def process_streaming(
210 self, request_obj, context: Any | None = None
211 ) -> list[ThreadStreamEvent]:
212 result = await self.process(
213 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
214 )
215 assert isinstance(result, StreamingResult)
216 return await decode_streaming_result(result)
217
218 async def process_non_streaming(self, request_obj, context: Any | None = None):
219 result = await self.process(
220 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
221 )
222 assert isinstance(result, NonStreamingResult)
223 return result
224
225 try:
226 yield TestChatKitServer()
227 finally:
228 db.close()
229
230
231async def test_stream_cancellation_persists_pending_assistant_message_and_hidden_context():
232 async def responder(
233 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
234 ) -> AsyncIterator[ThreadStreamEvent]:
235 yield ThreadItemAddedEvent(
236 item=AssistantMessageItem(
237 id="assistant-message-pending",
238 created_at=datetime.now(),
239 content=[AssistantMessageContent(text="Hello, ")],
240 thread_id=thread.id,
241 )
242 )
243 yield ThreadItemUpdatedEvent(
244 item_id="assistant-message-pending",
245 update=AssistantMessageContentPartTextDelta(
246 content_index=0,
247 delta="World!",
248 ),
249 )
250 raise asyncio.CancelledError()
251
252 with make_server(responder) as server:
253 # Allow hidden_context id generation in this test store
254 original_generate_item_id = server.store.generate_item_id
255
256 def generate_item_id(
257 item_type: StoreItemType, thread: ThreadMetadata, context: Any
258 ):
259 if item_type == "sdk_hidden_context":
260 return default_generate_id("sdk_hidden_context")
261 return original_generate_item_id(item_type, thread, context)
262
263 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
264
265 stream = await server.process(
266 ThreadsCreateReq(
267 params=ThreadCreateParams(
268 input=UserMessageInput(
269 content=[UserMessageTextContent(text="Hello")],
270 attachments=[],
271 inference_options=InferenceOptions(),
272 )
273 )
274 ).model_dump_json(),
275 DEFAULT_CONTEXT,
276 )
277 assert isinstance(stream, StreamingResult)
278
279 events: list[ThreadStreamEvent] = []
280 with pytest.raises(asyncio.CancelledError): # noqa: PT012
281 async for raw in stream.json_events:
282 events.append(decode_event(raw))
283
284 thread = next(e.thread for e in events if e.type == "thread.created")
285 items = await server.store.load_thread_items(
286 thread.id, None, 1, "desc", DEFAULT_CONTEXT
287 )
288 hidden_context_item = items.data[-1]
289 assert hidden_context_item.type == "sdk_hidden_context"
290 assert (
291 hidden_context_item.content
292 == "The user cancelled the stream. Stop responding to the prior request."
293 )
294
295 assistant_message_item = await server.store.load_item(
296 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
297 )
298 assert assistant_message_item.type == "assistant_message"
299 assert assistant_message_item.content[0].text == "Hello, World!"
300
301
302async def test_stream_cancellation_does_not_persist_pending_empty_assistant_message():
303 async def responder(
304 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
305 ) -> AsyncIterator[ThreadStreamEvent]:
306 yield ThreadItemAddedEvent(
307 item=AssistantMessageItem(
308 id="assistant-message-pending",
309 created_at=datetime.now(),
310 content=[],
311 thread_id=thread.id,
312 )
313 )
314 raise asyncio.CancelledError()
315
316 with make_server(responder) as server:
317 original_generate_item_id = server.store.generate_item_id
318
319 def generate_item_id(
320 item_type: StoreItemType, thread: ThreadMetadata, context: Any
321 ):
322 if item_type == "sdk_hidden_context":
323 return default_generate_id("sdk_hidden_context")
324 return original_generate_item_id(item_type, thread, context)
325
326 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
327
328 stream = await server.process(
329 ThreadsCreateReq(
330 params=ThreadCreateParams(
331 input=UserMessageInput(
332 content=[UserMessageTextContent(text="Hello")],
333 attachments=[],
334 inference_options=InferenceOptions(),
335 )
336 )
337 ).model_dump_json(),
338 DEFAULT_CONTEXT,
339 )
340 assert isinstance(stream, StreamingResult)
341
342 events: list[ThreadStreamEvent] = []
343 with pytest.raises(asyncio.CancelledError): # noqa: PT012
344 async for raw in stream.json_events:
345 events.append(decode_event(raw))
346
347 thread = next(e.thread for e in events if e.type == "thread.created")
348 items = await server.store.load_thread_items(
349 thread.id, None, 1, "desc", DEFAULT_CONTEXT
350 )
351 hidden_context_item = items.data[-1]
352 assert hidden_context_item.type == "sdk_hidden_context"
353 assert (
354 hidden_context_item.content
355 == "The user cancelled the stream. Stop responding to the prior request."
356 )
357
358 with pytest.raises(NotFoundError):
359 await server.store.load_item(
360 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
361 )
362
363
364async def test_workflow_task_not_duplicated_on_done_event():
365 """
366 Regression test to make sure pending item updates do not modify the
367 origin item that was streamed.
368 """
369
370 async def responder(
371 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
372 ) -> AsyncIterator[ThreadStreamEvent]:
373 workflow_item = WorkflowItem(
374 id="workflow-item",
375 created_at=datetime.now(),
376 thread_id=thread.id,
377 workflow=Workflow(type="custom", tasks=[]),
378 )
379
380 yield ThreadItemAddedEvent(item=workflow_item)
381
382 task = CustomTask(title="foo", content="bar")
383 yield ThreadItemUpdatedEvent(
384 item_id=workflow_item.id,
385 update=WorkflowTaskAdded(task=task, task_index=0),
386 )
387
388 workflow_item.workflow.tasks.append(task)
389 yield ThreadItemDoneEvent(item=workflow_item)
390
391 with make_server(responder) as server:
392 events = await server.process_streaming(
393 ThreadsCreateReq(
394 params=ThreadCreateParams(
395 input=UserMessageInput(
396 content=[UserMessageTextContent(text="Hello")],
397 attachments=[],
398 inference_options=InferenceOptions(),
399 )
400 )
401 )
402 )
403
404 thread = next(e.thread for e in events if e.type == "thread.created")
405 workflow_done_event = next(
406 e
407 for e in events
408 if isinstance(e, ThreadItemDoneEvent) and isinstance(e.item, WorkflowItem)
409 )
410 workflow_done_item = cast(WorkflowItem, workflow_done_event.item)
411 assert len(workflow_done_item.workflow.tasks) == 1
412 assert workflow_done_item.workflow.tasks[0].title == "foo"
413
414 stored = await server.store.load_item(
415 thread.id, workflow_done_item.id, DEFAULT_CONTEXT
416 )
417 assert isinstance(stored, WorkflowItem)
418 stored_workflow = cast(WorkflowItem, stored)
419 assert len(stored_workflow.workflow.tasks) == 1
420 assert stored_workflow.workflow.tasks[0].title == "foo"
421
422
423async def test_flows_context_to_responder():
424 responder_context = None
425 add_feedback_context = None
426
427 async def responder(
428 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
429 ) -> AsyncIterator[ThreadStreamEvent]:
430 nonlocal responder_context
431 responder_context = context
432 yield ThreadItemDoneEvent(
433 item=AssistantMessageItem(
434 id="msg_1",
435 created_at=datetime.now(),
436 content=[AssistantMessageContent(text="Hello, world!")],
437 thread_id=thread.id,
438 ),
439 )
440
441 def add_feedback(
442 thread_id: str, item_ids: list[str], feedback: FeedbackKind, context: Any
443 ) -> None:
444 nonlocal add_feedback_context
445 add_feedback_context = context
446
447 context = RequestContext(user_id="text_ctx")
448 with make_server(responder, add_feedback) as server:
449 events = await server.process_streaming(
450 ThreadsCreateReq(
451 params=ThreadCreateParams(
452 input=UserMessageInput(
453 content=[UserMessageTextContent(text="Hello, world!")],
454 attachments=[],
455 inference_options=InferenceOptions(),
456 )
457 )
458 ),
459 context,
460 )
461 assert responder_context == context
462 thread = next(
463 event.thread for event in events if event.type == "thread.created"
464 )
465 item = next(
466 event.item
467 for event in events
468 if event.type == "thread.item.done"
469 and event.item.type == "assistant_message"
470 )
471 await server.process_non_streaming(
472 ItemsFeedbackReq(
473 params=ItemFeedbackParams(
474 thread_id=thread.id,
475 item_ids=[item.id],
476 kind="positive",
477 )
478 ),
479 context,
480 )
481 assert add_feedback_context == context
482
483
484async def test_creates_thread():
485 async def responder(
486 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
487 ) -> AsyncIterator[ThreadStreamEvent]:
488 return
489 yield
490
491 with make_server(responder) as server:
492 events = await server.process_streaming(
493 ThreadsCreateReq(
494 params=ThreadCreateParams(
495 input=UserMessageInput(
496 content=[UserMessageTextContent(text="Hello, world!")],
497 attachments=[],
498 inference_options=InferenceOptions(),
499 quoted_text="Quoted text",
500 )
501 )
502 )
503 )
504
505 thread_created_event = next(
506 event for event in events if event.type == "thread.created"
507 )
508 assert thread_created_event.thread.title is None
509
510 item_done_event = next(
511 event.item for event in events if event.type == "thread.item.done"
512 )
513
514 assert item_done_event.type == "user_message"
515 assert item_done_event.content[0].text == "Hello, world!"
516 assert item_done_event.quoted_text == "Quoted text"
517
518
519async def test_saves_thread_title():
520 async def responder(
521 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
522 ) -> AsyncIterator[ThreadStreamEvent]:
523 thread.title = "Updated title"
524 return
525 yield
526
527 with make_server(responder) as server:
528 events = await server.process_streaming(
529 ThreadsCreateReq(
530 params=ThreadCreateParams(
531 input=UserMessageInput(
532 content=[UserMessageTextContent(text="Hello, world!")],
533 attachments=[],
534 inference_options=InferenceOptions(),
535 )
536 )
537 )
538 )
539 thread = next(
540 event.thread for event in events if event.type == "thread.created"
541 )
542 assert (
543 await server.store.load_thread(
544 thread.id, RequestContext(user_id="test_user")
545 )
546 ).title == "Updated title"
547 assert events[-1].type == "thread.updated"
548 assert events[-1].thread.title == "Updated title"
549
550
551async def test_saves_thread_metadata():
552 async def responder(
553 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
554 ) -> AsyncIterator[ThreadStreamEvent]:
555 thread.metadata["test_key"] = "test_value"
556 return
557 yield
558
559 with make_server(responder) as server:
560 events = await server.process_streaming(
561 ThreadsCreateReq(
562 params=ThreadCreateParams(
563 input=UserMessageInput(
564 content=[UserMessageTextContent(text="Hello, world!")],
565 attachments=[],
566 inference_options=InferenceOptions(),
567 )
568 )
569 )
570 )
571 thread = next(
572 event.thread for event in events if event.type == "thread.created"
573 )
574 assert (await server.store.load_thread(thread.id, DEFAULT_CONTEXT)).metadata[
575 "test_key"
576 ] == "test_value"
577 assert events[-1].type == "thread.updated"
578
579
580async def test_saves_thread_locked_fields():
581 async def responder(
582 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
583 ) -> AsyncIterator[ThreadStreamEvent]:
584 thread.status = LockedStatus(reason="Because")
585 return
586 yield
587
588 with make_server(responder) as server:
589 events = await server.process_streaming(
590 ThreadsCreateReq(
591 params=ThreadCreateParams(
592 input=UserMessageInput(
593 content=[UserMessageTextContent(text="Hello, world!")],
594 attachments=[],
595 inference_options=InferenceOptions(),
596 )
597 )
598 )
599 )
600 thread = next(
601 event.thread for event in events if event.type == "thread.created"
602 )
603 loaded = await server.store.load_thread(
604 thread.id, RequestContext(user_id="test_user")
605 )
606 assert loaded.status == LockedStatus(reason="Because")
607 assert events[-1].type == "thread.updated"
608 assert events[-1].thread.status == LockedStatus(reason="Because")
609
610
611async def test_emits_thread_updated_mid_stream_and_persists():
612 async def responder(
613 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
614 ) -> AsyncIterator[ThreadStreamEvent]:
615 # First assistant message
616 yield ThreadItemDoneEvent(
617 item=AssistantMessageItem(
618 id="assistant_a",
619 content=[AssistantMessageContent(text="A")],
620 created_at=datetime.now(),
621 thread_id=thread.id,
622 ),
623 )
624
625 # Update the thread mid-stream
626 thread.title = "Mid-stream title"
627
628 # Second assistant message, after which thread.updated should be emitted
629 yield ThreadItemDoneEvent(
630 item=AssistantMessageItem(
631 id="assistant_b",
632 content=[AssistantMessageContent(text="B")],
633 created_at=datetime.now(),
634 thread_id=thread.id,
635 ),
636 )
637
638 # Continue streaming after the update event
639 yield ThreadItemDoneEvent(
640 item=AssistantMessageItem(
641 id="assistant_c",
642 content=[AssistantMessageContent(text="C")],
643 created_at=datetime.now(),
644 thread_id=thread.id,
645 ),
646 )
647
648 with make_server(responder) as server:
649 events = await server.process_streaming(
650 ThreadsCreateReq(
651 params=ThreadCreateParams(
652 input=UserMessageInput(
653 content=[UserMessageTextContent(text="Hello")],
654 attachments=[],
655 inference_options=InferenceOptions(),
656 )
657 )
658 )
659 )
660
661 # Find the created thread
662 thread = next(e.thread for e in events if isinstance(e, ThreadCreatedEvent))
663
664 # Ensure a thread.updated event occurs right after assistant_b
665 idx_b = next(
666 i
667 for i, e in enumerate(events)
668 if isinstance(e, ThreadItemDoneEvent)
669 and isinstance(e.item, AssistantMessageItem)
670 and e.item.id == "assistant_b"
671 )
672 assert isinstance(events[idx_b + 1], ThreadUpdatedEvent)
673 updated_event = cast(ThreadUpdatedEvent, events[idx_b + 1])
674 assert updated_event.thread.title == "Mid-stream title"
675
676 # Streaming continues after the update event
677 assert isinstance(events[idx_b + 2], ThreadItemDoneEvent)
678 done_event = cast(ThreadItemDoneEvent, events[idx_b + 2])
679 assert isinstance(done_event.item, AssistantMessageItem)
680 assert done_event.item.id == "assistant_c"
681
682 # Persisted in the store
683 saved = await server.store.load_thread(thread.id, DEFAULT_CONTEXT)
684 assert saved.title == "Mid-stream title"
685
686
687async def test_respond_with_tool_call():
688 async def responder(
689 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
690 ) -> AsyncIterator[ThreadStreamEvent]:
691 if isinstance(input, UserMessageItem):
692 yield ThreadItemDoneEvent(
693 item=ClientToolCallItem(
694 id="msg_1",
695 created_at=datetime.now(),
696 name="tool_call_1",
697 arguments={"arg1": "val1", "arg2": False},
698 call_id="tool_call_1",
699 thread_id=thread.id,
700 ),
701 )
702 elif input is None:
703 yield ThreadItemDoneEvent(
704 item=AssistantMessageItem(
705 id="msg_2",
706 content=[
707 AssistantMessageContent(text="Glad the tool call succeeded!")
708 ],
709 created_at=datetime.now(),
710 thread_id=thread.id,
711 ),
712 )
713
714 with make_server(responder) as server:
715 events = await server.process_streaming(
716 ThreadsCreateReq(
717 params=ThreadCreateParams(
718 input=UserMessageInput(
719 content=[UserMessageTextContent(text="Hello, world!")],
720 attachments=[],
721 inference_options=InferenceOptions(),
722 )
723 )
724 )
725 )
726
727 assert len(events) == 4
728 assert events[0].type == "thread.created"
729 thread = events[0].thread
730
731 assert events[1].type == "thread.item.done"
732 assert events[1].item.type == "user_message"
733
734 assert events[2].type == "stream_options"
735
736 assert events[3].type == "thread.item.done"
737 assert events[3].item.type == "client_tool_call"
738 assert events[3].item.id == "msg_1"
739 assert events[3].item.name == "tool_call_1"
740 assert events[3].item.arguments == {"arg1": "val1", "arg2": False}
741 assert events[3].item.call_id == "tool_call_1"
742
743 events = await server.process_streaming(
744 ThreadsAddClientToolOutputReq(
745 params=ThreadAddClientToolOutputParams(
746 thread_id=thread.id,
747 result={"text": "Wow!"},
748 )
749 )
750 )
751
752 assert len(events) == 2
753 assert events[0].type == "stream_options"
754 assert events[1].type == "thread.item.done"
755 assert events[1].item.type == "assistant_message"
756
757
758async def test_respond_with_tool_status():
759 async def responder(
760 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
761 ) -> AsyncIterator[ThreadStreamEvent]:
762 yield ProgressUpdateEvent(text="Tool status")
763 yield ThreadItemDoneEvent(
764 item=AssistantMessageItem(
765 id="msg_4",
766 content=[AssistantMessageContent(text="Hello, world!")],
767 created_at=datetime.now(),
768 thread_id=thread.id,
769 ),
770 )
771
772 with make_server(responder) as server:
773 events = await server.process_streaming(
774 ThreadsCreateReq(
775 params=ThreadCreateParams(
776 input=UserMessageInput(
777 content=[UserMessageTextContent(text="Hello, world!")],
778 attachments=[],
779 inference_options=InferenceOptions(),
780 )
781 )
782 )
783 )
784
785 assert len(events) == 5
786 assert events[0].type == "thread.created"
787 assert events[1].type == "thread.item.done"
788 assert events[2].type == "stream_options"
789 assert events[3].type == "progress_update"
790 assert events[4].type == "thread.item.done"
791
792
793async def test_list_threads_response():
794 with make_server() as server:
795 for i in range(25):
796 await server.process_streaming(
797 ThreadsCreateReq(
798 params=ThreadCreateParams(
799 input=UserMessageInput(
800 content=[UserMessageTextContent(text=f"Thread {i}")],
801 attachments=[],
802 inference_options=InferenceOptions(),
803 )
804 )
805 )
806 )
807 list_result = await server.process_non_streaming(
808 ThreadsListReq(params=ThreadListParams(limit=20))
809 )
810 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
811 assert len(threads.data) == 20
812
813 # Newest first
814 assert threads.data[0].created_at > threads.data[1].created_at
815 assert threads.has_more is True
816 assert threads.after is not None
817
818 more_threads = await server.process_non_streaming(
819 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
820 )
821 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(more_threads.json)
822 assert len(threads.data) == 5
823 assert threads.has_more is False
824 assert not threads.after
825
826
827async def test_list_items_response():
828 async def responder(
829 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
830 ) -> AsyncIterator[ThreadStreamEvent]:
831 for i in range(25):
832 yield ThreadItemDoneEvent(
833 item=UserMessageItem(
834 id=f"msg_{i}",
835 content=[UserMessageTextContent(text=f"Message {i}")],
836 attachments=[],
837 inference_options=InferenceOptions(),
838 thread_id=thread.id,
839 created_at=datetime.now(),
840 ),
841 )
842
843 with make_server(responder) as server:
844 events = await server.process_streaming(
845 ThreadsCreateReq(
846 params=ThreadCreateParams(
847 input=UserMessageInput(
848 content=[UserMessageTextContent(text="Thread 1")],
849 attachments=[],
850 inference_options=InferenceOptions(),
851 )
852 )
853 )
854 )
855 thread = next(
856 event.thread for event in events if event.type == "thread.created"
857 )
858 list_result = await server.process_non_streaming(
859 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
860 )
861 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
862 assert len(items.data) == 20
863 # Newest first
864 assert items.data[0].created_at > items.data[1].created_at
865 assert items.has_more is True
866 assert items.after is not None
867
868 list_result = await server.process_non_streaming(
869 ItemsListReq(
870 params=ItemsListParams(thread_id=thread.id, after=items.data[-1].id)
871 )
872 )
873 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
874 assert len(items.data) == 6 # +1 for user message
875 assert items.has_more is False
876
877
878async def test_calls_action():
879 actions = []
880
881 async def responder(
882 thread: ThreadMetadata,
883 input: UserMessageItem | None,
884 context: Any,
885 ) -> AsyncIterator[ThreadStreamEvent]:
886 yield ThreadItemDoneEvent(
887 item=WidgetItem(
888 id="widget_1",
889 type="widget",
890 created_at=datetime.now(),
891 thread_id=thread.id,
892 widget=Card(
893 children=[
894 Text(
895 key="text_1",
896 value="test",
897 )
898 ],
899 ),
900 ),
901 )
902
903 async def action(
904 thread: ThreadMetadata,
905 action: Action[str, Any],
906 sender: WidgetItem | None,
907 context: Any,
908 ) -> AsyncIterator[ThreadStreamEvent]:
909 actions.append((action, sender))
910 assert sender
911
912 yield ThreadItemUpdatedEvent(
913 item_id=sender.id,
914 update=WidgetRootUpdated(
915 widget=Card(
916 children=[
917 Text(value="Email sent!"),
918 ]
919 ),
920 ),
921 )
922
923 with make_server(responder, action_callback=action) as server:
924 events = await server.process_streaming(
925 ThreadsCreateReq(
926 params=ThreadCreateParams(
927 input=UserMessageInput(
928 content=[UserMessageTextContent(text="Show widget")],
929 attachments=[],
930 inference_options=InferenceOptions(),
931 )
932 )
933 )
934 )
935 thread = next(
936 event.thread for event in events if event.type == "thread.created"
937 )
938 widget_item = next(
939 event.item
940 for event in events
941 if isinstance(event, ThreadItemDoneEvent)
942 and isinstance(event.item, WidgetItem)
943 )
944
945 events = await server.process_streaming(
946 ThreadsCustomActionReq(
947 params=ThreadCustomActionParams(
948 thread_id=thread.id,
949 item_id=widget_item.id,
950 action=Action(type="create_user", payload={"user_id": "123"}),
951 )
952 )
953 )
954
955 assert actions
956 assert actions[0] == (
957 Action(type="create_user", payload={"user_id": "123"}),
958 widget_item,
959 )
960
961 assert len(events) == 2
962 assert events[0].type == "stream_options"
963 assert events[1].type == "thread.item.updated"
964 assert isinstance(events[1], ThreadItemUpdatedEvent)
965 assert events[1].update.type == "widget.root.updated"
966 assert events[1].update.widget == Card(children=[Text(value="Email sent!")])
967
968
969async def test_add_feedback():
970 called = []
971
972 def handle_feedback(thread_id, item_ids, feedback, context):
973 called.append((thread_id, item_ids, feedback, context))
974
975 async def responder(
976 thread: ThreadMetadata,
977 input: UserMessageItem | None,
978 context: Any,
979 ) -> AsyncIterator[ThreadStreamEvent]:
980 yield ThreadItemDoneEvent(
981 item=AssistantMessageItem(
982 id="assistant_msg_1",
983 content=[
984 AssistantMessageContent(text="Feedback received!"),
985 ],
986 created_at=datetime.now(),
987 thread_id=thread.id,
988 ),
989 )
990 yield ThreadItemDoneEvent(
991 item=WidgetItem(
992 id="widget_1",
993 type="widget",
994 created_at=datetime.now(),
995 widget=Card(children=[Text(key="text", value="Widget content")]),
996 thread_id=thread.id,
997 ),
998 )
999
1000 with make_server(responder, handle_feedback=handle_feedback) as server:
1001 events = await server.process_streaming(
1002 ThreadsCreateReq(
1003 params=ThreadCreateParams(
1004 input=UserMessageInput(
1005 content=[UserMessageTextContent(text="Test feedback")],
1006 attachments=[],
1007 inference_options=InferenceOptions(),
1008 )
1009 )
1010 )
1011 )
1012 thread = next(
1013 event.thread for event in events if event.type == "thread.created"
1014 )
1015
1016 await server.process_non_streaming(
1017 ItemsFeedbackReq(
1018 params=ItemFeedbackParams(
1019 thread_id=thread.id,
1020 item_ids=["assistant_msg_1"],
1021 kind="positive",
1022 )
1023 )
1024 )
1025 await server.process_non_streaming(
1026 ItemsFeedbackReq(
1027 params=ItemFeedbackParams(
1028 thread_id=thread.id,
1029 item_ids=["widget_1"],
1030 kind="negative",
1031 )
1032 )
1033 )
1034
1035 assert len(called) == 2
1036 assert called[0][0] == thread.id
1037 assert called[0][1] == ["assistant_msg_1"]
1038 assert called[0][2] == "positive"
1039
1040 assert called[1][0] == thread.id
1041 assert called[1][1] == ["widget_1"]
1042 assert called[1][2] == "negative"
1043
1044
1045async def test_get_thread_by_id():
1046 make_thread_items()
1047
1048 with make_server() as server:
1049 # Create a thread by sending a ThreadCreateRequest
1050 events = await server.process_streaming(
1051 ThreadsCreateReq(
1052 params=ThreadCreateParams(
1053 input=UserMessageInput(
1054 content=[UserMessageTextContent(text="Test thread")],
1055 attachments=[],
1056 inference_options=InferenceOptions(),
1057 )
1058 )
1059 )
1060 )
1061 thread = next(
1062 event.thread for event in events if event.type == "thread.created"
1063 )
1064
1065 # Retrieve the thread by id
1066 result = await server.process_non_streaming(
1067 ThreadsGetByIdReq(
1068 params=ThreadGetByIdParams(thread_id=thread.id),
1069 )
1070 )
1071 loaded_thread = TypeAdapter(Thread).validate_json(result.json)
1072 assert loaded_thread.id == thread.id
1073 assert loaded_thread.title == thread.title
1074 assert loaded_thread.items.data[0].type == "user_message"
1075 assert loaded_thread.items.data[0].content[0].text == "Test thread"
1076
1077
1078async def test_create_file():
1079 store = InMemoryFileStore()
1080 file_name = "test-file-name"
1081 file_content_type = "text/plain"
1082
1083 with make_server(file_store=store) as server:
1084 result = await server.process_non_streaming(
1085 AttachmentsCreateReq(
1086 params=AttachmentCreateParams(
1087 name=file_name, size=1024, mime_type=file_content_type
1088 )
1089 )
1090 )
1091 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1092
1093 # Verify response includes file
1094 assert attachment.id is not None
1095 assert attachment.mime_type == file_content_type
1096 assert attachment.name == file_name
1097 assert attachment.type == "file"
1098 assert attachment.upload_descriptor is not None
1099 assert attachment.upload_descriptor.url == AnyUrl(
1100 f"https://example.com/{attachment.id}/upload"
1101 )
1102 assert attachment.upload_descriptor.method == "PUT"
1103 assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"}
1104 assert attachment.id in store.files
1105
1106 # The `metadata` field is excluded from serialization by default but still stored server-side
1107 assert attachment.metadata is None
1108 assert store.files[attachment.id].metadata == {"source": "test"}
1109
1110
1111async def test_create_image_file():
1112 store = InMemoryFileStore()
1113 file_name = "test-file-name"
1114 file_content_type = "image/png"
1115
1116 with make_server(file_store=store) as server:
1117 result = await server.process_non_streaming(
1118 AttachmentsCreateReq(
1119 params=AttachmentCreateParams(
1120 name=file_name,
1121 size=1024,
1122 mime_type=file_content_type,
1123 )
1124 )
1125 )
1126 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1127
1128 assert attachment.id is not None
1129 assert attachment.mime_type == file_content_type
1130 assert attachment.name == file_name
1131 assert attachment.type == "image"
1132 assert attachment.preview_url == AnyUrl(
1133 f"https://example.com/{attachment.id}/preview"
1134 )
1135 assert attachment.upload_descriptor is not None
1136 assert attachment.upload_descriptor.url == AnyUrl(
1137 f"https://example.com/{attachment.id}/upload"
1138 )
1139 assert attachment.upload_descriptor.method == "PUT"
1140 assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"}
1141
1142 assert attachment.id in store.files
1143
1144
1145async def test_user_message_attachment_metadata_not_serialized():
1146 attachment = FileAttachment(
1147 id="file_with_metadata",
1148 type="file",
1149 mime_type="text/plain",
1150 name="test.txt",
1151 metadata={"source": "test"},
1152 )
1153
1154 with make_server() as server:
1155 await server.store.save_attachment(attachment, DEFAULT_CONTEXT)
1156 events = await server.process_streaming(
1157 ThreadsCreateReq(
1158 params=ThreadCreateParams(
1159 input=UserMessageInput(
1160 content=[UserMessageTextContent(text="Message with file")],
1161 attachments=[attachment.id],
1162 inference_options=InferenceOptions(),
1163 )
1164 )
1165 )
1166 )
1167 thread = next(
1168 event.thread for event in events if event.type == "thread.created"
1169 )
1170
1171 list_result = await server.process_non_streaming(
1172 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1173 )
1174 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1175 user_items = [item for item in items.data if item.type == "user_message"]
1176 assert user_items
1177 attachments = user_items[0].attachments
1178 assert attachments
1179 assert all(att.metadata is None for att in attachments)
1180
1181
1182async def test_create_file_without_filestore():
1183 with pytest.raises(RuntimeError):
1184 with make_server() as server:
1185 await server.process_non_streaming(
1186 AttachmentsCreateReq(
1187 params=AttachmentCreateParams(
1188 name="test-file-name",
1189 size=1024,
1190 mime_type="text/plain",
1191 )
1192 )
1193 )
1194
1195
1196async def test_delete_file():
1197 store = InMemoryFileStore()
1198 store.files["test-file-id"] = {"id": "test-file-id"}
1199 with make_server(file_store=store) as server:
1200 await server.process_non_streaming(
1201 AttachmentsDeleteReq(
1202 params=AttachmentDeleteParams(attachment_id="test-file-id")
1203 )
1204 )
1205 assert store.files == {}
1206
1207
1208async def test_delete_file_without_filestore():
1209 with pytest.raises(RuntimeError):
1210 with make_server() as server:
1211 await server.process_non_streaming(
1212 AttachmentsDeleteReq(
1213 params=AttachmentDeleteParams(attachment_id="test-file-id")
1214 )
1215 )
1216
1217
1218async def test_list_threads_paginated_response():
1219 with make_server() as server:
1220 for i in range(25):
1221 await server.process_streaming(
1222 ThreadsCreateReq(
1223 params=ThreadCreateParams(
1224 input=UserMessageInput(
1225 content=[UserMessageTextContent(text=f"Thread {i}")],
1226 attachments=[],
1227 inference_options=InferenceOptions(),
1228 )
1229 )
1230 )
1231 )
1232 list_result = await server.process_non_streaming(
1233 ThreadsListReq(params=ThreadListParams(limit=20))
1234 )
1235 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1236 assert len(threads.data) == 20
1237 assert threads.has_more is True
1238 assert threads.after is not None
1239
1240 list_result = await server.process_non_streaming(
1241 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
1242 )
1243 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1244 assert len(threads.data) == 5
1245 assert threads.has_more is False
1246
1247
1248async def test_threads_update_request():
1249 with make_server() as server:
1250 events = await server.process_streaming(
1251 ThreadsCreateReq(
1252 params=ThreadCreateParams(
1253 input=UserMessageInput(
1254 content=[UserMessageTextContent(text="Hello, world!")],
1255 attachments=[],
1256 inference_options=InferenceOptions(),
1257 )
1258 )
1259 )
1260 )
1261 thread = next(
1262 event.thread for event in events if event.type == "thread.created"
1263 )
1264 updated_thread = await server.process_non_streaming(
1265 ThreadsUpdateReq(
1266 params=ThreadUpdateParams(thread_id=thread.id, title="New Title")
1267 )
1268 )
1269 updated_thread = TypeAdapter[Thread](Thread).validate_json(updated_thread.json)
1270 assert updated_thread.title == "New Title"
1271
1272
1273async def test_returns_widget_item_generator():
1274 thread = ThreadMetadata(
1275 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1276 )
1277
1278 def render_widget(i: int) -> Card:
1279 return Card(children=[Text(id="text", value="Hello, world"[:i])])
1280
1281 async def widget_generator():
1282 yield render_widget(0)
1283 yield render_widget(3)
1284 yield render_widget(6)
1285 yield render_widget(12)
1286
1287 events = [event async for event in stream_widget(thread, widget_generator())]
1288
1289 assert len(events) == 5
1290 assert isinstance(events[0], ThreadItemAddedEvent)
1291 assert isinstance(events[0].item, WidgetItem)
1292 assert events[0].item.widget == Card(children=[Text(id="text", value="")])
1293
1294 assert isinstance(events[1], ThreadItemUpdatedEvent)
1295 assert events[1].update.type == "widget.streaming_text.value_delta"
1296 assert events[1].update.component_id == "text"
1297 assert events[1].update.delta == "Hel"
1298
1299 assert isinstance(events[2], ThreadItemUpdatedEvent)
1300 assert events[2].update.type == "widget.streaming_text.value_delta"
1301 assert events[2].update.component_id == "text"
1302 assert events[2].update.delta == "lo,"
1303
1304 assert isinstance(events[3], ThreadItemUpdatedEvent)
1305 assert events[3].update.type == "widget.streaming_text.value_delta"
1306 assert events[3].update.component_id == "text"
1307 assert events[3].update.delta == " world"
1308
1309 assert isinstance(events[4], ThreadItemDoneEvent)
1310 assert isinstance(events[4].item, WidgetItem)
1311 assert events[4].item.widget == Card(
1312 children=[Text(id="text", value="Hello, world")]
1313 )
1314
1315
1316async def test_returns_widget_item_generator_full_replace():
1317 thread = ThreadMetadata(
1318 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1319 )
1320
1321 async def widget_generator():
1322 yield Card(children=[Text(id="text", value="Hello")])
1323 yield Card(children=[Text(id="text", value="World")])
1324
1325 events = [event async for event in stream_widget(thread, widget_generator())]
1326
1327 assert len(events) == 3
1328 assert isinstance(events[0], ThreadItemAddedEvent)
1329 assert isinstance(events[0].item, WidgetItem)
1330 assert events[0].item.widget == Card(children=[Text(id="text", value="Hello")])
1331
1332 assert isinstance(events[1], ThreadItemUpdatedEvent)
1333 assert events[1].update.type == "widget.root.updated"
1334 assert events[1].update.widget == Card(children=[Text(id="text", value="World")])
1335
1336 assert isinstance(events[2], ThreadItemDoneEvent)
1337 assert isinstance(events[2].item, WidgetItem)
1338 assert events[2].item.widget == Card(children=[Text(id="text", value="World")])
1339
1340
1341async def test_delete_thread():
1342 with make_server() as server:
1343 events = await server.process_streaming(
1344 ThreadsCreateReq(
1345 params=ThreadCreateParams(
1346 input=UserMessageInput(
1347 content=[UserMessageTextContent(text="Thread to delete")],
1348 attachments=[],
1349 inference_options=InferenceOptions(),
1350 )
1351 )
1352 )
1353 )
1354 thread = next(
1355 event.thread for event in events if event.type == "thread.created"
1356 )
1357
1358 response = await server.process_non_streaming(
1359 ThreadsDeleteReq(params=ThreadDeleteParams(thread_id=thread.id))
1360 )
1361 assert response.json == b"{}"
1362
1363 try:
1364 await server.process_non_streaming(
1365 ThreadsGetByIdReq(
1366 params=ThreadGetByIdParams(thread_id=thread.id),
1367 )
1368 )
1369 assert False, "Expected ValueError for deleted thread"
1370 except NotFoundError as e:
1371 assert f"Thread {thread.id} not found" in str(e)
1372
1373
1374async def test_thread_item_removed_event_removes_item():
1375 removed_id = "msg_to_remove"
1376
1377 async def responder(
1378 thread: ThreadMetadata,
1379 input: UserMessageItem | None,
1380 context: Any,
1381 ) -> AsyncIterator[ThreadStreamEvent]:
1382 yield ThreadItemDoneEvent(
1383 item=UserMessageItem(
1384 id=removed_id,
1385 content=[UserMessageTextContent(text="To be removed")],
1386 attachments=[],
1387 inference_options=InferenceOptions(),
1388 thread_id=thread.id,
1389 created_at=datetime.now(),
1390 ),
1391 )
1392 yield ThreadItemRemovedEvent(item_id=removed_id)
1393
1394 with make_server(responder) as server:
1395 events = await server.process_streaming(
1396 ThreadsCreateReq(
1397 params=ThreadCreateParams(
1398 input=UserMessageInput(
1399 content=[UserMessageTextContent(text="Trigger removal")],
1400 attachments=[],
1401 inference_options=InferenceOptions(),
1402 )
1403 )
1404 )
1405 )
1406 thread = next(
1407 event.thread for event in events if event.type == "thread.created"
1408 )
1409 assert any(
1410 e.type == "thread.item.removed" and e.item_id == removed_id for e in events
1411 )
1412 # The item should not be present in the store
1413 list_result = await server.process_non_streaming(
1414 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1415 )
1416 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1417 assert all(i.id != removed_id for i in items)
1418
1419
1420async def test_raising_in_responder_yields_error_event():
1421 async def responder(
1422 thread: ThreadMetadata,
1423 input: UserMessageItem | None,
1424 context: Any,
1425 ) -> AsyncIterator[ThreadStreamEvent]:
1426 raise ValueError("Test error")
1427 yield
1428
1429 with make_server(responder) as server:
1430 result = await server.process(
1431 ThreadsCreateReq(
1432 params=ThreadCreateParams(
1433 input=UserMessageInput(
1434 content=[UserMessageTextContent(text="Test error")],
1435 attachments=[],
1436 inference_options=InferenceOptions(),
1437 )
1438 )
1439 ).model_dump_json(),
1440 RequestContext(user_id="test_user"),
1441 )
1442 assert isinstance(result, StreamingResult)
1443 iter = result.__aiter__()
1444
1445 # Yield an errro
1446 async for e in iter:
1447 e = decode_event(e)
1448 if e.type == "error":
1449 assert e.code == ErrorCode.STREAM_ERROR
1450 break
1451 else:
1452 assert False, "Expected error event"
1453
1454 # Then finish
1455 try:
1456 await iter.__anext__()
1457 except StopAsyncIteration:
1458 pass
1459 else:
1460 assert False, "Expected StopAsyncIteration"
1461
1462
1463async def test_thread_item_replaced_event_replaces_item():
1464 original_id = "msg_to_replace"
1465
1466 async def responder(
1467 thread: ThreadMetadata,
1468 input: UserMessageItem | None,
1469 context: Any,
1470 ) -> AsyncIterator[ThreadStreamEvent]:
1471 # First add an item
1472 yield ThreadItemDoneEvent(
1473 item=UserMessageItem(
1474 id=original_id,
1475 content=[UserMessageTextContent(text="Original content")],
1476 attachments=[],
1477 inference_options=InferenceOptions(),
1478 thread_id=thread.id,
1479 created_at=datetime.now(),
1480 ),
1481 )
1482 # Then replace it
1483 replacement_item = AssistantMessageItem(
1484 id=original_id,
1485 content=[AssistantMessageContent(text="This is the replacement content")],
1486 created_at=datetime.now(),
1487 thread_id=thread.id,
1488 )
1489 yield ThreadItemReplacedEvent(item=replacement_item)
1490
1491 with make_server(responder) as server:
1492 events = await server.process_streaming(
1493 ThreadsCreateReq(
1494 params=ThreadCreateParams(
1495 input=UserMessageInput(
1496 content=[UserMessageTextContent(text="Trigger replacement")],
1497 attachments=[],
1498 inference_options=InferenceOptions(),
1499 )
1500 )
1501 )
1502 )
1503 thread = next(
1504 event.thread for event in events if event.type == "thread.created"
1505 )
1506
1507 # Verify the replacement event was generated
1508 assert any(
1509 e.type == "thread.item.replaced" and e.item.id == original_id
1510 for e in events
1511 )
1512
1513 # Verify the item was replaced in the store
1514 list_result = await server.process_non_streaming(
1515 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1516 )
1517 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1518
1519 # Find the replaced item
1520 replaced_item = next((i for i in items if i.id == original_id), None)
1521 assert replaced_item is not None
1522 assert isinstance(replaced_item, AssistantMessageItem)
1523 assert isinstance(replaced_item.content[0], AssistantMessageContent)
1524 assert replaced_item.content[0].text == "This is the replacement content"
1525
1526
1527async def test_thread_item_replaced_event_with_widget():
1528 widget_id = "widget_to_replace"
1529 original_widget = Card(children=[Text(value="Original widget content")])
1530 replacement_widget = Card(children=[Text(value="Replaced widget content")])
1531
1532 async def responder(
1533 thread: ThreadMetadata,
1534 input: UserMessageItem | None,
1535 context: Any,
1536 ) -> AsyncIterator[ThreadStreamEvent]:
1537 # First add a widget item
1538 yield ThreadItemDoneEvent(
1539 item=WidgetItem(
1540 id=widget_id,
1541 widget=original_widget,
1542 created_at=datetime.now(),
1543 thread_id=thread.id,
1544 ),
1545 )
1546 # Then replace it
1547 yield ThreadItemReplacedEvent(
1548 item=WidgetItem(
1549 id=widget_id,
1550 widget=replacement_widget,
1551 created_at=datetime.now(),
1552 thread_id=thread.id,
1553 ),
1554 )
1555
1556 with make_server(responder) as server:
1557 events = await server.process_streaming(
1558 ThreadsCreateReq(
1559 params=ThreadCreateParams(
1560 input=UserMessageInput(
1561 content=[
1562 UserMessageTextContent(text="Trigger widget replacement")
1563 ],
1564 attachments=[],
1565 inference_options=InferenceOptions(),
1566 )
1567 )
1568 )
1569 )
1570 thread = next(
1571 event.thread for event in events if event.type == "thread.created"
1572 )
1573
1574 # Verify the replacement event was generated
1575 replacement_event = next(
1576 (
1577 e
1578 for e in events
1579 if e.type == "thread.item.replaced" and e.item.id == widget_id
1580 ),
1581 None,
1582 )
1583 assert replacement_event is not None
1584 assert isinstance(replacement_event.item, WidgetItem)
1585 assert replacement_event.item.widget == replacement_widget
1586
1587 # Verify the item was replaced in the store
1588 list_result = await server.process_non_streaming(
1589 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1590 )
1591 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1592
1593 # Find the replaced widget item
1594 replaced_item = next((i for i in items if i.id == widget_id), None)
1595 assert replaced_item is not None
1596 assert isinstance(replaced_item, WidgetItem)
1597 assert isinstance(replaced_item.widget, Card)
1598 assert isinstance(replaced_item.widget.children[0], Text)
1599 assert replaced_item.widget.children[0].value == "Replaced widget content"
1600
1601
1602async def test_retry_after_item():
1603 responder_calls = []
1604
1605 async def responder(
1606 thread: ThreadMetadata,
1607 input: UserMessageItem | None,
1608 context: Any,
1609 ) -> AsyncIterator[ThreadStreamEvent]:
1610 responder_calls.append(input)
1611
1612 if len(responder_calls) == 1:
1613 # First response - return an assistant message
1614 yield ThreadItemDoneEvent(
1615 item=AssistantMessageItem(
1616 id="assistant_msg_1",
1617 content=[AssistantMessageContent(text="First response")],
1618 created_at=datetime.now(),
1619 thread_id=thread.id,
1620 ),
1621 )
1622 else:
1623 # Retry response - return different content
1624 yield ThreadItemDoneEvent(
1625 item=AssistantMessageItem(
1626 id="assistant_msg_2",
1627 content=[AssistantMessageContent(text="Retried response")],
1628 created_at=datetime.now(),
1629 thread_id=thread.id,
1630 ),
1631 )
1632
1633 with make_server(responder) as server:
1634 events = await server.process_streaming(
1635 ThreadsCreateReq(
1636 params=ThreadCreateParams(
1637 input=UserMessageInput(
1638 content=[UserMessageTextContent(text="Hello, world!")],
1639 attachments=[],
1640 inference_options=InferenceOptions(),
1641 )
1642 )
1643 )
1644 )
1645 thread = next(
1646 event.thread for event in events if event.type == "thread.created"
1647 )
1648
1649 list_result = await server.process_non_streaming(
1650 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1651 )
1652 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1653 assert len(items_before.data) == 2
1654 assert items_before.data[0].type == "assistant_message"
1655 assert items_before.data[0].content[0].type == "output_text"
1656 assert items_before.data[0].content[0].text == "First response"
1657 assert items_before.data[1].type == "user_message"
1658 assert items_before.data[1].content[0].text == "Hello, world!"
1659
1660 # Get the user message ID for retrying after it
1661 user_message_id = items_before.data[1].id
1662
1663 # Retry after the user message
1664 retry_events = await server.process_streaming(
1665 ThreadsRetryAfterItemReq(
1666 params=ThreadRetryAfterItemParams(
1667 thread_id=thread.id, item_id=user_message_id
1668 )
1669 )
1670 )
1671
1672 # Verify the retry generated new response
1673 assert len(retry_events) == 2
1674 assert retry_events[0].type == "stream_options"
1675 assert retry_events[1].type == "thread.item.done"
1676 assert retry_events[1].item.type == "assistant_message"
1677 assert retry_events[1].item.content[0].type == "output_text"
1678 assert retry_events[1].item.content[0].text == "Retried response"
1679
1680 # Verify the responder was called twice with the same user message
1681 assert len(responder_calls) == 2
1682 assert responder_calls[0] == responder_calls[1]
1683 assert responder_calls[0].type == "user_message"
1684 assert responder_calls[0].content[0].text == "Hello, world!"
1685
1686 # Verify the assistant response was replaced in storage
1687 list_result_after = await server.process_non_streaming(
1688 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1689 )
1690 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1691 list_result_after.json
1692 )
1693 assert len(items_after.data) == 2 # Still user message + assistant message
1694 assert items_after.data[0].type == "assistant_message"
1695 assert items_after.data[0].content[0].type == "output_text"
1696 assert items_after.data[0].content[0].text == "Retried response" # New response
1697 assert items_after.data[1].type == "user_message"
1698 assert items_after.data[1].content[0].text == "Hello, world!"
1699
1700
1701async def test_retry_after_item_invalid_item():
1702 async def responder(
1703 thread: ThreadMetadata,
1704 input: UserMessageItem | None,
1705 context: Any,
1706 ) -> AsyncIterator[ThreadStreamEvent]:
1707 yield ThreadItemDoneEvent(
1708 item=AssistantMessageItem(
1709 id="assistant_msg_1",
1710 content=[AssistantMessageContent(text="Response")],
1711 created_at=datetime.now(),
1712 thread_id=thread.id,
1713 ),
1714 )
1715
1716 with make_server(responder) as server:
1717 events = await server.process_streaming(
1718 ThreadsCreateReq(
1719 params=ThreadCreateParams(
1720 input=UserMessageInput(
1721 content=[UserMessageTextContent(text="Hello")],
1722 attachments=[],
1723 inference_options=InferenceOptions(),
1724 )
1725 )
1726 )
1727 )
1728 thread = next(
1729 event.thread for event in events if event.type == "thread.created"
1730 )
1731
1732 items_result = await server.process_non_streaming(
1733 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1734 )
1735 items = TypeAdapter(Page[ThreadItem]).validate_json(items_result.json)
1736 assistant_item_id = items.data[0].id
1737
1738 # Try to retry after an assistant message (should fail)
1739 with pytest.raises(ValueError, match="is not a user message"):
1740 await server.process_streaming(
1741 ThreadsRetryAfterItemReq(
1742 params=ThreadRetryAfterItemParams(
1743 thread_id=thread.id, item_id=assistant_item_id
1744 )
1745 )
1746 )
1747
1748
1749async def test_retry_after_item_with_multiple_messages():
1750 responder_calls = []
1751
1752 async def responder(
1753 thread: ThreadMetadata,
1754 input: UserMessageItem | None,
1755 context: Any,
1756 ) -> AsyncIterator[ThreadStreamEvent]:
1757 responder_calls.append(input)
1758 assert input is not None
1759 assert input.type == "user_message"
1760 yield ThreadItemDoneEvent(
1761 item=AssistantMessageItem(
1762 id=f"assistant_msg_{len(responder_calls)}",
1763 content=[
1764 AssistantMessageContent(
1765 text=f"Response to: {input.content[0].text}"
1766 )
1767 ],
1768 created_at=datetime.now(),
1769 thread_id=thread.id,
1770 ),
1771 )
1772
1773 with make_server(responder) as server:
1774 events = await server.process_streaming(
1775 ThreadsCreateReq(
1776 params=ThreadCreateParams(
1777 input=UserMessageInput(
1778 content=[UserMessageTextContent(text="First message")],
1779 attachments=[],
1780 inference_options=InferenceOptions(),
1781 )
1782 )
1783 )
1784 )
1785 thread = next(
1786 event.thread for event in events if event.type == "thread.created"
1787 )
1788
1789 await server.process_streaming(
1790 ThreadsAddUserMessageReq(
1791 params=ThreadAddUserMessageParams(
1792 thread_id=thread.id,
1793 input=UserMessageInput(
1794 content=[UserMessageTextContent(text="Second message")],
1795 attachments=[],
1796 inference_options=InferenceOptions(),
1797 ),
1798 )
1799 )
1800 )
1801
1802 # Verify we have 4 items: user1, assistant1, user2, assistant2
1803 list_result = await server.process_non_streaming(
1804 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1805 )
1806 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1807 assert len(items_before.data) == 4
1808
1809 # Get the second user message ID to retry after it
1810 second_user_message_id = items_before.data[1].id # Second user message
1811
1812 # Retry after the second user message
1813 retry_events = await server.process_streaming(
1814 ThreadsRetryAfterItemReq(
1815 params=ThreadRetryAfterItemParams(
1816 thread_id=thread.id, item_id=second_user_message_id
1817 )
1818 )
1819 )
1820
1821 assert len(retry_events) == 2
1822 assert retry_events[0].type == "stream_options"
1823 assert retry_events[1].type == "thread.item.done"
1824
1825 # Verify retry used the second user message
1826 assert len(responder_calls) == 3 # Original 2 + 1 retry
1827 assert responder_calls[2].type == "user_message"
1828 assert responder_calls[2].content[0].text == "Second message"
1829
1830 # Verify only the last assistant response was removed and regenerated
1831 list_result_after = await server.process_non_streaming(
1832 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1833 )
1834 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1835 list_result_after.json
1836 )
1837 assert len(items_after.data) == 4 # Still 4 items
1838 # First user message and response should remain unchanged
1839 assert items_after.data[3].type == "user_message"
1840 assert items_after.data[3].content[0].text == "First message"
1841 assert items_after.data[3].content[0].type == "input_text"
1842 assert items_after.data[3].content[0].text == "First message"
1843 assert items_after.data[2].type == "assistant_message"
1844 assert items_after.data[2].content[0].type == "output_text"
1845 assert items_after.data[2].content[0].text == "Response to: First message"
1846 # Second user message should remain, but assistant response should be new
1847 assert items_after.data[1].type == "user_message"
1848 assert items_after.data[1].content[0].text == "Second message"
1849 assert items_after.data[0].type == "assistant_message"
1850 assert items_after.data[0].content[0].type == "output_text"
1851 assert items_after.data[0].content[0].text == "Response to: Second message"
1852 assert items_after.data[0].id != items_before.data[0].id
1853
1854
1855async def test_threads_create_passes_tools_to_responder():
1856 tool_choice = ToolChoice(id="web_search")
1857
1858 async def responder(
1859 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1860 ) -> AsyncIterator[ThreadStreamEvent]:
1861 assert input is not None
1862 assert input.inference_options.tool_choice is not None
1863 assert input.inference_options.tool_choice.id == tool_choice.id
1864
1865 yield ThreadItemDoneEvent(
1866 item=AssistantMessageItem(
1867 id="assistant_msg_tools_create",
1868 content=[AssistantMessageContent(text="ok")],
1869 created_at=datetime.now(),
1870 thread_id=thread.id,
1871 ),
1872 )
1873
1874 with make_server(responder) as server:
1875 events = await server.process_streaming(
1876 ThreadsCreateReq(
1877 params=ThreadCreateParams(
1878 input=UserMessageInput(
1879 content=[UserMessageTextContent(text="Hello with tools")],
1880 attachments=[],
1881 inference_options=InferenceOptions(tool_choice=tool_choice),
1882 )
1883 )
1884 )
1885 )
1886 # Sanity check that the request flowed through.
1887 assert any(e.type == "thread.item.done" for e in events)
1888
1889
1890async def test_retry_after_item_passes_tools_to_responder():
1891 pass
1892
1893
1894async def test_threads_add_message_passes_tools_to_responder():
1895 tool_choice = ToolChoice(id="retrieval")
1896
1897 async def responder(
1898 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1899 ) -> AsyncIterator[ThreadStreamEvent]:
1900 assert input is not None
1901 assert input.inference_options.tool_choice is not None
1902 assert input.inference_options.tool_choice.id == tool_choice.id
1903
1904 yield ThreadItemDoneEvent(
1905 item=AssistantMessageItem(
1906 id="assistant_msg_tools_add",
1907 content=[AssistantMessageContent(text="ok")],
1908 created_at=datetime.now(),
1909 thread_id=thread.id,
1910 ),
1911 )
1912
1913 with make_server(responder) as server:
1914 # Create a thread first.
1915 events = await server.process_streaming(
1916 ThreadsCreateReq(
1917 params=ThreadCreateParams(
1918 input=UserMessageInput(
1919 content=[UserMessageTextContent(text="Start")],
1920 attachments=[],
1921 inference_options=InferenceOptions(tool_choice=tool_choice),
1922 )
1923 )
1924 )
1925 )
1926 thread = next(e.thread for e in events if e.type == "thread.created")
1927
1928 # Add a message with tools and verify they reach the responder.
1929 events = await server.process_streaming(
1930 ThreadsAddUserMessageReq(
1931 params=ThreadAddUserMessageParams(
1932 thread_id=thread.id,
1933 input=UserMessageInput(
1934 content=[UserMessageTextContent(text="Second")],
1935 attachments=[],
1936 inference_options=InferenceOptions(tool_choice=tool_choice),
1937 ),
1938 )
1939 )
1940 )
1941 # Sanity check that the request flowed through.
1942 assert any(e.type == "thread.item.done" for e in events)
1943
1944
1945async def test_custom_id_generators_on_server():
1946 class UniqueIdStore(SQLiteStore):
1947 def __init__(self, path):
1948 super().__init__(path)
1949 self.counter = 0
1950
1951 def generate_thread_id(self, context) -> str:
1952 self.counter += 1
1953 return f"thr_custom_{self.counter}"
1954
1955 def generate_item_id(
1956 self,
1957 item_type: str,
1958 thread: ThreadMetadata,
1959 context,
1960 ) -> str:
1961 self.counter += 1
1962 return f"{item_type}_custom_{self.counter}_{thread.id}"
1963
1964 async def responder(
1965 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1966 ) -> AsyncIterator[ThreadStreamEvent]:
1967 # Emit a tool call to later verify tool_call_output id generation
1968 return
1969 yield
1970
1971 with make_server(responder) as server:
1972 # Swap in our custom ID-generating store, reusing the same DB path.
1973 db_path = cast(SQLiteStore, server.store).db_path
1974 server.store = UniqueIdStore(db_path)
1975
1976 # Create thread and verify thread id and user message id
1977 result = await server.process(
1978 ThreadsCreateReq(
1979 params=ThreadCreateParams(
1980 input=UserMessageInput(
1981 content=[UserMessageTextContent(text="Hello")],
1982 attachments=[],
1983 inference_options=InferenceOptions(),
1984 )
1985 )
1986 ).model_dump_json(),
1987 DEFAULT_CONTEXT,
1988 )
1989 assert isinstance(result, StreamingResult)
1990 events = await decode_streaming_result(result)
1991
1992 thread_created = next(e for e in events if e.type == "thread.created")
1993 assert thread_created.thread.id == "thr_custom_1"
1994
1995 user_message = next(
1996 e.item
1997 for e in events
1998 if e.type == "thread.item.done" and e.item.type == "user_message"
1999 )
2000 assert user_message.id == "message_custom_2_thr_custom_1"
2001