openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.20.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1491lines · modecode

1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import gc
6import os
7import json
8import asyncio
9import inspect
10import tracemalloc
11from typing import Any, Union, cast
12from unittest import mock
13
14import httpx
15import pytest
16from respx import MockRouter
17from pydantic import ValidationError
18
19from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
20from openai._client import OpenAI, AsyncOpenAI
21from openai._models import BaseModel, FinalRequestOptions
22from openai._constants import RAW_RESPONSE_HEADER
23from openai._streaming import Stream, AsyncStream
24from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
25from openai._base_client import DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, BaseClient, make_request_options
26
27from .utils import update_env
28
29base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
30api_key = "My API Key"
31
32
33def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
34 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
35 url = httpx.URL(request.url)
36 return dict(url.params)
37
38
39def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
40 return 0.1
41
42
43def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
44 transport = client._client._transport
45 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
46
47 pool = transport._pool
48 return len(pool._requests)
49
50
51class TestOpenAI:
52 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
53
54 @pytest.mark.respx(base_url=base_url)
55 def test_raw_response(self, respx_mock: MockRouter) -> None:
56 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
57
58 response = self.client.post("/foo", cast_to=httpx.Response)
59 assert response.status_code == 200
60 assert isinstance(response, httpx.Response)
61 assert response.json() == {"foo": "bar"}
62
63 @pytest.mark.respx(base_url=base_url)
64 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
65 respx_mock.post("/foo").mock(
66 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
67 )
68
69 response = self.client.post("/foo", cast_to=httpx.Response)
70 assert response.status_code == 200
71 assert isinstance(response, httpx.Response)
72 assert response.json() == {"foo": "bar"}
73
74 def test_copy(self) -> None:
75 copied = self.client.copy()
76 assert id(copied) != id(self.client)
77
78 copied = self.client.copy(api_key="another My API Key")
79 assert copied.api_key == "another My API Key"
80 assert self.client.api_key == "My API Key"
81
82 def test_copy_default_options(self) -> None:
83 # options that have a default are overridden correctly
84 copied = self.client.copy(max_retries=7)
85 assert copied.max_retries == 7
86 assert self.client.max_retries == 2
87
88 copied2 = copied.copy(max_retries=6)
89 assert copied2.max_retries == 6
90 assert copied.max_retries == 7
91
92 # timeout
93 assert isinstance(self.client.timeout, httpx.Timeout)
94 copied = self.client.copy(timeout=None)
95 assert copied.timeout is None
96 assert isinstance(self.client.timeout, httpx.Timeout)
97
98 def test_copy_default_headers(self) -> None:
99 client = OpenAI(
100 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
101 )
102 assert client.default_headers["X-Foo"] == "bar"
103
104 # does not override the already given value when not specified
105 copied = client.copy()
106 assert copied.default_headers["X-Foo"] == "bar"
107
108 # merges already given headers
109 copied = client.copy(default_headers={"X-Bar": "stainless"})
110 assert copied.default_headers["X-Foo"] == "bar"
111 assert copied.default_headers["X-Bar"] == "stainless"
112
113 # uses new values for any already given headers
114 copied = client.copy(default_headers={"X-Foo": "stainless"})
115 assert copied.default_headers["X-Foo"] == "stainless"
116
117 # set_default_headers
118
119 # completely overrides already set values
120 copied = client.copy(set_default_headers={})
121 assert copied.default_headers.get("X-Foo") is None
122
123 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
124 assert copied.default_headers["X-Bar"] == "Robert"
125
126 with pytest.raises(
127 ValueError,
128 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
129 ):
130 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
131
132 def test_copy_default_query(self) -> None:
133 client = OpenAI(
134 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
135 )
136 assert _get_params(client)["foo"] == "bar"
137
138 # does not override the already given value when not specified
139 copied = client.copy()
140 assert _get_params(copied)["foo"] == "bar"
141
142 # merges already given params
143 copied = client.copy(default_query={"bar": "stainless"})
144 params = _get_params(copied)
145 assert params["foo"] == "bar"
146 assert params["bar"] == "stainless"
147
148 # uses new values for any already given headers
149 copied = client.copy(default_query={"foo": "stainless"})
150 assert _get_params(copied)["foo"] == "stainless"
151
152 # set_default_query
153
154 # completely overrides already set values
155 copied = client.copy(set_default_query={})
156 assert _get_params(copied) == {}
157
158 copied = client.copy(set_default_query={"bar": "Robert"})
159 assert _get_params(copied)["bar"] == "Robert"
160
161 with pytest.raises(
162 ValueError,
163 # TODO: update
164 match="`default_query` and `set_default_query` arguments are mutually exclusive",
165 ):
166 client.copy(set_default_query={}, default_query={"foo": "Bar"})
167
168 def test_copy_signature(self) -> None:
169 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
170 init_signature = inspect.signature(
171 # mypy doesn't like that we access the `__init__` property.
172 self.client.__init__, # type: ignore[misc]
173 )
174 copy_signature = inspect.signature(self.client.copy)
175 exclude_params = {"transport", "proxies", "_strict_response_validation"}
176
177 for name in init_signature.parameters.keys():
178 if name in exclude_params:
179 continue
180
181 copy_param = copy_signature.parameters.get(name)
182 assert copy_param is not None, f"copy() signature is missing the {name} param"
183
184 def test_copy_build_request(self) -> None:
185 options = FinalRequestOptions(method="get", url="/foo")
186
187 def build_request(options: FinalRequestOptions) -> None:
188 client = self.client.copy()
189 client._build_request(options)
190
191 # ensure that the machinery is warmed up before tracing starts.
192 build_request(options)
193 gc.collect()
194
195 tracemalloc.start(1000)
196
197 snapshot_before = tracemalloc.take_snapshot()
198
199 ITERATIONS = 10
200 for _ in range(ITERATIONS):
201 build_request(options)
202
203 gc.collect()
204 snapshot_after = tracemalloc.take_snapshot()
205
206 tracemalloc.stop()
207
208 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
209 if diff.count == 0:
210 # Avoid false positives by considering only leaks (i.e. allocations that persist).
211 return
212
213 if diff.count % ITERATIONS != 0:
214 # Avoid false positives by considering only leaks that appear per iteration.
215 return
216
217 for frame in diff.traceback:
218 if any(
219 frame.filename.endswith(fragment)
220 for fragment in [
221 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
222 #
223 # removing the decorator fixes the leak for reasons we don't understand.
224 "openai/_legacy_response.py",
225 "openai/_response.py",
226 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
227 "openai/_compat.py",
228 # Standard library leaks we don't care about.
229 "/logging/__init__.py",
230 ]
231 ):
232 return
233
234 leaks.append(diff)
235
236 leaks: list[tracemalloc.StatisticDiff] = []
237 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
238 add_leak(leaks, diff)
239 if leaks:
240 for leak in leaks:
241 print("MEMORY LEAK:", leak)
242 for frame in leak.traceback:
243 print(frame)
244 raise AssertionError()
245
246 def test_request_timeout(self) -> None:
247 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
248 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
249 assert timeout == DEFAULT_TIMEOUT
250
251 request = self.client._build_request(
252 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
253 )
254 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
255 assert timeout == httpx.Timeout(100.0)
256
257 def test_client_timeout_option(self) -> None:
258 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
259
260 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
261 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
262 assert timeout == httpx.Timeout(0)
263
264 def test_http_client_timeout_option(self) -> None:
265 # custom timeout given to the httpx client should be used
266 with httpx.Client(timeout=None) as http_client:
267 client = OpenAI(
268 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
269 )
270
271 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
272 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
273 assert timeout == httpx.Timeout(None)
274
275 # no timeout given to the httpx client should not use the httpx default
276 with httpx.Client() as http_client:
277 client = OpenAI(
278 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
279 )
280
281 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
282 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
283 assert timeout == DEFAULT_TIMEOUT
284
285 # explicitly passing the default timeout currently results in it being ignored
286 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
287 client = OpenAI(
288 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
289 )
290
291 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
292 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
293 assert timeout == DEFAULT_TIMEOUT # our default
294
295 async def test_invalid_http_client(self) -> None:
296 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
297 async with httpx.AsyncClient() as http_client:
298 OpenAI(
299 base_url=base_url,
300 api_key=api_key,
301 _strict_response_validation=True,
302 http_client=cast(Any, http_client),
303 )
304
305 def test_default_headers_option(self) -> None:
306 client = OpenAI(
307 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
308 )
309 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
310 assert request.headers.get("x-foo") == "bar"
311 assert request.headers.get("x-stainless-lang") == "python"
312
313 client2 = OpenAI(
314 base_url=base_url,
315 api_key=api_key,
316 _strict_response_validation=True,
317 default_headers={
318 "X-Foo": "stainless",
319 "X-Stainless-Lang": "my-overriding-header",
320 },
321 )
322 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
323 assert request.headers.get("x-foo") == "stainless"
324 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
325
326 def test_validate_headers(self) -> None:
327 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
328 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
329 assert request.headers.get("Authorization") == f"Bearer {api_key}"
330
331 with pytest.raises(OpenAIError):
332 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
333 _ = client2
334
335 def test_default_query_option(self) -> None:
336 client = OpenAI(
337 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
338 )
339 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
340 url = httpx.URL(request.url)
341 assert dict(url.params) == {"query_param": "bar"}
342
343 request = client._build_request(
344 FinalRequestOptions(
345 method="get",
346 url="/foo",
347 params={"foo": "baz", "query_param": "overriden"},
348 )
349 )
350 url = httpx.URL(request.url)
351 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
352
353 def test_request_extra_json(self) -> None:
354 request = self.client._build_request(
355 FinalRequestOptions(
356 method="post",
357 url="/foo",
358 json_data={"foo": "bar"},
359 extra_json={"baz": False},
360 ),
361 )
362 data = json.loads(request.content.decode("utf-8"))
363 assert data == {"foo": "bar", "baz": False}
364
365 request = self.client._build_request(
366 FinalRequestOptions(
367 method="post",
368 url="/foo",
369 extra_json={"baz": False},
370 ),
371 )
372 data = json.loads(request.content.decode("utf-8"))
373 assert data == {"baz": False}
374
375 # `extra_json` takes priority over `json_data` when keys clash
376 request = self.client._build_request(
377 FinalRequestOptions(
378 method="post",
379 url="/foo",
380 json_data={"foo": "bar", "baz": True},
381 extra_json={"baz": None},
382 ),
383 )
384 data = json.loads(request.content.decode("utf-8"))
385 assert data == {"foo": "bar", "baz": None}
386
387 def test_request_extra_headers(self) -> None:
388 request = self.client._build_request(
389 FinalRequestOptions(
390 method="post",
391 url="/foo",
392 **make_request_options(extra_headers={"X-Foo": "Foo"}),
393 ),
394 )
395 assert request.headers.get("X-Foo") == "Foo"
396
397 # `extra_headers` takes priority over `default_headers` when keys clash
398 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
399 FinalRequestOptions(
400 method="post",
401 url="/foo",
402 **make_request_options(
403 extra_headers={"X-Bar": "false"},
404 ),
405 ),
406 )
407 assert request.headers.get("X-Bar") == "false"
408
409 def test_request_extra_query(self) -> None:
410 request = self.client._build_request(
411 FinalRequestOptions(
412 method="post",
413 url="/foo",
414 **make_request_options(
415 extra_query={"my_query_param": "Foo"},
416 ),
417 ),
418 )
419 params = dict(request.url.params)
420 assert params == {"my_query_param": "Foo"}
421
422 # if both `query` and `extra_query` are given, they are merged
423 request = self.client._build_request(
424 FinalRequestOptions(
425 method="post",
426 url="/foo",
427 **make_request_options(
428 query={"bar": "1"},
429 extra_query={"foo": "2"},
430 ),
431 ),
432 )
433 params = dict(request.url.params)
434 assert params == {"bar": "1", "foo": "2"}
435
436 # `extra_query` takes priority over `query` when keys clash
437 request = self.client._build_request(
438 FinalRequestOptions(
439 method="post",
440 url="/foo",
441 **make_request_options(
442 query={"foo": "1"},
443 extra_query={"foo": "2"},
444 ),
445 ),
446 )
447 params = dict(request.url.params)
448 assert params == {"foo": "2"}
449
450 def test_multipart_repeating_array(self, client: OpenAI) -> None:
451 request = client._build_request(
452 FinalRequestOptions.construct(
453 method="get",
454 url="/foo",
455 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
456 json_data={"array": ["foo", "bar"]},
457 files=[("foo.txt", b"hello world")],
458 )
459 )
460
461 assert request.read().split(b"\r\n") == [
462 b"--6b7ba517decee4a450543ea6ae821c82",
463 b'Content-Disposition: form-data; name="array[]"',
464 b"",
465 b"foo",
466 b"--6b7ba517decee4a450543ea6ae821c82",
467 b'Content-Disposition: form-data; name="array[]"',
468 b"",
469 b"bar",
470 b"--6b7ba517decee4a450543ea6ae821c82",
471 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
472 b"Content-Type: application/octet-stream",
473 b"",
474 b"hello world",
475 b"--6b7ba517decee4a450543ea6ae821c82--",
476 b"",
477 ]
478
479 @pytest.mark.respx(base_url=base_url)
480 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
481 class Model1(BaseModel):
482 name: str
483
484 class Model2(BaseModel):
485 foo: str
486
487 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
488
489 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
490 assert isinstance(response, Model2)
491 assert response.foo == "bar"
492
493 @pytest.mark.respx(base_url=base_url)
494 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
495 """Union of objects with the same field name using a different type"""
496
497 class Model1(BaseModel):
498 foo: int
499
500 class Model2(BaseModel):
501 foo: str
502
503 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
504
505 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
506 assert isinstance(response, Model2)
507 assert response.foo == "bar"
508
509 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
510
511 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
512 assert isinstance(response, Model1)
513 assert response.foo == 1
514
515 @pytest.mark.respx(base_url=base_url)
516 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
517 """
518 Response that sets Content-Type to something other than application/json but returns json data
519 """
520
521 class Model(BaseModel):
522 foo: int
523
524 respx_mock.get("/foo").mock(
525 return_value=httpx.Response(
526 200,
527 content=json.dumps({"foo": 2}),
528 headers={"Content-Type": "application/text"},
529 )
530 )
531
532 response = self.client.get("/foo", cast_to=Model)
533 assert isinstance(response, Model)
534 assert response.foo == 2
535
536 def test_base_url_setter(self) -> None:
537 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
538 assert client.base_url == "https://example.com/from_init/"
539
540 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
541
542 assert client.base_url == "https://example.com/from_setter/"
543
544 def test_base_url_env(self) -> None:
545 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
546 client = OpenAI(api_key=api_key, _strict_response_validation=True)
547 assert client.base_url == "http://localhost:5000/from/env/"
548
549 @pytest.mark.parametrize(
550 "client",
551 [
552 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
553 OpenAI(
554 base_url="http://localhost:5000/custom/path/",
555 api_key=api_key,
556 _strict_response_validation=True,
557 http_client=httpx.Client(),
558 ),
559 ],
560 ids=["standard", "custom http client"],
561 )
562 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
563 request = client._build_request(
564 FinalRequestOptions(
565 method="post",
566 url="/foo",
567 json_data={"foo": "bar"},
568 ),
569 )
570 assert request.url == "http://localhost:5000/custom/path/foo"
571
572 @pytest.mark.parametrize(
573 "client",
574 [
575 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
576 OpenAI(
577 base_url="http://localhost:5000/custom/path/",
578 api_key=api_key,
579 _strict_response_validation=True,
580 http_client=httpx.Client(),
581 ),
582 ],
583 ids=["standard", "custom http client"],
584 )
585 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
586 request = client._build_request(
587 FinalRequestOptions(
588 method="post",
589 url="/foo",
590 json_data={"foo": "bar"},
591 ),
592 )
593 assert request.url == "http://localhost:5000/custom/path/foo"
594
595 @pytest.mark.parametrize(
596 "client",
597 [
598 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
599 OpenAI(
600 base_url="http://localhost:5000/custom/path/",
601 api_key=api_key,
602 _strict_response_validation=True,
603 http_client=httpx.Client(),
604 ),
605 ],
606 ids=["standard", "custom http client"],
607 )
608 def test_absolute_request_url(self, client: OpenAI) -> None:
609 request = client._build_request(
610 FinalRequestOptions(
611 method="post",
612 url="https://myapi.com/foo",
613 json_data={"foo": "bar"},
614 ),
615 )
616 assert request.url == "https://myapi.com/foo"
617
618 def test_copied_client_does_not_close_http(self) -> None:
619 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
620 assert not client.is_closed()
621
622 copied = client.copy()
623 assert copied is not client
624
625 del copied
626
627 assert not client.is_closed()
628
629 def test_client_context_manager(self) -> None:
630 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
631 with client as c2:
632 assert c2 is client
633 assert not c2.is_closed()
634 assert not client.is_closed()
635 assert client.is_closed()
636
637 @pytest.mark.respx(base_url=base_url)
638 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
639 class Model(BaseModel):
640 foo: str
641
642 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
643
644 with pytest.raises(APIResponseValidationError) as exc:
645 self.client.get("/foo", cast_to=Model)
646
647 assert isinstance(exc.value.__cause__, ValidationError)
648
649 def test_client_max_retries_validation(self) -> None:
650 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
651 OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None))
652
653 @pytest.mark.respx(base_url=base_url)
654 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
655 class Model(BaseModel):
656 name: str
657
658 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
659
660 stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
661 assert isinstance(stream, Stream)
662 stream.response.close()
663
664 @pytest.mark.respx(base_url=base_url)
665 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
666 class Model(BaseModel):
667 name: str
668
669 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
670
671 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
672
673 with pytest.raises(APIResponseValidationError):
674 strict_client.get("/foo", cast_to=Model)
675
676 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
677
678 response = client.get("/foo", cast_to=Model)
679 assert isinstance(response, str) # type: ignore[unreachable]
680
681 @pytest.mark.parametrize(
682 "remaining_retries,retry_after,timeout",
683 [
684 [3, "20", 20],
685 [3, "0", 0.5],
686 [3, "-10", 0.5],
687 [3, "60", 60],
688 [3, "61", 0.5],
689 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
690 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
691 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
692 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
693 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
694 [3, "99999999999999999999999999999999999", 0.5],
695 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
696 [3, "", 0.5],
697 [2, "", 0.5 * 2.0],
698 [1, "", 0.5 * 4.0],
699 ],
700 )
701 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
702 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
703 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
704
705 headers = httpx.Headers({"retry-after": retry_after})
706 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
707 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
708 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
709
710 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
711 @pytest.mark.respx(base_url=base_url)
712 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
713 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
714
715 with pytest.raises(APITimeoutError):
716 self.client.post(
717 "/chat/completions",
718 body=cast(
719 object,
720 dict(
721 messages=[
722 {
723 "role": "user",
724 "content": "Say this is a test",
725 }
726 ],
727 model="gpt-3.5-turbo",
728 ),
729 ),
730 cast_to=httpx.Response,
731 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
732 )
733
734 assert _get_open_connections(self.client) == 0
735
736 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
737 @pytest.mark.respx(base_url=base_url)
738 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
739 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
740
741 with pytest.raises(APIStatusError):
742 self.client.post(
743 "/chat/completions",
744 body=cast(
745 object,
746 dict(
747 messages=[
748 {
749 "role": "user",
750 "content": "Say this is a test",
751 }
752 ],
753 model="gpt-3.5-turbo",
754 ),
755 ),
756 cast_to=httpx.Response,
757 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
758 )
759
760 assert _get_open_connections(self.client) == 0
761
762
763class TestAsyncOpenAI:
764 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
765
766 @pytest.mark.respx(base_url=base_url)
767 @pytest.mark.asyncio
768 async def test_raw_response(self, respx_mock: MockRouter) -> None:
769 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
770
771 response = await self.client.post("/foo", cast_to=httpx.Response)
772 assert response.status_code == 200
773 assert isinstance(response, httpx.Response)
774 assert response.json() == {"foo": "bar"}
775
776 @pytest.mark.respx(base_url=base_url)
777 @pytest.mark.asyncio
778 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
779 respx_mock.post("/foo").mock(
780 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
781 )
782
783 response = await self.client.post("/foo", cast_to=httpx.Response)
784 assert response.status_code == 200
785 assert isinstance(response, httpx.Response)
786 assert response.json() == {"foo": "bar"}
787
788 def test_copy(self) -> None:
789 copied = self.client.copy()
790 assert id(copied) != id(self.client)
791
792 copied = self.client.copy(api_key="another My API Key")
793 assert copied.api_key == "another My API Key"
794 assert self.client.api_key == "My API Key"
795
796 def test_copy_default_options(self) -> None:
797 # options that have a default are overridden correctly
798 copied = self.client.copy(max_retries=7)
799 assert copied.max_retries == 7
800 assert self.client.max_retries == 2
801
802 copied2 = copied.copy(max_retries=6)
803 assert copied2.max_retries == 6
804 assert copied.max_retries == 7
805
806 # timeout
807 assert isinstance(self.client.timeout, httpx.Timeout)
808 copied = self.client.copy(timeout=None)
809 assert copied.timeout is None
810 assert isinstance(self.client.timeout, httpx.Timeout)
811
812 def test_copy_default_headers(self) -> None:
813 client = AsyncOpenAI(
814 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
815 )
816 assert client.default_headers["X-Foo"] == "bar"
817
818 # does not override the already given value when not specified
819 copied = client.copy()
820 assert copied.default_headers["X-Foo"] == "bar"
821
822 # merges already given headers
823 copied = client.copy(default_headers={"X-Bar": "stainless"})
824 assert copied.default_headers["X-Foo"] == "bar"
825 assert copied.default_headers["X-Bar"] == "stainless"
826
827 # uses new values for any already given headers
828 copied = client.copy(default_headers={"X-Foo": "stainless"})
829 assert copied.default_headers["X-Foo"] == "stainless"
830
831 # set_default_headers
832
833 # completely overrides already set values
834 copied = client.copy(set_default_headers={})
835 assert copied.default_headers.get("X-Foo") is None
836
837 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
838 assert copied.default_headers["X-Bar"] == "Robert"
839
840 with pytest.raises(
841 ValueError,
842 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
843 ):
844 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
845
846 def test_copy_default_query(self) -> None:
847 client = AsyncOpenAI(
848 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
849 )
850 assert _get_params(client)["foo"] == "bar"
851
852 # does not override the already given value when not specified
853 copied = client.copy()
854 assert _get_params(copied)["foo"] == "bar"
855
856 # merges already given params
857 copied = client.copy(default_query={"bar": "stainless"})
858 params = _get_params(copied)
859 assert params["foo"] == "bar"
860 assert params["bar"] == "stainless"
861
862 # uses new values for any already given headers
863 copied = client.copy(default_query={"foo": "stainless"})
864 assert _get_params(copied)["foo"] == "stainless"
865
866 # set_default_query
867
868 # completely overrides already set values
869 copied = client.copy(set_default_query={})
870 assert _get_params(copied) == {}
871
872 copied = client.copy(set_default_query={"bar": "Robert"})
873 assert _get_params(copied)["bar"] == "Robert"
874
875 with pytest.raises(
876 ValueError,
877 # TODO: update
878 match="`default_query` and `set_default_query` arguments are mutually exclusive",
879 ):
880 client.copy(set_default_query={}, default_query={"foo": "Bar"})
881
882 def test_copy_signature(self) -> None:
883 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
884 init_signature = inspect.signature(
885 # mypy doesn't like that we access the `__init__` property.
886 self.client.__init__, # type: ignore[misc]
887 )
888 copy_signature = inspect.signature(self.client.copy)
889 exclude_params = {"transport", "proxies", "_strict_response_validation"}
890
891 for name in init_signature.parameters.keys():
892 if name in exclude_params:
893 continue
894
895 copy_param = copy_signature.parameters.get(name)
896 assert copy_param is not None, f"copy() signature is missing the {name} param"
897
898 def test_copy_build_request(self) -> None:
899 options = FinalRequestOptions(method="get", url="/foo")
900
901 def build_request(options: FinalRequestOptions) -> None:
902 client = self.client.copy()
903 client._build_request(options)
904
905 # ensure that the machinery is warmed up before tracing starts.
906 build_request(options)
907 gc.collect()
908
909 tracemalloc.start(1000)
910
911 snapshot_before = tracemalloc.take_snapshot()
912
913 ITERATIONS = 10
914 for _ in range(ITERATIONS):
915 build_request(options)
916
917 gc.collect()
918 snapshot_after = tracemalloc.take_snapshot()
919
920 tracemalloc.stop()
921
922 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
923 if diff.count == 0:
924 # Avoid false positives by considering only leaks (i.e. allocations that persist).
925 return
926
927 if diff.count % ITERATIONS != 0:
928 # Avoid false positives by considering only leaks that appear per iteration.
929 return
930
931 for frame in diff.traceback:
932 if any(
933 frame.filename.endswith(fragment)
934 for fragment in [
935 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
936 #
937 # removing the decorator fixes the leak for reasons we don't understand.
938 "openai/_legacy_response.py",
939 "openai/_response.py",
940 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
941 "openai/_compat.py",
942 # Standard library leaks we don't care about.
943 "/logging/__init__.py",
944 ]
945 ):
946 return
947
948 leaks.append(diff)
949
950 leaks: list[tracemalloc.StatisticDiff] = []
951 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
952 add_leak(leaks, diff)
953 if leaks:
954 for leak in leaks:
955 print("MEMORY LEAK:", leak)
956 for frame in leak.traceback:
957 print(frame)
958 raise AssertionError()
959
960 async def test_request_timeout(self) -> None:
961 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
962 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
963 assert timeout == DEFAULT_TIMEOUT
964
965 request = self.client._build_request(
966 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
967 )
968 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
969 assert timeout == httpx.Timeout(100.0)
970
971 async def test_client_timeout_option(self) -> None:
972 client = AsyncOpenAI(
973 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
974 )
975
976 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
977 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
978 assert timeout == httpx.Timeout(0)
979
980 async def test_http_client_timeout_option(self) -> None:
981 # custom timeout given to the httpx client should be used
982 async with httpx.AsyncClient(timeout=None) as http_client:
983 client = AsyncOpenAI(
984 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
985 )
986
987 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
988 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
989 assert timeout == httpx.Timeout(None)
990
991 # no timeout given to the httpx client should not use the httpx default
992 async with httpx.AsyncClient() as http_client:
993 client = AsyncOpenAI(
994 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
995 )
996
997 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
998 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
999 assert timeout == DEFAULT_TIMEOUT
1000
1001 # explicitly passing the default timeout currently results in it being ignored
1002 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1003 client = AsyncOpenAI(
1004 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1005 )
1006
1007 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1008 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1009 assert timeout == DEFAULT_TIMEOUT # our default
1010
1011 def test_invalid_http_client(self) -> None:
1012 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1013 with httpx.Client() as http_client:
1014 AsyncOpenAI(
1015 base_url=base_url,
1016 api_key=api_key,
1017 _strict_response_validation=True,
1018 http_client=cast(Any, http_client),
1019 )
1020
1021 def test_default_headers_option(self) -> None:
1022 client = AsyncOpenAI(
1023 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1024 )
1025 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1026 assert request.headers.get("x-foo") == "bar"
1027 assert request.headers.get("x-stainless-lang") == "python"
1028
1029 client2 = AsyncOpenAI(
1030 base_url=base_url,
1031 api_key=api_key,
1032 _strict_response_validation=True,
1033 default_headers={
1034 "X-Foo": "stainless",
1035 "X-Stainless-Lang": "my-overriding-header",
1036 },
1037 )
1038 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1039 assert request.headers.get("x-foo") == "stainless"
1040 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1041
1042 def test_validate_headers(self) -> None:
1043 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1044 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1045 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1046
1047 with pytest.raises(OpenAIError):
1048 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1049 _ = client2
1050
1051 def test_default_query_option(self) -> None:
1052 client = AsyncOpenAI(
1053 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1054 )
1055 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1056 url = httpx.URL(request.url)
1057 assert dict(url.params) == {"query_param": "bar"}
1058
1059 request = client._build_request(
1060 FinalRequestOptions(
1061 method="get",
1062 url="/foo",
1063 params={"foo": "baz", "query_param": "overriden"},
1064 )
1065 )
1066 url = httpx.URL(request.url)
1067 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
1068
1069 def test_request_extra_json(self) -> None:
1070 request = self.client._build_request(
1071 FinalRequestOptions(
1072 method="post",
1073 url="/foo",
1074 json_data={"foo": "bar"},
1075 extra_json={"baz": False},
1076 ),
1077 )
1078 data = json.loads(request.content.decode("utf-8"))
1079 assert data == {"foo": "bar", "baz": False}
1080
1081 request = self.client._build_request(
1082 FinalRequestOptions(
1083 method="post",
1084 url="/foo",
1085 extra_json={"baz": False},
1086 ),
1087 )
1088 data = json.loads(request.content.decode("utf-8"))
1089 assert data == {"baz": False}
1090
1091 # `extra_json` takes priority over `json_data` when keys clash
1092 request = self.client._build_request(
1093 FinalRequestOptions(
1094 method="post",
1095 url="/foo",
1096 json_data={"foo": "bar", "baz": True},
1097 extra_json={"baz": None},
1098 ),
1099 )
1100 data = json.loads(request.content.decode("utf-8"))
1101 assert data == {"foo": "bar", "baz": None}
1102
1103 def test_request_extra_headers(self) -> None:
1104 request = self.client._build_request(
1105 FinalRequestOptions(
1106 method="post",
1107 url="/foo",
1108 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1109 ),
1110 )
1111 assert request.headers.get("X-Foo") == "Foo"
1112
1113 # `extra_headers` takes priority over `default_headers` when keys clash
1114 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1115 FinalRequestOptions(
1116 method="post",
1117 url="/foo",
1118 **make_request_options(
1119 extra_headers={"X-Bar": "false"},
1120 ),
1121 ),
1122 )
1123 assert request.headers.get("X-Bar") == "false"
1124
1125 def test_request_extra_query(self) -> None:
1126 request = self.client._build_request(
1127 FinalRequestOptions(
1128 method="post",
1129 url="/foo",
1130 **make_request_options(
1131 extra_query={"my_query_param": "Foo"},
1132 ),
1133 ),
1134 )
1135 params = dict(request.url.params)
1136 assert params == {"my_query_param": "Foo"}
1137
1138 # if both `query` and `extra_query` are given, they are merged
1139 request = self.client._build_request(
1140 FinalRequestOptions(
1141 method="post",
1142 url="/foo",
1143 **make_request_options(
1144 query={"bar": "1"},
1145 extra_query={"foo": "2"},
1146 ),
1147 ),
1148 )
1149 params = dict(request.url.params)
1150 assert params == {"bar": "1", "foo": "2"}
1151
1152 # `extra_query` takes priority over `query` when keys clash
1153 request = self.client._build_request(
1154 FinalRequestOptions(
1155 method="post",
1156 url="/foo",
1157 **make_request_options(
1158 query={"foo": "1"},
1159 extra_query={"foo": "2"},
1160 ),
1161 ),
1162 )
1163 params = dict(request.url.params)
1164 assert params == {"foo": "2"}
1165
1166 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1167 request = async_client._build_request(
1168 FinalRequestOptions.construct(
1169 method="get",
1170 url="/foo",
1171 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1172 json_data={"array": ["foo", "bar"]},
1173 files=[("foo.txt", b"hello world")],
1174 )
1175 )
1176
1177 assert request.read().split(b"\r\n") == [
1178 b"--6b7ba517decee4a450543ea6ae821c82",
1179 b'Content-Disposition: form-data; name="array[]"',
1180 b"",
1181 b"foo",
1182 b"--6b7ba517decee4a450543ea6ae821c82",
1183 b'Content-Disposition: form-data; name="array[]"',
1184 b"",
1185 b"bar",
1186 b"--6b7ba517decee4a450543ea6ae821c82",
1187 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1188 b"Content-Type: application/octet-stream",
1189 b"",
1190 b"hello world",
1191 b"--6b7ba517decee4a450543ea6ae821c82--",
1192 b"",
1193 ]
1194
1195 @pytest.mark.respx(base_url=base_url)
1196 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1197 class Model1(BaseModel):
1198 name: str
1199
1200 class Model2(BaseModel):
1201 foo: str
1202
1203 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1204
1205 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1206 assert isinstance(response, Model2)
1207 assert response.foo == "bar"
1208
1209 @pytest.mark.respx(base_url=base_url)
1210 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1211 """Union of objects with the same field name using a different type"""
1212
1213 class Model1(BaseModel):
1214 foo: int
1215
1216 class Model2(BaseModel):
1217 foo: str
1218
1219 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1220
1221 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1222 assert isinstance(response, Model2)
1223 assert response.foo == "bar"
1224
1225 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1226
1227 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1228 assert isinstance(response, Model1)
1229 assert response.foo == 1
1230
1231 @pytest.mark.respx(base_url=base_url)
1232 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1233 """
1234 Response that sets Content-Type to something other than application/json but returns json data
1235 """
1236
1237 class Model(BaseModel):
1238 foo: int
1239
1240 respx_mock.get("/foo").mock(
1241 return_value=httpx.Response(
1242 200,
1243 content=json.dumps({"foo": 2}),
1244 headers={"Content-Type": "application/text"},
1245 )
1246 )
1247
1248 response = await self.client.get("/foo", cast_to=Model)
1249 assert isinstance(response, Model)
1250 assert response.foo == 2
1251
1252 def test_base_url_setter(self) -> None:
1253 client = AsyncOpenAI(
1254 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1255 )
1256 assert client.base_url == "https://example.com/from_init/"
1257
1258 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1259
1260 assert client.base_url == "https://example.com/from_setter/"
1261
1262 def test_base_url_env(self) -> None:
1263 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1264 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1265 assert client.base_url == "http://localhost:5000/from/env/"
1266
1267 @pytest.mark.parametrize(
1268 "client",
1269 [
1270 AsyncOpenAI(
1271 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1272 ),
1273 AsyncOpenAI(
1274 base_url="http://localhost:5000/custom/path/",
1275 api_key=api_key,
1276 _strict_response_validation=True,
1277 http_client=httpx.AsyncClient(),
1278 ),
1279 ],
1280 ids=["standard", "custom http client"],
1281 )
1282 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1283 request = client._build_request(
1284 FinalRequestOptions(
1285 method="post",
1286 url="/foo",
1287 json_data={"foo": "bar"},
1288 ),
1289 )
1290 assert request.url == "http://localhost:5000/custom/path/foo"
1291
1292 @pytest.mark.parametrize(
1293 "client",
1294 [
1295 AsyncOpenAI(
1296 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1297 ),
1298 AsyncOpenAI(
1299 base_url="http://localhost:5000/custom/path/",
1300 api_key=api_key,
1301 _strict_response_validation=True,
1302 http_client=httpx.AsyncClient(),
1303 ),
1304 ],
1305 ids=["standard", "custom http client"],
1306 )
1307 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1308 request = client._build_request(
1309 FinalRequestOptions(
1310 method="post",
1311 url="/foo",
1312 json_data={"foo": "bar"},
1313 ),
1314 )
1315 assert request.url == "http://localhost:5000/custom/path/foo"
1316
1317 @pytest.mark.parametrize(
1318 "client",
1319 [
1320 AsyncOpenAI(
1321 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1322 ),
1323 AsyncOpenAI(
1324 base_url="http://localhost:5000/custom/path/",
1325 api_key=api_key,
1326 _strict_response_validation=True,
1327 http_client=httpx.AsyncClient(),
1328 ),
1329 ],
1330 ids=["standard", "custom http client"],
1331 )
1332 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1333 request = client._build_request(
1334 FinalRequestOptions(
1335 method="post",
1336 url="https://myapi.com/foo",
1337 json_data={"foo": "bar"},
1338 ),
1339 )
1340 assert request.url == "https://myapi.com/foo"
1341
1342 async def test_copied_client_does_not_close_http(self) -> None:
1343 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1344 assert not client.is_closed()
1345
1346 copied = client.copy()
1347 assert copied is not client
1348
1349 del copied
1350
1351 await asyncio.sleep(0.2)
1352 assert not client.is_closed()
1353
1354 async def test_client_context_manager(self) -> None:
1355 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1356 async with client as c2:
1357 assert c2 is client
1358 assert not c2.is_closed()
1359 assert not client.is_closed()
1360 assert client.is_closed()
1361
1362 @pytest.mark.respx(base_url=base_url)
1363 @pytest.mark.asyncio
1364 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1365 class Model(BaseModel):
1366 foo: str
1367
1368 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1369
1370 with pytest.raises(APIResponseValidationError) as exc:
1371 await self.client.get("/foo", cast_to=Model)
1372
1373 assert isinstance(exc.value.__cause__, ValidationError)
1374
1375 async def test_client_max_retries_validation(self) -> None:
1376 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
1377 AsyncOpenAI(
1378 base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)
1379 )
1380
1381 @pytest.mark.respx(base_url=base_url)
1382 @pytest.mark.asyncio
1383 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1384 class Model(BaseModel):
1385 name: str
1386
1387 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1388
1389 stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1390 assert isinstance(stream, AsyncStream)
1391 await stream.response.aclose()
1392
1393 @pytest.mark.respx(base_url=base_url)
1394 @pytest.mark.asyncio
1395 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1396 class Model(BaseModel):
1397 name: str
1398
1399 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1400
1401 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1402
1403 with pytest.raises(APIResponseValidationError):
1404 await strict_client.get("/foo", cast_to=Model)
1405
1406 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1407
1408 response = await client.get("/foo", cast_to=Model)
1409 assert isinstance(response, str) # type: ignore[unreachable]
1410
1411 @pytest.mark.parametrize(
1412 "remaining_retries,retry_after,timeout",
1413 [
1414 [3, "20", 20],
1415 [3, "0", 0.5],
1416 [3, "-10", 0.5],
1417 [3, "60", 60],
1418 [3, "61", 0.5],
1419 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1420 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1421 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1422 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1423 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1424 [3, "99999999999999999999999999999999999", 0.5],
1425 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1426 [3, "", 0.5],
1427 [2, "", 0.5 * 2.0],
1428 [1, "", 0.5 * 4.0],
1429 ],
1430 )
1431 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1432 @pytest.mark.asyncio
1433 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1434 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1435
1436 headers = httpx.Headers({"retry-after": retry_after})
1437 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1438 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1439 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1440
1441 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1442 @pytest.mark.respx(base_url=base_url)
1443 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1444 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1445
1446 with pytest.raises(APITimeoutError):
1447 await self.client.post(
1448 "/chat/completions",
1449 body=cast(
1450 object,
1451 dict(
1452 messages=[
1453 {
1454 "role": "user",
1455 "content": "Say this is a test",
1456 }
1457 ],
1458 model="gpt-3.5-turbo",
1459 ),
1460 ),
1461 cast_to=httpx.Response,
1462 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1463 )
1464
1465 assert _get_open_connections(self.client) == 0
1466
1467 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1468 @pytest.mark.respx(base_url=base_url)
1469 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1470 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1471
1472 with pytest.raises(APIStatusError):
1473 await self.client.post(
1474 "/chat/completions",
1475 body=cast(
1476 object,
1477 dict(
1478 messages=[
1479 {
1480 "role": "user",
1481 "content": "Say this is a test",
1482 }
1483 ],
1484 model="gpt-3.5-turbo",
1485 ),
1486 ),
1487 cast_to=httpx.Response,
1488 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1489 )
1490
1491 assert _get_open_connections(self.client) == 0
1492