openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
6646f83af681e14fcdcf69f2fa31dbdb0a9efa2d

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/agents.py

1183lines · modecode

1import asyncio
2import json
3from collections import defaultdict
4from collections.abc import AsyncIterator
5from datetime import datetime
6from inspect import cleandoc
7from typing import (
8 Annotated,
9 Any,
10 AsyncGenerator,
11 Generic,
12 Sequence,
13 TypeVar,
14 cast,
15)
16
17from agents import (
18 InputGuardrailTripwireTriggered,
19 OutputGuardrailTripwireTriggered,
20 RunResultStreaming,
21 StreamEvent,
22 TResponseInputItem,
23)
24from openai.types.responses import (
25 EasyInputMessageParam,
26 ResponseFunctionToolCallParam,
27 ResponseInputContentParam,
28 ResponseInputImageParam,
29 ResponseInputMessageContentListParam,
30 ResponseInputTextParam,
31 ResponseOutputText,
32)
33from openai.types.responses.response_input_item_param import (
34 FunctionCallOutput,
35 Message,
36)
37from openai.types.responses.response_output_message import Content
38from openai.types.responses.response_output_text import (
39 Annotation as ResponsesAnnotation,
40)
41from pydantic import BaseModel, ConfigDict, SkipValidation, TypeAdapter
42from typing_extensions import assert_never
43
44from .server import stream_widget
45from .store import Store, StoreItemType
46from .types import (
47 Annotation,
48 AssistantMessageContent,
49 AssistantMessageContentPartAdded,
50 AssistantMessageContentPartAnnotationAdded,
51 AssistantMessageContentPartDone,
52 AssistantMessageContentPartTextDelta,
53 AssistantMessageItem,
54 Attachment,
55 ClientToolCallItem,
56 DurationSummary,
57 EndOfTurnItem,
58 FileSource,
59 GeneratedImage,
60 GeneratedImageItem,
61 GeneratedImageUpdated,
62 HiddenContextItem,
63 SDKHiddenContextItem,
64 Task,
65 TaskItem,
66 ThoughtTask,
67 ThreadItem,
68 ThreadItemAddedEvent,
69 ThreadItemDoneEvent,
70 ThreadItemRemovedEvent,
71 ThreadItemUpdatedEvent,
72 ThreadMetadata,
73 ThreadStreamEvent,
74 URLSource,
75 UserMessageItem,
76 UserMessageTagContent,
77 UserMessageTextContent,
78 WidgetItem,
79 Workflow,
80 WorkflowItem,
81 WorkflowSummary,
82 WorkflowTaskAdded,
83 WorkflowTaskUpdated,
84)
85from .widgets import Markdown, Text, WidgetRoot
86
87
88class ClientToolCall(BaseModel):
89 """
90 Returned from tool methods to indicate a client-side tool call.
91 """
92
93 name: str
94 arguments: dict[str, Any]
95
96
97class _QueueCompleteSentinel: ...
98
99
100TContext = TypeVar("TContext")
101
102
103class AgentContext(BaseModel, Generic[TContext]):
104 model_config = ConfigDict(arbitrary_types_allowed=True)
105
106 thread: ThreadMetadata
107 store: Annotated[Store[TContext], SkipValidation]
108 request_context: TContext
109 previous_response_id: str | None = None
110 client_tool_call: ClientToolCall | None = None
111 workflow_item: WorkflowItem | None = None
112 generated_image_item: GeneratedImageItem | None = None
113 _events: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel] = asyncio.Queue()
114
115 def generate_id(
116 self, type: StoreItemType, thread: ThreadMetadata | None = None
117 ) -> str:
118 """Generate a new store-backed id for the given item type."""
119 if type == "thread":
120 return self.store.generate_thread_id(self.request_context)
121 return self.store.generate_item_id(
122 type, thread or self.thread, self.request_context
123 )
124
125 async def stream_widget(
126 self,
127 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
128 copy_text: str | None = None,
129 ) -> None:
130 """Stream a widget into the thread by enqueueing widget events."""
131 async for event in stream_widget(
132 self.thread,
133 widget,
134 copy_text,
135 lambda item_type: self.store.generate_item_id(
136 item_type, self.thread, self.request_context
137 ),
138 ):
139 await self._events.put(event)
140
141 async def end_workflow(
142 self, summary: WorkflowSummary | None = None, expanded: bool = False
143 ) -> None:
144 """Finalize the active workflow item, optionally attaching a summary."""
145 if not self.workflow_item:
146 # No workflow to end
147 return
148
149 if summary is not None:
150 self.workflow_item.workflow.summary = summary
151 elif self.workflow_item.workflow.summary is None:
152 # If no summary was set or provided, set a basic work summary
153 delta = datetime.now() - self.workflow_item.created_at
154 duration = int(delta.total_seconds())
155 self.workflow_item.workflow.summary = DurationSummary(duration=duration)
156 self.workflow_item.workflow.expanded = expanded
157 await self.stream(ThreadItemDoneEvent(item=self.workflow_item))
158 self.workflow_item = None
159
160 async def start_workflow(self, workflow: Workflow) -> None:
161 """Begin streaming a new workflow item."""
162 self.workflow_item = WorkflowItem(
163 id=self.generate_id("workflow"),
164 created_at=datetime.now(),
165 workflow=workflow,
166 thread_id=self.thread.id,
167 )
168
169 if workflow.type != "reasoning" and len(workflow.tasks) == 0:
170 # Defer sending added event until we have tasks
171 return
172
173 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
174
175 async def update_workflow_task(self, task: Task, task_index: int) -> None:
176 """Update an existing workflow task and stream the delta."""
177 if self.workflow_item is None:
178 raise ValueError("Workflow is not set")
179 # ensure reference is updated in case task is a copy
180 self.workflow_item.workflow.tasks[task_index] = task
181 await self.stream(
182 ThreadItemUpdatedEvent(
183 item_id=self.workflow_item.id,
184 update=WorkflowTaskUpdated(
185 task=task,
186 task_index=task_index,
187 ),
188 )
189 )
190
191 async def add_workflow_task(self, task: Task) -> None:
192 """Append a workflow task and stream the appropriate event."""
193 self.workflow_item = self.workflow_item or WorkflowItem(
194 id=self.generate_id("workflow"),
195 created_at=datetime.now(),
196 workflow=Workflow(type="custom", tasks=[]),
197 thread_id=self.thread.id,
198 )
199 workflow = self.workflow_item.workflow
200 workflow.tasks.append(task)
201
202 if workflow.type != "reasoning" and len(workflow.tasks) == 1:
203 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
204 else:
205 await self.stream(
206 ThreadItemUpdatedEvent(
207 item_id=self.workflow_item.id,
208 update=WorkflowTaskAdded(
209 task=task,
210 task_index=workflow.tasks.index(task),
211 ),
212 )
213 )
214
215 async def stream(self, event: ThreadStreamEvent) -> None:
216 """Enqueue a ThreadStreamEvent for downstream processing."""
217 await self._events.put(event)
218
219 def _complete(self):
220 self._events.put_nowait(_QueueCompleteSentinel())
221
222
223def _convert_content(content: Content) -> AssistantMessageContent:
224 if content.type == "output_text":
225 annotations = [
226 _convert_annotation(annotation) for annotation in content.annotations
227 ]
228 annotations = [a for a in annotations if a is not None]
229 return AssistantMessageContent(
230 text=content.text,
231 annotations=annotations,
232 )
233 else:
234 return AssistantMessageContent(
235 text=content.refusal,
236 annotations=[],
237 )
238
239
240def _convert_annotation(raw_annotation: object) -> Annotation | None:
241 # There is a bug in the OpenAPI client that sometimes parses the annotation delta event into the wrong class
242 # resulting into annotation being a dict or untyped object instead instead of a ResponsesAnnotation
243 annotation = TypeAdapter[ResponsesAnnotation](ResponsesAnnotation).validate_python(
244 raw_annotation
245 )
246
247 if annotation.type == "file_citation":
248 filename = annotation.filename
249 if not filename:
250 return None
251
252 return Annotation(
253 source=FileSource(filename=filename, title=filename),
254 index=annotation.index,
255 )
256
257 if annotation.type == "url_citation":
258 return Annotation(
259 source=URLSource(
260 url=annotation.url,
261 title=annotation.title,
262 ),
263 index=annotation.end_index,
264 )
265
266 if annotation.type == "container_file_citation":
267 filename = annotation.filename
268 if not filename:
269 return None
270
271 return Annotation(
272 source=FileSource(filename=filename, title=filename),
273 index=annotation.end_index,
274 )
275
276 return None
277
278
279T1 = TypeVar("T1")
280T2 = TypeVar("T2")
281
282
283async def _merge_generators(
284 a: AsyncIterator[T1],
285 b: AsyncIterator[T2],
286) -> AsyncIterator[T1 | T2]:
287 pending: list[AsyncIterator[T1 | T2]] = [a, b]
288 pending_tasks: dict[asyncio.Task, AsyncIterator[T1 | T2]] = {
289 asyncio.ensure_future(g.__anext__()): g for g in pending
290 }
291 while len(pending_tasks) > 0:
292 done, _ = await asyncio.wait(
293 pending_tasks.keys(), return_when="FIRST_COMPLETED"
294 )
295 stop = False
296 for d in done:
297 try:
298 result = d.result()
299 yield result
300 dg = pending_tasks[d]
301 pending_tasks[asyncio.ensure_future(dg.__anext__())] = dg
302 except StopAsyncIteration:
303 stop = True
304 finally:
305 del pending_tasks[d]
306 if stop:
307 for task in pending_tasks.keys():
308 if not task.cancel():
309 try:
310 yield task.result()
311 except asyncio.CancelledError:
312 pass
313 except asyncio.InvalidStateError:
314 pass
315 break
316
317
318class _EventWrapper:
319 def __init__(self, event: ThreadStreamEvent):
320 self.event = event
321
322
323class _AsyncQueueIterator(AsyncIterator[_EventWrapper]):
324 def __init__(
325 self, queue: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel]
326 ):
327 self.queue = queue
328 self.completed = False
329
330 def __aiter__(self):
331 return self
332
333 async def __anext__(self):
334 if self.completed:
335 raise StopAsyncIteration
336
337 item = await self.queue.get()
338 if isinstance(item, _QueueCompleteSentinel):
339 self.completed = True
340 raise StopAsyncIteration
341 return _EventWrapper(item)
342
343 def drain_and_complete(self) -> None:
344 """Empty the underlying queue without awaiting and mark this iterator completed.
345
346 This is intended for cleanup paths where we must guarantee no awaits
347 occur. All queued items, including any completion sentinel, are
348 discarded.
349 """
350 while True:
351 try:
352 self.queue.get_nowait()
353 except asyncio.QueueEmpty:
354 break
355 self.completed = True
356
357
358class StreamingThoughtTracker(BaseModel):
359 item_id: str
360 index: int
361 task: ThoughtTask
362
363
364class ResponseStreamConverter:
365 """Used by `stream_agent_response` to convert streamed Agents SDK output
366 into values used by ChatKit thread items and thread stream events.
367
368 Defines overridable methods for adapting streamed data (such as image
369 generation results and partial updates) into the forms expected by ChatKit.
370 """
371
372 partial_images: int | None = None
373 """
374 The expected number of partial image updates for an image generation result.
375
376 When set, this value is used to normalize partial image indices into a
377 progress value in the range [0, 1]. If unset, all partial image updates are
378 assigned a progress value of 0.
379 """
380
381 def __init__(self, *, partial_images: int | None = None):
382 """
383 Args:
384 partial_images: The expected number of partial image updates for image
385 generation results, or None if no progress normalization should
386 be performed.
387 """
388 self.partial_images = partial_images
389
390 async def base64_image_to_url(
391 self,
392 image_id: str,
393 base64_image: str,
394 partial_image_index: int | None = None,
395 ) -> str:
396 """
397 Convert a base64-encoded image into a URL.
398
399 This method is used to produce the URL stored on thread items for image
400 generation results.
401
402 Args:
403 image_id: The ID of the image generation call. This stays stable across partial image updates.
404 base64_image: The base64-encoded image.
405 partial_image_index: The index of the partial image update, starting from 0.
406
407 Returns:
408 A URL string.
409 """
410 return f"data:image/png;base64,{base64_image}"
411
412 def partial_image_index_to_progress(self, partial_image_index: int) -> float:
413 """
414 Convert a partial image index into a normalized progress value.
415
416 Args:
417 partial_image_index: The index of the partial image update, starting from 0.
418
419 Returns:
420 A float between 0 and 1 representing progress for the image
421 generation result.
422 """
423 if self.partial_images is None or self.partial_images <= 0:
424 return 0.0
425
426 return min(1.0, partial_image_index / self.partial_images)
427
428
429_DEFAULT_RESPONSE_STREAM_CONVERTER = ResponseStreamConverter()
430
431
432async def stream_agent_response(
433 context: AgentContext,
434 result: RunResultStreaming,
435 *,
436 converter: ResponseStreamConverter = _DEFAULT_RESPONSE_STREAM_CONVERTER,
437) -> AsyncIterator[ThreadStreamEvent]:
438 """
439 Convert a streamed Agents SDK run into ChatKit thread stream events.
440
441 This function consumes a streaming run result and yields `ThreadStreamEvent`
442 objects as the run progresses.
443
444 Args:
445 context: The AgentContext to use for the stream.
446 result: The RunResultStreaming to convert.
447 image_generation_stream_converter: Controls how streamed image generation output
448 is converted into URLs and progress updates. The default converter stores the
449 generated base64 image and assigns a progress value of 0 to all partial image
450 updates.
451
452 Returns:
453 An async iterator that yields thread stream events representing the run result.
454 """
455 current_item_id = None
456 current_tool_call = None
457 ctx = context
458 thread = context.thread
459 queue_iterator = _AsyncQueueIterator(context._events)
460 produced_items = set()
461 streaming_thought: None | StreamingThoughtTracker = None
462 # item_id -> content_index -> annotation count
463 item_annotation_count: defaultdict[str, defaultdict[int, int]] = defaultdict(
464 lambda: defaultdict(int)
465 )
466
467 # check if the last item in the thread was a workflow or a client tool call
468 # if it was a client tool call, check if the second last item was a workflow
469 # if either was, continue the workflow
470 items = await context.store.load_thread_items(
471 thread.id, None, 2, "desc", context.request_context
472 )
473 last_item = items.data[0] if len(items.data) > 0 else None
474 second_last_item = items.data[1] if len(items.data) > 1 else None
475
476 if last_item and last_item.type == "workflow":
477 ctx.workflow_item = last_item
478 elif (
479 last_item
480 and last_item.type == "client_tool_call"
481 and second_last_item
482 and second_last_item.type == "workflow"
483 ):
484 ctx.workflow_item = second_last_item
485
486 def end_workflow(item: WorkflowItem):
487 if item == ctx.workflow_item:
488 ctx.workflow_item = None
489 delta = datetime.now() - item.created_at
490 duration = int(delta.total_seconds())
491 if item.workflow.summary is None:
492 item.workflow.summary = DurationSummary(duration=duration)
493 # Default to closing all workflows
494 # To keep a workflow open on completion, close it explicitly with
495 # AgentContext.end_workflow(expanded=True)
496 item.workflow.expanded = False
497 return ThreadItemDoneEvent(item=item)
498
499 try:
500 async for event in _merge_generators(result.stream_events(), queue_iterator):
501 # Events emitted from agent context helpers
502 if isinstance(event, _EventWrapper):
503 event = event.event
504 if (
505 event.type == "thread.item.added"
506 or event.type == "thread.item.done"
507 ):
508 # End the current workflow if visual item is added after it
509 if (
510 ctx.workflow_item
511 and ctx.workflow_item.id != event.item.id
512 and event.item.type != "client_tool_call"
513 and event.item.type != "hidden_context_item"
514 ):
515 yield end_workflow(ctx.workflow_item)
516
517 # track the current workflow if one is added
518 if (
519 event.type == "thread.item.added"
520 and event.item.type == "workflow"
521 ):
522 ctx.workflow_item = event.item
523
524 # track integration produced items so we can clean them up if
525 # there is a guardrail tripwire
526 produced_items.add(event.item.id)
527 yield event
528 continue
529
530 if event.type == "run_item_stream_event":
531 event = event.item
532 if (
533 event.type == "tool_call_item"
534 and event.raw_item.type == "function_call"
535 ):
536 current_tool_call = event.raw_item.call_id
537 current_item_id = event.raw_item.id
538 assert current_item_id
539 produced_items.add(current_item_id)
540 continue
541
542 if event.type != "raw_response_event":
543 # Ignore everything else that isn't a raw response event
544 continue
545
546 # Handle Responses events
547 event = event.data
548 if event.type == "response.content_part.added":
549 if event.part.type == "reasoning_text":
550 continue
551 content = _convert_content(event.part)
552 yield ThreadItemUpdatedEvent(
553 item_id=event.item_id,
554 update=AssistantMessageContentPartAdded(
555 content_index=event.content_index,
556 content=content,
557 ),
558 )
559 elif event.type == "response.output_text.delta":
560 yield ThreadItemUpdatedEvent(
561 item_id=event.item_id,
562 update=AssistantMessageContentPartTextDelta(
563 content_index=event.content_index,
564 delta=event.delta,
565 ),
566 )
567 elif event.type == "response.output_text.done":
568 yield ThreadItemUpdatedEvent(
569 item_id=event.item_id,
570 update=AssistantMessageContentPartDone(
571 content_index=event.content_index,
572 content=AssistantMessageContent(
573 text=event.text,
574 annotations=[],
575 ),
576 ),
577 )
578 elif event.type == "response.output_text.annotation.added":
579 annotation = _convert_annotation(event.annotation)
580 if annotation:
581 # Manually track annotation indices per content part in case we drop an annotation that
582 # we can't convert to our internal representation (e.g. missing filename).
583 annotation_index = item_annotation_count[event.item_id][
584 event.content_index
585 ]
586 item_annotation_count[event.item_id][event.content_index] = (
587 annotation_index + 1
588 )
589 yield ThreadItemUpdatedEvent(
590 item_id=event.item_id,
591 update=AssistantMessageContentPartAnnotationAdded(
592 content_index=event.content_index,
593 annotation_index=annotation_index,
594 annotation=annotation,
595 ),
596 )
597 continue
598 elif event.type == "response.output_item.added":
599 item = event.item
600 if item.type == "reasoning" and not ctx.workflow_item:
601 ctx.workflow_item = WorkflowItem(
602 id=ctx.generate_id("workflow"),
603 created_at=datetime.now(),
604 workflow=Workflow(type="reasoning", tasks=[]),
605 thread_id=thread.id,
606 )
607 produced_items.add(ctx.workflow_item.id)
608 yield ThreadItemAddedEvent(item=ctx.workflow_item)
609 if item.type == "message":
610 if ctx.workflow_item:
611 yield end_workflow(ctx.workflow_item)
612 produced_items.add(item.id)
613 yield ThreadItemAddedEvent(
614 item=AssistantMessageItem(
615 # Reusing the Responses message ID
616 id=item.id,
617 thread_id=thread.id,
618 content=[_convert_content(c) for c in item.content],
619 created_at=datetime.now(),
620 ),
621 )
622 elif item.type == "image_generation_call":
623 ctx.generated_image_item = GeneratedImageItem(
624 id=ctx.generate_id("message"),
625 thread_id=thread.id,
626 created_at=datetime.now(),
627 image=None,
628 )
629 produced_items.add(ctx.generated_image_item.id)
630 yield ThreadItemAddedEvent(item=ctx.generated_image_item)
631 elif event.type == "response.image_generation_call.partial_image":
632 if not ctx.generated_image_item:
633 continue
634
635 url = await converter.base64_image_to_url(
636 image_id=event.item_id,
637 base64_image=event.partial_image_b64,
638 partial_image_index=event.partial_image_index,
639 )
640 progress = converter.partial_image_index_to_progress(
641 event.partial_image_index
642 )
643
644 ctx.generated_image_item.image = GeneratedImage(
645 id=event.item_id, url=url
646 )
647
648 yield ThreadItemUpdatedEvent(
649 item_id=ctx.generated_image_item.id,
650 update=GeneratedImageUpdated(
651 image=ctx.generated_image_item.image, progress=progress
652 ),
653 )
654 elif event.type == "response.reasoning_summary_text.delta":
655 if not ctx.workflow_item:
656 continue
657
658 # stream the first thought in a new workflow so that we can show it earlier
659 if (
660 ctx.workflow_item.workflow.type == "reasoning"
661 and len(ctx.workflow_item.workflow.tasks) == 0
662 ):
663 streaming_thought = StreamingThoughtTracker(
664 item_id=event.item_id,
665 index=event.summary_index,
666 task=ThoughtTask(content=event.delta),
667 )
668 ctx.workflow_item.workflow.tasks.append(streaming_thought.task)
669 yield ThreadItemUpdatedEvent(
670 item_id=ctx.workflow_item.id,
671 update=WorkflowTaskAdded(
672 task=streaming_thought.task,
673 task_index=0,
674 ),
675 )
676 elif (
677 streaming_thought
678 and streaming_thought.task in ctx.workflow_item.workflow.tasks
679 and event.item_id == streaming_thought.item_id
680 and event.summary_index == streaming_thought.index
681 ):
682 streaming_thought.task.content += event.delta
683 yield ThreadItemUpdatedEvent(
684 item_id=ctx.workflow_item.id,
685 update=WorkflowTaskUpdated(
686 task=streaming_thought.task,
687 task_index=ctx.workflow_item.workflow.tasks.index(
688 streaming_thought.task
689 ),
690 ),
691 )
692 elif event.type == "response.reasoning_summary_text.done":
693 if ctx.workflow_item:
694 if (
695 streaming_thought
696 and streaming_thought.task in ctx.workflow_item.workflow.tasks
697 and event.item_id == streaming_thought.item_id
698 and event.summary_index == streaming_thought.index
699 ):
700 task = streaming_thought.task
701 task.content = event.text
702 streaming_thought = None
703 update = WorkflowTaskUpdated(
704 task=task,
705 task_index=ctx.workflow_item.workflow.tasks.index(task),
706 )
707 else:
708 task = ThoughtTask(content=event.text)
709 ctx.workflow_item.workflow.tasks.append(task)
710 update = WorkflowTaskAdded(
711 task=task,
712 task_index=ctx.workflow_item.workflow.tasks.index(task),
713 )
714 yield ThreadItemUpdatedEvent(
715 item_id=ctx.workflow_item.id,
716 update=update,
717 )
718 elif event.type == "response.output_item.done":
719 item = event.item
720 if item.type == "message":
721 produced_items.add(item.id)
722 yield ThreadItemDoneEvent(
723 item=AssistantMessageItem(
724 # Reusing the Responses message ID
725 id=item.id,
726 thread_id=thread.id,
727 content=[_convert_content(c) for c in item.content],
728 created_at=datetime.now(),
729 ),
730 )
731 elif item.type == "image_generation_call" and item.result:
732 if not ctx.generated_image_item:
733 continue
734
735 url = await converter.base64_image_to_url(
736 image_id=item.id,
737 base64_image=item.result,
738 )
739 image = GeneratedImage(id=item.id, url=url)
740
741 ctx.generated_image_item.image = image
742 yield ThreadItemDoneEvent(item=ctx.generated_image_item)
743
744 ctx.generated_image_item = None
745
746 except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
747 for item_id in produced_items:
748 yield ThreadItemRemovedEvent(item_id=item_id)
749
750 # Drain remaining events without processing them
751 context._complete()
752 queue_iterator.drain_and_complete()
753
754 raise
755
756 context._complete()
757
758 # Drain remaining events
759 async for event in queue_iterator:
760 yield event.event
761
762 # If there is still an active workflow at the end of the run, store
763 # it's current state so that we can continue it in the next turn.
764 if ctx.workflow_item:
765 await ctx.store.add_thread_item(
766 thread.id, ctx.workflow_item, ctx.request_context
767 )
768
769 if context.client_tool_call:
770 yield ThreadItemDoneEvent(
771 item=ClientToolCallItem(
772 id=current_item_id
773 or context.store.generate_item_id(
774 "tool_call", thread, context.request_context
775 ),
776 thread_id=thread.id,
777 name=context.client_tool_call.name,
778 arguments=context.client_tool_call.arguments,
779 created_at=datetime.now(),
780 call_id=current_tool_call
781 or context.store.generate_item_id(
782 "tool_call", thread, context.request_context
783 ),
784 ),
785 )
786
787
788TWidget = TypeVar("TWidget", bound=Markdown | Text)
789
790
791async def accumulate_text(
792 events: AsyncIterator[StreamEvent],
793 base_widget: TWidget,
794) -> AsyncIterator[TWidget]:
795 text = ""
796 yield base_widget
797 async for event in events:
798 if event.type == "raw_response_event":
799 if event.data.type == "response.output_text.delta":
800 text += event.data.delta
801 yield base_widget.model_copy(update={"value": text})
802 yield base_widget.model_copy(update={"value": text, "streaming": False})
803
804
805class ThreadItemConverter:
806 """
807 Converts thread items to Agent SDK input items.
808 Widgets, Tasks, and Workflows have default conversions but can be customized.
809 Attachments, Tags, and HiddenContextItems require custom handling based on the use case.
810 Other item types are converted automatically.
811 """
812
813 async def attachment_to_message_content(
814 self, attachment: Attachment
815 ) -> ResponseInputContentParam:
816 """
817 Convert an attachment in a user message into a message content part to send to the model.
818 Required when attachments are enabled.
819 """
820 raise NotImplementedError(
821 "An Attachment was included in a UserMessageItem but Converter.attachment_to_message_content was not implemented"
822 )
823
824 async def tag_to_message_content(
825 self, tag: UserMessageTagContent
826 ) -> ResponseInputContentParam:
827 """
828 Convert a tag in a user message into a message content part to send to the model as context.
829 Required when tags are used.
830 """
831 raise NotImplementedError(
832 "A Tag was included in a UserMessageItem but Converter.tag_to_message_content is not implemented"
833 )
834
835 async def generated_image_to_input(
836 self, item: GeneratedImageItem
837 ) -> TResponseInputItem | list[TResponseInputItem] | None:
838 """
839 Convert a GeneratedImageItem into input item(s) to send to the model.
840 Override this method to customize the conversion of generated images, such as when your
841 generated image url is not publicly reachable.
842 """
843 if not item.image:
844 return None
845
846 return Message(
847 type="message",
848 content=[
849 ResponseInputTextParam(
850 type="input_text",
851 text="The following image was generated by the agent.",
852 ),
853 ResponseInputImageParam(
854 type="input_image",
855 detail="auto",
856 image_url=item.image.url,
857 ),
858 ],
859 role="user",
860 )
861
862 async def hidden_context_to_input(
863 self, item: HiddenContextItem
864 ) -> TResponseInputItem | list[TResponseInputItem] | None:
865 """
866 Convert a HiddenContextItem into input item(s) to send to the model.
867 Required to override when HiddenContextItems with non-string content are used.
868 """
869 if not isinstance(item.content, str):
870 raise NotImplementedError(
871 "HiddenContextItems with non-string content were present in a user message but a Converter.hidden_context_to_input was not implemented"
872 )
873
874 text = (
875 "Hidden context for the agent (not shown to the user):\n"
876 f"<HiddenContext>\n{item.content}\n</HiddenContext>"
877 )
878 return Message(
879 type="message",
880 content=[
881 ResponseInputTextParam(
882 type="input_text",
883 text=text,
884 )
885 ],
886 role="user",
887 )
888
889 async def sdk_hidden_context_to_input(
890 self, item: SDKHiddenContextItem
891 ) -> TResponseInputItem | list[TResponseInputItem] | None:
892 """
893 Convert a SDKHiddenContextItem into input item to send to the model.
894 This is used by the ChatKit Python SDK for storing additional context
895 for internal operations.
896 Override if you want to wrap the content in a different format.
897 """
898 text = (
899 "Hidden context for the agent (not shown to the user):\n"
900 f"<HiddenContext>\n{item.content}\n</HiddenContext>"
901 )
902 return Message(
903 type="message",
904 content=[
905 ResponseInputTextParam(
906 type="input_text",
907 text=text,
908 )
909 ],
910 role="user",
911 )
912
913 async def task_to_input(
914 self, item: TaskItem
915 ) -> TResponseInputItem | list[TResponseInputItem] | None:
916 """
917 Convert a TaskItem into input item(s) to send to the model.
918 """
919 if item.task.type != "custom" or (
920 not item.task.title and not item.task.content
921 ):
922 return None
923 title = f"{item.task.title}" if item.task.title else ""
924 content = f"{item.task.content}" if item.task.content else ""
925 task_text = f"{title}: {content}" if title and content else title or content
926 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
927 return Message(
928 type="message",
929 content=[
930 ResponseInputTextParam(
931 type="input_text",
932 text=text,
933 )
934 ],
935 role="user",
936 )
937
938 async def workflow_to_input(
939 self, item: WorkflowItem
940 ) -> TResponseInputItem | list[TResponseInputItem] | None:
941 """
942 Convert a TaskItem into input item(s) to send to the model.
943 Returns WorkflowItem.response_items by default.
944 """
945 messages = []
946 for task in item.workflow.tasks:
947 if task.type != "custom" or (not task.title and not task.content):
948 continue
949
950 title = f"{task.title}" if task.title else ""
951 content = f"{task.content}" if task.content else ""
952 task_text = f"{title}: {content}" if title and content else title or content
953 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
954 messages.append(
955 Message(
956 type="message",
957 content=[
958 ResponseInputTextParam(
959 type="input_text",
960 text=text,
961 )
962 ],
963 role="user",
964 )
965 )
966 return messages
967
968 async def widget_to_input(
969 self, item: WidgetItem
970 ) -> TResponseInputItem | list[TResponseInputItem] | None:
971 """
972 Convert a WidgetItem into input item(s) to send to the model.
973 By default, WidgetItems converted to a text description with a JSON representation of the widget.
974 """
975 return Message(
976 type="message",
977 content=[
978 ResponseInputTextParam(
979 type="input_text",
980 text=f"The following graphical UI widget (id: {item.id}) was displayed to the user:"
981 + item.widget.model_dump_json(
982 exclude_unset=True, exclude_none=True
983 ),
984 )
985 ],
986 role="user",
987 )
988
989 async def user_message_to_input(
990 self, item: UserMessageItem, is_last_message: bool = True
991 ) -> TResponseInputItem | list[TResponseInputItem] | None:
992 # Build the user text exactly as typed, rendering tags as @key
993 message_text_parts: list[str] = []
994 # Track tags separately to add system context
995 raw_tags: list[UserMessageTagContent] = []
996
997 for part in item.content:
998 if isinstance(part, UserMessageTextContent):
999 message_text_parts.append(part.text)
1000 elif isinstance(part, UserMessageTagContent):
1001 message_text_parts.append(f"@{part.text}")
1002 raw_tags.append(part)
1003 else:
1004 assert_never(part)
1005
1006 user_text_item = Message(
1007 role="user",
1008 type="message",
1009 content=[
1010 ResponseInputTextParam(
1011 type="input_text", text="".join(message_text_parts)
1012 ),
1013 *[
1014 await self.attachment_to_message_content(a)
1015 for a in item.attachments
1016 ],
1017 ],
1018 )
1019
1020 # Build system items (prepend later): quoted text and @-mention context
1021 context_items: list[TResponseInputItem] = []
1022
1023 if item.quoted_text and is_last_message:
1024 context_items.append(
1025 Message(
1026 role="user",
1027 type="message",
1028 content=[
1029 ResponseInputTextParam(
1030 type="input_text",
1031 text=f"The user is referring to this in particular: \n{item.quoted_text}",
1032 )
1033 ],
1034 )
1035 )
1036
1037 # Dedupe tags (preserve order) and resolve to message content
1038 if raw_tags:
1039 seen, uniq_tags = set(), []
1040 for t in raw_tags:
1041 if t.text not in seen:
1042 seen.add(t.text)
1043 uniq_tags.append(t)
1044
1045 tag_content: ResponseInputMessageContentListParam = [
1046 # should return summarized text items
1047 await self.tag_to_message_content(tag)
1048 for tag in uniq_tags
1049 ]
1050
1051 if tag_content:
1052 context_items.append(
1053 Message(
1054 role="user",
1055 type="message",
1056 content=[
1057 ResponseInputTextParam(
1058 type="input_text",
1059 text=cleandoc("""
1060 # User-provided context for @-mentions
1061 - When referencing resolved entities, use their canonical names **without** '@'.
1062 - The '@' form appears only in user text and should not be echoed.
1063 """).strip(),
1064 ),
1065 *tag_content,
1066 ],
1067 )
1068 )
1069
1070 return [user_text_item, *context_items]
1071
1072 async def assistant_message_to_input(
1073 self, item: AssistantMessageItem
1074 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1075 return EasyInputMessageParam(
1076 type="message",
1077 content=[
1078 # content param doesn't support the assistant message content types
1079 cast(
1080 ResponseInputContentParam,
1081 ResponseOutputText(
1082 type="output_text",
1083 text=c.text,
1084 annotations=[], # TODO: these should be sent back as well
1085 ).model_dump(),
1086 )
1087 for c in item.content
1088 ],
1089 role="assistant",
1090 )
1091
1092 async def client_tool_call_to_input(
1093 self, item: ClientToolCallItem
1094 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1095 if item.status == "pending":
1096 # Filter out pending tool calls - they cannot be sent to the model
1097 return None
1098
1099 return [
1100 ResponseFunctionToolCallParam(
1101 type="function_call",
1102 call_id=item.call_id,
1103 name=item.name,
1104 arguments=json.dumps(item.arguments),
1105 ),
1106 FunctionCallOutput(
1107 type="function_call_output",
1108 call_id=item.call_id,
1109 output=json.dumps(item.output),
1110 ),
1111 ]
1112
1113 async def end_of_turn_to_input(
1114 self, item: EndOfTurnItem
1115 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1116 # Only used for UI hints - you shouldn't need to override this
1117 return None
1118
1119 async def _thread_item_to_input_item(
1120 self,
1121 item: ThreadItem,
1122 is_last_message: bool = True,
1123 ) -> list[TResponseInputItem]:
1124 match item:
1125 case UserMessageItem():
1126 out = await self.user_message_to_input(item, is_last_message) or []
1127 return out if isinstance(out, list) else [out]
1128 case AssistantMessageItem():
1129 out = await self.assistant_message_to_input(item) or []
1130 return out if isinstance(out, list) else [out]
1131 case ClientToolCallItem():
1132 out = await self.client_tool_call_to_input(item) or []
1133 return out if isinstance(out, list) else [out]
1134 case EndOfTurnItem():
1135 out = await self.end_of_turn_to_input(item) or []
1136 return out if isinstance(out, list) else [out]
1137 case WidgetItem():
1138 out = await self.widget_to_input(item) or []
1139 return out if isinstance(out, list) else [out]
1140 case WorkflowItem():
1141 out = await self.workflow_to_input(item) or []
1142 return out if isinstance(out, list) else [out]
1143 case TaskItem():
1144 out = await self.task_to_input(item) or []
1145 return out if isinstance(out, list) else [out]
1146 case HiddenContextItem():
1147 out = await self.hidden_context_to_input(item) or []
1148 return out if isinstance(out, list) else [out]
1149 case SDKHiddenContextItem():
1150 out = await self.sdk_hidden_context_to_input(item) or []
1151 return out if isinstance(out, list) else [out]
1152 case GeneratedImageItem():
1153 out = await self.generated_image_to_input(item) or []
1154 return out if isinstance(out, list) else [out]
1155 case _:
1156 assert_never(item)
1157
1158 async def to_agent_input(
1159 self,
1160 thread_items: Sequence[ThreadItem] | ThreadItem,
1161 ) -> list[TResponseInputItem]:
1162 if isinstance(thread_items, Sequence):
1163 # shallow copy in case caller mutates the list while we're iterating
1164 thread_items = thread_items[:]
1165 else:
1166 thread_items = [thread_items]
1167 output: list[TResponseInputItem] = []
1168 for item in thread_items:
1169 output.extend(
1170 await self._thread_item_to_input_item(
1171 item,
1172 is_last_message=item is thread_items[-1],
1173 )
1174 )
1175 return output
1176
1177
1178_DEFAULT_CONVERTER = ThreadItemConverter()
1179
1180
1181def simple_to_agent_input(thread_items: Sequence[ThreadItem] | ThreadItem):
1182 """Helper that converts thread items using the default ThreadItemConverter."""
1183 return _DEFAULT_CONVERTER.to_agent_input(thread_items)
1184