openai/openai-python

Public

mirrored from https://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
63de8ef4229ba92fb05ac15efb1a07b3ab2a79ee

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1449lines · modecode

1# File generated from our OpenAPI spec by Stainless.
2
3from __future__ import annotations
4
5import gc
6import os
7import json
8import asyncio
9import inspect
10import tracemalloc
11from typing import Any, Union, cast
12from unittest import mock
13
14import httpx
15import pytest
16from respx import MockRouter
17from pydantic import ValidationError
18
19from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
20from openai._client import OpenAI, AsyncOpenAI
21from openai._models import BaseModel, FinalRequestOptions
22from openai._constants import RAW_RESPONSE_HEADER
23from openai._streaming import Stream, AsyncStream
24from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
25from openai._base_client import DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, BaseClient, make_request_options
26
27from .utils import update_env
28
29base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
30api_key = "My API Key"
31
32
33def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
34 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
35 url = httpx.URL(request.url)
36 return dict(url.params)
37
38
39def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
40 return 0.1
41
42
43def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
44 transport = client._client._transport
45 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
46
47 pool = transport._pool
48 return len(pool._requests)
49
50
51class TestOpenAI:
52 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
53
54 @pytest.mark.respx(base_url=base_url)
55 def test_raw_response(self, respx_mock: MockRouter) -> None:
56 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
57
58 response = self.client.post("/foo", cast_to=httpx.Response)
59 assert response.status_code == 200
60 assert isinstance(response, httpx.Response)
61 assert response.json() == {"foo": "bar"}
62
63 @pytest.mark.respx(base_url=base_url)
64 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
65 respx_mock.post("/foo").mock(
66 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
67 )
68
69 response = self.client.post("/foo", cast_to=httpx.Response)
70 assert response.status_code == 200
71 assert isinstance(response, httpx.Response)
72 assert response.json() == {"foo": "bar"}
73
74 def test_copy(self) -> None:
75 copied = self.client.copy()
76 assert id(copied) != id(self.client)
77
78 copied = self.client.copy(api_key="another My API Key")
79 assert copied.api_key == "another My API Key"
80 assert self.client.api_key == "My API Key"
81
82 def test_copy_default_options(self) -> None:
83 # options that have a default are overridden correctly
84 copied = self.client.copy(max_retries=7)
85 assert copied.max_retries == 7
86 assert self.client.max_retries == 2
87
88 copied2 = copied.copy(max_retries=6)
89 assert copied2.max_retries == 6
90 assert copied.max_retries == 7
91
92 # timeout
93 assert isinstance(self.client.timeout, httpx.Timeout)
94 copied = self.client.copy(timeout=None)
95 assert copied.timeout is None
96 assert isinstance(self.client.timeout, httpx.Timeout)
97
98 def test_copy_default_headers(self) -> None:
99 client = OpenAI(
100 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
101 )
102 assert client.default_headers["X-Foo"] == "bar"
103
104 # does not override the already given value when not specified
105 copied = client.copy()
106 assert copied.default_headers["X-Foo"] == "bar"
107
108 # merges already given headers
109 copied = client.copy(default_headers={"X-Bar": "stainless"})
110 assert copied.default_headers["X-Foo"] == "bar"
111 assert copied.default_headers["X-Bar"] == "stainless"
112
113 # uses new values for any already given headers
114 copied = client.copy(default_headers={"X-Foo": "stainless"})
115 assert copied.default_headers["X-Foo"] == "stainless"
116
117 # set_default_headers
118
119 # completely overrides already set values
120 copied = client.copy(set_default_headers={})
121 assert copied.default_headers.get("X-Foo") is None
122
123 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
124 assert copied.default_headers["X-Bar"] == "Robert"
125
126 with pytest.raises(
127 ValueError,
128 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
129 ):
130 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
131
132 def test_copy_default_query(self) -> None:
133 client = OpenAI(
134 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
135 )
136 assert _get_params(client)["foo"] == "bar"
137
138 # does not override the already given value when not specified
139 copied = client.copy()
140 assert _get_params(copied)["foo"] == "bar"
141
142 # merges already given params
143 copied = client.copy(default_query={"bar": "stainless"})
144 params = _get_params(copied)
145 assert params["foo"] == "bar"
146 assert params["bar"] == "stainless"
147
148 # uses new values for any already given headers
149 copied = client.copy(default_query={"foo": "stainless"})
150 assert _get_params(copied)["foo"] == "stainless"
151
152 # set_default_query
153
154 # completely overrides already set values
155 copied = client.copy(set_default_query={})
156 assert _get_params(copied) == {}
157
158 copied = client.copy(set_default_query={"bar": "Robert"})
159 assert _get_params(copied)["bar"] == "Robert"
160
161 with pytest.raises(
162 ValueError,
163 # TODO: update
164 match="`default_query` and `set_default_query` arguments are mutually exclusive",
165 ):
166 client.copy(set_default_query={}, default_query={"foo": "Bar"})
167
168 def test_copy_signature(self) -> None:
169 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
170 init_signature = inspect.signature(
171 # mypy doesn't like that we access the `__init__` property.
172 self.client.__init__, # type: ignore[misc]
173 )
174 copy_signature = inspect.signature(self.client.copy)
175 exclude_params = {"transport", "proxies", "_strict_response_validation"}
176
177 for name in init_signature.parameters.keys():
178 if name in exclude_params:
179 continue
180
181 copy_param = copy_signature.parameters.get(name)
182 assert copy_param is not None, f"copy() signature is missing the {name} param"
183
184 def test_copy_build_request(self) -> None:
185 options = FinalRequestOptions(method="get", url="/foo")
186
187 def build_request(options: FinalRequestOptions) -> None:
188 client = self.client.copy()
189 client._build_request(options)
190
191 # ensure that the machinery is warmed up before tracing starts.
192 build_request(options)
193 gc.collect()
194
195 tracemalloc.start(1000)
196
197 snapshot_before = tracemalloc.take_snapshot()
198
199 ITERATIONS = 10
200 for _ in range(ITERATIONS):
201 build_request(options)
202
203 gc.collect()
204 snapshot_after = tracemalloc.take_snapshot()
205
206 tracemalloc.stop()
207
208 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
209 if diff.count == 0:
210 # Avoid false positives by considering only leaks (i.e. allocations that persist).
211 return
212
213 if diff.count % ITERATIONS != 0:
214 # Avoid false positives by considering only leaks that appear per iteration.
215 return
216
217 for frame in diff.traceback:
218 if any(
219 frame.filename.endswith(fragment)
220 for fragment in [
221 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
222 #
223 # removing the decorator fixes the leak for reasons we don't understand.
224 "openai/_legacy_response.py",
225 "openai/_response.py",
226 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
227 "openai/_compat.py",
228 # Standard library leaks we don't care about.
229 "/logging/__init__.py",
230 ]
231 ):
232 return
233
234 leaks.append(diff)
235
236 leaks: list[tracemalloc.StatisticDiff] = []
237 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
238 add_leak(leaks, diff)
239 if leaks:
240 for leak in leaks:
241 print("MEMORY LEAK:", leak)
242 for frame in leak.traceback:
243 print(frame)
244 raise AssertionError()
245
246 def test_request_timeout(self) -> None:
247 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
248 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
249 assert timeout == DEFAULT_TIMEOUT
250
251 request = self.client._build_request(
252 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
253 )
254 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
255 assert timeout == httpx.Timeout(100.0)
256
257 def test_client_timeout_option(self) -> None:
258 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
259
260 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
261 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
262 assert timeout == httpx.Timeout(0)
263
264 def test_http_client_timeout_option(self) -> None:
265 # custom timeout given to the httpx client should be used
266 with httpx.Client(timeout=None) as http_client:
267 client = OpenAI(
268 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
269 )
270
271 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
272 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
273 assert timeout == httpx.Timeout(None)
274
275 # no timeout given to the httpx client should not use the httpx default
276 with httpx.Client() as http_client:
277 client = OpenAI(
278 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
279 )
280
281 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
282 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
283 assert timeout == DEFAULT_TIMEOUT
284
285 # explicitly passing the default timeout currently results in it being ignored
286 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
287 client = OpenAI(
288 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
289 )
290
291 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
292 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
293 assert timeout == DEFAULT_TIMEOUT # our default
294
295 def test_default_headers_option(self) -> None:
296 client = OpenAI(
297 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
298 )
299 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
300 assert request.headers.get("x-foo") == "bar"
301 assert request.headers.get("x-stainless-lang") == "python"
302
303 client2 = OpenAI(
304 base_url=base_url,
305 api_key=api_key,
306 _strict_response_validation=True,
307 default_headers={
308 "X-Foo": "stainless",
309 "X-Stainless-Lang": "my-overriding-header",
310 },
311 )
312 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
313 assert request.headers.get("x-foo") == "stainless"
314 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
315
316 def test_validate_headers(self) -> None:
317 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
318 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
319 assert request.headers.get("Authorization") == f"Bearer {api_key}"
320
321 with pytest.raises(OpenAIError):
322 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
323 _ = client2
324
325 def test_default_query_option(self) -> None:
326 client = OpenAI(
327 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
328 )
329 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
330 url = httpx.URL(request.url)
331 assert dict(url.params) == {"query_param": "bar"}
332
333 request = client._build_request(
334 FinalRequestOptions(
335 method="get",
336 url="/foo",
337 params={"foo": "baz", "query_param": "overriden"},
338 )
339 )
340 url = httpx.URL(request.url)
341 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
342
343 def test_request_extra_json(self) -> None:
344 request = self.client._build_request(
345 FinalRequestOptions(
346 method="post",
347 url="/foo",
348 json_data={"foo": "bar"},
349 extra_json={"baz": False},
350 ),
351 )
352 data = json.loads(request.content.decode("utf-8"))
353 assert data == {"foo": "bar", "baz": False}
354
355 request = self.client._build_request(
356 FinalRequestOptions(
357 method="post",
358 url="/foo",
359 extra_json={"baz": False},
360 ),
361 )
362 data = json.loads(request.content.decode("utf-8"))
363 assert data == {"baz": False}
364
365 # `extra_json` takes priority over `json_data` when keys clash
366 request = self.client._build_request(
367 FinalRequestOptions(
368 method="post",
369 url="/foo",
370 json_data={"foo": "bar", "baz": True},
371 extra_json={"baz": None},
372 ),
373 )
374 data = json.loads(request.content.decode("utf-8"))
375 assert data == {"foo": "bar", "baz": None}
376
377 def test_request_extra_headers(self) -> None:
378 request = self.client._build_request(
379 FinalRequestOptions(
380 method="post",
381 url="/foo",
382 **make_request_options(extra_headers={"X-Foo": "Foo"}),
383 ),
384 )
385 assert request.headers.get("X-Foo") == "Foo"
386
387 # `extra_headers` takes priority over `default_headers` when keys clash
388 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
389 FinalRequestOptions(
390 method="post",
391 url="/foo",
392 **make_request_options(
393 extra_headers={"X-Bar": "false"},
394 ),
395 ),
396 )
397 assert request.headers.get("X-Bar") == "false"
398
399 def test_request_extra_query(self) -> None:
400 request = self.client._build_request(
401 FinalRequestOptions(
402 method="post",
403 url="/foo",
404 **make_request_options(
405 extra_query={"my_query_param": "Foo"},
406 ),
407 ),
408 )
409 params = dict(request.url.params)
410 assert params == {"my_query_param": "Foo"}
411
412 # if both `query` and `extra_query` are given, they are merged
413 request = self.client._build_request(
414 FinalRequestOptions(
415 method="post",
416 url="/foo",
417 **make_request_options(
418 query={"bar": "1"},
419 extra_query={"foo": "2"},
420 ),
421 ),
422 )
423 params = dict(request.url.params)
424 assert params == {"bar": "1", "foo": "2"}
425
426 # `extra_query` takes priority over `query` when keys clash
427 request = self.client._build_request(
428 FinalRequestOptions(
429 method="post",
430 url="/foo",
431 **make_request_options(
432 query={"foo": "1"},
433 extra_query={"foo": "2"},
434 ),
435 ),
436 )
437 params = dict(request.url.params)
438 assert params == {"foo": "2"}
439
440 def test_multipart_repeating_array(self, client: OpenAI) -> None:
441 request = client._build_request(
442 FinalRequestOptions.construct(
443 method="get",
444 url="/foo",
445 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
446 json_data={"array": ["foo", "bar"]},
447 files=[("foo.txt", b"hello world")],
448 )
449 )
450
451 assert request.read().split(b"\r\n") == [
452 b"--6b7ba517decee4a450543ea6ae821c82",
453 b'Content-Disposition: form-data; name="array[]"',
454 b"",
455 b"foo",
456 b"--6b7ba517decee4a450543ea6ae821c82",
457 b'Content-Disposition: form-data; name="array[]"',
458 b"",
459 b"bar",
460 b"--6b7ba517decee4a450543ea6ae821c82",
461 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
462 b"Content-Type: application/octet-stream",
463 b"",
464 b"hello world",
465 b"--6b7ba517decee4a450543ea6ae821c82--",
466 b"",
467 ]
468
469 @pytest.mark.respx(base_url=base_url)
470 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
471 class Model1(BaseModel):
472 name: str
473
474 class Model2(BaseModel):
475 foo: str
476
477 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
478
479 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
480 assert isinstance(response, Model2)
481 assert response.foo == "bar"
482
483 @pytest.mark.respx(base_url=base_url)
484 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
485 """Union of objects with the same field name using a different type"""
486
487 class Model1(BaseModel):
488 foo: int
489
490 class Model2(BaseModel):
491 foo: str
492
493 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
494
495 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
496 assert isinstance(response, Model2)
497 assert response.foo == "bar"
498
499 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
500
501 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
502 assert isinstance(response, Model1)
503 assert response.foo == 1
504
505 @pytest.mark.respx(base_url=base_url)
506 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
507 """
508 Response that sets Content-Type to something other than application/json but returns json data
509 """
510
511 class Model(BaseModel):
512 foo: int
513
514 respx_mock.get("/foo").mock(
515 return_value=httpx.Response(
516 200,
517 content=json.dumps({"foo": 2}),
518 headers={"Content-Type": "application/text"},
519 )
520 )
521
522 response = self.client.get("/foo", cast_to=Model)
523 assert isinstance(response, Model)
524 assert response.foo == 2
525
526 def test_base_url_setter(self) -> None:
527 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
528 assert client.base_url == "https://example.com/from_init/"
529
530 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
531
532 assert client.base_url == "https://example.com/from_setter/"
533
534 def test_base_url_env(self) -> None:
535 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
536 client = OpenAI(api_key=api_key, _strict_response_validation=True)
537 assert client.base_url == "http://localhost:5000/from/env/"
538
539 @pytest.mark.parametrize(
540 "client",
541 [
542 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
543 OpenAI(
544 base_url="http://localhost:5000/custom/path/",
545 api_key=api_key,
546 _strict_response_validation=True,
547 http_client=httpx.Client(),
548 ),
549 ],
550 ids=["standard", "custom http client"],
551 )
552 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
553 request = client._build_request(
554 FinalRequestOptions(
555 method="post",
556 url="/foo",
557 json_data={"foo": "bar"},
558 ),
559 )
560 assert request.url == "http://localhost:5000/custom/path/foo"
561
562 @pytest.mark.parametrize(
563 "client",
564 [
565 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
566 OpenAI(
567 base_url="http://localhost:5000/custom/path/",
568 api_key=api_key,
569 _strict_response_validation=True,
570 http_client=httpx.Client(),
571 ),
572 ],
573 ids=["standard", "custom http client"],
574 )
575 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
576 request = client._build_request(
577 FinalRequestOptions(
578 method="post",
579 url="/foo",
580 json_data={"foo": "bar"},
581 ),
582 )
583 assert request.url == "http://localhost:5000/custom/path/foo"
584
585 @pytest.mark.parametrize(
586 "client",
587 [
588 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
589 OpenAI(
590 base_url="http://localhost:5000/custom/path/",
591 api_key=api_key,
592 _strict_response_validation=True,
593 http_client=httpx.Client(),
594 ),
595 ],
596 ids=["standard", "custom http client"],
597 )
598 def test_absolute_request_url(self, client: OpenAI) -> None:
599 request = client._build_request(
600 FinalRequestOptions(
601 method="post",
602 url="https://myapi.com/foo",
603 json_data={"foo": "bar"},
604 ),
605 )
606 assert request.url == "https://myapi.com/foo"
607
608 def test_copied_client_does_not_close_http(self) -> None:
609 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
610 assert not client.is_closed()
611
612 copied = client.copy()
613 assert copied is not client
614
615 del copied
616
617 assert not client.is_closed()
618
619 def test_client_context_manager(self) -> None:
620 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
621 with client as c2:
622 assert c2 is client
623 assert not c2.is_closed()
624 assert not client.is_closed()
625 assert client.is_closed()
626
627 @pytest.mark.respx(base_url=base_url)
628 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
629 class Model(BaseModel):
630 foo: str
631
632 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
633
634 with pytest.raises(APIResponseValidationError) as exc:
635 self.client.get("/foo", cast_to=Model)
636
637 assert isinstance(exc.value.__cause__, ValidationError)
638
639 @pytest.mark.respx(base_url=base_url)
640 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
641 class Model(BaseModel):
642 name: str
643
644 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
645
646 stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
647 assert isinstance(stream, Stream)
648 stream.response.close()
649
650 @pytest.mark.respx(base_url=base_url)
651 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
652 class Model(BaseModel):
653 name: str
654
655 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
656
657 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
658
659 with pytest.raises(APIResponseValidationError):
660 strict_client.get("/foo", cast_to=Model)
661
662 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
663
664 response = client.get("/foo", cast_to=Model)
665 assert isinstance(response, str) # type: ignore[unreachable]
666
667 @pytest.mark.parametrize(
668 "remaining_retries,retry_after,timeout",
669 [
670 [3, "20", 20],
671 [3, "0", 0.5],
672 [3, "-10", 0.5],
673 [3, "60", 60],
674 [3, "61", 0.5],
675 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
676 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
677 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
678 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
679 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
680 [3, "99999999999999999999999999999999999", 0.5],
681 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
682 [3, "", 0.5],
683 [2, "", 0.5 * 2.0],
684 [1, "", 0.5 * 4.0],
685 ],
686 )
687 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
688 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
689 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
690
691 headers = httpx.Headers({"retry-after": retry_after})
692 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
693 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
694 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
695
696 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
697 @pytest.mark.respx(base_url=base_url)
698 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
699 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
700
701 with pytest.raises(APITimeoutError):
702 self.client.post(
703 "/chat/completions",
704 body=dict(
705 messages=[
706 {
707 "role": "user",
708 "content": "Say this is a test",
709 }
710 ],
711 model="gpt-3.5-turbo",
712 ),
713 cast_to=httpx.Response,
714 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
715 )
716
717 assert _get_open_connections(self.client) == 0
718
719 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
720 @pytest.mark.respx(base_url=base_url)
721 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
722 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
723
724 with pytest.raises(APIStatusError):
725 self.client.post(
726 "/chat/completions",
727 body=dict(
728 messages=[
729 {
730 "role": "user",
731 "content": "Say this is a test",
732 }
733 ],
734 model="gpt-3.5-turbo",
735 ),
736 cast_to=httpx.Response,
737 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
738 )
739
740 assert _get_open_connections(self.client) == 0
741
742
743class TestAsyncOpenAI:
744 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
745
746 @pytest.mark.respx(base_url=base_url)
747 @pytest.mark.asyncio
748 async def test_raw_response(self, respx_mock: MockRouter) -> None:
749 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
750
751 response = await self.client.post("/foo", cast_to=httpx.Response)
752 assert response.status_code == 200
753 assert isinstance(response, httpx.Response)
754 assert response.json() == {"foo": "bar"}
755
756 @pytest.mark.respx(base_url=base_url)
757 @pytest.mark.asyncio
758 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
759 respx_mock.post("/foo").mock(
760 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
761 )
762
763 response = await self.client.post("/foo", cast_to=httpx.Response)
764 assert response.status_code == 200
765 assert isinstance(response, httpx.Response)
766 assert response.json() == {"foo": "bar"}
767
768 def test_copy(self) -> None:
769 copied = self.client.copy()
770 assert id(copied) != id(self.client)
771
772 copied = self.client.copy(api_key="another My API Key")
773 assert copied.api_key == "another My API Key"
774 assert self.client.api_key == "My API Key"
775
776 def test_copy_default_options(self) -> None:
777 # options that have a default are overridden correctly
778 copied = self.client.copy(max_retries=7)
779 assert copied.max_retries == 7
780 assert self.client.max_retries == 2
781
782 copied2 = copied.copy(max_retries=6)
783 assert copied2.max_retries == 6
784 assert copied.max_retries == 7
785
786 # timeout
787 assert isinstance(self.client.timeout, httpx.Timeout)
788 copied = self.client.copy(timeout=None)
789 assert copied.timeout is None
790 assert isinstance(self.client.timeout, httpx.Timeout)
791
792 def test_copy_default_headers(self) -> None:
793 client = AsyncOpenAI(
794 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
795 )
796 assert client.default_headers["X-Foo"] == "bar"
797
798 # does not override the already given value when not specified
799 copied = client.copy()
800 assert copied.default_headers["X-Foo"] == "bar"
801
802 # merges already given headers
803 copied = client.copy(default_headers={"X-Bar": "stainless"})
804 assert copied.default_headers["X-Foo"] == "bar"
805 assert copied.default_headers["X-Bar"] == "stainless"
806
807 # uses new values for any already given headers
808 copied = client.copy(default_headers={"X-Foo": "stainless"})
809 assert copied.default_headers["X-Foo"] == "stainless"
810
811 # set_default_headers
812
813 # completely overrides already set values
814 copied = client.copy(set_default_headers={})
815 assert copied.default_headers.get("X-Foo") is None
816
817 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
818 assert copied.default_headers["X-Bar"] == "Robert"
819
820 with pytest.raises(
821 ValueError,
822 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
823 ):
824 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
825
826 def test_copy_default_query(self) -> None:
827 client = AsyncOpenAI(
828 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
829 )
830 assert _get_params(client)["foo"] == "bar"
831
832 # does not override the already given value when not specified
833 copied = client.copy()
834 assert _get_params(copied)["foo"] == "bar"
835
836 # merges already given params
837 copied = client.copy(default_query={"bar": "stainless"})
838 params = _get_params(copied)
839 assert params["foo"] == "bar"
840 assert params["bar"] == "stainless"
841
842 # uses new values for any already given headers
843 copied = client.copy(default_query={"foo": "stainless"})
844 assert _get_params(copied)["foo"] == "stainless"
845
846 # set_default_query
847
848 # completely overrides already set values
849 copied = client.copy(set_default_query={})
850 assert _get_params(copied) == {}
851
852 copied = client.copy(set_default_query={"bar": "Robert"})
853 assert _get_params(copied)["bar"] == "Robert"
854
855 with pytest.raises(
856 ValueError,
857 # TODO: update
858 match="`default_query` and `set_default_query` arguments are mutually exclusive",
859 ):
860 client.copy(set_default_query={}, default_query={"foo": "Bar"})
861
862 def test_copy_signature(self) -> None:
863 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
864 init_signature = inspect.signature(
865 # mypy doesn't like that we access the `__init__` property.
866 self.client.__init__, # type: ignore[misc]
867 )
868 copy_signature = inspect.signature(self.client.copy)
869 exclude_params = {"transport", "proxies", "_strict_response_validation"}
870
871 for name in init_signature.parameters.keys():
872 if name in exclude_params:
873 continue
874
875 copy_param = copy_signature.parameters.get(name)
876 assert copy_param is not None, f"copy() signature is missing the {name} param"
877
878 def test_copy_build_request(self) -> None:
879 options = FinalRequestOptions(method="get", url="/foo")
880
881 def build_request(options: FinalRequestOptions) -> None:
882 client = self.client.copy()
883 client._build_request(options)
884
885 # ensure that the machinery is warmed up before tracing starts.
886 build_request(options)
887 gc.collect()
888
889 tracemalloc.start(1000)
890
891 snapshot_before = tracemalloc.take_snapshot()
892
893 ITERATIONS = 10
894 for _ in range(ITERATIONS):
895 build_request(options)
896
897 gc.collect()
898 snapshot_after = tracemalloc.take_snapshot()
899
900 tracemalloc.stop()
901
902 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
903 if diff.count == 0:
904 # Avoid false positives by considering only leaks (i.e. allocations that persist).
905 return
906
907 if diff.count % ITERATIONS != 0:
908 # Avoid false positives by considering only leaks that appear per iteration.
909 return
910
911 for frame in diff.traceback:
912 if any(
913 frame.filename.endswith(fragment)
914 for fragment in [
915 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
916 #
917 # removing the decorator fixes the leak for reasons we don't understand.
918 "openai/_legacy_response.py",
919 "openai/_response.py",
920 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
921 "openai/_compat.py",
922 # Standard library leaks we don't care about.
923 "/logging/__init__.py",
924 ]
925 ):
926 return
927
928 leaks.append(diff)
929
930 leaks: list[tracemalloc.StatisticDiff] = []
931 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
932 add_leak(leaks, diff)
933 if leaks:
934 for leak in leaks:
935 print("MEMORY LEAK:", leak)
936 for frame in leak.traceback:
937 print(frame)
938 raise AssertionError()
939
940 async def test_request_timeout(self) -> None:
941 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
942 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
943 assert timeout == DEFAULT_TIMEOUT
944
945 request = self.client._build_request(
946 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
947 )
948 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
949 assert timeout == httpx.Timeout(100.0)
950
951 async def test_client_timeout_option(self) -> None:
952 client = AsyncOpenAI(
953 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
954 )
955
956 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
957 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
958 assert timeout == httpx.Timeout(0)
959
960 async def test_http_client_timeout_option(self) -> None:
961 # custom timeout given to the httpx client should be used
962 async with httpx.AsyncClient(timeout=None) as http_client:
963 client = AsyncOpenAI(
964 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
965 )
966
967 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
968 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
969 assert timeout == httpx.Timeout(None)
970
971 # no timeout given to the httpx client should not use the httpx default
972 async with httpx.AsyncClient() as http_client:
973 client = AsyncOpenAI(
974 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
975 )
976
977 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
978 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
979 assert timeout == DEFAULT_TIMEOUT
980
981 # explicitly passing the default timeout currently results in it being ignored
982 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
983 client = AsyncOpenAI(
984 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
985 )
986
987 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
988 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
989 assert timeout == DEFAULT_TIMEOUT # our default
990
991 def test_default_headers_option(self) -> None:
992 client = AsyncOpenAI(
993 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
994 )
995 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
996 assert request.headers.get("x-foo") == "bar"
997 assert request.headers.get("x-stainless-lang") == "python"
998
999 client2 = AsyncOpenAI(
1000 base_url=base_url,
1001 api_key=api_key,
1002 _strict_response_validation=True,
1003 default_headers={
1004 "X-Foo": "stainless",
1005 "X-Stainless-Lang": "my-overriding-header",
1006 },
1007 )
1008 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1009 assert request.headers.get("x-foo") == "stainless"
1010 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1011
1012 def test_validate_headers(self) -> None:
1013 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1014 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1015 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1016
1017 with pytest.raises(OpenAIError):
1018 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1019 _ = client2
1020
1021 def test_default_query_option(self) -> None:
1022 client = AsyncOpenAI(
1023 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1024 )
1025 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1026 url = httpx.URL(request.url)
1027 assert dict(url.params) == {"query_param": "bar"}
1028
1029 request = client._build_request(
1030 FinalRequestOptions(
1031 method="get",
1032 url="/foo",
1033 params={"foo": "baz", "query_param": "overriden"},
1034 )
1035 )
1036 url = httpx.URL(request.url)
1037 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
1038
1039 def test_request_extra_json(self) -> None:
1040 request = self.client._build_request(
1041 FinalRequestOptions(
1042 method="post",
1043 url="/foo",
1044 json_data={"foo": "bar"},
1045 extra_json={"baz": False},
1046 ),
1047 )
1048 data = json.loads(request.content.decode("utf-8"))
1049 assert data == {"foo": "bar", "baz": False}
1050
1051 request = self.client._build_request(
1052 FinalRequestOptions(
1053 method="post",
1054 url="/foo",
1055 extra_json={"baz": False},
1056 ),
1057 )
1058 data = json.loads(request.content.decode("utf-8"))
1059 assert data == {"baz": False}
1060
1061 # `extra_json` takes priority over `json_data` when keys clash
1062 request = self.client._build_request(
1063 FinalRequestOptions(
1064 method="post",
1065 url="/foo",
1066 json_data={"foo": "bar", "baz": True},
1067 extra_json={"baz": None},
1068 ),
1069 )
1070 data = json.loads(request.content.decode("utf-8"))
1071 assert data == {"foo": "bar", "baz": None}
1072
1073 def test_request_extra_headers(self) -> None:
1074 request = self.client._build_request(
1075 FinalRequestOptions(
1076 method="post",
1077 url="/foo",
1078 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1079 ),
1080 )
1081 assert request.headers.get("X-Foo") == "Foo"
1082
1083 # `extra_headers` takes priority over `default_headers` when keys clash
1084 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1085 FinalRequestOptions(
1086 method="post",
1087 url="/foo",
1088 **make_request_options(
1089 extra_headers={"X-Bar": "false"},
1090 ),
1091 ),
1092 )
1093 assert request.headers.get("X-Bar") == "false"
1094
1095 def test_request_extra_query(self) -> None:
1096 request = self.client._build_request(
1097 FinalRequestOptions(
1098 method="post",
1099 url="/foo",
1100 **make_request_options(
1101 extra_query={"my_query_param": "Foo"},
1102 ),
1103 ),
1104 )
1105 params = dict(request.url.params)
1106 assert params == {"my_query_param": "Foo"}
1107
1108 # if both `query` and `extra_query` are given, they are merged
1109 request = self.client._build_request(
1110 FinalRequestOptions(
1111 method="post",
1112 url="/foo",
1113 **make_request_options(
1114 query={"bar": "1"},
1115 extra_query={"foo": "2"},
1116 ),
1117 ),
1118 )
1119 params = dict(request.url.params)
1120 assert params == {"bar": "1", "foo": "2"}
1121
1122 # `extra_query` takes priority over `query` when keys clash
1123 request = self.client._build_request(
1124 FinalRequestOptions(
1125 method="post",
1126 url="/foo",
1127 **make_request_options(
1128 query={"foo": "1"},
1129 extra_query={"foo": "2"},
1130 ),
1131 ),
1132 )
1133 params = dict(request.url.params)
1134 assert params == {"foo": "2"}
1135
1136 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1137 request = async_client._build_request(
1138 FinalRequestOptions.construct(
1139 method="get",
1140 url="/foo",
1141 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1142 json_data={"array": ["foo", "bar"]},
1143 files=[("foo.txt", b"hello world")],
1144 )
1145 )
1146
1147 assert request.read().split(b"\r\n") == [
1148 b"--6b7ba517decee4a450543ea6ae821c82",
1149 b'Content-Disposition: form-data; name="array[]"',
1150 b"",
1151 b"foo",
1152 b"--6b7ba517decee4a450543ea6ae821c82",
1153 b'Content-Disposition: form-data; name="array[]"',
1154 b"",
1155 b"bar",
1156 b"--6b7ba517decee4a450543ea6ae821c82",
1157 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1158 b"Content-Type: application/octet-stream",
1159 b"",
1160 b"hello world",
1161 b"--6b7ba517decee4a450543ea6ae821c82--",
1162 b"",
1163 ]
1164
1165 @pytest.mark.respx(base_url=base_url)
1166 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1167 class Model1(BaseModel):
1168 name: str
1169
1170 class Model2(BaseModel):
1171 foo: str
1172
1173 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1174
1175 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1176 assert isinstance(response, Model2)
1177 assert response.foo == "bar"
1178
1179 @pytest.mark.respx(base_url=base_url)
1180 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1181 """Union of objects with the same field name using a different type"""
1182
1183 class Model1(BaseModel):
1184 foo: int
1185
1186 class Model2(BaseModel):
1187 foo: str
1188
1189 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1190
1191 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1192 assert isinstance(response, Model2)
1193 assert response.foo == "bar"
1194
1195 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1196
1197 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1198 assert isinstance(response, Model1)
1199 assert response.foo == 1
1200
1201 @pytest.mark.respx(base_url=base_url)
1202 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1203 """
1204 Response that sets Content-Type to something other than application/json but returns json data
1205 """
1206
1207 class Model(BaseModel):
1208 foo: int
1209
1210 respx_mock.get("/foo").mock(
1211 return_value=httpx.Response(
1212 200,
1213 content=json.dumps({"foo": 2}),
1214 headers={"Content-Type": "application/text"},
1215 )
1216 )
1217
1218 response = await self.client.get("/foo", cast_to=Model)
1219 assert isinstance(response, Model)
1220 assert response.foo == 2
1221
1222 def test_base_url_setter(self) -> None:
1223 client = AsyncOpenAI(
1224 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1225 )
1226 assert client.base_url == "https://example.com/from_init/"
1227
1228 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1229
1230 assert client.base_url == "https://example.com/from_setter/"
1231
1232 def test_base_url_env(self) -> None:
1233 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1234 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1235 assert client.base_url == "http://localhost:5000/from/env/"
1236
1237 @pytest.mark.parametrize(
1238 "client",
1239 [
1240 AsyncOpenAI(
1241 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1242 ),
1243 AsyncOpenAI(
1244 base_url="http://localhost:5000/custom/path/",
1245 api_key=api_key,
1246 _strict_response_validation=True,
1247 http_client=httpx.AsyncClient(),
1248 ),
1249 ],
1250 ids=["standard", "custom http client"],
1251 )
1252 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1253 request = client._build_request(
1254 FinalRequestOptions(
1255 method="post",
1256 url="/foo",
1257 json_data={"foo": "bar"},
1258 ),
1259 )
1260 assert request.url == "http://localhost:5000/custom/path/foo"
1261
1262 @pytest.mark.parametrize(
1263 "client",
1264 [
1265 AsyncOpenAI(
1266 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1267 ),
1268 AsyncOpenAI(
1269 base_url="http://localhost:5000/custom/path/",
1270 api_key=api_key,
1271 _strict_response_validation=True,
1272 http_client=httpx.AsyncClient(),
1273 ),
1274 ],
1275 ids=["standard", "custom http client"],
1276 )
1277 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1278 request = client._build_request(
1279 FinalRequestOptions(
1280 method="post",
1281 url="/foo",
1282 json_data={"foo": "bar"},
1283 ),
1284 )
1285 assert request.url == "http://localhost:5000/custom/path/foo"
1286
1287 @pytest.mark.parametrize(
1288 "client",
1289 [
1290 AsyncOpenAI(
1291 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1292 ),
1293 AsyncOpenAI(
1294 base_url="http://localhost:5000/custom/path/",
1295 api_key=api_key,
1296 _strict_response_validation=True,
1297 http_client=httpx.AsyncClient(),
1298 ),
1299 ],
1300 ids=["standard", "custom http client"],
1301 )
1302 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1303 request = client._build_request(
1304 FinalRequestOptions(
1305 method="post",
1306 url="https://myapi.com/foo",
1307 json_data={"foo": "bar"},
1308 ),
1309 )
1310 assert request.url == "https://myapi.com/foo"
1311
1312 async def test_copied_client_does_not_close_http(self) -> None:
1313 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1314 assert not client.is_closed()
1315
1316 copied = client.copy()
1317 assert copied is not client
1318
1319 del copied
1320
1321 await asyncio.sleep(0.2)
1322 assert not client.is_closed()
1323
1324 async def test_client_context_manager(self) -> None:
1325 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1326 async with client as c2:
1327 assert c2 is client
1328 assert not c2.is_closed()
1329 assert not client.is_closed()
1330 assert client.is_closed()
1331
1332 @pytest.mark.respx(base_url=base_url)
1333 @pytest.mark.asyncio
1334 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1335 class Model(BaseModel):
1336 foo: str
1337
1338 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1339
1340 with pytest.raises(APIResponseValidationError) as exc:
1341 await self.client.get("/foo", cast_to=Model)
1342
1343 assert isinstance(exc.value.__cause__, ValidationError)
1344
1345 @pytest.mark.respx(base_url=base_url)
1346 @pytest.mark.asyncio
1347 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1348 class Model(BaseModel):
1349 name: str
1350
1351 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1352
1353 stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1354 assert isinstance(stream, AsyncStream)
1355 await stream.response.aclose()
1356
1357 @pytest.mark.respx(base_url=base_url)
1358 @pytest.mark.asyncio
1359 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1360 class Model(BaseModel):
1361 name: str
1362
1363 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1364
1365 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1366
1367 with pytest.raises(APIResponseValidationError):
1368 await strict_client.get("/foo", cast_to=Model)
1369
1370 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1371
1372 response = await client.get("/foo", cast_to=Model)
1373 assert isinstance(response, str) # type: ignore[unreachable]
1374
1375 @pytest.mark.parametrize(
1376 "remaining_retries,retry_after,timeout",
1377 [
1378 [3, "20", 20],
1379 [3, "0", 0.5],
1380 [3, "-10", 0.5],
1381 [3, "60", 60],
1382 [3, "61", 0.5],
1383 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1384 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1385 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1386 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1387 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1388 [3, "99999999999999999999999999999999999", 0.5],
1389 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1390 [3, "", 0.5],
1391 [2, "", 0.5 * 2.0],
1392 [1, "", 0.5 * 4.0],
1393 ],
1394 )
1395 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1396 @pytest.mark.asyncio
1397 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1398 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1399
1400 headers = httpx.Headers({"retry-after": retry_after})
1401 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1402 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1403 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1404
1405 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1406 @pytest.mark.respx(base_url=base_url)
1407 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1408 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1409
1410 with pytest.raises(APITimeoutError):
1411 await self.client.post(
1412 "/chat/completions",
1413 body=dict(
1414 messages=[
1415 {
1416 "role": "user",
1417 "content": "Say this is a test",
1418 }
1419 ],
1420 model="gpt-3.5-turbo",
1421 ),
1422 cast_to=httpx.Response,
1423 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1424 )
1425
1426 assert _get_open_connections(self.client) == 0
1427
1428 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1429 @pytest.mark.respx(base_url=base_url)
1430 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1431 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1432
1433 with pytest.raises(APIStatusError):
1434 await self.client.post(
1435 "/chat/completions",
1436 body=dict(
1437 messages=[
1438 {
1439 "role": "user",
1440 "content": "Say this is a test",
1441 }
1442 ],
1443 model="gpt-3.5-turbo",
1444 ),
1445 cast_to=httpx.Response,
1446 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1447 )
1448
1449 assert _get_open_connections(self.client) == 0
1450