openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.50.2

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1749lines · modecode

1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import gc
6import os
7import json
8import asyncio
9import inspect
10import tracemalloc
11from typing import Any, Union, cast
12from unittest import mock
13
14import httpx
15import pytest
16from respx import MockRouter
17from pydantic import ValidationError
18
19from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
20from openai._types import Omit
21from openai._models import BaseModel, FinalRequestOptions
22from openai._constants import RAW_RESPONSE_HEADER
23from openai._streaming import Stream, AsyncStream
24from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
25from openai._base_client import DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, BaseClient, make_request_options
26
27from .utils import update_env
28
29base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
30api_key = "My API Key"
31
32
33def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
34 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
35 url = httpx.URL(request.url)
36 return dict(url.params)
37
38
39def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
40 return 0.1
41
42
43def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
44 transport = client._client._transport
45 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
46
47 pool = transport._pool
48 return len(pool._requests)
49
50
51class TestOpenAI:
52 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
53
54 @pytest.mark.respx(base_url=base_url)
55 def test_raw_response(self, respx_mock: MockRouter) -> None:
56 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
57
58 response = self.client.post("/foo", cast_to=httpx.Response)
59 assert response.status_code == 200
60 assert isinstance(response, httpx.Response)
61 assert response.json() == {"foo": "bar"}
62
63 @pytest.mark.respx(base_url=base_url)
64 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
65 respx_mock.post("/foo").mock(
66 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
67 )
68
69 response = self.client.post("/foo", cast_to=httpx.Response)
70 assert response.status_code == 200
71 assert isinstance(response, httpx.Response)
72 assert response.json() == {"foo": "bar"}
73
74 def test_copy(self) -> None:
75 copied = self.client.copy()
76 assert id(copied) != id(self.client)
77
78 copied = self.client.copy(api_key="another My API Key")
79 assert copied.api_key == "another My API Key"
80 assert self.client.api_key == "My API Key"
81
82 def test_copy_default_options(self) -> None:
83 # options that have a default are overridden correctly
84 copied = self.client.copy(max_retries=7)
85 assert copied.max_retries == 7
86 assert self.client.max_retries == 2
87
88 copied2 = copied.copy(max_retries=6)
89 assert copied2.max_retries == 6
90 assert copied.max_retries == 7
91
92 # timeout
93 assert isinstance(self.client.timeout, httpx.Timeout)
94 copied = self.client.copy(timeout=None)
95 assert copied.timeout is None
96 assert isinstance(self.client.timeout, httpx.Timeout)
97
98 def test_copy_default_headers(self) -> None:
99 client = OpenAI(
100 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
101 )
102 assert client.default_headers["X-Foo"] == "bar"
103
104 # does not override the already given value when not specified
105 copied = client.copy()
106 assert copied.default_headers["X-Foo"] == "bar"
107
108 # merges already given headers
109 copied = client.copy(default_headers={"X-Bar": "stainless"})
110 assert copied.default_headers["X-Foo"] == "bar"
111 assert copied.default_headers["X-Bar"] == "stainless"
112
113 # uses new values for any already given headers
114 copied = client.copy(default_headers={"X-Foo": "stainless"})
115 assert copied.default_headers["X-Foo"] == "stainless"
116
117 # set_default_headers
118
119 # completely overrides already set values
120 copied = client.copy(set_default_headers={})
121 assert copied.default_headers.get("X-Foo") is None
122
123 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
124 assert copied.default_headers["X-Bar"] == "Robert"
125
126 with pytest.raises(
127 ValueError,
128 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
129 ):
130 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
131
132 def test_copy_default_query(self) -> None:
133 client = OpenAI(
134 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
135 )
136 assert _get_params(client)["foo"] == "bar"
137
138 # does not override the already given value when not specified
139 copied = client.copy()
140 assert _get_params(copied)["foo"] == "bar"
141
142 # merges already given params
143 copied = client.copy(default_query={"bar": "stainless"})
144 params = _get_params(copied)
145 assert params["foo"] == "bar"
146 assert params["bar"] == "stainless"
147
148 # uses new values for any already given headers
149 copied = client.copy(default_query={"foo": "stainless"})
150 assert _get_params(copied)["foo"] == "stainless"
151
152 # set_default_query
153
154 # completely overrides already set values
155 copied = client.copy(set_default_query={})
156 assert _get_params(copied) == {}
157
158 copied = client.copy(set_default_query={"bar": "Robert"})
159 assert _get_params(copied)["bar"] == "Robert"
160
161 with pytest.raises(
162 ValueError,
163 # TODO: update
164 match="`default_query` and `set_default_query` arguments are mutually exclusive",
165 ):
166 client.copy(set_default_query={}, default_query={"foo": "Bar"})
167
168 def test_copy_signature(self) -> None:
169 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
170 init_signature = inspect.signature(
171 # mypy doesn't like that we access the `__init__` property.
172 self.client.__init__, # type: ignore[misc]
173 )
174 copy_signature = inspect.signature(self.client.copy)
175 exclude_params = {"transport", "proxies", "_strict_response_validation"}
176
177 for name in init_signature.parameters.keys():
178 if name in exclude_params:
179 continue
180
181 copy_param = copy_signature.parameters.get(name)
182 assert copy_param is not None, f"copy() signature is missing the {name} param"
183
184 def test_copy_build_request(self) -> None:
185 options = FinalRequestOptions(method="get", url="/foo")
186
187 def build_request(options: FinalRequestOptions) -> None:
188 client = self.client.copy()
189 client._build_request(options)
190
191 # ensure that the machinery is warmed up before tracing starts.
192 build_request(options)
193 gc.collect()
194
195 tracemalloc.start(1000)
196
197 snapshot_before = tracemalloc.take_snapshot()
198
199 ITERATIONS = 10
200 for _ in range(ITERATIONS):
201 build_request(options)
202
203 gc.collect()
204 snapshot_after = tracemalloc.take_snapshot()
205
206 tracemalloc.stop()
207
208 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
209 if diff.count == 0:
210 # Avoid false positives by considering only leaks (i.e. allocations that persist).
211 return
212
213 if diff.count % ITERATIONS != 0:
214 # Avoid false positives by considering only leaks that appear per iteration.
215 return
216
217 for frame in diff.traceback:
218 if any(
219 frame.filename.endswith(fragment)
220 for fragment in [
221 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
222 #
223 # removing the decorator fixes the leak for reasons we don't understand.
224 "openai/_legacy_response.py",
225 "openai/_response.py",
226 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
227 "openai/_compat.py",
228 # Standard library leaks we don't care about.
229 "/logging/__init__.py",
230 ]
231 ):
232 return
233
234 leaks.append(diff)
235
236 leaks: list[tracemalloc.StatisticDiff] = []
237 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
238 add_leak(leaks, diff)
239 if leaks:
240 for leak in leaks:
241 print("MEMORY LEAK:", leak)
242 for frame in leak.traceback:
243 print(frame)
244 raise AssertionError()
245
246 def test_request_timeout(self) -> None:
247 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
248 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
249 assert timeout == DEFAULT_TIMEOUT
250
251 request = self.client._build_request(
252 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
253 )
254 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
255 assert timeout == httpx.Timeout(100.0)
256
257 def test_client_timeout_option(self) -> None:
258 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
259
260 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
261 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
262 assert timeout == httpx.Timeout(0)
263
264 def test_http_client_timeout_option(self) -> None:
265 # custom timeout given to the httpx client should be used
266 with httpx.Client(timeout=None) as http_client:
267 client = OpenAI(
268 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
269 )
270
271 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
272 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
273 assert timeout == httpx.Timeout(None)
274
275 # no timeout given to the httpx client should not use the httpx default
276 with httpx.Client() as http_client:
277 client = OpenAI(
278 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
279 )
280
281 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
282 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
283 assert timeout == DEFAULT_TIMEOUT
284
285 # explicitly passing the default timeout currently results in it being ignored
286 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
287 client = OpenAI(
288 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
289 )
290
291 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
292 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
293 assert timeout == DEFAULT_TIMEOUT # our default
294
295 async def test_invalid_http_client(self) -> None:
296 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
297 async with httpx.AsyncClient() as http_client:
298 OpenAI(
299 base_url=base_url,
300 api_key=api_key,
301 _strict_response_validation=True,
302 http_client=cast(Any, http_client),
303 )
304
305 def test_default_headers_option(self) -> None:
306 client = OpenAI(
307 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
308 )
309 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
310 assert request.headers.get("x-foo") == "bar"
311 assert request.headers.get("x-stainless-lang") == "python"
312
313 client2 = OpenAI(
314 base_url=base_url,
315 api_key=api_key,
316 _strict_response_validation=True,
317 default_headers={
318 "X-Foo": "stainless",
319 "X-Stainless-Lang": "my-overriding-header",
320 },
321 )
322 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
323 assert request.headers.get("x-foo") == "stainless"
324 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
325
326 def test_validate_headers(self) -> None:
327 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
328 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
329 assert request.headers.get("Authorization") == f"Bearer {api_key}"
330
331 with pytest.raises(OpenAIError):
332 with update_env(**{"OPENAI_API_KEY": Omit()}):
333 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
334 _ = client2
335
336 def test_default_query_option(self) -> None:
337 client = OpenAI(
338 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
339 )
340 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
341 url = httpx.URL(request.url)
342 assert dict(url.params) == {"query_param": "bar"}
343
344 request = client._build_request(
345 FinalRequestOptions(
346 method="get",
347 url="/foo",
348 params={"foo": "baz", "query_param": "overriden"},
349 )
350 )
351 url = httpx.URL(request.url)
352 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
353
354 def test_request_extra_json(self) -> None:
355 request = self.client._build_request(
356 FinalRequestOptions(
357 method="post",
358 url="/foo",
359 json_data={"foo": "bar"},
360 extra_json={"baz": False},
361 ),
362 )
363 data = json.loads(request.content.decode("utf-8"))
364 assert data == {"foo": "bar", "baz": False}
365
366 request = self.client._build_request(
367 FinalRequestOptions(
368 method="post",
369 url="/foo",
370 extra_json={"baz": False},
371 ),
372 )
373 data = json.loads(request.content.decode("utf-8"))
374 assert data == {"baz": False}
375
376 # `extra_json` takes priority over `json_data` when keys clash
377 request = self.client._build_request(
378 FinalRequestOptions(
379 method="post",
380 url="/foo",
381 json_data={"foo": "bar", "baz": True},
382 extra_json={"baz": None},
383 ),
384 )
385 data = json.loads(request.content.decode("utf-8"))
386 assert data == {"foo": "bar", "baz": None}
387
388 def test_request_extra_headers(self) -> None:
389 request = self.client._build_request(
390 FinalRequestOptions(
391 method="post",
392 url="/foo",
393 **make_request_options(extra_headers={"X-Foo": "Foo"}),
394 ),
395 )
396 assert request.headers.get("X-Foo") == "Foo"
397
398 # `extra_headers` takes priority over `default_headers` when keys clash
399 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
400 FinalRequestOptions(
401 method="post",
402 url="/foo",
403 **make_request_options(
404 extra_headers={"X-Bar": "false"},
405 ),
406 ),
407 )
408 assert request.headers.get("X-Bar") == "false"
409
410 def test_request_extra_query(self) -> None:
411 request = self.client._build_request(
412 FinalRequestOptions(
413 method="post",
414 url="/foo",
415 **make_request_options(
416 extra_query={"my_query_param": "Foo"},
417 ),
418 ),
419 )
420 params = dict(request.url.params)
421 assert params == {"my_query_param": "Foo"}
422
423 # if both `query` and `extra_query` are given, they are merged
424 request = self.client._build_request(
425 FinalRequestOptions(
426 method="post",
427 url="/foo",
428 **make_request_options(
429 query={"bar": "1"},
430 extra_query={"foo": "2"},
431 ),
432 ),
433 )
434 params = dict(request.url.params)
435 assert params == {"bar": "1", "foo": "2"}
436
437 # `extra_query` takes priority over `query` when keys clash
438 request = self.client._build_request(
439 FinalRequestOptions(
440 method="post",
441 url="/foo",
442 **make_request_options(
443 query={"foo": "1"},
444 extra_query={"foo": "2"},
445 ),
446 ),
447 )
448 params = dict(request.url.params)
449 assert params == {"foo": "2"}
450
451 def test_multipart_repeating_array(self, client: OpenAI) -> None:
452 request = client._build_request(
453 FinalRequestOptions.construct(
454 method="get",
455 url="/foo",
456 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
457 json_data={"array": ["foo", "bar"]},
458 files=[("foo.txt", b"hello world")],
459 )
460 )
461
462 assert request.read().split(b"\r\n") == [
463 b"--6b7ba517decee4a450543ea6ae821c82",
464 b'Content-Disposition: form-data; name="array[]"',
465 b"",
466 b"foo",
467 b"--6b7ba517decee4a450543ea6ae821c82",
468 b'Content-Disposition: form-data; name="array[]"',
469 b"",
470 b"bar",
471 b"--6b7ba517decee4a450543ea6ae821c82",
472 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
473 b"Content-Type: application/octet-stream",
474 b"",
475 b"hello world",
476 b"--6b7ba517decee4a450543ea6ae821c82--",
477 b"",
478 ]
479
480 @pytest.mark.respx(base_url=base_url)
481 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
482 class Model1(BaseModel):
483 name: str
484
485 class Model2(BaseModel):
486 foo: str
487
488 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
489
490 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
491 assert isinstance(response, Model2)
492 assert response.foo == "bar"
493
494 @pytest.mark.respx(base_url=base_url)
495 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
496 """Union of objects with the same field name using a different type"""
497
498 class Model1(BaseModel):
499 foo: int
500
501 class Model2(BaseModel):
502 foo: str
503
504 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
505
506 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
507 assert isinstance(response, Model2)
508 assert response.foo == "bar"
509
510 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
511
512 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
513 assert isinstance(response, Model1)
514 assert response.foo == 1
515
516 @pytest.mark.respx(base_url=base_url)
517 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
518 """
519 Response that sets Content-Type to something other than application/json but returns json data
520 """
521
522 class Model(BaseModel):
523 foo: int
524
525 respx_mock.get("/foo").mock(
526 return_value=httpx.Response(
527 200,
528 content=json.dumps({"foo": 2}),
529 headers={"Content-Type": "application/text"},
530 )
531 )
532
533 response = self.client.get("/foo", cast_to=Model)
534 assert isinstance(response, Model)
535 assert response.foo == 2
536
537 def test_base_url_setter(self) -> None:
538 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
539 assert client.base_url == "https://example.com/from_init/"
540
541 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
542
543 assert client.base_url == "https://example.com/from_setter/"
544
545 def test_base_url_env(self) -> None:
546 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
547 client = OpenAI(api_key=api_key, _strict_response_validation=True)
548 assert client.base_url == "http://localhost:5000/from/env/"
549
550 @pytest.mark.parametrize(
551 "client",
552 [
553 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
554 OpenAI(
555 base_url="http://localhost:5000/custom/path/",
556 api_key=api_key,
557 _strict_response_validation=True,
558 http_client=httpx.Client(),
559 ),
560 ],
561 ids=["standard", "custom http client"],
562 )
563 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
564 request = client._build_request(
565 FinalRequestOptions(
566 method="post",
567 url="/foo",
568 json_data={"foo": "bar"},
569 ),
570 )
571 assert request.url == "http://localhost:5000/custom/path/foo"
572
573 @pytest.mark.parametrize(
574 "client",
575 [
576 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
577 OpenAI(
578 base_url="http://localhost:5000/custom/path/",
579 api_key=api_key,
580 _strict_response_validation=True,
581 http_client=httpx.Client(),
582 ),
583 ],
584 ids=["standard", "custom http client"],
585 )
586 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
587 request = client._build_request(
588 FinalRequestOptions(
589 method="post",
590 url="/foo",
591 json_data={"foo": "bar"},
592 ),
593 )
594 assert request.url == "http://localhost:5000/custom/path/foo"
595
596 @pytest.mark.parametrize(
597 "client",
598 [
599 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
600 OpenAI(
601 base_url="http://localhost:5000/custom/path/",
602 api_key=api_key,
603 _strict_response_validation=True,
604 http_client=httpx.Client(),
605 ),
606 ],
607 ids=["standard", "custom http client"],
608 )
609 def test_absolute_request_url(self, client: OpenAI) -> None:
610 request = client._build_request(
611 FinalRequestOptions(
612 method="post",
613 url="https://myapi.com/foo",
614 json_data={"foo": "bar"},
615 ),
616 )
617 assert request.url == "https://myapi.com/foo"
618
619 def test_copied_client_does_not_close_http(self) -> None:
620 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
621 assert not client.is_closed()
622
623 copied = client.copy()
624 assert copied is not client
625
626 del copied
627
628 assert not client.is_closed()
629
630 def test_client_context_manager(self) -> None:
631 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
632 with client as c2:
633 assert c2 is client
634 assert not c2.is_closed()
635 assert not client.is_closed()
636 assert client.is_closed()
637
638 @pytest.mark.respx(base_url=base_url)
639 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
640 class Model(BaseModel):
641 foo: str
642
643 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
644
645 with pytest.raises(APIResponseValidationError) as exc:
646 self.client.get("/foo", cast_to=Model)
647
648 assert isinstance(exc.value.__cause__, ValidationError)
649
650 def test_client_max_retries_validation(self) -> None:
651 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
652 OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None))
653
654 @pytest.mark.respx(base_url=base_url)
655 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
656 class Model(BaseModel):
657 name: str
658
659 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
660
661 stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
662 assert isinstance(stream, Stream)
663 stream.response.close()
664
665 @pytest.mark.respx(base_url=base_url)
666 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
667 class Model(BaseModel):
668 name: str
669
670 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
671
672 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
673
674 with pytest.raises(APIResponseValidationError):
675 strict_client.get("/foo", cast_to=Model)
676
677 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
678
679 response = client.get("/foo", cast_to=Model)
680 assert isinstance(response, str) # type: ignore[unreachable]
681
682 @pytest.mark.parametrize(
683 "remaining_retries,retry_after,timeout",
684 [
685 [3, "20", 20],
686 [3, "0", 0.5],
687 [3, "-10", 0.5],
688 [3, "60", 60],
689 [3, "61", 0.5],
690 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
691 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
692 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
693 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
694 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
695 [3, "99999999999999999999999999999999999", 0.5],
696 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
697 [3, "", 0.5],
698 [2, "", 0.5 * 2.0],
699 [1, "", 0.5 * 4.0],
700 ],
701 )
702 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
703 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
704 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
705
706 headers = httpx.Headers({"retry-after": retry_after})
707 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
708 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
709 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
710
711 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
712 @pytest.mark.respx(base_url=base_url)
713 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
714 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
715
716 with pytest.raises(APITimeoutError):
717 self.client.post(
718 "/chat/completions",
719 body=cast(
720 object,
721 dict(
722 messages=[
723 {
724 "role": "user",
725 "content": "Say this is a test",
726 }
727 ],
728 model="gpt-3.5-turbo",
729 ),
730 ),
731 cast_to=httpx.Response,
732 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
733 )
734
735 assert _get_open_connections(self.client) == 0
736
737 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
738 @pytest.mark.respx(base_url=base_url)
739 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
740 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
741
742 with pytest.raises(APIStatusError):
743 self.client.post(
744 "/chat/completions",
745 body=cast(
746 object,
747 dict(
748 messages=[
749 {
750 "role": "user",
751 "content": "Say this is a test",
752 }
753 ],
754 model="gpt-3.5-turbo",
755 ),
756 ),
757 cast_to=httpx.Response,
758 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
759 )
760
761 assert _get_open_connections(self.client) == 0
762
763 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
764 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
765 @pytest.mark.respx(base_url=base_url)
766 def test_retries_taken(self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter) -> None:
767 client = client.with_options(max_retries=4)
768
769 nb_retries = 0
770
771 def retry_handler(_request: httpx.Request) -> httpx.Response:
772 nonlocal nb_retries
773 if nb_retries < failures_before_success:
774 nb_retries += 1
775 return httpx.Response(500)
776 return httpx.Response(200)
777
778 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
779
780 response = client.chat.completions.with_raw_response.create(
781 messages=[
782 {
783 "content": "string",
784 "role": "system",
785 }
786 ],
787 model="gpt-4o",
788 )
789
790 assert response.retries_taken == failures_before_success
791 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
792
793 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
794 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
795 @pytest.mark.respx(base_url=base_url)
796 def test_omit_retry_count_header(
797 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
798 ) -> None:
799 client = client.with_options(max_retries=4)
800
801 nb_retries = 0
802
803 def retry_handler(_request: httpx.Request) -> httpx.Response:
804 nonlocal nb_retries
805 if nb_retries < failures_before_success:
806 nb_retries += 1
807 return httpx.Response(500)
808 return httpx.Response(200)
809
810 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
811
812 response = client.chat.completions.with_raw_response.create(
813 messages=[
814 {
815 "content": "string",
816 "role": "system",
817 }
818 ],
819 model="gpt-4o",
820 extra_headers={"x-stainless-retry-count": Omit()},
821 )
822
823 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
824
825 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
826 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
827 @pytest.mark.respx(base_url=base_url)
828 def test_overwrite_retry_count_header(
829 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
830 ) -> None:
831 client = client.with_options(max_retries=4)
832
833 nb_retries = 0
834
835 def retry_handler(_request: httpx.Request) -> httpx.Response:
836 nonlocal nb_retries
837 if nb_retries < failures_before_success:
838 nb_retries += 1
839 return httpx.Response(500)
840 return httpx.Response(200)
841
842 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
843
844 response = client.chat.completions.with_raw_response.create(
845 messages=[
846 {
847 "content": "string",
848 "role": "system",
849 }
850 ],
851 model="gpt-4o",
852 extra_headers={"x-stainless-retry-count": "42"},
853 )
854
855 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
856
857 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
858 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
859 @pytest.mark.respx(base_url=base_url)
860 def test_retries_taken_new_response_class(
861 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
862 ) -> None:
863 client = client.with_options(max_retries=4)
864
865 nb_retries = 0
866
867 def retry_handler(_request: httpx.Request) -> httpx.Response:
868 nonlocal nb_retries
869 if nb_retries < failures_before_success:
870 nb_retries += 1
871 return httpx.Response(500)
872 return httpx.Response(200)
873
874 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
875
876 with client.chat.completions.with_streaming_response.create(
877 messages=[
878 {
879 "content": "string",
880 "role": "system",
881 }
882 ],
883 model="gpt-4o",
884 ) as response:
885 assert response.retries_taken == failures_before_success
886 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
887
888
889class TestAsyncOpenAI:
890 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
891
892 @pytest.mark.respx(base_url=base_url)
893 @pytest.mark.asyncio
894 async def test_raw_response(self, respx_mock: MockRouter) -> None:
895 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
896
897 response = await self.client.post("/foo", cast_to=httpx.Response)
898 assert response.status_code == 200
899 assert isinstance(response, httpx.Response)
900 assert response.json() == {"foo": "bar"}
901
902 @pytest.mark.respx(base_url=base_url)
903 @pytest.mark.asyncio
904 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
905 respx_mock.post("/foo").mock(
906 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
907 )
908
909 response = await self.client.post("/foo", cast_to=httpx.Response)
910 assert response.status_code == 200
911 assert isinstance(response, httpx.Response)
912 assert response.json() == {"foo": "bar"}
913
914 def test_copy(self) -> None:
915 copied = self.client.copy()
916 assert id(copied) != id(self.client)
917
918 copied = self.client.copy(api_key="another My API Key")
919 assert copied.api_key == "another My API Key"
920 assert self.client.api_key == "My API Key"
921
922 def test_copy_default_options(self) -> None:
923 # options that have a default are overridden correctly
924 copied = self.client.copy(max_retries=7)
925 assert copied.max_retries == 7
926 assert self.client.max_retries == 2
927
928 copied2 = copied.copy(max_retries=6)
929 assert copied2.max_retries == 6
930 assert copied.max_retries == 7
931
932 # timeout
933 assert isinstance(self.client.timeout, httpx.Timeout)
934 copied = self.client.copy(timeout=None)
935 assert copied.timeout is None
936 assert isinstance(self.client.timeout, httpx.Timeout)
937
938 def test_copy_default_headers(self) -> None:
939 client = AsyncOpenAI(
940 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
941 )
942 assert client.default_headers["X-Foo"] == "bar"
943
944 # does not override the already given value when not specified
945 copied = client.copy()
946 assert copied.default_headers["X-Foo"] == "bar"
947
948 # merges already given headers
949 copied = client.copy(default_headers={"X-Bar": "stainless"})
950 assert copied.default_headers["X-Foo"] == "bar"
951 assert copied.default_headers["X-Bar"] == "stainless"
952
953 # uses new values for any already given headers
954 copied = client.copy(default_headers={"X-Foo": "stainless"})
955 assert copied.default_headers["X-Foo"] == "stainless"
956
957 # set_default_headers
958
959 # completely overrides already set values
960 copied = client.copy(set_default_headers={})
961 assert copied.default_headers.get("X-Foo") is None
962
963 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
964 assert copied.default_headers["X-Bar"] == "Robert"
965
966 with pytest.raises(
967 ValueError,
968 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
969 ):
970 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
971
972 def test_copy_default_query(self) -> None:
973 client = AsyncOpenAI(
974 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
975 )
976 assert _get_params(client)["foo"] == "bar"
977
978 # does not override the already given value when not specified
979 copied = client.copy()
980 assert _get_params(copied)["foo"] == "bar"
981
982 # merges already given params
983 copied = client.copy(default_query={"bar": "stainless"})
984 params = _get_params(copied)
985 assert params["foo"] == "bar"
986 assert params["bar"] == "stainless"
987
988 # uses new values for any already given headers
989 copied = client.copy(default_query={"foo": "stainless"})
990 assert _get_params(copied)["foo"] == "stainless"
991
992 # set_default_query
993
994 # completely overrides already set values
995 copied = client.copy(set_default_query={})
996 assert _get_params(copied) == {}
997
998 copied = client.copy(set_default_query={"bar": "Robert"})
999 assert _get_params(copied)["bar"] == "Robert"
1000
1001 with pytest.raises(
1002 ValueError,
1003 # TODO: update
1004 match="`default_query` and `set_default_query` arguments are mutually exclusive",
1005 ):
1006 client.copy(set_default_query={}, default_query={"foo": "Bar"})
1007
1008 def test_copy_signature(self) -> None:
1009 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
1010 init_signature = inspect.signature(
1011 # mypy doesn't like that we access the `__init__` property.
1012 self.client.__init__, # type: ignore[misc]
1013 )
1014 copy_signature = inspect.signature(self.client.copy)
1015 exclude_params = {"transport", "proxies", "_strict_response_validation"}
1016
1017 for name in init_signature.parameters.keys():
1018 if name in exclude_params:
1019 continue
1020
1021 copy_param = copy_signature.parameters.get(name)
1022 assert copy_param is not None, f"copy() signature is missing the {name} param"
1023
1024 def test_copy_build_request(self) -> None:
1025 options = FinalRequestOptions(method="get", url="/foo")
1026
1027 def build_request(options: FinalRequestOptions) -> None:
1028 client = self.client.copy()
1029 client._build_request(options)
1030
1031 # ensure that the machinery is warmed up before tracing starts.
1032 build_request(options)
1033 gc.collect()
1034
1035 tracemalloc.start(1000)
1036
1037 snapshot_before = tracemalloc.take_snapshot()
1038
1039 ITERATIONS = 10
1040 for _ in range(ITERATIONS):
1041 build_request(options)
1042
1043 gc.collect()
1044 snapshot_after = tracemalloc.take_snapshot()
1045
1046 tracemalloc.stop()
1047
1048 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
1049 if diff.count == 0:
1050 # Avoid false positives by considering only leaks (i.e. allocations that persist).
1051 return
1052
1053 if diff.count % ITERATIONS != 0:
1054 # Avoid false positives by considering only leaks that appear per iteration.
1055 return
1056
1057 for frame in diff.traceback:
1058 if any(
1059 frame.filename.endswith(fragment)
1060 for fragment in [
1061 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
1062 #
1063 # removing the decorator fixes the leak for reasons we don't understand.
1064 "openai/_legacy_response.py",
1065 "openai/_response.py",
1066 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
1067 "openai/_compat.py",
1068 # Standard library leaks we don't care about.
1069 "/logging/__init__.py",
1070 ]
1071 ):
1072 return
1073
1074 leaks.append(diff)
1075
1076 leaks: list[tracemalloc.StatisticDiff] = []
1077 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
1078 add_leak(leaks, diff)
1079 if leaks:
1080 for leak in leaks:
1081 print("MEMORY LEAK:", leak)
1082 for frame in leak.traceback:
1083 print(frame)
1084 raise AssertionError()
1085
1086 async def test_request_timeout(self) -> None:
1087 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
1088 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1089 assert timeout == DEFAULT_TIMEOUT
1090
1091 request = self.client._build_request(
1092 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
1093 )
1094 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1095 assert timeout == httpx.Timeout(100.0)
1096
1097 async def test_client_timeout_option(self) -> None:
1098 client = AsyncOpenAI(
1099 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
1100 )
1101
1102 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1103 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1104 assert timeout == httpx.Timeout(0)
1105
1106 async def test_http_client_timeout_option(self) -> None:
1107 # custom timeout given to the httpx client should be used
1108 async with httpx.AsyncClient(timeout=None) as http_client:
1109 client = AsyncOpenAI(
1110 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1111 )
1112
1113 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1114 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1115 assert timeout == httpx.Timeout(None)
1116
1117 # no timeout given to the httpx client should not use the httpx default
1118 async with httpx.AsyncClient() as http_client:
1119 client = AsyncOpenAI(
1120 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1121 )
1122
1123 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1124 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1125 assert timeout == DEFAULT_TIMEOUT
1126
1127 # explicitly passing the default timeout currently results in it being ignored
1128 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1129 client = AsyncOpenAI(
1130 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1131 )
1132
1133 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1134 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1135 assert timeout == DEFAULT_TIMEOUT # our default
1136
1137 def test_invalid_http_client(self) -> None:
1138 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1139 with httpx.Client() as http_client:
1140 AsyncOpenAI(
1141 base_url=base_url,
1142 api_key=api_key,
1143 _strict_response_validation=True,
1144 http_client=cast(Any, http_client),
1145 )
1146
1147 def test_default_headers_option(self) -> None:
1148 client = AsyncOpenAI(
1149 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1150 )
1151 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1152 assert request.headers.get("x-foo") == "bar"
1153 assert request.headers.get("x-stainless-lang") == "python"
1154
1155 client2 = AsyncOpenAI(
1156 base_url=base_url,
1157 api_key=api_key,
1158 _strict_response_validation=True,
1159 default_headers={
1160 "X-Foo": "stainless",
1161 "X-Stainless-Lang": "my-overriding-header",
1162 },
1163 )
1164 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1165 assert request.headers.get("x-foo") == "stainless"
1166 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1167
1168 def test_validate_headers(self) -> None:
1169 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1170 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1171 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1172
1173 with pytest.raises(OpenAIError):
1174 with update_env(**{"OPENAI_API_KEY": Omit()}):
1175 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1176 _ = client2
1177
1178 def test_default_query_option(self) -> None:
1179 client = AsyncOpenAI(
1180 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1181 )
1182 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1183 url = httpx.URL(request.url)
1184 assert dict(url.params) == {"query_param": "bar"}
1185
1186 request = client._build_request(
1187 FinalRequestOptions(
1188 method="get",
1189 url="/foo",
1190 params={"foo": "baz", "query_param": "overriden"},
1191 )
1192 )
1193 url = httpx.URL(request.url)
1194 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
1195
1196 def test_request_extra_json(self) -> None:
1197 request = self.client._build_request(
1198 FinalRequestOptions(
1199 method="post",
1200 url="/foo",
1201 json_data={"foo": "bar"},
1202 extra_json={"baz": False},
1203 ),
1204 )
1205 data = json.loads(request.content.decode("utf-8"))
1206 assert data == {"foo": "bar", "baz": False}
1207
1208 request = self.client._build_request(
1209 FinalRequestOptions(
1210 method="post",
1211 url="/foo",
1212 extra_json={"baz": False},
1213 ),
1214 )
1215 data = json.loads(request.content.decode("utf-8"))
1216 assert data == {"baz": False}
1217
1218 # `extra_json` takes priority over `json_data` when keys clash
1219 request = self.client._build_request(
1220 FinalRequestOptions(
1221 method="post",
1222 url="/foo",
1223 json_data={"foo": "bar", "baz": True},
1224 extra_json={"baz": None},
1225 ),
1226 )
1227 data = json.loads(request.content.decode("utf-8"))
1228 assert data == {"foo": "bar", "baz": None}
1229
1230 def test_request_extra_headers(self) -> None:
1231 request = self.client._build_request(
1232 FinalRequestOptions(
1233 method="post",
1234 url="/foo",
1235 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1236 ),
1237 )
1238 assert request.headers.get("X-Foo") == "Foo"
1239
1240 # `extra_headers` takes priority over `default_headers` when keys clash
1241 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1242 FinalRequestOptions(
1243 method="post",
1244 url="/foo",
1245 **make_request_options(
1246 extra_headers={"X-Bar": "false"},
1247 ),
1248 ),
1249 )
1250 assert request.headers.get("X-Bar") == "false"
1251
1252 def test_request_extra_query(self) -> None:
1253 request = self.client._build_request(
1254 FinalRequestOptions(
1255 method="post",
1256 url="/foo",
1257 **make_request_options(
1258 extra_query={"my_query_param": "Foo"},
1259 ),
1260 ),
1261 )
1262 params = dict(request.url.params)
1263 assert params == {"my_query_param": "Foo"}
1264
1265 # if both `query` and `extra_query` are given, they are merged
1266 request = self.client._build_request(
1267 FinalRequestOptions(
1268 method="post",
1269 url="/foo",
1270 **make_request_options(
1271 query={"bar": "1"},
1272 extra_query={"foo": "2"},
1273 ),
1274 ),
1275 )
1276 params = dict(request.url.params)
1277 assert params == {"bar": "1", "foo": "2"}
1278
1279 # `extra_query` takes priority over `query` when keys clash
1280 request = self.client._build_request(
1281 FinalRequestOptions(
1282 method="post",
1283 url="/foo",
1284 **make_request_options(
1285 query={"foo": "1"},
1286 extra_query={"foo": "2"},
1287 ),
1288 ),
1289 )
1290 params = dict(request.url.params)
1291 assert params == {"foo": "2"}
1292
1293 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1294 request = async_client._build_request(
1295 FinalRequestOptions.construct(
1296 method="get",
1297 url="/foo",
1298 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1299 json_data={"array": ["foo", "bar"]},
1300 files=[("foo.txt", b"hello world")],
1301 )
1302 )
1303
1304 assert request.read().split(b"\r\n") == [
1305 b"--6b7ba517decee4a450543ea6ae821c82",
1306 b'Content-Disposition: form-data; name="array[]"',
1307 b"",
1308 b"foo",
1309 b"--6b7ba517decee4a450543ea6ae821c82",
1310 b'Content-Disposition: form-data; name="array[]"',
1311 b"",
1312 b"bar",
1313 b"--6b7ba517decee4a450543ea6ae821c82",
1314 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1315 b"Content-Type: application/octet-stream",
1316 b"",
1317 b"hello world",
1318 b"--6b7ba517decee4a450543ea6ae821c82--",
1319 b"",
1320 ]
1321
1322 @pytest.mark.respx(base_url=base_url)
1323 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1324 class Model1(BaseModel):
1325 name: str
1326
1327 class Model2(BaseModel):
1328 foo: str
1329
1330 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1331
1332 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1333 assert isinstance(response, Model2)
1334 assert response.foo == "bar"
1335
1336 @pytest.mark.respx(base_url=base_url)
1337 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1338 """Union of objects with the same field name using a different type"""
1339
1340 class Model1(BaseModel):
1341 foo: int
1342
1343 class Model2(BaseModel):
1344 foo: str
1345
1346 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1347
1348 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1349 assert isinstance(response, Model2)
1350 assert response.foo == "bar"
1351
1352 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1353
1354 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1355 assert isinstance(response, Model1)
1356 assert response.foo == 1
1357
1358 @pytest.mark.respx(base_url=base_url)
1359 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1360 """
1361 Response that sets Content-Type to something other than application/json but returns json data
1362 """
1363
1364 class Model(BaseModel):
1365 foo: int
1366
1367 respx_mock.get("/foo").mock(
1368 return_value=httpx.Response(
1369 200,
1370 content=json.dumps({"foo": 2}),
1371 headers={"Content-Type": "application/text"},
1372 )
1373 )
1374
1375 response = await self.client.get("/foo", cast_to=Model)
1376 assert isinstance(response, Model)
1377 assert response.foo == 2
1378
1379 def test_base_url_setter(self) -> None:
1380 client = AsyncOpenAI(
1381 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1382 )
1383 assert client.base_url == "https://example.com/from_init/"
1384
1385 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1386
1387 assert client.base_url == "https://example.com/from_setter/"
1388
1389 def test_base_url_env(self) -> None:
1390 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1391 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1392 assert client.base_url == "http://localhost:5000/from/env/"
1393
1394 @pytest.mark.parametrize(
1395 "client",
1396 [
1397 AsyncOpenAI(
1398 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1399 ),
1400 AsyncOpenAI(
1401 base_url="http://localhost:5000/custom/path/",
1402 api_key=api_key,
1403 _strict_response_validation=True,
1404 http_client=httpx.AsyncClient(),
1405 ),
1406 ],
1407 ids=["standard", "custom http client"],
1408 )
1409 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1410 request = client._build_request(
1411 FinalRequestOptions(
1412 method="post",
1413 url="/foo",
1414 json_data={"foo": "bar"},
1415 ),
1416 )
1417 assert request.url == "http://localhost:5000/custom/path/foo"
1418
1419 @pytest.mark.parametrize(
1420 "client",
1421 [
1422 AsyncOpenAI(
1423 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1424 ),
1425 AsyncOpenAI(
1426 base_url="http://localhost:5000/custom/path/",
1427 api_key=api_key,
1428 _strict_response_validation=True,
1429 http_client=httpx.AsyncClient(),
1430 ),
1431 ],
1432 ids=["standard", "custom http client"],
1433 )
1434 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1435 request = client._build_request(
1436 FinalRequestOptions(
1437 method="post",
1438 url="/foo",
1439 json_data={"foo": "bar"},
1440 ),
1441 )
1442 assert request.url == "http://localhost:5000/custom/path/foo"
1443
1444 @pytest.mark.parametrize(
1445 "client",
1446 [
1447 AsyncOpenAI(
1448 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1449 ),
1450 AsyncOpenAI(
1451 base_url="http://localhost:5000/custom/path/",
1452 api_key=api_key,
1453 _strict_response_validation=True,
1454 http_client=httpx.AsyncClient(),
1455 ),
1456 ],
1457 ids=["standard", "custom http client"],
1458 )
1459 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1460 request = client._build_request(
1461 FinalRequestOptions(
1462 method="post",
1463 url="https://myapi.com/foo",
1464 json_data={"foo": "bar"},
1465 ),
1466 )
1467 assert request.url == "https://myapi.com/foo"
1468
1469 async def test_copied_client_does_not_close_http(self) -> None:
1470 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1471 assert not client.is_closed()
1472
1473 copied = client.copy()
1474 assert copied is not client
1475
1476 del copied
1477
1478 await asyncio.sleep(0.2)
1479 assert not client.is_closed()
1480
1481 async def test_client_context_manager(self) -> None:
1482 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1483 async with client as c2:
1484 assert c2 is client
1485 assert not c2.is_closed()
1486 assert not client.is_closed()
1487 assert client.is_closed()
1488
1489 @pytest.mark.respx(base_url=base_url)
1490 @pytest.mark.asyncio
1491 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1492 class Model(BaseModel):
1493 foo: str
1494
1495 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1496
1497 with pytest.raises(APIResponseValidationError) as exc:
1498 await self.client.get("/foo", cast_to=Model)
1499
1500 assert isinstance(exc.value.__cause__, ValidationError)
1501
1502 async def test_client_max_retries_validation(self) -> None:
1503 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
1504 AsyncOpenAI(
1505 base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)
1506 )
1507
1508 @pytest.mark.respx(base_url=base_url)
1509 @pytest.mark.asyncio
1510 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1511 class Model(BaseModel):
1512 name: str
1513
1514 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1515
1516 stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1517 assert isinstance(stream, AsyncStream)
1518 await stream.response.aclose()
1519
1520 @pytest.mark.respx(base_url=base_url)
1521 @pytest.mark.asyncio
1522 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1523 class Model(BaseModel):
1524 name: str
1525
1526 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1527
1528 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1529
1530 with pytest.raises(APIResponseValidationError):
1531 await strict_client.get("/foo", cast_to=Model)
1532
1533 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1534
1535 response = await client.get("/foo", cast_to=Model)
1536 assert isinstance(response, str) # type: ignore[unreachable]
1537
1538 @pytest.mark.parametrize(
1539 "remaining_retries,retry_after,timeout",
1540 [
1541 [3, "20", 20],
1542 [3, "0", 0.5],
1543 [3, "-10", 0.5],
1544 [3, "60", 60],
1545 [3, "61", 0.5],
1546 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1547 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1548 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1549 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1550 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1551 [3, "99999999999999999999999999999999999", 0.5],
1552 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1553 [3, "", 0.5],
1554 [2, "", 0.5 * 2.0],
1555 [1, "", 0.5 * 4.0],
1556 ],
1557 )
1558 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1559 @pytest.mark.asyncio
1560 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1561 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1562
1563 headers = httpx.Headers({"retry-after": retry_after})
1564 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1565 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1566 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1567
1568 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1569 @pytest.mark.respx(base_url=base_url)
1570 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1571 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1572
1573 with pytest.raises(APITimeoutError):
1574 await self.client.post(
1575 "/chat/completions",
1576 body=cast(
1577 object,
1578 dict(
1579 messages=[
1580 {
1581 "role": "user",
1582 "content": "Say this is a test",
1583 }
1584 ],
1585 model="gpt-3.5-turbo",
1586 ),
1587 ),
1588 cast_to=httpx.Response,
1589 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1590 )
1591
1592 assert _get_open_connections(self.client) == 0
1593
1594 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1595 @pytest.mark.respx(base_url=base_url)
1596 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1597 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1598
1599 with pytest.raises(APIStatusError):
1600 await self.client.post(
1601 "/chat/completions",
1602 body=cast(
1603 object,
1604 dict(
1605 messages=[
1606 {
1607 "role": "user",
1608 "content": "Say this is a test",
1609 }
1610 ],
1611 model="gpt-3.5-turbo",
1612 ),
1613 ),
1614 cast_to=httpx.Response,
1615 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1616 )
1617
1618 assert _get_open_connections(self.client) == 0
1619
1620 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1621 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1622 @pytest.mark.respx(base_url=base_url)
1623 @pytest.mark.asyncio
1624 async def test_retries_taken(
1625 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1626 ) -> None:
1627 client = async_client.with_options(max_retries=4)
1628
1629 nb_retries = 0
1630
1631 def retry_handler(_request: httpx.Request) -> httpx.Response:
1632 nonlocal nb_retries
1633 if nb_retries < failures_before_success:
1634 nb_retries += 1
1635 return httpx.Response(500)
1636 return httpx.Response(200)
1637
1638 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1639
1640 response = await client.chat.completions.with_raw_response.create(
1641 messages=[
1642 {
1643 "content": "string",
1644 "role": "system",
1645 }
1646 ],
1647 model="gpt-4o",
1648 )
1649
1650 assert response.retries_taken == failures_before_success
1651 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1652
1653 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1654 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1655 @pytest.mark.respx(base_url=base_url)
1656 @pytest.mark.asyncio
1657 async def test_omit_retry_count_header(
1658 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1659 ) -> None:
1660 client = async_client.with_options(max_retries=4)
1661
1662 nb_retries = 0
1663
1664 def retry_handler(_request: httpx.Request) -> httpx.Response:
1665 nonlocal nb_retries
1666 if nb_retries < failures_before_success:
1667 nb_retries += 1
1668 return httpx.Response(500)
1669 return httpx.Response(200)
1670
1671 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1672
1673 response = await client.chat.completions.with_raw_response.create(
1674 messages=[
1675 {
1676 "content": "string",
1677 "role": "system",
1678 }
1679 ],
1680 model="gpt-4o",
1681 extra_headers={"x-stainless-retry-count": Omit()},
1682 )
1683
1684 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
1685
1686 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1687 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1688 @pytest.mark.respx(base_url=base_url)
1689 @pytest.mark.asyncio
1690 async def test_overwrite_retry_count_header(
1691 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1692 ) -> None:
1693 client = async_client.with_options(max_retries=4)
1694
1695 nb_retries = 0
1696
1697 def retry_handler(_request: httpx.Request) -> httpx.Response:
1698 nonlocal nb_retries
1699 if nb_retries < failures_before_success:
1700 nb_retries += 1
1701 return httpx.Response(500)
1702 return httpx.Response(200)
1703
1704 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1705
1706 response = await client.chat.completions.with_raw_response.create(
1707 messages=[
1708 {
1709 "content": "string",
1710 "role": "system",
1711 }
1712 ],
1713 model="gpt-4o",
1714 extra_headers={"x-stainless-retry-count": "42"},
1715 )
1716
1717 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
1718
1719 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1720 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1721 @pytest.mark.respx(base_url=base_url)
1722 @pytest.mark.asyncio
1723 async def test_retries_taken_new_response_class(
1724 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1725 ) -> None:
1726 client = async_client.with_options(max_retries=4)
1727
1728 nb_retries = 0
1729
1730 def retry_handler(_request: httpx.Request) -> httpx.Response:
1731 nonlocal nb_retries
1732 if nb_retries < failures_before_success:
1733 nb_retries += 1
1734 return httpx.Response(500)
1735 return httpx.Response(200)
1736
1737 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1738
1739 async with client.chat.completions.with_streaming_response.create(
1740 messages=[
1741 {
1742 "content": "string",
1743 "role": "system",
1744 }
1745 ],
1746 model="gpt-4o",
1747 ) as response:
1748 assert response.retries_taken == failures_before_success
1749 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1750