openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.6.1

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/server.py

936lines · modecode

1import asyncio
2import base64
3from abc import ABC, abstractmethod
4from collections.abc import AsyncIterator
5from contextlib import contextmanager
6from datetime import datetime
7from typing import (
8 Any,
9 AsyncGenerator,
10 AsyncIterable,
11 Callable,
12 Generic,
13)
14
15import agents
16from agents.models.chatcmpl_helpers import (
17 HEADERS_OVERRIDE as chat_completions_headers_override,
18)
19from agents.models.openai_responses import (
20 _HEADERS_OVERRIDE as responses_headers_override,
21)
22from pydantic import BaseModel, TypeAdapter
23from typing_extensions import TypeVar, assert_never
24
25from chatkit.errors import CustomStreamError, StreamError
26
27from .logger import logger
28from .store import AttachmentStore, Store, StoreItemType, default_generate_id
29from .types import (
30 Action,
31 AssistantMessageContent,
32 AssistantMessageContentPartAdded,
33 AssistantMessageContentPartAnnotationAdded,
34 AssistantMessageContentPartDone,
35 AssistantMessageContentPartTextDelta,
36 AssistantMessageItem,
37 AttachmentsCreateReq,
38 AttachmentsDeleteReq,
39 AudioInput,
40 ChatKitReq,
41 ClientToolCallItem,
42 ErrorCode,
43 ErrorEvent,
44 FeedbackKind,
45 HiddenContextItem,
46 InputTranscribeReq,
47 ItemsFeedbackReq,
48 ItemsListReq,
49 NonStreamingReq,
50 Page,
51 SDKHiddenContextItem,
52 StreamingReq,
53 StreamOptions,
54 StreamOptionsEvent,
55 Thread,
56 ThreadCreatedEvent,
57 ThreadItem,
58 ThreadItemAddedEvent,
59 ThreadItemDoneEvent,
60 ThreadItemRemovedEvent,
61 ThreadItemReplacedEvent,
62 ThreadItemUpdatedEvent,
63 ThreadMetadata,
64 ThreadsAddClientToolOutputReq,
65 ThreadsAddUserMessageReq,
66 ThreadsCreateReq,
67 ThreadsCustomActionReq,
68 ThreadsDeleteReq,
69 ThreadsGetByIdReq,
70 ThreadsListReq,
71 ThreadsRetryAfterItemReq,
72 ThreadStreamEvent,
73 ThreadsUpdateReq,
74 ThreadUpdatedEvent,
75 TranscriptionResult,
76 UserMessageInput,
77 UserMessageItem,
78 WidgetComponentUpdated,
79 WidgetItem,
80 WidgetRootUpdated,
81 WidgetStreamingTextValueDelta,
82 WorkflowItem,
83 WorkflowTaskAdded,
84 WorkflowTaskUpdated,
85 is_streaming_req,
86)
87from .version import __version__
88from .widgets import WidgetComponent, WidgetComponentBase, WidgetRoot
89
90DEFAULT_PAGE_SIZE = 20
91DEFAULT_ERROR_MESSAGE = "An error occurred when generating a response."
92
93
94def diff_widget(
95 before: WidgetRoot, after: WidgetRoot
96) -> list[WidgetStreamingTextValueDelta | WidgetRootUpdated | WidgetComponentUpdated]:
97 """
98 Compare two WidgetRoots and return a list of deltas.
99 """
100
101 def is_streaming_text(component: WidgetComponentBase) -> bool:
102 return getattr(component, "type", None) in {"Markdown", "Text"} and isinstance(
103 getattr(component, "value", None), str
104 )
105
106 def full_replace(before: WidgetComponentBase, after: WidgetComponentBase) -> bool:
107 if (
108 before.type != after.type
109 or before.id != after.id
110 or before.key != after.key
111 ):
112 return True
113
114 def full_replace_value(before_value: Any, after_value: Any) -> bool:
115 if isinstance(before_value, list) and isinstance(after_value, list):
116 if len(before_value) != len(after_value):
117 return True
118 for nth_before_value, nth_after_value in zip(before_value, after_value):
119 if full_replace_value(nth_before_value, nth_after_value):
120 return True
121 elif before_value != after_value:
122 if isinstance(before_value, WidgetComponentBase) and isinstance(
123 after_value, WidgetComponentBase
124 ):
125 return full_replace(before_value, after_value)
126 else:
127 return True
128 return False
129
130 for field in before.model_fields_set.union(after.model_fields_set):
131 if (
132 is_streaming_text(before)
133 and is_streaming_text(after)
134 and field == "value"
135 and getattr(after, "value", "").startswith(getattr(before, "value", ""))
136 ):
137 # Appends to the value prop of Markdown or Text do not trigger a full replace
138 continue
139 if full_replace_value(getattr(before, field), getattr(after, field)):
140 return True
141
142 return False
143
144 if full_replace(before, after):
145 return [WidgetRootUpdated(widget=after)]
146
147 deltas: list[
148 WidgetStreamingTextValueDelta | WidgetComponentUpdated | WidgetRootUpdated
149 ] = []
150
151 def find_all_streaming_text_components(
152 component: WidgetComponent | WidgetRoot,
153 ) -> dict[str, WidgetComponentBase]:
154 components = {}
155
156 def recurse(component: WidgetComponent | WidgetRoot):
157 if is_streaming_text(component) and component.id:
158 components[component.id] = component
159
160 if hasattr(component, "children"):
161 children = getattr(component, "children", None) or []
162 for child in children:
163 recurse(child)
164
165 recurse(component)
166 return components
167
168 before_nodes = find_all_streaming_text_components(before)
169 after_nodes = find_all_streaming_text_components(after)
170
171 for id, after_node in after_nodes.items():
172 before_node = before_nodes.get(id)
173 if before_node is None:
174 raise ValueError(
175 f"Node {id} was not present when the widget was initially rendered. All nodes with ID must persist across all widget updates."
176 )
177
178 before_value = str(getattr(before_node, "value", None))
179 after_value = str(getattr(after_node, "value", None))
180
181 if before_value != after_value:
182 if not after_value.startswith(before_value):
183 raise ValueError(
184 f"Node {id} was updated with a new value that is not a prefix of the initial value. All widget updates must be cumulative."
185 )
186 done = not getattr(after_node, "streaming", False)
187 deltas.append(
188 WidgetStreamingTextValueDelta(
189 component_id=id,
190 delta=after_value[len(before_value) :],
191 done=done,
192 )
193 )
194
195 return deltas
196
197
198async def stream_widget(
199 thread: ThreadMetadata,
200 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
201 copy_text: str | None = None,
202 generate_id: Callable[[StoreItemType], str] = default_generate_id,
203) -> AsyncIterator[ThreadStreamEvent]:
204 """Stream a widget root (or async sequence of roots) as ThreadStreamEvents."""
205 item_id = generate_id("message")
206
207 if not isinstance(widget, AsyncGenerator):
208 yield ThreadItemDoneEvent(
209 item=WidgetItem(
210 id=item_id,
211 thread_id=thread.id,
212 created_at=datetime.now(),
213 widget=widget,
214 copy_text=copy_text,
215 ),
216 )
217 return
218
219 initial_state = await widget.__anext__()
220
221 item = WidgetItem(
222 id=item_id,
223 created_at=datetime.now(),
224 widget=initial_state,
225 copy_text=copy_text,
226 thread_id=thread.id,
227 )
228
229 yield ThreadItemAddedEvent(item=item)
230
231 last_state = initial_state
232
233 while widget:
234 try:
235 new_state = await widget.__anext__()
236 for update in diff_widget(last_state, new_state):
237 yield ThreadItemUpdatedEvent(
238 item_id=item_id,
239 update=update,
240 )
241 last_state = new_state
242 except StopAsyncIteration:
243 break
244
245 yield ThreadItemDoneEvent(
246 item=item.model_copy(update={"widget": last_state}),
247 )
248
249
250@contextmanager
251def agents_sdk_user_agent_override():
252 ua = f"Agents/Python {agents.__version__} ChatKit/Python {__version__}"
253 chat_completions_token = chat_completions_headers_override.set({"User-Agent": ua})
254 responses_token = responses_headers_override.set({"User-Agent": ua})
255
256 yield
257
258 chat_completions_headers_override.reset(chat_completions_token)
259 responses_headers_override.reset(responses_token)
260
261
262class StreamingResult(AsyncIterable[bytes]):
263 def __init__(self, stream: AsyncGenerator[bytes, None]):
264 self.json_events = stream
265
266 async def __aiter__(self):
267 async for event in self.json_events:
268 yield event
269
270
271class NonStreamingResult:
272 def __init__(self, result: bytes):
273 self.json = result
274
275
276TContext = TypeVar("TContext", default=Any)
277
278
279class ChatKitServer(ABC, Generic[TContext]):
280 def __init__(
281 self,
282 store: Store[TContext],
283 attachment_store: AttachmentStore[TContext] | None = None,
284 ):
285 """Create a ChatKitServer with the backing Store and optional AttachmentStore."""
286 self.store = store
287 self.attachment_store = attachment_store
288
289 def _get_attachment_store(self) -> AttachmentStore[TContext]:
290 """Return the configured AttachmentStore or raise if missing."""
291 if self.attachment_store is None:
292 raise RuntimeError(
293 "AttachmentStore is not configured. Provide a AttachmentStore to ChatKitServer to handle file operations."
294 )
295 return self.attachment_store
296
297 @abstractmethod
298 def respond(
299 self,
300 thread: ThreadMetadata,
301 input_user_message: UserMessageItem | None,
302 context: TContext,
303 ) -> AsyncIterator[ThreadStreamEvent]:
304 """Stream `ThreadStreamEvent` instances for a new user message.
305
306 Args:
307 thread: Metadata for the thread being processed.
308 input_user_message: The incoming message the server should respond to, if any.
309 context: Arbitrary per-request context provided by the caller.
310
311 Returns:
312 An async iterator that yields events representing the server's response.
313 """
314 pass
315
316 async def add_feedback( # noqa: B027
317 self,
318 thread_id: str,
319 item_ids: list[str],
320 feedback: FeedbackKind,
321 context: TContext,
322 ) -> None:
323 """Persist user feedback for one or more thread items."""
324 pass
325
326 async def transcribe( # noqa: B027
327 self, audio_input: AudioInput, context: TContext
328 ) -> TranscriptionResult:
329 """Transcribe speech audio to text (dictation).
330
331 Audio bytes are in `audio_input.data`. The raw MIME type is `audio_input.mime_type`; use
332 `audio_input.media_type` for the base media type (one of: `audio/webm`, `audio/ogg`, `audio/mp4`).
333
334 Args:
335 audio_input: Audio bytes plus MIME type metadata for transcription.
336 context: Arbitrary per-request context provided by the caller.
337
338 Returns:
339 A `TranscriptionResult` whose `text` is what should appear in the composer.
340 """
341 raise NotImplementedError(
342 "transcribe() must be overridden to support the input.transcribe request."
343 )
344
345 def action(
346 self,
347 thread: ThreadMetadata,
348 action: Action[str, Any],
349 sender: WidgetItem | None,
350 context: TContext,
351 ) -> AsyncIterator[ThreadStreamEvent]:
352 """Handle a widget or client-dispatched action and yield response events."""
353 raise NotImplementedError(
354 "The action() method must be overridden to react to actions. "
355 "See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
356 )
357
358 def get_stream_options(
359 self, thread: ThreadMetadata, context: TContext
360 ) -> StreamOptions:
361 """
362 Return stream-level runtime options. Allows the user to cancel the stream by default.
363 Override this method to customize behavior.
364 """
365 return StreamOptions(allow_cancel=True)
366
367 async def handle_stream_cancelled(
368 self,
369 thread: ThreadMetadata,
370 pending_items: list[ThreadItem],
371 context: TContext,
372 ):
373 """Perform custom cleanup / stop inference when a stream is cancelled.
374 Updates you make here will not be reflected in the UI until a reload.
375
376 The default implementation persists any non-empty pending assistant messages
377 to the thread but does not auto-save pending widget items or workflow items.
378
379 Args:
380 thread: The thread that was being processed.
381 pending_items: Items that were not done streaming at cancellation time.
382 context: Arbitrary per-request context provided by the caller.
383 """
384 pending_assistant_message_items: list[AssistantMessageItem] = [
385 item for item in pending_items if isinstance(item, AssistantMessageItem)
386 ]
387 for item in pending_assistant_message_items:
388 is_empty = len(item.content) == 0 or all(
389 (not content.text.strip()) for content in item.content
390 )
391 if not is_empty:
392 await self.store.add_thread_item(thread.id, item, context=context)
393
394 # Add a hidden context item to the thread to indicate that the stream was cancelled.
395 # Otherwise, depending on the timing of the cancellation, subsequent responses may
396 # attempt to continue the cancelled response.
397 await self.store.add_thread_item(
398 thread.id,
399 SDKHiddenContextItem(
400 thread_id=thread.id,
401 created_at=datetime.now(),
402 id=self.store.generate_item_id("sdk_hidden_context", thread, context),
403 content="The user cancelled the stream. Stop responding to the prior request.",
404 ),
405 context=context,
406 )
407
408 async def process(
409 self, request: str | bytes | bytearray, context: TContext
410 ) -> StreamingResult | NonStreamingResult:
411 """Parse an incoming ChatKit request and route it to streaming or non-streaming handlers."""
412 parsed_request = TypeAdapter[ChatKitReq](ChatKitReq).validate_json(request)
413 logger.info(f"Received request op: {parsed_request.type}")
414
415 if is_streaming_req(parsed_request):
416 return StreamingResult(self._process_streaming(parsed_request, context))
417 else:
418 return NonStreamingResult(
419 await self._process_non_streaming(parsed_request, context)
420 )
421
422 async def _process_non_streaming(
423 self, request: NonStreamingReq, context: TContext
424 ) -> bytes:
425 match request:
426 case ThreadsGetByIdReq():
427 thread = await self._load_full_thread(
428 request.params.thread_id, context=context
429 )
430 return self._serialize(self._to_thread_response(thread))
431 case ThreadsListReq():
432 params = request.params
433 threads = await self.store.load_threads(
434 limit=params.limit or DEFAULT_PAGE_SIZE,
435 after=params.after,
436 order=params.order,
437 context=context,
438 )
439 return self._serialize(
440 Page(
441 has_more=threads.has_more,
442 after=threads.after,
443 data=[
444 self._to_thread_response(thread) for thread in threads.data
445 ],
446 )
447 )
448 case ItemsFeedbackReq():
449 await self.add_feedback(
450 request.params.thread_id,
451 request.params.item_ids,
452 request.params.kind,
453 context,
454 )
455 return b"{}"
456 case AttachmentsCreateReq():
457 attachment_store = self._get_attachment_store()
458 attachment = await attachment_store.create_attachment(
459 request.params, context
460 )
461 await self.store.save_attachment(attachment, context=context)
462 return self._serialize(attachment)
463 case AttachmentsDeleteReq():
464 attachment_store = self._get_attachment_store()
465 await attachment_store.delete_attachment(
466 request.params.attachment_id, context=context
467 )
468 await self.store.delete_attachment(
469 request.params.attachment_id, context=context
470 )
471 return b"{}"
472 case InputTranscribeReq():
473 audio_bytes = base64.b64decode(request.params.audio_base64)
474 transcription_result = await self.transcribe(
475 AudioInput(data=audio_bytes, mime_type=request.params.mime_type),
476 context=context,
477 )
478 return self._serialize(transcription_result)
479 case ItemsListReq():
480 items_list_params = request.params
481 items = await self.store.load_thread_items(
482 items_list_params.thread_id,
483 limit=items_list_params.limit or DEFAULT_PAGE_SIZE,
484 order=items_list_params.order,
485 after=items_list_params.after,
486 context=context,
487 )
488 # filter out hidden context items
489 items.data = [
490 item
491 for item in items.data
492 if not isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
493 ]
494 return self._serialize(items)
495 case ThreadsUpdateReq():
496 thread = await self.store.load_thread(
497 request.params.thread_id, context=context
498 )
499 thread.title = request.params.title
500 await self.store.save_thread(thread, context=context)
501 return self._serialize(self._to_thread_response(thread))
502 case ThreadsDeleteReq():
503 await self.store.delete_thread(
504 request.params.thread_id, context=context
505 )
506 return b"{}"
507 case _:
508 assert_never(request)
509
510 async def _process_streaming(
511 self, request: StreamingReq, context: TContext
512 ) -> AsyncGenerator[bytes, None]:
513 try:
514 async for event in self._process_streaming_impl(request, context):
515 b = self._serialize(event)
516 yield b"data: " + b + b"\n\n"
517 except asyncio.CancelledError:
518 # Let cancellation bubble up without logging as an error.
519 raise
520 except Exception:
521 logger.exception("Error while generating streamed response")
522 raise
523
524 async def _process_streaming_impl(
525 self, request: StreamingReq, context: TContext
526 ) -> AsyncGenerator[ThreadStreamEvent, None]:
527 match request:
528 case ThreadsCreateReq():
529 thread = Thread(
530 id=self.store.generate_thread_id(context),
531 created_at=datetime.now(),
532 items=Page(),
533 )
534 await self.store.save_thread(
535 ThreadMetadata(**thread.model_dump()),
536 context=context,
537 )
538 yield ThreadCreatedEvent(thread=self._to_thread_response(thread))
539 user_message = await self._build_user_message_item(
540 request.params.input, thread, context
541 )
542 async for event in self._process_new_thread_item_respond(
543 thread,
544 user_message,
545 context,
546 ):
547 yield event
548
549 case ThreadsAddUserMessageReq():
550 thread = await self.store.load_thread(
551 request.params.thread_id, context=context
552 )
553 user_message = await self._build_user_message_item(
554 request.params.input, thread, context
555 )
556 async for event in self._process_new_thread_item_respond(
557 thread,
558 user_message,
559 context,
560 ):
561 yield event
562
563 case ThreadsAddClientToolOutputReq():
564 thread = await self.store.load_thread(
565 request.params.thread_id, context=context
566 )
567 items = await self.store.load_thread_items(
568 thread.id, None, 1, "desc", context
569 )
570 tool_call = next(
571 (
572 item
573 for item in items.data
574 if isinstance(item, ClientToolCallItem)
575 and item.status == "pending"
576 ),
577 None,
578 )
579 if not tool_call:
580 raise ValueError(
581 f"Last thread item in {thread.id} was not a ClientToolCallItem"
582 )
583
584 tool_call.output = request.params.result
585 tool_call.status = "completed"
586
587 await self.store.save_item(thread.id, tool_call, context=context)
588
589 # Safety against dangling pending tool calls if there are
590 # multiple in a row, which should be impossible, and
591 # integrations should ultimately filter out pending tool calls
592 # when creating input response messages.
593 await self._cleanup_pending_client_tool_call(thread, context)
594
595 async for event in self._process_events(
596 thread,
597 context,
598 lambda: self.respond(thread, None, context),
599 ):
600 yield event
601
602 case ThreadsRetryAfterItemReq():
603 thread_metadata = await self.store.load_thread(
604 request.params.thread_id, context=context
605 )
606
607 # Collect items to remove (all items after the user message)
608 items_to_remove: list[ThreadItem] = []
609 user_message_item = None
610
611 async for item in self._paginate_thread_items_reverse(
612 request.params.thread_id, context
613 ):
614 if item.id == request.params.item_id:
615 if not isinstance(item, UserMessageItem):
616 raise ValueError(
617 f"Item {request.params.item_id} is not a user message"
618 )
619 user_message_item = item
620 break
621 items_to_remove.append(item)
622
623 if user_message_item:
624 for item in items_to_remove:
625 await self.store.delete_thread_item(
626 request.params.thread_id, item.id, context=context
627 )
628 async for event in self._process_events(
629 thread_metadata,
630 context,
631 lambda: self.respond(
632 thread_metadata,
633 user_message_item,
634 context,
635 ),
636 ):
637 yield event
638 case ThreadsCustomActionReq():
639 thread_metadata = await self.store.load_thread(
640 request.params.thread_id, context=context
641 )
642
643 item: ThreadItem | None = None
644 if request.params.item_id:
645 item = await self.store.load_item(
646 request.params.thread_id,
647 request.params.item_id,
648 context=context,
649 )
650
651 if item and not isinstance(item, WidgetItem):
652 # shouldn't happen if the caller is using the API correctly.
653 yield ErrorEvent(
654 code=ErrorCode.STREAM_ERROR,
655 allow_retry=False,
656 )
657 return
658
659 async for event in self._process_events(
660 thread_metadata,
661 context,
662 lambda: self.action(
663 thread_metadata,
664 request.params.action,
665 item,
666 context,
667 ),
668 ):
669 yield event
670
671 case _:
672 assert_never(request)
673
674 async def _cleanup_pending_client_tool_call(
675 self, thread: ThreadMetadata, context: TContext
676 ) -> None:
677 items = await self.store.load_thread_items(
678 thread.id, None, DEFAULT_PAGE_SIZE, "desc", context
679 )
680 for tool_call in items.data:
681 if not isinstance(tool_call, ClientToolCallItem):
682 continue
683 if tool_call.status == "pending":
684 logger.warning(
685 f"Client tool call {tool_call.call_id} was not completed, ignoring"
686 )
687 await self.store.delete_thread_item(
688 thread.id, tool_call.id, context=context
689 )
690
691 async def _process_new_thread_item_respond(
692 self,
693 thread: ThreadMetadata,
694 item: UserMessageItem,
695 context: TContext,
696 ) -> AsyncIterator[ThreadStreamEvent]:
697 # Store the updated attachments (now that they have a thread id).
698 for attachment in item.attachments:
699 await self.store.save_attachment(attachment, context=context)
700
701 await self.store.add_thread_item(thread.id, item, context=context)
702 yield ThreadItemDoneEvent(item=item)
703
704 async for event in self._process_events(
705 thread,
706 context,
707 lambda: self.respond(thread, item, context),
708 ):
709 yield event
710
711 async def _process_events(
712 self,
713 thread: ThreadMetadata,
714 context: TContext,
715 stream: Callable[[], AsyncIterator[ThreadStreamEvent]],
716 ) -> AsyncIterator[ThreadStreamEvent]:
717 await asyncio.sleep(0) # allow the response to start streaming
718
719 # Send initial stream options
720 yield StreamOptionsEvent(
721 stream_options=self.get_stream_options(thread, context)
722 )
723
724 last_thread = thread.model_copy(deep=True)
725
726 # Keep track of items that were streamed but not yet saved
727 # so that we can persist them when the stream is cancelled.
728 pending_items: dict[str, ThreadItem] = {}
729
730 try:
731 with agents_sdk_user_agent_override():
732 async for event in stream():
733 if isinstance(event, ThreadItemAddedEvent):
734 # Stash an isolated copy in case we need to persist unfinished items
735 # on cancellation; downstream handlers keep using the original event.item.
736 pending_items[event.item.id] = event.item.model_copy(deep=True)
737
738 match event:
739 case ThreadItemDoneEvent():
740 await self.store.add_thread_item(
741 thread.id, event.item, context=context
742 )
743 pending_items.pop(event.item.id, None)
744 case ThreadItemRemovedEvent():
745 await self.store.delete_thread_item(
746 thread.id, event.item_id, context=context
747 )
748 pending_items.pop(event.item_id, None)
749 case ThreadItemReplacedEvent():
750 await self.store.save_item(
751 thread.id, event.item, context=context
752 )
753 pending_items.pop(event.item.id, None)
754 case ThreadItemUpdatedEvent():
755 # Keep pending assistant message and workflow items up to date
756 # so that we have a reference to the latest version of these pending items
757 # when the stream is cancelled.
758 self._update_pending_items(pending_items, event)
759
760 # special case - don't send hidden context items back to the client
761 should_swallow_event = isinstance(
762 event, ThreadItemDoneEvent
763 ) and isinstance(
764 event.item, (HiddenContextItem, SDKHiddenContextItem)
765 )
766
767 if not should_swallow_event:
768 yield event
769
770 # in case user updated the thread while streaming
771 if thread != last_thread:
772 last_thread = thread.model_copy(deep=True)
773 await self.store.save_thread(thread, context=context)
774 yield ThreadUpdatedEvent(
775 thread=self._to_thread_response(thread)
776 )
777 # in case user updated the thread while streaming
778 if thread != last_thread:
779 last_thread = thread.model_copy(deep=True)
780 await self.store.save_thread(thread, context=context)
781 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
782 except asyncio.CancelledError:
783 await self.handle_stream_cancelled(
784 thread, list(pending_items.values()), context
785 )
786 raise
787 except CustomStreamError as e:
788 yield ErrorEvent(
789 code="custom",
790 message=e.message,
791 allow_retry=e.allow_retry,
792 )
793 except StreamError as e:
794 yield ErrorEvent(
795 code=e.code,
796 allow_retry=e.allow_retry,
797 )
798 except Exception as e:
799 yield ErrorEvent(
800 code=ErrorCode.STREAM_ERROR,
801 allow_retry=True,
802 )
803 logger.exception(e)
804
805 if thread != last_thread:
806 # in case user updated the thread at the end of the stream
807 await self.store.save_thread(thread, context=context)
808 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
809
810 def _apply_assistant_message_update(
811 self,
812 item: AssistantMessageItem,
813 update: AssistantMessageContentPartAdded
814 | AssistantMessageContentPartTextDelta
815 | AssistantMessageContentPartAnnotationAdded
816 | AssistantMessageContentPartDone,
817 ) -> AssistantMessageItem:
818 # Pad the content list so the requested content_index exists before we write into it.
819 # (Streaming updates can arrive for an index that hasn’t been created yet)
820 while len(item.content) <= update.content_index:
821 item.content.append(AssistantMessageContent(text="", annotations=[]))
822
823 match update:
824 case AssistantMessageContentPartAdded():
825 item.content[update.content_index] = update.content
826 case AssistantMessageContentPartTextDelta():
827 item.content[update.content_index].text += update.delta
828 case AssistantMessageContentPartAnnotationAdded():
829 annotations = item.content[update.content_index].annotations
830 if update.annotation_index <= len(annotations):
831 annotations.insert(update.annotation_index, update.annotation)
832 else:
833 annotations.append(update.annotation)
834 case AssistantMessageContentPartDone():
835 item.content[update.content_index] = update.content
836 return item
837
838 def _update_pending_items(
839 self,
840 pending_items: dict[str, ThreadItem],
841 event: ThreadItemUpdatedEvent,
842 ):
843 updated_item = pending_items.get(event.item_id)
844 update = event.update
845 match updated_item:
846 case AssistantMessageItem():
847 if isinstance(
848 update,
849 (
850 AssistantMessageContentPartAdded,
851 AssistantMessageContentPartTextDelta,
852 AssistantMessageContentPartAnnotationAdded,
853 AssistantMessageContentPartDone,
854 ),
855 ):
856 pending_items[updated_item.id] = (
857 self._apply_assistant_message_update(updated_item, update)
858 )
859 case WorkflowItem():
860 if isinstance(update, (WorkflowTaskUpdated, WorkflowTaskAdded)):
861 match update:
862 case WorkflowTaskUpdated():
863 updated_item.workflow.tasks[update.task_index] = update.task
864 case WorkflowTaskAdded():
865 updated_item.workflow.tasks.append(update.task)
866
867 pending_items[updated_item.id] = updated_item
868 case _:
869 pass
870
871 async def _build_user_message_item(
872 self, input: UserMessageInput, thread: ThreadMetadata, context: TContext
873 ) -> UserMessageItem:
874 return UserMessageItem(
875 id=self.store.generate_item_id("message", thread, context),
876 content=input.content,
877 thread_id=thread.id,
878 attachments=[
879 (await self.store.load_attachment(attachment_id, context)).model_copy(
880 update={"thread_id": thread.id}
881 )
882 for attachment_id in input.attachments
883 ],
884 quoted_text=input.quoted_text,
885 inference_options=input.inference_options,
886 created_at=datetime.now(),
887 )
888
889 async def _load_full_thread(self, thread_id: str, context: TContext) -> Thread:
890 thread_meta = await self.store.load_thread(thread_id, context=context)
891 thread_items = await self.store.load_thread_items(
892 thread_id,
893 after=None,
894 limit=DEFAULT_PAGE_SIZE,
895 order="asc",
896 context=context,
897 )
898 return Thread(**thread_meta.model_dump(), items=thread_items)
899
900 async def _paginate_thread_items_reverse(
901 self, thread_id: str, context: TContext
902 ) -> AsyncIterator[ThreadItem]:
903 """Paginate through thread items in reverse order (newest first)."""
904 after = None
905 while True:
906 items = await self.store.load_thread_items(
907 thread_id, after, DEFAULT_PAGE_SIZE, "desc", context
908 )
909 for item in items.data:
910 yield item
911
912 if not items.has_more:
913 break
914 after = items.after
915
916 def _serialize(self, obj: BaseModel) -> bytes:
917 return obj.model_dump_json(
918 by_alias=True,
919 exclude_none=True,
920 context={"exclude_metadata": True},
921 ).encode("utf-8")
922
923 def _to_thread_response(self, thread: ThreadMetadata | Thread) -> Thread:
924 def is_hidden(item: ThreadItem) -> bool:
925 return isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
926
927 items = thread.items if isinstance(thread, Thread) else Page()
928 items.data = [item for item in items.data if not is_hidden(item)]
929
930 return Thread(
931 id=thread.id,
932 title=thread.title,
933 created_at=thread.created_at,
934 items=items,
935 status=thread.status,
936 )
937