openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
5a7244ba1337ba55dcb1f9e4a96cc5527b00bc93

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/server.py

881lines · modecode

1import asyncio
2from abc import ABC, abstractmethod
3from collections.abc import AsyncIterator
4from contextlib import contextmanager
5from datetime import datetime
6from typing import (
7 Any,
8 AsyncGenerator,
9 AsyncIterable,
10 Callable,
11 Generic,
12)
13
14import agents
15from agents.models.chatcmpl_helpers import (
16 HEADERS_OVERRIDE as chat_completions_headers_override,
17)
18from agents.models.openai_responses import (
19 _HEADERS_OVERRIDE as responses_headers_override,
20)
21from pydantic import BaseModel, TypeAdapter
22from typing_extensions import TypeVar, assert_never
23
24from chatkit.errors import CustomStreamError, StreamError
25
26from .logger import logger
27from .store import AttachmentStore, Store, StoreItemType, default_generate_id
28from .types import (
29 Action,
30 AssistantMessageContent,
31 AssistantMessageContentPartAdded,
32 AssistantMessageContentPartAnnotationAdded,
33 AssistantMessageContentPartDone,
34 AssistantMessageContentPartTextDelta,
35 AssistantMessageItem,
36 AttachmentsCreateReq,
37 AttachmentsDeleteReq,
38 ChatKitReq,
39 ClientToolCallItem,
40 ErrorCode,
41 ErrorEvent,
42 FeedbackKind,
43 HiddenContextItem,
44 ItemsFeedbackReq,
45 ItemsListReq,
46 NonStreamingReq,
47 Page,
48 SDKHiddenContextItem,
49 StreamingReq,
50 StreamOptions,
51 StreamOptionsEvent,
52 Thread,
53 ThreadCreatedEvent,
54 ThreadItem,
55 ThreadItemAddedEvent,
56 ThreadItemDoneEvent,
57 ThreadItemRemovedEvent,
58 ThreadItemReplacedEvent,
59 ThreadItemUpdatedEvent,
60 ThreadMetadata,
61 ThreadsAddClientToolOutputReq,
62 ThreadsAddUserMessageReq,
63 ThreadsCreateReq,
64 ThreadsCustomActionReq,
65 ThreadsDeleteReq,
66 ThreadsGetByIdReq,
67 ThreadsListReq,
68 ThreadsRetryAfterItemReq,
69 ThreadStreamEvent,
70 ThreadsUpdateReq,
71 ThreadUpdatedEvent,
72 UserMessageInput,
73 UserMessageItem,
74 WidgetComponentUpdated,
75 WidgetItem,
76 WidgetRootUpdated,
77 WidgetStreamingTextValueDelta,
78 WorkflowItem,
79 WorkflowTaskAdded,
80 WorkflowTaskUpdated,
81 is_streaming_req,
82)
83from .version import __version__
84from .widgets import Markdown, Text, WidgetComponent, WidgetComponentBase, WidgetRoot
85
86DEFAULT_PAGE_SIZE = 20
87DEFAULT_ERROR_MESSAGE = "An error occurred when generating a response."
88
89
90def diff_widget(
91 before: WidgetRoot, after: WidgetRoot
92) -> list[WidgetStreamingTextValueDelta | WidgetRootUpdated | WidgetComponentUpdated]:
93 """
94 Compare two WidgetRoots and return a list of deltas.
95 """
96
97 def full_replace(before: WidgetComponentBase, after: WidgetComponentBase) -> bool:
98 if (
99 before.type != after.type
100 or before.id != after.id
101 or before.key != after.key
102 ):
103 return True
104
105 def full_replace_value(before_value: Any, after_value: Any) -> bool:
106 if isinstance(before_value, list) and isinstance(after_value, list):
107 if len(before_value) != len(after_value):
108 return True
109 for nth_before_value, nth_after_value in zip(before_value, after_value):
110 if full_replace_value(nth_before_value, nth_after_value):
111 return True
112 elif before_value != after_value:
113 if isinstance(before_value, WidgetComponentBase) and isinstance(
114 after_value, WidgetComponentBase
115 ):
116 return full_replace(before_value, after_value)
117 else:
118 return True
119 return False
120
121 for field in before.model_fields_set.union(after.model_fields_set):
122 if (
123 isinstance(before, (Markdown, Text))
124 and isinstance(after, (Markdown, Text))
125 and field == "value"
126 and after.value.startswith(before.value)
127 ):
128 # Appends to the value prop of Markdown or Text do not trigger a full replace
129 continue
130 if full_replace_value(getattr(before, field), getattr(after, field)):
131 return True
132
133 return False
134
135 if full_replace(before, after):
136 return [WidgetRootUpdated(widget=after)]
137
138 deltas: list[
139 WidgetStreamingTextValueDelta | WidgetComponentUpdated | WidgetRootUpdated
140 ] = []
141
142 def find_all_streaming_text_components(
143 component: WidgetComponent | WidgetRoot,
144 ) -> dict[str, Markdown | Text]:
145 components = {}
146
147 def recurse(component: WidgetComponent | WidgetRoot):
148 if isinstance(component, (Markdown, Text)) and component.id:
149 components[component.id] = component
150
151 if hasattr(component, "children"):
152 children = getattr(component, "children", None) or []
153 for child in children:
154 recurse(child)
155
156 recurse(component)
157 return components
158
159 before_nodes = find_all_streaming_text_components(before)
160 after_nodes = find_all_streaming_text_components(after)
161
162 for id, after_node in after_nodes.items():
163 before_node = before_nodes.get(id)
164 if before_node is None:
165 raise ValueError(
166 f"Node {id} was not present when the widget was initially rendered. All nodes with ID must persist across all widget updates."
167 )
168
169 if before_node.value != after_node.value:
170 if not after_node.value.startswith(before_node.value):
171 raise ValueError(
172 f"Node {id} was updated with a new value that is not a prefix of the initial value. All widget updates must be cumulative."
173 )
174 done = not after_node.streaming
175 deltas.append(
176 WidgetStreamingTextValueDelta(
177 component_id=id,
178 delta=after_node.value[len(before_node.value) :],
179 done=done,
180 )
181 )
182
183 return deltas
184
185
186async def stream_widget(
187 thread: ThreadMetadata,
188 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
189 copy_text: str | None = None,
190 generate_id: Callable[[StoreItemType], str] = default_generate_id,
191) -> AsyncIterator[ThreadStreamEvent]:
192 item_id = generate_id("message")
193
194 if not isinstance(widget, AsyncGenerator):
195 yield ThreadItemDoneEvent(
196 item=WidgetItem(
197 id=item_id,
198 thread_id=thread.id,
199 created_at=datetime.now(),
200 widget=widget,
201 copy_text=copy_text,
202 ),
203 )
204 return
205
206 initial_state = await widget.__anext__()
207
208 item = WidgetItem(
209 id=item_id,
210 created_at=datetime.now(),
211 widget=initial_state,
212 copy_text=copy_text,
213 thread_id=thread.id,
214 )
215
216 yield ThreadItemAddedEvent(item=item)
217
218 last_state = initial_state
219
220 while widget:
221 try:
222 new_state = await widget.__anext__()
223 for update in diff_widget(last_state, new_state):
224 yield ThreadItemUpdatedEvent(
225 item_id=item_id,
226 update=update,
227 )
228 last_state = new_state
229 except StopAsyncIteration:
230 break
231
232 yield ThreadItemDoneEvent(
233 item=item.model_copy(update={"widget": last_state}),
234 )
235
236
237@contextmanager
238def agents_sdk_user_agent_override():
239 ua = f"Agents/Python {agents.__version__} ChatKit/Python {__version__}"
240 chat_completions_token = chat_completions_headers_override.set({"User-Agent": ua})
241 responses_token = responses_headers_override.set({"User-Agent": ua})
242
243 yield
244
245 chat_completions_headers_override.reset(chat_completions_token)
246 responses_headers_override.reset(responses_token)
247
248
249class StreamingResult(AsyncIterable[bytes]):
250 def __init__(self, stream: AsyncGenerator[bytes, None]):
251 self.json_events = stream
252
253 async def __aiter__(self):
254 async for event in self.json_events:
255 yield event
256
257
258class NonStreamingResult:
259 def __init__(self, result: bytes):
260 self.json = result
261
262
263TContext = TypeVar("TContext", default=Any)
264
265
266class ChatKitServer(ABC, Generic[TContext]):
267 def __init__(
268 self,
269 store: Store[TContext],
270 attachment_store: AttachmentStore[TContext] | None = None,
271 ):
272 self.store = store
273 self.attachment_store = attachment_store
274
275 def _get_attachment_store(self) -> AttachmentStore[TContext]:
276 """Return the configured AttachmentStore or raise if missing."""
277 if self.attachment_store is None:
278 raise RuntimeError(
279 "AttachmentStore is not configured. Provide a AttachmentStore to ChatKitServer to handle file operations."
280 )
281 return self.attachment_store
282
283 @abstractmethod
284 def respond(
285 self,
286 thread: ThreadMetadata,
287 input_user_message: UserMessageItem | None,
288 context: TContext,
289 ) -> AsyncIterator[ThreadStreamEvent]:
290 """Stream `ThreadStreamEvent` instances for a new user message.
291
292 Args:
293 thread: Metadata for the thread being processed.
294 input_user_message: The incoming message the server should respond to, if any.
295 context: Arbitrary per-request context provided by the caller.
296
297 Returns:
298 An async iterator that yields events representing the server's response.
299 """
300 pass
301
302 async def add_feedback( # noqa: B027
303 self,
304 thread_id: str,
305 item_ids: list[str],
306 feedback: FeedbackKind,
307 context: TContext,
308 ) -> None:
309 pass
310
311 def action(
312 self,
313 thread: ThreadMetadata,
314 action: Action[str, Any],
315 sender: WidgetItem | None,
316 context: TContext,
317 ) -> AsyncIterator[ThreadStreamEvent]:
318 raise NotImplementedError(
319 "The action() method must be overridden to react to actions. "
320 "See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
321 )
322
323 def get_stream_options(
324 self, thread: ThreadMetadata, context: TContext
325 ) -> StreamOptions:
326 """
327 Return stream-level runtime options. Allows the user to cancel the stream by default.
328 Override this method to customize behavior.
329 """
330 return StreamOptions(allow_cancel=True)
331
332 async def handle_stream_cancelled(
333 self,
334 thread: ThreadMetadata,
335 pending_items: list[ThreadItem],
336 context: TContext,
337 ):
338 """Perform custom cleanup / stop inference when a stream is cancelled.
339 Updates you make here will not be reflected in the UI until a reload.
340
341 The default implementation persists any non-empty pending assistant messages
342 to the thread but does not auto-save pending widget items or workflow items.
343
344 Args:
345 thread: The thread that was being processed.
346 pending_items: Items that were not done streaming at cancellation time.
347 context: Arbitrary per-request context provided by the caller.
348 """
349 pending_assistant_message_items: list[AssistantMessageItem] = [
350 item for item in pending_items if isinstance(item, AssistantMessageItem)
351 ]
352 for item in pending_assistant_message_items:
353 is_empty = len(item.content) == 0 or all(
354 (not content.text.strip()) for content in item.content
355 )
356 if not is_empty:
357 await self.store.add_thread_item(thread.id, item, context=context)
358
359 # Add a hidden context item to the thread to indicate that the stream was cancelled.
360 # Otherwise, depending on the timing of the cancellation, subsequent responses may
361 # attempt to continue the cancelled response.
362 await self.store.add_thread_item(
363 thread.id,
364 SDKHiddenContextItem(
365 thread_id=thread.id,
366 created_at=datetime.now(),
367 id=self.store.generate_item_id("sdk_hidden_context", thread, context),
368 content="The user cancelled the stream.",
369 ),
370 context=context,
371 )
372
373 async def process(
374 self, request: str | bytes | bytearray, context: TContext
375 ) -> StreamingResult | NonStreamingResult:
376 parsed_request = TypeAdapter[ChatKitReq](ChatKitReq).validate_json(request)
377 logger.info(f"Received request op: {parsed_request.type}")
378
379 if is_streaming_req(parsed_request):
380 return StreamingResult(self._process_streaming(parsed_request, context))
381 else:
382 return NonStreamingResult(
383 await self._process_non_streaming(parsed_request, context)
384 )
385
386 async def _process_non_streaming(
387 self, request: NonStreamingReq, context: TContext
388 ) -> bytes:
389 match request:
390 case ThreadsGetByIdReq():
391 thread = await self._load_full_thread(
392 request.params.thread_id, context=context
393 )
394 return self._serialize(self._to_thread_response(thread))
395 case ThreadsListReq():
396 params = request.params
397 threads = await self.store.load_threads(
398 limit=params.limit or DEFAULT_PAGE_SIZE,
399 after=params.after,
400 order=params.order,
401 context=context,
402 )
403 return self._serialize(
404 Page(
405 has_more=threads.has_more,
406 after=threads.after,
407 data=[
408 self._to_thread_response(thread) for thread in threads.data
409 ],
410 )
411 )
412 case ItemsFeedbackReq():
413 await self.add_feedback(
414 request.params.thread_id,
415 request.params.item_ids,
416 request.params.kind,
417 context,
418 )
419 return b"{}"
420 case AttachmentsCreateReq():
421 attachment_store = self._get_attachment_store()
422 attachment = await attachment_store.create_attachment(
423 request.params, context
424 )
425 return self._serialize(attachment)
426 case AttachmentsDeleteReq():
427 attachment_store = self._get_attachment_store()
428 await attachment_store.delete_attachment(
429 request.params.attachment_id, context=context
430 )
431 await self.store.delete_attachment(
432 request.params.attachment_id, context=context
433 )
434 return b"{}"
435 case ItemsListReq():
436 items_list_params = request.params
437 items = await self.store.load_thread_items(
438 items_list_params.thread_id,
439 limit=items_list_params.limit or DEFAULT_PAGE_SIZE,
440 order=items_list_params.order,
441 after=items_list_params.after,
442 context=context,
443 )
444 # filter out HiddenContextItems
445 items.data = [
446 item
447 for item in items.data
448 if not isinstance(item, HiddenContextItem)
449 ]
450 return self._serialize(items)
451 case ThreadsUpdateReq():
452 thread = await self.store.load_thread(
453 request.params.thread_id, context=context
454 )
455 thread.title = request.params.title
456 await self.store.save_thread(thread, context=context)
457 return self._serialize(self._to_thread_response(thread))
458 case ThreadsDeleteReq():
459 await self.store.delete_thread(
460 request.params.thread_id, context=context
461 )
462 return b"{}"
463 case _:
464 assert_never(request)
465
466 async def _process_streaming(
467 self, request: StreamingReq, context: TContext
468 ) -> AsyncGenerator[bytes, None]:
469 try:
470 async for event in self._process_streaming_impl(request, context):
471 b = self._serialize(event)
472 yield b"data: " + b + b"\n\n"
473 except asyncio.CancelledError:
474 # Let cancellation bubble up without logging as an error.
475 raise
476 except Exception:
477 logger.exception("Error while generating streamed response")
478 raise
479
480 async def _process_streaming_impl(
481 self, request: StreamingReq, context: TContext
482 ) -> AsyncGenerator[ThreadStreamEvent, None]:
483 match request:
484 case ThreadsCreateReq():
485 thread = Thread(
486 id=self.store.generate_thread_id(context),
487 created_at=datetime.now(),
488 items=Page(),
489 )
490 await self.store.save_thread(
491 ThreadMetadata(**thread.model_dump()),
492 context=context,
493 )
494 yield ThreadCreatedEvent(thread=self._to_thread_response(thread))
495 user_message = await self._build_user_message_item(
496 request.params.input, thread, context
497 )
498 async for event in self._process_new_thread_item_respond(
499 thread,
500 user_message,
501 context,
502 ):
503 yield event
504
505 case ThreadsAddUserMessageReq():
506 thread = await self.store.load_thread(
507 request.params.thread_id, context=context
508 )
509 user_message = await self._build_user_message_item(
510 request.params.input, thread, context
511 )
512 async for event in self._process_new_thread_item_respond(
513 thread,
514 user_message,
515 context,
516 ):
517 yield event
518
519 case ThreadsAddClientToolOutputReq():
520 thread = await self.store.load_thread(
521 request.params.thread_id, context=context
522 )
523 items = await self.store.load_thread_items(
524 thread.id, None, 1, "desc", context
525 )
526 tool_call = next(
527 (
528 item
529 for item in items.data
530 if isinstance(item, ClientToolCallItem)
531 and item.status == "pending"
532 ),
533 None,
534 )
535 if not tool_call:
536 raise ValueError(
537 f"Last thread item in {thread.id} was not a ClientToolCallItem"
538 )
539
540 tool_call.output = request.params.result
541 tool_call.status = "completed"
542
543 await self.store.save_item(thread.id, tool_call, context=context)
544
545 # Safety against dangling pending tool calls if there are
546 # multiple in a row, which should be impossible, and
547 # integrations should ultimately filter out pending tool calls
548 # when creating input response messages.
549 await self._cleanup_pending_client_tool_call(thread, context)
550
551 async for event in self._process_events(
552 thread,
553 context,
554 lambda: self.respond(thread, None, context),
555 ):
556 yield event
557
558 case ThreadsRetryAfterItemReq():
559 thread_metadata = await self.store.load_thread(
560 request.params.thread_id, context=context
561 )
562
563 # Collect items to remove (all items after the user message)
564 items_to_remove: list[ThreadItem] = []
565 user_message_item = None
566
567 async for item in self._paginate_thread_items_reverse(
568 request.params.thread_id, context
569 ):
570 if item.id == request.params.item_id:
571 if not isinstance(item, UserMessageItem):
572 raise ValueError(
573 f"Item {request.params.item_id} is not a user message"
574 )
575 user_message_item = item
576 break
577 items_to_remove.append(item)
578
579 if user_message_item:
580 for item in items_to_remove:
581 await self.store.delete_thread_item(
582 request.params.thread_id, item.id, context=context
583 )
584 async for event in self._process_events(
585 thread_metadata,
586 context,
587 lambda: self.respond(
588 thread_metadata,
589 user_message_item,
590 context,
591 ),
592 ):
593 yield event
594 case ThreadsCustomActionReq():
595 thread_metadata = await self.store.load_thread(
596 request.params.thread_id, context=context
597 )
598
599 item: ThreadItem | None = None
600 if request.params.item_id:
601 item = await self.store.load_item(
602 request.params.thread_id,
603 request.params.item_id,
604 context=context,
605 )
606
607 if item and not isinstance(item, WidgetItem):
608 # shouldn't happen if the caller is using the API correctly.
609 yield ErrorEvent(
610 code=ErrorCode.STREAM_ERROR,
611 allow_retry=False,
612 )
613 return
614
615 async for event in self._process_events(
616 thread_metadata,
617 context,
618 lambda: self.action(
619 thread_metadata,
620 request.params.action,
621 item,
622 context,
623 ),
624 ):
625 yield event
626
627 case _:
628 assert_never(request)
629
630 async def _cleanup_pending_client_tool_call(
631 self, thread: ThreadMetadata, context: TContext
632 ) -> None:
633 items = await self.store.load_thread_items(
634 thread.id, None, DEFAULT_PAGE_SIZE, "desc", context
635 )
636 for tool_call in items.data:
637 if not isinstance(tool_call, ClientToolCallItem):
638 continue
639 if tool_call.status == "pending":
640 logger.warning(
641 f"Client tool call {tool_call.call_id} was not completed, ignoring"
642 )
643 await self.store.delete_thread_item(
644 thread.id, tool_call.id, context=context
645 )
646
647 async def _process_new_thread_item_respond(
648 self,
649 thread: ThreadMetadata,
650 item: UserMessageItem,
651 context: TContext,
652 ) -> AsyncIterator[ThreadStreamEvent]:
653 await self.store.add_thread_item(thread.id, item, context=context)
654 await self._cleanup_pending_client_tool_call(thread, context)
655 yield ThreadItemDoneEvent(item=item)
656
657 async for event in self._process_events(
658 thread,
659 context,
660 lambda: self.respond(thread, item, context),
661 ):
662 yield event
663
664 async def _process_events(
665 self,
666 thread: ThreadMetadata,
667 context: TContext,
668 stream: Callable[[], AsyncIterator[ThreadStreamEvent]],
669 ) -> AsyncIterator[ThreadStreamEvent]:
670 await asyncio.sleep(0) # allow the response to start streaming
671
672 # Send initial stream options
673 yield StreamOptionsEvent(
674 stream_options=self.get_stream_options(thread, context)
675 )
676
677 last_thread = thread.model_copy(deep=True)
678
679 # Keep track of items that were streamed but not yet saved
680 # so that we can persist them when the stream is cancelled.
681 pending_items: dict[str, ThreadItem] = {}
682
683 try:
684 with agents_sdk_user_agent_override():
685 async for event in stream():
686 if isinstance(event, ThreadItemAddedEvent):
687 pending_items[event.item.id] = event.item
688
689 match event:
690 case ThreadItemDoneEvent():
691 await self.store.add_thread_item(
692 thread.id, event.item, context=context
693 )
694 pending_items.pop(event.item.id, None)
695 case ThreadItemRemovedEvent():
696 await self.store.delete_thread_item(
697 thread.id, event.item_id, context=context
698 )
699 pending_items.pop(event.item_id, None)
700 case ThreadItemReplacedEvent():
701 await self.store.save_item(
702 thread.id, event.item, context=context
703 )
704 pending_items.pop(event.item.id, None)
705 case ThreadItemUpdatedEvent():
706 # Keep pending assistant message and workflow items up to date
707 # so that we have a reference to the latest version of these pending items
708 # when the stream is cancelled.
709 self._update_pending_items(pending_items, event)
710
711 # special case - don't send hidden context items back to the client
712 should_swallow_event = isinstance(
713 event, ThreadItemDoneEvent
714 ) and isinstance(event.item, HiddenContextItem)
715
716 if not should_swallow_event:
717 yield event
718
719 # in case user updated the thread while streaming
720 if thread != last_thread:
721 last_thread = thread.model_copy(deep=True)
722 await self.store.save_thread(thread, context=context)
723 yield ThreadUpdatedEvent(
724 thread=self._to_thread_response(thread)
725 )
726 # in case user updated the thread while streaming
727 if thread != last_thread:
728 last_thread = thread.model_copy(deep=True)
729 await self.store.save_thread(thread, context=context)
730 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
731 except asyncio.CancelledError:
732 await self.handle_stream_cancelled(
733 thread, list(pending_items.values()), context
734 )
735 raise
736 except CustomStreamError as e:
737 yield ErrorEvent(
738 code="custom",
739 message=e.message,
740 allow_retry=e.allow_retry,
741 )
742 except StreamError as e:
743 yield ErrorEvent(
744 code=e.code,
745 allow_retry=e.allow_retry,
746 )
747 except Exception as e:
748 yield ErrorEvent(
749 code=ErrorCode.STREAM_ERROR,
750 allow_retry=True,
751 )
752 logger.exception(e)
753
754 if thread != last_thread:
755 # in case user updated the thread at the end of the stream
756 await self.store.save_thread(thread, context=context)
757 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
758
759 def _apply_assistant_message_update(
760 self,
761 item: AssistantMessageItem,
762 update: AssistantMessageContentPartAdded
763 | AssistantMessageContentPartTextDelta
764 | AssistantMessageContentPartAnnotationAdded
765 | AssistantMessageContentPartDone,
766 ) -> AssistantMessageItem:
767 updated = item.model_copy(deep=True)
768
769 # Pad the content list so the requested content_index exists before we write into it.
770 # (Streaming updates can arrive for an index that hasn’t been created yet)
771 while len(updated.content) <= update.content_index:
772 updated.content.append(AssistantMessageContent(text="", annotations=[]))
773
774 match update:
775 case AssistantMessageContentPartAdded():
776 updated.content[update.content_index] = update.content
777 case AssistantMessageContentPartTextDelta():
778 updated.content[update.content_index].text += update.delta
779 case AssistantMessageContentPartAnnotationAdded():
780 annotations = updated.content[update.content_index].annotations
781 if update.annotation_index <= len(annotations):
782 annotations.insert(update.annotation_index, update.annotation)
783 else:
784 annotations.append(update.annotation)
785 case AssistantMessageContentPartDone():
786 updated.content[update.content_index] = update.content
787 return updated
788
789 def _update_pending_items(
790 self,
791 pending_items: dict[str, ThreadItem],
792 event: ThreadItemUpdatedEvent,
793 ):
794 updated_item = pending_items.get(event.item_id)
795 update = event.update
796 match updated_item:
797 case AssistantMessageItem():
798 if isinstance(
799 update,
800 (
801 AssistantMessageContentPartAdded,
802 AssistantMessageContentPartTextDelta,
803 AssistantMessageContentPartAnnotationAdded,
804 AssistantMessageContentPartDone,
805 ),
806 ):
807 pending_items[updated_item.id] = (
808 self._apply_assistant_message_update(updated_item, update)
809 )
810 case WorkflowItem():
811 if isinstance(update, (WorkflowTaskUpdated, WorkflowTaskAdded)):
812 match update:
813 case WorkflowTaskUpdated():
814 updated_item.workflow.tasks[update.task_index] = update.task
815 case WorkflowTaskAdded():
816 updated_item.workflow.tasks.append(update.task)
817
818 pending_items[updated_item.id] = updated_item
819 case _:
820 pass
821
822 async def _build_user_message_item(
823 self, input: UserMessageInput, thread: ThreadMetadata, context: TContext
824 ) -> UserMessageItem:
825 return UserMessageItem(
826 id=self.store.generate_item_id("message", thread, context),
827 content=input.content,
828 thread_id=thread.id,
829 attachments=[
830 await self.store.load_attachment(attachment_id, context)
831 for attachment_id in input.attachments
832 ],
833 quoted_text=input.quoted_text,
834 inference_options=input.inference_options,
835 created_at=datetime.now(),
836 )
837
838 async def _load_full_thread(self, thread_id: str, context: TContext) -> Thread:
839 thread_meta = await self.store.load_thread(thread_id, context=context)
840 thread_items = await self.store.load_thread_items(
841 thread_id,
842 after=None,
843 limit=DEFAULT_PAGE_SIZE,
844 order="asc",
845 context=context,
846 )
847 return Thread(**thread_meta.model_dump(), items=thread_items)
848
849 async def _paginate_thread_items_reverse(
850 self, thread_id: str, context: TContext
851 ) -> AsyncIterator[ThreadItem]:
852 """Paginate through thread items in reverse order (newest first)."""
853 after = None
854 while True:
855 items = await self.store.load_thread_items(
856 thread_id, after, DEFAULT_PAGE_SIZE, "desc", context
857 )
858 for item in items.data:
859 yield item
860
861 if not items.has_more:
862 break
863 after = items.after
864
865 def _serialize(self, obj: BaseModel) -> bytes:
866 return obj.model_dump_json(by_alias=True, exclude_none=True).encode("utf-8")
867
868 def _to_thread_response(self, thread: ThreadMetadata | Thread) -> Thread:
869 def is_hidden(item: ThreadItem) -> bool:
870 return isinstance(item, HiddenContextItem)
871
872 items = thread.items if isinstance(thread, Thread) else Page()
873 items.data = [item for item in items.data if not is_hidden(item)]
874
875 return Thread(
876 id=thread.id,
877 title=thread.title,
878 created_at=thread.created_at,
879 items=items,
880 status=thread.status,
881 )
882