openai/chatkit-python

Public

mirrored from https://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
b77d8564687675d4dca66278bb553a05390c709d

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/server.py

891lines · modeblame

f688d870victor-openai8 months ago1import asyncio
2from abc import ABC, abstractmethod
3from collections.abc import AsyncIterator
4from contextlib import contextmanager
5from datetime import datetime
6from typing import (
7Any,
8AsyncGenerator,
9AsyncIterable,
10Callable,
11Generic,
12)
13
14import agents
15from agents.models.chatcmpl_helpers import (
16HEADERS_OVERRIDE as chat_completions_headers_override,
17)
18from agents.models.openai_responses import (
19_HEADERS_OVERRIDE as responses_headers_override,
20)
21from pydantic import BaseModel, TypeAdapter
3de07454Dan Wang7 months ago22from typing_extensions import TypeVar, assert_never
f688d870victor-openai8 months ago23
24from chatkit.errors import CustomStreamError, StreamError
25
26from .logger import logger
27from .store import AttachmentStore, Store, StoreItemType, default_generate_id
28from .types import (
29Action,
eee0a2f7Jiwon Kim7 months ago30AssistantMessageContent,
31AssistantMessageContentPartAdded,
32AssistantMessageContentPartAnnotationAdded,
33AssistantMessageContentPartDone,
34AssistantMessageContentPartTextDelta,
35AssistantMessageItem,
f688d870victor-openai8 months ago36AttachmentsCreateReq,
37AttachmentsDeleteReq,
38ChatKitReq,
39ClientToolCallItem,
40ErrorCode,
41ErrorEvent,
42FeedbackKind,
43HiddenContextItem,
44ItemsFeedbackReq,
45ItemsListReq,
46NonStreamingReq,
47Page,
5a7244baJiwon Kim7 months ago48SDKHiddenContextItem,
f688d870victor-openai8 months ago49StreamingReq,
5a7244baJiwon Kim7 months ago50StreamOptions,
51StreamOptionsEvent,
f688d870victor-openai8 months ago52Thread,
53ThreadCreatedEvent,
54ThreadItem,
55ThreadItemAddedEvent,
56ThreadItemDoneEvent,
57ThreadItemRemovedEvent,
58ThreadItemReplacedEvent,
643a02f7Jiwon Kim7 months ago59ThreadItemUpdatedEvent,
f688d870victor-openai8 months ago60ThreadMetadata,
61ThreadsAddClientToolOutputReq,
62ThreadsAddUserMessageReq,
63ThreadsCreateReq,
64ThreadsCustomActionReq,
65ThreadsDeleteReq,
66ThreadsGetByIdReq,
67ThreadsListReq,
68ThreadsRetryAfterItemReq,
69ThreadStreamEvent,
70ThreadsUpdateReq,
71ThreadUpdatedEvent,
72UserMessageInput,
73UserMessageItem,
74WidgetComponentUpdated,
75WidgetItem,
76WidgetRootUpdated,
77WidgetStreamingTextValueDelta,
eee0a2f7Jiwon Kim7 months ago78WorkflowItem,
5a7244baJiwon Kim7 months ago79WorkflowTaskAdded,
80WorkflowTaskUpdated,
f688d870victor-openai8 months ago81is_streaming_req,
82)
83from .version import __version__
91c93ee6Jiwon Kim7 months ago84from .widgets import WidgetComponent, WidgetComponentBase, WidgetRoot
f688d870victor-openai8 months ago85
86DEFAULT_PAGE_SIZE = 20
87DEFAULT_ERROR_MESSAGE = "An error occurred when generating a response."
88
89
90def diff_widget(
91before: WidgetRoot, after: WidgetRoot
92) -> list[WidgetStreamingTextValueDelta | WidgetRootUpdated | WidgetComponentUpdated]:
93"""
94Compare two WidgetRoots and return a list of deltas.
95"""
96
91c93ee6Jiwon Kim7 months ago97def is_streaming_text(component: WidgetComponentBase) -> bool:
98return getattr(component, "type", None) in {"Markdown", "Text"} and isinstance(
99getattr(component, "value", None), str
100)
101
f688d870victor-openai8 months ago102def full_replace(before: WidgetComponentBase, after: WidgetComponentBase) -> bool:
103if (
104before.type != after.type
105or before.id != after.id
106or before.key != after.key
107):
108return True
109
110def full_replace_value(before_value: Any, after_value: Any) -> bool:
111if isinstance(before_value, list) and isinstance(after_value, list):
112if len(before_value) != len(after_value):
113return True
114for nth_before_value, nth_after_value in zip(before_value, after_value):
115if full_replace_value(nth_before_value, nth_after_value):
116return True
117elif before_value != after_value:
118if isinstance(before_value, WidgetComponentBase) and isinstance(
119after_value, WidgetComponentBase
120):
121return full_replace(before_value, after_value)
122else:
123return True
124return False
125
126for field in before.model_fields_set.union(after.model_fields_set):
127if (
91c93ee6Jiwon Kim7 months ago128is_streaming_text(before)
129and is_streaming_text(after)
f688d870victor-openai8 months ago130and field == "value"
91c93ee6Jiwon Kim7 months ago131and getattr(after, "value", "").startswith(getattr(before, "value", ""))
f688d870victor-openai8 months ago132):
133# Appends to the value prop of Markdown or Text do not trigger a full replace
134continue
135if full_replace_value(getattr(before, field), getattr(after, field)):
136return True
137
138return False
139
140if full_replace(before, after):
141return [WidgetRootUpdated(widget=after)]
142
143deltas: list[
144WidgetStreamingTextValueDelta | WidgetComponentUpdated | WidgetRootUpdated
145] = []
146
147def find_all_streaming_text_components(
148component: WidgetComponent | WidgetRoot,
91c93ee6Jiwon Kim7 months ago149) -> dict[str, WidgetComponentBase]:
f688d870victor-openai8 months ago150components = {}
151
152def recurse(component: WidgetComponent | WidgetRoot):
91c93ee6Jiwon Kim7 months ago153if is_streaming_text(component) and component.id:
f688d870victor-openai8 months ago154components[component.id] = component
155
156if hasattr(component, "children"):
157children = getattr(component, "children", None) or []
158for child in children:
159recurse(child)
160
161recurse(component)
162return components
163
164before_nodes = find_all_streaming_text_components(before)
165after_nodes = find_all_streaming_text_components(after)
166
167for id, after_node in after_nodes.items():
168before_node = before_nodes.get(id)
169if before_node is None:
170raise ValueError(
171f"Node {id} was not present when the widget was initially rendered. All nodes with ID must persist across all widget updates."
172)
173
91c93ee6Jiwon Kim7 months ago174before_value = str(getattr(before_node, "value", None))
175after_value = str(getattr(after_node, "value", None))
176
177if before_value != after_value:
178if not after_value.startswith(before_value):
f688d870victor-openai8 months ago179raise ValueError(
180f"Node {id} was updated with a new value that is not a prefix of the initial value. All widget updates must be cumulative."
181)
91c93ee6Jiwon Kim7 months ago182done = not getattr(after_node, "streaming", False)
f688d870victor-openai8 months ago183deltas.append(
184WidgetStreamingTextValueDelta(
185component_id=id,
91c93ee6Jiwon Kim7 months ago186delta=after_value[len(before_value) :],
f688d870victor-openai8 months ago187done=done,
188)
189)
190
191return deltas
192
193
194async def stream_widget(
195thread: ThreadMetadata,
196widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
197copy_text: str | None = None,
198generate_id: Callable[[StoreItemType], str] = default_generate_id,
199) -> AsyncIterator[ThreadStreamEvent]:
200item_id = generate_id("message")
201
202if not isinstance(widget, AsyncGenerator):
203yield ThreadItemDoneEvent(
204item=WidgetItem(
205id=item_id,
206thread_id=thread.id,
207created_at=datetime.now(),
208widget=widget,
209copy_text=copy_text,
210),
211)
212return
213
214initial_state = await widget.__anext__()
215
216item = WidgetItem(
217id=item_id,
218created_at=datetime.now(),
219widget=initial_state,
220copy_text=copy_text,
221thread_id=thread.id,
222)
223
224yield ThreadItemAddedEvent(item=item)
225
226last_state = initial_state
227
228while widget:
229try:
230new_state = await widget.__anext__()
231for update in diff_widget(last_state, new_state):
643a02f7Jiwon Kim7 months ago232yield ThreadItemUpdatedEvent(
f688d870victor-openai8 months ago233item_id=item_id,
234update=update,
235)
236last_state = new_state
237except StopAsyncIteration:
238break
239
240yield ThreadItemDoneEvent(
241item=item.model_copy(update={"widget": last_state}),
242)
243
244
245@contextmanager
246def agents_sdk_user_agent_override():
247ua = f"Agents/Python {agents.__version__} ChatKit/Python {__version__}"
248chat_completions_token = chat_completions_headers_override.set({"User-Agent": ua})
249responses_token = responses_headers_override.set({"User-Agent": ua})
250
251yield
252
253chat_completions_headers_override.reset(chat_completions_token)
254responses_headers_override.reset(responses_token)
255
256
257class StreamingResult(AsyncIterable[bytes]):
258def __init__(self, stream: AsyncGenerator[bytes, None]):
259self.json_events = stream
260
261async def __aiter__(self):
262async for event in self.json_events:
263yield event
264
265
266class NonStreamingResult:
267def __init__(self, result: bytes):
268self.json = result
269
270
271TContext = TypeVar("TContext", default=Any)
272
273
274class ChatKitServer(ABC, Generic[TContext]):
275def __init__(
276self,
277store: Store[TContext],
278attachment_store: AttachmentStore[TContext] | None = None,
279):
280self.store = store
281self.attachment_store = attachment_store
282
283def _get_attachment_store(self) -> AttachmentStore[TContext]:
284"""Return the configured AttachmentStore or raise if missing."""
285if self.attachment_store is None:
286raise RuntimeError(
287"AttachmentStore is not configured. Provide a AttachmentStore to ChatKitServer to handle file operations."
288)
289return self.attachment_store
290
291@abstractmethod
292def respond(
293self,
294thread: ThreadMetadata,
295input_user_message: UserMessageItem | None,
296context: TContext,
297) -> AsyncIterator[ThreadStreamEvent]:
6988ea0avictor-openai8 months ago298"""Stream `ThreadStreamEvent` instances for a new user message.
299
300Args:
301thread: Metadata for the thread being processed.
302input_user_message: The incoming message the server should respond to, if any.
303context: Arbitrary per-request context provided by the caller.
304
305Returns:
306An async iterator that yields events representing the server's response.
307"""
f688d870victor-openai8 months ago308pass
309
310async def add_feedback( # noqa: B027
311self,
312thread_id: str,
313item_ids: list[str],
314feedback: FeedbackKind,
315context: TContext,
316) -> None:
317pass
318
319def action(
320self,
321thread: ThreadMetadata,
322action: Action[str, Any],
323sender: WidgetItem | None,
324context: TContext,
325) -> AsyncIterator[ThreadStreamEvent]:
326raise NotImplementedError(
327"The action() method must be overridden to react to actions. "
93a34923Vincent Lin8 months ago328"See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
f688d870victor-openai8 months ago329)
330
5a7244baJiwon Kim7 months ago331def get_stream_options(
332self, thread: ThreadMetadata, context: TContext
333) -> StreamOptions:
334"""
335Return stream-level runtime options. Allows the user to cancel the stream by default.
336Override this method to customize behavior.
337"""
338return StreamOptions(allow_cancel=True)
339
eee0a2f7Jiwon Kim7 months ago340async def handle_stream_cancelled(
341self,
342thread: ThreadMetadata,
343pending_items: list[ThreadItem],
344context: TContext,
345):
346"""Perform custom cleanup / stop inference when a stream is cancelled.
5a7244baJiwon Kim7 months ago347Updates you make here will not be reflected in the UI until a reload.
348
349The default implementation persists any non-empty pending assistant messages
350to the thread but does not auto-save pending widget items or workflow items.
eee0a2f7Jiwon Kim7 months ago351
352Args:
353thread: The thread that was being processed.
354pending_items: Items that were not done streaming at cancellation time.
355context: Arbitrary per-request context provided by the caller.
356"""
5a7244baJiwon Kim7 months ago357pending_assistant_message_items: list[AssistantMessageItem] = [
358item for item in pending_items if isinstance(item, AssistantMessageItem)
359]
360for item in pending_assistant_message_items:
361is_empty = len(item.content) == 0 or all(
362(not content.text.strip()) for content in item.content
363)
364if not is_empty:
365await self.store.add_thread_item(thread.id, item, context=context)
366
367# Add a hidden context item to the thread to indicate that the stream was cancelled.
368# Otherwise, depending on the timing of the cancellation, subsequent responses may
369# attempt to continue the cancelled response.
370await self.store.add_thread_item(
371thread.id,
372SDKHiddenContextItem(
373thread_id=thread.id,
374created_at=datetime.now(),
375id=self.store.generate_item_id("sdk_hidden_context", thread, context),
6df9ee60Jiwon Kim7 months ago376content="The user cancelled the stream. Stop responding to the prior request.",
5a7244baJiwon Kim7 months ago377),
378context=context,
379)
eee0a2f7Jiwon Kim7 months ago380
f688d870victor-openai8 months ago381async def process(
382self, request: str | bytes | bytearray, context: TContext
383) -> StreamingResult | NonStreamingResult:
384parsed_request = TypeAdapter[ChatKitReq](ChatKitReq).validate_json(request)
385logger.info(f"Received request op: {parsed_request.type}")
386
387if is_streaming_req(parsed_request):
388return StreamingResult(self._process_streaming(parsed_request, context))
389else:
390return NonStreamingResult(
391await self._process_non_streaming(parsed_request, context)
392)
393
394async def _process_non_streaming(
395self, request: NonStreamingReq, context: TContext
396) -> bytes:
397match request:
398case ThreadsGetByIdReq():
399thread = await self._load_full_thread(
400request.params.thread_id, context=context
401)
402return self._serialize(self._to_thread_response(thread))
403case ThreadsListReq():
404params = request.params
405threads = await self.store.load_threads(
406limit=params.limit or DEFAULT_PAGE_SIZE,
407after=params.after,
408order=params.order,
409context=context,
410)
411return self._serialize(
412Page(
413has_more=threads.has_more,
414after=threads.after,
415data=[
416self._to_thread_response(thread) for thread in threads.data
417],
418)
419)
420case ItemsFeedbackReq():
421await self.add_feedback(
422request.params.thread_id,
423request.params.item_ids,
424request.params.kind,
425context,
426)
427return b"{}"
428case AttachmentsCreateReq():
429attachment_store = self._get_attachment_store()
430attachment = await attachment_store.create_attachment(
431request.params, context
432)
433return self._serialize(attachment)
434case AttachmentsDeleteReq():
435attachment_store = self._get_attachment_store()
436await attachment_store.delete_attachment(
437request.params.attachment_id, context=context
438)
439await self.store.delete_attachment(
440request.params.attachment_id, context=context
441)
442return b"{}"
443case ItemsListReq():
444items_list_params = request.params
445items = await self.store.load_thread_items(
446items_list_params.thread_id,
447limit=items_list_params.limit or DEFAULT_PAGE_SIZE,
448order=items_list_params.order,
449after=items_list_params.after,
450context=context,
451)
1860286bJiwon Kim7 months ago452# filter out hidden context items
f688d870victor-openai8 months ago453items.data = [
454item
455for item in items.data
1860286bJiwon Kim7 months ago456if not isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
f688d870victor-openai8 months ago457]
458return self._serialize(items)
459case ThreadsUpdateReq():
460thread = await self.store.load_thread(
461request.params.thread_id, context=context
462)
463thread.title = request.params.title
464await self.store.save_thread(thread, context=context)
465return self._serialize(self._to_thread_response(thread))
466case ThreadsDeleteReq():
467await self.store.delete_thread(
468request.params.thread_id, context=context
469)
470return b"{}"
471case _:
472assert_never(request)
473
474async def _process_streaming(
475self, request: StreamingReq, context: TContext
476) -> AsyncGenerator[bytes, None]:
477try:
478async for event in self._process_streaming_impl(request, context):
479b = self._serialize(event)
480yield b"data: " + b + b"\n\n"
eee0a2f7Jiwon Kim7 months ago481except asyncio.CancelledError:
482# Let cancellation bubble up without logging as an error.
483raise
f688d870victor-openai8 months ago484except Exception:
485logger.exception("Error while generating streamed response")
486raise
487
488async def _process_streaming_impl(
489self, request: StreamingReq, context: TContext
490) -> AsyncGenerator[ThreadStreamEvent, None]:
491match request:
492case ThreadsCreateReq():
493thread = Thread(
494id=self.store.generate_thread_id(context),
495created_at=datetime.now(),
496items=Page(),
497)
8885a10bquettabit8 months ago498await self.store.save_thread(
499ThreadMetadata(**thread.model_dump()),
500context=context,
501)
f688d870victor-openai8 months ago502yield ThreadCreatedEvent(thread=self._to_thread_response(thread))
503user_message = await self._build_user_message_item(
504request.params.input, thread, context
505)
506async for event in self._process_new_thread_item_respond(
507thread,
508user_message,
509context,
510):
511yield event
512
513case ThreadsAddUserMessageReq():
514thread = await self.store.load_thread(
515request.params.thread_id, context=context
516)
517user_message = await self._build_user_message_item(
518request.params.input, thread, context
519)
520async for event in self._process_new_thread_item_respond(
521thread,
522user_message,
523context,
524):
525yield event
526
527case ThreadsAddClientToolOutputReq():
528thread = await self.store.load_thread(
529request.params.thread_id, context=context
530)
531items = await self.store.load_thread_items(
532thread.id, None, 1, "desc", context
533)
534tool_call = next(
535(
536item
537for item in items.data
538if isinstance(item, ClientToolCallItem)
539and item.status == "pending"
540),
541None,
542)
543if not tool_call:
544raise ValueError(
545f"Last thread item in {thread.id} was not a ClientToolCallItem"
546)
547
548tool_call.output = request.params.result
549tool_call.status = "completed"
550
551await self.store.save_item(thread.id, tool_call, context=context)
552
553# Safety against dangling pending tool calls if there are
554# multiple in a row, which should be impossible, and
555# integrations should ultimately filter out pending tool calls
556# when creating input response messages.
557await self._cleanup_pending_client_tool_call(thread, context)
558
559async for event in self._process_events(
560thread,
561context,
562lambda: self.respond(thread, None, context),
563):
564yield event
565
566case ThreadsRetryAfterItemReq():
567thread_metadata = await self.store.load_thread(
568request.params.thread_id, context=context
569)
570
571# Collect items to remove (all items after the user message)
572items_to_remove: list[ThreadItem] = []
573user_message_item = None
574
575async for item in self._paginate_thread_items_reverse(
576request.params.thread_id, context
577):
578if item.id == request.params.item_id:
579if not isinstance(item, UserMessageItem):
580raise ValueError(
581f"Item {request.params.item_id} is not a user message"
582)
583user_message_item = item
584break
585items_to_remove.append(item)
586
587if user_message_item:
588for item in items_to_remove:
589await self.store.delete_thread_item(
590request.params.thread_id, item.id, context=context
591)
592async for event in self._process_events(
593thread_metadata,
594context,
595lambda: self.respond(
596thread_metadata,
597user_message_item,
598context,
599),
600):
601yield event
602case ThreadsCustomActionReq():
603thread_metadata = await self.store.load_thread(
604request.params.thread_id, context=context
605)
606
607item: ThreadItem | None = None
608if request.params.item_id:
609item = await self.store.load_item(
610request.params.thread_id,
611request.params.item_id,
612context=context,
613)
614
615if item and not isinstance(item, WidgetItem):
616# shouldn't happen if the caller is using the API correctly.
617yield ErrorEvent(
618code=ErrorCode.STREAM_ERROR,
619allow_retry=False,
620)
621return
622
623async for event in self._process_events(
624thread_metadata,
625context,
626lambda: self.action(
627thread_metadata,
628request.params.action,
629item,
630context,
631),
632):
633yield event
634
635case _:
636assert_never(request)
637
638async def _cleanup_pending_client_tool_call(
639self, thread: ThreadMetadata, context: TContext
640) -> None:
641items = await self.store.load_thread_items(
642thread.id, None, DEFAULT_PAGE_SIZE, "desc", context
643)
644for tool_call in items.data:
645if not isinstance(tool_call, ClientToolCallItem):
646continue
647if tool_call.status == "pending":
648logger.warning(
649f"Client tool call {tool_call.call_id} was not completed, ignoring"
650)
651await self.store.delete_thread_item(
652thread.id, tool_call.id, context=context
653)
654
655async def _process_new_thread_item_respond(
656self,
657thread: ThreadMetadata,
658item: UserMessageItem,
659context: TContext,
660) -> AsyncIterator[ThreadStreamEvent]:
661await self.store.add_thread_item(thread.id, item, context=context)
662await self._cleanup_pending_client_tool_call(thread, context)
663yield ThreadItemDoneEvent(item=item)
664
665async for event in self._process_events(
666thread,
667context,
668lambda: self.respond(thread, item, context),
669):
670yield event
671
672async def _process_events(
673self,
674thread: ThreadMetadata,
675context: TContext,
676stream: Callable[[], AsyncIterator[ThreadStreamEvent]],
677) -> AsyncIterator[ThreadStreamEvent]:
678await asyncio.sleep(0) # allow the response to start streaming
679
5a7244baJiwon Kim7 months ago680# Send initial stream options
681yield StreamOptionsEvent(
682stream_options=self.get_stream_options(thread, context)
683)
684
f688d870victor-openai8 months ago685last_thread = thread.model_copy(deep=True)
686
eee0a2f7Jiwon Kim7 months ago687# Keep track of items that were streamed but not yet saved
688# so that we can persist them when the stream is cancelled.
689pending_items: dict[str, ThreadItem] = {}
690
f688d870victor-openai8 months ago691try:
692with agents_sdk_user_agent_override():
693async for event in stream():
eee0a2f7Jiwon Kim7 months ago694if isinstance(event, ThreadItemAddedEvent):
695pending_items[event.item.id] = event.item
696
f688d870victor-openai8 months ago697match event:
698case ThreadItemDoneEvent():
699await self.store.add_thread_item(
700thread.id, event.item, context=context
701)
eee0a2f7Jiwon Kim7 months ago702pending_items.pop(event.item.id, None)
f688d870victor-openai8 months ago703case ThreadItemRemovedEvent():
704await self.store.delete_thread_item(
705thread.id, event.item_id, context=context
706)
eee0a2f7Jiwon Kim7 months ago707pending_items.pop(event.item_id, None)
f688d870victor-openai8 months ago708case ThreadItemReplacedEvent():
709await self.store.save_item(
710thread.id, event.item, context=context
711)
eee0a2f7Jiwon Kim7 months ago712pending_items.pop(event.item.id, None)
713case ThreadItemUpdatedEvent():
5a7244baJiwon Kim7 months ago714# Keep pending assistant message and workflow items up to date
715# so that we have a reference to the latest version of these pending items
716# when the stream is cancelled.
717self._update_pending_items(pending_items, event)
f688d870victor-openai8 months ago718
719# special case - don't send hidden context items back to the client
720should_swallow_event = isinstance(
721event, ThreadItemDoneEvent
1860286bJiwon Kim7 months ago722) and isinstance(
723event.item, (HiddenContextItem, SDKHiddenContextItem)
724)
f688d870victor-openai8 months ago725
726if not should_swallow_event:
727yield event
728
729# in case user updated the thread while streaming
730if thread != last_thread:
731last_thread = thread.model_copy(deep=True)
732await self.store.save_thread(thread, context=context)
733yield ThreadUpdatedEvent(
734thread=self._to_thread_response(thread)
735)
736# in case user updated the thread while streaming
737if thread != last_thread:
738last_thread = thread.model_copy(deep=True)
739await self.store.save_thread(thread, context=context)
740yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
eee0a2f7Jiwon Kim7 months ago741except asyncio.CancelledError:
742await self.handle_stream_cancelled(
743thread, list(pending_items.values()), context
744)
745raise
f688d870victor-openai8 months ago746except CustomStreamError as e:
747yield ErrorEvent(
748code="custom",
749message=e.message,
750allow_retry=e.allow_retry,
751)
752except StreamError as e:
753yield ErrorEvent(
754code=e.code,
755allow_retry=e.allow_retry,
756)
757except Exception as e:
758yield ErrorEvent(
759code=ErrorCode.STREAM_ERROR,
760allow_retry=True,
761)
762logger.exception(e)
763
764if thread != last_thread:
765# in case user updated the thread at the end of the stream
766await self.store.save_thread(thread, context=context)
767yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
768
eee0a2f7Jiwon Kim7 months ago769def _apply_assistant_message_update(
770self,
771item: AssistantMessageItem,
772update: AssistantMessageContentPartAdded
773| AssistantMessageContentPartTextDelta
774| AssistantMessageContentPartAnnotationAdded
775| AssistantMessageContentPartDone,
776) -> AssistantMessageItem:
777updated = item.model_copy(deep=True)
778
779# Pad the content list so the requested content_index exists before we write into it.
780# (Streaming updates can arrive for an index that hasn’t been created yet)
781while len(updated.content) <= update.content_index:
782updated.content.append(AssistantMessageContent(text="", annotations=[]))
783
784match update:
785case AssistantMessageContentPartAdded():
786updated.content[update.content_index] = update.content
787case AssistantMessageContentPartTextDelta():
788updated.content[update.content_index].text += update.delta
789case AssistantMessageContentPartAnnotationAdded():
790annotations = updated.content[update.content_index].annotations
791if update.annotation_index <= len(annotations):
792annotations.insert(update.annotation_index, update.annotation)
793else:
794annotations.append(update.annotation)
795case AssistantMessageContentPartDone():
796updated.content[update.content_index] = update.content
797return updated
798
5a7244baJiwon Kim7 months ago799def _update_pending_items(
eee0a2f7Jiwon Kim7 months ago800self,
801pending_items: dict[str, ThreadItem],
802event: ThreadItemUpdatedEvent,
803):
804updated_item = pending_items.get(event.item_id)
5a7244baJiwon Kim7 months ago805update = event.update
806match updated_item:
807case AssistantMessageItem():
808if isinstance(
809update,
810(
811AssistantMessageContentPartAdded,
812AssistantMessageContentPartTextDelta,
813AssistantMessageContentPartAnnotationAdded,
814AssistantMessageContentPartDone,
815),
816):
817pending_items[updated_item.id] = (
818self._apply_assistant_message_update(updated_item, update)
819)
820case WorkflowItem():
821if isinstance(update, (WorkflowTaskUpdated, WorkflowTaskAdded)):
822match update:
823case WorkflowTaskUpdated():
824updated_item.workflow.tasks[update.task_index] = update.task
825case WorkflowTaskAdded():
826updated_item.workflow.tasks.append(update.task)
827
828pending_items[updated_item.id] = updated_item
829case _:
830pass
eee0a2f7Jiwon Kim7 months ago831
f688d870victor-openai8 months ago832async def _build_user_message_item(
833self, input: UserMessageInput, thread: ThreadMetadata, context: TContext
834) -> UserMessageItem:
835return UserMessageItem(
836id=self.store.generate_item_id("message", thread, context),
837content=input.content,
838thread_id=thread.id,
839attachments=[
840await self.store.load_attachment(attachment_id, context)
841for attachment_id in input.attachments
842],
843quoted_text=input.quoted_text,
844inference_options=input.inference_options,
845created_at=datetime.now(),
846)
847
848async def _load_full_thread(self, thread_id: str, context: TContext) -> Thread:
849thread_meta = await self.store.load_thread(thread_id, context=context)
850thread_items = await self.store.load_thread_items(
851thread_id,
852after=None,
853limit=DEFAULT_PAGE_SIZE,
854order="asc",
855context=context,
856)
857return Thread(**thread_meta.model_dump(), items=thread_items)
858
859async def _paginate_thread_items_reverse(
860self, thread_id: str, context: TContext
861) -> AsyncIterator[ThreadItem]:
862"""Paginate through thread items in reverse order (newest first)."""
863after = None
864while True:
865items = await self.store.load_thread_items(
866thread_id, after, DEFAULT_PAGE_SIZE, "desc", context
867)
868for item in items.data:
869yield item
870
871if not items.has_more:
872break
873after = items.after
874
875def _serialize(self, obj: BaseModel) -> bytes:
876return obj.model_dump_json(by_alias=True, exclude_none=True).encode("utf-8")
877
878def _to_thread_response(self, thread: ThreadMetadata | Thread) -> Thread:
879def is_hidden(item: ThreadItem) -> bool:
1860286bJiwon Kim7 months ago880return isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
f688d870victor-openai8 months ago881
882items = thread.items if isinstance(thread, Thread) else Page()
883items.data = [item for item in items.data if not is_hidden(item)]
884
885return Thread(
886id=thread.id,
887title=thread.title,
888created_at=thread.created_at,
889items=items,
890status=thread.status,
891)