openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.6.5

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_chatkit_server.py

2624lines · modepreview

import asyncio
import base64
import json
import sqlite3
from contextlib import contextmanager
from datetime import datetime
from typing import Any, AsyncIterator, Callable, cast

import pytest
from helpers.mock_store import SQLiteStore
from pydantic import AnyUrl, TypeAdapter

from chatkit.actions import Action
from chatkit.errors import ErrorCode
from chatkit.server import (
    ChatKitServer,
    NonStreamingResult,
    StreamingResult,
    stream_widget,
)
from chatkit.store import (
    AttachmentStore,
    NotFoundError,
    StoreItemType,
    default_generate_id,
)
from chatkit.types import (
    AssistantMessageContent,
    AssistantMessageContentPartTextDelta,
    AssistantMessageItem,
    Attachment,
    AttachmentCreateParams,
    AttachmentDeleteParams,
    AttachmentsCreateReq,
    AttachmentsDeleteReq,
    AttachmentUploadDescriptor,
    AudioInput,
    ClientToolCallItem,
    CustomTask,
    FeedbackKind,
    FileAttachment,
    ImageAttachment,
    InferenceOptions,
    InputTranscribeParams,
    InputTranscribeReq,
    ItemFeedbackParams,
    ItemsFeedbackReq,
    ItemsListParams,
    ItemsListReq,
    LockedStatus,
    Page,
    ProgressUpdateEvent,
    StructuredInputAnswerSubmission,
    StructuredInputFreeform,
    StructuredInputItem,
    StructuredInputMultipleChoice,
    StructuredInputMultipleChoiceOption,
    StructuredInputSubmission,
    SyncCustomActionResponse,
    Thread,
    ThreadAddClientToolOutputParams,
    ThreadAddStructuredInputParams,
    ThreadAddUserMessageParams,
    ThreadCreatedEvent,
    ThreadCreateParams,
    ThreadCustomActionParams,
    ThreadDeleteParams,
    ThreadGetByIdParams,
    ThreadItem,
    ThreadItemAddedEvent,
    ThreadItemDoneEvent,
    ThreadItemRemovedEvent,
    ThreadItemReplacedEvent,
    ThreadItemUpdatedEvent,
    ThreadListParams,
    ThreadMetadata,
    ThreadRetryAfterItemParams,
    ThreadsAddClientToolOutputReq,
    ThreadsAddStructuredInputReq,
    ThreadsAddUserMessageReq,
    ThreadsCreateReq,
    ThreadsCustomActionReq,
    ThreadsDeleteReq,
    ThreadsGetByIdReq,
    ThreadsListReq,
    ThreadsRetryAfterItemReq,
    ThreadsSyncCustomActionReq,
    ThreadStreamEvent,
    ThreadsUpdateReq,
    ThreadUpdatedEvent,
    ThreadUpdateParams,
    ToolChoice,
    TranscriptionResult,
    UserMessageInput,
    UserMessageItem,
    UserMessageTextContent,
    WidgetItem,
    WidgetRootUpdated,
    Workflow,
    WorkflowItem,
    WorkflowTaskAdded,
)
from chatkit.widgets import Card, Text
from tests._types import RequestContext
from tests.test_store import make_thread_items

server_id = 0


DEFAULT_CONTEXT = RequestContext(user_id="test_user")


class InMemoryFileStore(AttachmentStore):
    def __init__(self):
        self.files = {}

    async def delete_attachment(self, attachment_id: str, context: Any) -> None:
        del self.files[attachment_id]

    async def create_attachment(
        self, input: AttachmentCreateParams, context: RequestContext
    ) -> Attachment:
        # Simple in-memory attachment creation
        id = f"atc_{len(self.files) + 1}"
        metadata = {"source": "test"}
        if input.mime_type.startswith("image/"):
            attachment = ImageAttachment(
                id=id,
                mime_type=input.mime_type,
                name=input.name,
                preview_url=AnyUrl(f"https://example.com/{id}/preview"),
                upload_descriptor=AttachmentUploadDescriptor(
                    url=AnyUrl(f"https://example.com/{id}/upload"),
                    method="PUT",
                    headers={"X-My-Header": "my-value"},
                ),
                metadata=metadata,
            )
        else:
            attachment = FileAttachment(
                id=id,
                mime_type=input.mime_type,
                name=input.name,
                upload_descriptor=AttachmentUploadDescriptor(
                    url=AnyUrl(f"https://example.com/{id}/upload"),
                    method="PUT",
                    headers={"X-My-Header": "my-value"},
                ),
                metadata=metadata,
            )
        self.files[attachment.id] = attachment
        return attachment


def decode_event(event: bytes) -> ThreadStreamEvent:
    return TypeAdapter(ThreadStreamEvent).validate_json(event.split(b"data: ")[1])


async def decode_streaming_result(
    streaming_result: StreamingResult,
) -> list[ThreadStreamEvent]:
    return [decode_event(event) async for event in streaming_result.json_events]


@contextmanager
def make_server(
    responder: Callable[
        [ThreadMetadata, UserMessageItem | None, Any], AsyncIterator[ThreadStreamEvent]
    ]
    | None = None,
    handle_feedback: Callable[[str, list[str], FeedbackKind, Any], None] | None = None,
    action_callback: Callable[
        [ThreadMetadata, Action[str, Any], WidgetItem | None, Any],
        AsyncIterator[ThreadStreamEvent],
    ]
    | None = None,
    sync_action_callback: Callable[
        [ThreadMetadata, Action[str, Any], WidgetItem | None, Any],
        SyncCustomActionResponse,
    ]
    | None = None,
    file_store: AttachmentStore | None = None,
    transcribe_callback: Callable[[AudioInput, Any], TranscriptionResult] | None = None,
):
    global server_id
    db_path = f"file:{server_id}?mode=memory&cache=shared"
    server_id += 1
    # Keep the shared in-memory database from being deleted when the last connection is closed
    db = sqlite3.connect(db_path, uri=True)

    class TestChatKitServer(ChatKitServer):
        def __init__(self):
            super().__init__(SQLiteStore(db_path), file_store)

        def action(
            self,
            thread: ThreadMetadata,
            action: Action[str, Any],
            sender: WidgetItem | None,
            context: Any,
        ) -> AsyncIterator[ThreadStreamEvent]:
            if action_callback is None:
                raise ValueError("action_callback not wired up")
            return action_callback(thread, action, sender, context)

        async def sync_action(
            self,
            thread: ThreadMetadata,
            action: Action[str, Any],
            sender: WidgetItem | None,
            context: Any,
        ) -> SyncCustomActionResponse:
            if sync_action_callback is None:
                raise ValueError("sync_action_callback not wired up")
            return sync_action_callback(thread, action, sender, context)

        def respond(
            self,
            thread: ThreadMetadata,
            input_user_message: UserMessageItem | None,
            context: Any,
        ) -> AsyncIterator[ThreadStreamEvent]:
            if responder is None:
                return self._empty_responder()
            return responder(thread, input_user_message, context)

        async def _empty_responder(self) -> AsyncIterator[ThreadStreamEvent]:
            return
            yield

        async def add_feedback(
            self,
            thread_id: str,
            item_ids: list[str],
            feedback: FeedbackKind,
            context: Any,
        ) -> None:
            if handle_feedback is None:
                return
            handle_feedback(thread_id, item_ids, feedback, context)

        async def transcribe(
            self, audio_input: AudioInput, context: Any
        ) -> TranscriptionResult:
            if transcribe_callback is None:
                return await super().transcribe(audio_input, context)
            return transcribe_callback(audio_input, context)

        async def process_streaming(
            self, request_obj, context: Any | None = None
        ) -> list[ThreadStreamEvent]:
            result = await self.process(
                request_obj.model_dump_json(), context or DEFAULT_CONTEXT
            )
            assert isinstance(result, StreamingResult)
            return await decode_streaming_result(result)

        async def process_non_streaming(self, request_obj, context: Any | None = None):
            result = await self.process(
                request_obj.model_dump_json(), context or DEFAULT_CONTEXT
            )
            assert isinstance(result, NonStreamingResult)
            return result

    try:
        yield TestChatKitServer()
    finally:
        db.close()


