openai/gpt-oss
Publicmirrored fromhttps://github.com/openai/gpt-ossAvailable
examples/streamlit/streamlit_chat.py
276lines ยท modecode
| 1 | import json |
| 2 | |
| 3 | import requests |
| 4 | import streamlit as st |
| 5 | |
| 6 | DEFAULT_FUNCTION_PROPERTIES = """ |
| 7 | { |
| 8 | "type": "object", |
| 9 | "properties": { |
| 10 | "location": { |
| 11 | "type": "string", |
| 12 | "description": "The city and state, e.g. San Francisco, CA" |
| 13 | } |
| 14 | }, |
| 15 | "required": ["location"] |
| 16 | } |
| 17 | """.strip() |
| 18 | |
| 19 | # Session state for chat |
| 20 | if "messages" not in st.session_state: |
| 21 | st.session_state.messages = [] |
| 22 | |
| 23 | st.title("๐ฌ Chatbot") |
| 24 | |
| 25 | if "model" not in st.session_state: |
| 26 | if "model" in st.query_params: |
| 27 | st.session_state.model = st.query_params["model"] |
| 28 | else: |
| 29 | st.session_state.model = "small" |
| 30 | |
| 31 | options = ["large", "small"] |
| 32 | selection = st.sidebar.segmented_control( |
| 33 | "Model", options, selection_mode="single", default=st.session_state.model |
| 34 | ) |
| 35 | # st.session_state.model = selection |
| 36 | st.query_params.update({"model": selection}) |
| 37 | |
| 38 | instructions = st.sidebar.text_area( |
| 39 | "Instructions", |
| 40 | value="You are a helpful assistant that can answer questions and help with tasks.", |
| 41 | ) |
| 42 | effort = st.sidebar.radio( |
| 43 | "Reasoning effort", |
| 44 | ["low", "medium", "high"], |
| 45 | index=1, |
| 46 | ) |
| 47 | st.sidebar.divider() |
| 48 | st.sidebar.subheader("Functions") |
| 49 | use_functions = st.sidebar.toggle("Use functions", value=False) |
| 50 | |
| 51 | st.sidebar.subheader("Built-in Tools") |
| 52 | # Built-in Tools section |
| 53 | use_browser_search = st.sidebar.toggle("Use browser search", value=False) |
| 54 | use_code_interpreter = st.sidebar.toggle("Use code interpreter", value=False) |
| 55 | |
| 56 | if use_functions: |
| 57 | function_name = st.sidebar.text_input("Function name", value="get_weather") |
| 58 | function_description = st.sidebar.text_area( |
| 59 | "Function description", value="Get the weather for a given city" |
| 60 | ) |
| 61 | function_parameters = st.sidebar.text_area( |
| 62 | "Function parameters", value=DEFAULT_FUNCTION_PROPERTIES |
| 63 | ) |
| 64 | else: |
| 65 | function_name = None |
| 66 | function_description = None |
| 67 | function_parameters = None |
| 68 | st.sidebar.divider() |
| 69 | temperature = st.sidebar.slider( |
| 70 | "Temperature", min_value=0.0, max_value=1.0, value=1.0, step=0.01 |
| 71 | ) |
| 72 | max_output_tokens = st.sidebar.slider( |
| 73 | "Max output tokens", min_value=1, max_value=131072, value=30000, step=1000 |
| 74 | ) |
| 75 | st.sidebar.divider() |
| 76 | debug_mode = st.sidebar.toggle("Debug mode", value=False) |
| 77 | |
| 78 | if debug_mode: |
| 79 | st.sidebar.divider() |
| 80 | st.sidebar.code(json.dumps(st.session_state.messages, indent=2), "json") |
| 81 | |
| 82 | render_input = True |
| 83 | |
| 84 | URL = ( |
| 85 | "http://localhost:8081/v1/responses" |
| 86 | if selection == options[1] |
| 87 | else "http://localhost:8000/v1/responses" |
| 88 | ) |
| 89 | |
| 90 | |
| 91 | def trigger_fake_tool(container): |
| 92 | function_output = st.session_state.get("function_output", "It's sunny!") |
| 93 | last_call = st.session_state.messages[-1] |
| 94 | if last_call.get("type") == "function_call": |
| 95 | st.session_state.messages.append( |
| 96 | { |
| 97 | "type": "function_call_output", |
| 98 | "call_id": last_call.get("call_id"), |
| 99 | "output": function_output, |
| 100 | } |
| 101 | ) |
| 102 | run(container) |
| 103 | |
| 104 | |
| 105 | def run(container): |
| 106 | tools = [] |
| 107 | if use_functions: |
| 108 | tools.append( |
| 109 | { |
| 110 | "type": "function", |
| 111 | "name": function_name, |
| 112 | "description": function_description, |
| 113 | "parameters": json.loads(function_parameters), |
| 114 | } |
| 115 | ) |
| 116 | # Add browser_search tool if checkbox is checked |
| 117 | if use_browser_search: |
| 118 | tools.append({"type": "browser_search"}) |
| 119 | if use_code_interpreter: |
| 120 | tools.append({"type": "code_interpreter"}) |
| 121 | response = requests.post( |
| 122 | URL, |
| 123 | json={ |
| 124 | "input": st.session_state.messages, |
| 125 | "stream": True, |
| 126 | "instructions": instructions, |
| 127 | "reasoning": {"effort": effort}, |
| 128 | "metadata": {"__debug": debug_mode}, |
| 129 | "tools": tools, |
| 130 | "temperature": temperature, |
| 131 | "max_output_tokens": max_output_tokens, |
| 132 | }, |
| 133 | stream=True, |
| 134 | ) |
| 135 | |
| 136 | text_delta = "" |
| 137 | |
| 138 | _current_output_index = 0 |
| 139 | for line in response.iter_lines(decode_unicode=True): |
| 140 | if not line or not line.startswith("data:"): |
| 141 | continue |
| 142 | data_str = line[len("data:") :].strip() |
| 143 | if not data_str: |
| 144 | continue |
| 145 | try: |
| 146 | data = json.loads(data_str) |
| 147 | except Exception: |
| 148 | continue |
| 149 | |
| 150 | event_type = data.get("type", "") |
| 151 | output_index = data.get("output_index", 0) |
| 152 | if event_type == "response.output_item.added": |
| 153 | _current_output_index = output_index |
| 154 | output_type = data.get("item", {}).get("type", "message") |
| 155 | if output_type == "message": |
| 156 | output = container.chat_message("assistant") |
| 157 | placeholder = output.empty() |
| 158 | elif output_type == "reasoning": |
| 159 | output = container.chat_message("reasoning", avatar="๐ค") |
| 160 | placeholder = output.empty() |
| 161 | elif output_type == "web_search_call": |
| 162 | output = container.chat_message("web_search_call", avatar="๐") |
| 163 | output.code( |
| 164 | json.dumps(data.get("item", {}).get("action", {}), indent=4), |
| 165 | language="json", |
| 166 | ) |
| 167 | placeholder = output.empty() |
| 168 | elif output_type == "code_interpreter_call": |
| 169 | output = container.chat_message("code_interpreter_call", avatar="๐งช") |
| 170 | placeholder = output.empty() |
| 171 | text_delta = "" |
| 172 | elif event_type == "response.reasoning_text.delta": |
| 173 | output.avatar = "๐ค" |
| 174 | text_delta += data.get("delta", "") |
| 175 | placeholder.markdown(text_delta) |
| 176 | elif event_type == "response.output_text.delta": |
| 177 | text_delta += data.get("delta", "") |
| 178 | placeholder.markdown(text_delta) |
| 179 | elif event_type == "response.output_item.done": |
| 180 | item = data.get("item", {}) |
| 181 | if item.get("type") == "function_call": |
| 182 | with container.chat_message("function_call", avatar="๐จ"): |
| 183 | st.markdown(f"Called `{item.get('name')}`") |
| 184 | st.caption("Arguments") |
| 185 | st.code(item.get("arguments", ""), language="json") |
| 186 | if item.get("type") == "web_search_call": |
| 187 | placeholder.markdown("โ
Done") |
| 188 | if item.get("type") == "code_interpreter_call": |
| 189 | placeholder.markdown("โ
Done") |
| 190 | elif event_type == "response.code_interpreter_call.in_progress": |
| 191 | try: |
| 192 | placeholder.markdown("โณ Running") |
| 193 | except Exception: |
| 194 | pass |
| 195 | elif event_type == "response.code_interpreter_call.completed": |
| 196 | try: |
| 197 | placeholder.markdown("โ
Done") |
| 198 | except Exception: |
| 199 | pass |
| 200 | elif event_type == "response.completed": |
| 201 | response = data.get("response", {}) |
| 202 | if debug_mode: |
| 203 | container.expander("Debug", expanded=False).code( |
| 204 | response.get("metadata", {}).get("__debug", ""), language="text" |
| 205 | ) |
| 206 | st.session_state.messages.extend(response.get("output", [])) |
| 207 | if st.session_state.messages[-1].get("type") == "function_call": |
| 208 | with container.form("function_output_form"): |
| 209 | _function_output = st.text_input( |
| 210 | "Enter function output", |
| 211 | value=st.session_state.get("function_output", "It's sunny!"), |
| 212 | key="function_output", |
| 213 | ) |
| 214 | st.form_submit_button( |
| 215 | "Submit function output", |
| 216 | on_click=trigger_fake_tool, |
| 217 | args=[container], |
| 218 | ) |
| 219 | # Optionally handle other event types... |
| 220 | |
| 221 | |
| 222 | # Chat display |
| 223 | for msg in st.session_state.messages: |
| 224 | if msg.get("type") == "message": |
| 225 | with st.chat_message(msg["role"]): |
| 226 | for item in msg["content"]: |
| 227 | if ( |
| 228 | item.get("type") == "text" |
| 229 | or item.get("type") == "output_text" |
| 230 | or item.get("type") == "input_text" |
| 231 | ): |
| 232 | st.markdown(item["text"]) |
| 233 | if item.get("annotations"): |
| 234 | annotation_lines = "\n".join( |
| 235 | f"- {annotation.get('url')}" |
| 236 | for annotation in item["annotations"] |
| 237 | if annotation.get("url") |
| 238 | ) |
| 239 | st.caption(f"**Annotations:**\n{annotation_lines}") |
| 240 | elif msg.get("type") == "reasoning": |
| 241 | with st.chat_message("reasoning", avatar="๐ค"): |
| 242 | for item in msg["content"]: |
| 243 | if item.get("type") == "reasoning_text": |
| 244 | st.markdown(item["text"]) |
| 245 | elif msg.get("type") == "function_call": |
| 246 | with st.chat_message("function_call", avatar="๐จ"): |
| 247 | st.markdown(f"Called `{msg.get('name')}`") |
| 248 | st.caption("Arguments") |
| 249 | st.code(msg.get("arguments", ""), language="json") |
| 250 | elif msg.get("type") == "function_call_output": |
| 251 | with st.chat_message("function_call_output", avatar="โ
"): |
| 252 | st.caption("Output") |
| 253 | st.code(msg.get("output", ""), language="text") |
| 254 | elif msg.get("type") == "web_search_call": |
| 255 | with st.chat_message("web_search_call", avatar="๐"): |
| 256 | st.code(json.dumps(msg.get("action", {}), indent=4), language="json") |
| 257 | st.markdown("โ
Done") |
| 258 | elif msg.get("type") == "code_interpreter_call": |
| 259 | with st.chat_message("code_interpreter_call", avatar="๐งช"): |
| 260 | st.markdown("โ
Done") |
| 261 | |
| 262 | if render_input: |
| 263 | # Input field |
| 264 | if prompt := st.chat_input("Type a message..."): |
| 265 | st.session_state.messages.append( |
| 266 | { |
| 267 | "type": "message", |
| 268 | "role": "user", |
| 269 | "content": [{"type": "input_text", "text": prompt}], |
| 270 | } |
| 271 | ) |
| 272 | |
| 273 | with st.chat_message("user"): |
| 274 | st.markdown(prompt) |
| 275 | |
| 276 | run(st.container()) |
| 277 | |