openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
c47e77019578870e99d1b4873e9cddf7dc098f55

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/server.py

726lines · modecode

1import asyncio
2from abc import ABC, abstractmethod
3from collections.abc import AsyncIterator
4from contextlib import contextmanager
5from datetime import datetime
6from typing import (
7 Any,
8 AsyncGenerator,
9 AsyncIterable,
10 Callable,
11 Generic,
12 assert_never,
13)
14
15import agents
16from agents.models.chatcmpl_helpers import (
17 HEADERS_OVERRIDE as chat_completions_headers_override,
18)
19from agents.models.openai_responses import (
20 _HEADERS_OVERRIDE as responses_headers_override,
21)
22from pydantic import BaseModel, TypeAdapter
23from typing_extensions import TypeVar
24
25from chatkit.errors import CustomStreamError, StreamError
26
27from .logger import logger
28from .store import AttachmentStore, Store, StoreItemType, default_generate_id
29from .types import (
30 Action,
31 AttachmentsCreateReq,
32 AttachmentsDeleteReq,
33 ChatKitReq,
34 ClientToolCallItem,
35 ErrorCode,
36 ErrorEvent,
37 FeedbackKind,
38 HiddenContextItem,
39 ItemsFeedbackReq,
40 ItemsListReq,
41 NonStreamingReq,
42 Page,
43 StreamingReq,
44 Thread,
45 ThreadCreatedEvent,
46 ThreadItem,
47 ThreadItemAddedEvent,
48 ThreadItemDoneEvent,
49 ThreadItemRemovedEvent,
50 ThreadItemReplacedEvent,
51 ThreadItemUpdated,
52 ThreadMetadata,
53 ThreadsAddClientToolOutputReq,
54 ThreadsAddUserMessageReq,
55 ThreadsCreateReq,
56 ThreadsCustomActionReq,
57 ThreadsDeleteReq,
58 ThreadsGetByIdReq,
59 ThreadsListReq,
60 ThreadsRetryAfterItemReq,
61 ThreadStreamEvent,
62 ThreadsUpdateReq,
63 ThreadUpdatedEvent,
64 UserMessageInput,
65 UserMessageItem,
66 WidgetComponentUpdated,
67 WidgetItem,
68 WidgetRootUpdated,
69 WidgetStreamingTextValueDelta,
70 is_streaming_req,
71)
72from .version import __version__
73from .widgets import Markdown, Text, WidgetComponent, WidgetComponentBase, WidgetRoot
74
75DEFAULT_PAGE_SIZE = 20
76DEFAULT_ERROR_MESSAGE = "An error occurred when generating a response."
77
78
79def diff_widget(
80 before: WidgetRoot, after: WidgetRoot
81) -> list[WidgetStreamingTextValueDelta | WidgetRootUpdated | WidgetComponentUpdated]:
82 """
83 Compare two WidgetRoots and return a list of deltas.
84 """
85
86 def full_replace(before: WidgetComponentBase, after: WidgetComponentBase) -> bool:
87 if (
88 before.type != after.type
89 or before.id != after.id
90 or before.key != after.key
91 ):
92 return True
93
94 def full_replace_value(before_value: Any, after_value: Any) -> bool:
95 if isinstance(before_value, list) and isinstance(after_value, list):
96 if len(before_value) != len(after_value):
97 return True
98 for nth_before_value, nth_after_value in zip(before_value, after_value):
99 if full_replace_value(nth_before_value, nth_after_value):
100 return True
101 elif before_value != after_value:
102 if isinstance(before_value, WidgetComponentBase) and isinstance(
103 after_value, WidgetComponentBase
104 ):
105 return full_replace(before_value, after_value)
106 else:
107 return True
108 return False
109
110 for field in before.model_fields_set.union(after.model_fields_set):
111 if (
112 isinstance(before, (Markdown, Text))
113 and isinstance(after, (Markdown, Text))
114 and field == "value"
115 and after.value.startswith(before.value)
116 ):
117 # Appends to the value prop of Markdown or Text do not trigger a full replace
118 continue
119 if full_replace_value(getattr(before, field), getattr(after, field)):
120 return True
121
122 return False
123
124 if full_replace(before, after):
125 return [WidgetRootUpdated(widget=after)]
126
127 deltas: list[
128 WidgetStreamingTextValueDelta | WidgetComponentUpdated | WidgetRootUpdated
129 ] = []
130
131 def find_all_streaming_text_components(
132 component: WidgetComponent | WidgetRoot,
133 ) -> dict[str, Markdown | Text]:
134 components = {}
135
136 def recurse(component: WidgetComponent | WidgetRoot):
137 if isinstance(component, (Markdown, Text)) and component.id:
138 components[component.id] = component
139
140 if hasattr(component, "children"):
141 children = getattr(component, "children", None) or []
142 for child in children:
143 recurse(child)
144
145 recurse(component)
146 return components
147
148 before_nodes = find_all_streaming_text_components(before)
149 after_nodes = find_all_streaming_text_components(after)
150
151 for id, after_node in after_nodes.items():
152 before_node = before_nodes.get(id)
153 if before_node is None:
154 raise ValueError(
155 f"Node {id} was not present when the widget was initially rendered. All nodes with ID must persist across all widget updates."
156 )
157
158 if before_node.value != after_node.value:
159 if not after_node.value.startswith(before_node.value):
160 raise ValueError(
161 f"Node {id} was updated with a new value that is not a prefix of the initial value. All widget updates must be cumulative."
162 )
163 done = not after_node.streaming
164 deltas.append(
165 WidgetStreamingTextValueDelta(
166 component_id=id,
167 delta=after_node.value[len(before_node.value) :],
168 done=done,
169 )
170 )
171
172 return deltas
173
174
175async def stream_widget(
176 thread: ThreadMetadata,
177 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
178 copy_text: str | None = None,
179 generate_id: Callable[[StoreItemType], str] = default_generate_id,
180) -> AsyncIterator[ThreadStreamEvent]:
181 item_id = generate_id("message")
182
183 if not isinstance(widget, AsyncGenerator):
184 yield ThreadItemDoneEvent(
185 item=WidgetItem(
186 id=item_id,
187 thread_id=thread.id,
188 created_at=datetime.now(),
189 widget=widget,
190 copy_text=copy_text,
191 ),
192 )
193 return
194
195 initial_state = await widget.__anext__()
196
197 item = WidgetItem(
198 id=item_id,
199 created_at=datetime.now(),
200 widget=initial_state,
201 copy_text=copy_text,
202 thread_id=thread.id,
203 )
204
205 yield ThreadItemAddedEvent(item=item)
206
207 last_state = initial_state
208
209 while widget:
210 try:
211 new_state = await widget.__anext__()
212 for update in diff_widget(last_state, new_state):
213 yield ThreadItemUpdated(
214 item_id=item_id,
215 update=update,
216 )
217 last_state = new_state
218 except StopAsyncIteration:
219 break
220
221 yield ThreadItemDoneEvent(
222 item=item.model_copy(update={"widget": last_state}),
223 )
224
225
226@contextmanager
227def agents_sdk_user_agent_override():
228 ua = f"Agents/Python {agents.__version__} ChatKit/Python {__version__}"
229 chat_completions_token = chat_completions_headers_override.set({"User-Agent": ua})
230 responses_token = responses_headers_override.set({"User-Agent": ua})
231
232 yield
233
234 chat_completions_headers_override.reset(chat_completions_token)
235 responses_headers_override.reset(responses_token)
236
237
238class StreamingResult(AsyncIterable[bytes]):
239 def __init__(self, stream: AsyncGenerator[bytes, None]):
240 self.json_events = stream
241
242 async def __aiter__(self):
243 async for event in self.json_events:
244 yield event
245
246
247class NonStreamingResult:
248 def __init__(self, result: bytes):
249 self.json = result
250
251
252TContext = TypeVar("TContext", default=Any)
253
254
255class ChatKitServer(ABC, Generic[TContext]):
256 def __init__(
257 self,
258 store: Store[TContext],
259 attachment_store: AttachmentStore[TContext] | None = None,
260 ):
261 self.store = store
262 self.attachment_store = attachment_store
263
264 def _get_attachment_store(self) -> AttachmentStore[TContext]:
265 """Return the configured AttachmentStore or raise if missing."""
266 if self.attachment_store is None:
267 raise RuntimeError(
268 "AttachmentStore is not configured. Provide a AttachmentStore to ChatKitServer to handle file operations."
269 )
270 return self.attachment_store
271
272 @abstractmethod
273 def respond(
274 self,
275 thread: ThreadMetadata,
276 input_user_message: UserMessageItem | None,
277 context: TContext,
278 ) -> AsyncIterator[ThreadStreamEvent]:
279 """Stream `ThreadStreamEvent` instances for a new user message.
280
281 Args:
282 thread: Metadata for the thread being processed.
283 input_user_message: The incoming message the server should respond to, if any.
284 context: Arbitrary per-request context provided by the caller.
285
286 Returns:
287 An async iterator that yields events representing the server's response.
288 """
289 pass
290
291 async def add_feedback( # noqa: B027
292 self,
293 thread_id: str,
294 item_ids: list[str],
295 feedback: FeedbackKind,
296 context: TContext,
297 ) -> None:
298 pass
299
300 def action(
301 self,
302 thread: ThreadMetadata,
303 action: Action[str, Any],
304 sender: WidgetItem | None,
305 context: TContext,
306 ) -> AsyncIterator[ThreadStreamEvent]:
307 raise NotImplementedError(
308 "The action() method must be overridden to react to actions. "
309 "See https://github.com/OpenAI-Early-Access/chatkit/blob/main/docs/widgets.md#widget-actions"
310 )
311
312 async def process(
313 self, request: str | bytes | bytearray, context: TContext
314 ) -> StreamingResult | NonStreamingResult:
315 parsed_request = TypeAdapter[ChatKitReq](ChatKitReq).validate_json(request)
316 logger.info(f"Received request op: {parsed_request.type}")
317
318 if is_streaming_req(parsed_request):
319 return StreamingResult(self._process_streaming(parsed_request, context))
320 else:
321 return NonStreamingResult(
322 await self._process_non_streaming(parsed_request, context)
323 )
324
325 async def _process_non_streaming(
326 self, request: NonStreamingReq, context: TContext
327 ) -> bytes:
328 match request:
329 case ThreadsGetByIdReq():
330 thread = await self._load_full_thread(
331 request.params.thread_id, context=context
332 )
333 return self._serialize(self._to_thread_response(thread))
334 case ThreadsListReq():
335 params = request.params
336 threads = await self.store.load_threads(
337 limit=params.limit or DEFAULT_PAGE_SIZE,
338 after=params.after,
339 order=params.order,
340 context=context,
341 )
342 return self._serialize(
343 Page(
344 has_more=threads.has_more,
345 after=threads.after,
346 data=[
347 self._to_thread_response(thread) for thread in threads.data
348 ],
349 )
350 )
351 case ItemsFeedbackReq():
352 await self.add_feedback(
353 request.params.thread_id,
354 request.params.item_ids,
355 request.params.kind,
356 context,
357 )
358 return b"{}"
359 case AttachmentsCreateReq():
360 attachment_store = self._get_attachment_store()
361 attachment = await attachment_store.create_attachment(
362 request.params, context
363 )
364 return self._serialize(attachment)
365 case AttachmentsDeleteReq():
366 attachment_store = self._get_attachment_store()
367 await attachment_store.delete_attachment(
368 request.params.attachment_id, context=context
369 )
370 await self.store.delete_attachment(
371 request.params.attachment_id, context=context
372 )
373 return b"{}"
374 case ItemsListReq():
375 items_list_params = request.params
376 items = await self.store.load_thread_items(
377 items_list_params.thread_id,
378 limit=items_list_params.limit or DEFAULT_PAGE_SIZE,
379 order=items_list_params.order,
380 after=items_list_params.after,
381 context=context,
382 )
383 # filter out HiddenContextItems
384 items.data = [
385 item
386 for item in items.data
387 if not isinstance(item, HiddenContextItem)
388 ]
389 return self._serialize(items)
390 case ThreadsUpdateReq():
391 thread = await self.store.load_thread(
392 request.params.thread_id, context=context
393 )
394 thread.title = request.params.title
395 await self.store.save_thread(thread, context=context)
396 return self._serialize(self._to_thread_response(thread))
397 case ThreadsDeleteReq():
398 await self.store.delete_thread(
399 request.params.thread_id, context=context
400 )
401 return b"{}"
402 case _:
403 assert_never(request)
404
405 async def _process_streaming(
406 self, request: StreamingReq, context: TContext
407 ) -> AsyncGenerator[bytes, None]:
408 try:
409 async for event in self._process_streaming_impl(request, context):
410 b = self._serialize(event)
411 yield b"data: " + b + b"\n\n"
412 except Exception:
413 logger.exception("Error while generating streamed response")
414 raise
415
416 async def _process_streaming_impl(
417 self, request: StreamingReq, context: TContext
418 ) -> AsyncGenerator[ThreadStreamEvent, None]:
419 match request:
420 case ThreadsCreateReq():
421 thread = Thread(
422 id=self.store.generate_thread_id(context),
423 created_at=datetime.now(),
424 items=Page(),
425 )
426 await self.store.save_thread(thread, context=context)
427 yield ThreadCreatedEvent(thread=self._to_thread_response(thread))
428 user_message = await self._build_user_message_item(
429 request.params.input, thread, context
430 )
431 async for event in self._process_new_thread_item_respond(
432 thread,
433 user_message,
434 context,
435 ):
436 yield event
437
438 case ThreadsAddUserMessageReq():
439 thread = await self.store.load_thread(
440 request.params.thread_id, context=context
441 )
442 user_message = await self._build_user_message_item(
443 request.params.input, thread, context
444 )
445 async for event in self._process_new_thread_item_respond(
446 thread,
447 user_message,
448 context,
449 ):
450 yield event
451
452 case ThreadsAddClientToolOutputReq():
453 thread = await self.store.load_thread(
454 request.params.thread_id, context=context
455 )
456 items = await self.store.load_thread_items(
457 thread.id, None, 1, "desc", context
458 )
459 tool_call = next(
460 (
461 item
462 for item in items.data
463 if isinstance(item, ClientToolCallItem)
464 and item.status == "pending"
465 ),
466 None,
467 )
468 if not tool_call:
469 raise ValueError(
470 f"Last thread item in {thread.id} was not a ClientToolCallItem"
471 )
472
473 tool_call.output = request.params.result
474 tool_call.status = "completed"
475
476 await self.store.save_item(thread.id, tool_call, context=context)
477
478 # Safety against dangling pending tool calls if there are
479 # multiple in a row, which should be impossible, and
480 # integrations should ultimately filter out pending tool calls
481 # when creating input response messages.
482 await self._cleanup_pending_client_tool_call(thread, context)
483
484 async for event in self._process_events(
485 thread,
486 context,
487 lambda: self.respond(thread, None, context),
488 ):
489 yield event
490
491 case ThreadsRetryAfterItemReq():
492 thread_metadata = await self.store.load_thread(
493 request.params.thread_id, context=context
494 )
495
496 # Collect items to remove (all items after the user message)
497 items_to_remove: list[ThreadItem] = []
498 user_message_item = None
499
500 async for item in self._paginate_thread_items_reverse(
501 request.params.thread_id, context
502 ):
503 if item.id == request.params.item_id:
504 if not isinstance(item, UserMessageItem):
505 raise ValueError(
506 f"Item {request.params.item_id} is not a user message"
507 )
508 user_message_item = item
509 break
510 items_to_remove.append(item)
511
512 if user_message_item:
513 for item in items_to_remove:
514 await self.store.delete_thread_item(
515 request.params.thread_id, item.id, context=context
516 )
517 async for event in self._process_events(
518 thread_metadata,
519 context,
520 lambda: self.respond(
521 thread_metadata,
522 user_message_item,
523 context,
524 ),
525 ):
526 yield event
527 case ThreadsCustomActionReq():
528 thread_metadata = await self.store.load_thread(
529 request.params.thread_id, context=context
530 )
531
532 item: ThreadItem | None = None
533 if request.params.item_id:
534 item = await self.store.load_item(
535 request.params.thread_id,
536 request.params.item_id,
537 context=context,
538 )
539
540 if item and not isinstance(item, WidgetItem):
541 # shouldn't happen if the caller is using the API correctly.
542 yield ErrorEvent(
543 code=ErrorCode.STREAM_ERROR,
544 allow_retry=False,
545 )
546 return
547
548 async for event in self._process_events(
549 thread_metadata,
550 context,
551 lambda: self.action(
552 thread_metadata,
553 request.params.action,
554 item,
555 context,
556 ),
557 ):
558 yield event
559
560 case _:
561 assert_never(request)
562
563 async def _cleanup_pending_client_tool_call(
564 self, thread: ThreadMetadata, context: TContext
565 ) -> None:
566 items = await self.store.load_thread_items(
567 thread.id, None, DEFAULT_PAGE_SIZE, "desc", context
568 )
569 for tool_call in items.data:
570 if not isinstance(tool_call, ClientToolCallItem):
571 continue
572 if tool_call.status == "pending":
573 logger.warning(
574 f"Client tool call {tool_call.call_id} was not completed, ignoring"
575 )
576 await self.store.delete_thread_item(
577 thread.id, tool_call.id, context=context
578 )
579
580 async def _process_new_thread_item_respond(
581 self,
582 thread: ThreadMetadata,
583 item: UserMessageItem,
584 context: TContext,
585 ) -> AsyncIterator[ThreadStreamEvent]:
586 await self.store.add_thread_item(thread.id, item, context=context)
587 await self._cleanup_pending_client_tool_call(thread, context)
588 yield ThreadItemDoneEvent(item=item)
589
590 async for event in self._process_events(
591 thread,
592 context,
593 lambda: self.respond(thread, item, context),
594 ):
595 yield event
596
597 async def _process_events(
598 self,
599 thread: ThreadMetadata,
600 context: TContext,
601 stream: Callable[[], AsyncIterator[ThreadStreamEvent]],
602 ) -> AsyncIterator[ThreadStreamEvent]:
603 await asyncio.sleep(0) # allow the response to start streaming
604
605 last_thread = thread.model_copy(deep=True)
606
607 try:
608 with agents_sdk_user_agent_override():
609 async for event in stream():
610 match event:
611 case ThreadItemDoneEvent():
612 await self.store.add_thread_item(
613 thread.id, event.item, context=context
614 )
615 case ThreadItemRemovedEvent():
616 await self.store.delete_thread_item(
617 thread.id, event.item_id, context=context
618 )
619 case ThreadItemReplacedEvent():
620 await self.store.save_item(
621 thread.id, event.item, context=context
622 )
623
624 # special case - don't send hidden context items back to the client
625 should_swallow_event = isinstance(
626 event, ThreadItemDoneEvent
627 ) and isinstance(event.item, HiddenContextItem)
628
629 if not should_swallow_event:
630 yield event
631
632 # in case user updated the thread while streaming
633 if thread != last_thread:
634 last_thread = thread.model_copy(deep=True)
635 await self.store.save_thread(thread, context=context)
636 yield ThreadUpdatedEvent(
637 thread=self._to_thread_response(thread)
638 )
639 # in case user updated the thread while streaming
640 if thread != last_thread:
641 last_thread = thread.model_copy(deep=True)
642 await self.store.save_thread(thread, context=context)
643 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
644 except CustomStreamError as e:
645 yield ErrorEvent(
646 code="custom",
647 message=e.message,
648 allow_retry=e.allow_retry,
649 )
650 except StreamError as e:
651 yield ErrorEvent(
652 code=e.code,
653 allow_retry=e.allow_retry,
654 )
655 except Exception as e:
656 yield ErrorEvent(
657 code=ErrorCode.STREAM_ERROR,
658 allow_retry=True,
659 )
660 logger.exception(e)
661
662 if thread != last_thread:
663 # in case user updated the thread at the end of the stream
664 await self.store.save_thread(thread, context=context)
665 yield ThreadUpdatedEvent(thread=self._to_thread_response(thread))
666
667 async def _build_user_message_item(
668 self, input: UserMessageInput, thread: ThreadMetadata, context: TContext
669 ) -> UserMessageItem:
670 return UserMessageItem(
671 id=self.store.generate_item_id("message", thread, context),
672 content=input.content,
673 thread_id=thread.id,
674 attachments=[
675 await self.store.load_attachment(attachment_id, context)
676 for attachment_id in input.attachments
677 ],
678 quoted_text=input.quoted_text,
679 inference_options=input.inference_options,
680 created_at=datetime.now(),
681 )
682
683 async def _load_full_thread(self, thread_id: str, context: TContext) -> Thread:
684 thread_meta = await self.store.load_thread(thread_id, context=context)
685 thread_items = await self.store.load_thread_items(
686 thread_id,
687 after=None,
688 limit=DEFAULT_PAGE_SIZE,
689 order="asc",
690 context=context,
691 )
692 return Thread(**thread_meta.model_dump(), items=thread_items)
693
694 async def _paginate_thread_items_reverse(
695 self, thread_id: str, context: TContext
696 ) -> AsyncIterator[ThreadItem]:
697 """Paginate through thread items in reverse order (newest first)."""
698 after = None
699 while True:
700 items = await self.store.load_thread_items(
701 thread_id, after, DEFAULT_PAGE_SIZE, "desc", context
702 )
703 for item in items.data:
704 yield item
705
706 if not items.has_more:
707 break
708 after = items.after
709
710 def _serialize(self, obj: BaseModel) -> bytes:
711 return obj.model_dump_json(by_alias=True, exclude_none=True).encode("utf-8")
712
713 def _to_thread_response(self, thread: ThreadMetadata | Thread) -> Thread:
714 def is_hidden(item: ThreadItem) -> bool:
715 return isinstance(item, HiddenContextItem)
716
717 items = thread.items if isinstance(thread, Thread) else Page()
718 items.data = [item for item in items.data if not is_hidden(item)]
719
720 return Thread(
721 id=thread.id,
722 title=thread.title,
723 created_at=thread.created_at,
724 items=items,
725 status=thread.status,
726 )