openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
8b263facaeaba1c1035a4d92befabf8f2c75436c

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/agents.py

1249lines · modecode

1import asyncio
2import json
3from collections import defaultdict
4from collections.abc import AsyncIterator
5from datetime import datetime
6from inspect import cleandoc
7from typing import (
8 Annotated,
9 Any,
10 AsyncGenerator,
11 Generic,
12 Sequence,
13 TypeVar,
14 cast,
15)
16
17from agents import (
18 InputGuardrailTripwireTriggered,
19 OutputGuardrailTripwireTriggered,
20 RunResultStreaming,
21 StreamEvent,
22 TResponseInputItem,
23)
24from openai.types.responses import (
25 EasyInputMessageParam,
26 ResponseFunctionToolCallParam,
27 ResponseInputContentParam,
28 ResponseInputImageParam,
29 ResponseInputMessageContentListParam,
30 ResponseInputTextParam,
31 ResponseOutputText,
32)
33from openai.types.responses.response_input_item_param import (
34 FunctionCallOutput,
35 Message,
36)
37from openai.types.responses.response_output_message import Content
38from openai.types.responses.response_output_text import (
39 Annotation as ResponsesAnnotation,
40)
41from openai.types.responses.response_output_text import (
42 AnnotationContainerFileCitation,
43 AnnotationFileCitation,
44 AnnotationURLCitation,
45)
46from pydantic import BaseModel, ConfigDict, SkipValidation, TypeAdapter
47from typing_extensions import assert_never
48
49from .server import stream_widget
50from .store import Store, StoreItemType
51from .types import (
52 Annotation,
53 AssistantMessageContent,
54 AssistantMessageContentPartAdded,
55 AssistantMessageContentPartAnnotationAdded,
56 AssistantMessageContentPartDone,
57 AssistantMessageContentPartTextDelta,
58 AssistantMessageItem,
59 Attachment,
60 ClientToolCallItem,
61 DurationSummary,
62 EndOfTurnItem,
63 FileSource,
64 GeneratedImage,
65 GeneratedImageItem,
66 GeneratedImageUpdated,
67 HiddenContextItem,
68 SDKHiddenContextItem,
69 StructuredInputItem,
70 Task,
71 TaskItem,
72 ThoughtTask,
73 ThreadItem,
74 ThreadItemAddedEvent,
75 ThreadItemDoneEvent,
76 ThreadItemRemovedEvent,
77 ThreadItemUpdatedEvent,
78 ThreadMetadata,
79 ThreadStreamEvent,
80 URLSource,
81 UserMessageItem,
82 UserMessageTagContent,
83 UserMessageTextContent,
84 WidgetItem,
85 Workflow,
86 WorkflowItem,
87 WorkflowSummary,
88 WorkflowTaskAdded,
89 WorkflowTaskUpdated,
90)
91from .widgets import Markdown, Text, WidgetRoot
92
93
94class ClientToolCall(BaseModel):
95 """
96 Returned from tool methods to indicate a client-side tool call.
97 """
98
99 name: str
100 arguments: dict[str, Any]
101
102
103class _QueueCompleteSentinel: ...
104
105
106TContext = TypeVar("TContext")
107
108
109class AgentContext(BaseModel, Generic[TContext]):
110 model_config = ConfigDict(arbitrary_types_allowed=True)
111
112 thread: ThreadMetadata
113 store: Annotated[Store[TContext], SkipValidation]
114 request_context: TContext
115 previous_response_id: str | None = None
116 client_tool_call: ClientToolCall | None = None
117 workflow_item: WorkflowItem | None = None
118 generated_image_item: GeneratedImageItem | None = None
119 _events: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel] = asyncio.Queue()
120
121 def generate_id(
122 self, type: StoreItemType, thread: ThreadMetadata | None = None
123 ) -> str:
124 """Generate a new store-backed id for the given item type."""
125 if type == "thread":
126 return self.store.generate_thread_id(self.request_context)
127 return self.store.generate_item_id(
128 type, thread or self.thread, self.request_context
129 )
130
131 async def stream_widget(
132 self,
133 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
134 copy_text: str | None = None,
135 ) -> None:
136 """Stream a widget into the thread by enqueueing widget events."""
137 async for event in stream_widget(
138 self.thread,
139 widget,
140 copy_text,
141 lambda item_type: self.store.generate_item_id(
142 item_type, self.thread, self.request_context
143 ),
144 ):
145 await self._events.put(event)
146
147 async def end_workflow(
148 self, summary: WorkflowSummary | None = None, expanded: bool = False
149 ) -> None:
150 """Finalize the active workflow item, optionally attaching a summary."""
151 if not self.workflow_item:
152 # No workflow to end
153 return
154
155 if summary is not None:
156 self.workflow_item.workflow.summary = summary
157 elif self.workflow_item.workflow.summary is None:
158 # If no summary was set or provided, set a basic work summary
159 delta = datetime.now() - self.workflow_item.created_at
160 duration = int(delta.total_seconds())
161 self.workflow_item.workflow.summary = DurationSummary(duration=duration)
162 self.workflow_item.workflow.expanded = expanded
163 await self.stream(ThreadItemDoneEvent(item=self.workflow_item))
164 self.workflow_item = None
165
166 async def start_workflow(self, workflow: Workflow) -> None:
167 """Begin streaming a new workflow item."""
168 self.workflow_item = WorkflowItem(
169 id=self.generate_id("workflow"),
170 created_at=datetime.now(),
171 workflow=workflow,
172 thread_id=self.thread.id,
173 )
174
175 if workflow.type != "reasoning" and len(workflow.tasks) == 0:
176 # Defer sending added event until we have tasks
177 return
178
179 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
180
181 async def update_workflow_task(self, task: Task, task_index: int) -> None:
182 """Update an existing workflow task and stream the delta."""
183 if self.workflow_item is None:
184 raise ValueError("Workflow is not set")
185 # ensure reference is updated in case task is a copy
186 self.workflow_item.workflow.tasks[task_index] = task
187 await self.stream(
188 ThreadItemUpdatedEvent(
189 item_id=self.workflow_item.id,
190 update=WorkflowTaskUpdated(
191 task=task,
192 task_index=task_index,
193 ),
194 )
195 )
196
197 async def add_workflow_task(self, task: Task) -> None:
198 """Append a workflow task and stream the appropriate event."""
199 self.workflow_item = self.workflow_item or WorkflowItem(
200 id=self.generate_id("workflow"),
201 created_at=datetime.now(),
202 workflow=Workflow(type="custom", tasks=[]),
203 thread_id=self.thread.id,
204 )
205 workflow = self.workflow_item.workflow
206 workflow.tasks.append(task)
207
208 if workflow.type != "reasoning" and len(workflow.tasks) == 1:
209 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
210 else:
211 await self.stream(
212 ThreadItemUpdatedEvent(
213 item_id=self.workflow_item.id,
214 update=WorkflowTaskAdded(
215 task=task,
216 task_index=workflow.tasks.index(task),
217 ),
218 )
219 )
220
221 async def stream(self, event: ThreadStreamEvent) -> None:
222 """Enqueue a ThreadStreamEvent for downstream processing."""
223 await self._events.put(event)
224
225 def _complete(self):
226 self._events.put_nowait(_QueueCompleteSentinel())
227
228
229T1 = TypeVar("T1")
230T2 = TypeVar("T2")
231
232
233async def _merge_generators(
234 a: AsyncIterator[T1],
235 b: AsyncIterator[T2],
236) -> AsyncIterator[T1 | T2]:
237 pending: list[AsyncIterator[T1 | T2]] = [a, b]
238 pending_tasks: dict[asyncio.Task, AsyncIterator[T1 | T2]] = {
239 asyncio.ensure_future(g.__anext__()): g for g in pending
240 }
241 while len(pending_tasks) > 0:
242 done, _ = await asyncio.wait(
243 pending_tasks.keys(), return_when="FIRST_COMPLETED"
244 )
245 stop = False
246 for d in done:
247 try:
248 result = d.result()
249 yield result
250 dg = pending_tasks[d]
251 pending_tasks[asyncio.ensure_future(dg.__anext__())] = dg
252 except StopAsyncIteration:
253 stop = True
254 finally:
255 del pending_tasks[d]
256 if stop:
257 for task in pending_tasks.keys():
258 if not task.cancel():
259 try:
260 yield task.result()
261 except asyncio.CancelledError:
262 pass
263 except asyncio.InvalidStateError:
264 pass
265 break
266
267
268class _EventWrapper:
269 def __init__(self, event: ThreadStreamEvent):
270 self.event = event
271
272
273class _AsyncQueueIterator(AsyncIterator[_EventWrapper]):
274 def __init__(
275 self, queue: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel]
276 ):
277 self.queue = queue
278 self.completed = False
279
280 def __aiter__(self):
281 return self
282
283 async def __anext__(self):
284 if self.completed:
285 raise StopAsyncIteration
286
287 item = await self.queue.get()
288 if isinstance(item, _QueueCompleteSentinel):
289 self.completed = True
290 raise StopAsyncIteration
291 return _EventWrapper(item)
292
293 def drain_and_complete(self) -> None:
294 """Empty the underlying queue without awaiting and mark this iterator completed.
295
296 This is intended for cleanup paths where we must guarantee no awaits
297 occur. All queued items, including any completion sentinel, are
298 discarded.
299 """
300 while True:
301 try:
302 self.queue.get_nowait()
303 except asyncio.QueueEmpty:
304 break
305 self.completed = True
306
307
308class StreamingThoughtTracker(BaseModel):
309 item_id: str
310 index: int
311 task: ThoughtTask
312
313
314class ResponseStreamConverter:
315 """Used by `stream_agent_response` to convert streamed Agents SDK output
316 into values used by ChatKit thread items and thread stream events.
317
318 Defines overridable methods for adapting streamed data (such as image
319 generation results and partial updates) into the forms expected by ChatKit.
320 """
321
322 partial_images: int | None = None
323 """
324 The expected number of partial image updates for an image generation result.
325
326 When set, this value is used to normalize partial image indices into a
327 progress value in the range [0, 1]. If unset, all partial image updates are
328 assigned a progress value of 0.
329 """
330
331 def __init__(self, *, partial_images: int | None = None):
332 """
333 Args:
334 partial_images: The expected number of partial image updates for image
335 generation results, or None if no progress normalization should
336 be performed.
337 """
338 self.partial_images = partial_images
339
340 async def base64_image_to_url(
341 self,
342 image_id: str,
343 base64_image: str,
344 partial_image_index: int | None = None,
345 ) -> str:
346 """
347 Convert a base64-encoded image into a URL.
348
349 This method is used to produce the URL stored on thread items for image
350 generation results.
351
352 Args:
353 image_id: The ID of the image generation call. This stays stable across partial image updates.
354 base64_image: The base64-encoded image.
355 partial_image_index: The index of the partial image update, starting from 0.
356
357 Returns:
358 A URL string.
359 """
360 return f"data:image/png;base64,{base64_image}"
361
362 def partial_image_index_to_progress(self, partial_image_index: int) -> float:
363 """
364 Convert a partial image index into a normalized progress value.
365
366 Args:
367 partial_image_index: The index of the partial image update, starting from 0.
368
369 Returns:
370 A float between 0 and 1 representing progress for the image
371 generation result.
372 """
373 if self.partial_images is None or self.partial_images <= 0:
374 return 0.0
375
376 return min(1.0, partial_image_index / self.partial_images)
377
378 async def file_citation_to_annotation(
379 self, file_citation: AnnotationFileCitation
380 ) -> Annotation | None:
381 """Convert a Responses API file citation into an assistant message annotation."""
382 filename = file_citation.filename
383 if not filename:
384 return None
385
386 return Annotation(
387 source=FileSource(filename=filename, title=filename),
388 index=file_citation.index,
389 )
390
391 async def container_file_citation_to_annotation(
392 self, container_file_citation: AnnotationContainerFileCitation
393 ) -> Annotation | None:
394 """Convert a Responses API container file citation into an assistant message annotation."""
395 filename = container_file_citation.filename
396 if not filename:
397 return None
398
399 return Annotation(
400 source=FileSource(filename=filename, title=filename),
401 index=container_file_citation.end_index,
402 )
403
404 async def url_citation_to_annotation(
405 self, url_citation: AnnotationURLCitation
406 ) -> Annotation | None:
407 """Convert a Responses API URL citation into an assistant message annotation."""
408 return Annotation(
409 source=URLSource(url=url_citation.url, title=url_citation.title),
410 index=url_citation.end_index,
411 )
412
413
414_DEFAULT_RESPONSE_STREAM_CONVERTER = ResponseStreamConverter()
415
416
417async def _convert_content(
418 content: Content, converter: ResponseStreamConverter
419) -> AssistantMessageContent:
420 if content.type == "output_text":
421 annotations = [
422 await _convert_annotation(annotation, converter)
423 for annotation in content.annotations
424 ]
425 annotations = [a for a in annotations if a is not None]
426 return AssistantMessageContent(
427 text=content.text,
428 annotations=annotations,
429 )
430 else:
431 return AssistantMessageContent(
432 text=content.refusal,
433 annotations=[],
434 )
435
436
437async def _convert_annotation(
438 raw_annotation: object, converter: ResponseStreamConverter
439) -> Annotation | None:
440 # There is a bug in the OpenAPI client that sometimes parses the annotation delta event into the wrong class
441 # resulting into annotation being a dict or untyped object instead instead of a ResponsesAnnotation
442 annotation = TypeAdapter[ResponsesAnnotation](ResponsesAnnotation).validate_python(
443 raw_annotation
444 )
445
446 if annotation.type == "file_citation":
447 return await converter.file_citation_to_annotation(annotation)
448
449 if annotation.type == "url_citation":
450 return await converter.url_citation_to_annotation(annotation)
451
452 if annotation.type == "container_file_citation":
453 return await converter.container_file_citation_to_annotation(annotation)
454
455 return None
456
457
458async def stream_agent_response(
459 context: AgentContext,
460 result: RunResultStreaming,
461 *,
462 converter: ResponseStreamConverter = _DEFAULT_RESPONSE_STREAM_CONVERTER,
463) -> AsyncIterator[ThreadStreamEvent]:
464 """
465 Convert a streamed Agents SDK run into ChatKit thread stream events.
466
467 This function consumes a streaming run result and yields `ThreadStreamEvent`
468 objects as the run progresses.
469
470 Args:
471 context: The AgentContext to use for the stream.
472 result: The RunResultStreaming to convert.
473 converter: Defines overridable methods for adapting streamed data (such as image
474 generation results and partial updates) into the forms expected by ChatKit.
475
476 Returns:
477 An async iterator that yields thread stream events representing the run result.
478 """
479 current_item_id = None
480 current_tool_call = None
481 ctx = context
482 thread = context.thread
483 queue_iterator = _AsyncQueueIterator(context._events)
484 produced_items = set()
485 streaming_thought: None | StreamingThoughtTracker = None
486 # item_id -> content_index -> annotation count
487 item_annotation_count: defaultdict[str, defaultdict[int, int]] = defaultdict(
488 lambda: defaultdict(int)
489 )
490
491 # check if the last item in the thread was a workflow or a client tool call
492 # if it was a client tool call, check if the second last item was a workflow
493 # if either was, continue the workflow
494 items = await context.store.load_thread_items(
495 thread.id, None, 2, "desc", context.request_context
496 )
497 last_item = items.data[0] if len(items.data) > 0 else None
498 second_last_item = items.data[1] if len(items.data) > 1 else None
499
500 if last_item and last_item.type == "workflow":
501 ctx.workflow_item = last_item
502 elif (
503 last_item
504 and last_item.type == "client_tool_call"
505 and second_last_item
506 and second_last_item.type == "workflow"
507 ):
508 ctx.workflow_item = second_last_item
509
510 def end_workflow(item: WorkflowItem):
511 if item == ctx.workflow_item:
512 ctx.workflow_item = None
513 delta = datetime.now() - item.created_at
514 duration = int(delta.total_seconds())
515 if item.workflow.summary is None:
516 item.workflow.summary = DurationSummary(duration=duration)
517 # Default to closing all workflows
518 # To keep a workflow open on completion, close it explicitly with
519 # AgentContext.end_workflow(expanded=True)
520 item.workflow.expanded = False
521 return ThreadItemDoneEvent(item=item)
522
523 try:
524 async for event in _merge_generators(result.stream_events(), queue_iterator):
525 # Events emitted from agent context helpers
526 if isinstance(event, _EventWrapper):
527 event = event.event
528 if (
529 event.type == "thread.item.added"
530 or event.type == "thread.item.done"
531 ):
532 # End the current workflow if visual item is added after it
533 if (
534 ctx.workflow_item
535 and ctx.workflow_item.id != event.item.id
536 and event.item.type != "client_tool_call"
537 and event.item.type != "hidden_context_item"
538 ):
539 yield end_workflow(ctx.workflow_item)
540
541 # track the current workflow if one is added
542 if (
543 event.type == "thread.item.added"
544 and event.item.type == "workflow"
545 ):
546 ctx.workflow_item = event.item
547
548 # track integration produced items so we can clean them up if
549 # there is a guardrail tripwire
550 produced_items.add(event.item.id)
551 yield event
552 continue
553
554 if event.type == "run_item_stream_event":
555 event = event.item
556 if (
557 event.type == "tool_call_item"
558 and event.raw_item.type == "function_call"
559 ):
560 current_tool_call = event.raw_item.call_id
561 current_item_id = event.raw_item.id
562 assert current_item_id
563 produced_items.add(current_item_id)
564 continue
565
566 if event.type != "raw_response_event":
567 # Ignore everything else that isn't a raw response event
568 continue
569
570 # Handle Responses events
571 event = event.data
572 if event.type == "response.content_part.added":
573 if event.part.type == "reasoning_text":
574 continue
575 content = await _convert_content(event.part, converter)
576 yield ThreadItemUpdatedEvent(
577 item_id=event.item_id,
578 update=AssistantMessageContentPartAdded(
579 content_index=event.content_index,
580 content=content,
581 ),
582 )
583 elif event.type == "response.output_text.delta":
584 yield ThreadItemUpdatedEvent(
585 item_id=event.item_id,
586 update=AssistantMessageContentPartTextDelta(
587 content_index=event.content_index,
588 delta=event.delta,
589 ),
590 )
591 elif event.type == "response.output_text.done":
592 yield ThreadItemUpdatedEvent(
593 item_id=event.item_id,
594 update=AssistantMessageContentPartDone(
595 content_index=event.content_index,
596 content=AssistantMessageContent(
597 text=event.text,
598 annotations=[],
599 ),
600 ),
601 )
602 elif event.type == "response.output_text.annotation.added":
603 annotation = await _convert_annotation(event.annotation, converter)
604 if annotation:
605 # Manually track annotation indices per content part in case we drop an annotation that
606 # we can't convert to our internal representation (e.g. missing filename).
607 annotation_index = item_annotation_count[event.item_id][
608 event.content_index
609 ]
610 item_annotation_count[event.item_id][event.content_index] = (
611 annotation_index + 1
612 )
613 yield ThreadItemUpdatedEvent(
614 item_id=event.item_id,
615 update=AssistantMessageContentPartAnnotationAdded(
616 content_index=event.content_index,
617 annotation_index=annotation_index,
618 annotation=annotation,
619 ),
620 )
621 continue
622 elif event.type == "response.output_item.added":
623 item = event.item
624 if item.type == "reasoning" and not ctx.workflow_item:
625 ctx.workflow_item = WorkflowItem(
626 id=ctx.generate_id("workflow"),
627 created_at=datetime.now(),
628 workflow=Workflow(type="reasoning", tasks=[]),
629 thread_id=thread.id,
630 )
631 produced_items.add(ctx.workflow_item.id)
632 yield ThreadItemAddedEvent(item=ctx.workflow_item)
633 if item.type == "message":
634 if ctx.workflow_item:
635 yield end_workflow(ctx.workflow_item)
636 produced_items.add(item.id)
637 yield ThreadItemAddedEvent(
638 item=AssistantMessageItem(
639 # Reusing the Responses message ID
640 id=item.id,
641 thread_id=thread.id,
642 content=[
643 await _convert_content(c, converter)
644 for c in item.content
645 ],
646 created_at=datetime.now(),
647 ),
648 )
649 elif item.type == "image_generation_call":
650 ctx.generated_image_item = GeneratedImageItem(
651 id=ctx.generate_id("message"),
652 thread_id=thread.id,
653 created_at=datetime.now(),
654 image=None,
655 )
656 produced_items.add(ctx.generated_image_item.id)
657 yield ThreadItemAddedEvent(item=ctx.generated_image_item)
658 elif event.type == "response.image_generation_call.partial_image":
659 if not ctx.generated_image_item:
660 continue
661
662 url = await converter.base64_image_to_url(
663 image_id=event.item_id,
664 base64_image=event.partial_image_b64,
665 partial_image_index=event.partial_image_index,
666 )
667 progress = converter.partial_image_index_to_progress(
668 event.partial_image_index
669 )
670
671 ctx.generated_image_item.image = GeneratedImage(
672 id=event.item_id, url=url
673 )
674
675 yield ThreadItemUpdatedEvent(
676 item_id=ctx.generated_image_item.id,
677 update=GeneratedImageUpdated(
678 image=ctx.generated_image_item.image, progress=progress
679 ),
680 )
681 elif event.type == "response.reasoning_summary_text.delta":
682 if not ctx.workflow_item:
683 continue
684
685 # stream the first thought in a new workflow so that we can show it earlier
686 if (
687 ctx.workflow_item.workflow.type == "reasoning"
688 and len(ctx.workflow_item.workflow.tasks) == 0
689 ):
690 streaming_thought = StreamingThoughtTracker(
691 item_id=event.item_id,
692 index=event.summary_index,
693 task=ThoughtTask(content=event.delta),
694 )
695 ctx.workflow_item.workflow.tasks.append(streaming_thought.task)
696 yield ThreadItemUpdatedEvent(
697 item_id=ctx.workflow_item.id,
698 update=WorkflowTaskAdded(
699 task=streaming_thought.task,
700 task_index=0,
701 ),
702 )
703 elif (
704 streaming_thought
705 and streaming_thought.task in ctx.workflow_item.workflow.tasks
706 and event.item_id == streaming_thought.item_id
707 and event.summary_index == streaming_thought.index
708 ):
709 streaming_thought.task.content += event.delta
710 yield ThreadItemUpdatedEvent(
711 item_id=ctx.workflow_item.id,
712 update=WorkflowTaskUpdated(
713 task=streaming_thought.task,
714 task_index=ctx.workflow_item.workflow.tasks.index(
715 streaming_thought.task
716 ),
717 ),
718 )
719 elif event.type == "response.reasoning_summary_text.done":
720 if ctx.workflow_item:
721 if (
722 streaming_thought
723 and streaming_thought.task in ctx.workflow_item.workflow.tasks
724 and event.item_id == streaming_thought.item_id
725 and event.summary_index == streaming_thought.index
726 ):
727 task = streaming_thought.task
728 task.content = event.text
729 streaming_thought = None
730 update = WorkflowTaskUpdated(
731 task=task,
732 task_index=ctx.workflow_item.workflow.tasks.index(task),
733 )
734 else:
735 task = ThoughtTask(content=event.text)
736 ctx.workflow_item.workflow.tasks.append(task)
737 update = WorkflowTaskAdded(
738 task=task,
739 task_index=ctx.workflow_item.workflow.tasks.index(task),
740 )
741 yield ThreadItemUpdatedEvent(
742 item_id=ctx.workflow_item.id,
743 update=update,
744 )
745 elif event.type == "response.output_item.done":
746 item = event.item
747 if item.type == "message":
748 produced_items.add(item.id)
749 yield ThreadItemDoneEvent(
750 item=AssistantMessageItem(
751 # Reusing the Responses message ID
752 id=item.id,
753 thread_id=thread.id,
754 content=[
755 await _convert_content(c, converter)
756 for c in item.content
757 ],
758 created_at=datetime.now(),
759 ),
760 )
761 elif item.type == "image_generation_call" and item.result:
762 if not ctx.generated_image_item:
763 continue
764
765 url = await converter.base64_image_to_url(
766 image_id=item.id,
767 base64_image=item.result,
768 )
769 image = GeneratedImage(id=item.id, url=url)
770
771 ctx.generated_image_item.image = image
772 yield ThreadItemDoneEvent(item=ctx.generated_image_item)
773
774 ctx.generated_image_item = None
775
776 except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
777 for item_id in produced_items:
778 yield ThreadItemRemovedEvent(item_id=item_id)
779
780 # Drain remaining events without processing them
781 context._complete()
782 queue_iterator.drain_and_complete()
783
784 raise
785
786 context._complete()
787
788 # Drain remaining events
789 async for event in queue_iterator:
790 yield event.event
791
792 # If there is still an active workflow at the end of the run, store
793 # it's current state so that we can continue it in the next turn.
794 if ctx.workflow_item:
795 await ctx.store.add_thread_item(
796 thread.id, ctx.workflow_item, ctx.request_context
797 )
798
799 if context.client_tool_call:
800 yield ThreadItemDoneEvent(
801 item=ClientToolCallItem(
802 id=current_item_id
803 or context.store.generate_item_id(
804 "tool_call", thread, context.request_context
805 ),
806 thread_id=thread.id,
807 name=context.client_tool_call.name,
808 arguments=context.client_tool_call.arguments,
809 created_at=datetime.now(),
810 call_id=current_tool_call
811 or context.store.generate_item_id(
812 "tool_call", thread, context.request_context
813 ),
814 ),
815 )
816
817
818TWidget = TypeVar("TWidget", bound=Markdown | Text)
819
820
821async def accumulate_text(
822 events: AsyncIterator[StreamEvent],
823 base_widget: TWidget,
824) -> AsyncIterator[TWidget]:
825 text = ""
826 yield base_widget
827 async for event in events:
828 if event.type == "raw_response_event":
829 if event.data.type == "response.output_text.delta":
830 text += event.data.delta
831 yield base_widget.model_copy(update={"value": text})
832 yield base_widget.model_copy(update={"value": text, "streaming": False})
833
834
835class ThreadItemConverter:
836 """
837 Converts thread items to Agent SDK input items.
838 Widgets, Tasks, and Workflows have default conversions but can be customized.
839 Attachments, Tags, and HiddenContextItems require custom handling based on the use case.
840 Other item types are converted automatically.
841 """
842
843 async def attachment_to_message_content(
844 self, attachment: Attachment
845 ) -> ResponseInputContentParam:
846 """
847 Convert an attachment in a user message into a message content part to send to the model.
848 Required when attachments are enabled.
849 """
850 raise NotImplementedError(
851 "An Attachment was included in a UserMessageItem but Converter.attachment_to_message_content was not implemented"
852 )
853
854 async def tag_to_message_content(
855 self, tag: UserMessageTagContent
856 ) -> ResponseInputContentParam:
857 """
858 Convert a tag in a user message into a message content part to send to the model as context.
859 Required when tags are used.
860 """
861 raise NotImplementedError(
862 "A Tag was included in a UserMessageItem but Converter.tag_to_message_content is not implemented"
863 )
864
865 async def generated_image_to_input(
866 self, item: GeneratedImageItem
867 ) -> TResponseInputItem | list[TResponseInputItem] | None:
868 """
869 Convert a GeneratedImageItem into input item(s) to send to the model.
870 Override this method to customize the conversion of generated images, such as when your
871 generated image url is not publicly reachable.
872 """
873 if not item.image:
874 return None
875
876 return Message(
877 type="message",
878 content=[
879 ResponseInputTextParam(
880 type="input_text",
881 text="The following image was generated by the agent.",
882 ),
883 ResponseInputImageParam(
884 type="input_image",
885 detail="auto",
886 image_url=item.image.url,
887 ),
888 ],
889 role="user",
890 )
891
892 async def hidden_context_to_input(
893 self, item: HiddenContextItem
894 ) -> TResponseInputItem | list[TResponseInputItem] | None:
895 """
896 Convert a HiddenContextItem into input item(s) to send to the model.
897 Required to override when HiddenContextItems with non-string content are used.
898 """
899 if not isinstance(item.content, str):
900 raise NotImplementedError(
901 "HiddenContextItems with non-string content were present in a user message but a Converter.hidden_context_to_input was not implemented"
902 )
903
904 text = (
905 "Hidden context for the agent (not shown to the user):\n"
906 f"<HiddenContext>\n{item.content}\n</HiddenContext>"
907 )
908 return Message(
909 type="message",
910 content=[
911 ResponseInputTextParam(
912 type="input_text",
913 text=text,
914 )
915 ],
916 role="user",
917 )
918
919 async def sdk_hidden_context_to_input(
920 self, item: SDKHiddenContextItem
921 ) -> TResponseInputItem | list[TResponseInputItem] | None:
922 """
923 Convert a SDKHiddenContextItem into input item to send to the model.
924 This is used by the ChatKit Python SDK for storing additional context
925 for internal operations.
926 Override if you want to wrap the content in a different format.
927 """
928 text = (
929 "Hidden context for the agent (not shown to the user):\n"
930 f"<HiddenContext>\n{item.content}\n</HiddenContext>"
931 )
932 return Message(
933 type="message",
934 content=[
935 ResponseInputTextParam(
936 type="input_text",
937 text=text,
938 )
939 ],
940 role="user",
941 )
942
943 async def task_to_input(
944 self, item: TaskItem
945 ) -> TResponseInputItem | list[TResponseInputItem] | None:
946 """
947 Convert a TaskItem into input item(s) to send to the model.
948 """
949 if item.task.type != "custom" or (
950 not item.task.title and not item.task.content
951 ):
952 return None
953 title = f"{item.task.title}" if item.task.title else ""
954 content = f"{item.task.content}" if item.task.content else ""
955 task_text = f"{title}: {content}" if title and content else title or content
956 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
957 return Message(
958 type="message",
959 content=[
960 ResponseInputTextParam(
961 type="input_text",
962 text=text,
963 )
964 ],
965 role="user",
966 )
967
968 async def structured_input_to_input(
969 self, item: StructuredInputItem
970 ) -> TResponseInputItem | list[TResponseInputItem] | None:
971 """
972 Convert a StructuredInputItem into input item(s) to send to the model.
973 """
974 lines = []
975 for question in item.inputs:
976 answer = question.answer
977 if answer is None:
978 lines.append(f"- {question.question}: unanswered")
979 elif answer.skipped:
980 lines.append(f"- {question.question}: skipped")
981 else:
982 lines.append(f"- {question.question}: {', '.join(answer.values)}")
983
984 text = (
985 "A structured input request was displayed to the user with the following "
986 f"status: {item.status}\n<StructuredInput>\n"
987 + "\n".join(lines)
988 + "\n</StructuredInput>"
989 )
990 return Message(
991 type="message",
992 content=[
993 ResponseInputTextParam(
994 type="input_text",
995 text=text,
996 ),
997 ],
998 role="user",
999 )
1000
1001 async def workflow_to_input(
1002 self, item: WorkflowItem
1003 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1004 """
1005 Convert a TaskItem into input item(s) to send to the model.
1006 Returns WorkflowItem.response_items by default.
1007 """
1008 messages = []
1009 for task in item.workflow.tasks:
1010 if task.type != "custom" or (not task.title and not task.content):
1011 continue
1012
1013 title = f"{task.title}" if task.title else ""
1014 content = f"{task.content}" if task.content else ""
1015 task_text = f"{title}: {content}" if title and content else title or content
1016 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
1017 messages.append(
1018 Message(
1019 type="message",
1020 content=[
1021 ResponseInputTextParam(
1022 type="input_text",
1023 text=text,
1024 )
1025 ],
1026 role="user",
1027 )
1028 )
1029 return messages
1030
1031 async def widget_to_input(
1032 self, item: WidgetItem
1033 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1034 """
1035 Convert a WidgetItem into input item(s) to send to the model.
1036 By default, WidgetItems converted to a text description with a JSON representation of the widget.
1037 """
1038 return Message(
1039 type="message",
1040 content=[
1041 ResponseInputTextParam(
1042 type="input_text",
1043 text=f"The following graphical UI widget (id: {item.id}) was displayed to the user:"
1044 + item.widget.model_dump_json(
1045 exclude_unset=True, exclude_none=True
1046 ),
1047 )
1048 ],
1049 role="user",
1050 )
1051
1052 async def user_message_to_input(
1053 self, item: UserMessageItem, is_last_message: bool = True
1054 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1055 # Build the user text exactly as typed, rendering tags as @key
1056 message_text_parts: list[str] = []
1057 # Track tags separately to add system context
1058 raw_tags: list[UserMessageTagContent] = []
1059
1060 for part in item.content:
1061 if isinstance(part, UserMessageTextContent):
1062 message_text_parts.append(part.text)
1063 elif isinstance(part, UserMessageTagContent):
1064 message_text_parts.append(f"@{part.text}")
1065 raw_tags.append(part)
1066 else:
1067 assert_never(part)
1068
1069 user_text_item = Message(
1070 role="user",
1071 type="message",
1072 content=[
1073 ResponseInputTextParam(
1074 type="input_text", text="".join(message_text_parts)
1075 ),
1076 *[
1077 await self.attachment_to_message_content(a)
1078 for a in item.attachments
1079 ],
1080 ],
1081 )
1082
1083 # Build system items (prepend later): quoted text and @-mention context
1084 context_items: list[TResponseInputItem] = []
1085
1086 if item.quoted_text and is_last_message:
1087 context_items.append(
1088 Message(
1089 role="user",
1090 type="message",
1091 content=[
1092 ResponseInputTextParam(
1093 type="input_text",
1094 text=f"The user is referring to this in particular: \n{item.quoted_text}",
1095 )
1096 ],
1097 )
1098 )
1099
1100 # Dedupe tags (preserve order) and resolve to message content
1101 if raw_tags:
1102 seen, uniq_tags = set(), []
1103 for t in raw_tags:
1104 if t.text not in seen:
1105 seen.add(t.text)
1106 uniq_tags.append(t)
1107
1108 tag_content: ResponseInputMessageContentListParam = [
1109 # should return summarized text items
1110 await self.tag_to_message_content(tag)
1111 for tag in uniq_tags
1112 ]
1113
1114 if tag_content:
1115 context_items.append(
1116 Message(
1117 role="user",
1118 type="message",
1119 content=[
1120 ResponseInputTextParam(
1121 type="input_text",
1122 text=cleandoc("""
1123 # User-provided context for @-mentions
1124 - When referencing resolved entities, use their canonical names **without** '@'.
1125 - The '@' form appears only in user text and should not be echoed.
1126 """).strip(),
1127 ),
1128 *tag_content,
1129 ],
1130 )
1131 )
1132
1133 return [user_text_item, *context_items]
1134
1135 async def assistant_message_to_input(
1136 self, item: AssistantMessageItem
1137 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1138 return EasyInputMessageParam(
1139 type="message",
1140 content=[
1141 # content param doesn't support the assistant message content types
1142 cast(
1143 ResponseInputContentParam,
1144 ResponseOutputText(
1145 type="output_text",
1146 text=c.text,
1147 annotations=[], # TODO: these should be sent back as well
1148 ).model_dump(),
1149 )
1150 for c in item.content
1151 ],
1152 role="assistant",
1153 )
1154
1155 async def client_tool_call_to_input(
1156 self, item: ClientToolCallItem
1157 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1158 if item.status == "pending":
1159 # Filter out pending tool calls - they cannot be sent to the model
1160 return None
1161
1162 return [
1163 ResponseFunctionToolCallParam(
1164 type="function_call",
1165 call_id=item.call_id,
1166 name=item.name,
1167 arguments=json.dumps(item.arguments),
1168 ),
1169 FunctionCallOutput(
1170 type="function_call_output",
1171 call_id=item.call_id,
1172 output=json.dumps(item.output),
1173 ),
1174 ]
1175
1176 async def end_of_turn_to_input(
1177 self, item: EndOfTurnItem
1178 ) -> TResponseInputItem | list[TResponseInputItem] | None:
1179 # Only used for UI hints - you shouldn't need to override this
1180 return None
1181
1182 async def _thread_item_to_input_item(
1183 self,
1184 item: ThreadItem,
1185 is_last_message: bool = True,
1186 ) -> list[TResponseInputItem]:
1187 match item:
1188 case UserMessageItem():
1189 out = await self.user_message_to_input(item, is_last_message) or []
1190 return out if isinstance(out, list) else [out]
1191 case AssistantMessageItem():
1192 out = await self.assistant_message_to_input(item) or []
1193 return out if isinstance(out, list) else [out]
1194 case ClientToolCallItem():
1195 out = await self.client_tool_call_to_input(item) or []
1196 return out if isinstance(out, list) else [out]
1197 case EndOfTurnItem():
1198 out = await self.end_of_turn_to_input(item) or []
1199 return out if isinstance(out, list) else [out]
1200 case WidgetItem():
1201 out = await self.widget_to_input(item) or []
1202 return out if isinstance(out, list) else [out]
1203 case WorkflowItem():
1204 out = await self.workflow_to_input(item) or []
1205 return out if isinstance(out, list) else [out]
1206 case TaskItem():
1207 out = await self.task_to_input(item) or []
1208 return out if isinstance(out, list) else [out]
1209 case StructuredInputItem():
1210 out = await self.structured_input_to_input(item) or []
1211 return out if isinstance(out, list) else [out]
1212 case HiddenContextItem():
1213 out = await self.hidden_context_to_input(item) or []
1214 return out if isinstance(out, list) else [out]
1215 case SDKHiddenContextItem():
1216 out = await self.sdk_hidden_context_to_input(item) or []
1217 return out if isinstance(out, list) else [out]
1218 case GeneratedImageItem():
1219 out = await self.generated_image_to_input(item) or []
1220 return out if isinstance(out, list) else [out]
1221 case _:
1222 assert_never(item)
1223
1224 async def to_agent_input(
1225 self,
1226 thread_items: Sequence[ThreadItem] | ThreadItem,
1227 ) -> list[TResponseInputItem]:
1228 if isinstance(thread_items, Sequence):
1229 # shallow copy in case caller mutates the list while we're iterating
1230 thread_items = thread_items[:]
1231 else:
1232 thread_items = [thread_items]
1233 output: list[TResponseInputItem] = []
1234 for item in thread_items:
1235 output.extend(
1236 await self._thread_item_to_input_item(
1237 item,
1238 is_last_message=item is thread_items[-1],
1239 )
1240 )
1241 return output
1242
1243
1244_DEFAULT_CONVERTER = ThreadItemConverter()
1245
1246
1247def simple_to_agent_input(thread_items: Sequence[ThreadItem] | ThreadItem):
1248 """Helper that converts thread items using the default ThreadItemConverter."""
1249 return _DEFAULT_CONVERTER.to_agent_input(thread_items)
1250