openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
3f7a2fa23d77277ce5213176a0267a72b7e215ec

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/server.py

728lines · modecode

1import asyncio
2from abc import ABC, abstractmethod
3from collections.abc import AsyncIterator
4from contextlib import contextmanager
5from datetime import datetime
6from typing import (
7 Any,
8 AsyncGenerator,
9 AsyncIterable,
10 Callable,
11 Generic,
12)
13
14import agents
15from agents.models.chatcmpl_helpers import (
16 HEADERS_OVERRIDE as chat_completions_headers_override,
17)
18from agents.models.openai_responses import (
19 _HEADERS_OVERRIDE as responses_headers_override,
20)
21from pydantic import BaseModel, TypeAdapter
22from typing_extensions import TypeVar, assert_never
23
24from chatkit.errors import CustomStreamError, StreamError
25
26from .logger import logger
27from .store import AttachmentStore, Store, StoreItemType, default_generate_id
28from .types import (
29 Action,
30 AttachmentsCreateReq,
31 AttachmentsDeleteReq,
32 ChatKitReq,
33 ClientToolCallItem,
34 ErrorCode,
35 ErrorEvent,
36 FeedbackKind,
37 HiddenContextItem,
38 ItemsFeedbackReq,
39 ItemsListReq,
40 NonStreamingReq,
41 Page,
42 StreamingReq,
43 Thread,
44 ThreadCreatedEvent,
45 ThreadItem,
46 ThreadItemAddedEvent,
47 ThreadItemDoneEvent,
48 ThreadItemRemovedEvent,
49 ThreadItemReplacedEvent,
50 ThreadItemUpdatedEvent,
51 ThreadMetadata,
52 ThreadsAddClientToolOutputReq,
53 ThreadsAddUserMessageReq,
54 ThreadsCreateReq,
55 ThreadsCustomActionReq,
56 ThreadsDeleteReq,
57 ThreadsGetByIdReq,
58 ThreadsListReq,
59 ThreadsRetryAfterItemReq,
60 ThreadStreamEvent,
61 ThreadsUpdateReq,
62 ThreadUpdatedEvent,
63 UserMessageInput,
64 UserMessageItem,
65 WidgetComponentUpdated,
66 WidgetItem,
67 WidgetRootUpdated,
68 WidgetStreamingTextValueDelta,
69 is_streaming_req,
70)
71from .version import __version__
72from .widgets import Markdown, Text, WidgetComponent, WidgetComponentBase, WidgetRoot
73
74DEFAULT_PAGE_SIZE = 20
75DEFAULT_ERROR_MESSAGE = "An error occurred when generating a response."
76
77
78def diff_widget(
79 before: WidgetRoot, after: WidgetRoot
80) -> list[WidgetStreamingTextValueDelta | WidgetRootUpdated | WidgetComponentUpdated]:
81 """
82 Compare two WidgetRoots and return a list of deltas.
83 """
84
85 def full_replace(before: WidgetComponentBase, after: WidgetComponentBase) -> bool:
86 if (
87 before.type != after.type
88 or before.id != after.id
89 or before.key != after.key
90 ):
91 return True
92
93 def full_replace_value(before_value: Any, after_value: Any) -> bool:
94 if isinstance(before_value, list) and isinstance(after_value, list):
95 if len(before_value) != len(after_value):
96 return True
97 for nth_before_value, nth_after_value in zip(before_value, after_value):
98 if full_replace_value(nth_before_value, nth_after_value):
99 return True
100 elif before_value != after_value:
101 if isinstance(before_value, WidgetComponentBase) and isinstance(
102 after_value, WidgetComponentBase
103 ):
104 return full_replace(before_value, after_value)
105 else:
106 return True
107 return False
108
109 for field in before.model_fields_set.union(after.model_fields_set):
110 if (
111 isinstance(before, (Markdown, Text))
112 and isinstance(after, (Markdown, Text))
113 and field == "value"
114 and after.value.startswith(before.value)
115 ):
116 # Appends to the value prop of Markdown or Text do not trigger a full replace
117 continue
118 if full_replace_value(getattr(before, field), getattr(after, field)):
119 return True
120
121 return False
122
123 if full_replace(before, after):
124 return [WidgetRootUpdated(widget=after)]
125
126 deltas: list[
127 WidgetStreamingTextValueDelta | WidgetComponentUpdated | WidgetRootUpdated
128 ] = []
129
130 def find_all_streaming_text_components(
131 component: WidgetComponent | WidgetRoot,
132 ) -> dict[str, Markdown | Text]:
133 components = {}
134
135 def recurse(component: WidgetComponent | WidgetRoot):
136 if isinstance(component, (Markdown, Text)) and component.id:
137 components[component.id] = component
138
139 if hasattr(component, "children"):
140 children = getattr(component, "children", None) or []
141 for child in children:
142 recurse(child)
143
144 recurse(component)
145 return components
146
147 before_nodes = find_all_streaming_text_components(before)
148 after_nodes = find_all_streaming_text_components(after)
149
150 for id, after_node in after_nodes.items():
151 before_node = before_nodes.get(id)
152 if before_node is None:
153 raise ValueError(
154 f"Node {id} was not present when the widget was initially rendered. All nodes with ID must persist across all widget updates."
155 )
156
157 if before_node.value != after_node.value:
158 if not after_node.value.startswith(before_node.value):
159 raise ValueError(
160 f"Node {id} was updated with a new value that is not a prefix of the initial value. All widget updates must be cumulative."
161 )
162 done = not after_node.streaming
163 deltas.append(
164 WidgetStreamingTextValueDelta(
165 component_id=id,
166 delta=after_node.value[len(before_node.value) :],
167 done=done,
168 )
169 )
170
171 return deltas
172
173
174async def stream_widget(
175 thread: ThreadMetadata,
176 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
177 copy_text: str | None = None,
178 generate_id: Callable[[StoreItemType], str] = default_generate_id,
179) -> AsyncIterator[ThreadStreamEvent]:
180 item_id = generate_id("message")
181
182 if not isinstance(widget, AsyncGenerator):
183 yield ThreadItemDoneEvent(
184 item=WidgetItem(
185 id=item_id,
186 thread_id=thread.id,
187 created_at=datetime.now(),
188 widget=widget,
189 copy_text=copy_text,
190 ),
191 )
192 return
193
194 initial_state = await widget.__anext__()
195
196 item = WidgetItem(
197 id=item_id,
198 created_at=datetime.now(),
199 widget=initial_state,
200 copy_text=copy_text,
201 thread_id=thread.id,
202 )
203
204 yield ThreadItemAddedEvent(item=item)
205
206 last_state = initial_state
207
208 while widget:
209 try:
210 new_state = await widget.__anext__()
211 for update in diff_widget(last_state, new_state):
212 yield ThreadItemUpdatedEvent(
213 item_id=item_id,
214 update=update,
215 )
216 last_state = new_state
217 except StopAsyncIteration:
218 break
219
220 yield ThreadItemDoneEvent(
221 item=item.model_copy(update={"widget": last_state}),
222 )
223
224
225@contextmanager
226def agents_sdk_user_agent_override():
227 ua = f"Agents/Python {agents.__version__} ChatKit/Python {__version__}"
228 chat_completions_token = chat_completions_headers_override.set({"User-Agent": ua})
229 responses_token = responses_headers_override.set({"User-Agent": ua})
230
231 yield
232
233 chat_completions_headers_override.reset(chat_completions_token)
234 responses_headers_override.reset(responses_token)
235
236
237class StreamingResult(AsyncIterable[bytes]):
238 def __init__(self, stream: AsyncGenerator[bytes, None]):
239 self.json_events = stream
240
241 async def __aiter__(self):
242 async for event in self.json_events:
243 yield event
244
245
246class NonStreamingResult:
247 def __init__(self, result: bytes):
248 self.json = result
249
250
251TContext = TypeVar("TContext", default=Any)
252
253
254class ChatKitServer(ABC, Generic[TContext]):
255 def __init__(
256 self,
257 store: Store[TContext],
258 attachment_store: AttachmentStore[TContext] | None = None,
259 ):
260 self.store = store
261 self.attachment_store = attachment_store
262
263 def _get_attachment_store(self) -> AttachmentStore[TContext]:
264 """Return the configured AttachmentStore or raise if missing."""
265 if self.attachment_store is None:
266 raise RuntimeError(
267 "AttachmentStore is not configured. Provide a AttachmentStore to ChatKitServer to handle file operations."
268 )
269 return self.attachment_store
270
271 @abstractmethod
272 def respond(
273 self,
274 thread: ThreadMetadata,
275 input_user_message: UserMessageItem | None,
276 context: TContext,
277 ) -> AsyncIterator[ThreadStreamEvent]:
278 """Stream `ThreadStreamEvent` instances for a new user message.
279
280 Args:
281 thread: Metadata for the thread being processed.
282 input_user_message: The incoming message the server should respond to, if any.
283 context: Arbitrary per-request context provided by the caller.
284
285 Returns:
286 An async iterator that yields events representing the server's response.
287 """
288 pass
289
290 async def add_feedback( # noqa: B027
291 self,
292 thread_id: str,
293 item_ids: list[str],
294 feedback: FeedbackKind,
295 context: TContext,
296 ) -> None:
297 pass
298
299 def action(
300 self,
301 thread: ThreadMetadata,
302 action: Action[str, Any],
303 sender: WidgetItem | None,
304 context: TContext,
305 ) -> AsyncIterator[ThreadStreamEvent]:
306 raise NotImplementedError(
307 "The action() method must be overridden to react to actions. "
308 "See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
309 )
310
311 async def process(
312 self, request: str | bytes | bytearray, context: TContext
313 ) -> StreamingResult | NonStreamingResult:
314 parsed_request = TypeAdapter[ChatKitReq](ChatKitReq).validate_json(request)
315 logger.info(f"Received request op: {parsed_request.type}")
316
317 if is_streaming_req(parsed_request):
318 return StreamingResult(self._process_streaming(parsed_request, context))
319 else:
320 return NonStreamingResult(
321 await self._process_non_streaming(parsed_request, context)
322 )
323
324 async def _process_non_streaming(
325 self, request: NonStreamingReq, context: TContext
326 ) -> bytes:
327 match request:
328 case ThreadsGetByIdReq():
329 thread = await self._load_full_thread(
330 request.params.thread_id, context=context
331 )
332 return self._serialize(self._to_thread_response(thread))
333 case ThreadsListReq():
334 params = request.params
335 threads = await self.store.load_threads(
336 limit=params.limit or DEFAULT_PAGE_SIZE,
337 after=params.after,
338 order=params.order,
339 context=context,
340 )
341 return self._serialize(
342 Page(
343 has_more=threads.has_more,
344 after=threads.after,
345 data=[
346 self._to_thread_response(thread) for thread in threads.data
347 ],
348 )
349 )
350 case ItemsFeedbackReq():
351 await self.add_feedback(
352 request.params.thread_id,
353 request.params.item_ids,
354 request.params.kind,
355 context,
356 )
357 return b"{}"
358 case AttachmentsCreateReq():
359 attachment_store = self._get_attachment_store()
360 attachment = await attachment_store.create_attachment(
361 request.params, context
362 )
363 return self._serialize(attachment)
364 case AttachmentsDeleteReq():
365 attachment_store = self._get_attachment_store()
366 await attachment_store.delete_attachment(
367 request.params.attachment_id, context=context
368 )
369 await self.store.delete_attachment(
370 request.params.attachment_id, context=context
371 )
372 return b"{}"
373 case ItemsListReq():
374 items_list_params = request.params
375 items = await self.store.load_thread_items(
376 items_list_params.thread_id,
377 limit=items_list_params.limit or DEFAULT_PAGE_SIZE,
378 order=items_list_params.order,
379 after=items_list_params.after,
380 context=context,
381 )
382 # filter out HiddenContextItems
383 items.data = [
384 item
385 for item in items.data
386 if not isinstance(item, HiddenContextItem)
387 ]
388 return self._serialize(items)
389 case ThreadsUpdateReq():
390 thread = await self.store.load_thread(
391 request.params.thread_id, context=context
392 )
393 thread.title = request.params.title
394 await self.store.save_thread(thread, context=context)
395 return self._serialize(self._to_thread_response(thread))
396 case ThreadsDeleteReq():
397 await self.store.delete_thread(
398 request.params.thread_id, context=context
399 )
400 return b"{}"
401 case _:
402 assert_never(request)
403
404 async def _process_streaming(
405 self, request: StreamingReq, context: TContext
406 ) -> AsyncGenerator[bytes, None]:
407 try:
408 async for event in self._process_streaming_impl(request, context):
409 b = self._serialize(event)
410 yield b"data: " + b + b"\n\n"
411 except Exception:
412 logger.exception("Error while generating streamed response")
413 raise
414
415 async def _process_streaming_impl(
416 self, request: StreamingReq, context: TContext
417 ) -> AsyncGenerator[ThreadStreamEvent, None]:
418 match request:
419 case ThreadsCreateReq():
420 thread = Thread(
421 id=self.store.generate_thread_id(context),
422 created_at=datetime.now(),
423 items=Page(),
424 )
425 await self.store.save_thread(
426 ThreadMetadata(**thread.model_dump()),
427 context=context,
428 )
429 yield ThreadCreatedEvent(thread=self._to_thread_response(thread))
430 user_message = await self._build_user_message_item(
431 request.params.input, thread, context
432 )
433 async for event in self._process_new_thread_item_respond(
434 thread,
435 user_message,
436 context,
437 ):
438 yield event
439
440 case ThreadsAddUserMessageReq():
441 thread = await self.store.load_thread(
442 request.params.thread_id, context=context
443 )
444 user_message = await self._build_user_message_item(
445 request.params.input, thread, context
446 )
447 async for event in self._process_new_thread_item_respond(
448 thread,
449 user_message,
450 context,
451 ):
452 yield event
453
454 case ThreadsAddClientToolOutputReq():
455 thread = await self.store.load_thread(
456 request.params.thread_id, context=context
457 )
458 items = await self.store.load_thread_items(
459 thread.id, None, 1, "desc", context
460 )
461 tool_call = next(
462 (
463 item
464 for item in items.data
465 if isinstance(item, ClientToolCallItem)
466 and item.status == "pending"
467 ),
468 None,
469 )
470 if not tool_call:
471 raise ValueError(
472 f"Last thread item in {thread.id} was not a ClientToolCallItem"
473 )
474
475 tool_call.output = request.params.result
476 tool_call.status = "completed"
477
478 await self.store.save_item(thread.id, tool_call, context=context)
479
480 # Safety against dangling pending tool calls if there are
481 # multiple in a row, which should be impossible, and
482 # integrations should ultimately filter out pending tool calls
483 # when creating input response messages.
484 await self._cleanup_pending_client_tool_call(thread, context)
485
486 async for event in self._process_events(
487 thread,
488 context,
489 lambda: self.respond(thread, None, context),
490 ):
491 yield event
492
493 case ThreadsRetryAfterItemReq():
494 thread_metadata = await self.store.load_thread(
495 request.params.thread_id, context=context
496 )
497
498 # Collect items to remove (all items after the user message)
499 items_to_remove: list[ThreadItem] = []
500 user_message_item = None
501
502 async for item in self._paginate_thread_items_reverse(
503 request.params.thread_id, context
504 ):
505 if item.id == request.params.item_id:
506 if not isinstance(item, UserMessageItem):
507 raise ValueError(
508 f"Item {request.params.item_id} is not a user message"
509 )
510 user_message_item = item
511 break
512 items_to_remove.append(item)
513
514 if user_message_item:
515 for item in items_to_remove:
516 await self.store.delete_thread_item(
517 request.params.thread_id, item.id, context=context
518 )
519 async for event in self._process_events(
520 thread_metadata,
521 context,
522 lambda: self.respond(
523 thread_metadata,
524 user_message_item,
525 context,
526 ),
527 ):
528 yield event
529 case ThreadsCustomActionReq():
530 thread_metadata = await self.store.load_thread(
531 request.params.thread_id, context=context
532 )
533
534 item: ThreadItem | None = None
535 if request.params.item_id:
536 item = await self.store.load_item(
537 request.params.thread_id,
538 request.params.item_id,
539 context=context,
540 )
541
542 if item and not isinstance(item, WidgetItem):
543 # shouldn't happen if the caller is using the API correctly.
544 yield ErrorEvent(
545 code=ErrorCode.STREAM_ERROR,
546 allow_retry=False,
547 )
548 return
549
550 async for event in self._process_events(
551 thread_metadata,
552 context,
553 lambda: self.action(
554 thread_metadata,
555 request.params.action,
556 item,
557 context,
558 ),
559 ):
560 yield event
561
562 case _:
563 assert_never(request)
564
565 async def _cleanup_pending_client_tool_call(
566 self, thread: ThreadMetadata, context: TContext
567 ) -> None:
568 items = await self.store.load_thread_items(
569 thread.id, None, DEFAULT_PAGE_SIZE, "desc", context
570 )
571 for tool_call in items.data:
572 if not isinstance(tool_call, ClientToolCallItem):
573 continue
574 if tool_call.status == "pending":
575 logger.warning(
576 f"Client tool call {tool_call.call_id} was not completed, ignoring"
577 )
578 await self.store.delete_thread_item(
579 thread.id, tool_call.id, context=context
580 )
581
582 async def _process_new_thread_item_respond(
583 self,
584 thread: ThreadMetadata,
585 item: UserMessageItem,
586 context: TContext,
587 ) -> AsyncIterator[ThreadStreamEvent]:
588 await self.store.add_thread_item(thread.id, item, context=context)
589 await self._cleanup_pending_client_tool_call(thread, context)
590 yield ThreadItemDoneEvent(item=item)
591
592 async for event in self._process_events(
593 thread,
594 context,
595 lambda: self.respond(thread, item, context),
596 ):
597 yield event
598
599 async def _process_events(
600 self,
601 thread: ThreadMetadata,
602 context: TContext,
603 stream: Callable[[], AsyncIterator[ThreadStreamEvent]],
604 ) -> AsyncIterator[ThreadStreamEvent]:
605 await asyncio.sleep(0) # allow the response to start streaming
606
607 last_thread = thread.model_copy(deep=True)
608
609 try:
610 with agents_sdk_user_agent_override():
611 async for event in stream():
612 match event:
613 case ThreadItemDoneEvent():
614 await self.store.add_thread_item(
615 thread.id, event.item, context=context
616 )
617 case ThreadItemRemovedEvent():
618 await self.store.delete_thread_item(
619 thread.id, event.item_id, context=context
620 )
621 case ThreadItemReplacedEvent():
622 await self.store.save_item(
623 thread.id, event.item, context=context
624 )
625
626 # special case - don't send hidden context items back to the client
627 should_swallow_event = isinstance(
628 event, ThreadItemDoneEvent
629 ) and isinstance(event.item, HiddenContextItem)
630
631 if not should_swallow_event:
632 yield event
633
634 # in case user updated the thread while streaming
635 if thread != last_thread:
636 last_thread = thread.model_copy(deep=True)
637 await self.store.save_thread(thread, context=context)
638 yield ThreadUpdatedEvent(
639 thread=self._to_thread_response(thread)
640 )
641 # in case user updated the thread while streaming
642 if thread != last_thread:
643 last_thread = thread.model_copy(deep=True)
644 await self.store.save_thread(thread, context=context)
645 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
646 except CustomStreamError as e:
647 yield ErrorEvent(
648 code="custom",
649 message=e.message,
650 allow_retry=e.allow_retry,
651 )
652 except StreamError as e:
653 yield ErrorEvent(
654 code=e.code,
655 allow_retry=e.allow_retry,
656 )
657 except Exception as e:
658 yield ErrorEvent(
659 code=ErrorCode.STREAM_ERROR,
660 allow_retry=True,
661 )
662 logger.exception(e)
663
664 if thread != last_thread:
665 # in case user updated the thread at the end of the stream
666 await self.store.save_thread(thread, context=context)
667 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
668
669 async def _build_user_message_item(
670 self, input: UserMessageInput, thread: ThreadMetadata, context: TContext
671 ) -> UserMessageItem:
672 return UserMessageItem(
673 id=self.store.generate_item_id("message", thread, context),
674 content=input.content,
675 thread_id=thread.id,
676 attachments=[
677 await self.store.load_attachment(attachment_id, context)
678 for attachment_id in input.attachments
679 ],
680 quoted_text=input.quoted_text,
681 inference_options=input.inference_options,
682 created_at=datetime.now(),
683 )
684
685 async def _load_full_thread(self, thread_id: str, context: TContext) -> Thread:
686 thread_meta = await self.store.load_thread(thread_id, context=context)
687 thread_items = await self.store.load_thread_items(
688 thread_id,
689 after=None,
690 limit=DEFAULT_PAGE_SIZE,
691 order="asc",
692 context=context,
693 )
694 return Thread(**thread_meta.model_dump(), items=thread_items)
695
696 async def _paginate_thread_items_reverse(
697 self, thread_id: str, context: TContext
698 ) -> AsyncIterator[ThreadItem]:
699 """Paginate through thread items in reverse order (newest first)."""
700 after = None
701 while True:
702 items = await self.store.load_thread_items(
703 thread_id, after, DEFAULT_PAGE_SIZE, "desc", context
704 )
705 for item in items.data:
706 yield item
707
708 if not items.has_more:
709 break
710 after = items.after
711
712 def _serialize(self, obj: BaseModel) -> bytes:
713 return obj.model_dump_json(by_alias=True, exclude_none=True).encode("utf-8")
714
715 def _to_thread_response(self, thread: ThreadMetadata | Thread) -> Thread:
716 def is_hidden(item: ThreadItem) -> bool:
717 return isinstance(item, HiddenContextItem)
718
719 items = thread.items if isinstance(thread, Thread) else Page()
720 items.data = [item for item in items.data if not is_hidden(item)]
721
722 return Thread(
723 id=thread.id,
724 title=thread.title,
725 created_at=thread.created_at,
726 items=items,
727 status=thread.status,
728 )
729