openai/chatkit-python

Public

mirrored from https://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.1.1

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_chatkit_server.py

1797lines · modeblame

f688d870victor-openai8 months ago1import sqlite3
2from contextlib import contextmanager
3from datetime import datetime
4from typing import Any, AsyncIterator, Callable, cast
5
6import pytest
7from helpers.mock_store import SQLiteStore
8from pydantic import AnyUrl, TypeAdapter
9
10from chatkit.actions import Action
11from chatkit.errors import ErrorCode
12from chatkit.server import (
13ChatKitServer,
14NonStreamingResult,
15StreamingResult,
16stream_widget,
17)
18from chatkit.store import AttachmentStore, NotFoundError
19from chatkit.types import (
20AssistantMessageContent,
21AssistantMessageItem,
22Attachment,
23AttachmentCreateParams,
24AttachmentDeleteParams,
25AttachmentsCreateReq,
26AttachmentsDeleteReq,
27ClientToolCallItem,
28FeedbackKind,
29FileAttachment,
30ImageAttachment,
31InferenceOptions,
32ItemFeedbackParams,
33ItemsFeedbackReq,
34ItemsListParams,
35ItemsListReq,
36LockedStatus,
37Page,
38ProgressUpdateEvent,
39Thread,
40ThreadAddClientToolOutputParams,
41ThreadAddUserMessageParams,
42ThreadCreatedEvent,
43ThreadCreateParams,
44ThreadCustomActionParams,
45ThreadDeleteParams,
46ThreadGetByIdParams,
47ThreadItem,
48ThreadItemAddedEvent,
49ThreadItemDoneEvent,
50ThreadItemRemovedEvent,
51ThreadItemReplacedEvent,
52ThreadItemUpdated,
53ThreadListParams,
54ThreadMetadata,
55ThreadRetryAfterItemParams,
56ThreadsAddClientToolOutputReq,
57ThreadsAddUserMessageReq,
58ThreadsCreateReq,
59ThreadsCustomActionReq,
60ThreadsDeleteReq,
61ThreadsGetByIdReq,
62ThreadsListReq,
63ThreadsRetryAfterItemReq,
64ThreadStreamEvent,
65ThreadsUpdateReq,
66ThreadUpdatedEvent,
67ThreadUpdateParams,
68ToolChoice,
69UserMessageInput,
70UserMessageItem,
71UserMessageTextContent,
72WidgetItem,
73WidgetRootUpdated,
74)
75from chatkit.widgets import Card, Text
76from tests._types import RequestContext
77from tests.test_store import make_thread_items
78
79server_id = 0
80
81
82DEFAULT_CONTEXT = RequestContext(user_id="test_user")
83
84
85class InMemoryFileStore(AttachmentStore):
86def __init__(self):
87self.files = {}
88
89async def delete_attachment(self, attachment_id: str, context: Any) -> None:
90del self.files[attachment_id]
91
92async def create_attachment(
93self, input: AttachmentCreateParams, context: RequestContext
94) -> Attachment:
95# Simple in-memory attachment creation
96id = f"atc_{len(self.files) + 1}"
97if input.mime_type.startswith("image/"):
98attachment = ImageAttachment(
99id=id,
100mime_type=input.mime_type,
101name=input.name,
102preview_url=AnyUrl(f"https://example.com/{id}/preview"),
103upload_url=AnyUrl(f"https://example.com/{id}/upload"),
104)
105else:
106attachment = FileAttachment(
107id=id,
108mime_type=input.mime_type,
109name=input.name,
110upload_url=AnyUrl(f"https://example.com/{id}/upload"),
111)
112self.files[attachment.id] = attachment
113return attachment
114
115
116def decode_event(event: bytes) -> ThreadStreamEvent:
117return TypeAdapter(ThreadStreamEvent).validate_json(event.split(b"data: ")[1])
118
119
120async def decode_streaming_result(
121streaming_result: StreamingResult,
122) -> list[ThreadStreamEvent]:
123return [decode_event(event) async for event in streaming_result.json_events]
124
125
126@contextmanager
127def make_server(
128responder: Callable[
129[ThreadMetadata, UserMessageItem | None, Any], AsyncIterator[ThreadStreamEvent]
130]
131| None = None,
132handle_feedback: Callable[[str, list[str], FeedbackKind, Any], None] | None = None,
133action_callback: Callable[
134[ThreadMetadata, Action[str, Any], WidgetItem | None, Any],
135AsyncIterator[ThreadStreamEvent],
136]
137| None = None,
138file_store: AttachmentStore | None = None,
139):
140global server_id
141db_path = f"file:{server_id}?mode=memory&cache=shared"
142server_id += 1
143# Keep the shared in-memory database from being deleted when the last connection is closed
144db = sqlite3.connect(db_path, uri=True)
145
146class TestChatKitServer(ChatKitServer):
147def __init__(self):
148super().__init__(SQLiteStore(db_path), file_store)
149
150def action(
151self,
152thread: ThreadMetadata,
153action: Action[str, Any],
154sender: WidgetItem | None,
155context: Any,
156) -> AsyncIterator[ThreadStreamEvent]:
157if action_callback is None:
158raise ValueError("action_callback not wired up")
159return action_callback(thread, action, sender, context)
160
161def respond(
162self,
163thread: ThreadMetadata,
164input_user_message: UserMessageItem | None,
165context: Any,
166) -> AsyncIterator[ThreadStreamEvent]:
167if responder is None:
168return self._empty_responder()
169return responder(thread, input_user_message, context)
170
171async def _empty_responder(self) -> AsyncIterator[ThreadStreamEvent]:
172return
173yield
174
175async def add_feedback(
176self,
177thread_id: str,
178item_ids: list[str],
179feedback: FeedbackKind,
180context: Any,
181) -> None:
182if handle_feedback is None:
183return
184handle_feedback(thread_id, item_ids, feedback, context)
185
186async def process_streaming(
187self, request_obj, context: Any | None = None
188) -> list[ThreadStreamEvent]:
189result = await self.process(
190request_obj.model_dump_json(), context or DEFAULT_CONTEXT
191)
192assert isinstance(result, StreamingResult)
193return await decode_streaming_result(result)
194
195async def process_non_streaming(self, request_obj, context: Any | None = None):
196result = await self.process(
197request_obj.model_dump_json(), context or DEFAULT_CONTEXT
198)
199assert isinstance(result, NonStreamingResult)
200return result
201
202try:
203yield TestChatKitServer()
204finally:
205db.close()
206
207
208async def test_flows_context_to_responder():
209responder_context = None
210add_feedback_context = None
211
212async def responder(
213thread: ThreadMetadata, input: UserMessageItem | None, context: Any
214) -> AsyncIterator[ThreadStreamEvent]:
215nonlocal responder_context
216responder_context = context
217yield ThreadItemDoneEvent(
218item=AssistantMessageItem(
219id="msg_1",
220created_at=datetime.now(),
221content=[AssistantMessageContent(text="Hello, world!")],
222thread_id=thread.id,
223),
224)
225
226def add_feedback(
227thread_id: str, item_ids: list[str], feedback: FeedbackKind, context: Any
228) -> None:
229nonlocal add_feedback_context
230add_feedback_context = context
231
232context = RequestContext(user_id="text_ctx")
233with make_server(responder, add_feedback) as server:
234events = await server.process_streaming(
235ThreadsCreateReq(
236params=ThreadCreateParams(
237input=UserMessageInput(
238content=[UserMessageTextContent(text="Hello, world!")],
239attachments=[],
240inference_options=InferenceOptions(),
241)
242)
243),
244context,
245)
246assert responder_context == context
247thread = next(
248event.thread for event in events if event.type == "thread.created"
249)
250item = next(
251event.item
252for event in events
253if event.type == "thread.item.done"
254and event.item.type == "assistant_message"
255)
256await server.process_non_streaming(
257ItemsFeedbackReq(
258params=ItemFeedbackParams(
259thread_id=thread.id,
260item_ids=[item.id],
261kind="positive",
262)
263),
264context,
265)
266assert add_feedback_context == context
267
268
269async def test_creates_thread():
270async def responder(
271thread: ThreadMetadata, input: UserMessageItem | None, context: Any
272) -> AsyncIterator[ThreadStreamEvent]:
273return
274yield
275
276with make_server(responder) as server:
277events = await server.process_streaming(
278ThreadsCreateReq(
279params=ThreadCreateParams(
280input=UserMessageInput(
281content=[UserMessageTextContent(text="Hello, world!")],
282attachments=[],
283inference_options=InferenceOptions(),
284quoted_text="Quoted text",
285)
286)
287)
288)
289
290thread_created_event = next(
291event for event in events if event.type == "thread.created"
292)
293assert thread_created_event.thread.title is None
294
295item_done_event = next(
296event.item for event in events if event.type == "thread.item.done"
297)
298
299assert item_done_event.type == "user_message"
300assert item_done_event.content[0].text == "Hello, world!"
301assert item_done_event.quoted_text == "Quoted text"
302
303
304async def test_saves_thread_title():
305async def responder(
306thread: ThreadMetadata, input: UserMessageItem | None, context: Any
307) -> AsyncIterator[ThreadStreamEvent]:
308thread.title = "Updated title"
309return
310yield
311
312with make_server(responder) as server:
313events = await server.process_streaming(
314ThreadsCreateReq(
315params=ThreadCreateParams(
316input=UserMessageInput(
317content=[UserMessageTextContent(text="Hello, world!")],
318attachments=[],
319inference_options=InferenceOptions(),
320)
321)
322)
323)
324thread = next(
325event.thread for event in events if event.type == "thread.created"
326)
327assert (
328await server.store.load_thread(
329thread.id, RequestContext(user_id="test_user")
330)
331).title == "Updated title"
332assert events[-1].type == "thread.updated"
333assert events[-1].thread.title == "Updated title"
334
335
336async def test_saves_thread_metadata():
337async def responder(
338thread: ThreadMetadata, input: UserMessageItem | None, context: Any
339) -> AsyncIterator[ThreadStreamEvent]:
340thread.metadata["test_key"] = "test_value"
341return
342yield
343
344with make_server(responder) as server:
345events = await server.process_streaming(
346ThreadsCreateReq(
347params=ThreadCreateParams(
348input=UserMessageInput(
349content=[UserMessageTextContent(text="Hello, world!")],
350attachments=[],
351inference_options=InferenceOptions(),
352)
353)
354)
355)
356thread = next(
357event.thread for event in events if event.type == "thread.created"
358)
359assert (await server.store.load_thread(thread.id, DEFAULT_CONTEXT)).metadata[
360"test_key"
361] == "test_value"
362assert events[-1].type == "thread.updated"
363
364
365async def test_saves_thread_locked_fields():
366async def responder(
367thread: ThreadMetadata, input: UserMessageItem | None, context: Any
368) -> AsyncIterator[ThreadStreamEvent]:
369thread.status = LockedStatus(reason="Because")
370return
371yield
372
373with make_server(responder) as server:
374events = await server.process_streaming(
375ThreadsCreateReq(
376params=ThreadCreateParams(
377input=UserMessageInput(
378content=[UserMessageTextContent(text="Hello, world!")],
379attachments=[],
380inference_options=InferenceOptions(),
381)
382)
383)
384)
385thread = next(
386event.thread for event in events if event.type == "thread.created"
387)
388loaded = await server.store.load_thread(
389thread.id, RequestContext(user_id="test_user")
390)
391assert loaded.status == LockedStatus(reason="Because")
392assert events[-1].type == "thread.updated"
393assert events[-1].thread.status == LockedStatus(reason="Because")
394
395
396async def test_emits_thread_updated_mid_stream_and_persists():
397async def responder(
398thread: ThreadMetadata, input: UserMessageItem | None, context: Any
399) -> AsyncIterator[ThreadStreamEvent]:
400# First assistant message
401yield ThreadItemDoneEvent(
402item=AssistantMessageItem(
403id="assistant_a",
404content=[AssistantMessageContent(text="A")],
405created_at=datetime.now(),
406thread_id=thread.id,
407),
408)
409
410# Update the thread mid-stream
411thread.title = "Mid-stream title"
412
413# Second assistant message, after which thread.updated should be emitted
414yield ThreadItemDoneEvent(
415item=AssistantMessageItem(
416id="assistant_b",
417content=[AssistantMessageContent(text="B")],
418created_at=datetime.now(),
419thread_id=thread.id,
420),
421)
422
423# Continue streaming after the update event
424yield ThreadItemDoneEvent(
425item=AssistantMessageItem(
426id="assistant_c",
427content=[AssistantMessageContent(text="C")],
428created_at=datetime.now(),
429thread_id=thread.id,
430),
431)
432
433with make_server(responder) as server:
434events = await server.process_streaming(
435ThreadsCreateReq(
436params=ThreadCreateParams(
437input=UserMessageInput(
438content=[UserMessageTextContent(text="Hello")],
439attachments=[],
440inference_options=InferenceOptions(),
441)
442)
443)
444)
445
446# Find the created thread
447thread = next(e.thread for e in events if isinstance(e, ThreadCreatedEvent))
448
449# Ensure a thread.updated event occurs right after assistant_b
450idx_b = next(
451i
452for i, e in enumerate(events)
453if isinstance(e, ThreadItemDoneEvent)
454and isinstance(e.item, AssistantMessageItem)
455and e.item.id == "assistant_b"
456)
457assert isinstance(events[idx_b + 1], ThreadUpdatedEvent)
458updated_event = cast(ThreadUpdatedEvent, events[idx_b + 1])
459assert updated_event.thread.title == "Mid-stream title"
460
461# Streaming continues after the update event
462assert isinstance(events[idx_b + 2], ThreadItemDoneEvent)
463done_event = cast(ThreadItemDoneEvent, events[idx_b + 2])
464assert isinstance(done_event.item, AssistantMessageItem)
465assert done_event.item.id == "assistant_c"
466
467# Persisted in the store
468saved = await server.store.load_thread(thread.id, DEFAULT_CONTEXT)
469assert saved.title == "Mid-stream title"
470
471
472async def test_respond_with_tool_call():
473async def responder(
474thread: ThreadMetadata, input: UserMessageItem | None, context: Any
475) -> AsyncIterator[ThreadStreamEvent]:
476if isinstance(input, UserMessageItem):
477yield ThreadItemDoneEvent(
478item=ClientToolCallItem(
479id="msg_1",
480created_at=datetime.now(),
481name="tool_call_1",
482arguments={"arg1": "val1", "arg2": False},
483call_id="tool_call_1",
484thread_id=thread.id,
485),
486)
487elif input is None:
488yield ThreadItemDoneEvent(
489item=AssistantMessageItem(
490id="msg_2",
491content=[
492AssistantMessageContent(text="Glad the tool call succeeded!")
493],
494created_at=datetime.now(),
495thread_id=thread.id,
496),
497)
498
499with make_server(responder) as server:
500events = await server.process_streaming(
501ThreadsCreateReq(
502params=ThreadCreateParams(
503input=UserMessageInput(
504content=[UserMessageTextContent(text="Hello, world!")],
505attachments=[],
506inference_options=InferenceOptions(),
507)
508)
509)
510)
511
512assert len(events) == 3
513assert events[0].type == "thread.created"
514thread = events[0].thread
515
516assert events[1].type == "thread.item.done"
517assert events[1].item.type == "user_message"
518
519assert events[2].type == "thread.item.done"
520assert events[2].item.type == "client_tool_call"
521assert events[2].item.id == "msg_1"
522assert events[2].item.name == "tool_call_1"
523assert events[2].item.arguments == {"arg1": "val1", "arg2": False}
524assert events[2].item.call_id == "tool_call_1"
525
526events = await server.process_streaming(
527ThreadsAddClientToolOutputReq(
528params=ThreadAddClientToolOutputParams(
529thread_id=thread.id,
530result={"text": "Wow!"},
531)
532)
533)
534
535assert len(events) == 1
536assert events[0].type == "thread.item.done"
537assert events[0].item.type == "assistant_message"
538
539
540async def test_removes_tool_call_if_no_output_provided():
541async def responder(
542thread: ThreadMetadata, input: UserMessageItem | None, context: Any
543) -> AsyncIterator[ThreadStreamEvent]:
544assert isinstance(input, UserMessageItem)
545assert input.content[0].type == "input_text"
546if input.content[0].text == "Message 1":
547yield ThreadItemDoneEvent(
548item=ClientToolCallItem(
549id="msg_1",
550created_at=datetime.now(),
551name="tool_call_1",
552arguments={"arg1": "val1", "arg2": False},
553call_id="tool_call_1",
554thread_id=thread.id,
555),
556)
557else:
558yield ThreadItemDoneEvent(
559item=AssistantMessageItem(
560id="msg_2",
561content=[AssistantMessageContent(text="All done!")],
562created_at=datetime.now(),
563thread_id=thread.id,
564),
565)
566
567with make_server(responder) as server:
568events = await server.process_streaming(
569ThreadsCreateReq(
570params=ThreadCreateParams(
571input=UserMessageInput(
572content=[UserMessageTextContent(text="Message 1")],
573attachments=[],
574inference_options=InferenceOptions(),
575)
576)
577)
578)
579thread = next(
580event.thread for event in events if event.type == "thread.created"
581)
582
583await server.process_streaming(
584ThreadsAddUserMessageReq(
585params=ThreadAddUserMessageParams(
586thread_id=thread.id,
587input=UserMessageInput(
588content=[UserMessageTextContent(text="Message 2")],
589attachments=[],
590inference_options=InferenceOptions(),
591),
592)
593)
594)
595
596items_result = await server.process_non_streaming(
597ItemsListReq(params=ItemsListParams(thread_id=thread.id))
598)
599items = TypeAdapter(Page[ThreadItem]).validate_json(items_result.json)
600assert len(items.data) == 3
601assert items.data[0].type == "assistant_message"
602assert items.data[1].type == "user_message"
603assert items.data[2].type == "user_message"
604
605
606async def test_respond_with_tool_status():
607async def responder(
608thread: ThreadMetadata, input: UserMessageItem | None, context: Any
609) -> AsyncIterator[ThreadStreamEvent]:
610yield ProgressUpdateEvent(text="Tool status")
611yield ThreadItemDoneEvent(
612item=AssistantMessageItem(
613id="msg_4",
614content=[AssistantMessageContent(text="Hello, world!")],
615created_at=datetime.now(),
616thread_id=thread.id,
617),
618)
619
620with make_server(responder) as server:
621events = await server.process_streaming(
622ThreadsCreateReq(
623params=ThreadCreateParams(
624input=UserMessageInput(
625content=[UserMessageTextContent(text="Hello, world!")],
626attachments=[],
627inference_options=InferenceOptions(),
628)
629)
630)
631)
632
633assert len(events) == 4
634assert events[0].type == "thread.created"
635assert events[1].type == "thread.item.done"
636assert events[2].type == "progress_update"
637assert events[3].type == "thread.item.done"
638
639
640async def test_list_threads_response():
641with make_server() as server:
642for i in range(25):
643await server.process_streaming(
644ThreadsCreateReq(
645params=ThreadCreateParams(
646input=UserMessageInput(
647content=[UserMessageTextContent(text=f"Thread {i}")],
648attachments=[],
649inference_options=InferenceOptions(),
650)
651)
652)
653)
654list_result = await server.process_non_streaming(
655ThreadsListReq(params=ThreadListParams(limit=20))
656)
657threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
658assert len(threads.data) == 20
659
660# Newest first
661assert threads.data[0].created_at > threads.data[1].created_at
662assert threads.has_more is True
663assert threads.after is not None
664
665more_threads = await server.process_non_streaming(
666ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
667)
668threads = TypeAdapter(Page[ThreadMetadata]).validate_json(more_threads.json)
669assert len(threads.data) == 5
670assert threads.has_more is False
671assert not threads.after
672
673
674async def test_list_items_response():
675async def responder(
676thread: ThreadMetadata, input: UserMessageItem | None, context: Any
677) -> AsyncIterator[ThreadStreamEvent]:
678for i in range(25):
679yield ThreadItemDoneEvent(
680item=UserMessageItem(
681id=f"msg_{i}",
682content=[UserMessageTextContent(text=f"Message {i}")],
683attachments=[],
684inference_options=InferenceOptions(),
685thread_id=thread.id,
686created_at=datetime.now(),
687),
688)
689
690with make_server(responder) as server:
691events = await server.process_streaming(
692ThreadsCreateReq(
693params=ThreadCreateParams(
694input=UserMessageInput(
695content=[UserMessageTextContent(text="Thread 1")],
696attachments=[],
697inference_options=InferenceOptions(),
698)
699)
700)
701)
702thread = next(
703event.thread for event in events if event.type == "thread.created"
704)
705list_result = await server.process_non_streaming(
706ItemsListReq(params=ItemsListParams(thread_id=thread.id))
707)
708items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
709assert len(items.data) == 20
710# Newest first
711assert items.data[0].created_at > items.data[1].created_at
712assert items.has_more is True
713assert items.after is not None
714
715list_result = await server.process_non_streaming(
716ItemsListReq(
717params=ItemsListParams(thread_id=thread.id, after=items.data[-1].id)
718)
719)
720items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
721assert len(items.data) == 6 # +1 for user message
722assert items.has_more is False
723
724
725async def test_calls_action():
726actions = []
727
728async def responder(
729thread: ThreadMetadata,
730input: UserMessageItem | None,
731context: Any,
732) -> AsyncIterator[ThreadStreamEvent]:
733yield ThreadItemDoneEvent(
734item=WidgetItem(
735id="widget_1",
736type="widget",
737created_at=datetime.now(),
738thread_id=thread.id,
739widget=Card(
740children=[
741Text(
742key="text_1",
743value="test",
744)
745],
746),
747),
748)
749
750async def action(
751thread: ThreadMetadata,
752action: Action[str, Any],
753sender: WidgetItem | None,
754context: Any,
755) -> AsyncIterator[ThreadStreamEvent]:
756actions.append((action, sender))
757assert sender
758
759yield ThreadItemUpdated(
760item_id=sender.id,
761update=WidgetRootUpdated(
762widget=Card(
763children=[
764Text(value="Email sent!"),
765]
766),
767),
768)
769
770with make_server(responder, action_callback=action) as server:
771events = await server.process_streaming(
772ThreadsCreateReq(
773params=ThreadCreateParams(
774input=UserMessageInput(
775content=[UserMessageTextContent(text="Show widget")],
776attachments=[],
777inference_options=InferenceOptions(),
778)
779)
780)
781)
782thread = next(
783event.thread for event in events if event.type == "thread.created"
784)
785widget_item = next(
786event.item
787for event in events
788if isinstance(event, ThreadItemDoneEvent)
789and isinstance(event.item, WidgetItem)
790)
791
792events = await server.process_streaming(
793ThreadsCustomActionReq(
794params=ThreadCustomActionParams(
795thread_id=thread.id,
796item_id=widget_item.id,
797action=Action(type="create_user", payload={"user_id": "123"}),
798)
799)
800)
801
802assert actions
803assert actions[0] == (
804Action(type="create_user", payload={"user_id": "123"}),
805widget_item,
806)
807
808assert len(events) == 1
809assert events[0].type == "thread.item.updated"
810assert isinstance(events[0], ThreadItemUpdated)
811assert events[0].update.type == "widget.root.updated"
812assert events[0].update.widget == Card(children=[Text(value="Email sent!")])
813
814
815async def test_add_feedback():
816called = []
817
818def handle_feedback(thread_id, item_ids, feedback, context):
819called.append((thread_id, item_ids, feedback, context))
820
821async def responder(
822thread: ThreadMetadata,
823input: UserMessageItem | None,
824context: Any,
825) -> AsyncIterator[ThreadStreamEvent]:
826yield ThreadItemDoneEvent(
827item=AssistantMessageItem(
828id="assistant_msg_1",
829content=[
830AssistantMessageContent(text="Feedback received!"),
831],
832created_at=datetime.now(),
833thread_id=thread.id,
834),
835)
836yield ThreadItemDoneEvent(
837item=WidgetItem(
838id="widget_1",
839type="widget",
840created_at=datetime.now(),
841widget=Card(children=[Text(key="text", value="Widget content")]),
842thread_id=thread.id,
843),
844)
845
846with make_server(responder, handle_feedback=handle_feedback) as server:
847events = await server.process_streaming(
848ThreadsCreateReq(
849params=ThreadCreateParams(
850input=UserMessageInput(
851content=[UserMessageTextContent(text="Test feedback")],
852attachments=[],
853inference_options=InferenceOptions(),
854)
855)
856)
857)
858thread = next(
859event.thread for event in events if event.type == "thread.created"
860)
861
862await server.process_non_streaming(
863ItemsFeedbackReq(
864params=ItemFeedbackParams(
865thread_id=thread.id,
866item_ids=["assistant_msg_1"],
867kind="positive",
868)
869)
870)
871await server.process_non_streaming(
872ItemsFeedbackReq(
873params=ItemFeedbackParams(
874thread_id=thread.id,
875item_ids=["widget_1"],
876kind="negative",
877)
878)
879)
880
881assert len(called) == 2
882assert called[0][0] == thread.id
883assert called[0][1] == ["assistant_msg_1"]
884assert called[0][2] == "positive"
885
886assert called[1][0] == thread.id
887assert called[1][1] == ["widget_1"]
888assert called[1][2] == "negative"
889
890
891async def test_get_thread_by_id():
892make_thread_items()
893
894with make_server() as server:
895# Create a thread by sending a ThreadCreateRequest
896events = await server.process_streaming(
897ThreadsCreateReq(
898params=ThreadCreateParams(
899input=UserMessageInput(
900content=[UserMessageTextContent(text="Test thread")],
901attachments=[],
902inference_options=InferenceOptions(),
903)
904)
905)
906)
907thread = next(
908event.thread for event in events if event.type == "thread.created"
909)
910
911# Retrieve the thread by id
912result = await server.process_non_streaming(
913ThreadsGetByIdReq(
914params=ThreadGetByIdParams(thread_id=thread.id),
915)
916)
917loaded_thread = TypeAdapter(Thread).validate_json(result.json)
918assert loaded_thread.id == thread.id
919assert loaded_thread.title == thread.title
920assert loaded_thread.items.data[0].type == "user_message"
921assert loaded_thread.items.data[0].content[0].text == "Test thread"
922
923
924async def test_create_file():
925store = InMemoryFileStore()
926file_name = "test-file-name"
927file_content_type = "text/plain"
928
929with make_server(file_store=store) as server:
930result = await server.process_non_streaming(
931AttachmentsCreateReq(
932params=AttachmentCreateParams(
933name=file_name, size=1024, mime_type=file_content_type
934)
935)
936)
937attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
938
939# Verify response includes file
940assert attachment.id is not None
941assert attachment.mime_type == file_content_type
942assert attachment.name == file_name
943assert attachment.type == "file"
944assert attachment.upload_url == AnyUrl(
945f"https://example.com/{attachment.id}/upload"
946)
947assert attachment.id in store.files
948
949
950async def test_create_image_file():
951store = InMemoryFileStore()
952file_name = "test-file-name"
953file_content_type = "image/png"
954
955with make_server(file_store=store) as server:
956result = await server.process_non_streaming(
957AttachmentsCreateReq(
958params=AttachmentCreateParams(
959name=file_name,
960size=1024,
961mime_type=file_content_type,
962)
963)
964)
965attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
966
967assert attachment.id is not None
968assert attachment.mime_type == file_content_type
969assert attachment.name == file_name
970assert attachment.type == "image"
971assert attachment.preview_url == AnyUrl(
972f"https://example.com/{attachment.id}/preview"
973)
974assert attachment.upload_url == AnyUrl(
975f"https://example.com/{attachment.id}/upload"
976)
977
978assert attachment.id in store.files
979
980
981async def test_create_file_without_filestore():
982with pytest.raises(RuntimeError):
983with make_server() as server:
984await server.process_non_streaming(
985AttachmentsCreateReq(
986params=AttachmentCreateParams(
987name="test-file-name",
988size=1024,
989mime_type="text/plain",
990)
991)
992)
993
994
995async def test_delete_file():
996store = InMemoryFileStore()
997store.files["test-file-id"] = {"id": "test-file-id"}
998with make_server(file_store=store) as server:
999await server.process_non_streaming(
1000AttachmentsDeleteReq(
1001params=AttachmentDeleteParams(attachment_id="test-file-id")
1002)
1003)
1004assert store.files == {}
1005
1006
1007async def test_delete_file_without_filestore():
1008with pytest.raises(RuntimeError):
1009with make_server() as server:
1010await server.process_non_streaming(
1011AttachmentsDeleteReq(
1012params=AttachmentDeleteParams(attachment_id="test-file-id")
1013)
1014)
1015
1016
1017async def test_list_threads_paginated_response():
1018with make_server() as server:
1019for i in range(25):
1020await server.process_streaming(
1021ThreadsCreateReq(
1022params=ThreadCreateParams(
1023input=UserMessageInput(
1024content=[UserMessageTextContent(text=f"Thread {i}")],
1025attachments=[],
1026inference_options=InferenceOptions(),
1027)
1028)
1029)
1030)
1031list_result = await server.process_non_streaming(
1032ThreadsListReq(params=ThreadListParams(limit=20))
1033)
1034threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1035assert len(threads.data) == 20
1036assert threads.has_more is True
1037assert threads.after is not None
1038
1039list_result = await server.process_non_streaming(
1040ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
1041)
1042threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1043assert len(threads.data) == 5
1044assert threads.has_more is False
1045
1046
1047async def test_threads_update_request():
1048with make_server() as server:
1049events = await server.process_streaming(
1050ThreadsCreateReq(
1051params=ThreadCreateParams(
1052input=UserMessageInput(
1053content=[UserMessageTextContent(text="Hello, world!")],
1054attachments=[],
1055inference_options=InferenceOptions(),
1056)
1057)
1058)
1059)
1060thread = next(
1061event.thread for event in events if event.type == "thread.created"
1062)
1063updated_thread = await server.process_non_streaming(
1064ThreadsUpdateReq(
1065params=ThreadUpdateParams(thread_id=thread.id, title="New Title")
1066)
1067)
1068updated_thread = TypeAdapter[Thread](Thread).validate_json(updated_thread.json)
1069assert updated_thread.title == "New Title"
1070
1071
1072async def test_returns_widget_item_generator():
1073thread = ThreadMetadata(
1074id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1075)
1076
1077def render_widget(i: int) -> Card:
1078return Card(children=[Text(id="text", value="Hello, world"[:i])])
1079
1080async def widget_generator():
1081yield render_widget(0)
1082yield render_widget(3)
1083yield render_widget(6)
1084yield render_widget(12)
1085
1086events = [event async for event in stream_widget(thread, widget_generator())]
1087
1088assert len(events) == 5
1089assert isinstance(events[0], ThreadItemAddedEvent)
1090assert isinstance(events[0].item, WidgetItem)
1091assert events[0].item.widget == Card(children=[Text(id="text", value="")])
1092
1093assert isinstance(events[1], ThreadItemUpdated)
1094assert events[1].update.type == "widget.streaming_text.value_delta"
1095assert events[1].update.component_id == "text"
1096assert events[1].update.delta == "Hel"
1097
1098assert isinstance(events[2], ThreadItemUpdated)
1099assert events[2].update.type == "widget.streaming_text.value_delta"
1100assert events[2].update.component_id == "text"
1101assert events[2].update.delta == "lo,"
1102
1103assert isinstance(events[3], ThreadItemUpdated)
1104assert events[3].update.type == "widget.streaming_text.value_delta"
1105assert events[3].update.component_id == "text"
1106assert events[3].update.delta == " world"
1107
1108assert isinstance(events[4], ThreadItemDoneEvent)
1109assert isinstance(events[4].item, WidgetItem)
1110assert events[4].item.widget == Card(
1111children=[Text(id="text", value="Hello, world")]
1112)
1113
1114
1115async def test_returns_widget_item_generator_full_replace():
1116thread = ThreadMetadata(
1117id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1118)
1119
1120async def widget_generator():
1121yield Card(children=[Text(id="text", value="Hello")])
1122yield Card(children=[Text(id="text", value="World")])
1123
1124events = [event async for event in stream_widget(thread, widget_generator())]
1125
1126assert len(events) == 3
1127assert isinstance(events[0], ThreadItemAddedEvent)
1128assert isinstance(events[0].item, WidgetItem)
1129assert events[0].item.widget == Card(children=[Text(id="text", value="Hello")])
1130
1131assert isinstance(events[1], ThreadItemUpdated)
1132assert events[1].update.type == "widget.root.updated"
1133assert events[1].update.widget == Card(children=[Text(id="text", value="World")])
1134
1135assert isinstance(events[2], ThreadItemDoneEvent)
1136assert isinstance(events[2].item, WidgetItem)
1137assert events[2].item.widget == Card(children=[Text(id="text", value="World")])
1138
1139
1140async def test_delete_thread():
1141with make_server() as server:
1142events = await server.process_streaming(
1143ThreadsCreateReq(
1144params=ThreadCreateParams(
1145input=UserMessageInput(
1146content=[UserMessageTextContent(text="Thread to delete")],
1147attachments=[],
1148inference_options=InferenceOptions(),
1149)
1150)
1151)
1152)
1153thread = next(
1154event.thread for event in events if event.type == "thread.created"
1155)
1156
1157response = await server.process_non_streaming(
1158ThreadsDeleteReq(params=ThreadDeleteParams(thread_id=thread.id))
1159)
1160assert response.json == b"{}"
1161
1162try:
1163await server.process_non_streaming(
1164ThreadsGetByIdReq(
1165params=ThreadGetByIdParams(thread_id=thread.id),
1166)
1167)
1168assert False, "Expected ValueError for deleted thread"
1169except NotFoundError as e:
1170assert f"Thread {thread.id} not found" in str(e)
1171
1172
1173async def test_thread_item_removed_event_removes_item():
1174removed_id = "msg_to_remove"
1175
1176async def responder(
1177thread: ThreadMetadata,
1178input: UserMessageItem | None,
1179context: Any,
1180) -> AsyncIterator[ThreadStreamEvent]:
1181yield ThreadItemDoneEvent(
1182item=UserMessageItem(
1183id=removed_id,
1184content=[UserMessageTextContent(text="To be removed")],
1185attachments=[],
1186inference_options=InferenceOptions(),
1187thread_id=thread.id,
1188created_at=datetime.now(),
1189),
1190)
1191yield ThreadItemRemovedEvent(item_id=removed_id)
1192
1193with make_server(responder) as server:
1194events = await server.process_streaming(
1195ThreadsCreateReq(
1196params=ThreadCreateParams(
1197input=UserMessageInput(
1198content=[UserMessageTextContent(text="Trigger removal")],
1199attachments=[],
1200inference_options=InferenceOptions(),
1201)
1202)
1203)
1204)
1205thread = next(
1206event.thread for event in events if event.type == "thread.created"
1207)
1208assert any(
1209e.type == "thread.item.removed" and e.item_id == removed_id for e in events
1210)
1211# The item should not be present in the store
1212list_result = await server.process_non_streaming(
1213ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1214)
1215items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1216assert all(i.id != removed_id for i in items)
1217
1218
1219async def test_raising_in_responder_yields_error_event():
1220async def responder(
1221thread: ThreadMetadata,
1222input: UserMessageItem | None,
1223context: Any,
1224) -> AsyncIterator[ThreadStreamEvent]:
1225raise ValueError("Test error")
1226yield
1227
1228with make_server(responder) as server:
1229result = await server.process(
1230ThreadsCreateReq(
1231params=ThreadCreateParams(
1232input=UserMessageInput(
1233content=[UserMessageTextContent(text="Test error")],
1234attachments=[],
1235inference_options=InferenceOptions(),
1236)
1237)
1238).model_dump_json(),
1239RequestContext(user_id="test_user"),
1240)
1241assert isinstance(result, StreamingResult)
1242iter = result.__aiter__()
1243
1244# Yield an errro
1245async for e in iter:
1246e = decode_event(e)
1247if e.type == "error":
1248assert e.code == ErrorCode.STREAM_ERROR
1249break
1250else:
1251assert False, "Expected error event"
1252
1253# Then finish
1254try:
1255await iter.__anext__()
1256except StopAsyncIteration:
1257pass
1258else:
1259assert False, "Expected StopAsyncIteration"
1260
1261
1262async def test_thread_item_replaced_event_replaces_item():
1263original_id = "msg_to_replace"
1264
1265async def responder(
1266thread: ThreadMetadata,
1267input: UserMessageItem | None,
1268context: Any,
1269) -> AsyncIterator[ThreadStreamEvent]:
1270# First add an item
1271yield ThreadItemDoneEvent(
1272item=UserMessageItem(
1273id=original_id,
1274content=[UserMessageTextContent(text="Original content")],
1275attachments=[],
1276inference_options=InferenceOptions(),
1277thread_id=thread.id,
1278created_at=datetime.now(),
1279),
1280)
1281# Then replace it
1282replacement_item = AssistantMessageItem(
1283id=original_id,
1284content=[AssistantMessageContent(text="This is the replacement content")],
1285created_at=datetime.now(),
1286thread_id=thread.id,
1287)
1288yield ThreadItemReplacedEvent(item=replacement_item)
1289
1290with make_server(responder) as server:
1291events = await server.process_streaming(
1292ThreadsCreateReq(
1293params=ThreadCreateParams(
1294input=UserMessageInput(
1295content=[UserMessageTextContent(text="Trigger replacement")],
1296attachments=[],
1297inference_options=InferenceOptions(),
1298)
1299)
1300)
1301)
1302thread = next(
1303event.thread for event in events if event.type == "thread.created"
1304)
1305
1306# Verify the replacement event was generated
1307assert any(
1308e.type == "thread.item.replaced" and e.item.id == original_id
1309for e in events
1310)
1311
1312# Verify the item was replaced in the store
1313list_result = await server.process_non_streaming(
1314ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1315)
1316items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1317
1318# Find the replaced item
1319replaced_item = next((i for i in items if i.id == original_id), None)
1320assert replaced_item is not None
1321assert isinstance(replaced_item, AssistantMessageItem)
1322assert isinstance(replaced_item.content[0], AssistantMessageContent)
1323assert replaced_item.content[0].text == "This is the replacement content"
1324
1325
1326async def test_thread_item_replaced_event_with_widget():
1327widget_id = "widget_to_replace"
1328original_widget = Card(children=[Text(value="Original widget content")])
1329replacement_widget = Card(children=[Text(value="Replaced widget content")])
1330
1331async def responder(
1332thread: ThreadMetadata,
1333input: UserMessageItem | None,
1334context: Any,
1335) -> AsyncIterator[ThreadStreamEvent]:
1336# First add a widget item
1337yield ThreadItemDoneEvent(
1338item=WidgetItem(
1339id=widget_id,
1340widget=original_widget,
1341created_at=datetime.now(),
1342thread_id=thread.id,
1343),
1344)
1345# Then replace it
1346yield ThreadItemReplacedEvent(
1347item=WidgetItem(
1348id=widget_id,
1349widget=replacement_widget,
1350created_at=datetime.now(),
1351thread_id=thread.id,
1352),
1353)
1354
1355with make_server(responder) as server:
1356events = await server.process_streaming(
1357ThreadsCreateReq(
1358params=ThreadCreateParams(
1359input=UserMessageInput(
1360content=[
1361UserMessageTextContent(text="Trigger widget replacement")
1362],
1363attachments=[],
1364inference_options=InferenceOptions(),
1365)
1366)
1367)
1368)
1369thread = next(
1370event.thread for event in events if event.type == "thread.created"
1371)
1372
1373# Verify the replacement event was generated
1374replacement_event = next(
1375(
1376e
1377for e in events
1378if e.type == "thread.item.replaced" and e.item.id == widget_id
1379),
1380None,
1381)
1382assert replacement_event is not None
1383assert isinstance(replacement_event.item, WidgetItem)
1384assert replacement_event.item.widget == replacement_widget
1385
1386# Verify the item was replaced in the store
1387list_result = await server.process_non_streaming(
1388ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1389)
1390items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1391
1392# Find the replaced widget item
1393replaced_item = next((i for i in items if i.id == widget_id), None)
1394assert replaced_item is not None
1395assert isinstance(replaced_item, WidgetItem)
1396assert isinstance(replaced_item.widget, Card)
1397assert isinstance(replaced_item.widget.children[0], Text)
1398assert replaced_item.widget.children[0].value == "Replaced widget content"
1399
1400
1401async def test_retry_after_item():
1402responder_calls = []
1403
1404async def responder(
1405thread: ThreadMetadata,
1406input: UserMessageItem | None,
1407context: Any,
1408) -> AsyncIterator[ThreadStreamEvent]:
1409responder_calls.append(input)
1410
1411if len(responder_calls) == 1:
1412# First response - return an assistant message
1413yield ThreadItemDoneEvent(
1414item=AssistantMessageItem(
1415id="assistant_msg_1",
1416content=[AssistantMessageContent(text="First response")],
1417created_at=datetime.now(),
1418thread_id=thread.id,
1419),
1420)
1421else:
1422# Retry response - return different content
1423yield ThreadItemDoneEvent(
1424item=AssistantMessageItem(
1425id="assistant_msg_2",
1426content=[AssistantMessageContent(text="Retried response")],
1427created_at=datetime.now(),
1428thread_id=thread.id,
1429),
1430)
1431
1432with make_server(responder) as server:
1433events = await server.process_streaming(
1434ThreadsCreateReq(
1435params=ThreadCreateParams(
1436input=UserMessageInput(
1437content=[UserMessageTextContent(text="Hello, world!")],
1438attachments=[],
1439inference_options=InferenceOptions(),
1440)
1441)
1442)
1443)
1444thread = next(
1445event.thread for event in events if event.type == "thread.created"
1446)
1447
1448list_result = await server.process_non_streaming(
1449ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1450)
1451items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1452assert len(items_before.data) == 2
1453assert items_before.data[0].type == "assistant_message"
1454assert items_before.data[0].content[0].type == "output_text"
1455assert items_before.data[0].content[0].text == "First response"
1456assert items_before.data[1].type == "user_message"
1457assert items_before.data[1].content[0].text == "Hello, world!"
1458
1459# Get the user message ID for retrying after it
1460user_message_id = items_before.data[1].id
1461
1462# Retry after the user message
1463retry_events = await server.process_streaming(
1464ThreadsRetryAfterItemReq(
1465params=ThreadRetryAfterItemParams(
1466thread_id=thread.id, item_id=user_message_id
1467)
1468)
1469)
1470
1471# Verify the retry generated new response
1472assert len(retry_events) == 1
1473assert retry_events[0].type == "thread.item.done"
1474assert retry_events[0].item.type == "assistant_message"
1475assert retry_events[0].item.content[0].type == "output_text"
1476assert retry_events[0].item.content[0].text == "Retried response"
1477
1478# Verify the responder was called twice with the same user message
1479assert len(responder_calls) == 2
1480assert responder_calls[0] == responder_calls[1]
1481assert responder_calls[0].type == "user_message"
1482assert responder_calls[0].content[0].text == "Hello, world!"
1483
1484# Verify the assistant response was replaced in storage
1485list_result_after = await server.process_non_streaming(
1486ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1487)
1488items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1489list_result_after.json
1490)
1491assert len(items_after.data) == 2 # Still user message + assistant message
1492assert items_after.data[0].type == "assistant_message"
1493assert items_after.data[0].content[0].type == "output_text"
1494assert items_after.data[0].content[0].text == "Retried response" # New response
1495assert items_after.data[1].type == "user_message"
1496assert items_after.data[1].content[0].text == "Hello, world!"
1497
1498
1499async def test_retry_after_item_invalid_item():
1500async def responder(
1501thread: ThreadMetadata,
1502input: UserMessageItem | None,
1503context: Any,
1504) -> AsyncIterator[ThreadStreamEvent]:
1505yield ThreadItemDoneEvent(
1506item=AssistantMessageItem(
1507id="assistant_msg_1",
1508content=[AssistantMessageContent(text="Response")],
1509created_at=datetime.now(),
1510thread_id=thread.id,
1511),
1512)
1513
1514with make_server(responder) as server:
1515events = await server.process_streaming(
1516ThreadsCreateReq(
1517params=ThreadCreateParams(
1518input=UserMessageInput(
1519content=[UserMessageTextContent(text="Hello")],
1520attachments=[],
1521inference_options=InferenceOptions(),
1522)
1523)
1524)
1525)
1526thread = next(
1527event.thread for event in events if event.type == "thread.created"
1528)
1529
1530items_result = await server.process_non_streaming(
1531ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1532)
1533items = TypeAdapter(Page[ThreadItem]).validate_json(items_result.json)
1534assistant_item_id = items.data[0].id
1535
1536# Try to retry after an assistant message (should fail)
1537with pytest.raises(ValueError, match="is not a user message"):
1538await server.process_streaming(
1539ThreadsRetryAfterItemReq(
1540params=ThreadRetryAfterItemParams(
1541thread_id=thread.id, item_id=assistant_item_id
1542)
1543)
1544)
1545
1546
1547async def test_retry_after_item_with_multiple_messages():
1548responder_calls = []
1549
1550async def responder(
1551thread: ThreadMetadata,
1552input: UserMessageItem | None,
1553context: Any,
1554) -> AsyncIterator[ThreadStreamEvent]:
1555responder_calls.append(input)
1556assert input is not None
1557assert input.type == "user_message"
1558yield ThreadItemDoneEvent(
1559item=AssistantMessageItem(
1560id=f"assistant_msg_{len(responder_calls)}",
1561content=[
1562AssistantMessageContent(
1563text=f"Response to: {input.content[0].text}"
1564)
1565],
1566created_at=datetime.now(),
1567thread_id=thread.id,
1568),
1569)
1570
1571with make_server(responder) as server:
1572events = await server.process_streaming(
1573ThreadsCreateReq(
1574params=ThreadCreateParams(
1575input=UserMessageInput(
1576content=[UserMessageTextContent(text="First message")],
1577attachments=[],
1578inference_options=InferenceOptions(),
1579)
1580)
1581)
1582)
1583thread = next(
1584event.thread for event in events if event.type == "thread.created"
1585)
1586
1587await server.process_streaming(
1588ThreadsAddUserMessageReq(
1589params=ThreadAddUserMessageParams(
1590thread_id=thread.id,
1591input=UserMessageInput(
1592content=[UserMessageTextContent(text="Second message")],
1593attachments=[],
1594inference_options=InferenceOptions(),
1595),
1596)
1597)
1598)
1599
1600# Verify we have 4 items: user1, assistant1, user2, assistant2
1601list_result = await server.process_non_streaming(
1602ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1603)
1604items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1605assert len(items_before.data) == 4
1606
1607# Get the second user message ID to retry after it
1608second_user_message_id = items_before.data[1].id # Second user message
1609
1610# Retry after the second user message
1611retry_events = await server.process_streaming(
1612ThreadsRetryAfterItemReq(
1613params=ThreadRetryAfterItemParams(
1614thread_id=thread.id, item_id=second_user_message_id
1615)
1616)
1617)
1618
1619assert len(retry_events) == 1
1620assert retry_events[0].type == "thread.item.done"
1621
1622# Verify retry used the second user message
1623assert len(responder_calls) == 3 # Original 2 + 1 retry
1624assert responder_calls[2].type == "user_message"
1625assert responder_calls[2].content[0].text == "Second message"
1626
1627# Verify only the last assistant response was removed and regenerated
1628list_result_after = await server.process_non_streaming(
1629ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1630)
1631items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1632list_result_after.json
1633)
1634assert len(items_after.data) == 4 # Still 4 items
1635# First user message and response should remain unchanged
1636assert items_after.data[3].type == "user_message"
1637assert items_after.data[3].content[0].text == "First message"
1638assert items_after.data[3].content[0].type == "input_text"
1639assert items_after.data[3].content[0].text == "First message"
1640assert items_after.data[2].type == "assistant_message"
1641assert items_after.data[2].content[0].type == "output_text"
1642assert items_after.data[2].content[0].text == "Response to: First message"
1643# Second user message should remain, but assistant response should be new
1644assert items_after.data[1].type == "user_message"
1645assert items_after.data[1].content[0].text == "Second message"
1646assert items_after.data[0].type == "assistant_message"
1647assert items_after.data[0].content[0].type == "output_text"
1648assert items_after.data[0].content[0].text == "Response to: Second message"
1649assert items_after.data[0].id != items_before.data[0].id
1650
1651
1652async def test_threads_create_passes_tools_to_responder():
1653tool_choice = ToolChoice(id="web_search")
1654
1655async def responder(
1656thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1657) -> AsyncIterator[ThreadStreamEvent]:
1658assert input is not None
1659assert input.inference_options.tool_choice is not None
1660assert input.inference_options.tool_choice.id == tool_choice.id
1661
1662yield ThreadItemDoneEvent(
1663item=AssistantMessageItem(
1664id="assistant_msg_tools_create",
1665content=[AssistantMessageContent(text="ok")],
1666created_at=datetime.now(),
1667thread_id=thread.id,
1668),
1669)
1670
1671with make_server(responder) as server:
1672events = await server.process_streaming(
1673ThreadsCreateReq(
1674params=ThreadCreateParams(
1675input=UserMessageInput(
1676content=[UserMessageTextContent(text="Hello with tools")],
1677attachments=[],
1678inference_options=InferenceOptions(tool_choice=tool_choice),
1679)
1680)
1681)
1682)
1683# Sanity check that the request flowed through.
1684assert any(e.type == "thread.item.done" for e in events)
1685
1686
1687async def test_retry_after_item_passes_tools_to_responder():
1688pass
1689
1690
1691async def test_threads_add_message_passes_tools_to_responder():
1692tool_choice = ToolChoice(id="retrieval")
1693
1694async def responder(
1695thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1696) -> AsyncIterator[ThreadStreamEvent]:
1697assert input is not None
1698assert input.inference_options.tool_choice is not None
1699assert input.inference_options.tool_choice.id == tool_choice.id
1700
1701yield ThreadItemDoneEvent(
1702item=AssistantMessageItem(
1703id="assistant_msg_tools_add",
1704content=[AssistantMessageContent(text="ok")],
1705created_at=datetime.now(),
1706thread_id=thread.id,
1707),
1708)
1709
1710with make_server(responder) as server:
1711# Create a thread first.
1712events = await server.process_streaming(
1713ThreadsCreateReq(
1714params=ThreadCreateParams(
1715input=UserMessageInput(
1716content=[UserMessageTextContent(text="Start")],
1717attachments=[],
1718inference_options=InferenceOptions(tool_choice=tool_choice),
1719)
1720)
1721)
1722)
1723thread = next(e.thread for e in events if e.type == "thread.created")
1724
1725# Add a message with tools and verify they reach the responder.
1726events = await server.process_streaming(
1727ThreadsAddUserMessageReq(
1728params=ThreadAddUserMessageParams(
1729thread_id=thread.id,
1730input=UserMessageInput(
1731content=[UserMessageTextContent(text="Second")],
1732attachments=[],
1733inference_options=InferenceOptions(tool_choice=tool_choice),
1734),
1735)
1736)
1737)
1738# Sanity check that the request flowed through.
1739assert any(e.type == "thread.item.done" for e in events)
1740
1741
1742async def test_custom_id_generators_on_server():
1743class UniqueIdStore(SQLiteStore):
1744def __init__(self, path):
1745super().__init__(path)
1746self.counter = 0
1747
1748def generate_thread_id(self, context) -> str:
1749self.counter += 1
1750return f"thr_custom_{self.counter}"
1751
1752def generate_item_id(
1753self,
1754item_type: str,
1755thread: ThreadMetadata,
1756context,
1757) -> str:
1758self.counter += 1
1759return f"{item_type}_custom_{self.counter}_{thread.id}"
1760
1761async def responder(
1762thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1763) -> AsyncIterator[ThreadStreamEvent]:
1764# Emit a tool call to later verify tool_call_output id generation
1765return
1766yield
1767
1768with make_server(responder) as server:
1769# Swap in our custom ID-generating store, reusing the same DB path.
1770db_path = cast(SQLiteStore, server.store).db_path
1771server.store = UniqueIdStore(db_path)
1772
1773# Create thread and verify thread id and user message id
1774result = await server.process(
1775ThreadsCreateReq(
1776params=ThreadCreateParams(
1777input=UserMessageInput(
1778content=[UserMessageTextContent(text="Hello")],
1779attachments=[],
1780inference_options=InferenceOptions(),
1781)
1782)
1783).model_dump_json(),
1784DEFAULT_CONTEXT,
1785)
1786assert isinstance(result, StreamingResult)
1787events = await decode_streaming_result(result)
1788
1789thread_created = next(e for e in events if e.type == "thread.created")
1790assert thread_created.thread.id == "thr_custom_1"
1791
1792user_message = next(
1793e.item
1794for e in events
1795if e.type == "thread.item.done" and e.item.type == "user_message"
1796)
1797assert user_message.id == "message_custom_2_thr_custom_1"