openai/openai-python

Public

mirrored from https://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.3.9

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1396lines · modecode

1# File generated from our OpenAPI spec by Stainless.
2
3from __future__ import annotations
4
5import gc
6import os
7import json
8import asyncio
9import inspect
10import tracemalloc
11from typing import Any, Union, cast
12from unittest import mock
13
14import httpx
15import pytest
16from respx import MockRouter
17from pydantic import ValidationError
18
19from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
20from openai._client import OpenAI, AsyncOpenAI
21from openai._models import BaseModel, FinalRequestOptions
22from openai._streaming import Stream, AsyncStream
23from openai._exceptions import (
24 OpenAIError,
25 APIStatusError,
26 APITimeoutError,
27 APIResponseValidationError,
28)
29from openai._base_client import (
30 DEFAULT_TIMEOUT,
31 HTTPX_DEFAULT_TIMEOUT,
32 BaseClient,
33 make_request_options,
34)
35
36from .utils import update_env
37
38base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
39api_key = "My API Key"
40
41
42def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
43 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
44 url = httpx.URL(request.url)
45 return dict(url.params)
46
47
48def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
49 return 0.1
50
51
52def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
53 transport = client._client._transport
54 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
55
56 pool = transport._pool
57 return len(pool._requests)
58
59
60class TestOpenAI:
61 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
62
63 @pytest.mark.respx(base_url=base_url)
64 def test_raw_response(self, respx_mock: MockRouter) -> None:
65 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
66
67 response = self.client.post("/foo", cast_to=httpx.Response)
68 assert response.status_code == 200
69 assert isinstance(response, httpx.Response)
70 assert response.json() == {"foo": "bar"}
71
72 @pytest.mark.respx(base_url=base_url)
73 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
74 respx_mock.post("/foo").mock(
75 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
76 )
77
78 response = self.client.post("/foo", cast_to=httpx.Response)
79 assert response.status_code == 200
80 assert isinstance(response, httpx.Response)
81 assert response.json() == {"foo": "bar"}
82
83 def test_copy(self) -> None:
84 copied = self.client.copy()
85 assert id(copied) != id(self.client)
86
87 copied = self.client.copy(api_key="another My API Key")
88 assert copied.api_key == "another My API Key"
89 assert self.client.api_key == "My API Key"
90
91 def test_copy_default_options(self) -> None:
92 # options that have a default are overridden correctly
93 copied = self.client.copy(max_retries=7)
94 assert copied.max_retries == 7
95 assert self.client.max_retries == 2
96
97 copied2 = copied.copy(max_retries=6)
98 assert copied2.max_retries == 6
99 assert copied.max_retries == 7
100
101 # timeout
102 assert isinstance(self.client.timeout, httpx.Timeout)
103 copied = self.client.copy(timeout=None)
104 assert copied.timeout is None
105 assert isinstance(self.client.timeout, httpx.Timeout)
106
107 def test_copy_default_headers(self) -> None:
108 client = OpenAI(
109 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
110 )
111 assert client.default_headers["X-Foo"] == "bar"
112
113 # does not override the already given value when not specified
114 copied = client.copy()
115 assert copied.default_headers["X-Foo"] == "bar"
116
117 # merges already given headers
118 copied = client.copy(default_headers={"X-Bar": "stainless"})
119 assert copied.default_headers["X-Foo"] == "bar"
120 assert copied.default_headers["X-Bar"] == "stainless"
121
122 # uses new values for any already given headers
123 copied = client.copy(default_headers={"X-Foo": "stainless"})
124 assert copied.default_headers["X-Foo"] == "stainless"
125
126 # set_default_headers
127
128 # completely overrides already set values
129 copied = client.copy(set_default_headers={})
130 assert copied.default_headers.get("X-Foo") is None
131
132 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
133 assert copied.default_headers["X-Bar"] == "Robert"
134
135 with pytest.raises(
136 ValueError,
137 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
138 ):
139 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
140
141 def test_copy_default_query(self) -> None:
142 client = OpenAI(
143 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
144 )
145 assert _get_params(client)["foo"] == "bar"
146
147 # does not override the already given value when not specified
148 copied = client.copy()
149 assert _get_params(copied)["foo"] == "bar"
150
151 # merges already given params
152 copied = client.copy(default_query={"bar": "stainless"})
153 params = _get_params(copied)
154 assert params["foo"] == "bar"
155 assert params["bar"] == "stainless"
156
157 # uses new values for any already given headers
158 copied = client.copy(default_query={"foo": "stainless"})
159 assert _get_params(copied)["foo"] == "stainless"
160
161 # set_default_query
162
163 # completely overrides already set values
164 copied = client.copy(set_default_query={})
165 assert _get_params(copied) == {}
166
167 copied = client.copy(set_default_query={"bar": "Robert"})
168 assert _get_params(copied)["bar"] == "Robert"
169
170 with pytest.raises(
171 ValueError,
172 # TODO: update
173 match="`default_query` and `set_default_query` arguments are mutually exclusive",
174 ):
175 client.copy(set_default_query={}, default_query={"foo": "Bar"})
176
177 def test_copy_signature(self) -> None:
178 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
179 init_signature = inspect.signature(
180 # mypy doesn't like that we access the `__init__` property.
181 self.client.__init__, # type: ignore[misc]
182 )
183 copy_signature = inspect.signature(self.client.copy)
184 exclude_params = {"transport", "proxies", "_strict_response_validation"}
185
186 for name in init_signature.parameters.keys():
187 if name in exclude_params:
188 continue
189
190 copy_param = copy_signature.parameters.get(name)
191 assert copy_param is not None, f"copy() signature is missing the {name} param"
192
193 def test_copy_build_request(self) -> None:
194 options = FinalRequestOptions(method="get", url="/foo")
195
196 def build_request(options: FinalRequestOptions) -> None:
197 client = self.client.copy()
198 client._build_request(options)
199
200 # ensure that the machinery is warmed up before tracing starts.
201 build_request(options)
202 gc.collect()
203
204 tracemalloc.start(1000)
205
206 snapshot_before = tracemalloc.take_snapshot()
207
208 ITERATIONS = 10
209 for _ in range(ITERATIONS):
210 build_request(options)
211 gc.collect()
212
213 snapshot_after = tracemalloc.take_snapshot()
214
215 tracemalloc.stop()
216
217 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
218 if diff.count == 0:
219 # Avoid false positives by considering only leaks (i.e. allocations that persist).
220 return
221
222 if diff.count % ITERATIONS != 0:
223 # Avoid false positives by considering only leaks that appear per iteration.
224 return
225
226 for frame in diff.traceback:
227 if any(
228 frame.filename.endswith(fragment)
229 for fragment in [
230 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
231 #
232 # removing the decorator fixes the leak for reasons we don't understand.
233 "openai/_response.py",
234 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
235 "openai/_compat.py",
236 # Standard library leaks we don't care about.
237 "/logging/__init__.py",
238 ]
239 ):
240 return
241
242 leaks.append(diff)
243
244 leaks: list[tracemalloc.StatisticDiff] = []
245 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
246 add_leak(leaks, diff)
247 if leaks:
248 for leak in leaks:
249 print("MEMORY LEAK:", leak)
250 for frame in leak.traceback:
251 print(frame)
252 raise AssertionError()
253
254 def test_request_timeout(self) -> None:
255 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
256 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
257 assert timeout == DEFAULT_TIMEOUT
258
259 request = self.client._build_request(
260 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
261 )
262 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
263 assert timeout == httpx.Timeout(100.0)
264
265 def test_client_timeout_option(self) -> None:
266 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
267
268 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
269 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
270 assert timeout == httpx.Timeout(0)
271
272 def test_http_client_timeout_option(self) -> None:
273 # custom timeout given to the httpx client should be used
274 with httpx.Client(timeout=None) as http_client:
275 client = OpenAI(
276 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
277 )
278
279 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
280 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
281 assert timeout == httpx.Timeout(None)
282
283 # no timeout given to the httpx client should not use the httpx default
284 with httpx.Client() as http_client:
285 client = OpenAI(
286 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
287 )
288
289 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
290 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
291 assert timeout == DEFAULT_TIMEOUT
292
293 # explicitly passing the default timeout currently results in it being ignored
294 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
295 client = OpenAI(
296 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
297 )
298
299 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
300 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
301 assert timeout == DEFAULT_TIMEOUT # our default
302
303 def test_default_headers_option(self) -> None:
304 client = OpenAI(
305 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
306 )
307 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
308 assert request.headers.get("x-foo") == "bar"
309 assert request.headers.get("x-stainless-lang") == "python"
310
311 client2 = OpenAI(
312 base_url=base_url,
313 api_key=api_key,
314 _strict_response_validation=True,
315 default_headers={
316 "X-Foo": "stainless",
317 "X-Stainless-Lang": "my-overriding-header",
318 },
319 )
320 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
321 assert request.headers.get("x-foo") == "stainless"
322 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
323
324 def test_validate_headers(self) -> None:
325 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
326 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
327 assert request.headers.get("Authorization") == f"Bearer {api_key}"
328
329 with pytest.raises(OpenAIError):
330 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
331 _ = client2
332
333 def test_default_query_option(self) -> None:
334 client = OpenAI(
335 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
336 )
337 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
338 url = httpx.URL(request.url)
339 assert dict(url.params) == {"query_param": "bar"}
340
341 request = client._build_request(
342 FinalRequestOptions(
343 method="get",
344 url="/foo",
345 params={"foo": "baz", "query_param": "overriden"},
346 )
347 )
348 url = httpx.URL(request.url)
349 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
350
351 def test_request_extra_json(self) -> None:
352 request = self.client._build_request(
353 FinalRequestOptions(
354 method="post",
355 url="/foo",
356 json_data={"foo": "bar"},
357 extra_json={"baz": False},
358 ),
359 )
360 data = json.loads(request.content.decode("utf-8"))
361 assert data == {"foo": "bar", "baz": False}
362
363 request = self.client._build_request(
364 FinalRequestOptions(
365 method="post",
366 url="/foo",
367 extra_json={"baz": False},
368 ),
369 )
370 data = json.loads(request.content.decode("utf-8"))
371 assert data == {"baz": False}
372
373 # `extra_json` takes priority over `json_data` when keys clash
374 request = self.client._build_request(
375 FinalRequestOptions(
376 method="post",
377 url="/foo",
378 json_data={"foo": "bar", "baz": True},
379 extra_json={"baz": None},
380 ),
381 )
382 data = json.loads(request.content.decode("utf-8"))
383 assert data == {"foo": "bar", "baz": None}
384
385 def test_request_extra_headers(self) -> None:
386 request = self.client._build_request(
387 FinalRequestOptions(
388 method="post",
389 url="/foo",
390 **make_request_options(extra_headers={"X-Foo": "Foo"}),
391 ),
392 )
393 assert request.headers.get("X-Foo") == "Foo"
394
395 # `extra_headers` takes priority over `default_headers` when keys clash
396 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
397 FinalRequestOptions(
398 method="post",
399 url="/foo",
400 **make_request_options(
401 extra_headers={"X-Bar": "false"},
402 ),
403 ),
404 )
405 assert request.headers.get("X-Bar") == "false"
406
407 def test_request_extra_query(self) -> None:
408 request = self.client._build_request(
409 FinalRequestOptions(
410 method="post",
411 url="/foo",
412 **make_request_options(
413 extra_query={"my_query_param": "Foo"},
414 ),
415 ),
416 )
417 params = dict(request.url.params)
418 assert params == {"my_query_param": "Foo"}
419
420 # if both `query` and `extra_query` are given, they are merged
421 request = self.client._build_request(
422 FinalRequestOptions(
423 method="post",
424 url="/foo",
425 **make_request_options(
426 query={"bar": "1"},
427 extra_query={"foo": "2"},
428 ),
429 ),
430 )
431 params = dict(request.url.params)
432 assert params == {"bar": "1", "foo": "2"}
433
434 # `extra_query` takes priority over `query` when keys clash
435 request = self.client._build_request(
436 FinalRequestOptions(
437 method="post",
438 url="/foo",
439 **make_request_options(
440 query={"foo": "1"},
441 extra_query={"foo": "2"},
442 ),
443 ),
444 )
445 params = dict(request.url.params)
446 assert params == {"foo": "2"}
447
448 @pytest.mark.respx(base_url=base_url)
449 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
450 class Model1(BaseModel):
451 name: str
452
453 class Model2(BaseModel):
454 foo: str
455
456 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
457
458 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
459 assert isinstance(response, Model2)
460 assert response.foo == "bar"
461
462 @pytest.mark.respx(base_url=base_url)
463 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
464 """Union of objects with the same field name using a different type"""
465
466 class Model1(BaseModel):
467 foo: int
468
469 class Model2(BaseModel):
470 foo: str
471
472 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
473
474 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
475 assert isinstance(response, Model2)
476 assert response.foo == "bar"
477
478 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
479
480 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
481 assert isinstance(response, Model1)
482 assert response.foo == 1
483
484 @pytest.mark.respx(base_url=base_url)
485 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
486 """
487 Response that sets Content-Type to something other than application/json but returns json data
488 """
489
490 class Model(BaseModel):
491 foo: int
492
493 respx_mock.get("/foo").mock(
494 return_value=httpx.Response(
495 200,
496 content=json.dumps({"foo": 2}),
497 headers={"Content-Type": "application/text"},
498 )
499 )
500
501 response = self.client.get("/foo", cast_to=Model)
502 assert isinstance(response, Model)
503 assert response.foo == 2
504
505 def test_base_url_setter(self) -> None:
506 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
507 assert client.base_url == "https://example.com/from_init/"
508
509 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
510
511 assert client.base_url == "https://example.com/from_setter/"
512
513 def test_base_url_env(self) -> None:
514 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
515 client = OpenAI(api_key=api_key, _strict_response_validation=True)
516 assert client.base_url == "http://localhost:5000/from/env/"
517
518 @pytest.mark.parametrize(
519 "client",
520 [
521 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
522 OpenAI(
523 base_url="http://localhost:5000/custom/path/",
524 api_key=api_key,
525 _strict_response_validation=True,
526 http_client=httpx.Client(),
527 ),
528 ],
529 ids=["standard", "custom http client"],
530 )
531 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
532 request = client._build_request(
533 FinalRequestOptions(
534 method="post",
535 url="/foo",
536 json_data={"foo": "bar"},
537 ),
538 )
539 assert request.url == "http://localhost:5000/custom/path/foo"
540
541 @pytest.mark.parametrize(
542 "client",
543 [
544 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
545 OpenAI(
546 base_url="http://localhost:5000/custom/path/",
547 api_key=api_key,
548 _strict_response_validation=True,
549 http_client=httpx.Client(),
550 ),
551 ],
552 ids=["standard", "custom http client"],
553 )
554 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
555 request = client._build_request(
556 FinalRequestOptions(
557 method="post",
558 url="/foo",
559 json_data={"foo": "bar"},
560 ),
561 )
562 assert request.url == "http://localhost:5000/custom/path/foo"
563
564 @pytest.mark.parametrize(
565 "client",
566 [
567 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
568 OpenAI(
569 base_url="http://localhost:5000/custom/path/",
570 api_key=api_key,
571 _strict_response_validation=True,
572 http_client=httpx.Client(),
573 ),
574 ],
575 ids=["standard", "custom http client"],
576 )
577 def test_absolute_request_url(self, client: OpenAI) -> None:
578 request = client._build_request(
579 FinalRequestOptions(
580 method="post",
581 url="https://myapi.com/foo",
582 json_data={"foo": "bar"},
583 ),
584 )
585 assert request.url == "https://myapi.com/foo"
586
587 def test_copied_client_does_not_close_http(self) -> None:
588 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
589 assert not client.is_closed()
590
591 copied = client.copy()
592 assert copied is not client
593
594 del copied
595
596 assert not client.is_closed()
597
598 def test_client_context_manager(self) -> None:
599 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
600 with client as c2:
601 assert c2 is client
602 assert not c2.is_closed()
603 assert not client.is_closed()
604 assert client.is_closed()
605
606 @pytest.mark.respx(base_url=base_url)
607 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
608 class Model(BaseModel):
609 foo: str
610
611 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
612
613 with pytest.raises(APIResponseValidationError) as exc:
614 self.client.get("/foo", cast_to=Model)
615
616 assert isinstance(exc.value.__cause__, ValidationError)
617
618 @pytest.mark.respx(base_url=base_url)
619 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
620 class Model(BaseModel):
621 name: str
622
623 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
624
625 response = self.client.post("/foo", cast_to=Model, stream=True)
626 assert isinstance(response, Stream)
627
628 @pytest.mark.respx(base_url=base_url)
629 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
630 class Model(BaseModel):
631 name: str
632
633 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
634
635 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
636
637 with pytest.raises(APIResponseValidationError):
638 strict_client.get("/foo", cast_to=Model)
639
640 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
641
642 response = client.get("/foo", cast_to=Model)
643 assert isinstance(response, str) # type: ignore[unreachable]
644
645 @pytest.mark.parametrize(
646 "remaining_retries,retry_after,timeout",
647 [
648 [3, "20", 20],
649 [3, "0", 0.5],
650 [3, "-10", 0.5],
651 [3, "60", 60],
652 [3, "61", 0.5],
653 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
654 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
655 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
656 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
657 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
658 [3, "99999999999999999999999999999999999", 0.5],
659 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
660 [3, "", 0.5],
661 [2, "", 0.5 * 2.0],
662 [1, "", 0.5 * 4.0],
663 ],
664 )
665 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
666 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
667 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
668
669 headers = httpx.Headers({"retry-after": retry_after})
670 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
671 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
672 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
673
674 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
675 @pytest.mark.respx(base_url=base_url)
676 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
677 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
678
679 with pytest.raises(APITimeoutError):
680 self.client.post(
681 "/chat/completions",
682 body=dict(
683 messages=[
684 {
685 "role": "user",
686 "content": "Say this is a test",
687 }
688 ],
689 model="gpt-3.5-turbo",
690 ),
691 cast_to=httpx.Response,
692 options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
693 )
694
695 assert _get_open_connections(self.client) == 0
696
697 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
698 @pytest.mark.respx(base_url=base_url)
699 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
700 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
701
702 with pytest.raises(APIStatusError):
703 self.client.post(
704 "/chat/completions",
705 body=dict(
706 messages=[
707 {
708 "role": "user",
709 "content": "Say this is a test",
710 }
711 ],
712 model="gpt-3.5-turbo",
713 ),
714 cast_to=httpx.Response,
715 options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
716 )
717
718 assert _get_open_connections(self.client) == 0
719
720
721class TestAsyncOpenAI:
722 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
723
724 @pytest.mark.respx(base_url=base_url)
725 @pytest.mark.asyncio
726 async def test_raw_response(self, respx_mock: MockRouter) -> None:
727 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
728
729 response = await self.client.post("/foo", cast_to=httpx.Response)
730 assert response.status_code == 200
731 assert isinstance(response, httpx.Response)
732 assert response.json() == {"foo": "bar"}
733
734 @pytest.mark.respx(base_url=base_url)
735 @pytest.mark.asyncio
736 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
737 respx_mock.post("/foo").mock(
738 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
739 )
740
741 response = await self.client.post("/foo", cast_to=httpx.Response)
742 assert response.status_code == 200
743 assert isinstance(response, httpx.Response)
744 assert response.json() == {"foo": "bar"}
745
746 def test_copy(self) -> None:
747 copied = self.client.copy()
748 assert id(copied) != id(self.client)
749
750 copied = self.client.copy(api_key="another My API Key")
751 assert copied.api_key == "another My API Key"
752 assert self.client.api_key == "My API Key"
753
754 def test_copy_default_options(self) -> None:
755 # options that have a default are overridden correctly
756 copied = self.client.copy(max_retries=7)
757 assert copied.max_retries == 7
758 assert self.client.max_retries == 2
759
760 copied2 = copied.copy(max_retries=6)
761 assert copied2.max_retries == 6
762 assert copied.max_retries == 7
763
764 # timeout
765 assert isinstance(self.client.timeout, httpx.Timeout)
766 copied = self.client.copy(timeout=None)
767 assert copied.timeout is None
768 assert isinstance(self.client.timeout, httpx.Timeout)
769
770 def test_copy_default_headers(self) -> None:
771 client = AsyncOpenAI(
772 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
773 )
774 assert client.default_headers["X-Foo"] == "bar"
775
776 # does not override the already given value when not specified
777 copied = client.copy()
778 assert copied.default_headers["X-Foo"] == "bar"
779
780 # merges already given headers
781 copied = client.copy(default_headers={"X-Bar": "stainless"})
782 assert copied.default_headers["X-Foo"] == "bar"
783 assert copied.default_headers["X-Bar"] == "stainless"
784
785 # uses new values for any already given headers
786 copied = client.copy(default_headers={"X-Foo": "stainless"})
787 assert copied.default_headers["X-Foo"] == "stainless"
788
789 # set_default_headers
790
791 # completely overrides already set values
792 copied = client.copy(set_default_headers={})
793 assert copied.default_headers.get("X-Foo") is None
794
795 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
796 assert copied.default_headers["X-Bar"] == "Robert"
797
798 with pytest.raises(
799 ValueError,
800 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
801 ):
802 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
803
804 def test_copy_default_query(self) -> None:
805 client = AsyncOpenAI(
806 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
807 )
808 assert _get_params(client)["foo"] == "bar"
809
810 # does not override the already given value when not specified
811 copied = client.copy()
812 assert _get_params(copied)["foo"] == "bar"
813
814 # merges already given params
815 copied = client.copy(default_query={"bar": "stainless"})
816 params = _get_params(copied)
817 assert params["foo"] == "bar"
818 assert params["bar"] == "stainless"
819
820 # uses new values for any already given headers
821 copied = client.copy(default_query={"foo": "stainless"})
822 assert _get_params(copied)["foo"] == "stainless"
823
824 # set_default_query
825
826 # completely overrides already set values
827 copied = client.copy(set_default_query={})
828 assert _get_params(copied) == {}
829
830 copied = client.copy(set_default_query={"bar": "Robert"})
831 assert _get_params(copied)["bar"] == "Robert"
832
833 with pytest.raises(
834 ValueError,
835 # TODO: update
836 match="`default_query` and `set_default_query` arguments are mutually exclusive",
837 ):
838 client.copy(set_default_query={}, default_query={"foo": "Bar"})
839
840 def test_copy_signature(self) -> None:
841 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
842 init_signature = inspect.signature(
843 # mypy doesn't like that we access the `__init__` property.
844 self.client.__init__, # type: ignore[misc]
845 )
846 copy_signature = inspect.signature(self.client.copy)
847 exclude_params = {"transport", "proxies", "_strict_response_validation"}
848
849 for name in init_signature.parameters.keys():
850 if name in exclude_params:
851 continue
852
853 copy_param = copy_signature.parameters.get(name)
854 assert copy_param is not None, f"copy() signature is missing the {name} param"
855
856 def test_copy_build_request(self) -> None:
857 options = FinalRequestOptions(method="get", url="/foo")
858
859 def build_request(options: FinalRequestOptions) -> None:
860 client = self.client.copy()
861 client._build_request(options)
862
863 # ensure that the machinery is warmed up before tracing starts.
864 build_request(options)
865 gc.collect()
866
867 tracemalloc.start(1000)
868
869 snapshot_before = tracemalloc.take_snapshot()
870
871 ITERATIONS = 10
872 for _ in range(ITERATIONS):
873 build_request(options)
874 gc.collect()
875
876 snapshot_after = tracemalloc.take_snapshot()
877
878 tracemalloc.stop()
879
880 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
881 if diff.count == 0:
882 # Avoid false positives by considering only leaks (i.e. allocations that persist).
883 return
884
885 if diff.count % ITERATIONS != 0:
886 # Avoid false positives by considering only leaks that appear per iteration.
887 return
888
889 for frame in diff.traceback:
890 if any(
891 frame.filename.endswith(fragment)
892 for fragment in [
893 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
894 #
895 # removing the decorator fixes the leak for reasons we don't understand.
896 "openai/_response.py",
897 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
898 "openai/_compat.py",
899 # Standard library leaks we don't care about.
900 "/logging/__init__.py",
901 ]
902 ):
903 return
904
905 leaks.append(diff)
906
907 leaks: list[tracemalloc.StatisticDiff] = []
908 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
909 add_leak(leaks, diff)
910 if leaks:
911 for leak in leaks:
912 print("MEMORY LEAK:", leak)
913 for frame in leak.traceback:
914 print(frame)
915 raise AssertionError()
916
917 async def test_request_timeout(self) -> None:
918 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
919 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
920 assert timeout == DEFAULT_TIMEOUT
921
922 request = self.client._build_request(
923 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
924 )
925 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
926 assert timeout == httpx.Timeout(100.0)
927
928 async def test_client_timeout_option(self) -> None:
929 client = AsyncOpenAI(
930 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
931 )
932
933 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
934 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
935 assert timeout == httpx.Timeout(0)
936
937 async def test_http_client_timeout_option(self) -> None:
938 # custom timeout given to the httpx client should be used
939 async with httpx.AsyncClient(timeout=None) as http_client:
940 client = AsyncOpenAI(
941 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
942 )
943
944 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
945 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
946 assert timeout == httpx.Timeout(None)
947
948 # no timeout given to the httpx client should not use the httpx default
949 async with httpx.AsyncClient() as http_client:
950 client = AsyncOpenAI(
951 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
952 )
953
954 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
955 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
956 assert timeout == DEFAULT_TIMEOUT
957
958 # explicitly passing the default timeout currently results in it being ignored
959 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
960 client = AsyncOpenAI(
961 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
962 )
963
964 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
965 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
966 assert timeout == DEFAULT_TIMEOUT # our default
967
968 def test_default_headers_option(self) -> None:
969 client = AsyncOpenAI(
970 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
971 )
972 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
973 assert request.headers.get("x-foo") == "bar"
974 assert request.headers.get("x-stainless-lang") == "python"
975
976 client2 = AsyncOpenAI(
977 base_url=base_url,
978 api_key=api_key,
979 _strict_response_validation=True,
980 default_headers={
981 "X-Foo": "stainless",
982 "X-Stainless-Lang": "my-overriding-header",
983 },
984 )
985 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
986 assert request.headers.get("x-foo") == "stainless"
987 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
988
989 def test_validate_headers(self) -> None:
990 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
991 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
992 assert request.headers.get("Authorization") == f"Bearer {api_key}"
993
994 with pytest.raises(OpenAIError):
995 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
996 _ = client2
997
998 def test_default_query_option(self) -> None:
999 client = AsyncOpenAI(
1000 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1001 )
1002 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1003 url = httpx.URL(request.url)
1004 assert dict(url.params) == {"query_param": "bar"}
1005
1006 request = client._build_request(
1007 FinalRequestOptions(
1008 method="get",
1009 url="/foo",
1010 params={"foo": "baz", "query_param": "overriden"},
1011 )
1012 )
1013 url = httpx.URL(request.url)
1014 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
1015
1016 def test_request_extra_json(self) -> None:
1017 request = self.client._build_request(
1018 FinalRequestOptions(
1019 method="post",
1020 url="/foo",
1021 json_data={"foo": "bar"},
1022 extra_json={"baz": False},
1023 ),
1024 )
1025 data = json.loads(request.content.decode("utf-8"))
1026 assert data == {"foo": "bar", "baz": False}
1027
1028 request = self.client._build_request(
1029 FinalRequestOptions(
1030 method="post",
1031 url="/foo",
1032 extra_json={"baz": False},
1033 ),
1034 )
1035 data = json.loads(request.content.decode("utf-8"))
1036 assert data == {"baz": False}
1037
1038 # `extra_json` takes priority over `json_data` when keys clash
1039 request = self.client._build_request(
1040 FinalRequestOptions(
1041 method="post",
1042 url="/foo",
1043 json_data={"foo": "bar", "baz": True},
1044 extra_json={"baz": None},
1045 ),
1046 )
1047 data = json.loads(request.content.decode("utf-8"))
1048 assert data == {"foo": "bar", "baz": None}
1049
1050 def test_request_extra_headers(self) -> None:
1051 request = self.client._build_request(
1052 FinalRequestOptions(
1053 method="post",
1054 url="/foo",
1055 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1056 ),
1057 )
1058 assert request.headers.get("X-Foo") == "Foo"
1059
1060 # `extra_headers` takes priority over `default_headers` when keys clash
1061 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1062 FinalRequestOptions(
1063 method="post",
1064 url="/foo",
1065 **make_request_options(
1066 extra_headers={"X-Bar": "false"},
1067 ),
1068 ),
1069 )
1070 assert request.headers.get("X-Bar") == "false"
1071
1072 def test_request_extra_query(self) -> None:
1073 request = self.client._build_request(
1074 FinalRequestOptions(
1075 method="post",
1076 url="/foo",
1077 **make_request_options(
1078 extra_query={"my_query_param": "Foo"},
1079 ),
1080 ),
1081 )
1082 params = dict(request.url.params)
1083 assert params == {"my_query_param": "Foo"}
1084
1085 # if both `query` and `extra_query` are given, they are merged
1086 request = self.client._build_request(
1087 FinalRequestOptions(
1088 method="post",
1089 url="/foo",
1090 **make_request_options(
1091 query={"bar": "1"},
1092 extra_query={"foo": "2"},
1093 ),
1094 ),
1095 )
1096 params = dict(request.url.params)
1097 assert params == {"bar": "1", "foo": "2"}
1098
1099 # `extra_query` takes priority over `query` when keys clash
1100 request = self.client._build_request(
1101 FinalRequestOptions(
1102 method="post",
1103 url="/foo",
1104 **make_request_options(
1105 query={"foo": "1"},
1106 extra_query={"foo": "2"},
1107 ),
1108 ),
1109 )
1110 params = dict(request.url.params)
1111 assert params == {"foo": "2"}
1112
1113 @pytest.mark.respx(base_url=base_url)
1114 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1115 class Model1(BaseModel):
1116 name: str
1117
1118 class Model2(BaseModel):
1119 foo: str
1120
1121 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1122
1123 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1124 assert isinstance(response, Model2)
1125 assert response.foo == "bar"
1126
1127 @pytest.mark.respx(base_url=base_url)
1128 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1129 """Union of objects with the same field name using a different type"""
1130
1131 class Model1(BaseModel):
1132 foo: int
1133
1134 class Model2(BaseModel):
1135 foo: str
1136
1137 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1138
1139 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1140 assert isinstance(response, Model2)
1141 assert response.foo == "bar"
1142
1143 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1144
1145 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1146 assert isinstance(response, Model1)
1147 assert response.foo == 1
1148
1149 @pytest.mark.respx(base_url=base_url)
1150 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1151 """
1152 Response that sets Content-Type to something other than application/json but returns json data
1153 """
1154
1155 class Model(BaseModel):
1156 foo: int
1157
1158 respx_mock.get("/foo").mock(
1159 return_value=httpx.Response(
1160 200,
1161 content=json.dumps({"foo": 2}),
1162 headers={"Content-Type": "application/text"},
1163 )
1164 )
1165
1166 response = await self.client.get("/foo", cast_to=Model)
1167 assert isinstance(response, Model)
1168 assert response.foo == 2
1169
1170 def test_base_url_setter(self) -> None:
1171 client = AsyncOpenAI(
1172 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1173 )
1174 assert client.base_url == "https://example.com/from_init/"
1175
1176 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1177
1178 assert client.base_url == "https://example.com/from_setter/"
1179
1180 def test_base_url_env(self) -> None:
1181 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1182 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1183 assert client.base_url == "http://localhost:5000/from/env/"
1184
1185 @pytest.mark.parametrize(
1186 "client",
1187 [
1188 AsyncOpenAI(
1189 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1190 ),
1191 AsyncOpenAI(
1192 base_url="http://localhost:5000/custom/path/",
1193 api_key=api_key,
1194 _strict_response_validation=True,
1195 http_client=httpx.AsyncClient(),
1196 ),
1197 ],
1198 ids=["standard", "custom http client"],
1199 )
1200 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1201 request = client._build_request(
1202 FinalRequestOptions(
1203 method="post",
1204 url="/foo",
1205 json_data={"foo": "bar"},
1206 ),
1207 )
1208 assert request.url == "http://localhost:5000/custom/path/foo"
1209
1210 @pytest.mark.parametrize(
1211 "client",
1212 [
1213 AsyncOpenAI(
1214 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1215 ),
1216 AsyncOpenAI(
1217 base_url="http://localhost:5000/custom/path/",
1218 api_key=api_key,
1219 _strict_response_validation=True,
1220 http_client=httpx.AsyncClient(),
1221 ),
1222 ],
1223 ids=["standard", "custom http client"],
1224 )
1225 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1226 request = client._build_request(
1227 FinalRequestOptions(
1228 method="post",
1229 url="/foo",
1230 json_data={"foo": "bar"},
1231 ),
1232 )
1233 assert request.url == "http://localhost:5000/custom/path/foo"
1234
1235 @pytest.mark.parametrize(
1236 "client",
1237 [
1238 AsyncOpenAI(
1239 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1240 ),
1241 AsyncOpenAI(
1242 base_url="http://localhost:5000/custom/path/",
1243 api_key=api_key,
1244 _strict_response_validation=True,
1245 http_client=httpx.AsyncClient(),
1246 ),
1247 ],
1248 ids=["standard", "custom http client"],
1249 )
1250 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1251 request = client._build_request(
1252 FinalRequestOptions(
1253 method="post",
1254 url="https://myapi.com/foo",
1255 json_data={"foo": "bar"},
1256 ),
1257 )
1258 assert request.url == "https://myapi.com/foo"
1259
1260 async def test_copied_client_does_not_close_http(self) -> None:
1261 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1262 assert not client.is_closed()
1263
1264 copied = client.copy()
1265 assert copied is not client
1266
1267 del copied
1268
1269 await asyncio.sleep(0.2)
1270 assert not client.is_closed()
1271
1272 async def test_client_context_manager(self) -> None:
1273 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1274 async with client as c2:
1275 assert c2 is client
1276 assert not c2.is_closed()
1277 assert not client.is_closed()
1278 assert client.is_closed()
1279
1280 @pytest.mark.respx(base_url=base_url)
1281 @pytest.mark.asyncio
1282 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1283 class Model(BaseModel):
1284 foo: str
1285
1286 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1287
1288 with pytest.raises(APIResponseValidationError) as exc:
1289 await self.client.get("/foo", cast_to=Model)
1290
1291 assert isinstance(exc.value.__cause__, ValidationError)
1292
1293 @pytest.mark.respx(base_url=base_url)
1294 @pytest.mark.asyncio
1295 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1296 class Model(BaseModel):
1297 name: str
1298
1299 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1300
1301 response = await self.client.post("/foo", cast_to=Model, stream=True)
1302 assert isinstance(response, AsyncStream)
1303
1304 @pytest.mark.respx(base_url=base_url)
1305 @pytest.mark.asyncio
1306 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1307 class Model(BaseModel):
1308 name: str
1309
1310 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1311
1312 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1313
1314 with pytest.raises(APIResponseValidationError):
1315 await strict_client.get("/foo", cast_to=Model)
1316
1317 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1318
1319 response = await client.get("/foo", cast_to=Model)
1320 assert isinstance(response, str) # type: ignore[unreachable]
1321
1322 @pytest.mark.parametrize(
1323 "remaining_retries,retry_after,timeout",
1324 [
1325 [3, "20", 20],
1326 [3, "0", 0.5],
1327 [3, "-10", 0.5],
1328 [3, "60", 60],
1329 [3, "61", 0.5],
1330 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1331 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1332 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1333 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1334 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1335 [3, "99999999999999999999999999999999999", 0.5],
1336 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1337 [3, "", 0.5],
1338 [2, "", 0.5 * 2.0],
1339 [1, "", 0.5 * 4.0],
1340 ],
1341 )
1342 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1343 @pytest.mark.asyncio
1344 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1345 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1346
1347 headers = httpx.Headers({"retry-after": retry_after})
1348 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1349 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1350 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1351
1352 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1353 @pytest.mark.respx(base_url=base_url)
1354 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1355 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1356
1357 with pytest.raises(APITimeoutError):
1358 await self.client.post(
1359 "/chat/completions",
1360 body=dict(
1361 messages=[
1362 {
1363 "role": "user",
1364 "content": "Say this is a test",
1365 }
1366 ],
1367 model="gpt-3.5-turbo",
1368 ),
1369 cast_to=httpx.Response,
1370 options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
1371 )
1372
1373 assert _get_open_connections(self.client) == 0
1374
1375 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1376 @pytest.mark.respx(base_url=base_url)
1377 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1378 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1379
1380 with pytest.raises(APIStatusError):
1381 await self.client.post(
1382 "/chat/completions",
1383 body=dict(
1384 messages=[
1385 {
1386 "role": "user",
1387 "content": "Say this is a test",
1388 }
1389 ],
1390 model="gpt-3.5-turbo",
1391 ),
1392 cast_to=httpx.Response,
1393 options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
1394 )
1395
1396 assert _get_open_connections(self.client) == 0
1397