openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
529aea8a1f29b8efdb46abbf2711f3b6c39366d0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_chatkit_server.py

1877lines · modecode

1import asyncio
2import sqlite3
3from contextlib import contextmanager
4from datetime import datetime
5from typing import Any, AsyncIterator, Callable, cast
6
7import pytest
8from helpers.mock_store import SQLiteStore
9from pydantic import AnyUrl, TypeAdapter
10
11from chatkit.actions import Action
12from chatkit.errors import ErrorCode
13from chatkit.server import (
14 ChatKitServer,
15 NonStreamingResult,
16 StreamingResult,
17 stream_widget,
18)
19from chatkit.store import (
20 AttachmentStore,
21 NotFoundError,
22 StoreItemType,
23 default_generate_id,
24)
25from chatkit.types import (
26 AssistantMessageContent,
27 AssistantMessageContentPartTextDelta,
28 AssistantMessageItem,
29 Attachment,
30 AttachmentCreateParams,
31 AttachmentDeleteParams,
32 AttachmentsCreateReq,
33 AttachmentsDeleteReq,
34 ClientToolCallItem,
35 FeedbackKind,
36 FileAttachment,
37 ImageAttachment,
38 InferenceOptions,
39 ItemFeedbackParams,
40 ItemsFeedbackReq,
41 ItemsListParams,
42 ItemsListReq,
43 LockedStatus,
44 Page,
45 ProgressUpdateEvent,
46 Thread,
47 ThreadAddClientToolOutputParams,
48 ThreadAddUserMessageParams,
49 ThreadCreatedEvent,
50 ThreadCreateParams,
51 ThreadCustomActionParams,
52 ThreadDeleteParams,
53 ThreadGetByIdParams,
54 ThreadItem,
55 ThreadItemAddedEvent,
56 ThreadItemDoneEvent,
57 ThreadItemRemovedEvent,
58 ThreadItemReplacedEvent,
59 ThreadItemUpdatedEvent,
60 ThreadListParams,
61 ThreadMetadata,
62 ThreadRetryAfterItemParams,
63 ThreadsAddClientToolOutputReq,
64 ThreadsAddUserMessageReq,
65 ThreadsCreateReq,
66 ThreadsCustomActionReq,
67 ThreadsDeleteReq,
68 ThreadsGetByIdReq,
69 ThreadsListReq,
70 ThreadsRetryAfterItemReq,
71 ThreadStreamEvent,
72 ThreadsUpdateReq,
73 ThreadUpdatedEvent,
74 ThreadUpdateParams,
75 ToolChoice,
76 UserMessageInput,
77 UserMessageItem,
78 UserMessageTextContent,
79 WidgetItem,
80 WidgetRootUpdated,
81)
82from chatkit.widgets import Card, Text
83from tests._types import RequestContext
84from tests.test_store import make_thread_items
85
86server_id = 0
87
88
89DEFAULT_CONTEXT = RequestContext(user_id="test_user")
90
91
92class InMemoryFileStore(AttachmentStore):
93 def __init__(self):
94 self.files = {}
95
96 async def delete_attachment(self, attachment_id: str, context: Any) -> None:
97 del self.files[attachment_id]
98
99 async def create_attachment(
100 self, input: AttachmentCreateParams, context: RequestContext
101 ) -> Attachment:
102 # Simple in-memory attachment creation
103 id = f"atc_{len(self.files) + 1}"
104 if input.mime_type.startswith("image/"):
105 attachment = ImageAttachment(
106 id=id,
107 mime_type=input.mime_type,
108 name=input.name,
109 preview_url=AnyUrl(f"https://example.com/{id}/preview"),
110 upload_url=AnyUrl(f"https://example.com/{id}/upload"),
111 )
112 else:
113 attachment = FileAttachment(
114 id=id,
115 mime_type=input.mime_type,
116 name=input.name,
117 upload_url=AnyUrl(f"https://example.com/{id}/upload"),
118 )
119 self.files[attachment.id] = attachment
120 return attachment
121
122
123def decode_event(event: bytes) -> ThreadStreamEvent:
124 return TypeAdapter(ThreadStreamEvent).validate_json(event.split(b"data: ")[1])
125
126
127async def decode_streaming_result(
128 streaming_result: StreamingResult,
129) -> list[ThreadStreamEvent]:
130 return [decode_event(event) async for event in streaming_result.json_events]
131
132
133@contextmanager
134def make_server(
135 responder: Callable[
136 [ThreadMetadata, UserMessageItem | None, Any], AsyncIterator[ThreadStreamEvent]
137 ]
138 | None = None,
139 handle_feedback: Callable[[str, list[str], FeedbackKind, Any], None] | None = None,
140 action_callback: Callable[
141 [ThreadMetadata, Action[str, Any], WidgetItem | None, Any],
142 AsyncIterator[ThreadStreamEvent],
143 ]
144 | None = None,
145 file_store: AttachmentStore | None = None,
146):
147 global server_id
148 db_path = f"file:{server_id}?mode=memory&cache=shared"
149 server_id += 1
150 # Keep the shared in-memory database from being deleted when the last connection is closed
151 db = sqlite3.connect(db_path, uri=True)
152
153 class TestChatKitServer(ChatKitServer):
154 def __init__(self):
155 super().__init__(SQLiteStore(db_path), file_store)
156
157 def action(
158 self,
159 thread: ThreadMetadata,
160 action: Action[str, Any],
161 sender: WidgetItem | None,
162 context: Any,
163 ) -> AsyncIterator[ThreadStreamEvent]:
164 if action_callback is None:
165 raise ValueError("action_callback not wired up")
166 return action_callback(thread, action, sender, context)
167
168 def respond(
169 self,
170 thread: ThreadMetadata,
171 input_user_message: UserMessageItem | None,
172 context: Any,
173 ) -> AsyncIterator[ThreadStreamEvent]:
174 if responder is None:
175 return self._empty_responder()
176 return responder(thread, input_user_message, context)
177
178 async def _empty_responder(self) -> AsyncIterator[ThreadStreamEvent]:
179 return
180 yield
181
182 async def add_feedback(
183 self,
184 thread_id: str,
185 item_ids: list[str],
186 feedback: FeedbackKind,
187 context: Any,
188 ) -> None:
189 if handle_feedback is None:
190 return
191 handle_feedback(thread_id, item_ids, feedback, context)
192
193 async def process_streaming(
194 self, request_obj, context: Any | None = None
195 ) -> list[ThreadStreamEvent]:
196 result = await self.process(
197 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
198 )
199 assert isinstance(result, StreamingResult)
200 return await decode_streaming_result(result)
201
202 async def process_non_streaming(self, request_obj, context: Any | None = None):
203 result = await self.process(
204 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
205 )
206 assert isinstance(result, NonStreamingResult)
207 return result
208
209 try:
210 yield TestChatKitServer()
211 finally:
212 db.close()
213
214
215async def test_stream_cancellation_persists_pending_assistant_message_and_hidden_context():
216 async def responder(
217 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
218 ) -> AsyncIterator[ThreadStreamEvent]:
219 yield ThreadItemAddedEvent(
220 item=AssistantMessageItem(
221 id="assistant-message-pending",
222 created_at=datetime.now(),
223 content=[AssistantMessageContent(text="Hello, ")],
224 thread_id=thread.id,
225 )
226 )
227 yield ThreadItemUpdatedEvent(
228 item_id="assistant-message-pending",
229 update=AssistantMessageContentPartTextDelta(
230 content_index=0,
231 delta="World!",
232 ),
233 )
234 raise asyncio.CancelledError()
235
236 with make_server(responder) as server:
237 # Allow hidden_context id generation in this test store
238 original_generate_item_id = server.store.generate_item_id
239
240 def generate_item_id(
241 item_type: StoreItemType, thread: ThreadMetadata, context: Any
242 ):
243 if item_type == "sdk_hidden_context":
244 return default_generate_id("sdk_hidden_context")
245 return original_generate_item_id(item_type, thread, context)
246
247 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
248
249 stream = await server.process(
250 ThreadsCreateReq(
251 params=ThreadCreateParams(
252 input=UserMessageInput(
253 content=[UserMessageTextContent(text="Hello")],
254 attachments=[],
255 inference_options=InferenceOptions(),
256 )
257 )
258 ).model_dump_json(),
259 DEFAULT_CONTEXT,
260 )
261 assert isinstance(stream, StreamingResult)
262
263 events: list[ThreadStreamEvent] = []
264 with pytest.raises(asyncio.CancelledError): # noqa: PT012
265 async for raw in stream.json_events:
266 events.append(decode_event(raw))
267
268 thread = next(e.thread for e in events if e.type == "thread.created")
269 items = await server.store.load_thread_items(
270 thread.id, None, 1, "desc", DEFAULT_CONTEXT
271 )
272 hidden_context_item = items.data[-1]
273 assert hidden_context_item.type == "sdk_hidden_context"
274 assert (
275 hidden_context_item.content
276 == "The user cancelled the stream. Stop responding to the prior request."
277 )
278
279 assistant_message_item = await server.store.load_item(
280 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
281 )
282 assert assistant_message_item.type == "assistant_message"
283 assert assistant_message_item.content[0].text == "Hello, World!"
284
285
286async def test_stream_cancellation_does_not_persist_pending_empty_assistant_message():
287 async def responder(
288 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
289 ) -> AsyncIterator[ThreadStreamEvent]:
290 yield ThreadItemAddedEvent(
291 item=AssistantMessageItem(
292 id="assistant-message-pending",
293 created_at=datetime.now(),
294 content=[],
295 thread_id=thread.id,
296 )
297 )
298 raise asyncio.CancelledError()
299
300 with make_server(responder) as server:
301 original_generate_item_id = server.store.generate_item_id
302
303 def generate_item_id(
304 item_type: StoreItemType, thread: ThreadMetadata, context: Any
305 ):
306 if item_type == "sdk_hidden_context":
307 return default_generate_id("sdk_hidden_context")
308 return original_generate_item_id(item_type, thread, context)
309
310 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
311
312 stream = await server.process(
313 ThreadsCreateReq(
314 params=ThreadCreateParams(
315 input=UserMessageInput(
316 content=[UserMessageTextContent(text="Hello")],
317 attachments=[],
318 inference_options=InferenceOptions(),
319 )
320 )
321 ).model_dump_json(),
322 DEFAULT_CONTEXT,
323 )
324 assert isinstance(stream, StreamingResult)
325
326 events: list[ThreadStreamEvent] = []
327 with pytest.raises(asyncio.CancelledError): # noqa: PT012
328 async for raw in stream.json_events:
329 events.append(decode_event(raw))
330
331 thread = next(e.thread for e in events if e.type == "thread.created")
332 items = await server.store.load_thread_items(
333 thread.id, None, 1, "desc", DEFAULT_CONTEXT
334 )
335 hidden_context_item = items.data[-1]
336 assert hidden_context_item.type == "sdk_hidden_context"
337 assert (
338 hidden_context_item.content
339 == "The user cancelled the stream. Stop responding to the prior request."
340 )
341
342 with pytest.raises(NotFoundError):
343 await server.store.load_item(
344 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
345 )
346
347
348async def test_flows_context_to_responder():
349 responder_context = None
350 add_feedback_context = None
351
352 async def responder(
353 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
354 ) -> AsyncIterator[ThreadStreamEvent]:
355 nonlocal responder_context
356 responder_context = context
357 yield ThreadItemDoneEvent(
358 item=AssistantMessageItem(
359 id="msg_1",
360 created_at=datetime.now(),
361 content=[AssistantMessageContent(text="Hello, world!")],
362 thread_id=thread.id,
363 ),
364 )
365
366 def add_feedback(
367 thread_id: str, item_ids: list[str], feedback: FeedbackKind, context: Any
368 ) -> None:
369 nonlocal add_feedback_context
370 add_feedback_context = context
371
372 context = RequestContext(user_id="text_ctx")
373 with make_server(responder, add_feedback) as server:
374 events = await server.process_streaming(
375 ThreadsCreateReq(
376 params=ThreadCreateParams(
377 input=UserMessageInput(
378 content=[UserMessageTextContent(text="Hello, world!")],
379 attachments=[],
380 inference_options=InferenceOptions(),
381 )
382 )
383 ),
384 context,
385 )
386 assert responder_context == context
387 thread = next(
388 event.thread for event in events if event.type == "thread.created"
389 )
390 item = next(
391 event.item
392 for event in events
393 if event.type == "thread.item.done"
394 and event.item.type == "assistant_message"
395 )
396 await server.process_non_streaming(
397 ItemsFeedbackReq(
398 params=ItemFeedbackParams(
399 thread_id=thread.id,
400 item_ids=[item.id],
401 kind="positive",
402 )
403 ),
404 context,
405 )
406 assert add_feedback_context == context
407
408
409async def test_creates_thread():
410 async def responder(
411 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
412 ) -> AsyncIterator[ThreadStreamEvent]:
413 return
414 yield
415
416 with make_server(responder) as server:
417 events = await server.process_streaming(
418 ThreadsCreateReq(
419 params=ThreadCreateParams(
420 input=UserMessageInput(
421 content=[UserMessageTextContent(text="Hello, world!")],
422 attachments=[],
423 inference_options=InferenceOptions(),
424 quoted_text="Quoted text",
425 )
426 )
427 )
428 )
429
430 thread_created_event = next(
431 event for event in events if event.type == "thread.created"
432 )
433 assert thread_created_event.thread.title is None
434
435 item_done_event = next(
436 event.item for event in events if event.type == "thread.item.done"
437 )
438
439 assert item_done_event.type == "user_message"
440 assert item_done_event.content[0].text == "Hello, world!"
441 assert item_done_event.quoted_text == "Quoted text"
442
443
444async def test_saves_thread_title():
445 async def responder(
446 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
447 ) -> AsyncIterator[ThreadStreamEvent]:
448 thread.title = "Updated title"
449 return
450 yield
451
452 with make_server(responder) as server:
453 events = await server.process_streaming(
454 ThreadsCreateReq(
455 params=ThreadCreateParams(
456 input=UserMessageInput(
457 content=[UserMessageTextContent(text="Hello, world!")],
458 attachments=[],
459 inference_options=InferenceOptions(),
460 )
461 )
462 )
463 )
464 thread = next(
465 event.thread for event in events if event.type == "thread.created"
466 )
467 assert (
468 await server.store.load_thread(
469 thread.id, RequestContext(user_id="test_user")
470 )
471 ).title == "Updated title"
472 assert events[-1].type == "thread.updated"
473 assert events[-1].thread.title == "Updated title"
474
475
476async def test_saves_thread_metadata():
477 async def responder(
478 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
479 ) -> AsyncIterator[ThreadStreamEvent]:
480 thread.metadata["test_key"] = "test_value"
481 return
482 yield
483
484 with make_server(responder) as server:
485 events = await server.process_streaming(
486 ThreadsCreateReq(
487 params=ThreadCreateParams(
488 input=UserMessageInput(
489 content=[UserMessageTextContent(text="Hello, world!")],
490 attachments=[],
491 inference_options=InferenceOptions(),
492 )
493 )
494 )
495 )
496 thread = next(
497 event.thread for event in events if event.type == "thread.created"
498 )
499 assert (await server.store.load_thread(thread.id, DEFAULT_CONTEXT)).metadata[
500 "test_key"
501 ] == "test_value"
502 assert events[-1].type == "thread.updated"
503
504
505async def test_saves_thread_locked_fields():
506 async def responder(
507 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
508 ) -> AsyncIterator[ThreadStreamEvent]:
509 thread.status = LockedStatus(reason="Because")
510 return
511 yield
512
513 with make_server(responder) as server:
514 events = await server.process_streaming(
515 ThreadsCreateReq(
516 params=ThreadCreateParams(
517 input=UserMessageInput(
518 content=[UserMessageTextContent(text="Hello, world!")],
519 attachments=[],
520 inference_options=InferenceOptions(),
521 )
522 )
523 )
524 )
525 thread = next(
526 event.thread for event in events if event.type == "thread.created"
527 )
528 loaded = await server.store.load_thread(
529 thread.id, RequestContext(user_id="test_user")
530 )
531 assert loaded.status == LockedStatus(reason="Because")
532 assert events[-1].type == "thread.updated"
533 assert events[-1].thread.status == LockedStatus(reason="Because")
534
535
536async def test_emits_thread_updated_mid_stream_and_persists():
537 async def responder(
538 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
539 ) -> AsyncIterator[ThreadStreamEvent]:
540 # First assistant message
541 yield ThreadItemDoneEvent(
542 item=AssistantMessageItem(
543 id="assistant_a",
544 content=[AssistantMessageContent(text="A")],
545 created_at=datetime.now(),
546 thread_id=thread.id,
547 ),
548 )
549
550 # Update the thread mid-stream
551 thread.title = "Mid-stream title"
552
553 # Second assistant message, after which thread.updated should be emitted
554 yield ThreadItemDoneEvent(
555 item=AssistantMessageItem(
556 id="assistant_b",
557 content=[AssistantMessageContent(text="B")],
558 created_at=datetime.now(),
559 thread_id=thread.id,
560 ),
561 )
562
563 # Continue streaming after the update event
564 yield ThreadItemDoneEvent(
565 item=AssistantMessageItem(
566 id="assistant_c",
567 content=[AssistantMessageContent(text="C")],
568 created_at=datetime.now(),
569 thread_id=thread.id,
570 ),
571 )
572
573 with make_server(responder) as server:
574 events = await server.process_streaming(
575 ThreadsCreateReq(
576 params=ThreadCreateParams(
577 input=UserMessageInput(
578 content=[UserMessageTextContent(text="Hello")],
579 attachments=[],
580 inference_options=InferenceOptions(),
581 )
582 )
583 )
584 )
585
586 # Find the created thread
587 thread = next(e.thread for e in events if isinstance(e, ThreadCreatedEvent))
588
589 # Ensure a thread.updated event occurs right after assistant_b
590 idx_b = next(
591 i
592 for i, e in enumerate(events)
593 if isinstance(e, ThreadItemDoneEvent)
594 and isinstance(e.item, AssistantMessageItem)
595 and e.item.id == "assistant_b"
596 )
597 assert isinstance(events[idx_b + 1], ThreadUpdatedEvent)
598 updated_event = cast(ThreadUpdatedEvent, events[idx_b + 1])
599 assert updated_event.thread.title == "Mid-stream title"
600
601 # Streaming continues after the update event
602 assert isinstance(events[idx_b + 2], ThreadItemDoneEvent)
603 done_event = cast(ThreadItemDoneEvent, events[idx_b + 2])
604 assert isinstance(done_event.item, AssistantMessageItem)
605 assert done_event.item.id == "assistant_c"
606
607 # Persisted in the store
608 saved = await server.store.load_thread(thread.id, DEFAULT_CONTEXT)
609 assert saved.title == "Mid-stream title"
610
611
612async def test_respond_with_tool_call():
613 async def responder(
614 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
615 ) -> AsyncIterator[ThreadStreamEvent]:
616 if isinstance(input, UserMessageItem):
617 yield ThreadItemDoneEvent(
618 item=ClientToolCallItem(
619 id="msg_1",
620 created_at=datetime.now(),
621 name="tool_call_1",
622 arguments={"arg1": "val1", "arg2": False},
623 call_id="tool_call_1",
624 thread_id=thread.id,
625 ),
626 )
627 elif input is None:
628 yield ThreadItemDoneEvent(
629 item=AssistantMessageItem(
630 id="msg_2",
631 content=[
632 AssistantMessageContent(text="Glad the tool call succeeded!")
633 ],
634 created_at=datetime.now(),
635 thread_id=thread.id,
636 ),
637 )
638
639 with make_server(responder) as server:
640 events = await server.process_streaming(
641 ThreadsCreateReq(
642 params=ThreadCreateParams(
643 input=UserMessageInput(
644 content=[UserMessageTextContent(text="Hello, world!")],
645 attachments=[],
646 inference_options=InferenceOptions(),
647 )
648 )
649 )
650 )
651
652 assert len(events) == 4
653 assert events[0].type == "thread.created"
654 thread = events[0].thread
655
656 assert events[1].type == "thread.item.done"
657 assert events[1].item.type == "user_message"
658
659 assert events[2].type == "stream_options"
660
661 assert events[3].type == "thread.item.done"
662 assert events[3].item.type == "client_tool_call"
663 assert events[3].item.id == "msg_1"
664 assert events[3].item.name == "tool_call_1"
665 assert events[3].item.arguments == {"arg1": "val1", "arg2": False}
666 assert events[3].item.call_id == "tool_call_1"
667
668 events = await server.process_streaming(
669 ThreadsAddClientToolOutputReq(
670 params=ThreadAddClientToolOutputParams(
671 thread_id=thread.id,
672 result={"text": "Wow!"},
673 )
674 )
675 )
676
677 assert len(events) == 2
678 assert events[0].type == "stream_options"
679 assert events[1].type == "thread.item.done"
680 assert events[1].item.type == "assistant_message"
681
682async def test_respond_with_tool_status():
683 async def responder(
684 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
685 ) -> AsyncIterator[ThreadStreamEvent]:
686 yield ProgressUpdateEvent(text="Tool status")
687 yield ThreadItemDoneEvent(
688 item=AssistantMessageItem(
689 id="msg_4",
690 content=[AssistantMessageContent(text="Hello, world!")],
691 created_at=datetime.now(),
692 thread_id=thread.id,
693 ),
694 )
695
696 with make_server(responder) as server:
697 events = await server.process_streaming(
698 ThreadsCreateReq(
699 params=ThreadCreateParams(
700 input=UserMessageInput(
701 content=[UserMessageTextContent(text="Hello, world!")],
702 attachments=[],
703 inference_options=InferenceOptions(),
704 )
705 )
706 )
707 )
708
709 assert len(events) == 5
710 assert events[0].type == "thread.created"
711 assert events[1].type == "thread.item.done"
712 assert events[2].type == "stream_options"
713 assert events[3].type == "progress_update"
714 assert events[4].type == "thread.item.done"
715
716
717async def test_list_threads_response():
718 with make_server() as server:
719 for i in range(25):
720 await server.process_streaming(
721 ThreadsCreateReq(
722 params=ThreadCreateParams(
723 input=UserMessageInput(
724 content=[UserMessageTextContent(text=f"Thread {i}")],
725 attachments=[],
726 inference_options=InferenceOptions(),
727 )
728 )
729 )
730 )
731 list_result = await server.process_non_streaming(
732 ThreadsListReq(params=ThreadListParams(limit=20))
733 )
734 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
735 assert len(threads.data) == 20
736
737 # Newest first
738 assert threads.data[0].created_at > threads.data[1].created_at
739 assert threads.has_more is True
740 assert threads.after is not None
741
742 more_threads = await server.process_non_streaming(
743 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
744 )
745 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(more_threads.json)
746 assert len(threads.data) == 5
747 assert threads.has_more is False
748 assert not threads.after
749
750
751async def test_list_items_response():
752 async def responder(
753 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
754 ) -> AsyncIterator[ThreadStreamEvent]:
755 for i in range(25):
756 yield ThreadItemDoneEvent(
757 item=UserMessageItem(
758 id=f"msg_{i}",
759 content=[UserMessageTextContent(text=f"Message {i}")],
760 attachments=[],
761 inference_options=InferenceOptions(),
762 thread_id=thread.id,
763 created_at=datetime.now(),
764 ),
765 )
766
767 with make_server(responder) as server:
768 events = await server.process_streaming(
769 ThreadsCreateReq(
770 params=ThreadCreateParams(
771 input=UserMessageInput(
772 content=[UserMessageTextContent(text="Thread 1")],
773 attachments=[],
774 inference_options=InferenceOptions(),
775 )
776 )
777 )
778 )
779 thread = next(
780 event.thread for event in events if event.type == "thread.created"
781 )
782 list_result = await server.process_non_streaming(
783 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
784 )
785 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
786 assert len(items.data) == 20
787 # Newest first
788 assert items.data[0].created_at > items.data[1].created_at
789 assert items.has_more is True
790 assert items.after is not None
791
792 list_result = await server.process_non_streaming(
793 ItemsListReq(
794 params=ItemsListParams(thread_id=thread.id, after=items.data[-1].id)
795 )
796 )
797 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
798 assert len(items.data) == 6 # +1 for user message
799 assert items.has_more is False
800
801
802async def test_calls_action():
803 actions = []
804
805 async def responder(
806 thread: ThreadMetadata,
807 input: UserMessageItem | None,
808 context: Any,
809 ) -> AsyncIterator[ThreadStreamEvent]:
810 yield ThreadItemDoneEvent(
811 item=WidgetItem(
812 id="widget_1",
813 type="widget",
814 created_at=datetime.now(),
815 thread_id=thread.id,
816 widget=Card(
817 children=[
818 Text(
819 key="text_1",
820 value="test",
821 )
822 ],
823 ),
824 ),
825 )
826
827 async def action(
828 thread: ThreadMetadata,
829 action: Action[str, Any],
830 sender: WidgetItem | None,
831 context: Any,
832 ) -> AsyncIterator[ThreadStreamEvent]:
833 actions.append((action, sender))
834 assert sender
835
836 yield ThreadItemUpdatedEvent(
837 item_id=sender.id,
838 update=WidgetRootUpdated(
839 widget=Card(
840 children=[
841 Text(value="Email sent!"),
842 ]
843 ),
844 ),
845 )
846
847 with make_server(responder, action_callback=action) as server:
848 events = await server.process_streaming(
849 ThreadsCreateReq(
850 params=ThreadCreateParams(
851 input=UserMessageInput(
852 content=[UserMessageTextContent(text="Show widget")],
853 attachments=[],
854 inference_options=InferenceOptions(),
855 )
856 )
857 )
858 )
859 thread = next(
860 event.thread for event in events if event.type == "thread.created"
861 )
862 widget_item = next(
863 event.item
864 for event in events
865 if isinstance(event, ThreadItemDoneEvent)
866 and isinstance(event.item, WidgetItem)
867 )
868
869 events = await server.process_streaming(
870 ThreadsCustomActionReq(
871 params=ThreadCustomActionParams(
872 thread_id=thread.id,
873 item_id=widget_item.id,
874 action=Action(type="create_user", payload={"user_id": "123"}),
875 )
876 )
877 )
878
879 assert actions
880 assert actions[0] == (
881 Action(type="create_user", payload={"user_id": "123"}),
882 widget_item,
883 )
884
885 assert len(events) == 2
886 assert events[0].type == "stream_options"
887 assert events[1].type == "thread.item.updated"
888 assert isinstance(events[1], ThreadItemUpdatedEvent)
889 assert events[1].update.type == "widget.root.updated"
890 assert events[1].update.widget == Card(children=[Text(value="Email sent!")])
891
892
893async def test_add_feedback():
894 called = []
895
896 def handle_feedback(thread_id, item_ids, feedback, context):
897 called.append((thread_id, item_ids, feedback, context))
898
899 async def responder(
900 thread: ThreadMetadata,
901 input: UserMessageItem | None,
902 context: Any,
903 ) -> AsyncIterator[ThreadStreamEvent]:
904 yield ThreadItemDoneEvent(
905 item=AssistantMessageItem(
906 id="assistant_msg_1",
907 content=[
908 AssistantMessageContent(text="Feedback received!"),
909 ],
910 created_at=datetime.now(),
911 thread_id=thread.id,
912 ),
913 )
914 yield ThreadItemDoneEvent(
915 item=WidgetItem(
916 id="widget_1",
917 type="widget",
918 created_at=datetime.now(),
919 widget=Card(children=[Text(key="text", value="Widget content")]),
920 thread_id=thread.id,
921 ),
922 )
923
924 with make_server(responder, handle_feedback=handle_feedback) as server:
925 events = await server.process_streaming(
926 ThreadsCreateReq(
927 params=ThreadCreateParams(
928 input=UserMessageInput(
929 content=[UserMessageTextContent(text="Test feedback")],
930 attachments=[],
931 inference_options=InferenceOptions(),
932 )
933 )
934 )
935 )
936 thread = next(
937 event.thread for event in events if event.type == "thread.created"
938 )
939
940 await server.process_non_streaming(
941 ItemsFeedbackReq(
942 params=ItemFeedbackParams(
943 thread_id=thread.id,
944 item_ids=["assistant_msg_1"],
945 kind="positive",
946 )
947 )
948 )
949 await server.process_non_streaming(
950 ItemsFeedbackReq(
951 params=ItemFeedbackParams(
952 thread_id=thread.id,
953 item_ids=["widget_1"],
954 kind="negative",
955 )
956 )
957 )
958
959 assert len(called) == 2
960 assert called[0][0] == thread.id
961 assert called[0][1] == ["assistant_msg_1"]
962 assert called[0][2] == "positive"
963
964 assert called[1][0] == thread.id
965 assert called[1][1] == ["widget_1"]
966 assert called[1][2] == "negative"
967
968
969async def test_get_thread_by_id():
970 make_thread_items()
971
972 with make_server() as server:
973 # Create a thread by sending a ThreadCreateRequest
974 events = await server.process_streaming(
975 ThreadsCreateReq(
976 params=ThreadCreateParams(
977 input=UserMessageInput(
978 content=[UserMessageTextContent(text="Test thread")],
979 attachments=[],
980 inference_options=InferenceOptions(),
981 )
982 )
983 )
984 )
985 thread = next(
986 event.thread for event in events if event.type == "thread.created"
987 )
988
989 # Retrieve the thread by id
990 result = await server.process_non_streaming(
991 ThreadsGetByIdReq(
992 params=ThreadGetByIdParams(thread_id=thread.id),
993 )
994 )
995 loaded_thread = TypeAdapter(Thread).validate_json(result.json)
996 assert loaded_thread.id == thread.id
997 assert loaded_thread.title == thread.title
998 assert loaded_thread.items.data[0].type == "user_message"
999 assert loaded_thread.items.data[0].content[0].text == "Test thread"
1000
1001
1002async def test_create_file():
1003 store = InMemoryFileStore()
1004 file_name = "test-file-name"
1005 file_content_type = "text/plain"
1006
1007 with make_server(file_store=store) as server:
1008 result = await server.process_non_streaming(
1009 AttachmentsCreateReq(
1010 params=AttachmentCreateParams(
1011 name=file_name, size=1024, mime_type=file_content_type
1012 )
1013 )
1014 )
1015 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1016
1017 # Verify response includes file
1018 assert attachment.id is not None
1019 assert attachment.mime_type == file_content_type
1020 assert attachment.name == file_name
1021 assert attachment.type == "file"
1022 assert attachment.upload_url == AnyUrl(
1023 f"https://example.com/{attachment.id}/upload"
1024 )
1025 assert attachment.id in store.files
1026
1027
1028async def test_create_image_file():
1029 store = InMemoryFileStore()
1030 file_name = "test-file-name"
1031 file_content_type = "image/png"
1032
1033 with make_server(file_store=store) as server:
1034 result = await server.process_non_streaming(
1035 AttachmentsCreateReq(
1036 params=AttachmentCreateParams(
1037 name=file_name,
1038 size=1024,
1039 mime_type=file_content_type,
1040 )
1041 )
1042 )
1043 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1044
1045 assert attachment.id is not None
1046 assert attachment.mime_type == file_content_type
1047 assert attachment.name == file_name
1048 assert attachment.type == "image"
1049 assert attachment.preview_url == AnyUrl(
1050 f"https://example.com/{attachment.id}/preview"
1051 )
1052 assert attachment.upload_url == AnyUrl(
1053 f"https://example.com/{attachment.id}/upload"
1054 )
1055
1056 assert attachment.id in store.files
1057
1058
1059async def test_create_file_without_filestore():
1060 with pytest.raises(RuntimeError):
1061 with make_server() as server:
1062 await server.process_non_streaming(
1063 AttachmentsCreateReq(
1064 params=AttachmentCreateParams(
1065 name="test-file-name",
1066 size=1024,
1067 mime_type="text/plain",
1068 )
1069 )
1070 )
1071
1072
1073async def test_delete_file():
1074 store = InMemoryFileStore()
1075 store.files["test-file-id"] = {"id": "test-file-id"}
1076 with make_server(file_store=store) as server:
1077 await server.process_non_streaming(
1078 AttachmentsDeleteReq(
1079 params=AttachmentDeleteParams(attachment_id="test-file-id")
1080 )
1081 )
1082 assert store.files == {}
1083
1084
1085async def test_delete_file_without_filestore():
1086 with pytest.raises(RuntimeError):
1087 with make_server() as server:
1088 await server.process_non_streaming(
1089 AttachmentsDeleteReq(
1090 params=AttachmentDeleteParams(attachment_id="test-file-id")
1091 )
1092 )
1093
1094
1095async def test_list_threads_paginated_response():
1096 with make_server() as server:
1097 for i in range(25):
1098 await server.process_streaming(
1099 ThreadsCreateReq(
1100 params=ThreadCreateParams(
1101 input=UserMessageInput(
1102 content=[UserMessageTextContent(text=f"Thread {i}")],
1103 attachments=[],
1104 inference_options=InferenceOptions(),
1105 )
1106 )
1107 )
1108 )
1109 list_result = await server.process_non_streaming(
1110 ThreadsListReq(params=ThreadListParams(limit=20))
1111 )
1112 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1113 assert len(threads.data) == 20
1114 assert threads.has_more is True
1115 assert threads.after is not None
1116
1117 list_result = await server.process_non_streaming(
1118 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
1119 )
1120 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1121 assert len(threads.data) == 5
1122 assert threads.has_more is False
1123
1124
1125async def test_threads_update_request():
1126 with make_server() as server:
1127 events = await server.process_streaming(
1128 ThreadsCreateReq(
1129 params=ThreadCreateParams(
1130 input=UserMessageInput(
1131 content=[UserMessageTextContent(text="Hello, world!")],
1132 attachments=[],
1133 inference_options=InferenceOptions(),
1134 )
1135 )
1136 )
1137 )
1138 thread = next(
1139 event.thread for event in events if event.type == "thread.created"
1140 )
1141 updated_thread = await server.process_non_streaming(
1142 ThreadsUpdateReq(
1143 params=ThreadUpdateParams(thread_id=thread.id, title="New Title")
1144 )
1145 )
1146 updated_thread = TypeAdapter[Thread](Thread).validate_json(updated_thread.json)
1147 assert updated_thread.title == "New Title"
1148
1149
1150async def test_returns_widget_item_generator():
1151 thread = ThreadMetadata(
1152 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1153 )
1154
1155 def render_widget(i: int) -> Card:
1156 return Card(children=[Text(id="text", value="Hello, world"[:i])])
1157
1158 async def widget_generator():
1159 yield render_widget(0)
1160 yield render_widget(3)
1161 yield render_widget(6)
1162 yield render_widget(12)
1163
1164 events = [event async for event in stream_widget(thread, widget_generator())]
1165
1166 assert len(events) == 5
1167 assert isinstance(events[0], ThreadItemAddedEvent)
1168 assert isinstance(events[0].item, WidgetItem)
1169 assert events[0].item.widget == Card(children=[Text(id="text", value="")])
1170
1171 assert isinstance(events[1], ThreadItemUpdatedEvent)
1172 assert events[1].update.type == "widget.streaming_text.value_delta"
1173 assert events[1].update.component_id == "text"
1174 assert events[1].update.delta == "Hel"
1175
1176 assert isinstance(events[2], ThreadItemUpdatedEvent)
1177 assert events[2].update.type == "widget.streaming_text.value_delta"
1178 assert events[2].update.component_id == "text"
1179 assert events[2].update.delta == "lo,"
1180
1181 assert isinstance(events[3], ThreadItemUpdatedEvent)
1182 assert events[3].update.type == "widget.streaming_text.value_delta"
1183 assert events[3].update.component_id == "text"
1184 assert events[3].update.delta == " world"
1185
1186 assert isinstance(events[4], ThreadItemDoneEvent)
1187 assert isinstance(events[4].item, WidgetItem)
1188 assert events[4].item.widget == Card(
1189 children=[Text(id="text", value="Hello, world")]
1190 )
1191
1192
1193async def test_returns_widget_item_generator_full_replace():
1194 thread = ThreadMetadata(
1195 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1196 )
1197
1198 async def widget_generator():
1199 yield Card(children=[Text(id="text", value="Hello")])
1200 yield Card(children=[Text(id="text", value="World")])
1201
1202 events = [event async for event in stream_widget(thread, widget_generator())]
1203
1204 assert len(events) == 3
1205 assert isinstance(events[0], ThreadItemAddedEvent)
1206 assert isinstance(events[0].item, WidgetItem)
1207 assert events[0].item.widget == Card(children=[Text(id="text", value="Hello")])
1208
1209 assert isinstance(events[1], ThreadItemUpdatedEvent)
1210 assert events[1].update.type == "widget.root.updated"
1211 assert events[1].update.widget == Card(children=[Text(id="text", value="World")])
1212
1213 assert isinstance(events[2], ThreadItemDoneEvent)
1214 assert isinstance(events[2].item, WidgetItem)
1215 assert events[2].item.widget == Card(children=[Text(id="text", value="World")])
1216
1217
1218async def test_delete_thread():
1219 with make_server() as server:
1220 events = await server.process_streaming(
1221 ThreadsCreateReq(
1222 params=ThreadCreateParams(
1223 input=UserMessageInput(
1224 content=[UserMessageTextContent(text="Thread to delete")],
1225 attachments=[],
1226 inference_options=InferenceOptions(),
1227 )
1228 )
1229 )
1230 )
1231 thread = next(
1232 event.thread for event in events if event.type == "thread.created"
1233 )
1234
1235 response = await server.process_non_streaming(
1236 ThreadsDeleteReq(params=ThreadDeleteParams(thread_id=thread.id))
1237 )
1238 assert response.json == b"{}"
1239
1240 try:
1241 await server.process_non_streaming(
1242 ThreadsGetByIdReq(
1243 params=ThreadGetByIdParams(thread_id=thread.id),
1244 )
1245 )
1246 assert False, "Expected ValueError for deleted thread"
1247 except NotFoundError as e:
1248 assert f"Thread {thread.id} not found" in str(e)
1249
1250
1251async def test_thread_item_removed_event_removes_item():
1252 removed_id = "msg_to_remove"
1253
1254 async def responder(
1255 thread: ThreadMetadata,
1256 input: UserMessageItem | None,
1257 context: Any,
1258 ) -> AsyncIterator[ThreadStreamEvent]:
1259 yield ThreadItemDoneEvent(
1260 item=UserMessageItem(
1261 id=removed_id,
1262 content=[UserMessageTextContent(text="To be removed")],
1263 attachments=[],
1264 inference_options=InferenceOptions(),
1265 thread_id=thread.id,
1266 created_at=datetime.now(),
1267 ),
1268 )
1269 yield ThreadItemRemovedEvent(item_id=removed_id)
1270
1271 with make_server(responder) as server:
1272 events = await server.process_streaming(
1273 ThreadsCreateReq(
1274 params=ThreadCreateParams(
1275 input=UserMessageInput(
1276 content=[UserMessageTextContent(text="Trigger removal")],
1277 attachments=[],
1278 inference_options=InferenceOptions(),
1279 )
1280 )
1281 )
1282 )
1283 thread = next(
1284 event.thread for event in events if event.type == "thread.created"
1285 )
1286 assert any(
1287 e.type == "thread.item.removed" and e.item_id == removed_id for e in events
1288 )
1289 # The item should not be present in the store
1290 list_result = await server.process_non_streaming(
1291 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1292 )
1293 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1294 assert all(i.id != removed_id for i in items)
1295
1296
1297async def test_raising_in_responder_yields_error_event():
1298 async def responder(
1299 thread: ThreadMetadata,
1300 input: UserMessageItem | None,
1301 context: Any,
1302 ) -> AsyncIterator[ThreadStreamEvent]:
1303 raise ValueError("Test error")
1304 yield
1305
1306 with make_server(responder) as server:
1307 result = await server.process(
1308 ThreadsCreateReq(
1309 params=ThreadCreateParams(
1310 input=UserMessageInput(
1311 content=[UserMessageTextContent(text="Test error")],
1312 attachments=[],
1313 inference_options=InferenceOptions(),
1314 )
1315 )
1316 ).model_dump_json(),
1317 RequestContext(user_id="test_user"),
1318 )
1319 assert isinstance(result, StreamingResult)
1320 iter = result.__aiter__()
1321
1322 # Yield an errro
1323 async for e in iter:
1324 e = decode_event(e)
1325 if e.type == "error":
1326 assert e.code == ErrorCode.STREAM_ERROR
1327 break
1328 else:
1329 assert False, "Expected error event"
1330
1331 # Then finish
1332 try:
1333 await iter.__anext__()
1334 except StopAsyncIteration:
1335 pass
1336 else:
1337 assert False, "Expected StopAsyncIteration"
1338
1339
1340async def test_thread_item_replaced_event_replaces_item():
1341 original_id = "msg_to_replace"
1342
1343 async def responder(
1344 thread: ThreadMetadata,
1345 input: UserMessageItem | None,
1346 context: Any,
1347 ) -> AsyncIterator[ThreadStreamEvent]:
1348 # First add an item
1349 yield ThreadItemDoneEvent(
1350 item=UserMessageItem(
1351 id=original_id,
1352 content=[UserMessageTextContent(text="Original content")],
1353 attachments=[],
1354 inference_options=InferenceOptions(),
1355 thread_id=thread.id,
1356 created_at=datetime.now(),
1357 ),
1358 )
1359 # Then replace it
1360 replacement_item = AssistantMessageItem(
1361 id=original_id,
1362 content=[AssistantMessageContent(text="This is the replacement content")],
1363 created_at=datetime.now(),
1364 thread_id=thread.id,
1365 )
1366 yield ThreadItemReplacedEvent(item=replacement_item)
1367
1368 with make_server(responder) as server:
1369 events = await server.process_streaming(
1370 ThreadsCreateReq(
1371 params=ThreadCreateParams(
1372 input=UserMessageInput(
1373 content=[UserMessageTextContent(text="Trigger replacement")],
1374 attachments=[],
1375 inference_options=InferenceOptions(),
1376 )
1377 )
1378 )
1379 )
1380 thread = next(
1381 event.thread for event in events if event.type == "thread.created"
1382 )
1383
1384 # Verify the replacement event was generated
1385 assert any(
1386 e.type == "thread.item.replaced" and e.item.id == original_id
1387 for e in events
1388 )
1389
1390 # Verify the item was replaced in the store
1391 list_result = await server.process_non_streaming(
1392 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1393 )
1394 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1395
1396 # Find the replaced item
1397 replaced_item = next((i for i in items if i.id == original_id), None)
1398 assert replaced_item is not None
1399 assert isinstance(replaced_item, AssistantMessageItem)
1400 assert isinstance(replaced_item.content[0], AssistantMessageContent)
1401 assert replaced_item.content[0].text == "This is the replacement content"
1402
1403
1404async def test_thread_item_replaced_event_with_widget():
1405 widget_id = "widget_to_replace"
1406 original_widget = Card(children=[Text(value="Original widget content")])
1407 replacement_widget = Card(children=[Text(value="Replaced widget content")])
1408
1409 async def responder(
1410 thread: ThreadMetadata,
1411 input: UserMessageItem | None,
1412 context: Any,
1413 ) -> AsyncIterator[ThreadStreamEvent]:
1414 # First add a widget item
1415 yield ThreadItemDoneEvent(
1416 item=WidgetItem(
1417 id=widget_id,
1418 widget=original_widget,
1419 created_at=datetime.now(),
1420 thread_id=thread.id,
1421 ),
1422 )
1423 # Then replace it
1424 yield ThreadItemReplacedEvent(
1425 item=WidgetItem(
1426 id=widget_id,
1427 widget=replacement_widget,
1428 created_at=datetime.now(),
1429 thread_id=thread.id,
1430 ),
1431 )
1432
1433 with make_server(responder) as server:
1434 events = await server.process_streaming(
1435 ThreadsCreateReq(
1436 params=ThreadCreateParams(
1437 input=UserMessageInput(
1438 content=[
1439 UserMessageTextContent(text="Trigger widget replacement")
1440 ],
1441 attachments=[],
1442 inference_options=InferenceOptions(),
1443 )
1444 )
1445 )
1446 )
1447 thread = next(
1448 event.thread for event in events if event.type == "thread.created"
1449 )
1450
1451 # Verify the replacement event was generated
1452 replacement_event = next(
1453 (
1454 e
1455 for e in events
1456 if e.type == "thread.item.replaced" and e.item.id == widget_id
1457 ),
1458 None,
1459 )
1460 assert replacement_event is not None
1461 assert isinstance(replacement_event.item, WidgetItem)
1462 assert replacement_event.item.widget == replacement_widget
1463
1464 # Verify the item was replaced in the store
1465 list_result = await server.process_non_streaming(
1466 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1467 )
1468 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1469
1470 # Find the replaced widget item
1471 replaced_item = next((i for i in items if i.id == widget_id), None)
1472 assert replaced_item is not None
1473 assert isinstance(replaced_item, WidgetItem)
1474 assert isinstance(replaced_item.widget, Card)
1475 assert isinstance(replaced_item.widget.children[0], Text)
1476 assert replaced_item.widget.children[0].value == "Replaced widget content"
1477
1478
1479async def test_retry_after_item():
1480 responder_calls = []
1481
1482 async def responder(
1483 thread: ThreadMetadata,
1484 input: UserMessageItem | None,
1485 context: Any,
1486 ) -> AsyncIterator[ThreadStreamEvent]:
1487 responder_calls.append(input)
1488
1489 if len(responder_calls) == 1:
1490 # First response - return an assistant message
1491 yield ThreadItemDoneEvent(
1492 item=AssistantMessageItem(
1493 id="assistant_msg_1",
1494 content=[AssistantMessageContent(text="First response")],
1495 created_at=datetime.now(),
1496 thread_id=thread.id,
1497 ),
1498 )
1499 else:
1500 # Retry response - return different content
1501 yield ThreadItemDoneEvent(
1502 item=AssistantMessageItem(
1503 id="assistant_msg_2",
1504 content=[AssistantMessageContent(text="Retried response")],
1505 created_at=datetime.now(),
1506 thread_id=thread.id,
1507 ),
1508 )
1509
1510 with make_server(responder) as server:
1511 events = await server.process_streaming(
1512 ThreadsCreateReq(
1513 params=ThreadCreateParams(
1514 input=UserMessageInput(
1515 content=[UserMessageTextContent(text="Hello, world!")],
1516 attachments=[],
1517 inference_options=InferenceOptions(),
1518 )
1519 )
1520 )
1521 )
1522 thread = next(
1523 event.thread for event in events if event.type == "thread.created"
1524 )
1525
1526 list_result = await server.process_non_streaming(
1527 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1528 )
1529 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1530 assert len(items_before.data) == 2
1531 assert items_before.data[0].type == "assistant_message"
1532 assert items_before.data[0].content[0].type == "output_text"
1533 assert items_before.data[0].content[0].text == "First response"
1534 assert items_before.data[1].type == "user_message"
1535 assert items_before.data[1].content[0].text == "Hello, world!"
1536
1537 # Get the user message ID for retrying after it
1538 user_message_id = items_before.data[1].id
1539
1540 # Retry after the user message
1541 retry_events = await server.process_streaming(
1542 ThreadsRetryAfterItemReq(
1543 params=ThreadRetryAfterItemParams(
1544 thread_id=thread.id, item_id=user_message_id
1545 )
1546 )
1547 )
1548
1549 # Verify the retry generated new response
1550 assert len(retry_events) == 2
1551 assert retry_events[0].type == "stream_options"
1552 assert retry_events[1].type == "thread.item.done"
1553 assert retry_events[1].item.type == "assistant_message"
1554 assert retry_events[1].item.content[0].type == "output_text"
1555 assert retry_events[1].item.content[0].text == "Retried response"
1556
1557 # Verify the responder was called twice with the same user message
1558 assert len(responder_calls) == 2
1559 assert responder_calls[0] == responder_calls[1]
1560 assert responder_calls[0].type == "user_message"
1561 assert responder_calls[0].content[0].text == "Hello, world!"
1562
1563 # Verify the assistant response was replaced in storage
1564 list_result_after = await server.process_non_streaming(
1565 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1566 )
1567 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1568 list_result_after.json
1569 )
1570 assert len(items_after.data) == 2 # Still user message + assistant message
1571 assert items_after.data[0].type == "assistant_message"
1572 assert items_after.data[0].content[0].type == "output_text"
1573 assert items_after.data[0].content[0].text == "Retried response" # New response
1574 assert items_after.data[1].type == "user_message"
1575 assert items_after.data[1].content[0].text == "Hello, world!"
1576
1577
1578async def test_retry_after_item_invalid_item():
1579 async def responder(
1580 thread: ThreadMetadata,
1581 input: UserMessageItem | None,
1582 context: Any,
1583 ) -> AsyncIterator[ThreadStreamEvent]:
1584 yield ThreadItemDoneEvent(
1585 item=AssistantMessageItem(
1586 id="assistant_msg_1",
1587 content=[AssistantMessageContent(text="Response")],
1588 created_at=datetime.now(),
1589 thread_id=thread.id,
1590 ),
1591 )
1592
1593 with make_server(responder) as server:
1594 events = await server.process_streaming(
1595 ThreadsCreateReq(
1596 params=ThreadCreateParams(
1597 input=UserMessageInput(
1598 content=[UserMessageTextContent(text="Hello")],
1599 attachments=[],
1600 inference_options=InferenceOptions(),
1601 )
1602 )
1603 )
1604 )
1605 thread = next(
1606 event.thread for event in events if event.type == "thread.created"
1607 )
1608
1609 items_result = await server.process_non_streaming(
1610 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1611 )
1612 items = TypeAdapter(Page[ThreadItem]).validate_json(items_result.json)
1613 assistant_item_id = items.data[0].id
1614
1615 # Try to retry after an assistant message (should fail)
1616 with pytest.raises(ValueError, match="is not a user message"):
1617 await server.process_streaming(
1618 ThreadsRetryAfterItemReq(
1619 params=ThreadRetryAfterItemParams(
1620 thread_id=thread.id, item_id=assistant_item_id
1621 )
1622 )
1623 )
1624
1625
1626async def test_retry_after_item_with_multiple_messages():
1627 responder_calls = []
1628
1629 async def responder(
1630 thread: ThreadMetadata,
1631 input: UserMessageItem | None,
1632 context: Any,
1633 ) -> AsyncIterator[ThreadStreamEvent]:
1634 responder_calls.append(input)
1635 assert input is not None
1636 assert input.type == "user_message"
1637 yield ThreadItemDoneEvent(
1638 item=AssistantMessageItem(
1639 id=f"assistant_msg_{len(responder_calls)}",
1640 content=[
1641 AssistantMessageContent(
1642 text=f"Response to: {input.content[0].text}"
1643 )
1644 ],
1645 created_at=datetime.now(),
1646 thread_id=thread.id,
1647 ),
1648 )
1649
1650 with make_server(responder) as server:
1651 events = await server.process_streaming(
1652 ThreadsCreateReq(
1653 params=ThreadCreateParams(
1654 input=UserMessageInput(
1655 content=[UserMessageTextContent(text="First message")],
1656 attachments=[],
1657 inference_options=InferenceOptions(),
1658 )
1659 )
1660 )
1661 )
1662 thread = next(
1663 event.thread for event in events if event.type == "thread.created"
1664 )
1665
1666 await server.process_streaming(
1667 ThreadsAddUserMessageReq(
1668 params=ThreadAddUserMessageParams(
1669 thread_id=thread.id,
1670 input=UserMessageInput(
1671 content=[UserMessageTextContent(text="Second message")],
1672 attachments=[],
1673 inference_options=InferenceOptions(),
1674 ),
1675 )
1676 )
1677 )
1678
1679 # Verify we have 4 items: user1, assistant1, user2, assistant2
1680 list_result = await server.process_non_streaming(
1681 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1682 )
1683 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1684 assert len(items_before.data) == 4
1685
1686 # Get the second user message ID to retry after it
1687 second_user_message_id = items_before.data[1].id # Second user message
1688
1689 # Retry after the second user message
1690 retry_events = await server.process_streaming(
1691 ThreadsRetryAfterItemReq(
1692 params=ThreadRetryAfterItemParams(
1693 thread_id=thread.id, item_id=second_user_message_id
1694 )
1695 )
1696 )
1697
1698 assert len(retry_events) == 2
1699 assert retry_events[0].type == "stream_options"
1700 assert retry_events[1].type == "thread.item.done"
1701
1702 # Verify retry used the second user message
1703 assert len(responder_calls) == 3 # Original 2 + 1 retry
1704 assert responder_calls[2].type == "user_message"
1705 assert responder_calls[2].content[0].text == "Second message"
1706
1707 # Verify only the last assistant response was removed and regenerated
1708 list_result_after = await server.process_non_streaming(
1709 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1710 )
1711 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1712 list_result_after.json
1713 )
1714 assert len(items_after.data) == 4 # Still 4 items
1715 # First user message and response should remain unchanged
1716 assert items_after.data[3].type == "user_message"
1717 assert items_after.data[3].content[0].text == "First message"
1718 assert items_after.data[3].content[0].type == "input_text"
1719 assert items_after.data[3].content[0].text == "First message"
1720 assert items_after.data[2].type == "assistant_message"
1721 assert items_after.data[2].content[0].type == "output_text"
1722 assert items_after.data[2].content[0].text == "Response to: First message"
1723 # Second user message should remain, but assistant response should be new
1724 assert items_after.data[1].type == "user_message"
1725 assert items_after.data[1].content[0].text == "Second message"
1726 assert items_after.data[0].type == "assistant_message"
1727 assert items_after.data[0].content[0].type == "output_text"
1728 assert items_after.data[0].content[0].text == "Response to: Second message"
1729 assert items_after.data[0].id != items_before.data[0].id
1730
1731
1732async def test_threads_create_passes_tools_to_responder():
1733 tool_choice = ToolChoice(id="web_search")
1734
1735 async def responder(
1736 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1737 ) -> AsyncIterator[ThreadStreamEvent]:
1738 assert input is not None
1739 assert input.inference_options.tool_choice is not None
1740 assert input.inference_options.tool_choice.id == tool_choice.id
1741
1742 yield ThreadItemDoneEvent(
1743 item=AssistantMessageItem(
1744 id="assistant_msg_tools_create",
1745 content=[AssistantMessageContent(text="ok")],
1746 created_at=datetime.now(),
1747 thread_id=thread.id,
1748 ),
1749 )
1750
1751 with make_server(responder) as server:
1752 events = await server.process_streaming(
1753 ThreadsCreateReq(
1754 params=ThreadCreateParams(
1755 input=UserMessageInput(
1756 content=[UserMessageTextContent(text="Hello with tools")],
1757 attachments=[],
1758 inference_options=InferenceOptions(tool_choice=tool_choice),
1759 )
1760 )
1761 )
1762 )
1763 # Sanity check that the request flowed through.
1764 assert any(e.type == "thread.item.done" for e in events)
1765
1766
1767async def test_retry_after_item_passes_tools_to_responder():
1768 pass
1769
1770
1771async def test_threads_add_message_passes_tools_to_responder():
1772 tool_choice = ToolChoice(id="retrieval")
1773
1774 async def responder(
1775 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1776 ) -> AsyncIterator[ThreadStreamEvent]:
1777 assert input is not None
1778 assert input.inference_options.tool_choice is not None
1779 assert input.inference_options.tool_choice.id == tool_choice.id
1780
1781 yield ThreadItemDoneEvent(
1782 item=AssistantMessageItem(
1783 id="assistant_msg_tools_add",
1784 content=[AssistantMessageContent(text="ok")],
1785 created_at=datetime.now(),
1786 thread_id=thread.id,
1787 ),
1788 )
1789
1790 with make_server(responder) as server:
1791 # Create a thread first.
1792 events = await server.process_streaming(
1793 ThreadsCreateReq(
1794 params=ThreadCreateParams(
1795 input=UserMessageInput(
1796 content=[UserMessageTextContent(text="Start")],
1797 attachments=[],
1798 inference_options=InferenceOptions(tool_choice=tool_choice),
1799 )
1800 )
1801 )
1802 )
1803 thread = next(e.thread for e in events if e.type == "thread.created")
1804
1805 # Add a message with tools and verify they reach the responder.
1806 events = await server.process_streaming(
1807 ThreadsAddUserMessageReq(
1808 params=ThreadAddUserMessageParams(
1809 thread_id=thread.id,
1810 input=UserMessageInput(
1811 content=[UserMessageTextContent(text="Second")],
1812 attachments=[],
1813 inference_options=InferenceOptions(tool_choice=tool_choice),
1814 ),
1815 )
1816 )
1817 )
1818 # Sanity check that the request flowed through.
1819 assert any(e.type == "thread.item.done" for e in events)
1820
1821
1822async def test_custom_id_generators_on_server():
1823 class UniqueIdStore(SQLiteStore):
1824 def __init__(self, path):
1825 super().__init__(path)
1826 self.counter = 0
1827
1828 def generate_thread_id(self, context) -> str:
1829 self.counter += 1
1830 return f"thr_custom_{self.counter}"
1831
1832 def generate_item_id(
1833 self,
1834 item_type: str,
1835 thread: ThreadMetadata,
1836 context,
1837 ) -> str:
1838 self.counter += 1
1839 return f"{item_type}_custom_{self.counter}_{thread.id}"
1840
1841 async def responder(
1842 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1843 ) -> AsyncIterator[ThreadStreamEvent]:
1844 # Emit a tool call to later verify tool_call_output id generation
1845 return
1846 yield
1847
1848 with make_server(responder) as server:
1849 # Swap in our custom ID-generating store, reusing the same DB path.
1850 db_path = cast(SQLiteStore, server.store).db_path
1851 server.store = UniqueIdStore(db_path)
1852
1853 # Create thread and verify thread id and user message id
1854 result = await server.process(
1855 ThreadsCreateReq(
1856 params=ThreadCreateParams(
1857 input=UserMessageInput(
1858 content=[UserMessageTextContent(text="Hello")],
1859 attachments=[],
1860 inference_options=InferenceOptions(),
1861 )
1862 )
1863 ).model_dump_json(),
1864 DEFAULT_CONTEXT,
1865 )
1866 assert isinstance(result, StreamingResult)
1867 events = await decode_streaming_result(result)
1868
1869 thread_created = next(e for e in events if e.type == "thread.created")
1870 assert thread_created.thread.id == "thr_custom_1"
1871
1872 user_message = next(
1873 e.item
1874 for e in events
1875 if e.type == "thread.item.done" and e.item.type == "user_message"
1876 )
1877 assert user_message.id == "message_custom_2_thr_custom_1"
1878