openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
f3dca008d3511f1e14ee6e780833a82c5ffd0d71

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_chatkit_server.py

2300lines · modecode

1import asyncio
2import base64
3import json
4import sqlite3
5from contextlib import contextmanager
6from datetime import datetime
7from typing import Any, AsyncIterator, Callable, cast
8
9import pytest
10from helpers.mock_store import SQLiteStore
11from pydantic import AnyUrl, TypeAdapter
12
13from chatkit.actions import Action
14from chatkit.errors import ErrorCode
15from chatkit.server import (
16 ChatKitServer,
17 NonStreamingResult,
18 StreamingResult,
19 stream_widget,
20)
21from chatkit.store import (
22 AttachmentStore,
23 NotFoundError,
24 StoreItemType,
25 default_generate_id,
26)
27from chatkit.types import (
28 AssistantMessageContent,
29 AssistantMessageContentPartTextDelta,
30 AssistantMessageItem,
31 Attachment,
32 AttachmentCreateParams,
33 AttachmentDeleteParams,
34 AttachmentsCreateReq,
35 AttachmentsDeleteReq,
36 AttachmentUploadDescriptor,
37 AudioInput,
38 ClientToolCallItem,
39 CustomTask,
40 FeedbackKind,
41 FileAttachment,
42 ImageAttachment,
43 InferenceOptions,
44 InputTranscribeParams,
45 InputTranscribeReq,
46 ItemFeedbackParams,
47 ItemsFeedbackReq,
48 ItemsListParams,
49 ItemsListReq,
50 LockedStatus,
51 Page,
52 ProgressUpdateEvent,
53 SyncCustomActionResponse,
54 Thread,
55 ThreadAddClientToolOutputParams,
56 ThreadAddUserMessageParams,
57 ThreadCreatedEvent,
58 ThreadCreateParams,
59 ThreadCustomActionParams,
60 ThreadDeleteParams,
61 ThreadGetByIdParams,
62 ThreadItem,
63 ThreadItemAddedEvent,
64 ThreadItemDoneEvent,
65 ThreadItemRemovedEvent,
66 ThreadItemReplacedEvent,
67 ThreadItemUpdatedEvent,
68 ThreadListParams,
69 ThreadMetadata,
70 ThreadRetryAfterItemParams,
71 ThreadsAddClientToolOutputReq,
72 ThreadsAddUserMessageReq,
73 ThreadsCreateReq,
74 ThreadsCustomActionReq,
75 ThreadsDeleteReq,
76 ThreadsGetByIdReq,
77 ThreadsListReq,
78 ThreadsRetryAfterItemReq,
79 ThreadsSyncCustomActionReq,
80 ThreadStreamEvent,
81 ThreadsUpdateReq,
82 ThreadUpdatedEvent,
83 ThreadUpdateParams,
84 ToolChoice,
85 TranscriptionResult,
86 UserMessageInput,
87 UserMessageItem,
88 UserMessageTextContent,
89 WidgetItem,
90 WidgetRootUpdated,
91 Workflow,
92 WorkflowItem,
93 WorkflowTaskAdded,
94)
95from chatkit.widgets import Card, Text
96from tests._types import RequestContext
97from tests.test_store import make_thread_items
98
99server_id = 0
100
101
102DEFAULT_CONTEXT = RequestContext(user_id="test_user")
103
104
105class InMemoryFileStore(AttachmentStore):
106 def __init__(self):
107 self.files = {}
108
109 async def delete_attachment(self, attachment_id: str, context: Any) -> None:
110 del self.files[attachment_id]
111
112 async def create_attachment(
113 self, input: AttachmentCreateParams, context: RequestContext
114 ) -> Attachment:
115 # Simple in-memory attachment creation
116 id = f"atc_{len(self.files) + 1}"
117 metadata = {"source": "test"}
118 if input.mime_type.startswith("image/"):
119 attachment = ImageAttachment(
120 id=id,
121 mime_type=input.mime_type,
122 name=input.name,
123 preview_url=AnyUrl(f"https://example.com/{id}/preview"),
124 upload_descriptor=AttachmentUploadDescriptor(
125 url=AnyUrl(f"https://example.com/{id}/upload"),
126 method="PUT",
127 headers={"X-My-Header": "my-value"},
128 ),
129 metadata=metadata,
130 )
131 else:
132 attachment = FileAttachment(
133 id=id,
134 mime_type=input.mime_type,
135 name=input.name,
136 upload_descriptor=AttachmentUploadDescriptor(
137 url=AnyUrl(f"https://example.com/{id}/upload"),
138 method="PUT",
139 headers={"X-My-Header": "my-value"},
140 ),
141 metadata=metadata,
142 )
143 self.files[attachment.id] = attachment
144 return attachment
145
146
147def decode_event(event: bytes) -> ThreadStreamEvent:
148 return TypeAdapter(ThreadStreamEvent).validate_json(event.split(b"data: ")[1])
149
150
151async def decode_streaming_result(
152 streaming_result: StreamingResult,
153) -> list[ThreadStreamEvent]:
154 return [decode_event(event) async for event in streaming_result.json_events]
155
156
157@contextmanager
158def make_server(
159 responder: Callable[
160 [ThreadMetadata, UserMessageItem | None, Any], AsyncIterator[ThreadStreamEvent]
161 ]
162 | None = None,
163 handle_feedback: Callable[[str, list[str], FeedbackKind, Any], None] | None = None,
164 action_callback: Callable[
165 [ThreadMetadata, Action[str, Any], WidgetItem | None, Any],
166 AsyncIterator[ThreadStreamEvent],
167 ]
168 | None = None,
169 sync_action_callback: Callable[
170 [ThreadMetadata, Action[str, Any], WidgetItem | None, Any],
171 SyncCustomActionResponse,
172 ]
173 | None = None,
174 file_store: AttachmentStore | None = None,
175 transcribe_callback: Callable[[AudioInput, Any], TranscriptionResult] | None = None,
176):
177 global server_id
178 db_path = f"file:{server_id}?mode=memory&cache=shared"
179 server_id += 1
180 # Keep the shared in-memory database from being deleted when the last connection is closed
181 db = sqlite3.connect(db_path, uri=True)
182
183 class TestChatKitServer(ChatKitServer):
184 def __init__(self):
185 super().__init__(SQLiteStore(db_path), file_store)
186
187 def action(
188 self,
189 thread: ThreadMetadata,
190 action: Action[str, Any],
191 sender: WidgetItem | None,
192 context: Any,
193 ) -> AsyncIterator[ThreadStreamEvent]:
194 if action_callback is None:
195 raise ValueError("action_callback not wired up")
196 return action_callback(thread, action, sender, context)
197
198 async def sync_action(
199 self,
200 thread: ThreadMetadata,
201 action: Action[str, Any],
202 sender: WidgetItem | None,
203 context: Any,
204 ) -> SyncCustomActionResponse:
205 if sync_action_callback is None:
206 raise ValueError("sync_action_callback not wired up")
207 return sync_action_callback(thread, action, sender, context)
208
209 def respond(
210 self,
211 thread: ThreadMetadata,
212 input_user_message: UserMessageItem | None,
213 context: Any,
214 ) -> AsyncIterator[ThreadStreamEvent]:
215 if responder is None:
216 return self._empty_responder()
217 return responder(thread, input_user_message, context)
218
219 async def _empty_responder(self) -> AsyncIterator[ThreadStreamEvent]:
220 return
221 yield
222
223 async def add_feedback(
224 self,
225 thread_id: str,
226 item_ids: list[str],
227 feedback: FeedbackKind,
228 context: Any,
229 ) -> None:
230 if handle_feedback is None:
231 return
232 handle_feedback(thread_id, item_ids, feedback, context)
233
234 async def transcribe(
235 self, audio_input: AudioInput, context: Any
236 ) -> TranscriptionResult:
237 if transcribe_callback is None:
238 return await super().transcribe(audio_input, context)
239 return transcribe_callback(audio_input, context)
240
241 async def process_streaming(
242 self, request_obj, context: Any | None = None
243 ) -> list[ThreadStreamEvent]:
244 result = await self.process(
245 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
246 )
247 assert isinstance(result, StreamingResult)
248 return await decode_streaming_result(result)
249
250 async def process_non_streaming(self, request_obj, context: Any | None = None):
251 result = await self.process(
252 request_obj.model_dump_json(), context or DEFAULT_CONTEXT
253 )
254 assert isinstance(result, NonStreamingResult)
255 return result
256
257 try:
258 yield TestChatKitServer()
259 finally:
260 db.close()
261
262
263async def test_stream_cancellation_persists_pending_assistant_message_and_hidden_context():
264 async def responder(
265 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
266 ) -> AsyncIterator[ThreadStreamEvent]:
267 yield ThreadItemAddedEvent(
268 item=AssistantMessageItem(
269 id="assistant-message-pending",
270 created_at=datetime.now(),
271 content=[AssistantMessageContent(text="Hello, ")],
272 thread_id=thread.id,
273 )
274 )
275 yield ThreadItemUpdatedEvent(
276 item_id="assistant-message-pending",
277 update=AssistantMessageContentPartTextDelta(
278 content_index=0,
279 delta="World!",
280 ),
281 )
282 raise asyncio.CancelledError()
283
284 with make_server(responder) as server:
285 # Allow hidden_context id generation in this test store
286 original_generate_item_id = server.store.generate_item_id
287
288 def generate_item_id(
289 item_type: StoreItemType, thread: ThreadMetadata, context: Any
290 ):
291 if item_type == "sdk_hidden_context":
292 return default_generate_id("sdk_hidden_context")
293 return original_generate_item_id(item_type, thread, context)
294
295 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
296
297 stream = await server.process(
298 ThreadsCreateReq(
299 params=ThreadCreateParams(
300 input=UserMessageInput(
301 content=[UserMessageTextContent(text="Hello")],
302 attachments=[],
303 inference_options=InferenceOptions(),
304 )
305 )
306 ).model_dump_json(),
307 DEFAULT_CONTEXT,
308 )
309 assert isinstance(stream, StreamingResult)
310
311 events: list[ThreadStreamEvent] = []
312 with pytest.raises(asyncio.CancelledError): # noqa: PT012
313 async for raw in stream.json_events:
314 events.append(decode_event(raw))
315
316 thread = next(e.thread for e in events if e.type == "thread.created")
317 items = await server.store.load_thread_items(
318 thread.id, None, 1, "desc", DEFAULT_CONTEXT
319 )
320 hidden_context_item = items.data[-1]
321 assert hidden_context_item.type == "sdk_hidden_context"
322 assert (
323 hidden_context_item.content
324 == "The user cancelled the stream. Stop responding to the prior request."
325 )
326
327 assistant_message_item = await server.store.load_item(
328 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
329 )
330 assert assistant_message_item.type == "assistant_message"
331 assert assistant_message_item.content[0].text == "Hello, World!"
332
333
334async def test_stream_cancellation_does_not_persist_pending_empty_assistant_message():
335 async def responder(
336 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
337 ) -> AsyncIterator[ThreadStreamEvent]:
338 yield ThreadItemAddedEvent(
339 item=AssistantMessageItem(
340 id="assistant-message-pending",
341 created_at=datetime.now(),
342 content=[],
343 thread_id=thread.id,
344 )
345 )
346 raise asyncio.CancelledError()
347
348 with make_server(responder) as server:
349 original_generate_item_id = server.store.generate_item_id
350
351 def generate_item_id(
352 item_type: StoreItemType, thread: ThreadMetadata, context: Any
353 ):
354 if item_type == "sdk_hidden_context":
355 return default_generate_id("sdk_hidden_context")
356 return original_generate_item_id(item_type, thread, context)
357
358 server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
359
360 stream = await server.process(
361 ThreadsCreateReq(
362 params=ThreadCreateParams(
363 input=UserMessageInput(
364 content=[UserMessageTextContent(text="Hello")],
365 attachments=[],
366 inference_options=InferenceOptions(),
367 )
368 )
369 ).model_dump_json(),
370 DEFAULT_CONTEXT,
371 )
372 assert isinstance(stream, StreamingResult)
373
374 events: list[ThreadStreamEvent] = []
375 with pytest.raises(asyncio.CancelledError): # noqa: PT012
376 async for raw in stream.json_events:
377 events.append(decode_event(raw))
378
379 thread = next(e.thread for e in events if e.type == "thread.created")
380 items = await server.store.load_thread_items(
381 thread.id, None, 1, "desc", DEFAULT_CONTEXT
382 )
383 hidden_context_item = items.data[-1]
384 assert hidden_context_item.type == "sdk_hidden_context"
385 assert (
386 hidden_context_item.content
387 == "The user cancelled the stream. Stop responding to the prior request."
388 )
389
390 with pytest.raises(NotFoundError):
391 await server.store.load_item(
392 thread.id, "assistant-message-pending", DEFAULT_CONTEXT
393 )
394
395
396async def test_workflow_task_not_duplicated_on_done_event():
397 """
398 Regression test to make sure pending item updates do not modify the
399 origin item that was streamed.
400 """
401
402 async def responder(
403 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
404 ) -> AsyncIterator[ThreadStreamEvent]:
405 workflow_item = WorkflowItem(
406 id="workflow-item",
407 created_at=datetime.now(),
408 thread_id=thread.id,
409 workflow=Workflow(type="custom", tasks=[]),
410 )
411
412 yield ThreadItemAddedEvent(item=workflow_item)
413
414 task = CustomTask(title="foo", content="bar")
415 yield ThreadItemUpdatedEvent(
416 item_id=workflow_item.id,
417 update=WorkflowTaskAdded(task=task, task_index=0),
418 )
419
420 workflow_item.workflow.tasks.append(task)
421 yield ThreadItemDoneEvent(item=workflow_item)
422
423 with make_server(responder) as server:
424 events = await server.process_streaming(
425 ThreadsCreateReq(
426 params=ThreadCreateParams(
427 input=UserMessageInput(
428 content=[UserMessageTextContent(text="Hello")],
429 attachments=[],
430 inference_options=InferenceOptions(),
431 )
432 )
433 )
434 )
435
436 thread = next(e.thread for e in events if e.type == "thread.created")
437 workflow_done_event = next(
438 e
439 for e in events
440 if isinstance(e, ThreadItemDoneEvent) and isinstance(e.item, WorkflowItem)
441 )
442 workflow_done_item = cast(WorkflowItem, workflow_done_event.item)
443 assert len(workflow_done_item.workflow.tasks) == 1
444 assert workflow_done_item.workflow.tasks[0].title == "foo"
445
446 stored = await server.store.load_item(
447 thread.id, workflow_done_item.id, DEFAULT_CONTEXT
448 )
449 assert isinstance(stored, WorkflowItem)
450 stored_workflow = cast(WorkflowItem, stored)
451 assert len(stored_workflow.workflow.tasks) == 1
452 assert stored_workflow.workflow.tasks[0].title == "foo"
453
454
455async def test_flows_context_to_responder():
456 responder_context = None
457 add_feedback_context = None
458
459 async def responder(
460 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
461 ) -> AsyncIterator[ThreadStreamEvent]:
462 nonlocal responder_context
463 responder_context = context
464 yield ThreadItemDoneEvent(
465 item=AssistantMessageItem(
466 id="msg_1",
467 created_at=datetime.now(),
468 content=[AssistantMessageContent(text="Hello, world!")],
469 thread_id=thread.id,
470 ),
471 )
472
473 def add_feedback(
474 thread_id: str, item_ids: list[str], feedback: FeedbackKind, context: Any
475 ) -> None:
476 nonlocal add_feedback_context
477 add_feedback_context = context
478
479 context = RequestContext(user_id="text_ctx")
480 with make_server(responder, add_feedback) as server:
481 events = await server.process_streaming(
482 ThreadsCreateReq(
483 params=ThreadCreateParams(
484 input=UserMessageInput(
485 content=[UserMessageTextContent(text="Hello, world!")],
486 attachments=[],
487 inference_options=InferenceOptions(),
488 )
489 )
490 ),
491 context,
492 )
493 assert responder_context == context
494 thread = next(
495 event.thread for event in events if event.type == "thread.created"
496 )
497 item = next(
498 event.item
499 for event in events
500 if event.type == "thread.item.done"
501 and event.item.type == "assistant_message"
502 )
503 await server.process_non_streaming(
504 ItemsFeedbackReq(
505 params=ItemFeedbackParams(
506 thread_id=thread.id,
507 item_ids=[item.id],
508 kind="positive",
509 )
510 ),
511 context,
512 )
513 assert add_feedback_context == context
514
515
516async def test_creates_thread():
517 async def responder(
518 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
519 ) -> AsyncIterator[ThreadStreamEvent]:
520 return
521 yield
522
523 with make_server(responder) as server:
524 events = await server.process_streaming(
525 ThreadsCreateReq(
526 params=ThreadCreateParams(
527 input=UserMessageInput(
528 content=[UserMessageTextContent(text="Hello, world!")],
529 attachments=[],
530 inference_options=InferenceOptions(),
531 quoted_text="Quoted text",
532 )
533 )
534 )
535 )
536
537 thread_created_event = next(
538 event for event in events if event.type == "thread.created"
539 )
540 assert thread_created_event.thread.title is None
541
542 item_done_event = next(
543 event.item for event in events if event.type == "thread.item.done"
544 )
545
546 assert item_done_event.type == "user_message"
547 assert item_done_event.content[0].text == "Hello, world!"
548 assert item_done_event.quoted_text == "Quoted text"
549
550
551async def test_saves_thread_title():
552 async def responder(
553 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
554 ) -> AsyncIterator[ThreadStreamEvent]:
555 thread.title = "Updated title"
556 return
557 yield
558
559 with make_server(responder) as server:
560 events = await server.process_streaming(
561 ThreadsCreateReq(
562 params=ThreadCreateParams(
563 input=UserMessageInput(
564 content=[UserMessageTextContent(text="Hello, world!")],
565 attachments=[],
566 inference_options=InferenceOptions(),
567 )
568 )
569 )
570 )
571 thread = next(
572 event.thread for event in events if event.type == "thread.created"
573 )
574 assert (
575 await server.store.load_thread(
576 thread.id, RequestContext(user_id="test_user")
577 )
578 ).title == "Updated title"
579 assert events[-1].type == "thread.updated"
580 assert events[-1].thread.title == "Updated title"
581
582
583async def test_saves_thread_metadata():
584 async def responder(
585 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
586 ) -> AsyncIterator[ThreadStreamEvent]:
587 thread.metadata["test_key"] = "test_value"
588 return
589 yield
590
591 with make_server(responder) as server:
592 events = await server.process_streaming(
593 ThreadsCreateReq(
594 params=ThreadCreateParams(
595 input=UserMessageInput(
596 content=[UserMessageTextContent(text="Hello, world!")],
597 attachments=[],
598 inference_options=InferenceOptions(),
599 )
600 )
601 )
602 )
603 thread = next(
604 event.thread for event in events if event.type == "thread.created"
605 )
606 assert (await server.store.load_thread(thread.id, DEFAULT_CONTEXT)).metadata[
607 "test_key"
608 ] == "test_value"
609 assert events[-1].type == "thread.updated"
610
611
612async def test_saves_thread_locked_fields():
613 async def responder(
614 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
615 ) -> AsyncIterator[ThreadStreamEvent]:
616 thread.status = LockedStatus(reason="Because")
617 return
618 yield
619
620 with make_server(responder) as server:
621 events = await server.process_streaming(
622 ThreadsCreateReq(
623 params=ThreadCreateParams(
624 input=UserMessageInput(
625 content=[UserMessageTextContent(text="Hello, world!")],
626 attachments=[],
627 inference_options=InferenceOptions(),
628 )
629 )
630 )
631 )
632 thread = next(
633 event.thread for event in events if event.type == "thread.created"
634 )
635 loaded = await server.store.load_thread(
636 thread.id, RequestContext(user_id="test_user")
637 )
638 assert loaded.status == LockedStatus(reason="Because")
639 assert events[-1].type == "thread.updated"
640 assert events[-1].thread.status == LockedStatus(reason="Because")
641
642
643async def test_saves_allowed_image_domains_and_streams_thread_updated():
644 async def responder(
645 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
646 ) -> AsyncIterator[ThreadStreamEvent]:
647 thread.allowed_image_domains = ["example.com", "images.example.com"]
648 return
649 yield
650
651 with make_server(responder) as server:
652 events = await server.process_streaming(
653 ThreadsCreateReq(
654 params=ThreadCreateParams(
655 input=UserMessageInput(
656 content=[UserMessageTextContent(text="Hello, world!")],
657 attachments=[],
658 inference_options=InferenceOptions(),
659 )
660 )
661 )
662 )
663 thread = next(
664 event.thread for event in events if event.type == "thread.created"
665 )
666 loaded = await server.store.load_thread(thread.id, DEFAULT_CONTEXT)
667 assert loaded.allowed_image_domains == ["example.com", "images.example.com"]
668 assert events[-1].type == "thread.updated"
669 assert events[-1].thread.allowed_image_domains == [
670 "example.com",
671 "images.example.com",
672 ]
673
674
675async def test_omits_unset_allowed_image_domains_in_created_and_updated_json_events():
676 async def responder(
677 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
678 ) -> AsyncIterator[ThreadStreamEvent]:
679 thread.title = "Updated title"
680 return
681 yield
682
683 with make_server(responder) as server:
684 stream = await server.process(
685 ThreadsCreateReq(
686 params=ThreadCreateParams(
687 input=UserMessageInput(
688 content=[UserMessageTextContent(text="Hello, world!")],
689 attachments=[],
690 inference_options=InferenceOptions(),
691 )
692 )
693 ).model_dump_json(),
694 DEFAULT_CONTEXT,
695 )
696 assert isinstance(stream, StreamingResult)
697
698 thread_created_event: dict[str, Any] | None = None
699 thread_updated_event: dict[str, Any] | None = None
700 async for raw in stream.json_events:
701 event = json.loads(raw.split(b"data: ")[1])
702 if event["type"] == "thread.created":
703 thread_created_event = event
704 if event["type"] == "thread.updated":
705 thread_updated_event = event
706
707 assert thread_created_event is not None
708 assert "allowed_image_domains" not in thread_created_event["thread"]
709
710 assert thread_updated_event is not None
711 assert thread_updated_event["thread"]["title"] == "Updated title"
712 assert "allowed_image_domains" not in thread_updated_event["thread"]
713
714
715async def test_emits_thread_updated_mid_stream_and_persists():
716 async def responder(
717 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
718 ) -> AsyncIterator[ThreadStreamEvent]:
719 # First assistant message
720 yield ThreadItemDoneEvent(
721 item=AssistantMessageItem(
722 id="assistant_a",
723 content=[AssistantMessageContent(text="A")],
724 created_at=datetime.now(),
725 thread_id=thread.id,
726 ),
727 )
728
729 # Update the thread mid-stream
730 thread.title = "Mid-stream title"
731
732 # Second assistant message, after which thread.updated should be emitted
733 yield ThreadItemDoneEvent(
734 item=AssistantMessageItem(
735 id="assistant_b",
736 content=[AssistantMessageContent(text="B")],
737 created_at=datetime.now(),
738 thread_id=thread.id,
739 ),
740 )
741
742 # Continue streaming after the update event
743 yield ThreadItemDoneEvent(
744 item=AssistantMessageItem(
745 id="assistant_c",
746 content=[AssistantMessageContent(text="C")],
747 created_at=datetime.now(),
748 thread_id=thread.id,
749 ),
750 )
751
752 with make_server(responder) as server:
753 events = await server.process_streaming(
754 ThreadsCreateReq(
755 params=ThreadCreateParams(
756 input=UserMessageInput(
757 content=[UserMessageTextContent(text="Hello")],
758 attachments=[],
759 inference_options=InferenceOptions(),
760 )
761 )
762 )
763 )
764
765 # Find the created thread
766 thread = next(e.thread for e in events if isinstance(e, ThreadCreatedEvent))
767
768 # Ensure a thread.updated event occurs right after assistant_b
769 idx_b = next(
770 i
771 for i, e in enumerate(events)
772 if isinstance(e, ThreadItemDoneEvent)
773 and isinstance(e.item, AssistantMessageItem)
774 and e.item.id == "assistant_b"
775 )
776 assert isinstance(events[idx_b + 1], ThreadUpdatedEvent)
777 updated_event = cast(ThreadUpdatedEvent, events[idx_b + 1])
778 assert updated_event.thread.title == "Mid-stream title"
779
780 # Streaming continues after the update event
781 assert isinstance(events[idx_b + 2], ThreadItemDoneEvent)
782 done_event = cast(ThreadItemDoneEvent, events[idx_b + 2])
783 assert isinstance(done_event.item, AssistantMessageItem)
784 assert done_event.item.id == "assistant_c"
785
786 # Persisted in the store
787 saved = await server.store.load_thread(thread.id, DEFAULT_CONTEXT)
788 assert saved.title == "Mid-stream title"
789
790
791async def test_respond_with_tool_call():
792 async def responder(
793 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
794 ) -> AsyncIterator[ThreadStreamEvent]:
795 if isinstance(input, UserMessageItem):
796 yield ThreadItemDoneEvent(
797 item=ClientToolCallItem(
798 id="msg_1",
799 created_at=datetime.now(),
800 name="tool_call_1",
801 arguments={"arg1": "val1", "arg2": False},
802 call_id="tool_call_1",
803 thread_id=thread.id,
804 ),
805 )
806 elif input is None:
807 yield ThreadItemDoneEvent(
808 item=AssistantMessageItem(
809 id="msg_2",
810 content=[
811 AssistantMessageContent(text="Glad the tool call succeeded!")
812 ],
813 created_at=datetime.now(),
814 thread_id=thread.id,
815 ),
816 )
817
818 with make_server(responder) as server:
819 events = await server.process_streaming(
820 ThreadsCreateReq(
821 params=ThreadCreateParams(
822 input=UserMessageInput(
823 content=[UserMessageTextContent(text="Hello, world!")],
824 attachments=[],
825 inference_options=InferenceOptions(),
826 )
827 )
828 )
829 )
830
831 assert len(events) == 4
832 assert events[0].type == "thread.created"
833 thread = events[0].thread
834
835 assert events[1].type == "thread.item.done"
836 assert events[1].item.type == "user_message"
837
838 assert events[2].type == "stream_options"
839
840 assert events[3].type == "thread.item.done"
841 assert events[3].item.type == "client_tool_call"
842 assert events[3].item.id == "msg_1"
843 assert events[3].item.name == "tool_call_1"
844 assert events[3].item.arguments == {"arg1": "val1", "arg2": False}
845 assert events[3].item.call_id == "tool_call_1"
846
847 events = await server.process_streaming(
848 ThreadsAddClientToolOutputReq(
849 params=ThreadAddClientToolOutputParams(
850 thread_id=thread.id,
851 result={"text": "Wow!"},
852 )
853 )
854 )
855
856 assert len(events) == 2
857 assert events[0].type == "stream_options"
858 assert events[1].type == "thread.item.done"
859 assert events[1].item.type == "assistant_message"
860
861
862async def test_respond_with_tool_status():
863 async def responder(
864 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
865 ) -> AsyncIterator[ThreadStreamEvent]:
866 yield ProgressUpdateEvent(text="Tool status")
867 yield ThreadItemDoneEvent(
868 item=AssistantMessageItem(
869 id="msg_4",
870 content=[AssistantMessageContent(text="Hello, world!")],
871 created_at=datetime.now(),
872 thread_id=thread.id,
873 ),
874 )
875
876 with make_server(responder) as server:
877 events = await server.process_streaming(
878 ThreadsCreateReq(
879 params=ThreadCreateParams(
880 input=UserMessageInput(
881 content=[UserMessageTextContent(text="Hello, world!")],
882 attachments=[],
883 inference_options=InferenceOptions(),
884 )
885 )
886 )
887 )
888
889 assert len(events) == 5
890 assert events[0].type == "thread.created"
891 assert events[1].type == "thread.item.done"
892 assert events[2].type == "stream_options"
893 assert events[3].type == "progress_update"
894 assert events[4].type == "thread.item.done"
895
896
897async def test_list_threads_response():
898 with make_server() as server:
899 for i in range(25):
900 await server.process_streaming(
901 ThreadsCreateReq(
902 params=ThreadCreateParams(
903 input=UserMessageInput(
904 content=[UserMessageTextContent(text=f"Thread {i}")],
905 attachments=[],
906 inference_options=InferenceOptions(),
907 )
908 )
909 )
910 )
911 list_result = await server.process_non_streaming(
912 ThreadsListReq(params=ThreadListParams(limit=20))
913 )
914 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
915 assert len(threads.data) == 20
916
917 # Newest first
918 assert threads.data[0].created_at > threads.data[1].created_at
919 assert threads.has_more is True
920 assert threads.after is not None
921
922 more_threads = await server.process_non_streaming(
923 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
924 )
925 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(more_threads.json)
926 assert len(threads.data) == 5
927 assert threads.has_more is False
928 assert not threads.after
929
930
931async def test_list_items_response():
932 async def responder(
933 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
934 ) -> AsyncIterator[ThreadStreamEvent]:
935 for i in range(25):
936 yield ThreadItemDoneEvent(
937 item=UserMessageItem(
938 id=f"msg_{i}",
939 content=[UserMessageTextContent(text=f"Message {i}")],
940 attachments=[],
941 inference_options=InferenceOptions(),
942 thread_id=thread.id,
943 created_at=datetime.now(),
944 ),
945 )
946
947 with make_server(responder) as server:
948 events = await server.process_streaming(
949 ThreadsCreateReq(
950 params=ThreadCreateParams(
951 input=UserMessageInput(
952 content=[UserMessageTextContent(text="Thread 1")],
953 attachments=[],
954 inference_options=InferenceOptions(),
955 )
956 )
957 )
958 )
959 thread = next(
960 event.thread for event in events if event.type == "thread.created"
961 )
962 list_result = await server.process_non_streaming(
963 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
964 )
965 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
966 assert len(items.data) == 20
967 # Newest first
968 assert items.data[0].created_at > items.data[1].created_at
969 assert items.has_more is True
970 assert items.after is not None
971
972 list_result = await server.process_non_streaming(
973 ItemsListReq(
974 params=ItemsListParams(thread_id=thread.id, after=items.data[-1].id)
975 )
976 )
977 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
978 assert len(items.data) == 6 # +1 for user message
979 assert items.has_more is False
980
981
982async def test_calls_action():
983 actions = []
984
985 async def responder(
986 thread: ThreadMetadata,
987 input: UserMessageItem | None,
988 context: Any,
989 ) -> AsyncIterator[ThreadStreamEvent]:
990 yield ThreadItemDoneEvent(
991 item=WidgetItem(
992 id="widget_1",
993 type="widget",
994 created_at=datetime.now(),
995 thread_id=thread.id,
996 widget=Card(
997 children=[
998 Text(
999 key="text_1",
1000 value="test",
1001 )
1002 ],
1003 ),
1004 ),
1005 )
1006
1007 async def action(
1008 thread: ThreadMetadata,
1009 action: Action[str, Any],
1010 sender: WidgetItem | None,
1011 context: Any,
1012 ) -> AsyncIterator[ThreadStreamEvent]:
1013 actions.append((action, sender))
1014 assert sender
1015
1016 yield ThreadItemUpdatedEvent(
1017 item_id=sender.id,
1018 update=WidgetRootUpdated(
1019 widget=Card(
1020 children=[
1021 Text(value="Email sent!"),
1022 ]
1023 ),
1024 ),
1025 )
1026
1027 with make_server(responder, action_callback=action) as server:
1028 events = await server.process_streaming(
1029 ThreadsCreateReq(
1030 params=ThreadCreateParams(
1031 input=UserMessageInput(
1032 content=[UserMessageTextContent(text="Show widget")],
1033 attachments=[],
1034 inference_options=InferenceOptions(),
1035 )
1036 )
1037 )
1038 )
1039 thread = next(
1040 event.thread for event in events if event.type == "thread.created"
1041 )
1042 widget_item = next(
1043 event.item
1044 for event in events
1045 if isinstance(event, ThreadItemDoneEvent)
1046 and isinstance(event.item, WidgetItem)
1047 )
1048
1049 events = await server.process_streaming(
1050 ThreadsCustomActionReq(
1051 params=ThreadCustomActionParams(
1052 thread_id=thread.id,
1053 item_id=widget_item.id,
1054 action=Action(type="create_user", payload={"user_id": "123"}),
1055 )
1056 )
1057 )
1058
1059 assert actions
1060 assert actions[0] == (
1061 Action(type="create_user", payload={"user_id": "123"}),
1062 widget_item,
1063 )
1064
1065 assert len(events) == 2
1066 assert events[0].type == "stream_options"
1067 assert events[1].type == "thread.item.updated"
1068 assert isinstance(events[1], ThreadItemUpdatedEvent)
1069 assert events[1].update.type == "widget.root.updated"
1070 assert events[1].update.widget == Card(children=[Text(value="Email sent!")])
1071
1072
1073async def test_calls_action_sync():
1074 sync_actions = []
1075
1076 async def responder(
1077 thread: ThreadMetadata,
1078 input: UserMessageItem | None,
1079 context: Any,
1080 ) -> AsyncIterator[ThreadStreamEvent]:
1081 yield ThreadItemDoneEvent(
1082 item=WidgetItem(
1083 id="widget_1",
1084 type="widget",
1085 created_at=datetime.now(),
1086 thread_id=thread.id,
1087 widget=Card(
1088 children=[
1089 Text(
1090 key="text_1",
1091 value="test",
1092 )
1093 ],
1094 ),
1095 ),
1096 )
1097
1098 def sync_action(
1099 thread: ThreadMetadata,
1100 action: Action[str, Any],
1101 sender: WidgetItem | None,
1102 context: Any,
1103 ) -> SyncCustomActionResponse:
1104 sync_actions.append((action, sender))
1105 assert sender
1106
1107 return SyncCustomActionResponse(
1108 updated_item=sender.model_copy(
1109 update={
1110 "widget": Card(
1111 children=[
1112 Text(value="Email sent!"),
1113 ]
1114 )
1115 }
1116 )
1117 )
1118
1119 with make_server(responder, sync_action_callback=sync_action) as server:
1120 events = await server.process_streaming(
1121 ThreadsCreateReq(
1122 params=ThreadCreateParams(
1123 input=UserMessageInput(
1124 content=[UserMessageTextContent(text="Show widget")],
1125 attachments=[],
1126 inference_options=InferenceOptions(),
1127 )
1128 )
1129 )
1130 )
1131 thread = next(
1132 event.thread for event in events if event.type == "thread.created"
1133 )
1134 widget_item = next(
1135 event.item
1136 for event in events
1137 if isinstance(event, ThreadItemDoneEvent)
1138 and isinstance(event.item, WidgetItem)
1139 )
1140
1141 result = await server.process_non_streaming(
1142 ThreadsSyncCustomActionReq(
1143 params=ThreadCustomActionParams(
1144 thread_id=thread.id,
1145 item_id=widget_item.id,
1146 action=Action(type="create_user", payload={"user_id": "123"}),
1147 )
1148 )
1149 )
1150 response = TypeAdapter(SyncCustomActionResponse).validate_json(result.json)
1151
1152 assert sync_actions
1153 assert sync_actions[0] == (
1154 Action(type="create_user", payload={"user_id": "123"}),
1155 widget_item,
1156 )
1157 assert response.updated_item == widget_item.model_copy(
1158 update={"widget": Card(children=[Text(value="Email sent!")])}
1159 )
1160
1161
1162async def test_add_feedback():
1163 called = []
1164
1165 def handle_feedback(thread_id, item_ids, feedback, context):
1166 called.append((thread_id, item_ids, feedback, context))
1167
1168 async def responder(
1169 thread: ThreadMetadata,
1170 input: UserMessageItem | None,
1171 context: Any,
1172 ) -> AsyncIterator[ThreadStreamEvent]:
1173 yield ThreadItemDoneEvent(
1174 item=AssistantMessageItem(
1175 id="assistant_msg_1",
1176 content=[
1177 AssistantMessageContent(text="Feedback received!"),
1178 ],
1179 created_at=datetime.now(),
1180 thread_id=thread.id,
1181 ),
1182 )
1183 yield ThreadItemDoneEvent(
1184 item=WidgetItem(
1185 id="widget_1",
1186 type="widget",
1187 created_at=datetime.now(),
1188 widget=Card(children=[Text(key="text", value="Widget content")]),
1189 thread_id=thread.id,
1190 ),
1191 )
1192
1193 with make_server(responder, handle_feedback=handle_feedback) as server:
1194 events = await server.process_streaming(
1195 ThreadsCreateReq(
1196 params=ThreadCreateParams(
1197 input=UserMessageInput(
1198 content=[UserMessageTextContent(text="Test feedback")],
1199 attachments=[],
1200 inference_options=InferenceOptions(),
1201 )
1202 )
1203 )
1204 )
1205 thread = next(
1206 event.thread for event in events if event.type == "thread.created"
1207 )
1208
1209 await server.process_non_streaming(
1210 ItemsFeedbackReq(
1211 params=ItemFeedbackParams(
1212 thread_id=thread.id,
1213 item_ids=["assistant_msg_1"],
1214 kind="positive",
1215 )
1216 )
1217 )
1218 await server.process_non_streaming(
1219 ItemsFeedbackReq(
1220 params=ItemFeedbackParams(
1221 thread_id=thread.id,
1222 item_ids=["widget_1"],
1223 kind="negative",
1224 )
1225 )
1226 )
1227
1228 assert len(called) == 2
1229 assert called[0][0] == thread.id
1230 assert called[0][1] == ["assistant_msg_1"]
1231 assert called[0][2] == "positive"
1232
1233 assert called[1][0] == thread.id
1234 assert called[1][1] == ["widget_1"]
1235 assert called[1][2] == "negative"
1236
1237
1238async def test_get_thread_by_id():
1239 make_thread_items()
1240
1241 with make_server() as server:
1242 # Create a thread by sending a ThreadCreateRequest
1243 events = await server.process_streaming(
1244 ThreadsCreateReq(
1245 params=ThreadCreateParams(
1246 input=UserMessageInput(
1247 content=[UserMessageTextContent(text="Test thread")],
1248 attachments=[],
1249 inference_options=InferenceOptions(),
1250 )
1251 )
1252 )
1253 )
1254 thread = next(
1255 event.thread for event in events if event.type == "thread.created"
1256 )
1257
1258 # Retrieve the thread by id
1259 result = await server.process_non_streaming(
1260 ThreadsGetByIdReq(
1261 params=ThreadGetByIdParams(thread_id=thread.id),
1262 )
1263 )
1264 loaded_thread = TypeAdapter(Thread).validate_json(result.json)
1265 assert loaded_thread.id == thread.id
1266 assert loaded_thread.title == thread.title
1267 assert loaded_thread.items.data[0].type == "user_message"
1268 assert loaded_thread.items.data[0].content[0].text == "Test thread"
1269
1270
1271async def test_create_file():
1272 store = InMemoryFileStore()
1273 file_name = "test-file-name"
1274 file_content_type = "text/plain"
1275
1276 with make_server(file_store=store) as server:
1277 result = await server.process_non_streaming(
1278 AttachmentsCreateReq(
1279 params=AttachmentCreateParams(
1280 name=file_name, size=1024, mime_type=file_content_type
1281 )
1282 )
1283 )
1284 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1285
1286 # Verify response includes file
1287 assert attachment.id is not None
1288 assert attachment.mime_type == file_content_type
1289 assert attachment.name == file_name
1290 assert attachment.type == "file"
1291 assert attachment.upload_descriptor is not None
1292 assert attachment.upload_descriptor.url == AnyUrl(
1293 f"https://example.com/{attachment.id}/upload"
1294 )
1295 assert attachment.upload_descriptor.method == "PUT"
1296 assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"}
1297 assert attachment.id in store.files
1298
1299 # The `metadata` field is excluded from serialization by default but still stored server-side
1300 assert attachment.metadata is None
1301 assert store.files[attachment.id].metadata == {"source": "test"}
1302
1303
1304async def test_create_image_file():
1305 store = InMemoryFileStore()
1306 file_name = "test-file-name"
1307 file_content_type = "image/png"
1308
1309 with make_server(file_store=store) as server:
1310 result = await server.process_non_streaming(
1311 AttachmentsCreateReq(
1312 params=AttachmentCreateParams(
1313 name=file_name,
1314 size=1024,
1315 mime_type=file_content_type,
1316 )
1317 )
1318 )
1319 attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1320
1321 assert attachment.id is not None
1322 assert attachment.mime_type == file_content_type
1323 assert attachment.name == file_name
1324 assert attachment.type == "image"
1325 assert attachment.preview_url == AnyUrl(
1326 f"https://example.com/{attachment.id}/preview"
1327 )
1328 assert attachment.upload_descriptor is not None
1329 assert attachment.upload_descriptor.url == AnyUrl(
1330 f"https://example.com/{attachment.id}/upload"
1331 )
1332 assert attachment.upload_descriptor.method == "PUT"
1333 assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"}
1334
1335 assert attachment.id in store.files
1336
1337
1338async def test_user_message_attachment_metadata_not_serialized():
1339 attachment = FileAttachment(
1340 id="file_with_metadata",
1341 type="file",
1342 mime_type="text/plain",
1343 name="test.txt",
1344 metadata={"source": "test"},
1345 )
1346
1347 with make_server() as server:
1348 await server.store.save_attachment(attachment, DEFAULT_CONTEXT)
1349 events = await server.process_streaming(
1350 ThreadsCreateReq(
1351 params=ThreadCreateParams(
1352 input=UserMessageInput(
1353 content=[UserMessageTextContent(text="Message with file")],
1354 attachments=[attachment.id],
1355 inference_options=InferenceOptions(),
1356 )
1357 )
1358 )
1359 )
1360 thread = next(
1361 event.thread for event in events if event.type == "thread.created"
1362 )
1363
1364 list_result = await server.process_non_streaming(
1365 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1366 )
1367 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1368 user_items = [item for item in items.data if item.type == "user_message"]
1369 assert user_items
1370 attachments = user_items[0].attachments
1371 assert attachments
1372 assert all(att.metadata is None for att in attachments)
1373
1374
1375async def test_user_message_attachment_thread_id_persisted_on_thread_create():
1376 attachment = FileAttachment(
1377 id="file_thread_id_on_create",
1378 type="file",
1379 mime_type="text/plain",
1380 name="test.txt",
1381 thread_id=None,
1382 )
1383
1384 with make_server() as server:
1385 # Attachment exists before a thread exists, so it may not have a thread_id yet.
1386 await server.store.save_attachment(attachment, DEFAULT_CONTEXT)
1387
1388 events = await server.process_streaming(
1389 ThreadsCreateReq(
1390 params=ThreadCreateParams(
1391 input=UserMessageInput(
1392 content=[UserMessageTextContent(text="Message with file")],
1393 attachments=[attachment.id],
1394 inference_options=InferenceOptions(),
1395 )
1396 )
1397 )
1398 )
1399 thread = next(e.thread for e in events if e.type == "thread.created")
1400
1401 # After the user message is saved, the attachment record should be updated
1402 # to reflect its thread association.
1403 stored = await server.store.load_attachment(attachment.id, DEFAULT_CONTEXT)
1404 assert stored.thread_id == thread.id
1405
1406
1407async def test_user_message_attachment_thread_id_persisted_on_add_user_message():
1408 with make_server() as server:
1409 # Create a thread first.
1410 events = await server.process_streaming(
1411 ThreadsCreateReq(
1412 params=ThreadCreateParams(
1413 input=UserMessageInput(
1414 content=[UserMessageTextContent(text="Start")],
1415 attachments=[],
1416 inference_options=InferenceOptions(),
1417 )
1418 )
1419 )
1420 )
1421 thread = next(e.thread for e in events if e.type == "thread.created")
1422
1423 attachment = FileAttachment(
1424 id="file_thread_id_on_add",
1425 type="file",
1426 mime_type="text/plain",
1427 name="test.txt",
1428 thread_id=None,
1429 )
1430 await server.store.save_attachment(attachment, DEFAULT_CONTEXT)
1431
1432 # Attach the pre-created attachment to a message in an existing thread.
1433 await server.process_streaming(
1434 ThreadsAddUserMessageReq(
1435 params=ThreadAddUserMessageParams(
1436 thread_id=thread.id,
1437 input=UserMessageInput(
1438 content=[UserMessageTextContent(text="Second")],
1439 attachments=[attachment.id],
1440 inference_options=InferenceOptions(),
1441 ),
1442 )
1443 )
1444 )
1445
1446 stored = await server.store.load_attachment(attachment.id, DEFAULT_CONTEXT)
1447 assert stored.thread_id == thread.id
1448
1449
1450async def test_create_file_without_filestore():
1451 with pytest.raises(RuntimeError):
1452 with make_server() as server:
1453 await server.process_non_streaming(
1454 AttachmentsCreateReq(
1455 params=AttachmentCreateParams(
1456 name="test-file-name",
1457 size=1024,
1458 mime_type="text/plain",
1459 )
1460 )
1461 )
1462
1463
1464async def test_delete_file():
1465 store = InMemoryFileStore()
1466 store.files["test-file-id"] = {"id": "test-file-id"}
1467 with make_server(file_store=store) as server:
1468 await server.process_non_streaming(
1469 AttachmentsDeleteReq(
1470 params=AttachmentDeleteParams(attachment_id="test-file-id")
1471 )
1472 )
1473 assert store.files == {}
1474
1475
1476async def test_delete_file_without_filestore():
1477 with pytest.raises(RuntimeError):
1478 with make_server() as server:
1479 await server.process_non_streaming(
1480 AttachmentsDeleteReq(
1481 params=AttachmentDeleteParams(attachment_id="test-file-id")
1482 )
1483 )
1484
1485
1486async def test_list_threads_paginated_response():
1487 with make_server() as server:
1488 for i in range(25):
1489 await server.process_streaming(
1490 ThreadsCreateReq(
1491 params=ThreadCreateParams(
1492 input=UserMessageInput(
1493 content=[UserMessageTextContent(text=f"Thread {i}")],
1494 attachments=[],
1495 inference_options=InferenceOptions(),
1496 )
1497 )
1498 )
1499 )
1500 list_result = await server.process_non_streaming(
1501 ThreadsListReq(params=ThreadListParams(limit=20))
1502 )
1503 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1504 assert len(threads.data) == 20
1505 assert threads.has_more is True
1506 assert threads.after is not None
1507
1508 list_result = await server.process_non_streaming(
1509 ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
1510 )
1511 threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1512 assert len(threads.data) == 5
1513 assert threads.has_more is False
1514
1515
1516async def test_threads_update_request():
1517 with make_server() as server:
1518 events = await server.process_streaming(
1519 ThreadsCreateReq(
1520 params=ThreadCreateParams(
1521 input=UserMessageInput(
1522 content=[UserMessageTextContent(text="Hello, world!")],
1523 attachments=[],
1524 inference_options=InferenceOptions(),
1525 )
1526 )
1527 )
1528 )
1529 thread = next(
1530 event.thread for event in events if event.type == "thread.created"
1531 )
1532 updated_thread = await server.process_non_streaming(
1533 ThreadsUpdateReq(
1534 params=ThreadUpdateParams(thread_id=thread.id, title="New Title")
1535 )
1536 )
1537 updated_thread = TypeAdapter[Thread](Thread).validate_json(updated_thread.json)
1538 assert updated_thread.title == "New Title"
1539
1540
1541async def test_returns_widget_item_generator():
1542 thread = ThreadMetadata(
1543 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1544 )
1545
1546 def render_widget(i: int) -> Card:
1547 return Card(children=[Text(id="text", value="Hello, world"[:i])])
1548
1549 async def widget_generator():
1550 yield render_widget(0)
1551 yield render_widget(3)
1552 yield render_widget(6)
1553 yield render_widget(12)
1554
1555 events = [event async for event in stream_widget(thread, widget_generator())]
1556
1557 assert len(events) == 5
1558 assert isinstance(events[0], ThreadItemAddedEvent)
1559 assert isinstance(events[0].item, WidgetItem)
1560 assert events[0].item.widget == Card(children=[Text(id="text", value="")])
1561
1562 assert isinstance(events[1], ThreadItemUpdatedEvent)
1563 assert events[1].update.type == "widget.streaming_text.value_delta"
1564 assert events[1].update.component_id == "text"
1565 assert events[1].update.delta == "Hel"
1566
1567 assert isinstance(events[2], ThreadItemUpdatedEvent)
1568 assert events[2].update.type == "widget.streaming_text.value_delta"
1569 assert events[2].update.component_id == "text"
1570 assert events[2].update.delta == "lo,"
1571
1572 assert isinstance(events[3], ThreadItemUpdatedEvent)
1573 assert events[3].update.type == "widget.streaming_text.value_delta"
1574 assert events[3].update.component_id == "text"
1575 assert events[3].update.delta == " world"
1576
1577 assert isinstance(events[4], ThreadItemDoneEvent)
1578 assert isinstance(events[4].item, WidgetItem)
1579 assert events[4].item.widget == Card(
1580 children=[Text(id="text", value="Hello, world")]
1581 )
1582
1583
1584async def test_returns_widget_item_generator_full_replace():
1585 thread = ThreadMetadata(
1586 id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1587 )
1588
1589 async def widget_generator():
1590 yield Card(children=[Text(id="text", value="Hello")])
1591 yield Card(children=[Text(id="text", value="World")])
1592
1593 events = [event async for event in stream_widget(thread, widget_generator())]
1594
1595 assert len(events) == 3
1596 assert isinstance(events[0], ThreadItemAddedEvent)
1597 assert isinstance(events[0].item, WidgetItem)
1598 assert events[0].item.widget == Card(children=[Text(id="text", value="Hello")])
1599
1600 assert isinstance(events[1], ThreadItemUpdatedEvent)
1601 assert events[1].update.type == "widget.root.updated"
1602 assert events[1].update.widget == Card(children=[Text(id="text", value="World")])
1603
1604 assert isinstance(events[2], ThreadItemDoneEvent)
1605 assert isinstance(events[2].item, WidgetItem)
1606 assert events[2].item.widget == Card(children=[Text(id="text", value="World")])
1607
1608
1609async def test_delete_thread():
1610 with make_server() as server:
1611 events = await server.process_streaming(
1612 ThreadsCreateReq(
1613 params=ThreadCreateParams(
1614 input=UserMessageInput(
1615 content=[UserMessageTextContent(text="Thread to delete")],
1616 attachments=[],
1617 inference_options=InferenceOptions(),
1618 )
1619 )
1620 )
1621 )
1622 thread = next(
1623 event.thread for event in events if event.type == "thread.created"
1624 )
1625
1626 response = await server.process_non_streaming(
1627 ThreadsDeleteReq(params=ThreadDeleteParams(thread_id=thread.id))
1628 )
1629 assert response.json == b"{}"
1630
1631 try:
1632 await server.process_non_streaming(
1633 ThreadsGetByIdReq(
1634 params=ThreadGetByIdParams(thread_id=thread.id),
1635 )
1636 )
1637 assert False, "Expected ValueError for deleted thread"
1638 except NotFoundError as e:
1639 assert f"Thread {thread.id} not found" in str(e)
1640
1641
1642async def test_thread_item_removed_event_removes_item():
1643 removed_id = "msg_to_remove"
1644
1645 async def responder(
1646 thread: ThreadMetadata,
1647 input: UserMessageItem | None,
1648 context: Any,
1649 ) -> AsyncIterator[ThreadStreamEvent]:
1650 yield ThreadItemDoneEvent(
1651 item=UserMessageItem(
1652 id=removed_id,
1653 content=[UserMessageTextContent(text="To be removed")],
1654 attachments=[],
1655 inference_options=InferenceOptions(),
1656 thread_id=thread.id,
1657 created_at=datetime.now(),
1658 ),
1659 )
1660 yield ThreadItemRemovedEvent(item_id=removed_id)
1661
1662 with make_server(responder) as server:
1663 events = await server.process_streaming(
1664 ThreadsCreateReq(
1665 params=ThreadCreateParams(
1666 input=UserMessageInput(
1667 content=[UserMessageTextContent(text="Trigger removal")],
1668 attachments=[],
1669 inference_options=InferenceOptions(),
1670 )
1671 )
1672 )
1673 )
1674 thread = next(
1675 event.thread for event in events if event.type == "thread.created"
1676 )
1677 assert any(
1678 e.type == "thread.item.removed" and e.item_id == removed_id for e in events
1679 )
1680 # The item should not be present in the store
1681 list_result = await server.process_non_streaming(
1682 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1683 )
1684 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1685 assert all(i.id != removed_id for i in items)
1686
1687
1688async def test_raising_in_responder_yields_error_event():
1689 async def responder(
1690 thread: ThreadMetadata,
1691 input: UserMessageItem | None,
1692 context: Any,
1693 ) -> AsyncIterator[ThreadStreamEvent]:
1694 raise ValueError("Test error")
1695 yield
1696
1697 with make_server(responder) as server:
1698 result = await server.process(
1699 ThreadsCreateReq(
1700 params=ThreadCreateParams(
1701 input=UserMessageInput(
1702 content=[UserMessageTextContent(text="Test error")],
1703 attachments=[],
1704 inference_options=InferenceOptions(),
1705 )
1706 )
1707 ).model_dump_json(),
1708 RequestContext(user_id="test_user"),
1709 )
1710 assert isinstance(result, StreamingResult)
1711 iter = result.__aiter__()
1712
1713 # Yield an errro
1714 async for e in iter:
1715 e = decode_event(e)
1716 if e.type == "error":
1717 assert e.code == ErrorCode.STREAM_ERROR
1718 break
1719 else:
1720 assert False, "Expected error event"
1721
1722 # Then finish
1723 try:
1724 await iter.__anext__()
1725 except StopAsyncIteration:
1726 pass
1727 else:
1728 assert False, "Expected StopAsyncIteration"
1729
1730
1731async def test_thread_item_replaced_event_replaces_item():
1732 original_id = "msg_to_replace"
1733
1734 async def responder(
1735 thread: ThreadMetadata,
1736 input: UserMessageItem | None,
1737 context: Any,
1738 ) -> AsyncIterator[ThreadStreamEvent]:
1739 # First add an item
1740 yield ThreadItemDoneEvent(
1741 item=UserMessageItem(
1742 id=original_id,
1743 content=[UserMessageTextContent(text="Original content")],
1744 attachments=[],
1745 inference_options=InferenceOptions(),
1746 thread_id=thread.id,
1747 created_at=datetime.now(),
1748 ),
1749 )
1750 # Then replace it
1751 replacement_item = AssistantMessageItem(
1752 id=original_id,
1753 content=[AssistantMessageContent(text="This is the replacement content")],
1754 created_at=datetime.now(),
1755 thread_id=thread.id,
1756 )
1757 yield ThreadItemReplacedEvent(item=replacement_item)
1758
1759 with make_server(responder) as server:
1760 events = await server.process_streaming(
1761 ThreadsCreateReq(
1762 params=ThreadCreateParams(
1763 input=UserMessageInput(
1764 content=[UserMessageTextContent(text="Trigger replacement")],
1765 attachments=[],
1766 inference_options=InferenceOptions(),
1767 )
1768 )
1769 )
1770 )
1771 thread = next(
1772 event.thread for event in events if event.type == "thread.created"
1773 )
1774
1775 # Verify the replacement event was generated
1776 assert any(
1777 e.type == "thread.item.replaced" and e.item.id == original_id
1778 for e in events
1779 )
1780
1781 # Verify the item was replaced in the store
1782 list_result = await server.process_non_streaming(
1783 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1784 )
1785 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1786
1787 # Find the replaced item
1788 replaced_item = next((i for i in items if i.id == original_id), None)
1789 assert replaced_item is not None
1790 assert isinstance(replaced_item, AssistantMessageItem)
1791 assert isinstance(replaced_item.content[0], AssistantMessageContent)
1792 assert replaced_item.content[0].text == "This is the replacement content"
1793
1794
1795async def test_thread_item_replaced_event_with_widget():
1796 widget_id = "widget_to_replace"
1797 original_widget = Card(children=[Text(value="Original widget content")])
1798 replacement_widget = Card(children=[Text(value="Replaced widget content")])
1799
1800 async def responder(
1801 thread: ThreadMetadata,
1802 input: UserMessageItem | None,
1803 context: Any,
1804 ) -> AsyncIterator[ThreadStreamEvent]:
1805 # First add a widget item
1806 yield ThreadItemDoneEvent(
1807 item=WidgetItem(
1808 id=widget_id,
1809 widget=original_widget,
1810 created_at=datetime.now(),
1811 thread_id=thread.id,
1812 ),
1813 )
1814 # Then replace it
1815 yield ThreadItemReplacedEvent(
1816 item=WidgetItem(
1817 id=widget_id,
1818 widget=replacement_widget,
1819 created_at=datetime.now(),
1820 thread_id=thread.id,
1821 ),
1822 )
1823
1824 with make_server(responder) as server:
1825 events = await server.process_streaming(
1826 ThreadsCreateReq(
1827 params=ThreadCreateParams(
1828 input=UserMessageInput(
1829 content=[
1830 UserMessageTextContent(text="Trigger widget replacement")
1831 ],
1832 attachments=[],
1833 inference_options=InferenceOptions(),
1834 )
1835 )
1836 )
1837 )
1838 thread = next(
1839 event.thread for event in events if event.type == "thread.created"
1840 )
1841
1842 # Verify the replacement event was generated
1843 replacement_event = next(
1844 (
1845 e
1846 for e in events
1847 if e.type == "thread.item.replaced" and e.item.id == widget_id
1848 ),
1849 None,
1850 )
1851 assert replacement_event is not None
1852 assert isinstance(replacement_event.item, WidgetItem)
1853 assert replacement_event.item.widget == replacement_widget
1854
1855 # Verify the item was replaced in the store
1856 list_result = await server.process_non_streaming(
1857 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1858 )
1859 items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1860
1861 # Find the replaced widget item
1862 replaced_item = next((i for i in items if i.id == widget_id), None)
1863 assert replaced_item is not None
1864 assert isinstance(replaced_item, WidgetItem)
1865 assert isinstance(replaced_item.widget, Card)
1866 assert isinstance(replaced_item.widget.children[0], Text)
1867 assert replaced_item.widget.children[0].value == "Replaced widget content"
1868
1869
1870async def test_retry_after_item():
1871 responder_calls = []
1872
1873 async def responder(
1874 thread: ThreadMetadata,
1875 input: UserMessageItem | None,
1876 context: Any,
1877 ) -> AsyncIterator[ThreadStreamEvent]:
1878 responder_calls.append(input)
1879
1880 if len(responder_calls) == 1:
1881 # First response - return an assistant message
1882 yield ThreadItemDoneEvent(
1883 item=AssistantMessageItem(
1884 id="assistant_msg_1",
1885 content=[AssistantMessageContent(text="First response")],
1886 created_at=datetime.now(),
1887 thread_id=thread.id,
1888 ),
1889 )
1890 else:
1891 # Retry response - return different content
1892 yield ThreadItemDoneEvent(
1893 item=AssistantMessageItem(
1894 id="assistant_msg_2",
1895 content=[AssistantMessageContent(text="Retried response")],
1896 created_at=datetime.now(),
1897 thread_id=thread.id,
1898 ),
1899 )
1900
1901 with make_server(responder) as server:
1902 events = await server.process_streaming(
1903 ThreadsCreateReq(
1904 params=ThreadCreateParams(
1905 input=UserMessageInput(
1906 content=[UserMessageTextContent(text="Hello, world!")],
1907 attachments=[],
1908 inference_options=InferenceOptions(),
1909 )
1910 )
1911 )
1912 )
1913 thread = next(
1914 event.thread for event in events if event.type == "thread.created"
1915 )
1916
1917 list_result = await server.process_non_streaming(
1918 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1919 )
1920 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1921 assert len(items_before.data) == 2
1922 assert items_before.data[0].type == "assistant_message"
1923 assert items_before.data[0].content[0].type == "output_text"
1924 assert items_before.data[0].content[0].text == "First response"
1925 assert items_before.data[1].type == "user_message"
1926 assert items_before.data[1].content[0].text == "Hello, world!"
1927
1928 # Get the user message ID for retrying after it
1929 user_message_id = items_before.data[1].id
1930
1931 # Retry after the user message
1932 retry_events = await server.process_streaming(
1933 ThreadsRetryAfterItemReq(
1934 params=ThreadRetryAfterItemParams(
1935 thread_id=thread.id, item_id=user_message_id
1936 )
1937 )
1938 )
1939
1940 # Verify the retry generated new response
1941 assert len(retry_events) == 2
1942 assert retry_events[0].type == "stream_options"
1943 assert retry_events[1].type == "thread.item.done"
1944 assert retry_events[1].item.type == "assistant_message"
1945 assert retry_events[1].item.content[0].type == "output_text"
1946 assert retry_events[1].item.content[0].text == "Retried response"
1947
1948 # Verify the responder was called twice with the same user message
1949 assert len(responder_calls) == 2
1950 assert responder_calls[0] == responder_calls[1]
1951 assert responder_calls[0].type == "user_message"
1952 assert responder_calls[0].content[0].text == "Hello, world!"
1953
1954 # Verify the assistant response was replaced in storage
1955 list_result_after = await server.process_non_streaming(
1956 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1957 )
1958 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1959 list_result_after.json
1960 )
1961 assert len(items_after.data) == 2 # Still user message + assistant message
1962 assert items_after.data[0].type == "assistant_message"
1963 assert items_after.data[0].content[0].type == "output_text"
1964 assert items_after.data[0].content[0].text == "Retried response" # New response
1965 assert items_after.data[1].type == "user_message"
1966 assert items_after.data[1].content[0].text == "Hello, world!"
1967
1968
1969async def test_retry_after_item_invalid_item():
1970 async def responder(
1971 thread: ThreadMetadata,
1972 input: UserMessageItem | None,
1973 context: Any,
1974 ) -> AsyncIterator[ThreadStreamEvent]:
1975 yield ThreadItemDoneEvent(
1976 item=AssistantMessageItem(
1977 id="assistant_msg_1",
1978 content=[AssistantMessageContent(text="Response")],
1979 created_at=datetime.now(),
1980 thread_id=thread.id,
1981 ),
1982 )
1983
1984 with make_server(responder) as server:
1985 events = await server.process_streaming(
1986 ThreadsCreateReq(
1987 params=ThreadCreateParams(
1988 input=UserMessageInput(
1989 content=[UserMessageTextContent(text="Hello")],
1990 attachments=[],
1991 inference_options=InferenceOptions(),
1992 )
1993 )
1994 )
1995 )
1996 thread = next(
1997 event.thread for event in events if event.type == "thread.created"
1998 )
1999
2000 items_result = await server.process_non_streaming(
2001 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
2002 )
2003 items = TypeAdapter(Page[ThreadItem]).validate_json(items_result.json)
2004 assistant_item_id = items.data[0].id
2005
2006 # Try to retry after an assistant message (should fail)
2007 with pytest.raises(ValueError, match="is not a user message"):
2008 await server.process_streaming(
2009 ThreadsRetryAfterItemReq(
2010 params=ThreadRetryAfterItemParams(
2011 thread_id=thread.id, item_id=assistant_item_id
2012 )
2013 )
2014 )
2015
2016
2017async def test_retry_after_item_with_multiple_messages():
2018 responder_calls = []
2019
2020 async def responder(
2021 thread: ThreadMetadata,
2022 input: UserMessageItem | None,
2023 context: Any,
2024 ) -> AsyncIterator[ThreadStreamEvent]:
2025 responder_calls.append(input)
2026 assert input is not None
2027 assert input.type == "user_message"
2028 yield ThreadItemDoneEvent(
2029 item=AssistantMessageItem(
2030 id=f"assistant_msg_{len(responder_calls)}",
2031 content=[
2032 AssistantMessageContent(
2033 text=f"Response to: {input.content[0].text}"
2034 )
2035 ],
2036 created_at=datetime.now(),
2037 thread_id=thread.id,
2038 ),
2039 )
2040
2041 with make_server(responder) as server:
2042 events = await server.process_streaming(
2043 ThreadsCreateReq(
2044 params=ThreadCreateParams(
2045 input=UserMessageInput(
2046 content=[UserMessageTextContent(text="First message")],
2047 attachments=[],
2048 inference_options=InferenceOptions(),
2049 )
2050 )
2051 )
2052 )
2053 thread = next(
2054 event.thread for event in events if event.type == "thread.created"
2055 )
2056
2057 await server.process_streaming(
2058 ThreadsAddUserMessageReq(
2059 params=ThreadAddUserMessageParams(
2060 thread_id=thread.id,
2061 input=UserMessageInput(
2062 content=[UserMessageTextContent(text="Second message")],
2063 attachments=[],
2064 inference_options=InferenceOptions(),
2065 ),
2066 )
2067 )
2068 )
2069
2070 # Verify we have 4 items: user1, assistant1, user2, assistant2
2071 list_result = await server.process_non_streaming(
2072 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
2073 )
2074 items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
2075 assert len(items_before.data) == 4
2076
2077 # Get the second user message ID to retry after it
2078 second_user_message_id = items_before.data[1].id # Second user message
2079
2080 # Retry after the second user message
2081 retry_events = await server.process_streaming(
2082 ThreadsRetryAfterItemReq(
2083 params=ThreadRetryAfterItemParams(
2084 thread_id=thread.id, item_id=second_user_message_id
2085 )
2086 )
2087 )
2088
2089 assert len(retry_events) == 2
2090 assert retry_events[0].type == "stream_options"
2091 assert retry_events[1].type == "thread.item.done"
2092
2093 # Verify retry used the second user message
2094 assert len(responder_calls) == 3 # Original 2 + 1 retry
2095 assert responder_calls[2].type == "user_message"
2096 assert responder_calls[2].content[0].text == "Second message"
2097
2098 # Verify only the last assistant response was removed and regenerated
2099 list_result_after = await server.process_non_streaming(
2100 ItemsListReq(params=ItemsListParams(thread_id=thread.id))
2101 )
2102 items_after = TypeAdapter(Page[ThreadItem]).validate_json(
2103 list_result_after.json
2104 )
2105 assert len(items_after.data) == 4 # Still 4 items
2106 # First user message and response should remain unchanged
2107 assert items_after.data[3].type == "user_message"
2108 assert items_after.data[3].content[0].text == "First message"
2109 assert items_after.data[3].content[0].type == "input_text"
2110 assert items_after.data[3].content[0].text == "First message"
2111 assert items_after.data[2].type == "assistant_message"
2112 assert items_after.data[2].content[0].type == "output_text"
2113 assert items_after.data[2].content[0].text == "Response to: First message"
2114 # Second user message should remain, but assistant response should be new
2115 assert items_after.data[1].type == "user_message"
2116 assert items_after.data[1].content[0].text == "Second message"
2117 assert items_after.data[0].type == "assistant_message"
2118 assert items_after.data[0].content[0].type == "output_text"
2119 assert items_after.data[0].content[0].text == "Response to: Second message"
2120 assert items_after.data[0].id != items_before.data[0].id
2121
2122
2123async def test_threads_create_passes_tools_to_responder():
2124 tool_choice = ToolChoice(id="web_search")
2125
2126 async def responder(
2127 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
2128 ) -> AsyncIterator[ThreadStreamEvent]:
2129 assert input is not None
2130 assert input.inference_options.tool_choice is not None
2131 assert input.inference_options.tool_choice.id == tool_choice.id
2132
2133 yield ThreadItemDoneEvent(
2134 item=AssistantMessageItem(
2135 id="assistant_msg_tools_create",
2136 content=[AssistantMessageContent(text="ok")],
2137 created_at=datetime.now(),
2138 thread_id=thread.id,
2139 ),
2140 )
2141
2142 with make_server(responder) as server:
2143 events = await server.process_streaming(
2144 ThreadsCreateReq(
2145 params=ThreadCreateParams(
2146 input=UserMessageInput(
2147 content=[UserMessageTextContent(text="Hello with tools")],
2148 attachments=[],
2149 inference_options=InferenceOptions(tool_choice=tool_choice),
2150 )
2151 )
2152 )
2153 )
2154 # Sanity check that the request flowed through.
2155 assert any(e.type == "thread.item.done" for e in events)
2156
2157
2158async def test_input_transcribe_decodes_base64_and_normalizes_mime_type():
2159 audio_bytes = b"hello audio"
2160 audio_b64 = base64.b64encode(audio_bytes).decode("ascii")
2161 seen: dict[str, Any] = {}
2162
2163 def transcribe_callback(
2164 audio_input: AudioInput, context: Any
2165 ) -> TranscriptionResult:
2166 seen["audio"] = audio_input.data
2167 seen["mime"] = audio_input.mime_type
2168 seen["media_type"] = audio_input.media_type
2169 seen["context"] = context
2170 return TranscriptionResult(text="ok")
2171
2172 with make_server(transcribe_callback=transcribe_callback) as server:
2173 result = await server.process_non_streaming(
2174 InputTranscribeReq(
2175 params=InputTranscribeParams(
2176 audio_base64=audio_b64,
2177 mime_type="audio/webm;codecs=opus",
2178 )
2179 )
2180 )
2181 parsed = TypeAdapter(TranscriptionResult).validate_json(result.json)
2182 assert parsed.text == "ok"
2183
2184 assert seen["audio"] == audio_bytes
2185 assert seen["mime"] == "audio/webm;codecs=opus"
2186 assert seen["media_type"] == "audio/webm"
2187 assert seen["context"] == DEFAULT_CONTEXT
2188
2189
2190async def test_retry_after_item_passes_tools_to_responder():
2191 pass
2192
2193
2194async def test_threads_add_message_passes_tools_to_responder():
2195 tool_choice = ToolChoice(id="retrieval")
2196
2197 async def responder(
2198 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
2199 ) -> AsyncIterator[ThreadStreamEvent]:
2200 assert input is not None
2201 assert input.inference_options.tool_choice is not None
2202 assert input.inference_options.tool_choice.id == tool_choice.id
2203
2204 yield ThreadItemDoneEvent(
2205 item=AssistantMessageItem(
2206 id="assistant_msg_tools_add",
2207 content=[AssistantMessageContent(text="ok")],
2208 created_at=datetime.now(),
2209 thread_id=thread.id,
2210 ),
2211 )
2212
2213 with make_server(responder) as server:
2214 # Create a thread first.
2215 events = await server.process_streaming(
2216 ThreadsCreateReq(
2217 params=ThreadCreateParams(
2218 input=UserMessageInput(
2219 content=[UserMessageTextContent(text="Start")],
2220 attachments=[],
2221 inference_options=InferenceOptions(tool_choice=tool_choice),
2222 )
2223 )
2224 )
2225 )
2226 thread = next(e.thread for e in events if e.type == "thread.created")
2227
2228 # Add a message with tools and verify they reach the responder.
2229 events = await server.process_streaming(
2230 ThreadsAddUserMessageReq(
2231 params=ThreadAddUserMessageParams(
2232 thread_id=thread.id,
2233 input=UserMessageInput(
2234 content=[UserMessageTextContent(text="Second")],
2235 attachments=[],
2236 inference_options=InferenceOptions(tool_choice=tool_choice),
2237 ),
2238 )
2239 )
2240 )
2241 # Sanity check that the request flowed through.
2242 assert any(e.type == "thread.item.done" for e in events)
2243
2244
2245async def test_custom_id_generators_on_server():
2246 class UniqueIdStore(SQLiteStore):
2247 def __init__(self, path):
2248 super().__init__(path)
2249 self.counter = 0
2250
2251 def generate_thread_id(self, context) -> str:
2252 self.counter += 1
2253 return f"thr_custom_{self.counter}"
2254
2255 def generate_item_id(
2256 self,
2257 item_type: str,
2258 thread: ThreadMetadata,
2259 context,
2260 ) -> str:
2261 self.counter += 1
2262 return f"{item_type}_custom_{self.counter}_{thread.id}"
2263
2264 async def responder(
2265 thread: ThreadMetadata, input: UserMessageItem | None, context: Any
2266 ) -> AsyncIterator[ThreadStreamEvent]:
2267 # Emit a tool call to later verify tool_call_output id generation
2268 return
2269 yield
2270
2271 with make_server(responder) as server:
2272 # Swap in our custom ID-generating store, reusing the same DB path.
2273 db_path = cast(SQLiteStore, server.store).db_path
2274 server.store = UniqueIdStore(db_path)
2275
2276 # Create thread and verify thread id and user message id
2277 result = await server.process(
2278 ThreadsCreateReq(
2279 params=ThreadCreateParams(
2280 input=UserMessageInput(
2281 content=[UserMessageTextContent(text="Hello")],
2282 attachments=[],
2283 inference_options=InferenceOptions(),
2284 )
2285 )
2286 ).model_dump_json(),
2287 DEFAULT_CONTEXT,
2288 )
2289 assert isinstance(result, StreamingResult)
2290 events = await decode_streaming_result(result)
2291
2292 thread_created = next(e for e in events if e.type == "thread.created")
2293 assert thread_created.thread.id == "thr_custom_1"
2294
2295 user_message = next(
2296 e.item
2297 for e in events
2298 if e.type == "thread.item.done" and e.item.type == "user_message"
2299 )
2300 assert user_message.id == "message_custom_2_thr_custom_1"
2301