openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.6.5

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_chatkit_server.py

2624lines · modecode

1import asyncio
2import base64
3import json
4import sqlite3
5from contextlib import contextmanager
6from datetime import datetime
7from typing import Any, AsyncIterator, Callable, cast
8
9import pytest
10from helpers.mock_store import SQLiteStore
11from pydantic import AnyUrl, TypeAdapter
12
13from chatkit.actions import Action
14from chatkit.errors import ErrorCode
15from chatkit.server import (
16 ChatKitServer,
17 NonStreamingResult,
18 StreamingResult,
19 stream_widget,
20)
21from chatkit.store import (
22 AttachmentStore,
23 NotFoundError,
24 StoreItemType,
25 default_generate_id,
26)
27from chatkit.types import (
28 AssistantMessageContent,
29 AssistantMessageContentPartTextDelta,
30 AssistantMessageItem,
31 Attachment,
32 AttachmentCreateParams,
33 AttachmentDeleteParams,
34 AttachmentsCreateReq,
35 AttachmentsDeleteReq,
36 AttachmentUploadDescriptor,
37 AudioInput,
38 ClientToolCallItem,
39 CustomTask,
40 FeedbackKind,
41 FileAttachment,
42 ImageAttachment,
43 InferenceOptions,
44 InputTranscribeParams,
45 InputTranscribeReq,
46 ItemFeedbackParams,
47 ItemsFeedbackReq,
48 ItemsListParams,
49 ItemsListReq,
50 LockedStatus,
51 Page,
52 ProgressUpdateEvent,
53 StructuredInputAnswerSubmission,
54 StructuredInputFreeform,
55 StructuredInputItem,
56 StructuredInputMultipleChoice,
57 StructuredInputMultipleChoiceOption,
58 StructuredInputSubmission,
59 SyncCustomActionResponse,
60 Thread,
61 ThreadAddClientToolOutputParams,
62 ThreadAddStructuredInputParams,
63 ThreadAddUserMessageParams,
64 ThreadCreatedEvent,
65 ThreadCreateParams,
66 ThreadCustomActionParams,
67 ThreadDeleteParams,
68 ThreadGetByIdParams,
69 ThreadItem,
70 ThreadItemAddedEvent,
71 ThreadItemDoneEvent,
72 ThreadItemRemovedEvent,
73 ThreadItemReplacedEvent,
74 ThreadItemUpdatedEvent,
75 ThreadListParams,
76 ThreadMetadata,
77 ThreadRetryAfterItemParams,
78 ThreadsAddClientToolOutputReq,
79 ThreadsAddStructuredInputReq,
80 ThreadsAddUserMessageReq,
81 ThreadsCreateReq,
82 ThreadsCustomActionReq,
83 ThreadsDeleteReq,
84 ThreadsGetByIdReq,
85 ThreadsListReq,
86 ThreadsRetryAfterItemReq,
87 ThreadsSyncCustomActionReq,
88 ThreadStreamEvent,
89 ThreadsUpdateReq,
90 ThreadUpdatedEvent,
91 ThreadUpdateParams,
92 ToolChoice,
93 TranscriptionResult,
94 UserMessageInput,
95 UserMessageItem,
96 UserMessageTextContent,
97 WidgetItem,
98 WidgetRootUpdated,
99 Workflow,
100 WorkflowItem,
101 WorkflowTaskAdded,
102)
103from chatkit.widgets import Card, Text
104from tests._types import RequestContext
105from tests.test_store import make_thread_items
106
107server_id = 0
108
109
110DEFAULT_CONTEXT = RequestContext(user_id="test_user")
111
112
113class InMemoryFileStore(AttachmentStore):
114 def __init__(self):
115 self.files = {}
116
117 async def delete_attachment(self, attachment_id: str, context: Any) -> None:
118 del self.files[attachment_id]
119
120 async def create_attachment(
121 self, input: AttachmentCreateParams, context: RequestContext
122 ) -> Attachment:
123 # Simple in-memory attachment creation
124 id = f"atc_{len(self.files) + 1}"
125 metadata = {"source": "test"}
126 if input.mime_type.startswith("image/"):
127 attachment = ImageAttachment(
128 id=id,
129 mime_type=input.mime_type,
130 name=input.name,
131 preview_url=AnyUrl(f"https://example.com/{id}/preview"),
132 upload_descriptor=AttachmentUploadDescriptor(
133 url=AnyUrl(f"https://example.com/{id}/upload"),
134 method="PUT",
135 headers={"X-My-Header": "my-value"},
136 ),
137 metadata=metadata,
138 )
139 else:
140 attachment = FileAttachment(
141 id=id,
142 mime_type=input.mime_type,
143 name=input.name,
144 upload_descriptor=AttachmentUploadDescriptor(
145 url=AnyUrl(f"https://example.com/{id}/upload"),
146 method="PUT",
147 headers={"X-My-Header": "my-value"},
148 ),
149 metadata=metadata,
150 )
151 self.files[attachment.id] = attachment
152 return attachment
153
154
155def decode_event(event: bytes) -> ThreadStreamEvent:
156 return TypeAdapter(ThreadStreamEvent).validate_json(event.split(b"data: ")[1])
157
158
159async def decode_streaming_result(
160 streaming_result: StreamingResult,
161) -> list[ThreadStreamEvent]:
162 return [decode_event(event) async for event in streaming_result.json_events]
163
164
165@contextmanager
166def make_server(
167 responder: Callable[
168 [ThreadMetadata, UserMessageItem | None, Any], AsyncIterator[ThreadStreamEvent]
169 ]
170 | None = None,
171 handle_feedback: Callable[[str, list[str], FeedbackKind, Any], None] | None = None,
172 action_callback: Callable[
173 [ThreadMetadata, Action[str, Any], WidgetItem | None, Any],
174 AsyncIterator[ThreadStreamEvent],
175 ]
176 | None = None,
177 sync_action_callback: Callable[
178 [ThreadMetadata, Action[str, Any], WidgetItem | None, Any],
179 SyncCustomActionResponse,
180 ]
181 | None = None,
182 file_store: AttachmentStore | None = None,
183 transcribe_callback: Callable[[AudioInput, Any], TranscriptionResult] | None = None,
184):
185 global server_id
186 db_path = f"file:{server_id}?mode=memory&cache=shared"
187 server_id += 1
188 # Keep the shared in-memory database from being deleted when the last connection is closed
189 db = sqlite3.connect(db_path, uri=True)
190
191 class TestChatKitServer(ChatKitServer):
192 def __init__(self):
193 super().__init__(SQLiteStore(db_path), file_store)
194
195 def action(
196 self,
197 thread: ThreadMetadata,
198 action: Action[str, Any],
199 sender: WidgetItem | None,
200 context: Any,
201 ) -> AsyncIterator[ThreadStreamEvent]:
202 if action_callback is None:
203 raise ValueError("action_callback not wired up")
204 return action_callback(thread, action, sender, context)
205
206 async def sync_action(
207 self,
208 thread: ThreadMetadata,
209 action: Action[str, Any],
210 sender: WidgetItem | None,
211 context: Any,
212 ) -> SyncCustomActionResponse:
213 if sync_action_callback is None:
214 raise ValueError("sync_action_callback not wired up")
215 return sync_action_callback(thread, action, sender, context)
216
217 def respond(
218 self,
219 thread: ThreadMetadata,
220 input_user_message: UserMessageItem | None,
221 context: Any,
222 ) -> AsyncIterator[ThreadStreamEvent]:
223 if responder is None:
224 return self._empty_responder()
225 return responder(thread, input_user_message, context)
226
227 async def _empty_responder(self) -> AsyncIterator[ThreadStreamEvent]:
228 return
229 yield
230
231 async def add_feedback(
232 self,
233 thread_id: str,
234 item_ids: list[str],
235 feedback: FeedbackKind,
236 context: Any,
237 ) -> None:
238 if handle_feedback is None:
239 return
240 handle_feedback(thread_id, item_ids, feedback, context)
241
242 async def transcribe(
243 self, audio_input: AudioInput, context: Any
244 ) -> TranscriptionResult:
245 if transcribe_callback is None:
246 return await super().transcribe(audio_input, context)
247 return transcribe_callback(audio_input, context)
248
249 async def process_streaming(
250 self, request_obj, context: Any | None = None
251 ) -> list[ThreadStreamEvent]:
252 result = await self.process(
253 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
254 )
255 assert isinstance(result, StreamingResult)
256 return await decode_streaming_result(result)
257
258 async def process_non_streaming(self, request_obj, context: Any | None = None):
259 result = await self.process(
260 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
261 )
262 assert isinstance(result, NonStreamingResult)
263 return result
264
265 try:
266 yield TestChatKitServer()
267 finally:
268 db.close()
269
270
271async def test_stream_cancellation_persists_pending_assistant_message_and_hidden_context():
272 async def responder(
273 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
274 ) -> AsyncIterator[ThreadStreamEvent]:
275 yield ThreadItemAddedEvent(
276 item=AssistantMessageItem(
277 id="assistant-message-pending",
278 created_at=datetime.now(),
279 content=[AssistantMessageContent(text="Hello, ")],
280 thread_id=thread.id,
281 )
282 )
283 yield ThreadItemUpdatedEvent(
284 item_id="assistant-message-pending",
285 update=AssistantMessageContentPartTextDelta(
286 content_index=0,
287 delta="World!",
288 ),
289 )
290 raise asyncio.CancelledError()
291
292 with make_server(responder) as server:
293 # Allow hidden_context id generation in this test store
294 original_generate_item_id = server.store.generate_item_id
295
296 def generate_item_id(
297 item_type: StoreItemType, thread: ThreadMetadata, context: Any
298 ):
299 if item_type == "sdk_hidden_context":
300 return default_generate_id("sdk_hidden_context")
301 return original_generate_item_id(item_type, thread, context)
302
303 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
304
305 stream = await server.process(
306 ThreadsCreateReq(
307 params=ThreadCreateParams(
308 input=UserMessageInput(
309 content=[UserMessageTextContent(text="Hello")],
310 attachments=[],
311 inference_options=InferenceOptions(),
312 )
313 )
314 ).model_dump_json(),
315 DEFAULT_CONTEXT,
316 )
317 assert isinstance(stream, StreamingResult)
318
319 events: list[ThreadStreamEvent] = []
320 with pytest.raises(asyncio.CancelledError): # noqa: PT012
321 async for raw in stream.json_events:
322 events.append(decode_event(raw))
323
324 thread = next(e.thread for e in events if e.type == "thread.created")
325 items = await server.store.load_thread_items(
326 thread.id, None, 1, "desc", DEFAULT_CONTEXT
327 )
328 hidden_context_item = items.data[-1]
329 assert hidden_context_item.type == "sdk_hidden_context"
330 assert (
331 hidden_context_item.content
332 == "The user cancelled the stream. Stop responding to the prior request."
333 )
334
335 assistant_message_item = await server.store.load_item(
336 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
337 )
338 assert assistant_message_item.type == "assistant_message"
339 assert assistant_message_item.content[0].text == "Hello, World!"
340
341
342async def test_stream_cancellation_does_not_persist_pending_empty_assistant_message():
343 async def responder(
344 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
345 ) -> AsyncIterator[ThreadStreamEvent]:
346 yield ThreadItemAddedEvent(
347 item=AssistantMessageItem(
348 id="assistant-message-pending",
349 created_at=datetime.now(),
350 content=[],
351 thread_id=thread.id,
352 )
353 )
354 raise asyncio.CancelledError()
355
356 with make_server(responder) as server:
357 original_generate_item_id = server.store.generate_item_id
358
359 def generate_item_id(
360 item_type: StoreItemType, thread: ThreadMetadata, context: Any
361 ):
362 if item_type == "sdk_hidden_context":
363 return default_generate_id("sdk_hidden_context")
364 return original_generate_item_id(item_type, thread, context)
365
366 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
367
368 stream = await server.process(
369 ThreadsCreateReq(
370 params=ThreadCreateParams(
371 input=UserMessageInput(
372 content=[UserMessageTextContent(text="Hello")],
373 attachments=[],
374 inference_options=InferenceOptions(),
375 )
376 )
377 ).model_dump_json(),
378 DEFAULT_CONTEXT,
379 )
380 assert isinstance(stream, StreamingResult)
381
382 events: list[ThreadStreamEvent] = []
383 with pytest.raises(asyncio.CancelledError): # noqa: PT012
384 async for raw in stream.json_events:
385 events.append(decode_event(raw))
386
387 thread = next(e.thread for e in events if e.type == "thread.created")
388 items = await server.store.load_thread_items(
389 thread.id, None, 1, "desc", DEFAULT_CONTEXT
390 )
391 hidden_context_item = items.data[-1]
392 assert hidden_context_item.type == "sdk_hidden_context"
393 assert (
394 hidden_context_item.content
395 == "The user cancelled the stream. Stop responding to the prior request."
396 )
397
398 with pytest.raises(NotFoundError):
399 await server.store.load_item(
400 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
401 )
402
403
404async def test_workflow_task_not_duplicated_on_done_event():
405 """
406 Regression test to make sure pending item updates do not modify the
407 origin item that was streamed.
408 """
409
410 async def responder(
411 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
412 ) -> AsyncIterator[ThreadStreamEvent]:
413 workflow_item = WorkflowItem(
414 id="workflow-item",
415 created_at=datetime.now(),
416 thread_id=thread.id,
417 workflow=Workflow(type="custom", tasks=[]),
418 )
419
420 yield ThreadItemAddedEvent(item=workflow_item)
421
422 task = CustomTask(title="foo", content="bar")
423 yield ThreadItemUpdatedEvent(
424 item_id=workflow_item.id,
425 update=WorkflowTaskAdded(task=task, task_index=0),
426 )
427
428 workflow_item.workflow.tasks.append(task)
429 yield ThreadItemDoneEvent(item=workflow_item)
430
431 with make_server(responder) as server:
432 events = await server.process_streaming(
433 ThreadsCreateReq(
434 params=ThreadCreateParams(
435 input=UserMessageInput(
436 content=[UserMessageTextContent(text="Hello")],
437 attachments=[],
438 inference_options=InferenceOptions(),
439 )
440 )
441 )
442 )
443
444 thread = next(e.thread for e in events if e.type == "thread.created")
445 workflow_done_event = next(
446 e
447 for e in events
448 if isinstance(e, ThreadItemDoneEvent) and isinstance(e.item, WorkflowItem)
449 )
450 workflow_done_item = cast(WorkflowItem, workflow_done_event.item)
451 assert len(workflow_done_item.workflow.tasks) == 1
452 assert workflow_done_item.workflow.tasks[0].title == "foo"
453
454 stored = await server.store.load_item(
455 thread.id, workflow_done_item.id, DEFAULT_CONTEXT
456 )
457 assert isinstance(stored, WorkflowItem)
458 stored_workflow = cast(WorkflowItem, stored)
459 assert len(stored_workflow.workflow.tasks) == 1
460 assert stored_workflow.workflow.tasks[0].title == "foo"
461
462
463async def test_flows_context_to_responder():
464 responder_context = None
465 add_feedback_context = None
466
467 async def responder(
468 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
469 ) -> AsyncIterator[ThreadStreamEvent]:
470 nonlocal responder_context
471 responder_context = context
472 yield ThreadItemDoneEvent(
473 item=AssistantMessageItem(
474 id="msg_1",
475 created_at=datetime.now(),
476 content=[AssistantMessageContent(text="Hello, world!")],
477 thread_id=thread.id,
478 ),
479 )
480
481 def add_feedback(
482 thread_id: str, item_ids: list[str], feedback: FeedbackKind, context: Any
483 ) -> None:
484 nonlocal add_feedback_context
485 add_feedback_context = context
486
487 context = RequestContext(user_id="text_ctx")
488 with make_server(responder, add_feedback) as server:
489 events = await server.process_streaming(
490 ThreadsCreateReq(
491 params=ThreadCreateParams(
492 input=UserMessageInput(
493 content=[UserMessageTextContent(text="Hello, world!")],
494 attachments=[],
495 inference_options=InferenceOptions(),
496 )
497 )
498 ),
499 context,
500 )
501 assert responder_context == context
502 thread = next(
503 event.thread for event in events if event.type == "thread.created"
504 )
505 item = next(
506 event.item
507 for event in events
508 if event.type == "thread.item.done"
509 and event.item.type == "assistant_message"
510 )
511 await server.process_non_streaming(
512 ItemsFeedbackReq(
513 params=ItemFeedbackParams(
514 thread_id=thread.id,
515 item_ids=[item.id],
516 kind="positive",
517 )
518 ),
519 context,
520 )
521 assert add_feedback_context == context
522
523
524async def test_creates_thread():
525 async def responder(
526 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
527 ) -> AsyncIterator[ThreadStreamEvent]:
528 return
529 yield
530
531 with make_server(responder) as server:
532 events = await server.process_streaming(
533 ThreadsCreateReq(
534 params=ThreadCreateParams(
535 input=UserMessageInput(
536 content=[UserMessageTextContent(text="Hello, world!")],
537 attachments=[],
538 inference_options=InferenceOptions(),
539 quoted_text="Quoted text",
540 )
541 )
542 )
543 )
544
545 thread_created_event = next(
546 event for event in events if event.type == "thread.created"
547 )
548 assert thread_created_event.thread.title is None
549
550 item_done_event = next(
551 event.item for event in events if event.type == "thread.item.done"
552 )
553
554 assert item_done_event.type == "user_message"
555 assert item_done_event.content[0].text == "Hello, world!"
556 assert item_done_event.quoted_text == "Quoted text"
557
558
559async def test_saves_thread_title():
560 async def responder(
561 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
562 ) -> AsyncIterator[ThreadStreamEvent]:
563 thread.title = "Updated title"
564 return
565 yield
566
567 with make_server(responder) as server:
568 events = await server.process_streaming(
569 ThreadsCreateReq(
570 params=ThreadCreateParams(
571 input=UserMessageInput(
572 content=[UserMessageTextContent(text="Hello, world!")],
573 attachments=[],
574 inference_options=InferenceOptions(),
575 )
576 )
577 )
578 )
579 thread = next(
580 event.thread for event in events if event.type == "thread.created"
581 )
582 assert (
583 await server.store.load_thread(
584 thread.id, RequestContext(user_id="test_user")
585 )
586 ).title == "Updated title"
587 assert events[-1].type == "thread.updated"
588 assert events[-1].thread.title == "Updated title"
589
590
591async def test_saves_thread_metadata():
592 async def responder(
593 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
594 ) -> AsyncIterator[ThreadStreamEvent]:
595 thread.metadata["test_key"] = "test_value"
596 return
597 yield
598
599 with make_server(responder) as server:
600 events = await server.process_streaming(
601 ThreadsCreateReq(
602 params=ThreadCreateParams(
603 input=UserMessageInput(
604 content=[UserMessageTextContent(text="Hello, world!")],
605 attachments=[],
606 inference_options=InferenceOptions(),
607 )
608 )
609 )
610 )
611 thread = next(
612 event.thread for event in events if event.type == "thread.created"
613 )
614 assert (await server.store.load_thread(thread.id, DEFAULT_CONTEXT)).metadata[
615 "test_key"
616 ] == "test_value"
617 assert events[-1].type == "thread.updated"
618
619
620async def test_saves_thread_locked_fields():
621 async def responder(
622 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
623 ) -> AsyncIterator[ThreadStreamEvent]:
624 thread.status = LockedStatus(reason="Because")
625 return
626 yield
627
628 with make_server(responder) as server:
629 events = await server.process_streaming(
630 ThreadsCreateReq(
631 params=ThreadCreateParams(
632 input=UserMessageInput(
633 content=[UserMessageTextContent(text="Hello, world!")],
634 attachments=[],
635 inference_options=InferenceOptions(),
636 )
637 )
638 )
639 )
640 thread = next(
641 event.thread for event in events if event.type == "thread.created"
642 )
643 loaded = await server.store.load_thread(
644 thread.id, RequestContext(user_id="test_user")
645 )
646 assert loaded.status == LockedStatus(reason="Because")
647 assert events[-1].type == "thread.updated"
648 assert events[-1].thread.status == LockedStatus(reason="Because")
649
650
651async def test_saves_allowed_image_domains_and_streams_thread_updated():
652 async def responder(
653 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
654 ) -> AsyncIterator[ThreadStreamEvent]:
655 thread.allowed_image_domains = ["example.com", "images.example.com"]
656 return
657 yield
658
659 with make_server(responder) as server:
660 events = await server.process_streaming(
661 ThreadsCreateReq(
662 params=ThreadCreateParams(
663 input=UserMessageInput(
664 content=[UserMessageTextContent(text="Hello, world!")],
665 attachments=[],
666 inference_options=InferenceOptions(),
667 )
668 )
669 )
670 )
671 thread = next(
672 event.thread for event in events if event.type == "thread.created"
673 )
674 loaded = await server.store.load_thread(thread.id, DEFAULT_CONTEXT)
675 assert loaded.allowed_image_domains == ["example.com", "images.example.com"]
676 assert events[-1].type == "thread.updated"
677 assert events[-1].thread.allowed_image_domains == [
678 "example.com",
679 "images.example.com",
680 ]
681
682
683async def test_omits_unset_allowed_image_domains_in_created_and_updated_json_events():
684 async def responder(
685 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
686 ) -> AsyncIterator[ThreadStreamEvent]:
687 thread.title = "Updated title"
688 return
689 yield
690
691 with make_server(responder) as server:
692 stream = await server.process(
693 ThreadsCreateReq(
694 params=ThreadCreateParams(
695 input=UserMessageInput(
696 content=[UserMessageTextContent(text="Hello, world!")],
697 attachments=[],
698 inference_options=InferenceOptions(),
699 )
700 )
701 ).model_dump_json(),
702 DEFAULT_CONTEXT,
703 )
704 assert isinstance(stream, StreamingResult)
705
706 thread_created_event: dict[str, Any] | None = None
707 thread_updated_event: dict[str, Any] | None = None
708 async for raw in stream.json_events:
709 event = json.loads(raw.split(b"data: ")[1])
710 if event["type"] == "thread.created":
711 thread_created_event = event
712 if event["type"] == "thread.updated":
713 thread_updated_event = event
714
715 assert thread_created_event is not None
716 assert "allowed_image_domains" not in thread_created_event["thread"]
717
718 assert thread_updated_event is not None
719 assert thread_updated_event["thread"]["title"] == "Updated title"
720 assert "allowed_image_domains" not in thread_updated_event["thread"]
721
722
723async def test_emits_thread_updated_mid_stream_and_persists():
724 async def responder(
725 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
726 ) -> AsyncIterator[ThreadStreamEvent]:
727 # First assistant message
728 yield ThreadItemDoneEvent(
729 item=AssistantMessageItem(
730 id="assistant_a",
731 content=[AssistantMessageContent(text="A")],
732 created_at=datetime.now(),
733 thread_id=thread.id,
734 ),
735 )
736
737 # Update the thread mid-stream
738 thread.title = "Mid-stream title"
739
740 # Second assistant message, after which thread.updated should be emitted
741 yield ThreadItemDoneEvent(
742 item=AssistantMessageItem(
743 id="assistant_b",
744 content=[AssistantMessageContent(text="B")],
745 created_at=datetime.now(),
746 thread_id=thread.id,
747 ),
748 )
749
750 # Continue streaming after the update event
751 yield ThreadItemDoneEvent(
752 item=AssistantMessageItem(
753 id="assistant_c",
754 content=[AssistantMessageContent(text="C")],
755 created_at=datetime.now(),
756 thread_id=thread.id,
757 ),
758 )
759
760 with make_server(responder) as server:
761 events = await server.process_streaming(
762 ThreadsCreateReq(
763 params=ThreadCreateParams(
764 input=UserMessageInput(
765 content=[UserMessageTextContent(text="Hello")],
766 attachments=[],
767 inference_options=InferenceOptions(),
768 )
769 )
770 )
771 )
772
773 # Find the created thread
774 thread = next(e.thread for e in events if isinstance(e, ThreadCreatedEvent))
775
776 # Ensure a thread.updated event occurs right after assistant_b
777 idx_b = next(
778 i
779 for i, e in enumerate(events)
780 if isinstance(e, ThreadItemDoneEvent)
781 and isinstance(e.item, AssistantMessageItem)
782 and e.item.id == "assistant_b"
783 )
784 assert isinstance(events[idx_b + 1], ThreadUpdatedEvent)
785 updated_event = cast(ThreadUpdatedEvent, events[idx_b + 1])
786 assert updated_event.thread.title == "Mid-stream title"
787
788 # Streaming continues after the update event
789 assert isinstance(events[idx_b + 2], ThreadItemDoneEvent)
790 done_event = cast(ThreadItemDoneEvent, events[idx_b + 2])
791 assert isinstance(done_event.item, AssistantMessageItem)
792 assert done_event.item.id == "assistant_c"
793
794 # Persisted in the store
795 saved = await server.store.load_thread(thread.id, DEFAULT_CONTEXT)
796 assert saved.title == "Mid-stream title"
797
798
799async def test_respond_with_tool_call():
800 async def responder(
801 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
802 ) -> AsyncIterator[ThreadStreamEvent]:
803 if isinstance(input, UserMessageItem):
804 yield ThreadItemDoneEvent(
805 item=ClientToolCallItem(
806 id="msg_1",
807 created_at=datetime.now(),
808 name="tool_call_1",
809 arguments={"arg1": "val1", "arg2": False},
810 call_id="tool_call_1",
811 thread_id=thread.id,
812 ),
813 )
814 elif input is None:
815 yield ThreadItemDoneEvent(
816 item=AssistantMessageItem(
817 id="msg_2",
818 content=[
819 AssistantMessageContent(text="Glad the tool call succeeded!")
820 ],
821 created_at=datetime.now(),
822 thread_id=thread.id,
823 ),
824 )
825
826 with make_server(responder) as server:
827 events = await server.process_streaming(
828 ThreadsCreateReq(
829 params=ThreadCreateParams(
830 input=UserMessageInput(
831 content=[UserMessageTextContent(text="Hello, world!")],
832 attachments=[],
833 inference_options=InferenceOptions(),
834 )
835 )
836 )
837 )
838
839 assert len(events) == 4
840 assert events[0].type == "thread.created"
841 thread = events[0].thread
842
843 assert events[1].type == "thread.item.done"
844 assert events[1].item.type == "user_message"
845
846 assert events[2].type == "stream_options"
847
848 assert events[3].type == "thread.item.done"
849 assert events[3].item.type == "client_tool_call"
850 assert events[3].item.id == "msg_1"
851 assert events[3].item.name == "tool_call_1"
852 assert events[3].item.arguments == {"arg1": "val1", "arg2": False}
853 assert events[3].item.call_id == "tool_call_1"
854
855 events = await server.process_streaming(
856 ThreadsAddClientToolOutputReq(
857 params=ThreadAddClientToolOutputParams(
858 thread_id=thread.id,
859 result={"text": "Wow!"},
860 )
861 )
862 )
863
864 assert len(events) == 2
865 assert events[0].type == "stream_options"
866 assert events[1].type == "thread.item.done"
867 assert events[1].item.type == "assistant_message"
868
869
870async def test_respond_with_tool_status():
871 async def responder(
872 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
873 ) -> AsyncIterator[ThreadStreamEvent]:
874 yield ProgressUpdateEvent(text="Tool status")
875 yield ThreadItemDoneEvent(
876 item=AssistantMessageItem(
877 id="msg_4",
878 content=[AssistantMessageContent(text="Hello, world!")],
879 created_at=datetime.now(),
880 thread_id=thread.id,
881 ),
882 )
883
884 with make_server(responder) as server:
885 events = await server.process_streaming(
886 ThreadsCreateReq(
887 params=ThreadCreateParams(
888 input=UserMessageInput(
889 content=[UserMessageTextContent(text="Hello, world!")],
890 attachments=[],
891 inference_options=InferenceOptions(),
892 )
893 )
894 )
895 )
896
897 assert len(events) == 5
898 assert events[0].type == "thread.created"
899 assert events[1].type == "thread.item.done"
900 assert events[2].type == "stream_options"
901 assert events[3].type == "progress_update"
902 assert events[4].type == "thread.item.done"
903
904
905async def test_list_threads_response():
906 with make_server() as server:
907 for i in range(25):
908 await server.process_streaming(
909 ThreadsCreateReq(
910 params=ThreadCreateParams(
911 input=UserMessageInput(
912 content=[UserMessageTextContent(text=f"Thread {i}")],
913 attachments=[],
914 inference_options=InferenceOptions(),
915 )
916 )
917 )
918 )
919 list_result = await server.process_non_streaming(
920 ThreadsListReq(params=ThreadListParams(limit=20))
921 )
922 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
923 assert len(threads.data) == 20
924
925 # Newest first
926 assert threads.data[0].created_at > threads.data[1].created_at
927 assert threads.has_more is True
928 assert threads.after is not None
929
930 more_threads = await server.process_non_streaming(
931 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
932 )
933 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(more_threads.json)
934 assert len(threads.data) == 5
935 assert threads.has_more is False
936 assert not threads.after
937
938
939async def test_list_items_response():
940 async def responder(
941 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
942 ) -> AsyncIterator[ThreadStreamEvent]:
943 for i in range(25):
944 yield ThreadItemDoneEvent(
945 item=UserMessageItem(
946 id=f"msg_{i}",
947 content=[UserMessageTextContent(text=f"Message {i}")],
948 attachments=[],
949 inference_options=InferenceOptions(),
950 thread_id=thread.id,
951 created_at=datetime.now(),
952 ),
953 )
954
955 with make_server(responder) as server:
956 events = await server.process_streaming(
957 ThreadsCreateReq(
958 params=ThreadCreateParams(
959 input=UserMessageInput(
960 content=[UserMessageTextContent(text="Thread 1")],
961 attachments=[],
962 inference_options=InferenceOptions(),
963 )
964 )
965 )
966 )
967 thread = next(
968 event.thread for event in events if event.type == "thread.created"
969 )
970 list_result = await server.process_non_streaming(
971 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
972 )
973 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
974 assert len(items.data) == 20
975 # Newest first
976 assert items.data[0].created_at > items.data[1].created_at
977 assert items.has_more is True
978 assert items.after is not None
979
980 list_result = await server.process_non_streaming(
981 ItemsListReq(
982 params=ItemsListParams(thread_id=thread.id, after=items.data[-1].id)
983 )
984 )
985 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
986 assert len(items.data) == 6 # +1 for user message
987 assert items.has_more is False
988
989
990async def test_calls_action():
991 actions = []
992
993 async def responder(
994 thread: ThreadMetadata,
995 input: UserMessageItem | None,
996 context: Any,
997 ) -> AsyncIterator[ThreadStreamEvent]:
998 yield ThreadItemDoneEvent(
999 item=WidgetItem(
1000 id="widget_1",
1001 type="widget",
1002 created_at=datetime.now(),
1003 thread_id=thread.id,
1004 widget=Card(
1005 children=[
1006 Text(
1007 key="text_1",
1008 value="test",
1009 )
1010 ],
1011 ),
1012 ),
1013 )
1014
1015 async def action(
1016 thread: ThreadMetadata,
1017 action: Action[str, Any],
1018 sender: WidgetItem | None,
1019 context: Any,
1020 ) -> AsyncIterator[ThreadStreamEvent]:
1021 actions.append((action, sender))
1022 assert sender
1023
1024 yield ThreadItemUpdatedEvent(
1025 item_id=sender.id,
1026 update=WidgetRootUpdated(
1027 widget=Card(
1028 children=[
1029 Text(value="Email sent!"),
1030 ]
1031 ),
1032 ),
1033 )
1034
1035 with make_server(responder, action_callback=action) as server:
1036 events = await server.process_streaming(
1037 ThreadsCreateReq(
1038 params=ThreadCreateParams(
1039 input=UserMessageInput(
1040 content=[UserMessageTextContent(text="Show widget")],
1041 attachments=[],
1042 inference_options=InferenceOptions(),
1043 )
1044 )
1045 )
1046 )
1047 thread = next(
1048 event.thread for event in events if event.type == "thread.created"
1049 )
1050 widget_item = next(
1051 event.item
1052 for event in events
1053 if isinstance(event, ThreadItemDoneEvent)
1054 and isinstance(event.item, WidgetItem)
1055 )
1056
1057 events = await server.process_streaming(
1058 ThreadsCustomActionReq(
1059 params=ThreadCustomActionParams(
1060 thread_id=thread.id,
1061 item_id=widget_item.id,
1062 action=Action(type="create_user", payload={"user_id": "123"}),
1063 )
1064 )
1065 )
1066
1067 assert actions
1068 assert actions[0] == (
1069 Action(type="create_user", payload={"user_id": "123"}),
1070 widget_item,
1071 )
1072
1073 assert len(events) == 2
1074 assert events[0].type == "stream_options"
1075 assert events[1].type == "thread.item.updated"
1076 assert isinstance(events[1], ThreadItemUpdatedEvent)
1077 assert events[1].update.type == "widget.root.updated"
1078 assert events[1].update.widget == Card(children=[Text(value="Email sent!")])
1079
1080
1081async def test_add_structured_input_replaces_item_with_custom_answer_and_resumes_response():
1082 async def responder(
1083 thread: ThreadMetadata,
1084 input: UserMessageItem | None,
1085 context: Any,
1086 ) -> AsyncIterator[ThreadStreamEvent]:
1087 if input is not None:
1088 yield ThreadItemDoneEvent(
1089 item=StructuredInputItem(
1090 id="si_1",
1091 thread_id=thread.id,
1092 created_at=datetime.now(),
1093 inputs=[
1094 StructuredInputMultipleChoice(
1095 id="subject",
1096 question="Which subject is this lesson for?",
1097 options=[
1098 StructuredInputMultipleChoiceOption(value="Math"),
1099 StructuredInputMultipleChoiceOption(value="Science"),
1100 ],
1101 ),
1102 StructuredInputFreeform(
1103 id="details",
1104 question="Anything else to know?",
1105 ),
1106 ],
1107 )
1108 )
1109 return
1110
1111 yield ThreadItemDoneEvent(
1112 item=AssistantMessageItem(
1113 id="msg_after_structured_input",
1114 content=[AssistantMessageContent(text="Thanks, continuing now.")],
1115 created_at=datetime.now(),
1116 thread_id=thread.id,
1117 )
1118 )
1119
1120 with make_server(responder) as server:
1121 events = await server.process_streaming(
1122 ThreadsCreateReq(
1123 params=ThreadCreateParams(
1124 input=UserMessageInput(
1125 content=[UserMessageTextContent(text="Plan a lesson")],
1126 attachments=[],
1127 inference_options=InferenceOptions(),
1128 )
1129 )
1130 )
1131 )
1132 thread = next(
1133 event.thread for event in events if isinstance(event, ThreadCreatedEvent)
1134 )
1135
1136 events = await server.process_streaming(
1137 ThreadsAddStructuredInputReq(
1138 params=ThreadAddStructuredInputParams(
1139 thread_id=thread.id,
1140 item_id="si_1",
1141 input=StructuredInputSubmission(
1142 status="answered",
1143 answers={
1144 "subject": StructuredInputAnswerSubmission(
1145 values=["Private class"]
1146 ),
1147 "details": StructuredInputAnswerSubmission(
1148 values=["Please include group work."]
1149 ),
1150 },
1151 ),
1152 )
1153 )
1154 )
1155
1156 replaced_event = next(
1157 event for event in events if isinstance(event, ThreadItemReplacedEvent)
1158 )
1159 assert isinstance(replaced_event.item, StructuredInputItem)
1160 assert replaced_event.item.status == "answered"
1161 assert replaced_event.item.inputs[0].answer is not None
1162 assert replaced_event.item.inputs[0].answer.values == ["Private class"]
1163 assert replaced_event.item.inputs[1].answer is not None
1164 assert replaced_event.item.inputs[1].answer.values == [
1165 "Please include group work."
1166 ]
1167
1168 persisted_item = await server.store.load_item(
1169 thread.id,
1170 "si_1",
1171 DEFAULT_CONTEXT,
1172 )
1173 assert persisted_item == replaced_event.item
1174 assert any(
1175 isinstance(event, ThreadItemDoneEvent)
1176 and isinstance(event.item, AssistantMessageItem)
1177 and event.item.id == "msg_after_structured_input"
1178 for event in events
1179 )
1180
1181
1182async def test_add_structured_input_treats_omitted_answers_as_skipped():
1183 async def responder(
1184 thread: ThreadMetadata,
1185 input: UserMessageItem | None,
1186 context: Any,
1187 ) -> AsyncIterator[ThreadStreamEvent]:
1188 if input is not None:
1189 yield ThreadItemDoneEvent(
1190 item=StructuredInputItem(
1191 id="si_1",
1192 thread_id=thread.id,
1193 created_at=datetime.now(),
1194 inputs=[
1195 StructuredInputMultipleChoice(
1196 id="subject",
1197 question="Which subject is this lesson for?",
1198 options=[
1199 StructuredInputMultipleChoiceOption(value="Math"),
1200 StructuredInputMultipleChoiceOption(value="Science"),
1201 ],
1202 ),
1203 StructuredInputFreeform(
1204 id="details",
1205 question="Anything else to know?",
1206 ),
1207 ],
1208 )
1209 )
1210
1211 with make_server(responder) as server:
1212 events = await server.process_streaming(
1213 ThreadsCreateReq(
1214 params=ThreadCreateParams(
1215 input=UserMessageInput(
1216 content=[UserMessageTextContent(text="Plan a lesson")],
1217 attachments=[],
1218 inference_options=InferenceOptions(),
1219 )
1220 )
1221 )
1222 )
1223 thread = next(
1224 event.thread for event in events if isinstance(event, ThreadCreatedEvent)
1225 )
1226
1227 events = await server.process_streaming(
1228 ThreadsAddStructuredInputReq(
1229 params=ThreadAddStructuredInputParams(
1230 thread_id=thread.id,
1231 item_id="si_1",
1232 input=StructuredInputSubmission(
1233 status="answered",
1234 answers={
1235 "subject": StructuredInputAnswerSubmission(
1236 values=["Math"]
1237 ),
1238 },
1239 ),
1240 )
1241 )
1242 )
1243
1244 replaced_event = next(
1245 event for event in events if isinstance(event, ThreadItemReplacedEvent)
1246 )
1247 assert isinstance(replaced_event.item, StructuredInputItem)
1248 subject_answer = replaced_event.item.inputs[0].answer
1249 assert subject_answer is not None
1250 assert subject_answer.values == ["Math"]
1251 details_answer = replaced_event.item.inputs[1].answer
1252 assert details_answer is not None
1253 assert details_answer.skipped
1254
1255
1256async def test_add_structured_input_ignores_unknown_answer_ids():
1257 async def responder(
1258 thread: ThreadMetadata,
1259 input: UserMessageItem | None,
1260 context: Any,
1261 ) -> AsyncIterator[ThreadStreamEvent]:
1262 if input is not None:
1263 yield ThreadItemDoneEvent(
1264 item=StructuredInputItem(
1265 id="si_1",
1266 thread_id=thread.id,
1267 created_at=datetime.now(),
1268 inputs=[
1269 StructuredInputFreeform(
1270 id="details",
1271 question="Anything else to know?",
1272 ),
1273 ],
1274 )
1275 )
1276
1277 with make_server(responder) as server:
1278 events = await server.process_streaming(
1279 ThreadsCreateReq(
1280 params=ThreadCreateParams(
1281 input=UserMessageInput(
1282 content=[UserMessageTextContent(text="Plan a lesson")],
1283 attachments=[],
1284 inference_options=InferenceOptions(),
1285 )
1286 )
1287 )
1288 )
1289 thread = next(
1290 event.thread for event in events if isinstance(event, ThreadCreatedEvent)
1291 )
1292
1293 events = await server.process_streaming(
1294 ThreadsAddStructuredInputReq(
1295 params=ThreadAddStructuredInputParams(
1296 thread_id=thread.id,
1297 item_id="si_1",
1298 input=StructuredInputSubmission(
1299 status="answered",
1300 answers={
1301 "details": StructuredInputAnswerSubmission(
1302 values=["Please include group work."]
1303 ),
1304 "unknown": StructuredInputAnswerSubmission(
1305 values=["ignored"]
1306 ),
1307 },
1308 ),
1309 )
1310 )
1311 )
1312
1313 replaced_event = next(
1314 event for event in events if isinstance(event, ThreadItemReplacedEvent)
1315 )
1316 assert isinstance(replaced_event.item, StructuredInputItem)
1317 answer = replaced_event.item.inputs[0].answer
1318 assert answer is not None
1319 assert answer.values == ["Please include group work."]
1320
1321
1322async def test_add_structured_input_truncates_extra_single_choice_values():
1323 async def responder(
1324 thread: ThreadMetadata,
1325 input: UserMessageItem | None,
1326 context: Any,
1327 ) -> AsyncIterator[ThreadStreamEvent]:
1328 yield ThreadItemDoneEvent(
1329 item=StructuredInputItem(
1330 id="si_1",
1331 thread_id=thread.id,
1332 created_at=datetime.now(),
1333 inputs=[
1334 StructuredInputMultipleChoice(
1335 id="subject",
1336 question="Which subject is this lesson for?",
1337 options=[
1338 StructuredInputMultipleChoiceOption(value="Math"),
1339 StructuredInputMultipleChoiceOption(value="Science"),
1340 ],
1341 ),
1342 ],
1343 )
1344 )
1345
1346 with make_server(responder) as server:
1347 events = await server.process_streaming(
1348 ThreadsCreateReq(
1349 params=ThreadCreateParams(
1350 input=UserMessageInput(
1351 content=[UserMessageTextContent(text="Plan a lesson")],
1352 attachments=[],
1353 inference_options=InferenceOptions(),
1354 )
1355 )
1356 )
1357 )
1358 thread = next(
1359 event.thread for event in events if isinstance(event, ThreadCreatedEvent)
1360 )
1361
1362 events = await server.process_streaming(
1363 ThreadsAddStructuredInputReq(
1364 params=ThreadAddStructuredInputParams(
1365 thread_id=thread.id,
1366 item_id="si_1",
1367 input=StructuredInputSubmission(
1368 status="answered",
1369 answers={
1370 "subject": StructuredInputAnswerSubmission(
1371 values=["Math", "Science"],
1372 ),
1373 },
1374 ),
1375 )
1376 )
1377 )
1378
1379 replaced_event = next(
1380 event for event in events if isinstance(event, ThreadItemReplacedEvent)
1381 )
1382 assert isinstance(replaced_event.item, StructuredInputItem)
1383 answer = replaced_event.item.inputs[0].answer
1384 assert answer is not None
1385 assert answer.values == ["Math"]
1386
1387 persisted_item = await server.store.load_item(
1388 thread.id,
1389 "si_1",
1390 DEFAULT_CONTEXT,
1391 )
1392 assert isinstance(persisted_item, StructuredInputItem)
1393 assert persisted_item.status == "answered"
1394 assert persisted_item.inputs[0].answer == answer
1395
1396
1397async def test_calls_action_sync():
1398 sync_actions = []
1399
1400 async def responder(
1401 thread: ThreadMetadata,
1402 input: UserMessageItem | None,
1403 context: Any,
1404 ) -> AsyncIterator[ThreadStreamEvent]:
1405 yield ThreadItemDoneEvent(
1406 item=WidgetItem(
1407 id="widget_1",
1408 type="widget",
1409 created_at=datetime.now(),
1410 thread_id=thread.id,
1411 widget=Card(
1412 children=[
1413 Text(
1414 key="text_1",
1415 value="test",
1416 )
1417 ],
1418 ),
1419 ),
1420 )
1421
1422 def sync_action(
1423 thread: ThreadMetadata,
1424 action: Action[str, Any],
1425 sender: WidgetItem | None,
1426 context: Any,
1427 ) -> SyncCustomActionResponse:
1428 sync_actions.append((action, sender))
1429 assert sender
1430
1431 return SyncCustomActionResponse(
1432 updated_item=sender.model_copy(
1433 update={
1434 "widget": Card(
1435 children=[
1436 Text(value="Email sent!"),
1437 ]
1438 )
1439 }
1440 )
1441 )
1442
1443 with make_server(responder, sync_action_callback=sync_action) as server:
1444 events = await server.process_streaming(
1445 ThreadsCreateReq(
1446 params=ThreadCreateParams(
1447 input=UserMessageInput(
1448 content=[UserMessageTextContent(text="Show widget")],
1449 attachments=[],
1450 inference_options=InferenceOptions(),
1451 )
1452 )
1453 )
1454 )
1455 thread = next(
1456 event.thread for event in events if event.type == "thread.created"
1457 )
1458 widget_item = next(
1459 event.item
1460 for event in events
1461 if isinstance(event, ThreadItemDoneEvent)
1462 and isinstance(event.item, WidgetItem)
1463 )
1464
1465 result = await server.process_non_streaming(
1466 ThreadsSyncCustomActionReq(
1467 params=ThreadCustomActionParams(
1468 thread_id=thread.id,
1469 item_id=widget_item.id,
1470 action=Action(type="create_user", payload={"user_id": "123"}),
1471 )
1472 )
1473 )
1474 response = TypeAdapter(SyncCustomActionResponse).validate_json(result.json)
1475
1476 assert sync_actions
1477 assert sync_actions[0] == (
1478 Action(type="create_user", payload={"user_id": "123"}),
1479 widget_item,
1480 )
1481 assert response.updated_item == widget_item.model_copy(
1482 update={"widget": Card(children=[Text(value="Email sent!")])}
1483 )
1484
1485
1486async def test_add_feedback():
1487 called = []
1488
1489 def handle_feedback(thread_id, item_ids, feedback, context):
1490 called.append((thread_id, item_ids, feedback, context))
1491
1492 async def responder(
1493 thread: ThreadMetadata,
1494 input: UserMessageItem | None,
1495 context: Any,
1496 ) -> AsyncIterator[ThreadStreamEvent]:
1497 yield ThreadItemDoneEvent(
1498 item=AssistantMessageItem(
1499 id="assistant_msg_1",
1500 content=[
1501 AssistantMessageContent(text="Feedback received!"),
1502 ],
1503 created_at=datetime.now(),
1504 thread_id=thread.id,
1505 ),
1506 )
1507 yield ThreadItemDoneEvent(
1508 item=WidgetItem(
1509 id="widget_1",
1510 type="widget",
1511 created_at=datetime.now(),
1512 widget=Card(children=[Text(key="text", value="Widget content")]),
1513 thread_id=thread.id,
1514 ),
1515 )
1516
1517 with make_server(responder, handle_feedback=handle_feedback) as server:
1518 events = await server.process_streaming(
1519 ThreadsCreateReq(
1520 params=ThreadCreateParams(
1521 input=UserMessageInput(
1522 content=[UserMessageTextContent(text="Test feedback")],
1523 attachments=[],
1524 inference_options=InferenceOptions(),
1525 )
1526 )
1527 )
1528 )
1529 thread = next(
1530 event.thread for event in events if event.type == "thread.created"
1531 )
1532
1533 await server.process_non_streaming(
1534 ItemsFeedbackReq(
1535 params=ItemFeedbackParams(
1536 thread_id=thread.id,
1537 item_ids=["assistant_msg_1"],
1538 kind="positive",
1539 )
1540 )
1541 )
1542 await server.process_non_streaming(
1543 ItemsFeedbackReq(
1544 params=ItemFeedbackParams(
1545 thread_id=thread.id,
1546 item_ids=["widget_1"],
1547 kind="negative",
1548 )
1549 )
1550 )
1551
1552 assert len(called) == 2
1553 assert called[0][0] == thread.id
1554 assert called[0][1] == ["assistant_msg_1"]
1555 assert called[0][2] == "positive"
1556
1557 assert called[1][0] == thread.id
1558 assert called[1][1] == ["widget_1"]
1559 assert called[1][2] == "negative"
1560
1561
1562async def test_get_thread_by_id():
1563 make_thread_items()
1564
1565 with make_server() as server:
1566 # Create a thread by sending a ThreadCreateRequest
1567 events = await server.process_streaming(
1568 ThreadsCreateReq(
1569 params=ThreadCreateParams(
1570 input=UserMessageInput(
1571 content=[UserMessageTextContent(text="Test thread")],
1572 attachments=[],
1573 inference_options=InferenceOptions(),
1574 )
1575 )
1576 )
1577 )
1578 thread = next(
1579 event.thread for event in events if event.type == "thread.created"
1580 )
1581
1582 # Retrieve the thread by id
1583 result = await server.process_non_streaming(
1584 ThreadsGetByIdReq(
1585 params=ThreadGetByIdParams(thread_id=thread.id),
1586 )
1587 )
1588 loaded_thread = TypeAdapter(Thread).validate_json(result.json)
1589 assert loaded_thread.id == thread.id
1590 assert loaded_thread.title == thread.title
1591 assert loaded_thread.items.data[0].type == "user_message"
1592 assert loaded_thread.items.data[0].content[0].text == "Test thread"
1593
1594
1595async def test_create_file():
1596 store = InMemoryFileStore()
1597 file_name = "test-file-name"
1598 file_content_type = "text/plain"
1599
1600 with make_server(file_store=store) as server:
1601 result = await server.process_non_streaming(
1602 AttachmentsCreateReq(
1603 params=AttachmentCreateParams(
1604 name=file_name, size=1024, mime_type=file_content_type
1605 )
1606 )
1607 )
1608 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1609
1610 # Verify response includes file
1611 assert attachment.id is not None
1612 assert attachment.mime_type == file_content_type
1613 assert attachment.name == file_name
1614 assert attachment.type == "file"
1615 assert attachment.upload_descriptor is not None
1616 assert attachment.upload_descriptor.url == AnyUrl(
1617 f"https://example.com/{attachment.id}/upload"
1618 )
1619 assert attachment.upload_descriptor.method == "PUT"
1620 assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"}
1621 assert attachment.id in store.files
1622
1623 # The `metadata` field is excluded from serialization by default but still stored server-side
1624 assert attachment.metadata is None
1625 assert store.files[attachment.id].metadata == {"source": "test"}
1626
1627
1628async def test_create_image_file():
1629 store = InMemoryFileStore()
1630 file_name = "test-file-name"
1631 file_content_type = "image/png"
1632
1633 with make_server(file_store=store) as server:
1634 result = await server.process_non_streaming(
1635 AttachmentsCreateReq(
1636 params=AttachmentCreateParams(
1637 name=file_name,
1638 size=1024,
1639 mime_type=file_content_type,
1640 )
1641 )
1642 )
1643 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1644
1645 assert attachment.id is not None
1646 assert attachment.mime_type == file_content_type
1647 assert attachment.name == file_name
1648 assert attachment.type == "image"
1649 assert attachment.preview_url == AnyUrl(
1650 f"https://example.com/{attachment.id}/preview"
1651 )
1652 assert attachment.upload_descriptor is not None
1653 assert attachment.upload_descriptor.url == AnyUrl(
1654 f"https://example.com/{attachment.id}/upload"
1655 )
1656 assert attachment.upload_descriptor.method == "PUT"
1657 assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"}
1658
1659 assert attachment.id in store.files
1660
1661
1662async def test_user_message_attachment_metadata_not_serialized():
1663 attachment = FileAttachment(
1664 id="file_with_metadata",
1665 type="file",
1666 mime_type="text/plain",
1667 name="test.txt",
1668 metadata={"source": "test"},
1669 )
1670
1671 with make_server() as server:
1672 await server.store.save_attachment(attachment, DEFAULT_CONTEXT)
1673 events = await server.process_streaming(
1674 ThreadsCreateReq(
1675 params=ThreadCreateParams(
1676 input=UserMessageInput(
1677 content=[UserMessageTextContent(text="Message with file")],
1678 attachments=[attachment.id],
1679 inference_options=InferenceOptions(),
1680 )
1681 )
1682 )
1683 )
1684 thread = next(
1685 event.thread for event in events if event.type == "thread.created"
1686 )
1687
1688 list_result = await server.process_non_streaming(
1689 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1690 )
1691 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1692 user_items = [item for item in items.data if item.type == "user_message"]
1693 assert user_items
1694 attachments = user_items[0].attachments
1695 assert attachments
1696 assert all(att.metadata is None for att in attachments)
1697
1698
1699async def test_user_message_attachment_thread_id_persisted_on_thread_create():
1700 attachment = FileAttachment(
1701 id="file_thread_id_on_create",
1702 type="file",
1703 mime_type="text/plain",
1704 name="test.txt",
1705 thread_id=None,
1706 )
1707
1708 with make_server() as server:
1709 # Attachment exists before a thread exists, so it may not have a thread_id yet.
1710 await server.store.save_attachment(attachment, DEFAULT_CONTEXT)
1711
1712 events = await server.process_streaming(
1713 ThreadsCreateReq(
1714 params=ThreadCreateParams(
1715 input=UserMessageInput(
1716 content=[UserMessageTextContent(text="Message with file")],
1717 attachments=[attachment.id],
1718 inference_options=InferenceOptions(),
1719 )
1720 )
1721 )
1722 )
1723 thread = next(e.thread for e in events if e.type == "thread.created")
1724
1725 # After the user message is saved, the attachment record should be updated
1726 # to reflect its thread association.
1727 stored = await server.store.load_attachment(attachment.id, DEFAULT_CONTEXT)
1728 assert stored.thread_id == thread.id
1729
1730
1731async def test_user_message_attachment_thread_id_persisted_on_add_user_message():
1732 with make_server() as server:
1733 # Create a thread first.
1734 events = await server.process_streaming(
1735 ThreadsCreateReq(
1736 params=ThreadCreateParams(
1737 input=UserMessageInput(
1738 content=[UserMessageTextContent(text="Start")],
1739 attachments=[],
1740 inference_options=InferenceOptions(),
1741 )
1742 )
1743 )
1744 )
1745 thread = next(e.thread for e in events if e.type == "thread.created")
1746
1747 attachment = FileAttachment(
1748 id="file_thread_id_on_add",
1749 type="file",
1750 mime_type="text/plain",
1751 name="test.txt",
1752 thread_id=None,
1753 )
1754 await server.store.save_attachment(attachment, DEFAULT_CONTEXT)
1755
1756 # Attach the pre-created attachment to a message in an existing thread.
1757 await server.process_streaming(
1758 ThreadsAddUserMessageReq(
1759 params=ThreadAddUserMessageParams(
1760 thread_id=thread.id,
1761 input=UserMessageInput(
1762 content=[UserMessageTextContent(text="Second")],
1763 attachments=[attachment.id],
1764 inference_options=InferenceOptions(),
1765 ),
1766 )
1767 )
1768 )
1769
1770 stored = await server.store.load_attachment(attachment.id, DEFAULT_CONTEXT)
1771 assert stored.thread_id == thread.id
1772
1773
1774async def test_create_file_without_filestore():
1775 with pytest.raises(RuntimeError):
1776 with make_server() as server:
1777 await server.process_non_streaming(
1778 AttachmentsCreateReq(
1779 params=AttachmentCreateParams(
1780 name="test-file-name",
1781 size=1024,
1782 mime_type="text/plain",
1783 )
1784 )
1785 )
1786
1787
1788async def test_delete_file():
1789 store = InMemoryFileStore()
1790 store.files["test-file-id"] = {"id": "test-file-id"}
1791 with make_server(file_store=store) as server:
1792 await server.process_non_streaming(
1793 AttachmentsDeleteReq(
1794 params=AttachmentDeleteParams(attachment_id="test-file-id")
1795 )
1796 )
1797 assert store.files == {}
1798
1799
1800async def test_delete_file_without_filestore():
1801 with pytest.raises(RuntimeError):
1802 with make_server() as server:
1803 await server.process_non_streaming(
1804 AttachmentsDeleteReq(
1805 params=AttachmentDeleteParams(attachment_id="test-file-id")
1806 )
1807 )
1808
1809
1810async def test_list_threads_paginated_response():
1811 with make_server() as server:
1812 for i in range(25):
1813 await server.process_streaming(
1814 ThreadsCreateReq(
1815 params=ThreadCreateParams(
1816 input=UserMessageInput(
1817 content=[UserMessageTextContent(text=f"Thread {i}")],
1818 attachments=[],
1819 inference_options=InferenceOptions(),
1820 )
1821 )
1822 )
1823 )
1824 list_result = await server.process_non_streaming(
1825 ThreadsListReq(params=ThreadListParams(limit=20))
1826 )
1827 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1828 assert len(threads.data) == 20
1829 assert threads.has_more is True
1830 assert threads.after is not None
1831
1832 list_result = await server.process_non_streaming(
1833 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
1834 )
1835 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1836 assert len(threads.data) == 5
1837 assert threads.has_more is False
1838
1839
1840async def test_threads_update_request():
1841 with make_server() as server:
1842 events = await server.process_streaming(
1843 ThreadsCreateReq(
1844 params=ThreadCreateParams(
1845 input=UserMessageInput(
1846 content=[UserMessageTextContent(text="Hello, world!")],
1847 attachments=[],
1848 inference_options=InferenceOptions(),
1849 )
1850 )
1851 )
1852 )
1853 thread = next(
1854 event.thread for event in events if event.type == "thread.created"
1855 )
1856 updated_thread = await server.process_non_streaming(
1857 ThreadsUpdateReq(
1858 params=ThreadUpdateParams(thread_id=thread.id, title="New Title")
1859 )
1860 )
1861 updated_thread = TypeAdapter[Thread](Thread).validate_json(updated_thread.json)
1862 assert updated_thread.title == "New Title"
1863
1864
1865async def test_returns_widget_item_generator():
1866 thread = ThreadMetadata(
1867 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1868 )
1869
1870 def render_widget(i: int) -> Card:
1871 return Card(children=[Text(id="text", value="Hello, world"[:i])])
1872
1873 async def widget_generator():
1874 yield render_widget(0)
1875 yield render_widget(3)
1876 yield render_widget(6)
1877 yield render_widget(12)
1878
1879 events = [event async for event in stream_widget(thread, widget_generator())]
1880
1881 assert len(events) == 5
1882 assert isinstance(events[0], ThreadItemAddedEvent)
1883 assert isinstance(events[0].item, WidgetItem)
1884 assert events[0].item.widget == Card(children=[Text(id="text", value="")])
1885
1886 assert isinstance(events[1], ThreadItemUpdatedEvent)
1887 assert events[1].update.type == "widget.streaming_text.value_delta"
1888 assert events[1].update.component_id == "text"
1889 assert events[1].update.delta == "Hel"
1890
1891 assert isinstance(events[2], ThreadItemUpdatedEvent)
1892 assert events[2].update.type == "widget.streaming_text.value_delta"
1893 assert events[2].update.component_id == "text"
1894 assert events[2].update.delta == "lo,"
1895
1896 assert isinstance(events[3], ThreadItemUpdatedEvent)
1897 assert events[3].update.type == "widget.streaming_text.value_delta"
1898 assert events[3].update.component_id == "text"
1899 assert events[3].update.delta == " world"
1900
1901 assert isinstance(events[4], ThreadItemDoneEvent)
1902 assert isinstance(events[4].item, WidgetItem)
1903 assert events[4].item.widget == Card(
1904 children=[Text(id="text", value="Hello, world")]
1905 )
1906
1907
1908async def test_returns_widget_item_generator_full_replace():
1909 thread = ThreadMetadata(
1910 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1911 )
1912
1913 async def widget_generator():
1914 yield Card(children=[Text(id="text", value="Hello")])
1915 yield Card(children=[Text(id="text", value="World")])
1916
1917 events = [event async for event in stream_widget(thread, widget_generator())]
1918
1919 assert len(events) == 3
1920 assert isinstance(events[0], ThreadItemAddedEvent)
1921 assert isinstance(events[0].item, WidgetItem)
1922 assert events[0].item.widget == Card(children=[Text(id="text", value="Hello")])
1923
1924 assert isinstance(events[1], ThreadItemUpdatedEvent)
1925 assert events[1].update.type == "widget.root.updated"
1926 assert events[1].update.widget == Card(children=[Text(id="text", value="World")])
1927
1928 assert isinstance(events[2], ThreadItemDoneEvent)
1929 assert isinstance(events[2].item, WidgetItem)
1930 assert events[2].item.widget == Card(children=[Text(id="text", value="World")])
1931
1932
1933async def test_delete_thread():
1934 with make_server() as server:
1935 events = await server.process_streaming(
1936 ThreadsCreateReq(
1937 params=ThreadCreateParams(
1938 input=UserMessageInput(
1939 content=[UserMessageTextContent(text="Thread to delete")],
1940 attachments=[],
1941 inference_options=InferenceOptions(),
1942 )
1943 )
1944 )
1945 )
1946 thread = next(
1947 event.thread for event in events if event.type == "thread.created"
1948 )
1949
1950 response = await server.process_non_streaming(
1951 ThreadsDeleteReq(params=ThreadDeleteParams(thread_id=thread.id))
1952 )
1953 assert response.json == b"{}"
1954
1955 try:
1956 await server.process_non_streaming(
1957 ThreadsGetByIdReq(
1958 params=ThreadGetByIdParams(thread_id=thread.id),
1959 )
1960 )
1961 assert False, "Expected ValueError for deleted thread"
1962 except NotFoundError as e:
1963 assert f"Thread {thread.id} not found" in str(e)
1964
1965
1966async def test_thread_item_removed_event_removes_item():
1967 removed_id = "msg_to_remove"
1968
1969 async def responder(
1970 thread: ThreadMetadata,
1971 input: UserMessageItem | None,
1972 context: Any,
1973 ) -> AsyncIterator[ThreadStreamEvent]:
1974 yield ThreadItemDoneEvent(
1975 item=UserMessageItem(
1976 id=removed_id,
1977 content=[UserMessageTextContent(text="To be removed")],
1978 attachments=[],
1979 inference_options=InferenceOptions(),
1980 thread_id=thread.id,
1981 created_at=datetime.now(),
1982 ),
1983 )
1984 yield ThreadItemRemovedEvent(item_id=removed_id)
1985
1986 with make_server(responder) as server:
1987 events = await server.process_streaming(
1988 ThreadsCreateReq(
1989 params=ThreadCreateParams(
1990 input=UserMessageInput(
1991 content=[UserMessageTextContent(text="Trigger removal")],
1992 attachments=[],
1993 inference_options=InferenceOptions(),
1994 )
1995 )
1996 )
1997 )
1998 thread = next(
1999 event.thread for event in events if event.type == "thread.created"
2000 )
2001 assert any(
2002 e.type == "thread.item.removed" and e.item_id == removed_id for e in events
2003 )
2004 # The item should not be present in the store
2005 list_result = await server.process_non_streaming(
2006 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
2007 )
2008 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
2009 assert all(i.id != removed_id for i in items)
2010
2011
2012async def test_raising_in_responder_yields_error_event():
2013 async def responder(
2014 thread: ThreadMetadata,
2015 input: UserMessageItem | None,
2016 context: Any,
2017 ) -> AsyncIterator[ThreadStreamEvent]:
2018 raise ValueError("Test error")
2019 yield
2020
2021 with make_server(responder) as server:
2022 result = await server.process(
2023 ThreadsCreateReq(
2024 params=ThreadCreateParams(
2025 input=UserMessageInput(
2026 content=[UserMessageTextContent(text="Test error")],
2027 attachments=[],
2028 inference_options=InferenceOptions(),
2029 )
2030 )
2031 ).model_dump_json(),
2032 RequestContext(user_id="test_user"),
2033 )
2034 assert isinstance(result, StreamingResult)
2035 iter = result.__aiter__()
2036
2037 # Yield an errro
2038 async for e in iter:
2039 e = decode_event(e)
2040 if e.type == "error":
2041 assert e.code == ErrorCode.STREAM_ERROR
2042 break
2043 else:
2044 assert False, "Expected error event"
2045
2046 # Then finish
2047 try:
2048 await iter.__anext__()
2049 except StopAsyncIteration:
2050 pass
2051 else:
2052 assert False, "Expected StopAsyncIteration"
2053
2054
2055async def test_thread_item_replaced_event_replaces_item():
2056 original_id = "msg_to_replace"
2057
2058 async def responder(
2059 thread: ThreadMetadata,
2060 input: UserMessageItem | None,
2061 context: Any,
2062 ) -> AsyncIterator[ThreadStreamEvent]:
2063 # First add an item
2064 yield ThreadItemDoneEvent(
2065 item=UserMessageItem(
2066 id=original_id,
2067 content=[UserMessageTextContent(text="Original content")],
2068 attachments=[],
2069 inference_options=InferenceOptions(),
2070 thread_id=thread.id,
2071 created_at=datetime.now(),
2072 ),
2073 )
2074 # Then replace it
2075 replacement_item = AssistantMessageItem(
2076 id=original_id,
2077 content=[AssistantMessageContent(text="This is the replacement content")],
2078 created_at=datetime.now(),
2079 thread_id=thread.id,
2080 )
2081 yield ThreadItemReplacedEvent(item=replacement_item)
2082
2083 with make_server(responder) as server:
2084 events = await server.process_streaming(
2085 ThreadsCreateReq(
2086 params=ThreadCreateParams(
2087 input=UserMessageInput(
2088 content=[UserMessageTextContent(text="Trigger replacement")],
2089 attachments=[],
2090 inference_options=InferenceOptions(),
2091 )
2092 )
2093 )
2094 )
2095 thread = next(
2096 event.thread for event in events if event.type == "thread.created"
2097 )
2098
2099 # Verify the replacement event was generated
2100 assert any(
2101 e.type == "thread.item.replaced" and e.item.id == original_id
2102 for e in events
2103 )
2104
2105 # Verify the item was replaced in the store
2106 list_result = await server.process_non_streaming(
2107 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
2108 )
2109 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
2110
2111 # Find the replaced item
2112 replaced_item = next((i for i in items if i.id == original_id), None)
2113 assert replaced_item is not None
2114 assert isinstance(replaced_item, AssistantMessageItem)
2115 assert isinstance(replaced_item.content[0], AssistantMessageContent)
2116 assert replaced_item.content[0].text == "This is the replacement content"
2117
2118
2119async def test_thread_item_replaced_event_with_widget():
2120 widget_id = "widget_to_replace"
2121 original_widget = Card(children=[Text(value="Original widget content")])
2122 replacement_widget = Card(children=[Text(value="Replaced widget content")])
2123
2124 async def responder(
2125 thread: ThreadMetadata,
2126 input: UserMessageItem | None,
2127 context: Any,
2128 ) -> AsyncIterator[ThreadStreamEvent]:
2129 # First add a widget item
2130 yield ThreadItemDoneEvent(
2131 item=WidgetItem(
2132 id=widget_id,
2133 widget=original_widget,
2134 created_at=datetime.now(),
2135 thread_id=thread.id,
2136 ),
2137 )
2138 # Then replace it
2139 yield ThreadItemReplacedEvent(
2140 item=WidgetItem(
2141 id=widget_id,
2142 widget=replacement_widget,
2143 created_at=datetime.now(),
2144 thread_id=thread.id,
2145 ),
2146 )
2147
2148 with make_server(responder) as server:
2149 events = await server.process_streaming(
2150 ThreadsCreateReq(
2151 params=ThreadCreateParams(
2152 input=UserMessageInput(
2153 content=[
2154 UserMessageTextContent(text="Trigger widget replacement")
2155 ],
2156 attachments=[],
2157 inference_options=InferenceOptions(),
2158 )
2159 )
2160 )
2161 )
2162 thread = next(
2163 event.thread for event in events if event.type == "thread.created"
2164 )
2165
2166 # Verify the replacement event was generated
2167 replacement_event = next(
2168 (
2169 e
2170 for e in events
2171 if e.type == "thread.item.replaced" and e.item.id == widget_id
2172 ),
2173 None,
2174 )
2175 assert replacement_event is not None
2176 assert isinstance(replacement_event.item, WidgetItem)
2177 assert replacement_event.item.widget == replacement_widget
2178
2179 # Verify the item was replaced in the store
2180 list_result = await server.process_non_streaming(
2181 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
2182 )
2183 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
2184
2185 # Find the replaced widget item
2186 replaced_item = next((i for i in items if i.id == widget_id), None)
2187 assert replaced_item is not None
2188 assert isinstance(replaced_item, WidgetItem)
2189 assert isinstance(replaced_item.widget, Card)
2190 assert isinstance(replaced_item.widget.children[0], Text)
2191 assert replaced_item.widget.children[0].value == "Replaced widget content"
2192
2193
2194async def test_retry_after_item():
2195 responder_calls = []
2196
2197 async def responder(
2198 thread: ThreadMetadata,
2199 input: UserMessageItem | None,
2200 context: Any,
2201 ) -> AsyncIterator[ThreadStreamEvent]:
2202 responder_calls.append(input)
2203
2204 if len(responder_calls) == 1:
2205 # First response - return an assistant message
2206 yield ThreadItemDoneEvent(
2207 item=AssistantMessageItem(
2208 id="assistant_msg_1",
2209 content=[AssistantMessageContent(text="First response")],
2210 created_at=datetime.now(),
2211 thread_id=thread.id,
2212 ),
2213 )
2214 else:
2215 # Retry response - return different content
2216 yield ThreadItemDoneEvent(
2217 item=AssistantMessageItem(
2218 id="assistant_msg_2",
2219 content=[AssistantMessageContent(text="Retried response")],
2220 created_at=datetime.now(),
2221 thread_id=thread.id,
2222 ),
2223 )
2224
2225 with make_server(responder) as server:
2226 events = await server.process_streaming(
2227 ThreadsCreateReq(
2228 params=ThreadCreateParams(
2229 input=UserMessageInput(
2230 content=[UserMessageTextContent(text="Hello, world!")],
2231 attachments=[],
2232 inference_options=InferenceOptions(),
2233 )
2234 )
2235 )
2236 )
2237 thread = next(
2238 event.thread for event in events if event.type == "thread.created"
2239 )
2240
2241 list_result = await server.process_non_streaming(
2242 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
2243 )
2244 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
2245 assert len(items_before.data) == 2
2246 assert items_before.data[0].type == "assistant_message"
2247 assert items_before.data[0].content[0].type == "output_text"
2248 assert items_before.data[0].content[0].text == "First response"
2249 assert items_before.data[1].type == "user_message"
2250 assert items_before.data[1].content[0].text == "Hello, world!"
2251
2252 # Get the user message ID for retrying after it
2253 user_message_id = items_before.data[1].id
2254
2255 # Retry after the user message
2256 retry_events = await server.process_streaming(
2257 ThreadsRetryAfterItemReq(
2258 params=ThreadRetryAfterItemParams(
2259 thread_id=thread.id, item_id=user_message_id
2260 )
2261 )
2262 )
2263
2264 # Verify the retry generated new response
2265 assert len(retry_events) == 2
2266 assert retry_events[0].type == "stream_options"
2267 assert retry_events[1].type == "thread.item.done"
2268 assert retry_events[1].item.type == "assistant_message"
2269 assert retry_events[1].item.content[0].type == "output_text"
2270 assert retry_events[1].item.content[0].text == "Retried response"
2271
2272 # Verify the responder was called twice with the same user message
2273 assert len(responder_calls) == 2
2274 assert responder_calls[0] == responder_calls[1]
2275 assert responder_calls[0].type == "user_message"
2276 assert responder_calls[0].content[0].text == "Hello, world!"
2277
2278 # Verify the assistant response was replaced in storage
2279 list_result_after = await server.process_non_streaming(
2280 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
2281 )
2282 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
2283 list_result_after.json
2284 )
2285 assert len(items_after.data) == 2 # Still user message + assistant message
2286 assert items_after.data[0].type == "assistant_message"
2287 assert items_after.data[0].content[0].type == "output_text"
2288 assert items_after.data[0].content[0].text == "Retried response" # New response
2289 assert items_after.data[1].type == "user_message"
2290 assert items_after.data[1].content[0].text == "Hello, world!"
2291
2292
2293async def test_retry_after_item_invalid_item():
2294 async def responder(
2295 thread: ThreadMetadata,
2296 input: UserMessageItem | None,
2297 context: Any,
2298 ) -> AsyncIterator[ThreadStreamEvent]:
2299 yield ThreadItemDoneEvent(
2300 item=AssistantMessageItem(
2301 id="assistant_msg_1",
2302 content=[AssistantMessageContent(text="Response")],
2303 created_at=datetime.now(),
2304 thread_id=thread.id,
2305 ),
2306 )
2307
2308 with make_server(responder) as server:
2309 events = await server.process_streaming(
2310 ThreadsCreateReq(
2311 params=ThreadCreateParams(
2312 input=UserMessageInput(
2313 content=[UserMessageTextContent(text="Hello")],
2314 attachments=[],
2315 inference_options=InferenceOptions(),
2316 )
2317 )
2318 )
2319 )
2320 thread = next(
2321 event.thread for event in events if event.type == "thread.created"
2322 )
2323
2324 items_result = await server.process_non_streaming(
2325 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
2326 )
2327 items = TypeAdapter(Page[ThreadItem]).validate_json(items_result.json)
2328 assistant_item_id = items.data[0].id
2329
2330 # Try to retry after an assistant message (should fail)
2331 with pytest.raises(ValueError, match="is not a user message"):
2332 await server.process_streaming(
2333 ThreadsRetryAfterItemReq(
2334 params=ThreadRetryAfterItemParams(
2335 thread_id=thread.id, item_id=assistant_item_id
2336 )
2337 )
2338 )
2339
2340
2341async def test_retry_after_item_with_multiple_messages():
2342 responder_calls = []
2343
2344 async def responder(
2345 thread: ThreadMetadata,
2346 input: UserMessageItem | None,
2347 context: Any,
2348 ) -> AsyncIterator[ThreadStreamEvent]:
2349 responder_calls.append(input)
2350 assert input is not None
2351 assert input.type == "user_message"
2352 yield ThreadItemDoneEvent(
2353 item=AssistantMessageItem(
2354 id=f"assistant_msg_{len(responder_calls)}",
2355 content=[
2356 AssistantMessageContent(
2357 text=f"Response to: {input.content[0].text}"
2358 )
2359 ],
2360 created_at=datetime.now(),
2361 thread_id=thread.id,
2362 ),
2363 )
2364
2365 with make_server(responder) as server:
2366 events = await server.process_streaming(
2367 ThreadsCreateReq(
2368 params=ThreadCreateParams(
2369 input=UserMessageInput(
2370 content=[UserMessageTextContent(text="First message")],
2371 attachments=[],
2372 inference_options=InferenceOptions(),
2373 )
2374 )
2375 )
2376 )
2377 thread = next(
2378 event.thread for event in events if event.type == "thread.created"
2379 )
2380
2381 await server.process_streaming(
2382 ThreadsAddUserMessageReq(
2383 params=ThreadAddUserMessageParams(
2384 thread_id=thread.id,
2385 input=UserMessageInput(
2386 content=[UserMessageTextContent(text="Second message")],
2387 attachments=[],
2388 inference_options=InferenceOptions(),
2389 ),
2390 )
2391 )
2392 )
2393
2394 # Verify we have 4 items: user1, assistant1, user2, assistant2
2395 list_result = await server.process_non_streaming(
2396 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
2397 )
2398 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
2399 assert len(items_before.data) == 4
2400
2401 # Get the second user message ID to retry after it
2402 second_user_message_id = items_before.data[1].id # Second user message
2403
2404 # Retry after the second user message
2405 retry_events = await server.process_streaming(
2406 ThreadsRetryAfterItemReq(
2407 params=ThreadRetryAfterItemParams(
2408 thread_id=thread.id, item_id=second_user_message_id
2409 )
2410 )
2411 )
2412
2413 assert len(retry_events) == 2
2414 assert retry_events[0].type == "stream_options"
2415 assert retry_events[1].type == "thread.item.done"
2416
2417 # Verify retry used the second user message
2418 assert len(responder_calls) == 3 # Original 2 + 1 retry
2419 assert responder_calls[2].type == "user_message"
2420 assert responder_calls[2].content[0].text == "Second message"
2421
2422 # Verify only the last assistant response was removed and regenerated
2423 list_result_after = await server.process_non_streaming(
2424 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
2425 )
2426 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
2427 list_result_after.json
2428 )
2429 assert len(items_after.data) == 4 # Still 4 items
2430 # First user message and response should remain unchanged
2431 assert items_after.data[3].type == "user_message"
2432 assert items_after.data[3].content[0].text == "First message"
2433 assert items_after.data[3].content[0].type == "input_text"
2434 assert items_after.data[3].content[0].text == "First message"
2435 assert items_after.data[2].type == "assistant_message"
2436 assert items_after.data[2].content[0].type == "output_text"
2437 assert items_after.data[2].content[0].text == "Response to: First message"
2438 # Second user message should remain, but assistant response should be new
2439 assert items_after.data[1].type == "user_message"
2440 assert items_after.data[1].content[0].text == "Second message"
2441 assert items_after.data[0].type == "assistant_message"
2442 assert items_after.data[0].content[0].type == "output_text"
2443 assert items_after.data[0].content[0].text == "Response to: Second message"
2444 assert items_after.data[0].id != items_before.data[0].id
2445
2446
2447async def test_threads_create_passes_tools_to_responder():
2448 tool_choice = ToolChoice(id="web_search")
2449
2450 async def responder(
2451 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
2452 ) -> AsyncIterator[ThreadStreamEvent]:
2453 assert input is not None
2454 assert input.inference_options.tool_choice is not None
2455 assert input.inference_options.tool_choice.id == tool_choice.id
2456
2457 yield ThreadItemDoneEvent(
2458 item=AssistantMessageItem(
2459 id="assistant_msg_tools_create",
2460 content=[AssistantMessageContent(text="ok")],
2461 created_at=datetime.now(),
2462 thread_id=thread.id,
2463 ),
2464 )
2465
2466 with make_server(responder) as server:
2467 events = await server.process_streaming(
2468 ThreadsCreateReq(
2469 params=ThreadCreateParams(
2470 input=UserMessageInput(
2471 content=[UserMessageTextContent(text="Hello with tools")],
2472 attachments=[],
2473 inference_options=InferenceOptions(tool_choice=tool_choice),
2474 )
2475 )
2476 )
2477 )
2478 # Sanity check that the request flowed through.
2479 assert any(e.type == "thread.item.done" for e in events)
2480
2481
2482async def test_input_transcribe_decodes_base64_and_normalizes_mime_type():
2483 audio_bytes = b"hello audio"
2484 audio_b64 = base64.b64encode(audio_bytes).decode("ascii")
2485 seen: dict[str, Any] = {}
2486
2487 def transcribe_callback(
2488 audio_input: AudioInput, context: Any
2489 ) -> TranscriptionResult:
2490 seen["audio"] = audio_input.data
2491 seen["mime"] = audio_input.mime_type
2492 seen["media_type"] = audio_input.media_type
2493 seen["context"] = context
2494 return TranscriptionResult(text="ok")
2495
2496 with make_server(transcribe_callback=transcribe_callback) as server:
2497 result = await server.process_non_streaming(
2498 InputTranscribeReq(
2499 params=InputTranscribeParams(
2500 audio_base64=audio_b64,
2501 mime_type="audio/webm;codecs=opus",
2502 )
2503 )
2504 )
2505 parsed = TypeAdapter(TranscriptionResult).validate_json(result.json)
2506 assert parsed.text == "ok"
2507
2508 assert seen["audio"] == audio_bytes
2509 assert seen["mime"] == "audio/webm;codecs=opus"
2510 assert seen["media_type"] == "audio/webm"
2511 assert seen["context"] == DEFAULT_CONTEXT
2512
2513
2514async def test_retry_after_item_passes_tools_to_responder():
2515 pass
2516
2517
2518async def test_threads_add_message_passes_tools_to_responder():
2519 tool_choice = ToolChoice(id="retrieval")
2520
2521 async def responder(
2522 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
2523 ) -> AsyncIterator[ThreadStreamEvent]:
2524 assert input is not None
2525 assert input.inference_options.tool_choice is not None
2526 assert input.inference_options.tool_choice.id == tool_choice.id
2527
2528 yield ThreadItemDoneEvent(
2529 item=AssistantMessageItem(
2530 id="assistant_msg_tools_add",
2531 content=[AssistantMessageContent(text="ok")],
2532 created_at=datetime.now(),
2533 thread_id=thread.id,
2534 ),
2535 )
2536
2537 with make_server(responder) as server:
2538 # Create a thread first.
2539 events = await server.process_streaming(
2540 ThreadsCreateReq(
2541 params=ThreadCreateParams(
2542 input=UserMessageInput(
2543 content=[UserMessageTextContent(text="Start")],
2544 attachments=[],
2545 inference_options=InferenceOptions(tool_choice=tool_choice),
2546 )
2547 )
2548 )
2549 )
2550 thread = next(e.thread for e in events if e.type == "thread.created")
2551
2552 # Add a message with tools and verify they reach the responder.
2553 events = await server.process_streaming(
2554 ThreadsAddUserMessageReq(
2555 params=ThreadAddUserMessageParams(
2556 thread_id=thread.id,
2557 input=UserMessageInput(
2558 content=[UserMessageTextContent(text="Second")],
2559 attachments=[],
2560 inference_options=InferenceOptions(tool_choice=tool_choice),
2561 ),
2562 )
2563 )
2564 )
2565 # Sanity check that the request flowed through.
2566 assert any(e.type == "thread.item.done" for e in events)
2567
2568
2569async def test_custom_id_generators_on_server():
2570 class UniqueIdStore(SQLiteStore):
2571 def __init__(self, path):
2572 super().__init__(path)
2573 self.counter = 0
2574
2575 def generate_thread_id(self, context) -> str:
2576 self.counter += 1
2577 return f"thr_custom_{self.counter}"
2578
2579 def generate_item_id(
2580 self,
2581 item_type: str,
2582 thread: ThreadMetadata,
2583 context,
2584 ) -> str:
2585 self.counter += 1
2586 return f"{item_type}_custom_{self.counter}_{thread.id}"
2587
2588 async def responder(
2589 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
2590 ) -> AsyncIterator[ThreadStreamEvent]:
2591 # Emit a tool call to later verify tool_call_output id generation
2592 return
2593 yield
2594
2595 with make_server(responder) as server:
2596 # Swap in our custom ID-generating store, reusing the same DB path.
2597 db_path = cast(SQLiteStore, server.store).db_path
2598 server.store = UniqueIdStore(db_path)
2599
2600 # Create thread and verify thread id and user message id
2601 result = await server.process(
2602 ThreadsCreateReq(
2603 params=ThreadCreateParams(
2604 input=UserMessageInput(
2605 content=[UserMessageTextContent(text="Hello")],
2606 attachments=[],
2607 inference_options=InferenceOptions(),
2608 )
2609 )
2610 ).model_dump_json(),
2611 DEFAULT_CONTEXT,
2612 )
2613 assert isinstance(result, StreamingResult)
2614 events = await decode_streaming_result(result)
2615
2616 thread_created = next(e for e in events if e.type == "thread.created")
2617 assert thread_created.thread.id == "thr_custom_1"
2618
2619 user_message = next(
2620 e.item
2621 for e in events
2622 if e.type == "thread.item.done" and e.item.type == "user_message"
2623 )
2624 assert user_message.id == "message_custom_2_thr_custom_1"
2625