openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
15488ce07cb97535d8564e82dd5cda3481bc1a81

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1391lines · modecode

1# File generated from our OpenAPI spec by Stainless.
2
3from __future__ import annotations
4
5import gc
6import os
7import json
8import asyncio
9import inspect
10import tracemalloc
11from typing import Any, Union, cast
12from unittest import mock
13
14import httpx
15import pytest
16from respx import MockRouter
17from pydantic import ValidationError
18
19from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
20from openai._client import OpenAI, AsyncOpenAI
21from openai._models import BaseModel, FinalRequestOptions
22from openai._constants import RAW_RESPONSE_HEADER
23from openai._streaming import Stream, AsyncStream
24from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
25from openai._base_client import DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, BaseClient, make_request_options
26
27from .utils import update_env
28
29base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
30api_key = "My API Key"
31
32
33def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
34 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
35 url = httpx.URL(request.url)
36 return dict(url.params)
37
38
39def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
40 return 0.1
41
42
43def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
44 transport = client._client._transport
45 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
46
47 pool = transport._pool
48 return len(pool._requests)
49
50
51class TestOpenAI:
52 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
53
54 @pytest.mark.respx(base_url=base_url)
55 def test_raw_response(self, respx_mock: MockRouter) -> None:
56 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
57
58 response = self.client.post("/foo", cast_to=httpx.Response)
59 assert response.status_code == 200
60 assert isinstance(response, httpx.Response)
61 assert response.json() == {"foo": "bar"}
62
63 @pytest.mark.respx(base_url=base_url)
64 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
65 respx_mock.post("/foo").mock(
66 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
67 )
68
69 response = self.client.post("/foo", cast_to=httpx.Response)
70 assert response.status_code == 200
71 assert isinstance(response, httpx.Response)
72 assert response.json() == {"foo": "bar"}
73
74 def test_copy(self) -> None:
75 copied = self.client.copy()
76 assert id(copied) != id(self.client)
77
78 copied = self.client.copy(api_key="another My API Key")
79 assert copied.api_key == "another My API Key"
80 assert self.client.api_key == "My API Key"
81
82 def test_copy_default_options(self) -> None:
83 # options that have a default are overridden correctly
84 copied = self.client.copy(max_retries=7)
85 assert copied.max_retries == 7
86 assert self.client.max_retries == 2
87
88 copied2 = copied.copy(max_retries=6)
89 assert copied2.max_retries == 6
90 assert copied.max_retries == 7
91
92 # timeout
93 assert isinstance(self.client.timeout, httpx.Timeout)
94 copied = self.client.copy(timeout=None)
95 assert copied.timeout is None
96 assert isinstance(self.client.timeout, httpx.Timeout)
97
98 def test_copy_default_headers(self) -> None:
99 client = OpenAI(
100 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
101 )
102 assert client.default_headers["X-Foo"] == "bar"
103
104 # does not override the already given value when not specified
105 copied = client.copy()
106 assert copied.default_headers["X-Foo"] == "bar"
107
108 # merges already given headers
109 copied = client.copy(default_headers={"X-Bar": "stainless"})
110 assert copied.default_headers["X-Foo"] == "bar"
111 assert copied.default_headers["X-Bar"] == "stainless"
112
113 # uses new values for any already given headers
114 copied = client.copy(default_headers={"X-Foo": "stainless"})
115 assert copied.default_headers["X-Foo"] == "stainless"
116
117 # set_default_headers
118
119 # completely overrides already set values
120 copied = client.copy(set_default_headers={})
121 assert copied.default_headers.get("X-Foo") is None
122
123 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
124 assert copied.default_headers["X-Bar"] == "Robert"
125
126 with pytest.raises(
127 ValueError,
128 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
129 ):
130 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
131
132 def test_copy_default_query(self) -> None:
133 client = OpenAI(
134 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
135 )
136 assert _get_params(client)["foo"] == "bar"
137
138 # does not override the already given value when not specified
139 copied = client.copy()
140 assert _get_params(copied)["foo"] == "bar"
141
142 # merges already given params
143 copied = client.copy(default_query={"bar": "stainless"})
144 params = _get_params(copied)
145 assert params["foo"] == "bar"
146 assert params["bar"] == "stainless"
147
148 # uses new values for any already given headers
149 copied = client.copy(default_query={"foo": "stainless"})
150 assert _get_params(copied)["foo"] == "stainless"
151
152 # set_default_query
153
154 # completely overrides already set values
155 copied = client.copy(set_default_query={})
156 assert _get_params(copied) == {}
157
158 copied = client.copy(set_default_query={"bar": "Robert"})
159 assert _get_params(copied)["bar"] == "Robert"
160
161 with pytest.raises(
162 ValueError,
163 # TODO: update
164 match="`default_query` and `set_default_query` arguments are mutually exclusive",
165 ):
166 client.copy(set_default_query={}, default_query={"foo": "Bar"})
167
168 def test_copy_signature(self) -> None:
169 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
170 init_signature = inspect.signature(
171 # mypy doesn't like that we access the `__init__` property.
172 self.client.__init__, # type: ignore[misc]
173 )
174 copy_signature = inspect.signature(self.client.copy)
175 exclude_params = {"transport", "proxies", "_strict_response_validation"}
176
177 for name in init_signature.parameters.keys():
178 if name in exclude_params:
179 continue
180
181 copy_param = copy_signature.parameters.get(name)
182 assert copy_param is not None, f"copy() signature is missing the {name} param"
183
184 def test_copy_build_request(self) -> None:
185 options = FinalRequestOptions(method="get", url="/foo")
186
187 def build_request(options: FinalRequestOptions) -> None:
188 client = self.client.copy()
189 client._build_request(options)
190
191 # ensure that the machinery is warmed up before tracing starts.
192 build_request(options)
193 gc.collect()
194
195 tracemalloc.start(1000)
196
197 snapshot_before = tracemalloc.take_snapshot()
198
199 ITERATIONS = 10
200 for _ in range(ITERATIONS):
201 build_request(options)
202
203 gc.collect()
204 snapshot_after = tracemalloc.take_snapshot()
205
206 tracemalloc.stop()
207
208 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
209 if diff.count == 0:
210 # Avoid false positives by considering only leaks (i.e. allocations that persist).
211 return
212
213 if diff.count % ITERATIONS != 0:
214 # Avoid false positives by considering only leaks that appear per iteration.
215 return
216
217 for frame in diff.traceback:
218 if any(
219 frame.filename.endswith(fragment)
220 for fragment in [
221 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
222 #
223 # removing the decorator fixes the leak for reasons we don't understand.
224 "openai/_legacy_response.py",
225 "openai/_response.py",
226 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
227 "openai/_compat.py",
228 # Standard library leaks we don't care about.
229 "/logging/__init__.py",
230 ]
231 ):
232 return
233
234 leaks.append(diff)
235
236 leaks: list[tracemalloc.StatisticDiff] = []
237 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
238 add_leak(leaks, diff)
239 if leaks:
240 for leak in leaks:
241 print("MEMORY LEAK:", leak)
242 for frame in leak.traceback:
243 print(frame)
244 raise AssertionError()
245
246 def test_request_timeout(self) -> None:
247 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
248 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
249 assert timeout == DEFAULT_TIMEOUT
250
251 request = self.client._build_request(
252 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
253 )
254 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
255 assert timeout == httpx.Timeout(100.0)
256
257 def test_client_timeout_option(self) -> None:
258 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
259
260 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
261 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
262 assert timeout == httpx.Timeout(0)
263
264 def test_http_client_timeout_option(self) -> None:
265 # custom timeout given to the httpx client should be used
266 with httpx.Client(timeout=None) as http_client:
267 client = OpenAI(
268 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
269 )
270
271 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
272 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
273 assert timeout == httpx.Timeout(None)
274
275 # no timeout given to the httpx client should not use the httpx default
276 with httpx.Client() as http_client:
277 client = OpenAI(
278 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
279 )
280
281 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
282 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
283 assert timeout == DEFAULT_TIMEOUT
284
285 # explicitly passing the default timeout currently results in it being ignored
286 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
287 client = OpenAI(
288 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
289 )
290
291 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
292 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
293 assert timeout == DEFAULT_TIMEOUT # our default
294
295 def test_default_headers_option(self) -> None:
296 client = OpenAI(
297 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
298 )
299 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
300 assert request.headers.get("x-foo") == "bar"
301 assert request.headers.get("x-stainless-lang") == "python"
302
303 client2 = OpenAI(
304 base_url=base_url,
305 api_key=api_key,
306 _strict_response_validation=True,
307 default_headers={
308 "X-Foo": "stainless",
309 "X-Stainless-Lang": "my-overriding-header",
310 },
311 )
312 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
313 assert request.headers.get("x-foo") == "stainless"
314 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
315
316 def test_validate_headers(self) -> None:
317 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
318 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
319 assert request.headers.get("Authorization") == f"Bearer {api_key}"
320
321 with pytest.raises(OpenAIError):
322 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
323 _ = client2
324
325 def test_default_query_option(self) -> None:
326 client = OpenAI(
327 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
328 )
329 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
330 url = httpx.URL(request.url)
331 assert dict(url.params) == {"query_param": "bar"}
332
333 request = client._build_request(
334 FinalRequestOptions(
335 method="get",
336 url="/foo",
337 params={"foo": "baz", "query_param": "overriden"},
338 )
339 )
340 url = httpx.URL(request.url)
341 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
342
343 def test_request_extra_json(self) -> None:
344 request = self.client._build_request(
345 FinalRequestOptions(
346 method="post",
347 url="/foo",
348 json_data={"foo": "bar"},
349 extra_json={"baz": False},
350 ),
351 )
352 data = json.loads(request.content.decode("utf-8"))
353 assert data == {"foo": "bar", "baz": False}
354
355 request = self.client._build_request(
356 FinalRequestOptions(
357 method="post",
358 url="/foo",
359 extra_json={"baz": False},
360 ),
361 )
362 data = json.loads(request.content.decode("utf-8"))
363 assert data == {"baz": False}
364
365 # `extra_json` takes priority over `json_data` when keys clash
366 request = self.client._build_request(
367 FinalRequestOptions(
368 method="post",
369 url="/foo",
370 json_data={"foo": "bar", "baz": True},
371 extra_json={"baz": None},
372 ),
373 )
374 data = json.loads(request.content.decode("utf-8"))
375 assert data == {"foo": "bar", "baz": None}
376
377 def test_request_extra_headers(self) -> None:
378 request = self.client._build_request(
379 FinalRequestOptions(
380 method="post",
381 url="/foo",
382 **make_request_options(extra_headers={"X-Foo": "Foo"}),
383 ),
384 )
385 assert request.headers.get("X-Foo") == "Foo"
386
387 # `extra_headers` takes priority over `default_headers` when keys clash
388 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
389 FinalRequestOptions(
390 method="post",
391 url="/foo",
392 **make_request_options(
393 extra_headers={"X-Bar": "false"},
394 ),
395 ),
396 )
397 assert request.headers.get("X-Bar") == "false"
398
399 def test_request_extra_query(self) -> None:
400 request = self.client._build_request(
401 FinalRequestOptions(
402 method="post",
403 url="/foo",
404 **make_request_options(
405 extra_query={"my_query_param": "Foo"},
406 ),
407 ),
408 )
409 params = dict(request.url.params)
410 assert params == {"my_query_param": "Foo"}
411
412 # if both `query` and `extra_query` are given, they are merged
413 request = self.client._build_request(
414 FinalRequestOptions(
415 method="post",
416 url="/foo",
417 **make_request_options(
418 query={"bar": "1"},
419 extra_query={"foo": "2"},
420 ),
421 ),
422 )
423 params = dict(request.url.params)
424 assert params == {"bar": "1", "foo": "2"}
425
426 # `extra_query` takes priority over `query` when keys clash
427 request = self.client._build_request(
428 FinalRequestOptions(
429 method="post",
430 url="/foo",
431 **make_request_options(
432 query={"foo": "1"},
433 extra_query={"foo": "2"},
434 ),
435 ),
436 )
437 params = dict(request.url.params)
438 assert params == {"foo": "2"}
439
440 @pytest.mark.respx(base_url=base_url)
441 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
442 class Model1(BaseModel):
443 name: str
444
445 class Model2(BaseModel):
446 foo: str
447
448 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
449
450 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
451 assert isinstance(response, Model2)
452 assert response.foo == "bar"
453
454 @pytest.mark.respx(base_url=base_url)
455 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
456 """Union of objects with the same field name using a different type"""
457
458 class Model1(BaseModel):
459 foo: int
460
461 class Model2(BaseModel):
462 foo: str
463
464 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
465
466 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
467 assert isinstance(response, Model2)
468 assert response.foo == "bar"
469
470 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
471
472 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
473 assert isinstance(response, Model1)
474 assert response.foo == 1
475
476 @pytest.mark.respx(base_url=base_url)
477 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
478 """
479 Response that sets Content-Type to something other than application/json but returns json data
480 """
481
482 class Model(BaseModel):
483 foo: int
484
485 respx_mock.get("/foo").mock(
486 return_value=httpx.Response(
487 200,
488 content=json.dumps({"foo": 2}),
489 headers={"Content-Type": "application/text"},
490 )
491 )
492
493 response = self.client.get("/foo", cast_to=Model)
494 assert isinstance(response, Model)
495 assert response.foo == 2
496
497 def test_base_url_setter(self) -> None:
498 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
499 assert client.base_url == "https://example.com/from_init/"
500
501 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
502
503 assert client.base_url == "https://example.com/from_setter/"
504
505 def test_base_url_env(self) -> None:
506 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
507 client = OpenAI(api_key=api_key, _strict_response_validation=True)
508 assert client.base_url == "http://localhost:5000/from/env/"
509
510 @pytest.mark.parametrize(
511 "client",
512 [
513 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
514 OpenAI(
515 base_url="http://localhost:5000/custom/path/",
516 api_key=api_key,
517 _strict_response_validation=True,
518 http_client=httpx.Client(),
519 ),
520 ],
521 ids=["standard", "custom http client"],
522 )
523 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
524 request = client._build_request(
525 FinalRequestOptions(
526 method="post",
527 url="/foo",
528 json_data={"foo": "bar"},
529 ),
530 )
531 assert request.url == "http://localhost:5000/custom/path/foo"
532
533 @pytest.mark.parametrize(
534 "client",
535 [
536 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
537 OpenAI(
538 base_url="http://localhost:5000/custom/path/",
539 api_key=api_key,
540 _strict_response_validation=True,
541 http_client=httpx.Client(),
542 ),
543 ],
544 ids=["standard", "custom http client"],
545 )
546 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
547 request = client._build_request(
548 FinalRequestOptions(
549 method="post",
550 url="/foo",
551 json_data={"foo": "bar"},
552 ),
553 )
554 assert request.url == "http://localhost:5000/custom/path/foo"
555
556 @pytest.mark.parametrize(
557 "client",
558 [
559 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
560 OpenAI(
561 base_url="http://localhost:5000/custom/path/",
562 api_key=api_key,
563 _strict_response_validation=True,
564 http_client=httpx.Client(),
565 ),
566 ],
567 ids=["standard", "custom http client"],
568 )
569 def test_absolute_request_url(self, client: OpenAI) -> None:
570 request = client._build_request(
571 FinalRequestOptions(
572 method="post",
573 url="https://myapi.com/foo",
574 json_data={"foo": "bar"},
575 ),
576 )
577 assert request.url == "https://myapi.com/foo"
578
579 def test_copied_client_does_not_close_http(self) -> None:
580 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
581 assert not client.is_closed()
582
583 copied = client.copy()
584 assert copied is not client
585
586 del copied
587
588 assert not client.is_closed()
589
590 def test_client_context_manager(self) -> None:
591 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
592 with client as c2:
593 assert c2 is client
594 assert not c2.is_closed()
595 assert not client.is_closed()
596 assert client.is_closed()
597
598 @pytest.mark.respx(base_url=base_url)
599 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
600 class Model(BaseModel):
601 foo: str
602
603 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
604
605 with pytest.raises(APIResponseValidationError) as exc:
606 self.client.get("/foo", cast_to=Model)
607
608 assert isinstance(exc.value.__cause__, ValidationError)
609
610 @pytest.mark.respx(base_url=base_url)
611 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
612 class Model(BaseModel):
613 name: str
614
615 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
616
617 stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
618 assert isinstance(stream, Stream)
619 stream.response.close()
620
621 @pytest.mark.respx(base_url=base_url)
622 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
623 class Model(BaseModel):
624 name: str
625
626 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
627
628 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
629
630 with pytest.raises(APIResponseValidationError):
631 strict_client.get("/foo", cast_to=Model)
632
633 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
634
635 response = client.get("/foo", cast_to=Model)
636 assert isinstance(response, str) # type: ignore[unreachable]
637
638 @pytest.mark.parametrize(
639 "remaining_retries,retry_after,timeout",
640 [
641 [3, "20", 20],
642 [3, "0", 0.5],
643 [3, "-10", 0.5],
644 [3, "60", 60],
645 [3, "61", 0.5],
646 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
647 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
648 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
649 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
650 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
651 [3, "99999999999999999999999999999999999", 0.5],
652 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
653 [3, "", 0.5],
654 [2, "", 0.5 * 2.0],
655 [1, "", 0.5 * 4.0],
656 ],
657 )
658 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
659 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
660 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
661
662 headers = httpx.Headers({"retry-after": retry_after})
663 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
664 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
665 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
666
667 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
668 @pytest.mark.respx(base_url=base_url)
669 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
670 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
671
672 with pytest.raises(APITimeoutError):
673 self.client.post(
674 "/chat/completions",
675 body=dict(
676 messages=[
677 {
678 "role": "user",
679 "content": "Say this is a test",
680 }
681 ],
682 model="gpt-3.5-turbo",
683 ),
684 cast_to=httpx.Response,
685 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
686 )
687
688 assert _get_open_connections(self.client) == 0
689
690 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
691 @pytest.mark.respx(base_url=base_url)
692 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
693 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
694
695 with pytest.raises(APIStatusError):
696 self.client.post(
697 "/chat/completions",
698 body=dict(
699 messages=[
700 {
701 "role": "user",
702 "content": "Say this is a test",
703 }
704 ],
705 model="gpt-3.5-turbo",
706 ),
707 cast_to=httpx.Response,
708 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
709 )
710
711 assert _get_open_connections(self.client) == 0
712
713
714class TestAsyncOpenAI:
715 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
716
717 @pytest.mark.respx(base_url=base_url)
718 @pytest.mark.asyncio
719 async def test_raw_response(self, respx_mock: MockRouter) -> None:
720 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
721
722 response = await self.client.post("/foo", cast_to=httpx.Response)
723 assert response.status_code == 200
724 assert isinstance(response, httpx.Response)
725 assert response.json() == {"foo": "bar"}
726
727 @pytest.mark.respx(base_url=base_url)
728 @pytest.mark.asyncio
729 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
730 respx_mock.post("/foo").mock(
731 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
732 )
733
734 response = await self.client.post("/foo", cast_to=httpx.Response)
735 assert response.status_code == 200
736 assert isinstance(response, httpx.Response)
737 assert response.json() == {"foo": "bar"}
738
739 def test_copy(self) -> None:
740 copied = self.client.copy()
741 assert id(copied) != id(self.client)
742
743 copied = self.client.copy(api_key="another My API Key")
744 assert copied.api_key == "another My API Key"
745 assert self.client.api_key == "My API Key"
746
747 def test_copy_default_options(self) -> None:
748 # options that have a default are overridden correctly
749 copied = self.client.copy(max_retries=7)
750 assert copied.max_retries == 7
751 assert self.client.max_retries == 2
752
753 copied2 = copied.copy(max_retries=6)
754 assert copied2.max_retries == 6
755 assert copied.max_retries == 7
756
757 # timeout
758 assert isinstance(self.client.timeout, httpx.Timeout)
759 copied = self.client.copy(timeout=None)
760 assert copied.timeout is None
761 assert isinstance(self.client.timeout, httpx.Timeout)
762
763 def test_copy_default_headers(self) -> None:
764 client = AsyncOpenAI(
765 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
766 )
767 assert client.default_headers["X-Foo"] == "bar"
768
769 # does not override the already given value when not specified
770 copied = client.copy()
771 assert copied.default_headers["X-Foo"] == "bar"
772
773 # merges already given headers
774 copied = client.copy(default_headers={"X-Bar": "stainless"})
775 assert copied.default_headers["X-Foo"] == "bar"
776 assert copied.default_headers["X-Bar"] == "stainless"
777
778 # uses new values for any already given headers
779 copied = client.copy(default_headers={"X-Foo": "stainless"})
780 assert copied.default_headers["X-Foo"] == "stainless"
781
782 # set_default_headers
783
784 # completely overrides already set values
785 copied = client.copy(set_default_headers={})
786 assert copied.default_headers.get("X-Foo") is None
787
788 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
789 assert copied.default_headers["X-Bar"] == "Robert"
790
791 with pytest.raises(
792 ValueError,
793 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
794 ):
795 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
796
797 def test_copy_default_query(self) -> None:
798 client = AsyncOpenAI(
799 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
800 )
801 assert _get_params(client)["foo"] == "bar"
802
803 # does not override the already given value when not specified
804 copied = client.copy()
805 assert _get_params(copied)["foo"] == "bar"
806
807 # merges already given params
808 copied = client.copy(default_query={"bar": "stainless"})
809 params = _get_params(copied)
810 assert params["foo"] == "bar"
811 assert params["bar"] == "stainless"
812
813 # uses new values for any already given headers
814 copied = client.copy(default_query={"foo": "stainless"})
815 assert _get_params(copied)["foo"] == "stainless"
816
817 # set_default_query
818
819 # completely overrides already set values
820 copied = client.copy(set_default_query={})
821 assert _get_params(copied) == {}
822
823 copied = client.copy(set_default_query={"bar": "Robert"})
824 assert _get_params(copied)["bar"] == "Robert"
825
826 with pytest.raises(
827 ValueError,
828 # TODO: update
829 match="`default_query` and `set_default_query` arguments are mutually exclusive",
830 ):
831 client.copy(set_default_query={}, default_query={"foo": "Bar"})
832
833 def test_copy_signature(self) -> None:
834 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
835 init_signature = inspect.signature(
836 # mypy doesn't like that we access the `__init__` property.
837 self.client.__init__, # type: ignore[misc]
838 )
839 copy_signature = inspect.signature(self.client.copy)
840 exclude_params = {"transport", "proxies", "_strict_response_validation"}
841
842 for name in init_signature.parameters.keys():
843 if name in exclude_params:
844 continue
845
846 copy_param = copy_signature.parameters.get(name)
847 assert copy_param is not None, f"copy() signature is missing the {name} param"
848
849 def test_copy_build_request(self) -> None:
850 options = FinalRequestOptions(method="get", url="/foo")
851
852 def build_request(options: FinalRequestOptions) -> None:
853 client = self.client.copy()
854 client._build_request(options)
855
856 # ensure that the machinery is warmed up before tracing starts.
857 build_request(options)
858 gc.collect()
859
860 tracemalloc.start(1000)
861
862 snapshot_before = tracemalloc.take_snapshot()
863
864 ITERATIONS = 10
865 for _ in range(ITERATIONS):
866 build_request(options)
867
868 gc.collect()
869 snapshot_after = tracemalloc.take_snapshot()
870
871 tracemalloc.stop()
872
873 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
874 if diff.count == 0:
875 # Avoid false positives by considering only leaks (i.e. allocations that persist).
876 return
877
878 if diff.count % ITERATIONS != 0:
879 # Avoid false positives by considering only leaks that appear per iteration.
880 return
881
882 for frame in diff.traceback:
883 if any(
884 frame.filename.endswith(fragment)
885 for fragment in [
886 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
887 #
888 # removing the decorator fixes the leak for reasons we don't understand.
889 "openai/_legacy_response.py",
890 "openai/_response.py",
891 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
892 "openai/_compat.py",
893 # Standard library leaks we don't care about.
894 "/logging/__init__.py",
895 ]
896 ):
897 return
898
899 leaks.append(diff)
900
901 leaks: list[tracemalloc.StatisticDiff] = []
902 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
903 add_leak(leaks, diff)
904 if leaks:
905 for leak in leaks:
906 print("MEMORY LEAK:", leak)
907 for frame in leak.traceback:
908 print(frame)
909 raise AssertionError()
910
911 async def test_request_timeout(self) -> None:
912 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
913 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
914 assert timeout == DEFAULT_TIMEOUT
915
916 request = self.client._build_request(
917 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
918 )
919 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
920 assert timeout == httpx.Timeout(100.0)
921
922 async def test_client_timeout_option(self) -> None:
923 client = AsyncOpenAI(
924 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
925 )
926
927 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
928 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
929 assert timeout == httpx.Timeout(0)
930
931 async def test_http_client_timeout_option(self) -> None:
932 # custom timeout given to the httpx client should be used
933 async with httpx.AsyncClient(timeout=None) as http_client:
934 client = AsyncOpenAI(
935 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
936 )
937
938 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
939 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
940 assert timeout == httpx.Timeout(None)
941
942 # no timeout given to the httpx client should not use the httpx default
943 async with httpx.AsyncClient() as http_client:
944 client = AsyncOpenAI(
945 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
946 )
947
948 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
949 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
950 assert timeout == DEFAULT_TIMEOUT
951
952 # explicitly passing the default timeout currently results in it being ignored
953 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
954 client = AsyncOpenAI(
955 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
956 )
957
958 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
959 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
960 assert timeout == DEFAULT_TIMEOUT # our default
961
962 def test_default_headers_option(self) -> None:
963 client = AsyncOpenAI(
964 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
965 )
966 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
967 assert request.headers.get("x-foo") == "bar"
968 assert request.headers.get("x-stainless-lang") == "python"
969
970 client2 = AsyncOpenAI(
971 base_url=base_url,
972 api_key=api_key,
973 _strict_response_validation=True,
974 default_headers={
975 "X-Foo": "stainless",
976 "X-Stainless-Lang": "my-overriding-header",
977 },
978 )
979 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
980 assert request.headers.get("x-foo") == "stainless"
981 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
982
983 def test_validate_headers(self) -> None:
984 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
985 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
986 assert request.headers.get("Authorization") == f"Bearer {api_key}"
987
988 with pytest.raises(OpenAIError):
989 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
990 _ = client2
991
992 def test_default_query_option(self) -> None:
993 client = AsyncOpenAI(
994 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
995 )
996 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
997 url = httpx.URL(request.url)
998 assert dict(url.params) == {"query_param": "bar"}
999
1000 request = client._build_request(
1001 FinalRequestOptions(
1002 method="get",
1003 url="/foo",
1004 params={"foo": "baz", "query_param": "overriden"},
1005 )
1006 )
1007 url = httpx.URL(request.url)
1008 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
1009
1010 def test_request_extra_json(self) -> None:
1011 request = self.client._build_request(
1012 FinalRequestOptions(
1013 method="post",
1014 url="/foo",
1015 json_data={"foo": "bar"},
1016 extra_json={"baz": False},
1017 ),
1018 )
1019 data = json.loads(request.content.decode("utf-8"))
1020 assert data == {"foo": "bar", "baz": False}
1021
1022 request = self.client._build_request(
1023 FinalRequestOptions(
1024 method="post",
1025 url="/foo",
1026 extra_json={"baz": False},
1027 ),
1028 )
1029 data = json.loads(request.content.decode("utf-8"))
1030 assert data == {"baz": False}
1031
1032 # `extra_json` takes priority over `json_data` when keys clash
1033 request = self.client._build_request(
1034 FinalRequestOptions(
1035 method="post",
1036 url="/foo",
1037 json_data={"foo": "bar", "baz": True},
1038 extra_json={"baz": None},
1039 ),
1040 )
1041 data = json.loads(request.content.decode("utf-8"))
1042 assert data == {"foo": "bar", "baz": None}
1043
1044 def test_request_extra_headers(self) -> None:
1045 request = self.client._build_request(
1046 FinalRequestOptions(
1047 method="post",
1048 url="/foo",
1049 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1050 ),
1051 )
1052 assert request.headers.get("X-Foo") == "Foo"
1053
1054 # `extra_headers` takes priority over `default_headers` when keys clash
1055 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1056 FinalRequestOptions(
1057 method="post",
1058 url="/foo",
1059 **make_request_options(
1060 extra_headers={"X-Bar": "false"},
1061 ),
1062 ),
1063 )
1064 assert request.headers.get("X-Bar") == "false"
1065
1066 def test_request_extra_query(self) -> None:
1067 request = self.client._build_request(
1068 FinalRequestOptions(
1069 method="post",
1070 url="/foo",
1071 **make_request_options(
1072 extra_query={"my_query_param": "Foo"},
1073 ),
1074 ),
1075 )
1076 params = dict(request.url.params)
1077 assert params == {"my_query_param": "Foo"}
1078
1079 # if both `query` and `extra_query` are given, they are merged
1080 request = self.client._build_request(
1081 FinalRequestOptions(
1082 method="post",
1083 url="/foo",
1084 **make_request_options(
1085 query={"bar": "1"},
1086 extra_query={"foo": "2"},
1087 ),
1088 ),
1089 )
1090 params = dict(request.url.params)
1091 assert params == {"bar": "1", "foo": "2"}
1092
1093 # `extra_query` takes priority over `query` when keys clash
1094 request = self.client._build_request(
1095 FinalRequestOptions(
1096 method="post",
1097 url="/foo",
1098 **make_request_options(
1099 query={"foo": "1"},
1100 extra_query={"foo": "2"},
1101 ),
1102 ),
1103 )
1104 params = dict(request.url.params)
1105 assert params == {"foo": "2"}
1106
1107 @pytest.mark.respx(base_url=base_url)
1108 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1109 class Model1(BaseModel):
1110 name: str
1111
1112 class Model2(BaseModel):
1113 foo: str
1114
1115 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1116
1117 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1118 assert isinstance(response, Model2)
1119 assert response.foo == "bar"
1120
1121 @pytest.mark.respx(base_url=base_url)
1122 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1123 """Union of objects with the same field name using a different type"""
1124
1125 class Model1(BaseModel):
1126 foo: int
1127
1128 class Model2(BaseModel):
1129 foo: str
1130
1131 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1132
1133 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1134 assert isinstance(response, Model2)
1135 assert response.foo == "bar"
1136
1137 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1138
1139 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1140 assert isinstance(response, Model1)
1141 assert response.foo == 1
1142
1143 @pytest.mark.respx(base_url=base_url)
1144 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1145 """
1146 Response that sets Content-Type to something other than application/json but returns json data
1147 """
1148
1149 class Model(BaseModel):
1150 foo: int
1151
1152 respx_mock.get("/foo").mock(
1153 return_value=httpx.Response(
1154 200,
1155 content=json.dumps({"foo": 2}),
1156 headers={"Content-Type": "application/text"},
1157 )
1158 )
1159
1160 response = await self.client.get("/foo", cast_to=Model)
1161 assert isinstance(response, Model)
1162 assert response.foo == 2
1163
1164 def test_base_url_setter(self) -> None:
1165 client = AsyncOpenAI(
1166 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1167 )
1168 assert client.base_url == "https://example.com/from_init/"
1169
1170 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1171
1172 assert client.base_url == "https://example.com/from_setter/"
1173
1174 def test_base_url_env(self) -> None:
1175 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1176 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1177 assert client.base_url == "http://localhost:5000/from/env/"
1178
1179 @pytest.mark.parametrize(
1180 "client",
1181 [
1182 AsyncOpenAI(
1183 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1184 ),
1185 AsyncOpenAI(
1186 base_url="http://localhost:5000/custom/path/",
1187 api_key=api_key,
1188 _strict_response_validation=True,
1189 http_client=httpx.AsyncClient(),
1190 ),
1191 ],
1192 ids=["standard", "custom http client"],
1193 )
1194 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1195 request = client._build_request(
1196 FinalRequestOptions(
1197 method="post",
1198 url="/foo",
1199 json_data={"foo": "bar"},
1200 ),
1201 )
1202 assert request.url == "http://localhost:5000/custom/path/foo"
1203
1204 @pytest.mark.parametrize(
1205 "client",
1206 [
1207 AsyncOpenAI(
1208 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1209 ),
1210 AsyncOpenAI(
1211 base_url="http://localhost:5000/custom/path/",
1212 api_key=api_key,
1213 _strict_response_validation=True,
1214 http_client=httpx.AsyncClient(),
1215 ),
1216 ],
1217 ids=["standard", "custom http client"],
1218 )
1219 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1220 request = client._build_request(
1221 FinalRequestOptions(
1222 method="post",
1223 url="/foo",
1224 json_data={"foo": "bar"},
1225 ),
1226 )
1227 assert request.url == "http://localhost:5000/custom/path/foo"
1228
1229 @pytest.mark.parametrize(
1230 "client",
1231 [
1232 AsyncOpenAI(
1233 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1234 ),
1235 AsyncOpenAI(
1236 base_url="http://localhost:5000/custom/path/",
1237 api_key=api_key,
1238 _strict_response_validation=True,
1239 http_client=httpx.AsyncClient(),
1240 ),
1241 ],
1242 ids=["standard", "custom http client"],
1243 )
1244 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1245 request = client._build_request(
1246 FinalRequestOptions(
1247 method="post",
1248 url="https://myapi.com/foo",
1249 json_data={"foo": "bar"},
1250 ),
1251 )
1252 assert request.url == "https://myapi.com/foo"
1253
1254 async def test_copied_client_does_not_close_http(self) -> None:
1255 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1256 assert not client.is_closed()
1257
1258 copied = client.copy()
1259 assert copied is not client
1260
1261 del copied
1262
1263 await asyncio.sleep(0.2)
1264 assert not client.is_closed()
1265
1266 async def test_client_context_manager(self) -> None:
1267 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1268 async with client as c2:
1269 assert c2 is client
1270 assert not c2.is_closed()
1271 assert not client.is_closed()
1272 assert client.is_closed()
1273
1274 @pytest.mark.respx(base_url=base_url)
1275 @pytest.mark.asyncio
1276 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1277 class Model(BaseModel):
1278 foo: str
1279
1280 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1281
1282 with pytest.raises(APIResponseValidationError) as exc:
1283 await self.client.get("/foo", cast_to=Model)
1284
1285 assert isinstance(exc.value.__cause__, ValidationError)
1286
1287 @pytest.mark.respx(base_url=base_url)
1288 @pytest.mark.asyncio
1289 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1290 class Model(BaseModel):
1291 name: str
1292
1293 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1294
1295 stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1296 assert isinstance(stream, AsyncStream)
1297 await stream.response.aclose()
1298
1299 @pytest.mark.respx(base_url=base_url)
1300 @pytest.mark.asyncio
1301 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1302 class Model(BaseModel):
1303 name: str
1304
1305 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1306
1307 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1308
1309 with pytest.raises(APIResponseValidationError):
1310 await strict_client.get("/foo", cast_to=Model)
1311
1312 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1313
1314 response = await client.get("/foo", cast_to=Model)
1315 assert isinstance(response, str) # type: ignore[unreachable]
1316
1317 @pytest.mark.parametrize(
1318 "remaining_retries,retry_after,timeout",
1319 [
1320 [3, "20", 20],
1321 [3, "0", 0.5],
1322 [3, "-10", 0.5],
1323 [3, "60", 60],
1324 [3, "61", 0.5],
1325 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1326 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1327 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1328 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1329 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1330 [3, "99999999999999999999999999999999999", 0.5],
1331 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1332 [3, "", 0.5],
1333 [2, "", 0.5 * 2.0],
1334 [1, "", 0.5 * 4.0],
1335 ],
1336 )
1337 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1338 @pytest.mark.asyncio
1339 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1340 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1341
1342 headers = httpx.Headers({"retry-after": retry_after})
1343 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1344 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1345 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1346
1347 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1348 @pytest.mark.respx(base_url=base_url)
1349 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1350 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1351
1352 with pytest.raises(APITimeoutError):
1353 await self.client.post(
1354 "/chat/completions",
1355 body=dict(
1356 messages=[
1357 {
1358 "role": "user",
1359 "content": "Say this is a test",
1360 }
1361 ],
1362 model="gpt-3.5-turbo",
1363 ),
1364 cast_to=httpx.Response,
1365 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1366 )
1367
1368 assert _get_open_connections(self.client) == 0
1369
1370 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1371 @pytest.mark.respx(base_url=base_url)
1372 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1373 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1374
1375 with pytest.raises(APIStatusError):
1376 await self.client.post(
1377 "/chat/completions",
1378 body=dict(
1379 messages=[
1380 {
1381 "role": "user",
1382 "content": "Say this is a test",
1383 }
1384 ],
1385 model="gpt-3.5-turbo",
1386 ),
1387 cast_to=httpx.Response,
1388 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1389 )
1390
1391 assert _get_open_connections(self.client) == 0
1392