openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.6.2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/server.py

980lines · modecode

1import asyncio
2import base64
3from abc import ABC, abstractmethod
4from collections.abc import AsyncIterator
5from contextlib import contextmanager
6from datetime import datetime
7from typing import (
8 Any,
9 AsyncGenerator,
10 AsyncIterable,
11 Callable,
12 Generic,
13)
14
15import agents
16from agents.models.chatcmpl_helpers import (
17 HEADERS_OVERRIDE as chat_completions_headers_override,
18)
19from agents.models.openai_responses import (
20 _HEADERS_OVERRIDE as responses_headers_override,
21)
22from pydantic import BaseModel, TypeAdapter
23from typing_extensions import TypeVar, assert_never
24
25from chatkit.errors import CustomStreamError, StreamError
26
27from .logger import logger
28from .store import AttachmentStore, Store, StoreItemType, default_generate_id
29from .types import (
30 Action,
31 AssistantMessageContent,
32 AssistantMessageContentPartAdded,
33 AssistantMessageContentPartAnnotationAdded,
34 AssistantMessageContentPartDone,
35 AssistantMessageContentPartTextDelta,
36 AssistantMessageItem,
37 AttachmentsCreateReq,
38 AttachmentsDeleteReq,
39 AudioInput,
40 ChatKitReq,
41 ClientToolCallItem,
42 ErrorCode,
43 ErrorEvent,
44 FeedbackKind,
45 HiddenContextItem,
46 InputTranscribeReq,
47 ItemsFeedbackReq,
48 ItemsListReq,
49 NonStreamingReq,
50 Page,
51 SDKHiddenContextItem,
52 StreamingReq,
53 StreamOptions,
54 StreamOptionsEvent,
55 SyncCustomActionResponse,
56 Thread,
57 ThreadCreatedEvent,
58 ThreadItem,
59 ThreadItemAddedEvent,
60 ThreadItemDoneEvent,
61 ThreadItemRemovedEvent,
62 ThreadItemReplacedEvent,
63 ThreadItemUpdatedEvent,
64 ThreadMetadata,
65 ThreadsAddClientToolOutputReq,
66 ThreadsAddUserMessageReq,
67 ThreadsCreateReq,
68 ThreadsCustomActionReq,
69 ThreadsDeleteReq,
70 ThreadsGetByIdReq,
71 ThreadsListReq,
72 ThreadsRetryAfterItemReq,
73 ThreadsSyncCustomActionReq,
74 ThreadStreamEvent,
75 ThreadsUpdateReq,
76 ThreadUpdatedEvent,
77 TranscriptionResult,
78 UserMessageInput,
79 UserMessageItem,
80 WidgetComponentUpdated,
81 WidgetItem,
82 WidgetRootUpdated,
83 WidgetStreamingTextValueDelta,
84 WorkflowItem,
85 WorkflowTaskAdded,
86 WorkflowTaskUpdated,
87 is_streaming_req,
88)
89from .version import __version__
90from .widgets import WidgetComponent, WidgetComponentBase, WidgetRoot
91
92DEFAULT_PAGE_SIZE = 20
93DEFAULT_ERROR_MESSAGE = "An error occurred when generating a response."
94
95
96def diff_widget(
97 before: WidgetRoot, after: WidgetRoot
98) -> list[WidgetStreamingTextValueDelta | WidgetRootUpdated | WidgetComponentUpdated]:
99 """
100 Compare two WidgetRoots and return a list of deltas.
101 """
102
103 def is_streaming_text(component: WidgetComponentBase) -> bool:
104 return getattr(component, "type", None) in {"Markdown", "Text"} and isinstance(
105 getattr(component, "value", None), str
106 )
107
108 def full_replace(before: WidgetComponentBase, after: WidgetComponentBase) -> bool:
109 if (
110 before.type != after.type
111 or before.id != after.id
112 or before.key != after.key
113 ):
114 return True
115
116 def full_replace_value(before_value: Any, after_value: Any) -> bool:
117 if isinstance(before_value, list) and isinstance(after_value, list):
118 if len(before_value) != len(after_value):
119 return True
120 for nth_before_value, nth_after_value in zip(before_value, after_value):
121 if full_replace_value(nth_before_value, nth_after_value):
122 return True
123 elif before_value != after_value:
124 if isinstance(before_value, WidgetComponentBase) and isinstance(
125 after_value, WidgetComponentBase
126 ):
127 return full_replace(before_value, after_value)
128 else:
129 return True
130 return False
131
132 for field in before.model_fields_set.union(after.model_fields_set):
133 if (
134 is_streaming_text(before)
135 and is_streaming_text(after)
136 and field == "value"
137 and getattr(after, "value", "").startswith(getattr(before, "value", ""))
138 ):
139 # Appends to the value prop of Markdown or Text do not trigger a full replace
140 continue
141 if full_replace_value(getattr(before, field), getattr(after, field)):
142 return True
143
144 return False
145
146 if full_replace(before, after):
147 return [WidgetRootUpdated(widget=after)]
148
149 deltas: list[
150 WidgetStreamingTextValueDelta | WidgetComponentUpdated | WidgetRootUpdated
151 ] = []
152
153 def find_all_streaming_text_components(
154 component: WidgetComponent | WidgetRoot,
155 ) -> dict[str, WidgetComponentBase]:
156 components = {}
157
158 def recurse(component: WidgetComponent | WidgetRoot):
159 if is_streaming_text(component) and component.id:
160 components[component.id] = component
161
162 if hasattr(component, "children"):
163 children = getattr(component, "children", None) or []
164 for child in children:
165 recurse(child)
166
167 recurse(component)
168 return components
169
170 before_nodes = find_all_streaming_text_components(before)
171 after_nodes = find_all_streaming_text_components(after)
172
173 for id, after_node in after_nodes.items():
174 before_node = before_nodes.get(id)
175 if before_node is None:
176 raise ValueError(
177 f"Node {id} was not present when the widget was initially rendered. All nodes with ID must persist across all widget updates."
178 )
179
180 before_value = str(getattr(before_node, "value", None))
181 after_value = str(getattr(after_node, "value", None))
182
183 if before_value != after_value:
184 if not after_value.startswith(before_value):
185 raise ValueError(
186 f"Node {id} was updated with a new value that is not a prefix of the initial value. All widget updates must be cumulative."
187 )
188 done = not getattr(after_node, "streaming", False)
189 deltas.append(
190 WidgetStreamingTextValueDelta(
191 component_id=id,
192 delta=after_value[len(before_value) :],
193 done=done,
194 )
195 )
196
197 return deltas
198
199
200async def stream_widget(
201 thread: ThreadMetadata,
202 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
203 copy_text: str | None = None,
204 generate_id: Callable[[StoreItemType], str] = default_generate_id,
205) -> AsyncIterator[ThreadStreamEvent]:
206 """Stream a widget root (or async sequence of roots) as ThreadStreamEvents."""
207 item_id = generate_id("message")
208
209 if not isinstance(widget, AsyncGenerator):
210 yield ThreadItemDoneEvent(
211 item=WidgetItem(
212 id=item_id,
213 thread_id=thread.id,
214 created_at=datetime.now(),
215 widget=widget,
216 copy_text=copy_text,
217 ),
218 )
219 return
220
221 initial_state = await widget.__anext__()
222
223 item = WidgetItem(
224 id=item_id,
225 created_at=datetime.now(),
226 widget=initial_state,
227 copy_text=copy_text,
228 thread_id=thread.id,
229 )
230
231 yield ThreadItemAddedEvent(item=item)
232
233 last_state = initial_state
234
235 while widget:
236 try:
237 new_state = await widget.__anext__()
238 for update in diff_widget(last_state, new_state):
239 yield ThreadItemUpdatedEvent(
240 item_id=item_id,
241 update=update,
242 )
243 last_state = new_state
244 except StopAsyncIteration:
245 break
246
247 yield ThreadItemDoneEvent(
248 item=item.model_copy(update={"widget": last_state}),
249 )
250
251
252@contextmanager
253def agents_sdk_user_agent_override():
254 ua = f"Agents/Python {agents.__version__} ChatKit/Python {__version__}"
255 chat_completions_token = chat_completions_headers_override.set({"User-Agent": ua})
256 responses_token = responses_headers_override.set({"User-Agent": ua})
257
258 yield
259
260 chat_completions_headers_override.reset(chat_completions_token)
261 responses_headers_override.reset(responses_token)
262
263
264class StreamingResult(AsyncIterable[bytes]):
265 def __init__(self, stream: AsyncGenerator[bytes, None]):
266 self.json_events = stream
267
268 async def __aiter__(self):
269 async for event in self.json_events:
270 yield event
271
272
273class NonStreamingResult:
274 def __init__(self, result: bytes):
275 self.json = result
276
277
278TContext = TypeVar("TContext", default=Any)
279
280
281class ChatKitServer(ABC, Generic[TContext]):
282 def __init__(
283 self,
284 store: Store[TContext],
285 attachment_store: AttachmentStore[TContext] | None = None,
286 ):
287 """Create a ChatKitServer with the backing Store and optional AttachmentStore."""
288 self.store = store
289 self.attachment_store = attachment_store
290
291 def _get_attachment_store(self) -> AttachmentStore[TContext]:
292 """Return the configured AttachmentStore or raise if missing."""
293 if self.attachment_store is None:
294 raise RuntimeError(
295 "AttachmentStore is not configured. Provide a AttachmentStore to ChatKitServer to handle file operations."
296 )
297 return self.attachment_store
298
299 @abstractmethod
300 def respond(
301 self,
302 thread: ThreadMetadata,
303 input_user_message: UserMessageItem | None,
304 context: TContext,
305 ) -> AsyncIterator[ThreadStreamEvent]:
306 """Stream `ThreadStreamEvent` instances for a new user message.
307
308 Args:
309 thread: Metadata for the thread being processed.
310 input_user_message: The incoming message the server should respond to, if any.
311 context: Arbitrary per-request context provided by the caller.
312
313 Returns:
314 An async iterator that yields events representing the server's response.
315 """
316 pass
317
318 async def add_feedback( # noqa: B027
319 self,
320 thread_id: str,
321 item_ids: list[str],
322 feedback: FeedbackKind,
323 context: TContext,
324 ) -> None:
325 """Persist user feedback for one or more thread items."""
326 pass
327
328 async def transcribe( # noqa: B027
329 self, audio_input: AudioInput, context: TContext
330 ) -> TranscriptionResult:
331 """Transcribe speech audio to text (dictation).
332
333 Audio bytes are in `audio_input.data`. The raw MIME type is `audio_input.mime_type`; use
334 `audio_input.media_type` for the base media type (one of: `audio/webm`, `audio/ogg`, `audio/mp4`).
335
336 Args:
337 audio_input: Audio bytes plus MIME type metadata for transcription.
338 context: Arbitrary per-request context provided by the caller.
339
340 Returns:
341 A `TranscriptionResult` whose `text` is what should appear in the composer.
342 """
343 raise NotImplementedError(
344 "transcribe() must be overridden to support the input.transcribe request."
345 )
346
347 def action(
348 self,
349 thread: ThreadMetadata,
350 action: Action[str, Any],
351 sender: WidgetItem | None,
352 context: TContext,
353 ) -> AsyncIterator[ThreadStreamEvent]:
354 """Handle a widget or client-dispatched action and yield response events."""
355 raise NotImplementedError(
356 "The action() method must be overridden to react to actions. "
357 "See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
358 )
359
360 async def sync_action(
361 self,
362 thread: ThreadMetadata,
363 action: Action[str, Any],
364 sender: WidgetItem | None,
365 context: TContext,
366 ) -> SyncCustomActionResponse:
367 """Handle a widget or client-dispatched action and return a SyncCustomActionResponse."""
368 raise NotImplementedError(
369 "The sync_action() method must be overridden to react to sync actions. "
370 "See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
371 )
372
373 def get_stream_options(
374 self, thread: ThreadMetadata, context: TContext
375 ) -> StreamOptions:
376 """
377 Return stream-level runtime options. Allows the user to cancel the stream by default.
378 Override this method to customize behavior.
379 """
380 return StreamOptions(allow_cancel=True)
381
382 async def handle_stream_cancelled(
383 self,
384 thread: ThreadMetadata,
385 pending_items: list[ThreadItem],
386 context: TContext,
387 ):
388 """Perform custom cleanup / stop inference when a stream is cancelled.
389 Updates you make here will not be reflected in the UI until a reload.
390
391 The default implementation persists any non-empty pending assistant messages
392 to the thread but does not auto-save pending widget items or workflow items.
393
394 Args:
395 thread: The thread that was being processed.
396 pending_items: Items that were not done streaming at cancellation time.
397 context: Arbitrary per-request context provided by the caller.
398 """
399 pending_assistant_message_items: list[AssistantMessageItem] = [
400 item for item in pending_items if isinstance(item, AssistantMessageItem)
401 ]
402 for item in pending_assistant_message_items:
403 is_empty = len(item.content) == 0 or all(
404 (not content.text.strip()) for content in item.content
405 )
406 if not is_empty:
407 await self.store.add_thread_item(thread.id, item, context=context)
408
409 # Add a hidden context item to the thread to indicate that the stream was cancelled.
410 # Otherwise, depending on the timing of the cancellation, subsequent responses may
411 # attempt to continue the cancelled response.
412 await self.store.add_thread_item(
413 thread.id,
414 SDKHiddenContextItem(
415 thread_id=thread.id,
416 created_at=datetime.now(),
417 id=self.store.generate_item_id("sdk_hidden_context", thread, context),
418 content="The user cancelled the stream. Stop responding to the prior request.",
419 ),
420 context=context,
421 )
422
423 async def process(
424 self, request: str | bytes | bytearray, context: TContext
425 ) -> StreamingResult | NonStreamingResult:
426 """Parse an incoming ChatKit request and route it to streaming or non-streaming handlers."""
427 parsed_request = TypeAdapter[ChatKitReq](ChatKitReq).validate_json(request)
428 logger.info(f"Received request op: {parsed_request.type}")
429
430 if is_streaming_req(parsed_request):
431 return StreamingResult(self._process_streaming(parsed_request, context))
432 else:
433 return NonStreamingResult(
434 await self._process_non_streaming(parsed_request, context)
435 )
436
437 async def _process_non_streaming(
438 self, request: NonStreamingReq, context: TContext
439 ) -> bytes:
440 match request:
441 case ThreadsGetByIdReq():
442 thread = await self._load_full_thread(
443 request.params.thread_id, context=context
444 )
445 return self._serialize(self._to_thread_response(thread))
446 case ThreadsListReq():
447 params = request.params
448 threads = await self.store.load_threads(
449 limit=params.limit or DEFAULT_PAGE_SIZE,
450 after=params.after,
451 order=params.order,
452 context=context,
453 )
454 return self._serialize(
455 Page(
456 has_more=threads.has_more,
457 after=threads.after,
458 data=[
459 self._to_thread_response(thread) for thread in threads.data
460 ],
461 )
462 )
463 case ItemsFeedbackReq():
464 await self.add_feedback(
465 request.params.thread_id,
466 request.params.item_ids,
467 request.params.kind,
468 context,
469 )
470 return b"{}"
471 case AttachmentsCreateReq():
472 attachment_store = self._get_attachment_store()
473 attachment = await attachment_store.create_attachment(
474 request.params, context
475 )
476 await self.store.save_attachment(attachment, context=context)
477 return self._serialize(attachment)
478 case AttachmentsDeleteReq():
479 attachment_store = self._get_attachment_store()
480 await attachment_store.delete_attachment(
481 request.params.attachment_id, context=context
482 )
483 await self.store.delete_attachment(
484 request.params.attachment_id, context=context
485 )
486 return b"{}"
487 case InputTranscribeReq():
488 audio_bytes = base64.b64decode(request.params.audio_base64)
489 transcription_result = await self.transcribe(
490 AudioInput(data=audio_bytes, mime_type=request.params.mime_type),
491 context=context,
492 )
493 return self._serialize(transcription_result)
494 case ItemsListReq():
495 items_list_params = request.params
496 items = await self.store.load_thread_items(
497 items_list_params.thread_id,
498 limit=items_list_params.limit or DEFAULT_PAGE_SIZE,
499 order=items_list_params.order,
500 after=items_list_params.after,
501 context=context,
502 )
503 # filter out hidden context items
504 items.data = [
505 item
506 for item in items.data
507 if not isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
508 ]
509 return self._serialize(items)
510 case ThreadsUpdateReq():
511 thread = await self.store.load_thread(
512 request.params.thread_id, context=context
513 )
514 thread.title = request.params.title
515 await self.store.save_thread(thread, context=context)
516 return self._serialize(self._to_thread_response(thread))
517 case ThreadsDeleteReq():
518 await self.store.delete_thread(
519 request.params.thread_id, context=context
520 )
521 return b"{}"
522 case ThreadsSyncCustomActionReq():
523 return await self._process_sync_custom_action(request, context)
524 case _:
525 assert_never(request)
526
527 async def _process_streaming(
528 self, request: StreamingReq, context: TContext
529 ) -> AsyncGenerator[bytes, None]:
530 try:
531 async for event in self._process_streaming_impl(request, context):
532 b = self._serialize(event)
533 yield b"data: " + b + b"\n\n"
534 except asyncio.CancelledError:
535 # Let cancellation bubble up without logging as an error.
536 raise
537 except Exception:
538 logger.exception("Error while generating streamed response")
539 raise
540
541 async def _process_streaming_impl(
542 self, request: StreamingReq, context: TContext
543 ) -> AsyncGenerator[ThreadStreamEvent, None]:
544 match request:
545 case ThreadsCreateReq():
546 thread = Thread(
547 id=self.store.generate_thread_id(context),
548 created_at=datetime.now(),
549 items=Page(),
550 )
551 await self.store.save_thread(
552 ThreadMetadata(**thread.model_dump()),
553 context=context,
554 )
555 yield ThreadCreatedEvent(thread=self._to_thread_response(thread))
556 user_message = await self._build_user_message_item(
557 request.params.input, thread, context
558 )
559 async for event in self._process_new_thread_item_respond(
560 thread,
561 user_message,
562 context,
563 ):
564 yield event
565
566 case ThreadsAddUserMessageReq():
567 thread = await self.store.load_thread(
568 request.params.thread_id, context=context
569 )
570 user_message = await self._build_user_message_item(
571 request.params.input, thread, context
572 )
573 async for event in self._process_new_thread_item_respond(
574 thread,
575 user_message,
576 context,
577 ):
578 yield event
579
580 case ThreadsAddClientToolOutputReq():
581 thread = await self.store.load_thread(
582 request.params.thread_id, context=context
583 )
584 items = await self.store.load_thread_items(
585 thread.id, None, 1, "desc", context
586 )
587 tool_call = next(
588 (
589 item
590 for item in items.data
591 if isinstance(item, ClientToolCallItem)
592 and item.status == "pending"
593 ),
594 None,
595 )
596 if not tool_call:
597 raise ValueError(
598 f"Last thread item in {thread.id} was not a ClientToolCallItem"
599 )
600
601 tool_call.output = request.params.result
602 tool_call.status = "completed"
603
604 await self.store.save_item(thread.id, tool_call, context=context)
605
606 # Safety against dangling pending tool calls if there are
607 # multiple in a row, which should be impossible, and
608 # integrations should ultimately filter out pending tool calls
609 # when creating input response messages.
610 await self._cleanup_pending_client_tool_call(thread, context)
611
612 async for event in self._process_events(
613 thread,
614 context,
615 lambda: self.respond(thread, None, context),
616 ):
617 yield event
618
619 case ThreadsRetryAfterItemReq():
620 thread_metadata = await self.store.load_thread(
621 request.params.thread_id, context=context
622 )
623
624 # Collect items to remove (all items after the user message)
625 items_to_remove: list[ThreadItem] = []
626 user_message_item = None
627
628 async for item in self._paginate_thread_items_reverse(
629 request.params.thread_id, context
630 ):
631 if item.id == request.params.item_id:
632 if not isinstance(item, UserMessageItem):
633 raise ValueError(
634 f"Item {request.params.item_id} is not a user message"
635 )
636 user_message_item = item
637 break
638 items_to_remove.append(item)
639
640 if user_message_item:
641 for item in items_to_remove:
642 await self.store.delete_thread_item(
643 request.params.thread_id, item.id, context=context
644 )
645 async for event in self._process_events(
646 thread_metadata,
647 context,
648 lambda: self.respond(
649 thread_metadata,
650 user_message_item,
651 context,
652 ),
653 ):
654 yield event
655 case ThreadsCustomActionReq():
656 thread_metadata = await self.store.load_thread(
657 request.params.thread_id, context=context
658 )
659
660 item: ThreadItem | None = None
661 if request.params.item_id:
662 item = await self.store.load_item(
663 request.params.thread_id,
664 request.params.item_id,
665 context=context,
666 )
667
668 if item and not isinstance(item, WidgetItem):
669 # shouldn't happen if the caller is using the API correctly.
670 yield ErrorEvent(
671 code=ErrorCode.STREAM_ERROR,
672 allow_retry=False,
673 )
674 return
675
676 async for event in self._process_events(
677 thread_metadata,
678 context,
679 lambda: self.action(
680 thread_metadata,
681 request.params.action,
682 item,
683 context,
684 ),
685 ):
686 yield event
687
688 case _:
689 assert_never(request)
690
691 async def _process_sync_custom_action(
692 self, request: ThreadsSyncCustomActionReq, context: TContext
693 ) -> bytes:
694 thread_metadata = await self.store.load_thread(
695 request.params.thread_id, context=context
696 )
697
698 item: ThreadItem | None = None
699 if request.params.item_id:
700 item = await self.store.load_item(
701 request.params.thread_id,
702 request.params.item_id,
703 context=context,
704 )
705
706 if item and not isinstance(item, WidgetItem):
707 raise ValueError("threads.sync_custom_action requires a widget sender item")
708
709 return self._serialize(
710 await self.sync_action(
711 thread_metadata,
712 request.params.action,
713 item,
714 context,
715 )
716 )
717
718 async def _cleanup_pending_client_tool_call(
719 self, thread: ThreadMetadata, context: TContext
720 ) -> None:
721 items = await self.store.load_thread_items(
722 thread.id, None, DEFAULT_PAGE_SIZE, "desc", context
723 )
724 for tool_call in items.data:
725 if not isinstance(tool_call, ClientToolCallItem):
726 continue
727 if tool_call.status == "pending":
728 logger.warning(
729 f"Client tool call {tool_call.call_id} was not completed, ignoring"
730 )
731 await self.store.delete_thread_item(
732 thread.id, tool_call.id, context=context
733 )
734
735 async def _process_new_thread_item_respond(
736 self,
737 thread: ThreadMetadata,
738 item: UserMessageItem,
739 context: TContext,
740 ) -> AsyncIterator[ThreadStreamEvent]:
741 # Store the updated attachments (now that they have a thread id).
742 for attachment in item.attachments:
743 await self.store.save_attachment(attachment, context=context)
744
745 await self.store.add_thread_item(thread.id, item, context=context)
746 yield ThreadItemDoneEvent(item=item)
747
748 async for event in self._process_events(
749 thread,
750 context,
751 lambda: self.respond(thread, item, context),
752 ):
753 yield event
754
755 async def _process_events(
756 self,
757 thread: ThreadMetadata,
758 context: TContext,
759 stream: Callable[[], AsyncIterator[ThreadStreamEvent]],
760 ) -> AsyncIterator[ThreadStreamEvent]:
761 await asyncio.sleep(0) # allow the response to start streaming
762
763 # Send initial stream options
764 yield StreamOptionsEvent(
765 stream_options=self.get_stream_options(thread, context)
766 )
767
768 last_thread = thread.model_copy(deep=True)
769
770 # Keep track of items that were streamed but not yet saved
771 # so that we can persist them when the stream is cancelled.
772 pending_items: dict[str, ThreadItem] = {}
773
774 try:
775 with agents_sdk_user_agent_override():
776 async for event in stream():
777 if isinstance(event, ThreadItemAddedEvent):
778 # Stash an isolated copy in case we need to persist unfinished items
779 # on cancellation; downstream handlers keep using the original event.item.
780 pending_items[event.item.id] = event.item.model_copy(deep=True)
781
782 match event:
783 case ThreadItemDoneEvent():
784 await self.store.add_thread_item(
785 thread.id, event.item, context=context
786 )
787 pending_items.pop(event.item.id, None)
788 case ThreadItemRemovedEvent():
789 await self.store.delete_thread_item(
790 thread.id, event.item_id, context=context
791 )
792 pending_items.pop(event.item_id, None)
793 case ThreadItemReplacedEvent():
794 await self.store.save_item(
795 thread.id, event.item, context=context
796 )
797 pending_items.pop(event.item.id, None)
798 case ThreadItemUpdatedEvent():
799 # Keep pending assistant message and workflow items up to date
800 # so that we have a reference to the latest version of these pending items
801 # when the stream is cancelled.
802 self._update_pending_items(pending_items, event)
803
804 # special case - don't send hidden context items back to the client
805 should_swallow_event = isinstance(
806 event, ThreadItemDoneEvent
807 ) and isinstance(
808 event.item, (HiddenContextItem, SDKHiddenContextItem)
809 )
810
811 if not should_swallow_event:
812 yield event
813
814 # in case user updated the thread while streaming
815 if thread != last_thread:
816 last_thread = thread.model_copy(deep=True)
817 await self.store.save_thread(thread, context=context)
818 yield ThreadUpdatedEvent(
819 thread=self._to_thread_response(thread)
820 )
821 # in case user updated the thread while streaming
822 if thread != last_thread:
823 last_thread = thread.model_copy(deep=True)
824 await self.store.save_thread(thread, context=context)
825 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
826 except asyncio.CancelledError:
827 await self.handle_stream_cancelled(
828 thread, list(pending_items.values()), context
829 )
830 raise
831 except CustomStreamError as e:
832 yield ErrorEvent(
833 code="custom",
834 message=e.message,
835 allow_retry=e.allow_retry,
836 )
837 except StreamError as e:
838 yield ErrorEvent(
839 code=e.code,
840 allow_retry=e.allow_retry,
841 )
842 except Exception as e:
843 yield ErrorEvent(
844 code=ErrorCode.STREAM_ERROR,
845 allow_retry=True,
846 )
847 logger.exception(e)
848
849 if thread != last_thread:
850 # in case user updated the thread at the end of the stream
851 await self.store.save_thread(thread, context=context)
852 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
853
854 def _apply_assistant_message_update(
855 self,
856 item: AssistantMessageItem,
857 update: AssistantMessageContentPartAdded
858 | AssistantMessageContentPartTextDelta
859 | AssistantMessageContentPartAnnotationAdded
860 | AssistantMessageContentPartDone,
861 ) -> AssistantMessageItem:
862 # Pad the content list so the requested content_index exists before we write into it.
863 # (Streaming updates can arrive for an index that hasn’t been created yet)
864 while len(item.content) <= update.content_index:
865 item.content.append(AssistantMessageContent(text="", annotations=[]))
866
867 match update:
868 case AssistantMessageContentPartAdded():
869 item.content[update.content_index] = update.content
870 case AssistantMessageContentPartTextDelta():
871 item.content[update.content_index].text += update.delta
872 case AssistantMessageContentPartAnnotationAdded():
873 annotations = item.content[update.content_index].annotations
874 if update.annotation_index <= len(annotations):
875 annotations.insert(update.annotation_index, update.annotation)
876 else:
877 annotations.append(update.annotation)
878 case AssistantMessageContentPartDone():
879 item.content[update.content_index] = update.content
880 return item
881
882 def _update_pending_items(
883 self,
884 pending_items: dict[str, ThreadItem],
885 event: ThreadItemUpdatedEvent,
886 ):
887 updated_item = pending_items.get(event.item_id)
888 update = event.update
889 match updated_item:
890 case AssistantMessageItem():
891 if isinstance(
892 update,
893 (
894 AssistantMessageContentPartAdded,
895 AssistantMessageContentPartTextDelta,
896 AssistantMessageContentPartAnnotationAdded,
897 AssistantMessageContentPartDone,
898 ),
899 ):
900 pending_items[updated_item.id] = (
901 self._apply_assistant_message_update(updated_item, update)
902 )
903 case WorkflowItem():
904 if isinstance(update, (WorkflowTaskUpdated, WorkflowTaskAdded)):
905 match update:
906 case WorkflowTaskUpdated():
907 updated_item.workflow.tasks[update.task_index] = update.task
908 case WorkflowTaskAdded():
909 updated_item.workflow.tasks.append(update.task)
910
911 pending_items[updated_item.id] = updated_item
912 case _:
913 pass
914
915 async def _build_user_message_item(
916 self, input: UserMessageInput, thread: ThreadMetadata, context: TContext
917 ) -> UserMessageItem:
918 return UserMessageItem(
919 id=self.store.generate_item_id("message", thread, context),
920 content=input.content,
921 thread_id=thread.id,
922 attachments=[
923 (await self.store.load_attachment(attachment_id, context)).model_copy(
924 update={"thread_id": thread.id}
925 )
926 for attachment_id in input.attachments
927 ],
928 quoted_text=input.quoted_text,
929 inference_options=input.inference_options,
930 created_at=datetime.now(),
931 )
932
933 async def _load_full_thread(self, thread_id: str, context: TContext) -> Thread:
934 thread_meta = await self.store.load_thread(thread_id, context=context)
935 thread_items = await self.store.load_thread_items(
936 thread_id,
937 after=None,
938 limit=DEFAULT_PAGE_SIZE,
939 order="asc",
940 context=context,
941 )
942 return Thread(**thread_meta.model_dump(), items=thread_items)
943
944 async def _paginate_thread_items_reverse(
945 self, thread_id: str, context: TContext
946 ) -> AsyncIterator[ThreadItem]:
947 """Paginate through thread items in reverse order (newest first)."""
948 after = None
949 while True:
950 items = await self.store.load_thread_items(
951 thread_id, after, DEFAULT_PAGE_SIZE, "desc", context
952 )
953 for item in items.data:
954 yield item
955
956 if not items.has_more:
957 break
958 after = items.after
959
960 def _serialize(self, obj: BaseModel) -> bytes:
961 return obj.model_dump_json(
962 by_alias=True,
963 exclude_none=True,
964 context={"exclude_metadata": True},
965 ).encode("utf-8")
966
967 def _to_thread_response(self, thread: ThreadMetadata | Thread) -> Thread:
968 def is_hidden(item: ThreadItem) -> bool:
969 return isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
970
971 items = thread.items if isinstance(thread, Thread) else Page()
972 items.data = [item for item in items.data if not is_hidden(item)]
973
974 return Thread(
975 id=thread.id,
976 title=thread.title,
977 created_at=thread.created_at,
978 items=items,
979 status=thread.status,
980 )
981