openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
8b263facaeaba1c1035a4d92befabf8f2c75436c

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/server.py

1063lines · modepreview

import asyncio
import base64
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from contextlib import contextmanager
from datetime import datetime
from typing import (
    Any,
    AsyncGenerator,
    AsyncIterable,
    Callable,
    Generic,
)

import agents
from agents.models.chatcmpl_helpers import (
    HEADERS_OVERRIDE as chat_completions_headers_override,
)
from agents.models.openai_responses import (
    _HEADERS_OVERRIDE as responses_headers_override,
)
from pydantic import BaseModel, TypeAdapter
from typing_extensions import TypeVar, assert_never

from chatkit.errors import CustomStreamError, StreamError

from .logger import logger
from .store import AttachmentStore, Store, StoreItemType, default_generate_id
from .types import (
    Action,
    AssistantMessageContent,
    AssistantMessageContentPartAdded,
    AssistantMessageContentPartAnnotationAdded,
    AssistantMessageContentPartDone,
    AssistantMessageContentPartTextDelta,
    AssistantMessageItem,
    AttachmentsCreateReq,
    AttachmentsDeleteReq,
    AudioInput,
    ChatKitReq,
    ClientToolCallItem,
    ErrorCode,
    ErrorEvent,
    FeedbackKind,
    HiddenContextItem,
    InputTranscribeReq,
    ItemsFeedbackReq,
    ItemsListReq,
    NonStreamingReq,
    Page,
    SDKHiddenContextItem,
    StreamingReq,
    StreamOptions,
    StreamOptionsEvent,
    StructuredInputAnswer,
    StructuredInputItem,
    StructuredInputMultipleChoice,
    StructuredInputSubmission,
    SyncCustomActionResponse,
    Thread,
    ThreadCreatedEvent,
    ThreadItem,
    ThreadItemAddedEvent,
    ThreadItemDoneEvent,
    ThreadItemRemovedEvent,
    ThreadItemReplacedEvent,
    ThreadItemUpdatedEvent,
    ThreadMetadata,
    ThreadsAddClientToolOutputReq,
    ThreadsAddStructuredInputReq,
    ThreadsAddUserMessageReq,
    ThreadsCreateReq,
    ThreadsCustomActionReq,
    ThreadsDeleteReq,
    ThreadsGetByIdReq,
    ThreadsListReq,
    ThreadsRetryAfterItemReq,
    ThreadsSyncCustomActionReq,
    ThreadStreamEvent,
    ThreadsUpdateReq,
    ThreadUpdatedEvent,
    TranscriptionResult,
    UserMessageInput,
    UserMessageItem,
    WidgetComponentUpdated,
    WidgetItem,
    WidgetRootUpdated,
    WidgetStreamingTextValueDelta,
    WorkflowItem,
    WorkflowTaskAdded,
    WorkflowTaskUpdated,
    is_streaming_req,
)
from .version import __version__
from .widgets import WidgetComponent, WidgetComponentBase, WidgetRoot

DEFAULT_PAGE_SIZE = 20
DEFAULT_ERROR_MESSAGE = "An error occurred when generating a response."


