openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
4e1156dd81f931ed2f455197737892dd59437ab2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/agents.py

1212lines · modecode

1import asyncio
2import json
3from collections import defaultdict
4from collections.abc import AsyncIterator
5from datetime import datetime
6from inspect import cleandoc
7from typing import (
8 Annotated,
9 Any,
10 AsyncGenerator,
11 Generic,
12 Sequence,
13 TypeVar,
14 cast,
15)
16
17from agents import (
18 InputGuardrailTripwireTriggered,
19 OutputGuardrailTripwireTriggered,
20 RunResultStreaming,
21 StreamEvent,
22 TResponseInputItem,
23)
24from openai.types.responses import (
25 EasyInputMessageParam,
26 ResponseFunctionToolCallParam,
27 ResponseInputContentParam,
28 ResponseInputImageParam,
29 ResponseInputMessageContentListParam,
30 ResponseInputTextParam,
31 ResponseOutputText,
32)
33from openai.types.responses.response_input_item_param import (
34 FunctionCallOutput,
35 Message,
36)
37from openai.types.responses.response_output_message import Content
38from openai.types.responses.response_output_text import (
39 Annotation as ResponsesAnnotation,
40)
41from openai.types.responses.response_output_text import (
42 AnnotationContainerFileCitation,
43 AnnotationFileCitation,
44 AnnotationURLCitation,
45)
46from pydantic import BaseModel, ConfigDict, SkipValidation, TypeAdapter
47from typing_extensions import assert_never
48
49from .server import stream_widget
50from .store import Store, StoreItemType
51from .types import (
52 Annotation,
53 AssistantMessageContent,
54 AssistantMessageContentPartAdded,
55 AssistantMessageContentPartAnnotationAdded,
56 AssistantMessageContentPartDone,
57 AssistantMessageContentPartTextDelta,
58 AssistantMessageItem,
59 Attachment,
60 ClientToolCallItem,
61 DurationSummary,
62 EndOfTurnItem,
63 FileSource,
64 GeneratedImage,
65 GeneratedImageItem,
66 GeneratedImageUpdated,
67 HiddenContextItem,
68 SDKHiddenContextItem,
69 Task,
70 TaskItem,
71 ThoughtTask,
72 ThreadItem,
73 ThreadItemAddedEvent,
74 ThreadItemDoneEvent,
75 ThreadItemRemovedEvent,
76 ThreadItemUpdatedEvent,
77 ThreadMetadata,
78 ThreadStreamEvent,
79 URLSource,
80 UserMessageItem,
81 UserMessageTagContent,
82 UserMessageTextContent,
83 WidgetItem,
84 Workflow,
85 WorkflowItem,
86 WorkflowSummary,
87 WorkflowTaskAdded,
88 WorkflowTaskUpdated,
89)
90from .widgets import Markdown, Text, WidgetRoot
91
92
93class ClientToolCall(BaseModel):
94 """
95 Returned from tool methods to indicate a client-side tool call.
96 """
97
98 name: str
99 arguments: dict[str, Any]
100
101
102class _QueueCompleteSentinel: ...
103
104
105TContext = TypeVar("TContext")
106
107
108class AgentContext(BaseModel, Generic[TContext]):
109 model_config = ConfigDict(arbitrary_types_allowed=True)
110
111 thread: ThreadMetadata
112 store: Annotated[Store[TContext], SkipValidation]
113 request_context: TContext
114 previous_response_id: str | None = None
115 client_tool_call: ClientToolCall | None = None
116 workflow_item: WorkflowItem | None = None
117 generated_image_item: GeneratedImageItem | None = None
118 _events: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel] = asyncio.Queue()
119
120 def generate_id(
121 self, type: StoreItemType, thread: ThreadMetadata | None = None
122 ) -> str:
123 """Generate a new store-backed id for the given item type."""
124 if type == "thread":
125 return self.store.generate_thread_id(self.request_context)
126 return self.store.generate_item_id(
127 type, thread or self.thread, self.request_context
128 )
129
130 async def stream_widget(
131 self,
132 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
133 copy_text: str | None = None,
134 ) -> None:
135 """Stream a widget into the thread by enqueueing widget events."""
136 async for event in stream_widget(
137 self.thread,
138 widget,
139 copy_text,
140 lambda item_type: self.store.generate_item_id(
141 item_type, self.thread, self.request_context
142 ),
143 ):
144 await self._events.put(event)
145
146 async def end_workflow(
147 self, summary: WorkflowSummary | None = None, expanded: bool = False
148 ) -> None:
149 """Finalize the active workflow item, optionally attaching a summary."""
150 if not self.workflow_item:
151 # No workflow to end
152 return
153
154 if summary is not None:
155 self.workflow_item.workflow.summary = summary
156 elif self.workflow_item.workflow.summary is None:
157 # If no summary was set or provided, set a basic work summary
158 delta = datetime.now() - self.workflow_item.created_at
159 duration = int(delta.total_seconds())
160 self.workflow_item.workflow.summary = DurationSummary(duration=duration)
161 self.workflow_item.workflow.expanded = expanded
162 await self.stream(ThreadItemDoneEvent(item=self.workflow_item))
163 self.workflow_item = None
164
165 async def start_workflow(self, workflow: Workflow) -> None:
166 """Begin streaming a new workflow item."""
167 self.workflow_item = WorkflowItem(
168 id=self.generate_id("workflow"),
169 created_at=datetime.now(),
170 workflow=workflow,
171 thread_id=self.thread.id,
172 )
173
174 if workflow.type != "reasoning" and len(workflow.tasks) == 0:
175 # Defer sending added event until we have tasks
176 return
177
178 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
179
180 async def update_workflow_task(self, task: Task, task_index: int) -> None:
181 """Update an existing workflow task and stream the delta."""
182 if self.workflow_item is None:
183 raise ValueError("Workflow is not set")
184 # ensure reference is updated in case task is a copy
185 self.workflow_item.workflow.tasks[task_index] = task
186 await self.stream(
187 ThreadItemUpdatedEvent(
188 item_id=self.workflow_item.id,
189 update=WorkflowTaskUpdated(
190 task=task,
191 task_index=task_index,
192 ),
193 )
194 )
195
196 async def add_workflow_task(self, task: Task) -> None:
197 """Append a workflow task and stream the appropriate event."""
198 self.workflow_item = self.workflow_item or WorkflowItem(
199 id=self.generate_id("workflow"),
200 created_at=datetime.now(),
201 workflow=Workflow(type="custom", tasks=[]),
202 thread_id=self.thread.id,
203 )
204 workflow = self.workflow_item.workflow
205 workflow.tasks.append(task)
206
207 if workflow.type != "reasoning" and len(workflow.tasks) == 1:
208 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
209 else:
210 await self.stream(
211 ThreadItemUpdatedEvent(
212 item_id=self.workflow_item.id,
213 update=WorkflowTaskAdded(
214 task=task,
215 task_index=workflow.tasks.index(task),
216 ),
217 )
218 )
219
220 async def stream(self, event: ThreadStreamEvent) -> None:
221 """Enqueue a ThreadStreamEvent for downstream processing."""
222 await self._events.put(event)
223
224 def _complete(self):
225 self._events.put_nowait(_QueueCompleteSentinel())
226
227
228T1 = TypeVar("T1")
229T2 = TypeVar("T2")
230
231
232async def _merge_generators(
233 a: AsyncIterator[T1],
234 b: AsyncIterator[T2],
235) -> AsyncIterator[T1 | T2]:
236 pending: list[AsyncIterator[T1 | T2]] = [a, b]
237 pending_tasks: dict[asyncio.Task, AsyncIterator[T1 | T2]] = {
238 asyncio.ensure_future(g.__anext__()): g for g in pending
239 }
240 while len(pending_tasks) > 0:
241 done, _ = await asyncio.wait(
242 pending_tasks.keys(), return_when="FIRST_COMPLETED"
243 )
244 stop = False
245 for d in done:
246 try:
247 result = d.result()
248 yield result
249 dg = pending_tasks[d]
250 pending_tasks[asyncio.ensure_future(dg.__anext__())] = dg
251 except StopAsyncIteration:
252 stop = True
253 finally:
254 del pending_tasks[d]
255 if stop:
256 for task in pending_tasks.keys():
257 if not task.cancel():
258 try:
259 yield task.result()
260 except asyncio.CancelledError:
261 pass
262 except asyncio.InvalidStateError:
263 pass
264 break
265
266
267class _EventWrapper:
268 def __init__(self, event: ThreadStreamEvent):
269 self.event = event
270
271
272class _AsyncQueueIterator(AsyncIterator[_EventWrapper]):
273 def __init__(
274 self, queue: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel]
275 ):
276 self.queue = queue
277 self.completed = False
278
279 def __aiter__(self):
280 return self
281
282 async def __anext__(self):
283 if self.completed:
284 raise StopAsyncIteration
285
286 item = await self.queue.get()
287 if isinstance(item, _QueueCompleteSentinel):
288 self.completed = True
289 raise StopAsyncIteration
290 return _EventWrapper(item)
291
292 def drain_and_complete(self) -> None:
293 """Empty the underlying queue without awaiting and mark this iterator completed.
294
295 This is intended for cleanup paths where we must guarantee no awaits
296 occur. All queued items, including any completion sentinel, are
297 discarded.
298 """
299 while True:
300 try:
301 self.queue.get_nowait()
302 except asyncio.QueueEmpty:
303 break
304 self.completed = True
305
306
307class StreamingThoughtTracker(BaseModel):
308 item_id: str
309 index: int
310 task: ThoughtTask
311
312
313class ResponseStreamConverter:
314 """Used by `stream_agent_response` to convert streamed Agents SDK output
315 into values used by ChatKit thread items and thread stream events.
316
317 Defines overridable methods for adapting streamed data (such as image
318 generation results and partial updates) into the forms expected by ChatKit.
319 """
320
321 partial_images: int | None = None
322 """
323 The expected number of partial image updates for an image generation result.
324
325 When set, this value is used to normalize partial image indices into a
326 progress value in the range [0, 1]. If unset, all partial image updates are
327 assigned a progress value of 0.
328 """
329
330 def __init__(self, *, partial_images: int | None = None):
331 """
332 Args:
333 partial_images: The expected number of partial image updates for image
334 generation results, or None if no progress normalization should
335 be performed.
336 """
337 self.partial_images = partial_images
338
339 async def base64_image_to_url(
340 self,
341 image_id: str,
342 base64_image: str,
343 partial_image_index: int | None = None,
344 ) -> str:
345 """
346 Convert a base64-encoded image into a URL.
347
348 This method is used to produce the URL stored on thread items for image
349 generation results.
350
351 Args:
352 image_id: The ID of the image generation call. This stays stable across partial image updates.
353 base64_image: The base64-encoded image.
354 partial_image_index: The index of the partial image update, starting from 0.
355
356 Returns:
357 A URL string.
358 """
359 return f"data:image/png;base64,{base64_image}"
360
361 def partial_image_index_to_progress(self, partial_image_index: int) -> float:
362 """
363 Convert a partial image index into a normalized progress value.
364
365 Args:
366 partial_image_index: The index of the partial image update, starting from 0.
367
368 Returns:
369 A float between 0 and 1 representing progress for the image
370 generation result.
371 """
372 if self.partial_images is None or self.partial_images <= 0:
373 return 0.0
374
375 return min(1.0, partial_image_index / self.partial_images)
376
377 async def file_citation_to_annotation(
378 self, file_citation: AnnotationFileCitation
379 ) -> Annotation | None:
380 """Convert a Responses API file citation into an assistant message annotation."""
381 filename = file_citation.filename
382 if not filename:
383 return None
384
385 return Annotation(
386 source=FileSource(filename=filename, title=filename),
387 index=file_citation.index,
388 )
389
390 async def container_file_citation_to_annotation(
391 self, container_file_citation: AnnotationContainerFileCitation
392 ) -> Annotation | None:
393 """Convert a Responses API container file citation into an assistant message annotation."""
394 filename = container_file_citation.filename
395 if not filename:
396 return None
397
398 return Annotation(
399 source=FileSource(filename=filename, title=filename),
400 index=container_file_citation.end_index,
401 )
402
403 async def url_citation_to_annotation(
404 self, url_citation: AnnotationURLCitation
405 ) -> Annotation | None:
406 """Convert a Responses API URL citation into an assistant message annotation."""
407 return Annotation(
408 source=URLSource(url=url_citation.url, title=url_citation.title),
409 index=url_citation.end_index,
410 )
411
412
413_DEFAULT_RESPONSE_STREAM_CONVERTER = ResponseStreamConverter()
414
415
416async def _convert_content(
417 content: Content, converter: ResponseStreamConverter
418) -> AssistantMessageContent:
419 if content.type == "output_text":
420 annotations = [
421 await _convert_annotation(annotation, converter)
422 for annotation in content.annotations
423 ]
424 annotations = [a for a in annotations if a is not None]
425 return AssistantMessageContent(
426 text=content.text,
427 annotations=annotations,
428 )
429 else:
430 return AssistantMessageContent(
431 text=content.refusal,
432 annotations=[],
433 )
434
435
436async def _convert_annotation(
437 raw_annotation: object, converter: ResponseStreamConverter
438) -> Annotation | None:
439 # There is a bug in the OpenAPI client that sometimes parses the annotation delta event into the wrong class
440 # resulting into annotation being a dict or untyped object instead instead of a ResponsesAnnotation
441 annotation = TypeAdapter[ResponsesAnnotation](ResponsesAnnotation).validate_python(
442 raw_annotation
443 )
444
445 if annotation.type == "file_citation":
446 return await converter.file_citation_to_annotation(annotation)
447
448 if annotation.type == "url_citation":
449 return await converter.url_citation_to_annotation(annotation)
450
451 if annotation.type == "container_file_citation":
452 return await converter.container_file_citation_to_annotation(annotation)
453
454 return None
455
456
457async def stream_agent_response(
458 context: AgentContext,
459 result: RunResultStreaming,
460 *,
461 converter: ResponseStreamConverter = _DEFAULT_RESPONSE_STREAM_CONVERTER,
462) -> AsyncIterator[ThreadStreamEvent]:
463 """
464 Convert a streamed Agents SDK run into ChatKit thread stream events.
465
466 This function consumes a streaming run result and yields `ThreadStreamEvent`
467 objects as the run progresses.
468
469 Args:
470 context: The AgentContext to use for the stream.
471 result: The RunResultStreaming to convert.
472 converter: Defines overridable methods for adapting streamed data (such as image
473 generation results and partial updates) into the forms expected by ChatKit.
474
475 Returns:
476 An async iterator that yields thread stream events representing the run result.
477 """
478 current_item_id = None
479 current_tool_call = None
480 ctx = context
481 thread = context.thread
482 queue_iterator = _AsyncQueueIterator(context._events)
483 produced_items = set()
484 streaming_thought: None | StreamingThoughtTracker = None
485 # item_id -> content_index -> annotation count
486 item_annotation_count: defaultdict[str, defaultdict[int, int]] = defaultdict(
487 lambda: defaultdict(int)
488 )
489
490 # check if the last item in the thread was a workflow or a client tool call
491 # if it was a client tool call, check if the second last item was a workflow
492 # if either was, continue the workflow
493 items = await context.store.load_thread_items(
494 thread.id, None, 2, "desc", context.request_context
495 )
496 last_item = items.data[0] if len(items.data) > 0 else None
497 second_last_item = items.data[1] if len(items.data) > 1 else None
498
499 if last_item and last_item.type == "workflow":
500 ctx.workflow_item = last_item
501 elif (
502 last_item
503 and last_item.type == "client_tool_call"
504 and second_last_item
505 and second_last_item.type == "workflow"
506 ):
507 ctx.workflow_item = second_last_item
508
509 def end_workflow(item: WorkflowItem):
510 if item == ctx.workflow_item:
511 ctx.workflow_item = None
512 delta = datetime.now() - item.created_at
513 duration = int(delta.total_seconds())
514 if item.workflow.summary is None:
515 item.workflow.summary = DurationSummary(duration=duration)
516 # Default to closing all workflows
517 # To keep a workflow open on completion, close it explicitly with
518 # AgentContext.end_workflow(expanded=True)
519 item.workflow.expanded = False
520 return ThreadItemDoneEvent(item=item)
521
522 try:
523 async for event in _merge_generators(result.stream_events(), queue_iterator):
524 # Events emitted from agent context helpers
525 if isinstance(event, _EventWrapper):
526 event = event.event
527 if (
528 event.type == "thread.item.added"
529 or event.type == "thread.item.done"
530 ):
531 # End the current workflow if visual item is added after it
532 if (
533 ctx.workflow_item
534 and ctx.workflow_item.id != event.item.id
535 and event.item.type != "client_tool_call"
536 and event.item.type != "hidden_context_item"
537 ):
538 yield end_workflow(ctx.workflow_item)
539
540 # track the current workflow if one is added
541 if (
542 event.type == "thread.item.added"
543 and event.item.type == "workflow"
544 ):
545 ctx.workflow_item = event.item
546
547 # track integration produced items so we can clean them up if
548 # there is a guardrail tripwire
549 produced_items.add(event.item.id)
550 yield event
551 continue
552
553 if event.type == "run_item_stream_event":
554 event = event.item
555 if (
556 event.type == "tool_call_item"
557 and event.raw_item.type == "function_call"
558 ):
559 current_tool_call = event.raw_item.call_id
560 current_item_id = event.raw_item.id
561 assert current_item_id
562 produced_items.add(current_item_id)
563 continue
564
565 if event.type != "raw_response_event":
566 # Ignore everything else that isn't a raw response event
567 continue
568
569 # Handle Responses events
570 event = event.data
571 if event.type == "response.content_part.added":
572 if event.part.type == "reasoning_text":
573 continue
574 content = await _convert_content(event.part, converter)
575 yield ThreadItemUpdatedEvent(
576 item_id=event.item_id,
577 update=AssistantMessageContentPartAdded(
578 content_index=event.content_index,
579 content=content,
580 ),
581 )
582 elif event.type == "response.output_text.delta":
583 yield ThreadItemUpdatedEvent(
584 item_id=event.item_id,
585 update=AssistantMessageContentPartTextDelta(
586 content_index=event.content_index,
587 delta=event.delta,
588 ),
589 )
590 elif event.type == "response.output_text.done":
591 yield ThreadItemUpdatedEvent(
592 item_id=event.item_id,
593 update=AssistantMessageContentPartDone(
594 content_index=event.content_index,
595 content=AssistantMessageContent(
596 text=event.text,
597 annotations=[],
598 ),
599 ),
600 )
601 elif event.type == "response.output_text.annotation.added":
602 annotation = await _convert_annotation(event.annotation, converter)
603 if annotation:
604 # Manually track annotation indices per content part in case we drop an annotation that
605 # we can't convert to our internal representation (e.g. missing filename).
606 annotation_index = item_annotation_count[event.item_id][
607 event.content_index
608 ]
609 item_annotation_count[event.item_id][event.content_index] = (
610 annotation_index + 1
611 )
612 yield ThreadItemUpdatedEvent(
613 item_id=event.item_id,
614 update=AssistantMessageContentPartAnnotationAdded(
615 content_index=event.content_index,
616 annotation_index=annotation_index,
617 annotation=annotation,
618 ),
619 )
620 continue
621 elif event.type == "response.output_item.added":
622 item = event.item
623 if item.type == "reasoning" and not ctx.workflow_item:
624 ctx.workflow_item = WorkflowItem(
625 id=ctx.generate_id("workflow"),
626 created_at=datetime.now(),
627 workflow=Workflow(type="reasoning", tasks=[]),
628 thread_id=thread.id,
629 )
630 produced_items.add(ctx.workflow_item.id)
631 yield ThreadItemAddedEvent(item=ctx.workflow_item)
632 if item.type == "message":
633 if ctx.workflow_item:
634 yield end_workflow(ctx.workflow_item)
635 produced_items.add(item.id)
636 yield ThreadItemAddedEvent(
637 item=AssistantMessageItem(
638 # Reusing the Responses message ID
639 id=item.id,
640 thread_id=thread.id,
641 content=[
642 await _convert_content(c, converter)
643 for c in item.content
644 ],
645 created_at=datetime.now(),
646 ),
647 )
648 elif item.type == "image_generation_call":
649 ctx.generated_image_item = GeneratedImageItem(
650 id=ctx.generate_id("message"),
651 thread_id=thread.id,
652 created_at=datetime.now(),
653 image=None,
654 )
655 produced_items.add(ctx.generated_image_item.id)
656 yield ThreadItemAddedEvent(item=ctx.generated_image_item)
657 elif event.type == "response.image_generation_call.partial_image":
658 if not ctx.generated_image_item:
659 continue
660
661 url = await converter.base64_image_to_url(
662 image_id=event.item_id,
663 base64_image=event.partial_image_b64,
664 partial_image_index=event.partial_image_index,
665 )
666 progress = converter.partial_image_index_to_progress(
667 event.partial_image_index
668 )
669
670 ctx.generated_image_item.image = GeneratedImage(
671 id=event.item_id, url=url
672 )
673
674 yield ThreadItemUpdatedEvent(
675 item_id=ctx.generated_image_item.id,
676 update=GeneratedImageUpdated(
677 image=ctx.generated_image_item.image, progress=progress
678 ),
679 )
680 elif event.type == "response.reasoning_summary_text.delta":
681 if not ctx.workflow_item:
682 continue
683
684 # stream the first thought in a new workflow so that we can show it earlier
685 if (
686 ctx.workflow_item.workflow.type == "reasoning"
687 and len(ctx.workflow_item.workflow.tasks) == 0
688 ):
689 streaming_thought = StreamingThoughtTracker(
690 item_id=event.item_id,
691 index=event.summary_index,
692 task=ThoughtTask(content=event.delta),
693 )
694 ctx.workflow_item.workflow.tasks.append(streaming_thought.task)
695 yield ThreadItemUpdatedEvent(
696 item_id=ctx.workflow_item.id,
697 update=WorkflowTaskAdded(
698 task=streaming_thought.task,
699 task_index=0,
700 ),
701 )
702 elif (
703 streaming_thought
704 and streaming_thought.task in ctx.workflow_item.workflow.tasks
705 and event.item_id == streaming_thought.item_id
706 and event.summary_index == streaming_thought.index
707 ):
708 streaming_thought.task.content += event.delta
709 yield ThreadItemUpdatedEvent(
710 item_id=ctx.workflow_item.id,
711 update=WorkflowTaskUpdated(
712 task=streaming_thought.task,
713 task_index=ctx.workflow_item.workflow.tasks.index(
714 streaming_thought.task
715 ),
716 ),
717 )
718 elif event.type == "response.reasoning_summary_text.done":
719 if ctx.workflow_item:
720 if (
721 streaming_thought
722 and streaming_thought.task in ctx.workflow_item.workflow.tasks
723 and event.item_id == streaming_thought.item_id
724 and event.summary_index == streaming_thought.index
725 ):
726 task = streaming_thought.task
727 task.content = event.text
728 streaming_thought = None
729 update = WorkflowTaskUpdated(
730 task=task,
731 task_index=ctx.workflow_item.workflow.tasks.index(task),
732 )
733 else:
734 task = ThoughtTask(content=event.text)
735 ctx.workflow_item.workflow.tasks.append(task)
736 update = WorkflowTaskAdded(
737 task=task,
738 task_index=ctx.workflow_item.workflow.tasks.index(task),
739 )
740 yield ThreadItemUpdatedEvent(
741 item_id=ctx.workflow_item.id,
742 update=update,
743 )
744 elif event.type == "response.output_item.done":
745 item = event.item
746 if item.type == "message":
747 produced_items.add(item.id)
748 yield ThreadItemDoneEvent(
749 item=AssistantMessageItem(
750 # Reusing the Responses message ID
751 id=item.id,
752 thread_id=thread.id,
753 content=[
754 await _convert_content(c, converter)
755 for c in item.content
756 ],
757 created_at=datetime.now(),
758 ),
759 )
760 elif item.type == "image_generation_call" and item.result:
761 if not ctx.generated_image_item:
762 continue
763
764 url = await converter.base64_image_to_url(
765 image_id=item.id,
766 base64_image=item.result,
767 )
768 image = GeneratedImage(id=item.id, url=url)
769
770 ctx.generated_image_item.image = image
771 yield ThreadItemDoneEvent(item=ctx.generated_image_item)
772
773 ctx.generated_image_item = None
774
775 except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
776 for item_id in produced_items:
777 yield ThreadItemRemovedEvent(item_id=item_id)
778
779 # Drain remaining events without processing them
780 context._complete()
781 queue_iterator.drain_and_complete()
782
783 raise
784
785 context._complete()
786
787 # Drain remaining events
788 async for event in queue_iterator:
789 yield event.event
790
791 # If there is still an active workflow at the end of the run, store
792 # it's current state so that we can continue it in the next turn.
793 if ctx.workflow_item:
794 await ctx.store.add_thread_item(
795 thread.id, ctx.workflow_item, ctx.request_context
796 )
797
798 if context.client_tool_call:
799 yield ThreadItemDoneEvent(
800 item=ClientToolCallItem(
801 id=current_item_id
802 or context.store.generate_item_id(
803 "tool_call", thread, context.request_context
804 ),
805 thread_id=thread.id,
806 name=context.client_tool_call.name,
807 arguments=context.client_tool_call.arguments,
808 created_at=datetime.now(),
809 call_id=current_tool_call
810 or context.store.generate_item_id(
811 "tool_call", thread, context.request_context
812 ),
813 ),
814 )
815
816
817TWidget = TypeVar("TWidget", bound=Markdown | Text)
818
819
820async def accumulate_text(
821 events: AsyncIterator[StreamEvent],
822 base_widget: TWidget,
823) -> AsyncIterator[TWidget]:
824 text = ""
825 yield base_widget
826 async for event in events:
827 if event.type == "raw_response_event":
828 if event.data.type == "response.output_text.delta":
829 text += event.data.delta
830 yield base_widget.model_copy(update={"value": text})
831 yield base_widget.model_copy(update={"value": text, "streaming": False})
832
833
834class ThreadItemConverter:
835 """
836 Converts thread items to Agent SDK input items.
837 Widgets, Tasks, and Workflows have default conversions but can be customized.
838 Attachments, Tags, and HiddenContextItems require custom handling based on the use case.
839 Other item types are converted automatically.
840 """
841
842 async def attachment_to_message_content(
843 self, attachment: Attachment
844 ) -> ResponseInputContentParam:
845 """
846 Convert an attachment in a user message into a message content part to send to the model.
847 Required when attachments are enabled.
848 """
849 raise NotImplementedError(
850 "An Attachment was included in a UserMessageItem but Converter.attachment_to_message_content was not implemented"
851 )
852
853 async def tag_to_message_content(
854 self, tag: UserMessageTagContent
855 ) -> ResponseInputContentParam:
856 """
857 Convert a tag in a user message into a message content part to send to the model as context.
858 Required when tags are used.
859 """
860 raise NotImplementedError(
861 "A Tag was included in a UserMessageItem but Converter.tag_to_message_content is not implemented"
862 )
863
864 async def generated_image_to_input(
865 self, item: GeneratedImageItem
866 ) -> TResponseInputItem | list[TResponseInputItem] | None:
867 """
868 Convert a GeneratedImageItem into input item(s) to send to the model.
869 Override this method to customize the conversion of generated images, such as when your
870 generated image url is not publicly reachable.
871 """
872 if not item.image:
873 return None
874
875 return Message(
876 type="message",
877 content=[
878 ResponseInputTextParam(
879 type="input_text",
880 text="The following image was generated by the agent.",
881 ),
882 ResponseInputImageParam(
883 type="input_image",
884 detail="auto",
885 image_url=item.image.url,
886 ),
887 ],
888 role="user",
889 )
890
891 async def hidden_context_to_input(
892 self, item: HiddenContextItem
893 ) -> TResponseInputItem | list[TResponseInputItem] | None:
894 """
895 Convert a HiddenContextItem into input item(s) to send to the model.
896 Required to override when HiddenContextItems with non-string content are used.
897 """
898 if not isinstance(item.content, str):
899 raise NotImplementedError(
900 "HiddenContextItems with non-string content were present in a user message but a Converter.hidden_context_to_input was not implemented"
901 )
902
903 text = (
904 "Hidden context for the agent (not shown to the user):\n"
905 f"<HiddenContext>\n{item.content}\n</HiddenContext>"
906 )
907 return Message(
908 type="message",
909 content=[
910 ResponseInputTextParam(
911 type="input_text",
912 text=text,
913 )
914 ],
915 role="user",
916 )
917
918 async def sdk_hidden_context_to_input(
919 self, item: SDKHiddenContextItem
920 ) -> TResponseInputItem | list[TResponseInputItem] | None:
921 """
922 Convert a SDKHiddenContextItem into input item to send to the model.
923 This is used by the ChatKit Python SDK for storing additional context
924 for internal operations.
925 Override if you want to wrap the content in a different format.
926 """
927 text = (
928 "Hidden context for the agent (not shown to the user):\n"
929 f"<HiddenContext>\n{item.content}\n</HiddenContext>"
930 )
931 return Message(
932 type="message",
933 content=[
934 ResponseInputTextParam(
935 type="input_text",
936 text=text,
937 )
938 ],
939 role="user",
940 )
941
942 async def task_to_input(
943 self, item: TaskItem
944 ) -> TResponseInputItem | list[TResponseInputItem] | None:
945 """
946 Convert a TaskItem into input item(s) to send to the model.
947 """
948 if item.task.type != "custom" or (
949 not item.task.title and not item.task.content
950 ):
951 return None
952 title = f"{item.task.title}" if item.task.title else ""
953 content = f"{item.task.content}" if item.task.content else ""
954 task_text = f"{title}: {content}" if title and content else title or content
955 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
956 return Message(
957 type="message",
958 content=[
959 ResponseInputTextParam(
960 type="input_text",
961 text=text,
962 )
963 ],
964 role="user",
965 )
966
967 async def workflow_to_input(
968 self, item: WorkflowItem
969 ) -> TResponseInputItem | list[TResponseInputItem] | None:
970 """
971 Convert a TaskItem into input item(s) to send to the model.
972 Returns WorkflowItem.response_items by default.
973 """
974 messages = []
975 for task in item.workflow.tasks:
976 if task.type != "custom" or (not task.title and not task.content):
977 continue
978
979 title = f"{task.title}" if task.title else ""
980 content = f"{task.content}" if task.content else ""
981 task_text = f"{title}: {content}" if title and content else title or content
982 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
983 messages.append(
984 Message(
985 type="message",
986 content=[
987 ResponseInputTextParam(
988 type="input_text",
989 text=text,
990 )
991 ],
992 role="user",
993 )
994 )
995 return messages
996
997 async def widget_to_input(
998 self, item: WidgetItem
999 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1000 """
1001 Convert a WidgetItem into input item(s) to send to the model.
1002 By default, WidgetItems converted to a text description with a JSON representation of the widget.
1003 """
1004 return Message(
1005 type="message",
1006 content=[
1007 ResponseInputTextParam(
1008 type="input_text",
1009 text=f"The following graphical UI widget (id: {item.id}) was displayed to the user:"
1010 + item.widget.model_dump_json(
1011 exclude_unset=True, exclude_none=True
1012 ),
1013 )
1014 ],
1015 role="user",
1016 )
1017
1018 async def user_message_to_input(
1019 self, item: UserMessageItem, is_last_message: bool = True
1020 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1021 # Build the user text exactly as typed, rendering tags as @key
1022 message_text_parts: list[str] = []
1023 # Track tags separately to add system context
1024 raw_tags: list[UserMessageTagContent] = []
1025
1026 for part in item.content:
1027 if isinstance(part, UserMessageTextContent):
1028 message_text_parts.append(part.text)
1029 elif isinstance(part, UserMessageTagContent):
1030 message_text_parts.append(f"@{part.text}")
1031 raw_tags.append(part)
1032 else:
1033 assert_never(part)
1034
1035 user_text_item = Message(
1036 role="user",
1037 type="message",
1038 content=[
1039 ResponseInputTextParam(
1040 type="input_text", text="".join(message_text_parts)
1041 ),
1042 *[
1043 await self.attachment_to_message_content(a)
1044 for a in item.attachments
1045 ],
1046 ],
1047 )
1048
1049 # Build system items (prepend later): quoted text and @-mention context
1050 context_items: list[TResponseInputItem] = []
1051
1052 if item.quoted_text and is_last_message:
1053 context_items.append(
1054 Message(
1055 role="user",
1056 type="message",
1057 content=[
1058 ResponseInputTextParam(
1059 type="input_text",
1060 text=f"The user is referring to this in particular: \n{item.quoted_text}",
1061 )
1062 ],
1063 )
1064 )
1065
1066 # Dedupe tags (preserve order) and resolve to message content
1067 if raw_tags:
1068 seen, uniq_tags = set(), []
1069 for t in raw_tags:
1070 if t.text not in seen:
1071 seen.add(t.text)
1072 uniq_tags.append(t)
1073
1074 tag_content: ResponseInputMessageContentListParam = [
1075 # should return summarized text items
1076 await self.tag_to_message_content(tag)
1077 for tag in uniq_tags
1078 ]
1079
1080 if tag_content:
1081 context_items.append(
1082 Message(
1083 role="user",
1084 type="message",
1085 content=[
1086 ResponseInputTextParam(
1087 type="input_text",
1088 text=cleandoc("""
1089 # User-provided context for @-mentions
1090 - When referencing resolved entities, use their canonical names **without** '@'.
1091 - The '@' form appears only in user text and should not be echoed.
1092 """).strip(),
1093 ),
1094 *tag_content,
1095 ],
1096 )
1097 )
1098
1099 return [user_text_item, *context_items]
1100
1101 async def assistant_message_to_input(
1102 self, item: AssistantMessageItem
1103 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1104 return EasyInputMessageParam(
1105 type="message",
1106 content=[
1107 # content param doesn't support the assistant message content types
1108 cast(
1109 ResponseInputContentParam,
1110 ResponseOutputText(
1111 type="output_text",
1112 text=c.text,
1113 annotations=[], # TODO: these should be sent back as well
1114 ).model_dump(),
1115 )
1116 for c in item.content
1117 ],
1118 role="assistant",
1119 )
1120
1121 async def client_tool_call_to_input(
1122 self, item: ClientToolCallItem
1123 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1124 if item.status == "pending":
1125 # Filter out pending tool calls - they cannot be sent to the model
1126 return None
1127
1128 return [
1129 ResponseFunctionToolCallParam(
1130 type="function_call",
1131 call_id=item.call_id,
1132 name=item.name,
1133 arguments=json.dumps(item.arguments),
1134 ),
1135 FunctionCallOutput(
1136 type="function_call_output",
1137 call_id=item.call_id,
1138 output=json.dumps(item.output),
1139 ),
1140 ]
1141
1142 async def end_of_turn_to_input(
1143 self, item: EndOfTurnItem
1144 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1145 # Only used for UI hints - you shouldn't need to override this
1146 return None
1147
1148 async def _thread_item_to_input_item(
1149 self,
1150 item: ThreadItem,
1151 is_last_message: bool = True,
1152 ) -> list[TResponseInputItem]:
1153 match item:
1154 case UserMessageItem():
1155 out = await self.user_message_to_input(item, is_last_message) or []
1156 return out if isinstance(out, list) else [out]
1157 case AssistantMessageItem():
1158 out = await self.assistant_message_to_input(item) or []
1159 return out if isinstance(out, list) else [out]
1160 case ClientToolCallItem():
1161 out = await self.client_tool_call_to_input(item) or []
1162 return out if isinstance(out, list) else [out]
1163 case EndOfTurnItem():
1164 out = await self.end_of_turn_to_input(item) or []
1165 return out if isinstance(out, list) else [out]
1166 case WidgetItem():
1167 out = await self.widget_to_input(item) or []
1168 return out if isinstance(out, list) else [out]
1169 case WorkflowItem():
1170 out = await self.workflow_to_input(item) or []
1171 return out if isinstance(out, list) else [out]
1172 case TaskItem():
1173 out = await self.task_to_input(item) or []
1174 return out if isinstance(out, list) else [out]
1175 case HiddenContextItem():
1176 out = await self.hidden_context_to_input(item) or []
1177 return out if isinstance(out, list) else [out]
1178 case SDKHiddenContextItem():
1179 out = await self.sdk_hidden_context_to_input(item) or []
1180 return out if isinstance(out, list) else [out]
1181 case GeneratedImageItem():
1182 out = await self.generated_image_to_input(item) or []
1183 return out if isinstance(out, list) else [out]
1184 case _:
1185 assert_never(item)
1186
1187 async def to_agent_input(
1188 self,
1189 thread_items: Sequence[ThreadItem] | ThreadItem,
1190 ) -> list[TResponseInputItem]:
1191 if isinstance(thread_items, Sequence):
1192 # shallow copy in case caller mutates the list while we're iterating
1193 thread_items = thread_items[:]
1194 else:
1195 thread_items = [thread_items]
1196 output: list[TResponseInputItem] = []
1197 for item in thread_items:
1198 output.extend(
1199 await self._thread_item_to_input_item(
1200 item,
1201 is_last_message=item is thread_items[-1],
1202 )
1203 )
1204 return output
1205
1206
1207_DEFAULT_CONVERTER = ThreadItemConverter()
1208
1209
1210def simple_to_agent_input(thread_items: Sequence[ThreadItem] | ThreadItem):
1211 """Helper that converts thread items using the default ThreadItemConverter."""
1212 return _DEFAULT_CONVERTER.to_agent_input(thread_items)
1213