openai/chatkit-python

Public

mirrored from https://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
51e0454a7b87e90ed0c9990d11993d9945fdb758

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/agents.py

1146lines · modecode

1import asyncio
2import json
3from collections import defaultdict
4from collections.abc import AsyncIterator
5from datetime import datetime
6from inspect import cleandoc
7from typing import (
8 Annotated,
9 Any,
10 AsyncGenerator,
11 Generic,
12 Sequence,
13 TypeVar,
14 cast,
15)
16
17from agents import (
18 InputGuardrailTripwireTriggered,
19 OutputGuardrailTripwireTriggered,
20 RunResultStreaming,
21 StreamEvent,
22 TResponseInputItem,
23)
24from openai.types.responses import (
25 EasyInputMessageParam,
26 ResponseFunctionToolCallParam,
27 ResponseInputContentParam,
28 ResponseInputMessageContentListParam,
29 ResponseInputTextParam,
30 ResponseOutputText,
31)
32from openai.types.responses.response_input_item_param import (
33 FunctionCallOutput,
34 Message,
35)
36from openai.types.responses.response_output_message import Content
37from openai.types.responses.response_output_text import (
38 Annotation as ResponsesAnnotation,
39)
40from pydantic import BaseModel, ConfigDict, SkipValidation, TypeAdapter
41from typing_extensions import assert_never
42
43from .server import stream_widget
44from .store import Store, StoreItemType
45from .types import (
46 Annotation,
47 AssistantMessageContent,
48 AssistantMessageContentPartAdded,
49 AssistantMessageContentPartAnnotationAdded,
50 AssistantMessageContentPartDone,
51 AssistantMessageContentPartTextDelta,
52 AssistantMessageItem,
53 Attachment,
54 ClientToolCallItem,
55 DurationSummary,
56 EndOfTurnItem,
57 FileSource,
58 GeneratedImage,
59 GeneratedImageItem,
60 GeneratedImageUpdated,
61 HiddenContextItem,
62 SDKHiddenContextItem,
63 Task,
64 TaskItem,
65 ThoughtTask,
66 ThreadItem,
67 ThreadItemAddedEvent,
68 ThreadItemDoneEvent,
69 ThreadItemRemovedEvent,
70 ThreadItemUpdatedEvent,
71 ThreadMetadata,
72 ThreadStreamEvent,
73 URLSource,
74 UserMessageItem,
75 UserMessageTagContent,
76 UserMessageTextContent,
77 WidgetItem,
78 Workflow,
79 WorkflowItem,
80 WorkflowSummary,
81 WorkflowTaskAdded,
82 WorkflowTaskUpdated,
83)
84from .widgets import Markdown, Text, WidgetRoot
85
86
87class ClientToolCall(BaseModel):
88 """
89 Returned from tool methods to indicate a client-side tool call.
90 """
91
92 name: str
93 arguments: dict[str, Any]
94
95
96class _QueueCompleteSentinel: ...
97
98
99TContext = TypeVar("TContext")
100
101
102class AgentContext(BaseModel, Generic[TContext]):
103 model_config = ConfigDict(arbitrary_types_allowed=True)
104
105 thread: ThreadMetadata
106 store: Annotated[Store[TContext], SkipValidation]
107 request_context: TContext
108 previous_response_id: str | None = None
109 client_tool_call: ClientToolCall | None = None
110 workflow_item: WorkflowItem | None = None
111 generated_image_item: GeneratedImageItem | None = None
112 _events: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel] = asyncio.Queue()
113
114 def generate_id(
115 self, type: StoreItemType, thread: ThreadMetadata | None = None
116 ) -> str:
117 """Generate a new store-backed id for the given item type."""
118 if type == "thread":
119 return self.store.generate_thread_id(self.request_context)
120 return self.store.generate_item_id(
121 type, thread or self.thread, self.request_context
122 )
123
124 async def stream_widget(
125 self,
126 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
127 copy_text: str | None = None,
128 ) -> None:
129 """Stream a widget into the thread by enqueueing widget events."""
130 async for event in stream_widget(
131 self.thread,
132 widget,
133 copy_text,
134 lambda item_type: self.store.generate_item_id(
135 item_type, self.thread, self.request_context
136 ),
137 ):
138 await self._events.put(event)
139
140 async def end_workflow(
141 self, summary: WorkflowSummary | None = None, expanded: bool = False
142 ) -> None:
143 """Finalize the active workflow item, optionally attaching a summary."""
144 if not self.workflow_item:
145 # No workflow to end
146 return
147
148 if summary is not None:
149 self.workflow_item.workflow.summary = summary
150 elif self.workflow_item.workflow.summary is None:
151 # If no summary was set or provided, set a basic work summary
152 delta = datetime.now() - self.workflow_item.created_at
153 duration = int(delta.total_seconds())
154 self.workflow_item.workflow.summary = DurationSummary(duration=duration)
155 self.workflow_item.workflow.expanded = expanded
156 await self.stream(ThreadItemDoneEvent(item=self.workflow_item))
157 self.workflow_item = None
158
159 async def start_workflow(self, workflow: Workflow) -> None:
160 """Begin streaming a new workflow item."""
161 self.workflow_item = WorkflowItem(
162 id=self.generate_id("workflow"),
163 created_at=datetime.now(),
164 workflow=workflow,
165 thread_id=self.thread.id,
166 )
167
168 if workflow.type != "reasoning" and len(workflow.tasks) == 0:
169 # Defer sending added event until we have tasks
170 return
171
172 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
173
174 async def update_workflow_task(self, task: Task, task_index: int) -> None:
175 """Update an existing workflow task and stream the delta."""
176 if self.workflow_item is None:
177 raise ValueError("Workflow is not set")
178 # ensure reference is updated in case task is a copy
179 self.workflow_item.workflow.tasks[task_index] = task
180 await self.stream(
181 ThreadItemUpdatedEvent(
182 item_id=self.workflow_item.id,
183 update=WorkflowTaskUpdated(
184 task=task,
185 task_index=task_index,
186 ),
187 )
188 )
189
190 async def add_workflow_task(self, task: Task) -> None:
191 """Append a workflow task and stream the appropriate event."""
192 self.workflow_item = self.workflow_item or WorkflowItem(
193 id=self.generate_id("workflow"),
194 created_at=datetime.now(),
195 workflow=Workflow(type="custom", tasks=[]),
196 thread_id=self.thread.id,
197 )
198 workflow = self.workflow_item.workflow
199 workflow.tasks.append(task)
200
201 if workflow.type != "reasoning" and len(workflow.tasks) == 1:
202 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
203 else:
204 await self.stream(
205 ThreadItemUpdatedEvent(
206 item_id=self.workflow_item.id,
207 update=WorkflowTaskAdded(
208 task=task,
209 task_index=workflow.tasks.index(task),
210 ),
211 )
212 )
213
214 async def stream(self, event: ThreadStreamEvent) -> None:
215 """Enqueue a ThreadStreamEvent for downstream processing."""
216 await self._events.put(event)
217
218 def _complete(self):
219 self._events.put_nowait(_QueueCompleteSentinel())
220
221
222def _convert_content(content: Content) -> AssistantMessageContent:
223 if content.type == "output_text":
224 annotations = [
225 _convert_annotation(annotation) for annotation in content.annotations
226 ]
227 annotations = [a for a in annotations if a is not None]
228 return AssistantMessageContent(
229 text=content.text,
230 annotations=annotations,
231 )
232 else:
233 return AssistantMessageContent(
234 text=content.refusal,
235 annotations=[],
236 )
237
238
239def _convert_annotation(raw_annotation: object) -> Annotation | None:
240 # There is a bug in the OpenAPI client that sometimes parses the annotation delta event into the wrong class
241 # resulting into annotation being a dict or untyped object instead instead of a ResponsesAnnotation
242 annotation = TypeAdapter[ResponsesAnnotation](ResponsesAnnotation).validate_python(
243 raw_annotation
244 )
245
246 if annotation.type == "file_citation":
247 filename = annotation.filename
248 if not filename:
249 return None
250
251 return Annotation(
252 source=FileSource(filename=filename, title=filename),
253 index=annotation.index,
254 )
255
256 if annotation.type == "url_citation":
257 return Annotation(
258 source=URLSource(
259 url=annotation.url,
260 title=annotation.title,
261 ),
262 index=annotation.end_index,
263 )
264
265 if annotation.type == "container_file_citation":
266 filename = annotation.filename
267 if not filename:
268 return None
269
270 return Annotation(
271 source=FileSource(filename=filename, title=filename),
272 index=annotation.end_index,
273 )
274
275 return None
276
277
278T1 = TypeVar("T1")
279T2 = TypeVar("T2")
280
281
282async def _merge_generators(
283 a: AsyncIterator[T1],
284 b: AsyncIterator[T2],
285) -> AsyncIterator[T1 | T2]:
286 pending: list[AsyncIterator[T1 | T2]] = [a, b]
287 pending_tasks: dict[asyncio.Task, AsyncIterator[T1 | T2]] = {
288 asyncio.ensure_future(g.__anext__()): g for g in pending
289 }
290 while len(pending_tasks) > 0:
291 done, _ = await asyncio.wait(
292 pending_tasks.keys(), return_when="FIRST_COMPLETED"
293 )
294 stop = False
295 for d in done:
296 try:
297 result = d.result()
298 yield result
299 dg = pending_tasks[d]
300 pending_tasks[asyncio.ensure_future(dg.__anext__())] = dg
301 except StopAsyncIteration:
302 stop = True
303 finally:
304 del pending_tasks[d]
305 if stop:
306 for task in pending_tasks.keys():
307 if not task.cancel():
308 try:
309 yield task.result()
310 except asyncio.CancelledError:
311 pass
312 except asyncio.InvalidStateError:
313 pass
314 break
315
316
317class _EventWrapper:
318 def __init__(self, event: ThreadStreamEvent):
319 self.event = event
320
321
322class _AsyncQueueIterator(AsyncIterator[_EventWrapper]):
323 def __init__(
324 self, queue: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel]
325 ):
326 self.queue = queue
327 self.completed = False
328
329 def __aiter__(self):
330 return self
331
332 async def __anext__(self):
333 if self.completed:
334 raise StopAsyncIteration
335
336 item = await self.queue.get()
337 if isinstance(item, _QueueCompleteSentinel):
338 self.completed = True
339 raise StopAsyncIteration
340 return _EventWrapper(item)
341
342 def drain_and_complete(self) -> None:
343 """Empty the underlying queue without awaiting and mark this iterator completed.
344
345 This is intended for cleanup paths where we must guarantee no awaits
346 occur. All queued items, including any completion sentinel, are
347 discarded.
348 """
349 while True:
350 try:
351 self.queue.get_nowait()
352 except asyncio.QueueEmpty:
353 break
354 self.completed = True
355
356
357class StreamingThoughtTracker(BaseModel):
358 item_id: str
359 index: int
360 task: ThoughtTask
361
362
363class ResponseStreamConverter:
364 """Used by `stream_agent_response` to convert streamed Agents SDK output
365 into values used by ChatKit thread items and thread stream events.
366
367 Defines overridable methods for adapting streamed data (such as image
368 generation results and partial updates) into the forms expected by ChatKit.
369 """
370
371 partial_images: int | None = None
372 """
373 The expected number of partial image updates for an image generation result.
374
375 When set, this value is used to normalize partial image indices into a
376 progress value in the range [0, 1]. If unset, all partial image updates are
377 assigned a progress value of 0.
378 """
379
380 def __init__(self, partial_images: int | None = None):
381 """
382 Args:
383 partial_images: The expected number of partial image updates for image
384 generation results, or None if no progress normalization should
385 be performed.
386 """
387 self.partial_images = partial_images
388
389 async def base64_image_to_url(self, base64_image: str) -> str:
390 """
391 Convert a base64-encoded image into a URL.
392
393 This method is used to produce the URL stored on thread items for image
394 generation results.
395 """
396 return f"data:image/png;base64,{base64_image}"
397
398 def partial_image_index_to_progress(self, partial_image_index: int) -> float:
399 """
400 Convert a partial image index into a normalized progress value.
401
402 Args:
403 partial_image_index: The index of the partial image update, starting from 0.
404
405 Returns:
406 A float between 0 and 1 representing progress for the image
407 generation result.
408 """
409 if self.partial_images is None:
410 return 0
411
412 return partial_image_index / self.partial_images
413
414
415_DEFAULT_RESPONSE_STREAM_CONVERTER = ResponseStreamConverter()
416
417
418async def stream_agent_response(
419 context: AgentContext,
420 result: RunResultStreaming,
421 *,
422 converter: ResponseStreamConverter = _DEFAULT_RESPONSE_STREAM_CONVERTER,
423) -> AsyncIterator[ThreadStreamEvent]:
424 """
425 Convert a streamed Agents SDK run into ChatKit thread stream events.
426
427 This function consumes a streaming run result and yields `ThreadStreamEvent`
428 objects as the run progresses.
429
430 Args:
431 context: The AgentContext to use for the stream.
432 result: The RunResultStreaming to convert.
433 image_generation_stream_converter: Controls how streamed image generation output
434 is converted into URLs and progress updates. The default converter stores the
435 generated base64 image and assigns a progress value of 0 to all partial image
436 updates.
437
438 Returns:
439 An async iterator that yields thread stream events representing the run result.
440 """
441 current_item_id = None
442 current_tool_call = None
443 ctx = context
444 thread = context.thread
445 queue_iterator = _AsyncQueueIterator(context._events)
446 produced_items = set()
447 streaming_thought: None | StreamingThoughtTracker = None
448 # item_id -> content_index -> annotation count
449 item_annotation_count: defaultdict[str, defaultdict[int, int]] = defaultdict(
450 lambda: defaultdict(int)
451 )
452
453 # check if the last item in the thread was a workflow or a client tool call
454 # if it was a client tool call, check if the second last item was a workflow
455 # if either was, continue the workflow
456 items = await context.store.load_thread_items(
457 thread.id, None, 2, "desc", context.request_context
458 )
459 last_item = items.data[0] if len(items.data) > 0 else None
460 second_last_item = items.data[1] if len(items.data) > 1 else None
461
462 if last_item and last_item.type == "workflow":
463 ctx.workflow_item = last_item
464 elif (
465 last_item
466 and last_item.type == "client_tool_call"
467 and second_last_item
468 and second_last_item.type == "workflow"
469 ):
470 ctx.workflow_item = second_last_item
471
472 def end_workflow(item: WorkflowItem):
473 if item == ctx.workflow_item:
474 ctx.workflow_item = None
475 delta = datetime.now() - item.created_at
476 duration = int(delta.total_seconds())
477 if item.workflow.summary is None:
478 item.workflow.summary = DurationSummary(duration=duration)
479 # Default to closing all workflows
480 # To keep a workflow open on completion, close it explicitly with
481 # AgentContext.end_workflow(expanded=True)
482 item.workflow.expanded = False
483 return ThreadItemDoneEvent(item=item)
484
485 try:
486 async for event in _merge_generators(result.stream_events(), queue_iterator):
487 # Events emitted from agent context helpers
488 if isinstance(event, _EventWrapper):
489 event = event.event
490 if (
491 event.type == "thread.item.added"
492 or event.type == "thread.item.done"
493 ):
494 # End the current workflow if visual item is added after it
495 if (
496 ctx.workflow_item
497 and ctx.workflow_item.id != event.item.id
498 and event.item.type != "client_tool_call"
499 and event.item.type != "hidden_context_item"
500 ):
501 yield end_workflow(ctx.workflow_item)
502
503 # track the current workflow if one is added
504 if (
505 event.type == "thread.item.added"
506 and event.item.type == "workflow"
507 ):
508 ctx.workflow_item = event.item
509
510 # track integration produced items so we can clean them up if
511 # there is a guardrail tripwire
512 produced_items.add(event.item.id)
513 yield event
514 continue
515
516 if event.type == "run_item_stream_event":
517 event = event.item
518 if (
519 event.type == "tool_call_item"
520 and event.raw_item.type == "function_call"
521 ):
522 current_tool_call = event.raw_item.call_id
523 current_item_id = event.raw_item.id
524 assert current_item_id
525 produced_items.add(current_item_id)
526 continue
527
528 if event.type != "raw_response_event":
529 # Ignore everything else that isn't a raw response event
530 continue
531
532 # Handle Responses events
533 event = event.data
534 if event.type == "response.content_part.added":
535 if event.part.type == "reasoning_text":
536 continue
537 content = _convert_content(event.part)
538 yield ThreadItemUpdatedEvent(
539 item_id=event.item_id,
540 update=AssistantMessageContentPartAdded(
541 content_index=event.content_index,
542 content=content,
543 ),
544 )
545 elif event.type == "response.output_text.delta":
546 yield ThreadItemUpdatedEvent(
547 item_id=event.item_id,
548 update=AssistantMessageContentPartTextDelta(
549 content_index=event.content_index,
550 delta=event.delta,
551 ),
552 )
553 elif event.type == "response.output_text.done":
554 yield ThreadItemUpdatedEvent(
555 item_id=event.item_id,
556 update=AssistantMessageContentPartDone(
557 content_index=event.content_index,
558 content=AssistantMessageContent(
559 text=event.text,
560 annotations=[],
561 ),
562 ),
563 )
564 elif event.type == "response.output_text.annotation.added":
565 annotation = _convert_annotation(event.annotation)
566 if annotation:
567 # Manually track annotation indices per content part in case we drop an annotation that
568 # we can't convert to our internal representation (e.g. missing filename).
569 annotation_index = item_annotation_count[event.item_id][
570 event.content_index
571 ]
572 item_annotation_count[event.item_id][event.content_index] = (
573 annotation_index + 1
574 )
575 yield ThreadItemUpdatedEvent(
576 item_id=event.item_id,
577 update=AssistantMessageContentPartAnnotationAdded(
578 content_index=event.content_index,
579 annotation_index=annotation_index,
580 annotation=annotation,
581 ),
582 )
583 continue
584 elif event.type == "response.output_item.added":
585 item = event.item
586 if item.type == "reasoning" and not ctx.workflow_item:
587 ctx.workflow_item = WorkflowItem(
588 id=ctx.generate_id("workflow"),
589 created_at=datetime.now(),
590 workflow=Workflow(type="reasoning", tasks=[]),
591 thread_id=thread.id,
592 )
593 produced_items.add(ctx.workflow_item.id)
594 yield ThreadItemAddedEvent(item=ctx.workflow_item)
595 if item.type == "message":
596 if ctx.workflow_item:
597 yield end_workflow(ctx.workflow_item)
598 produced_items.add(item.id)
599 yield ThreadItemAddedEvent(
600 item=AssistantMessageItem(
601 # Reusing the Responses message ID
602 id=item.id,
603 thread_id=thread.id,
604 content=[_convert_content(c) for c in item.content],
605 created_at=datetime.now(),
606 ),
607 )
608 elif item.type == "image_generation_call":
609 ctx.generated_image_item = GeneratedImageItem(
610 id=ctx.generate_id("message"),
611 thread_id=thread.id,
612 created_at=datetime.now(),
613 image=None,
614 )
615 produced_items.add(ctx.generated_image_item.id)
616 yield ThreadItemAddedEvent(item=ctx.generated_image_item)
617 elif event.type == "response.reasoning_summary_text.delta":
618 if not ctx.workflow_item:
619 continue
620
621 # stream the first thought in a new workflow so that we can show it earlier
622 if (
623 ctx.workflow_item.workflow.type == "reasoning"
624 and len(ctx.workflow_item.workflow.tasks) == 0
625 ):
626 streaming_thought = StreamingThoughtTracker(
627 item_id=event.item_id,
628 index=event.summary_index,
629 task=ThoughtTask(content=event.delta),
630 )
631 ctx.workflow_item.workflow.tasks.append(streaming_thought.task)
632 yield ThreadItemUpdatedEvent(
633 item_id=ctx.workflow_item.id,
634 update=WorkflowTaskAdded(
635 task=streaming_thought.task,
636 task_index=0,
637 ),
638 )
639 elif (
640 streaming_thought
641 and streaming_thought.task in ctx.workflow_item.workflow.tasks
642 and event.item_id == streaming_thought.item_id
643 and event.summary_index == streaming_thought.index
644 ):
645 streaming_thought.task.content += event.delta
646 yield ThreadItemUpdatedEvent(
647 item_id=ctx.workflow_item.id,
648 update=WorkflowTaskUpdated(
649 task=streaming_thought.task,
650 task_index=ctx.workflow_item.workflow.tasks.index(
651 streaming_thought.task
652 ),
653 ),
654 )
655 elif event.type == "response.reasoning_summary_text.done":
656 if ctx.workflow_item:
657 if (
658 streaming_thought
659 and streaming_thought.task in ctx.workflow_item.workflow.tasks
660 and event.item_id == streaming_thought.item_id
661 and event.summary_index == streaming_thought.index
662 ):
663 task = streaming_thought.task
664 task.content = event.text
665 streaming_thought = None
666 update = WorkflowTaskUpdated(
667 task=task,
668 task_index=ctx.workflow_item.workflow.tasks.index(task),
669 )
670 else:
671 task = ThoughtTask(content=event.text)
672 ctx.workflow_item.workflow.tasks.append(task)
673 update = WorkflowTaskAdded(
674 task=task,
675 task_index=ctx.workflow_item.workflow.tasks.index(task),
676 )
677 yield ThreadItemUpdatedEvent(
678 item_id=ctx.workflow_item.id,
679 update=update,
680 )
681 elif event.type == "response.output_item.done":
682 item = event.item
683 if item.type == "message":
684 produced_items.add(item.id)
685 yield ThreadItemDoneEvent(
686 item=AssistantMessageItem(
687 # Reusing the Responses message ID
688 id=item.id,
689 thread_id=thread.id,
690 content=[_convert_content(c) for c in item.content],
691 created_at=datetime.now(),
692 ),
693 )
694 elif item.type == "image_generation_call" and item.result:
695 if not ctx.generated_image_item:
696 continue
697
698 url = await converter.base64_image_to_url(item.result)
699 image = GeneratedImage(id=item.id, url=url)
700
701 ctx.generated_image_item.image = image
702 yield ThreadItemDoneEvent(item=ctx.generated_image_item)
703
704 ctx.generated_image_item = None
705 elif event.type == "response.image_generation_call.partial_image":
706 if not ctx.generated_image_item:
707 continue
708
709 url = await converter.base64_image_to_url(event.partial_image_b64)
710 progress = converter.partial_image_index_to_progress(
711 event.partial_image_index
712 )
713
714 ctx.generated_image_item.image = GeneratedImage(
715 id=event.item_id, url=url
716 )
717
718 yield ThreadItemUpdatedEvent(
719 item_id=ctx.generated_image_item.id,
720 update=GeneratedImageUpdated(
721 image=ctx.generated_image_item.image, progress=progress
722 ),
723 )
724
725 except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
726 for item_id in produced_items:
727 yield ThreadItemRemovedEvent(item_id=item_id)
728
729 # Drain remaining events without processing them
730 context._complete()
731 queue_iterator.drain_and_complete()
732
733 raise
734
735 context._complete()
736
737 # Drain remaining events
738 async for event in queue_iterator:
739 yield event.event
740
741 # If there is still an active workflow at the end of the run, store
742 # it's current state so that we can continue it in the next turn.
743 if ctx.workflow_item:
744 await ctx.store.add_thread_item(
745 thread.id, ctx.workflow_item, ctx.request_context
746 )
747
748 if context.client_tool_call:
749 yield ThreadItemDoneEvent(
750 item=ClientToolCallItem(
751 id=current_item_id
752 or context.store.generate_item_id(
753 "tool_call", thread, context.request_context
754 ),
755 thread_id=thread.id,
756 name=context.client_tool_call.name,
757 arguments=context.client_tool_call.arguments,
758 created_at=datetime.now(),
759 call_id=current_tool_call
760 or context.store.generate_item_id(
761 "tool_call", thread, context.request_context
762 ),
763 ),
764 )
765
766
767TWidget = TypeVar("TWidget", bound=Markdown | Text)
768
769
770async def accumulate_text(
771 events: AsyncIterator[StreamEvent],
772 base_widget: TWidget,
773) -> AsyncIterator[TWidget]:
774 text = ""
775 yield base_widget
776 async for event in events:
777 if event.type == "raw_response_event":
778 if event.data.type == "response.output_text.delta":
779 text += event.data.delta
780 yield base_widget.model_copy(update={"value": text})
781 yield base_widget.model_copy(update={"value": text, "streaming": False})
782
783
784class ThreadItemConverter:
785 """
786 Converts thread items to Agent SDK input items.
787 Widgets, Tasks, and Workflows have default conversions but can be customized.
788 Attachments, Tags, and HiddenContextItems require custom handling based on the use case.
789 Other item types are converted automatically.
790 """
791
792 async def attachment_to_message_content(
793 self, attachment: Attachment
794 ) -> ResponseInputContentParam:
795 """
796 Convert an attachment in a user message into a message content part to send to the model.
797 Required when attachments are enabled.
798 """
799 raise NotImplementedError(
800 "An Attachment was included in a UserMessageItem but Converter.attachment_to_message_content was not implemented"
801 )
802
803 async def tag_to_message_content(
804 self, tag: UserMessageTagContent
805 ) -> ResponseInputContentParam:
806 """
807 Convert a tag in a user message into a message content part to send to the model as context.
808 Required when tags are used.
809 """
810 raise NotImplementedError(
811 "A Tag was included in a UserMessageItem but Converter.tag_to_message_content is not implemented"
812 )
813
814 async def generated_image_to_input(
815 self, item: GeneratedImageItem
816 ) -> TResponseInputItem | list[TResponseInputItem] | None:
817 """
818 Convert a GeneratedImageItem into input item(s) to send to the model.
819 Required when generated images are enabled.
820 """
821 raise NotImplementedError(
822 "A GeneratedImageItem was included in a UserMessageItem but Converter.generated_image_to_message_content was not implemented"
823 )
824
825 async def hidden_context_to_input(
826 self, item: HiddenContextItem
827 ) -> TResponseInputItem | list[TResponseInputItem] | None:
828 """
829 Convert a HiddenContextItem into input item(s) to send to the model.
830 Required to override when HiddenContextItems with non-string content are used.
831 """
832 if not isinstance(item.content, str):
833 raise NotImplementedError(
834 "HiddenContextItems with non-string content were present in a user message but a Converter.hidden_context_to_input was not implemented"
835 )
836
837 text = (
838 "Hidden context for the agent (not shown to the user):\n"
839 f"<HiddenContext>\n{item.content}\n</HiddenContext>"
840 )
841 return Message(
842 type="message",
843 content=[
844 ResponseInputTextParam(
845 type="input_text",
846 text=text,
847 )
848 ],
849 role="user",
850 )
851
852 async def sdk_hidden_context_to_input(
853 self, item: SDKHiddenContextItem
854 ) -> TResponseInputItem | list[TResponseInputItem] | None:
855 """
856 Convert a SDKHiddenContextItem into input item to send to the model.
857 This is used by the ChatKit Python SDK for storing additional context
858 for internal operations.
859 Override if you want to wrap the content in a different format.
860 """
861 text = (
862 "Hidden context for the agent (not shown to the user):\n"
863 f"<HiddenContext>\n{item.content}\n</HiddenContext>"
864 )
865 return Message(
866 type="message",
867 content=[
868 ResponseInputTextParam(
869 type="input_text",
870 text=text,
871 )
872 ],
873 role="user",
874 )
875
876 async def task_to_input(
877 self, item: TaskItem
878 ) -> TResponseInputItem | list[TResponseInputItem] | None:
879 """
880 Convert a TaskItem into input item(s) to send to the model.
881 """
882 if item.task.type != "custom" or (
883 not item.task.title and not item.task.content
884 ):
885 return None
886 title = f"{item.task.title}" if item.task.title else ""
887 content = f"{item.task.content}" if item.task.content else ""
888 task_text = f"{title}: {content}" if title and content else title or content
889 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
890 return Message(
891 type="message",
892 content=[
893 ResponseInputTextParam(
894 type="input_text",
895 text=text,
896 )
897 ],
898 role="user",
899 )
900
901 async def workflow_to_input(
902 self, item: WorkflowItem
903 ) -> TResponseInputItem | list[TResponseInputItem] | None:
904 """
905 Convert a TaskItem into input item(s) to send to the model.
906 Returns WorkflowItem.response_items by default.
907 """
908 messages = []
909 for task in item.workflow.tasks:
910 if task.type != "custom" or (not task.title and not task.content):
911 continue
912
913 title = f"{task.title}" if task.title else ""
914 content = f"{task.content}" if task.content else ""
915 task_text = f"{title}: {content}" if title and content else title or content
916 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
917 messages.append(
918 Message(
919 type="message",
920 content=[
921 ResponseInputTextParam(
922 type="input_text",
923 text=text,
924 )
925 ],
926 role="user",
927 )
928 )
929 return messages
930
931 async def widget_to_input(
932 self, item: WidgetItem
933 ) -> TResponseInputItem | list[TResponseInputItem] | None:
934 """
935 Convert a WidgetItem into input item(s) to send to the model.
936 By default, WidgetItems converted to a text description with a JSON representation of the widget.
937 """
938 return Message(
939 type="message",
940 content=[
941 ResponseInputTextParam(
942 type="input_text",
943 text=f"The following graphical UI widget (id: {item.id}) was displayed to the user:"
944 + item.widget.model_dump_json(
945 exclude_unset=True, exclude_none=True
946 ),
947 )
948 ],
949 role="user",
950 )
951
952 async def user_message_to_input(
953 self, item: UserMessageItem, is_last_message: bool = True
954 ) -> TResponseInputItem | list[TResponseInputItem] | None:
955 # Build the user text exactly as typed, rendering tags as @key
956 message_text_parts: list[str] = []
957 # Track tags separately to add system context
958 raw_tags: list[UserMessageTagContent] = []
959
960 for part in item.content:
961 if isinstance(part, UserMessageTextContent):
962 message_text_parts.append(part.text)
963 elif isinstance(part, UserMessageTagContent):
964 message_text_parts.append(f"@{part.text}")
965 raw_tags.append(part)
966 else:
967 assert_never(part)
968
969 user_text_item = Message(
970 role="user",
971 type="message",
972 content=[
973 ResponseInputTextParam(
974 type="input_text", text="".join(message_text_parts)
975 ),
976 *[
977 await self.attachment_to_message_content(a)
978 for a in item.attachments
979 ],
980 ],
981 )
982
983 # Build system items (prepend later): quoted text and @-mention context
984 context_items: list[TResponseInputItem] = []
985
986 if item.quoted_text and is_last_message:
987 context_items.append(
988 Message(
989 role="user",
990 type="message",
991 content=[
992 ResponseInputTextParam(
993 type="input_text",
994 text=f"The user is referring to this in particular: \n{item.quoted_text}",
995 )
996 ],
997 )
998 )
999
1000 # Dedupe tags (preserve order) and resolve to message content
1001 if raw_tags:
1002 seen, uniq_tags = set(), []
1003 for t in raw_tags:
1004 if t.text not in seen:
1005 seen.add(t.text)
1006 uniq_tags.append(t)
1007
1008 tag_content: ResponseInputMessageContentListParam = [
1009 # should return summarized text items
1010 await self.tag_to_message_content(tag)
1011 for tag in uniq_tags
1012 ]
1013
1014 if tag_content:
1015 context_items.append(
1016 Message(
1017 role="user",
1018 type="message",
1019 content=[
1020 ResponseInputTextParam(
1021 type="input_text",
1022 text=cleandoc("""
1023 # User-provided context for @-mentions
1024 - When referencing resolved entities, use their canonical names **without** '@'.
1025 - The '@' form appears only in user text and should not be echoed.
1026 """).strip(),
1027 ),
1028 *tag_content,
1029 ],
1030 )
1031 )
1032
1033 return [user_text_item, *context_items]
1034
1035 async def assistant_message_to_input(
1036 self, item: AssistantMessageItem
1037 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1038 return EasyInputMessageParam(
1039 type="message",
1040 content=[
1041 # content param doesn't support the assistant message content types
1042 cast(
1043 ResponseInputContentParam,
1044 ResponseOutputText(
1045 type="output_text",
1046 text=c.text,
1047 annotations=[], # TODO: these should be sent back as well
1048 ).model_dump(),
1049 )
1050 for c in item.content
1051 ],
1052 role="assistant",
1053 )
1054
1055 async def client_tool_call_to_input(
1056 self, item: ClientToolCallItem
1057 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1058 if item.status == "pending":
1059 # Filter out pending tool calls - they cannot be sent to the model
1060 return None
1061
1062 return [
1063 ResponseFunctionToolCallParam(
1064 type="function_call",
1065 call_id=item.call_id,
1066 name=item.name,
1067 arguments=json.dumps(item.arguments),
1068 ),
1069 FunctionCallOutput(
1070 type="function_call_output",
1071 call_id=item.call_id,
1072 output=json.dumps(item.output),
1073 ),
1074 ]
1075
1076 async def end_of_turn_to_input(
1077 self, item: EndOfTurnItem
1078 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1079 # Only used for UI hints - you shouldn't need to override this
1080 return None
1081
1082 async def _thread_item_to_input_item(
1083 self,
1084 item: ThreadItem,
1085 is_last_message: bool = True,
1086 ) -> list[TResponseInputItem]:
1087 match item:
1088 case UserMessageItem():
1089 out = await self.user_message_to_input(item, is_last_message) or []
1090 return out if isinstance(out, list) else [out]
1091 case AssistantMessageItem():
1092 out = await self.assistant_message_to_input(item) or []
1093 return out if isinstance(out, list) else [out]
1094 case ClientToolCallItem():
1095 out = await self.client_tool_call_to_input(item) or []
1096 return out if isinstance(out, list) else [out]
1097 case EndOfTurnItem():
1098 out = await self.end_of_turn_to_input(item) or []
1099 return out if isinstance(out, list) else [out]
1100 case WidgetItem():
1101 out = await self.widget_to_input(item) or []
1102 return out if isinstance(out, list) else [out]
1103 case WorkflowItem():
1104 out = await self.workflow_to_input(item) or []
1105 return out if isinstance(out, list) else [out]
1106 case TaskItem():
1107 out = await self.task_to_input(item) or []
1108 return out if isinstance(out, list) else [out]
1109 case HiddenContextItem():
1110 out = await self.hidden_context_to_input(item) or []
1111 return out if isinstance(out, list) else [out]
1112 case SDKHiddenContextItem():
1113 out = await self.sdk_hidden_context_to_input(item) or []
1114 return out if isinstance(out, list) else [out]
1115 case GeneratedImageItem():
1116 out = await self.generated_image_to_input(item) or []
1117 return out if isinstance(out, list) else [out]
1118 case _:
1119 assert_never(item)
1120
1121 async def to_agent_input(
1122 self,
1123 thread_items: Sequence[ThreadItem] | ThreadItem,
1124 ) -> list[TResponseInputItem]:
1125 if isinstance(thread_items, Sequence):
1126 # shallow copy in case caller mutates the list while we're iterating
1127 thread_items = thread_items[:]
1128 else:
1129 thread_items = [thread_items]
1130 output: list[TResponseInputItem] = []
1131 for item in thread_items:
1132 output.extend(
1133 await self._thread_item_to_input_item(
1134 item,
1135 is_last_message=item is thread_items[-1],
1136 )
1137 )
1138 return output
1139
1140
1141_DEFAULT_CONVERTER = ThreadItemConverter()
1142
1143
1144def simple_to_agent_input(thread_items: Sequence[ThreadItem] | ThreadItem):
1145 """Helper that converts thread items using the default ThreadItemConverter."""
1146 return _DEFAULT_CONVERTER.to_agent_input(thread_items)
1147