openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
c9f5f09d50d3fa3029f1576bc8750f0f01286b16

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

chatkit/agents.py

933lines · modecode

1import asyncio
2import json
3from collections.abc import AsyncIterator
4from datetime import datetime
5from inspect import cleandoc
6from typing import (
7 Annotated,
8 Any,
9 AsyncGenerator,
10 Awaitable,
11 Generic,
12 Sequence,
13 TypeVar,
14 assert_never,
15 cast,
16)
17
18from agents import (
19 InputGuardrailTripwireTriggered,
20 OutputGuardrailTripwireTriggered,
21 RunResultStreaming,
22 StreamEvent,
23 TResponseInputItem,
24)
25from openai.types.responses import (
26 EasyInputMessageParam,
27 ResponseFunctionToolCallParam,
28 ResponseInputContentParam,
29 ResponseInputMessageContentListParam,
30 ResponseInputTextParam,
31 ResponseOutputText,
32)
33from openai.types.responses.response_input_item_param import (
34 FunctionCallOutput,
35 Message,
36)
37from openai.types.responses.response_output_message import Content
38from openai.types.responses.response_output_text import (
39 Annotation as ResponsesAnnotation,
40)
41from pydantic import BaseModel, ConfigDict, SkipValidation, TypeAdapter
42
43from .server import stream_widget
44from .store import Store, StoreItemType
45from .types import (
46 Annotation,
47 AssistantMessageContent,
48 AssistantMessageContentPartAdded,
49 AssistantMessageContentPartDone,
50 AssistantMessageContentPartTextDelta,
51 AssistantMessageItem,
52 Attachment,
53 ClientToolCallItem,
54 DurationSummary,
55 EndOfTurnItem,
56 FileSource,
57 HiddenContextItem,
58 Task,
59 TaskItem,
60 ThoughtTask,
61 ThreadItem,
62 ThreadItemAddedEvent,
63 ThreadItemDoneEvent,
64 ThreadItemRemovedEvent,
65 ThreadItemUpdated,
66 ThreadMetadata,
67 ThreadStreamEvent,
68 URLSource,
69 UserMessageItem,
70 UserMessageTagContent,
71 UserMessageTextContent,
72 WidgetItem,
73 Workflow,
74 WorkflowItem,
75 WorkflowSummary,
76 WorkflowTaskAdded,
77 WorkflowTaskUpdated,
78)
79from .widgets import Markdown, Text, WidgetRoot
80
81
82class ClientToolCall(BaseModel):
83 """
84 Returned from tool methods to indicate a client-side tool call.
85 """
86
87 name: str
88 arguments: dict[str, Any]
89
90
91class _QueueCompleteSentinel: ...
92
93
94TContext = TypeVar("TContext")
95
96
97class AgentContext(BaseModel, Generic[TContext]):
98 model_config = ConfigDict(arbitrary_types_allowed=True)
99
100 thread: ThreadMetadata
101 store: Annotated[Store[TContext], SkipValidation]
102 request_context: TContext
103 previous_response_id: str | None = None
104 client_tool_call: ClientToolCall | None = None
105 workflow_item: WorkflowItem | None = None
106 _events: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel] = asyncio.Queue()
107
108 def generate_id(
109 self, type: StoreItemType, thread: ThreadMetadata | None = None
110 ) -> str:
111 if type == "thread":
112 return self.store.generate_thread_id(self.request_context)
113 return self.store.generate_item_id(
114 type, thread or self.thread, self.request_context
115 )
116
117 async def stream_widget(
118 self,
119 widget: WidgetRoot | AsyncGenerator[WidgetRoot, None],
120 copy_text: str | None = None,
121 ) -> None:
122 async for event in stream_widget(
123 self.thread,
124 widget,
125 copy_text,
126 lambda item_type: self.store.generate_item_id(
127 item_type, self.thread, self.request_context
128 ),
129 ):
130 await self._events.put(event)
131
132 async def end_workflow(
133 self, summary: WorkflowSummary | None = None, expanded: bool = False
134 ) -> None:
135 if not self.workflow_item:
136 # No workflow to end
137 return
138
139 if summary is not None:
140 self.workflow_item.workflow.summary = summary
141 elif self.workflow_item.workflow.summary is None:
142 # If no summary was set or provided, set a basic work summary
143 delta = datetime.now() - self.workflow_item.created_at
144 duration = int(delta.total_seconds())
145 self.workflow_item.workflow.summary = DurationSummary(duration=duration)
146 self.workflow_item.workflow.expanded = expanded
147 await self.stream(ThreadItemDoneEvent(item=self.workflow_item))
148 self.workflow_item = None
149
150 async def start_workflow(self, workflow: Workflow) -> None:
151 self.workflow_item = WorkflowItem(
152 id=self.generate_id("workflow"),
153 created_at=datetime.now(),
154 workflow=workflow,
155 thread_id=self.thread.id,
156 )
157
158 if workflow.type != "reasoning" and len(workflow.tasks) == 0:
159 # Defer sending added event until we have tasks
160 return
161
162 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
163
164 async def update_workflow_task(self, task: Task, task_index: int) -> None:
165 if self.workflow_item is None:
166 raise ValueError("Workflow is not set")
167 # ensure reference is updated in case task is a copy
168 self.workflow_item.workflow.tasks[task_index] = task
169 await self.stream(
170 ThreadItemUpdated(
171 item_id=self.workflow_item.id,
172 update=WorkflowTaskUpdated(
173 task=task,
174 task_index=task_index,
175 ),
176 )
177 )
178
179 async def add_workflow_task(self, task: Task) -> None:
180 self.workflow_item = self.workflow_item or WorkflowItem(
181 id=self.generate_id("workflow"),
182 created_at=datetime.now(),
183 workflow=Workflow(type="custom", tasks=[]),
184 thread_id=self.thread.id,
185 )
186 workflow = self.workflow_item.workflow
187 workflow.tasks.append(task)
188
189 if workflow.type != "reasoning" and len(workflow.tasks) == 1:
190 await self.stream(ThreadItemAddedEvent(item=self.workflow_item))
191 else:
192 await self.stream(
193 ThreadItemUpdated(
194 item_id=self.workflow_item.id,
195 update=WorkflowTaskAdded(
196 task=task,
197 task_index=workflow.tasks.index(task),
198 ),
199 )
200 )
201
202 async def stream(self, event: ThreadStreamEvent) -> None:
203 await self._events.put(event)
204
205 def _complete(self):
206 self._events.put_nowait(_QueueCompleteSentinel())
207
208
209def _convert_content(content: Content) -> AssistantMessageContent:
210 if content.type == "output_text":
211 annotations = []
212 for annotation in content.annotations:
213 annotations.extend(_convert_annotation(annotation))
214 return AssistantMessageContent(
215 text=content.text,
216 annotations=annotations,
217 )
218 else:
219 return AssistantMessageContent(
220 text=content.refusal,
221 annotations=[],
222 )
223
224
225def _convert_annotation(
226 annotation: ResponsesAnnotation,
227) -> list[Annotation]:
228 # There is a bug in the OpenAPI client that sometimes parses the annotation delta event into the wrong class
229 # resulting into annotation being a dict instead of a ResponsesAnnotation
230 if isinstance(annotation, dict):
231 annotation = TypeAdapter(ResponsesAnnotation).validate_python(annotation)
232
233 result: list[Annotation] = []
234 if annotation.type == "file_citation":
235 filename = annotation.filename
236 if not filename:
237 return []
238 result.append(
239 Annotation(
240 source=FileSource(filename=filename, title=filename),
241 index=annotation.index,
242 )
243 )
244 elif annotation.type == "url_citation":
245 result.append(
246 Annotation(
247 source=URLSource(
248 url=annotation.url,
249 title=annotation.title,
250 ),
251 index=annotation.end_index,
252 )
253 )
254
255 return result
256
257
258T1 = TypeVar("T1")
259T2 = TypeVar("T2")
260
261
262async def _merge_generators(
263 a: AsyncIterator[T1],
264 b: AsyncIterator[T2],
265) -> AsyncIterator[T1 | T2]:
266 pending: list[AsyncIterator[T1 | T2]] = [a, b]
267 pending_tasks: dict[asyncio.Task, AsyncIterator[T1 | T2]] = {
268 asyncio.ensure_future(g.__anext__()): g for g in pending
269 }
270 while len(pending_tasks) > 0:
271 done, _ = await asyncio.wait(
272 pending_tasks.keys(), return_when="FIRST_COMPLETED"
273 )
274 stop = False
275 for d in done:
276 try:
277 result = d.result()
278 yield result
279 dg = pending_tasks[d]
280 pending_tasks[asyncio.ensure_future(dg.__anext__())] = dg
281 except StopAsyncIteration:
282 stop = True
283 finally:
284 del pending_tasks[d]
285 if stop:
286 for task in pending_tasks.keys():
287 if not task.cancel():
288 try:
289 yield task.result()
290 except asyncio.CancelledError:
291 pass
292 except asyncio.InvalidStateError:
293 pass
294 break
295
296
297class _EventWrapper:
298 def __init__(self, event: ThreadStreamEvent):
299 self.event = event
300
301
302class _AsyncQueueIterator(AsyncIterator[_EventWrapper]):
303 def __init__(
304 self, queue: asyncio.Queue[ThreadStreamEvent | _QueueCompleteSentinel]
305 ):
306 self.queue = queue
307 self.completed = False
308
309 def __aiter__(self):
310 return self
311
312 async def __anext__(self):
313 if self.completed:
314 raise StopAsyncIteration
315
316 item = await self.queue.get()
317 if isinstance(item, _QueueCompleteSentinel):
318 self.completed = True
319 raise StopAsyncIteration
320 return _EventWrapper(item)
321
322 def drain_and_complete(self) -> None:
323 """Empty the underlying queue without awaiting and mark this iterator completed.
324
325 This is intended for cleanup paths where we must guarantee no awaits
326 occur. All queued items, including any completion sentinel, are
327 discarded.
328 """
329 while True:
330 try:
331 self.queue.get_nowait()
332 except asyncio.QueueEmpty:
333 break
334 self.completed = True
335
336
337class StreamingThoughtTracker(BaseModel):
338 item_id: str
339 index: int
340 task: ThoughtTask
341
342
343async def stream_agent_response(
344 context: AgentContext, result: RunResultStreaming
345) -> AsyncIterator[ThreadStreamEvent]:
346 current_item_id = None
347 current_tool_call = None
348 ctx = context
349 thread = context.thread
350 queue_iterator = _AsyncQueueIterator(context._events)
351 produced_items = set()
352 streaming_thought: None | StreamingThoughtTracker = None
353
354 # check if the last item in the thread was a workflow or a client tool call
355 # if it was a client tool call, check if the second last item was a workflow
356 # if either was, continue the workflow
357 items = await context.store.load_thread_items(
358 thread.id, None, 2, "desc", context.request_context
359 )
360 last_item = items.data[0] if len(items.data) > 0 else None
361 second_last_item = items.data[1] if len(items.data) > 1 else None
362
363 if last_item and last_item.type == "workflow":
364 ctx.workflow_item = last_item
365 elif (
366 last_item
367 and last_item.type == "client_tool_call"
368 and second_last_item
369 and second_last_item.type == "workflow"
370 ):
371 ctx.workflow_item = second_last_item
372
373 def end_workflow(item: WorkflowItem):
374 if item == ctx.workflow_item:
375 ctx.workflow_item = None
376 delta = datetime.now() - item.created_at
377 duration = int(delta.total_seconds())
378 if item.workflow.summary is None:
379 item.workflow.summary = DurationSummary(duration=duration)
380 # Default to closing all workflows
381 # To keep a workflow open on completion, close it explicitly with
382 # AgentContext.end_workflow(expanded=True)
383 item.workflow.expanded = False
384 return ThreadItemDoneEvent(item=item)
385
386 try:
387 async for event in _merge_generators(result.stream_events(), queue_iterator):
388 # Events emitted from agent context helpers
389 if isinstance(event, _EventWrapper):
390 event = event.event
391 if (
392 event.type == "thread.item.added"
393 or event.type == "thread.item.done"
394 ):
395 # End the current workflow if visual item is added after it
396 if (
397 ctx.workflow_item
398 and ctx.workflow_item.id != event.item.id
399 and event.item.type != "client_tool_call"
400 and event.item.type != "hidden_context_item"
401 ):
402 yield end_workflow(ctx.workflow_item)
403
404 # track the current workflow if one is added
405 if (
406 event.type == "thread.item.added"
407 and event.item.type == "workflow"
408 ):
409 ctx.workflow_item = event.item
410
411 # track integration produced items so we can clean them up if
412 # there is a guardrail tripwire
413 produced_items.add(event.item.id)
414 yield event
415 continue
416
417 if event.type == "run_item_stream_event":
418 event = event.item
419 if (
420 event.type == "tool_call_item"
421 and event.raw_item.type == "function_call"
422 ):
423 current_tool_call = event.raw_item.call_id
424 current_item_id = event.raw_item.id
425 assert current_item_id
426 produced_items.add(current_item_id)
427 continue
428
429 if event.type != "raw_response_event":
430 # Ignore everything else that isn't a raw response event
431 continue
432
433 # Handle Responses events
434 event = event.data
435 if event.type == "response.content_part.added":
436 if event.part.type == "reasoning_text":
437 continue
438 content = _convert_content(event.part)
439 yield ThreadItemUpdated(
440 item_id=event.item_id,
441 update=AssistantMessageContentPartAdded(
442 content_index=event.content_index,
443 content=content,
444 ),
445 )
446 elif event.type == "response.output_text.delta":
447 yield ThreadItemUpdated(
448 item_id=event.item_id,
449 update=AssistantMessageContentPartTextDelta(
450 content_index=event.content_index,
451 delta=event.delta,
452 ),
453 )
454 elif event.type == "response.output_text.done":
455 yield ThreadItemUpdated(
456 item_id=event.item_id,
457 update=AssistantMessageContentPartDone(
458 content_index=event.content_index,
459 content=AssistantMessageContent(
460 text=event.text,
461 annotations=[],
462 ),
463 ),
464 )
465 elif event.type == "response.output_text.annotation.added":
466 # Ignore annotation-added events; annotations are reflected in the final item content.
467 continue
468 elif event.type == "response.output_item.added":
469 item = event.item
470 if item.type == "reasoning" and not ctx.workflow_item:
471 ctx.workflow_item = WorkflowItem(
472 id=ctx.generate_id("workflow"),
473 created_at=datetime.now(),
474 workflow=Workflow(type="reasoning", tasks=[]),
475 thread_id=thread.id,
476 )
477 produced_items.add(ctx.workflow_item.id)
478 yield ThreadItemAddedEvent(item=ctx.workflow_item)
479 if item.type == "message":
480 if ctx.workflow_item:
481 yield end_workflow(ctx.workflow_item)
482 produced_items.add(item.id)
483 yield ThreadItemAddedEvent(
484 item=AssistantMessageItem(
485 # Reusing the Responses message ID
486 id=item.id,
487 thread_id=thread.id,
488 content=[_convert_content(c) for c in item.content],
489 created_at=datetime.now(),
490 ),
491 )
492 elif event.type == "response.reasoning_summary_text.delta":
493 if not ctx.workflow_item:
494 continue
495
496 # stream the first thought in a new workflow so that we can show it earlier
497 if (
498 ctx.workflow_item.workflow.type == "reasoning"
499 and len(ctx.workflow_item.workflow.tasks) == 0
500 ):
501 streaming_thought = StreamingThoughtTracker(
502 item_id=event.item_id,
503 index=event.summary_index,
504 task=ThoughtTask(content=event.delta),
505 )
506 ctx.workflow_item.workflow.tasks.append(streaming_thought.task)
507 yield ThreadItemUpdated(
508 item_id=ctx.workflow_item.id,
509 update=WorkflowTaskAdded(
510 task=streaming_thought.task,
511 task_index=0,
512 ),
513 )
514 elif (
515 streaming_thought
516 and streaming_thought.task in ctx.workflow_item.workflow.tasks
517 and event.item_id == streaming_thought.item_id
518 and event.summary_index == streaming_thought.index
519 ):
520 streaming_thought.task.content += event.delta
521 yield ThreadItemUpdated(
522 item_id=ctx.workflow_item.id,
523 update=WorkflowTaskUpdated(
524 task=streaming_thought.task,
525 task_index=ctx.workflow_item.workflow.tasks.index(
526 streaming_thought.task
527 ),
528 ),
529 )
530 elif event.type == "response.reasoning_summary_text.done":
531 if ctx.workflow_item:
532 if (
533 streaming_thought
534 and streaming_thought.task in ctx.workflow_item.workflow.tasks
535 and event.item_id == streaming_thought.item_id
536 and event.summary_index == streaming_thought.index
537 ):
538 task = streaming_thought.task
539 task.content = event.text
540 streaming_thought = None
541 update = WorkflowTaskUpdated(
542 task=task,
543 task_index=ctx.workflow_item.workflow.tasks.index(task),
544 )
545 else:
546 task = ThoughtTask(content=event.text)
547 ctx.workflow_item.workflow.tasks.append(task)
548 update = WorkflowTaskAdded(
549 task=task,
550 task_index=ctx.workflow_item.workflow.tasks.index(task),
551 )
552 yield ThreadItemUpdated(
553 item_id=ctx.workflow_item.id,
554 update=update,
555 )
556 elif event.type == "response.output_item.done":
557 item = event.item
558 if item.type == "message":
559 produced_items.add(item.id)
560 yield ThreadItemDoneEvent(
561 item=AssistantMessageItem(
562 # Reusing the Responses message ID
563 id=item.id,
564 thread_id=thread.id,
565 content=[_convert_content(c) for c in item.content],
566 created_at=datetime.now(),
567 ),
568 )
569
570 except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
571 for item_id in produced_items:
572 yield ThreadItemRemovedEvent(item_id=item_id)
573
574 # Drain remaining events without processing them
575 context._complete()
576 queue_iterator.drain_and_complete()
577
578 raise
579
580 context._complete()
581
582 # Drain remaining events
583 async for event in queue_iterator:
584 yield event.event
585
586 # If there is still an active workflow at the end of the run, store
587 # it's current state so that we can continue it in the next turn.
588 if ctx.workflow_item:
589 await ctx.store.add_thread_item(
590 thread.id, ctx.workflow_item, ctx.request_context
591 )
592
593 if context.client_tool_call:
594 yield ThreadItemDoneEvent(
595 item=ClientToolCallItem(
596 id=current_item_id
597 or context.store.generate_item_id(
598 "tool_call", thread, context.request_context
599 ),
600 thread_id=thread.id,
601 name=context.client_tool_call.name,
602 arguments=context.client_tool_call.arguments,
603 created_at=datetime.now(),
604 call_id=current_tool_call
605 or context.store.generate_item_id(
606 "tool_call", thread, context.request_context
607 ),
608 ),
609 )
610
611
612TWidget = TypeVar("TWidget", bound=Markdown | Text)
613
614
615async def accumulate_text(
616 events: AsyncIterator[StreamEvent],
617 base_widget: TWidget,
618) -> AsyncIterator[TWidget]:
619 text = ""
620 yield base_widget
621 async for event in events:
622 if event.type == "raw_response_event":
623 if event.data.type == "response.output_text.delta":
624 text += event.data.delta
625 yield base_widget.model_copy(update={"value": text})
626 yield base_widget.model_copy(update={"value": text, "streaming": False})
627
628
629class ThreadItemConverter:
630 """
631 Converts thread items to Agent SDK input items.
632 Widgets, Tasks, and Workflows have default conversions but can be customized.
633 Attachments, Tags, and HiddenContextItems require custom handling based on the use case.
634 Other item types are converted automatically.
635 """
636
637 def attachment_to_message_content(
638 self, attachment: Attachment
639 ) -> Awaitable[ResponseInputContentParam]:
640 """
641 Convert an attachment in a user message into a message content part to send to the model.
642 Required when attachments are enabled.
643 """
644 raise NotImplementedError(
645 "An Attachment was included in a UserMessageItem but Converter.attachment_to_message_content was not implemented"
646 )
647
648 def tag_to_message_content(
649 self, tag: UserMessageTagContent
650 ) -> ResponseInputContentParam:
651 """
652 Convert a tag in a user message into a message content part to send to the model as context.
653 Required when tags are used.
654 """
655 raise NotImplementedError(
656 "A Tag was included in a UserMessageItem but Converter.tag_to_message_content is not implemented"
657 )
658
659 def hidden_context_to_input(
660 self, item: HiddenContextItem
661 ) -> TResponseInputItem | list[TResponseInputItem] | None:
662 """
663 Convert a HiddenContextItem into input item(s) to send to the model.
664 Required when HiddenContextItem are used.
665 """
666 raise NotImplementedError(
667 "HiddenContextItem were present in a user message but Converter.hidden_context_to_input was not implemented"
668 )
669
670 def task_to_input(
671 self, item: TaskItem
672 ) -> TResponseInputItem | list[TResponseInputItem] | None:
673 """
674 Convert a TaskItem into input item(s) to send to the model.
675 """
676 if item.task.type != "custom" or (
677 not item.task.title and not item.task.content
678 ):
679 return None
680 title = f"{item.task.title}" if item.task.title else ""
681 content = f"{item.task.content}" if item.task.content else ""
682 task_text = f"{title}: {content}" if title and content else title or content
683 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
684 return Message(
685 type="message",
686 content=[
687 ResponseInputTextParam(
688 type="input_text",
689 text=text,
690 )
691 ],
692 role="user",
693 )
694
695 def workflow_to_input(
696 self, item: WorkflowItem
697 ) -> TResponseInputItem | list[TResponseInputItem] | None:
698 """
699 Convert a TaskItem into input item(s) to send to the model.
700 Returns WorkflowItem.response_items by default.
701 """
702 messages = []
703 for task in item.workflow.tasks:
704 if task.type != "custom" or (not task.title and not task.content):
705 continue
706
707 title = f"{task.title}" if task.title else ""
708 content = f"{task.content}" if task.content else ""
709 task_text = f"{title}: {content}" if title and content else title or content
710 text = f"A message was displayed to the user that the following task was performed:\n<Task>\n{task_text}\n</Task>"
711 messages.append(
712 Message(
713 type="message",
714 content=[
715 ResponseInputTextParam(
716 type="input_text",
717 text=text,
718 )
719 ],
720 role="user",
721 )
722 )
723 return messages
724
725 def widget_to_input(
726 self, item: WidgetItem
727 ) -> TResponseInputItem | list[TResponseInputItem] | None:
728 """
729 Convert a WidgetItem into input item(s) to send to the model.
730 By default, WidgetItems converted to a text description with a JSON representation of the widget.
731 """
732 return Message(
733 type="message",
734 content=[
735 ResponseInputTextParam(
736 type="input_text",
737 text=f"The following graphical UI widget (id: {item.id}) was displayed to the user:"
738 + item.widget.model_dump_json(
739 exclude_unset=True, exclude_none=True
740 ),
741 )
742 ],
743 role="user",
744 )
745
746 async def user_message_to_input(
747 self, item: UserMessageItem, is_last_message: bool = True
748 ) -> TResponseInputItem | list[TResponseInputItem] | None:
749 # Build the user text exactly as typed, rendering tags as @key
750 message_text_parts: list[str] = []
751 # Track tags separately to add system context
752 raw_tags: list[UserMessageTagContent] = []
753
754 for part in item.content:
755 if isinstance(part, UserMessageTextContent):
756 message_text_parts.append(part.text)
757 elif isinstance(part, UserMessageTagContent):
758 message_text_parts.append(f"@{part.text}")
759 raw_tags.append(part)
760 else:
761 assert_never(part)
762
763 user_text_item = Message(
764 role="user",
765 type="message",
766 content=[
767 ResponseInputTextParam(
768 type="input_text", text="".join(message_text_parts)
769 ),
770 *[
771 await self.attachment_to_message_content(a)
772 for a in item.attachments
773 ],
774 ],
775 )
776
777 # Build system items (prepend later): quoted text and @-mention context
778 context_items: list[TResponseInputItem] = []
779
780 if item.quoted_text and is_last_message:
781 context_items.append(
782 Message(
783 role="user",
784 type="message",
785 content=[
786 ResponseInputTextParam(
787 type="input_text",
788 text=f"The user is referring to this in particular: \n{item.quoted_text}",
789 )
790 ],
791 )
792 )
793
794 # Dedupe tags (preserve order) and resolve to message content
795 if raw_tags:
796 seen, uniq_tags = set(), []
797 for t in raw_tags:
798 if t.text not in seen:
799 seen.add(t.text)
800 uniq_tags.append(t)
801
802 tag_content: ResponseInputMessageContentListParam = [
803 # should return summarized text items
804 self.tag_to_message_content(tag)
805 for tag in uniq_tags
806 ]
807
808 if tag_content:
809 context_items.append(
810 Message(
811 role="user",
812 type="message",
813 content=[
814 ResponseInputTextParam(
815 type="input_text",
816 text=cleandoc("""
817 # User-provided context for @-mentions
818 - When referencing resolved entities, use their canonical names **without** '@'.
819 - The '@' form appears only in user text and should not be echoed.
820 """).strip(),
821 ),
822 *tag_content,
823 ],
824 )
825 )
826
827 return [user_text_item, *context_items]
828
829 async def assistant_message_to_input(
830 self, item: AssistantMessageItem
831 ) -> TResponseInputItem | list[TResponseInputItem] | None:
832 return EasyInputMessageParam(
833 type="message",
834 content=[
835 # content param doesn't support the assistant message content types
836 cast(
837 ResponseInputContentParam,
838 ResponseOutputText(
839 type="output_text",
840 text=c.text,
841 annotations=[], # TODO: these should be sent back as well
842 ).model_dump(),
843 )
844 for c in item.content
845 ],
846 role="assistant",
847 )
848
849 async def client_tool_call_to_input(
850 self, item: ClientToolCallItem
851 ) -> TResponseInputItem | list[TResponseInputItem] | None:
852 if item.status == "pending":
853 # Filter out pending tool calls - they cannot be sent to the model
854 return None
855
856 return [
857 ResponseFunctionToolCallParam(
858 type="function_call",
859 call_id=item.call_id,
860 name=item.name,
861 arguments=json.dumps(item.arguments),
862 ),
863 FunctionCallOutput(
864 type="function_call_output",
865 call_id=item.call_id,
866 output=json.dumps(item.output),
867 ),
868 ]
869
870 async def end_of_turn_to_input(
871 self, item: EndOfTurnItem
872 ) -> TResponseInputItem | list[TResponseInputItem] | None:
873 # Only used for UI hints - you shouldn't need to override this
874 return None
875
876 async def _thread_item_to_input_item(
877 self,
878 item: ThreadItem,
879 is_last_message: bool = True,
880 ) -> list[TResponseInputItem]:
881 match item:
882 case UserMessageItem():
883 out = await self.user_message_to_input(item, is_last_message) or []
884 return out if isinstance(out, list) else [out]
885 case AssistantMessageItem():
886 out = await self.assistant_message_to_input(item) or []
887 return out if isinstance(out, list) else [out]
888 case ClientToolCallItem():
889 out = await self.client_tool_call_to_input(item) or []
890 return out if isinstance(out, list) else [out]
891 case EndOfTurnItem():
892 out = await self.end_of_turn_to_input(item) or []
893 return out if isinstance(out, list) else [out]
894 case WidgetItem():
895 out = self.widget_to_input(item) or []
896 return out if isinstance(out, list) else [out]
897 case WorkflowItem():
898 out = self.workflow_to_input(item) or []
899 return out if isinstance(out, list) else [out]
900 case TaskItem():
901 out = self.task_to_input(item) or []
902 return out if isinstance(out, list) else [out]
903 case HiddenContextItem():
904 out = self.hidden_context_to_input(item) or []
905 return out if isinstance(out, list) else [out]
906 case _:
907 assert_never(item)
908
909 async def to_agent_input(
910 self,
911 thread_items: Sequence[ThreadItem] | ThreadItem,
912 ) -> list[TResponseInputItem]:
913 if isinstance(thread_items, Sequence):
914 # shallow copy in case caller mutates the list while we're iterating
915 thread_items = thread_items[:]
916 else:
917 thread_items = [thread_items]
918 output: list[TResponseInputItem] = []
919 for item in thread_items:
920 output.extend(
921 await self._thread_item_to_input_item(
922 item,
923 is_last_message=item is thread_items[-1],
924 )
925 )
926 return output
927
928
929_DEFAULT_CONVERTER = ThreadItemConverter()
930
931
932def simple_to_agent_input(thread_items: Sequence[ThreadItem] | ThreadItem):
933 return _DEFAULT_CONVERTER.to_agent_input(thread_items)
934