async def test_stream_cancellation_persists_pending_assistant_message_and_hidden_context():
    async def responder(
        thread: ThreadMetadata, input: UserMessageItem | None, context: Any
    ) -> AsyncIterator[ThreadStreamEvent]:
        yield ThreadItemAddedEvent(
            item=AssistantMessageItem(
                id="assistant-message-pending",
                created_at=datetime.now(),
                content=[AssistantMessageContent(text="Hello, ")],
                thread_id=thread.id,
            )
        )
        yield ThreadItemUpdatedEvent(
            item_id="assistant-message-pending",
            update=AssistantMessageContentPartTextDelta(
                content_index=0,
                delta="World!",
            ),
        )
        raise asyncio.CancelledError()

    with make_server(responder) as server:
        # Allow hidden_context id generation in this test store
        original_generate_item_id = server.store.generate_item_id

        def generate_item_id(
            item_type: StoreItemType, thread: ThreadMetadata, context: Any
        ):
            if item_type == "sdk_hidden_context":
                return default_generate_id("sdk_hidden_context")
            return original_generate_item_id(item_type, thread, context)

        server.store.generate_item_id = generate_item_id  # type: ignore[method-assign]

        stream = await server.process(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Hello")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            ).model_dump_json(),
            DEFAULT_CONTEXT,
        )
        assert isinstance(stream, StreamingResult)

        events: list[ThreadStreamEvent] = []
        with pytest.raises(asyncio.CancelledError):  # noqa: PT012
            async for raw in stream.json_events:
                events.append(decode_event(raw))

        thread = next(e.thread for e in events if e.type == "thread.created")
        items = await server.store.load_thread_items(
            thread.id, None, 1, "desc", DEFAULT_CONTEXT
        )
        hidden_context_item = items.data[-1]
        assert hidden_context_item.type == "sdk_hidden_context"
        assert (
            hidden_context_item.content
            == "The user cancelled the stream. Stop responding to the prior request."
        )

        assistant_message_item = await server.store.load_item(
            thread.id, "assistant-message-pending", DEFAULT_CONTEXT
        )
        assert assistant_message_item.type == "assistant_message"
        assert assistant_message_item.content[0].text == "Hello, World!"


async def test_stream_cancellation_does_not_persist_pending_empty_assistant_message():
    async def responder(
        thread: ThreadMetadata, input: UserMessageItem | None, context: Any
    ) -> AsyncIterator[ThreadStreamEvent]:
        yield ThreadItemAddedEvent(
            item=AssistantMessageItem(
                id="assistant-message-pending",
                created_at=datetime.now(),
                content=[],
                thread_id=thread.id,
            )
        )
        raise asyncio.CancelledError()

    with make_server(responder) as server:
        original_generate_item_id = server.store.generate_item_id

        def generate_item_id(
            item_type: StoreItemType, thread: ThreadMetadata, context: Any
        ):
            if item_type == "sdk_hidden_context":
                return default_generate_id("sdk_hidden_context")
            return original_generate_item_id(item_type, thread, context)

        server.store.generate_item_id = generate_item_id  # type: ignore[method-assign]

        stream = await server.process(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Hello")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            ).model_dump_json(),
            DEFAULT_CONTEXT,
        )
        assert isinstance(stream, StreamingResult)

        events: list[ThreadStreamEvent] = []
        with pytest.raises(asyncio.CancelledError):  # noqa: PT012
            async for raw in stream.json_events:
                events.append(decode_event(raw))

        thread = next(e.thread for e in events if e.type == "thread.created")
        items = await server.store.load_thread_items(
            thread.id, None, 1, "desc", DEFAULT_CONTEXT
        )
        hidden_context_item = items.data[-1]
        assert hidden_context_item.type == "sdk_hidden_context"
        assert (
            hidden_context_item.content
            == "The user cancelled the stream. Stop responding to the prior request."
        )

        with pytest.raises(NotFoundError):
            await server.store.load_item(
                thread.id, "assistant-message-pending", DEFAULT_CONTEXT
            )


async def test_workflow_task_not_duplicated_on_done_event():
    """
    Regression test to make sure pending item updates do not modify the
    origin item that was streamed.
    """

    async def responder(
        thread: ThreadMetadata, input: UserMessageItem | None, context: Any
    ) -> AsyncIterator[ThreadStreamEvent]:
        workflow_item = WorkflowItem(
            id="workflow-item",
            created_at=datetime.now(),
            thread_id=thread.id,
            workflow=Workflow(type="custom", tasks=[]),
        )

        yield ThreadItemAddedEvent(item=workflow_item)

        task = CustomTask(title="foo", content="bar")
        yield ThreadItemUpdatedEvent(
            item_id=workflow_item.id,
            update=WorkflowTaskAdded(task=task, task_index=0),
        )

        workflow_item.workflow.tasks.append(task)
        yield ThreadItemDoneEvent(item=workflow_item)

    with make_server(responder) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Hello")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )

        thread = next(e.thread for e in events if e.type == "thread.created")
        workflow_done_event = next(
            e
            for e in events
            if isinstance(e, ThreadItemDoneEvent) and isinstance(e.item, WorkflowItem)
        )
        workflow_done_item = cast(WorkflowItem, workflow_done_event.item)
        assert len(workflow_done_item.workflow.tasks) == 1
        assert workflow_done_item.workflow.tasks[0].title == "foo"

        stored = await server.store.load_item(
            thread.id, workflow_done_item.id, DEFAULT_CONTEXT
        )
        assert isinstance(stored, WorkflowItem)
        stored_workflow = cast(WorkflowItem, stored)
        assert len(stored_workflow.workflow.tasks) == 1
        assert stored_workflow.workflow.tasks[0].title == "foo"


async def test_flows_context_to_responder():
    responder_context = None
    add_feedback_context = None

    async def responder(
        thread: ThreadMetadata, input: UserMessageItem | None, context: Any
    ) -> AsyncIterator[ThreadStreamEvent]:
        nonlocal responder_context
        responder_context = context
        yield ThreadItemDoneEvent(
            item=AssistantMessageItem(
                id="msg_1",
                created_at=datetime.now(),
                content=[AssistantMessageContent(text="Hello, world!")],
                thread_id=thread.id,
            ),
        )

    def add_feedback(
        thread_id: str, item_ids: list[str], feedback: FeedbackKind, context: Any
    ) -> None:
        nonlocal add_feedback_context
        add_feedback_context = context

    context = RequestContext(user_id="text_ctx")
    with make_server(responder, add_feedback) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Hello, world!")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            ),
            context,
        )
        assert responder_context == context
        thread = next(
            event.thread for event in events if event.type == "thread.created"
        )
        item = next(
            event.item
            for event in events
            if event.type == "thread.item.done"
            and event.item.type == "assistant_message"
        )
        await server.process_non_streaming(
            ItemsFeedbackReq(
                params=ItemFeedbackParams(
                    thread_id=thread.id,
                    item_ids=[item.id],
                    kind="positive",
                )
            ),
            context,
        )
        assert add_feedback_context == context


async def test_creates_thread():
    async def responder(
        thread: ThreadMetadata, input: UserMessageItem | None, context: Any
    ) -> AsyncIterator[ThreadStreamEvent]:
        return
        yield

    with make_server(responder) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Hello, world!")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                        quoted_text="Quoted text",
                    )
                )
            )
        )

        thread_created_event = next(
            event for event in events if event.type == "thread.created"
        )
        assert thread_created_event.thread.title is None

        item_done_event = next(
            event.item for event in events if event.type == "thread.item.done"
        )

        assert item_done_event.type == "user_message"
        assert item_done_event.content[0].text == "Hello, world!"
        assert item_done_event.quoted_text == "Quoted text"


async def test_saves_thread_title():
    async def responder(
        thread: ThreadMetadata, input: UserMessageItem | None, context: Any
    ) -> AsyncIterator[ThreadStreamEvent]:
        thread.title = "Updated title"
        return
        yield

    with make_server(responder) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Hello, world!")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(
            event.thread for event in events if event.type == "thread.created"
        )
        assert (
            await server.store.load_thread(
                thread.id, RequestContext(user_id="test_user")
            )
        ).title == "Updated title"
        assert events[-1].type == "thread.updated"
        assert events[-1].thread.title == "Updated title"


async def test_saves_thread_metadata():
    async def responder(
        thread: ThreadMetadata, input: UserMessageItem | None, context: Any
    ) -> AsyncIterator[ThreadStreamEvent]:
        thread.metadata["test_key"] = "test_value"
        return
        yield

    with make_server(responder) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Hello, world!")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(
            event.thread for event in events if event.type == "thread.created"
        )
        assert (await server.store.load_thread(thread.id, DEFAULT_CONTEXT)).metadata[
            "test_key"
        ] == "test_value"
        assert events[-1].type == "thread.updated"


