openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
f43533b0dad1fc1013dbc615dd4bc192d484b94e

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_agents.py

1492lines · modecode

1import asyncio
2import json
3from collections.abc import AsyncIterator
4from datetime import datetime
5from typing import cast
6from unittest.mock import AsyncMock, Mock
7
8import pytest
9from agents import (
10 Agent,
11 GuardrailFunctionOutput,
12 InputGuardrail,
13 InputGuardrailResult,
14 InputGuardrailTripwireTriggered,
15 OutputGuardrail,
16 OutputGuardrailResult,
17 OutputGuardrailTripwireTriggered,
18 RawResponsesStreamEvent,
19 RunContextWrapper,
20 RunItemStreamEvent,
21 Runner,
22 RunResultStreaming,
23 StreamEvent,
24 ToolCallItem,
25)
26from agents._run_impl import QueueCompleteSentinel
27from openai.types.responses import (
28 EasyInputMessageParam,
29 ResponseFileSearchToolCall,
30 ResponseInputContentParam,
31 ResponseInputTextParam,
32 ResponseOutputItemAddedEvent,
33 ResponseOutputItemDoneEvent,
34 ResponseOutputMessage,
35 ResponseReasoningItem,
36)
37from openai.types.responses.response_content_part_added_event import (
38 ResponseContentPartAddedEvent,
39)
40from openai.types.responses.response_file_search_tool_call import Result
41from openai.types.responses.response_input_item_param import Message
42from openai.types.responses.response_output_text import (
43 AnnotationContainerFileCitation as ResponsesAnnotationContainerFileCitation,
44)
45from openai.types.responses.response_output_text import (
46 AnnotationFileCitation as ResponsesAnnotationFileCitation,
47)
48from openai.types.responses.response_output_text import (
49 AnnotationFilePath as ResponsesAnnotationFilePath,
50)
51from openai.types.responses.response_output_text import (
52 AnnotationURLCitation as ResponsesAnnotationURLCitation,
53)
54from openai.types.responses.response_output_text import (
55 ResponseOutputText,
56)
57from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent
58from openai.types.responses.response_text_done_event import ResponseTextDoneEvent
59
60from chatkit.agents import (
61 AgentContext,
62 ThreadItemConverter,
63 accumulate_text,
64 simple_to_agent_input,
65 stream_agent_response,
66)
67from chatkit.types import (
68 Annotation,
69 AssistantMessageContent,
70 AssistantMessageContentPartAdded,
71 AssistantMessageContentPartAnnotationAdded,
72 AssistantMessageContentPartDone,
73 AssistantMessageContentPartTextDelta,
74 AssistantMessageItem,
75 Attachment,
76 ClientToolCallItem,
77 CustomSummary,
78 CustomTask,
79 DurationSummary,
80 FileSource,
81 HiddenContextItem,
82 InferenceOptions,
83 Page,
84 TaskItem,
85 ThoughtTask,
86 Thread,
87 ThreadItemAddedEvent,
88 ThreadItemDoneEvent,
89 ThreadItemUpdatedEvent,
90 ThreadStreamEvent,
91 URLSource,
92 UserMessageItem,
93 UserMessageTagContent,
94 UserMessageTextContent,
95 WidgetItem,
96 Workflow,
97 WorkflowItem,
98 WorkflowTaskAdded,
99 WorkflowTaskUpdated,
100)
101from chatkit.widgets import Card, Text
102
103thread = Thread(id="123", title="Test", created_at=datetime.now(), items=Page())
104
105mock_store = Mock()
106mock_store.generate_item_id = lambda item_type, thread, context: f"{item_type}_id"
107mock_store.load_thread_items = AsyncMock(return_value=Page())
108mock_store.add_thread_item = AsyncMock()
109
110
111class RunResult(RunResultStreaming):
112 def add_event(self, event: StreamEvent):
113 self._event_queue.put_nowait(event)
114
115 def done(self):
116 self.is_complete = True
117 self._event_queue.put_nowait(QueueCompleteSentinel())
118
119 def throw_input_guardrails(self):
120 self._stored_exception = InputGuardrailTripwireTriggered(
121 InputGuardrailResult(
122 guardrail=Mock(spec=InputGuardrail),
123 output=GuardrailFunctionOutput(
124 output_info=None,
125 tripwire_triggered=True,
126 ),
127 )
128 )
129 self.is_complete = True
130 self._event_queue.put_nowait(QueueCompleteSentinel())
131
132 def throw_output_guardrails(self):
133 self._stored_exception = OutputGuardrailTripwireTriggered(
134 OutputGuardrailResult(
135 guardrail=Mock(spec=OutputGuardrail),
136 output=GuardrailFunctionOutput(
137 output_info=None,
138 tripwire_triggered=True,
139 ),
140 agent=Mock(spec=Agent),
141 agent_output=None,
142 )
143 )
144 self.is_complete = True
145 self._event_queue.put_nowait(QueueCompleteSentinel())
146
147
148def make_result() -> RunResult:
149 return RunResult(
150 context_wrapper=Mock(spec=RunContextWrapper),
151 input=[],
152 tool_input_guardrail_results=[],
153 tool_output_guardrail_results=[],
154 new_items=[],
155 raw_responses=[],
156 final_output=None,
157 current_agent=Agent(name="test"),
158 current_turn=0,
159 max_turns=10,
160 _current_agent_output_schema=None,
161 trace=None,
162 is_complete=False,
163 _event_queue=asyncio.Queue(),
164 _input_guardrail_queue=asyncio.Queue(),
165 _output_guardrails_task=None,
166 _run_impl_task=None,
167 _stored_exception=None,
168 output_guardrail_results=[],
169 input_guardrail_results=[],
170 )
171
172
173async def all_events(
174 events: AsyncIterator[ThreadStreamEvent],
175) -> list[ThreadStreamEvent]:
176 return [event async for event in events]
177
178
179async def test_returns_widget_item():
180 context = AgentContext(
181 previous_response_id=None, thread=thread, store=mock_store, request_context=None
182 )
183 result = make_result()
184 result.add_event(
185 RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
186 )
187 await context.stream_widget(Card(children=[Text(value="Hello, world!")]))
188 result.done()
189
190 events = await all_events(
191 stream_agent_response(
192 context=context,
193 result=result,
194 )
195 )
196
197 assert len(events) == 1
198 assert isinstance(events[0], ThreadItemDoneEvent)
199 assert isinstance(events[0].item, WidgetItem)
200 assert events[0].item.widget == Card(children=[Text(value="Hello, world!")])
201
202
203async def test_returns_widget_item_generator():
204 context = AgentContext(
205 previous_response_id=None, thread=thread, store=mock_store, request_context=None
206 )
207 result = make_result()
208 result.add_event(
209 RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
210 )
211
212 def render_widget(i: int) -> Card:
213 return Card(children=[Text(id="text", value="Hello, world"[:i])])
214
215 async def widget_generator():
216 yield render_widget(0)
217 yield render_widget(12)
218
219 await context.stream_widget(widget_generator())
220 result.done()
221
222 events = await all_events(
223 stream_agent_response(
224 context=context,
225 result=result,
226 )
227 )
228
229 assert len(events) == 3
230 assert isinstance(events[0], ThreadItemAddedEvent)
231 assert isinstance(events[0].item, WidgetItem)
232 assert events[0].item.widget == Card(children=[Text(id="text", value="")])
233
234 assert isinstance(events[1], ThreadItemUpdatedEvent)
235 assert events[1].update.type == "widget.streaming_text.value_delta"
236 assert events[1].update.component_id == "text"
237 assert events[1].update.delta == "Hello, world"
238
239 assert isinstance(events[2], ThreadItemDoneEvent)
240 assert isinstance(events[2].item, WidgetItem)
241 assert events[2].item.widget == Card(
242 children=[Text(id="text", value="Hello, world")]
243 )
244
245
246async def test_returns_widget_full_replace_generator():
247 context = AgentContext(
248 previous_response_id=None, thread=thread, store=mock_store, request_context=None
249 )
250 result = make_result()
251 result.add_event(
252 RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
253 )
254
255 async def widget_generator():
256 yield Card(children=[Text(id="text", value="Hello!")])
257 yield Card(children=[Text(key="other text", value="World!", streaming=False)])
258
259 await context.stream_widget(widget_generator())
260 result.done()
261
262 events = await all_events(
263 stream_agent_response(
264 context=context,
265 result=result,
266 )
267 )
268
269 assert len(events) == 3
270 assert isinstance(events[0], ThreadItemAddedEvent)
271 assert isinstance(events[0].item, WidgetItem)
272 assert events[0].item.widget == Card(children=[Text(id="text", value="Hello!")])
273
274 assert isinstance(events[1], ThreadItemUpdatedEvent)
275 assert events[1].update.type == "widget.root.updated"
276 assert events[1].update.widget == Card(
277 children=[Text(key="other text", value="World!", streaming=False)]
278 )
279
280 assert isinstance(events[2], ThreadItemDoneEvent)
281 assert isinstance(events[2].item, WidgetItem)
282 assert events[2].item.widget == Card(
283 children=[Text(key="other text", value="World!", streaming=False)]
284 )
285
286
287async def test_accumulate_text():
288 def delta(text: str) -> RawResponsesStreamEvent:
289 return RawResponsesStreamEvent(
290 type="raw_response_event",
291 data=ResponseTextDeltaEvent(
292 type="response.output_text.delta",
293 delta=text,
294 content_index=0,
295 item_id="123",
296 logprobs=[],
297 output_index=0,
298 sequence_number=0,
299 ),
300 )
301
302 result = Runner.run_streamed(
303 Agent("Assistant", instructions="You are a helpful assistant."), "Say hello!"
304 )
305 result = make_result()
306 result.add_event(delta("Hello, "))
307 result.add_event(delta("world!"))
308
309 result.done()
310
311 events = [
312 event
313 async for event in accumulate_text(
314 result.stream_events(), Text(key="text", value="", streaming=True)
315 )
316 ]
317 assert events == [
318 Text(key="text", value="", streaming=True),
319 Text(key="text", value="Hello, ", streaming=True),
320 Text(key="text", value="Hello, world!", streaming=True),
321 Text(key="text", value="Hello, world!", streaming=False),
322 ]
323
324
325async def test_input_item_converter_quotes_last_user_message():
326 items = [
327 UserMessageItem(
328 id="123",
329 content=[UserMessageTextContent(text="Hello!")],
330 attachments=[],
331 inference_options=InferenceOptions(),
332 thread_id=thread.id,
333 quoted_text="Hi!",
334 created_at=datetime.now(),
335 ),
336 UserMessageItem(
337 id="123",
338 content=[UserMessageTextContent(text="I'm well, thank you!")],
339 attachments=[],
340 inference_options=InferenceOptions(),
341 thread_id=thread.id,
342 quoted_text="How are you doing?",
343 created_at=datetime.now(),
344 ),
345 ]
346
347 async def throw_exception(
348 _: Attachment,
349 ) -> ResponseInputContentParam:
350 raise Exception("Not implemented")
351
352 input_items = await simple_to_agent_input(items)
353 assert len(input_items) == 3
354 assert input_items[0] == {
355 "content": [
356 {
357 "text": "Hello!",
358 "type": "input_text",
359 },
360 ],
361 "role": "user",
362 "type": "message",
363 }
364 assert input_items[1] == {
365 "content": [
366 {
367 "text": "I'm well, thank you!",
368 "type": "input_text",
369 },
370 ],
371 "role": "user",
372 "type": "message",
373 }
374 assert input_items[2] == {
375 "content": [
376 {
377 "text": "The user is referring to this in particular: \nHow are you doing?",
378 "type": "input_text",
379 },
380 ],
381 "role": "user",
382 "type": "message",
383 }
384
385
386async def test_input_item_converter_to_input_items_mixed():
387 items = [
388 UserMessageItem(
389 id="123",
390 content=[UserMessageTextContent(text="Hello!")],
391 attachments=[],
392 inference_options=InferenceOptions(),
393 thread_id=thread.id,
394 quoted_text="Hi!",
395 created_at=datetime.now(),
396 ),
397 UserMessageItem(
398 id="123",
399 content=[UserMessageTextContent(text="I'm well, thank you!")],
400 attachments=[],
401 inference_options=InferenceOptions(),
402 thread_id=thread.id,
403 quoted_text="How are you doing?",
404 created_at=datetime.now(),
405 ),
406 AssistantMessageItem(
407 id="123",
408 content=[
409 AssistantMessageContent(text="How are you doing?"),
410 AssistantMessageContent(text="Can't do that"),
411 ],
412 thread_id=thread.id,
413 created_at=datetime.now(),
414 ),
415 WidgetItem(
416 id="wd_123",
417 widget=Card(children=[Text(value="Hello, world!")]),
418 thread_id=thread.id,
419 created_at=datetime.now(),
420 ),
421 ]
422
423 input_items = await simple_to_agent_input(items)
424 assert len(input_items) == 4
425 assert input_items[0] == {
426 "content": [
427 {
428 "text": "Hello!",
429 "type": "input_text",
430 },
431 ],
432 "role": "user",
433 "type": "message",
434 }
435 assert input_items[1] == {
436 "content": [
437 {
438 "text": "I'm well, thank you!",
439 "type": "input_text",
440 },
441 ],
442 "role": "user",
443 "type": "message",
444 }
445 assert input_items[2] == {
446 "content": [
447 {
448 "annotations": [],
449 "text": "How are you doing?",
450 "logprobs": None,
451 "type": "output_text",
452 },
453 {
454 "annotations": [],
455 "text": "Can't do that",
456 "logprobs": None,
457 "type": "output_text",
458 },
459 ],
460 "type": "message",
461 "role": "assistant",
462 }
463 assert "type" in input_items[3]
464 widget_item = cast(EasyInputMessageParam, input_items[3])
465 assert widget_item.get("type") == "message"
466 assert widget_item.get("role") == "user"
467 text = widget_item.get("content")[0]["text"] # type: ignore
468 assert (
469 "The following graphical UI widget (id: wd_123) was displayed to the user"
470 in text
471 )
472 assert "Hello, world!" in text
473 assert "created_at" not in text
474
475
476async def test_input_item_converter_user_input_with_tags():
477 class MyThreadItemConverter(ThreadItemConverter):
478 async def tag_to_message_content(self, tag):
479 return ResponseInputTextParam(
480 type="input_text", text=tag.text + " " + tag.data["key"]
481 )
482
483 items = [
484 UserMessageItem(
485 id="123",
486 content=[
487 UserMessageTagContent(
488 text="Hello!", type="input_tag", id="hello", data={"key": "value"}
489 )
490 ],
491 attachments=[],
492 inference_options=InferenceOptions(),
493 thread_id=thread.id,
494 created_at=datetime.now(),
495 )
496 ]
497 items = await MyThreadItemConverter().to_agent_input(items)
498
499 assert len(items) == 2
500 assert items[0] == {
501 "content": [
502 {
503 "text": "@Hello!",
504 "type": "input_text",
505 },
506 ],
507 "role": "user",
508 "type": "message",
509 }
510 assert items[1] == {
511 "content": [
512 {
513 "text": "# User-provided context for @-mentions\n- When referencing resolved entities, use their canonical names **without** '@'.\n"
514 + "- The '@' form appears only in user text and should not be echoed.",
515 "type": "input_text",
516 },
517 {
518 "text": "Hello! value",
519 "type": "input_text",
520 },
521 ],
522 "role": "user",
523 "type": "message",
524 }
525
526
527async def test_input_item_converter_user_input_with_tags_throws_by_default():
528 items = [
529 UserMessageItem(
530 id="123",
531 content=[
532 UserMessageTagContent(
533 text="Hello!", type="input_tag", id="hello", data={}
534 )
535 ],
536 attachments=[],
537 inference_options=InferenceOptions(),
538 thread_id=thread.id,
539 created_at=datetime.now(),
540 )
541 ]
542
543 with pytest.raises(NotImplementedError):
544 await simple_to_agent_input(items)
545
546
547async def test_input_item_converter_for_hidden_context_with_string_content():
548 items = [
549 HiddenContextItem(
550 id="123",
551 content="User pressed the red button",
552 thread_id=thread.id,
553 created_at=datetime.now(),
554 )
555 ]
556
557 # The default converter works for string content
558 items = await simple_to_agent_input(items)
559 assert len(items) == 1
560 assert items[0] == {
561 "content": [
562 {
563 "text": "Hidden context for the agent (not shown to the user):\n<HiddenContext>\nUser pressed the red button\n</HiddenContext>",
564 "type": "input_text",
565 },
566 ],
567 "role": "user",
568 "type": "message",
569 }
570
571
572async def test_input_item_converter_for_hidden_context_with_non_string_content():
573 items = [
574 HiddenContextItem(
575 id="123",
576 content={"harry": "potter", "hermione": "granger"},
577 thread_id=thread.id,
578 created_at=datetime.now(),
579 )
580 ]
581
582 # Default converter should throw
583 with pytest.raises(NotImplementedError):
584 await simple_to_agent_input(items)
585
586 class MyThreadItemConverter(ThreadItemConverter):
587 async def hidden_context_to_input(self, item: HiddenContextItem):
588 content = ",".join(item.content.keys())
589 return Message(
590 type="message",
591 content=[
592 ResponseInputTextParam(
593 type="input_text", text=f"<HIDDEN>{content}</HIDDEN>"
594 )
595 ],
596 role="user",
597 )
598
599 items = await MyThreadItemConverter().to_agent_input(items)
600 assert len(items) == 1
601 assert items[0] == {
602 "content": [
603 {
604 "text": "<HIDDEN>harry,hermione</HIDDEN>",
605 "type": "input_text",
606 },
607 ],
608 "role": "user",
609 "type": "message",
610 }
611
612
613async def test_input_item_converter_with_client_tool_call():
614 items = [
615 UserMessageItem(
616 id="123",
617 content=[UserMessageTextContent(text="Call a client tool call xyz")],
618 attachments=[],
619 inference_options=InferenceOptions(),
620 thread_id=thread.id,
621 quoted_text="Hi!",
622 created_at=datetime.now(),
623 ),
624 TaskItem(
625 id="tsk_123",
626 created_at=datetime.now(),
627 task=CustomTask(title="Called xyx"),
628 thread_id=thread.id,
629 ),
630 ClientToolCallItem(
631 id="ctc_123",
632 thread_id=thread.id,
633 created_at=datetime.now(),
634 name="xyz",
635 arguments={"foo": "bar"},
636 call_id="ctc_123",
637 ),
638 ClientToolCallItem(
639 id="ctc_123_done",
640 thread_id=thread.id,
641 created_at=datetime.now(),
642 name="xyz",
643 arguments={"foo": "bar"},
644 call_id="ctc_123",
645 status="completed",
646 output={"success": True},
647 ),
648 ]
649
650 input_items = await simple_to_agent_input(items)
651 assert len(input_items) == 4
652 assert input_items[0] == {
653 "content": [
654 {
655 "text": "Call a client tool call xyz",
656 "type": "input_text",
657 },
658 ],
659 "role": "user",
660 "type": "message",
661 }
662 assert input_items[1] == {
663 "content": [
664 {
665 "text": "A message was displayed to the user that the following task was performed:\n<Task>\nCalled xyx\n</Task>",
666 "type": "input_text",
667 },
668 ],
669 "type": "message",
670 "role": "user",
671 }
672 assert input_items[2] == {
673 "type": "function_call",
674 "name": "xyz",
675 "arguments": json.dumps({"foo": "bar"}),
676 "call_id": "ctc_123",
677 }
678 assert input_items[3] == {
679 "type": "function_call_output",
680 "call_id": "ctc_123",
681 "output": json.dumps({"success": True}),
682 }
683
684
685async def test_stream_agent_response_yields_context_events_without_streaming_events():
686 context = AgentContext(
687 previous_response_id=None, thread=thread, store=mock_store, request_context=None
688 )
689 result = make_result()
690
691 event = ThreadItemAddedEvent(
692 item=WidgetItem(
693 id="123",
694 created_at=datetime.now(),
695 thread_id=thread.id,
696 widget=Card(children=[Text(id="text", value="Hello, world!")]),
697 ),
698 )
699
700 await context.stream(event)
701
702 response_streamer = stream_agent_response(context, result)
703 event = await response_streamer.__anext__()
704
705 assert event.type == "thread.item.added"
706
707 future = asyncio.ensure_future(response_streamer.__anext__())
708 assert future.done() is False
709
710 result.done()
711
712 try:
713 await future
714 assert False, "expected StopAsyncIteration"
715 except StopAsyncIteration:
716 pass
717
718 assert future.done() is True
719
720
721async def test_stream_agent_response_maps_events():
722 context = AgentContext(
723 previous_response_id=None, thread=thread, store=mock_store, request_context=None
724 )
725 result = make_result()
726
727 event = ThreadItemAddedEvent(
728 item=WidgetItem(
729 id="123",
730 created_at=datetime.now(),
731 thread_id=thread.id,
732 widget=Card(children=[Text(id="text", value="Hello, world!")]),
733 ),
734 )
735
736 await context.stream(event)
737 result.add_event(
738 RawResponsesStreamEvent(
739 type="raw_response_event",
740 data=ResponseTextDeltaEvent(
741 type="response.output_text.delta",
742 delta="Hello, world!",
743 content_index=0,
744 item_id="123",
745 logprobs=[],
746 output_index=0,
747 sequence_number=0,
748 ),
749 )
750 )
751
752 response_streamer = stream_agent_response(context, result)
753 event1 = await response_streamer.__anext__()
754 event2 = await response_streamer.__anext__()
755
756 assert {event1.type, event2.type} == {
757 "thread.item.added",
758 "thread.item.updated",
759 }
760
761 future = asyncio.ensure_future(response_streamer.__anext__())
762 assert future.done() is False
763
764 result.done()
765
766 try:
767 await future
768 assert False, "expected StopAsyncIteration"
769 except StopAsyncIteration:
770 pass
771
772 assert future.done() is True
773
774
775@pytest.mark.parametrize(
776 "raw_event,expected_event",
777 [
778 (
779 RawResponsesStreamEvent(
780 type="raw_response_event",
781 data=ResponseTextDeltaEvent(
782 type="response.output_text.delta",
783 delta="Hello, world!",
784 content_index=0,
785 item_id="123",
786 logprobs=[],
787 output_index=0,
788 sequence_number=0,
789 ),
790 ),
791 ThreadItemUpdatedEvent(
792 item_id="123",
793 update=AssistantMessageContentPartTextDelta(
794 content_index=0,
795 delta="Hello, world!",
796 ),
797 ),
798 ),
799 (
800 RawResponsesStreamEvent(
801 type="raw_response_event",
802 data=ResponseContentPartAddedEvent(
803 type="response.content_part.added",
804 part=ResponseOutputText(
805 type="output_text",
806 text="New content",
807 annotations=[],
808 ),
809 content_index=1,
810 item_id="123",
811 output_index=0,
812 sequence_number=1,
813 ),
814 ),
815 ThreadItemUpdatedEvent(
816 item_id="123",
817 update=AssistantMessageContentPartAdded(
818 content_index=1,
819 content=AssistantMessageContent(text="New content", annotations=[]),
820 ),
821 ),
822 ),
823 (
824 RawResponsesStreamEvent(
825 type="raw_response_event",
826 data=ResponseTextDoneEvent(
827 type="response.output_text.done",
828 text="Final text",
829 content_index=0,
830 item_id="123",
831 logprobs=[],
832 output_index=0,
833 sequence_number=2,
834 ),
835 ),
836 ThreadItemUpdatedEvent(
837 item_id="123",
838 update=AssistantMessageContentPartDone(
839 content_index=0,
840 content=AssistantMessageContent(
841 text="Final text",
842 annotations=[],
843 ),
844 ),
845 ),
846 ),
847 (
848 RawResponsesStreamEvent(
849 type="raw_response_event",
850 data=Mock(
851 type="response.output_text.annotation.added",
852 annotation=ResponsesAnnotationFileCitation(
853 type="file_citation",
854 file_id="file_123",
855 filename="file.txt",
856 index=5,
857 ),
858 content_index=0,
859 item_id="123",
860 annotation_index=0,
861 output_index=0,
862 sequence_number=3,
863 ),
864 ),
865 ThreadItemUpdatedEvent(
866 item_id="123",
867 update=AssistantMessageContentPartAnnotationAdded(
868 content_index=0,
869 annotation_index=0,
870 annotation=Annotation(
871 source=FileSource(filename="file.txt", title="file.txt"),
872 index=5,
873 ),
874 ),
875 ),
876 ),
877 ],
878)
879async def test_event_mapping(raw_event, expected_event):
880 context = AgentContext(
881 previous_response_id=None, thread=thread, store=mock_store, request_context=None
882 )
883 result = make_result()
884
885 result.add_event(raw_event)
886 result.done()
887
888 events = await all_events(stream_agent_response(context, result))
889 if expected_event:
890 assert events == [expected_event]
891 else:
892 assert events == []
893
894
895async def test_stream_agent_response_emits_annotation_added_events():
896 context = AgentContext(
897 previous_response_id=None, thread=thread, store=mock_store, request_context=None
898 )
899 result = make_result()
900 item_id = "item_123"
901
902 def add_annotation_event(annotation, sequence_number):
903 result.add_event(
904 RawResponsesStreamEvent(
905 type="raw_response_event",
906 data=Mock(
907 type="response.output_text.annotation.added",
908 annotation=annotation,
909 content_index=0,
910 item_id=item_id,
911 annotation_index=sequence_number,
912 output_index=0,
913 sequence_number=sequence_number,
914 ),
915 )
916 )
917
918 add_annotation_event(
919 ResponsesAnnotationFileCitation(
920 type="file_citation",
921 file_id="file_invalid",
922 filename="",
923 index=0,
924 ),
925 sequence_number=0,
926 )
927 add_annotation_event(
928 ResponsesAnnotationContainerFileCitation(
929 type="container_file_citation",
930 container_id="container_1",
931 file_id="file_123",
932 filename="container.txt",
933 start_index=0,
934 end_index=3,
935 ),
936 sequence_number=1,
937 )
938 add_annotation_event(
939 ResponsesAnnotationURLCitation(
940 type="url_citation",
941 url="https://example.com",
942 title="Example",
943 start_index=1,
944 end_index=5,
945 ),
946 sequence_number=2,
947 )
948 result.done()
949
950 events = await all_events(stream_agent_response(context, result))
951 assert events == [
952 ThreadItemUpdatedEvent(
953 item_id=item_id,
954 update=AssistantMessageContentPartAnnotationAdded(
955 content_index=0,
956 annotation_index=0,
957 annotation=Annotation(
958 source=FileSource(filename="container.txt", title="container.txt"),
959 index=3,
960 ),
961 ),
962 ),
963 ThreadItemUpdatedEvent(
964 item_id=item_id,
965 update=AssistantMessageContentPartAnnotationAdded(
966 content_index=0,
967 annotation_index=1,
968 annotation=Annotation(
969 source=URLSource(
970 url="https://example.com",
971 title="Example",
972 ),
973 index=5,
974 ),
975 ),
976 ),
977 ]
978
979
980@pytest.mark.parametrize("throw_guardrail", ["input", "output"])
981async def test_stream_agent_response_yields_item_removed_event(throw_guardrail):
982 context = AgentContext(
983 previous_response_id=None, thread=thread, store=mock_store, request_context=None
984 )
985 result = make_result()
986 result.add_event(
987 RawResponsesStreamEvent(
988 type="raw_response_event",
989 data=ResponseOutputItemAddedEvent(
990 type="response.output_item.added",
991 item=ResponseOutputMessage(
992 id="1",
993 content=[
994 ResponseOutputText(
995 annotations=[], type="output_text", text="Hello, world!"
996 )
997 ],
998 role="assistant",
999 status="completed",
1000 type="message",
1001 ),
1002 output_index=0,
1003 sequence_number=0,
1004 ),
1005 )
1006 )
1007 await context.stream(
1008 ThreadItemAddedEvent(
1009 item=AssistantMessageItem(
1010 id="2",
1011 content=[AssistantMessageContent(text="Hello, world!")],
1012 thread_id=thread.id,
1013 created_at=datetime.now(),
1014 ),
1015 )
1016 )
1017
1018 await context.stream(
1019 ThreadItemDoneEvent(
1020 item=WidgetItem(
1021 id="3",
1022 created_at=datetime.now(),
1023 thread_id=thread.id,
1024 widget=Card(children=[Text(id="text", value="Hello, world!")]),
1025 ),
1026 )
1027 )
1028
1029 iterator = stream_agent_response(context, result)
1030
1031 n = 3
1032 events = []
1033 # Grab first 3 events to
1034 async for event in iterator:
1035 n -= 1
1036 events.append(event)
1037 if n == 0:
1038 break
1039
1040 if throw_guardrail == "input":
1041 result.throw_input_guardrails()
1042 else:
1043 result.throw_output_guardrails()
1044
1045 try:
1046 async for event in iterator:
1047 events.append(event)
1048 assert False, "Guardrail should have been thrown from stream_agent_response"
1049 except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
1050 pass
1051 except Exception as e:
1052 assert False, f"Unexpected exception: {e}"
1053
1054 deleted_item_ids = {
1055 event.item_id for event in events if event.type == "thread.item.removed"
1056 }
1057 assert deleted_item_ids == {"1", "2", "3"}
1058
1059
1060async def test_stream_agent_response_assistant_message_content_types():
1061 AgentContext(
1062 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1063 )
1064 result = make_result()
1065
1066 result.add_event(
1067 RawResponsesStreamEvent(
1068 type="raw_response_event",
1069 data=ResponseOutputItemDoneEvent(
1070 type="response.output_item.done",
1071 item=ResponseFileSearchToolCall(
1072 id="fs_0",
1073 queries=["Hello, world!"],
1074 status="completed",
1075 type="file_search_call",
1076 results=[
1077 Result(
1078 file_id="f_123",
1079 filename="test.txt",
1080 text="Hello, world!",
1081 score=1.0,
1082 ),
1083 Result(
1084 file_id="f_123",
1085 filename="test.txt",
1086 text="Hello, friends!",
1087 score=0.5,
1088 ),
1089 ],
1090 ),
1091 output_index=0,
1092 sequence_number=0,
1093 ),
1094 )
1095 )
1096 result.add_event(
1097 RawResponsesStreamEvent(
1098 type="raw_response_event",
1099 data=ResponseOutputItemDoneEvent(
1100 type="response.output_item.done",
1101 item=ResponseOutputMessage(
1102 id="1",
1103 content=[
1104 ResponseOutputText(
1105 annotations=[
1106 ResponsesAnnotationFileCitation(
1107 type="file_citation",
1108 file_id="f_123",
1109 index=0,
1110 filename="test.txt",
1111 ),
1112 ResponsesAnnotationContainerFileCitation(
1113 type="container_file_citation",
1114 container_id="container_1",
1115 file_id="f_456",
1116 filename="container.txt",
1117 start_index=0,
1118 end_index=3,
1119 ),
1120 ResponsesAnnotationURLCitation(
1121 type="url_citation",
1122 url="https://www.google.com",
1123 title="Google",
1124 start_index=0,
1125 end_index=10,
1126 ),
1127 ResponsesAnnotationFilePath(
1128 type="file_path",
1129 file_id="123",
1130 index=0,
1131 ),
1132 ],
1133 text="Hello, world!",
1134 type="output_text",
1135 ),
1136 ResponseOutputText(
1137 annotations=[],
1138 text="Can't do that",
1139 type="output_text",
1140 ),
1141 ],
1142 role="assistant",
1143 status="completed",
1144 type="message",
1145 ),
1146 output_index=0,
1147 sequence_number=0,
1148 ),
1149 )
1150 )
1151
1152 result.done()
1153
1154 context = AgentContext(
1155 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1156 )
1157 events = await all_events(stream_agent_response(context, result))
1158 assert len(events) == 1
1159 assert isinstance(events[0], ThreadItemDoneEvent)
1160 message = events[0].item
1161 assert isinstance(message, AssistantMessageItem)
1162 assert message.content == [
1163 AssistantMessageContent(
1164 annotations=[
1165 Annotation(
1166 source=FileSource(
1167 filename="test.txt",
1168 title="test.txt",
1169 ),
1170 index=0,
1171 ),
1172 Annotation(
1173 source=FileSource(
1174 filename="container.txt",
1175 title="container.txt",
1176 ),
1177 index=3,
1178 ),
1179 Annotation(
1180 source=URLSource(
1181 url="https://www.google.com",
1182 title="Google",
1183 ),
1184 index=10,
1185 ),
1186 ],
1187 text="Hello, world!",
1188 ),
1189 AssistantMessageContent(text="Can't do that", annotations=[]),
1190 ]
1191 assert message.id == "1"
1192
1193
1194async def test_workflow_streams_first_thought():
1195 context = AgentContext(
1196 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1197 )
1198 result = make_result()
1199
1200 # first thought
1201 result.add_event(
1202 RawResponsesStreamEvent(
1203 type="raw_response_event",
1204 data=ResponseOutputItemAddedEvent(
1205 type="response.output_item.added",
1206 item=ResponseReasoningItem(
1207 id="resp_1",
1208 summary=[],
1209 type="reasoning",
1210 ),
1211 output_index=0,
1212 sequence_number=0,
1213 ),
1214 )
1215 )
1216 result.add_event(
1217 RawResponsesStreamEvent(
1218 type="raw_response_event",
1219 data=Mock(
1220 type="response.reasoning_summary_text.delta",
1221 item_id="resp_1",
1222 summary_index=0,
1223 delta="Think",
1224 ),
1225 )
1226 )
1227 result.add_event(
1228 RawResponsesStreamEvent(
1229 type="raw_response_event",
1230 data=Mock(
1231 type="response.reasoning_summary_text.delta",
1232 item_id="resp_1",
1233 summary_index=0,
1234 delta="ing 1",
1235 ),
1236 )
1237 )
1238 result.add_event(
1239 RawResponsesStreamEvent(
1240 type="raw_response_event",
1241 data=Mock(
1242 type="response.reasoning_summary_text.done",
1243 item_id="resp_1",
1244 summary_index=0,
1245 text="Thinking 1",
1246 ),
1247 )
1248 )
1249
1250 # second thought
1251 result.add_event(
1252 RawResponsesStreamEvent(
1253 type="raw_response_event",
1254 data=Mock(
1255 type="response.reasoning_summary_text.delta",
1256 item_id="resp_1",
1257 summary_index=1,
1258 delta="Think",
1259 ),
1260 )
1261 )
1262 result.add_event(
1263 RawResponsesStreamEvent(
1264 type="raw_response_event",
1265 data=Mock(
1266 type="response.reasoning_summary_text.delta",
1267 item_id="resp_1",
1268 summary_index=1,
1269 delta="ing 2",
1270 ),
1271 )
1272 )
1273 result.add_event(
1274 RawResponsesStreamEvent(
1275 type="raw_response_event",
1276 data=Mock(
1277 type="response.reasoning_summary_text.done",
1278 item_id="resp_1",
1279 summary_index=1,
1280 text="Thinking 2",
1281 ),
1282 )
1283 )
1284
1285 result.done()
1286 stream = stream_agent_response(context, result)
1287
1288 # Workflow added
1289 event = await anext(stream)
1290 assert isinstance(event, ThreadItemAddedEvent)
1291 assert context.workflow_item is not None
1292 assert context.workflow_item.workflow.type == "reasoning"
1293 assert len(context.workflow_item.workflow.tasks) == 0
1294 assert event == ThreadItemAddedEvent(item=context.workflow_item)
1295
1296 # First thought added
1297 event = await anext(stream)
1298 assert context.workflow_item is not None
1299 assert len(context.workflow_item.workflow.tasks) == 1
1300 assert isinstance(event, ThreadItemUpdatedEvent)
1301 assert event == ThreadItemUpdatedEvent(
1302 item_id=context.workflow_item.id,
1303 update=WorkflowTaskAdded(
1304 task=ThoughtTask(content="Think"),
1305 task_index=0,
1306 ),
1307 )
1308
1309 # First thought delta
1310 event = await anext(stream)
1311 assert context.workflow_item is not None
1312 assert len(context.workflow_item.workflow.tasks) == 1
1313 assert isinstance(event, ThreadItemUpdatedEvent)
1314 assert event == ThreadItemUpdatedEvent(
1315 item_id=context.workflow_item.id,
1316 update=WorkflowTaskUpdated(
1317 task=ThoughtTask(content="Thinking 1"),
1318 task_index=0,
1319 ),
1320 )
1321
1322 # First thought done
1323 event = await anext(stream)
1324 assert context.workflow_item is not None
1325 assert len(context.workflow_item.workflow.tasks) == 1
1326 assert isinstance(event, ThreadItemUpdatedEvent)
1327 assert event == ThreadItemUpdatedEvent(
1328 item_id=context.workflow_item.id,
1329 update=WorkflowTaskUpdated(
1330 task=ThoughtTask(content="Thinking 1"),
1331 task_index=0,
1332 ),
1333 )
1334
1335 # Second thought added (not streamed)
1336 event = await anext(stream)
1337 assert context.workflow_item is not None
1338 assert len(context.workflow_item.workflow.tasks) == 2
1339 assert isinstance(event, ThreadItemUpdatedEvent)
1340 assert event == ThreadItemUpdatedEvent(
1341 item_id=context.workflow_item.id,
1342 update=WorkflowTaskAdded(
1343 task=ThoughtTask(content="Thinking 2"),
1344 task_index=1,
1345 ),
1346 )
1347
1348 try:
1349 while True:
1350 await anext(stream)
1351 except StopAsyncIteration:
1352 pass
1353
1354
1355async def test_workflow_ends_on_message():
1356 context = AgentContext(
1357 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1358 )
1359 result = make_result()
1360
1361 # first thought
1362 result.add_event(
1363 RawResponsesStreamEvent(
1364 type="raw_response_event",
1365 data=ResponseOutputItemAddedEvent(
1366 type="response.output_item.added",
1367 item=ResponseReasoningItem(
1368 id="resp_1",
1369 summary=[],
1370 type="reasoning",
1371 ),
1372 output_index=0,
1373 sequence_number=0,
1374 ),
1375 )
1376 )
1377 result.add_event(
1378 RawResponsesStreamEvent(
1379 type="raw_response_event",
1380 data=Mock(
1381 type="response.reasoning_summary_text.done",
1382 item_id="resp_1",
1383 summary_index=0,
1384 text="Thinking 1",
1385 ),
1386 )
1387 )
1388
1389 # not reasoning
1390 result.add_event(
1391 RawResponsesStreamEvent(
1392 type="raw_response_event",
1393 data=ResponseOutputItemAddedEvent(
1394 type="response.output_item.added",
1395 item=ResponseOutputMessage(
1396 id="m_1",
1397 content=[],
1398 role="assistant",
1399 status="in_progress",
1400 type="message",
1401 ),
1402 output_index=0,
1403 sequence_number=0,
1404 ),
1405 )
1406 )
1407
1408 result.done()
1409 stream = stream_agent_response(context, result)
1410
1411 # Workflow added
1412 event = await anext(stream)
1413 assert isinstance(event, ThreadItemAddedEvent)
1414 assert context.workflow_item is not None
1415 assert context.workflow_item.workflow.type == "reasoning"
1416 assert len(context.workflow_item.workflow.tasks) == 0
1417 assert event == ThreadItemAddedEvent(item=context.workflow_item)
1418
1419 # First thought done
1420 event = await anext(stream)
1421 assert context.workflow_item is not None
1422 assert len(context.workflow_item.workflow.tasks) == 1
1423 assert isinstance(event, ThreadItemUpdatedEvent)
1424 assert event == ThreadItemUpdatedEvent(
1425 item_id=context.workflow_item.id,
1426 update=WorkflowTaskAdded(
1427 task=ThoughtTask(content="Thinking 1"),
1428 task_index=0,
1429 ),
1430 )
1431
1432 # Workflow ended
1433 event = await anext(stream)
1434 assert isinstance(event, ThreadItemDoneEvent)
1435 assert event.item.type == "workflow"
1436 assert context.workflow_item is None
1437 # Summary and expanded are handled by the end_workflow method
1438 assert isinstance(event.item.workflow.summary, DurationSummary)
1439 assert event.item.workflow.expanded is False
1440
1441 try:
1442 while True:
1443 await anext(stream)
1444 except StopAsyncIteration:
1445 pass
1446
1447
1448async def test_existing_workflow_summary_not_overwritten_on_automatic_end():
1449 context = AgentContext(
1450 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1451 )
1452 result = make_result()
1453 context.workflow_item = WorkflowItem(
1454 id="wf_1",
1455 created_at=datetime.now(),
1456 workflow=Workflow(type="custom", tasks=[], summary=CustomSummary(title="Test")),
1457 thread_id=thread.id,
1458 )
1459
1460 result.add_event(
1461 RawResponsesStreamEvent(
1462 type="raw_response_event",
1463 data=ResponseOutputItemAddedEvent(
1464 type="response.output_item.added",
1465 item=ResponseOutputMessage(
1466 id="m_1",
1467 content=[],
1468 role="assistant",
1469 status="in_progress",
1470 type="message",
1471 ),
1472 output_index=0,
1473 sequence_number=0,
1474 ),
1475 )
1476 )
1477
1478 result.done()
1479 stream = stream_agent_response(context, result)
1480
1481 event = await anext(stream)
1482
1483 assert isinstance(event, ThreadItemDoneEvent)
1484 assert context.workflow_item is None
1485 assert event.item.type == "workflow"
1486 assert event.item.workflow.summary == CustomSummary(title="Test")
1487
1488 try:
1489 while True:
1490 await anext(stream)
1491 except StopAsyncIteration:
1492 pass
1493