openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
ee4063b8bc4c927542e2cdbc69ba11b76d28d2a3

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_agents.py

1310lines · modecode

1import asyncio
2import json
3from collections.abc import AsyncIterator
4from datetime import datetime
5from typing import cast
6from unittest.mock import AsyncMock, Mock
7
8import pytest
9from agents import (
10 Agent,
11 GuardrailFunctionOutput,
12 InputGuardrail,
13 InputGuardrailResult,
14 InputGuardrailTripwireTriggered,
15 OutputGuardrail,
16 OutputGuardrailResult,
17 OutputGuardrailTripwireTriggered,
18 RawResponsesStreamEvent,
19 RunContextWrapper,
20 RunItemStreamEvent,
21 Runner,
22 RunResultStreaming,
23 StreamEvent,
24 ToolCallItem,
25)
26from agents._run_impl import QueueCompleteSentinel
27from openai.types.responses import (
28 EasyInputMessageParam,
29 ResponseFileSearchToolCall,
30 ResponseInputContentParam,
31 ResponseInputTextParam,
32 ResponseOutputItemAddedEvent,
33 ResponseOutputItemDoneEvent,
34 ResponseOutputMessage,
35 ResponseReasoningItem,
36)
37from openai.types.responses.response_content_part_added_event import (
38 ResponseContentPartAddedEvent,
39)
40from openai.types.responses.response_file_search_tool_call import Result
41from openai.types.responses.response_output_text import (
42 AnnotationFileCitation as ResponsesAnnotationFileCitation,
43)
44from openai.types.responses.response_output_text import (
45 AnnotationFilePath as ResponsesAnnotationFilePath,
46)
47from openai.types.responses.response_output_text import (
48 AnnotationURLCitation as ResponsesAnnotationURLCitation,
49)
50from openai.types.responses.response_output_text import (
51 ResponseOutputText,
52)
53from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent
54from openai.types.responses.response_text_done_event import ResponseTextDoneEvent
55
56from chatkit.agents import (
57 AgentContext,
58 ThreadItemConverter,
59 accumulate_text,
60 simple_to_agent_input,
61 stream_agent_response,
62)
63from chatkit.types import (
64 Annotation,
65 AssistantMessageContent,
66 AssistantMessageContentPartAdded,
67 AssistantMessageContentPartDone,
68 AssistantMessageContentPartTextDelta,
69 AssistantMessageItem,
70 Attachment,
71 ClientToolCallItem,
72 CustomSummary,
73 CustomTask,
74 DurationSummary,
75 FileSource,
76 InferenceOptions,
77 Page,
78 TaskItem,
79 ThoughtTask,
80 Thread,
81 ThreadItemAddedEvent,
82 ThreadItemDoneEvent,
83 ThreadItemUpdated,
84 ThreadStreamEvent,
85 URLSource,
86 UserMessageItem,
87 UserMessageTagContent,
88 UserMessageTextContent,
89 WidgetItem,
90 Workflow,
91 WorkflowItem,
92 WorkflowTaskAdded,
93 WorkflowTaskUpdated,
94)
95from chatkit.widgets import Card, Text
96
97thread = Thread(id="123", title="Test", created_at=datetime.now(), items=Page())
98
99mock_store = Mock()
100mock_store.generate_item_id = lambda item_type, thread, context: f"{item_type}_id"
101mock_store.load_thread_items = AsyncMock(return_value=Page())
102mock_store.add_thread_item = AsyncMock()
103
104
105class RunResult(RunResultStreaming):
106 def add_event(self, event: StreamEvent):
107 self._event_queue.put_nowait(event)
108
109 def done(self):
110 self.is_complete = True
111 self._event_queue.put_nowait(QueueCompleteSentinel())
112
113 def throw_input_guardrails(self):
114 self._stored_exception = InputGuardrailTripwireTriggered(
115 InputGuardrailResult(
116 guardrail=Mock(spec=InputGuardrail),
117 output=GuardrailFunctionOutput(
118 output_info=None,
119 tripwire_triggered=True,
120 ),
121 )
122 )
123 self.is_complete = True
124 self._event_queue.put_nowait(QueueCompleteSentinel())
125
126 def throw_output_guardrails(self):
127 self._stored_exception = OutputGuardrailTripwireTriggered(
128 OutputGuardrailResult(
129 guardrail=Mock(spec=OutputGuardrail),
130 output=GuardrailFunctionOutput(
131 output_info=None,
132 tripwire_triggered=True,
133 ),
134 agent=Mock(spec=Agent),
135 agent_output=None,
136 )
137 )
138 self.is_complete = True
139 self._event_queue.put_nowait(QueueCompleteSentinel())
140
141
142def make_result() -> RunResult:
143 return RunResult(
144 context_wrapper=Mock(spec=RunContextWrapper),
145 input=[],
146 tool_input_guardrail_results=[],
147 tool_output_guardrail_results=[],
148 new_items=[],
149 raw_responses=[],
150 final_output=None,
151 current_agent=Agent(name="test"),
152 current_turn=0,
153 max_turns=10,
154 _current_agent_output_schema=None,
155 trace=None,
156 is_complete=False,
157 _event_queue=asyncio.Queue(),
158 _input_guardrail_queue=asyncio.Queue(),
159 _output_guardrails_task=None,
160 _run_impl_task=None,
161 _stored_exception=None,
162 output_guardrail_results=[],
163 input_guardrail_results=[],
164 )
165
166
167async def all_events(
168 events: AsyncIterator[ThreadStreamEvent],
169) -> list[ThreadStreamEvent]:
170 return [event async for event in events]
171
172
173async def test_returns_widget_item():
174 context = AgentContext(
175 previous_response_id=None, thread=thread, store=mock_store, request_context=None
176 )
177 result = make_result()
178 result.add_event(
179 RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
180 )
181 await context.stream_widget(Card(children=[Text(value="Hello, world!")]))
182 result.done()
183
184 events = await all_events(
185 stream_agent_response(
186 context=context,
187 result=result,
188 )
189 )
190
191 assert len(events) == 1
192 assert isinstance(events[0], ThreadItemDoneEvent)
193 assert isinstance(events[0].item, WidgetItem)
194 assert events[0].item.widget == Card(children=[Text(value="Hello, world!")])
195
196
197async def test_returns_widget_item_generator():
198 context = AgentContext(
199 previous_response_id=None, thread=thread, store=mock_store, request_context=None
200 )
201 result = make_result()
202 result.add_event(
203 RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
204 )
205
206 def render_widget(i: int) -> Card:
207 return Card(children=[Text(id="text", value="Hello, world"[:i])])
208
209 async def widget_generator():
210 yield render_widget(0)
211 yield render_widget(12)
212
213 await context.stream_widget(widget_generator())
214 result.done()
215
216 events = await all_events(
217 stream_agent_response(
218 context=context,
219 result=result,
220 )
221 )
222
223 assert len(events) == 3
224 assert isinstance(events[0], ThreadItemAddedEvent)
225 assert isinstance(events[0].item, WidgetItem)
226 assert events[0].item.widget == Card(children=[Text(id="text", value="")])
227
228 assert isinstance(events[1], ThreadItemUpdated)
229 assert events[1].update.type == "widget.streaming_text.value_delta"
230 assert events[1].update.component_id == "text"
231 assert events[1].update.delta == "Hello, world"
232
233 assert isinstance(events[2], ThreadItemDoneEvent)
234 assert isinstance(events[2].item, WidgetItem)
235 assert events[2].item.widget == Card(
236 children=[Text(id="text", value="Hello, world")]
237 )
238
239
240async def test_returns_widget_full_replace_generator():
241 context = AgentContext(
242 previous_response_id=None, thread=thread, store=mock_store, request_context=None
243 )
244 result = make_result()
245 result.add_event(
246 RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
247 )
248
249 async def widget_generator():
250 yield Card(children=[Text(id="text", value="Hello!")])
251 yield Card(children=[Text(key="other text", value="World!", streaming=False)])
252
253 await context.stream_widget(widget_generator())
254 result.done()
255
256 events = await all_events(
257 stream_agent_response(
258 context=context,
259 result=result,
260 )
261 )
262
263 assert len(events) == 3
264 assert isinstance(events[0], ThreadItemAddedEvent)
265 assert isinstance(events[0].item, WidgetItem)
266 assert events[0].item.widget == Card(children=[Text(id="text", value="Hello!")])
267
268 assert isinstance(events[1], ThreadItemUpdated)
269 assert events[1].update.type == "widget.root.updated"
270 assert events[1].update.widget == Card(
271 children=[Text(key="other text", value="World!", streaming=False)]
272 )
273
274 assert isinstance(events[2], ThreadItemDoneEvent)
275 assert isinstance(events[2].item, WidgetItem)
276 assert events[2].item.widget == Card(
277 children=[Text(key="other text", value="World!", streaming=False)]
278 )
279
280
281async def test_accumulate_text():
282 def delta(text: str) -> RawResponsesStreamEvent:
283 return RawResponsesStreamEvent(
284 type="raw_response_event",
285 data=ResponseTextDeltaEvent(
286 type="response.output_text.delta",
287 delta=text,
288 content_index=0,
289 item_id="123",
290 logprobs=[],
291 output_index=0,
292 sequence_number=0,
293 ),
294 )
295
296 result = Runner.run_streamed(
297 Agent("Assistant", instructions="You are a helpful assistant."), "Say hello!"
298 )
299 result = make_result()
300 result.add_event(delta("Hello, "))
301 result.add_event(delta("world!"))
302
303 result.done()
304
305 events = [
306 event
307 async for event in accumulate_text(
308 result.stream_events(), Text(key="text", value="", streaming=True)
309 )
310 ]
311 assert events == [
312 Text(key="text", value="", streaming=True),
313 Text(key="text", value="Hello, ", streaming=True),
314 Text(key="text", value="Hello, world!", streaming=True),
315 Text(key="text", value="Hello, world!", streaming=False),
316 ]
317
318
319async def test_input_item_converter_quotes_last_user_message():
320 items = [
321 UserMessageItem(
322 id="123",
323 content=[UserMessageTextContent(text="Hello!")],
324 attachments=[],
325 inference_options=InferenceOptions(),
326 thread_id=thread.id,
327 quoted_text="Hi!",
328 created_at=datetime.now(),
329 ),
330 UserMessageItem(
331 id="123",
332 content=[UserMessageTextContent(text="I'm well, thank you!")],
333 attachments=[],
334 inference_options=InferenceOptions(),
335 thread_id=thread.id,
336 quoted_text="How are you doing?",
337 created_at=datetime.now(),
338 ),
339 ]
340
341 async def throw_exception(
342 _: Attachment,
343 ) -> ResponseInputContentParam:
344 raise Exception("Not implemented")
345
346 input_items = await simple_to_agent_input(items)
347 assert len(input_items) == 3
348 assert input_items[0] == {
349 "content": [
350 {
351 "text": "Hello!",
352 "type": "input_text",
353 },
354 ],
355 "role": "user",
356 "type": "message",
357 }
358 assert input_items[1] == {
359 "content": [
360 {
361 "text": "I'm well, thank you!",
362 "type": "input_text",
363 },
364 ],
365 "role": "user",
366 "type": "message",
367 }
368 assert input_items[2] == {
369 "content": [
370 {
371 "text": "The user is referring to this in particular: \nHow are you doing?",
372 "type": "input_text",
373 },
374 ],
375 "role": "user",
376 "type": "message",
377 }
378
379
380async def test_input_item_converter_to_input_items_mixed():
381 items = [
382 UserMessageItem(
383 id="123",
384 content=[UserMessageTextContent(text="Hello!")],
385 attachments=[],
386 inference_options=InferenceOptions(),
387 thread_id=thread.id,
388 quoted_text="Hi!",
389 created_at=datetime.now(),
390 ),
391 UserMessageItem(
392 id="123",
393 content=[UserMessageTextContent(text="I'm well, thank you!")],
394 attachments=[],
395 inference_options=InferenceOptions(),
396 thread_id=thread.id,
397 quoted_text="How are you doing?",
398 created_at=datetime.now(),
399 ),
400 AssistantMessageItem(
401 id="123",
402 content=[
403 AssistantMessageContent(text="How are you doing?"),
404 AssistantMessageContent(text="Can't do that"),
405 ],
406 thread_id=thread.id,
407 created_at=datetime.now(),
408 ),
409 WidgetItem(
410 id="wd_123",
411 widget=Card(children=[Text(value="Hello, world!")]),
412 thread_id=thread.id,
413 created_at=datetime.now(),
414 ),
415 ]
416
417 input_items = await simple_to_agent_input(items)
418 assert len(input_items) == 4
419 assert input_items[0] == {
420 "content": [
421 {
422 "text": "Hello!",
423 "type": "input_text",
424 },
425 ],
426 "role": "user",
427 "type": "message",
428 }
429 assert input_items[1] == {
430 "content": [
431 {
432 "text": "I'm well, thank you!",
433 "type": "input_text",
434 },
435 ],
436 "role": "user",
437 "type": "message",
438 }
439 assert input_items[2] == {
440 "content": [
441 {
442 "annotations": [],
443 "text": "How are you doing?",
444 "logprobs": None,
445 "type": "output_text",
446 },
447 {
448 "annotations": [],
449 "text": "Can't do that",
450 "logprobs": None,
451 "type": "output_text",
452 },
453 ],
454 "type": "message",
455 "role": "assistant",
456 }
457 assert "type" in input_items[3]
458 widget_item = cast(EasyInputMessageParam, input_items[3])
459 assert widget_item.get("type") == "message"
460 assert widget_item.get("role") == "user"
461 text = widget_item.get("content")[0]["text"] # type: ignore
462 assert (
463 "The following graphical UI widget (id: wd_123) was displayed to the user"
464 in text
465 )
466 assert "Hello, world!" in text
467 assert "created_at" not in text
468
469
470async def test_input_item_converter_user_input_with_tags():
471 class MyThreadItemConverter(ThreadItemConverter):
472 def tag_to_message_content(self, tag):
473 return ResponseInputTextParam(
474 type="input_text", text=tag.text + " " + tag.data["key"]
475 )
476
477 items = [
478 UserMessageItem(
479 id="123",
480 content=[
481 UserMessageTagContent(
482 text="Hello!", type="input_tag", id="hello", data={"key": "value"}
483 )
484 ],
485 attachments=[],
486 inference_options=InferenceOptions(),
487 thread_id=thread.id,
488 created_at=datetime.now(),
489 )
490 ]
491 items = await MyThreadItemConverter().to_agent_input(items)
492
493 assert len(items) == 2
494 assert items[0] == {
495 "content": [
496 {
497 "text": "@Hello!",
498 "type": "input_text",
499 },
500 ],
501 "role": "user",
502 "type": "message",
503 }
504 assert items[1] == {
505 "content": [
506 {
507 "text": "# User-provided context for @-mentions\n- When referencing resolved entities, use their canonical names **without** '@'.\n"
508 + "- The '@' form appears only in user text and should not be echoed.",
509 "type": "input_text",
510 },
511 {
512 "text": "Hello! value",
513 "type": "input_text",
514 },
515 ],
516 "role": "user",
517 "type": "message",
518 }
519
520
521async def test_input_item_converter_user_input_with_tags_throws_by_default():
522 items = [
523 UserMessageItem(
524 id="123",
525 content=[
526 UserMessageTagContent(
527 text="Hello!", type="input_tag", id="hello", data={}
528 )
529 ],
530 attachments=[],
531 inference_options=InferenceOptions(),
532 thread_id=thread.id,
533 created_at=datetime.now(),
534 )
535 ]
536
537 with pytest.raises(NotImplementedError):
538 await simple_to_agent_input(items)
539
540
541async def test_input_item_converter_with_client_tool_call():
542 items = [
543 UserMessageItem(
544 id="123",
545 content=[UserMessageTextContent(text="Call a client tool call xyz")],
546 attachments=[],
547 inference_options=InferenceOptions(),
548 thread_id=thread.id,
549 quoted_text="Hi!",
550 created_at=datetime.now(),
551 ),
552 TaskItem(
553 id="tsk_123",
554 created_at=datetime.now(),
555 task=CustomTask(title="Called xyx"),
556 thread_id=thread.id,
557 ),
558 ClientToolCallItem(
559 id="ctc_123",
560 thread_id=thread.id,
561 created_at=datetime.now(),
562 name="xyz",
563 arguments={"foo": "bar"},
564 call_id="ctc_123",
565 ),
566 ClientToolCallItem(
567 id="ctc_123_done",
568 thread_id=thread.id,
569 created_at=datetime.now(),
570 name="xyz",
571 arguments={"foo": "bar"},
572 call_id="ctc_123",
573 status="completed",
574 output={"success": True},
575 ),
576 ]
577
578 input_items = await simple_to_agent_input(items)
579 assert len(input_items) == 4
580 assert input_items[0] == {
581 "content": [
582 {
583 "text": "Call a client tool call xyz",
584 "type": "input_text",
585 },
586 ],
587 "role": "user",
588 "type": "message",
589 }
590 assert input_items[1] == {
591 "content": [
592 {
593 "text": "A message was displayed to the user that the following task was performed:\n<Task>\nCalled xyx\n</Task>",
594 "type": "input_text",
595 },
596 ],
597 "type": "message",
598 "role": "user",
599 }
600 assert input_items[2] == {
601 "type": "function_call",
602 "name": "xyz",
603 "arguments": json.dumps({"foo": "bar"}),
604 "call_id": "ctc_123",
605 }
606 assert input_items[3] == {
607 "type": "function_call_output",
608 "call_id": "ctc_123",
609 "output": json.dumps({"success": True}),
610 }
611
612
613async def test_stream_agent_response_yields_context_events_without_streaming_events():
614 context = AgentContext(
615 previous_response_id=None, thread=thread, store=mock_store, request_context=None
616 )
617 result = make_result()
618
619 event = ThreadItemAddedEvent(
620 item=WidgetItem(
621 id="123",
622 created_at=datetime.now(),
623 thread_id=thread.id,
624 widget=Card(children=[Text(id="text", value="Hello, world!")]),
625 ),
626 )
627
628 await context.stream(event)
629
630 response_streamer = stream_agent_response(context, result)
631 event = await response_streamer.__anext__()
632
633 assert event.type == "thread.item.added"
634
635 future = asyncio.ensure_future(response_streamer.__anext__())
636 assert future.done() is False
637
638 result.done()
639
640 try:
641 await future
642 assert False, "expected StopAsyncIteration"
643 except StopAsyncIteration:
644 pass
645
646 assert future.done() is True
647
648
649async def test_stream_agent_response_maps_events():
650 context = AgentContext(
651 previous_response_id=None, thread=thread, store=mock_store, request_context=None
652 )
653 result = make_result()
654
655 event = ThreadItemAddedEvent(
656 item=WidgetItem(
657 id="123",
658 created_at=datetime.now(),
659 thread_id=thread.id,
660 widget=Card(children=[Text(id="text", value="Hello, world!")]),
661 ),
662 )
663
664 await context.stream(event)
665 result.add_event(
666 RawResponsesStreamEvent(
667 type="raw_response_event",
668 data=ResponseTextDeltaEvent(
669 type="response.output_text.delta",
670 delta="Hello, world!",
671 content_index=0,
672 item_id="123",
673 logprobs=[],
674 output_index=0,
675 sequence_number=0,
676 ),
677 )
678 )
679
680 response_streamer = stream_agent_response(context, result)
681 event1 = await response_streamer.__anext__()
682 event2 = await response_streamer.__anext__()
683
684 assert {event1.type, event2.type} == {
685 "thread.item.added",
686 "thread.item.updated",
687 }
688
689 future = asyncio.ensure_future(response_streamer.__anext__())
690 assert future.done() is False
691
692 result.done()
693
694 try:
695 await future
696 assert False, "expected StopAsyncIteration"
697 except StopAsyncIteration:
698 pass
699
700 assert future.done() is True
701
702
703@pytest.mark.parametrize(
704 "raw_event,expected_event",
705 [
706 (
707 RawResponsesStreamEvent(
708 type="raw_response_event",
709 data=ResponseTextDeltaEvent(
710 type="response.output_text.delta",
711 delta="Hello, world!",
712 content_index=0,
713 item_id="123",
714 logprobs=[],
715 output_index=0,
716 sequence_number=0,
717 ),
718 ),
719 ThreadItemUpdated(
720 item_id="123",
721 update=AssistantMessageContentPartTextDelta(
722 content_index=0,
723 delta="Hello, world!",
724 ),
725 ),
726 ),
727 (
728 RawResponsesStreamEvent(
729 type="raw_response_event",
730 data=ResponseContentPartAddedEvent(
731 type="response.content_part.added",
732 part=ResponseOutputText(
733 type="output_text",
734 text="New content",
735 annotations=[],
736 ),
737 content_index=1,
738 item_id="123",
739 output_index=0,
740 sequence_number=1,
741 ),
742 ),
743 ThreadItemUpdated(
744 item_id="123",
745 update=AssistantMessageContentPartAdded(
746 content_index=1,
747 content=AssistantMessageContent(text="New content", annotations=[]),
748 ),
749 ),
750 ),
751 (
752 RawResponsesStreamEvent(
753 type="raw_response_event",
754 data=ResponseTextDoneEvent(
755 type="response.output_text.done",
756 text="Final text",
757 content_index=0,
758 item_id="123",
759 logprobs=[],
760 output_index=0,
761 sequence_number=2,
762 ),
763 ),
764 ThreadItemUpdated(
765 item_id="123",
766 update=AssistantMessageContentPartDone(
767 content_index=0,
768 content=AssistantMessageContent(
769 text="Final text",
770 annotations=[],
771 ),
772 ),
773 ),
774 ),
775 (
776 RawResponsesStreamEvent(
777 type="raw_response_event",
778 data=Mock(
779 type="response.output_text.annotation.added",
780 annotation=ResponsesAnnotationFileCitation(
781 type="file_citation",
782 file_id="file_123",
783 filename="file.txt",
784 index=5,
785 ),
786 content_index=0,
787 item_id="123",
788 annotation_index=0,
789 output_index=0,
790 sequence_number=3,
791 ),
792 ),
793 None,
794 ),
795 ],
796)
797async def test_event_mapping(raw_event, expected_event):
798 context = AgentContext(
799 previous_response_id=None, thread=thread, store=mock_store, request_context=None
800 )
801 result = make_result()
802
803 result.add_event(raw_event)
804 result.done()
805
806 events = await all_events(stream_agent_response(context, result))
807 if expected_event:
808 assert events == [expected_event]
809 else:
810 assert events == []
811
812
813@pytest.mark.parametrize("throw_guardrail", ["input", "output"])
814async def test_stream_agent_response_yields_item_removed_event(throw_guardrail):
815 context = AgentContext(
816 previous_response_id=None, thread=thread, store=mock_store, request_context=None
817 )
818 result = make_result()
819 result.add_event(
820 RawResponsesStreamEvent(
821 type="raw_response_event",
822 data=ResponseOutputItemAddedEvent(
823 type="response.output_item.added",
824 item=ResponseOutputMessage(
825 id="1",
826 content=[
827 ResponseOutputText(
828 annotations=[], type="output_text", text="Hello, world!"
829 )
830 ],
831 role="assistant",
832 status="completed",
833 type="message",
834 ),
835 output_index=0,
836 sequence_number=0,
837 ),
838 )
839 )
840 await context.stream(
841 ThreadItemAddedEvent(
842 item=AssistantMessageItem(
843 id="2",
844 content=[AssistantMessageContent(text="Hello, world!")],
845 thread_id=thread.id,
846 created_at=datetime.now(),
847 ),
848 )
849 )
850
851 await context.stream(
852 ThreadItemDoneEvent(
853 item=WidgetItem(
854 id="3",
855 created_at=datetime.now(),
856 thread_id=thread.id,
857 widget=Card(children=[Text(id="text", value="Hello, world!")]),
858 ),
859 )
860 )
861
862 iterator = stream_agent_response(context, result)
863
864 n = 3
865 events = []
866 # Grab first 3 events to
867 async for event in iterator:
868 n -= 1
869 events.append(event)
870 if n == 0:
871 break
872
873 if throw_guardrail == "input":
874 result.throw_input_guardrails()
875 else:
876 result.throw_output_guardrails()
877
878 try:
879 async for event in iterator:
880 events.append(event)
881 assert False, "Guardrail should have been thrown from stream_agent_response"
882 except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
883 pass
884 except Exception as e:
885 assert False, f"Unexpected exception: {e}"
886
887 deleted_item_ids = {
888 event.item_id for event in events if event.type == "thread.item.removed"
889 }
890 assert deleted_item_ids == {"1", "2", "3"}
891
892
893async def test_stream_agent_response_assistant_message_content_types():
894 AgentContext(
895 previous_response_id=None, thread=thread, store=mock_store, request_context=None
896 )
897 result = make_result()
898
899 result.add_event(
900 RawResponsesStreamEvent(
901 type="raw_response_event",
902 data=ResponseOutputItemDoneEvent(
903 type="response.output_item.done",
904 item=ResponseFileSearchToolCall(
905 id="fs_0",
906 queries=["Hello, world!"],
907 status="completed",
908 type="file_search_call",
909 results=[
910 Result(
911 file_id="f_123",
912 filename="test.txt",
913 text="Hello, world!",
914 score=1.0,
915 ),
916 Result(
917 file_id="f_123",
918 filename="test.txt",
919 text="Hello, friends!",
920 score=0.5,
921 ),
922 ],
923 ),
924 output_index=0,
925 sequence_number=0,
926 ),
927 )
928 )
929 result.add_event(
930 RawResponsesStreamEvent(
931 type="raw_response_event",
932 data=ResponseOutputItemDoneEvent(
933 type="response.output_item.done",
934 item=ResponseOutputMessage(
935 id="1",
936 content=[
937 ResponseOutputText(
938 annotations=[
939 ResponsesAnnotationFileCitation(
940 type="file_citation",
941 file_id="f_123",
942 index=0,
943 filename="test.txt",
944 ),
945 ResponsesAnnotationURLCitation(
946 type="url_citation",
947 url="https://www.google.com",
948 title="Google",
949 start_index=0,
950 end_index=10,
951 ),
952 ResponsesAnnotationFilePath(
953 type="file_path",
954 file_id="123",
955 index=0,
956 ),
957 ],
958 text="Hello, world!",
959 type="output_text",
960 ),
961 ResponseOutputText(
962 annotations=[],
963 text="Can't do that",
964 type="output_text",
965 ),
966 ],
967 role="assistant",
968 status="completed",
969 type="message",
970 ),
971 output_index=0,
972 sequence_number=0,
973 ),
974 )
975 )
976
977 result.done()
978
979 context = AgentContext(
980 previous_response_id=None, thread=thread, store=mock_store, request_context=None
981 )
982 events = await all_events(stream_agent_response(context, result))
983 assert len(events) == 1
984 assert isinstance(events[0], ThreadItemDoneEvent)
985 message = events[0].item
986 assert isinstance(message, AssistantMessageItem)
987 assert message.content == [
988 AssistantMessageContent(
989 annotations=[
990 Annotation(
991 source=FileSource(
992 filename="test.txt",
993 title="test.txt",
994 ),
995 index=0,
996 ),
997 Annotation(
998 source=URLSource(
999 url="https://www.google.com",
1000 title="Google",
1001 ),
1002 index=10,
1003 ),
1004 ],
1005 text="Hello, world!",
1006 ),
1007 AssistantMessageContent(text="Can't do that", annotations=[]),
1008 ]
1009 assert message.id == "1"
1010
1011
1012async def test_workflow_streams_first_thought():
1013 context = AgentContext(
1014 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1015 )
1016 result = make_result()
1017
1018 # first thought
1019 result.add_event(
1020 RawResponsesStreamEvent(
1021 type="raw_response_event",
1022 data=ResponseOutputItemAddedEvent(
1023 type="response.output_item.added",
1024 item=ResponseReasoningItem(
1025 id="resp_1",
1026 summary=[],
1027 type="reasoning",
1028 ),
1029 output_index=0,
1030 sequence_number=0,
1031 ),
1032 )
1033 )
1034 result.add_event(
1035 RawResponsesStreamEvent(
1036 type="raw_response_event",
1037 data=Mock(
1038 type="response.reasoning_summary_text.delta",
1039 item_id="resp_1",
1040 summary_index=0,
1041 delta="Think",
1042 ),
1043 )
1044 )
1045 result.add_event(
1046 RawResponsesStreamEvent(
1047 type="raw_response_event",
1048 data=Mock(
1049 type="response.reasoning_summary_text.delta",
1050 item_id="resp_1",
1051 summary_index=0,
1052 delta="ing 1",
1053 ),
1054 )
1055 )
1056 result.add_event(
1057 RawResponsesStreamEvent(
1058 type="raw_response_event",
1059 data=Mock(
1060 type="response.reasoning_summary_text.done",
1061 item_id="resp_1",
1062 summary_index=0,
1063 text="Thinking 1",
1064 ),
1065 )
1066 )
1067
1068 # second thought
1069 result.add_event(
1070 RawResponsesStreamEvent(
1071 type="raw_response_event",
1072 data=Mock(
1073 type="response.reasoning_summary_text.delta",
1074 item_id="resp_1",
1075 summary_index=1,
1076 delta="Think",
1077 ),
1078 )
1079 )
1080 result.add_event(
1081 RawResponsesStreamEvent(
1082 type="raw_response_event",
1083 data=Mock(
1084 type="response.reasoning_summary_text.delta",
1085 item_id="resp_1",
1086 summary_index=1,
1087 delta="ing 2",
1088 ),
1089 )
1090 )
1091 result.add_event(
1092 RawResponsesStreamEvent(
1093 type="raw_response_event",
1094 data=Mock(
1095 type="response.reasoning_summary_text.done",
1096 item_id="resp_1",
1097 summary_index=1,
1098 text="Thinking 2",
1099 ),
1100 )
1101 )
1102
1103 result.done()
1104 stream = stream_agent_response(context, result)
1105
1106 # Workflow added
1107 event = await anext(stream)
1108 assert isinstance(event, ThreadItemAddedEvent)
1109 assert context.workflow_item is not None
1110 assert context.workflow_item.workflow.type == "reasoning"
1111 assert len(context.workflow_item.workflow.tasks) == 0
1112 assert event == ThreadItemAddedEvent(item=context.workflow_item)
1113
1114 # First thought added
1115 event = await anext(stream)
1116 assert context.workflow_item is not None
1117 assert len(context.workflow_item.workflow.tasks) == 1
1118 assert isinstance(event, ThreadItemUpdated)
1119 assert event == ThreadItemUpdated(
1120 item_id=context.workflow_item.id,
1121 update=WorkflowTaskAdded(
1122 task=ThoughtTask(content="Think"),
1123 task_index=0,
1124 ),
1125 )
1126
1127 # First thought delta
1128 event = await anext(stream)
1129 assert context.workflow_item is not None
1130 assert len(context.workflow_item.workflow.tasks) == 1
1131 assert isinstance(event, ThreadItemUpdated)
1132 assert event == ThreadItemUpdated(
1133 item_id=context.workflow_item.id,
1134 update=WorkflowTaskUpdated(
1135 task=ThoughtTask(content="Thinking 1"),
1136 task_index=0,
1137 ),
1138 )
1139
1140 # First thought done
1141 event = await anext(stream)
1142 assert context.workflow_item is not None
1143 assert len(context.workflow_item.workflow.tasks) == 1
1144 assert isinstance(event, ThreadItemUpdated)
1145 assert event == ThreadItemUpdated(
1146 item_id=context.workflow_item.id,
1147 update=WorkflowTaskUpdated(
1148 task=ThoughtTask(content="Thinking 1"),
1149 task_index=0,
1150 ),
1151 )
1152
1153 # Second thought added (not streamed)
1154 event = await anext(stream)
1155 assert context.workflow_item is not None
1156 assert len(context.workflow_item.workflow.tasks) == 2
1157 assert isinstance(event, ThreadItemUpdated)
1158 assert event == ThreadItemUpdated(
1159 item_id=context.workflow_item.id,
1160 update=WorkflowTaskAdded(
1161 task=ThoughtTask(content="Thinking 2"),
1162 task_index=1,
1163 ),
1164 )
1165
1166 try:
1167 while True:
1168 await anext(stream)
1169 except StopAsyncIteration:
1170 pass
1171
1172
1173async def test_workflow_ends_on_message():
1174 context = AgentContext(
1175 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1176 )
1177 result = make_result()
1178
1179 # first thought
1180 result.add_event(
1181 RawResponsesStreamEvent(
1182 type="raw_response_event",
1183 data=ResponseOutputItemAddedEvent(
1184 type="response.output_item.added",
1185 item=ResponseReasoningItem(
1186 id="resp_1",
1187 summary=[],
1188 type="reasoning",
1189 ),
1190 output_index=0,
1191 sequence_number=0,
1192 ),
1193 )
1194 )
1195 result.add_event(
1196 RawResponsesStreamEvent(
1197 type="raw_response_event",
1198 data=Mock(
1199 type="response.reasoning_summary_text.done",
1200 item_id="resp_1",
1201 summary_index=0,
1202 text="Thinking 1",
1203 ),
1204 )
1205 )
1206
1207 # not reasoning
1208 result.add_event(
1209 RawResponsesStreamEvent(
1210 type="raw_response_event",
1211 data=ResponseOutputItemAddedEvent(
1212 type="response.output_item.added",
1213 item=ResponseOutputMessage(
1214 id="m_1",
1215 content=[],
1216 role="assistant",
1217 status="in_progress",
1218 type="message",
1219 ),
1220 output_index=0,
1221 sequence_number=0,
1222 ),
1223 )
1224 )
1225
1226 result.done()
1227 stream = stream_agent_response(context, result)
1228
1229 # Workflow added
1230 event = await anext(stream)
1231 assert isinstance(event, ThreadItemAddedEvent)
1232 assert context.workflow_item is not None
1233 assert context.workflow_item.workflow.type == "reasoning"
1234 assert len(context.workflow_item.workflow.tasks) == 0
1235 assert event == ThreadItemAddedEvent(item=context.workflow_item)
1236
1237 # First thought done
1238 event = await anext(stream)
1239 assert context.workflow_item is not None
1240 assert len(context.workflow_item.workflow.tasks) == 1
1241 assert isinstance(event, ThreadItemUpdated)
1242 assert event == ThreadItemUpdated(
1243 item_id=context.workflow_item.id,
1244 update=WorkflowTaskAdded(
1245 task=ThoughtTask(content="Thinking 1"),
1246 task_index=0,
1247 ),
1248 )
1249
1250 # Workflow ended
1251 event = await anext(stream)
1252 assert isinstance(event, ThreadItemDoneEvent)
1253 assert event.item.type == "workflow"
1254 assert context.workflow_item is None
1255 # Summary and expanded are handled by the end_workflow method
1256 assert isinstance(event.item.workflow.summary, DurationSummary)
1257 assert event.item.workflow.expanded is False
1258
1259 try:
1260 while True:
1261 await anext(stream)
1262 except StopAsyncIteration:
1263 pass
1264
1265
1266async def test_existing_workflow_summary_not_overwritten_on_automatic_end():
1267 context = AgentContext(
1268 previous_response_id=None, thread=thread, store=mock_store, request_context=None
1269 )
1270 result = make_result()
1271 context.workflow_item = WorkflowItem(
1272 id="wf_1",
1273 created_at=datetime.now(),
1274 workflow=Workflow(type="custom", tasks=[], summary=CustomSummary(title="Test")),
1275 thread_id=thread.id,
1276 )
1277
1278 result.add_event(
1279 RawResponsesStreamEvent(
1280 type="raw_response_event",
1281 data=ResponseOutputItemAddedEvent(
1282 type="response.output_item.added",
1283 item=ResponseOutputMessage(
1284 id="m_1",
1285 content=[],
1286 role="assistant",
1287 status="in_progress",
1288 type="message",
1289 ),
1290 output_index=0,
1291 sequence_number=0,
1292 ),
1293 )
1294 )
1295
1296 result.done()
1297 stream = stream_agent_response(context, result)
1298
1299 event = await anext(stream)
1300
1301 assert isinstance(event, ThreadItemDoneEvent)
1302 assert context.workflow_item is None
1303 assert event.item.type == "workflow"
1304 assert event.item.workflow.summary == CustomSummary(title="Test")
1305
1306 try:
1307 while True:
1308 await anext(stream)
1309 except StopAsyncIteration:
1310 pass
1311