openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
ec580cba878a096f48ea5fac46bcbf29b41ecc61

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_agents.py

1424lines · modecode

1import asyncio
2import json
3from collections.abc import AsyncIterator
4from datetime import datetime
5from typing import cast
6from unittest.mock import AsyncMock, Mock
7
8import pytest
9from agents import (
10 Agent,
11 GuardrailFunctionOutput,
12 InputGuardrail,
13 InputGuardrailResult,
14 InputGuardrailTripwireTriggered,
15 OutputGuardrail,
16 OutputGuardrailResult,
17 OutputGuardrailTripwireTriggered,
18 RawResponsesStreamEvent,
19 RunContextWrapper,
20 RunItemStreamEvent,
21 Runner,
22 RunResultStreaming,
23 StreamEvent,
24 ToolCallItem,
25)
26from agents._run_impl import QueueCompleteSentinel
27from openai.types.responses import (
28 EasyInputMessageParam,
29 ResponseFileSearchToolCall,
30 ResponseInputContentParam,
31 ResponseInputTextParam,
32 ResponseOutputItemAddedEvent,
33 ResponseOutputItemDoneEvent,
34 ResponseOutputMessage,
35 ResponseReasoningItem,
36)
37from openai.types.responses.response_content_part_added_event import (
38 ResponseContentPartAddedEvent,
39)
40from openai.types.responses.response_file_search_tool_call import Result
41from openai.types.responses.response_output_text import (
42 AnnotationContainerFileCitation as ResponsesAnnotationContainerFileCitation,
43)
44from openai.types.responses.response_output_text import (
45 AnnotationFileCitation as ResponsesAnnotationFileCitation,
46)
47from openai.types.responses.response_output_text import (
48 AnnotationFilePath as ResponsesAnnotationFilePath,
49)
50from openai.types.responses.response_output_text import (
51 AnnotationURLCitation as ResponsesAnnotationURLCitation,
52)
53from openai.types.responses.response_output_text import (
54 ResponseOutputText,
55)
56from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent
57from openai.types.responses.response_text_done_event import ResponseTextDoneEvent
58
59from chatkit.agents import (
60 AgentContext,
61 ThreadItemConverter,
62 accumulate_text,
63 simple_to_agent_input,
64 stream_agent_response,
65)
66from chatkit.types import (
67 Annotation,
68 AssistantMessageContent,
69 AssistantMessageContentPartAdded,
70 AssistantMessageContentPartAnnotationAdded,
71 AssistantMessageContentPartDone,
72 AssistantMessageContentPartTextDelta,
73 AssistantMessageItem,
74 Attachment,
75 ClientToolCallItem,
76 CustomSummary,
77 CustomTask,
78 DurationSummary,
79 FileSource,
80 InferenceOptions,
81 Page,
82 TaskItem,
83 ThoughtTask,
84 Thread,
85 ThreadItemAddedEvent,
86 ThreadItemDoneEvent,
87 ThreadItemUpdated,
88 ThreadStreamEvent,
89 URLSource,
90 UserMessageItem,
91 UserMessageTagContent,
92 UserMessageTextContent,
93 WidgetItem,
94 Workflow,
95 WorkflowItem,
96 WorkflowTaskAdded,
97 WorkflowTaskUpdated,
98)
99from chatkit.widgets import Card, Text
100
101thread = Thread(id="123", title="Test", created_at=datetime.now(), items=Page())
102
103mock_store = Mock()
104mock_store.generate_item_id = lambda item_type, thread, context: f"{item_type}_id"
105mock_store.load_thread_items = AsyncMock(return_value=Page())
106mock_store.add_thread_item = AsyncMock()
107
108
109class RunResult(RunResultStreaming):
110 def add_event(self, event: StreamEvent):
111 self._event_queue.put_nowait(event)
112
113 def done(self):
114 self.is_complete = True
115 self._event_queue.put_nowait(QueueCompleteSentinel())
116
117 def throw_input_guardrails(self):
118 self._stored_exception = InputGuardrailTripwireTriggered(
119 InputGuardrailResult(
120 guardrail=Mock(spec=InputGuardrail),
121 output=GuardrailFunctionOutput(
122 output_info=None,
123 tripwire_triggered=True,
124 ),
125 )
126 )
127 self.is_complete = True
128 self._event_queue.put_nowait(QueueCompleteSentinel())
129
130 def throw_output_guardrails(self):
131 self._stored_exception = OutputGuardrailTripwireTriggered(
132 OutputGuardrailResult(
133 guardrail=Mock(spec=OutputGuardrail),
134 output=GuardrailFunctionOutput(
135 output_info=None,
136 tripwire_triggered=True,
137 ),
138 agent=Mock(spec=Agent),
139 agent_output=None,
140 )
141 )
142 self.is_complete = True
143 self._event_queue.put_nowait(QueueCompleteSentinel())
144
145
146def make_result() -> RunResult:
147 return RunResult(
148 context_wrapper=Mock(spec=RunContextWrapper),
149 input=[],
150 tool_input_guardrail_results=[],
151 tool_output_guardrail_results=[],
152 new_items=[],
153 raw_responses=[],
154 final_output=None,
155 current_agent=Agent(name="test"),
156 current_turn=0,
157 max_turns=10,
158 _current_agent_output_schema=None,
159 trace=None,
160 is_complete=False,
161 _event_queue=asyncio.Queue(),
162 _input_guardrail_queue=asyncio.Queue(),
163 _output_guardrails_task=None,
164 _run_impl_task=None,
165 _stored_exception=None,
166 output_guardrail_results=[],
167 input_guardrail_results=[],
168 )
169
170
171async def all_events(
172 events: AsyncIterator[ThreadStreamEvent],
173) -> list[ThreadStreamEvent]:
174 return [event async for event in events]
175
176
177async def test_returns_widget_item():
178 context = AgentContext(
179 previous_response_id=None, thread=thread, store=mock_store, request_context=None
180 )
181 result = make_result()
182 result.add_event(
183 RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
184 )
185 await context.stream_widget(Card(children=[Text(value="Hello, world!")]))
186 result.done()
187
188 events = await all_events(
189 stream_agent_response(
190 context=context,
191 result=result,
192 )
193 )
194
195 assert len(events) == 1
196 assert isinstance(events[0], ThreadItemDoneEvent)
197 assert isinstance(events[0].item, WidgetItem)
198 assert events[0].item.widget == Card(children=[Text(value="Hello, world!")])
199
200
201async def test_returns_widget_item_generator():
202 context = AgentContext(
203 previous_response_id=None, thread=thread, store=mock_store, request_context=None
204 )
205 result = make_result()
206 result.add_event(
207 RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
208 )
209
210 def render_widget(i: int) -> Card:
211 return Card(children=[Text(id="text", value="Hello, world"[:i])])
212
213 async def widget_generator():
214 yield render_widget(0)
215 yield render_widget(12)
216
217 await context.stream_widget(widget_generator())
218 result.done()
219
220 events = await all_events(
221 stream_agent_response(
222 context=context,
223 result=result,
224 )
225 )
226
227 assert len(events) == 3
228 assert isinstance(events[0], ThreadItemAddedEvent)
229 assert isinstance(events[0].item, WidgetItem)
230 assert events[0].item.widget == Card(children=[Text(id="text", value="")])
231
232 assert isinstance(events[1], ThreadItemUpdated)
233 assert events[1].update.type == "widget.streaming_text.value_delta"
234 assert events[1].update.component_id == "text"
235 assert events[1].update.delta == "Hello, world"
236
237 assert isinstance(events[2], ThreadItemDoneEvent)
238 assert isinstance(events[2].item, WidgetItem)
239 assert events[2].item.widget == Card(
240 children=[Text(id="text", value="Hello, world")]
241 )
242
243
244async def test_returns_widget_full_replace_generator():
245 context = AgentContext(
246 previous_response_id=None, thread=thread, store=mock_store, request_context=None
247 )
248 result = make_result()
249 result.add_event(
250 RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
251 )
252
253 async def widget_generator():
254 yield Card(children=[Text(id="text", value="Hello!")])
255 yield Card(children=[Text(key="other text", value="World!", streaming=False)])
256
257 await context.stream_widget(widget_generator())
258 result.done()
259
260 events = await all_events(
261 stream_agent_response(
262 context=context,
263 result=result,
264 )
265 )
266
267 assert len(events) == 3
268 assert isinstance(events[0], ThreadItemAddedEvent)
269 assert isinstance(events[0].item, WidgetItem)
270 assert events[0].item.widget == Card(children=[Text(id="text", value="Hello!")])
271
272 assert isinstance(events[1], ThreadItemUpdated)
273 assert events[1].update.type == "widget.root.updated"
274 assert events[1].update.widget == Card(
275 children=[Text(key="other text", value="World!", streaming=False)]
276 )
277
278 assert isinstance(events[2], ThreadItemDoneEvent)
279 assert isinstance(events[2].item, WidgetItem)
280 assert events[2].item.widget == Card(
281 children=[Text(key="other text", value="World!", streaming=False)]
282 )
283
284
285async def test_accumulate_text():
286 def delta(text: str) -> RawResponsesStreamEvent:
287 return RawResponsesStreamEvent(
288 type="raw_response_event",
289 data=ResponseTextDeltaEvent(
290 type="response.output_text.delta",
291 delta=text,
292 content_index=0,
293 item_id="123",
294 logprobs=[],
295 output_index=0,
296 sequence_number=0,
297 ),
298 )
299
300 result = Runner.run_streamed(
301 Agent("Assistant", instructions="You are a helpful assistant."), "Say hello!"
302 )
303 result = make_result()
304 result.add_event(delta("Hello, "))
305 result.add_event(delta("world!"))
306
307 result.done()
308
309 events = [
310 event
311 async for event in accumulate_text(
312 result.stream_events(), Text(key="text", value="", streaming=True)
313 )
314 ]
315 assert events == [
316 Text(key="text", value="", streaming=True),
317 Text(key="text", value="Hello, ", streaming=True),
318 Text(key="text", value="Hello, world!", streaming=True),
319 Text(key="text", value="Hello, world!", streaming=False),
320 ]
321
322
323async def test_input_item_converter_quotes_last_user_message():
324 items = [
325 UserMessageItem(
326 id="123",
327 content=[UserMessageTextContent(text="Hello!")],
328 attachments=[],
329 inference_options=InferenceOptions(),
330 thread_id=thread.id,
331 quoted_text="Hi!",
332 created_at=datetime.now(),
333 ),
334 UserMessageItem(
335 id="123",
336 content=[UserMessageTextContent(text="I'm well, thank you!")],
337 attachments=[],
338 inference_options=InferenceOptions(),
339 thread_id=thread.id,
340 quoted_text="How are you doing?",
341 created_at=datetime.now(),
342 ),
343 ]
344
345 async def throw_exception(
346 _: Attachment,
347 ) -> ResponseInputContentParam:
348 raise Exception("Not implemented")
349
350 input_items = await simple_to_agent_input(items)
351 assert len(input_items) == 3
352 assert input_items[0] == {
353 "content": [
354 {
355 "text": "Hello!",
356 "type": "input_text",
357 },
358 ],
359 "role": "user",
360 "type": "message",
361 }
362 assert input_items[1] == {
363 "content": [
364 {
365 "text": "I'm well, thank you!",
366 "type": "input_text",
367 },
368 ],
369 "role": "user",
370 "type": "message",
371 }
372 assert input_items[2] == {
373 "content": [
374 {
375 "text": "The user is referring to this in particular: \nHow are you doing?",
376 "type": "input_text",
377 },
378 ],
379 "role": "user",
380 "type": "message",
381 }
382
383
384async def test_input_item_converter_to_input_items_mixed():
385 items = [
386 UserMessageItem(
387 id="123",
388 content=[UserMessageTextContent(text="Hello!")],
389 attachments=[],
390 inference_options=InferenceOptions(),
391 thread_id=thread.id,
392 quoted_text="Hi!",
393 created_at=datetime.now(),
394 ),
395 UserMessageItem(
396 id="123",
397 content=[UserMessageTextContent(text="I'm well, thank you!")],
398 attachments=[],
399 inference_options=InferenceOptions(),
400 thread_id=thread.id,
401 quoted_text="How are you doing?",
402 created_at=datetime.now(),
403 ),
404 AssistantMessageItem(
405 id="123",
406 content=[
407 AssistantMessageContent(text="How are you doing?"),
408 AssistantMessageContent(text="Can't do that"),
409 ],
410 thread_id=thread.id,
411 created_at=datetime.now(),
412 ),
413 WidgetItem(
414 id="wd_123",
415 widget=Card(children=[Text(value="Hello, world!")]),
416 thread_id=thread.id,
417 created_at=datetime.now(),
418 ),
419 ]
420
421 input_items = await simple_to_agent_input(items)
422 assert len(input_items) == 4
423 assert input_items[0] == {
424 "content": [
425 {
426 "text": "Hello!",
427 "type": "input_text",
428 },
429 ],
430 "role": "user",
431 "type": "message",
432 }
433 assert input_items[1] == {
434 "content": [
435 {
436 "text": "I'm well, thank you!",
437 "type": "input_text",
438 },
439 ],
440 "role": "user",
441 "type": "message",
442 }
443 assert input_items[2] == {
444 "content": [
445 {
446 "annotations": [],
447 "text": "How are you doing?",
448 "logprobs": None,
449 "type": "output_text",
450 },
451 {
452 "annotations": [],
453 "text": "Can't do that",
454 "logprobs": None,
455 "type": "output_text",
456 },
457 ],
458 "type": "message",
459 "role": "assistant",
460 }
461 assert "type" in input_items[3]
462 widget_item = cast(EasyInputMessageParam, input_items[3])
463 assert widget_item.get("type") == "message"
464 assert widget_item.get("role") == "user"
465 text = widget_item.get("content")[0]["text"] # type: ignore
466 assert (
467 "The following graphical UI widget (id: wd_123) was displayed to the user"
468 in text
469 )
470 assert "Hello, world!" in text
471 assert "created_at" not in text
472
473
474async def test_input_item_converter_user_input_with_tags():
475 class MyThreadItemConverter(ThreadItemConverter):
476 async def tag_to_message_content(self, tag):
477 return ResponseInputTextParam(
478 type="input_text", text=tag.text + " " + tag.data["key"]
479 )
480
481 items = [
482 UserMessageItem(
483 id="123",
484 content=[
485 UserMessageTagContent(
486 text="Hello!", type="input_tag", id="hello", data={"key": "value"}
487 )
488 ],
489 attachments=[],
490 inference_options=InferenceOptions(),
491 thread_id=thread.id,
492 created_at=datetime.now(),
493 )
494 ]
495 items = await MyThreadItemConverter().to_agent_input(items)
496
497 assert len(items) == 2
498 assert items[0] == {
499 "content": [
500 {
501 "text": "@Hello!",
502 "type": "input_text",
503 },
504 ],
505 "role": "user",
506 "type": "message",
507 }
508 assert items[1] == {
509 "content": [
510 {
511 "text": "# User-provided context for @-mentions\n- When referencing resolved entities, use their canonical names **without** '@'.\n"
512 + "- The '@' form appears only in user text and should not be echoed.",
513 "type": "input_text",
514 },
515 {
516 "text": "Hello! value",
517 "type": "input_text",
518 },
519 ],
520 "role": "user",
521 "type": "message",
522 }
523
524
525async def test_input_item_converter_user_input_with_tags_throws_by_default():
526 items = [
527 UserMessageItem(
528 id="123",
529 content=[
530 UserMessageTagContent(
531 text="Hello!", type="input_tag", id="hello", data={}
532 )
533 ],
534 attachments=[],
535 inference_options=InferenceOptions(),
536 thread_id=thread.id,
537 created_at=datetime.now(),
538 )
539 ]
540
541 with pytest.raises(NotImplementedError):
542 await simple_to_agent_input(items)
543
544
545async def test_input_item_converter_with_client_tool_call():
546 items = [
547 UserMessageItem(
548 id="123",
549 content=[UserMessageTextContent(text="Call a client tool call xyz")],
550 attachments=[],
551 inference_options=InferenceOptions(),
552 thread_id=thread.id,
553 quoted_text="Hi!",
554 created_at=datetime.now(),
555 ),
556 TaskItem(
557 id="tsk_123",
558 created_at=datetime.now(),
559 task=CustomTask(title="Called xyx"),
560 thread_id=thread.id,
561 ),
562 ClientToolCallItem(
563 id="ctc_123",
564 thread_id=thread.id,
565 created_at=datetime.now(),
566 name="xyz",
567 arguments={"foo": "bar"},
568 call_id="ctc_123",
569 ),
570 ClientToolCallItem(
571 id="ctc_123_done",
572 thread_id=thread.id,
573 created_at=datetime.now(),
574 name="xyz",
575 arguments={"foo": "bar"},
576 call_id="ctc_123",
577 status="completed",
578 output={"success": True},
579 ),
580 ]
581
582 input_items = await simple_to_agent_input(items)
583 assert len(input_items) == 4
584 assert input_items[0] == {
585 "content": [
586 {
587 "text": "Call a client tool call xyz",
588 "type": "input_text",
589 },
590 ],
591 "role": "user",
592 "type": "message",
593 }
594 assert input_items[1] == {
595 "content": [
596 {
597 "text": "A message was displayed to the user that the following task was performed:\n<Task>\nCalled xyx\n</Task>",
598 "type": "input_text",
599 },
600 ],
601 "type": "message",
602 "role": "user",
603 }
604 assert input_items[2] == {
605 "type": "function_call",
606 "name": "xyz",
607 "arguments": json.dumps({"foo": "bar"}),
608 "call_id": "ctc_123",
609 }
610 assert input_items[3] == {
611 "type": "function_call_output",
612 "call_id": "ctc_123",
613 "output": json.dumps({"success": True}),
614 }
615
616
617async def test_stream_agent_response_yields_context_events_without_streaming_events():
618 context = AgentContext(
619 previous_response_id=None, thread=thread, store=mock_store, request_context=None
620 )
621 result = make_result()
622
623 event = ThreadItemAddedEvent(
624 item=WidgetItem(
625 id="123",
626 created_at=datetime.now(),
627 thread_id=thread.id,
628 widget=Card(children=[Text(id="text", value="Hello, world!")]),
629 ),
630 )
631
632 await context.stream(event)
633
634 response_streamer = stream_agent_response(context, result)
635 event = await response_streamer.__anext__()
636
637 assert event.type == "thread.item.added"
638
639 future = asyncio.ensure_future(response_streamer.__anext__())
640 assert future.done() is False
641
642 result.done()
643
644 try:
645 await future
646 assert False, "expected StopAsyncIteration"
647 except StopAsyncIteration:
648 pass
649
650 assert future.done() is True
651
652
653async def test_stream_agent_response_maps_events():
654 context = AgentContext(
655 previous_response_id=None, thread=thread, store=mock_store, request_context=None
656 )
657 result = make_result()
658
659 event = ThreadItemAddedEvent(
660 item=WidgetItem(
661 id="123",
662 created_at=datetime.now(),
663 thread_id=thread.id,
664 widget=Card(children=[Text(id="text", value="Hello, world!")]),
665 ),
666 )
667
668 await context.stream(event)
669 result.add_event(
670 RawResponsesStreamEvent(
671 type="raw_response_event",
672 data=ResponseTextDeltaEvent(
673 type="response.output_text.delta",
674 delta="Hello, world!",
675 content_index=0,
676 item_id="123",
677 logprobs=[],
678 output_index=0,
679 sequence_number=0,
680 ),
681 )
682 )
683
684 response_streamer = stream_agent_response(context, result)
685 event1 = await response_streamer.__anext__()
686 event2 = await response_streamer.__anext__()
687
688 assert {event1.type, event2.type} == {
689 "thread.item.added",
690 "thread.item.updated",
691 }
692
693 future = asyncio.ensure_future(response_streamer.__anext__())
694 assert future.done() is False
695
696 result.done()
697
698 try:
699 await future
700 assert False, "expected StopAsyncIteration"
701 except StopAsyncIteration:
702 pass
703
704 assert future.done() is True
705
706
707@pytest.mark.parametrize(
708 "raw_event,expected_event",
709 [
710 (
711 RawResponsesStreamEvent(
712 type="raw_response_event",
713 data=ResponseTextDeltaEvent(
714 type="response.output_text.delta",
715 delta="Hello, world!",
716 content_index=0,
717 item_id="123",
718 logprobs=[],
719 output_index=0,
720 sequence_number=0,
721 ),
722 ),
723 ThreadItemUpdated(
724 item_id="123",
725 update=AssistantMessageContentPartTextDelta(
726 content_index=0,
727 delta="Hello, world!",
728 ),
729 ),
730 ),
731 (
732 RawResponsesStreamEvent(
733 type="raw_response_event",
734 data=ResponseContentPartAddedEvent(
735 type="response.content_part.added",
736 part=ResponseOutputText(
737 type="output_text",
738 text="New content",
739 annotations=[],
740 ),
741 content_index=1,
742 item_id="123",
743 output_index=0,
744 sequence_number=1,
745 ),
746 ),
747 ThreadItemUpdated(
748 item_id="123",
749 update=AssistantMessageContentPartAdded(
750 content_index=1,
751 content=AssistantMessageContent(text="New content", annotations=[]),
752 ),
753 ),
754 ),
755 (
756 RawResponsesStreamEvent(
757 type="raw_response_event",
758 data=ResponseTextDoneEvent(
759 type="response.output_text.done",
760 text="Final text",
761 content_index=0,
762 item_id="123",
763 logprobs=[],
764 output_index=0,
765 sequence_number=2,
766 ),
767 ),
768 ThreadItemUpdated(
769 item_id="123",
770 update=AssistantMessageContentPartDone(
771 content_index=0,
772 content=AssistantMessageContent(
773 text="Final text",
774 annotations=[],
775 ),
776 ),
777 ),
778 ),
779 (
780 RawResponsesStreamEvent(
781 type="raw_response_event",
782 data=Mock(
783 type="response.output_text.annotation.added",
784 annotation=ResponsesAnnotationFileCitation(
785 type="file_citation",
786 file_id="file_123",
787 filename="file.txt",
788 index=5,
789 ),
790 content_index=0,
791 item_id="123",
792 annotation_index=0,
793 output_index=0,
794 sequence_number=3,
795 ),
796 ),
797 ThreadItemUpdated(
798 item_id="123",
799 update=AssistantMessageContentPartAnnotationAdded(
800 content_index=0,
801 annotation_index=0,
802 annotation=Annotation(
803 source=FileSource(filename="file.txt", title="file.txt"),
804 index=5,
805 ),
806 ),
807 ),
808 ),
809 ],
810)
811async def test_event_mapping(raw_event, expected_event):
812 context = AgentContext(
813 previous_response_id=None, thread=thread, store=mock_store, request_context=None
814 )
815 result = make_result()
816
817 result.add_event(raw_event)
818 result.done()
819
820 events = await all_events(stream_agent_response(context, result))
821 if expected_event:
822 assert events == [expected_event]
823 else:
824 assert events == []
825
826
827async def test_stream_agent_response_emits_annotation_added_events():
828 context = AgentContext(
829 previous_response_id=None, thread=thread, store=mock_store, request_context=None
830 )
831 result = make_result()
832 item_id = "item_123"
833
834 def add_annotation_event(annotation, sequence_number):
835 result.add_event(
836 RawResponsesStreamEvent(
837 type="raw_response_event",
838 data=Mock(
839 type="response.output_text.annotation.added",
840 annotation=annotation,
841 content_index=0,
842 item_id=item_id,
843 annotation_index=sequence_number,
844 output_index=0,
845 sequence_number=sequence_number,
846 ),
847 )
848 )
849
850 add_annotation_event(
851 ResponsesAnnotationFileCitation(
852 type="file_citation",
853 file_id="file_invalid",
854 filename="",
855 index=0,
856 ),
857 sequence_number=0,
858 )
859 add_annotation_event(
860 ResponsesAnnotationContainerFileCitation(
861 type="container_file_citation",
862 container_id="container_1",
863 file_id="file_123",
864 filename="container.txt",
865 start_index=0,
866 end_index=3,
867 ),
868 sequence_number=1,
869 )
870 add_annotation_event(
871 ResponsesAnnotationURLCitation(
872 type="url_citation",
873 url="https://example.com",
874 title="Example",
875 start_index=1,
876 end_index=5,
877 ),
878 sequence_number=2,
879 )
880 result.done()
881
882 events = await all_events(stream_agent_response(context, result))
883 assert events == [
884 ThreadItemUpdated(
885 item_id=item_id,
886 update=AssistantMessageContentPartAnnotationAdded(
887 content_index=0,
888 annotation_index=0,
889 annotation=Annotation(
890 source=FileSource(filename="container.txt", title="container.txt"),
891 index=3,
892 ),
893 ),
894 ),
895 ThreadItemUpdated(
896 item_id=item_id,
897 update=AssistantMessageContentPartAnnotationAdded(
898 content_index=0,
899 annotation_index=1,
900 annotation=Annotation(
901 source=URLSource(
902 url="https://example.com",
903 title="Example",
904 ),
905 index=5,
906 ),
907 ),
908 ),
909 ]
910
911
912@pytest.mark.parametrize("throw_guardrail", ["input", "output"])
913async def test_stream_agent_response_yields_item_removed_event(throw_guardrail):
914 context = AgentContext(
915 previous_response_id=None, thread=thread, store=mock_store, request_context=None
916 )
917 result = make_result()
918 result.add_event(
919 RawResponsesStreamEvent(
920 type="raw_response_event",
921 data=ResponseOutputItemAddedEvent(
922 type="response.output_item.added",
923 item=ResponseOutputMessage(
924 id="1",
925 content=[
926 ResponseOutputText(
927 annotations=[], type="output_text", text="Hello, world!"
928 )
929 ],
930 role="assistant",
931 status="completed",
932 type="message",
933 ),
934 output_index=0,
935 sequence_number=0,
936 ),
937 )
938 )
939 await context.stream(
940 ThreadItemAddedEvent(
941 item=AssistantMessageItem(
942 id="2",
943 content=[AssistantMessageContent(text="Hello, world!")],
944 thread_id=thread.id,
945 created_at=datetime.now(),
946 ),
947 )
948 )
949
950 await context.stream(
951 ThreadItemDoneEvent(
952 item=WidgetItem(
953 id="3",
954 created_at=datetime.now(),
955 thread_id=thread.id,
956 widget=Card(children=[Text(id="text", value="Hello, world!")]),
957 ),
958 )
959 )
960
961 iterator = stream_agent_response(context, result)
962
963 n = 3
964 events = []
965 # Grab first 3 events to
966 async for event in iterator:
967 n -= 1
968 events.append(event)
969 if n == 0:
970 break
971
972 if throw_guardrail == "input":
973 result.throw_input_guardrails()
974 else:
975 result.throw_output_guardrails()
976
977 try:
978 async for event in iterator:
979 events.append(event)
980 assert False, "Guardrail should have been thrown from stream_agent_response"
981 except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
982 pass
983 except Exception as e:
984 assert False, f"Unexpected exception: {e}"
985
986 deleted_item_ids = {
987 event.item_id for event in events if event.type == "thread.item.removed"
988 }
989 assert deleted_item_ids == {"1", "2", "3"}
990
991
992async def test_stream_agent_response_assistant_message_content_types():
993 AgentContext(
994 previous_response_id=None, thread=thread, store=mock_store, request_context=None
995 )
996 result = make_result()
997
998 result.add_event(
999 RawResponsesStreamEvent(
1000 type="raw_response_event",
1001 data=ResponseOutputItemDoneEvent(
1002 type="response.output_item.done",
1003 item=ResponseFileSearchToolCall(
1004 id="fs_0",
1005 queries=["Hello, world!"],
1006 status="completed",
1007 type="file_search_call",
1008 results=[
1009 Result(
1010 file_id="f_123",
1011 filename="test.txt",
1012 text="Hello, world!",
1013 score=1.0,
1014 ),
1015 Result(
1016 file_id="f_123",
1017 filename="test.txt",
1018 text="Hello, friends!",
1019 score=0.5,
1020 ),
1021 ],
1022 ),
1023 output_index=0,
1024 sequence_number=0,
1025 ),
1026 )
1027 )
1028 result.add_event(
1029 RawResponsesStreamEvent(
1030 type="raw_response_event",
1031 data=ResponseOutputItemDoneEvent(
1032 type="response.output_item.done",
1033 item=ResponseOutputMessage(
1034 id="1",
1035 content=[
1036 ResponseOutputText(
1037 annotations=[
1038 ResponsesAnnotationFileCitation(
1039 type="file_citation",
1040 file_id="f_123",
1041 index=0,
1042 filename="test.txt",
1043 ),
1044 ResponsesAnnotationContainerFileCitation(
1045 type="container_file_citation",
1046 container_id="container_1",
1047 file_id="f_456",
1048 filename="container.txt",
1049 start_index=0,
1050 end_index=3,
1051 ),
1052 ResponsesAnnotationURLCitation(
1053 type="url_citation",
1054 url="https://www.google.com",
1055 title="Google",
1056 start_index=0,
1057 end_index=10,
1058 ),
1059 ResponsesAnnotationFilePath(
1060 type="file_path",
1061 file_id="123",
1062 index=0,
1063 ),
1064 ],
1065 text="Hello, world!",
1066 type="output_text",
1067 ),
1068 ResponseOutputText(
1069 annotations=[],
1070 text="Can't do that",
1071 type="output_text",
1072 ),
1073 ],
1074 role="assistant",
1075 status="completed",
1076 type="message",
1077 ),
1078 output_index=0,
1079 sequence_number=0,
1080 ),
1081 )
1082 )
1083
1084 result.done()
1085
1086 context = AgentContext(
1087 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1088 )
1089 events = await all_events(stream_agent_response(context, result))
1090 assert len(events) == 1
1091 assert isinstance(events[0], ThreadItemDoneEvent)
1092 message = events[0].item
1093 assert isinstance(message, AssistantMessageItem)
1094 assert message.content == [
1095 AssistantMessageContent(
1096 annotations=[
1097 Annotation(
1098 source=FileSource(
1099 filename="test.txt",
1100 title="test.txt",
1101 ),
1102 index=0,
1103 ),
1104 Annotation(
1105 source=FileSource(
1106 filename="container.txt",
1107 title="container.txt",
1108 ),
1109 index=3,
1110 ),
1111 Annotation(
1112 source=URLSource(
1113 url="https://www.google.com",
1114 title="Google",
1115 ),
1116 index=10,
1117 ),
1118 ],
1119 text="Hello, world!",
1120 ),
1121 AssistantMessageContent(text="Can't do that", annotations=[]),
1122 ]
1123 assert message.id == "1"
1124
1125
1126async def test_workflow_streams_first_thought():
1127 context = AgentContext(
1128 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1129 )
1130 result = make_result()
1131
1132 # first thought
1133 result.add_event(
1134 RawResponsesStreamEvent(
1135 type="raw_response_event",
1136 data=ResponseOutputItemAddedEvent(
1137 type="response.output_item.added",
1138 item=ResponseReasoningItem(
1139 id="resp_1",
1140 summary=[],
1141 type="reasoning",
1142 ),
1143 output_index=0,
1144 sequence_number=0,
1145 ),
1146 )
1147 )
1148 result.add_event(
1149 RawResponsesStreamEvent(
1150 type="raw_response_event",
1151 data=Mock(
1152 type="response.reasoning_summary_text.delta",
1153 item_id="resp_1",
1154 summary_index=0,
1155 delta="Think",
1156 ),
1157 )
1158 )
1159 result.add_event(
1160 RawResponsesStreamEvent(
1161 type="raw_response_event",
1162 data=Mock(
1163 type="response.reasoning_summary_text.delta",
1164 item_id="resp_1",
1165 summary_index=0,
1166 delta="ing 1",
1167 ),
1168 )
1169 )
1170 result.add_event(
1171 RawResponsesStreamEvent(
1172 type="raw_response_event",
1173 data=Mock(
1174 type="response.reasoning_summary_text.done",
1175 item_id="resp_1",
1176 summary_index=0,
1177 text="Thinking 1",
1178 ),
1179 )
1180 )
1181
1182 # second thought
1183 result.add_event(
1184 RawResponsesStreamEvent(
1185 type="raw_response_event",
1186 data=Mock(
1187 type="response.reasoning_summary_text.delta",
1188 item_id="resp_1",
1189 summary_index=1,
1190 delta="Think",
1191 ),
1192 )
1193 )
1194 result.add_event(
1195 RawResponsesStreamEvent(
1196 type="raw_response_event",
1197 data=Mock(
1198 type="response.reasoning_summary_text.delta",
1199 item_id="resp_1",
1200 summary_index=1,
1201 delta="ing 2",
1202 ),
1203 )
1204 )
1205 result.add_event(
1206 RawResponsesStreamEvent(
1207 type="raw_response_event",
1208 data=Mock(
1209 type="response.reasoning_summary_text.done",
1210 item_id="resp_1",
1211 summary_index=1,
1212 text="Thinking 2",
1213 ),
1214 )
1215 )
1216
1217 result.done()
1218 stream = stream_agent_response(context, result)
1219
1220 # Workflow added
1221 event = await anext(stream)
1222 assert isinstance(event, ThreadItemAddedEvent)
1223 assert context.workflow_item is not None
1224 assert context.workflow_item.workflow.type == "reasoning"
1225 assert len(context.workflow_item.workflow.tasks) == 0
1226 assert event == ThreadItemAddedEvent(item=context.workflow_item)
1227
1228 # First thought added
1229 event = await anext(stream)
1230 assert context.workflow_item is not None
1231 assert len(context.workflow_item.workflow.tasks) == 1
1232 assert isinstance(event, ThreadItemUpdated)
1233 assert event == ThreadItemUpdated(
1234 item_id=context.workflow_item.id,
1235 update=WorkflowTaskAdded(
1236 task=ThoughtTask(content="Think"),
1237 task_index=0,
1238 ),
1239 )
1240
1241 # First thought delta
1242 event = await anext(stream)
1243 assert context.workflow_item is not None
1244 assert len(context.workflow_item.workflow.tasks) == 1
1245 assert isinstance(event, ThreadItemUpdated)
1246 assert event == ThreadItemUpdated(
1247 item_id=context.workflow_item.id,
1248 update=WorkflowTaskUpdated(
1249 task=ThoughtTask(content="Thinking 1"),
1250 task_index=0,
1251 ),
1252 )
1253
1254 # First thought done
1255 event = await anext(stream)
1256 assert context.workflow_item is not None
1257 assert len(context.workflow_item.workflow.tasks) == 1
1258 assert isinstance(event, ThreadItemUpdated)
1259 assert event == ThreadItemUpdated(
1260 item_id=context.workflow_item.id,
1261 update=WorkflowTaskUpdated(
1262 task=ThoughtTask(content="Thinking 1"),
1263 task_index=0,
1264 ),
1265 )
1266
1267 # Second thought added (not streamed)
1268 event = await anext(stream)
1269 assert context.workflow_item is not None
1270 assert len(context.workflow_item.workflow.tasks) == 2
1271 assert isinstance(event, ThreadItemUpdated)
1272 assert event == ThreadItemUpdated(
1273 item_id=context.workflow_item.id,
1274 update=WorkflowTaskAdded(
1275 task=ThoughtTask(content="Thinking 2"),
1276 task_index=1,
1277 ),
1278 )
1279
1280 try:
1281 while True:
1282 await anext(stream)
1283 except StopAsyncIteration:
1284 pass
1285
1286
1287async def test_workflow_ends_on_message():
1288 context = AgentContext(
1289 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1290 )
1291 result = make_result()
1292
1293 # first thought
1294 result.add_event(
1295 RawResponsesStreamEvent(
1296 type="raw_response_event",
1297 data=ResponseOutputItemAddedEvent(
1298 type="response.output_item.added",
1299 item=ResponseReasoningItem(
1300 id="resp_1",
1301 summary=[],
1302 type="reasoning",
1303 ),
1304 output_index=0,
1305 sequence_number=0,
1306 ),
1307 )
1308 )
1309 result.add_event(
1310 RawResponsesStreamEvent(
1311 type="raw_response_event",
1312 data=Mock(
1313 type="response.reasoning_summary_text.done",
1314 item_id="resp_1",
1315 summary_index=0,
1316 text="Thinking 1",
1317 ),
1318 )
1319 )
1320
1321 # not reasoning
1322 result.add_event(
1323 RawResponsesStreamEvent(
1324 type="raw_response_event",
1325 data=ResponseOutputItemAddedEvent(
1326 type="response.output_item.added",
1327 item=ResponseOutputMessage(
1328 id="m_1",
1329 content=[],
1330 role="assistant",
1331 status="in_progress",
1332 type="message",
1333 ),
1334 output_index=0,
1335 sequence_number=0,
1336 ),
1337 )
1338 )
1339
1340 result.done()
1341 stream = stream_agent_response(context, result)
1342
1343 # Workflow added
1344 event = await anext(stream)
1345 assert isinstance(event, ThreadItemAddedEvent)
1346 assert context.workflow_item is not None
1347 assert context.workflow_item.workflow.type == "reasoning"
1348 assert len(context.workflow_item.workflow.tasks) == 0
1349 assert event == ThreadItemAddedEvent(item=context.workflow_item)
1350
1351 # First thought done
1352 event = await anext(stream)
1353 assert context.workflow_item is not None
1354 assert len(context.workflow_item.workflow.tasks) == 1
1355 assert isinstance(event, ThreadItemUpdated)
1356 assert event == ThreadItemUpdated(
1357 item_id=context.workflow_item.id,
1358 update=WorkflowTaskAdded(
1359 task=ThoughtTask(content="Thinking 1"),
1360 task_index=0,
1361 ),
1362 )
1363
1364 # Workflow ended
1365 event = await anext(stream)
1366 assert isinstance(event, ThreadItemDoneEvent)
1367 assert event.item.type == "workflow"
1368 assert context.workflow_item is None
1369 # Summary and expanded are handled by the end_workflow method
1370 assert isinstance(event.item.workflow.summary, DurationSummary)
1371 assert event.item.workflow.expanded is False
1372
1373 try:
1374 while True:
1375 await anext(stream)
1376 except StopAsyncIteration:
1377 pass
1378
1379
1380async def test_existing_workflow_summary_not_overwritten_on_automatic_end():
1381 context = AgentContext(
1382 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1383 )
1384 result = make_result()
1385 context.workflow_item = WorkflowItem(
1386 id="wf_1",
1387 created_at=datetime.now(),
1388 workflow=Workflow(type="custom", tasks=[], summary=CustomSummary(title="Test")),
1389 thread_id=thread.id,
1390 )
1391
1392 result.add_event(
1393 RawResponsesStreamEvent(
1394 type="raw_response_event",
1395 data=ResponseOutputItemAddedEvent(
1396 type="response.output_item.added",
1397 item=ResponseOutputMessage(
1398 id="m_1",
1399 content=[],
1400 role="assistant",
1401 status="in_progress",
1402 type="message",
1403 ),
1404 output_index=0,
1405 sequence_number=0,
1406 ),
1407 )
1408 )
1409
1410 result.done()
1411 stream = stream_agent_response(context, result)
1412
1413 event = await anext(stream)
1414
1415 assert isinstance(event, ThreadItemDoneEvent)
1416 assert context.workflow_item is None
1417 assert event.item.type == "workflow"
1418 assert event.item.workflow.summary == CustomSummary(title="Test")
1419
1420 try:
1421 while True:
1422 await anext(stream)
1423 except StopAsyncIteration:
1424 pass
1425