openai/openai-python

Public

mirrored from https://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.47.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1631lines · modecode

1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import gc
6import os
7import json
8import asyncio
9import inspect
10import tracemalloc
11from typing import Any, Union, cast
12from unittest import mock
13
14import httpx
15import pytest
16from respx import MockRouter
17from pydantic import ValidationError
18
19from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
20from openai._types import Omit
21from openai._models import BaseModel, FinalRequestOptions
22from openai._constants import RAW_RESPONSE_HEADER
23from openai._streaming import Stream, AsyncStream
24from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
25from openai._base_client import DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, BaseClient, make_request_options
26
27from .utils import update_env
28
29base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
30api_key = "My API Key"
31
32
33def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
34 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
35 url = httpx.URL(request.url)
36 return dict(url.params)
37
38
39def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
40 return 0.1
41
42
43def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
44 transport = client._client._transport
45 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
46
47 pool = transport._pool
48 return len(pool._requests)
49
50
51class TestOpenAI:
52 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
53
54 @pytest.mark.respx(base_url=base_url)
55 def test_raw_response(self, respx_mock: MockRouter) -> None:
56 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
57
58 response = self.client.post("/foo", cast_to=httpx.Response)
59 assert response.status_code == 200
60 assert isinstance(response, httpx.Response)
61 assert response.json() == {"foo": "bar"}
62
63 @pytest.mark.respx(base_url=base_url)
64 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
65 respx_mock.post("/foo").mock(
66 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
67 )
68
69 response = self.client.post("/foo", cast_to=httpx.Response)
70 assert response.status_code == 200
71 assert isinstance(response, httpx.Response)
72 assert response.json() == {"foo": "bar"}
73
74 def test_copy(self) -> None:
75 copied = self.client.copy()
76 assert id(copied) != id(self.client)
77
78 copied = self.client.copy(api_key="another My API Key")
79 assert copied.api_key == "another My API Key"
80 assert self.client.api_key == "My API Key"
81
82 def test_copy_default_options(self) -> None:
83 # options that have a default are overridden correctly
84 copied = self.client.copy(max_retries=7)
85 assert copied.max_retries == 7
86 assert self.client.max_retries == 2
87
88 copied2 = copied.copy(max_retries=6)
89 assert copied2.max_retries == 6
90 assert copied.max_retries == 7
91
92 # timeout
93 assert isinstance(self.client.timeout, httpx.Timeout)
94 copied = self.client.copy(timeout=None)
95 assert copied.timeout is None
96 assert isinstance(self.client.timeout, httpx.Timeout)
97
98 def test_copy_default_headers(self) -> None:
99 client = OpenAI(
100 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
101 )
102 assert client.default_headers["X-Foo"] == "bar"
103
104 # does not override the already given value when not specified
105 copied = client.copy()
106 assert copied.default_headers["X-Foo"] == "bar"
107
108 # merges already given headers
109 copied = client.copy(default_headers={"X-Bar": "stainless"})
110 assert copied.default_headers["X-Foo"] == "bar"
111 assert copied.default_headers["X-Bar"] == "stainless"
112
113 # uses new values for any already given headers
114 copied = client.copy(default_headers={"X-Foo": "stainless"})
115 assert copied.default_headers["X-Foo"] == "stainless"
116
117 # set_default_headers
118
119 # completely overrides already set values
120 copied = client.copy(set_default_headers={})
121 assert copied.default_headers.get("X-Foo") is None
122
123 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
124 assert copied.default_headers["X-Bar"] == "Robert"
125
126 with pytest.raises(
127 ValueError,
128 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
129 ):
130 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
131
132 def test_copy_default_query(self) -> None:
133 client = OpenAI(
134 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
135 )
136 assert _get_params(client)["foo"] == "bar"
137
138 # does not override the already given value when not specified
139 copied = client.copy()
140 assert _get_params(copied)["foo"] == "bar"
141
142 # merges already given params
143 copied = client.copy(default_query={"bar": "stainless"})
144 params = _get_params(copied)
145 assert params["foo"] == "bar"
146 assert params["bar"] == "stainless"
147
148 # uses new values for any already given headers
149 copied = client.copy(default_query={"foo": "stainless"})
150 assert _get_params(copied)["foo"] == "stainless"
151
152 # set_default_query
153
154 # completely overrides already set values
155 copied = client.copy(set_default_query={})
156 assert _get_params(copied) == {}
157
158 copied = client.copy(set_default_query={"bar": "Robert"})
159 assert _get_params(copied)["bar"] == "Robert"
160
161 with pytest.raises(
162 ValueError,
163 # TODO: update
164 match="`default_query` and `set_default_query` arguments are mutually exclusive",
165 ):
166 client.copy(set_default_query={}, default_query={"foo": "Bar"})
167
168 def test_copy_signature(self) -> None:
169 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
170 init_signature = inspect.signature(
171 # mypy doesn't like that we access the `__init__` property.
172 self.client.__init__, # type: ignore[misc]
173 )
174 copy_signature = inspect.signature(self.client.copy)
175 exclude_params = {"transport", "proxies", "_strict_response_validation"}
176
177 for name in init_signature.parameters.keys():
178 if name in exclude_params:
179 continue
180
181 copy_param = copy_signature.parameters.get(name)
182 assert copy_param is not None, f"copy() signature is missing the {name} param"
183
184 def test_copy_build_request(self) -> None:
185 options = FinalRequestOptions(method="get", url="/foo")
186
187 def build_request(options: FinalRequestOptions) -> None:
188 client = self.client.copy()
189 client._build_request(options)
190
191 # ensure that the machinery is warmed up before tracing starts.
192 build_request(options)
193 gc.collect()
194
195 tracemalloc.start(1000)
196
197 snapshot_before = tracemalloc.take_snapshot()
198
199 ITERATIONS = 10
200 for _ in range(ITERATIONS):
201 build_request(options)
202
203 gc.collect()
204 snapshot_after = tracemalloc.take_snapshot()
205
206 tracemalloc.stop()
207
208 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
209 if diff.count == 0:
210 # Avoid false positives by considering only leaks (i.e. allocations that persist).
211 return
212
213 if diff.count % ITERATIONS != 0:
214 # Avoid false positives by considering only leaks that appear per iteration.
215 return
216
217 for frame in diff.traceback:
218 if any(
219 frame.filename.endswith(fragment)
220 for fragment in [
221 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
222 #
223 # removing the decorator fixes the leak for reasons we don't understand.
224 "openai/_legacy_response.py",
225 "openai/_response.py",
226 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
227 "openai/_compat.py",
228 # Standard library leaks we don't care about.
229 "/logging/__init__.py",
230 ]
231 ):
232 return
233
234 leaks.append(diff)
235
236 leaks: list[tracemalloc.StatisticDiff] = []
237 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
238 add_leak(leaks, diff)
239 if leaks:
240 for leak in leaks:
241 print("MEMORY LEAK:", leak)
242 for frame in leak.traceback:
243 print(frame)
244 raise AssertionError()
245
246 def test_request_timeout(self) -> None:
247 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
248 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
249 assert timeout == DEFAULT_TIMEOUT
250
251 request = self.client._build_request(
252 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
253 )
254 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
255 assert timeout == httpx.Timeout(100.0)
256
257 def test_client_timeout_option(self) -> None:
258 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
259
260 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
261 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
262 assert timeout == httpx.Timeout(0)
263
264 def test_http_client_timeout_option(self) -> None:
265 # custom timeout given to the httpx client should be used
266 with httpx.Client(timeout=None) as http_client:
267 client = OpenAI(
268 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
269 )
270
271 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
272 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
273 assert timeout == httpx.Timeout(None)
274
275 # no timeout given to the httpx client should not use the httpx default
276 with httpx.Client() as http_client:
277 client = OpenAI(
278 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
279 )
280
281 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
282 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
283 assert timeout == DEFAULT_TIMEOUT
284
285 # explicitly passing the default timeout currently results in it being ignored
286 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
287 client = OpenAI(
288 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
289 )
290
291 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
292 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
293 assert timeout == DEFAULT_TIMEOUT # our default
294
295 async def test_invalid_http_client(self) -> None:
296 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
297 async with httpx.AsyncClient() as http_client:
298 OpenAI(
299 base_url=base_url,
300 api_key=api_key,
301 _strict_response_validation=True,
302 http_client=cast(Any, http_client),
303 )
304
305 def test_default_headers_option(self) -> None:
306 client = OpenAI(
307 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
308 )
309 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
310 assert request.headers.get("x-foo") == "bar"
311 assert request.headers.get("x-stainless-lang") == "python"
312
313 client2 = OpenAI(
314 base_url=base_url,
315 api_key=api_key,
316 _strict_response_validation=True,
317 default_headers={
318 "X-Foo": "stainless",
319 "X-Stainless-Lang": "my-overriding-header",
320 },
321 )
322 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
323 assert request.headers.get("x-foo") == "stainless"
324 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
325
326 def test_validate_headers(self) -> None:
327 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
328 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
329 assert request.headers.get("Authorization") == f"Bearer {api_key}"
330
331 with pytest.raises(OpenAIError):
332 with update_env(**{"OPENAI_API_KEY": Omit()}):
333 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
334 _ = client2
335
336 def test_default_query_option(self) -> None:
337 client = OpenAI(
338 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
339 )
340 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
341 url = httpx.URL(request.url)
342 assert dict(url.params) == {"query_param": "bar"}
343
344 request = client._build_request(
345 FinalRequestOptions(
346 method="get",
347 url="/foo",
348 params={"foo": "baz", "query_param": "overriden"},
349 )
350 )
351 url = httpx.URL(request.url)
352 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
353
354 def test_request_extra_json(self) -> None:
355 request = self.client._build_request(
356 FinalRequestOptions(
357 method="post",
358 url="/foo",
359 json_data={"foo": "bar"},
360 extra_json={"baz": False},
361 ),
362 )
363 data = json.loads(request.content.decode("utf-8"))
364 assert data == {"foo": "bar", "baz": False}
365
366 request = self.client._build_request(
367 FinalRequestOptions(
368 method="post",
369 url="/foo",
370 extra_json={"baz": False},
371 ),
372 )
373 data = json.loads(request.content.decode("utf-8"))
374 assert data == {"baz": False}
375
376 # `extra_json` takes priority over `json_data` when keys clash
377 request = self.client._build_request(
378 FinalRequestOptions(
379 method="post",
380 url="/foo",
381 json_data={"foo": "bar", "baz": True},
382 extra_json={"baz": None},
383 ),
384 )
385 data = json.loads(request.content.decode("utf-8"))
386 assert data == {"foo": "bar", "baz": None}
387
388 def test_request_extra_headers(self) -> None:
389 request = self.client._build_request(
390 FinalRequestOptions(
391 method="post",
392 url="/foo",
393 **make_request_options(extra_headers={"X-Foo": "Foo"}),
394 ),
395 )
396 assert request.headers.get("X-Foo") == "Foo"
397
398 # `extra_headers` takes priority over `default_headers` when keys clash
399 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
400 FinalRequestOptions(
401 method="post",
402 url="/foo",
403 **make_request_options(
404 extra_headers={"X-Bar": "false"},
405 ),
406 ),
407 )
408 assert request.headers.get("X-Bar") == "false"
409
410 def test_request_extra_query(self) -> None:
411 request = self.client._build_request(
412 FinalRequestOptions(
413 method="post",
414 url="/foo",
415 **make_request_options(
416 extra_query={"my_query_param": "Foo"},
417 ),
418 ),
419 )
420 params = dict(request.url.params)
421 assert params == {"my_query_param": "Foo"}
422
423 # if both `query` and `extra_query` are given, they are merged
424 request = self.client._build_request(
425 FinalRequestOptions(
426 method="post",
427 url="/foo",
428 **make_request_options(
429 query={"bar": "1"},
430 extra_query={"foo": "2"},
431 ),
432 ),
433 )
434 params = dict(request.url.params)
435 assert params == {"bar": "1", "foo": "2"}
436
437 # `extra_query` takes priority over `query` when keys clash
438 request = self.client._build_request(
439 FinalRequestOptions(
440 method="post",
441 url="/foo",
442 **make_request_options(
443 query={"foo": "1"},
444 extra_query={"foo": "2"},
445 ),
446 ),
447 )
448 params = dict(request.url.params)
449 assert params == {"foo": "2"}
450
451 def test_multipart_repeating_array(self, client: OpenAI) -> None:
452 request = client._build_request(
453 FinalRequestOptions.construct(
454 method="get",
455 url="/foo",
456 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
457 json_data={"array": ["foo", "bar"]},
458 files=[("foo.txt", b"hello world")],
459 )
460 )
461
462 assert request.read().split(b"\r\n") == [
463 b"--6b7ba517decee4a450543ea6ae821c82",
464 b'Content-Disposition: form-data; name="array[]"',
465 b"",
466 b"foo",
467 b"--6b7ba517decee4a450543ea6ae821c82",
468 b'Content-Disposition: form-data; name="array[]"',
469 b"",
470 b"bar",
471 b"--6b7ba517decee4a450543ea6ae821c82",
472 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
473 b"Content-Type: application/octet-stream",
474 b"",
475 b"hello world",
476 b"--6b7ba517decee4a450543ea6ae821c82--",
477 b"",
478 ]
479
480 @pytest.mark.respx(base_url=base_url)
481 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
482 class Model1(BaseModel):
483 name: str
484
485 class Model2(BaseModel):
486 foo: str
487
488 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
489
490 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
491 assert isinstance(response, Model2)
492 assert response.foo == "bar"
493
494 @pytest.mark.respx(base_url=base_url)
495 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
496 """Union of objects with the same field name using a different type"""
497
498 class Model1(BaseModel):
499 foo: int
500
501 class Model2(BaseModel):
502 foo: str
503
504 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
505
506 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
507 assert isinstance(response, Model2)
508 assert response.foo == "bar"
509
510 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
511
512 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
513 assert isinstance(response, Model1)
514 assert response.foo == 1
515
516 @pytest.mark.respx(base_url=base_url)
517 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
518 """
519 Response that sets Content-Type to something other than application/json but returns json data
520 """
521
522 class Model(BaseModel):
523 foo: int
524
525 respx_mock.get("/foo").mock(
526 return_value=httpx.Response(
527 200,
528 content=json.dumps({"foo": 2}),
529 headers={"Content-Type": "application/text"},
530 )
531 )
532
533 response = self.client.get("/foo", cast_to=Model)
534 assert isinstance(response, Model)
535 assert response.foo == 2
536
537 def test_base_url_setter(self) -> None:
538 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
539 assert client.base_url == "https://example.com/from_init/"
540
541 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
542
543 assert client.base_url == "https://example.com/from_setter/"
544
545 def test_base_url_env(self) -> None:
546 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
547 client = OpenAI(api_key=api_key, _strict_response_validation=True)
548 assert client.base_url == "http://localhost:5000/from/env/"
549
550 @pytest.mark.parametrize(
551 "client",
552 [
553 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
554 OpenAI(
555 base_url="http://localhost:5000/custom/path/",
556 api_key=api_key,
557 _strict_response_validation=True,
558 http_client=httpx.Client(),
559 ),
560 ],
561 ids=["standard", "custom http client"],
562 )
563 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
564 request = client._build_request(
565 FinalRequestOptions(
566 method="post",
567 url="/foo",
568 json_data={"foo": "bar"},
569 ),
570 )
571 assert request.url == "http://localhost:5000/custom/path/foo"
572
573 @pytest.mark.parametrize(
574 "client",
575 [
576 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
577 OpenAI(
578 base_url="http://localhost:5000/custom/path/",
579 api_key=api_key,
580 _strict_response_validation=True,
581 http_client=httpx.Client(),
582 ),
583 ],
584 ids=["standard", "custom http client"],
585 )
586 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
587 request = client._build_request(
588 FinalRequestOptions(
589 method="post",
590 url="/foo",
591 json_data={"foo": "bar"},
592 ),
593 )
594 assert request.url == "http://localhost:5000/custom/path/foo"
595
596 @pytest.mark.parametrize(
597 "client",
598 [
599 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
600 OpenAI(
601 base_url="http://localhost:5000/custom/path/",
602 api_key=api_key,
603 _strict_response_validation=True,
604 http_client=httpx.Client(),
605 ),
606 ],
607 ids=["standard", "custom http client"],
608 )
609 def test_absolute_request_url(self, client: OpenAI) -> None:
610 request = client._build_request(
611 FinalRequestOptions(
612 method="post",
613 url="https://myapi.com/foo",
614 json_data={"foo": "bar"},
615 ),
616 )
617 assert request.url == "https://myapi.com/foo"
618
619 def test_copied_client_does_not_close_http(self) -> None:
620 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
621 assert not client.is_closed()
622
623 copied = client.copy()
624 assert copied is not client
625
626 del copied
627
628 assert not client.is_closed()
629
630 def test_client_context_manager(self) -> None:
631 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
632 with client as c2:
633 assert c2 is client
634 assert not c2.is_closed()
635 assert not client.is_closed()
636 assert client.is_closed()
637
638 @pytest.mark.respx(base_url=base_url)
639 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
640 class Model(BaseModel):
641 foo: str
642
643 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
644
645 with pytest.raises(APIResponseValidationError) as exc:
646 self.client.get("/foo", cast_to=Model)
647
648 assert isinstance(exc.value.__cause__, ValidationError)
649
650 def test_client_max_retries_validation(self) -> None:
651 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
652 OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None))
653
654 @pytest.mark.respx(base_url=base_url)
655 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
656 class Model(BaseModel):
657 name: str
658
659 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
660
661 stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
662 assert isinstance(stream, Stream)
663 stream.response.close()
664
665 @pytest.mark.respx(base_url=base_url)
666 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
667 class Model(BaseModel):
668 name: str
669
670 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
671
672 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
673
674 with pytest.raises(APIResponseValidationError):
675 strict_client.get("/foo", cast_to=Model)
676
677 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
678
679 response = client.get("/foo", cast_to=Model)
680 assert isinstance(response, str) # type: ignore[unreachable]
681
682 @pytest.mark.parametrize(
683 "remaining_retries,retry_after,timeout",
684 [
685 [3, "20", 20],
686 [3, "0", 0.5],
687 [3, "-10", 0.5],
688 [3, "60", 60],
689 [3, "61", 0.5],
690 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
691 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
692 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
693 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
694 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
695 [3, "99999999999999999999999999999999999", 0.5],
696 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
697 [3, "", 0.5],
698 [2, "", 0.5 * 2.0],
699 [1, "", 0.5 * 4.0],
700 ],
701 )
702 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
703 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
704 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
705
706 headers = httpx.Headers({"retry-after": retry_after})
707 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
708 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
709 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
710
711 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
712 @pytest.mark.respx(base_url=base_url)
713 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
714 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
715
716 with pytest.raises(APITimeoutError):
717 self.client.post(
718 "/chat/completions",
719 body=cast(
720 object,
721 dict(
722 messages=[
723 {
724 "role": "user",
725 "content": "Say this is a test",
726 }
727 ],
728 model="gpt-3.5-turbo",
729 ),
730 ),
731 cast_to=httpx.Response,
732 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
733 )
734
735 assert _get_open_connections(self.client) == 0
736
737 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
738 @pytest.mark.respx(base_url=base_url)
739 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
740 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
741
742 with pytest.raises(APIStatusError):
743 self.client.post(
744 "/chat/completions",
745 body=cast(
746 object,
747 dict(
748 messages=[
749 {
750 "role": "user",
751 "content": "Say this is a test",
752 }
753 ],
754 model="gpt-3.5-turbo",
755 ),
756 ),
757 cast_to=httpx.Response,
758 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
759 )
760
761 assert _get_open_connections(self.client) == 0
762
763 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
764 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
765 @pytest.mark.respx(base_url=base_url)
766 def test_retries_taken(self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter) -> None:
767 client = client.with_options(max_retries=4)
768
769 nb_retries = 0
770
771 def retry_handler(_request: httpx.Request) -> httpx.Response:
772 nonlocal nb_retries
773 if nb_retries < failures_before_success:
774 nb_retries += 1
775 return httpx.Response(500)
776 return httpx.Response(200)
777
778 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
779
780 response = client.chat.completions.with_raw_response.create(
781 messages=[
782 {
783 "content": "string",
784 "role": "system",
785 }
786 ],
787 model="gpt-4o",
788 )
789
790 assert response.retries_taken == failures_before_success
791 if failures_before_success == 0:
792 assert "x-stainless-retry-count" not in response.http_request.headers
793 else:
794 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
795
796 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
797 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
798 @pytest.mark.respx(base_url=base_url)
799 def test_retries_taken_new_response_class(
800 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
801 ) -> None:
802 client = client.with_options(max_retries=4)
803
804 nb_retries = 0
805
806 def retry_handler(_request: httpx.Request) -> httpx.Response:
807 nonlocal nb_retries
808 if nb_retries < failures_before_success:
809 nb_retries += 1
810 return httpx.Response(500)
811 return httpx.Response(200)
812
813 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
814
815 with client.chat.completions.with_streaming_response.create(
816 messages=[
817 {
818 "content": "string",
819 "role": "system",
820 }
821 ],
822 model="gpt-4o",
823 ) as response:
824 assert response.retries_taken == failures_before_success
825 if failures_before_success == 0:
826 assert "x-stainless-retry-count" not in response.http_request.headers
827 else:
828 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
829
830
831class TestAsyncOpenAI:
832 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
833
834 @pytest.mark.respx(base_url=base_url)
835 @pytest.mark.asyncio
836 async def test_raw_response(self, respx_mock: MockRouter) -> None:
837 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
838
839 response = await self.client.post("/foo", cast_to=httpx.Response)
840 assert response.status_code == 200
841 assert isinstance(response, httpx.Response)
842 assert response.json() == {"foo": "bar"}
843
844 @pytest.mark.respx(base_url=base_url)
845 @pytest.mark.asyncio
846 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
847 respx_mock.post("/foo").mock(
848 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
849 )
850
851 response = await self.client.post("/foo", cast_to=httpx.Response)
852 assert response.status_code == 200
853 assert isinstance(response, httpx.Response)
854 assert response.json() == {"foo": "bar"}
855
856 def test_copy(self) -> None:
857 copied = self.client.copy()
858 assert id(copied) != id(self.client)
859
860 copied = self.client.copy(api_key="another My API Key")
861 assert copied.api_key == "another My API Key"
862 assert self.client.api_key == "My API Key"
863
864 def test_copy_default_options(self) -> None:
865 # options that have a default are overridden correctly
866 copied = self.client.copy(max_retries=7)
867 assert copied.max_retries == 7
868 assert self.client.max_retries == 2
869
870 copied2 = copied.copy(max_retries=6)
871 assert copied2.max_retries == 6
872 assert copied.max_retries == 7
873
874 # timeout
875 assert isinstance(self.client.timeout, httpx.Timeout)
876 copied = self.client.copy(timeout=None)
877 assert copied.timeout is None
878 assert isinstance(self.client.timeout, httpx.Timeout)
879
880 def test_copy_default_headers(self) -> None:
881 client = AsyncOpenAI(
882 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
883 )
884 assert client.default_headers["X-Foo"] == "bar"
885
886 # does not override the already given value when not specified
887 copied = client.copy()
888 assert copied.default_headers["X-Foo"] == "bar"
889
890 # merges already given headers
891 copied = client.copy(default_headers={"X-Bar": "stainless"})
892 assert copied.default_headers["X-Foo"] == "bar"
893 assert copied.default_headers["X-Bar"] == "stainless"
894
895 # uses new values for any already given headers
896 copied = client.copy(default_headers={"X-Foo": "stainless"})
897 assert copied.default_headers["X-Foo"] == "stainless"
898
899 # set_default_headers
900
901 # completely overrides already set values
902 copied = client.copy(set_default_headers={})
903 assert copied.default_headers.get("X-Foo") is None
904
905 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
906 assert copied.default_headers["X-Bar"] == "Robert"
907
908 with pytest.raises(
909 ValueError,
910 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
911 ):
912 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
913
914 def test_copy_default_query(self) -> None:
915 client = AsyncOpenAI(
916 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
917 )
918 assert _get_params(client)["foo"] == "bar"
919
920 # does not override the already given value when not specified
921 copied = client.copy()
922 assert _get_params(copied)["foo"] == "bar"
923
924 # merges already given params
925 copied = client.copy(default_query={"bar": "stainless"})
926 params = _get_params(copied)
927 assert params["foo"] == "bar"
928 assert params["bar"] == "stainless"
929
930 # uses new values for any already given headers
931 copied = client.copy(default_query={"foo": "stainless"})
932 assert _get_params(copied)["foo"] == "stainless"
933
934 # set_default_query
935
936 # completely overrides already set values
937 copied = client.copy(set_default_query={})
938 assert _get_params(copied) == {}
939
940 copied = client.copy(set_default_query={"bar": "Robert"})
941 assert _get_params(copied)["bar"] == "Robert"
942
943 with pytest.raises(
944 ValueError,
945 # TODO: update
946 match="`default_query` and `set_default_query` arguments are mutually exclusive",
947 ):
948 client.copy(set_default_query={}, default_query={"foo": "Bar"})
949
950 def test_copy_signature(self) -> None:
951 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
952 init_signature = inspect.signature(
953 # mypy doesn't like that we access the `__init__` property.
954 self.client.__init__, # type: ignore[misc]
955 )
956 copy_signature = inspect.signature(self.client.copy)
957 exclude_params = {"transport", "proxies", "_strict_response_validation"}
958
959 for name in init_signature.parameters.keys():
960 if name in exclude_params:
961 continue
962
963 copy_param = copy_signature.parameters.get(name)
964 assert copy_param is not None, f"copy() signature is missing the {name} param"
965
966 def test_copy_build_request(self) -> None:
967 options = FinalRequestOptions(method="get", url="/foo")
968
969 def build_request(options: FinalRequestOptions) -> None:
970 client = self.client.copy()
971 client._build_request(options)
972
973 # ensure that the machinery is warmed up before tracing starts.
974 build_request(options)
975 gc.collect()
976
977 tracemalloc.start(1000)
978
979 snapshot_before = tracemalloc.take_snapshot()
980
981 ITERATIONS = 10
982 for _ in range(ITERATIONS):
983 build_request(options)
984
985 gc.collect()
986 snapshot_after = tracemalloc.take_snapshot()
987
988 tracemalloc.stop()
989
990 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
991 if diff.count == 0:
992 # Avoid false positives by considering only leaks (i.e. allocations that persist).
993 return
994
995 if diff.count % ITERATIONS != 0:
996 # Avoid false positives by considering only leaks that appear per iteration.
997 return
998
999 for frame in diff.traceback:
1000 if any(
1001 frame.filename.endswith(fragment)
1002 for fragment in [
1003 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
1004 #
1005 # removing the decorator fixes the leak for reasons we don't understand.
1006 "openai/_legacy_response.py",
1007 "openai/_response.py",
1008 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
1009 "openai/_compat.py",
1010 # Standard library leaks we don't care about.
1011 "/logging/__init__.py",
1012 ]
1013 ):
1014 return
1015
1016 leaks.append(diff)
1017
1018 leaks: list[tracemalloc.StatisticDiff] = []
1019 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
1020 add_leak(leaks, diff)
1021 if leaks:
1022 for leak in leaks:
1023 print("MEMORY LEAK:", leak)
1024 for frame in leak.traceback:
1025 print(frame)
1026 raise AssertionError()
1027
1028 async def test_request_timeout(self) -> None:
1029 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
1030 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1031 assert timeout == DEFAULT_TIMEOUT
1032
1033 request = self.client._build_request(
1034 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
1035 )
1036 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1037 assert timeout == httpx.Timeout(100.0)
1038
1039 async def test_client_timeout_option(self) -> None:
1040 client = AsyncOpenAI(
1041 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
1042 )
1043
1044 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1045 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1046 assert timeout == httpx.Timeout(0)
1047
1048 async def test_http_client_timeout_option(self) -> None:
1049 # custom timeout given to the httpx client should be used
1050 async with httpx.AsyncClient(timeout=None) as http_client:
1051 client = AsyncOpenAI(
1052 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1053 )
1054
1055 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1056 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1057 assert timeout == httpx.Timeout(None)
1058
1059 # no timeout given to the httpx client should not use the httpx default
1060 async with httpx.AsyncClient() as http_client:
1061 client = AsyncOpenAI(
1062 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1063 )
1064
1065 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1066 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1067 assert timeout == DEFAULT_TIMEOUT
1068
1069 # explicitly passing the default timeout currently results in it being ignored
1070 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1071 client = AsyncOpenAI(
1072 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1073 )
1074
1075 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1076 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1077 assert timeout == DEFAULT_TIMEOUT # our default
1078
1079 def test_invalid_http_client(self) -> None:
1080 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1081 with httpx.Client() as http_client:
1082 AsyncOpenAI(
1083 base_url=base_url,
1084 api_key=api_key,
1085 _strict_response_validation=True,
1086 http_client=cast(Any, http_client),
1087 )
1088
1089 def test_default_headers_option(self) -> None:
1090 client = AsyncOpenAI(
1091 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1092 )
1093 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1094 assert request.headers.get("x-foo") == "bar"
1095 assert request.headers.get("x-stainless-lang") == "python"
1096
1097 client2 = AsyncOpenAI(
1098 base_url=base_url,
1099 api_key=api_key,
1100 _strict_response_validation=True,
1101 default_headers={
1102 "X-Foo": "stainless",
1103 "X-Stainless-Lang": "my-overriding-header",
1104 },
1105 )
1106 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1107 assert request.headers.get("x-foo") == "stainless"
1108 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1109
1110 def test_validate_headers(self) -> None:
1111 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1112 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1113 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1114
1115 with pytest.raises(OpenAIError):
1116 with update_env(**{"OPENAI_API_KEY": Omit()}):
1117 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1118 _ = client2
1119
1120 def test_default_query_option(self) -> None:
1121 client = AsyncOpenAI(
1122 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1123 )
1124 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1125 url = httpx.URL(request.url)
1126 assert dict(url.params) == {"query_param": "bar"}
1127
1128 request = client._build_request(
1129 FinalRequestOptions(
1130 method="get",
1131 url="/foo",
1132 params={"foo": "baz", "query_param": "overriden"},
1133 )
1134 )
1135 url = httpx.URL(request.url)
1136 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
1137
1138 def test_request_extra_json(self) -> None:
1139 request = self.client._build_request(
1140 FinalRequestOptions(
1141 method="post",
1142 url="/foo",
1143 json_data={"foo": "bar"},
1144 extra_json={"baz": False},
1145 ),
1146 )
1147 data = json.loads(request.content.decode("utf-8"))
1148 assert data == {"foo": "bar", "baz": False}
1149
1150 request = self.client._build_request(
1151 FinalRequestOptions(
1152 method="post",
1153 url="/foo",
1154 extra_json={"baz": False},
1155 ),
1156 )
1157 data = json.loads(request.content.decode("utf-8"))
1158 assert data == {"baz": False}
1159
1160 # `extra_json` takes priority over `json_data` when keys clash
1161 request = self.client._build_request(
1162 FinalRequestOptions(
1163 method="post",
1164 url="/foo",
1165 json_data={"foo": "bar", "baz": True},
1166 extra_json={"baz": None},
1167 ),
1168 )
1169 data = json.loads(request.content.decode("utf-8"))
1170 assert data == {"foo": "bar", "baz": None}
1171
1172 def test_request_extra_headers(self) -> None:
1173 request = self.client._build_request(
1174 FinalRequestOptions(
1175 method="post",
1176 url="/foo",
1177 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1178 ),
1179 )
1180 assert request.headers.get("X-Foo") == "Foo"
1181
1182 # `extra_headers` takes priority over `default_headers` when keys clash
1183 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1184 FinalRequestOptions(
1185 method="post",
1186 url="/foo",
1187 **make_request_options(
1188 extra_headers={"X-Bar": "false"},
1189 ),
1190 ),
1191 )
1192 assert request.headers.get("X-Bar") == "false"
1193
1194 def test_request_extra_query(self) -> None:
1195 request = self.client._build_request(
1196 FinalRequestOptions(
1197 method="post",
1198 url="/foo",
1199 **make_request_options(
1200 extra_query={"my_query_param": "Foo"},
1201 ),
1202 ),
1203 )
1204 params = dict(request.url.params)
1205 assert params == {"my_query_param": "Foo"}
1206
1207 # if both `query` and `extra_query` are given, they are merged
1208 request = self.client._build_request(
1209 FinalRequestOptions(
1210 method="post",
1211 url="/foo",
1212 **make_request_options(
1213 query={"bar": "1"},
1214 extra_query={"foo": "2"},
1215 ),
1216 ),
1217 )
1218 params = dict(request.url.params)
1219 assert params == {"bar": "1", "foo": "2"}
1220
1221 # `extra_query` takes priority over `query` when keys clash
1222 request = self.client._build_request(
1223 FinalRequestOptions(
1224 method="post",
1225 url="/foo",
1226 **make_request_options(
1227 query={"foo": "1"},
1228 extra_query={"foo": "2"},
1229 ),
1230 ),
1231 )
1232 params = dict(request.url.params)
1233 assert params == {"foo": "2"}
1234
1235 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1236 request = async_client._build_request(
1237 FinalRequestOptions.construct(
1238 method="get",
1239 url="/foo",
1240 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1241 json_data={"array": ["foo", "bar"]},
1242 files=[("foo.txt", b"hello world")],
1243 )
1244 )
1245
1246 assert request.read().split(b"\r\n") == [
1247 b"--6b7ba517decee4a450543ea6ae821c82",
1248 b'Content-Disposition: form-data; name="array[]"',
1249 b"",
1250 b"foo",
1251 b"--6b7ba517decee4a450543ea6ae821c82",
1252 b'Content-Disposition: form-data; name="array[]"',
1253 b"",
1254 b"bar",
1255 b"--6b7ba517decee4a450543ea6ae821c82",
1256 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1257 b"Content-Type: application/octet-stream",
1258 b"",
1259 b"hello world",
1260 b"--6b7ba517decee4a450543ea6ae821c82--",
1261 b"",
1262 ]
1263
1264 @pytest.mark.respx(base_url=base_url)
1265 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1266 class Model1(BaseModel):
1267 name: str
1268
1269 class Model2(BaseModel):
1270 foo: str
1271
1272 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1273
1274 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1275 assert isinstance(response, Model2)
1276 assert response.foo == "bar"
1277
1278 @pytest.mark.respx(base_url=base_url)
1279 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1280 """Union of objects with the same field name using a different type"""
1281
1282 class Model1(BaseModel):
1283 foo: int
1284
1285 class Model2(BaseModel):
1286 foo: str
1287
1288 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1289
1290 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1291 assert isinstance(response, Model2)
1292 assert response.foo == "bar"
1293
1294 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1295
1296 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1297 assert isinstance(response, Model1)
1298 assert response.foo == 1
1299
1300 @pytest.mark.respx(base_url=base_url)
1301 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1302 """
1303 Response that sets Content-Type to something other than application/json but returns json data
1304 """
1305
1306 class Model(BaseModel):
1307 foo: int
1308
1309 respx_mock.get("/foo").mock(
1310 return_value=httpx.Response(
1311 200,
1312 content=json.dumps({"foo": 2}),
1313 headers={"Content-Type": "application/text"},
1314 )
1315 )
1316
1317 response = await self.client.get("/foo", cast_to=Model)
1318 assert isinstance(response, Model)
1319 assert response.foo == 2
1320
1321 def test_base_url_setter(self) -> None:
1322 client = AsyncOpenAI(
1323 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1324 )
1325 assert client.base_url == "https://example.com/from_init/"
1326
1327 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1328
1329 assert client.base_url == "https://example.com/from_setter/"
1330
1331 def test_base_url_env(self) -> None:
1332 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1333 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1334 assert client.base_url == "http://localhost:5000/from/env/"
1335
1336 @pytest.mark.parametrize(
1337 "client",
1338 [
1339 AsyncOpenAI(
1340 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1341 ),
1342 AsyncOpenAI(
1343 base_url="http://localhost:5000/custom/path/",
1344 api_key=api_key,
1345 _strict_response_validation=True,
1346 http_client=httpx.AsyncClient(),
1347 ),
1348 ],
1349 ids=["standard", "custom http client"],
1350 )
1351 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1352 request = client._build_request(
1353 FinalRequestOptions(
1354 method="post",
1355 url="/foo",
1356 json_data={"foo": "bar"},
1357 ),
1358 )
1359 assert request.url == "http://localhost:5000/custom/path/foo"
1360
1361 @pytest.mark.parametrize(
1362 "client",
1363 [
1364 AsyncOpenAI(
1365 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1366 ),
1367 AsyncOpenAI(
1368 base_url="http://localhost:5000/custom/path/",
1369 api_key=api_key,
1370 _strict_response_validation=True,
1371 http_client=httpx.AsyncClient(),
1372 ),
1373 ],
1374 ids=["standard", "custom http client"],
1375 )
1376 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1377 request = client._build_request(
1378 FinalRequestOptions(
1379 method="post",
1380 url="/foo",
1381 json_data={"foo": "bar"},
1382 ),
1383 )
1384 assert request.url == "http://localhost:5000/custom/path/foo"
1385
1386 @pytest.mark.parametrize(
1387 "client",
1388 [
1389 AsyncOpenAI(
1390 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1391 ),
1392 AsyncOpenAI(
1393 base_url="http://localhost:5000/custom/path/",
1394 api_key=api_key,
1395 _strict_response_validation=True,
1396 http_client=httpx.AsyncClient(),
1397 ),
1398 ],
1399 ids=["standard", "custom http client"],
1400 )
1401 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1402 request = client._build_request(
1403 FinalRequestOptions(
1404 method="post",
1405 url="https://myapi.com/foo",
1406 json_data={"foo": "bar"},
1407 ),
1408 )
1409 assert request.url == "https://myapi.com/foo"
1410
1411 async def test_copied_client_does_not_close_http(self) -> None:
1412 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1413 assert not client.is_closed()
1414
1415 copied = client.copy()
1416 assert copied is not client
1417
1418 del copied
1419
1420 await asyncio.sleep(0.2)
1421 assert not client.is_closed()
1422
1423 async def test_client_context_manager(self) -> None:
1424 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1425 async with client as c2:
1426 assert c2 is client
1427 assert not c2.is_closed()
1428 assert not client.is_closed()
1429 assert client.is_closed()
1430
1431 @pytest.mark.respx(base_url=base_url)
1432 @pytest.mark.asyncio
1433 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1434 class Model(BaseModel):
1435 foo: str
1436
1437 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1438
1439 with pytest.raises(APIResponseValidationError) as exc:
1440 await self.client.get("/foo", cast_to=Model)
1441
1442 assert isinstance(exc.value.__cause__, ValidationError)
1443
1444 async def test_client_max_retries_validation(self) -> None:
1445 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
1446 AsyncOpenAI(
1447 base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)
1448 )
1449
1450 @pytest.mark.respx(base_url=base_url)
1451 @pytest.mark.asyncio
1452 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1453 class Model(BaseModel):
1454 name: str
1455
1456 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1457
1458 stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1459 assert isinstance(stream, AsyncStream)
1460 await stream.response.aclose()
1461
1462 @pytest.mark.respx(base_url=base_url)
1463 @pytest.mark.asyncio
1464 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1465 class Model(BaseModel):
1466 name: str
1467
1468 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1469
1470 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1471
1472 with pytest.raises(APIResponseValidationError):
1473 await strict_client.get("/foo", cast_to=Model)
1474
1475 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1476
1477 response = await client.get("/foo", cast_to=Model)
1478 assert isinstance(response, str) # type: ignore[unreachable]
1479
1480 @pytest.mark.parametrize(
1481 "remaining_retries,retry_after,timeout",
1482 [
1483 [3, "20", 20],
1484 [3, "0", 0.5],
1485 [3, "-10", 0.5],
1486 [3, "60", 60],
1487 [3, "61", 0.5],
1488 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1489 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1490 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1491 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1492 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1493 [3, "99999999999999999999999999999999999", 0.5],
1494 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1495 [3, "", 0.5],
1496 [2, "", 0.5 * 2.0],
1497 [1, "", 0.5 * 4.0],
1498 ],
1499 )
1500 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1501 @pytest.mark.asyncio
1502 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1503 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1504
1505 headers = httpx.Headers({"retry-after": retry_after})
1506 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1507 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1508 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1509
1510 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1511 @pytest.mark.respx(base_url=base_url)
1512 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1513 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1514
1515 with pytest.raises(APITimeoutError):
1516 await self.client.post(
1517 "/chat/completions",
1518 body=cast(
1519 object,
1520 dict(
1521 messages=[
1522 {
1523 "role": "user",
1524 "content": "Say this is a test",
1525 }
1526 ],
1527 model="gpt-3.5-turbo",
1528 ),
1529 ),
1530 cast_to=httpx.Response,
1531 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1532 )
1533
1534 assert _get_open_connections(self.client) == 0
1535
1536 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1537 @pytest.mark.respx(base_url=base_url)
1538 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1539 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1540
1541 with pytest.raises(APIStatusError):
1542 await self.client.post(
1543 "/chat/completions",
1544 body=cast(
1545 object,
1546 dict(
1547 messages=[
1548 {
1549 "role": "user",
1550 "content": "Say this is a test",
1551 }
1552 ],
1553 model="gpt-3.5-turbo",
1554 ),
1555 ),
1556 cast_to=httpx.Response,
1557 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1558 )
1559
1560 assert _get_open_connections(self.client) == 0
1561
1562 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1563 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1564 @pytest.mark.respx(base_url=base_url)
1565 @pytest.mark.asyncio
1566 async def test_retries_taken(
1567 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1568 ) -> None:
1569 client = async_client.with_options(max_retries=4)
1570
1571 nb_retries = 0
1572
1573 def retry_handler(_request: httpx.Request) -> httpx.Response:
1574 nonlocal nb_retries
1575 if nb_retries < failures_before_success:
1576 nb_retries += 1
1577 return httpx.Response(500)
1578 return httpx.Response(200)
1579
1580 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1581
1582 response = await client.chat.completions.with_raw_response.create(
1583 messages=[
1584 {
1585 "content": "string",
1586 "role": "system",
1587 }
1588 ],
1589 model="gpt-4o",
1590 )
1591
1592 assert response.retries_taken == failures_before_success
1593 if failures_before_success == 0:
1594 assert "x-stainless-retry-count" not in response.http_request.headers
1595 else:
1596 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1597
1598 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1599 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1600 @pytest.mark.respx(base_url=base_url)
1601 @pytest.mark.asyncio
1602 async def test_retries_taken_new_response_class(
1603 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1604 ) -> None:
1605 client = async_client.with_options(max_retries=4)
1606
1607 nb_retries = 0
1608
1609 def retry_handler(_request: httpx.Request) -> httpx.Response:
1610 nonlocal nb_retries
1611 if nb_retries < failures_before_success:
1612 nb_retries += 1
1613 return httpx.Response(500)
1614 return httpx.Response(200)
1615
1616 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1617
1618 async with client.chat.completions.with_streaming_response.create(
1619 messages=[
1620 {
1621 "content": "string",
1622 "role": "system",
1623 }
1624 ],
1625 model="gpt-4o",
1626 ) as response:
1627 assert response.retries_taken == failures_before_success
1628 if failures_before_success == 0:
1629 assert "x-stainless-retry-count" not in response.http_request.headers
1630 else:
1631 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success