openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.104.2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1889lines · modecode

1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import gc
6import os
7import sys
8import json
9import time
10import asyncio
11import inspect
12import subprocess
13import tracemalloc
14from typing import Any, Union, cast
15from textwrap import dedent
16from unittest import mock
17from typing_extensions import Literal
18
19import httpx
20import pytest
21from respx import MockRouter
22from pydantic import ValidationError
23
24from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
25from openai._types import Omit
26from openai._models import BaseModel, FinalRequestOptions
27from openai._streaming import Stream, AsyncStream
28from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
29from openai._base_client import (
30 DEFAULT_TIMEOUT,
31 HTTPX_DEFAULT_TIMEOUT,
32 BaseClient,
33 DefaultHttpxClient,
34 DefaultAsyncHttpxClient,
35 make_request_options,
36)
37
38from .utils import update_env
39
40base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
41api_key = "My API Key"
42
43
44def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
45 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
46 url = httpx.URL(request.url)
47 return dict(url.params)
48
49
50def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
51 return 0.1
52
53
54def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
55 transport = client._client._transport
56 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
57
58 pool = transport._pool
59 return len(pool._requests)
60
61
62class TestOpenAI:
63 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
64
65 @pytest.mark.respx(base_url=base_url)
66 def test_raw_response(self, respx_mock: MockRouter) -> None:
67 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
68
69 response = self.client.post("/foo", cast_to=httpx.Response)
70 assert response.status_code == 200
71 assert isinstance(response, httpx.Response)
72 assert response.json() == {"foo": "bar"}
73
74 @pytest.mark.respx(base_url=base_url)
75 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
76 respx_mock.post("/foo").mock(
77 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
78 )
79
80 response = self.client.post("/foo", cast_to=httpx.Response)
81 assert response.status_code == 200
82 assert isinstance(response, httpx.Response)
83 assert response.json() == {"foo": "bar"}
84
85 def test_copy(self) -> None:
86 copied = self.client.copy()
87 assert id(copied) != id(self.client)
88
89 copied = self.client.copy(api_key="another My API Key")
90 assert copied.api_key == "another My API Key"
91 assert self.client.api_key == "My API Key"
92
93 def test_copy_default_options(self) -> None:
94 # options that have a default are overridden correctly
95 copied = self.client.copy(max_retries=7)
96 assert copied.max_retries == 7
97 assert self.client.max_retries == 2
98
99 copied2 = copied.copy(max_retries=6)
100 assert copied2.max_retries == 6
101 assert copied.max_retries == 7
102
103 # timeout
104 assert isinstance(self.client.timeout, httpx.Timeout)
105 copied = self.client.copy(timeout=None)
106 assert copied.timeout is None
107 assert isinstance(self.client.timeout, httpx.Timeout)
108
109 def test_copy_default_headers(self) -> None:
110 client = OpenAI(
111 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
112 )
113 assert client.default_headers["X-Foo"] == "bar"
114
115 # does not override the already given value when not specified
116 copied = client.copy()
117 assert copied.default_headers["X-Foo"] == "bar"
118
119 # merges already given headers
120 copied = client.copy(default_headers={"X-Bar": "stainless"})
121 assert copied.default_headers["X-Foo"] == "bar"
122 assert copied.default_headers["X-Bar"] == "stainless"
123
124 # uses new values for any already given headers
125 copied = client.copy(default_headers={"X-Foo": "stainless"})
126 assert copied.default_headers["X-Foo"] == "stainless"
127
128 # set_default_headers
129
130 # completely overrides already set values
131 copied = client.copy(set_default_headers={})
132 assert copied.default_headers.get("X-Foo") is None
133
134 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
135 assert copied.default_headers["X-Bar"] == "Robert"
136
137 with pytest.raises(
138 ValueError,
139 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
140 ):
141 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
142
143 def test_copy_default_query(self) -> None:
144 client = OpenAI(
145 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
146 )
147 assert _get_params(client)["foo"] == "bar"
148
149 # does not override the already given value when not specified
150 copied = client.copy()
151 assert _get_params(copied)["foo"] == "bar"
152
153 # merges already given params
154 copied = client.copy(default_query={"bar": "stainless"})
155 params = _get_params(copied)
156 assert params["foo"] == "bar"
157 assert params["bar"] == "stainless"
158
159 # uses new values for any already given headers
160 copied = client.copy(default_query={"foo": "stainless"})
161 assert _get_params(copied)["foo"] == "stainless"
162
163 # set_default_query
164
165 # completely overrides already set values
166 copied = client.copy(set_default_query={})
167 assert _get_params(copied) == {}
168
169 copied = client.copy(set_default_query={"bar": "Robert"})
170 assert _get_params(copied)["bar"] == "Robert"
171
172 with pytest.raises(
173 ValueError,
174 # TODO: update
175 match="`default_query` and `set_default_query` arguments are mutually exclusive",
176 ):
177 client.copy(set_default_query={}, default_query={"foo": "Bar"})
178
179 def test_copy_signature(self) -> None:
180 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
181 init_signature = inspect.signature(
182 # mypy doesn't like that we access the `__init__` property.
183 self.client.__init__, # type: ignore[misc]
184 )
185 copy_signature = inspect.signature(self.client.copy)
186 exclude_params = {"transport", "proxies", "_strict_response_validation"}
187
188 for name in init_signature.parameters.keys():
189 if name in exclude_params:
190 continue
191
192 copy_param = copy_signature.parameters.get(name)
193 assert copy_param is not None, f"copy() signature is missing the {name} param"
194
195 @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
196 def test_copy_build_request(self) -> None:
197 options = FinalRequestOptions(method="get", url="/foo")
198
199 def build_request(options: FinalRequestOptions) -> None:
200 client = self.client.copy()
201 client._build_request(options)
202
203 # ensure that the machinery is warmed up before tracing starts.
204 build_request(options)
205 gc.collect()
206
207 tracemalloc.start(1000)
208
209 snapshot_before = tracemalloc.take_snapshot()
210
211 ITERATIONS = 10
212 for _ in range(ITERATIONS):
213 build_request(options)
214
215 gc.collect()
216 snapshot_after = tracemalloc.take_snapshot()
217
218 tracemalloc.stop()
219
220 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
221 if diff.count == 0:
222 # Avoid false positives by considering only leaks (i.e. allocations that persist).
223 return
224
225 if diff.count % ITERATIONS != 0:
226 # Avoid false positives by considering only leaks that appear per iteration.
227 return
228
229 for frame in diff.traceback:
230 if any(
231 frame.filename.endswith(fragment)
232 for fragment in [
233 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
234 #
235 # removing the decorator fixes the leak for reasons we don't understand.
236 "openai/_legacy_response.py",
237 "openai/_response.py",
238 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
239 "openai/_compat.py",
240 # Standard library leaks we don't care about.
241 "/logging/__init__.py",
242 ]
243 ):
244 return
245
246 leaks.append(diff)
247
248 leaks: list[tracemalloc.StatisticDiff] = []
249 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
250 add_leak(leaks, diff)
251 if leaks:
252 for leak in leaks:
253 print("MEMORY LEAK:", leak)
254 for frame in leak.traceback:
255 print(frame)
256 raise AssertionError()
257
258 def test_request_timeout(self) -> None:
259 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
260 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
261 assert timeout == DEFAULT_TIMEOUT
262
263 request = self.client._build_request(
264 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
265 )
266 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
267 assert timeout == httpx.Timeout(100.0)
268
269 def test_client_timeout_option(self) -> None:
270 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
271
272 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
273 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
274 assert timeout == httpx.Timeout(0)
275
276 def test_http_client_timeout_option(self) -> None:
277 # custom timeout given to the httpx client should be used
278 with httpx.Client(timeout=None) as http_client:
279 client = OpenAI(
280 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
281 )
282
283 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
284 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
285 assert timeout == httpx.Timeout(None)
286
287 # no timeout given to the httpx client should not use the httpx default
288 with httpx.Client() as http_client:
289 client = OpenAI(
290 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
291 )
292
293 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
294 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
295 assert timeout == DEFAULT_TIMEOUT
296
297 # explicitly passing the default timeout currently results in it being ignored
298 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
299 client = OpenAI(
300 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
301 )
302
303 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
304 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
305 assert timeout == DEFAULT_TIMEOUT # our default
306
307 async def test_invalid_http_client(self) -> None:
308 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
309 async with httpx.AsyncClient() as http_client:
310 OpenAI(
311 base_url=base_url,
312 api_key=api_key,
313 _strict_response_validation=True,
314 http_client=cast(Any, http_client),
315 )
316
317 def test_default_headers_option(self) -> None:
318 client = OpenAI(
319 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
320 )
321 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
322 assert request.headers.get("x-foo") == "bar"
323 assert request.headers.get("x-stainless-lang") == "python"
324
325 client2 = OpenAI(
326 base_url=base_url,
327 api_key=api_key,
328 _strict_response_validation=True,
329 default_headers={
330 "X-Foo": "stainless",
331 "X-Stainless-Lang": "my-overriding-header",
332 },
333 )
334 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
335 assert request.headers.get("x-foo") == "stainless"
336 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
337
338 def test_validate_headers(self) -> None:
339 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
340 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
341 assert request.headers.get("Authorization") == f"Bearer {api_key}"
342
343 with pytest.raises(OpenAIError):
344 with update_env(**{"OPENAI_API_KEY": Omit()}):
345 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
346 _ = client2
347
348 def test_default_query_option(self) -> None:
349 client = OpenAI(
350 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
351 )
352 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
353 url = httpx.URL(request.url)
354 assert dict(url.params) == {"query_param": "bar"}
355
356 request = client._build_request(
357 FinalRequestOptions(
358 method="get",
359 url="/foo",
360 params={"foo": "baz", "query_param": "overridden"},
361 )
362 )
363 url = httpx.URL(request.url)
364 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
365
366 def test_request_extra_json(self) -> None:
367 request = self.client._build_request(
368 FinalRequestOptions(
369 method="post",
370 url="/foo",
371 json_data={"foo": "bar"},
372 extra_json={"baz": False},
373 ),
374 )
375 data = json.loads(request.content.decode("utf-8"))
376 assert data == {"foo": "bar", "baz": False}
377
378 request = self.client._build_request(
379 FinalRequestOptions(
380 method="post",
381 url="/foo",
382 extra_json={"baz": False},
383 ),
384 )
385 data = json.loads(request.content.decode("utf-8"))
386 assert data == {"baz": False}
387
388 # `extra_json` takes priority over `json_data` when keys clash
389 request = self.client._build_request(
390 FinalRequestOptions(
391 method="post",
392 url="/foo",
393 json_data={"foo": "bar", "baz": True},
394 extra_json={"baz": None},
395 ),
396 )
397 data = json.loads(request.content.decode("utf-8"))
398 assert data == {"foo": "bar", "baz": None}
399
400 def test_request_extra_headers(self) -> None:
401 request = self.client._build_request(
402 FinalRequestOptions(
403 method="post",
404 url="/foo",
405 **make_request_options(extra_headers={"X-Foo": "Foo"}),
406 ),
407 )
408 assert request.headers.get("X-Foo") == "Foo"
409
410 # `extra_headers` takes priority over `default_headers` when keys clash
411 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
412 FinalRequestOptions(
413 method="post",
414 url="/foo",
415 **make_request_options(
416 extra_headers={"X-Bar": "false"},
417 ),
418 ),
419 )
420 assert request.headers.get("X-Bar") == "false"
421
422 def test_request_extra_query(self) -> None:
423 request = self.client._build_request(
424 FinalRequestOptions(
425 method="post",
426 url="/foo",
427 **make_request_options(
428 extra_query={"my_query_param": "Foo"},
429 ),
430 ),
431 )
432 params = dict(request.url.params)
433 assert params == {"my_query_param": "Foo"}
434
435 # if both `query` and `extra_query` are given, they are merged
436 request = self.client._build_request(
437 FinalRequestOptions(
438 method="post",
439 url="/foo",
440 **make_request_options(
441 query={"bar": "1"},
442 extra_query={"foo": "2"},
443 ),
444 ),
445 )
446 params = dict(request.url.params)
447 assert params == {"bar": "1", "foo": "2"}
448
449 # `extra_query` takes priority over `query` when keys clash
450 request = self.client._build_request(
451 FinalRequestOptions(
452 method="post",
453 url="/foo",
454 **make_request_options(
455 query={"foo": "1"},
456 extra_query={"foo": "2"},
457 ),
458 ),
459 )
460 params = dict(request.url.params)
461 assert params == {"foo": "2"}
462
463 def test_multipart_repeating_array(self, client: OpenAI) -> None:
464 request = client._build_request(
465 FinalRequestOptions.construct(
466 method="post",
467 url="/foo",
468 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
469 json_data={"array": ["foo", "bar"]},
470 files=[("foo.txt", b"hello world")],
471 )
472 )
473
474 assert request.read().split(b"\r\n") == [
475 b"--6b7ba517decee4a450543ea6ae821c82",
476 b'Content-Disposition: form-data; name="array[]"',
477 b"",
478 b"foo",
479 b"--6b7ba517decee4a450543ea6ae821c82",
480 b'Content-Disposition: form-data; name="array[]"',
481 b"",
482 b"bar",
483 b"--6b7ba517decee4a450543ea6ae821c82",
484 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
485 b"Content-Type: application/octet-stream",
486 b"",
487 b"hello world",
488 b"--6b7ba517decee4a450543ea6ae821c82--",
489 b"",
490 ]
491
492 @pytest.mark.respx(base_url=base_url)
493 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
494 class Model1(BaseModel):
495 name: str
496
497 class Model2(BaseModel):
498 foo: str
499
500 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
501
502 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
503 assert isinstance(response, Model2)
504 assert response.foo == "bar"
505
506 @pytest.mark.respx(base_url=base_url)
507 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
508 """Union of objects with the same field name using a different type"""
509
510 class Model1(BaseModel):
511 foo: int
512
513 class Model2(BaseModel):
514 foo: str
515
516 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
517
518 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
519 assert isinstance(response, Model2)
520 assert response.foo == "bar"
521
522 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
523
524 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
525 assert isinstance(response, Model1)
526 assert response.foo == 1
527
528 @pytest.mark.respx(base_url=base_url)
529 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
530 """
531 Response that sets Content-Type to something other than application/json but returns json data
532 """
533
534 class Model(BaseModel):
535 foo: int
536
537 respx_mock.get("/foo").mock(
538 return_value=httpx.Response(
539 200,
540 content=json.dumps({"foo": 2}),
541 headers={"Content-Type": "application/text"},
542 )
543 )
544
545 response = self.client.get("/foo", cast_to=Model)
546 assert isinstance(response, Model)
547 assert response.foo == 2
548
549 def test_base_url_setter(self) -> None:
550 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
551 assert client.base_url == "https://example.com/from_init/"
552
553 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
554
555 assert client.base_url == "https://example.com/from_setter/"
556
557 def test_base_url_env(self) -> None:
558 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
559 client = OpenAI(api_key=api_key, _strict_response_validation=True)
560 assert client.base_url == "http://localhost:5000/from/env/"
561
562 @pytest.mark.parametrize(
563 "client",
564 [
565 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
566 OpenAI(
567 base_url="http://localhost:5000/custom/path/",
568 api_key=api_key,
569 _strict_response_validation=True,
570 http_client=httpx.Client(),
571 ),
572 ],
573 ids=["standard", "custom http client"],
574 )
575 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
576 request = client._build_request(
577 FinalRequestOptions(
578 method="post",
579 url="/foo",
580 json_data={"foo": "bar"},
581 ),
582 )
583 assert request.url == "http://localhost:5000/custom/path/foo"
584
585 @pytest.mark.parametrize(
586 "client",
587 [
588 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
589 OpenAI(
590 base_url="http://localhost:5000/custom/path/",
591 api_key=api_key,
592 _strict_response_validation=True,
593 http_client=httpx.Client(),
594 ),
595 ],
596 ids=["standard", "custom http client"],
597 )
598 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
599 request = client._build_request(
600 FinalRequestOptions(
601 method="post",
602 url="/foo",
603 json_data={"foo": "bar"},
604 ),
605 )
606 assert request.url == "http://localhost:5000/custom/path/foo"
607
608 @pytest.mark.parametrize(
609 "client",
610 [
611 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
612 OpenAI(
613 base_url="http://localhost:5000/custom/path/",
614 api_key=api_key,
615 _strict_response_validation=True,
616 http_client=httpx.Client(),
617 ),
618 ],
619 ids=["standard", "custom http client"],
620 )
621 def test_absolute_request_url(self, client: OpenAI) -> None:
622 request = client._build_request(
623 FinalRequestOptions(
624 method="post",
625 url="https://myapi.com/foo",
626 json_data={"foo": "bar"},
627 ),
628 )
629 assert request.url == "https://myapi.com/foo"
630
631 def test_copied_client_does_not_close_http(self) -> None:
632 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
633 assert not client.is_closed()
634
635 copied = client.copy()
636 assert copied is not client
637
638 del copied
639
640 assert not client.is_closed()
641
642 def test_client_context_manager(self) -> None:
643 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
644 with client as c2:
645 assert c2 is client
646 assert not c2.is_closed()
647 assert not client.is_closed()
648 assert client.is_closed()
649
650 @pytest.mark.respx(base_url=base_url)
651 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
652 class Model(BaseModel):
653 foo: str
654
655 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
656
657 with pytest.raises(APIResponseValidationError) as exc:
658 self.client.get("/foo", cast_to=Model)
659
660 assert isinstance(exc.value.__cause__, ValidationError)
661
662 def test_client_max_retries_validation(self) -> None:
663 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
664 OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None))
665
666 @pytest.mark.respx(base_url=base_url)
667 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
668 class Model(BaseModel):
669 name: str
670
671 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
672
673 stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
674 assert isinstance(stream, Stream)
675 stream.response.close()
676
677 @pytest.mark.respx(base_url=base_url)
678 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
679 class Model(BaseModel):
680 name: str
681
682 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
683
684 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
685
686 with pytest.raises(APIResponseValidationError):
687 strict_client.get("/foo", cast_to=Model)
688
689 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
690
691 response = client.get("/foo", cast_to=Model)
692 assert isinstance(response, str) # type: ignore[unreachable]
693
694 @pytest.mark.parametrize(
695 "remaining_retries,retry_after,timeout",
696 [
697 [3, "20", 20],
698 [3, "0", 0.5],
699 [3, "-10", 0.5],
700 [3, "60", 60],
701 [3, "61", 0.5],
702 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
703 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
704 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
705 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
706 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
707 [3, "99999999999999999999999999999999999", 0.5],
708 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
709 [3, "", 0.5],
710 [2, "", 0.5 * 2.0],
711 [1, "", 0.5 * 4.0],
712 [-1100, "", 8], # test large number potentially overflowing
713 ],
714 )
715 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
716 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
717 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
718
719 headers = httpx.Headers({"retry-after": retry_after})
720 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
721 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
722 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
723
724 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
725 @pytest.mark.respx(base_url=base_url)
726 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: OpenAI) -> None:
727 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
728
729 with pytest.raises(APITimeoutError):
730 client.chat.completions.with_streaming_response.create(
731 messages=[
732 {
733 "content": "string",
734 "role": "developer",
735 }
736 ],
737 model="gpt-4o",
738 ).__enter__()
739
740 assert _get_open_connections(self.client) == 0
741
742 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
743 @pytest.mark.respx(base_url=base_url)
744 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: OpenAI) -> None:
745 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
746
747 with pytest.raises(APIStatusError):
748 client.chat.completions.with_streaming_response.create(
749 messages=[
750 {
751 "content": "string",
752 "role": "developer",
753 }
754 ],
755 model="gpt-4o",
756 ).__enter__()
757 assert _get_open_connections(self.client) == 0
758
759 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
760 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
761 @pytest.mark.respx(base_url=base_url)
762 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
763 def test_retries_taken(
764 self,
765 client: OpenAI,
766 failures_before_success: int,
767 failure_mode: Literal["status", "exception"],
768 respx_mock: MockRouter,
769 ) -> None:
770 client = client.with_options(max_retries=4)
771
772 nb_retries = 0
773
774 def retry_handler(_request: httpx.Request) -> httpx.Response:
775 nonlocal nb_retries
776 if nb_retries < failures_before_success:
777 nb_retries += 1
778 if failure_mode == "exception":
779 raise RuntimeError("oops")
780 return httpx.Response(500)
781 return httpx.Response(200)
782
783 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
784
785 response = client.chat.completions.with_raw_response.create(
786 messages=[
787 {
788 "content": "string",
789 "role": "developer",
790 }
791 ],
792 model="gpt-4o",
793 )
794
795 assert response.retries_taken == failures_before_success
796 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
797
798 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
799 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
800 @pytest.mark.respx(base_url=base_url)
801 def test_omit_retry_count_header(
802 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
803 ) -> None:
804 client = client.with_options(max_retries=4)
805
806 nb_retries = 0
807
808 def retry_handler(_request: httpx.Request) -> httpx.Response:
809 nonlocal nb_retries
810 if nb_retries < failures_before_success:
811 nb_retries += 1
812 return httpx.Response(500)
813 return httpx.Response(200)
814
815 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
816
817 response = client.chat.completions.with_raw_response.create(
818 messages=[
819 {
820 "content": "string",
821 "role": "developer",
822 }
823 ],
824 model="gpt-4o",
825 extra_headers={"x-stainless-retry-count": Omit()},
826 )
827
828 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
829
830 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
831 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
832 @pytest.mark.respx(base_url=base_url)
833 def test_overwrite_retry_count_header(
834 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
835 ) -> None:
836 client = client.with_options(max_retries=4)
837
838 nb_retries = 0
839
840 def retry_handler(_request: httpx.Request) -> httpx.Response:
841 nonlocal nb_retries
842 if nb_retries < failures_before_success:
843 nb_retries += 1
844 return httpx.Response(500)
845 return httpx.Response(200)
846
847 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
848
849 response = client.chat.completions.with_raw_response.create(
850 messages=[
851 {
852 "content": "string",
853 "role": "developer",
854 }
855 ],
856 model="gpt-4o",
857 extra_headers={"x-stainless-retry-count": "42"},
858 )
859
860 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
861
862 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
863 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
864 @pytest.mark.respx(base_url=base_url)
865 def test_retries_taken_new_response_class(
866 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
867 ) -> None:
868 client = client.with_options(max_retries=4)
869
870 nb_retries = 0
871
872 def retry_handler(_request: httpx.Request) -> httpx.Response:
873 nonlocal nb_retries
874 if nb_retries < failures_before_success:
875 nb_retries += 1
876 return httpx.Response(500)
877 return httpx.Response(200)
878
879 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
880
881 with client.chat.completions.with_streaming_response.create(
882 messages=[
883 {
884 "content": "string",
885 "role": "developer",
886 }
887 ],
888 model="gpt-4o",
889 ) as response:
890 assert response.retries_taken == failures_before_success
891 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
892
893 def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
894 # Test that the proxy environment variables are set correctly
895 monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
896
897 client = DefaultHttpxClient()
898
899 mounts = tuple(client._mounts.items())
900 assert len(mounts) == 1
901 assert mounts[0][0].pattern == "https://"
902
903 @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
904 def test_default_client_creation(self) -> None:
905 # Ensure that the client can be initialized without any exceptions
906 DefaultHttpxClient(
907 verify=True,
908 cert=None,
909 trust_env=True,
910 http1=True,
911 http2=False,
912 limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
913 )
914
915 @pytest.mark.respx(base_url=base_url)
916 def test_follow_redirects(self, respx_mock: MockRouter) -> None:
917 # Test that the default follow_redirects=True allows following redirects
918 respx_mock.post("/redirect").mock(
919 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
920 )
921 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
922
923 response = self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
924 assert response.status_code == 200
925 assert response.json() == {"status": "ok"}
926
927 @pytest.mark.respx(base_url=base_url)
928 def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
929 # Test that follow_redirects=False prevents following redirects
930 respx_mock.post("/redirect").mock(
931 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
932 )
933
934 with pytest.raises(APIStatusError) as exc_info:
935 self.client.post(
936 "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
937 )
938
939 assert exc_info.value.response.status_code == 302
940 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
941
942
943class TestAsyncOpenAI:
944 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
945
946 @pytest.mark.respx(base_url=base_url)
947 @pytest.mark.asyncio
948 async def test_raw_response(self, respx_mock: MockRouter) -> None:
949 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
950
951 response = await self.client.post("/foo", cast_to=httpx.Response)
952 assert response.status_code == 200
953 assert isinstance(response, httpx.Response)
954 assert response.json() == {"foo": "bar"}
955
956 @pytest.mark.respx(base_url=base_url)
957 @pytest.mark.asyncio
958 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
959 respx_mock.post("/foo").mock(
960 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
961 )
962
963 response = await self.client.post("/foo", cast_to=httpx.Response)
964 assert response.status_code == 200
965 assert isinstance(response, httpx.Response)
966 assert response.json() == {"foo": "bar"}
967
968 def test_copy(self) -> None:
969 copied = self.client.copy()
970 assert id(copied) != id(self.client)
971
972 copied = self.client.copy(api_key="another My API Key")
973 assert copied.api_key == "another My API Key"
974 assert self.client.api_key == "My API Key"
975
976 def test_copy_default_options(self) -> None:
977 # options that have a default are overridden correctly
978 copied = self.client.copy(max_retries=7)
979 assert copied.max_retries == 7
980 assert self.client.max_retries == 2
981
982 copied2 = copied.copy(max_retries=6)
983 assert copied2.max_retries == 6
984 assert copied.max_retries == 7
985
986 # timeout
987 assert isinstance(self.client.timeout, httpx.Timeout)
988 copied = self.client.copy(timeout=None)
989 assert copied.timeout is None
990 assert isinstance(self.client.timeout, httpx.Timeout)
991
992 def test_copy_default_headers(self) -> None:
993 client = AsyncOpenAI(
994 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
995 )
996 assert client.default_headers["X-Foo"] == "bar"
997
998 # does not override the already given value when not specified
999 copied = client.copy()
1000 assert copied.default_headers["X-Foo"] == "bar"
1001
1002 # merges already given headers
1003 copied = client.copy(default_headers={"X-Bar": "stainless"})
1004 assert copied.default_headers["X-Foo"] == "bar"
1005 assert copied.default_headers["X-Bar"] == "stainless"
1006
1007 # uses new values for any already given headers
1008 copied = client.copy(default_headers={"X-Foo": "stainless"})
1009 assert copied.default_headers["X-Foo"] == "stainless"
1010
1011 # set_default_headers
1012
1013 # completely overrides already set values
1014 copied = client.copy(set_default_headers={})
1015 assert copied.default_headers.get("X-Foo") is None
1016
1017 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
1018 assert copied.default_headers["X-Bar"] == "Robert"
1019
1020 with pytest.raises(
1021 ValueError,
1022 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
1023 ):
1024 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
1025
1026 def test_copy_default_query(self) -> None:
1027 client = AsyncOpenAI(
1028 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
1029 )
1030 assert _get_params(client)["foo"] == "bar"
1031
1032 # does not override the already given value when not specified
1033 copied = client.copy()
1034 assert _get_params(copied)["foo"] == "bar"
1035
1036 # merges already given params
1037 copied = client.copy(default_query={"bar": "stainless"})
1038 params = _get_params(copied)
1039 assert params["foo"] == "bar"
1040 assert params["bar"] == "stainless"
1041
1042 # uses new values for any already given headers
1043 copied = client.copy(default_query={"foo": "stainless"})
1044 assert _get_params(copied)["foo"] == "stainless"
1045
1046 # set_default_query
1047
1048 # completely overrides already set values
1049 copied = client.copy(set_default_query={})
1050 assert _get_params(copied) == {}
1051
1052 copied = client.copy(set_default_query={"bar": "Robert"})
1053 assert _get_params(copied)["bar"] == "Robert"
1054
1055 with pytest.raises(
1056 ValueError,
1057 # TODO: update
1058 match="`default_query` and `set_default_query` arguments are mutually exclusive",
1059 ):
1060 client.copy(set_default_query={}, default_query={"foo": "Bar"})
1061
1062 def test_copy_signature(self) -> None:
1063 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
1064 init_signature = inspect.signature(
1065 # mypy doesn't like that we access the `__init__` property.
1066 self.client.__init__, # type: ignore[misc]
1067 )
1068 copy_signature = inspect.signature(self.client.copy)
1069 exclude_params = {"transport", "proxies", "_strict_response_validation"}
1070
1071 for name in init_signature.parameters.keys():
1072 if name in exclude_params:
1073 continue
1074
1075 copy_param = copy_signature.parameters.get(name)
1076 assert copy_param is not None, f"copy() signature is missing the {name} param"
1077
1078 @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
1079 def test_copy_build_request(self) -> None:
1080 options = FinalRequestOptions(method="get", url="/foo")
1081
1082 def build_request(options: FinalRequestOptions) -> None:
1083 client = self.client.copy()
1084 client._build_request(options)
1085
1086 # ensure that the machinery is warmed up before tracing starts.
1087 build_request(options)
1088 gc.collect()
1089
1090 tracemalloc.start(1000)
1091
1092 snapshot_before = tracemalloc.take_snapshot()
1093
1094 ITERATIONS = 10
1095 for _ in range(ITERATIONS):
1096 build_request(options)
1097
1098 gc.collect()
1099 snapshot_after = tracemalloc.take_snapshot()
1100
1101 tracemalloc.stop()
1102
1103 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
1104 if diff.count == 0:
1105 # Avoid false positives by considering only leaks (i.e. allocations that persist).
1106 return
1107
1108 if diff.count % ITERATIONS != 0:
1109 # Avoid false positives by considering only leaks that appear per iteration.
1110 return
1111
1112 for frame in diff.traceback:
1113 if any(
1114 frame.filename.endswith(fragment)
1115 for fragment in [
1116 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
1117 #
1118 # removing the decorator fixes the leak for reasons we don't understand.
1119 "openai/_legacy_response.py",
1120 "openai/_response.py",
1121 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
1122 "openai/_compat.py",
1123 # Standard library leaks we don't care about.
1124 "/logging/__init__.py",
1125 ]
1126 ):
1127 return
1128
1129 leaks.append(diff)
1130
1131 leaks: list[tracemalloc.StatisticDiff] = []
1132 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
1133 add_leak(leaks, diff)
1134 if leaks:
1135 for leak in leaks:
1136 print("MEMORY LEAK:", leak)
1137 for frame in leak.traceback:
1138 print(frame)
1139 raise AssertionError()
1140
1141 async def test_request_timeout(self) -> None:
1142 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
1143 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1144 assert timeout == DEFAULT_TIMEOUT
1145
1146 request = self.client._build_request(
1147 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
1148 )
1149 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1150 assert timeout == httpx.Timeout(100.0)
1151
1152 async def test_client_timeout_option(self) -> None:
1153 client = AsyncOpenAI(
1154 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
1155 )
1156
1157 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1158 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1159 assert timeout == httpx.Timeout(0)
1160
1161 async def test_http_client_timeout_option(self) -> None:
1162 # custom timeout given to the httpx client should be used
1163 async with httpx.AsyncClient(timeout=None) as http_client:
1164 client = AsyncOpenAI(
1165 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1166 )
1167
1168 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1169 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1170 assert timeout == httpx.Timeout(None)
1171
1172 # no timeout given to the httpx client should not use the httpx default
1173 async with httpx.AsyncClient() as http_client:
1174 client = AsyncOpenAI(
1175 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1176 )
1177
1178 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1179 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1180 assert timeout == DEFAULT_TIMEOUT
1181
1182 # explicitly passing the default timeout currently results in it being ignored
1183 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1184 client = AsyncOpenAI(
1185 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1186 )
1187
1188 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1189 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1190 assert timeout == DEFAULT_TIMEOUT # our default
1191
1192 def test_invalid_http_client(self) -> None:
1193 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1194 with httpx.Client() as http_client:
1195 AsyncOpenAI(
1196 base_url=base_url,
1197 api_key=api_key,
1198 _strict_response_validation=True,
1199 http_client=cast(Any, http_client),
1200 )
1201
1202 def test_default_headers_option(self) -> None:
1203 client = AsyncOpenAI(
1204 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1205 )
1206 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1207 assert request.headers.get("x-foo") == "bar"
1208 assert request.headers.get("x-stainless-lang") == "python"
1209
1210 client2 = AsyncOpenAI(
1211 base_url=base_url,
1212 api_key=api_key,
1213 _strict_response_validation=True,
1214 default_headers={
1215 "X-Foo": "stainless",
1216 "X-Stainless-Lang": "my-overriding-header",
1217 },
1218 )
1219 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1220 assert request.headers.get("x-foo") == "stainless"
1221 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1222
1223 def test_validate_headers(self) -> None:
1224 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1225 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1226 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1227
1228 with pytest.raises(OpenAIError):
1229 with update_env(**{"OPENAI_API_KEY": Omit()}):
1230 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1231 _ = client2
1232
1233 def test_default_query_option(self) -> None:
1234 client = AsyncOpenAI(
1235 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1236 )
1237 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1238 url = httpx.URL(request.url)
1239 assert dict(url.params) == {"query_param": "bar"}
1240
1241 request = client._build_request(
1242 FinalRequestOptions(
1243 method="get",
1244 url="/foo",
1245 params={"foo": "baz", "query_param": "overridden"},
1246 )
1247 )
1248 url = httpx.URL(request.url)
1249 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
1250
1251 def test_request_extra_json(self) -> None:
1252 request = self.client._build_request(
1253 FinalRequestOptions(
1254 method="post",
1255 url="/foo",
1256 json_data={"foo": "bar"},
1257 extra_json={"baz": False},
1258 ),
1259 )
1260 data = json.loads(request.content.decode("utf-8"))
1261 assert data == {"foo": "bar", "baz": False}
1262
1263 request = self.client._build_request(
1264 FinalRequestOptions(
1265 method="post",
1266 url="/foo",
1267 extra_json={"baz": False},
1268 ),
1269 )
1270 data = json.loads(request.content.decode("utf-8"))
1271 assert data == {"baz": False}
1272
1273 # `extra_json` takes priority over `json_data` when keys clash
1274 request = self.client._build_request(
1275 FinalRequestOptions(
1276 method="post",
1277 url="/foo",
1278 json_data={"foo": "bar", "baz": True},
1279 extra_json={"baz": None},
1280 ),
1281 )
1282 data = json.loads(request.content.decode("utf-8"))
1283 assert data == {"foo": "bar", "baz": None}
1284
1285 def test_request_extra_headers(self) -> None:
1286 request = self.client._build_request(
1287 FinalRequestOptions(
1288 method="post",
1289 url="/foo",
1290 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1291 ),
1292 )
1293 assert request.headers.get("X-Foo") == "Foo"
1294
1295 # `extra_headers` takes priority over `default_headers` when keys clash
1296 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1297 FinalRequestOptions(
1298 method="post",
1299 url="/foo",
1300 **make_request_options(
1301 extra_headers={"X-Bar": "false"},
1302 ),
1303 ),
1304 )
1305 assert request.headers.get("X-Bar") == "false"
1306
1307 def test_request_extra_query(self) -> None:
1308 request = self.client._build_request(
1309 FinalRequestOptions(
1310 method="post",
1311 url="/foo",
1312 **make_request_options(
1313 extra_query={"my_query_param": "Foo"},
1314 ),
1315 ),
1316 )
1317 params = dict(request.url.params)
1318 assert params == {"my_query_param": "Foo"}
1319
1320 # if both `query` and `extra_query` are given, they are merged
1321 request = self.client._build_request(
1322 FinalRequestOptions(
1323 method="post",
1324 url="/foo",
1325 **make_request_options(
1326 query={"bar": "1"},
1327 extra_query={"foo": "2"},
1328 ),
1329 ),
1330 )
1331 params = dict(request.url.params)
1332 assert params == {"bar": "1", "foo": "2"}
1333
1334 # `extra_query` takes priority over `query` when keys clash
1335 request = self.client._build_request(
1336 FinalRequestOptions(
1337 method="post",
1338 url="/foo",
1339 **make_request_options(
1340 query={"foo": "1"},
1341 extra_query={"foo": "2"},
1342 ),
1343 ),
1344 )
1345 params = dict(request.url.params)
1346 assert params == {"foo": "2"}
1347
1348 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1349 request = async_client._build_request(
1350 FinalRequestOptions.construct(
1351 method="post",
1352 url="/foo",
1353 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1354 json_data={"array": ["foo", "bar"]},
1355 files=[("foo.txt", b"hello world")],
1356 )
1357 )
1358
1359 assert request.read().split(b"\r\n") == [
1360 b"--6b7ba517decee4a450543ea6ae821c82",
1361 b'Content-Disposition: form-data; name="array[]"',
1362 b"",
1363 b"foo",
1364 b"--6b7ba517decee4a450543ea6ae821c82",
1365 b'Content-Disposition: form-data; name="array[]"',
1366 b"",
1367 b"bar",
1368 b"--6b7ba517decee4a450543ea6ae821c82",
1369 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1370 b"Content-Type: application/octet-stream",
1371 b"",
1372 b"hello world",
1373 b"--6b7ba517decee4a450543ea6ae821c82--",
1374 b"",
1375 ]
1376
1377 @pytest.mark.respx(base_url=base_url)
1378 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1379 class Model1(BaseModel):
1380 name: str
1381
1382 class Model2(BaseModel):
1383 foo: str
1384
1385 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1386
1387 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1388 assert isinstance(response, Model2)
1389 assert response.foo == "bar"
1390
1391 @pytest.mark.respx(base_url=base_url)
1392 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1393 """Union of objects with the same field name using a different type"""
1394
1395 class Model1(BaseModel):
1396 foo: int
1397
1398 class Model2(BaseModel):
1399 foo: str
1400
1401 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1402
1403 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1404 assert isinstance(response, Model2)
1405 assert response.foo == "bar"
1406
1407 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1408
1409 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1410 assert isinstance(response, Model1)
1411 assert response.foo == 1
1412
1413 @pytest.mark.respx(base_url=base_url)
1414 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1415 """
1416 Response that sets Content-Type to something other than application/json but returns json data
1417 """
1418
1419 class Model(BaseModel):
1420 foo: int
1421
1422 respx_mock.get("/foo").mock(
1423 return_value=httpx.Response(
1424 200,
1425 content=json.dumps({"foo": 2}),
1426 headers={"Content-Type": "application/text"},
1427 )
1428 )
1429
1430 response = await self.client.get("/foo", cast_to=Model)
1431 assert isinstance(response, Model)
1432 assert response.foo == 2
1433
1434 def test_base_url_setter(self) -> None:
1435 client = AsyncOpenAI(
1436 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1437 )
1438 assert client.base_url == "https://example.com/from_init/"
1439
1440 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1441
1442 assert client.base_url == "https://example.com/from_setter/"
1443
1444 def test_base_url_env(self) -> None:
1445 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1446 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1447 assert client.base_url == "http://localhost:5000/from/env/"
1448
1449 @pytest.mark.parametrize(
1450 "client",
1451 [
1452 AsyncOpenAI(
1453 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1454 ),
1455 AsyncOpenAI(
1456 base_url="http://localhost:5000/custom/path/",
1457 api_key=api_key,
1458 _strict_response_validation=True,
1459 http_client=httpx.AsyncClient(),
1460 ),
1461 ],
1462 ids=["standard", "custom http client"],
1463 )
1464 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1465 request = client._build_request(
1466 FinalRequestOptions(
1467 method="post",
1468 url="/foo",
1469 json_data={"foo": "bar"},
1470 ),
1471 )
1472 assert request.url == "http://localhost:5000/custom/path/foo"
1473
1474 @pytest.mark.parametrize(
1475 "client",
1476 [
1477 AsyncOpenAI(
1478 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1479 ),
1480 AsyncOpenAI(
1481 base_url="http://localhost:5000/custom/path/",
1482 api_key=api_key,
1483 _strict_response_validation=True,
1484 http_client=httpx.AsyncClient(),
1485 ),
1486 ],
1487 ids=["standard", "custom http client"],
1488 )
1489 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1490 request = client._build_request(
1491 FinalRequestOptions(
1492 method="post",
1493 url="/foo",
1494 json_data={"foo": "bar"},
1495 ),
1496 )
1497 assert request.url == "http://localhost:5000/custom/path/foo"
1498
1499 @pytest.mark.parametrize(
1500 "client",
1501 [
1502 AsyncOpenAI(
1503 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1504 ),
1505 AsyncOpenAI(
1506 base_url="http://localhost:5000/custom/path/",
1507 api_key=api_key,
1508 _strict_response_validation=True,
1509 http_client=httpx.AsyncClient(),
1510 ),
1511 ],
1512 ids=["standard", "custom http client"],
1513 )
1514 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1515 request = client._build_request(
1516 FinalRequestOptions(
1517 method="post",
1518 url="https://myapi.com/foo",
1519 json_data={"foo": "bar"},
1520 ),
1521 )
1522 assert request.url == "https://myapi.com/foo"
1523
1524 async def test_copied_client_does_not_close_http(self) -> None:
1525 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1526 assert not client.is_closed()
1527
1528 copied = client.copy()
1529 assert copied is not client
1530
1531 del copied
1532
1533 await asyncio.sleep(0.2)
1534 assert not client.is_closed()
1535
1536 async def test_client_context_manager(self) -> None:
1537 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1538 async with client as c2:
1539 assert c2 is client
1540 assert not c2.is_closed()
1541 assert not client.is_closed()
1542 assert client.is_closed()
1543
1544 @pytest.mark.respx(base_url=base_url)
1545 @pytest.mark.asyncio
1546 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1547 class Model(BaseModel):
1548 foo: str
1549
1550 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1551
1552 with pytest.raises(APIResponseValidationError) as exc:
1553 await self.client.get("/foo", cast_to=Model)
1554
1555 assert isinstance(exc.value.__cause__, ValidationError)
1556
1557 async def test_client_max_retries_validation(self) -> None:
1558 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
1559 AsyncOpenAI(
1560 base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)
1561 )
1562
1563 @pytest.mark.respx(base_url=base_url)
1564 @pytest.mark.asyncio
1565 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1566 class Model(BaseModel):
1567 name: str
1568
1569 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1570
1571 stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1572 assert isinstance(stream, AsyncStream)
1573 await stream.response.aclose()
1574
1575 @pytest.mark.respx(base_url=base_url)
1576 @pytest.mark.asyncio
1577 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1578 class Model(BaseModel):
1579 name: str
1580
1581 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1582
1583 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1584
1585 with pytest.raises(APIResponseValidationError):
1586 await strict_client.get("/foo", cast_to=Model)
1587
1588 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1589
1590 response = await client.get("/foo", cast_to=Model)
1591 assert isinstance(response, str) # type: ignore[unreachable]
1592
1593 @pytest.mark.parametrize(
1594 "remaining_retries,retry_after,timeout",
1595 [
1596 [3, "20", 20],
1597 [3, "0", 0.5],
1598 [3, "-10", 0.5],
1599 [3, "60", 60],
1600 [3, "61", 0.5],
1601 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1602 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1603 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1604 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1605 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1606 [3, "99999999999999999999999999999999999", 0.5],
1607 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1608 [3, "", 0.5],
1609 [2, "", 0.5 * 2.0],
1610 [1, "", 0.5 * 4.0],
1611 [-1100, "", 8], # test large number potentially overflowing
1612 ],
1613 )
1614 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1615 @pytest.mark.asyncio
1616 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1617 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1618
1619 headers = httpx.Headers({"retry-after": retry_after})
1620 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1621 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1622 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1623
1624 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1625 @pytest.mark.respx(base_url=base_url)
1626 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1627 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1628
1629 with pytest.raises(APITimeoutError):
1630 await async_client.chat.completions.with_streaming_response.create(
1631 messages=[
1632 {
1633 "content": "string",
1634 "role": "developer",
1635 }
1636 ],
1637 model="gpt-4o",
1638 ).__aenter__()
1639
1640 assert _get_open_connections(self.client) == 0
1641
1642 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1643 @pytest.mark.respx(base_url=base_url)
1644 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1645 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1646
1647 with pytest.raises(APIStatusError):
1648 await async_client.chat.completions.with_streaming_response.create(
1649 messages=[
1650 {
1651 "content": "string",
1652 "role": "developer",
1653 }
1654 ],
1655 model="gpt-4o",
1656 ).__aenter__()
1657 assert _get_open_connections(self.client) == 0
1658
1659 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1660 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1661 @pytest.mark.respx(base_url=base_url)
1662 @pytest.mark.asyncio
1663 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
1664 async def test_retries_taken(
1665 self,
1666 async_client: AsyncOpenAI,
1667 failures_before_success: int,
1668 failure_mode: Literal["status", "exception"],
1669 respx_mock: MockRouter,
1670 ) -> None:
1671 client = async_client.with_options(max_retries=4)
1672
1673 nb_retries = 0
1674
1675 def retry_handler(_request: httpx.Request) -> httpx.Response:
1676 nonlocal nb_retries
1677 if nb_retries < failures_before_success:
1678 nb_retries += 1
1679 if failure_mode == "exception":
1680 raise RuntimeError("oops")
1681 return httpx.Response(500)
1682 return httpx.Response(200)
1683
1684 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1685
1686 response = await client.chat.completions.with_raw_response.create(
1687 messages=[
1688 {
1689 "content": "string",
1690 "role": "developer",
1691 }
1692 ],
1693 model="gpt-4o",
1694 )
1695
1696 assert response.retries_taken == failures_before_success
1697 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1698
1699 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1700 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1701 @pytest.mark.respx(base_url=base_url)
1702 @pytest.mark.asyncio
1703 async def test_omit_retry_count_header(
1704 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1705 ) -> None:
1706 client = async_client.with_options(max_retries=4)
1707
1708 nb_retries = 0
1709
1710 def retry_handler(_request: httpx.Request) -> httpx.Response:
1711 nonlocal nb_retries
1712 if nb_retries < failures_before_success:
1713 nb_retries += 1
1714 return httpx.Response(500)
1715 return httpx.Response(200)
1716
1717 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1718
1719 response = await client.chat.completions.with_raw_response.create(
1720 messages=[
1721 {
1722 "content": "string",
1723 "role": "developer",
1724 }
1725 ],
1726 model="gpt-4o",
1727 extra_headers={"x-stainless-retry-count": Omit()},
1728 )
1729
1730 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
1731
1732 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1733 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1734 @pytest.mark.respx(base_url=base_url)
1735 @pytest.mark.asyncio
1736 async def test_overwrite_retry_count_header(
1737 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1738 ) -> None:
1739 client = async_client.with_options(max_retries=4)
1740
1741 nb_retries = 0
1742
1743 def retry_handler(_request: httpx.Request) -> httpx.Response:
1744 nonlocal nb_retries
1745 if nb_retries < failures_before_success:
1746 nb_retries += 1
1747 return httpx.Response(500)
1748 return httpx.Response(200)
1749
1750 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1751
1752 response = await client.chat.completions.with_raw_response.create(
1753 messages=[
1754 {
1755 "content": "string",
1756 "role": "developer",
1757 }
1758 ],
1759 model="gpt-4o",
1760 extra_headers={"x-stainless-retry-count": "42"},
1761 )
1762
1763 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
1764
1765 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1766 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1767 @pytest.mark.respx(base_url=base_url)
1768 @pytest.mark.asyncio
1769 async def test_retries_taken_new_response_class(
1770 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1771 ) -> None:
1772 client = async_client.with_options(max_retries=4)
1773
1774 nb_retries = 0
1775
1776 def retry_handler(_request: httpx.Request) -> httpx.Response:
1777 nonlocal nb_retries
1778 if nb_retries < failures_before_success:
1779 nb_retries += 1
1780 return httpx.Response(500)
1781 return httpx.Response(200)
1782
1783 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1784
1785 async with client.chat.completions.with_streaming_response.create(
1786 messages=[
1787 {
1788 "content": "string",
1789 "role": "developer",
1790 }
1791 ],
1792 model="gpt-4o",
1793 ) as response:
1794 assert response.retries_taken == failures_before_success
1795 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1796
1797 def test_get_platform(self) -> None:
1798 # A previous implementation of asyncify could leave threads unterminated when
1799 # used with nest_asyncio.
1800 #
1801 # Since nest_asyncio.apply() is global and cannot be un-applied, this
1802 # test is run in a separate process to avoid affecting other tests.
1803 test_code = dedent("""
1804 import asyncio
1805 import nest_asyncio
1806 import threading
1807
1808 from openai._utils import asyncify
1809 from openai._base_client import get_platform
1810
1811 async def test_main() -> None:
1812 result = await asyncify(get_platform)()
1813 print(result)
1814 for thread in threading.enumerate():
1815 print(thread.name)
1816
1817 nest_asyncio.apply()
1818 asyncio.run(test_main())
1819 """)
1820 with subprocess.Popen(
1821 [sys.executable, "-c", test_code],
1822 text=True,
1823 ) as process:
1824 timeout = 10 # seconds
1825
1826 start_time = time.monotonic()
1827 while True:
1828 return_code = process.poll()
1829 if return_code is not None:
1830 if return_code != 0:
1831 raise AssertionError("calling get_platform using asyncify resulted in a non-zero exit code")
1832
1833 # success
1834 break
1835
1836 if time.monotonic() - start_time > timeout:
1837 process.kill()
1838 raise AssertionError("calling get_platform using asyncify resulted in a hung process")
1839
1840 time.sleep(0.1)
1841
1842 async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
1843 # Test that the proxy environment variables are set correctly
1844 monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
1845
1846 client = DefaultAsyncHttpxClient()
1847
1848 mounts = tuple(client._mounts.items())
1849 assert len(mounts) == 1
1850 assert mounts[0][0].pattern == "https://"
1851
1852 @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
1853 async def test_default_client_creation(self) -> None:
1854 # Ensure that the client can be initialized without any exceptions
1855 DefaultAsyncHttpxClient(
1856 verify=True,
1857 cert=None,
1858 trust_env=True,
1859 http1=True,
1860 http2=False,
1861 limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
1862 )
1863
1864 @pytest.mark.respx(base_url=base_url)
1865 async def test_follow_redirects(self, respx_mock: MockRouter) -> None:
1866 # Test that the default follow_redirects=True allows following redirects
1867 respx_mock.post("/redirect").mock(
1868 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1869 )
1870 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
1871
1872 response = await self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
1873 assert response.status_code == 200
1874 assert response.json() == {"status": "ok"}
1875
1876 @pytest.mark.respx(base_url=base_url)
1877 async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
1878 # Test that follow_redirects=False prevents following redirects
1879 respx_mock.post("/redirect").mock(
1880 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1881 )
1882
1883 with pytest.raises(APIStatusError) as exc_info:
1884 await self.client.post(
1885 "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
1886 )
1887
1888 assert exc_info.value.response.status_code == 302
1889 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
1890