openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
6df9ee60da4de3f063e752e60bd3ceef920f4e78

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/agents.py

1005lines · modecode

1import asyncio
2import json
3from collections import defaultdict
4from collections.abc import AsyncIterator
5from datetime import datetime
6from inspect import cleandoc
7from typing import (
8 Annotated,
9 Any,
10 AsyncGenerator,
11 Generic,
12 Sequence,
13 TypeVar,
14 cast,
15)
16
17from agents import (
18 InputGuardrailTripwireTriggered,
19 OutputGuardrailTripwireTriggered,
20 RunResultStreaming,
21 StreamEvent,
22 TResponseInputItem,
23)
24from openai.types.responses import (
25 EasyInputMessageParam,
26 ResponseFunctionToolCallParam,
27 ResponseInputContentParam,
28 ResponseInputMessageContentListParam,
29 ResponseInputTextParam,
30 ResponseOutputText,
31)
32from openai.types.responses.response_input_item_param import (
33 FunctionCallOutput,
34 Message,
35)
36from openai.types.responses.response_output_message import Content
37from openai.types.responses.response_output_text import (
38 Annotation as ResponsesAnnotation,
39)
40from pydantic import BaseModel, ConfigDict, SkipValidation, TypeAdapter
41from typing_extensions import assert_never
42
43from .server import stream_widget
44from .store import Store, StoreItemType
45from .types import (
46 Annotation,
47 AssistantMessageContent,
48 AssistantMessageContentPartAdded,
49 AssistantMessageContentPartAnnotationAdded,
50 AssistantMessageContentPartDone,
51 AssistantMessageContentPartTextDelta,
52 AssistantMessageItem,
53 Attachment,
54 ClientToolCallItem,
55 DurationSummary,
56 EndOfTurnItem,
57 FileSource,
58 HiddenContextItem,
59 SDKHiddenContextItem,
60 Task,
61 TaskItem,
62 ThoughtTask,
63 ThreadItem,
64 ThreadItemAddedEvent,
65 ThreadItemDoneEvent,
66 ThreadItemRemovedEvent,
67 ThreadItemUpdatedEvent,
68 ThreadMetadata,
69 ThreadStreamEvent,
70 URLSource,
71 UserMessageItem,
72 UserMessageTagContent,
73 UserMessageTextContent,
74 WidgetItem,
75 Workflow,
76 WorkflowItem,
77 WorkflowSummary,
78 WorkflowTaskAdded,
79 WorkflowTaskUpdated,
80)
81from .widgets import Markdown, Text, WidgetRoot
82
83
84class ClientToolCall(BaseModel):
85 """
86 Returned from tool methods to indicate a client-side tool call.
87 """
88
89 name: str
90 arguments: dict[str, Any]
91
92
93class _QueueCompleteSentinel: ...
94
95
96TContext = TypeVar("TContext")
97
98
99class AgentContext(BaseModel, Generic[TContext]):
100 model_config = ConfigDict(arbitrary_types_allowed=True)
101
102 thread: ThreadMetadata
103 store: Annotated[Store[TContext], SkipValidation]
104 request_context: TContext
105 previous_response_id: str | None = None
106 client_tool_call: ClientToolCall | None = None
107 workflow_item: WorkflowItem | None = None
108 _events: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel] = asyncio.Queue()
109
110 def generate_id(
111 self, type: StoreItemType, thread: ThreadMetadata | None = None
112 ) -> str:
113 if type == "thread":
114 return self.store.generate_thread_id(self.request_context)
115 return self.store.generate_item_id(
116 type, thread or self.thread, self.request_context
117 )
118
119 async def stream_widget(
120 self,
121 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
122 copy_text: str | None = None,
123 ) -> None:
124 async for event in stream_widget(
125 self.thread,
126 widget,
127 copy_text,
128 lambda item_type: self.store.generate_item_id(
129 item_type, self.thread, self.request_context
130 ),
131 ):
132 await self._events.put(event)
133
134 async def end_workflow(
135 self, summary: WorkflowSummary | None = None, expanded: bool = False
136 ) -> None:
137 if not self.workflow_item:
138 # No workflow to end
139 return
140
141 if summary is not None:
142 self.workflow_item.workflow.summary = summary
143 elif self.workflow_item.workflow.summary is None:
144 # If no summary was set or provided, set a basic work summary
145 delta = datetime.now() - self.workflow_item.created_at
146 duration = int(delta.total_seconds())
147 self.workflow_item.workflow.summary = DurationSummary(duration=duration)
148 self.workflow_item.workflow.expanded = expanded
149 await self.stream(ThreadItemDoneEvent(item=self.workflow_item))
150 self.workflow_item = None
151
152 async def start_workflow(self, workflow: Workflow) -> None:
153 self.workflow_item = WorkflowItem(
154 id=self.generate_id("workflow"),
155 created_at=datetime.now(),
156 workflow=workflow,
157 thread_id=self.thread.id,
158 )
159
160 if workflow.type != "reasoning" and len(workflow.tasks) == 0:
161 # Defer sending added event until we have tasks
162 return
163
164 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
165
166 async def update_workflow_task(self, task: Task, task_index: int) -> None:
167 if self.workflow_item is None:
168 raise ValueError("Workflow is not set")
169 # ensure reference is updated in case task is a copy
170 self.workflow_item.workflow.tasks[task_index] = task
171 await self.stream(
172 ThreadItemUpdatedEvent(
173 item_id=self.workflow_item.id,
174 update=WorkflowTaskUpdated(
175 task=task,
176 task_index=task_index,
177 ),
178 )
179 )
180
181 async def add_workflow_task(self, task: Task) -> None:
182 self.workflow_item = self.workflow_item or WorkflowItem(
183 id=self.generate_id("workflow"),
184 created_at=datetime.now(),
185 workflow=Workflow(type="custom", tasks=[]),
186 thread_id=self.thread.id,
187 )
188 workflow = self.workflow_item.workflow
189 workflow.tasks.append(task)
190
191 if workflow.type != "reasoning" and len(workflow.tasks) == 1:
192 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
193 else:
194 await self.stream(
195 ThreadItemUpdatedEvent(
196 item_id=self.workflow_item.id,
197 update=WorkflowTaskAdded(
198 task=task,
199 task_index=workflow.tasks.index(task),
200 ),
201 )
202 )
203
204 async def stream(self, event: ThreadStreamEvent) -> None:
205 await self._events.put(event)
206
207 def _complete(self):
208 self._events.put_nowait(_QueueCompleteSentinel())
209
210
211def _convert_content(content: Content) -> AssistantMessageContent:
212 if content.type == "output_text":
213 annotations = [
214 _convert_annotation(annotation) for annotation in content.annotations
215 ]
216 annotations = [a for a in annotations if a is not None]
217 return AssistantMessageContent(
218 text=content.text,
219 annotations=annotations,
220 )
221 else:
222 return AssistantMessageContent(
223 text=content.refusal,
224 annotations=[],
225 )
226
227
228def _convert_annotation(raw_annotation: object) -> Annotation | None:
229 # There is a bug in the OpenAPI client that sometimes parses the annotation delta event into the wrong class
230 # resulting into annotation being a dict or untyped object instead instead of a ResponsesAnnotation
231 annotation = TypeAdapter[ResponsesAnnotation](ResponsesAnnotation).validate_python(
232 raw_annotation
233 )
234
235 if annotation.type == "file_citation":
236 filename = annotation.filename
237 if not filename:
238 return None
239
240 return Annotation(
241 source=FileSource(filename=filename, title=filename),
242 index=annotation.index,
243 )
244
245 if annotation.type == "url_citation":
246 return Annotation(
247 source=URLSource(
248 url=annotation.url,
249 title=annotation.title,
250 ),
251 index=annotation.end_index,
252 )
253
254 if annotation.type == "container_file_citation":
255 filename = annotation.filename
256 if not filename:
257 return None
258
259 return Annotation(
260 source=FileSource(filename=filename, title=filename),
261 index=annotation.end_index,
262 )
263
264 return None
265
266
267T1 = TypeVar("T1")
268T2 = TypeVar("T2")
269
270
271async def _merge_generators(
272 a: AsyncIterator[T1],
273 b: AsyncIterator[T2],
274) -> AsyncIterator[T1 | T2]:
275 pending: list[AsyncIterator[T1 | T2]] = [a, b]
276 pending_tasks: dict[asyncio.Task, AsyncIterator[T1 | T2]] = {
277 asyncio.ensure_future(g.__anext__()): g for g in pending
278 }
279 while len(pending_tasks) > 0:
280 done, _ = await asyncio.wait(
281 pending_tasks.keys(), return_when="FIRST_COMPLETED"
282 )
283 stop = False
284 for d in done:
285 try:
286 result = d.result()
287 yield result
288 dg = pending_tasks[d]
289 pending_tasks[asyncio.ensure_future(dg.__anext__())] = dg
290 except StopAsyncIteration:
291 stop = True
292 finally:
293 del pending_tasks[d]
294 if stop:
295 for task in pending_tasks.keys():
296 if not task.cancel():
297 try:
298 yield task.result()
299 except asyncio.CancelledError:
300 pass
301 except asyncio.InvalidStateError:
302 pass
303 break
304
305
306class _EventWrapper:
307 def __init__(self, event: ThreadStreamEvent):
308 self.event = event
309
310
311class _AsyncQueueIterator(AsyncIterator[_EventWrapper]):
312 def __init__(
313 self, queue: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel]
314 ):
315 self.queue = queue
316 self.completed = False
317
318 def __aiter__(self):
319 return self
320
321 async def __anext__(self):
322 if self.completed:
323 raise StopAsyncIteration
324
325 item = await self.queue.get()
326 if isinstance(item, _QueueCompleteSentinel):
327 self.completed = True
328 raise StopAsyncIteration
329 return _EventWrapper(item)
330
331 def drain_and_complete(self) -> None:
332 """Empty the underlying queue without awaiting and mark this iterator completed.
333
334 This is intended for cleanup paths where we must guarantee no awaits
335 occur. All queued items, including any completion sentinel, are
336 discarded.
337 """
338 while True:
339 try:
340 self.queue.get_nowait()
341 except asyncio.QueueEmpty:
342 break
343 self.completed = True
344
345
346class StreamingThoughtTracker(BaseModel):
347 item_id: str
348 index: int
349 task: ThoughtTask
350
351
352async def stream_agent_response(
353 context: AgentContext, result: RunResultStreaming
354) -> AsyncIterator[ThreadStreamEvent]:
355 current_item_id = None
356 current_tool_call = None
357 ctx = context
358 thread = context.thread
359 queue_iterator = _AsyncQueueIterator(context._events)
360 produced_items = set()
361 streaming_thought: None | StreamingThoughtTracker = None
362 # item_id -> content_index -> annotation count
363 item_annotation_count: defaultdict[str, defaultdict[int, int]] = defaultdict(
364 lambda: defaultdict(int)
365 )
366
367 # check if the last item in the thread was a workflow or a client tool call
368 # if it was a client tool call, check if the second last item was a workflow
369 # if either was, continue the workflow
370 items = await context.store.load_thread_items(
371 thread.id, None, 2, "desc", context.request_context
372 )
373 last_item = items.data[0] if len(items.data) > 0 else None
374 second_last_item = items.data[1] if len(items.data) > 1 else None
375
376 if last_item and last_item.type == "workflow":
377 ctx.workflow_item = last_item
378 elif (
379 last_item
380 and last_item.type == "client_tool_call"
381 and second_last_item
382 and second_last_item.type == "workflow"
383 ):
384 ctx.workflow_item = second_last_item
385
386 def end_workflow(item: WorkflowItem):
387 if item == ctx.workflow_item:
388 ctx.workflow_item = None
389 delta = datetime.now() - item.created_at
390 duration = int(delta.total_seconds())
391 if item.workflow.summary is None:
392 item.workflow.summary = DurationSummary(duration=duration)
393 # Default to closing all workflows
394 # To keep a workflow open on completion, close it explicitly with
395 # AgentContext.end_workflow(expanded=True)
396 item.workflow.expanded = False
397 return ThreadItemDoneEvent(item=item)
398
399 try:
400 async for event in _merge_generators(result.stream_events(), queue_iterator):
401 # Events emitted from agent context helpers
402 if isinstance(event, _EventWrapper):
403 event = event.event
404 if (
405 event.type == "thread.item.added"
406 or event.type == "thread.item.done"
407 ):
408 # End the current workflow if visual item is added after it
409 if (
410 ctx.workflow_item
411 and ctx.workflow_item.id != event.item.id
412 and event.item.type != "client_tool_call"
413 and event.item.type != "hidden_context_item"
414 ):
415 yield end_workflow(ctx.workflow_item)
416
417 # track the current workflow if one is added
418 if (
419 event.type == "thread.item.added"
420 and event.item.type == "workflow"
421 ):
422 ctx.workflow_item = event.item
423
424 # track integration produced items so we can clean them up if
425 # there is a guardrail tripwire
426 produced_items.add(event.item.id)
427 yield event
428 continue
429
430 if event.type == "run_item_stream_event":
431 event = event.item
432 if (
433 event.type == "tool_call_item"
434 and event.raw_item.type == "function_call"
435 ):
436 current_tool_call = event.raw_item.call_id
437 current_item_id = event.raw_item.id
438 assert current_item_id
439 produced_items.add(current_item_id)
440 continue
441
442 if event.type != "raw_response_event":
443 # Ignore everything else that isn't a raw response event
444 continue
445
446 # Handle Responses events
447 event = event.data
448 if event.type == "response.content_part.added":
449 if event.part.type == "reasoning_text":
450 continue
451 content = _convert_content(event.part)
452 yield ThreadItemUpdatedEvent(
453 item_id=event.item_id,
454 update=AssistantMessageContentPartAdded(
455 content_index=event.content_index,
456 content=content,
457 ),
458 )
459 elif event.type == "response.output_text.delta":
460 yield ThreadItemUpdatedEvent(
461 item_id=event.item_id,
462 update=AssistantMessageContentPartTextDelta(
463 content_index=event.content_index,
464 delta=event.delta,
465 ),
466 )
467 elif event.type == "response.output_text.done":
468 yield ThreadItemUpdatedEvent(
469 item_id=event.item_id,
470 update=AssistantMessageContentPartDone(
471 content_index=event.content_index,
472 content=AssistantMessageContent(
473 text=event.text,
474 annotations=[],
475 ),
476 ),
477 )
478 elif event.type == "response.output_text.annotation.added":
479 annotation = _convert_annotation(event.annotation)
480 if annotation:
481 # Manually track annotation indices per content part in case we drop an annotation that
482 # we can't convert to our internal representation (e.g. missing filename).
483 annotation_index = item_annotation_count[event.item_id][
484 event.content_index
485 ]
486 item_annotation_count[event.item_id][event.content_index] = (
487 annotation_index + 1
488 )
489 yield ThreadItemUpdatedEvent(
490 item_id=event.item_id,
491 update=AssistantMessageContentPartAnnotationAdded(
492 content_index=event.content_index,
493 annotation_index=annotation_index,
494 annotation=annotation,
495 ),
496 )
497 continue
498 elif event.type == "response.output_item.added":
499 item = event.item
500 if item.type == "reasoning" and not ctx.workflow_item:
501 ctx.workflow_item = WorkflowItem(
502 id=ctx.generate_id("workflow"),
503 created_at=datetime.now(),
504 workflow=Workflow(type="reasoning", tasks=[]),
505 thread_id=thread.id,
506 )
507 produced_items.add(ctx.workflow_item.id)
508 yield ThreadItemAddedEvent(item=ctx.workflow_item)
509 if item.type == "message":
510 if ctx.workflow_item:
511 yield end_workflow(ctx.workflow_item)
512 produced_items.add(item.id)
513 yield ThreadItemAddedEvent(
514 item=AssistantMessageItem(
515 # Reusing the Responses message ID
516 id=item.id,
517 thread_id=thread.id,
518 content=[_convert_content(c) for c in item.content],
519 created_at=datetime.now(),
520 ),
521 )
522 elif event.type == "response.reasoning_summary_text.delta":
523 if not ctx.workflow_item:
524 continue
525
526 # stream the first thought in a new workflow so that we can show it earlier
527 if (
528 ctx.workflow_item.workflow.type == "reasoning"
529 and len(ctx.workflow_item.workflow.tasks) == 0
530 ):
531 streaming_thought = StreamingThoughtTracker(
532 item_id=event.item_id,
533 index=event.summary_index,
534 task=ThoughtTask(content=event.delta),
535 )
536 ctx.workflow_item.workflow.tasks.append(streaming_thought.task)
537 yield ThreadItemUpdatedEvent(
538 item_id=ctx.workflow_item.id,
539 update=WorkflowTaskAdded(
540 task=streaming_thought.task,
541 task_index=0,
542 ),
543 )
544 elif (
545 streaming_thought
546 and streaming_thought.task in ctx.workflow_item.workflow.tasks
547 and event.item_id == streaming_thought.item_id
548 and event.summary_index == streaming_thought.index
549 ):
550 streaming_thought.task.content += event.delta
551 yield ThreadItemUpdatedEvent(
552 item_id=ctx.workflow_item.id,
553 update=WorkflowTaskUpdated(
554 task=streaming_thought.task,
555 task_index=ctx.workflow_item.workflow.tasks.index(
556 streaming_thought.task
557 ),
558 ),
559 )
560 elif event.type == "response.reasoning_summary_text.done":
561 if ctx.workflow_item:
562 if (
563 streaming_thought
564 and streaming_thought.task in ctx.workflow_item.workflow.tasks
565 and event.item_id == streaming_thought.item_id
566 and event.summary_index == streaming_thought.index
567 ):
568 task = streaming_thought.task
569 task.content = event.text
570 streaming_thought = None
571 update = WorkflowTaskUpdated(
572 task=task,
573 task_index=ctx.workflow_item.workflow.tasks.index(task),
574 )
575 else:
576 task = ThoughtTask(content=event.text)
577 ctx.workflow_item.workflow.tasks.append(task)
578 update = WorkflowTaskAdded(
579 task=task,
580 task_index=ctx.workflow_item.workflow.tasks.index(task),
581 )
582 yield ThreadItemUpdatedEvent(
583 item_id=ctx.workflow_item.id,
584 update=update,
585 )
586 elif event.type == "response.output_item.done":
587 item = event.item
588 if item.type == "message":
589 produced_items.add(item.id)
590 yield ThreadItemDoneEvent(
591 item=AssistantMessageItem(
592 # Reusing the Responses message ID
593 id=item.id,
594 thread_id=thread.id,
595 content=[_convert_content(c) for c in item.content],
596 created_at=datetime.now(),
597 ),
598 )
599
600 except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
601 for item_id in produced_items:
602 yield ThreadItemRemovedEvent(item_id=item_id)
603
604 # Drain remaining events without processing them
605 context._complete()
606 queue_iterator.drain_and_complete()
607
608 raise
609
610 context._complete()
611
612 # Drain remaining events
613 async for event in queue_iterator:
614 yield event.event
615
616 # If there is still an active workflow at the end of the run, store
617 # it's current state so that we can continue it in the next turn.
618 if ctx.workflow_item:
619 await ctx.store.add_thread_item(
620 thread.id, ctx.workflow_item, ctx.request_context
621 )
622
623 if context.client_tool_call:
624 yield ThreadItemDoneEvent(
625 item=ClientToolCallItem(
626 id=current_item_id
627 or context.store.generate_item_id(
628 "tool_call", thread, context.request_context
629 ),
630 thread_id=thread.id,
631 name=context.client_tool_call.name,
632 arguments=context.client_tool_call.arguments,
633 created_at=datetime.now(),
634 call_id=current_tool_call
635 or context.store.generate_item_id(
636 "tool_call", thread, context.request_context
637 ),
638 ),
639 )
640
641
642TWidget = TypeVar("TWidget", bound=Markdown | Text)
643
644
645async def accumulate_text(
646 events: AsyncIterator[StreamEvent],
647 base_widget: TWidget,
648) -> AsyncIterator[TWidget]:
649 text = ""
650 yield base_widget
651 async for event in events:
652 if event.type == "raw_response_event":
653 if event.data.type == "response.output_text.delta":
654 text += event.data.delta
655 yield base_widget.model_copy(update={"value": text})
656 yield base_widget.model_copy(update={"value": text, "streaming": False})
657
658
659class ThreadItemConverter:
660 """
661 Converts thread items to Agent SDK input items.
662 Widgets, Tasks, and Workflows have default conversions but can be customized.
663 Attachments, Tags, and HiddenContextItems require custom handling based on the use case.
664 Other item types are converted automatically.
665 """
666
667 async def attachment_to_message_content(
668 self, attachment: Attachment
669 ) -> ResponseInputContentParam:
670 """
671 Convert an attachment in a user message into a message content part to send to the model.
672 Required when attachments are enabled.
673 """
674 raise NotImplementedError(
675 "An Attachment was included in a UserMessageItem but Converter.attachment_to_message_content was not implemented"
676 )
677
678 async def tag_to_message_content(
679 self, tag: UserMessageTagContent
680 ) -> ResponseInputContentParam:
681 """
682 Convert a tag in a user message into a message content part to send to the model as context.
683 Required when tags are used.
684 """
685 raise NotImplementedError(
686 "A Tag was included in a UserMessageItem but Converter.tag_to_message_content is not implemented"
687 )
688
689 async def hidden_context_to_input(
690 self, item: HiddenContextItem
691 ) -> TResponseInputItem | list[TResponseInputItem] | None:
692 """
693 Convert a HiddenContextItem into input item(s) to send to the model.
694 Required to override when HiddenContextItems with non-string content are used.
695 """
696 if not isinstance(item.content, str):
697 raise NotImplementedError(
698 "HiddenContextItems with non-string content were present in a user message but a Converter.hidden_context_to_input was not implemented"
699 )
700
701 text = (
702 "Hidden context for the agent (not shown to the user):\n"
703 f"<HiddenContext>\n{item.content}\n</HiddenContext>"
704 )
705 return Message(
706 type="message",
707 content=[
708 ResponseInputTextParam(
709 type="input_text",
710 text=text,
711 )
712 ],
713 role="user",
714 )
715
716 async def sdk_hidden_context_to_input(
717 self, item: SDKHiddenContextItem
718 ) -> TResponseInputItem | list[TResponseInputItem] | None:
719 """
720 Convert a SDKHiddenContextItem into input item to send to the model.
721 This is used by the ChatKit Python SDK for storing additional context
722 for internal operations; you shouldn't need to override this.
723 """
724 text = (
725 "Hidden ChatKit SDK context for the agent (not shown to the user):\n"
726 f"<SDKHiddenContext>\n{item.content}\n</SDKHiddenContext>"
727 )
728 return Message(
729 type="message",
730 content=[
731 ResponseInputTextParam(
732 type="input_text",
733 text=text,
734 )
735 ],
736 role="user",
737 )
738
739 async def task_to_input(
740 self, item: TaskItem
741 ) -> TResponseInputItem | list[TResponseInputItem] | None:
742 """
743 Convert a TaskItem into input item(s) to send to the model.
744 """
745 if item.task.type != "custom" or (
746 not item.task.title and not item.task.content
747 ):
748 return None
749 title = f"{item.task.title}" if item.task.title else ""
750 content = f"{item.task.content}" if item.task.content else ""
751 task_text = f"{title}: {content}" if title and content else title or content
752 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
753 return Message(
754 type="message",
755 content=[
756 ResponseInputTextParam(
757 type="input_text",
758 text=text,
759 )
760 ],
761 role="user",
762 )
763
764 async def workflow_to_input(
765 self, item: WorkflowItem
766 ) -> TResponseInputItem | list[TResponseInputItem] | None:
767 """
768 Convert a TaskItem into input item(s) to send to the model.
769 Returns WorkflowItem.response_items by default.
770 """
771 messages = []
772 for task in item.workflow.tasks:
773 if task.type != "custom" or (not task.title and not task.content):
774 continue
775
776 title = f"{task.title}" if task.title else ""
777 content = f"{task.content}" if task.content else ""
778 task_text = f"{title}: {content}" if title and content else title or content
779 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
780 messages.append(
781 Message(
782 type="message",
783 content=[
784 ResponseInputTextParam(
785 type="input_text",
786 text=text,
787 )
788 ],
789 role="user",
790 )
791 )
792 return messages
793
794 async def widget_to_input(
795 self, item: WidgetItem
796 ) -> TResponseInputItem | list[TResponseInputItem] | None:
797 """
798 Convert a WidgetItem into input item(s) to send to the model.
799 By default, WidgetItems converted to a text description with a JSON representation of the widget.
800 """
801 return Message(
802 type="message",
803 content=[
804 ResponseInputTextParam(
805 type="input_text",
806 text=f"The following graphical UI widget (id: {item.id}) was displayed to the user:"
807 + item.widget.model_dump_json(
808 exclude_unset=True, exclude_none=True
809 ),
810 )
811 ],
812 role="user",
813 )
814
815 async def user_message_to_input(
816 self, item: UserMessageItem, is_last_message: bool = True
817 ) -> TResponseInputItem | list[TResponseInputItem] | None:
818 # Build the user text exactly as typed, rendering tags as @key
819 message_text_parts: list[str] = []
820 # Track tags separately to add system context
821 raw_tags: list[UserMessageTagContent] = []
822
823 for part in item.content:
824 if isinstance(part, UserMessageTextContent):
825 message_text_parts.append(part.text)
826 elif isinstance(part, UserMessageTagContent):
827 message_text_parts.append(f"@{part.text}")
828 raw_tags.append(part)
829 else:
830 assert_never(part)
831
832 user_text_item = Message(
833 role="user",
834 type="message",
835 content=[
836 ResponseInputTextParam(
837 type="input_text", text="".join(message_text_parts)
838 ),
839 *[
840 await self.attachment_to_message_content(a)
841 for a in item.attachments
842 ],
843 ],
844 )
845
846 # Build system items (prepend later): quoted text and @-mention context
847 context_items: list[TResponseInputItem] = []
848
849 if item.quoted_text and is_last_message:
850 context_items.append(
851 Message(
852 role="user",
853 type="message",
854 content=[
855 ResponseInputTextParam(
856 type="input_text",
857 text=f"The user is referring to this in particular: \n{item.quoted_text}",
858 )
859 ],
860 )
861 )
862
863 # Dedupe tags (preserve order) and resolve to message content
864 if raw_tags:
865 seen, uniq_tags = set(), []
866 for t in raw_tags:
867 if t.text not in seen:
868 seen.add(t.text)
869 uniq_tags.append(t)
870
871 tag_content: ResponseInputMessageContentListParam = [
872 # should return summarized text items
873 await self.tag_to_message_content(tag)
874 for tag in uniq_tags
875 ]
876
877 if tag_content:
878 context_items.append(
879 Message(
880 role="user",
881 type="message",
882 content=[
883 ResponseInputTextParam(
884 type="input_text",
885 text=cleandoc("""
886 # User-provided context for @-mentions
887 - When referencing resolved entities, use their canonical names **without** '@'.
888 - The '@' form appears only in user text and should not be echoed.
889 """).strip(),
890 ),
891 *tag_content,
892 ],
893 )
894 )
895
896 return [user_text_item, *context_items]
897
898 async def assistant_message_to_input(
899 self, item: AssistantMessageItem
900 ) -> TResponseInputItem | list[TResponseInputItem] | None:
901 return EasyInputMessageParam(
902 type="message",
903 content=[
904 # content param doesn't support the assistant message content types
905 cast(
906 ResponseInputContentParam,
907 ResponseOutputText(
908 type="output_text",
909 text=c.text,
910 annotations=[], # TODO: these should be sent back as well
911 ).model_dump(),
912 )
913 for c in item.content
914 ],
915 role="assistant",
916 )
917
918 async def client_tool_call_to_input(
919 self, item: ClientToolCallItem
920 ) -> TResponseInputItem | list[TResponseInputItem] | None:
921 if item.status == "pending":
922 # Filter out pending tool calls - they cannot be sent to the model
923 return None
924
925 return [
926 ResponseFunctionToolCallParam(
927 type="function_call",
928 call_id=item.call_id,
929 name=item.name,
930 arguments=json.dumps(item.arguments),
931 ),
932 FunctionCallOutput(
933 type="function_call_output",
934 call_id=item.call_id,
935 output=json.dumps(item.output),
936 ),
937 ]
938
939 async def end_of_turn_to_input(
940 self, item: EndOfTurnItem
941 ) -> TResponseInputItem | list[TResponseInputItem] | None:
942 # Only used for UI hints - you shouldn't need to override this
943 return None
944
945 async def _thread_item_to_input_item(
946 self,
947 item: ThreadItem,
948 is_last_message: bool = True,
949 ) -> list[TResponseInputItem]:
950 match item:
951 case UserMessageItem():
952 out = await self.user_message_to_input(item, is_last_message) or []
953 return out if isinstance(out, list) else [out]
954 case AssistantMessageItem():
955 out = await self.assistant_message_to_input(item) or []
956 return out if isinstance(out, list) else [out]
957 case ClientToolCallItem():
958 out = await self.client_tool_call_to_input(item) or []
959 return out if isinstance(out, list) else [out]
960 case EndOfTurnItem():
961 out = await self.end_of_turn_to_input(item) or []
962 return out if isinstance(out, list) else [out]
963 case WidgetItem():
964 out = await self.widget_to_input(item) or []
965 return out if isinstance(out, list) else [out]
966 case WorkflowItem():
967 out = await self.workflow_to_input(item) or []
968 return out if isinstance(out, list) else [out]
969 case TaskItem():
970 out = await self.task_to_input(item) or []
971 return out if isinstance(out, list) else [out]
972 case HiddenContextItem():
973 out = await self.hidden_context_to_input(item) or []
974 return out if isinstance(out, list) else [out]
975 case SDKHiddenContextItem():
976 out = await self.sdk_hidden_context_to_input(item) or []
977 return out if isinstance(out, list) else [out]
978 case _:
979 assert_never(item)
980
981 async def to_agent_input(
982 self,
983 thread_items: Sequence[ThreadItem] | ThreadItem,
984 ) -> list[TResponseInputItem]:
985 if isinstance(thread_items, Sequence):
986 # shallow copy in case caller mutates the list while we're iterating
987 thread_items = thread_items[:]
988 else:
989 thread_items = [thread_items]
990 output: list[TResponseInputItem] = []
991 for item in thread_items:
992 output.extend(
993 await self._thread_item_to_input_item(
994 item,
995 is_last_message=item is thread_items[-1],
996 )
997 )
998 return output
999
1000
1001_DEFAULT_CONVERTER = ThreadItemConverter()
1002
1003
1004def simple_to_agent_input(thread_items: Sequence[ThreadItem] | ThreadItem):
1005 return _DEFAULT_CONVERTER.to_agent_input(thread_items)
1006