openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
1860286bf7fca86e6b70e7e87ef4681a0f482145

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_chatkit_server.py

1938lines · modecode

1import asyncio
2import sqlite3
3from contextlib import contextmanager
4from datetime import datetime
5from typing import Any, AsyncIterator, Callable, cast
6
7import pytest
8from helpers.mock_store import SQLiteStore
9from pydantic import AnyUrl, TypeAdapter
10
11from chatkit.actions import Action
12from chatkit.errors import ErrorCode
13from chatkit.server import (
14 ChatKitServer,
15 NonStreamingResult,
16 StreamingResult,
17 stream_widget,
18)
19from chatkit.store import (
20 AttachmentStore,
21 NotFoundError,
22 StoreItemType,
23 default_generate_id,
24)
25from chatkit.types import (
26 AssistantMessageContent,
27 AssistantMessageContentPartTextDelta,
28 AssistantMessageItem,
29 Attachment,
30 AttachmentCreateParams,
31 AttachmentDeleteParams,
32 AttachmentsCreateReq,
33 AttachmentsDeleteReq,
34 ClientToolCallItem,
35 FeedbackKind,
36 FileAttachment,
37 ImageAttachment,
38 InferenceOptions,
39 ItemFeedbackParams,
40 ItemsFeedbackReq,
41 ItemsListParams,
42 ItemsListReq,
43 LockedStatus,
44 Page,
45 ProgressUpdateEvent,
46 Thread,
47 ThreadAddClientToolOutputParams,
48 ThreadAddUserMessageParams,
49 ThreadCreatedEvent,
50 ThreadCreateParams,
51 ThreadCustomActionParams,
52 ThreadDeleteParams,
53 ThreadGetByIdParams,
54 ThreadItem,
55 ThreadItemAddedEvent,
56 ThreadItemDoneEvent,
57 ThreadItemRemovedEvent,
58 ThreadItemReplacedEvent,
59 ThreadItemUpdatedEvent,
60 ThreadListParams,
61 ThreadMetadata,
62 ThreadRetryAfterItemParams,
63 ThreadsAddClientToolOutputReq,
64 ThreadsAddUserMessageReq,
65 ThreadsCreateReq,
66 ThreadsCustomActionReq,
67 ThreadsDeleteReq,
68 ThreadsGetByIdReq,
69 ThreadsListReq,
70 ThreadsRetryAfterItemReq,
71 ThreadStreamEvent,
72 ThreadsUpdateReq,
73 ThreadUpdatedEvent,
74 ThreadUpdateParams,
75 ToolChoice,
76 UserMessageInput,
77 UserMessageItem,
78 UserMessageTextContent,
79 WidgetItem,
80 WidgetRootUpdated,
81)
82from chatkit.widgets import Card, Text
83from tests._types import RequestContext
84from tests.test_store import make_thread_items
85
86server_id = 0
87
88
89DEFAULT_CONTEXT = RequestContext(user_id="test_user")
90
91
92class InMemoryFileStore(AttachmentStore):
93 def __init__(self):
94 self.files = {}
95
96 async def delete_attachment(self, attachment_id: str, context: Any) -> None:
97 del self.files[attachment_id]
98
99 async def create_attachment(
100 self, input: AttachmentCreateParams, context: RequestContext
101 ) -> Attachment:
102 # Simple in-memory attachment creation
103 id = f"atc_{len(self.files) + 1}"
104 if input.mime_type.startswith("image/"):
105 attachment = ImageAttachment(
106 id=id,
107 mime_type=input.mime_type,
108 name=input.name,
109 preview_url=AnyUrl(f"https://example.com/{id}/preview"),
110 upload_url=AnyUrl(f"https://example.com/{id}/upload"),
111 )
112 else:
113 attachment = FileAttachment(
114 id=id,
115 mime_type=input.mime_type,
116 name=input.name,
117 upload_url=AnyUrl(f"https://example.com/{id}/upload"),
118 )
119 self.files[attachment.id] = attachment
120 return attachment
121
122
123def decode_event(event: bytes) -> ThreadStreamEvent:
124 return TypeAdapter(ThreadStreamEvent).validate_json(event.split(b"data: ")[1])
125
126
127async def decode_streaming_result(
128 streaming_result: StreamingResult,
129) -> list[ThreadStreamEvent]:
130 return [decode_event(event) async for event in streaming_result.json_events]
131
132
133@contextmanager
134def make_server(
135 responder: Callable[
136 [ThreadMetadata, UserMessageItem | None, Any], AsyncIterator[ThreadStreamEvent]
137 ]
138 | None = None,
139 handle_feedback: Callable[[str, list[str], FeedbackKind, Any], None] | None = None,
140 action_callback: Callable[
141 [ThreadMetadata, Action[str, Any], WidgetItem | None, Any],
142 AsyncIterator[ThreadStreamEvent],
143 ]
144 | None = None,
145 file_store: AttachmentStore | None = None,
146):
147 global server_id
148 db_path = f"file:{server_id}?mode=memory&cache=shared"
149 server_id += 1
150 # Keep the shared in-memory database from being deleted when the last connection is closed
151 db = sqlite3.connect(db_path, uri=True)
152
153 class TestChatKitServer(ChatKitServer):
154 def __init__(self):
155 super().__init__(SQLiteStore(db_path), file_store)
156
157 def action(
158 self,
159 thread: ThreadMetadata,
160 action: Action[str, Any],
161 sender: WidgetItem | None,
162 context: Any,
163 ) -> AsyncIterator[ThreadStreamEvent]:
164 if action_callback is None:
165 raise ValueError("action_callback not wired up")
166 return action_callback(thread, action, sender, context)
167
168 def respond(
169 self,
170 thread: ThreadMetadata,
171 input_user_message: UserMessageItem | None,
172 context: Any,
173 ) -> AsyncIterator[ThreadStreamEvent]:
174 if responder is None:
175 return self._empty_responder()
176 return responder(thread, input_user_message, context)
177
178 async def _empty_responder(self) -> AsyncIterator[ThreadStreamEvent]:
179 return
180 yield
181
182 async def add_feedback(
183 self,
184 thread_id: str,
185 item_ids: list[str],
186 feedback: FeedbackKind,
187 context: Any,
188 ) -> None:
189 if handle_feedback is None:
190 return
191 handle_feedback(thread_id, item_ids, feedback, context)
192
193 async def process_streaming(
194 self, request_obj, context: Any | None = None
195 ) -> list[ThreadStreamEvent]:
196 result = await self.process(
197 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
198 )
199 assert isinstance(result, StreamingResult)
200 return await decode_streaming_result(result)
201
202 async def process_non_streaming(self, request_obj, context: Any | None = None):
203 result = await self.process(
204 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
205 )
206 assert isinstance(result, NonStreamingResult)
207 return result
208
209 try:
210 yield TestChatKitServer()
211 finally:
212 db.close()
213
214
215async def test_stream_cancellation_persists_pending_assistant_message_and_hidden_context():
216 async def responder(
217 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
218 ) -> AsyncIterator[ThreadStreamEvent]:
219 yield ThreadItemAddedEvent(
220 item=AssistantMessageItem(
221 id="assistant-message-pending",
222 created_at=datetime.now(),
223 content=[AssistantMessageContent(text="Hello, ")],
224 thread_id=thread.id,
225 )
226 )
227 yield ThreadItemUpdatedEvent(
228 item_id="assistant-message-pending",
229 update=AssistantMessageContentPartTextDelta(
230 content_index=0,
231 delta="World!",
232 ),
233 )
234 raise asyncio.CancelledError()
235
236 with make_server(responder) as server:
237 # Allow hidden_context id generation in this test store
238 original_generate_item_id = server.store.generate_item_id
239
240 def generate_item_id(
241 item_type: StoreItemType, thread: ThreadMetadata, context: Any
242 ):
243 if item_type == "sdk_hidden_context":
244 return default_generate_id("sdk_hidden_context")
245 return original_generate_item_id(item_type, thread, context)
246
247 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
248
249 stream = await server.process(
250 ThreadsCreateReq(
251 params=ThreadCreateParams(
252 input=UserMessageInput(
253 content=[UserMessageTextContent(text="Hello")],
254 attachments=[],
255 inference_options=InferenceOptions(),
256 )
257 )
258 ).model_dump_json(),
259 DEFAULT_CONTEXT,
260 )
261 assert isinstance(stream, StreamingResult)
262
263 events: list[ThreadStreamEvent] = []
264 with pytest.raises(asyncio.CancelledError): # noqa: PT012
265 async for raw in stream.json_events:
266 events.append(decode_event(raw))
267
268 thread = next(e.thread for e in events if e.type == "thread.created")
269 items = await server.store.load_thread_items(
270 thread.id, None, 1, "desc", DEFAULT_CONTEXT
271 )
272 hidden_context_item = items.data[-1]
273 assert hidden_context_item.type == "sdk_hidden_context"
274 assert hidden_context_item.content == "The user cancelled the stream."
275
276 assistant_message_item = await server.store.load_item(
277 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
278 )
279 assert assistant_message_item.type == "assistant_message"
280 assert assistant_message_item.content[0].text == "Hello, World!"
281
282
283async def test_stream_cancellation_does_not_persist_pending_empty_assistant_message():
284 async def responder(
285 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
286 ) -> AsyncIterator[ThreadStreamEvent]:
287 yield ThreadItemAddedEvent(
288 item=AssistantMessageItem(
289 id="assistant-message-pending",
290 created_at=datetime.now(),
291 content=[],
292 thread_id=thread.id,
293 )
294 )
295 raise asyncio.CancelledError()
296
297 with make_server(responder) as server:
298 original_generate_item_id = server.store.generate_item_id
299
300 def generate_item_id(
301 item_type: StoreItemType, thread: ThreadMetadata, context: Any
302 ):
303 if item_type == "sdk_hidden_context":
304 return default_generate_id("sdk_hidden_context")
305 return original_generate_item_id(item_type, thread, context)
306
307 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
308
309 stream = await server.process(
310 ThreadsCreateReq(
311 params=ThreadCreateParams(
312 input=UserMessageInput(
313 content=[UserMessageTextContent(text="Hello")],
314 attachments=[],
315 inference_options=InferenceOptions(),
316 )
317 )
318 ).model_dump_json(),
319 DEFAULT_CONTEXT,
320 )
321 assert isinstance(stream, StreamingResult)
322
323 events: list[ThreadStreamEvent] = []
324 with pytest.raises(asyncio.CancelledError): # noqa: PT012
325 async for raw in stream.json_events:
326 events.append(decode_event(raw))
327
328 thread = next(e.thread for e in events if e.type == "thread.created")
329 items = await server.store.load_thread_items(
330 thread.id, None, 1, "desc", DEFAULT_CONTEXT
331 )
332 hidden_context_item = items.data[-1]
333 assert hidden_context_item.type == "sdk_hidden_context"
334 assert hidden_context_item.content == "The user cancelled the stream."
335
336 with pytest.raises(NotFoundError):
337 await server.store.load_item(
338 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
339 )
340
341
342async def test_flows_context_to_responder():
343 responder_context = None
344 add_feedback_context = None
345
346 async def responder(
347 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
348 ) -> AsyncIterator[ThreadStreamEvent]:
349 nonlocal responder_context
350 responder_context = context
351 yield ThreadItemDoneEvent(
352 item=AssistantMessageItem(
353 id="msg_1",
354 created_at=datetime.now(),
355 content=[AssistantMessageContent(text="Hello, world!")],
356 thread_id=thread.id,
357 ),
358 )
359
360 def add_feedback(
361 thread_id: str, item_ids: list[str], feedback: FeedbackKind, context: Any
362 ) -> None:
363 nonlocal add_feedback_context
364 add_feedback_context = context
365
366 context = RequestContext(user_id="text_ctx")
367 with make_server(responder, add_feedback) as server:
368 events = await server.process_streaming(
369 ThreadsCreateReq(
370 params=ThreadCreateParams(
371 input=UserMessageInput(
372 content=[UserMessageTextContent(text="Hello, world!")],
373 attachments=[],
374 inference_options=InferenceOptions(),
375 )
376 )
377 ),
378 context,
379 )
380 assert responder_context == context
381 thread = next(
382 event.thread for event in events if event.type == "thread.created"
383 )
384 item = next(
385 event.item
386 for event in events
387 if event.type == "thread.item.done"
388 and event.item.type == "assistant_message"
389 )
390 await server.process_non_streaming(
391 ItemsFeedbackReq(
392 params=ItemFeedbackParams(
393 thread_id=thread.id,
394 item_ids=[item.id],
395 kind="positive",
396 )
397 ),
398 context,
399 )
400 assert add_feedback_context == context
401
402
403async def test_creates_thread():
404 async def responder(
405 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
406 ) -> AsyncIterator[ThreadStreamEvent]:
407 return
408 yield
409
410 with make_server(responder) as server:
411 events = await server.process_streaming(
412 ThreadsCreateReq(
413 params=ThreadCreateParams(
414 input=UserMessageInput(
415 content=[UserMessageTextContent(text="Hello, world!")],
416 attachments=[],
417 inference_options=InferenceOptions(),
418 quoted_text="Quoted text",
419 )
420 )
421 )
422 )
423
424 thread_created_event = next(
425 event for event in events if event.type == "thread.created"
426 )
427 assert thread_created_event.thread.title is None
428
429 item_done_event = next(
430 event.item for event in events if event.type == "thread.item.done"
431 )
432
433 assert item_done_event.type == "user_message"
434 assert item_done_event.content[0].text == "Hello, world!"
435 assert item_done_event.quoted_text == "Quoted text"
436
437
438async def test_saves_thread_title():
439 async def responder(
440 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
441 ) -> AsyncIterator[ThreadStreamEvent]:
442 thread.title = "Updated title"
443 return
444 yield
445
446 with make_server(responder) as server:
447 events = await server.process_streaming(
448 ThreadsCreateReq(
449 params=ThreadCreateParams(
450 input=UserMessageInput(
451 content=[UserMessageTextContent(text="Hello, world!")],
452 attachments=[],
453 inference_options=InferenceOptions(),
454 )
455 )
456 )
457 )
458 thread = next(
459 event.thread for event in events if event.type == "thread.created"
460 )
461 assert (
462 await server.store.load_thread(
463 thread.id, RequestContext(user_id="test_user")
464 )
465 ).title == "Updated title"
466 assert events[-1].type == "thread.updated"
467 assert events[-1].thread.title == "Updated title"
468
469
470async def test_saves_thread_metadata():
471 async def responder(
472 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
473 ) -> AsyncIterator[ThreadStreamEvent]:
474 thread.metadata["test_key"] = "test_value"
475 return
476 yield
477
478 with make_server(responder) as server:
479 events = await server.process_streaming(
480 ThreadsCreateReq(
481 params=ThreadCreateParams(
482 input=UserMessageInput(
483 content=[UserMessageTextContent(text="Hello, world!")],
484 attachments=[],
485 inference_options=InferenceOptions(),
486 )
487 )
488 )
489 )
490 thread = next(
491 event.thread for event in events if event.type == "thread.created"
492 )
493 assert (await server.store.load_thread(thread.id, DEFAULT_CONTEXT)).metadata[
494 "test_key"
495 ] == "test_value"
496 assert events[-1].type == "thread.updated"
497
498
499async def test_saves_thread_locked_fields():
500 async def responder(
501 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
502 ) -> AsyncIterator[ThreadStreamEvent]:
503 thread.status = LockedStatus(reason="Because")
504 return
505 yield
506
507 with make_server(responder) as server:
508 events = await server.process_streaming(
509 ThreadsCreateReq(
510 params=ThreadCreateParams(
511 input=UserMessageInput(
512 content=[UserMessageTextContent(text="Hello, world!")],
513 attachments=[],
514 inference_options=InferenceOptions(),
515 )
516 )
517 )
518 )
519 thread = next(
520 event.thread for event in events if event.type == "thread.created"
521 )
522 loaded = await server.store.load_thread(
523 thread.id, RequestContext(user_id="test_user")
524 )
525 assert loaded.status == LockedStatus(reason="Because")
526 assert events[-1].type == "thread.updated"
527 assert events[-1].thread.status == LockedStatus(reason="Because")
528
529
530async def test_emits_thread_updated_mid_stream_and_persists():
531 async def responder(
532 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
533 ) -> AsyncIterator[ThreadStreamEvent]:
534 # First assistant message
535 yield ThreadItemDoneEvent(
536 item=AssistantMessageItem(
537 id="assistant_a",
538 content=[AssistantMessageContent(text="A")],
539 created_at=datetime.now(),
540 thread_id=thread.id,
541 ),
542 )
543
544 # Update the thread mid-stream
545 thread.title = "Mid-stream title"
546
547 # Second assistant message, after which thread.updated should be emitted
548 yield ThreadItemDoneEvent(
549 item=AssistantMessageItem(
550 id="assistant_b",
551 content=[AssistantMessageContent(text="B")],
552 created_at=datetime.now(),
553 thread_id=thread.id,
554 ),
555 )
556
557 # Continue streaming after the update event
558 yield ThreadItemDoneEvent(
559 item=AssistantMessageItem(
560 id="assistant_c",
561 content=[AssistantMessageContent(text="C")],
562 created_at=datetime.now(),
563 thread_id=thread.id,
564 ),
565 )
566
567 with make_server(responder) as server:
568 events = await server.process_streaming(
569 ThreadsCreateReq(
570 params=ThreadCreateParams(
571 input=UserMessageInput(
572 content=[UserMessageTextContent(text="Hello")],
573 attachments=[],
574 inference_options=InferenceOptions(),
575 )
576 )
577 )
578 )
579
580 # Find the created thread
581 thread = next(e.thread for e in events if isinstance(e, ThreadCreatedEvent))
582
583 # Ensure a thread.updated event occurs right after assistant_b
584 idx_b = next(
585 i
586 for i, e in enumerate(events)
587 if isinstance(e, ThreadItemDoneEvent)
588 and isinstance(e.item, AssistantMessageItem)
589 and e.item.id == "assistant_b"
590 )
591 assert isinstance(events[idx_b + 1], ThreadUpdatedEvent)
592 updated_event = cast(ThreadUpdatedEvent, events[idx_b + 1])
593 assert updated_event.thread.title == "Mid-stream title"
594
595 # Streaming continues after the update event
596 assert isinstance(events[idx_b + 2], ThreadItemDoneEvent)
597 done_event = cast(ThreadItemDoneEvent, events[idx_b + 2])
598 assert isinstance(done_event.item, AssistantMessageItem)
599 assert done_event.item.id == "assistant_c"
600
601 # Persisted in the store
602 saved = await server.store.load_thread(thread.id, DEFAULT_CONTEXT)
603 assert saved.title == "Mid-stream title"
604
605
606async def test_respond_with_tool_call():
607 async def responder(
608 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
609 ) -> AsyncIterator[ThreadStreamEvent]:
610 if isinstance(input, UserMessageItem):
611 yield ThreadItemDoneEvent(
612 item=ClientToolCallItem(
613 id="msg_1",
614 created_at=datetime.now(),
615 name="tool_call_1",
616 arguments={"arg1": "val1", "arg2": False},
617 call_id="tool_call_1",
618 thread_id=thread.id,
619 ),
620 )
621 elif input is None:
622 yield ThreadItemDoneEvent(
623 item=AssistantMessageItem(
624 id="msg_2",
625 content=[
626 AssistantMessageContent(text="Glad the tool call succeeded!")
627 ],
628 created_at=datetime.now(),
629 thread_id=thread.id,
630 ),
631 )
632
633 with make_server(responder) as server:
634 events = await server.process_streaming(
635 ThreadsCreateReq(
636 params=ThreadCreateParams(
637 input=UserMessageInput(
638 content=[UserMessageTextContent(text="Hello, world!")],
639 attachments=[],
640 inference_options=InferenceOptions(),
641 )
642 )
643 )
644 )
645
646 assert len(events) == 4
647 assert events[0].type == "thread.created"
648 thread = events[0].thread
649
650 assert events[1].type == "thread.item.done"
651 assert events[1].item.type == "user_message"
652
653 assert events[2].type == "stream_options"
654
655 assert events[3].type == "thread.item.done"
656 assert events[3].item.type == "client_tool_call"
657 assert events[3].item.id == "msg_1"
658 assert events[3].item.name == "tool_call_1"
659 assert events[3].item.arguments == {"arg1": "val1", "arg2": False}
660 assert events[3].item.call_id == "tool_call_1"
661
662 events = await server.process_streaming(
663 ThreadsAddClientToolOutputReq(
664 params=ThreadAddClientToolOutputParams(
665 thread_id=thread.id,
666 result={"text": "Wow!"},
667 )
668 )
669 )
670
671 assert len(events) == 2
672 assert events[0].type == "stream_options"
673 assert events[1].type == "thread.item.done"
674 assert events[1].item.type == "assistant_message"
675
676
677async def test_removes_tool_call_if_no_output_provided():
678 async def responder(
679 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
680 ) -> AsyncIterator[ThreadStreamEvent]:
681 assert isinstance(input, UserMessageItem)
682 assert input.content[0].type == "input_text"
683 if input.content[0].text == "Message 1":
684 yield ThreadItemDoneEvent(
685 item=ClientToolCallItem(
686 id="msg_1",
687 created_at=datetime.now(),
688 name="tool_call_1",
689 arguments={"arg1": "val1", "arg2": False},
690 call_id="tool_call_1",
691 thread_id=thread.id,
692 ),
693 )
694 else:
695 yield ThreadItemDoneEvent(
696 item=AssistantMessageItem(
697 id="msg_2",
698 content=[AssistantMessageContent(text="All done!")],
699 created_at=datetime.now(),
700 thread_id=thread.id,
701 ),
702 )
703
704 with make_server(responder) as server:
705 events = await server.process_streaming(
706 ThreadsCreateReq(
707 params=ThreadCreateParams(
708 input=UserMessageInput(
709 content=[UserMessageTextContent(text="Message 1")],
710 attachments=[],
711 inference_options=InferenceOptions(),
712 )
713 )
714 )
715 )
716 thread = next(
717 event.thread for event in events if event.type == "thread.created"
718 )
719
720 await server.process_streaming(
721 ThreadsAddUserMessageReq(
722 params=ThreadAddUserMessageParams(
723 thread_id=thread.id,
724 input=UserMessageInput(
725 content=[UserMessageTextContent(text="Message 2")],
726 attachments=[],
727 inference_options=InferenceOptions(),
728 ),
729 )
730 )
731 )
732
733 items_result = await server.process_non_streaming(
734 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
735 )
736 items = TypeAdapter(Page[ThreadItem]).validate_json(items_result.json)
737 assert len(items.data) == 3
738 assert items.data[0].type == "assistant_message"
739 assert items.data[1].type == "user_message"
740 assert items.data[2].type == "user_message"
741
742
743async def test_respond_with_tool_status():
744 async def responder(
745 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
746 ) -> AsyncIterator[ThreadStreamEvent]:
747 yield ProgressUpdateEvent(text="Tool status")
748 yield ThreadItemDoneEvent(
749 item=AssistantMessageItem(
750 id="msg_4",
751 content=[AssistantMessageContent(text="Hello, world!")],
752 created_at=datetime.now(),
753 thread_id=thread.id,
754 ),
755 )
756
757 with make_server(responder) as server:
758 events = await server.process_streaming(
759 ThreadsCreateReq(
760 params=ThreadCreateParams(
761 input=UserMessageInput(
762 content=[UserMessageTextContent(text="Hello, world!")],
763 attachments=[],
764 inference_options=InferenceOptions(),
765 )
766 )
767 )
768 )
769
770 assert len(events) == 5
771 assert events[0].type == "thread.created"
772 assert events[1].type == "thread.item.done"
773 assert events[2].type == "stream_options"
774 assert events[3].type == "progress_update"
775 assert events[4].type == "thread.item.done"
776
777
778async def test_list_threads_response():
779 with make_server() as server:
780 for i in range(25):
781 await server.process_streaming(
782 ThreadsCreateReq(
783 params=ThreadCreateParams(
784 input=UserMessageInput(
785 content=[UserMessageTextContent(text=f"Thread {i}")],
786 attachments=[],
787 inference_options=InferenceOptions(),
788 )
789 )
790 )
791 )
792 list_result = await server.process_non_streaming(
793 ThreadsListReq(params=ThreadListParams(limit=20))
794 )
795 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
796 assert len(threads.data) == 20
797
798 # Newest first
799 assert threads.data[0].created_at > threads.data[1].created_at
800 assert threads.has_more is True
801 assert threads.after is not None
802
803 more_threads = await server.process_non_streaming(
804 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
805 )
806 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(more_threads.json)
807 assert len(threads.data) == 5
808 assert threads.has_more is False
809 assert not threads.after
810
811
812async def test_list_items_response():
813 async def responder(
814 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
815 ) -> AsyncIterator[ThreadStreamEvent]:
816 for i in range(25):
817 yield ThreadItemDoneEvent(
818 item=UserMessageItem(
819 id=f"msg_{i}",
820 content=[UserMessageTextContent(text=f"Message {i}")],
821 attachments=[],
822 inference_options=InferenceOptions(),
823 thread_id=thread.id,
824 created_at=datetime.now(),
825 ),
826 )
827
828 with make_server(responder) as server:
829 events = await server.process_streaming(
830 ThreadsCreateReq(
831 params=ThreadCreateParams(
832 input=UserMessageInput(
833 content=[UserMessageTextContent(text="Thread 1")],
834 attachments=[],
835 inference_options=InferenceOptions(),
836 )
837 )
838 )
839 )
840 thread = next(
841 event.thread for event in events if event.type == "thread.created"
842 )
843 list_result = await server.process_non_streaming(
844 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
845 )
846 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
847 assert len(items.data) == 20
848 # Newest first
849 assert items.data[0].created_at > items.data[1].created_at
850 assert items.has_more is True
851 assert items.after is not None
852
853 list_result = await server.process_non_streaming(
854 ItemsListReq(
855 params=ItemsListParams(thread_id=thread.id, after=items.data[-1].id)
856 )
857 )
858 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
859 assert len(items.data) == 6 # +1 for user message
860 assert items.has_more is False
861
862
863async def test_calls_action():
864 actions = []
865
866 async def responder(
867 thread: ThreadMetadata,
868 input: UserMessageItem | None,
869 context: Any,
870 ) -> AsyncIterator[ThreadStreamEvent]:
871 yield ThreadItemDoneEvent(
872 item=WidgetItem(
873 id="widget_1",
874 type="widget",
875 created_at=datetime.now(),
876 thread_id=thread.id,
877 widget=Card(
878 children=[
879 Text(
880 key="text_1",
881 value="test",
882 )
883 ],
884 ),
885 ),
886 )
887
888 async def action(
889 thread: ThreadMetadata,
890 action: Action[str, Any],
891 sender: WidgetItem | None,
892 context: Any,
893 ) -> AsyncIterator[ThreadStreamEvent]:
894 actions.append((action, sender))
895 assert sender
896
897 yield ThreadItemUpdatedEvent(
898 item_id=sender.id,
899 update=WidgetRootUpdated(
900 widget=Card(
901 children=[
902 Text(value="Email sent!"),
903 ]
904 ),
905 ),
906 )
907
908 with make_server(responder, action_callback=action) as server:
909 events = await server.process_streaming(
910 ThreadsCreateReq(
911 params=ThreadCreateParams(
912 input=UserMessageInput(
913 content=[UserMessageTextContent(text="Show widget")],
914 attachments=[],
915 inference_options=InferenceOptions(),
916 )
917 )
918 )
919 )
920 thread = next(
921 event.thread for event in events if event.type == "thread.created"
922 )
923 widget_item = next(
924 event.item
925 for event in events
926 if isinstance(event, ThreadItemDoneEvent)
927 and isinstance(event.item, WidgetItem)
928 )
929
930 events = await server.process_streaming(
931 ThreadsCustomActionReq(
932 params=ThreadCustomActionParams(
933 thread_id=thread.id,
934 item_id=widget_item.id,
935 action=Action(type="create_user", payload={"user_id": "123"}),
936 )
937 )
938 )
939
940 assert actions
941 assert actions[0] == (
942 Action(type="create_user", payload={"user_id": "123"}),
943 widget_item,
944 )
945
946 assert len(events) == 2
947 assert events[0].type == "stream_options"
948 assert events[1].type == "thread.item.updated"
949 assert isinstance(events[1], ThreadItemUpdatedEvent)
950 assert events[1].update.type == "widget.root.updated"
951 assert events[1].update.widget == Card(children=[Text(value="Email sent!")])
952
953
954async def test_add_feedback():
955 called = []
956
957 def handle_feedback(thread_id, item_ids, feedback, context):
958 called.append((thread_id, item_ids, feedback, context))
959
960 async def responder(
961 thread: ThreadMetadata,
962 input: UserMessageItem | None,
963 context: Any,
964 ) -> AsyncIterator[ThreadStreamEvent]:
965 yield ThreadItemDoneEvent(
966 item=AssistantMessageItem(
967 id="assistant_msg_1",
968 content=[
969 AssistantMessageContent(text="Feedback received!"),
970 ],
971 created_at=datetime.now(),
972 thread_id=thread.id,
973 ),
974 )
975 yield ThreadItemDoneEvent(
976 item=WidgetItem(
977 id="widget_1",
978 type="widget",
979 created_at=datetime.now(),
980 widget=Card(children=[Text(key="text", value="Widget content")]),
981 thread_id=thread.id,
982 ),
983 )
984
985 with make_server(responder, handle_feedback=handle_feedback) as server:
986 events = await server.process_streaming(
987 ThreadsCreateReq(
988 params=ThreadCreateParams(
989 input=UserMessageInput(
990 content=[UserMessageTextContent(text="Test feedback")],
991 attachments=[],
992 inference_options=InferenceOptions(),
993 )
994 )
995 )
996 )
997 thread = next(
998 event.thread for event in events if event.type == "thread.created"
999 )
1000
1001 await server.process_non_streaming(
1002 ItemsFeedbackReq(
1003 params=ItemFeedbackParams(
1004 thread_id=thread.id,
1005 item_ids=["assistant_msg_1"],
1006 kind="positive",
1007 )
1008 )
1009 )
1010 await server.process_non_streaming(
1011 ItemsFeedbackReq(
1012 params=ItemFeedbackParams(
1013 thread_id=thread.id,
1014 item_ids=["widget_1"],
1015 kind="negative",
1016 )
1017 )
1018 )
1019
1020 assert len(called) == 2
1021 assert called[0][0] == thread.id
1022 assert called[0][1] == ["assistant_msg_1"]
1023 assert called[0][2] == "positive"
1024
1025 assert called[1][0] == thread.id
1026 assert called[1][1] == ["widget_1"]
1027 assert called[1][2] == "negative"
1028
1029
1030async def test_get_thread_by_id():
1031 make_thread_items()
1032
1033 with make_server() as server:
1034 # Create a thread by sending a ThreadCreateRequest
1035 events = await server.process_streaming(
1036 ThreadsCreateReq(
1037 params=ThreadCreateParams(
1038 input=UserMessageInput(
1039 content=[UserMessageTextContent(text="Test thread")],
1040 attachments=[],
1041 inference_options=InferenceOptions(),
1042 )
1043 )
1044 )
1045 )
1046 thread = next(
1047 event.thread for event in events if event.type == "thread.created"
1048 )
1049
1050 # Retrieve the thread by id
1051 result = await server.process_non_streaming(
1052 ThreadsGetByIdReq(
1053 params=ThreadGetByIdParams(thread_id=thread.id),
1054 )
1055 )
1056 loaded_thread = TypeAdapter(Thread).validate_json(result.json)
1057 assert loaded_thread.id == thread.id
1058 assert loaded_thread.title == thread.title
1059 assert loaded_thread.items.data[0].type == "user_message"
1060 assert loaded_thread.items.data[0].content[0].text == "Test thread"
1061
1062
1063async def test_create_file():
1064 store = InMemoryFileStore()
1065 file_name = "test-file-name"
1066 file_content_type = "text/plain"
1067
1068 with make_server(file_store=store) as server:
1069 result = await server.process_non_streaming(
1070 AttachmentsCreateReq(
1071 params=AttachmentCreateParams(
1072 name=file_name, size=1024, mime_type=file_content_type
1073 )
1074 )
1075 )
1076 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1077
1078 # Verify response includes file
1079 assert attachment.id is not None
1080 assert attachment.mime_type == file_content_type
1081 assert attachment.name == file_name
1082 assert attachment.type == "file"
1083 assert attachment.upload_url == AnyUrl(
1084 f"https://example.com/{attachment.id}/upload"
1085 )
1086 assert attachment.id in store.files
1087
1088
1089async def test_create_image_file():
1090 store = InMemoryFileStore()
1091 file_name = "test-file-name"
1092 file_content_type = "image/png"
1093
1094 with make_server(file_store=store) as server:
1095 result = await server.process_non_streaming(
1096 AttachmentsCreateReq(
1097 params=AttachmentCreateParams(
1098 name=file_name,
1099 size=1024,
1100 mime_type=file_content_type,
1101 )
1102 )
1103 )
1104 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1105
1106 assert attachment.id is not None
1107 assert attachment.mime_type == file_content_type
1108 assert attachment.name == file_name
1109 assert attachment.type == "image"
1110 assert attachment.preview_url == AnyUrl(
1111 f"https://example.com/{attachment.id}/preview"
1112 )
1113 assert attachment.upload_url == AnyUrl(
1114 f"https://example.com/{attachment.id}/upload"
1115 )
1116
1117 assert attachment.id in store.files
1118
1119
1120async def test_create_file_without_filestore():
1121 with pytest.raises(RuntimeError):
1122 with make_server() as server:
1123 await server.process_non_streaming(
1124 AttachmentsCreateReq(
1125 params=AttachmentCreateParams(
1126 name="test-file-name",
1127 size=1024,
1128 mime_type="text/plain",
1129 )
1130 )
1131 )
1132
1133
1134async def test_delete_file():
1135 store = InMemoryFileStore()
1136 store.files["test-file-id"] = {"id": "test-file-id"}
1137 with make_server(file_store=store) as server:
1138 await server.process_non_streaming(
1139 AttachmentsDeleteReq(
1140 params=AttachmentDeleteParams(attachment_id="test-file-id")
1141 )
1142 )
1143 assert store.files == {}
1144
1145
1146async def test_delete_file_without_filestore():
1147 with pytest.raises(RuntimeError):
1148 with make_server() as server:
1149 await server.process_non_streaming(
1150 AttachmentsDeleteReq(
1151 params=AttachmentDeleteParams(attachment_id="test-file-id")
1152 )
1153 )
1154
1155
1156async def test_list_threads_paginated_response():
1157 with make_server() as server:
1158 for i in range(25):
1159 await server.process_streaming(
1160 ThreadsCreateReq(
1161 params=ThreadCreateParams(
1162 input=UserMessageInput(
1163 content=[UserMessageTextContent(text=f"Thread {i}")],
1164 attachments=[],
1165 inference_options=InferenceOptions(),
1166 )
1167 )
1168 )
1169 )
1170 list_result = await server.process_non_streaming(
1171 ThreadsListReq(params=ThreadListParams(limit=20))
1172 )
1173 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1174 assert len(threads.data) == 20
1175 assert threads.has_more is True
1176 assert threads.after is not None
1177
1178 list_result = await server.process_non_streaming(
1179 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
1180 )
1181 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1182 assert len(threads.data) == 5
1183 assert threads.has_more is False
1184
1185
1186async def test_threads_update_request():
1187 with make_server() as server:
1188 events = await server.process_streaming(
1189 ThreadsCreateReq(
1190 params=ThreadCreateParams(
1191 input=UserMessageInput(
1192 content=[UserMessageTextContent(text="Hello, world!")],
1193 attachments=[],
1194 inference_options=InferenceOptions(),
1195 )
1196 )
1197 )
1198 )
1199 thread = next(
1200 event.thread for event in events if event.type == "thread.created"
1201 )
1202 updated_thread = await server.process_non_streaming(
1203 ThreadsUpdateReq(
1204 params=ThreadUpdateParams(thread_id=thread.id, title="New Title")
1205 )
1206 )
1207 updated_thread = TypeAdapter[Thread](Thread).validate_json(updated_thread.json)
1208 assert updated_thread.title == "New Title"
1209
1210
1211async def test_returns_widget_item_generator():
1212 thread = ThreadMetadata(
1213 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1214 )
1215
1216 def render_widget(i: int) -> Card:
1217 return Card(children=[Text(id="text", value="Hello, world"[:i])])
1218
1219 async def widget_generator():
1220 yield render_widget(0)
1221 yield render_widget(3)
1222 yield render_widget(6)
1223 yield render_widget(12)
1224
1225 events = [event async for event in stream_widget(thread, widget_generator())]
1226
1227 assert len(events) == 5
1228 assert isinstance(events[0], ThreadItemAddedEvent)
1229 assert isinstance(events[0].item, WidgetItem)
1230 assert events[0].item.widget == Card(children=[Text(id="text", value="")])
1231
1232 assert isinstance(events[1], ThreadItemUpdatedEvent)
1233 assert events[1].update.type == "widget.streaming_text.value_delta"
1234 assert events[1].update.component_id == "text"
1235 assert events[1].update.delta == "Hel"
1236
1237 assert isinstance(events[2], ThreadItemUpdatedEvent)
1238 assert events[2].update.type == "widget.streaming_text.value_delta"
1239 assert events[2].update.component_id == "text"
1240 assert events[2].update.delta == "lo,"
1241
1242 assert isinstance(events[3], ThreadItemUpdatedEvent)
1243 assert events[3].update.type == "widget.streaming_text.value_delta"
1244 assert events[3].update.component_id == "text"
1245 assert events[3].update.delta == " world"
1246
1247 assert isinstance(events[4], ThreadItemDoneEvent)
1248 assert isinstance(events[4].item, WidgetItem)
1249 assert events[4].item.widget == Card(
1250 children=[Text(id="text", value="Hello, world")]
1251 )
1252
1253
1254async def test_returns_widget_item_generator_full_replace():
1255 thread = ThreadMetadata(
1256 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1257 )
1258
1259 async def widget_generator():
1260 yield Card(children=[Text(id="text", value="Hello")])
1261 yield Card(children=[Text(id="text", value="World")])
1262
1263 events = [event async for event in stream_widget(thread, widget_generator())]
1264
1265 assert len(events) == 3
1266 assert isinstance(events[0], ThreadItemAddedEvent)
1267 assert isinstance(events[0].item, WidgetItem)
1268 assert events[0].item.widget == Card(children=[Text(id="text", value="Hello")])
1269
1270 assert isinstance(events[1], ThreadItemUpdatedEvent)
1271 assert events[1].update.type == "widget.root.updated"
1272 assert events[1].update.widget == Card(children=[Text(id="text", value="World")])
1273
1274 assert isinstance(events[2], ThreadItemDoneEvent)
1275 assert isinstance(events[2].item, WidgetItem)
1276 assert events[2].item.widget == Card(children=[Text(id="text", value="World")])
1277
1278
1279async def test_delete_thread():
1280 with make_server() as server:
1281 events = await server.process_streaming(
1282 ThreadsCreateReq(
1283 params=ThreadCreateParams(
1284 input=UserMessageInput(
1285 content=[UserMessageTextContent(text="Thread to delete")],
1286 attachments=[],
1287 inference_options=InferenceOptions(),
1288 )
1289 )
1290 )
1291 )
1292 thread = next(
1293 event.thread for event in events if event.type == "thread.created"
1294 )
1295
1296 response = await server.process_non_streaming(
1297 ThreadsDeleteReq(params=ThreadDeleteParams(thread_id=thread.id))
1298 )
1299 assert response.json == b"{}"
1300
1301 try:
1302 await server.process_non_streaming(
1303 ThreadsGetByIdReq(
1304 params=ThreadGetByIdParams(thread_id=thread.id),
1305 )
1306 )
1307 assert False, "Expected ValueError for deleted thread"
1308 except NotFoundError as e:
1309 assert f"Thread {thread.id} not found" in str(e)
1310
1311
1312async def test_thread_item_removed_event_removes_item():
1313 removed_id = "msg_to_remove"
1314
1315 async def responder(
1316 thread: ThreadMetadata,
1317 input: UserMessageItem | None,
1318 context: Any,
1319 ) -> AsyncIterator[ThreadStreamEvent]:
1320 yield ThreadItemDoneEvent(
1321 item=UserMessageItem(
1322 id=removed_id,
1323 content=[UserMessageTextContent(text="To be removed")],
1324 attachments=[],
1325 inference_options=InferenceOptions(),
1326 thread_id=thread.id,
1327 created_at=datetime.now(),
1328 ),
1329 )
1330 yield ThreadItemRemovedEvent(item_id=removed_id)
1331
1332 with make_server(responder) as server:
1333 events = await server.process_streaming(
1334 ThreadsCreateReq(
1335 params=ThreadCreateParams(
1336 input=UserMessageInput(
1337 content=[UserMessageTextContent(text="Trigger removal")],
1338 attachments=[],
1339 inference_options=InferenceOptions(),
1340 )
1341 )
1342 )
1343 )
1344 thread = next(
1345 event.thread for event in events if event.type == "thread.created"
1346 )
1347 assert any(
1348 e.type == "thread.item.removed" and e.item_id == removed_id for e in events
1349 )
1350 # The item should not be present in the store
1351 list_result = await server.process_non_streaming(
1352 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1353 )
1354 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1355 assert all(i.id != removed_id for i in items)
1356
1357
1358async def test_raising_in_responder_yields_error_event():
1359 async def responder(
1360 thread: ThreadMetadata,
1361 input: UserMessageItem | None,
1362 context: Any,
1363 ) -> AsyncIterator[ThreadStreamEvent]:
1364 raise ValueError("Test error")
1365 yield
1366
1367 with make_server(responder) as server:
1368 result = await server.process(
1369 ThreadsCreateReq(
1370 params=ThreadCreateParams(
1371 input=UserMessageInput(
1372 content=[UserMessageTextContent(text="Test error")],
1373 attachments=[],
1374 inference_options=InferenceOptions(),
1375 )
1376 )
1377 ).model_dump_json(),
1378 RequestContext(user_id="test_user"),
1379 )
1380 assert isinstance(result, StreamingResult)
1381 iter = result.__aiter__()
1382
1383 # Yield an errro
1384 async for e in iter:
1385 e = decode_event(e)
1386 if e.type == "error":
1387 assert e.code == ErrorCode.STREAM_ERROR
1388 break
1389 else:
1390 assert False, "Expected error event"
1391
1392 # Then finish
1393 try:
1394 await iter.__anext__()
1395 except StopAsyncIteration:
1396 pass
1397 else:
1398 assert False, "Expected StopAsyncIteration"
1399
1400
1401async def test_thread_item_replaced_event_replaces_item():
1402 original_id = "msg_to_replace"
1403
1404 async def responder(
1405 thread: ThreadMetadata,
1406 input: UserMessageItem | None,
1407 context: Any,
1408 ) -> AsyncIterator[ThreadStreamEvent]:
1409 # First add an item
1410 yield ThreadItemDoneEvent(
1411 item=UserMessageItem(
1412 id=original_id,
1413 content=[UserMessageTextContent(text="Original content")],
1414 attachments=[],
1415 inference_options=InferenceOptions(),
1416 thread_id=thread.id,
1417 created_at=datetime.now(),
1418 ),
1419 )
1420 # Then replace it
1421 replacement_item = AssistantMessageItem(
1422 id=original_id,
1423 content=[AssistantMessageContent(text="This is the replacement content")],
1424 created_at=datetime.now(),
1425 thread_id=thread.id,
1426 )
1427 yield ThreadItemReplacedEvent(item=replacement_item)
1428
1429 with make_server(responder) as server:
1430 events = await server.process_streaming(
1431 ThreadsCreateReq(
1432 params=ThreadCreateParams(
1433 input=UserMessageInput(
1434 content=[UserMessageTextContent(text="Trigger replacement")],
1435 attachments=[],
1436 inference_options=InferenceOptions(),
1437 )
1438 )
1439 )
1440 )
1441 thread = next(
1442 event.thread for event in events if event.type == "thread.created"
1443 )
1444
1445 # Verify the replacement event was generated
1446 assert any(
1447 e.type == "thread.item.replaced" and e.item.id == original_id
1448 for e in events
1449 )
1450
1451 # Verify the item was replaced in the store
1452 list_result = await server.process_non_streaming(
1453 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1454 )
1455 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1456
1457 # Find the replaced item
1458 replaced_item = next((i for i in items if i.id == original_id), None)
1459 assert replaced_item is not None
1460 assert isinstance(replaced_item, AssistantMessageItem)
1461 assert isinstance(replaced_item.content[0], AssistantMessageContent)
1462 assert replaced_item.content[0].text == "This is the replacement content"
1463
1464
1465async def test_thread_item_replaced_event_with_widget():
1466 widget_id = "widget_to_replace"
1467 original_widget = Card(children=[Text(value="Original widget content")])
1468 replacement_widget = Card(children=[Text(value="Replaced widget content")])
1469
1470 async def responder(
1471 thread: ThreadMetadata,
1472 input: UserMessageItem | None,
1473 context: Any,
1474 ) -> AsyncIterator[ThreadStreamEvent]:
1475 # First add a widget item
1476 yield ThreadItemDoneEvent(
1477 item=WidgetItem(
1478 id=widget_id,
1479 widget=original_widget,
1480 created_at=datetime.now(),
1481 thread_id=thread.id,
1482 ),
1483 )
1484 # Then replace it
1485 yield ThreadItemReplacedEvent(
1486 item=WidgetItem(
1487 id=widget_id,
1488 widget=replacement_widget,
1489 created_at=datetime.now(),
1490 thread_id=thread.id,
1491 ),
1492 )
1493
1494 with make_server(responder) as server:
1495 events = await server.process_streaming(
1496 ThreadsCreateReq(
1497 params=ThreadCreateParams(
1498 input=UserMessageInput(
1499 content=[
1500 UserMessageTextContent(text="Trigger widget replacement")
1501 ],
1502 attachments=[],
1503 inference_options=InferenceOptions(),
1504 )
1505 )
1506 )
1507 )
1508 thread = next(
1509 event.thread for event in events if event.type == "thread.created"
1510 )
1511
1512 # Verify the replacement event was generated
1513 replacement_event = next(
1514 (
1515 e
1516 for e in events
1517 if e.type == "thread.item.replaced" and e.item.id == widget_id
1518 ),
1519 None,
1520 )
1521 assert replacement_event is not None
1522 assert isinstance(replacement_event.item, WidgetItem)
1523 assert replacement_event.item.widget == replacement_widget
1524
1525 # Verify the item was replaced in the store
1526 list_result = await server.process_non_streaming(
1527 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1528 )
1529 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1530
1531 # Find the replaced widget item
1532 replaced_item = next((i for i in items if i.id == widget_id), None)
1533 assert replaced_item is not None
1534 assert isinstance(replaced_item, WidgetItem)
1535 assert isinstance(replaced_item.widget, Card)
1536 assert isinstance(replaced_item.widget.children[0], Text)
1537 assert replaced_item.widget.children[0].value == "Replaced widget content"
1538
1539
1540async def test_retry_after_item():
1541 responder_calls = []
1542
1543 async def responder(
1544 thread: ThreadMetadata,
1545 input: UserMessageItem | None,
1546 context: Any,
1547 ) -> AsyncIterator[ThreadStreamEvent]:
1548 responder_calls.append(input)
1549
1550 if len(responder_calls) == 1:
1551 # First response - return an assistant message
1552 yield ThreadItemDoneEvent(
1553 item=AssistantMessageItem(
1554 id="assistant_msg_1",
1555 content=[AssistantMessageContent(text="First response")],
1556 created_at=datetime.now(),
1557 thread_id=thread.id,
1558 ),
1559 )
1560 else:
1561 # Retry response - return different content
1562 yield ThreadItemDoneEvent(
1563 item=AssistantMessageItem(
1564 id="assistant_msg_2",
1565 content=[AssistantMessageContent(text="Retried response")],
1566 created_at=datetime.now(),
1567 thread_id=thread.id,
1568 ),
1569 )
1570
1571 with make_server(responder) as server:
1572 events = await server.process_streaming(
1573 ThreadsCreateReq(
1574 params=ThreadCreateParams(
1575 input=UserMessageInput(
1576 content=[UserMessageTextContent(text="Hello, world!")],
1577 attachments=[],
1578 inference_options=InferenceOptions(),
1579 )
1580 )
1581 )
1582 )
1583 thread = next(
1584 event.thread for event in events if event.type == "thread.created"
1585 )
1586
1587 list_result = await server.process_non_streaming(
1588 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1589 )
1590 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1591 assert len(items_before.data) == 2
1592 assert items_before.data[0].type == "assistant_message"
1593 assert items_before.data[0].content[0].type == "output_text"
1594 assert items_before.data[0].content[0].text == "First response"
1595 assert items_before.data[1].type == "user_message"
1596 assert items_before.data[1].content[0].text == "Hello, world!"
1597
1598 # Get the user message ID for retrying after it
1599 user_message_id = items_before.data[1].id
1600
1601 # Retry after the user message
1602 retry_events = await server.process_streaming(
1603 ThreadsRetryAfterItemReq(
1604 params=ThreadRetryAfterItemParams(
1605 thread_id=thread.id, item_id=user_message_id
1606 )
1607 )
1608 )
1609
1610 # Verify the retry generated new response
1611 assert len(retry_events) == 2
1612 assert retry_events[0].type == "stream_options"
1613 assert retry_events[1].type == "thread.item.done"
1614 assert retry_events[1].item.type == "assistant_message"
1615 assert retry_events[1].item.content[0].type == "output_text"
1616 assert retry_events[1].item.content[0].text == "Retried response"
1617
1618 # Verify the responder was called twice with the same user message
1619 assert len(responder_calls) == 2
1620 assert responder_calls[0] == responder_calls[1]
1621 assert responder_calls[0].type == "user_message"
1622 assert responder_calls[0].content[0].text == "Hello, world!"
1623
1624 # Verify the assistant response was replaced in storage
1625 list_result_after = await server.process_non_streaming(
1626 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1627 )
1628 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1629 list_result_after.json
1630 )
1631 assert len(items_after.data) == 2 # Still user message + assistant message
1632 assert items_after.data[0].type == "assistant_message"
1633 assert items_after.data[0].content[0].type == "output_text"
1634 assert items_after.data[0].content[0].text == "Retried response" # New response
1635 assert items_after.data[1].type == "user_message"
1636 assert items_after.data[1].content[0].text == "Hello, world!"
1637
1638
1639async def test_retry_after_item_invalid_item():
1640 async def responder(
1641 thread: ThreadMetadata,
1642 input: UserMessageItem | None,
1643 context: Any,
1644 ) -> AsyncIterator[ThreadStreamEvent]:
1645 yield ThreadItemDoneEvent(
1646 item=AssistantMessageItem(
1647 id="assistant_msg_1",
1648 content=[AssistantMessageContent(text="Response")],
1649 created_at=datetime.now(),
1650 thread_id=thread.id,
1651 ),
1652 )
1653
1654 with make_server(responder) as server:
1655 events = await server.process_streaming(
1656 ThreadsCreateReq(
1657 params=ThreadCreateParams(
1658 input=UserMessageInput(
1659 content=[UserMessageTextContent(text="Hello")],
1660 attachments=[],
1661 inference_options=InferenceOptions(),
1662 )
1663 )
1664 )
1665 )
1666 thread = next(
1667 event.thread for event in events if event.type == "thread.created"
1668 )
1669
1670 items_result = await server.process_non_streaming(
1671 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1672 )
1673 items = TypeAdapter(Page[ThreadItem]).validate_json(items_result.json)
1674 assistant_item_id = items.data[0].id
1675
1676 # Try to retry after an assistant message (should fail)
1677 with pytest.raises(ValueError, match="is not a user message"):
1678 await server.process_streaming(
1679 ThreadsRetryAfterItemReq(
1680 params=ThreadRetryAfterItemParams(
1681 thread_id=thread.id, item_id=assistant_item_id
1682 )
1683 )
1684 )
1685
1686
1687async def test_retry_after_item_with_multiple_messages():
1688 responder_calls = []
1689
1690 async def responder(
1691 thread: ThreadMetadata,
1692 input: UserMessageItem | None,
1693 context: Any,
1694 ) -> AsyncIterator[ThreadStreamEvent]:
1695 responder_calls.append(input)
1696 assert input is not None
1697 assert input.type == "user_message"
1698 yield ThreadItemDoneEvent(
1699 item=AssistantMessageItem(
1700 id=f"assistant_msg_{len(responder_calls)}",
1701 content=[
1702 AssistantMessageContent(
1703 text=f"Response to: {input.content[0].text}"
1704 )
1705 ],
1706 created_at=datetime.now(),
1707 thread_id=thread.id,
1708 ),
1709 )
1710
1711 with make_server(responder) as server:
1712 events = await server.process_streaming(
1713 ThreadsCreateReq(
1714 params=ThreadCreateParams(
1715 input=UserMessageInput(
1716 content=[UserMessageTextContent(text="First message")],
1717 attachments=[],
1718 inference_options=InferenceOptions(),
1719 )
1720 )
1721 )
1722 )
1723 thread = next(
1724 event.thread for event in events if event.type == "thread.created"
1725 )
1726
1727 await server.process_streaming(
1728 ThreadsAddUserMessageReq(
1729 params=ThreadAddUserMessageParams(
1730 thread_id=thread.id,
1731 input=UserMessageInput(
1732 content=[UserMessageTextContent(text="Second message")],
1733 attachments=[],
1734 inference_options=InferenceOptions(),
1735 ),
1736 )
1737 )
1738 )
1739
1740 # Verify we have 4 items: user1, assistant1, user2, assistant2
1741 list_result = await server.process_non_streaming(
1742 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1743 )
1744 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1745 assert len(items_before.data) == 4
1746
1747 # Get the second user message ID to retry after it
1748 second_user_message_id = items_before.data[1].id # Second user message
1749
1750 # Retry after the second user message
1751 retry_events = await server.process_streaming(
1752 ThreadsRetryAfterItemReq(
1753 params=ThreadRetryAfterItemParams(
1754 thread_id=thread.id, item_id=second_user_message_id
1755 )
1756 )
1757 )
1758
1759 assert len(retry_events) == 2
1760 assert retry_events[0].type == "stream_options"
1761 assert retry_events[1].type == "thread.item.done"
1762
1763 # Verify retry used the second user message
1764 assert len(responder_calls) == 3 # Original 2 + 1 retry
1765 assert responder_calls[2].type == "user_message"
1766 assert responder_calls[2].content[0].text == "Second message"
1767
1768 # Verify only the last assistant response was removed and regenerated
1769 list_result_after = await server.process_non_streaming(
1770 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1771 )
1772 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1773 list_result_after.json
1774 )
1775 assert len(items_after.data) == 4 # Still 4 items
1776 # First user message and response should remain unchanged
1777 assert items_after.data[3].type == "user_message"
1778 assert items_after.data[3].content[0].text == "First message"
1779 assert items_after.data[3].content[0].type == "input_text"
1780 assert items_after.data[3].content[0].text == "First message"
1781 assert items_after.data[2].type == "assistant_message"
1782 assert items_after.data[2].content[0].type == "output_text"
1783 assert items_after.data[2].content[0].text == "Response to: First message"
1784 # Second user message should remain, but assistant response should be new
1785 assert items_after.data[1].type == "user_message"
1786 assert items_after.data[1].content[0].text == "Second message"
1787 assert items_after.data[0].type == "assistant_message"
1788 assert items_after.data[0].content[0].type == "output_text"
1789 assert items_after.data[0].content[0].text == "Response to: Second message"
1790 assert items_after.data[0].id != items_before.data[0].id
1791
1792
1793async def test_threads_create_passes_tools_to_responder():
1794 tool_choice = ToolChoice(id="web_search")
1795
1796 async def responder(
1797 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1798 ) -> AsyncIterator[ThreadStreamEvent]:
1799 assert input is not None
1800 assert input.inference_options.tool_choice is not None
1801 assert input.inference_options.tool_choice.id == tool_choice.id
1802
1803 yield ThreadItemDoneEvent(
1804 item=AssistantMessageItem(
1805 id="assistant_msg_tools_create",
1806 content=[AssistantMessageContent(text="ok")],
1807 created_at=datetime.now(),
1808 thread_id=thread.id,
1809 ),
1810 )
1811
1812 with make_server(responder) as server:
1813 events = await server.process_streaming(
1814 ThreadsCreateReq(
1815 params=ThreadCreateParams(
1816 input=UserMessageInput(
1817 content=[UserMessageTextContent(text="Hello with tools")],
1818 attachments=[],
1819 inference_options=InferenceOptions(tool_choice=tool_choice),
1820 )
1821 )
1822 )
1823 )
1824 # Sanity check that the request flowed through.
1825 assert any(e.type == "thread.item.done" for e in events)
1826
1827
1828async def test_retry_after_item_passes_tools_to_responder():
1829 pass
1830
1831
1832async def test_threads_add_message_passes_tools_to_responder():
1833 tool_choice = ToolChoice(id="retrieval")
1834
1835 async def responder(
1836 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1837 ) -> AsyncIterator[ThreadStreamEvent]:
1838 assert input is not None
1839 assert input.inference_options.tool_choice is not None
1840 assert input.inference_options.tool_choice.id == tool_choice.id
1841
1842 yield ThreadItemDoneEvent(
1843 item=AssistantMessageItem(
1844 id="assistant_msg_tools_add",
1845 content=[AssistantMessageContent(text="ok")],
1846 created_at=datetime.now(),
1847 thread_id=thread.id,
1848 ),
1849 )
1850
1851 with make_server(responder) as server:
1852 # Create a thread first.
1853 events = await server.process_streaming(
1854 ThreadsCreateReq(
1855 params=ThreadCreateParams(
1856 input=UserMessageInput(
1857 content=[UserMessageTextContent(text="Start")],
1858 attachments=[],
1859 inference_options=InferenceOptions(tool_choice=tool_choice),
1860 )
1861 )
1862 )
1863 )
1864 thread = next(e.thread for e in events if e.type == "thread.created")
1865
1866 # Add a message with tools and verify they reach the responder.
1867 events = await server.process_streaming(
1868 ThreadsAddUserMessageReq(
1869 params=ThreadAddUserMessageParams(
1870 thread_id=thread.id,
1871 input=UserMessageInput(
1872 content=[UserMessageTextContent(text="Second")],
1873 attachments=[],
1874 inference_options=InferenceOptions(tool_choice=tool_choice),
1875 ),
1876 )
1877 )
1878 )
1879 # Sanity check that the request flowed through.
1880 assert any(e.type == "thread.item.done" for e in events)
1881
1882
1883async def test_custom_id_generators_on_server():
1884 class UniqueIdStore(SQLiteStore):
1885 def __init__(self, path):
1886 super().__init__(path)
1887 self.counter = 0
1888
1889 def generate_thread_id(self, context) -> str:
1890 self.counter += 1
1891 return f"thr_custom_{self.counter}"
1892
1893 def generate_item_id(
1894 self,
1895 item_type: str,
1896 thread: ThreadMetadata,
1897 context,
1898 ) -> str:
1899 self.counter += 1
1900 return f"{item_type}_custom_{self.counter}_{thread.id}"
1901
1902 async def responder(
1903 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1904 ) -> AsyncIterator[ThreadStreamEvent]:
1905 # Emit a tool call to later verify tool_call_output id generation
1906 return
1907 yield
1908
1909 with make_server(responder) as server:
1910 # Swap in our custom ID-generating store, reusing the same DB path.
1911 db_path = cast(SQLiteStore, server.store).db_path
1912 server.store = UniqueIdStore(db_path)
1913
1914 # Create thread and verify thread id and user message id
1915 result = await server.process(
1916 ThreadsCreateReq(
1917 params=ThreadCreateParams(
1918 input=UserMessageInput(
1919 content=[UserMessageTextContent(text="Hello")],
1920 attachments=[],
1921 inference_options=InferenceOptions(),
1922 )
1923 )
1924 ).model_dump_json(),
1925 DEFAULT_CONTEXT,
1926 )
1927 assert isinstance(result, StreamingResult)
1928 events = await decode_streaming_result(result)
1929
1930 thread_created = next(e for e in events if e.type == "thread.created")
1931 assert thread_created.thread.id == "thr_custom_1"
1932
1933 user_message = next(
1934 e.item
1935 for e in events
1936 if e.type == "thread.item.done" and e.item.type == "user_message"
1937 )
1938 assert user_message.id == "message_custom_2_thr_custom_1"
1939