openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
45ae38b39b3b506ac24db075ec0faf8a6a930b5e

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/agents.py

1015lines · modecode

1import asyncio
2import json
3from collections import defaultdict
4from collections.abc import AsyncIterator
5from datetime import datetime
6from inspect import cleandoc
7from typing import (
8 Annotated,
9 Any,
10 AsyncGenerator,
11 Generic,
12 Sequence,
13 TypeVar,
14 cast,
15)
16
17from agents import (
18 InputGuardrailTripwireTriggered,
19 OutputGuardrailTripwireTriggered,
20 RunResultStreaming,
21 StreamEvent,
22 TResponseInputItem,
23)
24from openai.types.responses import (
25 EasyInputMessageParam,
26 ResponseFunctionToolCallParam,
27 ResponseInputContentParam,
28 ResponseInputMessageContentListParam,
29 ResponseInputTextParam,
30 ResponseOutputText,
31)
32from openai.types.responses.response_input_item_param import (
33 FunctionCallOutput,
34 Message,
35)
36from openai.types.responses.response_output_message import Content
37from openai.types.responses.response_output_text import (
38 Annotation as ResponsesAnnotation,
39)
40from pydantic import BaseModel, ConfigDict, SkipValidation, TypeAdapter
41from typing_extensions import assert_never
42
43from .server import stream_widget
44from .store import Store, StoreItemType
45from .types import (
46 Annotation,
47 AssistantMessageContent,
48 AssistantMessageContentPartAdded,
49 AssistantMessageContentPartAnnotationAdded,
50 AssistantMessageContentPartDone,
51 AssistantMessageContentPartTextDelta,
52 AssistantMessageItem,
53 Attachment,
54 ClientToolCallItem,
55 DurationSummary,
56 EndOfTurnItem,
57 FileSource,
58 HiddenContextItem,
59 SDKHiddenContextItem,
60 Task,
61 TaskItem,
62 ThoughtTask,
63 ThreadItem,
64 ThreadItemAddedEvent,
65 ThreadItemDoneEvent,
66 ThreadItemRemovedEvent,
67 ThreadItemUpdatedEvent,
68 ThreadMetadata,
69 ThreadStreamEvent,
70 URLSource,
71 UserMessageItem,
72 UserMessageTagContent,
73 UserMessageTextContent,
74 WidgetItem,
75 Workflow,
76 WorkflowItem,
77 WorkflowSummary,
78 WorkflowTaskAdded,
79 WorkflowTaskUpdated,
80)
81from .widgets import Markdown, Text, WidgetRoot
82
83
84class ClientToolCall(BaseModel):
85 """
86 Returned from tool methods to indicate a client-side tool call.
87 """
88
89 name: str
90 arguments: dict[str, Any]
91
92
93class _QueueCompleteSentinel: ...
94
95
96TContext = TypeVar("TContext")
97
98
99class AgentContext(BaseModel, Generic[TContext]):
100 model_config = ConfigDict(arbitrary_types_allowed=True)
101
102 thread: ThreadMetadata
103 store: Annotated[Store[TContext], SkipValidation]
104 request_context: TContext
105 previous_response_id: str | None = None
106 client_tool_call: ClientToolCall | None = None
107 workflow_item: WorkflowItem | None = None
108 _events: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel] = asyncio.Queue()
109
110 def generate_id(
111 self, type: StoreItemType, thread: ThreadMetadata | None = None
112 ) -> str:
113 """Generate a new store-backed id for the given item type."""
114 if type == "thread":
115 return self.store.generate_thread_id(self.request_context)
116 return self.store.generate_item_id(
117 type, thread or self.thread, self.request_context
118 )
119
120 async def stream_widget(
121 self,
122 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
123 copy_text: str | None = None,
124 ) -> None:
125 """Stream a widget into the thread by enqueueing widget events."""
126 async for event in stream_widget(
127 self.thread,
128 widget,
129 copy_text,
130 lambda item_type: self.store.generate_item_id(
131 item_type, self.thread, self.request_context
132 ),
133 ):
134 await self._events.put(event)
135
136 async def end_workflow(
137 self, summary: WorkflowSummary | None = None, expanded: bool = False
138 ) -> None:
139 """Finalize the active workflow item, optionally attaching a summary."""
140 if not self.workflow_item:
141 # No workflow to end
142 return
143
144 if summary is not None:
145 self.workflow_item.workflow.summary = summary
146 elif self.workflow_item.workflow.summary is None:
147 # If no summary was set or provided, set a basic work summary
148 delta = datetime.now() - self.workflow_item.created_at
149 duration = int(delta.total_seconds())
150 self.workflow_item.workflow.summary = DurationSummary(duration=duration)
151 self.workflow_item.workflow.expanded = expanded
152 await self.stream(ThreadItemDoneEvent(item=self.workflow_item))
153 self.workflow_item = None
154
155 async def start_workflow(self, workflow: Workflow) -> None:
156 """Begin streaming a new workflow item."""
157 self.workflow_item = WorkflowItem(
158 id=self.generate_id("workflow"),
159 created_at=datetime.now(),
160 workflow=workflow,
161 thread_id=self.thread.id,
162 )
163
164 if workflow.type != "reasoning" and len(workflow.tasks) == 0:
165 # Defer sending added event until we have tasks
166 return
167
168 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
169
170 async def update_workflow_task(self, task: Task, task_index: int) -> None:
171 """Update an existing workflow task and stream the delta."""
172 if self.workflow_item is None:
173 raise ValueError("Workflow is not set")
174 # ensure reference is updated in case task is a copy
175 self.workflow_item.workflow.tasks[task_index] = task
176 await self.stream(
177 ThreadItemUpdatedEvent(
178 item_id=self.workflow_item.id,
179 update=WorkflowTaskUpdated(
180 task=task,
181 task_index=task_index,
182 ),
183 )
184 )
185
186 async def add_workflow_task(self, task: Task) -> None:
187 """Append a workflow task and stream the appropriate event."""
188 self.workflow_item = self.workflow_item or WorkflowItem(
189 id=self.generate_id("workflow"),
190 created_at=datetime.now(),
191 workflow=Workflow(type="custom", tasks=[]),
192 thread_id=self.thread.id,
193 )
194 workflow = self.workflow_item.workflow
195 workflow.tasks.append(task)
196
197 if workflow.type != "reasoning" and len(workflow.tasks) == 1:
198 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
199 else:
200 await self.stream(
201 ThreadItemUpdatedEvent(
202 item_id=self.workflow_item.id,
203 update=WorkflowTaskAdded(
204 task=task,
205 task_index=workflow.tasks.index(task),
206 ),
207 )
208 )
209
210 async def stream(self, event: ThreadStreamEvent) -> None:
211 """Enqueue a ThreadStreamEvent for downstream processing."""
212 await self._events.put(event)
213
214 def _complete(self):
215 self._events.put_nowait(_QueueCompleteSentinel())
216
217
218def _convert_content(content: Content) -> AssistantMessageContent:
219 if content.type == "output_text":
220 annotations = [
221 _convert_annotation(annotation) for annotation in content.annotations
222 ]
223 annotations = [a for a in annotations if a is not None]
224 return AssistantMessageContent(
225 text=content.text,
226 annotations=annotations,
227 )
228 else:
229 return AssistantMessageContent(
230 text=content.refusal,
231 annotations=[],
232 )
233
234
235def _convert_annotation(raw_annotation: object) -> Annotation | None:
236 # There is a bug in the OpenAPI client that sometimes parses the annotation delta event into the wrong class
237 # resulting into annotation being a dict or untyped object instead instead of a ResponsesAnnotation
238 annotation = TypeAdapter[ResponsesAnnotation](ResponsesAnnotation).validate_python(
239 raw_annotation
240 )
241
242 if annotation.type == "file_citation":
243 filename = annotation.filename
244 if not filename:
245 return None
246
247 return Annotation(
248 source=FileSource(filename=filename, title=filename),
249 index=annotation.index,
250 )
251
252 if annotation.type == "url_citation":
253 return Annotation(
254 source=URLSource(
255 url=annotation.url,
256 title=annotation.title,
257 ),
258 index=annotation.end_index,
259 )
260
261 if annotation.type == "container_file_citation":
262 filename = annotation.filename
263 if not filename:
264 return None
265
266 return Annotation(
267 source=FileSource(filename=filename, title=filename),
268 index=annotation.end_index,
269 )
270
271 return None
272
273
274T1 = TypeVar("T1")
275T2 = TypeVar("T2")
276
277
278async def _merge_generators(
279 a: AsyncIterator[T1],
280 b: AsyncIterator[T2],
281) -> AsyncIterator[T1 | T2]:
282 pending: list[AsyncIterator[T1 | T2]] = [a, b]
283 pending_tasks: dict[asyncio.Task, AsyncIterator[T1 | T2]] = {
284 asyncio.ensure_future(g.__anext__()): g for g in pending
285 }
286 while len(pending_tasks) > 0:
287 done, _ = await asyncio.wait(
288 pending_tasks.keys(), return_when="FIRST_COMPLETED"
289 )
290 stop = False
291 for d in done:
292 try:
293 result = d.result()
294 yield result
295 dg = pending_tasks[d]
296 pending_tasks[asyncio.ensure_future(dg.__anext__())] = dg
297 except StopAsyncIteration:
298 stop = True
299 finally:
300 del pending_tasks[d]
301 if stop:
302 for task in pending_tasks.keys():
303 if not task.cancel():
304 try:
305 yield task.result()
306 except asyncio.CancelledError:
307 pass
308 except asyncio.InvalidStateError:
309 pass
310 break
311
312
313class _EventWrapper:
314 def __init__(self, event: ThreadStreamEvent):
315 self.event = event
316
317
318class _AsyncQueueIterator(AsyncIterator[_EventWrapper]):
319 def __init__(
320 self, queue: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel]
321 ):
322 self.queue = queue
323 self.completed = False
324
325 def __aiter__(self):
326 return self
327
328 async def __anext__(self):
329 if self.completed:
330 raise StopAsyncIteration
331
332 item = await self.queue.get()
333 if isinstance(item, _QueueCompleteSentinel):
334 self.completed = True
335 raise StopAsyncIteration
336 return _EventWrapper(item)
337
338 def drain_and_complete(self) -> None:
339 """Empty the underlying queue without awaiting and mark this iterator completed.
340
341 This is intended for cleanup paths where we must guarantee no awaits
342 occur. All queued items, including any completion sentinel, are
343 discarded.
344 """
345 while True:
346 try:
347 self.queue.get_nowait()
348 except asyncio.QueueEmpty:
349 break
350 self.completed = True
351
352
353class StreamingThoughtTracker(BaseModel):
354 item_id: str
355 index: int
356 task: ThoughtTask
357
358
359async def stream_agent_response(
360 context: AgentContext, result: RunResultStreaming
361) -> AsyncIterator[ThreadStreamEvent]:
362 """Convert a streamed Agents SDK run into ChatKit ThreadStreamEvents."""
363 current_item_id = None
364 current_tool_call = None
365 ctx = context
366 thread = context.thread
367 queue_iterator = _AsyncQueueIterator(context._events)
368 produced_items = set()
369 streaming_thought: None | StreamingThoughtTracker = None
370 # item_id -> content_index -> annotation count
371 item_annotation_count: defaultdict[str, defaultdict[int, int]] = defaultdict(
372 lambda: defaultdict(int)
373 )
374
375 # check if the last item in the thread was a workflow or a client tool call
376 # if it was a client tool call, check if the second last item was a workflow
377 # if either was, continue the workflow
378 items = await context.store.load_thread_items(
379 thread.id, None, 2, "desc", context.request_context
380 )
381 last_item = items.data[0] if len(items.data) > 0 else None
382 second_last_item = items.data[1] if len(items.data) > 1 else None
383
384 if last_item and last_item.type == "workflow":
385 ctx.workflow_item = last_item
386 elif (
387 last_item
388 and last_item.type == "client_tool_call"
389 and second_last_item
390 and second_last_item.type == "workflow"
391 ):
392 ctx.workflow_item = second_last_item
393
394 def end_workflow(item: WorkflowItem):
395 if item == ctx.workflow_item:
396 ctx.workflow_item = None
397 delta = datetime.now() - item.created_at
398 duration = int(delta.total_seconds())
399 if item.workflow.summary is None:
400 item.workflow.summary = DurationSummary(duration=duration)
401 # Default to closing all workflows
402 # To keep a workflow open on completion, close it explicitly with
403 # AgentContext.end_workflow(expanded=True)
404 item.workflow.expanded = False
405 return ThreadItemDoneEvent(item=item)
406
407 try:
408 async for event in _merge_generators(result.stream_events(), queue_iterator):
409 # Events emitted from agent context helpers
410 if isinstance(event, _EventWrapper):
411 event = event.event
412 if (
413 event.type == "thread.item.added"
414 or event.type == "thread.item.done"
415 ):
416 # End the current workflow if visual item is added after it
417 if (
418 ctx.workflow_item
419 and ctx.workflow_item.id != event.item.id
420 and event.item.type != "client_tool_call"
421 and event.item.type != "hidden_context_item"
422 ):
423 yield end_workflow(ctx.workflow_item)
424
425 # track the current workflow if one is added
426 if (
427 event.type == "thread.item.added"
428 and event.item.type == "workflow"
429 ):
430 ctx.workflow_item = event.item
431
432 # track integration produced items so we can clean them up if
433 # there is a guardrail tripwire
434 produced_items.add(event.item.id)
435 yield event
436 continue
437
438 if event.type == "run_item_stream_event":
439 event = event.item
440 if (
441 event.type == "tool_call_item"
442 and event.raw_item.type == "function_call"
443 ):
444 current_tool_call = event.raw_item.call_id
445 current_item_id = event.raw_item.id
446 assert current_item_id
447 produced_items.add(current_item_id)
448 continue
449
450 if event.type != "raw_response_event":
451 # Ignore everything else that isn't a raw response event
452 continue
453
454 # Handle Responses events
455 event = event.data
456 if event.type == "response.content_part.added":
457 if event.part.type == "reasoning_text":
458 continue
459 content = _convert_content(event.part)
460 yield ThreadItemUpdatedEvent(
461 item_id=event.item_id,
462 update=AssistantMessageContentPartAdded(
463 content_index=event.content_index,
464 content=content,
465 ),
466 )
467 elif event.type == "response.output_text.delta":
468 yield ThreadItemUpdatedEvent(
469 item_id=event.item_id,
470 update=AssistantMessageContentPartTextDelta(
471 content_index=event.content_index,
472 delta=event.delta,
473 ),
474 )
475 elif event.type == "response.output_text.done":
476 yield ThreadItemUpdatedEvent(
477 item_id=event.item_id,
478 update=AssistantMessageContentPartDone(
479 content_index=event.content_index,
480 content=AssistantMessageContent(
481 text=event.text,
482 annotations=[],
483 ),
484 ),
485 )
486 elif event.type == "response.output_text.annotation.added":
487 annotation = _convert_annotation(event.annotation)
488 if annotation:
489 # Manually track annotation indices per content part in case we drop an annotation that
490 # we can't convert to our internal representation (e.g. missing filename).
491 annotation_index = item_annotation_count[event.item_id][
492 event.content_index
493 ]
494 item_annotation_count[event.item_id][event.content_index] = (
495 annotation_index + 1
496 )
497 yield ThreadItemUpdatedEvent(
498 item_id=event.item_id,
499 update=AssistantMessageContentPartAnnotationAdded(
500 content_index=event.content_index,
501 annotation_index=annotation_index,
502 annotation=annotation,
503 ),
504 )
505 continue
506 elif event.type == "response.output_item.added":
507 item = event.item
508 if item.type == "reasoning" and not ctx.workflow_item:
509 ctx.workflow_item = WorkflowItem(
510 id=ctx.generate_id("workflow"),
511 created_at=datetime.now(),
512 workflow=Workflow(type="reasoning", tasks=[]),
513 thread_id=thread.id,
514 )
515 produced_items.add(ctx.workflow_item.id)
516 yield ThreadItemAddedEvent(item=ctx.workflow_item)
517 if item.type == "message":
518 if ctx.workflow_item:
519 yield end_workflow(ctx.workflow_item)
520 produced_items.add(item.id)
521 yield ThreadItemAddedEvent(
522 item=AssistantMessageItem(
523 # Reusing the Responses message ID
524 id=item.id,
525 thread_id=thread.id,
526 content=[_convert_content(c) for c in item.content],
527 created_at=datetime.now(),
528 ),
529 )
530 elif event.type == "response.reasoning_summary_text.delta":
531 if not ctx.workflow_item:
532 continue
533
534 # stream the first thought in a new workflow so that we can show it earlier
535 if (
536 ctx.workflow_item.workflow.type == "reasoning"
537 and len(ctx.workflow_item.workflow.tasks) == 0
538 ):
539 streaming_thought = StreamingThoughtTracker(
540 item_id=event.item_id,
541 index=event.summary_index,
542 task=ThoughtTask(content=event.delta),
543 )
544 ctx.workflow_item.workflow.tasks.append(streaming_thought.task)
545 yield ThreadItemUpdatedEvent(
546 item_id=ctx.workflow_item.id,
547 update=WorkflowTaskAdded(
548 task=streaming_thought.task,
549 task_index=0,
550 ),
551 )
552 elif (
553 streaming_thought
554 and streaming_thought.task in ctx.workflow_item.workflow.tasks
555 and event.item_id == streaming_thought.item_id
556 and event.summary_index == streaming_thought.index
557 ):
558 streaming_thought.task.content += event.delta
559 yield ThreadItemUpdatedEvent(
560 item_id=ctx.workflow_item.id,
561 update=WorkflowTaskUpdated(
562 task=streaming_thought.task,
563 task_index=ctx.workflow_item.workflow.tasks.index(
564 streaming_thought.task
565 ),
566 ),
567 )
568 elif event.type == "response.reasoning_summary_text.done":
569 if ctx.workflow_item:
570 if (
571 streaming_thought
572 and streaming_thought.task in ctx.workflow_item.workflow.tasks
573 and event.item_id == streaming_thought.item_id
574 and event.summary_index == streaming_thought.index
575 ):
576 task = streaming_thought.task
577 task.content = event.text
578 streaming_thought = None
579 update = WorkflowTaskUpdated(
580 task=task,
581 task_index=ctx.workflow_item.workflow.tasks.index(task),
582 )
583 else:
584 task = ThoughtTask(content=event.text)
585 ctx.workflow_item.workflow.tasks.append(task)
586 update = WorkflowTaskAdded(
587 task=task,
588 task_index=ctx.workflow_item.workflow.tasks.index(task),
589 )
590 yield ThreadItemUpdatedEvent(
591 item_id=ctx.workflow_item.id,
592 update=update,
593 )
594 elif event.type == "response.output_item.done":
595 item = event.item
596 if item.type == "message":
597 produced_items.add(item.id)
598 yield ThreadItemDoneEvent(
599 item=AssistantMessageItem(
600 # Reusing the Responses message ID
601 id=item.id,
602 thread_id=thread.id,
603 content=[_convert_content(c) for c in item.content],
604 created_at=datetime.now(),
605 ),
606 )
607
608 except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
609 for item_id in produced_items:
610 yield ThreadItemRemovedEvent(item_id=item_id)
611
612 # Drain remaining events without processing them
613 context._complete()
614 queue_iterator.drain_and_complete()
615
616 raise
617
618 context._complete()
619
620 # Drain remaining events
621 async for event in queue_iterator:
622 yield event.event
623
624 # If there is still an active workflow at the end of the run, store
625 # it's current state so that we can continue it in the next turn.
626 if ctx.workflow_item:
627 await ctx.store.add_thread_item(
628 thread.id, ctx.workflow_item, ctx.request_context
629 )
630
631 if context.client_tool_call:
632 yield ThreadItemDoneEvent(
633 item=ClientToolCallItem(
634 id=current_item_id
635 or context.store.generate_item_id(
636 "tool_call", thread, context.request_context
637 ),
638 thread_id=thread.id,
639 name=context.client_tool_call.name,
640 arguments=context.client_tool_call.arguments,
641 created_at=datetime.now(),
642 call_id=current_tool_call
643 or context.store.generate_item_id(
644 "tool_call", thread, context.request_context
645 ),
646 ),
647 )
648
649
650TWidget = TypeVar("TWidget", bound=Markdown | Text)
651
652
653async def accumulate_text(
654 events: AsyncIterator[StreamEvent],
655 base_widget: TWidget,
656) -> AsyncIterator[TWidget]:
657 text = ""
658 yield base_widget
659 async for event in events:
660 if event.type == "raw_response_event":
661 if event.data.type == "response.output_text.delta":
662 text += event.data.delta
663 yield base_widget.model_copy(update={"value": text})
664 yield base_widget.model_copy(update={"value": text, "streaming": False})
665
666
667class ThreadItemConverter:
668 """
669 Converts thread items to Agent SDK input items.
670 Widgets, Tasks, and Workflows have default conversions but can be customized.
671 Attachments, Tags, and HiddenContextItems require custom handling based on the use case.
672 Other item types are converted automatically.
673 """
674
675 async def attachment_to_message_content(
676 self, attachment: Attachment
677 ) -> ResponseInputContentParam:
678 """
679 Convert an attachment in a user message into a message content part to send to the model.
680 Required when attachments are enabled.
681 """
682 raise NotImplementedError(
683 "An Attachment was included in a UserMessageItem but Converter.attachment_to_message_content was not implemented"
684 )
685
686 async def tag_to_message_content(
687 self, tag: UserMessageTagContent
688 ) -> ResponseInputContentParam:
689 """
690 Convert a tag in a user message into a message content part to send to the model as context.
691 Required when tags are used.
692 """
693 raise NotImplementedError(
694 "A Tag was included in a UserMessageItem but Converter.tag_to_message_content is not implemented"
695 )
696
697 async def hidden_context_to_input(
698 self, item: HiddenContextItem
699 ) -> TResponseInputItem | list[TResponseInputItem] | None:
700 """
701 Convert a HiddenContextItem into input item(s) to send to the model.
702 Required to override when HiddenContextItems with non-string content are used.
703 """
704 if not isinstance(item.content, str):
705 raise NotImplementedError(
706 "HiddenContextItems with non-string content were present in a user message but a Converter.hidden_context_to_input was not implemented"
707 )
708
709 text = (
710 "Hidden context for the agent (not shown to the user):\n"
711 f"<HiddenContext>\n{item.content}\n</HiddenContext>"
712 )
713 return Message(
714 type="message",
715 content=[
716 ResponseInputTextParam(
717 type="input_text",
718 text=text,
719 )
720 ],
721 role="user",
722 )
723
724 async def sdk_hidden_context_to_input(
725 self, item: SDKHiddenContextItem
726 ) -> TResponseInputItem | list[TResponseInputItem] | None:
727 """
728 Convert a SDKHiddenContextItem into input item to send to the model.
729 This is used by the ChatKit Python SDK for storing additional context
730 for internal operations.
731 Override if you want to wrap the content in a different format.
732 """
733 text = (
734 "Hidden context for the agent (not shown to the user):\n"
735 f"<HiddenContext>\n{item.content}\n</HiddenContext>"
736 )
737 return Message(
738 type="message",
739 content=[
740 ResponseInputTextParam(
741 type="input_text",
742 text=text,
743 )
744 ],
745 role="user",
746 )
747
748 async def task_to_input(
749 self, item: TaskItem
750 ) -> TResponseInputItem | list[TResponseInputItem] | None:
751 """
752 Convert a TaskItem into input item(s) to send to the model.
753 """
754 if item.task.type != "custom" or (
755 not item.task.title and not item.task.content
756 ):
757 return None
758 title = f"{item.task.title}" if item.task.title else ""
759 content = f"{item.task.content}" if item.task.content else ""
760 task_text = f"{title}: {content}" if title and content else title or content
761 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
762 return Message(
763 type="message",
764 content=[
765 ResponseInputTextParam(
766 type="input_text",
767 text=text,
768 )
769 ],
770 role="user",
771 )
772
773 async def workflow_to_input(
774 self, item: WorkflowItem
775 ) -> TResponseInputItem | list[TResponseInputItem] | None:
776 """
777 Convert a TaskItem into input item(s) to send to the model.
778 Returns WorkflowItem.response_items by default.
779 """
780 messages = []
781 for task in item.workflow.tasks:
782 if task.type != "custom" or (not task.title and not task.content):
783 continue
784
785 title = f"{task.title}" if task.title else ""
786 content = f"{task.content}" if task.content else ""
787 task_text = f"{title}: {content}" if title and content else title or content
788 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
789 messages.append(
790 Message(
791 type="message",
792 content=[
793 ResponseInputTextParam(
794 type="input_text",
795 text=text,
796 )
797 ],
798 role="user",
799 )
800 )
801 return messages
802
803 async def widget_to_input(
804 self, item: WidgetItem
805 ) -> TResponseInputItem | list[TResponseInputItem] | None:
806 """
807 Convert a WidgetItem into input item(s) to send to the model.
808 By default, WidgetItems converted to a text description with a JSON representation of the widget.
809 """
810 return Message(
811 type="message",
812 content=[
813 ResponseInputTextParam(
814 type="input_text",
815 text=f"The following graphical UI widget (id: {item.id}) was displayed to the user:"
816 + item.widget.model_dump_json(
817 exclude_unset=True, exclude_none=True
818 ),
819 )
820 ],
821 role="user",
822 )
823
824 async def user_message_to_input(
825 self, item: UserMessageItem, is_last_message: bool = True
826 ) -> TResponseInputItem | list[TResponseInputItem] | None:
827 # Build the user text exactly as typed, rendering tags as @key
828 message_text_parts: list[str] = []
829 # Track tags separately to add system context
830 raw_tags: list[UserMessageTagContent] = []
831
832 for part in item.content:
833 if isinstance(part, UserMessageTextContent):
834 message_text_parts.append(part.text)
835 elif isinstance(part, UserMessageTagContent):
836 message_text_parts.append(f"@{part.text}")
837 raw_tags.append(part)
838 else:
839 assert_never(part)
840
841 user_text_item = Message(
842 role="user",
843 type="message",
844 content=[
845 ResponseInputTextParam(
846 type="input_text", text="".join(message_text_parts)
847 ),
848 *[
849 await self.attachment_to_message_content(a)
850 for a in item.attachments
851 ],
852 ],
853 )
854
855 # Build system items (prepend later): quoted text and @-mention context
856 context_items: list[TResponseInputItem] = []
857
858 if item.quoted_text and is_last_message:
859 context_items.append(
860 Message(
861 role="user",
862 type="message",
863 content=[
864 ResponseInputTextParam(
865 type="input_text",
866 text=f"The user is referring to this in particular: \n{item.quoted_text}",
867 )
868 ],
869 )
870 )
871
872 # Dedupe tags (preserve order) and resolve to message content
873 if raw_tags:
874 seen, uniq_tags = set(), []
875 for t in raw_tags:
876 if t.text not in seen:
877 seen.add(t.text)
878 uniq_tags.append(t)
879
880 tag_content: ResponseInputMessageContentListParam = [
881 # should return summarized text items
882 await self.tag_to_message_content(tag)
883 for tag in uniq_tags
884 ]
885
886 if tag_content:
887 context_items.append(
888 Message(
889 role="user",
890 type="message",
891 content=[
892 ResponseInputTextParam(
893 type="input_text",
894 text=cleandoc("""
895 # User-provided context for @-mentions
896 - When referencing resolved entities, use their canonical names **without** '@'.
897 - The '@' form appears only in user text and should not be echoed.
898 """).strip(),
899 ),
900 *tag_content,
901 ],
902 )
903 )
904
905 return [user_text_item, *context_items]
906
907 async def assistant_message_to_input(
908 self, item: AssistantMessageItem
909 ) -> TResponseInputItem | list[TResponseInputItem] | None:
910 return EasyInputMessageParam(
911 type="message",
912 content=[
913 # content param doesn't support the assistant message content types
914 cast(
915 ResponseInputContentParam,
916 ResponseOutputText(
917 type="output_text",
918 text=c.text,
919 annotations=[], # TODO: these should be sent back as well
920 ).model_dump(),
921 )
922 for c in item.content
923 ],
924 role="assistant",
925 )
926
927 async def client_tool_call_to_input(
928 self, item: ClientToolCallItem
929 ) -> TResponseInputItem | list[TResponseInputItem] | None:
930 if item.status == "pending":
931 # Filter out pending tool calls - they cannot be sent to the model
932 return None
933
934 return [
935 ResponseFunctionToolCallParam(
936 type="function_call",
937 call_id=item.call_id,
938 name=item.name,
939 arguments=json.dumps(item.arguments),
940 ),
941 FunctionCallOutput(
942 type="function_call_output",
943 call_id=item.call_id,
944 output=json.dumps(item.output),
945 ),
946 ]
947
948 async def end_of_turn_to_input(
949 self, item: EndOfTurnItem
950 ) -> TResponseInputItem | list[TResponseInputItem] | None:
951 # Only used for UI hints - you shouldn't need to override this
952 return None
953
954 async def _thread_item_to_input_item(
955 self,
956 item: ThreadItem,
957 is_last_message: bool = True,
958 ) -> list[TResponseInputItem]:
959 match item:
960 case UserMessageItem():
961 out = await self.user_message_to_input(item, is_last_message) or []
962 return out if isinstance(out, list) else [out]
963 case AssistantMessageItem():
964 out = await self.assistant_message_to_input(item) or []
965 return out if isinstance(out, list) else [out]
966 case ClientToolCallItem():
967 out = await self.client_tool_call_to_input(item) or []
968 return out if isinstance(out, list) else [out]
969 case EndOfTurnItem():
970 out = await self.end_of_turn_to_input(item) or []
971 return out if isinstance(out, list) else [out]
972 case WidgetItem():
973 out = await self.widget_to_input(item) or []
974 return out if isinstance(out, list) else [out]
975 case WorkflowItem():
976 out = await self.workflow_to_input(item) or []
977 return out if isinstance(out, list) else [out]
978 case TaskItem():
979 out = await self.task_to_input(item) or []
980 return out if isinstance(out, list) else [out]
981 case HiddenContextItem():
982 out = await self.hidden_context_to_input(item) or []
983 return out if isinstance(out, list) else [out]
984 case SDKHiddenContextItem():
985 out = await self.sdk_hidden_context_to_input(item) or []
986 return out if isinstance(out, list) else [out]
987 case _:
988 assert_never(item)
989
990 async def to_agent_input(
991 self,
992 thread_items: Sequence[ThreadItem] | ThreadItem,
993 ) -> list[TResponseInputItem]:
994 if isinstance(thread_items, Sequence):
995 # shallow copy in case caller mutates the list while we're iterating
996 thread_items = thread_items[:]
997 else:
998 thread_items = [thread_items]
999 output: list[TResponseInputItem] = []
1000 for item in thread_items:
1001 output.extend(
1002 await self._thread_item_to_input_item(
1003 item,
1004 is_last_message=item is thread_items[-1],
1005 )
1006 )
1007 return output
1008
1009
1010_DEFAULT_CONVERTER = ThreadItemConverter()
1011
1012
1013def simple_to_agent_input(thread_items: Sequence[ThreadItem] | ThreadItem):
1014 """Helper that converts thread items using the default ThreadItemConverter."""
1015 return _DEFAULT_CONVERTER.to_agent_input(thread_items)
1016