openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.108.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1978lines · modecode

1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import gc
6import os
7import sys
8import json
9import asyncio
10import inspect
11import tracemalloc
12from typing import Any, Union, Protocol, cast
13from unittest import mock
14from typing_extensions import Literal
15
16import httpx
17import pytest
18from respx import MockRouter
19from pydantic import ValidationError
20
21from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
22from openai._types import Omit
23from openai._utils import asyncify
24from openai._models import BaseModel, FinalRequestOptions
25from openai._streaming import Stream, AsyncStream
26from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
27from openai._base_client import (
28 DEFAULT_TIMEOUT,
29 HTTPX_DEFAULT_TIMEOUT,
30 BaseClient,
31 OtherPlatform,
32 DefaultHttpxClient,
33 DefaultAsyncHttpxClient,
34 get_platform,
35 make_request_options,
36)
37
38from .utils import update_env
39
40base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
41api_key = "My API Key"
42
43
44class MockRequestCall(Protocol):
45 request: httpx.Request
46
47
48def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
49 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
50 url = httpx.URL(request.url)
51 return dict(url.params)
52
53
54def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
55 return 0.1
56
57
58def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
59 transport = client._client._transport
60 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
61
62 pool = transport._pool
63 return len(pool._requests)
64
65
66class TestOpenAI:
67 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
68
69 @pytest.mark.respx(base_url=base_url)
70 def test_raw_response(self, respx_mock: MockRouter) -> None:
71 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
72
73 response = self.client.post("/foo", cast_to=httpx.Response)
74 assert response.status_code == 200
75 assert isinstance(response, httpx.Response)
76 assert response.json() == {"foo": "bar"}
77
78 @pytest.mark.respx(base_url=base_url)
79 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
80 respx_mock.post("/foo").mock(
81 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
82 )
83
84 response = self.client.post("/foo", cast_to=httpx.Response)
85 assert response.status_code == 200
86 assert isinstance(response, httpx.Response)
87 assert response.json() == {"foo": "bar"}
88
89 def test_copy(self) -> None:
90 copied = self.client.copy()
91 assert id(copied) != id(self.client)
92
93 copied = self.client.copy(api_key="another My API Key")
94 assert copied.api_key == "another My API Key"
95 assert self.client.api_key == "My API Key"
96
97 def test_copy_default_options(self) -> None:
98 # options that have a default are overridden correctly
99 copied = self.client.copy(max_retries=7)
100 assert copied.max_retries == 7
101 assert self.client.max_retries == 2
102
103 copied2 = copied.copy(max_retries=6)
104 assert copied2.max_retries == 6
105 assert copied.max_retries == 7
106
107 # timeout
108 assert isinstance(self.client.timeout, httpx.Timeout)
109 copied = self.client.copy(timeout=None)
110 assert copied.timeout is None
111 assert isinstance(self.client.timeout, httpx.Timeout)
112
113 def test_copy_default_headers(self) -> None:
114 client = OpenAI(
115 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
116 )
117 assert client.default_headers["X-Foo"] == "bar"
118
119 # does not override the already given value when not specified
120 copied = client.copy()
121 assert copied.default_headers["X-Foo"] == "bar"
122
123 # merges already given headers
124 copied = client.copy(default_headers={"X-Bar": "stainless"})
125 assert copied.default_headers["X-Foo"] == "bar"
126 assert copied.default_headers["X-Bar"] == "stainless"
127
128 # uses new values for any already given headers
129 copied = client.copy(default_headers={"X-Foo": "stainless"})
130 assert copied.default_headers["X-Foo"] == "stainless"
131
132 # set_default_headers
133
134 # completely overrides already set values
135 copied = client.copy(set_default_headers={})
136 assert copied.default_headers.get("X-Foo") is None
137
138 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
139 assert copied.default_headers["X-Bar"] == "Robert"
140
141 with pytest.raises(
142 ValueError,
143 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
144 ):
145 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
146
147 def test_copy_default_query(self) -> None:
148 client = OpenAI(
149 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
150 )
151 assert _get_params(client)["foo"] == "bar"
152
153 # does not override the already given value when not specified
154 copied = client.copy()
155 assert _get_params(copied)["foo"] == "bar"
156
157 # merges already given params
158 copied = client.copy(default_query={"bar": "stainless"})
159 params = _get_params(copied)
160 assert params["foo"] == "bar"
161 assert params["bar"] == "stainless"
162
163 # uses new values for any already given headers
164 copied = client.copy(default_query={"foo": "stainless"})
165 assert _get_params(copied)["foo"] == "stainless"
166
167 # set_default_query
168
169 # completely overrides already set values
170 copied = client.copy(set_default_query={})
171 assert _get_params(copied) == {}
172
173 copied = client.copy(set_default_query={"bar": "Robert"})
174 assert _get_params(copied)["bar"] == "Robert"
175
176 with pytest.raises(
177 ValueError,
178 # TODO: update
179 match="`default_query` and `set_default_query` arguments are mutually exclusive",
180 ):
181 client.copy(set_default_query={}, default_query={"foo": "Bar"})
182
183 def test_copy_signature(self) -> None:
184 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
185 init_signature = inspect.signature(
186 # mypy doesn't like that we access the `__init__` property.
187 self.client.__init__, # type: ignore[misc]
188 )
189 copy_signature = inspect.signature(self.client.copy)
190 exclude_params = {"transport", "proxies", "_strict_response_validation"}
191
192 for name in init_signature.parameters.keys():
193 if name in exclude_params:
194 continue
195
196 copy_param = copy_signature.parameters.get(name)
197 assert copy_param is not None, f"copy() signature is missing the {name} param"
198
199 @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
200 def test_copy_build_request(self) -> None:
201 options = FinalRequestOptions(method="get", url="/foo")
202
203 def build_request(options: FinalRequestOptions) -> None:
204 client = self.client.copy()
205 client._build_request(options)
206
207 # ensure that the machinery is warmed up before tracing starts.
208 build_request(options)
209 gc.collect()
210
211 tracemalloc.start(1000)
212
213 snapshot_before = tracemalloc.take_snapshot()
214
215 ITERATIONS = 10
216 for _ in range(ITERATIONS):
217 build_request(options)
218
219 gc.collect()
220 snapshot_after = tracemalloc.take_snapshot()
221
222 tracemalloc.stop()
223
224 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
225 if diff.count == 0:
226 # Avoid false positives by considering only leaks (i.e. allocations that persist).
227 return
228
229 if diff.count % ITERATIONS != 0:
230 # Avoid false positives by considering only leaks that appear per iteration.
231 return
232
233 for frame in diff.traceback:
234 if any(
235 frame.filename.endswith(fragment)
236 for fragment in [
237 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
238 #
239 # removing the decorator fixes the leak for reasons we don't understand.
240 "openai/_legacy_response.py",
241 "openai/_response.py",
242 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
243 "openai/_compat.py",
244 # Standard library leaks we don't care about.
245 "/logging/__init__.py",
246 ]
247 ):
248 return
249
250 leaks.append(diff)
251
252 leaks: list[tracemalloc.StatisticDiff] = []
253 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
254 add_leak(leaks, diff)
255 if leaks:
256 for leak in leaks:
257 print("MEMORY LEAK:", leak)
258 for frame in leak.traceback:
259 print(frame)
260 raise AssertionError()
261
262 def test_request_timeout(self) -> None:
263 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
264 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
265 assert timeout == DEFAULT_TIMEOUT
266
267 request = self.client._build_request(
268 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
269 )
270 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
271 assert timeout == httpx.Timeout(100.0)
272
273 def test_client_timeout_option(self) -> None:
274 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
275
276 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
277 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
278 assert timeout == httpx.Timeout(0)
279
280 def test_http_client_timeout_option(self) -> None:
281 # custom timeout given to the httpx client should be used
282 with httpx.Client(timeout=None) as http_client:
283 client = OpenAI(
284 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
285 )
286
287 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
288 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
289 assert timeout == httpx.Timeout(None)
290
291 # no timeout given to the httpx client should not use the httpx default
292 with httpx.Client() as http_client:
293 client = OpenAI(
294 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
295 )
296
297 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
298 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
299 assert timeout == DEFAULT_TIMEOUT
300
301 # explicitly passing the default timeout currently results in it being ignored
302 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
303 client = OpenAI(
304 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
305 )
306
307 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
308 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
309 assert timeout == DEFAULT_TIMEOUT # our default
310
311 async def test_invalid_http_client(self) -> None:
312 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
313 async with httpx.AsyncClient() as http_client:
314 OpenAI(
315 base_url=base_url,
316 api_key=api_key,
317 _strict_response_validation=True,
318 http_client=cast(Any, http_client),
319 )
320
321 def test_default_headers_option(self) -> None:
322 client = OpenAI(
323 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
324 )
325 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
326 assert request.headers.get("x-foo") == "bar"
327 assert request.headers.get("x-stainless-lang") == "python"
328
329 client2 = OpenAI(
330 base_url=base_url,
331 api_key=api_key,
332 _strict_response_validation=True,
333 default_headers={
334 "X-Foo": "stainless",
335 "X-Stainless-Lang": "my-overriding-header",
336 },
337 )
338 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
339 assert request.headers.get("x-foo") == "stainless"
340 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
341
342 def test_validate_headers(self) -> None:
343 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
344 options = client._prepare_options(FinalRequestOptions(method="get", url="/foo"))
345 request = client._build_request(options)
346
347 assert request.headers.get("Authorization") == f"Bearer {api_key}"
348
349 with pytest.raises(OpenAIError):
350 with update_env(**{"OPENAI_API_KEY": Omit()}):
351 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
352 _ = client2
353
354 def test_default_query_option(self) -> None:
355 client = OpenAI(
356 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
357 )
358 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
359 url = httpx.URL(request.url)
360 assert dict(url.params) == {"query_param": "bar"}
361
362 request = client._build_request(
363 FinalRequestOptions(
364 method="get",
365 url="/foo",
366 params={"foo": "baz", "query_param": "overridden"},
367 )
368 )
369 url = httpx.URL(request.url)
370 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
371
372 def test_request_extra_json(self) -> None:
373 request = self.client._build_request(
374 FinalRequestOptions(
375 method="post",
376 url="/foo",
377 json_data={"foo": "bar"},
378 extra_json={"baz": False},
379 ),
380 )
381 data = json.loads(request.content.decode("utf-8"))
382 assert data == {"foo": "bar", "baz": False}
383
384 request = self.client._build_request(
385 FinalRequestOptions(
386 method="post",
387 url="/foo",
388 extra_json={"baz": False},
389 ),
390 )
391 data = json.loads(request.content.decode("utf-8"))
392 assert data == {"baz": False}
393
394 # `extra_json` takes priority over `json_data` when keys clash
395 request = self.client._build_request(
396 FinalRequestOptions(
397 method="post",
398 url="/foo",
399 json_data={"foo": "bar", "baz": True},
400 extra_json={"baz": None},
401 ),
402 )
403 data = json.loads(request.content.decode("utf-8"))
404 assert data == {"foo": "bar", "baz": None}
405
406 def test_request_extra_headers(self) -> None:
407 request = self.client._build_request(
408 FinalRequestOptions(
409 method="post",
410 url="/foo",
411 **make_request_options(extra_headers={"X-Foo": "Foo"}),
412 ),
413 )
414 assert request.headers.get("X-Foo") == "Foo"
415
416 # `extra_headers` takes priority over `default_headers` when keys clash
417 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
418 FinalRequestOptions(
419 method="post",
420 url="/foo",
421 **make_request_options(
422 extra_headers={"X-Bar": "false"},
423 ),
424 ),
425 )
426 assert request.headers.get("X-Bar") == "false"
427
428 def test_request_extra_query(self) -> None:
429 request = self.client._build_request(
430 FinalRequestOptions(
431 method="post",
432 url="/foo",
433 **make_request_options(
434 extra_query={"my_query_param": "Foo"},
435 ),
436 ),
437 )
438 params = dict(request.url.params)
439 assert params == {"my_query_param": "Foo"}
440
441 # if both `query` and `extra_query` are given, they are merged
442 request = self.client._build_request(
443 FinalRequestOptions(
444 method="post",
445 url="/foo",
446 **make_request_options(
447 query={"bar": "1"},
448 extra_query={"foo": "2"},
449 ),
450 ),
451 )
452 params = dict(request.url.params)
453 assert params == {"bar": "1", "foo": "2"}
454
455 # `extra_query` takes priority over `query` when keys clash
456 request = self.client._build_request(
457 FinalRequestOptions(
458 method="post",
459 url="/foo",
460 **make_request_options(
461 query={"foo": "1"},
462 extra_query={"foo": "2"},
463 ),
464 ),
465 )
466 params = dict(request.url.params)
467 assert params == {"foo": "2"}
468
469 def test_multipart_repeating_array(self, client: OpenAI) -> None:
470 request = client._build_request(
471 FinalRequestOptions.construct(
472 method="post",
473 url="/foo",
474 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
475 json_data={"array": ["foo", "bar"]},
476 files=[("foo.txt", b"hello world")],
477 )
478 )
479
480 assert request.read().split(b"\r\n") == [
481 b"--6b7ba517decee4a450543ea6ae821c82",
482 b'Content-Disposition: form-data; name="array[]"',
483 b"",
484 b"foo",
485 b"--6b7ba517decee4a450543ea6ae821c82",
486 b'Content-Disposition: form-data; name="array[]"',
487 b"",
488 b"bar",
489 b"--6b7ba517decee4a450543ea6ae821c82",
490 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
491 b"Content-Type: application/octet-stream",
492 b"",
493 b"hello world",
494 b"--6b7ba517decee4a450543ea6ae821c82--",
495 b"",
496 ]
497
498 @pytest.mark.respx(base_url=base_url)
499 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
500 class Model1(BaseModel):
501 name: str
502
503 class Model2(BaseModel):
504 foo: str
505
506 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
507
508 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
509 assert isinstance(response, Model2)
510 assert response.foo == "bar"
511
512 @pytest.mark.respx(base_url=base_url)
513 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
514 """Union of objects with the same field name using a different type"""
515
516 class Model1(BaseModel):
517 foo: int
518
519 class Model2(BaseModel):
520 foo: str
521
522 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
523
524 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
525 assert isinstance(response, Model2)
526 assert response.foo == "bar"
527
528 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
529
530 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
531 assert isinstance(response, Model1)
532 assert response.foo == 1
533
534 @pytest.mark.respx(base_url=base_url)
535 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
536 """
537 Response that sets Content-Type to something other than application/json but returns json data
538 """
539
540 class Model(BaseModel):
541 foo: int
542
543 respx_mock.get("/foo").mock(
544 return_value=httpx.Response(
545 200,
546 content=json.dumps({"foo": 2}),
547 headers={"Content-Type": "application/text"},
548 )
549 )
550
551 response = self.client.get("/foo", cast_to=Model)
552 assert isinstance(response, Model)
553 assert response.foo == 2
554
555 def test_base_url_setter(self) -> None:
556 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
557 assert client.base_url == "https://example.com/from_init/"
558
559 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
560
561 assert client.base_url == "https://example.com/from_setter/"
562
563 def test_base_url_env(self) -> None:
564 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
565 client = OpenAI(api_key=api_key, _strict_response_validation=True)
566 assert client.base_url == "http://localhost:5000/from/env/"
567
568 @pytest.mark.parametrize(
569 "client",
570 [
571 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
572 OpenAI(
573 base_url="http://localhost:5000/custom/path/",
574 api_key=api_key,
575 _strict_response_validation=True,
576 http_client=httpx.Client(),
577 ),
578 ],
579 ids=["standard", "custom http client"],
580 )
581 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
582 request = client._build_request(
583 FinalRequestOptions(
584 method="post",
585 url="/foo",
586 json_data={"foo": "bar"},
587 ),
588 )
589 assert request.url == "http://localhost:5000/custom/path/foo"
590
591 @pytest.mark.parametrize(
592 "client",
593 [
594 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
595 OpenAI(
596 base_url="http://localhost:5000/custom/path/",
597 api_key=api_key,
598 _strict_response_validation=True,
599 http_client=httpx.Client(),
600 ),
601 ],
602 ids=["standard", "custom http client"],
603 )
604 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
605 request = client._build_request(
606 FinalRequestOptions(
607 method="post",
608 url="/foo",
609 json_data={"foo": "bar"},
610 ),
611 )
612 assert request.url == "http://localhost:5000/custom/path/foo"
613
614 @pytest.mark.parametrize(
615 "client",
616 [
617 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
618 OpenAI(
619 base_url="http://localhost:5000/custom/path/",
620 api_key=api_key,
621 _strict_response_validation=True,
622 http_client=httpx.Client(),
623 ),
624 ],
625 ids=["standard", "custom http client"],
626 )
627 def test_absolute_request_url(self, client: OpenAI) -> None:
628 request = client._build_request(
629 FinalRequestOptions(
630 method="post",
631 url="https://myapi.com/foo",
632 json_data={"foo": "bar"},
633 ),
634 )
635 assert request.url == "https://myapi.com/foo"
636
637 def test_copied_client_does_not_close_http(self) -> None:
638 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
639 assert not client.is_closed()
640
641 copied = client.copy()
642 assert copied is not client
643
644 del copied
645
646 assert not client.is_closed()
647
648 def test_client_context_manager(self) -> None:
649 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
650 with client as c2:
651 assert c2 is client
652 assert not c2.is_closed()
653 assert not client.is_closed()
654 assert client.is_closed()
655
656 @pytest.mark.respx(base_url=base_url)
657 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
658 class Model(BaseModel):
659 foo: str
660
661 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
662
663 with pytest.raises(APIResponseValidationError) as exc:
664 self.client.get("/foo", cast_to=Model)
665
666 assert isinstance(exc.value.__cause__, ValidationError)
667
668 def test_client_max_retries_validation(self) -> None:
669 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
670 OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None))
671
672 @pytest.mark.respx(base_url=base_url)
673 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
674 class Model(BaseModel):
675 name: str
676
677 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
678
679 stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
680 assert isinstance(stream, Stream)
681 stream.response.close()
682
683 @pytest.mark.respx(base_url=base_url)
684 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
685 class Model(BaseModel):
686 name: str
687
688 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
689
690 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
691
692 with pytest.raises(APIResponseValidationError):
693 strict_client.get("/foo", cast_to=Model)
694
695 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
696
697 response = client.get("/foo", cast_to=Model)
698 assert isinstance(response, str) # type: ignore[unreachable]
699
700 @pytest.mark.parametrize(
701 "remaining_retries,retry_after,timeout",
702 [
703 [3, "20", 20],
704 [3, "0", 0.5],
705 [3, "-10", 0.5],
706 [3, "60", 60],
707 [3, "61", 0.5],
708 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
709 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
710 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
711 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
712 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
713 [3, "99999999999999999999999999999999999", 0.5],
714 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
715 [3, "", 0.5],
716 [2, "", 0.5 * 2.0],
717 [1, "", 0.5 * 4.0],
718 [-1100, "", 8], # test large number potentially overflowing
719 ],
720 )
721 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
722 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
723 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
724
725 headers = httpx.Headers({"retry-after": retry_after})
726 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
727 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
728 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
729
730 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
731 @pytest.mark.respx(base_url=base_url)
732 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: OpenAI) -> None:
733 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
734
735 with pytest.raises(APITimeoutError):
736 client.chat.completions.with_streaming_response.create(
737 messages=[
738 {
739 "content": "string",
740 "role": "developer",
741 }
742 ],
743 model="gpt-4o",
744 ).__enter__()
745
746 assert _get_open_connections(self.client) == 0
747
748 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
749 @pytest.mark.respx(base_url=base_url)
750 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: OpenAI) -> None:
751 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
752
753 with pytest.raises(APIStatusError):
754 client.chat.completions.with_streaming_response.create(
755 messages=[
756 {
757 "content": "string",
758 "role": "developer",
759 }
760 ],
761 model="gpt-4o",
762 ).__enter__()
763 assert _get_open_connections(self.client) == 0
764
765 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
766 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
767 @pytest.mark.respx(base_url=base_url)
768 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
769 def test_retries_taken(
770 self,
771 client: OpenAI,
772 failures_before_success: int,
773 failure_mode: Literal["status", "exception"],
774 respx_mock: MockRouter,
775 ) -> None:
776 client = client.with_options(max_retries=4)
777
778 nb_retries = 0
779
780 def retry_handler(_request: httpx.Request) -> httpx.Response:
781 nonlocal nb_retries
782 if nb_retries < failures_before_success:
783 nb_retries += 1
784 if failure_mode == "exception":
785 raise RuntimeError("oops")
786 return httpx.Response(500)
787 return httpx.Response(200)
788
789 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
790
791 response = client.chat.completions.with_raw_response.create(
792 messages=[
793 {
794 "content": "string",
795 "role": "developer",
796 }
797 ],
798 model="gpt-4o",
799 )
800
801 assert response.retries_taken == failures_before_success
802 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
803
804 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
805 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
806 @pytest.mark.respx(base_url=base_url)
807 def test_omit_retry_count_header(
808 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
809 ) -> None:
810 client = client.with_options(max_retries=4)
811
812 nb_retries = 0
813
814 def retry_handler(_request: httpx.Request) -> httpx.Response:
815 nonlocal nb_retries
816 if nb_retries < failures_before_success:
817 nb_retries += 1
818 return httpx.Response(500)
819 return httpx.Response(200)
820
821 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
822
823 response = client.chat.completions.with_raw_response.create(
824 messages=[
825 {
826 "content": "string",
827 "role": "developer",
828 }
829 ],
830 model="gpt-4o",
831 extra_headers={"x-stainless-retry-count": Omit()},
832 )
833
834 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
835
836 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
837 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
838 @pytest.mark.respx(base_url=base_url)
839 def test_overwrite_retry_count_header(
840 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
841 ) -> None:
842 client = client.with_options(max_retries=4)
843
844 nb_retries = 0
845
846 def retry_handler(_request: httpx.Request) -> httpx.Response:
847 nonlocal nb_retries
848 if nb_retries < failures_before_success:
849 nb_retries += 1
850 return httpx.Response(500)
851 return httpx.Response(200)
852
853 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
854
855 response = client.chat.completions.with_raw_response.create(
856 messages=[
857 {
858 "content": "string",
859 "role": "developer",
860 }
861 ],
862 model="gpt-4o",
863 extra_headers={"x-stainless-retry-count": "42"},
864 )
865
866 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
867
868 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
869 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
870 @pytest.mark.respx(base_url=base_url)
871 def test_retries_taken_new_response_class(
872 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
873 ) -> None:
874 client = client.with_options(max_retries=4)
875
876 nb_retries = 0
877
878 def retry_handler(_request: httpx.Request) -> httpx.Response:
879 nonlocal nb_retries
880 if nb_retries < failures_before_success:
881 nb_retries += 1
882 return httpx.Response(500)
883 return httpx.Response(200)
884
885 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
886
887 with client.chat.completions.with_streaming_response.create(
888 messages=[
889 {
890 "content": "string",
891 "role": "developer",
892 }
893 ],
894 model="gpt-4o",
895 ) as response:
896 assert response.retries_taken == failures_before_success
897 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
898
899 def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
900 # Test that the proxy environment variables are set correctly
901 monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
902
903 client = DefaultHttpxClient()
904
905 mounts = tuple(client._mounts.items())
906 assert len(mounts) == 1
907 assert mounts[0][0].pattern == "https://"
908
909 @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
910 def test_default_client_creation(self) -> None:
911 # Ensure that the client can be initialized without any exceptions
912 DefaultHttpxClient(
913 verify=True,
914 cert=None,
915 trust_env=True,
916 http1=True,
917 http2=False,
918 limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
919 )
920
921 @pytest.mark.respx(base_url=base_url)
922 def test_follow_redirects(self, respx_mock: MockRouter) -> None:
923 # Test that the default follow_redirects=True allows following redirects
924 respx_mock.post("/redirect").mock(
925 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
926 )
927 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
928
929 response = self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
930 assert response.status_code == 200
931 assert response.json() == {"status": "ok"}
932
933 @pytest.mark.respx(base_url=base_url)
934 def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
935 # Test that follow_redirects=False prevents following redirects
936 respx_mock.post("/redirect").mock(
937 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
938 )
939
940 with pytest.raises(APIStatusError) as exc_info:
941 self.client.post(
942 "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
943 )
944
945 assert exc_info.value.response.status_code == 302
946 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
947
948 def test_api_key_before_after_refresh_provider(self) -> None:
949 client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token")
950
951 assert client.api_key == ""
952 assert "Authorization" not in client.auth_headers
953
954 client._refresh_api_key()
955
956 assert client.api_key == "test_bearer_token"
957 assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
958
959 def test_api_key_before_after_refresh_str(self) -> None:
960 client = OpenAI(base_url=base_url, api_key="test_api_key")
961
962 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
963 client._refresh_api_key()
964
965 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
966
967 @pytest.mark.respx()
968 def test_api_key_refresh_on_retry(self, respx_mock: MockRouter) -> None:
969 respx_mock.post(base_url + "/chat/completions").mock(
970 side_effect=[
971 httpx.Response(500, json={"error": "server error"}),
972 httpx.Response(200, json={"foo": "bar"}),
973 ]
974 )
975
976 counter = 0
977
978 def token_provider() -> str:
979 nonlocal counter
980
981 counter += 1
982
983 if counter == 1:
984 return "first"
985
986 return "second"
987
988 client = OpenAI(base_url=base_url, api_key=token_provider)
989 client.chat.completions.create(messages=[], model="gpt-4")
990
991 calls = cast("list[MockRequestCall]", respx_mock.calls)
992 assert len(calls) == 2
993
994 assert calls[0].request.headers.get("Authorization") == "Bearer first"
995 assert calls[1].request.headers.get("Authorization") == "Bearer second"
996
997 def test_copy_auth(self) -> None:
998 client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token_1").copy(
999 api_key=lambda: "test_bearer_token_2"
1000 )
1001 client._refresh_api_key()
1002 assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}
1003
1004
1005class TestAsyncOpenAI:
1006 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1007
1008 @pytest.mark.respx(base_url=base_url)
1009 @pytest.mark.asyncio
1010 async def test_raw_response(self, respx_mock: MockRouter) -> None:
1011 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1012
1013 response = await self.client.post("/foo", cast_to=httpx.Response)
1014 assert response.status_code == 200
1015 assert isinstance(response, httpx.Response)
1016 assert response.json() == {"foo": "bar"}
1017
1018 @pytest.mark.respx(base_url=base_url)
1019 @pytest.mark.asyncio
1020 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
1021 respx_mock.post("/foo").mock(
1022 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
1023 )
1024
1025 response = await self.client.post("/foo", cast_to=httpx.Response)
1026 assert response.status_code == 200
1027 assert isinstance(response, httpx.Response)
1028 assert response.json() == {"foo": "bar"}
1029
1030 def test_copy(self) -> None:
1031 copied = self.client.copy()
1032 assert id(copied) != id(self.client)
1033
1034 copied = self.client.copy(api_key="another My API Key")
1035 assert copied.api_key == "another My API Key"
1036 assert self.client.api_key == "My API Key"
1037
1038 def test_copy_default_options(self) -> None:
1039 # options that have a default are overridden correctly
1040 copied = self.client.copy(max_retries=7)
1041 assert copied.max_retries == 7
1042 assert self.client.max_retries == 2
1043
1044 copied2 = copied.copy(max_retries=6)
1045 assert copied2.max_retries == 6
1046 assert copied.max_retries == 7
1047
1048 # timeout
1049 assert isinstance(self.client.timeout, httpx.Timeout)
1050 copied = self.client.copy(timeout=None)
1051 assert copied.timeout is None
1052 assert isinstance(self.client.timeout, httpx.Timeout)
1053
1054 def test_copy_default_headers(self) -> None:
1055 client = AsyncOpenAI(
1056 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1057 )
1058 assert client.default_headers["X-Foo"] == "bar"
1059
1060 # does not override the already given value when not specified
1061 copied = client.copy()
1062 assert copied.default_headers["X-Foo"] == "bar"
1063
1064 # merges already given headers
1065 copied = client.copy(default_headers={"X-Bar": "stainless"})
1066 assert copied.default_headers["X-Foo"] == "bar"
1067 assert copied.default_headers["X-Bar"] == "stainless"
1068
1069 # uses new values for any already given headers
1070 copied = client.copy(default_headers={"X-Foo": "stainless"})
1071 assert copied.default_headers["X-Foo"] == "stainless"
1072
1073 # set_default_headers
1074
1075 # completely overrides already set values
1076 copied = client.copy(set_default_headers={})
1077 assert copied.default_headers.get("X-Foo") is None
1078
1079 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
1080 assert copied.default_headers["X-Bar"] == "Robert"
1081
1082 with pytest.raises(
1083 ValueError,
1084 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
1085 ):
1086 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
1087
1088 def test_copy_default_query(self) -> None:
1089 client = AsyncOpenAI(
1090 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
1091 )
1092 assert _get_params(client)["foo"] == "bar"
1093
1094 # does not override the already given value when not specified
1095 copied = client.copy()
1096 assert _get_params(copied)["foo"] == "bar"
1097
1098 # merges already given params
1099 copied = client.copy(default_query={"bar": "stainless"})
1100 params = _get_params(copied)
1101 assert params["foo"] == "bar"
1102 assert params["bar"] == "stainless"
1103
1104 # uses new values for any already given headers
1105 copied = client.copy(default_query={"foo": "stainless"})
1106 assert _get_params(copied)["foo"] == "stainless"
1107
1108 # set_default_query
1109
1110 # completely overrides already set values
1111 copied = client.copy(set_default_query={})
1112 assert _get_params(copied) == {}
1113
1114 copied = client.copy(set_default_query={"bar": "Robert"})
1115 assert _get_params(copied)["bar"] == "Robert"
1116
1117 with pytest.raises(
1118 ValueError,
1119 # TODO: update
1120 match="`default_query` and `set_default_query` arguments are mutually exclusive",
1121 ):
1122 client.copy(set_default_query={}, default_query={"foo": "Bar"})
1123
1124 def test_copy_signature(self) -> None:
1125 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
1126 init_signature = inspect.signature(
1127 # mypy doesn't like that we access the `__init__` property.
1128 self.client.__init__, # type: ignore[misc]
1129 )
1130 copy_signature = inspect.signature(self.client.copy)
1131 exclude_params = {"transport", "proxies", "_strict_response_validation"}
1132
1133 for name in init_signature.parameters.keys():
1134 if name in exclude_params:
1135 continue
1136
1137 copy_param = copy_signature.parameters.get(name)
1138 assert copy_param is not None, f"copy() signature is missing the {name} param"
1139
1140 @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
1141 def test_copy_build_request(self) -> None:
1142 options = FinalRequestOptions(method="get", url="/foo")
1143
1144 def build_request(options: FinalRequestOptions) -> None:
1145 client = self.client.copy()
1146 client._build_request(options)
1147
1148 # ensure that the machinery is warmed up before tracing starts.
1149 build_request(options)
1150 gc.collect()
1151
1152 tracemalloc.start(1000)
1153
1154 snapshot_before = tracemalloc.take_snapshot()
1155
1156 ITERATIONS = 10
1157 for _ in range(ITERATIONS):
1158 build_request(options)
1159
1160 gc.collect()
1161 snapshot_after = tracemalloc.take_snapshot()
1162
1163 tracemalloc.stop()
1164
1165 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
1166 if diff.count == 0:
1167 # Avoid false positives by considering only leaks (i.e. allocations that persist).
1168 return
1169
1170 if diff.count % ITERATIONS != 0:
1171 # Avoid false positives by considering only leaks that appear per iteration.
1172 return
1173
1174 for frame in diff.traceback:
1175 if any(
1176 frame.filename.endswith(fragment)
1177 for fragment in [
1178 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
1179 #
1180 # removing the decorator fixes the leak for reasons we don't understand.
1181 "openai/_legacy_response.py",
1182 "openai/_response.py",
1183 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
1184 "openai/_compat.py",
1185 # Standard library leaks we don't care about.
1186 "/logging/__init__.py",
1187 ]
1188 ):
1189 return
1190
1191 leaks.append(diff)
1192
1193 leaks: list[tracemalloc.StatisticDiff] = []
1194 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
1195 add_leak(leaks, diff)
1196 if leaks:
1197 for leak in leaks:
1198 print("MEMORY LEAK:", leak)
1199 for frame in leak.traceback:
1200 print(frame)
1201 raise AssertionError()
1202
1203 async def test_request_timeout(self) -> None:
1204 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
1205 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1206 assert timeout == DEFAULT_TIMEOUT
1207
1208 request = self.client._build_request(
1209 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
1210 )
1211 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1212 assert timeout == httpx.Timeout(100.0)
1213
1214 async def test_client_timeout_option(self) -> None:
1215 client = AsyncOpenAI(
1216 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
1217 )
1218
1219 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1220 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1221 assert timeout == httpx.Timeout(0)
1222
1223 async def test_http_client_timeout_option(self) -> None:
1224 # custom timeout given to the httpx client should be used
1225 async with httpx.AsyncClient(timeout=None) as http_client:
1226 client = AsyncOpenAI(
1227 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1228 )
1229
1230 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1231 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1232 assert timeout == httpx.Timeout(None)
1233
1234 # no timeout given to the httpx client should not use the httpx default
1235 async with httpx.AsyncClient() as http_client:
1236 client = AsyncOpenAI(
1237 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1238 )
1239
1240 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1241 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1242 assert timeout == DEFAULT_TIMEOUT
1243
1244 # explicitly passing the default timeout currently results in it being ignored
1245 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1246 client = AsyncOpenAI(
1247 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1248 )
1249
1250 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1251 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1252 assert timeout == DEFAULT_TIMEOUT # our default
1253
1254 def test_invalid_http_client(self) -> None:
1255 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1256 with httpx.Client() as http_client:
1257 AsyncOpenAI(
1258 base_url=base_url,
1259 api_key=api_key,
1260 _strict_response_validation=True,
1261 http_client=cast(Any, http_client),
1262 )
1263
1264 def test_default_headers_option(self) -> None:
1265 client = AsyncOpenAI(
1266 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1267 )
1268 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1269 assert request.headers.get("x-foo") == "bar"
1270 assert request.headers.get("x-stainless-lang") == "python"
1271
1272 client2 = AsyncOpenAI(
1273 base_url=base_url,
1274 api_key=api_key,
1275 _strict_response_validation=True,
1276 default_headers={
1277 "X-Foo": "stainless",
1278 "X-Stainless-Lang": "my-overriding-header",
1279 },
1280 )
1281 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1282 assert request.headers.get("x-foo") == "stainless"
1283 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1284
1285 async def test_validate_headers(self) -> None:
1286 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1287 options = await client._prepare_options(FinalRequestOptions(method="get", url="/foo"))
1288 request = client._build_request(options)
1289 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1290
1291 with pytest.raises(OpenAIError):
1292 with update_env(**{"OPENAI_API_KEY": Omit()}):
1293 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1294 _ = client2
1295
1296 def test_default_query_option(self) -> None:
1297 client = AsyncOpenAI(
1298 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1299 )
1300 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1301 url = httpx.URL(request.url)
1302 assert dict(url.params) == {"query_param": "bar"}
1303
1304 request = client._build_request(
1305 FinalRequestOptions(
1306 method="get",
1307 url="/foo",
1308 params={"foo": "baz", "query_param": "overridden"},
1309 )
1310 )
1311 url = httpx.URL(request.url)
1312 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
1313
1314 def test_request_extra_json(self) -> None:
1315 request = self.client._build_request(
1316 FinalRequestOptions(
1317 method="post",
1318 url="/foo",
1319 json_data={"foo": "bar"},
1320 extra_json={"baz": False},
1321 ),
1322 )
1323 data = json.loads(request.content.decode("utf-8"))
1324 assert data == {"foo": "bar", "baz": False}
1325
1326 request = self.client._build_request(
1327 FinalRequestOptions(
1328 method="post",
1329 url="/foo",
1330 extra_json={"baz": False},
1331 ),
1332 )
1333 data = json.loads(request.content.decode("utf-8"))
1334 assert data == {"baz": False}
1335
1336 # `extra_json` takes priority over `json_data` when keys clash
1337 request = self.client._build_request(
1338 FinalRequestOptions(
1339 method="post",
1340 url="/foo",
1341 json_data={"foo": "bar", "baz": True},
1342 extra_json={"baz": None},
1343 ),
1344 )
1345 data = json.loads(request.content.decode("utf-8"))
1346 assert data == {"foo": "bar", "baz": None}
1347
1348 def test_request_extra_headers(self) -> None:
1349 request = self.client._build_request(
1350 FinalRequestOptions(
1351 method="post",
1352 url="/foo",
1353 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1354 ),
1355 )
1356 assert request.headers.get("X-Foo") == "Foo"
1357
1358 # `extra_headers` takes priority over `default_headers` when keys clash
1359 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1360 FinalRequestOptions(
1361 method="post",
1362 url="/foo",
1363 **make_request_options(
1364 extra_headers={"X-Bar": "false"},
1365 ),
1366 ),
1367 )
1368 assert request.headers.get("X-Bar") == "false"
1369
1370 def test_request_extra_query(self) -> None:
1371 request = self.client._build_request(
1372 FinalRequestOptions(
1373 method="post",
1374 url="/foo",
1375 **make_request_options(
1376 extra_query={"my_query_param": "Foo"},
1377 ),
1378 ),
1379 )
1380 params = dict(request.url.params)
1381 assert params == {"my_query_param": "Foo"}
1382
1383 # if both `query` and `extra_query` are given, they are merged
1384 request = self.client._build_request(
1385 FinalRequestOptions(
1386 method="post",
1387 url="/foo",
1388 **make_request_options(
1389 query={"bar": "1"},
1390 extra_query={"foo": "2"},
1391 ),
1392 ),
1393 )
1394 params = dict(request.url.params)
1395 assert params == {"bar": "1", "foo": "2"}
1396
1397 # `extra_query` takes priority over `query` when keys clash
1398 request = self.client._build_request(
1399 FinalRequestOptions(
1400 method="post",
1401 url="/foo",
1402 **make_request_options(
1403 query={"foo": "1"},
1404 extra_query={"foo": "2"},
1405 ),
1406 ),
1407 )
1408 params = dict(request.url.params)
1409 assert params == {"foo": "2"}
1410
1411 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1412 request = async_client._build_request(
1413 FinalRequestOptions.construct(
1414 method="post",
1415 url="/foo",
1416 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1417 json_data={"array": ["foo", "bar"]},
1418 files=[("foo.txt", b"hello world")],
1419 )
1420 )
1421
1422 assert request.read().split(b"\r\n") == [
1423 b"--6b7ba517decee4a450543ea6ae821c82",
1424 b'Content-Disposition: form-data; name="array[]"',
1425 b"",
1426 b"foo",
1427 b"--6b7ba517decee4a450543ea6ae821c82",
1428 b'Content-Disposition: form-data; name="array[]"',
1429 b"",
1430 b"bar",
1431 b"--6b7ba517decee4a450543ea6ae821c82",
1432 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1433 b"Content-Type: application/octet-stream",
1434 b"",
1435 b"hello world",
1436 b"--6b7ba517decee4a450543ea6ae821c82--",
1437 b"",
1438 ]
1439
1440 @pytest.mark.respx(base_url=base_url)
1441 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1442 class Model1(BaseModel):
1443 name: str
1444
1445 class Model2(BaseModel):
1446 foo: str
1447
1448 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1449
1450 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1451 assert isinstance(response, Model2)
1452 assert response.foo == "bar"
1453
1454 @pytest.mark.respx(base_url=base_url)
1455 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1456 """Union of objects with the same field name using a different type"""
1457
1458 class Model1(BaseModel):
1459 foo: int
1460
1461 class Model2(BaseModel):
1462 foo: str
1463
1464 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1465
1466 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1467 assert isinstance(response, Model2)
1468 assert response.foo == "bar"
1469
1470 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1471
1472 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1473 assert isinstance(response, Model1)
1474 assert response.foo == 1
1475
1476 @pytest.mark.respx(base_url=base_url)
1477 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1478 """
1479 Response that sets Content-Type to something other than application/json but returns json data
1480 """
1481
1482 class Model(BaseModel):
1483 foo: int
1484
1485 respx_mock.get("/foo").mock(
1486 return_value=httpx.Response(
1487 200,
1488 content=json.dumps({"foo": 2}),
1489 headers={"Content-Type": "application/text"},
1490 )
1491 )
1492
1493 response = await self.client.get("/foo", cast_to=Model)
1494 assert isinstance(response, Model)
1495 assert response.foo == 2
1496
1497 def test_base_url_setter(self) -> None:
1498 client = AsyncOpenAI(
1499 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1500 )
1501 assert client.base_url == "https://example.com/from_init/"
1502
1503 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1504
1505 assert client.base_url == "https://example.com/from_setter/"
1506
1507 def test_base_url_env(self) -> None:
1508 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1509 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1510 assert client.base_url == "http://localhost:5000/from/env/"
1511
1512 @pytest.mark.parametrize(
1513 "client",
1514 [
1515 AsyncOpenAI(
1516 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1517 ),
1518 AsyncOpenAI(
1519 base_url="http://localhost:5000/custom/path/",
1520 api_key=api_key,
1521 _strict_response_validation=True,
1522 http_client=httpx.AsyncClient(),
1523 ),
1524 ],
1525 ids=["standard", "custom http client"],
1526 )
1527 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1528 request = client._build_request(
1529 FinalRequestOptions(
1530 method="post",
1531 url="/foo",
1532 json_data={"foo": "bar"},
1533 ),
1534 )
1535 assert request.url == "http://localhost:5000/custom/path/foo"
1536
1537 @pytest.mark.parametrize(
1538 "client",
1539 [
1540 AsyncOpenAI(
1541 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1542 ),
1543 AsyncOpenAI(
1544 base_url="http://localhost:5000/custom/path/",
1545 api_key=api_key,
1546 _strict_response_validation=True,
1547 http_client=httpx.AsyncClient(),
1548 ),
1549 ],
1550 ids=["standard", "custom http client"],
1551 )
1552 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1553 request = client._build_request(
1554 FinalRequestOptions(
1555 method="post",
1556 url="/foo",
1557 json_data={"foo": "bar"},
1558 ),
1559 )
1560 assert request.url == "http://localhost:5000/custom/path/foo"
1561
1562 @pytest.mark.parametrize(
1563 "client",
1564 [
1565 AsyncOpenAI(
1566 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1567 ),
1568 AsyncOpenAI(
1569 base_url="http://localhost:5000/custom/path/",
1570 api_key=api_key,
1571 _strict_response_validation=True,
1572 http_client=httpx.AsyncClient(),
1573 ),
1574 ],
1575 ids=["standard", "custom http client"],
1576 )
1577 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1578 request = client._build_request(
1579 FinalRequestOptions(
1580 method="post",
1581 url="https://myapi.com/foo",
1582 json_data={"foo": "bar"},
1583 ),
1584 )
1585 assert request.url == "https://myapi.com/foo"
1586
1587 async def test_copied_client_does_not_close_http(self) -> None:
1588 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1589 assert not client.is_closed()
1590
1591 copied = client.copy()
1592 assert copied is not client
1593
1594 del copied
1595
1596 await asyncio.sleep(0.2)
1597 assert not client.is_closed()
1598
1599 async def test_client_context_manager(self) -> None:
1600 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1601 async with client as c2:
1602 assert c2 is client
1603 assert not c2.is_closed()
1604 assert not client.is_closed()
1605 assert client.is_closed()
1606
1607 @pytest.mark.respx(base_url=base_url)
1608 @pytest.mark.asyncio
1609 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1610 class Model(BaseModel):
1611 foo: str
1612
1613 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1614
1615 with pytest.raises(APIResponseValidationError) as exc:
1616 await self.client.get("/foo", cast_to=Model)
1617
1618 assert isinstance(exc.value.__cause__, ValidationError)
1619
1620 async def test_client_max_retries_validation(self) -> None:
1621 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
1622 AsyncOpenAI(
1623 base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)
1624 )
1625
1626 @pytest.mark.respx(base_url=base_url)
1627 @pytest.mark.asyncio
1628 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1629 class Model(BaseModel):
1630 name: str
1631
1632 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1633
1634 stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1635 assert isinstance(stream, AsyncStream)
1636 await stream.response.aclose()
1637
1638 @pytest.mark.respx(base_url=base_url)
1639 @pytest.mark.asyncio
1640 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1641 class Model(BaseModel):
1642 name: str
1643
1644 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1645
1646 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1647
1648 with pytest.raises(APIResponseValidationError):
1649 await strict_client.get("/foo", cast_to=Model)
1650
1651 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1652
1653 response = await client.get("/foo", cast_to=Model)
1654 assert isinstance(response, str) # type: ignore[unreachable]
1655
1656 @pytest.mark.parametrize(
1657 "remaining_retries,retry_after,timeout",
1658 [
1659 [3, "20", 20],
1660 [3, "0", 0.5],
1661 [3, "-10", 0.5],
1662 [3, "60", 60],
1663 [3, "61", 0.5],
1664 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1665 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1666 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1667 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1668 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1669 [3, "99999999999999999999999999999999999", 0.5],
1670 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1671 [3, "", 0.5],
1672 [2, "", 0.5 * 2.0],
1673 [1, "", 0.5 * 4.0],
1674 [-1100, "", 8], # test large number potentially overflowing
1675 ],
1676 )
1677 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1678 @pytest.mark.asyncio
1679 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1680 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1681
1682 headers = httpx.Headers({"retry-after": retry_after})
1683 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1684 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1685 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1686
1687 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1688 @pytest.mark.respx(base_url=base_url)
1689 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1690 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1691
1692 with pytest.raises(APITimeoutError):
1693 await async_client.chat.completions.with_streaming_response.create(
1694 messages=[
1695 {
1696 "content": "string",
1697 "role": "developer",
1698 }
1699 ],
1700 model="gpt-4o",
1701 ).__aenter__()
1702
1703 assert _get_open_connections(self.client) == 0
1704
1705 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1706 @pytest.mark.respx(base_url=base_url)
1707 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1708 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1709
1710 with pytest.raises(APIStatusError):
1711 await async_client.chat.completions.with_streaming_response.create(
1712 messages=[
1713 {
1714 "content": "string",
1715 "role": "developer",
1716 }
1717 ],
1718 model="gpt-4o",
1719 ).__aenter__()
1720 assert _get_open_connections(self.client) == 0
1721
1722 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1723 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1724 @pytest.mark.respx(base_url=base_url)
1725 @pytest.mark.asyncio
1726 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
1727 async def test_retries_taken(
1728 self,
1729 async_client: AsyncOpenAI,
1730 failures_before_success: int,
1731 failure_mode: Literal["status", "exception"],
1732 respx_mock: MockRouter,
1733 ) -> None:
1734 client = async_client.with_options(max_retries=4)
1735
1736 nb_retries = 0
1737
1738 def retry_handler(_request: httpx.Request) -> httpx.Response:
1739 nonlocal nb_retries
1740 if nb_retries < failures_before_success:
1741 nb_retries += 1
1742 if failure_mode == "exception":
1743 raise RuntimeError("oops")
1744 return httpx.Response(500)
1745 return httpx.Response(200)
1746
1747 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1748
1749 response = await client.chat.completions.with_raw_response.create(
1750 messages=[
1751 {
1752 "content": "string",
1753 "role": "developer",
1754 }
1755 ],
1756 model="gpt-4o",
1757 )
1758
1759 assert response.retries_taken == failures_before_success
1760 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1761
1762 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1763 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1764 @pytest.mark.respx(base_url=base_url)
1765 @pytest.mark.asyncio
1766 async def test_omit_retry_count_header(
1767 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1768 ) -> None:
1769 client = async_client.with_options(max_retries=4)
1770
1771 nb_retries = 0
1772
1773 def retry_handler(_request: httpx.Request) -> httpx.Response:
1774 nonlocal nb_retries
1775 if nb_retries < failures_before_success:
1776 nb_retries += 1
1777 return httpx.Response(500)
1778 return httpx.Response(200)
1779
1780 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1781
1782 response = await client.chat.completions.with_raw_response.create(
1783 messages=[
1784 {
1785 "content": "string",
1786 "role": "developer",
1787 }
1788 ],
1789 model="gpt-4o",
1790 extra_headers={"x-stainless-retry-count": Omit()},
1791 )
1792
1793 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
1794
1795 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1796 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1797 @pytest.mark.respx(base_url=base_url)
1798 @pytest.mark.asyncio
1799 async def test_overwrite_retry_count_header(
1800 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1801 ) -> None:
1802 client = async_client.with_options(max_retries=4)
1803
1804 nb_retries = 0
1805
1806 def retry_handler(_request: httpx.Request) -> httpx.Response:
1807 nonlocal nb_retries
1808 if nb_retries < failures_before_success:
1809 nb_retries += 1
1810 return httpx.Response(500)
1811 return httpx.Response(200)
1812
1813 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1814
1815 response = await client.chat.completions.with_raw_response.create(
1816 messages=[
1817 {
1818 "content": "string",
1819 "role": "developer",
1820 }
1821 ],
1822 model="gpt-4o",
1823 extra_headers={"x-stainless-retry-count": "42"},
1824 )
1825
1826 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
1827
1828 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1829 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1830 @pytest.mark.respx(base_url=base_url)
1831 @pytest.mark.asyncio
1832 async def test_retries_taken_new_response_class(
1833 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1834 ) -> None:
1835 client = async_client.with_options(max_retries=4)
1836
1837 nb_retries = 0
1838
1839 def retry_handler(_request: httpx.Request) -> httpx.Response:
1840 nonlocal nb_retries
1841 if nb_retries < failures_before_success:
1842 nb_retries += 1
1843 return httpx.Response(500)
1844 return httpx.Response(200)
1845
1846 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1847
1848 async with client.chat.completions.with_streaming_response.create(
1849 messages=[
1850 {
1851 "content": "string",
1852 "role": "developer",
1853 }
1854 ],
1855 model="gpt-4o",
1856 ) as response:
1857 assert response.retries_taken == failures_before_success
1858 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1859
1860 async def test_get_platform(self) -> None:
1861 platform = await asyncify(get_platform)()
1862 assert isinstance(platform, (str, OtherPlatform))
1863
1864 async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
1865 # Test that the proxy environment variables are set correctly
1866 monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
1867
1868 client = DefaultAsyncHttpxClient()
1869
1870 mounts = tuple(client._mounts.items())
1871 assert len(mounts) == 1
1872 assert mounts[0][0].pattern == "https://"
1873
1874 @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
1875 async def test_default_client_creation(self) -> None:
1876 # Ensure that the client can be initialized without any exceptions
1877 DefaultAsyncHttpxClient(
1878 verify=True,
1879 cert=None,
1880 trust_env=True,
1881 http1=True,
1882 http2=False,
1883 limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
1884 )
1885
1886 @pytest.mark.respx(base_url=base_url)
1887 async def test_follow_redirects(self, respx_mock: MockRouter) -> None:
1888 # Test that the default follow_redirects=True allows following redirects
1889 respx_mock.post("/redirect").mock(
1890 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1891 )
1892 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
1893
1894 response = await self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
1895 assert response.status_code == 200
1896 assert response.json() == {"status": "ok"}
1897
1898 @pytest.mark.respx(base_url=base_url)
1899 async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
1900 # Test that follow_redirects=False prevents following redirects
1901 respx_mock.post("/redirect").mock(
1902 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1903 )
1904
1905 with pytest.raises(APIStatusError) as exc_info:
1906 await self.client.post(
1907 "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
1908 )
1909
1910 assert exc_info.value.response.status_code == 302
1911 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
1912
1913 @pytest.mark.asyncio
1914 async def test_api_key_before_after_refresh_provider(self) -> None:
1915 async def mock_api_key_provider():
1916 return "test_bearer_token"
1917
1918 client = AsyncOpenAI(base_url=base_url, api_key=mock_api_key_provider)
1919
1920 assert client.api_key == ""
1921 assert "Authorization" not in client.auth_headers
1922
1923 await client._refresh_api_key()
1924
1925 assert client.api_key == "test_bearer_token"
1926 assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
1927
1928 @pytest.mark.asyncio
1929 async def test_api_key_before_after_refresh_str(self) -> None:
1930 client = AsyncOpenAI(base_url=base_url, api_key="test_api_key")
1931
1932 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
1933 await client._refresh_api_key()
1934
1935 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
1936
1937 @pytest.mark.asyncio
1938 @pytest.mark.respx()
1939 async def test_bearer_token_refresh_async(self, respx_mock: MockRouter) -> None:
1940 respx_mock.post(base_url + "/chat/completions").mock(
1941 side_effect=[
1942 httpx.Response(500, json={"error": "server error"}),
1943 httpx.Response(200, json={"foo": "bar"}),
1944 ]
1945 )
1946
1947 counter = 0
1948
1949 async def token_provider() -> str:
1950 nonlocal counter
1951
1952 counter += 1
1953
1954 if counter == 1:
1955 return "first"
1956
1957 return "second"
1958
1959 client = AsyncOpenAI(base_url=base_url, api_key=token_provider)
1960 await client.chat.completions.create(messages=[], model="gpt-4")
1961
1962 calls = cast("list[MockRequestCall]", respx_mock.calls)
1963 assert len(calls) == 2
1964
1965 assert calls[0].request.headers.get("Authorization") == "Bearer first"
1966 assert calls[1].request.headers.get("Authorization") == "Bearer second"
1967
1968 @pytest.mark.asyncio
1969 async def test_copy_auth(self) -> None:
1970 async def token_provider_1() -> str:
1971 return "test_bearer_token_1"
1972
1973 async def token_provider_2() -> str:
1974 return "test_bearer_token_2"
1975
1976 client = AsyncOpenAI(base_url=base_url, api_key=token_provider_1).copy(api_key=token_provider_2)
1977 await client._refresh_api_key()
1978 assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}
1979