openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
86379b4471d67a9d2e85f0b0c098787fb99aa4e0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1446lines · modecode

1# File generated from our OpenAPI spec by Stainless.
2
3from __future__ import annotations
4
5import gc
6import os
7import json
8import asyncio
9import inspect
10import tracemalloc
11from typing import Any, Union, cast
12from unittest import mock
13
14import httpx
15import pytest
16from respx import MockRouter
17from pydantic import ValidationError
18
19from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
20from openai._client import OpenAI, AsyncOpenAI
21from openai._models import BaseModel, FinalRequestOptions
22from openai._response import APIResponse, AsyncAPIResponse
23from openai._constants import RAW_RESPONSE_HEADER
24from openai._streaming import Stream, AsyncStream
25from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
26from openai._base_client import DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, BaseClient, make_request_options
27
28from .utils import update_env
29
30base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
31api_key = "My API Key"
32
33
34def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
35 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
36 url = httpx.URL(request.url)
37 return dict(url.params)
38
39
40def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
41 return 0.1
42
43
44def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
45 transport = client._client._transport
46 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
47
48 pool = transport._pool
49 return len(pool._requests)
50
51
52class TestOpenAI:
53 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
54
55 @pytest.mark.respx(base_url=base_url)
56 def test_raw_response(self, respx_mock: MockRouter) -> None:
57 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
58
59 response = self.client.post("/foo", cast_to=httpx.Response)
60 assert response.status_code == 200
61 assert isinstance(response, httpx.Response)
62 assert response.json() == {"foo": "bar"}
63
64 @pytest.mark.respx(base_url=base_url)
65 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
66 respx_mock.post("/foo").mock(
67 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
68 )
69
70 response = self.client.post("/foo", cast_to=httpx.Response)
71 assert response.status_code == 200
72 assert isinstance(response, httpx.Response)
73 assert response.json() == {"foo": "bar"}
74
75 def test_copy(self) -> None:
76 copied = self.client.copy()
77 assert id(copied) != id(self.client)
78
79 copied = self.client.copy(api_key="another My API Key")
80 assert copied.api_key == "another My API Key"
81 assert self.client.api_key == "My API Key"
82
83 def test_copy_default_options(self) -> None:
84 # options that have a default are overridden correctly
85 copied = self.client.copy(max_retries=7)
86 assert copied.max_retries == 7
87 assert self.client.max_retries == 2
88
89 copied2 = copied.copy(max_retries=6)
90 assert copied2.max_retries == 6
91 assert copied.max_retries == 7
92
93 # timeout
94 assert isinstance(self.client.timeout, httpx.Timeout)
95 copied = self.client.copy(timeout=None)
96 assert copied.timeout is None
97 assert isinstance(self.client.timeout, httpx.Timeout)
98
99 def test_copy_default_headers(self) -> None:
100 client = OpenAI(
101 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
102 )
103 assert client.default_headers["X-Foo"] == "bar"
104
105 # does not override the already given value when not specified
106 copied = client.copy()
107 assert copied.default_headers["X-Foo"] == "bar"
108
109 # merges already given headers
110 copied = client.copy(default_headers={"X-Bar": "stainless"})
111 assert copied.default_headers["X-Foo"] == "bar"
112 assert copied.default_headers["X-Bar"] == "stainless"
113
114 # uses new values for any already given headers
115 copied = client.copy(default_headers={"X-Foo": "stainless"})
116 assert copied.default_headers["X-Foo"] == "stainless"
117
118 # set_default_headers
119
120 # completely overrides already set values
121 copied = client.copy(set_default_headers={})
122 assert copied.default_headers.get("X-Foo") is None
123
124 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
125 assert copied.default_headers["X-Bar"] == "Robert"
126
127 with pytest.raises(
128 ValueError,
129 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
130 ):
131 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
132
133 def test_copy_default_query(self) -> None:
134 client = OpenAI(
135 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
136 )
137 assert _get_params(client)["foo"] == "bar"
138
139 # does not override the already given value when not specified
140 copied = client.copy()
141 assert _get_params(copied)["foo"] == "bar"
142
143 # merges already given params
144 copied = client.copy(default_query={"bar": "stainless"})
145 params = _get_params(copied)
146 assert params["foo"] == "bar"
147 assert params["bar"] == "stainless"
148
149 # uses new values for any already given headers
150 copied = client.copy(default_query={"foo": "stainless"})
151 assert _get_params(copied)["foo"] == "stainless"
152
153 # set_default_query
154
155 # completely overrides already set values
156 copied = client.copy(set_default_query={})
157 assert _get_params(copied) == {}
158
159 copied = client.copy(set_default_query={"bar": "Robert"})
160 assert _get_params(copied)["bar"] == "Robert"
161
162 with pytest.raises(
163 ValueError,
164 # TODO: update
165 match="`default_query` and `set_default_query` arguments are mutually exclusive",
166 ):
167 client.copy(set_default_query={}, default_query={"foo": "Bar"})
168
169 def test_copy_signature(self) -> None:
170 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
171 init_signature = inspect.signature(
172 # mypy doesn't like that we access the `__init__` property.
173 self.client.__init__, # type: ignore[misc]
174 )
175 copy_signature = inspect.signature(self.client.copy)
176 exclude_params = {"transport", "proxies", "_strict_response_validation"}
177
178 for name in init_signature.parameters.keys():
179 if name in exclude_params:
180 continue
181
182 copy_param = copy_signature.parameters.get(name)
183 assert copy_param is not None, f"copy() signature is missing the {name} param"
184
185 def test_copy_build_request(self) -> None:
186 options = FinalRequestOptions(method="get", url="/foo")
187
188 def build_request(options: FinalRequestOptions) -> None:
189 client = self.client.copy()
190 client._build_request(options)
191
192 # ensure that the machinery is warmed up before tracing starts.
193 build_request(options)
194 gc.collect()
195
196 tracemalloc.start(1000)
197
198 snapshot_before = tracemalloc.take_snapshot()
199
200 ITERATIONS = 10
201 for _ in range(ITERATIONS):
202 build_request(options)
203
204 gc.collect()
205 snapshot_after = tracemalloc.take_snapshot()
206
207 tracemalloc.stop()
208
209 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
210 if diff.count == 0:
211 # Avoid false positives by considering only leaks (i.e. allocations that persist).
212 return
213
214 if diff.count % ITERATIONS != 0:
215 # Avoid false positives by considering only leaks that appear per iteration.
216 return
217
218 for frame in diff.traceback:
219 if any(
220 frame.filename.endswith(fragment)
221 for fragment in [
222 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
223 #
224 # removing the decorator fixes the leak for reasons we don't understand.
225 "openai/_legacy_response.py",
226 "openai/_response.py",
227 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
228 "openai/_compat.py",
229 # Standard library leaks we don't care about.
230 "/logging/__init__.py",
231 ]
232 ):
233 return
234
235 leaks.append(diff)
236
237 leaks: list[tracemalloc.StatisticDiff] = []
238 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
239 add_leak(leaks, diff)
240 if leaks:
241 for leak in leaks:
242 print("MEMORY LEAK:", leak)
243 for frame in leak.traceback:
244 print(frame)
245 raise AssertionError()
246
247 def test_request_timeout(self) -> None:
248 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
249 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
250 assert timeout == DEFAULT_TIMEOUT
251
252 request = self.client._build_request(
253 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
254 )
255 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
256 assert timeout == httpx.Timeout(100.0)
257
258 def test_client_timeout_option(self) -> None:
259 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
260
261 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
262 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
263 assert timeout == httpx.Timeout(0)
264
265 def test_http_client_timeout_option(self) -> None:
266 # custom timeout given to the httpx client should be used
267 with httpx.Client(timeout=None) as http_client:
268 client = OpenAI(
269 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
270 )
271
272 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
273 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
274 assert timeout == httpx.Timeout(None)
275
276 # no timeout given to the httpx client should not use the httpx default
277 with httpx.Client() as http_client:
278 client = OpenAI(
279 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
280 )
281
282 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
283 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
284 assert timeout == DEFAULT_TIMEOUT
285
286 # explicitly passing the default timeout currently results in it being ignored
287 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
288 client = OpenAI(
289 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
290 )
291
292 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
293 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
294 assert timeout == DEFAULT_TIMEOUT # our default
295
296 def test_default_headers_option(self) -> None:
297 client = OpenAI(
298 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
299 )
300 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
301 assert request.headers.get("x-foo") == "bar"
302 assert request.headers.get("x-stainless-lang") == "python"
303
304 client2 = OpenAI(
305 base_url=base_url,
306 api_key=api_key,
307 _strict_response_validation=True,
308 default_headers={
309 "X-Foo": "stainless",
310 "X-Stainless-Lang": "my-overriding-header",
311 },
312 )
313 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
314 assert request.headers.get("x-foo") == "stainless"
315 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
316
317 def test_validate_headers(self) -> None:
318 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
319 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
320 assert request.headers.get("Authorization") == f"Bearer {api_key}"
321
322 with pytest.raises(OpenAIError):
323 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
324 _ = client2
325
326 def test_default_query_option(self) -> None:
327 client = OpenAI(
328 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
329 )
330 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
331 url = httpx.URL(request.url)
332 assert dict(url.params) == {"query_param": "bar"}
333
334 request = client._build_request(
335 FinalRequestOptions(
336 method="get",
337 url="/foo",
338 params={"foo": "baz", "query_param": "overriden"},
339 )
340 )
341 url = httpx.URL(request.url)
342 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
343
344 def test_request_extra_json(self) -> None:
345 request = self.client._build_request(
346 FinalRequestOptions(
347 method="post",
348 url="/foo",
349 json_data={"foo": "bar"},
350 extra_json={"baz": False},
351 ),
352 )
353 data = json.loads(request.content.decode("utf-8"))
354 assert data == {"foo": "bar", "baz": False}
355
356 request = self.client._build_request(
357 FinalRequestOptions(
358 method="post",
359 url="/foo",
360 extra_json={"baz": False},
361 ),
362 )
363 data = json.loads(request.content.decode("utf-8"))
364 assert data == {"baz": False}
365
366 # `extra_json` takes priority over `json_data` when keys clash
367 request = self.client._build_request(
368 FinalRequestOptions(
369 method="post",
370 url="/foo",
371 json_data={"foo": "bar", "baz": True},
372 extra_json={"baz": None},
373 ),
374 )
375 data = json.loads(request.content.decode("utf-8"))
376 assert data == {"foo": "bar", "baz": None}
377
378 def test_request_extra_headers(self) -> None:
379 request = self.client._build_request(
380 FinalRequestOptions(
381 method="post",
382 url="/foo",
383 **make_request_options(extra_headers={"X-Foo": "Foo"}),
384 ),
385 )
386 assert request.headers.get("X-Foo") == "Foo"
387
388 # `extra_headers` takes priority over `default_headers` when keys clash
389 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
390 FinalRequestOptions(
391 method="post",
392 url="/foo",
393 **make_request_options(
394 extra_headers={"X-Bar": "false"},
395 ),
396 ),
397 )
398 assert request.headers.get("X-Bar") == "false"
399
400 def test_request_extra_query(self) -> None:
401 request = self.client._build_request(
402 FinalRequestOptions(
403 method="post",
404 url="/foo",
405 **make_request_options(
406 extra_query={"my_query_param": "Foo"},
407 ),
408 ),
409 )
410 params = dict(request.url.params)
411 assert params == {"my_query_param": "Foo"}
412
413 # if both `query` and `extra_query` are given, they are merged
414 request = self.client._build_request(
415 FinalRequestOptions(
416 method="post",
417 url="/foo",
418 **make_request_options(
419 query={"bar": "1"},
420 extra_query={"foo": "2"},
421 ),
422 ),
423 )
424 params = dict(request.url.params)
425 assert params == {"bar": "1", "foo": "2"}
426
427 # `extra_query` takes priority over `query` when keys clash
428 request = self.client._build_request(
429 FinalRequestOptions(
430 method="post",
431 url="/foo",
432 **make_request_options(
433 query={"foo": "1"},
434 extra_query={"foo": "2"},
435 ),
436 ),
437 )
438 params = dict(request.url.params)
439 assert params == {"foo": "2"}
440
441 @pytest.mark.respx(base_url=base_url)
442 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
443 class Model1(BaseModel):
444 name: str
445
446 class Model2(BaseModel):
447 foo: str
448
449 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
450
451 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
452 assert isinstance(response, Model2)
453 assert response.foo == "bar"
454
455 @pytest.mark.respx(base_url=base_url)
456 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
457 """Union of objects with the same field name using a different type"""
458
459 class Model1(BaseModel):
460 foo: int
461
462 class Model2(BaseModel):
463 foo: str
464
465 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
466
467 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
468 assert isinstance(response, Model2)
469 assert response.foo == "bar"
470
471 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
472
473 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
474 assert isinstance(response, Model1)
475 assert response.foo == 1
476
477 @pytest.mark.respx(base_url=base_url)
478 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
479 """
480 Response that sets Content-Type to something other than application/json but returns json data
481 """
482
483 class Model(BaseModel):
484 foo: int
485
486 respx_mock.get("/foo").mock(
487 return_value=httpx.Response(
488 200,
489 content=json.dumps({"foo": 2}),
490 headers={"Content-Type": "application/text"},
491 )
492 )
493
494 response = self.client.get("/foo", cast_to=Model)
495 assert isinstance(response, Model)
496 assert response.foo == 2
497
498 def test_base_url_setter(self) -> None:
499 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
500 assert client.base_url == "https://example.com/from_init/"
501
502 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
503
504 assert client.base_url == "https://example.com/from_setter/"
505
506 def test_base_url_env(self) -> None:
507 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
508 client = OpenAI(api_key=api_key, _strict_response_validation=True)
509 assert client.base_url == "http://localhost:5000/from/env/"
510
511 @pytest.mark.parametrize(
512 "client",
513 [
514 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
515 OpenAI(
516 base_url="http://localhost:5000/custom/path/",
517 api_key=api_key,
518 _strict_response_validation=True,
519 http_client=httpx.Client(),
520 ),
521 ],
522 ids=["standard", "custom http client"],
523 )
524 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
525 request = client._build_request(
526 FinalRequestOptions(
527 method="post",
528 url="/foo",
529 json_data={"foo": "bar"},
530 ),
531 )
532 assert request.url == "http://localhost:5000/custom/path/foo"
533
534 @pytest.mark.parametrize(
535 "client",
536 [
537 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
538 OpenAI(
539 base_url="http://localhost:5000/custom/path/",
540 api_key=api_key,
541 _strict_response_validation=True,
542 http_client=httpx.Client(),
543 ),
544 ],
545 ids=["standard", "custom http client"],
546 )
547 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
548 request = client._build_request(
549 FinalRequestOptions(
550 method="post",
551 url="/foo",
552 json_data={"foo": "bar"},
553 ),
554 )
555 assert request.url == "http://localhost:5000/custom/path/foo"
556
557 @pytest.mark.parametrize(
558 "client",
559 [
560 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
561 OpenAI(
562 base_url="http://localhost:5000/custom/path/",
563 api_key=api_key,
564 _strict_response_validation=True,
565 http_client=httpx.Client(),
566 ),
567 ],
568 ids=["standard", "custom http client"],
569 )
570 def test_absolute_request_url(self, client: OpenAI) -> None:
571 request = client._build_request(
572 FinalRequestOptions(
573 method="post",
574 url="https://myapi.com/foo",
575 json_data={"foo": "bar"},
576 ),
577 )
578 assert request.url == "https://myapi.com/foo"
579
580 def test_copied_client_does_not_close_http(self) -> None:
581 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
582 assert not client.is_closed()
583
584 copied = client.copy()
585 assert copied is not client
586
587 del copied
588
589 assert not client.is_closed()
590
591 def test_client_context_manager(self) -> None:
592 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
593 with client as c2:
594 assert c2 is client
595 assert not c2.is_closed()
596 assert not client.is_closed()
597 assert client.is_closed()
598
599 @pytest.mark.respx(base_url=base_url)
600 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
601 class Model(BaseModel):
602 foo: str
603
604 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
605
606 with pytest.raises(APIResponseValidationError) as exc:
607 self.client.get("/foo", cast_to=Model)
608
609 assert isinstance(exc.value.__cause__, ValidationError)
610
611 @pytest.mark.respx(base_url=base_url)
612 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
613 class Model(BaseModel):
614 name: str
615
616 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
617
618 stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
619 assert isinstance(stream, Stream)
620 stream.response.close()
621
622 @pytest.mark.respx(base_url=base_url)
623 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
624 class Model(BaseModel):
625 name: str
626
627 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
628
629 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
630
631 with pytest.raises(APIResponseValidationError):
632 strict_client.get("/foo", cast_to=Model)
633
634 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
635
636 response = client.get("/foo", cast_to=Model)
637 assert isinstance(response, str) # type: ignore[unreachable]
638
639 @pytest.mark.parametrize(
640 "remaining_retries,retry_after,timeout",
641 [
642 [3, "20", 20],
643 [3, "0", 0.5],
644 [3, "-10", 0.5],
645 [3, "60", 60],
646 [3, "61", 0.5],
647 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
648 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
649 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
650 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
651 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
652 [3, "99999999999999999999999999999999999", 0.5],
653 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
654 [3, "", 0.5],
655 [2, "", 0.5 * 2.0],
656 [1, "", 0.5 * 4.0],
657 ],
658 )
659 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
660 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
661 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
662
663 headers = httpx.Headers({"retry-after": retry_after})
664 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
665 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
666 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
667
668 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
669 @pytest.mark.respx(base_url=base_url)
670 def test_streaming_response(self) -> None:
671 response = self.client.post(
672 "/chat/completions",
673 body=dict(
674 messages=[
675 {
676 "role": "user",
677 "content": "Say this is a test",
678 }
679 ],
680 model="gpt-3.5-turbo",
681 ),
682 cast_to=APIResponse[bytes],
683 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
684 )
685
686 assert not cast(Any, response.is_closed)
687 assert _get_open_connections(self.client) == 1
688
689 for _ in response.iter_bytes():
690 ...
691
692 assert cast(Any, response.is_closed)
693 assert _get_open_connections(self.client) == 0
694
695 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
696 @pytest.mark.respx(base_url=base_url)
697 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
698 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
699
700 with pytest.raises(APITimeoutError):
701 self.client.post(
702 "/chat/completions",
703 body=dict(
704 messages=[
705 {
706 "role": "user",
707 "content": "Say this is a test",
708 }
709 ],
710 model="gpt-3.5-turbo",
711 ),
712 cast_to=httpx.Response,
713 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
714 )
715
716 assert _get_open_connections(self.client) == 0
717
718 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
719 @pytest.mark.respx(base_url=base_url)
720 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
721 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
722
723 with pytest.raises(APIStatusError):
724 self.client.post(
725 "/chat/completions",
726 body=dict(
727 messages=[
728 {
729 "role": "user",
730 "content": "Say this is a test",
731 }
732 ],
733 model="gpt-3.5-turbo",
734 ),
735 cast_to=httpx.Response,
736 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
737 )
738
739 assert _get_open_connections(self.client) == 0
740
741
742class TestAsyncOpenAI:
743 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
744
745 @pytest.mark.respx(base_url=base_url)
746 @pytest.mark.asyncio
747 async def test_raw_response(self, respx_mock: MockRouter) -> None:
748 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
749
750 response = await self.client.post("/foo", cast_to=httpx.Response)
751 assert response.status_code == 200
752 assert isinstance(response, httpx.Response)
753 assert response.json() == {"foo": "bar"}
754
755 @pytest.mark.respx(base_url=base_url)
756 @pytest.mark.asyncio
757 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
758 respx_mock.post("/foo").mock(
759 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
760 )
761
762 response = await self.client.post("/foo", cast_to=httpx.Response)
763 assert response.status_code == 200
764 assert isinstance(response, httpx.Response)
765 assert response.json() == {"foo": "bar"}
766
767 def test_copy(self) -> None:
768 copied = self.client.copy()
769 assert id(copied) != id(self.client)
770
771 copied = self.client.copy(api_key="another My API Key")
772 assert copied.api_key == "another My API Key"
773 assert self.client.api_key == "My API Key"
774
775 def test_copy_default_options(self) -> None:
776 # options that have a default are overridden correctly
777 copied = self.client.copy(max_retries=7)
778 assert copied.max_retries == 7
779 assert self.client.max_retries == 2
780
781 copied2 = copied.copy(max_retries=6)
782 assert copied2.max_retries == 6
783 assert copied.max_retries == 7
784
785 # timeout
786 assert isinstance(self.client.timeout, httpx.Timeout)
787 copied = self.client.copy(timeout=None)
788 assert copied.timeout is None
789 assert isinstance(self.client.timeout, httpx.Timeout)
790
791 def test_copy_default_headers(self) -> None:
792 client = AsyncOpenAI(
793 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
794 )
795 assert client.default_headers["X-Foo"] == "bar"
796
797 # does not override the already given value when not specified
798 copied = client.copy()
799 assert copied.default_headers["X-Foo"] == "bar"
800
801 # merges already given headers
802 copied = client.copy(default_headers={"X-Bar": "stainless"})
803 assert copied.default_headers["X-Foo"] == "bar"
804 assert copied.default_headers["X-Bar"] == "stainless"
805
806 # uses new values for any already given headers
807 copied = client.copy(default_headers={"X-Foo": "stainless"})
808 assert copied.default_headers["X-Foo"] == "stainless"
809
810 # set_default_headers
811
812 # completely overrides already set values
813 copied = client.copy(set_default_headers={})
814 assert copied.default_headers.get("X-Foo") is None
815
816 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
817 assert copied.default_headers["X-Bar"] == "Robert"
818
819 with pytest.raises(
820 ValueError,
821 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
822 ):
823 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
824
825 def test_copy_default_query(self) -> None:
826 client = AsyncOpenAI(
827 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
828 )
829 assert _get_params(client)["foo"] == "bar"
830
831 # does not override the already given value when not specified
832 copied = client.copy()
833 assert _get_params(copied)["foo"] == "bar"
834
835 # merges already given params
836 copied = client.copy(default_query={"bar": "stainless"})
837 params = _get_params(copied)
838 assert params["foo"] == "bar"
839 assert params["bar"] == "stainless"
840
841 # uses new values for any already given headers
842 copied = client.copy(default_query={"foo": "stainless"})
843 assert _get_params(copied)["foo"] == "stainless"
844
845 # set_default_query
846
847 # completely overrides already set values
848 copied = client.copy(set_default_query={})
849 assert _get_params(copied) == {}
850
851 copied = client.copy(set_default_query={"bar": "Robert"})
852 assert _get_params(copied)["bar"] == "Robert"
853
854 with pytest.raises(
855 ValueError,
856 # TODO: update
857 match="`default_query` and `set_default_query` arguments are mutually exclusive",
858 ):
859 client.copy(set_default_query={}, default_query={"foo": "Bar"})
860
861 def test_copy_signature(self) -> None:
862 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
863 init_signature = inspect.signature(
864 # mypy doesn't like that we access the `__init__` property.
865 self.client.__init__, # type: ignore[misc]
866 )
867 copy_signature = inspect.signature(self.client.copy)
868 exclude_params = {"transport", "proxies", "_strict_response_validation"}
869
870 for name in init_signature.parameters.keys():
871 if name in exclude_params:
872 continue
873
874 copy_param = copy_signature.parameters.get(name)
875 assert copy_param is not None, f"copy() signature is missing the {name} param"
876
877 def test_copy_build_request(self) -> None:
878 options = FinalRequestOptions(method="get", url="/foo")
879
880 def build_request(options: FinalRequestOptions) -> None:
881 client = self.client.copy()
882 client._build_request(options)
883
884 # ensure that the machinery is warmed up before tracing starts.
885 build_request(options)
886 gc.collect()
887
888 tracemalloc.start(1000)
889
890 snapshot_before = tracemalloc.take_snapshot()
891
892 ITERATIONS = 10
893 for _ in range(ITERATIONS):
894 build_request(options)
895
896 gc.collect()
897 snapshot_after = tracemalloc.take_snapshot()
898
899 tracemalloc.stop()
900
901 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
902 if diff.count == 0:
903 # Avoid false positives by considering only leaks (i.e. allocations that persist).
904 return
905
906 if diff.count % ITERATIONS != 0:
907 # Avoid false positives by considering only leaks that appear per iteration.
908 return
909
910 for frame in diff.traceback:
911 if any(
912 frame.filename.endswith(fragment)
913 for fragment in [
914 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
915 #
916 # removing the decorator fixes the leak for reasons we don't understand.
917 "openai/_legacy_response.py",
918 "openai/_response.py",
919 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
920 "openai/_compat.py",
921 # Standard library leaks we don't care about.
922 "/logging/__init__.py",
923 ]
924 ):
925 return
926
927 leaks.append(diff)
928
929 leaks: list[tracemalloc.StatisticDiff] = []
930 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
931 add_leak(leaks, diff)
932 if leaks:
933 for leak in leaks:
934 print("MEMORY LEAK:", leak)
935 for frame in leak.traceback:
936 print(frame)
937 raise AssertionError()
938
939 async def test_request_timeout(self) -> None:
940 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
941 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
942 assert timeout == DEFAULT_TIMEOUT
943
944 request = self.client._build_request(
945 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
946 )
947 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
948 assert timeout == httpx.Timeout(100.0)
949
950 async def test_client_timeout_option(self) -> None:
951 client = AsyncOpenAI(
952 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
953 )
954
955 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
956 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
957 assert timeout == httpx.Timeout(0)
958
959 async def test_http_client_timeout_option(self) -> None:
960 # custom timeout given to the httpx client should be used
961 async with httpx.AsyncClient(timeout=None) as http_client:
962 client = AsyncOpenAI(
963 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
964 )
965
966 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
967 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
968 assert timeout == httpx.Timeout(None)
969
970 # no timeout given to the httpx client should not use the httpx default
971 async with httpx.AsyncClient() as http_client:
972 client = AsyncOpenAI(
973 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
974 )
975
976 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
977 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
978 assert timeout == DEFAULT_TIMEOUT
979
980 # explicitly passing the default timeout currently results in it being ignored
981 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
982 client = AsyncOpenAI(
983 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
984 )
985
986 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
987 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
988 assert timeout == DEFAULT_TIMEOUT # our default
989
990 def test_default_headers_option(self) -> None:
991 client = AsyncOpenAI(
992 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
993 )
994 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
995 assert request.headers.get("x-foo") == "bar"
996 assert request.headers.get("x-stainless-lang") == "python"
997
998 client2 = AsyncOpenAI(
999 base_url=base_url,
1000 api_key=api_key,
1001 _strict_response_validation=True,
1002 default_headers={
1003 "X-Foo": "stainless",
1004 "X-Stainless-Lang": "my-overriding-header",
1005 },
1006 )
1007 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1008 assert request.headers.get("x-foo") == "stainless"
1009 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1010
1011 def test_validate_headers(self) -> None:
1012 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1013 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1014 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1015
1016 with pytest.raises(OpenAIError):
1017 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1018 _ = client2
1019
1020 def test_default_query_option(self) -> None:
1021 client = AsyncOpenAI(
1022 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1023 )
1024 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1025 url = httpx.URL(request.url)
1026 assert dict(url.params) == {"query_param": "bar"}
1027
1028 request = client._build_request(
1029 FinalRequestOptions(
1030 method="get",
1031 url="/foo",
1032 params={"foo": "baz", "query_param": "overriden"},
1033 )
1034 )
1035 url = httpx.URL(request.url)
1036 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
1037
1038 def test_request_extra_json(self) -> None:
1039 request = self.client._build_request(
1040 FinalRequestOptions(
1041 method="post",
1042 url="/foo",
1043 json_data={"foo": "bar"},
1044 extra_json={"baz": False},
1045 ),
1046 )
1047 data = json.loads(request.content.decode("utf-8"))
1048 assert data == {"foo": "bar", "baz": False}
1049
1050 request = self.client._build_request(
1051 FinalRequestOptions(
1052 method="post",
1053 url="/foo",
1054 extra_json={"baz": False},
1055 ),
1056 )
1057 data = json.loads(request.content.decode("utf-8"))
1058 assert data == {"baz": False}
1059
1060 # `extra_json` takes priority over `json_data` when keys clash
1061 request = self.client._build_request(
1062 FinalRequestOptions(
1063 method="post",
1064 url="/foo",
1065 json_data={"foo": "bar", "baz": True},
1066 extra_json={"baz": None},
1067 ),
1068 )
1069 data = json.loads(request.content.decode("utf-8"))
1070 assert data == {"foo": "bar", "baz": None}
1071
1072 def test_request_extra_headers(self) -> None:
1073 request = self.client._build_request(
1074 FinalRequestOptions(
1075 method="post",
1076 url="/foo",
1077 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1078 ),
1079 )
1080 assert request.headers.get("X-Foo") == "Foo"
1081
1082 # `extra_headers` takes priority over `default_headers` when keys clash
1083 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1084 FinalRequestOptions(
1085 method="post",
1086 url="/foo",
1087 **make_request_options(
1088 extra_headers={"X-Bar": "false"},
1089 ),
1090 ),
1091 )
1092 assert request.headers.get("X-Bar") == "false"
1093
1094 def test_request_extra_query(self) -> None:
1095 request = self.client._build_request(
1096 FinalRequestOptions(
1097 method="post",
1098 url="/foo",
1099 **make_request_options(
1100 extra_query={"my_query_param": "Foo"},
1101 ),
1102 ),
1103 )
1104 params = dict(request.url.params)
1105 assert params == {"my_query_param": "Foo"}
1106
1107 # if both `query` and `extra_query` are given, they are merged
1108 request = self.client._build_request(
1109 FinalRequestOptions(
1110 method="post",
1111 url="/foo",
1112 **make_request_options(
1113 query={"bar": "1"},
1114 extra_query={"foo": "2"},
1115 ),
1116 ),
1117 )
1118 params = dict(request.url.params)
1119 assert params == {"bar": "1", "foo": "2"}
1120
1121 # `extra_query` takes priority over `query` when keys clash
1122 request = self.client._build_request(
1123 FinalRequestOptions(
1124 method="post",
1125 url="/foo",
1126 **make_request_options(
1127 query={"foo": "1"},
1128 extra_query={"foo": "2"},
1129 ),
1130 ),
1131 )
1132 params = dict(request.url.params)
1133 assert params == {"foo": "2"}
1134
1135 @pytest.mark.respx(base_url=base_url)
1136 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1137 class Model1(BaseModel):
1138 name: str
1139
1140 class Model2(BaseModel):
1141 foo: str
1142
1143 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1144
1145 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1146 assert isinstance(response, Model2)
1147 assert response.foo == "bar"
1148
1149 @pytest.mark.respx(base_url=base_url)
1150 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1151 """Union of objects with the same field name using a different type"""
1152
1153 class Model1(BaseModel):
1154 foo: int
1155
1156 class Model2(BaseModel):
1157 foo: str
1158
1159 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1160
1161 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1162 assert isinstance(response, Model2)
1163 assert response.foo == "bar"
1164
1165 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1166
1167 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1168 assert isinstance(response, Model1)
1169 assert response.foo == 1
1170
1171 @pytest.mark.respx(base_url=base_url)
1172 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1173 """
1174 Response that sets Content-Type to something other than application/json but returns json data
1175 """
1176
1177 class Model(BaseModel):
1178 foo: int
1179
1180 respx_mock.get("/foo").mock(
1181 return_value=httpx.Response(
1182 200,
1183 content=json.dumps({"foo": 2}),
1184 headers={"Content-Type": "application/text"},
1185 )
1186 )
1187
1188 response = await self.client.get("/foo", cast_to=Model)
1189 assert isinstance(response, Model)
1190 assert response.foo == 2
1191
1192 def test_base_url_setter(self) -> None:
1193 client = AsyncOpenAI(
1194 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1195 )
1196 assert client.base_url == "https://example.com/from_init/"
1197
1198 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1199
1200 assert client.base_url == "https://example.com/from_setter/"
1201
1202 def test_base_url_env(self) -> None:
1203 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1204 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1205 assert client.base_url == "http://localhost:5000/from/env/"
1206
1207 @pytest.mark.parametrize(
1208 "client",
1209 [
1210 AsyncOpenAI(
1211 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1212 ),
1213 AsyncOpenAI(
1214 base_url="http://localhost:5000/custom/path/",
1215 api_key=api_key,
1216 _strict_response_validation=True,
1217 http_client=httpx.AsyncClient(),
1218 ),
1219 ],
1220 ids=["standard", "custom http client"],
1221 )
1222 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1223 request = client._build_request(
1224 FinalRequestOptions(
1225 method="post",
1226 url="/foo",
1227 json_data={"foo": "bar"},
1228 ),
1229 )
1230 assert request.url == "http://localhost:5000/custom/path/foo"
1231
1232 @pytest.mark.parametrize(
1233 "client",
1234 [
1235 AsyncOpenAI(
1236 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1237 ),
1238 AsyncOpenAI(
1239 base_url="http://localhost:5000/custom/path/",
1240 api_key=api_key,
1241 _strict_response_validation=True,
1242 http_client=httpx.AsyncClient(),
1243 ),
1244 ],
1245 ids=["standard", "custom http client"],
1246 )
1247 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1248 request = client._build_request(
1249 FinalRequestOptions(
1250 method="post",
1251 url="/foo",
1252 json_data={"foo": "bar"},
1253 ),
1254 )
1255 assert request.url == "http://localhost:5000/custom/path/foo"
1256
1257 @pytest.mark.parametrize(
1258 "client",
1259 [
1260 AsyncOpenAI(
1261 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1262 ),
1263 AsyncOpenAI(
1264 base_url="http://localhost:5000/custom/path/",
1265 api_key=api_key,
1266 _strict_response_validation=True,
1267 http_client=httpx.AsyncClient(),
1268 ),
1269 ],
1270 ids=["standard", "custom http client"],
1271 )
1272 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1273 request = client._build_request(
1274 FinalRequestOptions(
1275 method="post",
1276 url="https://myapi.com/foo",
1277 json_data={"foo": "bar"},
1278 ),
1279 )
1280 assert request.url == "https://myapi.com/foo"
1281
1282 async def test_copied_client_does_not_close_http(self) -> None:
1283 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1284 assert not client.is_closed()
1285
1286 copied = client.copy()
1287 assert copied is not client
1288
1289 del copied
1290
1291 await asyncio.sleep(0.2)
1292 assert not client.is_closed()
1293
1294 async def test_client_context_manager(self) -> None:
1295 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1296 async with client as c2:
1297 assert c2 is client
1298 assert not c2.is_closed()
1299 assert not client.is_closed()
1300 assert client.is_closed()
1301
1302 @pytest.mark.respx(base_url=base_url)
1303 @pytest.mark.asyncio
1304 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1305 class Model(BaseModel):
1306 foo: str
1307
1308 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1309
1310 with pytest.raises(APIResponseValidationError) as exc:
1311 await self.client.get("/foo", cast_to=Model)
1312
1313 assert isinstance(exc.value.__cause__, ValidationError)
1314
1315 @pytest.mark.respx(base_url=base_url)
1316 @pytest.mark.asyncio
1317 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1318 class Model(BaseModel):
1319 name: str
1320
1321 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1322
1323 stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1324 assert isinstance(stream, AsyncStream)
1325 await stream.response.aclose()
1326
1327 @pytest.mark.respx(base_url=base_url)
1328 @pytest.mark.asyncio
1329 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1330 class Model(BaseModel):
1331 name: str
1332
1333 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1334
1335 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1336
1337 with pytest.raises(APIResponseValidationError):
1338 await strict_client.get("/foo", cast_to=Model)
1339
1340 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1341
1342 response = await client.get("/foo", cast_to=Model)
1343 assert isinstance(response, str) # type: ignore[unreachable]
1344
1345 @pytest.mark.parametrize(
1346 "remaining_retries,retry_after,timeout",
1347 [
1348 [3, "20", 20],
1349 [3, "0", 0.5],
1350 [3, "-10", 0.5],
1351 [3, "60", 60],
1352 [3, "61", 0.5],
1353 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1354 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1355 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1356 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1357 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1358 [3, "99999999999999999999999999999999999", 0.5],
1359 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1360 [3, "", 0.5],
1361 [2, "", 0.5 * 2.0],
1362 [1, "", 0.5 * 4.0],
1363 ],
1364 )
1365 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1366 @pytest.mark.asyncio
1367 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1368 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1369
1370 headers = httpx.Headers({"retry-after": retry_after})
1371 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1372 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1373 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1374
1375 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1376 @pytest.mark.respx(base_url=base_url)
1377 async def test_streaming_response(self) -> None:
1378 response = await self.client.post(
1379 "/chat/completions",
1380 body=dict(
1381 messages=[
1382 {
1383 "role": "user",
1384 "content": "Say this is a test",
1385 }
1386 ],
1387 model="gpt-3.5-turbo",
1388 ),
1389 cast_to=AsyncAPIResponse[bytes],
1390 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1391 )
1392
1393 assert not cast(Any, response.is_closed)
1394 assert _get_open_connections(self.client) == 1
1395
1396 async for _ in response.iter_bytes():
1397 ...
1398
1399 assert cast(Any, response.is_closed)
1400 assert _get_open_connections(self.client) == 0
1401
1402 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1403 @pytest.mark.respx(base_url=base_url)
1404 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1405 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1406
1407 with pytest.raises(APITimeoutError):
1408 await self.client.post(
1409 "/chat/completions",
1410 body=dict(
1411 messages=[
1412 {
1413 "role": "user",
1414 "content": "Say this is a test",
1415 }
1416 ],
1417 model="gpt-3.5-turbo",
1418 ),
1419 cast_to=httpx.Response,
1420 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1421 )
1422
1423 assert _get_open_connections(self.client) == 0
1424
1425 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1426 @pytest.mark.respx(base_url=base_url)
1427 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1428 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1429
1430 with pytest.raises(APIStatusError):
1431 await self.client.post(
1432 "/chat/completions",
1433 body=dict(
1434 messages=[
1435 {
1436 "role": "user",
1437 "content": "Say this is a test",
1438 }
1439 ],
1440 model="gpt-3.5-turbo",
1441 ),
1442 cast_to=httpx.Response,
1443 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1444 )
1445
1446 assert _get_open_connections(self.client) == 0
1447