openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
b77d8564687675d4dca66278bb553a05390c709d

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/server.py

891lines · modecode

1import asyncio
2from abc import ABC, abstractmethod
3from collections.abc import AsyncIterator
4from contextlib import contextmanager
5from datetime import datetime
6from typing import (
7 Any,
8 AsyncGenerator,
9 AsyncIterable,
10 Callable,
11 Generic,
12)
13
14import agents
15from agents.models.chatcmpl_helpers import (
16 HEADERS_OVERRIDE as chat_completions_headers_override,
17)
18from agents.models.openai_responses import (
19 _HEADERS_OVERRIDE as responses_headers_override,
20)
21from pydantic import BaseModel, TypeAdapter
22from typing_extensions import TypeVar, assert_never
23
24from chatkit.errors import CustomStreamError, StreamError
25
26from .logger import logger
27from .store import AttachmentStore, Store, StoreItemType, default_generate_id
28from .types import (
29 Action,
30 AssistantMessageContent,
31 AssistantMessageContentPartAdded,
32 AssistantMessageContentPartAnnotationAdded,
33 AssistantMessageContentPartDone,
34 AssistantMessageContentPartTextDelta,
35 AssistantMessageItem,
36 AttachmentsCreateReq,
37 AttachmentsDeleteReq,
38 ChatKitReq,
39 ClientToolCallItem,
40 ErrorCode,
41 ErrorEvent,
42 FeedbackKind,
43 HiddenContextItem,
44 ItemsFeedbackReq,
45 ItemsListReq,
46 NonStreamingReq,
47 Page,
48 SDKHiddenContextItem,
49 StreamingReq,
50 StreamOptions,
51 StreamOptionsEvent,
52 Thread,
53 ThreadCreatedEvent,
54 ThreadItem,
55 ThreadItemAddedEvent,
56 ThreadItemDoneEvent,
57 ThreadItemRemovedEvent,
58 ThreadItemReplacedEvent,
59 ThreadItemUpdatedEvent,
60 ThreadMetadata,
61 ThreadsAddClientToolOutputReq,
62 ThreadsAddUserMessageReq,
63 ThreadsCreateReq,
64 ThreadsCustomActionReq,
65 ThreadsDeleteReq,
66 ThreadsGetByIdReq,
67 ThreadsListReq,
68 ThreadsRetryAfterItemReq,
69 ThreadStreamEvent,
70 ThreadsUpdateReq,
71 ThreadUpdatedEvent,
72 UserMessageInput,
73 UserMessageItem,
74 WidgetComponentUpdated,
75 WidgetItem,
76 WidgetRootUpdated,
77 WidgetStreamingTextValueDelta,
78 WorkflowItem,
79 WorkflowTaskAdded,
80 WorkflowTaskUpdated,
81 is_streaming_req,
82)
83from .version import __version__
84from .widgets import WidgetComponent, WidgetComponentBase, WidgetRoot
85
86DEFAULT_PAGE_SIZE = 20
87DEFAULT_ERROR_MESSAGE = "An error occurred when generating a response."
88
89
90def diff_widget(
91 before: WidgetRoot, after: WidgetRoot
92) -> list[WidgetStreamingTextValueDelta | WidgetRootUpdated | WidgetComponentUpdated]:
93 """
94 Compare two WidgetRoots and return a list of deltas.
95 """
96
97 def is_streaming_text(component: WidgetComponentBase) -> bool:
98 return getattr(component, "type", None) in {"Markdown", "Text"} and isinstance(
99 getattr(component, "value", None), str
100 )
101
102 def full_replace(before: WidgetComponentBase, after: WidgetComponentBase) -> bool:
103 if (
104 before.type != after.type
105 or before.id != after.id
106 or before.key != after.key
107 ):
108 return True
109
110 def full_replace_value(before_value: Any, after_value: Any) -> bool:
111 if isinstance(before_value, list) and isinstance(after_value, list):
112 if len(before_value) != len(after_value):
113 return True
114 for nth_before_value, nth_after_value in zip(before_value, after_value):
115 if full_replace_value(nth_before_value, nth_after_value):
116 return True
117 elif before_value != after_value:
118 if isinstance(before_value, WidgetComponentBase) and isinstance(
119 after_value, WidgetComponentBase
120 ):
121 return full_replace(before_value, after_value)
122 else:
123 return True
124 return False
125
126 for field in before.model_fields_set.union(after.model_fields_set):
127 if (
128 is_streaming_text(before)
129 and is_streaming_text(after)
130 and field == "value"
131 and getattr(after, "value", "").startswith(getattr(before, "value", ""))
132 ):
133 # Appends to the value prop of Markdown or Text do not trigger a full replace
134 continue
135 if full_replace_value(getattr(before, field), getattr(after, field)):
136 return True
137
138 return False
139
140 if full_replace(before, after):
141 return [WidgetRootUpdated(widget=after)]
142
143 deltas: list[
144 WidgetStreamingTextValueDelta | WidgetComponentUpdated | WidgetRootUpdated
145 ] = []
146
147 def find_all_streaming_text_components(
148 component: WidgetComponent | WidgetRoot,
149 ) -> dict[str, WidgetComponentBase]:
150 components = {}
151
152 def recurse(component: WidgetComponent | WidgetRoot):
153 if is_streaming_text(component) and component.id:
154 components[component.id] = component
155
156 if hasattr(component, "children"):
157 children = getattr(component, "children", None) or []
158 for child in children:
159 recurse(child)
160
161 recurse(component)
162 return components
163
164 before_nodes = find_all_streaming_text_components(before)
165 after_nodes = find_all_streaming_text_components(after)
166
167 for id, after_node in after_nodes.items():
168 before_node = before_nodes.get(id)
169 if before_node is None:
170 raise ValueError(
171 f"Node {id} was not present when the widget was initially rendered. All nodes with ID must persist across all widget updates."
172 )
173
174 before_value = str(getattr(before_node, "value", None))
175 after_value = str(getattr(after_node, "value", None))
176
177 if before_value != after_value:
178 if not after_value.startswith(before_value):
179 raise ValueError(
180 f"Node {id} was updated with a new value that is not a prefix of the initial value. All widget updates must be cumulative."
181 )
182 done = not getattr(after_node, "streaming", False)
183 deltas.append(
184 WidgetStreamingTextValueDelta(
185 component_id=id,
186 delta=after_value[len(before_value) :],
187 done=done,
188 )
189 )
190
191 return deltas
192
193
194async def stream_widget(
195 thread: ThreadMetadata,
196 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
197 copy_text: str | None = None,
198 generate_id: Callable[[StoreItemType], str] = default_generate_id,
199) -> AsyncIterator[ThreadStreamEvent]:
200 item_id = generate_id("message")
201
202 if not isinstance(widget, AsyncGenerator):
203 yield ThreadItemDoneEvent(
204 item=WidgetItem(
205 id=item_id,
206 thread_id=thread.id,
207 created_at=datetime.now(),
208 widget=widget,
209 copy_text=copy_text,
210 ),
211 )
212 return
213
214 initial_state = await widget.__anext__()
215
216 item = WidgetItem(
217 id=item_id,
218 created_at=datetime.now(),
219 widget=initial_state,
220 copy_text=copy_text,
221 thread_id=thread.id,
222 )
223
224 yield ThreadItemAddedEvent(item=item)
225
226 last_state = initial_state
227
228 while widget:
229 try:
230 new_state = await widget.__anext__()
231 for update in diff_widget(last_state, new_state):
232 yield ThreadItemUpdatedEvent(
233 item_id=item_id,
234 update=update,
235 )
236 last_state = new_state
237 except StopAsyncIteration:
238 break
239
240 yield ThreadItemDoneEvent(
241 item=item.model_copy(update={"widget": last_state}),
242 )
243
244
245@contextmanager
246def agents_sdk_user_agent_override():
247 ua = f"Agents/Python {agents.__version__} ChatKit/Python {__version__}"
248 chat_completions_token = chat_completions_headers_override.set({"User-Agent": ua})
249 responses_token = responses_headers_override.set({"User-Agent": ua})
250
251 yield
252
253 chat_completions_headers_override.reset(chat_completions_token)
254 responses_headers_override.reset(responses_token)
255
256
257class StreamingResult(AsyncIterable[bytes]):
258 def __init__(self, stream: AsyncGenerator[bytes, None]):
259 self.json_events = stream
260
261 async def __aiter__(self):
262 async for event in self.json_events:
263 yield event
264
265
266class NonStreamingResult:
267 def __init__(self, result: bytes):
268 self.json = result
269
270
271TContext = TypeVar("TContext", default=Any)
272
273
274class ChatKitServer(ABC, Generic[TContext]):
275 def __init__(
276 self,
277 store: Store[TContext],
278 attachment_store: AttachmentStore[TContext] | None = None,
279 ):
280 self.store = store
281 self.attachment_store = attachment_store
282
283 def _get_attachment_store(self) -> AttachmentStore[TContext]:
284 """Return the configured AttachmentStore or raise if missing."""
285 if self.attachment_store is None:
286 raise RuntimeError(
287 "AttachmentStore is not configured. Provide a AttachmentStore to ChatKitServer to handle file operations."
288 )
289 return self.attachment_store
290
291 @abstractmethod
292 def respond(
293 self,
294 thread: ThreadMetadata,
295 input_user_message: UserMessageItem | None,
296 context: TContext,
297 ) -> AsyncIterator[ThreadStreamEvent]:
298 """Stream `ThreadStreamEvent` instances for a new user message.
299
300 Args:
301 thread: Metadata for the thread being processed.
302 input_user_message: The incoming message the server should respond to, if any.
303 context: Arbitrary per-request context provided by the caller.
304
305 Returns:
306 An async iterator that yields events representing the server's response.
307 """
308 pass
309
310 async def add_feedback( # noqa: B027
311 self,
312 thread_id: str,
313 item_ids: list[str],
314 feedback: FeedbackKind,
315 context: TContext,
316 ) -> None:
317 pass
318
319 def action(
320 self,
321 thread: ThreadMetadata,
322 action: Action[str, Any],
323 sender: WidgetItem | None,
324 context: TContext,
325 ) -> AsyncIterator[ThreadStreamEvent]:
326 raise NotImplementedError(
327 "The action() method must be overridden to react to actions. "
328 "See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
329 )
330
331 def get_stream_options(
332 self, thread: ThreadMetadata, context: TContext
333 ) -> StreamOptions:
334 """
335 Return stream-level runtime options. Allows the user to cancel the stream by default.
336 Override this method to customize behavior.
337 """
338 return StreamOptions(allow_cancel=True)
339
340 async def handle_stream_cancelled(
341 self,
342 thread: ThreadMetadata,
343 pending_items: list[ThreadItem],
344 context: TContext,
345 ):
346 """Perform custom cleanup / stop inference when a stream is cancelled.
347 Updates you make here will not be reflected in the UI until a reload.
348
349 The default implementation persists any non-empty pending assistant messages
350 to the thread but does not auto-save pending widget items or workflow items.
351
352 Args:
353 thread: The thread that was being processed.
354 pending_items: Items that were not done streaming at cancellation time.
355 context: Arbitrary per-request context provided by the caller.
356 """
357 pending_assistant_message_items: list[AssistantMessageItem] = [
358 item for item in pending_items if isinstance(item, AssistantMessageItem)
359 ]
360 for item in pending_assistant_message_items:
361 is_empty = len(item.content) == 0 or all(
362 (not content.text.strip()) for content in item.content
363 )
364 if not is_empty:
365 await self.store.add_thread_item(thread.id, item, context=context)
366
367 # Add a hidden context item to the thread to indicate that the stream was cancelled.
368 # Otherwise, depending on the timing of the cancellation, subsequent responses may
369 # attempt to continue the cancelled response.
370 await self.store.add_thread_item(
371 thread.id,
372 SDKHiddenContextItem(
373 thread_id=thread.id,
374 created_at=datetime.now(),
375 id=self.store.generate_item_id("sdk_hidden_context", thread, context),
376 content="The user cancelled the stream. Stop responding to the prior request.",
377 ),
378 context=context,
379 )
380
381 async def process(
382 self, request: str | bytes | bytearray, context: TContext
383 ) -> StreamingResult | NonStreamingResult:
384 parsed_request = TypeAdapter[ChatKitReq](ChatKitReq).validate_json(request)
385 logger.info(f"Received request op: {parsed_request.type}")
386
387 if is_streaming_req(parsed_request):
388 return StreamingResult(self._process_streaming(parsed_request, context))
389 else:
390 return NonStreamingResult(
391 await self._process_non_streaming(parsed_request, context)
392 )
393
394 async def _process_non_streaming(
395 self, request: NonStreamingReq, context: TContext
396 ) -> bytes:
397 match request:
398 case ThreadsGetByIdReq():
399 thread = await self._load_full_thread(
400 request.params.thread_id, context=context
401 )
402 return self._serialize(self._to_thread_response(thread))
403 case ThreadsListReq():
404 params = request.params
405 threads = await self.store.load_threads(
406 limit=params.limit or DEFAULT_PAGE_SIZE,
407 after=params.after,
408 order=params.order,
409 context=context,
410 )
411 return self._serialize(
412 Page(
413 has_more=threads.has_more,
414 after=threads.after,
415 data=[
416 self._to_thread_response(thread) for thread in threads.data
417 ],
418 )
419 )
420 case ItemsFeedbackReq():
421 await self.add_feedback(
422 request.params.thread_id,
423 request.params.item_ids,
424 request.params.kind,
425 context,
426 )
427 return b"{}"
428 case AttachmentsCreateReq():
429 attachment_store = self._get_attachment_store()
430 attachment = await attachment_store.create_attachment(
431 request.params, context
432 )
433 return self._serialize(attachment)
434 case AttachmentsDeleteReq():
435 attachment_store = self._get_attachment_store()
436 await attachment_store.delete_attachment(
437 request.params.attachment_id, context=context
438 )
439 await self.store.delete_attachment(
440 request.params.attachment_id, context=context
441 )
442 return b"{}"
443 case ItemsListReq():
444 items_list_params = request.params
445 items = await self.store.load_thread_items(
446 items_list_params.thread_id,
447 limit=items_list_params.limit or DEFAULT_PAGE_SIZE,
448 order=items_list_params.order,
449 after=items_list_params.after,
450 context=context,
451 )
452 # filter out hidden context items
453 items.data = [
454 item
455 for item in items.data
456 if not isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
457 ]
458 return self._serialize(items)
459 case ThreadsUpdateReq():
460 thread = await self.store.load_thread(
461 request.params.thread_id, context=context
462 )
463 thread.title = request.params.title
464 await self.store.save_thread(thread, context=context)
465 return self._serialize(self._to_thread_response(thread))
466 case ThreadsDeleteReq():
467 await self.store.delete_thread(
468 request.params.thread_id, context=context
469 )
470 return b"{}"
471 case _:
472 assert_never(request)
473
474 async def _process_streaming(
475 self, request: StreamingReq, context: TContext
476 ) -> AsyncGenerator[bytes, None]:
477 try:
478 async for event in self._process_streaming_impl(request, context):
479 b = self._serialize(event)
480 yield b"data: " + b + b"\n\n"
481 except asyncio.CancelledError:
482 # Let cancellation bubble up without logging as an error.
483 raise
484 except Exception:
485 logger.exception("Error while generating streamed response")
486 raise
487
488 async def _process_streaming_impl(
489 self, request: StreamingReq, context: TContext
490 ) -> AsyncGenerator[ThreadStreamEvent, None]:
491 match request:
492 case ThreadsCreateReq():
493 thread = Thread(
494 id=self.store.generate_thread_id(context),
495 created_at=datetime.now(),
496 items=Page(),
497 )
498 await self.store.save_thread(
499 ThreadMetadata(**thread.model_dump()),
500 context=context,
501 )
502 yield ThreadCreatedEvent(thread=self._to_thread_response(thread))
503 user_message = await self._build_user_message_item(
504 request.params.input, thread, context
505 )
506 async for event in self._process_new_thread_item_respond(
507 thread,
508 user_message,
509 context,
510 ):
511 yield event
512
513 case ThreadsAddUserMessageReq():
514 thread = await self.store.load_thread(
515 request.params.thread_id, context=context
516 )
517 user_message = await self._build_user_message_item(
518 request.params.input, thread, context
519 )
520 async for event in self._process_new_thread_item_respond(
521 thread,
522 user_message,
523 context,
524 ):
525 yield event
526
527 case ThreadsAddClientToolOutputReq():
528 thread = await self.store.load_thread(
529 request.params.thread_id, context=context
530 )
531 items = await self.store.load_thread_items(
532 thread.id, None, 1, "desc", context
533 )
534 tool_call = next(
535 (
536 item
537 for item in items.data
538 if isinstance(item, ClientToolCallItem)
539 and item.status == "pending"
540 ),
541 None,
542 )
543 if not tool_call:
544 raise ValueError(
545 f"Last thread item in {thread.id} was not a ClientToolCallItem"
546 )
547
548 tool_call.output = request.params.result
549 tool_call.status = "completed"
550
551 await self.store.save_item(thread.id, tool_call, context=context)
552
553 # Safety against dangling pending tool calls if there are
554 # multiple in a row, which should be impossible, and
555 # integrations should ultimately filter out pending tool calls
556 # when creating input response messages.
557 await self._cleanup_pending_client_tool_call(thread, context)
558
559 async for event in self._process_events(
560 thread,
561 context,
562 lambda: self.respond(thread, None, context),
563 ):
564 yield event
565
566 case ThreadsRetryAfterItemReq():
567 thread_metadata = await self.store.load_thread(
568 request.params.thread_id, context=context
569 )
570
571 # Collect items to remove (all items after the user message)
572 items_to_remove: list[ThreadItem] = []
573 user_message_item = None
574
575 async for item in self._paginate_thread_items_reverse(
576 request.params.thread_id, context
577 ):
578 if item.id == request.params.item_id:
579 if not isinstance(item, UserMessageItem):
580 raise ValueError(
581 f"Item {request.params.item_id} is not a user message"
582 )
583 user_message_item = item
584 break
585 items_to_remove.append(item)
586
587 if user_message_item:
588 for item in items_to_remove:
589 await self.store.delete_thread_item(
590 request.params.thread_id, item.id, context=context
591 )
592 async for event in self._process_events(
593 thread_metadata,
594 context,
595 lambda: self.respond(
596 thread_metadata,
597 user_message_item,
598 context,
599 ),
600 ):
601 yield event
602 case ThreadsCustomActionReq():
603 thread_metadata = await self.store.load_thread(
604 request.params.thread_id, context=context
605 )
606
607 item: ThreadItem | None = None
608 if request.params.item_id:
609 item = await self.store.load_item(
610 request.params.thread_id,
611 request.params.item_id,
612 context=context,
613 )
614
615 if item and not isinstance(item, WidgetItem):
616 # shouldn't happen if the caller is using the API correctly.
617 yield ErrorEvent(
618 code=ErrorCode.STREAM_ERROR,
619 allow_retry=False,
620 )
621 return
622
623 async for event in self._process_events(
624 thread_metadata,
625 context,
626 lambda: self.action(
627 thread_metadata,
628 request.params.action,
629 item,
630 context,
631 ),
632 ):
633 yield event
634
635 case _:
636 assert_never(request)
637
638 async def _cleanup_pending_client_tool_call(
639 self, thread: ThreadMetadata, context: TContext
640 ) -> None:
641 items = await self.store.load_thread_items(
642 thread.id, None, DEFAULT_PAGE_SIZE, "desc", context
643 )
644 for tool_call in items.data:
645 if not isinstance(tool_call, ClientToolCallItem):
646 continue
647 if tool_call.status == "pending":
648 logger.warning(
649 f"Client tool call {tool_call.call_id} was not completed, ignoring"
650 )
651 await self.store.delete_thread_item(
652 thread.id, tool_call.id, context=context
653 )
654
655 async def _process_new_thread_item_respond(
656 self,
657 thread: ThreadMetadata,
658 item: UserMessageItem,
659 context: TContext,
660 ) -> AsyncIterator[ThreadStreamEvent]:
661 await self.store.add_thread_item(thread.id, item, context=context)
662 await self._cleanup_pending_client_tool_call(thread, context)
663 yield ThreadItemDoneEvent(item=item)
664
665 async for event in self._process_events(
666 thread,
667 context,
668 lambda: self.respond(thread, item, context),
669 ):
670 yield event
671
672 async def _process_events(
673 self,
674 thread: ThreadMetadata,
675 context: TContext,
676 stream: Callable[[], AsyncIterator[ThreadStreamEvent]],
677 ) -> AsyncIterator[ThreadStreamEvent]:
678 await asyncio.sleep(0) # allow the response to start streaming
679
680 # Send initial stream options
681 yield StreamOptionsEvent(
682 stream_options=self.get_stream_options(thread, context)
683 )
684
685 last_thread = thread.model_copy(deep=True)
686
687 # Keep track of items that were streamed but not yet saved
688 # so that we can persist them when the stream is cancelled.
689 pending_items: dict[str, ThreadItem] = {}
690
691 try:
692 with agents_sdk_user_agent_override():
693 async for event in stream():
694 if isinstance(event, ThreadItemAddedEvent):
695 pending_items[event.item.id] = event.item
696
697 match event:
698 case ThreadItemDoneEvent():
699 await self.store.add_thread_item(
700 thread.id, event.item, context=context
701 )
702 pending_items.pop(event.item.id, None)
703 case ThreadItemRemovedEvent():
704 await self.store.delete_thread_item(
705 thread.id, event.item_id, context=context
706 )
707 pending_items.pop(event.item_id, None)
708 case ThreadItemReplacedEvent():
709 await self.store.save_item(
710 thread.id, event.item, context=context
711 )
712 pending_items.pop(event.item.id, None)
713 case ThreadItemUpdatedEvent():
714 # Keep pending assistant message and workflow items up to date
715 # so that we have a reference to the latest version of these pending items
716 # when the stream is cancelled.
717 self._update_pending_items(pending_items, event)
718
719 # special case - don't send hidden context items back to the client
720 should_swallow_event = isinstance(
721 event, ThreadItemDoneEvent
722 ) and isinstance(
723 event.item, (HiddenContextItem, SDKHiddenContextItem)
724 )
725
726 if not should_swallow_event:
727 yield event
728
729 # in case user updated the thread while streaming
730 if thread != last_thread:
731 last_thread = thread.model_copy(deep=True)
732 await self.store.save_thread(thread, context=context)
733 yield ThreadUpdatedEvent(
734 thread=self._to_thread_response(thread)
735 )
736 # in case user updated the thread while streaming
737 if thread != last_thread:
738 last_thread = thread.model_copy(deep=True)
739 await self.store.save_thread(thread, context=context)
740 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
741 except asyncio.CancelledError:
742 await self.handle_stream_cancelled(
743 thread, list(pending_items.values()), context
744 )
745 raise
746 except CustomStreamError as e:
747 yield ErrorEvent(
748 code="custom",
749 message=e.message,
750 allow_retry=e.allow_retry,
751 )
752 except StreamError as e:
753 yield ErrorEvent(
754 code=e.code,
755 allow_retry=e.allow_retry,
756 )
757 except Exception as e:
758 yield ErrorEvent(
759 code=ErrorCode.STREAM_ERROR,
760 allow_retry=True,
761 )
762 logger.exception(e)
763
764 if thread != last_thread:
765 # in case user updated the thread at the end of the stream
766 await self.store.save_thread(thread, context=context)
767 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
768
769 def _apply_assistant_message_update(
770 self,
771 item: AssistantMessageItem,
772 update: AssistantMessageContentPartAdded
773 | AssistantMessageContentPartTextDelta
774 | AssistantMessageContentPartAnnotationAdded
775 | AssistantMessageContentPartDone,
776 ) -> AssistantMessageItem:
777 updated = item.model_copy(deep=True)
778
779 # Pad the content list so the requested content_index exists before we write into it.
780 # (Streaming updates can arrive for an index that hasn’t been created yet)
781 while len(updated.content) <= update.content_index:
782 updated.content.append(AssistantMessageContent(text="", annotations=[]))
783
784 match update:
785 case AssistantMessageContentPartAdded():
786 updated.content[update.content_index] = update.content
787 case AssistantMessageContentPartTextDelta():
788 updated.content[update.content_index].text += update.delta
789 case AssistantMessageContentPartAnnotationAdded():
790 annotations = updated.content[update.content_index].annotations
791 if update.annotation_index <= len(annotations):
792 annotations.insert(update.annotation_index, update.annotation)
793 else:
794 annotations.append(update.annotation)
795 case AssistantMessageContentPartDone():
796 updated.content[update.content_index] = update.content
797 return updated
798
799 def _update_pending_items(
800 self,
801 pending_items: dict[str, ThreadItem],
802 event: ThreadItemUpdatedEvent,
803 ):
804 updated_item = pending_items.get(event.item_id)
805 update = event.update
806 match updated_item:
807 case AssistantMessageItem():
808 if isinstance(
809 update,
810 (
811 AssistantMessageContentPartAdded,
812 AssistantMessageContentPartTextDelta,
813 AssistantMessageContentPartAnnotationAdded,
814 AssistantMessageContentPartDone,
815 ),
816 ):
817 pending_items[updated_item.id] = (
818 self._apply_assistant_message_update(updated_item, update)
819 )
820 case WorkflowItem():
821 if isinstance(update, (WorkflowTaskUpdated, WorkflowTaskAdded)):
822 match update:
823 case WorkflowTaskUpdated():
824 updated_item.workflow.tasks[update.task_index] = update.task
825 case WorkflowTaskAdded():
826 updated_item.workflow.tasks.append(update.task)
827
828 pending_items[updated_item.id] = updated_item
829 case _:
830 pass
831
832 async def _build_user_message_item(
833 self, input: UserMessageInput, thread: ThreadMetadata, context: TContext
834 ) -> UserMessageItem:
835 return UserMessageItem(
836 id=self.store.generate_item_id("message", thread, context),
837 content=input.content,
838 thread_id=thread.id,
839 attachments=[
840 await self.store.load_attachment(attachment_id, context)
841 for attachment_id in input.attachments
842 ],
843 quoted_text=input.quoted_text,
844 inference_options=input.inference_options,
845 created_at=datetime.now(),
846 )
847
848 async def _load_full_thread(self, thread_id: str, context: TContext) -> Thread:
849 thread_meta = await self.store.load_thread(thread_id, context=context)
850 thread_items = await self.store.load_thread_items(
851 thread_id,
852 after=None,
853 limit=DEFAULT_PAGE_SIZE,
854 order="asc",
855 context=context,
856 )
857 return Thread(**thread_meta.model_dump(), items=thread_items)
858
859 async def _paginate_thread_items_reverse(
860 self, thread_id: str, context: TContext
861 ) -> AsyncIterator[ThreadItem]:
862 """Paginate through thread items in reverse order (newest first)."""
863 after = None
864 while True:
865 items = await self.store.load_thread_items(
866 thread_id, after, DEFAULT_PAGE_SIZE, "desc", context
867 )
868 for item in items.data:
869 yield item
870
871 if not items.has_more:
872 break
873 after = items.after
874
875 def _serialize(self, obj: BaseModel) -> bytes:
876 return obj.model_dump_json(by_alias=True, exclude_none=True).encode("utf-8")
877
878 def _to_thread_response(self, thread: ThreadMetadata | Thread) -> Thread:
879 def is_hidden(item: ThreadItem) -> bool:
880 return isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
881
882 items = thread.items if isinstance(thread, Thread) else Page()
883 items.data = [item for item in items.data if not is_hidden(item)]
884
885 return Thread(
886 id=thread.id,
887 title=thread.title,
888 created_at=thread.created_at,
889 items=items,
890 status=thread.status,
891 )
892