async def test_saves_thread_locked_fields():
    async def responder(
        thread: ThreadMetadata, input: UserMessageItem | None, context: Any
    ) -> AsyncIterator[ThreadStreamEvent]:
        thread.status = LockedStatus(reason="Because")
        return
        yield

    with make_server(responder) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Hello, world!")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(
            event.thread for event in events if event.type == "thread.created"
        )
        loaded = await server.store.load_thread(
            thread.id, RequestContext(user_id="test_user")
        )
        assert loaded.status == LockedStatus(reason="Because")
        assert events[-1].type == "thread.updated"
        assert events[-1].thread.status == LockedStatus(reason="Because")


async def test_saves_allowed_image_domains_and_streams_thread_updated():
    async def responder(
        thread: ThreadMetadata, input: UserMessageItem | None, context: Any
    ) -> AsyncIterator[ThreadStreamEvent]:
        thread.allowed_image_domains = ["example.com", "images.example.com"]
        return
        yield

    with make_server(responder) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Hello, world!")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(
            event.thread for event in events if event.type == "thread.created"
        )
        loaded = await server.store.load_thread(thread.id, DEFAULT_CONTEXT)
        assert loaded.allowed_image_domains == ["example.com", "images.example.com"]
        assert events[-1].type == "thread.updated"
        assert events[-1].thread.allowed_image_domains == [
            "example.com",
            "images.example.com",
        ]


async def test_omits_unset_allowed_image_domains_in_created_and_updated_json_events():
    async def responder(
        thread: ThreadMetadata, input: UserMessageItem | None, context: Any
    ) -> AsyncIterator[ThreadStreamEvent]:
        thread.title = "Updated title"
        return
        yield

    with make_server(responder) as server:
        stream = await server.process(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Hello, world!")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            ).model_dump_json(),
            DEFAULT_CONTEXT,
        )
        assert isinstance(stream, StreamingResult)

        thread_created_event: dict[str, Any] | None = None
        thread_updated_event: dict[str, Any] | None = None
        async for raw in stream.json_events:
            event = json.loads(raw.split(b"data: ")[1])
            if event["type"] == "thread.created":
                thread_created_event = event
            if event["type"] == "thread.updated":
                thread_updated_event = event

        assert thread_created_event is not None
        assert "allowed_image_domains" not in thread_created_event["thread"]

        assert thread_updated_event is not None
        assert thread_updated_event["thread"]["title"] == "Updated title"
        assert "allowed_image_domains" not in thread_updated_event["thread"]


async def test_emits_thread_updated_mid_stream_and_persists():
    async def responder(
        thread: ThreadMetadata, input: UserMessageItem | None, context: Any
    ) -> AsyncIterator[ThreadStreamEvent]:
        # First assistant message
        yield ThreadItemDoneEvent(
            item=AssistantMessageItem(
                id="assistant_a",
                content=[AssistantMessageContent(text="A")],
                created_at=datetime.now(),
                thread_id=thread.id,
            ),
        )

        # Update the thread mid-stream
        thread.title = "Mid-stream title"

        # Second assistant message, after which thread.updated should be emitted
        yield ThreadItemDoneEvent(
            item=AssistantMessageItem(
                id="assistant_b",
                content=[AssistantMessageContent(text="B")],
                created_at=datetime.now(),
                thread_id=thread.id,
            ),
        )

        # Continue streaming after the update event
        yield ThreadItemDoneEvent(
            item=AssistantMessageItem(
                id="assistant_c",
                content=[AssistantMessageContent(text="C")],
                created_at=datetime.now(),
                thread_id=thread.id,
            ),
        )

    with make_server(responder) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Hello")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )

        # Find the created thread
        thread = next(e.thread for e in events if isinstance(e, ThreadCreatedEvent))

        # Ensure a thread.updated event occurs right after assistant_b
        idx_b = next(
            i
            for i, e in enumerate(events)
            if isinstance(e, ThreadItemDoneEvent)
            and isinstance(e.item, AssistantMessageItem)
            and e.item.id == "assistant_b"
        )
        assert isinstance(events[idx_b + 1], ThreadUpdatedEvent)
        updated_event = cast(ThreadUpdatedEvent, events[idx_b + 1])
        assert updated_event.thread.title == "Mid-stream title"

        # Streaming continues after the update event
        assert isinstance(events[idx_b + 2], ThreadItemDoneEvent)
        done_event = cast(ThreadItemDoneEvent, events[idx_b + 2])
        assert isinstance(done_event.item, AssistantMessageItem)
        assert done_event.item.id == "assistant_c"

        # Persisted in the store
        saved = await server.store.load_thread(thread.id, DEFAULT_CONTEXT)
        assert saved.title == "Mid-stream title"


async def test_respond_with_tool_call():
    async def responder(
        thread: ThreadMetadata, input: UserMessageItem | None, context: Any
    ) -> AsyncIterator[ThreadStreamEvent]:
        if isinstance(input, UserMessageItem):
            yield ThreadItemDoneEvent(
                item=ClientToolCallItem(
                    id="msg_1",
                    created_at=datetime.now(),
                    name="tool_call_1",
                    arguments={"arg1": "val1", "arg2": False},
                    call_id="tool_call_1",
                    thread_id=thread.id,
                ),
            )
        elif input is None:
            yield ThreadItemDoneEvent(
                item=AssistantMessageItem(
                    id="msg_2",
                    content=[
                        AssistantMessageContent(text="Glad the tool call succeeded!")
                    ],
                    created_at=datetime.now(),
                    thread_id=thread.id,
                ),
            )

    with make_server(responder) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Hello, world!")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )

        assert len(events) == 4
        assert events[0].type == "thread.created"
        thread = events[0].thread

        assert events[1].type == "thread.item.done"
        assert events[1].item.type == "user_message"

        assert events[2].type == "stream_options"

        assert events[3].type == "thread.item.done"
        assert events[3].item.type == "client_tool_call"
        assert events[3].item.id == "msg_1"
        assert events[3].item.name == "tool_call_1"
        assert events[3].item.arguments == {"arg1": "val1", "arg2": False}
        assert events[3].item.call_id == "tool_call_1"

        events = await server.process_streaming(
            ThreadsAddClientToolOutputReq(
                params=ThreadAddClientToolOutputParams(
                    thread_id=thread.id,
                    result={"text": "Wow!"},
                )
            )
        )

        assert len(events) == 2
        assert events[0].type == "stream_options"
        assert events[1].type == "thread.item.done"
        assert events[1].item.type == "assistant_message"


async def test_respond_with_tool_status():
    async def responder(
        thread: ThreadMetadata, input: UserMessageItem | None, context: Any
    ) -> AsyncIterator[ThreadStreamEvent]:
        yield ProgressUpdateEvent(text="Tool status")
        yield ThreadItemDoneEvent(
            item=AssistantMessageItem(
                id="msg_4",
                content=[AssistantMessageContent(text="Hello, world!")],
                created_at=datetime.now(),
                thread_id=thread.id,
            ),
        )

    with make_server(responder) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Hello, world!")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )

        assert len(events) == 5
        assert events[0].type == "thread.created"
        assert events[1].type == "thread.item.done"
        assert events[2].type == "stream_options"
        assert events[3].type == "progress_update"
        assert events[4].type == "thread.item.done"


async def test_list_threads_response():
    with make_server() as server:
        for i in range(25):
            await server.process_streaming(
                ThreadsCreateReq(
                    params=ThreadCreateParams(
                        input=UserMessageInput(
                            content=[UserMessageTextContent(text=f"Thread {i}")],
                            attachments=[],
                            inference_options=InferenceOptions(),
                        )
                    )
                )
            )
        list_result = await server.process_non_streaming(
            ThreadsListReq(params=ThreadListParams(limit=20))
        )
        threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
        assert len(threads.data) == 20

        # Newest first
        assert threads.data[0].created_at > threads.data[1].created_at
        assert threads.has_more is True
        assert threads.after is not None

        more_threads = await server.process_non_streaming(
            ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
        )
        threads = TypeAdapter(Page[ThreadMetadata]).validate_json(more_threads.json)
        assert len(threads.data) == 5
        assert threads.has_more is False
        assert not threads.after


