openai/openai-python

Public

mirrored from https://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
98d8b2ac13ce7b22fd9ce237157bc1be593c7362

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1612lines · modecode

1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import gc
6import os
7import json
8import asyncio
9import inspect
10import tracemalloc
11from typing import Any, Union, cast
12from unittest import mock
13
14import httpx
15import pytest
16from respx import MockRouter
17from pydantic import ValidationError
18
19from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
20from openai._models import BaseModel, FinalRequestOptions
21from openai._constants import RAW_RESPONSE_HEADER
22from openai._streaming import Stream, AsyncStream
23from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
24from openai._base_client import DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, BaseClient, make_request_options
25
26from .utils import update_env
27
28base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
29api_key = "My API Key"
30
31
32def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
33 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
34 url = httpx.URL(request.url)
35 return dict(url.params)
36
37
38def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
39 return 0.1
40
41
42def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
43 transport = client._client._transport
44 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
45
46 pool = transport._pool
47 return len(pool._requests)
48
49
50class TestOpenAI:
51 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
52
53 @pytest.mark.respx(base_url=base_url)
54 def test_raw_response(self, respx_mock: MockRouter) -> None:
55 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
56
57 response = self.client.post("/foo", cast_to=httpx.Response)
58 assert response.status_code == 200
59 assert isinstance(response, httpx.Response)
60 assert response.json() == {"foo": "bar"}
61
62 @pytest.mark.respx(base_url=base_url)
63 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
64 respx_mock.post("/foo").mock(
65 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
66 )
67
68 response = self.client.post("/foo", cast_to=httpx.Response)
69 assert response.status_code == 200
70 assert isinstance(response, httpx.Response)
71 assert response.json() == {"foo": "bar"}
72
73 def test_copy(self) -> None:
74 copied = self.client.copy()
75 assert id(copied) != id(self.client)
76
77 copied = self.client.copy(api_key="another My API Key")
78 assert copied.api_key == "another My API Key"
79 assert self.client.api_key == "My API Key"
80
81 def test_copy_default_options(self) -> None:
82 # options that have a default are overridden correctly
83 copied = self.client.copy(max_retries=7)
84 assert copied.max_retries == 7
85 assert self.client.max_retries == 2
86
87 copied2 = copied.copy(max_retries=6)
88 assert copied2.max_retries == 6
89 assert copied.max_retries == 7
90
91 # timeout
92 assert isinstance(self.client.timeout, httpx.Timeout)
93 copied = self.client.copy(timeout=None)
94 assert copied.timeout is None
95 assert isinstance(self.client.timeout, httpx.Timeout)
96
97 def test_copy_default_headers(self) -> None:
98 client = OpenAI(
99 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
100 )
101 assert client.default_headers["X-Foo"] == "bar"
102
103 # does not override the already given value when not specified
104 copied = client.copy()
105 assert copied.default_headers["X-Foo"] == "bar"
106
107 # merges already given headers
108 copied = client.copy(default_headers={"X-Bar": "stainless"})
109 assert copied.default_headers["X-Foo"] == "bar"
110 assert copied.default_headers["X-Bar"] == "stainless"
111
112 # uses new values for any already given headers
113 copied = client.copy(default_headers={"X-Foo": "stainless"})
114 assert copied.default_headers["X-Foo"] == "stainless"
115
116 # set_default_headers
117
118 # completely overrides already set values
119 copied = client.copy(set_default_headers={})
120 assert copied.default_headers.get("X-Foo") is None
121
122 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
123 assert copied.default_headers["X-Bar"] == "Robert"
124
125 with pytest.raises(
126 ValueError,
127 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
128 ):
129 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
130
131 def test_copy_default_query(self) -> None:
132 client = OpenAI(
133 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
134 )
135 assert _get_params(client)["foo"] == "bar"
136
137 # does not override the already given value when not specified
138 copied = client.copy()
139 assert _get_params(copied)["foo"] == "bar"
140
141 # merges already given params
142 copied = client.copy(default_query={"bar": "stainless"})
143 params = _get_params(copied)
144 assert params["foo"] == "bar"
145 assert params["bar"] == "stainless"
146
147 # uses new values for any already given headers
148 copied = client.copy(default_query={"foo": "stainless"})
149 assert _get_params(copied)["foo"] == "stainless"
150
151 # set_default_query
152
153 # completely overrides already set values
154 copied = client.copy(set_default_query={})
155 assert _get_params(copied) == {}
156
157 copied = client.copy(set_default_query={"bar": "Robert"})
158 assert _get_params(copied)["bar"] == "Robert"
159
160 with pytest.raises(
161 ValueError,
162 # TODO: update
163 match="`default_query` and `set_default_query` arguments are mutually exclusive",
164 ):
165 client.copy(set_default_query={}, default_query={"foo": "Bar"})
166
167 def test_copy_signature(self) -> None:
168 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
169 init_signature = inspect.signature(
170 # mypy doesn't like that we access the `__init__` property.
171 self.client.__init__, # type: ignore[misc]
172 )
173 copy_signature = inspect.signature(self.client.copy)
174 exclude_params = {"transport", "proxies", "_strict_response_validation"}
175
176 for name in init_signature.parameters.keys():
177 if name in exclude_params:
178 continue
179
180 copy_param = copy_signature.parameters.get(name)
181 assert copy_param is not None, f"copy() signature is missing the {name} param"
182
183 def test_copy_build_request(self) -> None:
184 options = FinalRequestOptions(method="get", url="/foo")
185
186 def build_request(options: FinalRequestOptions) -> None:
187 client = self.client.copy()
188 client._build_request(options)
189
190 # ensure that the machinery is warmed up before tracing starts.
191 build_request(options)
192 gc.collect()
193
194 tracemalloc.start(1000)
195
196 snapshot_before = tracemalloc.take_snapshot()
197
198 ITERATIONS = 10
199 for _ in range(ITERATIONS):
200 build_request(options)
201
202 gc.collect()
203 snapshot_after = tracemalloc.take_snapshot()
204
205 tracemalloc.stop()
206
207 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
208 if diff.count == 0:
209 # Avoid false positives by considering only leaks (i.e. allocations that persist).
210 return
211
212 if diff.count % ITERATIONS != 0:
213 # Avoid false positives by considering only leaks that appear per iteration.
214 return
215
216 for frame in diff.traceback:
217 if any(
218 frame.filename.endswith(fragment)
219 for fragment in [
220 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
221 #
222 # removing the decorator fixes the leak for reasons we don't understand.
223 "openai/_legacy_response.py",
224 "openai/_response.py",
225 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
226 "openai/_compat.py",
227 # Standard library leaks we don't care about.
228 "/logging/__init__.py",
229 ]
230 ):
231 return
232
233 leaks.append(diff)
234
235 leaks: list[tracemalloc.StatisticDiff] = []
236 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
237 add_leak(leaks, diff)
238 if leaks:
239 for leak in leaks:
240 print("MEMORY LEAK:", leak)
241 for frame in leak.traceback:
242 print(frame)
243 raise AssertionError()
244
245 def test_request_timeout(self) -> None:
246 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
247 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
248 assert timeout == DEFAULT_TIMEOUT
249
250 request = self.client._build_request(
251 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
252 )
253 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
254 assert timeout == httpx.Timeout(100.0)
255
256 def test_client_timeout_option(self) -> None:
257 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
258
259 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
260 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
261 assert timeout == httpx.Timeout(0)
262
263 def test_http_client_timeout_option(self) -> None:
264 # custom timeout given to the httpx client should be used
265 with httpx.Client(timeout=None) as http_client:
266 client = OpenAI(
267 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
268 )
269
270 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
271 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
272 assert timeout == httpx.Timeout(None)
273
274 # no timeout given to the httpx client should not use the httpx default
275 with httpx.Client() as http_client:
276 client = OpenAI(
277 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
278 )
279
280 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
281 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
282 assert timeout == DEFAULT_TIMEOUT
283
284 # explicitly passing the default timeout currently results in it being ignored
285 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
286 client = OpenAI(
287 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
288 )
289
290 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
291 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
292 assert timeout == DEFAULT_TIMEOUT # our default
293
294 async def test_invalid_http_client(self) -> None:
295 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
296 async with httpx.AsyncClient() as http_client:
297 OpenAI(
298 base_url=base_url,
299 api_key=api_key,
300 _strict_response_validation=True,
301 http_client=cast(Any, http_client),
302 )
303
304 def test_default_headers_option(self) -> None:
305 client = OpenAI(
306 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
307 )
308 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
309 assert request.headers.get("x-foo") == "bar"
310 assert request.headers.get("x-stainless-lang") == "python"
311
312 client2 = OpenAI(
313 base_url=base_url,
314 api_key=api_key,
315 _strict_response_validation=True,
316 default_headers={
317 "X-Foo": "stainless",
318 "X-Stainless-Lang": "my-overriding-header",
319 },
320 )
321 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
322 assert request.headers.get("x-foo") == "stainless"
323 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
324
325 def test_validate_headers(self) -> None:
326 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
327 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
328 assert request.headers.get("Authorization") == f"Bearer {api_key}"
329
330 with pytest.raises(OpenAIError):
331 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
332 _ = client2
333
334 def test_default_query_option(self) -> None:
335 client = OpenAI(
336 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
337 )
338 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
339 url = httpx.URL(request.url)
340 assert dict(url.params) == {"query_param": "bar"}
341
342 request = client._build_request(
343 FinalRequestOptions(
344 method="get",
345 url="/foo",
346 params={"foo": "baz", "query_param": "overriden"},
347 )
348 )
349 url = httpx.URL(request.url)
350 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
351
352 def test_request_extra_json(self) -> None:
353 request = self.client._build_request(
354 FinalRequestOptions(
355 method="post",
356 url="/foo",
357 json_data={"foo": "bar"},
358 extra_json={"baz": False},
359 ),
360 )
361 data = json.loads(request.content.decode("utf-8"))
362 assert data == {"foo": "bar", "baz": False}
363
364 request = self.client._build_request(
365 FinalRequestOptions(
366 method="post",
367 url="/foo",
368 extra_json={"baz": False},
369 ),
370 )
371 data = json.loads(request.content.decode("utf-8"))
372 assert data == {"baz": False}
373
374 # `extra_json` takes priority over `json_data` when keys clash
375 request = self.client._build_request(
376 FinalRequestOptions(
377 method="post",
378 url="/foo",
379 json_data={"foo": "bar", "baz": True},
380 extra_json={"baz": None},
381 ),
382 )
383 data = json.loads(request.content.decode("utf-8"))
384 assert data == {"foo": "bar", "baz": None}
385
386 def test_request_extra_headers(self) -> None:
387 request = self.client._build_request(
388 FinalRequestOptions(
389 method="post",
390 url="/foo",
391 **make_request_options(extra_headers={"X-Foo": "Foo"}),
392 ),
393 )
394 assert request.headers.get("X-Foo") == "Foo"
395
396 # `extra_headers` takes priority over `default_headers` when keys clash
397 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
398 FinalRequestOptions(
399 method="post",
400 url="/foo",
401 **make_request_options(
402 extra_headers={"X-Bar": "false"},
403 ),
404 ),
405 )
406 assert request.headers.get("X-Bar") == "false"
407
408 def test_request_extra_query(self) -> None:
409 request = self.client._build_request(
410 FinalRequestOptions(
411 method="post",
412 url="/foo",
413 **make_request_options(
414 extra_query={"my_query_param": "Foo"},
415 ),
416 ),
417 )
418 params = dict(request.url.params)
419 assert params == {"my_query_param": "Foo"}
420
421 # if both `query` and `extra_query` are given, they are merged
422 request = self.client._build_request(
423 FinalRequestOptions(
424 method="post",
425 url="/foo",
426 **make_request_options(
427 query={"bar": "1"},
428 extra_query={"foo": "2"},
429 ),
430 ),
431 )
432 params = dict(request.url.params)
433 assert params == {"bar": "1", "foo": "2"}
434
435 # `extra_query` takes priority over `query` when keys clash
436 request = self.client._build_request(
437 FinalRequestOptions(
438 method="post",
439 url="/foo",
440 **make_request_options(
441 query={"foo": "1"},
442 extra_query={"foo": "2"},
443 ),
444 ),
445 )
446 params = dict(request.url.params)
447 assert params == {"foo": "2"}
448
449 def test_multipart_repeating_array(self, client: OpenAI) -> None:
450 request = client._build_request(
451 FinalRequestOptions.construct(
452 method="get",
453 url="/foo",
454 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
455 json_data={"array": ["foo", "bar"]},
456 files=[("foo.txt", b"hello world")],
457 )
458 )
459
460 assert request.read().split(b"\r\n") == [
461 b"--6b7ba517decee4a450543ea6ae821c82",
462 b'Content-Disposition: form-data; name="array[]"',
463 b"",
464 b"foo",
465 b"--6b7ba517decee4a450543ea6ae821c82",
466 b'Content-Disposition: form-data; name="array[]"',
467 b"",
468 b"bar",
469 b"--6b7ba517decee4a450543ea6ae821c82",
470 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
471 b"Content-Type: application/octet-stream",
472 b"",
473 b"hello world",
474 b"--6b7ba517decee4a450543ea6ae821c82--",
475 b"",
476 ]
477
478 @pytest.mark.respx(base_url=base_url)
479 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
480 class Model1(BaseModel):
481 name: str
482
483 class Model2(BaseModel):
484 foo: str
485
486 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
487
488 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
489 assert isinstance(response, Model2)
490 assert response.foo == "bar"
491
492 @pytest.mark.respx(base_url=base_url)
493 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
494 """Union of objects with the same field name using a different type"""
495
496 class Model1(BaseModel):
497 foo: int
498
499 class Model2(BaseModel):
500 foo: str
501
502 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
503
504 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
505 assert isinstance(response, Model2)
506 assert response.foo == "bar"
507
508 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
509
510 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
511 assert isinstance(response, Model1)
512 assert response.foo == 1
513
514 @pytest.mark.respx(base_url=base_url)
515 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
516 """
517 Response that sets Content-Type to something other than application/json but returns json data
518 """
519
520 class Model(BaseModel):
521 foo: int
522
523 respx_mock.get("/foo").mock(
524 return_value=httpx.Response(
525 200,
526 content=json.dumps({"foo": 2}),
527 headers={"Content-Type": "application/text"},
528 )
529 )
530
531 response = self.client.get("/foo", cast_to=Model)
532 assert isinstance(response, Model)
533 assert response.foo == 2
534
535 def test_base_url_setter(self) -> None:
536 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
537 assert client.base_url == "https://example.com/from_init/"
538
539 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
540
541 assert client.base_url == "https://example.com/from_setter/"
542
543 def test_base_url_env(self) -> None:
544 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
545 client = OpenAI(api_key=api_key, _strict_response_validation=True)
546 assert client.base_url == "http://localhost:5000/from/env/"
547
548 @pytest.mark.parametrize(
549 "client",
550 [
551 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
552 OpenAI(
553 base_url="http://localhost:5000/custom/path/",
554 api_key=api_key,
555 _strict_response_validation=True,
556 http_client=httpx.Client(),
557 ),
558 ],
559 ids=["standard", "custom http client"],
560 )
561 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
562 request = client._build_request(
563 FinalRequestOptions(
564 method="post",
565 url="/foo",
566 json_data={"foo": "bar"},
567 ),
568 )
569 assert request.url == "http://localhost:5000/custom/path/foo"
570
571 @pytest.mark.parametrize(
572 "client",
573 [
574 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
575 OpenAI(
576 base_url="http://localhost:5000/custom/path/",
577 api_key=api_key,
578 _strict_response_validation=True,
579 http_client=httpx.Client(),
580 ),
581 ],
582 ids=["standard", "custom http client"],
583 )
584 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
585 request = client._build_request(
586 FinalRequestOptions(
587 method="post",
588 url="/foo",
589 json_data={"foo": "bar"},
590 ),
591 )
592 assert request.url == "http://localhost:5000/custom/path/foo"
593
594 @pytest.mark.parametrize(
595 "client",
596 [
597 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
598 OpenAI(
599 base_url="http://localhost:5000/custom/path/",
600 api_key=api_key,
601 _strict_response_validation=True,
602 http_client=httpx.Client(),
603 ),
604 ],
605 ids=["standard", "custom http client"],
606 )
607 def test_absolute_request_url(self, client: OpenAI) -> None:
608 request = client._build_request(
609 FinalRequestOptions(
610 method="post",
611 url="https://myapi.com/foo",
612 json_data={"foo": "bar"},
613 ),
614 )
615 assert request.url == "https://myapi.com/foo"
616
617 def test_copied_client_does_not_close_http(self) -> None:
618 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
619 assert not client.is_closed()
620
621 copied = client.copy()
622 assert copied is not client
623
624 del copied
625
626 assert not client.is_closed()
627
628 def test_client_context_manager(self) -> None:
629 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
630 with client as c2:
631 assert c2 is client
632 assert not c2.is_closed()
633 assert not client.is_closed()
634 assert client.is_closed()
635
636 @pytest.mark.respx(base_url=base_url)
637 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
638 class Model(BaseModel):
639 foo: str
640
641 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
642
643 with pytest.raises(APIResponseValidationError) as exc:
644 self.client.get("/foo", cast_to=Model)
645
646 assert isinstance(exc.value.__cause__, ValidationError)
647
648 def test_client_max_retries_validation(self) -> None:
649 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
650 OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None))
651
652 @pytest.mark.respx(base_url=base_url)
653 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
654 class Model(BaseModel):
655 name: str
656
657 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
658
659 stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
660 assert isinstance(stream, Stream)
661 stream.response.close()
662
663 @pytest.mark.respx(base_url=base_url)
664 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
665 class Model(BaseModel):
666 name: str
667
668 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
669
670 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
671
672 with pytest.raises(APIResponseValidationError):
673 strict_client.get("/foo", cast_to=Model)
674
675 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
676
677 response = client.get("/foo", cast_to=Model)
678 assert isinstance(response, str) # type: ignore[unreachable]
679
680 @pytest.mark.parametrize(
681 "remaining_retries,retry_after,timeout",
682 [
683 [3, "20", 20],
684 [3, "0", 0.5],
685 [3, "-10", 0.5],
686 [3, "60", 60],
687 [3, "61", 0.5],
688 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
689 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
690 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
691 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
692 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
693 [3, "99999999999999999999999999999999999", 0.5],
694 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
695 [3, "", 0.5],
696 [2, "", 0.5 * 2.0],
697 [1, "", 0.5 * 4.0],
698 ],
699 )
700 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
701 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
702 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
703
704 headers = httpx.Headers({"retry-after": retry_after})
705 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
706 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
707 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
708
709 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
710 @pytest.mark.respx(base_url=base_url)
711 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
712 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
713
714 with pytest.raises(APITimeoutError):
715 self.client.post(
716 "/chat/completions",
717 body=cast(
718 object,
719 dict(
720 messages=[
721 {
722 "role": "user",
723 "content": "Say this is a test",
724 }
725 ],
726 model="gpt-3.5-turbo",
727 ),
728 ),
729 cast_to=httpx.Response,
730 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
731 )
732
733 assert _get_open_connections(self.client) == 0
734
735 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
736 @pytest.mark.respx(base_url=base_url)
737 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
738 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
739
740 with pytest.raises(APIStatusError):
741 self.client.post(
742 "/chat/completions",
743 body=cast(
744 object,
745 dict(
746 messages=[
747 {
748 "role": "user",
749 "content": "Say this is a test",
750 }
751 ],
752 model="gpt-3.5-turbo",
753 ),
754 ),
755 cast_to=httpx.Response,
756 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
757 )
758
759 assert _get_open_connections(self.client) == 0
760
761 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
762 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
763 @pytest.mark.respx(base_url=base_url)
764 def test_retries_taken(self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter) -> None:
765 client = client.with_options(max_retries=4)
766
767 nb_retries = 0
768
769 def retry_handler(_request: httpx.Request) -> httpx.Response:
770 nonlocal nb_retries
771 if nb_retries < failures_before_success:
772 nb_retries += 1
773 return httpx.Response(500)
774 return httpx.Response(200)
775
776 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
777
778 response = client.chat.completions.with_raw_response.create(
779 messages=[
780 {
781 "content": "content",
782 "role": "system",
783 }
784 ],
785 model="gpt-4-turbo",
786 )
787
788 assert response.retries_taken == failures_before_success
789
790 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
791 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
792 @pytest.mark.respx(base_url=base_url)
793 def test_retries_taken_new_response_class(
794 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
795 ) -> None:
796 client = client.with_options(max_retries=4)
797
798 nb_retries = 0
799
800 def retry_handler(_request: httpx.Request) -> httpx.Response:
801 nonlocal nb_retries
802 if nb_retries < failures_before_success:
803 nb_retries += 1
804 return httpx.Response(500)
805 return httpx.Response(200)
806
807 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
808
809 with client.chat.completions.with_streaming_response.create(
810 messages=[
811 {
812 "content": "content",
813 "role": "system",
814 }
815 ],
816 model="gpt-4-turbo",
817 ) as response:
818 assert response.retries_taken == failures_before_success
819
820
821class TestAsyncOpenAI:
822 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
823
824 @pytest.mark.respx(base_url=base_url)
825 @pytest.mark.asyncio
826 async def test_raw_response(self, respx_mock: MockRouter) -> None:
827 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
828
829 response = await self.client.post("/foo", cast_to=httpx.Response)
830 assert response.status_code == 200
831 assert isinstance(response, httpx.Response)
832 assert response.json() == {"foo": "bar"}
833
834 @pytest.mark.respx(base_url=base_url)
835 @pytest.mark.asyncio
836 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
837 respx_mock.post("/foo").mock(
838 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
839 )
840
841 response = await self.client.post("/foo", cast_to=httpx.Response)
842 assert response.status_code == 200
843 assert isinstance(response, httpx.Response)
844 assert response.json() == {"foo": "bar"}
845
846 def test_copy(self) -> None:
847 copied = self.client.copy()
848 assert id(copied) != id(self.client)
849
850 copied = self.client.copy(api_key="another My API Key")
851 assert copied.api_key == "another My API Key"
852 assert self.client.api_key == "My API Key"
853
854 def test_copy_default_options(self) -> None:
855 # options that have a default are overridden correctly
856 copied = self.client.copy(max_retries=7)
857 assert copied.max_retries == 7
858 assert self.client.max_retries == 2
859
860 copied2 = copied.copy(max_retries=6)
861 assert copied2.max_retries == 6
862 assert copied.max_retries == 7
863
864 # timeout
865 assert isinstance(self.client.timeout, httpx.Timeout)
866 copied = self.client.copy(timeout=None)
867 assert copied.timeout is None
868 assert isinstance(self.client.timeout, httpx.Timeout)
869
870 def test_copy_default_headers(self) -> None:
871 client = AsyncOpenAI(
872 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
873 )
874 assert client.default_headers["X-Foo"] == "bar"
875
876 # does not override the already given value when not specified
877 copied = client.copy()
878 assert copied.default_headers["X-Foo"] == "bar"
879
880 # merges already given headers
881 copied = client.copy(default_headers={"X-Bar": "stainless"})
882 assert copied.default_headers["X-Foo"] == "bar"
883 assert copied.default_headers["X-Bar"] == "stainless"
884
885 # uses new values for any already given headers
886 copied = client.copy(default_headers={"X-Foo": "stainless"})
887 assert copied.default_headers["X-Foo"] == "stainless"
888
889 # set_default_headers
890
891 # completely overrides already set values
892 copied = client.copy(set_default_headers={})
893 assert copied.default_headers.get("X-Foo") is None
894
895 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
896 assert copied.default_headers["X-Bar"] == "Robert"
897
898 with pytest.raises(
899 ValueError,
900 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
901 ):
902 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
903
904 def test_copy_default_query(self) -> None:
905 client = AsyncOpenAI(
906 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
907 )
908 assert _get_params(client)["foo"] == "bar"
909
910 # does not override the already given value when not specified
911 copied = client.copy()
912 assert _get_params(copied)["foo"] == "bar"
913
914 # merges already given params
915 copied = client.copy(default_query={"bar": "stainless"})
916 params = _get_params(copied)
917 assert params["foo"] == "bar"
918 assert params["bar"] == "stainless"
919
920 # uses new values for any already given headers
921 copied = client.copy(default_query={"foo": "stainless"})
922 assert _get_params(copied)["foo"] == "stainless"
923
924 # set_default_query
925
926 # completely overrides already set values
927 copied = client.copy(set_default_query={})
928 assert _get_params(copied) == {}
929
930 copied = client.copy(set_default_query={"bar": "Robert"})
931 assert _get_params(copied)["bar"] == "Robert"
932
933 with pytest.raises(
934 ValueError,
935 # TODO: update
936 match="`default_query` and `set_default_query` arguments are mutually exclusive",
937 ):
938 client.copy(set_default_query={}, default_query={"foo": "Bar"})
939
940 def test_copy_signature(self) -> None:
941 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
942 init_signature = inspect.signature(
943 # mypy doesn't like that we access the `__init__` property.
944 self.client.__init__, # type: ignore[misc]
945 )
946 copy_signature = inspect.signature(self.client.copy)
947 exclude_params = {"transport", "proxies", "_strict_response_validation"}
948
949 for name in init_signature.parameters.keys():
950 if name in exclude_params:
951 continue
952
953 copy_param = copy_signature.parameters.get(name)
954 assert copy_param is not None, f"copy() signature is missing the {name} param"
955
956 def test_copy_build_request(self) -> None:
957 options = FinalRequestOptions(method="get", url="/foo")
958
959 def build_request(options: FinalRequestOptions) -> None:
960 client = self.client.copy()
961 client._build_request(options)
962
963 # ensure that the machinery is warmed up before tracing starts.
964 build_request(options)
965 gc.collect()
966
967 tracemalloc.start(1000)
968
969 snapshot_before = tracemalloc.take_snapshot()
970
971 ITERATIONS = 10
972 for _ in range(ITERATIONS):
973 build_request(options)
974
975 gc.collect()
976 snapshot_after = tracemalloc.take_snapshot()
977
978 tracemalloc.stop()
979
980 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
981 if diff.count == 0:
982 # Avoid false positives by considering only leaks (i.e. allocations that persist).
983 return
984
985 if diff.count % ITERATIONS != 0:
986 # Avoid false positives by considering only leaks that appear per iteration.
987 return
988
989 for frame in diff.traceback:
990 if any(
991 frame.filename.endswith(fragment)
992 for fragment in [
993 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
994 #
995 # removing the decorator fixes the leak for reasons we don't understand.
996 "openai/_legacy_response.py",
997 "openai/_response.py",
998 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
999 "openai/_compat.py",
1000 # Standard library leaks we don't care about.
1001 "/logging/__init__.py",
1002 ]
1003 ):
1004 return
1005
1006 leaks.append(diff)
1007
1008 leaks: list[tracemalloc.StatisticDiff] = []
1009 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
1010 add_leak(leaks, diff)
1011 if leaks:
1012 for leak in leaks:
1013 print("MEMORY LEAK:", leak)
1014 for frame in leak.traceback:
1015 print(frame)
1016 raise AssertionError()
1017
1018 async def test_request_timeout(self) -> None:
1019 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
1020 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1021 assert timeout == DEFAULT_TIMEOUT
1022
1023 request = self.client._build_request(
1024 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
1025 )
1026 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1027 assert timeout == httpx.Timeout(100.0)
1028
1029 async def test_client_timeout_option(self) -> None:
1030 client = AsyncOpenAI(
1031 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
1032 )
1033
1034 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1035 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1036 assert timeout == httpx.Timeout(0)
1037
1038 async def test_http_client_timeout_option(self) -> None:
1039 # custom timeout given to the httpx client should be used
1040 async with httpx.AsyncClient(timeout=None) as http_client:
1041 client = AsyncOpenAI(
1042 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1043 )
1044
1045 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1046 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1047 assert timeout == httpx.Timeout(None)
1048
1049 # no timeout given to the httpx client should not use the httpx default
1050 async with httpx.AsyncClient() as http_client:
1051 client = AsyncOpenAI(
1052 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1053 )
1054
1055 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1056 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1057 assert timeout == DEFAULT_TIMEOUT
1058
1059 # explicitly passing the default timeout currently results in it being ignored
1060 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1061 client = AsyncOpenAI(
1062 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1063 )
1064
1065 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1066 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1067 assert timeout == DEFAULT_TIMEOUT # our default
1068
1069 def test_invalid_http_client(self) -> None:
1070 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1071 with httpx.Client() as http_client:
1072 AsyncOpenAI(
1073 base_url=base_url,
1074 api_key=api_key,
1075 _strict_response_validation=True,
1076 http_client=cast(Any, http_client),
1077 )
1078
1079 def test_default_headers_option(self) -> None:
1080 client = AsyncOpenAI(
1081 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1082 )
1083 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1084 assert request.headers.get("x-foo") == "bar"
1085 assert request.headers.get("x-stainless-lang") == "python"
1086
1087 client2 = AsyncOpenAI(
1088 base_url=base_url,
1089 api_key=api_key,
1090 _strict_response_validation=True,
1091 default_headers={
1092 "X-Foo": "stainless",
1093 "X-Stainless-Lang": "my-overriding-header",
1094 },
1095 )
1096 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1097 assert request.headers.get("x-foo") == "stainless"
1098 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1099
1100 def test_validate_headers(self) -> None:
1101 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1102 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1103 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1104
1105 with pytest.raises(OpenAIError):
1106 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1107 _ = client2
1108
1109 def test_default_query_option(self) -> None:
1110 client = AsyncOpenAI(
1111 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1112 )
1113 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1114 url = httpx.URL(request.url)
1115 assert dict(url.params) == {"query_param": "bar"}
1116
1117 request = client._build_request(
1118 FinalRequestOptions(
1119 method="get",
1120 url="/foo",
1121 params={"foo": "baz", "query_param": "overriden"},
1122 )
1123 )
1124 url = httpx.URL(request.url)
1125 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
1126
1127 def test_request_extra_json(self) -> None:
1128 request = self.client._build_request(
1129 FinalRequestOptions(
1130 method="post",
1131 url="/foo",
1132 json_data={"foo": "bar"},
1133 extra_json={"baz": False},
1134 ),
1135 )
1136 data = json.loads(request.content.decode("utf-8"))
1137 assert data == {"foo": "bar", "baz": False}
1138
1139 request = self.client._build_request(
1140 FinalRequestOptions(
1141 method="post",
1142 url="/foo",
1143 extra_json={"baz": False},
1144 ),
1145 )
1146 data = json.loads(request.content.decode("utf-8"))
1147 assert data == {"baz": False}
1148
1149 # `extra_json` takes priority over `json_data` when keys clash
1150 request = self.client._build_request(
1151 FinalRequestOptions(
1152 method="post",
1153 url="/foo",
1154 json_data={"foo": "bar", "baz": True},
1155 extra_json={"baz": None},
1156 ),
1157 )
1158 data = json.loads(request.content.decode("utf-8"))
1159 assert data == {"foo": "bar", "baz": None}
1160
1161 def test_request_extra_headers(self) -> None:
1162 request = self.client._build_request(
1163 FinalRequestOptions(
1164 method="post",
1165 url="/foo",
1166 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1167 ),
1168 )
1169 assert request.headers.get("X-Foo") == "Foo"
1170
1171 # `extra_headers` takes priority over `default_headers` when keys clash
1172 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1173 FinalRequestOptions(
1174 method="post",
1175 url="/foo",
1176 **make_request_options(
1177 extra_headers={"X-Bar": "false"},
1178 ),
1179 ),
1180 )
1181 assert request.headers.get("X-Bar") == "false"
1182
1183 def test_request_extra_query(self) -> None:
1184 request = self.client._build_request(
1185 FinalRequestOptions(
1186 method="post",
1187 url="/foo",
1188 **make_request_options(
1189 extra_query={"my_query_param": "Foo"},
1190 ),
1191 ),
1192 )
1193 params = dict(request.url.params)
1194 assert params == {"my_query_param": "Foo"}
1195
1196 # if both `query` and `extra_query` are given, they are merged
1197 request = self.client._build_request(
1198 FinalRequestOptions(
1199 method="post",
1200 url="/foo",
1201 **make_request_options(
1202 query={"bar": "1"},
1203 extra_query={"foo": "2"},
1204 ),
1205 ),
1206 )
1207 params = dict(request.url.params)
1208 assert params == {"bar": "1", "foo": "2"}
1209
1210 # `extra_query` takes priority over `query` when keys clash
1211 request = self.client._build_request(
1212 FinalRequestOptions(
1213 method="post",
1214 url="/foo",
1215 **make_request_options(
1216 query={"foo": "1"},
1217 extra_query={"foo": "2"},
1218 ),
1219 ),
1220 )
1221 params = dict(request.url.params)
1222 assert params == {"foo": "2"}
1223
1224 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1225 request = async_client._build_request(
1226 FinalRequestOptions.construct(
1227 method="get",
1228 url="/foo",
1229 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1230 json_data={"array": ["foo", "bar"]},
1231 files=[("foo.txt", b"hello world")],
1232 )
1233 )
1234
1235 assert request.read().split(b"\r\n") == [
1236 b"--6b7ba517decee4a450543ea6ae821c82",
1237 b'Content-Disposition: form-data; name="array[]"',
1238 b"",
1239 b"foo",
1240 b"--6b7ba517decee4a450543ea6ae821c82",
1241 b'Content-Disposition: form-data; name="array[]"',
1242 b"",
1243 b"bar",
1244 b"--6b7ba517decee4a450543ea6ae821c82",
1245 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1246 b"Content-Type: application/octet-stream",
1247 b"",
1248 b"hello world",
1249 b"--6b7ba517decee4a450543ea6ae821c82--",
1250 b"",
1251 ]
1252
1253 @pytest.mark.respx(base_url=base_url)
1254 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1255 class Model1(BaseModel):
1256 name: str
1257
1258 class Model2(BaseModel):
1259 foo: str
1260
1261 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1262
1263 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1264 assert isinstance(response, Model2)
1265 assert response.foo == "bar"
1266
1267 @pytest.mark.respx(base_url=base_url)
1268 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1269 """Union of objects with the same field name using a different type"""
1270
1271 class Model1(BaseModel):
1272 foo: int
1273
1274 class Model2(BaseModel):
1275 foo: str
1276
1277 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1278
1279 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1280 assert isinstance(response, Model2)
1281 assert response.foo == "bar"
1282
1283 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1284
1285 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1286 assert isinstance(response, Model1)
1287 assert response.foo == 1
1288
1289 @pytest.mark.respx(base_url=base_url)
1290 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1291 """
1292 Response that sets Content-Type to something other than application/json but returns json data
1293 """
1294
1295 class Model(BaseModel):
1296 foo: int
1297
1298 respx_mock.get("/foo").mock(
1299 return_value=httpx.Response(
1300 200,
1301 content=json.dumps({"foo": 2}),
1302 headers={"Content-Type": "application/text"},
1303 )
1304 )
1305
1306 response = await self.client.get("/foo", cast_to=Model)
1307 assert isinstance(response, Model)
1308 assert response.foo == 2
1309
1310 def test_base_url_setter(self) -> None:
1311 client = AsyncOpenAI(
1312 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1313 )
1314 assert client.base_url == "https://example.com/from_init/"
1315
1316 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1317
1318 assert client.base_url == "https://example.com/from_setter/"
1319
1320 def test_base_url_env(self) -> None:
1321 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1322 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1323 assert client.base_url == "http://localhost:5000/from/env/"
1324
1325 @pytest.mark.parametrize(
1326 "client",
1327 [
1328 AsyncOpenAI(
1329 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1330 ),
1331 AsyncOpenAI(
1332 base_url="http://localhost:5000/custom/path/",
1333 api_key=api_key,
1334 _strict_response_validation=True,
1335 http_client=httpx.AsyncClient(),
1336 ),
1337 ],
1338 ids=["standard", "custom http client"],
1339 )
1340 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1341 request = client._build_request(
1342 FinalRequestOptions(
1343 method="post",
1344 url="/foo",
1345 json_data={"foo": "bar"},
1346 ),
1347 )
1348 assert request.url == "http://localhost:5000/custom/path/foo"
1349
1350 @pytest.mark.parametrize(
1351 "client",
1352 [
1353 AsyncOpenAI(
1354 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1355 ),
1356 AsyncOpenAI(
1357 base_url="http://localhost:5000/custom/path/",
1358 api_key=api_key,
1359 _strict_response_validation=True,
1360 http_client=httpx.AsyncClient(),
1361 ),
1362 ],
1363 ids=["standard", "custom http client"],
1364 )
1365 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1366 request = client._build_request(
1367 FinalRequestOptions(
1368 method="post",
1369 url="/foo",
1370 json_data={"foo": "bar"},
1371 ),
1372 )
1373 assert request.url == "http://localhost:5000/custom/path/foo"
1374
1375 @pytest.mark.parametrize(
1376 "client",
1377 [
1378 AsyncOpenAI(
1379 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1380 ),
1381 AsyncOpenAI(
1382 base_url="http://localhost:5000/custom/path/",
1383 api_key=api_key,
1384 _strict_response_validation=True,
1385 http_client=httpx.AsyncClient(),
1386 ),
1387 ],
1388 ids=["standard", "custom http client"],
1389 )
1390 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1391 request = client._build_request(
1392 FinalRequestOptions(
1393 method="post",
1394 url="https://myapi.com/foo",
1395 json_data={"foo": "bar"},
1396 ),
1397 )
1398 assert request.url == "https://myapi.com/foo"
1399
1400 async def test_copied_client_does_not_close_http(self) -> None:
1401 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1402 assert not client.is_closed()
1403
1404 copied = client.copy()
1405 assert copied is not client
1406
1407 del copied
1408
1409 await asyncio.sleep(0.2)
1410 assert not client.is_closed()
1411
1412 async def test_client_context_manager(self) -> None:
1413 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1414 async with client as c2:
1415 assert c2 is client
1416 assert not c2.is_closed()
1417 assert not client.is_closed()
1418 assert client.is_closed()
1419
1420 @pytest.mark.respx(base_url=base_url)
1421 @pytest.mark.asyncio
1422 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1423 class Model(BaseModel):
1424 foo: str
1425
1426 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1427
1428 with pytest.raises(APIResponseValidationError) as exc:
1429 await self.client.get("/foo", cast_to=Model)
1430
1431 assert isinstance(exc.value.__cause__, ValidationError)
1432
1433 async def test_client_max_retries_validation(self) -> None:
1434 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
1435 AsyncOpenAI(
1436 base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)
1437 )
1438
1439 @pytest.mark.respx(base_url=base_url)
1440 @pytest.mark.asyncio
1441 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1442 class Model(BaseModel):
1443 name: str
1444
1445 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1446
1447 stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1448 assert isinstance(stream, AsyncStream)
1449 await stream.response.aclose()
1450
1451 @pytest.mark.respx(base_url=base_url)
1452 @pytest.mark.asyncio
1453 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1454 class Model(BaseModel):
1455 name: str
1456
1457 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1458
1459 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1460
1461 with pytest.raises(APIResponseValidationError):
1462 await strict_client.get("/foo", cast_to=Model)
1463
1464 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1465
1466 response = await client.get("/foo", cast_to=Model)
1467 assert isinstance(response, str) # type: ignore[unreachable]
1468
1469 @pytest.mark.parametrize(
1470 "remaining_retries,retry_after,timeout",
1471 [
1472 [3, "20", 20],
1473 [3, "0", 0.5],
1474 [3, "-10", 0.5],
1475 [3, "60", 60],
1476 [3, "61", 0.5],
1477 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1478 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1479 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1480 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1481 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1482 [3, "99999999999999999999999999999999999", 0.5],
1483 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1484 [3, "", 0.5],
1485 [2, "", 0.5 * 2.0],
1486 [1, "", 0.5 * 4.0],
1487 ],
1488 )
1489 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1490 @pytest.mark.asyncio
1491 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1492 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1493
1494 headers = httpx.Headers({"retry-after": retry_after})
1495 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1496 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1497 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1498
1499 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1500 @pytest.mark.respx(base_url=base_url)
1501 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1502 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1503
1504 with pytest.raises(APITimeoutError):
1505 await self.client.post(
1506 "/chat/completions",
1507 body=cast(
1508 object,
1509 dict(
1510 messages=[
1511 {
1512 "role": "user",
1513 "content": "Say this is a test",
1514 }
1515 ],
1516 model="gpt-3.5-turbo",
1517 ),
1518 ),
1519 cast_to=httpx.Response,
1520 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1521 )
1522
1523 assert _get_open_connections(self.client) == 0
1524
1525 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1526 @pytest.mark.respx(base_url=base_url)
1527 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1528 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1529
1530 with pytest.raises(APIStatusError):
1531 await self.client.post(
1532 "/chat/completions",
1533 body=cast(
1534 object,
1535 dict(
1536 messages=[
1537 {
1538 "role": "user",
1539 "content": "Say this is a test",
1540 }
1541 ],
1542 model="gpt-3.5-turbo",
1543 ),
1544 ),
1545 cast_to=httpx.Response,
1546 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1547 )
1548
1549 assert _get_open_connections(self.client) == 0
1550
1551 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1552 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1553 @pytest.mark.respx(base_url=base_url)
1554 @pytest.mark.asyncio
1555 async def test_retries_taken(
1556 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1557 ) -> None:
1558 client = async_client.with_options(max_retries=4)
1559
1560 nb_retries = 0
1561
1562 def retry_handler(_request: httpx.Request) -> httpx.Response:
1563 nonlocal nb_retries
1564 if nb_retries < failures_before_success:
1565 nb_retries += 1
1566 return httpx.Response(500)
1567 return httpx.Response(200)
1568
1569 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1570
1571 response = await client.chat.completions.with_raw_response.create(
1572 messages=[
1573 {
1574 "content": "content",
1575 "role": "system",
1576 }
1577 ],
1578 model="gpt-4-turbo",
1579 )
1580
1581 assert response.retries_taken == failures_before_success
1582
1583 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1584 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1585 @pytest.mark.respx(base_url=base_url)
1586 @pytest.mark.asyncio
1587 async def test_retries_taken_new_response_class(
1588 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1589 ) -> None:
1590 client = async_client.with_options(max_retries=4)
1591
1592 nb_retries = 0
1593
1594 def retry_handler(_request: httpx.Request) -> httpx.Response:
1595 nonlocal nb_retries
1596 if nb_retries < failures_before_success:
1597 nb_retries += 1
1598 return httpx.Response(500)
1599 return httpx.Response(200)
1600
1601 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1602
1603 async with client.chat.completions.with_streaming_response.create(
1604 messages=[
1605 {
1606 "content": "content",
1607 "role": "system",
1608 }
1609 ],
1610 model="gpt-4-turbo",
1611 ) as response:
1612 assert response.retries_taken == failures_before_success
1613