openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
dacc133c280b39b9334d06ea73f0f1c199e59927

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/server.py

1063lines · modecode

1import asyncio
2import base64
3from abc import ABC, abstractmethod
4from collections.abc import AsyncIterator
5from contextlib import contextmanager
6from datetime import datetime
7from typing import (
8 Any,
9 AsyncGenerator,
10 AsyncIterable,
11 Callable,
12 Generic,
13)
14
15import agents
16from agents.models.chatcmpl_helpers import (
17 HEADERS_OVERRIDE as chat_completions_headers_override,
18)
19from agents.models.openai_responses import (
20 _HEADERS_OVERRIDE as responses_headers_override,
21)
22from pydantic import BaseModel, TypeAdapter
23from typing_extensions import TypeVar, assert_never
24
25from chatkit.errors import CustomStreamError, StreamError
26
27from .logger import logger
28from .store import AttachmentStore, Store, StoreItemType, default_generate_id
29from .types import (
30 Action,
31 AssistantMessageContent,
32 AssistantMessageContentPartAdded,
33 AssistantMessageContentPartAnnotationAdded,
34 AssistantMessageContentPartDone,
35 AssistantMessageContentPartTextDelta,
36 AssistantMessageItem,
37 AttachmentsCreateReq,
38 AttachmentsDeleteReq,
39 AudioInput,
40 ChatKitReq,
41 ClientToolCallItem,
42 ErrorCode,
43 ErrorEvent,
44 FeedbackKind,
45 HiddenContextItem,
46 InputTranscribeReq,
47 ItemsFeedbackReq,
48 ItemsListReq,
49 NonStreamingReq,
50 Page,
51 SDKHiddenContextItem,
52 StreamingReq,
53 StreamOptions,
54 StreamOptionsEvent,
55 StructuredInputAnswer,
56 StructuredInputItem,
57 StructuredInputMultipleChoice,
58 StructuredInputSubmission,
59 SyncCustomActionResponse,
60 Thread,
61 ThreadCreatedEvent,
62 ThreadItem,
63 ThreadItemAddedEvent,
64 ThreadItemDoneEvent,
65 ThreadItemRemovedEvent,
66 ThreadItemReplacedEvent,
67 ThreadItemUpdatedEvent,
68 ThreadMetadata,
69 ThreadsAddClientToolOutputReq,
70 ThreadsAddStructuredInputReq,
71 ThreadsAddUserMessageReq,
72 ThreadsCreateReq,
73 ThreadsCustomActionReq,
74 ThreadsDeleteReq,
75 ThreadsGetByIdReq,
76 ThreadsListReq,
77 ThreadsRetryAfterItemReq,
78 ThreadsSyncCustomActionReq,
79 ThreadStreamEvent,
80 ThreadsUpdateReq,
81 ThreadUpdatedEvent,
82 TranscriptionResult,
83 UserMessageInput,
84 UserMessageItem,
85 WidgetComponentUpdated,
86 WidgetItem,
87 WidgetRootUpdated,
88 WidgetStreamingTextValueDelta,
89 WorkflowItem,
90 WorkflowTaskAdded,
91 WorkflowTaskUpdated,
92 is_streaming_req,
93)
94from .version import __version__
95from .widgets import WidgetComponent, WidgetComponentBase, WidgetRoot
96
97DEFAULT_PAGE_SIZE = 20
98DEFAULT_ERROR_MESSAGE = "An error occurred when generating a response."
99
100
101def diff_widget(
102 before: WidgetRoot, after: WidgetRoot
103) -> list[WidgetStreamingTextValueDelta | WidgetRootUpdated | WidgetComponentUpdated]:
104 """
105 Compare two WidgetRoots and return a list of deltas.
106 """
107
108 def is_streaming_text(component: WidgetComponentBase) -> bool:
109 return getattr(component, "type", None) in {"Markdown", "Text"} and isinstance(
110 getattr(component, "value", None), str
111 )
112
113 def full_replace(before: WidgetComponentBase, after: WidgetComponentBase) -> bool:
114 if (
115 before.type != after.type
116 or before.id != after.id
117 or before.key != after.key
118 ):
119 return True
120
121 def full_replace_value(before_value: Any, after_value: Any) -> bool:
122 if isinstance(before_value, list) and isinstance(after_value, list):
123 if len(before_value) != len(after_value):
124 return True
125 for nth_before_value, nth_after_value in zip(before_value, after_value):
126 if full_replace_value(nth_before_value, nth_after_value):
127 return True
128 elif before_value != after_value:
129 if isinstance(before_value, WidgetComponentBase) and isinstance(
130 after_value, WidgetComponentBase
131 ):
132 return full_replace(before_value, after_value)
133 else:
134 return True
135 return False
136
137 for field in before.model_fields_set.union(after.model_fields_set):
138 if (
139 is_streaming_text(before)
140 and is_streaming_text(after)
141 and field == "value"
142 and getattr(after, "value", "").startswith(getattr(before, "value", ""))
143 ):
144 # Appends to the value prop of Markdown or Text do not trigger a full replace
145 continue
146 if full_replace_value(getattr(before, field), getattr(after, field)):
147 return True
148
149 return False
150
151 if full_replace(before, after):
152 return [WidgetRootUpdated(widget=after)]
153
154 deltas: list[
155 WidgetStreamingTextValueDelta | WidgetComponentUpdated | WidgetRootUpdated
156 ] = []
157
158 def find_all_streaming_text_components(
159 component: WidgetComponent | WidgetRoot,
160 ) -> dict[str, WidgetComponentBase]:
161 components = {}
162
163 def recurse(component: WidgetComponent | WidgetRoot):
164 if is_streaming_text(component) and component.id:
165 components[component.id] = component
166
167 if hasattr(component, "children"):
168 children = getattr(component, "children", None) or []
169 for child in children:
170 recurse(child)
171
172 recurse(component)
173 return components
174
175 before_nodes = find_all_streaming_text_components(before)
176 after_nodes = find_all_streaming_text_components(after)
177
178 for id, after_node in after_nodes.items():
179 before_node = before_nodes.get(id)
180 if before_node is None:
181 raise ValueError(
182 f"Node {id} was not present when the widget was initially rendered. All nodes with ID must persist across all widget updates."
183 )
184
185 before_value = str(getattr(before_node, "value", None))
186 after_value = str(getattr(after_node, "value", None))
187
188 if before_value != after_value:
189 if not after_value.startswith(before_value):
190 raise ValueError(
191 f"Node {id} was updated with a new value that is not a prefix of the initial value. All widget updates must be cumulative."
192 )
193 done = not getattr(after_node, "streaming", False)
194 deltas.append(
195 WidgetStreamingTextValueDelta(
196 component_id=id,
197 delta=after_value[len(before_value) :],
198 done=done,
199 )
200 )
201
202 return deltas
203
204
205async def stream_widget(
206 thread: ThreadMetadata,
207 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
208 copy_text: str | None = None,
209 generate_id: Callable[[StoreItemType], str] = default_generate_id,
210) -> AsyncIterator[ThreadStreamEvent]:
211 """Stream a widget root (or async sequence of roots) as ThreadStreamEvents."""
212 item_id = generate_id("message")
213
214 if not isinstance(widget, AsyncGenerator):
215 yield ThreadItemDoneEvent(
216 item=WidgetItem(
217 id=item_id,
218 thread_id=thread.id,
219 created_at=datetime.now(),
220 widget=widget,
221 copy_text=copy_text,
222 ),
223 )
224 return
225
226 initial_state = await widget.__anext__()
227
228 item = WidgetItem(
229 id=item_id,
230 created_at=datetime.now(),
231 widget=initial_state,
232 copy_text=copy_text,
233 thread_id=thread.id,
234 )
235
236 yield ThreadItemAddedEvent(item=item)
237
238 last_state = initial_state
239
240 while widget:
241 try:
242 new_state = await widget.__anext__()
243 for update in diff_widget(last_state, new_state):
244 yield ThreadItemUpdatedEvent(
245 item_id=item_id,
246 update=update,
247 )
248 last_state = new_state
249 except StopAsyncIteration:
250 break
251
252 yield ThreadItemDoneEvent(
253 item=item.model_copy(update={"widget": last_state}),
254 )
255
256
257@contextmanager
258def agents_sdk_user_agent_override():
259 ua = f"Agents/Python {agents.__version__} ChatKit/Python {__version__}"
260 chat_completions_token = chat_completions_headers_override.set({"User-Agent": ua})
261 responses_token = responses_headers_override.set({"User-Agent": ua})
262
263 yield
264
265 chat_completions_headers_override.reset(chat_completions_token)
266 responses_headers_override.reset(responses_token)
267
268
269class StreamingResult(AsyncIterable[bytes]):
270 def __init__(self, stream: AsyncGenerator[bytes, None]):
271 self.json_events = stream
272
273 async def __aiter__(self):
274 async for event in self.json_events:
275 yield event
276
277
278class NonStreamingResult:
279 def __init__(self, result: bytes):
280 self.json = result
281
282
283TContext = TypeVar("TContext", default=Any)
284
285
286class ChatKitServer(ABC, Generic[TContext]):
287 def __init__(
288 self,
289 store: Store[TContext],
290 attachment_store: AttachmentStore[TContext] | None = None,
291 ):
292 """Create a ChatKitServer with the backing Store and optional AttachmentStore."""
293 self.store = store
294 self.attachment_store = attachment_store
295
296 def _get_attachment_store(self) -> AttachmentStore[TContext]:
297 """Return the configured AttachmentStore or raise if missing."""
298 if self.attachment_store is None:
299 raise RuntimeError(
300 "AttachmentStore is not configured. Provide a AttachmentStore to ChatKitServer to handle file operations."
301 )
302 return self.attachment_store
303
304 @abstractmethod
305 def respond(
306 self,
307 thread: ThreadMetadata,
308 input_user_message: UserMessageItem | None,
309 context: TContext,
310 ) -> AsyncIterator[ThreadStreamEvent]:
311 """Stream `ThreadStreamEvent` instances for a new user message.
312
313 Args:
314 thread: Metadata for the thread being processed.
315 input_user_message: The incoming message the server should respond to, if any.
316 context: Arbitrary per-request context provided by the caller.
317
318 Returns:
319 An async iterator that yields events representing the server's response.
320 """
321 pass
322
323 async def add_feedback( # noqa: B027
324 self,
325 thread_id: str,
326 item_ids: list[str],
327 feedback: FeedbackKind,
328 context: TContext,
329 ) -> None:
330 """Persist user feedback for one or more thread items."""
331 pass
332
333 async def transcribe( # noqa: B027
334 self, audio_input: AudioInput, context: TContext
335 ) -> TranscriptionResult:
336 """Transcribe speech audio to text (dictation).
337
338 Audio bytes are in `audio_input.data`. The raw MIME type is `audio_input.mime_type`; use
339 `audio_input.media_type` for the base media type (one of: `audio/webm`, `audio/ogg`, `audio/mp4`).
340
341 Args:
342 audio_input: Audio bytes plus MIME type metadata for transcription.
343 context: Arbitrary per-request context provided by the caller.
344
345 Returns:
346 A `TranscriptionResult` whose `text` is what should appear in the composer.
347 """
348 raise NotImplementedError(
349 "transcribe() must be overridden to support the input.transcribe request."
350 )
351
352 def action(
353 self,
354 thread: ThreadMetadata,
355 action: Action[str, Any],
356 sender: WidgetItem | None,
357 context: TContext,
358 ) -> AsyncIterator[ThreadStreamEvent]:
359 """Handle a widget or client-dispatched action and yield response events."""
360 raise NotImplementedError(
361 "The action() method must be overridden to react to actions. "
362 "See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
363 )
364
365 async def sync_action(
366 self,
367 thread: ThreadMetadata,
368 action: Action[str, Any],
369 sender: WidgetItem | None,
370 context: TContext,
371 ) -> SyncCustomActionResponse:
372 """Handle a widget or client-dispatched action and return a SyncCustomActionResponse."""
373 raise NotImplementedError(
374 "The sync_action() method must be overridden to react to sync actions. "
375 "See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
376 )
377
378 def get_stream_options(
379 self, thread: ThreadMetadata, context: TContext
380 ) -> StreamOptions:
381 """
382 Return stream-level runtime options. Allows the user to cancel the stream by default.
383 Override this method to customize behavior.
384 """
385 return StreamOptions(allow_cancel=True)
386
387 async def handle_stream_cancelled(
388 self,
389 thread: ThreadMetadata,
390 pending_items: list[ThreadItem],
391 context: TContext,
392 ):
393 """Perform custom cleanup / stop inference when a stream is cancelled.
394 Updates you make here will not be reflected in the UI until a reload.
395
396 The default implementation persists any non-empty pending assistant messages
397 to the thread but does not auto-save pending widget items or workflow items.
398
399 Args:
400 thread: The thread that was being processed.
401 pending_items: Items that were not done streaming at cancellation time.
402 context: Arbitrary per-request context provided by the caller.
403 """
404 pending_assistant_message_items: list[AssistantMessageItem] = [
405 item for item in pending_items if isinstance(item, AssistantMessageItem)
406 ]
407 for item in pending_assistant_message_items:
408 is_empty = len(item.content) == 0 or all(
409 (not content.text.strip()) for content in item.content
410 )
411 if not is_empty:
412 await self.store.add_thread_item(thread.id, item, context=context)
413
414 # Add a hidden context item to the thread to indicate that the stream was cancelled.
415 # Otherwise, depending on the timing of the cancellation, subsequent responses may
416 # attempt to continue the cancelled response.
417 await self.store.add_thread_item(
418 thread.id,
419 SDKHiddenContextItem(
420 thread_id=thread.id,
421 created_at=datetime.now(),
422 id=self.store.generate_item_id("sdk_hidden_context", thread, context),
423 content="The user cancelled the stream. Stop responding to the prior request.",
424 ),
425 context=context,
426 )
427
428 async def process(
429 self, request: str | bytes | bytearray, context: TContext
430 ) -> StreamingResult | NonStreamingResult:
431 """Parse an incoming ChatKit request and route it to streaming or non-streaming handlers."""
432 parsed_request = TypeAdapter[ChatKitReq](ChatKitReq).validate_json(request)
433 logger.info(f"Received request op: {parsed_request.type}")
434
435 if is_streaming_req(parsed_request):
436 return StreamingResult(self._process_streaming(parsed_request, context))
437 else:
438 return NonStreamingResult(
439 await self._process_non_streaming(parsed_request, context)
440 )
441
442 async def _process_non_streaming(
443 self, request: NonStreamingReq, context: TContext
444 ) -> bytes:
445 match request:
446 case ThreadsGetByIdReq():
447 thread = await self._load_full_thread(
448 request.params.thread_id, context=context
449 )
450 return self._serialize(self._to_thread_response(thread))
451 case ThreadsListReq():
452 params = request.params
453 threads = await self.store.load_threads(
454 limit=params.limit or DEFAULT_PAGE_SIZE,
455 after=params.after,
456 order=params.order,
457 context=context,
458 )
459 return self._serialize(
460 Page(
461 has_more=threads.has_more,
462 after=threads.after,
463 data=[
464 self._to_thread_response(thread) for thread in threads.data
465 ],
466 )
467 )
468 case ItemsFeedbackReq():
469 await self.add_feedback(
470 request.params.thread_id,
471 request.params.item_ids,
472 request.params.kind,
473 context,
474 )
475 return b"{}"
476 case AttachmentsCreateReq():
477 attachment_store = self._get_attachment_store()
478 attachment = await attachment_store.create_attachment(
479 request.params, context
480 )
481 await self.store.save_attachment(attachment, context=context)
482 return self._serialize(attachment)
483 case AttachmentsDeleteReq():
484 attachment_store = self._get_attachment_store()
485 await attachment_store.delete_attachment(
486 request.params.attachment_id, context=context
487 )
488 await self.store.delete_attachment(
489 request.params.attachment_id, context=context
490 )
491 return b"{}"
492 case InputTranscribeReq():
493 audio_bytes = base64.b64decode(request.params.audio_base64)
494 transcription_result = await self.transcribe(
495 AudioInput(data=audio_bytes, mime_type=request.params.mime_type),
496 context=context,
497 )
498 return self._serialize(transcription_result)
499 case ItemsListReq():
500 items_list_params = request.params
501 items = await self.store.load_thread_items(
502 items_list_params.thread_id,
503 limit=items_list_params.limit or DEFAULT_PAGE_SIZE,
504 order=items_list_params.order,
505 after=items_list_params.after,
506 context=context,
507 )
508 # filter out hidden context items
509 items.data = [
510 item
511 for item in items.data
512 if not isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
513 ]
514 return self._serialize(items)
515 case ThreadsUpdateReq():
516 thread = await self.store.load_thread(
517 request.params.thread_id, context=context
518 )
519 thread.title = request.params.title
520 await self.store.save_thread(thread, context=context)
521 return self._serialize(self._to_thread_response(thread))
522 case ThreadsDeleteReq():
523 await self.store.delete_thread(
524 request.params.thread_id, context=context
525 )
526 return b"{}"
527 case ThreadsSyncCustomActionReq():
528 return await self._process_sync_custom_action(request, context)
529 case _:
530 assert_never(request)
531
532 async def _process_streaming(
533 self, request: StreamingReq, context: TContext
534 ) -> AsyncGenerator[bytes, None]:
535 try:
536 async for event in self._process_streaming_impl(request, context):
537 b = self._serialize(event)
538 yield b"data: " + b + b"\n\n"
539 except asyncio.CancelledError:
540 # Let cancellation bubble up without logging as an error.
541 raise
542 except Exception:
543 logger.exception("Error while generating streamed response")
544 raise
545
546 async def _process_streaming_impl(
547 self, request: StreamingReq, context: TContext
548 ) -> AsyncGenerator[ThreadStreamEvent, None]:
549 match request:
550 case ThreadsCreateReq():
551 thread = Thread(
552 id=self.store.generate_thread_id(context),
553 created_at=datetime.now(),
554 items=Page(),
555 )
556 await self.store.save_thread(
557 ThreadMetadata(**thread.model_dump()),
558 context=context,
559 )
560 yield ThreadCreatedEvent(thread=self._to_thread_response(thread))
561 user_message = await self._build_user_message_item(
562 request.params.input, thread, context
563 )
564 async for event in self._process_new_thread_item_respond(
565 thread,
566 user_message,
567 context,
568 ):
569 yield event
570
571 case ThreadsAddUserMessageReq():
572 thread = await self.store.load_thread(
573 request.params.thread_id, context=context
574 )
575 user_message = await self._build_user_message_item(
576 request.params.input, thread, context
577 )
578 async for event in self._process_new_thread_item_respond(
579 thread,
580 user_message,
581 context,
582 ):
583 yield event
584
585 case ThreadsAddClientToolOutputReq():
586 thread = await self.store.load_thread(
587 request.params.thread_id, context=context
588 )
589 items = await self.store.load_thread_items(
590 thread.id, None, 1, "desc", context
591 )
592 tool_call = next(
593 (
594 item
595 for item in items.data
596 if isinstance(item, ClientToolCallItem)
597 and item.status == "pending"
598 ),
599 None,
600 )
601 if not tool_call:
602 raise ValueError(
603 f"Last thread item in {thread.id} was not a ClientToolCallItem"
604 )
605
606 tool_call.output = request.params.result
607 tool_call.status = "completed"
608
609 await self.store.save_item(thread.id, tool_call, context=context)
610
611 # Safety against dangling pending tool calls if there are
612 # multiple in a row, which should be impossible, and
613 # integrations should ultimately filter out pending tool calls
614 # when creating input response messages.
615 await self._cleanup_pending_client_tool_call(thread, context)
616
617 async for event in self._process_events(
618 thread,
619 context,
620 lambda: self.respond(thread, None, context),
621 ):
622 yield event
623
624 case ThreadsAddStructuredInputReq():
625 thread = await self.store.load_thread(
626 request.params.thread_id, context=context
627 )
628 item = await self.store.load_item(
629 request.params.thread_id,
630 request.params.item_id,
631 context=context,
632 )
633 if not isinstance(item, StructuredInputItem):
634 raise ValueError(
635 f"Item {request.params.item_id} is not a StructuredInputItem"
636 )
637
638 updated_item = await self._apply_structured_input_submission(
639 item,
640 request.params.input,
641 )
642
643 async def stream_structured_input_response() -> AsyncIterator[
644 ThreadStreamEvent
645 ]:
646 # Keep this replacement inside _process_events so it is persisted
647 # through the same event pipeline as other thread item replacements.
648 yield ThreadItemReplacedEvent(item=updated_item)
649 async for event in self.respond(thread, None, context):
650 yield event
651
652 async for event in self._process_events(
653 thread,
654 context,
655 stream_structured_input_response,
656 ):
657 yield event
658
659 case ThreadsRetryAfterItemReq():
660 thread_metadata = await self.store.load_thread(
661 request.params.thread_id, context=context
662 )
663
664 # Collect items to remove (all items after the user message)
665 items_to_remove: list[ThreadItem] = []
666 user_message_item = None
667
668 async for item in self._paginate_thread_items_reverse(
669 request.params.thread_id, context
670 ):
671 if item.id == request.params.item_id:
672 if not isinstance(item, UserMessageItem):
673 raise ValueError(
674 f"Item {request.params.item_id} is not a user message"
675 )
676 user_message_item = item
677 break
678 items_to_remove.append(item)
679
680 if user_message_item:
681 for item in items_to_remove:
682 await self.store.delete_thread_item(
683 request.params.thread_id, item.id, context=context
684 )
685 async for event in self._process_events(
686 thread_metadata,
687 context,
688 lambda: self.respond(
689 thread_metadata,
690 user_message_item,
691 context,
692 ),
693 ):
694 yield event
695 case ThreadsCustomActionReq():
696 thread_metadata = await self.store.load_thread(
697 request.params.thread_id, context=context
698 )
699
700 item: ThreadItem | None = None
701 if request.params.item_id:
702 item = await self.store.load_item(
703 request.params.thread_id,
704 request.params.item_id,
705 context=context,
706 )
707
708 if item and not isinstance(item, WidgetItem):
709 # shouldn't happen if the caller is using the API correctly.
710 yield ErrorEvent(
711 code=ErrorCode.STREAM_ERROR,
712 allow_retry=False,
713 )
714 return
715
716 async for event in self._process_events(
717 thread_metadata,
718 context,
719 lambda: self.action(
720 thread_metadata,
721 request.params.action,
722 item,
723 context,
724 ),
725 ):
726 yield event
727
728 case _:
729 assert_never(request)
730
731 async def _process_sync_custom_action(
732 self, request: ThreadsSyncCustomActionReq, context: TContext
733 ) -> bytes:
734 thread_metadata = await self.store.load_thread(
735 request.params.thread_id, context=context
736 )
737
738 item: ThreadItem | None = None
739 if request.params.item_id:
740 item = await self.store.load_item(
741 request.params.thread_id,
742 request.params.item_id,
743 context=context,
744 )
745
746 if item and not isinstance(item, WidgetItem):
747 raise ValueError("threads.sync_custom_action requires a widget sender item")
748
749 return self._serialize(
750 await self.sync_action(
751 thread_metadata,
752 request.params.action,
753 item,
754 context,
755 )
756 )
757
758 async def _apply_structured_input_submission(
759 self,
760 item: StructuredInputItem,
761 submission: StructuredInputSubmission,
762 ) -> StructuredInputItem:
763 if item.status != "pending":
764 raise ValueError(f"Structured input item {item.id} is not pending")
765
766 updated_item = item.model_copy(deep=True)
767 questions_by_id = {
768 question.id: question for question in updated_item.inputs
769 }
770 submitted_question_ids = set(submission.answers)
771 unknown_question_ids = submitted_question_ids.difference(questions_by_id)
772 if unknown_question_ids:
773 unknown = ", ".join(sorted(unknown_question_ids))
774 logger.warning(
775 f"Structured input item {item.id} received unknown question id(s), ignoring: {unknown}"
776 )
777
778 for question in updated_item.inputs:
779 answer = submission.answers.get(question.id)
780
781 if (
782 submission.status == "skipped" or
783 answer is None or
784 answer.skipped or
785 not answer.values
786 ):
787 question.answer = StructuredInputAnswer(skipped=True)
788 continue
789
790 values = answer.values
791 if isinstance(question, StructuredInputMultipleChoice):
792 if not question.multiple:
793 values = values[:1]
794
795 question.answer = StructuredInputAnswer(values=values)
796
797 updated_item.status = submission.status
798 return updated_item
799
800 async def _cleanup_pending_client_tool_call(
801 self, thread: ThreadMetadata, context: TContext
802 ) -> None:
803 items = await self.store.load_thread_items(
804 thread.id, None, DEFAULT_PAGE_SIZE, "desc", context
805 )
806 for tool_call in items.data:
807 if not isinstance(tool_call, ClientToolCallItem):
808 continue
809 if tool_call.status == "pending":
810 logger.warning(
811 f"Client tool call {tool_call.call_id} was not completed, ignoring"
812 )
813 await self.store.delete_thread_item(
814 thread.id, tool_call.id, context=context
815 )
816
817 async def _process_new_thread_item_respond(
818 self,
819 thread: ThreadMetadata,
820 item: UserMessageItem,
821 context: TContext,
822 ) -> AsyncIterator[ThreadStreamEvent]:
823 # Store the updated attachments (now that they have a thread id).
824 for attachment in item.attachments:
825 await self.store.save_attachment(attachment, context=context)
826
827 await self.store.add_thread_item(thread.id, item, context=context)
828 yield ThreadItemDoneEvent(item=item)
829
830 async for event in self._process_events(
831 thread,
832 context,
833 lambda: self.respond(thread, item, context),
834 ):
835 yield event
836
837 async def _process_events(
838 self,
839 thread: ThreadMetadata,
840 context: TContext,
841 stream: Callable[[], AsyncIterator[ThreadStreamEvent]],
842 ) -> AsyncIterator[ThreadStreamEvent]:
843 await asyncio.sleep(0) # allow the response to start streaming
844
845 # Send initial stream options
846 yield StreamOptionsEvent(
847 stream_options=self.get_stream_options(thread, context)
848 )
849
850 last_thread = thread.model_copy(deep=True)
851
852 # Keep track of items that were streamed but not yet saved
853 # so that we can persist them when the stream is cancelled.
854 pending_items: dict[str, ThreadItem] = {}
855
856 try:
857 with agents_sdk_user_agent_override():
858 async for event in stream():
859 if isinstance(event, ThreadItemAddedEvent):
860 # Stash an isolated copy in case we need to persist unfinished items
861 # on cancellation; downstream handlers keep using the original event.item.
862 pending_items[event.item.id] = event.item.model_copy(deep=True)
863
864 match event:
865 case ThreadItemDoneEvent():
866 await self.store.add_thread_item(
867 thread.id, event.item, context=context
868 )
869 pending_items.pop(event.item.id, None)
870 case ThreadItemRemovedEvent():
871 await self.store.delete_thread_item(
872 thread.id, event.item_id, context=context
873 )
874 pending_items.pop(event.item_id, None)
875 case ThreadItemReplacedEvent():
876 await self.store.save_item(
877 thread.id, event.item, context=context
878 )
879 pending_items.pop(event.item.id, None)
880 case ThreadItemUpdatedEvent():
881 # Keep pending assistant message and workflow items up to date
882 # so that we have a reference to the latest version of these pending items
883 # when the stream is cancelled.
884 self._update_pending_items(pending_items, event)
885
886 # special case - don't send hidden context items back to the client
887 should_swallow_event = isinstance(
888 event, ThreadItemDoneEvent
889 ) and isinstance(
890 event.item, (HiddenContextItem, SDKHiddenContextItem)
891 )
892
893 if not should_swallow_event:
894 yield event
895
896 # in case user updated the thread while streaming
897 if thread != last_thread:
898 last_thread = thread.model_copy(deep=True)
899 await self.store.save_thread(thread, context=context)
900 yield ThreadUpdatedEvent(
901 thread=self._to_thread_response(thread)
902 )
903 # in case user updated the thread while streaming
904 if thread != last_thread:
905 last_thread = thread.model_copy(deep=True)
906 await self.store.save_thread(thread, context=context)
907 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
908 except asyncio.CancelledError:
909 await self.handle_stream_cancelled(
910 thread, list(pending_items.values()), context
911 )
912 raise
913 except CustomStreamError as e:
914 yield ErrorEvent(
915 code="custom",
916 message=e.message,
917 allow_retry=e.allow_retry,
918 )
919 except StreamError as e:
920 yield ErrorEvent(
921 code=e.code,
922 allow_retry=e.allow_retry,
923 )
924 except Exception as e:
925 yield ErrorEvent(
926 code=ErrorCode.STREAM_ERROR,
927 allow_retry=True,
928 )
929 logger.exception(e)
930
931 if thread != last_thread:
932 # in case user updated the thread at the end of the stream
933 await self.store.save_thread(thread, context=context)
934 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
935
936 def _apply_assistant_message_update(
937 self,
938 item: AssistantMessageItem,
939 update: AssistantMessageContentPartAdded
940 | AssistantMessageContentPartTextDelta
941 | AssistantMessageContentPartAnnotationAdded
942 | AssistantMessageContentPartDone,
943 ) -> AssistantMessageItem:
944 # Pad the content list so the requested content_index exists before we write into it.
945 # (Streaming updates can arrive for an index that hasn’t been created yet)
946 while len(item.content) <= update.content_index:
947 item.content.append(AssistantMessageContent(text="", annotations=[]))
948
949 match update:
950 case AssistantMessageContentPartAdded():
951 item.content[update.content_index] = update.content
952 case AssistantMessageContentPartTextDelta():
953 item.content[update.content_index].text += update.delta
954 case AssistantMessageContentPartAnnotationAdded():
955 annotations = item.content[update.content_index].annotations
956 if update.annotation_index <= len(annotations):
957 annotations.insert(update.annotation_index, update.annotation)
958 else:
959 annotations.append(update.annotation)
960 case AssistantMessageContentPartDone():
961 item.content[update.content_index] = update.content
962 return item
963
964 def _update_pending_items(
965 self,
966 pending_items: dict[str, ThreadItem],
967 event: ThreadItemUpdatedEvent,
968 ):
969 updated_item = pending_items.get(event.item_id)
970 update = event.update
971 match updated_item:
972 case AssistantMessageItem():
973 if isinstance(
974 update,
975 (
976 AssistantMessageContentPartAdded,
977 AssistantMessageContentPartTextDelta,
978 AssistantMessageContentPartAnnotationAdded,
979 AssistantMessageContentPartDone,
980 ),
981 ):
982 pending_items[updated_item.id] = (
983 self._apply_assistant_message_update(updated_item, update)
984 )
985 case WorkflowItem():
986 if isinstance(update, (WorkflowTaskUpdated, WorkflowTaskAdded)):
987 match update:
988 case WorkflowTaskUpdated():
989 updated_item.workflow.tasks[update.task_index] = update.task
990 case WorkflowTaskAdded():
991 updated_item.workflow.tasks.append(update.task)
992
993 pending_items[updated_item.id] = updated_item
994 case _:
995 pass
996
997 async def _build_user_message_item(
998 self, input: UserMessageInput, thread: ThreadMetadata, context: TContext
999 ) -> UserMessageItem:
1000 return UserMessageItem(
1001 id=self.store.generate_item_id("message", thread, context),
1002 content=input.content,
1003 thread_id=thread.id,
1004 attachments=[
1005 (await self.store.load_attachment(attachment_id, context)).model_copy(
1006 update={"thread_id": thread.id}
1007 )
1008 for attachment_id in input.attachments
1009 ],
1010 quoted_text=input.quoted_text,
1011 inference_options=input.inference_options,
1012 created_at=datetime.now(),
1013 )
1014
1015 async def _load_full_thread(self, thread_id: str, context: TContext) -> Thread:
1016 thread_meta = await self.store.load_thread(thread_id, context=context)
1017 thread_items = await self.store.load_thread_items(
1018 thread_id,
1019 after=None,
1020 limit=DEFAULT_PAGE_SIZE,
1021 order="asc",
1022 context=context,
1023 )
1024 return Thread(**thread_meta.model_dump(), items=thread_items)
1025
1026 async def _paginate_thread_items_reverse(
1027 self, thread_id: str, context: TContext
1028 ) -> AsyncIterator[ThreadItem]:
1029 """Paginate through thread items in reverse order (newest first)."""
1030 after = None
1031 while True:
1032 items = await self.store.load_thread_items(
1033 thread_id, after, DEFAULT_PAGE_SIZE, "desc", context
1034 )
1035 for item in items.data:
1036 yield item
1037
1038 if not items.has_more:
1039 break
1040 after = items.after
1041
1042 def _serialize(self, obj: BaseModel) -> bytes:
1043 return obj.model_dump_json(
1044 by_alias=True,
1045 exclude_none=True,
1046 context={"exclude_metadata": True},
1047 ).encode("utf-8")
1048
1049 def _to_thread_response(self, thread: ThreadMetadata | Thread) -> Thread:
1050 def is_hidden(item: ThreadItem) -> bool:
1051 return isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
1052
1053 items = thread.items if isinstance(thread, Thread) else Page()
1054 items.data = [item for item in items.data if not is_hidden(item)]
1055
1056 return Thread(
1057 id=thread.id,
1058 title=thread.title,
1059 created_at=thread.created_at,
1060 items=items,
1061 status=thread.status,
1062 allowed_image_domains=thread.allowed_image_domains,
1063 )
1064