async def test_list_items_response():
    async def responder(
        thread: ThreadMetadata, input: UserMessageItem | None, context: Any
    ) -> AsyncIterator[ThreadStreamEvent]:
        for i in range(25):
            yield ThreadItemDoneEvent(
                item=UserMessageItem(
                    id=f"msg_{i}",
                    content=[UserMessageTextContent(text=f"Message {i}")],
                    attachments=[],
                    inference_options=InferenceOptions(),
                    thread_id=thread.id,
                    created_at=datetime.now(),
                ),
            )

    with make_server(responder) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Thread 1")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(
            event.thread for event in events if event.type == "thread.created"
        )
        list_result = await server.process_non_streaming(
            ItemsListReq(params=ItemsListParams(thread_id=thread.id))
        )
        items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
        assert len(items.data) == 20
        # Newest first
        assert items.data[0].created_at > items.data[1].created_at
        assert items.has_more is True
        assert items.after is not None

        list_result = await server.process_non_streaming(
            ItemsListReq(
                params=ItemsListParams(thread_id=thread.id, after=items.data[-1].id)
            )
        )
        items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
        assert len(items.data) == 6  # +1 for user message
        assert items.has_more is False


async def test_calls_action():
    actions = []

    async def responder(
        thread: ThreadMetadata,
        input: UserMessageItem | None,
        context: Any,
    ) -> AsyncIterator[ThreadStreamEvent]:
        yield ThreadItemDoneEvent(
            item=WidgetItem(
                id="widget_1",
                type="widget",
                created_at=datetime.now(),
                thread_id=thread.id,
                widget=Card(
                    children=[
                        Text(
                            key="text_1",
                            value="test",
                        )
                    ],
                ),
            ),
        )

    async def action(
        thread: ThreadMetadata,
        action: Action[str, Any],
        sender: WidgetItem | None,
        context: Any,
    ) -> AsyncIterator[ThreadStreamEvent]:
        actions.append((action, sender))
        assert sender

        yield ThreadItemUpdatedEvent(
            item_id=sender.id,
            update=WidgetRootUpdated(
                widget=Card(
                    children=[
                        Text(value="Email sent!"),
                    ]
                ),
            ),
        )

    with make_server(responder, action_callback=action) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Show widget")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(
            event.thread for event in events if event.type == "thread.created"
        )
        widget_item = next(
            event.item
            for event in events
            if isinstance(event, ThreadItemDoneEvent)
            and isinstance(event.item, WidgetItem)
        )

        events = await server.process_streaming(
            ThreadsCustomActionReq(
                params=ThreadCustomActionParams(
                    thread_id=thread.id,
                    item_id=widget_item.id,
                    action=Action(type="create_user", payload={"user_id": "123"}),
                )
            )
        )

        assert actions
        assert actions[0] == (
            Action(type="create_user", payload={"user_id": "123"}),
            widget_item,
        )

        assert len(events) == 2
        assert events[0].type == "stream_options"
        assert events[1].type == "thread.item.updated"
        assert isinstance(events[1], ThreadItemUpdatedEvent)
        assert events[1].update.type == "widget.root.updated"
        assert events[1].update.widget == Card(children=[Text(value="Email sent!")])


async def test_add_structured_input_replaces_item_with_custom_answer_and_resumes_response():
    async def responder(
        thread: ThreadMetadata,
        input: UserMessageItem | None,
        context: Any,
    ) -> AsyncIterator[ThreadStreamEvent]:
        if input is not None:
            yield ThreadItemDoneEvent(
                item=StructuredInputItem(
                    id="si_1",
                    thread_id=thread.id,
                    created_at=datetime.now(),
                    inputs=[
                        StructuredInputMultipleChoice(
                            id="subject",
                            question="Which subject is this lesson for?",
                            options=[
                                StructuredInputMultipleChoiceOption(value="Math"),
                                StructuredInputMultipleChoiceOption(value="Science"),
                            ],
                        ),
                        StructuredInputFreeform(
                            id="details",
                            question="Anything else to know?",
                        ),
                    ],
                )
            )
            return

        yield ThreadItemDoneEvent(
            item=AssistantMessageItem(
                id="msg_after_structured_input",
                content=[AssistantMessageContent(text="Thanks, continuing now.")],
                created_at=datetime.now(),
                thread_id=thread.id,
            )
        )

    with make_server(responder) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Plan a lesson")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(
            event.thread for event in events if isinstance(event, ThreadCreatedEvent)
        )

        events = await server.process_streaming(
            ThreadsAddStructuredInputReq(
                params=ThreadAddStructuredInputParams(
                    thread_id=thread.id,
                    item_id="si_1",
                    input=StructuredInputSubmission(
                        status="answered",
                        answers={
                            "subject": StructuredInputAnswerSubmission(
                                values=["Private class"]
                            ),
                            "details": StructuredInputAnswerSubmission(
                                values=["Please include group work."]
                            ),
                        },
                    ),
                )
            )
        )

        replaced_event = next(
            event for event in events if isinstance(event, ThreadItemReplacedEvent)
        )
        assert isinstance(replaced_event.item, StructuredInputItem)
        assert replaced_event.item.status == "answered"
        assert replaced_event.item.inputs[0].answer is not None
        assert replaced_event.item.inputs[0].answer.values == ["Private class"]
        assert replaced_event.item.inputs[1].answer is not None
        assert replaced_event.item.inputs[1].answer.values == [
            "Please include group work."
        ]

        persisted_item = await server.store.load_item(
            thread.id,
            "si_1",
            DEFAULT_CONTEXT,
        )
        assert persisted_item == replaced_event.item
        assert any(
            isinstance(event, ThreadItemDoneEvent)
            and isinstance(event.item, AssistantMessageItem)
            and event.item.id == "msg_after_structured_input"
            for event in events
        )


async def test_add_structured_input_treats_omitted_answers_as_skipped():
    async def responder(
        thread: ThreadMetadata,
        input: UserMessageItem | None,
        context: Any,
    ) -> AsyncIterator[ThreadStreamEvent]:
        if input is not None:
            yield ThreadItemDoneEvent(
                item=StructuredInputItem(
                    id="si_1",
                    thread_id=thread.id,
                    created_at=datetime.now(),
                    inputs=[
                        StructuredInputMultipleChoice(
                            id="subject",
                            question="Which subject is this lesson for?",
                            options=[
                                StructuredInputMultipleChoiceOption(value="Math"),
                                StructuredInputMultipleChoiceOption(value="Science"),
                            ],
                        ),
                        StructuredInputFreeform(
                            id="details",
                            question="Anything else to know?",
                        ),
                    ],
                )
            )

    with make_server(responder) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Plan a lesson")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(
            event.thread for event in events if isinstance(event, ThreadCreatedEvent)
        )

        events = await server.process_streaming(
            ThreadsAddStructuredInputReq(
                params=ThreadAddStructuredInputParams(
                    thread_id=thread.id,
                    item_id="si_1",
                    input=StructuredInputSubmission(
                        status="answered",
                        answers={
                            "subject": StructuredInputAnswerSubmission(
                                values=["Math"]
                            ),
                        },
                    ),
                )
            )
        )

        replaced_event = next(
            event for event in events if isinstance(event, ThreadItemReplacedEvent)
        )
        assert isinstance(replaced_event.item, StructuredInputItem)
        subject_answer = replaced_event.item.inputs[0].answer
        assert subject_answer is not None
        assert subject_answer.values == ["Math"]
        details_answer = replaced_event.item.inputs[1].answer
        assert details_answer is not None
        assert details_answer.skipped


async def test_add_structured_input_ignores_unknown_answer_ids():
    async def responder(
        thread: ThreadMetadata,
        input: UserMessageItem | None,
        context: Any,
    ) -> AsyncIterator[ThreadStreamEvent]:
        if input is not None:
            yield ThreadItemDoneEvent(
                item=StructuredInputItem(
                    id="si_1",
                    thread_id=thread.id,
                    created_at=datetime.now(),
                    inputs=[
                        StructuredInputFreeform(
                            id="details",
                            question="Anything else to know?",
                        ),
                    ],
                )
            )

    with make_server(responder) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Plan a lesson")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(
            event.thread for event in events if isinstance(event, ThreadCreatedEvent)
        )

        events = await server.process_streaming(
            ThreadsAddStructuredInputReq(
                params=ThreadAddStructuredInputParams(
                    thread_id=thread.id,
                    item_id="si_1",
                    input=StructuredInputSubmission(
                        status="answered",
                        answers={
                            "details": StructuredInputAnswerSubmission(
                                values=["Please include group work."]
                            ),
                            "unknown": StructuredInputAnswerSubmission(
                                values=["ignored"]
                            ),
                        },
                    ),
                )
            )
        )

        replaced_event = next(
            event for event in events if isinstance(event, ThreadItemReplacedEvent)
        )
        assert isinstance(replaced_event.item, StructuredInputItem)
        answer = replaced_event.item.inputs[0].answer
        assert answer is not None
        assert answer.values == ["Please include group work."]


