openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
d14ce4602e2002ef97ccff685ec14fb27c49c94c

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_agents.py

1928lines · modecode

1import asyncio
2import json
3from collections.abc import AsyncIterator
4from datetime import datetime
5from typing import cast
6from unittest.mock import AsyncMock, Mock
7
8import pytest
9from agents import (
10 Agent,
11 GuardrailFunctionOutput,
12 InputGuardrail,
13 InputGuardrailResult,
14 InputGuardrailTripwireTriggered,
15 OutputGuardrail,
16 OutputGuardrailResult,
17 OutputGuardrailTripwireTriggered,
18 RawResponsesStreamEvent,
19 RunContextWrapper,
20 RunItemStreamEvent,
21 Runner,
22 RunResultStreaming,
23 StreamEvent,
24 ToolCallItem,
25)
26from agents._run_impl import QueueCompleteSentinel
27from openai.types.responses import (
28 EasyInputMessageParam,
29 ResponseFileSearchToolCall,
30 ResponseInputContentParam,
31 ResponseInputTextParam,
32 ResponseOutputItemAddedEvent,
33 ResponseOutputItemDoneEvent,
34 ResponseOutputMessage,
35 ResponseReasoningItem,
36)
37from openai.types.responses.response_content_part_added_event import (
38 ResponseContentPartAddedEvent,
39)
40from openai.types.responses.response_file_search_tool_call import Result
41from openai.types.responses.response_input_item_param import Message
42from openai.types.responses.response_output_text import (
43 AnnotationContainerFileCitation as ResponsesAnnotationContainerFileCitation,
44)
45from openai.types.responses.response_output_text import (
46 AnnotationFileCitation as ResponsesAnnotationFileCitation,
47)
48from openai.types.responses.response_output_text import (
49 AnnotationFilePath as ResponsesAnnotationFilePath,
50)
51from openai.types.responses.response_output_text import (
52 AnnotationURLCitation as ResponsesAnnotationURLCitation,
53)
54from openai.types.responses.response_output_text import (
55 ResponseOutputText,
56)
57from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent
58from openai.types.responses.response_text_done_event import ResponseTextDoneEvent
59
60from chatkit.agents import (
61 AgentContext,
62 ResponseStreamConverter,
63 ThreadItemConverter,
64 accumulate_text,
65 simple_to_agent_input,
66 stream_agent_response,
67)
68from chatkit.types import (
69 Annotation,
70 AssistantMessageContent,
71 AssistantMessageContentPartAdded,
72 AssistantMessageContentPartAnnotationAdded,
73 AssistantMessageContentPartDone,
74 AssistantMessageContentPartTextDelta,
75 AssistantMessageItem,
76 Attachment,
77 ClientToolCallItem,
78 CustomSummary,
79 CustomTask,
80 DurationSummary,
81 FileSource,
82 GeneratedImage,
83 GeneratedImageItem,
84 GeneratedImageUpdated,
85 HiddenContextItem,
86 InferenceOptions,
87 Page,
88 TaskItem,
89 ThoughtTask,
90 Thread,
91 ThreadItemAddedEvent,
92 ThreadItemDoneEvent,
93 ThreadItemUpdatedEvent,
94 ThreadStreamEvent,
95 URLSource,
96 UserMessageItem,
97 UserMessageTagContent,
98 UserMessageTextContent,
99 WidgetItem,
100 Workflow,
101 WorkflowItem,
102 WorkflowTaskAdded,
103 WorkflowTaskUpdated,
104)
105from chatkit.widgets import Card, Text
106
107thread = Thread(id="123", title="Test", created_at=datetime.now(), items=Page())
108
109mock_store = Mock()
110mock_store.generate_item_id = lambda item_type, thread, context: f"{item_type}_id"
111mock_store.load_thread_items = AsyncMock(return_value=Page())
112mock_store.add_thread_item = AsyncMock()
113
114
115class RunResult(RunResultStreaming):
116 def add_event(self, event: StreamEvent):
117 self._event_queue.put_nowait(event)
118
119 def done(self):
120 self.is_complete = True
121 self._event_queue.put_nowait(QueueCompleteSentinel())
122
123 def throw_input_guardrails(self):
124 self._stored_exception = InputGuardrailTripwireTriggered(
125 InputGuardrailResult(
126 guardrail=Mock(spec=InputGuardrail),
127 output=GuardrailFunctionOutput(
128 output_info=None,
129 tripwire_triggered=True,
130 ),
131 )
132 )
133 self.is_complete = True
134 self._event_queue.put_nowait(QueueCompleteSentinel())
135
136 def throw_output_guardrails(self):
137 self._stored_exception = OutputGuardrailTripwireTriggered(
138 OutputGuardrailResult(
139 guardrail=Mock(spec=OutputGuardrail),
140 output=GuardrailFunctionOutput(
141 output_info=None,
142 tripwire_triggered=True,
143 ),
144 agent=Mock(spec=Agent),
145 agent_output=None,
146 )
147 )
148 self.is_complete = True
149 self._event_queue.put_nowait(QueueCompleteSentinel())
150
151
152def make_result() -> RunResult:
153 return RunResult(
154 context_wrapper=Mock(spec=RunContextWrapper),
155 input=[],
156 tool_input_guardrail_results=[],
157 tool_output_guardrail_results=[],
158 new_items=[],
159 raw_responses=[],
160 final_output=None,
161 current_agent=Agent(name="test"),
162 current_turn=0,
163 max_turns=10,
164 _current_agent_output_schema=None,
165 trace=None,
166 is_complete=False,
167 _event_queue=asyncio.Queue(),
168 _input_guardrail_queue=asyncio.Queue(),
169 _output_guardrails_task=None,
170 _run_impl_task=None,
171 _stored_exception=None,
172 output_guardrail_results=[],
173 input_guardrail_results=[],
174 )
175
176
177async def all_events(
178 events: AsyncIterator[ThreadStreamEvent],
179) -> list[ThreadStreamEvent]:
180 return [event async for event in events]
181
182
183async def test_returns_widget_item():
184 context = AgentContext(
185 previous_response_id=None, thread=thread, store=mock_store, request_context=None
186 )
187 result = make_result()
188 result.add_event(
189 RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
190 )
191 await context.stream_widget(Card(children=[Text(value="Hello, world!")]))
192 result.done()
193
194 events = await all_events(
195 stream_agent_response(
196 context=context,
197 result=result,
198 )
199 )
200
201 assert len(events) == 1
202 assert isinstance(events[0], ThreadItemDoneEvent)
203 assert isinstance(events[0].item, WidgetItem)
204 assert events[0].item.widget == Card(children=[Text(value="Hello, world!")])
205
206
207async def test_returns_widget_item_generator():
208 context = AgentContext(
209 previous_response_id=None, thread=thread, store=mock_store, request_context=None
210 )
211 result = make_result()
212 result.add_event(
213 RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
214 )
215
216 def render_widget(i: int) -> Card:
217 return Card(children=[Text(id="text", value="Hello, world"[:i])])
218
219 async def widget_generator():
220 yield render_widget(0)
221 yield render_widget(12)
222
223 await context.stream_widget(widget_generator())
224 result.done()
225
226 events = await all_events(
227 stream_agent_response(
228 context=context,
229 result=result,
230 )
231 )
232
233 assert len(events) == 3
234 assert isinstance(events[0], ThreadItemAddedEvent)
235 assert isinstance(events[0].item, WidgetItem)
236 assert events[0].item.widget == Card(children=[Text(id="text", value="")])
237
238 assert isinstance(events[1], ThreadItemUpdatedEvent)
239 assert events[1].update.type == "widget.streaming_text.value_delta"
240 assert events[1].update.component_id == "text"
241 assert events[1].update.delta == "Hello, world"
242
243 assert isinstance(events[2], ThreadItemDoneEvent)
244 assert isinstance(events[2].item, WidgetItem)
245 assert events[2].item.widget == Card(
246 children=[Text(id="text", value="Hello, world")]
247 )
248
249
250async def test_returns_widget_full_replace_generator():
251 context = AgentContext(
252 previous_response_id=None, thread=thread, store=mock_store, request_context=None
253 )
254 result = make_result()
255 result.add_event(
256 RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
257 )
258
259 async def widget_generator():
260 yield Card(children=[Text(id="text", value="Hello!")])
261 yield Card(children=[Text(key="other text", value="World!", streaming=False)])
262
263 await context.stream_widget(widget_generator())
264 result.done()
265
266 events = await all_events(
267 stream_agent_response(
268 context=context,
269 result=result,
270 )
271 )
272
273 assert len(events) == 3
274 assert isinstance(events[0], ThreadItemAddedEvent)
275 assert isinstance(events[0].item, WidgetItem)
276 assert events[0].item.widget == Card(children=[Text(id="text", value="Hello!")])
277
278 assert isinstance(events[1], ThreadItemUpdatedEvent)
279 assert events[1].update.type == "widget.root.updated"
280 assert events[1].update.widget == Card(
281 children=[Text(key="other text", value="World!", streaming=False)]
282 )
283
284 assert isinstance(events[2], ThreadItemDoneEvent)
285 assert isinstance(events[2].item, WidgetItem)
286 assert events[2].item.widget == Card(
287 children=[Text(key="other text", value="World!", streaming=False)]
288 )
289
290
291async def test_accumulate_text():
292 def delta(text: str) -> RawResponsesStreamEvent:
293 return RawResponsesStreamEvent(
294 type="raw_response_event",
295 data=ResponseTextDeltaEvent(
296 type="response.output_text.delta",
297 delta=text,
298 content_index=0,
299 item_id="123",
300 logprobs=[],
301 output_index=0,
302 sequence_number=0,
303 ),
304 )
305
306 result = Runner.run_streamed(
307 Agent("Assistant", instructions="You are a helpful assistant."), "Say hello!"
308 )
309 result = make_result()
310 result.add_event(delta("Hello, "))
311 result.add_event(delta("world!"))
312
313 result.done()
314
315 events = [
316 event
317 async for event in accumulate_text(
318 result.stream_events(), Text(key="text", value="", streaming=True)
319 )
320 ]
321 assert events == [
322 Text(key="text", value="", streaming=True),
323 Text(key="text", value="Hello, ", streaming=True),
324 Text(key="text", value="Hello, world!", streaming=True),
325 Text(key="text", value="Hello, world!", streaming=False),
326 ]
327
328
329async def test_input_item_converter_quotes_last_user_message():
330 items = [
331 UserMessageItem(
332 id="123",
333 content=[UserMessageTextContent(text="Hello!")],
334 attachments=[],
335 inference_options=InferenceOptions(),
336 thread_id=thread.id,
337 quoted_text="Hi!",
338 created_at=datetime.now(),
339 ),
340 UserMessageItem(
341 id="123",
342 content=[UserMessageTextContent(text="I'm well, thank you!")],
343 attachments=[],
344 inference_options=InferenceOptions(),
345 thread_id=thread.id,
346 quoted_text="How are you doing?",
347 created_at=datetime.now(),
348 ),
349 ]
350
351 async def throw_exception(
352 _: Attachment,
353 ) -> ResponseInputContentParam:
354 raise Exception("Not implemented")
355
356 input_items = await simple_to_agent_input(items)
357 assert len(input_items) == 3
358 assert input_items[0] == {
359 "content": [
360 {
361 "text": "Hello!",
362 "type": "input_text",
363 },
364 ],
365 "role": "user",
366 "type": "message",
367 }
368 assert input_items[1] == {
369 "content": [
370 {
371 "text": "I'm well, thank you!",
372 "type": "input_text",
373 },
374 ],
375 "role": "user",
376 "type": "message",
377 }
378 assert input_items[2] == {
379 "content": [
380 {
381 "text": "The user is referring to this in particular: \nHow are you doing?",
382 "type": "input_text",
383 },
384 ],
385 "role": "user",
386 "type": "message",
387 }
388
389
390async def test_input_item_converter_to_input_items_mixed():
391 items = [
392 UserMessageItem(
393 id="123",
394 content=[UserMessageTextContent(text="Hello!")],
395 attachments=[],
396 inference_options=InferenceOptions(),
397 thread_id=thread.id,
398 quoted_text="Hi!",
399 created_at=datetime.now(),
400 ),
401 UserMessageItem(
402 id="123",
403 content=[UserMessageTextContent(text="I'm well, thank you!")],
404 attachments=[],
405 inference_options=InferenceOptions(),
406 thread_id=thread.id,
407 quoted_text="How are you doing?",
408 created_at=datetime.now(),
409 ),
410 AssistantMessageItem(
411 id="123",
412 content=[
413 AssistantMessageContent(text="How are you doing?"),
414 AssistantMessageContent(text="Can't do that"),
415 ],
416 thread_id=thread.id,
417 created_at=datetime.now(),
418 ),
419 WidgetItem(
420 id="wd_123",
421 widget=Card(children=[Text(value="Hello, world!")]),
422 thread_id=thread.id,
423 created_at=datetime.now(),
424 ),
425 ]
426
427 input_items = await simple_to_agent_input(items)
428 assert len(input_items) == 4
429 assert input_items[0] == {
430 "content": [
431 {
432 "text": "Hello!",
433 "type": "input_text",
434 },
435 ],
436 "role": "user",
437 "type": "message",
438 }
439 assert input_items[1] == {
440 "content": [
441 {
442 "text": "I'm well, thank you!",
443 "type": "input_text",
444 },
445 ],
446 "role": "user",
447 "type": "message",
448 }
449 assert input_items[2] == {
450 "content": [
451 {
452 "annotations": [],
453 "text": "How are you doing?",
454 "logprobs": None,
455 "type": "output_text",
456 },
457 {
458 "annotations": [],
459 "text": "Can't do that",
460 "logprobs": None,
461 "type": "output_text",
462 },
463 ],
464 "type": "message",
465 "role": "assistant",
466 }
467 assert "type" in input_items[3]
468 widget_item = cast(EasyInputMessageParam, input_items[3])
469 assert widget_item.get("type") == "message"
470 assert widget_item.get("role") == "user"
471 text = widget_item.get("content")[0]["text"] # type: ignore
472 assert (
473 "The following graphical UI widget (id: wd_123) was displayed to the user"
474 in text
475 )
476 assert "Hello, world!" in text
477 assert "created_at" not in text
478
479
480async def test_input_item_converter_user_input_with_tags():
481 class MyThreadItemConverter(ThreadItemConverter):
482 async def tag_to_message_content(self, tag):
483 return ResponseInputTextParam(
484 type="input_text", text=tag.text + " " + tag.data["key"]
485 )
486
487 items = [
488 UserMessageItem(
489 id="123",
490 content=[
491 UserMessageTagContent(
492 text="Hello!", type="input_tag", id="hello", data={"key": "value"}
493 )
494 ],
495 attachments=[],
496 inference_options=InferenceOptions(),
497 thread_id=thread.id,
498 created_at=datetime.now(),
499 )
500 ]
501 items = await MyThreadItemConverter().to_agent_input(items)
502
503 assert len(items) == 2
504 assert items[0] == {
505 "content": [
506 {
507 "text": "@Hello!",
508 "type": "input_text",
509 },
510 ],
511 "role": "user",
512 "type": "message",
513 }
514 assert items[1] == {
515 "content": [
516 {
517 "text": "# User-provided context for @-mentions\n- When referencing resolved entities, use their canonical names **without** '@'.\n"
518 + "- The '@' form appears only in user text and should not be echoed.",
519 "type": "input_text",
520 },
521 {
522 "text": "Hello! value",
523 "type": "input_text",
524 },
525 ],
526 "role": "user",
527 "type": "message",
528 }
529
530
531async def test_input_item_converter_user_input_with_tags_throws_by_default():
532 items = [
533 UserMessageItem(
534 id="123",
535 content=[
536 UserMessageTagContent(
537 text="Hello!", type="input_tag", id="hello", data={}
538 )
539 ],
540 attachments=[],
541 inference_options=InferenceOptions(),
542 thread_id=thread.id,
543 created_at=datetime.now(),
544 )
545 ]
546
547 with pytest.raises(NotImplementedError):
548 await simple_to_agent_input(items)
549
550
551async def test_input_item_converter_generated_image_item():
552 items = [
553 GeneratedImageItem(
554 id="img_item_1",
555 thread_id=thread.id,
556 created_at=datetime.now(),
557 image=GeneratedImage(id="img_1", url="https://example.com/img.png"),
558 )
559 ]
560
561 input_items = await simple_to_agent_input(items)
562 assert len(input_items) == 1
563
564 message = cast(dict, input_items[0])
565 assert message.get("type") == "message"
566 assert message.get("role") == "user"
567
568 content = cast(list, message.get("content"))
569 assert content[0].get("type") == "input_text"
570 assert content[0].get("text") == "The following image was generated by the agent."
571 assert content[1].get("type") == "input_image"
572 assert content[1].get("file_id") is None
573 assert content[1].get("image_url") == "https://example.com/img.png"
574 assert content[1].get("detail") == "auto"
575
576
577async def test_input_item_converter_generated_image_item_without_image():
578 items = [
579 GeneratedImageItem(
580 id="img_item_1",
581 thread_id=thread.id,
582 created_at=datetime.now(),
583 )
584 ]
585
586 input_items = await simple_to_agent_input(items)
587 assert input_items == []
588
589
590async def test_input_item_converter_for_hidden_context_with_string_content():
591 items = [
592 HiddenContextItem(
593 id="123",
594 content="User pressed the red button",
595 thread_id=thread.id,
596 created_at=datetime.now(),
597 )
598 ]
599
600 # The default converter works for string content
601 items = await simple_to_agent_input(items)
602 assert len(items) == 1
603 assert items[0] == {
604 "content": [
605 {
606 "text": "Hidden context for the agent (not shown to the user):\n<HiddenContext>\nUser pressed the red button\n</HiddenContext>",
607 "type": "input_text",
608 },
609 ],
610 "role": "user",
611 "type": "message",
612 }
613
614
615async def test_input_item_converter_for_hidden_context_with_non_string_content():
616 items = [
617 HiddenContextItem(
618 id="123",
619 content={"harry": "potter", "hermione": "granger"},
620 thread_id=thread.id,
621 created_at=datetime.now(),
622 )
623 ]
624
625 # Default converter should throw
626 with pytest.raises(NotImplementedError):
627 await simple_to_agent_input(items)
628
629 class MyThreadItemConverter(ThreadItemConverter):
630 async def hidden_context_to_input(self, item: HiddenContextItem):
631 content = ",".join(item.content.keys())
632 return Message(
633 type="message",
634 content=[
635 ResponseInputTextParam(
636 type="input_text", text=f"<HIDDEN>{content}</HIDDEN>"
637 )
638 ],
639 role="user",
640 )
641
642 items = await MyThreadItemConverter().to_agent_input(items)
643 assert len(items) == 1
644 assert items[0] == {
645 "content": [
646 {
647 "text": "<HIDDEN>harry,hermione</HIDDEN>",
648 "type": "input_text",
649 },
650 ],
651 "role": "user",
652 "type": "message",
653 }
654
655
656async def test_input_item_converter_with_client_tool_call():
657 items = [
658 UserMessageItem(
659 id="123",
660 content=[UserMessageTextContent(text="Call a client tool call xyz")],
661 attachments=[],
662 inference_options=InferenceOptions(),
663 thread_id=thread.id,
664 quoted_text="Hi!",
665 created_at=datetime.now(),
666 ),
667 TaskItem(
668 id="tsk_123",
669 created_at=datetime.now(),
670 task=CustomTask(title="Called xyx"),
671 thread_id=thread.id,
672 ),
673 ClientToolCallItem(
674 id="ctc_123",
675 thread_id=thread.id,
676 created_at=datetime.now(),
677 name="xyz",
678 arguments={"foo": "bar"},
679 call_id="ctc_123",
680 ),
681 ClientToolCallItem(
682 id="ctc_123_done",
683 thread_id=thread.id,
684 created_at=datetime.now(),
685 name="xyz",
686 arguments={"foo": "bar"},
687 call_id="ctc_123",
688 status="completed",
689 output={"success": True},
690 ),
691 ]
692
693 input_items = await simple_to_agent_input(items)
694 assert len(input_items) == 4
695 assert input_items[0] == {
696 "content": [
697 {
698 "text": "Call a client tool call xyz",
699 "type": "input_text",
700 },
701 ],
702 "role": "user",
703 "type": "message",
704 }
705 assert input_items[1] == {
706 "content": [
707 {
708 "text": "A message was displayed to the user that the following task was performed:\n<Task>\nCalled xyx\n</Task>",
709 "type": "input_text",
710 },
711 ],
712 "type": "message",
713 "role": "user",
714 }
715 assert input_items[2] == {
716 "type": "function_call",
717 "name": "xyz",
718 "arguments": json.dumps({"foo": "bar"}),
719 "call_id": "ctc_123",
720 }
721 assert input_items[3] == {
722 "type": "function_call_output",
723 "call_id": "ctc_123",
724 "output": json.dumps({"success": True}),
725 }
726
727
728async def test_stream_agent_response_yields_context_events_without_streaming_events():
729 context = AgentContext(
730 previous_response_id=None, thread=thread, store=mock_store, request_context=None
731 )
732 result = make_result()
733
734 event = ThreadItemAddedEvent(
735 item=WidgetItem(
736 id="123",
737 created_at=datetime.now(),
738 thread_id=thread.id,
739 widget=Card(children=[Text(id="text", value="Hello, world!")]),
740 ),
741 )
742
743 await context.stream(event)
744
745 response_streamer = stream_agent_response(context, result)
746 event = await response_streamer.__anext__()
747
748 assert event.type == "thread.item.added"
749
750 future = asyncio.ensure_future(response_streamer.__anext__())
751 assert future.done() is False
752
753 result.done()
754
755 try:
756 await future
757 assert False, "expected StopAsyncIteration"
758 except StopAsyncIteration:
759 pass
760
761 assert future.done() is True
762
763
764async def test_stream_agent_response_maps_events():
765 context = AgentContext(
766 previous_response_id=None, thread=thread, store=mock_store, request_context=None
767 )
768 result = make_result()
769
770 event = ThreadItemAddedEvent(
771 item=WidgetItem(
772 id="123",
773 created_at=datetime.now(),
774 thread_id=thread.id,
775 widget=Card(children=[Text(id="text", value="Hello, world!")]),
776 ),
777 )
778
779 await context.stream(event)
780 result.add_event(
781 RawResponsesStreamEvent(
782 type="raw_response_event",
783 data=ResponseTextDeltaEvent(
784 type="response.output_text.delta",
785 delta="Hello, world!",
786 content_index=0,
787 item_id="123",
788 logprobs=[],
789 output_index=0,
790 sequence_number=0,
791 ),
792 )
793 )
794
795 response_streamer = stream_agent_response(context, result)
796 event1 = await response_streamer.__anext__()
797 event2 = await response_streamer.__anext__()
798
799 assert {event1.type, event2.type} == {
800 "thread.item.added",
801 "thread.item.updated",
802 }
803
804 future = asyncio.ensure_future(response_streamer.__anext__())
805 assert future.done() is False
806
807 result.done()
808
809 try:
810 await future
811 assert False, "expected StopAsyncIteration"
812 except StopAsyncIteration:
813 pass
814
815 assert future.done() is True
816
817
818@pytest.mark.parametrize(
819 "raw_event,expected_event",
820 [
821 (
822 RawResponsesStreamEvent(
823 type="raw_response_event",
824 data=ResponseTextDeltaEvent(
825 type="response.output_text.delta",
826 delta="Hello, world!",
827 content_index=0,
828 item_id="123",
829 logprobs=[],
830 output_index=0,
831 sequence_number=0,
832 ),
833 ),
834 ThreadItemUpdatedEvent(
835 item_id="123",
836 update=AssistantMessageContentPartTextDelta(
837 content_index=0,
838 delta="Hello, world!",
839 ),
840 ),
841 ),
842 (
843 RawResponsesStreamEvent(
844 type="raw_response_event",
845 data=ResponseContentPartAddedEvent(
846 type="response.content_part.added",
847 part=ResponseOutputText(
848 type="output_text",
849 text="New content",
850 annotations=[],
851 ),
852 content_index=1,
853 item_id="123",
854 output_index=0,
855 sequence_number=1,
856 ),
857 ),
858 ThreadItemUpdatedEvent(
859 item_id="123",
860 update=AssistantMessageContentPartAdded(
861 content_index=1,
862 content=AssistantMessageContent(text="New content", annotations=[]),
863 ),
864 ),
865 ),
866 (
867 RawResponsesStreamEvent(
868 type="raw_response_event",
869 data=ResponseTextDoneEvent(
870 type="response.output_text.done",
871 text="Final text",
872 content_index=0,
873 item_id="123",
874 logprobs=[],
875 output_index=0,
876 sequence_number=2,
877 ),
878 ),
879 ThreadItemUpdatedEvent(
880 item_id="123",
881 update=AssistantMessageContentPartDone(
882 content_index=0,
883 content=AssistantMessageContent(
884 text="Final text",
885 annotations=[],
886 ),
887 ),
888 ),
889 ),
890 (
891 RawResponsesStreamEvent(
892 type="raw_response_event",
893 data=Mock(
894 type="response.output_text.annotation.added",
895 annotation=ResponsesAnnotationFileCitation(
896 type="file_citation",
897 file_id="file_123",
898 filename="file.txt",
899 index=5,
900 ),
901 content_index=0,
902 item_id="123",
903 annotation_index=0,
904 output_index=0,
905 sequence_number=3,
906 ),
907 ),
908 ThreadItemUpdatedEvent(
909 item_id="123",
910 update=AssistantMessageContentPartAnnotationAdded(
911 content_index=0,
912 annotation_index=0,
913 annotation=Annotation(
914 source=FileSource(filename="file.txt", title="file.txt"),
915 index=5,
916 ),
917 ),
918 ),
919 ),
920 ],
921)
922async def test_event_mapping(raw_event, expected_event):
923 context = AgentContext(
924 previous_response_id=None, thread=thread, store=mock_store, request_context=None
925 )
926 result = make_result()
927
928 result.add_event(raw_event)
929 result.done()
930
931 events = await all_events(stream_agent_response(context, result))
932 if expected_event:
933 assert events == [expected_event]
934 else:
935 assert events == []
936
937
938async def test_stream_agent_response_emits_annotation_added_events():
939 context = AgentContext(
940 previous_response_id=None, thread=thread, store=mock_store, request_context=None
941 )
942 result = make_result()
943 item_id = "item_123"
944
945 def add_annotation_event(annotation, sequence_number):
946 result.add_event(
947 RawResponsesStreamEvent(
948 type="raw_response_event",
949 data=Mock(
950 type="response.output_text.annotation.added",
951 annotation=annotation,
952 content_index=0,
953 item_id=item_id,
954 annotation_index=sequence_number,
955 output_index=0,
956 sequence_number=sequence_number,
957 ),
958 )
959 )
960
961 add_annotation_event(
962 ResponsesAnnotationFileCitation(
963 type="file_citation",
964 file_id="file_invalid",
965 filename="",
966 index=0,
967 ),
968 sequence_number=0,
969 )
970 add_annotation_event(
971 ResponsesAnnotationContainerFileCitation(
972 type="container_file_citation",
973 container_id="container_1",
974 file_id="file_123",
975 filename="container.txt",
976 start_index=0,
977 end_index=3,
978 ),
979 sequence_number=1,
980 )
981 add_annotation_event(
982 ResponsesAnnotationURLCitation(
983 type="url_citation",
984 url="https://example.com",
985 title="Example",
986 start_index=1,
987 end_index=5,
988 ),
989 sequence_number=2,
990 )
991 result.done()
992
993 events = await all_events(stream_agent_response(context, result))
994 assert events == [
995 ThreadItemUpdatedEvent(
996 item_id=item_id,
997 update=AssistantMessageContentPartAnnotationAdded(
998 content_index=0,
999 annotation_index=0,
1000 annotation=Annotation(
1001 source=FileSource(filename="container.txt", title="container.txt"),
1002 index=3,
1003 ),
1004 ),
1005 ),
1006 ThreadItemUpdatedEvent(
1007 item_id=item_id,
1008 update=AssistantMessageContentPartAnnotationAdded(
1009 content_index=0,
1010 annotation_index=1,
1011 annotation=Annotation(
1012 source=URLSource(
1013 url="https://example.com",
1014 title="Example",
1015 ),
1016 index=5,
1017 ),
1018 ),
1019 ),
1020 ]
1021
1022
1023async def test_stream_agent_response_annotation_added_normalizes_annotations():
1024 context = AgentContext(
1025 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1026 )
1027 result = make_result()
1028 item_id = "item_123"
1029
1030 def add_annotation_event(annotation, sequence_number):
1031 result.add_event(
1032 RawResponsesStreamEvent(
1033 type="raw_response_event",
1034 data=Mock(
1035 type="response.output_text.annotation.added",
1036 annotation=annotation,
1037 content_index=0,
1038 item_id=item_id,
1039 annotation_index=sequence_number,
1040 output_index=0,
1041 sequence_number=sequence_number,
1042 ),
1043 )
1044 )
1045
1046 # Invalid file citation is dropped (empty filename)
1047 add_annotation_event(
1048 {
1049 "type": "file_citation",
1050 "file_id": "file_invalid",
1051 "filename": "",
1052 "index": 0,
1053 },
1054 sequence_number=0,
1055 )
1056
1057 # Container + URL citations are converted; indices are compacted (0, 1) despite the dropped first event.
1058 add_annotation_event(
1059 {
1060 "type": "container_file_citation",
1061 "container_id": "container_1",
1062 "file_id": "file_123",
1063 "filename": "container.txt",
1064 "start_index": 0,
1065 "end_index": 3,
1066 },
1067 sequence_number=1,
1068 )
1069 add_annotation_event(
1070 {
1071 "type": "url_citation",
1072 "url": "https://example.com",
1073 "title": "Example",
1074 "start_index": 1,
1075 "end_index": 5,
1076 },
1077 sequence_number=2,
1078 )
1079
1080 result.done()
1081
1082 events = await all_events(stream_agent_response(context, result))
1083 assert events == [
1084 ThreadItemUpdatedEvent(
1085 item_id=item_id,
1086 update=AssistantMessageContentPartAnnotationAdded(
1087 content_index=0,
1088 annotation_index=0,
1089 annotation=Annotation(
1090 source=FileSource(filename="container.txt", title="container.txt"),
1091 index=3,
1092 ),
1093 ),
1094 ),
1095 ThreadItemUpdatedEvent(
1096 item_id=item_id,
1097 update=AssistantMessageContentPartAnnotationAdded(
1098 content_index=0,
1099 annotation_index=1,
1100 annotation=Annotation(
1101 source=URLSource(url="https://example.com", title="Example"),
1102 index=5,
1103 ),
1104 ),
1105 ),
1106 ]
1107
1108
1109async def test_custom_annotation_conversion_used_by_stream_agent_response():
1110 context = AgentContext(
1111 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1112 )
1113 result = make_result()
1114
1115 class CustomResponseStreamConverter(ResponseStreamConverter):
1116 def __init__(self):
1117 super().__init__()
1118 self.calls: list[str] = []
1119
1120 async def file_citation_to_annotation(self, file_citation):
1121 self.calls.append("file_citation")
1122 return Annotation(
1123 source=FileSource(
1124 filename="report.pdf",
1125 title="Usage Report",
1126 description="Usage report for the month of January",
1127 ),
1128 index=111,
1129 )
1130
1131 async def url_citation_to_annotation(self, url_citation):
1132 self.calls.append("url_citation")
1133 return Annotation(
1134 source=URLSource(url="https://custom.example/url", title="Custom"),
1135 index=222,
1136 )
1137
1138 converter = CustomResponseStreamConverter()
1139
1140 result.add_event(
1141 RawResponsesStreamEvent(
1142 type="raw_response_event",
1143 data=Mock(
1144 type="response.output_text.annotation.added",
1145 annotation=ResponsesAnnotationFileCitation(
1146 type="file_citation",
1147 file_id="file_123",
1148 filename="file.txt",
1149 index=0,
1150 ),
1151 content_index=0,
1152 item_id="m_1",
1153 annotation_index=0,
1154 output_index=0,
1155 sequence_number=0,
1156 ),
1157 )
1158 )
1159
1160 result.add_event(
1161 RawResponsesStreamEvent(
1162 type="raw_response_event",
1163 data=ResponseOutputItemDoneEvent(
1164 type="response.output_item.done",
1165 item=ResponseOutputMessage(
1166 id="m_1",
1167 content=[
1168 ResponseOutputText(
1169 annotations=[
1170 ResponsesAnnotationURLCitation(
1171 type="url_citation",
1172 url="https://example.com",
1173 title="Example",
1174 start_index=0,
1175 end_index=7,
1176 )
1177 ],
1178 text="Hello!",
1179 type="output_text",
1180 )
1181 ],
1182 role="assistant",
1183 status="completed",
1184 type="message",
1185 ),
1186 output_index=0,
1187 sequence_number=1,
1188 ),
1189 )
1190 )
1191 result.done()
1192
1193 events = await all_events(
1194 stream_agent_response(context, result, converter=converter)
1195 )
1196 assert len(events) == 2
1197
1198 assert isinstance(events[0], ThreadItemUpdatedEvent)
1199 assert events[0].item_id == "m_1"
1200 assert events[0].update == AssistantMessageContentPartAnnotationAdded(
1201 content_index=0,
1202 annotation_index=0,
1203 annotation=Annotation(
1204 source=FileSource(
1205 filename="report.pdf",
1206 title="Usage Report",
1207 description="Usage report for the month of January",
1208 ),
1209 index=111,
1210 ),
1211 )
1212
1213 assert isinstance(events[1], ThreadItemDoneEvent)
1214 assert isinstance(events[1].item, AssistantMessageItem)
1215 assert events[1].item.id == "m_1"
1216 assert events[1].item.content == [
1217 AssistantMessageContent(
1218 text="Hello!",
1219 annotations=[
1220 Annotation(
1221 source=URLSource(url="https://custom.example/url", title="Custom"),
1222 index=222,
1223 )
1224 ],
1225 )
1226 ]
1227 assert converter.calls == ["file_citation", "url_citation"]
1228
1229
1230@pytest.mark.parametrize("throw_guardrail", ["input", "output"])
1231async def test_stream_agent_response_yields_item_removed_event(throw_guardrail):
1232 context = AgentContext(
1233 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1234 )
1235 result = make_result()
1236 result.add_event(
1237 RawResponsesStreamEvent(
1238 type="raw_response_event",
1239 data=ResponseOutputItemAddedEvent(
1240 type="response.output_item.added",
1241 item=ResponseOutputMessage(
1242 id="1",
1243 content=[
1244 ResponseOutputText(
1245 annotations=[], type="output_text", text="Hello, world!"
1246 )
1247 ],
1248 role="assistant",
1249 status="completed",
1250 type="message",
1251 ),
1252 output_index=0,
1253 sequence_number=0,
1254 ),
1255 )
1256 )
1257 await context.stream(
1258 ThreadItemAddedEvent(
1259 item=AssistantMessageItem(
1260 id="2",
1261 content=[AssistantMessageContent(text="Hello, world!")],
1262 thread_id=thread.id,
1263 created_at=datetime.now(),
1264 ),
1265 )
1266 )
1267
1268 await context.stream(
1269 ThreadItemDoneEvent(
1270 item=WidgetItem(
1271 id="3",
1272 created_at=datetime.now(),
1273 thread_id=thread.id,
1274 widget=Card(children=[Text(id="text", value="Hello, world!")]),
1275 ),
1276 )
1277 )
1278
1279 iterator = stream_agent_response(context, result)
1280
1281 n = 3
1282 events = []
1283 # Grab first 3 events to
1284 async for event in iterator:
1285 n -= 1
1286 events.append(event)
1287 if n == 0:
1288 break
1289
1290 if throw_guardrail == "input":
1291 result.throw_input_guardrails()
1292 else:
1293 result.throw_output_guardrails()
1294
1295 try:
1296 async for event in iterator:
1297 events.append(event)
1298 assert False, "Guardrail should have been thrown from stream_agent_response"
1299 except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
1300 pass
1301 except Exception as e:
1302 assert False, f"Unexpected exception: {e}"
1303
1304 deleted_item_ids = {
1305 event.item_id for event in events if event.type == "thread.item.removed"
1306 }
1307 assert deleted_item_ids == {"1", "2", "3"}
1308
1309
1310async def test_stream_agent_response_assistant_message_content_types():
1311 AgentContext(
1312 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1313 )
1314 result = make_result()
1315
1316 result.add_event(
1317 RawResponsesStreamEvent(
1318 type="raw_response_event",
1319 data=ResponseOutputItemDoneEvent(
1320 type="response.output_item.done",
1321 item=ResponseFileSearchToolCall(
1322 id="fs_0",
1323 queries=["Hello, world!"],
1324 status="completed",
1325 type="file_search_call",
1326 results=[
1327 Result(
1328 file_id="f_123",
1329 filename="test.txt",
1330 text="Hello, world!",
1331 score=1.0,
1332 ),
1333 Result(
1334 file_id="f_123",
1335 filename="test.txt",
1336 text="Hello, friends!",
1337 score=0.5,
1338 ),
1339 ],
1340 ),
1341 output_index=0,
1342 sequence_number=0,
1343 ),
1344 )
1345 )
1346 result.add_event(
1347 RawResponsesStreamEvent(
1348 type="raw_response_event",
1349 data=ResponseOutputItemDoneEvent(
1350 type="response.output_item.done",
1351 item=ResponseOutputMessage(
1352 id="1",
1353 content=[
1354 ResponseOutputText(
1355 annotations=[
1356 ResponsesAnnotationFileCitation(
1357 type="file_citation",
1358 file_id="f_123",
1359 index=0,
1360 filename="test.txt",
1361 ),
1362 ResponsesAnnotationContainerFileCitation(
1363 type="container_file_citation",
1364 container_id="container_1",
1365 file_id="f_456",
1366 filename="container.txt",
1367 start_index=0,
1368 end_index=3,
1369 ),
1370 ResponsesAnnotationURLCitation(
1371 type="url_citation",
1372 url="https://www.google.com",
1373 title="Google",
1374 start_index=0,
1375 end_index=10,
1376 ),
1377 ResponsesAnnotationFilePath(
1378 type="file_path",
1379 file_id="123",
1380 index=0,
1381 ),
1382 ],
1383 text="Hello, world!",
1384 type="output_text",
1385 ),
1386 ResponseOutputText(
1387 annotations=[],
1388 text="Can't do that",
1389 type="output_text",
1390 ),
1391 ],
1392 role="assistant",
1393 status="completed",
1394 type="message",
1395 ),
1396 output_index=0,
1397 sequence_number=0,
1398 ),
1399 )
1400 )
1401
1402 result.done()
1403
1404 context = AgentContext(
1405 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1406 )
1407 events = await all_events(stream_agent_response(context, result))
1408 assert len(events) == 1
1409 assert isinstance(events[0], ThreadItemDoneEvent)
1410 message = events[0].item
1411 assert isinstance(message, AssistantMessageItem)
1412 assert message.content == [
1413 AssistantMessageContent(
1414 annotations=[
1415 Annotation(
1416 source=FileSource(
1417 filename="test.txt",
1418 title="test.txt",
1419 ),
1420 index=0,
1421 ),
1422 Annotation(
1423 source=FileSource(
1424 filename="container.txt",
1425 title="container.txt",
1426 ),
1427 index=3,
1428 ),
1429 Annotation(
1430 source=URLSource(
1431 url="https://www.google.com",
1432 title="Google",
1433 ),
1434 index=10,
1435 ),
1436 ],
1437 text="Hello, world!",
1438 ),
1439 AssistantMessageContent(text="Can't do that", annotations=[]),
1440 ]
1441 assert message.id == "1"
1442
1443
1444async def test_stream_agent_response_image_generation_events():
1445 context = AgentContext(
1446 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1447 )
1448 result = make_result()
1449
1450 result.add_event(
1451 RawResponsesStreamEvent(
1452 type="raw_response_event",
1453 data=Mock(
1454 type="response.output_item.added",
1455 item=Mock(type="image_generation_call", id="img_call_1"),
1456 output_index=0,
1457 sequence_number=0,
1458 ),
1459 )
1460 )
1461 result.add_event(
1462 RawResponsesStreamEvent(
1463 type="raw_response_event",
1464 data=Mock(
1465 type="response.output_item.done",
1466 item=Mock(
1467 type="image_generation_call", id="img_call_1", result="dGVzdA=="
1468 ),
1469 output_index=0,
1470 sequence_number=1,
1471 ),
1472 )
1473 )
1474 result.done()
1475
1476 stream = stream_agent_response(context, result)
1477 event1 = await stream.__anext__()
1478 assert isinstance(event1, ThreadItemAddedEvent)
1479 assert isinstance(event1.item, GeneratedImageItem)
1480 assert event1.item.type == "generated_image"
1481 assert event1.item.id == "message_id"
1482 assert event1.item.image is None
1483
1484 event2 = await stream.__anext__()
1485 assert isinstance(event2, ThreadItemDoneEvent)
1486 assert isinstance(event2.item, GeneratedImageItem)
1487 assert event2.item.id == event1.item.id
1488 assert event2.item.image == GeneratedImage(
1489 id="img_call_1", url="data:image/png;base64,dGVzdA=="
1490 )
1491
1492 with pytest.raises(StopAsyncIteration):
1493 await stream.__anext__()
1494
1495
1496async def test_stream_agent_response_image_generation_events_with_custom_converter():
1497 context = AgentContext(
1498 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1499 )
1500 result = make_result()
1501
1502 result.add_event(
1503 RawResponsesStreamEvent(
1504 type="raw_response_event",
1505 data=Mock(
1506 type="response.output_item.added",
1507 item=Mock(type="image_generation_call", id="img_call_1"),
1508 output_index=0,
1509 sequence_number=0,
1510 ),
1511 )
1512 )
1513 result.add_event(
1514 RawResponsesStreamEvent(
1515 type="raw_response_event",
1516 data=Mock(
1517 type="response.output_item.done",
1518 item=Mock(
1519 type="image_generation_call", id="img_call_1", result="dGVzdA=="
1520 ),
1521 output_index=0,
1522 sequence_number=1,
1523 ),
1524 )
1525 )
1526 result.done()
1527
1528 class CustomResponseStreamConverter(ResponseStreamConverter):
1529 def __init__(self):
1530 super().__init__()
1531 self.calls: list[tuple[str, str, int | None]] = []
1532
1533 async def base64_image_to_url(
1534 self,
1535 image_id: str,
1536 base64_image: str,
1537 partial_image_index: int | None = None,
1538 ) -> str:
1539 self.calls.append((image_id, base64_image, partial_image_index))
1540 return f"https://example.com/{image_id}"
1541
1542 converter = CustomResponseStreamConverter()
1543 stream = stream_agent_response(context, result, converter=converter)
1544 event1 = await stream.__anext__()
1545 assert isinstance(event1, ThreadItemAddedEvent)
1546 assert isinstance(event1.item, GeneratedImageItem)
1547 assert event1.item.image is None
1548
1549 event2 = await stream.__anext__()
1550 assert isinstance(event2, ThreadItemDoneEvent)
1551 assert isinstance(event2.item, GeneratedImageItem)
1552 assert converter.calls == [("img_call_1", "dGVzdA==", None)]
1553 assert event2.item.image == GeneratedImage(
1554 id="img_call_1", url="https://example.com/img_call_1"
1555 )
1556 with pytest.raises(StopAsyncIteration):
1557 await stream.__anext__()
1558
1559
1560async def test_stream_agent_response_image_generation_partial_progress():
1561 context = AgentContext(
1562 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1563 )
1564 result = make_result()
1565
1566 result.add_event(
1567 RawResponsesStreamEvent(
1568 type="raw_response_event",
1569 data=Mock(
1570 type="response.output_item.added",
1571 item=Mock(type="image_generation_call", id="img_call_1"),
1572 output_index=0,
1573 sequence_number=0,
1574 ),
1575 )
1576 )
1577 result.add_event(
1578 RawResponsesStreamEvent(
1579 type="raw_response_event",
1580 data=Mock(
1581 type="response.image_generation_call.partial_image",
1582 partial_image_b64="dGVzdA==",
1583 partial_image_index=1,
1584 item_id="img_call_1",
1585 output_index=0,
1586 sequence_number=1,
1587 ),
1588 )
1589 )
1590 result.add_event(
1591 RawResponsesStreamEvent(
1592 type="raw_response_event",
1593 data=Mock(
1594 type="response.output_item.done",
1595 item=Mock(
1596 type="image_generation_call", id="img_call_1", result="dGVzdA=="
1597 ),
1598 output_index=0,
1599 sequence_number=2,
1600 ),
1601 )
1602 )
1603 result.done()
1604
1605 converter = ResponseStreamConverter(partial_images=3)
1606 events = await all_events(
1607 stream_agent_response(context, result, converter=converter)
1608 )
1609
1610 assert len(events) == 3
1611 added_event, partial_event, done_event = events
1612
1613 assert isinstance(added_event, ThreadItemAddedEvent)
1614 assert isinstance(added_event.item, GeneratedImageItem)
1615
1616 assert isinstance(partial_event, ThreadItemUpdatedEvent)
1617 assert isinstance(partial_event.update, GeneratedImageUpdated)
1618 assert partial_event.update.progress == pytest.approx(1 / 3)
1619 assert partial_event.update.image == GeneratedImage(
1620 id="img_call_1", url="data:image/png;base64,dGVzdA=="
1621 )
1622
1623 assert isinstance(done_event, ThreadItemDoneEvent)
1624 assert isinstance(done_event.item, GeneratedImageItem)
1625 assert done_event.item.image == GeneratedImage(
1626 id="img_call_1", url="data:image/png;base64,dGVzdA=="
1627 )
1628
1629
1630async def test_workflow_streams_first_thought():
1631 context = AgentContext(
1632 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1633 )
1634 result = make_result()
1635
1636 # first thought
1637 result.add_event(
1638 RawResponsesStreamEvent(
1639 type="raw_response_event",
1640 data=ResponseOutputItemAddedEvent(
1641 type="response.output_item.added",
1642 item=ResponseReasoningItem(
1643 id="resp_1",
1644 summary=[],
1645 type="reasoning",
1646 ),
1647 output_index=0,
1648 sequence_number=0,
1649 ),
1650 )
1651 )
1652 result.add_event(
1653 RawResponsesStreamEvent(
1654 type="raw_response_event",
1655 data=Mock(
1656 type="response.reasoning_summary_text.delta",
1657 item_id="resp_1",
1658 summary_index=0,
1659 delta="Think",
1660 ),
1661 )
1662 )
1663 result.add_event(
1664 RawResponsesStreamEvent(
1665 type="raw_response_event",
1666 data=Mock(
1667 type="response.reasoning_summary_text.delta",
1668 item_id="resp_1",
1669 summary_index=0,
1670 delta="ing 1",
1671 ),
1672 )
1673 )
1674 result.add_event(
1675 RawResponsesStreamEvent(
1676 type="raw_response_event",
1677 data=Mock(
1678 type="response.reasoning_summary_text.done",
1679 item_id="resp_1",
1680 summary_index=0,
1681 text="Thinking 1",
1682 ),
1683 )
1684 )
1685
1686 # second thought
1687 result.add_event(
1688 RawResponsesStreamEvent(
1689 type="raw_response_event",
1690 data=Mock(
1691 type="response.reasoning_summary_text.delta",
1692 item_id="resp_1",
1693 summary_index=1,
1694 delta="Think",
1695 ),
1696 )
1697 )
1698 result.add_event(
1699 RawResponsesStreamEvent(
1700 type="raw_response_event",
1701 data=Mock(
1702 type="response.reasoning_summary_text.delta",
1703 item_id="resp_1",
1704 summary_index=1,
1705 delta="ing 2",
1706 ),
1707 )
1708 )
1709 result.add_event(
1710 RawResponsesStreamEvent(
1711 type="raw_response_event",
1712 data=Mock(
1713 type="response.reasoning_summary_text.done",
1714 item_id="resp_1",
1715 summary_index=1,
1716 text="Thinking 2",
1717 ),
1718 )
1719 )
1720
1721 result.done()
1722 stream = stream_agent_response(context, result)
1723
1724 # Workflow added
1725 event = await anext(stream)
1726 assert isinstance(event, ThreadItemAddedEvent)
1727 assert context.workflow_item is not None
1728 assert context.workflow_item.workflow.type == "reasoning"
1729 assert len(context.workflow_item.workflow.tasks) == 0
1730 assert event == ThreadItemAddedEvent(item=context.workflow_item)
1731
1732 # First thought added
1733 event = await anext(stream)
1734 assert context.workflow_item is not None
1735 assert len(context.workflow_item.workflow.tasks) == 1
1736 assert isinstance(event, ThreadItemUpdatedEvent)
1737 assert event == ThreadItemUpdatedEvent(
1738 item_id=context.workflow_item.id,
1739 update=WorkflowTaskAdded(
1740 task=ThoughtTask(content="Think"),
1741 task_index=0,
1742 ),
1743 )
1744
1745 # First thought delta
1746 event = await anext(stream)
1747 assert context.workflow_item is not None
1748 assert len(context.workflow_item.workflow.tasks) == 1
1749 assert isinstance(event, ThreadItemUpdatedEvent)
1750 assert event == ThreadItemUpdatedEvent(
1751 item_id=context.workflow_item.id,
1752 update=WorkflowTaskUpdated(
1753 task=ThoughtTask(content="Thinking 1"),
1754 task_index=0,
1755 ),
1756 )
1757
1758 # First thought done
1759 event = await anext(stream)
1760 assert context.workflow_item is not None
1761 assert len(context.workflow_item.workflow.tasks) == 1
1762 assert isinstance(event, ThreadItemUpdatedEvent)
1763 assert event == ThreadItemUpdatedEvent(
1764 item_id=context.workflow_item.id,
1765 update=WorkflowTaskUpdated(
1766 task=ThoughtTask(content="Thinking 1"),
1767 task_index=0,
1768 ),
1769 )
1770
1771 # Second thought added (not streamed)
1772 event = await anext(stream)
1773 assert context.workflow_item is not None
1774 assert len(context.workflow_item.workflow.tasks) == 2
1775 assert isinstance(event, ThreadItemUpdatedEvent)
1776 assert event == ThreadItemUpdatedEvent(
1777 item_id=context.workflow_item.id,
1778 update=WorkflowTaskAdded(
1779 task=ThoughtTask(content="Thinking 2"),
1780 task_index=1,
1781 ),
1782 )
1783
1784 try:
1785 while True:
1786 await anext(stream)
1787 except StopAsyncIteration:
1788 pass
1789
1790
1791async def test_workflow_ends_on_message():
1792 context = AgentContext(
1793 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1794 )
1795 result = make_result()
1796
1797 # first thought
1798 result.add_event(
1799 RawResponsesStreamEvent(
1800 type="raw_response_event",
1801 data=ResponseOutputItemAddedEvent(
1802 type="response.output_item.added",
1803 item=ResponseReasoningItem(
1804 id="resp_1",
1805 summary=[],
1806 type="reasoning",
1807 ),
1808 output_index=0,
1809 sequence_number=0,
1810 ),
1811 )
1812 )
1813 result.add_event(
1814 RawResponsesStreamEvent(
1815 type="raw_response_event",
1816 data=Mock(
1817 type="response.reasoning_summary_text.done",
1818 item_id="resp_1",
1819 summary_index=0,
1820 text="Thinking 1",
1821 ),
1822 )
1823 )
1824
1825 # not reasoning
1826 result.add_event(
1827 RawResponsesStreamEvent(
1828 type="raw_response_event",
1829 data=ResponseOutputItemAddedEvent(
1830 type="response.output_item.added",
1831 item=ResponseOutputMessage(
1832 id="m_1",
1833 content=[],
1834 role="assistant",
1835 status="in_progress",
1836 type="message",
1837 ),
1838 output_index=0,
1839 sequence_number=0,
1840 ),
1841 )
1842 )
1843
1844 result.done()
1845 stream = stream_agent_response(context, result)
1846
1847 # Workflow added
1848 event = await anext(stream)
1849 assert isinstance(event, ThreadItemAddedEvent)
1850 assert context.workflow_item is not None
1851 assert context.workflow_item.workflow.type == "reasoning"
1852 assert len(context.workflow_item.workflow.tasks) == 0
1853 assert event == ThreadItemAddedEvent(item=context.workflow_item)
1854
1855 # First thought done
1856 event = await anext(stream)
1857 assert context.workflow_item is not None
1858 assert len(context.workflow_item.workflow.tasks) == 1
1859 assert isinstance(event, ThreadItemUpdatedEvent)
1860 assert event == ThreadItemUpdatedEvent(
1861 item_id=context.workflow_item.id,
1862 update=WorkflowTaskAdded(
1863 task=ThoughtTask(content="Thinking 1"),
1864 task_index=0,
1865 ),
1866 )
1867
1868 # Workflow ended
1869 event = await anext(stream)
1870 assert isinstance(event, ThreadItemDoneEvent)
1871 assert event.item.type == "workflow"
1872 assert context.workflow_item is None
1873 # Summary and expanded are handled by the end_workflow method
1874 assert isinstance(event.item.workflow.summary, DurationSummary)
1875 assert event.item.workflow.expanded is False
1876
1877 try:
1878 while True:
1879 await anext(stream)
1880 except StopAsyncIteration:
1881 pass
1882
1883
1884async def test_existing_workflow_summary_not_overwritten_on_automatic_end():
1885 context = AgentContext(
1886 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1887 )
1888 result = make_result()
1889 context.workflow_item = WorkflowItem(
1890 id="wf_1",
1891 created_at=datetime.now(),
1892 workflow=Workflow(type="custom", tasks=[], summary=CustomSummary(title="Test")),
1893 thread_id=thread.id,
1894 )
1895
1896 result.add_event(
1897 RawResponsesStreamEvent(
1898 type="raw_response_event",
1899 data=ResponseOutputItemAddedEvent(
1900 type="response.output_item.added",
1901 item=ResponseOutputMessage(
1902 id="m_1",
1903 content=[],
1904 role="assistant",
1905 status="in_progress",
1906 type="message",
1907 ),
1908 output_index=0,
1909 sequence_number=0,
1910 ),
1911 )
1912 )
1913
1914 result.done()
1915 stream = stream_agent_response(context, result)
1916
1917 event = await anext(stream)
1918
1919 assert isinstance(event, ThreadItemDoneEvent)
1920 assert context.workflow_item is None
1921 assert event.item.type == "workflow"
1922 assert event.item.workflow.summary == CustomSummary(title="Test")
1923
1924 try:
1925 while True:
1926 await anext(stream)
1927 except StopAsyncIteration:
1928 pass
1929