openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.6.4

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/agents.py

1254lines · modecode

1import asyncio
2import json
3from collections import defaultdict
4from collections.abc import AsyncIterator
5from datetime import datetime
6from inspect import cleandoc
7from typing import (
8 Annotated,
9 Any,
10 AsyncGenerator,
11 Generic,
12 Sequence,
13 TypeVar,
14 cast,
15)
16
17from agents import (
18 InputGuardrailTripwireTriggered,
19 OutputGuardrailTripwireTriggered,
20 RunResultStreaming,
21 StreamEvent,
22 TResponseInputItem,
23)
24from openai.types.responses import (
25 EasyInputMessageParam,
26 ResponseFunctionToolCallParam,
27 ResponseInputContentParam,
28 ResponseInputImageParam,
29 ResponseInputMessageContentListParam,
30 ResponseInputTextParam,
31 ResponseOutputText,
32)
33from openai.types.responses.response_input_item_param import (
34 FunctionCallOutput,
35 Message,
36)
37from openai.types.responses.response_output_message import Content
38from openai.types.responses.response_output_text import (
39 Annotation as ResponsesAnnotation,
40)
41from openai.types.responses.response_output_text import (
42 AnnotationContainerFileCitation,
43 AnnotationFileCitation,
44 AnnotationURLCitation,
45)
46from pydantic import BaseModel, ConfigDict, SkipValidation, TypeAdapter
47from typing_extensions import assert_never
48
49from .server import stream_widget
50from .store import Store, StoreItemType
51from .types import (
52 Annotation,
53 AssistantMessageContent,
54 AssistantMessageContentPartAdded,
55 AssistantMessageContentPartAnnotationAdded,
56 AssistantMessageContentPartDone,
57 AssistantMessageContentPartTextDelta,
58 AssistantMessageItem,
59 Attachment,
60 ClientToolCallItem,
61 DurationSummary,
62 EndOfTurnItem,
63 FileSource,
64 GeneratedImage,
65 GeneratedImageItem,
66 GeneratedImageUpdated,
67 HiddenContextItem,
68 SDKHiddenContextItem,
69 StructuredInputItem,
70 Task,
71 TaskItem,
72 ThoughtTask,
73 ThreadItem,
74 ThreadItemAddedEvent,
75 ThreadItemDoneEvent,
76 ThreadItemRemovedEvent,
77 ThreadItemUpdatedEvent,
78 ThreadMetadata,
79 ThreadStreamEvent,
80 URLSource,
81 UserMessageItem,
82 UserMessageTagContent,
83 UserMessageTextContent,
84 WidgetItem,
85 Workflow,
86 WorkflowItem,
87 WorkflowSummary,
88 WorkflowTaskAdded,
89 WorkflowTaskUpdated,
90)
91from .widgets import Markdown, Text, WidgetRoot
92
93
94class ClientToolCall(BaseModel):
95 """
96 Returned from tool methods to indicate a client-side tool call.
97 """
98
99 name: str
100 arguments: dict[str, Any]
101
102
103class _QueueCompleteSentinel: ...
104
105
106TContext = TypeVar("TContext")
107
108
109class AgentContext(BaseModel, Generic[TContext]):
110 model_config = ConfigDict(arbitrary_types_allowed=True)
111
112 thread: ThreadMetadata
113 store: Annotated[Store[TContext], SkipValidation]
114 request_context: TContext
115 previous_response_id: str | None = None
116 client_tool_call: ClientToolCall | None = None
117 workflow_item: WorkflowItem | None = None
118 generated_image_item: GeneratedImageItem | None = None
119 _events: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel] = asyncio.Queue()
120
121 def generate_id(
122 self, type: StoreItemType, thread: ThreadMetadata | None = None
123 ) -> str:
124 """Generate a new store-backed id for the given item type."""
125 if type == "thread":
126 return self.store.generate_thread_id(self.request_context)
127 return self.store.generate_item_id(
128 type, thread or self.thread, self.request_context
129 )
130
131 async def stream_widget(
132 self,
133 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
134 copy_text: str | None = None,
135 ) -> None:
136 """Stream a widget into the thread by enqueueing widget events."""
137 async for event in stream_widget(
138 self.thread,
139 widget,
140 copy_text,
141 lambda item_type: self.store.generate_item_id(
142 item_type, self.thread, self.request_context
143 ),
144 ):
145 await self._events.put(event)
146
147 async def end_workflow(
148 self, summary: WorkflowSummary | None = None, expanded: bool = False
149 ) -> None:
150 """Finalize the active workflow item, optionally attaching a summary."""
151 if not self.workflow_item:
152 # No workflow to end
153 return
154
155 if summary is not None:
156 self.workflow_item.workflow.summary = summary
157 elif self.workflow_item.workflow.summary is None:
158 # If no summary was set or provided, set a basic work summary
159 delta = datetime.now() - self.workflow_item.created_at
160 duration = int(delta.total_seconds())
161 self.workflow_item.workflow.summary = DurationSummary(duration=duration)
162 self.workflow_item.workflow.expanded = expanded
163 await self.stream(ThreadItemDoneEvent(item=self.workflow_item))
164 self.workflow_item = None
165
166 async def start_workflow(self, workflow: Workflow) -> None:
167 """Begin streaming a new workflow item."""
168 self.workflow_item = WorkflowItem(
169 id=self.generate_id("workflow"),
170 created_at=datetime.now(),
171 workflow=workflow,
172 thread_id=self.thread.id,
173 )
174
175 if workflow.type != "reasoning" and len(workflow.tasks) == 0:
176 # Defer sending added event until we have tasks
177 return
178
179 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
180
181 async def update_workflow_task(self, task: Task, task_index: int) -> None:
182 """Update an existing workflow task and stream the delta."""
183 if self.workflow_item is None:
184 raise ValueError("Workflow is not set")
185 # ensure reference is updated in case task is a copy
186 self.workflow_item.workflow.tasks[task_index] = task
187 await self.stream(
188 ThreadItemUpdatedEvent(
189 item_id=self.workflow_item.id,
190 update=WorkflowTaskUpdated(
191 task=task,
192 task_index=task_index,
193 ),
194 )
195 )
196
197 async def add_workflow_task(self, task: Task) -> None:
198 """Append a workflow task and stream the appropriate event."""
199 self.workflow_item = self.workflow_item or WorkflowItem(
200 id=self.generate_id("workflow"),
201 created_at=datetime.now(),
202 workflow=Workflow(type="custom", tasks=[]),
203 thread_id=self.thread.id,
204 )
205 workflow = self.workflow_item.workflow
206 workflow.tasks.append(task)
207
208 if workflow.type != "reasoning" and len(workflow.tasks) == 1:
209 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
210 else:
211 await self.stream(
212 ThreadItemUpdatedEvent(
213 item_id=self.workflow_item.id,
214 update=WorkflowTaskAdded(
215 task=task,
216 task_index=workflow.tasks.index(task),
217 ),
218 )
219 )
220
221 async def stream(self, event: ThreadStreamEvent) -> None:
222 """Enqueue a ThreadStreamEvent for downstream processing."""
223 await self._events.put(event)
224
225 def _complete(self):
226 self._events.put_nowait(_QueueCompleteSentinel())
227
228
229T1 = TypeVar("T1")
230T2 = TypeVar("T2")
231
232
233async def _merge_generators(
234 a: AsyncIterator[T1],
235 b: AsyncIterator[T2],
236) -> AsyncIterator[T1 | T2]:
237 pending: list[AsyncIterator[T1 | T2]] = [a, b]
238 pending_tasks: dict[asyncio.Task, AsyncIterator[T1 | T2]] = {
239 asyncio.ensure_future(g.__anext__()): g for g in pending
240 }
241 while len(pending_tasks) > 0:
242 done, _ = await asyncio.wait(
243 pending_tasks.keys(), return_when="FIRST_COMPLETED"
244 )
245 stop = False
246 for d in done:
247 try:
248 result = d.result()
249 yield result
250 dg = pending_tasks[d]
251 pending_tasks[asyncio.ensure_future(dg.__anext__())] = dg
252 except StopAsyncIteration:
253 stop = True
254 finally:
255 del pending_tasks[d]
256 if stop:
257 for task in pending_tasks.keys():
258 if not task.cancel():
259 try:
260 yield task.result()
261 except asyncio.CancelledError:
262 pass
263 except asyncio.InvalidStateError:
264 pass
265 break
266
267
268class _EventWrapper:
269 def __init__(self, event: ThreadStreamEvent):
270 self.event = event
271
272
273class _AsyncQueueIterator(AsyncIterator[_EventWrapper]):
274 def __init__(
275 self, queue: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel]
276 ):
277 self.queue = queue
278 self.completed = False
279
280 def __aiter__(self):
281 return self
282
283 async def __anext__(self):
284 if self.completed:
285 raise StopAsyncIteration
286
287 item = await self.queue.get()
288 if isinstance(item, _QueueCompleteSentinel):
289 self.completed = True
290 raise StopAsyncIteration
291 return _EventWrapper(item)
292
293 def drain_and_complete(self) -> None:
294 """Empty the underlying queue without awaiting and mark this iterator completed.
295
296 This is intended for cleanup paths where we must guarantee no awaits
297 occur. All queued items, including any completion sentinel, are
298 discarded.
299 """
300 while True:
301 try:
302 self.queue.get_nowait()
303 except asyncio.QueueEmpty:
304 break
305 self.completed = True
306
307
308class StreamingThoughtTracker(BaseModel):
309 item_id: str
310 index: int
311 task: ThoughtTask
312
313
314class ResponseStreamConverter:
315 """Used by `stream_agent_response` to convert streamed Agents SDK output
316 into values used by ChatKit thread items and thread stream events.
317
318 Defines overridable methods for adapting streamed data (such as image
319 generation results and partial updates) into the forms expected by ChatKit.
320 """
321
322 partial_images: int | None = None
323 """
324 The expected number of partial image updates for an image generation result.
325
326 When set, this value is used to normalize partial image indices into a
327 progress value in the range [0, 1]. If unset, all partial image updates are
328 assigned a progress value of 0.
329 """
330
331 def __init__(self, *, partial_images: int | None = None):
332 """
333 Args:
334 partial_images: The expected number of partial image updates for image
335 generation results, or None if no progress normalization should
336 be performed.
337 """
338 self.partial_images = partial_images
339
340 async def base64_image_to_url(
341 self,
342 image_id: str,
343 base64_image: str,
344 partial_image_index: int | None = None,
345 ) -> str:
346 """
347 Convert a base64-encoded image into a URL.
348
349 This method is used to produce the URL stored on thread items for image
350 generation results.
351
352 Args:
353 image_id: The ID of the image generation call. This stays stable across partial image updates.
354 base64_image: The base64-encoded image.
355 partial_image_index: The index of the partial image update, starting from 0.
356
357 Returns:
358 A URL string.
359 """
360 return f"data:image/png;base64,{base64_image}"
361
362 def partial_image_index_to_progress(self, partial_image_index: int) -> float:
363 """
364 Convert a partial image index into a normalized progress value.
365
366 Args:
367 partial_image_index: The index of the partial image update, starting from 0.
368
369 Returns:
370 A float between 0 and 1 representing progress for the image
371 generation result.
372 """
373 if self.partial_images is None or self.partial_images <= 0:
374 return 0.0
375
376 return min(1.0, partial_image_index / self.partial_images)
377
378 async def file_citation_to_annotation(
379 self, file_citation: AnnotationFileCitation
380 ) -> Annotation | None:
381 """Convert a Responses API file citation into an assistant message annotation."""
382 filename = file_citation.filename
383 if not filename:
384 return None
385
386 return Annotation(
387 source=FileSource(filename=filename, title=filename),
388 index=file_citation.index,
389 )
390
391 async def container_file_citation_to_annotation(
392 self, container_file_citation: AnnotationContainerFileCitation
393 ) -> Annotation | None:
394 """Convert a Responses API container file citation into an assistant message annotation."""
395 filename = container_file_citation.filename
396 if not filename:
397 return None
398
399 return Annotation(
400 source=FileSource(filename=filename, title=filename),
401 index=container_file_citation.end_index,
402 )
403
404 async def url_citation_to_annotation(
405 self, url_citation: AnnotationURLCitation
406 ) -> Annotation | None:
407 """Convert a Responses API URL citation into an assistant message annotation."""
408 return Annotation(
409 source=URLSource(url=url_citation.url, title=url_citation.title),
410 index=url_citation.end_index,
411 )
412
413
414_DEFAULT_RESPONSE_STREAM_CONVERTER = ResponseStreamConverter()
415
416
417async def _convert_content(
418 content: Content, converter: ResponseStreamConverter
419) -> AssistantMessageContent:
420 if content.type == "output_text":
421 annotations = [
422 await _convert_annotation(annotation, converter)
423 for annotation in content.annotations
424 ]
425 annotations = [a for a in annotations if a is not None]
426 return AssistantMessageContent(
427 text=content.text,
428 annotations=annotations,
429 )
430 else:
431 return AssistantMessageContent(
432 text=content.refusal,
433 annotations=[],
434 )
435
436
437async def _convert_annotation(
438 raw_annotation: object, converter: ResponseStreamConverter
439) -> Annotation | None:
440 # There is a bug in the OpenAPI client that sometimes parses the annotation delta event into the wrong class
441 # resulting into annotation being a dict or untyped object instead instead of a ResponsesAnnotation
442 annotation = TypeAdapter[ResponsesAnnotation](ResponsesAnnotation).validate_python(
443 raw_annotation
444 )
445
446 if annotation.type == "file_citation":
447 return await converter.file_citation_to_annotation(annotation)
448
449 if annotation.type == "url_citation":
450 return await converter.url_citation_to_annotation(annotation)
451
452 if annotation.type == "container_file_citation":
453 return await converter.container_file_citation_to_annotation(annotation)
454
455 return None
456
457
458async def stream_agent_response(
459 context: AgentContext,
460 result: RunResultStreaming,
461 *,
462 converter: ResponseStreamConverter = _DEFAULT_RESPONSE_STREAM_CONVERTER,
463) -> AsyncIterator[ThreadStreamEvent]:
464 """
465 Convert a streamed Agents SDK run into ChatKit thread stream events.
466
467 This function consumes a streaming run result and yields `ThreadStreamEvent`
468 objects as the run progresses.
469
470 Args:
471 context: The AgentContext to use for the stream.
472 result: The RunResultStreaming to convert.
473 converter: Defines overridable methods for adapting streamed data (such as image
474 generation results and partial updates) into the forms expected by ChatKit.
475
476 Returns:
477 An async iterator that yields thread stream events representing the run result.
478 """
479 current_item_id = None
480 current_tool_call = None
481 ctx = context
482 thread = context.thread
483 queue_iterator = _AsyncQueueIterator(context._events)
484 produced_items = set()
485 streaming_thought: None | StreamingThoughtTracker = None
486 # item_id -> content_index -> annotation count
487 item_annotation_count: defaultdict[str, defaultdict[int, int]] = defaultdict(
488 lambda: defaultdict(int)
489 )
490
491 # check if the last item in the thread was a workflow or a client tool call
492 # if it was a client tool call, check if the second last item was a workflow
493 # if either was, continue the workflow
494 items = await context.store.load_thread_items(
495 thread.id, None, 2, "desc", context.request_context
496 )
497 last_item = items.data[0] if len(items.data) > 0 else None
498 second_last_item = items.data[1] if len(items.data) > 1 else None
499
500 if last_item and last_item.type == "workflow":
501 ctx.workflow_item = last_item
502 elif (
503 last_item
504 and last_item.type == "client_tool_call"
505 and second_last_item
506 and second_last_item.type == "workflow"
507 ):
508 ctx.workflow_item = second_last_item
509
510 def end_workflow(item: WorkflowItem):
511 if item == ctx.workflow_item:
512 ctx.workflow_item = None
513 delta = datetime.now() - item.created_at
514 duration = int(delta.total_seconds())
515 if item.workflow.summary is None:
516 item.workflow.summary = DurationSummary(duration=duration)
517 # Default to closing all workflows
518 # To keep a workflow open on completion, close it explicitly with
519 # AgentContext.end_workflow(expanded=True)
520 item.workflow.expanded = False
521 return ThreadItemDoneEvent(item=item)
522
523 try:
524 async for event in _merge_generators(result.stream_events(), queue_iterator):
525 # Events emitted from agent context helpers
526 if isinstance(event, _EventWrapper):
527 event = event.event
528 if (
529 event.type == "thread.item.added"
530 or event.type == "thread.item.done"
531 ):
532 # End the current workflow if visual item is added after it
533 if (
534 ctx.workflow_item
535 and ctx.workflow_item.id != event.item.id
536 and event.item.type != "client_tool_call"
537 and event.item.type != "hidden_context_item"
538 ):
539 yield end_workflow(ctx.workflow_item)
540
541 # track the current workflow if one is added
542 if (
543 event.type == "thread.item.added"
544 and event.item.type == "workflow"
545 ):
546 ctx.workflow_item = event.item
547
548 # track integration produced items so we can clean them up if
549 # there is a guardrail tripwire
550 produced_items.add(event.item.id)
551 yield event
552 continue
553
554 if event.type == "run_item_stream_event":
555 event = event.item
556 if event.type == "tool_call_item":
557 raw_item = event.raw_item
558 if isinstance(raw_item, dict):
559 if raw_item.get("type") == "function_call":
560 current_tool_call = event.call_id
561 current_item_id = raw_item.get("id")
562 assert current_item_id
563 produced_items.add(current_item_id)
564 elif raw_item.type == "function_call":
565 current_tool_call = event.call_id
566 current_item_id = raw_item.id
567 assert current_item_id
568 produced_items.add(current_item_id)
569 continue
570
571 if event.type != "raw_response_event":
572 # Ignore everything else that isn't a raw response event
573 continue
574
575 # Handle Responses events
576 event = event.data
577 if event.type == "response.content_part.added":
578 if event.part.type == "reasoning_text":
579 continue
580 content = await _convert_content(event.part, converter)
581 yield ThreadItemUpdatedEvent(
582 item_id=event.item_id,
583 update=AssistantMessageContentPartAdded(
584 content_index=event.content_index,
585 content=content,
586 ),
587 )
588 elif event.type == "response.output_text.delta":
589 yield ThreadItemUpdatedEvent(
590 item_id=event.item_id,
591 update=AssistantMessageContentPartTextDelta(
592 content_index=event.content_index,
593 delta=event.delta,
594 ),
595 )
596 elif event.type == "response.output_text.done":
597 yield ThreadItemUpdatedEvent(
598 item_id=event.item_id,
599 update=AssistantMessageContentPartDone(
600 content_index=event.content_index,
601 content=AssistantMessageContent(
602 text=event.text,
603 annotations=[],
604 ),
605 ),
606 )
607 elif event.type == "response.output_text.annotation.added":
608 annotation = await _convert_annotation(event.annotation, converter)
609 if annotation:
610 # Manually track annotation indices per content part in case we drop an annotation that
611 # we can't convert to our internal representation (e.g. missing filename).
612 annotation_index = item_annotation_count[event.item_id][
613 event.content_index
614 ]
615 item_annotation_count[event.item_id][event.content_index] = (
616 annotation_index + 1
617 )
618 yield ThreadItemUpdatedEvent(
619 item_id=event.item_id,
620 update=AssistantMessageContentPartAnnotationAdded(
621 content_index=event.content_index,
622 annotation_index=annotation_index,
623 annotation=annotation,
624 ),
625 )
626 continue
627 elif event.type == "response.output_item.added":
628 item = event.item
629 if item.type == "reasoning" and not ctx.workflow_item:
630 ctx.workflow_item = WorkflowItem(
631 id=ctx.generate_id("workflow"),
632 created_at=datetime.now(),
633 workflow=Workflow(type="reasoning", tasks=[]),
634 thread_id=thread.id,
635 )
636 produced_items.add(ctx.workflow_item.id)
637 yield ThreadItemAddedEvent(item=ctx.workflow_item)
638 if item.type == "message":
639 if ctx.workflow_item:
640 yield end_workflow(ctx.workflow_item)
641 produced_items.add(item.id)
642 yield ThreadItemAddedEvent(
643 item=AssistantMessageItem(
644 # Reusing the Responses message ID
645 id=item.id,
646 thread_id=thread.id,
647 content=[
648 await _convert_content(c, converter)
649 for c in item.content
650 ],
651 created_at=datetime.now(),
652 ),
653 )
654 elif item.type == "image_generation_call":
655 ctx.generated_image_item = GeneratedImageItem(
656 id=ctx.generate_id("message"),
657 thread_id=thread.id,
658 created_at=datetime.now(),
659 image=None,
660 )
661 produced_items.add(ctx.generated_image_item.id)
662 yield ThreadItemAddedEvent(item=ctx.generated_image_item)
663 elif event.type == "response.image_generation_call.partial_image":
664 if not ctx.generated_image_item:
665 continue
666
667 url = await converter.base64_image_to_url(
668 image_id=event.item_id,
669 base64_image=event.partial_image_b64,
670 partial_image_index=event.partial_image_index,
671 )
672 progress = converter.partial_image_index_to_progress(
673 event.partial_image_index
674 )
675
676 ctx.generated_image_item.image = GeneratedImage(
677 id=event.item_id, url=url
678 )
679
680 yield ThreadItemUpdatedEvent(
681 item_id=ctx.generated_image_item.id,
682 update=GeneratedImageUpdated(
683 image=ctx.generated_image_item.image, progress=progress
684 ),
685 )
686 elif event.type == "response.reasoning_summary_text.delta":
687 if not ctx.workflow_item:
688 continue
689
690 # stream the first thought in a new workflow so that we can show it earlier
691 if (
692 ctx.workflow_item.workflow.type == "reasoning"
693 and len(ctx.workflow_item.workflow.tasks) == 0
694 ):
695 streaming_thought = StreamingThoughtTracker(
696 item_id=event.item_id,
697 index=event.summary_index,
698 task=ThoughtTask(content=event.delta),
699 )
700 ctx.workflow_item.workflow.tasks.append(streaming_thought.task)
701 yield ThreadItemUpdatedEvent(
702 item_id=ctx.workflow_item.id,
703 update=WorkflowTaskAdded(
704 task=streaming_thought.task,
705 task_index=0,
706 ),
707 )
708 elif (
709 streaming_thought
710 and streaming_thought.task in ctx.workflow_item.workflow.tasks
711 and event.item_id == streaming_thought.item_id
712 and event.summary_index == streaming_thought.index
713 ):
714 streaming_thought.task.content += event.delta
715 yield ThreadItemUpdatedEvent(
716 item_id=ctx.workflow_item.id,
717 update=WorkflowTaskUpdated(
718 task=streaming_thought.task,
719 task_index=ctx.workflow_item.workflow.tasks.index(
720 streaming_thought.task
721 ),
722 ),
723 )
724 elif event.type == "response.reasoning_summary_text.done":
725 if ctx.workflow_item:
726 if (
727 streaming_thought
728 and streaming_thought.task in ctx.workflow_item.workflow.tasks
729 and event.item_id == streaming_thought.item_id
730 and event.summary_index == streaming_thought.index
731 ):
732 task = streaming_thought.task
733 task.content = event.text
734 streaming_thought = None
735 update = WorkflowTaskUpdated(
736 task=task,
737 task_index=ctx.workflow_item.workflow.tasks.index(task),
738 )
739 else:
740 task = ThoughtTask(content=event.text)
741 ctx.workflow_item.workflow.tasks.append(task)
742 update = WorkflowTaskAdded(
743 task=task,
744 task_index=ctx.workflow_item.workflow.tasks.index(task),
745 )
746 yield ThreadItemUpdatedEvent(
747 item_id=ctx.workflow_item.id,
748 update=update,
749 )
750 elif event.type == "response.output_item.done":
751 item = event.item
752 if item.type == "message":
753 produced_items.add(item.id)
754 yield ThreadItemDoneEvent(
755 item=AssistantMessageItem(
756 # Reusing the Responses message ID
757 id=item.id,
758 thread_id=thread.id,
759 content=[
760 await _convert_content(c, converter)
761 for c in item.content
762 ],
763 created_at=datetime.now(),
764 ),
765 )
766 elif item.type == "image_generation_call" and item.result:
767 if not ctx.generated_image_item:
768 continue
769
770 url = await converter.base64_image_to_url(
771 image_id=item.id,
772 base64_image=item.result,
773 )
774 image = GeneratedImage(id=item.id, url=url)
775
776 ctx.generated_image_item.image = image
777 yield ThreadItemDoneEvent(item=ctx.generated_image_item)
778
779 ctx.generated_image_item = None
780
781 except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
782 for item_id in produced_items:
783 yield ThreadItemRemovedEvent(item_id=item_id)
784
785 # Drain remaining events without processing them
786 context._complete()
787 queue_iterator.drain_and_complete()
788
789 raise
790
791 context._complete()
792
793 # Drain remaining events
794 async for event in queue_iterator:
795 yield event.event
796
797 # If there is still an active workflow at the end of the run, store
798 # it's current state so that we can continue it in the next turn.
799 if ctx.workflow_item:
800 await ctx.store.add_thread_item(
801 thread.id, ctx.workflow_item, ctx.request_context
802 )
803
804 if context.client_tool_call:
805 yield ThreadItemDoneEvent(
806 item=ClientToolCallItem(
807 id=current_item_id
808 or context.store.generate_item_id(
809 "tool_call", thread, context.request_context
810 ),
811 thread_id=thread.id,
812 name=context.client_tool_call.name,
813 arguments=context.client_tool_call.arguments,
814 created_at=datetime.now(),
815 call_id=current_tool_call
816 or context.store.generate_item_id(
817 "tool_call", thread, context.request_context
818 ),
819 ),
820 )
821
822
823TWidget = TypeVar("TWidget", bound=Markdown | Text)
824
825
826async def accumulate_text(
827 events: AsyncIterator[StreamEvent],
828 base_widget: TWidget,
829) -> AsyncIterator[TWidget]:
830 text = ""
831 yield base_widget
832 async for event in events:
833 if event.type == "raw_response_event":
834 if event.data.type == "response.output_text.delta":
835 text += event.data.delta
836 yield base_widget.model_copy(update={"value": text})
837 yield base_widget.model_copy(update={"value": text, "streaming": False})
838
839
840class ThreadItemConverter:
841 """
842 Converts thread items to Agent SDK input items.
843 Widgets, Tasks, and Workflows have default conversions but can be customized.
844 Attachments, Tags, and HiddenContextItems require custom handling based on the use case.
845 Other item types are converted automatically.
846 """
847
848 async def attachment_to_message_content(
849 self, attachment: Attachment
850 ) -> ResponseInputContentParam:
851 """
852 Convert an attachment in a user message into a message content part to send to the model.
853 Required when attachments are enabled.
854 """
855 raise NotImplementedError(
856 "An Attachment was included in a UserMessageItem but Converter.attachment_to_message_content was not implemented"
857 )
858
859 async def tag_to_message_content(
860 self, tag: UserMessageTagContent
861 ) -> ResponseInputContentParam:
862 """
863 Convert a tag in a user message into a message content part to send to the model as context.
864 Required when tags are used.
865 """
866 raise NotImplementedError(
867 "A Tag was included in a UserMessageItem but Converter.tag_to_message_content is not implemented"
868 )
869
870 async def generated_image_to_input(
871 self, item: GeneratedImageItem
872 ) -> TResponseInputItem | list[TResponseInputItem] | None:
873 """
874 Convert a GeneratedImageItem into input item(s) to send to the model.
875 Override this method to customize the conversion of generated images, such as when your
876 generated image url is not publicly reachable.
877 """
878 if not item.image:
879 return None
880
881 return Message(
882 type="message",
883 content=[
884 ResponseInputTextParam(
885 type="input_text",
886 text="The following image was generated by the agent.",
887 ),
888 ResponseInputImageParam(
889 type="input_image",
890 detail="auto",
891 image_url=item.image.url,
892 ),
893 ],
894 role="user",
895 )
896
897 async def hidden_context_to_input(
898 self, item: HiddenContextItem
899 ) -> TResponseInputItem | list[TResponseInputItem] | None:
900 """
901 Convert a HiddenContextItem into input item(s) to send to the model.
902 Required to override when HiddenContextItems with non-string content are used.
903 """
904 if not isinstance(item.content, str):
905 raise NotImplementedError(
906 "HiddenContextItems with non-string content were present in a user message but a Converter.hidden_context_to_input was not implemented"
907 )
908
909 text = (
910 "Hidden context for the agent (not shown to the user):\n"
911 f"<HiddenContext>\n{item.content}\n</HiddenContext>"
912 )
913 return Message(
914 type="message",
915 content=[
916 ResponseInputTextParam(
917 type="input_text",
918 text=text,
919 )
920 ],
921 role="user",
922 )
923
924 async def sdk_hidden_context_to_input(
925 self, item: SDKHiddenContextItem
926 ) -> TResponseInputItem | list[TResponseInputItem] | None:
927 """
928 Convert a SDKHiddenContextItem into input item to send to the model.
929 This is used by the ChatKit Python SDK for storing additional context
930 for internal operations.
931 Override if you want to wrap the content in a different format.
932 """
933 text = (
934 "Hidden context for the agent (not shown to the user):\n"
935 f"<HiddenContext>\n{item.content}\n</HiddenContext>"
936 )
937 return Message(
938 type="message",
939 content=[
940 ResponseInputTextParam(
941 type="input_text",
942 text=text,
943 )
944 ],
945 role="user",
946 )
947
948 async def task_to_input(
949 self, item: TaskItem
950 ) -> TResponseInputItem | list[TResponseInputItem] | None:
951 """
952 Convert a TaskItem into input item(s) to send to the model.
953 """
954 if item.task.type != "custom" or (
955 not item.task.title and not item.task.content
956 ):
957 return None
958 title = f"{item.task.title}" if item.task.title else ""
959 content = f"{item.task.content}" if item.task.content else ""
960 task_text = f"{title}: {content}" if title and content else title or content
961 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
962 return Message(
963 type="message",
964 content=[
965 ResponseInputTextParam(
966 type="input_text",
967 text=text,
968 )
969 ],
970 role="user",
971 )
972
973 async def structured_input_to_input(
974 self, item: StructuredInputItem
975 ) -> TResponseInputItem | list[TResponseInputItem] | None:
976 """
977 Convert a StructuredInputItem into input item(s) to send to the model.
978 """
979 lines = []
980 for question in item.inputs:
981 answer = question.answer
982 if answer is None:
983 lines.append(f"- {question.question}: unanswered")
984 elif answer.skipped:
985 lines.append(f"- {question.question}: skipped")
986 else:
987 lines.append(f"- {question.question}: {', '.join(answer.values)}")
988
989 text = (
990 "A structured input request was displayed to the user with the following "
991 f"status: {item.status}\n<StructuredInput>\n"
992 + "\n".join(lines)
993 + "\n</StructuredInput>"
994 )
995 return Message(
996 type="message",
997 content=[
998 ResponseInputTextParam(
999 type="input_text",
1000 text=text,
1001 ),
1002 ],
1003 role="user",
1004 )
1005
1006 async def workflow_to_input(
1007 self, item: WorkflowItem
1008 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1009 """
1010 Convert a TaskItem into input item(s) to send to the model.
1011 Returns WorkflowItem.response_items by default.
1012 """
1013 messages = []
1014 for task in item.workflow.tasks:
1015 if task.type != "custom" or (not task.title and not task.content):
1016 continue
1017
1018 title = f"{task.title}" if task.title else ""
1019 content = f"{task.content}" if task.content else ""
1020 task_text = f"{title}: {content}" if title and content else title or content
1021 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
1022 messages.append(
1023 Message(
1024 type="message",
1025 content=[
1026 ResponseInputTextParam(
1027 type="input_text",
1028 text=text,
1029 )
1030 ],
1031 role="user",
1032 )
1033 )
1034 return messages
1035
1036 async def widget_to_input(
1037 self, item: WidgetItem
1038 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1039 """
1040 Convert a WidgetItem into input item(s) to send to the model.
1041 By default, WidgetItems converted to a text description with a JSON representation of the widget.
1042 """
1043 return Message(
1044 type="message",
1045 content=[
1046 ResponseInputTextParam(
1047 type="input_text",
1048 text=f"The following graphical UI widget (id: {item.id}) was displayed to the user:"
1049 + item.widget.model_dump_json(
1050 exclude_unset=True, exclude_none=True
1051 ),
1052 )
1053 ],
1054 role="user",
1055 )
1056
1057 async def user_message_to_input(
1058 self, item: UserMessageItem, is_last_message: bool = True
1059 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1060 # Build the user text exactly as typed, rendering tags as @key
1061 message_text_parts: list[str] = []
1062 # Track tags separately to add system context
1063 raw_tags: list[UserMessageTagContent] = []
1064
1065 for part in item.content:
1066 if isinstance(part, UserMessageTextContent):
1067 message_text_parts.append(part.text)
1068 elif isinstance(part, UserMessageTagContent):
1069 message_text_parts.append(f"@{part.text}")
1070 raw_tags.append(part)
1071 else:
1072 assert_never(part)
1073
1074 user_text_item = Message(
1075 role="user",
1076 type="message",
1077 content=[
1078 ResponseInputTextParam(
1079 type="input_text", text="".join(message_text_parts)
1080 ),
1081 *[
1082 await self.attachment_to_message_content(a)
1083 for a in item.attachments
1084 ],
1085 ],
1086 )
1087
1088 # Build system items (prepend later): quoted text and @-mention context
1089 context_items: list[TResponseInputItem] = []
1090
1091 if item.quoted_text and is_last_message:
1092 context_items.append(
1093 Message(
1094 role="user",
1095 type="message",
1096 content=[
1097 ResponseInputTextParam(
1098 type="input_text",
1099 text=f"The user is referring to this in particular: \n{item.quoted_text}",
1100 )
1101 ],
1102 )
1103 )
1104
1105 # Dedupe tags (preserve order) and resolve to message content
1106 if raw_tags:
1107 seen, uniq_tags = set(), []
1108 for t in raw_tags:
1109 if t.text not in seen:
1110 seen.add(t.text)
1111 uniq_tags.append(t)
1112
1113 tag_content: ResponseInputMessageContentListParam = [
1114 # should return summarized text items
1115 await self.tag_to_message_content(tag)
1116 for tag in uniq_tags
1117 ]
1118
1119 if tag_content:
1120 context_items.append(
1121 Message(
1122 role="user",
1123 type="message",
1124 content=[
1125 ResponseInputTextParam(
1126 type="input_text",
1127 text=cleandoc("""
1128 # User-provided context for @-mentions
1129 - When referencing resolved entities, use their canonical names **without** '@'.
1130 - The '@' form appears only in user text and should not be echoed.
1131 """).strip(),
1132 ),
1133 *tag_content,
1134 ],
1135 )
1136 )
1137
1138 return [user_text_item, *context_items]
1139
1140 async def assistant_message_to_input(
1141 self, item: AssistantMessageItem
1142 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1143 return EasyInputMessageParam(
1144 type="message",
1145 content=[
1146 # content param doesn't support the assistant message content types
1147 cast(
1148 ResponseInputContentParam,
1149 ResponseOutputText(
1150 type="output_text",
1151 text=c.text,
1152 annotations=[], # TODO: these should be sent back as well
1153 ).model_dump(),
1154 )
1155 for c in item.content
1156 ],
1157 role="assistant",
1158 )
1159
1160 async def client_tool_call_to_input(
1161 self, item: ClientToolCallItem
1162 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1163 if item.status == "pending":
1164 # Filter out pending tool calls - they cannot be sent to the model
1165 return None
1166
1167 return [
1168 ResponseFunctionToolCallParam(
1169 type="function_call",
1170 call_id=item.call_id,
1171 name=item.name,
1172 arguments=json.dumps(item.arguments),
1173 ),
1174 FunctionCallOutput(
1175 type="function_call_output",
1176 call_id=item.call_id,
1177 output=json.dumps(item.output),
1178 ),
1179 ]
1180
1181 async def end_of_turn_to_input(
1182 self, item: EndOfTurnItem
1183 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1184 # Only used for UI hints - you shouldn't need to override this
1185 return None
1186
1187 async def _thread_item_to_input_item(
1188 self,
1189 item: ThreadItem,
1190 is_last_message: bool = True,
1191 ) -> list[TResponseInputItem]:
1192 match item:
1193 case UserMessageItem():
1194 out = await self.user_message_to_input(item, is_last_message) or []
1195 return out if isinstance(out, list) else [out]
1196 case AssistantMessageItem():
1197 out = await self.assistant_message_to_input(item) or []
1198 return out if isinstance(out, list) else [out]
1199 case ClientToolCallItem():
1200 out = await self.client_tool_call_to_input(item) or []
1201 return out if isinstance(out, list) else [out]
1202 case EndOfTurnItem():
1203 out = await self.end_of_turn_to_input(item) or []
1204 return out if isinstance(out, list) else [out]
1205 case WidgetItem():
1206 out = await self.widget_to_input(item) or []
1207 return out if isinstance(out, list) else [out]
1208 case WorkflowItem():
1209 out = await self.workflow_to_input(item) or []
1210 return out if isinstance(out, list) else [out]
1211 case TaskItem():
1212 out = await self.task_to_input(item) or []
1213 return out if isinstance(out, list) else [out]
1214 case StructuredInputItem():
1215 out = await self.structured_input_to_input(item) or []
1216 return out if isinstance(out, list) else [out]
1217 case HiddenContextItem():
1218 out = await self.hidden_context_to_input(item) or []
1219 return out if isinstance(out, list) else [out]
1220 case SDKHiddenContextItem():
1221 out = await self.sdk_hidden_context_to_input(item) or []
1222 return out if isinstance(out, list) else [out]
1223 case GeneratedImageItem():
1224 out = await self.generated_image_to_input(item) or []
1225 return out if isinstance(out, list) else [out]
1226 case _:
1227 assert_never(item)
1228
1229 async def to_agent_input(
1230 self,
1231 thread_items: Sequence[ThreadItem] | ThreadItem,
1232 ) -> list[TResponseInputItem]:
1233 if isinstance(thread_items, Sequence):
1234 # shallow copy in case caller mutates the list while we're iterating
1235 thread_items = thread_items[:]
1236 else:
1237 thread_items = [thread_items]
1238 output: list[TResponseInputItem] = []
1239 for item in thread_items:
1240 output.extend(
1241 await self._thread_item_to_input_item(
1242 item,
1243 is_last_message=item is thread_items[-1],
1244 )
1245 )
1246 return output
1247
1248
1249_DEFAULT_CONVERTER = ThreadItemConverter()
1250
1251
1252def simple_to_agent_input(thread_items: Sequence[ThreadItem] | ThreadItem):
1253 """Helper that converts thread items using the default ThreadItemConverter."""
1254 return _DEFAULT_CONVERTER.to_agent_input(thread_items)
1255