openai/chatkit-python

Public

mirrored from https://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.6.3

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/server.py

981lines · modeblame

f688d870victor-openai8 months ago1import asyncio
a202f603Jiwon Kim5 months ago2import base64
f688d870victor-openai8 months ago3from abc import ABC, abstractmethod
4from collections.abc import AsyncIterator
5from contextlib import contextmanager
6from datetime import datetime
7from typing import (
8Any,
9AsyncGenerator,
10AsyncIterable,
11Callable,
12Generic,
13)
14
15import agents
16from agents.models.chatcmpl_helpers import (
17HEADERS_OVERRIDE as chat_completions_headers_override,
18)
19from agents.models.openai_responses import (
20_HEADERS_OVERRIDE as responses_headers_override,
21)
22from pydantic import BaseModel, TypeAdapter
3de07454Dan Wang7 months ago23from typing_extensions import TypeVar, assert_never
f688d870victor-openai8 months ago24
25from chatkit.errors import CustomStreamError, StreamError
26
27from .logger import logger
28from .store import AttachmentStore, Store, StoreItemType, default_generate_id
29from .types import (
30Action,
eee0a2f7Jiwon Kim7 months ago31AssistantMessageContent,
32AssistantMessageContentPartAdded,
33AssistantMessageContentPartAnnotationAdded,
34AssistantMessageContentPartDone,
35AssistantMessageContentPartTextDelta,
36AssistantMessageItem,
f688d870victor-openai8 months ago37AttachmentsCreateReq,
38AttachmentsDeleteReq,
9c6c0f90Jiwon Kim5 months ago39AudioInput,
f688d870victor-openai8 months ago40ChatKitReq,
41ClientToolCallItem,
42ErrorCode,
43ErrorEvent,
44FeedbackKind,
45HiddenContextItem,
a202f603Jiwon Kim5 months ago46InputTranscribeReq,
f688d870victor-openai8 months ago47ItemsFeedbackReq,
48ItemsListReq,
49NonStreamingReq,
50Page,
5a7244baJiwon Kim7 months ago51SDKHiddenContextItem,
f688d870victor-openai8 months ago52StreamingReq,
5a7244baJiwon Kim7 months ago53StreamOptions,
54StreamOptionsEvent,
8bb983f1David Weedon4 months ago55SyncCustomActionResponse,
f688d870victor-openai8 months ago56Thread,
57ThreadCreatedEvent,
58ThreadItem,
59ThreadItemAddedEvent,
60ThreadItemDoneEvent,
61ThreadItemRemovedEvent,
62ThreadItemReplacedEvent,
643a02f7Jiwon Kim7 months ago63ThreadItemUpdatedEvent,
f688d870victor-openai8 months ago64ThreadMetadata,
65ThreadsAddClientToolOutputReq,
66ThreadsAddUserMessageReq,
67ThreadsCreateReq,
68ThreadsCustomActionReq,
69ThreadsDeleteReq,
70ThreadsGetByIdReq,
71ThreadsListReq,
72ThreadsRetryAfterItemReq,
8bb983f1David Weedon4 months ago73ThreadsSyncCustomActionReq,
f688d870victor-openai8 months ago74ThreadStreamEvent,
75ThreadsUpdateReq,
76ThreadUpdatedEvent,
a202f603Jiwon Kim5 months ago77TranscriptionResult,
f688d870victor-openai8 months ago78UserMessageInput,
79UserMessageItem,
80WidgetComponentUpdated,
81WidgetItem,
82WidgetRootUpdated,
83WidgetStreamingTextValueDelta,
eee0a2f7Jiwon Kim7 months ago84WorkflowItem,
5a7244baJiwon Kim7 months ago85WorkflowTaskAdded,
86WorkflowTaskUpdated,
f688d870victor-openai8 months ago87is_streaming_req,
88)
89from .version import __version__
91c93ee6Jiwon Kim7 months ago90from .widgets import WidgetComponent, WidgetComponentBase, WidgetRoot
f688d870victor-openai8 months ago91
92DEFAULT_PAGE_SIZE = 20
93DEFAULT_ERROR_MESSAGE = "An error occurred when generating a response."
94
95
96def diff_widget(
97before: WidgetRoot, after: WidgetRoot
98) -> list[WidgetStreamingTextValueDelta | WidgetRootUpdated | WidgetComponentUpdated]:
99"""
100Compare two WidgetRoots and return a list of deltas.
101"""
102
91c93ee6Jiwon Kim7 months ago103def is_streaming_text(component: WidgetComponentBase) -> bool:
104return getattr(component, "type", None) in {"Markdown", "Text"} and isinstance(
105getattr(component, "value", None), str
106)
107
f688d870victor-openai8 months ago108def full_replace(before: WidgetComponentBase, after: WidgetComponentBase) -> bool:
109if (
110before.type != after.type
111or before.id != after.id
112or before.key != after.key
113):
114return True
115
116def full_replace_value(before_value: Any, after_value: Any) -> bool:
117if isinstance(before_value, list) and isinstance(after_value, list):
118if len(before_value) != len(after_value):
119return True
120for nth_before_value, nth_after_value in zip(before_value, after_value):
121if full_replace_value(nth_before_value, nth_after_value):
122return True
123elif before_value != after_value:
124if isinstance(before_value, WidgetComponentBase) and isinstance(
125after_value, WidgetComponentBase
126):
127return full_replace(before_value, after_value)
128else:
129return True
130return False
131
132for field in before.model_fields_set.union(after.model_fields_set):
133if (
91c93ee6Jiwon Kim7 months ago134is_streaming_text(before)
135and is_streaming_text(after)
f688d870victor-openai8 months ago136and field == "value"
91c93ee6Jiwon Kim7 months ago137and getattr(after, "value", "").startswith(getattr(before, "value", ""))
f688d870victor-openai8 months ago138):
139# Appends to the value prop of Markdown or Text do not trigger a full replace
140continue
141if full_replace_value(getattr(before, field), getattr(after, field)):
142return True
143
144return False
145
146if full_replace(before, after):
147return [WidgetRootUpdated(widget=after)]
148
149deltas: list[
150WidgetStreamingTextValueDelta | WidgetComponentUpdated | WidgetRootUpdated
151] = []
152
153def find_all_streaming_text_components(
154component: WidgetComponent | WidgetRoot,
91c93ee6Jiwon Kim7 months ago155) -> dict[str, WidgetComponentBase]:
f688d870victor-openai8 months ago156components = {}
157
158def recurse(component: WidgetComponent | WidgetRoot):
91c93ee6Jiwon Kim7 months ago159if is_streaming_text(component) and component.id:
f688d870victor-openai8 months ago160components[component.id] = component
161
162if hasattr(component, "children"):
163children = getattr(component, "children", None) or []
164for child in children:
165recurse(child)
166
167recurse(component)
168return components
169
170before_nodes = find_all_streaming_text_components(before)
171after_nodes = find_all_streaming_text_components(after)
172
173for id, after_node in after_nodes.items():
174before_node = before_nodes.get(id)
175if before_node is None:
176raise ValueError(
177f"Node {id} was not present when the widget was initially rendered. All nodes with ID must persist across all widget updates."
178)
179
91c93ee6Jiwon Kim7 months ago180before_value = str(getattr(before_node, "value", None))
181after_value = str(getattr(after_node, "value", None))
182
183if before_value != after_value:
184if not after_value.startswith(before_value):
f688d870victor-openai8 months ago185raise ValueError(
186f"Node {id} was updated with a new value that is not a prefix of the initial value. All widget updates must be cumulative."
187)
91c93ee6Jiwon Kim7 months ago188done = not getattr(after_node, "streaming", False)
f688d870victor-openai8 months ago189deltas.append(
190WidgetStreamingTextValueDelta(
191component_id=id,
91c93ee6Jiwon Kim7 months ago192delta=after_value[len(before_value) :],
f688d870victor-openai8 months ago193done=done,
194)
195)
196
197return deltas
198
199
200async def stream_widget(
201thread: ThreadMetadata,
202widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
203copy_text: str | None = None,
204generate_id: Callable[[StoreItemType], str] = default_generate_id,
205) -> AsyncIterator[ThreadStreamEvent]:
3b20c3d1Jiwon Kim6 months ago206"""Stream a widget root (or async sequence of roots) as ThreadStreamEvents."""
f688d870victor-openai8 months ago207item_id = generate_id("message")
208
209if not isinstance(widget, AsyncGenerator):
210yield ThreadItemDoneEvent(
211item=WidgetItem(
212id=item_id,
213thread_id=thread.id,
214created_at=datetime.now(),
215widget=widget,
216copy_text=copy_text,
217),
218)
219return
220
221initial_state = await widget.__anext__()
222
223item = WidgetItem(
224id=item_id,
225created_at=datetime.now(),
226widget=initial_state,
227copy_text=copy_text,
228thread_id=thread.id,
229)
230
231yield ThreadItemAddedEvent(item=item)
232
233last_state = initial_state
234
235while widget:
236try:
237new_state = await widget.__anext__()
238for update in diff_widget(last_state, new_state):
643a02f7Jiwon Kim7 months ago239yield ThreadItemUpdatedEvent(
f688d870victor-openai8 months ago240item_id=item_id,
241update=update,
242)
243last_state = new_state
244except StopAsyncIteration:
245break
246
247yield ThreadItemDoneEvent(
248item=item.model_copy(update={"widget": last_state}),
249)
250
251
252@contextmanager
253def agents_sdk_user_agent_override():
254ua = f"Agents/Python {agents.__version__} ChatKit/Python {__version__}"
255chat_completions_token = chat_completions_headers_override.set({"User-Agent": ua})
256responses_token = responses_headers_override.set({"User-Agent": ua})
257
258yield
259
260chat_completions_headers_override.reset(chat_completions_token)
261responses_headers_override.reset(responses_token)
262
263
264class StreamingResult(AsyncIterable[bytes]):
265def __init__(self, stream: AsyncGenerator[bytes, None]):
266self.json_events = stream
267
268async def __aiter__(self):
269async for event in self.json_events:
270yield event
271
272
273class NonStreamingResult:
274def __init__(self, result: bytes):
275self.json = result
276
277
278TContext = TypeVar("TContext", default=Any)
279
280
281class ChatKitServer(ABC, Generic[TContext]):
282def __init__(
283self,
284store: Store[TContext],
285attachment_store: AttachmentStore[TContext] | None = None,
286):
3b20c3d1Jiwon Kim6 months ago287"""Create a ChatKitServer with the backing Store and optional AttachmentStore."""
f688d870victor-openai8 months ago288self.store = store
289self.attachment_store = attachment_store
290
291def _get_attachment_store(self) -> AttachmentStore[TContext]:
292"""Return the configured AttachmentStore or raise if missing."""
293if self.attachment_store is None:
294raise RuntimeError(
295"AttachmentStore is not configured. Provide a AttachmentStore to ChatKitServer to handle file operations."
296)
297return self.attachment_store
298
299@abstractmethod
300def respond(
301self,
302thread: ThreadMetadata,
303input_user_message: UserMessageItem | None,
304context: TContext,
305) -> AsyncIterator[ThreadStreamEvent]:
6988ea0avictor-openai8 months ago306"""Stream `ThreadStreamEvent` instances for a new user message.
307
308Args:
309thread: Metadata for the thread being processed.
310input_user_message: The incoming message the server should respond to, if any.
311context: Arbitrary per-request context provided by the caller.
312
313Returns:
314An async iterator that yields events representing the server's response.
315"""
f688d870victor-openai8 months ago316pass
317
318async def add_feedback( # noqa: B027
319self,
320thread_id: str,
321item_ids: list[str],
322feedback: FeedbackKind,
323context: TContext,
324) -> None:
3b20c3d1Jiwon Kim6 months ago325"""Persist user feedback for one or more thread items."""
f688d870victor-openai8 months ago326pass
327
a202f603Jiwon Kim5 months ago328async def transcribe( # noqa: B027
9c6c0f90Jiwon Kim5 months ago329self, audio_input: AudioInput, context: TContext
a202f603Jiwon Kim5 months ago330) -> TranscriptionResult:
9c6c0f90Jiwon Kim5 months ago331"""Transcribe speech audio to text (dictation).
55114842Jiwon Kim5 months ago332
333Audio bytes are in `audio_input.data`. The raw MIME type is `audio_input.mime_type`; use
4e1156ddJiwon Kim5 months ago334`audio_input.media_type` for the base media type (one of: `audio/webm`, `audio/ogg`, `audio/mp4`).
9c6c0f90Jiwon Kim5 months ago335
336Args:
337audio_input: Audio bytes plus MIME type metadata for transcription.
338context: Arbitrary per-request context provided by the caller.
339
340Returns:
341A `TranscriptionResult` whose `text` is what should appear in the composer.
342"""
a202f603Jiwon Kim5 months ago343raise NotImplementedError(
344"transcribe() must be overridden to support the input.transcribe request."
345)
346
f688d870victor-openai8 months ago347def action(
348self,
349thread: ThreadMetadata,
350action: Action[str, Any],
351sender: WidgetItem | None,
352context: TContext,
353) -> AsyncIterator[ThreadStreamEvent]:
3b20c3d1Jiwon Kim6 months ago354"""Handle a widget or client-dispatched action and yield response events."""
f688d870victor-openai8 months ago355raise NotImplementedError(
356"The action() method must be overridden to react to actions. "
93a34923Vincent Lin8 months ago357"See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
f688d870victor-openai8 months ago358)
359
d14ce460David Weedon4 months ago360async def sync_action(
8bb983f1David Weedon4 months ago361self,
362thread: ThreadMetadata,
363action: Action[str, Any],
364sender: WidgetItem | None,
365context: TContext,
366) -> SyncCustomActionResponse:
367"""Handle a widget or client-dispatched action and return a SyncCustomActionResponse."""
368raise NotImplementedError(
369"The sync_action() method must be overridden to react to sync actions. "
370"See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
371)
372
5a7244baJiwon Kim7 months ago373def get_stream_options(
374self, thread: ThreadMetadata, context: TContext
375) -> StreamOptions:
376"""
377Return stream-level runtime options. Allows the user to cancel the stream by default.
378Override this method to customize behavior.
379"""
380return StreamOptions(allow_cancel=True)
381
eee0a2f7Jiwon Kim7 months ago382async def handle_stream_cancelled(
383self,
384thread: ThreadMetadata,
385pending_items: list[ThreadItem],
386context: TContext,
387):
388"""Perform custom cleanup / stop inference when a stream is cancelled.
5a7244baJiwon Kim7 months ago389Updates you make here will not be reflected in the UI until a reload.
390
391The default implementation persists any non-empty pending assistant messages
392to the thread but does not auto-save pending widget items or workflow items.
eee0a2f7Jiwon Kim7 months ago393
394Args:
395thread: The thread that was being processed.
396pending_items: Items that were not done streaming at cancellation time.
397context: Arbitrary per-request context provided by the caller.
398"""
5a7244baJiwon Kim7 months ago399pending_assistant_message_items: list[AssistantMessageItem] = [
400item for item in pending_items if isinstance(item, AssistantMessageItem)
401]
402for item in pending_assistant_message_items:
403is_empty = len(item.content) == 0 or all(
404(not content.text.strip()) for content in item.content
405)
406if not is_empty:
407await self.store.add_thread_item(thread.id, item, context=context)
408
409# Add a hidden context item to the thread to indicate that the stream was cancelled.
410# Otherwise, depending on the timing of the cancellation, subsequent responses may
411# attempt to continue the cancelled response.
412await self.store.add_thread_item(
413thread.id,
414SDKHiddenContextItem(
415thread_id=thread.id,
416created_at=datetime.now(),
417id=self.store.generate_item_id("sdk_hidden_context", thread, context),
6df9ee60Jiwon Kim7 months ago418content="The user cancelled the stream. Stop responding to the prior request.",
5a7244baJiwon Kim7 months ago419),
420context=context,
421)
eee0a2f7Jiwon Kim7 months ago422
f688d870victor-openai8 months ago423async def process(
424self, request: str | bytes | bytearray, context: TContext
425) -> StreamingResult | NonStreamingResult:
3b20c3d1Jiwon Kim6 months ago426"""Parse an incoming ChatKit request and route it to streaming or non-streaming handlers."""
f688d870victor-openai8 months ago427parsed_request = TypeAdapter[ChatKitReq](ChatKitReq).validate_json(request)
428logger.info(f"Received request op: {parsed_request.type}")
429
430if is_streaming_req(parsed_request):
431return StreamingResult(self._process_streaming(parsed_request, context))
432else:
433return NonStreamingResult(
434await self._process_non_streaming(parsed_request, context)
435)
436
437async def _process_non_streaming(
438self, request: NonStreamingReq, context: TContext
439) -> bytes:
440match request:
441case ThreadsGetByIdReq():
442thread = await self._load_full_thread(
443request.params.thread_id, context=context
444)
445return self._serialize(self._to_thread_response(thread))
446case ThreadsListReq():
447params = request.params
448threads = await self.store.load_threads(
449limit=params.limit or DEFAULT_PAGE_SIZE,
450after=params.after,
451order=params.order,
452context=context,
453)
454return self._serialize(
455Page(
456has_more=threads.has_more,
457after=threads.after,
458data=[
459self._to_thread_response(thread) for thread in threads.data
460],
461)
462)
463case ItemsFeedbackReq():
464await self.add_feedback(
465request.params.thread_id,
466request.params.item_ids,
467request.params.kind,
468context,
469)
470return b"{}"
471case AttachmentsCreateReq():
472attachment_store = self._get_attachment_store()
473attachment = await attachment_store.create_attachment(
474request.params, context
475)
7c90a733Jiwon Kim6 months ago476await self.store.save_attachment(attachment, context=context)
f688d870victor-openai8 months ago477return self._serialize(attachment)
478case AttachmentsDeleteReq():
479attachment_store = self._get_attachment_store()
480await attachment_store.delete_attachment(
481request.params.attachment_id, context=context
482)
483await self.store.delete_attachment(
484request.params.attachment_id, context=context
485)
486return b"{}"
a202f603Jiwon Kim5 months ago487case InputTranscribeReq():
488audio_bytes = base64.b64decode(request.params.audio_base64)
489transcription_result = await self.transcribe(
9c6c0f90Jiwon Kim5 months ago490AudioInput(data=audio_bytes, mime_type=request.params.mime_type),
491context=context,
a202f603Jiwon Kim5 months ago492)
493return self._serialize(transcription_result)
f688d870victor-openai8 months ago494case ItemsListReq():
495items_list_params = request.params
496items = await self.store.load_thread_items(
497items_list_params.thread_id,
498limit=items_list_params.limit or DEFAULT_PAGE_SIZE,
499order=items_list_params.order,
500after=items_list_params.after,
501context=context,
502)
1860286bJiwon Kim7 months ago503# filter out hidden context items
f688d870victor-openai8 months ago504items.data = [
505item
506for item in items.data
1860286bJiwon Kim7 months ago507if not isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
f688d870victor-openai8 months ago508]
509return self._serialize(items)
510case ThreadsUpdateReq():
511thread = await self.store.load_thread(
512request.params.thread_id, context=context
513)
514thread.title = request.params.title
515await self.store.save_thread(thread, context=context)
516return self._serialize(self._to_thread_response(thread))
517case ThreadsDeleteReq():
518await self.store.delete_thread(
519request.params.thread_id, context=context
520)
521return b"{}"
8bb983f1David Weedon4 months ago522case ThreadsSyncCustomActionReq():
523return await self._process_sync_custom_action(request, context)
f688d870victor-openai8 months ago524case _:
525assert_never(request)
526
527async def _process_streaming(
528self, request: StreamingReq, context: TContext
529) -> AsyncGenerator[bytes, None]:
530try:
531async for event in self._process_streaming_impl(request, context):
532b = self._serialize(event)
533yield b"data: " + b + b"\n\n"
eee0a2f7Jiwon Kim7 months ago534except asyncio.CancelledError:
535# Let cancellation bubble up without logging as an error.
536raise
f688d870victor-openai8 months ago537except Exception:
538logger.exception("Error while generating streamed response")
539raise
540
541async def _process_streaming_impl(
542self, request: StreamingReq, context: TContext
543) -> AsyncGenerator[ThreadStreamEvent, None]:
544match request:
545case ThreadsCreateReq():
546thread = Thread(
547id=self.store.generate_thread_id(context),
548created_at=datetime.now(),
549items=Page(),
550)
8885a10bquettabit8 months ago551await self.store.save_thread(
552ThreadMetadata(**thread.model_dump()),
553context=context,
554)
f688d870victor-openai8 months ago555yield ThreadCreatedEvent(thread=self._to_thread_response(thread))
556user_message = await self._build_user_message_item(
557request.params.input, thread, context
558)
559async for event in self._process_new_thread_item_respond(
560thread,
561user_message,
562context,
563):
564yield event
565
566case ThreadsAddUserMessageReq():
567thread = await self.store.load_thread(
568request.params.thread_id, context=context
569)
570user_message = await self._build_user_message_item(
571request.params.input, thread, context
572)
573async for event in self._process_new_thread_item_respond(
574thread,
575user_message,
576context,
577):
578yield event
579
580case ThreadsAddClientToolOutputReq():
581thread = await self.store.load_thread(
582request.params.thread_id, context=context
583)
584items = await self.store.load_thread_items(
585thread.id, None, 1, "desc", context
586)
587tool_call = next(
588(
589item
590for item in items.data
591if isinstance(item, ClientToolCallItem)
592and item.status == "pending"
593),
594None,
595)
596if not tool_call:
597raise ValueError(
598f"Last thread item in {thread.id} was not a ClientToolCallItem"
599)
600
601tool_call.output = request.params.result
602tool_call.status = "completed"
603
604await self.store.save_item(thread.id, tool_call, context=context)
605
606# Safety against dangling pending tool calls if there are
607# multiple in a row, which should be impossible, and
608# integrations should ultimately filter out pending tool calls
609# when creating input response messages.
610await self._cleanup_pending_client_tool_call(thread, context)
611
612async for event in self._process_events(
613thread,
614context,
615lambda: self.respond(thread, None, context),
616):
617yield event
618
619case ThreadsRetryAfterItemReq():
620thread_metadata = await self.store.load_thread(
621request.params.thread_id, context=context
622)
623
624# Collect items to remove (all items after the user message)
625items_to_remove: list[ThreadItem] = []
626user_message_item = None
627
628async for item in self._paginate_thread_items_reverse(
629request.params.thread_id, context
630):
631if item.id == request.params.item_id:
632if not isinstance(item, UserMessageItem):
633raise ValueError(
634f"Item {request.params.item_id} is not a user message"
635)
636user_message_item = item
637break
638items_to_remove.append(item)
639
640if user_message_item:
641for item in items_to_remove:
642await self.store.delete_thread_item(
643request.params.thread_id, item.id, context=context
644)
645async for event in self._process_events(
646thread_metadata,
647context,
648lambda: self.respond(
649thread_metadata,
650user_message_item,
651context,
652),
653):
654yield event
655case ThreadsCustomActionReq():
656thread_metadata = await self.store.load_thread(
657request.params.thread_id, context=context
658)
659
660item: ThreadItem | None = None
661if request.params.item_id:
662item = await self.store.load_item(
663request.params.thread_id,
664request.params.item_id,
665context=context,
666)
667
668if item and not isinstance(item, WidgetItem):
669# shouldn't happen if the caller is using the API correctly.
670yield ErrorEvent(
671code=ErrorCode.STREAM_ERROR,
672allow_retry=False,
673)
674return
675
676async for event in self._process_events(
677thread_metadata,
678context,
679lambda: self.action(
680thread_metadata,
681request.params.action,
682item,
683context,
684),
685):
686yield event
687
688case _:
689assert_never(request)
690
8bb983f1David Weedon4 months ago691async def _process_sync_custom_action(
692self, request: ThreadsSyncCustomActionReq, context: TContext
693) -> bytes:
694thread_metadata = await self.store.load_thread(
695request.params.thread_id, context=context
696)
697
698item: ThreadItem | None = None
699if request.params.item_id:
700item = await self.store.load_item(
701request.params.thread_id,
702request.params.item_id,
703context=context,
704)
705
706if item and not isinstance(item, WidgetItem):
707raise ValueError("threads.sync_custom_action requires a widget sender item")
708
709return self._serialize(
d14ce460David Weedon4 months ago710await self.sync_action(
8bb983f1David Weedon4 months ago711thread_metadata,
712request.params.action,
713item,
714context,
715)
716)
717
f688d870victor-openai8 months ago718async def _cleanup_pending_client_tool_call(
719self, thread: ThreadMetadata, context: TContext
720) -> None:
721items = await self.store.load_thread_items(
722thread.id, None, DEFAULT_PAGE_SIZE, "desc", context
723)
724for tool_call in items.data:
725if not isinstance(tool_call, ClientToolCallItem):
726continue
727if tool_call.status == "pending":
728logger.warning(
729f"Client tool call {tool_call.call_id} was not completed, ignoring"
730)
731await self.store.delete_thread_item(
732thread.id, tool_call.id, context=context
733)
734
735async def _process_new_thread_item_respond(
736self,
737thread: ThreadMetadata,
738item: UserMessageItem,
739context: TContext,
740) -> AsyncIterator[ThreadStreamEvent]:
4a0b6794Jiwon Kim5 months ago741# Store the updated attachments (now that they have a thread id).
742for attachment in item.attachments:
743await self.store.save_attachment(attachment, context=context)
744
f688d870victor-openai8 months ago745await self.store.add_thread_item(thread.id, item, context=context)
746yield ThreadItemDoneEvent(item=item)
747
748async for event in self._process_events(
749thread,
750context,
751lambda: self.respond(thread, item, context),
752):
753yield event
754
755async def _process_events(
756self,
757thread: ThreadMetadata,
758context: TContext,
759stream: Callable[[], AsyncIterator[ThreadStreamEvent]],
760) -> AsyncIterator[ThreadStreamEvent]:
761await asyncio.sleep(0) # allow the response to start streaming
762
5a7244baJiwon Kim7 months ago763# Send initial stream options
764yield StreamOptionsEvent(
765stream_options=self.get_stream_options(thread, context)
766)
767
f688d870victor-openai8 months ago768last_thread = thread.model_copy(deep=True)
769
eee0a2f7Jiwon Kim7 months ago770# Keep track of items that were streamed but not yet saved
771# so that we can persist them when the stream is cancelled.
772pending_items: dict[str, ThreadItem] = {}
773
f688d870victor-openai8 months ago774try:
775with agents_sdk_user_agent_override():
776async for event in stream():
eee0a2f7Jiwon Kim7 months ago777if isinstance(event, ThreadItemAddedEvent):
72483746Jiwon Kim5 months ago778# Stash an isolated copy in case we need to persist unfinished items
779# on cancellation; downstream handlers keep using the original event.item.
780pending_items[event.item.id] = event.item.model_copy(deep=True)
eee0a2f7Jiwon Kim7 months ago781
f688d870victor-openai8 months ago782match event:
783case ThreadItemDoneEvent():
784await self.store.add_thread_item(
785thread.id, event.item, context=context
786)
eee0a2f7Jiwon Kim7 months ago787pending_items.pop(event.item.id, None)
f688d870victor-openai8 months ago788case ThreadItemRemovedEvent():
789await self.store.delete_thread_item(
790thread.id, event.item_id, context=context
791)
eee0a2f7Jiwon Kim7 months ago792pending_items.pop(event.item_id, None)
f688d870victor-openai8 months ago793case ThreadItemReplacedEvent():
794await self.store.save_item(
795thread.id, event.item, context=context
796)
eee0a2f7Jiwon Kim7 months ago797pending_items.pop(event.item.id, None)
798case ThreadItemUpdatedEvent():
5a7244baJiwon Kim7 months ago799# Keep pending assistant message and workflow items up to date
800# so that we have a reference to the latest version of these pending items
801# when the stream is cancelled.
802self._update_pending_items(pending_items, event)
f688d870victor-openai8 months ago803
804# special case - don't send hidden context items back to the client
805should_swallow_event = isinstance(
806event, ThreadItemDoneEvent
1860286bJiwon Kim7 months ago807) and isinstance(
808event.item, (HiddenContextItem, SDKHiddenContextItem)
809)
f688d870victor-openai8 months ago810
811if not should_swallow_event:
812yield event
813
814# in case user updated the thread while streaming
815if thread != last_thread:
816last_thread = thread.model_copy(deep=True)
817await self.store.save_thread(thread, context=context)
818yield ThreadUpdatedEvent(
819thread=self._to_thread_response(thread)
820)
821# in case user updated the thread while streaming
822if thread != last_thread:
823last_thread = thread.model_copy(deep=True)
824await self.store.save_thread(thread, context=context)
825yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
eee0a2f7Jiwon Kim7 months ago826except asyncio.CancelledError:
827await self.handle_stream_cancelled(
828thread, list(pending_items.values()), context
829)
830raise
f688d870victor-openai8 months ago831except CustomStreamError as e:
832yield ErrorEvent(
833code="custom",
834message=e.message,
835allow_retry=e.allow_retry,
836)
837except StreamError as e:
838yield ErrorEvent(
839code=e.code,
840allow_retry=e.allow_retry,
841)
842except Exception as e:
843yield ErrorEvent(
844code=ErrorCode.STREAM_ERROR,
845allow_retry=True,
846)
847logger.exception(e)
848
849if thread != last_thread:
850# in case user updated the thread at the end of the stream
851await self.store.save_thread(thread, context=context)
852yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
853
eee0a2f7Jiwon Kim7 months ago854def _apply_assistant_message_update(
855self,
856item: AssistantMessageItem,
857update: AssistantMessageContentPartAdded
858| AssistantMessageContentPartTextDelta
859| AssistantMessageContentPartAnnotationAdded
860| AssistantMessageContentPartDone,
861) -> AssistantMessageItem:
862# Pad the content list so the requested content_index exists before we write into it.
863# (Streaming updates can arrive for an index that hasn’t been created yet)
811575b4Theophane Gregoir5 months ago864while len(item.content) <= update.content_index:
865item.content.append(AssistantMessageContent(text="", annotations=[]))
eee0a2f7Jiwon Kim7 months ago866
867match update:
868case AssistantMessageContentPartAdded():
811575b4Theophane Gregoir5 months ago869item.content[update.content_index] = update.content
eee0a2f7Jiwon Kim7 months ago870case AssistantMessageContentPartTextDelta():
811575b4Theophane Gregoir5 months ago871item.content[update.content_index].text += update.delta
eee0a2f7Jiwon Kim7 months ago872case AssistantMessageContentPartAnnotationAdded():
811575b4Theophane Gregoir5 months ago873annotations = item.content[update.content_index].annotations
eee0a2f7Jiwon Kim7 months ago874if update.annotation_index <= len(annotations):
875annotations.insert(update.annotation_index, update.annotation)
876else:
877annotations.append(update.annotation)
878case AssistantMessageContentPartDone():
811575b4Theophane Gregoir5 months ago879item.content[update.content_index] = update.content
880return item
eee0a2f7Jiwon Kim7 months ago881
5a7244baJiwon Kim7 months ago882def _update_pending_items(
eee0a2f7Jiwon Kim7 months ago883self,
884pending_items: dict[str, ThreadItem],
885event: ThreadItemUpdatedEvent,
886):
887updated_item = pending_items.get(event.item_id)
5a7244baJiwon Kim7 months ago888update = event.update
889match updated_item:
890case AssistantMessageItem():
891if isinstance(
892update,
893(
894AssistantMessageContentPartAdded,
895AssistantMessageContentPartTextDelta,
896AssistantMessageContentPartAnnotationAdded,
897AssistantMessageContentPartDone,
898),
899):
900pending_items[updated_item.id] = (
901self._apply_assistant_message_update(updated_item, update)
902)
903case WorkflowItem():
904if isinstance(update, (WorkflowTaskUpdated, WorkflowTaskAdded)):
905match update:
906case WorkflowTaskUpdated():
907updated_item.workflow.tasks[update.task_index] = update.task
908case WorkflowTaskAdded():
909updated_item.workflow.tasks.append(update.task)
910
911pending_items[updated_item.id] = updated_item
912case _:
913pass
eee0a2f7Jiwon Kim7 months ago914
f688d870victor-openai8 months ago915async def _build_user_message_item(
916self, input: UserMessageInput, thread: ThreadMetadata, context: TContext
917) -> UserMessageItem:
918return UserMessageItem(
919id=self.store.generate_item_id("message", thread, context),
920content=input.content,
921thread_id=thread.id,
922attachments=[
4a0b6794Jiwon Kim5 months ago923(await self.store.load_attachment(attachment_id, context)).model_copy(
924update={"thread_id": thread.id}
925)
f688d870victor-openai8 months ago926for attachment_id in input.attachments
927],
928quoted_text=input.quoted_text,
929inference_options=input.inference_options,
930created_at=datetime.now(),
931)
932
933async def _load_full_thread(self, thread_id: str, context: TContext) -> Thread:
934thread_meta = await self.store.load_thread(thread_id, context=context)
935thread_items = await self.store.load_thread_items(
936thread_id,
937after=None,
938limit=DEFAULT_PAGE_SIZE,
939order="asc",
940context=context,
941)
942return Thread(**thread_meta.model_dump(), items=thread_items)
943
944async def _paginate_thread_items_reverse(
945self, thread_id: str, context: TContext
946) -> AsyncIterator[ThreadItem]:
947"""Paginate through thread items in reverse order (newest first)."""
948after = None
949while True:
950items = await self.store.load_thread_items(
951thread_id, after, DEFAULT_PAGE_SIZE, "desc", context
952)
953for item in items.data:
954yield item
955
956if not items.has_more:
957break
958after = items.after
959
960def _serialize(self, obj: BaseModel) -> bytes:
b676003bJiwon Kim5 months ago961return obj.model_dump_json(
962by_alias=True,
963exclude_none=True,
964context={"exclude_metadata": True},
965).encode("utf-8")
f688d870victor-openai8 months ago966
967def _to_thread_response(self, thread: ThreadMetadata | Thread) -> Thread:
968def is_hidden(item: ThreadItem) -> bool:
1860286bJiwon Kim7 months ago969return isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
f688d870victor-openai8 months ago970
971items = thread.items if isinstance(thread, Thread) else Page()
972items.data = [item for item in items.data if not is_hidden(item)]
973
974return Thread(
975id=thread.id,
976title=thread.title,
977created_at=thread.created_at,
978items=items,
979status=thread.status,
10381acdJiwon Kim3 months ago980allowed_image_domains=thread.allowed_image_domains,
f688d870victor-openai8 months ago981)