openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
3f7a2fa23d77277ce5213176a0267a72b7e215ec

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/agents.py

978lines · modecode

1import asyncio
2import json
3from collections import defaultdict
4from collections.abc import AsyncIterator
5from datetime import datetime
6from inspect import cleandoc
7from typing import (
8 Annotated,
9 Any,
10 AsyncGenerator,
11 Generic,
12 Sequence,
13 TypeVar,
14 cast,
15)
16
17from agents import (
18 InputGuardrailTripwireTriggered,
19 OutputGuardrailTripwireTriggered,
20 RunResultStreaming,
21 StreamEvent,
22 TResponseInputItem,
23)
24from openai.types.responses import (
25 EasyInputMessageParam,
26 ResponseFunctionToolCallParam,
27 ResponseInputContentParam,
28 ResponseInputMessageContentListParam,
29 ResponseInputTextParam,
30 ResponseOutputText,
31)
32from openai.types.responses.response_input_item_param import (
33 FunctionCallOutput,
34 Message,
35)
36from openai.types.responses.response_output_message import Content
37from openai.types.responses.response_output_text import (
38 Annotation as ResponsesAnnotation,
39)
40from pydantic import BaseModel, ConfigDict, SkipValidation, TypeAdapter
41from typing_extensions import assert_never
42
43from .server import stream_widget
44from .store import Store, StoreItemType
45from .types import (
46 Annotation,
47 AssistantMessageContent,
48 AssistantMessageContentPartAdded,
49 AssistantMessageContentPartAnnotationAdded,
50 AssistantMessageContentPartDone,
51 AssistantMessageContentPartTextDelta,
52 AssistantMessageItem,
53 Attachment,
54 ClientToolCallItem,
55 DurationSummary,
56 EndOfTurnItem,
57 FileSource,
58 HiddenContextItem,
59 Task,
60 TaskItem,
61 ThoughtTask,
62 ThreadItem,
63 ThreadItemAddedEvent,
64 ThreadItemDoneEvent,
65 ThreadItemRemovedEvent,
66 ThreadItemUpdatedEvent,
67 ThreadMetadata,
68 ThreadStreamEvent,
69 URLSource,
70 UserMessageItem,
71 UserMessageTagContent,
72 UserMessageTextContent,
73 WidgetItem,
74 Workflow,
75 WorkflowItem,
76 WorkflowSummary,
77 WorkflowTaskAdded,
78 WorkflowTaskUpdated,
79)
80from .widgets import Markdown, Text, WidgetRoot
81
82
83class ClientToolCall(BaseModel):
84 """
85 Returned from tool methods to indicate a client-side tool call.
86 """
87
88 name: str
89 arguments: dict[str, Any]
90
91
92class _QueueCompleteSentinel: ...
93
94
95TContext = TypeVar("TContext")
96
97
98class AgentContext(BaseModel, Generic[TContext]):
99 model_config = ConfigDict(arbitrary_types_allowed=True)
100
101 thread: ThreadMetadata
102 store: Annotated[Store[TContext], SkipValidation]
103 request_context: TContext
104 previous_response_id: str | None = None
105 client_tool_call: ClientToolCall | None = None
106 workflow_item: WorkflowItem | None = None
107 _events: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel] = asyncio.Queue()
108
109 def generate_id(
110 self, type: StoreItemType, thread: ThreadMetadata | None = None
111 ) -> str:
112 if type == "thread":
113 return self.store.generate_thread_id(self.request_context)
114 return self.store.generate_item_id(
115 type, thread or self.thread, self.request_context
116 )
117
118 async def stream_widget(
119 self,
120 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
121 copy_text: str | None = None,
122 ) -> None:
123 async for event in stream_widget(
124 self.thread,
125 widget,
126 copy_text,
127 lambda item_type: self.store.generate_item_id(
128 item_type, self.thread, self.request_context
129 ),
130 ):
131 await self._events.put(event)
132
133 async def end_workflow(
134 self, summary: WorkflowSummary | None = None, expanded: bool = False
135 ) -> None:
136 if not self.workflow_item:
137 # No workflow to end
138 return
139
140 if summary is not None:
141 self.workflow_item.workflow.summary = summary
142 elif self.workflow_item.workflow.summary is None:
143 # If no summary was set or provided, set a basic work summary
144 delta = datetime.now() - self.workflow_item.created_at
145 duration = int(delta.total_seconds())
146 self.workflow_item.workflow.summary = DurationSummary(duration=duration)
147 self.workflow_item.workflow.expanded = expanded
148 await self.stream(ThreadItemDoneEvent(item=self.workflow_item))
149 self.workflow_item = None
150
151 async def start_workflow(self, workflow: Workflow) -> None:
152 self.workflow_item = WorkflowItem(
153 id=self.generate_id("workflow"),
154 created_at=datetime.now(),
155 workflow=workflow,
156 thread_id=self.thread.id,
157 )
158
159 if workflow.type != "reasoning" and len(workflow.tasks) == 0:
160 # Defer sending added event until we have tasks
161 return
162
163 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
164
165 async def update_workflow_task(self, task: Task, task_index: int) -> None:
166 if self.workflow_item is None:
167 raise ValueError("Workflow is not set")
168 # ensure reference is updated in case task is a copy
169 self.workflow_item.workflow.tasks[task_index] = task
170 await self.stream(
171 ThreadItemUpdatedEvent(
172 item_id=self.workflow_item.id,
173 update=WorkflowTaskUpdated(
174 task=task,
175 task_index=task_index,
176 ),
177 )
178 )
179
180 async def add_workflow_task(self, task: Task) -> None:
181 self.workflow_item = self.workflow_item or WorkflowItem(
182 id=self.generate_id("workflow"),
183 created_at=datetime.now(),
184 workflow=Workflow(type="custom", tasks=[]),
185 thread_id=self.thread.id,
186 )
187 workflow = self.workflow_item.workflow
188 workflow.tasks.append(task)
189
190 if workflow.type != "reasoning" and len(workflow.tasks) == 1:
191 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
192 else:
193 await self.stream(
194 ThreadItemUpdatedEvent(
195 item_id=self.workflow_item.id,
196 update=WorkflowTaskAdded(
197 task=task,
198 task_index=workflow.tasks.index(task),
199 ),
200 )
201 )
202
203 async def stream(self, event: ThreadStreamEvent) -> None:
204 await self._events.put(event)
205
206 def _complete(self):
207 self._events.put_nowait(_QueueCompleteSentinel())
208
209
210def _convert_content(content: Content) -> AssistantMessageContent:
211 if content.type == "output_text":
212 annotations = [
213 _convert_annotation(annotation) for annotation in content.annotations
214 ]
215 annotations = [a for a in annotations if a is not None]
216 return AssistantMessageContent(
217 text=content.text,
218 annotations=annotations,
219 )
220 else:
221 return AssistantMessageContent(
222 text=content.refusal,
223 annotations=[],
224 )
225
226
227def _convert_annotation(raw_annotation: object) -> Annotation | None:
228 # There is a bug in the OpenAPI client that sometimes parses the annotation delta event into the wrong class
229 # resulting into annotation being a dict or untyped object instead instead of a ResponsesAnnotation
230 annotation = TypeAdapter[ResponsesAnnotation](ResponsesAnnotation).validate_python(
231 raw_annotation
232 )
233
234 if annotation.type == "file_citation":
235 filename = annotation.filename
236 if not filename:
237 return None
238
239 return Annotation(
240 source=FileSource(filename=filename, title=filename),
241 index=annotation.index,
242 )
243
244 if annotation.type == "url_citation":
245 return Annotation(
246 source=URLSource(
247 url=annotation.url,
248 title=annotation.title,
249 ),
250 index=annotation.end_index,
251 )
252
253 if annotation.type == "container_file_citation":
254 filename = annotation.filename
255 if not filename:
256 return None
257
258 return Annotation(
259 source=FileSource(filename=filename, title=filename),
260 index=annotation.end_index,
261 )
262
263 return None
264
265
266T1 = TypeVar("T1")
267T2 = TypeVar("T2")
268
269
270async def _merge_generators(
271 a: AsyncIterator[T1],
272 b: AsyncIterator[T2],
273) -> AsyncIterator[T1 | T2]:
274 pending: list[AsyncIterator[T1 | T2]] = [a, b]
275 pending_tasks: dict[asyncio.Task, AsyncIterator[T1 | T2]] = {
276 asyncio.ensure_future(g.__anext__()): g for g in pending
277 }
278 while len(pending_tasks) > 0:
279 done, _ = await asyncio.wait(
280 pending_tasks.keys(), return_when="FIRST_COMPLETED"
281 )
282 stop = False
283 for d in done:
284 try:
285 result = d.result()
286 yield result
287 dg = pending_tasks[d]
288 pending_tasks[asyncio.ensure_future(dg.__anext__())] = dg
289 except StopAsyncIteration:
290 stop = True
291 finally:
292 del pending_tasks[d]
293 if stop:
294 for task in pending_tasks.keys():
295 if not task.cancel():
296 try:
297 yield task.result()
298 except asyncio.CancelledError:
299 pass
300 except asyncio.InvalidStateError:
301 pass
302 break
303
304
305class _EventWrapper:
306 def __init__(self, event: ThreadStreamEvent):
307 self.event = event
308
309
310class _AsyncQueueIterator(AsyncIterator[_EventWrapper]):
311 def __init__(
312 self, queue: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel]
313 ):
314 self.queue = queue
315 self.completed = False
316
317 def __aiter__(self):
318 return self
319
320 async def __anext__(self):
321 if self.completed:
322 raise StopAsyncIteration
323
324 item = await self.queue.get()
325 if isinstance(item, _QueueCompleteSentinel):
326 self.completed = True
327 raise StopAsyncIteration
328 return _EventWrapper(item)
329
330 def drain_and_complete(self) -> None:
331 """Empty the underlying queue without awaiting and mark this iterator completed.
332
333 This is intended for cleanup paths where we must guarantee no awaits
334 occur. All queued items, including any completion sentinel, are
335 discarded.
336 """
337 while True:
338 try:
339 self.queue.get_nowait()
340 except asyncio.QueueEmpty:
341 break
342 self.completed = True
343
344
345class StreamingThoughtTracker(BaseModel):
346 item_id: str
347 index: int
348 task: ThoughtTask
349
350
351async def stream_agent_response(
352 context: AgentContext, result: RunResultStreaming
353) -> AsyncIterator[ThreadStreamEvent]:
354 current_item_id = None
355 current_tool_call = None
356 ctx = context
357 thread = context.thread
358 queue_iterator = _AsyncQueueIterator(context._events)
359 produced_items = set()
360 streaming_thought: None | StreamingThoughtTracker = None
361 # item_id -> content_index -> annotation count
362 item_annotation_count: defaultdict[str, defaultdict[int, int]] = defaultdict(
363 lambda: defaultdict(int)
364 )
365
366 # check if the last item in the thread was a workflow or a client tool call
367 # if it was a client tool call, check if the second last item was a workflow
368 # if either was, continue the workflow
369 items = await context.store.load_thread_items(
370 thread.id, None, 2, "desc", context.request_context
371 )
372 last_item = items.data[0] if len(items.data) > 0 else None
373 second_last_item = items.data[1] if len(items.data) > 1 else None
374
375 if last_item and last_item.type == "workflow":
376 ctx.workflow_item = last_item
377 elif (
378 last_item
379 and last_item.type == "client_tool_call"
380 and second_last_item
381 and second_last_item.type == "workflow"
382 ):
383 ctx.workflow_item = second_last_item
384
385 def end_workflow(item: WorkflowItem):
386 if item == ctx.workflow_item:
387 ctx.workflow_item = None
388 delta = datetime.now() - item.created_at
389 duration = int(delta.total_seconds())
390 if item.workflow.summary is None:
391 item.workflow.summary = DurationSummary(duration=duration)
392 # Default to closing all workflows
393 # To keep a workflow open on completion, close it explicitly with
394 # AgentContext.end_workflow(expanded=True)
395 item.workflow.expanded = False
396 return ThreadItemDoneEvent(item=item)
397
398 try:
399 async for event in _merge_generators(result.stream_events(), queue_iterator):
400 # Events emitted from agent context helpers
401 if isinstance(event, _EventWrapper):
402 event = event.event
403 if (
404 event.type == "thread.item.added"
405 or event.type == "thread.item.done"
406 ):
407 # End the current workflow if visual item is added after it
408 if (
409 ctx.workflow_item
410 and ctx.workflow_item.id != event.item.id
411 and event.item.type != "client_tool_call"
412 and event.item.type != "hidden_context_item"
413 ):
414 yield end_workflow(ctx.workflow_item)
415
416 # track the current workflow if one is added
417 if (
418 event.type == "thread.item.added"
419 and event.item.type == "workflow"
420 ):
421 ctx.workflow_item = event.item
422
423 # track integration produced items so we can clean them up if
424 # there is a guardrail tripwire
425 produced_items.add(event.item.id)
426 yield event
427 continue
428
429 if event.type == "run_item_stream_event":
430 event = event.item
431 if (
432 event.type == "tool_call_item"
433 and event.raw_item.type == "function_call"
434 ):
435 current_tool_call = event.raw_item.call_id
436 current_item_id = event.raw_item.id
437 assert current_item_id
438 produced_items.add(current_item_id)
439 continue
440
441 if event.type != "raw_response_event":
442 # Ignore everything else that isn't a raw response event
443 continue
444
445 # Handle Responses events
446 event = event.data
447 if event.type == "response.content_part.added":
448 if event.part.type == "reasoning_text":
449 continue
450 content = _convert_content(event.part)
451 yield ThreadItemUpdatedEvent(
452 item_id=event.item_id,
453 update=AssistantMessageContentPartAdded(
454 content_index=event.content_index,
455 content=content,
456 ),
457 )
458 elif event.type == "response.output_text.delta":
459 yield ThreadItemUpdatedEvent(
460 item_id=event.item_id,
461 update=AssistantMessageContentPartTextDelta(
462 content_index=event.content_index,
463 delta=event.delta,
464 ),
465 )
466 elif event.type == "response.output_text.done":
467 yield ThreadItemUpdatedEvent(
468 item_id=event.item_id,
469 update=AssistantMessageContentPartDone(
470 content_index=event.content_index,
471 content=AssistantMessageContent(
472 text=event.text,
473 annotations=[],
474 ),
475 ),
476 )
477 elif event.type == "response.output_text.annotation.added":
478 annotation = _convert_annotation(event.annotation)
479 if annotation:
480 # Manually track annotation indices per content part in case we drop an annotation that
481 # we can't convert to our internal representation (e.g. missing filename).
482 annotation_index = item_annotation_count[event.item_id][
483 event.content_index
484 ]
485 item_annotation_count[event.item_id][event.content_index] = (
486 annotation_index + 1
487 )
488 yield ThreadItemUpdatedEvent(
489 item_id=event.item_id,
490 update=AssistantMessageContentPartAnnotationAdded(
491 content_index=event.content_index,
492 annotation_index=annotation_index,
493 annotation=annotation,
494 ),
495 )
496 continue
497 elif event.type == "response.output_item.added":
498 item = event.item
499 if item.type == "reasoning" and not ctx.workflow_item:
500 ctx.workflow_item = WorkflowItem(
501 id=ctx.generate_id("workflow"),
502 created_at=datetime.now(),
503 workflow=Workflow(type="reasoning", tasks=[]),
504 thread_id=thread.id,
505 )
506 produced_items.add(ctx.workflow_item.id)
507 yield ThreadItemAddedEvent(item=ctx.workflow_item)
508 if item.type == "message":
509 if ctx.workflow_item:
510 yield end_workflow(ctx.workflow_item)
511 produced_items.add(item.id)
512 yield ThreadItemAddedEvent(
513 item=AssistantMessageItem(
514 # Reusing the Responses message ID
515 id=item.id,
516 thread_id=thread.id,
517 content=[_convert_content(c) for c in item.content],
518 created_at=datetime.now(),
519 ),
520 )
521 elif event.type == "response.reasoning_summary_text.delta":
522 if not ctx.workflow_item:
523 continue
524
525 # stream the first thought in a new workflow so that we can show it earlier
526 if (
527 ctx.workflow_item.workflow.type == "reasoning"
528 and len(ctx.workflow_item.workflow.tasks) == 0
529 ):
530 streaming_thought = StreamingThoughtTracker(
531 item_id=event.item_id,
532 index=event.summary_index,
533 task=ThoughtTask(content=event.delta),
534 )
535 ctx.workflow_item.workflow.tasks.append(streaming_thought.task)
536 yield ThreadItemUpdatedEvent(
537 item_id=ctx.workflow_item.id,
538 update=WorkflowTaskAdded(
539 task=streaming_thought.task,
540 task_index=0,
541 ),
542 )
543 elif (
544 streaming_thought
545 and streaming_thought.task in ctx.workflow_item.workflow.tasks
546 and event.item_id == streaming_thought.item_id
547 and event.summary_index == streaming_thought.index
548 ):
549 streaming_thought.task.content += event.delta
550 yield ThreadItemUpdatedEvent(
551 item_id=ctx.workflow_item.id,
552 update=WorkflowTaskUpdated(
553 task=streaming_thought.task,
554 task_index=ctx.workflow_item.workflow.tasks.index(
555 streaming_thought.task
556 ),
557 ),
558 )
559 elif event.type == "response.reasoning_summary_text.done":
560 if ctx.workflow_item:
561 if (
562 streaming_thought
563 and streaming_thought.task in ctx.workflow_item.workflow.tasks
564 and event.item_id == streaming_thought.item_id
565 and event.summary_index == streaming_thought.index
566 ):
567 task = streaming_thought.task
568 task.content = event.text
569 streaming_thought = None
570 update = WorkflowTaskUpdated(
571 task=task,
572 task_index=ctx.workflow_item.workflow.tasks.index(task),
573 )
574 else:
575 task = ThoughtTask(content=event.text)
576 ctx.workflow_item.workflow.tasks.append(task)
577 update = WorkflowTaskAdded(
578 task=task,
579 task_index=ctx.workflow_item.workflow.tasks.index(task),
580 )
581 yield ThreadItemUpdatedEvent(
582 item_id=ctx.workflow_item.id,
583 update=update,
584 )
585 elif event.type == "response.output_item.done":
586 item = event.item
587 if item.type == "message":
588 produced_items.add(item.id)
589 yield ThreadItemDoneEvent(
590 item=AssistantMessageItem(
591 # Reusing the Responses message ID
592 id=item.id,
593 thread_id=thread.id,
594 content=[_convert_content(c) for c in item.content],
595 created_at=datetime.now(),
596 ),
597 )
598
599 except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
600 for item_id in produced_items:
601 yield ThreadItemRemovedEvent(item_id=item_id)
602
603 # Drain remaining events without processing them
604 context._complete()
605 queue_iterator.drain_and_complete()
606
607 raise
608
609 context._complete()
610
611 # Drain remaining events
612 async for event in queue_iterator:
613 yield event.event
614
615 # If there is still an active workflow at the end of the run, store
616 # it's current state so that we can continue it in the next turn.
617 if ctx.workflow_item:
618 await ctx.store.add_thread_item(
619 thread.id, ctx.workflow_item, ctx.request_context
620 )
621
622 if context.client_tool_call:
623 yield ThreadItemDoneEvent(
624 item=ClientToolCallItem(
625 id=current_item_id
626 or context.store.generate_item_id(
627 "tool_call", thread, context.request_context
628 ),
629 thread_id=thread.id,
630 name=context.client_tool_call.name,
631 arguments=context.client_tool_call.arguments,
632 created_at=datetime.now(),
633 call_id=current_tool_call
634 or context.store.generate_item_id(
635 "tool_call", thread, context.request_context
636 ),
637 ),
638 )
639
640
641TWidget = TypeVar("TWidget", bound=Markdown | Text)
642
643
644async def accumulate_text(
645 events: AsyncIterator[StreamEvent],
646 base_widget: TWidget,
647) -> AsyncIterator[TWidget]:
648 text = ""
649 yield base_widget
650 async for event in events:
651 if event.type == "raw_response_event":
652 if event.data.type == "response.output_text.delta":
653 text += event.data.delta
654 yield base_widget.model_copy(update={"value": text})
655 yield base_widget.model_copy(update={"value": text, "streaming": False})
656
657
658class ThreadItemConverter:
659 """
660 Converts thread items to Agent SDK input items.
661 Widgets, Tasks, and Workflows have default conversions but can be customized.
662 Attachments, Tags, and HiddenContextItems require custom handling based on the use case.
663 Other item types are converted automatically.
664 """
665
666 async def attachment_to_message_content(
667 self, attachment: Attachment
668 ) -> ResponseInputContentParam:
669 """
670 Convert an attachment in a user message into a message content part to send to the model.
671 Required when attachments are enabled.
672 """
673 raise NotImplementedError(
674 "An Attachment was included in a UserMessageItem but Converter.attachment_to_message_content was not implemented"
675 )
676
677 async def tag_to_message_content(
678 self, tag: UserMessageTagContent
679 ) -> ResponseInputContentParam:
680 """
681 Convert a tag in a user message into a message content part to send to the model as context.
682 Required when tags are used.
683 """
684 raise NotImplementedError(
685 "A Tag was included in a UserMessageItem but Converter.tag_to_message_content is not implemented"
686 )
687
688 async def hidden_context_to_input(
689 self, item: HiddenContextItem
690 ) -> TResponseInputItem | list[TResponseInputItem] | None:
691 """
692 Convert a HiddenContextItem into input item(s) to send to the model.
693 Required to override when HiddenContextItems with non-string content are used.
694 """
695 if not isinstance(item.content, str):
696 raise NotImplementedError(
697 "HiddenContextItems with non-string content were present in a user message but a Converter.hidden_context_to_input was not implemented"
698 )
699
700 text = (
701 "Hidden context for the agent (not shown to the user):\n"
702 f"<HiddenContext>\n{item.content}\n</HiddenContext>"
703 )
704 return Message(
705 type="message",
706 content=[
707 ResponseInputTextParam(
708 type="input_text",
709 text=text,
710 )
711 ],
712 role="user",
713 )
714
715 async def task_to_input(
716 self, item: TaskItem
717 ) -> TResponseInputItem | list[TResponseInputItem] | None:
718 """
719 Convert a TaskItem into input item(s) to send to the model.
720 """
721 if item.task.type != "custom" or (
722 not item.task.title and not item.task.content
723 ):
724 return None
725 title = f"{item.task.title}" if item.task.title else ""
726 content = f"{item.task.content}" if item.task.content else ""
727 task_text = f"{title}: {content}" if title and content else title or content
728 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
729 return Message(
730 type="message",
731 content=[
732 ResponseInputTextParam(
733 type="input_text",
734 text=text,
735 )
736 ],
737 role="user",
738 )
739
740 async def workflow_to_input(
741 self, item: WorkflowItem
742 ) -> TResponseInputItem | list[TResponseInputItem] | None:
743 """
744 Convert a TaskItem into input item(s) to send to the model.
745 Returns WorkflowItem.response_items by default.
746 """
747 messages = []
748 for task in item.workflow.tasks:
749 if task.type != "custom" or (not task.title and not task.content):
750 continue
751
752 title = f"{task.title}" if task.title else ""
753 content = f"{task.content}" if task.content else ""
754 task_text = f"{title}: {content}" if title and content else title or content
755 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
756 messages.append(
757 Message(
758 type="message",
759 content=[
760 ResponseInputTextParam(
761 type="input_text",
762 text=text,
763 )
764 ],
765 role="user",
766 )
767 )
768 return messages
769
770 async def widget_to_input(
771 self, item: WidgetItem
772 ) -> TResponseInputItem | list[TResponseInputItem] | None:
773 """
774 Convert a WidgetItem into input item(s) to send to the model.
775 By default, WidgetItems converted to a text description with a JSON representation of the widget.
776 """
777 return Message(
778 type="message",
779 content=[
780 ResponseInputTextParam(
781 type="input_text",
782 text=f"The following graphical UI widget (id: {item.id}) was displayed to the user:"
783 + item.widget.model_dump_json(
784 exclude_unset=True, exclude_none=True
785 ),
786 )
787 ],
788 role="user",
789 )
790
791 async def user_message_to_input(
792 self, item: UserMessageItem, is_last_message: bool = True
793 ) -> TResponseInputItem | list[TResponseInputItem] | None:
794 # Build the user text exactly as typed, rendering tags as @key
795 message_text_parts: list[str] = []
796 # Track tags separately to add system context
797 raw_tags: list[UserMessageTagContent] = []
798
799 for part in item.content:
800 if isinstance(part, UserMessageTextContent):
801 message_text_parts.append(part.text)
802 elif isinstance(part, UserMessageTagContent):
803 message_text_parts.append(f"@{part.text}")
804 raw_tags.append(part)
805 else:
806 assert_never(part)
807
808 user_text_item = Message(
809 role="user",
810 type="message",
811 content=[
812 ResponseInputTextParam(
813 type="input_text", text="".join(message_text_parts)
814 ),
815 *[
816 await self.attachment_to_message_content(a)
817 for a in item.attachments
818 ],
819 ],
820 )
821
822 # Build system items (prepend later): quoted text and @-mention context
823 context_items: list[TResponseInputItem] = []
824
825 if item.quoted_text and is_last_message:
826 context_items.append(
827 Message(
828 role="user",
829 type="message",
830 content=[
831 ResponseInputTextParam(
832 type="input_text",
833 text=f"The user is referring to this in particular: \n{item.quoted_text}",
834 )
835 ],
836 )
837 )
838
839 # Dedupe tags (preserve order) and resolve to message content
840 if raw_tags:
841 seen, uniq_tags = set(), []
842 for t in raw_tags:
843 if t.text not in seen:
844 seen.add(t.text)
845 uniq_tags.append(t)
846
847 tag_content: ResponseInputMessageContentListParam = [
848 # should return summarized text items
849 await self.tag_to_message_content(tag)
850 for tag in uniq_tags
851 ]
852
853 if tag_content:
854 context_items.append(
855 Message(
856 role="user",
857 type="message",
858 content=[
859 ResponseInputTextParam(
860 type="input_text",
861 text=cleandoc("""
862 # User-provided context for @-mentions
863 - When referencing resolved entities, use their canonical names **without** '@'.
864 - The '@' form appears only in user text and should not be echoed.
865 """).strip(),
866 ),
867 *tag_content,
868 ],
869 )
870 )
871
872 return [user_text_item, *context_items]
873
874 async def assistant_message_to_input(
875 self, item: AssistantMessageItem
876 ) -> TResponseInputItem | list[TResponseInputItem] | None:
877 return EasyInputMessageParam(
878 type="message",
879 content=[
880 # content param doesn't support the assistant message content types
881 cast(
882 ResponseInputContentParam,
883 ResponseOutputText(
884 type="output_text",
885 text=c.text,
886 annotations=[], # TODO: these should be sent back as well
887 ).model_dump(),
888 )
889 for c in item.content
890 ],
891 role="assistant",
892 )
893
894 async def client_tool_call_to_input(
895 self, item: ClientToolCallItem
896 ) -> TResponseInputItem | list[TResponseInputItem] | None:
897 if item.status == "pending":
898 # Filter out pending tool calls - they cannot be sent to the model
899 return None
900
901 return [
902 ResponseFunctionToolCallParam(
903 type="function_call",
904 call_id=item.call_id,
905 name=item.name,
906 arguments=json.dumps(item.arguments),
907 ),
908 FunctionCallOutput(
909 type="function_call_output",
910 call_id=item.call_id,
911 output=json.dumps(item.output),
912 ),
913 ]
914
915 async def end_of_turn_to_input(
916 self, item: EndOfTurnItem
917 ) -> TResponseInputItem | list[TResponseInputItem] | None:
918 # Only used for UI hints - you shouldn't need to override this
919 return None
920
921 async def _thread_item_to_input_item(
922 self,
923 item: ThreadItem,
924 is_last_message: bool = True,
925 ) -> list[TResponseInputItem]:
926 match item:
927 case UserMessageItem():
928 out = await self.user_message_to_input(item, is_last_message) or []
929 return out if isinstance(out, list) else [out]
930 case AssistantMessageItem():
931 out = await self.assistant_message_to_input(item) or []
932 return out if isinstance(out, list) else [out]
933 case ClientToolCallItem():
934 out = await self.client_tool_call_to_input(item) or []
935 return out if isinstance(out, list) else [out]
936 case EndOfTurnItem():
937 out = await self.end_of_turn_to_input(item) or []
938 return out if isinstance(out, list) else [out]
939 case WidgetItem():
940 out = await self.widget_to_input(item) or []
941 return out if isinstance(out, list) else [out]
942 case WorkflowItem():
943 out = await self.workflow_to_input(item) or []
944 return out if isinstance(out, list) else [out]
945 case TaskItem():
946 out = await self.task_to_input(item) or []
947 return out if isinstance(out, list) else [out]
948 case HiddenContextItem():
949 out = await self.hidden_context_to_input(item) or []
950 return out if isinstance(out, list) else [out]
951 case _:
952 assert_never(item)
953
954 async def to_agent_input(
955 self,
956 thread_items: Sequence[ThreadItem] | ThreadItem,
957 ) -> list[TResponseInputItem]:
958 if isinstance(thread_items, Sequence):
959 # shallow copy in case caller mutates the list while we're iterating
960 thread_items = thread_items[:]
961 else:
962 thread_items = [thread_items]
963 output: list[TResponseInputItem] = []
964 for item in thread_items:
965 output.extend(
966 await self._thread_item_to_input_item(
967 item,
968 is_last_message=item is thread_items[-1],
969 )
970 )
971 return output
972
973
974_DEFAULT_CONVERTER = ThreadItemConverter()
975
976
977def simple_to_agent_input(thread_items: Sequence[ThreadItem] | ThreadItem):
978 return _DEFAULT_CONVERTER.to_agent_input(thread_items)
979