openai/openai-python
Publicmirrored from https://github.com/openai/openai-pythonAvailable
examples/responses/websocket.py
436lines · modeblame
e273d622stainless-app[bot]4 months ago | 1 | from __future__ import annotations |
| 2 | | |
| 3 | import json | |
| 4 | import argparse | |
| 5 | from typing import TYPE_CHECKING, Dict, Union, Literal, Optional, TypedDict, NamedTuple, cast | |
| 6 | | |
| 7 | from openai import OpenAI | |
| 8 | from openai.types.responses import ( | |
| 9 | FunctionToolParam, | |
| 10 | ToolChoiceOptions, | |
| 11 | ResponseInputParam, | |
| 12 | ResponseFailedEvent, | |
| 13 | ResponseCompletedEvent, | |
| 14 | ResponseInputItemParam, | |
| 15 | ResponseIncompleteEvent, | |
| 16 | ToolChoiceFunctionParam, | |
| 17 | ) | |
| 18 | | |
| 19 | if TYPE_CHECKING: | |
| 20 | from openai.resources.responses.responses import ResponsesConnection | |
| 21 | | |
| 22 | ToolName = Literal["get_sku_inventory", "get_supplier_eta", "get_quality_alerts"] | |
| 23 | ToolChoice = Union[ToolChoiceOptions, ToolChoiceFunctionParam] | |
| 24 | | |
| 25 | | |
| 26 | class DemoTurn(TypedDict): | |
| 27 | tool_name: ToolName | |
| 28 | prompt: str | |
| 29 | | |
| 30 | | |
| 31 | class SKUArguments(TypedDict): | |
| 32 | sku: str | |
| 33 | | |
| 34 | | |
| 35 | class SKUInventoryOutput(TypedDict): | |
| 36 | sku: str | |
| 37 | warehouse: str | |
| 38 | on_hand_units: int | |
| 39 | reserved_units: int | |
| 40 | reorder_point: int | |
| 41 | safety_stock: int | |
| 42 | | |
| 43 | | |
| 44 | class SupplierShipment(TypedDict): | |
| 45 | shipment_id: str | |
| 46 | eta_date: str | |
| 47 | quantity: int | |
| 48 | risk: str | |
| 49 | | |
| 50 | | |
| 51 | class SupplierETAOutput(TypedDict): | |
| 52 | sku: str | |
| 53 | supplier_shipments: list[SupplierShipment] | |
| 54 | | |
| 55 | | |
| 56 | class QualityAlert(TypedDict): | |
| 57 | alert_id: str | |
| 58 | status: str | |
| 59 | severity: str | |
| 60 | summary: str | |
| 61 | | |
| 62 | | |
| 63 | class QualityAlertsOutput(TypedDict): | |
| 64 | sku: str | |
| 65 | alerts: list[QualityAlert] | |
| 66 | | |
| 67 | | |
| 68 | class FunctionCallOutputItem(TypedDict): | |
| 69 | type: Literal["function_call_output"] | |
| 70 | call_id: str | |
| 71 | output: str | |
| 72 | | |
| 73 | | |
| 74 | class FunctionCallRequest(NamedTuple): | |
| 75 | name: str | |
| 76 | arguments_json: str | |
| 77 | call_id: str | |
| 78 | | |
| 79 | | |
| 80 | class RunResponseResult(NamedTuple): | |
| 81 | text: str | |
| 82 | response_id: str | |
| 83 | function_calls: list[FunctionCallRequest] | |
| 84 | | |
| 85 | | |
| 86 | class RunTurnResult(NamedTuple): | |
| 87 | assistant_text: str | |
| 88 | response_id: str | |
| 89 | | |
| 90 | | |
| 91 | ToolOutput = Union[SKUInventoryOutput, SupplierETAOutput, QualityAlertsOutput] | |
| 92 | | |
| 93 | TOOLS: list[FunctionToolParam] = [ | |
| 94 | { | |
| 95 | "type": "function", | |
| 96 | "name": "get_sku_inventory", | |
| 97 | "description": "Return froge pond inventory details for a SKU.", | |
| 98 | "strict": True, | |
| 99 | "parameters": { | |
| 100 | "type": "object", | |
| 101 | "properties": { | |
| 102 | "sku": { | |
| 103 | "type": "string", | |
| 104 | "description": "Stock-keeping unit identifier, such as sku-froge-lily-pad-deluxe.", | |
| 105 | } | |
| 106 | }, | |
| 107 | "required": ["sku"], | |
| 108 | "additionalProperties": False, | |
| 109 | }, | |
| 110 | }, | |
| 111 | { | |
| 112 | "type": "function", | |
| 113 | "name": "get_supplier_eta", | |
| 114 | "description": "Return tadpole supplier restock ETA data for a SKU.", | |
| 115 | "strict": True, | |
| 116 | "parameters": { | |
| 117 | "type": "object", | |
| 118 | "properties": { | |
| 119 | "sku": { | |
| 120 | "type": "string", | |
| 121 | "description": "Stock-keeping unit identifier, such as sku-froge-lily-pad-deluxe.", | |
| 122 | } | |
| 123 | }, | |
| 124 | "required": ["sku"], | |
| 125 | "additionalProperties": False, | |
| 126 | }, | |
| 127 | }, | |
| 128 | { | |
| 129 | "type": "function", | |
| 130 | "name": "get_quality_alerts", | |
| 131 | "description": "Return recent froge quality alerts for a SKU.", | |
| 132 | "strict": True, | |
| 133 | "parameters": { | |
| 134 | "type": "object", | |
| 135 | "properties": { | |
| 136 | "sku": { | |
| 137 | "type": "string", | |
| 138 | "description": "Stock-keeping unit identifier, such as sku-froge-lily-pad-deluxe.", | |
| 139 | } | |
| 140 | }, | |
| 141 | "required": ["sku"], | |
| 142 | "additionalProperties": False, | |
| 143 | }, | |
| 144 | }, | |
| 145 | ] | |
| 146 | | |
| 147 | DEMO_TURNS: list[DemoTurn] = [ | |
| 148 | { | |
| 149 | "tool_name": "get_sku_inventory", | |
| 150 | "prompt": "Use get_sku_inventory for sku='sku-froge-lily-pad-deluxe' and summarize current pond stock health in one sentence.", | |
| 151 | }, | |
| 152 | { | |
| 153 | "tool_name": "get_supplier_eta", | |
| 154 | "prompt": "Now use get_supplier_eta for the same SKU and summarize restock ETA and tadpole shipment risk.", | |
| 155 | }, | |
| 156 | { | |
| 157 | "tool_name": "get_quality_alerts", | |
| 158 | "prompt": "Finally use get_quality_alerts for the same SKU and summarize unresolved froge quality concerns in one short paragraph.", | |
| 159 | }, | |
| 160 | ] | |
| 161 | | |
| 162 | BETA_HEADER_VALUE = "responses_websockets=2026-02-06" | |
| 163 | | |
| 164 | | |
| 165 | def parse_args() -> argparse.Namespace: | |
| 166 | parser = argparse.ArgumentParser( | |
| 167 | description=("Run a 3-turn Responses WebSocket demo with function calling and chained previous_response_id.") | |
| 168 | ) | |
| 169 | parser.add_argument("--model", default="gpt-5.2", help="Model used in the `response.create` payload.") | |
| 170 | parser.add_argument( | |
| 171 | "--use-beta-header", | |
| 172 | action="store_true", | |
| 173 | help=f"Include `OpenAI-Beta: {BETA_HEADER_VALUE}` for beta websocket behavior.", | |
| 174 | ) | |
| 175 | parser.add_argument( | |
| 176 | "--show-events", | |
| 177 | action="store_true", | |
| 178 | help="Print non-text event types while streaming.", | |
| 179 | ) | |
| 180 | parser.add_argument( | |
| 181 | "--show-tool-io", | |
| 182 | action="store_true", | |
| 183 | help="Print each tool call and tool output payload.", | |
| 184 | ) | |
| 185 | return parser.parse_args() | |
| 186 | | |
| 187 | | |
| 188 | def parse_tool_name(name: str) -> ToolName: | |
| 189 | if name not in {"get_sku_inventory", "get_supplier_eta", "get_quality_alerts"}: | |
| 190 | raise ValueError(f"Unsupported tool requested: {name}") | |
| 191 | return cast(ToolName, name) | |
| 192 | | |
| 193 | | |
| 194 | def parse_sku_arguments(raw_arguments: str) -> SKUArguments: | |
| 195 | parsed_raw = json.loads(raw_arguments) | |
| 196 | if not isinstance(parsed_raw, dict): | |
| 197 | raise ValueError(f"Tool arguments must be a JSON object: {raw_arguments}") | |
| 198 | | |
| 199 | parsed = cast(Dict[str, object], parsed_raw) | |
| 200 | sku_value = parsed.get("sku") | |
| 201 | if not isinstance(sku_value, str): | |
| 202 | raise ValueError(f"Tool arguments must include a string `sku`: {raw_arguments}") | |
| 203 | | |
| 204 | return {"sku": sku_value} | |
| 205 | | |
| 206 | | |
| 207 | def call_tool(name: ToolName, arguments: SKUArguments) -> ToolOutput: | |
| 208 | sku = arguments["sku"] | |
| 209 | | |
| 210 | if name == "get_sku_inventory": | |
| 211 | return { | |
| 212 | "sku": sku, | |
| 213 | "warehouse": "pond-west-1", | |
| 214 | "on_hand_units": 84, | |
| 215 | "reserved_units": 26, | |
| 216 | "reorder_point": 60, | |
| 217 | "safety_stock": 40, | |
| 218 | } | |
| 219 | | |
| 220 | if name == "get_supplier_eta": | |
| 221 | return { | |
| 222 | "sku": sku, | |
| 223 | "supplier_shipments": [ | |
| 224 | { | |
| 225 | "shipment_id": "frog_ship_2201", | |
| 226 | "eta_date": "2026-02-24", | |
| 227 | "quantity": 180, | |
| 228 | "risk": "low", | |
| 229 | }, | |
| 230 | { | |
| 231 | "shipment_id": "frog_ship_2205", | |
| 232 | "eta_date": "2026-03-03", | |
| 233 | "quantity": 220, | |
| 234 | "risk": "medium", | |
| 235 | }, | |
| 236 | ], | |
| 237 | } | |
| 238 | | |
| 239 | if name == "get_quality_alerts": | |
| 240 | return { | |
| 241 | "sku": sku, | |
| 242 | "alerts": [ | |
| 243 | { | |
| 244 | "alert_id": "frog_qa_781", | |
| 245 | "status": "open", | |
| 246 | "severity": "high", | |
| 247 | "summary": "Lily-pad coating chipping in lot LP-42", | |
| 248 | }, | |
| 249 | { | |
| 250 | "alert_id": "frog_qa_795", | |
| 251 | "status": "in_progress", | |
| 252 | "severity": "medium", | |
| 253 | "summary": "Pond-crate scuff rate above threshold", | |
| 254 | }, | |
| 255 | { | |
| 256 | "alert_id": "frog_qa_802", | |
| 257 | "status": "resolved", | |
| 258 | "severity": "low", | |
| 259 | "summary": "Froge label alignment issue corrected", | |
| 260 | }, | |
| 261 | ], | |
| 262 | } | |
| 263 | | |
| 264 | raise ValueError(f"Unknown tool: {name}") | |
| 265 | | |
| 266 | | |
| 267 | def run_response( | |
| 268 | *, | |
| 269 | connection: ResponsesConnection, | |
| 270 | model: str, | |
| 271 | previous_response_id: Optional[str], | |
| 272 | input_payload: Union[str, ResponseInputParam], | |
| 273 | tools: list[FunctionToolParam], | |
| 274 | tool_choice: ToolChoice, | |
| 275 | show_events: bool, | |
| 276 | ) -> RunResponseResult: | |
| 277 | connection.response.create( | |
| 278 | model=model, | |
| 279 | input=input_payload, | |
| 280 | stream=True, | |
| 281 | previous_response_id=previous_response_id, | |
| 282 | tools=tools, | |
| 283 | tool_choice=tool_choice, | |
| 284 | ) | |
| 285 | | |
| 286 | text_parts: list[str] = [] | |
| 287 | function_calls: list[FunctionCallRequest] = [] | |
| 288 | response_id: Optional[str] = None | |
| 289 | | |
| 290 | for event in connection: | |
| 291 | if event.type == "response.output_text.delta": | |
| 292 | text_parts.append(event.delta) | |
| 293 | continue | |
| 294 | | |
| 295 | if event.type == "response.output_item.done" and event.item.type == "function_call": | |
| 296 | function_calls.append( | |
| 297 | FunctionCallRequest( | |
| 298 | name=event.item.name, | |
| 299 | arguments_json=event.item.arguments, | |
| 300 | call_id=event.item.call_id, | |
| 301 | ) | |
| 302 | ) | |
| 303 | continue | |
| 304 | | |
| 305 | if getattr(event, "type", None) == "error": | |
| 306 | raise RuntimeError(f"WebSocket error event: {event!r}") | |
| 307 | | |
| 308 | if isinstance(event, (ResponseCompletedEvent, ResponseFailedEvent, ResponseIncompleteEvent)): | |
| 309 | response_id = event.response.id | |
| 310 | if not isinstance(event, ResponseCompletedEvent): | |
| 311 | raise RuntimeError(f"Response ended with {event.type} (id={response_id})") | |
| 312 | if show_events: | |
| 313 | print(f"[{event.type}]") | |
| 314 | break | |
| 315 | | |
| 316 | if getattr(event, "type", None) == "response.done": | |
| 317 | # Responses over WebSocket currently emit `response.done` as the final event. | |
| 318 | # The payload still includes `response.id`, which we use for chaining. | |
| 319 | event_response = getattr(event, "response", None) | |
| 320 | event_response_id: Optional[str] = None | |
| 321 | if isinstance(event_response, dict): | |
| 322 | event_response_dict = cast(Dict[str, object], event_response) | |
| 323 | raw_event_response_id = event_response_dict.get("id") | |
| 324 | if isinstance(raw_event_response_id, str): | |
| 325 | event_response_id = raw_event_response_id | |
| 326 | else: | |
| 327 | raw_event_response_id = getattr(event_response, "id", None) | |
| 328 | if isinstance(raw_event_response_id, str): | |
| 329 | event_response_id = raw_event_response_id | |
| 330 | | |
| 331 | if not isinstance(event_response_id, str): | |
| 332 | raise RuntimeError(f"response.done event did not include a valid response.id: {event!r}") | |
| 333 | | |
| 334 | response_id = event_response_id | |
| 335 | if show_events: | |
| 336 | print("[response.done]") | |
| 337 | break | |
| 338 | | |
| 339 | if show_events: | |
| 340 | print(f"[{event.type}]") | |
| 341 | | |
| 342 | if response_id is None: | |
| 343 | raise RuntimeError("No terminal response event received.") | |
| 344 | | |
| 345 | return RunResponseResult( | |
| 346 | text="".join(text_parts), | |
| 347 | response_id=response_id, | |
| 348 | function_calls=function_calls, | |
| 349 | ) | |
| 350 | | |
| 351 | | |
| 352 | def run_turn( | |
| 353 | *, | |
| 354 | connection: ResponsesConnection, | |
| 355 | model: str, | |
| 356 | previous_response_id: Optional[str], | |
| 357 | turn_prompt: str, | |
| 358 | forced_tool_name: ToolName, | |
| 359 | show_events: bool, | |
| 360 | show_tool_io: bool, | |
| 361 | ) -> RunTurnResult: | |
| 362 | accumulated_text_parts: list[str] = [] | |
| 363 | | |
| 364 | current_input: Union[str, ResponseInputParam] = turn_prompt | |
| 365 | current_tool_choice: ToolChoice = {"type": "function", "name": forced_tool_name} | |
| 366 | current_previous_response_id = previous_response_id | |
| 367 | | |
| 368 | while True: | |
| 369 | response_result = run_response( | |
| 370 | connection=connection, | |
| 371 | model=model, | |
| 372 | previous_response_id=current_previous_response_id, | |
| 373 | input_payload=current_input, | |
| 374 | tools=TOOLS, | |
| 375 | tool_choice=current_tool_choice, | |
| 376 | show_events=show_events, | |
| 377 | ) | |
| 378 | | |
| 379 | if response_result.text: | |
| 380 | accumulated_text_parts.append(response_result.text) | |
| 381 | | |
| 382 | current_previous_response_id = response_result.response_id | |
| 383 | if not response_result.function_calls: | |
| 384 | break | |
| 385 | | |
| 386 | tool_outputs: ResponseInputParam = [] | |
| 387 | for function_call in response_result.function_calls: | |
| 388 | tool_name = parse_tool_name(function_call.name) | |
| 389 | arguments = parse_sku_arguments(function_call.arguments_json) | |
| 390 | output_payload = call_tool(tool_name, arguments) | |
| 391 | if show_tool_io: | |
| 392 | print(f"[tool_call] {function_call.name}({function_call.arguments_json})") | |
| 393 | print(f"[tool_output] {json.dumps(output_payload)}") | |
| 394 | | |
| 395 | function_call_output: FunctionCallOutputItem = { | |
| 396 | "type": "function_call_output", | |
| 397 | "call_id": function_call.call_id, | |
| 398 | "output": json.dumps(output_payload), | |
| 399 | } | |
| 400 | tool_outputs.append(cast(ResponseInputItemParam, function_call_output)) | |
| 401 | | |
| 402 | current_input = tool_outputs | |
| 403 | current_tool_choice = "none" | |
| 404 | | |
| 405 | return RunTurnResult( | |
| 406 | assistant_text="".join(accumulated_text_parts).strip(), | |
| 407 | response_id=current_previous_response_id, | |
| 408 | ) | |
| 409 | | |
| 410 | | |
| 411 | def main() -> None: | |
| 412 | args = parse_args() | |
| 413 | | |
| 414 | client = OpenAI() | |
| 415 | extra_headers = {"OpenAI-Beta": BETA_HEADER_VALUE} if args.use_beta_header else {} | |
| 416 | | |
| 417 | with client.responses.connect(extra_headers=extra_headers) as connection: | |
| 418 | previous_response_id: Optional[str] = None | |
| 419 | for index, turn in enumerate(DEMO_TURNS, start=1): | |
| 420 | print(f"\n=== Turn {index} ===") | |
| 421 | print(f"User: {turn['prompt']}") | |
| 422 | turn_result = run_turn( | |
| 423 | connection=connection, | |
| 424 | model=args.model, | |
| 425 | previous_response_id=previous_response_id, | |
| 426 | turn_prompt=turn["prompt"], | |
| 427 | forced_tool_name=turn["tool_name"], | |
| 428 | show_events=args.show_events, | |
| 429 | show_tool_io=args.show_tool_io, | |
| 430 | ) | |
| 431 | previous_response_id = turn_result.response_id | |
| 432 | print(f"Assistant: {turn_result.assistant_text}") | |
| 433 | | |
| 434 | | |
| 435 | if __name__ == "__main__": | |
| 436 | main() |