async def test_add_structured_input_truncates_extra_single_choice_values():
    async def responder(
        thread: ThreadMetadata,
        input: UserMessageItem | None,
        context: Any,
    ) -> AsyncIterator[ThreadStreamEvent]:
        yield ThreadItemDoneEvent(
            item=StructuredInputItem(
                id="si_1",
                thread_id=thread.id,
                created_at=datetime.now(),
                inputs=[
                    StructuredInputMultipleChoice(
                        id="subject",
                        question="Which subject is this lesson for?",
                        options=[
                            StructuredInputMultipleChoiceOption(value="Math"),
                            StructuredInputMultipleChoiceOption(value="Science"),
                        ],
                    ),
                ],
            )
        )

    with make_server(responder) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Plan a lesson")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(
            event.thread for event in events if isinstance(event, ThreadCreatedEvent)
        )

        events = await server.process_streaming(
            ThreadsAddStructuredInputReq(
                params=ThreadAddStructuredInputParams(
                    thread_id=thread.id,
                    item_id="si_1",
                    input=StructuredInputSubmission(
                        status="answered",
                        answers={
                            "subject": StructuredInputAnswerSubmission(
                                values=["Math", "Science"],
                            ),
                        },
                    ),
                )
            )
        )

        replaced_event = next(
            event for event in events if isinstance(event, ThreadItemReplacedEvent)
        )
        assert isinstance(replaced_event.item, StructuredInputItem)
        answer = replaced_event.item.inputs[0].answer
        assert answer is not None
        assert answer.values == ["Math"]

        persisted_item = await server.store.load_item(
            thread.id,
            "si_1",
            DEFAULT_CONTEXT,
        )
        assert isinstance(persisted_item, StructuredInputItem)
        assert persisted_item.status == "answered"
        assert persisted_item.inputs[0].answer == answer


async def test_calls_action_sync():
    sync_actions = []

    async def responder(
        thread: ThreadMetadata,
        input: UserMessageItem | None,
        context: Any,
    ) -> AsyncIterator[ThreadStreamEvent]:
        yield ThreadItemDoneEvent(
            item=WidgetItem(
                id="widget_1",
                type="widget",
                created_at=datetime.now(),
                thread_id=thread.id,
                widget=Card(
                    children=[
                        Text(
                            key="text_1",
                            value="test",
                        )
                    ],
                ),
            ),
        )

    def sync_action(
        thread: ThreadMetadata,
        action: Action[str, Any],
        sender: WidgetItem | None,
        context: Any,
    ) -> SyncCustomActionResponse:
        sync_actions.append((action, sender))
        assert sender

        return SyncCustomActionResponse(
            updated_item=sender.model_copy(
                update={
                    "widget": Card(
                        children=[
                            Text(value="Email sent!"),
                        ]
                    )
                }
            )
        )

    with make_server(responder, sync_action_callback=sync_action) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Show widget")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(
            event.thread for event in events if event.type == "thread.created"
        )
        widget_item = next(
            event.item
            for event in events
            if isinstance(event, ThreadItemDoneEvent)
            and isinstance(event.item, WidgetItem)
        )

        result = await server.process_non_streaming(
            ThreadsSyncCustomActionReq(
                params=ThreadCustomActionParams(
                    thread_id=thread.id,
                    item_id=widget_item.id,
                    action=Action(type="create_user", payload={"user_id": "123"}),
                )
            )
        )
        response = TypeAdapter(SyncCustomActionResponse).validate_json(result.json)

        assert sync_actions
        assert sync_actions[0] == (
            Action(type="create_user", payload={"user_id": "123"}),
            widget_item,
        )
        assert response.updated_item == widget_item.model_copy(
            update={"widget": Card(children=[Text(value="Email sent!")])}
        )


async def test_add_feedback():
    called = []

    def handle_feedback(thread_id, item_ids, feedback, context):
        called.append((thread_id, item_ids, feedback, context))

    async def responder(
        thread: ThreadMetadata,
        input: UserMessageItem | None,
        context: Any,
    ) -> AsyncIterator[ThreadStreamEvent]:
        yield ThreadItemDoneEvent(
            item=AssistantMessageItem(
                id="assistant_msg_1",
                content=[
                    AssistantMessageContent(text="Feedback received!"),
                ],
                created_at=datetime.now(),
                thread_id=thread.id,
            ),
        )
        yield ThreadItemDoneEvent(
            item=WidgetItem(
                id="widget_1",
                type="widget",
                created_at=datetime.now(),
                widget=Card(children=[Text(key="text", value="Widget content")]),
                thread_id=thread.id,
            ),
        )

    with make_server(responder, handle_feedback=handle_feedback) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Test feedback")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(
            event.thread for event in events if event.type == "thread.created"
        )

        await server.process_non_streaming(
            ItemsFeedbackReq(
                params=ItemFeedbackParams(
                    thread_id=thread.id,
                    item_ids=["assistant_msg_1"],
                    kind="positive",
                )
            )
        )
        await server.process_non_streaming(
            ItemsFeedbackReq(
                params=ItemFeedbackParams(
                    thread_id=thread.id,
                    item_ids=["widget_1"],
                    kind="negative",
                )
            )
        )

        assert len(called) == 2
        assert called[0][0] == thread.id
        assert called[0][1] == ["assistant_msg_1"]
        assert called[0][2] == "positive"

        assert called[1][0] == thread.id
        assert called[1][1] == ["widget_1"]
        assert called[1][2] == "negative"


async def test_get_thread_by_id():
    make_thread_items()

    with make_server() as server:
        # Create a thread by sending a ThreadCreateRequest
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Test thread")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(
            event.thread for event in events if event.type == "thread.created"
        )

        # Retrieve the thread by id
        result = await server.process_non_streaming(
            ThreadsGetByIdReq(
                params=ThreadGetByIdParams(thread_id=thread.id),
            )
        )
        loaded_thread = TypeAdapter(Thread).validate_json(result.json)
        assert loaded_thread.id == thread.id
        assert loaded_thread.title == thread.title
        assert loaded_thread.items.data[0].type == "user_message"
        assert loaded_thread.items.data[0].content[0].text == "Test thread"


async def test_create_file():
    store = InMemoryFileStore()
    file_name = "test-file-name"
    file_content_type = "text/plain"

    with make_server(file_store=store) as server:
        result = await server.process_non_streaming(
            AttachmentsCreateReq(
                params=AttachmentCreateParams(
                    name=file_name, size=1024, mime_type=file_content_type
                )
            )
        )
        attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)

        # Verify response includes file
        assert attachment.id is not None
        assert attachment.mime_type == file_content_type
        assert attachment.name == file_name
        assert attachment.type == "file"
        assert attachment.upload_descriptor is not None
        assert attachment.upload_descriptor.url == AnyUrl(
            f"https://example.com/{attachment.id}/upload"
        )
        assert attachment.upload_descriptor.method == "PUT"
        assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"}
        assert attachment.id in store.files

        # The `metadata` field is excluded from serialization by default but still stored server-side
        assert attachment.metadata is None
        assert store.files[attachment.id].metadata == {"source": "test"}


async def test_create_image_file():
    store = InMemoryFileStore()
    file_name = "test-file-name"
    file_content_type = "image/png"

    with make_server(file_store=store) as server:
        result = await server.process_non_streaming(
            AttachmentsCreateReq(
                params=AttachmentCreateParams(
                    name=file_name,
                    size=1024,
                    mime_type=file_content_type,
                )
            )
        )
        attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)

        assert attachment.id is not None
        assert attachment.mime_type == file_content_type
        assert attachment.name == file_name
        assert attachment.type == "image"
        assert attachment.preview_url == AnyUrl(
            f"https://example.com/{attachment.id}/preview"
        )
        assert attachment.upload_descriptor is not None
        assert attachment.upload_descriptor.url == AnyUrl(
            f"https://example.com/{attachment.id}/upload"
        )
        assert attachment.upload_descriptor.method == "PUT"
        assert attachment.upload_descriptor.headers == {"X-My-Header": "my-value"}

        assert attachment.id in store.files


