openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
7327b5dbc039093f32c6b3bc8b7a0ecad950cd12

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/agents.py

932lines · modecode

1import asyncio
2import json
3from collections.abc import AsyncIterator
4from datetime import datetime
5from inspect import cleandoc
6from typing import (
7 Annotated,
8 Any,
9 AsyncGenerator,
10 Generic,
11 Sequence,
12 TypeVar,
13 cast,
14)
15
16from agents import (
17 InputGuardrailTripwireTriggered,
18 OutputGuardrailTripwireTriggered,
19 RunResultStreaming,
20 StreamEvent,
21 TResponseInputItem,
22)
23from openai.types.responses import (
24 EasyInputMessageParam,
25 ResponseFunctionToolCallParam,
26 ResponseInputContentParam,
27 ResponseInputMessageContentListParam,
28 ResponseInputTextParam,
29 ResponseOutputText,
30)
31from openai.types.responses.response_input_item_param import (
32 FunctionCallOutput,
33 Message,
34)
35from openai.types.responses.response_output_message import Content
36from openai.types.responses.response_output_text import (
37 Annotation as ResponsesAnnotation,
38)
39from pydantic import BaseModel, ConfigDict, SkipValidation, TypeAdapter
40from typing_extensions import assert_never
41
42from .server import stream_widget
43from .store import Store, StoreItemType
44from .types import (
45 Annotation,
46 AssistantMessageContent,
47 AssistantMessageContentPartAdded,
48 AssistantMessageContentPartDone,
49 AssistantMessageContentPartTextDelta,
50 AssistantMessageItem,
51 Attachment,
52 ClientToolCallItem,
53 DurationSummary,
54 EndOfTurnItem,
55 FileSource,
56 HiddenContextItem,
57 Task,
58 TaskItem,
59 ThoughtTask,
60 ThreadItem,
61 ThreadItemAddedEvent,
62 ThreadItemDoneEvent,
63 ThreadItemRemovedEvent,
64 ThreadItemUpdated,
65 ThreadMetadata,
66 ThreadStreamEvent,
67 URLSource,
68 UserMessageItem,
69 UserMessageTagContent,
70 UserMessageTextContent,
71 WidgetItem,
72 Workflow,
73 WorkflowItem,
74 WorkflowSummary,
75 WorkflowTaskAdded,
76 WorkflowTaskUpdated,
77)
78from .widgets import Markdown, Text, WidgetRoot
79
80
81class ClientToolCall(BaseModel):
82 """
83 Returned from tool methods to indicate a client-side tool call.
84 """
85
86 name: str
87 arguments: dict[str, Any]
88
89
90class _QueueCompleteSentinel: ...
91
92
93TContext = TypeVar("TContext")
94
95
96class AgentContext(BaseModel, Generic[TContext]):
97 model_config = ConfigDict(arbitrary_types_allowed=True)
98
99 thread: ThreadMetadata
100 store: Annotated[Store[TContext], SkipValidation]
101 request_context: TContext
102 previous_response_id: str | None = None
103 client_tool_call: ClientToolCall | None = None
104 workflow_item: WorkflowItem | None = None
105 _events: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel] = asyncio.Queue()
106
107 def generate_id(
108 self, type: StoreItemType, thread: ThreadMetadata | None = None
109 ) -> str:
110 if type == "thread":
111 return self.store.generate_thread_id(self.request_context)
112 return self.store.generate_item_id(
113 type, thread or self.thread, self.request_context
114 )
115
116 async def stream_widget(
117 self,
118 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
119 copy_text: str | None = None,
120 ) -> None:
121 async for event in stream_widget(
122 self.thread,
123 widget,
124 copy_text,
125 lambda item_type: self.store.generate_item_id(
126 item_type, self.thread, self.request_context
127 ),
128 ):
129 await self._events.put(event)
130
131 async def end_workflow(
132 self, summary: WorkflowSummary | None = None, expanded: bool = False
133 ) -> None:
134 if not self.workflow_item:
135 # No workflow to end
136 return
137
138 if summary is not None:
139 self.workflow_item.workflow.summary = summary
140 elif self.workflow_item.workflow.summary is None:
141 # If no summary was set or provided, set a basic work summary
142 delta = datetime.now() - self.workflow_item.created_at
143 duration = int(delta.total_seconds())
144 self.workflow_item.workflow.summary = DurationSummary(duration=duration)
145 self.workflow_item.workflow.expanded = expanded
146 await self.stream(ThreadItemDoneEvent(item=self.workflow_item))
147 self.workflow_item = None
148
149 async def start_workflow(self, workflow: Workflow) -> None:
150 self.workflow_item = WorkflowItem(
151 id=self.generate_id("workflow"),
152 created_at=datetime.now(),
153 workflow=workflow,
154 thread_id=self.thread.id,
155 )
156
157 if workflow.type != "reasoning" and len(workflow.tasks) == 0:
158 # Defer sending added event until we have tasks
159 return
160
161 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
162
163 async def update_workflow_task(self, task: Task, task_index: int) -> None:
164 if self.workflow_item is None:
165 raise ValueError("Workflow is not set")
166 # ensure reference is updated in case task is a copy
167 self.workflow_item.workflow.tasks[task_index] = task
168 await self.stream(
169 ThreadItemUpdated(
170 item_id=self.workflow_item.id,
171 update=WorkflowTaskUpdated(
172 task=task,
173 task_index=task_index,
174 ),
175 )
176 )
177
178 async def add_workflow_task(self, task: Task) -> None:
179 self.workflow_item = self.workflow_item or WorkflowItem(
180 id=self.generate_id("workflow"),
181 created_at=datetime.now(),
182 workflow=Workflow(type="custom", tasks=[]),
183 thread_id=self.thread.id,
184 )
185 workflow = self.workflow_item.workflow
186 workflow.tasks.append(task)
187
188 if workflow.type != "reasoning" and len(workflow.tasks) == 1:
189 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
190 else:
191 await self.stream(
192 ThreadItemUpdated(
193 item_id=self.workflow_item.id,
194 update=WorkflowTaskAdded(
195 task=task,
196 task_index=workflow.tasks.index(task),
197 ),
198 )
199 )
200
201 async def stream(self, event: ThreadStreamEvent) -> None:
202 await self._events.put(event)
203
204 def _complete(self):
205 self._events.put_nowait(_QueueCompleteSentinel())
206
207
208def _convert_content(content: Content) -> AssistantMessageContent:
209 if content.type == "output_text":
210 annotations = []
211 for annotation in content.annotations:
212 annotations.extend(_convert_annotation(annotation))
213 return AssistantMessageContent(
214 text=content.text,
215 annotations=annotations,
216 )
217 else:
218 return AssistantMessageContent(
219 text=content.refusal,
220 annotations=[],
221 )
222
223
224def _convert_annotation(
225 annotation: ResponsesAnnotation,
226) -> list[Annotation]:
227 # There is a bug in the OpenAPI client that sometimes parses the annotation delta event into the wrong class
228 # resulting into annotation being a dict instead of a ResponsesAnnotation
229 if isinstance(annotation, dict):
230 annotation = TypeAdapter(ResponsesAnnotation).validate_python(annotation)
231
232 result: list[Annotation] = []
233 if annotation.type == "file_citation":
234 filename = annotation.filename
235 if not filename:
236 return []
237 result.append(
238 Annotation(
239 source=FileSource(filename=filename, title=filename),
240 index=annotation.index,
241 )
242 )
243 elif annotation.type == "url_citation":
244 result.append(
245 Annotation(
246 source=URLSource(
247 url=annotation.url,
248 title=annotation.title,
249 ),
250 index=annotation.end_index,
251 )
252 )
253
254 return result
255
256
257T1 = TypeVar("T1")
258T2 = TypeVar("T2")
259
260
261async def _merge_generators(
262 a: AsyncIterator[T1],
263 b: AsyncIterator[T2],
264) -> AsyncIterator[T1 | T2]:
265 pending: list[AsyncIterator[T1 | T2]] = [a, b]
266 pending_tasks: dict[asyncio.Task, AsyncIterator[T1 | T2]] = {
267 asyncio.ensure_future(g.__anext__()): g for g in pending
268 }
269 while len(pending_tasks) > 0:
270 done, _ = await asyncio.wait(
271 pending_tasks.keys(), return_when="FIRST_COMPLETED"
272 )
273 stop = False
274 for d in done:
275 try:
276 result = d.result()
277 yield result
278 dg = pending_tasks[d]
279 pending_tasks[asyncio.ensure_future(dg.__anext__())] = dg
280 except StopAsyncIteration:
281 stop = True
282 finally:
283 del pending_tasks[d]
284 if stop:
285 for task in pending_tasks.keys():
286 if not task.cancel():
287 try:
288 yield task.result()
289 except asyncio.CancelledError:
290 pass
291 except asyncio.InvalidStateError:
292 pass
293 break
294
295
296class _EventWrapper:
297 def __init__(self, event: ThreadStreamEvent):
298 self.event = event
299
300
301class _AsyncQueueIterator(AsyncIterator[_EventWrapper]):
302 def __init__(
303 self, queue: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel]
304 ):
305 self.queue = queue
306 self.completed = False
307
308 def __aiter__(self):
309 return self
310
311 async def __anext__(self):
312 if self.completed:
313 raise StopAsyncIteration
314
315 item = await self.queue.get()
316 if isinstance(item, _QueueCompleteSentinel):
317 self.completed = True
318 raise StopAsyncIteration
319 return _EventWrapper(item)
320
321 def drain_and_complete(self) -> None:
322 """Empty the underlying queue without awaiting and mark this iterator completed.
323
324 This is intended for cleanup paths where we must guarantee no awaits
325 occur. All queued items, including any completion sentinel, are
326 discarded.
327 """
328 while True:
329 try:
330 self.queue.get_nowait()
331 except asyncio.QueueEmpty:
332 break
333 self.completed = True
334
335
336class StreamingThoughtTracker(BaseModel):
337 item_id: str
338 index: int
339 task: ThoughtTask
340
341
342async def stream_agent_response(
343 context: AgentContext, result: RunResultStreaming
344) -> AsyncIterator[ThreadStreamEvent]:
345 current_item_id = None
346 current_tool_call = None
347 ctx = context
348 thread = context.thread
349 queue_iterator = _AsyncQueueIterator(context._events)
350 produced_items = set()
351 streaming_thought: None | StreamingThoughtTracker = None
352
353 # check if the last item in the thread was a workflow or a client tool call
354 # if it was a client tool call, check if the second last item was a workflow
355 # if either was, continue the workflow
356 items = await context.store.load_thread_items(
357 thread.id, None, 2, "desc", context.request_context
358 )
359 last_item = items.data[0] if len(items.data) > 0 else None
360 second_last_item = items.data[1] if len(items.data) > 1 else None
361
362 if last_item and last_item.type == "workflow":
363 ctx.workflow_item = last_item
364 elif (
365 last_item
366 and last_item.type == "client_tool_call"
367 and second_last_item
368 and second_last_item.type == "workflow"
369 ):
370 ctx.workflow_item = second_last_item
371
372 def end_workflow(item: WorkflowItem):
373 if item == ctx.workflow_item:
374 ctx.workflow_item = None
375 delta = datetime.now() - item.created_at
376 duration = int(delta.total_seconds())
377 if item.workflow.summary is None:
378 item.workflow.summary = DurationSummary(duration=duration)
379 # Default to closing all workflows
380 # To keep a workflow open on completion, close it explicitly with
381 # AgentContext.end_workflow(expanded=True)
382 item.workflow.expanded = False
383 return ThreadItemDoneEvent(item=item)
384
385 try:
386 async for event in _merge_generators(result.stream_events(), queue_iterator):
387 # Events emitted from agent context helpers
388 if isinstance(event, _EventWrapper):
389 event = event.event
390 if (
391 event.type == "thread.item.added"
392 or event.type == "thread.item.done"
393 ):
394 # End the current workflow if visual item is added after it
395 if (
396 ctx.workflow_item
397 and ctx.workflow_item.id != event.item.id
398 and event.item.type != "client_tool_call"
399 and event.item.type != "hidden_context_item"
400 ):
401 yield end_workflow(ctx.workflow_item)
402
403 # track the current workflow if one is added
404 if (
405 event.type == "thread.item.added"
406 and event.item.type == "workflow"
407 ):
408 ctx.workflow_item = event.item
409
410 # track integration produced items so we can clean them up if
411 # there is a guardrail tripwire
412 produced_items.add(event.item.id)
413 yield event
414 continue
415
416 if event.type == "run_item_stream_event":
417 event = event.item
418 if (
419 event.type == "tool_call_item"
420 and event.raw_item.type == "function_call"
421 ):
422 current_tool_call = event.raw_item.call_id
423 current_item_id = event.raw_item.id
424 assert current_item_id
425 produced_items.add(current_item_id)
426 continue
427
428 if event.type != "raw_response_event":
429 # Ignore everything else that isn't a raw response event
430 continue
431
432 # Handle Responses events
433 event = event.data
434 if event.type == "response.content_part.added":
435 if event.part.type == "reasoning_text":
436 continue
437 content = _convert_content(event.part)
438 yield ThreadItemUpdated(
439 item_id=event.item_id,
440 update=AssistantMessageContentPartAdded(
441 content_index=event.content_index,
442 content=content,
443 ),
444 )
445 elif event.type == "response.output_text.delta":
446 yield ThreadItemUpdated(
447 item_id=event.item_id,
448 update=AssistantMessageContentPartTextDelta(
449 content_index=event.content_index,
450 delta=event.delta,
451 ),
452 )
453 elif event.type == "response.output_text.done":
454 yield ThreadItemUpdated(
455 item_id=event.item_id,
456 update=AssistantMessageContentPartDone(
457 content_index=event.content_index,
458 content=AssistantMessageContent(
459 text=event.text,
460 annotations=[],
461 ),
462 ),
463 )
464 elif event.type == "response.output_text.annotation.added":
465 # Ignore annotation-added events; annotations are reflected in the final item content.
466 continue
467 elif event.type == "response.output_item.added":
468 item = event.item
469 if item.type == "reasoning" and not ctx.workflow_item:
470 ctx.workflow_item = WorkflowItem(
471 id=ctx.generate_id("workflow"),
472 created_at=datetime.now(),
473 workflow=Workflow(type="reasoning", tasks=[]),
474 thread_id=thread.id,
475 )
476 produced_items.add(ctx.workflow_item.id)
477 yield ThreadItemAddedEvent(item=ctx.workflow_item)
478 if item.type == "message":
479 if ctx.workflow_item:
480 yield end_workflow(ctx.workflow_item)
481 produced_items.add(item.id)
482 yield ThreadItemAddedEvent(
483 item=AssistantMessageItem(
484 # Reusing the Responses message ID
485 id=item.id,
486 thread_id=thread.id,
487 content=[_convert_content(c) for c in item.content],
488 created_at=datetime.now(),
489 ),
490 )
491 elif event.type == "response.reasoning_summary_text.delta":
492 if not ctx.workflow_item:
493 continue
494
495 # stream the first thought in a new workflow so that we can show it earlier
496 if (
497 ctx.workflow_item.workflow.type == "reasoning"
498 and len(ctx.workflow_item.workflow.tasks) == 0
499 ):
500 streaming_thought = StreamingThoughtTracker(
501 item_id=event.item_id,
502 index=event.summary_index,
503 task=ThoughtTask(content=event.delta),
504 )
505 ctx.workflow_item.workflow.tasks.append(streaming_thought.task)
506 yield ThreadItemUpdated(
507 item_id=ctx.workflow_item.id,
508 update=WorkflowTaskAdded(
509 task=streaming_thought.task,
510 task_index=0,
511 ),
512 )
513 elif (
514 streaming_thought
515 and streaming_thought.task in ctx.workflow_item.workflow.tasks
516 and event.item_id == streaming_thought.item_id
517 and event.summary_index == streaming_thought.index
518 ):
519 streaming_thought.task.content += event.delta
520 yield ThreadItemUpdated(
521 item_id=ctx.workflow_item.id,
522 update=WorkflowTaskUpdated(
523 task=streaming_thought.task,
524 task_index=ctx.workflow_item.workflow.tasks.index(
525 streaming_thought.task
526 ),
527 ),
528 )
529 elif event.type == "response.reasoning_summary_text.done":
530 if ctx.workflow_item:
531 if (
532 streaming_thought
533 and streaming_thought.task in ctx.workflow_item.workflow.tasks
534 and event.item_id == streaming_thought.item_id
535 and event.summary_index == streaming_thought.index
536 ):
537 task = streaming_thought.task
538 task.content = event.text
539 streaming_thought = None
540 update = WorkflowTaskUpdated(
541 task=task,
542 task_index=ctx.workflow_item.workflow.tasks.index(task),
543 )
544 else:
545 task = ThoughtTask(content=event.text)
546 ctx.workflow_item.workflow.tasks.append(task)
547 update = WorkflowTaskAdded(
548 task=task,
549 task_index=ctx.workflow_item.workflow.tasks.index(task),
550 )
551 yield ThreadItemUpdated(
552 item_id=ctx.workflow_item.id,
553 update=update,
554 )
555 elif event.type == "response.output_item.done":
556 item = event.item
557 if item.type == "message":
558 produced_items.add(item.id)
559 yield ThreadItemDoneEvent(
560 item=AssistantMessageItem(
561 # Reusing the Responses message ID
562 id=item.id,
563 thread_id=thread.id,
564 content=[_convert_content(c) for c in item.content],
565 created_at=datetime.now(),
566 ),
567 )
568
569 except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
570 for item_id in produced_items:
571 yield ThreadItemRemovedEvent(item_id=item_id)
572
573 # Drain remaining events without processing them
574 context._complete()
575 queue_iterator.drain_and_complete()
576
577 raise
578
579 context._complete()
580
581 # Drain remaining events
582 async for event in queue_iterator:
583 yield event.event
584
585 # If there is still an active workflow at the end of the run, store
586 # it's current state so that we can continue it in the next turn.
587 if ctx.workflow_item:
588 await ctx.store.add_thread_item(
589 thread.id, ctx.workflow_item, ctx.request_context
590 )
591
592 if context.client_tool_call:
593 yield ThreadItemDoneEvent(
594 item=ClientToolCallItem(
595 id=current_item_id
596 or context.store.generate_item_id(
597 "tool_call", thread, context.request_context
598 ),
599 thread_id=thread.id,
600 name=context.client_tool_call.name,
601 arguments=context.client_tool_call.arguments,
602 created_at=datetime.now(),
603 call_id=current_tool_call
604 or context.store.generate_item_id(
605 "tool_call", thread, context.request_context
606 ),
607 ),
608 )
609
610
611TWidget = TypeVar("TWidget", bound=Markdown | Text)
612
613
614async def accumulate_text(
615 events: AsyncIterator[StreamEvent],
616 base_widget: TWidget,
617) -> AsyncIterator[TWidget]:
618 text = ""
619 yield base_widget
620 async for event in events:
621 if event.type == "raw_response_event":
622 if event.data.type == "response.output_text.delta":
623 text += event.data.delta
624 yield base_widget.model_copy(update={"value": text})
625 yield base_widget.model_copy(update={"value": text, "streaming": False})
626
627
628class ThreadItemConverter:
629 """
630 Converts thread items to Agent SDK input items.
631 Widgets, Tasks, and Workflows have default conversions but can be customized.
632 Attachments, Tags, and HiddenContextItems require custom handling based on the use case.
633 Other item types are converted automatically.
634 """
635
636 async def attachment_to_message_content(
637 self, attachment: Attachment
638 ) -> ResponseInputContentParam:
639 """
640 Convert an attachment in a user message into a message content part to send to the model.
641 Required when attachments are enabled.
642 """
643 raise NotImplementedError(
644 "An Attachment was included in a UserMessageItem but Converter.attachment_to_message_content was not implemented"
645 )
646
647 async def tag_to_message_content(
648 self, tag: UserMessageTagContent
649 ) -> ResponseInputContentParam:
650 """
651 Convert a tag in a user message into a message content part to send to the model as context.
652 Required when tags are used.
653 """
654 raise NotImplementedError(
655 "A Tag was included in a UserMessageItem but Converter.tag_to_message_content is not implemented"
656 )
657
658 async def hidden_context_to_input(
659 self, item: HiddenContextItem
660 ) -> TResponseInputItem | list[TResponseInputItem] | None:
661 """
662 Convert a HiddenContextItem into input item(s) to send to the model.
663 Required when HiddenContextItem are used.
664 """
665 raise NotImplementedError(
666 "HiddenContextItem were present in a user message but Converter.hidden_context_to_input was not implemented"
667 )
668
669 async def task_to_input(
670 self, item: TaskItem
671 ) -> TResponseInputItem | list[TResponseInputItem] | None:
672 """
673 Convert a TaskItem into input item(s) to send to the model.
674 """
675 if item.task.type != "custom" or (
676 not item.task.title and not item.task.content
677 ):
678 return None
679 title = f"{item.task.title}" if item.task.title else ""
680 content = f"{item.task.content}" if item.task.content else ""
681 task_text = f"{title}: {content}" if title and content else title or content
682 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
683 return Message(
684 type="message",
685 content=[
686 ResponseInputTextParam(
687 type="input_text",
688 text=text,
689 )
690 ],
691 role="user",
692 )
693
694 async def workflow_to_input(
695 self, item: WorkflowItem
696 ) -> TResponseInputItem | list[TResponseInputItem] | None:
697 """
698 Convert a TaskItem into input item(s) to send to the model.
699 Returns WorkflowItem.response_items by default.
700 """
701 messages = []
702 for task in item.workflow.tasks:
703 if task.type != "custom" or (not task.title and not task.content):
704 continue
705
706 title = f"{task.title}" if task.title else ""
707 content = f"{task.content}" if task.content else ""
708 task_text = f"{title}: {content}" if title and content else title or content
709 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
710 messages.append(
711 Message(
712 type="message",
713 content=[
714 ResponseInputTextParam(
715 type="input_text",
716 text=text,
717 )
718 ],
719 role="user",
720 )
721 )
722 return messages
723
724 async def widget_to_input(
725 self, item: WidgetItem
726 ) -> TResponseInputItem | list[TResponseInputItem] | None:
727 """
728 Convert a WidgetItem into input item(s) to send to the model.
729 By default, WidgetItems converted to a text description with a JSON representation of the widget.
730 """
731 return Message(
732 type="message",
733 content=[
734 ResponseInputTextParam(
735 type="input_text",
736 text=f"The following graphical UI widget (id: {item.id}) was displayed to the user:"
737 + item.widget.model_dump_json(
738 exclude_unset=True, exclude_none=True
739 ),
740 )
741 ],
742 role="user",
743 )
744
745 async def user_message_to_input(
746 self, item: UserMessageItem, is_last_message: bool = True
747 ) -> TResponseInputItem | list[TResponseInputItem] | None:
748 # Build the user text exactly as typed, rendering tags as @key
749 message_text_parts: list[str] = []
750 # Track tags separately to add system context
751 raw_tags: list[UserMessageTagContent] = []
752
753 for part in item.content:
754 if isinstance(part, UserMessageTextContent):
755 message_text_parts.append(part.text)
756 elif isinstance(part, UserMessageTagContent):
757 message_text_parts.append(f"@{part.text}")
758 raw_tags.append(part)
759 else:
760 assert_never(part)
761
762 user_text_item = Message(
763 role="user",
764 type="message",
765 content=[
766 ResponseInputTextParam(
767 type="input_text", text="".join(message_text_parts)
768 ),
769 *[
770 await self.attachment_to_message_content(a)
771 for a in item.attachments
772 ],
773 ],
774 )
775
776 # Build system items (prepend later): quoted text and @-mention context
777 context_items: list[TResponseInputItem] = []
778
779 if item.quoted_text and is_last_message:
780 context_items.append(
781 Message(
782 role="user",
783 type="message",
784 content=[
785 ResponseInputTextParam(
786 type="input_text",
787 text=f"The user is referring to this in particular: \n{item.quoted_text}",
788 )
789 ],
790 )
791 )
792
793 # Dedupe tags (preserve order) and resolve to message content
794 if raw_tags:
795 seen, uniq_tags = set(), []
796 for t in raw_tags:
797 if t.text not in seen:
798 seen.add(t.text)
799 uniq_tags.append(t)
800
801 tag_content: ResponseInputMessageContentListParam = [
802 # should return summarized text items
803 await self.tag_to_message_content(tag)
804 for tag in uniq_tags
805 ]
806
807 if tag_content:
808 context_items.append(
809 Message(
810 role="user",
811 type="message",
812 content=[
813 ResponseInputTextParam(
814 type="input_text",
815 text=cleandoc("""
816 # User-provided context for @-mentions
817 - When referencing resolved entities, use their canonical names **without** '@'.
818 - The '@' form appears only in user text and should not be echoed.
819 """).strip(),
820 ),
821 *tag_content,
822 ],
823 )
824 )
825
826 return [user_text_item, *context_items]
827
828 async def assistant_message_to_input(
829 self, item: AssistantMessageItem
830 ) -> TResponseInputItem | list[TResponseInputItem] | None:
831 return EasyInputMessageParam(
832 type="message",
833 content=[
834 # content param doesn't support the assistant message content types
835 cast(
836 ResponseInputContentParam,
837 ResponseOutputText(
838 type="output_text",
839 text=c.text,
840 annotations=[], # TODO: these should be sent back as well
841 ).model_dump(),
842 )
843 for c in item.content
844 ],
845 role="assistant",
846 )
847
848 async def client_tool_call_to_input(
849 self, item: ClientToolCallItem
850 ) -> TResponseInputItem | list[TResponseInputItem] | None:
851 if item.status == "pending":
852 # Filter out pending tool calls - they cannot be sent to the model
853 return None
854
855 return [
856 ResponseFunctionToolCallParam(
857 type="function_call",
858 call_id=item.call_id,
859 name=item.name,
860 arguments=json.dumps(item.arguments),
861 ),
862 FunctionCallOutput(
863 type="function_call_output",
864 call_id=item.call_id,
865 output=json.dumps(item.output),
866 ),
867 ]
868
869 async def end_of_turn_to_input(
870 self, item: EndOfTurnItem
871 ) -> TResponseInputItem | list[TResponseInputItem] | None:
872 # Only used for UI hints - you shouldn't need to override this
873 return None
874
875 async def _thread_item_to_input_item(
876 self,
877 item: ThreadItem,
878 is_last_message: bool = True,
879 ) -> list[TResponseInputItem]:
880 match item:
881 case UserMessageItem():
882 out = await self.user_message_to_input(item, is_last_message) or []
883 return out if isinstance(out, list) else [out]
884 case AssistantMessageItem():
885 out = await self.assistant_message_to_input(item) or []
886 return out if isinstance(out, list) else [out]
887 case ClientToolCallItem():
888 out = await self.client_tool_call_to_input(item) or []
889 return out if isinstance(out, list) else [out]
890 case EndOfTurnItem():
891 out = await self.end_of_turn_to_input(item) or []
892 return out if isinstance(out, list) else [out]
893 case WidgetItem():
894 out = await self.widget_to_input(item) or []
895 return out if isinstance(out, list) else [out]
896 case WorkflowItem():
897 out = await self.workflow_to_input(item) or []
898 return out if isinstance(out, list) else [out]
899 case TaskItem():
900 out = await self.task_to_input(item) or []
901 return out if isinstance(out, list) else [out]
902 case HiddenContextItem():
903 out = await self.hidden_context_to_input(item) or []
904 return out if isinstance(out, list) else [out]
905 case _:
906 assert_never(item)
907
908 async def to_agent_input(
909 self,
910 thread_items: Sequence[ThreadItem] | ThreadItem,
911 ) -> list[TResponseInputItem]:
912 if isinstance(thread_items, Sequence):
913 # shallow copy in case caller mutates the list while we're iterating
914 thread_items = thread_items[:]
915 else:
916 thread_items = [thread_items]
917 output: list[TResponseInputItem] = []
918 for item in thread_items:
919 output.extend(
920 await self._thread_item_to_input_item(
921 item,
922 is_last_message=item is thread_items[-1],
923 )
924 )
925 return output
926
927
928_DEFAULT_CONVERTER = ThreadItemConverter()
929
930
931def simple_to_agent_input(thread_items: Sequence[ThreadItem] | ThreadItem):
932 return _DEFAULT_CONVERTER.to_agent_input(thread_items)
933