openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
dfdcf571ced31f7859cd1871be39e2fb3af6bafa

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1768lines · modecode

1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import gc
6import os
7import json
8import asyncio
9import inspect
10import tracemalloc
11from typing import Any, Union, cast
12from unittest import mock
13from typing_extensions import Literal
14
15import httpx
16import pytest
17from respx import MockRouter
18from pydantic import ValidationError
19
20from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
21from openai._types import Omit
22from openai._models import BaseModel, FinalRequestOptions
23from openai._constants import RAW_RESPONSE_HEADER
24from openai._streaming import Stream, AsyncStream
25from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
26from openai._base_client import DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, BaseClient, make_request_options
27
28from .utils import update_env
29
30base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
31api_key = "My API Key"
32
33
34def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
35 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
36 url = httpx.URL(request.url)
37 return dict(url.params)
38
39
40def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
41 return 0.1
42
43
44def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
45 transport = client._client._transport
46 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
47
48 pool = transport._pool
49 return len(pool._requests)
50
51
52class TestOpenAI:
53 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
54
55 @pytest.mark.respx(base_url=base_url)
56 def test_raw_response(self, respx_mock: MockRouter) -> None:
57 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
58
59 response = self.client.post("/foo", cast_to=httpx.Response)
60 assert response.status_code == 200
61 assert isinstance(response, httpx.Response)
62 assert response.json() == {"foo": "bar"}
63
64 @pytest.mark.respx(base_url=base_url)
65 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
66 respx_mock.post("/foo").mock(
67 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
68 )
69
70 response = self.client.post("/foo", cast_to=httpx.Response)
71 assert response.status_code == 200
72 assert isinstance(response, httpx.Response)
73 assert response.json() == {"foo": "bar"}
74
75 def test_copy(self) -> None:
76 copied = self.client.copy()
77 assert id(copied) != id(self.client)
78
79 copied = self.client.copy(api_key="another My API Key")
80 assert copied.api_key == "another My API Key"
81 assert self.client.api_key == "My API Key"
82
83 def test_copy_default_options(self) -> None:
84 # options that have a default are overridden correctly
85 copied = self.client.copy(max_retries=7)
86 assert copied.max_retries == 7
87 assert self.client.max_retries == 2
88
89 copied2 = copied.copy(max_retries=6)
90 assert copied2.max_retries == 6
91 assert copied.max_retries == 7
92
93 # timeout
94 assert isinstance(self.client.timeout, httpx.Timeout)
95 copied = self.client.copy(timeout=None)
96 assert copied.timeout is None
97 assert isinstance(self.client.timeout, httpx.Timeout)
98
99 def test_copy_default_headers(self) -> None:
100 client = OpenAI(
101 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
102 )
103 assert client.default_headers["X-Foo"] == "bar"
104
105 # does not override the already given value when not specified
106 copied = client.copy()
107 assert copied.default_headers["X-Foo"] == "bar"
108
109 # merges already given headers
110 copied = client.copy(default_headers={"X-Bar": "stainless"})
111 assert copied.default_headers["X-Foo"] == "bar"
112 assert copied.default_headers["X-Bar"] == "stainless"
113
114 # uses new values for any already given headers
115 copied = client.copy(default_headers={"X-Foo": "stainless"})
116 assert copied.default_headers["X-Foo"] == "stainless"
117
118 # set_default_headers
119
120 # completely overrides already set values
121 copied = client.copy(set_default_headers={})
122 assert copied.default_headers.get("X-Foo") is None
123
124 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
125 assert copied.default_headers["X-Bar"] == "Robert"
126
127 with pytest.raises(
128 ValueError,
129 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
130 ):
131 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
132
133 def test_copy_default_query(self) -> None:
134 client = OpenAI(
135 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
136 )
137 assert _get_params(client)["foo"] == "bar"
138
139 # does not override the already given value when not specified
140 copied = client.copy()
141 assert _get_params(copied)["foo"] == "bar"
142
143 # merges already given params
144 copied = client.copy(default_query={"bar": "stainless"})
145 params = _get_params(copied)
146 assert params["foo"] == "bar"
147 assert params["bar"] == "stainless"
148
149 # uses new values for any already given headers
150 copied = client.copy(default_query={"foo": "stainless"})
151 assert _get_params(copied)["foo"] == "stainless"
152
153 # set_default_query
154
155 # completely overrides already set values
156 copied = client.copy(set_default_query={})
157 assert _get_params(copied) == {}
158
159 copied = client.copy(set_default_query={"bar": "Robert"})
160 assert _get_params(copied)["bar"] == "Robert"
161
162 with pytest.raises(
163 ValueError,
164 # TODO: update
165 match="`default_query` and `set_default_query` arguments are mutually exclusive",
166 ):
167 client.copy(set_default_query={}, default_query={"foo": "Bar"})
168
169 def test_copy_signature(self) -> None:
170 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
171 init_signature = inspect.signature(
172 # mypy doesn't like that we access the `__init__` property.
173 self.client.__init__, # type: ignore[misc]
174 )
175 copy_signature = inspect.signature(self.client.copy)
176 exclude_params = {"transport", "proxies", "_strict_response_validation"}
177
178 for name in init_signature.parameters.keys():
179 if name in exclude_params:
180 continue
181
182 copy_param = copy_signature.parameters.get(name)
183 assert copy_param is not None, f"copy() signature is missing the {name} param"
184
185 def test_copy_build_request(self) -> None:
186 options = FinalRequestOptions(method="get", url="/foo")
187
188 def build_request(options: FinalRequestOptions) -> None:
189 client = self.client.copy()
190 client._build_request(options)
191
192 # ensure that the machinery is warmed up before tracing starts.
193 build_request(options)
194 gc.collect()
195
196 tracemalloc.start(1000)
197
198 snapshot_before = tracemalloc.take_snapshot()
199
200 ITERATIONS = 10
201 for _ in range(ITERATIONS):
202 build_request(options)
203
204 gc.collect()
205 snapshot_after = tracemalloc.take_snapshot()
206
207 tracemalloc.stop()
208
209 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
210 if diff.count == 0:
211 # Avoid false positives by considering only leaks (i.e. allocations that persist).
212 return
213
214 if diff.count % ITERATIONS != 0:
215 # Avoid false positives by considering only leaks that appear per iteration.
216 return
217
218 for frame in diff.traceback:
219 if any(
220 frame.filename.endswith(fragment)
221 for fragment in [
222 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
223 #
224 # removing the decorator fixes the leak for reasons we don't understand.
225 "openai/_legacy_response.py",
226 "openai/_response.py",
227 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
228 "openai/_compat.py",
229 # Standard library leaks we don't care about.
230 "/logging/__init__.py",
231 ]
232 ):
233 return
234
235 leaks.append(diff)
236
237 leaks: list[tracemalloc.StatisticDiff] = []
238 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
239 add_leak(leaks, diff)
240 if leaks:
241 for leak in leaks:
242 print("MEMORY LEAK:", leak)
243 for frame in leak.traceback:
244 print(frame)
245 raise AssertionError()
246
247 def test_request_timeout(self) -> None:
248 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
249 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
250 assert timeout == DEFAULT_TIMEOUT
251
252 request = self.client._build_request(
253 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
254 )
255 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
256 assert timeout == httpx.Timeout(100.0)
257
258 def test_client_timeout_option(self) -> None:
259 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
260
261 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
262 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
263 assert timeout == httpx.Timeout(0)
264
265 def test_http_client_timeout_option(self) -> None:
266 # custom timeout given to the httpx client should be used
267 with httpx.Client(timeout=None) as http_client:
268 client = OpenAI(
269 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
270 )
271
272 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
273 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
274 assert timeout == httpx.Timeout(None)
275
276 # no timeout given to the httpx client should not use the httpx default
277 with httpx.Client() as http_client:
278 client = OpenAI(
279 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
280 )
281
282 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
283 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
284 assert timeout == DEFAULT_TIMEOUT
285
286 # explicitly passing the default timeout currently results in it being ignored
287 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
288 client = OpenAI(
289 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
290 )
291
292 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
293 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
294 assert timeout == DEFAULT_TIMEOUT # our default
295
296 async def test_invalid_http_client(self) -> None:
297 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
298 async with httpx.AsyncClient() as http_client:
299 OpenAI(
300 base_url=base_url,
301 api_key=api_key,
302 _strict_response_validation=True,
303 http_client=cast(Any, http_client),
304 )
305
306 def test_default_headers_option(self) -> None:
307 client = OpenAI(
308 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
309 )
310 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
311 assert request.headers.get("x-foo") == "bar"
312 assert request.headers.get("x-stainless-lang") == "python"
313
314 client2 = OpenAI(
315 base_url=base_url,
316 api_key=api_key,
317 _strict_response_validation=True,
318 default_headers={
319 "X-Foo": "stainless",
320 "X-Stainless-Lang": "my-overriding-header",
321 },
322 )
323 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
324 assert request.headers.get("x-foo") == "stainless"
325 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
326
327 def test_validate_headers(self) -> None:
328 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
329 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
330 assert request.headers.get("Authorization") == f"Bearer {api_key}"
331
332 with pytest.raises(OpenAIError):
333 with update_env(**{"OPENAI_API_KEY": Omit()}):
334 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
335 _ = client2
336
337 def test_default_query_option(self) -> None:
338 client = OpenAI(
339 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
340 )
341 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
342 url = httpx.URL(request.url)
343 assert dict(url.params) == {"query_param": "bar"}
344
345 request = client._build_request(
346 FinalRequestOptions(
347 method="get",
348 url="/foo",
349 params={"foo": "baz", "query_param": "overriden"},
350 )
351 )
352 url = httpx.URL(request.url)
353 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
354
355 def test_request_extra_json(self) -> None:
356 request = self.client._build_request(
357 FinalRequestOptions(
358 method="post",
359 url="/foo",
360 json_data={"foo": "bar"},
361 extra_json={"baz": False},
362 ),
363 )
364 data = json.loads(request.content.decode("utf-8"))
365 assert data == {"foo": "bar", "baz": False}
366
367 request = self.client._build_request(
368 FinalRequestOptions(
369 method="post",
370 url="/foo",
371 extra_json={"baz": False},
372 ),
373 )
374 data = json.loads(request.content.decode("utf-8"))
375 assert data == {"baz": False}
376
377 # `extra_json` takes priority over `json_data` when keys clash
378 request = self.client._build_request(
379 FinalRequestOptions(
380 method="post",
381 url="/foo",
382 json_data={"foo": "bar", "baz": True},
383 extra_json={"baz": None},
384 ),
385 )
386 data = json.loads(request.content.decode("utf-8"))
387 assert data == {"foo": "bar", "baz": None}
388
389 def test_request_extra_headers(self) -> None:
390 request = self.client._build_request(
391 FinalRequestOptions(
392 method="post",
393 url="/foo",
394 **make_request_options(extra_headers={"X-Foo": "Foo"}),
395 ),
396 )
397 assert request.headers.get("X-Foo") == "Foo"
398
399 # `extra_headers` takes priority over `default_headers` when keys clash
400 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
401 FinalRequestOptions(
402 method="post",
403 url="/foo",
404 **make_request_options(
405 extra_headers={"X-Bar": "false"},
406 ),
407 ),
408 )
409 assert request.headers.get("X-Bar") == "false"
410
411 def test_request_extra_query(self) -> None:
412 request = self.client._build_request(
413 FinalRequestOptions(
414 method="post",
415 url="/foo",
416 **make_request_options(
417 extra_query={"my_query_param": "Foo"},
418 ),
419 ),
420 )
421 params = dict(request.url.params)
422 assert params == {"my_query_param": "Foo"}
423
424 # if both `query` and `extra_query` are given, they are merged
425 request = self.client._build_request(
426 FinalRequestOptions(
427 method="post",
428 url="/foo",
429 **make_request_options(
430 query={"bar": "1"},
431 extra_query={"foo": "2"},
432 ),
433 ),
434 )
435 params = dict(request.url.params)
436 assert params == {"bar": "1", "foo": "2"}
437
438 # `extra_query` takes priority over `query` when keys clash
439 request = self.client._build_request(
440 FinalRequestOptions(
441 method="post",
442 url="/foo",
443 **make_request_options(
444 query={"foo": "1"},
445 extra_query={"foo": "2"},
446 ),
447 ),
448 )
449 params = dict(request.url.params)
450 assert params == {"foo": "2"}
451
452 def test_multipart_repeating_array(self, client: OpenAI) -> None:
453 request = client._build_request(
454 FinalRequestOptions.construct(
455 method="get",
456 url="/foo",
457 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
458 json_data={"array": ["foo", "bar"]},
459 files=[("foo.txt", b"hello world")],
460 )
461 )
462
463 assert request.read().split(b"\r\n") == [
464 b"--6b7ba517decee4a450543ea6ae821c82",
465 b'Content-Disposition: form-data; name="array[]"',
466 b"",
467 b"foo",
468 b"--6b7ba517decee4a450543ea6ae821c82",
469 b'Content-Disposition: form-data; name="array[]"',
470 b"",
471 b"bar",
472 b"--6b7ba517decee4a450543ea6ae821c82",
473 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
474 b"Content-Type: application/octet-stream",
475 b"",
476 b"hello world",
477 b"--6b7ba517decee4a450543ea6ae821c82--",
478 b"",
479 ]
480
481 @pytest.mark.respx(base_url=base_url)
482 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
483 class Model1(BaseModel):
484 name: str
485
486 class Model2(BaseModel):
487 foo: str
488
489 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
490
491 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
492 assert isinstance(response, Model2)
493 assert response.foo == "bar"
494
495 @pytest.mark.respx(base_url=base_url)
496 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
497 """Union of objects with the same field name using a different type"""
498
499 class Model1(BaseModel):
500 foo: int
501
502 class Model2(BaseModel):
503 foo: str
504
505 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
506
507 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
508 assert isinstance(response, Model2)
509 assert response.foo == "bar"
510
511 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
512
513 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
514 assert isinstance(response, Model1)
515 assert response.foo == 1
516
517 @pytest.mark.respx(base_url=base_url)
518 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
519 """
520 Response that sets Content-Type to something other than application/json but returns json data
521 """
522
523 class Model(BaseModel):
524 foo: int
525
526 respx_mock.get("/foo").mock(
527 return_value=httpx.Response(
528 200,
529 content=json.dumps({"foo": 2}),
530 headers={"Content-Type": "application/text"},
531 )
532 )
533
534 response = self.client.get("/foo", cast_to=Model)
535 assert isinstance(response, Model)
536 assert response.foo == 2
537
538 def test_base_url_setter(self) -> None:
539 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
540 assert client.base_url == "https://example.com/from_init/"
541
542 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
543
544 assert client.base_url == "https://example.com/from_setter/"
545
546 def test_base_url_env(self) -> None:
547 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
548 client = OpenAI(api_key=api_key, _strict_response_validation=True)
549 assert client.base_url == "http://localhost:5000/from/env/"
550
551 @pytest.mark.parametrize(
552 "client",
553 [
554 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
555 OpenAI(
556 base_url="http://localhost:5000/custom/path/",
557 api_key=api_key,
558 _strict_response_validation=True,
559 http_client=httpx.Client(),
560 ),
561 ],
562 ids=["standard", "custom http client"],
563 )
564 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
565 request = client._build_request(
566 FinalRequestOptions(
567 method="post",
568 url="/foo",
569 json_data={"foo": "bar"},
570 ),
571 )
572 assert request.url == "http://localhost:5000/custom/path/foo"
573
574 @pytest.mark.parametrize(
575 "client",
576 [
577 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
578 OpenAI(
579 base_url="http://localhost:5000/custom/path/",
580 api_key=api_key,
581 _strict_response_validation=True,
582 http_client=httpx.Client(),
583 ),
584 ],
585 ids=["standard", "custom http client"],
586 )
587 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
588 request = client._build_request(
589 FinalRequestOptions(
590 method="post",
591 url="/foo",
592 json_data={"foo": "bar"},
593 ),
594 )
595 assert request.url == "http://localhost:5000/custom/path/foo"
596
597 @pytest.mark.parametrize(
598 "client",
599 [
600 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
601 OpenAI(
602 base_url="http://localhost:5000/custom/path/",
603 api_key=api_key,
604 _strict_response_validation=True,
605 http_client=httpx.Client(),
606 ),
607 ],
608 ids=["standard", "custom http client"],
609 )
610 def test_absolute_request_url(self, client: OpenAI) -> None:
611 request = client._build_request(
612 FinalRequestOptions(
613 method="post",
614 url="https://myapi.com/foo",
615 json_data={"foo": "bar"},
616 ),
617 )
618 assert request.url == "https://myapi.com/foo"
619
620 def test_copied_client_does_not_close_http(self) -> None:
621 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
622 assert not client.is_closed()
623
624 copied = client.copy()
625 assert copied is not client
626
627 del copied
628
629 assert not client.is_closed()
630
631 def test_client_context_manager(self) -> None:
632 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
633 with client as c2:
634 assert c2 is client
635 assert not c2.is_closed()
636 assert not client.is_closed()
637 assert client.is_closed()
638
639 @pytest.mark.respx(base_url=base_url)
640 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
641 class Model(BaseModel):
642 foo: str
643
644 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
645
646 with pytest.raises(APIResponseValidationError) as exc:
647 self.client.get("/foo", cast_to=Model)
648
649 assert isinstance(exc.value.__cause__, ValidationError)
650
651 def test_client_max_retries_validation(self) -> None:
652 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
653 OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None))
654
655 @pytest.mark.respx(base_url=base_url)
656 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
657 class Model(BaseModel):
658 name: str
659
660 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
661
662 stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
663 assert isinstance(stream, Stream)
664 stream.response.close()
665
666 @pytest.mark.respx(base_url=base_url)
667 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
668 class Model(BaseModel):
669 name: str
670
671 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
672
673 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
674
675 with pytest.raises(APIResponseValidationError):
676 strict_client.get("/foo", cast_to=Model)
677
678 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
679
680 response = client.get("/foo", cast_to=Model)
681 assert isinstance(response, str) # type: ignore[unreachable]
682
683 @pytest.mark.parametrize(
684 "remaining_retries,retry_after,timeout",
685 [
686 [3, "20", 20],
687 [3, "0", 0.5],
688 [3, "-10", 0.5],
689 [3, "60", 60],
690 [3, "61", 0.5],
691 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
692 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
693 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
694 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
695 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
696 [3, "99999999999999999999999999999999999", 0.5],
697 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
698 [3, "", 0.5],
699 [2, "", 0.5 * 2.0],
700 [1, "", 0.5 * 4.0],
701 [-1100, "", 7.8], # test large number potentially overflowing
702 ],
703 )
704 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
705 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
706 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
707
708 headers = httpx.Headers({"retry-after": retry_after})
709 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
710 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
711 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
712
713 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
714 @pytest.mark.respx(base_url=base_url)
715 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
716 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
717
718 with pytest.raises(APITimeoutError):
719 self.client.post(
720 "/chat/completions",
721 body=cast(
722 object,
723 dict(
724 messages=[
725 {
726 "role": "user",
727 "content": "Say this is a test",
728 }
729 ],
730 model="gpt-3.5-turbo",
731 ),
732 ),
733 cast_to=httpx.Response,
734 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
735 )
736
737 assert _get_open_connections(self.client) == 0
738
739 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
740 @pytest.mark.respx(base_url=base_url)
741 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
742 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
743
744 with pytest.raises(APIStatusError):
745 self.client.post(
746 "/chat/completions",
747 body=cast(
748 object,
749 dict(
750 messages=[
751 {
752 "role": "user",
753 "content": "Say this is a test",
754 }
755 ],
756 model="gpt-3.5-turbo",
757 ),
758 ),
759 cast_to=httpx.Response,
760 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
761 )
762
763 assert _get_open_connections(self.client) == 0
764
765 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
766 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
767 @pytest.mark.respx(base_url=base_url)
768 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
769 def test_retries_taken(
770 self,
771 client: OpenAI,
772 failures_before_success: int,
773 failure_mode: Literal["status", "exception"],
774 respx_mock: MockRouter,
775 ) -> None:
776 client = client.with_options(max_retries=4)
777
778 nb_retries = 0
779
780 def retry_handler(_request: httpx.Request) -> httpx.Response:
781 nonlocal nb_retries
782 if nb_retries < failures_before_success:
783 nb_retries += 1
784 if failure_mode == "exception":
785 raise RuntimeError("oops")
786 return httpx.Response(500)
787 return httpx.Response(200)
788
789 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
790
791 response = client.chat.completions.with_raw_response.create(
792 messages=[
793 {
794 "content": "string",
795 "role": "system",
796 }
797 ],
798 model="gpt-4o",
799 )
800
801 assert response.retries_taken == failures_before_success
802 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
803
804 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
805 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
806 @pytest.mark.respx(base_url=base_url)
807 def test_omit_retry_count_header(
808 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
809 ) -> None:
810 client = client.with_options(max_retries=4)
811
812 nb_retries = 0
813
814 def retry_handler(_request: httpx.Request) -> httpx.Response:
815 nonlocal nb_retries
816 if nb_retries < failures_before_success:
817 nb_retries += 1
818 return httpx.Response(500)
819 return httpx.Response(200)
820
821 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
822
823 response = client.chat.completions.with_raw_response.create(
824 messages=[
825 {
826 "content": "string",
827 "role": "system",
828 }
829 ],
830 model="gpt-4o",
831 extra_headers={"x-stainless-retry-count": Omit()},
832 )
833
834 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
835
836 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
837 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
838 @pytest.mark.respx(base_url=base_url)
839 def test_overwrite_retry_count_header(
840 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
841 ) -> None:
842 client = client.with_options(max_retries=4)
843
844 nb_retries = 0
845
846 def retry_handler(_request: httpx.Request) -> httpx.Response:
847 nonlocal nb_retries
848 if nb_retries < failures_before_success:
849 nb_retries += 1
850 return httpx.Response(500)
851 return httpx.Response(200)
852
853 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
854
855 response = client.chat.completions.with_raw_response.create(
856 messages=[
857 {
858 "content": "string",
859 "role": "system",
860 }
861 ],
862 model="gpt-4o",
863 extra_headers={"x-stainless-retry-count": "42"},
864 )
865
866 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
867
868 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
869 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
870 @pytest.mark.respx(base_url=base_url)
871 def test_retries_taken_new_response_class(
872 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
873 ) -> None:
874 client = client.with_options(max_retries=4)
875
876 nb_retries = 0
877
878 def retry_handler(_request: httpx.Request) -> httpx.Response:
879 nonlocal nb_retries
880 if nb_retries < failures_before_success:
881 nb_retries += 1
882 return httpx.Response(500)
883 return httpx.Response(200)
884
885 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
886
887 with client.chat.completions.with_streaming_response.create(
888 messages=[
889 {
890 "content": "string",
891 "role": "system",
892 }
893 ],
894 model="gpt-4o",
895 ) as response:
896 assert response.retries_taken == failures_before_success
897 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
898
899
900class TestAsyncOpenAI:
901 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
902
903 @pytest.mark.respx(base_url=base_url)
904 @pytest.mark.asyncio
905 async def test_raw_response(self, respx_mock: MockRouter) -> None:
906 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
907
908 response = await self.client.post("/foo", cast_to=httpx.Response)
909 assert response.status_code == 200
910 assert isinstance(response, httpx.Response)
911 assert response.json() == {"foo": "bar"}
912
913 @pytest.mark.respx(base_url=base_url)
914 @pytest.mark.asyncio
915 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
916 respx_mock.post("/foo").mock(
917 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
918 )
919
920 response = await self.client.post("/foo", cast_to=httpx.Response)
921 assert response.status_code == 200
922 assert isinstance(response, httpx.Response)
923 assert response.json() == {"foo": "bar"}
924
925 def test_copy(self) -> None:
926 copied = self.client.copy()
927 assert id(copied) != id(self.client)
928
929 copied = self.client.copy(api_key="another My API Key")
930 assert copied.api_key == "another My API Key"
931 assert self.client.api_key == "My API Key"
932
933 def test_copy_default_options(self) -> None:
934 # options that have a default are overridden correctly
935 copied = self.client.copy(max_retries=7)
936 assert copied.max_retries == 7
937 assert self.client.max_retries == 2
938
939 copied2 = copied.copy(max_retries=6)
940 assert copied2.max_retries == 6
941 assert copied.max_retries == 7
942
943 # timeout
944 assert isinstance(self.client.timeout, httpx.Timeout)
945 copied = self.client.copy(timeout=None)
946 assert copied.timeout is None
947 assert isinstance(self.client.timeout, httpx.Timeout)
948
949 def test_copy_default_headers(self) -> None:
950 client = AsyncOpenAI(
951 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
952 )
953 assert client.default_headers["X-Foo"] == "bar"
954
955 # does not override the already given value when not specified
956 copied = client.copy()
957 assert copied.default_headers["X-Foo"] == "bar"
958
959 # merges already given headers
960 copied = client.copy(default_headers={"X-Bar": "stainless"})
961 assert copied.default_headers["X-Foo"] == "bar"
962 assert copied.default_headers["X-Bar"] == "stainless"
963
964 # uses new values for any already given headers
965 copied = client.copy(default_headers={"X-Foo": "stainless"})
966 assert copied.default_headers["X-Foo"] == "stainless"
967
968 # set_default_headers
969
970 # completely overrides already set values
971 copied = client.copy(set_default_headers={})
972 assert copied.default_headers.get("X-Foo") is None
973
974 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
975 assert copied.default_headers["X-Bar"] == "Robert"
976
977 with pytest.raises(
978 ValueError,
979 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
980 ):
981 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
982
983 def test_copy_default_query(self) -> None:
984 client = AsyncOpenAI(
985 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
986 )
987 assert _get_params(client)["foo"] == "bar"
988
989 # does not override the already given value when not specified
990 copied = client.copy()
991 assert _get_params(copied)["foo"] == "bar"
992
993 # merges already given params
994 copied = client.copy(default_query={"bar": "stainless"})
995 params = _get_params(copied)
996 assert params["foo"] == "bar"
997 assert params["bar"] == "stainless"
998
999 # uses new values for any already given headers
1000 copied = client.copy(default_query={"foo": "stainless"})
1001 assert _get_params(copied)["foo"] == "stainless"
1002
1003 # set_default_query
1004
1005 # completely overrides already set values
1006 copied = client.copy(set_default_query={})
1007 assert _get_params(copied) == {}
1008
1009 copied = client.copy(set_default_query={"bar": "Robert"})
1010 assert _get_params(copied)["bar"] == "Robert"
1011
1012 with pytest.raises(
1013 ValueError,
1014 # TODO: update
1015 match="`default_query` and `set_default_query` arguments are mutually exclusive",
1016 ):
1017 client.copy(set_default_query={}, default_query={"foo": "Bar"})
1018
1019 def test_copy_signature(self) -> None:
1020 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
1021 init_signature = inspect.signature(
1022 # mypy doesn't like that we access the `__init__` property.
1023 self.client.__init__, # type: ignore[misc]
1024 )
1025 copy_signature = inspect.signature(self.client.copy)
1026 exclude_params = {"transport", "proxies", "_strict_response_validation"}
1027
1028 for name in init_signature.parameters.keys():
1029 if name in exclude_params:
1030 continue
1031
1032 copy_param = copy_signature.parameters.get(name)
1033 assert copy_param is not None, f"copy() signature is missing the {name} param"
1034
1035 def test_copy_build_request(self) -> None:
1036 options = FinalRequestOptions(method="get", url="/foo")
1037
1038 def build_request(options: FinalRequestOptions) -> None:
1039 client = self.client.copy()
1040 client._build_request(options)
1041
1042 # ensure that the machinery is warmed up before tracing starts.
1043 build_request(options)
1044 gc.collect()
1045
1046 tracemalloc.start(1000)
1047
1048 snapshot_before = tracemalloc.take_snapshot()
1049
1050 ITERATIONS = 10
1051 for _ in range(ITERATIONS):
1052 build_request(options)
1053
1054 gc.collect()
1055 snapshot_after = tracemalloc.take_snapshot()
1056
1057 tracemalloc.stop()
1058
1059 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
1060 if diff.count == 0:
1061 # Avoid false positives by considering only leaks (i.e. allocations that persist).
1062 return
1063
1064 if diff.count % ITERATIONS != 0:
1065 # Avoid false positives by considering only leaks that appear per iteration.
1066 return
1067
1068 for frame in diff.traceback:
1069 if any(
1070 frame.filename.endswith(fragment)
1071 for fragment in [
1072 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
1073 #
1074 # removing the decorator fixes the leak for reasons we don't understand.
1075 "openai/_legacy_response.py",
1076 "openai/_response.py",
1077 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
1078 "openai/_compat.py",
1079 # Standard library leaks we don't care about.
1080 "/logging/__init__.py",
1081 ]
1082 ):
1083 return
1084
1085 leaks.append(diff)
1086
1087 leaks: list[tracemalloc.StatisticDiff] = []
1088 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
1089 add_leak(leaks, diff)
1090 if leaks:
1091 for leak in leaks:
1092 print("MEMORY LEAK:", leak)
1093 for frame in leak.traceback:
1094 print(frame)
1095 raise AssertionError()
1096
1097 async def test_request_timeout(self) -> None:
1098 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
1099 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1100 assert timeout == DEFAULT_TIMEOUT
1101
1102 request = self.client._build_request(
1103 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
1104 )
1105 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1106 assert timeout == httpx.Timeout(100.0)
1107
1108 async def test_client_timeout_option(self) -> None:
1109 client = AsyncOpenAI(
1110 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
1111 )
1112
1113 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1114 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1115 assert timeout == httpx.Timeout(0)
1116
1117 async def test_http_client_timeout_option(self) -> None:
1118 # custom timeout given to the httpx client should be used
1119 async with httpx.AsyncClient(timeout=None) as http_client:
1120 client = AsyncOpenAI(
1121 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1122 )
1123
1124 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1125 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1126 assert timeout == httpx.Timeout(None)
1127
1128 # no timeout given to the httpx client should not use the httpx default
1129 async with httpx.AsyncClient() as http_client:
1130 client = AsyncOpenAI(
1131 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1132 )
1133
1134 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1135 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1136 assert timeout == DEFAULT_TIMEOUT
1137
1138 # explicitly passing the default timeout currently results in it being ignored
1139 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1140 client = AsyncOpenAI(
1141 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1142 )
1143
1144 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1145 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1146 assert timeout == DEFAULT_TIMEOUT # our default
1147
1148 def test_invalid_http_client(self) -> None:
1149 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1150 with httpx.Client() as http_client:
1151 AsyncOpenAI(
1152 base_url=base_url,
1153 api_key=api_key,
1154 _strict_response_validation=True,
1155 http_client=cast(Any, http_client),
1156 )
1157
1158 def test_default_headers_option(self) -> None:
1159 client = AsyncOpenAI(
1160 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1161 )
1162 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1163 assert request.headers.get("x-foo") == "bar"
1164 assert request.headers.get("x-stainless-lang") == "python"
1165
1166 client2 = AsyncOpenAI(
1167 base_url=base_url,
1168 api_key=api_key,
1169 _strict_response_validation=True,
1170 default_headers={
1171 "X-Foo": "stainless",
1172 "X-Stainless-Lang": "my-overriding-header",
1173 },
1174 )
1175 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1176 assert request.headers.get("x-foo") == "stainless"
1177 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1178
1179 def test_validate_headers(self) -> None:
1180 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1181 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1182 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1183
1184 with pytest.raises(OpenAIError):
1185 with update_env(**{"OPENAI_API_KEY": Omit()}):
1186 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1187 _ = client2
1188
1189 def test_default_query_option(self) -> None:
1190 client = AsyncOpenAI(
1191 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1192 )
1193 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1194 url = httpx.URL(request.url)
1195 assert dict(url.params) == {"query_param": "bar"}
1196
1197 request = client._build_request(
1198 FinalRequestOptions(
1199 method="get",
1200 url="/foo",
1201 params={"foo": "baz", "query_param": "overriden"},
1202 )
1203 )
1204 url = httpx.URL(request.url)
1205 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
1206
1207 def test_request_extra_json(self) -> None:
1208 request = self.client._build_request(
1209 FinalRequestOptions(
1210 method="post",
1211 url="/foo",
1212 json_data={"foo": "bar"},
1213 extra_json={"baz": False},
1214 ),
1215 )
1216 data = json.loads(request.content.decode("utf-8"))
1217 assert data == {"foo": "bar", "baz": False}
1218
1219 request = self.client._build_request(
1220 FinalRequestOptions(
1221 method="post",
1222 url="/foo",
1223 extra_json={"baz": False},
1224 ),
1225 )
1226 data = json.loads(request.content.decode("utf-8"))
1227 assert data == {"baz": False}
1228
1229 # `extra_json` takes priority over `json_data` when keys clash
1230 request = self.client._build_request(
1231 FinalRequestOptions(
1232 method="post",
1233 url="/foo",
1234 json_data={"foo": "bar", "baz": True},
1235 extra_json={"baz": None},
1236 ),
1237 )
1238 data = json.loads(request.content.decode("utf-8"))
1239 assert data == {"foo": "bar", "baz": None}
1240
1241 def test_request_extra_headers(self) -> None:
1242 request = self.client._build_request(
1243 FinalRequestOptions(
1244 method="post",
1245 url="/foo",
1246 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1247 ),
1248 )
1249 assert request.headers.get("X-Foo") == "Foo"
1250
1251 # `extra_headers` takes priority over `default_headers` when keys clash
1252 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1253 FinalRequestOptions(
1254 method="post",
1255 url="/foo",
1256 **make_request_options(
1257 extra_headers={"X-Bar": "false"},
1258 ),
1259 ),
1260 )
1261 assert request.headers.get("X-Bar") == "false"
1262
1263 def test_request_extra_query(self) -> None:
1264 request = self.client._build_request(
1265 FinalRequestOptions(
1266 method="post",
1267 url="/foo",
1268 **make_request_options(
1269 extra_query={"my_query_param": "Foo"},
1270 ),
1271 ),
1272 )
1273 params = dict(request.url.params)
1274 assert params == {"my_query_param": "Foo"}
1275
1276 # if both `query` and `extra_query` are given, they are merged
1277 request = self.client._build_request(
1278 FinalRequestOptions(
1279 method="post",
1280 url="/foo",
1281 **make_request_options(
1282 query={"bar": "1"},
1283 extra_query={"foo": "2"},
1284 ),
1285 ),
1286 )
1287 params = dict(request.url.params)
1288 assert params == {"bar": "1", "foo": "2"}
1289
1290 # `extra_query` takes priority over `query` when keys clash
1291 request = self.client._build_request(
1292 FinalRequestOptions(
1293 method="post",
1294 url="/foo",
1295 **make_request_options(
1296 query={"foo": "1"},
1297 extra_query={"foo": "2"},
1298 ),
1299 ),
1300 )
1301 params = dict(request.url.params)
1302 assert params == {"foo": "2"}
1303
1304 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1305 request = async_client._build_request(
1306 FinalRequestOptions.construct(
1307 method="get",
1308 url="/foo",
1309 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1310 json_data={"array": ["foo", "bar"]},
1311 files=[("foo.txt", b"hello world")],
1312 )
1313 )
1314
1315 assert request.read().split(b"\r\n") == [
1316 b"--6b7ba517decee4a450543ea6ae821c82",
1317 b'Content-Disposition: form-data; name="array[]"',
1318 b"",
1319 b"foo",
1320 b"--6b7ba517decee4a450543ea6ae821c82",
1321 b'Content-Disposition: form-data; name="array[]"',
1322 b"",
1323 b"bar",
1324 b"--6b7ba517decee4a450543ea6ae821c82",
1325 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1326 b"Content-Type: application/octet-stream",
1327 b"",
1328 b"hello world",
1329 b"--6b7ba517decee4a450543ea6ae821c82--",
1330 b"",
1331 ]
1332
1333 @pytest.mark.respx(base_url=base_url)
1334 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1335 class Model1(BaseModel):
1336 name: str
1337
1338 class Model2(BaseModel):
1339 foo: str
1340
1341 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1342
1343 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1344 assert isinstance(response, Model2)
1345 assert response.foo == "bar"
1346
1347 @pytest.mark.respx(base_url=base_url)
1348 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1349 """Union of objects with the same field name using a different type"""
1350
1351 class Model1(BaseModel):
1352 foo: int
1353
1354 class Model2(BaseModel):
1355 foo: str
1356
1357 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1358
1359 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1360 assert isinstance(response, Model2)
1361 assert response.foo == "bar"
1362
1363 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1364
1365 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1366 assert isinstance(response, Model1)
1367 assert response.foo == 1
1368
1369 @pytest.mark.respx(base_url=base_url)
1370 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1371 """
1372 Response that sets Content-Type to something other than application/json but returns json data
1373 """
1374
1375 class Model(BaseModel):
1376 foo: int
1377
1378 respx_mock.get("/foo").mock(
1379 return_value=httpx.Response(
1380 200,
1381 content=json.dumps({"foo": 2}),
1382 headers={"Content-Type": "application/text"},
1383 )
1384 )
1385
1386 response = await self.client.get("/foo", cast_to=Model)
1387 assert isinstance(response, Model)
1388 assert response.foo == 2
1389
1390 def test_base_url_setter(self) -> None:
1391 client = AsyncOpenAI(
1392 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1393 )
1394 assert client.base_url == "https://example.com/from_init/"
1395
1396 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1397
1398 assert client.base_url == "https://example.com/from_setter/"
1399
1400 def test_base_url_env(self) -> None:
1401 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1402 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1403 assert client.base_url == "http://localhost:5000/from/env/"
1404
1405 @pytest.mark.parametrize(
1406 "client",
1407 [
1408 AsyncOpenAI(
1409 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1410 ),
1411 AsyncOpenAI(
1412 base_url="http://localhost:5000/custom/path/",
1413 api_key=api_key,
1414 _strict_response_validation=True,
1415 http_client=httpx.AsyncClient(),
1416 ),
1417 ],
1418 ids=["standard", "custom http client"],
1419 )
1420 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1421 request = client._build_request(
1422 FinalRequestOptions(
1423 method="post",
1424 url="/foo",
1425 json_data={"foo": "bar"},
1426 ),
1427 )
1428 assert request.url == "http://localhost:5000/custom/path/foo"
1429
1430 @pytest.mark.parametrize(
1431 "client",
1432 [
1433 AsyncOpenAI(
1434 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1435 ),
1436 AsyncOpenAI(
1437 base_url="http://localhost:5000/custom/path/",
1438 api_key=api_key,
1439 _strict_response_validation=True,
1440 http_client=httpx.AsyncClient(),
1441 ),
1442 ],
1443 ids=["standard", "custom http client"],
1444 )
1445 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1446 request = client._build_request(
1447 FinalRequestOptions(
1448 method="post",
1449 url="/foo",
1450 json_data={"foo": "bar"},
1451 ),
1452 )
1453 assert request.url == "http://localhost:5000/custom/path/foo"
1454
1455 @pytest.mark.parametrize(
1456 "client",
1457 [
1458 AsyncOpenAI(
1459 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1460 ),
1461 AsyncOpenAI(
1462 base_url="http://localhost:5000/custom/path/",
1463 api_key=api_key,
1464 _strict_response_validation=True,
1465 http_client=httpx.AsyncClient(),
1466 ),
1467 ],
1468 ids=["standard", "custom http client"],
1469 )
1470 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1471 request = client._build_request(
1472 FinalRequestOptions(
1473 method="post",
1474 url="https://myapi.com/foo",
1475 json_data={"foo": "bar"},
1476 ),
1477 )
1478 assert request.url == "https://myapi.com/foo"
1479
1480 async def test_copied_client_does_not_close_http(self) -> None:
1481 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1482 assert not client.is_closed()
1483
1484 copied = client.copy()
1485 assert copied is not client
1486
1487 del copied
1488
1489 await asyncio.sleep(0.2)
1490 assert not client.is_closed()
1491
1492 async def test_client_context_manager(self) -> None:
1493 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1494 async with client as c2:
1495 assert c2 is client
1496 assert not c2.is_closed()
1497 assert not client.is_closed()
1498 assert client.is_closed()
1499
1500 @pytest.mark.respx(base_url=base_url)
1501 @pytest.mark.asyncio
1502 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1503 class Model(BaseModel):
1504 foo: str
1505
1506 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1507
1508 with pytest.raises(APIResponseValidationError) as exc:
1509 await self.client.get("/foo", cast_to=Model)
1510
1511 assert isinstance(exc.value.__cause__, ValidationError)
1512
1513 async def test_client_max_retries_validation(self) -> None:
1514 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
1515 AsyncOpenAI(
1516 base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)
1517 )
1518
1519 @pytest.mark.respx(base_url=base_url)
1520 @pytest.mark.asyncio
1521 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1522 class Model(BaseModel):
1523 name: str
1524
1525 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1526
1527 stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1528 assert isinstance(stream, AsyncStream)
1529 await stream.response.aclose()
1530
1531 @pytest.mark.respx(base_url=base_url)
1532 @pytest.mark.asyncio
1533 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1534 class Model(BaseModel):
1535 name: str
1536
1537 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1538
1539 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1540
1541 with pytest.raises(APIResponseValidationError):
1542 await strict_client.get("/foo", cast_to=Model)
1543
1544 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1545
1546 response = await client.get("/foo", cast_to=Model)
1547 assert isinstance(response, str) # type: ignore[unreachable]
1548
1549 @pytest.mark.parametrize(
1550 "remaining_retries,retry_after,timeout",
1551 [
1552 [3, "20", 20],
1553 [3, "0", 0.5],
1554 [3, "-10", 0.5],
1555 [3, "60", 60],
1556 [3, "61", 0.5],
1557 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1558 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1559 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1560 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1561 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1562 [3, "99999999999999999999999999999999999", 0.5],
1563 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1564 [3, "", 0.5],
1565 [2, "", 0.5 * 2.0],
1566 [1, "", 0.5 * 4.0],
1567 [-1100, "", 7.8], # test large number potentially overflowing
1568 ],
1569 )
1570 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1571 @pytest.mark.asyncio
1572 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1573 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1574
1575 headers = httpx.Headers({"retry-after": retry_after})
1576 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1577 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1578 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1579
1580 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1581 @pytest.mark.respx(base_url=base_url)
1582 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1583 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1584
1585 with pytest.raises(APITimeoutError):
1586 await self.client.post(
1587 "/chat/completions",
1588 body=cast(
1589 object,
1590 dict(
1591 messages=[
1592 {
1593 "role": "user",
1594 "content": "Say this is a test",
1595 }
1596 ],
1597 model="gpt-3.5-turbo",
1598 ),
1599 ),
1600 cast_to=httpx.Response,
1601 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1602 )
1603
1604 assert _get_open_connections(self.client) == 0
1605
1606 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1607 @pytest.mark.respx(base_url=base_url)
1608 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1609 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1610
1611 with pytest.raises(APIStatusError):
1612 await self.client.post(
1613 "/chat/completions",
1614 body=cast(
1615 object,
1616 dict(
1617 messages=[
1618 {
1619 "role": "user",
1620 "content": "Say this is a test",
1621 }
1622 ],
1623 model="gpt-3.5-turbo",
1624 ),
1625 ),
1626 cast_to=httpx.Response,
1627 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1628 )
1629
1630 assert _get_open_connections(self.client) == 0
1631
1632 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1633 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1634 @pytest.mark.respx(base_url=base_url)
1635 @pytest.mark.asyncio
1636 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
1637 async def test_retries_taken(
1638 self,
1639 async_client: AsyncOpenAI,
1640 failures_before_success: int,
1641 failure_mode: Literal["status", "exception"],
1642 respx_mock: MockRouter,
1643 ) -> None:
1644 client = async_client.with_options(max_retries=4)
1645
1646 nb_retries = 0
1647
1648 def retry_handler(_request: httpx.Request) -> httpx.Response:
1649 nonlocal nb_retries
1650 if nb_retries < failures_before_success:
1651 nb_retries += 1
1652 if failure_mode == "exception":
1653 raise RuntimeError("oops")
1654 return httpx.Response(500)
1655 return httpx.Response(200)
1656
1657 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1658
1659 response = await client.chat.completions.with_raw_response.create(
1660 messages=[
1661 {
1662 "content": "string",
1663 "role": "system",
1664 }
1665 ],
1666 model="gpt-4o",
1667 )
1668
1669 assert response.retries_taken == failures_before_success
1670 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1671
1672 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1673 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1674 @pytest.mark.respx(base_url=base_url)
1675 @pytest.mark.asyncio
1676 async def test_omit_retry_count_header(
1677 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1678 ) -> None:
1679 client = async_client.with_options(max_retries=4)
1680
1681 nb_retries = 0
1682
1683 def retry_handler(_request: httpx.Request) -> httpx.Response:
1684 nonlocal nb_retries
1685 if nb_retries < failures_before_success:
1686 nb_retries += 1
1687 return httpx.Response(500)
1688 return httpx.Response(200)
1689
1690 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1691
1692 response = await client.chat.completions.with_raw_response.create(
1693 messages=[
1694 {
1695 "content": "string",
1696 "role": "system",
1697 }
1698 ],
1699 model="gpt-4o",
1700 extra_headers={"x-stainless-retry-count": Omit()},
1701 )
1702
1703 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
1704
1705 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1706 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1707 @pytest.mark.respx(base_url=base_url)
1708 @pytest.mark.asyncio
1709 async def test_overwrite_retry_count_header(
1710 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1711 ) -> None:
1712 client = async_client.with_options(max_retries=4)
1713
1714 nb_retries = 0
1715
1716 def retry_handler(_request: httpx.Request) -> httpx.Response:
1717 nonlocal nb_retries
1718 if nb_retries < failures_before_success:
1719 nb_retries += 1
1720 return httpx.Response(500)
1721 return httpx.Response(200)
1722
1723 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1724
1725 response = await client.chat.completions.with_raw_response.create(
1726 messages=[
1727 {
1728 "content": "string",
1729 "role": "system",
1730 }
1731 ],
1732 model="gpt-4o",
1733 extra_headers={"x-stainless-retry-count": "42"},
1734 )
1735
1736 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
1737
1738 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1739 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1740 @pytest.mark.respx(base_url=base_url)
1741 @pytest.mark.asyncio
1742 async def test_retries_taken_new_response_class(
1743 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1744 ) -> None:
1745 client = async_client.with_options(max_retries=4)
1746
1747 nb_retries = 0
1748
1749 def retry_handler(_request: httpx.Request) -> httpx.Response:
1750 nonlocal nb_retries
1751 if nb_retries < failures_before_success:
1752 nb_retries += 1
1753 return httpx.Response(500)
1754 return httpx.Response(200)
1755
1756 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1757
1758 async with client.chat.completions.with_streaming_response.create(
1759 messages=[
1760 {
1761 "content": "string",
1762 "role": "system",
1763 }
1764 ],
1765 model="gpt-4o",
1766 ) as response:
1767 assert response.retries_taken == failures_before_success
1768 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1769