openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.6.5

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_agents.py

2035lines · modecode

1import asyncio
2import json
3from collections.abc import AsyncIterator
4from datetime import datetime
5from importlib import import_module
6from typing import cast
7from unittest.mock import AsyncMock, Mock
8
9import pytest
10from agents import (
11 Agent,
12 GuardrailFunctionOutput,
13 InputGuardrail,
14 InputGuardrailResult,
15 InputGuardrailTripwireTriggered,
16 OutputGuardrail,
17 OutputGuardrailResult,
18 OutputGuardrailTripwireTriggered,
19 RawResponsesStreamEvent,
20 RunContextWrapper,
21 RunItemStreamEvent,
22 Runner,
23 RunResultStreaming,
24 StreamEvent,
25 ToolCallItem,
26)
27from openai.types.responses import (
28 EasyInputMessageParam,
29 ResponseFileSearchToolCall,
30 ResponseFunctionToolCall,
31 ResponseInputContentParam,
32 ResponseInputTextParam,
33 ResponseOutputItemAddedEvent,
34 ResponseOutputItemDoneEvent,
35 ResponseOutputMessage,
36 ResponseReasoningItem,
37)
38from openai.types.responses.response_content_part_added_event import (
39 ResponseContentPartAddedEvent,
40)
41from openai.types.responses.response_file_search_tool_call import Result
42from openai.types.responses.response_input_item_param import Message
43from openai.types.responses.response_output_text import (
44 AnnotationContainerFileCitation as ResponsesAnnotationContainerFileCitation,
45)
46from openai.types.responses.response_output_text import (
47 AnnotationFileCitation as ResponsesAnnotationFileCitation,
48)
49from openai.types.responses.response_output_text import (
50 AnnotationFilePath as ResponsesAnnotationFilePath,
51)
52from openai.types.responses.response_output_text import (
53 AnnotationURLCitation as ResponsesAnnotationURLCitation,
54)
55from openai.types.responses.response_output_text import (
56 ResponseOutputText,
57)
58from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent
59from openai.types.responses.response_text_done_event import ResponseTextDoneEvent
60
61from chatkit.agents import (
62 AgentContext,
63 ClientToolCall,
64 ResponseStreamConverter,
65 ThreadItemConverter,
66 accumulate_text,
67 simple_to_agent_input,
68 stream_agent_response,
69)
70from chatkit.types import (
71 Annotation,
72 AssistantMessageContent,
73 AssistantMessageContentPartAdded,
74 AssistantMessageContentPartAnnotationAdded,
75 AssistantMessageContentPartDone,
76 AssistantMessageContentPartTextDelta,
77 AssistantMessageItem,
78 Attachment,
79 ClientToolCallItem,
80 CustomSummary,
81 CustomTask,
82 DurationSummary,
83 FileSource,
84 GeneratedImage,
85 GeneratedImageItem,
86 GeneratedImageUpdated,
87 HiddenContextItem,
88 InferenceOptions,
89 Page,
90 StructuredInputAnswer,
91 StructuredInputFreeform,
92 StructuredInputItem,
93 StructuredInputMultipleChoice,
94 StructuredInputMultipleChoiceOption,
95 TaskItem,
96 ThoughtTask,
97 Thread,
98 ThreadItemAddedEvent,
99 ThreadItemDoneEvent,
100 ThreadItemUpdatedEvent,
101 ThreadStreamEvent,
102 URLSource,
103 UserMessageItem,
104 UserMessageTagContent,
105 UserMessageTextContent,
106 WidgetItem,
107 Workflow,
108 WorkflowItem,
109 WorkflowTaskAdded,
110 WorkflowTaskUpdated,
111)
112from chatkit.widgets import Card, Text
113
114thread = Thread(id="123", title="Test", created_at=datetime.now(), items=Page())
115
116mock_store = Mock()
117mock_store.generate_item_id = lambda item_type, thread, context: f"{item_type}_id"
118mock_store.load_thread_items = AsyncMock(return_value=Page())
119mock_store.add_thread_item = AsyncMock()
120
121
122def make_queue_complete_sentinel():
123 for module_name in (
124 "agents.result",
125 "agents.run_internal.run_steps",
126 "agents._run_impl",
127 ):
128 try:
129 module = import_module(module_name)
130 except ImportError:
131 continue
132
133 sentinel = getattr(module, "QueueCompleteSentinel", None)
134 if sentinel is not None:
135 return sentinel()
136
137 raise ImportError("QueueCompleteSentinel not found in installed agents package")
138
139
140class RunResult(RunResultStreaming):
141 def add_event(self, event: StreamEvent):
142 self._event_queue.put_nowait(event)
143
144 def done(self):
145 self.is_complete = True
146 self._event_queue.put_nowait(make_queue_complete_sentinel())
147
148 def throw_input_guardrails(self):
149 self._stored_exception = InputGuardrailTripwireTriggered(
150 InputGuardrailResult(
151 guardrail=Mock(spec=InputGuardrail),
152 output=GuardrailFunctionOutput(
153 output_info=None,
154 tripwire_triggered=True,
155 ),
156 )
157 )
158 self.is_complete = True
159 self._event_queue.put_nowait(make_queue_complete_sentinel())
160
161 def throw_output_guardrails(self):
162 self._stored_exception = OutputGuardrailTripwireTriggered(
163 OutputGuardrailResult(
164 guardrail=Mock(spec=OutputGuardrail),
165 output=GuardrailFunctionOutput(
166 output_info=None,
167 tripwire_triggered=True,
168 ),
169 agent=Mock(spec=Agent),
170 agent_output=None,
171 )
172 )
173 self.is_complete = True
174 self._event_queue.put_nowait(make_queue_complete_sentinel())
175
176
177def make_result() -> RunResult:
178 return RunResult(
179 context_wrapper=Mock(spec=RunContextWrapper),
180 input=[],
181 tool_input_guardrail_results=[],
182 tool_output_guardrail_results=[],
183 new_items=[],
184 raw_responses=[],
185 final_output=None,
186 current_agent=Agent(name="test"),
187 current_turn=0,
188 max_turns=10,
189 _current_agent_output_schema=None,
190 trace=None,
191 is_complete=False,
192 _event_queue=asyncio.Queue(),
193 _input_guardrail_queue=asyncio.Queue(),
194 _output_guardrails_task=None,
195 _run_impl_task=None,
196 _stored_exception=None,
197 output_guardrail_results=[],
198 input_guardrail_results=[],
199 )
200
201
202async def all_events(
203 events: AsyncIterator[ThreadStreamEvent],
204) -> list[ThreadStreamEvent]:
205 return [event async for event in events]
206
207
208async def test_returns_widget_item():
209 context = AgentContext(
210 previous_response_id=None, thread=thread, store=mock_store, request_context=None
211 )
212 result = make_result()
213 result.add_event(
214 RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
215 )
216 await context.stream_widget(Card(children=[Text(value="Hello, world!")]))
217 result.done()
218
219 events = await all_events(
220 stream_agent_response(
221 context=context,
222 result=result,
223 )
224 )
225
226 assert len(events) == 1
227 assert isinstance(events[0], ThreadItemDoneEvent)
228 assert isinstance(events[0].item, WidgetItem)
229 assert events[0].item.widget == Card(children=[Text(value="Hello, world!")])
230
231
232async def test_returns_widget_item_generator():
233 context = AgentContext(
234 previous_response_id=None, thread=thread, store=mock_store, request_context=None
235 )
236 result = make_result()
237 result.add_event(
238 RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
239 )
240
241 def render_widget(i: int) -> Card:
242 return Card(children=[Text(id="text", value="Hello, world"[:i])])
243
244 async def widget_generator():
245 yield render_widget(0)
246 yield render_widget(12)
247
248 await context.stream_widget(widget_generator())
249 result.done()
250
251 events = await all_events(
252 stream_agent_response(
253 context=context,
254 result=result,
255 )
256 )
257
258 assert len(events) == 3
259 assert isinstance(events[0], ThreadItemAddedEvent)
260 assert isinstance(events[0].item, WidgetItem)
261 assert events[0].item.widget == Card(children=[Text(id="text", value="")])
262
263 assert isinstance(events[1], ThreadItemUpdatedEvent)
264 assert events[1].update.type == "widget.streaming_text.value_delta"
265 assert events[1].update.component_id == "text"
266 assert events[1].update.delta == "Hello, world"
267
268 assert isinstance(events[2], ThreadItemDoneEvent)
269 assert isinstance(events[2].item, WidgetItem)
270 assert events[2].item.widget == Card(
271 children=[Text(id="text", value="Hello, world")]
272 )
273
274
275async def test_returns_widget_full_replace_generator():
276 context = AgentContext(
277 previous_response_id=None, thread=thread, store=mock_store, request_context=None
278 )
279 result = make_result()
280 result.add_event(
281 RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
282 )
283
284 async def widget_generator():
285 yield Card(children=[Text(id="text", value="Hello!")])
286 yield Card(children=[Text(key="other text", value="World!", streaming=False)])
287
288 await context.stream_widget(widget_generator())
289 result.done()
290
291 events = await all_events(
292 stream_agent_response(
293 context=context,
294 result=result,
295 )
296 )
297
298 assert len(events) == 3
299 assert isinstance(events[0], ThreadItemAddedEvent)
300 assert isinstance(events[0].item, WidgetItem)
301 assert events[0].item.widget == Card(children=[Text(id="text", value="Hello!")])
302
303 assert isinstance(events[1], ThreadItemUpdatedEvent)
304 assert events[1].update.type == "widget.root.updated"
305 assert events[1].update.widget == Card(
306 children=[Text(key="other text", value="World!", streaming=False)]
307 )
308
309 assert isinstance(events[2], ThreadItemDoneEvent)
310 assert isinstance(events[2].item, WidgetItem)
311 assert events[2].item.widget == Card(
312 children=[Text(key="other text", value="World!", streaming=False)]
313 )
314
315
316async def test_client_tool_call_uses_raw_function_call_id():
317 context = AgentContext(
318 previous_response_id=None,
319 thread=thread,
320 store=mock_store,
321 request_context=None,
322 client_tool_call=ClientToolCall(name="get_selection", arguments={}),
323 )
324 result = make_result()
325 raw_item = ResponseFunctionToolCall(
326 id="fc_123",
327 type="function_call",
328 call_id="call_123",
329 name="get_selection",
330 arguments="{}",
331 status="completed",
332 )
333 result.add_event(
334 RunItemStreamEvent(
335 name="tool_called",
336 item=ToolCallItem(
337 agent=Agent("Assistant"),
338 raw_item=raw_item,
339 ),
340 )
341 )
342 result.done()
343
344 events = await all_events(
345 stream_agent_response(
346 context=context,
347 result=result,
348 )
349 )
350
351 assert len(events) == 1
352 assert isinstance(events[0], ThreadItemDoneEvent)
353 assert isinstance(events[0].item, ClientToolCallItem)
354 assert events[0].item.id == "fc_123"
355 assert events[0].item.call_id == "call_123"
356
357
358async def test_accumulate_text():
359 def delta(text: str) -> RawResponsesStreamEvent:
360 return RawResponsesStreamEvent(
361 type="raw_response_event",
362 data=ResponseTextDeltaEvent(
363 type="response.output_text.delta",
364 delta=text,
365 content_index=0,
366 item_id="123",
367 logprobs=[],
368 output_index=0,
369 sequence_number=0,
370 ),
371 )
372
373 result = Runner.run_streamed(
374 Agent("Assistant", instructions="You are a helpful assistant."), "Say hello!"
375 )
376 result = make_result()
377 result.add_event(delta("Hello, "))
378 result.add_event(delta("world!"))
379
380 result.done()
381
382 events = [
383 event
384 async for event in accumulate_text(
385 result.stream_events(), Text(key="text", value="", streaming=True)
386 )
387 ]
388 assert events == [
389 Text(key="text", value="", streaming=True),
390 Text(key="text", value="Hello, ", streaming=True),
391 Text(key="text", value="Hello, world!", streaming=True),
392 Text(key="text", value="Hello, world!", streaming=False),
393 ]
394
395
396async def test_input_item_converter_quotes_last_user_message():
397 items = [
398 UserMessageItem(
399 id="123",
400 content=[UserMessageTextContent(text="Hello!")],
401 attachments=[],
402 inference_options=InferenceOptions(),
403 thread_id=thread.id,
404 quoted_text="Hi!",
405 created_at=datetime.now(),
406 ),
407 UserMessageItem(
408 id="123",
409 content=[UserMessageTextContent(text="I'm well, thank you!")],
410 attachments=[],
411 inference_options=InferenceOptions(),
412 thread_id=thread.id,
413 quoted_text="How are you doing?",
414 created_at=datetime.now(),
415 ),
416 ]
417
418 async def throw_exception(
419 _: Attachment,
420 ) -> ResponseInputContentParam:
421 raise Exception("Not implemented")
422
423 input_items = await simple_to_agent_input(items)
424 assert len(input_items) == 3
425 assert input_items[0] == {
426 "content": [
427 {
428 "text": "Hello!",
429 "type": "input_text",
430 },
431 ],
432 "role": "user",
433 "type": "message",
434 }
435 assert input_items[1] == {
436 "content": [
437 {
438 "text": "I'm well, thank you!",
439 "type": "input_text",
440 },
441 ],
442 "role": "user",
443 "type": "message",
444 }
445 assert input_items[2] == {
446 "content": [
447 {
448 "text": "The user is referring to this in particular: \nHow are you doing?",
449 "type": "input_text",
450 },
451 ],
452 "role": "user",
453 "type": "message",
454 }
455
456
457async def test_input_item_converter_to_input_items_mixed():
458 items = [
459 UserMessageItem(
460 id="123",
461 content=[UserMessageTextContent(text="Hello!")],
462 attachments=[],
463 inference_options=InferenceOptions(),
464 thread_id=thread.id,
465 quoted_text="Hi!",
466 created_at=datetime.now(),
467 ),
468 UserMessageItem(
469 id="123",
470 content=[UserMessageTextContent(text="I'm well, thank you!")],
471 attachments=[],
472 inference_options=InferenceOptions(),
473 thread_id=thread.id,
474 quoted_text="How are you doing?",
475 created_at=datetime.now(),
476 ),
477 AssistantMessageItem(
478 id="123",
479 content=[
480 AssistantMessageContent(text="How are you doing?"),
481 AssistantMessageContent(text="Can't do that"),
482 ],
483 thread_id=thread.id,
484 created_at=datetime.now(),
485 ),
486 WidgetItem(
487 id="wd_123",
488 widget=Card(children=[Text(value="Hello, world!")]),
489 thread_id=thread.id,
490 created_at=datetime.now(),
491 ),
492 ]
493
494 input_items = await simple_to_agent_input(items)
495 assert len(input_items) == 4
496 assert input_items[0] == {
497 "content": [
498 {
499 "text": "Hello!",
500 "type": "input_text",
501 },
502 ],
503 "role": "user",
504 "type": "message",
505 }
506 assert input_items[1] == {
507 "content": [
508 {
509 "text": "I'm well, thank you!",
510 "type": "input_text",
511 },
512 ],
513 "role": "user",
514 "type": "message",
515 }
516 assert input_items[2] == {
517 "content": [
518 {
519 "annotations": [],
520 "text": "How are you doing?",
521 "logprobs": None,
522 "type": "output_text",
523 },
524 {
525 "annotations": [],
526 "text": "Can't do that",
527 "logprobs": None,
528 "type": "output_text",
529 },
530 ],
531 "type": "message",
532 "role": "assistant",
533 }
534 assert "type" in input_items[3]
535 widget_item = cast(EasyInputMessageParam, input_items[3])
536 assert widget_item.get("type") == "message"
537 assert widget_item.get("role") == "user"
538 text = widget_item.get("content")[0]["text"] # type: ignore
539 assert (
540 "The following graphical UI widget (id: wd_123) was displayed to the user"
541 in text
542 )
543 assert "Hello, world!" in text
544 assert "created_at" not in text
545
546
547async def test_input_item_converter_user_input_with_tags():
548 class MyThreadItemConverter(ThreadItemConverter):
549 async def tag_to_message_content(self, tag):
550 return ResponseInputTextParam(
551 type="input_text", text=tag.text + " " + tag.data["key"]
552 )
553
554 items = [
555 UserMessageItem(
556 id="123",
557 content=[
558 UserMessageTagContent(
559 text="Hello!", type="input_tag", id="hello", data={"key": "value"}
560 )
561 ],
562 attachments=[],
563 inference_options=InferenceOptions(),
564 thread_id=thread.id,
565 created_at=datetime.now(),
566 )
567 ]
568 items = await MyThreadItemConverter().to_agent_input(items)
569
570 assert len(items) == 2
571 assert items[0] == {
572 "content": [
573 {
574 "text": "@Hello!",
575 "type": "input_text",
576 },
577 ],
578 "role": "user",
579 "type": "message",
580 }
581 assert items[1] == {
582 "content": [
583 {
584 "text": "# User-provided context for @-mentions\n- When referencing resolved entities, use their canonical names **without** '@'.\n"
585 + "- The '@' form appears only in user text and should not be echoed.",
586 "type": "input_text",
587 },
588 {
589 "text": "Hello! value",
590 "type": "input_text",
591 },
592 ],
593 "role": "user",
594 "type": "message",
595 }
596
597
598async def test_input_item_converter_user_input_with_tags_throws_by_default():
599 items = [
600 UserMessageItem(
601 id="123",
602 content=[
603 UserMessageTagContent(
604 text="Hello!", type="input_tag", id="hello", data={}
605 )
606 ],
607 attachments=[],
608 inference_options=InferenceOptions(),
609 thread_id=thread.id,
610 created_at=datetime.now(),
611 )
612 ]
613
614 with pytest.raises(NotImplementedError):
615 await simple_to_agent_input(items)
616
617
618async def test_input_item_converter_generated_image_item():
619 items = [
620 GeneratedImageItem(
621 id="img_item_1",
622 thread_id=thread.id,
623 created_at=datetime.now(),
624 image=GeneratedImage(id="img_1", url="https://example.com/img.png"),
625 )
626 ]
627
628 input_items = await simple_to_agent_input(items)
629 assert len(input_items) == 1
630
631 message = cast(dict, input_items[0])
632 assert message.get("type") == "message"
633 assert message.get("role") == "user"
634
635 content = cast(list, message.get("content"))
636 assert content[0].get("type") == "input_text"
637 assert content[0].get("text") == "The following image was generated by the agent."
638 assert content[1].get("type") == "input_image"
639 assert content[1].get("file_id") is None
640 assert content[1].get("image_url") == "https://example.com/img.png"
641 assert content[1].get("detail") == "auto"
642
643
644async def test_input_item_converter_generated_image_item_without_image():
645 items = [
646 GeneratedImageItem(
647 id="img_item_1",
648 thread_id=thread.id,
649 created_at=datetime.now(),
650 )
651 ]
652
653 input_items = await simple_to_agent_input(items)
654 assert input_items == []
655
656
657async def test_input_item_converter_for_hidden_context_with_string_content():
658 items = [
659 HiddenContextItem(
660 id="123",
661 content="User pressed the red button",
662 thread_id=thread.id,
663 created_at=datetime.now(),
664 )
665 ]
666
667 # The default converter works for string content
668 items = await simple_to_agent_input(items)
669 assert len(items) == 1
670 assert items[0] == {
671 "content": [
672 {
673 "text": "Hidden context for the agent (not shown to the user):\n<HiddenContext>\nUser pressed the red button\n</HiddenContext>",
674 "type": "input_text",
675 },
676 ],
677 "role": "user",
678 "type": "message",
679 }
680
681
682async def test_input_item_converter_for_hidden_context_with_non_string_content():
683 items = [
684 HiddenContextItem(
685 id="123",
686 content={"harry": "potter", "hermione": "granger"},
687 thread_id=thread.id,
688 created_at=datetime.now(),
689 )
690 ]
691
692 # Default converter should throw
693 with pytest.raises(NotImplementedError):
694 await simple_to_agent_input(items)
695
696 class MyThreadItemConverter(ThreadItemConverter):
697 async def hidden_context_to_input(self, item: HiddenContextItem):
698 content = ",".join(item.content.keys())
699 return Message(
700 type="message",
701 content=[
702 ResponseInputTextParam(
703 type="input_text", text=f"<HIDDEN>{content}</HIDDEN>"
704 )
705 ],
706 role="user",
707 )
708
709 items = await MyThreadItemConverter().to_agent_input(items)
710 assert len(items) == 1
711 assert items[0] == {
712 "content": [
713 {
714 "text": "<HIDDEN>harry,hermione</HIDDEN>",
715 "type": "input_text",
716 },
717 ],
718 "role": "user",
719 "type": "message",
720 }
721
722
723async def test_input_item_converter_with_client_tool_call():
724 items = [
725 UserMessageItem(
726 id="123",
727 content=[UserMessageTextContent(text="Call a client tool call xyz")],
728 attachments=[],
729 inference_options=InferenceOptions(),
730 thread_id=thread.id,
731 quoted_text="Hi!",
732 created_at=datetime.now(),
733 ),
734 TaskItem(
735 id="tsk_123",
736 created_at=datetime.now(),
737 task=CustomTask(title="Called xyx"),
738 thread_id=thread.id,
739 ),
740 ClientToolCallItem(
741 id="ctc_123",
742 thread_id=thread.id,
743 created_at=datetime.now(),
744 name="xyz",
745 arguments={"foo": "bar"},
746 call_id="ctc_123",
747 ),
748 ClientToolCallItem(
749 id="ctc_123_done",
750 thread_id=thread.id,
751 created_at=datetime.now(),
752 name="xyz",
753 arguments={"foo": "bar"},
754 call_id="ctc_123",
755 status="completed",
756 output={"success": True},
757 ),
758 ]
759
760 input_items = await simple_to_agent_input(items)
761 assert len(input_items) == 4
762 assert input_items[0] == {
763 "content": [
764 {
765 "text": "Call a client tool call xyz",
766 "type": "input_text",
767 },
768 ],
769 "role": "user",
770 "type": "message",
771 }
772 assert input_items[1] == {
773 "content": [
774 {
775 "text": "A message was displayed to the user that the following task was performed:\n<Task>\nCalled xyx\n</Task>",
776 "type": "input_text",
777 },
778 ],
779 "type": "message",
780 "role": "user",
781 }
782 assert input_items[2] == {
783 "type": "function_call",
784 "name": "xyz",
785 "arguments": json.dumps({"foo": "bar"}),
786 "call_id": "ctc_123",
787 }
788 assert input_items[3] == {
789 "type": "function_call_output",
790 "call_id": "ctc_123",
791 "output": json.dumps({"success": True}),
792 }
793
794
795async def test_input_item_converter_with_structured_input():
796 items = [
797 StructuredInputItem(
798 id="si_123",
799 thread_id=thread.id,
800 created_at=datetime.now(),
801 status="answered",
802 inputs=[
803 StructuredInputMultipleChoice(
804 id="subject",
805 question="Which subject is this lesson for?",
806 options=[
807 StructuredInputMultipleChoiceOption(value="Math"),
808 StructuredInputMultipleChoiceOption(value="Science"),
809 ],
810 answer=StructuredInputAnswer(values=["Math"]),
811 ),
812 StructuredInputFreeform(
813 id="details",
814 question="Anything else to know?",
815 answer=StructuredInputAnswer(skipped=True),
816 ),
817 ],
818 )
819 ]
820
821 input_items = await simple_to_agent_input(items)
822 assert len(input_items) == 1
823 assert input_items[0] == {
824 "content": [
825 {
826 "text": "A structured input request was displayed to the user with the following status: answered\n<StructuredInput>\n- Which subject is this lesson for?: Math\n- Anything else to know?: skipped\n</StructuredInput>",
827 "type": "input_text",
828 },
829 ],
830 "role": "user",
831 "type": "message",
832 }
833
834
835async def test_stream_agent_response_yields_context_events_without_streaming_events():
836 context = AgentContext(
837 previous_response_id=None, thread=thread, store=mock_store, request_context=None
838 )
839 result = make_result()
840
841 event = ThreadItemAddedEvent(
842 item=WidgetItem(
843 id="123",
844 created_at=datetime.now(),
845 thread_id=thread.id,
846 widget=Card(children=[Text(id="text", value="Hello, world!")]),
847 ),
848 )
849
850 await context.stream(event)
851
852 response_streamer = stream_agent_response(context, result)
853 event = await response_streamer.__anext__()
854
855 assert event.type == "thread.item.added"
856
857 future = asyncio.ensure_future(response_streamer.__anext__())
858 assert future.done() is False
859
860 result.done()
861
862 try:
863 await future
864 assert False, "expected StopAsyncIteration"
865 except StopAsyncIteration:
866 pass
867
868 assert future.done() is True
869
870
871async def test_stream_agent_response_maps_events():
872 context = AgentContext(
873 previous_response_id=None, thread=thread, store=mock_store, request_context=None
874 )
875 result = make_result()
876
877 event = ThreadItemAddedEvent(
878 item=WidgetItem(
879 id="123",
880 created_at=datetime.now(),
881 thread_id=thread.id,
882 widget=Card(children=[Text(id="text", value="Hello, world!")]),
883 ),
884 )
885
886 await context.stream(event)
887 result.add_event(
888 RawResponsesStreamEvent(
889 type="raw_response_event",
890 data=ResponseTextDeltaEvent(
891 type="response.output_text.delta",
892 delta="Hello, world!",
893 content_index=0,
894 item_id="123",
895 logprobs=[],
896 output_index=0,
897 sequence_number=0,
898 ),
899 )
900 )
901
902 response_streamer = stream_agent_response(context, result)
903 event1 = await response_streamer.__anext__()
904 event2 = await response_streamer.__anext__()
905
906 assert {event1.type, event2.type} == {
907 "thread.item.added",
908 "thread.item.updated",
909 }
910
911 future = asyncio.ensure_future(response_streamer.__anext__())
912 assert future.done() is False
913
914 result.done()
915
916 try:
917 await future
918 assert False, "expected StopAsyncIteration"
919 except StopAsyncIteration:
920 pass
921
922 assert future.done() is True
923
924
925@pytest.mark.parametrize(
926 "raw_event,expected_event",
927 [
928 (
929 RawResponsesStreamEvent(
930 type="raw_response_event",
931 data=ResponseTextDeltaEvent(
932 type="response.output_text.delta",
933 delta="Hello, world!",
934 content_index=0,
935 item_id="123",
936 logprobs=[],
937 output_index=0,
938 sequence_number=0,
939 ),
940 ),
941 ThreadItemUpdatedEvent(
942 item_id="123",
943 update=AssistantMessageContentPartTextDelta(
944 content_index=0,
945 delta="Hello, world!",
946 ),
947 ),
948 ),
949 (
950 RawResponsesStreamEvent(
951 type="raw_response_event",
952 data=ResponseContentPartAddedEvent(
953 type="response.content_part.added",
954 part=ResponseOutputText(
955 type="output_text",
956 text="New content",
957 annotations=[],
958 ),
959 content_index=1,
960 item_id="123",
961 output_index=0,
962 sequence_number=1,
963 ),
964 ),
965 ThreadItemUpdatedEvent(
966 item_id="123",
967 update=AssistantMessageContentPartAdded(
968 content_index=1,
969 content=AssistantMessageContent(text="New content", annotations=[]),
970 ),
971 ),
972 ),
973 (
974 RawResponsesStreamEvent(
975 type="raw_response_event",
976 data=ResponseTextDoneEvent(
977 type="response.output_text.done",
978 text="Final text",
979 content_index=0,
980 item_id="123",
981 logprobs=[],
982 output_index=0,
983 sequence_number=2,
984 ),
985 ),
986 ThreadItemUpdatedEvent(
987 item_id="123",
988 update=AssistantMessageContentPartDone(
989 content_index=0,
990 content=AssistantMessageContent(
991 text="Final text",
992 annotations=[],
993 ),
994 ),
995 ),
996 ),
997 (
998 RawResponsesStreamEvent(
999 type="raw_response_event",
1000 data=Mock(
1001 type="response.output_text.annotation.added",
1002 annotation=ResponsesAnnotationFileCitation(
1003 type="file_citation",
1004 file_id="file_123",
1005 filename="file.txt",
1006 index=5,
1007 ),
1008 content_index=0,
1009 item_id="123",
1010 annotation_index=0,
1011 output_index=0,
1012 sequence_number=3,
1013 ),
1014 ),
1015 ThreadItemUpdatedEvent(
1016 item_id="123",
1017 update=AssistantMessageContentPartAnnotationAdded(
1018 content_index=0,
1019 annotation_index=0,
1020 annotation=Annotation(
1021 source=FileSource(filename="file.txt", title="file.txt"),
1022 index=5,
1023 ),
1024 ),
1025 ),
1026 ),
1027 ],
1028)
1029async def test_event_mapping(raw_event, expected_event):
1030 context = AgentContext(
1031 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1032 )
1033 result = make_result()
1034
1035 result.add_event(raw_event)
1036 result.done()
1037
1038 events = await all_events(stream_agent_response(context, result))
1039 if expected_event:
1040 assert events == [expected_event]
1041 else:
1042 assert events == []
1043
1044
1045async def test_stream_agent_response_emits_annotation_added_events():
1046 context = AgentContext(
1047 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1048 )
1049 result = make_result()
1050 item_id = "item_123"
1051
1052 def add_annotation_event(annotation, sequence_number):
1053 result.add_event(
1054 RawResponsesStreamEvent(
1055 type="raw_response_event",
1056 data=Mock(
1057 type="response.output_text.annotation.added",
1058 annotation=annotation,
1059 content_index=0,
1060 item_id=item_id,
1061 annotation_index=sequence_number,
1062 output_index=0,
1063 sequence_number=sequence_number,
1064 ),
1065 )
1066 )
1067
1068 add_annotation_event(
1069 ResponsesAnnotationFileCitation(
1070 type="file_citation",
1071 file_id="file_invalid",
1072 filename="",
1073 index=0,
1074 ),
1075 sequence_number=0,
1076 )
1077 add_annotation_event(
1078 ResponsesAnnotationContainerFileCitation(
1079 type="container_file_citation",
1080 container_id="container_1",
1081 file_id="file_123",
1082 filename="container.txt",
1083 start_index=0,
1084 end_index=3,
1085 ),
1086 sequence_number=1,
1087 )
1088 add_annotation_event(
1089 ResponsesAnnotationURLCitation(
1090 type="url_citation",
1091 url="https://example.com",
1092 title="Example",
1093 start_index=1,
1094 end_index=5,
1095 ),
1096 sequence_number=2,
1097 )
1098 result.done()
1099
1100 events = await all_events(stream_agent_response(context, result))
1101 assert events == [
1102 ThreadItemUpdatedEvent(
1103 item_id=item_id,
1104 update=AssistantMessageContentPartAnnotationAdded(
1105 content_index=0,
1106 annotation_index=0,
1107 annotation=Annotation(
1108 source=FileSource(filename="container.txt", title="container.txt"),
1109 index=3,
1110 ),
1111 ),
1112 ),
1113 ThreadItemUpdatedEvent(
1114 item_id=item_id,
1115 update=AssistantMessageContentPartAnnotationAdded(
1116 content_index=0,
1117 annotation_index=1,
1118 annotation=Annotation(
1119 source=URLSource(
1120 url="https://example.com",
1121 title="Example",
1122 ),
1123 index=5,
1124 ),
1125 ),
1126 ),
1127 ]
1128
1129
1130async def test_stream_agent_response_annotation_added_normalizes_annotations():
1131 context = AgentContext(
1132 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1133 )
1134 result = make_result()
1135 item_id = "item_123"
1136
1137 def add_annotation_event(annotation, sequence_number):
1138 result.add_event(
1139 RawResponsesStreamEvent(
1140 type="raw_response_event",
1141 data=Mock(
1142 type="response.output_text.annotation.added",
1143 annotation=annotation,
1144 content_index=0,
1145 item_id=item_id,
1146 annotation_index=sequence_number,
1147 output_index=0,
1148 sequence_number=sequence_number,
1149 ),
1150 )
1151 )
1152
1153 # Invalid file citation is dropped (empty filename)
1154 add_annotation_event(
1155 {
1156 "type": "file_citation",
1157 "file_id": "file_invalid",
1158 "filename": "",
1159 "index": 0,
1160 },
1161 sequence_number=0,
1162 )
1163
1164 # Container + URL citations are converted; indices are compacted (0, 1) despite the dropped first event.
1165 add_annotation_event(
1166 {
1167 "type": "container_file_citation",
1168 "container_id": "container_1",
1169 "file_id": "file_123",
1170 "filename": "container.txt",
1171 "start_index": 0,
1172 "end_index": 3,
1173 },
1174 sequence_number=1,
1175 )
1176 add_annotation_event(
1177 {
1178 "type": "url_citation",
1179 "url": "https://example.com",
1180 "title": "Example",
1181 "start_index": 1,
1182 "end_index": 5,
1183 },
1184 sequence_number=2,
1185 )
1186
1187 result.done()
1188
1189 events = await all_events(stream_agent_response(context, result))
1190 assert events == [
1191 ThreadItemUpdatedEvent(
1192 item_id=item_id,
1193 update=AssistantMessageContentPartAnnotationAdded(
1194 content_index=0,
1195 annotation_index=0,
1196 annotation=Annotation(
1197 source=FileSource(filename="container.txt", title="container.txt"),
1198 index=3,
1199 ),
1200 ),
1201 ),
1202 ThreadItemUpdatedEvent(
1203 item_id=item_id,
1204 update=AssistantMessageContentPartAnnotationAdded(
1205 content_index=0,
1206 annotation_index=1,
1207 annotation=Annotation(
1208 source=URLSource(url="https://example.com", title="Example"),
1209 index=5,
1210 ),
1211 ),
1212 ),
1213 ]
1214
1215
1216async def test_custom_annotation_conversion_used_by_stream_agent_response():
1217 context = AgentContext(
1218 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1219 )
1220 result = make_result()
1221
1222 class CustomResponseStreamConverter(ResponseStreamConverter):
1223 def __init__(self):
1224 super().__init__()
1225 self.calls: list[str] = []
1226
1227 async def file_citation_to_annotation(self, file_citation):
1228 self.calls.append("file_citation")
1229 return Annotation(
1230 source=FileSource(
1231 filename="report.pdf",
1232 title="Usage Report",
1233 description="Usage report for the month of January",
1234 ),
1235 index=111,
1236 )
1237
1238 async def url_citation_to_annotation(self, url_citation):
1239 self.calls.append("url_citation")
1240 return Annotation(
1241 source=URLSource(url="https://custom.example/url", title="Custom"),
1242 index=222,
1243 )
1244
1245 converter = CustomResponseStreamConverter()
1246
1247 result.add_event(
1248 RawResponsesStreamEvent(
1249 type="raw_response_event",
1250 data=Mock(
1251 type="response.output_text.annotation.added",
1252 annotation=ResponsesAnnotationFileCitation(
1253 type="file_citation",
1254 file_id="file_123",
1255 filename="file.txt",
1256 index=0,
1257 ),
1258 content_index=0,
1259 item_id="m_1",
1260 annotation_index=0,
1261 output_index=0,
1262 sequence_number=0,
1263 ),
1264 )
1265 )
1266
1267 result.add_event(
1268 RawResponsesStreamEvent(
1269 type="raw_response_event",
1270 data=ResponseOutputItemDoneEvent(
1271 type="response.output_item.done",
1272 item=ResponseOutputMessage(
1273 id="m_1",
1274 content=[
1275 ResponseOutputText(
1276 annotations=[
1277 ResponsesAnnotationURLCitation(
1278 type="url_citation",
1279 url="https://example.com",
1280 title="Example",
1281 start_index=0,
1282 end_index=7,
1283 )
1284 ],
1285 text="Hello!",
1286 type="output_text",
1287 )
1288 ],
1289 role="assistant",
1290 status="completed",
1291 type="message",
1292 ),
1293 output_index=0,
1294 sequence_number=1,
1295 ),
1296 )
1297 )
1298 result.done()
1299
1300 events = await all_events(
1301 stream_agent_response(context, result, converter=converter)
1302 )
1303 assert len(events) == 2
1304
1305 assert isinstance(events[0], ThreadItemUpdatedEvent)
1306 assert events[0].item_id == "m_1"
1307 assert events[0].update == AssistantMessageContentPartAnnotationAdded(
1308 content_index=0,
1309 annotation_index=0,
1310 annotation=Annotation(
1311 source=FileSource(
1312 filename="report.pdf",
1313 title="Usage Report",
1314 description="Usage report for the month of January",
1315 ),
1316 index=111,
1317 ),
1318 )
1319
1320 assert isinstance(events[1], ThreadItemDoneEvent)
1321 assert isinstance(events[1].item, AssistantMessageItem)
1322 assert events[1].item.id == "m_1"
1323 assert events[1].item.content == [
1324 AssistantMessageContent(
1325 text="Hello!",
1326 annotations=[
1327 Annotation(
1328 source=URLSource(url="https://custom.example/url", title="Custom"),
1329 index=222,
1330 )
1331 ],
1332 )
1333 ]
1334 assert converter.calls == ["file_citation", "url_citation"]
1335
1336
1337@pytest.mark.parametrize("throw_guardrail", ["input", "output"])
1338async def test_stream_agent_response_yields_item_removed_event(throw_guardrail):
1339 context = AgentContext(
1340 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1341 )
1342 result = make_result()
1343 result.add_event(
1344 RawResponsesStreamEvent(
1345 type="raw_response_event",
1346 data=ResponseOutputItemAddedEvent(
1347 type="response.output_item.added",
1348 item=ResponseOutputMessage(
1349 id="1",
1350 content=[
1351 ResponseOutputText(
1352 annotations=[], type="output_text", text="Hello, world!"
1353 )
1354 ],
1355 role="assistant",
1356 status="completed",
1357 type="message",
1358 ),
1359 output_index=0,
1360 sequence_number=0,
1361 ),
1362 )
1363 )
1364 await context.stream(
1365 ThreadItemAddedEvent(
1366 item=AssistantMessageItem(
1367 id="2",
1368 content=[AssistantMessageContent(text="Hello, world!")],
1369 thread_id=thread.id,
1370 created_at=datetime.now(),
1371 ),
1372 )
1373 )
1374
1375 await context.stream(
1376 ThreadItemDoneEvent(
1377 item=WidgetItem(
1378 id="3",
1379 created_at=datetime.now(),
1380 thread_id=thread.id,
1381 widget=Card(children=[Text(id="text", value="Hello, world!")]),
1382 ),
1383 )
1384 )
1385
1386 iterator = stream_agent_response(context, result)
1387
1388 n = 3
1389 events = []
1390 # Grab first 3 events to
1391 async for event in iterator:
1392 n -= 1
1393 events.append(event)
1394 if n == 0:
1395 break
1396
1397 if throw_guardrail == "input":
1398 result.throw_input_guardrails()
1399 else:
1400 result.throw_output_guardrails()
1401
1402 try:
1403 async for event in iterator:
1404 events.append(event)
1405 assert False, "Guardrail should have been thrown from stream_agent_response"
1406 except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
1407 pass
1408 except Exception as e:
1409 assert False, f"Unexpected exception: {e}"
1410
1411 deleted_item_ids = {
1412 event.item_id for event in events if event.type == "thread.item.removed"
1413 }
1414 assert deleted_item_ids == {"1", "2", "3"}
1415
1416
1417async def test_stream_agent_response_assistant_message_content_types():
1418 AgentContext(
1419 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1420 )
1421 result = make_result()
1422
1423 result.add_event(
1424 RawResponsesStreamEvent(
1425 type="raw_response_event",
1426 data=ResponseOutputItemDoneEvent(
1427 type="response.output_item.done",
1428 item=ResponseFileSearchToolCall(
1429 id="fs_0",
1430 queries=["Hello, world!"],
1431 status="completed",
1432 type="file_search_call",
1433 results=[
1434 Result(
1435 file_id="f_123",
1436 filename="test.txt",
1437 text="Hello, world!",
1438 score=1.0,
1439 ),
1440 Result(
1441 file_id="f_123",
1442 filename="test.txt",
1443 text="Hello, friends!",
1444 score=0.5,
1445 ),
1446 ],
1447 ),
1448 output_index=0,
1449 sequence_number=0,
1450 ),
1451 )
1452 )
1453 result.add_event(
1454 RawResponsesStreamEvent(
1455 type="raw_response_event",
1456 data=ResponseOutputItemDoneEvent(
1457 type="response.output_item.done",
1458 item=ResponseOutputMessage(
1459 id="1",
1460 content=[
1461 ResponseOutputText(
1462 annotations=[
1463 ResponsesAnnotationFileCitation(
1464 type="file_citation",
1465 file_id="f_123",
1466 index=0,
1467 filename="test.txt",
1468 ),
1469 ResponsesAnnotationContainerFileCitation(
1470 type="container_file_citation",
1471 container_id="container_1",
1472 file_id="f_456",
1473 filename="container.txt",
1474 start_index=0,
1475 end_index=3,
1476 ),
1477 ResponsesAnnotationURLCitation(
1478 type="url_citation",
1479 url="https://www.google.com",
1480 title="Google",
1481 start_index=0,
1482 end_index=10,
1483 ),
1484 ResponsesAnnotationFilePath(
1485 type="file_path",
1486 file_id="123",
1487 index=0,
1488 ),
1489 ],
1490 text="Hello, world!",
1491 type="output_text",
1492 ),
1493 ResponseOutputText(
1494 annotations=[],
1495 text="Can't do that",
1496 type="output_text",
1497 ),
1498 ],
1499 role="assistant",
1500 status="completed",
1501 type="message",
1502 ),
1503 output_index=0,
1504 sequence_number=0,
1505 ),
1506 )
1507 )
1508
1509 result.done()
1510
1511 context = AgentContext(
1512 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1513 )
1514 events = await all_events(stream_agent_response(context, result))
1515 assert len(events) == 1
1516 assert isinstance(events[0], ThreadItemDoneEvent)
1517 message = events[0].item
1518 assert isinstance(message, AssistantMessageItem)
1519 assert message.content == [
1520 AssistantMessageContent(
1521 annotations=[
1522 Annotation(
1523 source=FileSource(
1524 filename="test.txt",
1525 title="test.txt",
1526 ),
1527 index=0,
1528 ),
1529 Annotation(
1530 source=FileSource(
1531 filename="container.txt",
1532 title="container.txt",
1533 ),
1534 index=3,
1535 ),
1536 Annotation(
1537 source=URLSource(
1538 url="https://www.google.com",
1539 title="Google",
1540 ),
1541 index=10,
1542 ),
1543 ],
1544 text="Hello, world!",
1545 ),
1546 AssistantMessageContent(text="Can't do that", annotations=[]),
1547 ]
1548 assert message.id == "1"
1549
1550
1551async def test_stream_agent_response_image_generation_events():
1552 context = AgentContext(
1553 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1554 )
1555 result = make_result()
1556
1557 result.add_event(
1558 RawResponsesStreamEvent(
1559 type="raw_response_event",
1560 data=Mock(
1561 type="response.output_item.added",
1562 item=Mock(type="image_generation_call", id="img_call_1"),
1563 output_index=0,
1564 sequence_number=0,
1565 ),
1566 )
1567 )
1568 result.add_event(
1569 RawResponsesStreamEvent(
1570 type="raw_response_event",
1571 data=Mock(
1572 type="response.output_item.done",
1573 item=Mock(
1574 type="image_generation_call", id="img_call_1", result="dGVzdA=="
1575 ),
1576 output_index=0,
1577 sequence_number=1,
1578 ),
1579 )
1580 )
1581 result.done()
1582
1583 stream = stream_agent_response(context, result)
1584 event1 = await stream.__anext__()
1585 assert isinstance(event1, ThreadItemAddedEvent)
1586 assert isinstance(event1.item, GeneratedImageItem)
1587 assert event1.item.type == "generated_image"
1588 assert event1.item.id == "message_id"
1589 assert event1.item.image is None
1590
1591 event2 = await stream.__anext__()
1592 assert isinstance(event2, ThreadItemDoneEvent)
1593 assert isinstance(event2.item, GeneratedImageItem)
1594 assert event2.item.id == event1.item.id
1595 assert event2.item.image == GeneratedImage(
1596 id="img_call_1", url="data:image/png;base64,dGVzdA=="
1597 )
1598
1599 with pytest.raises(StopAsyncIteration):
1600 await stream.__anext__()
1601
1602
1603async def test_stream_agent_response_image_generation_events_with_custom_converter():
1604 context = AgentContext(
1605 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1606 )
1607 result = make_result()
1608
1609 result.add_event(
1610 RawResponsesStreamEvent(
1611 type="raw_response_event",
1612 data=Mock(
1613 type="response.output_item.added",
1614 item=Mock(type="image_generation_call", id="img_call_1"),
1615 output_index=0,
1616 sequence_number=0,
1617 ),
1618 )
1619 )
1620 result.add_event(
1621 RawResponsesStreamEvent(
1622 type="raw_response_event",
1623 data=Mock(
1624 type="response.output_item.done",
1625 item=Mock(
1626 type="image_generation_call", id="img_call_1", result="dGVzdA=="
1627 ),
1628 output_index=0,
1629 sequence_number=1,
1630 ),
1631 )
1632 )
1633 result.done()
1634
1635 class CustomResponseStreamConverter(ResponseStreamConverter):
1636 def __init__(self):
1637 super().__init__()
1638 self.calls: list[tuple[str, str, int | None]] = []
1639
1640 async def base64_image_to_url(
1641 self,
1642 image_id: str,
1643 base64_image: str,
1644 partial_image_index: int | None = None,
1645 ) -> str:
1646 self.calls.append((image_id, base64_image, partial_image_index))
1647 return f"https://example.com/{image_id}"
1648
1649 converter = CustomResponseStreamConverter()
1650 stream = stream_agent_response(context, result, converter=converter)
1651 event1 = await stream.__anext__()
1652 assert isinstance(event1, ThreadItemAddedEvent)
1653 assert isinstance(event1.item, GeneratedImageItem)
1654 assert event1.item.image is None
1655
1656 event2 = await stream.__anext__()
1657 assert isinstance(event2, ThreadItemDoneEvent)
1658 assert isinstance(event2.item, GeneratedImageItem)
1659 assert converter.calls == [("img_call_1", "dGVzdA==", None)]
1660 assert event2.item.image == GeneratedImage(
1661 id="img_call_1", url="https://example.com/img_call_1"
1662 )
1663 with pytest.raises(StopAsyncIteration):
1664 await stream.__anext__()
1665
1666
1667async def test_stream_agent_response_image_generation_partial_progress():
1668 context = AgentContext(
1669 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1670 )
1671 result = make_result()
1672
1673 result.add_event(
1674 RawResponsesStreamEvent(
1675 type="raw_response_event",
1676 data=Mock(
1677 type="response.output_item.added",
1678 item=Mock(type="image_generation_call", id="img_call_1"),
1679 output_index=0,
1680 sequence_number=0,
1681 ),
1682 )
1683 )
1684 result.add_event(
1685 RawResponsesStreamEvent(
1686 type="raw_response_event",
1687 data=Mock(
1688 type="response.image_generation_call.partial_image",
1689 partial_image_b64="dGVzdA==",
1690 partial_image_index=1,
1691 item_id="img_call_1",
1692 output_index=0,
1693 sequence_number=1,
1694 ),
1695 )
1696 )
1697 result.add_event(
1698 RawResponsesStreamEvent(
1699 type="raw_response_event",
1700 data=Mock(
1701 type="response.output_item.done",
1702 item=Mock(
1703 type="image_generation_call", id="img_call_1", result="dGVzdA=="
1704 ),
1705 output_index=0,
1706 sequence_number=2,
1707 ),
1708 )
1709 )
1710 result.done()
1711
1712 converter = ResponseStreamConverter(partial_images=3)
1713 events = await all_events(
1714 stream_agent_response(context, result, converter=converter)
1715 )
1716
1717 assert len(events) == 3
1718 added_event, partial_event, done_event = events
1719
1720 assert isinstance(added_event, ThreadItemAddedEvent)
1721 assert isinstance(added_event.item, GeneratedImageItem)
1722
1723 assert isinstance(partial_event, ThreadItemUpdatedEvent)
1724 assert isinstance(partial_event.update, GeneratedImageUpdated)
1725 assert partial_event.update.progress == pytest.approx(1 / 3)
1726 assert partial_event.update.image == GeneratedImage(
1727 id="img_call_1", url="data:image/png;base64,dGVzdA=="
1728 )
1729
1730 assert isinstance(done_event, ThreadItemDoneEvent)
1731 assert isinstance(done_event.item, GeneratedImageItem)
1732 assert done_event.item.image == GeneratedImage(
1733 id="img_call_1", url="data:image/png;base64,dGVzdA=="
1734 )
1735
1736
1737async def test_workflow_streams_first_thought():
1738 context = AgentContext(
1739 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1740 )
1741 result = make_result()
1742
1743 # first thought
1744 result.add_event(
1745 RawResponsesStreamEvent(
1746 type="raw_response_event",
1747 data=ResponseOutputItemAddedEvent(
1748 type="response.output_item.added",
1749 item=ResponseReasoningItem(
1750 id="resp_1",
1751 summary=[],
1752 type="reasoning",
1753 ),
1754 output_index=0,
1755 sequence_number=0,
1756 ),
1757 )
1758 )
1759 result.add_event(
1760 RawResponsesStreamEvent(
1761 type="raw_response_event",
1762 data=Mock(
1763 type="response.reasoning_summary_text.delta",
1764 item_id="resp_1",
1765 summary_index=0,
1766 delta="Think",
1767 ),
1768 )
1769 )
1770 result.add_event(
1771 RawResponsesStreamEvent(
1772 type="raw_response_event",
1773 data=Mock(
1774 type="response.reasoning_summary_text.delta",
1775 item_id="resp_1",
1776 summary_index=0,
1777 delta="ing 1",
1778 ),
1779 )
1780 )
1781 result.add_event(
1782 RawResponsesStreamEvent(
1783 type="raw_response_event",
1784 data=Mock(
1785 type="response.reasoning_summary_text.done",
1786 item_id="resp_1",
1787 summary_index=0,
1788 text="Thinking 1",
1789 ),
1790 )
1791 )
1792
1793 # second thought
1794 result.add_event(
1795 RawResponsesStreamEvent(
1796 type="raw_response_event",
1797 data=Mock(
1798 type="response.reasoning_summary_text.delta",
1799 item_id="resp_1",
1800 summary_index=1,
1801 delta="Think",
1802 ),
1803 )
1804 )
1805 result.add_event(
1806 RawResponsesStreamEvent(
1807 type="raw_response_event",
1808 data=Mock(
1809 type="response.reasoning_summary_text.delta",
1810 item_id="resp_1",
1811 summary_index=1,
1812 delta="ing 2",
1813 ),
1814 )
1815 )
1816 result.add_event(
1817 RawResponsesStreamEvent(
1818 type="raw_response_event",
1819 data=Mock(
1820 type="response.reasoning_summary_text.done",
1821 item_id="resp_1",
1822 summary_index=1,
1823 text="Thinking 2",
1824 ),
1825 )
1826 )
1827
1828 result.done()
1829 stream = stream_agent_response(context, result)
1830
1831 # Workflow added
1832 event = await anext(stream)
1833 assert isinstance(event, ThreadItemAddedEvent)
1834 assert context.workflow_item is not None
1835 assert context.workflow_item.workflow.type == "reasoning"
1836 assert len(context.workflow_item.workflow.tasks) == 0
1837 assert event == ThreadItemAddedEvent(item=context.workflow_item)
1838
1839 # First thought added
1840 event = await anext(stream)
1841 assert context.workflow_item is not None
1842 assert len(context.workflow_item.workflow.tasks) == 1
1843 assert isinstance(event, ThreadItemUpdatedEvent)
1844 assert event == ThreadItemUpdatedEvent(
1845 item_id=context.workflow_item.id,
1846 update=WorkflowTaskAdded(
1847 task=ThoughtTask(content="Think"),
1848 task_index=0,
1849 ),
1850 )
1851
1852 # First thought delta
1853 event = await anext(stream)
1854 assert context.workflow_item is not None
1855 assert len(context.workflow_item.workflow.tasks) == 1
1856 assert isinstance(event, ThreadItemUpdatedEvent)
1857 assert event == ThreadItemUpdatedEvent(
1858 item_id=context.workflow_item.id,
1859 update=WorkflowTaskUpdated(
1860 task=ThoughtTask(content="Thinking 1"),
1861 task_index=0,
1862 ),
1863 )
1864
1865 # First thought done
1866 event = await anext(stream)
1867 assert context.workflow_item is not None
1868 assert len(context.workflow_item.workflow.tasks) == 1
1869 assert isinstance(event, ThreadItemUpdatedEvent)
1870 assert event == ThreadItemUpdatedEvent(
1871 item_id=context.workflow_item.id,
1872 update=WorkflowTaskUpdated(
1873 task=ThoughtTask(content="Thinking 1"),
1874 task_index=0,
1875 ),
1876 )
1877
1878 # Second thought added (not streamed)
1879 event = await anext(stream)
1880 assert context.workflow_item is not None
1881 assert len(context.workflow_item.workflow.tasks) == 2
1882 assert isinstance(event, ThreadItemUpdatedEvent)
1883 assert event == ThreadItemUpdatedEvent(
1884 item_id=context.workflow_item.id,
1885 update=WorkflowTaskAdded(
1886 task=ThoughtTask(content="Thinking 2"),
1887 task_index=1,
1888 ),
1889 )
1890
1891 try:
1892 while True:
1893 await anext(stream)
1894 except StopAsyncIteration:
1895 pass
1896
1897
1898async def test_workflow_ends_on_message():
1899 context = AgentContext(
1900 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1901 )
1902 result = make_result()
1903
1904 # first thought
1905 result.add_event(
1906 RawResponsesStreamEvent(
1907 type="raw_response_event",
1908 data=ResponseOutputItemAddedEvent(
1909 type="response.output_item.added",
1910 item=ResponseReasoningItem(
1911 id="resp_1",
1912 summary=[],
1913 type="reasoning",
1914 ),
1915 output_index=0,
1916 sequence_number=0,
1917 ),
1918 )
1919 )
1920 result.add_event(
1921 RawResponsesStreamEvent(
1922 type="raw_response_event",
1923 data=Mock(
1924 type="response.reasoning_summary_text.done",
1925 item_id="resp_1",
1926 summary_index=0,
1927 text="Thinking 1",
1928 ),
1929 )
1930 )
1931
1932 # not reasoning
1933 result.add_event(
1934 RawResponsesStreamEvent(
1935 type="raw_response_event",
1936 data=ResponseOutputItemAddedEvent(
1937 type="response.output_item.added",
1938 item=ResponseOutputMessage(
1939 id="m_1",
1940 content=[],
1941 role="assistant",
1942 status="in_progress",
1943 type="message",
1944 ),
1945 output_index=0,
1946 sequence_number=0,
1947 ),
1948 )
1949 )
1950
1951 result.done()
1952 stream = stream_agent_response(context, result)
1953
1954 # Workflow added
1955 event = await anext(stream)
1956 assert isinstance(event, ThreadItemAddedEvent)
1957 assert context.workflow_item is not None
1958 assert context.workflow_item.workflow.type == "reasoning"
1959 assert len(context.workflow_item.workflow.tasks) == 0
1960 assert event == ThreadItemAddedEvent(item=context.workflow_item)
1961
1962 # First thought done
1963 event = await anext(stream)
1964 assert context.workflow_item is not None
1965 assert len(context.workflow_item.workflow.tasks) == 1
1966 assert isinstance(event, ThreadItemUpdatedEvent)
1967 assert event == ThreadItemUpdatedEvent(
1968 item_id=context.workflow_item.id,
1969 update=WorkflowTaskAdded(
1970 task=ThoughtTask(content="Thinking 1"),
1971 task_index=0,
1972 ),
1973 )
1974
1975 # Workflow ended
1976 event = await anext(stream)
1977 assert isinstance(event, ThreadItemDoneEvent)
1978 assert event.item.type == "workflow"
1979 assert context.workflow_item is None
1980 # Summary and expanded are handled by the end_workflow method
1981 assert isinstance(event.item.workflow.summary, DurationSummary)
1982 assert event.item.workflow.expanded is False
1983
1984 try:
1985 while True:
1986 await anext(stream)
1987 except StopAsyncIteration:
1988 pass
1989
1990
1991async def test_existing_workflow_summary_not_overwritten_on_automatic_end():
1992 context = AgentContext(
1993 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1994 )
1995 result = make_result()
1996 context.workflow_item = WorkflowItem(
1997 id="wf_1",
1998 created_at=datetime.now(),
1999 workflow=Workflow(type="custom", tasks=[], summary=CustomSummary(title="Test")),
2000 thread_id=thread.id,
2001 )
2002
2003 result.add_event(
2004 RawResponsesStreamEvent(
2005 type="raw_response_event",
2006 data=ResponseOutputItemAddedEvent(
2007 type="response.output_item.added",
2008 item=ResponseOutputMessage(
2009 id="m_1",
2010 content=[],
2011 role="assistant",
2012 status="in_progress",
2013 type="message",
2014 ),
2015 output_index=0,
2016 sequence_number=0,
2017 ),
2018 )
2019 )
2020
2021 result.done()
2022 stream = stream_agent_response(context, result)
2023
2024 event = await anext(stream)
2025
2026 assert isinstance(event, ThreadItemDoneEvent)
2027 assert context.workflow_item is None
2028 assert event.item.type == "workflow"
2029 assert event.item.workflow.summary == CustomSummary(title="Test")
2030
2031 try:
2032 while True:
2033 await anext(stream)
2034 except StopAsyncIteration:
2035 pass
2036