openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
1b15f7de44765f3dd01a304292b08ffe12c2574b

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/server.py

736lines · modecode

1import asyncio
2from abc import ABC, abstractmethod
3from collections.abc import AsyncIterator
4from contextlib import contextmanager
5from datetime import datetime
6from typing import (
7 Any,
8 AsyncGenerator,
9 AsyncIterable,
10 Callable,
11 Generic,
12)
13
14import agents
15from agents.models.chatcmpl_helpers import (
16 HEADERS_OVERRIDE as chat_completions_headers_override,
17)
18from agents.models.openai_responses import (
19 _HEADERS_OVERRIDE as responses_headers_override,
20)
21from pydantic import BaseModel, TypeAdapter
22from typing_extensions import TypeVar, assert_never
23
24from chatkit.errors import CustomStreamError, StreamError
25
26from .logger import logger
27from .store import AttachmentStore, Store, StoreItemType, default_generate_id
28from .types import (
29 Action,
30 AttachmentsCreateReq,
31 AttachmentsDeleteReq,
32 ChatKitReq,
33 ClientToolCallItem,
34 ErrorCode,
35 ErrorEvent,
36 FeedbackKind,
37 HiddenContextItem,
38 ItemsFeedbackReq,
39 ItemsListReq,
40 NonStreamingReq,
41 Page,
42 StreamingReq,
43 Thread,
44 ThreadCreatedEvent,
45 ThreadItem,
46 ThreadItemAddedEvent,
47 ThreadItemDoneEvent,
48 ThreadItemRemovedEvent,
49 ThreadItemReplacedEvent,
50 ThreadItemUpdatedEvent,
51 ThreadMetadata,
52 ThreadsAddClientToolOutputReq,
53 ThreadsAddUserMessageReq,
54 ThreadsCreateReq,
55 ThreadsCustomActionReq,
56 ThreadsDeleteReq,
57 ThreadsGetByIdReq,
58 ThreadsListReq,
59 ThreadsRetryAfterItemReq,
60 ThreadStreamEvent,
61 ThreadsUpdateReq,
62 ThreadUpdatedEvent,
63 UserMessageInput,
64 UserMessageItem,
65 WidgetComponentUpdated,
66 WidgetItem,
67 WidgetRootUpdated,
68 WidgetStreamingTextValueDelta,
69 is_streaming_req,
70)
71from .version import __version__
72from .widgets import WidgetComponent, WidgetComponentBase, WidgetRoot
73
74DEFAULT_PAGE_SIZE = 20
75DEFAULT_ERROR_MESSAGE = "An error occurred when generating a response."
76
77
78def diff_widget(
79 before: WidgetRoot, after: WidgetRoot
80) -> list[WidgetStreamingTextValueDelta | WidgetRootUpdated | WidgetComponentUpdated]:
81 """
82 Compare two WidgetRoots and return a list of deltas.
83 """
84
85 def is_streaming_text(component: WidgetComponentBase) -> bool:
86 return getattr(component, "type", None) in {"Markdown", "Text"} and isinstance(
87 getattr(component, "value", None), str
88 )
89
90 def full_replace(before: WidgetComponentBase, after: WidgetComponentBase) -> bool:
91 if (
92 before.type != after.type
93 or before.id != after.id
94 or before.key != after.key
95 ):
96 return True
97
98 def full_replace_value(before_value: Any, after_value: Any) -> bool:
99 if isinstance(before_value, list) and isinstance(after_value, list):
100 if len(before_value) != len(after_value):
101 return True
102 for nth_before_value, nth_after_value in zip(before_value, after_value):
103 if full_replace_value(nth_before_value, nth_after_value):
104 return True
105 elif before_value != after_value:
106 if isinstance(before_value, WidgetComponentBase) and isinstance(
107 after_value, WidgetComponentBase
108 ):
109 return full_replace(before_value, after_value)
110 else:
111 return True
112 return False
113
114 for field in before.model_fields_set.union(after.model_fields_set):
115 if (
116 is_streaming_text(before)
117 and is_streaming_text(after)
118 and field == "value"
119 and getattr(after, "value", "").startswith(getattr(before, "value", ""))
120 ):
121 # Appends to the value prop of Markdown or Text do not trigger a full replace
122 continue
123 if full_replace_value(getattr(before, field), getattr(after, field)):
124 return True
125
126 return False
127
128 if full_replace(before, after):
129 return [WidgetRootUpdated(widget=after)]
130
131 deltas: list[
132 WidgetStreamingTextValueDelta | WidgetComponentUpdated | WidgetRootUpdated
133 ] = []
134
135 def find_all_streaming_text_components(
136 component: WidgetComponent | WidgetRoot,
137 ) -> dict[str, WidgetComponentBase]:
138 components = {}
139
140 def recurse(component: WidgetComponent | WidgetRoot):
141 if is_streaming_text(component) and component.id:
142 components[component.id] = component
143
144 if hasattr(component, "children"):
145 children = getattr(component, "children", None) or []
146 for child in children:
147 recurse(child)
148
149 recurse(component)
150 return components
151
152 before_nodes = find_all_streaming_text_components(before)
153 after_nodes = find_all_streaming_text_components(after)
154
155 for id, after_node in after_nodes.items():
156 before_node = before_nodes.get(id)
157 if before_node is None:
158 raise ValueError(
159 f"Node {id} was not present when the widget was initially rendered. All nodes with ID must persist across all widget updates."
160 )
161
162 before_value = str(getattr(before_node, "value", None))
163 after_value = str(getattr(after_node, "value", None))
164
165 if before_value != after_value:
166 if not after_value.startswith(before_value):
167 raise ValueError(
168 f"Node {id} was updated with a new value that is not a prefix of the initial value. All widget updates must be cumulative."
169 )
170 done = not getattr(after_node, "streaming", False)
171 deltas.append(
172 WidgetStreamingTextValueDelta(
173 component_id=id,
174 delta=after_value[len(before_value) :],
175 done=done,
176 )
177 )
178
179 return deltas
180
181
182async def stream_widget(
183 thread: ThreadMetadata,
184 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
185 copy_text: str | None = None,
186 generate_id: Callable[[StoreItemType], str] = default_generate_id,
187) -> AsyncIterator[ThreadStreamEvent]:
188 item_id = generate_id("message")
189
190 if not isinstance(widget, AsyncGenerator):
191 yield ThreadItemDoneEvent(
192 item=WidgetItem(
193 id=item_id,
194 thread_id=thread.id,
195 created_at=datetime.now(),
196 widget=widget,
197 copy_text=copy_text,
198 ),
199 )
200 return
201
202 initial_state = await widget.__anext__()
203
204 item = WidgetItem(
205 id=item_id,
206 created_at=datetime.now(),
207 widget=initial_state,
208 copy_text=copy_text,
209 thread_id=thread.id,
210 )
211
212 yield ThreadItemAddedEvent(item=item)
213
214 last_state = initial_state
215
216 while widget:
217 try:
218 new_state = await widget.__anext__()
219 for update in diff_widget(last_state, new_state):
220 yield ThreadItemUpdatedEvent(
221 item_id=item_id,
222 update=update,
223 )
224 last_state = new_state
225 except StopAsyncIteration:
226 break
227
228 yield ThreadItemDoneEvent(
229 item=item.model_copy(update={"widget": last_state}),
230 )
231
232
233@contextmanager
234def agents_sdk_user_agent_override():
235 ua = f"Agents/Python {agents.__version__} ChatKit/Python {__version__}"
236 chat_completions_token = chat_completions_headers_override.set({"User-Agent": ua})
237 responses_token = responses_headers_override.set({"User-Agent": ua})
238
239 yield
240
241 chat_completions_headers_override.reset(chat_completions_token)
242 responses_headers_override.reset(responses_token)
243
244
245class StreamingResult(AsyncIterable[bytes]):
246 def __init__(self, stream: AsyncGenerator[bytes, None]):
247 self.json_events = stream
248
249 async def __aiter__(self):
250 async for event in self.json_events:
251 yield event
252
253
254class NonStreamingResult:
255 def __init__(self, result: bytes):
256 self.json = result
257
258
259TContext = TypeVar("TContext", default=Any)
260
261
262class ChatKitServer(ABC, Generic[TContext]):
263 def __init__(
264 self,
265 store: Store[TContext],
266 attachment_store: AttachmentStore[TContext] | None = None,
267 ):
268 self.store = store
269 self.attachment_store = attachment_store
270
271 def _get_attachment_store(self) -> AttachmentStore[TContext]:
272 """Return the configured AttachmentStore or raise if missing."""
273 if self.attachment_store is None:
274 raise RuntimeError(
275 "AttachmentStore is not configured. Provide a AttachmentStore to ChatKitServer to handle file operations."
276 )
277 return self.attachment_store
278
279 @abstractmethod
280 def respond(
281 self,
282 thread: ThreadMetadata,
283 input_user_message: UserMessageItem | None,
284 context: TContext,
285 ) -> AsyncIterator[ThreadStreamEvent]:
286 """Stream `ThreadStreamEvent` instances for a new user message.
287
288 Args:
289 thread: Metadata for the thread being processed.
290 input_user_message: The incoming message the server should respond to, if any.
291 context: Arbitrary per-request context provided by the caller.
292
293 Returns:
294 An async iterator that yields events representing the server's response.
295 """
296 pass
297
298 async def add_feedback( # noqa: B027
299 self,
300 thread_id: str,
301 item_ids: list[str],
302 feedback: FeedbackKind,
303 context: TContext,
304 ) -> None:
305 pass
306
307 def action(
308 self,
309 thread: ThreadMetadata,
310 action: Action[str, Any],
311 sender: WidgetItem | None,
312 context: TContext,
313 ) -> AsyncIterator[ThreadStreamEvent]:
314 raise NotImplementedError(
315 "The action() method must be overridden to react to actions. "
316 "See https://github.com/openai/chatkit-python/blob/main/docs/widgets.md#widget-actions"
317 )
318
319 async def process(
320 self, request: str | bytes | bytearray, context: TContext
321 ) -> StreamingResult | NonStreamingResult:
322 parsed_request = TypeAdapter[ChatKitReq](ChatKitReq).validate_json(request)
323 logger.info(f"Received request op: {parsed_request.type}")
324
325 if is_streaming_req(parsed_request):
326 return StreamingResult(self._process_streaming(parsed_request, context))
327 else:
328 return NonStreamingResult(
329 await self._process_non_streaming(parsed_request, context)
330 )
331
332 async def _process_non_streaming(
333 self, request: NonStreamingReq, context: TContext
334 ) -> bytes:
335 match request:
336 case ThreadsGetByIdReq():
337 thread = await self._load_full_thread(
338 request.params.thread_id, context=context
339 )
340 return self._serialize(self._to_thread_response(thread))
341 case ThreadsListReq():
342 params = request.params
343 threads = await self.store.load_threads(
344 limit=params.limit or DEFAULT_PAGE_SIZE,
345 after=params.after,
346 order=params.order,
347 context=context,
348 )
349 return self._serialize(
350 Page(
351 has_more=threads.has_more,
352 after=threads.after,
353 data=[
354 self._to_thread_response(thread) for thread in threads.data
355 ],
356 )
357 )
358 case ItemsFeedbackReq():
359 await self.add_feedback(
360 request.params.thread_id,
361 request.params.item_ids,
362 request.params.kind,
363 context,
364 )
365 return b"{}"
366 case AttachmentsCreateReq():
367 attachment_store = self._get_attachment_store()
368 attachment = await attachment_store.create_attachment(
369 request.params, context
370 )
371 return self._serialize(attachment)
372 case AttachmentsDeleteReq():
373 attachment_store = self._get_attachment_store()
374 await attachment_store.delete_attachment(
375 request.params.attachment_id, context=context
376 )
377 await self.store.delete_attachment(
378 request.params.attachment_id, context=context
379 )
380 return b"{}"
381 case ItemsListReq():
382 items_list_params = request.params
383 items = await self.store.load_thread_items(
384 items_list_params.thread_id,
385 limit=items_list_params.limit or DEFAULT_PAGE_SIZE,
386 order=items_list_params.order,
387 after=items_list_params.after,
388 context=context,
389 )
390 # filter out HiddenContextItems
391 items.data = [
392 item
393 for item in items.data
394 if not isinstance(item, HiddenContextItem)
395 ]
396 return self._serialize(items)
397 case ThreadsUpdateReq():
398 thread = await self.store.load_thread(
399 request.params.thread_id, context=context
400 )
401 thread.title = request.params.title
402 await self.store.save_thread(thread, context=context)
403 return self._serialize(self._to_thread_response(thread))
404 case ThreadsDeleteReq():
405 await self.store.delete_thread(
406 request.params.thread_id, context=context
407 )
408 return b"{}"
409 case _:
410 assert_never(request)
411
412 async def _process_streaming(
413 self, request: StreamingReq, context: TContext
414 ) -> AsyncGenerator[bytes, None]:
415 try:
416 async for event in self._process_streaming_impl(request, context):
417 b = self._serialize(event)
418 yield b"data: " + b + b"\n\n"
419 except Exception:
420 logger.exception("Error while generating streamed response")
421 raise
422
423 async def _process_streaming_impl(
424 self, request: StreamingReq, context: TContext
425 ) -> AsyncGenerator[ThreadStreamEvent, None]:
426 match request:
427 case ThreadsCreateReq():
428 thread = Thread(
429 id=self.store.generate_thread_id(context),
430 created_at=datetime.now(),
431 items=Page(),
432 )
433 await self.store.save_thread(
434 ThreadMetadata(**thread.model_dump()),
435 context=context,
436 )
437 yield ThreadCreatedEvent(thread=self._to_thread_response(thread))
438 user_message = await self._build_user_message_item(
439 request.params.input, thread, context
440 )
441 async for event in self._process_new_thread_item_respond(
442 thread,
443 user_message,
444 context,
445 ):
446 yield event
447
448 case ThreadsAddUserMessageReq():
449 thread = await self.store.load_thread(
450 request.params.thread_id, context=context
451 )
452 user_message = await self._build_user_message_item(
453 request.params.input, thread, context
454 )
455 async for event in self._process_new_thread_item_respond(
456 thread,
457 user_message,
458 context,
459 ):
460 yield event
461
462 case ThreadsAddClientToolOutputReq():
463 thread = await self.store.load_thread(
464 request.params.thread_id, context=context
465 )
466 items = await self.store.load_thread_items(
467 thread.id, None, 1, "desc", context
468 )
469 tool_call = next(
470 (
471 item
472 for item in items.data
473 if isinstance(item, ClientToolCallItem)
474 and item.status == "pending"
475 ),
476 None,
477 )
478 if not tool_call:
479 raise ValueError(
480 f"Last thread item in {thread.id} was not a ClientToolCallItem"
481 )
482
483 tool_call.output = request.params.result
484 tool_call.status = "completed"
485
486 await self.store.save_item(thread.id, tool_call, context=context)
487
488 # Safety against dangling pending tool calls if there are
489 # multiple in a row, which should be impossible, and
490 # integrations should ultimately filter out pending tool calls
491 # when creating input response messages.
492 await self._cleanup_pending_client_tool_call(thread, context)
493
494 async for event in self._process_events(
495 thread,
496 context,
497 lambda: self.respond(thread, None, context),
498 ):
499 yield event
500
501 case ThreadsRetryAfterItemReq():
502 thread_metadata = await self.store.load_thread(
503 request.params.thread_id, context=context
504 )
505
506 # Collect items to remove (all items after the user message)
507 items_to_remove: list[ThreadItem] = []
508 user_message_item = None
509
510 async for item in self._paginate_thread_items_reverse(
511 request.params.thread_id, context
512 ):
513 if item.id == request.params.item_id:
514 if not isinstance(item, UserMessageItem):
515 raise ValueError(
516 f"Item {request.params.item_id} is not a user message"
517 )
518 user_message_item = item
519 break
520 items_to_remove.append(item)
521
522 if user_message_item:
523 for item in items_to_remove:
524 await self.store.delete_thread_item(
525 request.params.thread_id, item.id, context=context
526 )
527 async for event in self._process_events(
528 thread_metadata,
529 context,
530 lambda: self.respond(
531 thread_metadata,
532 user_message_item,
533 context,
534 ),
535 ):
536 yield event
537 case ThreadsCustomActionReq():
538 thread_metadata = await self.store.load_thread(
539 request.params.thread_id, context=context
540 )
541
542 item: ThreadItem | None = None
543 if request.params.item_id:
544 item = await self.store.load_item(
545 request.params.thread_id,
546 request.params.item_id,
547 context=context,
548 )
549
550 if item and not isinstance(item, WidgetItem):
551 # shouldn't happen if the caller is using the API correctly.
552 yield ErrorEvent(
553 code=ErrorCode.STREAM_ERROR,
554 allow_retry=False,
555 )
556 return
557
558 async for event in self._process_events(
559 thread_metadata,
560 context,
561 lambda: self.action(
562 thread_metadata,
563 request.params.action,
564 item,
565 context,
566 ),
567 ):
568 yield event
569
570 case _:
571 assert_never(request)
572
573 async def _cleanup_pending_client_tool_call(
574 self, thread: ThreadMetadata, context: TContext
575 ) -> None:
576 items = await self.store.load_thread_items(
577 thread.id, None, DEFAULT_PAGE_SIZE, "desc", context
578 )
579 for tool_call in items.data:
580 if not isinstance(tool_call, ClientToolCallItem):
581 continue
582 if tool_call.status == "pending":
583 logger.warning(
584 f"Client tool call {tool_call.call_id} was not completed, ignoring"
585 )
586 await self.store.delete_thread_item(
587 thread.id, tool_call.id, context=context
588 )
589
590 async def _process_new_thread_item_respond(
591 self,
592 thread: ThreadMetadata,
593 item: UserMessageItem,
594 context: TContext,
595 ) -> AsyncIterator[ThreadStreamEvent]:
596 await self.store.add_thread_item(thread.id, item, context=context)
597 await self._cleanup_pending_client_tool_call(thread, context)
598 yield ThreadItemDoneEvent(item=item)
599
600 async for event in self._process_events(
601 thread,
602 context,
603 lambda: self.respond(thread, item, context),
604 ):
605 yield event
606
607 async def _process_events(
608 self,
609 thread: ThreadMetadata,
610 context: TContext,
611 stream: Callable[[], AsyncIterator[ThreadStreamEvent]],
612 ) -> AsyncIterator[ThreadStreamEvent]:
613 await asyncio.sleep(0) # allow the response to start streaming
614
615 last_thread = thread.model_copy(deep=True)
616
617 try:
618 with agents_sdk_user_agent_override():
619 async for event in stream():
620 match event:
621 case ThreadItemDoneEvent():
622 await self.store.add_thread_item(
623 thread.id, event.item, context=context
624 )
625 case ThreadItemRemovedEvent():
626 await self.store.delete_thread_item(
627 thread.id, event.item_id, context=context
628 )
629 case ThreadItemReplacedEvent():
630 await self.store.save_item(
631 thread.id, event.item, context=context
632 )
633
634 # special case - don't send hidden context items back to the client
635 should_swallow_event = isinstance(
636 event, ThreadItemDoneEvent
637 ) and isinstance(event.item, HiddenContextItem)
638
639 if not should_swallow_event:
640 yield event
641
642 # in case user updated the thread while streaming
643 if thread != last_thread:
644 last_thread = thread.model_copy(deep=True)
645 await self.store.save_thread(thread, context=context)
646 yield ThreadUpdatedEvent(
647 thread=self._to_thread_response(thread)
648 )
649 # in case user updated the thread while streaming
650 if thread != last_thread:
651 last_thread = thread.model_copy(deep=True)
652 await self.store.save_thread(thread, context=context)
653 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
654 except CustomStreamError as e:
655 yield ErrorEvent(
656 code="custom",
657 message=e.message,
658 allow_retry=e.allow_retry,
659 )
660 except StreamError as e:
661 yield ErrorEvent(
662 code=e.code,
663 allow_retry=e.allow_retry,
664 )
665 except Exception as e:
666 yield ErrorEvent(
667 code=ErrorCode.STREAM_ERROR,
668 allow_retry=True,
669 )
670 logger.exception(e)
671
672 if thread != last_thread:
673 # in case user updated the thread at the end of the stream
674 await self.store.save_thread(thread, context=context)
675 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
676
677 async def _build_user_message_item(
678 self, input: UserMessageInput, thread: ThreadMetadata, context: TContext
679 ) -> UserMessageItem:
680 return UserMessageItem(
681 id=self.store.generate_item_id("message", thread, context),
682 content=input.content,
683 thread_id=thread.id,
684 attachments=[
685 await self.store.load_attachment(attachment_id, context)
686 for attachment_id in input.attachments
687 ],
688 quoted_text=input.quoted_text,
689 inference_options=input.inference_options,
690 created_at=datetime.now(),
691 )
692
693 async def _load_full_thread(self, thread_id: str, context: TContext) -> Thread:
694 thread_meta = await self.store.load_thread(thread_id, context=context)
695 thread_items = await self.store.load_thread_items(
696 thread_id,
697 after=None,
698 limit=DEFAULT_PAGE_SIZE,
699 order="asc",
700 context=context,
701 )
702 return Thread(**thread_meta.model_dump(), items=thread_items)
703
704 async def _paginate_thread_items_reverse(
705 self, thread_id: str, context: TContext
706 ) -> AsyncIterator[ThreadItem]:
707 """Paginate through thread items in reverse order (newest first)."""
708 after = None
709 while True:
710 items = await self.store.load_thread_items(
711 thread_id, after, DEFAULT_PAGE_SIZE, "desc", context
712 )
713 for item in items.data:
714 yield item
715
716 if not items.has_more:
717 break
718 after = items.after
719
720 def _serialize(self, obj: BaseModel) -> bytes:
721 return obj.model_dump_json(by_alias=True, exclude_none=True).encode("utf-8")
722
723 def _to_thread_response(self, thread: ThreadMetadata | Thread) -> Thread:
724 def is_hidden(item: ThreadItem) -> bool:
725 return isinstance(item, HiddenContextItem)
726
727 items = thread.items if isinstance(thread, Thread) else Page()
728 items.data = [item for item in items.data if not is_hidden(item)]
729
730 return Thread(
731 id=thread.id,
732 title=thread.title,
733 created_at=thread.created_at,
734 items=items,
735 status=thread.status,
736 )
737