openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
a202f603a00389ec4330c4faeb58dd41c64153a8

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/server.py

917lines · modecode

1import asyncio
2import base64
3from abc import ABC, abstractmethod
4from collections.abc import AsyncIterator
5from contextlib import contextmanager
6from datetime import datetime
7from typing import (
8 Any,
9 AsyncGenerator,
10 AsyncIterable,
11 Callable,
12 Generic,
13)
14
15import agents
16from agents.models.chatcmpl_helpers import (
17 HEADERS_OVERRIDE as chat_completions_headers_override,
18)
19from agents.models.openai_responses import (
20 _HEADERS_OVERRIDE as responses_headers_override,
21)
22from pydantic import BaseModel, TypeAdapter
23from typing_extensions import TypeVar, assert_never
24
25from chatkit.errors import CustomStreamError, StreamError
26
27from .logger import logger
28from .store import AttachmentStore, Store, StoreItemType, default_generate_id
29from .types import (
30 Action,
31 AssistantMessageContent,
32 AssistantMessageContentPartAdded,
33 AssistantMessageContentPartAnnotationAdded,
34 AssistantMessageContentPartDone,
35 AssistantMessageContentPartTextDelta,
36 AssistantMessageItem,
37 AttachmentsCreateReq,
38 AttachmentsDeleteReq,
39 ChatKitReq,
40 ClientToolCallItem,
41 ErrorCode,
42 ErrorEvent,
43 FeedbackKind,
44 HiddenContextItem,
45 InputTranscribeReq,
46 ItemsFeedbackReq,
47 ItemsListReq,
48 NonStreamingReq,
49 Page,
50 SDKHiddenContextItem,
51 StreamingReq,
52 StreamOptions,
53 StreamOptionsEvent,
54 Thread,
55 ThreadCreatedEvent,
56 ThreadItem,
57 ThreadItemAddedEvent,
58 ThreadItemDoneEvent,
59 ThreadItemRemovedEvent,
60 ThreadItemReplacedEvent,
61 ThreadItemUpdatedEvent,
62 ThreadMetadata,
63 ThreadsAddClientToolOutputReq,
64 ThreadsAddUserMessageReq,
65 ThreadsCreateReq,
66 ThreadsCustomActionReq,
67 ThreadsDeleteReq,
68 ThreadsGetByIdReq,
69 ThreadsListReq,
70 ThreadsRetryAfterItemReq,
71 ThreadStreamEvent,
72 ThreadsUpdateReq,
73 ThreadUpdatedEvent,
74 TranscriptionResult,
75 UserMessageInput,
76 UserMessageItem,
77 WidgetComponentUpdated,
78 WidgetItem,
79 WidgetRootUpdated,
80 WidgetStreamingTextValueDelta,
81 WorkflowItem,
82 WorkflowTaskAdded,
83 WorkflowTaskUpdated,
84 is_streaming_req,
85)
86from .version import __version__
87from .widgets import WidgetComponent, WidgetComponentBase, WidgetRoot
88
89DEFAULT_PAGE_SIZE = 20
90DEFAULT_ERROR_MESSAGE = "An error occurred when generating a response."
91
92
93def diff_widget(
94 before: WidgetRoot, after: WidgetRoot
95) -> list[WidgetStreamingTextValueDelta | WidgetRootUpdated | WidgetComponentUpdated]:
96 """
97 Compare two WidgetRoots and return a list of deltas.
98 """
99
100 def is_streaming_text(component: WidgetComponentBase) -> bool:
101 return getattr(component, "type", None) in {"Markdown", "Text"} and isinstance(
102 getattr(component, "value", None), str
103 )
104
105 def full_replace(before: WidgetComponentBase, after: WidgetComponentBase) -> bool:
106 if (
107 before.type != after.type
108 or before.id != after.id
109 or before.key != after.key
110 ):
111 return True
112
113 def full_replace_value(before_value: Any, after_value: Any) -> bool:
114 if isinstance(before_value, list) and isinstance(after_value, list):
115 if len(before_value) != len(after_value):
116 return True
117 for nth_before_value, nth_after_value in zip(before_value, after_value):
118 if full_replace_value(nth_before_value, nth_after_value):
119 return True
120 elif before_value != after_value:
121 if isinstance(before_value, WidgetComponentBase) and isinstance(
122 after_value, WidgetComponentBase
123 ):
124 return full_replace(before_value, after_value)
125 else:
126 return True
127 return False
128
129 for field in before.model_fields_set.union(after.model_fields_set):
130 if (
131 is_streaming_text(before)
132 and is_streaming_text(after)
133 and field == "value"
134 and getattr(after, "value", "").startswith(getattr(before, "value", ""))
135 ):
136 # Appends to the value prop of Markdown or Text do not trigger a full replace
137 continue
138 if full_replace_value(getattr(before, field), getattr(after, field)):
139 return True
140
141 return False
142
143 if full_replace(before, after):
144 return [WidgetRootUpdated(widget=after)]
145
146 deltas: list[
147 WidgetStreamingTextValueDelta | WidgetComponentUpdated | WidgetRootUpdated
148 ] = []
149
150 def find_all_streaming_text_components(
151 component: WidgetComponent | WidgetRoot,
152 ) -> dict[str, WidgetComponentBase]:
153 components = {}
154
155 def recurse(component: WidgetComponent | WidgetRoot):
156 if is_streaming_text(component) and component.id:
157 components[component.id] = component
158
159 if hasattr(component, "children"):
160 children = getattr(component, "children", None) or []
161 for child in children:
162 recurse(child)
163
164 recurse(component)
165 return components
166
167 before_nodes = find_all_streaming_text_components(before)
168 after_nodes = find_all_streaming_text_components(after)
169
170 for id, after_node in after_nodes.items():
171 before_node = before_nodes.get(id)
172 if before_node is None:
173 raise ValueError(
174 f"Node {id} was not present when the widget was initially rendered. All nodes with ID must persist across all widget updates."
175 )
176
177 before_value = str(getattr(before_node, "value", None))
178 after_value = str(getattr(after_node, "value", None))
179
180 if before_value != after_value:
181 if not after_value.startswith(before_value):
182 raise ValueError(
183 f"Node {id} was updated with a new value that is not a prefix of the initial value. All widget updates must be cumulative."
184 )
185 done = not getattr(after_node, "streaming", False)
186 deltas.append(
187 WidgetStreamingTextValueDelta(
188 component_id=id,
189 delta=after_value[len(before_value) :],
190 done=done,
191 )
192 )
193
194 return deltas
195
196
197async def stream_widget(
198 thread: ThreadMetadata,
199 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
200 copy_text: str | None = None,
201 generate_id: Callable[[StoreItemType], str] = default_generate_id,
202) -> AsyncIterator[ThreadStreamEvent]:
203 """Stream a widget root (or async sequence of roots) as ThreadStreamEvents."""
204 item_id = generate_id("message")
205
206 if not isinstance(widget, AsyncGenerator):
207 yield ThreadItemDoneEvent(
208 item=WidgetItem(
209 id=item_id,
210 thread_id=thread.id,
211 created_at=datetime.now(),
212 widget=widget,
213 copy_text=copy_text,
214 ),
215 )
216 return
217
218 initial_state = await widget.__anext__()
219
220 item = WidgetItem(
221 id=item_id,
222 created_at=datetime.now(),
223 widget=initial_state,
224 copy_text=copy_text,
225 thread_id=thread.id,
226 )
227
228 yield ThreadItemAddedEvent(item=item)
229
230 last_state = initial_state
231
232 while widget:
233 try:
234 new_state = await widget.__anext__()
235 for update in diff_widget(last_state, new_state):
236 yield ThreadItemUpdatedEvent(
237 item_id=item_id,
238 update=update,
239 )
240 last_state = new_state
241 except StopAsyncIteration:
242 break
243
244 yield ThreadItemDoneEvent(
245 item=item.model_copy(update={"widget": last_state}),
246 )
247
248
249@contextmanager
250def agents_sdk_user_agent_override():
251 ua = f"Agents/Python {agents.__version__} ChatKit/Python {__version__}"
252 chat_completions_token = chat_completions_headers_override.set({"User-Agent": ua})
253 responses_token = responses_headers_override.set({"User-Agent": ua})
254
255 yield
256
257 chat_completions_headers_override.reset(chat_completions_token)
258 responses_headers_override.reset(responses_token)
259
260
261class StreamingResult(AsyncIterable[bytes]):
262 def __init__(self, stream: AsyncGenerator[bytes, None]):
263 self.json_events = stream
264
265 async def __aiter__(self):
266 async for event in self.json_events:
267 yield event
268
269
270class NonStreamingResult:
271 def __init__(self, result: bytes):
272 self.json = result
273
274
275TContext = TypeVar("TContext", default=Any)
276
277
278class ChatKitServer(ABC, Generic[TContext]):
279 def __init__(
280 self,
281 store: Store[TContext],
282 attachment_store: AttachmentStore[TContext] | None = None,
283 ):
284 """Create a ChatKitServer with the backing Store and optional AttachmentStore."""
285 self.store = store
286 self.attachment_store = attachment_store
287
288 def _get_attachment_store(self) -> AttachmentStore[TContext]:
289 """Return the configured AttachmentStore or raise if missing."""
290 if self.attachment_store is None:
291 raise RuntimeError(
292 "AttachmentStore is not configured. Provide a AttachmentStore to ChatKitServer to handle file operations."
293 )
294 return self.attachment_store
295
296 @abstractmethod
297 def respond(
298 self,
299 thread: ThreadMetadata,
300 input_user_message: UserMessageItem | None,
301 context: TContext,
302 ) -> AsyncIterator[ThreadStreamEvent]:
303 """Stream `ThreadStreamEvent` instances for a new user message.
304
305 Args:
306 thread: Metadata for the thread being processed.
307 input_user_message: The incoming message the server should respond to, if any.
308 context: Arbitrary per-request context provided by the caller.
309
310 Returns:
311 An async iterator that yields events representing the server's response.
312 """
313 pass
314
315 async def add_feedback( # noqa: B027
316 self,
317 thread_id: str,
318 item_ids: list[str],
319 feedback: FeedbackKind,
320 context: TContext,
321 ) -> None:
322 """Persist user feedback for one or more thread items."""
323 pass
324
325 async def transcribe( # noqa: B027
326 self, audio_bytes: bytes, mime_type: str, context: TContext
327 ) -> TranscriptionResult:
328 """Transcribe speech audio to text. Override this method to support dictation."""
329 raise NotImplementedError(
330 "transcribe() must be overridden to support the input.transcribe request."
331 )
332
333 def action(
334 self,
335 thread: ThreadMetadata,
336 action: Action[str, Any],
337 sender: WidgetItem | None,
338 context: TContext,
339 ) -> AsyncIterator[ThreadStreamEvent]:
340 """Handle a widget or client-dispatched action and yield response events."""
341 raise NotImplementedError(
342 "The action() method must be overridden to react to actions. "
343 "See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
344 )
345
346 def get_stream_options(
347 self, thread: ThreadMetadata, context: TContext
348 ) -> StreamOptions:
349 """
350 Return stream-level runtime options. Allows the user to cancel the stream by default.
351 Override this method to customize behavior.
352 """
353 return StreamOptions(allow_cancel=True)
354
355 async def handle_stream_cancelled(
356 self,
357 thread: ThreadMetadata,
358 pending_items: list[ThreadItem],
359 context: TContext,
360 ):
361 """Perform custom cleanup / stop inference when a stream is cancelled.
362 Updates you make here will not be reflected in the UI until a reload.
363
364 The default implementation persists any non-empty pending assistant messages
365 to the thread but does not auto-save pending widget items or workflow items.
366
367 Args:
368 thread: The thread that was being processed.
369 pending_items: Items that were not done streaming at cancellation time.
370 context: Arbitrary per-request context provided by the caller.
371 """
372 pending_assistant_message_items: list[AssistantMessageItem] = [
373 item for item in pending_items if isinstance(item, AssistantMessageItem)
374 ]
375 for item in pending_assistant_message_items:
376 is_empty = len(item.content) == 0 or all(
377 (not content.text.strip()) for content in item.content
378 )
379 if not is_empty:
380 await self.store.add_thread_item(thread.id, item, context=context)
381
382 # Add a hidden context item to the thread to indicate that the stream was cancelled.
383 # Otherwise, depending on the timing of the cancellation, subsequent responses may
384 # attempt to continue the cancelled response.
385 await self.store.add_thread_item(
386 thread.id,
387 SDKHiddenContextItem(
388 thread_id=thread.id,
389 created_at=datetime.now(),
390 id=self.store.generate_item_id("sdk_hidden_context", thread, context),
391 content="The user cancelled the stream. Stop responding to the prior request.",
392 ),
393 context=context,
394 )
395
396 async def process(
397 self, request: str | bytes | bytearray, context: TContext
398 ) -> StreamingResult | NonStreamingResult:
399 """Parse an incoming ChatKit request and route it to streaming or non-streaming handlers."""
400 parsed_request = TypeAdapter[ChatKitReq](ChatKitReq).validate_json(request)
401 logger.info(f"Received request op: {parsed_request.type}")
402
403 if is_streaming_req(parsed_request):
404 return StreamingResult(self._process_streaming(parsed_request, context))
405 else:
406 return NonStreamingResult(
407 await self._process_non_streaming(parsed_request, context)
408 )
409
410 async def _process_non_streaming(
411 self, request: NonStreamingReq, context: TContext
412 ) -> bytes:
413 match request:
414 case ThreadsGetByIdReq():
415 thread = await self._load_full_thread(
416 request.params.thread_id, context=context
417 )
418 return self._serialize(self._to_thread_response(thread))
419 case ThreadsListReq():
420 params = request.params
421 threads = await self.store.load_threads(
422 limit=params.limit or DEFAULT_PAGE_SIZE,
423 after=params.after,
424 order=params.order,
425 context=context,
426 )
427 return self._serialize(
428 Page(
429 has_more=threads.has_more,
430 after=threads.after,
431 data=[
432 self._to_thread_response(thread) for thread in threads.data
433 ],
434 )
435 )
436 case ItemsFeedbackReq():
437 await self.add_feedback(
438 request.params.thread_id,
439 request.params.item_ids,
440 request.params.kind,
441 context,
442 )
443 return b"{}"
444 case AttachmentsCreateReq():
445 attachment_store = self._get_attachment_store()
446 attachment = await attachment_store.create_attachment(
447 request.params, context
448 )
449 await self.store.save_attachment(attachment, context=context)
450 return self._serialize(attachment)
451 case AttachmentsDeleteReq():
452 attachment_store = self._get_attachment_store()
453 await attachment_store.delete_attachment(
454 request.params.attachment_id, context=context
455 )
456 await self.store.delete_attachment(
457 request.params.attachment_id, context=context
458 )
459 return b"{}"
460 case InputTranscribeReq():
461 audio_bytes = base64.b64decode(request.params.audio_base64)
462 transcription_result = await self.transcribe(
463 audio_bytes, request.params.mime_type, context=context
464 )
465 return self._serialize(transcription_result)
466 case ItemsListReq():
467 items_list_params = request.params
468 items = await self.store.load_thread_items(
469 items_list_params.thread_id,
470 limit=items_list_params.limit or DEFAULT_PAGE_SIZE,
471 order=items_list_params.order,
472 after=items_list_params.after,
473 context=context,
474 )
475 # filter out hidden context items
476 items.data = [
477 item
478 for item in items.data
479 if not isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
480 ]
481 return self._serialize(items)
482 case ThreadsUpdateReq():
483 thread = await self.store.load_thread(
484 request.params.thread_id, context=context
485 )
486 thread.title = request.params.title
487 await self.store.save_thread(thread, context=context)
488 return self._serialize(self._to_thread_response(thread))
489 case ThreadsDeleteReq():
490 await self.store.delete_thread(
491 request.params.thread_id, context=context
492 )
493 return b"{}"
494 case _:
495 assert_never(request)
496
497 async def _process_streaming(
498 self, request: StreamingReq, context: TContext
499 ) -> AsyncGenerator[bytes, None]:
500 try:
501 async for event in self._process_streaming_impl(request, context):
502 b = self._serialize(event)
503 yield b"data: " + b + b"\n\n"
504 except asyncio.CancelledError:
505 # Let cancellation bubble up without logging as an error.
506 raise
507 except Exception:
508 logger.exception("Error while generating streamed response")
509 raise
510
511 async def _process_streaming_impl(
512 self, request: StreamingReq, context: TContext
513 ) -> AsyncGenerator[ThreadStreamEvent, None]:
514 match request:
515 case ThreadsCreateReq():
516 thread = Thread(
517 id=self.store.generate_thread_id(context),
518 created_at=datetime.now(),
519 items=Page(),
520 )
521 await self.store.save_thread(
522 ThreadMetadata(**thread.model_dump()),
523 context=context,
524 )
525 yield ThreadCreatedEvent(thread=self._to_thread_response(thread))
526 user_message = await self._build_user_message_item(
527 request.params.input, thread, context
528 )
529 async for event in self._process_new_thread_item_respond(
530 thread,
531 user_message,
532 context,
533 ):
534 yield event
535
536 case ThreadsAddUserMessageReq():
537 thread = await self.store.load_thread(
538 request.params.thread_id, context=context
539 )
540 user_message = await self._build_user_message_item(
541 request.params.input, thread, context
542 )
543 async for event in self._process_new_thread_item_respond(
544 thread,
545 user_message,
546 context,
547 ):
548 yield event
549
550 case ThreadsAddClientToolOutputReq():
551 thread = await self.store.load_thread(
552 request.params.thread_id, context=context
553 )
554 items = await self.store.load_thread_items(
555 thread.id, None, 1, "desc", context
556 )
557 tool_call = next(
558 (
559 item
560 for item in items.data
561 if isinstance(item, ClientToolCallItem)
562 and item.status == "pending"
563 ),
564 None,
565 )
566 if not tool_call:
567 raise ValueError(
568 f"Last thread item in {thread.id} was not a ClientToolCallItem"
569 )
570
571 tool_call.output = request.params.result
572 tool_call.status = "completed"
573
574 await self.store.save_item(thread.id, tool_call, context=context)
575
576 # Safety against dangling pending tool calls if there are
577 # multiple in a row, which should be impossible, and
578 # integrations should ultimately filter out pending tool calls
579 # when creating input response messages.
580 await self._cleanup_pending_client_tool_call(thread, context)
581
582 async for event in self._process_events(
583 thread,
584 context,
585 lambda: self.respond(thread, None, context),
586 ):
587 yield event
588
589 case ThreadsRetryAfterItemReq():
590 thread_metadata = await self.store.load_thread(
591 request.params.thread_id, context=context
592 )
593
594 # Collect items to remove (all items after the user message)
595 items_to_remove: list[ThreadItem] = []
596 user_message_item = None
597
598 async for item in self._paginate_thread_items_reverse(
599 request.params.thread_id, context
600 ):
601 if item.id == request.params.item_id:
602 if not isinstance(item, UserMessageItem):
603 raise ValueError(
604 f"Item {request.params.item_id} is not a user message"
605 )
606 user_message_item = item
607 break
608 items_to_remove.append(item)
609
610 if user_message_item:
611 for item in items_to_remove:
612 await self.store.delete_thread_item(
613 request.params.thread_id, item.id, context=context
614 )
615 async for event in self._process_events(
616 thread_metadata,
617 context,
618 lambda: self.respond(
619 thread_metadata,
620 user_message_item,
621 context,
622 ),
623 ):
624 yield event
625 case ThreadsCustomActionReq():
626 thread_metadata = await self.store.load_thread(
627 request.params.thread_id, context=context
628 )
629
630 item: ThreadItem | None = None
631 if request.params.item_id:
632 item = await self.store.load_item(
633 request.params.thread_id,
634 request.params.item_id,
635 context=context,
636 )
637
638 if item and not isinstance(item, WidgetItem):
639 # shouldn't happen if the caller is using the API correctly.
640 yield ErrorEvent(
641 code=ErrorCode.STREAM_ERROR,
642 allow_retry=False,
643 )
644 return
645
646 async for event in self._process_events(
647 thread_metadata,
648 context,
649 lambda: self.action(
650 thread_metadata,
651 request.params.action,
652 item,
653 context,
654 ),
655 ):
656 yield event
657
658 case _:
659 assert_never(request)
660
661 async def _cleanup_pending_client_tool_call(
662 self, thread: ThreadMetadata, context: TContext
663 ) -> None:
664 items = await self.store.load_thread_items(
665 thread.id, None, DEFAULT_PAGE_SIZE, "desc", context
666 )
667 for tool_call in items.data:
668 if not isinstance(tool_call, ClientToolCallItem):
669 continue
670 if tool_call.status == "pending":
671 logger.warning(
672 f"Client tool call {tool_call.call_id} was not completed, ignoring"
673 )
674 await self.store.delete_thread_item(
675 thread.id, tool_call.id, context=context
676 )
677
678 async def _process_new_thread_item_respond(
679 self,
680 thread: ThreadMetadata,
681 item: UserMessageItem,
682 context: TContext,
683 ) -> AsyncIterator[ThreadStreamEvent]:
684 await self.store.add_thread_item(thread.id, item, context=context)
685 yield ThreadItemDoneEvent(item=item)
686
687 async for event in self._process_events(
688 thread,
689 context,
690 lambda: self.respond(thread, item, context),
691 ):
692 yield event
693
694 async def _process_events(
695 self,
696 thread: ThreadMetadata,
697 context: TContext,
698 stream: Callable[[], AsyncIterator[ThreadStreamEvent]],
699 ) -> AsyncIterator[ThreadStreamEvent]:
700 await asyncio.sleep(0) # allow the response to start streaming
701
702 # Send initial stream options
703 yield StreamOptionsEvent(
704 stream_options=self.get_stream_options(thread, context)
705 )
706
707 last_thread = thread.model_copy(deep=True)
708
709 # Keep track of items that were streamed but not yet saved
710 # so that we can persist them when the stream is cancelled.
711 pending_items: dict[str, ThreadItem] = {}
712
713 try:
714 with agents_sdk_user_agent_override():
715 async for event in stream():
716 if isinstance(event, ThreadItemAddedEvent):
717 # Stash an isolated copy in case we need to persist unfinished items
718 # on cancellation; downstream handlers keep using the original event.item.
719 pending_items[event.item.id] = event.item.model_copy(deep=True)
720
721 match event:
722 case ThreadItemDoneEvent():
723 await self.store.add_thread_item(
724 thread.id, event.item, context=context
725 )
726 pending_items.pop(event.item.id, None)
727 case ThreadItemRemovedEvent():
728 await self.store.delete_thread_item(
729 thread.id, event.item_id, context=context
730 )
731 pending_items.pop(event.item_id, None)
732 case ThreadItemReplacedEvent():
733 await self.store.save_item(
734 thread.id, event.item, context=context
735 )
736 pending_items.pop(event.item.id, None)
737 case ThreadItemUpdatedEvent():
738 # Keep pending assistant message and workflow items up to date
739 # so that we have a reference to the latest version of these pending items
740 # when the stream is cancelled.
741 self._update_pending_items(pending_items, event)
742
743 # special case - don't send hidden context items back to the client
744 should_swallow_event = isinstance(
745 event, ThreadItemDoneEvent
746 ) and isinstance(
747 event.item, (HiddenContextItem, SDKHiddenContextItem)
748 )
749
750 if not should_swallow_event:
751 yield event
752
753 # in case user updated the thread while streaming
754 if thread != last_thread:
755 last_thread = thread.model_copy(deep=True)
756 await self.store.save_thread(thread, context=context)
757 yield ThreadUpdatedEvent(
758 thread=self._to_thread_response(thread)
759 )
760 # in case user updated the thread while streaming
761 if thread != last_thread:
762 last_thread = thread.model_copy(deep=True)
763 await self.store.save_thread(thread, context=context)
764 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
765 except asyncio.CancelledError:
766 await self.handle_stream_cancelled(
767 thread, list(pending_items.values()), context
768 )
769 raise
770 except CustomStreamError as e:
771 yield ErrorEvent(
772 code="custom",
773 message=e.message,
774 allow_retry=e.allow_retry,
775 )
776 except StreamError as e:
777 yield ErrorEvent(
778 code=e.code,
779 allow_retry=e.allow_retry,
780 )
781 except Exception as e:
782 yield ErrorEvent(
783 code=ErrorCode.STREAM_ERROR,
784 allow_retry=True,
785 )
786 logger.exception(e)
787
788 if thread != last_thread:
789 # in case user updated the thread at the end of the stream
790 await self.store.save_thread(thread, context=context)
791 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
792
793 def _apply_assistant_message_update(
794 self,
795 item: AssistantMessageItem,
796 update: AssistantMessageContentPartAdded
797 | AssistantMessageContentPartTextDelta
798 | AssistantMessageContentPartAnnotationAdded
799 | AssistantMessageContentPartDone,
800 ) -> AssistantMessageItem:
801 # Pad the content list so the requested content_index exists before we write into it.
802 # (Streaming updates can arrive for an index that hasn’t been created yet)
803 while len(item.content) <= update.content_index:
804 item.content.append(AssistantMessageContent(text="", annotations=[]))
805
806 match update:
807 case AssistantMessageContentPartAdded():
808 item.content[update.content_index] = update.content
809 case AssistantMessageContentPartTextDelta():
810 item.content[update.content_index].text += update.delta
811 case AssistantMessageContentPartAnnotationAdded():
812 annotations = item.content[update.content_index].annotations
813 if update.annotation_index <= len(annotations):
814 annotations.insert(update.annotation_index, update.annotation)
815 else:
816 annotations.append(update.annotation)
817 case AssistantMessageContentPartDone():
818 item.content[update.content_index] = update.content
819 return item
820
821 def _update_pending_items(
822 self,
823 pending_items: dict[str, ThreadItem],
824 event: ThreadItemUpdatedEvent,
825 ):
826 updated_item = pending_items.get(event.item_id)
827 update = event.update
828 match updated_item:
829 case AssistantMessageItem():
830 if isinstance(
831 update,
832 (
833 AssistantMessageContentPartAdded,
834 AssistantMessageContentPartTextDelta,
835 AssistantMessageContentPartAnnotationAdded,
836 AssistantMessageContentPartDone,
837 ),
838 ):
839 pending_items[updated_item.id] = (
840 self._apply_assistant_message_update(updated_item, update)
841 )
842 case WorkflowItem():
843 if isinstance(update, (WorkflowTaskUpdated, WorkflowTaskAdded)):
844 match update:
845 case WorkflowTaskUpdated():
846 updated_item.workflow.tasks[update.task_index] = update.task
847 case WorkflowTaskAdded():
848 updated_item.workflow.tasks.append(update.task)
849
850 pending_items[updated_item.id] = updated_item
851 case _:
852 pass
853
854 async def _build_user_message_item(
855 self, input: UserMessageInput, thread: ThreadMetadata, context: TContext
856 ) -> UserMessageItem:
857 return UserMessageItem(
858 id=self.store.generate_item_id("message", thread, context),
859 content=input.content,
860 thread_id=thread.id,
861 attachments=[
862 await self.store.load_attachment(attachment_id, context)
863 for attachment_id in input.attachments
864 ],
865 quoted_text=input.quoted_text,
866 inference_options=input.inference_options,
867 created_at=datetime.now(),
868 )
869
870 async def _load_full_thread(self, thread_id: str, context: TContext) -> Thread:
871 thread_meta = await self.store.load_thread(thread_id, context=context)
872 thread_items = await self.store.load_thread_items(
873 thread_id,
874 after=None,
875 limit=DEFAULT_PAGE_SIZE,
876 order="asc",
877 context=context,
878 )
879 return Thread(**thread_meta.model_dump(), items=thread_items)
880
881 async def _paginate_thread_items_reverse(
882 self, thread_id: str, context: TContext
883 ) -> AsyncIterator[ThreadItem]:
884 """Paginate through thread items in reverse order (newest first)."""
885 after = None
886 while True:
887 items = await self.store.load_thread_items(
888 thread_id, after, DEFAULT_PAGE_SIZE, "desc", context
889 )
890 for item in items.data:
891 yield item
892
893 if not items.has_more:
894 break
895 after = items.after
896
897 def _serialize(self, obj: BaseModel) -> bytes:
898 return obj.model_dump_json(
899 by_alias=True,
900 exclude_none=True,
901 context={"exclude_metadata": True},
902 ).encode("utf-8")
903
904 def _to_thread_response(self, thread: ThreadMetadata | Thread) -> Thread:
905 def is_hidden(item: ThreadItem) -> bool:
906 return isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
907
908 items = thread.items if isinstance(thread, Thread) else Page()
909 items.data = [item for item in items.data if not is_hidden(item)]
910
911 return Thread(
912 id=thread.id,
913 title=thread.title,
914 created_at=thread.created_at,
915 items=items,
916 status=thread.status,
917 )
918