openai/openai-python

Public

mirrored from https://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
550b855fe51dd3fad4e1c5c796d7c35b8e603d2f

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1751lines · modecode

1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import gc
6import os
7import json
8import asyncio
9import inspect
10import tracemalloc
11from typing import Any, Union, cast
12from unittest import mock
13
14import httpx
15import pytest
16from respx import MockRouter
17from pydantic import ValidationError
18
19from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
20from openai._types import Omit
21from openai._models import BaseModel, FinalRequestOptions
22from openai._constants import RAW_RESPONSE_HEADER
23from openai._streaming import Stream, AsyncStream
24from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
25from openai._base_client import DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, BaseClient, make_request_options
26
27from .utils import update_env
28
29base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
30api_key = "My API Key"
31
32
33def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
34 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
35 url = httpx.URL(request.url)
36 return dict(url.params)
37
38
39def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
40 return 0.1
41
42
43def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
44 transport = client._client._transport
45 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
46
47 pool = transport._pool
48 return len(pool._requests)
49
50
51class TestOpenAI:
52 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
53
54 @pytest.mark.respx(base_url=base_url)
55 def test_raw_response(self, respx_mock: MockRouter) -> None:
56 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
57
58 response = self.client.post("/foo", cast_to=httpx.Response)
59 assert response.status_code == 200
60 assert isinstance(response, httpx.Response)
61 assert response.json() == {"foo": "bar"}
62
63 @pytest.mark.respx(base_url=base_url)
64 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
65 respx_mock.post("/foo").mock(
66 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
67 )
68
69 response = self.client.post("/foo", cast_to=httpx.Response)
70 assert response.status_code == 200
71 assert isinstance(response, httpx.Response)
72 assert response.json() == {"foo": "bar"}
73
74 def test_copy(self) -> None:
75 copied = self.client.copy()
76 assert id(copied) != id(self.client)
77
78 copied = self.client.copy(api_key="another My API Key")
79 assert copied.api_key == "another My API Key"
80 assert self.client.api_key == "My API Key"
81
82 def test_copy_default_options(self) -> None:
83 # options that have a default are overridden correctly
84 copied = self.client.copy(max_retries=7)
85 assert copied.max_retries == 7
86 assert self.client.max_retries == 2
87
88 copied2 = copied.copy(max_retries=6)
89 assert copied2.max_retries == 6
90 assert copied.max_retries == 7
91
92 # timeout
93 assert isinstance(self.client.timeout, httpx.Timeout)
94 copied = self.client.copy(timeout=None)
95 assert copied.timeout is None
96 assert isinstance(self.client.timeout, httpx.Timeout)
97
98 def test_copy_default_headers(self) -> None:
99 client = OpenAI(
100 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
101 )
102 assert client.default_headers["X-Foo"] == "bar"
103
104 # does not override the already given value when not specified
105 copied = client.copy()
106 assert copied.default_headers["X-Foo"] == "bar"
107
108 # merges already given headers
109 copied = client.copy(default_headers={"X-Bar": "stainless"})
110 assert copied.default_headers["X-Foo"] == "bar"
111 assert copied.default_headers["X-Bar"] == "stainless"
112
113 # uses new values for any already given headers
114 copied = client.copy(default_headers={"X-Foo": "stainless"})
115 assert copied.default_headers["X-Foo"] == "stainless"
116
117 # set_default_headers
118
119 # completely overrides already set values
120 copied = client.copy(set_default_headers={})
121 assert copied.default_headers.get("X-Foo") is None
122
123 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
124 assert copied.default_headers["X-Bar"] == "Robert"
125
126 with pytest.raises(
127 ValueError,
128 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
129 ):
130 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
131
132 def test_copy_default_query(self) -> None:
133 client = OpenAI(
134 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
135 )
136 assert _get_params(client)["foo"] == "bar"
137
138 # does not override the already given value when not specified
139 copied = client.copy()
140 assert _get_params(copied)["foo"] == "bar"
141
142 # merges already given params
143 copied = client.copy(default_query={"bar": "stainless"})
144 params = _get_params(copied)
145 assert params["foo"] == "bar"
146 assert params["bar"] == "stainless"
147
148 # uses new values for any already given headers
149 copied = client.copy(default_query={"foo": "stainless"})
150 assert _get_params(copied)["foo"] == "stainless"
151
152 # set_default_query
153
154 # completely overrides already set values
155 copied = client.copy(set_default_query={})
156 assert _get_params(copied) == {}
157
158 copied = client.copy(set_default_query={"bar": "Robert"})
159 assert _get_params(copied)["bar"] == "Robert"
160
161 with pytest.raises(
162 ValueError,
163 # TODO: update
164 match="`default_query` and `set_default_query` arguments are mutually exclusive",
165 ):
166 client.copy(set_default_query={}, default_query={"foo": "Bar"})
167
168 def test_copy_signature(self) -> None:
169 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
170 init_signature = inspect.signature(
171 # mypy doesn't like that we access the `__init__` property.
172 self.client.__init__, # type: ignore[misc]
173 )
174 copy_signature = inspect.signature(self.client.copy)
175 exclude_params = {"transport", "proxies", "_strict_response_validation"}
176
177 for name in init_signature.parameters.keys():
178 if name in exclude_params:
179 continue
180
181 copy_param = copy_signature.parameters.get(name)
182 assert copy_param is not None, f"copy() signature is missing the {name} param"
183
184 def test_copy_build_request(self) -> None:
185 options = FinalRequestOptions(method="get", url="/foo")
186
187 def build_request(options: FinalRequestOptions) -> None:
188 client = self.client.copy()
189 client._build_request(options)
190
191 # ensure that the machinery is warmed up before tracing starts.
192 build_request(options)
193 gc.collect()
194
195 tracemalloc.start(1000)
196
197 snapshot_before = tracemalloc.take_snapshot()
198
199 ITERATIONS = 10
200 for _ in range(ITERATIONS):
201 build_request(options)
202
203 gc.collect()
204 snapshot_after = tracemalloc.take_snapshot()
205
206 tracemalloc.stop()
207
208 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
209 if diff.count == 0:
210 # Avoid false positives by considering only leaks (i.e. allocations that persist).
211 return
212
213 if diff.count % ITERATIONS != 0:
214 # Avoid false positives by considering only leaks that appear per iteration.
215 return
216
217 for frame in diff.traceback:
218 if any(
219 frame.filename.endswith(fragment)
220 for fragment in [
221 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
222 #
223 # removing the decorator fixes the leak for reasons we don't understand.
224 "openai/_legacy_response.py",
225 "openai/_response.py",
226 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
227 "openai/_compat.py",
228 # Standard library leaks we don't care about.
229 "/logging/__init__.py",
230 ]
231 ):
232 return
233
234 leaks.append(diff)
235
236 leaks: list[tracemalloc.StatisticDiff] = []
237 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
238 add_leak(leaks, diff)
239 if leaks:
240 for leak in leaks:
241 print("MEMORY LEAK:", leak)
242 for frame in leak.traceback:
243 print(frame)
244 raise AssertionError()
245
246 def test_request_timeout(self) -> None:
247 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
248 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
249 assert timeout == DEFAULT_TIMEOUT
250
251 request = self.client._build_request(
252 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
253 )
254 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
255 assert timeout == httpx.Timeout(100.0)
256
257 def test_client_timeout_option(self) -> None:
258 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
259
260 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
261 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
262 assert timeout == httpx.Timeout(0)
263
264 def test_http_client_timeout_option(self) -> None:
265 # custom timeout given to the httpx client should be used
266 with httpx.Client(timeout=None) as http_client:
267 client = OpenAI(
268 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
269 )
270
271 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
272 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
273 assert timeout == httpx.Timeout(None)
274
275 # no timeout given to the httpx client should not use the httpx default
276 with httpx.Client() as http_client:
277 client = OpenAI(
278 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
279 )
280
281 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
282 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
283 assert timeout == DEFAULT_TIMEOUT
284
285 # explicitly passing the default timeout currently results in it being ignored
286 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
287 client = OpenAI(
288 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
289 )
290
291 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
292 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
293 assert timeout == DEFAULT_TIMEOUT # our default
294
295 async def test_invalid_http_client(self) -> None:
296 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
297 async with httpx.AsyncClient() as http_client:
298 OpenAI(
299 base_url=base_url,
300 api_key=api_key,
301 _strict_response_validation=True,
302 http_client=cast(Any, http_client),
303 )
304
305 def test_default_headers_option(self) -> None:
306 client = OpenAI(
307 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
308 )
309 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
310 assert request.headers.get("x-foo") == "bar"
311 assert request.headers.get("x-stainless-lang") == "python"
312
313 client2 = OpenAI(
314 base_url=base_url,
315 api_key=api_key,
316 _strict_response_validation=True,
317 default_headers={
318 "X-Foo": "stainless",
319 "X-Stainless-Lang": "my-overriding-header",
320 },
321 )
322 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
323 assert request.headers.get("x-foo") == "stainless"
324 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
325
326 def test_validate_headers(self) -> None:
327 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
328 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
329 assert request.headers.get("Authorization") == f"Bearer {api_key}"
330
331 with pytest.raises(OpenAIError):
332 with update_env(**{"OPENAI_API_KEY": Omit()}):
333 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
334 _ = client2
335
336 def test_default_query_option(self) -> None:
337 client = OpenAI(
338 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
339 )
340 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
341 url = httpx.URL(request.url)
342 assert dict(url.params) == {"query_param": "bar"}
343
344 request = client._build_request(
345 FinalRequestOptions(
346 method="get",
347 url="/foo",
348 params={"foo": "baz", "query_param": "overriden"},
349 )
350 )
351 url = httpx.URL(request.url)
352 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
353
354 def test_request_extra_json(self) -> None:
355 request = self.client._build_request(
356 FinalRequestOptions(
357 method="post",
358 url="/foo",
359 json_data={"foo": "bar"},
360 extra_json={"baz": False},
361 ),
362 )
363 data = json.loads(request.content.decode("utf-8"))
364 assert data == {"foo": "bar", "baz": False}
365
366 request = self.client._build_request(
367 FinalRequestOptions(
368 method="post",
369 url="/foo",
370 extra_json={"baz": False},
371 ),
372 )
373 data = json.loads(request.content.decode("utf-8"))
374 assert data == {"baz": False}
375
376 # `extra_json` takes priority over `json_data` when keys clash
377 request = self.client._build_request(
378 FinalRequestOptions(
379 method="post",
380 url="/foo",
381 json_data={"foo": "bar", "baz": True},
382 extra_json={"baz": None},
383 ),
384 )
385 data = json.loads(request.content.decode("utf-8"))
386 assert data == {"foo": "bar", "baz": None}
387
388 def test_request_extra_headers(self) -> None:
389 request = self.client._build_request(
390 FinalRequestOptions(
391 method="post",
392 url="/foo",
393 **make_request_options(extra_headers={"X-Foo": "Foo"}),
394 ),
395 )
396 assert request.headers.get("X-Foo") == "Foo"
397
398 # `extra_headers` takes priority over `default_headers` when keys clash
399 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
400 FinalRequestOptions(
401 method="post",
402 url="/foo",
403 **make_request_options(
404 extra_headers={"X-Bar": "false"},
405 ),
406 ),
407 )
408 assert request.headers.get("X-Bar") == "false"
409
410 def test_request_extra_query(self) -> None:
411 request = self.client._build_request(
412 FinalRequestOptions(
413 method="post",
414 url="/foo",
415 **make_request_options(
416 extra_query={"my_query_param": "Foo"},
417 ),
418 ),
419 )
420 params = dict(request.url.params)
421 assert params == {"my_query_param": "Foo"}
422
423 # if both `query` and `extra_query` are given, they are merged
424 request = self.client._build_request(
425 FinalRequestOptions(
426 method="post",
427 url="/foo",
428 **make_request_options(
429 query={"bar": "1"},
430 extra_query={"foo": "2"},
431 ),
432 ),
433 )
434 params = dict(request.url.params)
435 assert params == {"bar": "1", "foo": "2"}
436
437 # `extra_query` takes priority over `query` when keys clash
438 request = self.client._build_request(
439 FinalRequestOptions(
440 method="post",
441 url="/foo",
442 **make_request_options(
443 query={"foo": "1"},
444 extra_query={"foo": "2"},
445 ),
446 ),
447 )
448 params = dict(request.url.params)
449 assert params == {"foo": "2"}
450
451 def test_multipart_repeating_array(self, client: OpenAI) -> None:
452 request = client._build_request(
453 FinalRequestOptions.construct(
454 method="get",
455 url="/foo",
456 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
457 json_data={"array": ["foo", "bar"]},
458 files=[("foo.txt", b"hello world")],
459 )
460 )
461
462 assert request.read().split(b"\r\n") == [
463 b"--6b7ba517decee4a450543ea6ae821c82",
464 b'Content-Disposition: form-data; name="array[]"',
465 b"",
466 b"foo",
467 b"--6b7ba517decee4a450543ea6ae821c82",
468 b'Content-Disposition: form-data; name="array[]"',
469 b"",
470 b"bar",
471 b"--6b7ba517decee4a450543ea6ae821c82",
472 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
473 b"Content-Type: application/octet-stream",
474 b"",
475 b"hello world",
476 b"--6b7ba517decee4a450543ea6ae821c82--",
477 b"",
478 ]
479
480 @pytest.mark.respx(base_url=base_url)
481 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
482 class Model1(BaseModel):
483 name: str
484
485 class Model2(BaseModel):
486 foo: str
487
488 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
489
490 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
491 assert isinstance(response, Model2)
492 assert response.foo == "bar"
493
494 @pytest.mark.respx(base_url=base_url)
495 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
496 """Union of objects with the same field name using a different type"""
497
498 class Model1(BaseModel):
499 foo: int
500
501 class Model2(BaseModel):
502 foo: str
503
504 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
505
506 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
507 assert isinstance(response, Model2)
508 assert response.foo == "bar"
509
510 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
511
512 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
513 assert isinstance(response, Model1)
514 assert response.foo == 1
515
516 @pytest.mark.respx(base_url=base_url)
517 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
518 """
519 Response that sets Content-Type to something other than application/json but returns json data
520 """
521
522 class Model(BaseModel):
523 foo: int
524
525 respx_mock.get("/foo").mock(
526 return_value=httpx.Response(
527 200,
528 content=json.dumps({"foo": 2}),
529 headers={"Content-Type": "application/text"},
530 )
531 )
532
533 response = self.client.get("/foo", cast_to=Model)
534 assert isinstance(response, Model)
535 assert response.foo == 2
536
537 def test_base_url_setter(self) -> None:
538 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
539 assert client.base_url == "https://example.com/from_init/"
540
541 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
542
543 assert client.base_url == "https://example.com/from_setter/"
544
545 def test_base_url_env(self) -> None:
546 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
547 client = OpenAI(api_key=api_key, _strict_response_validation=True)
548 assert client.base_url == "http://localhost:5000/from/env/"
549
550 @pytest.mark.parametrize(
551 "client",
552 [
553 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
554 OpenAI(
555 base_url="http://localhost:5000/custom/path/",
556 api_key=api_key,
557 _strict_response_validation=True,
558 http_client=httpx.Client(),
559 ),
560 ],
561 ids=["standard", "custom http client"],
562 )
563 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
564 request = client._build_request(
565 FinalRequestOptions(
566 method="post",
567 url="/foo",
568 json_data={"foo": "bar"},
569 ),
570 )
571 assert request.url == "http://localhost:5000/custom/path/foo"
572
573 @pytest.mark.parametrize(
574 "client",
575 [
576 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
577 OpenAI(
578 base_url="http://localhost:5000/custom/path/",
579 api_key=api_key,
580 _strict_response_validation=True,
581 http_client=httpx.Client(),
582 ),
583 ],
584 ids=["standard", "custom http client"],
585 )
586 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
587 request = client._build_request(
588 FinalRequestOptions(
589 method="post",
590 url="/foo",
591 json_data={"foo": "bar"},
592 ),
593 )
594 assert request.url == "http://localhost:5000/custom/path/foo"
595
596 @pytest.mark.parametrize(
597 "client",
598 [
599 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
600 OpenAI(
601 base_url="http://localhost:5000/custom/path/",
602 api_key=api_key,
603 _strict_response_validation=True,
604 http_client=httpx.Client(),
605 ),
606 ],
607 ids=["standard", "custom http client"],
608 )
609 def test_absolute_request_url(self, client: OpenAI) -> None:
610 request = client._build_request(
611 FinalRequestOptions(
612 method="post",
613 url="https://myapi.com/foo",
614 json_data={"foo": "bar"},
615 ),
616 )
617 assert request.url == "https://myapi.com/foo"
618
619 def test_copied_client_does_not_close_http(self) -> None:
620 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
621 assert not client.is_closed()
622
623 copied = client.copy()
624 assert copied is not client
625
626 del copied
627
628 assert not client.is_closed()
629
630 def test_client_context_manager(self) -> None:
631 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
632 with client as c2:
633 assert c2 is client
634 assert not c2.is_closed()
635 assert not client.is_closed()
636 assert client.is_closed()
637
638 @pytest.mark.respx(base_url=base_url)
639 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
640 class Model(BaseModel):
641 foo: str
642
643 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
644
645 with pytest.raises(APIResponseValidationError) as exc:
646 self.client.get("/foo", cast_to=Model)
647
648 assert isinstance(exc.value.__cause__, ValidationError)
649
650 def test_client_max_retries_validation(self) -> None:
651 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
652 OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None))
653
654 @pytest.mark.respx(base_url=base_url)
655 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
656 class Model(BaseModel):
657 name: str
658
659 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
660
661 stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
662 assert isinstance(stream, Stream)
663 stream.response.close()
664
665 @pytest.mark.respx(base_url=base_url)
666 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
667 class Model(BaseModel):
668 name: str
669
670 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
671
672 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
673
674 with pytest.raises(APIResponseValidationError):
675 strict_client.get("/foo", cast_to=Model)
676
677 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
678
679 response = client.get("/foo", cast_to=Model)
680 assert isinstance(response, str) # type: ignore[unreachable]
681
682 @pytest.mark.parametrize(
683 "remaining_retries,retry_after,timeout",
684 [
685 [3, "20", 20],
686 [3, "0", 0.5],
687 [3, "-10", 0.5],
688 [3, "60", 60],
689 [3, "61", 0.5],
690 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
691 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
692 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
693 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
694 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
695 [3, "99999999999999999999999999999999999", 0.5],
696 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
697 [3, "", 0.5],
698 [2, "", 0.5 * 2.0],
699 [1, "", 0.5 * 4.0],
700 [-1100, "", 7.8], # test large number potentially overflowing
701 ],
702 )
703 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
704 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
705 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
706
707 headers = httpx.Headers({"retry-after": retry_after})
708 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
709 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
710 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
711
712 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
713 @pytest.mark.respx(base_url=base_url)
714 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
715 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
716
717 with pytest.raises(APITimeoutError):
718 self.client.post(
719 "/chat/completions",
720 body=cast(
721 object,
722 dict(
723 messages=[
724 {
725 "role": "user",
726 "content": "Say this is a test",
727 }
728 ],
729 model="gpt-3.5-turbo",
730 ),
731 ),
732 cast_to=httpx.Response,
733 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
734 )
735
736 assert _get_open_connections(self.client) == 0
737
738 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
739 @pytest.mark.respx(base_url=base_url)
740 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
741 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
742
743 with pytest.raises(APIStatusError):
744 self.client.post(
745 "/chat/completions",
746 body=cast(
747 object,
748 dict(
749 messages=[
750 {
751 "role": "user",
752 "content": "Say this is a test",
753 }
754 ],
755 model="gpt-3.5-turbo",
756 ),
757 ),
758 cast_to=httpx.Response,
759 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
760 )
761
762 assert _get_open_connections(self.client) == 0
763
764 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
765 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
766 @pytest.mark.respx(base_url=base_url)
767 def test_retries_taken(self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter) -> None:
768 client = client.with_options(max_retries=4)
769
770 nb_retries = 0
771
772 def retry_handler(_request: httpx.Request) -> httpx.Response:
773 nonlocal nb_retries
774 if nb_retries < failures_before_success:
775 nb_retries += 1
776 return httpx.Response(500)
777 return httpx.Response(200)
778
779 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
780
781 response = client.chat.completions.with_raw_response.create(
782 messages=[
783 {
784 "content": "string",
785 "role": "system",
786 }
787 ],
788 model="gpt-4o",
789 )
790
791 assert response.retries_taken == failures_before_success
792 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
793
794 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
795 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
796 @pytest.mark.respx(base_url=base_url)
797 def test_omit_retry_count_header(
798 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
799 ) -> None:
800 client = client.with_options(max_retries=4)
801
802 nb_retries = 0
803
804 def retry_handler(_request: httpx.Request) -> httpx.Response:
805 nonlocal nb_retries
806 if nb_retries < failures_before_success:
807 nb_retries += 1
808 return httpx.Response(500)
809 return httpx.Response(200)
810
811 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
812
813 response = client.chat.completions.with_raw_response.create(
814 messages=[
815 {
816 "content": "string",
817 "role": "system",
818 }
819 ],
820 model="gpt-4o",
821 extra_headers={"x-stainless-retry-count": Omit()},
822 )
823
824 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
825
826 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
827 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
828 @pytest.mark.respx(base_url=base_url)
829 def test_overwrite_retry_count_header(
830 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
831 ) -> None:
832 client = client.with_options(max_retries=4)
833
834 nb_retries = 0
835
836 def retry_handler(_request: httpx.Request) -> httpx.Response:
837 nonlocal nb_retries
838 if nb_retries < failures_before_success:
839 nb_retries += 1
840 return httpx.Response(500)
841 return httpx.Response(200)
842
843 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
844
845 response = client.chat.completions.with_raw_response.create(
846 messages=[
847 {
848 "content": "string",
849 "role": "system",
850 }
851 ],
852 model="gpt-4o",
853 extra_headers={"x-stainless-retry-count": "42"},
854 )
855
856 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
857
858 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
859 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
860 @pytest.mark.respx(base_url=base_url)
861 def test_retries_taken_new_response_class(
862 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
863 ) -> None:
864 client = client.with_options(max_retries=4)
865
866 nb_retries = 0
867
868 def retry_handler(_request: httpx.Request) -> httpx.Response:
869 nonlocal nb_retries
870 if nb_retries < failures_before_success:
871 nb_retries += 1
872 return httpx.Response(500)
873 return httpx.Response(200)
874
875 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
876
877 with client.chat.completions.with_streaming_response.create(
878 messages=[
879 {
880 "content": "string",
881 "role": "system",
882 }
883 ],
884 model="gpt-4o",
885 ) as response:
886 assert response.retries_taken == failures_before_success
887 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
888
889
890class TestAsyncOpenAI:
891 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
892
893 @pytest.mark.respx(base_url=base_url)
894 @pytest.mark.asyncio
895 async def test_raw_response(self, respx_mock: MockRouter) -> None:
896 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
897
898 response = await self.client.post("/foo", cast_to=httpx.Response)
899 assert response.status_code == 200
900 assert isinstance(response, httpx.Response)
901 assert response.json() == {"foo": "bar"}
902
903 @pytest.mark.respx(base_url=base_url)
904 @pytest.mark.asyncio
905 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
906 respx_mock.post("/foo").mock(
907 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
908 )
909
910 response = await self.client.post("/foo", cast_to=httpx.Response)
911 assert response.status_code == 200
912 assert isinstance(response, httpx.Response)
913 assert response.json() == {"foo": "bar"}
914
915 def test_copy(self) -> None:
916 copied = self.client.copy()
917 assert id(copied) != id(self.client)
918
919 copied = self.client.copy(api_key="another My API Key")
920 assert copied.api_key == "another My API Key"
921 assert self.client.api_key == "My API Key"
922
923 def test_copy_default_options(self) -> None:
924 # options that have a default are overridden correctly
925 copied = self.client.copy(max_retries=7)
926 assert copied.max_retries == 7
927 assert self.client.max_retries == 2
928
929 copied2 = copied.copy(max_retries=6)
930 assert copied2.max_retries == 6
931 assert copied.max_retries == 7
932
933 # timeout
934 assert isinstance(self.client.timeout, httpx.Timeout)
935 copied = self.client.copy(timeout=None)
936 assert copied.timeout is None
937 assert isinstance(self.client.timeout, httpx.Timeout)
938
939 def test_copy_default_headers(self) -> None:
940 client = AsyncOpenAI(
941 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
942 )
943 assert client.default_headers["X-Foo"] == "bar"
944
945 # does not override the already given value when not specified
946 copied = client.copy()
947 assert copied.default_headers["X-Foo"] == "bar"
948
949 # merges already given headers
950 copied = client.copy(default_headers={"X-Bar": "stainless"})
951 assert copied.default_headers["X-Foo"] == "bar"
952 assert copied.default_headers["X-Bar"] == "stainless"
953
954 # uses new values for any already given headers
955 copied = client.copy(default_headers={"X-Foo": "stainless"})
956 assert copied.default_headers["X-Foo"] == "stainless"
957
958 # set_default_headers
959
960 # completely overrides already set values
961 copied = client.copy(set_default_headers={})
962 assert copied.default_headers.get("X-Foo") is None
963
964 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
965 assert copied.default_headers["X-Bar"] == "Robert"
966
967 with pytest.raises(
968 ValueError,
969 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
970 ):
971 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
972
973 def test_copy_default_query(self) -> None:
974 client = AsyncOpenAI(
975 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
976 )
977 assert _get_params(client)["foo"] == "bar"
978
979 # does not override the already given value when not specified
980 copied = client.copy()
981 assert _get_params(copied)["foo"] == "bar"
982
983 # merges already given params
984 copied = client.copy(default_query={"bar": "stainless"})
985 params = _get_params(copied)
986 assert params["foo"] == "bar"
987 assert params["bar"] == "stainless"
988
989 # uses new values for any already given headers
990 copied = client.copy(default_query={"foo": "stainless"})
991 assert _get_params(copied)["foo"] == "stainless"
992
993 # set_default_query
994
995 # completely overrides already set values
996 copied = client.copy(set_default_query={})
997 assert _get_params(copied) == {}
998
999 copied = client.copy(set_default_query={"bar": "Robert"})
1000 assert _get_params(copied)["bar"] == "Robert"
1001
1002 with pytest.raises(
1003 ValueError,
1004 # TODO: update
1005 match="`default_query` and `set_default_query` arguments are mutually exclusive",
1006 ):
1007 client.copy(set_default_query={}, default_query={"foo": "Bar"})
1008
1009 def test_copy_signature(self) -> None:
1010 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
1011 init_signature = inspect.signature(
1012 # mypy doesn't like that we access the `__init__` property.
1013 self.client.__init__, # type: ignore[misc]
1014 )
1015 copy_signature = inspect.signature(self.client.copy)
1016 exclude_params = {"transport", "proxies", "_strict_response_validation"}
1017
1018 for name in init_signature.parameters.keys():
1019 if name in exclude_params:
1020 continue
1021
1022 copy_param = copy_signature.parameters.get(name)
1023 assert copy_param is not None, f"copy() signature is missing the {name} param"
1024
1025 def test_copy_build_request(self) -> None:
1026 options = FinalRequestOptions(method="get", url="/foo")
1027
1028 def build_request(options: FinalRequestOptions) -> None:
1029 client = self.client.copy()
1030 client._build_request(options)
1031
1032 # ensure that the machinery is warmed up before tracing starts.
1033 build_request(options)
1034 gc.collect()
1035
1036 tracemalloc.start(1000)
1037
1038 snapshot_before = tracemalloc.take_snapshot()
1039
1040 ITERATIONS = 10
1041 for _ in range(ITERATIONS):
1042 build_request(options)
1043
1044 gc.collect()
1045 snapshot_after = tracemalloc.take_snapshot()
1046
1047 tracemalloc.stop()
1048
1049 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
1050 if diff.count == 0:
1051 # Avoid false positives by considering only leaks (i.e. allocations that persist).
1052 return
1053
1054 if diff.count % ITERATIONS != 0:
1055 # Avoid false positives by considering only leaks that appear per iteration.
1056 return
1057
1058 for frame in diff.traceback:
1059 if any(
1060 frame.filename.endswith(fragment)
1061 for fragment in [
1062 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
1063 #
1064 # removing the decorator fixes the leak for reasons we don't understand.
1065 "openai/_legacy_response.py",
1066 "openai/_response.py",
1067 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
1068 "openai/_compat.py",
1069 # Standard library leaks we don't care about.
1070 "/logging/__init__.py",
1071 ]
1072 ):
1073 return
1074
1075 leaks.append(diff)
1076
1077 leaks: list[tracemalloc.StatisticDiff] = []
1078 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
1079 add_leak(leaks, diff)
1080 if leaks:
1081 for leak in leaks:
1082 print("MEMORY LEAK:", leak)
1083 for frame in leak.traceback:
1084 print(frame)
1085 raise AssertionError()
1086
1087 async def test_request_timeout(self) -> None:
1088 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
1089 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1090 assert timeout == DEFAULT_TIMEOUT
1091
1092 request = self.client._build_request(
1093 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
1094 )
1095 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1096 assert timeout == httpx.Timeout(100.0)
1097
1098 async def test_client_timeout_option(self) -> None:
1099 client = AsyncOpenAI(
1100 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
1101 )
1102
1103 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1104 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1105 assert timeout == httpx.Timeout(0)
1106
1107 async def test_http_client_timeout_option(self) -> None:
1108 # custom timeout given to the httpx client should be used
1109 async with httpx.AsyncClient(timeout=None) as http_client:
1110 client = AsyncOpenAI(
1111 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1112 )
1113
1114 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1115 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1116 assert timeout == httpx.Timeout(None)
1117
1118 # no timeout given to the httpx client should not use the httpx default
1119 async with httpx.AsyncClient() as http_client:
1120 client = AsyncOpenAI(
1121 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1122 )
1123
1124 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1125 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1126 assert timeout == DEFAULT_TIMEOUT
1127
1128 # explicitly passing the default timeout currently results in it being ignored
1129 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1130 client = AsyncOpenAI(
1131 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1132 )
1133
1134 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1135 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1136 assert timeout == DEFAULT_TIMEOUT # our default
1137
1138 def test_invalid_http_client(self) -> None:
1139 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1140 with httpx.Client() as http_client:
1141 AsyncOpenAI(
1142 base_url=base_url,
1143 api_key=api_key,
1144 _strict_response_validation=True,
1145 http_client=cast(Any, http_client),
1146 )
1147
1148 def test_default_headers_option(self) -> None:
1149 client = AsyncOpenAI(
1150 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1151 )
1152 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1153 assert request.headers.get("x-foo") == "bar"
1154 assert request.headers.get("x-stainless-lang") == "python"
1155
1156 client2 = AsyncOpenAI(
1157 base_url=base_url,
1158 api_key=api_key,
1159 _strict_response_validation=True,
1160 default_headers={
1161 "X-Foo": "stainless",
1162 "X-Stainless-Lang": "my-overriding-header",
1163 },
1164 )
1165 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1166 assert request.headers.get("x-foo") == "stainless"
1167 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1168
1169 def test_validate_headers(self) -> None:
1170 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1171 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1172 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1173
1174 with pytest.raises(OpenAIError):
1175 with update_env(**{"OPENAI_API_KEY": Omit()}):
1176 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1177 _ = client2
1178
1179 def test_default_query_option(self) -> None:
1180 client = AsyncOpenAI(
1181 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1182 )
1183 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1184 url = httpx.URL(request.url)
1185 assert dict(url.params) == {"query_param": "bar"}
1186
1187 request = client._build_request(
1188 FinalRequestOptions(
1189 method="get",
1190 url="/foo",
1191 params={"foo": "baz", "query_param": "overriden"},
1192 )
1193 )
1194 url = httpx.URL(request.url)
1195 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
1196
1197 def test_request_extra_json(self) -> None:
1198 request = self.client._build_request(
1199 FinalRequestOptions(
1200 method="post",
1201 url="/foo",
1202 json_data={"foo": "bar"},
1203 extra_json={"baz": False},
1204 ),
1205 )
1206 data = json.loads(request.content.decode("utf-8"))
1207 assert data == {"foo": "bar", "baz": False}
1208
1209 request = self.client._build_request(
1210 FinalRequestOptions(
1211 method="post",
1212 url="/foo",
1213 extra_json={"baz": False},
1214 ),
1215 )
1216 data = json.loads(request.content.decode("utf-8"))
1217 assert data == {"baz": False}
1218
1219 # `extra_json` takes priority over `json_data` when keys clash
1220 request = self.client._build_request(
1221 FinalRequestOptions(
1222 method="post",
1223 url="/foo",
1224 json_data={"foo": "bar", "baz": True},
1225 extra_json={"baz": None},
1226 ),
1227 )
1228 data = json.loads(request.content.decode("utf-8"))
1229 assert data == {"foo": "bar", "baz": None}
1230
1231 def test_request_extra_headers(self) -> None:
1232 request = self.client._build_request(
1233 FinalRequestOptions(
1234 method="post",
1235 url="/foo",
1236 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1237 ),
1238 )
1239 assert request.headers.get("X-Foo") == "Foo"
1240
1241 # `extra_headers` takes priority over `default_headers` when keys clash
1242 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1243 FinalRequestOptions(
1244 method="post",
1245 url="/foo",
1246 **make_request_options(
1247 extra_headers={"X-Bar": "false"},
1248 ),
1249 ),
1250 )
1251 assert request.headers.get("X-Bar") == "false"
1252
1253 def test_request_extra_query(self) -> None:
1254 request = self.client._build_request(
1255 FinalRequestOptions(
1256 method="post",
1257 url="/foo",
1258 **make_request_options(
1259 extra_query={"my_query_param": "Foo"},
1260 ),
1261 ),
1262 )
1263 params = dict(request.url.params)
1264 assert params == {"my_query_param": "Foo"}
1265
1266 # if both `query` and `extra_query` are given, they are merged
1267 request = self.client._build_request(
1268 FinalRequestOptions(
1269 method="post",
1270 url="/foo",
1271 **make_request_options(
1272 query={"bar": "1"},
1273 extra_query={"foo": "2"},
1274 ),
1275 ),
1276 )
1277 params = dict(request.url.params)
1278 assert params == {"bar": "1", "foo": "2"}
1279
1280 # `extra_query` takes priority over `query` when keys clash
1281 request = self.client._build_request(
1282 FinalRequestOptions(
1283 method="post",
1284 url="/foo",
1285 **make_request_options(
1286 query={"foo": "1"},
1287 extra_query={"foo": "2"},
1288 ),
1289 ),
1290 )
1291 params = dict(request.url.params)
1292 assert params == {"foo": "2"}
1293
1294 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1295 request = async_client._build_request(
1296 FinalRequestOptions.construct(
1297 method="get",
1298 url="/foo",
1299 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1300 json_data={"array": ["foo", "bar"]},
1301 files=[("foo.txt", b"hello world")],
1302 )
1303 )
1304
1305 assert request.read().split(b"\r\n") == [
1306 b"--6b7ba517decee4a450543ea6ae821c82",
1307 b'Content-Disposition: form-data; name="array[]"',
1308 b"",
1309 b"foo",
1310 b"--6b7ba517decee4a450543ea6ae821c82",
1311 b'Content-Disposition: form-data; name="array[]"',
1312 b"",
1313 b"bar",
1314 b"--6b7ba517decee4a450543ea6ae821c82",
1315 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1316 b"Content-Type: application/octet-stream",
1317 b"",
1318 b"hello world",
1319 b"--6b7ba517decee4a450543ea6ae821c82--",
1320 b"",
1321 ]
1322
1323 @pytest.mark.respx(base_url=base_url)
1324 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1325 class Model1(BaseModel):
1326 name: str
1327
1328 class Model2(BaseModel):
1329 foo: str
1330
1331 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1332
1333 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1334 assert isinstance(response, Model2)
1335 assert response.foo == "bar"
1336
1337 @pytest.mark.respx(base_url=base_url)
1338 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1339 """Union of objects with the same field name using a different type"""
1340
1341 class Model1(BaseModel):
1342 foo: int
1343
1344 class Model2(BaseModel):
1345 foo: str
1346
1347 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1348
1349 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1350 assert isinstance(response, Model2)
1351 assert response.foo == "bar"
1352
1353 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1354
1355 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1356 assert isinstance(response, Model1)
1357 assert response.foo == 1
1358
1359 @pytest.mark.respx(base_url=base_url)
1360 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1361 """
1362 Response that sets Content-Type to something other than application/json but returns json data
1363 """
1364
1365 class Model(BaseModel):
1366 foo: int
1367
1368 respx_mock.get("/foo").mock(
1369 return_value=httpx.Response(
1370 200,
1371 content=json.dumps({"foo": 2}),
1372 headers={"Content-Type": "application/text"},
1373 )
1374 )
1375
1376 response = await self.client.get("/foo", cast_to=Model)
1377 assert isinstance(response, Model)
1378 assert response.foo == 2
1379
1380 def test_base_url_setter(self) -> None:
1381 client = AsyncOpenAI(
1382 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1383 )
1384 assert client.base_url == "https://example.com/from_init/"
1385
1386 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1387
1388 assert client.base_url == "https://example.com/from_setter/"
1389
1390 def test_base_url_env(self) -> None:
1391 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1392 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1393 assert client.base_url == "http://localhost:5000/from/env/"
1394
1395 @pytest.mark.parametrize(
1396 "client",
1397 [
1398 AsyncOpenAI(
1399 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1400 ),
1401 AsyncOpenAI(
1402 base_url="http://localhost:5000/custom/path/",
1403 api_key=api_key,
1404 _strict_response_validation=True,
1405 http_client=httpx.AsyncClient(),
1406 ),
1407 ],
1408 ids=["standard", "custom http client"],
1409 )
1410 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1411 request = client._build_request(
1412 FinalRequestOptions(
1413 method="post",
1414 url="/foo",
1415 json_data={"foo": "bar"},
1416 ),
1417 )
1418 assert request.url == "http://localhost:5000/custom/path/foo"
1419
1420 @pytest.mark.parametrize(
1421 "client",
1422 [
1423 AsyncOpenAI(
1424 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1425 ),
1426 AsyncOpenAI(
1427 base_url="http://localhost:5000/custom/path/",
1428 api_key=api_key,
1429 _strict_response_validation=True,
1430 http_client=httpx.AsyncClient(),
1431 ),
1432 ],
1433 ids=["standard", "custom http client"],
1434 )
1435 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1436 request = client._build_request(
1437 FinalRequestOptions(
1438 method="post",
1439 url="/foo",
1440 json_data={"foo": "bar"},
1441 ),
1442 )
1443 assert request.url == "http://localhost:5000/custom/path/foo"
1444
1445 @pytest.mark.parametrize(
1446 "client",
1447 [
1448 AsyncOpenAI(
1449 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1450 ),
1451 AsyncOpenAI(
1452 base_url="http://localhost:5000/custom/path/",
1453 api_key=api_key,
1454 _strict_response_validation=True,
1455 http_client=httpx.AsyncClient(),
1456 ),
1457 ],
1458 ids=["standard", "custom http client"],
1459 )
1460 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1461 request = client._build_request(
1462 FinalRequestOptions(
1463 method="post",
1464 url="https://myapi.com/foo",
1465 json_data={"foo": "bar"},
1466 ),
1467 )
1468 assert request.url == "https://myapi.com/foo"
1469
1470 async def test_copied_client_does_not_close_http(self) -> None:
1471 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1472 assert not client.is_closed()
1473
1474 copied = client.copy()
1475 assert copied is not client
1476
1477 del copied
1478
1479 await asyncio.sleep(0.2)
1480 assert not client.is_closed()
1481
1482 async def test_client_context_manager(self) -> None:
1483 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1484 async with client as c2:
1485 assert c2 is client
1486 assert not c2.is_closed()
1487 assert not client.is_closed()
1488 assert client.is_closed()
1489
1490 @pytest.mark.respx(base_url=base_url)
1491 @pytest.mark.asyncio
1492 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1493 class Model(BaseModel):
1494 foo: str
1495
1496 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1497
1498 with pytest.raises(APIResponseValidationError) as exc:
1499 await self.client.get("/foo", cast_to=Model)
1500
1501 assert isinstance(exc.value.__cause__, ValidationError)
1502
1503 async def test_client_max_retries_validation(self) -> None:
1504 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
1505 AsyncOpenAI(
1506 base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)
1507 )
1508
1509 @pytest.mark.respx(base_url=base_url)
1510 @pytest.mark.asyncio
1511 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1512 class Model(BaseModel):
1513 name: str
1514
1515 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1516
1517 stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1518 assert isinstance(stream, AsyncStream)
1519 await stream.response.aclose()
1520
1521 @pytest.mark.respx(base_url=base_url)
1522 @pytest.mark.asyncio
1523 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1524 class Model(BaseModel):
1525 name: str
1526
1527 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1528
1529 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1530
1531 with pytest.raises(APIResponseValidationError):
1532 await strict_client.get("/foo", cast_to=Model)
1533
1534 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1535
1536 response = await client.get("/foo", cast_to=Model)
1537 assert isinstance(response, str) # type: ignore[unreachable]
1538
1539 @pytest.mark.parametrize(
1540 "remaining_retries,retry_after,timeout",
1541 [
1542 [3, "20", 20],
1543 [3, "0", 0.5],
1544 [3, "-10", 0.5],
1545 [3, "60", 60],
1546 [3, "61", 0.5],
1547 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1548 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1549 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1550 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1551 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1552 [3, "99999999999999999999999999999999999", 0.5],
1553 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1554 [3, "", 0.5],
1555 [2, "", 0.5 * 2.0],
1556 [1, "", 0.5 * 4.0],
1557 [-1100, "", 7.8], # test large number potentially overflowing
1558 ],
1559 )
1560 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1561 @pytest.mark.asyncio
1562 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1563 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1564
1565 headers = httpx.Headers({"retry-after": retry_after})
1566 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1567 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1568 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1569
1570 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1571 @pytest.mark.respx(base_url=base_url)
1572 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1573 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1574
1575 with pytest.raises(APITimeoutError):
1576 await self.client.post(
1577 "/chat/completions",
1578 body=cast(
1579 object,
1580 dict(
1581 messages=[
1582 {
1583 "role": "user",
1584 "content": "Say this is a test",
1585 }
1586 ],
1587 model="gpt-3.5-turbo",
1588 ),
1589 ),
1590 cast_to=httpx.Response,
1591 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1592 )
1593
1594 assert _get_open_connections(self.client) == 0
1595
1596 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1597 @pytest.mark.respx(base_url=base_url)
1598 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1599 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1600
1601 with pytest.raises(APIStatusError):
1602 await self.client.post(
1603 "/chat/completions",
1604 body=cast(
1605 object,
1606 dict(
1607 messages=[
1608 {
1609 "role": "user",
1610 "content": "Say this is a test",
1611 }
1612 ],
1613 model="gpt-3.5-turbo",
1614 ),
1615 ),
1616 cast_to=httpx.Response,
1617 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1618 )
1619
1620 assert _get_open_connections(self.client) == 0
1621
1622 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1623 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1624 @pytest.mark.respx(base_url=base_url)
1625 @pytest.mark.asyncio
1626 async def test_retries_taken(
1627 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1628 ) -> None:
1629 client = async_client.with_options(max_retries=4)
1630
1631 nb_retries = 0
1632
1633 def retry_handler(_request: httpx.Request) -> httpx.Response:
1634 nonlocal nb_retries
1635 if nb_retries < failures_before_success:
1636 nb_retries += 1
1637 return httpx.Response(500)
1638 return httpx.Response(200)
1639
1640 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1641
1642 response = await client.chat.completions.with_raw_response.create(
1643 messages=[
1644 {
1645 "content": "string",
1646 "role": "system",
1647 }
1648 ],
1649 model="gpt-4o",
1650 )
1651
1652 assert response.retries_taken == failures_before_success
1653 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1654
1655 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1656 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1657 @pytest.mark.respx(base_url=base_url)
1658 @pytest.mark.asyncio
1659 async def test_omit_retry_count_header(
1660 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1661 ) -> None:
1662 client = async_client.with_options(max_retries=4)
1663
1664 nb_retries = 0
1665
1666 def retry_handler(_request: httpx.Request) -> httpx.Response:
1667 nonlocal nb_retries
1668 if nb_retries < failures_before_success:
1669 nb_retries += 1
1670 return httpx.Response(500)
1671 return httpx.Response(200)
1672
1673 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1674
1675 response = await client.chat.completions.with_raw_response.create(
1676 messages=[
1677 {
1678 "content": "string",
1679 "role": "system",
1680 }
1681 ],
1682 model="gpt-4o",
1683 extra_headers={"x-stainless-retry-count": Omit()},
1684 )
1685
1686 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
1687
1688 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1689 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1690 @pytest.mark.respx(base_url=base_url)
1691 @pytest.mark.asyncio
1692 async def test_overwrite_retry_count_header(
1693 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1694 ) -> None:
1695 client = async_client.with_options(max_retries=4)
1696
1697 nb_retries = 0
1698
1699 def retry_handler(_request: httpx.Request) -> httpx.Response:
1700 nonlocal nb_retries
1701 if nb_retries < failures_before_success:
1702 nb_retries += 1
1703 return httpx.Response(500)
1704 return httpx.Response(200)
1705
1706 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1707
1708 response = await client.chat.completions.with_raw_response.create(
1709 messages=[
1710 {
1711 "content": "string",
1712 "role": "system",
1713 }
1714 ],
1715 model="gpt-4o",
1716 extra_headers={"x-stainless-retry-count": "42"},
1717 )
1718
1719 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
1720
1721 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1722 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1723 @pytest.mark.respx(base_url=base_url)
1724 @pytest.mark.asyncio
1725 async def test_retries_taken_new_response_class(
1726 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1727 ) -> None:
1728 client = async_client.with_options(max_retries=4)
1729
1730 nb_retries = 0
1731
1732 def retry_handler(_request: httpx.Request) -> httpx.Response:
1733 nonlocal nb_retries
1734 if nb_retries < failures_before_success:
1735 nb_retries += 1
1736 return httpx.Response(500)
1737 return httpx.Response(200)
1738
1739 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1740
1741 async with client.chat.completions.with_streaming_response.create(
1742 messages=[
1743 {
1744 "content": "string",
1745 "role": "system",
1746 }
1747 ],
1748 model="gpt-4o",
1749 ) as response:
1750 assert response.retries_taken == failures_before_success
1751 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1752