openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
codex/bedrock-thin-provider-rewrite

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

examples/responses/websocket.py

436lines · modecode

1from __future__ import annotations
2
3import json
4import argparse
5from typing import TYPE_CHECKING, Dict, Union, Literal, Optional, TypedDict, NamedTuple, cast
6
7from openai import OpenAI
8from openai.types.responses import (
9 FunctionToolParam,
10 ToolChoiceOptions,
11 ResponseInputParam,
12 ResponseFailedEvent,
13 ResponseCompletedEvent,
14 ResponseInputItemParam,
15 ResponseIncompleteEvent,
16 ToolChoiceFunctionParam,
17)
18
19if TYPE_CHECKING:
20 from openai.resources.responses.responses import ResponsesConnection
21
22ToolName = Literal["get_sku_inventory", "get_supplier_eta", "get_quality_alerts"]
23ToolChoice = Union[ToolChoiceOptions, ToolChoiceFunctionParam]
24
25
26class DemoTurn(TypedDict):
27 tool_name: ToolName
28 prompt: str
29
30
31class SKUArguments(TypedDict):
32 sku: str
33
34
35class SKUInventoryOutput(TypedDict):
36 sku: str
37 warehouse: str
38 on_hand_units: int
39 reserved_units: int
40 reorder_point: int
41 safety_stock: int
42
43
44class SupplierShipment(TypedDict):
45 shipment_id: str
46 eta_date: str
47 quantity: int
48 risk: str
49
50
51class SupplierETAOutput(TypedDict):
52 sku: str
53 supplier_shipments: list[SupplierShipment]
54
55
56class QualityAlert(TypedDict):
57 alert_id: str
58 status: str
59 severity: str
60 summary: str
61
62
63class QualityAlertsOutput(TypedDict):
64 sku: str
65 alerts: list[QualityAlert]
66
67
68class FunctionCallOutputItem(TypedDict):
69 type: Literal["function_call_output"]
70 call_id: str
71 output: str
72
73
74class FunctionCallRequest(NamedTuple):
75 name: str
76 arguments_json: str
77 call_id: str
78
79
80class RunResponseResult(NamedTuple):
81 text: str
82 response_id: str
83 function_calls: list[FunctionCallRequest]
84
85
86class RunTurnResult(NamedTuple):
87 assistant_text: str
88 response_id: str
89
90
91ToolOutput = Union[SKUInventoryOutput, SupplierETAOutput, QualityAlertsOutput]
92
93TOOLS: list[FunctionToolParam] = [
94 {
95 "type": "function",
96 "name": "get_sku_inventory",
97 "description": "Return froge pond inventory details for a SKU.",
98 "strict": True,
99 "parameters": {
100 "type": "object",
101 "properties": {
102 "sku": {
103 "type": "string",
104 "description": "Stock-keeping unit identifier, such as sku-froge-lily-pad-deluxe.",
105 }
106 },
107 "required": ["sku"],
108 "additionalProperties": False,
109 },
110 },
111 {
112 "type": "function",
113 "name": "get_supplier_eta",
114 "description": "Return tadpole supplier restock ETA data for a SKU.",
115 "strict": True,
116 "parameters": {
117 "type": "object",
118 "properties": {
119 "sku": {
120 "type": "string",
121 "description": "Stock-keeping unit identifier, such as sku-froge-lily-pad-deluxe.",
122 }
123 },
124 "required": ["sku"],
125 "additionalProperties": False,
126 },
127 },
128 {
129 "type": "function",
130 "name": "get_quality_alerts",
131 "description": "Return recent froge quality alerts for a SKU.",
132 "strict": True,
133 "parameters": {
134 "type": "object",
135 "properties": {
136 "sku": {
137 "type": "string",
138 "description": "Stock-keeping unit identifier, such as sku-froge-lily-pad-deluxe.",
139 }
140 },
141 "required": ["sku"],
142 "additionalProperties": False,
143 },
144 },
145]
146
147DEMO_TURNS: list[DemoTurn] = [
148 {
149 "tool_name": "get_sku_inventory",
150 "prompt": "Use get_sku_inventory for sku='sku-froge-lily-pad-deluxe' and summarize current pond stock health in one sentence.",
151 },
152 {
153 "tool_name": "get_supplier_eta",
154 "prompt": "Now use get_supplier_eta for the same SKU and summarize restock ETA and tadpole shipment risk.",
155 },
156 {
157 "tool_name": "get_quality_alerts",
158 "prompt": "Finally use get_quality_alerts for the same SKU and summarize unresolved froge quality concerns in one short paragraph.",
159 },
160]
161
162BETA_HEADER_VALUE = "responses_websockets=2026-02-06"
163
164
165def parse_args() -> argparse.Namespace:
166 parser = argparse.ArgumentParser(
167 description=("Run a 3-turn Responses WebSocket demo with function calling and chained previous_response_id.")
168 )
169 parser.add_argument("--model", default="gpt-5.2", help="Model used in the `response.create` payload.")
170 parser.add_argument(
171 "--use-beta-header",
172 action="store_true",
173 help=f"Include `OpenAI-Beta: {BETA_HEADER_VALUE}` for beta websocket behavior.",
174 )
175 parser.add_argument(
176 "--show-events",
177 action="store_true",
178 help="Print non-text event types while streaming.",
179 )
180 parser.add_argument(
181 "--show-tool-io",
182 action="store_true",
183 help="Print each tool call and tool output payload.",
184 )
185 return parser.parse_args()
186
187
188def parse_tool_name(name: str) -> ToolName:
189 if name not in {"get_sku_inventory", "get_supplier_eta", "get_quality_alerts"}:
190 raise ValueError(f"Unsupported tool requested: {name}")
191 return cast(ToolName, name)
192
193
194def parse_sku_arguments(raw_arguments: str) -> SKUArguments:
195 parsed_raw = json.loads(raw_arguments)
196 if not isinstance(parsed_raw, dict):
197 raise ValueError(f"Tool arguments must be a JSON object: {raw_arguments}")
198
199 parsed = cast(Dict[str, object], parsed_raw)
200 sku_value = parsed.get("sku")
201 if not isinstance(sku_value, str):
202 raise ValueError(f"Tool arguments must include a string `sku`: {raw_arguments}")
203
204 return {"sku": sku_value}
205
206
207def call_tool(name: ToolName, arguments: SKUArguments) -> ToolOutput:
208 sku = arguments["sku"]
209
210 if name == "get_sku_inventory":
211 return {
212 "sku": sku,
213 "warehouse": "pond-west-1",
214 "on_hand_units": 84,
215 "reserved_units": 26,
216 "reorder_point": 60,
217 "safety_stock": 40,
218 }
219
220 if name == "get_supplier_eta":
221 return {
222 "sku": sku,
223 "supplier_shipments": [
224 {
225 "shipment_id": "frog_ship_2201",
226 "eta_date": "2026-02-24",
227 "quantity": 180,
228 "risk": "low",
229 },
230 {
231 "shipment_id": "frog_ship_2205",
232 "eta_date": "2026-03-03",
233 "quantity": 220,
234 "risk": "medium",
235 },
236 ],
237 }
238
239 if name == "get_quality_alerts":
240 return {
241 "sku": sku,
242 "alerts": [
243 {
244 "alert_id": "frog_qa_781",
245 "status": "open",
246 "severity": "high",
247 "summary": "Lily-pad coating chipping in lot LP-42",
248 },
249 {
250 "alert_id": "frog_qa_795",
251 "status": "in_progress",
252 "severity": "medium",
253 "summary": "Pond-crate scuff rate above threshold",
254 },
255 {
256 "alert_id": "frog_qa_802",
257 "status": "resolved",
258 "severity": "low",
259 "summary": "Froge label alignment issue corrected",
260 },
261 ],
262 }
263
264 raise ValueError(f"Unknown tool: {name}")
265
266
267def run_response(
268 *,
269 connection: ResponsesConnection,
270 model: str,
271 previous_response_id: Optional[str],
272 input_payload: Union[str, ResponseInputParam],
273 tools: list[FunctionToolParam],
274 tool_choice: ToolChoice,
275 show_events: bool,
276) -> RunResponseResult:
277 connection.response.create(
278 model=model,
279 input=input_payload,
280 stream=True,
281 previous_response_id=previous_response_id,
282 tools=tools,
283 tool_choice=tool_choice,
284 )
285
286 text_parts: list[str] = []
287 function_calls: list[FunctionCallRequest] = []
288 response_id: Optional[str] = None
289
290 for event in connection:
291 if event.type == "response.output_text.delta":
292 text_parts.append(event.delta)
293 continue
294
295 if event.type == "response.output_item.done" and event.item.type == "function_call":
296 function_calls.append(
297 FunctionCallRequest(
298 name=event.item.name,
299 arguments_json=event.item.arguments,
300 call_id=event.item.call_id,
301 )
302 )
303 continue
304
305 if getattr(event, "type", None) == "error":
306 raise RuntimeError(f"WebSocket error event: {event!r}")
307
308 if isinstance(event, (ResponseCompletedEvent, ResponseFailedEvent, ResponseIncompleteEvent)):
309 response_id = event.response.id
310 if not isinstance(event, ResponseCompletedEvent):
311 raise RuntimeError(f"Response ended with {event.type} (id={response_id})")
312 if show_events:
313 print(f"[{event.type}]")
314 break
315
316 if getattr(event, "type", None) == "response.done":
317 # Responses over WebSocket currently emit `response.done` as the final event.
318 # The payload still includes `response.id`, which we use for chaining.
319 event_response = getattr(event, "response", None)
320 event_response_id: Optional[str] = None
321 if isinstance(event_response, dict):
322 event_response_dict = cast(Dict[str, object], event_response)
323 raw_event_response_id = event_response_dict.get("id")
324 if isinstance(raw_event_response_id, str):
325 event_response_id = raw_event_response_id
326 else:
327 raw_event_response_id = getattr(event_response, "id", None)
328 if isinstance(raw_event_response_id, str):
329 event_response_id = raw_event_response_id
330
331 if not isinstance(event_response_id, str):
332 raise RuntimeError(f"response.done event did not include a valid response.id: {event!r}")
333
334 response_id = event_response_id
335 if show_events:
336 print("[response.done]")
337 break
338
339 if show_events:
340 print(f"[{event.type}]")
341
342 if response_id is None:
343 raise RuntimeError("No terminal response event received.")
344
345 return RunResponseResult(
346 text="".join(text_parts),
347 response_id=response_id,
348 function_calls=function_calls,
349 )
350
351
352def run_turn(
353 *,
354 connection: ResponsesConnection,
355 model: str,
356 previous_response_id: Optional[str],
357 turn_prompt: str,
358 forced_tool_name: ToolName,
359 show_events: bool,
360 show_tool_io: bool,
361) -> RunTurnResult:
362 accumulated_text_parts: list[str] = []
363
364 current_input: Union[str, ResponseInputParam] = turn_prompt
365 current_tool_choice: ToolChoice = {"type": "function", "name": forced_tool_name}
366 current_previous_response_id = previous_response_id
367
368 while True:
369 response_result = run_response(
370 connection=connection,
371 model=model,
372 previous_response_id=current_previous_response_id,
373 input_payload=current_input,
374 tools=TOOLS,
375 tool_choice=current_tool_choice,
376 show_events=show_events,
377 )
378
379 if response_result.text:
380 accumulated_text_parts.append(response_result.text)
381
382 current_previous_response_id = response_result.response_id
383 if not response_result.function_calls:
384 break
385
386 tool_outputs: ResponseInputParam = []
387 for function_call in response_result.function_calls:
388 tool_name = parse_tool_name(function_call.name)
389 arguments = parse_sku_arguments(function_call.arguments_json)
390 output_payload = call_tool(tool_name, arguments)
391 if show_tool_io:
392 print(f"[tool_call] {function_call.name}({function_call.arguments_json})")
393 print(f"[tool_output] {json.dumps(output_payload)}")
394
395 function_call_output: FunctionCallOutputItem = {
396 "type": "function_call_output",
397 "call_id": function_call.call_id,
398 "output": json.dumps(output_payload),
399 }
400 tool_outputs.append(cast(ResponseInputItemParam, function_call_output))
401
402 current_input = tool_outputs
403 current_tool_choice = "none"
404
405 return RunTurnResult(
406 assistant_text="".join(accumulated_text_parts).strip(),
407 response_id=current_previous_response_id,
408 )
409
410
411def main() -> None:
412 args = parse_args()
413
414 client = OpenAI()
415 extra_headers = {"OpenAI-Beta": BETA_HEADER_VALUE} if args.use_beta_header else {}
416
417 with client.responses.connect(extra_headers=extra_headers) as connection:
418 previous_response_id: Optional[str] = None
419 for index, turn in enumerate(DEMO_TURNS, start=1):
420 print(f"\n=== Turn {index} ===")
421 print(f"User: {turn['prompt']}")
422 turn_result = run_turn(
423 connection=connection,
424 model=args.model,
425 previous_response_id=previous_response_id,
426 turn_prompt=turn["prompt"],
427 forced_tool_name=turn["tool_name"],
428 show_events=args.show_events,
429 show_tool_io=args.show_tool_io,
430 )
431 previous_response_id = turn_result.response_id
432 print(f"Assistant: {turn_result.assistant_text}")
433
434
435if __name__ == "__main__":
436 main()
437