openai/chatkit-python
Publicmirrored fromhttps://github.com/openai/chatkit-pythonAvailable
tests/helpers/mock_store.py
334lines · modecode
| 1 | import sqlite3 |
| 2 | from pathlib import Path |
| 3 | from typing import Any |
| 4 | |
| 5 | from pydantic import BaseModel |
| 6 | |
| 7 | from chatkit.store import NotFoundError, Store |
| 8 | from chatkit.types import ( |
| 9 | Attachment, |
| 10 | Page, |
| 11 | ThreadItem, |
| 12 | ThreadMetadata, |
| 13 | ) |
| 14 | from tests._types import RequestContext |
| 15 | |
| 16 | from .mock_widget import SampleWidget |
| 17 | |
| 18 | SCHEMA_VERSION = 5 # Bump every time the schema changes |
| 19 | |
| 20 | |
| 21 | class ThreadData(BaseModel): |
| 22 | thread: ThreadMetadata |
| 23 | |
| 24 | |
| 25 | class ItemData(BaseModel): |
| 26 | item: ThreadItem |
| 27 | |
| 28 | |
| 29 | class AttachmentData(BaseModel): |
| 30 | attachment: Attachment |
| 31 | |
| 32 | |
| 33 | class SampleWidgetData(BaseModel): |
| 34 | widget: SampleWidget |
| 35 | |
| 36 | |
| 37 | class SQLiteStore(Store[RequestContext]): |
| 38 | def __init__(self, db_path: str | None = None): |
| 39 | self.db_path = ( |
| 40 | db_path or Path(__file__).parent / f"chatkit_v{SCHEMA_VERSION}.db" |
| 41 | ) |
| 42 | self._create_tables() |
| 43 | |
| 44 | def _create_connection(self): |
| 45 | return sqlite3.connect(self.db_path, uri=True) |
| 46 | |
| 47 | def _create_tables(self): |
| 48 | with self._create_connection() as conn: |
| 49 | conn.execute( |
| 50 | """CREATE TABLE IF NOT EXISTS items ( |
| 51 | id TEXT PRIMARY KEY, |
| 52 | thread_id TEXT NOT NULL, |
| 53 | user_id TEXT NOT NULL, |
| 54 | created_at TEXT NOT NULL, |
| 55 | data TEXT NOT NULL |
| 56 | )""" |
| 57 | ) |
| 58 | |
| 59 | conn.execute( |
| 60 | """CREATE TABLE IF NOT EXISTS threads ( |
| 61 | id TEXT PRIMARY KEY, |
| 62 | user_id TEXT NOT NULL, |
| 63 | created_at TEXT NOT NULL, |
| 64 | data TEXT NOT NULL |
| 65 | )""" |
| 66 | ) |
| 67 | |
| 68 | conn.execute( |
| 69 | """CREATE TABLE IF NOT EXISTS files ( |
| 70 | id TEXT PRIMARY KEY, |
| 71 | user_id TEXT NOT NULL, |
| 72 | data TEXT NOT NULL |
| 73 | )""" |
| 74 | ) |
| 75 | |
| 76 | conn.execute( |
| 77 | """CREATE TABLE IF NOT EXISTS sample_widgets ( |
| 78 | id TEXT PRIMARY KEY, |
| 79 | user_id TEXT NOT NULL, |
| 80 | data TEXT NOT NULL |
| 81 | )""" |
| 82 | ) |
| 83 | |
| 84 | async def load_thread( |
| 85 | self, thread_id: str, context: RequestContext |
| 86 | ) -> ThreadMetadata: |
| 87 | with self._create_connection() as conn: |
| 88 | thread_cursor = conn.execute( |
| 89 | "SELECT data, created_at FROM threads WHERE id = ? AND user_id = ?", |
| 90 | (thread_id, context.user_id), |
| 91 | ).fetchone() |
| 92 | if thread_cursor is None: |
| 93 | raise NotFoundError(f"Thread {thread_id} not found") |
| 94 | |
| 95 | return ThreadData.model_validate_json(thread_cursor[0]).thread |
| 96 | |
| 97 | async def save_thread( |
| 98 | self, thread: ThreadMetadata, context: RequestContext |
| 99 | ) -> None: |
| 100 | with self._create_connection() as conn: |
| 101 | thread_data = ThreadData(thread=thread) |
| 102 | conn.execute( |
| 103 | "DELETE FROM threads WHERE id = ? AND user_id = ?", |
| 104 | (thread.id, context.user_id), |
| 105 | ) |
| 106 | conn.execute( |
| 107 | "INSERT INTO threads (id, user_id, created_at, data) VALUES (?, ?, ?, ?)", |
| 108 | ( |
| 109 | thread.id, |
| 110 | context.user_id, |
| 111 | thread.created_at.isoformat(), |
| 112 | thread_data.model_dump_json(), |
| 113 | ), |
| 114 | ) |
| 115 | conn.commit() |
| 116 | |
| 117 | async def load_thread_items( |
| 118 | self, |
| 119 | thread_id: str, |
| 120 | after: str | None, |
| 121 | limit: int, |
| 122 | order: str, |
| 123 | context: RequestContext, |
| 124 | ) -> Page[ThreadItem]: |
| 125 | with self._create_connection() as conn: |
| 126 | created_after: str | None = None |
| 127 | if after: |
| 128 | created_of_after_thread = conn.execute( |
| 129 | "SELECT created_at FROM items WHERE id = ? AND user_id = ?", |
| 130 | (after, context.user_id), |
| 131 | ).fetchone() |
| 132 | if created_of_after_thread is None: |
| 133 | raise NotFoundError(f"Item {after} not found") |
| 134 | created_after = created_of_after_thread[0] |
| 135 | |
| 136 | query = """ |
| 137 | SELECT id, data FROM items |
| 138 | WHERE thread_id = ? AND user_id = ? |
| 139 | """ |
| 140 | params: list[Any] = [thread_id, context.user_id] |
| 141 | if created_after: |
| 142 | query += ( |
| 143 | " AND created_at > ?" if order == "asc" else " AND created_at < ?" |
| 144 | ) |
| 145 | params.append(created_after) |
| 146 | |
| 147 | query += f" ORDER BY created_at {order} LIMIT ?" |
| 148 | params.append(limit + 1) |
| 149 | |
| 150 | items_cursor = conn.execute(query, params) |
| 151 | items = [ |
| 152 | ItemData.model_validate_json(item[1]).item for item in items_cursor |
| 153 | ] |
| 154 | has_more = len(items) > limit |
| 155 | next_after: str | None = None |
| 156 | if has_more: |
| 157 | items = items[:limit] |
| 158 | next_after = items[-1].id |
| 159 | return Page[ThreadItem](data=items, has_more=has_more, after=next_after) |
| 160 | |
| 161 | async def save_attachment( |
| 162 | self, attachment: Attachment, context: RequestContext |
| 163 | ) -> None: |
| 164 | with self._create_connection() as conn: |
| 165 | conn.execute( |
| 166 | "INSERT OR REPLACE INTO files (id, user_id, data) VALUES (?, ?, ?)", |
| 167 | ( |
| 168 | attachment.id, |
| 169 | context.user_id, |
| 170 | AttachmentData(attachment=attachment).model_dump_json(), |
| 171 | ), |
| 172 | ) |
| 173 | conn.commit() |
| 174 | |
| 175 | async def load_attachment( |
| 176 | self, attachment_id: str, context: RequestContext |
| 177 | ) -> Attachment: |
| 178 | with self._create_connection() as conn: |
| 179 | file_cursor = conn.execute( |
| 180 | "SELECT data FROM files WHERE id = ?", # TODO: consider checking user_id |
| 181 | (attachment_id,), |
| 182 | ).fetchone() |
| 183 | if file_cursor is None: |
| 184 | raise NotFoundError(f"File {attachment_id} not found") |
| 185 | return AttachmentData.model_validate_json(file_cursor[0]).attachment |
| 186 | |
| 187 | async def load_threads( |
| 188 | self, |
| 189 | limit: int, |
| 190 | after: str | None, |
| 191 | order: str, |
| 192 | context: RequestContext, |
| 193 | ) -> Page[ThreadMetadata]: |
| 194 | with self._create_connection() as conn: |
| 195 | created_after: str | None = None |
| 196 | if after: |
| 197 | created_of_after_thread = conn.execute( |
| 198 | "SELECT created_at FROM threads WHERE id = ? AND user_id = ?", |
| 199 | (after, context.user_id), |
| 200 | ).fetchone() |
| 201 | if created_of_after_thread is None: |
| 202 | raise ValueError(f"Thread {after} not found") |
| 203 | created_after = created_of_after_thread[0] |
| 204 | |
| 205 | query = """ |
| 206 | SELECT data |
| 207 | FROM threads |
| 208 | WHERE user_id = ? |
| 209 | """ |
| 210 | params: list[Any] = [context.user_id] |
| 211 | if created_after: |
| 212 | query += ( |
| 213 | " AND created_at > ?" if order == "asc" else " AND created_at < ?" |
| 214 | ) |
| 215 | params.append(created_after) |
| 216 | query += f" ORDER BY created_at {order} LIMIT ?" |
| 217 | params.append(limit + 1) |
| 218 | threads_cursor = conn.execute(query, params).fetchall() |
| 219 | result = [] |
| 220 | for data in threads_cursor: |
| 221 | thread = ThreadData.model_validate_json(data[0]).thread |
| 222 | result.append(thread) |
| 223 | next_after = None |
| 224 | has_more = len(result) > limit |
| 225 | if has_more: |
| 226 | result = result[:limit] |
| 227 | next_after = result[-1].id |
| 228 | return Page[ThreadMetadata]( |
| 229 | data=result, has_more=has_more, after=next_after |
| 230 | ) |
| 231 | |
| 232 | async def add_thread_item( |
| 233 | self, thread_id: str, item: ThreadItem, context: RequestContext |
| 234 | ) -> None: |
| 235 | with self._create_connection() as conn: |
| 236 | conn.execute( |
| 237 | "INSERT INTO items (id, thread_id, user_id, created_at, data) VALUES (?, ?, ?, ?, ?)", |
| 238 | ( |
| 239 | item.id, |
| 240 | thread_id, |
| 241 | context.user_id, |
| 242 | item.created_at.isoformat(), |
| 243 | ItemData(item=item).model_dump_json(), |
| 244 | ), |
| 245 | ) |
| 246 | conn.commit() |
| 247 | |
| 248 | async def save_item( |
| 249 | self, thread_id: str, item: ThreadItem, context: RequestContext |
| 250 | ) -> None: |
| 251 | with self._create_connection() as conn: |
| 252 | conn.execute( |
| 253 | "UPDATE items SET data = ? WHERE id = ? AND thread_id = ? AND user_id = ?", |
| 254 | ( |
| 255 | ItemData(item=item).model_dump_json(), |
| 256 | item.id, |
| 257 | thread_id, |
| 258 | context.user_id, |
| 259 | ), |
| 260 | ) |
| 261 | conn.commit() |
| 262 | |
| 263 | async def load_item( |
| 264 | self, thread_id: str, item_id: str, context: RequestContext |
| 265 | ) -> ThreadItem: |
| 266 | with self._create_connection() as conn: |
| 267 | cursor = conn.execute( |
| 268 | "SELECT data FROM items WHERE id = ? AND thread_id = ? AND user_id = ?", |
| 269 | (item_id, thread_id, context.user_id), |
| 270 | ).fetchone() |
| 271 | if cursor is None: |
| 272 | raise NotFoundError(f"Item {item_id} not found in thread {thread_id}") |
| 273 | return ItemData.model_validate_json(cursor[0]).item |
| 274 | |
| 275 | async def delete_thread(self, thread_id: str, context: RequestContext) -> None: |
| 276 | with self._create_connection() as conn: |
| 277 | conn.execute( |
| 278 | "DELETE FROM threads WHERE id = ? AND user_id = ?", |
| 279 | (thread_id, context.user_id), |
| 280 | ) |
| 281 | conn.execute( |
| 282 | "DELETE FROM items WHERE thread_id = ? AND user_id = ?", |
| 283 | (thread_id, context.user_id), |
| 284 | ) |
| 285 | conn.commit() |
| 286 | |
| 287 | async def delete_attachment( |
| 288 | self, attachment_id: str, context: RequestContext |
| 289 | ) -> None: |
| 290 | with self._create_connection() as conn: |
| 291 | conn.execute( |
| 292 | "DELETE FROM files WHERE id = ? AND user_id = ?", |
| 293 | (attachment_id, context.user_id), |
| 294 | ) |
| 295 | conn.commit() |
| 296 | |
| 297 | async def delete_thread_item( |
| 298 | self, thread_id: str, item_id: str, context: RequestContext |
| 299 | ) -> None: |
| 300 | with self._create_connection() as conn: |
| 301 | conn.execute( |
| 302 | "DELETE FROM items WHERE id = ? AND thread_id = ? AND user_id = ?", |
| 303 | (item_id, thread_id, context.user_id), |
| 304 | ) |
| 305 | conn.commit() |
| 306 | |
| 307 | async def save_sample_widget( |
| 308 | self, |
| 309 | widget: SampleWidget, |
| 310 | context: RequestContext, |
| 311 | ) -> None: |
| 312 | with self._create_connection() as conn: |
| 313 | data = SampleWidgetData(widget=widget).model_dump_json() |
| 314 | conn.execute( |
| 315 | "INSERT OR REPLACE INTO sample_widgets (id, user_id, data) VALUES (?, ?, ?)", |
| 316 | ( |
| 317 | widget.id, |
| 318 | context.user_id, |
| 319 | data, |
| 320 | ), |
| 321 | ) |
| 322 | conn.commit() |
| 323 | |
| 324 | async def load_sample_widget( |
| 325 | self, widget_id: str, context: RequestContext |
| 326 | ) -> SampleWidget: |
| 327 | with self._create_connection() as conn: |
| 328 | cursor = conn.execute( |
| 329 | "SELECT data FROM sample_widgets WHERE id = ? AND user_id = ?", |
| 330 | (widget_id, context.user_id), |
| 331 | ).fetchone() |
| 332 | if cursor is None: |
| 333 | raise NotFoundError(f"SampleWidget {widget_id} not found.") |
| 334 | return SampleWidgetData.model_validate_json(cursor[0]).widget |
| 335 | |