openai/openai-python

Public

mirrored from https://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v2.22.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

2193lines · modecode

1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import gc
6import os
7import sys
8import json
9import asyncio
10import inspect
11import dataclasses
12import tracemalloc
13from typing import Any, Union, TypeVar, Callable, Iterable, Iterator, Optional, Protocol, Coroutine, cast
14from unittest import mock
15from typing_extensions import Literal, AsyncIterator, override
16
17import httpx
18import pytest
19from respx import MockRouter
20from pydantic import ValidationError
21
22from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
23from openai._types import Omit
24from openai._utils import asyncify
25from openai._models import BaseModel, FinalRequestOptions
26from openai._streaming import Stream, AsyncStream
27from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
28from openai._base_client import (
29 DEFAULT_TIMEOUT,
30 HTTPX_DEFAULT_TIMEOUT,
31 BaseClient,
32 OtherPlatform,
33 DefaultHttpxClient,
34 DefaultAsyncHttpxClient,
35 get_platform,
36 make_request_options,
37)
38
39from .utils import update_env
40
41T = TypeVar("T")
42base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
43api_key = "My API Key"
44
45
46class MockRequestCall(Protocol):
47 request: httpx.Request
48
49
50def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
51 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
52 url = httpx.URL(request.url)
53 return dict(url.params)
54
55
56def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
57 return 0.1
58
59
60def mirror_request_content(request: httpx.Request) -> httpx.Response:
61 return httpx.Response(200, content=request.content)
62
63
64# note: we can't use the httpx.MockTransport class as it consumes the request
65# body itself, which means we can't test that the body is read lazily
66class MockTransport(httpx.BaseTransport, httpx.AsyncBaseTransport):
67 def __init__(
68 self,
69 handler: Callable[[httpx.Request], httpx.Response]
70 | Callable[[httpx.Request], Coroutine[Any, Any, httpx.Response]],
71 ) -> None:
72 self.handler = handler
73
74 @override
75 def handle_request(
76 self,
77 request: httpx.Request,
78 ) -> httpx.Response:
79 assert not inspect.iscoroutinefunction(self.handler), "handler must not be a coroutine function"
80 assert inspect.isfunction(self.handler), "handler must be a function"
81 return self.handler(request)
82
83 @override
84 async def handle_async_request(
85 self,
86 request: httpx.Request,
87 ) -> httpx.Response:
88 assert inspect.iscoroutinefunction(self.handler), "handler must be a coroutine function"
89 return await self.handler(request)
90
91
92@dataclasses.dataclass
93class Counter:
94 value: int = 0
95
96
97def _make_sync_iterator(iterable: Iterable[T], counter: Optional[Counter] = None) -> Iterator[T]:
98 for item in iterable:
99 if counter:
100 counter.value += 1
101 yield item
102
103
104async def _make_async_iterator(iterable: Iterable[T], counter: Optional[Counter] = None) -> AsyncIterator[T]:
105 for item in iterable:
106 if counter:
107 counter.value += 1
108 yield item
109
110
111def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
112 transport = client._client._transport
113 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
114
115 pool = transport._pool
116 return len(pool._requests)
117
118
119class TestOpenAI:
120 @pytest.mark.respx(base_url=base_url)
121 def test_raw_response(self, respx_mock: MockRouter, client: OpenAI) -> None:
122 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
123
124 response = client.post("/foo", cast_to=httpx.Response)
125 assert response.status_code == 200
126 assert isinstance(response, httpx.Response)
127 assert response.json() == {"foo": "bar"}
128
129 @pytest.mark.respx(base_url=base_url)
130 def test_raw_response_for_binary(self, respx_mock: MockRouter, client: OpenAI) -> None:
131 respx_mock.post("/foo").mock(
132 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
133 )
134
135 response = client.post("/foo", cast_to=httpx.Response)
136 assert response.status_code == 200
137 assert isinstance(response, httpx.Response)
138 assert response.json() == {"foo": "bar"}
139
140 def test_copy(self, client: OpenAI) -> None:
141 copied = client.copy()
142 assert id(copied) != id(client)
143
144 copied = client.copy(api_key="another My API Key")
145 assert copied.api_key == "another My API Key"
146 assert client.api_key == "My API Key"
147
148 def test_copy_default_options(self, client: OpenAI) -> None:
149 # options that have a default are overridden correctly
150 copied = client.copy(max_retries=7)
151 assert copied.max_retries == 7
152 assert client.max_retries == 2
153
154 copied2 = copied.copy(max_retries=6)
155 assert copied2.max_retries == 6
156 assert copied.max_retries == 7
157
158 # timeout
159 assert isinstance(client.timeout, httpx.Timeout)
160 copied = client.copy(timeout=None)
161 assert copied.timeout is None
162 assert isinstance(client.timeout, httpx.Timeout)
163
164 def test_copy_default_headers(self) -> None:
165 client = OpenAI(
166 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
167 )
168 assert client.default_headers["X-Foo"] == "bar"
169
170 # does not override the already given value when not specified
171 copied = client.copy()
172 assert copied.default_headers["X-Foo"] == "bar"
173
174 # merges already given headers
175 copied = client.copy(default_headers={"X-Bar": "stainless"})
176 assert copied.default_headers["X-Foo"] == "bar"
177 assert copied.default_headers["X-Bar"] == "stainless"
178
179 # uses new values for any already given headers
180 copied = client.copy(default_headers={"X-Foo": "stainless"})
181 assert copied.default_headers["X-Foo"] == "stainless"
182
183 # set_default_headers
184
185 # completely overrides already set values
186 copied = client.copy(set_default_headers={})
187 assert copied.default_headers.get("X-Foo") is None
188
189 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
190 assert copied.default_headers["X-Bar"] == "Robert"
191
192 with pytest.raises(
193 ValueError,
194 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
195 ):
196 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
197 client.close()
198
199 def test_copy_default_query(self) -> None:
200 client = OpenAI(
201 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
202 )
203 assert _get_params(client)["foo"] == "bar"
204
205 # does not override the already given value when not specified
206 copied = client.copy()
207 assert _get_params(copied)["foo"] == "bar"
208
209 # merges already given params
210 copied = client.copy(default_query={"bar": "stainless"})
211 params = _get_params(copied)
212 assert params["foo"] == "bar"
213 assert params["bar"] == "stainless"
214
215 # uses new values for any already given headers
216 copied = client.copy(default_query={"foo": "stainless"})
217 assert _get_params(copied)["foo"] == "stainless"
218
219 # set_default_query
220
221 # completely overrides already set values
222 copied = client.copy(set_default_query={})
223 assert _get_params(copied) == {}
224
225 copied = client.copy(set_default_query={"bar": "Robert"})
226 assert _get_params(copied)["bar"] == "Robert"
227
228 with pytest.raises(
229 ValueError,
230 # TODO: update
231 match="`default_query` and `set_default_query` arguments are mutually exclusive",
232 ):
233 client.copy(set_default_query={}, default_query={"foo": "Bar"})
234
235 client.close()
236
237 def test_copy_signature(self, client: OpenAI) -> None:
238 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
239 init_signature = inspect.signature(
240 # mypy doesn't like that we access the `__init__` property.
241 client.__init__, # type: ignore[misc]
242 )
243 copy_signature = inspect.signature(client.copy)
244 exclude_params = {"transport", "proxies", "_strict_response_validation"}
245
246 for name in init_signature.parameters.keys():
247 if name in exclude_params:
248 continue
249
250 copy_param = copy_signature.parameters.get(name)
251 assert copy_param is not None, f"copy() signature is missing the {name} param"
252
253 @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
254 def test_copy_build_request(self, client: OpenAI) -> None:
255 options = FinalRequestOptions(method="get", url="/foo")
256
257 def build_request(options: FinalRequestOptions) -> None:
258 client_copy = client.copy()
259 client_copy._build_request(options)
260
261 # ensure that the machinery is warmed up before tracing starts.
262 build_request(options)
263 gc.collect()
264
265 tracemalloc.start(1000)
266
267 snapshot_before = tracemalloc.take_snapshot()
268
269 ITERATIONS = 10
270 for _ in range(ITERATIONS):
271 build_request(options)
272
273 gc.collect()
274 snapshot_after = tracemalloc.take_snapshot()
275
276 tracemalloc.stop()
277
278 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
279 if diff.count == 0:
280 # Avoid false positives by considering only leaks (i.e. allocations that persist).
281 return
282
283 if diff.count % ITERATIONS != 0:
284 # Avoid false positives by considering only leaks that appear per iteration.
285 return
286
287 for frame in diff.traceback:
288 if any(
289 frame.filename.endswith(fragment)
290 for fragment in [
291 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
292 #
293 # removing the decorator fixes the leak for reasons we don't understand.
294 "openai/_legacy_response.py",
295 "openai/_response.py",
296 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
297 "openai/_compat.py",
298 # Standard library leaks we don't care about.
299 "/logging/__init__.py",
300 ]
301 ):
302 return
303
304 leaks.append(diff)
305
306 leaks: list[tracemalloc.StatisticDiff] = []
307 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
308 add_leak(leaks, diff)
309 if leaks:
310 for leak in leaks:
311 print("MEMORY LEAK:", leak)
312 for frame in leak.traceback:
313 print(frame)
314 raise AssertionError()
315
316 def test_request_timeout(self, client: OpenAI) -> None:
317 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
318 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
319 assert timeout == DEFAULT_TIMEOUT
320
321 request = client._build_request(FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0)))
322 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
323 assert timeout == httpx.Timeout(100.0)
324
325 def test_client_timeout_option(self) -> None:
326 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
327
328 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
329 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
330 assert timeout == httpx.Timeout(0)
331
332 client.close()
333
334 def test_http_client_timeout_option(self) -> None:
335 # custom timeout given to the httpx client should be used
336 with httpx.Client(timeout=None) as http_client:
337 client = OpenAI(
338 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
339 )
340
341 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
342 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
343 assert timeout == httpx.Timeout(None)
344
345 client.close()
346
347 # no timeout given to the httpx client should not use the httpx default
348 with httpx.Client() as http_client:
349 client = OpenAI(
350 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
351 )
352
353 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
354 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
355 assert timeout == DEFAULT_TIMEOUT
356
357 client.close()
358
359 # explicitly passing the default timeout currently results in it being ignored
360 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
361 client = OpenAI(
362 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
363 )
364
365 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
366 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
367 assert timeout == DEFAULT_TIMEOUT # our default
368
369 client.close()
370
371 async def test_invalid_http_client(self) -> None:
372 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
373 async with httpx.AsyncClient() as http_client:
374 OpenAI(
375 base_url=base_url,
376 api_key=api_key,
377 _strict_response_validation=True,
378 http_client=cast(Any, http_client),
379 )
380
381 def test_default_headers_option(self) -> None:
382 test_client = OpenAI(
383 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
384 )
385 request = test_client._build_request(FinalRequestOptions(method="get", url="/foo"))
386 assert request.headers.get("x-foo") == "bar"
387 assert request.headers.get("x-stainless-lang") == "python"
388
389 test_client2 = OpenAI(
390 base_url=base_url,
391 api_key=api_key,
392 _strict_response_validation=True,
393 default_headers={
394 "X-Foo": "stainless",
395 "X-Stainless-Lang": "my-overriding-header",
396 },
397 )
398 request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo"))
399 assert request.headers.get("x-foo") == "stainless"
400 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
401
402 test_client.close()
403 test_client2.close()
404
405 def test_validate_headers(self) -> None:
406 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
407 options = client._prepare_options(FinalRequestOptions(method="get", url="/foo"))
408 request = client._build_request(options)
409
410 assert request.headers.get("Authorization") == f"Bearer {api_key}"
411
412 with pytest.raises(OpenAIError):
413 with update_env(**{"OPENAI_API_KEY": Omit()}):
414 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
415 _ = client2
416
417 def test_default_query_option(self) -> None:
418 client = OpenAI(
419 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
420 )
421 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
422 url = httpx.URL(request.url)
423 assert dict(url.params) == {"query_param": "bar"}
424
425 request = client._build_request(
426 FinalRequestOptions(
427 method="get",
428 url="/foo",
429 params={"foo": "baz", "query_param": "overridden"},
430 )
431 )
432 url = httpx.URL(request.url)
433 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
434
435 client.close()
436
437 def test_request_extra_json(self, client: OpenAI) -> None:
438 request = client._build_request(
439 FinalRequestOptions(
440 method="post",
441 url="/foo",
442 json_data={"foo": "bar"},
443 extra_json={"baz": False},
444 ),
445 )
446 data = json.loads(request.content.decode("utf-8"))
447 assert data == {"foo": "bar", "baz": False}
448
449 request = client._build_request(
450 FinalRequestOptions(
451 method="post",
452 url="/foo",
453 extra_json={"baz": False},
454 ),
455 )
456 data = json.loads(request.content.decode("utf-8"))
457 assert data == {"baz": False}
458
459 # `extra_json` takes priority over `json_data` when keys clash
460 request = client._build_request(
461 FinalRequestOptions(
462 method="post",
463 url="/foo",
464 json_data={"foo": "bar", "baz": True},
465 extra_json={"baz": None},
466 ),
467 )
468 data = json.loads(request.content.decode("utf-8"))
469 assert data == {"foo": "bar", "baz": None}
470
471 def test_request_extra_headers(self, client: OpenAI) -> None:
472 request = client._build_request(
473 FinalRequestOptions(
474 method="post",
475 url="/foo",
476 **make_request_options(extra_headers={"X-Foo": "Foo"}),
477 ),
478 )
479 assert request.headers.get("X-Foo") == "Foo"
480
481 # `extra_headers` takes priority over `default_headers` when keys clash
482 request = client.with_options(default_headers={"X-Bar": "true"})._build_request(
483 FinalRequestOptions(
484 method="post",
485 url="/foo",
486 **make_request_options(
487 extra_headers={"X-Bar": "false"},
488 ),
489 ),
490 )
491 assert request.headers.get("X-Bar") == "false"
492
493 def test_request_extra_query(self, client: OpenAI) -> None:
494 request = client._build_request(
495 FinalRequestOptions(
496 method="post",
497 url="/foo",
498 **make_request_options(
499 extra_query={"my_query_param": "Foo"},
500 ),
501 ),
502 )
503 params = dict(request.url.params)
504 assert params == {"my_query_param": "Foo"}
505
506 # if both `query` and `extra_query` are given, they are merged
507 request = client._build_request(
508 FinalRequestOptions(
509 method="post",
510 url="/foo",
511 **make_request_options(
512 query={"bar": "1"},
513 extra_query={"foo": "2"},
514 ),
515 ),
516 )
517 params = dict(request.url.params)
518 assert params == {"bar": "1", "foo": "2"}
519
520 # `extra_query` takes priority over `query` when keys clash
521 request = client._build_request(
522 FinalRequestOptions(
523 method="post",
524 url="/foo",
525 **make_request_options(
526 query={"foo": "1"},
527 extra_query={"foo": "2"},
528 ),
529 ),
530 )
531 params = dict(request.url.params)
532 assert params == {"foo": "2"}
533
534 def test_multipart_repeating_array(self, client: OpenAI) -> None:
535 request = client._build_request(
536 FinalRequestOptions.construct(
537 method="post",
538 url="/foo",
539 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
540 json_data={"array": ["foo", "bar"]},
541 files=[("foo.txt", b"hello world")],
542 )
543 )
544
545 assert request.read().split(b"\r\n") == [
546 b"--6b7ba517decee4a450543ea6ae821c82",
547 b'Content-Disposition: form-data; name="array[]"',
548 b"",
549 b"foo",
550 b"--6b7ba517decee4a450543ea6ae821c82",
551 b'Content-Disposition: form-data; name="array[]"',
552 b"",
553 b"bar",
554 b"--6b7ba517decee4a450543ea6ae821c82",
555 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
556 b"Content-Type: application/octet-stream",
557 b"",
558 b"hello world",
559 b"--6b7ba517decee4a450543ea6ae821c82--",
560 b"",
561 ]
562
563 @pytest.mark.respx(base_url=base_url)
564 def test_binary_content_upload(self, respx_mock: MockRouter, client: OpenAI) -> None:
565 respx_mock.post("/upload").mock(side_effect=mirror_request_content)
566
567 file_content = b"Hello, this is a test file."
568
569 response = client.post(
570 "/upload",
571 content=file_content,
572 cast_to=httpx.Response,
573 options={"headers": {"Content-Type": "application/octet-stream"}},
574 )
575
576 assert response.status_code == 200
577 assert response.request.headers["Content-Type"] == "application/octet-stream"
578 assert response.content == file_content
579
580 def test_binary_content_upload_with_iterator(self) -> None:
581 file_content = b"Hello, this is a test file."
582 counter = Counter()
583 iterator = _make_sync_iterator([file_content], counter=counter)
584
585 def mock_handler(request: httpx.Request) -> httpx.Response:
586 assert counter.value == 0, "the request body should not have been read"
587 return httpx.Response(200, content=request.read())
588
589 with OpenAI(
590 base_url=base_url,
591 api_key=api_key,
592 _strict_response_validation=True,
593 http_client=httpx.Client(transport=MockTransport(handler=mock_handler)),
594 ) as client:
595 response = client.post(
596 "/upload",
597 content=iterator,
598 cast_to=httpx.Response,
599 options={"headers": {"Content-Type": "application/octet-stream"}},
600 )
601
602 assert response.status_code == 200
603 assert response.request.headers["Content-Type"] == "application/octet-stream"
604 assert response.content == file_content
605 assert counter.value == 1
606
607 @pytest.mark.respx(base_url=base_url)
608 def test_binary_content_upload_with_body_is_deprecated(self, respx_mock: MockRouter, client: OpenAI) -> None:
609 respx_mock.post("/upload").mock(side_effect=mirror_request_content)
610
611 file_content = b"Hello, this is a test file."
612
613 with pytest.deprecated_call(
614 match="Passing raw bytes as `body` is deprecated and will be removed in a future version. Please pass raw bytes via the `content` parameter instead."
615 ):
616 response = client.post(
617 "/upload",
618 body=file_content,
619 cast_to=httpx.Response,
620 options={"headers": {"Content-Type": "application/octet-stream"}},
621 )
622
623 assert response.status_code == 200
624 assert response.request.headers["Content-Type"] == "application/octet-stream"
625 assert response.content == file_content
626
627 @pytest.mark.respx(base_url=base_url)
628 def test_basic_union_response(self, respx_mock: MockRouter, client: OpenAI) -> None:
629 class Model1(BaseModel):
630 name: str
631
632 class Model2(BaseModel):
633 foo: str
634
635 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
636
637 response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
638 assert isinstance(response, Model2)
639 assert response.foo == "bar"
640
641 @pytest.mark.respx(base_url=base_url)
642 def test_union_response_different_types(self, respx_mock: MockRouter, client: OpenAI) -> None:
643 """Union of objects with the same field name using a different type"""
644
645 class Model1(BaseModel):
646 foo: int
647
648 class Model2(BaseModel):
649 foo: str
650
651 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
652
653 response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
654 assert isinstance(response, Model2)
655 assert response.foo == "bar"
656
657 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
658
659 response = client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
660 assert isinstance(response, Model1)
661 assert response.foo == 1
662
663 @pytest.mark.respx(base_url=base_url)
664 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter, client: OpenAI) -> None:
665 """
666 Response that sets Content-Type to something other than application/json but returns json data
667 """
668
669 class Model(BaseModel):
670 foo: int
671
672 respx_mock.get("/foo").mock(
673 return_value=httpx.Response(
674 200,
675 content=json.dumps({"foo": 2}),
676 headers={"Content-Type": "application/text"},
677 )
678 )
679
680 response = client.get("/foo", cast_to=Model)
681 assert isinstance(response, Model)
682 assert response.foo == 2
683
684 def test_base_url_setter(self) -> None:
685 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
686 assert client.base_url == "https://example.com/from_init/"
687
688 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
689
690 assert client.base_url == "https://example.com/from_setter/"
691
692 client.close()
693
694 def test_base_url_env(self) -> None:
695 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
696 client = OpenAI(api_key=api_key, _strict_response_validation=True)
697 assert client.base_url == "http://localhost:5000/from/env/"
698
699 @pytest.mark.parametrize(
700 "client",
701 [
702 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
703 OpenAI(
704 base_url="http://localhost:5000/custom/path/",
705 api_key=api_key,
706 _strict_response_validation=True,
707 http_client=httpx.Client(),
708 ),
709 ],
710 ids=["standard", "custom http client"],
711 )
712 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
713 request = client._build_request(
714 FinalRequestOptions(
715 method="post",
716 url="/foo",
717 json_data={"foo": "bar"},
718 ),
719 )
720 assert request.url == "http://localhost:5000/custom/path/foo"
721 client.close()
722
723 @pytest.mark.parametrize(
724 "client",
725 [
726 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
727 OpenAI(
728 base_url="http://localhost:5000/custom/path/",
729 api_key=api_key,
730 _strict_response_validation=True,
731 http_client=httpx.Client(),
732 ),
733 ],
734 ids=["standard", "custom http client"],
735 )
736 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
737 request = client._build_request(
738 FinalRequestOptions(
739 method="post",
740 url="/foo",
741 json_data={"foo": "bar"},
742 ),
743 )
744 assert request.url == "http://localhost:5000/custom/path/foo"
745 client.close()
746
747 @pytest.mark.parametrize(
748 "client",
749 [
750 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
751 OpenAI(
752 base_url="http://localhost:5000/custom/path/",
753 api_key=api_key,
754 _strict_response_validation=True,
755 http_client=httpx.Client(),
756 ),
757 ],
758 ids=["standard", "custom http client"],
759 )
760 def test_absolute_request_url(self, client: OpenAI) -> None:
761 request = client._build_request(
762 FinalRequestOptions(
763 method="post",
764 url="https://myapi.com/foo",
765 json_data={"foo": "bar"},
766 ),
767 )
768 assert request.url == "https://myapi.com/foo"
769 client.close()
770
771 def test_copied_client_does_not_close_http(self) -> None:
772 test_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
773 assert not test_client.is_closed()
774
775 copied = test_client.copy()
776 assert copied is not test_client
777
778 del copied
779
780 assert not test_client.is_closed()
781
782 def test_client_context_manager(self) -> None:
783 test_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
784 with test_client as c2:
785 assert c2 is test_client
786 assert not c2.is_closed()
787 assert not test_client.is_closed()
788 assert test_client.is_closed()
789
790 @pytest.mark.respx(base_url=base_url)
791 def test_client_response_validation_error(self, respx_mock: MockRouter, client: OpenAI) -> None:
792 class Model(BaseModel):
793 foo: str
794
795 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
796
797 with pytest.raises(APIResponseValidationError) as exc:
798 client.get("/foo", cast_to=Model)
799
800 assert isinstance(exc.value.__cause__, ValidationError)
801
802 def test_client_max_retries_validation(self) -> None:
803 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
804 OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None))
805
806 @pytest.mark.respx(base_url=base_url)
807 def test_default_stream_cls(self, respx_mock: MockRouter, client: OpenAI) -> None:
808 class Model(BaseModel):
809 name: str
810
811 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
812
813 stream = client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
814 assert isinstance(stream, Stream)
815 stream.response.close()
816
817 @pytest.mark.respx(base_url=base_url)
818 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
819 class Model(BaseModel):
820 name: str
821
822 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
823
824 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
825
826 with pytest.raises(APIResponseValidationError):
827 strict_client.get("/foo", cast_to=Model)
828
829 non_strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
830
831 response = non_strict_client.get("/foo", cast_to=Model)
832 assert isinstance(response, str) # type: ignore[unreachable]
833
834 strict_client.close()
835 non_strict_client.close()
836
837 @pytest.mark.parametrize(
838 "remaining_retries,retry_after,timeout",
839 [
840 [3, "20", 20],
841 [3, "0", 0.5],
842 [3, "-10", 0.5],
843 [3, "60", 60],
844 [3, "61", 0.5],
845 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
846 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
847 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
848 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
849 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
850 [3, "99999999999999999999999999999999999", 0.5],
851 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
852 [3, "", 0.5],
853 [2, "", 0.5 * 2.0],
854 [1, "", 0.5 * 4.0],
855 [-1100, "", 8], # test large number potentially overflowing
856 ],
857 )
858 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
859 def test_parse_retry_after_header(
860 self, remaining_retries: int, retry_after: str, timeout: float, client: OpenAI
861 ) -> None:
862 headers = httpx.Headers({"retry-after": retry_after})
863 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
864 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
865 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
866
867 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
868 @pytest.mark.respx(base_url=base_url)
869 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: OpenAI) -> None:
870 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
871
872 with pytest.raises(APITimeoutError):
873 client.chat.completions.with_streaming_response.create(
874 messages=[
875 {
876 "content": "string",
877 "role": "developer",
878 }
879 ],
880 model="gpt-4o",
881 ).__enter__()
882
883 assert _get_open_connections(client) == 0
884
885 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
886 @pytest.mark.respx(base_url=base_url)
887 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: OpenAI) -> None:
888 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
889
890 with pytest.raises(APIStatusError):
891 client.chat.completions.with_streaming_response.create(
892 messages=[
893 {
894 "content": "string",
895 "role": "developer",
896 }
897 ],
898 model="gpt-4o",
899 ).__enter__()
900 assert _get_open_connections(client) == 0
901
902 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
903 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
904 @pytest.mark.respx(base_url=base_url)
905 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
906 def test_retries_taken(
907 self,
908 client: OpenAI,
909 failures_before_success: int,
910 failure_mode: Literal["status", "exception"],
911 respx_mock: MockRouter,
912 ) -> None:
913 client = client.with_options(max_retries=4)
914
915 nb_retries = 0
916
917 def retry_handler(_request: httpx.Request) -> httpx.Response:
918 nonlocal nb_retries
919 if nb_retries < failures_before_success:
920 nb_retries += 1
921 if failure_mode == "exception":
922 raise RuntimeError("oops")
923 return httpx.Response(500)
924 return httpx.Response(200)
925
926 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
927
928 response = client.chat.completions.with_raw_response.create(
929 messages=[
930 {
931 "content": "string",
932 "role": "developer",
933 }
934 ],
935 model="gpt-4o",
936 )
937
938 assert response.retries_taken == failures_before_success
939 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
940
941 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
942 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
943 @pytest.mark.respx(base_url=base_url)
944 def test_omit_retry_count_header(
945 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
946 ) -> None:
947 client = client.with_options(max_retries=4)
948
949 nb_retries = 0
950
951 def retry_handler(_request: httpx.Request) -> httpx.Response:
952 nonlocal nb_retries
953 if nb_retries < failures_before_success:
954 nb_retries += 1
955 return httpx.Response(500)
956 return httpx.Response(200)
957
958 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
959
960 response = client.chat.completions.with_raw_response.create(
961 messages=[
962 {
963 "content": "string",
964 "role": "developer",
965 }
966 ],
967 model="gpt-4o",
968 extra_headers={"x-stainless-retry-count": Omit()},
969 )
970
971 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
972
973 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
974 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
975 @pytest.mark.respx(base_url=base_url)
976 def test_overwrite_retry_count_header(
977 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
978 ) -> None:
979 client = client.with_options(max_retries=4)
980
981 nb_retries = 0
982
983 def retry_handler(_request: httpx.Request) -> httpx.Response:
984 nonlocal nb_retries
985 if nb_retries < failures_before_success:
986 nb_retries += 1
987 return httpx.Response(500)
988 return httpx.Response(200)
989
990 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
991
992 response = client.chat.completions.with_raw_response.create(
993 messages=[
994 {
995 "content": "string",
996 "role": "developer",
997 }
998 ],
999 model="gpt-4o",
1000 extra_headers={"x-stainless-retry-count": "42"},
1001 )
1002
1003 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
1004
1005 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1006 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1007 @pytest.mark.respx(base_url=base_url)
1008 def test_retries_taken_new_response_class(
1009 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
1010 ) -> None:
1011 client = client.with_options(max_retries=4)
1012
1013 nb_retries = 0
1014
1015 def retry_handler(_request: httpx.Request) -> httpx.Response:
1016 nonlocal nb_retries
1017 if nb_retries < failures_before_success:
1018 nb_retries += 1
1019 return httpx.Response(500)
1020 return httpx.Response(200)
1021
1022 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1023
1024 with client.chat.completions.with_streaming_response.create(
1025 messages=[
1026 {
1027 "content": "string",
1028 "role": "developer",
1029 }
1030 ],
1031 model="gpt-4o",
1032 ) as response:
1033 assert response.retries_taken == failures_before_success
1034 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1035
1036 def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
1037 # Test that the proxy environment variables are set correctly
1038 monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
1039
1040 client = DefaultHttpxClient()
1041
1042 mounts = tuple(client._mounts.items())
1043 assert len(mounts) == 1
1044 assert mounts[0][0].pattern == "https://"
1045
1046 @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
1047 def test_default_client_creation(self) -> None:
1048 # Ensure that the client can be initialized without any exceptions
1049 DefaultHttpxClient(
1050 verify=True,
1051 cert=None,
1052 trust_env=True,
1053 http1=True,
1054 http2=False,
1055 limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
1056 )
1057
1058 @pytest.mark.respx(base_url=base_url)
1059 def test_follow_redirects(self, respx_mock: MockRouter, client: OpenAI) -> None:
1060 # Test that the default follow_redirects=True allows following redirects
1061 respx_mock.post("/redirect").mock(
1062 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1063 )
1064 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
1065
1066 response = client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
1067 assert response.status_code == 200
1068 assert response.json() == {"status": "ok"}
1069
1070 @pytest.mark.respx(base_url=base_url)
1071 def test_follow_redirects_disabled(self, respx_mock: MockRouter, client: OpenAI) -> None:
1072 # Test that follow_redirects=False prevents following redirects
1073 respx_mock.post("/redirect").mock(
1074 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1075 )
1076
1077 with pytest.raises(APIStatusError) as exc_info:
1078 client.post("/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response)
1079
1080 assert exc_info.value.response.status_code == 302
1081 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
1082
1083 def test_api_key_before_after_refresh_provider(self) -> None:
1084 client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token")
1085
1086 assert client.api_key == ""
1087 assert "Authorization" not in client.auth_headers
1088
1089 client._refresh_api_key()
1090
1091 assert client.api_key == "test_bearer_token"
1092 assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
1093
1094 def test_api_key_before_after_refresh_str(self) -> None:
1095 client = OpenAI(base_url=base_url, api_key="test_api_key")
1096
1097 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
1098 client._refresh_api_key()
1099
1100 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
1101
1102 @pytest.mark.respx()
1103 def test_api_key_refresh_on_retry(self, respx_mock: MockRouter) -> None:
1104 respx_mock.post(base_url + "/chat/completions").mock(
1105 side_effect=[
1106 httpx.Response(500, json={"error": "server error"}),
1107 httpx.Response(200, json={"foo": "bar"}),
1108 ]
1109 )
1110
1111 counter = 0
1112
1113 def token_provider() -> str:
1114 nonlocal counter
1115
1116 counter += 1
1117
1118 if counter == 1:
1119 return "first"
1120
1121 return "second"
1122
1123 client = OpenAI(base_url=base_url, api_key=token_provider)
1124 client.chat.completions.create(messages=[], model="gpt-4")
1125
1126 calls = cast("list[MockRequestCall]", respx_mock.calls)
1127 assert len(calls) == 2
1128
1129 assert calls[0].request.headers.get("Authorization") == "Bearer first"
1130 assert calls[1].request.headers.get("Authorization") == "Bearer second"
1131
1132 def test_copy_auth(self) -> None:
1133 client = OpenAI(base_url=base_url, api_key=lambda: "test_bearer_token_1").copy(
1134 api_key=lambda: "test_bearer_token_2"
1135 )
1136 client._refresh_api_key()
1137 assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}
1138
1139
1140class TestAsyncOpenAI:
1141 @pytest.mark.respx(base_url=base_url)
1142 async def test_raw_response(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1143 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1144
1145 response = await async_client.post("/foo", cast_to=httpx.Response)
1146 assert response.status_code == 200
1147 assert isinstance(response, httpx.Response)
1148 assert response.json() == {"foo": "bar"}
1149
1150 @pytest.mark.respx(base_url=base_url)
1151 async def test_raw_response_for_binary(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1152 respx_mock.post("/foo").mock(
1153 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
1154 )
1155
1156 response = await async_client.post("/foo", cast_to=httpx.Response)
1157 assert response.status_code == 200
1158 assert isinstance(response, httpx.Response)
1159 assert response.json() == {"foo": "bar"}
1160
1161 def test_copy(self, async_client: AsyncOpenAI) -> None:
1162 copied = async_client.copy()
1163 assert id(copied) != id(async_client)
1164
1165 copied = async_client.copy(api_key="another My API Key")
1166 assert copied.api_key == "another My API Key"
1167 assert async_client.api_key == "My API Key"
1168
1169 def test_copy_default_options(self, async_client: AsyncOpenAI) -> None:
1170 # options that have a default are overridden correctly
1171 copied = async_client.copy(max_retries=7)
1172 assert copied.max_retries == 7
1173 assert async_client.max_retries == 2
1174
1175 copied2 = copied.copy(max_retries=6)
1176 assert copied2.max_retries == 6
1177 assert copied.max_retries == 7
1178
1179 # timeout
1180 assert isinstance(async_client.timeout, httpx.Timeout)
1181 copied = async_client.copy(timeout=None)
1182 assert copied.timeout is None
1183 assert isinstance(async_client.timeout, httpx.Timeout)
1184
1185 async def test_copy_default_headers(self) -> None:
1186 client = AsyncOpenAI(
1187 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1188 )
1189 assert client.default_headers["X-Foo"] == "bar"
1190
1191 # does not override the already given value when not specified
1192 copied = client.copy()
1193 assert copied.default_headers["X-Foo"] == "bar"
1194
1195 # merges already given headers
1196 copied = client.copy(default_headers={"X-Bar": "stainless"})
1197 assert copied.default_headers["X-Foo"] == "bar"
1198 assert copied.default_headers["X-Bar"] == "stainless"
1199
1200 # uses new values for any already given headers
1201 copied = client.copy(default_headers={"X-Foo": "stainless"})
1202 assert copied.default_headers["X-Foo"] == "stainless"
1203
1204 # set_default_headers
1205
1206 # completely overrides already set values
1207 copied = client.copy(set_default_headers={})
1208 assert copied.default_headers.get("X-Foo") is None
1209
1210 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
1211 assert copied.default_headers["X-Bar"] == "Robert"
1212
1213 with pytest.raises(
1214 ValueError,
1215 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
1216 ):
1217 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
1218 await client.close()
1219
1220 async def test_copy_default_query(self) -> None:
1221 client = AsyncOpenAI(
1222 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
1223 )
1224 assert _get_params(client)["foo"] == "bar"
1225
1226 # does not override the already given value when not specified
1227 copied = client.copy()
1228 assert _get_params(copied)["foo"] == "bar"
1229
1230 # merges already given params
1231 copied = client.copy(default_query={"bar": "stainless"})
1232 params = _get_params(copied)
1233 assert params["foo"] == "bar"
1234 assert params["bar"] == "stainless"
1235
1236 # uses new values for any already given headers
1237 copied = client.copy(default_query={"foo": "stainless"})
1238 assert _get_params(copied)["foo"] == "stainless"
1239
1240 # set_default_query
1241
1242 # completely overrides already set values
1243 copied = client.copy(set_default_query={})
1244 assert _get_params(copied) == {}
1245
1246 copied = client.copy(set_default_query={"bar": "Robert"})
1247 assert _get_params(copied)["bar"] == "Robert"
1248
1249 with pytest.raises(
1250 ValueError,
1251 # TODO: update
1252 match="`default_query` and `set_default_query` arguments are mutually exclusive",
1253 ):
1254 client.copy(set_default_query={}, default_query={"foo": "Bar"})
1255
1256 await client.close()
1257
1258 def test_copy_signature(self, async_client: AsyncOpenAI) -> None:
1259 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
1260 init_signature = inspect.signature(
1261 # mypy doesn't like that we access the `__init__` property.
1262 async_client.__init__, # type: ignore[misc]
1263 )
1264 copy_signature = inspect.signature(async_client.copy)
1265 exclude_params = {"transport", "proxies", "_strict_response_validation"}
1266
1267 for name in init_signature.parameters.keys():
1268 if name in exclude_params:
1269 continue
1270
1271 copy_param = copy_signature.parameters.get(name)
1272 assert copy_param is not None, f"copy() signature is missing the {name} param"
1273
1274 @pytest.mark.skipif(sys.version_info >= (3, 10), reason="fails because of a memory leak that started from 3.12")
1275 def test_copy_build_request(self, async_client: AsyncOpenAI) -> None:
1276 options = FinalRequestOptions(method="get", url="/foo")
1277
1278 def build_request(options: FinalRequestOptions) -> None:
1279 client_copy = async_client.copy()
1280 client_copy._build_request(options)
1281
1282 # ensure that the machinery is warmed up before tracing starts.
1283 build_request(options)
1284 gc.collect()
1285
1286 tracemalloc.start(1000)
1287
1288 snapshot_before = tracemalloc.take_snapshot()
1289
1290 ITERATIONS = 10
1291 for _ in range(ITERATIONS):
1292 build_request(options)
1293
1294 gc.collect()
1295 snapshot_after = tracemalloc.take_snapshot()
1296
1297 tracemalloc.stop()
1298
1299 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
1300 if diff.count == 0:
1301 # Avoid false positives by considering only leaks (i.e. allocations that persist).
1302 return
1303
1304 if diff.count % ITERATIONS != 0:
1305 # Avoid false positives by considering only leaks that appear per iteration.
1306 return
1307
1308 for frame in diff.traceback:
1309 if any(
1310 frame.filename.endswith(fragment)
1311 for fragment in [
1312 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
1313 #
1314 # removing the decorator fixes the leak for reasons we don't understand.
1315 "openai/_legacy_response.py",
1316 "openai/_response.py",
1317 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
1318 "openai/_compat.py",
1319 # Standard library leaks we don't care about.
1320 "/logging/__init__.py",
1321 ]
1322 ):
1323 return
1324
1325 leaks.append(diff)
1326
1327 leaks: list[tracemalloc.StatisticDiff] = []
1328 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
1329 add_leak(leaks, diff)
1330 if leaks:
1331 for leak in leaks:
1332 print("MEMORY LEAK:", leak)
1333 for frame in leak.traceback:
1334 print(frame)
1335 raise AssertionError()
1336
1337 async def test_request_timeout(self, async_client: AsyncOpenAI) -> None:
1338 request = async_client._build_request(FinalRequestOptions(method="get", url="/foo"))
1339 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1340 assert timeout == DEFAULT_TIMEOUT
1341
1342 request = async_client._build_request(
1343 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
1344 )
1345 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1346 assert timeout == httpx.Timeout(100.0)
1347
1348 async def test_client_timeout_option(self) -> None:
1349 client = AsyncOpenAI(
1350 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
1351 )
1352
1353 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1354 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1355 assert timeout == httpx.Timeout(0)
1356
1357 await client.close()
1358
1359 async def test_http_client_timeout_option(self) -> None:
1360 # custom timeout given to the httpx client should be used
1361 async with httpx.AsyncClient(timeout=None) as http_client:
1362 client = AsyncOpenAI(
1363 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1364 )
1365
1366 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1367 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1368 assert timeout == httpx.Timeout(None)
1369
1370 await client.close()
1371
1372 # no timeout given to the httpx client should not use the httpx default
1373 async with httpx.AsyncClient() as http_client:
1374 client = AsyncOpenAI(
1375 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1376 )
1377
1378 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1379 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1380 assert timeout == DEFAULT_TIMEOUT
1381
1382 await client.close()
1383
1384 # explicitly passing the default timeout currently results in it being ignored
1385 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1386 client = AsyncOpenAI(
1387 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1388 )
1389
1390 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1391 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1392 assert timeout == DEFAULT_TIMEOUT # our default
1393
1394 await client.close()
1395
1396 def test_invalid_http_client(self) -> None:
1397 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1398 with httpx.Client() as http_client:
1399 AsyncOpenAI(
1400 base_url=base_url,
1401 api_key=api_key,
1402 _strict_response_validation=True,
1403 http_client=cast(Any, http_client),
1404 )
1405
1406 async def test_default_headers_option(self) -> None:
1407 test_client = AsyncOpenAI(
1408 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1409 )
1410 request = test_client._build_request(FinalRequestOptions(method="get", url="/foo"))
1411 assert request.headers.get("x-foo") == "bar"
1412 assert request.headers.get("x-stainless-lang") == "python"
1413
1414 test_client2 = AsyncOpenAI(
1415 base_url=base_url,
1416 api_key=api_key,
1417 _strict_response_validation=True,
1418 default_headers={
1419 "X-Foo": "stainless",
1420 "X-Stainless-Lang": "my-overriding-header",
1421 },
1422 )
1423 request = test_client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1424 assert request.headers.get("x-foo") == "stainless"
1425 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1426
1427 await test_client.close()
1428 await test_client2.close()
1429
1430 async def test_validate_headers(self) -> None:
1431 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1432 options = await client._prepare_options(FinalRequestOptions(method="get", url="/foo"))
1433 request = client._build_request(options)
1434 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1435
1436 with pytest.raises(OpenAIError):
1437 with update_env(**{"OPENAI_API_KEY": Omit()}):
1438 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1439 _ = client2
1440
1441 async def test_default_query_option(self) -> None:
1442 client = AsyncOpenAI(
1443 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1444 )
1445 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1446 url = httpx.URL(request.url)
1447 assert dict(url.params) == {"query_param": "bar"}
1448
1449 request = client._build_request(
1450 FinalRequestOptions(
1451 method="get",
1452 url="/foo",
1453 params={"foo": "baz", "query_param": "overridden"},
1454 )
1455 )
1456 url = httpx.URL(request.url)
1457 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
1458
1459 await client.close()
1460
1461 def test_request_extra_json(self, client: OpenAI) -> None:
1462 request = client._build_request(
1463 FinalRequestOptions(
1464 method="post",
1465 url="/foo",
1466 json_data={"foo": "bar"},
1467 extra_json={"baz": False},
1468 ),
1469 )
1470 data = json.loads(request.content.decode("utf-8"))
1471 assert data == {"foo": "bar", "baz": False}
1472
1473 request = client._build_request(
1474 FinalRequestOptions(
1475 method="post",
1476 url="/foo",
1477 extra_json={"baz": False},
1478 ),
1479 )
1480 data = json.loads(request.content.decode("utf-8"))
1481 assert data == {"baz": False}
1482
1483 # `extra_json` takes priority over `json_data` when keys clash
1484 request = client._build_request(
1485 FinalRequestOptions(
1486 method="post",
1487 url="/foo",
1488 json_data={"foo": "bar", "baz": True},
1489 extra_json={"baz": None},
1490 ),
1491 )
1492 data = json.loads(request.content.decode("utf-8"))
1493 assert data == {"foo": "bar", "baz": None}
1494
1495 def test_request_extra_headers(self, client: OpenAI) -> None:
1496 request = client._build_request(
1497 FinalRequestOptions(
1498 method="post",
1499 url="/foo",
1500 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1501 ),
1502 )
1503 assert request.headers.get("X-Foo") == "Foo"
1504
1505 # `extra_headers` takes priority over `default_headers` when keys clash
1506 request = client.with_options(default_headers={"X-Bar": "true"})._build_request(
1507 FinalRequestOptions(
1508 method="post",
1509 url="/foo",
1510 **make_request_options(
1511 extra_headers={"X-Bar": "false"},
1512 ),
1513 ),
1514 )
1515 assert request.headers.get("X-Bar") == "false"
1516
1517 def test_request_extra_query(self, client: OpenAI) -> None:
1518 request = client._build_request(
1519 FinalRequestOptions(
1520 method="post",
1521 url="/foo",
1522 **make_request_options(
1523 extra_query={"my_query_param": "Foo"},
1524 ),
1525 ),
1526 )
1527 params = dict(request.url.params)
1528 assert params == {"my_query_param": "Foo"}
1529
1530 # if both `query` and `extra_query` are given, they are merged
1531 request = client._build_request(
1532 FinalRequestOptions(
1533 method="post",
1534 url="/foo",
1535 **make_request_options(
1536 query={"bar": "1"},
1537 extra_query={"foo": "2"},
1538 ),
1539 ),
1540 )
1541 params = dict(request.url.params)
1542 assert params == {"bar": "1", "foo": "2"}
1543
1544 # `extra_query` takes priority over `query` when keys clash
1545 request = client._build_request(
1546 FinalRequestOptions(
1547 method="post",
1548 url="/foo",
1549 **make_request_options(
1550 query={"foo": "1"},
1551 extra_query={"foo": "2"},
1552 ),
1553 ),
1554 )
1555 params = dict(request.url.params)
1556 assert params == {"foo": "2"}
1557
1558 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1559 request = async_client._build_request(
1560 FinalRequestOptions.construct(
1561 method="post",
1562 url="/foo",
1563 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1564 json_data={"array": ["foo", "bar"]},
1565 files=[("foo.txt", b"hello world")],
1566 )
1567 )
1568
1569 assert request.read().split(b"\r\n") == [
1570 b"--6b7ba517decee4a450543ea6ae821c82",
1571 b'Content-Disposition: form-data; name="array[]"',
1572 b"",
1573 b"foo",
1574 b"--6b7ba517decee4a450543ea6ae821c82",
1575 b'Content-Disposition: form-data; name="array[]"',
1576 b"",
1577 b"bar",
1578 b"--6b7ba517decee4a450543ea6ae821c82",
1579 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1580 b"Content-Type: application/octet-stream",
1581 b"",
1582 b"hello world",
1583 b"--6b7ba517decee4a450543ea6ae821c82--",
1584 b"",
1585 ]
1586
1587 @pytest.mark.respx(base_url=base_url)
1588 async def test_binary_content_upload(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1589 respx_mock.post("/upload").mock(side_effect=mirror_request_content)
1590
1591 file_content = b"Hello, this is a test file."
1592
1593 response = await async_client.post(
1594 "/upload",
1595 content=file_content,
1596 cast_to=httpx.Response,
1597 options={"headers": {"Content-Type": "application/octet-stream"}},
1598 )
1599
1600 assert response.status_code == 200
1601 assert response.request.headers["Content-Type"] == "application/octet-stream"
1602 assert response.content == file_content
1603
1604 async def test_binary_content_upload_with_asynciterator(self) -> None:
1605 file_content = b"Hello, this is a test file."
1606 counter = Counter()
1607 iterator = _make_async_iterator([file_content], counter=counter)
1608
1609 async def mock_handler(request: httpx.Request) -> httpx.Response:
1610 assert counter.value == 0, "the request body should not have been read"
1611 return httpx.Response(200, content=await request.aread())
1612
1613 async with AsyncOpenAI(
1614 base_url=base_url,
1615 api_key=api_key,
1616 _strict_response_validation=True,
1617 http_client=httpx.AsyncClient(transport=MockTransport(handler=mock_handler)),
1618 ) as client:
1619 response = await client.post(
1620 "/upload",
1621 content=iterator,
1622 cast_to=httpx.Response,
1623 options={"headers": {"Content-Type": "application/octet-stream"}},
1624 )
1625
1626 assert response.status_code == 200
1627 assert response.request.headers["Content-Type"] == "application/octet-stream"
1628 assert response.content == file_content
1629 assert counter.value == 1
1630
1631 @pytest.mark.respx(base_url=base_url)
1632 async def test_binary_content_upload_with_body_is_deprecated(
1633 self, respx_mock: MockRouter, async_client: AsyncOpenAI
1634 ) -> None:
1635 respx_mock.post("/upload").mock(side_effect=mirror_request_content)
1636
1637 file_content = b"Hello, this is a test file."
1638
1639 with pytest.deprecated_call(
1640 match="Passing raw bytes as `body` is deprecated and will be removed in a future version. Please pass raw bytes via the `content` parameter instead."
1641 ):
1642 response = await async_client.post(
1643 "/upload",
1644 body=file_content,
1645 cast_to=httpx.Response,
1646 options={"headers": {"Content-Type": "application/octet-stream"}},
1647 )
1648
1649 assert response.status_code == 200
1650 assert response.request.headers["Content-Type"] == "application/octet-stream"
1651 assert response.content == file_content
1652
1653 @pytest.mark.respx(base_url=base_url)
1654 async def test_basic_union_response(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1655 class Model1(BaseModel):
1656 name: str
1657
1658 class Model2(BaseModel):
1659 foo: str
1660
1661 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1662
1663 response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1664 assert isinstance(response, Model2)
1665 assert response.foo == "bar"
1666
1667 @pytest.mark.respx(base_url=base_url)
1668 async def test_union_response_different_types(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1669 """Union of objects with the same field name using a different type"""
1670
1671 class Model1(BaseModel):
1672 foo: int
1673
1674 class Model2(BaseModel):
1675 foo: str
1676
1677 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1678
1679 response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1680 assert isinstance(response, Model2)
1681 assert response.foo == "bar"
1682
1683 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1684
1685 response = await async_client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1686 assert isinstance(response, Model1)
1687 assert response.foo == 1
1688
1689 @pytest.mark.respx(base_url=base_url)
1690 async def test_non_application_json_content_type_for_json_data(
1691 self, respx_mock: MockRouter, async_client: AsyncOpenAI
1692 ) -> None:
1693 """
1694 Response that sets Content-Type to something other than application/json but returns json data
1695 """
1696
1697 class Model(BaseModel):
1698 foo: int
1699
1700 respx_mock.get("/foo").mock(
1701 return_value=httpx.Response(
1702 200,
1703 content=json.dumps({"foo": 2}),
1704 headers={"Content-Type": "application/text"},
1705 )
1706 )
1707
1708 response = await async_client.get("/foo", cast_to=Model)
1709 assert isinstance(response, Model)
1710 assert response.foo == 2
1711
1712 async def test_base_url_setter(self) -> None:
1713 client = AsyncOpenAI(
1714 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1715 )
1716 assert client.base_url == "https://example.com/from_init/"
1717
1718 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1719
1720 assert client.base_url == "https://example.com/from_setter/"
1721
1722 await client.close()
1723
1724 async def test_base_url_env(self) -> None:
1725 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1726 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1727 assert client.base_url == "http://localhost:5000/from/env/"
1728
1729 @pytest.mark.parametrize(
1730 "client",
1731 [
1732 AsyncOpenAI(
1733 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1734 ),
1735 AsyncOpenAI(
1736 base_url="http://localhost:5000/custom/path/",
1737 api_key=api_key,
1738 _strict_response_validation=True,
1739 http_client=httpx.AsyncClient(),
1740 ),
1741 ],
1742 ids=["standard", "custom http client"],
1743 )
1744 async def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1745 request = client._build_request(
1746 FinalRequestOptions(
1747 method="post",
1748 url="/foo",
1749 json_data={"foo": "bar"},
1750 ),
1751 )
1752 assert request.url == "http://localhost:5000/custom/path/foo"
1753 await client.close()
1754
1755 @pytest.mark.parametrize(
1756 "client",
1757 [
1758 AsyncOpenAI(
1759 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1760 ),
1761 AsyncOpenAI(
1762 base_url="http://localhost:5000/custom/path/",
1763 api_key=api_key,
1764 _strict_response_validation=True,
1765 http_client=httpx.AsyncClient(),
1766 ),
1767 ],
1768 ids=["standard", "custom http client"],
1769 )
1770 async def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1771 request = client._build_request(
1772 FinalRequestOptions(
1773 method="post",
1774 url="/foo",
1775 json_data={"foo": "bar"},
1776 ),
1777 )
1778 assert request.url == "http://localhost:5000/custom/path/foo"
1779 await client.close()
1780
1781 @pytest.mark.parametrize(
1782 "client",
1783 [
1784 AsyncOpenAI(
1785 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1786 ),
1787 AsyncOpenAI(
1788 base_url="http://localhost:5000/custom/path/",
1789 api_key=api_key,
1790 _strict_response_validation=True,
1791 http_client=httpx.AsyncClient(),
1792 ),
1793 ],
1794 ids=["standard", "custom http client"],
1795 )
1796 async def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1797 request = client._build_request(
1798 FinalRequestOptions(
1799 method="post",
1800 url="https://myapi.com/foo",
1801 json_data={"foo": "bar"},
1802 ),
1803 )
1804 assert request.url == "https://myapi.com/foo"
1805 await client.close()
1806
1807 async def test_copied_client_does_not_close_http(self) -> None:
1808 test_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1809 assert not test_client.is_closed()
1810
1811 copied = test_client.copy()
1812 assert copied is not test_client
1813
1814 del copied
1815
1816 await asyncio.sleep(0.2)
1817 assert not test_client.is_closed()
1818
1819 async def test_client_context_manager(self) -> None:
1820 test_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1821 async with test_client as c2:
1822 assert c2 is test_client
1823 assert not c2.is_closed()
1824 assert not test_client.is_closed()
1825 assert test_client.is_closed()
1826
1827 @pytest.mark.respx(base_url=base_url)
1828 async def test_client_response_validation_error(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1829 class Model(BaseModel):
1830 foo: str
1831
1832 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1833
1834 with pytest.raises(APIResponseValidationError) as exc:
1835 await async_client.get("/foo", cast_to=Model)
1836
1837 assert isinstance(exc.value.__cause__, ValidationError)
1838
1839 async def test_client_max_retries_validation(self) -> None:
1840 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
1841 AsyncOpenAI(
1842 base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)
1843 )
1844
1845 @pytest.mark.respx(base_url=base_url)
1846 async def test_default_stream_cls(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1847 class Model(BaseModel):
1848 name: str
1849
1850 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1851
1852 stream = await async_client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1853 assert isinstance(stream, AsyncStream)
1854 await stream.response.aclose()
1855
1856 @pytest.mark.respx(base_url=base_url)
1857 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1858 class Model(BaseModel):
1859 name: str
1860
1861 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1862
1863 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1864
1865 with pytest.raises(APIResponseValidationError):
1866 await strict_client.get("/foo", cast_to=Model)
1867
1868 non_strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1869
1870 response = await non_strict_client.get("/foo", cast_to=Model)
1871 assert isinstance(response, str) # type: ignore[unreachable]
1872
1873 await strict_client.close()
1874 await non_strict_client.close()
1875
1876 @pytest.mark.parametrize(
1877 "remaining_retries,retry_after,timeout",
1878 [
1879 [3, "20", 20],
1880 [3, "0", 0.5],
1881 [3, "-10", 0.5],
1882 [3, "60", 60],
1883 [3, "61", 0.5],
1884 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1885 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1886 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1887 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1888 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1889 [3, "99999999999999999999999999999999999", 0.5],
1890 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1891 [3, "", 0.5],
1892 [2, "", 0.5 * 2.0],
1893 [1, "", 0.5 * 4.0],
1894 [-1100, "", 8], # test large number potentially overflowing
1895 ],
1896 )
1897 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1898 async def test_parse_retry_after_header(
1899 self, remaining_retries: int, retry_after: str, timeout: float, async_client: AsyncOpenAI
1900 ) -> None:
1901 headers = httpx.Headers({"retry-after": retry_after})
1902 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1903 calculated = async_client._calculate_retry_timeout(remaining_retries, options, headers)
1904 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1905
1906 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1907 @pytest.mark.respx(base_url=base_url)
1908 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1909 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1910
1911 with pytest.raises(APITimeoutError):
1912 await async_client.chat.completions.with_streaming_response.create(
1913 messages=[
1914 {
1915 "content": "string",
1916 "role": "developer",
1917 }
1918 ],
1919 model="gpt-4o",
1920 ).__aenter__()
1921
1922 assert _get_open_connections(async_client) == 0
1923
1924 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1925 @pytest.mark.respx(base_url=base_url)
1926 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1927 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1928
1929 with pytest.raises(APIStatusError):
1930 await async_client.chat.completions.with_streaming_response.create(
1931 messages=[
1932 {
1933 "content": "string",
1934 "role": "developer",
1935 }
1936 ],
1937 model="gpt-4o",
1938 ).__aenter__()
1939 assert _get_open_connections(async_client) == 0
1940
1941 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1942 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1943 @pytest.mark.respx(base_url=base_url)
1944 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
1945 async def test_retries_taken(
1946 self,
1947 async_client: AsyncOpenAI,
1948 failures_before_success: int,
1949 failure_mode: Literal["status", "exception"],
1950 respx_mock: MockRouter,
1951 ) -> None:
1952 client = async_client.with_options(max_retries=4)
1953
1954 nb_retries = 0
1955
1956 def retry_handler(_request: httpx.Request) -> httpx.Response:
1957 nonlocal nb_retries
1958 if nb_retries < failures_before_success:
1959 nb_retries += 1
1960 if failure_mode == "exception":
1961 raise RuntimeError("oops")
1962 return httpx.Response(500)
1963 return httpx.Response(200)
1964
1965 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1966
1967 response = await client.chat.completions.with_raw_response.create(
1968 messages=[
1969 {
1970 "content": "string",
1971 "role": "developer",
1972 }
1973 ],
1974 model="gpt-4o",
1975 )
1976
1977 assert response.retries_taken == failures_before_success
1978 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1979
1980 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1981 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1982 @pytest.mark.respx(base_url=base_url)
1983 async def test_omit_retry_count_header(
1984 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1985 ) -> None:
1986 client = async_client.with_options(max_retries=4)
1987
1988 nb_retries = 0
1989
1990 def retry_handler(_request: httpx.Request) -> httpx.Response:
1991 nonlocal nb_retries
1992 if nb_retries < failures_before_success:
1993 nb_retries += 1
1994 return httpx.Response(500)
1995 return httpx.Response(200)
1996
1997 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1998
1999 response = await client.chat.completions.with_raw_response.create(
2000 messages=[
2001 {
2002 "content": "string",
2003 "role": "developer",
2004 }
2005 ],
2006 model="gpt-4o",
2007 extra_headers={"x-stainless-retry-count": Omit()},
2008 )
2009
2010 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
2011
2012 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
2013 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
2014 @pytest.mark.respx(base_url=base_url)
2015 async def test_overwrite_retry_count_header(
2016 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
2017 ) -> None:
2018 client = async_client.with_options(max_retries=4)
2019
2020 nb_retries = 0
2021
2022 def retry_handler(_request: httpx.Request) -> httpx.Response:
2023 nonlocal nb_retries
2024 if nb_retries < failures_before_success:
2025 nb_retries += 1
2026 return httpx.Response(500)
2027 return httpx.Response(200)
2028
2029 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
2030
2031 response = await client.chat.completions.with_raw_response.create(
2032 messages=[
2033 {
2034 "content": "string",
2035 "role": "developer",
2036 }
2037 ],
2038 model="gpt-4o",
2039 extra_headers={"x-stainless-retry-count": "42"},
2040 )
2041
2042 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
2043
2044 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
2045 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
2046 @pytest.mark.respx(base_url=base_url)
2047 async def test_retries_taken_new_response_class(
2048 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
2049 ) -> None:
2050 client = async_client.with_options(max_retries=4)
2051
2052 nb_retries = 0
2053
2054 def retry_handler(_request: httpx.Request) -> httpx.Response:
2055 nonlocal nb_retries
2056 if nb_retries < failures_before_success:
2057 nb_retries += 1
2058 return httpx.Response(500)
2059 return httpx.Response(200)
2060
2061 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
2062
2063 async with client.chat.completions.with_streaming_response.create(
2064 messages=[
2065 {
2066 "content": "string",
2067 "role": "developer",
2068 }
2069 ],
2070 model="gpt-4o",
2071 ) as response:
2072 assert response.retries_taken == failures_before_success
2073 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
2074
2075 async def test_get_platform(self) -> None:
2076 platform = await asyncify(get_platform)()
2077 assert isinstance(platform, (str, OtherPlatform))
2078
2079 async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
2080 # Test that the proxy environment variables are set correctly
2081 monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
2082
2083 client = DefaultAsyncHttpxClient()
2084
2085 mounts = tuple(client._mounts.items())
2086 assert len(mounts) == 1
2087 assert mounts[0][0].pattern == "https://"
2088
2089 @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
2090 async def test_default_client_creation(self) -> None:
2091 # Ensure that the client can be initialized without any exceptions
2092 DefaultAsyncHttpxClient(
2093 verify=True,
2094 cert=None,
2095 trust_env=True,
2096 http1=True,
2097 http2=False,
2098 limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
2099 )
2100
2101 @pytest.mark.respx(base_url=base_url)
2102 async def test_follow_redirects(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
2103 # Test that the default follow_redirects=True allows following redirects
2104 respx_mock.post("/redirect").mock(
2105 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
2106 )
2107 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
2108
2109 response = await async_client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
2110 assert response.status_code == 200
2111 assert response.json() == {"status": "ok"}
2112
2113 @pytest.mark.respx(base_url=base_url)
2114 async def test_follow_redirects_disabled(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
2115 # Test that follow_redirects=False prevents following redirects
2116 respx_mock.post("/redirect").mock(
2117 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
2118 )
2119
2120 with pytest.raises(APIStatusError) as exc_info:
2121 await async_client.post(
2122 "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
2123 )
2124
2125 assert exc_info.value.response.status_code == 302
2126 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
2127
2128 @pytest.mark.asyncio
2129 async def test_api_key_before_after_refresh_provider(self) -> None:
2130 async def mock_api_key_provider():
2131 return "test_bearer_token"
2132
2133 client = AsyncOpenAI(base_url=base_url, api_key=mock_api_key_provider)
2134
2135 assert client.api_key == ""
2136 assert "Authorization" not in client.auth_headers
2137
2138 await client._refresh_api_key()
2139
2140 assert client.api_key == "test_bearer_token"
2141 assert client.auth_headers.get("Authorization") == "Bearer test_bearer_token"
2142
2143 @pytest.mark.asyncio
2144 async def test_api_key_before_after_refresh_str(self) -> None:
2145 client = AsyncOpenAI(base_url=base_url, api_key="test_api_key")
2146
2147 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
2148 await client._refresh_api_key()
2149
2150 assert client.auth_headers.get("Authorization") == "Bearer test_api_key"
2151
2152 @pytest.mark.asyncio
2153 @pytest.mark.respx()
2154 async def test_bearer_token_refresh_async(self, respx_mock: MockRouter) -> None:
2155 respx_mock.post(base_url + "/chat/completions").mock(
2156 side_effect=[
2157 httpx.Response(500, json={"error": "server error"}),
2158 httpx.Response(200, json={"foo": "bar"}),
2159 ]
2160 )
2161
2162 counter = 0
2163
2164 async def token_provider() -> str:
2165 nonlocal counter
2166
2167 counter += 1
2168
2169 if counter == 1:
2170 return "first"
2171
2172 return "second"
2173
2174 client = AsyncOpenAI(base_url=base_url, api_key=token_provider)
2175 await client.chat.completions.create(messages=[], model="gpt-4")
2176
2177 calls = cast("list[MockRequestCall]", respx_mock.calls)
2178 assert len(calls) == 2
2179
2180 assert calls[0].request.headers.get("Authorization") == "Bearer first"
2181 assert calls[1].request.headers.get("Authorization") == "Bearer second"
2182
2183 @pytest.mark.asyncio
2184 async def test_copy_auth(self) -> None:
2185 async def token_provider_1() -> str:
2186 return "test_bearer_token_1"
2187
2188 async def token_provider_2() -> str:
2189 return "test_bearer_token_2"
2190
2191 client = AsyncOpenAI(base_url=base_url, api_key=token_provider_1).copy(api_key=token_provider_2)
2192 await client._refresh_api_key()
2193 assert client.auth_headers == {"Authorization": "Bearer test_bearer_token_2"}
2194