openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
b77d8564687675d4dca66278bb553a05390c709d

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/agents.py

1006lines · modecode

1import asyncio
2import json
3from collections import defaultdict
4from collections.abc import AsyncIterator
5from datetime import datetime
6from inspect import cleandoc
7from typing import (
8 Annotated,
9 Any,
10 AsyncGenerator,
11 Generic,
12 Sequence,
13 TypeVar,
14 cast,
15)
16
17from agents import (
18 InputGuardrailTripwireTriggered,
19 OutputGuardrailTripwireTriggered,
20 RunResultStreaming,
21 StreamEvent,
22 TResponseInputItem,
23)
24from openai.types.responses import (
25 EasyInputMessageParam,
26 ResponseFunctionToolCallParam,
27 ResponseInputContentParam,
28 ResponseInputMessageContentListParam,
29 ResponseInputTextParam,
30 ResponseOutputText,
31)
32from openai.types.responses.response_input_item_param import (
33 FunctionCallOutput,
34 Message,
35)
36from openai.types.responses.response_output_message import Content
37from openai.types.responses.response_output_text import (
38 Annotation as ResponsesAnnotation,
39)
40from pydantic import BaseModel, ConfigDict, SkipValidation, TypeAdapter
41from typing_extensions import assert_never
42
43from .server import stream_widget
44from .store import Store, StoreItemType
45from .types import (
46 Annotation,
47 AssistantMessageContent,
48 AssistantMessageContentPartAdded,
49 AssistantMessageContentPartAnnotationAdded,
50 AssistantMessageContentPartDone,
51 AssistantMessageContentPartTextDelta,
52 AssistantMessageItem,
53 Attachment,
54 ClientToolCallItem,
55 DurationSummary,
56 EndOfTurnItem,
57 FileSource,
58 HiddenContextItem,
59 SDKHiddenContextItem,
60 Task,
61 TaskItem,
62 ThoughtTask,
63 ThreadItem,
64 ThreadItemAddedEvent,
65 ThreadItemDoneEvent,
66 ThreadItemRemovedEvent,
67 ThreadItemUpdatedEvent,
68 ThreadMetadata,
69 ThreadStreamEvent,
70 URLSource,
71 UserMessageItem,
72 UserMessageTagContent,
73 UserMessageTextContent,
74 WidgetItem,
75 Workflow,
76 WorkflowItem,
77 WorkflowSummary,
78 WorkflowTaskAdded,
79 WorkflowTaskUpdated,
80)
81from .widgets import Markdown, Text, WidgetRoot
82
83
84class ClientToolCall(BaseModel):
85 """
86 Returned from tool methods to indicate a client-side tool call.
87 """
88
89 name: str
90 arguments: dict[str, Any]
91
92
93class _QueueCompleteSentinel: ...
94
95
96TContext = TypeVar("TContext")
97
98
99class AgentContext(BaseModel, Generic[TContext]):
100 model_config = ConfigDict(arbitrary_types_allowed=True)
101
102 thread: ThreadMetadata
103 store: Annotated[Store[TContext], SkipValidation]
104 request_context: TContext
105 previous_response_id: str | None = None
106 client_tool_call: ClientToolCall | None = None
107 workflow_item: WorkflowItem | None = None
108 _events: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel] = asyncio.Queue()
109
110 def generate_id(
111 self, type: StoreItemType, thread: ThreadMetadata | None = None
112 ) -> str:
113 if type == "thread":
114 return self.store.generate_thread_id(self.request_context)
115 return self.store.generate_item_id(
116 type, thread or self.thread, self.request_context
117 )
118
119 async def stream_widget(
120 self,
121 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
122 copy_text: str | None = None,
123 ) -> None:
124 async for event in stream_widget(
125 self.thread,
126 widget,
127 copy_text,
128 lambda item_type: self.store.generate_item_id(
129 item_type, self.thread, self.request_context
130 ),
131 ):
132 await self._events.put(event)
133
134 async def end_workflow(
135 self, summary: WorkflowSummary | None = None, expanded: bool = False
136 ) -> None:
137 if not self.workflow_item:
138 # No workflow to end
139 return
140
141 if summary is not None:
142 self.workflow_item.workflow.summary = summary
143 elif self.workflow_item.workflow.summary is None:
144 # If no summary was set or provided, set a basic work summary
145 delta = datetime.now() - self.workflow_item.created_at
146 duration = int(delta.total_seconds())
147 self.workflow_item.workflow.summary = DurationSummary(duration=duration)
148 self.workflow_item.workflow.expanded = expanded
149 await self.stream(ThreadItemDoneEvent(item=self.workflow_item))
150 self.workflow_item = None
151
152 async def start_workflow(self, workflow: Workflow) -> None:
153 self.workflow_item = WorkflowItem(
154 id=self.generate_id("workflow"),
155 created_at=datetime.now(),
156 workflow=workflow,
157 thread_id=self.thread.id,
158 )
159
160 if workflow.type != "reasoning" and len(workflow.tasks) == 0:
161 # Defer sending added event until we have tasks
162 return
163
164 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
165
166 async def update_workflow_task(self, task: Task, task_index: int) -> None:
167 if self.workflow_item is None:
168 raise ValueError("Workflow is not set")
169 # ensure reference is updated in case task is a copy
170 self.workflow_item.workflow.tasks[task_index] = task
171 await self.stream(
172 ThreadItemUpdatedEvent(
173 item_id=self.workflow_item.id,
174 update=WorkflowTaskUpdated(
175 task=task,
176 task_index=task_index,
177 ),
178 )
179 )
180
181 async def add_workflow_task(self, task: Task) -> None:
182 self.workflow_item = self.workflow_item or WorkflowItem(
183 id=self.generate_id("workflow"),
184 created_at=datetime.now(),
185 workflow=Workflow(type="custom", tasks=[]),
186 thread_id=self.thread.id,
187 )
188 workflow = self.workflow_item.workflow
189 workflow.tasks.append(task)
190
191 if workflow.type != "reasoning" and len(workflow.tasks) == 1:
192 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
193 else:
194 await self.stream(
195 ThreadItemUpdatedEvent(
196 item_id=self.workflow_item.id,
197 update=WorkflowTaskAdded(
198 task=task,
199 task_index=workflow.tasks.index(task),
200 ),
201 )
202 )
203
204 async def stream(self, event: ThreadStreamEvent) -> None:
205 await self._events.put(event)
206
207 def _complete(self):
208 self._events.put_nowait(_QueueCompleteSentinel())
209
210
211def _convert_content(content: Content) -> AssistantMessageContent:
212 if content.type == "output_text":
213 annotations = [
214 _convert_annotation(annotation) for annotation in content.annotations
215 ]
216 annotations = [a for a in annotations if a is not None]
217 return AssistantMessageContent(
218 text=content.text,
219 annotations=annotations,
220 )
221 else:
222 return AssistantMessageContent(
223 text=content.refusal,
224 annotations=[],
225 )
226
227
228def _convert_annotation(raw_annotation: object) -> Annotation | None:
229 # There is a bug in the OpenAPI client that sometimes parses the annotation delta event into the wrong class
230 # resulting into annotation being a dict or untyped object instead instead of a ResponsesAnnotation
231 annotation = TypeAdapter[ResponsesAnnotation](ResponsesAnnotation).validate_python(
232 raw_annotation
233 )
234
235 if annotation.type == "file_citation":
236 filename = annotation.filename
237 if not filename:
238 return None
239
240 return Annotation(
241 source=FileSource(filename=filename, title=filename),
242 index=annotation.index,
243 )
244
245 if annotation.type == "url_citation":
246 return Annotation(
247 source=URLSource(
248 url=annotation.url,
249 title=annotation.title,
250 ),
251 index=annotation.end_index,
252 )
253
254 if annotation.type == "container_file_citation":
255 filename = annotation.filename
256 if not filename:
257 return None
258
259 return Annotation(
260 source=FileSource(filename=filename, title=filename),
261 index=annotation.end_index,
262 )
263
264 return None
265
266
267T1 = TypeVar("T1")
268T2 = TypeVar("T2")
269
270
271async def _merge_generators(
272 a: AsyncIterator[T1],
273 b: AsyncIterator[T2],
274) -> AsyncIterator[T1 | T2]:
275 pending: list[AsyncIterator[T1 | T2]] = [a, b]
276 pending_tasks: dict[asyncio.Task, AsyncIterator[T1 | T2]] = {
277 asyncio.ensure_future(g.__anext__()): g for g in pending
278 }
279 while len(pending_tasks) > 0:
280 done, _ = await asyncio.wait(
281 pending_tasks.keys(), return_when="FIRST_COMPLETED"
282 )
283 stop = False
284 for d in done:
285 try:
286 result = d.result()
287 yield result
288 dg = pending_tasks[d]
289 pending_tasks[asyncio.ensure_future(dg.__anext__())] = dg
290 except StopAsyncIteration:
291 stop = True
292 finally:
293 del pending_tasks[d]
294 if stop:
295 for task in pending_tasks.keys():
296 if not task.cancel():
297 try:
298 yield task.result()
299 except asyncio.CancelledError:
300 pass
301 except asyncio.InvalidStateError:
302 pass
303 break
304
305
306class _EventWrapper:
307 def __init__(self, event: ThreadStreamEvent):
308 self.event = event
309
310
311class _AsyncQueueIterator(AsyncIterator[_EventWrapper]):
312 def __init__(
313 self, queue: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel]
314 ):
315 self.queue = queue
316 self.completed = False
317
318 def __aiter__(self):
319 return self
320
321 async def __anext__(self):
322 if self.completed:
323 raise StopAsyncIteration
324
325 item = await self.queue.get()
326 if isinstance(item, _QueueCompleteSentinel):
327 self.completed = True
328 raise StopAsyncIteration
329 return _EventWrapper(item)
330
331 def drain_and_complete(self) -> None:
332 """Empty the underlying queue without awaiting and mark this iterator completed.
333
334 This is intended for cleanup paths where we must guarantee no awaits
335 occur. All queued items, including any completion sentinel, are
336 discarded.
337 """
338 while True:
339 try:
340 self.queue.get_nowait()
341 except asyncio.QueueEmpty:
342 break
343 self.completed = True
344
345
346class StreamingThoughtTracker(BaseModel):
347 item_id: str
348 index: int
349 task: ThoughtTask
350
351
352async def stream_agent_response(
353 context: AgentContext, result: RunResultStreaming
354) -> AsyncIterator[ThreadStreamEvent]:
355 current_item_id = None
356 current_tool_call = None
357 ctx = context
358 thread = context.thread
359 queue_iterator = _AsyncQueueIterator(context._events)
360 produced_items = set()
361 streaming_thought: None | StreamingThoughtTracker = None
362 # item_id -> content_index -> annotation count
363 item_annotation_count: defaultdict[str, defaultdict[int, int]] = defaultdict(
364 lambda: defaultdict(int)
365 )
366
367 # check if the last item in the thread was a workflow or a client tool call
368 # if it was a client tool call, check if the second last item was a workflow
369 # if either was, continue the workflow
370 items = await context.store.load_thread_items(
371 thread.id, None, 2, "desc", context.request_context
372 )
373 last_item = items.data[0] if len(items.data) > 0 else None
374 second_last_item = items.data[1] if len(items.data) > 1 else None
375
376 if last_item and last_item.type == "workflow":
377 ctx.workflow_item = last_item
378 elif (
379 last_item
380 and last_item.type == "client_tool_call"
381 and second_last_item
382 and second_last_item.type == "workflow"
383 ):
384 ctx.workflow_item = second_last_item
385
386 def end_workflow(item: WorkflowItem):
387 if item == ctx.workflow_item:
388 ctx.workflow_item = None
389 delta = datetime.now() - item.created_at
390 duration = int(delta.total_seconds())
391 if item.workflow.summary is None:
392 item.workflow.summary = DurationSummary(duration=duration)
393 # Default to closing all workflows
394 # To keep a workflow open on completion, close it explicitly with
395 # AgentContext.end_workflow(expanded=True)
396 item.workflow.expanded = False
397 return ThreadItemDoneEvent(item=item)
398
399 try:
400 async for event in _merge_generators(result.stream_events(), queue_iterator):
401 # Events emitted from agent context helpers
402 if isinstance(event, _EventWrapper):
403 event = event.event
404 if (
405 event.type == "thread.item.added"
406 or event.type == "thread.item.done"
407 ):
408 # End the current workflow if visual item is added after it
409 if (
410 ctx.workflow_item
411 and ctx.workflow_item.id != event.item.id
412 and event.item.type != "client_tool_call"
413 and event.item.type != "hidden_context_item"
414 ):
415 yield end_workflow(ctx.workflow_item)
416
417 # track the current workflow if one is added
418 if (
419 event.type == "thread.item.added"
420 and event.item.type == "workflow"
421 ):
422 ctx.workflow_item = event.item
423
424 # track integration produced items so we can clean them up if
425 # there is a guardrail tripwire
426 produced_items.add(event.item.id)
427 yield event
428 continue
429
430 if event.type == "run_item_stream_event":
431 event = event.item
432 if (
433 event.type == "tool_call_item"
434 and event.raw_item.type == "function_call"
435 ):
436 current_tool_call = event.raw_item.call_id
437 current_item_id = event.raw_item.id
438 assert current_item_id
439 produced_items.add(current_item_id)
440 continue
441
442 if event.type != "raw_response_event":
443 # Ignore everything else that isn't a raw response event
444 continue
445
446 # Handle Responses events
447 event = event.data
448 if event.type == "response.content_part.added":
449 if event.part.type == "reasoning_text":
450 continue
451 content = _convert_content(event.part)
452 yield ThreadItemUpdatedEvent(
453 item_id=event.item_id,
454 update=AssistantMessageContentPartAdded(
455 content_index=event.content_index,
456 content=content,
457 ),
458 )
459 elif event.type == "response.output_text.delta":
460 yield ThreadItemUpdatedEvent(
461 item_id=event.item_id,
462 update=AssistantMessageContentPartTextDelta(
463 content_index=event.content_index,
464 delta=event.delta,
465 ),
466 )
467 elif event.type == "response.output_text.done":
468 yield ThreadItemUpdatedEvent(
469 item_id=event.item_id,
470 update=AssistantMessageContentPartDone(
471 content_index=event.content_index,
472 content=AssistantMessageContent(
473 text=event.text,
474 annotations=[],
475 ),
476 ),
477 )
478 elif event.type == "response.output_text.annotation.added":
479 annotation = _convert_annotation(event.annotation)
480 if annotation:
481 # Manually track annotation indices per content part in case we drop an annotation that
482 # we can't convert to our internal representation (e.g. missing filename).
483 annotation_index = item_annotation_count[event.item_id][
484 event.content_index
485 ]
486 item_annotation_count[event.item_id][event.content_index] = (
487 annotation_index + 1
488 )
489 yield ThreadItemUpdatedEvent(
490 item_id=event.item_id,
491 update=AssistantMessageContentPartAnnotationAdded(
492 content_index=event.content_index,
493 annotation_index=annotation_index,
494 annotation=annotation,
495 ),
496 )
497 continue
498 elif event.type == "response.output_item.added":
499 item = event.item
500 if item.type == "reasoning" and not ctx.workflow_item:
501 ctx.workflow_item = WorkflowItem(
502 id=ctx.generate_id("workflow"),
503 created_at=datetime.now(),
504 workflow=Workflow(type="reasoning", tasks=[]),
505 thread_id=thread.id,
506 )
507 produced_items.add(ctx.workflow_item.id)
508 yield ThreadItemAddedEvent(item=ctx.workflow_item)
509 if item.type == "message":
510 if ctx.workflow_item:
511 yield end_workflow(ctx.workflow_item)
512 produced_items.add(item.id)
513 yield ThreadItemAddedEvent(
514 item=AssistantMessageItem(
515 # Reusing the Responses message ID
516 id=item.id,
517 thread_id=thread.id,
518 content=[_convert_content(c) for c in item.content],
519 created_at=datetime.now(),
520 ),
521 )
522 elif event.type == "response.reasoning_summary_text.delta":
523 if not ctx.workflow_item:
524 continue
525
526 # stream the first thought in a new workflow so that we can show it earlier
527 if (
528 ctx.workflow_item.workflow.type == "reasoning"
529 and len(ctx.workflow_item.workflow.tasks) == 0
530 ):
531 streaming_thought = StreamingThoughtTracker(
532 item_id=event.item_id,
533 index=event.summary_index,
534 task=ThoughtTask(content=event.delta),
535 )
536 ctx.workflow_item.workflow.tasks.append(streaming_thought.task)
537 yield ThreadItemUpdatedEvent(
538 item_id=ctx.workflow_item.id,
539 update=WorkflowTaskAdded(
540 task=streaming_thought.task,
541 task_index=0,
542 ),
543 )
544 elif (
545 streaming_thought
546 and streaming_thought.task in ctx.workflow_item.workflow.tasks
547 and event.item_id == streaming_thought.item_id
548 and event.summary_index == streaming_thought.index
549 ):
550 streaming_thought.task.content += event.delta
551 yield ThreadItemUpdatedEvent(
552 item_id=ctx.workflow_item.id,
553 update=WorkflowTaskUpdated(
554 task=streaming_thought.task,
555 task_index=ctx.workflow_item.workflow.tasks.index(
556 streaming_thought.task
557 ),
558 ),
559 )
560 elif event.type == "response.reasoning_summary_text.done":
561 if ctx.workflow_item:
562 if (
563 streaming_thought
564 and streaming_thought.task in ctx.workflow_item.workflow.tasks
565 and event.item_id == streaming_thought.item_id
566 and event.summary_index == streaming_thought.index
567 ):
568 task = streaming_thought.task
569 task.content = event.text
570 streaming_thought = None
571 update = WorkflowTaskUpdated(
572 task=task,
573 task_index=ctx.workflow_item.workflow.tasks.index(task),
574 )
575 else:
576 task = ThoughtTask(content=event.text)
577 ctx.workflow_item.workflow.tasks.append(task)
578 update = WorkflowTaskAdded(
579 task=task,
580 task_index=ctx.workflow_item.workflow.tasks.index(task),
581 )
582 yield ThreadItemUpdatedEvent(
583 item_id=ctx.workflow_item.id,
584 update=update,
585 )
586 elif event.type == "response.output_item.done":
587 item = event.item
588 if item.type == "message":
589 produced_items.add(item.id)
590 yield ThreadItemDoneEvent(
591 item=AssistantMessageItem(
592 # Reusing the Responses message ID
593 id=item.id,
594 thread_id=thread.id,
595 content=[_convert_content(c) for c in item.content],
596 created_at=datetime.now(),
597 ),
598 )
599
600 except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
601 for item_id in produced_items:
602 yield ThreadItemRemovedEvent(item_id=item_id)
603
604 # Drain remaining events without processing them
605 context._complete()
606 queue_iterator.drain_and_complete()
607
608 raise
609
610 context._complete()
611
612 # Drain remaining events
613 async for event in queue_iterator:
614 yield event.event
615
616 # If there is still an active workflow at the end of the run, store
617 # it's current state so that we can continue it in the next turn.
618 if ctx.workflow_item:
619 await ctx.store.add_thread_item(
620 thread.id, ctx.workflow_item, ctx.request_context
621 )
622
623 if context.client_tool_call:
624 yield ThreadItemDoneEvent(
625 item=ClientToolCallItem(
626 id=current_item_id
627 or context.store.generate_item_id(
628 "tool_call", thread, context.request_context
629 ),
630 thread_id=thread.id,
631 name=context.client_tool_call.name,
632 arguments=context.client_tool_call.arguments,
633 created_at=datetime.now(),
634 call_id=current_tool_call
635 or context.store.generate_item_id(
636 "tool_call", thread, context.request_context
637 ),
638 ),
639 )
640
641
642TWidget = TypeVar("TWidget", bound=Markdown | Text)
643
644
645async def accumulate_text(
646 events: AsyncIterator[StreamEvent],
647 base_widget: TWidget,
648) -> AsyncIterator[TWidget]:
649 text = ""
650 yield base_widget
651 async for event in events:
652 if event.type == "raw_response_event":
653 if event.data.type == "response.output_text.delta":
654 text += event.data.delta
655 yield base_widget.model_copy(update={"value": text})
656 yield base_widget.model_copy(update={"value": text, "streaming": False})
657
658
659class ThreadItemConverter:
660 """
661 Converts thread items to Agent SDK input items.
662 Widgets, Tasks, and Workflows have default conversions but can be customized.
663 Attachments, Tags, and HiddenContextItems require custom handling based on the use case.
664 Other item types are converted automatically.
665 """
666
667 async def attachment_to_message_content(
668 self, attachment: Attachment
669 ) -> ResponseInputContentParam:
670 """
671 Convert an attachment in a user message into a message content part to send to the model.
672 Required when attachments are enabled.
673 """
674 raise NotImplementedError(
675 "An Attachment was included in a UserMessageItem but Converter.attachment_to_message_content was not implemented"
676 )
677
678 async def tag_to_message_content(
679 self, tag: UserMessageTagContent
680 ) -> ResponseInputContentParam:
681 """
682 Convert a tag in a user message into a message content part to send to the model as context.
683 Required when tags are used.
684 """
685 raise NotImplementedError(
686 "A Tag was included in a UserMessageItem but Converter.tag_to_message_content is not implemented"
687 )
688
689 async def hidden_context_to_input(
690 self, item: HiddenContextItem
691 ) -> TResponseInputItem | list[TResponseInputItem] | None:
692 """
693 Convert a HiddenContextItem into input item(s) to send to the model.
694 Required to override when HiddenContextItems with non-string content are used.
695 """
696 if not isinstance(item.content, str):
697 raise NotImplementedError(
698 "HiddenContextItems with non-string content were present in a user message but a Converter.hidden_context_to_input was not implemented"
699 )
700
701 text = (
702 "Hidden context for the agent (not shown to the user):\n"
703 f"<HiddenContext>\n{item.content}\n</HiddenContext>"
704 )
705 return Message(
706 type="message",
707 content=[
708 ResponseInputTextParam(
709 type="input_text",
710 text=text,
711 )
712 ],
713 role="user",
714 )
715
716 async def sdk_hidden_context_to_input(
717 self, item: SDKHiddenContextItem
718 ) -> TResponseInputItem | list[TResponseInputItem] | None:
719 """
720 Convert a SDKHiddenContextItem into input item to send to the model.
721 This is used by the ChatKit Python SDK for storing additional context
722 for internal operations.
723 Override if you want to wrap the content in a different format.
724 """
725 text = (
726 "Hidden context for the agent (not shown to the user):\n"
727 f"<HiddenContext>\n{item.content}\n</HiddenContext>"
728 )
729 return Message(
730 type="message",
731 content=[
732 ResponseInputTextParam(
733 type="input_text",
734 text=text,
735 )
736 ],
737 role="user",
738 )
739
740 async def task_to_input(
741 self, item: TaskItem
742 ) -> TResponseInputItem | list[TResponseInputItem] | None:
743 """
744 Convert a TaskItem into input item(s) to send to the model.
745 """
746 if item.task.type != "custom" or (
747 not item.task.title and not item.task.content
748 ):
749 return None
750 title = f"{item.task.title}" if item.task.title else ""
751 content = f"{item.task.content}" if item.task.content else ""
752 task_text = f"{title}: {content}" if title and content else title or content
753 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
754 return Message(
755 type="message",
756 content=[
757 ResponseInputTextParam(
758 type="input_text",
759 text=text,
760 )
761 ],
762 role="user",
763 )
764
765 async def workflow_to_input(
766 self, item: WorkflowItem
767 ) -> TResponseInputItem | list[TResponseInputItem] | None:
768 """
769 Convert a TaskItem into input item(s) to send to the model.
770 Returns WorkflowItem.response_items by default.
771 """
772 messages = []
773 for task in item.workflow.tasks:
774 if task.type != "custom" or (not task.title and not task.content):
775 continue
776
777 title = f"{task.title}" if task.title else ""
778 content = f"{task.content}" if task.content else ""
779 task_text = f"{title}: {content}" if title and content else title or content
780 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
781 messages.append(
782 Message(
783 type="message",
784 content=[
785 ResponseInputTextParam(
786 type="input_text",
787 text=text,
788 )
789 ],
790 role="user",
791 )
792 )
793 return messages
794
795 async def widget_to_input(
796 self, item: WidgetItem
797 ) -> TResponseInputItem | list[TResponseInputItem] | None:
798 """
799 Convert a WidgetItem into input item(s) to send to the model.
800 By default, WidgetItems converted to a text description with a JSON representation of the widget.
801 """
802 return Message(
803 type="message",
804 content=[
805 ResponseInputTextParam(
806 type="input_text",
807 text=f"The following graphical UI widget (id: {item.id}) was displayed to the user:"
808 + item.widget.model_dump_json(
809 exclude_unset=True, exclude_none=True
810 ),
811 )
812 ],
813 role="user",
814 )
815
816 async def user_message_to_input(
817 self, item: UserMessageItem, is_last_message: bool = True
818 ) -> TResponseInputItem | list[TResponseInputItem] | None:
819 # Build the user text exactly as typed, rendering tags as @key
820 message_text_parts: list[str] = []
821 # Track tags separately to add system context
822 raw_tags: list[UserMessageTagContent] = []
823
824 for part in item.content:
825 if isinstance(part, UserMessageTextContent):
826 message_text_parts.append(part.text)
827 elif isinstance(part, UserMessageTagContent):
828 message_text_parts.append(f"@{part.text}")
829 raw_tags.append(part)
830 else:
831 assert_never(part)
832
833 user_text_item = Message(
834 role="user",
835 type="message",
836 content=[
837 ResponseInputTextParam(
838 type="input_text", text="".join(message_text_parts)
839 ),
840 *[
841 await self.attachment_to_message_content(a)
842 for a in item.attachments
843 ],
844 ],
845 )
846
847 # Build system items (prepend later): quoted text and @-mention context
848 context_items: list[TResponseInputItem] = []
849
850 if item.quoted_text and is_last_message:
851 context_items.append(
852 Message(
853 role="user",
854 type="message",
855 content=[
856 ResponseInputTextParam(
857 type="input_text",
858 text=f"The user is referring to this in particular: \n{item.quoted_text}",
859 )
860 ],
861 )
862 )
863
864 # Dedupe tags (preserve order) and resolve to message content
865 if raw_tags:
866 seen, uniq_tags = set(), []
867 for t in raw_tags:
868 if t.text not in seen:
869 seen.add(t.text)
870 uniq_tags.append(t)
871
872 tag_content: ResponseInputMessageContentListParam = [
873 # should return summarized text items
874 await self.tag_to_message_content(tag)
875 for tag in uniq_tags
876 ]
877
878 if tag_content:
879 context_items.append(
880 Message(
881 role="user",
882 type="message",
883 content=[
884 ResponseInputTextParam(
885 type="input_text",
886 text=cleandoc("""
887 # User-provided context for @-mentions
888 - When referencing resolved entities, use their canonical names **without** '@'.
889 - The '@' form appears only in user text and should not be echoed.
890 """).strip(),
891 ),
892 *tag_content,
893 ],
894 )
895 )
896
897 return [user_text_item, *context_items]
898
899 async def assistant_message_to_input(
900 self, item: AssistantMessageItem
901 ) -> TResponseInputItem | list[TResponseInputItem] | None:
902 return EasyInputMessageParam(
903 type="message",
904 content=[
905 # content param doesn't support the assistant message content types
906 cast(
907 ResponseInputContentParam,
908 ResponseOutputText(
909 type="output_text",
910 text=c.text,
911 annotations=[], # TODO: these should be sent back as well
912 ).model_dump(),
913 )
914 for c in item.content
915 ],
916 role="assistant",
917 )
918
919 async def client_tool_call_to_input(
920 self, item: ClientToolCallItem
921 ) -> TResponseInputItem | list[TResponseInputItem] | None:
922 if item.status == "pending":
923 # Filter out pending tool calls - they cannot be sent to the model
924 return None
925
926 return [
927 ResponseFunctionToolCallParam(
928 type="function_call",
929 call_id=item.call_id,
930 name=item.name,
931 arguments=json.dumps(item.arguments),
932 ),
933 FunctionCallOutput(
934 type="function_call_output",
935 call_id=item.call_id,
936 output=json.dumps(item.output),
937 ),
938 ]
939
940 async def end_of_turn_to_input(
941 self, item: EndOfTurnItem
942 ) -> TResponseInputItem | list[TResponseInputItem] | None:
943 # Only used for UI hints - you shouldn't need to override this
944 return None
945
946 async def _thread_item_to_input_item(
947 self,
948 item: ThreadItem,
949 is_last_message: bool = True,
950 ) -> list[TResponseInputItem]:
951 match item:
952 case UserMessageItem():
953 out = await self.user_message_to_input(item, is_last_message) or []
954 return out if isinstance(out, list) else [out]
955 case AssistantMessageItem():
956 out = await self.assistant_message_to_input(item) or []
957 return out if isinstance(out, list) else [out]
958 case ClientToolCallItem():
959 out = await self.client_tool_call_to_input(item) or []
960 return out if isinstance(out, list) else [out]
961 case EndOfTurnItem():
962 out = await self.end_of_turn_to_input(item) or []
963 return out if isinstance(out, list) else [out]
964 case WidgetItem():
965 out = await self.widget_to_input(item) or []
966 return out if isinstance(out, list) else [out]
967 case WorkflowItem():
968 out = await self.workflow_to_input(item) or []
969 return out if isinstance(out, list) else [out]
970 case TaskItem():
971 out = await self.task_to_input(item) or []
972 return out if isinstance(out, list) else [out]
973 case HiddenContextItem():
974 out = await self.hidden_context_to_input(item) or []
975 return out if isinstance(out, list) else [out]
976 case SDKHiddenContextItem():
977 out = await self.sdk_hidden_context_to_input(item) or []
978 return out if isinstance(out, list) else [out]
979 case _:
980 assert_never(item)
981
982 async def to_agent_input(
983 self,
984 thread_items: Sequence[ThreadItem] | ThreadItem,
985 ) -> list[TResponseInputItem]:
986 if isinstance(thread_items, Sequence):
987 # shallow copy in case caller mutates the list while we're iterating
988 thread_items = thread_items[:]
989 else:
990 thread_items = [thread_items]
991 output: list[TResponseInputItem] = []
992 for item in thread_items:
993 output.extend(
994 await self._thread_item_to_input_item(
995 item,
996 is_last_message=item is thread_items[-1],
997 )
998 )
999 return output
1000
1001
1002_DEFAULT_CONVERTER = ThreadItemConverter()
1003
1004
1005def simple_to_agent_input(thread_items: Sequence[ThreadItem] | ThreadItem):
1006 return _DEFAULT_CONVERTER.to_agent_input(thread_items)
1007