def diff_widget(
    before: WidgetRoot, after: WidgetRoot
) -> list[WidgetStreamingTextValueDelta | WidgetRootUpdated | WidgetComponentUpdated]:
    """
    Compare two WidgetRoots and return a list of deltas.
    """

    def is_streaming_text(component: WidgetComponentBase) -> bool:
        return getattr(component, "type", None) in {"Markdown", "Text"} and isinstance(
            getattr(component, "value", None), str
        )

    def full_replace(before: WidgetComponentBase, after: WidgetComponentBase) -> bool:
        if (
            before.type != after.type
            or before.id != after.id
            or before.key != after.key
        ):
            return True

        def full_replace_value(before_value: Any, after_value: Any) -> bool:
            if isinstance(before_value, list) and isinstance(after_value, list):
                if len(before_value) != len(after_value):
                    return True
                for nth_before_value, nth_after_value in zip(before_value, after_value):
                    if full_replace_value(nth_before_value, nth_after_value):
                        return True
            elif before_value != after_value:
                if isinstance(before_value, WidgetComponentBase) and isinstance(
                    after_value, WidgetComponentBase
                ):
                    return full_replace(before_value, after_value)
                else:
                    return True
            return False

        for field in before.model_fields_set.union(after.model_fields_set):
            if (
                is_streaming_text(before)
                and is_streaming_text(after)
                and field == "value"
                and getattr(after, "value", "").startswith(getattr(before, "value", ""))
            ):
                # Appends to the value prop of Markdown or Text do not trigger a full replace
                continue
            if full_replace_value(getattr(before, field), getattr(after, field)):
                return True

        return False

    if full_replace(before, after):
        return [WidgetRootUpdated(widget=after)]

    deltas: list[
        WidgetStreamingTextValueDelta | WidgetComponentUpdated | WidgetRootUpdated
    ] = []

    def find_all_streaming_text_components(
        component: WidgetComponent | WidgetRoot,
    ) -> dict[str, WidgetComponentBase]:
        components = {}

        def recurse(component: WidgetComponent | WidgetRoot):
            if is_streaming_text(component) and component.id:
                components[component.id] = component

            if hasattr(component, "children"):
                children = getattr(component, "children", None) or []
                for child in children:
                    recurse(child)

        recurse(component)
        return components

    before_nodes = find_all_streaming_text_components(before)
    after_nodes = find_all_streaming_text_components(after)

    for id, after_node in after_nodes.items():
        before_node = before_nodes.get(id)
        if before_node is None:
            raise ValueError(
                f"Node {id} was not present when the widget was initially rendered. All nodes with ID must persist across all widget updates."
            )

        before_value = str(getattr(before_node, "value", None))
        after_value = str(getattr(after_node, "value", None))

        if before_value != after_value:
            if not after_value.startswith(before_value):
                raise ValueError(
                    f"Node {id} was updated with a new value that is not a prefix of the initial value. All widget updates must be cumulative."
                )
            done = not getattr(after_node, "streaming", False)
            deltas.append(
                WidgetStreamingTextValueDelta(
                    component_id=id,
                    delta=after_value[len(before_value) :],
                    done=done,
                )
            )

    return deltas


async def stream_widget(
    thread: ThreadMetadata,
    widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
    copy_text: str | None = None,
    generate_id: Callable[[StoreItemType], str] = default_generate_id,
) -> AsyncIterator[ThreadStreamEvent]:
    """Stream a widget root (or async sequence of roots) as ThreadStreamEvents."""
    item_id = generate_id("message")

    if not isinstance(widget, AsyncGenerator):
        yield ThreadItemDoneEvent(
            item=WidgetItem(
                id=item_id,
                thread_id=thread.id,
                created_at=datetime.now(),
                widget=widget,
                copy_text=copy_text,
            ),
        )
        return

    initial_state = await widget.__anext__()

    item = WidgetItem(
        id=item_id,
        created_at=datetime.now(),
        widget=initial_state,
        copy_text=copy_text,
        thread_id=thread.id,
    )

    yield ThreadItemAddedEvent(item=item)

    last_state = initial_state

    while widget:
        try:
            new_state = await widget.__anext__()
            for update in diff_widget(last_state, new_state):
                yield ThreadItemUpdatedEvent(
                    item_id=item_id,
                    update=update,
                )
            last_state = new_state
        except StopAsyncIteration:
            break

    yield ThreadItemDoneEvent(
        item=item.model_copy(update={"widget": last_state}),
    )


@contextmanager
def agents_sdk_user_agent_override():
    ua = f"Agents/Python {agents.__version__} ChatKit/Python {__version__}"
    chat_completions_token = chat_completions_headers_override.set({"User-Agent": ua})
    responses_token = responses_headers_override.set({"User-Agent": ua})

    yield

    chat_completions_headers_override.reset(chat_completions_token)
    responses_headers_override.reset(responses_token)


