openai/openai-python
Publicmirrored fromhttps://github.com/openai/openai-pythonAvailable
examples/responses/websocket.py
436lines · modecode
| 1 | from __future__ import annotations |
| 2 | |
| 3 | import json |
| 4 | import argparse |
| 5 | from typing import TYPE_CHECKING, Dict, Union, Literal, Optional, TypedDict, NamedTuple, cast |
| 6 | |
| 7 | from openai import OpenAI |
| 8 | from openai.types.responses import ( |
| 9 | FunctionToolParam, |
| 10 | ToolChoiceOptions, |
| 11 | ResponseInputParam, |
| 12 | ResponseFailedEvent, |
| 13 | ResponseCompletedEvent, |
| 14 | ResponseInputItemParam, |
| 15 | ResponseIncompleteEvent, |
| 16 | ToolChoiceFunctionParam, |
| 17 | ) |
| 18 | |
| 19 | if TYPE_CHECKING: |
| 20 | from openai.resources.responses.responses import ResponsesConnection |
| 21 | |
| 22 | ToolName = Literal["get_sku_inventory", "get_supplier_eta", "get_quality_alerts"] |
| 23 | ToolChoice = Union[ToolChoiceOptions, ToolChoiceFunctionParam] |
| 24 | |
| 25 | |
| 26 | class DemoTurn(TypedDict): |
| 27 | tool_name: ToolName |
| 28 | prompt: str |
| 29 | |
| 30 | |
| 31 | class SKUArguments(TypedDict): |
| 32 | sku: str |
| 33 | |
| 34 | |
| 35 | class SKUInventoryOutput(TypedDict): |
| 36 | sku: str |
| 37 | warehouse: str |
| 38 | on_hand_units: int |
| 39 | reserved_units: int |
| 40 | reorder_point: int |
| 41 | safety_stock: int |
| 42 | |
| 43 | |
| 44 | class SupplierShipment(TypedDict): |
| 45 | shipment_id: str |
| 46 | eta_date: str |
| 47 | quantity: int |
| 48 | risk: str |
| 49 | |
| 50 | |
| 51 | class SupplierETAOutput(TypedDict): |
| 52 | sku: str |
| 53 | supplier_shipments: list[SupplierShipment] |
| 54 | |
| 55 | |
| 56 | class QualityAlert(TypedDict): |
| 57 | alert_id: str |
| 58 | status: str |
| 59 | severity: str |
| 60 | summary: str |
| 61 | |
| 62 | |
| 63 | class QualityAlertsOutput(TypedDict): |
| 64 | sku: str |
| 65 | alerts: list[QualityAlert] |
| 66 | |
| 67 | |
| 68 | class FunctionCallOutputItem(TypedDict): |
| 69 | type: Literal["function_call_output"] |
| 70 | call_id: str |
| 71 | output: str |
| 72 | |
| 73 | |
| 74 | class FunctionCallRequest(NamedTuple): |
| 75 | name: str |
| 76 | arguments_json: str |
| 77 | call_id: str |
| 78 | |
| 79 | |
| 80 | class RunResponseResult(NamedTuple): |
| 81 | text: str |
| 82 | response_id: str |
| 83 | function_calls: list[FunctionCallRequest] |
| 84 | |
| 85 | |
| 86 | class RunTurnResult(NamedTuple): |
| 87 | assistant_text: str |
| 88 | response_id: str |
| 89 | |
| 90 | |
| 91 | ToolOutput = Union[SKUInventoryOutput, SupplierETAOutput, QualityAlertsOutput] |
| 92 | |
| 93 | TOOLS: list[FunctionToolParam] = [ |
| 94 | { |
| 95 | "type": "function", |
| 96 | "name": "get_sku_inventory", |
| 97 | "description": "Return froge pond inventory details for a SKU.", |
| 98 | "strict": True, |
| 99 | "parameters": { |
| 100 | "type": "object", |
| 101 | "properties": { |
| 102 | "sku": { |
| 103 | "type": "string", |
| 104 | "description": "Stock-keeping unit identifier, such as sku-froge-lily-pad-deluxe.", |
| 105 | } |
| 106 | }, |
| 107 | "required": ["sku"], |
| 108 | "additionalProperties": False, |
| 109 | }, |
| 110 | }, |
| 111 | { |
| 112 | "type": "function", |
| 113 | "name": "get_supplier_eta", |
| 114 | "description": "Return tadpole supplier restock ETA data for a SKU.", |
| 115 | "strict": True, |
| 116 | "parameters": { |
| 117 | "type": "object", |
| 118 | "properties": { |
| 119 | "sku": { |
| 120 | "type": "string", |
| 121 | "description": "Stock-keeping unit identifier, such as sku-froge-lily-pad-deluxe.", |
| 122 | } |
| 123 | }, |
| 124 | "required": ["sku"], |
| 125 | "additionalProperties": False, |
| 126 | }, |
| 127 | }, |
| 128 | { |
| 129 | "type": "function", |
| 130 | "name": "get_quality_alerts", |
| 131 | "description": "Return recent froge quality alerts for a SKU.", |
| 132 | "strict": True, |
| 133 | "parameters": { |
| 134 | "type": "object", |
| 135 | "properties": { |
| 136 | "sku": { |
| 137 | "type": "string", |
| 138 | "description": "Stock-keeping unit identifier, such as sku-froge-lily-pad-deluxe.", |
| 139 | } |
| 140 | }, |
| 141 | "required": ["sku"], |
| 142 | "additionalProperties": False, |
| 143 | }, |
| 144 | }, |
| 145 | ] |
| 146 | |
| 147 | DEMO_TURNS: list[DemoTurn] = [ |
| 148 | { |
| 149 | "tool_name": "get_sku_inventory", |
| 150 | "prompt": "Use get_sku_inventory for sku='sku-froge-lily-pad-deluxe' and summarize current pond stock health in one sentence.", |
| 151 | }, |
| 152 | { |
| 153 | "tool_name": "get_supplier_eta", |
| 154 | "prompt": "Now use get_supplier_eta for the same SKU and summarize restock ETA and tadpole shipment risk.", |
| 155 | }, |
| 156 | { |
| 157 | "tool_name": "get_quality_alerts", |
| 158 | "prompt": "Finally use get_quality_alerts for the same SKU and summarize unresolved froge quality concerns in one short paragraph.", |
| 159 | }, |
| 160 | ] |
| 161 | |
| 162 | BETA_HEADER_VALUE = "responses_websockets=2026-02-06" |
| 163 | |
| 164 | |
| 165 | def parse_args() -> argparse.Namespace: |
| 166 | parser = argparse.ArgumentParser( |
| 167 | description=("Run a 3-turn Responses WebSocket demo with function calling and chained previous_response_id.") |
| 168 | ) |
| 169 | parser.add_argument("--model", default="gpt-5.2", help="Model used in the `response.create` payload.") |
| 170 | parser.add_argument( |
| 171 | "--use-beta-header", |
| 172 | action="store_true", |
| 173 | help=f"Include `OpenAI-Beta: {BETA_HEADER_VALUE}` for beta websocket behavior.", |
| 174 | ) |
| 175 | parser.add_argument( |
| 176 | "--show-events", |
| 177 | action="store_true", |
| 178 | help="Print non-text event types while streaming.", |
| 179 | ) |
| 180 | parser.add_argument( |
| 181 | "--show-tool-io", |
| 182 | action="store_true", |
| 183 | help="Print each tool call and tool output payload.", |
| 184 | ) |
| 185 | return parser.parse_args() |
| 186 | |
| 187 | |
| 188 | def parse_tool_name(name: str) -> ToolName: |
| 189 | if name not in {"get_sku_inventory", "get_supplier_eta", "get_quality_alerts"}: |
| 190 | raise ValueError(f"Unsupported tool requested: {name}") |
| 191 | return cast(ToolName, name) |
| 192 | |
| 193 | |
| 194 | def parse_sku_arguments(raw_arguments: str) -> SKUArguments: |
| 195 | parsed_raw = json.loads(raw_arguments) |
| 196 | if not isinstance(parsed_raw, dict): |
| 197 | raise ValueError(f"Tool arguments must be a JSON object: {raw_arguments}") |
| 198 | |
| 199 | parsed = cast(Dict[str, object], parsed_raw) |
| 200 | sku_value = parsed.get("sku") |
| 201 | if not isinstance(sku_value, str): |
| 202 | raise ValueError(f"Tool arguments must include a string `sku`: {raw_arguments}") |
| 203 | |
| 204 | return {"sku": sku_value} |
| 205 | |
| 206 | |
| 207 | def call_tool(name: ToolName, arguments: SKUArguments) -> ToolOutput: |
| 208 | sku = arguments["sku"] |
| 209 | |
| 210 | if name == "get_sku_inventory": |
| 211 | return { |
| 212 | "sku": sku, |
| 213 | "warehouse": "pond-west-1", |
| 214 | "on_hand_units": 84, |
| 215 | "reserved_units": 26, |
| 216 | "reorder_point": 60, |
| 217 | "safety_stock": 40, |
| 218 | } |
| 219 | |
| 220 | if name == "get_supplier_eta": |
| 221 | return { |
| 222 | "sku": sku, |
| 223 | "supplier_shipments": [ |
| 224 | { |
| 225 | "shipment_id": "frog_ship_2201", |
| 226 | "eta_date": "2026-02-24", |
| 227 | "quantity": 180, |
| 228 | "risk": "low", |
| 229 | }, |
| 230 | { |
| 231 | "shipment_id": "frog_ship_2205", |
| 232 | "eta_date": "2026-03-03", |
| 233 | "quantity": 220, |
| 234 | "risk": "medium", |
| 235 | }, |
| 236 | ], |
| 237 | } |
| 238 | |
| 239 | if name == "get_quality_alerts": |
| 240 | return { |
| 241 | "sku": sku, |
| 242 | "alerts": [ |
| 243 | { |
| 244 | "alert_id": "frog_qa_781", |
| 245 | "status": "open", |
| 246 | "severity": "high", |
| 247 | "summary": "Lily-pad coating chipping in lot LP-42", |
| 248 | }, |
| 249 | { |
| 250 | "alert_id": "frog_qa_795", |
| 251 | "status": "in_progress", |
| 252 | "severity": "medium", |
| 253 | "summary": "Pond-crate scuff rate above threshold", |
| 254 | }, |
| 255 | { |
| 256 | "alert_id": "frog_qa_802", |
| 257 | "status": "resolved", |
| 258 | "severity": "low", |
| 259 | "summary": "Froge label alignment issue corrected", |
| 260 | }, |
| 261 | ], |
| 262 | } |
| 263 | |
| 264 | raise ValueError(f"Unknown tool: {name}") |
| 265 | |
| 266 | |
| 267 | def run_response( |
| 268 | *, |
| 269 | connection: ResponsesConnection, |
| 270 | model: str, |
| 271 | previous_response_id: Optional[str], |
| 272 | input_payload: Union[str, ResponseInputParam], |
| 273 | tools: list[FunctionToolParam], |
| 274 | tool_choice: ToolChoice, |
| 275 | show_events: bool, |
| 276 | ) -> RunResponseResult: |
| 277 | connection.response.create( |
| 278 | model=model, |
| 279 | input=input_payload, |
| 280 | stream=True, |
| 281 | previous_response_id=previous_response_id, |
| 282 | tools=tools, |
| 283 | tool_choice=tool_choice, |
| 284 | ) |
| 285 | |
| 286 | text_parts: list[str] = [] |
| 287 | function_calls: list[FunctionCallRequest] = [] |
| 288 | response_id: Optional[str] = None |
| 289 | |
| 290 | for event in connection: |
| 291 | if event.type == "response.output_text.delta": |
| 292 | text_parts.append(event.delta) |
| 293 | continue |
| 294 | |
| 295 | if event.type == "response.output_item.done" and event.item.type == "function_call": |
| 296 | function_calls.append( |
| 297 | FunctionCallRequest( |
| 298 | name=event.item.name, |
| 299 | arguments_json=event.item.arguments, |
| 300 | call_id=event.item.call_id, |
| 301 | ) |
| 302 | ) |
| 303 | continue |
| 304 | |
| 305 | if getattr(event, "type", None) == "error": |
| 306 | raise RuntimeError(f"WebSocket error event: {event!r}") |
| 307 | |
| 308 | if isinstance(event, (ResponseCompletedEvent, ResponseFailedEvent, ResponseIncompleteEvent)): |
| 309 | response_id = event.response.id |
| 310 | if not isinstance(event, ResponseCompletedEvent): |
| 311 | raise RuntimeError(f"Response ended with {event.type} (id={response_id})") |
| 312 | if show_events: |
| 313 | print(f"[{event.type}]") |
| 314 | break |
| 315 | |
| 316 | if getattr(event, "type", None) == "response.done": |
| 317 | # Responses over WebSocket currently emit `response.done` as the final event. |
| 318 | # The payload still includes `response.id`, which we use for chaining. |
| 319 | event_response = getattr(event, "response", None) |
| 320 | event_response_id: Optional[str] = None |
| 321 | if isinstance(event_response, dict): |
| 322 | event_response_dict = cast(Dict[str, object], event_response) |
| 323 | raw_event_response_id = event_response_dict.get("id") |
| 324 | if isinstance(raw_event_response_id, str): |
| 325 | event_response_id = raw_event_response_id |
| 326 | else: |
| 327 | raw_event_response_id = getattr(event_response, "id", None) |
| 328 | if isinstance(raw_event_response_id, str): |
| 329 | event_response_id = raw_event_response_id |
| 330 | |
| 331 | if not isinstance(event_response_id, str): |
| 332 | raise RuntimeError(f"response.done event did not include a valid response.id: {event!r}") |
| 333 | |
| 334 | response_id = event_response_id |
| 335 | if show_events: |
| 336 | print("[response.done]") |
| 337 | break |
| 338 | |
| 339 | if show_events: |
| 340 | print(f"[{event.type}]") |
| 341 | |
| 342 | if response_id is None: |
| 343 | raise RuntimeError("No terminal response event received.") |
| 344 | |
| 345 | return RunResponseResult( |
| 346 | text="".join(text_parts), |
| 347 | response_id=response_id, |
| 348 | function_calls=function_calls, |
| 349 | ) |
| 350 | |
| 351 | |
| 352 | def run_turn( |
| 353 | *, |
| 354 | connection: ResponsesConnection, |
| 355 | model: str, |
| 356 | previous_response_id: Optional[str], |
| 357 | turn_prompt: str, |
| 358 | forced_tool_name: ToolName, |
| 359 | show_events: bool, |
| 360 | show_tool_io: bool, |
| 361 | ) -> RunTurnResult: |
| 362 | accumulated_text_parts: list[str] = [] |
| 363 | |
| 364 | current_input: Union[str, ResponseInputParam] = turn_prompt |
| 365 | current_tool_choice: ToolChoice = {"type": "function", "name": forced_tool_name} |
| 366 | current_previous_response_id = previous_response_id |
| 367 | |
| 368 | while True: |
| 369 | response_result = run_response( |
| 370 | connection=connection, |
| 371 | model=model, |
| 372 | previous_response_id=current_previous_response_id, |
| 373 | input_payload=current_input, |
| 374 | tools=TOOLS, |
| 375 | tool_choice=current_tool_choice, |
| 376 | show_events=show_events, |
| 377 | ) |
| 378 | |
| 379 | if response_result.text: |
| 380 | accumulated_text_parts.append(response_result.text) |
| 381 | |
| 382 | current_previous_response_id = response_result.response_id |
| 383 | if not response_result.function_calls: |
| 384 | break |
| 385 | |
| 386 | tool_outputs: ResponseInputParam = [] |
| 387 | for function_call in response_result.function_calls: |
| 388 | tool_name = parse_tool_name(function_call.name) |
| 389 | arguments = parse_sku_arguments(function_call.arguments_json) |
| 390 | output_payload = call_tool(tool_name, arguments) |
| 391 | if show_tool_io: |
| 392 | print(f"[tool_call] {function_call.name}({function_call.arguments_json})") |
| 393 | print(f"[tool_output] {json.dumps(output_payload)}") |
| 394 | |
| 395 | function_call_output: FunctionCallOutputItem = { |
| 396 | "type": "function_call_output", |
| 397 | "call_id": function_call.call_id, |
| 398 | "output": json.dumps(output_payload), |
| 399 | } |
| 400 | tool_outputs.append(cast(ResponseInputItemParam, function_call_output)) |
| 401 | |
| 402 | current_input = tool_outputs |
| 403 | current_tool_choice = "none" |
| 404 | |
| 405 | return RunTurnResult( |
| 406 | assistant_text="".join(accumulated_text_parts).strip(), |
| 407 | response_id=current_previous_response_id, |
| 408 | ) |
| 409 | |
| 410 | |
| 411 | def main() -> None: |
| 412 | args = parse_args() |
| 413 | |
| 414 | client = OpenAI() |
| 415 | extra_headers = {"OpenAI-Beta": BETA_HEADER_VALUE} if args.use_beta_header else {} |
| 416 | |
| 417 | with client.responses.connect(extra_headers=extra_headers) as connection: |
| 418 | previous_response_id: Optional[str] = None |
| 419 | for index, turn in enumerate(DEMO_TURNS, start=1): |
| 420 | print(f"\n=== Turn {index} ===") |
| 421 | print(f"User: {turn['prompt']}") |
| 422 | turn_result = run_turn( |
| 423 | connection=connection, |
| 424 | model=args.model, |
| 425 | previous_response_id=previous_response_id, |
| 426 | turn_prompt=turn["prompt"], |
| 427 | forced_tool_name=turn["tool_name"], |
| 428 | show_events=args.show_events, |
| 429 | show_tool_io=args.show_tool_io, |
| 430 | ) |
| 431 | previous_response_id = turn_result.response_id |
| 432 | print(f"Assistant: {turn_result.assistant_text}") |
| 433 | |
| 434 | |
| 435 | if __name__ == "__main__": |
| 436 | main() |
| 437 | |