openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
bd7af0dda574f4d124e8682a2d1a78faa1076240

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/agents.py

1181lines · modecode

1import asyncio
2import json
3from collections import defaultdict
4from collections.abc import AsyncIterator
5from datetime import datetime
6from inspect import cleandoc
7from typing import (
8 Annotated,
9 Any,
10 AsyncGenerator,
11 Generic,
12 Sequence,
13 TypeVar,
14 cast,
15)
16
17from agents import (
18 InputGuardrailTripwireTriggered,
19 OutputGuardrailTripwireTriggered,
20 RunResultStreaming,
21 StreamEvent,
22 TResponseInputItem,
23)
24from openai.types.responses import (
25 EasyInputMessageParam,
26 ResponseFunctionToolCallParam,
27 ResponseInputContentParam,
28 ResponseInputImageParam,
29 ResponseInputMessageContentListParam,
30 ResponseInputTextParam,
31 ResponseOutputText,
32)
33from openai.types.responses.response_input_item_param import (
34 FunctionCallOutput,
35 Message,
36)
37from openai.types.responses.response_output_message import Content
38from openai.types.responses.response_output_text import (
39 Annotation as ResponsesAnnotation,
40)
41from pydantic import BaseModel, ConfigDict, SkipValidation, TypeAdapter
42from typing_extensions import assert_never
43
44from .server import stream_widget
45from .store import Store, StoreItemType
46from .types import (
47 Annotation,
48 AssistantMessageContent,
49 AssistantMessageContentPartAdded,
50 AssistantMessageContentPartAnnotationAdded,
51 AssistantMessageContentPartDone,
52 AssistantMessageContentPartTextDelta,
53 AssistantMessageItem,
54 Attachment,
55 ClientToolCallItem,
56 DurationSummary,
57 EndOfTurnItem,
58 FileSource,
59 GeneratedImage,
60 GeneratedImageItem,
61 GeneratedImageUpdated,
62 HiddenContextItem,
63 SDKHiddenContextItem,
64 Task,
65 TaskItem,
66 ThoughtTask,
67 ThreadItem,
68 ThreadItemAddedEvent,
69 ThreadItemDoneEvent,
70 ThreadItemRemovedEvent,
71 ThreadItemUpdatedEvent,
72 ThreadMetadata,
73 ThreadStreamEvent,
74 URLSource,
75 UserMessageItem,
76 UserMessageTagContent,
77 UserMessageTextContent,
78 WidgetItem,
79 Workflow,
80 WorkflowItem,
81 WorkflowSummary,
82 WorkflowTaskAdded,
83 WorkflowTaskUpdated,
84)
85from .widgets import Markdown, Text, WidgetRoot
86
87
88class ClientToolCall(BaseModel):
89 """
90 Returned from tool methods to indicate a client-side tool call.
91 """
92
93 name: str
94 arguments: dict[str, Any]
95
96
97class _QueueCompleteSentinel: ...
98
99
100TContext = TypeVar("TContext")
101
102
103class AgentContext(BaseModel, Generic[TContext]):
104 model_config = ConfigDict(arbitrary_types_allowed=True)
105
106 thread: ThreadMetadata
107 store: Annotated[Store[TContext], SkipValidation]
108 request_context: TContext
109 previous_response_id: str | None = None
110 client_tool_call: ClientToolCall | None = None
111 workflow_item: WorkflowItem | None = None
112 generated_image_item: GeneratedImageItem | None = None
113 _events: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel] = asyncio.Queue()
114
115 def generate_id(
116 self, type: StoreItemType, thread: ThreadMetadata | None = None
117 ) -> str:
118 """Generate a new store-backed id for the given item type."""
119 if type == "thread":
120 return self.store.generate_thread_id(self.request_context)
121 return self.store.generate_item_id(
122 type, thread or self.thread, self.request_context
123 )
124
125 async def stream_widget(
126 self,
127 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
128 copy_text: str | None = None,
129 ) -> None:
130 """Stream a widget into the thread by enqueueing widget events."""
131 async for event in stream_widget(
132 self.thread,
133 widget,
134 copy_text,
135 lambda item_type: self.store.generate_item_id(
136 item_type, self.thread, self.request_context
137 ),
138 ):
139 await self._events.put(event)
140
141 async def end_workflow(
142 self, summary: WorkflowSummary | None = None, expanded: bool = False
143 ) -> None:
144 """Finalize the active workflow item, optionally attaching a summary."""
145 if not self.workflow_item:
146 # No workflow to end
147 return
148
149 if summary is not None:
150 self.workflow_item.workflow.summary = summary
151 elif self.workflow_item.workflow.summary is None:
152 # If no summary was set or provided, set a basic work summary
153 delta = datetime.now() - self.workflow_item.created_at
154 duration = int(delta.total_seconds())
155 self.workflow_item.workflow.summary = DurationSummary(duration=duration)
156 self.workflow_item.workflow.expanded = expanded
157 await self.stream(ThreadItemDoneEvent(item=self.workflow_item))
158 self.workflow_item = None
159
160 async def start_workflow(self, workflow: Workflow) -> None:
161 """Begin streaming a new workflow item."""
162 self.workflow_item = WorkflowItem(
163 id=self.generate_id("workflow"),
164 created_at=datetime.now(),
165 workflow=workflow,
166 thread_id=self.thread.id,
167 )
168
169 if workflow.type != "reasoning" and len(workflow.tasks) == 0:
170 # Defer sending added event until we have tasks
171 return
172
173 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
174
175 async def update_workflow_task(self, task: Task, task_index: int) -> None:
176 """Update an existing workflow task and stream the delta."""
177 if self.workflow_item is None:
178 raise ValueError("Workflow is not set")
179 # ensure reference is updated in case task is a copy
180 self.workflow_item.workflow.tasks[task_index] = task
181 await self.stream(
182 ThreadItemUpdatedEvent(
183 item_id=self.workflow_item.id,
184 update=WorkflowTaskUpdated(
185 task=task,
186 task_index=task_index,
187 ),
188 )
189 )
190
191 async def add_workflow_task(self, task: Task) -> None:
192 """Append a workflow task and stream the appropriate event."""
193 self.workflow_item = self.workflow_item or WorkflowItem(
194 id=self.generate_id("workflow"),
195 created_at=datetime.now(),
196 workflow=Workflow(type="custom", tasks=[]),
197 thread_id=self.thread.id,
198 )
199 workflow = self.workflow_item.workflow
200 workflow.tasks.append(task)
201
202 if workflow.type != "reasoning" and len(workflow.tasks) == 1:
203 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
204 else:
205 await self.stream(
206 ThreadItemUpdatedEvent(
207 item_id=self.workflow_item.id,
208 update=WorkflowTaskAdded(
209 task=task,
210 task_index=workflow.tasks.index(task),
211 ),
212 )
213 )
214
215 async def stream(self, event: ThreadStreamEvent) -> None:
216 """Enqueue a ThreadStreamEvent for downstream processing."""
217 await self._events.put(event)
218
219 def _complete(self):
220 self._events.put_nowait(_QueueCompleteSentinel())
221
222
223def _convert_content(content: Content) -> AssistantMessageContent:
224 if content.type == "output_text":
225 annotations = [
226 _convert_annotation(annotation) for annotation in content.annotations
227 ]
228 annotations = [a for a in annotations if a is not None]
229 return AssistantMessageContent(
230 text=content.text,
231 annotations=annotations,
232 )
233 else:
234 return AssistantMessageContent(
235 text=content.refusal,
236 annotations=[],
237 )
238
239
240def _convert_annotation(raw_annotation: object) -> Annotation | None:
241 # There is a bug in the OpenAPI client that sometimes parses the annotation delta event into the wrong class
242 # resulting into annotation being a dict or untyped object instead instead of a ResponsesAnnotation
243 annotation = TypeAdapter[ResponsesAnnotation](ResponsesAnnotation).validate_python(
244 raw_annotation
245 )
246
247 if annotation.type == "file_citation":
248 filename = annotation.filename
249 if not filename:
250 return None
251
252 return Annotation(
253 source=FileSource(filename=filename, title=filename),
254 index=annotation.index,
255 )
256
257 if annotation.type == "url_citation":
258 return Annotation(
259 source=URLSource(
260 url=annotation.url,
261 title=annotation.title,
262 ),
263 index=annotation.end_index,
264 )
265
266 if annotation.type == "container_file_citation":
267 filename = annotation.filename
268 if not filename:
269 return None
270
271 return Annotation(
272 source=FileSource(filename=filename, title=filename),
273 index=annotation.end_index,
274 )
275
276 return None
277
278
279T1 = TypeVar("T1")
280T2 = TypeVar("T2")
281
282
283async def _merge_generators(
284 a: AsyncIterator[T1],
285 b: AsyncIterator[T2],
286) -> AsyncIterator[T1 | T2]:
287 pending: list[AsyncIterator[T1 | T2]] = [a, b]
288 pending_tasks: dict[asyncio.Task, AsyncIterator[T1 | T2]] = {
289 asyncio.ensure_future(g.__anext__()): g for g in pending
290 }
291 while len(pending_tasks) > 0:
292 done, _ = await asyncio.wait(
293 pending_tasks.keys(), return_when="FIRST_COMPLETED"
294 )
295 stop = False
296 for d in done:
297 try:
298 result = d.result()
299 yield result
300 dg = pending_tasks[d]
301 pending_tasks[asyncio.ensure_future(dg.__anext__())] = dg
302 except StopAsyncIteration:
303 stop = True
304 finally:
305 del pending_tasks[d]
306 if stop:
307 for task in pending_tasks.keys():
308 if not task.cancel():
309 try:
310 yield task.result()
311 except asyncio.CancelledError:
312 pass
313 except asyncio.InvalidStateError:
314 pass
315 break
316
317
318class _EventWrapper:
319 def __init__(self, event: ThreadStreamEvent):
320 self.event = event
321
322
323class _AsyncQueueIterator(AsyncIterator[_EventWrapper]):
324 def __init__(
325 self, queue: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel]
326 ):
327 self.queue = queue
328 self.completed = False
329
330 def __aiter__(self):
331 return self
332
333 async def __anext__(self):
334 if self.completed:
335 raise StopAsyncIteration
336
337 item = await self.queue.get()
338 if isinstance(item, _QueueCompleteSentinel):
339 self.completed = True
340 raise StopAsyncIteration
341 return _EventWrapper(item)
342
343 def drain_and_complete(self) -> None:
344 """Empty the underlying queue without awaiting and mark this iterator completed.
345
346 This is intended for cleanup paths where we must guarantee no awaits
347 occur. All queued items, including any completion sentinel, are
348 discarded.
349 """
350 while True:
351 try:
352 self.queue.get_nowait()
353 except asyncio.QueueEmpty:
354 break
355 self.completed = True
356
357
358class StreamingThoughtTracker(BaseModel):
359 item_id: str
360 index: int
361 task: ThoughtTask
362
363
364class ResponseStreamConverter:
365 """Used by `stream_agent_response` to convert streamed Agents SDK output
366 into values used by ChatKit thread items and thread stream events.
367
368 Defines overridable methods for adapting streamed data (such as image
369 generation results and partial updates) into the forms expected by ChatKit.
370 """
371
372 partial_images: int | None = None
373 """
374 The expected number of partial image updates for an image generation result.
375
376 When set, this value is used to normalize partial image indices into a
377 progress value in the range [0, 1]. If unset, all partial image updates are
378 assigned a progress value of 0.
379 """
380
381 def __init__(self, *, partial_images: int | None = None):
382 """
383 Args:
384 partial_images: The expected number of partial image updates for image
385 generation results, or None if no progress normalization should
386 be performed.
387 """
388 self.partial_images = partial_images
389
390 async def base64_image_to_url(
391 self,
392 image_id: str,
393 base64_image: str,
394 partial_image_index: int | None = None,
395 ) -> str:
396 """
397 Convert a base64-encoded image into a URL.
398
399 This method is used to produce the URL stored on thread items for image
400 generation results.
401
402 Args:
403 image_id: The ID of the image generation call. This stays stable across partial image updates.
404 base64_image: The base64-encoded image.
405 partial_image_index: The index of the partial image update, starting from 0.
406
407 Returns:
408 A URL string.
409 """
410 return f"data:image/png;base64,{base64_image}"
411
412 def partial_image_index_to_progress(self, partial_image_index: int) -> float:
413 """
414 Convert a partial image index into a normalized progress value.
415
416 Args:
417 partial_image_index: The index of the partial image update, starting from 0.
418
419 Returns:
420 A float between 0 and 1 representing progress for the image
421 generation result.
422 """
423 if self.partial_images is None or self.partial_images <= 0:
424 return 0.0
425
426 return min(1.0, partial_image_index / self.partial_images)
427
428
429_DEFAULT_RESPONSE_STREAM_CONVERTER = ResponseStreamConverter()
430
431
432async def stream_agent_response(
433 context: AgentContext,
434 result: RunResultStreaming,
435 *,
436 converter: ResponseStreamConverter = _DEFAULT_RESPONSE_STREAM_CONVERTER,
437) -> AsyncIterator[ThreadStreamEvent]:
438 """
439 Convert a streamed Agents SDK run into ChatKit thread stream events.
440
441 This function consumes a streaming run result and yields `ThreadStreamEvent`
442 objects as the run progresses.
443
444 Args:
445 context: The AgentContext to use for the stream.
446 result: The RunResultStreaming to convert.
447 converter: Defines overridable methods for adapting streamed data (such as image
448 generation results and partial updates) into the forms expected by ChatKit.
449
450 Returns:
451 An async iterator that yields thread stream events representing the run result.
452 """
453 current_item_id = None
454 current_tool_call = None
455 ctx = context
456 thread = context.thread
457 queue_iterator = _AsyncQueueIterator(context._events)
458 produced_items = set()
459 streaming_thought: None | StreamingThoughtTracker = None
460 # item_id -> content_index -> annotation count
461 item_annotation_count: defaultdict[str, defaultdict[int, int]] = defaultdict(
462 lambda: defaultdict(int)
463 )
464
465 # check if the last item in the thread was a workflow or a client tool call
466 # if it was a client tool call, check if the second last item was a workflow
467 # if either was, continue the workflow
468 items = await context.store.load_thread_items(
469 thread.id, None, 2, "desc", context.request_context
470 )
471 last_item = items.data[0] if len(items.data) > 0 else None
472 second_last_item = items.data[1] if len(items.data) > 1 else None
473
474 if last_item and last_item.type == "workflow":
475 ctx.workflow_item = last_item
476 elif (
477 last_item
478 and last_item.type == "client_tool_call"
479 and second_last_item
480 and second_last_item.type == "workflow"
481 ):
482 ctx.workflow_item = second_last_item
483
484 def end_workflow(item: WorkflowItem):
485 if item == ctx.workflow_item:
486 ctx.workflow_item = None
487 delta = datetime.now() - item.created_at
488 duration = int(delta.total_seconds())
489 if item.workflow.summary is None:
490 item.workflow.summary = DurationSummary(duration=duration)
491 # Default to closing all workflows
492 # To keep a workflow open on completion, close it explicitly with
493 # AgentContext.end_workflow(expanded=True)
494 item.workflow.expanded = False
495 return ThreadItemDoneEvent(item=item)
496
497 try:
498 async for event in _merge_generators(result.stream_events(), queue_iterator):
499 # Events emitted from agent context helpers
500 if isinstance(event, _EventWrapper):
501 event = event.event
502 if (
503 event.type == "thread.item.added"
504 or event.type == "thread.item.done"
505 ):
506 # End the current workflow if visual item is added after it
507 if (
508 ctx.workflow_item
509 and ctx.workflow_item.id != event.item.id
510 and event.item.type != "client_tool_call"
511 and event.item.type != "hidden_context_item"
512 ):
513 yield end_workflow(ctx.workflow_item)
514
515 # track the current workflow if one is added
516 if (
517 event.type == "thread.item.added"
518 and event.item.type == "workflow"
519 ):
520 ctx.workflow_item = event.item
521
522 # track integration produced items so we can clean them up if
523 # there is a guardrail tripwire
524 produced_items.add(event.item.id)
525 yield event
526 continue
527
528 if event.type == "run_item_stream_event":
529 event = event.item
530 if (
531 event.type == "tool_call_item"
532 and event.raw_item.type == "function_call"
533 ):
534 current_tool_call = event.raw_item.call_id
535 current_item_id = event.raw_item.id
536 assert current_item_id
537 produced_items.add(current_item_id)
538 continue
539
540 if event.type != "raw_response_event":
541 # Ignore everything else that isn't a raw response event
542 continue
543
544 # Handle Responses events
545 event = event.data
546 if event.type == "response.content_part.added":
547 if event.part.type == "reasoning_text":
548 continue
549 content = _convert_content(event.part)
550 yield ThreadItemUpdatedEvent(
551 item_id=event.item_id,
552 update=AssistantMessageContentPartAdded(
553 content_index=event.content_index,
554 content=content,
555 ),
556 )
557 elif event.type == "response.output_text.delta":
558 yield ThreadItemUpdatedEvent(
559 item_id=event.item_id,
560 update=AssistantMessageContentPartTextDelta(
561 content_index=event.content_index,
562 delta=event.delta,
563 ),
564 )
565 elif event.type == "response.output_text.done":
566 yield ThreadItemUpdatedEvent(
567 item_id=event.item_id,
568 update=AssistantMessageContentPartDone(
569 content_index=event.content_index,
570 content=AssistantMessageContent(
571 text=event.text,
572 annotations=[],
573 ),
574 ),
575 )
576 elif event.type == "response.output_text.annotation.added":
577 annotation = _convert_annotation(event.annotation)
578 if annotation:
579 # Manually track annotation indices per content part in case we drop an annotation that
580 # we can't convert to our internal representation (e.g. missing filename).
581 annotation_index = item_annotation_count[event.item_id][
582 event.content_index
583 ]
584 item_annotation_count[event.item_id][event.content_index] = (
585 annotation_index + 1
586 )
587 yield ThreadItemUpdatedEvent(
588 item_id=event.item_id,
589 update=AssistantMessageContentPartAnnotationAdded(
590 content_index=event.content_index,
591 annotation_index=annotation_index,
592 annotation=annotation,
593 ),
594 )
595 continue
596 elif event.type == "response.output_item.added":
597 item = event.item
598 if item.type == "reasoning" and not ctx.workflow_item:
599 ctx.workflow_item = WorkflowItem(
600 id=ctx.generate_id("workflow"),
601 created_at=datetime.now(),
602 workflow=Workflow(type="reasoning", tasks=[]),
603 thread_id=thread.id,
604 )
605 produced_items.add(ctx.workflow_item.id)
606 yield ThreadItemAddedEvent(item=ctx.workflow_item)
607 if item.type == "message":
608 if ctx.workflow_item:
609 yield end_workflow(ctx.workflow_item)
610 produced_items.add(item.id)
611 yield ThreadItemAddedEvent(
612 item=AssistantMessageItem(
613 # Reusing the Responses message ID
614 id=item.id,
615 thread_id=thread.id,
616 content=[_convert_content(c) for c in item.content],
617 created_at=datetime.now(),
618 ),
619 )
620 elif item.type == "image_generation_call":
621 ctx.generated_image_item = GeneratedImageItem(
622 id=ctx.generate_id("message"),
623 thread_id=thread.id,
624 created_at=datetime.now(),
625 image=None,
626 )
627 produced_items.add(ctx.generated_image_item.id)
628 yield ThreadItemAddedEvent(item=ctx.generated_image_item)
629 elif event.type == "response.image_generation_call.partial_image":
630 if not ctx.generated_image_item:
631 continue
632
633 url = await converter.base64_image_to_url(
634 image_id=event.item_id,
635 base64_image=event.partial_image_b64,
636 partial_image_index=event.partial_image_index,
637 )
638 progress = converter.partial_image_index_to_progress(
639 event.partial_image_index
640 )
641
642 ctx.generated_image_item.image = GeneratedImage(
643 id=event.item_id, url=url
644 )
645
646 yield ThreadItemUpdatedEvent(
647 item_id=ctx.generated_image_item.id,
648 update=GeneratedImageUpdated(
649 image=ctx.generated_image_item.image, progress=progress
650 ),
651 )
652 elif event.type == "response.reasoning_summary_text.delta":
653 if not ctx.workflow_item:
654 continue
655
656 # stream the first thought in a new workflow so that we can show it earlier
657 if (
658 ctx.workflow_item.workflow.type == "reasoning"
659 and len(ctx.workflow_item.workflow.tasks) == 0
660 ):
661 streaming_thought = StreamingThoughtTracker(
662 item_id=event.item_id,
663 index=event.summary_index,
664 task=ThoughtTask(content=event.delta),
665 )
666 ctx.workflow_item.workflow.tasks.append(streaming_thought.task)
667 yield ThreadItemUpdatedEvent(
668 item_id=ctx.workflow_item.id,
669 update=WorkflowTaskAdded(
670 task=streaming_thought.task,
671 task_index=0,
672 ),
673 )
674 elif (
675 streaming_thought
676 and streaming_thought.task in ctx.workflow_item.workflow.tasks
677 and event.item_id == streaming_thought.item_id
678 and event.summary_index == streaming_thought.index
679 ):
680 streaming_thought.task.content += event.delta
681 yield ThreadItemUpdatedEvent(
682 item_id=ctx.workflow_item.id,
683 update=WorkflowTaskUpdated(
684 task=streaming_thought.task,
685 task_index=ctx.workflow_item.workflow.tasks.index(
686 streaming_thought.task
687 ),
688 ),
689 )
690 elif event.type == "response.reasoning_summary_text.done":
691 if ctx.workflow_item:
692 if (
693 streaming_thought
694 and streaming_thought.task in ctx.workflow_item.workflow.tasks
695 and event.item_id == streaming_thought.item_id
696 and event.summary_index == streaming_thought.index
697 ):
698 task = streaming_thought.task
699 task.content = event.text
700 streaming_thought = None
701 update = WorkflowTaskUpdated(
702 task=task,
703 task_index=ctx.workflow_item.workflow.tasks.index(task),
704 )
705 else:
706 task = ThoughtTask(content=event.text)
707 ctx.workflow_item.workflow.tasks.append(task)
708 update = WorkflowTaskAdded(
709 task=task,
710 task_index=ctx.workflow_item.workflow.tasks.index(task),
711 )
712 yield ThreadItemUpdatedEvent(
713 item_id=ctx.workflow_item.id,
714 update=update,
715 )
716 elif event.type == "response.output_item.done":
717 item = event.item
718 if item.type == "message":
719 produced_items.add(item.id)
720 yield ThreadItemDoneEvent(
721 item=AssistantMessageItem(
722 # Reusing the Responses message ID
723 id=item.id,
724 thread_id=thread.id,
725 content=[_convert_content(c) for c in item.content],
726 created_at=datetime.now(),
727 ),
728 )
729 elif item.type == "image_generation_call" and item.result:
730 if not ctx.generated_image_item:
731 continue
732
733 url = await converter.base64_image_to_url(
734 image_id=item.id,
735 base64_image=item.result,
736 )
737 image = GeneratedImage(id=item.id, url=url)
738
739 ctx.generated_image_item.image = image
740 yield ThreadItemDoneEvent(item=ctx.generated_image_item)
741
742 ctx.generated_image_item = None
743
744 except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
745 for item_id in produced_items:
746 yield ThreadItemRemovedEvent(item_id=item_id)
747
748 # Drain remaining events without processing them
749 context._complete()
750 queue_iterator.drain_and_complete()
751
752 raise
753
754 context._complete()
755
756 # Drain remaining events
757 async for event in queue_iterator:
758 yield event.event
759
760 # If there is still an active workflow at the end of the run, store
761 # it's current state so that we can continue it in the next turn.
762 if ctx.workflow_item:
763 await ctx.store.add_thread_item(
764 thread.id, ctx.workflow_item, ctx.request_context
765 )
766
767 if context.client_tool_call:
768 yield ThreadItemDoneEvent(
769 item=ClientToolCallItem(
770 id=current_item_id
771 or context.store.generate_item_id(
772 "tool_call", thread, context.request_context
773 ),
774 thread_id=thread.id,
775 name=context.client_tool_call.name,
776 arguments=context.client_tool_call.arguments,
777 created_at=datetime.now(),
778 call_id=current_tool_call
779 or context.store.generate_item_id(
780 "tool_call", thread, context.request_context
781 ),
782 ),
783 )
784
785
786TWidget = TypeVar("TWidget", bound=Markdown | Text)
787
788
789async def accumulate_text(
790 events: AsyncIterator[StreamEvent],
791 base_widget: TWidget,
792) -> AsyncIterator[TWidget]:
793 text = ""
794 yield base_widget
795 async for event in events:
796 if event.type == "raw_response_event":
797 if event.data.type == "response.output_text.delta":
798 text += event.data.delta
799 yield base_widget.model_copy(update={"value": text})
800 yield base_widget.model_copy(update={"value": text, "streaming": False})
801
802
803class ThreadItemConverter:
804 """
805 Converts thread items to Agent SDK input items.
806 Widgets, Tasks, and Workflows have default conversions but can be customized.
807 Attachments, Tags, and HiddenContextItems require custom handling based on the use case.
808 Other item types are converted automatically.
809 """
810
811 async def attachment_to_message_content(
812 self, attachment: Attachment
813 ) -> ResponseInputContentParam:
814 """
815 Convert an attachment in a user message into a message content part to send to the model.
816 Required when attachments are enabled.
817 """
818 raise NotImplementedError(
819 "An Attachment was included in a UserMessageItem but Converter.attachment_to_message_content was not implemented"
820 )
821
822 async def tag_to_message_content(
823 self, tag: UserMessageTagContent
824 ) -> ResponseInputContentParam:
825 """
826 Convert a tag in a user message into a message content part to send to the model as context.
827 Required when tags are used.
828 """
829 raise NotImplementedError(
830 "A Tag was included in a UserMessageItem but Converter.tag_to_message_content is not implemented"
831 )
832
833 async def generated_image_to_input(
834 self, item: GeneratedImageItem
835 ) -> TResponseInputItem | list[TResponseInputItem] | None:
836 """
837 Convert a GeneratedImageItem into input item(s) to send to the model.
838 Override this method to customize the conversion of generated images, such as when your
839 generated image url is not publicly reachable.
840 """
841 if not item.image:
842 return None
843
844 return Message(
845 type="message",
846 content=[
847 ResponseInputTextParam(
848 type="input_text",
849 text="The following image was generated by the agent.",
850 ),
851 ResponseInputImageParam(
852 type="input_image",
853 detail="auto",
854 image_url=item.image.url,
855 ),
856 ],
857 role="user",
858 )
859
860 async def hidden_context_to_input(
861 self, item: HiddenContextItem
862 ) -> TResponseInputItem | list[TResponseInputItem] | None:
863 """
864 Convert a HiddenContextItem into input item(s) to send to the model.
865 Required to override when HiddenContextItems with non-string content are used.
866 """
867 if not isinstance(item.content, str):
868 raise NotImplementedError(
869 "HiddenContextItems with non-string content were present in a user message but a Converter.hidden_context_to_input was not implemented"
870 )
871
872 text = (
873 "Hidden context for the agent (not shown to the user):\n"
874 f"<HiddenContext>\n{item.content}\n</HiddenContext>"
875 )
876 return Message(
877 type="message",
878 content=[
879 ResponseInputTextParam(
880 type="input_text",
881 text=text,
882 )
883 ],
884 role="user",
885 )
886
887 async def sdk_hidden_context_to_input(
888 self, item: SDKHiddenContextItem
889 ) -> TResponseInputItem | list[TResponseInputItem] | None:
890 """
891 Convert a SDKHiddenContextItem into input item to send to the model.
892 This is used by the ChatKit Python SDK for storing additional context
893 for internal operations.
894 Override if you want to wrap the content in a different format.
895 """
896 text = (
897 "Hidden context for the agent (not shown to the user):\n"
898 f"<HiddenContext>\n{item.content}\n</HiddenContext>"
899 )
900 return Message(
901 type="message",
902 content=[
903 ResponseInputTextParam(
904 type="input_text",
905 text=text,
906 )
907 ],
908 role="user",
909 )
910
911 async def task_to_input(
912 self, item: TaskItem
913 ) -> TResponseInputItem | list[TResponseInputItem] | None:
914 """
915 Convert a TaskItem into input item(s) to send to the model.
916 """
917 if item.task.type != "custom" or (
918 not item.task.title and not item.task.content
919 ):
920 return None
921 title = f"{item.task.title}" if item.task.title else ""
922 content = f"{item.task.content}" if item.task.content else ""
923 task_text = f"{title}: {content}" if title and content else title or content
924 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
925 return Message(
926 type="message",
927 content=[
928 ResponseInputTextParam(
929 type="input_text",
930 text=text,
931 )
932 ],
933 role="user",
934 )
935
936 async def workflow_to_input(
937 self, item: WorkflowItem
938 ) -> TResponseInputItem | list[TResponseInputItem] | None:
939 """
940 Convert a TaskItem into input item(s) to send to the model.
941 Returns WorkflowItem.response_items by default.
942 """
943 messages = []
944 for task in item.workflow.tasks:
945 if task.type != "custom" or (not task.title and not task.content):
946 continue
947
948 title = f"{task.title}" if task.title else ""
949 content = f"{task.content}" if task.content else ""
950 task_text = f"{title}: {content}" if title and content else title or content
951 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
952 messages.append(
953 Message(
954 type="message",
955 content=[
956 ResponseInputTextParam(
957 type="input_text",
958 text=text,
959 )
960 ],
961 role="user",
962 )
963 )
964 return messages
965
966 async def widget_to_input(
967 self, item: WidgetItem
968 ) -> TResponseInputItem | list[TResponseInputItem] | None:
969 """
970 Convert a WidgetItem into input item(s) to send to the model.
971 By default, WidgetItems converted to a text description with a JSON representation of the widget.
972 """
973 return Message(
974 type="message",
975 content=[
976 ResponseInputTextParam(
977 type="input_text",
978 text=f"The following graphical UI widget (id: {item.id}) was displayed to the user:"
979 + item.widget.model_dump_json(
980 exclude_unset=True, exclude_none=True
981 ),
982 )
983 ],
984 role="user",
985 )
986
987 async def user_message_to_input(
988 self, item: UserMessageItem, is_last_message: bool = True
989 ) -> TResponseInputItem | list[TResponseInputItem] | None:
990 # Build the user text exactly as typed, rendering tags as @key
991 message_text_parts: list[str] = []
992 # Track tags separately to add system context
993 raw_tags: list[UserMessageTagContent] = []
994
995 for part in item.content:
996 if isinstance(part, UserMessageTextContent):
997 message_text_parts.append(part.text)
998 elif isinstance(part, UserMessageTagContent):
999 message_text_parts.append(f"@{part.text}")
1000 raw_tags.append(part)
1001 else:
1002 assert_never(part)
1003
1004 user_text_item = Message(
1005 role="user",
1006 type="message",
1007 content=[
1008 ResponseInputTextParam(
1009 type="input_text", text="".join(message_text_parts)
1010 ),
1011 *[
1012 await self.attachment_to_message_content(a)
1013 for a in item.attachments
1014 ],
1015 ],
1016 )
1017
1018 # Build system items (prepend later): quoted text and @-mention context
1019 context_items: list[TResponseInputItem] = []
1020
1021 if item.quoted_text and is_last_message:
1022 context_items.append(
1023 Message(
1024 role="user",
1025 type="message",
1026 content=[
1027 ResponseInputTextParam(
1028 type="input_text",
1029 text=f"The user is referring to this in particular: \n{item.quoted_text}",
1030 )
1031 ],
1032 )
1033 )
1034
1035 # Dedupe tags (preserve order) and resolve to message content
1036 if raw_tags:
1037 seen, uniq_tags = set(), []
1038 for t in raw_tags:
1039 if t.text not in seen:
1040 seen.add(t.text)
1041 uniq_tags.append(t)
1042
1043 tag_content: ResponseInputMessageContentListParam = [
1044 # should return summarized text items
1045 await self.tag_to_message_content(tag)
1046 for tag in uniq_tags
1047 ]
1048
1049 if tag_content:
1050 context_items.append(
1051 Message(
1052 role="user",
1053 type="message",
1054 content=[
1055 ResponseInputTextParam(
1056 type="input_text",
1057 text=cleandoc("""
1058 # User-provided context for @-mentions
1059 - When referencing resolved entities, use their canonical names **without** '@'.
1060 - The '@' form appears only in user text and should not be echoed.
1061 """).strip(),
1062 ),
1063 *tag_content,
1064 ],
1065 )
1066 )
1067
1068 return [user_text_item, *context_items]
1069
1070 async def assistant_message_to_input(
1071 self, item: AssistantMessageItem
1072 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1073 return EasyInputMessageParam(
1074 type="message",
1075 content=[
1076 # content param doesn't support the assistant message content types
1077 cast(
1078 ResponseInputContentParam,
1079 ResponseOutputText(
1080 type="output_text",
1081 text=c.text,
1082 annotations=[], # TODO: these should be sent back as well
1083 ).model_dump(),
1084 )
1085 for c in item.content
1086 ],
1087 role="assistant",
1088 )
1089
1090 async def client_tool_call_to_input(
1091 self, item: ClientToolCallItem
1092 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1093 if item.status == "pending":
1094 # Filter out pending tool calls - they cannot be sent to the model
1095 return None
1096
1097 return [
1098 ResponseFunctionToolCallParam(
1099 type="function_call",
1100 call_id=item.call_id,
1101 name=item.name,
1102 arguments=json.dumps(item.arguments),
1103 ),
1104 FunctionCallOutput(
1105 type="function_call_output",
1106 call_id=item.call_id,
1107 output=json.dumps(item.output),
1108 ),
1109 ]
1110
1111 async def end_of_turn_to_input(
1112 self, item: EndOfTurnItem
1113 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1114 # Only used for UI hints - you shouldn't need to override this
1115 return None
1116
1117 async def _thread_item_to_input_item(
1118 self,
1119 item: ThreadItem,
1120 is_last_message: bool = True,
1121 ) -> list[TResponseInputItem]:
1122 match item:
1123 case UserMessageItem():
1124 out = await self.user_message_to_input(item, is_last_message) or []
1125 return out if isinstance(out, list) else [out]
1126 case AssistantMessageItem():
1127 out = await self.assistant_message_to_input(item) or []
1128 return out if isinstance(out, list) else [out]
1129 case ClientToolCallItem():
1130 out = await self.client_tool_call_to_input(item) or []
1131 return out if isinstance(out, list) else [out]
1132 case EndOfTurnItem():
1133 out = await self.end_of_turn_to_input(item) or []
1134 return out if isinstance(out, list) else [out]
1135 case WidgetItem():
1136 out = await self.widget_to_input(item) or []
1137 return out if isinstance(out, list) else [out]
1138 case WorkflowItem():
1139 out = await self.workflow_to_input(item) or []
1140 return out if isinstance(out, list) else [out]
1141 case TaskItem():
1142 out = await self.task_to_input(item) or []
1143 return out if isinstance(out, list) else [out]
1144 case HiddenContextItem():
1145 out = await self.hidden_context_to_input(item) or []
1146 return out if isinstance(out, list) else [out]
1147 case SDKHiddenContextItem():
1148 out = await self.sdk_hidden_context_to_input(item) or []
1149 return out if isinstance(out, list) else [out]
1150 case GeneratedImageItem():
1151 out = await self.generated_image_to_input(item) or []
1152 return out if isinstance(out, list) else [out]
1153 case _:
1154 assert_never(item)
1155
1156 async def to_agent_input(
1157 self,
1158 thread_items: Sequence[ThreadItem] | ThreadItem,
1159 ) -> list[TResponseInputItem]:
1160 if isinstance(thread_items, Sequence):
1161 # shallow copy in case caller mutates the list while we're iterating
1162 thread_items = thread_items[:]
1163 else:
1164 thread_items = [thread_items]
1165 output: list[TResponseInputItem] = []
1166 for item in thread_items:
1167 output.extend(
1168 await self._thread_item_to_input_item(
1169 item,
1170 is_last_message=item is thread_items[-1],
1171 )
1172 )
1173 return output
1174
1175
1176_DEFAULT_CONVERTER = ThreadItemConverter()
1177
1178
1179def simple_to_agent_input(thread_items: Sequence[ThreadItem] | ThreadItem):
1180 """Helper that converts thread items using the default ThreadItemConverter."""
1181 return _DEFAULT_CONVERTER.to_agent_input(thread_items)
1182