openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
c9f5f09d50d3fa3029f1576bc8750f0f01286b16

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/server.py

729lines · modecode

1import asyncio
2from abc import ABC, abstractmethod
3from collections.abc import AsyncIterator
4from contextlib import contextmanager
5from datetime import datetime
6from typing import (
7 Any,
8 AsyncGenerator,
9 AsyncIterable,
10 Callable,
11 Generic,
12 assert_never,
13)
14
15import agents
16from agents.models.chatcmpl_helpers import (
17 HEADERS_OVERRIDE as chat_completions_headers_override,
18)
19from agents.models.openai_responses import (
20 _HEADERS_OVERRIDE as responses_headers_override,
21)
22from pydantic import BaseModel, TypeAdapter
23from typing_extensions import TypeVar
24
25from chatkit.errors import CustomStreamError, StreamError
26
27from .logger import logger
28from .store import AttachmentStore, Store, StoreItemType, default_generate_id
29from .types import (
30 Action,
31 AttachmentsCreateReq,
32 AttachmentsDeleteReq,
33 ChatKitReq,
34 ClientToolCallItem,
35 ErrorCode,
36 ErrorEvent,
37 FeedbackKind,
38 HiddenContextItem,
39 ItemsFeedbackReq,
40 ItemsListReq,
41 NonStreamingReq,
42 Page,
43 StreamingReq,
44 Thread,
45 ThreadCreatedEvent,
46 ThreadItem,
47 ThreadItemAddedEvent,
48 ThreadItemDoneEvent,
49 ThreadItemRemovedEvent,
50 ThreadItemReplacedEvent,
51 ThreadItemUpdated,
52 ThreadMetadata,
53 ThreadsAddClientToolOutputReq,
54 ThreadsAddUserMessageReq,
55 ThreadsCreateReq,
56 ThreadsCustomActionReq,
57 ThreadsDeleteReq,
58 ThreadsGetByIdReq,
59 ThreadsListReq,
60 ThreadsRetryAfterItemReq,
61 ThreadStreamEvent,
62 ThreadsUpdateReq,
63 ThreadUpdatedEvent,
64 UserMessageInput,
65 UserMessageItem,
66 WidgetComponentUpdated,
67 WidgetItem,
68 WidgetRootUpdated,
69 WidgetStreamingTextValueDelta,
70 is_streaming_req,
71)
72from .version import __version__
73from .widgets import Markdown, Text, WidgetComponent, WidgetComponentBase, WidgetRoot
74
75DEFAULT_PAGE_SIZE = 20
76DEFAULT_ERROR_MESSAGE = "An error occurred when generating a response."
77
78
79def diff_widget(
80 before: WidgetRoot, after: WidgetRoot
81) -> list[WidgetStreamingTextValueDelta | WidgetRootUpdated | WidgetComponentUpdated]:
82 """
83 Compare two WidgetRoots and return a list of deltas.
84 """
85
86 def full_replace(before: WidgetComponentBase, after: WidgetComponentBase) -> bool:
87 if (
88 before.type != after.type
89 or before.id != after.id
90 or before.key != after.key
91 ):
92 return True
93
94 def full_replace_value(before_value: Any, after_value: Any) -> bool:
95 if isinstance(before_value, list) and isinstance(after_value, list):
96 if len(before_value) != len(after_value):
97 return True
98 for nth_before_value, nth_after_value in zip(before_value, after_value):
99 if full_replace_value(nth_before_value, nth_after_value):
100 return True
101 elif before_value != after_value:
102 if isinstance(before_value, WidgetComponentBase) and isinstance(
103 after_value, WidgetComponentBase
104 ):
105 return full_replace(before_value, after_value)
106 else:
107 return True
108 return False
109
110 for field in before.model_fields_set.union(after.model_fields_set):
111 if (
112 isinstance(before, (Markdown, Text))
113 and isinstance(after, (Markdown, Text))
114 and field == "value"
115 and after.value.startswith(before.value)
116 ):
117 # Appends to the value prop of Markdown or Text do not trigger a full replace
118 continue
119 if full_replace_value(getattr(before, field), getattr(after, field)):
120 return True
121
122 return False
123
124 if full_replace(before, after):
125 return [WidgetRootUpdated(widget=after)]
126
127 deltas: list[
128 WidgetStreamingTextValueDelta | WidgetComponentUpdated | WidgetRootUpdated
129 ] = []
130
131 def find_all_streaming_text_components(
132 component: WidgetComponent | WidgetRoot,
133 ) -> dict[str, Markdown | Text]:
134 components = {}
135
136 def recurse(component: WidgetComponent | WidgetRoot):
137 if isinstance(component, (Markdown, Text)) and component.id:
138 components[component.id] = component
139
140 if hasattr(component, "children"):
141 children = getattr(component, "children", None) or []
142 for child in children:
143 recurse(child)
144
145 recurse(component)
146 return components
147
148 before_nodes = find_all_streaming_text_components(before)
149 after_nodes = find_all_streaming_text_components(after)
150
151 for id, after_node in after_nodes.items():
152 before_node = before_nodes.get(id)
153 if before_node is None:
154 raise ValueError(
155 f"Node {id} was not present when the widget was initially rendered. All nodes with ID must persist across all widget updates."
156 )
157
158 if before_node.value != after_node.value:
159 if not after_node.value.startswith(before_node.value):
160 raise ValueError(
161 f"Node {id} was updated with a new value that is not a prefix of the initial value. All widget updates must be cumulative."
162 )
163 done = not after_node.streaming
164 deltas.append(
165 WidgetStreamingTextValueDelta(
166 component_id=id,
167 delta=after_node.value[len(before_node.value) :],
168 done=done,
169 )
170 )
171
172 return deltas
173
174
175async def stream_widget(
176 thread: ThreadMetadata,
177 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
178 copy_text: str | None = None,
179 generate_id: Callable[[StoreItemType], str] = default_generate_id,
180) -> AsyncIterator[ThreadStreamEvent]:
181 item_id = generate_id("message")
182
183 if not isinstance(widget, AsyncGenerator):
184 yield ThreadItemDoneEvent(
185 item=WidgetItem(
186 id=item_id,
187 thread_id=thread.id,
188 created_at=datetime.now(),
189 widget=widget,
190 copy_text=copy_text,
191 ),
192 )
193 return
194
195 initial_state = await widget.__anext__()
196
197 item = WidgetItem(
198 id=item_id,
199 created_at=datetime.now(),
200 widget=initial_state,
201 copy_text=copy_text,
202 thread_id=thread.id,
203 )
204
205 yield ThreadItemAddedEvent(item=item)
206
207 last_state = initial_state
208
209 while widget:
210 try:
211 new_state = await widget.__anext__()
212 for update in diff_widget(last_state, new_state):
213 yield ThreadItemUpdated(
214 item_id=item_id,
215 update=update,
216 )
217 last_state = new_state
218 except StopAsyncIteration:
219 break
220
221 yield ThreadItemDoneEvent(
222 item=item.model_copy(update={"widget": last_state}),
223 )
224
225
226@contextmanager
227def agents_sdk_user_agent_override():
228 ua = f"Agents/Python {agents.__version__} ChatKit/Python {__version__}"
229 chat_completions_token = chat_completions_headers_override.set({"User-Agent": ua})
230 responses_token = responses_headers_override.set({"User-Agent": ua})
231
232 yield
233
234 chat_completions_headers_override.reset(chat_completions_token)
235 responses_headers_override.reset(responses_token)
236
237
238class StreamingResult(AsyncIterable[bytes]):
239 def __init__(self, stream: AsyncGenerator[bytes, None]):
240 self.json_events = stream
241
242 async def __aiter__(self):
243 async for event in self.json_events:
244 yield event
245
246
247class NonStreamingResult:
248 def __init__(self, result: bytes):
249 self.json = result
250
251
252TContext = TypeVar("TContext", default=Any)
253
254
255class ChatKitServer(ABC, Generic[TContext]):
256 def __init__(
257 self,
258 store: Store[TContext],
259 attachment_store: AttachmentStore[TContext] | None = None,
260 ):
261 self.store = store
262 self.attachment_store = attachment_store
263
264 def _get_attachment_store(self) -> AttachmentStore[TContext]:
265 """Return the configured AttachmentStore or raise if missing."""
266 if self.attachment_store is None:
267 raise RuntimeError(
268 "AttachmentStore is not configured. Provide a AttachmentStore to ChatKitServer to handle file operations."
269 )
270 return self.attachment_store
271
272 @abstractmethod
273 def respond(
274 self,
275 thread: ThreadMetadata,
276 input_user_message: UserMessageItem | None,
277 context: TContext,
278 ) -> AsyncIterator[ThreadStreamEvent]:
279 """Stream `ThreadStreamEvent` instances for a new user message.
280
281 Args:
282 thread: Metadata for the thread being processed.
283 input_user_message: The incoming message the server should respond to, if any.
284 context: Arbitrary per-request context provided by the caller.
285
286 Returns:
287 An async iterator that yields events representing the server's response.
288 """
289 pass
290
291 async def add_feedback( # noqa: B027
292 self,
293 thread_id: str,
294 item_ids: list[str],
295 feedback: FeedbackKind,
296 context: TContext,
297 ) -> None:
298 pass
299
300 def action(
301 self,
302 thread: ThreadMetadata,
303 action: Action[str, Any],
304 sender: WidgetItem | None,
305 context: TContext,
306 ) -> AsyncIterator[ThreadStreamEvent]:
307 raise NotImplementedError(
308 "The action() method must be overridden to react to actions. "
309 "See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
310 )
311
312 async def process(
313 self, request: str | bytes | bytearray, context: TContext
314 ) -> StreamingResult | NonStreamingResult:
315 parsed_request = TypeAdapter[ChatKitReq](ChatKitReq).validate_json(request)
316 logger.info(f"Received request op: {parsed_request.type}")
317
318 if is_streaming_req(parsed_request):
319 return StreamingResult(self._process_streaming(parsed_request, context))
320 else:
321 return NonStreamingResult(
322 await self._process_non_streaming(parsed_request, context)
323 )
324
325 async def _process_non_streaming(
326 self, request: NonStreamingReq, context: TContext
327 ) -> bytes:
328 match request:
329 case ThreadsGetByIdReq():
330 thread = await self._load_full_thread(
331 request.params.thread_id, context=context
332 )
333 return self._serialize(self._to_thread_response(thread))
334 case ThreadsListReq():
335 params = request.params
336 threads = await self.store.load_threads(
337 limit=params.limit or DEFAULT_PAGE_SIZE,
338 after=params.after,
339 order=params.order,
340 context=context,
341 )
342 return self._serialize(
343 Page(
344 has_more=threads.has_more,
345 after=threads.after,
346 data=[
347 self._to_thread_response(thread) for thread in threads.data
348 ],
349 )
350 )
351 case ItemsFeedbackReq():
352 await self.add_feedback(
353 request.params.thread_id,
354 request.params.item_ids,
355 request.params.kind,
356 context,
357 )
358 return b"{}"
359 case AttachmentsCreateReq():
360 attachment_store = self._get_attachment_store()
361 attachment = await attachment_store.create_attachment(
362 request.params, context
363 )
364 return self._serialize(attachment)
365 case AttachmentsDeleteReq():
366 attachment_store = self._get_attachment_store()
367 await attachment_store.delete_attachment(
368 request.params.attachment_id, context=context
369 )
370 await self.store.delete_attachment(
371 request.params.attachment_id, context=context
372 )
373 return b"{}"
374 case ItemsListReq():
375 items_list_params = request.params
376 items = await self.store.load_thread_items(
377 items_list_params.thread_id,
378 limit=items_list_params.limit or DEFAULT_PAGE_SIZE,
379 order=items_list_params.order,
380 after=items_list_params.after,
381 context=context,
382 )
383 # filter out HiddenContextItems
384 items.data = [
385 item
386 for item in items.data
387 if not isinstance(item, HiddenContextItem)
388 ]
389 return self._serialize(items)
390 case ThreadsUpdateReq():
391 thread = await self.store.load_thread(
392 request.params.thread_id, context=context
393 )
394 thread.title = request.params.title
395 await self.store.save_thread(thread, context=context)
396 return self._serialize(self._to_thread_response(thread))
397 case ThreadsDeleteReq():
398 await self.store.delete_thread(
399 request.params.thread_id, context=context
400 )
401 return b"{}"
402 case _:
403 assert_never(request)
404
405 async def _process_streaming(
406 self, request: StreamingReq, context: TContext
407 ) -> AsyncGenerator[bytes, None]:
408 try:
409 async for event in self._process_streaming_impl(request, context):
410 b = self._serialize(event)
411 yield b"data: " + b + b"\n\n"
412 except Exception:
413 logger.exception("Error while generating streamed response")
414 raise
415
416 async def _process_streaming_impl(
417 self, request: StreamingReq, context: TContext
418 ) -> AsyncGenerator[ThreadStreamEvent, None]:
419 match request:
420 case ThreadsCreateReq():
421 thread = Thread(
422 id=self.store.generate_thread_id(context),
423 created_at=datetime.now(),
424 items=Page(),
425 )
426 await self.store.save_thread(
427 ThreadMetadata(**thread.model_dump()),
428 context=context,
429 )
430 yield ThreadCreatedEvent(thread=self._to_thread_response(thread))
431 user_message = await self._build_user_message_item(
432 request.params.input, thread, context
433 )
434 async for event in self._process_new_thread_item_respond(
435 thread,
436 user_message,
437 context,
438 ):
439 yield event
440
441 case ThreadsAddUserMessageReq():
442 thread = await self.store.load_thread(
443 request.params.thread_id, context=context
444 )
445 user_message = await self._build_user_message_item(
446 request.params.input, thread, context
447 )
448 async for event in self._process_new_thread_item_respond(
449 thread,
450 user_message,
451 context,
452 ):
453 yield event
454
455 case ThreadsAddClientToolOutputReq():
456 thread = await self.store.load_thread(
457 request.params.thread_id, context=context
458 )
459 items = await self.store.load_thread_items(
460 thread.id, None, 1, "desc", context
461 )
462 tool_call = next(
463 (
464 item
465 for item in items.data
466 if isinstance(item, ClientToolCallItem)
467 and item.status == "pending"
468 ),
469 None,
470 )
471 if not tool_call:
472 raise ValueError(
473 f"Last thread item in {thread.id} was not a ClientToolCallItem"
474 )
475
476 tool_call.output = request.params.result
477 tool_call.status = "completed"
478
479 await self.store.save_item(thread.id, tool_call, context=context)
480
481 # Safety against dangling pending tool calls if there are
482 # multiple in a row, which should be impossible, and
483 # integrations should ultimately filter out pending tool calls
484 # when creating input response messages.
485 await self._cleanup_pending_client_tool_call(thread, context)
486
487 async for event in self._process_events(
488 thread,
489 context,
490 lambda: self.respond(thread, None, context),
491 ):
492 yield event
493
494 case ThreadsRetryAfterItemReq():
495 thread_metadata = await self.store.load_thread(
496 request.params.thread_id, context=context
497 )
498
499 # Collect items to remove (all items after the user message)
500 items_to_remove: list[ThreadItem] = []
501 user_message_item = None
502
503 async for item in self._paginate_thread_items_reverse(
504 request.params.thread_id, context
505 ):
506 if item.id == request.params.item_id:
507 if not isinstance(item, UserMessageItem):
508 raise ValueError(
509 f"Item {request.params.item_id} is not a user message"
510 )
511 user_message_item = item
512 break
513 items_to_remove.append(item)
514
515 if user_message_item:
516 for item in items_to_remove:
517 await self.store.delete_thread_item(
518 request.params.thread_id, item.id, context=context
519 )
520 async for event in self._process_events(
521 thread_metadata,
522 context,
523 lambda: self.respond(
524 thread_metadata,
525 user_message_item,
526 context,
527 ),
528 ):
529 yield event
530 case ThreadsCustomActionReq():
531 thread_metadata = await self.store.load_thread(
532 request.params.thread_id, context=context
533 )
534
535 item: ThreadItem | None = None
536 if request.params.item_id:
537 item = await self.store.load_item(
538 request.params.thread_id,
539 request.params.item_id,
540 context=context,
541 )
542
543 if item and not isinstance(item, WidgetItem):
544 # shouldn't happen if the caller is using the API correctly.
545 yield ErrorEvent(
546 code=ErrorCode.STREAM_ERROR,
547 allow_retry=False,
548 )
549 return
550
551 async for event in self._process_events(
552 thread_metadata,
553 context,
554 lambda: self.action(
555 thread_metadata,
556 request.params.action,
557 item,
558 context,
559 ),
560 ):
561 yield event
562
563 case _:
564 assert_never(request)
565
566 async def _cleanup_pending_client_tool_call(
567 self, thread: ThreadMetadata, context: TContext
568 ) -> None:
569 items = await self.store.load_thread_items(
570 thread.id, None, DEFAULT_PAGE_SIZE, "desc", context
571 )
572 for tool_call in items.data:
573 if not isinstance(tool_call, ClientToolCallItem):
574 continue
575 if tool_call.status == "pending":
576 logger.warning(
577 f"Client tool call {tool_call.call_id} was not completed, ignoring"
578 )
579 await self.store.delete_thread_item(
580 thread.id, tool_call.id, context=context
581 )
582
583 async def _process_new_thread_item_respond(
584 self,
585 thread: ThreadMetadata,
586 item: UserMessageItem,
587 context: TContext,
588 ) -> AsyncIterator[ThreadStreamEvent]:
589 await self.store.add_thread_item(thread.id, item, context=context)
590 await self._cleanup_pending_client_tool_call(thread, context)
591 yield ThreadItemDoneEvent(item=item)
592
593 async for event in self._process_events(
594 thread,
595 context,
596 lambda: self.respond(thread, item, context),
597 ):
598 yield event
599
600 async def _process_events(
601 self,
602 thread: ThreadMetadata,
603 context: TContext,
604 stream: Callable[[], AsyncIterator[ThreadStreamEvent]],
605 ) -> AsyncIterator[ThreadStreamEvent]:
606 await asyncio.sleep(0) # allow the response to start streaming
607
608 last_thread = thread.model_copy(deep=True)
609
610 try:
611 with agents_sdk_user_agent_override():
612 async for event in stream():
613 match event:
614 case ThreadItemDoneEvent():
615 await self.store.add_thread_item(
616 thread.id, event.item, context=context
617 )
618 case ThreadItemRemovedEvent():
619 await self.store.delete_thread_item(
620 thread.id, event.item_id, context=context
621 )
622 case ThreadItemReplacedEvent():
623 await self.store.save_item(
624 thread.id, event.item, context=context
625 )
626
627 # special case - don't send hidden context items back to the client
628 should_swallow_event = isinstance(
629 event, ThreadItemDoneEvent
630 ) and isinstance(event.item, HiddenContextItem)
631
632 if not should_swallow_event:
633 yield event
634
635 # in case user updated the thread while streaming
636 if thread != last_thread:
637 last_thread = thread.model_copy(deep=True)
638 await self.store.save_thread(thread, context=context)
639 yield ThreadUpdatedEvent(
640 thread=self._to_thread_response(thread)
641 )
642 # in case user updated the thread while streaming
643 if thread != last_thread:
644 last_thread = thread.model_copy(deep=True)
645 await self.store.save_thread(thread, context=context)
646 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
647 except CustomStreamError as e:
648 yield ErrorEvent(
649 code="custom",
650 message=e.message,
651 allow_retry=e.allow_retry,
652 )
653 except StreamError as e:
654 yield ErrorEvent(
655 code=e.code,
656 allow_retry=e.allow_retry,
657 )
658 except Exception as e:
659 yield ErrorEvent(
660 code=ErrorCode.STREAM_ERROR,
661 allow_retry=True,
662 )
663 logger.exception(e)
664
665 if thread != last_thread:
666 # in case user updated the thread at the end of the stream
667 await self.store.save_thread(thread, context=context)
668 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
669
670 async def _build_user_message_item(
671 self, input: UserMessageInput, thread: ThreadMetadata, context: TContext
672 ) -> UserMessageItem:
673 return UserMessageItem(
674 id=self.store.generate_item_id("message", thread, context),
675 content=input.content,
676 thread_id=thread.id,
677 attachments=[
678 await self.store.load_attachment(attachment_id, context)
679 for attachment_id in input.attachments
680 ],
681 quoted_text=input.quoted_text,
682 inference_options=input.inference_options,
683 created_at=datetime.now(),
684 )
685
686 async def _load_full_thread(self, thread_id: str, context: TContext) -> Thread:
687 thread_meta = await self.store.load_thread(thread_id, context=context)
688 thread_items = await self.store.load_thread_items(
689 thread_id,
690 after=None,
691 limit=DEFAULT_PAGE_SIZE,
692 order="asc",
693 context=context,
694 )
695 return Thread(**thread_meta.model_dump(), items=thread_items)
696
697 async def _paginate_thread_items_reverse(
698 self, thread_id: str, context: TContext
699 ) -> AsyncIterator[ThreadItem]:
700 """Paginate through thread items in reverse order (newest first)."""
701 after = None
702 while True:
703 items = await self.store.load_thread_items(
704 thread_id, after, DEFAULT_PAGE_SIZE, "desc", context
705 )
706 for item in items.data:
707 yield item
708
709 if not items.has_more:
710 break
711 after = items.after
712
713 def _serialize(self, obj: BaseModel) -> bytes:
714 return obj.model_dump_json(by_alias=True, exclude_none=True).encode("utf-8")
715
716 def _to_thread_response(self, thread: ThreadMetadata | Thread) -> Thread:
717 def is_hidden(item: ThreadItem) -> bool:
718 return isinstance(item, HiddenContextItem)
719
720 items = thread.items if isinstance(thread, Thread) else Page()
721 items.data = [item for item in items.data if not is_hidden(item)]
722
723 return Thread(
724 id=thread.id,
725 title=thread.title,
726 created_at=thread.created_at,
727 items=items,
728 status=thread.status,
729 )
730