openai/chatkit-python

Public

mirrored from https://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.1.1

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_agents.py

1310lines · modeblame

f688d870victor-openai9 months ago1import asyncio
2import json
3from collections.abc import AsyncIterator
4from datetime import datetime
5from typing import cast
6from unittest.mock import AsyncMock, Mock
7
8import pytest
9from agents import (
10Agent,
11GuardrailFunctionOutput,
12InputGuardrail,
13InputGuardrailResult,
14InputGuardrailTripwireTriggered,
15OutputGuardrail,
16OutputGuardrailResult,
17OutputGuardrailTripwireTriggered,
18RawResponsesStreamEvent,
19RunContextWrapper,
20RunItemStreamEvent,
21Runner,
22RunResultStreaming,
23StreamEvent,
24ToolCallItem,
25)
26from agents._run_impl import QueueCompleteSentinel
27from openai.types.responses import (
28EasyInputMessageParam,
29ResponseFileSearchToolCall,
30ResponseInputContentParam,
31ResponseInputTextParam,
32ResponseOutputItemAddedEvent,
33ResponseOutputItemDoneEvent,
34ResponseOutputMessage,
35ResponseReasoningItem,
36)
37from openai.types.responses.response_content_part_added_event import (
38ResponseContentPartAddedEvent,
39)
40from openai.types.responses.response_file_search_tool_call import Result
41from openai.types.responses.response_output_text import (
42AnnotationFileCitation as ResponsesAnnotationFileCitation,
43)
44from openai.types.responses.response_output_text import (
45AnnotationFilePath as ResponsesAnnotationFilePath,
46)
47from openai.types.responses.response_output_text import (
48AnnotationURLCitation as ResponsesAnnotationURLCitation,
49)
50from openai.types.responses.response_output_text import (
51ResponseOutputText,
52)
53from openai.types.responses.response_text_delta_event import ResponseTextDeltaEvent
54from openai.types.responses.response_text_done_event import ResponseTextDoneEvent
55
56from chatkit.agents import (
57AgentContext,
58ThreadItemConverter,
59accumulate_text,
60simple_to_agent_input,
61stream_agent_response,
62)
63from chatkit.types import (
64Annotation,
65AssistantMessageContent,
66AssistantMessageContentPartAdded,
67AssistantMessageContentPartDone,
68AssistantMessageContentPartTextDelta,
69AssistantMessageItem,
70Attachment,
71ClientToolCallItem,
72CustomSummary,
73CustomTask,
74DurationSummary,
75FileSource,
76InferenceOptions,
77Page,
78TaskItem,
79ThoughtTask,
80Thread,
81ThreadItemAddedEvent,
82ThreadItemDoneEvent,
83ThreadItemUpdated,
84ThreadStreamEvent,
85URLSource,
86UserMessageItem,
87UserMessageTagContent,
88UserMessageTextContent,
89WidgetItem,
90Workflow,
91WorkflowItem,
92WorkflowTaskAdded,
93WorkflowTaskUpdated,
94)
95from chatkit.widgets import Card, Text
96
97thread = Thread(id="123", title="Test", created_at=datetime.now(), items=Page())
98
99mock_store = Mock()
100mock_store.generate_item_id = lambda item_type, thread, context: f"{item_type}_id"
101mock_store.load_thread_items = AsyncMock(return_value=Page())
102mock_store.add_thread_item = AsyncMock()
103
104
105class RunResult(RunResultStreaming):
106def add_event(self, event: StreamEvent):
107self._event_queue.put_nowait(event)
108
109def done(self):
110self.is_complete = True
111self._event_queue.put_nowait(QueueCompleteSentinel())
112
113def throw_input_guardrails(self):
114self._stored_exception = InputGuardrailTripwireTriggered(
115InputGuardrailResult(
116guardrail=Mock(spec=InputGuardrail),
117output=GuardrailFunctionOutput(
118output_info=None,
119tripwire_triggered=True,
120),
121)
122)
123self.is_complete = True
124self._event_queue.put_nowait(QueueCompleteSentinel())
125
126def throw_output_guardrails(self):
127self._stored_exception = OutputGuardrailTripwireTriggered(
128OutputGuardrailResult(
129guardrail=Mock(spec=OutputGuardrail),
130output=GuardrailFunctionOutput(
131output_info=None,
132tripwire_triggered=True,
133),
134agent=Mock(spec=Agent),
135agent_output=None,
136)
137)
138self.is_complete = True
139self._event_queue.put_nowait(QueueCompleteSentinel())
140
141
142def make_result() -> RunResult:
143return RunResult(
144context_wrapper=Mock(spec=RunContextWrapper),
145input=[],
146tool_input_guardrail_results=[],
147tool_output_guardrail_results=[],
148new_items=[],
149raw_responses=[],
150final_output=None,
151current_agent=Agent(name="test"),
152current_turn=0,
153max_turns=10,
154_current_agent_output_schema=None,
155trace=None,
156is_complete=False,
157_event_queue=asyncio.Queue(),
158_input_guardrail_queue=asyncio.Queue(),
159_output_guardrails_task=None,
160_run_impl_task=None,
161_stored_exception=None,
162output_guardrail_results=[],
163input_guardrail_results=[],
164)
165
166
167async def all_events(
168events: AsyncIterator[ThreadStreamEvent],
169) -> list[ThreadStreamEvent]:
170return [event async for event in events]
171
172
173async def test_returns_widget_item():
174context = AgentContext(
175previous_response_id=None, thread=thread, store=mock_store, request_context=None
176)
177result = make_result()
178result.add_event(
179RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
180)
181await context.stream_widget(Card(children=[Text(value="Hello, world!")]))
182result.done()
183
184events = await all_events(
185stream_agent_response(
186context=context,
187result=result,
188)
189)
190
191assert len(events) == 1
192assert isinstance(events[0], ThreadItemDoneEvent)
193assert isinstance(events[0].item, WidgetItem)
194assert events[0].item.widget == Card(children=[Text(value="Hello, world!")])
195
196
197async def test_returns_widget_item_generator():
198context = AgentContext(
199previous_response_id=None, thread=thread, store=mock_store, request_context=None
200)
201result = make_result()
202result.add_event(
203RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
204)
205
206def render_widget(i: int) -> Card:
207return Card(children=[Text(id="text", value="Hello, world"[:i])])
208
209async def widget_generator():
210yield render_widget(0)
211yield render_widget(12)
212
213await context.stream_widget(widget_generator())
214result.done()
215
216events = await all_events(
217stream_agent_response(
218context=context,
219result=result,
220)
221)
222
223assert len(events) == 3
224assert isinstance(events[0], ThreadItemAddedEvent)
225assert isinstance(events[0].item, WidgetItem)
226assert events[0].item.widget == Card(children=[Text(id="text", value="")])
227
228assert isinstance(events[1], ThreadItemUpdated)
229assert events[1].update.type == "widget.streaming_text.value_delta"
230assert events[1].update.component_id == "text"
231assert events[1].update.delta == "Hello, world"
232
233assert isinstance(events[2], ThreadItemDoneEvent)
234assert isinstance(events[2].item, WidgetItem)
235assert events[2].item.widget == Card(
236children=[Text(id="text", value="Hello, world")]
237)
238
239
240async def test_returns_widget_full_replace_generator():
241context = AgentContext(
242previous_response_id=None, thread=thread, store=mock_store, request_context=None
243)
244result = make_result()
245result.add_event(
246RunItemStreamEvent(name="tool_called", item=Mock(spec=ToolCallItem))
247)
248
249async def widget_generator():
250yield Card(children=[Text(id="text", value="Hello!")])
251yield Card(children=[Text(key="other text", value="World!", streaming=False)])
252
253await context.stream_widget(widget_generator())
254result.done()
255
256events = await all_events(
257stream_agent_response(
258context=context,
259result=result,
260)
261)
262
263assert len(events) == 3
264assert isinstance(events[0], ThreadItemAddedEvent)
265assert isinstance(events[0].item, WidgetItem)
266assert events[0].item.widget == Card(children=[Text(id="text", value="Hello!")])
267
268assert isinstance(events[1], ThreadItemUpdated)
269assert events[1].update.type == "widget.root.updated"
270assert events[1].update.widget == Card(
271children=[Text(key="other text", value="World!", streaming=False)]
272)
273
274assert isinstance(events[2], ThreadItemDoneEvent)
275assert isinstance(events[2].item, WidgetItem)
276assert events[2].item.widget == Card(
277children=[Text(key="other text", value="World!", streaming=False)]
278)
279
280
281async def test_accumulate_text():
282def delta(text: str) -> RawResponsesStreamEvent:
283return RawResponsesStreamEvent(
284type="raw_response_event",
285data=ResponseTextDeltaEvent(
286type="response.output_text.delta",
287delta=text,
288content_index=0,
289item_id="123",
290logprobs=[],
291output_index=0,
292sequence_number=0,
293),
294)
295
296result = Runner.run_streamed(
297Agent("Assistant", instructions="You are a helpful assistant."), "Say hello!"
298)
299result = make_result()
300result.add_event(delta("Hello, "))
301result.add_event(delta("world!"))
302
303result.done()
304
305events = [
306event
307async for event in accumulate_text(
308result.stream_events(), Text(key="text", value="", streaming=True)
309)
310]
311assert events == [
312Text(key="text", value="", streaming=True),
313Text(key="text", value="Hello, ", streaming=True),
314Text(key="text", value="Hello, world!", streaming=True),
315Text(key="text", value="Hello, world!", streaming=False),
316]
317
318
319async def test_input_item_converter_quotes_last_user_message():
320items = [
321UserMessageItem(
322id="123",
323content=[UserMessageTextContent(text="Hello!")],
324attachments=[],
325inference_options=InferenceOptions(),
326thread_id=thread.id,
327quoted_text="Hi!",
328created_at=datetime.now(),
329),
330UserMessageItem(
331id="123",
332content=[UserMessageTextContent(text="I'm well, thank you!")],
333attachments=[],
334inference_options=InferenceOptions(),
335thread_id=thread.id,
336quoted_text="How are you doing?",
337created_at=datetime.now(),
338),
339]
340
341async def throw_exception(
342_: Attachment,
343) -> ResponseInputContentParam:
344raise Exception("Not implemented")
345
346input_items = await simple_to_agent_input(items)
347assert len(input_items) == 3
348assert input_items[0] == {
349"content": [
350{
351"text": "Hello!",
352"type": "input_text",
353},
354],
355"role": "user",
356"type": "message",
357}
358assert input_items[1] == {
359"content": [
360{
361"text": "I'm well, thank you!",
362"type": "input_text",
363},
364],
365"role": "user",
366"type": "message",
367}
368assert input_items[2] == {
369"content": [
370{
371"text": "The user is referring to this in particular: \nHow are you doing?",
372"type": "input_text",
373},
374],
375"role": "user",
376"type": "message",
377}
378
379
380async def test_input_item_converter_to_input_items_mixed():
381items = [
382UserMessageItem(
383id="123",
384content=[UserMessageTextContent(text="Hello!")],
385attachments=[],
386inference_options=InferenceOptions(),
387thread_id=thread.id,
388quoted_text="Hi!",
389created_at=datetime.now(),
390),
391UserMessageItem(
392id="123",
393content=[UserMessageTextContent(text="I'm well, thank you!")],
394attachments=[],
395inference_options=InferenceOptions(),
396thread_id=thread.id,
397quoted_text="How are you doing?",
398created_at=datetime.now(),
399),
400AssistantMessageItem(
401id="123",
402content=[
403AssistantMessageContent(text="How are you doing?"),
404AssistantMessageContent(text="Can't do that"),
405],
406thread_id=thread.id,
407created_at=datetime.now(),
408),
409WidgetItem(
410id="wd_123",
411widget=Card(children=[Text(value="Hello, world!")]),
412thread_id=thread.id,
413created_at=datetime.now(),
414),
415]
416
417input_items = await simple_to_agent_input(items)
418assert len(input_items) == 4
419assert input_items[0] == {
420"content": [
421{
422"text": "Hello!",
423"type": "input_text",
424},
425],
426"role": "user",
427"type": "message",
428}
429assert input_items[1] == {
430"content": [
431{
432"text": "I'm well, thank you!",
433"type": "input_text",
434},
435],
436"role": "user",
437"type": "message",
438}
439assert input_items[2] == {
440"content": [
441{
442"annotations": [],
443"text": "How are you doing?",
444"logprobs": None,
445"type": "output_text",
446},
447{
448"annotations": [],
449"text": "Can't do that",
450"logprobs": None,
451"type": "output_text",
452},
453],
454"type": "message",
455"role": "assistant",
456}
457assert "type" in input_items[3]
458widget_item = cast(EasyInputMessageParam, input_items[3])
459assert widget_item.get("type") == "message"
460assert widget_item.get("role") == "user"
461text = widget_item.get("content")[0]["text"] # type: ignore
462assert (
463"The following graphical UI widget (id: wd_123) was displayed to the user"
464in text
465)
466assert "Hello, world!" in text
467assert "created_at" not in text
468
469
470async def test_input_item_converter_user_input_with_tags():
471class MyThreadItemConverter(ThreadItemConverter):
b474bf11Jiwon Kim8 months ago472async def tag_to_message_content(self, tag):
f688d870victor-openai9 months ago473return ResponseInputTextParam(
474type="input_text", text=tag.text + " " + tag.data["key"]
475)
476
477items = [
478UserMessageItem(
479id="123",
480content=[
481UserMessageTagContent(
482text="Hello!", type="input_tag", id="hello", data={"key": "value"}
483)
484],
485attachments=[],
486inference_options=InferenceOptions(),
487thread_id=thread.id,
488created_at=datetime.now(),
489)
490]
491items = await MyThreadItemConverter().to_agent_input(items)
492
493assert len(items) == 2
494assert items[0] == {
495"content": [
496{
497"text": "@Hello!",
498"type": "input_text",
499},
500],
501"role": "user",
502"type": "message",
503}
504assert items[1] == {
505"content": [
506{
507"text": "# User-provided context for @-mentions\n- When referencing resolved entities, use their canonical names **without** '@'.\n"
508+ "- The '@' form appears only in user text and should not be echoed.",
509"type": "input_text",
510},
511{
512"text": "Hello! value",
513"type": "input_text",
514},
515],
516"role": "user",
517"type": "message",
518}
519
520
521async def test_input_item_converter_user_input_with_tags_throws_by_default():
522items = [
523UserMessageItem(
524id="123",
525content=[
526UserMessageTagContent(
527text="Hello!", type="input_tag", id="hello", data={}
528)
529],
530attachments=[],
531inference_options=InferenceOptions(),
532thread_id=thread.id,
533created_at=datetime.now(),
534)
535]
536
537with pytest.raises(NotImplementedError):
538await simple_to_agent_input(items)
539
540
541async def test_input_item_converter_with_client_tool_call():
542items = [
543UserMessageItem(
544id="123",
545content=[UserMessageTextContent(text="Call a client tool call xyz")],
546attachments=[],
547inference_options=InferenceOptions(),
548thread_id=thread.id,
549quoted_text="Hi!",
550created_at=datetime.now(),
551),
552TaskItem(
553id="tsk_123",
554created_at=datetime.now(),
555task=CustomTask(title="Called xyx"),
556thread_id=thread.id,
557),
558ClientToolCallItem(
559id="ctc_123",
560thread_id=thread.id,
561created_at=datetime.now(),
562name="xyz",
563arguments={"foo": "bar"},
564call_id="ctc_123",
565),
566ClientToolCallItem(
567id="ctc_123_done",
568thread_id=thread.id,
569created_at=datetime.now(),
570name="xyz",
571arguments={"foo": "bar"},
572call_id="ctc_123",
573status="completed",
574output={"success": True},
575),
576]
577
578input_items = await simple_to_agent_input(items)
579assert len(input_items) == 4
580assert input_items[0] == {
581"content": [
582{
583"text": "Call a client tool call xyz",
584"type": "input_text",
585},
586],
587"role": "user",
588"type": "message",
589}
590assert input_items[1] == {
591"content": [
592{
593"text": "A message was displayed to the user that the following task was performed:\n<Task>\nCalled xyx\n</Task>",
594"type": "input_text",
595},
596],
597"type": "message",
598"role": "user",
599}
600assert input_items[2] == {
601"type": "function_call",
602"name": "xyz",
603"arguments": json.dumps({"foo": "bar"}),
604"call_id": "ctc_123",
605}
606assert input_items[3] == {
607"type": "function_call_output",
608"call_id": "ctc_123",
609"output": json.dumps({"success": True}),
610}
611
612
613async def test_stream_agent_response_yields_context_events_without_streaming_events():
614context = AgentContext(
615previous_response_id=None, thread=thread, store=mock_store, request_context=None
616)
617result = make_result()
618
619event = ThreadItemAddedEvent(
620item=WidgetItem(
621id="123",
622created_at=datetime.now(),
623thread_id=thread.id,
624widget=Card(children=[Text(id="text", value="Hello, world!")]),
625),
626)
627
628await context.stream(event)
629
630response_streamer = stream_agent_response(context, result)
631event = await response_streamer.__anext__()
632
633assert event.type == "thread.item.added"
634
635future = asyncio.ensure_future(response_streamer.__anext__())
636assert future.done() is False
637
638result.done()
639
640try:
641await future
642assert False, "expected StopAsyncIteration"
643except StopAsyncIteration:
644pass
645
646assert future.done() is True
647
648
649async def test_stream_agent_response_maps_events():
650context = AgentContext(
651previous_response_id=None, thread=thread, store=mock_store, request_context=None
652)
653result = make_result()
654
655event = ThreadItemAddedEvent(
656item=WidgetItem(
657id="123",
658created_at=datetime.now(),
659thread_id=thread.id,
660widget=Card(children=[Text(id="text", value="Hello, world!")]),
661),
662)
663
664await context.stream(event)
665result.add_event(
666RawResponsesStreamEvent(
667type="raw_response_event",
668data=ResponseTextDeltaEvent(
669type="response.output_text.delta",
670delta="Hello, world!",
671content_index=0,
672item_id="123",
673logprobs=[],
674output_index=0,
675sequence_number=0,
676),
677)
678)
679
680response_streamer = stream_agent_response(context, result)
681event1 = await response_streamer.__anext__()
682event2 = await response_streamer.__anext__()
683
684assert {event1.type, event2.type} == {
685"thread.item.added",
686"thread.item.updated",
687}
688
689future = asyncio.ensure_future(response_streamer.__anext__())
690assert future.done() is False
691
692result.done()
693
694try:
695await future
696assert False, "expected StopAsyncIteration"
697except StopAsyncIteration:
698pass
699
700assert future.done() is True
701
702
703@pytest.mark.parametrize(
704"raw_event,expected_event",
705[
706(
707RawResponsesStreamEvent(
708type="raw_response_event",
709data=ResponseTextDeltaEvent(
710type="response.output_text.delta",
711delta="Hello, world!",
712content_index=0,
713item_id="123",
714logprobs=[],
715output_index=0,
716sequence_number=0,
717),
718),
719ThreadItemUpdated(
720item_id="123",
721update=AssistantMessageContentPartTextDelta(
722content_index=0,
723delta="Hello, world!",
724),
725),
726),
727(
728RawResponsesStreamEvent(
729type="raw_response_event",
730data=ResponseContentPartAddedEvent(
731type="response.content_part.added",
732part=ResponseOutputText(
733type="output_text",
734text="New content",
735annotations=[],
736),
737content_index=1,
738item_id="123",
739output_index=0,
740sequence_number=1,
741),
742),
743ThreadItemUpdated(
744item_id="123",
745update=AssistantMessageContentPartAdded(
746content_index=1,
747content=AssistantMessageContent(text="New content", annotations=[]),
748),
749),
750),
751(
752RawResponsesStreamEvent(
753type="raw_response_event",
754data=ResponseTextDoneEvent(
755type="response.output_text.done",
756text="Final text",
757content_index=0,
758item_id="123",
759logprobs=[],
760output_index=0,
761sequence_number=2,
762),
763),
764ThreadItemUpdated(
765item_id="123",
766update=AssistantMessageContentPartDone(
767content_index=0,
768content=AssistantMessageContent(
769text="Final text",
770annotations=[],
771),
772),
773),
774),
775(
776RawResponsesStreamEvent(
777type="raw_response_event",
778data=Mock(
779type="response.output_text.annotation.added",
780annotation=ResponsesAnnotationFileCitation(
781type="file_citation",
782file_id="file_123",
783filename="file.txt",
784index=5,
785),
786content_index=0,
787item_id="123",
788annotation_index=0,
789output_index=0,
790sequence_number=3,
791),
792),
793None,
794),
795],
796)
797async def test_event_mapping(raw_event, expected_event):
798context = AgentContext(
799previous_response_id=None, thread=thread, store=mock_store, request_context=None
800)
801result = make_result()
802
803result.add_event(raw_event)
804result.done()
805
806events = await all_events(stream_agent_response(context, result))
807if expected_event:
808assert events == [expected_event]
809else:
810assert events == []
811
812
813@pytest.mark.parametrize("throw_guardrail", ["input", "output"])
814async def test_stream_agent_response_yields_item_removed_event(throw_guardrail):
815context = AgentContext(
816previous_response_id=None, thread=thread, store=mock_store, request_context=None
817)
818result = make_result()
819result.add_event(
820RawResponsesStreamEvent(
821type="raw_response_event",
822data=ResponseOutputItemAddedEvent(
823type="response.output_item.added",
824item=ResponseOutputMessage(
825id="1",
826content=[
827ResponseOutputText(
828annotations=[], type="output_text", text="Hello, world!"
829)
830],
831role="assistant",
832status="completed",
833type="message",
834),
835output_index=0,
836sequence_number=0,
837),
838)
839)
840await context.stream(
841ThreadItemAddedEvent(
842item=AssistantMessageItem(
843id="2",
844content=[AssistantMessageContent(text="Hello, world!")],
845thread_id=thread.id,
846created_at=datetime.now(),
847),
848)
849)
850
851await context.stream(
852ThreadItemDoneEvent(
853item=WidgetItem(
854id="3",
855created_at=datetime.now(),
856thread_id=thread.id,
857widget=Card(children=[Text(id="text", value="Hello, world!")]),
858),
859)
860)
861
862iterator = stream_agent_response(context, result)
863
864n = 3
865events = []
866# Grab first 3 events to
867async for event in iterator:
868n -= 1
869events.append(event)
870if n == 0:
871break
872
873if throw_guardrail == "input":
874result.throw_input_guardrails()
875else:
876result.throw_output_guardrails()
877
878try:
879async for event in iterator:
880events.append(event)
881assert False, "Guardrail should have been thrown from stream_agent_response"
882except (InputGuardrailTripwireTriggered, OutputGuardrailTripwireTriggered):
883pass
884except Exception as e:
885assert False, f"Unexpected exception: {e}"
886
887deleted_item_ids = {
888event.item_id for event in events if event.type == "thread.item.removed"
889}
890assert deleted_item_ids == {"1", "2", "3"}
891
892
893async def test_stream_agent_response_assistant_message_content_types():
894AgentContext(
895previous_response_id=None, thread=thread, store=mock_store, request_context=None
896)
897result = make_result()
898
899result.add_event(
900RawResponsesStreamEvent(
901type="raw_response_event",
902data=ResponseOutputItemDoneEvent(
903type="response.output_item.done",
904item=ResponseFileSearchToolCall(
905id="fs_0",
906queries=["Hello, world!"],
907status="completed",
908type="file_search_call",
909results=[
910Result(
911file_id="f_123",
912filename="test.txt",
913text="Hello, world!",
914score=1.0,
915),
916Result(
917file_id="f_123",
918filename="test.txt",
919text="Hello, friends!",
920score=0.5,
921),
922],
923),
924output_index=0,
925sequence_number=0,
926),
927)
928)
929result.add_event(
930RawResponsesStreamEvent(
931type="raw_response_event",
932data=ResponseOutputItemDoneEvent(
933type="response.output_item.done",
934item=ResponseOutputMessage(
935id="1",
936content=[
937ResponseOutputText(
938annotations=[
939ResponsesAnnotationFileCitation(
940type="file_citation",
941file_id="f_123",
942index=0,
943filename="test.txt",
944),
945ResponsesAnnotationURLCitation(
946type="url_citation",
947url="https://www.google.com",
948title="Google",
949start_index=0,
950end_index=10,
951),
952ResponsesAnnotationFilePath(
953type="file_path",
954file_id="123",
955index=0,
956),
957],
958text="Hello, world!",
959type="output_text",
960),
961ResponseOutputText(
962annotations=[],
963text="Can't do that",
964type="output_text",
965),
966],
967role="assistant",
968status="completed",
969type="message",
970),
971output_index=0,
972sequence_number=0,
973),
974)
975)
976
977result.done()
978
979context = AgentContext(
980previous_response_id=None, thread=thread, store=mock_store, request_context=None
981)
982events = await all_events(stream_agent_response(context, result))
983assert len(events) == 1
984assert isinstance(events[0], ThreadItemDoneEvent)
985message = events[0].item
986assert isinstance(message, AssistantMessageItem)
987assert message.content == [
988AssistantMessageContent(
989annotations=[
990Annotation(
991source=FileSource(
992filename="test.txt",
993title="test.txt",
994),
995index=0,
996),
997Annotation(
998source=URLSource(
999url="https://www.google.com",
1000title="Google",
1001),
1002index=10,
1003),
1004],
1005text="Hello, world!",
1006),
1007AssistantMessageContent(text="Can't do that", annotations=[]),
1008]
1009assert message.id == "1"
1010
1011
1012async def test_workflow_streams_first_thought():
1013context = AgentContext(
1014previous_response_id=None, thread=thread, store=mock_store, request_context=None
1015)
1016result = make_result()
1017
1018# first thought
1019result.add_event(
1020RawResponsesStreamEvent(
1021type="raw_response_event",
1022data=ResponseOutputItemAddedEvent(
1023type="response.output_item.added",
1024item=ResponseReasoningItem(
1025id="resp_1",
1026summary=[],
1027type="reasoning",
1028),
1029output_index=0,
1030sequence_number=0,
1031),
1032)
1033)
1034result.add_event(
1035RawResponsesStreamEvent(
1036type="raw_response_event",
1037data=Mock(
1038type="response.reasoning_summary_text.delta",
1039item_id="resp_1",
1040summary_index=0,
1041delta="Think",
1042),
1043)
1044)
1045result.add_event(
1046RawResponsesStreamEvent(
1047type="raw_response_event",
1048data=Mock(
1049type="response.reasoning_summary_text.delta",
1050item_id="resp_1",
1051summary_index=0,
1052delta="ing 1",
1053),
1054)
1055)
1056result.add_event(
1057RawResponsesStreamEvent(
1058type="raw_response_event",
1059data=Mock(
1060type="response.reasoning_summary_text.done",
1061item_id="resp_1",
1062summary_index=0,
1063text="Thinking 1",
1064),
1065)
1066)
1067
1068# second thought
1069result.add_event(
1070RawResponsesStreamEvent(
1071type="raw_response_event",
1072data=Mock(
1073type="response.reasoning_summary_text.delta",
1074item_id="resp_1",
1075summary_index=1,
1076delta="Think",
1077),
1078)
1079)
1080result.add_event(
1081RawResponsesStreamEvent(
1082type="raw_response_event",
1083data=Mock(
1084type="response.reasoning_summary_text.delta",
1085item_id="resp_1",
1086summary_index=1,
1087delta="ing 2",
1088),
1089)
1090)
1091result.add_event(
1092RawResponsesStreamEvent(
1093type="raw_response_event",
1094data=Mock(
1095type="response.reasoning_summary_text.done",
1096item_id="resp_1",
1097summary_index=1,
1098text="Thinking 2",
1099),
1100)
1101)
1102
1103result.done()
1104stream = stream_agent_response(context, result)
1105
1106# Workflow added
1107event = await anext(stream)
1108assert isinstance(event, ThreadItemAddedEvent)
1109assert context.workflow_item is not None
1110assert context.workflow_item.workflow.type == "reasoning"
1111assert len(context.workflow_item.workflow.tasks) == 0
1112assert event == ThreadItemAddedEvent(item=context.workflow_item)
1113
1114# First thought added
1115event = await anext(stream)
1116assert context.workflow_item is not None
1117assert len(context.workflow_item.workflow.tasks) == 1
1118assert isinstance(event, ThreadItemUpdated)
1119assert event == ThreadItemUpdated(
1120item_id=context.workflow_item.id,
1121update=WorkflowTaskAdded(
1122task=ThoughtTask(content="Think"),
1123task_index=0,
1124),
1125)
1126
1127# First thought delta
1128event = await anext(stream)
1129assert context.workflow_item is not None
1130assert len(context.workflow_item.workflow.tasks) == 1
1131assert isinstance(event, ThreadItemUpdated)
1132assert event == ThreadItemUpdated(
1133item_id=context.workflow_item.id,
1134update=WorkflowTaskUpdated(
1135task=ThoughtTask(content="Thinking 1"),
1136task_index=0,
1137),
1138)
1139
1140# First thought done
1141event = await anext(stream)
1142assert context.workflow_item is not None
1143assert len(context.workflow_item.workflow.tasks) == 1
1144assert isinstance(event, ThreadItemUpdated)
1145assert event == ThreadItemUpdated(
1146item_id=context.workflow_item.id,
1147update=WorkflowTaskUpdated(
1148task=ThoughtTask(content="Thinking 1"),
1149task_index=0,
1150),
1151)
1152
1153# Second thought added (not streamed)
1154event = await anext(stream)
1155assert context.workflow_item is not None
1156assert len(context.workflow_item.workflow.tasks) == 2
1157assert isinstance(event, ThreadItemUpdated)
1158assert event == ThreadItemUpdated(
1159item_id=context.workflow_item.id,
1160update=WorkflowTaskAdded(
1161task=ThoughtTask(content="Thinking 2"),
1162task_index=1,
1163),
1164)
1165
1166try:
1167while True:
1168await anext(stream)
1169except StopAsyncIteration:
1170pass
1171
1172
1173async def test_workflow_ends_on_message():
1174context = AgentContext(
1175previous_response_id=None, thread=thread, store=mock_store, request_context=None
1176)
1177result = make_result()
1178
1179# first thought
1180result.add_event(
1181RawResponsesStreamEvent(
1182type="raw_response_event",
1183data=ResponseOutputItemAddedEvent(
1184type="response.output_item.added",
1185item=ResponseReasoningItem(
1186id="resp_1",
1187summary=[],
1188type="reasoning",
1189),
1190output_index=0,
1191sequence_number=0,
1192),
1193)
1194)
1195result.add_event(
1196RawResponsesStreamEvent(
1197type="raw_response_event",
1198data=Mock(
1199type="response.reasoning_summary_text.done",
1200item_id="resp_1",
1201summary_index=0,
1202text="Thinking 1",
1203),
1204)
1205)
1206
1207# not reasoning
1208result.add_event(
1209RawResponsesStreamEvent(
1210type="raw_response_event",
1211data=ResponseOutputItemAddedEvent(
1212type="response.output_item.added",
1213item=ResponseOutputMessage(
1214id="m_1",
1215content=[],
1216role="assistant",
1217status="in_progress",
1218type="message",
1219),
1220output_index=0,
1221sequence_number=0,
1222),
1223)
1224)
1225
1226result.done()
1227stream = stream_agent_response(context, result)
1228
1229# Workflow added
1230event = await anext(stream)
1231assert isinstance(event, ThreadItemAddedEvent)
1232assert context.workflow_item is not None
1233assert context.workflow_item.workflow.type == "reasoning"
1234assert len(context.workflow_item.workflow.tasks) == 0
1235assert event == ThreadItemAddedEvent(item=context.workflow_item)
1236
1237# First thought done
1238event = await anext(stream)
1239assert context.workflow_item is not None
1240assert len(context.workflow_item.workflow.tasks) == 1
1241assert isinstance(event, ThreadItemUpdated)
1242assert event == ThreadItemUpdated(
1243item_id=context.workflow_item.id,
1244update=WorkflowTaskAdded(
1245task=ThoughtTask(content="Thinking 1"),
1246task_index=0,
1247),
1248)
1249
1250# Workflow ended
1251event = await anext(stream)
1252assert isinstance(event, ThreadItemDoneEvent)
1253assert event.item.type == "workflow"
1254assert context.workflow_item is None
1255# Summary and expanded are handled by the end_workflow method
1256assert isinstance(event.item.workflow.summary, DurationSummary)
1257assert event.item.workflow.expanded is False
1258
1259try:
1260while True:
1261await anext(stream)
1262except StopAsyncIteration:
1263pass
1264
1265
1266async def test_existing_workflow_summary_not_overwritten_on_automatic_end():
1267context = AgentContext(
1268previous_response_id=None, thread=thread, store=mock_store, request_context=None
1269)
1270result = make_result()
1271context.workflow_item = WorkflowItem(
1272id="wf_1",
1273created_at=datetime.now(),
1274workflow=Workflow(type="custom", tasks=[], summary=CustomSummary(title="Test")),
1275thread_id=thread.id,
1276)
1277
1278result.add_event(
1279RawResponsesStreamEvent(
1280type="raw_response_event",
1281data=ResponseOutputItemAddedEvent(
1282type="response.output_item.added",
1283item=ResponseOutputMessage(
1284id="m_1",
1285content=[],
1286role="assistant",
1287status="in_progress",
1288type="message",
1289),
1290output_index=0,
1291sequence_number=0,
1292),
1293)
1294)
1295
1296result.done()
1297stream = stream_agent_response(context, result)
1298
1299event = await anext(stream)
1300
1301assert isinstance(event, ThreadItemDoneEvent)
1302assert context.workflow_item is None
1303assert event.item.type == "workflow"
1304assert event.item.workflow.summary == CustomSummary(title="Test")
1305
1306try:
1307while True:
1308await anext(stream)
1309except StopAsyncIteration:
1310pass