openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.5.2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/server.py

896lines · modecode

1import asyncio
2from abc import ABC, abstractmethod
3from collections.abc import AsyncIterator
4from contextlib import contextmanager
5from datetime import datetime
6from typing import (
7 Any,
8 AsyncGenerator,
9 AsyncIterable,
10 Callable,
11 Generic,
12)
13
14import agents
15from agents.models.chatcmpl_helpers import (
16 HEADERS_OVERRIDE as chat_completions_headers_override,
17)
18from agents.models.openai_responses import (
19 _HEADERS_OVERRIDE as responses_headers_override,
20)
21from pydantic import BaseModel, TypeAdapter
22from typing_extensions import TypeVar, assert_never
23
24from chatkit.errors import CustomStreamError, StreamError
25
26from .logger import logger
27from .store import AttachmentStore, Store, StoreItemType, default_generate_id
28from .types import (
29 Action,
30 AssistantMessageContent,
31 AssistantMessageContentPartAdded,
32 AssistantMessageContentPartAnnotationAdded,
33 AssistantMessageContentPartDone,
34 AssistantMessageContentPartTextDelta,
35 AssistantMessageItem,
36 AttachmentsCreateReq,
37 AttachmentsDeleteReq,
38 ChatKitReq,
39 ClientToolCallItem,
40 ErrorCode,
41 ErrorEvent,
42 FeedbackKind,
43 HiddenContextItem,
44 ItemsFeedbackReq,
45 ItemsListReq,
46 NonStreamingReq,
47 Page,
48 SDKHiddenContextItem,
49 StreamingReq,
50 StreamOptions,
51 StreamOptionsEvent,
52 Thread,
53 ThreadCreatedEvent,
54 ThreadItem,
55 ThreadItemAddedEvent,
56 ThreadItemDoneEvent,
57 ThreadItemRemovedEvent,
58 ThreadItemReplacedEvent,
59 ThreadItemUpdatedEvent,
60 ThreadMetadata,
61 ThreadsAddClientToolOutputReq,
62 ThreadsAddUserMessageReq,
63 ThreadsCreateReq,
64 ThreadsCustomActionReq,
65 ThreadsDeleteReq,
66 ThreadsGetByIdReq,
67 ThreadsListReq,
68 ThreadsRetryAfterItemReq,
69 ThreadStreamEvent,
70 ThreadsUpdateReq,
71 ThreadUpdatedEvent,
72 UserMessageInput,
73 UserMessageItem,
74 WidgetComponentUpdated,
75 WidgetItem,
76 WidgetRootUpdated,
77 WidgetStreamingTextValueDelta,
78 WorkflowItem,
79 WorkflowTaskAdded,
80 WorkflowTaskUpdated,
81 is_streaming_req,
82)
83from .version import __version__
84from .widgets import WidgetComponent, WidgetComponentBase, WidgetRoot
85
86DEFAULT_PAGE_SIZE = 20
87DEFAULT_ERROR_MESSAGE = "An error occurred when generating a response."
88
89
90def diff_widget(
91 before: WidgetRoot, after: WidgetRoot
92) -> list[WidgetStreamingTextValueDelta | WidgetRootUpdated | WidgetComponentUpdated]:
93 """
94 Compare two WidgetRoots and return a list of deltas.
95 """
96
97 def is_streaming_text(component: WidgetComponentBase) -> bool:
98 return getattr(component, "type", None) in {"Markdown", "Text"} and isinstance(
99 getattr(component, "value", None), str
100 )
101
102 def full_replace(before: WidgetComponentBase, after: WidgetComponentBase) -> bool:
103 if (
104 before.type != after.type
105 or before.id != after.id
106 or before.key != after.key
107 ):
108 return True
109
110 def full_replace_value(before_value: Any, after_value: Any) -> bool:
111 if isinstance(before_value, list) and isinstance(after_value, list):
112 if len(before_value) != len(after_value):
113 return True
114 for nth_before_value, nth_after_value in zip(before_value, after_value):
115 if full_replace_value(nth_before_value, nth_after_value):
116 return True
117 elif before_value != after_value:
118 if isinstance(before_value, WidgetComponentBase) and isinstance(
119 after_value, WidgetComponentBase
120 ):
121 return full_replace(before_value, after_value)
122 else:
123 return True
124 return False
125
126 for field in before.model_fields_set.union(after.model_fields_set):
127 if (
128 is_streaming_text(before)
129 and is_streaming_text(after)
130 and field == "value"
131 and getattr(after, "value", "").startswith(getattr(before, "value", ""))
132 ):
133 # Appends to the value prop of Markdown or Text do not trigger a full replace
134 continue
135 if full_replace_value(getattr(before, field), getattr(after, field)):
136 return True
137
138 return False
139
140 if full_replace(before, after):
141 return [WidgetRootUpdated(widget=after)]
142
143 deltas: list[
144 WidgetStreamingTextValueDelta | WidgetComponentUpdated | WidgetRootUpdated
145 ] = []
146
147 def find_all_streaming_text_components(
148 component: WidgetComponent | WidgetRoot,
149 ) -> dict[str, WidgetComponentBase]:
150 components = {}
151
152 def recurse(component: WidgetComponent | WidgetRoot):
153 if is_streaming_text(component) and component.id:
154 components[component.id] = component
155
156 if hasattr(component, "children"):
157 children = getattr(component, "children", None) or []
158 for child in children:
159 recurse(child)
160
161 recurse(component)
162 return components
163
164 before_nodes = find_all_streaming_text_components(before)
165 after_nodes = find_all_streaming_text_components(after)
166
167 for id, after_node in after_nodes.items():
168 before_node = before_nodes.get(id)
169 if before_node is None:
170 raise ValueError(
171 f"Node {id} was not present when the widget was initially rendered. All nodes with ID must persist across all widget updates."
172 )
173
174 before_value = str(getattr(before_node, "value", None))
175 after_value = str(getattr(after_node, "value", None))
176
177 if before_value != after_value:
178 if not after_value.startswith(before_value):
179 raise ValueError(
180 f"Node {id} was updated with a new value that is not a prefix of the initial value. All widget updates must be cumulative."
181 )
182 done = not getattr(after_node, "streaming", False)
183 deltas.append(
184 WidgetStreamingTextValueDelta(
185 component_id=id,
186 delta=after_value[len(before_value) :],
187 done=done,
188 )
189 )
190
191 return deltas
192
193
194async def stream_widget(
195 thread: ThreadMetadata,
196 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
197 copy_text: str | None = None,
198 generate_id: Callable[[StoreItemType], str] = default_generate_id,
199) -> AsyncIterator[ThreadStreamEvent]:
200 """Stream a widget root (or async sequence of roots) as ThreadStreamEvents."""
201 item_id = generate_id("message")
202
203 if not isinstance(widget, AsyncGenerator):
204 yield ThreadItemDoneEvent(
205 item=WidgetItem(
206 id=item_id,
207 thread_id=thread.id,
208 created_at=datetime.now(),
209 widget=widget,
210 copy_text=copy_text,
211 ),
212 )
213 return
214
215 initial_state = await widget.__anext__()
216
217 item = WidgetItem(
218 id=item_id,
219 created_at=datetime.now(),
220 widget=initial_state,
221 copy_text=copy_text,
222 thread_id=thread.id,
223 )
224
225 yield ThreadItemAddedEvent(item=item)
226
227 last_state = initial_state
228
229 while widget:
230 try:
231 new_state = await widget.__anext__()
232 for update in diff_widget(last_state, new_state):
233 yield ThreadItemUpdatedEvent(
234 item_id=item_id,
235 update=update,
236 )
237 last_state = new_state
238 except StopAsyncIteration:
239 break
240
241 yield ThreadItemDoneEvent(
242 item=item.model_copy(update={"widget": last_state}),
243 )
244
245
246@contextmanager
247def agents_sdk_user_agent_override():
248 ua = f"Agents/Python {agents.__version__} ChatKit/Python {__version__}"
249 chat_completions_token = chat_completions_headers_override.set({"User-Agent": ua})
250 responses_token = responses_headers_override.set({"User-Agent": ua})
251
252 yield
253
254 chat_completions_headers_override.reset(chat_completions_token)
255 responses_headers_override.reset(responses_token)
256
257
258class StreamingResult(AsyncIterable[bytes]):
259 def __init__(self, stream: AsyncGenerator[bytes, None]):
260 self.json_events = stream
261
262 async def __aiter__(self):
263 async for event in self.json_events:
264 yield event
265
266
267class NonStreamingResult:
268 def __init__(self, result: bytes):
269 self.json = result
270
271
272TContext = TypeVar("TContext", default=Any)
273
274
275class ChatKitServer(ABC, Generic[TContext]):
276 def __init__(
277 self,
278 store: Store[TContext],
279 attachment_store: AttachmentStore[TContext] | None = None,
280 ):
281 """Create a ChatKitServer with the backing Store and optional AttachmentStore."""
282 self.store = store
283 self.attachment_store = attachment_store
284
285 def _get_attachment_store(self) -> AttachmentStore[TContext]:
286 """Return the configured AttachmentStore or raise if missing."""
287 if self.attachment_store is None:
288 raise RuntimeError(
289 "AttachmentStore is not configured. Provide a AttachmentStore to ChatKitServer to handle file operations."
290 )
291 return self.attachment_store
292
293 @abstractmethod
294 def respond(
295 self,
296 thread: ThreadMetadata,
297 input_user_message: UserMessageItem | None,
298 context: TContext,
299 ) -> AsyncIterator[ThreadStreamEvent]:
300 """Stream `ThreadStreamEvent` instances for a new user message.
301
302 Args:
303 thread: Metadata for the thread being processed.
304 input_user_message: The incoming message the server should respond to, if any.
305 context: Arbitrary per-request context provided by the caller.
306
307 Returns:
308 An async iterator that yields events representing the server's response.
309 """
310 pass
311
312 async def add_feedback( # noqa: B027
313 self,
314 thread_id: str,
315 item_ids: list[str],
316 feedback: FeedbackKind,
317 context: TContext,
318 ) -> None:
319 """Persist user feedback for one or more thread items."""
320 pass
321
322 def action(
323 self,
324 thread: ThreadMetadata,
325 action: Action[str, Any],
326 sender: WidgetItem | None,
327 context: TContext,
328 ) -> AsyncIterator[ThreadStreamEvent]:
329 """Handle a widget or client-dispatched action and yield response events."""
330 raise NotImplementedError(
331 "The action() method must be overridden to react to actions. "
332 "See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
333 )
334
335 def get_stream_options(
336 self, thread: ThreadMetadata, context: TContext
337 ) -> StreamOptions:
338 """
339 Return stream-level runtime options. Allows the user to cancel the stream by default.
340 Override this method to customize behavior.
341 """
342 return StreamOptions(allow_cancel=True)
343
344 async def handle_stream_cancelled(
345 self,
346 thread: ThreadMetadata,
347 pending_items: list[ThreadItem],
348 context: TContext,
349 ):
350 """Perform custom cleanup / stop inference when a stream is cancelled.
351 Updates you make here will not be reflected in the UI until a reload.
352
353 The default implementation persists any non-empty pending assistant messages
354 to the thread but does not auto-save pending widget items or workflow items.
355
356 Args:
357 thread: The thread that was being processed.
358 pending_items: Items that were not done streaming at cancellation time.
359 context: Arbitrary per-request context provided by the caller.
360 """
361 pending_assistant_message_items: list[AssistantMessageItem] = [
362 item for item in pending_items if isinstance(item, AssistantMessageItem)
363 ]
364 for item in pending_assistant_message_items:
365 is_empty = len(item.content) == 0 or all(
366 (not content.text.strip()) for content in item.content
367 )
368 if not is_empty:
369 await self.store.add_thread_item(thread.id, item, context=context)
370
371 # Add a hidden context item to the thread to indicate that the stream was cancelled.
372 # Otherwise, depending on the timing of the cancellation, subsequent responses may
373 # attempt to continue the cancelled response.
374 await self.store.add_thread_item(
375 thread.id,
376 SDKHiddenContextItem(
377 thread_id=thread.id,
378 created_at=datetime.now(),
379 id=self.store.generate_item_id("sdk_hidden_context", thread, context),
380 content="The user cancelled the stream. Stop responding to the prior request.",
381 ),
382 context=context,
383 )
384
385 async def process(
386 self, request: str | bytes | bytearray, context: TContext
387 ) -> StreamingResult | NonStreamingResult:
388 """Parse an incoming ChatKit request and route it to streaming or non-streaming handlers."""
389 parsed_request = TypeAdapter[ChatKitReq](ChatKitReq).validate_json(request)
390 logger.info(f"Received request op: {parsed_request.type}")
391
392 if is_streaming_req(parsed_request):
393 return StreamingResult(self._process_streaming(parsed_request, context))
394 else:
395 return NonStreamingResult(
396 await self._process_non_streaming(parsed_request, context)
397 )
398
399 async def _process_non_streaming(
400 self, request: NonStreamingReq, context: TContext
401 ) -> bytes:
402 match request:
403 case ThreadsGetByIdReq():
404 thread = await self._load_full_thread(
405 request.params.thread_id, context=context
406 )
407 return self._serialize(self._to_thread_response(thread))
408 case ThreadsListReq():
409 params = request.params
410 threads = await self.store.load_threads(
411 limit=params.limit or DEFAULT_PAGE_SIZE,
412 after=params.after,
413 order=params.order,
414 context=context,
415 )
416 return self._serialize(
417 Page(
418 has_more=threads.has_more,
419 after=threads.after,
420 data=[
421 self._to_thread_response(thread) for thread in threads.data
422 ],
423 )
424 )
425 case ItemsFeedbackReq():
426 await self.add_feedback(
427 request.params.thread_id,
428 request.params.item_ids,
429 request.params.kind,
430 context,
431 )
432 return b"{}"
433 case AttachmentsCreateReq():
434 attachment_store = self._get_attachment_store()
435 attachment = await attachment_store.create_attachment(
436 request.params, context
437 )
438 await self.store.save_attachment(attachment, context=context)
439 return self._serialize(attachment)
440 case AttachmentsDeleteReq():
441 attachment_store = self._get_attachment_store()
442 await attachment_store.delete_attachment(
443 request.params.attachment_id, context=context
444 )
445 await self.store.delete_attachment(
446 request.params.attachment_id, context=context
447 )
448 return b"{}"
449 case ItemsListReq():
450 items_list_params = request.params
451 items = await self.store.load_thread_items(
452 items_list_params.thread_id,
453 limit=items_list_params.limit or DEFAULT_PAGE_SIZE,
454 order=items_list_params.order,
455 after=items_list_params.after,
456 context=context,
457 )
458 # filter out hidden context items
459 items.data = [
460 item
461 for item in items.data
462 if not isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
463 ]
464 return self._serialize(items)
465 case ThreadsUpdateReq():
466 thread = await self.store.load_thread(
467 request.params.thread_id, context=context
468 )
469 thread.title = request.params.title
470 await self.store.save_thread(thread, context=context)
471 return self._serialize(self._to_thread_response(thread))
472 case ThreadsDeleteReq():
473 await self.store.delete_thread(
474 request.params.thread_id, context=context
475 )
476 return b"{}"
477 case _:
478 assert_never(request)
479
480 async def _process_streaming(
481 self, request: StreamingReq, context: TContext
482 ) -> AsyncGenerator[bytes, None]:
483 try:
484 async for event in self._process_streaming_impl(request, context):
485 b = self._serialize(event)
486 yield b"data: " + b + b"\n\n"
487 except asyncio.CancelledError:
488 # Let cancellation bubble up without logging as an error.
489 raise
490 except Exception:
491 logger.exception("Error while generating streamed response")
492 raise
493
494 async def _process_streaming_impl(
495 self, request: StreamingReq, context: TContext
496 ) -> AsyncGenerator[ThreadStreamEvent, None]:
497 match request:
498 case ThreadsCreateReq():
499 thread = Thread(
500 id=self.store.generate_thread_id(context),
501 created_at=datetime.now(),
502 items=Page(),
503 )
504 await self.store.save_thread(
505 ThreadMetadata(**thread.model_dump()),
506 context=context,
507 )
508 yield ThreadCreatedEvent(thread=self._to_thread_response(thread))
509 user_message = await self._build_user_message_item(
510 request.params.input, thread, context
511 )
512 async for event in self._process_new_thread_item_respond(
513 thread,
514 user_message,
515 context,
516 ):
517 yield event
518
519 case ThreadsAddUserMessageReq():
520 thread = await self.store.load_thread(
521 request.params.thread_id, context=context
522 )
523 user_message = await self._build_user_message_item(
524 request.params.input, thread, context
525 )
526 async for event in self._process_new_thread_item_respond(
527 thread,
528 user_message,
529 context,
530 ):
531 yield event
532
533 case ThreadsAddClientToolOutputReq():
534 thread = await self.store.load_thread(
535 request.params.thread_id, context=context
536 )
537 items = await self.store.load_thread_items(
538 thread.id, None, 1, "desc", context
539 )
540 tool_call = next(
541 (
542 item
543 for item in items.data
544 if isinstance(item, ClientToolCallItem)
545 and item.status == "pending"
546 ),
547 None,
548 )
549 if not tool_call:
550 raise ValueError(
551 f"Last thread item in {thread.id} was not a ClientToolCallItem"
552 )
553
554 tool_call.output = request.params.result
555 tool_call.status = "completed"
556
557 await self.store.save_item(thread.id, tool_call, context=context)
558
559 # Safety against dangling pending tool calls if there are
560 # multiple in a row, which should be impossible, and
561 # integrations should ultimately filter out pending tool calls
562 # when creating input response messages.
563 await self._cleanup_pending_client_tool_call(thread, context)
564
565 async for event in self._process_events(
566 thread,
567 context,
568 lambda: self.respond(thread, None, context),
569 ):
570 yield event
571
572 case ThreadsRetryAfterItemReq():
573 thread_metadata = await self.store.load_thread(
574 request.params.thread_id, context=context
575 )
576
577 # Collect items to remove (all items after the user message)
578 items_to_remove: list[ThreadItem] = []
579 user_message_item = None
580
581 async for item in self._paginate_thread_items_reverse(
582 request.params.thread_id, context
583 ):
584 if item.id == request.params.item_id:
585 if not isinstance(item, UserMessageItem):
586 raise ValueError(
587 f"Item {request.params.item_id} is not a user message"
588 )
589 user_message_item = item
590 break
591 items_to_remove.append(item)
592
593 if user_message_item:
594 for item in items_to_remove:
595 await self.store.delete_thread_item(
596 request.params.thread_id, item.id, context=context
597 )
598 async for event in self._process_events(
599 thread_metadata,
600 context,
601 lambda: self.respond(
602 thread_metadata,
603 user_message_item,
604 context,
605 ),
606 ):
607 yield event
608 case ThreadsCustomActionReq():
609 thread_metadata = await self.store.load_thread(
610 request.params.thread_id, context=context
611 )
612
613 item: ThreadItem | None = None
614 if request.params.item_id:
615 item = await self.store.load_item(
616 request.params.thread_id,
617 request.params.item_id,
618 context=context,
619 )
620
621 if item and not isinstance(item, WidgetItem):
622 # shouldn't happen if the caller is using the API correctly.
623 yield ErrorEvent(
624 code=ErrorCode.STREAM_ERROR,
625 allow_retry=False,
626 )
627 return
628
629 async for event in self._process_events(
630 thread_metadata,
631 context,
632 lambda: self.action(
633 thread_metadata,
634 request.params.action,
635 item,
636 context,
637 ),
638 ):
639 yield event
640
641 case _:
642 assert_never(request)
643
644 async def _cleanup_pending_client_tool_call(
645 self, thread: ThreadMetadata, context: TContext
646 ) -> None:
647 items = await self.store.load_thread_items(
648 thread.id, None, DEFAULT_PAGE_SIZE, "desc", context
649 )
650 for tool_call in items.data:
651 if not isinstance(tool_call, ClientToolCallItem):
652 continue
653 if tool_call.status == "pending":
654 logger.warning(
655 f"Client tool call {tool_call.call_id} was not completed, ignoring"
656 )
657 await self.store.delete_thread_item(
658 thread.id, tool_call.id, context=context
659 )
660
661 async def _process_new_thread_item_respond(
662 self,
663 thread: ThreadMetadata,
664 item: UserMessageItem,
665 context: TContext,
666 ) -> AsyncIterator[ThreadStreamEvent]:
667 await self.store.add_thread_item(thread.id, item, context=context)
668 yield ThreadItemDoneEvent(item=item)
669
670 async for event in self._process_events(
671 thread,
672 context,
673 lambda: self.respond(thread, item, context),
674 ):
675 yield event
676
677 async def _process_events(
678 self,
679 thread: ThreadMetadata,
680 context: TContext,
681 stream: Callable[[], AsyncIterator[ThreadStreamEvent]],
682 ) -> AsyncIterator[ThreadStreamEvent]:
683 await asyncio.sleep(0) # allow the response to start streaming
684
685 # Send initial stream options
686 yield StreamOptionsEvent(
687 stream_options=self.get_stream_options(thread, context)
688 )
689
690 last_thread = thread.model_copy(deep=True)
691
692 # Keep track of items that were streamed but not yet saved
693 # so that we can persist them when the stream is cancelled.
694 pending_items: dict[str, ThreadItem] = {}
695
696 try:
697 with agents_sdk_user_agent_override():
698 async for event in stream():
699 if isinstance(event, ThreadItemAddedEvent):
700 # Stash an isolated copy in case we need to persist unfinished items
701 # on cancellation; downstream handlers keep using the original event.item.
702 pending_items[event.item.id] = event.item.model_copy(deep=True)
703
704 match event:
705 case ThreadItemDoneEvent():
706 await self.store.add_thread_item(
707 thread.id, event.item, context=context
708 )
709 pending_items.pop(event.item.id, None)
710 case ThreadItemRemovedEvent():
711 await self.store.delete_thread_item(
712 thread.id, event.item_id, context=context
713 )
714 pending_items.pop(event.item_id, None)
715 case ThreadItemReplacedEvent():
716 await self.store.save_item(
717 thread.id, event.item, context=context
718 )
719 pending_items.pop(event.item.id, None)
720 case ThreadItemUpdatedEvent():
721 # Keep pending assistant message and workflow items up to date
722 # so that we have a reference to the latest version of these pending items
723 # when the stream is cancelled.
724 self._update_pending_items(pending_items, event)
725
726 # special case - don't send hidden context items back to the client
727 should_swallow_event = isinstance(
728 event, ThreadItemDoneEvent
729 ) and isinstance(
730 event.item, (HiddenContextItem, SDKHiddenContextItem)
731 )
732
733 if not should_swallow_event:
734 yield event
735
736 # in case user updated the thread while streaming
737 if thread != last_thread:
738 last_thread = thread.model_copy(deep=True)
739 await self.store.save_thread(thread, context=context)
740 yield ThreadUpdatedEvent(
741 thread=self._to_thread_response(thread)
742 )
743 # in case user updated the thread while streaming
744 if thread != last_thread:
745 last_thread = thread.model_copy(deep=True)
746 await self.store.save_thread(thread, context=context)
747 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
748 except asyncio.CancelledError:
749 await self.handle_stream_cancelled(
750 thread, list(pending_items.values()), context
751 )
752 raise
753 except CustomStreamError as e:
754 yield ErrorEvent(
755 code="custom",
756 message=e.message,
757 allow_retry=e.allow_retry,
758 )
759 except StreamError as e:
760 yield ErrorEvent(
761 code=e.code,
762 allow_retry=e.allow_retry,
763 )
764 except Exception as e:
765 yield ErrorEvent(
766 code=ErrorCode.STREAM_ERROR,
767 allow_retry=True,
768 )
769 logger.exception(e)
770
771 if thread != last_thread:
772 # in case user updated the thread at the end of the stream
773 await self.store.save_thread(thread, context=context)
774 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
775
776 def _apply_assistant_message_update(
777 self,
778 item: AssistantMessageItem,
779 update: AssistantMessageContentPartAdded
780 | AssistantMessageContentPartTextDelta
781 | AssistantMessageContentPartAnnotationAdded
782 | AssistantMessageContentPartDone,
783 ) -> AssistantMessageItem:
784 # Pad the content list so the requested content_index exists before we write into it.
785 # (Streaming updates can arrive for an index that hasn’t been created yet)
786 while len(item.content) <= update.content_index:
787 item.content.append(AssistantMessageContent(text="", annotations=[]))
788
789 match update:
790 case AssistantMessageContentPartAdded():
791 item.content[update.content_index] = update.content
792 case AssistantMessageContentPartTextDelta():
793 item.content[update.content_index].text += update.delta
794 case AssistantMessageContentPartAnnotationAdded():
795 annotations = item.content[update.content_index].annotations
796 if update.annotation_index <= len(annotations):
797 annotations.insert(update.annotation_index, update.annotation)
798 else:
799 annotations.append(update.annotation)
800 case AssistantMessageContentPartDone():
801 item.content[update.content_index] = update.content
802 return item
803
804 def _update_pending_items(
805 self,
806 pending_items: dict[str, ThreadItem],
807 event: ThreadItemUpdatedEvent,
808 ):
809 updated_item = pending_items.get(event.item_id)
810 update = event.update
811 match updated_item:
812 case AssistantMessageItem():
813 if isinstance(
814 update,
815 (
816 AssistantMessageContentPartAdded,
817 AssistantMessageContentPartTextDelta,
818 AssistantMessageContentPartAnnotationAdded,
819 AssistantMessageContentPartDone,
820 ),
821 ):
822 pending_items[updated_item.id] = (
823 self._apply_assistant_message_update(updated_item, update)
824 )
825 case WorkflowItem():
826 if isinstance(update, (WorkflowTaskUpdated, WorkflowTaskAdded)):
827 match update:
828 case WorkflowTaskUpdated():
829 updated_item.workflow.tasks[update.task_index] = update.task
830 case WorkflowTaskAdded():
831 updated_item.workflow.tasks.append(update.task)
832
833 pending_items[updated_item.id] = updated_item
834 case _:
835 pass
836
837 async def _build_user_message_item(
838 self, input: UserMessageInput, thread: ThreadMetadata, context: TContext
839 ) -> UserMessageItem:
840 return UserMessageItem(
841 id=self.store.generate_item_id("message", thread, context),
842 content=input.content,
843 thread_id=thread.id,
844 attachments=[
845 await self.store.load_attachment(attachment_id, context)
846 for attachment_id in input.attachments
847 ],
848 quoted_text=input.quoted_text,
849 inference_options=input.inference_options,
850 created_at=datetime.now(),
851 )
852
853 async def _load_full_thread(self, thread_id: str, context: TContext) -> Thread:
854 thread_meta = await self.store.load_thread(thread_id, context=context)
855 thread_items = await self.store.load_thread_items(
856 thread_id,
857 after=None,
858 limit=DEFAULT_PAGE_SIZE,
859 order="asc",
860 context=context,
861 )
862 return Thread(**thread_meta.model_dump(), items=thread_items)
863
864 async def _paginate_thread_items_reverse(
865 self, thread_id: str, context: TContext
866 ) -> AsyncIterator[ThreadItem]:
867 """Paginate through thread items in reverse order (newest first)."""
868 after = None
869 while True:
870 items = await self.store.load_thread_items(
871 thread_id, after, DEFAULT_PAGE_SIZE, "desc", context
872 )
873 for item in items.data:
874 yield item
875
876 if not items.has_more:
877 break
878 after = items.after
879
880 def _serialize(self, obj: BaseModel) -> bytes:
881 return obj.model_dump_json(by_alias=True, exclude_none=True).encode("utf-8")
882
883 def _to_thread_response(self, thread: ThreadMetadata | Thread) -> Thread:
884 def is_hidden(item: ThreadItem) -> bool:
885 return isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
886
887 items = thread.items if isinstance(thread, Thread) else Page()
888 items.data = [item for item in items.data if not is_hidden(item)]
889
890 return Thread(
891 id=thread.id,
892 title=thread.title,
893 created_at=thread.created_at,
894 items=items,
895 status=thread.status,
896 )
897