async def test_user_message_attachment_metadata_not_serialized():
    attachment = FileAttachment(
        id="file_with_metadata",
        type="file",
        mime_type="text/plain",
        name="test.txt",
        metadata={"source": "test"},
    )

    with make_server() as server:
        await server.store.save_attachment(attachment, DEFAULT_CONTEXT)
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Message with file")],
                        attachments=[attachment.id],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(
            event.thread for event in events if event.type == "thread.created"
        )

        list_result = await server.process_non_streaming(
            ItemsListReq(params=ItemsListParams(thread_id=thread.id))
        )
        items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
        user_items = [item for item in items.data if item.type == "user_message"]
        assert user_items
        attachments = user_items[0].attachments
        assert attachments
        assert all(att.metadata is None for att in attachments)


async def test_user_message_attachment_thread_id_persisted_on_thread_create():
    attachment = FileAttachment(
        id="file_thread_id_on_create",
        type="file",
        mime_type="text/plain",
        name="test.txt",
        thread_id=None,
    )

    with make_server() as server:
        # Attachment exists before a thread exists, so it may not have a thread_id yet.
        await server.store.save_attachment(attachment, DEFAULT_CONTEXT)

        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Message with file")],
                        attachments=[attachment.id],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(e.thread for e in events if e.type == "thread.created")

        # After the user message is saved, the attachment record should be updated
        # to reflect its thread association.
        stored = await server.store.load_attachment(attachment.id, DEFAULT_CONTEXT)
        assert stored.thread_id == thread.id


async def test_user_message_attachment_thread_id_persisted_on_add_user_message():
    with make_server() as server:
        # Create a thread first.
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Start")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(e.thread for e in events if e.type == "thread.created")

        attachment = FileAttachment(
            id="file_thread_id_on_add",
            type="file",
            mime_type="text/plain",
            name="test.txt",
            thread_id=None,
        )
        await server.store.save_attachment(attachment, DEFAULT_CONTEXT)

        # Attach the pre-created attachment to a message in an existing thread.
        await server.process_streaming(
            ThreadsAddUserMessageReq(
                params=ThreadAddUserMessageParams(
                    thread_id=thread.id,
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Second")],
                        attachments=[attachment.id],
                        inference_options=InferenceOptions(),
                    ),
                )
            )
        )

        stored = await server.store.load_attachment(attachment.id, DEFAULT_CONTEXT)
        assert stored.thread_id == thread.id


async def test_create_file_without_filestore():
    with pytest.raises(RuntimeError):
        with make_server() as server:
            await server.process_non_streaming(
                AttachmentsCreateReq(
                    params=AttachmentCreateParams(
                        name="test-file-name",
                        size=1024,
                        mime_type="text/plain",
                    )
                )
            )


async def test_delete_file():
    store = InMemoryFileStore()
    store.files["test-file-id"] = {"id": "test-file-id"}
    with make_server(file_store=store) as server:
        await server.process_non_streaming(
            AttachmentsDeleteReq(
                params=AttachmentDeleteParams(attachment_id="test-file-id")
            )
        )
        assert store.files == {}


async def test_delete_file_without_filestore():
    with pytest.raises(RuntimeError):
        with make_server() as server:
            await server.process_non_streaming(
                AttachmentsDeleteReq(
                    params=AttachmentDeleteParams(attachment_id="test-file-id")
                )
            )


async def test_list_threads_paginated_response():
    with make_server() as server:
        for i in range(25):
            await server.process_streaming(
                ThreadsCreateReq(
                    params=ThreadCreateParams(
                        input=UserMessageInput(
                            content=[UserMessageTextContent(text=f"Thread {i}")],
                            attachments=[],
                            inference_options=InferenceOptions(),
                        )
                    )
                )
            )
        list_result = await server.process_non_streaming(
            ThreadsListReq(params=ThreadListParams(limit=20))
        )
        threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
        assert len(threads.data) == 20
        assert threads.has_more is True
        assert threads.after is not None

        list_result = await server.process_non_streaming(
            ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
        )
        threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
        assert len(threads.data) == 5
        assert threads.has_more is False


async def test_threads_update_request():
    with make_server() as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Hello, world!")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(
            event.thread for event in events if event.type == "thread.created"
        )
        updated_thread = await server.process_non_streaming(
            ThreadsUpdateReq(
                params=ThreadUpdateParams(thread_id=thread.id, title="New Title")
            )
        )
        updated_thread = TypeAdapter[Thread](Thread).validate_json(updated_thread.json)
        assert updated_thread.title == "New Title"


async def test_returns_widget_item_generator():
    thread = ThreadMetadata(
        id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
    )

    def render_widget(i: int) -> Card:
        return Card(children=[Text(id="text", value="Hello, world"[:i])])

    async def widget_generator():
        yield render_widget(0)
        yield render_widget(3)
        yield render_widget(6)
        yield render_widget(12)

    events = [event async for event in stream_widget(thread, widget_generator())]

    assert len(events) == 5
    assert isinstance(events[0], ThreadItemAddedEvent)
    assert isinstance(events[0].item, WidgetItem)
    assert events[0].item.widget == Card(children=[Text(id="text", value="")])

    assert isinstance(events[1], ThreadItemUpdatedEvent)
    assert events[1].update.type == "widget.streaming_text.value_delta"
    assert events[1].update.component_id == "text"
    assert events[1].update.delta == "Hel"

    assert isinstance(events[2], ThreadItemUpdatedEvent)
    assert events[2].update.type == "widget.streaming_text.value_delta"
    assert events[2].update.component_id == "text"
    assert events[2].update.delta == "lo,"

    assert isinstance(events[3], ThreadItemUpdatedEvent)
    assert events[3].update.type == "widget.streaming_text.value_delta"
    assert events[3].update.component_id == "text"
    assert events[3].update.delta == " world"

    assert isinstance(events[4], ThreadItemDoneEvent)
    assert isinstance(events[4].item, WidgetItem)
    assert events[4].item.widget == Card(
        children=[Text(id="text", value="Hello, world")]
    )


async def test_returns_widget_item_generator_full_replace():
    thread = ThreadMetadata(
        id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
    )

    async def widget_generator():
        yield Card(children=[Text(id="text", value="Hello")])
        yield Card(children=[Text(id="text", value="World")])

    events = [event async for event in stream_widget(thread, widget_generator())]

    assert len(events) == 3
    assert isinstance(events[0], ThreadItemAddedEvent)
    assert isinstance(events[0].item, WidgetItem)
    assert events[0].item.widget == Card(children=[Text(id="text", value="Hello")])

    assert isinstance(events[1], ThreadItemUpdatedEvent)
    assert events[1].update.type == "widget.root.updated"
    assert events[1].update.widget == Card(children=[Text(id="text", value="World")])

    assert isinstance(events[2], ThreadItemDoneEvent)
    assert isinstance(events[2].item, WidgetItem)
    assert events[2].item.widget == Card(children=[Text(id="text", value="World")])


async def test_delete_thread():
    with make_server() as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Thread to delete")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(
            event.thread for event in events if event.type == "thread.created"
        )

        response = await server.process_non_streaming(
            ThreadsDeleteReq(params=ThreadDeleteParams(thread_id=thread.id))
        )
        assert response.json == b"{}"

        try:
            await server.process_non_streaming(
                ThreadsGetByIdReq(
                    params=ThreadGetByIdParams(thread_id=thread.id),
                )
            )
            assert False, "Expected ValueError for deleted thread"
        except NotFoundError as e:
            assert f"Thread {thread.id} not found" in str(e)


async def test_thread_item_removed_event_removes_item():
    removed_id = "msg_to_remove"

    async def responder(
        thread: ThreadMetadata,
        input: UserMessageItem | None,
        context: Any,
    ) -> AsyncIterator[ThreadStreamEvent]:
        yield ThreadItemDoneEvent(
            item=UserMessageItem(
                id=removed_id,
                content=[UserMessageTextContent(text="To be removed")],
                attachments=[],
                inference_options=InferenceOptions(),
                thread_id=thread.id,
                created_at=datetime.now(),
            ),
        )
        yield ThreadItemRemovedEvent(item_id=removed_id)

    with make_server(responder) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Trigger removal")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(
            event.thread for event in events if event.type == "thread.created"
        )
        assert any(
            e.type == "thread.item.removed" and e.item_id == removed_id for e in events
        )
        # The item should not be present in the store
        list_result = await server.process_non_streaming(
            ItemsListReq(params=ItemsListParams(thread_id=thread.id))
        )
        items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
        assert all(i.id != removed_id for i in items)


