openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
codex/gate-pypi-publish

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

2951lines · modecode

1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import gc
6import os
7import sys
8import json
9import asyncio
10import inspect
11import dataclasses
12import tracemalloc
13from typing import Any, Union, TypeVar, Callable, Iterable, Iterator, Optional, Coroutine, cast
14from unittest import mock
15from typing_extensions import Literal, AsyncIterator, override
16
17import httpx
18import pytest
19from respx import MockRouter
20from pydantic import ValidationError
21from respx.models import Call as MockRequestCall
22
23from openai import OpenAI, AsyncOpenAI, OpenAIError, APIResponseValidationError
24from openai.auth import WorkloadIdentity
25from openai._types import Omit
26from openai._utils import asyncify
27from openai._models import BaseModel, FinalRequestOptions
28from openai._streaming import Stream, AsyncStream
29from openai._exceptions import APIStatusError, APITimeoutError, APIResponseValidationError
30from openai._base_client import (
31 DEFAULT_TIMEOUT,
32 HTTPX_DEFAULT_TIMEOUT,
33 BaseClient,
34 OtherPlatform,
35 DefaultHttpxClient,
36 DefaultAsyncHttpxClient,
37 get_platform,
38 make_request_options,
39)
40
41from .utils import update_env
42
43T = TypeVar("T")
44base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
45api_key = "My API Key"
46admin_api_key = "My Admin API Key"
47workload_identity: WorkloadIdentity = {
48 "identity_provider_id": "provider_123",
49 "service_account_id": "service_account_123",
50 "provider": {
51 "get_token": lambda: "external-subject-token",
52 "token_type": "jwt",
53 },
54}
55
56
57def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
58 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
59 url = httpx.URL(request.url)
60 return dict(url.params)
61
62
63def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
64 return 0.1
65
66
67def mirror_request_content(request: httpx.Request) -> httpx.Response:
68 return httpx.Response(200, content=request.content)
69
70
71# note: we can't use the httpx.MockTransport class as it consumes the request
72# body itself, which means we can't test that the body is read lazily
73class MockTransport(httpx.BaseTransport, httpx.AsyncBaseTransport):
74 def __init__(
75 self,
76 handler: Callable[[httpx.Request], httpx.Response]
77 | Callable[[httpx.Request], Coroutine[Any, Any, httpx.Response]],
78 ) -> None:
79 self.handler = handler
80
81 @override
82 def handle_request(
83 self,
84 request: httpx.Request,
85 ) -> httpx.Response:
86 assert not inspect.iscoroutinefunction(self.handler), "handler must not be a coroutine function"
87 assert inspect.isfunction(self.handler), "handler must be a function"
88 return self.handler(request)
89
90 @override
91 async def handle_async_request(
92 self,
93 request: httpx.Request,
94 ) -> httpx.Response:
95 assert inspect.iscoroutinefunction(self.handler), "handler must be a coroutine function"
96 return await self.handler(request)
97
98
99@dataclasses.dataclass
100class Counter:
101 value: int = 0
102
103
104def _make_sync_iterator(iterable: Iterable[T], counter: Optional[Counter] = None) -> Iterator[T]:
105 for item in iterable:
106 if counter:
107 counter.value += 1
108 yield item
109
110
111async def _make_async_iterator(iterable: Iterable[T], counter: Optional[Counter] = None) -> AsyncIterator[T]:
112 for item in iterable:
113 if counter:
114 counter.value += 1
115 yield item
116
117
118def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
119 transport = client._client._transport
120 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
121
122 pool = transport._pool
123 return len(pool._requests)
124
125
126class TestOpenAI:
127 @pytest.mark.respx(base_url=base_url)
128 def test_raw_response(self, respx_mock: MockRouter, client: OpenAI) -> None:
129 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
130
131 response = client.post("/foo", cast_to=httpx.Response)
132 assert response.status_code == 200
133 assert isinstance(response, httpx.Response)
134 assert response.json() == {"foo": "bar"}
135
136 @pytest.mark.respx(base_url=base_url)
137 def test_raw_response_for_binary(self, respx_mock: MockRouter, client: OpenAI) -> None:
138 respx_mock.post("/foo").mock(
139 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
140 )
141
142 response = client.post("/foo", cast_to=httpx.Response)
143 assert response.status_code == 200
144 assert isinstance(response, httpx.Response)
145 assert response.json() == {"foo": "bar"}
146
147 def test_copy(self, client: OpenAI) -> None:
148 copied = client.copy()
149 assert id(copied) != id(client)
150
151 copied = client.copy(api_key="another My API Key")
152 assert copied.api_key == "another My API Key"
153 assert client.api_key == "My API Key"
154
155 copied = client.copy(admin_api_key="another My Admin API Key")
156 assert copied.admin_api_key == "another My Admin API Key"
157 assert client.admin_api_key == "My Admin API Key"
158
159 def test_copy_default_options(self, client: OpenAI) -> None:
160 # options that have a default are overridden correctly
161 copied = client.copy(max_retries=7)
162 assert copied.max_retries == 7
163 assert client.max_retries == 2
164
165 copied2 = copied.copy(max_retries=6)
166 assert copied2.max_retries == 6
167 assert copied.max_retries == 7
168
169 # timeout
170 assert isinstance(client.timeout, httpx.Timeout)
171 copied = client.copy(timeout=None)
172 assert copied.timeout is None
173 assert isinstance(client.timeout, httpx.Timeout)
174
175 def test_copy_default_headers(self) -> None:
176 client = OpenAI(
177 base_url=base_url,
178 api_key=api_key,
179 admin_api_key=admin_api_key,
180 _strict_response_validation=True,
181 default_headers={"X-Foo": "bar"},
182 )
183 assert client.default_headers["X-Foo"] == "bar"
184
185 # does not override the already given value when not specified
186 copied = client.copy()
187 assert copied.default_headers["X-Foo"] == "bar"
188
189 # merges already given headers
190 copied = client.copy(default_headers={"X-Bar": "stainless"})
191 assert copied.default_headers["X-Foo"] == "bar"
192 assert copied.default_headers["X-Bar"] == "stainless"
193
194 # uses new values for any already given headers
195 copied = client.copy(default_headers={"X-Foo": "stainless"})
196 assert copied.default_headers["X-Foo"] == "stainless"
197
198 # set_default_headers
199
200 # completely overrides already set values
201 copied = client.copy(set_default_headers={})
202 assert copied.default_headers.get("X-Foo") is None
203
204 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
205 assert copied.default_headers["X-Bar"] == "Robert"
206
207 with pytest.raises(
208 ValueError,
209 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
210 ):
211 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
212 client.close()
213
214 def test_copy_default_query(self) -> None:
215 client = OpenAI(
216 base_url=base_url,
217 api_key=api_key,
218 admin_api_key=admin_api_key,
219 _strict_response_validation=True,
220 default_query={"foo": "bar"},
221 )
222 assert _get_params(client)["foo"] == "bar"
223
224 # does not override the already given value when not specified
225 copied = client.copy()
226 assert _get_params(copied)["foo"] == "bar"
227
228 # merges already given params
229 copied = client.copy(default_query={"bar": "stainless"})
230 params = _get_params(copied)
231 assert params["foo"] == "bar"
232 assert params["bar"] == "stainless"
233
234 # uses new values for any already given headers
235 copied = client.copy(default_query={"foo": "stainless"})
236 assert _get_params(copied)["foo"] == "stainless"
237
238 # set_default_query
239
240 # completely overrides already set values
241 copied = client.copy(set_default_query={})
242 assert _get_params(copied) == {}
243
244 copied = client.copy(set_default_query={"bar": "Robert"})
245 assert _get_params(copied)["bar"] == "Robert"
246
247 with pytest.raises(
248 ValueError,
249 # TODO: update
250 match="`default_query` and `set_default_query` arguments are mutually exclusive",
251 ):
252 client.copy(set_default_query={}, default_query={"foo": "Bar"})
253
254 client.close()
255
256 def test_copy_signature(self, client: OpenAI) -> None:
257 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
258 init_signature = inspect.signature(
259 # mypy doesn't like that we access the `__init__` property.
260 client.__init__, # type: ignore[misc]
261 )
262 copy_signature = inspect.signature(client.copy)
263 exclude_params = {"transport", "proxies", "_strict_response_validation"}
264
265 for name in init_signature.parameters.keys():
266 if name in exclude_params:
267 continue
268
269 copy_param = copy_signature.parameters.get(name)
270 assert copy_param is not None, f"copy() signature is missing the {name} param"
271
272 @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
273 def test_copy_build_request(self, client: OpenAI) -> None:
274 options = FinalRequestOptions(method="get", url="/foo")
275
276 def build_request(options: FinalRequestOptions) -> None:
277 client_copy = client.copy()
278 client_copy._build_request(options)
279
280 # ensure that the machinery is warmed up before tracing starts.
281 build_request(options)
282 gc.collect()
283
284 tracemalloc.start(1000)
285
286 snapshot_before = tracemalloc.take_snapshot()
287
288 ITERATIONS = 10
289 for _ in range(ITERATIONS):
290 build_request(options)
291
292 gc.collect()
293 snapshot_after = tracemalloc.take_snapshot()
294
295 tracemalloc.stop()
296
297 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
298 if diff.count == 0:
299 # Avoid false positives by considering only leaks (i.e. allocations that persist).
300 return
301
302 if diff.count % ITERATIONS != 0:
303 # Avoid false positives by considering only leaks that appear per iteration.
304 return
305
306 for frame in diff.traceback:
307 if any(
308 frame.filename.endswith(fragment)
309 for fragment in [
310 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
311 #
312 # removing the decorator fixes the leak for reasons we don't understand.
313 "openai/_legacy_response.py",
314 "openai/_response.py",
315 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
316 "openai/_compat.py",
317 # Standard library leaks we don't care about.
318 "/logging/__init__.py",
319 ]
320 ):
321 return
322
323 leaks.append(diff)
324
325 leaks: list[tracemalloc.StatisticDiff] = []
326 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
327 add_leak(leaks, diff)
328 if leaks:
329 for leak in leaks:
330 print("MEMORY LEAK:", leak)
331 for frame in leak.traceback:
332 print(frame)
333 raise AssertionError()
334
335 def test_request_timeout(self, client: OpenAI) -> None:
336 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
337 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
338 assert timeout == DEFAULT_TIMEOUT
339
340 request = client._build_request(FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)))
341 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
342 assert timeout == httpx.Timeout(100.0)
343
344 def test_client_timeout_option(self) -> None:
345 client = OpenAI(
346 base_url=base_url,
347 api_key=api_key,
348 admin_api_key=admin_api_key,
349 _strict_response_validation=True,
350 timeout=httpx.Timeout(0),
351 )
352
353 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
354 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
355 assert timeout == httpx.Timeout(0)
356
357 client.close()
358
359 def test_http_client_timeout_option(self) -> None:
360 # custom timeout given to the httpx client should be used
361 with httpx.Client(timeout=None) as http_client:
362 client = OpenAI(
363 base_url=base_url,
364 api_key=api_key,
365 admin_api_key=admin_api_key,
366 _strict_response_validation=True,
367 http_client=http_client,
368 )
369
370 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
371 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
372 assert timeout == httpx.Timeout(None)
373
374 client.close()
375
376 # no timeout given to the httpx client should not use the httpx default
377 with httpx.Client() as http_client:
378 client = OpenAI(
379 base_url=base_url,
380 api_key=api_key,
381 admin_api_key=admin_api_key,
382 _strict_response_validation=True,
383 http_client=http_client,
384 )
385
386 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
387 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
388 assert timeout == DEFAULT_TIMEOUT
389
390 client.close()
391
392 # explicitly passing the default timeout currently results in it being ignored
393 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
394 client = OpenAI(
395 base_url=base_url,
396 api_key=api_key,
397 admin_api_key=admin_api_key,
398 _strict_response_validation=True,
399 http_client=http_client,
400 )
401
402 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
403 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
404 assert timeout == DEFAULT_TIMEOUT # our default
405
406 client.close()
407
408 async def test_invalid_http_client(self) -> None:
409 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
410 async with httpx.AsyncClient() as http_client:
411 OpenAI(
412 base_url=base_url,
413 api_key=api_key,
414 admin_api_key=admin_api_key,
415 _strict_response_validation=True,
416 http_client=cast(Any, http_client),
417 )
418
419 def test_default_headers_option(self) -> None:
420 test_client = OpenAI(
421 base_url=base_url,
422 api_key=api_key,
423 admin_api_key=admin_api_key,
424 _strict_response_validation=True,
425 default_headers={"X-Foo": "bar"},
426 )
427 request = test_client._build_request(FinalRequestOptions(method="get", url="/foo"))
428 assert request.headers.get("x-foo") == "bar"
429 assert request.headers.get("x-stainless-lang") == "python"
430
431 test_client2 = OpenAI(
432 base_url=base_url,
433 api_key=api_key,
434 admin_api_key=admin_api_key,
435 _strict_response_validation=True,
436 default_headers={
437 "X-Foo": "stainless",
438 "X-Stainless-Lang": "my-overriding-header",
439 },
440 )
441 request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo"))
442 assert request.headers.get("x-foo") == "stainless"
443 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
444
445 test_client.close()
446 test_client2.close()
447
448 def test_validate_headers(self) -> None:
449 client = OpenAI(
450 base_url=base_url, api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=True
451 )
452 options = client._prepare_options(FinalRequestOptions(method="get", url="/foo"))
453 request = client._build_request(options)
454
455 assert request.headers.get("Authorization") == f"Bearer {api_key}"
456
457 admin_request = client._build_request(
458 FinalRequestOptions(
459 method="get",
460 url="/organization/projects",
461 security={"admin_api_key_auth": True},
462 )
463 )
464 assert admin_request.headers.get("Authorization") == f"Bearer {admin_api_key}"
465
466 with update_env(**{"OPENAI_API_KEY": Omit()}):
467 admin_only = OpenAI(
468 base_url=base_url,
469 api_key=None,
470 admin_api_key=admin_api_key,
471 _strict_response_validation=True,
472 )
473 admin_only_request = admin_only._build_request(
474 FinalRequestOptions(
475 method="get",
476 url="/organization/projects",
477 security={"admin_api_key_auth": True},
478 )
479 )
480 assert admin_only_request.headers.get("Authorization") == f"Bearer {admin_api_key}"
481
482 with pytest.raises(
483 TypeError,
484 match="Could not resolve authentication method",
485 ):
486 admin_only._build_request(
487 FinalRequestOptions(
488 method="post",
489 url="/responses",
490 security={"bearer_auth": True},
491 )
492 )
493
494 with update_env(
495 **{
496 "OPENAI_API_KEY": Omit(),
497 "OPENAI_ADMIN_KEY": Omit(),
498 }
499 ):
500 no_credentials = OpenAI(
501 base_url=base_url,
502 api_key=None,
503 admin_api_key=None,
504 _enforce_credentials=False,
505 _strict_response_validation=True,
506 )
507 lowercase_auth_request = no_credentials._build_request(
508 FinalRequestOptions(method="get", url="/foo", headers={"authorization": "Bearer custom"})
509 )
510 assert lowercase_auth_request.headers.get("Authorization") == "Bearer custom"
511
512 omitted_auth_request = no_credentials._build_request(
513 FinalRequestOptions(method="get", url="/foo", headers={"authorization": Omit()})
514 )
515 assert "Authorization" not in omitted_auth_request.headers
516
517 with update_env(
518 **{
519 "OPENAI_API_KEY": Omit(),
520 "OPENAI_ADMIN_KEY": Omit(),
521 }
522 ):
523 with pytest.raises(OpenAIError, match="Missing credentials"):
524 OpenAI(base_url=base_url, api_key=None, admin_api_key=None, _strict_response_validation=True)
525
526 @pytest.mark.respx(base_url=base_url)
527 def test_api_key_provider_preserves_admin_auth(self, respx_mock: MockRouter) -> None:
528 respx_mock.get("/organization/projects").mock(return_value=httpx.Response(200, json={"ok": True}))
529
530 provider_called = False
531
532 def api_key_provider() -> str:
533 nonlocal provider_called
534 provider_called = True
535 return "dynamic-api-key"
536
537 client = OpenAI(base_url=base_url, api_key=api_key_provider, admin_api_key=admin_api_key)
538 response = client.get(
539 "/organization/projects",
540 cast_to=httpx.Response,
541 options={"security": {"admin_api_key_auth": True}},
542 )
543
544 assert response.request.headers.get("Authorization") == f"Bearer {admin_api_key}"
545 assert provider_called is False
546
547 def test_api_key_provider_does_not_fill_admin_auth(self) -> None:
548 provider_called = False
549
550 def api_key_provider() -> str:
551 nonlocal provider_called
552 provider_called = True
553 return "dynamic-api-key"
554
555 with update_env(OPENAI_ADMIN_KEY=Omit()):
556 client = OpenAI(base_url=base_url, api_key=api_key_provider, admin_api_key=None)
557 with pytest.raises(TypeError, match="Could not resolve authentication method"):
558 client.get(
559 "/organization/projects",
560 cast_to=httpx.Response,
561 options={"security": {"admin_api_key_auth": True}},
562 )
563
564 assert provider_called is False
565
566 @pytest.mark.respx(base_url=base_url)
567 def test_workload_identity_preserves_admin_auth(self, respx_mock: MockRouter) -> None:
568 respx_mock.get("/organization/projects").mock(return_value=httpx.Response(200, json={"ok": True}))
569
570 client = OpenAI(base_url=base_url, workload_identity=workload_identity, admin_api_key=admin_api_key)
571 response = client.get(
572 "/organization/projects",
573 cast_to=httpx.Response,
574 options={"security": {"admin_api_key_auth": True}},
575 )
576
577 assert response.request.headers.get("Authorization") == f"Bearer {admin_api_key}"
578
579 def test_workload_identity_is_mutually_exclusive_with_api_key(self) -> None:
580 with pytest.raises(
581 OpenAIError,
582 match="The `api_key` and `workload_identity` arguments are mutually exclusive",
583 ):
584 OpenAI(
585 base_url=base_url,
586 api_key=api_key,
587 workload_identity=workload_identity, # type: ignore[reportArgumentType]
588 organization="org_123",
589 _strict_response_validation=True,
590 )
591
592 def test_default_query_option(self) -> None:
593 client = OpenAI(
594 base_url=base_url,
595 api_key=api_key,
596 admin_api_key=admin_api_key,
597 _strict_response_validation=True,
598 default_query={"query_param": "bar"},
599 )
600 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
601 url = httpx.URL(request.url)
602 assert dict(url.params) == {"query_param": "bar"}
603
604 request = client._build_request(
605 FinalRequestOptions(
606 method="get",
607 url="/foo",
608 params={"foo": "baz", "query_param": "overridden"},
609 )
610 )
611 url = httpx.URL(request.url)
612 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
613
614 client.close()
615
616 def test_hardcoded_query_params_in_url(self, client: OpenAI) -> None:
617 request = client._build_request(FinalRequestOptions(method="get", url="/foo?beta=true"))
618 url = httpx.URL(request.url)
619 assert dict(url.params) == {"beta": "true"}
620
621 request = client._build_request(
622 FinalRequestOptions(
623 method="get",
624 url="/foo?beta=true",
625 params={"limit": "10", "page": "abc"},
626 )
627 )
628 url = httpx.URL(request.url)
629 assert dict(url.params) == {"beta": "true", "limit": "10", "page": "abc"}
630
631 request = client._build_request(
632 FinalRequestOptions(
633 method="get",
634 url="/files/a%2Fb?beta=true",
635 params={"limit": "10"},
636 )
637 )
638 assert request.url.raw_path == b"/files/a%2Fb?beta=true&limit=10"
639
640 def test_request_extra_json(self, client: OpenAI) -> None:
641 request = client._build_request(
642 FinalRequestOptions(
643 method="post",
644 url="/foo",
645 json_data={"foo": "bar"},
646 extra_json={"baz": False},
647 ),
648 )
649 data = json.loads(request.content.decode("utf-8"))
650 assert data == {"foo": "bar", "baz": False}
651
652 request = client._build_request(
653 FinalRequestOptions(
654 method="post",
655 url="/foo",
656 extra_json={"baz": False},
657 ),
658 )
659 data = json.loads(request.content.decode("utf-8"))
660 assert data == {"baz": False}
661
662 # `extra_json` takes priority over `json_data` when keys clash
663 request = client._build_request(
664 FinalRequestOptions(
665 method="post",
666 url="/foo",
667 json_data={"foo": "bar", "baz": True},
668 extra_json={"baz": None},
669 ),
670 )
671 data = json.loads(request.content.decode("utf-8"))
672 assert data == {"foo": "bar", "baz": None}
673
674 def test_request_extra_headers(self, client: OpenAI) -> None:
675 request = client._build_request(
676 FinalRequestOptions(
677 method="post",
678 url="/foo",
679 **make_request_options(extra_headers={"X-Foo": "Foo"}),
680 ),
681 )
682 assert request.headers.get("X-Foo") == "Foo"
683
684 # `extra_headers` takes priority over `default_headers` when keys clash
685 request = client.with_options(default_headers={"X-Bar": "true"})._build_request(
686 FinalRequestOptions(
687 method="post",
688 url="/foo",
689 **make_request_options(
690 extra_headers={"X-Bar": "false"},
691 ),
692 ),
693 )
694 assert request.headers.get("X-Bar") == "false"
695
696 def test_request_extra_query(self, client: OpenAI) -> None:
697 request = client._build_request(
698 FinalRequestOptions(
699 method="post",
700 url="/foo",
701 **make_request_options(
702 extra_query={"my_query_param": "Foo"},
703 ),
704 ),
705 )
706 params = dict(request.url.params)
707 assert params == {"my_query_param": "Foo"}
708
709 # if both `query` and `extra_query` are given, they are merged
710 request = client._build_request(
711 FinalRequestOptions(
712 method="post",
713 url="/foo",
714 **make_request_options(
715 query={"bar": "1"},
716 extra_query={"foo": "2"},
717 ),
718 ),
719 )
720 params = dict(request.url.params)
721 assert params == {"bar": "1", "foo": "2"}
722
723 # `extra_query` takes priority over `query` when keys clash
724 request = client._build_request(
725 FinalRequestOptions(
726 method="post",
727 url="/foo",
728 **make_request_options(
729 query={"foo": "1"},
730 extra_query={"foo": "2"},
731 ),
732 ),
733 )
734 params = dict(request.url.params)
735 assert params == {"foo": "2"}
736
737 def test_multipart_repeating_array(self, client: OpenAI) -> None:
738 request = client._build_request(
739 FinalRequestOptions.construct(
740 method="post",
741 url="/foo",
742 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
743 json_data={"array": ["foo", "bar"]},
744 files=[("foo.txt", b"hello world")],
745 )
746 )
747
748 assert request.read().split(b"\r\n") == [
749 b"--6b7ba517decee4a450543ea6ae821c82",
750 b'Content-Disposition: form-data; name="array[]"',
751 b"",
752 b"foo",
753 b"--6b7ba517decee4a450543ea6ae821c82",
754 b'Content-Disposition: form-data; name="array[]"',
755 b"",
756 b"bar",
757 b"--6b7ba517decee4a450543ea6ae821c82",
758 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
759 b"Content-Type: application/octet-stream",
760 b"",
761 b"hello world",
762 b"--6b7ba517decee4a450543ea6ae821c82--",
763 b"",
764 ]
765
766 @pytest.mark.respx(base_url=base_url)
767 def test_binary_content_upload(self, respx_mock: MockRouter, client: OpenAI) -> None:
768 respx_mock.post("/upload").mock(side_effect=mirror_request_content)
769
770 file_content = b"Hello, this is a test file."
771
772 response = client.post(
773 "/upload",
774 content=file_content,
775 cast_to=httpx.Response,
776 options={"headers": {"Content-Type": "application/octet-stream"}},
777 )
778
779 assert response.status_code == 200
780 assert response.request.headers["Content-Type"] == "application/octet-stream"
781 assert response.content == file_content
782
783 def test_binary_content_upload_with_iterator(self) -> None:
784 file_content = b"Hello, this is a test file."
785 counter = Counter()
786 iterator = _make_sync_iterator([file_content], counter=counter)
787
788 def mock_handler(request: httpx.Request) -> httpx.Response:
789 assert counter.value == 0, "the request body should not have been read"
790 return httpx.Response(200, content=request.read())
791
792 with OpenAI(
793 base_url=base_url,
794 api_key=api_key,
795 admin_api_key=admin_api_key,
796 _strict_response_validation=True,
797 http_client=httpx.Client(transport=MockTransport(handler=mock_handler)),
798 ) as client:
799 response = client.post(
800 "/upload",
801 content=iterator,
802 cast_to=httpx.Response,
803 options={"headers": {"Content-Type": "application/octet-stream"}},
804 )
805
806 assert response.status_code == 200
807 assert response.request.headers["Content-Type"] == "application/octet-stream"
808 assert response.content == file_content
809 assert counter.value == 1
810
811 @pytest.mark.respx(base_url=base_url)
812 def test_binary_content_upload_with_body_is_deprecated(self, respx_mock: MockRouter, client: OpenAI) -> None:
813 respx_mock.post("/upload").mock(side_effect=mirror_request_content)
814
815 file_content = b"Hello, this is a test file."
816
817 with pytest.deprecated_call(
818 match="Passing raw bytes as `body` is deprecated and will be removed in a future version. Please pass raw bytes via the `content` parameter instead."
819 ):
820 response = client.post(
821 "/upload",
822 body=file_content,
823 cast_to=httpx.Response,
824 options={"headers": {"Content-Type": "application/octet-stream"}},
825 )
826
827 assert response.status_code == 200
828 assert response.request.headers["Content-Type"] == "application/octet-stream"
829 assert response.content == file_content
830
831 @pytest.mark.respx(base_url=base_url)
832 def test_basic_union_response(self, respx_mock: MockRouter, client: OpenAI) -> None:
833 class Model1(BaseModel):
834 name: str
835
836 class Model2(BaseModel):
837 foo: str
838
839 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
840
841 response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
842 assert isinstance(response, Model2)
843 assert response.foo == "bar"
844
845 @pytest.mark.respx(base_url=base_url)
846 def test_union_response_different_types(self, respx_mock: MockRouter, client: OpenAI) -> None:
847 """Union of objects with the same field name using a different type"""
848
849 class Model1(BaseModel):
850 foo: int
851
852 class Model2(BaseModel):
853 foo: str
854
855 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
856
857 response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
858 assert isinstance(response, Model2)
859 assert response.foo == "bar"
860
861 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
862
863 response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
864 assert isinstance(response, Model1)
865 assert response.foo == 1
866
867 @pytest.mark.respx(base_url=base_url)
868 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter, client: OpenAI) -> None:
869 """
870 Response that sets Content-Type to something other than application/json but returns json data
871 """
872
873 class Model(BaseModel):
874 foo: int
875
876 respx_mock.get("/foo").mock(
877 return_value=httpx.Response(
878 200,
879 content=json.dumps({"foo": 2}),
880 headers={"Content-Type": "application/text"},
881 )
882 )
883
884 response = client.get("/foo", cast_to=Model)
885 assert isinstance(response, Model)
886 assert response.foo == 2
887
888 def test_base_url_setter(self) -> None:
889 client = OpenAI(
890 base_url="https://example.com/from_init",
891 api_key=api_key,
892 admin_api_key=admin_api_key,
893 _strict_response_validation=True,
894 )
895 assert client.base_url == "https://example.com/from_init/"
896
897 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
898
899 assert client.base_url == "https://example.com/from_setter/"
900
901 client.close()
902
903 def test_base_url_env(self) -> None:
904 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
905 client = OpenAI(api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=True)
906 assert client.base_url == "http://localhost:5000/from/env/"
907
908 @pytest.mark.parametrize(
909 "client",
910 [
911 OpenAI(
912 base_url="http://localhost:5000/custom/path/",
913 api_key=api_key,
914 admin_api_key=admin_api_key,
915 _strict_response_validation=True,
916 ),
917 OpenAI(
918 base_url="http://localhost:5000/custom/path/",
919 api_key=api_key,
920 admin_api_key=admin_api_key,
921 _strict_response_validation=True,
922 http_client=httpx.Client(),
923 ),
924 ],
925 ids=["standard", "custom http client"],
926 )
927 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
928 request = client._build_request(
929 FinalRequestOptions(
930 method="post",
931 url="/foo",
932 json_data={"foo": "bar"},
933 ),
934 )
935 assert request.url == "http://localhost:5000/custom/path/foo"
936 client.close()
937
938 @pytest.mark.parametrize(
939 "client",
940 [
941 OpenAI(
942 base_url="http://localhost:5000/custom/path/",
943 api_key=api_key,
944 admin_api_key=admin_api_key,
945 _strict_response_validation=True,
946 ),
947 OpenAI(
948 base_url="http://localhost:5000/custom/path/",
949 api_key=api_key,
950 admin_api_key=admin_api_key,
951 _strict_response_validation=True,
952 http_client=httpx.Client(),
953 ),
954 ],
955 ids=["standard", "custom http client"],
956 )
957 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
958 request = client._build_request(
959 FinalRequestOptions(
960 method="post",
961 url="/foo",
962 json_data={"foo": "bar"},
963 ),
964 )
965 assert request.url == "http://localhost:5000/custom/path/foo"
966 client.close()
967
968 @pytest.mark.parametrize(
969 "client",
970 [
971 OpenAI(
972 base_url="http://localhost:5000/custom/path/",
973 api_key=api_key,
974 admin_api_key=admin_api_key,
975 _strict_response_validation=True,
976 ),
977 OpenAI(
978 base_url="http://localhost:5000/custom/path/",
979 api_key=api_key,
980 admin_api_key=admin_api_key,
981 _strict_response_validation=True,
982 http_client=httpx.Client(),
983 ),
984 ],
985 ids=["standard", "custom http client"],
986 )
987 def test_absolute_request_url(self, client: OpenAI) -> None:
988 request = client._build_request(
989 FinalRequestOptions(
990 method="post",
991 url="https://myapi.com/foo",
992 json_data={"foo": "bar"},
993 ),
994 )
995 assert request.url == "https://myapi.com/foo"
996 client.close()
997
998 def test_copied_client_does_not_close_http(self) -> None:
999 test_client = OpenAI(
1000 base_url=base_url, api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=True
1001 )
1002 assert not test_client.is_closed()
1003
1004 copied = test_client.copy()
1005 assert copied is not test_client
1006
1007 del copied
1008
1009 assert not test_client.is_closed()
1010
1011 def test_client_context_manager(self) -> None:
1012 test_client = OpenAI(
1013 base_url=base_url, api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=True
1014 )
1015 with test_client as c2:
1016 assert c2 is test_client
1017 assert not c2.is_closed()
1018 assert not test_client.is_closed()
1019 assert test_client.is_closed()
1020
1021 @pytest.mark.respx(base_url=base_url)
1022 def test_client_response_validation_error(self, respx_mock: MockRouter, client: OpenAI) -> None:
1023 class Model(BaseModel):
1024 foo: str
1025
1026 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1027
1028 with pytest.raises(APIResponseValidationError) as exc:
1029 client.get("/foo", cast_to=Model)
1030
1031 assert isinstance(exc.value.__cause__, ValidationError)
1032
1033 def test_client_max_retries_validation(self) -> None:
1034 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
1035 OpenAI(
1036 base_url=base_url,
1037 api_key=api_key,
1038 admin_api_key=admin_api_key,
1039 _strict_response_validation=True,
1040 max_retries=cast(Any, None),
1041 )
1042
1043 @pytest.mark.respx(base_url=base_url)
1044 def test_default_stream_cls(self, respx_mock: MockRouter, client: OpenAI) -> None:
1045 class Model(BaseModel):
1046 name: str
1047
1048 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1049
1050 stream = client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
1051 assert isinstance(stream, Stream)
1052 stream.response.close()
1053
1054 @pytest.mark.respx(base_url=base_url)
1055 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1056 class Model(BaseModel):
1057 name: str
1058
1059 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1060
1061 strict_client = OpenAI(
1062 base_url=base_url, api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=True
1063 )
1064
1065 with pytest.raises(APIResponseValidationError):
1066 strict_client.get("/foo", cast_to=Model)
1067
1068 non_strict_client = OpenAI(
1069 base_url=base_url, api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=False
1070 )
1071
1072 response = non_strict_client.get("/foo", cast_to=Model)
1073 assert isinstance(response, str) # type: ignore[unreachable]
1074
1075 strict_client.close()
1076 non_strict_client.close()
1077
1078 @pytest.mark.parametrize(
1079 "remaining_retries,retry_after,timeout",
1080 [
1081 [3, "20", 20],
1082 [3, "0", 0.5],
1083 [3, "-10", 0.5],
1084 [3, "60", 60],
1085 [3, "61", 0.5],
1086 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1087 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1088 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1089 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1090 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1091 [3, "99999999999999999999999999999999999", 0.5],
1092 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1093 [3, "", 0.5],
1094 [2, "", 0.5 * 2.0],
1095 [1, "", 0.5 * 4.0],
1096 [-1100, "", 8], # test large number potentially overflowing
1097 ],
1098 )
1099 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1100 def test_parse_retry_after_header(
1101 self, remaining_retries: int, retry_after: str, timeout: float, client: OpenAI
1102 ) -> None:
1103 headers = httpx.Headers({"retry-after": retry_after})
1104 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1105 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1106 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1107
1108 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1109 @pytest.mark.respx(base_url=base_url)
1110 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: OpenAI) -> None:
1111 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1112
1113 with pytest.raises(APITimeoutError):
1114 client.chat.completions.with_streaming_response.create(
1115 messages=[
1116 {
1117 "content": "string",
1118 "role": "developer",
1119 }
1120 ],
1121 model="gpt-5.4",
1122 ).__enter__()
1123
1124 assert _get_open_connections(client) == 0
1125
1126 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1127 @pytest.mark.respx(base_url=base_url)
1128 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: OpenAI) -> None:
1129 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1130
1131 with pytest.raises(APIStatusError):
1132 client.chat.completions.with_streaming_response.create(
1133 messages=[
1134 {
1135 "content": "string",
1136 "role": "developer",
1137 }
1138 ],
1139 model="gpt-5.4",
1140 ).__enter__()
1141 assert _get_open_connections(client) == 0
1142
1143 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1144 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1145 @pytest.mark.respx(base_url=base_url)
1146 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
1147 def test_retries_taken(
1148 self,
1149 client: OpenAI,
1150 failures_before_success: int,
1151 failure_mode: Literal["status", "exception"],
1152 respx_mock: MockRouter,
1153 ) -> None:
1154 client = client.with_options(max_retries=4)
1155
1156 nb_retries = 0
1157
1158 def retry_handler(_request: httpx.Request) -> httpx.Response:
1159 nonlocal nb_retries
1160 if nb_retries < failures_before_success:
1161 nb_retries += 1
1162 if failure_mode == "exception":
1163 raise RuntimeError("oops")
1164 return httpx.Response(500)
1165 return httpx.Response(200)
1166
1167 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1168
1169 response = client.chat.completions.with_raw_response.create(
1170 messages=[
1171 {
1172 "content": "string",
1173 "role": "developer",
1174 }
1175 ],
1176 model="gpt-5.4",
1177 )
1178
1179 assert response.retries_taken == failures_before_success
1180 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1181
1182 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1183 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1184 @pytest.mark.respx(base_url=base_url)
1185 def test_omit_retry_count_header(
1186 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
1187 ) -> None:
1188 client = client.with_options(max_retries=4)
1189
1190 nb_retries = 0
1191
1192 def retry_handler(_request: httpx.Request) -> httpx.Response:
1193 nonlocal nb_retries
1194 if nb_retries < failures_before_success:
1195 nb_retries += 1
1196 return httpx.Response(500)
1197 return httpx.Response(200)
1198
1199 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1200
1201 response = client.chat.completions.with_raw_response.create(
1202 messages=[
1203 {
1204 "content": "string",
1205 "role": "developer",
1206 }
1207 ],
1208 model="gpt-5.4",
1209 extra_headers={"x-stainless-retry-count": Omit()},
1210 )
1211
1212 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
1213
1214 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1215 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1216 @pytest.mark.respx(base_url=base_url)
1217 def test_overwrite_retry_count_header(
1218 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
1219 ) -> None:
1220 client = client.with_options(max_retries=4)
1221
1222 nb_retries = 0
1223
1224 def retry_handler(_request: httpx.Request) -> httpx.Response:
1225 nonlocal nb_retries
1226 if nb_retries < failures_before_success:
1227 nb_retries += 1
1228 return httpx.Response(500)
1229 return httpx.Response(200)
1230
1231 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1232
1233 response = client.chat.completions.with_raw_response.create(
1234 messages=[
1235 {
1236 "content": "string",
1237 "role": "developer",
1238 }
1239 ],
1240 model="gpt-5.4",
1241 extra_headers={"x-stainless-retry-count": "42"},
1242 )
1243
1244 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
1245
1246 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1247 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1248 @pytest.mark.respx(base_url=base_url)
1249 def test_retries_taken_new_response_class(
1250 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
1251 ) -> None:
1252 client = client.with_options(max_retries=4)
1253
1254 nb_retries = 0
1255
1256 def retry_handler(_request: httpx.Request) -> httpx.Response:
1257 nonlocal nb_retries
1258 if nb_retries < failures_before_success:
1259 nb_retries += 1
1260 return httpx.Response(500)
1261 return httpx.Response(200)
1262
1263 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1264
1265 with client.chat.completions.with_streaming_response.create(
1266 messages=[
1267 {
1268 "content": "string",
1269 "role": "developer",
1270 }
1271 ],
1272 model="gpt-5.4",
1273 ) as response:
1274 assert response.retries_taken == failures_before_success
1275 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1276
1277 def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
1278 # Test that the proxy environment variables are set correctly
1279 monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
1280 # Delete in case our environment has any proxy env vars set
1281 monkeypatch.delenv("HTTP_PROXY", raising=False)
1282 monkeypatch.delenv("ALL_PROXY", raising=False)
1283 monkeypatch.delenv("NO_PROXY", raising=False)
1284 monkeypatch.delenv("http_proxy", raising=False)
1285 monkeypatch.delenv("https_proxy", raising=False)
1286 monkeypatch.delenv("all_proxy", raising=False)
1287 monkeypatch.delenv("no_proxy", raising=False)
1288
1289 client = DefaultHttpxClient()
1290
1291 mounts = tuple(client._mounts.items())
1292 assert len(mounts) == 1
1293 assert mounts[0][0].pattern == "https://"
1294
1295 @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
1296 def test_default_client_creation(self) -> None:
1297 # Ensure that the client can be initialized without any exceptions
1298 DefaultHttpxClient(
1299 verify=True,
1300 cert=None,
1301 trust_env=True,
1302 http1=True,
1303 http2=False,
1304 limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
1305 )
1306
1307 @pytest.mark.respx(base_url=base_url)
1308 def test_follow_redirects(self, respx_mock: MockRouter, client: OpenAI) -> None:
1309 # Test that the default follow_redirects=True allows following redirects
1310 respx_mock.post("/redirect").mock(
1311 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1312 )
1313 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
1314
1315 response = client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
1316 assert response.status_code == 200
1317 assert response.json() == {"status": "ok"}
1318
1319 @pytest.mark.respx(base_url=base_url)
1320 def test_follow_redirects_disabled(self, respx_mock: MockRouter, client: OpenAI) -> None:
1321 # Test that follow_redirects=False prevents following redirects
1322 respx_mock.post("/redirect").mock(
1323 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1324 )
1325
1326 with pytest.raises(APIStatusError) as exc_info:
1327 client.post("/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response)
1328
1329 assert exc_info.value.response.status_code == 302
1330 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
1331
1332 def test_api_key_before_after_refresh_provider(self) -> None:
1333 client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token")
1334
1335 assert client.api_key == ""
1336 assert "Authorization" not in client.auth_headers
1337
1338 client._refresh_api_key()
1339
1340 assert client.api_key == "test_bearer_token"
1341 assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
1342
1343 def test_api_key_before_after_refresh_str(self) -> None:
1344 client = OpenAI(base_url=base_url, api_key="test_api_key")
1345
1346 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
1347 client._refresh_api_key()
1348
1349 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
1350
1351 @pytest.mark.respx()
1352 def test_api_key_refresh_on_retry(self, respx_mock: MockRouter) -> None:
1353 respx_mock.post(base_url + "/chat/completions").mock(
1354 side_effect=[
1355 httpx.Response(500, json={"error": "server error"}),
1356 httpx.Response(200, json={"foo": "bar"}),
1357 ]
1358 )
1359
1360 counter = 0
1361
1362 def token_provider() -> str:
1363 nonlocal counter
1364
1365 counter += 1
1366
1367 if counter == 1:
1368 return "first"
1369
1370 return "second"
1371
1372 client = OpenAI(base_url=base_url, api_key=token_provider)
1373 client.chat.completions.create(messages=[], model="gpt-4")
1374
1375 calls = cast("list[MockRequestCall]", respx_mock.calls)
1376 assert len(calls) == 2
1377
1378 assert calls[0].request.headers.get("Authorization") == "Bearer first"
1379 assert calls[1].request.headers.get("Authorization") == "Bearer second"
1380
1381 def test_copy_auth(self) -> None:
1382 client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token_1").copy(
1383 api_key=lambda: "test_bearer_token_2"
1384 )
1385 client._refresh_api_key()
1386 assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}
1387
1388
1389class TestAsyncOpenAI:
1390 @pytest.mark.respx(base_url=base_url)
1391 async def test_raw_response(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1392 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1393
1394 response = await async_client.post("/foo", cast_to=httpx.Response)
1395 assert response.status_code == 200
1396 assert isinstance(response, httpx.Response)
1397 assert response.json() == {"foo": "bar"}
1398
1399 @pytest.mark.respx(base_url=base_url)
1400 async def test_raw_response_for_binary(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1401 respx_mock.post("/foo").mock(
1402 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
1403 )
1404
1405 response = await async_client.post("/foo", cast_to=httpx.Response)
1406 assert response.status_code == 200
1407 assert isinstance(response, httpx.Response)
1408 assert response.json() == {"foo": "bar"}
1409
1410 def test_copy(self, async_client: AsyncOpenAI) -> None:
1411 copied = async_client.copy()
1412 assert id(copied) != id(async_client)
1413
1414 copied = async_client.copy(api_key="another My API Key")
1415 assert copied.api_key == "another My API Key"
1416 assert async_client.api_key == "My API Key"
1417
1418 copied = async_client.copy(admin_api_key="another My Admin API Key")
1419 assert copied.admin_api_key == "another My Admin API Key"
1420 assert async_client.admin_api_key == "My Admin API Key"
1421
1422 def test_copy_default_options(self, async_client: AsyncOpenAI) -> None:
1423 # options that have a default are overridden correctly
1424 copied = async_client.copy(max_retries=7)
1425 assert copied.max_retries == 7
1426 assert async_client.max_retries == 2
1427
1428 copied2 = copied.copy(max_retries=6)
1429 assert copied2.max_retries == 6
1430 assert copied.max_retries == 7
1431
1432 # timeout
1433 assert isinstance(async_client.timeout, httpx.Timeout)
1434 copied = async_client.copy(timeout=None)
1435 assert copied.timeout is None
1436 assert isinstance(async_client.timeout, httpx.Timeout)
1437
1438 async def test_copy_default_headers(self) -> None:
1439 client = AsyncOpenAI(
1440 base_url=base_url,
1441 api_key=api_key,
1442 admin_api_key=admin_api_key,
1443 _strict_response_validation=True,
1444 default_headers={"X-Foo": "bar"},
1445 )
1446 assert client.default_headers["X-Foo"] == "bar"
1447
1448 # does not override the already given value when not specified
1449 copied = client.copy()
1450 assert copied.default_headers["X-Foo"] == "bar"
1451
1452 # merges already given headers
1453 copied = client.copy(default_headers={"X-Bar": "stainless"})
1454 assert copied.default_headers["X-Foo"] == "bar"
1455 assert copied.default_headers["X-Bar"] == "stainless"
1456
1457 # uses new values for any already given headers
1458 copied = client.copy(default_headers={"X-Foo": "stainless"})
1459 assert copied.default_headers["X-Foo"] == "stainless"
1460
1461 # set_default_headers
1462
1463 # completely overrides already set values
1464 copied = client.copy(set_default_headers={})
1465 assert copied.default_headers.get("X-Foo") is None
1466
1467 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
1468 assert copied.default_headers["X-Bar"] == "Robert"
1469
1470 with pytest.raises(
1471 ValueError,
1472 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
1473 ):
1474 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
1475 await client.close()
1476
1477 async def test_copy_default_query(self) -> None:
1478 client = AsyncOpenAI(
1479 base_url=base_url,
1480 api_key=api_key,
1481 admin_api_key=admin_api_key,
1482 _strict_response_validation=True,
1483 default_query={"foo": "bar"},
1484 )
1485 assert _get_params(client)["foo"] == "bar"
1486
1487 # does not override the already given value when not specified
1488 copied = client.copy()
1489 assert _get_params(copied)["foo"] == "bar"
1490
1491 # merges already given params
1492 copied = client.copy(default_query={"bar": "stainless"})
1493 params = _get_params(copied)
1494 assert params["foo"] == "bar"
1495 assert params["bar"] == "stainless"
1496
1497 # uses new values for any already given headers
1498 copied = client.copy(default_query={"foo": "stainless"})
1499 assert _get_params(copied)["foo"] == "stainless"
1500
1501 # set_default_query
1502
1503 # completely overrides already set values
1504 copied = client.copy(set_default_query={})
1505 assert _get_params(copied) == {}
1506
1507 copied = client.copy(set_default_query={"bar": "Robert"})
1508 assert _get_params(copied)["bar"] == "Robert"
1509
1510 with pytest.raises(
1511 ValueError,
1512 # TODO: update
1513 match="`default_query` and `set_default_query` arguments are mutually exclusive",
1514 ):
1515 client.copy(set_default_query={}, default_query={"foo": "Bar"})
1516
1517 await client.close()
1518
1519 def test_copy_signature(self, async_client: AsyncOpenAI) -> None:
1520 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
1521 init_signature = inspect.signature(
1522 # mypy doesn't like that we access the `__init__` property.
1523 async_client.__init__, # type: ignore[misc]
1524 )
1525 copy_signature = inspect.signature(async_client.copy)
1526 exclude_params = {"transport", "proxies", "_strict_response_validation"}
1527
1528 for name in init_signature.parameters.keys():
1529 if name in exclude_params:
1530 continue
1531
1532 copy_param = copy_signature.parameters.get(name)
1533 assert copy_param is not None, f"copy() signature is missing the {name} param"
1534
1535 @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
1536 def test_copy_build_request(self, async_client: AsyncOpenAI) -> None:
1537 options = FinalRequestOptions(method="get", url="/foo")
1538
1539 def build_request(options: FinalRequestOptions) -> None:
1540 client_copy = async_client.copy()
1541 client_copy._build_request(options)
1542
1543 # ensure that the machinery is warmed up before tracing starts.
1544 build_request(options)
1545 gc.collect()
1546
1547 tracemalloc.start(1000)
1548
1549 snapshot_before = tracemalloc.take_snapshot()
1550
1551 ITERATIONS = 10
1552 for _ in range(ITERATIONS):
1553 build_request(options)
1554
1555 gc.collect()
1556 snapshot_after = tracemalloc.take_snapshot()
1557
1558 tracemalloc.stop()
1559
1560 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
1561 if diff.count == 0:
1562 # Avoid false positives by considering only leaks (i.e. allocations that persist).
1563 return
1564
1565 if diff.count % ITERATIONS != 0:
1566 # Avoid false positives by considering only leaks that appear per iteration.
1567 return
1568
1569 for frame in diff.traceback:
1570 if any(
1571 frame.filename.endswith(fragment)
1572 for fragment in [
1573 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
1574 #
1575 # removing the decorator fixes the leak for reasons we don't understand.
1576 "openai/_legacy_response.py",
1577 "openai/_response.py",
1578 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
1579 "openai/_compat.py",
1580 # Standard library leaks we don't care about.
1581 "/logging/__init__.py",
1582 ]
1583 ):
1584 return
1585
1586 leaks.append(diff)
1587
1588 leaks: list[tracemalloc.StatisticDiff] = []
1589 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
1590 add_leak(leaks, diff)
1591 if leaks:
1592 for leak in leaks:
1593 print("MEMORY LEAK:", leak)
1594 for frame in leak.traceback:
1595 print(frame)
1596 raise AssertionError()
1597
1598 async def test_request_timeout(self, async_client: AsyncOpenAI) -> None:
1599 request = async_client._build_request(FinalRequestOptions(method="get", url="/foo"))
1600 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1601 assert timeout == DEFAULT_TIMEOUT
1602
1603 request = async_client._build_request(
1604 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
1605 )
1606 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1607 assert timeout == httpx.Timeout(100.0)
1608
1609 async def test_client_timeout_option(self) -> None:
1610 client = AsyncOpenAI(
1611 base_url=base_url,
1612 api_key=api_key,
1613 admin_api_key=admin_api_key,
1614 _strict_response_validation=True,
1615 timeout=httpx.Timeout(0),
1616 )
1617
1618 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1619 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1620 assert timeout == httpx.Timeout(0)
1621
1622 await client.close()
1623
1624 async def test_http_client_timeout_option(self) -> None:
1625 # custom timeout given to the httpx client should be used
1626 async with httpx.AsyncClient(timeout=None) as http_client:
1627 client = AsyncOpenAI(
1628 base_url=base_url,
1629 api_key=api_key,
1630 admin_api_key=admin_api_key,
1631 _strict_response_validation=True,
1632 http_client=http_client,
1633 )
1634
1635 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1636 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1637 assert timeout == httpx.Timeout(None)
1638
1639 await client.close()
1640
1641 # no timeout given to the httpx client should not use the httpx default
1642 async with httpx.AsyncClient() as http_client:
1643 client = AsyncOpenAI(
1644 base_url=base_url,
1645 api_key=api_key,
1646 admin_api_key=admin_api_key,
1647 _strict_response_validation=True,
1648 http_client=http_client,
1649 )
1650
1651 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1652 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1653 assert timeout == DEFAULT_TIMEOUT
1654
1655 await client.close()
1656
1657 # explicitly passing the default timeout currently results in it being ignored
1658 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1659 client = AsyncOpenAI(
1660 base_url=base_url,
1661 api_key=api_key,
1662 admin_api_key=admin_api_key,
1663 _strict_response_validation=True,
1664 http_client=http_client,
1665 )
1666
1667 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1668 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1669 assert timeout == DEFAULT_TIMEOUT # our default
1670
1671 await client.close()
1672
1673 def test_invalid_http_client(self) -> None:
1674 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1675 with httpx.Client() as http_client:
1676 AsyncOpenAI(
1677 base_url=base_url,
1678 api_key=api_key,
1679 admin_api_key=admin_api_key,
1680 _strict_response_validation=True,
1681 http_client=cast(Any, http_client),
1682 )
1683
1684 async def test_default_headers_option(self) -> None:
1685 test_client = AsyncOpenAI(
1686 base_url=base_url,
1687 api_key=api_key,
1688 admin_api_key=admin_api_key,
1689 _strict_response_validation=True,
1690 default_headers={"X-Foo": "bar"},
1691 )
1692 request = test_client._build_request(FinalRequestOptions(method="get", url="/foo"))
1693 assert request.headers.get("x-foo") == "bar"
1694 assert request.headers.get("x-stainless-lang") == "python"
1695
1696 test_client2 = AsyncOpenAI(
1697 base_url=base_url,
1698 api_key=api_key,
1699 admin_api_key=admin_api_key,
1700 _strict_response_validation=True,
1701 default_headers={
1702 "X-Foo": "stainless",
1703 "X-Stainless-Lang": "my-overriding-header",
1704 },
1705 )
1706 request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1707 assert request.headers.get("x-foo") == "stainless"
1708 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1709
1710 await test_client.close()
1711 await test_client2.close()
1712
1713 async def test_validate_headers(self) -> None:
1714 client = AsyncOpenAI(
1715 base_url=base_url, api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=True
1716 )
1717 options = await client._prepare_options(FinalRequestOptions(method="get", url="/foo"))
1718 request = client._build_request(options)
1719 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1720
1721 admin_request = client._build_request(
1722 FinalRequestOptions(
1723 method="get",
1724 url="/organization/projects",
1725 security={"admin_api_key_auth": True},
1726 )
1727 )
1728 assert admin_request.headers.get("Authorization") == f"Bearer {admin_api_key}"
1729
1730 with update_env(**{"OPENAI_API_KEY": Omit()}):
1731 admin_only = AsyncOpenAI(
1732 base_url=base_url,
1733 api_key=None,
1734 admin_api_key=admin_api_key,
1735 _strict_response_validation=True,
1736 )
1737 admin_only_request = admin_only._build_request(
1738 FinalRequestOptions(
1739 method="get",
1740 url="/organization/projects",
1741 security={"admin_api_key_auth": True},
1742 )
1743 )
1744 assert admin_only_request.headers.get("Authorization") == f"Bearer {admin_api_key}"
1745
1746 with pytest.raises(
1747 TypeError,
1748 match="Could not resolve authentication method",
1749 ):
1750 admin_only._build_request(
1751 FinalRequestOptions(
1752 method="post",
1753 url="/responses",
1754 security={"bearer_auth": True},
1755 )
1756 )
1757
1758 with update_env(
1759 **{
1760 "OPENAI_API_KEY": Omit(),
1761 "OPENAI_ADMIN_KEY": Omit(),
1762 }
1763 ):
1764 no_credentials = AsyncOpenAI(
1765 base_url=base_url,
1766 api_key=None,
1767 admin_api_key=None,
1768 _enforce_credentials=False,
1769 _strict_response_validation=True,
1770 )
1771 lowercase_auth_request = no_credentials._build_request(
1772 FinalRequestOptions(method="get", url="/foo", headers={"authorization": "Bearer custom"})
1773 )
1774 assert lowercase_auth_request.headers.get("Authorization") == "Bearer custom"
1775
1776 omitted_auth_request = no_credentials._build_request(
1777 FinalRequestOptions(method="get", url="/foo", headers={"authorization": Omit()})
1778 )
1779 assert "Authorization" not in omitted_auth_request.headers
1780
1781 with update_env(
1782 **{
1783 "OPENAI_API_KEY": Omit(),
1784 "OPENAI_ADMIN_KEY": Omit(),
1785 }
1786 ):
1787 with pytest.raises(OpenAIError, match="Missing credentials"):
1788 AsyncOpenAI(base_url=base_url, api_key=None, admin_api_key=None, _strict_response_validation=True)
1789
1790 @pytest.mark.respx(base_url=base_url)
1791 async def test_api_key_provider_preserves_admin_auth(self, respx_mock: MockRouter) -> None:
1792 respx_mock.get("/organization/projects").mock(return_value=httpx.Response(200, json={"ok": True}))
1793
1794 provider_called = False
1795
1796 async def api_key_provider() -> str:
1797 nonlocal provider_called
1798 provider_called = True
1799 return "dynamic-api-key"
1800
1801 client = AsyncOpenAI(base_url=base_url, api_key=api_key_provider, admin_api_key=admin_api_key)
1802 response = await client.get(
1803 "/organization/projects",
1804 cast_to=httpx.Response,
1805 options={"security": {"admin_api_key_auth": True}},
1806 )
1807
1808 assert response.request.headers.get("Authorization") == f"Bearer {admin_api_key}"
1809 assert provider_called is False
1810
1811 async def test_api_key_provider_does_not_fill_admin_auth(self) -> None:
1812 provider_called = False
1813
1814 async def api_key_provider() -> str:
1815 nonlocal provider_called
1816 provider_called = True
1817 return "dynamic-api-key"
1818
1819 with update_env(OPENAI_ADMIN_KEY=Omit()):
1820 client = AsyncOpenAI(base_url=base_url, api_key=api_key_provider, admin_api_key=None)
1821 with pytest.raises(TypeError, match="Could not resolve authentication method"):
1822 await client.get(
1823 "/organization/projects",
1824 cast_to=httpx.Response,
1825 options={"security": {"admin_api_key_auth": True}},
1826 )
1827
1828 assert provider_called is False
1829
1830 @pytest.mark.respx(base_url=base_url)
1831 async def test_workload_identity_preserves_admin_auth(self, respx_mock: MockRouter) -> None:
1832 respx_mock.get("/organization/projects").mock(return_value=httpx.Response(200, json={"ok": True}))
1833
1834 client = AsyncOpenAI(base_url=base_url, workload_identity=workload_identity, admin_api_key=admin_api_key)
1835 response = await client.get(
1836 "/organization/projects",
1837 cast_to=httpx.Response,
1838 options={"security": {"admin_api_key_auth": True}},
1839 )
1840
1841 assert response.request.headers.get("Authorization") == f"Bearer {admin_api_key}"
1842
1843 async def test_default_query_option(self) -> None:
1844 client = AsyncOpenAI(
1845 base_url=base_url,
1846 api_key=api_key,
1847 admin_api_key=admin_api_key,
1848 _strict_response_validation=True,
1849 default_query={"query_param": "bar"},
1850 )
1851 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1852 url = httpx.URL(request.url)
1853 assert dict(url.params) == {"query_param": "bar"}
1854
1855 request = client._build_request(
1856 FinalRequestOptions(
1857 method="get",
1858 url="/foo",
1859 params={"foo": "baz", "query_param": "overridden"},
1860 )
1861 )
1862 url = httpx.URL(request.url)
1863 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
1864
1865 await client.close()
1866
1867 async def test_hardcoded_query_params_in_url(self, async_client: AsyncOpenAI) -> None:
1868 request = async_client._build_request(FinalRequestOptions(method="get", url="/foo?beta=true"))
1869 url = httpx.URL(request.url)
1870 assert dict(url.params) == {"beta": "true"}
1871
1872 request = async_client._build_request(
1873 FinalRequestOptions(
1874 method="get",
1875 url="/foo?beta=true",
1876 params={"limit": "10", "page": "abc"},
1877 )
1878 )
1879 url = httpx.URL(request.url)
1880 assert dict(url.params) == {"beta": "true", "limit": "10", "page": "abc"}
1881
1882 request = async_client._build_request(
1883 FinalRequestOptions(
1884 method="get",
1885 url="/files/a%2Fb?beta=true",
1886 params={"limit": "10"},
1887 )
1888 )
1889 assert request.url.raw_path == b"/files/a%2Fb?beta=true&limit=10"
1890
1891 def test_request_extra_json(self, client: OpenAI) -> None:
1892 request = client._build_request(
1893 FinalRequestOptions(
1894 method="post",
1895 url="/foo",
1896 json_data={"foo": "bar"},
1897 extra_json={"baz": False},
1898 ),
1899 )
1900 data = json.loads(request.content.decode("utf-8"))
1901 assert data == {"foo": "bar", "baz": False}
1902
1903 request = client._build_request(
1904 FinalRequestOptions(
1905 method="post",
1906 url="/foo",
1907 extra_json={"baz": False},
1908 ),
1909 )
1910 data = json.loads(request.content.decode("utf-8"))
1911 assert data == {"baz": False}
1912
1913 # `extra_json` takes priority over `json_data` when keys clash
1914 request = client._build_request(
1915 FinalRequestOptions(
1916 method="post",
1917 url="/foo",
1918 json_data={"foo": "bar", "baz": True},
1919 extra_json={"baz": None},
1920 ),
1921 )
1922 data = json.loads(request.content.decode("utf-8"))
1923 assert data == {"foo": "bar", "baz": None}
1924
1925 def test_request_extra_headers(self, client: OpenAI) -> None:
1926 request = client._build_request(
1927 FinalRequestOptions(
1928 method="post",
1929 url="/foo",
1930 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1931 ),
1932 )
1933 assert request.headers.get("X-Foo") == "Foo"
1934
1935 # `extra_headers` takes priority over `default_headers` when keys clash
1936 request = client.with_options(default_headers={"X-Bar": "true"})._build_request(
1937 FinalRequestOptions(
1938 method="post",
1939 url="/foo",
1940 **make_request_options(
1941 extra_headers={"X-Bar": "false"},
1942 ),
1943 ),
1944 )
1945 assert request.headers.get("X-Bar") == "false"
1946
1947 def test_request_extra_query(self, client: OpenAI) -> None:
1948 request = client._build_request(
1949 FinalRequestOptions(
1950 method="post",
1951 url="/foo",
1952 **make_request_options(
1953 extra_query={"my_query_param": "Foo"},
1954 ),
1955 ),
1956 )
1957 params = dict(request.url.params)
1958 assert params == {"my_query_param": "Foo"}
1959
1960 # if both `query` and `extra_query` are given, they are merged
1961 request = client._build_request(
1962 FinalRequestOptions(
1963 method="post",
1964 url="/foo",
1965 **make_request_options(
1966 query={"bar": "1"},
1967 extra_query={"foo": "2"},
1968 ),
1969 ),
1970 )
1971 params = dict(request.url.params)
1972 assert params == {"bar": "1", "foo": "2"}
1973
1974 # `extra_query` takes priority over `query` when keys clash
1975 request = client._build_request(
1976 FinalRequestOptions(
1977 method="post",
1978 url="/foo",
1979 **make_request_options(
1980 query={"foo": "1"},
1981 extra_query={"foo": "2"},
1982 ),
1983 ),
1984 )
1985 params = dict(request.url.params)
1986 assert params == {"foo": "2"}
1987
1988 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1989 request = async_client._build_request(
1990 FinalRequestOptions.construct(
1991 method="post",
1992 url="/foo",
1993 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1994 json_data={"array": ["foo", "bar"]},
1995 files=[("foo.txt", b"hello world")],
1996 )
1997 )
1998
1999 assert request.read().split(b"\r\n") == [
2000 b"--6b7ba517decee4a450543ea6ae821c82",
2001 b'Content-Disposition: form-data; name="array[]"',
2002 b"",
2003 b"foo",
2004 b"--6b7ba517decee4a450543ea6ae821c82",
2005 b'Content-Disposition: form-data; name="array[]"',
2006 b"",
2007 b"bar",
2008 b"--6b7ba517decee4a450543ea6ae821c82",
2009 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
2010 b"Content-Type: application/octet-stream",
2011 b"",
2012 b"hello world",
2013 b"--6b7ba517decee4a450543ea6ae821c82--",
2014 b"",
2015 ]
2016
2017 @pytest.mark.respx(base_url=base_url)
2018 async def test_binary_content_upload(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
2019 respx_mock.post("/upload").mock(side_effect=mirror_request_content)
2020
2021 file_content = b"Hello, this is a test file."
2022
2023 response = await async_client.post(
2024 "/upload",
2025 content=file_content,
2026 cast_to=httpx.Response,
2027 options={"headers": {"Content-Type": "application/octet-stream"}},
2028 )
2029
2030 assert response.status_code == 200
2031 assert response.request.headers["Content-Type"] == "application/octet-stream"
2032 assert response.content == file_content
2033
2034 async def test_binary_content_upload_with_asynciterator(self) -> None:
2035 file_content = b"Hello, this is a test file."
2036 counter = Counter()
2037 iterator = _make_async_iterator([file_content], counter=counter)
2038
2039 async def mock_handler(request: httpx.Request) -> httpx.Response:
2040 assert counter.value == 0, "the request body should not have been read"
2041 return httpx.Response(200, content=await request.aread())
2042
2043 async with AsyncOpenAI(
2044 base_url=base_url,
2045 api_key=api_key,
2046 admin_api_key=admin_api_key,
2047 _strict_response_validation=True,
2048 http_client=httpx.AsyncClient(transport=MockTransport(handler=mock_handler)),
2049 ) as client:
2050 response = await client.post(
2051 "/upload",
2052 content=iterator,
2053 cast_to=httpx.Response,
2054 options={"headers": {"Content-Type": "application/octet-stream"}},
2055 )
2056
2057 assert response.status_code == 200
2058 assert response.request.headers["Content-Type"] == "application/octet-stream"
2059 assert response.content == file_content
2060 assert counter.value == 1
2061
2062 @pytest.mark.respx(base_url=base_url)
2063 async def test_binary_content_upload_with_body_is_deprecated(
2064 self, respx_mock: MockRouter, async_client: AsyncOpenAI
2065 ) -> None:
2066 respx_mock.post("/upload").mock(side_effect=mirror_request_content)
2067
2068 file_content = b"Hello, this is a test file."
2069
2070 with pytest.deprecated_call(
2071 match="Passing raw bytes as `body` is deprecated and will be removed in a future version. Please pass raw bytes via the `content` parameter instead."
2072 ):
2073 response = await async_client.post(
2074 "/upload",
2075 body=file_content,
2076 cast_to=httpx.Response,
2077 options={"headers": {"Content-Type": "application/octet-stream"}},
2078 )
2079
2080 assert response.status_code == 200
2081 assert response.request.headers["Content-Type"] == "application/octet-stream"
2082 assert response.content == file_content
2083
2084 @pytest.mark.respx(base_url=base_url)
2085 async def test_basic_union_response(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
2086 class Model1(BaseModel):
2087 name: str
2088
2089 class Model2(BaseModel):
2090 foo: str
2091
2092 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
2093
2094 response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
2095 assert isinstance(response, Model2)
2096 assert response.foo == "bar"
2097
2098 @pytest.mark.respx(base_url=base_url)
2099 async def test_union_response_different_types(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
2100 """Union of objects with the same field name using a different type"""
2101
2102 class Model1(BaseModel):
2103 foo: int
2104
2105 class Model2(BaseModel):
2106 foo: str
2107
2108 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
2109
2110 response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
2111 assert isinstance(response, Model2)
2112 assert response.foo == "bar"
2113
2114 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
2115
2116 response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
2117 assert isinstance(response, Model1)
2118 assert response.foo == 1
2119
2120 @pytest.mark.respx(base_url=base_url)
2121 async def test_non_application_json_content_type_for_json_data(
2122 self, respx_mock: MockRouter, async_client: AsyncOpenAI
2123 ) -> None:
2124 """
2125 Response that sets Content-Type to something other than application/json but returns json data
2126 """
2127
2128 class Model(BaseModel):
2129 foo: int
2130
2131 respx_mock.get("/foo").mock(
2132 return_value=httpx.Response(
2133 200,
2134 content=json.dumps({"foo": 2}),
2135 headers={"Content-Type": "application/text"},
2136 )
2137 )
2138
2139 response = await async_client.get("/foo", cast_to=Model)
2140 assert isinstance(response, Model)
2141 assert response.foo == 2
2142
2143 async def test_base_url_setter(self) -> None:
2144 client = AsyncOpenAI(
2145 base_url="https://example.com/from_init",
2146 api_key=api_key,
2147 admin_api_key=admin_api_key,
2148 _strict_response_validation=True,
2149 )
2150 assert client.base_url == "https://example.com/from_init/"
2151
2152 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
2153
2154 assert client.base_url == "https://example.com/from_setter/"
2155
2156 await client.close()
2157
2158 async def test_base_url_env(self) -> None:
2159 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
2160 client = AsyncOpenAI(api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=True)
2161 assert client.base_url == "http://localhost:5000/from/env/"
2162
2163 @pytest.mark.parametrize(
2164 "client",
2165 [
2166 AsyncOpenAI(
2167 base_url="http://localhost:5000/custom/path/",
2168 api_key=api_key,
2169 admin_api_key=admin_api_key,
2170 _strict_response_validation=True,
2171 ),
2172 AsyncOpenAI(
2173 base_url="http://localhost:5000/custom/path/",
2174 api_key=api_key,
2175 admin_api_key=admin_api_key,
2176 _strict_response_validation=True,
2177 http_client=httpx.AsyncClient(),
2178 ),
2179 ],
2180 ids=["standard", "custom http client"],
2181 )
2182 async def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
2183 request = client._build_request(
2184 FinalRequestOptions(
2185 method="post",
2186 url="/foo",
2187 json_data={"foo": "bar"},
2188 ),
2189 )
2190 assert request.url == "http://localhost:5000/custom/path/foo"
2191 await client.close()
2192
2193 @pytest.mark.parametrize(
2194 "client",
2195 [
2196 AsyncOpenAI(
2197 base_url="http://localhost:5000/custom/path/",
2198 api_key=api_key,
2199 admin_api_key=admin_api_key,
2200 _strict_response_validation=True,
2201 ),
2202 AsyncOpenAI(
2203 base_url="http://localhost:5000/custom/path/",
2204 api_key=api_key,
2205 admin_api_key=admin_api_key,
2206 _strict_response_validation=True,
2207 http_client=httpx.AsyncClient(),
2208 ),
2209 ],
2210 ids=["standard", "custom http client"],
2211 )
2212 async def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
2213 request = client._build_request(
2214 FinalRequestOptions(
2215 method="post",
2216 url="/foo",
2217 json_data={"foo": "bar"},
2218 ),
2219 )
2220 assert request.url == "http://localhost:5000/custom/path/foo"
2221 await client.close()
2222
2223 @pytest.mark.parametrize(
2224 "client",
2225 [
2226 AsyncOpenAI(
2227 base_url="http://localhost:5000/custom/path/",
2228 api_key=api_key,
2229 admin_api_key=admin_api_key,
2230 _strict_response_validation=True,
2231 ),
2232 AsyncOpenAI(
2233 base_url="http://localhost:5000/custom/path/",
2234 api_key=api_key,
2235 admin_api_key=admin_api_key,
2236 _strict_response_validation=True,
2237 http_client=httpx.AsyncClient(),
2238 ),
2239 ],
2240 ids=["standard", "custom http client"],
2241 )
2242 async def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
2243 request = client._build_request(
2244 FinalRequestOptions(
2245 method="post",
2246 url="https://myapi.com/foo",
2247 json_data={"foo": "bar"},
2248 ),
2249 )
2250 assert request.url == "https://myapi.com/foo"
2251 await client.close()
2252
2253 async def test_copied_client_does_not_close_http(self) -> None:
2254 test_client = AsyncOpenAI(
2255 base_url=base_url, api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=True
2256 )
2257 assert not test_client.is_closed()
2258
2259 copied = test_client.copy()
2260 assert copied is not test_client
2261
2262 del copied
2263
2264 await asyncio.sleep(0.2)
2265 assert not test_client.is_closed()
2266
2267 async def test_client_context_manager(self) -> None:
2268 test_client = AsyncOpenAI(
2269 base_url=base_url, api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=True
2270 )
2271 async with test_client as c2:
2272 assert c2 is test_client
2273 assert not c2.is_closed()
2274 assert not test_client.is_closed()
2275 assert test_client.is_closed()
2276
2277 @pytest.mark.respx(base_url=base_url)
2278 async def test_client_response_validation_error(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
2279 class Model(BaseModel):
2280 foo: str
2281
2282 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
2283
2284 with pytest.raises(APIResponseValidationError) as exc:
2285 await async_client.get("/foo", cast_to=Model)
2286
2287 assert isinstance(exc.value.__cause__, ValidationError)
2288
2289 async def test_client_max_retries_validation(self) -> None:
2290 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
2291 AsyncOpenAI(
2292 base_url=base_url,
2293 api_key=api_key,
2294 admin_api_key=admin_api_key,
2295 _strict_response_validation=True,
2296 max_retries=cast(Any, None),
2297 )
2298
2299 @pytest.mark.respx(base_url=base_url)
2300 async def test_default_stream_cls(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
2301 class Model(BaseModel):
2302 name: str
2303
2304 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
2305
2306 stream = await async_client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
2307 assert isinstance(stream, AsyncStream)
2308 await stream.response.aclose()
2309
2310 @pytest.mark.respx(base_url=base_url)
2311 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
2312 class Model(BaseModel):
2313 name: str
2314
2315 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
2316
2317 strict_client = AsyncOpenAI(
2318 base_url=base_url, api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=True
2319 )
2320
2321 with pytest.raises(APIResponseValidationError):
2322 await strict_client.get("/foo", cast_to=Model)
2323
2324 non_strict_client = AsyncOpenAI(
2325 base_url=base_url, api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=False
2326 )
2327
2328 response = await non_strict_client.get("/foo", cast_to=Model)
2329 assert isinstance(response, str) # type: ignore[unreachable]
2330
2331 await strict_client.close()
2332 await non_strict_client.close()
2333
2334 @pytest.mark.parametrize(
2335 "remaining_retries,retry_after,timeout",
2336 [
2337 [3, "20", 20],
2338 [3, "0", 0.5],
2339 [3, "-10", 0.5],
2340 [3, "60", 60],
2341 [3, "61", 0.5],
2342 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
2343 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
2344 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
2345 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
2346 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
2347 [3, "99999999999999999999999999999999999", 0.5],
2348 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
2349 [3, "", 0.5],
2350 [2, "", 0.5 * 2.0],
2351 [1, "", 0.5 * 4.0],
2352 [-1100, "", 8], # test large number potentially overflowing
2353 ],
2354 )
2355 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
2356 async def test_parse_retry_after_header(
2357 self, remaining_retries: int, retry_after: str, timeout: float, async_client: AsyncOpenAI
2358 ) -> None:
2359 headers = httpx.Headers({"retry-after": retry_after})
2360 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
2361 calculated = async_client._calculate_retry_timeout(remaining_retries, options, headers)
2362 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
2363
2364 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
2365 @pytest.mark.respx(base_url=base_url)
2366 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
2367 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
2368
2369 with pytest.raises(APITimeoutError):
2370 await async_client.chat.completions.with_streaming_response.create(
2371 messages=[
2372 {
2373 "content": "string",
2374 "role": "developer",
2375 }
2376 ],
2377 model="gpt-5.4",
2378 ).__aenter__()
2379
2380 assert _get_open_connections(async_client) == 0
2381
2382 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
2383 @pytest.mark.respx(base_url=base_url)
2384 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
2385 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
2386
2387 with pytest.raises(APIStatusError):
2388 await async_client.chat.completions.with_streaming_response.create(
2389 messages=[
2390 {
2391 "content": "string",
2392 "role": "developer",
2393 }
2394 ],
2395 model="gpt-5.4",
2396 ).__aenter__()
2397 assert _get_open_connections(async_client) == 0
2398
2399 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
2400 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
2401 @pytest.mark.respx(base_url=base_url)
2402 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
2403 async def test_retries_taken(
2404 self,
2405 async_client: AsyncOpenAI,
2406 failures_before_success: int,
2407 failure_mode: Literal["status", "exception"],
2408 respx_mock: MockRouter,
2409 ) -> None:
2410 client = async_client.with_options(max_retries=4)
2411
2412 nb_retries = 0
2413
2414 def retry_handler(_request: httpx.Request) -> httpx.Response:
2415 nonlocal nb_retries
2416 if nb_retries < failures_before_success:
2417 nb_retries += 1
2418 if failure_mode == "exception":
2419 raise RuntimeError("oops")
2420 return httpx.Response(500)
2421 return httpx.Response(200)
2422
2423 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
2424
2425 response = await client.chat.completions.with_raw_response.create(
2426 messages=[
2427 {
2428 "content": "string",
2429 "role": "developer",
2430 }
2431 ],
2432 model="gpt-5.4",
2433 )
2434
2435 assert response.retries_taken == failures_before_success
2436 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
2437
2438 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
2439 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
2440 @pytest.mark.respx(base_url=base_url)
2441 async def test_omit_retry_count_header(
2442 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
2443 ) -> None:
2444 client = async_client.with_options(max_retries=4)
2445
2446 nb_retries = 0
2447
2448 def retry_handler(_request: httpx.Request) -> httpx.Response:
2449 nonlocal nb_retries
2450 if nb_retries < failures_before_success:
2451 nb_retries += 1
2452 return httpx.Response(500)
2453 return httpx.Response(200)
2454
2455 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
2456
2457 response = await client.chat.completions.with_raw_response.create(
2458 messages=[
2459 {
2460 "content": "string",
2461 "role": "developer",
2462 }
2463 ],
2464 model="gpt-5.4",
2465 extra_headers={"x-stainless-retry-count": Omit()},
2466 )
2467
2468 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
2469
2470 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
2471 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
2472 @pytest.mark.respx(base_url=base_url)
2473 async def test_overwrite_retry_count_header(
2474 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
2475 ) -> None:
2476 client = async_client.with_options(max_retries=4)
2477
2478 nb_retries = 0
2479
2480 def retry_handler(_request: httpx.Request) -> httpx.Response:
2481 nonlocal nb_retries
2482 if nb_retries < failures_before_success:
2483 nb_retries += 1
2484 return httpx.Response(500)
2485 return httpx.Response(200)
2486
2487 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
2488
2489 response = await client.chat.completions.with_raw_response.create(
2490 messages=[
2491 {
2492 "content": "string",
2493 "role": "developer",
2494 }
2495 ],
2496 model="gpt-5.4",
2497 extra_headers={"x-stainless-retry-count": "42"},
2498 )
2499
2500 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
2501
2502 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
2503 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
2504 @pytest.mark.respx(base_url=base_url)
2505 async def test_retries_taken_new_response_class(
2506 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
2507 ) -> None:
2508 client = async_client.with_options(max_retries=4)
2509
2510 nb_retries = 0
2511
2512 def retry_handler(_request: httpx.Request) -> httpx.Response:
2513 nonlocal nb_retries
2514 if nb_retries < failures_before_success:
2515 nb_retries += 1
2516 return httpx.Response(500)
2517 return httpx.Response(200)
2518
2519 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
2520
2521 async with client.chat.completions.with_streaming_response.create(
2522 messages=[
2523 {
2524 "content": "string",
2525 "role": "developer",
2526 }
2527 ],
2528 model="gpt-5.4",
2529 ) as response:
2530 assert response.retries_taken == failures_before_success
2531 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
2532
2533 async def test_get_platform(self) -> None:
2534 platform = await asyncify(get_platform)()
2535 assert isinstance(platform, (str, OtherPlatform))
2536
2537 async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
2538 # Test that the proxy environment variables are set correctly
2539 monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
2540 # Delete in case our environment has any proxy env vars set
2541 monkeypatch.delenv("HTTP_PROXY", raising=False)
2542 monkeypatch.delenv("ALL_PROXY", raising=False)
2543 monkeypatch.delenv("NO_PROXY", raising=False)
2544 monkeypatch.delenv("http_proxy", raising=False)
2545 monkeypatch.delenv("https_proxy", raising=False)
2546 monkeypatch.delenv("all_proxy", raising=False)
2547 monkeypatch.delenv("no_proxy", raising=False)
2548
2549 client = DefaultAsyncHttpxClient()
2550
2551 mounts = tuple(client._mounts.items())
2552 assert len(mounts) == 1
2553 assert mounts[0][0].pattern == "https://"
2554
2555 @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
2556 async def test_default_client_creation(self) -> None:
2557 # Ensure that the client can be initialized without any exceptions
2558 DefaultAsyncHttpxClient(
2559 verify=True,
2560 cert=None,
2561 trust_env=True,
2562 http1=True,
2563 http2=False,
2564 limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
2565 )
2566
2567 @pytest.mark.respx(base_url=base_url)
2568 async def test_follow_redirects(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
2569 # Test that the default follow_redirects=True allows following redirects
2570 respx_mock.post("/redirect").mock(
2571 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
2572 )
2573 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
2574
2575 response = await async_client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
2576 assert response.status_code == 200
2577 assert response.json() == {"status": "ok"}
2578
2579 @pytest.mark.respx(base_url=base_url)
2580 async def test_follow_redirects_disabled(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
2581 # Test that follow_redirects=False prevents following redirects
2582 respx_mock.post("/redirect").mock(
2583 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
2584 )
2585
2586 with pytest.raises(APIStatusError) as exc_info:
2587 await async_client.post(
2588 "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
2589 )
2590
2591 assert exc_info.value.response.status_code == 302
2592 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
2593
2594 async def test_api_key_before_after_refresh_provider(self) -> None:
2595 async def mock_api_key_provider():
2596 return "test_bearer_token"
2597
2598 client = AsyncOpenAI(base_url=base_url, api_key=mock_api_key_provider)
2599
2600 assert client.api_key == ""
2601 assert "Authorization" not in client.auth_headers
2602
2603 await client._refresh_api_key()
2604
2605 assert client.api_key == "test_bearer_token"
2606 assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
2607
2608 async def test_api_key_before_after_refresh_str(self) -> None:
2609 client = AsyncOpenAI(base_url=base_url, api_key="test_api_key")
2610
2611 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
2612 await client._refresh_api_key()
2613
2614 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
2615
2616 @pytest.mark.respx()
2617 async def test_bearer_token_refresh_async(self, respx_mock: MockRouter) -> None:
2618 respx_mock.post(base_url + "/chat/completions").mock(
2619 side_effect=[
2620 httpx.Response(500, json={"error": "server error"}),
2621 httpx.Response(200, json={"foo": "bar"}),
2622 ]
2623 )
2624
2625 counter = 0
2626
2627 async def token_provider() -> str:
2628 nonlocal counter
2629
2630 counter += 1
2631
2632 if counter == 1:
2633 return "first"
2634
2635 return "second"
2636
2637 client = AsyncOpenAI(base_url=base_url, api_key=token_provider)
2638 await client.chat.completions.create(messages=[], model="gpt-4")
2639
2640 calls = cast("list[MockRequestCall]", respx_mock.calls)
2641 assert len(calls) == 2
2642
2643 assert calls[0].request.headers.get("Authorization") == "Bearer first"
2644 assert calls[1].request.headers.get("Authorization") == "Bearer second"
2645
2646 async def test_copy_auth(self) -> None:
2647 async def token_provider_1() -> str:
2648 return "test_bearer_token_1"
2649
2650 async def token_provider_2() -> str:
2651 return "test_bearer_token_2"
2652
2653 client = AsyncOpenAI(base_url=base_url, api_key=token_provider_1).copy(api_key=token_provider_2)
2654 await client._refresh_api_key()
2655 assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}
2656
2657
2658class TestWorkloadIdentity401Retry:
2659 @pytest.mark.respx()
2660 def test_workload_identity_401_retry(self, respx_mock: MockRouter) -> None:
2661 provider_call_count = 0
2662
2663 def provider() -> str:
2664 nonlocal provider_call_count
2665 provider_call_count += 1
2666 return f"external-subject-token-{provider_call_count}"
2667
2668 respx_mock.post("https://auth.openai.com/oauth/token").mock(
2669 side_effect=[
2670 httpx.Response(
2671 200,
2672 json={
2673 "access_token": "openai-access-token-1",
2674 "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
2675 "token_type": "Bearer",
2676 "expires_in": 3600,
2677 },
2678 ),
2679 httpx.Response(
2680 200,
2681 json={
2682 "access_token": "openai-access-token-2",
2683 "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
2684 "token_type": "Bearer",
2685 "expires_in": 3600,
2686 },
2687 ),
2688 ]
2689 )
2690
2691 respx_mock.post(base_url + "/chat/completions").mock(
2692 side_effect=[
2693 httpx.Response(401, json={"error": {"message": "Unauthorized", "type": "invalid_request_error"}}),
2694 httpx.Response(
2695 200,
2696 json={
2697 "id": "chatcmpl-123",
2698 "object": "chat.completion",
2699 "created": 1234567890,
2700 "model": "gpt-4",
2701 "choices": [],
2702 },
2703 ),
2704 ]
2705 )
2706
2707 with OpenAI(
2708 base_url=base_url,
2709 workload_identity={
2710 **workload_identity,
2711 "provider": {
2712 "get_token": provider,
2713 "token_type": "jwt",
2714 },
2715 },
2716 organization="org_123",
2717 project="proj_123",
2718 _strict_response_validation=True,
2719 ) as client:
2720 client.chat.completions.create(messages=[], model="gpt-4")
2721
2722 calls = cast("list[MockRequestCall]", respx_mock.calls)
2723 assert len(calls) == 4
2724
2725 assert calls[0].request.url == httpx.URL("https://auth.openai.com/oauth/token")
2726 assert calls[1].request.url == httpx.URL(base_url + "/chat/completions")
2727 assert calls[1].request.headers.get("Authorization") == "Bearer openai-access-token-1"
2728
2729 assert calls[2].request.url == httpx.URL("https://auth.openai.com/oauth/token")
2730
2731 assert calls[3].request.url == httpx.URL(base_url + "/chat/completions")
2732 assert calls[3].request.headers.get("Authorization") == "Bearer openai-access-token-2"
2733
2734 assert provider_call_count == 2
2735
2736 @pytest.mark.respx()
2737 def test_401_without_workload_identity_no_retry(self, respx_mock: MockRouter) -> None:
2738 respx_mock.post(base_url + "/chat/completions").mock(
2739 return_value=httpx.Response(
2740 401, json={"error": {"message": "Unauthorized", "type": "invalid_request_error"}}
2741 )
2742 )
2743
2744 with OpenAI(
2745 base_url=base_url,
2746 api_key="test-api-key",
2747 _strict_response_validation=True,
2748 ) as client:
2749 with pytest.raises(APIStatusError) as exc_info:
2750 client.chat.completions.create(messages=[], model="gpt-4")
2751
2752 assert exc_info.value.status_code == 401
2753
2754 calls = cast("list[MockRequestCall]", respx_mock.calls)
2755 assert len(calls) == 1
2756
2757 @pytest.mark.respx()
2758 def test_non_401_errors_no_retry(self, respx_mock: MockRouter) -> None:
2759 provider_call_count = 0
2760
2761 def provider() -> str:
2762 nonlocal provider_call_count
2763 provider_call_count += 1
2764 return "external-subject-token"
2765
2766 respx_mock.post("https://auth.openai.com/oauth/token").mock(
2767 return_value=httpx.Response(
2768 200,
2769 json={
2770 "access_token": "openai-access-token-1",
2771 "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
2772 "token_type": "Bearer",
2773 "expires_in": 3600,
2774 },
2775 )
2776 )
2777
2778 respx_mock.post(base_url + "/chat/completions").mock(
2779 return_value=httpx.Response(403, json={"error": {"message": "Forbidden", "type": "invalid_request_error"}})
2780 )
2781
2782 with OpenAI(
2783 base_url=base_url,
2784 workload_identity={
2785 **workload_identity,
2786 "provider": {
2787 "get_token": provider,
2788 "token_type": "jwt",
2789 },
2790 },
2791 organization="org_123",
2792 project="proj_123",
2793 _strict_response_validation=True,
2794 ) as client:
2795 with pytest.raises(APIStatusError) as exc_info:
2796 client.chat.completions.create(messages=[], model="gpt-4")
2797
2798 assert exc_info.value.status_code == 403
2799
2800 calls = cast("list[MockRequestCall]", respx_mock.calls)
2801 assert len(calls) == 2
2802
2803 assert provider_call_count == 1
2804
2805
2806class TestAsyncWorkloadIdentity401Retry:
2807 @pytest.mark.respx()
2808 async def test_workload_identity_401_retry(self, respx_mock: MockRouter) -> None:
2809 provider_call_count = 0
2810
2811 def provider() -> str:
2812 nonlocal provider_call_count
2813 provider_call_count += 1
2814 return f"external-subject-token-{provider_call_count}"
2815
2816 respx_mock.post("https://auth.openai.com/oauth/token").mock(
2817 side_effect=[
2818 httpx.Response(
2819 200,
2820 json={
2821 "access_token": "openai-access-token-1",
2822 "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
2823 "token_type": "Bearer",
2824 "expires_in": 3600,
2825 },
2826 ),
2827 httpx.Response(
2828 200,
2829 json={
2830 "access_token": "openai-access-token-2",
2831 "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
2832 "token_type": "Bearer",
2833 "expires_in": 3600,
2834 },
2835 ),
2836 ]
2837 )
2838
2839 respx_mock.post(base_url + "/chat/completions").mock(
2840 side_effect=[
2841 httpx.Response(401, json={"error": {"message": "Unauthorized", "type": "invalid_request_error"}}),
2842 httpx.Response(
2843 200,
2844 json={
2845 "id": "chatcmpl-123",
2846 "object": "chat.completion",
2847 "created": 1234567890,
2848 "model": "gpt-4",
2849 "choices": [],
2850 },
2851 ),
2852 ]
2853 )
2854
2855 async with AsyncOpenAI(
2856 base_url=base_url,
2857 workload_identity={
2858 **workload_identity,
2859 "provider": {
2860 "get_token": provider,
2861 "token_type": "jwt",
2862 },
2863 },
2864 organization="org_123",
2865 project="proj_123",
2866 _strict_response_validation=True,
2867 ) as client:
2868 await client.chat.completions.create(messages=[], model="gpt-4")
2869
2870 calls = cast("list[MockRequestCall]", respx_mock.calls)
2871 assert len(calls) == 4
2872
2873 assert calls[0].request.url == httpx.URL("https://auth.openai.com/oauth/token")
2874 assert calls[1].request.url == httpx.URL(base_url + "/chat/completions")
2875 assert calls[1].request.headers.get("Authorization") == "Bearer openai-access-token-1"
2876
2877 assert calls[2].request.url == httpx.URL("https://auth.openai.com/oauth/token")
2878
2879 assert calls[3].request.url == httpx.URL(base_url + "/chat/completions")
2880 assert calls[3].request.headers.get("Authorization") == "Bearer openai-access-token-2"
2881
2882 assert provider_call_count == 2
2883
2884 @pytest.mark.respx()
2885 async def test_401_without_workload_identity_no_retry(self, respx_mock: MockRouter) -> None:
2886 respx_mock.post(base_url + "/chat/completions").mock(
2887 return_value=httpx.Response(
2888 401, json={"error": {"message": "Unauthorized", "type": "invalid_request_error"}}
2889 )
2890 )
2891
2892 async with AsyncOpenAI(
2893 base_url=base_url,
2894 api_key="test-api-key",
2895 _strict_response_validation=True,
2896 ) as client:
2897 with pytest.raises(APIStatusError) as exc_info:
2898 await client.chat.completions.create(messages=[], model="gpt-4")
2899
2900 assert exc_info.value.status_code == 401
2901
2902 calls = cast("list[MockRequestCall]", respx_mock.calls)
2903 assert len(calls) == 1
2904
2905 @pytest.mark.respx()
2906 async def test_non_401_errors_no_retry(self, respx_mock: MockRouter) -> None:
2907 provider_call_count = 0
2908
2909 def provider() -> str:
2910 nonlocal provider_call_count
2911 provider_call_count += 1
2912 return "external-subject-token"
2913
2914 respx_mock.post("https://auth.openai.com/oauth/token").mock(
2915 return_value=httpx.Response(
2916 200,
2917 json={
2918 "access_token": "openai-access-token-1",
2919 "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
2920 "token_type": "Bearer",
2921 "expires_in": 3600,
2922 },
2923 )
2924 )
2925
2926 respx_mock.post(base_url + "/chat/completions").mock(
2927 return_value=httpx.Response(403, json={"error": {"message": "Forbidden", "type": "invalid_request_error"}})
2928 )
2929
2930 async with AsyncOpenAI(
2931 base_url=base_url,
2932 workload_identity={
2933 **workload_identity,
2934 "provider": {
2935 "get_token": provider,
2936 "token_type": "jwt",
2937 },
2938 },
2939 organization="org_123",
2940 project="proj_123",
2941 _strict_response_validation=True,
2942 ) as client:
2943 with pytest.raises(APIStatusError) as exc_info:
2944 await client.chat.completions.create(messages=[], model="gpt-4")
2945
2946 assert exc_info.value.status_code == 403
2947
2948 calls = cast("list[MockRequestCall]", respx_mock.calls)
2949 assert len(calls) == 2
2950
2951 assert provider_call_count == 1
2952