class StreamingResult(AsyncIterable[bytes]):
    def __init__(self, stream: AsyncGenerator[bytes, None]):
        self.json_events = stream

    async def __aiter__(self):
        async for event in self.json_events:
            yield event


class NonStreamingResult:
    def __init__(self, result: bytes):
        self.json = result


TContext = TypeVar("TContext", default=Any)


class ChatKitServer(ABC, Generic[TContext]):
    def __init__(
        self,
        store: Store[TContext],
        attachment_store: AttachmentStore[TContext] | None = None,
    ):
        """Create a ChatKitServer with the backing Store and optional AttachmentStore."""
        self.store = store
        self.attachment_store = attachment_store

    def _get_attachment_store(self) -> AttachmentStore[TContext]:
        """Return the configured AttachmentStore or raise if missing."""
        if self.attachment_store is None:
            raise RuntimeError(
                "AttachmentStore is not configured. Provide a AttachmentStore to ChatKitServer to handle file operations."
            )
        return self.attachment_store

    @abstractmethod
    def respond(
        self,
        thread: ThreadMetadata,
        input_user_message: UserMessageItem | None,
        context: TContext,
    ) -> AsyncIterator[ThreadStreamEvent]:
        """Stream `ThreadStreamEvent` instances for a new user message.

        Args:
            thread: Metadata for the thread being processed.
            input_user_message: The incoming message the server should respond to, if any.
            context: Arbitrary per-request context provided by the caller.

        Returns:
            An async iterator that yields events representing the server's response.
        """
        pass

    async def add_feedback(  # noqa: B027
        self,
        thread_id: str,
        item_ids: list[str],
        feedback: FeedbackKind,
        context: TContext,
    ) -> None:
        """Persist user feedback for one or more thread items."""
        pass

    async def transcribe(  # noqa: B027
        self, audio_input: AudioInput, context: TContext
    ) -> TranscriptionResult:
        """Transcribe speech audio to text (dictation).

        Audio bytes are in `audio_input.data`. The raw MIME type is `audio_input.mime_type`; use
        `audio_input.media_type` for the base media type (one of: `audio/webm`, `audio/ogg`, `audio/mp4`).

        Args:
            audio_input: Audio bytes plus MIME type metadata for transcription.
            context: Arbitrary per-request context provided by the caller.

        Returns:
            A `TranscriptionResult` whose `text` is what should appear in the composer.
        """
        raise NotImplementedError(
            "transcribe() must be overridden to support the input.transcribe request."
        )

    def action(
        self,
        thread: ThreadMetadata,
        action: Action[str, Any],
        sender: WidgetItem | None,
        context: TContext,
    ) -> AsyncIterator[ThreadStreamEvent]:
        """Handle a widget or client-dispatched action and yield response events."""
        raise NotImplementedError(
            "The action() method must be overridden to react to actions. "
            "See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
        )

    async def sync_action(
        self,
        thread: ThreadMetadata,
        action: Action[str, Any],
        sender: WidgetItem | None,
        context: TContext,
    ) -> SyncCustomActionResponse:
        """Handle a widget or client-dispatched action and return a SyncCustomActionResponse."""
        raise NotImplementedError(
            "The sync_action() method must be overridden to react to sync actions. "
            "See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
        )

    def get_stream_options(
        self, thread: ThreadMetadata, context: TContext
    ) -> StreamOptions:
        """
        Return stream-level runtime options. Allows the user to cancel the stream by default.
        Override this method to customize behavior.
        """
        return StreamOptions(allow_cancel=True)

    async def handle_stream_cancelled(
        self,
        thread: ThreadMetadata,
        pending_items: list[ThreadItem],
        context: TContext,
    ):
        """Perform custom cleanup / stop inference when a stream is cancelled.
        Updates you make here will not be reflected in the UI until a reload.

        The default implementation persists any non-empty pending assistant messages
        to the thread but does not auto-save pending widget items or workflow items.

        Args:
            thread: The thread that was being processed.
            pending_items: Items that were not done streaming at cancellation time.
            context: Arbitrary per-request context provided by the caller.
        """
        pending_assistant_message_items: list[AssistantMessageItem] = [
            item for item in pending_items if isinstance(item, AssistantMessageItem)
        ]
        for item in pending_assistant_message_items:
            is_empty = len(item.content) == 0 or all(
                (not content.text.strip()) for content in item.content
            )
            if not is_empty:
                await self.store.add_thread_item(thread.id, item, context=context)

        # Add a hidden context item to the thread to indicate that the stream was cancelled.
        # Otherwise, depending on the timing of the cancellation, subsequent responses may
        # attempt to continue the cancelled response.
        await self.store.add_thread_item(
            thread.id,
            SDKHiddenContextItem(
                thread_id=thread.id,
                created_at=datetime.now(),
                id=self.store.generate_item_id("sdk_hidden_context", thread, context),
                content="The user cancelled the stream. Stop responding to the prior request.",
            ),
            context=context,
        )

    async def process(
        self, request: str | bytes | bytearray, context: TContext
    ) -> StreamingResult | NonStreamingResult:
        """Parse an incoming ChatKit request and route it to streaming or non-streaming handlers."""
        parsed_request = TypeAdapter[ChatKitReq](ChatKitReq).validate_json(request)
        logger.info(f"Received request op: {parsed_request.type}")

        if is_streaming_req(parsed_request):
            return StreamingResult(self._process_streaming(parsed_request, context))
        else:
            return NonStreamingResult(
                await self._process_non_streaming(parsed_request, context)
            )

    async def _process_non_streaming(
        self, request: NonStreamingReq, context: TContext
    ) -> bytes:
        match request:
            case ThreadsGetByIdReq():
                thread = await self._load_full_thread(
                    request.params.thread_id, context=context
                )
                return self._serialize(self._to_thread_response(thread))
            case ThreadsListReq():
                params = request.params
                threads = await self.store.load_threads(
                    limit=params.limit or DEFAULT_PAGE_SIZE,
                    after=params.after,
                    order=params.order,
                    context=context,
                )
                return self._serialize(
                    Page(
                        has_more=threads.has_more,
                        after=threads.after,
                        data=[
                            self._to_thread_response(thread) for thread in threads.data
                        ],
                    )
                )
            case ItemsFeedbackReq():
                await self.add_feedback(
                    request.params.thread_id,
                    request.params.item_ids,
                    request.params.kind,
                    context,
                )
                return b"{}"
            case AttachmentsCreateReq():
                attachment_store = self._get_attachment_store()
                attachment = await attachment_store.create_attachment(
                    request.params, context
                )
                await self.store.save_attachment(attachment, context=context)
                return self._serialize(attachment)
            case AttachmentsDeleteReq():
                attachment_store = self._get_attachment_store()
                await attachment_store.delete_attachment(
                    request.params.attachment_id, context=context
                )
                await self.store.delete_attachment(
                    request.params.attachment_id, context=context
                )
                return b"{}"
            case InputTranscribeReq():
                audio_bytes = base64.b64decode(request.params.audio_base64)
                transcription_result = await self.transcribe(
                    AudioInput(data=audio_bytes, mime_type=request.params.mime_type),
                    context=context,
                )
                return self._serialize(transcription_result)
            case ItemsListReq():
                items_list_params = request.params
                items = await self.store.load_thread_items(
                    items_list_params.thread_id,
                    limit=items_list_params.limit or DEFAULT_PAGE_SIZE,
                    order=items_list_params.order,
                    after=items_list_params.after,
                    context=context,
                )
                # filter out hidden context items
                items.data = [
                    item
                    for item in items.data
                    if not isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
                ]
                return self._serialize(items)
            case ThreadsUpdateReq():
                thread = await self.store.load_thread(
                    request.params.thread_id, context=context
                )
                thread.title = request.params.title
                await self.store.save_thread(thread, context=context)
                return self._serialize(self._to_thread_response(thread))
            case ThreadsDeleteReq():
                await self.store.delete_thread(
                    request.params.thread_id, context=context
                )
                return b"{}"
            case ThreadsSyncCustomActionReq():
                return await self._process_sync_custom_action(request, context)
            case _:
                assert_never(request)

    async def _process_streaming(
        self, request: StreamingReq, context: TContext
    ) -> AsyncGenerator[bytes, None]:
        try:
            async for event in self._process_streaming_impl(request, context):
                b = self._serialize(event)
                yield b"data: " + b + b"\n\n"
        except asyncio.CancelledError:
            # Let cancellation bubble up without logging as an error.
            raise
        except Exception:
            logger.exception("Error while generating streamed response")
            raise

    async def _process_streaming_impl(
        self, request: StreamingReq, context: TContext
    ) -> AsyncGenerator[ThreadStreamEvent, None]:
        match request:
            case ThreadsCreateReq():
                thread = Thread(
                    id=self.store.generate_thread_id(context),
                    created_at=datetime.now(),
                    items=Page(),
                )
                await self.store.save_thread(
                    ThreadMetadata(**thread.model_dump()),
                    context=context,
                )
                yield ThreadCreatedEvent(thread=self._to_thread_response(thread))
                user_message = await self._build_user_message_item(
                    request.params.input, thread, context
                )
                async for event in self._process_new_thread_item_respond(
                    thread,
                    user_message,
                    context,
                ):
                    yield event

            case ThreadsAddUserMessageReq():
                thread = await self.store.load_thread(
                    request.params.thread_id, context=context
                )
                user_message = await self._build_user_message_item(
                    request.params.input, thread, context
                )
                async for event in self._process_new_thread_item_respond(
                    thread,
                    user_message,
                    context,
                ):
                    yield event

            case ThreadsAddClientToolOutputReq():
                thread = await self.store.load_thread(
                    request.params.thread_id, context=context
                )
                items = await self.store.load_thread_items(
                    thread.id, None, 1, "desc", context
                )
                tool_call = next(
                    (
                        item
                        for item in items.data
                        if isinstance(item, ClientToolCallItem)
                        and item.status == "pending"
                    ),
                    None,
                )
                if not tool_call:
                    raise ValueError(
                        f"Last thread item in {thread.id} was not a ClientToolCallItem"
                    )

                tool_call.output = request.params.result
                tool_call.status = "completed"

                await self.store.save_item(thread.id, tool_call, context=context)

                # Safety against dangling pending tool calls if there are
                # multiple in a row, which should be impossible, and
                # integrations should ultimately filter out pending tool calls
                # when creating input response messages.
                await self._cleanup_pending_client_tool_call(thread, context)

                async for event in self._process_events(
                    thread,
                    context,
                    lambda: self.respond(thread, None, context),
                ):
                    yield event

            case ThreadsAddStructuredInputReq():
                thread = await self.store.load_thread(
                    request.params.thread_id, context=context
                )
                item = await self.store.load_item(
                    request.params.thread_id,
                    request.params.item_id,
                    context=context,
                )
                if not isinstance(item, StructuredInputItem):
                    raise ValueError(
                        f"Item {request.params.item_id} is not a StructuredInputItem"
                    )

                updated_item = await self._apply_structured_input_submission(
                    item,
                    request.params.input,
                )

                async def stream_structured_input_response() -> AsyncIterator[
                    ThreadStreamEvent
                ]:
                    # Keep this replacement inside _process_events so it is persisted
                    # through the same event pipeline as other thread item replacements.
                    yield ThreadItemReplacedEvent(item=updated_item)
                    async for event in self.respond(thread, None, context):
                        yield event

                async for event in self._process_events(
                    thread,
                    context,
                    stream_structured_input_response,
                ):
                    yield event

            case ThreadsRetryAfterItemReq():
                thread_metadata = await self.store.load_thread(
                    request.params.thread_id, context=context
                )

                # Collect items to remove (all items after the user message)
                items_to_remove: list[ThreadItem] = []
                user_message_item = None

                async for item in self._paginate_thread_items_reverse(
                    request.params.thread_id, context
                ):
                    if item.id == request.params.item_id:
                        if not isinstance(item, UserMessageItem):
                            raise ValueError(
                                f"Item {request.params.item_id} is not a user message"
                            )
                        user_message_item = item
                        break
                    items_to_remove.append(item)

                if user_message_item:
                    for item in items_to_remove:
                        await self.store.delete_thread_item(
                            request.params.thread_id, item.id, context=context
                        )
                    async for event in self._process_events(
                        thread_metadata,
                        context,
                        lambda: self.respond(
                            thread_metadata,
                            user_message_item,
                            context,
                        ),
                    ):
                        yield event
            case ThreadsCustomActionReq():
                thread_metadata = await self.store.load_thread(
                    request.params.thread_id, context=context
                )

                item: ThreadItem | None = None
                if request.params.item_id:
                    item = await self.store.load_item(
                        request.params.thread_id,
                        request.params.item_id,
                        context=context,
                    )

                if item and not isinstance(item, WidgetItem):
                    # shouldn't happen if the caller is using the API correctly.
                    yield ErrorEvent(
                        code=ErrorCode.STREAM_ERROR,
                        allow_retry=False,
                    )
                    return

                async for event in self._process_events(
                    thread_metadata,
                    context,
                    lambda: self.action(
                        thread_metadata,
                        request.params.action,
                        item,
                        context,
                    ),
                ):
                    yield event

            case _:
                assert_never(request)

    async def _process_sync_custom_action(
        self, request: ThreadsSyncCustomActionReq, context: TContext
    ) -> bytes:
        thread_metadata = await self.store.load_thread(
            request.params.thread_id, context=context
        )

        item: ThreadItem | None = None
        if request.params.item_id:
            item = await self.store.load_item(
                request.params.thread_id,
                request.params.item_id,
                context=context,
            )

        if item and not isinstance(item, WidgetItem):
            raise ValueError("threads.sync_custom_action requires a widget sender item")

        return self._serialize(
            await self.sync_action(
                thread_metadata,
                request.params.action,
                item,
                context,
            )
        )

    async def _apply_structured_input_submission(
        self,
        item: StructuredInputItem,
        submission: StructuredInputSubmission,
    ) -> StructuredInputItem:
        if item.status != "pending":
            raise ValueError(f"Structured input item {item.id} is not pending")

        updated_item = item.model_copy(deep=True)
        questions_by_id = {
            question.id: question for question in updated_item.inputs
        }
        submitted_question_ids = set(submission.answers)
        unknown_question_ids = submitted_question_ids.difference(questions_by_id)
        if unknown_question_ids:
            unknown = ", ".join(sorted(unknown_question_ids))
            logger.warning(
                f"Structured input item {item.id} received unknown question id(s), ignoring: {unknown}"
            )

        for question in updated_item.inputs:
            answer = submission.answers.get(question.id)

            if (
                submission.status == "skipped" or
                answer is None or
                answer.skipped or
                not answer.values
            ):
                question.answer = StructuredInputAnswer(skipped=True)
                continue

            values = answer.values
            if isinstance(question, StructuredInputMultipleChoice):
                if not question.multiple:
                    values = values[:1]

            question.answer = StructuredInputAnswer(values=values)

        updated_item.status = submission.status
        return updated_item

    async def _cleanup_pending_client_tool_call(
        self, thread: ThreadMetadata, context: TContext
    ) -> None:
        items = await self.store.load_thread_items(
            thread.id, None, DEFAULT_PAGE_SIZE, "desc", context
        )
        for tool_call in items.data:
            if not isinstance(tool_call, ClientToolCallItem):
                continue
            if tool_call.status == "pending":
                logger.warning(
                    f"Client tool call {tool_call.call_id} was not completed, ignoring"
                )
                await self.store.delete_thread_item(
                    thread.id, tool_call.id, context=context
                )

    async def _process_new_thread_item_respond(
        self,
        thread: ThreadMetadata,
        item: UserMessageItem,
        context: TContext,
    ) -> AsyncIterator[ThreadStreamEvent]:
        # Store the updated attachments (now that they have a thread id).
        for attachment in item.attachments:
            await self.store.save_attachment(attachment, context=context)

        await self.store.add_thread_item(thread.id, item, context=context)
        yield ThreadItemDoneEvent(item=item)

        async for event in self._process_events(
            thread,
            context,
            lambda: self.respond(thread, item, context),
        ):
            yield event

    async def _process_events(
        self,
        thread: ThreadMetadata,
        context: TContext,
        stream: Callable[[], AsyncIterator[ThreadStreamEvent]],
    ) -> AsyncIterator[ThreadStreamEvent]:
        await asyncio.sleep(0)  # allow the response to start streaming

        # Send initial stream options
        yield StreamOptionsEvent(
            stream_options=self.get_stream_options(thread, context)
        )

        last_thread = thread.model_copy(deep=True)

        # Keep track of items that were streamed but not yet saved
        # so that we can persist them when the stream is cancelled.
        pending_items: dict[str, ThreadItem] = {}

        try:
            with agents_sdk_user_agent_override():
                async for event in stream():
                    if isinstance(event, ThreadItemAddedEvent):
                        # Stash an isolated copy in case we need to persist unfinished items
                        # on cancellation; downstream handlers keep using the original event.item.
                        pending_items[event.item.id] = event.item.model_copy(deep=True)

                    match event:
                        case ThreadItemDoneEvent():
                            await self.store.add_thread_item(
                                thread.id, event.item, context=context
                            )
                            pending_items.pop(event.item.id, None)
                        case ThreadItemRemovedEvent():
                            await self.store.delete_thread_item(
                                thread.id, event.item_id, context=context
                            )
                            pending_items.pop(event.item_id, None)
                        case ThreadItemReplacedEvent():
                            await self.store.save_item(
                                thread.id, event.item, context=context
                            )
                            pending_items.pop(event.item.id, None)
                        case ThreadItemUpdatedEvent():
                            # Keep pending assistant message and workflow items up to date
                            # so that we have a reference to the latest version of these pending items
                            # when the stream is cancelled.
                            self._update_pending_items(pending_items, event)

                    # special case - don't send hidden context items back to the client
                    should_swallow_event = isinstance(
                        event, ThreadItemDoneEvent
                    ) and isinstance(
                        event.item, (HiddenContextItem, SDKHiddenContextItem)
                    )

                    if not should_swallow_event:
                        yield event

                    # in case user updated the thread while streaming
                    if thread != last_thread:
                        last_thread = thread.model_copy(deep=True)
                        await self.store.save_thread(thread, context=context)
                        yield ThreadUpdatedEvent(
                            thread=self._to_thread_response(thread)
                        )
                # in case user updated the thread while streaming
                if thread != last_thread:
                    last_thread = thread.model_copy(deep=True)
                    await self.store.save_thread(thread, context=context)
                    yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
        except asyncio.CancelledError:
            await self.handle_stream_cancelled(
                thread, list(pending_items.values()), context
            )
            raise
        except CustomStreamError as e:
            yield ErrorEvent(
                code="custom",
                message=e.message,
                allow_retry=e.allow_retry,
            )
        except StreamError as e:
            yield ErrorEvent(
                code=e.code,
                allow_retry=e.allow_retry,
            )
        except Exception as e:
            yield ErrorEvent(
                code=ErrorCode.STREAM_ERROR,
                allow_retry=True,
            )
            logger.exception(e)

        if thread != last_thread:
            # in case user updated the thread at the end of the stream
            await self.store.save_thread(thread, context=context)
            yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))

    def _apply_assistant_message_update(
        self,
        item: AssistantMessageItem,
        update: AssistantMessageContentPartAdded
        | AssistantMessageContentPartTextDelta
        | AssistantMessageContentPartAnnotationAdded
        | AssistantMessageContentPartDone,
    ) -> AssistantMessageItem:
        # Pad the content list so the requested content_index exists before we write into it.
        # (Streaming updates can arrive for an index that hasn’t been created yet)
        while len(item.content) <= update.content_index:
            item.content.append(AssistantMessageContent(text="", annotations=[]))

        match update:
            case AssistantMessageContentPartAdded():
                item.content[update.content_index] = update.content
            case AssistantMessageContentPartTextDelta():
                item.content[update.content_index].text += update.delta
            case AssistantMessageContentPartAnnotationAdded():
                annotations = item.content[update.content_index].annotations
                if update.annotation_index <= len(annotations):
                    annotations.insert(update.annotation_index, update.annotation)
                else:
                    annotations.append(update.annotation)
            case AssistantMessageContentPartDone():
                item.content[update.content_index] = update.content
        return item

    def _update_pending_items(
        self,
        pending_items: dict[str, ThreadItem],
        event: ThreadItemUpdatedEvent,
    ):
        updated_item = pending_items.get(event.item_id)
        update = event.update
        match updated_item:
            case AssistantMessageItem():
                if isinstance(
                    update,
                    (
                        AssistantMessageContentPartAdded,
                        AssistantMessageContentPartTextDelta,
                        AssistantMessageContentPartAnnotationAdded,
                        AssistantMessageContentPartDone,
                    ),
                ):
                    pending_items[updated_item.id] = (
                        self._apply_assistant_message_update(updated_item, update)
                    )
            case WorkflowItem():
                if isinstance(update, (WorkflowTaskUpdated, WorkflowTaskAdded)):
                    match update:
                        case WorkflowTaskUpdated():
                            updated_item.workflow.tasks[update.task_index] = update.task
                        case WorkflowTaskAdded():
                            updated_item.workflow.tasks.append(update.task)

                    pending_items[updated_item.id] = updated_item
            case _:
                pass

    async def _build_user_message_item(
        self, input: UserMessageInput, thread: ThreadMetadata, context: TContext
    ) -> UserMessageItem:
        return UserMessageItem(
            id=self.store.generate_item_id("message", thread, context),
            content=input.content,
            thread_id=thread.id,
            attachments=[
                (await self.store.load_attachment(attachment_id, context)).model_copy(
                    update={"thread_id": thread.id}
                )
                for attachment_id in input.attachments
            ],
            quoted_text=input.quoted_text,
            inference_options=input.inference_options,
            created_at=datetime.now(),
        )

    async def _load_full_thread(self, thread_id: str, context: TContext) -> Thread:
        thread_meta = await self.store.load_thread(thread_id, context=context)
        thread_items = await self.store.load_thread_items(
            thread_id,
            after=None,
            limit=DEFAULT_PAGE_SIZE,
            order="asc",
            context=context,
        )
        return Thread(**thread_meta.model_dump(), items=thread_items)

    async def _paginate_thread_items_reverse(
        self, thread_id: str, context: TContext
    ) -> AsyncIterator[ThreadItem]:
        """Paginate through thread items in reverse order (newest first)."""
        after = None
        while True:
            items = await self.store.load_thread_items(
                thread_id, after, DEFAULT_PAGE_SIZE, "desc", context
            )
            for item in items.data:
                yield item

            if not items.has_more:
                break
            after = items.after

    def _serialize(self, obj: BaseModel) -> bytes:
        return obj.model_dump_json(
            by_alias=True,
            exclude_none=True,
            context={"exclude_metadata": True},
        ).encode("utf-8")

    def _to_thread_response(self, thread: ThreadMetadata | Thread) -> Thread:
        def is_hidden(item: ThreadItem) -> bool:
            return isinstance(item, (HiddenContextItem, SDKHiddenContextItem))

        items = thread.items if isinstance(thread, Thread) else Page()
        items.data = [item for item in items.data if not is_hidden(item)]

        return Thread(
            id=thread.id,
            title=thread.title,
            created_at=thread.created_at,
            items=items,
            status=thread.status,
            allowed_image_domains=thread.allowed_image_domains,
        )