openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.106.1

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

2019lines · modecode

1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import gc
6import os
7import sys
8import json
9import time
10import asyncio
11import inspect
12import subprocess
13import tracemalloc
14from typing import Any, Union, Protocol, cast
15from textwrap import dedent
16from unittest import mock
17from typing_extensions import Literal
18
19import httpx
20import pytest
21from respx import MockRouter
22from pydantic import ValidationError
23
24from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
25from openai._types import Omit
26from openai._models import BaseModel, FinalRequestOptions
27from openai._streaming import Stream, AsyncStream
28from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
29from openai._base_client import (
30 DEFAULT_TIMEOUT,
31 HTTPX_DEFAULT_TIMEOUT,
32 BaseClient,
33 DefaultHttpxClient,
34 DefaultAsyncHttpxClient,
35 make_request_options,
36)
37
38from .utils import update_env
39
40base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
41api_key = "My API Key"
42
43
44class MockRequestCall(Protocol):
45 request: httpx.Request
46
47
48def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
49 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
50 url = httpx.URL(request.url)
51 return dict(url.params)
52
53
54def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
55 return 0.1
56
57
58def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
59 transport = client._client._transport
60 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
61
62 pool = transport._pool
63 return len(pool._requests)
64
65
66class TestOpenAI:
67 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
68
69 @pytest.mark.respx(base_url=base_url)
70 def test_raw_response(self, respx_mock: MockRouter) -> None:
71 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
72
73 response = self.client.post("/foo", cast_to=httpx.Response)
74 assert response.status_code == 200
75 assert isinstance(response, httpx.Response)
76 assert response.json() == {"foo": "bar"}
77
78 @pytest.mark.respx(base_url=base_url)
79 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
80 respx_mock.post("/foo").mock(
81 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
82 )
83
84 response = self.client.post("/foo", cast_to=httpx.Response)
85 assert response.status_code == 200
86 assert isinstance(response, httpx.Response)
87 assert response.json() == {"foo": "bar"}
88
89 def test_copy(self) -> None:
90 copied = self.client.copy()
91 assert id(copied) != id(self.client)
92
93 copied = self.client.copy(api_key="another My API Key")
94 assert copied.api_key == "another My API Key"
95 assert self.client.api_key == "My API Key"
96
97 def test_copy_default_options(self) -> None:
98 # options that have a default are overridden correctly
99 copied = self.client.copy(max_retries=7)
100 assert copied.max_retries == 7
101 assert self.client.max_retries == 2
102
103 copied2 = copied.copy(max_retries=6)
104 assert copied2.max_retries == 6
105 assert copied.max_retries == 7
106
107 # timeout
108 assert isinstance(self.client.timeout, httpx.Timeout)
109 copied = self.client.copy(timeout=None)
110 assert copied.timeout is None
111 assert isinstance(self.client.timeout, httpx.Timeout)
112
113 def test_copy_default_headers(self) -> None:
114 client = OpenAI(
115 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
116 )
117 assert client.default_headers["X-Foo"] == "bar"
118
119 # does not override the already given value when not specified
120 copied = client.copy()
121 assert copied.default_headers["X-Foo"] == "bar"
122
123 # merges already given headers
124 copied = client.copy(default_headers={"X-Bar": "stainless"})
125 assert copied.default_headers["X-Foo"] == "bar"
126 assert copied.default_headers["X-Bar"] == "stainless"
127
128 # uses new values for any already given headers
129 copied = client.copy(default_headers={"X-Foo": "stainless"})
130 assert copied.default_headers["X-Foo"] == "stainless"
131
132 # set_default_headers
133
134 # completely overrides already set values
135 copied = client.copy(set_default_headers={})
136 assert copied.default_headers.get("X-Foo") is None
137
138 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
139 assert copied.default_headers["X-Bar"] == "Robert"
140
141 with pytest.raises(
142 ValueError,
143 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
144 ):
145 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
146
147 def test_copy_default_query(self) -> None:
148 client = OpenAI(
149 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
150 )
151 assert _get_params(client)["foo"] == "bar"
152
153 # does not override the already given value when not specified
154 copied = client.copy()
155 assert _get_params(copied)["foo"] == "bar"
156
157 # merges already given params
158 copied = client.copy(default_query={"bar": "stainless"})
159 params = _get_params(copied)
160 assert params["foo"] == "bar"
161 assert params["bar"] == "stainless"
162
163 # uses new values for any already given headers
164 copied = client.copy(default_query={"foo": "stainless"})
165 assert _get_params(copied)["foo"] == "stainless"
166
167 # set_default_query
168
169 # completely overrides already set values
170 copied = client.copy(set_default_query={})
171 assert _get_params(copied) == {}
172
173 copied = client.copy(set_default_query={"bar": "Robert"})
174 assert _get_params(copied)["bar"] == "Robert"
175
176 with pytest.raises(
177 ValueError,
178 # TODO: update
179 match="`default_query` and `set_default_query` arguments are mutually exclusive",
180 ):
181 client.copy(set_default_query={}, default_query={"foo": "Bar"})
182
183 def test_copy_signature(self) -> None:
184 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
185 init_signature = inspect.signature(
186 # mypy doesn't like that we access the `__init__` property.
187 self.client.__init__, # type: ignore[misc]
188 )
189 copy_signature = inspect.signature(self.client.copy)
190 exclude_params = {"transport", "proxies", "_strict_response_validation"}
191
192 for name in init_signature.parameters.keys():
193 if name in exclude_params:
194 continue
195
196 copy_param = copy_signature.parameters.get(name)
197 assert copy_param is not None, f"copy() signature is missing the {name} param"
198
199 @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
200 def test_copy_build_request(self) -> None:
201 options = FinalRequestOptions(method="get", url="/foo")
202
203 def build_request(options: FinalRequestOptions) -> None:
204 client = self.client.copy()
205 client._build_request(options)
206
207 # ensure that the machinery is warmed up before tracing starts.
208 build_request(options)
209 gc.collect()
210
211 tracemalloc.start(1000)
212
213 snapshot_before = tracemalloc.take_snapshot()
214
215 ITERATIONS = 10
216 for _ in range(ITERATIONS):
217 build_request(options)
218
219 gc.collect()
220 snapshot_after = tracemalloc.take_snapshot()
221
222 tracemalloc.stop()
223
224 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
225 if diff.count == 0:
226 # Avoid false positives by considering only leaks (i.e. allocations that persist).
227 return
228
229 if diff.count % ITERATIONS != 0:
230 # Avoid false positives by considering only leaks that appear per iteration.
231 return
232
233 for frame in diff.traceback:
234 if any(
235 frame.filename.endswith(fragment)
236 for fragment in [
237 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
238 #
239 # removing the decorator fixes the leak for reasons we don't understand.
240 "openai/_legacy_response.py",
241 "openai/_response.py",
242 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
243 "openai/_compat.py",
244 # Standard library leaks we don't care about.
245 "/logging/__init__.py",
246 ]
247 ):
248 return
249
250 leaks.append(diff)
251
252 leaks: list[tracemalloc.StatisticDiff] = []
253 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
254 add_leak(leaks, diff)
255 if leaks:
256 for leak in leaks:
257 print("MEMORY LEAK:", leak)
258 for frame in leak.traceback:
259 print(frame)
260 raise AssertionError()
261
262 def test_request_timeout(self) -> None:
263 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
264 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
265 assert timeout == DEFAULT_TIMEOUT
266
267 request = self.client._build_request(
268 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
269 )
270 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
271 assert timeout == httpx.Timeout(100.0)
272
273 def test_client_timeout_option(self) -> None:
274 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
275
276 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
277 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
278 assert timeout == httpx.Timeout(0)
279
280 def test_http_client_timeout_option(self) -> None:
281 # custom timeout given to the httpx client should be used
282 with httpx.Client(timeout=None) as http_client:
283 client = OpenAI(
284 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
285 )
286
287 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
288 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
289 assert timeout == httpx.Timeout(None)
290
291 # no timeout given to the httpx client should not use the httpx default
292 with httpx.Client() as http_client:
293 client = OpenAI(
294 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
295 )
296
297 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
298 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
299 assert timeout == DEFAULT_TIMEOUT
300
301 # explicitly passing the default timeout currently results in it being ignored
302 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
303 client = OpenAI(
304 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
305 )
306
307 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
308 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
309 assert timeout == DEFAULT_TIMEOUT # our default
310
311 async def test_invalid_http_client(self) -> None:
312 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
313 async with httpx.AsyncClient() as http_client:
314 OpenAI(
315 base_url=base_url,
316 api_key=api_key,
317 _strict_response_validation=True,
318 http_client=cast(Any, http_client),
319 )
320
321 def test_default_headers_option(self) -> None:
322 client = OpenAI(
323 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
324 )
325 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
326 assert request.headers.get("x-foo") == "bar"
327 assert request.headers.get("x-stainless-lang") == "python"
328
329 client2 = OpenAI(
330 base_url=base_url,
331 api_key=api_key,
332 _strict_response_validation=True,
333 default_headers={
334 "X-Foo": "stainless",
335 "X-Stainless-Lang": "my-overriding-header",
336 },
337 )
338 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
339 assert request.headers.get("x-foo") == "stainless"
340 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
341
342 def test_validate_headers(self) -> None:
343 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
344 options = client._prepare_options(FinalRequestOptions(method="get", url="/foo"))
345 request = client._build_request(options)
346
347 assert request.headers.get("Authorization") == f"Bearer {api_key}"
348
349 with pytest.raises(OpenAIError):
350 with update_env(**{"OPENAI_API_KEY": Omit()}):
351 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
352 _ = client2
353
354 def test_default_query_option(self) -> None:
355 client = OpenAI(
356 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
357 )
358 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
359 url = httpx.URL(request.url)
360 assert dict(url.params) == {"query_param": "bar"}
361
362 request = client._build_request(
363 FinalRequestOptions(
364 method="get",
365 url="/foo",
366 params={"foo": "baz", "query_param": "overridden"},
367 )
368 )
369 url = httpx.URL(request.url)
370 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
371
372 def test_request_extra_json(self) -> None:
373 request = self.client._build_request(
374 FinalRequestOptions(
375 method="post",
376 url="/foo",
377 json_data={"foo": "bar"},
378 extra_json={"baz": False},
379 ),
380 )
381 data = json.loads(request.content.decode("utf-8"))
382 assert data == {"foo": "bar", "baz": False}
383
384 request = self.client._build_request(
385 FinalRequestOptions(
386 method="post",
387 url="/foo",
388 extra_json={"baz": False},
389 ),
390 )
391 data = json.loads(request.content.decode("utf-8"))
392 assert data == {"baz": False}
393
394 # `extra_json` takes priority over `json_data` when keys clash
395 request = self.client._build_request(
396 FinalRequestOptions(
397 method="post",
398 url="/foo",
399 json_data={"foo": "bar", "baz": True},
400 extra_json={"baz": None},
401 ),
402 )
403 data = json.loads(request.content.decode("utf-8"))
404 assert data == {"foo": "bar", "baz": None}
405
406 def test_request_extra_headers(self) -> None:
407 request = self.client._build_request(
408 FinalRequestOptions(
409 method="post",
410 url="/foo",
411 **make_request_options(extra_headers={"X-Foo": "Foo"}),
412 ),
413 )
414 assert request.headers.get("X-Foo") == "Foo"
415
416 # `extra_headers` takes priority over `default_headers` when keys clash
417 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
418 FinalRequestOptions(
419 method="post",
420 url="/foo",
421 **make_request_options(
422 extra_headers={"X-Bar": "false"},
423 ),
424 ),
425 )
426 assert request.headers.get("X-Bar") == "false"
427
428 def test_request_extra_query(self) -> None:
429 request = self.client._build_request(
430 FinalRequestOptions(
431 method="post",
432 url="/foo",
433 **make_request_options(
434 extra_query={"my_query_param": "Foo"},
435 ),
436 ),
437 )
438 params = dict(request.url.params)
439 assert params == {"my_query_param": "Foo"}
440
441 # if both `query` and `extra_query` are given, they are merged
442 request = self.client._build_request(
443 FinalRequestOptions(
444 method="post",
445 url="/foo",
446 **make_request_options(
447 query={"bar": "1"},
448 extra_query={"foo": "2"},
449 ),
450 ),
451 )
452 params = dict(request.url.params)
453 assert params == {"bar": "1", "foo": "2"}
454
455 # `extra_query` takes priority over `query` when keys clash
456 request = self.client._build_request(
457 FinalRequestOptions(
458 method="post",
459 url="/foo",
460 **make_request_options(
461 query={"foo": "1"},
462 extra_query={"foo": "2"},
463 ),
464 ),
465 )
466 params = dict(request.url.params)
467 assert params == {"foo": "2"}
468
469 def test_multipart_repeating_array(self, client: OpenAI) -> None:
470 request = client._build_request(
471 FinalRequestOptions.construct(
472 method="post",
473 url="/foo",
474 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
475 json_data={"array": ["foo", "bar"]},
476 files=[("foo.txt", b"hello world")],
477 )
478 )
479
480 assert request.read().split(b"\r\n") == [
481 b"--6b7ba517decee4a450543ea6ae821c82",
482 b'Content-Disposition: form-data; name="array[]"',
483 b"",
484 b"foo",
485 b"--6b7ba517decee4a450543ea6ae821c82",
486 b'Content-Disposition: form-data; name="array[]"',
487 b"",
488 b"bar",
489 b"--6b7ba517decee4a450543ea6ae821c82",
490 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
491 b"Content-Type: application/octet-stream",
492 b"",
493 b"hello world",
494 b"--6b7ba517decee4a450543ea6ae821c82--",
495 b"",
496 ]
497
498 @pytest.mark.respx(base_url=base_url)
499 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
500 class Model1(BaseModel):
501 name: str
502
503 class Model2(BaseModel):
504 foo: str
505
506 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
507
508 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
509 assert isinstance(response, Model2)
510 assert response.foo == "bar"
511
512 @pytest.mark.respx(base_url=base_url)
513 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
514 """Union of objects with the same field name using a different type"""
515
516 class Model1(BaseModel):
517 foo: int
518
519 class Model2(BaseModel):
520 foo: str
521
522 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
523
524 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
525 assert isinstance(response, Model2)
526 assert response.foo == "bar"
527
528 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
529
530 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
531 assert isinstance(response, Model1)
532 assert response.foo == 1
533
534 @pytest.mark.respx(base_url=base_url)
535 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
536 """
537 Response that sets Content-Type to something other than application/json but returns json data
538 """
539
540 class Model(BaseModel):
541 foo: int
542
543 respx_mock.get("/foo").mock(
544 return_value=httpx.Response(
545 200,
546 content=json.dumps({"foo": 2}),
547 headers={"Content-Type": "application/text"},
548 )
549 )
550
551 response = self.client.get("/foo", cast_to=Model)
552 assert isinstance(response, Model)
553 assert response.foo == 2
554
555 def test_base_url_setter(self) -> None:
556 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
557 assert client.base_url == "https://example.com/from_init/"
558
559 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
560
561 assert client.base_url == "https://example.com/from_setter/"
562
563 def test_base_url_env(self) -> None:
564 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
565 client = OpenAI(api_key=api_key, _strict_response_validation=True)
566 assert client.base_url == "http://localhost:5000/from/env/"
567
568 @pytest.mark.parametrize(
569 "client",
570 [
571 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
572 OpenAI(
573 base_url="http://localhost:5000/custom/path/",
574 api_key=api_key,
575 _strict_response_validation=True,
576 http_client=httpx.Client(),
577 ),
578 ],
579 ids=["standard", "custom http client"],
580 )
581 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
582 request = client._build_request(
583 FinalRequestOptions(
584 method="post",
585 url="/foo",
586 json_data={"foo": "bar"},
587 ),
588 )
589 assert request.url == "http://localhost:5000/custom/path/foo"
590
591 @pytest.mark.parametrize(
592 "client",
593 [
594 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
595 OpenAI(
596 base_url="http://localhost:5000/custom/path/",
597 api_key=api_key,
598 _strict_response_validation=True,
599 http_client=httpx.Client(),
600 ),
601 ],
602 ids=["standard", "custom http client"],
603 )
604 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
605 request = client._build_request(
606 FinalRequestOptions(
607 method="post",
608 url="/foo",
609 json_data={"foo": "bar"},
610 ),
611 )
612 assert request.url == "http://localhost:5000/custom/path/foo"
613
614 @pytest.mark.parametrize(
615 "client",
616 [
617 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
618 OpenAI(
619 base_url="http://localhost:5000/custom/path/",
620 api_key=api_key,
621 _strict_response_validation=True,
622 http_client=httpx.Client(),
623 ),
624 ],
625 ids=["standard", "custom http client"],
626 )
627 def test_absolute_request_url(self, client: OpenAI) -> None:
628 request = client._build_request(
629 FinalRequestOptions(
630 method="post",
631 url="https://myapi.com/foo",
632 json_data={"foo": "bar"},
633 ),
634 )
635 assert request.url == "https://myapi.com/foo"
636
637 def test_copied_client_does_not_close_http(self) -> None:
638 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
639 assert not client.is_closed()
640
641 copied = client.copy()
642 assert copied is not client
643
644 del copied
645
646 assert not client.is_closed()
647
648 def test_client_context_manager(self) -> None:
649 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
650 with client as c2:
651 assert c2 is client
652 assert not c2.is_closed()
653 assert not client.is_closed()
654 assert client.is_closed()
655
656 @pytest.mark.respx(base_url=base_url)
657 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
658 class Model(BaseModel):
659 foo: str
660
661 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
662
663 with pytest.raises(APIResponseValidationError) as exc:
664 self.client.get("/foo", cast_to=Model)
665
666 assert isinstance(exc.value.__cause__, ValidationError)
667
668 def test_client_max_retries_validation(self) -> None:
669 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
670 OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None))
671
672 @pytest.mark.respx(base_url=base_url)
673 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
674 class Model(BaseModel):
675 name: str
676
677 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
678
679 stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
680 assert isinstance(stream, Stream)
681 stream.response.close()
682
683 @pytest.mark.respx(base_url=base_url)
684 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
685 class Model(BaseModel):
686 name: str
687
688 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
689
690 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
691
692 with pytest.raises(APIResponseValidationError):
693 strict_client.get("/foo", cast_to=Model)
694
695 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
696
697 response = client.get("/foo", cast_to=Model)
698 assert isinstance(response, str) # type: ignore[unreachable]
699
700 @pytest.mark.parametrize(
701 "remaining_retries,retry_after,timeout",
702 [
703 [3, "20", 20],
704 [3, "0", 0.5],
705 [3, "-10", 0.5],
706 [3, "60", 60],
707 [3, "61", 0.5],
708 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
709 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
710 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
711 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
712 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
713 [3, "99999999999999999999999999999999999", 0.5],
714 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
715 [3, "", 0.5],
716 [2, "", 0.5 * 2.0],
717 [1, "", 0.5 * 4.0],
718 [-1100, "", 8], # test large number potentially overflowing
719 ],
720 )
721 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
722 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
723 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
724
725 headers = httpx.Headers({"retry-after": retry_after})
726 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
727 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
728 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
729
730 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
731 @pytest.mark.respx(base_url=base_url)
732 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: OpenAI) -> None:
733 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
734
735 with pytest.raises(APITimeoutError):
736 client.chat.completions.with_streaming_response.create(
737 messages=[
738 {
739 "content": "string",
740 "role": "developer",
741 }
742 ],
743 model="gpt-4o",
744 ).__enter__()
745
746 assert _get_open_connections(self.client) == 0
747
748 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
749 @pytest.mark.respx(base_url=base_url)
750 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: OpenAI) -> None:
751 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
752
753 with pytest.raises(APIStatusError):
754 client.chat.completions.with_streaming_response.create(
755 messages=[
756 {
757 "content": "string",
758 "role": "developer",
759 }
760 ],
761 model="gpt-4o",
762 ).__enter__()
763 assert _get_open_connections(self.client) == 0
764
765 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
766 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
767 @pytest.mark.respx(base_url=base_url)
768 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
769 def test_retries_taken(
770 self,
771 client: OpenAI,
772 failures_before_success: int,
773 failure_mode: Literal["status", "exception"],
774 respx_mock: MockRouter,
775 ) -> None:
776 client = client.with_options(max_retries=4)
777
778 nb_retries = 0
779
780 def retry_handler(_request: httpx.Request) -> httpx.Response:
781 nonlocal nb_retries
782 if nb_retries < failures_before_success:
783 nb_retries += 1
784 if failure_mode == "exception":
785 raise RuntimeError("oops")
786 return httpx.Response(500)
787 return httpx.Response(200)
788
789 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
790
791 response = client.chat.completions.with_raw_response.create(
792 messages=[
793 {
794 "content": "string",
795 "role": "developer",
796 }
797 ],
798 model="gpt-4o",
799 )
800
801 assert response.retries_taken == failures_before_success
802 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
803
804 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
805 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
806 @pytest.mark.respx(base_url=base_url)
807 def test_omit_retry_count_header(
808 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
809 ) -> None:
810 client = client.with_options(max_retries=4)
811
812 nb_retries = 0
813
814 def retry_handler(_request: httpx.Request) -> httpx.Response:
815 nonlocal nb_retries
816 if nb_retries < failures_before_success:
817 nb_retries += 1
818 return httpx.Response(500)
819 return httpx.Response(200)
820
821 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
822
823 response = client.chat.completions.with_raw_response.create(
824 messages=[
825 {
826 "content": "string",
827 "role": "developer",
828 }
829 ],
830 model="gpt-4o",
831 extra_headers={"x-stainless-retry-count": Omit()},
832 )
833
834 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
835
836 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
837 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
838 @pytest.mark.respx(base_url=base_url)
839 def test_overwrite_retry_count_header(
840 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
841 ) -> None:
842 client = client.with_options(max_retries=4)
843
844 nb_retries = 0
845
846 def retry_handler(_request: httpx.Request) -> httpx.Response:
847 nonlocal nb_retries
848 if nb_retries < failures_before_success:
849 nb_retries += 1
850 return httpx.Response(500)
851 return httpx.Response(200)
852
853 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
854
855 response = client.chat.completions.with_raw_response.create(
856 messages=[
857 {
858 "content": "string",
859 "role": "developer",
860 }
861 ],
862 model="gpt-4o",
863 extra_headers={"x-stainless-retry-count": "42"},
864 )
865
866 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
867
868 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
869 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
870 @pytest.mark.respx(base_url=base_url)
871 def test_retries_taken_new_response_class(
872 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
873 ) -> None:
874 client = client.with_options(max_retries=4)
875
876 nb_retries = 0
877
878 def retry_handler(_request: httpx.Request) -> httpx.Response:
879 nonlocal nb_retries
880 if nb_retries < failures_before_success:
881 nb_retries += 1
882 return httpx.Response(500)
883 return httpx.Response(200)
884
885 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
886
887 with client.chat.completions.with_streaming_response.create(
888 messages=[
889 {
890 "content": "string",
891 "role": "developer",
892 }
893 ],
894 model="gpt-4o",
895 ) as response:
896 assert response.retries_taken == failures_before_success
897 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
898
899 def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
900 # Test that the proxy environment variables are set correctly
901 monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
902
903 client = DefaultHttpxClient()
904
905 mounts = tuple(client._mounts.items())
906 assert len(mounts) == 1
907 assert mounts[0][0].pattern == "https://"
908
909 @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
910 def test_default_client_creation(self) -> None:
911 # Ensure that the client can be initialized without any exceptions
912 DefaultHttpxClient(
913 verify=True,
914 cert=None,
915 trust_env=True,
916 http1=True,
917 http2=False,
918 limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
919 )
920
921 @pytest.mark.respx(base_url=base_url)
922 def test_follow_redirects(self, respx_mock: MockRouter) -> None:
923 # Test that the default follow_redirects=True allows following redirects
924 respx_mock.post("/redirect").mock(
925 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
926 )
927 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
928
929 response = self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
930 assert response.status_code == 200
931 assert response.json() == {"status": "ok"}
932
933 @pytest.mark.respx(base_url=base_url)
934 def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
935 # Test that follow_redirects=False prevents following redirects
936 respx_mock.post("/redirect").mock(
937 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
938 )
939
940 with pytest.raises(APIStatusError) as exc_info:
941 self.client.post(
942 "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
943 )
944
945 assert exc_info.value.response.status_code == 302
946 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
947
948 def test_api_key_before_after_refresh_provider(self) -> None:
949 client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token")
950
951 assert client.api_key == ""
952 assert "Authorization" not in client.auth_headers
953
954 client._refresh_api_key()
955
956 assert client.api_key == "test_bearer_token"
957 assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
958
959 def test_api_key_before_after_refresh_str(self) -> None:
960 client = OpenAI(base_url=base_url, api_key="test_api_key")
961
962 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
963 client._refresh_api_key()
964
965 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
966
967 @pytest.mark.respx()
968 def test_api_key_refresh_on_retry(self, respx_mock: MockRouter) -> None:
969 respx_mock.post(base_url + "/chat/completions").mock(
970 side_effect=[
971 httpx.Response(500, json={"error": "server error"}),
972 httpx.Response(200, json={"foo": "bar"}),
973 ]
974 )
975
976 counter = 0
977
978 def token_provider() -> str:
979 nonlocal counter
980
981 counter += 1
982
983 if counter == 1:
984 return "first"
985
986 return "second"
987
988 client = OpenAI(base_url=base_url, api_key=token_provider)
989 client.chat.completions.create(messages=[], model="gpt-4")
990
991 calls = cast("list[MockRequestCall]", respx_mock.calls)
992 assert len(calls) == 2
993
994 assert calls[0].request.headers.get("Authorization") == "Bearer first"
995 assert calls[1].request.headers.get("Authorization") == "Bearer second"
996
997 def test_copy_auth(self) -> None:
998 client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token_1").copy(
999 api_key=lambda: "test_bearer_token_2"
1000 )
1001 client._refresh_api_key()
1002 assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}
1003
1004
1005class TestAsyncOpenAI:
1006 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1007
1008 @pytest.mark.respx(base_url=base_url)
1009 @pytest.mark.asyncio
1010 async def test_raw_response(self, respx_mock: MockRouter) -> None:
1011 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1012
1013 response = await self.client.post("/foo", cast_to=httpx.Response)
1014 assert response.status_code == 200
1015 assert isinstance(response, httpx.Response)
1016 assert response.json() == {"foo": "bar"}
1017
1018 @pytest.mark.respx(base_url=base_url)
1019 @pytest.mark.asyncio
1020 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
1021 respx_mock.post("/foo").mock(
1022 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
1023 )
1024
1025 response = await self.client.post("/foo", cast_to=httpx.Response)
1026 assert response.status_code == 200
1027 assert isinstance(response, httpx.Response)
1028 assert response.json() == {"foo": "bar"}
1029
1030 def test_copy(self) -> None:
1031 copied = self.client.copy()
1032 assert id(copied) != id(self.client)
1033
1034 copied = self.client.copy(api_key="another My API Key")
1035 assert copied.api_key == "another My API Key"
1036 assert self.client.api_key == "My API Key"
1037
1038 def test_copy_default_options(self) -> None:
1039 # options that have a default are overridden correctly
1040 copied = self.client.copy(max_retries=7)
1041 assert copied.max_retries == 7
1042 assert self.client.max_retries == 2
1043
1044 copied2 = copied.copy(max_retries=6)
1045 assert copied2.max_retries == 6
1046 assert copied.max_retries == 7
1047
1048 # timeout
1049 assert isinstance(self.client.timeout, httpx.Timeout)
1050 copied = self.client.copy(timeout=None)
1051 assert copied.timeout is None
1052 assert isinstance(self.client.timeout, httpx.Timeout)
1053
1054 def test_copy_default_headers(self) -> None:
1055 client = AsyncOpenAI(
1056 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1057 )
1058 assert client.default_headers["X-Foo"] == "bar"
1059
1060 # does not override the already given value when not specified
1061 copied = client.copy()
1062 assert copied.default_headers["X-Foo"] == "bar"
1063
1064 # merges already given headers
1065 copied = client.copy(default_headers={"X-Bar": "stainless"})
1066 assert copied.default_headers["X-Foo"] == "bar"
1067 assert copied.default_headers["X-Bar"] == "stainless"
1068
1069 # uses new values for any already given headers
1070 copied = client.copy(default_headers={"X-Foo": "stainless"})
1071 assert copied.default_headers["X-Foo"] == "stainless"
1072
1073 # set_default_headers
1074
1075 # completely overrides already set values
1076 copied = client.copy(set_default_headers={})
1077 assert copied.default_headers.get("X-Foo") is None
1078
1079 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
1080 assert copied.default_headers["X-Bar"] == "Robert"
1081
1082 with pytest.raises(
1083 ValueError,
1084 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
1085 ):
1086 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
1087
1088 def test_copy_default_query(self) -> None:
1089 client = AsyncOpenAI(
1090 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
1091 )
1092 assert _get_params(client)["foo"] == "bar"
1093
1094 # does not override the already given value when not specified
1095 copied = client.copy()
1096 assert _get_params(copied)["foo"] == "bar"
1097
1098 # merges already given params
1099 copied = client.copy(default_query={"bar": "stainless"})
1100 params = _get_params(copied)
1101 assert params["foo"] == "bar"
1102 assert params["bar"] == "stainless"
1103
1104 # uses new values for any already given headers
1105 copied = client.copy(default_query={"foo": "stainless"})
1106 assert _get_params(copied)["foo"] == "stainless"
1107
1108 # set_default_query
1109
1110 # completely overrides already set values
1111 copied = client.copy(set_default_query={})
1112 assert _get_params(copied) == {}
1113
1114 copied = client.copy(set_default_query={"bar": "Robert"})
1115 assert _get_params(copied)["bar"] == "Robert"
1116
1117 with pytest.raises(
1118 ValueError,
1119 # TODO: update
1120 match="`default_query` and `set_default_query` arguments are mutually exclusive",
1121 ):
1122 client.copy(set_default_query={}, default_query={"foo": "Bar"})
1123
1124 def test_copy_signature(self) -> None:
1125 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
1126 init_signature = inspect.signature(
1127 # mypy doesn't like that we access the `__init__` property.
1128 self.client.__init__, # type: ignore[misc]
1129 )
1130 copy_signature = inspect.signature(self.client.copy)
1131 exclude_params = {"transport", "proxies", "_strict_response_validation"}
1132
1133 for name in init_signature.parameters.keys():
1134 if name in exclude_params:
1135 continue
1136
1137 copy_param = copy_signature.parameters.get(name)
1138 assert copy_param is not None, f"copy() signature is missing the {name} param"
1139
1140 @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
1141 def test_copy_build_request(self) -> None:
1142 options = FinalRequestOptions(method="get", url="/foo")
1143
1144 def build_request(options: FinalRequestOptions) -> None:
1145 client = self.client.copy()
1146 client._build_request(options)
1147
1148 # ensure that the machinery is warmed up before tracing starts.
1149 build_request(options)
1150 gc.collect()
1151
1152 tracemalloc.start(1000)
1153
1154 snapshot_before = tracemalloc.take_snapshot()
1155
1156 ITERATIONS = 10
1157 for _ in range(ITERATIONS):
1158 build_request(options)
1159
1160 gc.collect()
1161 snapshot_after = tracemalloc.take_snapshot()
1162
1163 tracemalloc.stop()
1164
1165 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
1166 if diff.count == 0:
1167 # Avoid false positives by considering only leaks (i.e. allocations that persist).
1168 return
1169
1170 if diff.count % ITERATIONS != 0:
1171 # Avoid false positives by considering only leaks that appear per iteration.
1172 return
1173
1174 for frame in diff.traceback:
1175 if any(
1176 frame.filename.endswith(fragment)
1177 for fragment in [
1178 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
1179 #
1180 # removing the decorator fixes the leak for reasons we don't understand.
1181 "openai/_legacy_response.py",
1182 "openai/_response.py",
1183 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
1184 "openai/_compat.py",
1185 # Standard library leaks we don't care about.
1186 "/logging/__init__.py",
1187 ]
1188 ):
1189 return
1190
1191 leaks.append(diff)
1192
1193 leaks: list[tracemalloc.StatisticDiff] = []
1194 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
1195 add_leak(leaks, diff)
1196 if leaks:
1197 for leak in leaks:
1198 print("MEMORY LEAK:", leak)
1199 for frame in leak.traceback:
1200 print(frame)
1201 raise AssertionError()
1202
1203 async def test_request_timeout(self) -> None:
1204 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
1205 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1206 assert timeout == DEFAULT_TIMEOUT
1207
1208 request = self.client._build_request(
1209 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
1210 )
1211 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1212 assert timeout == httpx.Timeout(100.0)
1213
1214 async def test_client_timeout_option(self) -> None:
1215 client = AsyncOpenAI(
1216 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
1217 )
1218
1219 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1220 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1221 assert timeout == httpx.Timeout(0)
1222
1223 async def test_http_client_timeout_option(self) -> None:
1224 # custom timeout given to the httpx client should be used
1225 async with httpx.AsyncClient(timeout=None) as http_client:
1226 client = AsyncOpenAI(
1227 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1228 )
1229
1230 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1231 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1232 assert timeout == httpx.Timeout(None)
1233
1234 # no timeout given to the httpx client should not use the httpx default
1235 async with httpx.AsyncClient() as http_client:
1236 client = AsyncOpenAI(
1237 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1238 )
1239
1240 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1241 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1242 assert timeout == DEFAULT_TIMEOUT
1243
1244 # explicitly passing the default timeout currently results in it being ignored
1245 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1246 client = AsyncOpenAI(
1247 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1248 )
1249
1250 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1251 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1252 assert timeout == DEFAULT_TIMEOUT # our default
1253
1254 def test_invalid_http_client(self) -> None:
1255 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1256 with httpx.Client() as http_client:
1257 AsyncOpenAI(
1258 base_url=base_url,
1259 api_key=api_key,
1260 _strict_response_validation=True,
1261 http_client=cast(Any, http_client),
1262 )
1263
1264 def test_default_headers_option(self) -> None:
1265 client = AsyncOpenAI(
1266 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1267 )
1268 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1269 assert request.headers.get("x-foo") == "bar"
1270 assert request.headers.get("x-stainless-lang") == "python"
1271
1272 client2 = AsyncOpenAI(
1273 base_url=base_url,
1274 api_key=api_key,
1275 _strict_response_validation=True,
1276 default_headers={
1277 "X-Foo": "stainless",
1278 "X-Stainless-Lang": "my-overriding-header",
1279 },
1280 )
1281 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1282 assert request.headers.get("x-foo") == "stainless"
1283 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1284
1285 async def test_validate_headers(self) -> None:
1286 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1287 options = await client._prepare_options(FinalRequestOptions(method="get", url="/foo"))
1288 request = client._build_request(options)
1289 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1290
1291 with pytest.raises(OpenAIError):
1292 with update_env(**{"OPENAI_API_KEY": Omit()}):
1293 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1294 _ = client2
1295
1296 def test_default_query_option(self) -> None:
1297 client = AsyncOpenAI(
1298 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1299 )
1300 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1301 url = httpx.URL(request.url)
1302 assert dict(url.params) == {"query_param": "bar"}
1303
1304 request = client._build_request(
1305 FinalRequestOptions(
1306 method="get",
1307 url="/foo",
1308 params={"foo": "baz", "query_param": "overridden"},
1309 )
1310 )
1311 url = httpx.URL(request.url)
1312 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
1313
1314 def test_request_extra_json(self) -> None:
1315 request = self.client._build_request(
1316 FinalRequestOptions(
1317 method="post",
1318 url="/foo",
1319 json_data={"foo": "bar"},
1320 extra_json={"baz": False},
1321 ),
1322 )
1323 data = json.loads(request.content.decode("utf-8"))
1324 assert data == {"foo": "bar", "baz": False}
1325
1326 request = self.client._build_request(
1327 FinalRequestOptions(
1328 method="post",
1329 url="/foo",
1330 extra_json={"baz": False},
1331 ),
1332 )
1333 data = json.loads(request.content.decode("utf-8"))
1334 assert data == {"baz": False}
1335
1336 # `extra_json` takes priority over `json_data` when keys clash
1337 request = self.client._build_request(
1338 FinalRequestOptions(
1339 method="post",
1340 url="/foo",
1341 json_data={"foo": "bar", "baz": True},
1342 extra_json={"baz": None},
1343 ),
1344 )
1345 data = json.loads(request.content.decode("utf-8"))
1346 assert data == {"foo": "bar", "baz": None}
1347
1348 def test_request_extra_headers(self) -> None:
1349 request = self.client._build_request(
1350 FinalRequestOptions(
1351 method="post",
1352 url="/foo",
1353 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1354 ),
1355 )
1356 assert request.headers.get("X-Foo") == "Foo"
1357
1358 # `extra_headers` takes priority over `default_headers` when keys clash
1359 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1360 FinalRequestOptions(
1361 method="post",
1362 url="/foo",
1363 **make_request_options(
1364 extra_headers={"X-Bar": "false"},
1365 ),
1366 ),
1367 )
1368 assert request.headers.get("X-Bar") == "false"
1369
1370 def test_request_extra_query(self) -> None:
1371 request = self.client._build_request(
1372 FinalRequestOptions(
1373 method="post",
1374 url="/foo",
1375 **make_request_options(
1376 extra_query={"my_query_param": "Foo"},
1377 ),
1378 ),
1379 )
1380 params = dict(request.url.params)
1381 assert params == {"my_query_param": "Foo"}
1382
1383 # if both `query` and `extra_query` are given, they are merged
1384 request = self.client._build_request(
1385 FinalRequestOptions(
1386 method="post",
1387 url="/foo",
1388 **make_request_options(
1389 query={"bar": "1"},
1390 extra_query={"foo": "2"},
1391 ),
1392 ),
1393 )
1394 params = dict(request.url.params)
1395 assert params == {"bar": "1", "foo": "2"}
1396
1397 # `extra_query` takes priority over `query` when keys clash
1398 request = self.client._build_request(
1399 FinalRequestOptions(
1400 method="post",
1401 url="/foo",
1402 **make_request_options(
1403 query={"foo": "1"},
1404 extra_query={"foo": "2"},
1405 ),
1406 ),
1407 )
1408 params = dict(request.url.params)
1409 assert params == {"foo": "2"}
1410
1411 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1412 request = async_client._build_request(
1413 FinalRequestOptions.construct(
1414 method="post",
1415 url="/foo",
1416 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1417 json_data={"array": ["foo", "bar"]},
1418 files=[("foo.txt", b"hello world")],
1419 )
1420 )
1421
1422 assert request.read().split(b"\r\n") == [
1423 b"--6b7ba517decee4a450543ea6ae821c82",
1424 b'Content-Disposition: form-data; name="array[]"',
1425 b"",
1426 b"foo",
1427 b"--6b7ba517decee4a450543ea6ae821c82",
1428 b'Content-Disposition: form-data; name="array[]"',
1429 b"",
1430 b"bar",
1431 b"--6b7ba517decee4a450543ea6ae821c82",
1432 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1433 b"Content-Type: application/octet-stream",
1434 b"",
1435 b"hello world",
1436 b"--6b7ba517decee4a450543ea6ae821c82--",
1437 b"",
1438 ]
1439
1440 @pytest.mark.respx(base_url=base_url)
1441 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1442 class Model1(BaseModel):
1443 name: str
1444
1445 class Model2(BaseModel):
1446 foo: str
1447
1448 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1449
1450 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1451 assert isinstance(response, Model2)
1452 assert response.foo == "bar"
1453
1454 @pytest.mark.respx(base_url=base_url)
1455 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1456 """Union of objects with the same field name using a different type"""
1457
1458 class Model1(BaseModel):
1459 foo: int
1460
1461 class Model2(BaseModel):
1462 foo: str
1463
1464 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1465
1466 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1467 assert isinstance(response, Model2)
1468 assert response.foo == "bar"
1469
1470 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1471
1472 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1473 assert isinstance(response, Model1)
1474 assert response.foo == 1
1475
1476 @pytest.mark.respx(base_url=base_url)
1477 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1478 """
1479 Response that sets Content-Type to something other than application/json but returns json data
1480 """
1481
1482 class Model(BaseModel):
1483 foo: int
1484
1485 respx_mock.get("/foo").mock(
1486 return_value=httpx.Response(
1487 200,
1488 content=json.dumps({"foo": 2}),
1489 headers={"Content-Type": "application/text"},
1490 )
1491 )
1492
1493 response = await self.client.get("/foo", cast_to=Model)
1494 assert isinstance(response, Model)
1495 assert response.foo == 2
1496
1497 def test_base_url_setter(self) -> None:
1498 client = AsyncOpenAI(
1499 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1500 )
1501 assert client.base_url == "https://example.com/from_init/"
1502
1503 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1504
1505 assert client.base_url == "https://example.com/from_setter/"
1506
1507 def test_base_url_env(self) -> None:
1508 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1509 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1510 assert client.base_url == "http://localhost:5000/from/env/"
1511
1512 @pytest.mark.parametrize(
1513 "client",
1514 [
1515 AsyncOpenAI(
1516 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1517 ),
1518 AsyncOpenAI(
1519 base_url="http://localhost:5000/custom/path/",
1520 api_key=api_key,
1521 _strict_response_validation=True,
1522 http_client=httpx.AsyncClient(),
1523 ),
1524 ],
1525 ids=["standard", "custom http client"],
1526 )
1527 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1528 request = client._build_request(
1529 FinalRequestOptions(
1530 method="post",
1531 url="/foo",
1532 json_data={"foo": "bar"},
1533 ),
1534 )
1535 assert request.url == "http://localhost:5000/custom/path/foo"
1536
1537 @pytest.mark.parametrize(
1538 "client",
1539 [
1540 AsyncOpenAI(
1541 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1542 ),
1543 AsyncOpenAI(
1544 base_url="http://localhost:5000/custom/path/",
1545 api_key=api_key,
1546 _strict_response_validation=True,
1547 http_client=httpx.AsyncClient(),
1548 ),
1549 ],
1550 ids=["standard", "custom http client"],
1551 )
1552 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1553 request = client._build_request(
1554 FinalRequestOptions(
1555 method="post",
1556 url="/foo",
1557 json_data={"foo": "bar"},
1558 ),
1559 )
1560 assert request.url == "http://localhost:5000/custom/path/foo"
1561
1562 @pytest.mark.parametrize(
1563 "client",
1564 [
1565 AsyncOpenAI(
1566 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1567 ),
1568 AsyncOpenAI(
1569 base_url="http://localhost:5000/custom/path/",
1570 api_key=api_key,
1571 _strict_response_validation=True,
1572 http_client=httpx.AsyncClient(),
1573 ),
1574 ],
1575 ids=["standard", "custom http client"],
1576 )
1577 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1578 request = client._build_request(
1579 FinalRequestOptions(
1580 method="post",
1581 url="https://myapi.com/foo",
1582 json_data={"foo": "bar"},
1583 ),
1584 )
1585 assert request.url == "https://myapi.com/foo"
1586
1587 async def test_copied_client_does_not_close_http(self) -> None:
1588 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1589 assert not client.is_closed()
1590
1591 copied = client.copy()
1592 assert copied is not client
1593
1594 del copied
1595
1596 await asyncio.sleep(0.2)
1597 assert not client.is_closed()
1598
1599 async def test_client_context_manager(self) -> None:
1600 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1601 async with client as c2:
1602 assert c2 is client
1603 assert not c2.is_closed()
1604 assert not client.is_closed()
1605 assert client.is_closed()
1606
1607 @pytest.mark.respx(base_url=base_url)
1608 @pytest.mark.asyncio
1609 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1610 class Model(BaseModel):
1611 foo: str
1612
1613 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1614
1615 with pytest.raises(APIResponseValidationError) as exc:
1616 await self.client.get("/foo", cast_to=Model)
1617
1618 assert isinstance(exc.value.__cause__, ValidationError)
1619
1620 async def test_client_max_retries_validation(self) -> None:
1621 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
1622 AsyncOpenAI(
1623 base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)
1624 )
1625
1626 @pytest.mark.respx(base_url=base_url)
1627 @pytest.mark.asyncio
1628 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1629 class Model(BaseModel):
1630 name: str
1631
1632 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1633
1634 stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1635 assert isinstance(stream, AsyncStream)
1636 await stream.response.aclose()
1637
1638 @pytest.mark.respx(base_url=base_url)
1639 @pytest.mark.asyncio
1640 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1641 class Model(BaseModel):
1642 name: str
1643
1644 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1645
1646 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1647
1648 with pytest.raises(APIResponseValidationError):
1649 await strict_client.get("/foo", cast_to=Model)
1650
1651 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1652
1653 response = await client.get("/foo", cast_to=Model)
1654 assert isinstance(response, str) # type: ignore[unreachable]
1655
1656 @pytest.mark.parametrize(
1657 "remaining_retries,retry_after,timeout",
1658 [
1659 [3, "20", 20],
1660 [3, "0", 0.5],
1661 [3, "-10", 0.5],
1662 [3, "60", 60],
1663 [3, "61", 0.5],
1664 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1665 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1666 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1667 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1668 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1669 [3, "99999999999999999999999999999999999", 0.5],
1670 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1671 [3, "", 0.5],
1672 [2, "", 0.5 * 2.0],
1673 [1, "", 0.5 * 4.0],
1674 [-1100, "", 8], # test large number potentially overflowing
1675 ],
1676 )
1677 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1678 @pytest.mark.asyncio
1679 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1680 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1681
1682 headers = httpx.Headers({"retry-after": retry_after})
1683 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1684 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1685 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1686
1687 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1688 @pytest.mark.respx(base_url=base_url)
1689 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1690 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1691
1692 with pytest.raises(APITimeoutError):
1693 await async_client.chat.completions.with_streaming_response.create(
1694 messages=[
1695 {
1696 "content": "string",
1697 "role": "developer",
1698 }
1699 ],
1700 model="gpt-4o",
1701 ).__aenter__()
1702
1703 assert _get_open_connections(self.client) == 0
1704
1705 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1706 @pytest.mark.respx(base_url=base_url)
1707 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1708 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1709
1710 with pytest.raises(APIStatusError):
1711 await async_client.chat.completions.with_streaming_response.create(
1712 messages=[
1713 {
1714 "content": "string",
1715 "role": "developer",
1716 }
1717 ],
1718 model="gpt-4o",
1719 ).__aenter__()
1720 assert _get_open_connections(self.client) == 0
1721
1722 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1723 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1724 @pytest.mark.respx(base_url=base_url)
1725 @pytest.mark.asyncio
1726 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
1727 async def test_retries_taken(
1728 self,
1729 async_client: AsyncOpenAI,
1730 failures_before_success: int,
1731 failure_mode: Literal["status", "exception"],
1732 respx_mock: MockRouter,
1733 ) -> None:
1734 client = async_client.with_options(max_retries=4)
1735
1736 nb_retries = 0
1737
1738 def retry_handler(_request: httpx.Request) -> httpx.Response:
1739 nonlocal nb_retries
1740 if nb_retries < failures_before_success:
1741 nb_retries += 1
1742 if failure_mode == "exception":
1743 raise RuntimeError("oops")
1744 return httpx.Response(500)
1745 return httpx.Response(200)
1746
1747 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1748
1749 response = await client.chat.completions.with_raw_response.create(
1750 messages=[
1751 {
1752 "content": "string",
1753 "role": "developer",
1754 }
1755 ],
1756 model="gpt-4o",
1757 )
1758
1759 assert response.retries_taken == failures_before_success
1760 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1761
1762 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1763 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1764 @pytest.mark.respx(base_url=base_url)
1765 @pytest.mark.asyncio
1766 async def test_omit_retry_count_header(
1767 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1768 ) -> None:
1769 client = async_client.with_options(max_retries=4)
1770
1771 nb_retries = 0
1772
1773 def retry_handler(_request: httpx.Request) -> httpx.Response:
1774 nonlocal nb_retries
1775 if nb_retries < failures_before_success:
1776 nb_retries += 1
1777 return httpx.Response(500)
1778 return httpx.Response(200)
1779
1780 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1781
1782 response = await client.chat.completions.with_raw_response.create(
1783 messages=[
1784 {
1785 "content": "string",
1786 "role": "developer",
1787 }
1788 ],
1789 model="gpt-4o",
1790 extra_headers={"x-stainless-retry-count": Omit()},
1791 )
1792
1793 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
1794
1795 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1796 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1797 @pytest.mark.respx(base_url=base_url)
1798 @pytest.mark.asyncio
1799 async def test_overwrite_retry_count_header(
1800 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1801 ) -> None:
1802 client = async_client.with_options(max_retries=4)
1803
1804 nb_retries = 0
1805
1806 def retry_handler(_request: httpx.Request) -> httpx.Response:
1807 nonlocal nb_retries
1808 if nb_retries < failures_before_success:
1809 nb_retries += 1
1810 return httpx.Response(500)
1811 return httpx.Response(200)
1812
1813 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1814
1815 response = await client.chat.completions.with_raw_response.create(
1816 messages=[
1817 {
1818 "content": "string",
1819 "role": "developer",
1820 }
1821 ],
1822 model="gpt-4o",
1823 extra_headers={"x-stainless-retry-count": "42"},
1824 )
1825
1826 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
1827
1828 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1829 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1830 @pytest.mark.respx(base_url=base_url)
1831 @pytest.mark.asyncio
1832 async def test_retries_taken_new_response_class(
1833 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1834 ) -> None:
1835 client = async_client.with_options(max_retries=4)
1836
1837 nb_retries = 0
1838
1839 def retry_handler(_request: httpx.Request) -> httpx.Response:
1840 nonlocal nb_retries
1841 if nb_retries < failures_before_success:
1842 nb_retries += 1
1843 return httpx.Response(500)
1844 return httpx.Response(200)
1845
1846 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1847
1848 async with client.chat.completions.with_streaming_response.create(
1849 messages=[
1850 {
1851 "content": "string",
1852 "role": "developer",
1853 }
1854 ],
1855 model="gpt-4o",
1856 ) as response:
1857 assert response.retries_taken == failures_before_success
1858 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1859
1860 def test_get_platform(self) -> None:
1861 # A previous implementation of asyncify could leave threads unterminated when
1862 # used with nest_asyncio.
1863 #
1864 # Since nest_asyncio.apply() is global and cannot be un-applied, this
1865 # test is run in a separate process to avoid affecting other tests.
1866 test_code = dedent("""
1867 import asyncio
1868 import nest_asyncio
1869 import threading
1870
1871 from openai._utils import asyncify
1872 from openai._base_client import get_platform
1873
1874 async def test_main() -> None:
1875 result = await asyncify(get_platform)()
1876 print(result)
1877 for thread in threading.enumerate():
1878 print(thread.name)
1879
1880 nest_asyncio.apply()
1881 asyncio.run(test_main())
1882 """)
1883 with subprocess.Popen(
1884 [sys.executable, "-c", test_code],
1885 text=True,
1886 ) as process:
1887 timeout = 10 # seconds
1888
1889 start_time = time.monotonic()
1890 while True:
1891 return_code = process.poll()
1892 if return_code is not None:
1893 if return_code != 0:
1894 raise AssertionError("calling get_platform using asyncify resulted in a non-zero exit code")
1895
1896 # success
1897 break
1898
1899 if time.monotonic() - start_time > timeout:
1900 process.kill()
1901 raise AssertionError("calling get_platform using asyncify resulted in a hung process")
1902
1903 time.sleep(0.1)
1904
1905 async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
1906 # Test that the proxy environment variables are set correctly
1907 monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
1908
1909 client = DefaultAsyncHttpxClient()
1910
1911 mounts = tuple(client._mounts.items())
1912 assert len(mounts) == 1
1913 assert mounts[0][0].pattern == "https://"
1914
1915 @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
1916 async def test_default_client_creation(self) -> None:
1917 # Ensure that the client can be initialized without any exceptions
1918 DefaultAsyncHttpxClient(
1919 verify=True,
1920 cert=None,
1921 trust_env=True,
1922 http1=True,
1923 http2=False,
1924 limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
1925 )
1926
1927 @pytest.mark.respx(base_url=base_url)
1928 async def test_follow_redirects(self, respx_mock: MockRouter) -> None:
1929 # Test that the default follow_redirects=True allows following redirects
1930 respx_mock.post("/redirect").mock(
1931 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1932 )
1933 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
1934
1935 response = await self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
1936 assert response.status_code == 200
1937 assert response.json() == {"status": "ok"}
1938
1939 @pytest.mark.respx(base_url=base_url)
1940 async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
1941 # Test that follow_redirects=False prevents following redirects
1942 respx_mock.post("/redirect").mock(
1943 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1944 )
1945
1946 with pytest.raises(APIStatusError) as exc_info:
1947 await self.client.post(
1948 "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
1949 )
1950
1951 assert exc_info.value.response.status_code == 302
1952 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
1953
1954 @pytest.mark.asyncio
1955 async def test_api_key_before_after_refresh_provider(self) -> None:
1956 async def mock_api_key_provider():
1957 return "test_bearer_token"
1958
1959 client = AsyncOpenAI(base_url=base_url, api_key=mock_api_key_provider)
1960
1961 assert client.api_key == ""
1962 assert "Authorization" not in client.auth_headers
1963
1964 await client._refresh_api_key()
1965
1966 assert client.api_key == "test_bearer_token"
1967 assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
1968
1969 @pytest.mark.asyncio
1970 async def test_api_key_before_after_refresh_str(self) -> None:
1971 client = AsyncOpenAI(base_url=base_url, api_key="test_api_key")
1972
1973 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
1974 await client._refresh_api_key()
1975
1976 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
1977
1978 @pytest.mark.asyncio
1979 @pytest.mark.respx()
1980 async def test_bearer_token_refresh_async(self, respx_mock: MockRouter) -> None:
1981 respx_mock.post(base_url + "/chat/completions").mock(
1982 side_effect=[
1983 httpx.Response(500, json={"error": "server error"}),
1984 httpx.Response(200, json={"foo": "bar"}),
1985 ]
1986 )
1987
1988 counter = 0
1989
1990 async def token_provider() -> str:
1991 nonlocal counter
1992
1993 counter += 1
1994
1995 if counter == 1:
1996 return "first"
1997
1998 return "second"
1999
2000 client = AsyncOpenAI(base_url=base_url, api_key=token_provider)
2001 await client.chat.completions.create(messages=[], model="gpt-4")
2002
2003 calls = cast("list[MockRequestCall]", respx_mock.calls)
2004 assert len(calls) == 2
2005
2006 assert calls[0].request.headers.get("Authorization") == "Bearer first"
2007 assert calls[1].request.headers.get("Authorization") == "Bearer second"
2008
2009 @pytest.mark.asyncio
2010 async def test_copy_auth(self) -> None:
2011 async def token_provider_1() -> str:
2012 return "test_bearer_token_1"
2013
2014 async def token_provider_2() -> str:
2015 return "test_bearer_token_2"
2016
2017 client = AsyncOpenAI(base_url=base_url, api_key=token_provider_1).copy(api_key=token_provider_2)
2018 await client._refresh_api_key()
2019 assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}
2020