openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.3.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/agents.py

962lines · modecode

1import asyncio
2import json
3from collections import defaultdict
4from collections.abc import AsyncIterator
5from datetime import datetime
6from inspect import cleandoc
7from typing import (
8 Annotated,
9 Any,
10 AsyncGenerator,
11 Generic,
12 Sequence,
13 TypeVar,
14 cast,
15)
16
17from agents import (
18 InputGuardrailTripwireTriggered,
19 OutputGuardrailTripwireTriggered,
20 RunResultStreaming,
21 StreamEvent,
22 TResponseInputItem,
23)
24from openai.types.responses import (
25 EasyInputMessageParam,
26 ResponseFunctionToolCallParam,
27 ResponseInputContentParam,
28 ResponseInputMessageContentListParam,
29 ResponseInputTextParam,
30 ResponseOutputText,
31)
32from openai.types.responses.response_input_item_param import (
33 FunctionCallOutput,
34 Message,
35)
36from openai.types.responses.response_output_message import Content
37from openai.types.responses.response_output_text import (
38 Annotation as ResponsesAnnotation,
39)
40from pydantic import BaseModel, ConfigDict, SkipValidation, TypeAdapter
41from typing_extensions import assert_never
42
43from .server import stream_widget
44from .store import Store, StoreItemType
45from .types import (
46 Annotation,
47 AssistantMessageContent,
48 AssistantMessageContentPartAdded,
49 AssistantMessageContentPartAnnotationAdded,
50 AssistantMessageContentPartDone,
51 AssistantMessageContentPartTextDelta,
52 AssistantMessageItem,
53 Attachment,
54 ClientToolCallItem,
55 DurationSummary,
56 EndOfTurnItem,
57 FileSource,
58 HiddenContextItem,
59 Task,
60 TaskItem,
61 ThoughtTask,
62 ThreadItem,
63 ThreadItemAddedEvent,
64 ThreadItemDoneEvent,
65 ThreadItemRemovedEvent,
66 ThreadItemUpdated,
67 ThreadMetadata,
68 ThreadStreamEvent,
69 URLSource,
70 UserMessageItem,
71 UserMessageTagContent,
72 UserMessageTextContent,
73 WidgetItem,
74 Workflow,
75 WorkflowItem,
76 WorkflowSummary,
77 WorkflowTaskAdded,
78 WorkflowTaskUpdated,
79)
80from .widgets import Markdown, Text, WidgetRoot
81
82
83class ClientToolCall(BaseModel):
84 """
85 Returned from tool methods to indicate a client-side tool call.
86 """
87
88 name: str
89 arguments: dict[str, Any]
90
91
92class _QueueCompleteSentinel: ...
93
94
95TContext = TypeVar("TContext")
96
97
98class AgentContext(BaseModel, Generic[TContext]):
99 model_config = ConfigDict(arbitrary_types_allowed=True)
100
101 thread: ThreadMetadata
102 store: Annotated[Store[TContext], SkipValidation]
103 request_context: TContext
104 previous_response_id: str | None = None
105 client_tool_call: ClientToolCall | None = None
106 workflow_item: WorkflowItem | None = None
107 _events: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel] = asyncio.Queue()
108
109 def generate_id(
110 self, type: StoreItemType, thread: ThreadMetadata | None = None
111 ) -> str:
112 if type == "thread":
113 return self.store.generate_thread_id(self.request_context)
114 return self.store.generate_item_id(
115 type, thread or self.thread, self.request_context
116 )
117
118 async def stream_widget(
119 self,
120 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
121 copy_text: str | None = None,
122 ) -> None:
123 async for event in stream_widget(
124 self.thread,
125 widget,
126 copy_text,
127 lambda item_type: self.store.generate_item_id(
128 item_type, self.thread, self.request_context
129 ),
130 ):
131 await self._events.put(event)
132
133 async def end_workflow(
134 self, summary: WorkflowSummary | None = None, expanded: bool = False
135 ) -> None:
136 if not self.workflow_item:
137 # No workflow to end
138 return
139
140 if summary is not None:
141 self.workflow_item.workflow.summary = summary
142 elif self.workflow_item.workflow.summary is None:
143 # If no summary was set or provided, set a basic work summary
144 delta = datetime.now() - self.workflow_item.created_at
145 duration = int(delta.total_seconds())
146 self.workflow_item.workflow.summary = DurationSummary(duration=duration)
147 self.workflow_item.workflow.expanded = expanded
148 await self.stream(ThreadItemDoneEvent(item=self.workflow_item))
149 self.workflow_item = None
150
151 async def start_workflow(self, workflow: Workflow) -> None:
152 self.workflow_item = WorkflowItem(
153 id=self.generate_id("workflow"),
154 created_at=datetime.now(),
155 workflow=workflow,
156 thread_id=self.thread.id,
157 )
158
159 if workflow.type != "reasoning" and len(workflow.tasks) == 0:
160 # Defer sending added event until we have tasks
161 return
162
163 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
164
165 async def update_workflow_task(self, task: Task, task_index: int) -> None:
166 if self.workflow_item is None:
167 raise ValueError("Workflow is not set")
168 # ensure reference is updated in case task is a copy
169 self.workflow_item.workflow.tasks[task_index] = task
170 await self.stream(
171 ThreadItemUpdated(
172 item_id=self.workflow_item.id,
173 update=WorkflowTaskUpdated(
174 task=task,
175 task_index=task_index,
176 ),
177 )
178 )
179
180 async def add_workflow_task(self, task: Task) -> None:
181 self.workflow_item = self.workflow_item or WorkflowItem(
182 id=self.generate_id("workflow"),
183 created_at=datetime.now(),
184 workflow=Workflow(type="custom", tasks=[]),
185 thread_id=self.thread.id,
186 )
187 workflow = self.workflow_item.workflow
188 workflow.tasks.append(task)
189
190 if workflow.type != "reasoning" and len(workflow.tasks) == 1:
191 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
192 else:
193 await self.stream(
194 ThreadItemUpdated(
195 item_id=self.workflow_item.id,
196 update=WorkflowTaskAdded(
197 task=task,
198 task_index=workflow.tasks.index(task),
199 ),
200 )
201 )
202
203 async def stream(self, event: ThreadStreamEvent) -> None:
204 await self._events.put(event)
205
206 def _complete(self):
207 self._events.put_nowait(_QueueCompleteSentinel())
208
209
210def _convert_content(content: Content) -> AssistantMessageContent:
211 if content.type == "output_text":
212 annotations = [
213 _convert_annotation(annotation) for annotation in content.annotations
214 ]
215 annotations = [a for a in annotations if a is not None]
216 return AssistantMessageContent(
217 text=content.text,
218 annotations=annotations,
219 )
220 else:
221 return AssistantMessageContent(
222 text=content.refusal,
223 annotations=[],
224 )
225
226
227def _convert_annotation(raw_annotation: object) -> Annotation | None:
228 # There is a bug in the OpenAPI client that sometimes parses the annotation delta event into the wrong class
229 # resulting into annotation being a dict or untyped object instead instead of a ResponsesAnnotation
230 annotation = TypeAdapter[ResponsesAnnotation](ResponsesAnnotation).validate_python(
231 raw_annotation
232 )
233
234 if annotation.type == "file_citation":
235 filename = annotation.filename
236 if not filename:
237 return None
238
239 return Annotation(
240 source=FileSource(filename=filename, title=filename),
241 index=annotation.index,
242 )
243
244 if annotation.type == "url_citation":
245 return Annotation(
246 source=URLSource(
247 url=annotation.url,
248 title=annotation.title,
249 ),
250 index=annotation.end_index,
251 )
252
253 if annotation.type == "container_file_citation":
254 filename = annotation.filename
255 if not filename:
256 return None
257
258 return Annotation(
259 source=FileSource(filename=filename, title=filename),
260 index=annotation.end_index,
261 )
262
263 return None
264
265
266T1 = TypeVar("T1")
267T2 = TypeVar("T2")
268
269
270async def _merge_generators(
271 a: AsyncIterator[T1],
272 b: AsyncIterator[T2],
273) -> AsyncIterator[T1 | T2]:
274 pending: list[AsyncIterator[T1 | T2]] = [a, b]
275 pending_tasks: dict[asyncio.Task, AsyncIterator[T1 | T2]] = {
276 asyncio.ensure_future(g.__anext__()): g for g in pending
277 }
278 while len(pending_tasks) > 0:
279 done, _ = await asyncio.wait(
280 pending_tasks.keys(), return_when="FIRST_COMPLETED"
281 )
282 stop = False
283 for d in done:
284 try:
285 result = d.result()
286 yield result
287 dg = pending_tasks[d]
288 pending_tasks[asyncio.ensure_future(dg.__anext__())] = dg
289 except StopAsyncIteration:
290 stop = True
291 finally:
292 del pending_tasks[d]
293 if stop:
294 for task in pending_tasks.keys():
295 if not task.cancel():
296 try:
297 yield task.result()
298 except asyncio.CancelledError:
299 pass
300 except asyncio.InvalidStateError:
301 pass
302 break
303
304
305class _EventWrapper:
306 def __init__(self, event: ThreadStreamEvent):
307 self.event = event
308
309
310class _AsyncQueueIterator(AsyncIterator[_EventWrapper]):
311 def __init__(
312 self, queue: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel]
313 ):
314 self.queue = queue
315 self.completed = False
316
317 def __aiter__(self):
318 return self
319
320 async def __anext__(self):
321 if self.completed:
322 raise StopAsyncIteration
323
324 item = await self.queue.get()
325 if isinstance(item, _QueueCompleteSentinel):
326 self.completed = True
327 raise StopAsyncIteration
328 return _EventWrapper(item)
329
330 def drain_and_complete(self) -> None:
331 """Empty the underlying queue without awaiting and mark this iterator completed.
332
333 This is intended for cleanup paths where we must guarantee no awaits
334 occur. All queued items, including any completion sentinel, are
335 discarded.
336 """
337 while True:
338 try:
339 self.queue.get_nowait()
340 except asyncio.QueueEmpty:
341 break
342 self.completed = True
343
344
345class StreamingThoughtTracker(BaseModel):
346 item_id: str
347 index: int
348 task: ThoughtTask
349
350
351async def stream_agent_response(
352 context: AgentContext, result: RunResultStreaming
353) -> AsyncIterator[ThreadStreamEvent]:
354 current_item_id = None
355 current_tool_call = None
356 ctx = context
357 thread = context.thread
358 queue_iterator = _AsyncQueueIterator(context._events)
359 produced_items = set()
360 streaming_thought: None | StreamingThoughtTracker = None
361 # item_id -> content_index -> annotation count
362 item_annotation_count: defaultdict[str, defaultdict[int, int]] = defaultdict(
363 lambda: defaultdict(int)
364 )
365
366 # check if the last item in the thread was a workflow or a client tool call
367 # if it was a client tool call, check if the second last item was a workflow
368 # if either was, continue the workflow
369 items = await context.store.load_thread_items(
370 thread.id, None, 2, "desc", context.request_context
371 )
372 last_item = items.data[0] if len(items.data) > 0 else None
373 second_last_item = items.data[1] if len(items.data) > 1 else None
374
375 if last_item and last_item.type == "workflow":
376 ctx.workflow_item = last_item
377 elif (
378 last_item
379 and last_item.type == "client_tool_call"
380 and second_last_item
381 and second_last_item.type == "workflow"
382 ):
383 ctx.workflow_item = second_last_item
384
385 def end_workflow(item: WorkflowItem):
386 if item == ctx.workflow_item:
387 ctx.workflow_item = None
388 delta = datetime.now() - item.created_at
389 duration = int(delta.total_seconds())
390 if item.workflow.summary is None:
391 item.workflow.summary = DurationSummary(duration=duration)
392 # Default to closing all workflows
393 # To keep a workflow open on completion, close it explicitly with
394 # AgentContext.end_workflow(expanded=True)
395 item.workflow.expanded = False
396 return ThreadItemDoneEvent(item=item)
397
398 try:
399 async for event in _merge_generators(result.stream_events(), queue_iterator):
400 # Events emitted from agent context helpers
401 if isinstance(event, _EventWrapper):
402 event = event.event
403 if (
404 event.type == "thread.item.added"
405 or event.type == "thread.item.done"
406 ):
407 # End the current workflow if visual item is added after it
408 if (
409 ctx.workflow_item
410 and ctx.workflow_item.id != event.item.id
411 and event.item.type != "client_tool_call"
412 and event.item.type != "hidden_context_item"
413 ):
414 yield end_workflow(ctx.workflow_item)
415
416 # track the current workflow if one is added
417 if (
418 event.type == "thread.item.added"
419 and event.item.type == "workflow"
420 ):
421 ctx.workflow_item = event.item
422
423 # track integration produced items so we can clean them up if
424 # there is a guardrail tripwire
425 produced_items.add(event.item.id)
426 yield event
427 continue
428
429 if event.type == "run_item_stream_event":
430 event = event.item
431 if (
432 event.type == "tool_call_item"
433 and event.raw_item.type == "function_call"
434 ):
435 current_tool_call = event.raw_item.call_id
436 current_item_id = event.raw_item.id
437 assert current_item_id
438 produced_items.add(current_item_id)
439 continue
440
441 if event.type != "raw_response_event":
442 # Ignore everything else that isn't a raw response event
443 continue
444
445 # Handle Responses events
446 event = event.data
447 if event.type == "response.content_part.added":
448 if event.part.type == "reasoning_text":
449 continue
450 content = _convert_content(event.part)
451 yield ThreadItemUpdated(
452 item_id=event.item_id,
453 update=AssistantMessageContentPartAdded(
454 content_index=event.content_index,
455 content=content,
456 ),
457 )
458 elif event.type == "response.output_text.delta":
459 yield ThreadItemUpdated(
460 item_id=event.item_id,
461 update=AssistantMessageContentPartTextDelta(
462 content_index=event.content_index,
463 delta=event.delta,
464 ),
465 )
466 elif event.type == "response.output_text.done":
467 yield ThreadItemUpdated(
468 item_id=event.item_id,
469 update=AssistantMessageContentPartDone(
470 content_index=event.content_index,
471 content=AssistantMessageContent(
472 text=event.text,
473 annotations=[],
474 ),
475 ),
476 )
477 elif event.type == "response.output_text.annotation.added":
478 annotation = _convert_annotation(event.annotation)
479 if annotation:
480 # Manually track annotation indices per content part in case we drop an annotation that
481 # we can't convert to our internal representation (e.g. missing filename).
482 annotation_index = item_annotation_count[event.item_id][
483 event.content_index
484 ]
485 item_annotation_count[event.item_id][event.content_index] = (
486 annotation_index + 1
487 )
488 yield ThreadItemUpdated(
489 item_id=event.item_id,
490 update=AssistantMessageContentPartAnnotationAdded(
491 content_index=event.content_index,
492 annotation_index=annotation_index,
493 annotation=annotation,
494 ),
495 )
496 continue
497 elif event.type == "response.output_item.added":
498 item = event.item
499 if item.type == "reasoning" and not ctx.workflow_item:
500 ctx.workflow_item = WorkflowItem(
501 id=ctx.generate_id("workflow"),
502 created_at=datetime.now(),
503 workflow=Workflow(type="reasoning", tasks=[]),
504 thread_id=thread.id,
505 )
506 produced_items.add(ctx.workflow_item.id)
507 yield ThreadItemAddedEvent(item=ctx.workflow_item)
508 if item.type == "message":
509 if ctx.workflow_item:
510 yield end_workflow(ctx.workflow_item)
511 produced_items.add(item.id)
512 yield ThreadItemAddedEvent(
513 item=AssistantMessageItem(
514 # Reusing the Responses message ID
515 id=item.id,
516 thread_id=thread.id,
517 content=[_convert_content(c) for c in item.content],
518 created_at=datetime.now(),
519 ),
520 )
521 elif event.type == "response.reasoning_summary_text.delta":
522 if not ctx.workflow_item:
523 continue
524
525 # stream the first thought in a new workflow so that we can show it earlier
526 if (
527 ctx.workflow_item.workflow.type == "reasoning"
528 and len(ctx.workflow_item.workflow.tasks) == 0
529 ):
530 streaming_thought = StreamingThoughtTracker(
531 item_id=event.item_id,
532 index=event.summary_index,
533 task=ThoughtTask(content=event.delta),
534 )
535 ctx.workflow_item.workflow.tasks.append(streaming_thought.task)
536 yield ThreadItemUpdated(
537 item_id=ctx.workflow_item.id,
538 update=WorkflowTaskAdded(
539 task=streaming_thought.task,
540 task_index=0,
541 ),
542 )
543 elif (
544 streaming_thought
545 and streaming_thought.task in ctx.workflow_item.workflow.tasks
546 and event.item_id == streaming_thought.item_id
547 and event.summary_index == streaming_thought.index
548 ):
549 streaming_thought.task.content += event.delta
550 yield ThreadItemUpdated(
551 item_id=ctx.workflow_item.id,
552 update=WorkflowTaskUpdated(
553 task=streaming_thought.task,
554 task_index=ctx.workflow_item.workflow.tasks.index(
555 streaming_thought.task
556 ),
557 ),
558 )
559 elif event.type == "response.reasoning_summary_text.done":
560 if ctx.workflow_item:
561 if (
562 streaming_thought
563 and streaming_thought.task in ctx.workflow_item.workflow.tasks
564 and event.item_id == streaming_thought.item_id
565 and event.summary_index == streaming_thought.index
566 ):
567 task = streaming_thought.task
568 task.content = event.text
569 streaming_thought = None
570 update = WorkflowTaskUpdated(
571 task=task,
572 task_index=ctx.workflow_item.workflow.tasks.index(task),
573 )
574 else:
575 task = ThoughtTask(content=event.text)
576 ctx.workflow_item.workflow.tasks.append(task)
577 update = WorkflowTaskAdded(
578 task=task,
579 task_index=ctx.workflow_item.workflow.tasks.index(task),
580 )
581 yield ThreadItemUpdated(
582 item_id=ctx.workflow_item.id,
583 update=update,
584 )
585 elif event.type == "response.output_item.done":
586 item = event.item
587 if item.type == "message":
588 produced_items.add(item.id)
589 yield ThreadItemDoneEvent(
590 item=AssistantMessageItem(
591 # Reusing the Responses message ID
592 id=item.id,
593 thread_id=thread.id,
594 content=[_convert_content(c) for c in item.content],
595 created_at=datetime.now(),
596 ),
597 )
598
599 except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
600 for item_id in produced_items:
601 yield ThreadItemRemovedEvent(item_id=item_id)
602
603 # Drain remaining events without processing them
604 context._complete()
605 queue_iterator.drain_and_complete()
606
607 raise
608
609 context._complete()
610
611 # Drain remaining events
612 async for event in queue_iterator:
613 yield event.event
614
615 # If there is still an active workflow at the end of the run, store
616 # it's current state so that we can continue it in the next turn.
617 if ctx.workflow_item:
618 await ctx.store.add_thread_item(
619 thread.id, ctx.workflow_item, ctx.request_context
620 )
621
622 if context.client_tool_call:
623 yield ThreadItemDoneEvent(
624 item=ClientToolCallItem(
625 id=current_item_id
626 or context.store.generate_item_id(
627 "tool_call", thread, context.request_context
628 ),
629 thread_id=thread.id,
630 name=context.client_tool_call.name,
631 arguments=context.client_tool_call.arguments,
632 created_at=datetime.now(),
633 call_id=current_tool_call
634 or context.store.generate_item_id(
635 "tool_call", thread, context.request_context
636 ),
637 ),
638 )
639
640
641TWidget = TypeVar("TWidget", bound=Markdown | Text)
642
643
644async def accumulate_text(
645 events: AsyncIterator[StreamEvent],
646 base_widget: TWidget,
647) -> AsyncIterator[TWidget]:
648 text = ""
649 yield base_widget
650 async for event in events:
651 if event.type == "raw_response_event":
652 if event.data.type == "response.output_text.delta":
653 text += event.data.delta
654 yield base_widget.model_copy(update={"value": text})
655 yield base_widget.model_copy(update={"value": text, "streaming": False})
656
657
658class ThreadItemConverter:
659 """
660 Converts thread items to Agent SDK input items.
661 Widgets, Tasks, and Workflows have default conversions but can be customized.
662 Attachments, Tags, and HiddenContextItems require custom handling based on the use case.
663 Other item types are converted automatically.
664 """
665
666 async def attachment_to_message_content(
667 self, attachment: Attachment
668 ) -> ResponseInputContentParam:
669 """
670 Convert an attachment in a user message into a message content part to send to the model.
671 Required when attachments are enabled.
672 """
673 raise NotImplementedError(
674 "An Attachment was included in a UserMessageItem but Converter.attachment_to_message_content was not implemented"
675 )
676
677 async def tag_to_message_content(
678 self, tag: UserMessageTagContent
679 ) -> ResponseInputContentParam:
680 """
681 Convert a tag in a user message into a message content part to send to the model as context.
682 Required when tags are used.
683 """
684 raise NotImplementedError(
685 "A Tag was included in a UserMessageItem but Converter.tag_to_message_content is not implemented"
686 )
687
688 async def hidden_context_to_input(
689 self, item: HiddenContextItem
690 ) -> TResponseInputItem | list[TResponseInputItem] | None:
691 """
692 Convert a HiddenContextItem into input item(s) to send to the model.
693 Required when HiddenContextItem are used.
694 """
695 raise NotImplementedError(
696 "HiddenContextItem were present in a user message but Converter.hidden_context_to_input was not implemented"
697 )
698
699 async def task_to_input(
700 self, item: TaskItem
701 ) -> TResponseInputItem | list[TResponseInputItem] | None:
702 """
703 Convert a TaskItem into input item(s) to send to the model.
704 """
705 if item.task.type != "custom" or (
706 not item.task.title and not item.task.content
707 ):
708 return None
709 title = f"{item.task.title}" if item.task.title else ""
710 content = f"{item.task.content}" if item.task.content else ""
711 task_text = f"{title}: {content}" if title and content else title or content
712 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
713 return Message(
714 type="message",
715 content=[
716 ResponseInputTextParam(
717 type="input_text",
718 text=text,
719 )
720 ],
721 role="user",
722 )
723
724 async def workflow_to_input(
725 self, item: WorkflowItem
726 ) -> TResponseInputItem | list[TResponseInputItem] | None:
727 """
728 Convert a TaskItem into input item(s) to send to the model.
729 Returns WorkflowItem.response_items by default.
730 """
731 messages = []
732 for task in item.workflow.tasks:
733 if task.type != "custom" or (not task.title and not task.content):
734 continue
735
736 title = f"{task.title}" if task.title else ""
737 content = f"{task.content}" if task.content else ""
738 task_text = f"{title}: {content}" if title and content else title or content
739 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
740 messages.append(
741 Message(
742 type="message",
743 content=[
744 ResponseInputTextParam(
745 type="input_text",
746 text=text,
747 )
748 ],
749 role="user",
750 )
751 )
752 return messages
753
754 async def widget_to_input(
755 self, item: WidgetItem
756 ) -> TResponseInputItem | list[TResponseInputItem] | None:
757 """
758 Convert a WidgetItem into input item(s) to send to the model.
759 By default, WidgetItems converted to a text description with a JSON representation of the widget.
760 """
761 return Message(
762 type="message",
763 content=[
764 ResponseInputTextParam(
765 type="input_text",
766 text=f"The following graphical UI widget (id: {item.id}) was displayed to the user:"
767 + item.widget.model_dump_json(
768 exclude_unset=True, exclude_none=True
769 ),
770 )
771 ],
772 role="user",
773 )
774
775 async def user_message_to_input(
776 self, item: UserMessageItem, is_last_message: bool = True
777 ) -> TResponseInputItem | list[TResponseInputItem] | None:
778 # Build the user text exactly as typed, rendering tags as @key
779 message_text_parts: list[str] = []
780 # Track tags separately to add system context
781 raw_tags: list[UserMessageTagContent] = []
782
783 for part in item.content:
784 if isinstance(part, UserMessageTextContent):
785 message_text_parts.append(part.text)
786 elif isinstance(part, UserMessageTagContent):
787 message_text_parts.append(f"@{part.text}")
788 raw_tags.append(part)
789 else:
790 assert_never(part)
791
792 user_text_item = Message(
793 role="user",
794 type="message",
795 content=[
796 ResponseInputTextParam(
797 type="input_text", text="".join(message_text_parts)
798 ),
799 *[
800 await self.attachment_to_message_content(a)
801 for a in item.attachments
802 ],
803 ],
804 )
805
806 # Build system items (prepend later): quoted text and @-mention context
807 context_items: list[TResponseInputItem] = []
808
809 if item.quoted_text and is_last_message:
810 context_items.append(
811 Message(
812 role="user",
813 type="message",
814 content=[
815 ResponseInputTextParam(
816 type="input_text",
817 text=f"The user is referring to this in particular: \n{item.quoted_text}",
818 )
819 ],
820 )
821 )
822
823 # Dedupe tags (preserve order) and resolve to message content
824 if raw_tags:
825 seen, uniq_tags = set(), []
826 for t in raw_tags:
827 if t.text not in seen:
828 seen.add(t.text)
829 uniq_tags.append(t)
830
831 tag_content: ResponseInputMessageContentListParam = [
832 # should return summarized text items
833 await self.tag_to_message_content(tag)
834 for tag in uniq_tags
835 ]
836
837 if tag_content:
838 context_items.append(
839 Message(
840 role="user",
841 type="message",
842 content=[
843 ResponseInputTextParam(
844 type="input_text",
845 text=cleandoc("""
846 # User-provided context for @-mentions
847 - When referencing resolved entities, use their canonical names **without** '@'.
848 - The '@' form appears only in user text and should not be echoed.
849 """).strip(),
850 ),
851 *tag_content,
852 ],
853 )
854 )
855
856 return [user_text_item, *context_items]
857
858 async def assistant_message_to_input(
859 self, item: AssistantMessageItem
860 ) -> TResponseInputItem | list[TResponseInputItem] | None:
861 return EasyInputMessageParam(
862 type="message",
863 content=[
864 # content param doesn't support the assistant message content types
865 cast(
866 ResponseInputContentParam,
867 ResponseOutputText(
868 type="output_text",
869 text=c.text,
870 annotations=[], # TODO: these should be sent back as well
871 ).model_dump(),
872 )
873 for c in item.content
874 ],
875 role="assistant",
876 )
877
878 async def client_tool_call_to_input(
879 self, item: ClientToolCallItem
880 ) -> TResponseInputItem | list[TResponseInputItem] | None:
881 if item.status == "pending":
882 # Filter out pending tool calls - they cannot be sent to the model
883 return None
884
885 return [
886 ResponseFunctionToolCallParam(
887 type="function_call",
888 call_id=item.call_id,
889 name=item.name,
890 arguments=json.dumps(item.arguments),
891 ),
892 FunctionCallOutput(
893 type="function_call_output",
894 call_id=item.call_id,
895 output=json.dumps(item.output),
896 ),
897 ]
898
899 async def end_of_turn_to_input(
900 self, item: EndOfTurnItem
901 ) -> TResponseInputItem | list[TResponseInputItem] | None:
902 # Only used for UI hints - you shouldn't need to override this
903 return None
904
905 async def _thread_item_to_input_item(
906 self,
907 item: ThreadItem,
908 is_last_message: bool = True,
909 ) -> list[TResponseInputItem]:
910 match item:
911 case UserMessageItem():
912 out = await self.user_message_to_input(item, is_last_message) or []
913 return out if isinstance(out, list) else [out]
914 case AssistantMessageItem():
915 out = await self.assistant_message_to_input(item) or []
916 return out if isinstance(out, list) else [out]
917 case ClientToolCallItem():
918 out = await self.client_tool_call_to_input(item) or []
919 return out if isinstance(out, list) else [out]
920 case EndOfTurnItem():
921 out = await self.end_of_turn_to_input(item) or []
922 return out if isinstance(out, list) else [out]
923 case WidgetItem():
924 out = await self.widget_to_input(item) or []
925 return out if isinstance(out, list) else [out]
926 case WorkflowItem():
927 out = await self.workflow_to_input(item) or []
928 return out if isinstance(out, list) else [out]
929 case TaskItem():
930 out = await self.task_to_input(item) or []
931 return out if isinstance(out, list) else [out]
932 case HiddenContextItem():
933 out = await self.hidden_context_to_input(item) or []
934 return out if isinstance(out, list) else [out]
935 case _:
936 assert_never(item)
937
938 async def to_agent_input(
939 self,
940 thread_items: Sequence[ThreadItem] | ThreadItem,
941 ) -> list[TResponseInputItem]:
942 if isinstance(thread_items, Sequence):
943 # shallow copy in case caller mutates the list while we're iterating
944 thread_items = thread_items[:]
945 else:
946 thread_items = [thread_items]
947 output: list[TResponseInputItem] = []
948 for item in thread_items:
949 output.extend(
950 await self._thread_item_to_input_item(
951 item,
952 is_last_message=item is thread_items[-1],
953 )
954 )
955 return output
956
957
958_DEFAULT_CONVERTER = ThreadItemConverter()
959
960
961def simple_to_agent_input(thread_items: Sequence[ThreadItem] | ThreadItem):
962 return _DEFAULT_CONVERTER.to_agent_input(thread_items)
963