openai/chatkit-python

Public

mirrored from https://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.5.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/server.py

896lines · modeblame

f688d870victor-openai8 months ago1import asyncio
2from abc import ABC, abstractmethod
3from collections.abc import AsyncIterator
4from contextlib import contextmanager
5from datetime import datetime
6from typing import (
7Any,
8AsyncGenerator,
9AsyncIterable,
10Callable,
11Generic,
12)
13
14import agents
15from agents.models.chatcmpl_helpers import (
16HEADERS_OVERRIDE as chat_completions_headers_override,
17)
18from agents.models.openai_responses import (
19_HEADERS_OVERRIDE as responses_headers_override,
20)
21from pydantic import BaseModel, TypeAdapter
3de07454Dan Wang8 months ago22from typing_extensions import TypeVar, assert_never
f688d870victor-openai8 months ago23
24from chatkit.errors import CustomStreamError, StreamError
25
26from .logger import logger
27from .store import AttachmentStore, Store, StoreItemType, default_generate_id
28from .types import (
29Action,
eee0a2f7Jiwon Kim7 months ago30AssistantMessageContent,
31AssistantMessageContentPartAdded,
32AssistantMessageContentPartAnnotationAdded,
33AssistantMessageContentPartDone,
34AssistantMessageContentPartTextDelta,
35AssistantMessageItem,
f688d870victor-openai8 months ago36AttachmentsCreateReq,
37AttachmentsDeleteReq,
38ChatKitReq,
39ClientToolCallItem,
40ErrorCode,
41ErrorEvent,
42FeedbackKind,
43HiddenContextItem,
44ItemsFeedbackReq,
45ItemsListReq,
46NonStreamingReq,
47Page,
5a7244baJiwon Kim7 months ago48SDKHiddenContextItem,
f688d870victor-openai8 months ago49StreamingReq,
5a7244baJiwon Kim7 months ago50StreamOptions,
51StreamOptionsEvent,
f688d870victor-openai8 months ago52Thread,
53ThreadCreatedEvent,
54ThreadItem,
55ThreadItemAddedEvent,
56ThreadItemDoneEvent,
57ThreadItemRemovedEvent,
58ThreadItemReplacedEvent,
643a02f7Jiwon Kim7 months ago59ThreadItemUpdatedEvent,
f688d870victor-openai8 months ago60ThreadMetadata,
61ThreadsAddClientToolOutputReq,
62ThreadsAddUserMessageReq,
63ThreadsCreateReq,
64ThreadsCustomActionReq,
65ThreadsDeleteReq,
66ThreadsGetByIdReq,
67ThreadsListReq,
68ThreadsRetryAfterItemReq,
69ThreadStreamEvent,
70ThreadsUpdateReq,
71ThreadUpdatedEvent,
72UserMessageInput,
73UserMessageItem,
74WidgetComponentUpdated,
75WidgetItem,
76WidgetRootUpdated,
77WidgetStreamingTextValueDelta,
eee0a2f7Jiwon Kim7 months ago78WorkflowItem,
5a7244baJiwon Kim7 months ago79WorkflowTaskAdded,
80WorkflowTaskUpdated,
f688d870victor-openai8 months ago81is_streaming_req,
82)
83from .version import __version__
91c93ee6Jiwon Kim7 months ago84from .widgets import WidgetComponent, WidgetComponentBase, WidgetRoot
f688d870victor-openai8 months ago85
86DEFAULT_PAGE_SIZE = 20
87DEFAULT_ERROR_MESSAGE = "An error occurred when generating a response."
88
89
90def diff_widget(
91before: WidgetRoot, after: WidgetRoot
92) -> list[WidgetStreamingTextValueDelta | WidgetRootUpdated | WidgetComponentUpdated]:
93"""
94Compare two WidgetRoots and return a list of deltas.
95"""
96
91c93ee6Jiwon Kim7 months ago97def is_streaming_text(component: WidgetComponentBase) -> bool:
98return getattr(component, "type", None) in {"Markdown", "Text"} and isinstance(
99getattr(component, "value", None), str
100)
101
f688d870victor-openai8 months ago102def full_replace(before: WidgetComponentBase, after: WidgetComponentBase) -> bool:
103if (
104before.type != after.type
105or before.id != after.id
106or before.key != after.key
107):
108return True
109
110def full_replace_value(before_value: Any, after_value: Any) -> bool:
111if isinstance(before_value, list) and isinstance(after_value, list):
112if len(before_value) != len(after_value):
113return True
114for nth_before_value, nth_after_value in zip(before_value, after_value):
115if full_replace_value(nth_before_value, nth_after_value):
116return True
117elif before_value != after_value:
118if isinstance(before_value, WidgetComponentBase) and isinstance(
119after_value, WidgetComponentBase
120):
121return full_replace(before_value, after_value)
122else:
123return True
124return False
125
126for field in before.model_fields_set.union(after.model_fields_set):
127if (
91c93ee6Jiwon Kim7 months ago128is_streaming_text(before)
129and is_streaming_text(after)
f688d870victor-openai8 months ago130and field == "value"
91c93ee6Jiwon Kim7 months ago131and getattr(after, "value", "").startswith(getattr(before, "value", ""))
f688d870victor-openai8 months ago132):
133# Appends to the value prop of Markdown or Text do not trigger a full replace
134continue
135if full_replace_value(getattr(before, field), getattr(after, field)):
136return True
137
138return False
139
140if full_replace(before, after):
141return [WidgetRootUpdated(widget=after)]
142
143deltas: list[
144WidgetStreamingTextValueDelta | WidgetComponentUpdated | WidgetRootUpdated
145] = []
146
147def find_all_streaming_text_components(
148component: WidgetComponent | WidgetRoot,
91c93ee6Jiwon Kim7 months ago149) -> dict[str, WidgetComponentBase]:
f688d870victor-openai8 months ago150components = {}
151
152def recurse(component: WidgetComponent | WidgetRoot):
91c93ee6Jiwon Kim7 months ago153if is_streaming_text(component) and component.id:
f688d870victor-openai8 months ago154components[component.id] = component
155
156if hasattr(component, "children"):
157children = getattr(component, "children", None) or []
158for child in children:
159recurse(child)
160
161recurse(component)
162return components
163
164before_nodes = find_all_streaming_text_components(before)
165after_nodes = find_all_streaming_text_components(after)
166
167for id, after_node in after_nodes.items():
168before_node = before_nodes.get(id)
169if before_node is None:
170raise ValueError(
171f"Node {id} was not present when the widget was initially rendered. All nodes with ID must persist across all widget updates."
172)
173
91c93ee6Jiwon Kim7 months ago174before_value = str(getattr(before_node, "value", None))
175after_value = str(getattr(after_node, "value", None))
176
177if before_value != after_value:
178if not after_value.startswith(before_value):
f688d870victor-openai8 months ago179raise ValueError(
180f"Node {id} was updated with a new value that is not a prefix of the initial value. All widget updates must be cumulative."
181)
91c93ee6Jiwon Kim7 months ago182done = not getattr(after_node, "streaming", False)
f688d870victor-openai8 months ago183deltas.append(
184WidgetStreamingTextValueDelta(
185component_id=id,
91c93ee6Jiwon Kim7 months ago186delta=after_value[len(before_value) :],
f688d870victor-openai8 months ago187done=done,
188)
189)
190
191return deltas
192
193
194async def stream_widget(
195thread: ThreadMetadata,
196widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
197copy_text: str | None = None,
198generate_id: Callable[[StoreItemType], str] = default_generate_id,
199) -> AsyncIterator[ThreadStreamEvent]:
3b20c3d1Jiwon Kim6 months ago200"""Stream a widget root (or async sequence of roots) as ThreadStreamEvents."""
f688d870victor-openai8 months ago201item_id = generate_id("message")
202
203if not isinstance(widget, AsyncGenerator):
204yield ThreadItemDoneEvent(
205item=WidgetItem(
206id=item_id,
207thread_id=thread.id,
208created_at=datetime.now(),
209widget=widget,
210copy_text=copy_text,
211),
212)
213return
214
215initial_state = await widget.__anext__()
216
217item = WidgetItem(
218id=item_id,
219created_at=datetime.now(),
220widget=initial_state,
221copy_text=copy_text,
222thread_id=thread.id,
223)
224
225yield ThreadItemAddedEvent(item=item)
226
227last_state = initial_state
228
229while widget:
230try:
231new_state = await widget.__anext__()
232for update in diff_widget(last_state, new_state):
643a02f7Jiwon Kim7 months ago233yield ThreadItemUpdatedEvent(
f688d870victor-openai8 months ago234item_id=item_id,
235update=update,
236)
237last_state = new_state
238except StopAsyncIteration:
239break
240
241yield ThreadItemDoneEvent(
242item=item.model_copy(update={"widget": last_state}),
243)
244
245
246@contextmanager
247def agents_sdk_user_agent_override():
248ua = f"Agents/Python {agents.__version__} ChatKit/Python {__version__}"
249chat_completions_token = chat_completions_headers_override.set({"User-Agent": ua})
250responses_token = responses_headers_override.set({"User-Agent": ua})
251
252yield
253
254chat_completions_headers_override.reset(chat_completions_token)
255responses_headers_override.reset(responses_token)
256
257
258class StreamingResult(AsyncIterable[bytes]):
259def __init__(self, stream: AsyncGenerator[bytes, None]):
260self.json_events = stream
261
262async def __aiter__(self):
263async for event in self.json_events:
264yield event
265
266
267class NonStreamingResult:
268def __init__(self, result: bytes):
269self.json = result
270
271
272TContext = TypeVar("TContext", default=Any)
273
274
275class ChatKitServer(ABC, Generic[TContext]):
276def __init__(
277self,
278store: Store[TContext],
279attachment_store: AttachmentStore[TContext] | None = None,
280):
3b20c3d1Jiwon Kim6 months ago281"""Create a ChatKitServer with the backing Store and optional AttachmentStore."""
f688d870victor-openai8 months ago282self.store = store
283self.attachment_store = attachment_store
284
285def _get_attachment_store(self) -> AttachmentStore[TContext]:
286"""Return the configured AttachmentStore or raise if missing."""
287if self.attachment_store is None:
288raise RuntimeError(
289"AttachmentStore is not configured. Provide a AttachmentStore to ChatKitServer to handle file operations."
290)
291return self.attachment_store
292
293@abstractmethod
294def respond(
295self,
296thread: ThreadMetadata,
297input_user_message: UserMessageItem | None,
298context: TContext,
299) -> AsyncIterator[ThreadStreamEvent]:
6988ea0avictor-openai8 months ago300"""Stream `ThreadStreamEvent` instances for a new user message.
301
302Args:
303thread: Metadata for the thread being processed.
304input_user_message: The incoming message the server should respond to, if any.
305context: Arbitrary per-request context provided by the caller.
306
307Returns:
308An async iterator that yields events representing the server's response.
309"""
f688d870victor-openai8 months ago310pass
311
312async def add_feedback( # noqa: B027
313self,
314thread_id: str,
315item_ids: list[str],
316feedback: FeedbackKind,
317context: TContext,
318) -> None:
3b20c3d1Jiwon Kim6 months ago319"""Persist user feedback for one or more thread items."""
f688d870victor-openai8 months ago320pass
321
322def action(
323self,
324thread: ThreadMetadata,
325action: Action[str, Any],
326sender: WidgetItem | None,
327context: TContext,
328) -> AsyncIterator[ThreadStreamEvent]:
3b20c3d1Jiwon Kim6 months ago329"""Handle a widget or client-dispatched action and yield response events."""
f688d870victor-openai8 months ago330raise NotImplementedError(
331"The action() method must be overridden to react to actions. "
93a34923Vincent Lin8 months ago332"See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
f688d870victor-openai8 months ago333)
334
5a7244baJiwon Kim7 months ago335def get_stream_options(
336self, thread: ThreadMetadata, context: TContext
337) -> StreamOptions:
338"""
339Return stream-level runtime options. Allows the user to cancel the stream by default.
340Override this method to customize behavior.
341"""
342return StreamOptions(allow_cancel=True)
343
eee0a2f7Jiwon Kim7 months ago344async def handle_stream_cancelled(
345self,
346thread: ThreadMetadata,
347pending_items: list[ThreadItem],
348context: TContext,
349):
350"""Perform custom cleanup / stop inference when a stream is cancelled.
5a7244baJiwon Kim7 months ago351Updates you make here will not be reflected in the UI until a reload.
352
353The default implementation persists any non-empty pending assistant messages
354to the thread but does not auto-save pending widget items or workflow items.
eee0a2f7Jiwon Kim7 months ago355
356Args:
357thread: The thread that was being processed.
358pending_items: Items that were not done streaming at cancellation time.
359context: Arbitrary per-request context provided by the caller.
360"""
5a7244baJiwon Kim7 months ago361pending_assistant_message_items: list[AssistantMessageItem] = [
362item for item in pending_items if isinstance(item, AssistantMessageItem)
363]
364for item in pending_assistant_message_items:
365is_empty = len(item.content) == 0 or all(
366(not content.text.strip()) for content in item.content
367)
368if not is_empty:
369await self.store.add_thread_item(thread.id, item, context=context)
370
371# Add a hidden context item to the thread to indicate that the stream was cancelled.
372# Otherwise, depending on the timing of the cancellation, subsequent responses may
373# attempt to continue the cancelled response.
374await self.store.add_thread_item(
375thread.id,
376SDKHiddenContextItem(
377thread_id=thread.id,
378created_at=datetime.now(),
379id=self.store.generate_item_id("sdk_hidden_context", thread, context),
6df9ee60Jiwon Kim7 months ago380content="The user cancelled the stream. Stop responding to the prior request.",
5a7244baJiwon Kim7 months ago381),
382context=context,
383)
eee0a2f7Jiwon Kim7 months ago384
f688d870victor-openai8 months ago385async def process(
386self, request: str | bytes | bytearray, context: TContext
387) -> StreamingResult | NonStreamingResult:
3b20c3d1Jiwon Kim6 months ago388"""Parse an incoming ChatKit request and route it to streaming or non-streaming handlers."""
f688d870victor-openai8 months ago389parsed_request = TypeAdapter[ChatKitReq](ChatKitReq).validate_json(request)
390logger.info(f"Received request op: {parsed_request.type}")
391
392if is_streaming_req(parsed_request):
393return StreamingResult(self._process_streaming(parsed_request, context))
394else:
395return NonStreamingResult(
396await self._process_non_streaming(parsed_request, context)
397)
398
399async def _process_non_streaming(
400self, request: NonStreamingReq, context: TContext
401) -> bytes:
402match request:
403case ThreadsGetByIdReq():
404thread = await self._load_full_thread(
405request.params.thread_id, context=context
406)
407return self._serialize(self._to_thread_response(thread))
408case ThreadsListReq():
409params = request.params
410threads = await self.store.load_threads(
411limit=params.limit or DEFAULT_PAGE_SIZE,
412after=params.after,
413order=params.order,
414context=context,
415)
416return self._serialize(
417Page(
418has_more=threads.has_more,
419after=threads.after,
420data=[
421self._to_thread_response(thread) for thread in threads.data
422],
423)
424)
425case ItemsFeedbackReq():
426await self.add_feedback(
427request.params.thread_id,
428request.params.item_ids,
429request.params.kind,
430context,
431)
432return b"{}"
433case AttachmentsCreateReq():
434attachment_store = self._get_attachment_store()
435attachment = await attachment_store.create_attachment(
436request.params, context
437)
7c90a733Jiwon Kim6 months ago438await self.store.save_attachment(attachment, context=context)
f688d870victor-openai8 months ago439return self._serialize(attachment)
440case AttachmentsDeleteReq():
441attachment_store = self._get_attachment_store()
442await attachment_store.delete_attachment(
443request.params.attachment_id, context=context
444)
445await self.store.delete_attachment(
446request.params.attachment_id, context=context
447)
448return b"{}"
449case ItemsListReq():
450items_list_params = request.params
451items = await self.store.load_thread_items(
452items_list_params.thread_id,
453limit=items_list_params.limit or DEFAULT_PAGE_SIZE,
454order=items_list_params.order,
455after=items_list_params.after,
456context=context,
457)
1860286bJiwon Kim7 months ago458# filter out hidden context items
f688d870victor-openai8 months ago459items.data = [
460item
461for item in items.data
1860286bJiwon Kim7 months ago462if not isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
f688d870victor-openai8 months ago463]
464return self._serialize(items)
465case ThreadsUpdateReq():
466thread = await self.store.load_thread(
467request.params.thread_id, context=context
468)
469thread.title = request.params.title
470await self.store.save_thread(thread, context=context)
471return self._serialize(self._to_thread_response(thread))
472case ThreadsDeleteReq():
473await self.store.delete_thread(
474request.params.thread_id, context=context
475)
476return b"{}"
477case _:
478assert_never(request)
479
480async def _process_streaming(
481self, request: StreamingReq, context: TContext
482) -> AsyncGenerator[bytes, None]:
483try:
484async for event in self._process_streaming_impl(request, context):
485b = self._serialize(event)
486yield b"data: " + b + b"\n\n"
eee0a2f7Jiwon Kim7 months ago487except asyncio.CancelledError:
488# Let cancellation bubble up without logging as an error.
489raise
f688d870victor-openai8 months ago490except Exception:
491logger.exception("Error while generating streamed response")
492raise
493
494async def _process_streaming_impl(
495self, request: StreamingReq, context: TContext
496) -> AsyncGenerator[ThreadStreamEvent, None]:
497match request:
498case ThreadsCreateReq():
499thread = Thread(
500id=self.store.generate_thread_id(context),
501created_at=datetime.now(),
502items=Page(),
503)
8885a10bquettabit8 months ago504await self.store.save_thread(
505ThreadMetadata(**thread.model_dump()),
506context=context,
507)
f688d870victor-openai8 months ago508yield ThreadCreatedEvent(thread=self._to_thread_response(thread))
509user_message = await self._build_user_message_item(
510request.params.input, thread, context
511)
512async for event in self._process_new_thread_item_respond(
513thread,
514user_message,
515context,
516):
517yield event
518
519case ThreadsAddUserMessageReq():
520thread = await self.store.load_thread(
521request.params.thread_id, context=context
522)
523user_message = await self._build_user_message_item(
524request.params.input, thread, context
525)
526async for event in self._process_new_thread_item_respond(
527thread,
528user_message,
529context,
530):
531yield event
532
533case ThreadsAddClientToolOutputReq():
534thread = await self.store.load_thread(
535request.params.thread_id, context=context
536)
537items = await self.store.load_thread_items(
538thread.id, None, 1, "desc", context
539)
540tool_call = next(
541(
542item
543for item in items.data
544if isinstance(item, ClientToolCallItem)
545and item.status == "pending"
546),
547None,
548)
549if not tool_call:
550raise ValueError(
551f"Last thread item in {thread.id} was not a ClientToolCallItem"
552)
553
554tool_call.output = request.params.result
555tool_call.status = "completed"
556
557await self.store.save_item(thread.id, tool_call, context=context)
558
559# Safety against dangling pending tool calls if there are
560# multiple in a row, which should be impossible, and
561# integrations should ultimately filter out pending tool calls
562# when creating input response messages.
563await self._cleanup_pending_client_tool_call(thread, context)
564
565async for event in self._process_events(
566thread,
567context,
568lambda: self.respond(thread, None, context),
569):
570yield event
571
572case ThreadsRetryAfterItemReq():
573thread_metadata = await self.store.load_thread(
574request.params.thread_id, context=context
575)
576
577# Collect items to remove (all items after the user message)
578items_to_remove: list[ThreadItem] = []
579user_message_item = None
580
581async for item in self._paginate_thread_items_reverse(
582request.params.thread_id, context
583):
584if item.id == request.params.item_id:
585if not isinstance(item, UserMessageItem):
586raise ValueError(
587f"Item {request.params.item_id} is not a user message"
588)
589user_message_item = item
590break
591items_to_remove.append(item)
592
593if user_message_item:
594for item in items_to_remove:
595await self.store.delete_thread_item(
596request.params.thread_id, item.id, context=context
597)
598async for event in self._process_events(
599thread_metadata,
600context,
601lambda: self.respond(
602thread_metadata,
603user_message_item,
604context,
605),
606):
607yield event
608case ThreadsCustomActionReq():
609thread_metadata = await self.store.load_thread(
610request.params.thread_id, context=context
611)
612
613item: ThreadItem | None = None
614if request.params.item_id:
615item = await self.store.load_item(
616request.params.thread_id,
617request.params.item_id,
618context=context,
619)
620
621if item and not isinstance(item, WidgetItem):
622# shouldn't happen if the caller is using the API correctly.
623yield ErrorEvent(
624code=ErrorCode.STREAM_ERROR,
625allow_retry=False,
626)
627return
628
629async for event in self._process_events(
630thread_metadata,
631context,
632lambda: self.action(
633thread_metadata,
634request.params.action,
635item,
636context,
637),
638):
639yield event
640
641case _:
642assert_never(request)
643
644async def _cleanup_pending_client_tool_call(
645self, thread: ThreadMetadata, context: TContext
646) -> None:
647items = await self.store.load_thread_items(
648thread.id, None, DEFAULT_PAGE_SIZE, "desc", context
649)
650for tool_call in items.data:
651if not isinstance(tool_call, ClientToolCallItem):
652continue
653if tool_call.status == "pending":
654logger.warning(
655f"Client tool call {tool_call.call_id} was not completed, ignoring"
656)
657await self.store.delete_thread_item(
658thread.id, tool_call.id, context=context
659)
660
661async def _process_new_thread_item_respond(
662self,
663thread: ThreadMetadata,
664item: UserMessageItem,
665context: TContext,
666) -> AsyncIterator[ThreadStreamEvent]:
667await self.store.add_thread_item(thread.id, item, context=context)
668yield ThreadItemDoneEvent(item=item)
669
670async for event in self._process_events(
671thread,
672context,
673lambda: self.respond(thread, item, context),
674):
675yield event
676
677async def _process_events(
678self,
679thread: ThreadMetadata,
680context: TContext,
681stream: Callable[[], AsyncIterator[ThreadStreamEvent]],
682) -> AsyncIterator[ThreadStreamEvent]:
683await asyncio.sleep(0) # allow the response to start streaming
684
5a7244baJiwon Kim7 months ago685# Send initial stream options
686yield StreamOptionsEvent(
687stream_options=self.get_stream_options(thread, context)
688)
689
f688d870victor-openai8 months ago690last_thread = thread.model_copy(deep=True)
691
eee0a2f7Jiwon Kim7 months ago692# Keep track of items that were streamed but not yet saved
693# so that we can persist them when the stream is cancelled.
694pending_items: dict[str, ThreadItem] = {}
695
f688d870victor-openai8 months ago696try:
697with agents_sdk_user_agent_override():
698async for event in stream():
eee0a2f7Jiwon Kim7 months ago699if isinstance(event, ThreadItemAddedEvent):
700pending_items[event.item.id] = event.item
701
f688d870victor-openai8 months ago702match event:
703case ThreadItemDoneEvent():
704await self.store.add_thread_item(
705thread.id, event.item, context=context
706)
eee0a2f7Jiwon Kim7 months ago707pending_items.pop(event.item.id, None)
f688d870victor-openai8 months ago708case ThreadItemRemovedEvent():
709await self.store.delete_thread_item(
710thread.id, event.item_id, context=context
711)
eee0a2f7Jiwon Kim7 months ago712pending_items.pop(event.item_id, None)
f688d870victor-openai8 months ago713case ThreadItemReplacedEvent():
714await self.store.save_item(
715thread.id, event.item, context=context
716)
eee0a2f7Jiwon Kim7 months ago717pending_items.pop(event.item.id, None)
718case ThreadItemUpdatedEvent():
5a7244baJiwon Kim7 months ago719# Keep pending assistant message and workflow items up to date
720# so that we have a reference to the latest version of these pending items
721# when the stream is cancelled.
722self._update_pending_items(pending_items, event)
f688d870victor-openai8 months ago723
724# special case - don't send hidden context items back to the client
725should_swallow_event = isinstance(
726event, ThreadItemDoneEvent
1860286bJiwon Kim7 months ago727) and isinstance(
728event.item, (HiddenContextItem, SDKHiddenContextItem)
729)
f688d870victor-openai8 months ago730
731if not should_swallow_event:
732yield event
733
734# in case user updated the thread while streaming
735if thread != last_thread:
736last_thread = thread.model_copy(deep=True)
737await self.store.save_thread(thread, context=context)
738yield ThreadUpdatedEvent(
739thread=self._to_thread_response(thread)
740)
741# in case user updated the thread while streaming
742if thread != last_thread:
743last_thread = thread.model_copy(deep=True)
744await self.store.save_thread(thread, context=context)
745yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
eee0a2f7Jiwon Kim7 months ago746except asyncio.CancelledError:
747await self.handle_stream_cancelled(
748thread, list(pending_items.values()), context
749)
750raise
f688d870victor-openai8 months ago751except CustomStreamError as e:
752yield ErrorEvent(
753code="custom",
754message=e.message,
755allow_retry=e.allow_retry,
756)
757except StreamError as e:
758yield ErrorEvent(
759code=e.code,
760allow_retry=e.allow_retry,
761)
762except Exception as e:
763yield ErrorEvent(
764code=ErrorCode.STREAM_ERROR,
765allow_retry=True,
766)
767logger.exception(e)
768
769if thread != last_thread:
770# in case user updated the thread at the end of the stream
771await self.store.save_thread(thread, context=context)
772yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
773
eee0a2f7Jiwon Kim7 months ago774def _apply_assistant_message_update(
775self,
776item: AssistantMessageItem,
777update: AssistantMessageContentPartAdded
778| AssistantMessageContentPartTextDelta
779| AssistantMessageContentPartAnnotationAdded
780| AssistantMessageContentPartDone,
781) -> AssistantMessageItem:
782updated = item.model_copy(deep=True)
783
784# Pad the content list so the requested content_index exists before we write into it.
785# (Streaming updates can arrive for an index that hasn’t been created yet)
786while len(updated.content) <= update.content_index:
787updated.content.append(AssistantMessageContent(text="", annotations=[]))
788
789match update:
790case AssistantMessageContentPartAdded():
791updated.content[update.content_index] = update.content
792case AssistantMessageContentPartTextDelta():
793updated.content[update.content_index].text += update.delta
794case AssistantMessageContentPartAnnotationAdded():
795annotations = updated.content[update.content_index].annotations
796if update.annotation_index <= len(annotations):
797annotations.insert(update.annotation_index, update.annotation)
798else:
799annotations.append(update.annotation)
800case AssistantMessageContentPartDone():
801updated.content[update.content_index] = update.content
802return updated
803
5a7244baJiwon Kim7 months ago804def _update_pending_items(
eee0a2f7Jiwon Kim7 months ago805self,
806pending_items: dict[str, ThreadItem],
807event: ThreadItemUpdatedEvent,
808):
809updated_item = pending_items.get(event.item_id)
5a7244baJiwon Kim7 months ago810update = event.update
811match updated_item:
812case AssistantMessageItem():
813if isinstance(
814update,
815(
816AssistantMessageContentPartAdded,
817AssistantMessageContentPartTextDelta,
818AssistantMessageContentPartAnnotationAdded,
819AssistantMessageContentPartDone,
820),
821):
822pending_items[updated_item.id] = (
823self._apply_assistant_message_update(updated_item, update)
824)
825case WorkflowItem():
826if isinstance(update, (WorkflowTaskUpdated, WorkflowTaskAdded)):
827match update:
828case WorkflowTaskUpdated():
829updated_item.workflow.tasks[update.task_index] = update.task
830case WorkflowTaskAdded():
831updated_item.workflow.tasks.append(update.task)
832
833pending_items[updated_item.id] = updated_item
834case _:
835pass
eee0a2f7Jiwon Kim7 months ago836
f688d870victor-openai8 months ago837async def _build_user_message_item(
838self, input: UserMessageInput, thread: ThreadMetadata, context: TContext
839) -> UserMessageItem:
840return UserMessageItem(
841id=self.store.generate_item_id("message", thread, context),
842content=input.content,
843thread_id=thread.id,
844attachments=[
845await self.store.load_attachment(attachment_id, context)
846for attachment_id in input.attachments
847],
848quoted_text=input.quoted_text,
849inference_options=input.inference_options,
850created_at=datetime.now(),
851)
852
853async def _load_full_thread(self, thread_id: str, context: TContext) -> Thread:
854thread_meta = await self.store.load_thread(thread_id, context=context)
855thread_items = await self.store.load_thread_items(
856thread_id,
857after=None,
858limit=DEFAULT_PAGE_SIZE,
859order="asc",
860context=context,
861)
862return Thread(**thread_meta.model_dump(), items=thread_items)
863
864async def _paginate_thread_items_reverse(
865self, thread_id: str, context: TContext
866) -> AsyncIterator[ThreadItem]:
867"""Paginate through thread items in reverse order (newest first)."""
868after = None
869while True:
870items = await self.store.load_thread_items(
871thread_id, after, DEFAULT_PAGE_SIZE, "desc", context
872)
873for item in items.data:
874yield item
875
876if not items.has_more:
877break
878after = items.after
879
880def _serialize(self, obj: BaseModel) -> bytes:
881return obj.model_dump_json(by_alias=True, exclude_none=True).encode("utf-8")
882
883def _to_thread_response(self, thread: ThreadMetadata | Thread) -> Thread:
884def is_hidden(item: ThreadItem) -> bool:
1860286bJiwon Kim7 months ago885return isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
f688d870victor-openai8 months ago886
887items = thread.items if isinstance(thread, Thread) else Page()
888items.data = [item for item in items.data if not is_hidden(item)]
889
890return Thread(
891id=thread.id,
892title=thread.title,
893created_at=thread.created_at,
894items=items,
895status=thread.status,
896)