openai/openai-python

Public

mirrored from https://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
daaf92f3fdc127defbb90c9401a785c097d62890

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

2800lines · modecode

1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import gc
6import os
7import sys
8import json
9import asyncio
10import inspect
11import dataclasses
12import tracemalloc
13from typing import Any, Union, TypeVar, Callable, Iterable, Iterator, Optional, Coroutine, cast
14from unittest import mock
15from typing_extensions import Literal, AsyncIterator, override
16
17import httpx
18import pytest
19from respx import MockRouter
20from pydantic import ValidationError
21from respx.models import Call as MockRequestCall
22
23from openai import OpenAI, AsyncOpenAI, OpenAIError, APIResponseValidationError
24from openai.auth import WorkloadIdentity
25from openai._types import Omit
26from openai._utils import asyncify
27from openai._models import BaseModel, FinalRequestOptions
28from openai._streaming import Stream, AsyncStream
29from openai._exceptions import APIStatusError, APITimeoutError, APIResponseValidationError
30from openai._base_client import (
31 DEFAULT_TIMEOUT,
32 HTTPX_DEFAULT_TIMEOUT,
33 BaseClient,
34 OtherPlatform,
35 DefaultHttpxClient,
36 DefaultAsyncHttpxClient,
37 get_platform,
38 make_request_options,
39)
40
41from .utils import update_env
42
43T = TypeVar("T")
44base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
45api_key = "My API Key"
46admin_api_key = "My Admin API Key"
47workload_identity: WorkloadIdentity = {
48 "client_id": "client_123",
49 "identity_provider_id": "provider_123",
50 "service_account_id": "service_account_123",
51 "provider": {
52 "get_token": lambda: "external-subject-token",
53 "token_type": "jwt",
54 },
55}
56
57
58def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
59 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
60 url = httpx.URL(request.url)
61 return dict(url.params)
62
63
64def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
65 return 0.1
66
67
68def mirror_request_content(request: httpx.Request) -> httpx.Response:
69 return httpx.Response(200, content=request.content)
70
71
72# note: we can't use the httpx.MockTransport class as it consumes the request
73# body itself, which means we can't test that the body is read lazily
74class MockTransport(httpx.BaseTransport, httpx.AsyncBaseTransport):
75 def __init__(
76 self,
77 handler: Callable[[httpx.Request], httpx.Response]
78 | Callable[[httpx.Request], Coroutine[Any, Any, httpx.Response]],
79 ) -> None:
80 self.handler = handler
81
82 @override
83 def handle_request(
84 self,
85 request: httpx.Request,
86 ) -> httpx.Response:
87 assert not inspect.iscoroutinefunction(self.handler), "handler must not be a coroutine function"
88 assert inspect.isfunction(self.handler), "handler must be a function"
89 return self.handler(request)
90
91 @override
92 async def handle_async_request(
93 self,
94 request: httpx.Request,
95 ) -> httpx.Response:
96 assert inspect.iscoroutinefunction(self.handler), "handler must be a coroutine function"
97 return await self.handler(request)
98
99
100@dataclasses.dataclass
101class Counter:
102 value: int = 0
103
104
105def _make_sync_iterator(iterable: Iterable[T], counter: Optional[Counter] = None) -> Iterator[T]:
106 for item in iterable:
107 if counter:
108 counter.value += 1
109 yield item
110
111
112async def _make_async_iterator(iterable: Iterable[T], counter: Optional[Counter] = None) -> AsyncIterator[T]:
113 for item in iterable:
114 if counter:
115 counter.value += 1
116 yield item
117
118
119def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
120 transport = client._client._transport
121 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
122
123 pool = transport._pool
124 return len(pool._requests)
125
126
127class TestOpenAI:
128 @pytest.mark.respx(base_url=base_url)
129 def test_raw_response(self, respx_mock: MockRouter, client: OpenAI) -> None:
130 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
131
132 response = client.post("/foo", cast_to=httpx.Response)
133 assert response.status_code == 200
134 assert isinstance(response, httpx.Response)
135 assert response.json() == {"foo": "bar"}
136
137 @pytest.mark.respx(base_url=base_url)
138 def test_raw_response_for_binary(self, respx_mock: MockRouter, client: OpenAI) -> None:
139 respx_mock.post("/foo").mock(
140 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
141 )
142
143 response = client.post("/foo", cast_to=httpx.Response)
144 assert response.status_code == 200
145 assert isinstance(response, httpx.Response)
146 assert response.json() == {"foo": "bar"}
147
148 def test_copy(self, client: OpenAI) -> None:
149 copied = client.copy()
150 assert id(copied) != id(client)
151
152 copied = client.copy(api_key="another My API Key")
153 assert copied.api_key == "another My API Key"
154 assert client.api_key == "My API Key"
155
156 copied = client.copy(admin_api_key="another My Admin API Key")
157 assert copied.admin_api_key == "another My Admin API Key"
158 assert client.admin_api_key == "My Admin API Key"
159
160 def test_copy_default_options(self, client: OpenAI) -> None:
161 # options that have a default are overridden correctly
162 copied = client.copy(max_retries=7)
163 assert copied.max_retries == 7
164 assert client.max_retries == 2
165
166 copied2 = copied.copy(max_retries=6)
167 assert copied2.max_retries == 6
168 assert copied.max_retries == 7
169
170 # timeout
171 assert isinstance(client.timeout, httpx.Timeout)
172 copied = client.copy(timeout=None)
173 assert copied.timeout is None
174 assert isinstance(client.timeout, httpx.Timeout)
175
176 def test_copy_default_headers(self) -> None:
177 client = OpenAI(
178 base_url=base_url,
179 api_key=api_key,
180 admin_api_key=admin_api_key,
181 _strict_response_validation=True,
182 default_headers={"X-Foo": "bar"},
183 )
184 assert client.default_headers["X-Foo"] == "bar"
185
186 # does not override the already given value when not specified
187 copied = client.copy()
188 assert copied.default_headers["X-Foo"] == "bar"
189
190 # merges already given headers
191 copied = client.copy(default_headers={"X-Bar": "stainless"})
192 assert copied.default_headers["X-Foo"] == "bar"
193 assert copied.default_headers["X-Bar"] == "stainless"
194
195 # uses new values for any already given headers
196 copied = client.copy(default_headers={"X-Foo": "stainless"})
197 assert copied.default_headers["X-Foo"] == "stainless"
198
199 # set_default_headers
200
201 # completely overrides already set values
202 copied = client.copy(set_default_headers={})
203 assert copied.default_headers.get("X-Foo") is None
204
205 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
206 assert copied.default_headers["X-Bar"] == "Robert"
207
208 with pytest.raises(
209 ValueError,
210 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
211 ):
212 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
213 client.close()
214
215 def test_copy_default_query(self) -> None:
216 client = OpenAI(
217 base_url=base_url,
218 api_key=api_key,
219 admin_api_key=admin_api_key,
220 _strict_response_validation=True,
221 default_query={"foo": "bar"},
222 )
223 assert _get_params(client)["foo"] == "bar"
224
225 # does not override the already given value when not specified
226 copied = client.copy()
227 assert _get_params(copied)["foo"] == "bar"
228
229 # merges already given params
230 copied = client.copy(default_query={"bar": "stainless"})
231 params = _get_params(copied)
232 assert params["foo"] == "bar"
233 assert params["bar"] == "stainless"
234
235 # uses new values for any already given headers
236 copied = client.copy(default_query={"foo": "stainless"})
237 assert _get_params(copied)["foo"] == "stainless"
238
239 # set_default_query
240
241 # completely overrides already set values
242 copied = client.copy(set_default_query={})
243 assert _get_params(copied) == {}
244
245 copied = client.copy(set_default_query={"bar": "Robert"})
246 assert _get_params(copied)["bar"] == "Robert"
247
248 with pytest.raises(
249 ValueError,
250 # TODO: update
251 match="`default_query` and `set_default_query` arguments are mutually exclusive",
252 ):
253 client.copy(set_default_query={}, default_query={"foo": "Bar"})
254
255 client.close()
256
257 def test_copy_signature(self, client: OpenAI) -> None:
258 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
259 init_signature = inspect.signature(
260 # mypy doesn't like that we access the `__init__` property.
261 client.__init__, # type: ignore[misc]
262 )
263 copy_signature = inspect.signature(client.copy)
264 exclude_params = {"transport", "proxies", "_strict_response_validation"}
265
266 for name in init_signature.parameters.keys():
267 if name in exclude_params:
268 continue
269
270 copy_param = copy_signature.parameters.get(name)
271 assert copy_param is not None, f"copy() signature is missing the {name} param"
272
273 @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
274 def test_copy_build_request(self, client: OpenAI) -> None:
275 options = FinalRequestOptions(method="get", url="/foo")
276
277 def build_request(options: FinalRequestOptions) -> None:
278 client_copy = client.copy()
279 client_copy._build_request(options)
280
281 # ensure that the machinery is warmed up before tracing starts.
282 build_request(options)
283 gc.collect()
284
285 tracemalloc.start(1000)
286
287 snapshot_before = tracemalloc.take_snapshot()
288
289 ITERATIONS = 10
290 for _ in range(ITERATIONS):
291 build_request(options)
292
293 gc.collect()
294 snapshot_after = tracemalloc.take_snapshot()
295
296 tracemalloc.stop()
297
298 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
299 if diff.count == 0:
300 # Avoid false positives by considering only leaks (i.e. allocations that persist).
301 return
302
303 if diff.count % ITERATIONS != 0:
304 # Avoid false positives by considering only leaks that appear per iteration.
305 return
306
307 for frame in diff.traceback:
308 if any(
309 frame.filename.endswith(fragment)
310 for fragment in [
311 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
312 #
313 # removing the decorator fixes the leak for reasons we don't understand.
314 "openai/_legacy_response.py",
315 "openai/_response.py",
316 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
317 "openai/_compat.py",
318 # Standard library leaks we don't care about.
319 "/logging/__init__.py",
320 ]
321 ):
322 return
323
324 leaks.append(diff)
325
326 leaks: list[tracemalloc.StatisticDiff] = []
327 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
328 add_leak(leaks, diff)
329 if leaks:
330 for leak in leaks:
331 print("MEMORY LEAK:", leak)
332 for frame in leak.traceback:
333 print(frame)
334 raise AssertionError()
335
336 def test_request_timeout(self, client: OpenAI) -> None:
337 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
338 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
339 assert timeout == DEFAULT_TIMEOUT
340
341 request = client._build_request(FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)))
342 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
343 assert timeout == httpx.Timeout(100.0)
344
345 def test_client_timeout_option(self) -> None:
346 client = OpenAI(
347 base_url=base_url,
348 api_key=api_key,
349 admin_api_key=admin_api_key,
350 _strict_response_validation=True,
351 timeout=httpx.Timeout(0),
352 )
353
354 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
355 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
356 assert timeout == httpx.Timeout(0)
357
358 client.close()
359
360 def test_http_client_timeout_option(self) -> None:
361 # custom timeout given to the httpx client should be used
362 with httpx.Client(timeout=None) as http_client:
363 client = OpenAI(
364 base_url=base_url,
365 api_key=api_key,
366 admin_api_key=admin_api_key,
367 _strict_response_validation=True,
368 http_client=http_client,
369 )
370
371 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
372 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
373 assert timeout == httpx.Timeout(None)
374
375 client.close()
376
377 # no timeout given to the httpx client should not use the httpx default
378 with httpx.Client() as http_client:
379 client = OpenAI(
380 base_url=base_url,
381 api_key=api_key,
382 admin_api_key=admin_api_key,
383 _strict_response_validation=True,
384 http_client=http_client,
385 )
386
387 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
388 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
389 assert timeout == DEFAULT_TIMEOUT
390
391 client.close()
392
393 # explicitly passing the default timeout currently results in it being ignored
394 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
395 client = OpenAI(
396 base_url=base_url,
397 api_key=api_key,
398 admin_api_key=admin_api_key,
399 _strict_response_validation=True,
400 http_client=http_client,
401 )
402
403 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
404 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
405 assert timeout == DEFAULT_TIMEOUT # our default
406
407 client.close()
408
409 async def test_invalid_http_client(self) -> None:
410 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
411 async with httpx.AsyncClient() as http_client:
412 OpenAI(
413 base_url=base_url,
414 api_key=api_key,
415 admin_api_key=admin_api_key,
416 _strict_response_validation=True,
417 http_client=cast(Any, http_client),
418 )
419
420 def test_default_headers_option(self) -> None:
421 test_client = OpenAI(
422 base_url=base_url,
423 api_key=api_key,
424 admin_api_key=admin_api_key,
425 _strict_response_validation=True,
426 default_headers={"X-Foo": "bar"},
427 )
428 request = test_client._build_request(FinalRequestOptions(method="get", url="/foo"))
429 assert request.headers.get("x-foo") == "bar"
430 assert request.headers.get("x-stainless-lang") == "python"
431
432 test_client2 = OpenAI(
433 base_url=base_url,
434 api_key=api_key,
435 admin_api_key=admin_api_key,
436 _strict_response_validation=True,
437 default_headers={
438 "X-Foo": "stainless",
439 "X-Stainless-Lang": "my-overriding-header",
440 },
441 )
442 request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo"))
443 assert request.headers.get("x-foo") == "stainless"
444 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
445
446 test_client.close()
447 test_client2.close()
448
449 def test_validate_headers(self) -> None:
450 client = OpenAI(
451 base_url=base_url, api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=True
452 )
453 options = client._prepare_options(FinalRequestOptions(method="get", url="/foo"))
454 request = client._build_request(options)
455
456 assert request.headers.get("Authorization") == f"Bearer {api_key}"
457
458 admin_request = client._build_request(
459 FinalRequestOptions(
460 method="get",
461 url="/organization/projects",
462 security={"admin_api_key_auth": True},
463 )
464 )
465 assert admin_request.headers.get("Authorization") == f"Bearer {admin_api_key}"
466
467 with update_env(**{"OPENAI_API_KEY": Omit()}):
468 admin_only = OpenAI(
469 base_url=base_url,
470 api_key=None,
471 admin_api_key=admin_api_key,
472 _strict_response_validation=True,
473 )
474 admin_only_request = admin_only._build_request(
475 FinalRequestOptions(
476 method="get",
477 url="/organization/projects",
478 security={"admin_api_key_auth": True},
479 )
480 )
481 assert admin_only_request.headers.get("Authorization") == f"Bearer {admin_api_key}"
482
483 with pytest.raises(
484 TypeError,
485 match="Could not resolve authentication method",
486 ):
487 admin_only._build_request(
488 FinalRequestOptions(
489 method="post",
490 url="/responses",
491 security={"bearer_auth": True},
492 )
493 )
494
495 with update_env(
496 **{
497 "OPENAI_API_KEY": Omit(),
498 "OPENAI_ADMIN_KEY": Omit(),
499 }
500 ):
501 with pytest.raises(OpenAIError, match="Missing credentials"):
502 OpenAI(base_url=base_url, api_key=None, admin_api_key=None, _strict_response_validation=True)
503
504 def test_workload_identity_is_mutually_exclusive_with_api_key(self) -> None:
505 with pytest.raises(
506 OpenAIError,
507 match="The `api_key` and `workload_identity` arguments are mutually exclusive",
508 ):
509 OpenAI(
510 base_url=base_url,
511 api_key=api_key,
512 workload_identity=workload_identity, # type: ignore[reportArgumentType]
513 organization="org_123",
514 _strict_response_validation=True,
515 )
516
517 def test_default_query_option(self) -> None:
518 client = OpenAI(
519 base_url=base_url,
520 api_key=api_key,
521 admin_api_key=admin_api_key,
522 _strict_response_validation=True,
523 default_query={"query_param": "bar"},
524 )
525 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
526 url = httpx.URL(request.url)
527 assert dict(url.params) == {"query_param": "bar"}
528
529 request = client._build_request(
530 FinalRequestOptions(
531 method="get",
532 url="/foo",
533 params={"foo": "baz", "query_param": "overridden"},
534 )
535 )
536 url = httpx.URL(request.url)
537 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
538
539 client.close()
540
541 def test_hardcoded_query_params_in_url(self, client: OpenAI) -> None:
542 request = client._build_request(FinalRequestOptions(method="get", url="/foo?beta=true"))
543 url = httpx.URL(request.url)
544 assert dict(url.params) == {"beta": "true"}
545
546 request = client._build_request(
547 FinalRequestOptions(
548 method="get",
549 url="/foo?beta=true",
550 params={"limit": "10", "page": "abc"},
551 )
552 )
553 url = httpx.URL(request.url)
554 assert dict(url.params) == {"beta": "true", "limit": "10", "page": "abc"}
555
556 request = client._build_request(
557 FinalRequestOptions(
558 method="get",
559 url="/files/a%2Fb?beta=true",
560 params={"limit": "10"},
561 )
562 )
563 assert request.url.raw_path == b"/files/a%2Fb?beta=true&limit=10"
564
565 def test_request_extra_json(self, client: OpenAI) -> None:
566 request = client._build_request(
567 FinalRequestOptions(
568 method="post",
569 url="/foo",
570 json_data={"foo": "bar"},
571 extra_json={"baz": False},
572 ),
573 )
574 data = json.loads(request.content.decode("utf-8"))
575 assert data == {"foo": "bar", "baz": False}
576
577 request = client._build_request(
578 FinalRequestOptions(
579 method="post",
580 url="/foo",
581 extra_json={"baz": False},
582 ),
583 )
584 data = json.loads(request.content.decode("utf-8"))
585 assert data == {"baz": False}
586
587 # `extra_json` takes priority over `json_data` when keys clash
588 request = client._build_request(
589 FinalRequestOptions(
590 method="post",
591 url="/foo",
592 json_data={"foo": "bar", "baz": True},
593 extra_json={"baz": None},
594 ),
595 )
596 data = json.loads(request.content.decode("utf-8"))
597 assert data == {"foo": "bar", "baz": None}
598
599 def test_request_extra_headers(self, client: OpenAI) -> None:
600 request = client._build_request(
601 FinalRequestOptions(
602 method="post",
603 url="/foo",
604 **make_request_options(extra_headers={"X-Foo": "Foo"}),
605 ),
606 )
607 assert request.headers.get("X-Foo") == "Foo"
608
609 # `extra_headers` takes priority over `default_headers` when keys clash
610 request = client.with_options(default_headers={"X-Bar": "true"})._build_request(
611 FinalRequestOptions(
612 method="post",
613 url="/foo",
614 **make_request_options(
615 extra_headers={"X-Bar": "false"},
616 ),
617 ),
618 )
619 assert request.headers.get("X-Bar") == "false"
620
621 def test_request_extra_query(self, client: OpenAI) -> None:
622 request = client._build_request(
623 FinalRequestOptions(
624 method="post",
625 url="/foo",
626 **make_request_options(
627 extra_query={"my_query_param": "Foo"},
628 ),
629 ),
630 )
631 params = dict(request.url.params)
632 assert params == {"my_query_param": "Foo"}
633
634 # if both `query` and `extra_query` are given, they are merged
635 request = client._build_request(
636 FinalRequestOptions(
637 method="post",
638 url="/foo",
639 **make_request_options(
640 query={"bar": "1"},
641 extra_query={"foo": "2"},
642 ),
643 ),
644 )
645 params = dict(request.url.params)
646 assert params == {"bar": "1", "foo": "2"}
647
648 # `extra_query` takes priority over `query` when keys clash
649 request = client._build_request(
650 FinalRequestOptions(
651 method="post",
652 url="/foo",
653 **make_request_options(
654 query={"foo": "1"},
655 extra_query={"foo": "2"},
656 ),
657 ),
658 )
659 params = dict(request.url.params)
660 assert params == {"foo": "2"}
661
662 def test_multipart_repeating_array(self, client: OpenAI) -> None:
663 request = client._build_request(
664 FinalRequestOptions.construct(
665 method="post",
666 url="/foo",
667 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
668 json_data={"array": ["foo", "bar"]},
669 files=[("foo.txt", b"hello world")],
670 )
671 )
672
673 assert request.read().split(b"\r\n") == [
674 b"--6b7ba517decee4a450543ea6ae821c82",
675 b'Content-Disposition: form-data; name="array[]"',
676 b"",
677 b"foo",
678 b"--6b7ba517decee4a450543ea6ae821c82",
679 b'Content-Disposition: form-data; name="array[]"',
680 b"",
681 b"bar",
682 b"--6b7ba517decee4a450543ea6ae821c82",
683 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
684 b"Content-Type: application/octet-stream",
685 b"",
686 b"hello world",
687 b"--6b7ba517decee4a450543ea6ae821c82--",
688 b"",
689 ]
690
691 @pytest.mark.respx(base_url=base_url)
692 def test_binary_content_upload(self, respx_mock: MockRouter, client: OpenAI) -> None:
693 respx_mock.post("/upload").mock(side_effect=mirror_request_content)
694
695 file_content = b"Hello, this is a test file."
696
697 response = client.post(
698 "/upload",
699 content=file_content,
700 cast_to=httpx.Response,
701 options={"headers": {"Content-Type": "application/octet-stream"}},
702 )
703
704 assert response.status_code == 200
705 assert response.request.headers["Content-Type"] == "application/octet-stream"
706 assert response.content == file_content
707
708 def test_binary_content_upload_with_iterator(self) -> None:
709 file_content = b"Hello, this is a test file."
710 counter = Counter()
711 iterator = _make_sync_iterator([file_content], counter=counter)
712
713 def mock_handler(request: httpx.Request) -> httpx.Response:
714 assert counter.value == 0, "the request body should not have been read"
715 return httpx.Response(200, content=request.read())
716
717 with OpenAI(
718 base_url=base_url,
719 api_key=api_key,
720 admin_api_key=admin_api_key,
721 _strict_response_validation=True,
722 http_client=httpx.Client(transport=MockTransport(handler=mock_handler)),
723 ) as client:
724 response = client.post(
725 "/upload",
726 content=iterator,
727 cast_to=httpx.Response,
728 options={"headers": {"Content-Type": "application/octet-stream"}},
729 )
730
731 assert response.status_code == 200
732 assert response.request.headers["Content-Type"] == "application/octet-stream"
733 assert response.content == file_content
734 assert counter.value == 1
735
736 @pytest.mark.respx(base_url=base_url)
737 def test_binary_content_upload_with_body_is_deprecated(self, respx_mock: MockRouter, client: OpenAI) -> None:
738 respx_mock.post("/upload").mock(side_effect=mirror_request_content)
739
740 file_content = b"Hello, this is a test file."
741
742 with pytest.deprecated_call(
743 match="Passing raw bytes as `body` is deprecated and will be removed in a future version. Please pass raw bytes via the `content` parameter instead."
744 ):
745 response = client.post(
746 "/upload",
747 body=file_content,
748 cast_to=httpx.Response,
749 options={"headers": {"Content-Type": "application/octet-stream"}},
750 )
751
752 assert response.status_code == 200
753 assert response.request.headers["Content-Type"] == "application/octet-stream"
754 assert response.content == file_content
755
756 @pytest.mark.respx(base_url=base_url)
757 def test_basic_union_response(self, respx_mock: MockRouter, client: OpenAI) -> None:
758 class Model1(BaseModel):
759 name: str
760
761 class Model2(BaseModel):
762 foo: str
763
764 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
765
766 response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
767 assert isinstance(response, Model2)
768 assert response.foo == "bar"
769
770 @pytest.mark.respx(base_url=base_url)
771 def test_union_response_different_types(self, respx_mock: MockRouter, client: OpenAI) -> None:
772 """Union of objects with the same field name using a different type"""
773
774 class Model1(BaseModel):
775 foo: int
776
777 class Model2(BaseModel):
778 foo: str
779
780 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
781
782 response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
783 assert isinstance(response, Model2)
784 assert response.foo == "bar"
785
786 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
787
788 response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
789 assert isinstance(response, Model1)
790 assert response.foo == 1
791
792 @pytest.mark.respx(base_url=base_url)
793 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter, client: OpenAI) -> None:
794 """
795 Response that sets Content-Type to something other than application/json but returns json data
796 """
797
798 class Model(BaseModel):
799 foo: int
800
801 respx_mock.get("/foo").mock(
802 return_value=httpx.Response(
803 200,
804 content=json.dumps({"foo": 2}),
805 headers={"Content-Type": "application/text"},
806 )
807 )
808
809 response = client.get("/foo", cast_to=Model)
810 assert isinstance(response, Model)
811 assert response.foo == 2
812
813 def test_base_url_setter(self) -> None:
814 client = OpenAI(
815 base_url="https://example.com/from_init",
816 api_key=api_key,
817 admin_api_key=admin_api_key,
818 _strict_response_validation=True,
819 )
820 assert client.base_url == "https://example.com/from_init/"
821
822 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
823
824 assert client.base_url == "https://example.com/from_setter/"
825
826 client.close()
827
828 def test_base_url_env(self) -> None:
829 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
830 client = OpenAI(api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=True)
831 assert client.base_url == "http://localhost:5000/from/env/"
832
833 @pytest.mark.parametrize(
834 "client",
835 [
836 OpenAI(
837 base_url="http://localhost:5000/custom/path/",
838 api_key=api_key,
839 admin_api_key=admin_api_key,
840 _strict_response_validation=True,
841 ),
842 OpenAI(
843 base_url="http://localhost:5000/custom/path/",
844 api_key=api_key,
845 admin_api_key=admin_api_key,
846 _strict_response_validation=True,
847 http_client=httpx.Client(),
848 ),
849 ],
850 ids=["standard", "custom http client"],
851 )
852 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
853 request = client._build_request(
854 FinalRequestOptions(
855 method="post",
856 url="/foo",
857 json_data={"foo": "bar"},
858 ),
859 )
860 assert request.url == "http://localhost:5000/custom/path/foo"
861 client.close()
862
863 @pytest.mark.parametrize(
864 "client",
865 [
866 OpenAI(
867 base_url="http://localhost:5000/custom/path/",
868 api_key=api_key,
869 admin_api_key=admin_api_key,
870 _strict_response_validation=True,
871 ),
872 OpenAI(
873 base_url="http://localhost:5000/custom/path/",
874 api_key=api_key,
875 admin_api_key=admin_api_key,
876 _strict_response_validation=True,
877 http_client=httpx.Client(),
878 ),
879 ],
880 ids=["standard", "custom http client"],
881 )
882 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
883 request = client._build_request(
884 FinalRequestOptions(
885 method="post",
886 url="/foo",
887 json_data={"foo": "bar"},
888 ),
889 )
890 assert request.url == "http://localhost:5000/custom/path/foo"
891 client.close()
892
893 @pytest.mark.parametrize(
894 "client",
895 [
896 OpenAI(
897 base_url="http://localhost:5000/custom/path/",
898 api_key=api_key,
899 admin_api_key=admin_api_key,
900 _strict_response_validation=True,
901 ),
902 OpenAI(
903 base_url="http://localhost:5000/custom/path/",
904 api_key=api_key,
905 admin_api_key=admin_api_key,
906 _strict_response_validation=True,
907 http_client=httpx.Client(),
908 ),
909 ],
910 ids=["standard", "custom http client"],
911 )
912 def test_absolute_request_url(self, client: OpenAI) -> None:
913 request = client._build_request(
914 FinalRequestOptions(
915 method="post",
916 url="https://myapi.com/foo",
917 json_data={"foo": "bar"},
918 ),
919 )
920 assert request.url == "https://myapi.com/foo"
921 client.close()
922
923 def test_copied_client_does_not_close_http(self) -> None:
924 test_client = OpenAI(
925 base_url=base_url, api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=True
926 )
927 assert not test_client.is_closed()
928
929 copied = test_client.copy()
930 assert copied is not test_client
931
932 del copied
933
934 assert not test_client.is_closed()
935
936 def test_client_context_manager(self) -> None:
937 test_client = OpenAI(
938 base_url=base_url, api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=True
939 )
940 with test_client as c2:
941 assert c2 is test_client
942 assert not c2.is_closed()
943 assert not test_client.is_closed()
944 assert test_client.is_closed()
945
946 @pytest.mark.respx(base_url=base_url)
947 def test_client_response_validation_error(self, respx_mock: MockRouter, client: OpenAI) -> None:
948 class Model(BaseModel):
949 foo: str
950
951 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
952
953 with pytest.raises(APIResponseValidationError) as exc:
954 client.get("/foo", cast_to=Model)
955
956 assert isinstance(exc.value.__cause__, ValidationError)
957
958 def test_client_max_retries_validation(self) -> None:
959 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
960 OpenAI(
961 base_url=base_url,
962 api_key=api_key,
963 admin_api_key=admin_api_key,
964 _strict_response_validation=True,
965 max_retries=cast(Any, None),
966 )
967
968 @pytest.mark.respx(base_url=base_url)
969 def test_default_stream_cls(self, respx_mock: MockRouter, client: OpenAI) -> None:
970 class Model(BaseModel):
971 name: str
972
973 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
974
975 stream = client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
976 assert isinstance(stream, Stream)
977 stream.response.close()
978
979 @pytest.mark.respx(base_url=base_url)
980 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
981 class Model(BaseModel):
982 name: str
983
984 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
985
986 strict_client = OpenAI(
987 base_url=base_url, api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=True
988 )
989
990 with pytest.raises(APIResponseValidationError):
991 strict_client.get("/foo", cast_to=Model)
992
993 non_strict_client = OpenAI(
994 base_url=base_url, api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=False
995 )
996
997 response = non_strict_client.get("/foo", cast_to=Model)
998 assert isinstance(response, str) # type: ignore[unreachable]
999
1000 strict_client.close()
1001 non_strict_client.close()
1002
1003 @pytest.mark.parametrize(
1004 "remaining_retries,retry_after,timeout",
1005 [
1006 [3, "20", 20],
1007 [3, "0", 0.5],
1008 [3, "-10", 0.5],
1009 [3, "60", 60],
1010 [3, "61", 0.5],
1011 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1012 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1013 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1014 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1015 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1016 [3, "99999999999999999999999999999999999", 0.5],
1017 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1018 [3, "", 0.5],
1019 [2, "", 0.5 * 2.0],
1020 [1, "", 0.5 * 4.0],
1021 [-1100, "", 8], # test large number potentially overflowing
1022 ],
1023 )
1024 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1025 def test_parse_retry_after_header(
1026 self, remaining_retries: int, retry_after: str, timeout: float, client: OpenAI
1027 ) -> None:
1028 headers = httpx.Headers({"retry-after": retry_after})
1029 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1030 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1031 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1032
1033 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1034 @pytest.mark.respx(base_url=base_url)
1035 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: OpenAI) -> None:
1036 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1037
1038 with pytest.raises(APITimeoutError):
1039 client.chat.completions.with_streaming_response.create(
1040 messages=[
1041 {
1042 "content": "string",
1043 "role": "developer",
1044 }
1045 ],
1046 model="gpt-5.4",
1047 ).__enter__()
1048
1049 assert _get_open_connections(client) == 0
1050
1051 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1052 @pytest.mark.respx(base_url=base_url)
1053 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: OpenAI) -> None:
1054 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1055
1056 with pytest.raises(APIStatusError):
1057 client.chat.completions.with_streaming_response.create(
1058 messages=[
1059 {
1060 "content": "string",
1061 "role": "developer",
1062 }
1063 ],
1064 model="gpt-5.4",
1065 ).__enter__()
1066 assert _get_open_connections(client) == 0
1067
1068 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1069 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1070 @pytest.mark.respx(base_url=base_url)
1071 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
1072 def test_retries_taken(
1073 self,
1074 client: OpenAI,
1075 failures_before_success: int,
1076 failure_mode: Literal["status", "exception"],
1077 respx_mock: MockRouter,
1078 ) -> None:
1079 client = client.with_options(max_retries=4)
1080
1081 nb_retries = 0
1082
1083 def retry_handler(_request: httpx.Request) -> httpx.Response:
1084 nonlocal nb_retries
1085 if nb_retries < failures_before_success:
1086 nb_retries += 1
1087 if failure_mode == "exception":
1088 raise RuntimeError("oops")
1089 return httpx.Response(500)
1090 return httpx.Response(200)
1091
1092 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1093
1094 response = client.chat.completions.with_raw_response.create(
1095 messages=[
1096 {
1097 "content": "string",
1098 "role": "developer",
1099 }
1100 ],
1101 model="gpt-5.4",
1102 )
1103
1104 assert response.retries_taken == failures_before_success
1105 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1106
1107 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1108 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1109 @pytest.mark.respx(base_url=base_url)
1110 def test_omit_retry_count_header(
1111 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
1112 ) -> None:
1113 client = client.with_options(max_retries=4)
1114
1115 nb_retries = 0
1116
1117 def retry_handler(_request: httpx.Request) -> httpx.Response:
1118 nonlocal nb_retries
1119 if nb_retries < failures_before_success:
1120 nb_retries += 1
1121 return httpx.Response(500)
1122 return httpx.Response(200)
1123
1124 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1125
1126 response = client.chat.completions.with_raw_response.create(
1127 messages=[
1128 {
1129 "content": "string",
1130 "role": "developer",
1131 }
1132 ],
1133 model="gpt-5.4",
1134 extra_headers={"x-stainless-retry-count": Omit()},
1135 )
1136
1137 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
1138
1139 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1140 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1141 @pytest.mark.respx(base_url=base_url)
1142 def test_overwrite_retry_count_header(
1143 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
1144 ) -> None:
1145 client = client.with_options(max_retries=4)
1146
1147 nb_retries = 0
1148
1149 def retry_handler(_request: httpx.Request) -> httpx.Response:
1150 nonlocal nb_retries
1151 if nb_retries < failures_before_success:
1152 nb_retries += 1
1153 return httpx.Response(500)
1154 return httpx.Response(200)
1155
1156 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1157
1158 response = client.chat.completions.with_raw_response.create(
1159 messages=[
1160 {
1161 "content": "string",
1162 "role": "developer",
1163 }
1164 ],
1165 model="gpt-5.4",
1166 extra_headers={"x-stainless-retry-count": "42"},
1167 )
1168
1169 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
1170
1171 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1172 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1173 @pytest.mark.respx(base_url=base_url)
1174 def test_retries_taken_new_response_class(
1175 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
1176 ) -> None:
1177 client = client.with_options(max_retries=4)
1178
1179 nb_retries = 0
1180
1181 def retry_handler(_request: httpx.Request) -> httpx.Response:
1182 nonlocal nb_retries
1183 if nb_retries < failures_before_success:
1184 nb_retries += 1
1185 return httpx.Response(500)
1186 return httpx.Response(200)
1187
1188 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1189
1190 with client.chat.completions.with_streaming_response.create(
1191 messages=[
1192 {
1193 "content": "string",
1194 "role": "developer",
1195 }
1196 ],
1197 model="gpt-5.4",
1198 ) as response:
1199 assert response.retries_taken == failures_before_success
1200 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1201
1202 def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
1203 # Test that the proxy environment variables are set correctly
1204 monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
1205 # Delete in case our environment has any proxy env vars set
1206 monkeypatch.delenv("HTTP_PROXY", raising=False)
1207 monkeypatch.delenv("ALL_PROXY", raising=False)
1208 monkeypatch.delenv("NO_PROXY", raising=False)
1209 monkeypatch.delenv("http_proxy", raising=False)
1210 monkeypatch.delenv("https_proxy", raising=False)
1211 monkeypatch.delenv("all_proxy", raising=False)
1212 monkeypatch.delenv("no_proxy", raising=False)
1213
1214 client = DefaultHttpxClient()
1215
1216 mounts = tuple(client._mounts.items())
1217 assert len(mounts) == 1
1218 assert mounts[0][0].pattern == "https://"
1219
1220 @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
1221 def test_default_client_creation(self) -> None:
1222 # Ensure that the client can be initialized without any exceptions
1223 DefaultHttpxClient(
1224 verify=True,
1225 cert=None,
1226 trust_env=True,
1227 http1=True,
1228 http2=False,
1229 limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
1230 )
1231
1232 @pytest.mark.respx(base_url=base_url)
1233 def test_follow_redirects(self, respx_mock: MockRouter, client: OpenAI) -> None:
1234 # Test that the default follow_redirects=True allows following redirects
1235 respx_mock.post("/redirect").mock(
1236 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1237 )
1238 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
1239
1240 response = client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
1241 assert response.status_code == 200
1242 assert response.json() == {"status": "ok"}
1243
1244 @pytest.mark.respx(base_url=base_url)
1245 def test_follow_redirects_disabled(self, respx_mock: MockRouter, client: OpenAI) -> None:
1246 # Test that follow_redirects=False prevents following redirects
1247 respx_mock.post("/redirect").mock(
1248 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1249 )
1250
1251 with pytest.raises(APIStatusError) as exc_info:
1252 client.post("/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response)
1253
1254 assert exc_info.value.response.status_code == 302
1255 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
1256
1257 def test_api_key_before_after_refresh_provider(self) -> None:
1258 client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token")
1259
1260 assert client.api_key == ""
1261 assert "Authorization" not in client.auth_headers
1262
1263 client._refresh_api_key()
1264
1265 assert client.api_key == "test_bearer_token"
1266 assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
1267
1268 def test_api_key_before_after_refresh_str(self) -> None:
1269 client = OpenAI(base_url=base_url, api_key="test_api_key")
1270
1271 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
1272 client._refresh_api_key()
1273
1274 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
1275
1276 @pytest.mark.respx()
1277 def test_api_key_refresh_on_retry(self, respx_mock: MockRouter) -> None:
1278 respx_mock.post(base_url + "/chat/completions").mock(
1279 side_effect=[
1280 httpx.Response(500, json={"error": "server error"}),
1281 httpx.Response(200, json={"foo": "bar"}),
1282 ]
1283 )
1284
1285 counter = 0
1286
1287 def token_provider() -> str:
1288 nonlocal counter
1289
1290 counter += 1
1291
1292 if counter == 1:
1293 return "first"
1294
1295 return "second"
1296
1297 client = OpenAI(base_url=base_url, api_key=token_provider)
1298 client.chat.completions.create(messages=[], model="gpt-4")
1299
1300 calls = cast("list[MockRequestCall]", respx_mock.calls)
1301 assert len(calls) == 2
1302
1303 assert calls[0].request.headers.get("Authorization") == "Bearer first"
1304 assert calls[1].request.headers.get("Authorization") == "Bearer second"
1305
1306 def test_copy_auth(self) -> None:
1307 client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token_1").copy(
1308 api_key=lambda: "test_bearer_token_2"
1309 )
1310 client._refresh_api_key()
1311 assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}
1312
1313
1314class TestAsyncOpenAI:
1315 @pytest.mark.respx(base_url=base_url)
1316 async def test_raw_response(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1317 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1318
1319 response = await async_client.post("/foo", cast_to=httpx.Response)
1320 assert response.status_code == 200
1321 assert isinstance(response, httpx.Response)
1322 assert response.json() == {"foo": "bar"}
1323
1324 @pytest.mark.respx(base_url=base_url)
1325 async def test_raw_response_for_binary(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1326 respx_mock.post("/foo").mock(
1327 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
1328 )
1329
1330 response = await async_client.post("/foo", cast_to=httpx.Response)
1331 assert response.status_code == 200
1332 assert isinstance(response, httpx.Response)
1333 assert response.json() == {"foo": "bar"}
1334
1335 def test_copy(self, async_client: AsyncOpenAI) -> None:
1336 copied = async_client.copy()
1337 assert id(copied) != id(async_client)
1338
1339 copied = async_client.copy(api_key="another My API Key")
1340 assert copied.api_key == "another My API Key"
1341 assert async_client.api_key == "My API Key"
1342
1343 copied = async_client.copy(admin_api_key="another My Admin API Key")
1344 assert copied.admin_api_key == "another My Admin API Key"
1345 assert async_client.admin_api_key == "My Admin API Key"
1346
1347 def test_copy_default_options(self, async_client: AsyncOpenAI) -> None:
1348 # options that have a default are overridden correctly
1349 copied = async_client.copy(max_retries=7)
1350 assert copied.max_retries == 7
1351 assert async_client.max_retries == 2
1352
1353 copied2 = copied.copy(max_retries=6)
1354 assert copied2.max_retries == 6
1355 assert copied.max_retries == 7
1356
1357 # timeout
1358 assert isinstance(async_client.timeout, httpx.Timeout)
1359 copied = async_client.copy(timeout=None)
1360 assert copied.timeout is None
1361 assert isinstance(async_client.timeout, httpx.Timeout)
1362
1363 async def test_copy_default_headers(self) -> None:
1364 client = AsyncOpenAI(
1365 base_url=base_url,
1366 api_key=api_key,
1367 admin_api_key=admin_api_key,
1368 _strict_response_validation=True,
1369 default_headers={"X-Foo": "bar"},
1370 )
1371 assert client.default_headers["X-Foo"] == "bar"
1372
1373 # does not override the already given value when not specified
1374 copied = client.copy()
1375 assert copied.default_headers["X-Foo"] == "bar"
1376
1377 # merges already given headers
1378 copied = client.copy(default_headers={"X-Bar": "stainless"})
1379 assert copied.default_headers["X-Foo"] == "bar"
1380 assert copied.default_headers["X-Bar"] == "stainless"
1381
1382 # uses new values for any already given headers
1383 copied = client.copy(default_headers={"X-Foo": "stainless"})
1384 assert copied.default_headers["X-Foo"] == "stainless"
1385
1386 # set_default_headers
1387
1388 # completely overrides already set values
1389 copied = client.copy(set_default_headers={})
1390 assert copied.default_headers.get("X-Foo") is None
1391
1392 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
1393 assert copied.default_headers["X-Bar"] == "Robert"
1394
1395 with pytest.raises(
1396 ValueError,
1397 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
1398 ):
1399 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
1400 await client.close()
1401
1402 async def test_copy_default_query(self) -> None:
1403 client = AsyncOpenAI(
1404 base_url=base_url,
1405 api_key=api_key,
1406 admin_api_key=admin_api_key,
1407 _strict_response_validation=True,
1408 default_query={"foo": "bar"},
1409 )
1410 assert _get_params(client)["foo"] == "bar"
1411
1412 # does not override the already given value when not specified
1413 copied = client.copy()
1414 assert _get_params(copied)["foo"] == "bar"
1415
1416 # merges already given params
1417 copied = client.copy(default_query={"bar": "stainless"})
1418 params = _get_params(copied)
1419 assert params["foo"] == "bar"
1420 assert params["bar"] == "stainless"
1421
1422 # uses new values for any already given headers
1423 copied = client.copy(default_query={"foo": "stainless"})
1424 assert _get_params(copied)["foo"] == "stainless"
1425
1426 # set_default_query
1427
1428 # completely overrides already set values
1429 copied = client.copy(set_default_query={})
1430 assert _get_params(copied) == {}
1431
1432 copied = client.copy(set_default_query={"bar": "Robert"})
1433 assert _get_params(copied)["bar"] == "Robert"
1434
1435 with pytest.raises(
1436 ValueError,
1437 # TODO: update
1438 match="`default_query` and `set_default_query` arguments are mutually exclusive",
1439 ):
1440 client.copy(set_default_query={}, default_query={"foo": "Bar"})
1441
1442 await client.close()
1443
1444 def test_copy_signature(self, async_client: AsyncOpenAI) -> None:
1445 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
1446 init_signature = inspect.signature(
1447 # mypy doesn't like that we access the `__init__` property.
1448 async_client.__init__, # type: ignore[misc]
1449 )
1450 copy_signature = inspect.signature(async_client.copy)
1451 exclude_params = {"transport", "proxies", "_strict_response_validation"}
1452
1453 for name in init_signature.parameters.keys():
1454 if name in exclude_params:
1455 continue
1456
1457 copy_param = copy_signature.parameters.get(name)
1458 assert copy_param is not None, f"copy() signature is missing the {name} param"
1459
1460 @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
1461 def test_copy_build_request(self, async_client: AsyncOpenAI) -> None:
1462 options = FinalRequestOptions(method="get", url="/foo")
1463
1464 def build_request(options: FinalRequestOptions) -> None:
1465 client_copy = async_client.copy()
1466 client_copy._build_request(options)
1467
1468 # ensure that the machinery is warmed up before tracing starts.
1469 build_request(options)
1470 gc.collect()
1471
1472 tracemalloc.start(1000)
1473
1474 snapshot_before = tracemalloc.take_snapshot()
1475
1476 ITERATIONS = 10
1477 for _ in range(ITERATIONS):
1478 build_request(options)
1479
1480 gc.collect()
1481 snapshot_after = tracemalloc.take_snapshot()
1482
1483 tracemalloc.stop()
1484
1485 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
1486 if diff.count == 0:
1487 # Avoid false positives by considering only leaks (i.e. allocations that persist).
1488 return
1489
1490 if diff.count % ITERATIONS != 0:
1491 # Avoid false positives by considering only leaks that appear per iteration.
1492 return
1493
1494 for frame in diff.traceback:
1495 if any(
1496 frame.filename.endswith(fragment)
1497 for fragment in [
1498 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
1499 #
1500 # removing the decorator fixes the leak for reasons we don't understand.
1501 "openai/_legacy_response.py",
1502 "openai/_response.py",
1503 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
1504 "openai/_compat.py",
1505 # Standard library leaks we don't care about.
1506 "/logging/__init__.py",
1507 ]
1508 ):
1509 return
1510
1511 leaks.append(diff)
1512
1513 leaks: list[tracemalloc.StatisticDiff] = []
1514 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
1515 add_leak(leaks, diff)
1516 if leaks:
1517 for leak in leaks:
1518 print("MEMORY LEAK:", leak)
1519 for frame in leak.traceback:
1520 print(frame)
1521 raise AssertionError()
1522
1523 async def test_request_timeout(self, async_client: AsyncOpenAI) -> None:
1524 request = async_client._build_request(FinalRequestOptions(method="get", url="/foo"))
1525 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1526 assert timeout == DEFAULT_TIMEOUT
1527
1528 request = async_client._build_request(
1529 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
1530 )
1531 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1532 assert timeout == httpx.Timeout(100.0)
1533
1534 async def test_client_timeout_option(self) -> None:
1535 client = AsyncOpenAI(
1536 base_url=base_url,
1537 api_key=api_key,
1538 admin_api_key=admin_api_key,
1539 _strict_response_validation=True,
1540 timeout=httpx.Timeout(0),
1541 )
1542
1543 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1544 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1545 assert timeout == httpx.Timeout(0)
1546
1547 await client.close()
1548
1549 async def test_http_client_timeout_option(self) -> None:
1550 # custom timeout given to the httpx client should be used
1551 async with httpx.AsyncClient(timeout=None) as http_client:
1552 client = AsyncOpenAI(
1553 base_url=base_url,
1554 api_key=api_key,
1555 admin_api_key=admin_api_key,
1556 _strict_response_validation=True,
1557 http_client=http_client,
1558 )
1559
1560 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1561 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1562 assert timeout == httpx.Timeout(None)
1563
1564 await client.close()
1565
1566 # no timeout given to the httpx client should not use the httpx default
1567 async with httpx.AsyncClient() as http_client:
1568 client = AsyncOpenAI(
1569 base_url=base_url,
1570 api_key=api_key,
1571 admin_api_key=admin_api_key,
1572 _strict_response_validation=True,
1573 http_client=http_client,
1574 )
1575
1576 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1577 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1578 assert timeout == DEFAULT_TIMEOUT
1579
1580 await client.close()
1581
1582 # explicitly passing the default timeout currently results in it being ignored
1583 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1584 client = AsyncOpenAI(
1585 base_url=base_url,
1586 api_key=api_key,
1587 admin_api_key=admin_api_key,
1588 _strict_response_validation=True,
1589 http_client=http_client,
1590 )
1591
1592 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1593 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1594 assert timeout == DEFAULT_TIMEOUT # our default
1595
1596 await client.close()
1597
1598 def test_invalid_http_client(self) -> None:
1599 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1600 with httpx.Client() as http_client:
1601 AsyncOpenAI(
1602 base_url=base_url,
1603 api_key=api_key,
1604 admin_api_key=admin_api_key,
1605 _strict_response_validation=True,
1606 http_client=cast(Any, http_client),
1607 )
1608
1609 async def test_default_headers_option(self) -> None:
1610 test_client = AsyncOpenAI(
1611 base_url=base_url,
1612 api_key=api_key,
1613 admin_api_key=admin_api_key,
1614 _strict_response_validation=True,
1615 default_headers={"X-Foo": "bar"},
1616 )
1617 request = test_client._build_request(FinalRequestOptions(method="get", url="/foo"))
1618 assert request.headers.get("x-foo") == "bar"
1619 assert request.headers.get("x-stainless-lang") == "python"
1620
1621 test_client2 = AsyncOpenAI(
1622 base_url=base_url,
1623 api_key=api_key,
1624 admin_api_key=admin_api_key,
1625 _strict_response_validation=True,
1626 default_headers={
1627 "X-Foo": "stainless",
1628 "X-Stainless-Lang": "my-overriding-header",
1629 },
1630 )
1631 request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1632 assert request.headers.get("x-foo") == "stainless"
1633 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1634
1635 await test_client.close()
1636 await test_client2.close()
1637
1638 async def test_validate_headers(self) -> None:
1639 client = AsyncOpenAI(
1640 base_url=base_url, api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=True
1641 )
1642 options = await client._prepare_options(FinalRequestOptions(method="get", url="/foo"))
1643 request = client._build_request(options)
1644 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1645
1646 admin_request = client._build_request(
1647 FinalRequestOptions(
1648 method="get",
1649 url="/organization/projects",
1650 security={"admin_api_key_auth": True},
1651 )
1652 )
1653 assert admin_request.headers.get("Authorization") == f"Bearer {admin_api_key}"
1654
1655 with update_env(**{"OPENAI_API_KEY": Omit()}):
1656 admin_only = AsyncOpenAI(
1657 base_url=base_url,
1658 api_key=None,
1659 admin_api_key=admin_api_key,
1660 _strict_response_validation=True,
1661 )
1662 admin_only_request = admin_only._build_request(
1663 FinalRequestOptions(
1664 method="get",
1665 url="/organization/projects",
1666 security={"admin_api_key_auth": True},
1667 )
1668 )
1669 assert admin_only_request.headers.get("Authorization") == f"Bearer {admin_api_key}"
1670
1671 with pytest.raises(
1672 TypeError,
1673 match="Could not resolve authentication method",
1674 ):
1675 admin_only._build_request(
1676 FinalRequestOptions(
1677 method="post",
1678 url="/responses",
1679 security={"bearer_auth": True},
1680 )
1681 )
1682
1683 with update_env(
1684 **{
1685 "OPENAI_API_KEY": Omit(),
1686 "OPENAI_ADMIN_KEY": Omit(),
1687 }
1688 ):
1689 with pytest.raises(OpenAIError, match="Missing credentials"):
1690 AsyncOpenAI(base_url=base_url, api_key=None, admin_api_key=None, _strict_response_validation=True)
1691
1692 async def test_default_query_option(self) -> None:
1693 client = AsyncOpenAI(
1694 base_url=base_url,
1695 api_key=api_key,
1696 admin_api_key=admin_api_key,
1697 _strict_response_validation=True,
1698 default_query={"query_param": "bar"},
1699 )
1700 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1701 url = httpx.URL(request.url)
1702 assert dict(url.params) == {"query_param": "bar"}
1703
1704 request = client._build_request(
1705 FinalRequestOptions(
1706 method="get",
1707 url="/foo",
1708 params={"foo": "baz", "query_param": "overridden"},
1709 )
1710 )
1711 url = httpx.URL(request.url)
1712 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
1713
1714 await client.close()
1715
1716 async def test_hardcoded_query_params_in_url(self, async_client: AsyncOpenAI) -> None:
1717 request = async_client._build_request(FinalRequestOptions(method="get", url="/foo?beta=true"))
1718 url = httpx.URL(request.url)
1719 assert dict(url.params) == {"beta": "true"}
1720
1721 request = async_client._build_request(
1722 FinalRequestOptions(
1723 method="get",
1724 url="/foo?beta=true",
1725 params={"limit": "10", "page": "abc"},
1726 )
1727 )
1728 url = httpx.URL(request.url)
1729 assert dict(url.params) == {"beta": "true", "limit": "10", "page": "abc"}
1730
1731 request = async_client._build_request(
1732 FinalRequestOptions(
1733 method="get",
1734 url="/files/a%2Fb?beta=true",
1735 params={"limit": "10"},
1736 )
1737 )
1738 assert request.url.raw_path == b"/files/a%2Fb?beta=true&limit=10"
1739
1740 def test_request_extra_json(self, client: OpenAI) -> None:
1741 request = client._build_request(
1742 FinalRequestOptions(
1743 method="post",
1744 url="/foo",
1745 json_data={"foo": "bar"},
1746 extra_json={"baz": False},
1747 ),
1748 )
1749 data = json.loads(request.content.decode("utf-8"))
1750 assert data == {"foo": "bar", "baz": False}
1751
1752 request = client._build_request(
1753 FinalRequestOptions(
1754 method="post",
1755 url="/foo",
1756 extra_json={"baz": False},
1757 ),
1758 )
1759 data = json.loads(request.content.decode("utf-8"))
1760 assert data == {"baz": False}
1761
1762 # `extra_json` takes priority over `json_data` when keys clash
1763 request = client._build_request(
1764 FinalRequestOptions(
1765 method="post",
1766 url="/foo",
1767 json_data={"foo": "bar", "baz": True},
1768 extra_json={"baz": None},
1769 ),
1770 )
1771 data = json.loads(request.content.decode("utf-8"))
1772 assert data == {"foo": "bar", "baz": None}
1773
1774 def test_request_extra_headers(self, client: OpenAI) -> None:
1775 request = client._build_request(
1776 FinalRequestOptions(
1777 method="post",
1778 url="/foo",
1779 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1780 ),
1781 )
1782 assert request.headers.get("X-Foo") == "Foo"
1783
1784 # `extra_headers` takes priority over `default_headers` when keys clash
1785 request = client.with_options(default_headers={"X-Bar": "true"})._build_request(
1786 FinalRequestOptions(
1787 method="post",
1788 url="/foo",
1789 **make_request_options(
1790 extra_headers={"X-Bar": "false"},
1791 ),
1792 ),
1793 )
1794 assert request.headers.get("X-Bar") == "false"
1795
1796 def test_request_extra_query(self, client: OpenAI) -> None:
1797 request = client._build_request(
1798 FinalRequestOptions(
1799 method="post",
1800 url="/foo",
1801 **make_request_options(
1802 extra_query={"my_query_param": "Foo"},
1803 ),
1804 ),
1805 )
1806 params = dict(request.url.params)
1807 assert params == {"my_query_param": "Foo"}
1808
1809 # if both `query` and `extra_query` are given, they are merged
1810 request = client._build_request(
1811 FinalRequestOptions(
1812 method="post",
1813 url="/foo",
1814 **make_request_options(
1815 query={"bar": "1"},
1816 extra_query={"foo": "2"},
1817 ),
1818 ),
1819 )
1820 params = dict(request.url.params)
1821 assert params == {"bar": "1", "foo": "2"}
1822
1823 # `extra_query` takes priority over `query` when keys clash
1824 request = client._build_request(
1825 FinalRequestOptions(
1826 method="post",
1827 url="/foo",
1828 **make_request_options(
1829 query={"foo": "1"},
1830 extra_query={"foo": "2"},
1831 ),
1832 ),
1833 )
1834 params = dict(request.url.params)
1835 assert params == {"foo": "2"}
1836
1837 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1838 request = async_client._build_request(
1839 FinalRequestOptions.construct(
1840 method="post",
1841 url="/foo",
1842 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1843 json_data={"array": ["foo", "bar"]},
1844 files=[("foo.txt", b"hello world")],
1845 )
1846 )
1847
1848 assert request.read().split(b"\r\n") == [
1849 b"--6b7ba517decee4a450543ea6ae821c82",
1850 b'Content-Disposition: form-data; name="array[]"',
1851 b"",
1852 b"foo",
1853 b"--6b7ba517decee4a450543ea6ae821c82",
1854 b'Content-Disposition: form-data; name="array[]"',
1855 b"",
1856 b"bar",
1857 b"--6b7ba517decee4a450543ea6ae821c82",
1858 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1859 b"Content-Type: application/octet-stream",
1860 b"",
1861 b"hello world",
1862 b"--6b7ba517decee4a450543ea6ae821c82--",
1863 b"",
1864 ]
1865
1866 @pytest.mark.respx(base_url=base_url)
1867 async def test_binary_content_upload(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1868 respx_mock.post("/upload").mock(side_effect=mirror_request_content)
1869
1870 file_content = b"Hello, this is a test file."
1871
1872 response = await async_client.post(
1873 "/upload",
1874 content=file_content,
1875 cast_to=httpx.Response,
1876 options={"headers": {"Content-Type": "application/octet-stream"}},
1877 )
1878
1879 assert response.status_code == 200
1880 assert response.request.headers["Content-Type"] == "application/octet-stream"
1881 assert response.content == file_content
1882
1883 async def test_binary_content_upload_with_asynciterator(self) -> None:
1884 file_content = b"Hello, this is a test file."
1885 counter = Counter()
1886 iterator = _make_async_iterator([file_content], counter=counter)
1887
1888 async def mock_handler(request: httpx.Request) -> httpx.Response:
1889 assert counter.value == 0, "the request body should not have been read"
1890 return httpx.Response(200, content=await request.aread())
1891
1892 async with AsyncOpenAI(
1893 base_url=base_url,
1894 api_key=api_key,
1895 admin_api_key=admin_api_key,
1896 _strict_response_validation=True,
1897 http_client=httpx.AsyncClient(transport=MockTransport(handler=mock_handler)),
1898 ) as client:
1899 response = await client.post(
1900 "/upload",
1901 content=iterator,
1902 cast_to=httpx.Response,
1903 options={"headers": {"Content-Type": "application/octet-stream"}},
1904 )
1905
1906 assert response.status_code == 200
1907 assert response.request.headers["Content-Type"] == "application/octet-stream"
1908 assert response.content == file_content
1909 assert counter.value == 1
1910
1911 @pytest.mark.respx(base_url=base_url)
1912 async def test_binary_content_upload_with_body_is_deprecated(
1913 self, respx_mock: MockRouter, async_client: AsyncOpenAI
1914 ) -> None:
1915 respx_mock.post("/upload").mock(side_effect=mirror_request_content)
1916
1917 file_content = b"Hello, this is a test file."
1918
1919 with pytest.deprecated_call(
1920 match="Passing raw bytes as `body` is deprecated and will be removed in a future version. Please pass raw bytes via the `content` parameter instead."
1921 ):
1922 response = await async_client.post(
1923 "/upload",
1924 body=file_content,
1925 cast_to=httpx.Response,
1926 options={"headers": {"Content-Type": "application/octet-stream"}},
1927 )
1928
1929 assert response.status_code == 200
1930 assert response.request.headers["Content-Type"] == "application/octet-stream"
1931 assert response.content == file_content
1932
1933 @pytest.mark.respx(base_url=base_url)
1934 async def test_basic_union_response(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1935 class Model1(BaseModel):
1936 name: str
1937
1938 class Model2(BaseModel):
1939 foo: str
1940
1941 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1942
1943 response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1944 assert isinstance(response, Model2)
1945 assert response.foo == "bar"
1946
1947 @pytest.mark.respx(base_url=base_url)
1948 async def test_union_response_different_types(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1949 """Union of objects with the same field name using a different type"""
1950
1951 class Model1(BaseModel):
1952 foo: int
1953
1954 class Model2(BaseModel):
1955 foo: str
1956
1957 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1958
1959 response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1960 assert isinstance(response, Model2)
1961 assert response.foo == "bar"
1962
1963 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1964
1965 response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1966 assert isinstance(response, Model1)
1967 assert response.foo == 1
1968
1969 @pytest.mark.respx(base_url=base_url)
1970 async def test_non_application_json_content_type_for_json_data(
1971 self, respx_mock: MockRouter, async_client: AsyncOpenAI
1972 ) -> None:
1973 """
1974 Response that sets Content-Type to something other than application/json but returns json data
1975 """
1976
1977 class Model(BaseModel):
1978 foo: int
1979
1980 respx_mock.get("/foo").mock(
1981 return_value=httpx.Response(
1982 200,
1983 content=json.dumps({"foo": 2}),
1984 headers={"Content-Type": "application/text"},
1985 )
1986 )
1987
1988 response = await async_client.get("/foo", cast_to=Model)
1989 assert isinstance(response, Model)
1990 assert response.foo == 2
1991
1992 async def test_base_url_setter(self) -> None:
1993 client = AsyncOpenAI(
1994 base_url="https://example.com/from_init",
1995 api_key=api_key,
1996 admin_api_key=admin_api_key,
1997 _strict_response_validation=True,
1998 )
1999 assert client.base_url == "https://example.com/from_init/"
2000
2001 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
2002
2003 assert client.base_url == "https://example.com/from_setter/"
2004
2005 await client.close()
2006
2007 async def test_base_url_env(self) -> None:
2008 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
2009 client = AsyncOpenAI(api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=True)
2010 assert client.base_url == "http://localhost:5000/from/env/"
2011
2012 @pytest.mark.parametrize(
2013 "client",
2014 [
2015 AsyncOpenAI(
2016 base_url="http://localhost:5000/custom/path/",
2017 api_key=api_key,
2018 admin_api_key=admin_api_key,
2019 _strict_response_validation=True,
2020 ),
2021 AsyncOpenAI(
2022 base_url="http://localhost:5000/custom/path/",
2023 api_key=api_key,
2024 admin_api_key=admin_api_key,
2025 _strict_response_validation=True,
2026 http_client=httpx.AsyncClient(),
2027 ),
2028 ],
2029 ids=["standard", "custom http client"],
2030 )
2031 async def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
2032 request = client._build_request(
2033 FinalRequestOptions(
2034 method="post",
2035 url="/foo",
2036 json_data={"foo": "bar"},
2037 ),
2038 )
2039 assert request.url == "http://localhost:5000/custom/path/foo"
2040 await client.close()
2041
2042 @pytest.mark.parametrize(
2043 "client",
2044 [
2045 AsyncOpenAI(
2046 base_url="http://localhost:5000/custom/path/",
2047 api_key=api_key,
2048 admin_api_key=admin_api_key,
2049 _strict_response_validation=True,
2050 ),
2051 AsyncOpenAI(
2052 base_url="http://localhost:5000/custom/path/",
2053 api_key=api_key,
2054 admin_api_key=admin_api_key,
2055 _strict_response_validation=True,
2056 http_client=httpx.AsyncClient(),
2057 ),
2058 ],
2059 ids=["standard", "custom http client"],
2060 )
2061 async def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
2062 request = client._build_request(
2063 FinalRequestOptions(
2064 method="post",
2065 url="/foo",
2066 json_data={"foo": "bar"},
2067 ),
2068 )
2069 assert request.url == "http://localhost:5000/custom/path/foo"
2070 await client.close()
2071
2072 @pytest.mark.parametrize(
2073 "client",
2074 [
2075 AsyncOpenAI(
2076 base_url="http://localhost:5000/custom/path/",
2077 api_key=api_key,
2078 admin_api_key=admin_api_key,
2079 _strict_response_validation=True,
2080 ),
2081 AsyncOpenAI(
2082 base_url="http://localhost:5000/custom/path/",
2083 api_key=api_key,
2084 admin_api_key=admin_api_key,
2085 _strict_response_validation=True,
2086 http_client=httpx.AsyncClient(),
2087 ),
2088 ],
2089 ids=["standard", "custom http client"],
2090 )
2091 async def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
2092 request = client._build_request(
2093 FinalRequestOptions(
2094 method="post",
2095 url="https://myapi.com/foo",
2096 json_data={"foo": "bar"},
2097 ),
2098 )
2099 assert request.url == "https://myapi.com/foo"
2100 await client.close()
2101
2102 async def test_copied_client_does_not_close_http(self) -> None:
2103 test_client = AsyncOpenAI(
2104 base_url=base_url, api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=True
2105 )
2106 assert not test_client.is_closed()
2107
2108 copied = test_client.copy()
2109 assert copied is not test_client
2110
2111 del copied
2112
2113 await asyncio.sleep(0.2)
2114 assert not test_client.is_closed()
2115
2116 async def test_client_context_manager(self) -> None:
2117 test_client = AsyncOpenAI(
2118 base_url=base_url, api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=True
2119 )
2120 async with test_client as c2:
2121 assert c2 is test_client
2122 assert not c2.is_closed()
2123 assert not test_client.is_closed()
2124 assert test_client.is_closed()
2125
2126 @pytest.mark.respx(base_url=base_url)
2127 async def test_client_response_validation_error(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
2128 class Model(BaseModel):
2129 foo: str
2130
2131 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
2132
2133 with pytest.raises(APIResponseValidationError) as exc:
2134 await async_client.get("/foo", cast_to=Model)
2135
2136 assert isinstance(exc.value.__cause__, ValidationError)
2137
2138 async def test_client_max_retries_validation(self) -> None:
2139 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
2140 AsyncOpenAI(
2141 base_url=base_url,
2142 api_key=api_key,
2143 admin_api_key=admin_api_key,
2144 _strict_response_validation=True,
2145 max_retries=cast(Any, None),
2146 )
2147
2148 @pytest.mark.respx(base_url=base_url)
2149 async def test_default_stream_cls(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
2150 class Model(BaseModel):
2151 name: str
2152
2153 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
2154
2155 stream = await async_client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
2156 assert isinstance(stream, AsyncStream)
2157 await stream.response.aclose()
2158
2159 @pytest.mark.respx(base_url=base_url)
2160 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
2161 class Model(BaseModel):
2162 name: str
2163
2164 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
2165
2166 strict_client = AsyncOpenAI(
2167 base_url=base_url, api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=True
2168 )
2169
2170 with pytest.raises(APIResponseValidationError):
2171 await strict_client.get("/foo", cast_to=Model)
2172
2173 non_strict_client = AsyncOpenAI(
2174 base_url=base_url, api_key=api_key, admin_api_key=admin_api_key, _strict_response_validation=False
2175 )
2176
2177 response = await non_strict_client.get("/foo", cast_to=Model)
2178 assert isinstance(response, str) # type: ignore[unreachable]
2179
2180 await strict_client.close()
2181 await non_strict_client.close()
2182
2183 @pytest.mark.parametrize(
2184 "remaining_retries,retry_after,timeout",
2185 [
2186 [3, "20", 20],
2187 [3, "0", 0.5],
2188 [3, "-10", 0.5],
2189 [3, "60", 60],
2190 [3, "61", 0.5],
2191 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
2192 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
2193 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
2194 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
2195 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
2196 [3, "99999999999999999999999999999999999", 0.5],
2197 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
2198 [3, "", 0.5],
2199 [2, "", 0.5 * 2.0],
2200 [1, "", 0.5 * 4.0],
2201 [-1100, "", 8], # test large number potentially overflowing
2202 ],
2203 )
2204 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
2205 async def test_parse_retry_after_header(
2206 self, remaining_retries: int, retry_after: str, timeout: float, async_client: AsyncOpenAI
2207 ) -> None:
2208 headers = httpx.Headers({"retry-after": retry_after})
2209 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
2210 calculated = async_client._calculate_retry_timeout(remaining_retries, options, headers)
2211 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
2212
2213 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
2214 @pytest.mark.respx(base_url=base_url)
2215 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
2216 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
2217
2218 with pytest.raises(APITimeoutError):
2219 await async_client.chat.completions.with_streaming_response.create(
2220 messages=[
2221 {
2222 "content": "string",
2223 "role": "developer",
2224 }
2225 ],
2226 model="gpt-5.4",
2227 ).__aenter__()
2228
2229 assert _get_open_connections(async_client) == 0
2230
2231 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
2232 @pytest.mark.respx(base_url=base_url)
2233 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
2234 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
2235
2236 with pytest.raises(APIStatusError):
2237 await async_client.chat.completions.with_streaming_response.create(
2238 messages=[
2239 {
2240 "content": "string",
2241 "role": "developer",
2242 }
2243 ],
2244 model="gpt-5.4",
2245 ).__aenter__()
2246 assert _get_open_connections(async_client) == 0
2247
2248 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
2249 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
2250 @pytest.mark.respx(base_url=base_url)
2251 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
2252 async def test_retries_taken(
2253 self,
2254 async_client: AsyncOpenAI,
2255 failures_before_success: int,
2256 failure_mode: Literal["status", "exception"],
2257 respx_mock: MockRouter,
2258 ) -> None:
2259 client = async_client.with_options(max_retries=4)
2260
2261 nb_retries = 0
2262
2263 def retry_handler(_request: httpx.Request) -> httpx.Response:
2264 nonlocal nb_retries
2265 if nb_retries < failures_before_success:
2266 nb_retries += 1
2267 if failure_mode == "exception":
2268 raise RuntimeError("oops")
2269 return httpx.Response(500)
2270 return httpx.Response(200)
2271
2272 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
2273
2274 response = await client.chat.completions.with_raw_response.create(
2275 messages=[
2276 {
2277 "content": "string",
2278 "role": "developer",
2279 }
2280 ],
2281 model="gpt-5.4",
2282 )
2283
2284 assert response.retries_taken == failures_before_success
2285 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
2286
2287 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
2288 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
2289 @pytest.mark.respx(base_url=base_url)
2290 async def test_omit_retry_count_header(
2291 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
2292 ) -> None:
2293 client = async_client.with_options(max_retries=4)
2294
2295 nb_retries = 0
2296
2297 def retry_handler(_request: httpx.Request) -> httpx.Response:
2298 nonlocal nb_retries
2299 if nb_retries < failures_before_success:
2300 nb_retries += 1
2301 return httpx.Response(500)
2302 return httpx.Response(200)
2303
2304 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
2305
2306 response = await client.chat.completions.with_raw_response.create(
2307 messages=[
2308 {
2309 "content": "string",
2310 "role": "developer",
2311 }
2312 ],
2313 model="gpt-5.4",
2314 extra_headers={"x-stainless-retry-count": Omit()},
2315 )
2316
2317 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
2318
2319 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
2320 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
2321 @pytest.mark.respx(base_url=base_url)
2322 async def test_overwrite_retry_count_header(
2323 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
2324 ) -> None:
2325 client = async_client.with_options(max_retries=4)
2326
2327 nb_retries = 0
2328
2329 def retry_handler(_request: httpx.Request) -> httpx.Response:
2330 nonlocal nb_retries
2331 if nb_retries < failures_before_success:
2332 nb_retries += 1
2333 return httpx.Response(500)
2334 return httpx.Response(200)
2335
2336 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
2337
2338 response = await client.chat.completions.with_raw_response.create(
2339 messages=[
2340 {
2341 "content": "string",
2342 "role": "developer",
2343 }
2344 ],
2345 model="gpt-5.4",
2346 extra_headers={"x-stainless-retry-count": "42"},
2347 )
2348
2349 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
2350
2351 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
2352 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
2353 @pytest.mark.respx(base_url=base_url)
2354 async def test_retries_taken_new_response_class(
2355 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
2356 ) -> None:
2357 client = async_client.with_options(max_retries=4)
2358
2359 nb_retries = 0
2360
2361 def retry_handler(_request: httpx.Request) -> httpx.Response:
2362 nonlocal nb_retries
2363 if nb_retries < failures_before_success:
2364 nb_retries += 1
2365 return httpx.Response(500)
2366 return httpx.Response(200)
2367
2368 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
2369
2370 async with client.chat.completions.with_streaming_response.create(
2371 messages=[
2372 {
2373 "content": "string",
2374 "role": "developer",
2375 }
2376 ],
2377 model="gpt-5.4",
2378 ) as response:
2379 assert response.retries_taken == failures_before_success
2380 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
2381
2382 async def test_get_platform(self) -> None:
2383 platform = await asyncify(get_platform)()
2384 assert isinstance(platform, (str, OtherPlatform))
2385
2386 async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
2387 # Test that the proxy environment variables are set correctly
2388 monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
2389 # Delete in case our environment has any proxy env vars set
2390 monkeypatch.delenv("HTTP_PROXY", raising=False)
2391 monkeypatch.delenv("ALL_PROXY", raising=False)
2392 monkeypatch.delenv("NO_PROXY", raising=False)
2393 monkeypatch.delenv("http_proxy", raising=False)
2394 monkeypatch.delenv("https_proxy", raising=False)
2395 monkeypatch.delenv("all_proxy", raising=False)
2396 monkeypatch.delenv("no_proxy", raising=False)
2397
2398 client = DefaultAsyncHttpxClient()
2399
2400 mounts = tuple(client._mounts.items())
2401 assert len(mounts) == 1
2402 assert mounts[0][0].pattern == "https://"
2403
2404 @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
2405 async def test_default_client_creation(self) -> None:
2406 # Ensure that the client can be initialized without any exceptions
2407 DefaultAsyncHttpxClient(
2408 verify=True,
2409 cert=None,
2410 trust_env=True,
2411 http1=True,
2412 http2=False,
2413 limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
2414 )
2415
2416 @pytest.mark.respx(base_url=base_url)
2417 async def test_follow_redirects(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
2418 # Test that the default follow_redirects=True allows following redirects
2419 respx_mock.post("/redirect").mock(
2420 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
2421 )
2422 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
2423
2424 response = await async_client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
2425 assert response.status_code == 200
2426 assert response.json() == {"status": "ok"}
2427
2428 @pytest.mark.respx(base_url=base_url)
2429 async def test_follow_redirects_disabled(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
2430 # Test that follow_redirects=False prevents following redirects
2431 respx_mock.post("/redirect").mock(
2432 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
2433 )
2434
2435 with pytest.raises(APIStatusError) as exc_info:
2436 await async_client.post(
2437 "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
2438 )
2439
2440 assert exc_info.value.response.status_code == 302
2441 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
2442
2443 async def test_api_key_before_after_refresh_provider(self) -> None:
2444 async def mock_api_key_provider():
2445 return "test_bearer_token"
2446
2447 client = AsyncOpenAI(base_url=base_url, api_key=mock_api_key_provider)
2448
2449 assert client.api_key == ""
2450 assert "Authorization" not in client.auth_headers
2451
2452 await client._refresh_api_key()
2453
2454 assert client.api_key == "test_bearer_token"
2455 assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
2456
2457 async def test_api_key_before_after_refresh_str(self) -> None:
2458 client = AsyncOpenAI(base_url=base_url, api_key="test_api_key")
2459
2460 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
2461 await client._refresh_api_key()
2462
2463 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
2464
2465 @pytest.mark.respx()
2466 async def test_bearer_token_refresh_async(self, respx_mock: MockRouter) -> None:
2467 respx_mock.post(base_url + "/chat/completions").mock(
2468 side_effect=[
2469 httpx.Response(500, json={"error": "server error"}),
2470 httpx.Response(200, json={"foo": "bar"}),
2471 ]
2472 )
2473
2474 counter = 0
2475
2476 async def token_provider() -> str:
2477 nonlocal counter
2478
2479 counter += 1
2480
2481 if counter == 1:
2482 return "first"
2483
2484 return "second"
2485
2486 client = AsyncOpenAI(base_url=base_url, api_key=token_provider)
2487 await client.chat.completions.create(messages=[], model="gpt-4")
2488
2489 calls = cast("list[MockRequestCall]", respx_mock.calls)
2490 assert len(calls) == 2
2491
2492 assert calls[0].request.headers.get("Authorization") == "Bearer first"
2493 assert calls[1].request.headers.get("Authorization") == "Bearer second"
2494
2495 async def test_copy_auth(self) -> None:
2496 async def token_provider_1() -> str:
2497 return "test_bearer_token_1"
2498
2499 async def token_provider_2() -> str:
2500 return "test_bearer_token_2"
2501
2502 client = AsyncOpenAI(base_url=base_url, api_key=token_provider_1).copy(api_key=token_provider_2)
2503 await client._refresh_api_key()
2504 assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}
2505
2506
2507class TestWorkloadIdentity401Retry:
2508 @pytest.mark.respx()
2509 def test_workload_identity_401_retry(self, respx_mock: MockRouter) -> None:
2510 provider_call_count = 0
2511
2512 def provider() -> str:
2513 nonlocal provider_call_count
2514 provider_call_count += 1
2515 return f"external-subject-token-{provider_call_count}"
2516
2517 respx_mock.post("https://auth.openai.com/oauth/token").mock(
2518 side_effect=[
2519 httpx.Response(
2520 200,
2521 json={
2522 "access_token": "openai-access-token-1",
2523 "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
2524 "token_type": "Bearer",
2525 "expires_in": 3600,
2526 },
2527 ),
2528 httpx.Response(
2529 200,
2530 json={
2531 "access_token": "openai-access-token-2",
2532 "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
2533 "token_type": "Bearer",
2534 "expires_in": 3600,
2535 },
2536 ),
2537 ]
2538 )
2539
2540 respx_mock.post(base_url + "/chat/completions").mock(
2541 side_effect=[
2542 httpx.Response(401, json={"error": {"message": "Unauthorized", "type": "invalid_request_error"}}),
2543 httpx.Response(
2544 200,
2545 json={
2546 "id": "chatcmpl-123",
2547 "object": "chat.completion",
2548 "created": 1234567890,
2549 "model": "gpt-4",
2550 "choices": [],
2551 },
2552 ),
2553 ]
2554 )
2555
2556 with OpenAI(
2557 base_url=base_url,
2558 workload_identity={
2559 **workload_identity,
2560 "provider": {
2561 "get_token": provider,
2562 "token_type": "jwt",
2563 },
2564 },
2565 organization="org_123",
2566 project="proj_123",
2567 _strict_response_validation=True,
2568 ) as client:
2569 client.chat.completions.create(messages=[], model="gpt-4")
2570
2571 calls = cast("list[MockRequestCall]", respx_mock.calls)
2572 assert len(calls) == 4
2573
2574 assert calls[0].request.url == httpx.URL("https://auth.openai.com/oauth/token")
2575 assert calls[1].request.url == httpx.URL(base_url + "/chat/completions")
2576 assert calls[1].request.headers.get("Authorization") == "Bearer openai-access-token-1"
2577
2578 assert calls[2].request.url == httpx.URL("https://auth.openai.com/oauth/token")
2579
2580 assert calls[3].request.url == httpx.URL(base_url + "/chat/completions")
2581 assert calls[3].request.headers.get("Authorization") == "Bearer openai-access-token-2"
2582
2583 assert provider_call_count == 2
2584
2585 @pytest.mark.respx()
2586 def test_401_without_workload_identity_no_retry(self, respx_mock: MockRouter) -> None:
2587 respx_mock.post(base_url + "/chat/completions").mock(
2588 return_value=httpx.Response(
2589 401, json={"error": {"message": "Unauthorized", "type": "invalid_request_error"}}
2590 )
2591 )
2592
2593 with OpenAI(
2594 base_url=base_url,
2595 api_key="test-api-key",
2596 _strict_response_validation=True,
2597 ) as client:
2598 with pytest.raises(APIStatusError) as exc_info:
2599 client.chat.completions.create(messages=[], model="gpt-4")
2600
2601 assert exc_info.value.status_code == 401
2602
2603 calls = cast("list[MockRequestCall]", respx_mock.calls)
2604 assert len(calls) == 1
2605
2606 @pytest.mark.respx()
2607 def test_non_401_errors_no_retry(self, respx_mock: MockRouter) -> None:
2608 provider_call_count = 0
2609
2610 def provider() -> str:
2611 nonlocal provider_call_count
2612 provider_call_count += 1
2613 return "external-subject-token"
2614
2615 respx_mock.post("https://auth.openai.com/oauth/token").mock(
2616 return_value=httpx.Response(
2617 200,
2618 json={
2619 "access_token": "openai-access-token-1",
2620 "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
2621 "token_type": "Bearer",
2622 "expires_in": 3600,
2623 },
2624 )
2625 )
2626
2627 respx_mock.post(base_url + "/chat/completions").mock(
2628 return_value=httpx.Response(403, json={"error": {"message": "Forbidden", "type": "invalid_request_error"}})
2629 )
2630
2631 with OpenAI(
2632 base_url=base_url,
2633 workload_identity={
2634 **workload_identity,
2635 "provider": {
2636 "get_token": provider,
2637 "token_type": "jwt",
2638 },
2639 },
2640 organization="org_123",
2641 project="proj_123",
2642 _strict_response_validation=True,
2643 ) as client:
2644 with pytest.raises(APIStatusError) as exc_info:
2645 client.chat.completions.create(messages=[], model="gpt-4")
2646
2647 assert exc_info.value.status_code == 403
2648
2649 calls = cast("list[MockRequestCall]", respx_mock.calls)
2650 assert len(calls) == 2
2651
2652 assert provider_call_count == 1
2653
2654
2655class TestAsyncWorkloadIdentity401Retry:
2656 @pytest.mark.respx()
2657 async def test_workload_identity_401_retry(self, respx_mock: MockRouter) -> None:
2658 provider_call_count = 0
2659
2660 def provider() -> str:
2661 nonlocal provider_call_count
2662 provider_call_count += 1
2663 return f"external-subject-token-{provider_call_count}"
2664
2665 respx_mock.post("https://auth.openai.com/oauth/token").mock(
2666 side_effect=[
2667 httpx.Response(
2668 200,
2669 json={
2670 "access_token": "openai-access-token-1",
2671 "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
2672 "token_type": "Bearer",
2673 "expires_in": 3600,
2674 },
2675 ),
2676 httpx.Response(
2677 200,
2678 json={
2679 "access_token": "openai-access-token-2",
2680 "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
2681 "token_type": "Bearer",
2682 "expires_in": 3600,
2683 },
2684 ),
2685 ]
2686 )
2687
2688 respx_mock.post(base_url + "/chat/completions").mock(
2689 side_effect=[
2690 httpx.Response(401, json={"error": {"message": "Unauthorized", "type": "invalid_request_error"}}),
2691 httpx.Response(
2692 200,
2693 json={
2694 "id": "chatcmpl-123",
2695 "object": "chat.completion",
2696 "created": 1234567890,
2697 "model": "gpt-4",
2698 "choices": [],
2699 },
2700 ),
2701 ]
2702 )
2703
2704 async with AsyncOpenAI(
2705 base_url=base_url,
2706 workload_identity={
2707 **workload_identity,
2708 "provider": {
2709 "get_token": provider,
2710 "token_type": "jwt",
2711 },
2712 },
2713 organization="org_123",
2714 project="proj_123",
2715 _strict_response_validation=True,
2716 ) as client:
2717 await client.chat.completions.create(messages=[], model="gpt-4")
2718
2719 calls = cast("list[MockRequestCall]", respx_mock.calls)
2720 assert len(calls) == 4
2721
2722 assert calls[0].request.url == httpx.URL("https://auth.openai.com/oauth/token")
2723 assert calls[1].request.url == httpx.URL(base_url + "/chat/completions")
2724 assert calls[1].request.headers.get("Authorization") == "Bearer openai-access-token-1"
2725
2726 assert calls[2].request.url == httpx.URL("https://auth.openai.com/oauth/token")
2727
2728 assert calls[3].request.url == httpx.URL(base_url + "/chat/completions")
2729 assert calls[3].request.headers.get("Authorization") == "Bearer openai-access-token-2"
2730
2731 assert provider_call_count == 2
2732
2733 @pytest.mark.respx()
2734 async def test_401_without_workload_identity_no_retry(self, respx_mock: MockRouter) -> None:
2735 respx_mock.post(base_url + "/chat/completions").mock(
2736 return_value=httpx.Response(
2737 401, json={"error": {"message": "Unauthorized", "type": "invalid_request_error"}}
2738 )
2739 )
2740
2741 async with AsyncOpenAI(
2742 base_url=base_url,
2743 api_key="test-api-key",
2744 _strict_response_validation=True,
2745 ) as client:
2746 with pytest.raises(APIStatusError) as exc_info:
2747 await client.chat.completions.create(messages=[], model="gpt-4")
2748
2749 assert exc_info.value.status_code == 401
2750
2751 calls = cast("list[MockRequestCall]", respx_mock.calls)
2752 assert len(calls) == 1
2753
2754 @pytest.mark.respx()
2755 async def test_non_401_errors_no_retry(self, respx_mock: MockRouter) -> None:
2756 provider_call_count = 0
2757
2758 def provider() -> str:
2759 nonlocal provider_call_count
2760 provider_call_count += 1
2761 return "external-subject-token"
2762
2763 respx_mock.post("https://auth.openai.com/oauth/token").mock(
2764 return_value=httpx.Response(
2765 200,
2766 json={
2767 "access_token": "openai-access-token-1",
2768 "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
2769 "token_type": "Bearer",
2770 "expires_in": 3600,
2771 },
2772 )
2773 )
2774
2775 respx_mock.post(base_url + "/chat/completions").mock(
2776 return_value=httpx.Response(403, json={"error": {"message": "Forbidden", "type": "invalid_request_error"}})
2777 )
2778
2779 async with AsyncOpenAI(
2780 base_url=base_url,
2781 workload_identity={
2782 **workload_identity,
2783 "provider": {
2784 "get_token": provider,
2785 "token_type": "jwt",
2786 },
2787 },
2788 organization="org_123",
2789 project="proj_123",
2790 _strict_response_validation=True,
2791 ) as client:
2792 with pytest.raises(APIStatusError) as exc_info:
2793 await client.chat.completions.create(messages=[], model="gpt-4")
2794
2795 assert exc_info.value.status_code == 403
2796
2797 calls = cast("list[MockRequestCall]", respx_mock.calls)
2798 assert len(calls) == 2
2799
2800 assert provider_call_count == 1