openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
e0747eec143e784337b181f8fd651d38ddfcae8f

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_chatkit_server.py

1893lines · modecode

1import asyncio
2import sqlite3
3from contextlib import contextmanager
4from datetime import datetime
5from typing import Any, AsyncIterator, Callable, cast
6
7import pytest
8from helpers.mock_store import SQLiteStore
9from pydantic import AnyUrl, TypeAdapter
10
11from chatkit.actions import Action
12from chatkit.errors import ErrorCode
13from chatkit.server import (
14 ChatKitServer,
15 NonStreamingResult,
16 StreamingResult,
17 stream_widget,
18)
19from chatkit.store import (
20 AttachmentStore,
21 NotFoundError,
22 StoreItemType,
23 default_generate_id,
24)
25from chatkit.types import (
26 AssistantMessageContent,
27 AssistantMessageContentPartTextDelta,
28 AssistantMessageItem,
29 Attachment,
30 AttachmentCreateParams,
31 AttachmentDeleteParams,
32 AttachmentsCreateReq,
33 AttachmentsDeleteReq,
34 AttachmentUploadDescriptor,
35 ClientToolCallItem,
36 FeedbackKind,
37 FileAttachment,
38 ImageAttachment,
39 InferenceOptions,
40 ItemFeedbackParams,
41 ItemsFeedbackReq,
42 ItemsListParams,
43 ItemsListReq,
44 LockedStatus,
45 Page,
46 ProgressUpdateEvent,
47 Thread,
48 ThreadAddClientToolOutputParams,
49 ThreadAddUserMessageParams,
50 ThreadCreatedEvent,
51 ThreadCreateParams,
52 ThreadCustomActionParams,
53 ThreadDeleteParams,
54 ThreadGetByIdParams,
55 ThreadItem,
56 ThreadItemAddedEvent,
57 ThreadItemDoneEvent,
58 ThreadItemRemovedEvent,
59 ThreadItemReplacedEvent,
60 ThreadItemUpdatedEvent,
61 ThreadListParams,
62 ThreadMetadata,
63 ThreadRetryAfterItemParams,
64 ThreadsAddClientToolOutputReq,
65 ThreadsAddUserMessageReq,
66 ThreadsCreateReq,
67 ThreadsCustomActionReq,
68 ThreadsDeleteReq,
69 ThreadsGetByIdReq,
70 ThreadsListReq,
71 ThreadsRetryAfterItemReq,
72 ThreadStreamEvent,
73 ThreadsUpdateReq,
74 ThreadUpdatedEvent,
75 ThreadUpdateParams,
76 ToolChoice,
77 UserMessageInput,
78 UserMessageItem,
79 UserMessageTextContent,
80 WidgetItem,
81 WidgetRootUpdated,
82)
83from chatkit.widgets import Card, Text
84from tests._types import RequestContext
85from tests.test_store import make_thread_items
86
87server_id = 0
88
89
90DEFAULT_CONTEXT = RequestContext(user_id="test_user")
91
92
93class InMemoryFileStore(AttachmentStore):
94 def __init__(self):
95 self.files = {}
96
97 async def delete_attachment(self, attachment_id: str, context: Any) -> None:
98 del self.files[attachment_id]
99
100 async def create_attachment(
101 self, input: AttachmentCreateParams, context: RequestContext
102 ) -> Attachment:
103 # Simple in-memory attachment creation
104 id = f"atc_{len(self.files) + 1}"
105 if input.mime_type.startswith("image/"):
106 attachment = ImageAttachment(
107 id=id,
108 mime_type=input.mime_type,
109 name=input.name,
110 preview_url=AnyUrl(f"https://example.com/{id}/preview"),
111 upload_descriptor=AttachmentUploadDescriptor(
112 url=AnyUrl(f"https://example.com/{id}/upload"),
113 method="PUT",
114 headers={"X-My-Header": "my-value"},
115 ),
116 )
117 else:
118 attachment = FileAttachment(
119 id=id,
120 mime_type=input.mime_type,
121 name=input.name,
122 upload_descriptor=AttachmentUploadDescriptor(
123 url=AnyUrl(f"https://example.com/{id}/upload"),
124 method="PUT",
125 headers={"X-My-Header": "my-value"},
126 ),
127 )
128 self.files[attachment.id] = attachment
129 return attachment
130
131
132def decode_event(event: bytes) -> ThreadStreamEvent:
133 return TypeAdapter(ThreadStreamEvent).validate_json(event.split(b"data: ")[1])
134
135
136async def decode_streaming_result(
137 streaming_result: StreamingResult,
138) -> list[ThreadStreamEvent]:
139 return [decode_event(event) async for event in streaming_result.json_events]
140
141
142@contextmanager
143def make_server(
144 responder: Callable[
145 [ThreadMetadata, UserMessageItem | None, Any], AsyncIterator[ThreadStreamEvent]
146 ]
147 | None = None,
148 handle_feedback: Callable[[str, list[str], FeedbackKind, Any], None] | None = None,
149 action_callback: Callable[
150 [ThreadMetadata, Action[str, Any], WidgetItem | None, Any],
151 AsyncIterator[ThreadStreamEvent],
152 ]
153 | None = None,
154 file_store: AttachmentStore | None = None,
155):
156 global server_id
157 db_path = f"file:{server_id}?mode=memory&cache=shared"
158 server_id += 1
159 # Keep the shared in-memory database from being deleted when the last connection is closed
160 db = sqlite3.connect(db_path, uri=True)
161
162 class TestChatKitServer(ChatKitServer):
163 def __init__(self):
164 super().__init__(SQLiteStore(db_path), file_store)
165
166 def action(
167 self,
168 thread: ThreadMetadata,
169 action: Action[str, Any],
170 sender: WidgetItem | None,
171 context: Any,
172 ) -> AsyncIterator[ThreadStreamEvent]:
173 if action_callback is None:
174 raise ValueError("action_callback not wired up")
175 return action_callback(thread, action, sender, context)
176
177 def respond(
178 self,
179 thread: ThreadMetadata,
180 input_user_message: UserMessageItem | None,
181 context: Any,
182 ) -> AsyncIterator[ThreadStreamEvent]:
183 if responder is None:
184 return self._empty_responder()
185 return responder(thread, input_user_message, context)
186
187 async def _empty_responder(self) -> AsyncIterator[ThreadStreamEvent]:
188 return
189 yield
190
191 async def add_feedback(
192 self,
193 thread_id: str,
194 item_ids: list[str],
195 feedback: FeedbackKind,
196 context: Any,
197 ) -> None:
198 if handle_feedback is None:
199 return
200 handle_feedback(thread_id, item_ids, feedback, context)
201
202 async def process_streaming(
203 self, request_obj, context: Any | None = None
204 ) -> list[ThreadStreamEvent]:
205 result = await self.process(
206 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
207 )
208 assert isinstance(result, StreamingResult)
209 return await decode_streaming_result(result)
210
211 async def process_non_streaming(self, request_obj, context: Any | None = None):
212 result = await self.process(
213 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
214 )
215 assert isinstance(result, NonStreamingResult)
216 return result
217
218 try:
219 yield TestChatKitServer()
220 finally:
221 db.close()
222
223
224async def test_stream_cancellation_persists_pending_assistant_message_and_hidden_context():
225 async def responder(
226 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
227 ) -> AsyncIterator[ThreadStreamEvent]:
228 yield ThreadItemAddedEvent(
229 item=AssistantMessageItem(
230 id="assistant-message-pending",
231 created_at=datetime.now(),
232 content=[AssistantMessageContent(text="Hello, ")],
233 thread_id=thread.id,
234 )
235 )
236 yield ThreadItemUpdatedEvent(
237 item_id="assistant-message-pending",
238 update=AssistantMessageContentPartTextDelta(
239 content_index=0,
240 delta="World!",
241 ),
242 )
243 raise asyncio.CancelledError()
244
245 with make_server(responder) as server:
246 # Allow hidden_context id generation in this test store
247 original_generate_item_id = server.store.generate_item_id
248
249 def generate_item_id(
250 item_type: StoreItemType, thread: ThreadMetadata, context: Any
251 ):
252 if item_type == "sdk_hidden_context":
253 return default_generate_id("sdk_hidden_context")
254 return original_generate_item_id(item_type, thread, context)
255
256 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
257
258 stream = await server.process(
259 ThreadsCreateReq(
260 params=ThreadCreateParams(
261 input=UserMessageInput(
262 content=[UserMessageTextContent(text="Hello")],
263 attachments=[],
264 inference_options=InferenceOptions(),
265 )
266 )
267 ).model_dump_json(),
268 DEFAULT_CONTEXT,
269 )
270 assert isinstance(stream, StreamingResult)
271
272 events: list[ThreadStreamEvent] = []
273 with pytest.raises(asyncio.CancelledError): # noqa: PT012
274 async for raw in stream.json_events:
275 events.append(decode_event(raw))
276
277 thread = next(e.thread for e in events if e.type == "thread.created")
278 items = await server.store.load_thread_items(
279 thread.id, None, 1, "desc", DEFAULT_CONTEXT
280 )
281 hidden_context_item = items.data[-1]
282 assert hidden_context_item.type == "sdk_hidden_context"
283 assert (
284 hidden_context_item.content
285 == "The user cancelled the stream. Stop responding to the prior request."
286 )
287
288 assistant_message_item = await server.store.load_item(
289 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
290 )
291 assert assistant_message_item.type == "assistant_message"
292 assert assistant_message_item.content[0].text == "Hello, World!"
293
294
295async def test_stream_cancellation_does_not_persist_pending_empty_assistant_message():
296 async def responder(
297 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
298 ) -> AsyncIterator[ThreadStreamEvent]:
299 yield ThreadItemAddedEvent(
300 item=AssistantMessageItem(
301 id="assistant-message-pending",
302 created_at=datetime.now(),
303 content=[],
304 thread_id=thread.id,
305 )
306 )
307 raise asyncio.CancelledError()
308
309 with make_server(responder) as server:
310 original_generate_item_id = server.store.generate_item_id
311
312 def generate_item_id(
313 item_type: StoreItemType, thread: ThreadMetadata, context: Any
314 ):
315 if item_type == "sdk_hidden_context":
316 return default_generate_id("sdk_hidden_context")
317 return original_generate_item_id(item_type, thread, context)
318
319 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
320
321 stream = await server.process(
322 ThreadsCreateReq(
323 params=ThreadCreateParams(
324 input=UserMessageInput(
325 content=[UserMessageTextContent(text="Hello")],
326 attachments=[],
327 inference_options=InferenceOptions(),
328 )
329 )
330 ).model_dump_json(),
331 DEFAULT_CONTEXT,
332 )
333 assert isinstance(stream, StreamingResult)
334
335 events: list[ThreadStreamEvent] = []
336 with pytest.raises(asyncio.CancelledError): # noqa: PT012
337 async for raw in stream.json_events:
338 events.append(decode_event(raw))
339
340 thread = next(e.thread for e in events if e.type == "thread.created")
341 items = await server.store.load_thread_items(
342 thread.id, None, 1, "desc", DEFAULT_CONTEXT
343 )
344 hidden_context_item = items.data[-1]
345 assert hidden_context_item.type == "sdk_hidden_context"
346 assert (
347 hidden_context_item.content
348 == "The user cancelled the stream. Stop responding to the prior request."
349 )
350
351 with pytest.raises(NotFoundError):
352 await server.store.load_item(
353 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
354 )
355
356
357async def test_flows_context_to_responder():
358 responder_context = None
359 add_feedback_context = None
360
361 async def responder(
362 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
363 ) -> AsyncIterator[ThreadStreamEvent]:
364 nonlocal responder_context
365 responder_context = context
366 yield ThreadItemDoneEvent(
367 item=AssistantMessageItem(
368 id="msg_1",
369 created_at=datetime.now(),
370 content=[AssistantMessageContent(text="Hello, world!")],
371 thread_id=thread.id,
372 ),
373 )
374
375 def add_feedback(
376 thread_id: str, item_ids: list[str], feedback: FeedbackKind, context: Any
377 ) -> None:
378 nonlocal add_feedback_context
379 add_feedback_context = context
380
381 context = RequestContext(user_id="text_ctx")
382 with make_server(responder, add_feedback) as server:
383 events = await server.process_streaming(
384 ThreadsCreateReq(
385 params=ThreadCreateParams(
386 input=UserMessageInput(
387 content=[UserMessageTextContent(text="Hello, world!")],
388 attachments=[],
389 inference_options=InferenceOptions(),
390 )
391 )
392 ),
393 context,
394 )
395 assert responder_context == context
396 thread = next(
397 event.thread for event in events if event.type == "thread.created"
398 )
399 item = next(
400 event.item
401 for event in events
402 if event.type == "thread.item.done"
403 and event.item.type == "assistant_message"
404 )
405 await server.process_non_streaming(
406 ItemsFeedbackReq(
407 params=ItemFeedbackParams(
408 thread_id=thread.id,
409 item_ids=[item.id],
410 kind="positive",
411 )
412 ),
413 context,
414 )
415 assert add_feedback_context == context
416
417
418async def test_creates_thread():
419 async def responder(
420 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
421 ) -> AsyncIterator[ThreadStreamEvent]:
422 return
423 yield
424
425 with make_server(responder) as server:
426 events = await server.process_streaming(
427 ThreadsCreateReq(
428 params=ThreadCreateParams(
429 input=UserMessageInput(
430 content=[UserMessageTextContent(text="Hello, world!")],
431 attachments=[],
432 inference_options=InferenceOptions(),
433 quoted_text="Quoted text",
434 )
435 )
436 )
437 )
438
439 thread_created_event = next(
440 event for event in events if event.type == "thread.created"
441 )
442 assert thread_created_event.thread.title is None
443
444 item_done_event = next(
445 event.item for event in events if event.type == "thread.item.done"
446 )
447
448 assert item_done_event.type == "user_message"
449 assert item_done_event.content[0].text == "Hello, world!"
450 assert item_done_event.quoted_text == "Quoted text"
451
452
453async def test_saves_thread_title():
454 async def responder(
455 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
456 ) -> AsyncIterator[ThreadStreamEvent]:
457 thread.title = "Updated title"
458 return
459 yield
460
461 with make_server(responder) as server:
462 events = await server.process_streaming(
463 ThreadsCreateReq(
464 params=ThreadCreateParams(
465 input=UserMessageInput(
466 content=[UserMessageTextContent(text="Hello, world!")],
467 attachments=[],
468 inference_options=InferenceOptions(),
469 )
470 )
471 )
472 )
473 thread = next(
474 event.thread for event in events if event.type == "thread.created"
475 )
476 assert (
477 await server.store.load_thread(
478 thread.id, RequestContext(user_id="test_user")
479 )
480 ).title == "Updated title"
481 assert events[-1].type == "thread.updated"
482 assert events[-1].thread.title == "Updated title"
483
484
485async def test_saves_thread_metadata():
486 async def responder(
487 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
488 ) -> AsyncIterator[ThreadStreamEvent]:
489 thread.metadata["test_key"] = "test_value"
490 return
491 yield
492
493 with make_server(responder) as server:
494 events = await server.process_streaming(
495 ThreadsCreateReq(
496 params=ThreadCreateParams(
497 input=UserMessageInput(
498 content=[UserMessageTextContent(text="Hello, world!")],
499 attachments=[],
500 inference_options=InferenceOptions(),
501 )
502 )
503 )
504 )
505 thread = next(
506 event.thread for event in events if event.type == "thread.created"
507 )
508 assert (await server.store.load_thread(thread.id, DEFAULT_CONTEXT)).metadata[
509 "test_key"
510 ] == "test_value"
511 assert events[-1].type == "thread.updated"
512
513
514async def test_saves_thread_locked_fields():
515 async def responder(
516 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
517 ) -> AsyncIterator[ThreadStreamEvent]:
518 thread.status = LockedStatus(reason="Because")
519 return
520 yield
521
522 with make_server(responder) as server:
523 events = await server.process_streaming(
524 ThreadsCreateReq(
525 params=ThreadCreateParams(
526 input=UserMessageInput(
527 content=[UserMessageTextContent(text="Hello, world!")],
528 attachments=[],
529 inference_options=InferenceOptions(),
530 )
531 )
532 )
533 )
534 thread = next(
535 event.thread for event in events if event.type == "thread.created"
536 )
537 loaded = await server.store.load_thread(
538 thread.id, RequestContext(user_id="test_user")
539 )
540 assert loaded.status == LockedStatus(reason="Because")
541 assert events[-1].type == "thread.updated"
542 assert events[-1].thread.status == LockedStatus(reason="Because")
543
544
545async def test_emits_thread_updated_mid_stream_and_persists():
546 async def responder(
547 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
548 ) -> AsyncIterator[ThreadStreamEvent]:
549 # First assistant message
550 yield ThreadItemDoneEvent(
551 item=AssistantMessageItem(
552 id="assistant_a",
553 content=[AssistantMessageContent(text="A")],
554 created_at=datetime.now(),
555 thread_id=thread.id,
556 ),
557 )
558
559 # Update the thread mid-stream
560 thread.title = "Mid-stream title"
561
562 # Second assistant message, after which thread.updated should be emitted
563 yield ThreadItemDoneEvent(
564 item=AssistantMessageItem(
565 id="assistant_b",
566 content=[AssistantMessageContent(text="B")],
567 created_at=datetime.now(),
568 thread_id=thread.id,
569 ),
570 )
571
572 # Continue streaming after the update event
573 yield ThreadItemDoneEvent(
574 item=AssistantMessageItem(
575 id="assistant_c",
576 content=[AssistantMessageContent(text="C")],
577 created_at=datetime.now(),
578 thread_id=thread.id,
579 ),
580 )
581
582 with make_server(responder) as server:
583 events = await server.process_streaming(
584 ThreadsCreateReq(
585 params=ThreadCreateParams(
586 input=UserMessageInput(
587 content=[UserMessageTextContent(text="Hello")],
588 attachments=[],
589 inference_options=InferenceOptions(),
590 )
591 )
592 )
593 )
594
595 # Find the created thread
596 thread = next(e.thread for e in events if isinstance(e, ThreadCreatedEvent))
597
598 # Ensure a thread.updated event occurs right after assistant_b
599 idx_b = next(
600 i
601 for i, e in enumerate(events)
602 if isinstance(e, ThreadItemDoneEvent)
603 and isinstance(e.item, AssistantMessageItem)
604 and e.item.id == "assistant_b"
605 )
606 assert isinstance(events[idx_b + 1], ThreadUpdatedEvent)
607 updated_event = cast(ThreadUpdatedEvent, events[idx_b + 1])
608 assert updated_event.thread.title == "Mid-stream title"
609
610 # Streaming continues after the update event
611 assert isinstance(events[idx_b + 2], ThreadItemDoneEvent)
612 done_event = cast(ThreadItemDoneEvent, events[idx_b + 2])
613 assert isinstance(done_event.item, AssistantMessageItem)
614 assert done_event.item.id == "assistant_c"
615
616 # Persisted in the store
617 saved = await server.store.load_thread(thread.id, DEFAULT_CONTEXT)
618 assert saved.title == "Mid-stream title"
619
620
621async def test_respond_with_tool_call():
622 async def responder(
623 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
624 ) -> AsyncIterator[ThreadStreamEvent]:
625 if isinstance(input, UserMessageItem):
626 yield ThreadItemDoneEvent(
627 item=ClientToolCallItem(
628 id="msg_1",
629 created_at=datetime.now(),
630 name="tool_call_1",
631 arguments={"arg1": "val1", "arg2": False},
632 call_id="tool_call_1",
633 thread_id=thread.id,
634 ),
635 )
636 elif input is None:
637 yield ThreadItemDoneEvent(
638 item=AssistantMessageItem(
639 id="msg_2",
640 content=[
641 AssistantMessageContent(text="Glad the tool call succeeded!")
642 ],
643 created_at=datetime.now(),
644 thread_id=thread.id,
645 ),
646 )
647
648 with make_server(responder) as server:
649 events = await server.process_streaming(
650 ThreadsCreateReq(
651 params=ThreadCreateParams(
652 input=UserMessageInput(
653 content=[UserMessageTextContent(text="Hello, world!")],
654 attachments=[],
655 inference_options=InferenceOptions(),
656 )
657 )
658 )
659 )
660
661 assert len(events) == 4
662 assert events[0].type == "thread.created"
663 thread = events[0].thread
664
665 assert events[1].type == "thread.item.done"
666 assert events[1].item.type == "user_message"
667
668 assert events[2].type == "stream_options"
669
670 assert events[3].type == "thread.item.done"
671 assert events[3].item.type == "client_tool_call"
672 assert events[3].item.id == "msg_1"
673 assert events[3].item.name == "tool_call_1"
674 assert events[3].item.arguments == {"arg1": "val1", "arg2": False}
675 assert events[3].item.call_id == "tool_call_1"
676
677 events = await server.process_streaming(
678 ThreadsAddClientToolOutputReq(
679 params=ThreadAddClientToolOutputParams(
680 thread_id=thread.id,
681 result={"text": "Wow!"},
682 )
683 )
684 )
685
686 assert len(events) == 2
687 assert events[0].type == "stream_options"
688 assert events[1].type == "thread.item.done"
689 assert events[1].item.type == "assistant_message"
690
691
692async def test_respond_with_tool_status():
693 async def responder(
694 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
695 ) -> AsyncIterator[ThreadStreamEvent]:
696 yield ProgressUpdateEvent(text="Tool status")
697 yield ThreadItemDoneEvent(
698 item=AssistantMessageItem(
699 id="msg_4",
700 content=[AssistantMessageContent(text="Hello, world!")],
701 created_at=datetime.now(),
702 thread_id=thread.id,
703 ),
704 )
705
706 with make_server(responder) as server:
707 events = await server.process_streaming(
708 ThreadsCreateReq(
709 params=ThreadCreateParams(
710 input=UserMessageInput(
711 content=[UserMessageTextContent(text="Hello, world!")],
712 attachments=[],
713 inference_options=InferenceOptions(),
714 )
715 )
716 )
717 )
718
719 assert len(events) == 5
720 assert events[0].type == "thread.created"
721 assert events[1].type == "thread.item.done"
722 assert events[2].type == "stream_options"
723 assert events[3].type == "progress_update"
724 assert events[4].type == "thread.item.done"
725
726
727async def test_list_threads_response():
728 with make_server() as server:
729 for i in range(25):
730 await server.process_streaming(
731 ThreadsCreateReq(
732 params=ThreadCreateParams(
733 input=UserMessageInput(
734 content=[UserMessageTextContent(text=f"Thread {i}")],
735 attachments=[],
736 inference_options=InferenceOptions(),
737 )
738 )
739 )
740 )
741 list_result = await server.process_non_streaming(
742 ThreadsListReq(params=ThreadListParams(limit=20))
743 )
744 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
745 assert len(threads.data) == 20
746
747 # Newest first
748 assert threads.data[0].created_at > threads.data[1].created_at
749 assert threads.has_more is True
750 assert threads.after is not None
751
752 more_threads = await server.process_non_streaming(
753 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
754 )
755 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(more_threads.json)
756 assert len(threads.data) == 5
757 assert threads.has_more is False
758 assert not threads.after
759
760
761async def test_list_items_response():
762 async def responder(
763 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
764 ) -> AsyncIterator[ThreadStreamEvent]:
765 for i in range(25):
766 yield ThreadItemDoneEvent(
767 item=UserMessageItem(
768 id=f"msg_{i}",
769 content=[UserMessageTextContent(text=f"Message {i}")],
770 attachments=[],
771 inference_options=InferenceOptions(),
772 thread_id=thread.id,
773 created_at=datetime.now(),
774 ),
775 )
776
777 with make_server(responder) as server:
778 events = await server.process_streaming(
779 ThreadsCreateReq(
780 params=ThreadCreateParams(
781 input=UserMessageInput(
782 content=[UserMessageTextContent(text="Thread 1")],
783 attachments=[],
784 inference_options=InferenceOptions(),
785 )
786 )
787 )
788 )
789 thread = next(
790 event.thread for event in events if event.type == "thread.created"
791 )
792 list_result = await server.process_non_streaming(
793 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
794 )
795 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
796 assert len(items.data) == 20
797 # Newest first
798 assert items.data[0].created_at > items.data[1].created_at
799 assert items.has_more is True
800 assert items.after is not None
801
802 list_result = await server.process_non_streaming(
803 ItemsListReq(
804 params=ItemsListParams(thread_id=thread.id, after=items.data[-1].id)
805 )
806 )
807 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
808 assert len(items.data) == 6 # +1 for user message
809 assert items.has_more is False
810
811
812async def test_calls_action():
813 actions = []
814
815 async def responder(
816 thread: ThreadMetadata,
817 input: UserMessageItem | None,
818 context: Any,
819 ) -> AsyncIterator[ThreadStreamEvent]:
820 yield ThreadItemDoneEvent(
821 item=WidgetItem(
822 id="widget_1",
823 type="widget",
824 created_at=datetime.now(),
825 thread_id=thread.id,
826 widget=Card(
827 children=[
828 Text(
829 key="text_1",
830 value="test",
831 )
832 ],
833 ),
834 ),
835 )
836
837 async def action(
838 thread: ThreadMetadata,
839 action: Action[str, Any],
840 sender: WidgetItem | None,
841 context: Any,
842 ) -> AsyncIterator[ThreadStreamEvent]:
843 actions.append((action, sender))
844 assert sender
845
846 yield ThreadItemUpdatedEvent(
847 item_id=sender.id,
848 update=WidgetRootUpdated(
849 widget=Card(
850 children=[
851 Text(value="Email sent!"),
852 ]
853 ),
854 ),
855 )
856
857 with make_server(responder, action_callback=action) as server:
858 events = await server.process_streaming(
859 ThreadsCreateReq(
860 params=ThreadCreateParams(
861 input=UserMessageInput(
862 content=[UserMessageTextContent(text="Show widget")],
863 attachments=[],
864 inference_options=InferenceOptions(),
865 )
866 )
867 )
868 )
869 thread = next(
870 event.thread for event in events if event.type == "thread.created"
871 )
872 widget_item = next(
873 event.item
874 for event in events
875 if isinstance(event, ThreadItemDoneEvent)
876 and isinstance(event.item, WidgetItem)
877 )
878
879 events = await server.process_streaming(
880 ThreadsCustomActionReq(
881 params=ThreadCustomActionParams(
882 thread_id=thread.id,
883 item_id=widget_item.id,
884 action=Action(type="create_user", payload={"user_id": "123"}),
885 )
886 )
887 )
888
889 assert actions
890 assert actions[0] == (
891 Action(type="create_user", payload={"user_id": "123"}),
892 widget_item,
893 )
894
895 assert len(events) == 2
896 assert events[0].type == "stream_options"
897 assert events[1].type == "thread.item.updated"
898 assert isinstance(events[1], ThreadItemUpdatedEvent)
899 assert events[1].update.type == "widget.root.updated"
900 assert events[1].update.widget == Card(children=[Text(value="Email sent!")])
901
902
903async def test_add_feedback():
904 called = []
905
906 def handle_feedback(thread_id, item_ids, feedback, context):
907 called.append((thread_id, item_ids, feedback, context))
908
909 async def responder(
910 thread: ThreadMetadata,
911 input: UserMessageItem | None,
912 context: Any,
913 ) -> AsyncIterator[ThreadStreamEvent]:
914 yield ThreadItemDoneEvent(
915 item=AssistantMessageItem(
916 id="assistant_msg_1",
917 content=[
918 AssistantMessageContent(text="Feedback received!"),
919 ],
920 created_at=datetime.now(),
921 thread_id=thread.id,
922 ),
923 )
924 yield ThreadItemDoneEvent(
925 item=WidgetItem(
926 id="widget_1",
927 type="widget",
928 created_at=datetime.now(),
929 widget=Card(children=[Text(key="text", value="Widget content")]),
930 thread_id=thread.id,
931 ),
932 )
933
934 with make_server(responder, handle_feedback=handle_feedback) as server:
935 events = await server.process_streaming(
936 ThreadsCreateReq(
937 params=ThreadCreateParams(
938 input=UserMessageInput(
939 content=[UserMessageTextContent(text="Test feedback")],
940 attachments=[],
941 inference_options=InferenceOptions(),
942 )
943 )
944 )
945 )
946 thread = next(
947 event.thread for event in events if event.type == "thread.created"
948 )
949
950 await server.process_non_streaming(
951 ItemsFeedbackReq(
952 params=ItemFeedbackParams(
953 thread_id=thread.id,
954 item_ids=["assistant_msg_1"],
955 kind="positive",
956 )
957 )
958 )
959 await server.process_non_streaming(
960 ItemsFeedbackReq(
961 params=ItemFeedbackParams(
962 thread_id=thread.id,
963 item_ids=["widget_1"],
964 kind="negative",
965 )
966 )
967 )
968
969 assert len(called) == 2
970 assert called[0][0] == thread.id
971 assert called[0][1] == ["assistant_msg_1"]
972 assert called[0][2] == "positive"
973
974 assert called[1][0] == thread.id
975 assert called[1][1] == ["widget_1"]
976 assert called[1][2] == "negative"
977
978
979async def test_get_thread_by_id():
980 make_thread_items()
981
982 with make_server() as server:
983 # Create a thread by sending a ThreadCreateRequest
984 events = await server.process_streaming(
985 ThreadsCreateReq(
986 params=ThreadCreateParams(
987 input=UserMessageInput(
988 content=[UserMessageTextContent(text="Test thread")],
989 attachments=[],
990 inference_options=InferenceOptions(),
991 )
992 )
993 )
994 )
995 thread = next(
996 event.thread for event in events if event.type == "thread.created"
997 )
998
999 # Retrieve the thread by id
1000 result = await server.process_non_streaming(
1001 ThreadsGetByIdReq(
1002 params=ThreadGetByIdParams(thread_id=thread.id),
1003 )
1004 )
1005 loaded_thread = TypeAdapter(Thread).validate_json(result.json)
1006 assert loaded_thread.id == thread.id
1007 assert loaded_thread.title == thread.title
1008 assert loaded_thread.items.data[0].type == "user_message"
1009 assert loaded_thread.items.data[0].content[0].text == "Test thread"
1010
1011
1012async def test_create_file():
1013 store = InMemoryFileStore()
1014 file_name = "test-file-name"
1015 file_content_type = "text/plain"
1016
1017 with make_server(file_store=store) as server:
1018 result = await server.process_non_streaming(
1019 AttachmentsCreateReq(
1020 params=AttachmentCreateParams(
1021 name=file_name, size=1024, mime_type=file_content_type
1022 )
1023 )
1024 )
1025 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1026
1027 # Verify response includes file
1028 assert attachment.id is not None
1029 assert attachment.mime_type == file_content_type
1030 assert attachment.name == file_name
1031 assert attachment.type == "file"
1032 assert attachment.upload_descriptor is not None
1033 assert attachment.upload_descriptor.url == AnyUrl(
1034 f"https://example.com/{attachment.id}/upload"
1035 )
1036 assert attachment.upload_descriptor.method == "PUT"
1037 assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"}
1038 assert attachment.id in store.files
1039
1040
1041async def test_create_image_file():
1042 store = InMemoryFileStore()
1043 file_name = "test-file-name"
1044 file_content_type = "image/png"
1045
1046 with make_server(file_store=store) as server:
1047 result = await server.process_non_streaming(
1048 AttachmentsCreateReq(
1049 params=AttachmentCreateParams(
1050 name=file_name,
1051 size=1024,
1052 mime_type=file_content_type,
1053 )
1054 )
1055 )
1056 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1057
1058 assert attachment.id is not None
1059 assert attachment.mime_type == file_content_type
1060 assert attachment.name == file_name
1061 assert attachment.type == "image"
1062 assert attachment.preview_url == AnyUrl(
1063 f"https://example.com/{attachment.id}/preview"
1064 )
1065 assert attachment.upload_descriptor is not None
1066 assert attachment.upload_descriptor.url == AnyUrl(
1067 f"https://example.com/{attachment.id}/upload"
1068 )
1069 assert attachment.upload_descriptor.method == "PUT"
1070 assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"}
1071
1072 assert attachment.id in store.files
1073
1074
1075async def test_create_file_without_filestore():
1076 with pytest.raises(RuntimeError):
1077 with make_server() as server:
1078 await server.process_non_streaming(
1079 AttachmentsCreateReq(
1080 params=AttachmentCreateParams(
1081 name="test-file-name",
1082 size=1024,
1083 mime_type="text/plain",
1084 )
1085 )
1086 )
1087
1088
1089async def test_delete_file():
1090 store = InMemoryFileStore()
1091 store.files["test-file-id"] = {"id": "test-file-id"}
1092 with make_server(file_store=store) as server:
1093 await server.process_non_streaming(
1094 AttachmentsDeleteReq(
1095 params=AttachmentDeleteParams(attachment_id="test-file-id")
1096 )
1097 )
1098 assert store.files == {}
1099
1100
1101async def test_delete_file_without_filestore():
1102 with pytest.raises(RuntimeError):
1103 with make_server() as server:
1104 await server.process_non_streaming(
1105 AttachmentsDeleteReq(
1106 params=AttachmentDeleteParams(attachment_id="test-file-id")
1107 )
1108 )
1109
1110
1111async def test_list_threads_paginated_response():
1112 with make_server() as server:
1113 for i in range(25):
1114 await server.process_streaming(
1115 ThreadsCreateReq(
1116 params=ThreadCreateParams(
1117 input=UserMessageInput(
1118 content=[UserMessageTextContent(text=f"Thread {i}")],
1119 attachments=[],
1120 inference_options=InferenceOptions(),
1121 )
1122 )
1123 )
1124 )
1125 list_result = await server.process_non_streaming(
1126 ThreadsListReq(params=ThreadListParams(limit=20))
1127 )
1128 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1129 assert len(threads.data) == 20
1130 assert threads.has_more is True
1131 assert threads.after is not None
1132
1133 list_result = await server.process_non_streaming(
1134 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
1135 )
1136 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1137 assert len(threads.data) == 5
1138 assert threads.has_more is False
1139
1140
1141async def test_threads_update_request():
1142 with make_server() as server:
1143 events = await server.process_streaming(
1144 ThreadsCreateReq(
1145 params=ThreadCreateParams(
1146 input=UserMessageInput(
1147 content=[UserMessageTextContent(text="Hello, world!")],
1148 attachments=[],
1149 inference_options=InferenceOptions(),
1150 )
1151 )
1152 )
1153 )
1154 thread = next(
1155 event.thread for event in events if event.type == "thread.created"
1156 )
1157 updated_thread = await server.process_non_streaming(
1158 ThreadsUpdateReq(
1159 params=ThreadUpdateParams(thread_id=thread.id, title="New Title")
1160 )
1161 )
1162 updated_thread = TypeAdapter[Thread](Thread).validate_json(updated_thread.json)
1163 assert updated_thread.title == "New Title"
1164
1165
1166async def test_returns_widget_item_generator():
1167 thread = ThreadMetadata(
1168 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1169 )
1170
1171 def render_widget(i: int) -> Card:
1172 return Card(children=[Text(id="text", value="Hello, world"[:i])])
1173
1174 async def widget_generator():
1175 yield render_widget(0)
1176 yield render_widget(3)
1177 yield render_widget(6)
1178 yield render_widget(12)
1179
1180 events = [event async for event in stream_widget(thread, widget_generator())]
1181
1182 assert len(events) == 5
1183 assert isinstance(events[0], ThreadItemAddedEvent)
1184 assert isinstance(events[0].item, WidgetItem)
1185 assert events[0].item.widget == Card(children=[Text(id="text", value="")])
1186
1187 assert isinstance(events[1], ThreadItemUpdatedEvent)
1188 assert events[1].update.type == "widget.streaming_text.value_delta"
1189 assert events[1].update.component_id == "text"
1190 assert events[1].update.delta == "Hel"
1191
1192 assert isinstance(events[2], ThreadItemUpdatedEvent)
1193 assert events[2].update.type == "widget.streaming_text.value_delta"
1194 assert events[2].update.component_id == "text"
1195 assert events[2].update.delta == "lo,"
1196
1197 assert isinstance(events[3], ThreadItemUpdatedEvent)
1198 assert events[3].update.type == "widget.streaming_text.value_delta"
1199 assert events[3].update.component_id == "text"
1200 assert events[3].update.delta == " world"
1201
1202 assert isinstance(events[4], ThreadItemDoneEvent)
1203 assert isinstance(events[4].item, WidgetItem)
1204 assert events[4].item.widget == Card(
1205 children=[Text(id="text", value="Hello, world")]
1206 )
1207
1208
1209async def test_returns_widget_item_generator_full_replace():
1210 thread = ThreadMetadata(
1211 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1212 )
1213
1214 async def widget_generator():
1215 yield Card(children=[Text(id="text", value="Hello")])
1216 yield Card(children=[Text(id="text", value="World")])
1217
1218 events = [event async for event in stream_widget(thread, widget_generator())]
1219
1220 assert len(events) == 3
1221 assert isinstance(events[0], ThreadItemAddedEvent)
1222 assert isinstance(events[0].item, WidgetItem)
1223 assert events[0].item.widget == Card(children=[Text(id="text", value="Hello")])
1224
1225 assert isinstance(events[1], ThreadItemUpdatedEvent)
1226 assert events[1].update.type == "widget.root.updated"
1227 assert events[1].update.widget == Card(children=[Text(id="text", value="World")])
1228
1229 assert isinstance(events[2], ThreadItemDoneEvent)
1230 assert isinstance(events[2].item, WidgetItem)
1231 assert events[2].item.widget == Card(children=[Text(id="text", value="World")])
1232
1233
1234async def test_delete_thread():
1235 with make_server() as server:
1236 events = await server.process_streaming(
1237 ThreadsCreateReq(
1238 params=ThreadCreateParams(
1239 input=UserMessageInput(
1240 content=[UserMessageTextContent(text="Thread to delete")],
1241 attachments=[],
1242 inference_options=InferenceOptions(),
1243 )
1244 )
1245 )
1246 )
1247 thread = next(
1248 event.thread for event in events if event.type == "thread.created"
1249 )
1250
1251 response = await server.process_non_streaming(
1252 ThreadsDeleteReq(params=ThreadDeleteParams(thread_id=thread.id))
1253 )
1254 assert response.json == b"{}"
1255
1256 try:
1257 await server.process_non_streaming(
1258 ThreadsGetByIdReq(
1259 params=ThreadGetByIdParams(thread_id=thread.id),
1260 )
1261 )
1262 assert False, "Expected ValueError for deleted thread"
1263 except NotFoundError as e:
1264 assert f"Thread {thread.id} not found" in str(e)
1265
1266
1267async def test_thread_item_removed_event_removes_item():
1268 removed_id = "msg_to_remove"
1269
1270 async def responder(
1271 thread: ThreadMetadata,
1272 input: UserMessageItem | None,
1273 context: Any,
1274 ) -> AsyncIterator[ThreadStreamEvent]:
1275 yield ThreadItemDoneEvent(
1276 item=UserMessageItem(
1277 id=removed_id,
1278 content=[UserMessageTextContent(text="To be removed")],
1279 attachments=[],
1280 inference_options=InferenceOptions(),
1281 thread_id=thread.id,
1282 created_at=datetime.now(),
1283 ),
1284 )
1285 yield ThreadItemRemovedEvent(item_id=removed_id)
1286
1287 with make_server(responder) as server:
1288 events = await server.process_streaming(
1289 ThreadsCreateReq(
1290 params=ThreadCreateParams(
1291 input=UserMessageInput(
1292 content=[UserMessageTextContent(text="Trigger removal")],
1293 attachments=[],
1294 inference_options=InferenceOptions(),
1295 )
1296 )
1297 )
1298 )
1299 thread = next(
1300 event.thread for event in events if event.type == "thread.created"
1301 )
1302 assert any(
1303 e.type == "thread.item.removed" and e.item_id == removed_id for e in events
1304 )
1305 # The item should not be present in the store
1306 list_result = await server.process_non_streaming(
1307 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1308 )
1309 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1310 assert all(i.id != removed_id for i in items)
1311
1312
1313async def test_raising_in_responder_yields_error_event():
1314 async def responder(
1315 thread: ThreadMetadata,
1316 input: UserMessageItem | None,
1317 context: Any,
1318 ) -> AsyncIterator[ThreadStreamEvent]:
1319 raise ValueError("Test error")
1320 yield
1321
1322 with make_server(responder) as server:
1323 result = await server.process(
1324 ThreadsCreateReq(
1325 params=ThreadCreateParams(
1326 input=UserMessageInput(
1327 content=[UserMessageTextContent(text="Test error")],
1328 attachments=[],
1329 inference_options=InferenceOptions(),
1330 )
1331 )
1332 ).model_dump_json(),
1333 RequestContext(user_id="test_user"),
1334 )
1335 assert isinstance(result, StreamingResult)
1336 iter = result.__aiter__()
1337
1338 # Yield an errro
1339 async for e in iter:
1340 e = decode_event(e)
1341 if e.type == "error":
1342 assert e.code == ErrorCode.STREAM_ERROR
1343 break
1344 else:
1345 assert False, "Expected error event"
1346
1347 # Then finish
1348 try:
1349 await iter.__anext__()
1350 except StopAsyncIteration:
1351 pass
1352 else:
1353 assert False, "Expected StopAsyncIteration"
1354
1355
1356async def test_thread_item_replaced_event_replaces_item():
1357 original_id = "msg_to_replace"
1358
1359 async def responder(
1360 thread: ThreadMetadata,
1361 input: UserMessageItem | None,
1362 context: Any,
1363 ) -> AsyncIterator[ThreadStreamEvent]:
1364 # First add an item
1365 yield ThreadItemDoneEvent(
1366 item=UserMessageItem(
1367 id=original_id,
1368 content=[UserMessageTextContent(text="Original content")],
1369 attachments=[],
1370 inference_options=InferenceOptions(),
1371 thread_id=thread.id,
1372 created_at=datetime.now(),
1373 ),
1374 )
1375 # Then replace it
1376 replacement_item = AssistantMessageItem(
1377 id=original_id,
1378 content=[AssistantMessageContent(text="This is the replacement content")],
1379 created_at=datetime.now(),
1380 thread_id=thread.id,
1381 )
1382 yield ThreadItemReplacedEvent(item=replacement_item)
1383
1384 with make_server(responder) as server:
1385 events = await server.process_streaming(
1386 ThreadsCreateReq(
1387 params=ThreadCreateParams(
1388 input=UserMessageInput(
1389 content=[UserMessageTextContent(text="Trigger replacement")],
1390 attachments=[],
1391 inference_options=InferenceOptions(),
1392 )
1393 )
1394 )
1395 )
1396 thread = next(
1397 event.thread for event in events if event.type == "thread.created"
1398 )
1399
1400 # Verify the replacement event was generated
1401 assert any(
1402 e.type == "thread.item.replaced" and e.item.id == original_id
1403 for e in events
1404 )
1405
1406 # Verify the item was replaced in the store
1407 list_result = await server.process_non_streaming(
1408 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1409 )
1410 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1411
1412 # Find the replaced item
1413 replaced_item = next((i for i in items if i.id == original_id), None)
1414 assert replaced_item is not None
1415 assert isinstance(replaced_item, AssistantMessageItem)
1416 assert isinstance(replaced_item.content[0], AssistantMessageContent)
1417 assert replaced_item.content[0].text == "This is the replacement content"
1418
1419
1420async def test_thread_item_replaced_event_with_widget():
1421 widget_id = "widget_to_replace"
1422 original_widget = Card(children=[Text(value="Original widget content")])
1423 replacement_widget = Card(children=[Text(value="Replaced widget content")])
1424
1425 async def responder(
1426 thread: ThreadMetadata,
1427 input: UserMessageItem | None,
1428 context: Any,
1429 ) -> AsyncIterator[ThreadStreamEvent]:
1430 # First add a widget item
1431 yield ThreadItemDoneEvent(
1432 item=WidgetItem(
1433 id=widget_id,
1434 widget=original_widget,
1435 created_at=datetime.now(),
1436 thread_id=thread.id,
1437 ),
1438 )
1439 # Then replace it
1440 yield ThreadItemReplacedEvent(
1441 item=WidgetItem(
1442 id=widget_id,
1443 widget=replacement_widget,
1444 created_at=datetime.now(),
1445 thread_id=thread.id,
1446 ),
1447 )
1448
1449 with make_server(responder) as server:
1450 events = await server.process_streaming(
1451 ThreadsCreateReq(
1452 params=ThreadCreateParams(
1453 input=UserMessageInput(
1454 content=[
1455 UserMessageTextContent(text="Trigger widget replacement")
1456 ],
1457 attachments=[],
1458 inference_options=InferenceOptions(),
1459 )
1460 )
1461 )
1462 )
1463 thread = next(
1464 event.thread for event in events if event.type == "thread.created"
1465 )
1466
1467 # Verify the replacement event was generated
1468 replacement_event = next(
1469 (
1470 e
1471 for e in events
1472 if e.type == "thread.item.replaced" and e.item.id == widget_id
1473 ),
1474 None,
1475 )
1476 assert replacement_event is not None
1477 assert isinstance(replacement_event.item, WidgetItem)
1478 assert replacement_event.item.widget == replacement_widget
1479
1480 # Verify the item was replaced in the store
1481 list_result = await server.process_non_streaming(
1482 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1483 )
1484 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1485
1486 # Find the replaced widget item
1487 replaced_item = next((i for i in items if i.id == widget_id), None)
1488 assert replaced_item is not None
1489 assert isinstance(replaced_item, WidgetItem)
1490 assert isinstance(replaced_item.widget, Card)
1491 assert isinstance(replaced_item.widget.children[0], Text)
1492 assert replaced_item.widget.children[0].value == "Replaced widget content"
1493
1494
1495async def test_retry_after_item():
1496 responder_calls = []
1497
1498 async def responder(
1499 thread: ThreadMetadata,
1500 input: UserMessageItem | None,
1501 context: Any,
1502 ) -> AsyncIterator[ThreadStreamEvent]:
1503 responder_calls.append(input)
1504
1505 if len(responder_calls) == 1:
1506 # First response - return an assistant message
1507 yield ThreadItemDoneEvent(
1508 item=AssistantMessageItem(
1509 id="assistant_msg_1",
1510 content=[AssistantMessageContent(text="First response")],
1511 created_at=datetime.now(),
1512 thread_id=thread.id,
1513 ),
1514 )
1515 else:
1516 # Retry response - return different content
1517 yield ThreadItemDoneEvent(
1518 item=AssistantMessageItem(
1519 id="assistant_msg_2",
1520 content=[AssistantMessageContent(text="Retried response")],
1521 created_at=datetime.now(),
1522 thread_id=thread.id,
1523 ),
1524 )
1525
1526 with make_server(responder) as server:
1527 events = await server.process_streaming(
1528 ThreadsCreateReq(
1529 params=ThreadCreateParams(
1530 input=UserMessageInput(
1531 content=[UserMessageTextContent(text="Hello, world!")],
1532 attachments=[],
1533 inference_options=InferenceOptions(),
1534 )
1535 )
1536 )
1537 )
1538 thread = next(
1539 event.thread for event in events if event.type == "thread.created"
1540 )
1541
1542 list_result = await server.process_non_streaming(
1543 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1544 )
1545 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1546 assert len(items_before.data) == 2
1547 assert items_before.data[0].type == "assistant_message"
1548 assert items_before.data[0].content[0].type == "output_text"
1549 assert items_before.data[0].content[0].text == "First response"
1550 assert items_before.data[1].type == "user_message"
1551 assert items_before.data[1].content[0].text == "Hello, world!"
1552
1553 # Get the user message ID for retrying after it
1554 user_message_id = items_before.data[1].id
1555
1556 # Retry after the user message
1557 retry_events = await server.process_streaming(
1558 ThreadsRetryAfterItemReq(
1559 params=ThreadRetryAfterItemParams(
1560 thread_id=thread.id, item_id=user_message_id
1561 )
1562 )
1563 )
1564
1565 # Verify the retry generated new response
1566 assert len(retry_events) == 2
1567 assert retry_events[0].type == "stream_options"
1568 assert retry_events[1].type == "thread.item.done"
1569 assert retry_events[1].item.type == "assistant_message"
1570 assert retry_events[1].item.content[0].type == "output_text"
1571 assert retry_events[1].item.content[0].text == "Retried response"
1572
1573 # Verify the responder was called twice with the same user message
1574 assert len(responder_calls) == 2
1575 assert responder_calls[0] == responder_calls[1]
1576 assert responder_calls[0].type == "user_message"
1577 assert responder_calls[0].content[0].text == "Hello, world!"
1578
1579 # Verify the assistant response was replaced in storage
1580 list_result_after = await server.process_non_streaming(
1581 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1582 )
1583 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1584 list_result_after.json
1585 )
1586 assert len(items_after.data) == 2 # Still user message + assistant message
1587 assert items_after.data[0].type == "assistant_message"
1588 assert items_after.data[0].content[0].type == "output_text"
1589 assert items_after.data[0].content[0].text == "Retried response" # New response
1590 assert items_after.data[1].type == "user_message"
1591 assert items_after.data[1].content[0].text == "Hello, world!"
1592
1593
1594async def test_retry_after_item_invalid_item():
1595 async def responder(
1596 thread: ThreadMetadata,
1597 input: UserMessageItem | None,
1598 context: Any,
1599 ) -> AsyncIterator[ThreadStreamEvent]:
1600 yield ThreadItemDoneEvent(
1601 item=AssistantMessageItem(
1602 id="assistant_msg_1",
1603 content=[AssistantMessageContent(text="Response")],
1604 created_at=datetime.now(),
1605 thread_id=thread.id,
1606 ),
1607 )
1608
1609 with make_server(responder) as server:
1610 events = await server.process_streaming(
1611 ThreadsCreateReq(
1612 params=ThreadCreateParams(
1613 input=UserMessageInput(
1614 content=[UserMessageTextContent(text="Hello")],
1615 attachments=[],
1616 inference_options=InferenceOptions(),
1617 )
1618 )
1619 )
1620 )
1621 thread = next(
1622 event.thread for event in events if event.type == "thread.created"
1623 )
1624
1625 items_result = await server.process_non_streaming(
1626 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1627 )
1628 items = TypeAdapter(Page[ThreadItem]).validate_json(items_result.json)
1629 assistant_item_id = items.data[0].id
1630
1631 # Try to retry after an assistant message (should fail)
1632 with pytest.raises(ValueError, match="is not a user message"):
1633 await server.process_streaming(
1634 ThreadsRetryAfterItemReq(
1635 params=ThreadRetryAfterItemParams(
1636 thread_id=thread.id, item_id=assistant_item_id
1637 )
1638 )
1639 )
1640
1641
1642async def test_retry_after_item_with_multiple_messages():
1643 responder_calls = []
1644
1645 async def responder(
1646 thread: ThreadMetadata,
1647 input: UserMessageItem | None,
1648 context: Any,
1649 ) -> AsyncIterator[ThreadStreamEvent]:
1650 responder_calls.append(input)
1651 assert input is not None
1652 assert input.type == "user_message"
1653 yield ThreadItemDoneEvent(
1654 item=AssistantMessageItem(
1655 id=f"assistant_msg_{len(responder_calls)}",
1656 content=[
1657 AssistantMessageContent(
1658 text=f"Response to: {input.content[0].text}"
1659 )
1660 ],
1661 created_at=datetime.now(),
1662 thread_id=thread.id,
1663 ),
1664 )
1665
1666 with make_server(responder) as server:
1667 events = await server.process_streaming(
1668 ThreadsCreateReq(
1669 params=ThreadCreateParams(
1670 input=UserMessageInput(
1671 content=[UserMessageTextContent(text="First message")],
1672 attachments=[],
1673 inference_options=InferenceOptions(),
1674 )
1675 )
1676 )
1677 )
1678 thread = next(
1679 event.thread for event in events if event.type == "thread.created"
1680 )
1681
1682 await server.process_streaming(
1683 ThreadsAddUserMessageReq(
1684 params=ThreadAddUserMessageParams(
1685 thread_id=thread.id,
1686 input=UserMessageInput(
1687 content=[UserMessageTextContent(text="Second message")],
1688 attachments=[],
1689 inference_options=InferenceOptions(),
1690 ),
1691 )
1692 )
1693 )
1694
1695 # Verify we have 4 items: user1, assistant1, user2, assistant2
1696 list_result = await server.process_non_streaming(
1697 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1698 )
1699 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1700 assert len(items_before.data) == 4
1701
1702 # Get the second user message ID to retry after it
1703 second_user_message_id = items_before.data[1].id # Second user message
1704
1705 # Retry after the second user message
1706 retry_events = await server.process_streaming(
1707 ThreadsRetryAfterItemReq(
1708 params=ThreadRetryAfterItemParams(
1709 thread_id=thread.id, item_id=second_user_message_id
1710 )
1711 )
1712 )
1713
1714 assert len(retry_events) == 2
1715 assert retry_events[0].type == "stream_options"
1716 assert retry_events[1].type == "thread.item.done"
1717
1718 # Verify retry used the second user message
1719 assert len(responder_calls) == 3 # Original 2 + 1 retry
1720 assert responder_calls[2].type == "user_message"
1721 assert responder_calls[2].content[0].text == "Second message"
1722
1723 # Verify only the last assistant response was removed and regenerated
1724 list_result_after = await server.process_non_streaming(
1725 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1726 )
1727 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1728 list_result_after.json
1729 )
1730 assert len(items_after.data) == 4 # Still 4 items
1731 # First user message and response should remain unchanged
1732 assert items_after.data[3].type == "user_message"
1733 assert items_after.data[3].content[0].text == "First message"
1734 assert items_after.data[3].content[0].type == "input_text"
1735 assert items_after.data[3].content[0].text == "First message"
1736 assert items_after.data[2].type == "assistant_message"
1737 assert items_after.data[2].content[0].type == "output_text"
1738 assert items_after.data[2].content[0].text == "Response to: First message"
1739 # Second user message should remain, but assistant response should be new
1740 assert items_after.data[1].type == "user_message"
1741 assert items_after.data[1].content[0].text == "Second message"
1742 assert items_after.data[0].type == "assistant_message"
1743 assert items_after.data[0].content[0].type == "output_text"
1744 assert items_after.data[0].content[0].text == "Response to: Second message"
1745 assert items_after.data[0].id != items_before.data[0].id
1746
1747
1748async def test_threads_create_passes_tools_to_responder():
1749 tool_choice = ToolChoice(id="web_search")
1750
1751 async def responder(
1752 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1753 ) -> AsyncIterator[ThreadStreamEvent]:
1754 assert input is not None
1755 assert input.inference_options.tool_choice is not None
1756 assert input.inference_options.tool_choice.id == tool_choice.id
1757
1758 yield ThreadItemDoneEvent(
1759 item=AssistantMessageItem(
1760 id="assistant_msg_tools_create",
1761 content=[AssistantMessageContent(text="ok")],
1762 created_at=datetime.now(),
1763 thread_id=thread.id,
1764 ),
1765 )
1766
1767 with make_server(responder) as server:
1768 events = await server.process_streaming(
1769 ThreadsCreateReq(
1770 params=ThreadCreateParams(
1771 input=UserMessageInput(
1772 content=[UserMessageTextContent(text="Hello with tools")],
1773 attachments=[],
1774 inference_options=InferenceOptions(tool_choice=tool_choice),
1775 )
1776 )
1777 )
1778 )
1779 # Sanity check that the request flowed through.
1780 assert any(e.type == "thread.item.done" for e in events)
1781
1782
1783async def test_retry_after_item_passes_tools_to_responder():
1784 pass
1785
1786
1787async def test_threads_add_message_passes_tools_to_responder():
1788 tool_choice = ToolChoice(id="retrieval")
1789
1790 async def responder(
1791 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1792 ) -> AsyncIterator[ThreadStreamEvent]:
1793 assert input is not None
1794 assert input.inference_options.tool_choice is not None
1795 assert input.inference_options.tool_choice.id == tool_choice.id
1796
1797 yield ThreadItemDoneEvent(
1798 item=AssistantMessageItem(
1799 id="assistant_msg_tools_add",
1800 content=[AssistantMessageContent(text="ok")],
1801 created_at=datetime.now(),
1802 thread_id=thread.id,
1803 ),
1804 )
1805
1806 with make_server(responder) as server:
1807 # Create a thread first.
1808 events = await server.process_streaming(
1809 ThreadsCreateReq(
1810 params=ThreadCreateParams(
1811 input=UserMessageInput(
1812 content=[UserMessageTextContent(text="Start")],
1813 attachments=[],
1814 inference_options=InferenceOptions(tool_choice=tool_choice),
1815 )
1816 )
1817 )
1818 )
1819 thread = next(e.thread for e in events if e.type == "thread.created")
1820
1821 # Add a message with tools and verify they reach the responder.
1822 events = await server.process_streaming(
1823 ThreadsAddUserMessageReq(
1824 params=ThreadAddUserMessageParams(
1825 thread_id=thread.id,
1826 input=UserMessageInput(
1827 content=[UserMessageTextContent(text="Second")],
1828 attachments=[],
1829 inference_options=InferenceOptions(tool_choice=tool_choice),
1830 ),
1831 )
1832 )
1833 )
1834 # Sanity check that the request flowed through.
1835 assert any(e.type == "thread.item.done" for e in events)
1836
1837
1838async def test_custom_id_generators_on_server():
1839 class UniqueIdStore(SQLiteStore):
1840 def __init__(self, path):
1841 super().__init__(path)
1842 self.counter = 0
1843
1844 def generate_thread_id(self, context) -> str:
1845 self.counter += 1
1846 return f"thr_custom_{self.counter}"
1847
1848 def generate_item_id(
1849 self,
1850 item_type: str,
1851 thread: ThreadMetadata,
1852 context,
1853 ) -> str:
1854 self.counter += 1
1855 return f"{item_type}_custom_{self.counter}_{thread.id}"
1856
1857 async def responder(
1858 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1859 ) -> AsyncIterator[ThreadStreamEvent]:
1860 # Emit a tool call to later verify tool_call_output id generation
1861 return
1862 yield
1863
1864 with make_server(responder) as server:
1865 # Swap in our custom ID-generating store, reusing the same DB path.
1866 db_path = cast(SQLiteStore, server.store).db_path
1867 server.store = UniqueIdStore(db_path)
1868
1869 # Create thread and verify thread id and user message id
1870 result = await server.process(
1871 ThreadsCreateReq(
1872 params=ThreadCreateParams(
1873 input=UserMessageInput(
1874 content=[UserMessageTextContent(text="Hello")],
1875 attachments=[],
1876 inference_options=InferenceOptions(),
1877 )
1878 )
1879 ).model_dump_json(),
1880 DEFAULT_CONTEXT,
1881 )
1882 assert isinstance(result, StreamingResult)
1883 events = await decode_streaming_result(result)
1884
1885 thread_created = next(e for e in events if e.type == "thread.created")
1886 assert thread_created.thread.id == "thr_custom_1"
1887
1888 user_message = next(
1889 e.item
1890 for e in events
1891 if e.type == "thread.item.done" and e.item.type == "user_message"
1892 )
1893 assert user_message.id == "message_custom_2_thr_custom_1"
1894