async def test_raising_in_responder_yields_error_event():
    async def responder(
        thread: ThreadMetadata,
        input: UserMessageItem | None,
        context: Any,
    ) -> AsyncIterator[ThreadStreamEvent]:
        raise ValueError("Test error")
        yield

    with make_server(responder) as server:
        result = await server.process(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Test error")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            ).model_dump_json(),
            RequestContext(user_id="test_user"),
        )
        assert isinstance(result, StreamingResult)
        iter = result.__aiter__()

        # Yield an errro
        async for e in iter:
            e = decode_event(e)
            if e.type == "error":
                assert e.code == ErrorCode.STREAM_ERROR
                break
        else:
            assert False, "Expected error event"

        # Then finish
        try:
            await iter.__anext__()
        except StopAsyncIteration:
            pass
        else:
            assert False, "Expected StopAsyncIteration"


async def test_thread_item_replaced_event_replaces_item():
    original_id = "msg_to_replace"

    async def responder(
        thread: ThreadMetadata,
        input: UserMessageItem | None,
        context: Any,
    ) -> AsyncIterator[ThreadStreamEvent]:
        # First add an item
        yield ThreadItemDoneEvent(
            item=UserMessageItem(
                id=original_id,
                content=[UserMessageTextContent(text="Original content")],
                attachments=[],
                inference_options=InferenceOptions(),
                thread_id=thread.id,
                created_at=datetime.now(),
            ),
        )
        # Then replace it
        replacement_item = AssistantMessageItem(
            id=original_id,
            content=[AssistantMessageContent(text="This is the replacement content")],
            created_at=datetime.now(),
            thread_id=thread.id,
        )
        yield ThreadItemReplacedEvent(item=replacement_item)

    with make_server(responder) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Trigger replacement")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(
            event.thread for event in events if event.type == "thread.created"
        )

        # Verify the replacement event was generated
        assert any(
            e.type == "thread.item.replaced" and e.item.id == original_id
            for e in events
        )

        # Verify the item was replaced in the store
        list_result = await server.process_non_streaming(
            ItemsListReq(params=ItemsListParams(thread_id=thread.id))
        )
        items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data

        # Find the replaced item
        replaced_item = next((i for i in items if i.id == original_id), None)
        assert replaced_item is not None
        assert isinstance(replaced_item, AssistantMessageItem)
        assert isinstance(replaced_item.content[0], AssistantMessageContent)
        assert replaced_item.content[0].text == "This is the replacement content"


async def test_thread_item_replaced_event_with_widget():
    widget_id = "widget_to_replace"
    original_widget = Card(children=[Text(value="Original widget content")])
    replacement_widget = Card(children=[Text(value="Replaced widget content")])

    async def responder(
        thread: ThreadMetadata,
        input: UserMessageItem | None,
        context: Any,
    ) -> AsyncIterator[ThreadStreamEvent]:
        # First add a widget item
        yield ThreadItemDoneEvent(
            item=WidgetItem(
                id=widget_id,
                widget=original_widget,
                created_at=datetime.now(),
                thread_id=thread.id,
            ),
        )
        # Then replace it
        yield ThreadItemReplacedEvent(
            item=WidgetItem(
                id=widget_id,
                widget=replacement_widget,
                created_at=datetime.now(),
                thread_id=thread.id,
            ),
        )

    with make_server(responder) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[
                            UserMessageTextContent(text="Trigger widget replacement")
                        ],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(
            event.thread for event in events if event.type == "thread.created"
        )

        # Verify the replacement event was generated
        replacement_event = next(
            (
                e
                for e in events
                if e.type == "thread.item.replaced" and e.item.id == widget_id
            ),
            None,
        )
        assert replacement_event is not None
        assert isinstance(replacement_event.item, WidgetItem)
        assert replacement_event.item.widget == replacement_widget

        # Verify the item was replaced in the store
        list_result = await server.process_non_streaming(
            ItemsListReq(params=ItemsListParams(thread_id=thread.id))
        )
        items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data

        # Find the replaced widget item
        replaced_item = next((i for i in items if i.id == widget_id), None)
        assert replaced_item is not None
        assert isinstance(replaced_item, WidgetItem)
        assert isinstance(replaced_item.widget, Card)
        assert isinstance(replaced_item.widget.children[0], Text)
        assert replaced_item.widget.children[0].value == "Replaced widget content"


async def test_retry_after_item():
    responder_calls = []

    async def responder(
        thread: ThreadMetadata,
        input: UserMessageItem | None,
        context: Any,
    ) -> AsyncIterator[ThreadStreamEvent]:
        responder_calls.append(input)

        if len(responder_calls) == 1:
            # First response - return an assistant message
            yield ThreadItemDoneEvent(
                item=AssistantMessageItem(
                    id="assistant_msg_1",
                    content=[AssistantMessageContent(text="First response")],
                    created_at=datetime.now(),
                    thread_id=thread.id,
                ),
            )
        else:
            # Retry response - return different content
            yield ThreadItemDoneEvent(
                item=AssistantMessageItem(
                    id="assistant_msg_2",
                    content=[AssistantMessageContent(text="Retried response")],
                    created_at=datetime.now(),
                    thread_id=thread.id,
                ),
            )

    with make_server(responder) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Hello, world!")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(
            event.thread for event in events if event.type == "thread.created"
        )

        list_result = await server.process_non_streaming(
            ItemsListReq(params=ItemsListParams(thread_id=thread.id))
        )
        items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
        assert len(items_before.data) == 2
        assert items_before.data[0].type == "assistant_message"
        assert items_before.data[0].content[0].type == "output_text"
        assert items_before.data[0].content[0].text == "First response"
        assert items_before.data[1].type == "user_message"
        assert items_before.data[1].content[0].text == "Hello, world!"

        # Get the user message ID for retrying after it
        user_message_id = items_before.data[1].id

        # Retry after the user message
        retry_events = await server.process_streaming(
            ThreadsRetryAfterItemReq(
                params=ThreadRetryAfterItemParams(
                    thread_id=thread.id, item_id=user_message_id
                )
            )
        )

        # Verify the retry generated new response
        assert len(retry_events) == 2
        assert retry_events[0].type == "stream_options"
        assert retry_events[1].type == "thread.item.done"
        assert retry_events[1].item.type == "assistant_message"
        assert retry_events[1].item.content[0].type == "output_text"
        assert retry_events[1].item.content[0].text == "Retried response"

        # Verify the responder was called twice with the same user message
        assert len(responder_calls) == 2
        assert responder_calls[0] == responder_calls[1]
        assert responder_calls[0].type == "user_message"
        assert responder_calls[0].content[0].text == "Hello, world!"

        # Verify the assistant response was replaced in storage
        list_result_after = await server.process_non_streaming(
            ItemsListReq(params=ItemsListParams(thread_id=thread.id))
        )
        items_after = TypeAdapter(Page[ThreadItem]).validate_json(
            list_result_after.json
        )
        assert len(items_after.data) == 2  # Still user message + assistant message
        assert items_after.data[0].type == "assistant_message"
        assert items_after.data[0].content[0].type == "output_text"
        assert items_after.data[0].content[0].text == "Retried response"  # New response
        assert items_after.data[1].type == "user_message"
        assert items_after.data[1].content[0].text == "Hello, world!"


async def test_retry_after_item_invalid_item():
    async def responder(
        thread: ThreadMetadata,
        input: UserMessageItem | None,
        context: Any,
    ) -> AsyncIterator[ThreadStreamEvent]:
        yield ThreadItemDoneEvent(
            item=AssistantMessageItem(
                id="assistant_msg_1",
                content=[AssistantMessageContent(text="Response")],
                created_at=datetime.now(),
                thread_id=thread.id,
            ),
        )

    with make_server(responder) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Hello")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(
            event.thread for event in events if event.type == "thread.created"
        )

        items_result = await server.process_non_streaming(
            ItemsListReq(params=ItemsListParams(thread_id=thread.id))
        )
        items = TypeAdapter(Page[ThreadItem]).validate_json(items_result.json)
        assistant_item_id = items.data[0].id

        # Try to retry after an assistant message (should fail)
        with pytest.raises(ValueError, match="is not a user message"):
            await server.process_streaming(
                ThreadsRetryAfterItemReq(
                    params=ThreadRetryAfterItemParams(
                        thread_id=thread.id, item_id=assistant_item_id
                    )
                )
            )


