openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
367da6d8a75321471fbc01f324dcd8c06acd8add

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/agents.py

977lines · modecode

1import asyncio
2import json
3from collections import defaultdict
4from collections.abc import AsyncIterator
5from datetime import datetime
6from inspect import cleandoc
7from typing import (
8 Annotated,
9 Any,
10 AsyncGenerator,
11 Generic,
12 Sequence,
13 TypeVar,
14 cast,
15)
16
17from agents import (
18 InputGuardrailTripwireTriggered,
19 OutputGuardrailTripwireTriggered,
20 RunResultStreaming,
21 StreamEvent,
22 TResponseInputItem,
23)
24from openai.types.responses import (
25 EasyInputMessageParam,
26 ResponseFunctionToolCallParam,
27 ResponseInputContentParam,
28 ResponseInputMessageContentListParam,
29 ResponseInputTextParam,
30 ResponseOutputText,
31)
32from openai.types.responses.response_input_item_param import (
33 FunctionCallOutput,
34 Message,
35)
36from openai.types.responses.response_output_message import Content
37from openai.types.responses.response_output_text import (
38 Annotation as ResponsesAnnotation,
39)
40from openai.types.responses.response_output_text import (
41 AnnotationContainerFileCitation,
42 AnnotationFileCitation,
43 AnnotationFilePath,
44 AnnotationURLCitation,
45)
46from pydantic import BaseModel, ConfigDict, SkipValidation, TypeAdapter
47from typing_extensions import assert_never
48
49from .server import stream_widget
50from .store import Store, StoreItemType
51from .types import (
52 Annotation,
53 AssistantMessageContent,
54 AssistantMessageContentPartAdded,
55 AssistantMessageContentPartAnnotationAdded,
56 AssistantMessageContentPartDone,
57 AssistantMessageContentPartTextDelta,
58 AssistantMessageItem,
59 Attachment,
60 ClientToolCallItem,
61 DurationSummary,
62 EndOfTurnItem,
63 FileSource,
64 HiddenContextItem,
65 Task,
66 TaskItem,
67 ThoughtTask,
68 ThreadItem,
69 ThreadItemAddedEvent,
70 ThreadItemDoneEvent,
71 ThreadItemRemovedEvent,
72 ThreadItemUpdated,
73 ThreadMetadata,
74 ThreadStreamEvent,
75 URLSource,
76 UserMessageItem,
77 UserMessageTagContent,
78 UserMessageTextContent,
79 WidgetItem,
80 Workflow,
81 WorkflowItem,
82 WorkflowSummary,
83 WorkflowTaskAdded,
84 WorkflowTaskUpdated,
85)
86from .widgets import Markdown, Text, WidgetRoot
87
88
89class ClientToolCall(BaseModel):
90 """
91 Returned from tool methods to indicate a client-side tool call.
92 """
93
94 name: str
95 arguments: dict[str, Any]
96
97
98class _QueueCompleteSentinel: ...
99
100
101TContext = TypeVar("TContext")
102
103
104class AgentContext(BaseModel, Generic[TContext]):
105 model_config = ConfigDict(arbitrary_types_allowed=True)
106
107 thread: ThreadMetadata
108 store: Annotated[Store[TContext], SkipValidation]
109 request_context: TContext
110 previous_response_id: str | None = None
111 client_tool_call: ClientToolCall | None = None
112 workflow_item: WorkflowItem | None = None
113 _events: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel] = asyncio.Queue()
114
115 def generate_id(
116 self, type: StoreItemType, thread: ThreadMetadata | None = None
117 ) -> str:
118 if type == "thread":
119 return self.store.generate_thread_id(self.request_context)
120 return self.store.generate_item_id(
121 type, thread or self.thread, self.request_context
122 )
123
124 async def stream_widget(
125 self,
126 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
127 copy_text: str | None = None,
128 ) -> None:
129 async for event in stream_widget(
130 self.thread,
131 widget,
132 copy_text,
133 lambda item_type: self.store.generate_item_id(
134 item_type, self.thread, self.request_context
135 ),
136 ):
137 await self._events.put(event)
138
139 async def end_workflow(
140 self, summary: WorkflowSummary | None = None, expanded: bool = False
141 ) -> None:
142 if not self.workflow_item:
143 # No workflow to end
144 return
145
146 if summary is not None:
147 self.workflow_item.workflow.summary = summary
148 elif self.workflow_item.workflow.summary is None:
149 # If no summary was set or provided, set a basic work summary
150 delta = datetime.now() - self.workflow_item.created_at
151 duration = int(delta.total_seconds())
152 self.workflow_item.workflow.summary = DurationSummary(duration=duration)
153 self.workflow_item.workflow.expanded = expanded
154 await self.stream(ThreadItemDoneEvent(item=self.workflow_item))
155 self.workflow_item = None
156
157 async def start_workflow(self, workflow: Workflow) -> None:
158 self.workflow_item = WorkflowItem(
159 id=self.generate_id("workflow"),
160 created_at=datetime.now(),
161 workflow=workflow,
162 thread_id=self.thread.id,
163 )
164
165 if workflow.type != "reasoning" and len(workflow.tasks) == 0:
166 # Defer sending added event until we have tasks
167 return
168
169 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
170
171 async def update_workflow_task(self, task: Task, task_index: int) -> None:
172 if self.workflow_item is None:
173 raise ValueError("Workflow is not set")
174 # ensure reference is updated in case task is a copy
175 self.workflow_item.workflow.tasks[task_index] = task
176 await self.stream(
177 ThreadItemUpdated(
178 item_id=self.workflow_item.id,
179 update=WorkflowTaskUpdated(
180 task=task,
181 task_index=task_index,
182 ),
183 )
184 )
185
186 async def add_workflow_task(self, task: Task) -> None:
187 self.workflow_item = self.workflow_item or WorkflowItem(
188 id=self.generate_id("workflow"),
189 created_at=datetime.now(),
190 workflow=Workflow(type="custom", tasks=[]),
191 thread_id=self.thread.id,
192 )
193 workflow = self.workflow_item.workflow
194 workflow.tasks.append(task)
195
196 if workflow.type != "reasoning" and len(workflow.tasks) == 1:
197 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
198 else:
199 await self.stream(
200 ThreadItemUpdated(
201 item_id=self.workflow_item.id,
202 update=WorkflowTaskAdded(
203 task=task,
204 task_index=workflow.tasks.index(task),
205 ),
206 )
207 )
208
209 async def stream(self, event: ThreadStreamEvent) -> None:
210 await self._events.put(event)
211
212 def _complete(self):
213 self._events.put_nowait(_QueueCompleteSentinel())
214
215
216def _convert_content(content: Content) -> AssistantMessageContent:
217 if content.type == "output_text":
218 annotations = [
219 _convert_annotation(annotation) for annotation in content.annotations
220 ]
221 annotations = [a for a in annotations if a is not None]
222 return AssistantMessageContent(
223 text=content.text,
224 annotations=annotations,
225 )
226 else:
227 return AssistantMessageContent(
228 text=content.refusal,
229 annotations=[],
230 )
231
232
233def _convert_annotation(raw_annotation: object) -> Annotation | None:
234 # There is a bug in the OpenAPI client that sometimes parses the annotation delta event into the wrong class
235 # resulting into annotation being a dict.
236 match raw_annotation:
237 case (
238 AnnotationFileCitation()
239 | AnnotationURLCitation()
240 | AnnotationContainerFileCitation()
241 | AnnotationFilePath()
242 ):
243 annotation = raw_annotation
244 case _:
245 annotation = TypeAdapter[ResponsesAnnotation](
246 ResponsesAnnotation
247 ).validate_python(raw_annotation)
248
249 if annotation.type == "file_citation":
250 filename = annotation.filename
251 if not filename:
252 return None
253
254 return Annotation(
255 source=FileSource(filename=filename, title=filename),
256 index=annotation.index,
257 )
258
259 if annotation.type == "url_citation":
260 return Annotation(
261 source=URLSource(
262 url=annotation.url,
263 title=annotation.title,
264 ),
265 index=annotation.end_index,
266 )
267
268 if annotation.type == "container_file_citation":
269 filename = annotation.filename
270 if not filename:
271 return None
272
273 return Annotation(
274 source=FileSource(filename=filename, title=filename),
275 index=annotation.end_index,
276 )
277
278 return None
279
280
281T1 = TypeVar("T1")
282T2 = TypeVar("T2")
283
284
285async def _merge_generators(
286 a: AsyncIterator[T1],
287 b: AsyncIterator[T2],
288) -> AsyncIterator[T1 | T2]:
289 pending: list[AsyncIterator[T1 | T2]] = [a, b]
290 pending_tasks: dict[asyncio.Task, AsyncIterator[T1 | T2]] = {
291 asyncio.ensure_future(g.__anext__()): g for g in pending
292 }
293 while len(pending_tasks) > 0:
294 done, _ = await asyncio.wait(
295 pending_tasks.keys(), return_when="FIRST_COMPLETED"
296 )
297 stop = False
298 for d in done:
299 try:
300 result = d.result()
301 yield result
302 dg = pending_tasks[d]
303 pending_tasks[asyncio.ensure_future(dg.__anext__())] = dg
304 except StopAsyncIteration:
305 stop = True
306 finally:
307 del pending_tasks[d]
308 if stop:
309 for task in pending_tasks.keys():
310 if not task.cancel():
311 try:
312 yield task.result()
313 except asyncio.CancelledError:
314 pass
315 except asyncio.InvalidStateError:
316 pass
317 break
318
319
320class _EventWrapper:
321 def __init__(self, event: ThreadStreamEvent):
322 self.event = event
323
324
325class _AsyncQueueIterator(AsyncIterator[_EventWrapper]):
326 def __init__(
327 self, queue: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel]
328 ):
329 self.queue = queue
330 self.completed = False
331
332 def __aiter__(self):
333 return self
334
335 async def __anext__(self):
336 if self.completed:
337 raise StopAsyncIteration
338
339 item = await self.queue.get()
340 if isinstance(item, _QueueCompleteSentinel):
341 self.completed = True
342 raise StopAsyncIteration
343 return _EventWrapper(item)
344
345 def drain_and_complete(self) -> None:
346 """Empty the underlying queue without awaiting and mark this iterator completed.
347
348 This is intended for cleanup paths where we must guarantee no awaits
349 occur. All queued items, including any completion sentinel, are
350 discarded.
351 """
352 while True:
353 try:
354 self.queue.get_nowait()
355 except asyncio.QueueEmpty:
356 break
357 self.completed = True
358
359
360class StreamingThoughtTracker(BaseModel):
361 item_id: str
362 index: int
363 task: ThoughtTask
364
365
366async def stream_agent_response(
367 context: AgentContext, result: RunResultStreaming
368) -> AsyncIterator[ThreadStreamEvent]:
369 current_item_id = None
370 current_tool_call = None
371 ctx = context
372 thread = context.thread
373 queue_iterator = _AsyncQueueIterator(context._events)
374 produced_items = set()
375 streaming_thought: None | StreamingThoughtTracker = None
376 # item_id -> content_index -> annotation count
377 item_annotation_count: defaultdict[str, defaultdict[int, int]] = defaultdict(
378 lambda: defaultdict(int)
379 )
380
381 # check if the last item in the thread was a workflow or a client tool call
382 # if it was a client tool call, check if the second last item was a workflow
383 # if either was, continue the workflow
384 items = await context.store.load_thread_items(
385 thread.id, None, 2, "desc", context.request_context
386 )
387 last_item = items.data[0] if len(items.data) > 0 else None
388 second_last_item = items.data[1] if len(items.data) > 1 else None
389
390 if last_item and last_item.type == "workflow":
391 ctx.workflow_item = last_item
392 elif (
393 last_item
394 and last_item.type == "client_tool_call"
395 and second_last_item
396 and second_last_item.type == "workflow"
397 ):
398 ctx.workflow_item = second_last_item
399
400 def end_workflow(item: WorkflowItem):
401 if item == ctx.workflow_item:
402 ctx.workflow_item = None
403 delta = datetime.now() - item.created_at
404 duration = int(delta.total_seconds())
405 if item.workflow.summary is None:
406 item.workflow.summary = DurationSummary(duration=duration)
407 # Default to closing all workflows
408 # To keep a workflow open on completion, close it explicitly with
409 # AgentContext.end_workflow(expanded=True)
410 item.workflow.expanded = False
411 return ThreadItemDoneEvent(item=item)
412
413 try:
414 async for event in _merge_generators(result.stream_events(), queue_iterator):
415 # Events emitted from agent context helpers
416 if isinstance(event, _EventWrapper):
417 event = event.event
418 if (
419 event.type == "thread.item.added"
420 or event.type == "thread.item.done"
421 ):
422 # End the current workflow if visual item is added after it
423 if (
424 ctx.workflow_item
425 and ctx.workflow_item.id != event.item.id
426 and event.item.type != "client_tool_call"
427 and event.item.type != "hidden_context_item"
428 ):
429 yield end_workflow(ctx.workflow_item)
430
431 # track the current workflow if one is added
432 if (
433 event.type == "thread.item.added"
434 and event.item.type == "workflow"
435 ):
436 ctx.workflow_item = event.item
437
438 # track integration produced items so we can clean them up if
439 # there is a guardrail tripwire
440 produced_items.add(event.item.id)
441 yield event
442 continue
443
444 if event.type == "run_item_stream_event":
445 event = event.item
446 if (
447 event.type == "tool_call_item"
448 and event.raw_item.type == "function_call"
449 ):
450 current_tool_call = event.raw_item.call_id
451 current_item_id = event.raw_item.id
452 assert current_item_id
453 produced_items.add(current_item_id)
454 continue
455
456 if event.type != "raw_response_event":
457 # Ignore everything else that isn't a raw response event
458 continue
459
460 # Handle Responses events
461 event = event.data
462 if event.type == "response.content_part.added":
463 if event.part.type == "reasoning_text":
464 continue
465 content = _convert_content(event.part)
466 yield ThreadItemUpdated(
467 item_id=event.item_id,
468 update=AssistantMessageContentPartAdded(
469 content_index=event.content_index,
470 content=content,
471 ),
472 )
473 elif event.type == "response.output_text.delta":
474 yield ThreadItemUpdated(
475 item_id=event.item_id,
476 update=AssistantMessageContentPartTextDelta(
477 content_index=event.content_index,
478 delta=event.delta,
479 ),
480 )
481 elif event.type == "response.output_text.done":
482 yield ThreadItemUpdated(
483 item_id=event.item_id,
484 update=AssistantMessageContentPartDone(
485 content_index=event.content_index,
486 content=AssistantMessageContent(
487 text=event.text,
488 annotations=[],
489 ),
490 ),
491 )
492 elif event.type == "response.output_text.annotation.added":
493 annotation = _convert_annotation(event.annotation)
494 if annotation:
495 # Manually track annotation indices per content part in case we drop an annotation that
496 # we can't convert to our internal representation (e.g. missing filename).
497 annotation_index = item_annotation_count[event.item_id][
498 event.content_index
499 ]
500 item_annotation_count[event.item_id][event.content_index] = (
501 annotation_index + 1
502 )
503 yield ThreadItemUpdated(
504 item_id=event.item_id,
505 update=AssistantMessageContentPartAnnotationAdded(
506 content_index=event.content_index,
507 annotation_index=annotation_index,
508 annotation=annotation,
509 ),
510 )
511 continue
512 elif event.type == "response.output_item.added":
513 item = event.item
514 if item.type == "reasoning" and not ctx.workflow_item:
515 ctx.workflow_item = WorkflowItem(
516 id=ctx.generate_id("workflow"),
517 created_at=datetime.now(),
518 workflow=Workflow(type="reasoning", tasks=[]),
519 thread_id=thread.id,
520 )
521 produced_items.add(ctx.workflow_item.id)
522 yield ThreadItemAddedEvent(item=ctx.workflow_item)
523 if item.type == "message":
524 if ctx.workflow_item:
525 yield end_workflow(ctx.workflow_item)
526 produced_items.add(item.id)
527 yield ThreadItemAddedEvent(
528 item=AssistantMessageItem(
529 # Reusing the Responses message ID
530 id=item.id,
531 thread_id=thread.id,
532 content=[_convert_content(c) for c in item.content],
533 created_at=datetime.now(),
534 ),
535 )
536 elif event.type == "response.reasoning_summary_text.delta":
537 if not ctx.workflow_item:
538 continue
539
540 # stream the first thought in a new workflow so that we can show it earlier
541 if (
542 ctx.workflow_item.workflow.type == "reasoning"
543 and len(ctx.workflow_item.workflow.tasks) == 0
544 ):
545 streaming_thought = StreamingThoughtTracker(
546 item_id=event.item_id,
547 index=event.summary_index,
548 task=ThoughtTask(content=event.delta),
549 )
550 ctx.workflow_item.workflow.tasks.append(streaming_thought.task)
551 yield ThreadItemUpdated(
552 item_id=ctx.workflow_item.id,
553 update=WorkflowTaskAdded(
554 task=streaming_thought.task,
555 task_index=0,
556 ),
557 )
558 elif (
559 streaming_thought
560 and streaming_thought.task in ctx.workflow_item.workflow.tasks
561 and event.item_id == streaming_thought.item_id
562 and event.summary_index == streaming_thought.index
563 ):
564 streaming_thought.task.content += event.delta
565 yield ThreadItemUpdated(
566 item_id=ctx.workflow_item.id,
567 update=WorkflowTaskUpdated(
568 task=streaming_thought.task,
569 task_index=ctx.workflow_item.workflow.tasks.index(
570 streaming_thought.task
571 ),
572 ),
573 )
574 elif event.type == "response.reasoning_summary_text.done":
575 if ctx.workflow_item:
576 if (
577 streaming_thought
578 and streaming_thought.task in ctx.workflow_item.workflow.tasks
579 and event.item_id == streaming_thought.item_id
580 and event.summary_index == streaming_thought.index
581 ):
582 task = streaming_thought.task
583 task.content = event.text
584 streaming_thought = None
585 update = WorkflowTaskUpdated(
586 task=task,
587 task_index=ctx.workflow_item.workflow.tasks.index(task),
588 )
589 else:
590 task = ThoughtTask(content=event.text)
591 ctx.workflow_item.workflow.tasks.append(task)
592 update = WorkflowTaskAdded(
593 task=task,
594 task_index=ctx.workflow_item.workflow.tasks.index(task),
595 )
596 yield ThreadItemUpdated(
597 item_id=ctx.workflow_item.id,
598 update=update,
599 )
600 elif event.type == "response.output_item.done":
601 item = event.item
602 if item.type == "message":
603 produced_items.add(item.id)
604 yield ThreadItemDoneEvent(
605 item=AssistantMessageItem(
606 # Reusing the Responses message ID
607 id=item.id,
608 thread_id=thread.id,
609 content=[_convert_content(c) for c in item.content],
610 created_at=datetime.now(),
611 ),
612 )
613
614 except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
615 for item_id in produced_items:
616 yield ThreadItemRemovedEvent(item_id=item_id)
617
618 # Drain remaining events without processing them
619 context._complete()
620 queue_iterator.drain_and_complete()
621
622 raise
623
624 context._complete()
625
626 # Drain remaining events
627 async for event in queue_iterator:
628 yield event.event
629
630 # If there is still an active workflow at the end of the run, store
631 # it's current state so that we can continue it in the next turn.
632 if ctx.workflow_item:
633 await ctx.store.add_thread_item(
634 thread.id, ctx.workflow_item, ctx.request_context
635 )
636
637 if context.client_tool_call:
638 yield ThreadItemDoneEvent(
639 item=ClientToolCallItem(
640 id=current_item_id
641 or context.store.generate_item_id(
642 "tool_call", thread, context.request_context
643 ),
644 thread_id=thread.id,
645 name=context.client_tool_call.name,
646 arguments=context.client_tool_call.arguments,
647 created_at=datetime.now(),
648 call_id=current_tool_call
649 or context.store.generate_item_id(
650 "tool_call", thread, context.request_context
651 ),
652 ),
653 )
654
655
656TWidget = TypeVar("TWidget", bound=Markdown | Text)
657
658
659async def accumulate_text(
660 events: AsyncIterator[StreamEvent],
661 base_widget: TWidget,
662) -> AsyncIterator[TWidget]:
663 text = ""
664 yield base_widget
665 async for event in events:
666 if event.type == "raw_response_event":
667 if event.data.type == "response.output_text.delta":
668 text += event.data.delta
669 yield base_widget.model_copy(update={"value": text})
670 yield base_widget.model_copy(update={"value": text, "streaming": False})
671
672
673class ThreadItemConverter:
674 """
675 Converts thread items to Agent SDK input items.
676 Widgets, Tasks, and Workflows have default conversions but can be customized.
677 Attachments, Tags, and HiddenContextItems require custom handling based on the use case.
678 Other item types are converted automatically.
679 """
680
681 async def attachment_to_message_content(
682 self, attachment: Attachment
683 ) -> ResponseInputContentParam:
684 """
685 Convert an attachment in a user message into a message content part to send to the model.
686 Required when attachments are enabled.
687 """
688 raise NotImplementedError(
689 "An Attachment was included in a UserMessageItem but Converter.attachment_to_message_content was not implemented"
690 )
691
692 async def tag_to_message_content(
693 self, tag: UserMessageTagContent
694 ) -> ResponseInputContentParam:
695 """
696 Convert a tag in a user message into a message content part to send to the model as context.
697 Required when tags are used.
698 """
699 raise NotImplementedError(
700 "A Tag was included in a UserMessageItem but Converter.tag_to_message_content is not implemented"
701 )
702
703 async def hidden_context_to_input(
704 self, item: HiddenContextItem
705 ) -> TResponseInputItem | list[TResponseInputItem] | None:
706 """
707 Convert a HiddenContextItem into input item(s) to send to the model.
708 Required when HiddenContextItem are used.
709 """
710 raise NotImplementedError(
711 "HiddenContextItem were present in a user message but Converter.hidden_context_to_input was not implemented"
712 )
713
714 async def task_to_input(
715 self, item: TaskItem
716 ) -> TResponseInputItem | list[TResponseInputItem] | None:
717 """
718 Convert a TaskItem into input item(s) to send to the model.
719 """
720 if item.task.type != "custom" or (
721 not item.task.title and not item.task.content
722 ):
723 return None
724 title = f"{item.task.title}" if item.task.title else ""
725 content = f"{item.task.content}" if item.task.content else ""
726 task_text = f"{title}: {content}" if title and content else title or content
727 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
728 return Message(
729 type="message",
730 content=[
731 ResponseInputTextParam(
732 type="input_text",
733 text=text,
734 )
735 ],
736 role="user",
737 )
738
739 async def workflow_to_input(
740 self, item: WorkflowItem
741 ) -> TResponseInputItem | list[TResponseInputItem] | None:
742 """
743 Convert a TaskItem into input item(s) to send to the model.
744 Returns WorkflowItem.response_items by default.
745 """
746 messages = []
747 for task in item.workflow.tasks:
748 if task.type != "custom" or (not task.title and not task.content):
749 continue
750
751 title = f"{task.title}" if task.title else ""
752 content = f"{task.content}" if task.content else ""
753 task_text = f"{title}: {content}" if title and content else title or content
754 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
755 messages.append(
756 Message(
757 type="message",
758 content=[
759 ResponseInputTextParam(
760 type="input_text",
761 text=text,
762 )
763 ],
764 role="user",
765 )
766 )
767 return messages
768
769 async def widget_to_input(
770 self, item: WidgetItem
771 ) -> TResponseInputItem | list[TResponseInputItem] | None:
772 """
773 Convert a WidgetItem into input item(s) to send to the model.
774 By default, WidgetItems converted to a text description with a JSON representation of the widget.
775 """
776 return Message(
777 type="message",
778 content=[
779 ResponseInputTextParam(
780 type="input_text",
781 text=f"The following graphical UI widget (id: {item.id}) was displayed to the user:"
782 + item.widget.model_dump_json(
783 exclude_unset=True, exclude_none=True
784 ),
785 )
786 ],
787 role="user",
788 )
789
790 async def user_message_to_input(
791 self, item: UserMessageItem, is_last_message: bool = True
792 ) -> TResponseInputItem | list[TResponseInputItem] | None:
793 # Build the user text exactly as typed, rendering tags as @key
794 message_text_parts: list[str] = []
795 # Track tags separately to add system context
796 raw_tags: list[UserMessageTagContent] = []
797
798 for part in item.content:
799 if isinstance(part, UserMessageTextContent):
800 message_text_parts.append(part.text)
801 elif isinstance(part, UserMessageTagContent):
802 message_text_parts.append(f"@{part.text}")
803 raw_tags.append(part)
804 else:
805 assert_never(part)
806
807 user_text_item = Message(
808 role="user",
809 type="message",
810 content=[
811 ResponseInputTextParam(
812 type="input_text", text="".join(message_text_parts)
813 ),
814 *[
815 await self.attachment_to_message_content(a)
816 for a in item.attachments
817 ],
818 ],
819 )
820
821 # Build system items (prepend later): quoted text and @-mention context
822 context_items: list[TResponseInputItem] = []
823
824 if item.quoted_text and is_last_message:
825 context_items.append(
826 Message(
827 role="user",
828 type="message",
829 content=[
830 ResponseInputTextParam(
831 type="input_text",
832 text=f"The user is referring to this in particular: \n{item.quoted_text}",
833 )
834 ],
835 )
836 )
837
838 # Dedupe tags (preserve order) and resolve to message content
839 if raw_tags:
840 seen, uniq_tags = set(), []
841 for t in raw_tags:
842 if t.text not in seen:
843 seen.add(t.text)
844 uniq_tags.append(t)
845
846 tag_content: ResponseInputMessageContentListParam = [
847 # should return summarized text items
848 await self.tag_to_message_content(tag)
849 for tag in uniq_tags
850 ]
851
852 if tag_content:
853 context_items.append(
854 Message(
855 role="user",
856 type="message",
857 content=[
858 ResponseInputTextParam(
859 type="input_text",
860 text=cleandoc("""
861 # User-provided context for @-mentions
862 - When referencing resolved entities, use their canonical names **without** '@'.
863 - The '@' form appears only in user text and should not be echoed.
864 """).strip(),
865 ),
866 *tag_content,
867 ],
868 )
869 )
870
871 return [user_text_item, *context_items]
872
873 async def assistant_message_to_input(
874 self, item: AssistantMessageItem
875 ) -> TResponseInputItem | list[TResponseInputItem] | None:
876 return EasyInputMessageParam(
877 type="message",
878 content=[
879 # content param doesn't support the assistant message content types
880 cast(
881 ResponseInputContentParam,
882 ResponseOutputText(
883 type="output_text",
884 text=c.text,
885 annotations=[], # TODO: these should be sent back as well
886 ).model_dump(),
887 )
888 for c in item.content
889 ],
890 role="assistant",
891 )
892
893 async def client_tool_call_to_input(
894 self, item: ClientToolCallItem
895 ) -> TResponseInputItem | list[TResponseInputItem] | None:
896 if item.status == "pending":
897 # Filter out pending tool calls - they cannot be sent to the model
898 return None
899
900 return [
901 ResponseFunctionToolCallParam(
902 type="function_call",
903 call_id=item.call_id,
904 name=item.name,
905 arguments=json.dumps(item.arguments),
906 ),
907 FunctionCallOutput(
908 type="function_call_output",
909 call_id=item.call_id,
910 output=json.dumps(item.output),
911 ),
912 ]
913
914 async def end_of_turn_to_input(
915 self, item: EndOfTurnItem
916 ) -> TResponseInputItem | list[TResponseInputItem] | None:
917 # Only used for UI hints - you shouldn't need to override this
918 return None
919
920 async def _thread_item_to_input_item(
921 self,
922 item: ThreadItem,
923 is_last_message: bool = True,
924 ) -> list[TResponseInputItem]:
925 match item:
926 case UserMessageItem():
927 out = await self.user_message_to_input(item, is_last_message) or []
928 return out if isinstance(out, list) else [out]
929 case AssistantMessageItem():
930 out = await self.assistant_message_to_input(item) or []
931 return out if isinstance(out, list) else [out]
932 case ClientToolCallItem():
933 out = await self.client_tool_call_to_input(item) or []
934 return out if isinstance(out, list) else [out]
935 case EndOfTurnItem():
936 out = await self.end_of_turn_to_input(item) or []
937 return out if isinstance(out, list) else [out]
938 case WidgetItem():
939 out = await self.widget_to_input(item) or []
940 return out if isinstance(out, list) else [out]
941 case WorkflowItem():
942 out = await self.workflow_to_input(item) or []
943 return out if isinstance(out, list) else [out]
944 case TaskItem():
945 out = await self.task_to_input(item) or []
946 return out if isinstance(out, list) else [out]
947 case HiddenContextItem():
948 out = await self.hidden_context_to_input(item) or []
949 return out if isinstance(out, list) else [out]
950 case _:
951 assert_never(item)
952
953 async def to_agent_input(
954 self,
955 thread_items: Sequence[ThreadItem] | ThreadItem,
956 ) -> list[TResponseInputItem]:
957 if isinstance(thread_items, Sequence):
958 # shallow copy in case caller mutates the list while we're iterating
959 thread_items = thread_items[:]
960 else:
961 thread_items = [thread_items]
962 output: list[TResponseInputItem] = []
963 for item in thread_items:
964 output.extend(
965 await self._thread_item_to_input_item(
966 item,
967 is_last_message=item is thread_items[-1],
968 )
969 )
970 return output
971
972
973_DEFAULT_CONVERTER = ThreadItemConverter()
974
975
976def simple_to_agent_input(thread_items: Sequence[ThreadItem] | ThreadItem):
977 return _DEFAULT_CONVERTER.to_agent_input(thread_items)
978