openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
dev/codex/package-manager-safety-dry-run

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

2572lines · modecode

1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import gc
6import os
7import sys
8import json
9import asyncio
10import inspect
11import dataclasses
12import tracemalloc
13from typing import Any, Union, TypeVar, Callable, Iterable, Iterator, Optional, Protocol, Coroutine, cast
14from unittest import mock
15from typing_extensions import Literal, AsyncIterator, override
16
17import httpx
18import pytest
19from respx import MockRouter
20from pydantic import ValidationError
21
22from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
23from openai.auth import WorkloadIdentity
24from openai._types import Omit
25from openai._utils import asyncify
26from openai._models import BaseModel, FinalRequestOptions
27from openai._streaming import Stream, AsyncStream
28from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
29from openai._base_client import (
30 DEFAULT_TIMEOUT,
31 HTTPX_DEFAULT_TIMEOUT,
32 BaseClient,
33 OtherPlatform,
34 DefaultHttpxClient,
35 DefaultAsyncHttpxClient,
36 get_platform,
37 make_request_options,
38)
39
40from .utils import update_env
41
42T = TypeVar("T")
43base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
44api_key = "My API Key"
45workload_identity: WorkloadIdentity = {
46 "client_id": "client-id",
47 "identity_provider_id": "identity-provider-id",
48 "service_account_id": "service-account-id",
49 "provider": {
50 "token_type": "jwt",
51 "get_token": lambda: "external-subject-token",
52 },
53}
54
55
56class MockRequestCall(Protocol):
57 request: httpx.Request
58
59
60def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
61 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
62 url = httpx.URL(request.url)
63 return dict(url.params)
64
65
66def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
67 return 0.1
68
69
70def mirror_request_content(request: httpx.Request) -> httpx.Response:
71 return httpx.Response(200, content=request.content)
72
73
74# note: we can't use the httpx.MockTransport class as it consumes the request
75# body itself, which means we can't test that the body is read lazily
76class MockTransport(httpx.BaseTransport, httpx.AsyncBaseTransport):
77 def __init__(
78 self,
79 handler: Callable[[httpx.Request], httpx.Response]
80 | Callable[[httpx.Request], Coroutine[Any, Any, httpx.Response]],
81 ) -> None:
82 self.handler = handler
83
84 @override
85 def handle_request(
86 self,
87 request: httpx.Request,
88 ) -> httpx.Response:
89 assert not inspect.iscoroutinefunction(self.handler), "handler must not be a coroutine function"
90 assert inspect.isfunction(self.handler), "handler must be a function"
91 return self.handler(request)
92
93 @override
94 async def handle_async_request(
95 self,
96 request: httpx.Request,
97 ) -> httpx.Response:
98 assert inspect.iscoroutinefunction(self.handler), "handler must be a coroutine function"
99 return await self.handler(request)
100
101
102@dataclasses.dataclass
103class Counter:
104 value: int = 0
105
106
107def _make_sync_iterator(iterable: Iterable[T], counter: Optional[Counter] = None) -> Iterator[T]:
108 for item in iterable:
109 if counter:
110 counter.value += 1
111 yield item
112
113
114async def _make_async_iterator(iterable: Iterable[T], counter: Optional[Counter] = None) -> AsyncIterator[T]:
115 for item in iterable:
116 if counter:
117 counter.value += 1
118 yield item
119
120
121def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
122 transport = client._client._transport
123 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
124
125 pool = transport._pool
126 return len(pool._requests)
127
128
129class TestOpenAI:
130 @pytest.mark.respx(base_url=base_url)
131 def test_raw_response(self, respx_mock: MockRouter, client: OpenAI) -> None:
132 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
133
134 response = client.post("/foo", cast_to=httpx.Response)
135 assert response.status_code == 200
136 assert isinstance(response, httpx.Response)
137 assert response.json() == {"foo": "bar"}
138
139 @pytest.mark.respx(base_url=base_url)
140 def test_raw_response_for_binary(self, respx_mock: MockRouter, client: OpenAI) -> None:
141 respx_mock.post("/foo").mock(
142 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
143 )
144
145 response = client.post("/foo", cast_to=httpx.Response)
146 assert response.status_code == 200
147 assert isinstance(response, httpx.Response)
148 assert response.json() == {"foo": "bar"}
149
150 def test_copy(self, client: OpenAI) -> None:
151 copied = client.copy()
152 assert id(copied) != id(client)
153
154 copied = client.copy(api_key="another My API Key")
155 assert copied.api_key == "another My API Key"
156 assert client.api_key == "My API Key"
157
158 def test_copy_default_options(self, client: OpenAI) -> None:
159 # options that have a default are overridden correctly
160 copied = client.copy(max_retries=7)
161 assert copied.max_retries == 7
162 assert client.max_retries == 2
163
164 copied2 = copied.copy(max_retries=6)
165 assert copied2.max_retries == 6
166 assert copied.max_retries == 7
167
168 # timeout
169 assert isinstance(client.timeout, httpx.Timeout)
170 copied = client.copy(timeout=None)
171 assert copied.timeout is None
172 assert isinstance(client.timeout, httpx.Timeout)
173
174 def test_copy_default_headers(self) -> None:
175 client = OpenAI(
176 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
177 )
178 assert client.default_headers["X-Foo"] == "bar"
179
180 # does not override the already given value when not specified
181 copied = client.copy()
182 assert copied.default_headers["X-Foo"] == "bar"
183
184 # merges already given headers
185 copied = client.copy(default_headers={"X-Bar": "stainless"})
186 assert copied.default_headers["X-Foo"] == "bar"
187 assert copied.default_headers["X-Bar"] == "stainless"
188
189 # uses new values for any already given headers
190 copied = client.copy(default_headers={"X-Foo": "stainless"})
191 assert copied.default_headers["X-Foo"] == "stainless"
192
193 # set_default_headers
194
195 # completely overrides already set values
196 copied = client.copy(set_default_headers={})
197 assert copied.default_headers.get("X-Foo") is None
198
199 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
200 assert copied.default_headers["X-Bar"] == "Robert"
201
202 with pytest.raises(
203 ValueError,
204 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
205 ):
206 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
207 client.close()
208
209 def test_copy_default_query(self) -> None:
210 client = OpenAI(
211 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
212 )
213 assert _get_params(client)["foo"] == "bar"
214
215 # does not override the already given value when not specified
216 copied = client.copy()
217 assert _get_params(copied)["foo"] == "bar"
218
219 # merges already given params
220 copied = client.copy(default_query={"bar": "stainless"})
221 params = _get_params(copied)
222 assert params["foo"] == "bar"
223 assert params["bar"] == "stainless"
224
225 # uses new values for any already given headers
226 copied = client.copy(default_query={"foo": "stainless"})
227 assert _get_params(copied)["foo"] == "stainless"
228
229 # set_default_query
230
231 # completely overrides already set values
232 copied = client.copy(set_default_query={})
233 assert _get_params(copied) == {}
234
235 copied = client.copy(set_default_query={"bar": "Robert"})
236 assert _get_params(copied)["bar"] == "Robert"
237
238 with pytest.raises(
239 ValueError,
240 # TODO: update
241 match="`default_query` and `set_default_query` arguments are mutually exclusive",
242 ):
243 client.copy(set_default_query={}, default_query={"foo": "Bar"})
244
245 client.close()
246
247 def test_copy_signature(self, client: OpenAI) -> None:
248 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
249 init_signature = inspect.signature(
250 # mypy doesn't like that we access the `__init__` property.
251 client.__init__, # type: ignore[misc]
252 )
253 copy_signature = inspect.signature(client.copy)
254 exclude_params = {"transport", "proxies", "_strict_response_validation"}
255
256 for name in init_signature.parameters.keys():
257 if name in exclude_params:
258 continue
259
260 copy_param = copy_signature.parameters.get(name)
261 assert copy_param is not None, f"copy() signature is missing the {name} param"
262
263 @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
264 def test_copy_build_request(self, client: OpenAI) -> None:
265 options = FinalRequestOptions(method="get", url="/foo")
266
267 def build_request(options: FinalRequestOptions) -> None:
268 client_copy = client.copy()
269 client_copy._build_request(options)
270
271 # ensure that the machinery is warmed up before tracing starts.
272 build_request(options)
273 gc.collect()
274
275 tracemalloc.start(1000)
276
277 snapshot_before = tracemalloc.take_snapshot()
278
279 ITERATIONS = 10
280 for _ in range(ITERATIONS):
281 build_request(options)
282
283 gc.collect()
284 snapshot_after = tracemalloc.take_snapshot()
285
286 tracemalloc.stop()
287
288 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
289 if diff.count == 0:
290 # Avoid false positives by considering only leaks (i.e. allocations that persist).
291 return
292
293 if diff.count % ITERATIONS != 0:
294 # Avoid false positives by considering only leaks that appear per iteration.
295 return
296
297 for frame in diff.traceback:
298 if any(
299 frame.filename.endswith(fragment)
300 for fragment in [
301 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
302 #
303 # removing the decorator fixes the leak for reasons we don't understand.
304 "openai/_legacy_response.py",
305 "openai/_response.py",
306 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
307 "openai/_compat.py",
308 # Standard library leaks we don't care about.
309 "/logging/__init__.py",
310 ]
311 ):
312 return
313
314 leaks.append(diff)
315
316 leaks: list[tracemalloc.StatisticDiff] = []
317 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
318 add_leak(leaks, diff)
319 if leaks:
320 for leak in leaks:
321 print("MEMORY LEAK:", leak)
322 for frame in leak.traceback:
323 print(frame)
324 raise AssertionError()
325
326 def test_request_timeout(self, client: OpenAI) -> None:
327 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
328 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
329 assert timeout == DEFAULT_TIMEOUT
330
331 request = client._build_request(FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)))
332 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
333 assert timeout == httpx.Timeout(100.0)
334
335 def test_client_timeout_option(self) -> None:
336 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
337
338 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
339 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
340 assert timeout == httpx.Timeout(0)
341
342 client.close()
343
344 def test_http_client_timeout_option(self) -> None:
345 # custom timeout given to the httpx client should be used
346 with httpx.Client(timeout=None) as http_client:
347 client = OpenAI(
348 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
349 )
350
351 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
352 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
353 assert timeout == httpx.Timeout(None)
354
355 client.close()
356
357 # no timeout given to the httpx client should not use the httpx default
358 with httpx.Client() as http_client:
359 client = OpenAI(
360 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
361 )
362
363 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
364 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
365 assert timeout == DEFAULT_TIMEOUT
366
367 client.close()
368
369 # explicitly passing the default timeout currently results in it being ignored
370 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
371 client = OpenAI(
372 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
373 )
374
375 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
376 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
377 assert timeout == DEFAULT_TIMEOUT # our default
378
379 client.close()
380
381 async def test_invalid_http_client(self) -> None:
382 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
383 async with httpx.AsyncClient() as http_client:
384 OpenAI(
385 base_url=base_url,
386 api_key=api_key,
387 _strict_response_validation=True,
388 http_client=cast(Any, http_client),
389 )
390
391 def test_default_headers_option(self) -> None:
392 test_client = OpenAI(
393 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
394 )
395 request = test_client._build_request(FinalRequestOptions(method="get", url="/foo"))
396 assert request.headers.get("x-foo") == "bar"
397 assert request.headers.get("x-stainless-lang") == "python"
398
399 test_client2 = OpenAI(
400 base_url=base_url,
401 api_key=api_key,
402 _strict_response_validation=True,
403 default_headers={
404 "X-Foo": "stainless",
405 "X-Stainless-Lang": "my-overriding-header",
406 },
407 )
408 request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo"))
409 assert request.headers.get("x-foo") == "stainless"
410 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
411
412 test_client.close()
413 test_client2.close()
414
415 def test_validate_headers(self) -> None:
416 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
417 options = client._prepare_options(FinalRequestOptions(method="get", url="/foo"))
418 request = client._build_request(options)
419
420 assert request.headers.get("Authorization") == f"Bearer {api_key}"
421
422 with pytest.raises(OpenAIError):
423 with update_env(**{"OPENAI_API_KEY": Omit()}):
424 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
425 _ = client2
426
427 def test_workload_identity_is_mutually_exclusive_with_api_key(self) -> None:
428 with pytest.raises(
429 OpenAIError,
430 match="The `api_key` and `workload_identity` arguments are mutually exclusive",
431 ):
432 OpenAI(
433 base_url=base_url,
434 api_key=api_key,
435 workload_identity=workload_identity, # type: ignore[reportArgumentType]
436 organization="org_123",
437 _strict_response_validation=True,
438 )
439
440 def test_default_query_option(self) -> None:
441 client = OpenAI(
442 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
443 )
444 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
445 url = httpx.URL(request.url)
446 assert dict(url.params) == {"query_param": "bar"}
447
448 request = client._build_request(
449 FinalRequestOptions(
450 method="get",
451 url="/foo",
452 params={"foo": "baz", "query_param": "overridden"},
453 )
454 )
455 url = httpx.URL(request.url)
456 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
457
458 client.close()
459
460 def test_hardcoded_query_params_in_url(self, client: OpenAI) -> None:
461 request = client._build_request(FinalRequestOptions(method="get", url="/foo?beta=true"))
462 url = httpx.URL(request.url)
463 assert dict(url.params) == {"beta": "true"}
464
465 request = client._build_request(
466 FinalRequestOptions(
467 method="get",
468 url="/foo?beta=true",
469 params={"limit": "10", "page": "abc"},
470 )
471 )
472 url = httpx.URL(request.url)
473 assert dict(url.params) == {"beta": "true", "limit": "10", "page": "abc"}
474
475 request = client._build_request(
476 FinalRequestOptions(
477 method="get",
478 url="/files/a%2Fb?beta=true",
479 params={"limit": "10"},
480 )
481 )
482 assert request.url.raw_path == b"/files/a%2Fb?beta=true&limit=10"
483
484 def test_request_extra_json(self, client: OpenAI) -> None:
485 request = client._build_request(
486 FinalRequestOptions(
487 method="post",
488 url="/foo",
489 json_data={"foo": "bar"},
490 extra_json={"baz": False},
491 ),
492 )
493 data = json.loads(request.content.decode("utf-8"))
494 assert data == {"foo": "bar", "baz": False}
495
496 request = client._build_request(
497 FinalRequestOptions(
498 method="post",
499 url="/foo",
500 extra_json={"baz": False},
501 ),
502 )
503 data = json.loads(request.content.decode("utf-8"))
504 assert data == {"baz": False}
505
506 # `extra_json` takes priority over `json_data` when keys clash
507 request = client._build_request(
508 FinalRequestOptions(
509 method="post",
510 url="/foo",
511 json_data={"foo": "bar", "baz": True},
512 extra_json={"baz": None},
513 ),
514 )
515 data = json.loads(request.content.decode("utf-8"))
516 assert data == {"foo": "bar", "baz": None}
517
518 def test_request_extra_headers(self, client: OpenAI) -> None:
519 request = client._build_request(
520 FinalRequestOptions(
521 method="post",
522 url="/foo",
523 **make_request_options(extra_headers={"X-Foo": "Foo"}),
524 ),
525 )
526 assert request.headers.get("X-Foo") == "Foo"
527
528 # `extra_headers` takes priority over `default_headers` when keys clash
529 request = client.with_options(default_headers={"X-Bar": "true"})._build_request(
530 FinalRequestOptions(
531 method="post",
532 url="/foo",
533 **make_request_options(
534 extra_headers={"X-Bar": "false"},
535 ),
536 ),
537 )
538 assert request.headers.get("X-Bar") == "false"
539
540 def test_request_extra_query(self, client: OpenAI) -> None:
541 request = client._build_request(
542 FinalRequestOptions(
543 method="post",
544 url="/foo",
545 **make_request_options(
546 extra_query={"my_query_param": "Foo"},
547 ),
548 ),
549 )
550 params = dict(request.url.params)
551 assert params == {"my_query_param": "Foo"}
552
553 # if both `query` and `extra_query` are given, they are merged
554 request = client._build_request(
555 FinalRequestOptions(
556 method="post",
557 url="/foo",
558 **make_request_options(
559 query={"bar": "1"},
560 extra_query={"foo": "2"},
561 ),
562 ),
563 )
564 params = dict(request.url.params)
565 assert params == {"bar": "1", "foo": "2"}
566
567 # `extra_query` takes priority over `query` when keys clash
568 request = client._build_request(
569 FinalRequestOptions(
570 method="post",
571 url="/foo",
572 **make_request_options(
573 query={"foo": "1"},
574 extra_query={"foo": "2"},
575 ),
576 ),
577 )
578 params = dict(request.url.params)
579 assert params == {"foo": "2"}
580
581 def test_multipart_repeating_array(self, client: OpenAI) -> None:
582 request = client._build_request(
583 FinalRequestOptions.construct(
584 method="post",
585 url="/foo",
586 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
587 json_data={"array": ["foo", "bar"]},
588 files=[("foo.txt", b"hello world")],
589 )
590 )
591
592 assert request.read().split(b"\r\n") == [
593 b"--6b7ba517decee4a450543ea6ae821c82",
594 b'Content-Disposition: form-data; name="array[]"',
595 b"",
596 b"foo",
597 b"--6b7ba517decee4a450543ea6ae821c82",
598 b'Content-Disposition: form-data; name="array[]"',
599 b"",
600 b"bar",
601 b"--6b7ba517decee4a450543ea6ae821c82",
602 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
603 b"Content-Type: application/octet-stream",
604 b"",
605 b"hello world",
606 b"--6b7ba517decee4a450543ea6ae821c82--",
607 b"",
608 ]
609
610 @pytest.mark.respx(base_url=base_url)
611 def test_binary_content_upload(self, respx_mock: MockRouter, client: OpenAI) -> None:
612 respx_mock.post("/upload").mock(side_effect=mirror_request_content)
613
614 file_content = b"Hello, this is a test file."
615
616 response = client.post(
617 "/upload",
618 content=file_content,
619 cast_to=httpx.Response,
620 options={"headers": {"Content-Type": "application/octet-stream"}},
621 )
622
623 assert response.status_code == 200
624 assert response.request.headers["Content-Type"] == "application/octet-stream"
625 assert response.content == file_content
626
627 def test_binary_content_upload_with_iterator(self) -> None:
628 file_content = b"Hello, this is a test file."
629 counter = Counter()
630 iterator = _make_sync_iterator([file_content], counter=counter)
631
632 def mock_handler(request: httpx.Request) -> httpx.Response:
633 assert counter.value == 0, "the request body should not have been read"
634 return httpx.Response(200, content=request.read())
635
636 with OpenAI(
637 base_url=base_url,
638 api_key=api_key,
639 _strict_response_validation=True,
640 http_client=httpx.Client(transport=MockTransport(handler=mock_handler)),
641 ) as client:
642 response = client.post(
643 "/upload",
644 content=iterator,
645 cast_to=httpx.Response,
646 options={"headers": {"Content-Type": "application/octet-stream"}},
647 )
648
649 assert response.status_code == 200
650 assert response.request.headers["Content-Type"] == "application/octet-stream"
651 assert response.content == file_content
652 assert counter.value == 1
653
654 @pytest.mark.respx(base_url=base_url)
655 def test_binary_content_upload_with_body_is_deprecated(self, respx_mock: MockRouter, client: OpenAI) -> None:
656 respx_mock.post("/upload").mock(side_effect=mirror_request_content)
657
658 file_content = b"Hello, this is a test file."
659
660 with pytest.deprecated_call(
661 match="Passing raw bytes as `body` is deprecated and will be removed in a future version. Please pass raw bytes via the `content` parameter instead."
662 ):
663 response = client.post(
664 "/upload",
665 body=file_content,
666 cast_to=httpx.Response,
667 options={"headers": {"Content-Type": "application/octet-stream"}},
668 )
669
670 assert response.status_code == 200
671 assert response.request.headers["Content-Type"] == "application/octet-stream"
672 assert response.content == file_content
673
674 @pytest.mark.respx(base_url=base_url)
675 def test_basic_union_response(self, respx_mock: MockRouter, client: OpenAI) -> None:
676 class Model1(BaseModel):
677 name: str
678
679 class Model2(BaseModel):
680 foo: str
681
682 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
683
684 response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
685 assert isinstance(response, Model2)
686 assert response.foo == "bar"
687
688 @pytest.mark.respx(base_url=base_url)
689 def test_union_response_different_types(self, respx_mock: MockRouter, client: OpenAI) -> None:
690 """Union of objects with the same field name using a different type"""
691
692 class Model1(BaseModel):
693 foo: int
694
695 class Model2(BaseModel):
696 foo: str
697
698 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
699
700 response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
701 assert isinstance(response, Model2)
702 assert response.foo == "bar"
703
704 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
705
706 response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
707 assert isinstance(response, Model1)
708 assert response.foo == 1
709
710 @pytest.mark.respx(base_url=base_url)
711 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter, client: OpenAI) -> None:
712 """
713 Response that sets Content-Type to something other than application/json but returns json data
714 """
715
716 class Model(BaseModel):
717 foo: int
718
719 respx_mock.get("/foo").mock(
720 return_value=httpx.Response(
721 200,
722 content=json.dumps({"foo": 2}),
723 headers={"Content-Type": "application/text"},
724 )
725 )
726
727 response = client.get("/foo", cast_to=Model)
728 assert isinstance(response, Model)
729 assert response.foo == 2
730
731 def test_base_url_setter(self) -> None:
732 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
733 assert client.base_url == "https://example.com/from_init/"
734
735 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
736
737 assert client.base_url == "https://example.com/from_setter/"
738
739 client.close()
740
741 def test_base_url_env(self) -> None:
742 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
743 client = OpenAI(api_key=api_key, _strict_response_validation=True)
744 assert client.base_url == "http://localhost:5000/from/env/"
745
746 @pytest.mark.parametrize(
747 "client",
748 [
749 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
750 OpenAI(
751 base_url="http://localhost:5000/custom/path/",
752 api_key=api_key,
753 _strict_response_validation=True,
754 http_client=httpx.Client(),
755 ),
756 ],
757 ids=["standard", "custom http client"],
758 )
759 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
760 request = client._build_request(
761 FinalRequestOptions(
762 method="post",
763 url="/foo",
764 json_data={"foo": "bar"},
765 ),
766 )
767 assert request.url == "http://localhost:5000/custom/path/foo"
768 client.close()
769
770 @pytest.mark.parametrize(
771 "client",
772 [
773 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
774 OpenAI(
775 base_url="http://localhost:5000/custom/path/",
776 api_key=api_key,
777 _strict_response_validation=True,
778 http_client=httpx.Client(),
779 ),
780 ],
781 ids=["standard", "custom http client"],
782 )
783 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
784 request = client._build_request(
785 FinalRequestOptions(
786 method="post",
787 url="/foo",
788 json_data={"foo": "bar"},
789 ),
790 )
791 assert request.url == "http://localhost:5000/custom/path/foo"
792 client.close()
793
794 @pytest.mark.parametrize(
795 "client",
796 [
797 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
798 OpenAI(
799 base_url="http://localhost:5000/custom/path/",
800 api_key=api_key,
801 _strict_response_validation=True,
802 http_client=httpx.Client(),
803 ),
804 ],
805 ids=["standard", "custom http client"],
806 )
807 def test_absolute_request_url(self, client: OpenAI) -> None:
808 request = client._build_request(
809 FinalRequestOptions(
810 method="post",
811 url="https://myapi.com/foo",
812 json_data={"foo": "bar"},
813 ),
814 )
815 assert request.url == "https://myapi.com/foo"
816 client.close()
817
818 def test_copied_client_does_not_close_http(self) -> None:
819 test_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
820 assert not test_client.is_closed()
821
822 copied = test_client.copy()
823 assert copied is not test_client
824
825 del copied
826
827 assert not test_client.is_closed()
828
829 def test_client_context_manager(self) -> None:
830 test_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
831 with test_client as c2:
832 assert c2 is test_client
833 assert not c2.is_closed()
834 assert not test_client.is_closed()
835 assert test_client.is_closed()
836
837 @pytest.mark.respx(base_url=base_url)
838 def test_client_response_validation_error(self, respx_mock: MockRouter, client: OpenAI) -> None:
839 class Model(BaseModel):
840 foo: str
841
842 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
843
844 with pytest.raises(APIResponseValidationError) as exc:
845 client.get("/foo", cast_to=Model)
846
847 assert isinstance(exc.value.__cause__, ValidationError)
848
849 def test_client_max_retries_validation(self) -> None:
850 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
851 OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None))
852
853 @pytest.mark.respx(base_url=base_url)
854 def test_default_stream_cls(self, respx_mock: MockRouter, client: OpenAI) -> None:
855 class Model(BaseModel):
856 name: str
857
858 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
859
860 stream = client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
861 assert isinstance(stream, Stream)
862 stream.response.close()
863
864 @pytest.mark.respx(base_url=base_url)
865 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
866 class Model(BaseModel):
867 name: str
868
869 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
870
871 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
872
873 with pytest.raises(APIResponseValidationError):
874 strict_client.get("/foo", cast_to=Model)
875
876 non_strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
877
878 response = non_strict_client.get("/foo", cast_to=Model)
879 assert isinstance(response, str) # type: ignore[unreachable]
880
881 strict_client.close()
882 non_strict_client.close()
883
884 @pytest.mark.parametrize(
885 "remaining_retries,retry_after,timeout",
886 [
887 [3, "20", 20],
888 [3, "0", 0.5],
889 [3, "-10", 0.5],
890 [3, "60", 60],
891 [3, "61", 0.5],
892 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
893 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
894 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
895 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
896 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
897 [3, "99999999999999999999999999999999999", 0.5],
898 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
899 [3, "", 0.5],
900 [2, "", 0.5 * 2.0],
901 [1, "", 0.5 * 4.0],
902 [-1100, "", 8], # test large number potentially overflowing
903 ],
904 )
905 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
906 def test_parse_retry_after_header(
907 self, remaining_retries: int, retry_after: str, timeout: float, client: OpenAI
908 ) -> None:
909 headers = httpx.Headers({"retry-after": retry_after})
910 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
911 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
912 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
913
914 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
915 @pytest.mark.respx(base_url=base_url)
916 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: OpenAI) -> None:
917 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
918
919 with pytest.raises(APITimeoutError):
920 client.chat.completions.with_streaming_response.create(
921 messages=[
922 {
923 "content": "string",
924 "role": "developer",
925 }
926 ],
927 model="gpt-5.4",
928 ).__enter__()
929
930 assert _get_open_connections(client) == 0
931
932 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
933 @pytest.mark.respx(base_url=base_url)
934 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: OpenAI) -> None:
935 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
936
937 with pytest.raises(APIStatusError):
938 client.chat.completions.with_streaming_response.create(
939 messages=[
940 {
941 "content": "string",
942 "role": "developer",
943 }
944 ],
945 model="gpt-5.4",
946 ).__enter__()
947 assert _get_open_connections(client) == 0
948
949 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
950 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
951 @pytest.mark.respx(base_url=base_url)
952 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
953 def test_retries_taken(
954 self,
955 client: OpenAI,
956 failures_before_success: int,
957 failure_mode: Literal["status", "exception"],
958 respx_mock: MockRouter,
959 ) -> None:
960 client = client.with_options(max_retries=4)
961
962 nb_retries = 0
963
964 def retry_handler(_request: httpx.Request) -> httpx.Response:
965 nonlocal nb_retries
966 if nb_retries < failures_before_success:
967 nb_retries += 1
968 if failure_mode == "exception":
969 raise RuntimeError("oops")
970 return httpx.Response(500)
971 return httpx.Response(200)
972
973 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
974
975 response = client.chat.completions.with_raw_response.create(
976 messages=[
977 {
978 "content": "string",
979 "role": "developer",
980 }
981 ],
982 model="gpt-5.4",
983 )
984
985 assert response.retries_taken == failures_before_success
986 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
987
988 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
989 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
990 @pytest.mark.respx(base_url=base_url)
991 def test_omit_retry_count_header(
992 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
993 ) -> None:
994 client = client.with_options(max_retries=4)
995
996 nb_retries = 0
997
998 def retry_handler(_request: httpx.Request) -> httpx.Response:
999 nonlocal nb_retries
1000 if nb_retries < failures_before_success:
1001 nb_retries += 1
1002 return httpx.Response(500)
1003 return httpx.Response(200)
1004
1005 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1006
1007 response = client.chat.completions.with_raw_response.create(
1008 messages=[
1009 {
1010 "content": "string",
1011 "role": "developer",
1012 }
1013 ],
1014 model="gpt-5.4",
1015 extra_headers={"x-stainless-retry-count": Omit()},
1016 )
1017
1018 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
1019
1020 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1021 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1022 @pytest.mark.respx(base_url=base_url)
1023 def test_overwrite_retry_count_header(
1024 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
1025 ) -> None:
1026 client = client.with_options(max_retries=4)
1027
1028 nb_retries = 0
1029
1030 def retry_handler(_request: httpx.Request) -> httpx.Response:
1031 nonlocal nb_retries
1032 if nb_retries < failures_before_success:
1033 nb_retries += 1
1034 return httpx.Response(500)
1035 return httpx.Response(200)
1036
1037 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1038
1039 response = client.chat.completions.with_raw_response.create(
1040 messages=[
1041 {
1042 "content": "string",
1043 "role": "developer",
1044 }
1045 ],
1046 model="gpt-5.4",
1047 extra_headers={"x-stainless-retry-count": "42"},
1048 )
1049
1050 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
1051
1052 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1053 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1054 @pytest.mark.respx(base_url=base_url)
1055 def test_retries_taken_new_response_class(
1056 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
1057 ) -> None:
1058 client = client.with_options(max_retries=4)
1059
1060 nb_retries = 0
1061
1062 def retry_handler(_request: httpx.Request) -> httpx.Response:
1063 nonlocal nb_retries
1064 if nb_retries < failures_before_success:
1065 nb_retries += 1
1066 return httpx.Response(500)
1067 return httpx.Response(200)
1068
1069 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1070
1071 with client.chat.completions.with_streaming_response.create(
1072 messages=[
1073 {
1074 "content": "string",
1075 "role": "developer",
1076 }
1077 ],
1078 model="gpt-5.4",
1079 ) as response:
1080 assert response.retries_taken == failures_before_success
1081 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1082
1083 def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
1084 # Test that the proxy environment variables are set correctly
1085 monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
1086 # Delete in case our environment has any proxy env vars set
1087 monkeypatch.delenv("HTTP_PROXY", raising=False)
1088 monkeypatch.delenv("ALL_PROXY", raising=False)
1089 monkeypatch.delenv("NO_PROXY", raising=False)
1090 monkeypatch.delenv("http_proxy", raising=False)
1091 monkeypatch.delenv("https_proxy", raising=False)
1092 monkeypatch.delenv("all_proxy", raising=False)
1093 monkeypatch.delenv("no_proxy", raising=False)
1094
1095 client = DefaultHttpxClient()
1096
1097 mounts = tuple(client._mounts.items())
1098 assert len(mounts) == 1
1099 assert mounts[0][0].pattern == "https://"
1100
1101 @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
1102 def test_default_client_creation(self) -> None:
1103 # Ensure that the client can be initialized without any exceptions
1104 DefaultHttpxClient(
1105 verify=True,
1106 cert=None,
1107 trust_env=True,
1108 http1=True,
1109 http2=False,
1110 limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
1111 )
1112
1113 @pytest.mark.respx(base_url=base_url)
1114 def test_follow_redirects(self, respx_mock: MockRouter, client: OpenAI) -> None:
1115 # Test that the default follow_redirects=True allows following redirects
1116 respx_mock.post("/redirect").mock(
1117 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1118 )
1119 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
1120
1121 response = client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
1122 assert response.status_code == 200
1123 assert response.json() == {"status": "ok"}
1124
1125 @pytest.mark.respx(base_url=base_url)
1126 def test_follow_redirects_disabled(self, respx_mock: MockRouter, client: OpenAI) -> None:
1127 # Test that follow_redirects=False prevents following redirects
1128 respx_mock.post("/redirect").mock(
1129 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1130 )
1131
1132 with pytest.raises(APIStatusError) as exc_info:
1133 client.post("/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response)
1134
1135 assert exc_info.value.response.status_code == 302
1136 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
1137
1138 def test_api_key_before_after_refresh_provider(self) -> None:
1139 client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token")
1140
1141 assert client.api_key == ""
1142 assert "Authorization" not in client.auth_headers
1143
1144 client._refresh_api_key()
1145
1146 assert client.api_key == "test_bearer_token"
1147 assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
1148
1149 def test_api_key_before_after_refresh_str(self) -> None:
1150 client = OpenAI(base_url=base_url, api_key="test_api_key")
1151
1152 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
1153 client._refresh_api_key()
1154
1155 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
1156
1157 @pytest.mark.respx()
1158 def test_api_key_refresh_on_retry(self, respx_mock: MockRouter) -> None:
1159 respx_mock.post(base_url + "/chat/completions").mock(
1160 side_effect=[
1161 httpx.Response(500, json={"error": "server error"}),
1162 httpx.Response(200, json={"foo": "bar"}),
1163 ]
1164 )
1165
1166 counter = 0
1167
1168 def token_provider() -> str:
1169 nonlocal counter
1170
1171 counter += 1
1172
1173 if counter == 1:
1174 return "first"
1175
1176 return "second"
1177
1178 client = OpenAI(base_url=base_url, api_key=token_provider)
1179 client.chat.completions.create(messages=[], model="gpt-4")
1180
1181 calls = cast("list[MockRequestCall]", respx_mock.calls)
1182 assert len(calls) == 2
1183
1184 assert calls[0].request.headers.get("Authorization") == "Bearer first"
1185 assert calls[1].request.headers.get("Authorization") == "Bearer second"
1186
1187 def test_copy_auth(self) -> None:
1188 client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token_1").copy(
1189 api_key=lambda: "test_bearer_token_2"
1190 )
1191 client._refresh_api_key()
1192 assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}
1193
1194
1195class TestAsyncOpenAI:
1196 @pytest.mark.respx(base_url=base_url)
1197 async def test_raw_response(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1198 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1199
1200 response = await async_client.post("/foo", cast_to=httpx.Response)
1201 assert response.status_code == 200
1202 assert isinstance(response, httpx.Response)
1203 assert response.json() == {"foo": "bar"}
1204
1205 @pytest.mark.respx(base_url=base_url)
1206 async def test_raw_response_for_binary(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1207 respx_mock.post("/foo").mock(
1208 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
1209 )
1210
1211 response = await async_client.post("/foo", cast_to=httpx.Response)
1212 assert response.status_code == 200
1213 assert isinstance(response, httpx.Response)
1214 assert response.json() == {"foo": "bar"}
1215
1216 def test_copy(self, async_client: AsyncOpenAI) -> None:
1217 copied = async_client.copy()
1218 assert id(copied) != id(async_client)
1219
1220 copied = async_client.copy(api_key="another My API Key")
1221 assert copied.api_key == "another My API Key"
1222 assert async_client.api_key == "My API Key"
1223
1224 def test_copy_default_options(self, async_client: AsyncOpenAI) -> None:
1225 # options that have a default are overridden correctly
1226 copied = async_client.copy(max_retries=7)
1227 assert copied.max_retries == 7
1228 assert async_client.max_retries == 2
1229
1230 copied2 = copied.copy(max_retries=6)
1231 assert copied2.max_retries == 6
1232 assert copied.max_retries == 7
1233
1234 # timeout
1235 assert isinstance(async_client.timeout, httpx.Timeout)
1236 copied = async_client.copy(timeout=None)
1237 assert copied.timeout is None
1238 assert isinstance(async_client.timeout, httpx.Timeout)
1239
1240 async def test_copy_default_headers(self) -> None:
1241 client = AsyncOpenAI(
1242 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1243 )
1244 assert client.default_headers["X-Foo"] == "bar"
1245
1246 # does not override the already given value when not specified
1247 copied = client.copy()
1248 assert copied.default_headers["X-Foo"] == "bar"
1249
1250 # merges already given headers
1251 copied = client.copy(default_headers={"X-Bar": "stainless"})
1252 assert copied.default_headers["X-Foo"] == "bar"
1253 assert copied.default_headers["X-Bar"] == "stainless"
1254
1255 # uses new values for any already given headers
1256 copied = client.copy(default_headers={"X-Foo": "stainless"})
1257 assert copied.default_headers["X-Foo"] == "stainless"
1258
1259 # set_default_headers
1260
1261 # completely overrides already set values
1262 copied = client.copy(set_default_headers={})
1263 assert copied.default_headers.get("X-Foo") is None
1264
1265 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
1266 assert copied.default_headers["X-Bar"] == "Robert"
1267
1268 with pytest.raises(
1269 ValueError,
1270 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
1271 ):
1272 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
1273 await client.close()
1274
1275 async def test_copy_default_query(self) -> None:
1276 client = AsyncOpenAI(
1277 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
1278 )
1279 assert _get_params(client)["foo"] == "bar"
1280
1281 # does not override the already given value when not specified
1282 copied = client.copy()
1283 assert _get_params(copied)["foo"] == "bar"
1284
1285 # merges already given params
1286 copied = client.copy(default_query={"bar": "stainless"})
1287 params = _get_params(copied)
1288 assert params["foo"] == "bar"
1289 assert params["bar"] == "stainless"
1290
1291 # uses new values for any already given headers
1292 copied = client.copy(default_query={"foo": "stainless"})
1293 assert _get_params(copied)["foo"] == "stainless"
1294
1295 # set_default_query
1296
1297 # completely overrides already set values
1298 copied = client.copy(set_default_query={})
1299 assert _get_params(copied) == {}
1300
1301 copied = client.copy(set_default_query={"bar": "Robert"})
1302 assert _get_params(copied)["bar"] == "Robert"
1303
1304 with pytest.raises(
1305 ValueError,
1306 # TODO: update
1307 match="`default_query` and `set_default_query` arguments are mutually exclusive",
1308 ):
1309 client.copy(set_default_query={}, default_query={"foo": "Bar"})
1310
1311 await client.close()
1312
1313 def test_copy_signature(self, async_client: AsyncOpenAI) -> None:
1314 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
1315 init_signature = inspect.signature(
1316 # mypy doesn't like that we access the `__init__` property.
1317 async_client.__init__, # type: ignore[misc]
1318 )
1319 copy_signature = inspect.signature(async_client.copy)
1320 exclude_params = {"transport", "proxies", "_strict_response_validation"}
1321
1322 for name in init_signature.parameters.keys():
1323 if name in exclude_params:
1324 continue
1325
1326 copy_param = copy_signature.parameters.get(name)
1327 assert copy_param is not None, f"copy() signature is missing the {name} param"
1328
1329 @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
1330 def test_copy_build_request(self, async_client: AsyncOpenAI) -> None:
1331 options = FinalRequestOptions(method="get", url="/foo")
1332
1333 def build_request(options: FinalRequestOptions) -> None:
1334 client_copy = async_client.copy()
1335 client_copy._build_request(options)
1336
1337 # ensure that the machinery is warmed up before tracing starts.
1338 build_request(options)
1339 gc.collect()
1340
1341 tracemalloc.start(1000)
1342
1343 snapshot_before = tracemalloc.take_snapshot()
1344
1345 ITERATIONS = 10
1346 for _ in range(ITERATIONS):
1347 build_request(options)
1348
1349 gc.collect()
1350 snapshot_after = tracemalloc.take_snapshot()
1351
1352 tracemalloc.stop()
1353
1354 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
1355 if diff.count == 0:
1356 # Avoid false positives by considering only leaks (i.e. allocations that persist).
1357 return
1358
1359 if diff.count % ITERATIONS != 0:
1360 # Avoid false positives by considering only leaks that appear per iteration.
1361 return
1362
1363 for frame in diff.traceback:
1364 if any(
1365 frame.filename.endswith(fragment)
1366 for fragment in [
1367 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
1368 #
1369 # removing the decorator fixes the leak for reasons we don't understand.
1370 "openai/_legacy_response.py",
1371 "openai/_response.py",
1372 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
1373 "openai/_compat.py",
1374 # Standard library leaks we don't care about.
1375 "/logging/__init__.py",
1376 ]
1377 ):
1378 return
1379
1380 leaks.append(diff)
1381
1382 leaks: list[tracemalloc.StatisticDiff] = []
1383 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
1384 add_leak(leaks, diff)
1385 if leaks:
1386 for leak in leaks:
1387 print("MEMORY LEAK:", leak)
1388 for frame in leak.traceback:
1389 print(frame)
1390 raise AssertionError()
1391
1392 async def test_request_timeout(self, async_client: AsyncOpenAI) -> None:
1393 request = async_client._build_request(FinalRequestOptions(method="get", url="/foo"))
1394 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1395 assert timeout == DEFAULT_TIMEOUT
1396
1397 request = async_client._build_request(
1398 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
1399 )
1400 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1401 assert timeout == httpx.Timeout(100.0)
1402
1403 async def test_client_timeout_option(self) -> None:
1404 client = AsyncOpenAI(
1405 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
1406 )
1407
1408 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1409 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1410 assert timeout == httpx.Timeout(0)
1411
1412 await client.close()
1413
1414 async def test_http_client_timeout_option(self) -> None:
1415 # custom timeout given to the httpx client should be used
1416 async with httpx.AsyncClient(timeout=None) as http_client:
1417 client = AsyncOpenAI(
1418 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1419 )
1420
1421 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1422 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1423 assert timeout == httpx.Timeout(None)
1424
1425 await client.close()
1426
1427 # no timeout given to the httpx client should not use the httpx default
1428 async with httpx.AsyncClient() as http_client:
1429 client = AsyncOpenAI(
1430 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1431 )
1432
1433 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1434 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1435 assert timeout == DEFAULT_TIMEOUT
1436
1437 await client.close()
1438
1439 # explicitly passing the default timeout currently results in it being ignored
1440 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1441 client = AsyncOpenAI(
1442 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1443 )
1444
1445 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1446 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1447 assert timeout == DEFAULT_TIMEOUT # our default
1448
1449 await client.close()
1450
1451 def test_invalid_http_client(self) -> None:
1452 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1453 with httpx.Client() as http_client:
1454 AsyncOpenAI(
1455 base_url=base_url,
1456 api_key=api_key,
1457 _strict_response_validation=True,
1458 http_client=cast(Any, http_client),
1459 )
1460
1461 async def test_default_headers_option(self) -> None:
1462 test_client = AsyncOpenAI(
1463 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1464 )
1465 request = test_client._build_request(FinalRequestOptions(method="get", url="/foo"))
1466 assert request.headers.get("x-foo") == "bar"
1467 assert request.headers.get("x-stainless-lang") == "python"
1468
1469 test_client2 = AsyncOpenAI(
1470 base_url=base_url,
1471 api_key=api_key,
1472 _strict_response_validation=True,
1473 default_headers={
1474 "X-Foo": "stainless",
1475 "X-Stainless-Lang": "my-overriding-header",
1476 },
1477 )
1478 request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1479 assert request.headers.get("x-foo") == "stainless"
1480 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1481
1482 await test_client.close()
1483 await test_client2.close()
1484
1485 async def test_validate_headers(self) -> None:
1486 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1487 options = await client._prepare_options(FinalRequestOptions(method="get", url="/foo"))
1488 request = client._build_request(options)
1489 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1490
1491 with pytest.raises(OpenAIError):
1492 with update_env(**{"OPENAI_API_KEY": Omit()}):
1493 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1494 _ = client2
1495
1496 async def test_default_query_option(self) -> None:
1497 client = AsyncOpenAI(
1498 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1499 )
1500 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1501 url = httpx.URL(request.url)
1502 assert dict(url.params) == {"query_param": "bar"}
1503
1504 request = client._build_request(
1505 FinalRequestOptions(
1506 method="get",
1507 url="/foo",
1508 params={"foo": "baz", "query_param": "overridden"},
1509 )
1510 )
1511 url = httpx.URL(request.url)
1512 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
1513
1514 await client.close()
1515
1516 async def test_hardcoded_query_params_in_url(self, async_client: AsyncOpenAI) -> None:
1517 request = async_client._build_request(FinalRequestOptions(method="get", url="/foo?beta=true"))
1518 url = httpx.URL(request.url)
1519 assert dict(url.params) == {"beta": "true"}
1520
1521 request = async_client._build_request(
1522 FinalRequestOptions(
1523 method="get",
1524 url="/foo?beta=true",
1525 params={"limit": "10", "page": "abc"},
1526 )
1527 )
1528 url = httpx.URL(request.url)
1529 assert dict(url.params) == {"beta": "true", "limit": "10", "page": "abc"}
1530
1531 request = async_client._build_request(
1532 FinalRequestOptions(
1533 method="get",
1534 url="/files/a%2Fb?beta=true",
1535 params={"limit": "10"},
1536 )
1537 )
1538 assert request.url.raw_path == b"/files/a%2Fb?beta=true&limit=10"
1539
1540 def test_request_extra_json(self, client: OpenAI) -> None:
1541 request = client._build_request(
1542 FinalRequestOptions(
1543 method="post",
1544 url="/foo",
1545 json_data={"foo": "bar"},
1546 extra_json={"baz": False},
1547 ),
1548 )
1549 data = json.loads(request.content.decode("utf-8"))
1550 assert data == {"foo": "bar", "baz": False}
1551
1552 request = client._build_request(
1553 FinalRequestOptions(
1554 method="post",
1555 url="/foo",
1556 extra_json={"baz": False},
1557 ),
1558 )
1559 data = json.loads(request.content.decode("utf-8"))
1560 assert data == {"baz": False}
1561
1562 # `extra_json` takes priority over `json_data` when keys clash
1563 request = client._build_request(
1564 FinalRequestOptions(
1565 method="post",
1566 url="/foo",
1567 json_data={"foo": "bar", "baz": True},
1568 extra_json={"baz": None},
1569 ),
1570 )
1571 data = json.loads(request.content.decode("utf-8"))
1572 assert data == {"foo": "bar", "baz": None}
1573
1574 def test_request_extra_headers(self, client: OpenAI) -> None:
1575 request = client._build_request(
1576 FinalRequestOptions(
1577 method="post",
1578 url="/foo",
1579 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1580 ),
1581 )
1582 assert request.headers.get("X-Foo") == "Foo"
1583
1584 # `extra_headers` takes priority over `default_headers` when keys clash
1585 request = client.with_options(default_headers={"X-Bar": "true"})._build_request(
1586 FinalRequestOptions(
1587 method="post",
1588 url="/foo",
1589 **make_request_options(
1590 extra_headers={"X-Bar": "false"},
1591 ),
1592 ),
1593 )
1594 assert request.headers.get("X-Bar") == "false"
1595
1596 def test_request_extra_query(self, client: OpenAI) -> None:
1597 request = client._build_request(
1598 FinalRequestOptions(
1599 method="post",
1600 url="/foo",
1601 **make_request_options(
1602 extra_query={"my_query_param": "Foo"},
1603 ),
1604 ),
1605 )
1606 params = dict(request.url.params)
1607 assert params == {"my_query_param": "Foo"}
1608
1609 # if both `query` and `extra_query` are given, they are merged
1610 request = client._build_request(
1611 FinalRequestOptions(
1612 method="post",
1613 url="/foo",
1614 **make_request_options(
1615 query={"bar": "1"},
1616 extra_query={"foo": "2"},
1617 ),
1618 ),
1619 )
1620 params = dict(request.url.params)
1621 assert params == {"bar": "1", "foo": "2"}
1622
1623 # `extra_query` takes priority over `query` when keys clash
1624 request = client._build_request(
1625 FinalRequestOptions(
1626 method="post",
1627 url="/foo",
1628 **make_request_options(
1629 query={"foo": "1"},
1630 extra_query={"foo": "2"},
1631 ),
1632 ),
1633 )
1634 params = dict(request.url.params)
1635 assert params == {"foo": "2"}
1636
1637 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1638 request = async_client._build_request(
1639 FinalRequestOptions.construct(
1640 method="post",
1641 url="/foo",
1642 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1643 json_data={"array": ["foo", "bar"]},
1644 files=[("foo.txt", b"hello world")],
1645 )
1646 )
1647
1648 assert request.read().split(b"\r\n") == [
1649 b"--6b7ba517decee4a450543ea6ae821c82",
1650 b'Content-Disposition: form-data; name="array[]"',
1651 b"",
1652 b"foo",
1653 b"--6b7ba517decee4a450543ea6ae821c82",
1654 b'Content-Disposition: form-data; name="array[]"',
1655 b"",
1656 b"bar",
1657 b"--6b7ba517decee4a450543ea6ae821c82",
1658 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1659 b"Content-Type: application/octet-stream",
1660 b"",
1661 b"hello world",
1662 b"--6b7ba517decee4a450543ea6ae821c82--",
1663 b"",
1664 ]
1665
1666 @pytest.mark.respx(base_url=base_url)
1667 async def test_binary_content_upload(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1668 respx_mock.post("/upload").mock(side_effect=mirror_request_content)
1669
1670 file_content = b"Hello, this is a test file."
1671
1672 response = await async_client.post(
1673 "/upload",
1674 content=file_content,
1675 cast_to=httpx.Response,
1676 options={"headers": {"Content-Type": "application/octet-stream"}},
1677 )
1678
1679 assert response.status_code == 200
1680 assert response.request.headers["Content-Type"] == "application/octet-stream"
1681 assert response.content == file_content
1682
1683 async def test_binary_content_upload_with_asynciterator(self) -> None:
1684 file_content = b"Hello, this is a test file."
1685 counter = Counter()
1686 iterator = _make_async_iterator([file_content], counter=counter)
1687
1688 async def mock_handler(request: httpx.Request) -> httpx.Response:
1689 assert counter.value == 0, "the request body should not have been read"
1690 return httpx.Response(200, content=await request.aread())
1691
1692 async with AsyncOpenAI(
1693 base_url=base_url,
1694 api_key=api_key,
1695 _strict_response_validation=True,
1696 http_client=httpx.AsyncClient(transport=MockTransport(handler=mock_handler)),
1697 ) as client:
1698 response = await client.post(
1699 "/upload",
1700 content=iterator,
1701 cast_to=httpx.Response,
1702 options={"headers": {"Content-Type": "application/octet-stream"}},
1703 )
1704
1705 assert response.status_code == 200
1706 assert response.request.headers["Content-Type"] == "application/octet-stream"
1707 assert response.content == file_content
1708 assert counter.value == 1
1709
1710 @pytest.mark.respx(base_url=base_url)
1711 async def test_binary_content_upload_with_body_is_deprecated(
1712 self, respx_mock: MockRouter, async_client: AsyncOpenAI
1713 ) -> None:
1714 respx_mock.post("/upload").mock(side_effect=mirror_request_content)
1715
1716 file_content = b"Hello, this is a test file."
1717
1718 with pytest.deprecated_call(
1719 match="Passing raw bytes as `body` is deprecated and will be removed in a future version. Please pass raw bytes via the `content` parameter instead."
1720 ):
1721 response = await async_client.post(
1722 "/upload",
1723 body=file_content,
1724 cast_to=httpx.Response,
1725 options={"headers": {"Content-Type": "application/octet-stream"}},
1726 )
1727
1728 assert response.status_code == 200
1729 assert response.request.headers["Content-Type"] == "application/octet-stream"
1730 assert response.content == file_content
1731
1732 @pytest.mark.respx(base_url=base_url)
1733 async def test_basic_union_response(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1734 class Model1(BaseModel):
1735 name: str
1736
1737 class Model2(BaseModel):
1738 foo: str
1739
1740 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1741
1742 response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1743 assert isinstance(response, Model2)
1744 assert response.foo == "bar"
1745
1746 @pytest.mark.respx(base_url=base_url)
1747 async def test_union_response_different_types(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1748 """Union of objects with the same field name using a different type"""
1749
1750 class Model1(BaseModel):
1751 foo: int
1752
1753 class Model2(BaseModel):
1754 foo: str
1755
1756 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1757
1758 response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1759 assert isinstance(response, Model2)
1760 assert response.foo == "bar"
1761
1762 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1763
1764 response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1765 assert isinstance(response, Model1)
1766 assert response.foo == 1
1767
1768 @pytest.mark.respx(base_url=base_url)
1769 async def test_non_application_json_content_type_for_json_data(
1770 self, respx_mock: MockRouter, async_client: AsyncOpenAI
1771 ) -> None:
1772 """
1773 Response that sets Content-Type to something other than application/json but returns json data
1774 """
1775
1776 class Model(BaseModel):
1777 foo: int
1778
1779 respx_mock.get("/foo").mock(
1780 return_value=httpx.Response(
1781 200,
1782 content=json.dumps({"foo": 2}),
1783 headers={"Content-Type": "application/text"},
1784 )
1785 )
1786
1787 response = await async_client.get("/foo", cast_to=Model)
1788 assert isinstance(response, Model)
1789 assert response.foo == 2
1790
1791 async def test_base_url_setter(self) -> None:
1792 client = AsyncOpenAI(
1793 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1794 )
1795 assert client.base_url == "https://example.com/from_init/"
1796
1797 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1798
1799 assert client.base_url == "https://example.com/from_setter/"
1800
1801 await client.close()
1802
1803 async def test_base_url_env(self) -> None:
1804 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1805 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1806 assert client.base_url == "http://localhost:5000/from/env/"
1807
1808 @pytest.mark.parametrize(
1809 "client",
1810 [
1811 AsyncOpenAI(
1812 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1813 ),
1814 AsyncOpenAI(
1815 base_url="http://localhost:5000/custom/path/",
1816 api_key=api_key,
1817 _strict_response_validation=True,
1818 http_client=httpx.AsyncClient(),
1819 ),
1820 ],
1821 ids=["standard", "custom http client"],
1822 )
1823 async def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1824 request = client._build_request(
1825 FinalRequestOptions(
1826 method="post",
1827 url="/foo",
1828 json_data={"foo": "bar"},
1829 ),
1830 )
1831 assert request.url == "http://localhost:5000/custom/path/foo"
1832 await client.close()
1833
1834 @pytest.mark.parametrize(
1835 "client",
1836 [
1837 AsyncOpenAI(
1838 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1839 ),
1840 AsyncOpenAI(
1841 base_url="http://localhost:5000/custom/path/",
1842 api_key=api_key,
1843 _strict_response_validation=True,
1844 http_client=httpx.AsyncClient(),
1845 ),
1846 ],
1847 ids=["standard", "custom http client"],
1848 )
1849 async def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1850 request = client._build_request(
1851 FinalRequestOptions(
1852 method="post",
1853 url="/foo",
1854 json_data={"foo": "bar"},
1855 ),
1856 )
1857 assert request.url == "http://localhost:5000/custom/path/foo"
1858 await client.close()
1859
1860 @pytest.mark.parametrize(
1861 "client",
1862 [
1863 AsyncOpenAI(
1864 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1865 ),
1866 AsyncOpenAI(
1867 base_url="http://localhost:5000/custom/path/",
1868 api_key=api_key,
1869 _strict_response_validation=True,
1870 http_client=httpx.AsyncClient(),
1871 ),
1872 ],
1873 ids=["standard", "custom http client"],
1874 )
1875 async def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1876 request = client._build_request(
1877 FinalRequestOptions(
1878 method="post",
1879 url="https://myapi.com/foo",
1880 json_data={"foo": "bar"},
1881 ),
1882 )
1883 assert request.url == "https://myapi.com/foo"
1884 await client.close()
1885
1886 async def test_copied_client_does_not_close_http(self) -> None:
1887 test_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1888 assert not test_client.is_closed()
1889
1890 copied = test_client.copy()
1891 assert copied is not test_client
1892
1893 del copied
1894
1895 await asyncio.sleep(0.2)
1896 assert not test_client.is_closed()
1897
1898 async def test_client_context_manager(self) -> None:
1899 test_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1900 async with test_client as c2:
1901 assert c2 is test_client
1902 assert not c2.is_closed()
1903 assert not test_client.is_closed()
1904 assert test_client.is_closed()
1905
1906 @pytest.mark.respx(base_url=base_url)
1907 async def test_client_response_validation_error(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1908 class Model(BaseModel):
1909 foo: str
1910
1911 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1912
1913 with pytest.raises(APIResponseValidationError) as exc:
1914 await async_client.get("/foo", cast_to=Model)
1915
1916 assert isinstance(exc.value.__cause__, ValidationError)
1917
1918 async def test_client_max_retries_validation(self) -> None:
1919 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
1920 AsyncOpenAI(
1921 base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)
1922 )
1923
1924 @pytest.mark.respx(base_url=base_url)
1925 async def test_default_stream_cls(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1926 class Model(BaseModel):
1927 name: str
1928
1929 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1930
1931 stream = await async_client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1932 assert isinstance(stream, AsyncStream)
1933 await stream.response.aclose()
1934
1935 @pytest.mark.respx(base_url=base_url)
1936 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1937 class Model(BaseModel):
1938 name: str
1939
1940 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1941
1942 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1943
1944 with pytest.raises(APIResponseValidationError):
1945 await strict_client.get("/foo", cast_to=Model)
1946
1947 non_strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1948
1949 response = await non_strict_client.get("/foo", cast_to=Model)
1950 assert isinstance(response, str) # type: ignore[unreachable]
1951
1952 await strict_client.close()
1953 await non_strict_client.close()
1954
1955 @pytest.mark.parametrize(
1956 "remaining_retries,retry_after,timeout",
1957 [
1958 [3, "20", 20],
1959 [3, "0", 0.5],
1960 [3, "-10", 0.5],
1961 [3, "60", 60],
1962 [3, "61", 0.5],
1963 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1964 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1965 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1966 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1967 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1968 [3, "99999999999999999999999999999999999", 0.5],
1969 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1970 [3, "", 0.5],
1971 [2, "", 0.5 * 2.0],
1972 [1, "", 0.5 * 4.0],
1973 [-1100, "", 8], # test large number potentially overflowing
1974 ],
1975 )
1976 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1977 async def test_parse_retry_after_header(
1978 self, remaining_retries: int, retry_after: str, timeout: float, async_client: AsyncOpenAI
1979 ) -> None:
1980 headers = httpx.Headers({"retry-after": retry_after})
1981 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1982 calculated = async_client._calculate_retry_timeout(remaining_retries, options, headers)
1983 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1984
1985 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1986 @pytest.mark.respx(base_url=base_url)
1987 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1988 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1989
1990 with pytest.raises(APITimeoutError):
1991 await async_client.chat.completions.with_streaming_response.create(
1992 messages=[
1993 {
1994 "content": "string",
1995 "role": "developer",
1996 }
1997 ],
1998 model="gpt-5.4",
1999 ).__aenter__()
2000
2001 assert _get_open_connections(async_client) == 0
2002
2003 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
2004 @pytest.mark.respx(base_url=base_url)
2005 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
2006 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
2007
2008 with pytest.raises(APIStatusError):
2009 await async_client.chat.completions.with_streaming_response.create(
2010 messages=[
2011 {
2012 "content": "string",
2013 "role": "developer",
2014 }
2015 ],
2016 model="gpt-5.4",
2017 ).__aenter__()
2018 assert _get_open_connections(async_client) == 0
2019
2020 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
2021 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
2022 @pytest.mark.respx(base_url=base_url)
2023 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
2024 async def test_retries_taken(
2025 self,
2026 async_client: AsyncOpenAI,
2027 failures_before_success: int,
2028 failure_mode: Literal["status", "exception"],
2029 respx_mock: MockRouter,
2030 ) -> None:
2031 client = async_client.with_options(max_retries=4)
2032
2033 nb_retries = 0
2034
2035 def retry_handler(_request: httpx.Request) -> httpx.Response:
2036 nonlocal nb_retries
2037 if nb_retries < failures_before_success:
2038 nb_retries += 1
2039 if failure_mode == "exception":
2040 raise RuntimeError("oops")
2041 return httpx.Response(500)
2042 return httpx.Response(200)
2043
2044 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
2045
2046 response = await client.chat.completions.with_raw_response.create(
2047 messages=[
2048 {
2049 "content": "string",
2050 "role": "developer",
2051 }
2052 ],
2053 model="gpt-5.4",
2054 )
2055
2056 assert response.retries_taken == failures_before_success
2057 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
2058
2059 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
2060 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
2061 @pytest.mark.respx(base_url=base_url)
2062 async def test_omit_retry_count_header(
2063 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
2064 ) -> None:
2065 client = async_client.with_options(max_retries=4)
2066
2067 nb_retries = 0
2068
2069 def retry_handler(_request: httpx.Request) -> httpx.Response:
2070 nonlocal nb_retries
2071 if nb_retries < failures_before_success:
2072 nb_retries += 1
2073 return httpx.Response(500)
2074 return httpx.Response(200)
2075
2076 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
2077
2078 response = await client.chat.completions.with_raw_response.create(
2079 messages=[
2080 {
2081 "content": "string",
2082 "role": "developer",
2083 }
2084 ],
2085 model="gpt-5.4",
2086 extra_headers={"x-stainless-retry-count": Omit()},
2087 )
2088
2089 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
2090
2091 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
2092 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
2093 @pytest.mark.respx(base_url=base_url)
2094 async def test_overwrite_retry_count_header(
2095 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
2096 ) -> None:
2097 client = async_client.with_options(max_retries=4)
2098
2099 nb_retries = 0
2100
2101 def retry_handler(_request: httpx.Request) -> httpx.Response:
2102 nonlocal nb_retries
2103 if nb_retries < failures_before_success:
2104 nb_retries += 1
2105 return httpx.Response(500)
2106 return httpx.Response(200)
2107
2108 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
2109
2110 response = await client.chat.completions.with_raw_response.create(
2111 messages=[
2112 {
2113 "content": "string",
2114 "role": "developer",
2115 }
2116 ],
2117 model="gpt-5.4",
2118 extra_headers={"x-stainless-retry-count": "42"},
2119 )
2120
2121 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
2122
2123 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
2124 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
2125 @pytest.mark.respx(base_url=base_url)
2126 async def test_retries_taken_new_response_class(
2127 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
2128 ) -> None:
2129 client = async_client.with_options(max_retries=4)
2130
2131 nb_retries = 0
2132
2133 def retry_handler(_request: httpx.Request) -> httpx.Response:
2134 nonlocal nb_retries
2135 if nb_retries < failures_before_success:
2136 nb_retries += 1
2137 return httpx.Response(500)
2138 return httpx.Response(200)
2139
2140 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
2141
2142 async with client.chat.completions.with_streaming_response.create(
2143 messages=[
2144 {
2145 "content": "string",
2146 "role": "developer",
2147 }
2148 ],
2149 model="gpt-5.4",
2150 ) as response:
2151 assert response.retries_taken == failures_before_success
2152 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
2153
2154 async def test_get_platform(self) -> None:
2155 platform = await asyncify(get_platform)()
2156 assert isinstance(platform, (str, OtherPlatform))
2157
2158 async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
2159 # Test that the proxy environment variables are set correctly
2160 monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
2161 # Delete in case our environment has any proxy env vars set
2162 monkeypatch.delenv("HTTP_PROXY", raising=False)
2163 monkeypatch.delenv("ALL_PROXY", raising=False)
2164 monkeypatch.delenv("NO_PROXY", raising=False)
2165 monkeypatch.delenv("http_proxy", raising=False)
2166 monkeypatch.delenv("https_proxy", raising=False)
2167 monkeypatch.delenv("all_proxy", raising=False)
2168 monkeypatch.delenv("no_proxy", raising=False)
2169
2170 client = DefaultAsyncHttpxClient()
2171
2172 mounts = tuple(client._mounts.items())
2173 assert len(mounts) == 1
2174 assert mounts[0][0].pattern == "https://"
2175
2176 @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
2177 async def test_default_client_creation(self) -> None:
2178 # Ensure that the client can be initialized without any exceptions
2179 DefaultAsyncHttpxClient(
2180 verify=True,
2181 cert=None,
2182 trust_env=True,
2183 http1=True,
2184 http2=False,
2185 limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
2186 )
2187
2188 @pytest.mark.respx(base_url=base_url)
2189 async def test_follow_redirects(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
2190 # Test that the default follow_redirects=True allows following redirects
2191 respx_mock.post("/redirect").mock(
2192 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
2193 )
2194 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
2195
2196 response = await async_client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
2197 assert response.status_code == 200
2198 assert response.json() == {"status": "ok"}
2199
2200 @pytest.mark.respx(base_url=base_url)
2201 async def test_follow_redirects_disabled(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
2202 # Test that follow_redirects=False prevents following redirects
2203 respx_mock.post("/redirect").mock(
2204 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
2205 )
2206
2207 with pytest.raises(APIStatusError) as exc_info:
2208 await async_client.post(
2209 "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
2210 )
2211
2212 assert exc_info.value.response.status_code == 302
2213 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
2214
2215 async def test_api_key_before_after_refresh_provider(self) -> None:
2216 async def mock_api_key_provider():
2217 return "test_bearer_token"
2218
2219 client = AsyncOpenAI(base_url=base_url, api_key=mock_api_key_provider)
2220
2221 assert client.api_key == ""
2222 assert "Authorization" not in client.auth_headers
2223
2224 await client._refresh_api_key()
2225
2226 assert client.api_key == "test_bearer_token"
2227 assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
2228
2229 async def test_api_key_before_after_refresh_str(self) -> None:
2230 client = AsyncOpenAI(base_url=base_url, api_key="test_api_key")
2231
2232 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
2233 await client._refresh_api_key()
2234
2235 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
2236
2237 @pytest.mark.respx()
2238 async def test_bearer_token_refresh_async(self, respx_mock: MockRouter) -> None:
2239 respx_mock.post(base_url + "/chat/completions").mock(
2240 side_effect=[
2241 httpx.Response(500, json={"error": "server error"}),
2242 httpx.Response(200, json={"foo": "bar"}),
2243 ]
2244 )
2245
2246 counter = 0
2247
2248 async def token_provider() -> str:
2249 nonlocal counter
2250
2251 counter += 1
2252
2253 if counter == 1:
2254 return "first"
2255
2256 return "second"
2257
2258 client = AsyncOpenAI(base_url=base_url, api_key=token_provider)
2259 await client.chat.completions.create(messages=[], model="gpt-4")
2260
2261 calls = cast("list[MockRequestCall]", respx_mock.calls)
2262 assert len(calls) == 2
2263
2264 assert calls[0].request.headers.get("Authorization") == "Bearer first"
2265 assert calls[1].request.headers.get("Authorization") == "Bearer second"
2266
2267 async def test_copy_auth(self) -> None:
2268 async def token_provider_1() -> str:
2269 return "test_bearer_token_1"
2270
2271 async def token_provider_2() -> str:
2272 return "test_bearer_token_2"
2273
2274 client = AsyncOpenAI(base_url=base_url, api_key=token_provider_1).copy(api_key=token_provider_2)
2275 await client._refresh_api_key()
2276 assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}
2277
2278
2279class TestWorkloadIdentity401Retry:
2280 @pytest.mark.respx()
2281 def test_workload_identity_401_retry(self, respx_mock: MockRouter) -> None:
2282 provider_call_count = 0
2283
2284 def provider() -> str:
2285 nonlocal provider_call_count
2286 provider_call_count += 1
2287 return f"external-subject-token-{provider_call_count}"
2288
2289 respx_mock.post("https://auth.openai.com/oauth/token").mock(
2290 side_effect=[
2291 httpx.Response(
2292 200,
2293 json={
2294 "access_token": "openai-access-token-1",
2295 "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
2296 "token_type": "Bearer",
2297 "expires_in": 3600,
2298 },
2299 ),
2300 httpx.Response(
2301 200,
2302 json={
2303 "access_token": "openai-access-token-2",
2304 "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
2305 "token_type": "Bearer",
2306 "expires_in": 3600,
2307 },
2308 ),
2309 ]
2310 )
2311
2312 respx_mock.post(base_url + "/chat/completions").mock(
2313 side_effect=[
2314 httpx.Response(401, json={"error": {"message": "Unauthorized", "type": "invalid_request_error"}}),
2315 httpx.Response(
2316 200,
2317 json={
2318 "id": "chatcmpl-123",
2319 "object": "chat.completion",
2320 "created": 1234567890,
2321 "model": "gpt-4",
2322 "choices": [],
2323 },
2324 ),
2325 ]
2326 )
2327
2328 with OpenAI(
2329 base_url=base_url,
2330 workload_identity={
2331 **workload_identity,
2332 "provider": {
2333 "get_token": provider,
2334 "token_type": "jwt",
2335 },
2336 },
2337 organization="org_123",
2338 project="proj_123",
2339 _strict_response_validation=True,
2340 ) as client:
2341 client.chat.completions.create(messages=[], model="gpt-4")
2342
2343 calls = cast("list[MockRequestCall]", respx_mock.calls)
2344 assert len(calls) == 4
2345
2346 assert calls[0].request.url == httpx.URL("https://auth.openai.com/oauth/token")
2347 assert calls[1].request.url == httpx.URL(base_url + "/chat/completions")
2348 assert calls[1].request.headers.get("Authorization") == "Bearer openai-access-token-1"
2349
2350 assert calls[2].request.url == httpx.URL("https://auth.openai.com/oauth/token")
2351
2352 assert calls[3].request.url == httpx.URL(base_url + "/chat/completions")
2353 assert calls[3].request.headers.get("Authorization") == "Bearer openai-access-token-2"
2354
2355 assert provider_call_count == 2
2356
2357 @pytest.mark.respx()
2358 def test_401_without_workload_identity_no_retry(self, respx_mock: MockRouter) -> None:
2359 respx_mock.post(base_url + "/chat/completions").mock(
2360 return_value=httpx.Response(
2361 401, json={"error": {"message": "Unauthorized", "type": "invalid_request_error"}}
2362 )
2363 )
2364
2365 with OpenAI(
2366 base_url=base_url,
2367 api_key="test-api-key",
2368 _strict_response_validation=True,
2369 ) as client:
2370 with pytest.raises(APIStatusError) as exc_info:
2371 client.chat.completions.create(messages=[], model="gpt-4")
2372
2373 assert exc_info.value.status_code == 401
2374
2375 calls = cast("list[MockRequestCall]", respx_mock.calls)
2376 assert len(calls) == 1
2377
2378 @pytest.mark.respx()
2379 def test_non_401_errors_no_retry(self, respx_mock: MockRouter) -> None:
2380 provider_call_count = 0
2381
2382 def provider() -> str:
2383 nonlocal provider_call_count
2384 provider_call_count += 1
2385 return "external-subject-token"
2386
2387 respx_mock.post("https://auth.openai.com/oauth/token").mock(
2388 return_value=httpx.Response(
2389 200,
2390 json={
2391 "access_token": "openai-access-token-1",
2392 "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
2393 "token_type": "Bearer",
2394 "expires_in": 3600,
2395 },
2396 )
2397 )
2398
2399 respx_mock.post(base_url + "/chat/completions").mock(
2400 return_value=httpx.Response(403, json={"error": {"message": "Forbidden", "type": "invalid_request_error"}})
2401 )
2402
2403 with OpenAI(
2404 base_url=base_url,
2405 workload_identity={
2406 **workload_identity,
2407 "provider": {
2408 "get_token": provider,
2409 "token_type": "jwt",
2410 },
2411 },
2412 organization="org_123",
2413 project="proj_123",
2414 _strict_response_validation=True,
2415 ) as client:
2416 with pytest.raises(APIStatusError) as exc_info:
2417 client.chat.completions.create(messages=[], model="gpt-4")
2418
2419 assert exc_info.value.status_code == 403
2420
2421 calls = cast("list[MockRequestCall]", respx_mock.calls)
2422 assert len(calls) == 2
2423
2424 assert provider_call_count == 1
2425
2426
2427class TestAsyncWorkloadIdentity401Retry:
2428 @pytest.mark.respx()
2429 async def test_workload_identity_401_retry(self, respx_mock: MockRouter) -> None:
2430 provider_call_count = 0
2431
2432 def provider() -> str:
2433 nonlocal provider_call_count
2434 provider_call_count += 1
2435 return f"external-subject-token-{provider_call_count}"
2436
2437 respx_mock.post("https://auth.openai.com/oauth/token").mock(
2438 side_effect=[
2439 httpx.Response(
2440 200,
2441 json={
2442 "access_token": "openai-access-token-1",
2443 "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
2444 "token_type": "Bearer",
2445 "expires_in": 3600,
2446 },
2447 ),
2448 httpx.Response(
2449 200,
2450 json={
2451 "access_token": "openai-access-token-2",
2452 "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
2453 "token_type": "Bearer",
2454 "expires_in": 3600,
2455 },
2456 ),
2457 ]
2458 )
2459
2460 respx_mock.post(base_url + "/chat/completions").mock(
2461 side_effect=[
2462 httpx.Response(401, json={"error": {"message": "Unauthorized", "type": "invalid_request_error"}}),
2463 httpx.Response(
2464 200,
2465 json={
2466 "id": "chatcmpl-123",
2467 "object": "chat.completion",
2468 "created": 1234567890,
2469 "model": "gpt-4",
2470 "choices": [],
2471 },
2472 ),
2473 ]
2474 )
2475
2476 async with AsyncOpenAI(
2477 base_url=base_url,
2478 workload_identity={
2479 **workload_identity,
2480 "provider": {
2481 "get_token": provider,
2482 "token_type": "jwt",
2483 },
2484 },
2485 organization="org_123",
2486 project="proj_123",
2487 _strict_response_validation=True,
2488 ) as client:
2489 await client.chat.completions.create(messages=[], model="gpt-4")
2490
2491 calls = cast("list[MockRequestCall]", respx_mock.calls)
2492 assert len(calls) == 4
2493
2494 assert calls[0].request.url == httpx.URL("https://auth.openai.com/oauth/token")
2495 assert calls[1].request.url == httpx.URL(base_url + "/chat/completions")
2496 assert calls[1].request.headers.get("Authorization") == "Bearer openai-access-token-1"
2497
2498 assert calls[2].request.url == httpx.URL("https://auth.openai.com/oauth/token")
2499
2500 assert calls[3].request.url == httpx.URL(base_url + "/chat/completions")
2501 assert calls[3].request.headers.get("Authorization") == "Bearer openai-access-token-2"
2502
2503 assert provider_call_count == 2
2504
2505 @pytest.mark.respx()
2506 async def test_401_without_workload_identity_no_retry(self, respx_mock: MockRouter) -> None:
2507 respx_mock.post(base_url + "/chat/completions").mock(
2508 return_value=httpx.Response(
2509 401, json={"error": {"message": "Unauthorized", "type": "invalid_request_error"}}
2510 )
2511 )
2512
2513 async with AsyncOpenAI(
2514 base_url=base_url,
2515 api_key="test-api-key",
2516 _strict_response_validation=True,
2517 ) as client:
2518 with pytest.raises(APIStatusError) as exc_info:
2519 await client.chat.completions.create(messages=[], model="gpt-4")
2520
2521 assert exc_info.value.status_code == 401
2522
2523 calls = cast("list[MockRequestCall]", respx_mock.calls)
2524 assert len(calls) == 1
2525
2526 @pytest.mark.respx()
2527 async def test_non_401_errors_no_retry(self, respx_mock: MockRouter) -> None:
2528 provider_call_count = 0
2529
2530 def provider() -> str:
2531 nonlocal provider_call_count
2532 provider_call_count += 1
2533 return "external-subject-token"
2534
2535 respx_mock.post("https://auth.openai.com/oauth/token").mock(
2536 return_value=httpx.Response(
2537 200,
2538 json={
2539 "access_token": "openai-access-token-1",
2540 "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
2541 "token_type": "Bearer",
2542 "expires_in": 3600,
2543 },
2544 )
2545 )
2546
2547 respx_mock.post(base_url + "/chat/completions").mock(
2548 return_value=httpx.Response(403, json={"error": {"message": "Forbidden", "type": "invalid_request_error"}})
2549 )
2550
2551 async with AsyncOpenAI(
2552 base_url=base_url,
2553 workload_identity={
2554 **workload_identity,
2555 "provider": {
2556 "get_token": provider,
2557 "token_type": "jwt",
2558 },
2559 },
2560 organization="org_123",
2561 project="proj_123",
2562 _strict_response_validation=True,
2563 ) as client:
2564 with pytest.raises(APIStatusError) as exc_info:
2565 await client.chat.completions.create(messages=[], model="gpt-4")
2566
2567 assert exc_info.value.status_code == 403
2568
2569 calls = cast("list[MockRequestCall]", respx_mock.calls)
2570 assert len(calls) == 2
2571
2572 assert provider_call_count == 1