openai/chatkit-python

Public

mirrored fromhttps://github.com/openai/chatkit-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.6.5

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/helpers/mock_store.py

334lines · modecode

1import sqlite3
2from pathlib import Path
3from typing import Any
4
5from pydantic import BaseModel
6
7from chatkit.store import NotFoundError, Store
8from chatkit.types import (
9 Attachment,
10 Page,
11 ThreadItem,
12 ThreadMetadata,
13)
14from tests._types import RequestContext
15
16from .mock_widget import SampleWidget
17
18SCHEMA_VERSION = 5 # Bump every time the schema changes
19
20
21class ThreadData(BaseModel):
22 thread: ThreadMetadata
23
24
25class ItemData(BaseModel):
26 item: ThreadItem
27
28
29class AttachmentData(BaseModel):
30 attachment: Attachment
31
32
33class SampleWidgetData(BaseModel):
34 widget: SampleWidget
35
36
37class SQLiteStore(Store[RequestContext]):
38 def __init__(self, db_path: str | None = None):
39 self.db_path = (
40 db_path or Path(__file__).parent / f"chatkit_v{SCHEMA_VERSION}.db"
41 )
42 self._create_tables()
43
44 def _create_connection(self):
45 return sqlite3.connect(self.db_path, uri=True)
46
47 def _create_tables(self):
48 with self._create_connection() as conn:
49 conn.execute(
50 """CREATE TABLE IF NOT EXISTS items (
51 id TEXT PRIMARY KEY,
52 thread_id TEXT NOT NULL,
53 user_id TEXT NOT NULL,
54 created_at TEXT NOT NULL,
55 data TEXT NOT NULL
56 )"""
57 )
58
59 conn.execute(
60 """CREATE TABLE IF NOT EXISTS threads (
61 id TEXT PRIMARY KEY,
62 user_id TEXT NOT NULL,
63 created_at TEXT NOT NULL,
64 data TEXT NOT NULL
65 )"""
66 )
67
68 conn.execute(
69 """CREATE TABLE IF NOT EXISTS files (
70 id TEXT PRIMARY KEY,
71 user_id TEXT NOT NULL,
72 data TEXT NOT NULL
73 )"""
74 )
75
76 conn.execute(
77 """CREATE TABLE IF NOT EXISTS sample_widgets (
78 id TEXT PRIMARY KEY,
79 user_id TEXT NOT NULL,
80 data TEXT NOT NULL
81 )"""
82 )
83
84 async def load_thread(
85 self, thread_id: str, context: RequestContext
86 ) -> ThreadMetadata:
87 with self._create_connection() as conn:
88 thread_cursor = conn.execute(
89 "SELECT data, created_at FROM threads WHERE id = ? AND user_id = ?",
90 (thread_id, context.user_id),
91 ).fetchone()
92 if thread_cursor is None:
93 raise NotFoundError(f"Thread {thread_id} not found")
94
95 return ThreadData.model_validate_json(thread_cursor[0]).thread
96
97 async def save_thread(
98 self, thread: ThreadMetadata, context: RequestContext
99 ) -> None:
100 with self._create_connection() as conn:
101 thread_data = ThreadData(thread=thread)
102 conn.execute(
103 "DELETE FROM threads WHERE id = ? AND user_id = ?",
104 (thread.id, context.user_id),
105 )
106 conn.execute(
107 "INSERT INTO threads (id, user_id, created_at, data) VALUES (?, ?, ?, ?)",
108 (
109 thread.id,
110 context.user_id,
111 thread.created_at.isoformat(),
112 thread_data.model_dump_json(),
113 ),
114 )
115 conn.commit()
116
117 async def load_thread_items(
118 self,
119 thread_id: str,
120 after: str | None,
121 limit: int,
122 order: str,
123 context: RequestContext,
124 ) -> Page[ThreadItem]:
125 with self._create_connection() as conn:
126 created_after: str | None = None
127 if after:
128 created_of_after_thread = conn.execute(
129 "SELECT created_at FROM items WHERE id = ? AND user_id = ?",
130 (after, context.user_id),
131 ).fetchone()
132 if created_of_after_thread is None:
133 raise NotFoundError(f"Item {after} not found")
134 created_after = created_of_after_thread[0]
135
136 query = """
137 SELECT id, data FROM items
138 WHERE thread_id = ? AND user_id = ?
139 """
140 params: list[Any] = [thread_id, context.user_id]
141 if created_after:
142 query += (
143 " AND created_at > ?" if order == "asc" else " AND created_at < ?"
144 )
145 params.append(created_after)
146
147 query += f" ORDER BY created_at {order} LIMIT ?"
148 params.append(limit + 1)
149
150 items_cursor = conn.execute(query, params)
151 items = [
152 ItemData.model_validate_json(item[1]).item for item in items_cursor
153 ]
154 has_more = len(items) > limit
155 next_after: str | None = None
156 if has_more:
157 items = items[:limit]
158 next_after = items[-1].id
159 return Page[ThreadItem](data=items, has_more=has_more, after=next_after)
160
161 async def save_attachment(
162 self, attachment: Attachment, context: RequestContext
163 ) -> None:
164 with self._create_connection() as conn:
165 conn.execute(
166 "INSERT OR REPLACE INTO files (id, user_id, data) VALUES (?, ?, ?)",
167 (
168 attachment.id,
169 context.user_id,
170 AttachmentData(attachment=attachment).model_dump_json(),
171 ),
172 )
173 conn.commit()
174
175 async def load_attachment(
176 self, attachment_id: str, context: RequestContext
177 ) -> Attachment:
178 with self._create_connection() as conn:
179 file_cursor = conn.execute(
180 "SELECT data FROM files WHERE id = ?", # TODO: consider checking user_id
181 (attachment_id,),
182 ).fetchone()
183 if file_cursor is None:
184 raise NotFoundError(f"File {attachment_id} not found")
185 return AttachmentData.model_validate_json(file_cursor[0]).attachment
186
187 async def load_threads(
188 self,
189 limit: int,
190 after: str | None,
191 order: str,
192 context: RequestContext,
193 ) -> Page[ThreadMetadata]:
194 with self._create_connection() as conn:
195 created_after: str | None = None
196 if after:
197 created_of_after_thread = conn.execute(
198 "SELECT created_at FROM threads WHERE id = ? AND user_id = ?",
199 (after, context.user_id),
200 ).fetchone()
201 if created_of_after_thread is None:
202 raise ValueError(f"Thread {after} not found")
203 created_after = created_of_after_thread[0]
204
205 query = """
206 SELECT data
207 FROM threads
208 WHERE user_id = ?
209 """
210 params: list[Any] = [context.user_id]
211 if created_after:
212 query += (
213 " AND created_at > ?" if order == "asc" else " AND created_at < ?"
214 )
215 params.append(created_after)
216 query += f" ORDER BY created_at {order} LIMIT ?"
217 params.append(limit + 1)
218 threads_cursor = conn.execute(query, params).fetchall()
219 result = []
220 for data in threads_cursor:
221 thread = ThreadData.model_validate_json(data[0]).thread
222 result.append(thread)
223 next_after = None
224 has_more = len(result) > limit
225 if has_more:
226 result = result[:limit]
227 next_after = result[-1].id
228 return Page[ThreadMetadata](
229 data=result, has_more=has_more, after=next_after
230 )
231
232 async def add_thread_item(
233 self, thread_id: str, item: ThreadItem, context: RequestContext
234 ) -> None:
235 with self._create_connection() as conn:
236 conn.execute(
237 "INSERT INTO items (id, thread_id, user_id, created_at, data) VALUES (?, ?, ?, ?, ?)",
238 (
239 item.id,
240 thread_id,
241 context.user_id,
242 item.created_at.isoformat(),
243 ItemData(item=item).model_dump_json(),
244 ),
245 )
246 conn.commit()
247
248 async def save_item(
249 self, thread_id: str, item: ThreadItem, context: RequestContext
250 ) -> None:
251 with self._create_connection() as conn:
252 conn.execute(
253 "UPDATE items SET data = ? WHERE id = ? AND thread_id = ? AND user_id = ?",
254 (
255 ItemData(item=item).model_dump_json(),
256 item.id,
257 thread_id,
258 context.user_id,
259 ),
260 )
261 conn.commit()
262
263 async def load_item(
264 self, thread_id: str, item_id: str, context: RequestContext
265 ) -> ThreadItem:
266 with self._create_connection() as conn:
267 cursor = conn.execute(
268 "SELECT data FROM items WHERE id = ? AND thread_id = ? AND user_id = ?",
269 (item_id, thread_id, context.user_id),
270 ).fetchone()
271 if cursor is None:
272 raise NotFoundError(f"Item {item_id} not found in thread {thread_id}")
273 return ItemData.model_validate_json(cursor[0]).item
274
275 async def delete_thread(self, thread_id: str, context: RequestContext) -> None:
276 with self._create_connection() as conn:
277 conn.execute(
278 "DELETE FROM threads WHERE id = ? AND user_id = ?",
279 (thread_id, context.user_id),
280 )
281 conn.execute(
282 "DELETE FROM items WHERE thread_id = ? AND user_id = ?",
283 (thread_id, context.user_id),
284 )
285 conn.commit()
286
287 async def delete_attachment(
288 self, attachment_id: str, context: RequestContext
289 ) -> None:
290 with self._create_connection() as conn:
291 conn.execute(
292 "DELETE FROM files WHERE id = ? AND user_id = ?",
293 (attachment_id, context.user_id),
294 )
295 conn.commit()
296
297 async def delete_thread_item(
298 self, thread_id: str, item_id: str, context: RequestContext
299 ) -> None:
300 with self._create_connection() as conn:
301 conn.execute(
302 "DELETE FROM items WHERE id = ? AND thread_id = ? AND user_id = ?",
303 (item_id, thread_id, context.user_id),
304 )
305 conn.commit()
306
307 async def save_sample_widget(
308 self,
309 widget: SampleWidget,
310 context: RequestContext,
311 ) -> None:
312 with self._create_connection() as conn:
313 data = SampleWidgetData(widget=widget).model_dump_json()
314 conn.execute(
315 "INSERT OR REPLACE INTO sample_widgets (id, user_id, data) VALUES (?, ?, ?)",
316 (
317 widget.id,
318 context.user_id,
319 data,
320 ),
321 )
322 conn.commit()
323
324 async def load_sample_widget(
325 self, widget_id: str, context: RequestContext
326 ) -> SampleWidget:
327 with self._create_connection() as conn:
328 cursor = conn.execute(
329 "SELECT data FROM sample_widgets WHERE id = ? AND user_id = ?",
330 (widget_id, context.user_id),
331 ).fetchone()
332 if cursor is None:
333 raise NotFoundError(f"SampleWidget {widget_id} not found.")
334 return SampleWidgetData.model_validate_json(cursor[0]).widget
335