openai/openai-python

Public

mirrored from https://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
a3da0196ffb0ea31e70bf935193caa6f92a0a22b

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1527lines · modecode

1# File generated from our OpenAPI spec by Stainless.
2
3from __future__ import annotations
4
5import gc
6import os
7import json
8import asyncio
9import inspect
10import tracemalloc
11from typing import Any, Union, cast
12from unittest import mock
13
14import httpx
15import pytest
16from respx import MockRouter
17from pydantic import ValidationError
18
19from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
20from openai._client import OpenAI, AsyncOpenAI
21from openai._models import BaseModel, FinalRequestOptions
22from openai._streaming import Stream, AsyncStream
23from openai._exceptions import (
24 OpenAIError,
25 APIStatusError,
26 APITimeoutError,
27 APIConnectionError,
28 APIResponseValidationError,
29)
30from openai._base_client import (
31 DEFAULT_TIMEOUT,
32 HTTPX_DEFAULT_TIMEOUT,
33 BaseClient,
34 make_request_options,
35)
36
37from .utils import update_env
38
39base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
40api_key = "My API Key"
41
42
43def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
44 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
45 url = httpx.URL(request.url)
46 return dict(url.params)
47
48
49_original_response_init = cast(Any, httpx.Response.__init__) # type: ignore
50
51
52def _low_retry_response_init(*args: Any, **kwargs: Any) -> Any:
53 headers = cast("list[tuple[bytes, bytes]]", kwargs["headers"])
54 headers.append((b"retry-after", b"0.1"))
55
56 return _original_response_init(*args, **kwargs)
57
58
59def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
60 transport = client._client._transport
61 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
62
63 pool = transport._pool
64 return len(pool._requests)
65
66
67class TestOpenAI:
68 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
69
70 @pytest.mark.respx(base_url=base_url)
71 def test_raw_response(self, respx_mock: MockRouter) -> None:
72 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
73
74 response = self.client.post("/foo", cast_to=httpx.Response)
75 assert response.status_code == 200
76 assert isinstance(response, httpx.Response)
77 assert response.json() == {"foo": "bar"}
78
79 @pytest.mark.respx(base_url=base_url)
80 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
81 respx_mock.post("/foo").mock(
82 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
83 )
84
85 response = self.client.post("/foo", cast_to=httpx.Response)
86 assert response.status_code == 200
87 assert isinstance(response, httpx.Response)
88 assert response.json() == {"foo": "bar"}
89
90 def test_copy(self) -> None:
91 copied = self.client.copy()
92 assert id(copied) != id(self.client)
93
94 copied = self.client.copy(api_key="another My API Key")
95 assert copied.api_key == "another My API Key"
96 assert self.client.api_key == "My API Key"
97
98 def test_copy_default_options(self) -> None:
99 # options that have a default are overridden correctly
100 copied = self.client.copy(max_retries=7)
101 assert copied.max_retries == 7
102 assert self.client.max_retries == 2
103
104 copied2 = copied.copy(max_retries=6)
105 assert copied2.max_retries == 6
106 assert copied.max_retries == 7
107
108 # timeout
109 assert isinstance(self.client.timeout, httpx.Timeout)
110 copied = self.client.copy(timeout=None)
111 assert copied.timeout is None
112 assert isinstance(self.client.timeout, httpx.Timeout)
113
114 def test_copy_default_headers(self) -> None:
115 client = OpenAI(
116 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
117 )
118 assert client.default_headers["X-Foo"] == "bar"
119
120 # does not override the already given value when not specified
121 copied = client.copy()
122 assert copied.default_headers["X-Foo"] == "bar"
123
124 # merges already given headers
125 copied = client.copy(default_headers={"X-Bar": "stainless"})
126 assert copied.default_headers["X-Foo"] == "bar"
127 assert copied.default_headers["X-Bar"] == "stainless"
128
129 # uses new values for any already given headers
130 copied = client.copy(default_headers={"X-Foo": "stainless"})
131 assert copied.default_headers["X-Foo"] == "stainless"
132
133 # set_default_headers
134
135 # completely overrides already set values
136 copied = client.copy(set_default_headers={})
137 assert copied.default_headers.get("X-Foo") is None
138
139 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
140 assert copied.default_headers["X-Bar"] == "Robert"
141
142 with pytest.raises(
143 ValueError,
144 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
145 ):
146 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
147
148 def test_copy_default_query(self) -> None:
149 client = OpenAI(
150 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
151 )
152 assert _get_params(client)["foo"] == "bar"
153
154 # does not override the already given value when not specified
155 copied = client.copy()
156 assert _get_params(copied)["foo"] == "bar"
157
158 # merges already given params
159 copied = client.copy(default_query={"bar": "stainless"})
160 params = _get_params(copied)
161 assert params["foo"] == "bar"
162 assert params["bar"] == "stainless"
163
164 # uses new values for any already given headers
165 copied = client.copy(default_query={"foo": "stainless"})
166 assert _get_params(copied)["foo"] == "stainless"
167
168 # set_default_query
169
170 # completely overrides already set values
171 copied = client.copy(set_default_query={})
172 assert _get_params(copied) == {}
173
174 copied = client.copy(set_default_query={"bar": "Robert"})
175 assert _get_params(copied)["bar"] == "Robert"
176
177 with pytest.raises(
178 ValueError,
179 # TODO: update
180 match="`default_query` and `set_default_query` arguments are mutually exclusive",
181 ):
182 client.copy(set_default_query={}, default_query={"foo": "Bar"})
183
184 def test_copy_signature(self) -> None:
185 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
186 init_signature = inspect.signature(
187 # mypy doesn't like that we access the `__init__` property.
188 self.client.__init__, # type: ignore[misc]
189 )
190 copy_signature = inspect.signature(self.client.copy)
191 exclude_params = {"transport", "proxies", "_strict_response_validation"}
192
193 for name in init_signature.parameters.keys():
194 if name in exclude_params:
195 continue
196
197 copy_param = copy_signature.parameters.get(name)
198 assert copy_param is not None, f"copy() signature is missing the {name} param"
199
200 def test_copy_build_request(self) -> None:
201 options = FinalRequestOptions(method="get", url="/foo")
202
203 def build_request(options: FinalRequestOptions) -> None:
204 client = self.client.copy()
205 client._build_request(options)
206
207 # ensure that the machinery is warmed up before tracing starts.
208 build_request(options)
209 gc.collect()
210
211 tracemalloc.start(1000)
212
213 snapshot_before = tracemalloc.take_snapshot()
214
215 ITERATIONS = 10
216 for _ in range(ITERATIONS):
217 build_request(options)
218 gc.collect()
219
220 snapshot_after = tracemalloc.take_snapshot()
221
222 tracemalloc.stop()
223
224 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
225 if diff.count == 0:
226 # Avoid false positives by considering only leaks (i.e. allocations that persist).
227 return
228
229 if diff.count % ITERATIONS != 0:
230 # Avoid false positives by considering only leaks that appear per iteration.
231 return
232
233 for frame in diff.traceback:
234 if any(
235 frame.filename.endswith(fragment)
236 for fragment in [
237 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
238 #
239 # removing the decorator fixes the leak for reasons we don't understand.
240 "openai/_response.py",
241 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
242 "openai/_compat.py",
243 # Standard library leaks we don't care about.
244 "/logging/__init__.py",
245 ]
246 ):
247 return
248
249 leaks.append(diff)
250
251 leaks: list[tracemalloc.StatisticDiff] = []
252 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
253 add_leak(leaks, diff)
254 if leaks:
255 for leak in leaks:
256 print("MEMORY LEAK:", leak)
257 for frame in leak.traceback:
258 print(frame)
259 raise AssertionError()
260
261 def test_request_timeout(self) -> None:
262 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
263 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
264 assert timeout == DEFAULT_TIMEOUT
265
266 request = self.client._build_request(
267 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
268 )
269 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
270 assert timeout == httpx.Timeout(100.0)
271
272 def test_client_timeout_option(self) -> None:
273 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
274
275 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
276 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
277 assert timeout == httpx.Timeout(0)
278
279 def test_http_client_timeout_option(self) -> None:
280 # custom timeout given to the httpx client should be used
281 with httpx.Client(timeout=None) as http_client:
282 client = OpenAI(
283 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
284 )
285
286 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
287 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
288 assert timeout == httpx.Timeout(None)
289
290 # no timeout given to the httpx client should not use the httpx default
291 with httpx.Client() as http_client:
292 client = OpenAI(
293 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
294 )
295
296 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
297 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
298 assert timeout == DEFAULT_TIMEOUT
299
300 # explicitly passing the default timeout currently results in it being ignored
301 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
302 client = OpenAI(
303 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
304 )
305
306 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
307 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
308 assert timeout == DEFAULT_TIMEOUT # our default
309
310 def test_default_headers_option(self) -> None:
311 client = OpenAI(
312 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
313 )
314 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
315 assert request.headers.get("x-foo") == "bar"
316 assert request.headers.get("x-stainless-lang") == "python"
317
318 client2 = OpenAI(
319 base_url=base_url,
320 api_key=api_key,
321 _strict_response_validation=True,
322 default_headers={
323 "X-Foo": "stainless",
324 "X-Stainless-Lang": "my-overriding-header",
325 },
326 )
327 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
328 assert request.headers.get("x-foo") == "stainless"
329 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
330
331 def test_validate_headers(self) -> None:
332 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
333 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
334 assert request.headers.get("Authorization") == f"Bearer {api_key}"
335
336 with pytest.raises(OpenAIError):
337 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
338 _ = client2
339
340 def test_default_query_option(self) -> None:
341 client = OpenAI(
342 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
343 )
344 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
345 url = httpx.URL(request.url)
346 assert dict(url.params) == {"query_param": "bar"}
347
348 request = client._build_request(
349 FinalRequestOptions(
350 method="get",
351 url="/foo",
352 params={"foo": "baz", "query_param": "overriden"},
353 )
354 )
355 url = httpx.URL(request.url)
356 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
357
358 def test_request_extra_json(self) -> None:
359 request = self.client._build_request(
360 FinalRequestOptions(
361 method="post",
362 url="/foo",
363 json_data={"foo": "bar"},
364 extra_json={"baz": False},
365 ),
366 )
367 data = json.loads(request.content.decode("utf-8"))
368 assert data == {"foo": "bar", "baz": False}
369
370 request = self.client._build_request(
371 FinalRequestOptions(
372 method="post",
373 url="/foo",
374 extra_json={"baz": False},
375 ),
376 )
377 data = json.loads(request.content.decode("utf-8"))
378 assert data == {"baz": False}
379
380 # `extra_json` takes priority over `json_data` when keys clash
381 request = self.client._build_request(
382 FinalRequestOptions(
383 method="post",
384 url="/foo",
385 json_data={"foo": "bar", "baz": True},
386 extra_json={"baz": None},
387 ),
388 )
389 data = json.loads(request.content.decode("utf-8"))
390 assert data == {"foo": "bar", "baz": None}
391
392 def test_request_extra_headers(self) -> None:
393 request = self.client._build_request(
394 FinalRequestOptions(
395 method="post",
396 url="/foo",
397 **make_request_options(extra_headers={"X-Foo": "Foo"}),
398 ),
399 )
400 assert request.headers.get("X-Foo") == "Foo"
401
402 # `extra_headers` takes priority over `default_headers` when keys clash
403 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
404 FinalRequestOptions(
405 method="post",
406 url="/foo",
407 **make_request_options(
408 extra_headers={"X-Bar": "false"},
409 ),
410 ),
411 )
412 assert request.headers.get("X-Bar") == "false"
413
414 def test_request_extra_query(self) -> None:
415 request = self.client._build_request(
416 FinalRequestOptions(
417 method="post",
418 url="/foo",
419 **make_request_options(
420 extra_query={"my_query_param": "Foo"},
421 ),
422 ),
423 )
424 params = dict(request.url.params)
425 assert params == {"my_query_param": "Foo"}
426
427 # if both `query` and `extra_query` are given, they are merged
428 request = self.client._build_request(
429 FinalRequestOptions(
430 method="post",
431 url="/foo",
432 **make_request_options(
433 query={"bar": "1"},
434 extra_query={"foo": "2"},
435 ),
436 ),
437 )
438 params = dict(request.url.params)
439 assert params == {"bar": "1", "foo": "2"}
440
441 # `extra_query` takes priority over `query` when keys clash
442 request = self.client._build_request(
443 FinalRequestOptions(
444 method="post",
445 url="/foo",
446 **make_request_options(
447 query={"foo": "1"},
448 extra_query={"foo": "2"},
449 ),
450 ),
451 )
452 params = dict(request.url.params)
453 assert params == {"foo": "2"}
454
455 @pytest.mark.respx(base_url=base_url)
456 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
457 class Model1(BaseModel):
458 name: str
459
460 class Model2(BaseModel):
461 foo: str
462
463 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
464
465 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
466 assert isinstance(response, Model2)
467 assert response.foo == "bar"
468
469 @pytest.mark.respx(base_url=base_url)
470 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
471 """Union of objects with the same field name using a different type"""
472
473 class Model1(BaseModel):
474 foo: int
475
476 class Model2(BaseModel):
477 foo: str
478
479 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
480
481 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
482 assert isinstance(response, Model2)
483 assert response.foo == "bar"
484
485 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
486
487 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
488 assert isinstance(response, Model1)
489 assert response.foo == 1
490
491 @pytest.mark.respx(base_url=base_url)
492 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
493 """
494 Response that sets Content-Type to something other than application/json but returns json data
495 """
496
497 class Model(BaseModel):
498 foo: int
499
500 respx_mock.get("/foo").mock(
501 return_value=httpx.Response(
502 200,
503 content=json.dumps({"foo": 2}),
504 headers={"Content-Type": "application/text"},
505 )
506 )
507
508 response = self.client.get("/foo", cast_to=Model)
509 assert isinstance(response, Model)
510 assert response.foo == 2
511
512 def test_base_url_setter(self) -> None:
513 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
514 assert client.base_url == "https://example.com/from_init/"
515
516 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
517
518 assert client.base_url == "https://example.com/from_setter/"
519
520 def test_base_url_env(self) -> None:
521 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
522 client = OpenAI(api_key=api_key, _strict_response_validation=True)
523 assert client.base_url == "http://localhost:5000/from/env/"
524
525 @pytest.mark.parametrize(
526 "client",
527 [
528 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
529 OpenAI(
530 base_url="http://localhost:5000/custom/path/",
531 api_key=api_key,
532 _strict_response_validation=True,
533 http_client=httpx.Client(),
534 ),
535 ],
536 ids=["standard", "custom http client"],
537 )
538 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
539 request = client._build_request(
540 FinalRequestOptions(
541 method="post",
542 url="/foo",
543 json_data={"foo": "bar"},
544 ),
545 )
546 assert request.url == "http://localhost:5000/custom/path/foo"
547
548 @pytest.mark.parametrize(
549 "client",
550 [
551 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
552 OpenAI(
553 base_url="http://localhost:5000/custom/path/",
554 api_key=api_key,
555 _strict_response_validation=True,
556 http_client=httpx.Client(),
557 ),
558 ],
559 ids=["standard", "custom http client"],
560 )
561 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
562 request = client._build_request(
563 FinalRequestOptions(
564 method="post",
565 url="/foo",
566 json_data={"foo": "bar"},
567 ),
568 )
569 assert request.url == "http://localhost:5000/custom/path/foo"
570
571 @pytest.mark.parametrize(
572 "client",
573 [
574 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
575 OpenAI(
576 base_url="http://localhost:5000/custom/path/",
577 api_key=api_key,
578 _strict_response_validation=True,
579 http_client=httpx.Client(),
580 ),
581 ],
582 ids=["standard", "custom http client"],
583 )
584 def test_absolute_request_url(self, client: OpenAI) -> None:
585 request = client._build_request(
586 FinalRequestOptions(
587 method="post",
588 url="https://myapi.com/foo",
589 json_data={"foo": "bar"},
590 ),
591 )
592 assert request.url == "https://myapi.com/foo"
593
594 def test_client_del(self) -> None:
595 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
596 assert not client.is_closed()
597
598 client.__del__()
599
600 assert client.is_closed()
601
602 def test_copied_client_does_not_close_http(self) -> None:
603 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
604 assert not client.is_closed()
605
606 copied = client.copy()
607 assert copied is not client
608
609 copied.__del__()
610
611 assert not copied.is_closed()
612 assert not client.is_closed()
613
614 def test_client_context_manager(self) -> None:
615 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
616 with client as c2:
617 assert c2 is client
618 assert not c2.is_closed()
619 assert not client.is_closed()
620 assert client.is_closed()
621
622 @pytest.mark.respx(base_url=base_url)
623 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
624 class Model(BaseModel):
625 foo: str
626
627 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
628
629 with pytest.raises(APIResponseValidationError) as exc:
630 self.client.get("/foo", cast_to=Model)
631
632 assert isinstance(exc.value.__cause__, ValidationError)
633
634 @pytest.mark.respx(base_url=base_url)
635 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
636 class Model(BaseModel):
637 name: str
638
639 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
640
641 response = self.client.post("/foo", cast_to=Model, stream=True)
642 assert isinstance(response, Stream)
643
644 @pytest.mark.respx(base_url=base_url)
645 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
646 class Model(BaseModel):
647 name: str
648
649 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
650
651 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
652
653 with pytest.raises(APIResponseValidationError):
654 strict_client.get("/foo", cast_to=Model)
655
656 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
657
658 response = client.get("/foo", cast_to=Model)
659 assert isinstance(response, str) # type: ignore[unreachable]
660
661 @pytest.mark.parametrize(
662 "remaining_retries,retry_after,timeout",
663 [
664 [3, "20", 20],
665 [3, "0", 0.5],
666 [3, "-10", 0.5],
667 [3, "60", 60],
668 [3, "61", 0.5],
669 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
670 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
671 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
672 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
673 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
674 [3, "99999999999999999999999999999999999", 0.5],
675 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
676 [3, "", 0.5],
677 [2, "", 0.5 * 2.0],
678 [1, "", 0.5 * 4.0],
679 ],
680 )
681 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
682 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
683 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
684
685 headers = httpx.Headers({"retry-after": retry_after})
686 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
687 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
688 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
689
690 @mock.patch("httpx.Response.__init__", _low_retry_response_init)
691 def test_retrying_timeout_errors_doesnt_leak(self) -> None:
692 def raise_for_status(response: httpx.Response) -> None:
693 raise httpx.TimeoutException("Test timeout error", request=response.request)
694
695 with mock.patch("httpx.Response.raise_for_status", raise_for_status):
696 with pytest.raises(APITimeoutError):
697 self.client.post(
698 "/chat/completions",
699 body=dict(
700 messages=[
701 {
702 "role": "user",
703 "content": "Say this is a test",
704 }
705 ],
706 model="gpt-3.5-turbo",
707 ),
708 cast_to=httpx.Response,
709 options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
710 )
711
712 assert _get_open_connections(self.client) == 0
713
714 @mock.patch("httpx.Response.__init__", _low_retry_response_init)
715 def test_retrying_runtime_errors_doesnt_leak(self) -> None:
716 def raise_for_status(_response: httpx.Response) -> None:
717 raise RuntimeError("Test error")
718
719 with mock.patch("httpx.Response.raise_for_status", raise_for_status):
720 with pytest.raises(APIConnectionError):
721 self.client.post(
722 "/chat/completions",
723 body=dict(
724 messages=[
725 {
726 "role": "user",
727 "content": "Say this is a test",
728 }
729 ],
730 model="gpt-3.5-turbo",
731 ),
732 cast_to=httpx.Response,
733 options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
734 )
735
736 assert _get_open_connections(self.client) == 0
737
738 @mock.patch("httpx.Response.__init__", _low_retry_response_init)
739 def test_retrying_status_errors_doesnt_leak(self) -> None:
740 def raise_for_status(response: httpx.Response) -> None:
741 response.status_code = 500
742 raise httpx.HTTPStatusError("Test 500 error", response=response, request=response.request)
743
744 with mock.patch("httpx.Response.raise_for_status", raise_for_status):
745 with pytest.raises(APIStatusError):
746 self.client.post(
747 "/chat/completions",
748 body=dict(
749 messages=[
750 {
751 "role": "user",
752 "content": "Say this is a test",
753 }
754 ],
755 model="gpt-3.5-turbo",
756 ),
757 cast_to=httpx.Response,
758 options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
759 )
760
761 assert _get_open_connections(self.client) == 0
762
763 @pytest.mark.respx(base_url=base_url)
764 def test_status_error_within_httpx(self, respx_mock: MockRouter) -> None:
765 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
766
767 def on_response(response: httpx.Response) -> None:
768 raise httpx.HTTPStatusError(
769 "Simulating an error inside httpx",
770 response=response,
771 request=response.request,
772 )
773
774 client = OpenAI(
775 base_url=base_url,
776 api_key=api_key,
777 _strict_response_validation=True,
778 http_client=httpx.Client(
779 event_hooks={
780 "response": [on_response],
781 }
782 ),
783 max_retries=0,
784 )
785 with pytest.raises(APIStatusError):
786 client.post("/foo", cast_to=httpx.Response)
787
788
789class TestAsyncOpenAI:
790 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
791
792 @pytest.mark.respx(base_url=base_url)
793 @pytest.mark.asyncio
794 async def test_raw_response(self, respx_mock: MockRouter) -> None:
795 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
796
797 response = await self.client.post("/foo", cast_to=httpx.Response)
798 assert response.status_code == 200
799 assert isinstance(response, httpx.Response)
800 assert response.json() == {"foo": "bar"}
801
802 @pytest.mark.respx(base_url=base_url)
803 @pytest.mark.asyncio
804 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
805 respx_mock.post("/foo").mock(
806 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
807 )
808
809 response = await self.client.post("/foo", cast_to=httpx.Response)
810 assert response.status_code == 200
811 assert isinstance(response, httpx.Response)
812 assert response.json() == {"foo": "bar"}
813
814 def test_copy(self) -> None:
815 copied = self.client.copy()
816 assert id(copied) != id(self.client)
817
818 copied = self.client.copy(api_key="another My API Key")
819 assert copied.api_key == "another My API Key"
820 assert self.client.api_key == "My API Key"
821
822 def test_copy_default_options(self) -> None:
823 # options that have a default are overridden correctly
824 copied = self.client.copy(max_retries=7)
825 assert copied.max_retries == 7
826 assert self.client.max_retries == 2
827
828 copied2 = copied.copy(max_retries=6)
829 assert copied2.max_retries == 6
830 assert copied.max_retries == 7
831
832 # timeout
833 assert isinstance(self.client.timeout, httpx.Timeout)
834 copied = self.client.copy(timeout=None)
835 assert copied.timeout is None
836 assert isinstance(self.client.timeout, httpx.Timeout)
837
838 def test_copy_default_headers(self) -> None:
839 client = AsyncOpenAI(
840 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
841 )
842 assert client.default_headers["X-Foo"] == "bar"
843
844 # does not override the already given value when not specified
845 copied = client.copy()
846 assert copied.default_headers["X-Foo"] == "bar"
847
848 # merges already given headers
849 copied = client.copy(default_headers={"X-Bar": "stainless"})
850 assert copied.default_headers["X-Foo"] == "bar"
851 assert copied.default_headers["X-Bar"] == "stainless"
852
853 # uses new values for any already given headers
854 copied = client.copy(default_headers={"X-Foo": "stainless"})
855 assert copied.default_headers["X-Foo"] == "stainless"
856
857 # set_default_headers
858
859 # completely overrides already set values
860 copied = client.copy(set_default_headers={})
861 assert copied.default_headers.get("X-Foo") is None
862
863 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
864 assert copied.default_headers["X-Bar"] == "Robert"
865
866 with pytest.raises(
867 ValueError,
868 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
869 ):
870 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
871
872 def test_copy_default_query(self) -> None:
873 client = AsyncOpenAI(
874 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
875 )
876 assert _get_params(client)["foo"] == "bar"
877
878 # does not override the already given value when not specified
879 copied = client.copy()
880 assert _get_params(copied)["foo"] == "bar"
881
882 # merges already given params
883 copied = client.copy(default_query={"bar": "stainless"})
884 params = _get_params(copied)
885 assert params["foo"] == "bar"
886 assert params["bar"] == "stainless"
887
888 # uses new values for any already given headers
889 copied = client.copy(default_query={"foo": "stainless"})
890 assert _get_params(copied)["foo"] == "stainless"
891
892 # set_default_query
893
894 # completely overrides already set values
895 copied = client.copy(set_default_query={})
896 assert _get_params(copied) == {}
897
898 copied = client.copy(set_default_query={"bar": "Robert"})
899 assert _get_params(copied)["bar"] == "Robert"
900
901 with pytest.raises(
902 ValueError,
903 # TODO: update
904 match="`default_query` and `set_default_query` arguments are mutually exclusive",
905 ):
906 client.copy(set_default_query={}, default_query={"foo": "Bar"})
907
908 def test_copy_signature(self) -> None:
909 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
910 init_signature = inspect.signature(
911 # mypy doesn't like that we access the `__init__` property.
912 self.client.__init__, # type: ignore[misc]
913 )
914 copy_signature = inspect.signature(self.client.copy)
915 exclude_params = {"transport", "proxies", "_strict_response_validation"}
916
917 for name in init_signature.parameters.keys():
918 if name in exclude_params:
919 continue
920
921 copy_param = copy_signature.parameters.get(name)
922 assert copy_param is not None, f"copy() signature is missing the {name} param"
923
924 def test_copy_build_request(self) -> None:
925 options = FinalRequestOptions(method="get", url="/foo")
926
927 def build_request(options: FinalRequestOptions) -> None:
928 client = self.client.copy()
929 client._build_request(options)
930
931 # ensure that the machinery is warmed up before tracing starts.
932 build_request(options)
933 gc.collect()
934
935 tracemalloc.start(1000)
936
937 snapshot_before = tracemalloc.take_snapshot()
938
939 ITERATIONS = 10
940 for _ in range(ITERATIONS):
941 build_request(options)
942 gc.collect()
943
944 snapshot_after = tracemalloc.take_snapshot()
945
946 tracemalloc.stop()
947
948 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
949 if diff.count == 0:
950 # Avoid false positives by considering only leaks (i.e. allocations that persist).
951 return
952
953 if diff.count % ITERATIONS != 0:
954 # Avoid false positives by considering only leaks that appear per iteration.
955 return
956
957 for frame in diff.traceback:
958 if any(
959 frame.filename.endswith(fragment)
960 for fragment in [
961 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
962 #
963 # removing the decorator fixes the leak for reasons we don't understand.
964 "openai/_response.py",
965 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
966 "openai/_compat.py",
967 # Standard library leaks we don't care about.
968 "/logging/__init__.py",
969 ]
970 ):
971 return
972
973 leaks.append(diff)
974
975 leaks: list[tracemalloc.StatisticDiff] = []
976 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
977 add_leak(leaks, diff)
978 if leaks:
979 for leak in leaks:
980 print("MEMORY LEAK:", leak)
981 for frame in leak.traceback:
982 print(frame)
983 raise AssertionError()
984
985 async def test_request_timeout(self) -> None:
986 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
987 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
988 assert timeout == DEFAULT_TIMEOUT
989
990 request = self.client._build_request(
991 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
992 )
993 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
994 assert timeout == httpx.Timeout(100.0)
995
996 async def test_client_timeout_option(self) -> None:
997 client = AsyncOpenAI(
998 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
999 )
1000
1001 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1002 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1003 assert timeout == httpx.Timeout(0)
1004
1005 async def test_http_client_timeout_option(self) -> None:
1006 # custom timeout given to the httpx client should be used
1007 async with httpx.AsyncClient(timeout=None) as http_client:
1008 client = AsyncOpenAI(
1009 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1010 )
1011
1012 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1013 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1014 assert timeout == httpx.Timeout(None)
1015
1016 # no timeout given to the httpx client should not use the httpx default
1017 async with httpx.AsyncClient() as http_client:
1018 client = AsyncOpenAI(
1019 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1020 )
1021
1022 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1023 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1024 assert timeout == DEFAULT_TIMEOUT
1025
1026 # explicitly passing the default timeout currently results in it being ignored
1027 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1028 client = AsyncOpenAI(
1029 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1030 )
1031
1032 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1033 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1034 assert timeout == DEFAULT_TIMEOUT # our default
1035
1036 def test_default_headers_option(self) -> None:
1037 client = AsyncOpenAI(
1038 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1039 )
1040 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1041 assert request.headers.get("x-foo") == "bar"
1042 assert request.headers.get("x-stainless-lang") == "python"
1043
1044 client2 = AsyncOpenAI(
1045 base_url=base_url,
1046 api_key=api_key,
1047 _strict_response_validation=True,
1048 default_headers={
1049 "X-Foo": "stainless",
1050 "X-Stainless-Lang": "my-overriding-header",
1051 },
1052 )
1053 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1054 assert request.headers.get("x-foo") == "stainless"
1055 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1056
1057 def test_validate_headers(self) -> None:
1058 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1059 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1060 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1061
1062 with pytest.raises(OpenAIError):
1063 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1064 _ = client2
1065
1066 def test_default_query_option(self) -> None:
1067 client = AsyncOpenAI(
1068 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1069 )
1070 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1071 url = httpx.URL(request.url)
1072 assert dict(url.params) == {"query_param": "bar"}
1073
1074 request = client._build_request(
1075 FinalRequestOptions(
1076 method="get",
1077 url="/foo",
1078 params={"foo": "baz", "query_param": "overriden"},
1079 )
1080 )
1081 url = httpx.URL(request.url)
1082 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
1083
1084 def test_request_extra_json(self) -> None:
1085 request = self.client._build_request(
1086 FinalRequestOptions(
1087 method="post",
1088 url="/foo",
1089 json_data={"foo": "bar"},
1090 extra_json={"baz": False},
1091 ),
1092 )
1093 data = json.loads(request.content.decode("utf-8"))
1094 assert data == {"foo": "bar", "baz": False}
1095
1096 request = self.client._build_request(
1097 FinalRequestOptions(
1098 method="post",
1099 url="/foo",
1100 extra_json={"baz": False},
1101 ),
1102 )
1103 data = json.loads(request.content.decode("utf-8"))
1104 assert data == {"baz": False}
1105
1106 # `extra_json` takes priority over `json_data` when keys clash
1107 request = self.client._build_request(
1108 FinalRequestOptions(
1109 method="post",
1110 url="/foo",
1111 json_data={"foo": "bar", "baz": True},
1112 extra_json={"baz": None},
1113 ),
1114 )
1115 data = json.loads(request.content.decode("utf-8"))
1116 assert data == {"foo": "bar", "baz": None}
1117
1118 def test_request_extra_headers(self) -> None:
1119 request = self.client._build_request(
1120 FinalRequestOptions(
1121 method="post",
1122 url="/foo",
1123 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1124 ),
1125 )
1126 assert request.headers.get("X-Foo") == "Foo"
1127
1128 # `extra_headers` takes priority over `default_headers` when keys clash
1129 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1130 FinalRequestOptions(
1131 method="post",
1132 url="/foo",
1133 **make_request_options(
1134 extra_headers={"X-Bar": "false"},
1135 ),
1136 ),
1137 )
1138 assert request.headers.get("X-Bar") == "false"
1139
1140 def test_request_extra_query(self) -> None:
1141 request = self.client._build_request(
1142 FinalRequestOptions(
1143 method="post",
1144 url="/foo",
1145 **make_request_options(
1146 extra_query={"my_query_param": "Foo"},
1147 ),
1148 ),
1149 )
1150 params = dict(request.url.params)
1151 assert params == {"my_query_param": "Foo"}
1152
1153 # if both `query` and `extra_query` are given, they are merged
1154 request = self.client._build_request(
1155 FinalRequestOptions(
1156 method="post",
1157 url="/foo",
1158 **make_request_options(
1159 query={"bar": "1"},
1160 extra_query={"foo": "2"},
1161 ),
1162 ),
1163 )
1164 params = dict(request.url.params)
1165 assert params == {"bar": "1", "foo": "2"}
1166
1167 # `extra_query` takes priority over `query` when keys clash
1168 request = self.client._build_request(
1169 FinalRequestOptions(
1170 method="post",
1171 url="/foo",
1172 **make_request_options(
1173 query={"foo": "1"},
1174 extra_query={"foo": "2"},
1175 ),
1176 ),
1177 )
1178 params = dict(request.url.params)
1179 assert params == {"foo": "2"}
1180
1181 @pytest.mark.respx(base_url=base_url)
1182 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1183 class Model1(BaseModel):
1184 name: str
1185
1186 class Model2(BaseModel):
1187 foo: str
1188
1189 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1190
1191 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1192 assert isinstance(response, Model2)
1193 assert response.foo == "bar"
1194
1195 @pytest.mark.respx(base_url=base_url)
1196 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1197 """Union of objects with the same field name using a different type"""
1198
1199 class Model1(BaseModel):
1200 foo: int
1201
1202 class Model2(BaseModel):
1203 foo: str
1204
1205 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1206
1207 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1208 assert isinstance(response, Model2)
1209 assert response.foo == "bar"
1210
1211 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1212
1213 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1214 assert isinstance(response, Model1)
1215 assert response.foo == 1
1216
1217 @pytest.mark.respx(base_url=base_url)
1218 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1219 """
1220 Response that sets Content-Type to something other than application/json but returns json data
1221 """
1222
1223 class Model(BaseModel):
1224 foo: int
1225
1226 respx_mock.get("/foo").mock(
1227 return_value=httpx.Response(
1228 200,
1229 content=json.dumps({"foo": 2}),
1230 headers={"Content-Type": "application/text"},
1231 )
1232 )
1233
1234 response = await self.client.get("/foo", cast_to=Model)
1235 assert isinstance(response, Model)
1236 assert response.foo == 2
1237
1238 def test_base_url_setter(self) -> None:
1239 client = AsyncOpenAI(
1240 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1241 )
1242 assert client.base_url == "https://example.com/from_init/"
1243
1244 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1245
1246 assert client.base_url == "https://example.com/from_setter/"
1247
1248 def test_base_url_env(self) -> None:
1249 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1250 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1251 assert client.base_url == "http://localhost:5000/from/env/"
1252
1253 @pytest.mark.parametrize(
1254 "client",
1255 [
1256 AsyncOpenAI(
1257 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1258 ),
1259 AsyncOpenAI(
1260 base_url="http://localhost:5000/custom/path/",
1261 api_key=api_key,
1262 _strict_response_validation=True,
1263 http_client=httpx.AsyncClient(),
1264 ),
1265 ],
1266 ids=["standard", "custom http client"],
1267 )
1268 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1269 request = client._build_request(
1270 FinalRequestOptions(
1271 method="post",
1272 url="/foo",
1273 json_data={"foo": "bar"},
1274 ),
1275 )
1276 assert request.url == "http://localhost:5000/custom/path/foo"
1277
1278 @pytest.mark.parametrize(
1279 "client",
1280 [
1281 AsyncOpenAI(
1282 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1283 ),
1284 AsyncOpenAI(
1285 base_url="http://localhost:5000/custom/path/",
1286 api_key=api_key,
1287 _strict_response_validation=True,
1288 http_client=httpx.AsyncClient(),
1289 ),
1290 ],
1291 ids=["standard", "custom http client"],
1292 )
1293 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1294 request = client._build_request(
1295 FinalRequestOptions(
1296 method="post",
1297 url="/foo",
1298 json_data={"foo": "bar"},
1299 ),
1300 )
1301 assert request.url == "http://localhost:5000/custom/path/foo"
1302
1303 @pytest.mark.parametrize(
1304 "client",
1305 [
1306 AsyncOpenAI(
1307 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1308 ),
1309 AsyncOpenAI(
1310 base_url="http://localhost:5000/custom/path/",
1311 api_key=api_key,
1312 _strict_response_validation=True,
1313 http_client=httpx.AsyncClient(),
1314 ),
1315 ],
1316 ids=["standard", "custom http client"],
1317 )
1318 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1319 request = client._build_request(
1320 FinalRequestOptions(
1321 method="post",
1322 url="https://myapi.com/foo",
1323 json_data={"foo": "bar"},
1324 ),
1325 )
1326 assert request.url == "https://myapi.com/foo"
1327
1328 async def test_client_del(self) -> None:
1329 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1330 assert not client.is_closed()
1331
1332 client.__del__()
1333
1334 await asyncio.sleep(0.2)
1335 assert client.is_closed()
1336
1337 async def test_copied_client_does_not_close_http(self) -> None:
1338 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1339 assert not client.is_closed()
1340
1341 copied = client.copy()
1342 assert copied is not client
1343
1344 copied.__del__()
1345
1346 await asyncio.sleep(0.2)
1347 assert not copied.is_closed()
1348 assert not client.is_closed()
1349
1350 async def test_client_context_manager(self) -> None:
1351 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1352 async with client as c2:
1353 assert c2 is client
1354 assert not c2.is_closed()
1355 assert not client.is_closed()
1356 assert client.is_closed()
1357
1358 @pytest.mark.respx(base_url=base_url)
1359 @pytest.mark.asyncio
1360 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1361 class Model(BaseModel):
1362 foo: str
1363
1364 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1365
1366 with pytest.raises(APIResponseValidationError) as exc:
1367 await self.client.get("/foo", cast_to=Model)
1368
1369 assert isinstance(exc.value.__cause__, ValidationError)
1370
1371 @pytest.mark.respx(base_url=base_url)
1372 @pytest.mark.asyncio
1373 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1374 class Model(BaseModel):
1375 name: str
1376
1377 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1378
1379 response = await self.client.post("/foo", cast_to=Model, stream=True)
1380 assert isinstance(response, AsyncStream)
1381
1382 @pytest.mark.respx(base_url=base_url)
1383 @pytest.mark.asyncio
1384 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1385 class Model(BaseModel):
1386 name: str
1387
1388 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1389
1390 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1391
1392 with pytest.raises(APIResponseValidationError):
1393 await strict_client.get("/foo", cast_to=Model)
1394
1395 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1396
1397 response = await client.get("/foo", cast_to=Model)
1398 assert isinstance(response, str) # type: ignore[unreachable]
1399
1400 @pytest.mark.parametrize(
1401 "remaining_retries,retry_after,timeout",
1402 [
1403 [3, "20", 20],
1404 [3, "0", 0.5],
1405 [3, "-10", 0.5],
1406 [3, "60", 60],
1407 [3, "61", 0.5],
1408 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1409 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1410 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1411 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1412 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1413 [3, "99999999999999999999999999999999999", 0.5],
1414 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1415 [3, "", 0.5],
1416 [2, "", 0.5 * 2.0],
1417 [1, "", 0.5 * 4.0],
1418 ],
1419 )
1420 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1421 @pytest.mark.asyncio
1422 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1423 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1424
1425 headers = httpx.Headers({"retry-after": retry_after})
1426 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1427 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1428 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1429
1430 @mock.patch("httpx.Response.__init__", _low_retry_response_init)
1431 async def test_retrying_timeout_errors_doesnt_leak(self) -> None:
1432 def raise_for_status(response: httpx.Response) -> None:
1433 raise httpx.TimeoutException("Test timeout error", request=response.request)
1434
1435 with mock.patch("httpx.Response.raise_for_status", raise_for_status):
1436 with pytest.raises(APITimeoutError):
1437 await self.client.post(
1438 "/chat/completions",
1439 body=dict(
1440 messages=[
1441 {
1442 "role": "user",
1443 "content": "Say this is a test",
1444 }
1445 ],
1446 model="gpt-3.5-turbo",
1447 ),
1448 cast_to=httpx.Response,
1449 options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
1450 )
1451
1452 assert _get_open_connections(self.client) == 0
1453
1454 @mock.patch("httpx.Response.__init__", _low_retry_response_init)
1455 async def test_retrying_runtime_errors_doesnt_leak(self) -> None:
1456 def raise_for_status(_response: httpx.Response) -> None:
1457 raise RuntimeError("Test error")
1458
1459 with mock.patch("httpx.Response.raise_for_status", raise_for_status):
1460 with pytest.raises(APIConnectionError):
1461 await self.client.post(
1462 "/chat/completions",
1463 body=dict(
1464 messages=[
1465 {
1466 "role": "user",
1467 "content": "Say this is a test",
1468 }
1469 ],
1470 model="gpt-3.5-turbo",
1471 ),
1472 cast_to=httpx.Response,
1473 options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
1474 )
1475
1476 assert _get_open_connections(self.client) == 0
1477
1478 @mock.patch("httpx.Response.__init__", _low_retry_response_init)
1479 async def test_retrying_status_errors_doesnt_leak(self) -> None:
1480 def raise_for_status(response: httpx.Response) -> None:
1481 response.status_code = 500
1482 raise httpx.HTTPStatusError("Test 500 error", response=response, request=response.request)
1483
1484 with mock.patch("httpx.Response.raise_for_status", raise_for_status):
1485 with pytest.raises(APIStatusError):
1486 await self.client.post(
1487 "/chat/completions",
1488 body=dict(
1489 messages=[
1490 {
1491 "role": "user",
1492 "content": "Say this is a test",
1493 }
1494 ],
1495 model="gpt-3.5-turbo",
1496 ),
1497 cast_to=httpx.Response,
1498 options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
1499 )
1500
1501 assert _get_open_connections(self.client) == 0
1502
1503 @pytest.mark.respx(base_url=base_url)
1504 @pytest.mark.asyncio
1505 async def test_status_error_within_httpx(self, respx_mock: MockRouter) -> None:
1506 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1507
1508 def on_response(response: httpx.Response) -> None:
1509 raise httpx.HTTPStatusError(
1510 "Simulating an error inside httpx",
1511 response=response,
1512 request=response.request,
1513 )
1514
1515 client = AsyncOpenAI(
1516 base_url=base_url,
1517 api_key=api_key,
1518 _strict_response_validation=True,
1519 http_client=httpx.AsyncClient(
1520 event_hooks={
1521 "response": [on_response],
1522 }
1523 ),
1524 max_retries=0,
1525 )
1526 with pytest.raises(APIStatusError):
1527 await client.post("/foo", cast_to=httpx.Response)
1528