openai/openai-python

Public

mirrored from https://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v2.33.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

examples/responses/websocket.py

436lines · modeblame

e273d622stainless-app[bot]4 months ago1from __future__ import annotations
2
3import json
4import argparse
5from typing import TYPE_CHECKING, Dict, Union, Literal, Optional, TypedDict, NamedTuple, cast
6
7from openai import OpenAI
8from openai.types.responses import (
9FunctionToolParam,
10ToolChoiceOptions,
11ResponseInputParam,
12ResponseFailedEvent,
13ResponseCompletedEvent,
14ResponseInputItemParam,
15ResponseIncompleteEvent,
16ToolChoiceFunctionParam,
17)
18
19if TYPE_CHECKING:
20from openai.resources.responses.responses import ResponsesConnection
21
22ToolName = Literal["get_sku_inventory", "get_supplier_eta", "get_quality_alerts"]
23ToolChoice = Union[ToolChoiceOptions, ToolChoiceFunctionParam]
24
25
26class DemoTurn(TypedDict):
27tool_name: ToolName
28prompt: str
29
30
31class SKUArguments(TypedDict):
32sku: str
33
34
35class SKUInventoryOutput(TypedDict):
36sku: str
37warehouse: str
38on_hand_units: int
39reserved_units: int
40reorder_point: int
41safety_stock: int
42
43
44class SupplierShipment(TypedDict):
45shipment_id: str
46eta_date: str
47quantity: int
48risk: str
49
50
51class SupplierETAOutput(TypedDict):
52sku: str
53supplier_shipments: list[SupplierShipment]
54
55
56class QualityAlert(TypedDict):
57alert_id: str
58status: str
59severity: str
60summary: str
61
62
63class QualityAlertsOutput(TypedDict):
64sku: str
65alerts: list[QualityAlert]
66
67
68class FunctionCallOutputItem(TypedDict):
69type: Literal["function_call_output"]
70call_id: str
71output: str
72
73
74class FunctionCallRequest(NamedTuple):
75name: str
76arguments_json: str
77call_id: str
78
79
80class RunResponseResult(NamedTuple):
81text: str
82response_id: str
83function_calls: list[FunctionCallRequest]
84
85
86class RunTurnResult(NamedTuple):
87assistant_text: str
88response_id: str
89
90
91ToolOutput = Union[SKUInventoryOutput, SupplierETAOutput, QualityAlertsOutput]
92
93TOOLS: list[FunctionToolParam] = [
94{
95"type": "function",
96"name": "get_sku_inventory",
97"description": "Return froge pond inventory details for a SKU.",
98"strict": True,
99"parameters": {
100"type": "object",
101"properties": {
102"sku": {
103"type": "string",
104"description": "Stock-keeping unit identifier, such as sku-froge-lily-pad-deluxe.",
105}
106},
107"required": ["sku"],
108"additionalProperties": False,
109},
110},
111{
112"type": "function",
113"name": "get_supplier_eta",
114"description": "Return tadpole supplier restock ETA data for a SKU.",
115"strict": True,
116"parameters": {
117"type": "object",
118"properties": {
119"sku": {
120"type": "string",
121"description": "Stock-keeping unit identifier, such as sku-froge-lily-pad-deluxe.",
122}
123},
124"required": ["sku"],
125"additionalProperties": False,
126},
127},
128{
129"type": "function",
130"name": "get_quality_alerts",
131"description": "Return recent froge quality alerts for a SKU.",
132"strict": True,
133"parameters": {
134"type": "object",
135"properties": {
136"sku": {
137"type": "string",
138"description": "Stock-keeping unit identifier, such as sku-froge-lily-pad-deluxe.",
139}
140},
141"required": ["sku"],
142"additionalProperties": False,
143},
144},
145]
146
147DEMO_TURNS: list[DemoTurn] = [
148{
149"tool_name": "get_sku_inventory",
150"prompt": "Use get_sku_inventory for sku='sku-froge-lily-pad-deluxe' and summarize current pond stock health in one sentence.",
151},
152{
153"tool_name": "get_supplier_eta",
154"prompt": "Now use get_supplier_eta for the same SKU and summarize restock ETA and tadpole shipment risk.",
155},
156{
157"tool_name": "get_quality_alerts",
158"prompt": "Finally use get_quality_alerts for the same SKU and summarize unresolved froge quality concerns in one short paragraph.",
159},
160]
161
162BETA_HEADER_VALUE = "responses_websockets=2026-02-06"
163
164
165def parse_args() -> argparse.Namespace:
166parser = argparse.ArgumentParser(
167description=("Run a 3-turn Responses WebSocket demo with function calling and chained previous_response_id.")
168)
169parser.add_argument("--model", default="gpt-5.2", help="Model used in the `response.create` payload.")
170parser.add_argument(
171"--use-beta-header",
172action="store_true",
173help=f"Include `OpenAI-Beta: {BETA_HEADER_VALUE}` for beta websocket behavior.",
174)
175parser.add_argument(
176"--show-events",
177action="store_true",
178help="Print non-text event types while streaming.",
179)
180parser.add_argument(
181"--show-tool-io",
182action="store_true",
183help="Print each tool call and tool output payload.",
184)
185return parser.parse_args()
186
187
188def parse_tool_name(name: str) -> ToolName:
189if name not in {"get_sku_inventory", "get_supplier_eta", "get_quality_alerts"}:
190raise ValueError(f"Unsupported tool requested: {name}")
191return cast(ToolName, name)
192
193
194def parse_sku_arguments(raw_arguments: str) -> SKUArguments:
195parsed_raw = json.loads(raw_arguments)
196if not isinstance(parsed_raw, dict):
197raise ValueError(f"Tool arguments must be a JSON object: {raw_arguments}")
198
199parsed = cast(Dict[str, object], parsed_raw)
200sku_value = parsed.get("sku")
201if not isinstance(sku_value, str):
202raise ValueError(f"Tool arguments must include a string `sku`: {raw_arguments}")
203
204return {"sku": sku_value}
205
206
207def call_tool(name: ToolName, arguments: SKUArguments) -> ToolOutput:
208sku = arguments["sku"]
209
210if name == "get_sku_inventory":
211return {
212"sku": sku,
213"warehouse": "pond-west-1",
214"on_hand_units": 84,
215"reserved_units": 26,
216"reorder_point": 60,
217"safety_stock": 40,
218}
219
220if name == "get_supplier_eta":
221return {
222"sku": sku,
223"supplier_shipments": [
224{
225"shipment_id": "frog_ship_2201",
226"eta_date": "2026-02-24",
227"quantity": 180,
228"risk": "low",
229},
230{
231"shipment_id": "frog_ship_2205",
232"eta_date": "2026-03-03",
233"quantity": 220,
234"risk": "medium",
235},
236],
237}
238
239if name == "get_quality_alerts":
240return {
241"sku": sku,
242"alerts": [
243{
244"alert_id": "frog_qa_781",
245"status": "open",
246"severity": "high",
247"summary": "Lily-pad coating chipping in lot LP-42",
248},
249{
250"alert_id": "frog_qa_795",
251"status": "in_progress",
252"severity": "medium",
253"summary": "Pond-crate scuff rate above threshold",
254},
255{
256"alert_id": "frog_qa_802",
257"status": "resolved",
258"severity": "low",
259"summary": "Froge label alignment issue corrected",
260},
261],
262}
263
264raise ValueError(f"Unknown tool: {name}")
265
266
267def run_response(
268*,
269connection: ResponsesConnection,
270model: str,
271previous_response_id: Optional[str],
272input_payload: Union[str, ResponseInputParam],
273tools: list[FunctionToolParam],
274tool_choice: ToolChoice,
275show_events: bool,
276) -> RunResponseResult:
277connection.response.create(
278model=model,
279input=input_payload,
280stream=True,
281previous_response_id=previous_response_id,
282tools=tools,
283tool_choice=tool_choice,
284)
285
286text_parts: list[str] = []
287function_calls: list[FunctionCallRequest] = []
288response_id: Optional[str] = None
289
290for event in connection:
291if event.type == "response.output_text.delta":
292text_parts.append(event.delta)
293continue
294
295if event.type == "response.output_item.done" and event.item.type == "function_call":
296function_calls.append(
297FunctionCallRequest(
298name=event.item.name,
299arguments_json=event.item.arguments,
300call_id=event.item.call_id,
301)
302)
303continue
304
305if getattr(event, "type", None) == "error":
306raise RuntimeError(f"WebSocket error event: {event!r}")
307
308if isinstance(event, (ResponseCompletedEvent, ResponseFailedEvent, ResponseIncompleteEvent)):
309response_id = event.response.id
310if not isinstance(event, ResponseCompletedEvent):
311raise RuntimeError(f"Response ended with {event.type} (id={response_id})")
312if show_events:
313print(f"[{event.type}]")
314break
315
316if getattr(event, "type", None) == "response.done":
317# Responses over WebSocket currently emit `response.done` as the final event.
318# The payload still includes `response.id`, which we use for chaining.
319event_response = getattr(event, "response", None)
320event_response_id: Optional[str] = None
321if isinstance(event_response, dict):
322event_response_dict = cast(Dict[str, object], event_response)
323raw_event_response_id = event_response_dict.get("id")
324if isinstance(raw_event_response_id, str):
325event_response_id = raw_event_response_id
326else:
327raw_event_response_id = getattr(event_response, "id", None)
328if isinstance(raw_event_response_id, str):
329event_response_id = raw_event_response_id
330
331if not isinstance(event_response_id, str):
332raise RuntimeError(f"response.done event did not include a valid response.id: {event!r}")
333
334response_id = event_response_id
335if show_events:
336print("[response.done]")
337break
338
339if show_events:
340print(f"[{event.type}]")
341
342if response_id is None:
343raise RuntimeError("No terminal response event received.")
344
345return RunResponseResult(
346text="".join(text_parts),
347response_id=response_id,
348function_calls=function_calls,
349)
350
351
352def run_turn(
353*,
354connection: ResponsesConnection,
355model: str,
356previous_response_id: Optional[str],
357turn_prompt: str,
358forced_tool_name: ToolName,
359show_events: bool,
360show_tool_io: bool,
361) -> RunTurnResult:
362accumulated_text_parts: list[str] = []
363
364current_input: Union[str, ResponseInputParam] = turn_prompt
365current_tool_choice: ToolChoice = {"type": "function", "name": forced_tool_name}
366current_previous_response_id = previous_response_id
367
368while True:
369response_result = run_response(
370connection=connection,
371model=model,
372previous_response_id=current_previous_response_id,
373input_payload=current_input,
374tools=TOOLS,
375tool_choice=current_tool_choice,
376show_events=show_events,
377)
378
379if response_result.text:
380accumulated_text_parts.append(response_result.text)
381
382current_previous_response_id = response_result.response_id
383if not response_result.function_calls:
384break
385
386tool_outputs: ResponseInputParam = []
387for function_call in response_result.function_calls:
388tool_name = parse_tool_name(function_call.name)
389arguments = parse_sku_arguments(function_call.arguments_json)
390output_payload = call_tool(tool_name, arguments)
391if show_tool_io:
392print(f"[tool_call] {function_call.name}({function_call.arguments_json})")
393print(f"[tool_output] {json.dumps(output_payload)}")
394
395function_call_output: FunctionCallOutputItem = {
396"type": "function_call_output",
397"call_id": function_call.call_id,
398"output": json.dumps(output_payload),
399}
400tool_outputs.append(cast(ResponseInputItemParam, function_call_output))
401
402current_input = tool_outputs
403current_tool_choice = "none"
404
405return RunTurnResult(
406assistant_text="".join(accumulated_text_parts).strip(),
407response_id=current_previous_response_id,
408)
409
410
411def main() -> None:
412args = parse_args()
413
414client = OpenAI()
415extra_headers = {"OpenAI-Beta": BETA_HEADER_VALUE} if args.use_beta_header else {}
416
417with client.responses.connect(extra_headers=extra_headers) as connection:
418previous_response_id: Optional[str] = None
419for index, turn in enumerate(DEMO_TURNS, start=1):
420print(f"\n=== Turn {index} ===")
421print(f"User: {turn['prompt']}")
422turn_result = run_turn(
423connection=connection,
424model=args.model,
425previous_response_id=previous_response_id,
426turn_prompt=turn["prompt"],
427forced_tool_name=turn["tool_name"],
428show_events=args.show_events,
429show_tool_io=args.show_tool_io,
430)
431previous_response_id = turn_result.response_id
432print(f"Assistant: {turn_result.assistant_text}")
433
434
435if __name__ == "__main__":
436main()