openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
2bfec1a2f0be308053924ba673398cd18a038422

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1615lines · modecode

1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import gc
6import os
7import json
8import asyncio
9import inspect
10import tracemalloc
11from typing import Any, Union, cast
12from unittest import mock
13
14import httpx
15import pytest
16from respx import MockRouter
17from pydantic import ValidationError
18
19from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
20from openai._types import Omit
21from openai._models import BaseModel, FinalRequestOptions
22from openai._constants import RAW_RESPONSE_HEADER
23from openai._streaming import Stream, AsyncStream
24from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
25from openai._base_client import DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, BaseClient, make_request_options
26
27from .utils import update_env
28
29base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
30api_key = "My API Key"
31
32
33def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
34 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
35 url = httpx.URL(request.url)
36 return dict(url.params)
37
38
39def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
40 return 0.1
41
42
43def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
44 transport = client._client._transport
45 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
46
47 pool = transport._pool
48 return len(pool._requests)
49
50
51class TestOpenAI:
52 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
53
54 @pytest.mark.respx(base_url=base_url)
55 def test_raw_response(self, respx_mock: MockRouter) -> None:
56 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
57
58 response = self.client.post("/foo", cast_to=httpx.Response)
59 assert response.status_code == 200
60 assert isinstance(response, httpx.Response)
61 assert response.json() == {"foo": "bar"}
62
63 @pytest.mark.respx(base_url=base_url)
64 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
65 respx_mock.post("/foo").mock(
66 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
67 )
68
69 response = self.client.post("/foo", cast_to=httpx.Response)
70 assert response.status_code == 200
71 assert isinstance(response, httpx.Response)
72 assert response.json() == {"foo": "bar"}
73
74 def test_copy(self) -> None:
75 copied = self.client.copy()
76 assert id(copied) != id(self.client)
77
78 copied = self.client.copy(api_key="another My API Key")
79 assert copied.api_key == "another My API Key"
80 assert self.client.api_key == "My API Key"
81
82 def test_copy_default_options(self) -> None:
83 # options that have a default are overridden correctly
84 copied = self.client.copy(max_retries=7)
85 assert copied.max_retries == 7
86 assert self.client.max_retries == 2
87
88 copied2 = copied.copy(max_retries=6)
89 assert copied2.max_retries == 6
90 assert copied.max_retries == 7
91
92 # timeout
93 assert isinstance(self.client.timeout, httpx.Timeout)
94 copied = self.client.copy(timeout=None)
95 assert copied.timeout is None
96 assert isinstance(self.client.timeout, httpx.Timeout)
97
98 def test_copy_default_headers(self) -> None:
99 client = OpenAI(
100 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
101 )
102 assert client.default_headers["X-Foo"] == "bar"
103
104 # does not override the already given value when not specified
105 copied = client.copy()
106 assert copied.default_headers["X-Foo"] == "bar"
107
108 # merges already given headers
109 copied = client.copy(default_headers={"X-Bar": "stainless"})
110 assert copied.default_headers["X-Foo"] == "bar"
111 assert copied.default_headers["X-Bar"] == "stainless"
112
113 # uses new values for any already given headers
114 copied = client.copy(default_headers={"X-Foo": "stainless"})
115 assert copied.default_headers["X-Foo"] == "stainless"
116
117 # set_default_headers
118
119 # completely overrides already set values
120 copied = client.copy(set_default_headers={})
121 assert copied.default_headers.get("X-Foo") is None
122
123 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
124 assert copied.default_headers["X-Bar"] == "Robert"
125
126 with pytest.raises(
127 ValueError,
128 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
129 ):
130 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
131
132 def test_copy_default_query(self) -> None:
133 client = OpenAI(
134 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
135 )
136 assert _get_params(client)["foo"] == "bar"
137
138 # does not override the already given value when not specified
139 copied = client.copy()
140 assert _get_params(copied)["foo"] == "bar"
141
142 # merges already given params
143 copied = client.copy(default_query={"bar": "stainless"})
144 params = _get_params(copied)
145 assert params["foo"] == "bar"
146 assert params["bar"] == "stainless"
147
148 # uses new values for any already given headers
149 copied = client.copy(default_query={"foo": "stainless"})
150 assert _get_params(copied)["foo"] == "stainless"
151
152 # set_default_query
153
154 # completely overrides already set values
155 copied = client.copy(set_default_query={})
156 assert _get_params(copied) == {}
157
158 copied = client.copy(set_default_query={"bar": "Robert"})
159 assert _get_params(copied)["bar"] == "Robert"
160
161 with pytest.raises(
162 ValueError,
163 # TODO: update
164 match="`default_query` and `set_default_query` arguments are mutually exclusive",
165 ):
166 client.copy(set_default_query={}, default_query={"foo": "Bar"})
167
168 def test_copy_signature(self) -> None:
169 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
170 init_signature = inspect.signature(
171 # mypy doesn't like that we access the `__init__` property.
172 self.client.__init__, # type: ignore[misc]
173 )
174 copy_signature = inspect.signature(self.client.copy)
175 exclude_params = {"transport", "proxies", "_strict_response_validation"}
176
177 for name in init_signature.parameters.keys():
178 if name in exclude_params:
179 continue
180
181 copy_param = copy_signature.parameters.get(name)
182 assert copy_param is not None, f"copy() signature is missing the {name} param"
183
184 def test_copy_build_request(self) -> None:
185 options = FinalRequestOptions(method="get", url="/foo")
186
187 def build_request(options: FinalRequestOptions) -> None:
188 client = self.client.copy()
189 client._build_request(options)
190
191 # ensure that the machinery is warmed up before tracing starts.
192 build_request(options)
193 gc.collect()
194
195 tracemalloc.start(1000)
196
197 snapshot_before = tracemalloc.take_snapshot()
198
199 ITERATIONS = 10
200 for _ in range(ITERATIONS):
201 build_request(options)
202
203 gc.collect()
204 snapshot_after = tracemalloc.take_snapshot()
205
206 tracemalloc.stop()
207
208 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
209 if diff.count == 0:
210 # Avoid false positives by considering only leaks (i.e. allocations that persist).
211 return
212
213 if diff.count % ITERATIONS != 0:
214 # Avoid false positives by considering only leaks that appear per iteration.
215 return
216
217 for frame in diff.traceback:
218 if any(
219 frame.filename.endswith(fragment)
220 for fragment in [
221 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
222 #
223 # removing the decorator fixes the leak for reasons we don't understand.
224 "openai/_legacy_response.py",
225 "openai/_response.py",
226 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
227 "openai/_compat.py",
228 # Standard library leaks we don't care about.
229 "/logging/__init__.py",
230 ]
231 ):
232 return
233
234 leaks.append(diff)
235
236 leaks: list[tracemalloc.StatisticDiff] = []
237 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
238 add_leak(leaks, diff)
239 if leaks:
240 for leak in leaks:
241 print("MEMORY LEAK:", leak)
242 for frame in leak.traceback:
243 print(frame)
244 raise AssertionError()
245
246 def test_request_timeout(self) -> None:
247 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
248 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
249 assert timeout == DEFAULT_TIMEOUT
250
251 request = self.client._build_request(
252 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
253 )
254 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
255 assert timeout == httpx.Timeout(100.0)
256
257 def test_client_timeout_option(self) -> None:
258 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
259
260 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
261 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
262 assert timeout == httpx.Timeout(0)
263
264 def test_http_client_timeout_option(self) -> None:
265 # custom timeout given to the httpx client should be used
266 with httpx.Client(timeout=None) as http_client:
267 client = OpenAI(
268 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
269 )
270
271 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
272 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
273 assert timeout == httpx.Timeout(None)
274
275 # no timeout given to the httpx client should not use the httpx default
276 with httpx.Client() as http_client:
277 client = OpenAI(
278 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
279 )
280
281 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
282 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
283 assert timeout == DEFAULT_TIMEOUT
284
285 # explicitly passing the default timeout currently results in it being ignored
286 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
287 client = OpenAI(
288 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
289 )
290
291 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
292 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
293 assert timeout == DEFAULT_TIMEOUT # our default
294
295 async def test_invalid_http_client(self) -> None:
296 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
297 async with httpx.AsyncClient() as http_client:
298 OpenAI(
299 base_url=base_url,
300 api_key=api_key,
301 _strict_response_validation=True,
302 http_client=cast(Any, http_client),
303 )
304
305 def test_default_headers_option(self) -> None:
306 client = OpenAI(
307 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
308 )
309 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
310 assert request.headers.get("x-foo") == "bar"
311 assert request.headers.get("x-stainless-lang") == "python"
312
313 client2 = OpenAI(
314 base_url=base_url,
315 api_key=api_key,
316 _strict_response_validation=True,
317 default_headers={
318 "X-Foo": "stainless",
319 "X-Stainless-Lang": "my-overriding-header",
320 },
321 )
322 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
323 assert request.headers.get("x-foo") == "stainless"
324 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
325
326 def test_validate_headers(self) -> None:
327 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
328 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
329 assert request.headers.get("Authorization") == f"Bearer {api_key}"
330
331 with pytest.raises(OpenAIError):
332 with update_env(**{"OPENAI_API_KEY": Omit()}):
333 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
334 _ = client2
335
336 def test_default_query_option(self) -> None:
337 client = OpenAI(
338 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
339 )
340 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
341 url = httpx.URL(request.url)
342 assert dict(url.params) == {"query_param": "bar"}
343
344 request = client._build_request(
345 FinalRequestOptions(
346 method="get",
347 url="/foo",
348 params={"foo": "baz", "query_param": "overriden"},
349 )
350 )
351 url = httpx.URL(request.url)
352 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
353
354 def test_request_extra_json(self) -> None:
355 request = self.client._build_request(
356 FinalRequestOptions(
357 method="post",
358 url="/foo",
359 json_data={"foo": "bar"},
360 extra_json={"baz": False},
361 ),
362 )
363 data = json.loads(request.content.decode("utf-8"))
364 assert data == {"foo": "bar", "baz": False}
365
366 request = self.client._build_request(
367 FinalRequestOptions(
368 method="post",
369 url="/foo",
370 extra_json={"baz": False},
371 ),
372 )
373 data = json.loads(request.content.decode("utf-8"))
374 assert data == {"baz": False}
375
376 # `extra_json` takes priority over `json_data` when keys clash
377 request = self.client._build_request(
378 FinalRequestOptions(
379 method="post",
380 url="/foo",
381 json_data={"foo": "bar", "baz": True},
382 extra_json={"baz": None},
383 ),
384 )
385 data = json.loads(request.content.decode("utf-8"))
386 assert data == {"foo": "bar", "baz": None}
387
388 def test_request_extra_headers(self) -> None:
389 request = self.client._build_request(
390 FinalRequestOptions(
391 method="post",
392 url="/foo",
393 **make_request_options(extra_headers={"X-Foo": "Foo"}),
394 ),
395 )
396 assert request.headers.get("X-Foo") == "Foo"
397
398 # `extra_headers` takes priority over `default_headers` when keys clash
399 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
400 FinalRequestOptions(
401 method="post",
402 url="/foo",
403 **make_request_options(
404 extra_headers={"X-Bar": "false"},
405 ),
406 ),
407 )
408 assert request.headers.get("X-Bar") == "false"
409
410 def test_request_extra_query(self) -> None:
411 request = self.client._build_request(
412 FinalRequestOptions(
413 method="post",
414 url="/foo",
415 **make_request_options(
416 extra_query={"my_query_param": "Foo"},
417 ),
418 ),
419 )
420 params = dict(request.url.params)
421 assert params == {"my_query_param": "Foo"}
422
423 # if both `query` and `extra_query` are given, they are merged
424 request = self.client._build_request(
425 FinalRequestOptions(
426 method="post",
427 url="/foo",
428 **make_request_options(
429 query={"bar": "1"},
430 extra_query={"foo": "2"},
431 ),
432 ),
433 )
434 params = dict(request.url.params)
435 assert params == {"bar": "1", "foo": "2"}
436
437 # `extra_query` takes priority over `query` when keys clash
438 request = self.client._build_request(
439 FinalRequestOptions(
440 method="post",
441 url="/foo",
442 **make_request_options(
443 query={"foo": "1"},
444 extra_query={"foo": "2"},
445 ),
446 ),
447 )
448 params = dict(request.url.params)
449 assert params == {"foo": "2"}
450
451 def test_multipart_repeating_array(self, client: OpenAI) -> None:
452 request = client._build_request(
453 FinalRequestOptions.construct(
454 method="get",
455 url="/foo",
456 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
457 json_data={"array": ["foo", "bar"]},
458 files=[("foo.txt", b"hello world")],
459 )
460 )
461
462 assert request.read().split(b"\r\n") == [
463 b"--6b7ba517decee4a450543ea6ae821c82",
464 b'Content-Disposition: form-data; name="array[]"',
465 b"",
466 b"foo",
467 b"--6b7ba517decee4a450543ea6ae821c82",
468 b'Content-Disposition: form-data; name="array[]"',
469 b"",
470 b"bar",
471 b"--6b7ba517decee4a450543ea6ae821c82",
472 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
473 b"Content-Type: application/octet-stream",
474 b"",
475 b"hello world",
476 b"--6b7ba517decee4a450543ea6ae821c82--",
477 b"",
478 ]
479
480 @pytest.mark.respx(base_url=base_url)
481 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
482 class Model1(BaseModel):
483 name: str
484
485 class Model2(BaseModel):
486 foo: str
487
488 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
489
490 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
491 assert isinstance(response, Model2)
492 assert response.foo == "bar"
493
494 @pytest.mark.respx(base_url=base_url)
495 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
496 """Union of objects with the same field name using a different type"""
497
498 class Model1(BaseModel):
499 foo: int
500
501 class Model2(BaseModel):
502 foo: str
503
504 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
505
506 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
507 assert isinstance(response, Model2)
508 assert response.foo == "bar"
509
510 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
511
512 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
513 assert isinstance(response, Model1)
514 assert response.foo == 1
515
516 @pytest.mark.respx(base_url=base_url)
517 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
518 """
519 Response that sets Content-Type to something other than application/json but returns json data
520 """
521
522 class Model(BaseModel):
523 foo: int
524
525 respx_mock.get("/foo").mock(
526 return_value=httpx.Response(
527 200,
528 content=json.dumps({"foo": 2}),
529 headers={"Content-Type": "application/text"},
530 )
531 )
532
533 response = self.client.get("/foo", cast_to=Model)
534 assert isinstance(response, Model)
535 assert response.foo == 2
536
537 def test_base_url_setter(self) -> None:
538 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
539 assert client.base_url == "https://example.com/from_init/"
540
541 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
542
543 assert client.base_url == "https://example.com/from_setter/"
544
545 def test_base_url_env(self) -> None:
546 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
547 client = OpenAI(api_key=api_key, _strict_response_validation=True)
548 assert client.base_url == "http://localhost:5000/from/env/"
549
550 @pytest.mark.parametrize(
551 "client",
552 [
553 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
554 OpenAI(
555 base_url="http://localhost:5000/custom/path/",
556 api_key=api_key,
557 _strict_response_validation=True,
558 http_client=httpx.Client(),
559 ),
560 ],
561 ids=["standard", "custom http client"],
562 )
563 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
564 request = client._build_request(
565 FinalRequestOptions(
566 method="post",
567 url="/foo",
568 json_data={"foo": "bar"},
569 ),
570 )
571 assert request.url == "http://localhost:5000/custom/path/foo"
572
573 @pytest.mark.parametrize(
574 "client",
575 [
576 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
577 OpenAI(
578 base_url="http://localhost:5000/custom/path/",
579 api_key=api_key,
580 _strict_response_validation=True,
581 http_client=httpx.Client(),
582 ),
583 ],
584 ids=["standard", "custom http client"],
585 )
586 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
587 request = client._build_request(
588 FinalRequestOptions(
589 method="post",
590 url="/foo",
591 json_data={"foo": "bar"},
592 ),
593 )
594 assert request.url == "http://localhost:5000/custom/path/foo"
595
596 @pytest.mark.parametrize(
597 "client",
598 [
599 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
600 OpenAI(
601 base_url="http://localhost:5000/custom/path/",
602 api_key=api_key,
603 _strict_response_validation=True,
604 http_client=httpx.Client(),
605 ),
606 ],
607 ids=["standard", "custom http client"],
608 )
609 def test_absolute_request_url(self, client: OpenAI) -> None:
610 request = client._build_request(
611 FinalRequestOptions(
612 method="post",
613 url="https://myapi.com/foo",
614 json_data={"foo": "bar"},
615 ),
616 )
617 assert request.url == "https://myapi.com/foo"
618
619 def test_copied_client_does_not_close_http(self) -> None:
620 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
621 assert not client.is_closed()
622
623 copied = client.copy()
624 assert copied is not client
625
626 del copied
627
628 assert not client.is_closed()
629
630 def test_client_context_manager(self) -> None:
631 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
632 with client as c2:
633 assert c2 is client
634 assert not c2.is_closed()
635 assert not client.is_closed()
636 assert client.is_closed()
637
638 @pytest.mark.respx(base_url=base_url)
639 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
640 class Model(BaseModel):
641 foo: str
642
643 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
644
645 with pytest.raises(APIResponseValidationError) as exc:
646 self.client.get("/foo", cast_to=Model)
647
648 assert isinstance(exc.value.__cause__, ValidationError)
649
650 def test_client_max_retries_validation(self) -> None:
651 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
652 OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None))
653
654 @pytest.mark.respx(base_url=base_url)
655 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
656 class Model(BaseModel):
657 name: str
658
659 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
660
661 stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
662 assert isinstance(stream, Stream)
663 stream.response.close()
664
665 @pytest.mark.respx(base_url=base_url)
666 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
667 class Model(BaseModel):
668 name: str
669
670 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
671
672 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
673
674 with pytest.raises(APIResponseValidationError):
675 strict_client.get("/foo", cast_to=Model)
676
677 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
678
679 response = client.get("/foo", cast_to=Model)
680 assert isinstance(response, str) # type: ignore[unreachable]
681
682 @pytest.mark.parametrize(
683 "remaining_retries,retry_after,timeout",
684 [
685 [3, "20", 20],
686 [3, "0", 0.5],
687 [3, "-10", 0.5],
688 [3, "60", 60],
689 [3, "61", 0.5],
690 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
691 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
692 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
693 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
694 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
695 [3, "99999999999999999999999999999999999", 0.5],
696 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
697 [3, "", 0.5],
698 [2, "", 0.5 * 2.0],
699 [1, "", 0.5 * 4.0],
700 ],
701 )
702 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
703 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
704 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
705
706 headers = httpx.Headers({"retry-after": retry_after})
707 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
708 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
709 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
710
711 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
712 @pytest.mark.respx(base_url=base_url)
713 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
714 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
715
716 with pytest.raises(APITimeoutError):
717 self.client.post(
718 "/chat/completions",
719 body=cast(
720 object,
721 dict(
722 messages=[
723 {
724 "role": "user",
725 "content": "Say this is a test",
726 }
727 ],
728 model="gpt-3.5-turbo",
729 ),
730 ),
731 cast_to=httpx.Response,
732 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
733 )
734
735 assert _get_open_connections(self.client) == 0
736
737 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
738 @pytest.mark.respx(base_url=base_url)
739 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
740 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
741
742 with pytest.raises(APIStatusError):
743 self.client.post(
744 "/chat/completions",
745 body=cast(
746 object,
747 dict(
748 messages=[
749 {
750 "role": "user",
751 "content": "Say this is a test",
752 }
753 ],
754 model="gpt-3.5-turbo",
755 ),
756 ),
757 cast_to=httpx.Response,
758 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
759 )
760
761 assert _get_open_connections(self.client) == 0
762
763 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
764 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
765 @pytest.mark.respx(base_url=base_url)
766 def test_retries_taken(self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter) -> None:
767 client = client.with_options(max_retries=4)
768
769 nb_retries = 0
770
771 def retry_handler(_request: httpx.Request) -> httpx.Response:
772 nonlocal nb_retries
773 if nb_retries < failures_before_success:
774 nb_retries += 1
775 return httpx.Response(500)
776 return httpx.Response(200)
777
778 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
779
780 response = client.chat.completions.with_raw_response.create(
781 messages=[
782 {
783 "content": "string",
784 "role": "system",
785 }
786 ],
787 model="gpt-4o",
788 )
789
790 assert response.retries_taken == failures_before_success
791
792 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
793 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
794 @pytest.mark.respx(base_url=base_url)
795 def test_retries_taken_new_response_class(
796 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
797 ) -> None:
798 client = client.with_options(max_retries=4)
799
800 nb_retries = 0
801
802 def retry_handler(_request: httpx.Request) -> httpx.Response:
803 nonlocal nb_retries
804 if nb_retries < failures_before_success:
805 nb_retries += 1
806 return httpx.Response(500)
807 return httpx.Response(200)
808
809 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
810
811 with client.chat.completions.with_streaming_response.create(
812 messages=[
813 {
814 "content": "string",
815 "role": "system",
816 }
817 ],
818 model="gpt-4o",
819 ) as response:
820 assert response.retries_taken == failures_before_success
821
822
823class TestAsyncOpenAI:
824 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
825
826 @pytest.mark.respx(base_url=base_url)
827 @pytest.mark.asyncio
828 async def test_raw_response(self, respx_mock: MockRouter) -> None:
829 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
830
831 response = await self.client.post("/foo", cast_to=httpx.Response)
832 assert response.status_code == 200
833 assert isinstance(response, httpx.Response)
834 assert response.json() == {"foo": "bar"}
835
836 @pytest.mark.respx(base_url=base_url)
837 @pytest.mark.asyncio
838 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
839 respx_mock.post("/foo").mock(
840 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
841 )
842
843 response = await self.client.post("/foo", cast_to=httpx.Response)
844 assert response.status_code == 200
845 assert isinstance(response, httpx.Response)
846 assert response.json() == {"foo": "bar"}
847
848 def test_copy(self) -> None:
849 copied = self.client.copy()
850 assert id(copied) != id(self.client)
851
852 copied = self.client.copy(api_key="another My API Key")
853 assert copied.api_key == "another My API Key"
854 assert self.client.api_key == "My API Key"
855
856 def test_copy_default_options(self) -> None:
857 # options that have a default are overridden correctly
858 copied = self.client.copy(max_retries=7)
859 assert copied.max_retries == 7
860 assert self.client.max_retries == 2
861
862 copied2 = copied.copy(max_retries=6)
863 assert copied2.max_retries == 6
864 assert copied.max_retries == 7
865
866 # timeout
867 assert isinstance(self.client.timeout, httpx.Timeout)
868 copied = self.client.copy(timeout=None)
869 assert copied.timeout is None
870 assert isinstance(self.client.timeout, httpx.Timeout)
871
872 def test_copy_default_headers(self) -> None:
873 client = AsyncOpenAI(
874 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
875 )
876 assert client.default_headers["X-Foo"] == "bar"
877
878 # does not override the already given value when not specified
879 copied = client.copy()
880 assert copied.default_headers["X-Foo"] == "bar"
881
882 # merges already given headers
883 copied = client.copy(default_headers={"X-Bar": "stainless"})
884 assert copied.default_headers["X-Foo"] == "bar"
885 assert copied.default_headers["X-Bar"] == "stainless"
886
887 # uses new values for any already given headers
888 copied = client.copy(default_headers={"X-Foo": "stainless"})
889 assert copied.default_headers["X-Foo"] == "stainless"
890
891 # set_default_headers
892
893 # completely overrides already set values
894 copied = client.copy(set_default_headers={})
895 assert copied.default_headers.get("X-Foo") is None
896
897 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
898 assert copied.default_headers["X-Bar"] == "Robert"
899
900 with pytest.raises(
901 ValueError,
902 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
903 ):
904 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
905
906 def test_copy_default_query(self) -> None:
907 client = AsyncOpenAI(
908 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
909 )
910 assert _get_params(client)["foo"] == "bar"
911
912 # does not override the already given value when not specified
913 copied = client.copy()
914 assert _get_params(copied)["foo"] == "bar"
915
916 # merges already given params
917 copied = client.copy(default_query={"bar": "stainless"})
918 params = _get_params(copied)
919 assert params["foo"] == "bar"
920 assert params["bar"] == "stainless"
921
922 # uses new values for any already given headers
923 copied = client.copy(default_query={"foo": "stainless"})
924 assert _get_params(copied)["foo"] == "stainless"
925
926 # set_default_query
927
928 # completely overrides already set values
929 copied = client.copy(set_default_query={})
930 assert _get_params(copied) == {}
931
932 copied = client.copy(set_default_query={"bar": "Robert"})
933 assert _get_params(copied)["bar"] == "Robert"
934
935 with pytest.raises(
936 ValueError,
937 # TODO: update
938 match="`default_query` and `set_default_query` arguments are mutually exclusive",
939 ):
940 client.copy(set_default_query={}, default_query={"foo": "Bar"})
941
942 def test_copy_signature(self) -> None:
943 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
944 init_signature = inspect.signature(
945 # mypy doesn't like that we access the `__init__` property.
946 self.client.__init__, # type: ignore[misc]
947 )
948 copy_signature = inspect.signature(self.client.copy)
949 exclude_params = {"transport", "proxies", "_strict_response_validation"}
950
951 for name in init_signature.parameters.keys():
952 if name in exclude_params:
953 continue
954
955 copy_param = copy_signature.parameters.get(name)
956 assert copy_param is not None, f"copy() signature is missing the {name} param"
957
958 def test_copy_build_request(self) -> None:
959 options = FinalRequestOptions(method="get", url="/foo")
960
961 def build_request(options: FinalRequestOptions) -> None:
962 client = self.client.copy()
963 client._build_request(options)
964
965 # ensure that the machinery is warmed up before tracing starts.
966 build_request(options)
967 gc.collect()
968
969 tracemalloc.start(1000)
970
971 snapshot_before = tracemalloc.take_snapshot()
972
973 ITERATIONS = 10
974 for _ in range(ITERATIONS):
975 build_request(options)
976
977 gc.collect()
978 snapshot_after = tracemalloc.take_snapshot()
979
980 tracemalloc.stop()
981
982 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
983 if diff.count == 0:
984 # Avoid false positives by considering only leaks (i.e. allocations that persist).
985 return
986
987 if diff.count % ITERATIONS != 0:
988 # Avoid false positives by considering only leaks that appear per iteration.
989 return
990
991 for frame in diff.traceback:
992 if any(
993 frame.filename.endswith(fragment)
994 for fragment in [
995 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
996 #
997 # removing the decorator fixes the leak for reasons we don't understand.
998 "openai/_legacy_response.py",
999 "openai/_response.py",
1000 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
1001 "openai/_compat.py",
1002 # Standard library leaks we don't care about.
1003 "/logging/__init__.py",
1004 ]
1005 ):
1006 return
1007
1008 leaks.append(diff)
1009
1010 leaks: list[tracemalloc.StatisticDiff] = []
1011 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
1012 add_leak(leaks, diff)
1013 if leaks:
1014 for leak in leaks:
1015 print("MEMORY LEAK:", leak)
1016 for frame in leak.traceback:
1017 print(frame)
1018 raise AssertionError()
1019
1020 async def test_request_timeout(self) -> None:
1021 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
1022 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1023 assert timeout == DEFAULT_TIMEOUT
1024
1025 request = self.client._build_request(
1026 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
1027 )
1028 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1029 assert timeout == httpx.Timeout(100.0)
1030
1031 async def test_client_timeout_option(self) -> None:
1032 client = AsyncOpenAI(
1033 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
1034 )
1035
1036 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1037 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1038 assert timeout == httpx.Timeout(0)
1039
1040 async def test_http_client_timeout_option(self) -> None:
1041 # custom timeout given to the httpx client should be used
1042 async with httpx.AsyncClient(timeout=None) as http_client:
1043 client = AsyncOpenAI(
1044 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1045 )
1046
1047 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1048 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1049 assert timeout == httpx.Timeout(None)
1050
1051 # no timeout given to the httpx client should not use the httpx default
1052 async with httpx.AsyncClient() as http_client:
1053 client = AsyncOpenAI(
1054 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1055 )
1056
1057 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1058 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1059 assert timeout == DEFAULT_TIMEOUT
1060
1061 # explicitly passing the default timeout currently results in it being ignored
1062 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1063 client = AsyncOpenAI(
1064 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1065 )
1066
1067 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1068 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1069 assert timeout == DEFAULT_TIMEOUT # our default
1070
1071 def test_invalid_http_client(self) -> None:
1072 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1073 with httpx.Client() as http_client:
1074 AsyncOpenAI(
1075 base_url=base_url,
1076 api_key=api_key,
1077 _strict_response_validation=True,
1078 http_client=cast(Any, http_client),
1079 )
1080
1081 def test_default_headers_option(self) -> None:
1082 client = AsyncOpenAI(
1083 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1084 )
1085 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1086 assert request.headers.get("x-foo") == "bar"
1087 assert request.headers.get("x-stainless-lang") == "python"
1088
1089 client2 = AsyncOpenAI(
1090 base_url=base_url,
1091 api_key=api_key,
1092 _strict_response_validation=True,
1093 default_headers={
1094 "X-Foo": "stainless",
1095 "X-Stainless-Lang": "my-overriding-header",
1096 },
1097 )
1098 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1099 assert request.headers.get("x-foo") == "stainless"
1100 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1101
1102 def test_validate_headers(self) -> None:
1103 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1104 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1105 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1106
1107 with pytest.raises(OpenAIError):
1108 with update_env(**{"OPENAI_API_KEY": Omit()}):
1109 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1110 _ = client2
1111
1112 def test_default_query_option(self) -> None:
1113 client = AsyncOpenAI(
1114 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1115 )
1116 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1117 url = httpx.URL(request.url)
1118 assert dict(url.params) == {"query_param": "bar"}
1119
1120 request = client._build_request(
1121 FinalRequestOptions(
1122 method="get",
1123 url="/foo",
1124 params={"foo": "baz", "query_param": "overriden"},
1125 )
1126 )
1127 url = httpx.URL(request.url)
1128 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
1129
1130 def test_request_extra_json(self) -> None:
1131 request = self.client._build_request(
1132 FinalRequestOptions(
1133 method="post",
1134 url="/foo",
1135 json_data={"foo": "bar"},
1136 extra_json={"baz": False},
1137 ),
1138 )
1139 data = json.loads(request.content.decode("utf-8"))
1140 assert data == {"foo": "bar", "baz": False}
1141
1142 request = self.client._build_request(
1143 FinalRequestOptions(
1144 method="post",
1145 url="/foo",
1146 extra_json={"baz": False},
1147 ),
1148 )
1149 data = json.loads(request.content.decode("utf-8"))
1150 assert data == {"baz": False}
1151
1152 # `extra_json` takes priority over `json_data` when keys clash
1153 request = self.client._build_request(
1154 FinalRequestOptions(
1155 method="post",
1156 url="/foo",
1157 json_data={"foo": "bar", "baz": True},
1158 extra_json={"baz": None},
1159 ),
1160 )
1161 data = json.loads(request.content.decode("utf-8"))
1162 assert data == {"foo": "bar", "baz": None}
1163
1164 def test_request_extra_headers(self) -> None:
1165 request = self.client._build_request(
1166 FinalRequestOptions(
1167 method="post",
1168 url="/foo",
1169 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1170 ),
1171 )
1172 assert request.headers.get("X-Foo") == "Foo"
1173
1174 # `extra_headers` takes priority over `default_headers` when keys clash
1175 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1176 FinalRequestOptions(
1177 method="post",
1178 url="/foo",
1179 **make_request_options(
1180 extra_headers={"X-Bar": "false"},
1181 ),
1182 ),
1183 )
1184 assert request.headers.get("X-Bar") == "false"
1185
1186 def test_request_extra_query(self) -> None:
1187 request = self.client._build_request(
1188 FinalRequestOptions(
1189 method="post",
1190 url="/foo",
1191 **make_request_options(
1192 extra_query={"my_query_param": "Foo"},
1193 ),
1194 ),
1195 )
1196 params = dict(request.url.params)
1197 assert params == {"my_query_param": "Foo"}
1198
1199 # if both `query` and `extra_query` are given, they are merged
1200 request = self.client._build_request(
1201 FinalRequestOptions(
1202 method="post",
1203 url="/foo",
1204 **make_request_options(
1205 query={"bar": "1"},
1206 extra_query={"foo": "2"},
1207 ),
1208 ),
1209 )
1210 params = dict(request.url.params)
1211 assert params == {"bar": "1", "foo": "2"}
1212
1213 # `extra_query` takes priority over `query` when keys clash
1214 request = self.client._build_request(
1215 FinalRequestOptions(
1216 method="post",
1217 url="/foo",
1218 **make_request_options(
1219 query={"foo": "1"},
1220 extra_query={"foo": "2"},
1221 ),
1222 ),
1223 )
1224 params = dict(request.url.params)
1225 assert params == {"foo": "2"}
1226
1227 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1228 request = async_client._build_request(
1229 FinalRequestOptions.construct(
1230 method="get",
1231 url="/foo",
1232 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1233 json_data={"array": ["foo", "bar"]},
1234 files=[("foo.txt", b"hello world")],
1235 )
1236 )
1237
1238 assert request.read().split(b"\r\n") == [
1239 b"--6b7ba517decee4a450543ea6ae821c82",
1240 b'Content-Disposition: form-data; name="array[]"',
1241 b"",
1242 b"foo",
1243 b"--6b7ba517decee4a450543ea6ae821c82",
1244 b'Content-Disposition: form-data; name="array[]"',
1245 b"",
1246 b"bar",
1247 b"--6b7ba517decee4a450543ea6ae821c82",
1248 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1249 b"Content-Type: application/octet-stream",
1250 b"",
1251 b"hello world",
1252 b"--6b7ba517decee4a450543ea6ae821c82--",
1253 b"",
1254 ]
1255
1256 @pytest.mark.respx(base_url=base_url)
1257 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1258 class Model1(BaseModel):
1259 name: str
1260
1261 class Model2(BaseModel):
1262 foo: str
1263
1264 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1265
1266 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1267 assert isinstance(response, Model2)
1268 assert response.foo == "bar"
1269
1270 @pytest.mark.respx(base_url=base_url)
1271 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1272 """Union of objects with the same field name using a different type"""
1273
1274 class Model1(BaseModel):
1275 foo: int
1276
1277 class Model2(BaseModel):
1278 foo: str
1279
1280 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1281
1282 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1283 assert isinstance(response, Model2)
1284 assert response.foo == "bar"
1285
1286 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1287
1288 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1289 assert isinstance(response, Model1)
1290 assert response.foo == 1
1291
1292 @pytest.mark.respx(base_url=base_url)
1293 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1294 """
1295 Response that sets Content-Type to something other than application/json but returns json data
1296 """
1297
1298 class Model(BaseModel):
1299 foo: int
1300
1301 respx_mock.get("/foo").mock(
1302 return_value=httpx.Response(
1303 200,
1304 content=json.dumps({"foo": 2}),
1305 headers={"Content-Type": "application/text"},
1306 )
1307 )
1308
1309 response = await self.client.get("/foo", cast_to=Model)
1310 assert isinstance(response, Model)
1311 assert response.foo == 2
1312
1313 def test_base_url_setter(self) -> None:
1314 client = AsyncOpenAI(
1315 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1316 )
1317 assert client.base_url == "https://example.com/from_init/"
1318
1319 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1320
1321 assert client.base_url == "https://example.com/from_setter/"
1322
1323 def test_base_url_env(self) -> None:
1324 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1325 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1326 assert client.base_url == "http://localhost:5000/from/env/"
1327
1328 @pytest.mark.parametrize(
1329 "client",
1330 [
1331 AsyncOpenAI(
1332 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1333 ),
1334 AsyncOpenAI(
1335 base_url="http://localhost:5000/custom/path/",
1336 api_key=api_key,
1337 _strict_response_validation=True,
1338 http_client=httpx.AsyncClient(),
1339 ),
1340 ],
1341 ids=["standard", "custom http client"],
1342 )
1343 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1344 request = client._build_request(
1345 FinalRequestOptions(
1346 method="post",
1347 url="/foo",
1348 json_data={"foo": "bar"},
1349 ),
1350 )
1351 assert request.url == "http://localhost:5000/custom/path/foo"
1352
1353 @pytest.mark.parametrize(
1354 "client",
1355 [
1356 AsyncOpenAI(
1357 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1358 ),
1359 AsyncOpenAI(
1360 base_url="http://localhost:5000/custom/path/",
1361 api_key=api_key,
1362 _strict_response_validation=True,
1363 http_client=httpx.AsyncClient(),
1364 ),
1365 ],
1366 ids=["standard", "custom http client"],
1367 )
1368 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1369 request = client._build_request(
1370 FinalRequestOptions(
1371 method="post",
1372 url="/foo",
1373 json_data={"foo": "bar"},
1374 ),
1375 )
1376 assert request.url == "http://localhost:5000/custom/path/foo"
1377
1378 @pytest.mark.parametrize(
1379 "client",
1380 [
1381 AsyncOpenAI(
1382 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1383 ),
1384 AsyncOpenAI(
1385 base_url="http://localhost:5000/custom/path/",
1386 api_key=api_key,
1387 _strict_response_validation=True,
1388 http_client=httpx.AsyncClient(),
1389 ),
1390 ],
1391 ids=["standard", "custom http client"],
1392 )
1393 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1394 request = client._build_request(
1395 FinalRequestOptions(
1396 method="post",
1397 url="https://myapi.com/foo",
1398 json_data={"foo": "bar"},
1399 ),
1400 )
1401 assert request.url == "https://myapi.com/foo"
1402
1403 async def test_copied_client_does_not_close_http(self) -> None:
1404 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1405 assert not client.is_closed()
1406
1407 copied = client.copy()
1408 assert copied is not client
1409
1410 del copied
1411
1412 await asyncio.sleep(0.2)
1413 assert not client.is_closed()
1414
1415 async def test_client_context_manager(self) -> None:
1416 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1417 async with client as c2:
1418 assert c2 is client
1419 assert not c2.is_closed()
1420 assert not client.is_closed()
1421 assert client.is_closed()
1422
1423 @pytest.mark.respx(base_url=base_url)
1424 @pytest.mark.asyncio
1425 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1426 class Model(BaseModel):
1427 foo: str
1428
1429 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1430
1431 with pytest.raises(APIResponseValidationError) as exc:
1432 await self.client.get("/foo", cast_to=Model)
1433
1434 assert isinstance(exc.value.__cause__, ValidationError)
1435
1436 async def test_client_max_retries_validation(self) -> None:
1437 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
1438 AsyncOpenAI(
1439 base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)
1440 )
1441
1442 @pytest.mark.respx(base_url=base_url)
1443 @pytest.mark.asyncio
1444 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1445 class Model(BaseModel):
1446 name: str
1447
1448 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1449
1450 stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1451 assert isinstance(stream, AsyncStream)
1452 await stream.response.aclose()
1453
1454 @pytest.mark.respx(base_url=base_url)
1455 @pytest.mark.asyncio
1456 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1457 class Model(BaseModel):
1458 name: str
1459
1460 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1461
1462 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1463
1464 with pytest.raises(APIResponseValidationError):
1465 await strict_client.get("/foo", cast_to=Model)
1466
1467 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1468
1469 response = await client.get("/foo", cast_to=Model)
1470 assert isinstance(response, str) # type: ignore[unreachable]
1471
1472 @pytest.mark.parametrize(
1473 "remaining_retries,retry_after,timeout",
1474 [
1475 [3, "20", 20],
1476 [3, "0", 0.5],
1477 [3, "-10", 0.5],
1478 [3, "60", 60],
1479 [3, "61", 0.5],
1480 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1481 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1482 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1483 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1484 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1485 [3, "99999999999999999999999999999999999", 0.5],
1486 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1487 [3, "", 0.5],
1488 [2, "", 0.5 * 2.0],
1489 [1, "", 0.5 * 4.0],
1490 ],
1491 )
1492 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1493 @pytest.mark.asyncio
1494 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1495 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1496
1497 headers = httpx.Headers({"retry-after": retry_after})
1498 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1499 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1500 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1501
1502 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1503 @pytest.mark.respx(base_url=base_url)
1504 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1505 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1506
1507 with pytest.raises(APITimeoutError):
1508 await self.client.post(
1509 "/chat/completions",
1510 body=cast(
1511 object,
1512 dict(
1513 messages=[
1514 {
1515 "role": "user",
1516 "content": "Say this is a test",
1517 }
1518 ],
1519 model="gpt-3.5-turbo",
1520 ),
1521 ),
1522 cast_to=httpx.Response,
1523 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1524 )
1525
1526 assert _get_open_connections(self.client) == 0
1527
1528 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1529 @pytest.mark.respx(base_url=base_url)
1530 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1531 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1532
1533 with pytest.raises(APIStatusError):
1534 await self.client.post(
1535 "/chat/completions",
1536 body=cast(
1537 object,
1538 dict(
1539 messages=[
1540 {
1541 "role": "user",
1542 "content": "Say this is a test",
1543 }
1544 ],
1545 model="gpt-3.5-turbo",
1546 ),
1547 ),
1548 cast_to=httpx.Response,
1549 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1550 )
1551
1552 assert _get_open_connections(self.client) == 0
1553
1554 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1555 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1556 @pytest.mark.respx(base_url=base_url)
1557 @pytest.mark.asyncio
1558 async def test_retries_taken(
1559 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1560 ) -> None:
1561 client = async_client.with_options(max_retries=4)
1562
1563 nb_retries = 0
1564
1565 def retry_handler(_request: httpx.Request) -> httpx.Response:
1566 nonlocal nb_retries
1567 if nb_retries < failures_before_success:
1568 nb_retries += 1
1569 return httpx.Response(500)
1570 return httpx.Response(200)
1571
1572 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1573
1574 response = await client.chat.completions.with_raw_response.create(
1575 messages=[
1576 {
1577 "content": "string",
1578 "role": "system",
1579 }
1580 ],
1581 model="gpt-4o",
1582 )
1583
1584 assert response.retries_taken == failures_before_success
1585
1586 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1587 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1588 @pytest.mark.respx(base_url=base_url)
1589 @pytest.mark.asyncio
1590 async def test_retries_taken_new_response_class(
1591 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1592 ) -> None:
1593 client = async_client.with_options(max_retries=4)
1594
1595 nb_retries = 0
1596
1597 def retry_handler(_request: httpx.Request) -> httpx.Response:
1598 nonlocal nb_retries
1599 if nb_retries < failures_before_success:
1600 nb_retries += 1
1601 return httpx.Response(500)
1602 return httpx.Response(200)
1603
1604 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1605
1606 async with client.chat.completions.with_streaming_response.create(
1607 messages=[
1608 {
1609 "content": "string",
1610 "role": "system",
1611 }
1612 ],
1613 model="gpt-4o",
1614 ) as response:
1615 assert response.retries_taken == failures_before_success
1616