openai/chatkit-python

Public

mirrored from https://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
6df9ee60da4de3f063e752e60bd3ceef920f4e78

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_chatkit_server.py

1944lines · modeblame

db36b662Jiwon Kim7 months ago1import asyncio
f688d870victor-openai8 months ago2import sqlite3
3from contextlib import contextmanager
4from datetime import datetime
5from typing import Any, AsyncIterator, Callable, cast
6
7import pytest
8from helpers.mock_store import SQLiteStore
9from pydantic import AnyUrl, TypeAdapter
10
11from chatkit.actions import Action
12from chatkit.errors import ErrorCode
13from chatkit.server import (
14ChatKitServer,
15NonStreamingResult,
16StreamingResult,
17stream_widget,
18)
db36b662Jiwon Kim7 months ago19from chatkit.store import (
20AttachmentStore,
21NotFoundError,
22StoreItemType,
23default_generate_id,
24)
f688d870victor-openai8 months ago25from chatkit.types import (
26AssistantMessageContent,
db36b662Jiwon Kim7 months ago27AssistantMessageContentPartTextDelta,
f688d870victor-openai8 months ago28AssistantMessageItem,
29Attachment,
30AttachmentCreateParams,
31AttachmentDeleteParams,
32AttachmentsCreateReq,
33AttachmentsDeleteReq,
34ClientToolCallItem,
35FeedbackKind,
36FileAttachment,
37ImageAttachment,
38InferenceOptions,
39ItemFeedbackParams,
40ItemsFeedbackReq,
41ItemsListParams,
42ItemsListReq,
43LockedStatus,
44Page,
45ProgressUpdateEvent,
46Thread,
47ThreadAddClientToolOutputParams,
48ThreadAddUserMessageParams,
49ThreadCreatedEvent,
50ThreadCreateParams,
51ThreadCustomActionParams,
52ThreadDeleteParams,
53ThreadGetByIdParams,
54ThreadItem,
55ThreadItemAddedEvent,
56ThreadItemDoneEvent,
57ThreadItemRemovedEvent,
58ThreadItemReplacedEvent,
643a02f7Jiwon Kim7 months ago59ThreadItemUpdatedEvent,
f688d870victor-openai8 months ago60ThreadListParams,
61ThreadMetadata,
62ThreadRetryAfterItemParams,
63ThreadsAddClientToolOutputReq,
64ThreadsAddUserMessageReq,
65ThreadsCreateReq,
66ThreadsCustomActionReq,
67ThreadsDeleteReq,
68ThreadsGetByIdReq,
69ThreadsListReq,
70ThreadsRetryAfterItemReq,
71ThreadStreamEvent,
72ThreadsUpdateReq,
73ThreadUpdatedEvent,
74ThreadUpdateParams,
75ToolChoice,
76UserMessageInput,
77UserMessageItem,
78UserMessageTextContent,
79WidgetItem,
80WidgetRootUpdated,
81)
82from chatkit.widgets import Card, Text
83from tests._types import RequestContext
84from tests.test_store import make_thread_items
85
86server_id = 0
87
88
89DEFAULT_CONTEXT = RequestContext(user_id="test_user")
90
91
92class InMemoryFileStore(AttachmentStore):
93def __init__(self):
94self.files = {}
95
96async def delete_attachment(self, attachment_id: str, context: Any) -> None:
97del self.files[attachment_id]
98
99async def create_attachment(
100self, input: AttachmentCreateParams, context: RequestContext
101) -> Attachment:
102# Simple in-memory attachment creation
103id = f"atc_{len(self.files) + 1}"
104if input.mime_type.startswith("image/"):
105attachment = ImageAttachment(
106id=id,
107mime_type=input.mime_type,
108name=input.name,
109preview_url=AnyUrl(f"https://example.com/{id}/preview"),
110upload_url=AnyUrl(f"https://example.com/{id}/upload"),
111)
112else:
113attachment = FileAttachment(
114id=id,
115mime_type=input.mime_type,
116name=input.name,
117upload_url=AnyUrl(f"https://example.com/{id}/upload"),
118)
119self.files[attachment.id] = attachment
120return attachment
121
122
123def decode_event(event: bytes) -> ThreadStreamEvent:
124return TypeAdapter(ThreadStreamEvent).validate_json(event.split(b"data: ")[1])
125
126
127async def decode_streaming_result(
128streaming_result: StreamingResult,
129) -> list[ThreadStreamEvent]:
130return [decode_event(event) async for event in streaming_result.json_events]
131
132
133@contextmanager
134def make_server(
135responder: Callable[
136[ThreadMetadata, UserMessageItem | None, Any], AsyncIterator[ThreadStreamEvent]
137]
138| None = None,
139handle_feedback: Callable[[str, list[str], FeedbackKind, Any], None] | None = None,
140action_callback: Callable[
141[ThreadMetadata, Action[str, Any], WidgetItem | None, Any],
142AsyncIterator[ThreadStreamEvent],
143]
144| None = None,
145file_store: AttachmentStore | None = None,
146):
147global server_id
148db_path = f"file:{server_id}?mode=memory&cache=shared"
149server_id += 1
150# Keep the shared in-memory database from being deleted when the last connection is closed
151db = sqlite3.connect(db_path, uri=True)
152
153class TestChatKitServer(ChatKitServer):
154def __init__(self):
155super().__init__(SQLiteStore(db_path), file_store)
156
157def action(
158self,
159thread: ThreadMetadata,
160action: Action[str, Any],
161sender: WidgetItem | None,
162context: Any,
163) -> AsyncIterator[ThreadStreamEvent]:
164if action_callback is None:
165raise ValueError("action_callback not wired up")
166return action_callback(thread, action, sender, context)
167
168def respond(
169self,
170thread: ThreadMetadata,
171input_user_message: UserMessageItem | None,
172context: Any,
173) -> AsyncIterator[ThreadStreamEvent]:
174if responder is None:
175return self._empty_responder()
176return responder(thread, input_user_message, context)
177
178async def _empty_responder(self) -> AsyncIterator[ThreadStreamEvent]:
179return
180yield
181
182async def add_feedback(
183self,
184thread_id: str,
185item_ids: list[str],
186feedback: FeedbackKind,
187context: Any,
188) -> None:
189if handle_feedback is None:
190return
191handle_feedback(thread_id, item_ids, feedback, context)
192
193async def process_streaming(
194self, request_obj, context: Any | None = None
195) -> list[ThreadStreamEvent]:
196result = await self.process(
197request_obj.model_dump_json(), context or DEFAULT_CONTEXT
198)
199assert isinstance(result, StreamingResult)
200return await decode_streaming_result(result)
201
202async def process_non_streaming(self, request_obj, context: Any | None = None):
203result = await self.process(
204request_obj.model_dump_json(), context or DEFAULT_CONTEXT
205)
206assert isinstance(result, NonStreamingResult)
207return result
208
209try:
210yield TestChatKitServer()
211finally:
212db.close()
213
214
db36b662Jiwon Kim7 months ago215async def test_stream_cancellation_persists_pending_assistant_message_and_hidden_context():
216async def responder(
217thread: ThreadMetadata, input: UserMessageItem | None, context: Any
218) -> AsyncIterator[ThreadStreamEvent]:
219yield ThreadItemAddedEvent(
220item=AssistantMessageItem(
221id="assistant-message-pending",
222created_at=datetime.now(),
223content=[AssistantMessageContent(text="Hello, ")],
224thread_id=thread.id,
225)
226)
227yield ThreadItemUpdatedEvent(
228item_id="assistant-message-pending",
229update=AssistantMessageContentPartTextDelta(
230content_index=0,
231delta="World!",
232),
233)
234raise asyncio.CancelledError()
235
236with make_server(responder) as server:
237# Allow hidden_context id generation in this test store
238original_generate_item_id = server.store.generate_item_id
239
240def generate_item_id(
241item_type: StoreItemType, thread: ThreadMetadata, context: Any
242):
4a4b81b0Jiwon Kim7 months ago243if item_type == "sdk_hidden_context":
244return default_generate_id("sdk_hidden_context")
db36b662Jiwon Kim7 months ago245return original_generate_item_id(item_type, thread, context)
246
247server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
248
249stream = await server.process(
250ThreadsCreateReq(
251params=ThreadCreateParams(
252input=UserMessageInput(
253content=[UserMessageTextContent(text="Hello")],
254attachments=[],
255inference_options=InferenceOptions(),
256)
257)
258).model_dump_json(),
259DEFAULT_CONTEXT,
260)
261assert isinstance(stream, StreamingResult)
262
263events: list[ThreadStreamEvent] = []
264with pytest.raises(asyncio.CancelledError): # noqa: PT012
265async for raw in stream.json_events:
266events.append(decode_event(raw))
267
268thread = next(e.thread for e in events if e.type == "thread.created")
269items = await server.store.load_thread_items(
270thread.id, None, 1, "desc", DEFAULT_CONTEXT
271)
272hidden_context_item = items.data[-1]
4a4b81b0Jiwon Kim7 months ago273assert hidden_context_item.type == "sdk_hidden_context"
6df9ee60Jiwon Kim7 months ago274assert (
275hidden_context_item.content
276== "The user cancelled the stream. Stop responding to the prior request."
277)
db36b662Jiwon Kim7 months ago278
279assistant_message_item = await server.store.load_item(
280thread.id, "assistant-message-pending", DEFAULT_CONTEXT
281)
282assert assistant_message_item.type == "assistant_message"
283assert assistant_message_item.content[0].text == "Hello, World!"
284
285
e29b09e6Jiwon Kim7 months ago286async def test_stream_cancellation_does_not_persist_pending_empty_assistant_message():
287async def responder(
288thread: ThreadMetadata, input: UserMessageItem | None, context: Any
289) -> AsyncIterator[ThreadStreamEvent]:
290yield ThreadItemAddedEvent(
291item=AssistantMessageItem(
292id="assistant-message-pending",
293created_at=datetime.now(),
294content=[],
295thread_id=thread.id,
296)
297)
298raise asyncio.CancelledError()
299
300with make_server(responder) as server:
301original_generate_item_id = server.store.generate_item_id
302
303def generate_item_id(
304item_type: StoreItemType, thread: ThreadMetadata, context: Any
305):
4a4b81b0Jiwon Kim7 months ago306if item_type == "sdk_hidden_context":
307return default_generate_id("sdk_hidden_context")
e29b09e6Jiwon Kim7 months ago308return original_generate_item_id(item_type, thread, context)
309
310server.store.generate_item_id = generate_item_id # type: ignore[method-assign]
311
312stream = await server.process(
313ThreadsCreateReq(
314params=ThreadCreateParams(
315input=UserMessageInput(
316content=[UserMessageTextContent(text="Hello")],
317attachments=[],
318inference_options=InferenceOptions(),
319)
320)
321).model_dump_json(),
322DEFAULT_CONTEXT,
323)
324assert isinstance(stream, StreamingResult)
325
326events: list[ThreadStreamEvent] = []
327with pytest.raises(asyncio.CancelledError): # noqa: PT012
328async for raw in stream.json_events:
329events.append(decode_event(raw))
330
331thread = next(e.thread for e in events if e.type == "thread.created")
332items = await server.store.load_thread_items(
333thread.id, None, 1, "desc", DEFAULT_CONTEXT
334)
335hidden_context_item = items.data[-1]
4a4b81b0Jiwon Kim7 months ago336assert hidden_context_item.type == "sdk_hidden_context"
6df9ee60Jiwon Kim7 months ago337assert (
338hidden_context_item.content
339== "The user cancelled the stream. Stop responding to the prior request."
340)
e29b09e6Jiwon Kim7 months ago341
342with pytest.raises(NotFoundError):
343await server.store.load_item(
344thread.id, "assistant-message-pending", DEFAULT_CONTEXT
345)
346
347
f688d870victor-openai8 months ago348async def test_flows_context_to_responder():
349responder_context = None
350add_feedback_context = None
351
352async def responder(
353thread: ThreadMetadata, input: UserMessageItem | None, context: Any
354) -> AsyncIterator[ThreadStreamEvent]:
355nonlocal responder_context
356responder_context = context
357yield ThreadItemDoneEvent(
358item=AssistantMessageItem(
359id="msg_1",
360created_at=datetime.now(),
361content=[AssistantMessageContent(text="Hello, world!")],
362thread_id=thread.id,
363),
364)
365
366def add_feedback(
367thread_id: str, item_ids: list[str], feedback: FeedbackKind, context: Any
368) -> None:
369nonlocal add_feedback_context
370add_feedback_context = context
371
372context = RequestContext(user_id="text_ctx")
373with make_server(responder, add_feedback) as server:
374events = await server.process_streaming(
375ThreadsCreateReq(
376params=ThreadCreateParams(
377input=UserMessageInput(
378content=[UserMessageTextContent(text="Hello, world!")],
379attachments=[],
380inference_options=InferenceOptions(),
381)
382)
383),
384context,
385)
386assert responder_context == context
387thread = next(
388event.thread for event in events if event.type == "thread.created"
389)
390item = next(
391event.item
392for event in events
393if event.type == "thread.item.done"
394and event.item.type == "assistant_message"
395)
396await server.process_non_streaming(
397ItemsFeedbackReq(
398params=ItemFeedbackParams(
399thread_id=thread.id,
400item_ids=[item.id],
401kind="positive",
402)
403),
404context,
405)
406assert add_feedback_context == context
407
408
409async def test_creates_thread():
410async def responder(
411thread: ThreadMetadata, input: UserMessageItem | None, context: Any
412) -> AsyncIterator[ThreadStreamEvent]:
413return
414yield
415
416with make_server(responder) as server:
417events = await server.process_streaming(
418ThreadsCreateReq(
419params=ThreadCreateParams(
420input=UserMessageInput(
421content=[UserMessageTextContent(text="Hello, world!")],
422attachments=[],
423inference_options=InferenceOptions(),
424quoted_text="Quoted text",
425)
426)
427)
428)
429
430thread_created_event = next(
431event for event in events if event.type == "thread.created"
432)
433assert thread_created_event.thread.title is None
434
435item_done_event = next(
436event.item for event in events if event.type == "thread.item.done"
437)
438
439assert item_done_event.type == "user_message"
440assert item_done_event.content[0].text == "Hello, world!"
441assert item_done_event.quoted_text == "Quoted text"
442
443
444async def test_saves_thread_title():
445async def responder(
446thread: ThreadMetadata, input: UserMessageItem | None, context: Any
447) -> AsyncIterator[ThreadStreamEvent]:
448thread.title = "Updated title"
449return
450yield
451
452with make_server(responder) as server:
453events = await server.process_streaming(
454ThreadsCreateReq(
455params=ThreadCreateParams(
456input=UserMessageInput(
457content=[UserMessageTextContent(text="Hello, world!")],
458attachments=[],
459inference_options=InferenceOptions(),
460)
461)
462)
463)
464thread = next(
465event.thread for event in events if event.type == "thread.created"
466)
467assert (
468await server.store.load_thread(
469thread.id, RequestContext(user_id="test_user")
470)
471).title == "Updated title"
472assert events[-1].type == "thread.updated"
473assert events[-1].thread.title == "Updated title"
474
475
476async def test_saves_thread_metadata():
477async def responder(
478thread: ThreadMetadata, input: UserMessageItem | None, context: Any
479) -> AsyncIterator[ThreadStreamEvent]:
480thread.metadata["test_key"] = "test_value"
481return
482yield
483
484with make_server(responder) as server:
485events = await server.process_streaming(
486ThreadsCreateReq(
487params=ThreadCreateParams(
488input=UserMessageInput(
489content=[UserMessageTextContent(text="Hello, world!")],
490attachments=[],
491inference_options=InferenceOptions(),
492)
493)
494)
495)
496thread = next(
497event.thread for event in events if event.type == "thread.created"
498)
499assert (await server.store.load_thread(thread.id, DEFAULT_CONTEXT)).metadata[
500"test_key"
501] == "test_value"
502assert events[-1].type == "thread.updated"
503
504
505async def test_saves_thread_locked_fields():
506async def responder(
507thread: ThreadMetadata, input: UserMessageItem | None, context: Any
508) -> AsyncIterator[ThreadStreamEvent]:
509thread.status = LockedStatus(reason="Because")
510return
511yield
512
513with make_server(responder) as server:
514events = await server.process_streaming(
515ThreadsCreateReq(
516params=ThreadCreateParams(
517input=UserMessageInput(
518content=[UserMessageTextContent(text="Hello, world!")],
519attachments=[],
520inference_options=InferenceOptions(),
521)
522)
523)
524)
525thread = next(
526event.thread for event in events if event.type == "thread.created"
527)
528loaded = await server.store.load_thread(
529thread.id, RequestContext(user_id="test_user")
530)
531assert loaded.status == LockedStatus(reason="Because")
532assert events[-1].type == "thread.updated"
533assert events[-1].thread.status == LockedStatus(reason="Because")
534
535
536async def test_emits_thread_updated_mid_stream_and_persists():
537async def responder(
538thread: ThreadMetadata, input: UserMessageItem | None, context: Any
539) -> AsyncIterator[ThreadStreamEvent]:
540# First assistant message
541yield ThreadItemDoneEvent(
542item=AssistantMessageItem(
543id="assistant_a",
544content=[AssistantMessageContent(text="A")],
545created_at=datetime.now(),
546thread_id=thread.id,
547),
548)
549
550# Update the thread mid-stream
551thread.title = "Mid-stream title"
552
553# Second assistant message, after which thread.updated should be emitted
554yield ThreadItemDoneEvent(
555item=AssistantMessageItem(
556id="assistant_b",
557content=[AssistantMessageContent(text="B")],
558created_at=datetime.now(),
559thread_id=thread.id,
560),
561)
562
563# Continue streaming after the update event
564yield ThreadItemDoneEvent(
565item=AssistantMessageItem(
566id="assistant_c",
567content=[AssistantMessageContent(text="C")],
568created_at=datetime.now(),
569thread_id=thread.id,
570),
571)
572
573with make_server(responder) as server:
574events = await server.process_streaming(
575ThreadsCreateReq(
576params=ThreadCreateParams(
577input=UserMessageInput(
578content=[UserMessageTextContent(text="Hello")],
579attachments=[],
580inference_options=InferenceOptions(),
581)
582)
583)
584)
585
586# Find the created thread
587thread = next(e.thread for e in events if isinstance(e, ThreadCreatedEvent))
588
589# Ensure a thread.updated event occurs right after assistant_b
590idx_b = next(
591i
592for i, e in enumerate(events)
593if isinstance(e, ThreadItemDoneEvent)
594and isinstance(e.item, AssistantMessageItem)
595and e.item.id == "assistant_b"
596)
597assert isinstance(events[idx_b + 1], ThreadUpdatedEvent)
598updated_event = cast(ThreadUpdatedEvent, events[idx_b + 1])
599assert updated_event.thread.title == "Mid-stream title"
600
601# Streaming continues after the update event
602assert isinstance(events[idx_b + 2], ThreadItemDoneEvent)
603done_event = cast(ThreadItemDoneEvent, events[idx_b + 2])
604assert isinstance(done_event.item, AssistantMessageItem)
605assert done_event.item.id == "assistant_c"
606
607# Persisted in the store
608saved = await server.store.load_thread(thread.id, DEFAULT_CONTEXT)
609assert saved.title == "Mid-stream title"
610
611
612async def test_respond_with_tool_call():
613async def responder(
614thread: ThreadMetadata, input: UserMessageItem | None, context: Any
615) -> AsyncIterator[ThreadStreamEvent]:
616if isinstance(input, UserMessageItem):
617yield ThreadItemDoneEvent(
618item=ClientToolCallItem(
619id="msg_1",
620created_at=datetime.now(),
621name="tool_call_1",
622arguments={"arg1": "val1", "arg2": False},
623call_id="tool_call_1",
624thread_id=thread.id,
625),
626)
627elif input is None:
628yield ThreadItemDoneEvent(
629item=AssistantMessageItem(
630id="msg_2",
631content=[
632AssistantMessageContent(text="Glad the tool call succeeded!")
633],
634created_at=datetime.now(),
635thread_id=thread.id,
636),
637)
638
639with make_server(responder) as server:
640events = await server.process_streaming(
641ThreadsCreateReq(
642params=ThreadCreateParams(
643input=UserMessageInput(
644content=[UserMessageTextContent(text="Hello, world!")],
645attachments=[],
646inference_options=InferenceOptions(),
647)
648)
649)
650)
651
4a4b81b0Jiwon Kim7 months ago652assert len(events) == 4
f688d870victor-openai8 months ago653assert events[0].type == "thread.created"
654thread = events[0].thread
655
656assert events[1].type == "thread.item.done"
657assert events[1].item.type == "user_message"
658
4a4b81b0Jiwon Kim7 months ago659assert events[2].type == "stream_options"
660
661assert events[3].type == "thread.item.done"
662assert events[3].item.type == "client_tool_call"
663assert events[3].item.id == "msg_1"
664assert events[3].item.name == "tool_call_1"
665assert events[3].item.arguments == {"arg1": "val1", "arg2": False}
666assert events[3].item.call_id == "tool_call_1"
f688d870victor-openai8 months ago667
668events = await server.process_streaming(
669ThreadsAddClientToolOutputReq(
670params=ThreadAddClientToolOutputParams(
671thread_id=thread.id,
672result={"text": "Wow!"},
673)
674)
675)
676
4a4b81b0Jiwon Kim7 months ago677assert len(events) == 2
678assert events[0].type == "stream_options"
679assert events[1].type == "thread.item.done"
680assert events[1].item.type == "assistant_message"
f688d870victor-openai8 months ago681
682
683async def test_removes_tool_call_if_no_output_provided():
684async def responder(
685thread: ThreadMetadata, input: UserMessageItem | None, context: Any
686) -> AsyncIterator[ThreadStreamEvent]:
687assert isinstance(input, UserMessageItem)
688assert input.content[0].type == "input_text"
689if input.content[0].text == "Message 1":
690yield ThreadItemDoneEvent(
691item=ClientToolCallItem(
692id="msg_1",
693created_at=datetime.now(),
694name="tool_call_1",
695arguments={"arg1": "val1", "arg2": False},
696call_id="tool_call_1",
697thread_id=thread.id,
698),
699)
700else:
701yield ThreadItemDoneEvent(
702item=AssistantMessageItem(
703id="msg_2",
704content=[AssistantMessageContent(text="All done!")],
705created_at=datetime.now(),
706thread_id=thread.id,
707),
708)
709
710with make_server(responder) as server:
711events = await server.process_streaming(
712ThreadsCreateReq(
713params=ThreadCreateParams(
714input=UserMessageInput(
715content=[UserMessageTextContent(text="Message 1")],
716attachments=[],
717inference_options=InferenceOptions(),
718)
719)
720)
721)
722thread = next(
723event.thread for event in events if event.type == "thread.created"
724)
725
726await server.process_streaming(
727ThreadsAddUserMessageReq(
728params=ThreadAddUserMessageParams(
729thread_id=thread.id,
730input=UserMessageInput(
731content=[UserMessageTextContent(text="Message 2")],
732attachments=[],
733inference_options=InferenceOptions(),
734),
735)
736)
737)
738
739items_result = await server.process_non_streaming(
740ItemsListReq(params=ItemsListParams(thread_id=thread.id))
741)
742items = TypeAdapter(Page[ThreadItem]).validate_json(items_result.json)
743assert len(items.data) == 3
744assert items.data[0].type == "assistant_message"
745assert items.data[1].type == "user_message"
746assert items.data[2].type == "user_message"
747
748
749async def test_respond_with_tool_status():
750async def responder(
751thread: ThreadMetadata, input: UserMessageItem | None, context: Any
752) -> AsyncIterator[ThreadStreamEvent]:
753yield ProgressUpdateEvent(text="Tool status")
754yield ThreadItemDoneEvent(
755item=AssistantMessageItem(
756id="msg_4",
757content=[AssistantMessageContent(text="Hello, world!")],
758created_at=datetime.now(),
759thread_id=thread.id,
760),
761)
762
763with make_server(responder) as server:
764events = await server.process_streaming(
765ThreadsCreateReq(
766params=ThreadCreateParams(
767input=UserMessageInput(
768content=[UserMessageTextContent(text="Hello, world!")],
769attachments=[],
770inference_options=InferenceOptions(),
771)
772)
773)
774)
775
4a4b81b0Jiwon Kim7 months ago776assert len(events) == 5
f688d870victor-openai8 months ago777assert events[0].type == "thread.created"
778assert events[1].type == "thread.item.done"
4a4b81b0Jiwon Kim7 months ago779assert events[2].type == "stream_options"
780assert events[3].type == "progress_update"
781assert events[4].type == "thread.item.done"
f688d870victor-openai8 months ago782
783
784async def test_list_threads_response():
785with make_server() as server:
786for i in range(25):
787await server.process_streaming(
788ThreadsCreateReq(
789params=ThreadCreateParams(
790input=UserMessageInput(
791content=[UserMessageTextContent(text=f"Thread {i}")],
792attachments=[],
793inference_options=InferenceOptions(),
794)
795)
796)
797)
798list_result = await server.process_non_streaming(
799ThreadsListReq(params=ThreadListParams(limit=20))
800)
801threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
802assert len(threads.data) == 20
803
804# Newest first
805assert threads.data[0].created_at > threads.data[1].created_at
806assert threads.has_more is True
807assert threads.after is not None
808
809more_threads = await server.process_non_streaming(
810ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
811)
812threads = TypeAdapter(Page[ThreadMetadata]).validate_json(more_threads.json)
813assert len(threads.data) == 5
814assert threads.has_more is False
815assert not threads.after
816
817
818async def test_list_items_response():
819async def responder(
820thread: ThreadMetadata, input: UserMessageItem | None, context: Any
821) -> AsyncIterator[ThreadStreamEvent]:
822for i in range(25):
823yield ThreadItemDoneEvent(
824item=UserMessageItem(
825id=f"msg_{i}",
826content=[UserMessageTextContent(text=f"Message {i}")],
827attachments=[],
828inference_options=InferenceOptions(),
829thread_id=thread.id,
830created_at=datetime.now(),
831),
832)
833
834with make_server(responder) as server:
835events = await server.process_streaming(
836ThreadsCreateReq(
837params=ThreadCreateParams(
838input=UserMessageInput(
839content=[UserMessageTextContent(text="Thread 1")],
840attachments=[],
841inference_options=InferenceOptions(),
842)
843)
844)
845)
846thread = next(
847event.thread for event in events if event.type == "thread.created"
848)
849list_result = await server.process_non_streaming(
850ItemsListReq(params=ItemsListParams(thread_id=thread.id))
851)
852items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
853assert len(items.data) == 20
854# Newest first
855assert items.data[0].created_at > items.data[1].created_at
856assert items.has_more is True
857assert items.after is not None
858
859list_result = await server.process_non_streaming(
860ItemsListReq(
861params=ItemsListParams(thread_id=thread.id, after=items.data[-1].id)
862)
863)
864items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
865assert len(items.data) == 6 # +1 for user message
866assert items.has_more is False
867
868
869async def test_calls_action():
870actions = []
871
872async def responder(
873thread: ThreadMetadata,
874input: UserMessageItem | None,
875context: Any,
876) -> AsyncIterator[ThreadStreamEvent]:
877yield ThreadItemDoneEvent(
878item=WidgetItem(
879id="widget_1",
880type="widget",
881created_at=datetime.now(),
882thread_id=thread.id,
883widget=Card(
884children=[
885Text(
886key="text_1",
887value="test",
888)
889],
890),
891),
892)
893
894async def action(
895thread: ThreadMetadata,
896action: Action[str, Any],
897sender: WidgetItem | None,
898context: Any,
899) -> AsyncIterator[ThreadStreamEvent]:
900actions.append((action, sender))
901assert sender
902
643a02f7Jiwon Kim7 months ago903yield ThreadItemUpdatedEvent(
f688d870victor-openai8 months ago904item_id=sender.id,
905update=WidgetRootUpdated(
906widget=Card(
907children=[
908Text(value="Email sent!"),
909]
910),
911),
912)
913
914with make_server(responder, action_callback=action) as server:
915events = await server.process_streaming(
916ThreadsCreateReq(
917params=ThreadCreateParams(
918input=UserMessageInput(
919content=[UserMessageTextContent(text="Show widget")],
920attachments=[],
921inference_options=InferenceOptions(),
922)
923)
924)
925)
926thread = next(
927event.thread for event in events if event.type == "thread.created"
928)
929widget_item = next(
930event.item
931for event in events
932if isinstance(event, ThreadItemDoneEvent)
933and isinstance(event.item, WidgetItem)
934)
935
936events = await server.process_streaming(
937ThreadsCustomActionReq(
938params=ThreadCustomActionParams(
939thread_id=thread.id,
940item_id=widget_item.id,
941action=Action(type="create_user", payload={"user_id": "123"}),
942)
943)
944)
945
946assert actions
947assert actions[0] == (
948Action(type="create_user", payload={"user_id": "123"}),
949widget_item,
950)
951
4a4b81b0Jiwon Kim7 months ago952assert len(events) == 2
953assert events[0].type == "stream_options"
954assert events[1].type == "thread.item.updated"
955assert isinstance(events[1], ThreadItemUpdatedEvent)
956assert events[1].update.type == "widget.root.updated"
957assert events[1].update.widget == Card(children=[Text(value="Email sent!")])
f688d870victor-openai8 months ago958
959
960async def test_add_feedback():
961called = []
962
963def handle_feedback(thread_id, item_ids, feedback, context):
964called.append((thread_id, item_ids, feedback, context))
965
966async def responder(
967thread: ThreadMetadata,
968input: UserMessageItem | None,
969context: Any,
970) -> AsyncIterator[ThreadStreamEvent]:
971yield ThreadItemDoneEvent(
972item=AssistantMessageItem(
973id="assistant_msg_1",
974content=[
975AssistantMessageContent(text="Feedback received!"),
976],
977created_at=datetime.now(),
978thread_id=thread.id,
979),
980)
981yield ThreadItemDoneEvent(
982item=WidgetItem(
983id="widget_1",
984type="widget",
985created_at=datetime.now(),
986widget=Card(children=[Text(key="text", value="Widget content")]),
987thread_id=thread.id,
988),
989)
990
991with make_server(responder, handle_feedback=handle_feedback) as server:
992events = await server.process_streaming(
993ThreadsCreateReq(
994params=ThreadCreateParams(
995input=UserMessageInput(
996content=[UserMessageTextContent(text="Test feedback")],
997attachments=[],
998inference_options=InferenceOptions(),
999)
1000)
1001)
1002)
1003thread = next(
1004event.thread for event in events if event.type == "thread.created"
1005)
1006
1007await server.process_non_streaming(
1008ItemsFeedbackReq(
1009params=ItemFeedbackParams(
1010thread_id=thread.id,
1011item_ids=["assistant_msg_1"],
1012kind="positive",
1013)
1014)
1015)
1016await server.process_non_streaming(
1017ItemsFeedbackReq(
1018params=ItemFeedbackParams(
1019thread_id=thread.id,
1020item_ids=["widget_1"],
1021kind="negative",
1022)
1023)
1024)
1025
1026assert len(called) == 2
1027assert called[0][0] == thread.id
1028assert called[0][1] == ["assistant_msg_1"]
1029assert called[0][2] == "positive"
1030
1031assert called[1][0] == thread.id
1032assert called[1][1] == ["widget_1"]
1033assert called[1][2] == "negative"
1034
1035
1036async def test_get_thread_by_id():
1037make_thread_items()
1038
1039with make_server() as server:
1040# Create a thread by sending a ThreadCreateRequest
1041events = await server.process_streaming(
1042ThreadsCreateReq(
1043params=ThreadCreateParams(
1044input=UserMessageInput(
1045content=[UserMessageTextContent(text="Test thread")],
1046attachments=[],
1047inference_options=InferenceOptions(),
1048)
1049)
1050)
1051)
1052thread = next(
1053event.thread for event in events if event.type == "thread.created"
1054)
1055
1056# Retrieve the thread by id
1057result = await server.process_non_streaming(
1058ThreadsGetByIdReq(
1059params=ThreadGetByIdParams(thread_id=thread.id),
1060)
1061)
1062loaded_thread = TypeAdapter(Thread).validate_json(result.json)
1063assert loaded_thread.id == thread.id
1064assert loaded_thread.title == thread.title
1065assert loaded_thread.items.data[0].type == "user_message"
1066assert loaded_thread.items.data[0].content[0].text == "Test thread"
1067
1068
1069async def test_create_file():
1070store = InMemoryFileStore()
1071file_name = "test-file-name"
1072file_content_type = "text/plain"
1073
1074with make_server(file_store=store) as server:
1075result = await server.process_non_streaming(
1076AttachmentsCreateReq(
1077params=AttachmentCreateParams(
1078name=file_name, size=1024, mime_type=file_content_type
1079)
1080)
1081)
1082attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1083
1084# Verify response includes file
1085assert attachment.id is not None
1086assert attachment.mime_type == file_content_type
1087assert attachment.name == file_name
1088assert attachment.type == "file"
1089assert attachment.upload_url == AnyUrl(
1090f"https://example.com/{attachment.id}/upload"
1091)
1092assert attachment.id in store.files
1093
1094
1095async def test_create_image_file():
1096store = InMemoryFileStore()
1097file_name = "test-file-name"
1098file_content_type = "image/png"
1099
1100with make_server(file_store=store) as server:
1101result = await server.process_non_streaming(
1102AttachmentsCreateReq(
1103params=AttachmentCreateParams(
1104name=file_name,
1105size=1024,
1106mime_type=file_content_type,
1107)
1108)
1109)
1110attachment = TypeAdapter[Attachment](Attachment).validate_json(result.json)
1111
1112assert attachment.id is not None
1113assert attachment.mime_type == file_content_type
1114assert attachment.name == file_name
1115assert attachment.type == "image"
1116assert attachment.preview_url == AnyUrl(
1117f"https://example.com/{attachment.id}/preview"
1118)
1119assert attachment.upload_url == AnyUrl(
1120f"https://example.com/{attachment.id}/upload"
1121)
1122
1123assert attachment.id in store.files
1124
1125
1126async def test_create_file_without_filestore():
1127with pytest.raises(RuntimeError):
1128with make_server() as server:
1129await server.process_non_streaming(
1130AttachmentsCreateReq(
1131params=AttachmentCreateParams(
1132name="test-file-name",
1133size=1024,
1134mime_type="text/plain",
1135)
1136)
1137)
1138
1139
1140async def test_delete_file():
1141store = InMemoryFileStore()
1142store.files["test-file-id"] = {"id": "test-file-id"}
1143with make_server(file_store=store) as server:
1144await server.process_non_streaming(
1145AttachmentsDeleteReq(
1146params=AttachmentDeleteParams(attachment_id="test-file-id")
1147)
1148)
1149assert store.files == {}
1150
1151
1152async def test_delete_file_without_filestore():
1153with pytest.raises(RuntimeError):
1154with make_server() as server:
1155await server.process_non_streaming(
1156AttachmentsDeleteReq(
1157params=AttachmentDeleteParams(attachment_id="test-file-id")
1158)
1159)
1160
1161
1162async def test_list_threads_paginated_response():
1163with make_server() as server:
1164for i in range(25):
1165await server.process_streaming(
1166ThreadsCreateReq(
1167params=ThreadCreateParams(
1168input=UserMessageInput(
1169content=[UserMessageTextContent(text=f"Thread {i}")],
1170attachments=[],
1171inference_options=InferenceOptions(),
1172)
1173)
1174)
1175)
1176list_result = await server.process_non_streaming(
1177ThreadsListReq(params=ThreadListParams(limit=20))
1178)
1179threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1180assert len(threads.data) == 20
1181assert threads.has_more is True
1182assert threads.after is not None
1183
1184list_result = await server.process_non_streaming(
1185ThreadsListReq(params=ThreadListParams(after=threads.data[-1].id))
1186)
1187threads = TypeAdapter(Page[ThreadMetadata]).validate_json(list_result.json)
1188assert len(threads.data) == 5
1189assert threads.has_more is False
1190
1191
1192async def test_threads_update_request():
1193with make_server() as server:
1194events = await server.process_streaming(
1195ThreadsCreateReq(
1196params=ThreadCreateParams(
1197input=UserMessageInput(
1198content=[UserMessageTextContent(text="Hello, world!")],
1199attachments=[],
1200inference_options=InferenceOptions(),
1201)
1202)
1203)
1204)
1205thread = next(
1206event.thread for event in events if event.type == "thread.created"
1207)
1208updated_thread = await server.process_non_streaming(
1209ThreadsUpdateReq(
1210params=ThreadUpdateParams(thread_id=thread.id, title="New Title")
1211)
1212)
1213updated_thread = TypeAdapter[Thread](Thread).validate_json(updated_thread.json)
1214assert updated_thread.title == "New Title"
1215
1216
1217async def test_returns_widget_item_generator():
1218thread = ThreadMetadata(
1219id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1220)
1221
1222def render_widget(i: int) -> Card:
1223return Card(children=[Text(id="text", value="Hello, world"[:i])])
1224
1225async def widget_generator():
1226yield render_widget(0)
1227yield render_widget(3)
1228yield render_widget(6)
1229yield render_widget(12)
1230
1231events = [event async for event in stream_widget(thread, widget_generator())]
1232
1233assert len(events) == 5
1234assert isinstance(events[0], ThreadItemAddedEvent)
1235assert isinstance(events[0].item, WidgetItem)
1236assert events[0].item.widget == Card(children=[Text(id="text", value="")])
1237
643a02f7Jiwon Kim7 months ago1238assert isinstance(events[1], ThreadItemUpdatedEvent)
f688d870victor-openai8 months ago1239assert events[1].update.type == "widget.streaming_text.value_delta"
1240assert events[1].update.component_id == "text"
1241assert events[1].update.delta == "Hel"
1242
643a02f7Jiwon Kim7 months ago1243assert isinstance(events[2], ThreadItemUpdatedEvent)
f688d870victor-openai8 months ago1244assert events[2].update.type == "widget.streaming_text.value_delta"
1245assert events[2].update.component_id == "text"
1246assert events[2].update.delta == "lo,"
1247
643a02f7Jiwon Kim7 months ago1248assert isinstance(events[3], ThreadItemUpdatedEvent)
f688d870victor-openai8 months ago1249assert events[3].update.type == "widget.streaming_text.value_delta"
1250assert events[3].update.component_id == "text"
1251assert events[3].update.delta == " world"
1252
1253assert isinstance(events[4], ThreadItemDoneEvent)
1254assert isinstance(events[4].item, WidgetItem)
1255assert events[4].item.widget == Card(
1256children=[Text(id="text", value="Hello, world")]
1257)
1258
1259
1260async def test_returns_widget_item_generator_full_replace():
1261thread = ThreadMetadata(
1262id="test-thread-id", title="Test thread", metadata={}, created_at=datetime.now()
1263)
1264
1265async def widget_generator():
1266yield Card(children=[Text(id="text", value="Hello")])
1267yield Card(children=[Text(id="text", value="World")])
1268
1269events = [event async for event in stream_widget(thread, widget_generator())]
1270
1271assert len(events) == 3
1272assert isinstance(events[0], ThreadItemAddedEvent)
1273assert isinstance(events[0].item, WidgetItem)
1274assert events[0].item.widget == Card(children=[Text(id="text", value="Hello")])
1275
643a02f7Jiwon Kim7 months ago1276assert isinstance(events[1], ThreadItemUpdatedEvent)
f688d870victor-openai8 months ago1277assert events[1].update.type == "widget.root.updated"
1278assert events[1].update.widget == Card(children=[Text(id="text", value="World")])
1279
1280assert isinstance(events[2], ThreadItemDoneEvent)
1281assert isinstance(events[2].item, WidgetItem)
1282assert events[2].item.widget == Card(children=[Text(id="text", value="World")])
1283
1284
1285async def test_delete_thread():
1286with make_server() as server:
1287events = await server.process_streaming(
1288ThreadsCreateReq(
1289params=ThreadCreateParams(
1290input=UserMessageInput(
1291content=[UserMessageTextContent(text="Thread to delete")],
1292attachments=[],
1293inference_options=InferenceOptions(),
1294)
1295)
1296)
1297)
1298thread = next(
1299event.thread for event in events if event.type == "thread.created"
1300)
1301
1302response = await server.process_non_streaming(
1303ThreadsDeleteReq(params=ThreadDeleteParams(thread_id=thread.id))
1304)
1305assert response.json == b"{}"
1306
1307try:
1308await server.process_non_streaming(
1309ThreadsGetByIdReq(
1310params=ThreadGetByIdParams(thread_id=thread.id),
1311)
1312)
1313assert False, "Expected ValueError for deleted thread"
1314except NotFoundError as e:
1315assert f"Thread {thread.id} not found" in str(e)
1316
1317
1318async def test_thread_item_removed_event_removes_item():
1319removed_id = "msg_to_remove"
1320
1321async def responder(
1322thread: ThreadMetadata,
1323input: UserMessageItem | None,
1324context: Any,
1325) -> AsyncIterator[ThreadStreamEvent]:
1326yield ThreadItemDoneEvent(
1327item=UserMessageItem(
1328id=removed_id,
1329content=[UserMessageTextContent(text="To be removed")],
1330attachments=[],
1331inference_options=InferenceOptions(),
1332thread_id=thread.id,
1333created_at=datetime.now(),
1334),
1335)
1336yield ThreadItemRemovedEvent(item_id=removed_id)
1337
1338with make_server(responder) as server:
1339events = await server.process_streaming(
1340ThreadsCreateReq(
1341params=ThreadCreateParams(
1342input=UserMessageInput(
1343content=[UserMessageTextContent(text="Trigger removal")],
1344attachments=[],
1345inference_options=InferenceOptions(),
1346)
1347)
1348)
1349)
1350thread = next(
1351event.thread for event in events if event.type == "thread.created"
1352)
1353assert any(
1354e.type == "thread.item.removed" and e.item_id == removed_id for e in events
1355)
1356# The item should not be present in the store
1357list_result = await server.process_non_streaming(
1358ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1359)
1360items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1361assert all(i.id != removed_id for i in items)
1362
1363
1364async def test_raising_in_responder_yields_error_event():
1365async def responder(
1366thread: ThreadMetadata,
1367input: UserMessageItem | None,
1368context: Any,
1369) -> AsyncIterator[ThreadStreamEvent]:
1370raise ValueError("Test error")
1371yield
1372
1373with make_server(responder) as server:
1374result = await server.process(
1375ThreadsCreateReq(
1376params=ThreadCreateParams(
1377input=UserMessageInput(
1378content=[UserMessageTextContent(text="Test error")],
1379attachments=[],
1380inference_options=InferenceOptions(),
1381)
1382)
1383).model_dump_json(),
1384RequestContext(user_id="test_user"),
1385)
1386assert isinstance(result, StreamingResult)
1387iter = result.__aiter__()
1388
1389# Yield an errro
1390async for e in iter:
1391e = decode_event(e)
1392if e.type == "error":
1393assert e.code == ErrorCode.STREAM_ERROR
1394break
1395else:
1396assert False, "Expected error event"
1397
1398# Then finish
1399try:
1400await iter.__anext__()
1401except StopAsyncIteration:
1402pass
1403else:
1404assert False, "Expected StopAsyncIteration"
1405
1406
1407async def test_thread_item_replaced_event_replaces_item():
1408original_id = "msg_to_replace"
1409
1410async def responder(
1411thread: ThreadMetadata,
1412input: UserMessageItem | None,
1413context: Any,
1414) -> AsyncIterator[ThreadStreamEvent]:
1415# First add an item
1416yield ThreadItemDoneEvent(
1417item=UserMessageItem(
1418id=original_id,
1419content=[UserMessageTextContent(text="Original content")],
1420attachments=[],
1421inference_options=InferenceOptions(),
1422thread_id=thread.id,
1423created_at=datetime.now(),
1424),
1425)
1426# Then replace it
1427replacement_item = AssistantMessageItem(
1428id=original_id,
1429content=[AssistantMessageContent(text="This is the replacement content")],
1430created_at=datetime.now(),
1431thread_id=thread.id,
1432)
1433yield ThreadItemReplacedEvent(item=replacement_item)
1434
1435with make_server(responder) as server:
1436events = await server.process_streaming(
1437ThreadsCreateReq(
1438params=ThreadCreateParams(
1439input=UserMessageInput(
1440content=[UserMessageTextContent(text="Trigger replacement")],
1441attachments=[],
1442inference_options=InferenceOptions(),
1443)
1444)
1445)
1446)
1447thread = next(
1448event.thread for event in events if event.type == "thread.created"
1449)
1450
1451# Verify the replacement event was generated
1452assert any(
1453e.type == "thread.item.replaced" and e.item.id == original_id
1454for e in events
1455)
1456
1457# Verify the item was replaced in the store
1458list_result = await server.process_non_streaming(
1459ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1460)
1461items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1462
1463# Find the replaced item
1464replaced_item = next((i for i in items if i.id == original_id), None)
1465assert replaced_item is not None
1466assert isinstance(replaced_item, AssistantMessageItem)
1467assert isinstance(replaced_item.content[0], AssistantMessageContent)
1468assert replaced_item.content[0].text == "This is the replacement content"
1469
1470
1471async def test_thread_item_replaced_event_with_widget():
1472widget_id = "widget_to_replace"
1473original_widget = Card(children=[Text(value="Original widget content")])
1474replacement_widget = Card(children=[Text(value="Replaced widget content")])
1475
1476async def responder(
1477thread: ThreadMetadata,
1478input: UserMessageItem | None,
1479context: Any,
1480) -> AsyncIterator[ThreadStreamEvent]:
1481# First add a widget item
1482yield ThreadItemDoneEvent(
1483item=WidgetItem(
1484id=widget_id,
1485widget=original_widget,
1486created_at=datetime.now(),
1487thread_id=thread.id,
1488),
1489)
1490# Then replace it
1491yield ThreadItemReplacedEvent(
1492item=WidgetItem(
1493id=widget_id,
1494widget=replacement_widget,
1495created_at=datetime.now(),
1496thread_id=thread.id,
1497),
1498)
1499
1500with make_server(responder) as server:
1501events = await server.process_streaming(
1502ThreadsCreateReq(
1503params=ThreadCreateParams(
1504input=UserMessageInput(
1505content=[
1506UserMessageTextContent(text="Trigger widget replacement")
1507],
1508attachments=[],
1509inference_options=InferenceOptions(),
1510)
1511)
1512)
1513)
1514thread = next(
1515event.thread for event in events if event.type == "thread.created"
1516)
1517
1518# Verify the replacement event was generated
1519replacement_event = next(
1520(
1521e
1522for e in events
1523if e.type == "thread.item.replaced" and e.item.id == widget_id
1524),
1525None,
1526)
1527assert replacement_event is not None
1528assert isinstance(replacement_event.item, WidgetItem)
1529assert replacement_event.item.widget == replacement_widget
1530
1531# Verify the item was replaced in the store
1532list_result = await server.process_non_streaming(
1533ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1534)
1535items = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json).data
1536
1537# Find the replaced widget item
1538replaced_item = next((i for i in items if i.id == widget_id), None)
1539assert replaced_item is not None
1540assert isinstance(replaced_item, WidgetItem)
1541assert isinstance(replaced_item.widget, Card)
1542assert isinstance(replaced_item.widget.children[0], Text)
1543assert replaced_item.widget.children[0].value == "Replaced widget content"
1544
1545
1546async def test_retry_after_item():
1547responder_calls = []
1548
1549async def responder(
1550thread: ThreadMetadata,
1551input: UserMessageItem | None,
1552context: Any,
1553) -> AsyncIterator[ThreadStreamEvent]:
1554responder_calls.append(input)
1555
1556if len(responder_calls) == 1:
1557# First response - return an assistant message
1558yield ThreadItemDoneEvent(
1559item=AssistantMessageItem(
1560id="assistant_msg_1",
1561content=[AssistantMessageContent(text="First response")],
1562created_at=datetime.now(),
1563thread_id=thread.id,
1564),
1565)
1566else:
1567# Retry response - return different content
1568yield ThreadItemDoneEvent(
1569item=AssistantMessageItem(
1570id="assistant_msg_2",
1571content=[AssistantMessageContent(text="Retried response")],
1572created_at=datetime.now(),
1573thread_id=thread.id,
1574),
1575)
1576
1577with make_server(responder) as server:
1578events = await server.process_streaming(
1579ThreadsCreateReq(
1580params=ThreadCreateParams(
1581input=UserMessageInput(
1582content=[UserMessageTextContent(text="Hello, world!")],
1583attachments=[],
1584inference_options=InferenceOptions(),
1585)
1586)
1587)
1588)
1589thread = next(
1590event.thread for event in events if event.type == "thread.created"
1591)
1592
1593list_result = await server.process_non_streaming(
1594ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1595)
1596items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1597assert len(items_before.data) == 2
1598assert items_before.data[0].type == "assistant_message"
1599assert items_before.data[0].content[0].type == "output_text"
1600assert items_before.data[0].content[0].text == "First response"
1601assert items_before.data[1].type == "user_message"
1602assert items_before.data[1].content[0].text == "Hello, world!"
1603
1604# Get the user message ID for retrying after it
1605user_message_id = items_before.data[1].id
1606
1607# Retry after the user message
1608retry_events = await server.process_streaming(
1609ThreadsRetryAfterItemReq(
1610params=ThreadRetryAfterItemParams(
1611thread_id=thread.id, item_id=user_message_id
1612)
1613)
1614)
1615
1616# Verify the retry generated new response
4a4b81b0Jiwon Kim7 months ago1617assert len(retry_events) == 2
1618assert retry_events[0].type == "stream_options"
1619assert retry_events[1].type == "thread.item.done"
1620assert retry_events[1].item.type == "assistant_message"
1621assert retry_events[1].item.content[0].type == "output_text"
1622assert retry_events[1].item.content[0].text == "Retried response"
f688d870victor-openai8 months ago1623
1624# Verify the responder was called twice with the same user message
1625assert len(responder_calls) == 2
1626assert responder_calls[0] == responder_calls[1]
1627assert responder_calls[0].type == "user_message"
1628assert responder_calls[0].content[0].text == "Hello, world!"
1629
1630# Verify the assistant response was replaced in storage
1631list_result_after = await server.process_non_streaming(
1632ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1633)
1634items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1635list_result_after.json
1636)
1637assert len(items_after.data) == 2 # Still user message + assistant message
1638assert items_after.data[0].type == "assistant_message"
1639assert items_after.data[0].content[0].type == "output_text"
1640assert items_after.data[0].content[0].text == "Retried response" # New response
1641assert items_after.data[1].type == "user_message"
1642assert items_after.data[1].content[0].text == "Hello, world!"
1643
1644
1645async def test_retry_after_item_invalid_item():
1646async def responder(
1647thread: ThreadMetadata,
1648input: UserMessageItem | None,
1649context: Any,
1650) -> AsyncIterator[ThreadStreamEvent]:
1651yield ThreadItemDoneEvent(
1652item=AssistantMessageItem(
1653id="assistant_msg_1",
1654content=[AssistantMessageContent(text="Response")],
1655created_at=datetime.now(),
1656thread_id=thread.id,
1657),
1658)
1659
1660with make_server(responder) as server:
1661events = await server.process_streaming(
1662ThreadsCreateReq(
1663params=ThreadCreateParams(
1664input=UserMessageInput(
1665content=[UserMessageTextContent(text="Hello")],
1666attachments=[],
1667inference_options=InferenceOptions(),
1668)
1669)
1670)
1671)
1672thread = next(
1673event.thread for event in events if event.type == "thread.created"
1674)
1675
1676items_result = await server.process_non_streaming(
1677ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1678)
1679items = TypeAdapter(Page[ThreadItem]).validate_json(items_result.json)
1680assistant_item_id = items.data[0].id
1681
1682# Try to retry after an assistant message (should fail)
1683with pytest.raises(ValueError, match="is not a user message"):
1684await server.process_streaming(
1685ThreadsRetryAfterItemReq(
1686params=ThreadRetryAfterItemParams(
1687thread_id=thread.id, item_id=assistant_item_id
1688)
1689)
1690)
1691
1692
1693async def test_retry_after_item_with_multiple_messages():
1694responder_calls = []
1695
1696async def responder(
1697thread: ThreadMetadata,
1698input: UserMessageItem | None,
1699context: Any,
1700) -> AsyncIterator[ThreadStreamEvent]:
1701responder_calls.append(input)
1702assert input is not None
1703assert input.type == "user_message"
1704yield ThreadItemDoneEvent(
1705item=AssistantMessageItem(
1706id=f"assistant_msg_{len(responder_calls)}",
1707content=[
1708AssistantMessageContent(
1709text=f"Response to: {input.content[0].text}"
1710)
1711],
1712created_at=datetime.now(),
1713thread_id=thread.id,
1714),
1715)
1716
1717with make_server(responder) as server:
1718events = await server.process_streaming(
1719ThreadsCreateReq(
1720params=ThreadCreateParams(
1721input=UserMessageInput(
1722content=[UserMessageTextContent(text="First message")],
1723attachments=[],
1724inference_options=InferenceOptions(),
1725)
1726)
1727)
1728)
1729thread = next(
1730event.thread for event in events if event.type == "thread.created"
1731)
1732
1733await server.process_streaming(
1734ThreadsAddUserMessageReq(
1735params=ThreadAddUserMessageParams(
1736thread_id=thread.id,
1737input=UserMessageInput(
1738content=[UserMessageTextContent(text="Second message")],
1739attachments=[],
1740inference_options=InferenceOptions(),
1741),
1742)
1743)
1744)
1745
1746# Verify we have 4 items: user1, assistant1, user2, assistant2
1747list_result = await server.process_non_streaming(
1748ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1749)
1750items_before = TypeAdapter(Page[ThreadItem]).validate_json(list_result.json)
1751assert len(items_before.data) == 4
1752
1753# Get the second user message ID to retry after it
1754second_user_message_id = items_before.data[1].id # Second user message
1755
1756# Retry after the second user message
1757retry_events = await server.process_streaming(
1758ThreadsRetryAfterItemReq(
1759params=ThreadRetryAfterItemParams(
1760thread_id=thread.id, item_id=second_user_message_id
1761)
1762)
1763)
1764
4a4b81b0Jiwon Kim7 months ago1765assert len(retry_events) == 2
1766assert retry_events[0].type == "stream_options"
1767assert retry_events[1].type == "thread.item.done"
f688d870victor-openai8 months ago1768
1769# Verify retry used the second user message
1770assert len(responder_calls) == 3 # Original 2 + 1 retry
1771assert responder_calls[2].type == "user_message"
1772assert responder_calls[2].content[0].text == "Second message"
1773
1774# Verify only the last assistant response was removed and regenerated
1775list_result_after = await server.process_non_streaming(
1776ItemsListReq(params=ItemsListParams(thread_id=thread.id))
1777)
1778items_after = TypeAdapter(Page[ThreadItem]).validate_json(
1779list_result_after.json
1780)
1781assert len(items_after.data) == 4 # Still 4 items
1782# First user message and response should remain unchanged
1783assert items_after.data[3].type == "user_message"
1784assert items_after.data[3].content[0].text == "First message"
1785assert items_after.data[3].content[0].type == "input_text"
1786assert items_after.data[3].content[0].text == "First message"
1787assert items_after.data[2].type == "assistant_message"
1788assert items_after.data[2].content[0].type == "output_text"
1789assert items_after.data[2].content[0].text == "Response to: First message"
1790# Second user message should remain, but assistant response should be new
1791assert items_after.data[1].type == "user_message"
1792assert items_after.data[1].content[0].text == "Second message"
1793assert items_after.data[0].type == "assistant_message"
1794assert items_after.data[0].content[0].type == "output_text"
1795assert items_after.data[0].content[0].text == "Response to: Second message"
1796assert items_after.data[0].id != items_before.data[0].id
1797
1798
1799async def test_threads_create_passes_tools_to_responder():
1800tool_choice = ToolChoice(id="web_search")
1801
1802async def responder(
1803thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1804) -> AsyncIterator[ThreadStreamEvent]:
1805assert input is not None
1806assert input.inference_options.tool_choice is not None
1807assert input.inference_options.tool_choice.id == tool_choice.id
1808
1809yield ThreadItemDoneEvent(
1810item=AssistantMessageItem(
1811id="assistant_msg_tools_create",
1812content=[AssistantMessageContent(text="ok")],
1813created_at=datetime.now(),
1814thread_id=thread.id,
1815),
1816)
1817
1818with make_server(responder) as server:
1819events = await server.process_streaming(
1820ThreadsCreateReq(
1821params=ThreadCreateParams(
1822input=UserMessageInput(
1823content=[UserMessageTextContent(text="Hello with tools")],
1824attachments=[],
1825inference_options=InferenceOptions(tool_choice=tool_choice),
1826)
1827)
1828)
1829)
1830# Sanity check that the request flowed through.
1831assert any(e.type == "thread.item.done" for e in events)
1832
1833
1834async def test_retry_after_item_passes_tools_to_responder():
1835pass
1836
1837
1838async def test_threads_add_message_passes_tools_to_responder():
1839tool_choice = ToolChoice(id="retrieval")
1840
1841async def responder(
1842thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1843) -> AsyncIterator[ThreadStreamEvent]:
1844assert input is not None
1845assert input.inference_options.tool_choice is not None
1846assert input.inference_options.tool_choice.id == tool_choice.id
1847
1848yield ThreadItemDoneEvent(
1849item=AssistantMessageItem(
1850id="assistant_msg_tools_add",
1851content=[AssistantMessageContent(text="ok")],
1852created_at=datetime.now(),
1853thread_id=thread.id,
1854),
1855)
1856
1857with make_server(responder) as server:
1858# Create a thread first.
1859events = await server.process_streaming(
1860ThreadsCreateReq(
1861params=ThreadCreateParams(
1862input=UserMessageInput(
1863content=[UserMessageTextContent(text="Start")],
1864attachments=[],
1865inference_options=InferenceOptions(tool_choice=tool_choice),
1866)
1867)
1868)
1869)
1870thread = next(e.thread for e in events if e.type == "thread.created")
1871
1872# Add a message with tools and verify they reach the responder.
1873events = await server.process_streaming(
1874ThreadsAddUserMessageReq(
1875params=ThreadAddUserMessageParams(
1876thread_id=thread.id,
1877input=UserMessageInput(
1878content=[UserMessageTextContent(text="Second")],
1879attachments=[],
1880inference_options=InferenceOptions(tool_choice=tool_choice),
1881),
1882)
1883)
1884)
1885# Sanity check that the request flowed through.
1886assert any(e.type == "thread.item.done" for e in events)
1887
1888
1889async def test_custom_id_generators_on_server():
1890class UniqueIdStore(SQLiteStore):
1891def __init__(self, path):
1892super().__init__(path)
1893self.counter = 0
1894
1895def generate_thread_id(self, context) -> str:
1896self.counter += 1
1897return f"thr_custom_{self.counter}"
1898
1899def generate_item_id(
1900self,
1901item_type: str,
1902thread: ThreadMetadata,
1903context,
1904) -> str:
1905self.counter += 1
1906return f"{item_type}_custom_{self.counter}_{thread.id}"
1907
1908async def responder(
1909thread: ThreadMetadata, input: UserMessageItem | None, context: Any
1910) -> AsyncIterator[ThreadStreamEvent]:
1911# Emit a tool call to later verify tool_call_output id generation
1912return
1913yield
1914
1915with make_server(responder) as server:
1916# Swap in our custom ID-generating store, reusing the same DB path.
1917db_path = cast(SQLiteStore, server.store).db_path
1918server.store = UniqueIdStore(db_path)
1919
1920# Create thread and verify thread id and user message id
1921result = await server.process(
1922ThreadsCreateReq(
1923params=ThreadCreateParams(
1924input=UserMessageInput(
1925content=[UserMessageTextContent(text="Hello")],
1926attachments=[],
1927inference_options=InferenceOptions(),
1928)
1929)
1930).model_dump_json(),
1931DEFAULT_CONTEXT,
1932)
1933assert isinstance(result, StreamingResult)
1934events = await decode_streaming_result(result)
1935
1936thread_created = next(e for e in events if e.type == "thread.created")
1937assert thread_created.thread.id == "thr_custom_1"
1938
1939user_message = next(
1940e.item
1941for e in events
1942if e.type == "thread.item.done" and e.item.type == "user_message"
1943)
1944assert user_message.id == "message_custom_2_thr_custom_1"