openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
6df9ee60da4de3f063e752e60bd3ceef920f4e78

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/server.py

883lines · modecode

1import asyncio
2from abc import ABC, abstractmethod
3from collections.abc import AsyncIterator
4from contextlib import contextmanager
5from datetime import datetime
6from typing import (
7 Any,
8 AsyncGenerator,
9 AsyncIterable,
10 Callable,
11 Generic,
12)
13
14import agents
15from agents.models.chatcmpl_helpers import (
16 HEADERS_OVERRIDE as chat_completions_headers_override,
17)
18from agents.models.openai_responses import (
19 _HEADERS_OVERRIDE as responses_headers_override,
20)
21from pydantic import BaseModel, TypeAdapter
22from typing_extensions import TypeVar, assert_never
23
24from chatkit.errors import CustomStreamError, StreamError
25
26from .logger import logger
27from .store import AttachmentStore, Store, StoreItemType, default_generate_id
28from .types import (
29 Action,
30 AssistantMessageContent,
31 AssistantMessageContentPartAdded,
32 AssistantMessageContentPartAnnotationAdded,
33 AssistantMessageContentPartDone,
34 AssistantMessageContentPartTextDelta,
35 AssistantMessageItem,
36 AttachmentsCreateReq,
37 AttachmentsDeleteReq,
38 ChatKitReq,
39 ClientToolCallItem,
40 ErrorCode,
41 ErrorEvent,
42 FeedbackKind,
43 HiddenContextItem,
44 ItemsFeedbackReq,
45 ItemsListReq,
46 NonStreamingReq,
47 Page,
48 SDKHiddenContextItem,
49 StreamingReq,
50 StreamOptions,
51 StreamOptionsEvent,
52 Thread,
53 ThreadCreatedEvent,
54 ThreadItem,
55 ThreadItemAddedEvent,
56 ThreadItemDoneEvent,
57 ThreadItemRemovedEvent,
58 ThreadItemReplacedEvent,
59 ThreadItemUpdatedEvent,
60 ThreadMetadata,
61 ThreadsAddClientToolOutputReq,
62 ThreadsAddUserMessageReq,
63 ThreadsCreateReq,
64 ThreadsCustomActionReq,
65 ThreadsDeleteReq,
66 ThreadsGetByIdReq,
67 ThreadsListReq,
68 ThreadsRetryAfterItemReq,
69 ThreadStreamEvent,
70 ThreadsUpdateReq,
71 ThreadUpdatedEvent,
72 UserMessageInput,
73 UserMessageItem,
74 WidgetComponentUpdated,
75 WidgetItem,
76 WidgetRootUpdated,
77 WidgetStreamingTextValueDelta,
78 WorkflowItem,
79 WorkflowTaskAdded,
80 WorkflowTaskUpdated,
81 is_streaming_req,
82)
83from .version import __version__
84from .widgets import Markdown, Text, WidgetComponent, WidgetComponentBase, WidgetRoot
85
86DEFAULT_PAGE_SIZE = 20
87DEFAULT_ERROR_MESSAGE = "An error occurred when generating a response."
88
89
90def diff_widget(
91 before: WidgetRoot, after: WidgetRoot
92) -> list[WidgetStreamingTextValueDelta | WidgetRootUpdated | WidgetComponentUpdated]:
93 """
94 Compare two WidgetRoots and return a list of deltas.
95 """
96
97 def full_replace(before: WidgetComponentBase, after: WidgetComponentBase) -> bool:
98 if (
99 before.type != after.type
100 or before.id != after.id
101 or before.key != after.key
102 ):
103 return True
104
105 def full_replace_value(before_value: Any, after_value: Any) -> bool:
106 if isinstance(before_value, list) and isinstance(after_value, list):
107 if len(before_value) != len(after_value):
108 return True
109 for nth_before_value, nth_after_value in zip(before_value, after_value):
110 if full_replace_value(nth_before_value, nth_after_value):
111 return True
112 elif before_value != after_value:
113 if isinstance(before_value, WidgetComponentBase) and isinstance(
114 after_value, WidgetComponentBase
115 ):
116 return full_replace(before_value, after_value)
117 else:
118 return True
119 return False
120
121 for field in before.model_fields_set.union(after.model_fields_set):
122 if (
123 isinstance(before, (Markdown, Text))
124 and isinstance(after, (Markdown, Text))
125 and field == "value"
126 and after.value.startswith(before.value)
127 ):
128 # Appends to the value prop of Markdown or Text do not trigger a full replace
129 continue
130 if full_replace_value(getattr(before, field), getattr(after, field)):
131 return True
132
133 return False
134
135 if full_replace(before, after):
136 return [WidgetRootUpdated(widget=after)]
137
138 deltas: list[
139 WidgetStreamingTextValueDelta | WidgetComponentUpdated | WidgetRootUpdated
140 ] = []
141
142 def find_all_streaming_text_components(
143 component: WidgetComponent | WidgetRoot,
144 ) -> dict[str, Markdown | Text]:
145 components = {}
146
147 def recurse(component: WidgetComponent | WidgetRoot):
148 if isinstance(component, (Markdown, Text)) and component.id:
149 components[component.id] = component
150
151 if hasattr(component, "children"):
152 children = getattr(component, "children", None) or []
153 for child in children:
154 recurse(child)
155
156 recurse(component)
157 return components
158
159 before_nodes = find_all_streaming_text_components(before)
160 after_nodes = find_all_streaming_text_components(after)
161
162 for id, after_node in after_nodes.items():
163 before_node = before_nodes.get(id)
164 if before_node is None:
165 raise ValueError(
166 f"Node {id} was not present when the widget was initially rendered. All nodes with ID must persist across all widget updates."
167 )
168
169 if before_node.value != after_node.value:
170 if not after_node.value.startswith(before_node.value):
171 raise ValueError(
172 f"Node {id} was updated with a new value that is not a prefix of the initial value. All widget updates must be cumulative."
173 )
174 done = not after_node.streaming
175 deltas.append(
176 WidgetStreamingTextValueDelta(
177 component_id=id,
178 delta=after_node.value[len(before_node.value) :],
179 done=done,
180 )
181 )
182
183 return deltas
184
185
186async def stream_widget(
187 thread: ThreadMetadata,
188 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
189 copy_text: str | None = None,
190 generate_id: Callable[[StoreItemType], str] = default_generate_id,
191) -> AsyncIterator[ThreadStreamEvent]:
192 item_id = generate_id("message")
193
194 if not isinstance(widget, AsyncGenerator):
195 yield ThreadItemDoneEvent(
196 item=WidgetItem(
197 id=item_id,
198 thread_id=thread.id,
199 created_at=datetime.now(),
200 widget=widget,
201 copy_text=copy_text,
202 ),
203 )
204 return
205
206 initial_state = await widget.__anext__()
207
208 item = WidgetItem(
209 id=item_id,
210 created_at=datetime.now(),
211 widget=initial_state,
212 copy_text=copy_text,
213 thread_id=thread.id,
214 )
215
216 yield ThreadItemAddedEvent(item=item)
217
218 last_state = initial_state
219
220 while widget:
221 try:
222 new_state = await widget.__anext__()
223 for update in diff_widget(last_state, new_state):
224 yield ThreadItemUpdatedEvent(
225 item_id=item_id,
226 update=update,
227 )
228 last_state = new_state
229 except StopAsyncIteration:
230 break
231
232 yield ThreadItemDoneEvent(
233 item=item.model_copy(update={"widget": last_state}),
234 )
235
236
237@contextmanager
238def agents_sdk_user_agent_override():
239 ua = f"Agents/Python {agents.__version__} ChatKit/Python {__version__}"
240 chat_completions_token = chat_completions_headers_override.set({"User-Agent": ua})
241 responses_token = responses_headers_override.set({"User-Agent": ua})
242
243 yield
244
245 chat_completions_headers_override.reset(chat_completions_token)
246 responses_headers_override.reset(responses_token)
247
248
249class StreamingResult(AsyncIterable[bytes]):
250 def __init__(self, stream: AsyncGenerator[bytes, None]):
251 self.json_events = stream
252
253 async def __aiter__(self):
254 async for event in self.json_events:
255 yield event
256
257
258class NonStreamingResult:
259 def __init__(self, result: bytes):
260 self.json = result
261
262
263TContext = TypeVar("TContext", default=Any)
264
265
266class ChatKitServer(ABC, Generic[TContext]):
267 def __init__(
268 self,
269 store: Store[TContext],
270 attachment_store: AttachmentStore[TContext] | None = None,
271 ):
272 self.store = store
273 self.attachment_store = attachment_store
274
275 def _get_attachment_store(self) -> AttachmentStore[TContext]:
276 """Return the configured AttachmentStore or raise if missing."""
277 if self.attachment_store is None:
278 raise RuntimeError(
279 "AttachmentStore is not configured. Provide a AttachmentStore to ChatKitServer to handle file operations."
280 )
281 return self.attachment_store
282
283 @abstractmethod
284 def respond(
285 self,
286 thread: ThreadMetadata,
287 input_user_message: UserMessageItem | None,
288 context: TContext,
289 ) -> AsyncIterator[ThreadStreamEvent]:
290 """Stream `ThreadStreamEvent` instances for a new user message.
291
292 Args:
293 thread: Metadata for the thread being processed.
294 input_user_message: The incoming message the server should respond to, if any.
295 context: Arbitrary per-request context provided by the caller.
296
297 Returns:
298 An async iterator that yields events representing the server's response.
299 """
300 pass
301
302 async def add_feedback( # noqa: B027
303 self,
304 thread_id: str,
305 item_ids: list[str],
306 feedback: FeedbackKind,
307 context: TContext,
308 ) -> None:
309 pass
310
311 def action(
312 self,
313 thread: ThreadMetadata,
314 action: Action[str, Any],
315 sender: WidgetItem | None,
316 context: TContext,
317 ) -> AsyncIterator[ThreadStreamEvent]:
318 raise NotImplementedError(
319 "The action() method must be overridden to react to actions. "
320 "See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
321 )
322
323 def get_stream_options(
324 self, thread: ThreadMetadata, context: TContext
325 ) -> StreamOptions:
326 """
327 Return stream-level runtime options. Allows the user to cancel the stream by default.
328 Override this method to customize behavior.
329 """
330 return StreamOptions(allow_cancel=True)
331
332 async def handle_stream_cancelled(
333 self,
334 thread: ThreadMetadata,
335 pending_items: list[ThreadItem],
336 context: TContext,
337 ):
338 """Perform custom cleanup / stop inference when a stream is cancelled.
339 Updates you make here will not be reflected in the UI until a reload.
340
341 The default implementation persists any non-empty pending assistant messages
342 to the thread but does not auto-save pending widget items or workflow items.
343
344 Args:
345 thread: The thread that was being processed.
346 pending_items: Items that were not done streaming at cancellation time.
347 context: Arbitrary per-request context provided by the caller.
348 """
349 pending_assistant_message_items: list[AssistantMessageItem] = [
350 item for item in pending_items if isinstance(item, AssistantMessageItem)
351 ]
352 for item in pending_assistant_message_items:
353 is_empty = len(item.content) == 0 or all(
354 (not content.text.strip()) for content in item.content
355 )
356 if not is_empty:
357 await self.store.add_thread_item(thread.id, item, context=context)
358
359 # Add a hidden context item to the thread to indicate that the stream was cancelled.
360 # Otherwise, depending on the timing of the cancellation, subsequent responses may
361 # attempt to continue the cancelled response.
362 await self.store.add_thread_item(
363 thread.id,
364 SDKHiddenContextItem(
365 thread_id=thread.id,
366 created_at=datetime.now(),
367 id=self.store.generate_item_id("sdk_hidden_context", thread, context),
368 content="The user cancelled the stream. Stop responding to the prior request.",
369 ),
370 context=context,
371 )
372
373 async def process(
374 self, request: str | bytes | bytearray, context: TContext
375 ) -> StreamingResult | NonStreamingResult:
376 parsed_request = TypeAdapter[ChatKitReq](ChatKitReq).validate_json(request)
377 logger.info(f"Received request op: {parsed_request.type}")
378
379 if is_streaming_req(parsed_request):
380 return StreamingResult(self._process_streaming(parsed_request, context))
381 else:
382 return NonStreamingResult(
383 await self._process_non_streaming(parsed_request, context)
384 )
385
386 async def _process_non_streaming(
387 self, request: NonStreamingReq, context: TContext
388 ) -> bytes:
389 match request:
390 case ThreadsGetByIdReq():
391 thread = await self._load_full_thread(
392 request.params.thread_id, context=context
393 )
394 return self._serialize(self._to_thread_response(thread))
395 case ThreadsListReq():
396 params = request.params
397 threads = await self.store.load_threads(
398 limit=params.limit or DEFAULT_PAGE_SIZE,
399 after=params.after,
400 order=params.order,
401 context=context,
402 )
403 return self._serialize(
404 Page(
405 has_more=threads.has_more,
406 after=threads.after,
407 data=[
408 self._to_thread_response(thread) for thread in threads.data
409 ],
410 )
411 )
412 case ItemsFeedbackReq():
413 await self.add_feedback(
414 request.params.thread_id,
415 request.params.item_ids,
416 request.params.kind,
417 context,
418 )
419 return b"{}"
420 case AttachmentsCreateReq():
421 attachment_store = self._get_attachment_store()
422 attachment = await attachment_store.create_attachment(
423 request.params, context
424 )
425 return self._serialize(attachment)
426 case AttachmentsDeleteReq():
427 attachment_store = self._get_attachment_store()
428 await attachment_store.delete_attachment(
429 request.params.attachment_id, context=context
430 )
431 await self.store.delete_attachment(
432 request.params.attachment_id, context=context
433 )
434 return b"{}"
435 case ItemsListReq():
436 items_list_params = request.params
437 items = await self.store.load_thread_items(
438 items_list_params.thread_id,
439 limit=items_list_params.limit or DEFAULT_PAGE_SIZE,
440 order=items_list_params.order,
441 after=items_list_params.after,
442 context=context,
443 )
444 # filter out hidden context items
445 items.data = [
446 item
447 for item in items.data
448 if not isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
449 ]
450 return self._serialize(items)
451 case ThreadsUpdateReq():
452 thread = await self.store.load_thread(
453 request.params.thread_id, context=context
454 )
455 thread.title = request.params.title
456 await self.store.save_thread(thread, context=context)
457 return self._serialize(self._to_thread_response(thread))
458 case ThreadsDeleteReq():
459 await self.store.delete_thread(
460 request.params.thread_id, context=context
461 )
462 return b"{}"
463 case _:
464 assert_never(request)
465
466 async def _process_streaming(
467 self, request: StreamingReq, context: TContext
468 ) -> AsyncGenerator[bytes, None]:
469 try:
470 async for event in self._process_streaming_impl(request, context):
471 b = self._serialize(event)
472 yield b"data: " + b + b"\n\n"
473 except asyncio.CancelledError:
474 # Let cancellation bubble up without logging as an error.
475 raise
476 except Exception:
477 logger.exception("Error while generating streamed response")
478 raise
479
480 async def _process_streaming_impl(
481 self, request: StreamingReq, context: TContext
482 ) -> AsyncGenerator[ThreadStreamEvent, None]:
483 match request:
484 case ThreadsCreateReq():
485 thread = Thread(
486 id=self.store.generate_thread_id(context),
487 created_at=datetime.now(),
488 items=Page(),
489 )
490 await self.store.save_thread(
491 ThreadMetadata(**thread.model_dump()),
492 context=context,
493 )
494 yield ThreadCreatedEvent(thread=self._to_thread_response(thread))
495 user_message = await self._build_user_message_item(
496 request.params.input, thread, context
497 )
498 async for event in self._process_new_thread_item_respond(
499 thread,
500 user_message,
501 context,
502 ):
503 yield event
504
505 case ThreadsAddUserMessageReq():
506 thread = await self.store.load_thread(
507 request.params.thread_id, context=context
508 )
509 user_message = await self._build_user_message_item(
510 request.params.input, thread, context
511 )
512 async for event in self._process_new_thread_item_respond(
513 thread,
514 user_message,
515 context,
516 ):
517 yield event
518
519 case ThreadsAddClientToolOutputReq():
520 thread = await self.store.load_thread(
521 request.params.thread_id, context=context
522 )
523 items = await self.store.load_thread_items(
524 thread.id, None, 1, "desc", context
525 )
526 tool_call = next(
527 (
528 item
529 for item in items.data
530 if isinstance(item, ClientToolCallItem)
531 and item.status == "pending"
532 ),
533 None,
534 )
535 if not tool_call:
536 raise ValueError(
537 f"Last thread item in {thread.id} was not a ClientToolCallItem"
538 )
539
540 tool_call.output = request.params.result
541 tool_call.status = "completed"
542
543 await self.store.save_item(thread.id, tool_call, context=context)
544
545 # Safety against dangling pending tool calls if there are
546 # multiple in a row, which should be impossible, and
547 # integrations should ultimately filter out pending tool calls
548 # when creating input response messages.
549 await self._cleanup_pending_client_tool_call(thread, context)
550
551 async for event in self._process_events(
552 thread,
553 context,
554 lambda: self.respond(thread, None, context),
555 ):
556 yield event
557
558 case ThreadsRetryAfterItemReq():
559 thread_metadata = await self.store.load_thread(
560 request.params.thread_id, context=context
561 )
562
563 # Collect items to remove (all items after the user message)
564 items_to_remove: list[ThreadItem] = []
565 user_message_item = None
566
567 async for item in self._paginate_thread_items_reverse(
568 request.params.thread_id, context
569 ):
570 if item.id == request.params.item_id:
571 if not isinstance(item, UserMessageItem):
572 raise ValueError(
573 f"Item {request.params.item_id} is not a user message"
574 )
575 user_message_item = item
576 break
577 items_to_remove.append(item)
578
579 if user_message_item:
580 for item in items_to_remove:
581 await self.store.delete_thread_item(
582 request.params.thread_id, item.id, context=context
583 )
584 async for event in self._process_events(
585 thread_metadata,
586 context,
587 lambda: self.respond(
588 thread_metadata,
589 user_message_item,
590 context,
591 ),
592 ):
593 yield event
594 case ThreadsCustomActionReq():
595 thread_metadata = await self.store.load_thread(
596 request.params.thread_id, context=context
597 )
598
599 item: ThreadItem | None = None
600 if request.params.item_id:
601 item = await self.store.load_item(
602 request.params.thread_id,
603 request.params.item_id,
604 context=context,
605 )
606
607 if item and not isinstance(item, WidgetItem):
608 # shouldn't happen if the caller is using the API correctly.
609 yield ErrorEvent(
610 code=ErrorCode.STREAM_ERROR,
611 allow_retry=False,
612 )
613 return
614
615 async for event in self._process_events(
616 thread_metadata,
617 context,
618 lambda: self.action(
619 thread_metadata,
620 request.params.action,
621 item,
622 context,
623 ),
624 ):
625 yield event
626
627 case _:
628 assert_never(request)
629
630 async def _cleanup_pending_client_tool_call(
631 self, thread: ThreadMetadata, context: TContext
632 ) -> None:
633 items = await self.store.load_thread_items(
634 thread.id, None, DEFAULT_PAGE_SIZE, "desc", context
635 )
636 for tool_call in items.data:
637 if not isinstance(tool_call, ClientToolCallItem):
638 continue
639 if tool_call.status == "pending":
640 logger.warning(
641 f"Client tool call {tool_call.call_id} was not completed, ignoring"
642 )
643 await self.store.delete_thread_item(
644 thread.id, tool_call.id, context=context
645 )
646
647 async def _process_new_thread_item_respond(
648 self,
649 thread: ThreadMetadata,
650 item: UserMessageItem,
651 context: TContext,
652 ) -> AsyncIterator[ThreadStreamEvent]:
653 await self.store.add_thread_item(thread.id, item, context=context)
654 await self._cleanup_pending_client_tool_call(thread, context)
655 yield ThreadItemDoneEvent(item=item)
656
657 async for event in self._process_events(
658 thread,
659 context,
660 lambda: self.respond(thread, item, context),
661 ):
662 yield event
663
664 async def _process_events(
665 self,
666 thread: ThreadMetadata,
667 context: TContext,
668 stream: Callable[[], AsyncIterator[ThreadStreamEvent]],
669 ) -> AsyncIterator[ThreadStreamEvent]:
670 await asyncio.sleep(0) # allow the response to start streaming
671
672 # Send initial stream options
673 yield StreamOptionsEvent(
674 stream_options=self.get_stream_options(thread, context)
675 )
676
677 last_thread = thread.model_copy(deep=True)
678
679 # Keep track of items that were streamed but not yet saved
680 # so that we can persist them when the stream is cancelled.
681 pending_items: dict[str, ThreadItem] = {}
682
683 try:
684 with agents_sdk_user_agent_override():
685 async for event in stream():
686 if isinstance(event, ThreadItemAddedEvent):
687 pending_items[event.item.id] = event.item
688
689 match event:
690 case ThreadItemDoneEvent():
691 await self.store.add_thread_item(
692 thread.id, event.item, context=context
693 )
694 pending_items.pop(event.item.id, None)
695 case ThreadItemRemovedEvent():
696 await self.store.delete_thread_item(
697 thread.id, event.item_id, context=context
698 )
699 pending_items.pop(event.item_id, None)
700 case ThreadItemReplacedEvent():
701 await self.store.save_item(
702 thread.id, event.item, context=context
703 )
704 pending_items.pop(event.item.id, None)
705 case ThreadItemUpdatedEvent():
706 # Keep pending assistant message and workflow items up to date
707 # so that we have a reference to the latest version of these pending items
708 # when the stream is cancelled.
709 self._update_pending_items(pending_items, event)
710
711 # special case - don't send hidden context items back to the client
712 should_swallow_event = isinstance(
713 event, ThreadItemDoneEvent
714 ) and isinstance(
715 event.item, (HiddenContextItem, SDKHiddenContextItem)
716 )
717
718 if not should_swallow_event:
719 yield event
720
721 # in case user updated the thread while streaming
722 if thread != last_thread:
723 last_thread = thread.model_copy(deep=True)
724 await self.store.save_thread(thread, context=context)
725 yield ThreadUpdatedEvent(
726 thread=self._to_thread_response(thread)
727 )
728 # in case user updated the thread while streaming
729 if thread != last_thread:
730 last_thread = thread.model_copy(deep=True)
731 await self.store.save_thread(thread, context=context)
732 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
733 except asyncio.CancelledError:
734 await self.handle_stream_cancelled(
735 thread, list(pending_items.values()), context
736 )
737 raise
738 except CustomStreamError as e:
739 yield ErrorEvent(
740 code="custom",
741 message=e.message,
742 allow_retry=e.allow_retry,
743 )
744 except StreamError as e:
745 yield ErrorEvent(
746 code=e.code,
747 allow_retry=e.allow_retry,
748 )
749 except Exception as e:
750 yield ErrorEvent(
751 code=ErrorCode.STREAM_ERROR,
752 allow_retry=True,
753 )
754 logger.exception(e)
755
756 if thread != last_thread:
757 # in case user updated the thread at the end of the stream
758 await self.store.save_thread(thread, context=context)
759 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
760
761 def _apply_assistant_message_update(
762 self,
763 item: AssistantMessageItem,
764 update: AssistantMessageContentPartAdded
765 | AssistantMessageContentPartTextDelta
766 | AssistantMessageContentPartAnnotationAdded
767 | AssistantMessageContentPartDone,
768 ) -> AssistantMessageItem:
769 updated = item.model_copy(deep=True)
770
771 # Pad the content list so the requested content_index exists before we write into it.
772 # (Streaming updates can arrive for an index that hasn’t been created yet)
773 while len(updated.content) <= update.content_index:
774 updated.content.append(AssistantMessageContent(text="", annotations=[]))
775
776 match update:
777 case AssistantMessageContentPartAdded():
778 updated.content[update.content_index] = update.content
779 case AssistantMessageContentPartTextDelta():
780 updated.content[update.content_index].text += update.delta
781 case AssistantMessageContentPartAnnotationAdded():
782 annotations = updated.content[update.content_index].annotations
783 if update.annotation_index <= len(annotations):
784 annotations.insert(update.annotation_index, update.annotation)
785 else:
786 annotations.append(update.annotation)
787 case AssistantMessageContentPartDone():
788 updated.content[update.content_index] = update.content
789 return updated
790
791 def _update_pending_items(
792 self,
793 pending_items: dict[str, ThreadItem],
794 event: ThreadItemUpdatedEvent,
795 ):
796 updated_item = pending_items.get(event.item_id)
797 update = event.update
798 match updated_item:
799 case AssistantMessageItem():
800 if isinstance(
801 update,
802 (
803 AssistantMessageContentPartAdded,
804 AssistantMessageContentPartTextDelta,
805 AssistantMessageContentPartAnnotationAdded,
806 AssistantMessageContentPartDone,
807 ),
808 ):
809 pending_items[updated_item.id] = (
810 self._apply_assistant_message_update(updated_item, update)
811 )
812 case WorkflowItem():
813 if isinstance(update, (WorkflowTaskUpdated, WorkflowTaskAdded)):
814 match update:
815 case WorkflowTaskUpdated():
816 updated_item.workflow.tasks[update.task_index] = update.task
817 case WorkflowTaskAdded():
818 updated_item.workflow.tasks.append(update.task)
819
820 pending_items[updated_item.id] = updated_item
821 case _:
822 pass
823
824 async def _build_user_message_item(
825 self, input: UserMessageInput, thread: ThreadMetadata, context: TContext
826 ) -> UserMessageItem:
827 return UserMessageItem(
828 id=self.store.generate_item_id("message", thread, context),
829 content=input.content,
830 thread_id=thread.id,
831 attachments=[
832 await self.store.load_attachment(attachment_id, context)
833 for attachment_id in input.attachments
834 ],
835 quoted_text=input.quoted_text,
836 inference_options=input.inference_options,
837 created_at=datetime.now(),
838 )
839
840 async def _load_full_thread(self, thread_id: str, context: TContext) -> Thread:
841 thread_meta = await self.store.load_thread(thread_id, context=context)
842 thread_items = await self.store.load_thread_items(
843 thread_id,
844 after=None,
845 limit=DEFAULT_PAGE_SIZE,
846 order="asc",
847 context=context,
848 )
849 return Thread(**thread_meta.model_dump(), items=thread_items)
850
851 async def _paginate_thread_items_reverse(
852 self, thread_id: str, context: TContext
853 ) -> AsyncIterator[ThreadItem]:
854 """Paginate through thread items in reverse order (newest first)."""
855 after = None
856 while True:
857 items = await self.store.load_thread_items(
858 thread_id, after, DEFAULT_PAGE_SIZE, "desc", context
859 )
860 for item in items.data:
861 yield item
862
863 if not items.has_more:
864 break
865 after = items.after
866
867 def _serialize(self, obj: BaseModel) -> bytes:
868 return obj.model_dump_json(by_alias=True, exclude_none=True).encode("utf-8")
869
870 def _to_thread_response(self, thread: ThreadMetadata | Thread) -> Thread:
871 def is_hidden(item: ThreadItem) -> bool:
872 return isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
873
874 items = thread.items if isinstance(thread, Thread) else Page()
875 items.data = [item for item in items.data if not is_hidden(item)]
876
877 return Thread(
878 id=thread.id,
879 title=thread.title,
880 created_at=thread.created_at,
881 items=items,
882 status=thread.status,
883 )
884