openai/openai-python

Public

mirrored from https://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v2.23.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

2197lines · modecode

1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import gc
6import os
7import sys
8import json
9import asyncio
10import inspect
11import dataclasses
12import tracemalloc
13from typing import Any, Union, TypeVar, Callable, Iterable, Iterator, Optional, Protocol, Coroutine, cast
14from unittest import mock
15from typing_extensions import Literal, AsyncIterator, override
16
17import httpx
18import pytest
19from respx import MockRouter
20from pydantic import ValidationError
21
22from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
23from openai._types import Omit
24from openai._utils import asyncify
25from openai._models import BaseModel, FinalRequestOptions
26from openai._streaming import Stream, AsyncStream
27from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
28from openai._base_client import (
29 DEFAULT_TIMEOUT,
30 HTTPX_DEFAULT_TIMEOUT,
31 BaseClient,
32 OtherPlatform,
33 DefaultHttpxClient,
34 DefaultAsyncHttpxClient,
35 get_platform,
36 make_request_options,
37)
38
39from .utils import update_env
40
41T = TypeVar("T")
42base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
43api_key = "My API Key"
44
45
46class MockRequestCall(Protocol):
47 request: httpx.Request
48
49
50def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
51 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
52 url = httpx.URL(request.url)
53 return dict(url.params)
54
55
56def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
57 return 0.1
58
59
60def mirror_request_content(request: httpx.Request) -> httpx.Response:
61 return httpx.Response(200, content=request.content)
62
63
64# note: we can't use the httpx.MockTransport class as it consumes the request
65# body itself, which means we can't test that the body is read lazily
66class MockTransport(httpx.BaseTransport, httpx.AsyncBaseTransport):
67 def __init__(
68 self,
69 handler: Callable[[httpx.Request], httpx.Response]
70 | Callable[[httpx.Request], Coroutine[Any, Any, httpx.Response]],
71 ) -> None:
72 self.handler = handler
73
74 @override
75 def handle_request(
76 self,
77 request: httpx.Request,
78 ) -> httpx.Response:
79 assert not inspect.iscoroutinefunction(self.handler), "handler must not be a coroutine function"
80 assert inspect.isfunction(self.handler), "handler must be a function"
81 return self.handler(request)
82
83 @override
84 async def handle_async_request(
85 self,
86 request: httpx.Request,
87 ) -> httpx.Response:
88 assert inspect.iscoroutinefunction(self.handler), "handler must be a coroutine function"
89 return await self.handler(request)
90
91
92@dataclasses.dataclass
93class Counter:
94 value: int = 0
95
96
97def _make_sync_iterator(iterable: Iterable[T], counter: Optional[Counter] = None) -> Iterator[T]:
98 for item in iterable:
99 if counter:
100 counter.value += 1
101 yield item
102
103
104async def _make_async_iterator(iterable: Iterable[T], counter: Optional[Counter] = None) -> AsyncIterator[T]:
105 for item in iterable:
106 if counter:
107 counter.value += 1
108 yield item
109
110
111def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
112 transport = client._client._transport
113 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
114
115 pool = transport._pool
116 return len(pool._requests)
117
118
119class TestOpenAI:
120 @pytest.mark.respx(base_url=base_url)
121 def test_raw_response(self, respx_mock: MockRouter, client: OpenAI) -> None:
122 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
123
124 response = client.post("/foo", cast_to=httpx.Response)
125 assert response.status_code == 200
126 assert isinstance(response, httpx.Response)
127 assert response.json() == {"foo": "bar"}
128
129 @pytest.mark.respx(base_url=base_url)
130 def test_raw_response_for_binary(self, respx_mock: MockRouter, client: OpenAI) -> None:
131 respx_mock.post("/foo").mock(
132 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
133 )
134
135 response = client.post("/foo", cast_to=httpx.Response)
136 assert response.status_code == 200
137 assert isinstance(response, httpx.Response)
138 assert response.json() == {"foo": "bar"}
139
140 def test_copy(self, client: OpenAI) -> None:
141 copied = client.copy()
142 assert id(copied) != id(client)
143
144 copied = client.copy(api_key="another My API Key")
145 assert copied.api_key == "another My API Key"
146 assert client.api_key == "My API Key"
147
148 def test_copy_default_options(self, client: OpenAI) -> None:
149 # options that have a default are overridden correctly
150 copied = client.copy(max_retries=7)
151 assert copied.max_retries == 7
152 assert client.max_retries == 2
153
154 copied2 = copied.copy(max_retries=6)
155 assert copied2.max_retries == 6
156 assert copied.max_retries == 7
157
158 # timeout
159 assert isinstance(client.timeout, httpx.Timeout)
160 copied = client.copy(timeout=None)
161 assert copied.timeout is None
162 assert isinstance(client.timeout, httpx.Timeout)
163
164 def test_copy_default_headers(self) -> None:
165 client = OpenAI(
166 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
167 )
168 assert client.default_headers["X-Foo"] == "bar"
169
170 # does not override the already given value when not specified
171 copied = client.copy()
172 assert copied.default_headers["X-Foo"] == "bar"
173
174 # merges already given headers
175 copied = client.copy(default_headers={"X-Bar": "stainless"})
176 assert copied.default_headers["X-Foo"] == "bar"
177 assert copied.default_headers["X-Bar"] == "stainless"
178
179 # uses new values for any already given headers
180 copied = client.copy(default_headers={"X-Foo": "stainless"})
181 assert copied.default_headers["X-Foo"] == "stainless"
182
183 # set_default_headers
184
185 # completely overrides already set values
186 copied = client.copy(set_default_headers={})
187 assert copied.default_headers.get("X-Foo") is None
188
189 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
190 assert copied.default_headers["X-Bar"] == "Robert"
191
192 with pytest.raises(
193 ValueError,
194 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
195 ):
196 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
197 client.close()
198
199 def test_copy_default_query(self) -> None:
200 client = OpenAI(
201 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
202 )
203 assert _get_params(client)["foo"] == "bar"
204
205 # does not override the already given value when not specified
206 copied = client.copy()
207 assert _get_params(copied)["foo"] == "bar"
208
209 # merges already given params
210 copied = client.copy(default_query={"bar": "stainless"})
211 params = _get_params(copied)
212 assert params["foo"] == "bar"
213 assert params["bar"] == "stainless"
214
215 # uses new values for any already given headers
216 copied = client.copy(default_query={"foo": "stainless"})
217 assert _get_params(copied)["foo"] == "stainless"
218
219 # set_default_query
220
221 # completely overrides already set values
222 copied = client.copy(set_default_query={})
223 assert _get_params(copied) == {}
224
225 copied = client.copy(set_default_query={"bar": "Robert"})
226 assert _get_params(copied)["bar"] == "Robert"
227
228 with pytest.raises(
229 ValueError,
230 # TODO: update
231 match="`default_query` and `set_default_query` arguments are mutually exclusive",
232 ):
233 client.copy(set_default_query={}, default_query={"foo": "Bar"})
234
235 client.close()
236
237 def test_copy_signature(self, client: OpenAI) -> None:
238 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
239 init_signature = inspect.signature(
240 # mypy doesn't like that we access the `__init__` property.
241 client.__init__, # type: ignore[misc]
242 )
243 copy_signature = inspect.signature(client.copy)
244 exclude_params = {"transport", "proxies", "_strict_response_validation"}
245
246 for name in init_signature.parameters.keys():
247 if name in exclude_params:
248 continue
249
250 copy_param = copy_signature.parameters.get(name)
251 assert copy_param is not None, f"copy() signature is missing the {name} param"
252
253 @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
254 def test_copy_build_request(self, client: OpenAI) -> None:
255 options = FinalRequestOptions(method="get", url="/foo")
256
257 def build_request(options: FinalRequestOptions) -> None:
258 client_copy = client.copy()
259 client_copy._build_request(options)
260
261 # ensure that the machinery is warmed up before tracing starts.
262 build_request(options)
263 gc.collect()
264
265 tracemalloc.start(1000)
266
267 snapshot_before = tracemalloc.take_snapshot()
268
269 ITERATIONS = 10
270 for _ in range(ITERATIONS):
271 build_request(options)
272
273 gc.collect()
274 snapshot_after = tracemalloc.take_snapshot()
275
276 tracemalloc.stop()
277
278 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
279 if diff.count == 0:
280 # Avoid false positives by considering only leaks (i.e. allocations that persist).
281 return
282
283 if diff.count % ITERATIONS != 0:
284 # Avoid false positives by considering only leaks that appear per iteration.
285 return
286
287 for frame in diff.traceback:
288 if any(
289 frame.filename.endswith(fragment)
290 for fragment in [
291 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
292 #
293 # removing the decorator fixes the leak for reasons we don't understand.
294 "openai/_legacy_response.py",
295 "openai/_response.py",
296 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
297 "openai/_compat.py",
298 # Standard library leaks we don't care about.
299 "/logging/__init__.py",
300 ]
301 ):
302 return
303
304 leaks.append(diff)
305
306 leaks: list[tracemalloc.StatisticDiff] = []
307 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
308 add_leak(leaks, diff)
309 if leaks:
310 for leak in leaks:
311 print("MEMORY LEAK:", leak)
312 for frame in leak.traceback:
313 print(frame)
314 raise AssertionError()
315
316 def test_request_timeout(self, client: OpenAI) -> None:
317 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
318 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
319 assert timeout == DEFAULT_TIMEOUT
320
321 request = client._build_request(FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)))
322 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
323 assert timeout == httpx.Timeout(100.0)
324
325 def test_client_timeout_option(self) -> None:
326 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
327
328 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
329 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
330 assert timeout == httpx.Timeout(0)
331
332 client.close()
333
334 def test_http_client_timeout_option(self) -> None:
335 # custom timeout given to the httpx client should be used
336 with httpx.Client(timeout=None) as http_client:
337 client = OpenAI(
338 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
339 )
340
341 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
342 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
343 assert timeout == httpx.Timeout(None)
344
345 client.close()
346
347 # no timeout given to the httpx client should not use the httpx default
348 with httpx.Client() as http_client:
349 client = OpenAI(
350 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
351 )
352
353 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
354 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
355 assert timeout == DEFAULT_TIMEOUT
356
357 client.close()
358
359 # explicitly passing the default timeout currently results in it being ignored
360 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
361 client = OpenAI(
362 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
363 )
364
365 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
366 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
367 assert timeout == DEFAULT_TIMEOUT # our default
368
369 client.close()
370
371 async def test_invalid_http_client(self) -> None:
372 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
373 async with httpx.AsyncClient() as http_client:
374 OpenAI(
375 base_url=base_url,
376 api_key=api_key,
377 _strict_response_validation=True,
378 http_client=cast(Any, http_client),
379 )
380
381 def test_default_headers_option(self) -> None:
382 test_client = OpenAI(
383 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
384 )
385 request = test_client._build_request(FinalRequestOptions(method="get", url="/foo"))
386 assert request.headers.get("x-foo") == "bar"
387 assert request.headers.get("x-stainless-lang") == "python"
388
389 test_client2 = OpenAI(
390 base_url=base_url,
391 api_key=api_key,
392 _strict_response_validation=True,
393 default_headers={
394 "X-Foo": "stainless",
395 "X-Stainless-Lang": "my-overriding-header",
396 },
397 )
398 request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo"))
399 assert request.headers.get("x-foo") == "stainless"
400 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
401
402 test_client.close()
403 test_client2.close()
404
405 def test_validate_headers(self) -> None:
406 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
407 options = client._prepare_options(FinalRequestOptions(method="get", url="/foo"))
408 request = client._build_request(options)
409
410 assert request.headers.get("Authorization") == f"Bearer {api_key}"
411
412 with pytest.raises(OpenAIError):
413 with update_env(**{"OPENAI_API_KEY": Omit()}):
414 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
415 _ = client2
416
417 def test_default_query_option(self) -> None:
418 client = OpenAI(
419 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
420 )
421 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
422 url = httpx.URL(request.url)
423 assert dict(url.params) == {"query_param": "bar"}
424
425 request = client._build_request(
426 FinalRequestOptions(
427 method="get",
428 url="/foo",
429 params={"foo": "baz", "query_param": "overridden"},
430 )
431 )
432 url = httpx.URL(request.url)
433 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
434
435 client.close()
436
437 def test_request_extra_json(self, client: OpenAI) -> None:
438 request = client._build_request(
439 FinalRequestOptions(
440 method="post",
441 url="/foo",
442 json_data={"foo": "bar"},
443 extra_json={"baz": False},
444 ),
445 )
446 data = json.loads(request.content.decode("utf-8"))
447 assert data == {"foo": "bar", "baz": False}
448
449 request = client._build_request(
450 FinalRequestOptions(
451 method="post",
452 url="/foo",
453 extra_json={"baz": False},
454 ),
455 )
456 data = json.loads(request.content.decode("utf-8"))
457 assert data == {"baz": False}
458
459 # `extra_json` takes priority over `json_data` when keys clash
460 request = client._build_request(
461 FinalRequestOptions(
462 method="post",
463 url="/foo",
464 json_data={"foo": "bar", "baz": True},
465 extra_json={"baz": None},
466 ),
467 )
468 data = json.loads(request.content.decode("utf-8"))
469 assert data == {"foo": "bar", "baz": None}
470
471 def test_request_extra_headers(self, client: OpenAI) -> None:
472 request = client._build_request(
473 FinalRequestOptions(
474 method="post",
475 url="/foo",
476 **make_request_options(extra_headers={"X-Foo": "Foo"}),
477 ),
478 )
479 assert request.headers.get("X-Foo") == "Foo"
480
481 # `extra_headers` takes priority over `default_headers` when keys clash
482 request = client.with_options(default_headers={"X-Bar": "true"})._build_request(
483 FinalRequestOptions(
484 method="post",
485 url="/foo",
486 **make_request_options(
487 extra_headers={"X-Bar": "false"},
488 ),
489 ),
490 )
491 assert request.headers.get("X-Bar") == "false"
492
493 def test_request_extra_query(self, client: OpenAI) -> None:
494 request = client._build_request(
495 FinalRequestOptions(
496 method="post",
497 url="/foo",
498 **make_request_options(
499 extra_query={"my_query_param": "Foo"},
500 ),
501 ),
502 )
503 params = dict(request.url.params)
504 assert params == {"my_query_param": "Foo"}
505
506 # if both `query` and `extra_query` are given, they are merged
507 request = client._build_request(
508 FinalRequestOptions(
509 method="post",
510 url="/foo",
511 **make_request_options(
512 query={"bar": "1"},
513 extra_query={"foo": "2"},
514 ),
515 ),
516 )
517 params = dict(request.url.params)
518 assert params == {"bar": "1", "foo": "2"}
519
520 # `extra_query` takes priority over `query` when keys clash
521 request = client._build_request(
522 FinalRequestOptions(
523 method="post",
524 url="/foo",
525 **make_request_options(
526 query={"foo": "1"},
527 extra_query={"foo": "2"},
528 ),
529 ),
530 )
531 params = dict(request.url.params)
532 assert params == {"foo": "2"}
533
534 def test_multipart_repeating_array(self, client: OpenAI) -> None:
535 request = client._build_request(
536 FinalRequestOptions.construct(
537 method="post",
538 url="/foo",
539 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
540 json_data={"array": ["foo", "bar"]},
541 files=[("foo.txt", b"hello world")],
542 )
543 )
544
545 assert request.read().split(b"\r\n") == [
546 b"--6b7ba517decee4a450543ea6ae821c82",
547 b'Content-Disposition: form-data; name="array[]"',
548 b"",
549 b"foo",
550 b"--6b7ba517decee4a450543ea6ae821c82",
551 b'Content-Disposition: form-data; name="array[]"',
552 b"",
553 b"bar",
554 b"--6b7ba517decee4a450543ea6ae821c82",
555 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
556 b"Content-Type: application/octet-stream",
557 b"",
558 b"hello world",
559 b"--6b7ba517decee4a450543ea6ae821c82--",
560 b"",
561 ]
562
563 @pytest.mark.respx(base_url=base_url)
564 def test_binary_content_upload(self, respx_mock: MockRouter, client: OpenAI) -> None:
565 respx_mock.post("/upload").mock(side_effect=mirror_request_content)
566
567 file_content = b"Hello, this is a test file."
568
569 response = client.post(
570 "/upload",
571 content=file_content,
572 cast_to=httpx.Response,
573 options={"headers": {"Content-Type": "application/octet-stream"}},
574 )
575
576 assert response.status_code == 200
577 assert response.request.headers["Content-Type"] == "application/octet-stream"
578 assert response.content == file_content
579
580 def test_binary_content_upload_with_iterator(self) -> None:
581 file_content = b"Hello, this is a test file."
582 counter = Counter()
583 iterator = _make_sync_iterator([file_content], counter=counter)
584
585 def mock_handler(request: httpx.Request) -> httpx.Response:
586 assert counter.value == 0, "the request body should not have been read"
587 return httpx.Response(200, content=request.read())
588
589 with OpenAI(
590 base_url=base_url,
591 api_key=api_key,
592 _strict_response_validation=True,
593 http_client=httpx.Client(transport=MockTransport(handler=mock_handler)),
594 ) as client:
595 response = client.post(
596 "/upload",
597 content=iterator,
598 cast_to=httpx.Response,
599 options={"headers": {"Content-Type": "application/octet-stream"}},
600 )
601
602 assert response.status_code == 200
603 assert response.request.headers["Content-Type"] == "application/octet-stream"
604 assert response.content == file_content
605 assert counter.value == 1
606
607 @pytest.mark.respx(base_url=base_url)
608 def test_binary_content_upload_with_body_is_deprecated(self, respx_mock: MockRouter, client: OpenAI) -> None:
609 respx_mock.post("/upload").mock(side_effect=mirror_request_content)
610
611 file_content = b"Hello, this is a test file."
612
613 with pytest.deprecated_call(
614 match="Passing raw bytes as `body` is deprecated and will be removed in a future version. Please pass raw bytes via the `content` parameter instead."
615 ):
616 response = client.post(
617 "/upload",
618 body=file_content,
619 cast_to=httpx.Response,
620 options={"headers": {"Content-Type": "application/octet-stream"}},
621 )
622
623 assert response.status_code == 200
624 assert response.request.headers["Content-Type"] == "application/octet-stream"
625 assert response.content == file_content
626
627 @pytest.mark.respx(base_url=base_url)
628 def test_basic_union_response(self, respx_mock: MockRouter, client: OpenAI) -> None:
629 class Model1(BaseModel):
630 name: str
631
632 class Model2(BaseModel):
633 foo: str
634
635 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
636
637 response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
638 assert isinstance(response, Model2)
639 assert response.foo == "bar"
640
641 @pytest.mark.respx(base_url=base_url)
642 def test_union_response_different_types(self, respx_mock: MockRouter, client: OpenAI) -> None:
643 """Union of objects with the same field name using a different type"""
644
645 class Model1(BaseModel):
646 foo: int
647
648 class Model2(BaseModel):
649 foo: str
650
651 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
652
653 response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
654 assert isinstance(response, Model2)
655 assert response.foo == "bar"
656
657 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
658
659 response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
660 assert isinstance(response, Model1)
661 assert response.foo == 1
662
663 @pytest.mark.respx(base_url=base_url)
664 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter, client: OpenAI) -> None:
665 """
666 Response that sets Content-Type to something other than application/json but returns json data
667 """
668
669 class Model(BaseModel):
670 foo: int
671
672 respx_mock.get("/foo").mock(
673 return_value=httpx.Response(
674 200,
675 content=json.dumps({"foo": 2}),
676 headers={"Content-Type": "application/text"},
677 )
678 )
679
680 response = client.get("/foo", cast_to=Model)
681 assert isinstance(response, Model)
682 assert response.foo == 2
683
684 def test_base_url_setter(self) -> None:
685 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
686 assert client.base_url == "https://example.com/from_init/"
687
688 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
689
690 assert client.base_url == "https://example.com/from_setter/"
691
692 client.close()
693
694 def test_base_url_env(self) -> None:
695 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
696 client = OpenAI(api_key=api_key, _strict_response_validation=True)
697 assert client.base_url == "http://localhost:5000/from/env/"
698
699 @pytest.mark.parametrize(
700 "client",
701 [
702 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
703 OpenAI(
704 base_url="http://localhost:5000/custom/path/",
705 api_key=api_key,
706 _strict_response_validation=True,
707 http_client=httpx.Client(),
708 ),
709 ],
710 ids=["standard", "custom http client"],
711 )
712 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
713 request = client._build_request(
714 FinalRequestOptions(
715 method="post",
716 url="/foo",
717 json_data={"foo": "bar"},
718 ),
719 )
720 assert request.url == "http://localhost:5000/custom/path/foo"
721 client.close()
722
723 @pytest.mark.parametrize(
724 "client",
725 [
726 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
727 OpenAI(
728 base_url="http://localhost:5000/custom/path/",
729 api_key=api_key,
730 _strict_response_validation=True,
731 http_client=httpx.Client(),
732 ),
733 ],
734 ids=["standard", "custom http client"],
735 )
736 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
737 request = client._build_request(
738 FinalRequestOptions(
739 method="post",
740 url="/foo",
741 json_data={"foo": "bar"},
742 ),
743 )
744 assert request.url == "http://localhost:5000/custom/path/foo"
745 client.close()
746
747 @pytest.mark.parametrize(
748 "client",
749 [
750 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
751 OpenAI(
752 base_url="http://localhost:5000/custom/path/",
753 api_key=api_key,
754 _strict_response_validation=True,
755 http_client=httpx.Client(),
756 ),
757 ],
758 ids=["standard", "custom http client"],
759 )
760 def test_absolute_request_url(self, client: OpenAI) -> None:
761 request = client._build_request(
762 FinalRequestOptions(
763 method="post",
764 url="https://myapi.com/foo",
765 json_data={"foo": "bar"},
766 ),
767 )
768 assert request.url == "https://myapi.com/foo"
769 client.close()
770
771 def test_copied_client_does_not_close_http(self) -> None:
772 test_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
773 assert not test_client.is_closed()
774
775 copied = test_client.copy()
776 assert copied is not test_client
777
778 del copied
779
780 assert not test_client.is_closed()
781
782 def test_client_context_manager(self) -> None:
783 test_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
784 with test_client as c2:
785 assert c2 is test_client
786 assert not c2.is_closed()
787 assert not test_client.is_closed()
788 assert test_client.is_closed()
789
790 @pytest.mark.respx(base_url=base_url)
791 def test_client_response_validation_error(self, respx_mock: MockRouter, client: OpenAI) -> None:
792 class Model(BaseModel):
793 foo: str
794
795 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
796
797 with pytest.raises(APIResponseValidationError) as exc:
798 client.get("/foo", cast_to=Model)
799
800 assert isinstance(exc.value.__cause__, ValidationError)
801
802 def test_client_max_retries_validation(self) -> None:
803 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
804 OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None))
805
806 @pytest.mark.respx(base_url=base_url)
807 def test_default_stream_cls(self, respx_mock: MockRouter, client: OpenAI) -> None:
808 class Model(BaseModel):
809 name: str
810
811 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
812
813 stream = client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
814 assert isinstance(stream, Stream)
815 stream.response.close()
816
817 @pytest.mark.respx(base_url=base_url)
818 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
819 class Model(BaseModel):
820 name: str
821
822 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
823
824 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
825
826 with pytest.raises(APIResponseValidationError):
827 strict_client.get("/foo", cast_to=Model)
828
829 non_strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
830
831 response = non_strict_client.get("/foo", cast_to=Model)
832 assert isinstance(response, str) # type: ignore[unreachable]
833
834 strict_client.close()
835 non_strict_client.close()
836
837 @pytest.mark.parametrize(
838 "remaining_retries,retry_after,timeout",
839 [
840 [3, "20", 20],
841 [3, "0", 0.5],
842 [3, "-10", 0.5],
843 [3, "60", 60],
844 [3, "61", 0.5],
845 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
846 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
847 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
848 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
849 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
850 [3, "99999999999999999999999999999999999", 0.5],
851 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
852 [3, "", 0.5],
853 [2, "", 0.5 * 2.0],
854 [1, "", 0.5 * 4.0],
855 [-1100, "", 8], # test large number potentially overflowing
856 ],
857 )
858 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
859 def test_parse_retry_after_header(
860 self, remaining_retries: int, retry_after: str, timeout: float, client: OpenAI
861 ) -> None:
862 headers = httpx.Headers({"retry-after": retry_after})
863 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
864 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
865 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
866
867 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
868 @pytest.mark.respx(base_url=base_url)
869 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: OpenAI) -> None:
870 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
871
872 with pytest.raises(APITimeoutError):
873 client.chat.completions.with_streaming_response.create(
874 messages=[
875 {
876 "content": "string",
877 "role": "developer",
878 }
879 ],
880 model="gpt-4o",
881 ).__enter__()
882
883 assert _get_open_connections(client) == 0
884
885 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
886 @pytest.mark.respx(base_url=base_url)
887 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: OpenAI) -> None:
888 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
889
890 with pytest.raises(APIStatusError):
891 client.chat.completions.with_streaming_response.create(
892 messages=[
893 {
894 "content": "string",
895 "role": "developer",
896 }
897 ],
898 model="gpt-4o",
899 ).__enter__()
900 assert _get_open_connections(client) == 0
901
902 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
903 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
904 @pytest.mark.respx(base_url=base_url)
905 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
906 def test_retries_taken(
907 self,
908 client: OpenAI,
909 failures_before_success: int,
910 failure_mode: Literal["status", "exception"],
911 respx_mock: MockRouter,
912 ) -> None:
913 client = client.with_options(max_retries=4)
914
915 nb_retries = 0
916
917 def retry_handler(_request: httpx.Request) -> httpx.Response:
918 nonlocal nb_retries
919 if nb_retries < failures_before_success:
920 nb_retries += 1
921 if failure_mode == "exception":
922 raise RuntimeError("oops")
923 return httpx.Response(500)
924 return httpx.Response(200)
925
926 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
927
928 response = client.chat.completions.with_raw_response.create(
929 messages=[
930 {
931 "content": "string",
932 "role": "developer",
933 }
934 ],
935 model="gpt-4o",
936 )
937
938 assert response.retries_taken == failures_before_success
939 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
940
941 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
942 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
943 @pytest.mark.respx(base_url=base_url)
944 def test_omit_retry_count_header(
945 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
946 ) -> None:
947 client = client.with_options(max_retries=4)
948
949 nb_retries = 0
950
951 def retry_handler(_request: httpx.Request) -> httpx.Response:
952 nonlocal nb_retries
953 if nb_retries < failures_before_success:
954 nb_retries += 1
955 return httpx.Response(500)
956 return httpx.Response(200)
957
958 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
959
960 response = client.chat.completions.with_raw_response.create(
961 messages=[
962 {
963 "content": "string",
964 "role": "developer",
965 }
966 ],
967 model="gpt-4o",
968 extra_headers={"x-stainless-retry-count": Omit()},
969 )
970
971 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
972
973 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
974 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
975 @pytest.mark.respx(base_url=base_url)
976 def test_overwrite_retry_count_header(
977 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
978 ) -> None:
979 client = client.with_options(max_retries=4)
980
981 nb_retries = 0
982
983 def retry_handler(_request: httpx.Request) -> httpx.Response:
984 nonlocal nb_retries
985 if nb_retries < failures_before_success:
986 nb_retries += 1
987 return httpx.Response(500)
988 return httpx.Response(200)
989
990 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
991
992 response = client.chat.completions.with_raw_response.create(
993 messages=[
994 {
995 "content": "string",
996 "role": "developer",
997 }
998 ],
999 model="gpt-4o",
1000 extra_headers={"x-stainless-retry-count": "42"},
1001 )
1002
1003 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
1004
1005 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1006 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1007 @pytest.mark.respx(base_url=base_url)
1008 def test_retries_taken_new_response_class(
1009 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
1010 ) -> None:
1011 client = client.with_options(max_retries=4)
1012
1013 nb_retries = 0
1014
1015 def retry_handler(_request: httpx.Request) -> httpx.Response:
1016 nonlocal nb_retries
1017 if nb_retries < failures_before_success:
1018 nb_retries += 1
1019 return httpx.Response(500)
1020 return httpx.Response(200)
1021
1022 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1023
1024 with client.chat.completions.with_streaming_response.create(
1025 messages=[
1026 {
1027 "content": "string",
1028 "role": "developer",
1029 }
1030 ],
1031 model="gpt-4o",
1032 ) as response:
1033 assert response.retries_taken == failures_before_success
1034 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1035
1036 def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
1037 # Test that the proxy environment variables are set correctly
1038 monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
1039 # Delete in case our environment has this set
1040 monkeypatch.delenv("HTTP_PROXY", raising=False)
1041
1042 client = DefaultHttpxClient()
1043
1044 mounts = tuple(client._mounts.items())
1045 assert len(mounts) == 1
1046 assert mounts[0][0].pattern == "https://"
1047
1048 @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
1049 def test_default_client_creation(self) -> None:
1050 # Ensure that the client can be initialized without any exceptions
1051 DefaultHttpxClient(
1052 verify=True,
1053 cert=None,
1054 trust_env=True,
1055 http1=True,
1056 http2=False,
1057 limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
1058 )
1059
1060 @pytest.mark.respx(base_url=base_url)
1061 def test_follow_redirects(self, respx_mock: MockRouter, client: OpenAI) -> None:
1062 # Test that the default follow_redirects=True allows following redirects
1063 respx_mock.post("/redirect").mock(
1064 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1065 )
1066 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
1067
1068 response = client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
1069 assert response.status_code == 200
1070 assert response.json() == {"status": "ok"}
1071
1072 @pytest.mark.respx(base_url=base_url)
1073 def test_follow_redirects_disabled(self, respx_mock: MockRouter, client: OpenAI) -> None:
1074 # Test that follow_redirects=False prevents following redirects
1075 respx_mock.post("/redirect").mock(
1076 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1077 )
1078
1079 with pytest.raises(APIStatusError) as exc_info:
1080 client.post("/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response)
1081
1082 assert exc_info.value.response.status_code == 302
1083 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
1084
1085 def test_api_key_before_after_refresh_provider(self) -> None:
1086 client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token")
1087
1088 assert client.api_key == ""
1089 assert "Authorization" not in client.auth_headers
1090
1091 client._refresh_api_key()
1092
1093 assert client.api_key == "test_bearer_token"
1094 assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
1095
1096 def test_api_key_before_after_refresh_str(self) -> None:
1097 client = OpenAI(base_url=base_url, api_key="test_api_key")
1098
1099 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
1100 client._refresh_api_key()
1101
1102 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
1103
1104 @pytest.mark.respx()
1105 def test_api_key_refresh_on_retry(self, respx_mock: MockRouter) -> None:
1106 respx_mock.post(base_url + "/chat/completions").mock(
1107 side_effect=[
1108 httpx.Response(500, json={"error": "server error"}),
1109 httpx.Response(200, json={"foo": "bar"}),
1110 ]
1111 )
1112
1113 counter = 0
1114
1115 def token_provider() -> str:
1116 nonlocal counter
1117
1118 counter += 1
1119
1120 if counter == 1:
1121 return "first"
1122
1123 return "second"
1124
1125 client = OpenAI(base_url=base_url, api_key=token_provider)
1126 client.chat.completions.create(messages=[], model="gpt-4")
1127
1128 calls = cast("list[MockRequestCall]", respx_mock.calls)
1129 assert len(calls) == 2
1130
1131 assert calls[0].request.headers.get("Authorization") == "Bearer first"
1132 assert calls[1].request.headers.get("Authorization") == "Bearer second"
1133
1134 def test_copy_auth(self) -> None:
1135 client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token_1").copy(
1136 api_key=lambda: "test_bearer_token_2"
1137 )
1138 client._refresh_api_key()
1139 assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}
1140
1141
1142class TestAsyncOpenAI:
1143 @pytest.mark.respx(base_url=base_url)
1144 async def test_raw_response(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1145 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1146
1147 response = await async_client.post("/foo", cast_to=httpx.Response)
1148 assert response.status_code == 200
1149 assert isinstance(response, httpx.Response)
1150 assert response.json() == {"foo": "bar"}
1151
1152 @pytest.mark.respx(base_url=base_url)
1153 async def test_raw_response_for_binary(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1154 respx_mock.post("/foo").mock(
1155 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
1156 )
1157
1158 response = await async_client.post("/foo", cast_to=httpx.Response)
1159 assert response.status_code == 200
1160 assert isinstance(response, httpx.Response)
1161 assert response.json() == {"foo": "bar"}
1162
1163 def test_copy(self, async_client: AsyncOpenAI) -> None:
1164 copied = async_client.copy()
1165 assert id(copied) != id(async_client)
1166
1167 copied = async_client.copy(api_key="another My API Key")
1168 assert copied.api_key == "another My API Key"
1169 assert async_client.api_key == "My API Key"
1170
1171 def test_copy_default_options(self, async_client: AsyncOpenAI) -> None:
1172 # options that have a default are overridden correctly
1173 copied = async_client.copy(max_retries=7)
1174 assert copied.max_retries == 7
1175 assert async_client.max_retries == 2
1176
1177 copied2 = copied.copy(max_retries=6)
1178 assert copied2.max_retries == 6
1179 assert copied.max_retries == 7
1180
1181 # timeout
1182 assert isinstance(async_client.timeout, httpx.Timeout)
1183 copied = async_client.copy(timeout=None)
1184 assert copied.timeout is None
1185 assert isinstance(async_client.timeout, httpx.Timeout)
1186
1187 async def test_copy_default_headers(self) -> None:
1188 client = AsyncOpenAI(
1189 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1190 )
1191 assert client.default_headers["X-Foo"] == "bar"
1192
1193 # does not override the already given value when not specified
1194 copied = client.copy()
1195 assert copied.default_headers["X-Foo"] == "bar"
1196
1197 # merges already given headers
1198 copied = client.copy(default_headers={"X-Bar": "stainless"})
1199 assert copied.default_headers["X-Foo"] == "bar"
1200 assert copied.default_headers["X-Bar"] == "stainless"
1201
1202 # uses new values for any already given headers
1203 copied = client.copy(default_headers={"X-Foo": "stainless"})
1204 assert copied.default_headers["X-Foo"] == "stainless"
1205
1206 # set_default_headers
1207
1208 # completely overrides already set values
1209 copied = client.copy(set_default_headers={})
1210 assert copied.default_headers.get("X-Foo") is None
1211
1212 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
1213 assert copied.default_headers["X-Bar"] == "Robert"
1214
1215 with pytest.raises(
1216 ValueError,
1217 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
1218 ):
1219 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
1220 await client.close()
1221
1222 async def test_copy_default_query(self) -> None:
1223 client = AsyncOpenAI(
1224 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
1225 )
1226 assert _get_params(client)["foo"] == "bar"
1227
1228 # does not override the already given value when not specified
1229 copied = client.copy()
1230 assert _get_params(copied)["foo"] == "bar"
1231
1232 # merges already given params
1233 copied = client.copy(default_query={"bar": "stainless"})
1234 params = _get_params(copied)
1235 assert params["foo"] == "bar"
1236 assert params["bar"] == "stainless"
1237
1238 # uses new values for any already given headers
1239 copied = client.copy(default_query={"foo": "stainless"})
1240 assert _get_params(copied)["foo"] == "stainless"
1241
1242 # set_default_query
1243
1244 # completely overrides already set values
1245 copied = client.copy(set_default_query={})
1246 assert _get_params(copied) == {}
1247
1248 copied = client.copy(set_default_query={"bar": "Robert"})
1249 assert _get_params(copied)["bar"] == "Robert"
1250
1251 with pytest.raises(
1252 ValueError,
1253 # TODO: update
1254 match="`default_query` and `set_default_query` arguments are mutually exclusive",
1255 ):
1256 client.copy(set_default_query={}, default_query={"foo": "Bar"})
1257
1258 await client.close()
1259
1260 def test_copy_signature(self, async_client: AsyncOpenAI) -> None:
1261 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
1262 init_signature = inspect.signature(
1263 # mypy doesn't like that we access the `__init__` property.
1264 async_client.__init__, # type: ignore[misc]
1265 )
1266 copy_signature = inspect.signature(async_client.copy)
1267 exclude_params = {"transport", "proxies", "_strict_response_validation"}
1268
1269 for name in init_signature.parameters.keys():
1270 if name in exclude_params:
1271 continue
1272
1273 copy_param = copy_signature.parameters.get(name)
1274 assert copy_param is not None, f"copy() signature is missing the {name} param"
1275
1276 @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
1277 def test_copy_build_request(self, async_client: AsyncOpenAI) -> None:
1278 options = FinalRequestOptions(method="get", url="/foo")
1279
1280 def build_request(options: FinalRequestOptions) -> None:
1281 client_copy = async_client.copy()
1282 client_copy._build_request(options)
1283
1284 # ensure that the machinery is warmed up before tracing starts.
1285 build_request(options)
1286 gc.collect()
1287
1288 tracemalloc.start(1000)
1289
1290 snapshot_before = tracemalloc.take_snapshot()
1291
1292 ITERATIONS = 10
1293 for _ in range(ITERATIONS):
1294 build_request(options)
1295
1296 gc.collect()
1297 snapshot_after = tracemalloc.take_snapshot()
1298
1299 tracemalloc.stop()
1300
1301 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
1302 if diff.count == 0:
1303 # Avoid false positives by considering only leaks (i.e. allocations that persist).
1304 return
1305
1306 if diff.count % ITERATIONS != 0:
1307 # Avoid false positives by considering only leaks that appear per iteration.
1308 return
1309
1310 for frame in diff.traceback:
1311 if any(
1312 frame.filename.endswith(fragment)
1313 for fragment in [
1314 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
1315 #
1316 # removing the decorator fixes the leak for reasons we don't understand.
1317 "openai/_legacy_response.py",
1318 "openai/_response.py",
1319 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
1320 "openai/_compat.py",
1321 # Standard library leaks we don't care about.
1322 "/logging/__init__.py",
1323 ]
1324 ):
1325 return
1326
1327 leaks.append(diff)
1328
1329 leaks: list[tracemalloc.StatisticDiff] = []
1330 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
1331 add_leak(leaks, diff)
1332 if leaks:
1333 for leak in leaks:
1334 print("MEMORY LEAK:", leak)
1335 for frame in leak.traceback:
1336 print(frame)
1337 raise AssertionError()
1338
1339 async def test_request_timeout(self, async_client: AsyncOpenAI) -> None:
1340 request = async_client._build_request(FinalRequestOptions(method="get", url="/foo"))
1341 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1342 assert timeout == DEFAULT_TIMEOUT
1343
1344 request = async_client._build_request(
1345 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
1346 )
1347 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1348 assert timeout == httpx.Timeout(100.0)
1349
1350 async def test_client_timeout_option(self) -> None:
1351 client = AsyncOpenAI(
1352 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
1353 )
1354
1355 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1356 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1357 assert timeout == httpx.Timeout(0)
1358
1359 await client.close()
1360
1361 async def test_http_client_timeout_option(self) -> None:
1362 # custom timeout given to the httpx client should be used
1363 async with httpx.AsyncClient(timeout=None) as http_client:
1364 client = AsyncOpenAI(
1365 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1366 )
1367
1368 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1369 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1370 assert timeout == httpx.Timeout(None)
1371
1372 await client.close()
1373
1374 # no timeout given to the httpx client should not use the httpx default
1375 async with httpx.AsyncClient() as http_client:
1376 client = AsyncOpenAI(
1377 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1378 )
1379
1380 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1381 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1382 assert timeout == DEFAULT_TIMEOUT
1383
1384 await client.close()
1385
1386 # explicitly passing the default timeout currently results in it being ignored
1387 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1388 client = AsyncOpenAI(
1389 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1390 )
1391
1392 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1393 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1394 assert timeout == DEFAULT_TIMEOUT # our default
1395
1396 await client.close()
1397
1398 def test_invalid_http_client(self) -> None:
1399 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1400 with httpx.Client() as http_client:
1401 AsyncOpenAI(
1402 base_url=base_url,
1403 api_key=api_key,
1404 _strict_response_validation=True,
1405 http_client=cast(Any, http_client),
1406 )
1407
1408 async def test_default_headers_option(self) -> None:
1409 test_client = AsyncOpenAI(
1410 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1411 )
1412 request = test_client._build_request(FinalRequestOptions(method="get", url="/foo"))
1413 assert request.headers.get("x-foo") == "bar"
1414 assert request.headers.get("x-stainless-lang") == "python"
1415
1416 test_client2 = AsyncOpenAI(
1417 base_url=base_url,
1418 api_key=api_key,
1419 _strict_response_validation=True,
1420 default_headers={
1421 "X-Foo": "stainless",
1422 "X-Stainless-Lang": "my-overriding-header",
1423 },
1424 )
1425 request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1426 assert request.headers.get("x-foo") == "stainless"
1427 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1428
1429 await test_client.close()
1430 await test_client2.close()
1431
1432 async def test_validate_headers(self) -> None:
1433 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1434 options = await client._prepare_options(FinalRequestOptions(method="get", url="/foo"))
1435 request = client._build_request(options)
1436 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1437
1438 with pytest.raises(OpenAIError):
1439 with update_env(**{"OPENAI_API_KEY": Omit()}):
1440 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1441 _ = client2
1442
1443 async def test_default_query_option(self) -> None:
1444 client = AsyncOpenAI(
1445 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1446 )
1447 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1448 url = httpx.URL(request.url)
1449 assert dict(url.params) == {"query_param": "bar"}
1450
1451 request = client._build_request(
1452 FinalRequestOptions(
1453 method="get",
1454 url="/foo",
1455 params={"foo": "baz", "query_param": "overridden"},
1456 )
1457 )
1458 url = httpx.URL(request.url)
1459 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
1460
1461 await client.close()
1462
1463 def test_request_extra_json(self, client: OpenAI) -> None:
1464 request = client._build_request(
1465 FinalRequestOptions(
1466 method="post",
1467 url="/foo",
1468 json_data={"foo": "bar"},
1469 extra_json={"baz": False},
1470 ),
1471 )
1472 data = json.loads(request.content.decode("utf-8"))
1473 assert data == {"foo": "bar", "baz": False}
1474
1475 request = client._build_request(
1476 FinalRequestOptions(
1477 method="post",
1478 url="/foo",
1479 extra_json={"baz": False},
1480 ),
1481 )
1482 data = json.loads(request.content.decode("utf-8"))
1483 assert data == {"baz": False}
1484
1485 # `extra_json` takes priority over `json_data` when keys clash
1486 request = client._build_request(
1487 FinalRequestOptions(
1488 method="post",
1489 url="/foo",
1490 json_data={"foo": "bar", "baz": True},
1491 extra_json={"baz": None},
1492 ),
1493 )
1494 data = json.loads(request.content.decode("utf-8"))
1495 assert data == {"foo": "bar", "baz": None}
1496
1497 def test_request_extra_headers(self, client: OpenAI) -> None:
1498 request = client._build_request(
1499 FinalRequestOptions(
1500 method="post",
1501 url="/foo",
1502 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1503 ),
1504 )
1505 assert request.headers.get("X-Foo") == "Foo"
1506
1507 # `extra_headers` takes priority over `default_headers` when keys clash
1508 request = client.with_options(default_headers={"X-Bar": "true"})._build_request(
1509 FinalRequestOptions(
1510 method="post",
1511 url="/foo",
1512 **make_request_options(
1513 extra_headers={"X-Bar": "false"},
1514 ),
1515 ),
1516 )
1517 assert request.headers.get("X-Bar") == "false"
1518
1519 def test_request_extra_query(self, client: OpenAI) -> None:
1520 request = client._build_request(
1521 FinalRequestOptions(
1522 method="post",
1523 url="/foo",
1524 **make_request_options(
1525 extra_query={"my_query_param": "Foo"},
1526 ),
1527 ),
1528 )
1529 params = dict(request.url.params)
1530 assert params == {"my_query_param": "Foo"}
1531
1532 # if both `query` and `extra_query` are given, they are merged
1533 request = client._build_request(
1534 FinalRequestOptions(
1535 method="post",
1536 url="/foo",
1537 **make_request_options(
1538 query={"bar": "1"},
1539 extra_query={"foo": "2"},
1540 ),
1541 ),
1542 )
1543 params = dict(request.url.params)
1544 assert params == {"bar": "1", "foo": "2"}
1545
1546 # `extra_query` takes priority over `query` when keys clash
1547 request = client._build_request(
1548 FinalRequestOptions(
1549 method="post",
1550 url="/foo",
1551 **make_request_options(
1552 query={"foo": "1"},
1553 extra_query={"foo": "2"},
1554 ),
1555 ),
1556 )
1557 params = dict(request.url.params)
1558 assert params == {"foo": "2"}
1559
1560 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1561 request = async_client._build_request(
1562 FinalRequestOptions.construct(
1563 method="post",
1564 url="/foo",
1565 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1566 json_data={"array": ["foo", "bar"]},
1567 files=[("foo.txt", b"hello world")],
1568 )
1569 )
1570
1571 assert request.read().split(b"\r\n") == [
1572 b"--6b7ba517decee4a450543ea6ae821c82",
1573 b'Content-Disposition: form-data; name="array[]"',
1574 b"",
1575 b"foo",
1576 b"--6b7ba517decee4a450543ea6ae821c82",
1577 b'Content-Disposition: form-data; name="array[]"',
1578 b"",
1579 b"bar",
1580 b"--6b7ba517decee4a450543ea6ae821c82",
1581 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1582 b"Content-Type: application/octet-stream",
1583 b"",
1584 b"hello world",
1585 b"--6b7ba517decee4a450543ea6ae821c82--",
1586 b"",
1587 ]
1588
1589 @pytest.mark.respx(base_url=base_url)
1590 async def test_binary_content_upload(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1591 respx_mock.post("/upload").mock(side_effect=mirror_request_content)
1592
1593 file_content = b"Hello, this is a test file."
1594
1595 response = await async_client.post(
1596 "/upload",
1597 content=file_content,
1598 cast_to=httpx.Response,
1599 options={"headers": {"Content-Type": "application/octet-stream"}},
1600 )
1601
1602 assert response.status_code == 200
1603 assert response.request.headers["Content-Type"] == "application/octet-stream"
1604 assert response.content == file_content
1605
1606 async def test_binary_content_upload_with_asynciterator(self) -> None:
1607 file_content = b"Hello, this is a test file."
1608 counter = Counter()
1609 iterator = _make_async_iterator([file_content], counter=counter)
1610
1611 async def mock_handler(request: httpx.Request) -> httpx.Response:
1612 assert counter.value == 0, "the request body should not have been read"
1613 return httpx.Response(200, content=await request.aread())
1614
1615 async with AsyncOpenAI(
1616 base_url=base_url,
1617 api_key=api_key,
1618 _strict_response_validation=True,
1619 http_client=httpx.AsyncClient(transport=MockTransport(handler=mock_handler)),
1620 ) as client:
1621 response = await client.post(
1622 "/upload",
1623 content=iterator,
1624 cast_to=httpx.Response,
1625 options={"headers": {"Content-Type": "application/octet-stream"}},
1626 )
1627
1628 assert response.status_code == 200
1629 assert response.request.headers["Content-Type"] == "application/octet-stream"
1630 assert response.content == file_content
1631 assert counter.value == 1
1632
1633 @pytest.mark.respx(base_url=base_url)
1634 async def test_binary_content_upload_with_body_is_deprecated(
1635 self, respx_mock: MockRouter, async_client: AsyncOpenAI
1636 ) -> None:
1637 respx_mock.post("/upload").mock(side_effect=mirror_request_content)
1638
1639 file_content = b"Hello, this is a test file."
1640
1641 with pytest.deprecated_call(
1642 match="Passing raw bytes as `body` is deprecated and will be removed in a future version. Please pass raw bytes via the `content` parameter instead."
1643 ):
1644 response = await async_client.post(
1645 "/upload",
1646 body=file_content,
1647 cast_to=httpx.Response,
1648 options={"headers": {"Content-Type": "application/octet-stream"}},
1649 )
1650
1651 assert response.status_code == 200
1652 assert response.request.headers["Content-Type"] == "application/octet-stream"
1653 assert response.content == file_content
1654
1655 @pytest.mark.respx(base_url=base_url)
1656 async def test_basic_union_response(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1657 class Model1(BaseModel):
1658 name: str
1659
1660 class Model2(BaseModel):
1661 foo: str
1662
1663 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1664
1665 response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1666 assert isinstance(response, Model2)
1667 assert response.foo == "bar"
1668
1669 @pytest.mark.respx(base_url=base_url)
1670 async def test_union_response_different_types(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1671 """Union of objects with the same field name using a different type"""
1672
1673 class Model1(BaseModel):
1674 foo: int
1675
1676 class Model2(BaseModel):
1677 foo: str
1678
1679 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1680
1681 response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1682 assert isinstance(response, Model2)
1683 assert response.foo == "bar"
1684
1685 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1686
1687 response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1688 assert isinstance(response, Model1)
1689 assert response.foo == 1
1690
1691 @pytest.mark.respx(base_url=base_url)
1692 async def test_non_application_json_content_type_for_json_data(
1693 self, respx_mock: MockRouter, async_client: AsyncOpenAI
1694 ) -> None:
1695 """
1696 Response that sets Content-Type to something other than application/json but returns json data
1697 """
1698
1699 class Model(BaseModel):
1700 foo: int
1701
1702 respx_mock.get("/foo").mock(
1703 return_value=httpx.Response(
1704 200,
1705 content=json.dumps({"foo": 2}),
1706 headers={"Content-Type": "application/text"},
1707 )
1708 )
1709
1710 response = await async_client.get("/foo", cast_to=Model)
1711 assert isinstance(response, Model)
1712 assert response.foo == 2
1713
1714 async def test_base_url_setter(self) -> None:
1715 client = AsyncOpenAI(
1716 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1717 )
1718 assert client.base_url == "https://example.com/from_init/"
1719
1720 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1721
1722 assert client.base_url == "https://example.com/from_setter/"
1723
1724 await client.close()
1725
1726 async def test_base_url_env(self) -> None:
1727 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1728 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1729 assert client.base_url == "http://localhost:5000/from/env/"
1730
1731 @pytest.mark.parametrize(
1732 "client",
1733 [
1734 AsyncOpenAI(
1735 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1736 ),
1737 AsyncOpenAI(
1738 base_url="http://localhost:5000/custom/path/",
1739 api_key=api_key,
1740 _strict_response_validation=True,
1741 http_client=httpx.AsyncClient(),
1742 ),
1743 ],
1744 ids=["standard", "custom http client"],
1745 )
1746 async def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1747 request = client._build_request(
1748 FinalRequestOptions(
1749 method="post",
1750 url="/foo",
1751 json_data={"foo": "bar"},
1752 ),
1753 )
1754 assert request.url == "http://localhost:5000/custom/path/foo"
1755 await client.close()
1756
1757 @pytest.mark.parametrize(
1758 "client",
1759 [
1760 AsyncOpenAI(
1761 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1762 ),
1763 AsyncOpenAI(
1764 base_url="http://localhost:5000/custom/path/",
1765 api_key=api_key,
1766 _strict_response_validation=True,
1767 http_client=httpx.AsyncClient(),
1768 ),
1769 ],
1770 ids=["standard", "custom http client"],
1771 )
1772 async def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1773 request = client._build_request(
1774 FinalRequestOptions(
1775 method="post",
1776 url="/foo",
1777 json_data={"foo": "bar"},
1778 ),
1779 )
1780 assert request.url == "http://localhost:5000/custom/path/foo"
1781 await client.close()
1782
1783 @pytest.mark.parametrize(
1784 "client",
1785 [
1786 AsyncOpenAI(
1787 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1788 ),
1789 AsyncOpenAI(
1790 base_url="http://localhost:5000/custom/path/",
1791 api_key=api_key,
1792 _strict_response_validation=True,
1793 http_client=httpx.AsyncClient(),
1794 ),
1795 ],
1796 ids=["standard", "custom http client"],
1797 )
1798 async def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1799 request = client._build_request(
1800 FinalRequestOptions(
1801 method="post",
1802 url="https://myapi.com/foo",
1803 json_data={"foo": "bar"},
1804 ),
1805 )
1806 assert request.url == "https://myapi.com/foo"
1807 await client.close()
1808
1809 async def test_copied_client_does_not_close_http(self) -> None:
1810 test_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1811 assert not test_client.is_closed()
1812
1813 copied = test_client.copy()
1814 assert copied is not test_client
1815
1816 del copied
1817
1818 await asyncio.sleep(0.2)
1819 assert not test_client.is_closed()
1820
1821 async def test_client_context_manager(self) -> None:
1822 test_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1823 async with test_client as c2:
1824 assert c2 is test_client
1825 assert not c2.is_closed()
1826 assert not test_client.is_closed()
1827 assert test_client.is_closed()
1828
1829 @pytest.mark.respx(base_url=base_url)
1830 async def test_client_response_validation_error(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1831 class Model(BaseModel):
1832 foo: str
1833
1834 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1835
1836 with pytest.raises(APIResponseValidationError) as exc:
1837 await async_client.get("/foo", cast_to=Model)
1838
1839 assert isinstance(exc.value.__cause__, ValidationError)
1840
1841 async def test_client_max_retries_validation(self) -> None:
1842 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
1843 AsyncOpenAI(
1844 base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)
1845 )
1846
1847 @pytest.mark.respx(base_url=base_url)
1848 async def test_default_stream_cls(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1849 class Model(BaseModel):
1850 name: str
1851
1852 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1853
1854 stream = await async_client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1855 assert isinstance(stream, AsyncStream)
1856 await stream.response.aclose()
1857
1858 @pytest.mark.respx(base_url=base_url)
1859 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1860 class Model(BaseModel):
1861 name: str
1862
1863 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1864
1865 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1866
1867 with pytest.raises(APIResponseValidationError):
1868 await strict_client.get("/foo", cast_to=Model)
1869
1870 non_strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1871
1872 response = await non_strict_client.get("/foo", cast_to=Model)
1873 assert isinstance(response, str) # type: ignore[unreachable]
1874
1875 await strict_client.close()
1876 await non_strict_client.close()
1877
1878 @pytest.mark.parametrize(
1879 "remaining_retries,retry_after,timeout",
1880 [
1881 [3, "20", 20],
1882 [3, "0", 0.5],
1883 [3, "-10", 0.5],
1884 [3, "60", 60],
1885 [3, "61", 0.5],
1886 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1887 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1888 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1889 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1890 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1891 [3, "99999999999999999999999999999999999", 0.5],
1892 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1893 [3, "", 0.5],
1894 [2, "", 0.5 * 2.0],
1895 [1, "", 0.5 * 4.0],
1896 [-1100, "", 8], # test large number potentially overflowing
1897 ],
1898 )
1899 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1900 async def test_parse_retry_after_header(
1901 self, remaining_retries: int, retry_after: str, timeout: float, async_client: AsyncOpenAI
1902 ) -> None:
1903 headers = httpx.Headers({"retry-after": retry_after})
1904 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1905 calculated = async_client._calculate_retry_timeout(remaining_retries, options, headers)
1906 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1907
1908 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1909 @pytest.mark.respx(base_url=base_url)
1910 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1911 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1912
1913 with pytest.raises(APITimeoutError):
1914 await async_client.chat.completions.with_streaming_response.create(
1915 messages=[
1916 {
1917 "content": "string",
1918 "role": "developer",
1919 }
1920 ],
1921 model="gpt-4o",
1922 ).__aenter__()
1923
1924 assert _get_open_connections(async_client) == 0
1925
1926 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1927 @pytest.mark.respx(base_url=base_url)
1928 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1929 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1930
1931 with pytest.raises(APIStatusError):
1932 await async_client.chat.completions.with_streaming_response.create(
1933 messages=[
1934 {
1935 "content": "string",
1936 "role": "developer",
1937 }
1938 ],
1939 model="gpt-4o",
1940 ).__aenter__()
1941 assert _get_open_connections(async_client) == 0
1942
1943 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1944 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1945 @pytest.mark.respx(base_url=base_url)
1946 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
1947 async def test_retries_taken(
1948 self,
1949 async_client: AsyncOpenAI,
1950 failures_before_success: int,
1951 failure_mode: Literal["status", "exception"],
1952 respx_mock: MockRouter,
1953 ) -> None:
1954 client = async_client.with_options(max_retries=4)
1955
1956 nb_retries = 0
1957
1958 def retry_handler(_request: httpx.Request) -> httpx.Response:
1959 nonlocal nb_retries
1960 if nb_retries < failures_before_success:
1961 nb_retries += 1
1962 if failure_mode == "exception":
1963 raise RuntimeError("oops")
1964 return httpx.Response(500)
1965 return httpx.Response(200)
1966
1967 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1968
1969 response = await client.chat.completions.with_raw_response.create(
1970 messages=[
1971 {
1972 "content": "string",
1973 "role": "developer",
1974 }
1975 ],
1976 model="gpt-4o",
1977 )
1978
1979 assert response.retries_taken == failures_before_success
1980 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1981
1982 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1983 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1984 @pytest.mark.respx(base_url=base_url)
1985 async def test_omit_retry_count_header(
1986 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1987 ) -> None:
1988 client = async_client.with_options(max_retries=4)
1989
1990 nb_retries = 0
1991
1992 def retry_handler(_request: httpx.Request) -> httpx.Response:
1993 nonlocal nb_retries
1994 if nb_retries < failures_before_success:
1995 nb_retries += 1
1996 return httpx.Response(500)
1997 return httpx.Response(200)
1998
1999 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
2000
2001 response = await client.chat.completions.with_raw_response.create(
2002 messages=[
2003 {
2004 "content": "string",
2005 "role": "developer",
2006 }
2007 ],
2008 model="gpt-4o",
2009 extra_headers={"x-stainless-retry-count": Omit()},
2010 )
2011
2012 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
2013
2014 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
2015 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
2016 @pytest.mark.respx(base_url=base_url)
2017 async def test_overwrite_retry_count_header(
2018 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
2019 ) -> None:
2020 client = async_client.with_options(max_retries=4)
2021
2022 nb_retries = 0
2023
2024 def retry_handler(_request: httpx.Request) -> httpx.Response:
2025 nonlocal nb_retries
2026 if nb_retries < failures_before_success:
2027 nb_retries += 1
2028 return httpx.Response(500)
2029 return httpx.Response(200)
2030
2031 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
2032
2033 response = await client.chat.completions.with_raw_response.create(
2034 messages=[
2035 {
2036 "content": "string",
2037 "role": "developer",
2038 }
2039 ],
2040 model="gpt-4o",
2041 extra_headers={"x-stainless-retry-count": "42"},
2042 )
2043
2044 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
2045
2046 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
2047 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
2048 @pytest.mark.respx(base_url=base_url)
2049 async def test_retries_taken_new_response_class(
2050 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
2051 ) -> None:
2052 client = async_client.with_options(max_retries=4)
2053
2054 nb_retries = 0
2055
2056 def retry_handler(_request: httpx.Request) -> httpx.Response:
2057 nonlocal nb_retries
2058 if nb_retries < failures_before_success:
2059 nb_retries += 1
2060 return httpx.Response(500)
2061 return httpx.Response(200)
2062
2063 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
2064
2065 async with client.chat.completions.with_streaming_response.create(
2066 messages=[
2067 {
2068 "content": "string",
2069 "role": "developer",
2070 }
2071 ],
2072 model="gpt-4o",
2073 ) as response:
2074 assert response.retries_taken == failures_before_success
2075 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
2076
2077 async def test_get_platform(self) -> None:
2078 platform = await asyncify(get_platform)()
2079 assert isinstance(platform, (str, OtherPlatform))
2080
2081 async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
2082 # Test that the proxy environment variables are set correctly
2083 monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
2084 # Delete in case our environment has this set
2085 monkeypatch.delenv("HTTP_PROXY", raising=False)
2086
2087 client = DefaultAsyncHttpxClient()
2088
2089 mounts = tuple(client._mounts.items())
2090 assert len(mounts) == 1
2091 assert mounts[0][0].pattern == "https://"
2092
2093 @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
2094 async def test_default_client_creation(self) -> None:
2095 # Ensure that the client can be initialized without any exceptions
2096 DefaultAsyncHttpxClient(
2097 verify=True,
2098 cert=None,
2099 trust_env=True,
2100 http1=True,
2101 http2=False,
2102 limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
2103 )
2104
2105 @pytest.mark.respx(base_url=base_url)
2106 async def test_follow_redirects(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
2107 # Test that the default follow_redirects=True allows following redirects
2108 respx_mock.post("/redirect").mock(
2109 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
2110 )
2111 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
2112
2113 response = await async_client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
2114 assert response.status_code == 200
2115 assert response.json() == {"status": "ok"}
2116
2117 @pytest.mark.respx(base_url=base_url)
2118 async def test_follow_redirects_disabled(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
2119 # Test that follow_redirects=False prevents following redirects
2120 respx_mock.post("/redirect").mock(
2121 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
2122 )
2123
2124 with pytest.raises(APIStatusError) as exc_info:
2125 await async_client.post(
2126 "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
2127 )
2128
2129 assert exc_info.value.response.status_code == 302
2130 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
2131
2132 @pytest.mark.asyncio
2133 async def test_api_key_before_after_refresh_provider(self) -> None:
2134 async def mock_api_key_provider():
2135 return "test_bearer_token"
2136
2137 client = AsyncOpenAI(base_url=base_url, api_key=mock_api_key_provider)
2138
2139 assert client.api_key == ""
2140 assert "Authorization" not in client.auth_headers
2141
2142 await client._refresh_api_key()
2143
2144 assert client.api_key == "test_bearer_token"
2145 assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
2146
2147 @pytest.mark.asyncio
2148 async def test_api_key_before_after_refresh_str(self) -> None:
2149 client = AsyncOpenAI(base_url=base_url, api_key="test_api_key")
2150
2151 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
2152 await client._refresh_api_key()
2153
2154 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
2155
2156 @pytest.mark.asyncio
2157 @pytest.mark.respx()
2158 async def test_bearer_token_refresh_async(self, respx_mock: MockRouter) -> None:
2159 respx_mock.post(base_url + "/chat/completions").mock(
2160 side_effect=[
2161 httpx.Response(500, json={"error": "server error"}),
2162 httpx.Response(200, json={"foo": "bar"}),
2163 ]
2164 )
2165
2166 counter = 0
2167
2168 async def token_provider() -> str:
2169 nonlocal counter
2170
2171 counter += 1
2172
2173 if counter == 1:
2174 return "first"
2175
2176 return "second"
2177
2178 client = AsyncOpenAI(base_url=base_url, api_key=token_provider)
2179 await client.chat.completions.create(messages=[], model="gpt-4")
2180
2181 calls = cast("list[MockRequestCall]", respx_mock.calls)
2182 assert len(calls) == 2
2183
2184 assert calls[0].request.headers.get("Authorization") == "Bearer first"
2185 assert calls[1].request.headers.get("Authorization") == "Bearer second"
2186
2187 @pytest.mark.asyncio
2188 async def test_copy_auth(self) -> None:
2189 async def token_provider_1() -> str:
2190 return "test_bearer_token_1"
2191
2192 async def token_provider_2() -> str:
2193 return "test_bearer_token_2"
2194
2195 client = AsyncOpenAI(base_url=base_url, api_key=token_provider_1).copy(api_key=token_provider_2)
2196 await client._refresh_api_key()
2197 assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}
2198