openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.6.4

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_agents.py

1991lines · modecode

1import asyncio
2import json
3from collections.abc import AsyncIterator
4from datetime import datetime
5from importlib import import_module
6from typing import cast
7from unittest.mock import AsyncMock, Mock
8
9import pytest
10from agents import (
11 Agent,
12 GuardrailFunctionOutput,
13 InputGuardrail,
14 InputGuardrailResult,
15 InputGuardrailTripwireTriggered,
16 OutputGuardrail,
17 OutputGuardrailResult,
18 OutputGuardrailTripwireTriggered,
19 RawResponsesStreamEvent,
20 RunContextWrapper,
21 RunItemStreamEvent,
22 Runner,
23 RunResultStreaming,
24 StreamEvent,
25 ToolCallItem,
26)
27from openai.types.responses import (
28 EasyInputMessageParam,
29 ResponseFileSearchToolCall,
30 ResponseInputContentParam,
31 ResponseInputTextParam,
32 ResponseOutputItemAddedEvent,
33 ResponseOutputItemDoneEvent,
34 ResponseOutputMessage,
35 ResponseReasoningItem,
36)
37from openai.types.responses.response_content_part_added_event import (
38 ResponseContentPartAddedEvent,
39)
40from openai.types.responses.response_file_search_tool_call import Result
41from openai.types.responses.response_input_item_param import Message
42from openai.types.responses.response_output_text import (
43 AnnotationContainerFileCitation as ResponsesAnnotationContainerFileCitation,
44)
45from openai.types.responses.response_output_text import (
46 AnnotationFileCitation as ResponsesAnnotationFileCitation,
47)
48from openai.types.responses.response_output_text import (
49 AnnotationFilePath as ResponsesAnnotationFilePath,
50)
51from openai.types.responses.response_output_text import (
52 AnnotationURLCitation as ResponsesAnnotationURLCitation,
53)
54from openai.types.responses.response_output_text import (
55 ResponseOutputText,
56)
57from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent
58from openai.types.responses.response_text_done_event import ResponseTextDoneEvent
59
60from chatkit.agents import (
61 AgentContext,
62 ResponseStreamConverter,
63 ThreadItemConverter,
64 accumulate_text,
65 simple_to_agent_input,
66 stream_agent_response,
67)
68from chatkit.types import (
69 Annotation,
70 AssistantMessageContent,
71 AssistantMessageContentPartAdded,
72 AssistantMessageContentPartAnnotationAdded,
73 AssistantMessageContentPartDone,
74 AssistantMessageContentPartTextDelta,
75 AssistantMessageItem,
76 Attachment,
77 ClientToolCallItem,
78 CustomSummary,
79 CustomTask,
80 DurationSummary,
81 FileSource,
82 GeneratedImage,
83 GeneratedImageItem,
84 GeneratedImageUpdated,
85 HiddenContextItem,
86 InferenceOptions,
87 Page,
88 StructuredInputAnswer,
89 StructuredInputFreeform,
90 StructuredInputItem,
91 StructuredInputMultipleChoice,
92 StructuredInputMultipleChoiceOption,
93 TaskItem,
94 ThoughtTask,
95 Thread,
96 ThreadItemAddedEvent,
97 ThreadItemDoneEvent,
98 ThreadItemUpdatedEvent,
99 ThreadStreamEvent,
100 URLSource,
101 UserMessageItem,
102 UserMessageTagContent,
103 UserMessageTextContent,
104 WidgetItem,
105 Workflow,
106 WorkflowItem,
107 WorkflowTaskAdded,
108 WorkflowTaskUpdated,
109)
110from chatkit.widgets import Card, Text
111
112thread = Thread(id="123", title="Test", created_at=datetime.now(), items=Page())
113
114mock_store = Mock()
115mock_store.generate_item_id = lambda item_type, thread, context: f"{item_type}_id"
116mock_store.load_thread_items = AsyncMock(return_value=Page())
117mock_store.add_thread_item = AsyncMock()
118
119
120def make_queue_complete_sentinel():
121 for module_name in (
122 "agents.result",
123 "agents.run_internal.run_steps",
124 "agents._run_impl",
125 ):
126 try:
127 module = import_module(module_name)
128 except ImportError:
129 continue
130
131 sentinel = getattr(module, "QueueCompleteSentinel", None)
132 if sentinel is not None:
133 return sentinel()
134
135 raise ImportError("QueueCompleteSentinel not found in installed agents package")
136
137
138class RunResult(RunResultStreaming):
139 def add_event(self, event: StreamEvent):
140 self._event_queue.put_nowait(event)
141
142 def done(self):
143 self.is_complete = True
144 self._event_queue.put_nowait(make_queue_complete_sentinel())
145
146 def throw_input_guardrails(self):
147 self._stored_exception = InputGuardrailTripwireTriggered(
148 InputGuardrailResult(
149 guardrail=Mock(spec=InputGuardrail),
150 output=GuardrailFunctionOutput(
151 output_info=None,
152 tripwire_triggered=True,
153 ),
154 )
155 )
156 self.is_complete = True
157 self._event_queue.put_nowait(make_queue_complete_sentinel())
158
159 def throw_output_guardrails(self):
160 self._stored_exception = OutputGuardrailTripwireTriggered(
161 OutputGuardrailResult(
162 guardrail=Mock(spec=OutputGuardrail),
163 output=GuardrailFunctionOutput(
164 output_info=None,
165 tripwire_triggered=True,
166 ),
167 agent=Mock(spec=Agent),
168 agent_output=None,
169 )
170 )
171 self.is_complete = True
172 self._event_queue.put_nowait(make_queue_complete_sentinel())
173
174
175def make_result() -> RunResult:
176 return RunResult(
177 context_wrapper=Mock(spec=RunContextWrapper),
178 input=[],
179 tool_input_guardrail_results=[],
180 tool_output_guardrail_results=[],
181 new_items=[],
182 raw_responses=[],
183 final_output=None,
184 current_agent=Agent(name="test"),
185 current_turn=0,
186 max_turns=10,
187 _current_agent_output_schema=None,
188 trace=None,
189 is_complete=False,
190 _event_queue=asyncio.Queue(),
191 _input_guardrail_queue=asyncio.Queue(),
192 _output_guardrails_task=None,
193 _run_impl_task=None,
194 _stored_exception=None,
195 output_guardrail_results=[],
196 input_guardrail_results=[],
197 )
198
199
200async def all_events(
201 events: AsyncIterator[ThreadStreamEvent],
202) -> list[ThreadStreamEvent]:
203 return [event async for event in events]
204
205
206async def test_returns_widget_item():
207 context = AgentContext(
208 previous_response_id=None, thread=thread, store=mock_store, request_context=None
209 )
210 result = make_result()
211 result.add_event(
212 RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
213 )
214 await context.stream_widget(Card(children=[Text(value="Hello, world!")]))
215 result.done()
216
217 events = await all_events(
218 stream_agent_response(
219 context=context,
220 result=result,
221 )
222 )
223
224 assert len(events) == 1
225 assert isinstance(events[0], ThreadItemDoneEvent)
226 assert isinstance(events[0].item, WidgetItem)
227 assert events[0].item.widget == Card(children=[Text(value="Hello, world!")])
228
229
230async def test_returns_widget_item_generator():
231 context = AgentContext(
232 previous_response_id=None, thread=thread, store=mock_store, request_context=None
233 )
234 result = make_result()
235 result.add_event(
236 RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
237 )
238
239 def render_widget(i: int) -> Card:
240 return Card(children=[Text(id="text", value="Hello, world"[:i])])
241
242 async def widget_generator():
243 yield render_widget(0)
244 yield render_widget(12)
245
246 await context.stream_widget(widget_generator())
247 result.done()
248
249 events = await all_events(
250 stream_agent_response(
251 context=context,
252 result=result,
253 )
254 )
255
256 assert len(events) == 3
257 assert isinstance(events[0], ThreadItemAddedEvent)
258 assert isinstance(events[0].item, WidgetItem)
259 assert events[0].item.widget == Card(children=[Text(id="text", value="")])
260
261 assert isinstance(events[1], ThreadItemUpdatedEvent)
262 assert events[1].update.type == "widget.streaming_text.value_delta"
263 assert events[1].update.component_id == "text"
264 assert events[1].update.delta == "Hello, world"
265
266 assert isinstance(events[2], ThreadItemDoneEvent)
267 assert isinstance(events[2].item, WidgetItem)
268 assert events[2].item.widget == Card(
269 children=[Text(id="text", value="Hello, world")]
270 )
271
272
273async def test_returns_widget_full_replace_generator():
274 context = AgentContext(
275 previous_response_id=None, thread=thread, store=mock_store, request_context=None
276 )
277 result = make_result()
278 result.add_event(
279 RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
280 )
281
282 async def widget_generator():
283 yield Card(children=[Text(id="text", value="Hello!")])
284 yield Card(children=[Text(key="other text", value="World!", streaming=False)])
285
286 await context.stream_widget(widget_generator())
287 result.done()
288
289 events = await all_events(
290 stream_agent_response(
291 context=context,
292 result=result,
293 )
294 )
295
296 assert len(events) == 3
297 assert isinstance(events[0], ThreadItemAddedEvent)
298 assert isinstance(events[0].item, WidgetItem)
299 assert events[0].item.widget == Card(children=[Text(id="text", value="Hello!")])
300
301 assert isinstance(events[1], ThreadItemUpdatedEvent)
302 assert events[1].update.type == "widget.root.updated"
303 assert events[1].update.widget == Card(
304 children=[Text(key="other text", value="World!", streaming=False)]
305 )
306
307 assert isinstance(events[2], ThreadItemDoneEvent)
308 assert isinstance(events[2].item, WidgetItem)
309 assert events[2].item.widget == Card(
310 children=[Text(key="other text", value="World!", streaming=False)]
311 )
312
313
314async def test_accumulate_text():
315 def delta(text: str) -> RawResponsesStreamEvent:
316 return RawResponsesStreamEvent(
317 type="raw_response_event",
318 data=ResponseTextDeltaEvent(
319 type="response.output_text.delta",
320 delta=text,
321 content_index=0,
322 item_id="123",
323 logprobs=[],
324 output_index=0,
325 sequence_number=0,
326 ),
327 )
328
329 result = Runner.run_streamed(
330 Agent("Assistant", instructions="You are a helpful assistant."), "Say hello!"
331 )
332 result = make_result()
333 result.add_event(delta("Hello, "))
334 result.add_event(delta("world!"))
335
336 result.done()
337
338 events = [
339 event
340 async for event in accumulate_text(
341 result.stream_events(), Text(key="text", value="", streaming=True)
342 )
343 ]
344 assert events == [
345 Text(key="text", value="", streaming=True),
346 Text(key="text", value="Hello, ", streaming=True),
347 Text(key="text", value="Hello, world!", streaming=True),
348 Text(key="text", value="Hello, world!", streaming=False),
349 ]
350
351
352async def test_input_item_converter_quotes_last_user_message():
353 items = [
354 UserMessageItem(
355 id="123",
356 content=[UserMessageTextContent(text="Hello!")],
357 attachments=[],
358 inference_options=InferenceOptions(),
359 thread_id=thread.id,
360 quoted_text="Hi!",
361 created_at=datetime.now(),
362 ),
363 UserMessageItem(
364 id="123",
365 content=[UserMessageTextContent(text="I'm well, thank you!")],
366 attachments=[],
367 inference_options=InferenceOptions(),
368 thread_id=thread.id,
369 quoted_text="How are you doing?",
370 created_at=datetime.now(),
371 ),
372 ]
373
374 async def throw_exception(
375 _: Attachment,
376 ) -> ResponseInputContentParam:
377 raise Exception("Not implemented")
378
379 input_items = await simple_to_agent_input(items)
380 assert len(input_items) == 3
381 assert input_items[0] == {
382 "content": [
383 {
384 "text": "Hello!",
385 "type": "input_text",
386 },
387 ],
388 "role": "user",
389 "type": "message",
390 }
391 assert input_items[1] == {
392 "content": [
393 {
394 "text": "I'm well, thank you!",
395 "type": "input_text",
396 },
397 ],
398 "role": "user",
399 "type": "message",
400 }
401 assert input_items[2] == {
402 "content": [
403 {
404 "text": "The user is referring to this in particular: \nHow are you doing?",
405 "type": "input_text",
406 },
407 ],
408 "role": "user",
409 "type": "message",
410 }
411
412
413async def test_input_item_converter_to_input_items_mixed():
414 items = [
415 UserMessageItem(
416 id="123",
417 content=[UserMessageTextContent(text="Hello!")],
418 attachments=[],
419 inference_options=InferenceOptions(),
420 thread_id=thread.id,
421 quoted_text="Hi!",
422 created_at=datetime.now(),
423 ),
424 UserMessageItem(
425 id="123",
426 content=[UserMessageTextContent(text="I'm well, thank you!")],
427 attachments=[],
428 inference_options=InferenceOptions(),
429 thread_id=thread.id,
430 quoted_text="How are you doing?",
431 created_at=datetime.now(),
432 ),
433 AssistantMessageItem(
434 id="123",
435 content=[
436 AssistantMessageContent(text="How are you doing?"),
437 AssistantMessageContent(text="Can't do that"),
438 ],
439 thread_id=thread.id,
440 created_at=datetime.now(),
441 ),
442 WidgetItem(
443 id="wd_123",
444 widget=Card(children=[Text(value="Hello, world!")]),
445 thread_id=thread.id,
446 created_at=datetime.now(),
447 ),
448 ]
449
450 input_items = await simple_to_agent_input(items)
451 assert len(input_items) == 4
452 assert input_items[0] == {
453 "content": [
454 {
455 "text": "Hello!",
456 "type": "input_text",
457 },
458 ],
459 "role": "user",
460 "type": "message",
461 }
462 assert input_items[1] == {
463 "content": [
464 {
465 "text": "I'm well, thank you!",
466 "type": "input_text",
467 },
468 ],
469 "role": "user",
470 "type": "message",
471 }
472 assert input_items[2] == {
473 "content": [
474 {
475 "annotations": [],
476 "text": "How are you doing?",
477 "logprobs": None,
478 "type": "output_text",
479 },
480 {
481 "annotations": [],
482 "text": "Can't do that",
483 "logprobs": None,
484 "type": "output_text",
485 },
486 ],
487 "type": "message",
488 "role": "assistant",
489 }
490 assert "type" in input_items[3]
491 widget_item = cast(EasyInputMessageParam, input_items[3])
492 assert widget_item.get("type") == "message"
493 assert widget_item.get("role") == "user"
494 text = widget_item.get("content")[0]["text"] # type: ignore
495 assert (
496 "The following graphical UI widget (id: wd_123) was displayed to the user"
497 in text
498 )
499 assert "Hello, world!" in text
500 assert "created_at" not in text
501
502
503async def test_input_item_converter_user_input_with_tags():
504 class MyThreadItemConverter(ThreadItemConverter):
505 async def tag_to_message_content(self, tag):
506 return ResponseInputTextParam(
507 type="input_text", text=tag.text + " " + tag.data["key"]
508 )
509
510 items = [
511 UserMessageItem(
512 id="123",
513 content=[
514 UserMessageTagContent(
515 text="Hello!", type="input_tag", id="hello", data={"key": "value"}
516 )
517 ],
518 attachments=[],
519 inference_options=InferenceOptions(),
520 thread_id=thread.id,
521 created_at=datetime.now(),
522 )
523 ]
524 items = await MyThreadItemConverter().to_agent_input(items)
525
526 assert len(items) == 2
527 assert items[0] == {
528 "content": [
529 {
530 "text": "@Hello!",
531 "type": "input_text",
532 },
533 ],
534 "role": "user",
535 "type": "message",
536 }
537 assert items[1] == {
538 "content": [
539 {
540 "text": "# User-provided context for @-mentions\n- When referencing resolved entities, use their canonical names **without** '@'.\n"
541 + "- The '@' form appears only in user text and should not be echoed.",
542 "type": "input_text",
543 },
544 {
545 "text": "Hello! value",
546 "type": "input_text",
547 },
548 ],
549 "role": "user",
550 "type": "message",
551 }
552
553
554async def test_input_item_converter_user_input_with_tags_throws_by_default():
555 items = [
556 UserMessageItem(
557 id="123",
558 content=[
559 UserMessageTagContent(
560 text="Hello!", type="input_tag", id="hello", data={}
561 )
562 ],
563 attachments=[],
564 inference_options=InferenceOptions(),
565 thread_id=thread.id,
566 created_at=datetime.now(),
567 )
568 ]
569
570 with pytest.raises(NotImplementedError):
571 await simple_to_agent_input(items)
572
573
574async def test_input_item_converter_generated_image_item():
575 items = [
576 GeneratedImageItem(
577 id="img_item_1",
578 thread_id=thread.id,
579 created_at=datetime.now(),
580 image=GeneratedImage(id="img_1", url="https://example.com/img.png"),
581 )
582 ]
583
584 input_items = await simple_to_agent_input(items)
585 assert len(input_items) == 1
586
587 message = cast(dict, input_items[0])
588 assert message.get("type") == "message"
589 assert message.get("role") == "user"
590
591 content = cast(list, message.get("content"))
592 assert content[0].get("type") == "input_text"
593 assert content[0].get("text") == "The following image was generated by the agent."
594 assert content[1].get("type") == "input_image"
595 assert content[1].get("file_id") is None
596 assert content[1].get("image_url") == "https://example.com/img.png"
597 assert content[1].get("detail") == "auto"
598
599
600async def test_input_item_converter_generated_image_item_without_image():
601 items = [
602 GeneratedImageItem(
603 id="img_item_1",
604 thread_id=thread.id,
605 created_at=datetime.now(),
606 )
607 ]
608
609 input_items = await simple_to_agent_input(items)
610 assert input_items == []
611
612
613async def test_input_item_converter_for_hidden_context_with_string_content():
614 items = [
615 HiddenContextItem(
616 id="123",
617 content="User pressed the red button",
618 thread_id=thread.id,
619 created_at=datetime.now(),
620 )
621 ]
622
623 # The default converter works for string content
624 items = await simple_to_agent_input(items)
625 assert len(items) == 1
626 assert items[0] == {
627 "content": [
628 {
629 "text": "Hidden context for the agent (not shown to the user):\n<HiddenContext>\nUser pressed the red button\n</HiddenContext>",
630 "type": "input_text",
631 },
632 ],
633 "role": "user",
634 "type": "message",
635 }
636
637
638async def test_input_item_converter_for_hidden_context_with_non_string_content():
639 items = [
640 HiddenContextItem(
641 id="123",
642 content={"harry": "potter", "hermione": "granger"},
643 thread_id=thread.id,
644 created_at=datetime.now(),
645 )
646 ]
647
648 # Default converter should throw
649 with pytest.raises(NotImplementedError):
650 await simple_to_agent_input(items)
651
652 class MyThreadItemConverter(ThreadItemConverter):
653 async def hidden_context_to_input(self, item: HiddenContextItem):
654 content = ",".join(item.content.keys())
655 return Message(
656 type="message",
657 content=[
658 ResponseInputTextParam(
659 type="input_text", text=f"<HIDDEN>{content}</HIDDEN>"
660 )
661 ],
662 role="user",
663 )
664
665 items = await MyThreadItemConverter().to_agent_input(items)
666 assert len(items) == 1
667 assert items[0] == {
668 "content": [
669 {
670 "text": "<HIDDEN>harry,hermione</HIDDEN>",
671 "type": "input_text",
672 },
673 ],
674 "role": "user",
675 "type": "message",
676 }
677
678
679async def test_input_item_converter_with_client_tool_call():
680 items = [
681 UserMessageItem(
682 id="123",
683 content=[UserMessageTextContent(text="Call a client tool call xyz")],
684 attachments=[],
685 inference_options=InferenceOptions(),
686 thread_id=thread.id,
687 quoted_text="Hi!",
688 created_at=datetime.now(),
689 ),
690 TaskItem(
691 id="tsk_123",
692 created_at=datetime.now(),
693 task=CustomTask(title="Called xyx"),
694 thread_id=thread.id,
695 ),
696 ClientToolCallItem(
697 id="ctc_123",
698 thread_id=thread.id,
699 created_at=datetime.now(),
700 name="xyz",
701 arguments={"foo": "bar"},
702 call_id="ctc_123",
703 ),
704 ClientToolCallItem(
705 id="ctc_123_done",
706 thread_id=thread.id,
707 created_at=datetime.now(),
708 name="xyz",
709 arguments={"foo": "bar"},
710 call_id="ctc_123",
711 status="completed",
712 output={"success": True},
713 ),
714 ]
715
716 input_items = await simple_to_agent_input(items)
717 assert len(input_items) == 4
718 assert input_items[0] == {
719 "content": [
720 {
721 "text": "Call a client tool call xyz",
722 "type": "input_text",
723 },
724 ],
725 "role": "user",
726 "type": "message",
727 }
728 assert input_items[1] == {
729 "content": [
730 {
731 "text": "A message was displayed to the user that the following task was performed:\n<Task>\nCalled xyx\n</Task>",
732 "type": "input_text",
733 },
734 ],
735 "type": "message",
736 "role": "user",
737 }
738 assert input_items[2] == {
739 "type": "function_call",
740 "name": "xyz",
741 "arguments": json.dumps({"foo": "bar"}),
742 "call_id": "ctc_123",
743 }
744 assert input_items[3] == {
745 "type": "function_call_output",
746 "call_id": "ctc_123",
747 "output": json.dumps({"success": True}),
748 }
749
750
751async def test_input_item_converter_with_structured_input():
752 items = [
753 StructuredInputItem(
754 id="si_123",
755 thread_id=thread.id,
756 created_at=datetime.now(),
757 status="answered",
758 inputs=[
759 StructuredInputMultipleChoice(
760 id="subject",
761 question="Which subject is this lesson for?",
762 options=[
763 StructuredInputMultipleChoiceOption(value="Math"),
764 StructuredInputMultipleChoiceOption(value="Science"),
765 ],
766 answer=StructuredInputAnswer(values=["Math"]),
767 ),
768 StructuredInputFreeform(
769 id="details",
770 question="Anything else to know?",
771 answer=StructuredInputAnswer(skipped=True),
772 ),
773 ],
774 )
775 ]
776
777 input_items = await simple_to_agent_input(items)
778 assert len(input_items) == 1
779 assert input_items[0] == {
780 "content": [
781 {
782 "text": "A structured input request was displayed to the user with the following status: answered\n<StructuredInput>\n- Which subject is this lesson for?: Math\n- Anything else to know?: skipped\n</StructuredInput>",
783 "type": "input_text",
784 },
785 ],
786 "role": "user",
787 "type": "message",
788 }
789
790
791async def test_stream_agent_response_yields_context_events_without_streaming_events():
792 context = AgentContext(
793 previous_response_id=None, thread=thread, store=mock_store, request_context=None
794 )
795 result = make_result()
796
797 event = ThreadItemAddedEvent(
798 item=WidgetItem(
799 id="123",
800 created_at=datetime.now(),
801 thread_id=thread.id,
802 widget=Card(children=[Text(id="text", value="Hello, world!")]),
803 ),
804 )
805
806 await context.stream(event)
807
808 response_streamer = stream_agent_response(context, result)
809 event = await response_streamer.__anext__()
810
811 assert event.type == "thread.item.added"
812
813 future = asyncio.ensure_future(response_streamer.__anext__())
814 assert future.done() is False
815
816 result.done()
817
818 try:
819 await future
820 assert False, "expected StopAsyncIteration"
821 except StopAsyncIteration:
822 pass
823
824 assert future.done() is True
825
826
827async def test_stream_agent_response_maps_events():
828 context = AgentContext(
829 previous_response_id=None, thread=thread, store=mock_store, request_context=None
830 )
831 result = make_result()
832
833 event = ThreadItemAddedEvent(
834 item=WidgetItem(
835 id="123",
836 created_at=datetime.now(),
837 thread_id=thread.id,
838 widget=Card(children=[Text(id="text", value="Hello, world!")]),
839 ),
840 )
841
842 await context.stream(event)
843 result.add_event(
844 RawResponsesStreamEvent(
845 type="raw_response_event",
846 data=ResponseTextDeltaEvent(
847 type="response.output_text.delta",
848 delta="Hello, world!",
849 content_index=0,
850 item_id="123",
851 logprobs=[],
852 output_index=0,
853 sequence_number=0,
854 ),
855 )
856 )
857
858 response_streamer = stream_agent_response(context, result)
859 event1 = await response_streamer.__anext__()
860 event2 = await response_streamer.__anext__()
861
862 assert {event1.type, event2.type} == {
863 "thread.item.added",
864 "thread.item.updated",
865 }
866
867 future = asyncio.ensure_future(response_streamer.__anext__())
868 assert future.done() is False
869
870 result.done()
871
872 try:
873 await future
874 assert False, "expected StopAsyncIteration"
875 except StopAsyncIteration:
876 pass
877
878 assert future.done() is True
879
880
881@pytest.mark.parametrize(
882 "raw_event,expected_event",
883 [
884 (
885 RawResponsesStreamEvent(
886 type="raw_response_event",
887 data=ResponseTextDeltaEvent(
888 type="response.output_text.delta",
889 delta="Hello, world!",
890 content_index=0,
891 item_id="123",
892 logprobs=[],
893 output_index=0,
894 sequence_number=0,
895 ),
896 ),
897 ThreadItemUpdatedEvent(
898 item_id="123",
899 update=AssistantMessageContentPartTextDelta(
900 content_index=0,
901 delta="Hello, world!",
902 ),
903 ),
904 ),
905 (
906 RawResponsesStreamEvent(
907 type="raw_response_event",
908 data=ResponseContentPartAddedEvent(
909 type="response.content_part.added",
910 part=ResponseOutputText(
911 type="output_text",
912 text="New content",
913 annotations=[],
914 ),
915 content_index=1,
916 item_id="123",
917 output_index=0,
918 sequence_number=1,
919 ),
920 ),
921 ThreadItemUpdatedEvent(
922 item_id="123",
923 update=AssistantMessageContentPartAdded(
924 content_index=1,
925 content=AssistantMessageContent(text="New content", annotations=[]),
926 ),
927 ),
928 ),
929 (
930 RawResponsesStreamEvent(
931 type="raw_response_event",
932 data=ResponseTextDoneEvent(
933 type="response.output_text.done",
934 text="Final text",
935 content_index=0,
936 item_id="123",
937 logprobs=[],
938 output_index=0,
939 sequence_number=2,
940 ),
941 ),
942 ThreadItemUpdatedEvent(
943 item_id="123",
944 update=AssistantMessageContentPartDone(
945 content_index=0,
946 content=AssistantMessageContent(
947 text="Final text",
948 annotations=[],
949 ),
950 ),
951 ),
952 ),
953 (
954 RawResponsesStreamEvent(
955 type="raw_response_event",
956 data=Mock(
957 type="response.output_text.annotation.added",
958 annotation=ResponsesAnnotationFileCitation(
959 type="file_citation",
960 file_id="file_123",
961 filename="file.txt",
962 index=5,
963 ),
964 content_index=0,
965 item_id="123",
966 annotation_index=0,
967 output_index=0,
968 sequence_number=3,
969 ),
970 ),
971 ThreadItemUpdatedEvent(
972 item_id="123",
973 update=AssistantMessageContentPartAnnotationAdded(
974 content_index=0,
975 annotation_index=0,
976 annotation=Annotation(
977 source=FileSource(filename="file.txt", title="file.txt"),
978 index=5,
979 ),
980 ),
981 ),
982 ),
983 ],
984)
985async def test_event_mapping(raw_event, expected_event):
986 context = AgentContext(
987 previous_response_id=None, thread=thread, store=mock_store, request_context=None
988 )
989 result = make_result()
990
991 result.add_event(raw_event)
992 result.done()
993
994 events = await all_events(stream_agent_response(context, result))
995 if expected_event:
996 assert events == [expected_event]
997 else:
998 assert events == []
999
1000
1001async def test_stream_agent_response_emits_annotation_added_events():
1002 context = AgentContext(
1003 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1004 )
1005 result = make_result()
1006 item_id = "item_123"
1007
1008 def add_annotation_event(annotation, sequence_number):
1009 result.add_event(
1010 RawResponsesStreamEvent(
1011 type="raw_response_event",
1012 data=Mock(
1013 type="response.output_text.annotation.added",
1014 annotation=annotation,
1015 content_index=0,
1016 item_id=item_id,
1017 annotation_index=sequence_number,
1018 output_index=0,
1019 sequence_number=sequence_number,
1020 ),
1021 )
1022 )
1023
1024 add_annotation_event(
1025 ResponsesAnnotationFileCitation(
1026 type="file_citation",
1027 file_id="file_invalid",
1028 filename="",
1029 index=0,
1030 ),
1031 sequence_number=0,
1032 )
1033 add_annotation_event(
1034 ResponsesAnnotationContainerFileCitation(
1035 type="container_file_citation",
1036 container_id="container_1",
1037 file_id="file_123",
1038 filename="container.txt",
1039 start_index=0,
1040 end_index=3,
1041 ),
1042 sequence_number=1,
1043 )
1044 add_annotation_event(
1045 ResponsesAnnotationURLCitation(
1046 type="url_citation",
1047 url="https://example.com",
1048 title="Example",
1049 start_index=1,
1050 end_index=5,
1051 ),
1052 sequence_number=2,
1053 )
1054 result.done()
1055
1056 events = await all_events(stream_agent_response(context, result))
1057 assert events == [
1058 ThreadItemUpdatedEvent(
1059 item_id=item_id,
1060 update=AssistantMessageContentPartAnnotationAdded(
1061 content_index=0,
1062 annotation_index=0,
1063 annotation=Annotation(
1064 source=FileSource(filename="container.txt", title="container.txt"),
1065 index=3,
1066 ),
1067 ),
1068 ),
1069 ThreadItemUpdatedEvent(
1070 item_id=item_id,
1071 update=AssistantMessageContentPartAnnotationAdded(
1072 content_index=0,
1073 annotation_index=1,
1074 annotation=Annotation(
1075 source=URLSource(
1076 url="https://example.com",
1077 title="Example",
1078 ),
1079 index=5,
1080 ),
1081 ),
1082 ),
1083 ]
1084
1085
1086async def test_stream_agent_response_annotation_added_normalizes_annotations():
1087 context = AgentContext(
1088 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1089 )
1090 result = make_result()
1091 item_id = "item_123"
1092
1093 def add_annotation_event(annotation, sequence_number):
1094 result.add_event(
1095 RawResponsesStreamEvent(
1096 type="raw_response_event",
1097 data=Mock(
1098 type="response.output_text.annotation.added",
1099 annotation=annotation,
1100 content_index=0,
1101 item_id=item_id,
1102 annotation_index=sequence_number,
1103 output_index=0,
1104 sequence_number=sequence_number,
1105 ),
1106 )
1107 )
1108
1109 # Invalid file citation is dropped (empty filename)
1110 add_annotation_event(
1111 {
1112 "type": "file_citation",
1113 "file_id": "file_invalid",
1114 "filename": "",
1115 "index": 0,
1116 },
1117 sequence_number=0,
1118 )
1119
1120 # Container + URL citations are converted; indices are compacted (0, 1) despite the dropped first event.
1121 add_annotation_event(
1122 {
1123 "type": "container_file_citation",
1124 "container_id": "container_1",
1125 "file_id": "file_123",
1126 "filename": "container.txt",
1127 "start_index": 0,
1128 "end_index": 3,
1129 },
1130 sequence_number=1,
1131 )
1132 add_annotation_event(
1133 {
1134 "type": "url_citation",
1135 "url": "https://example.com",
1136 "title": "Example",
1137 "start_index": 1,
1138 "end_index": 5,
1139 },
1140 sequence_number=2,
1141 )
1142
1143 result.done()
1144
1145 events = await all_events(stream_agent_response(context, result))
1146 assert events == [
1147 ThreadItemUpdatedEvent(
1148 item_id=item_id,
1149 update=AssistantMessageContentPartAnnotationAdded(
1150 content_index=0,
1151 annotation_index=0,
1152 annotation=Annotation(
1153 source=FileSource(filename="container.txt", title="container.txt"),
1154 index=3,
1155 ),
1156 ),
1157 ),
1158 ThreadItemUpdatedEvent(
1159 item_id=item_id,
1160 update=AssistantMessageContentPartAnnotationAdded(
1161 content_index=0,
1162 annotation_index=1,
1163 annotation=Annotation(
1164 source=URLSource(url="https://example.com", title="Example"),
1165 index=5,
1166 ),
1167 ),
1168 ),
1169 ]
1170
1171
1172async def test_custom_annotation_conversion_used_by_stream_agent_response():
1173 context = AgentContext(
1174 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1175 )
1176 result = make_result()
1177
1178 class CustomResponseStreamConverter(ResponseStreamConverter):
1179 def __init__(self):
1180 super().__init__()
1181 self.calls: list[str] = []
1182
1183 async def file_citation_to_annotation(self, file_citation):
1184 self.calls.append("file_citation")
1185 return Annotation(
1186 source=FileSource(
1187 filename="report.pdf",
1188 title="Usage Report",
1189 description="Usage report for the month of January",
1190 ),
1191 index=111,
1192 )
1193
1194 async def url_citation_to_annotation(self, url_citation):
1195 self.calls.append("url_citation")
1196 return Annotation(
1197 source=URLSource(url="https://custom.example/url", title="Custom"),
1198 index=222,
1199 )
1200
1201 converter = CustomResponseStreamConverter()
1202
1203 result.add_event(
1204 RawResponsesStreamEvent(
1205 type="raw_response_event",
1206 data=Mock(
1207 type="response.output_text.annotation.added",
1208 annotation=ResponsesAnnotationFileCitation(
1209 type="file_citation",
1210 file_id="file_123",
1211 filename="file.txt",
1212 index=0,
1213 ),
1214 content_index=0,
1215 item_id="m_1",
1216 annotation_index=0,
1217 output_index=0,
1218 sequence_number=0,
1219 ),
1220 )
1221 )
1222
1223 result.add_event(
1224 RawResponsesStreamEvent(
1225 type="raw_response_event",
1226 data=ResponseOutputItemDoneEvent(
1227 type="response.output_item.done",
1228 item=ResponseOutputMessage(
1229 id="m_1",
1230 content=[
1231 ResponseOutputText(
1232 annotations=[
1233 ResponsesAnnotationURLCitation(
1234 type="url_citation",
1235 url="https://example.com",
1236 title="Example",
1237 start_index=0,
1238 end_index=7,
1239 )
1240 ],
1241 text="Hello!",
1242 type="output_text",
1243 )
1244 ],
1245 role="assistant",
1246 status="completed",
1247 type="message",
1248 ),
1249 output_index=0,
1250 sequence_number=1,
1251 ),
1252 )
1253 )
1254 result.done()
1255
1256 events = await all_events(
1257 stream_agent_response(context, result, converter=converter)
1258 )
1259 assert len(events) == 2
1260
1261 assert isinstance(events[0], ThreadItemUpdatedEvent)
1262 assert events[0].item_id == "m_1"
1263 assert events[0].update == AssistantMessageContentPartAnnotationAdded(
1264 content_index=0,
1265 annotation_index=0,
1266 annotation=Annotation(
1267 source=FileSource(
1268 filename="report.pdf",
1269 title="Usage Report",
1270 description="Usage report for the month of January",
1271 ),
1272 index=111,
1273 ),
1274 )
1275
1276 assert isinstance(events[1], ThreadItemDoneEvent)
1277 assert isinstance(events[1].item, AssistantMessageItem)
1278 assert events[1].item.id == "m_1"
1279 assert events[1].item.content == [
1280 AssistantMessageContent(
1281 text="Hello!",
1282 annotations=[
1283 Annotation(
1284 source=URLSource(url="https://custom.example/url", title="Custom"),
1285 index=222,
1286 )
1287 ],
1288 )
1289 ]
1290 assert converter.calls == ["file_citation", "url_citation"]
1291
1292
1293@pytest.mark.parametrize("throw_guardrail", ["input", "output"])
1294async def test_stream_agent_response_yields_item_removed_event(throw_guardrail):
1295 context = AgentContext(
1296 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1297 )
1298 result = make_result()
1299 result.add_event(
1300 RawResponsesStreamEvent(
1301 type="raw_response_event",
1302 data=ResponseOutputItemAddedEvent(
1303 type="response.output_item.added",
1304 item=ResponseOutputMessage(
1305 id="1",
1306 content=[
1307 ResponseOutputText(
1308 annotations=[], type="output_text", text="Hello, world!"
1309 )
1310 ],
1311 role="assistant",
1312 status="completed",
1313 type="message",
1314 ),
1315 output_index=0,
1316 sequence_number=0,
1317 ),
1318 )
1319 )
1320 await context.stream(
1321 ThreadItemAddedEvent(
1322 item=AssistantMessageItem(
1323 id="2",
1324 content=[AssistantMessageContent(text="Hello, world!")],
1325 thread_id=thread.id,
1326 created_at=datetime.now(),
1327 ),
1328 )
1329 )
1330
1331 await context.stream(
1332 ThreadItemDoneEvent(
1333 item=WidgetItem(
1334 id="3",
1335 created_at=datetime.now(),
1336 thread_id=thread.id,
1337 widget=Card(children=[Text(id="text", value="Hello, world!")]),
1338 ),
1339 )
1340 )
1341
1342 iterator = stream_agent_response(context, result)
1343
1344 n = 3
1345 events = []
1346 # Grab first 3 events to
1347 async for event in iterator:
1348 n -= 1
1349 events.append(event)
1350 if n == 0:
1351 break
1352
1353 if throw_guardrail == "input":
1354 result.throw_input_guardrails()
1355 else:
1356 result.throw_output_guardrails()
1357
1358 try:
1359 async for event in iterator:
1360 events.append(event)
1361 assert False, "Guardrail should have been thrown from stream_agent_response"
1362 except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
1363 pass
1364 except Exception as e:
1365 assert False, f"Unexpected exception: {e}"
1366
1367 deleted_item_ids = {
1368 event.item_id for event in events if event.type == "thread.item.removed"
1369 }
1370 assert deleted_item_ids == {"1", "2", "3"}
1371
1372
1373async def test_stream_agent_response_assistant_message_content_types():
1374 AgentContext(
1375 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1376 )
1377 result = make_result()
1378
1379 result.add_event(
1380 RawResponsesStreamEvent(
1381 type="raw_response_event",
1382 data=ResponseOutputItemDoneEvent(
1383 type="response.output_item.done",
1384 item=ResponseFileSearchToolCall(
1385 id="fs_0",
1386 queries=["Hello, world!"],
1387 status="completed",
1388 type="file_search_call",
1389 results=[
1390 Result(
1391 file_id="f_123",
1392 filename="test.txt",
1393 text="Hello, world!",
1394 score=1.0,
1395 ),
1396 Result(
1397 file_id="f_123",
1398 filename="test.txt",
1399 text="Hello, friends!",
1400 score=0.5,
1401 ),
1402 ],
1403 ),
1404 output_index=0,
1405 sequence_number=0,
1406 ),
1407 )
1408 )
1409 result.add_event(
1410 RawResponsesStreamEvent(
1411 type="raw_response_event",
1412 data=ResponseOutputItemDoneEvent(
1413 type="response.output_item.done",
1414 item=ResponseOutputMessage(
1415 id="1",
1416 content=[
1417 ResponseOutputText(
1418 annotations=[
1419 ResponsesAnnotationFileCitation(
1420 type="file_citation",
1421 file_id="f_123",
1422 index=0,
1423 filename="test.txt",
1424 ),
1425 ResponsesAnnotationContainerFileCitation(
1426 type="container_file_citation",
1427 container_id="container_1",
1428 file_id="f_456",
1429 filename="container.txt",
1430 start_index=0,
1431 end_index=3,
1432 ),
1433 ResponsesAnnotationURLCitation(
1434 type="url_citation",
1435 url="https://www.google.com",
1436 title="Google",
1437 start_index=0,
1438 end_index=10,
1439 ),
1440 ResponsesAnnotationFilePath(
1441 type="file_path",
1442 file_id="123",
1443 index=0,
1444 ),
1445 ],
1446 text="Hello, world!",
1447 type="output_text",
1448 ),
1449 ResponseOutputText(
1450 annotations=[],
1451 text="Can't do that",
1452 type="output_text",
1453 ),
1454 ],
1455 role="assistant",
1456 status="completed",
1457 type="message",
1458 ),
1459 output_index=0,
1460 sequence_number=0,
1461 ),
1462 )
1463 )
1464
1465 result.done()
1466
1467 context = AgentContext(
1468 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1469 )
1470 events = await all_events(stream_agent_response(context, result))
1471 assert len(events) == 1
1472 assert isinstance(events[0], ThreadItemDoneEvent)
1473 message = events[0].item
1474 assert isinstance(message, AssistantMessageItem)
1475 assert message.content == [
1476 AssistantMessageContent(
1477 annotations=[
1478 Annotation(
1479 source=FileSource(
1480 filename="test.txt",
1481 title="test.txt",
1482 ),
1483 index=0,
1484 ),
1485 Annotation(
1486 source=FileSource(
1487 filename="container.txt",
1488 title="container.txt",
1489 ),
1490 index=3,
1491 ),
1492 Annotation(
1493 source=URLSource(
1494 url="https://www.google.com",
1495 title="Google",
1496 ),
1497 index=10,
1498 ),
1499 ],
1500 text="Hello, world!",
1501 ),
1502 AssistantMessageContent(text="Can't do that", annotations=[]),
1503 ]
1504 assert message.id == "1"
1505
1506
1507async def test_stream_agent_response_image_generation_events():
1508 context = AgentContext(
1509 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1510 )
1511 result = make_result()
1512
1513 result.add_event(
1514 RawResponsesStreamEvent(
1515 type="raw_response_event",
1516 data=Mock(
1517 type="response.output_item.added",
1518 item=Mock(type="image_generation_call", id="img_call_1"),
1519 output_index=0,
1520 sequence_number=0,
1521 ),
1522 )
1523 )
1524 result.add_event(
1525 RawResponsesStreamEvent(
1526 type="raw_response_event",
1527 data=Mock(
1528 type="response.output_item.done",
1529 item=Mock(
1530 type="image_generation_call", id="img_call_1", result="dGVzdA=="
1531 ),
1532 output_index=0,
1533 sequence_number=1,
1534 ),
1535 )
1536 )
1537 result.done()
1538
1539 stream = stream_agent_response(context, result)
1540 event1 = await stream.__anext__()
1541 assert isinstance(event1, ThreadItemAddedEvent)
1542 assert isinstance(event1.item, GeneratedImageItem)
1543 assert event1.item.type == "generated_image"
1544 assert event1.item.id == "message_id"
1545 assert event1.item.image is None
1546
1547 event2 = await stream.__anext__()
1548 assert isinstance(event2, ThreadItemDoneEvent)
1549 assert isinstance(event2.item, GeneratedImageItem)
1550 assert event2.item.id == event1.item.id
1551 assert event2.item.image == GeneratedImage(
1552 id="img_call_1", url="data:image/png;base64,dGVzdA=="
1553 )
1554
1555 with pytest.raises(StopAsyncIteration):
1556 await stream.__anext__()
1557
1558
1559async def test_stream_agent_response_image_generation_events_with_custom_converter():
1560 context = AgentContext(
1561 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1562 )
1563 result = make_result()
1564
1565 result.add_event(
1566 RawResponsesStreamEvent(
1567 type="raw_response_event",
1568 data=Mock(
1569 type="response.output_item.added",
1570 item=Mock(type="image_generation_call", id="img_call_1"),
1571 output_index=0,
1572 sequence_number=0,
1573 ),
1574 )
1575 )
1576 result.add_event(
1577 RawResponsesStreamEvent(
1578 type="raw_response_event",
1579 data=Mock(
1580 type="response.output_item.done",
1581 item=Mock(
1582 type="image_generation_call", id="img_call_1", result="dGVzdA=="
1583 ),
1584 output_index=0,
1585 sequence_number=1,
1586 ),
1587 )
1588 )
1589 result.done()
1590
1591 class CustomResponseStreamConverter(ResponseStreamConverter):
1592 def __init__(self):
1593 super().__init__()
1594 self.calls: list[tuple[str, str, int | None]] = []
1595
1596 async def base64_image_to_url(
1597 self,
1598 image_id: str,
1599 base64_image: str,
1600 partial_image_index: int | None = None,
1601 ) -> str:
1602 self.calls.append((image_id, base64_image, partial_image_index))
1603 return f"https://example.com/{image_id}"
1604
1605 converter = CustomResponseStreamConverter()
1606 stream = stream_agent_response(context, result, converter=converter)
1607 event1 = await stream.__anext__()
1608 assert isinstance(event1, ThreadItemAddedEvent)
1609 assert isinstance(event1.item, GeneratedImageItem)
1610 assert event1.item.image is None
1611
1612 event2 = await stream.__anext__()
1613 assert isinstance(event2, ThreadItemDoneEvent)
1614 assert isinstance(event2.item, GeneratedImageItem)
1615 assert converter.calls == [("img_call_1", "dGVzdA==", None)]
1616 assert event2.item.image == GeneratedImage(
1617 id="img_call_1", url="https://example.com/img_call_1"
1618 )
1619 with pytest.raises(StopAsyncIteration):
1620 await stream.__anext__()
1621
1622
1623async def test_stream_agent_response_image_generation_partial_progress():
1624 context = AgentContext(
1625 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1626 )
1627 result = make_result()
1628
1629 result.add_event(
1630 RawResponsesStreamEvent(
1631 type="raw_response_event",
1632 data=Mock(
1633 type="response.output_item.added",
1634 item=Mock(type="image_generation_call", id="img_call_1"),
1635 output_index=0,
1636 sequence_number=0,
1637 ),
1638 )
1639 )
1640 result.add_event(
1641 RawResponsesStreamEvent(
1642 type="raw_response_event",
1643 data=Mock(
1644 type="response.image_generation_call.partial_image",
1645 partial_image_b64="dGVzdA==",
1646 partial_image_index=1,
1647 item_id="img_call_1",
1648 output_index=0,
1649 sequence_number=1,
1650 ),
1651 )
1652 )
1653 result.add_event(
1654 RawResponsesStreamEvent(
1655 type="raw_response_event",
1656 data=Mock(
1657 type="response.output_item.done",
1658 item=Mock(
1659 type="image_generation_call", id="img_call_1", result="dGVzdA=="
1660 ),
1661 output_index=0,
1662 sequence_number=2,
1663 ),
1664 )
1665 )
1666 result.done()
1667
1668 converter = ResponseStreamConverter(partial_images=3)
1669 events = await all_events(
1670 stream_agent_response(context, result, converter=converter)
1671 )
1672
1673 assert len(events) == 3
1674 added_event, partial_event, done_event = events
1675
1676 assert isinstance(added_event, ThreadItemAddedEvent)
1677 assert isinstance(added_event.item, GeneratedImageItem)
1678
1679 assert isinstance(partial_event, ThreadItemUpdatedEvent)
1680 assert isinstance(partial_event.update, GeneratedImageUpdated)
1681 assert partial_event.update.progress == pytest.approx(1 / 3)
1682 assert partial_event.update.image == GeneratedImage(
1683 id="img_call_1", url="data:image/png;base64,dGVzdA=="
1684 )
1685
1686 assert isinstance(done_event, ThreadItemDoneEvent)
1687 assert isinstance(done_event.item, GeneratedImageItem)
1688 assert done_event.item.image == GeneratedImage(
1689 id="img_call_1", url="data:image/png;base64,dGVzdA=="
1690 )
1691
1692
1693async def test_workflow_streams_first_thought():
1694 context = AgentContext(
1695 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1696 )
1697 result = make_result()
1698
1699 # first thought
1700 result.add_event(
1701 RawResponsesStreamEvent(
1702 type="raw_response_event",
1703 data=ResponseOutputItemAddedEvent(
1704 type="response.output_item.added",
1705 item=ResponseReasoningItem(
1706 id="resp_1",
1707 summary=[],
1708 type="reasoning",
1709 ),
1710 output_index=0,
1711 sequence_number=0,
1712 ),
1713 )
1714 )
1715 result.add_event(
1716 RawResponsesStreamEvent(
1717 type="raw_response_event",
1718 data=Mock(
1719 type="response.reasoning_summary_text.delta",
1720 item_id="resp_1",
1721 summary_index=0,
1722 delta="Think",
1723 ),
1724 )
1725 )
1726 result.add_event(
1727 RawResponsesStreamEvent(
1728 type="raw_response_event",
1729 data=Mock(
1730 type="response.reasoning_summary_text.delta",
1731 item_id="resp_1",
1732 summary_index=0,
1733 delta="ing 1",
1734 ),
1735 )
1736 )
1737 result.add_event(
1738 RawResponsesStreamEvent(
1739 type="raw_response_event",
1740 data=Mock(
1741 type="response.reasoning_summary_text.done",
1742 item_id="resp_1",
1743 summary_index=0,
1744 text="Thinking 1",
1745 ),
1746 )
1747 )
1748
1749 # second thought
1750 result.add_event(
1751 RawResponsesStreamEvent(
1752 type="raw_response_event",
1753 data=Mock(
1754 type="response.reasoning_summary_text.delta",
1755 item_id="resp_1",
1756 summary_index=1,
1757 delta="Think",
1758 ),
1759 )
1760 )
1761 result.add_event(
1762 RawResponsesStreamEvent(
1763 type="raw_response_event",
1764 data=Mock(
1765 type="response.reasoning_summary_text.delta",
1766 item_id="resp_1",
1767 summary_index=1,
1768 delta="ing 2",
1769 ),
1770 )
1771 )
1772 result.add_event(
1773 RawResponsesStreamEvent(
1774 type="raw_response_event",
1775 data=Mock(
1776 type="response.reasoning_summary_text.done",
1777 item_id="resp_1",
1778 summary_index=1,
1779 text="Thinking 2",
1780 ),
1781 )
1782 )
1783
1784 result.done()
1785 stream = stream_agent_response(context, result)
1786
1787 # Workflow added
1788 event = await anext(stream)
1789 assert isinstance(event, ThreadItemAddedEvent)
1790 assert context.workflow_item is not None
1791 assert context.workflow_item.workflow.type == "reasoning"
1792 assert len(context.workflow_item.workflow.tasks) == 0
1793 assert event == ThreadItemAddedEvent(item=context.workflow_item)
1794
1795 # First thought added
1796 event = await anext(stream)
1797 assert context.workflow_item is not None
1798 assert len(context.workflow_item.workflow.tasks) == 1
1799 assert isinstance(event, ThreadItemUpdatedEvent)
1800 assert event == ThreadItemUpdatedEvent(
1801 item_id=context.workflow_item.id,
1802 update=WorkflowTaskAdded(
1803 task=ThoughtTask(content="Think"),
1804 task_index=0,
1805 ),
1806 )
1807
1808 # First thought delta
1809 event = await anext(stream)
1810 assert context.workflow_item is not None
1811 assert len(context.workflow_item.workflow.tasks) == 1
1812 assert isinstance(event, ThreadItemUpdatedEvent)
1813 assert event == ThreadItemUpdatedEvent(
1814 item_id=context.workflow_item.id,
1815 update=WorkflowTaskUpdated(
1816 task=ThoughtTask(content="Thinking 1"),
1817 task_index=0,
1818 ),
1819 )
1820
1821 # First thought done
1822 event = await anext(stream)
1823 assert context.workflow_item is not None
1824 assert len(context.workflow_item.workflow.tasks) == 1
1825 assert isinstance(event, ThreadItemUpdatedEvent)
1826 assert event == ThreadItemUpdatedEvent(
1827 item_id=context.workflow_item.id,
1828 update=WorkflowTaskUpdated(
1829 task=ThoughtTask(content="Thinking 1"),
1830 task_index=0,
1831 ),
1832 )
1833
1834 # Second thought added (not streamed)
1835 event = await anext(stream)
1836 assert context.workflow_item is not None
1837 assert len(context.workflow_item.workflow.tasks) == 2
1838 assert isinstance(event, ThreadItemUpdatedEvent)
1839 assert event == ThreadItemUpdatedEvent(
1840 item_id=context.workflow_item.id,
1841 update=WorkflowTaskAdded(
1842 task=ThoughtTask(content="Thinking 2"),
1843 task_index=1,
1844 ),
1845 )
1846
1847 try:
1848 while True:
1849 await anext(stream)
1850 except StopAsyncIteration:
1851 pass
1852
1853
1854async def test_workflow_ends_on_message():
1855 context = AgentContext(
1856 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1857 )
1858 result = make_result()
1859
1860 # first thought
1861 result.add_event(
1862 RawResponsesStreamEvent(
1863 type="raw_response_event",
1864 data=ResponseOutputItemAddedEvent(
1865 type="response.output_item.added",
1866 item=ResponseReasoningItem(
1867 id="resp_1",
1868 summary=[],
1869 type="reasoning",
1870 ),
1871 output_index=0,
1872 sequence_number=0,
1873 ),
1874 )
1875 )
1876 result.add_event(
1877 RawResponsesStreamEvent(
1878 type="raw_response_event",
1879 data=Mock(
1880 type="response.reasoning_summary_text.done",
1881 item_id="resp_1",
1882 summary_index=0,
1883 text="Thinking 1",
1884 ),
1885 )
1886 )
1887
1888 # not reasoning
1889 result.add_event(
1890 RawResponsesStreamEvent(
1891 type="raw_response_event",
1892 data=ResponseOutputItemAddedEvent(
1893 type="response.output_item.added",
1894 item=ResponseOutputMessage(
1895 id="m_1",
1896 content=[],
1897 role="assistant",
1898 status="in_progress",
1899 type="message",
1900 ),
1901 output_index=0,
1902 sequence_number=0,
1903 ),
1904 )
1905 )
1906
1907 result.done()
1908 stream = stream_agent_response(context, result)
1909
1910 # Workflow added
1911 event = await anext(stream)
1912 assert isinstance(event, ThreadItemAddedEvent)
1913 assert context.workflow_item is not None
1914 assert context.workflow_item.workflow.type == "reasoning"
1915 assert len(context.workflow_item.workflow.tasks) == 0
1916 assert event == ThreadItemAddedEvent(item=context.workflow_item)
1917
1918 # First thought done
1919 event = await anext(stream)
1920 assert context.workflow_item is not None
1921 assert len(context.workflow_item.workflow.tasks) == 1
1922 assert isinstance(event, ThreadItemUpdatedEvent)
1923 assert event == ThreadItemUpdatedEvent(
1924 item_id=context.workflow_item.id,
1925 update=WorkflowTaskAdded(
1926 task=ThoughtTask(content="Thinking 1"),
1927 task_index=0,
1928 ),
1929 )
1930
1931 # Workflow ended
1932 event = await anext(stream)
1933 assert isinstance(event, ThreadItemDoneEvent)
1934 assert event.item.type == "workflow"
1935 assert context.workflow_item is None
1936 # Summary and expanded are handled by the end_workflow method
1937 assert isinstance(event.item.workflow.summary, DurationSummary)
1938 assert event.item.workflow.expanded is False
1939
1940 try:
1941 while True:
1942 await anext(stream)
1943 except StopAsyncIteration:
1944 pass
1945
1946
1947async def test_existing_workflow_summary_not_overwritten_on_automatic_end():
1948 context = AgentContext(
1949 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1950 )
1951 result = make_result()
1952 context.workflow_item = WorkflowItem(
1953 id="wf_1",
1954 created_at=datetime.now(),
1955 workflow=Workflow(type="custom", tasks=[], summary=CustomSummary(title="Test")),
1956 thread_id=thread.id,
1957 )
1958
1959 result.add_event(
1960 RawResponsesStreamEvent(
1961 type="raw_response_event",
1962 data=ResponseOutputItemAddedEvent(
1963 type="response.output_item.added",
1964 item=ResponseOutputMessage(
1965 id="m_1",
1966 content=[],
1967 role="assistant",
1968 status="in_progress",
1969 type="message",
1970 ),
1971 output_index=0,
1972 sequence_number=0,
1973 ),
1974 )
1975 )
1976
1977 result.done()
1978 stream = stream_agent_response(context, result)
1979
1980 event = await anext(stream)
1981
1982 assert isinstance(event, ThreadItemDoneEvent)
1983 assert context.workflow_item is None
1984 assert event.item.type == "workflow"
1985 assert event.item.workflow.summary == CustomSummary(title="Test")
1986
1987 try:
1988 while True:
1989 await anext(stream)
1990 except StopAsyncIteration:
1991 pass
1992