async def test_retry_after_item_with_multiple_messages():
    responder_calls = []

    async def responder(
        thread: ThreadMetadata,
        input: UserMessageItem | None,
        context: Any,
    ) -> AsyncIterator[ThreadStreamEvent]:
        responder_calls.append(input)
        assert input is not None
        assert input.type == "user_message"
        yield ThreadItemDoneEvent(
            item=AssistantMessageItem(
                id=f"assistant_msg_{len(responder_calls)}",
                content=[
                    AssistantMessageContent(
                        text=f"Response to: {input.content[0].text}"
                    )
                ],
                created_at=datetime.now(),
                thread_id=thread.id,
            ),
        )

    with make_server(responder) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="First message")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            )
        )
        thread = next(
            event.thread for event in events if event.type == "thread.created"
        )

        await server.process_streaming(
            ThreadsAddUserMessageReq(
                params=ThreadAddUserMessageParams(
                    thread_id=thread.id,
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Second message")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    ),
                )
            )
        )

        # Verify we have 4 items: user1, assistant1, user2, assistant2
        list_result = await server.process_non_streaming(
            ItemsListReq(params=ItemsListParams(thread_id=thread.id))
        )
        items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
        assert len(items_before.data) == 4

        # Get the second user message ID to retry after it
        second_user_message_id = items_before.data[1].id  # Second user message

        # Retry after the second user message
        retry_events = await server.process_streaming(
            ThreadsRetryAfterItemReq(
                params=ThreadRetryAfterItemParams(
                    thread_id=thread.id, item_id=second_user_message_id
                )
            )
        )

        assert len(retry_events) == 2
        assert retry_events[0].type == "stream_options"
        assert retry_events[1].type == "thread.item.done"

        # Verify retry used the second user message
        assert len(responder_calls) == 3  # Original 2 + 1 retry
        assert responder_calls[2].type == "user_message"
        assert responder_calls[2].content[0].text == "Second message"

        # Verify only the last assistant response was removed and regenerated
        list_result_after = await server.process_non_streaming(
            ItemsListReq(params=ItemsListParams(thread_id=thread.id))
        )
        items_after = TypeAdapter(Page[ThreadItem]).validate_json(
            list_result_after.json
        )
        assert len(items_after.data) == 4  # Still 4 items
        # First user message and response should remain unchanged
        assert items_after.data[3].type == "user_message"
        assert items_after.data[3].content[0].text == "First message"
        assert items_after.data[3].content[0].type == "input_text"
        assert items_after.data[3].content[0].text == "First message"
        assert items_after.data[2].type == "assistant_message"
        assert items_after.data[2].content[0].type == "output_text"
        assert items_after.data[2].content[0].text == "Response to: First message"
        # Second user message should remain, but assistant response should be new
        assert items_after.data[1].type == "user_message"
        assert items_after.data[1].content[0].text == "Second message"
        assert items_after.data[0].type == "assistant_message"
        assert items_after.data[0].content[0].type == "output_text"
        assert items_after.data[0].content[0].text == "Response to: Second message"
        assert items_after.data[0].id != items_before.data[0].id


async def test_threads_create_passes_tools_to_responder():
    tool_choice = ToolChoice(id="web_search")

    async def responder(
        thread: ThreadMetadata, input: UserMessageItem | None, context: Any
    ) -> AsyncIterator[ThreadStreamEvent]:
        assert input is not None
        assert input.inference_options.tool_choice is not None
        assert input.inference_options.tool_choice.id == tool_choice.id

        yield ThreadItemDoneEvent(
            item=AssistantMessageItem(
                id="assistant_msg_tools_create",
                content=[AssistantMessageContent(text="ok")],
                created_at=datetime.now(),
                thread_id=thread.id,
            ),
        )

    with make_server(responder) as server:
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Hello with tools")],
                        attachments=[],
                        inference_options=InferenceOptions(tool_choice=tool_choice),
                    )
                )
            )
        )
        # Sanity check that the request flowed through.
        assert any(e.type == "thread.item.done" for e in events)


async def test_input_transcribe_decodes_base64_and_normalizes_mime_type():
    audio_bytes = b"hello audio"
    audio_b64 = base64.b64encode(audio_bytes).decode("ascii")
    seen: dict[str, Any] = {}

    def transcribe_callback(
        audio_input: AudioInput, context: Any
    ) -> TranscriptionResult:
        seen["audio"] = audio_input.data
        seen["mime"] = audio_input.mime_type
        seen["media_type"] = audio_input.media_type
        seen["context"] = context
        return TranscriptionResult(text="ok")

    with make_server(transcribe_callback=transcribe_callback) as server:
        result = await server.process_non_streaming(
            InputTranscribeReq(
                params=InputTranscribeParams(
                    audio_base64=audio_b64,
                    mime_type="audio/webm;codecs=opus",
                )
            )
        )
        parsed = TypeAdapter(TranscriptionResult).validate_json(result.json)
        assert parsed.text == "ok"

    assert seen["audio"] == audio_bytes
    assert seen["mime"] == "audio/webm;codecs=opus"
    assert seen["media_type"] == "audio/webm"
    assert seen["context"] == DEFAULT_CONTEXT


async def test_retry_after_item_passes_tools_to_responder():
    pass


async def test_threads_add_message_passes_tools_to_responder():
    tool_choice = ToolChoice(id="retrieval")

    async def responder(
        thread: ThreadMetadata, input: UserMessageItem | None, context: Any
    ) -> AsyncIterator[ThreadStreamEvent]:
        assert input is not None
        assert input.inference_options.tool_choice is not None
        assert input.inference_options.tool_choice.id == tool_choice.id

        yield ThreadItemDoneEvent(
            item=AssistantMessageItem(
                id="assistant_msg_tools_add",
                content=[AssistantMessageContent(text="ok")],
                created_at=datetime.now(),
                thread_id=thread.id,
            ),
        )

    with make_server(responder) as server:
        # Create a thread first.
        events = await server.process_streaming(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Start")],
                        attachments=[],
                        inference_options=InferenceOptions(tool_choice=tool_choice),
                    )
                )
            )
        )
        thread = next(e.thread for e in events if e.type == "thread.created")

        # Add a message with tools and verify they reach the responder.
        events = await server.process_streaming(
            ThreadsAddUserMessageReq(
                params=ThreadAddUserMessageParams(
                    thread_id=thread.id,
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Second")],
                        attachments=[],
                        inference_options=InferenceOptions(tool_choice=tool_choice),
                    ),
                )
            )
        )
        # Sanity check that the request flowed through.
        assert any(e.type == "thread.item.done" for e in events)


async def test_custom_id_generators_on_server():
    class UniqueIdStore(SQLiteStore):
        def __init__(self, path):
            super().__init__(path)
            self.counter = 0

        def generate_thread_id(self, context) -> str:
            self.counter += 1
            return f"thr_custom_{self.counter}"

        def generate_item_id(
            self,
            item_type: str,
            thread: ThreadMetadata,
            context,
        ) -> str:
            self.counter += 1
            return f"{item_type}_custom_{self.counter}_{thread.id}"

    async def responder(
        thread: ThreadMetadata, input: UserMessageItem | None, context: Any
    ) -> AsyncIterator[ThreadStreamEvent]:
        # Emit a tool call to later verify tool_call_output id generation
        return
        yield

    with make_server(responder) as server:
        # Swap in our custom ID-generating store, reusing the same DB path.
        db_path = cast(SQLiteStore, server.store).db_path
        server.store = UniqueIdStore(db_path)

        # Create thread and verify thread id and user message id
        result = await server.process(
            ThreadsCreateReq(
                params=ThreadCreateParams(
                    input=UserMessageInput(
                        content=[UserMessageTextContent(text="Hello")],
                        attachments=[],
                        inference_options=InferenceOptions(),
                    )
                )
            ).model_dump_json(),
            DEFAULT_CONTEXT,
        )
        assert isinstance(result, StreamingResult)
        events = await decode_streaming_result(result)

        thread_created = next(e for e in events if e.type == "thread.created")
        assert thread_created.thread.id == "thr_custom_1"

        user_message = next(
            e.item
            for e in events
            if e.type == "thread.item.done" and e.item.type == "user_message"
        )
        assert user_message.id == "message_custom_2_thr_custom_1"