openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.14.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1481lines · modecode

1# File generated from our OpenAPI spec by Stainless.
2
3from __future__ import annotations
4
5import gc
6import os
7import json
8import asyncio
9import inspect
10import tracemalloc
11from typing import Any, Union, cast
12from unittest import mock
13
14import httpx
15import pytest
16from respx import MockRouter
17from pydantic import ValidationError
18
19from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
20from openai._client import OpenAI, AsyncOpenAI
21from openai._models import BaseModel, FinalRequestOptions
22from openai._constants import RAW_RESPONSE_HEADER
23from openai._streaming import Stream, AsyncStream
24from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
25from openai._base_client import DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, BaseClient, make_request_options
26
27from .utils import update_env
28
29base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
30api_key = "My API Key"
31
32
33def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
34 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
35 url = httpx.URL(request.url)
36 return dict(url.params)
37
38
39def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
40 return 0.1
41
42
43def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
44 transport = client._client._transport
45 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
46
47 pool = transport._pool
48 return len(pool._requests)
49
50
51class TestOpenAI:
52 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
53
54 @pytest.mark.respx(base_url=base_url)
55 def test_raw_response(self, respx_mock: MockRouter) -> None:
56 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
57
58 response = self.client.post("/foo", cast_to=httpx.Response)
59 assert response.status_code == 200
60 assert isinstance(response, httpx.Response)
61 assert response.json() == {"foo": "bar"}
62
63 @pytest.mark.respx(base_url=base_url)
64 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
65 respx_mock.post("/foo").mock(
66 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
67 )
68
69 response = self.client.post("/foo", cast_to=httpx.Response)
70 assert response.status_code == 200
71 assert isinstance(response, httpx.Response)
72 assert response.json() == {"foo": "bar"}
73
74 def test_copy(self) -> None:
75 copied = self.client.copy()
76 assert id(copied) != id(self.client)
77
78 copied = self.client.copy(api_key="another My API Key")
79 assert copied.api_key == "another My API Key"
80 assert self.client.api_key == "My API Key"
81
82 def test_copy_default_options(self) -> None:
83 # options that have a default are overridden correctly
84 copied = self.client.copy(max_retries=7)
85 assert copied.max_retries == 7
86 assert self.client.max_retries == 2
87
88 copied2 = copied.copy(max_retries=6)
89 assert copied2.max_retries == 6
90 assert copied.max_retries == 7
91
92 # timeout
93 assert isinstance(self.client.timeout, httpx.Timeout)
94 copied = self.client.copy(timeout=None)
95 assert copied.timeout is None
96 assert isinstance(self.client.timeout, httpx.Timeout)
97
98 def test_copy_default_headers(self) -> None:
99 client = OpenAI(
100 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
101 )
102 assert client.default_headers["X-Foo"] == "bar"
103
104 # does not override the already given value when not specified
105 copied = client.copy()
106 assert copied.default_headers["X-Foo"] == "bar"
107
108 # merges already given headers
109 copied = client.copy(default_headers={"X-Bar": "stainless"})
110 assert copied.default_headers["X-Foo"] == "bar"
111 assert copied.default_headers["X-Bar"] == "stainless"
112
113 # uses new values for any already given headers
114 copied = client.copy(default_headers={"X-Foo": "stainless"})
115 assert copied.default_headers["X-Foo"] == "stainless"
116
117 # set_default_headers
118
119 # completely overrides already set values
120 copied = client.copy(set_default_headers={})
121 assert copied.default_headers.get("X-Foo") is None
122
123 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
124 assert copied.default_headers["X-Bar"] == "Robert"
125
126 with pytest.raises(
127 ValueError,
128 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
129 ):
130 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
131
132 def test_copy_default_query(self) -> None:
133 client = OpenAI(
134 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
135 )
136 assert _get_params(client)["foo"] == "bar"
137
138 # does not override the already given value when not specified
139 copied = client.copy()
140 assert _get_params(copied)["foo"] == "bar"
141
142 # merges already given params
143 copied = client.copy(default_query={"bar": "stainless"})
144 params = _get_params(copied)
145 assert params["foo"] == "bar"
146 assert params["bar"] == "stainless"
147
148 # uses new values for any already given headers
149 copied = client.copy(default_query={"foo": "stainless"})
150 assert _get_params(copied)["foo"] == "stainless"
151
152 # set_default_query
153
154 # completely overrides already set values
155 copied = client.copy(set_default_query={})
156 assert _get_params(copied) == {}
157
158 copied = client.copy(set_default_query={"bar": "Robert"})
159 assert _get_params(copied)["bar"] == "Robert"
160
161 with pytest.raises(
162 ValueError,
163 # TODO: update
164 match="`default_query` and `set_default_query` arguments are mutually exclusive",
165 ):
166 client.copy(set_default_query={}, default_query={"foo": "Bar"})
167
168 def test_copy_signature(self) -> None:
169 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
170 init_signature = inspect.signature(
171 # mypy doesn't like that we access the `__init__` property.
172 self.client.__init__, # type: ignore[misc]
173 )
174 copy_signature = inspect.signature(self.client.copy)
175 exclude_params = {"transport", "proxies", "_strict_response_validation"}
176
177 for name in init_signature.parameters.keys():
178 if name in exclude_params:
179 continue
180
181 copy_param = copy_signature.parameters.get(name)
182 assert copy_param is not None, f"copy() signature is missing the {name} param"
183
184 def test_copy_build_request(self) -> None:
185 options = FinalRequestOptions(method="get", url="/foo")
186
187 def build_request(options: FinalRequestOptions) -> None:
188 client = self.client.copy()
189 client._build_request(options)
190
191 # ensure that the machinery is warmed up before tracing starts.
192 build_request(options)
193 gc.collect()
194
195 tracemalloc.start(1000)
196
197 snapshot_before = tracemalloc.take_snapshot()
198
199 ITERATIONS = 10
200 for _ in range(ITERATIONS):
201 build_request(options)
202
203 gc.collect()
204 snapshot_after = tracemalloc.take_snapshot()
205
206 tracemalloc.stop()
207
208 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
209 if diff.count == 0:
210 # Avoid false positives by considering only leaks (i.e. allocations that persist).
211 return
212
213 if diff.count % ITERATIONS != 0:
214 # Avoid false positives by considering only leaks that appear per iteration.
215 return
216
217 for frame in diff.traceback:
218 if any(
219 frame.filename.endswith(fragment)
220 for fragment in [
221 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
222 #
223 # removing the decorator fixes the leak for reasons we don't understand.
224 "openai/_legacy_response.py",
225 "openai/_response.py",
226 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
227 "openai/_compat.py",
228 # Standard library leaks we don't care about.
229 "/logging/__init__.py",
230 ]
231 ):
232 return
233
234 leaks.append(diff)
235
236 leaks: list[tracemalloc.StatisticDiff] = []
237 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
238 add_leak(leaks, diff)
239 if leaks:
240 for leak in leaks:
241 print("MEMORY LEAK:", leak)
242 for frame in leak.traceback:
243 print(frame)
244 raise AssertionError()
245
246 def test_request_timeout(self) -> None:
247 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
248 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
249 assert timeout == DEFAULT_TIMEOUT
250
251 request = self.client._build_request(
252 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
253 )
254 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
255 assert timeout == httpx.Timeout(100.0)
256
257 def test_client_timeout_option(self) -> None:
258 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
259
260 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
261 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
262 assert timeout == httpx.Timeout(0)
263
264 def test_http_client_timeout_option(self) -> None:
265 # custom timeout given to the httpx client should be used
266 with httpx.Client(timeout=None) as http_client:
267 client = OpenAI(
268 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
269 )
270
271 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
272 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
273 assert timeout == httpx.Timeout(None)
274
275 # no timeout given to the httpx client should not use the httpx default
276 with httpx.Client() as http_client:
277 client = OpenAI(
278 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
279 )
280
281 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
282 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
283 assert timeout == DEFAULT_TIMEOUT
284
285 # explicitly passing the default timeout currently results in it being ignored
286 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
287 client = OpenAI(
288 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
289 )
290
291 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
292 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
293 assert timeout == DEFAULT_TIMEOUT # our default
294
295 async def test_invalid_http_client(self) -> None:
296 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
297 async with httpx.AsyncClient() as http_client:
298 OpenAI(
299 base_url=base_url,
300 api_key=api_key,
301 _strict_response_validation=True,
302 http_client=cast(Any, http_client),
303 )
304
305 def test_default_headers_option(self) -> None:
306 client = OpenAI(
307 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
308 )
309 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
310 assert request.headers.get("x-foo") == "bar"
311 assert request.headers.get("x-stainless-lang") == "python"
312
313 client2 = OpenAI(
314 base_url=base_url,
315 api_key=api_key,
316 _strict_response_validation=True,
317 default_headers={
318 "X-Foo": "stainless",
319 "X-Stainless-Lang": "my-overriding-header",
320 },
321 )
322 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
323 assert request.headers.get("x-foo") == "stainless"
324 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
325
326 def test_validate_headers(self) -> None:
327 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
328 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
329 assert request.headers.get("Authorization") == f"Bearer {api_key}"
330
331 with pytest.raises(OpenAIError):
332 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
333 _ = client2
334
335 def test_default_query_option(self) -> None:
336 client = OpenAI(
337 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
338 )
339 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
340 url = httpx.URL(request.url)
341 assert dict(url.params) == {"query_param": "bar"}
342
343 request = client._build_request(
344 FinalRequestOptions(
345 method="get",
346 url="/foo",
347 params={"foo": "baz", "query_param": "overriden"},
348 )
349 )
350 url = httpx.URL(request.url)
351 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
352
353 def test_request_extra_json(self) -> None:
354 request = self.client._build_request(
355 FinalRequestOptions(
356 method="post",
357 url="/foo",
358 json_data={"foo": "bar"},
359 extra_json={"baz": False},
360 ),
361 )
362 data = json.loads(request.content.decode("utf-8"))
363 assert data == {"foo": "bar", "baz": False}
364
365 request = self.client._build_request(
366 FinalRequestOptions(
367 method="post",
368 url="/foo",
369 extra_json={"baz": False},
370 ),
371 )
372 data = json.loads(request.content.decode("utf-8"))
373 assert data == {"baz": False}
374
375 # `extra_json` takes priority over `json_data` when keys clash
376 request = self.client._build_request(
377 FinalRequestOptions(
378 method="post",
379 url="/foo",
380 json_data={"foo": "bar", "baz": True},
381 extra_json={"baz": None},
382 ),
383 )
384 data = json.loads(request.content.decode("utf-8"))
385 assert data == {"foo": "bar", "baz": None}
386
387 def test_request_extra_headers(self) -> None:
388 request = self.client._build_request(
389 FinalRequestOptions(
390 method="post",
391 url="/foo",
392 **make_request_options(extra_headers={"X-Foo": "Foo"}),
393 ),
394 )
395 assert request.headers.get("X-Foo") == "Foo"
396
397 # `extra_headers` takes priority over `default_headers` when keys clash
398 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
399 FinalRequestOptions(
400 method="post",
401 url="/foo",
402 **make_request_options(
403 extra_headers={"X-Bar": "false"},
404 ),
405 ),
406 )
407 assert request.headers.get("X-Bar") == "false"
408
409 def test_request_extra_query(self) -> None:
410 request = self.client._build_request(
411 FinalRequestOptions(
412 method="post",
413 url="/foo",
414 **make_request_options(
415 extra_query={"my_query_param": "Foo"},
416 ),
417 ),
418 )
419 params = dict(request.url.params)
420 assert params == {"my_query_param": "Foo"}
421
422 # if both `query` and `extra_query` are given, they are merged
423 request = self.client._build_request(
424 FinalRequestOptions(
425 method="post",
426 url="/foo",
427 **make_request_options(
428 query={"bar": "1"},
429 extra_query={"foo": "2"},
430 ),
431 ),
432 )
433 params = dict(request.url.params)
434 assert params == {"bar": "1", "foo": "2"}
435
436 # `extra_query` takes priority over `query` when keys clash
437 request = self.client._build_request(
438 FinalRequestOptions(
439 method="post",
440 url="/foo",
441 **make_request_options(
442 query={"foo": "1"},
443 extra_query={"foo": "2"},
444 ),
445 ),
446 )
447 params = dict(request.url.params)
448 assert params == {"foo": "2"}
449
450 def test_multipart_repeating_array(self, client: OpenAI) -> None:
451 request = client._build_request(
452 FinalRequestOptions.construct(
453 method="get",
454 url="/foo",
455 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
456 json_data={"array": ["foo", "bar"]},
457 files=[("foo.txt", b"hello world")],
458 )
459 )
460
461 assert request.read().split(b"\r\n") == [
462 b"--6b7ba517decee4a450543ea6ae821c82",
463 b'Content-Disposition: form-data; name="array[]"',
464 b"",
465 b"foo",
466 b"--6b7ba517decee4a450543ea6ae821c82",
467 b'Content-Disposition: form-data; name="array[]"',
468 b"",
469 b"bar",
470 b"--6b7ba517decee4a450543ea6ae821c82",
471 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
472 b"Content-Type: application/octet-stream",
473 b"",
474 b"hello world",
475 b"--6b7ba517decee4a450543ea6ae821c82--",
476 b"",
477 ]
478
479 @pytest.mark.respx(base_url=base_url)
480 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
481 class Model1(BaseModel):
482 name: str
483
484 class Model2(BaseModel):
485 foo: str
486
487 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
488
489 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
490 assert isinstance(response, Model2)
491 assert response.foo == "bar"
492
493 @pytest.mark.respx(base_url=base_url)
494 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
495 """Union of objects with the same field name using a different type"""
496
497 class Model1(BaseModel):
498 foo: int
499
500 class Model2(BaseModel):
501 foo: str
502
503 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
504
505 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
506 assert isinstance(response, Model2)
507 assert response.foo == "bar"
508
509 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
510
511 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
512 assert isinstance(response, Model1)
513 assert response.foo == 1
514
515 @pytest.mark.respx(base_url=base_url)
516 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
517 """
518 Response that sets Content-Type to something other than application/json but returns json data
519 """
520
521 class Model(BaseModel):
522 foo: int
523
524 respx_mock.get("/foo").mock(
525 return_value=httpx.Response(
526 200,
527 content=json.dumps({"foo": 2}),
528 headers={"Content-Type": "application/text"},
529 )
530 )
531
532 response = self.client.get("/foo", cast_to=Model)
533 assert isinstance(response, Model)
534 assert response.foo == 2
535
536 def test_base_url_setter(self) -> None:
537 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
538 assert client.base_url == "https://example.com/from_init/"
539
540 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
541
542 assert client.base_url == "https://example.com/from_setter/"
543
544 def test_base_url_env(self) -> None:
545 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
546 client = OpenAI(api_key=api_key, _strict_response_validation=True)
547 assert client.base_url == "http://localhost:5000/from/env/"
548
549 @pytest.mark.parametrize(
550 "client",
551 [
552 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
553 OpenAI(
554 base_url="http://localhost:5000/custom/path/",
555 api_key=api_key,
556 _strict_response_validation=True,
557 http_client=httpx.Client(),
558 ),
559 ],
560 ids=["standard", "custom http client"],
561 )
562 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
563 request = client._build_request(
564 FinalRequestOptions(
565 method="post",
566 url="/foo",
567 json_data={"foo": "bar"},
568 ),
569 )
570 assert request.url == "http://localhost:5000/custom/path/foo"
571
572 @pytest.mark.parametrize(
573 "client",
574 [
575 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
576 OpenAI(
577 base_url="http://localhost:5000/custom/path/",
578 api_key=api_key,
579 _strict_response_validation=True,
580 http_client=httpx.Client(),
581 ),
582 ],
583 ids=["standard", "custom http client"],
584 )
585 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
586 request = client._build_request(
587 FinalRequestOptions(
588 method="post",
589 url="/foo",
590 json_data={"foo": "bar"},
591 ),
592 )
593 assert request.url == "http://localhost:5000/custom/path/foo"
594
595 @pytest.mark.parametrize(
596 "client",
597 [
598 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
599 OpenAI(
600 base_url="http://localhost:5000/custom/path/",
601 api_key=api_key,
602 _strict_response_validation=True,
603 http_client=httpx.Client(),
604 ),
605 ],
606 ids=["standard", "custom http client"],
607 )
608 def test_absolute_request_url(self, client: OpenAI) -> None:
609 request = client._build_request(
610 FinalRequestOptions(
611 method="post",
612 url="https://myapi.com/foo",
613 json_data={"foo": "bar"},
614 ),
615 )
616 assert request.url == "https://myapi.com/foo"
617
618 def test_copied_client_does_not_close_http(self) -> None:
619 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
620 assert not client.is_closed()
621
622 copied = client.copy()
623 assert copied is not client
624
625 del copied
626
627 assert not client.is_closed()
628
629 def test_client_context_manager(self) -> None:
630 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
631 with client as c2:
632 assert c2 is client
633 assert not c2.is_closed()
634 assert not client.is_closed()
635 assert client.is_closed()
636
637 @pytest.mark.respx(base_url=base_url)
638 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
639 class Model(BaseModel):
640 foo: str
641
642 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
643
644 with pytest.raises(APIResponseValidationError) as exc:
645 self.client.get("/foo", cast_to=Model)
646
647 assert isinstance(exc.value.__cause__, ValidationError)
648
649 @pytest.mark.respx(base_url=base_url)
650 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
651 class Model(BaseModel):
652 name: str
653
654 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
655
656 stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
657 assert isinstance(stream, Stream)
658 stream.response.close()
659
660 @pytest.mark.respx(base_url=base_url)
661 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
662 class Model(BaseModel):
663 name: str
664
665 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
666
667 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
668
669 with pytest.raises(APIResponseValidationError):
670 strict_client.get("/foo", cast_to=Model)
671
672 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
673
674 response = client.get("/foo", cast_to=Model)
675 assert isinstance(response, str) # type: ignore[unreachable]
676
677 @pytest.mark.parametrize(
678 "remaining_retries,retry_after,timeout",
679 [
680 [3, "20", 20],
681 [3, "0", 0.5],
682 [3, "-10", 0.5],
683 [3, "60", 60],
684 [3, "61", 0.5],
685 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
686 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
687 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
688 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
689 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
690 [3, "99999999999999999999999999999999999", 0.5],
691 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
692 [3, "", 0.5],
693 [2, "", 0.5 * 2.0],
694 [1, "", 0.5 * 4.0],
695 ],
696 )
697 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
698 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
699 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
700
701 headers = httpx.Headers({"retry-after": retry_after})
702 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
703 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
704 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
705
706 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
707 @pytest.mark.respx(base_url=base_url)
708 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
709 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
710
711 with pytest.raises(APITimeoutError):
712 self.client.post(
713 "/chat/completions",
714 body=cast(
715 object,
716 dict(
717 messages=[
718 {
719 "role": "user",
720 "content": "Say this is a test",
721 }
722 ],
723 model="gpt-3.5-turbo",
724 ),
725 ),
726 cast_to=httpx.Response,
727 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
728 )
729
730 assert _get_open_connections(self.client) == 0
731
732 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
733 @pytest.mark.respx(base_url=base_url)
734 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
735 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
736
737 with pytest.raises(APIStatusError):
738 self.client.post(
739 "/chat/completions",
740 body=cast(
741 object,
742 dict(
743 messages=[
744 {
745 "role": "user",
746 "content": "Say this is a test",
747 }
748 ],
749 model="gpt-3.5-turbo",
750 ),
751 ),
752 cast_to=httpx.Response,
753 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
754 )
755
756 assert _get_open_connections(self.client) == 0
757
758
759class TestAsyncOpenAI:
760 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
761
762 @pytest.mark.respx(base_url=base_url)
763 @pytest.mark.asyncio
764 async def test_raw_response(self, respx_mock: MockRouter) -> None:
765 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
766
767 response = await self.client.post("/foo", cast_to=httpx.Response)
768 assert response.status_code == 200
769 assert isinstance(response, httpx.Response)
770 assert response.json() == {"foo": "bar"}
771
772 @pytest.mark.respx(base_url=base_url)
773 @pytest.mark.asyncio
774 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
775 respx_mock.post("/foo").mock(
776 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
777 )
778
779 response = await self.client.post("/foo", cast_to=httpx.Response)
780 assert response.status_code == 200
781 assert isinstance(response, httpx.Response)
782 assert response.json() == {"foo": "bar"}
783
784 def test_copy(self) -> None:
785 copied = self.client.copy()
786 assert id(copied) != id(self.client)
787
788 copied = self.client.copy(api_key="another My API Key")
789 assert copied.api_key == "another My API Key"
790 assert self.client.api_key == "My API Key"
791
792 def test_copy_default_options(self) -> None:
793 # options that have a default are overridden correctly
794 copied = self.client.copy(max_retries=7)
795 assert copied.max_retries == 7
796 assert self.client.max_retries == 2
797
798 copied2 = copied.copy(max_retries=6)
799 assert copied2.max_retries == 6
800 assert copied.max_retries == 7
801
802 # timeout
803 assert isinstance(self.client.timeout, httpx.Timeout)
804 copied = self.client.copy(timeout=None)
805 assert copied.timeout is None
806 assert isinstance(self.client.timeout, httpx.Timeout)
807
808 def test_copy_default_headers(self) -> None:
809 client = AsyncOpenAI(
810 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
811 )
812 assert client.default_headers["X-Foo"] == "bar"
813
814 # does not override the already given value when not specified
815 copied = client.copy()
816 assert copied.default_headers["X-Foo"] == "bar"
817
818 # merges already given headers
819 copied = client.copy(default_headers={"X-Bar": "stainless"})
820 assert copied.default_headers["X-Foo"] == "bar"
821 assert copied.default_headers["X-Bar"] == "stainless"
822
823 # uses new values for any already given headers
824 copied = client.copy(default_headers={"X-Foo": "stainless"})
825 assert copied.default_headers["X-Foo"] == "stainless"
826
827 # set_default_headers
828
829 # completely overrides already set values
830 copied = client.copy(set_default_headers={})
831 assert copied.default_headers.get("X-Foo") is None
832
833 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
834 assert copied.default_headers["X-Bar"] == "Robert"
835
836 with pytest.raises(
837 ValueError,
838 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
839 ):
840 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
841
842 def test_copy_default_query(self) -> None:
843 client = AsyncOpenAI(
844 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
845 )
846 assert _get_params(client)["foo"] == "bar"
847
848 # does not override the already given value when not specified
849 copied = client.copy()
850 assert _get_params(copied)["foo"] == "bar"
851
852 # merges already given params
853 copied = client.copy(default_query={"bar": "stainless"})
854 params = _get_params(copied)
855 assert params["foo"] == "bar"
856 assert params["bar"] == "stainless"
857
858 # uses new values for any already given headers
859 copied = client.copy(default_query={"foo": "stainless"})
860 assert _get_params(copied)["foo"] == "stainless"
861
862 # set_default_query
863
864 # completely overrides already set values
865 copied = client.copy(set_default_query={})
866 assert _get_params(copied) == {}
867
868 copied = client.copy(set_default_query={"bar": "Robert"})
869 assert _get_params(copied)["bar"] == "Robert"
870
871 with pytest.raises(
872 ValueError,
873 # TODO: update
874 match="`default_query` and `set_default_query` arguments are mutually exclusive",
875 ):
876 client.copy(set_default_query={}, default_query={"foo": "Bar"})
877
878 def test_copy_signature(self) -> None:
879 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
880 init_signature = inspect.signature(
881 # mypy doesn't like that we access the `__init__` property.
882 self.client.__init__, # type: ignore[misc]
883 )
884 copy_signature = inspect.signature(self.client.copy)
885 exclude_params = {"transport", "proxies", "_strict_response_validation"}
886
887 for name in init_signature.parameters.keys():
888 if name in exclude_params:
889 continue
890
891 copy_param = copy_signature.parameters.get(name)
892 assert copy_param is not None, f"copy() signature is missing the {name} param"
893
894 def test_copy_build_request(self) -> None:
895 options = FinalRequestOptions(method="get", url="/foo")
896
897 def build_request(options: FinalRequestOptions) -> None:
898 client = self.client.copy()
899 client._build_request(options)
900
901 # ensure that the machinery is warmed up before tracing starts.
902 build_request(options)
903 gc.collect()
904
905 tracemalloc.start(1000)
906
907 snapshot_before = tracemalloc.take_snapshot()
908
909 ITERATIONS = 10
910 for _ in range(ITERATIONS):
911 build_request(options)
912
913 gc.collect()
914 snapshot_after = tracemalloc.take_snapshot()
915
916 tracemalloc.stop()
917
918 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
919 if diff.count == 0:
920 # Avoid false positives by considering only leaks (i.e. allocations that persist).
921 return
922
923 if diff.count % ITERATIONS != 0:
924 # Avoid false positives by considering only leaks that appear per iteration.
925 return
926
927 for frame in diff.traceback:
928 if any(
929 frame.filename.endswith(fragment)
930 for fragment in [
931 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
932 #
933 # removing the decorator fixes the leak for reasons we don't understand.
934 "openai/_legacy_response.py",
935 "openai/_response.py",
936 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
937 "openai/_compat.py",
938 # Standard library leaks we don't care about.
939 "/logging/__init__.py",
940 ]
941 ):
942 return
943
944 leaks.append(diff)
945
946 leaks: list[tracemalloc.StatisticDiff] = []
947 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
948 add_leak(leaks, diff)
949 if leaks:
950 for leak in leaks:
951 print("MEMORY LEAK:", leak)
952 for frame in leak.traceback:
953 print(frame)
954 raise AssertionError()
955
956 async def test_request_timeout(self) -> None:
957 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
958 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
959 assert timeout == DEFAULT_TIMEOUT
960
961 request = self.client._build_request(
962 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
963 )
964 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
965 assert timeout == httpx.Timeout(100.0)
966
967 async def test_client_timeout_option(self) -> None:
968 client = AsyncOpenAI(
969 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
970 )
971
972 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
973 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
974 assert timeout == httpx.Timeout(0)
975
976 async def test_http_client_timeout_option(self) -> None:
977 # custom timeout given to the httpx client should be used
978 async with httpx.AsyncClient(timeout=None) as http_client:
979 client = AsyncOpenAI(
980 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
981 )
982
983 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
984 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
985 assert timeout == httpx.Timeout(None)
986
987 # no timeout given to the httpx client should not use the httpx default
988 async with httpx.AsyncClient() as http_client:
989 client = AsyncOpenAI(
990 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
991 )
992
993 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
994 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
995 assert timeout == DEFAULT_TIMEOUT
996
997 # explicitly passing the default timeout currently results in it being ignored
998 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
999 client = AsyncOpenAI(
1000 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1001 )
1002
1003 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1004 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1005 assert timeout == DEFAULT_TIMEOUT # our default
1006
1007 def test_invalid_http_client(self) -> None:
1008 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1009 with httpx.Client() as http_client:
1010 AsyncOpenAI(
1011 base_url=base_url,
1012 api_key=api_key,
1013 _strict_response_validation=True,
1014 http_client=cast(Any, http_client),
1015 )
1016
1017 def test_default_headers_option(self) -> None:
1018 client = AsyncOpenAI(
1019 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1020 )
1021 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1022 assert request.headers.get("x-foo") == "bar"
1023 assert request.headers.get("x-stainless-lang") == "python"
1024
1025 client2 = AsyncOpenAI(
1026 base_url=base_url,
1027 api_key=api_key,
1028 _strict_response_validation=True,
1029 default_headers={
1030 "X-Foo": "stainless",
1031 "X-Stainless-Lang": "my-overriding-header",
1032 },
1033 )
1034 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1035 assert request.headers.get("x-foo") == "stainless"
1036 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1037
1038 def test_validate_headers(self) -> None:
1039 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1040 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1041 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1042
1043 with pytest.raises(OpenAIError):
1044 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1045 _ = client2
1046
1047 def test_default_query_option(self) -> None:
1048 client = AsyncOpenAI(
1049 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1050 )
1051 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1052 url = httpx.URL(request.url)
1053 assert dict(url.params) == {"query_param": "bar"}
1054
1055 request = client._build_request(
1056 FinalRequestOptions(
1057 method="get",
1058 url="/foo",
1059 params={"foo": "baz", "query_param": "overriden"},
1060 )
1061 )
1062 url = httpx.URL(request.url)
1063 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
1064
1065 def test_request_extra_json(self) -> None:
1066 request = self.client._build_request(
1067 FinalRequestOptions(
1068 method="post",
1069 url="/foo",
1070 json_data={"foo": "bar"},
1071 extra_json={"baz": False},
1072 ),
1073 )
1074 data = json.loads(request.content.decode("utf-8"))
1075 assert data == {"foo": "bar", "baz": False}
1076
1077 request = self.client._build_request(
1078 FinalRequestOptions(
1079 method="post",
1080 url="/foo",
1081 extra_json={"baz": False},
1082 ),
1083 )
1084 data = json.loads(request.content.decode("utf-8"))
1085 assert data == {"baz": False}
1086
1087 # `extra_json` takes priority over `json_data` when keys clash
1088 request = self.client._build_request(
1089 FinalRequestOptions(
1090 method="post",
1091 url="/foo",
1092 json_data={"foo": "bar", "baz": True},
1093 extra_json={"baz": None},
1094 ),
1095 )
1096 data = json.loads(request.content.decode("utf-8"))
1097 assert data == {"foo": "bar", "baz": None}
1098
1099 def test_request_extra_headers(self) -> None:
1100 request = self.client._build_request(
1101 FinalRequestOptions(
1102 method="post",
1103 url="/foo",
1104 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1105 ),
1106 )
1107 assert request.headers.get("X-Foo") == "Foo"
1108
1109 # `extra_headers` takes priority over `default_headers` when keys clash
1110 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1111 FinalRequestOptions(
1112 method="post",
1113 url="/foo",
1114 **make_request_options(
1115 extra_headers={"X-Bar": "false"},
1116 ),
1117 ),
1118 )
1119 assert request.headers.get("X-Bar") == "false"
1120
1121 def test_request_extra_query(self) -> None:
1122 request = self.client._build_request(
1123 FinalRequestOptions(
1124 method="post",
1125 url="/foo",
1126 **make_request_options(
1127 extra_query={"my_query_param": "Foo"},
1128 ),
1129 ),
1130 )
1131 params = dict(request.url.params)
1132 assert params == {"my_query_param": "Foo"}
1133
1134 # if both `query` and `extra_query` are given, they are merged
1135 request = self.client._build_request(
1136 FinalRequestOptions(
1137 method="post",
1138 url="/foo",
1139 **make_request_options(
1140 query={"bar": "1"},
1141 extra_query={"foo": "2"},
1142 ),
1143 ),
1144 )
1145 params = dict(request.url.params)
1146 assert params == {"bar": "1", "foo": "2"}
1147
1148 # `extra_query` takes priority over `query` when keys clash
1149 request = self.client._build_request(
1150 FinalRequestOptions(
1151 method="post",
1152 url="/foo",
1153 **make_request_options(
1154 query={"foo": "1"},
1155 extra_query={"foo": "2"},
1156 ),
1157 ),
1158 )
1159 params = dict(request.url.params)
1160 assert params == {"foo": "2"}
1161
1162 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1163 request = async_client._build_request(
1164 FinalRequestOptions.construct(
1165 method="get",
1166 url="/foo",
1167 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1168 json_data={"array": ["foo", "bar"]},
1169 files=[("foo.txt", b"hello world")],
1170 )
1171 )
1172
1173 assert request.read().split(b"\r\n") == [
1174 b"--6b7ba517decee4a450543ea6ae821c82",
1175 b'Content-Disposition: form-data; name="array[]"',
1176 b"",
1177 b"foo",
1178 b"--6b7ba517decee4a450543ea6ae821c82",
1179 b'Content-Disposition: form-data; name="array[]"',
1180 b"",
1181 b"bar",
1182 b"--6b7ba517decee4a450543ea6ae821c82",
1183 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1184 b"Content-Type: application/octet-stream",
1185 b"",
1186 b"hello world",
1187 b"--6b7ba517decee4a450543ea6ae821c82--",
1188 b"",
1189 ]
1190
1191 @pytest.mark.respx(base_url=base_url)
1192 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1193 class Model1(BaseModel):
1194 name: str
1195
1196 class Model2(BaseModel):
1197 foo: str
1198
1199 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1200
1201 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1202 assert isinstance(response, Model2)
1203 assert response.foo == "bar"
1204
1205 @pytest.mark.respx(base_url=base_url)
1206 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1207 """Union of objects with the same field name using a different type"""
1208
1209 class Model1(BaseModel):
1210 foo: int
1211
1212 class Model2(BaseModel):
1213 foo: str
1214
1215 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1216
1217 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1218 assert isinstance(response, Model2)
1219 assert response.foo == "bar"
1220
1221 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1222
1223 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1224 assert isinstance(response, Model1)
1225 assert response.foo == 1
1226
1227 @pytest.mark.respx(base_url=base_url)
1228 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1229 """
1230 Response that sets Content-Type to something other than application/json but returns json data
1231 """
1232
1233 class Model(BaseModel):
1234 foo: int
1235
1236 respx_mock.get("/foo").mock(
1237 return_value=httpx.Response(
1238 200,
1239 content=json.dumps({"foo": 2}),
1240 headers={"Content-Type": "application/text"},
1241 )
1242 )
1243
1244 response = await self.client.get("/foo", cast_to=Model)
1245 assert isinstance(response, Model)
1246 assert response.foo == 2
1247
1248 def test_base_url_setter(self) -> None:
1249 client = AsyncOpenAI(
1250 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1251 )
1252 assert client.base_url == "https://example.com/from_init/"
1253
1254 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1255
1256 assert client.base_url == "https://example.com/from_setter/"
1257
1258 def test_base_url_env(self) -> None:
1259 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1260 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1261 assert client.base_url == "http://localhost:5000/from/env/"
1262
1263 @pytest.mark.parametrize(
1264 "client",
1265 [
1266 AsyncOpenAI(
1267 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1268 ),
1269 AsyncOpenAI(
1270 base_url="http://localhost:5000/custom/path/",
1271 api_key=api_key,
1272 _strict_response_validation=True,
1273 http_client=httpx.AsyncClient(),
1274 ),
1275 ],
1276 ids=["standard", "custom http client"],
1277 )
1278 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1279 request = client._build_request(
1280 FinalRequestOptions(
1281 method="post",
1282 url="/foo",
1283 json_data={"foo": "bar"},
1284 ),
1285 )
1286 assert request.url == "http://localhost:5000/custom/path/foo"
1287
1288 @pytest.mark.parametrize(
1289 "client",
1290 [
1291 AsyncOpenAI(
1292 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1293 ),
1294 AsyncOpenAI(
1295 base_url="http://localhost:5000/custom/path/",
1296 api_key=api_key,
1297 _strict_response_validation=True,
1298 http_client=httpx.AsyncClient(),
1299 ),
1300 ],
1301 ids=["standard", "custom http client"],
1302 )
1303 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1304 request = client._build_request(
1305 FinalRequestOptions(
1306 method="post",
1307 url="/foo",
1308 json_data={"foo": "bar"},
1309 ),
1310 )
1311 assert request.url == "http://localhost:5000/custom/path/foo"
1312
1313 @pytest.mark.parametrize(
1314 "client",
1315 [
1316 AsyncOpenAI(
1317 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1318 ),
1319 AsyncOpenAI(
1320 base_url="http://localhost:5000/custom/path/",
1321 api_key=api_key,
1322 _strict_response_validation=True,
1323 http_client=httpx.AsyncClient(),
1324 ),
1325 ],
1326 ids=["standard", "custom http client"],
1327 )
1328 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1329 request = client._build_request(
1330 FinalRequestOptions(
1331 method="post",
1332 url="https://myapi.com/foo",
1333 json_data={"foo": "bar"},
1334 ),
1335 )
1336 assert request.url == "https://myapi.com/foo"
1337
1338 async def test_copied_client_does_not_close_http(self) -> None:
1339 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1340 assert not client.is_closed()
1341
1342 copied = client.copy()
1343 assert copied is not client
1344
1345 del copied
1346
1347 await asyncio.sleep(0.2)
1348 assert not client.is_closed()
1349
1350 async def test_client_context_manager(self) -> None:
1351 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1352 async with client as c2:
1353 assert c2 is client
1354 assert not c2.is_closed()
1355 assert not client.is_closed()
1356 assert client.is_closed()
1357
1358 @pytest.mark.respx(base_url=base_url)
1359 @pytest.mark.asyncio
1360 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1361 class Model(BaseModel):
1362 foo: str
1363
1364 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1365
1366 with pytest.raises(APIResponseValidationError) as exc:
1367 await self.client.get("/foo", cast_to=Model)
1368
1369 assert isinstance(exc.value.__cause__, ValidationError)
1370
1371 @pytest.mark.respx(base_url=base_url)
1372 @pytest.mark.asyncio
1373 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1374 class Model(BaseModel):
1375 name: str
1376
1377 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1378
1379 stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1380 assert isinstance(stream, AsyncStream)
1381 await stream.response.aclose()
1382
1383 @pytest.mark.respx(base_url=base_url)
1384 @pytest.mark.asyncio
1385 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1386 class Model(BaseModel):
1387 name: str
1388
1389 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1390
1391 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1392
1393 with pytest.raises(APIResponseValidationError):
1394 await strict_client.get("/foo", cast_to=Model)
1395
1396 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1397
1398 response = await client.get("/foo", cast_to=Model)
1399 assert isinstance(response, str) # type: ignore[unreachable]
1400
1401 @pytest.mark.parametrize(
1402 "remaining_retries,retry_after,timeout",
1403 [
1404 [3, "20", 20],
1405 [3, "0", 0.5],
1406 [3, "-10", 0.5],
1407 [3, "60", 60],
1408 [3, "61", 0.5],
1409 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1410 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1411 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1412 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1413 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1414 [3, "99999999999999999999999999999999999", 0.5],
1415 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1416 [3, "", 0.5],
1417 [2, "", 0.5 * 2.0],
1418 [1, "", 0.5 * 4.0],
1419 ],
1420 )
1421 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1422 @pytest.mark.asyncio
1423 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1424 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1425
1426 headers = httpx.Headers({"retry-after": retry_after})
1427 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1428 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1429 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1430
1431 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1432 @pytest.mark.respx(base_url=base_url)
1433 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1434 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1435
1436 with pytest.raises(APITimeoutError):
1437 await self.client.post(
1438 "/chat/completions",
1439 body=cast(
1440 object,
1441 dict(
1442 messages=[
1443 {
1444 "role": "user",
1445 "content": "Say this is a test",
1446 }
1447 ],
1448 model="gpt-3.5-turbo",
1449 ),
1450 ),
1451 cast_to=httpx.Response,
1452 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1453 )
1454
1455 assert _get_open_connections(self.client) == 0
1456
1457 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1458 @pytest.mark.respx(base_url=base_url)
1459 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1460 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1461
1462 with pytest.raises(APIStatusError):
1463 await self.client.post(
1464 "/chat/completions",
1465 body=cast(
1466 object,
1467 dict(
1468 messages=[
1469 {
1470 "role": "user",
1471 "content": "Say this is a test",
1472 }
1473 ],
1474 model="gpt-3.5-turbo",
1475 ),
1476 ),
1477 cast_to=httpx.Response,
1478 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1479 )
1480
1481 assert _get_open_connections(self.client) == 0
1482