openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
a3c67d6039f8d3b2d0abc05bee4dd02c490454c2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_chatkit_server.py

1797lines · modecode

1import sqlite3
2from contextlib import contextmanager
3from datetime import datetime
4from typing import Any, AsyncIterator, Callable, cast
5
6import pytest
7from helpers.mock_store import SQLiteStore
8from pydantic import AnyUrl, TypeAdapter
9
10from chatkit.actions import Action
11from chatkit.errors import ErrorCode
12from chatkit.server import (
13 ChatKitServer,
14 NonStreamingResult,
15 StreamingResult,
16 stream_widget,
17)
18from chatkit.store import AttachmentStore, NotFoundError
19from chatkit.types import (
20 AssistantMessageContent,
21 AssistantMessageItem,
22 Attachment,
23 AttachmentCreateParams,
24 AttachmentDeleteParams,
25 AttachmentsCreateReq,
26 AttachmentsDeleteReq,
27 ClientToolCallItem,
28 FeedbackKind,
29 FileAttachment,
30 ImageAttachment,
31 InferenceOptions,
32 ItemFeedbackParams,
33 ItemsFeedbackReq,
34 ItemsListParams,
35 ItemsListReq,
36 LockedStatus,
37 Page,
38 ProgressUpdateEvent,
39 Thread,
40 ThreadAddClientToolOutputParams,
41 ThreadAddUserMessageParams,
42 ThreadCreatedEvent,
43 ThreadCreateParams,
44 ThreadCustomActionParams,
45 ThreadDeleteParams,
46 ThreadGetByIdParams,
47 ThreadItem,
48 ThreadItemAddedEvent,
49 ThreadItemDoneEvent,
50 ThreadItemRemovedEvent,
51 ThreadItemReplacedEvent,
52 ThreadItemUpdatedEvent,
53 ThreadListParams,
54 ThreadMetadata,
55 ThreadRetryAfterItemParams,
56 ThreadsAddClientToolOutputReq,
57 ThreadsAddUserMessageReq,
58 ThreadsCreateReq,
59 ThreadsCustomActionReq,
60 ThreadsDeleteReq,
61 ThreadsGetByIdReq,
62 ThreadsListReq,
63 ThreadsRetryAfterItemReq,
64 ThreadStreamEvent,
65 ThreadsUpdateReq,
66 ThreadUpdatedEvent,
67 ThreadUpdateParams,
68 ToolChoice,
69 UserMessageInput,
70 UserMessageItem,
71 UserMessageTextContent,
72 WidgetItem,
73 WidgetRootUpdated,
74)
75from chatkit.widgets import Card, Text
76from tests._types import RequestContext
77from tests.test_store import make_thread_items
78
79server_id = 0
80
81
82DEFAULT_CONTEXT = RequestContext(user_id="test_user")
83
84
85class InMemoryFileStore(AttachmentStore):
86 def __init__(self):
87 self.files = {}
88
89 async def delete_attachment(self, attachment_id: str, context: Any) -> None:
90 del self.files[attachment_id]
91
92 async def create_attachment(
93 self, input: AttachmentCreateParams, context: RequestContext
94 ) -> Attachment:
95 # Simple in-memory attachment creation
96 id = f"atc_{len(self.files) + 1}"
97 if input.mime_type.startswith("image/"):
98 attachment = ImageAttachment(
99 id=id,
100 mime_type=input.mime_type,
101 name=input.name,
102 preview_url=AnyUrl(f"https://example.com/{id}/preview"),
103 upload_url=AnyUrl(f"https://example.com/{id}/upload"),
104 )
105 else:
106 attachment = FileAttachment(
107 id=id,
108 mime_type=input.mime_type,
109 name=input.name,
110 upload_url=AnyUrl(f"https://example.com/{id}/upload"),
111 )
112 self.files[attachment.id] = attachment
113 return attachment
114
115
116def decode_event(event: bytes) -> ThreadStreamEvent:
117 return TypeAdapter(ThreadStreamEvent).validate_json(event.split(b"data: ")[1])
118
119
120async def decode_streaming_result(
121 streaming_result: StreamingResult,
122) -> list[ThreadStreamEvent]:
123 return [decode_event(event) async for event in streaming_result.json_events]
124
125
126@contextmanager
127def make_server(
128 responder: Callable[
129 [ThreadMetadata, UserMessageItem | None, Any], AsyncIterator[ThreadStreamEvent]
130 ]
131 | None = None,
132 handle_feedback: Callable[[str, list[str], FeedbackKind, Any], None] | None = None,
133 action_callback: Callable[
134 [ThreadMetadata, Action[str, Any], WidgetItem | None, Any],
135 AsyncIterator[ThreadStreamEvent],
136 ]
137 | None = None,
138 file_store: AttachmentStore | None = None,
139):
140 global server_id
141 db_path = f"file:{server_id}?mode=memory&cache=shared"
142 server_id += 1
143 # Keep the shared in-memory database from being deleted when the last connection is closed
144 db = sqlite3.connect(db_path, uri=True)
145
146 class TestChatKitServer(ChatKitServer):
147 def __init__(self):
148 super().__init__(SQLiteStore(db_path), file_store)
149
150 def action(
151 self,
152 thread: ThreadMetadata,
153 action: Action[str, Any],
154 sender: WidgetItem | None,
155 context: Any,
156 ) -> AsyncIterator[ThreadStreamEvent]:
157 if action_callback is None:
158 raise ValueError("action_callback not wired up")
159 return action_callback(thread, action, sender, context)
160
161 def respond(
162 self,
163 thread: ThreadMetadata,
164 input_user_message: UserMessageItem | None,
165 context: Any,
166 ) -> AsyncIterator[ThreadStreamEvent]:
167 if responder is None:
168 return self._empty_responder()
169 return responder(thread, input_user_message, context)
170
171 async def _empty_responder(self) -> AsyncIterator[ThreadStreamEvent]:
172 return
173 yield
174
175 async def add_feedback(
176 self,
177 thread_id: str,
178 item_ids: list[str],
179 feedback: FeedbackKind,
180 context: Any,
181 ) -> None:
182 if handle_feedback is None:
183 return
184 handle_feedback(thread_id, item_ids, feedback, context)
185
186 async def process_streaming(
187 self, request_obj, context: Any | None = None
188 ) -> list[ThreadStreamEvent]:
189 result = await self.process(
190 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
191 )
192 assert isinstance(result, StreamingResult)
193 return await decode_streaming_result(result)
194
195 async def process_non_streaming(self, request_obj, context: Any | None = None):
196 result = await self.process(
197 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
198 )
199 assert isinstance(result, NonStreamingResult)
200 return result
201
202 try:
203 yield TestChatKitServer()
204 finally:
205 db.close()
206
207
208async def test_flows_context_to_responder():
209 responder_context = None
210 add_feedback_context = None
211
212 async def responder(
213 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
214 ) -> AsyncIterator[ThreadStreamEvent]:
215 nonlocal responder_context
216 responder_context = context
217 yield ThreadItemDoneEvent(
218 item=AssistantMessageItem(
219 id="msg_1",
220 created_at=datetime.now(),
221 content=[AssistantMessageContent(text="Hello, world!")],
222 thread_id=thread.id,
223 ),
224 )
225
226 def add_feedback(
227 thread_id: str, item_ids: list[str], feedback: FeedbackKind, context: Any
228 ) -> None:
229 nonlocal add_feedback_context
230 add_feedback_context = context
231
232 context = RequestContext(user_id="text_ctx")
233 with make_server(responder, add_feedback) as server:
234 events = await server.process_streaming(
235 ThreadsCreateReq(
236 params=ThreadCreateParams(
237 input=UserMessageInput(
238 content=[UserMessageTextContent(text="Hello, world!")],
239 attachments=[],
240 inference_options=InferenceOptions(),
241 )
242 )
243 ),
244 context,
245 )
246 assert responder_context == context
247 thread = next(
248 event.thread for event in events if event.type == "thread.created"
249 )
250 item = next(
251 event.item
252 for event in events
253 if event.type == "thread.item.done"
254 and event.item.type == "assistant_message"
255 )
256 await server.process_non_streaming(
257 ItemsFeedbackReq(
258 params=ItemFeedbackParams(
259 thread_id=thread.id,
260 item_ids=[item.id],
261 kind="positive",
262 )
263 ),
264 context,
265 )
266 assert add_feedback_context == context
267
268
269async def test_creates_thread():
270 async def responder(
271 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
272 ) -> AsyncIterator[ThreadStreamEvent]:
273 return
274 yield
275
276 with make_server(responder) as server:
277 events = await server.process_streaming(
278 ThreadsCreateReq(
279 params=ThreadCreateParams(
280 input=UserMessageInput(
281 content=[UserMessageTextContent(text="Hello, world!")],
282 attachments=[],
283 inference_options=InferenceOptions(),
284 quoted_text="Quoted text",
285 )
286 )
287 )
288 )
289
290 thread_created_event = next(
291 event for event in events if event.type == "thread.created"
292 )
293 assert thread_created_event.thread.title is None
294
295 item_done_event = next(
296 event.item for event in events if event.type == "thread.item.done"
297 )
298
299 assert item_done_event.type == "user_message"
300 assert item_done_event.content[0].text == "Hello, world!"
301 assert item_done_event.quoted_text == "Quoted text"
302
303
304async def test_saves_thread_title():
305 async def responder(
306 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
307 ) -> AsyncIterator[ThreadStreamEvent]:
308 thread.title = "Updated title"
309 return
310 yield
311
312 with make_server(responder) as server:
313 events = await server.process_streaming(
314 ThreadsCreateReq(
315 params=ThreadCreateParams(
316 input=UserMessageInput(
317 content=[UserMessageTextContent(text="Hello, world!")],
318 attachments=[],
319 inference_options=InferenceOptions(),
320 )
321 )
322 )
323 )
324 thread = next(
325 event.thread for event in events if event.type == "thread.created"
326 )
327 assert (
328 await server.store.load_thread(
329 thread.id, RequestContext(user_id="test_user")
330 )
331 ).title == "Updated title"
332 assert events[-1].type == "thread.updated"
333 assert events[-1].thread.title == "Updated title"
334
335
336async def test_saves_thread_metadata():
337 async def responder(
338 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
339 ) -> AsyncIterator[ThreadStreamEvent]:
340 thread.metadata["test_key"] = "test_value"
341 return
342 yield
343
344 with make_server(responder) as server:
345 events = await server.process_streaming(
346 ThreadsCreateReq(
347 params=ThreadCreateParams(
348 input=UserMessageInput(
349 content=[UserMessageTextContent(text="Hello, world!")],
350 attachments=[],
351 inference_options=InferenceOptions(),
352 )
353 )
354 )
355 )
356 thread = next(
357 event.thread for event in events if event.type == "thread.created"
358 )
359 assert (await server.store.load_thread(thread.id, DEFAULT_CONTEXT)).metadata[
360 "test_key"
361 ] == "test_value"
362 assert events[-1].type == "thread.updated"
363
364
365async def test_saves_thread_locked_fields():
366 async def responder(
367 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
368 ) -> AsyncIterator[ThreadStreamEvent]:
369 thread.status = LockedStatus(reason="Because")
370 return
371 yield
372
373 with make_server(responder) as server:
374 events = await server.process_streaming(
375 ThreadsCreateReq(
376 params=ThreadCreateParams(
377 input=UserMessageInput(
378 content=[UserMessageTextContent(text="Hello, world!")],
379 attachments=[],
380 inference_options=InferenceOptions(),
381 )
382 )
383 )
384 )
385 thread = next(
386 event.thread for event in events if event.type == "thread.created"
387 )
388 loaded = await server.store.load_thread(
389 thread.id, RequestContext(user_id="test_user")
390 )
391 assert loaded.status == LockedStatus(reason="Because")
392 assert events[-1].type == "thread.updated"
393 assert events[-1].thread.status == LockedStatus(reason="Because")
394
395
396async def test_emits_thread_updated_mid_stream_and_persists():
397 async def responder(
398 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
399 ) -> AsyncIterator[ThreadStreamEvent]:
400 # First assistant message
401 yield ThreadItemDoneEvent(
402 item=AssistantMessageItem(
403 id="assistant_a",
404 content=[AssistantMessageContent(text="A")],
405 created_at=datetime.now(),
406 thread_id=thread.id,
407 ),
408 )
409
410 # Update the thread mid-stream
411 thread.title = "Mid-stream title"
412
413 # Second assistant message, after which thread.updated should be emitted
414 yield ThreadItemDoneEvent(
415 item=AssistantMessageItem(
416 id="assistant_b",
417 content=[AssistantMessageContent(text="B")],
418 created_at=datetime.now(),
419 thread_id=thread.id,
420 ),
421 )
422
423 # Continue streaming after the update event
424 yield ThreadItemDoneEvent(
425 item=AssistantMessageItem(
426 id="assistant_c",
427 content=[AssistantMessageContent(text="C")],
428 created_at=datetime.now(),
429 thread_id=thread.id,
430 ),
431 )
432
433 with make_server(responder) as server:
434 events = await server.process_streaming(
435 ThreadsCreateReq(
436 params=ThreadCreateParams(
437 input=UserMessageInput(
438 content=[UserMessageTextContent(text="Hello")],
439 attachments=[],
440 inference_options=InferenceOptions(),
441 )
442 )
443 )
444 )
445
446 # Find the created thread
447 thread = next(e.thread for e in events if isinstance(e, ThreadCreatedEvent))
448
449 # Ensure a thread.updated event occurs right after assistant_b
450 idx_b = next(
451 i
452 for i, e in enumerate(events)
453 if isinstance(e, ThreadItemDoneEvent)
454 and isinstance(e.item, AssistantMessageItem)
455 and e.item.id == "assistant_b"
456 )
457 assert isinstance(events[idx_b + 1], ThreadUpdatedEvent)
458 updated_event = cast(ThreadUpdatedEvent, events[idx_b + 1])
459 assert updated_event.thread.title == "Mid-stream title"
460
461 # Streaming continues after the update event
462 assert isinstance(events[idx_b + 2], ThreadItemDoneEvent)
463 done_event = cast(ThreadItemDoneEvent, events[idx_b + 2])
464 assert isinstance(done_event.item, AssistantMessageItem)
465 assert done_event.item.id == "assistant_c"
466
467 # Persisted in the store
468 saved = await server.store.load_thread(thread.id, DEFAULT_CONTEXT)
469 assert saved.title == "Mid-stream title"
470
471
472async def test_respond_with_tool_call():
473 async def responder(
474 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
475 ) -> AsyncIterator[ThreadStreamEvent]:
476 if isinstance(input, UserMessageItem):
477 yield ThreadItemDoneEvent(
478 item=ClientToolCallItem(
479 id="msg_1",
480 created_at=datetime.now(),
481 name="tool_call_1",
482 arguments={"arg1": "val1", "arg2": False},
483 call_id="tool_call_1",
484 thread_id=thread.id,
485 ),
486 )
487 elif input is None:
488 yield ThreadItemDoneEvent(
489 item=AssistantMessageItem(
490 id="msg_2",
491 content=[
492 AssistantMessageContent(text="Glad the tool call succeeded!")
493 ],
494 created_at=datetime.now(),
495 thread_id=thread.id,
496 ),
497 )
498
499 with make_server(responder) as server:
500 events = await server.process_streaming(
501 ThreadsCreateReq(
502 params=ThreadCreateParams(
503 input=UserMessageInput(
504 content=[UserMessageTextContent(text="Hello, world!")],
505 attachments=[],
506 inference_options=InferenceOptions(),
507 )
508 )
509 )
510 )
511
512 assert len(events) == 3
513 assert events[0].type == "thread.created"
514 thread = events[0].thread
515
516 assert events[1].type == "thread.item.done"
517 assert events[1].item.type == "user_message"
518
519 assert events[2].type == "thread.item.done"
520 assert events[2].item.type == "client_tool_call"
521 assert events[2].item.id == "msg_1"
522 assert events[2].item.name == "tool_call_1"
523 assert events[2].item.arguments == {"arg1": "val1", "arg2": False}
524 assert events[2].item.call_id == "tool_call_1"
525
526 events = await server.process_streaming(
527 ThreadsAddClientToolOutputReq(
528 params=ThreadAddClientToolOutputParams(
529 thread_id=thread.id,
530 result={"text": "Wow!"},
531 )
532 )
533 )
534
535 assert len(events) == 1
536 assert events[0].type == "thread.item.done"
537 assert events[0].item.type == "assistant_message"
538
539
540async def test_removes_tool_call_if_no_output_provided():
541 async def responder(
542 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
543 ) -> AsyncIterator[ThreadStreamEvent]:
544 assert isinstance(input, UserMessageItem)
545 assert input.content[0].type == "input_text"
546 if input.content[0].text == "Message 1":
547 yield ThreadItemDoneEvent(
548 item=ClientToolCallItem(
549 id="msg_1",
550 created_at=datetime.now(),
551 name="tool_call_1",
552 arguments={"arg1": "val1", "arg2": False},
553 call_id="tool_call_1",
554 thread_id=thread.id,
555 ),
556 )
557 else:
558 yield ThreadItemDoneEvent(
559 item=AssistantMessageItem(
560 id="msg_2",
561 content=[AssistantMessageContent(text="All done!")],
562 created_at=datetime.now(),
563 thread_id=thread.id,
564 ),
565 )
566
567 with make_server(responder) as server:
568 events = await server.process_streaming(
569 ThreadsCreateReq(
570 params=ThreadCreateParams(
571 input=UserMessageInput(
572 content=[UserMessageTextContent(text="Message 1")],
573 attachments=[],
574 inference_options=InferenceOptions(),
575 )
576 )
577 )
578 )
579 thread = next(
580 event.thread for event in events if event.type == "thread.created"
581 )
582
583 await server.process_streaming(
584 ThreadsAddUserMessageReq(
585 params=ThreadAddUserMessageParams(
586 thread_id=thread.id,
587 input=UserMessageInput(
588 content=[UserMessageTextContent(text="Message 2")],
589 attachments=[],
590 inference_options=InferenceOptions(),
591 ),
592 )
593 )
594 )
595
596 items_result = await server.process_non_streaming(
597 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
598 )
599 items = TypeAdapter(Page[ThreadItem]).validate_json(items_result.json)
600 assert len(items.data) == 3
601 assert items.data[0].type == "assistant_message"
602 assert items.data[1].type == "user_message"
603 assert items.data[2].type == "user_message"
604
605
606async def test_respond_with_tool_status():
607 async def responder(
608 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
609 ) -> AsyncIterator[ThreadStreamEvent]:
610 yield ProgressUpdateEvent(text="Tool status")
611 yield ThreadItemDoneEvent(
612 item=AssistantMessageItem(
613 id="msg_4",
614 content=[AssistantMessageContent(text="Hello, world!")],
615 created_at=datetime.now(),
616 thread_id=thread.id,
617 ),
618 )
619
620 with make_server(responder) as server:
621 events = await server.process_streaming(
622 ThreadsCreateReq(
623 params=ThreadCreateParams(
624 input=UserMessageInput(
625 content=[UserMessageTextContent(text="Hello, world!")],
626 attachments=[],
627 inference_options=InferenceOptions(),
628 )
629 )
630 )
631 )
632
633 assert len(events) == 4
634 assert events[0].type == "thread.created"
635 assert events[1].type == "thread.item.done"
636 assert events[2].type == "progress_update"
637 assert events[3].type == "thread.item.done"
638
639
640async def test_list_threads_response():
641 with make_server() as server:
642 for i in range(25):
643 await server.process_streaming(
644 ThreadsCreateReq(
645 params=ThreadCreateParams(
646 input=UserMessageInput(
647 content=[UserMessageTextContent(text=f"Thread {i}")],
648 attachments=[],
649 inference_options=InferenceOptions(),
650 )
651 )
652 )
653 )
654 list_result = await server.process_non_streaming(
655 ThreadsListReq(params=ThreadListParams(limit=20))
656 )
657 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
658 assert len(threads.data) == 20
659
660 # Newest first
661 assert threads.data[0].created_at > threads.data[1].created_at
662 assert threads.has_more is True
663 assert threads.after is not None
664
665 more_threads = await server.process_non_streaming(
666 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
667 )
668 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(more_threads.json)
669 assert len(threads.data) == 5
670 assert threads.has_more is False
671 assert not threads.after
672
673
674async def test_list_items_response():
675 async def responder(
676 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
677 ) -> AsyncIterator[ThreadStreamEvent]:
678 for i in range(25):
679 yield ThreadItemDoneEvent(
680 item=UserMessageItem(
681 id=f"msg_{i}",
682 content=[UserMessageTextContent(text=f"Message {i}")],
683 attachments=[],
684 inference_options=InferenceOptions(),
685 thread_id=thread.id,
686 created_at=datetime.now(),
687 ),
688 )
689
690 with make_server(responder) as server:
691 events = await server.process_streaming(
692 ThreadsCreateReq(
693 params=ThreadCreateParams(
694 input=UserMessageInput(
695 content=[UserMessageTextContent(text="Thread 1")],
696 attachments=[],
697 inference_options=InferenceOptions(),
698 )
699 )
700 )
701 )
702 thread = next(
703 event.thread for event in events if event.type == "thread.created"
704 )
705 list_result = await server.process_non_streaming(
706 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
707 )
708 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
709 assert len(items.data) == 20
710 # Newest first
711 assert items.data[0].created_at > items.data[1].created_at
712 assert items.has_more is True
713 assert items.after is not None
714
715 list_result = await server.process_non_streaming(
716 ItemsListReq(
717 params=ItemsListParams(thread_id=thread.id, after=items.data[-1].id)
718 )
719 )
720 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
721 assert len(items.data) == 6 # +1 for user message
722 assert items.has_more is False
723
724
725async def test_calls_action():
726 actions = []
727
728 async def responder(
729 thread: ThreadMetadata,
730 input: UserMessageItem | None,
731 context: Any,
732 ) -> AsyncIterator[ThreadStreamEvent]:
733 yield ThreadItemDoneEvent(
734 item=WidgetItem(
735 id="widget_1",
736 type="widget",
737 created_at=datetime.now(),
738 thread_id=thread.id,
739 widget=Card(
740 children=[
741 Text(
742 key="text_1",
743 value="test",
744 )
745 ],
746 ),
747 ),
748 )
749
750 async def action(
751 thread: ThreadMetadata,
752 action: Action[str, Any],
753 sender: WidgetItem | None,
754 context: Any,
755 ) -> AsyncIterator[ThreadStreamEvent]:
756 actions.append((action, sender))
757 assert sender
758
759 yield ThreadItemUpdatedEvent(
760 item_id=sender.id,
761 update=WidgetRootUpdated(
762 widget=Card(
763 children=[
764 Text(value="Email sent!"),
765 ]
766 ),
767 ),
768 )
769
770 with make_server(responder, action_callback=action) as server:
771 events = await server.process_streaming(
772 ThreadsCreateReq(
773 params=ThreadCreateParams(
774 input=UserMessageInput(
775 content=[UserMessageTextContent(text="Show widget")],
776 attachments=[],
777 inference_options=InferenceOptions(),
778 )
779 )
780 )
781 )
782 thread = next(
783 event.thread for event in events if event.type == "thread.created"
784 )
785 widget_item = next(
786 event.item
787 for event in events
788 if isinstance(event, ThreadItemDoneEvent)
789 and isinstance(event.item, WidgetItem)
790 )
791
792 events = await server.process_streaming(
793 ThreadsCustomActionReq(
794 params=ThreadCustomActionParams(
795 thread_id=thread.id,
796 item_id=widget_item.id,
797 action=Action(type="create_user", payload={"user_id": "123"}),
798 )
799 )
800 )
801
802 assert actions
803 assert actions[0] == (
804 Action(type="create_user", payload={"user_id": "123"}),
805 widget_item,
806 )
807
808 assert len(events) == 1
809 assert events[0].type == "thread.item.updated"
810 assert isinstance(events[0], ThreadItemUpdatedEvent)
811 assert events[0].update.type == "widget.root.updated"
812 assert events[0].update.widget == Card(children=[Text(value="Email sent!")])
813
814
815async def test_add_feedback():
816 called = []
817
818 def handle_feedback(thread_id, item_ids, feedback, context):
819 called.append((thread_id, item_ids, feedback, context))
820
821 async def responder(
822 thread: ThreadMetadata,
823 input: UserMessageItem | None,
824 context: Any,
825 ) -> AsyncIterator[ThreadStreamEvent]:
826 yield ThreadItemDoneEvent(
827 item=AssistantMessageItem(
828 id="assistant_msg_1",
829 content=[
830 AssistantMessageContent(text="Feedback received!"),
831 ],
832 created_at=datetime.now(),
833 thread_id=thread.id,
834 ),
835 )
836 yield ThreadItemDoneEvent(
837 item=WidgetItem(
838 id="widget_1",
839 type="widget",
840 created_at=datetime.now(),
841 widget=Card(children=[Text(key="text", value="Widget content")]),
842 thread_id=thread.id,
843 ),
844 )
845
846 with make_server(responder, handle_feedback=handle_feedback) as server:
847 events = await server.process_streaming(
848 ThreadsCreateReq(
849 params=ThreadCreateParams(
850 input=UserMessageInput(
851 content=[UserMessageTextContent(text="Test feedback")],
852 attachments=[],
853 inference_options=InferenceOptions(),
854 )
855 )
856 )
857 )
858 thread = next(
859 event.thread for event in events if event.type == "thread.created"
860 )
861
862 await server.process_non_streaming(
863 ItemsFeedbackReq(
864 params=ItemFeedbackParams(
865 thread_id=thread.id,
866 item_ids=["assistant_msg_1"],
867 kind="positive",
868 )
869 )
870 )
871 await server.process_non_streaming(
872 ItemsFeedbackReq(
873 params=ItemFeedbackParams(
874 thread_id=thread.id,
875 item_ids=["widget_1"],
876 kind="negative",
877 )
878 )
879 )
880
881 assert len(called) == 2
882 assert called[0][0] == thread.id
883 assert called[0][1] == ["assistant_msg_1"]
884 assert called[0][2] == "positive"
885
886 assert called[1][0] == thread.id
887 assert called[1][1] == ["widget_1"]
888 assert called[1][2] == "negative"
889
890
891async def test_get_thread_by_id():
892 make_thread_items()
893
894 with make_server() as server:
895 # Create a thread by sending a ThreadCreateRequest
896 events = await server.process_streaming(
897 ThreadsCreateReq(
898 params=ThreadCreateParams(
899 input=UserMessageInput(
900 content=[UserMessageTextContent(text="Test thread")],
901 attachments=[],
902 inference_options=InferenceOptions(),
903 )
904 )
905 )
906 )
907 thread = next(
908 event.thread for event in events if event.type == "thread.created"
909 )
910
911 # Retrieve the thread by id
912 result = await server.process_non_streaming(
913 ThreadsGetByIdReq(
914 params=ThreadGetByIdParams(thread_id=thread.id),
915 )
916 )
917 loaded_thread = TypeAdapter(Thread).validate_json(result.json)
918 assert loaded_thread.id == thread.id
919 assert loaded_thread.title == thread.title
920 assert loaded_thread.items.data[0].type == "user_message"
921 assert loaded_thread.items.data[0].content[0].text == "Test thread"
922
923
924async def test_create_file():
925 store = InMemoryFileStore()
926 file_name = "test-file-name"
927 file_content_type = "text/plain"
928
929 with make_server(file_store=store) as server:
930 result = await server.process_non_streaming(
931 AttachmentsCreateReq(
932 params=AttachmentCreateParams(
933 name=file_name, size=1024, mime_type=file_content_type
934 )
935 )
936 )
937 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
938
939 # Verify response includes file
940 assert attachment.id is not None
941 assert attachment.mime_type == file_content_type
942 assert attachment.name == file_name
943 assert attachment.type == "file"
944 assert attachment.upload_url == AnyUrl(
945 f"https://example.com/{attachment.id}/upload"
946 )
947 assert attachment.id in store.files
948
949
950async def test_create_image_file():
951 store = InMemoryFileStore()
952 file_name = "test-file-name"
953 file_content_type = "image/png"
954
955 with make_server(file_store=store) as server:
956 result = await server.process_non_streaming(
957 AttachmentsCreateReq(
958 params=AttachmentCreateParams(
959 name=file_name,
960 size=1024,
961 mime_type=file_content_type,
962 )
963 )
964 )
965 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
966
967 assert attachment.id is not None
968 assert attachment.mime_type == file_content_type
969 assert attachment.name == file_name
970 assert attachment.type == "image"
971 assert attachment.preview_url == AnyUrl(
972 f"https://example.com/{attachment.id}/preview"
973 )
974 assert attachment.upload_url == AnyUrl(
975 f"https://example.com/{attachment.id}/upload"
976 )
977
978 assert attachment.id in store.files
979
980
981async def test_create_file_without_filestore():
982 with pytest.raises(RuntimeError):
983 with make_server() as server:
984 await server.process_non_streaming(
985 AttachmentsCreateReq(
986 params=AttachmentCreateParams(
987 name="test-file-name",
988 size=1024,
989 mime_type="text/plain",
990 )
991 )
992 )
993
994
995async def test_delete_file():
996 store = InMemoryFileStore()
997 store.files["test-file-id"] = {"id": "test-file-id"}
998 with make_server(file_store=store) as server:
999 await server.process_non_streaming(
1000 AttachmentsDeleteReq(
1001 params=AttachmentDeleteParams(attachment_id="test-file-id")
1002 )
1003 )
1004 assert store.files == {}
1005
1006
1007async def test_delete_file_without_filestore():
1008 with pytest.raises(RuntimeError):
1009 with make_server() as server:
1010 await server.process_non_streaming(
1011 AttachmentsDeleteReq(
1012 params=AttachmentDeleteParams(attachment_id="test-file-id")
1013 )
1014 )
1015
1016
1017async def test_list_threads_paginated_response():
1018 with make_server() as server:
1019 for i in range(25):
1020 await server.process_streaming(
1021 ThreadsCreateReq(
1022 params=ThreadCreateParams(
1023 input=UserMessageInput(
1024 content=[UserMessageTextContent(text=f"Thread {i}")],
1025 attachments=[],
1026 inference_options=InferenceOptions(),
1027 )
1028 )
1029 )
1030 )
1031 list_result = await server.process_non_streaming(
1032 ThreadsListReq(params=ThreadListParams(limit=20))
1033 )
1034 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1035 assert len(threads.data) == 20
1036 assert threads.has_more is True
1037 assert threads.after is not None
1038
1039 list_result = await server.process_non_streaming(
1040 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
1041 )
1042 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1043 assert len(threads.data) == 5
1044 assert threads.has_more is False
1045
1046
1047async def test_threads_update_request():
1048 with make_server() as server:
1049 events = await server.process_streaming(
1050 ThreadsCreateReq(
1051 params=ThreadCreateParams(
1052 input=UserMessageInput(
1053 content=[UserMessageTextContent(text="Hello, world!")],
1054 attachments=[],
1055 inference_options=InferenceOptions(),
1056 )
1057 )
1058 )
1059 )
1060 thread = next(
1061 event.thread for event in events if event.type == "thread.created"
1062 )
1063 updated_thread = await server.process_non_streaming(
1064 ThreadsUpdateReq(
1065 params=ThreadUpdateParams(thread_id=thread.id, title="New Title")
1066 )
1067 )
1068 updated_thread = TypeAdapter[Thread](Thread).validate_json(updated_thread.json)
1069 assert updated_thread.title == "New Title"
1070
1071
1072async def test_returns_widget_item_generator():
1073 thread = ThreadMetadata(
1074 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1075 )
1076
1077 def render_widget(i: int) -> Card:
1078 return Card(children=[Text(id="text", value="Hello, world"[:i])])
1079
1080 async def widget_generator():
1081 yield render_widget(0)
1082 yield render_widget(3)
1083 yield render_widget(6)
1084 yield render_widget(12)
1085
1086 events = [event async for event in stream_widget(thread, widget_generator())]
1087
1088 assert len(events) == 5
1089 assert isinstance(events[0], ThreadItemAddedEvent)
1090 assert isinstance(events[0].item, WidgetItem)
1091 assert events[0].item.widget == Card(children=[Text(id="text", value="")])
1092
1093 assert isinstance(events[1], ThreadItemUpdatedEvent)
1094 assert events[1].update.type == "widget.streaming_text.value_delta"
1095 assert events[1].update.component_id == "text"
1096 assert events[1].update.delta == "Hel"
1097
1098 assert isinstance(events[2], ThreadItemUpdatedEvent)
1099 assert events[2].update.type == "widget.streaming_text.value_delta"
1100 assert events[2].update.component_id == "text"
1101 assert events[2].update.delta == "lo,"
1102
1103 assert isinstance(events[3], ThreadItemUpdatedEvent)
1104 assert events[3].update.type == "widget.streaming_text.value_delta"
1105 assert events[3].update.component_id == "text"
1106 assert events[3].update.delta == " world"
1107
1108 assert isinstance(events[4], ThreadItemDoneEvent)
1109 assert isinstance(events[4].item, WidgetItem)
1110 assert events[4].item.widget == Card(
1111 children=[Text(id="text", value="Hello, world")]
1112 )
1113
1114
1115async def test_returns_widget_item_generator_full_replace():
1116 thread = ThreadMetadata(
1117 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1118 )
1119
1120 async def widget_generator():
1121 yield Card(children=[Text(id="text", value="Hello")])
1122 yield Card(children=[Text(id="text", value="World")])
1123
1124 events = [event async for event in stream_widget(thread, widget_generator())]
1125
1126 assert len(events) == 3
1127 assert isinstance(events[0], ThreadItemAddedEvent)
1128 assert isinstance(events[0].item, WidgetItem)
1129 assert events[0].item.widget == Card(children=[Text(id="text", value="Hello")])
1130
1131 assert isinstance(events[1], ThreadItemUpdatedEvent)
1132 assert events[1].update.type == "widget.root.updated"
1133 assert events[1].update.widget == Card(children=[Text(id="text", value="World")])
1134
1135 assert isinstance(events[2], ThreadItemDoneEvent)
1136 assert isinstance(events[2].item, WidgetItem)
1137 assert events[2].item.widget == Card(children=[Text(id="text", value="World")])
1138
1139
1140async def test_delete_thread():
1141 with make_server() as server:
1142 events = await server.process_streaming(
1143 ThreadsCreateReq(
1144 params=ThreadCreateParams(
1145 input=UserMessageInput(
1146 content=[UserMessageTextContent(text="Thread to delete")],
1147 attachments=[],
1148 inference_options=InferenceOptions(),
1149 )
1150 )
1151 )
1152 )
1153 thread = next(
1154 event.thread for event in events if event.type == "thread.created"
1155 )
1156
1157 response = await server.process_non_streaming(
1158 ThreadsDeleteReq(params=ThreadDeleteParams(thread_id=thread.id))
1159 )
1160 assert response.json == b"{}"
1161
1162 try:
1163 await server.process_non_streaming(
1164 ThreadsGetByIdReq(
1165 params=ThreadGetByIdParams(thread_id=thread.id),
1166 )
1167 )
1168 assert False, "Expected ValueError for deleted thread"
1169 except NotFoundError as e:
1170 assert f"Thread {thread.id} not found" in str(e)
1171
1172
1173async def test_thread_item_removed_event_removes_item():
1174 removed_id = "msg_to_remove"
1175
1176 async def responder(
1177 thread: ThreadMetadata,
1178 input: UserMessageItem | None,
1179 context: Any,
1180 ) -> AsyncIterator[ThreadStreamEvent]:
1181 yield ThreadItemDoneEvent(
1182 item=UserMessageItem(
1183 id=removed_id,
1184 content=[UserMessageTextContent(text="To be removed")],
1185 attachments=[],
1186 inference_options=InferenceOptions(),
1187 thread_id=thread.id,
1188 created_at=datetime.now(),
1189 ),
1190 )
1191 yield ThreadItemRemovedEvent(item_id=removed_id)
1192
1193 with make_server(responder) as server:
1194 events = await server.process_streaming(
1195 ThreadsCreateReq(
1196 params=ThreadCreateParams(
1197 input=UserMessageInput(
1198 content=[UserMessageTextContent(text="Trigger removal")],
1199 attachments=[],
1200 inference_options=InferenceOptions(),
1201 )
1202 )
1203 )
1204 )
1205 thread = next(
1206 event.thread for event in events if event.type == "thread.created"
1207 )
1208 assert any(
1209 e.type == "thread.item.removed" and e.item_id == removed_id for e in events
1210 )
1211 # The item should not be present in the store
1212 list_result = await server.process_non_streaming(
1213 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1214 )
1215 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1216 assert all(i.id != removed_id for i in items)
1217
1218
1219async def test_raising_in_responder_yields_error_event():
1220 async def responder(
1221 thread: ThreadMetadata,
1222 input: UserMessageItem | None,
1223 context: Any,
1224 ) -> AsyncIterator[ThreadStreamEvent]:
1225 raise ValueError("Test error")
1226 yield
1227
1228 with make_server(responder) as server:
1229 result = await server.process(
1230 ThreadsCreateReq(
1231 params=ThreadCreateParams(
1232 input=UserMessageInput(
1233 content=[UserMessageTextContent(text="Test error")],
1234 attachments=[],
1235 inference_options=InferenceOptions(),
1236 )
1237 )
1238 ).model_dump_json(),
1239 RequestContext(user_id="test_user"),
1240 )
1241 assert isinstance(result, StreamingResult)
1242 iter = result.__aiter__()
1243
1244 # Yield an errro
1245 async for e in iter:
1246 e = decode_event(e)
1247 if e.type == "error":
1248 assert e.code == ErrorCode.STREAM_ERROR
1249 break
1250 else:
1251 assert False, "Expected error event"
1252
1253 # Then finish
1254 try:
1255 await iter.__anext__()
1256 except StopAsyncIteration:
1257 pass
1258 else:
1259 assert False, "Expected StopAsyncIteration"
1260
1261
1262async def test_thread_item_replaced_event_replaces_item():
1263 original_id = "msg_to_replace"
1264
1265 async def responder(
1266 thread: ThreadMetadata,
1267 input: UserMessageItem | None,
1268 context: Any,
1269 ) -> AsyncIterator[ThreadStreamEvent]:
1270 # First add an item
1271 yield ThreadItemDoneEvent(
1272 item=UserMessageItem(
1273 id=original_id,
1274 content=[UserMessageTextContent(text="Original content")],
1275 attachments=[],
1276 inference_options=InferenceOptions(),
1277 thread_id=thread.id,
1278 created_at=datetime.now(),
1279 ),
1280 )
1281 # Then replace it
1282 replacement_item = AssistantMessageItem(
1283 id=original_id,
1284 content=[AssistantMessageContent(text="This is the replacement content")],
1285 created_at=datetime.now(),
1286 thread_id=thread.id,
1287 )
1288 yield ThreadItemReplacedEvent(item=replacement_item)
1289
1290 with make_server(responder) as server:
1291 events = await server.process_streaming(
1292 ThreadsCreateReq(
1293 params=ThreadCreateParams(
1294 input=UserMessageInput(
1295 content=[UserMessageTextContent(text="Trigger replacement")],
1296 attachments=[],
1297 inference_options=InferenceOptions(),
1298 )
1299 )
1300 )
1301 )
1302 thread = next(
1303 event.thread for event in events if event.type == "thread.created"
1304 )
1305
1306 # Verify the replacement event was generated
1307 assert any(
1308 e.type == "thread.item.replaced" and e.item.id == original_id
1309 for e in events
1310 )
1311
1312 # Verify the item was replaced in the store
1313 list_result = await server.process_non_streaming(
1314 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1315 )
1316 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1317
1318 # Find the replaced item
1319 replaced_item = next((i for i in items if i.id == original_id), None)
1320 assert replaced_item is not None
1321 assert isinstance(replaced_item, AssistantMessageItem)
1322 assert isinstance(replaced_item.content[0], AssistantMessageContent)
1323 assert replaced_item.content[0].text == "This is the replacement content"
1324
1325
1326async def test_thread_item_replaced_event_with_widget():
1327 widget_id = "widget_to_replace"
1328 original_widget = Card(children=[Text(value="Original widget content")])
1329 replacement_widget = Card(children=[Text(value="Replaced widget content")])
1330
1331 async def responder(
1332 thread: ThreadMetadata,
1333 input: UserMessageItem | None,
1334 context: Any,
1335 ) -> AsyncIterator[ThreadStreamEvent]:
1336 # First add a widget item
1337 yield ThreadItemDoneEvent(
1338 item=WidgetItem(
1339 id=widget_id,
1340 widget=original_widget,
1341 created_at=datetime.now(),
1342 thread_id=thread.id,
1343 ),
1344 )
1345 # Then replace it
1346 yield ThreadItemReplacedEvent(
1347 item=WidgetItem(
1348 id=widget_id,
1349 widget=replacement_widget,
1350 created_at=datetime.now(),
1351 thread_id=thread.id,
1352 ),
1353 )
1354
1355 with make_server(responder) as server:
1356 events = await server.process_streaming(
1357 ThreadsCreateReq(
1358 params=ThreadCreateParams(
1359 input=UserMessageInput(
1360 content=[
1361 UserMessageTextContent(text="Trigger widget replacement")
1362 ],
1363 attachments=[],
1364 inference_options=InferenceOptions(),
1365 )
1366 )
1367 )
1368 )
1369 thread = next(
1370 event.thread for event in events if event.type == "thread.created"
1371 )
1372
1373 # Verify the replacement event was generated
1374 replacement_event = next(
1375 (
1376 e
1377 for e in events
1378 if e.type == "thread.item.replaced" and e.item.id == widget_id
1379 ),
1380 None,
1381 )
1382 assert replacement_event is not None
1383 assert isinstance(replacement_event.item, WidgetItem)
1384 assert replacement_event.item.widget == replacement_widget
1385
1386 # Verify the item was replaced in the store
1387 list_result = await server.process_non_streaming(
1388 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1389 )
1390 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1391
1392 # Find the replaced widget item
1393 replaced_item = next((i for i in items if i.id == widget_id), None)
1394 assert replaced_item is not None
1395 assert isinstance(replaced_item, WidgetItem)
1396 assert isinstance(replaced_item.widget, Card)
1397 assert isinstance(replaced_item.widget.children[0], Text)
1398 assert replaced_item.widget.children[0].value == "Replaced widget content"
1399
1400
1401async def test_retry_after_item():
1402 responder_calls = []
1403
1404 async def responder(
1405 thread: ThreadMetadata,
1406 input: UserMessageItem | None,
1407 context: Any,
1408 ) -> AsyncIterator[ThreadStreamEvent]:
1409 responder_calls.append(input)
1410
1411 if len(responder_calls) == 1:
1412 # First response - return an assistant message
1413 yield ThreadItemDoneEvent(
1414 item=AssistantMessageItem(
1415 id="assistant_msg_1",
1416 content=[AssistantMessageContent(text="First response")],
1417 created_at=datetime.now(),
1418 thread_id=thread.id,
1419 ),
1420 )
1421 else:
1422 # Retry response - return different content
1423 yield ThreadItemDoneEvent(
1424 item=AssistantMessageItem(
1425 id="assistant_msg_2",
1426 content=[AssistantMessageContent(text="Retried response")],
1427 created_at=datetime.now(),
1428 thread_id=thread.id,
1429 ),
1430 )
1431
1432 with make_server(responder) as server:
1433 events = await server.process_streaming(
1434 ThreadsCreateReq(
1435 params=ThreadCreateParams(
1436 input=UserMessageInput(
1437 content=[UserMessageTextContent(text="Hello, world!")],
1438 attachments=[],
1439 inference_options=InferenceOptions(),
1440 )
1441 )
1442 )
1443 )
1444 thread = next(
1445 event.thread for event in events if event.type == "thread.created"
1446 )
1447
1448 list_result = await server.process_non_streaming(
1449 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1450 )
1451 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1452 assert len(items_before.data) == 2
1453 assert items_before.data[0].type == "assistant_message"
1454 assert items_before.data[0].content[0].type == "output_text"
1455 assert items_before.data[0].content[0].text == "First response"
1456 assert items_before.data[1].type == "user_message"
1457 assert items_before.data[1].content[0].text == "Hello, world!"
1458
1459 # Get the user message ID for retrying after it
1460 user_message_id = items_before.data[1].id
1461
1462 # Retry after the user message
1463 retry_events = await server.process_streaming(
1464 ThreadsRetryAfterItemReq(
1465 params=ThreadRetryAfterItemParams(
1466 thread_id=thread.id, item_id=user_message_id
1467 )
1468 )
1469 )
1470
1471 # Verify the retry generated new response
1472 assert len(retry_events) == 1
1473 assert retry_events[0].type == "thread.item.done"
1474 assert retry_events[0].item.type == "assistant_message"
1475 assert retry_events[0].item.content[0].type == "output_text"
1476 assert retry_events[0].item.content[0].text == "Retried response"
1477
1478 # Verify the responder was called twice with the same user message
1479 assert len(responder_calls) == 2
1480 assert responder_calls[0] == responder_calls[1]
1481 assert responder_calls[0].type == "user_message"
1482 assert responder_calls[0].content[0].text == "Hello, world!"
1483
1484 # Verify the assistant response was replaced in storage
1485 list_result_after = await server.process_non_streaming(
1486 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1487 )
1488 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1489 list_result_after.json
1490 )
1491 assert len(items_after.data) == 2 # Still user message + assistant message
1492 assert items_after.data[0].type == "assistant_message"
1493 assert items_after.data[0].content[0].type == "output_text"
1494 assert items_after.data[0].content[0].text == "Retried response" # New response
1495 assert items_after.data[1].type == "user_message"
1496 assert items_after.data[1].content[0].text == "Hello, world!"
1497
1498
1499async def test_retry_after_item_invalid_item():
1500 async def responder(
1501 thread: ThreadMetadata,
1502 input: UserMessageItem | None,
1503 context: Any,
1504 ) -> AsyncIterator[ThreadStreamEvent]:
1505 yield ThreadItemDoneEvent(
1506 item=AssistantMessageItem(
1507 id="assistant_msg_1",
1508 content=[AssistantMessageContent(text="Response")],
1509 created_at=datetime.now(),
1510 thread_id=thread.id,
1511 ),
1512 )
1513
1514 with make_server(responder) as server:
1515 events = await server.process_streaming(
1516 ThreadsCreateReq(
1517 params=ThreadCreateParams(
1518 input=UserMessageInput(
1519 content=[UserMessageTextContent(text="Hello")],
1520 attachments=[],
1521 inference_options=InferenceOptions(),
1522 )
1523 )
1524 )
1525 )
1526 thread = next(
1527 event.thread for event in events if event.type == "thread.created"
1528 )
1529
1530 items_result = await server.process_non_streaming(
1531 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1532 )
1533 items = TypeAdapter(Page[ThreadItem]).validate_json(items_result.json)
1534 assistant_item_id = items.data[0].id
1535
1536 # Try to retry after an assistant message (should fail)
1537 with pytest.raises(ValueError, match="is not a user message"):
1538 await server.process_streaming(
1539 ThreadsRetryAfterItemReq(
1540 params=ThreadRetryAfterItemParams(
1541 thread_id=thread.id, item_id=assistant_item_id
1542 )
1543 )
1544 )
1545
1546
1547async def test_retry_after_item_with_multiple_messages():
1548 responder_calls = []
1549
1550 async def responder(
1551 thread: ThreadMetadata,
1552 input: UserMessageItem | None,
1553 context: Any,
1554 ) -> AsyncIterator[ThreadStreamEvent]:
1555 responder_calls.append(input)
1556 assert input is not None
1557 assert input.type == "user_message"
1558 yield ThreadItemDoneEvent(
1559 item=AssistantMessageItem(
1560 id=f"assistant_msg_{len(responder_calls)}",
1561 content=[
1562 AssistantMessageContent(
1563 text=f"Response to: {input.content[0].text}"
1564 )
1565 ],
1566 created_at=datetime.now(),
1567 thread_id=thread.id,
1568 ),
1569 )
1570
1571 with make_server(responder) as server:
1572 events = await server.process_streaming(
1573 ThreadsCreateReq(
1574 params=ThreadCreateParams(
1575 input=UserMessageInput(
1576 content=[UserMessageTextContent(text="First message")],
1577 attachments=[],
1578 inference_options=InferenceOptions(),
1579 )
1580 )
1581 )
1582 )
1583 thread = next(
1584 event.thread for event in events if event.type == "thread.created"
1585 )
1586
1587 await server.process_streaming(
1588 ThreadsAddUserMessageReq(
1589 params=ThreadAddUserMessageParams(
1590 thread_id=thread.id,
1591 input=UserMessageInput(
1592 content=[UserMessageTextContent(text="Second message")],
1593 attachments=[],
1594 inference_options=InferenceOptions(),
1595 ),
1596 )
1597 )
1598 )
1599
1600 # Verify we have 4 items: user1, assistant1, user2, assistant2
1601 list_result = await server.process_non_streaming(
1602 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1603 )
1604 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1605 assert len(items_before.data) == 4
1606
1607 # Get the second user message ID to retry after it
1608 second_user_message_id = items_before.data[1].id # Second user message
1609
1610 # Retry after the second user message
1611 retry_events = await server.process_streaming(
1612 ThreadsRetryAfterItemReq(
1613 params=ThreadRetryAfterItemParams(
1614 thread_id=thread.id, item_id=second_user_message_id
1615 )
1616 )
1617 )
1618
1619 assert len(retry_events) == 1
1620 assert retry_events[0].type == "thread.item.done"
1621
1622 # Verify retry used the second user message
1623 assert len(responder_calls) == 3 # Original 2 + 1 retry
1624 assert responder_calls[2].type == "user_message"
1625 assert responder_calls[2].content[0].text == "Second message"
1626
1627 # Verify only the last assistant response was removed and regenerated
1628 list_result_after = await server.process_non_streaming(
1629 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1630 )
1631 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1632 list_result_after.json
1633 )
1634 assert len(items_after.data) == 4 # Still 4 items
1635 # First user message and response should remain unchanged
1636 assert items_after.data[3].type == "user_message"
1637 assert items_after.data[3].content[0].text == "First message"
1638 assert items_after.data[3].content[0].type == "input_text"
1639 assert items_after.data[3].content[0].text == "First message"
1640 assert items_after.data[2].type == "assistant_message"
1641 assert items_after.data[2].content[0].type == "output_text"
1642 assert items_after.data[2].content[0].text == "Response to: First message"
1643 # Second user message should remain, but assistant response should be new
1644 assert items_after.data[1].type == "user_message"
1645 assert items_after.data[1].content[0].text == "Second message"
1646 assert items_after.data[0].type == "assistant_message"
1647 assert items_after.data[0].content[0].type == "output_text"
1648 assert items_after.data[0].content[0].text == "Response to: Second message"
1649 assert items_after.data[0].id != items_before.data[0].id
1650
1651
1652async def test_threads_create_passes_tools_to_responder():
1653 tool_choice = ToolChoice(id="web_search")
1654
1655 async def responder(
1656 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1657 ) -> AsyncIterator[ThreadStreamEvent]:
1658 assert input is not None
1659 assert input.inference_options.tool_choice is not None
1660 assert input.inference_options.tool_choice.id == tool_choice.id
1661
1662 yield ThreadItemDoneEvent(
1663 item=AssistantMessageItem(
1664 id="assistant_msg_tools_create",
1665 content=[AssistantMessageContent(text="ok")],
1666 created_at=datetime.now(),
1667 thread_id=thread.id,
1668 ),
1669 )
1670
1671 with make_server(responder) as server:
1672 events = await server.process_streaming(
1673 ThreadsCreateReq(
1674 params=ThreadCreateParams(
1675 input=UserMessageInput(
1676 content=[UserMessageTextContent(text="Hello with tools")],
1677 attachments=[],
1678 inference_options=InferenceOptions(tool_choice=tool_choice),
1679 )
1680 )
1681 )
1682 )
1683 # Sanity check that the request flowed through.
1684 assert any(e.type == "thread.item.done" for e in events)
1685
1686
1687async def test_retry_after_item_passes_tools_to_responder():
1688 pass
1689
1690
1691async def test_threads_add_message_passes_tools_to_responder():
1692 tool_choice = ToolChoice(id="retrieval")
1693
1694 async def responder(
1695 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1696 ) -> AsyncIterator[ThreadStreamEvent]:
1697 assert input is not None
1698 assert input.inference_options.tool_choice is not None
1699 assert input.inference_options.tool_choice.id == tool_choice.id
1700
1701 yield ThreadItemDoneEvent(
1702 item=AssistantMessageItem(
1703 id="assistant_msg_tools_add",
1704 content=[AssistantMessageContent(text="ok")],
1705 created_at=datetime.now(),
1706 thread_id=thread.id,
1707 ),
1708 )
1709
1710 with make_server(responder) as server:
1711 # Create a thread first.
1712 events = await server.process_streaming(
1713 ThreadsCreateReq(
1714 params=ThreadCreateParams(
1715 input=UserMessageInput(
1716 content=[UserMessageTextContent(text="Start")],
1717 attachments=[],
1718 inference_options=InferenceOptions(tool_choice=tool_choice),
1719 )
1720 )
1721 )
1722 )
1723 thread = next(e.thread for e in events if e.type == "thread.created")
1724
1725 # Add a message with tools and verify they reach the responder.
1726 events = await server.process_streaming(
1727 ThreadsAddUserMessageReq(
1728 params=ThreadAddUserMessageParams(
1729 thread_id=thread.id,
1730 input=UserMessageInput(
1731 content=[UserMessageTextContent(text="Second")],
1732 attachments=[],
1733 inference_options=InferenceOptions(tool_choice=tool_choice),
1734 ),
1735 )
1736 )
1737 )
1738 # Sanity check that the request flowed through.
1739 assert any(e.type == "thread.item.done" for e in events)
1740
1741
1742async def test_custom_id_generators_on_server():
1743 class UniqueIdStore(SQLiteStore):
1744 def __init__(self, path):
1745 super().__init__(path)
1746 self.counter = 0
1747
1748 def generate_thread_id(self, context) -> str:
1749 self.counter += 1
1750 return f"thr_custom_{self.counter}"
1751
1752 def generate_item_id(
1753 self,
1754 item_type: str,
1755 thread: ThreadMetadata,
1756 context,
1757 ) -> str:
1758 self.counter += 1
1759 return f"{item_type}_custom_{self.counter}_{thread.id}"
1760
1761 async def responder(
1762 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1763 ) -> AsyncIterator[ThreadStreamEvent]:
1764 # Emit a tool call to later verify tool_call_output id generation
1765 return
1766 yield
1767
1768 with make_server(responder) as server:
1769 # Swap in our custom ID-generating store, reusing the same DB path.
1770 db_path = cast(SQLiteStore, server.store).db_path
1771 server.store = UniqueIdStore(db_path)
1772
1773 # Create thread and verify thread id and user message id
1774 result = await server.process(
1775 ThreadsCreateReq(
1776 params=ThreadCreateParams(
1777 input=UserMessageInput(
1778 content=[UserMessageTextContent(text="Hello")],
1779 attachments=[],
1780 inference_options=InferenceOptions(),
1781 )
1782 )
1783 ).model_dump_json(),
1784 DEFAULT_CONTEXT,
1785 )
1786 assert isinstance(result, StreamingResult)
1787 events = await decode_streaming_result(result)
1788
1789 thread_created = next(e for e in events if e.type == "thread.created")
1790 assert thread_created.thread.id == "thr_custom_1"
1791
1792 user_message = next(
1793 e.item
1794 for e in events
1795 if e.type == "thread.item.done" and e.item.type == "user_message"
1796 )
1797 assert user_message.id == "message_custom_2_thr_custom_1"
1798