openai/chatkit-python

Public

mirrored from https://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
6646f83af681e14fcdcf69f2fa31dbdb0a9efa2d

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/server.py

896lines · modeblame

f688d870victor-openai8 months ago1import asyncio
2from abc import ABC, abstractmethod
3from collections.abc import AsyncIterator
4from contextlib import contextmanager
5from datetime import datetime
6from typing import (
7Any,
8AsyncGenerator,
9AsyncIterable,
10Callable,
11Generic,
12)
13
14import agents
15from agents.models.chatcmpl_helpers import (
16HEADERS_OVERRIDE as chat_completions_headers_override,
17)
18from agents.models.openai_responses import (
19_HEADERS_OVERRIDE as responses_headers_override,
20)
21from pydantic import BaseModel, TypeAdapter
3de07454Dan Wang8 months ago22from typing_extensions import TypeVar, assert_never
f688d870victor-openai8 months ago23
24from chatkit.errors import CustomStreamError, StreamError
25
26from .logger import logger
27from .store import AttachmentStore, Store, StoreItemType, default_generate_id
28from .types import (
29Action,
eee0a2f7Jiwon Kim7 months ago30AssistantMessageContent,
31AssistantMessageContentPartAdded,
32AssistantMessageContentPartAnnotationAdded,
33AssistantMessageContentPartDone,
34AssistantMessageContentPartTextDelta,
35AssistantMessageItem,
f688d870victor-openai8 months ago36AttachmentsCreateReq,
37AttachmentsDeleteReq,
38ChatKitReq,
39ClientToolCallItem,
40ErrorCode,
41ErrorEvent,
42FeedbackKind,
43HiddenContextItem,
44ItemsFeedbackReq,
45ItemsListReq,
46NonStreamingReq,
47Page,
5a7244baJiwon Kim7 months ago48SDKHiddenContextItem,
f688d870victor-openai8 months ago49StreamingReq,
5a7244baJiwon Kim7 months ago50StreamOptions,
51StreamOptionsEvent,
f688d870victor-openai8 months ago52Thread,
53ThreadCreatedEvent,
54ThreadItem,
55ThreadItemAddedEvent,
56ThreadItemDoneEvent,
57ThreadItemRemovedEvent,
58ThreadItemReplacedEvent,
643a02f7Jiwon Kim7 months ago59ThreadItemUpdatedEvent,
f688d870victor-openai8 months ago60ThreadMetadata,
61ThreadsAddClientToolOutputReq,
62ThreadsAddUserMessageReq,
63ThreadsCreateReq,
64ThreadsCustomActionReq,
65ThreadsDeleteReq,
66ThreadsGetByIdReq,
67ThreadsListReq,
68ThreadsRetryAfterItemReq,
69ThreadStreamEvent,
70ThreadsUpdateReq,
71ThreadUpdatedEvent,
72UserMessageInput,
73UserMessageItem,
74WidgetComponentUpdated,
75WidgetItem,
76WidgetRootUpdated,
77WidgetStreamingTextValueDelta,
eee0a2f7Jiwon Kim7 months ago78WorkflowItem,
5a7244baJiwon Kim7 months ago79WorkflowTaskAdded,
80WorkflowTaskUpdated,
f688d870victor-openai8 months ago81is_streaming_req,
82)
83from .version import __version__
91c93ee6Jiwon Kim7 months ago84from .widgets import WidgetComponent, WidgetComponentBase, WidgetRoot
f688d870victor-openai8 months ago85
86DEFAULT_PAGE_SIZE = 20
87DEFAULT_ERROR_MESSAGE = "An error occurred when generating a response."
88
89
90def diff_widget(
91before: WidgetRoot, after: WidgetRoot
92) -> list[WidgetStreamingTextValueDelta | WidgetRootUpdated | WidgetComponentUpdated]:
93"""
94Compare two WidgetRoots and return a list of deltas.
95"""
96
91c93ee6Jiwon Kim7 months ago97def is_streaming_text(component: WidgetComponentBase) -> bool:
98return getattr(component, "type", None) in {"Markdown", "Text"} and isinstance(
99getattr(component, "value", None), str
100)
101
f688d870victor-openai8 months ago102def full_replace(before: WidgetComponentBase, after: WidgetComponentBase) -> bool:
103if (
104before.type != after.type
105or before.id != after.id
106or before.key != after.key
107):
108return True
109
110def full_replace_value(before_value: Any, after_value: Any) -> bool:
111if isinstance(before_value, list) and isinstance(after_value, list):
112if len(before_value) != len(after_value):
113return True
114for nth_before_value, nth_after_value in zip(before_value, after_value):
115if full_replace_value(nth_before_value, nth_after_value):
116return True
117elif before_value != after_value:
118if isinstance(before_value, WidgetComponentBase) and isinstance(
119after_value, WidgetComponentBase
120):
121return full_replace(before_value, after_value)
122else:
123return True
124return False
125
126for field in before.model_fields_set.union(after.model_fields_set):
127if (
91c93ee6Jiwon Kim7 months ago128is_streaming_text(before)
129and is_streaming_text(after)
f688d870victor-openai8 months ago130and field == "value"
91c93ee6Jiwon Kim7 months ago131and getattr(after, "value", "").startswith(getattr(before, "value", ""))
f688d870victor-openai8 months ago132):
133# Appends to the value prop of Markdown or Text do not trigger a full replace
134continue
135if full_replace_value(getattr(before, field), getattr(after, field)):
136return True
137
138return False
139
140if full_replace(before, after):
141return [WidgetRootUpdated(widget=after)]
142
143deltas: list[
144WidgetStreamingTextValueDelta | WidgetComponentUpdated | WidgetRootUpdated
145] = []
146
147def find_all_streaming_text_components(
148component: WidgetComponent | WidgetRoot,
91c93ee6Jiwon Kim7 months ago149) -> dict[str, WidgetComponentBase]:
f688d870victor-openai8 months ago150components = {}
151
152def recurse(component: WidgetComponent | WidgetRoot):
91c93ee6Jiwon Kim7 months ago153if is_streaming_text(component) and component.id:
f688d870victor-openai8 months ago154components[component.id] = component
155
156if hasattr(component, "children"):
157children = getattr(component, "children", None) or []
158for child in children:
159recurse(child)
160
161recurse(component)
162return components
163
164before_nodes = find_all_streaming_text_components(before)
165after_nodes = find_all_streaming_text_components(after)
166
167for id, after_node in after_nodes.items():
168before_node = before_nodes.get(id)
169if before_node is None:
170raise ValueError(
171f"Node {id} was not present when the widget was initially rendered. All nodes with ID must persist across all widget updates."
172)
173
91c93ee6Jiwon Kim7 months ago174before_value = str(getattr(before_node, "value", None))
175after_value = str(getattr(after_node, "value", None))
176
177if before_value != after_value:
178if not after_value.startswith(before_value):
f688d870victor-openai8 months ago179raise ValueError(
180f"Node {id} was updated with a new value that is not a prefix of the initial value. All widget updates must be cumulative."
181)
91c93ee6Jiwon Kim7 months ago182done = not getattr(after_node, "streaming", False)
f688d870victor-openai8 months ago183deltas.append(
184WidgetStreamingTextValueDelta(
185component_id=id,
91c93ee6Jiwon Kim7 months ago186delta=after_value[len(before_value) :],
f688d870victor-openai8 months ago187done=done,
188)
189)
190
191return deltas
192
193
194async def stream_widget(
195thread: ThreadMetadata,
196widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
197copy_text: str | None = None,
198generate_id: Callable[[StoreItemType], str] = default_generate_id,
199) -> AsyncIterator[ThreadStreamEvent]:
3b20c3d1Jiwon Kim7 months ago200"""Stream a widget root (or async sequence of roots) as ThreadStreamEvents."""
f688d870victor-openai8 months ago201item_id = generate_id("message")
202
203if not isinstance(widget, AsyncGenerator):
204yield ThreadItemDoneEvent(
205item=WidgetItem(
206id=item_id,
207thread_id=thread.id,
208created_at=datetime.now(),
209widget=widget,
210copy_text=copy_text,
211),
212)
213return
214
215initial_state = await widget.__anext__()
216
217item = WidgetItem(
218id=item_id,
219created_at=datetime.now(),
220widget=initial_state,
221copy_text=copy_text,
222thread_id=thread.id,
223)
224
225yield ThreadItemAddedEvent(item=item)
226
227last_state = initial_state
228
229while widget:
230try:
231new_state = await widget.__anext__()
232for update in diff_widget(last_state, new_state):
643a02f7Jiwon Kim7 months ago233yield ThreadItemUpdatedEvent(
f688d870victor-openai8 months ago234item_id=item_id,
235update=update,
236)
237last_state = new_state
238except StopAsyncIteration:
239break
240
241yield ThreadItemDoneEvent(
242item=item.model_copy(update={"widget": last_state}),
243)
244
245
246@contextmanager
247def agents_sdk_user_agent_override():
248ua = f"Agents/Python {agents.__version__} ChatKit/Python {__version__}"
249chat_completions_token = chat_completions_headers_override.set({"User-Agent": ua})
250responses_token = responses_headers_override.set({"User-Agent": ua})
251
252yield
253
254chat_completions_headers_override.reset(chat_completions_token)
255responses_headers_override.reset(responses_token)
256
257
258class StreamingResult(AsyncIterable[bytes]):
259def __init__(self, stream: AsyncGenerator[bytes, None]):
260self.json_events = stream
261
262async def __aiter__(self):
263async for event in self.json_events:
264yield event
265
266
267class NonStreamingResult:
268def __init__(self, result: bytes):
269self.json = result
270
271
272TContext = TypeVar("TContext", default=Any)
273
274
275class ChatKitServer(ABC, Generic[TContext]):
276def __init__(
277self,
278store: Store[TContext],
279attachment_store: AttachmentStore[TContext] | None = None,
280):
3b20c3d1Jiwon Kim7 months ago281"""Create a ChatKitServer with the backing Store and optional AttachmentStore."""
f688d870victor-openai8 months ago282self.store = store
283self.attachment_store = attachment_store
284
285def _get_attachment_store(self) -> AttachmentStore[TContext]:
286"""Return the configured AttachmentStore or raise if missing."""
287if self.attachment_store is None:
288raise RuntimeError(
289"AttachmentStore is not configured. Provide a AttachmentStore to ChatKitServer to handle file operations."
290)
291return self.attachment_store
292
293@abstractmethod
294def respond(
295self,
296thread: ThreadMetadata,
297input_user_message: UserMessageItem | None,
298context: TContext,
299) -> AsyncIterator[ThreadStreamEvent]:
6988ea0avictor-openai8 months ago300"""Stream `ThreadStreamEvent` instances for a new user message.
301
302Args:
303thread: Metadata for the thread being processed.
304input_user_message: The incoming message the server should respond to, if any.
305context: Arbitrary per-request context provided by the caller.
306
307Returns:
308An async iterator that yields events representing the server's response.
309"""
f688d870victor-openai8 months ago310pass
311
312async def add_feedback( # noqa: B027
313self,
314thread_id: str,
315item_ids: list[str],
316feedback: FeedbackKind,
317context: TContext,
318) -> None:
3b20c3d1Jiwon Kim7 months ago319"""Persist user feedback for one or more thread items."""
f688d870victor-openai8 months ago320pass
321
322def action(
323self,
324thread: ThreadMetadata,
325action: Action[str, Any],
326sender: WidgetItem | None,
327context: TContext,
328) -> AsyncIterator[ThreadStreamEvent]:
3b20c3d1Jiwon Kim7 months ago329"""Handle a widget or client-dispatched action and yield response events."""
f688d870victor-openai8 months ago330raise NotImplementedError(
331"The action() method must be overridden to react to actions. "
93a34923Vincent Lin8 months ago332"See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
f688d870victor-openai8 months ago333)
334
5a7244baJiwon Kim7 months ago335def get_stream_options(
336self, thread: ThreadMetadata, context: TContext
337) -> StreamOptions:
338"""
339Return stream-level runtime options. Allows the user to cancel the stream by default.
340Override this method to customize behavior.
341"""
342return StreamOptions(allow_cancel=True)
343
eee0a2f7Jiwon Kim7 months ago344async def handle_stream_cancelled(
345self,
346thread: ThreadMetadata,
347pending_items: list[ThreadItem],
348context: TContext,
349):
350"""Perform custom cleanup / stop inference when a stream is cancelled.
5a7244baJiwon Kim7 months ago351Updates you make here will not be reflected in the UI until a reload.
352
353The default implementation persists any non-empty pending assistant messages
354to the thread but does not auto-save pending widget items or workflow items.
eee0a2f7Jiwon Kim7 months ago355
356Args:
357thread: The thread that was being processed.
358pending_items: Items that were not done streaming at cancellation time.
359context: Arbitrary per-request context provided by the caller.
360"""
5a7244baJiwon Kim7 months ago361pending_assistant_message_items: list[AssistantMessageItem] = [
362item for item in pending_items if isinstance(item, AssistantMessageItem)
363]
364for item in pending_assistant_message_items:
365is_empty = len(item.content) == 0 or all(
366(not content.text.strip()) for content in item.content
367)
368if not is_empty:
369await self.store.add_thread_item(thread.id, item, context=context)
370
371# Add a hidden context item to the thread to indicate that the stream was cancelled.
372# Otherwise, depending on the timing of the cancellation, subsequent responses may
373# attempt to continue the cancelled response.
374await self.store.add_thread_item(
375thread.id,
376SDKHiddenContextItem(
377thread_id=thread.id,
378created_at=datetime.now(),
379id=self.store.generate_item_id("sdk_hidden_context", thread, context),
6df9ee60Jiwon Kim7 months ago380content="The user cancelled the stream. Stop responding to the prior request.",
5a7244baJiwon Kim7 months ago381),
382context=context,
383)
eee0a2f7Jiwon Kim7 months ago384
f688d870victor-openai8 months ago385async def process(
386self, request: str | bytes | bytearray, context: TContext
387) -> StreamingResult | NonStreamingResult:
3b20c3d1Jiwon Kim7 months ago388"""Parse an incoming ChatKit request and route it to streaming or non-streaming handlers."""
f688d870victor-openai8 months ago389parsed_request = TypeAdapter[ChatKitReq](ChatKitReq).validate_json(request)
390logger.info(f"Received request op: {parsed_request.type}")
391
392if is_streaming_req(parsed_request):
393return StreamingResult(self._process_streaming(parsed_request, context))
394else:
395return NonStreamingResult(
396await self._process_non_streaming(parsed_request, context)
397)
398
399async def _process_non_streaming(
400self, request: NonStreamingReq, context: TContext
401) -> bytes:
402match request:
403case ThreadsGetByIdReq():
404thread = await self._load_full_thread(
405request.params.thread_id, context=context
406)
407return self._serialize(self._to_thread_response(thread))
408case ThreadsListReq():
409params = request.params
410threads = await self.store.load_threads(
411limit=params.limit or DEFAULT_PAGE_SIZE,
412after=params.after,
413order=params.order,
414context=context,
415)
416return self._serialize(
417Page(
418has_more=threads.has_more,
419after=threads.after,
420data=[
421self._to_thread_response(thread) for thread in threads.data
422],
423)
424)
425case ItemsFeedbackReq():
426await self.add_feedback(
427request.params.thread_id,
428request.params.item_ids,
429request.params.kind,
430context,
431)
432return b"{}"
433case AttachmentsCreateReq():
434attachment_store = self._get_attachment_store()
435attachment = await attachment_store.create_attachment(
436request.params, context
437)
438return self._serialize(attachment)
439case AttachmentsDeleteReq():
440attachment_store = self._get_attachment_store()
441await attachment_store.delete_attachment(
442request.params.attachment_id, context=context
443)
444await self.store.delete_attachment(
445request.params.attachment_id, context=context
446)
447return b"{}"
448case ItemsListReq():
449items_list_params = request.params
450items = await self.store.load_thread_items(
451items_list_params.thread_id,
452limit=items_list_params.limit or DEFAULT_PAGE_SIZE,
453order=items_list_params.order,
454after=items_list_params.after,
455context=context,
456)
1860286bJiwon Kim7 months ago457# filter out hidden context items
f688d870victor-openai8 months ago458items.data = [
459item
460for item in items.data
1860286bJiwon Kim7 months ago461if not isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
f688d870victor-openai8 months ago462]
463return self._serialize(items)
464case ThreadsUpdateReq():
465thread = await self.store.load_thread(
466request.params.thread_id, context=context
467)
468thread.title = request.params.title
469await self.store.save_thread(thread, context=context)
470return self._serialize(self._to_thread_response(thread))
471case ThreadsDeleteReq():
472await self.store.delete_thread(
473request.params.thread_id, context=context
474)
475return b"{}"
476case _:
477assert_never(request)
478
479async def _process_streaming(
480self, request: StreamingReq, context: TContext
481) -> AsyncGenerator[bytes, None]:
482try:
483async for event in self._process_streaming_impl(request, context):
484b = self._serialize(event)
485yield b"data: " + b + b"\n\n"
eee0a2f7Jiwon Kim7 months ago486except asyncio.CancelledError:
487# Let cancellation bubble up without logging as an error.
488raise
f688d870victor-openai8 months ago489except Exception:
490logger.exception("Error while generating streamed response")
491raise
492
493async def _process_streaming_impl(
494self, request: StreamingReq, context: TContext
495) -> AsyncGenerator[ThreadStreamEvent, None]:
496match request:
497case ThreadsCreateReq():
498thread = Thread(
499id=self.store.generate_thread_id(context),
500created_at=datetime.now(),
501items=Page(),
502)
8885a10bquettabit8 months ago503await self.store.save_thread(
504ThreadMetadata(**thread.model_dump()),
505context=context,
506)
f688d870victor-openai8 months ago507yield ThreadCreatedEvent(thread=self._to_thread_response(thread))
508user_message = await self._build_user_message_item(
509request.params.input, thread, context
510)
511async for event in self._process_new_thread_item_respond(
512thread,
513user_message,
514context,
515):
516yield event
517
518case ThreadsAddUserMessageReq():
519thread = await self.store.load_thread(
520request.params.thread_id, context=context
521)
522user_message = await self._build_user_message_item(
523request.params.input, thread, context
524)
525async for event in self._process_new_thread_item_respond(
526thread,
527user_message,
528context,
529):
530yield event
531
532case ThreadsAddClientToolOutputReq():
533thread = await self.store.load_thread(
534request.params.thread_id, context=context
535)
536items = await self.store.load_thread_items(
537thread.id, None, 1, "desc", context
538)
539tool_call = next(
540(
541item
542for item in items.data
543if isinstance(item, ClientToolCallItem)
544and item.status == "pending"
545),
546None,
547)
548if not tool_call:
549raise ValueError(
550f"Last thread item in {thread.id} was not a ClientToolCallItem"
551)
552
553tool_call.output = request.params.result
554tool_call.status = "completed"
555
556await self.store.save_item(thread.id, tool_call, context=context)
557
558# Safety against dangling pending tool calls if there are
559# multiple in a row, which should be impossible, and
560# integrations should ultimately filter out pending tool calls
561# when creating input response messages.
562await self._cleanup_pending_client_tool_call(thread, context)
563
564async for event in self._process_events(
565thread,
566context,
567lambda: self.respond(thread, None, context),
568):
569yield event
570
571case ThreadsRetryAfterItemReq():
572thread_metadata = await self.store.load_thread(
573request.params.thread_id, context=context
574)
575
576# Collect items to remove (all items after the user message)
577items_to_remove: list[ThreadItem] = []
578user_message_item = None
579
580async for item in self._paginate_thread_items_reverse(
581request.params.thread_id, context
582):
583if item.id == request.params.item_id:
584if not isinstance(item, UserMessageItem):
585raise ValueError(
586f"Item {request.params.item_id} is not a user message"
587)
588user_message_item = item
589break
590items_to_remove.append(item)
591
592if user_message_item:
593for item in items_to_remove:
594await self.store.delete_thread_item(
595request.params.thread_id, item.id, context=context
596)
597async for event in self._process_events(
598thread_metadata,
599context,
600lambda: self.respond(
601thread_metadata,
602user_message_item,
603context,
604),
605):
606yield event
607case ThreadsCustomActionReq():
608thread_metadata = await self.store.load_thread(
609request.params.thread_id, context=context
610)
611
612item: ThreadItem | None = None
613if request.params.item_id:
614item = await self.store.load_item(
615request.params.thread_id,
616request.params.item_id,
617context=context,
618)
619
620if item and not isinstance(item, WidgetItem):
621# shouldn't happen if the caller is using the API correctly.
622yield ErrorEvent(
623code=ErrorCode.STREAM_ERROR,
624allow_retry=False,
625)
626return
627
628async for event in self._process_events(
629thread_metadata,
630context,
631lambda: self.action(
632thread_metadata,
633request.params.action,
634item,
635context,
636),
637):
638yield event
639
640case _:
641assert_never(request)
642
643async def _cleanup_pending_client_tool_call(
644self, thread: ThreadMetadata, context: TContext
645) -> None:
646items = await self.store.load_thread_items(
647thread.id, None, DEFAULT_PAGE_SIZE, "desc", context
648)
649for tool_call in items.data:
650if not isinstance(tool_call, ClientToolCallItem):
651continue
652if tool_call.status == "pending":
653logger.warning(
654f"Client tool call {tool_call.call_id} was not completed, ignoring"
655)
656await self.store.delete_thread_item(
657thread.id, tool_call.id, context=context
658)
659
660async def _process_new_thread_item_respond(
661self,
662thread: ThreadMetadata,
663item: UserMessageItem,
664context: TContext,
665) -> AsyncIterator[ThreadStreamEvent]:
666await self.store.add_thread_item(thread.id, item, context=context)
667await self._cleanup_pending_client_tool_call(thread, context)
668yield ThreadItemDoneEvent(item=item)
669
670async for event in self._process_events(
671thread,
672context,
673lambda: self.respond(thread, item, context),
674):
675yield event
676
677async def _process_events(
678self,
679thread: ThreadMetadata,
680context: TContext,
681stream: Callable[[], AsyncIterator[ThreadStreamEvent]],
682) -> AsyncIterator[ThreadStreamEvent]:
683await asyncio.sleep(0) # allow the response to start streaming
684
5a7244baJiwon Kim7 months ago685# Send initial stream options
686yield StreamOptionsEvent(
687stream_options=self.get_stream_options(thread, context)
688)
689
f688d870victor-openai8 months ago690last_thread = thread.model_copy(deep=True)
691
eee0a2f7Jiwon Kim7 months ago692# Keep track of items that were streamed but not yet saved
693# so that we can persist them when the stream is cancelled.
694pending_items: dict[str, ThreadItem] = {}
695
f688d870victor-openai8 months ago696try:
697with agents_sdk_user_agent_override():
698async for event in stream():
eee0a2f7Jiwon Kim7 months ago699if isinstance(event, ThreadItemAddedEvent):
700pending_items[event.item.id] = event.item
701
f688d870victor-openai8 months ago702match event:
703case ThreadItemDoneEvent():
704await self.store.add_thread_item(
705thread.id, event.item, context=context
706)
eee0a2f7Jiwon Kim7 months ago707pending_items.pop(event.item.id, None)
f688d870victor-openai8 months ago708case ThreadItemRemovedEvent():
709await self.store.delete_thread_item(
710thread.id, event.item_id, context=context
711)
eee0a2f7Jiwon Kim7 months ago712pending_items.pop(event.item_id, None)
f688d870victor-openai8 months ago713case ThreadItemReplacedEvent():
714await self.store.save_item(
715thread.id, event.item, context=context
716)
eee0a2f7Jiwon Kim7 months ago717pending_items.pop(event.item.id, None)
718case ThreadItemUpdatedEvent():
5a7244baJiwon Kim7 months ago719# Keep pending assistant message and workflow items up to date
720# so that we have a reference to the latest version of these pending items
721# when the stream is cancelled.
722self._update_pending_items(pending_items, event)
f688d870victor-openai8 months ago723
724# special case - don't send hidden context items back to the client
725should_swallow_event = isinstance(
726event, ThreadItemDoneEvent
1860286bJiwon Kim7 months ago727) and isinstance(
728event.item, (HiddenContextItem, SDKHiddenContextItem)
729)
f688d870victor-openai8 months ago730
731if not should_swallow_event:
732yield event
733
734# in case user updated the thread while streaming
735if thread != last_thread:
736last_thread = thread.model_copy(deep=True)
737await self.store.save_thread(thread, context=context)
738yield ThreadUpdatedEvent(
739thread=self._to_thread_response(thread)
740)
741# in case user updated the thread while streaming
742if thread != last_thread:
743last_thread = thread.model_copy(deep=True)
744await self.store.save_thread(thread, context=context)
745yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
eee0a2f7Jiwon Kim7 months ago746except asyncio.CancelledError:
747await self.handle_stream_cancelled(
748thread, list(pending_items.values()), context
749)
750raise
f688d870victor-openai8 months ago751except CustomStreamError as e:
752yield ErrorEvent(
753code="custom",
754message=e.message,
755allow_retry=e.allow_retry,
756)
757except StreamError as e:
758yield ErrorEvent(
759code=e.code,
760allow_retry=e.allow_retry,
761)
762except Exception as e:
763yield ErrorEvent(
764code=ErrorCode.STREAM_ERROR,
765allow_retry=True,
766)
767logger.exception(e)
768
769if thread != last_thread:
770# in case user updated the thread at the end of the stream
771await self.store.save_thread(thread, context=context)
772yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
773
eee0a2f7Jiwon Kim7 months ago774def _apply_assistant_message_update(
775self,
776item: AssistantMessageItem,
777update: AssistantMessageContentPartAdded
778| AssistantMessageContentPartTextDelta
779| AssistantMessageContentPartAnnotationAdded
780| AssistantMessageContentPartDone,
781) -> AssistantMessageItem:
782updated = item.model_copy(deep=True)
783
784# Pad the content list so the requested content_index exists before we write into it.
785# (Streaming updates can arrive for an index that hasn’t been created yet)
786while len(updated.content) <= update.content_index:
787updated.content.append(AssistantMessageContent(text="", annotations=[]))
788
789match update:
790case AssistantMessageContentPartAdded():
791updated.content[update.content_index] = update.content
792case AssistantMessageContentPartTextDelta():
793updated.content[update.content_index].text += update.delta
794case AssistantMessageContentPartAnnotationAdded():
795annotations = updated.content[update.content_index].annotations
796if update.annotation_index <= len(annotations):
797annotations.insert(update.annotation_index, update.annotation)
798else:
799annotations.append(update.annotation)
800case AssistantMessageContentPartDone():
801updated.content[update.content_index] = update.content
802return updated
803
5a7244baJiwon Kim7 months ago804def _update_pending_items(
eee0a2f7Jiwon Kim7 months ago805self,
806pending_items: dict[str, ThreadItem],
807event: ThreadItemUpdatedEvent,
808):
809updated_item = pending_items.get(event.item_id)
5a7244baJiwon Kim7 months ago810update = event.update
811match updated_item:
812case AssistantMessageItem():
813if isinstance(
814update,
815(
816AssistantMessageContentPartAdded,
817AssistantMessageContentPartTextDelta,
818AssistantMessageContentPartAnnotationAdded,
819AssistantMessageContentPartDone,
820),
821):
822pending_items[updated_item.id] = (
823self._apply_assistant_message_update(updated_item, update)
824)
825case WorkflowItem():
826if isinstance(update, (WorkflowTaskUpdated, WorkflowTaskAdded)):
827match update:
828case WorkflowTaskUpdated():
829updated_item.workflow.tasks[update.task_index] = update.task
830case WorkflowTaskAdded():
831updated_item.workflow.tasks.append(update.task)
832
833pending_items[updated_item.id] = updated_item
834case _:
835pass
eee0a2f7Jiwon Kim7 months ago836
f688d870victor-openai8 months ago837async def _build_user_message_item(
838self, input: UserMessageInput, thread: ThreadMetadata, context: TContext
839) -> UserMessageItem:
840return UserMessageItem(
841id=self.store.generate_item_id("message", thread, context),
842content=input.content,
843thread_id=thread.id,
844attachments=[
845await self.store.load_attachment(attachment_id, context)
846for attachment_id in input.attachments
847],
848quoted_text=input.quoted_text,
849inference_options=input.inference_options,
850created_at=datetime.now(),
851)
852
853async def _load_full_thread(self, thread_id: str, context: TContext) -> Thread:
854thread_meta = await self.store.load_thread(thread_id, context=context)
855thread_items = await self.store.load_thread_items(
856thread_id,
857after=None,
858limit=DEFAULT_PAGE_SIZE,
859order="asc",
860context=context,
861)
862return Thread(**thread_meta.model_dump(), items=thread_items)
863
864async def _paginate_thread_items_reverse(
865self, thread_id: str, context: TContext
866) -> AsyncIterator[ThreadItem]:
867"""Paginate through thread items in reverse order (newest first)."""
868after = None
869while True:
870items = await self.store.load_thread_items(
871thread_id, after, DEFAULT_PAGE_SIZE, "desc", context
872)
873for item in items.data:
874yield item
875
876if not items.has_more:
877break
878after = items.after
879
880def _serialize(self, obj: BaseModel) -> bytes:
881return obj.model_dump_json(by_alias=True, exclude_none=True).encode("utf-8")
882
883def _to_thread_response(self, thread: ThreadMetadata | Thread) -> Thread:
884def is_hidden(item: ThreadItem) -> bool:
1860286bJiwon Kim7 months ago885return isinstance(item, (HiddenContextItem, SDKHiddenContextItem))
f688d870victor-openai8 months ago886
887items = thread.items if isinstance(thread, Thread) else Page()
888items.data = [item for item in items.data if not is_hidden(item)]
889
890return Thread(
891id=thread.id,
892title=thread.title,
893created_at=thread.created_at,
894items=items,
895status=thread.status,
896)