openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.25.1

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1490lines · modecode

1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import gc
6import os
7import json
8import asyncio
9import inspect
10import tracemalloc
11from typing import Any, Union, cast
12from unittest import mock
13
14import httpx
15import pytest
16from respx import MockRouter
17from pydantic import ValidationError
18
19from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
20from openai._models import BaseModel, FinalRequestOptions
21from openai._constants import RAW_RESPONSE_HEADER
22from openai._streaming import Stream, AsyncStream
23from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
24from openai._base_client import DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, BaseClient, make_request_options
25
26from .utils import update_env
27
28base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
29api_key = "My API Key"
30
31
32def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
33 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
34 url = httpx.URL(request.url)
35 return dict(url.params)
36
37
38def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
39 return 0.1
40
41
42def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
43 transport = client._client._transport
44 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
45
46 pool = transport._pool
47 return len(pool._requests)
48
49
50class TestOpenAI:
51 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
52
53 @pytest.mark.respx(base_url=base_url)
54 def test_raw_response(self, respx_mock: MockRouter) -> None:
55 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
56
57 response = self.client.post("/foo", cast_to=httpx.Response)
58 assert response.status_code == 200
59 assert isinstance(response, httpx.Response)
60 assert response.json() == {"foo": "bar"}
61
62 @pytest.mark.respx(base_url=base_url)
63 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
64 respx_mock.post("/foo").mock(
65 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
66 )
67
68 response = self.client.post("/foo", cast_to=httpx.Response)
69 assert response.status_code == 200
70 assert isinstance(response, httpx.Response)
71 assert response.json() == {"foo": "bar"}
72
73 def test_copy(self) -> None:
74 copied = self.client.copy()
75 assert id(copied) != id(self.client)
76
77 copied = self.client.copy(api_key="another My API Key")
78 assert copied.api_key == "another My API Key"
79 assert self.client.api_key == "My API Key"
80
81 def test_copy_default_options(self) -> None:
82 # options that have a default are overridden correctly
83 copied = self.client.copy(max_retries=7)
84 assert copied.max_retries == 7
85 assert self.client.max_retries == 2
86
87 copied2 = copied.copy(max_retries=6)
88 assert copied2.max_retries == 6
89 assert copied.max_retries == 7
90
91 # timeout
92 assert isinstance(self.client.timeout, httpx.Timeout)
93 copied = self.client.copy(timeout=None)
94 assert copied.timeout is None
95 assert isinstance(self.client.timeout, httpx.Timeout)
96
97 def test_copy_default_headers(self) -> None:
98 client = OpenAI(
99 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
100 )
101 assert client.default_headers["X-Foo"] == "bar"
102
103 # does not override the already given value when not specified
104 copied = client.copy()
105 assert copied.default_headers["X-Foo"] == "bar"
106
107 # merges already given headers
108 copied = client.copy(default_headers={"X-Bar": "stainless"})
109 assert copied.default_headers["X-Foo"] == "bar"
110 assert copied.default_headers["X-Bar"] == "stainless"
111
112 # uses new values for any already given headers
113 copied = client.copy(default_headers={"X-Foo": "stainless"})
114 assert copied.default_headers["X-Foo"] == "stainless"
115
116 # set_default_headers
117
118 # completely overrides already set values
119 copied = client.copy(set_default_headers={})
120 assert copied.default_headers.get("X-Foo") is None
121
122 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
123 assert copied.default_headers["X-Bar"] == "Robert"
124
125 with pytest.raises(
126 ValueError,
127 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
128 ):
129 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
130
131 def test_copy_default_query(self) -> None:
132 client = OpenAI(
133 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
134 )
135 assert _get_params(client)["foo"] == "bar"
136
137 # does not override the already given value when not specified
138 copied = client.copy()
139 assert _get_params(copied)["foo"] == "bar"
140
141 # merges already given params
142 copied = client.copy(default_query={"bar": "stainless"})
143 params = _get_params(copied)
144 assert params["foo"] == "bar"
145 assert params["bar"] == "stainless"
146
147 # uses new values for any already given headers
148 copied = client.copy(default_query={"foo": "stainless"})
149 assert _get_params(copied)["foo"] == "stainless"
150
151 # set_default_query
152
153 # completely overrides already set values
154 copied = client.copy(set_default_query={})
155 assert _get_params(copied) == {}
156
157 copied = client.copy(set_default_query={"bar": "Robert"})
158 assert _get_params(copied)["bar"] == "Robert"
159
160 with pytest.raises(
161 ValueError,
162 # TODO: update
163 match="`default_query` and `set_default_query` arguments are mutually exclusive",
164 ):
165 client.copy(set_default_query={}, default_query={"foo": "Bar"})
166
167 def test_copy_signature(self) -> None:
168 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
169 init_signature = inspect.signature(
170 # mypy doesn't like that we access the `__init__` property.
171 self.client.__init__, # type: ignore[misc]
172 )
173 copy_signature = inspect.signature(self.client.copy)
174 exclude_params = {"transport", "proxies", "_strict_response_validation"}
175
176 for name in init_signature.parameters.keys():
177 if name in exclude_params:
178 continue
179
180 copy_param = copy_signature.parameters.get(name)
181 assert copy_param is not None, f"copy() signature is missing the {name} param"
182
183 def test_copy_build_request(self) -> None:
184 options = FinalRequestOptions(method="get", url="/foo")
185
186 def build_request(options: FinalRequestOptions) -> None:
187 client = self.client.copy()
188 client._build_request(options)
189
190 # ensure that the machinery is warmed up before tracing starts.
191 build_request(options)
192 gc.collect()
193
194 tracemalloc.start(1000)
195
196 snapshot_before = tracemalloc.take_snapshot()
197
198 ITERATIONS = 10
199 for _ in range(ITERATIONS):
200 build_request(options)
201
202 gc.collect()
203 snapshot_after = tracemalloc.take_snapshot()
204
205 tracemalloc.stop()
206
207 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
208 if diff.count == 0:
209 # Avoid false positives by considering only leaks (i.e. allocations that persist).
210 return
211
212 if diff.count % ITERATIONS != 0:
213 # Avoid false positives by considering only leaks that appear per iteration.
214 return
215
216 for frame in diff.traceback:
217 if any(
218 frame.filename.endswith(fragment)
219 for fragment in [
220 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
221 #
222 # removing the decorator fixes the leak for reasons we don't understand.
223 "openai/_legacy_response.py",
224 "openai/_response.py",
225 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
226 "openai/_compat.py",
227 # Standard library leaks we don't care about.
228 "/logging/__init__.py",
229 ]
230 ):
231 return
232
233 leaks.append(diff)
234
235 leaks: list[tracemalloc.StatisticDiff] = []
236 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
237 add_leak(leaks, diff)
238 if leaks:
239 for leak in leaks:
240 print("MEMORY LEAK:", leak)
241 for frame in leak.traceback:
242 print(frame)
243 raise AssertionError()
244
245 def test_request_timeout(self) -> None:
246 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
247 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
248 assert timeout == DEFAULT_TIMEOUT
249
250 request = self.client._build_request(
251 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
252 )
253 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
254 assert timeout == httpx.Timeout(100.0)
255
256 def test_client_timeout_option(self) -> None:
257 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
258
259 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
260 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
261 assert timeout == httpx.Timeout(0)
262
263 def test_http_client_timeout_option(self) -> None:
264 # custom timeout given to the httpx client should be used
265 with httpx.Client(timeout=None) as http_client:
266 client = OpenAI(
267 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
268 )
269
270 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
271 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
272 assert timeout == httpx.Timeout(None)
273
274 # no timeout given to the httpx client should not use the httpx default
275 with httpx.Client() as http_client:
276 client = OpenAI(
277 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
278 )
279
280 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
281 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
282 assert timeout == DEFAULT_TIMEOUT
283
284 # explicitly passing the default timeout currently results in it being ignored
285 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
286 client = OpenAI(
287 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
288 )
289
290 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
291 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
292 assert timeout == DEFAULT_TIMEOUT # our default
293
294 async def test_invalid_http_client(self) -> None:
295 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
296 async with httpx.AsyncClient() as http_client:
297 OpenAI(
298 base_url=base_url,
299 api_key=api_key,
300 _strict_response_validation=True,
301 http_client=cast(Any, http_client),
302 )
303
304 def test_default_headers_option(self) -> None:
305 client = OpenAI(
306 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
307 )
308 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
309 assert request.headers.get("x-foo") == "bar"
310 assert request.headers.get("x-stainless-lang") == "python"
311
312 client2 = OpenAI(
313 base_url=base_url,
314 api_key=api_key,
315 _strict_response_validation=True,
316 default_headers={
317 "X-Foo": "stainless",
318 "X-Stainless-Lang": "my-overriding-header",
319 },
320 )
321 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
322 assert request.headers.get("x-foo") == "stainless"
323 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
324
325 def test_validate_headers(self) -> None:
326 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
327 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
328 assert request.headers.get("Authorization") == f"Bearer {api_key}"
329
330 with pytest.raises(OpenAIError):
331 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
332 _ = client2
333
334 def test_default_query_option(self) -> None:
335 client = OpenAI(
336 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
337 )
338 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
339 url = httpx.URL(request.url)
340 assert dict(url.params) == {"query_param": "bar"}
341
342 request = client._build_request(
343 FinalRequestOptions(
344 method="get",
345 url="/foo",
346 params={"foo": "baz", "query_param": "overriden"},
347 )
348 )
349 url = httpx.URL(request.url)
350 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
351
352 def test_request_extra_json(self) -> None:
353 request = self.client._build_request(
354 FinalRequestOptions(
355 method="post",
356 url="/foo",
357 json_data={"foo": "bar"},
358 extra_json={"baz": False},
359 ),
360 )
361 data = json.loads(request.content.decode("utf-8"))
362 assert data == {"foo": "bar", "baz": False}
363
364 request = self.client._build_request(
365 FinalRequestOptions(
366 method="post",
367 url="/foo",
368 extra_json={"baz": False},
369 ),
370 )
371 data = json.loads(request.content.decode("utf-8"))
372 assert data == {"baz": False}
373
374 # `extra_json` takes priority over `json_data` when keys clash
375 request = self.client._build_request(
376 FinalRequestOptions(
377 method="post",
378 url="/foo",
379 json_data={"foo": "bar", "baz": True},
380 extra_json={"baz": None},
381 ),
382 )
383 data = json.loads(request.content.decode("utf-8"))
384 assert data == {"foo": "bar", "baz": None}
385
386 def test_request_extra_headers(self) -> None:
387 request = self.client._build_request(
388 FinalRequestOptions(
389 method="post",
390 url="/foo",
391 **make_request_options(extra_headers={"X-Foo": "Foo"}),
392 ),
393 )
394 assert request.headers.get("X-Foo") == "Foo"
395
396 # `extra_headers` takes priority over `default_headers` when keys clash
397 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
398 FinalRequestOptions(
399 method="post",
400 url="/foo",
401 **make_request_options(
402 extra_headers={"X-Bar": "false"},
403 ),
404 ),
405 )
406 assert request.headers.get("X-Bar") == "false"
407
408 def test_request_extra_query(self) -> None:
409 request = self.client._build_request(
410 FinalRequestOptions(
411 method="post",
412 url="/foo",
413 **make_request_options(
414 extra_query={"my_query_param": "Foo"},
415 ),
416 ),
417 )
418 params = dict(request.url.params)
419 assert params == {"my_query_param": "Foo"}
420
421 # if both `query` and `extra_query` are given, they are merged
422 request = self.client._build_request(
423 FinalRequestOptions(
424 method="post",
425 url="/foo",
426 **make_request_options(
427 query={"bar": "1"},
428 extra_query={"foo": "2"},
429 ),
430 ),
431 )
432 params = dict(request.url.params)
433 assert params == {"bar": "1", "foo": "2"}
434
435 # `extra_query` takes priority over `query` when keys clash
436 request = self.client._build_request(
437 FinalRequestOptions(
438 method="post",
439 url="/foo",
440 **make_request_options(
441 query={"foo": "1"},
442 extra_query={"foo": "2"},
443 ),
444 ),
445 )
446 params = dict(request.url.params)
447 assert params == {"foo": "2"}
448
449 def test_multipart_repeating_array(self, client: OpenAI) -> None:
450 request = client._build_request(
451 FinalRequestOptions.construct(
452 method="get",
453 url="/foo",
454 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
455 json_data={"array": ["foo", "bar"]},
456 files=[("foo.txt", b"hello world")],
457 )
458 )
459
460 assert request.read().split(b"\r\n") == [
461 b"--6b7ba517decee4a450543ea6ae821c82",
462 b'Content-Disposition: form-data; name="array[]"',
463 b"",
464 b"foo",
465 b"--6b7ba517decee4a450543ea6ae821c82",
466 b'Content-Disposition: form-data; name="array[]"',
467 b"",
468 b"bar",
469 b"--6b7ba517decee4a450543ea6ae821c82",
470 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
471 b"Content-Type: application/octet-stream",
472 b"",
473 b"hello world",
474 b"--6b7ba517decee4a450543ea6ae821c82--",
475 b"",
476 ]
477
478 @pytest.mark.respx(base_url=base_url)
479 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
480 class Model1(BaseModel):
481 name: str
482
483 class Model2(BaseModel):
484 foo: str
485
486 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
487
488 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
489 assert isinstance(response, Model2)
490 assert response.foo == "bar"
491
492 @pytest.mark.respx(base_url=base_url)
493 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
494 """Union of objects with the same field name using a different type"""
495
496 class Model1(BaseModel):
497 foo: int
498
499 class Model2(BaseModel):
500 foo: str
501
502 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
503
504 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
505 assert isinstance(response, Model2)
506 assert response.foo == "bar"
507
508 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
509
510 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
511 assert isinstance(response, Model1)
512 assert response.foo == 1
513
514 @pytest.mark.respx(base_url=base_url)
515 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
516 """
517 Response that sets Content-Type to something other than application/json but returns json data
518 """
519
520 class Model(BaseModel):
521 foo: int
522
523 respx_mock.get("/foo").mock(
524 return_value=httpx.Response(
525 200,
526 content=json.dumps({"foo": 2}),
527 headers={"Content-Type": "application/text"},
528 )
529 )
530
531 response = self.client.get("/foo", cast_to=Model)
532 assert isinstance(response, Model)
533 assert response.foo == 2
534
535 def test_base_url_setter(self) -> None:
536 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
537 assert client.base_url == "https://example.com/from_init/"
538
539 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
540
541 assert client.base_url == "https://example.com/from_setter/"
542
543 def test_base_url_env(self) -> None:
544 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
545 client = OpenAI(api_key=api_key, _strict_response_validation=True)
546 assert client.base_url == "http://localhost:5000/from/env/"
547
548 @pytest.mark.parametrize(
549 "client",
550 [
551 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
552 OpenAI(
553 base_url="http://localhost:5000/custom/path/",
554 api_key=api_key,
555 _strict_response_validation=True,
556 http_client=httpx.Client(),
557 ),
558 ],
559 ids=["standard", "custom http client"],
560 )
561 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
562 request = client._build_request(
563 FinalRequestOptions(
564 method="post",
565 url="/foo",
566 json_data={"foo": "bar"},
567 ),
568 )
569 assert request.url == "http://localhost:5000/custom/path/foo"
570
571 @pytest.mark.parametrize(
572 "client",
573 [
574 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
575 OpenAI(
576 base_url="http://localhost:5000/custom/path/",
577 api_key=api_key,
578 _strict_response_validation=True,
579 http_client=httpx.Client(),
580 ),
581 ],
582 ids=["standard", "custom http client"],
583 )
584 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
585 request = client._build_request(
586 FinalRequestOptions(
587 method="post",
588 url="/foo",
589 json_data={"foo": "bar"},
590 ),
591 )
592 assert request.url == "http://localhost:5000/custom/path/foo"
593
594 @pytest.mark.parametrize(
595 "client",
596 [
597 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
598 OpenAI(
599 base_url="http://localhost:5000/custom/path/",
600 api_key=api_key,
601 _strict_response_validation=True,
602 http_client=httpx.Client(),
603 ),
604 ],
605 ids=["standard", "custom http client"],
606 )
607 def test_absolute_request_url(self, client: OpenAI) -> None:
608 request = client._build_request(
609 FinalRequestOptions(
610 method="post",
611 url="https://myapi.com/foo",
612 json_data={"foo": "bar"},
613 ),
614 )
615 assert request.url == "https://myapi.com/foo"
616
617 def test_copied_client_does_not_close_http(self) -> None:
618 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
619 assert not client.is_closed()
620
621 copied = client.copy()
622 assert copied is not client
623
624 del copied
625
626 assert not client.is_closed()
627
628 def test_client_context_manager(self) -> None:
629 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
630 with client as c2:
631 assert c2 is client
632 assert not c2.is_closed()
633 assert not client.is_closed()
634 assert client.is_closed()
635
636 @pytest.mark.respx(base_url=base_url)
637 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
638 class Model(BaseModel):
639 foo: str
640
641 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
642
643 with pytest.raises(APIResponseValidationError) as exc:
644 self.client.get("/foo", cast_to=Model)
645
646 assert isinstance(exc.value.__cause__, ValidationError)
647
648 def test_client_max_retries_validation(self) -> None:
649 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
650 OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None))
651
652 @pytest.mark.respx(base_url=base_url)
653 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
654 class Model(BaseModel):
655 name: str
656
657 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
658
659 stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
660 assert isinstance(stream, Stream)
661 stream.response.close()
662
663 @pytest.mark.respx(base_url=base_url)
664 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
665 class Model(BaseModel):
666 name: str
667
668 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
669
670 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
671
672 with pytest.raises(APIResponseValidationError):
673 strict_client.get("/foo", cast_to=Model)
674
675 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
676
677 response = client.get("/foo", cast_to=Model)
678 assert isinstance(response, str) # type: ignore[unreachable]
679
680 @pytest.mark.parametrize(
681 "remaining_retries,retry_after,timeout",
682 [
683 [3, "20", 20],
684 [3, "0", 0.5],
685 [3, "-10", 0.5],
686 [3, "60", 60],
687 [3, "61", 0.5],
688 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
689 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
690 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
691 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
692 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
693 [3, "99999999999999999999999999999999999", 0.5],
694 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
695 [3, "", 0.5],
696 [2, "", 0.5 * 2.0],
697 [1, "", 0.5 * 4.0],
698 ],
699 )
700 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
701 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
702 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
703
704 headers = httpx.Headers({"retry-after": retry_after})
705 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
706 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
707 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
708
709 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
710 @pytest.mark.respx(base_url=base_url)
711 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
712 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
713
714 with pytest.raises(APITimeoutError):
715 self.client.post(
716 "/chat/completions",
717 body=cast(
718 object,
719 dict(
720 messages=[
721 {
722 "role": "user",
723 "content": "Say this is a test",
724 }
725 ],
726 model="gpt-3.5-turbo",
727 ),
728 ),
729 cast_to=httpx.Response,
730 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
731 )
732
733 assert _get_open_connections(self.client) == 0
734
735 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
736 @pytest.mark.respx(base_url=base_url)
737 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
738 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
739
740 with pytest.raises(APIStatusError):
741 self.client.post(
742 "/chat/completions",
743 body=cast(
744 object,
745 dict(
746 messages=[
747 {
748 "role": "user",
749 "content": "Say this is a test",
750 }
751 ],
752 model="gpt-3.5-turbo",
753 ),
754 ),
755 cast_to=httpx.Response,
756 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
757 )
758
759 assert _get_open_connections(self.client) == 0
760
761
762class TestAsyncOpenAI:
763 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
764
765 @pytest.mark.respx(base_url=base_url)
766 @pytest.mark.asyncio
767 async def test_raw_response(self, respx_mock: MockRouter) -> None:
768 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
769
770 response = await self.client.post("/foo", cast_to=httpx.Response)
771 assert response.status_code == 200
772 assert isinstance(response, httpx.Response)
773 assert response.json() == {"foo": "bar"}
774
775 @pytest.mark.respx(base_url=base_url)
776 @pytest.mark.asyncio
777 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
778 respx_mock.post("/foo").mock(
779 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
780 )
781
782 response = await self.client.post("/foo", cast_to=httpx.Response)
783 assert response.status_code == 200
784 assert isinstance(response, httpx.Response)
785 assert response.json() == {"foo": "bar"}
786
787 def test_copy(self) -> None:
788 copied = self.client.copy()
789 assert id(copied) != id(self.client)
790
791 copied = self.client.copy(api_key="another My API Key")
792 assert copied.api_key == "another My API Key"
793 assert self.client.api_key == "My API Key"
794
795 def test_copy_default_options(self) -> None:
796 # options that have a default are overridden correctly
797 copied = self.client.copy(max_retries=7)
798 assert copied.max_retries == 7
799 assert self.client.max_retries == 2
800
801 copied2 = copied.copy(max_retries=6)
802 assert copied2.max_retries == 6
803 assert copied.max_retries == 7
804
805 # timeout
806 assert isinstance(self.client.timeout, httpx.Timeout)
807 copied = self.client.copy(timeout=None)
808 assert copied.timeout is None
809 assert isinstance(self.client.timeout, httpx.Timeout)
810
811 def test_copy_default_headers(self) -> None:
812 client = AsyncOpenAI(
813 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
814 )
815 assert client.default_headers["X-Foo"] == "bar"
816
817 # does not override the already given value when not specified
818 copied = client.copy()
819 assert copied.default_headers["X-Foo"] == "bar"
820
821 # merges already given headers
822 copied = client.copy(default_headers={"X-Bar": "stainless"})
823 assert copied.default_headers["X-Foo"] == "bar"
824 assert copied.default_headers["X-Bar"] == "stainless"
825
826 # uses new values for any already given headers
827 copied = client.copy(default_headers={"X-Foo": "stainless"})
828 assert copied.default_headers["X-Foo"] == "stainless"
829
830 # set_default_headers
831
832 # completely overrides already set values
833 copied = client.copy(set_default_headers={})
834 assert copied.default_headers.get("X-Foo") is None
835
836 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
837 assert copied.default_headers["X-Bar"] == "Robert"
838
839 with pytest.raises(
840 ValueError,
841 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
842 ):
843 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
844
845 def test_copy_default_query(self) -> None:
846 client = AsyncOpenAI(
847 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
848 )
849 assert _get_params(client)["foo"] == "bar"
850
851 # does not override the already given value when not specified
852 copied = client.copy()
853 assert _get_params(copied)["foo"] == "bar"
854
855 # merges already given params
856 copied = client.copy(default_query={"bar": "stainless"})
857 params = _get_params(copied)
858 assert params["foo"] == "bar"
859 assert params["bar"] == "stainless"
860
861 # uses new values for any already given headers
862 copied = client.copy(default_query={"foo": "stainless"})
863 assert _get_params(copied)["foo"] == "stainless"
864
865 # set_default_query
866
867 # completely overrides already set values
868 copied = client.copy(set_default_query={})
869 assert _get_params(copied) == {}
870
871 copied = client.copy(set_default_query={"bar": "Robert"})
872 assert _get_params(copied)["bar"] == "Robert"
873
874 with pytest.raises(
875 ValueError,
876 # TODO: update
877 match="`default_query` and `set_default_query` arguments are mutually exclusive",
878 ):
879 client.copy(set_default_query={}, default_query={"foo": "Bar"})
880
881 def test_copy_signature(self) -> None:
882 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
883 init_signature = inspect.signature(
884 # mypy doesn't like that we access the `__init__` property.
885 self.client.__init__, # type: ignore[misc]
886 )
887 copy_signature = inspect.signature(self.client.copy)
888 exclude_params = {"transport", "proxies", "_strict_response_validation"}
889
890 for name in init_signature.parameters.keys():
891 if name in exclude_params:
892 continue
893
894 copy_param = copy_signature.parameters.get(name)
895 assert copy_param is not None, f"copy() signature is missing the {name} param"
896
897 def test_copy_build_request(self) -> None:
898 options = FinalRequestOptions(method="get", url="/foo")
899
900 def build_request(options: FinalRequestOptions) -> None:
901 client = self.client.copy()
902 client._build_request(options)
903
904 # ensure that the machinery is warmed up before tracing starts.
905 build_request(options)
906 gc.collect()
907
908 tracemalloc.start(1000)
909
910 snapshot_before = tracemalloc.take_snapshot()
911
912 ITERATIONS = 10
913 for _ in range(ITERATIONS):
914 build_request(options)
915
916 gc.collect()
917 snapshot_after = tracemalloc.take_snapshot()
918
919 tracemalloc.stop()
920
921 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
922 if diff.count == 0:
923 # Avoid false positives by considering only leaks (i.e. allocations that persist).
924 return
925
926 if diff.count % ITERATIONS != 0:
927 # Avoid false positives by considering only leaks that appear per iteration.
928 return
929
930 for frame in diff.traceback:
931 if any(
932 frame.filename.endswith(fragment)
933 for fragment in [
934 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
935 #
936 # removing the decorator fixes the leak for reasons we don't understand.
937 "openai/_legacy_response.py",
938 "openai/_response.py",
939 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
940 "openai/_compat.py",
941 # Standard library leaks we don't care about.
942 "/logging/__init__.py",
943 ]
944 ):
945 return
946
947 leaks.append(diff)
948
949 leaks: list[tracemalloc.StatisticDiff] = []
950 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
951 add_leak(leaks, diff)
952 if leaks:
953 for leak in leaks:
954 print("MEMORY LEAK:", leak)
955 for frame in leak.traceback:
956 print(frame)
957 raise AssertionError()
958
959 async def test_request_timeout(self) -> None:
960 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
961 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
962 assert timeout == DEFAULT_TIMEOUT
963
964 request = self.client._build_request(
965 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
966 )
967 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
968 assert timeout == httpx.Timeout(100.0)
969
970 async def test_client_timeout_option(self) -> None:
971 client = AsyncOpenAI(
972 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
973 )
974
975 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
976 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
977 assert timeout == httpx.Timeout(0)
978
979 async def test_http_client_timeout_option(self) -> None:
980 # custom timeout given to the httpx client should be used
981 async with httpx.AsyncClient(timeout=None) as http_client:
982 client = AsyncOpenAI(
983 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
984 )
985
986 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
987 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
988 assert timeout == httpx.Timeout(None)
989
990 # no timeout given to the httpx client should not use the httpx default
991 async with httpx.AsyncClient() as http_client:
992 client = AsyncOpenAI(
993 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
994 )
995
996 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
997 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
998 assert timeout == DEFAULT_TIMEOUT
999
1000 # explicitly passing the default timeout currently results in it being ignored
1001 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1002 client = AsyncOpenAI(
1003 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1004 )
1005
1006 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1007 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1008 assert timeout == DEFAULT_TIMEOUT # our default
1009
1010 def test_invalid_http_client(self) -> None:
1011 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1012 with httpx.Client() as http_client:
1013 AsyncOpenAI(
1014 base_url=base_url,
1015 api_key=api_key,
1016 _strict_response_validation=True,
1017 http_client=cast(Any, http_client),
1018 )
1019
1020 def test_default_headers_option(self) -> None:
1021 client = AsyncOpenAI(
1022 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1023 )
1024 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1025 assert request.headers.get("x-foo") == "bar"
1026 assert request.headers.get("x-stainless-lang") == "python"
1027
1028 client2 = AsyncOpenAI(
1029 base_url=base_url,
1030 api_key=api_key,
1031 _strict_response_validation=True,
1032 default_headers={
1033 "X-Foo": "stainless",
1034 "X-Stainless-Lang": "my-overriding-header",
1035 },
1036 )
1037 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1038 assert request.headers.get("x-foo") == "stainless"
1039 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1040
1041 def test_validate_headers(self) -> None:
1042 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1043 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1044 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1045
1046 with pytest.raises(OpenAIError):
1047 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1048 _ = client2
1049
1050 def test_default_query_option(self) -> None:
1051 client = AsyncOpenAI(
1052 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1053 )
1054 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1055 url = httpx.URL(request.url)
1056 assert dict(url.params) == {"query_param": "bar"}
1057
1058 request = client._build_request(
1059 FinalRequestOptions(
1060 method="get",
1061 url="/foo",
1062 params={"foo": "baz", "query_param": "overriden"},
1063 )
1064 )
1065 url = httpx.URL(request.url)
1066 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
1067
1068 def test_request_extra_json(self) -> None:
1069 request = self.client._build_request(
1070 FinalRequestOptions(
1071 method="post",
1072 url="/foo",
1073 json_data={"foo": "bar"},
1074 extra_json={"baz": False},
1075 ),
1076 )
1077 data = json.loads(request.content.decode("utf-8"))
1078 assert data == {"foo": "bar", "baz": False}
1079
1080 request = self.client._build_request(
1081 FinalRequestOptions(
1082 method="post",
1083 url="/foo",
1084 extra_json={"baz": False},
1085 ),
1086 )
1087 data = json.loads(request.content.decode("utf-8"))
1088 assert data == {"baz": False}
1089
1090 # `extra_json` takes priority over `json_data` when keys clash
1091 request = self.client._build_request(
1092 FinalRequestOptions(
1093 method="post",
1094 url="/foo",
1095 json_data={"foo": "bar", "baz": True},
1096 extra_json={"baz": None},
1097 ),
1098 )
1099 data = json.loads(request.content.decode("utf-8"))
1100 assert data == {"foo": "bar", "baz": None}
1101
1102 def test_request_extra_headers(self) -> None:
1103 request = self.client._build_request(
1104 FinalRequestOptions(
1105 method="post",
1106 url="/foo",
1107 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1108 ),
1109 )
1110 assert request.headers.get("X-Foo") == "Foo"
1111
1112 # `extra_headers` takes priority over `default_headers` when keys clash
1113 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1114 FinalRequestOptions(
1115 method="post",
1116 url="/foo",
1117 **make_request_options(
1118 extra_headers={"X-Bar": "false"},
1119 ),
1120 ),
1121 )
1122 assert request.headers.get("X-Bar") == "false"
1123
1124 def test_request_extra_query(self) -> None:
1125 request = self.client._build_request(
1126 FinalRequestOptions(
1127 method="post",
1128 url="/foo",
1129 **make_request_options(
1130 extra_query={"my_query_param": "Foo"},
1131 ),
1132 ),
1133 )
1134 params = dict(request.url.params)
1135 assert params == {"my_query_param": "Foo"}
1136
1137 # if both `query` and `extra_query` are given, they are merged
1138 request = self.client._build_request(
1139 FinalRequestOptions(
1140 method="post",
1141 url="/foo",
1142 **make_request_options(
1143 query={"bar": "1"},
1144 extra_query={"foo": "2"},
1145 ),
1146 ),
1147 )
1148 params = dict(request.url.params)
1149 assert params == {"bar": "1", "foo": "2"}
1150
1151 # `extra_query` takes priority over `query` when keys clash
1152 request = self.client._build_request(
1153 FinalRequestOptions(
1154 method="post",
1155 url="/foo",
1156 **make_request_options(
1157 query={"foo": "1"},
1158 extra_query={"foo": "2"},
1159 ),
1160 ),
1161 )
1162 params = dict(request.url.params)
1163 assert params == {"foo": "2"}
1164
1165 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1166 request = async_client._build_request(
1167 FinalRequestOptions.construct(
1168 method="get",
1169 url="/foo",
1170 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1171 json_data={"array": ["foo", "bar"]},
1172 files=[("foo.txt", b"hello world")],
1173 )
1174 )
1175
1176 assert request.read().split(b"\r\n") == [
1177 b"--6b7ba517decee4a450543ea6ae821c82",
1178 b'Content-Disposition: form-data; name="array[]"',
1179 b"",
1180 b"foo",
1181 b"--6b7ba517decee4a450543ea6ae821c82",
1182 b'Content-Disposition: form-data; name="array[]"',
1183 b"",
1184 b"bar",
1185 b"--6b7ba517decee4a450543ea6ae821c82",
1186 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1187 b"Content-Type: application/octet-stream",
1188 b"",
1189 b"hello world",
1190 b"--6b7ba517decee4a450543ea6ae821c82--",
1191 b"",
1192 ]
1193
1194 @pytest.mark.respx(base_url=base_url)
1195 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1196 class Model1(BaseModel):
1197 name: str
1198
1199 class Model2(BaseModel):
1200 foo: str
1201
1202 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1203
1204 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1205 assert isinstance(response, Model2)
1206 assert response.foo == "bar"
1207
1208 @pytest.mark.respx(base_url=base_url)
1209 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1210 """Union of objects with the same field name using a different type"""
1211
1212 class Model1(BaseModel):
1213 foo: int
1214
1215 class Model2(BaseModel):
1216 foo: str
1217
1218 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1219
1220 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1221 assert isinstance(response, Model2)
1222 assert response.foo == "bar"
1223
1224 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1225
1226 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1227 assert isinstance(response, Model1)
1228 assert response.foo == 1
1229
1230 @pytest.mark.respx(base_url=base_url)
1231 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1232 """
1233 Response that sets Content-Type to something other than application/json but returns json data
1234 """
1235
1236 class Model(BaseModel):
1237 foo: int
1238
1239 respx_mock.get("/foo").mock(
1240 return_value=httpx.Response(
1241 200,
1242 content=json.dumps({"foo": 2}),
1243 headers={"Content-Type": "application/text"},
1244 )
1245 )
1246
1247 response = await self.client.get("/foo", cast_to=Model)
1248 assert isinstance(response, Model)
1249 assert response.foo == 2
1250
1251 def test_base_url_setter(self) -> None:
1252 client = AsyncOpenAI(
1253 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1254 )
1255 assert client.base_url == "https://example.com/from_init/"
1256
1257 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1258
1259 assert client.base_url == "https://example.com/from_setter/"
1260
1261 def test_base_url_env(self) -> None:
1262 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1263 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1264 assert client.base_url == "http://localhost:5000/from/env/"
1265
1266 @pytest.mark.parametrize(
1267 "client",
1268 [
1269 AsyncOpenAI(
1270 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1271 ),
1272 AsyncOpenAI(
1273 base_url="http://localhost:5000/custom/path/",
1274 api_key=api_key,
1275 _strict_response_validation=True,
1276 http_client=httpx.AsyncClient(),
1277 ),
1278 ],
1279 ids=["standard", "custom http client"],
1280 )
1281 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1282 request = client._build_request(
1283 FinalRequestOptions(
1284 method="post",
1285 url="/foo",
1286 json_data={"foo": "bar"},
1287 ),
1288 )
1289 assert request.url == "http://localhost:5000/custom/path/foo"
1290
1291 @pytest.mark.parametrize(
1292 "client",
1293 [
1294 AsyncOpenAI(
1295 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1296 ),
1297 AsyncOpenAI(
1298 base_url="http://localhost:5000/custom/path/",
1299 api_key=api_key,
1300 _strict_response_validation=True,
1301 http_client=httpx.AsyncClient(),
1302 ),
1303 ],
1304 ids=["standard", "custom http client"],
1305 )
1306 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1307 request = client._build_request(
1308 FinalRequestOptions(
1309 method="post",
1310 url="/foo",
1311 json_data={"foo": "bar"},
1312 ),
1313 )
1314 assert request.url == "http://localhost:5000/custom/path/foo"
1315
1316 @pytest.mark.parametrize(
1317 "client",
1318 [
1319 AsyncOpenAI(
1320 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1321 ),
1322 AsyncOpenAI(
1323 base_url="http://localhost:5000/custom/path/",
1324 api_key=api_key,
1325 _strict_response_validation=True,
1326 http_client=httpx.AsyncClient(),
1327 ),
1328 ],
1329 ids=["standard", "custom http client"],
1330 )
1331 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1332 request = client._build_request(
1333 FinalRequestOptions(
1334 method="post",
1335 url="https://myapi.com/foo",
1336 json_data={"foo": "bar"},
1337 ),
1338 )
1339 assert request.url == "https://myapi.com/foo"
1340
1341 async def test_copied_client_does_not_close_http(self) -> None:
1342 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1343 assert not client.is_closed()
1344
1345 copied = client.copy()
1346 assert copied is not client
1347
1348 del copied
1349
1350 await asyncio.sleep(0.2)
1351 assert not client.is_closed()
1352
1353 async def test_client_context_manager(self) -> None:
1354 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1355 async with client as c2:
1356 assert c2 is client
1357 assert not c2.is_closed()
1358 assert not client.is_closed()
1359 assert client.is_closed()
1360
1361 @pytest.mark.respx(base_url=base_url)
1362 @pytest.mark.asyncio
1363 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1364 class Model(BaseModel):
1365 foo: str
1366
1367 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1368
1369 with pytest.raises(APIResponseValidationError) as exc:
1370 await self.client.get("/foo", cast_to=Model)
1371
1372 assert isinstance(exc.value.__cause__, ValidationError)
1373
1374 async def test_client_max_retries_validation(self) -> None:
1375 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
1376 AsyncOpenAI(
1377 base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)
1378 )
1379
1380 @pytest.mark.respx(base_url=base_url)
1381 @pytest.mark.asyncio
1382 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1383 class Model(BaseModel):
1384 name: str
1385
1386 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1387
1388 stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1389 assert isinstance(stream, AsyncStream)
1390 await stream.response.aclose()
1391
1392 @pytest.mark.respx(base_url=base_url)
1393 @pytest.mark.asyncio
1394 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1395 class Model(BaseModel):
1396 name: str
1397
1398 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1399
1400 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1401
1402 with pytest.raises(APIResponseValidationError):
1403 await strict_client.get("/foo", cast_to=Model)
1404
1405 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1406
1407 response = await client.get("/foo", cast_to=Model)
1408 assert isinstance(response, str) # type: ignore[unreachable]
1409
1410 @pytest.mark.parametrize(
1411 "remaining_retries,retry_after,timeout",
1412 [
1413 [3, "20", 20],
1414 [3, "0", 0.5],
1415 [3, "-10", 0.5],
1416 [3, "60", 60],
1417 [3, "61", 0.5],
1418 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1419 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1420 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1421 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1422 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1423 [3, "99999999999999999999999999999999999", 0.5],
1424 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1425 [3, "", 0.5],
1426 [2, "", 0.5 * 2.0],
1427 [1, "", 0.5 * 4.0],
1428 ],
1429 )
1430 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1431 @pytest.mark.asyncio
1432 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1433 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1434
1435 headers = httpx.Headers({"retry-after": retry_after})
1436 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1437 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1438 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1439
1440 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1441 @pytest.mark.respx(base_url=base_url)
1442 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1443 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1444
1445 with pytest.raises(APITimeoutError):
1446 await self.client.post(
1447 "/chat/completions",
1448 body=cast(
1449 object,
1450 dict(
1451 messages=[
1452 {
1453 "role": "user",
1454 "content": "Say this is a test",
1455 }
1456 ],
1457 model="gpt-3.5-turbo",
1458 ),
1459 ),
1460 cast_to=httpx.Response,
1461 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1462 )
1463
1464 assert _get_open_connections(self.client) == 0
1465
1466 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1467 @pytest.mark.respx(base_url=base_url)
1468 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1469 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1470
1471 with pytest.raises(APIStatusError):
1472 await self.client.post(
1473 "/chat/completions",
1474 body=cast(
1475 object,
1476 dict(
1477 messages=[
1478 {
1479 "role": "user",
1480 "content": "Say this is a test",
1481 }
1482 ],
1483 model="gpt-3.5-turbo",
1484 ),
1485 ),
1486 cast_to=httpx.Response,
1487 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1488 )
1489
1490 assert _get_open_connections(self.client) == 0
1491