openai/openai-python

Public

mirrored from https://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
a7ebc26085fa0bb5ceeadc8aebb72dd2dfc9d5a6

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1508lines · modecode

1# File generated from our OpenAPI spec by Stainless.
2
3from __future__ import annotations
4
5import gc
6import os
7import json
8import asyncio
9import inspect
10import tracemalloc
11from typing import Any, Union, cast
12from unittest import mock
13
14import httpx
15import pytest
16from respx import MockRouter
17from pydantic import ValidationError
18
19from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
20from openai._client import OpenAI, AsyncOpenAI
21from openai._models import BaseModel, FinalRequestOptions
22from openai._streaming import Stream, AsyncStream
23from openai._exceptions import (
24 OpenAIError,
25 APIStatusError,
26 APITimeoutError,
27 APIConnectionError,
28 APIResponseValidationError,
29)
30from openai._base_client import (
31 DEFAULT_TIMEOUT,
32 HTTPX_DEFAULT_TIMEOUT,
33 BaseClient,
34 make_request_options,
35)
36
37from .utils import update_env
38
39base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
40api_key = "My API Key"
41
42
43def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
44 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
45 url = httpx.URL(request.url)
46 return dict(url.params)
47
48
49_original_response_init = cast(Any, httpx.Response.__init__) # type: ignore
50
51
52def _low_retry_response_init(*args: Any, **kwargs: Any) -> Any:
53 headers = cast("list[tuple[bytes, bytes]]", kwargs["headers"])
54 headers.append((b"retry-after", b"0.1"))
55
56 return _original_response_init(*args, **kwargs)
57
58
59def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
60 transport = client._client._transport
61 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
62
63 pool = transport._pool
64 return len(pool._requests)
65
66
67class TestOpenAI:
68 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
69
70 @pytest.mark.respx(base_url=base_url)
71 def test_raw_response(self, respx_mock: MockRouter) -> None:
72 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
73
74 response = self.client.post("/foo", cast_to=httpx.Response)
75 assert response.status_code == 200
76 assert isinstance(response, httpx.Response)
77 assert response.json() == {"foo": "bar"}
78
79 @pytest.mark.respx(base_url=base_url)
80 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
81 respx_mock.post("/foo").mock(
82 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
83 )
84
85 response = self.client.post("/foo", cast_to=httpx.Response)
86 assert response.status_code == 200
87 assert isinstance(response, httpx.Response)
88 assert response.json() == {"foo": "bar"}
89
90 def test_copy(self) -> None:
91 copied = self.client.copy()
92 assert id(copied) != id(self.client)
93
94 copied = self.client.copy(api_key="another My API Key")
95 assert copied.api_key == "another My API Key"
96 assert self.client.api_key == "My API Key"
97
98 def test_copy_default_options(self) -> None:
99 # options that have a default are overridden correctly
100 copied = self.client.copy(max_retries=7)
101 assert copied.max_retries == 7
102 assert self.client.max_retries == 2
103
104 copied2 = copied.copy(max_retries=6)
105 assert copied2.max_retries == 6
106 assert copied.max_retries == 7
107
108 # timeout
109 assert isinstance(self.client.timeout, httpx.Timeout)
110 copied = self.client.copy(timeout=None)
111 assert copied.timeout is None
112 assert isinstance(self.client.timeout, httpx.Timeout)
113
114 def test_copy_default_headers(self) -> None:
115 client = OpenAI(
116 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
117 )
118 assert client.default_headers["X-Foo"] == "bar"
119
120 # does not override the already given value when not specified
121 copied = client.copy()
122 assert copied.default_headers["X-Foo"] == "bar"
123
124 # merges already given headers
125 copied = client.copy(default_headers={"X-Bar": "stainless"})
126 assert copied.default_headers["X-Foo"] == "bar"
127 assert copied.default_headers["X-Bar"] == "stainless"
128
129 # uses new values for any already given headers
130 copied = client.copy(default_headers={"X-Foo": "stainless"})
131 assert copied.default_headers["X-Foo"] == "stainless"
132
133 # set_default_headers
134
135 # completely overrides already set values
136 copied = client.copy(set_default_headers={})
137 assert copied.default_headers.get("X-Foo") is None
138
139 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
140 assert copied.default_headers["X-Bar"] == "Robert"
141
142 with pytest.raises(
143 ValueError,
144 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
145 ):
146 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
147
148 def test_copy_default_query(self) -> None:
149 client = OpenAI(
150 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
151 )
152 assert _get_params(client)["foo"] == "bar"
153
154 # does not override the already given value when not specified
155 copied = client.copy()
156 assert _get_params(copied)["foo"] == "bar"
157
158 # merges already given params
159 copied = client.copy(default_query={"bar": "stainless"})
160 params = _get_params(copied)
161 assert params["foo"] == "bar"
162 assert params["bar"] == "stainless"
163
164 # uses new values for any already given headers
165 copied = client.copy(default_query={"foo": "stainless"})
166 assert _get_params(copied)["foo"] == "stainless"
167
168 # set_default_query
169
170 # completely overrides already set values
171 copied = client.copy(set_default_query={})
172 assert _get_params(copied) == {}
173
174 copied = client.copy(set_default_query={"bar": "Robert"})
175 assert _get_params(copied)["bar"] == "Robert"
176
177 with pytest.raises(
178 ValueError,
179 # TODO: update
180 match="`default_query` and `set_default_query` arguments are mutually exclusive",
181 ):
182 client.copy(set_default_query={}, default_query={"foo": "Bar"})
183
184 def test_copy_signature(self) -> None:
185 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
186 init_signature = inspect.signature(
187 # mypy doesn't like that we access the `__init__` property.
188 self.client.__init__, # type: ignore[misc]
189 )
190 copy_signature = inspect.signature(self.client.copy)
191 exclude_params = {"transport", "proxies", "_strict_response_validation"}
192
193 for name in init_signature.parameters.keys():
194 if name in exclude_params:
195 continue
196
197 copy_param = copy_signature.parameters.get(name)
198 assert copy_param is not None, f"copy() signature is missing the {name} param"
199
200 def test_copy_build_request(self) -> None:
201 options = FinalRequestOptions(method="get", url="/foo")
202
203 def build_request(options: FinalRequestOptions) -> None:
204 client = self.client.copy()
205 client._build_request(options)
206
207 # ensure that the machinery is warmed up before tracing starts.
208 build_request(options)
209 gc.collect()
210
211 tracemalloc.start(1000)
212
213 snapshot_before = tracemalloc.take_snapshot()
214
215 ITERATIONS = 10
216 for _ in range(ITERATIONS):
217 build_request(options)
218 gc.collect()
219
220 snapshot_after = tracemalloc.take_snapshot()
221
222 tracemalloc.stop()
223
224 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
225 if diff.count == 0:
226 # Avoid false positives by considering only leaks (i.e. allocations that persist).
227 return
228
229 if diff.count % ITERATIONS != 0:
230 # Avoid false positives by considering only leaks that appear per iteration.
231 return
232
233 for frame in diff.traceback:
234 if any(
235 frame.filename.endswith(fragment)
236 for fragment in [
237 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
238 #
239 # removing the decorator fixes the leak for reasons we don't understand.
240 "openai/_response.py",
241 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
242 "openai/_compat.py",
243 # Standard library leaks we don't care about.
244 "/logging/__init__.py",
245 ]
246 ):
247 return
248
249 leaks.append(diff)
250
251 leaks: list[tracemalloc.StatisticDiff] = []
252 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
253 add_leak(leaks, diff)
254 if leaks:
255 for leak in leaks:
256 print("MEMORY LEAK:", leak)
257 for frame in leak.traceback:
258 print(frame)
259 raise AssertionError()
260
261 def test_request_timeout(self) -> None:
262 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
263 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
264 assert timeout == DEFAULT_TIMEOUT
265
266 request = self.client._build_request(
267 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
268 )
269 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
270 assert timeout == httpx.Timeout(100.0)
271
272 def test_client_timeout_option(self) -> None:
273 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
274
275 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
276 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
277 assert timeout == httpx.Timeout(0)
278
279 def test_http_client_timeout_option(self) -> None:
280 # custom timeout given to the httpx client should be used
281 with httpx.Client(timeout=None) as http_client:
282 client = OpenAI(
283 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
284 )
285
286 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
287 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
288 assert timeout == httpx.Timeout(None)
289
290 # no timeout given to the httpx client should not use the httpx default
291 with httpx.Client() as http_client:
292 client = OpenAI(
293 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
294 )
295
296 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
297 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
298 assert timeout == DEFAULT_TIMEOUT
299
300 # explicitly passing the default timeout currently results in it being ignored
301 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
302 client = OpenAI(
303 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
304 )
305
306 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
307 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
308 assert timeout == DEFAULT_TIMEOUT # our default
309
310 def test_default_headers_option(self) -> None:
311 client = OpenAI(
312 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
313 )
314 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
315 assert request.headers.get("x-foo") == "bar"
316 assert request.headers.get("x-stainless-lang") == "python"
317
318 client2 = OpenAI(
319 base_url=base_url,
320 api_key=api_key,
321 _strict_response_validation=True,
322 default_headers={
323 "X-Foo": "stainless",
324 "X-Stainless-Lang": "my-overriding-header",
325 },
326 )
327 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
328 assert request.headers.get("x-foo") == "stainless"
329 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
330
331 def test_validate_headers(self) -> None:
332 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
333 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
334 assert request.headers.get("Authorization") == f"Bearer {api_key}"
335
336 with pytest.raises(OpenAIError):
337 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
338 _ = client2
339
340 def test_default_query_option(self) -> None:
341 client = OpenAI(
342 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
343 )
344 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
345 url = httpx.URL(request.url)
346 assert dict(url.params) == {"query_param": "bar"}
347
348 request = client._build_request(
349 FinalRequestOptions(
350 method="get",
351 url="/foo",
352 params={"foo": "baz", "query_param": "overriden"},
353 )
354 )
355 url = httpx.URL(request.url)
356 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
357
358 def test_request_extra_json(self) -> None:
359 request = self.client._build_request(
360 FinalRequestOptions(
361 method="post",
362 url="/foo",
363 json_data={"foo": "bar"},
364 extra_json={"baz": False},
365 ),
366 )
367 data = json.loads(request.content.decode("utf-8"))
368 assert data == {"foo": "bar", "baz": False}
369
370 request = self.client._build_request(
371 FinalRequestOptions(
372 method="post",
373 url="/foo",
374 extra_json={"baz": False},
375 ),
376 )
377 data = json.loads(request.content.decode("utf-8"))
378 assert data == {"baz": False}
379
380 # `extra_json` takes priority over `json_data` when keys clash
381 request = self.client._build_request(
382 FinalRequestOptions(
383 method="post",
384 url="/foo",
385 json_data={"foo": "bar", "baz": True},
386 extra_json={"baz": None},
387 ),
388 )
389 data = json.loads(request.content.decode("utf-8"))
390 assert data == {"foo": "bar", "baz": None}
391
392 def test_request_extra_headers(self) -> None:
393 request = self.client._build_request(
394 FinalRequestOptions(
395 method="post",
396 url="/foo",
397 **make_request_options(extra_headers={"X-Foo": "Foo"}),
398 ),
399 )
400 assert request.headers.get("X-Foo") == "Foo"
401
402 # `extra_headers` takes priority over `default_headers` when keys clash
403 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
404 FinalRequestOptions(
405 method="post",
406 url="/foo",
407 **make_request_options(
408 extra_headers={"X-Bar": "false"},
409 ),
410 ),
411 )
412 assert request.headers.get("X-Bar") == "false"
413
414 def test_request_extra_query(self) -> None:
415 request = self.client._build_request(
416 FinalRequestOptions(
417 method="post",
418 url="/foo",
419 **make_request_options(
420 extra_query={"my_query_param": "Foo"},
421 ),
422 ),
423 )
424 params = dict(request.url.params)
425 assert params == {"my_query_param": "Foo"}
426
427 # if both `query` and `extra_query` are given, they are merged
428 request = self.client._build_request(
429 FinalRequestOptions(
430 method="post",
431 url="/foo",
432 **make_request_options(
433 query={"bar": "1"},
434 extra_query={"foo": "2"},
435 ),
436 ),
437 )
438 params = dict(request.url.params)
439 assert params == {"bar": "1", "foo": "2"}
440
441 # `extra_query` takes priority over `query` when keys clash
442 request = self.client._build_request(
443 FinalRequestOptions(
444 method="post",
445 url="/foo",
446 **make_request_options(
447 query={"foo": "1"},
448 extra_query={"foo": "2"},
449 ),
450 ),
451 )
452 params = dict(request.url.params)
453 assert params == {"foo": "2"}
454
455 @pytest.mark.respx(base_url=base_url)
456 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
457 class Model1(BaseModel):
458 name: str
459
460 class Model2(BaseModel):
461 foo: str
462
463 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
464
465 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
466 assert isinstance(response, Model2)
467 assert response.foo == "bar"
468
469 @pytest.mark.respx(base_url=base_url)
470 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
471 """Union of objects with the same field name using a different type"""
472
473 class Model1(BaseModel):
474 foo: int
475
476 class Model2(BaseModel):
477 foo: str
478
479 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
480
481 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
482 assert isinstance(response, Model2)
483 assert response.foo == "bar"
484
485 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
486
487 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
488 assert isinstance(response, Model1)
489 assert response.foo == 1
490
491 @pytest.mark.respx(base_url=base_url)
492 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
493 """
494 Response that sets Content-Type to something other than application/json but returns json data
495 """
496
497 class Model(BaseModel):
498 foo: int
499
500 respx_mock.get("/foo").mock(
501 return_value=httpx.Response(
502 200,
503 content=json.dumps({"foo": 2}),
504 headers={"Content-Type": "application/text"},
505 )
506 )
507
508 response = self.client.get("/foo", cast_to=Model)
509 assert isinstance(response, Model)
510 assert response.foo == 2
511
512 def test_base_url_setter(self) -> None:
513 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
514 assert client.base_url == "https://example.com/from_init/"
515
516 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
517
518 assert client.base_url == "https://example.com/from_setter/"
519
520 def test_base_url_env(self) -> None:
521 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
522 client = OpenAI(api_key=api_key, _strict_response_validation=True)
523 assert client.base_url == "http://localhost:5000/from/env/"
524
525 @pytest.mark.parametrize(
526 "client",
527 [
528 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
529 OpenAI(
530 base_url="http://localhost:5000/custom/path/",
531 api_key=api_key,
532 _strict_response_validation=True,
533 http_client=httpx.Client(),
534 ),
535 ],
536 ids=["standard", "custom http client"],
537 )
538 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
539 request = client._build_request(
540 FinalRequestOptions(
541 method="post",
542 url="/foo",
543 json_data={"foo": "bar"},
544 ),
545 )
546 assert request.url == "http://localhost:5000/custom/path/foo"
547
548 @pytest.mark.parametrize(
549 "client",
550 [
551 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
552 OpenAI(
553 base_url="http://localhost:5000/custom/path/",
554 api_key=api_key,
555 _strict_response_validation=True,
556 http_client=httpx.Client(),
557 ),
558 ],
559 ids=["standard", "custom http client"],
560 )
561 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
562 request = client._build_request(
563 FinalRequestOptions(
564 method="post",
565 url="/foo",
566 json_data={"foo": "bar"},
567 ),
568 )
569 assert request.url == "http://localhost:5000/custom/path/foo"
570
571 @pytest.mark.parametrize(
572 "client",
573 [
574 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
575 OpenAI(
576 base_url="http://localhost:5000/custom/path/",
577 api_key=api_key,
578 _strict_response_validation=True,
579 http_client=httpx.Client(),
580 ),
581 ],
582 ids=["standard", "custom http client"],
583 )
584 def test_absolute_request_url(self, client: OpenAI) -> None:
585 request = client._build_request(
586 FinalRequestOptions(
587 method="post",
588 url="https://myapi.com/foo",
589 json_data={"foo": "bar"},
590 ),
591 )
592 assert request.url == "https://myapi.com/foo"
593
594 def test_copied_client_does_not_close_http(self) -> None:
595 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
596 assert not client.is_closed()
597
598 copied = client.copy()
599 assert copied is not client
600
601 del copied
602
603 assert not client.is_closed()
604
605 def test_client_context_manager(self) -> None:
606 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
607 with client as c2:
608 assert c2 is client
609 assert not c2.is_closed()
610 assert not client.is_closed()
611 assert client.is_closed()
612
613 @pytest.mark.respx(base_url=base_url)
614 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
615 class Model(BaseModel):
616 foo: str
617
618 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
619
620 with pytest.raises(APIResponseValidationError) as exc:
621 self.client.get("/foo", cast_to=Model)
622
623 assert isinstance(exc.value.__cause__, ValidationError)
624
625 @pytest.mark.respx(base_url=base_url)
626 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
627 class Model(BaseModel):
628 name: str
629
630 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
631
632 response = self.client.post("/foo", cast_to=Model, stream=True)
633 assert isinstance(response, Stream)
634
635 @pytest.mark.respx(base_url=base_url)
636 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
637 class Model(BaseModel):
638 name: str
639
640 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
641
642 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
643
644 with pytest.raises(APIResponseValidationError):
645 strict_client.get("/foo", cast_to=Model)
646
647 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
648
649 response = client.get("/foo", cast_to=Model)
650 assert isinstance(response, str) # type: ignore[unreachable]
651
652 @pytest.mark.parametrize(
653 "remaining_retries,retry_after,timeout",
654 [
655 [3, "20", 20],
656 [3, "0", 0.5],
657 [3, "-10", 0.5],
658 [3, "60", 60],
659 [3, "61", 0.5],
660 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
661 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
662 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
663 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
664 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
665 [3, "99999999999999999999999999999999999", 0.5],
666 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
667 [3, "", 0.5],
668 [2, "", 0.5 * 2.0],
669 [1, "", 0.5 * 4.0],
670 ],
671 )
672 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
673 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
674 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
675
676 headers = httpx.Headers({"retry-after": retry_after})
677 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
678 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
679 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
680
681 @mock.patch("httpx.Response.__init__", _low_retry_response_init)
682 def test_retrying_timeout_errors_doesnt_leak(self) -> None:
683 def raise_for_status(response: httpx.Response) -> None:
684 raise httpx.TimeoutException("Test timeout error", request=response.request)
685
686 with mock.patch("httpx.Response.raise_for_status", raise_for_status):
687 with pytest.raises(APITimeoutError):
688 self.client.post(
689 "/chat/completions",
690 body=dict(
691 messages=[
692 {
693 "role": "user",
694 "content": "Say this is a test",
695 }
696 ],
697 model="gpt-3.5-turbo",
698 ),
699 cast_to=httpx.Response,
700 options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
701 )
702
703 assert _get_open_connections(self.client) == 0
704
705 @mock.patch("httpx.Response.__init__", _low_retry_response_init)
706 def test_retrying_runtime_errors_doesnt_leak(self) -> None:
707 def raise_for_status(_response: httpx.Response) -> None:
708 raise RuntimeError("Test error")
709
710 with mock.patch("httpx.Response.raise_for_status", raise_for_status):
711 with pytest.raises(APIConnectionError):
712 self.client.post(
713 "/chat/completions",
714 body=dict(
715 messages=[
716 {
717 "role": "user",
718 "content": "Say this is a test",
719 }
720 ],
721 model="gpt-3.5-turbo",
722 ),
723 cast_to=httpx.Response,
724 options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
725 )
726
727 assert _get_open_connections(self.client) == 0
728
729 @mock.patch("httpx.Response.__init__", _low_retry_response_init)
730 def test_retrying_status_errors_doesnt_leak(self) -> None:
731 def raise_for_status(response: httpx.Response) -> None:
732 response.status_code = 500
733 raise httpx.HTTPStatusError("Test 500 error", response=response, request=response.request)
734
735 with mock.patch("httpx.Response.raise_for_status", raise_for_status):
736 with pytest.raises(APIStatusError):
737 self.client.post(
738 "/chat/completions",
739 body=dict(
740 messages=[
741 {
742 "role": "user",
743 "content": "Say this is a test",
744 }
745 ],
746 model="gpt-3.5-turbo",
747 ),
748 cast_to=httpx.Response,
749 options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
750 )
751
752 assert _get_open_connections(self.client) == 0
753
754 @pytest.mark.respx(base_url=base_url)
755 def test_status_error_within_httpx(self, respx_mock: MockRouter) -> None:
756 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
757
758 def on_response(response: httpx.Response) -> None:
759 raise httpx.HTTPStatusError(
760 "Simulating an error inside httpx",
761 response=response,
762 request=response.request,
763 )
764
765 client = OpenAI(
766 base_url=base_url,
767 api_key=api_key,
768 _strict_response_validation=True,
769 http_client=httpx.Client(
770 event_hooks={
771 "response": [on_response],
772 }
773 ),
774 max_retries=0,
775 )
776 with pytest.raises(APIStatusError):
777 client.post("/foo", cast_to=httpx.Response)
778
779
780class TestAsyncOpenAI:
781 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
782
783 @pytest.mark.respx(base_url=base_url)
784 @pytest.mark.asyncio
785 async def test_raw_response(self, respx_mock: MockRouter) -> None:
786 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
787
788 response = await self.client.post("/foo", cast_to=httpx.Response)
789 assert response.status_code == 200
790 assert isinstance(response, httpx.Response)
791 assert response.json() == {"foo": "bar"}
792
793 @pytest.mark.respx(base_url=base_url)
794 @pytest.mark.asyncio
795 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
796 respx_mock.post("/foo").mock(
797 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
798 )
799
800 response = await self.client.post("/foo", cast_to=httpx.Response)
801 assert response.status_code == 200
802 assert isinstance(response, httpx.Response)
803 assert response.json() == {"foo": "bar"}
804
805 def test_copy(self) -> None:
806 copied = self.client.copy()
807 assert id(copied) != id(self.client)
808
809 copied = self.client.copy(api_key="another My API Key")
810 assert copied.api_key == "another My API Key"
811 assert self.client.api_key == "My API Key"
812
813 def test_copy_default_options(self) -> None:
814 # options that have a default are overridden correctly
815 copied = self.client.copy(max_retries=7)
816 assert copied.max_retries == 7
817 assert self.client.max_retries == 2
818
819 copied2 = copied.copy(max_retries=6)
820 assert copied2.max_retries == 6
821 assert copied.max_retries == 7
822
823 # timeout
824 assert isinstance(self.client.timeout, httpx.Timeout)
825 copied = self.client.copy(timeout=None)
826 assert copied.timeout is None
827 assert isinstance(self.client.timeout, httpx.Timeout)
828
829 def test_copy_default_headers(self) -> None:
830 client = AsyncOpenAI(
831 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
832 )
833 assert client.default_headers["X-Foo"] == "bar"
834
835 # does not override the already given value when not specified
836 copied = client.copy()
837 assert copied.default_headers["X-Foo"] == "bar"
838
839 # merges already given headers
840 copied = client.copy(default_headers={"X-Bar": "stainless"})
841 assert copied.default_headers["X-Foo"] == "bar"
842 assert copied.default_headers["X-Bar"] == "stainless"
843
844 # uses new values for any already given headers
845 copied = client.copy(default_headers={"X-Foo": "stainless"})
846 assert copied.default_headers["X-Foo"] == "stainless"
847
848 # set_default_headers
849
850 # completely overrides already set values
851 copied = client.copy(set_default_headers={})
852 assert copied.default_headers.get("X-Foo") is None
853
854 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
855 assert copied.default_headers["X-Bar"] == "Robert"
856
857 with pytest.raises(
858 ValueError,
859 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
860 ):
861 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
862
863 def test_copy_default_query(self) -> None:
864 client = AsyncOpenAI(
865 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
866 )
867 assert _get_params(client)["foo"] == "bar"
868
869 # does not override the already given value when not specified
870 copied = client.copy()
871 assert _get_params(copied)["foo"] == "bar"
872
873 # merges already given params
874 copied = client.copy(default_query={"bar": "stainless"})
875 params = _get_params(copied)
876 assert params["foo"] == "bar"
877 assert params["bar"] == "stainless"
878
879 # uses new values for any already given headers
880 copied = client.copy(default_query={"foo": "stainless"})
881 assert _get_params(copied)["foo"] == "stainless"
882
883 # set_default_query
884
885 # completely overrides already set values
886 copied = client.copy(set_default_query={})
887 assert _get_params(copied) == {}
888
889 copied = client.copy(set_default_query={"bar": "Robert"})
890 assert _get_params(copied)["bar"] == "Robert"
891
892 with pytest.raises(
893 ValueError,
894 # TODO: update
895 match="`default_query` and `set_default_query` arguments are mutually exclusive",
896 ):
897 client.copy(set_default_query={}, default_query={"foo": "Bar"})
898
899 def test_copy_signature(self) -> None:
900 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
901 init_signature = inspect.signature(
902 # mypy doesn't like that we access the `__init__` property.
903 self.client.__init__, # type: ignore[misc]
904 )
905 copy_signature = inspect.signature(self.client.copy)
906 exclude_params = {"transport", "proxies", "_strict_response_validation"}
907
908 for name in init_signature.parameters.keys():
909 if name in exclude_params:
910 continue
911
912 copy_param = copy_signature.parameters.get(name)
913 assert copy_param is not None, f"copy() signature is missing the {name} param"
914
915 def test_copy_build_request(self) -> None:
916 options = FinalRequestOptions(method="get", url="/foo")
917
918 def build_request(options: FinalRequestOptions) -> None:
919 client = self.client.copy()
920 client._build_request(options)
921
922 # ensure that the machinery is warmed up before tracing starts.
923 build_request(options)
924 gc.collect()
925
926 tracemalloc.start(1000)
927
928 snapshot_before = tracemalloc.take_snapshot()
929
930 ITERATIONS = 10
931 for _ in range(ITERATIONS):
932 build_request(options)
933 gc.collect()
934
935 snapshot_after = tracemalloc.take_snapshot()
936
937 tracemalloc.stop()
938
939 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
940 if diff.count == 0:
941 # Avoid false positives by considering only leaks (i.e. allocations that persist).
942 return
943
944 if diff.count % ITERATIONS != 0:
945 # Avoid false positives by considering only leaks that appear per iteration.
946 return
947
948 for frame in diff.traceback:
949 if any(
950 frame.filename.endswith(fragment)
951 for fragment in [
952 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
953 #
954 # removing the decorator fixes the leak for reasons we don't understand.
955 "openai/_response.py",
956 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
957 "openai/_compat.py",
958 # Standard library leaks we don't care about.
959 "/logging/__init__.py",
960 ]
961 ):
962 return
963
964 leaks.append(diff)
965
966 leaks: list[tracemalloc.StatisticDiff] = []
967 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
968 add_leak(leaks, diff)
969 if leaks:
970 for leak in leaks:
971 print("MEMORY LEAK:", leak)
972 for frame in leak.traceback:
973 print(frame)
974 raise AssertionError()
975
976 async def test_request_timeout(self) -> None:
977 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
978 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
979 assert timeout == DEFAULT_TIMEOUT
980
981 request = self.client._build_request(
982 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
983 )
984 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
985 assert timeout == httpx.Timeout(100.0)
986
987 async def test_client_timeout_option(self) -> None:
988 client = AsyncOpenAI(
989 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
990 )
991
992 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
993 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
994 assert timeout == httpx.Timeout(0)
995
996 async def test_http_client_timeout_option(self) -> None:
997 # custom timeout given to the httpx client should be used
998 async with httpx.AsyncClient(timeout=None) as http_client:
999 client = AsyncOpenAI(
1000 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1001 )
1002
1003 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1004 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1005 assert timeout == httpx.Timeout(None)
1006
1007 # no timeout given to the httpx client should not use the httpx default
1008 async with httpx.AsyncClient() as http_client:
1009 client = AsyncOpenAI(
1010 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1011 )
1012
1013 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1014 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1015 assert timeout == DEFAULT_TIMEOUT
1016
1017 # explicitly passing the default timeout currently results in it being ignored
1018 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1019 client = AsyncOpenAI(
1020 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1021 )
1022
1023 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1024 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1025 assert timeout == DEFAULT_TIMEOUT # our default
1026
1027 def test_default_headers_option(self) -> None:
1028 client = AsyncOpenAI(
1029 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1030 )
1031 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1032 assert request.headers.get("x-foo") == "bar"
1033 assert request.headers.get("x-stainless-lang") == "python"
1034
1035 client2 = AsyncOpenAI(
1036 base_url=base_url,
1037 api_key=api_key,
1038 _strict_response_validation=True,
1039 default_headers={
1040 "X-Foo": "stainless",
1041 "X-Stainless-Lang": "my-overriding-header",
1042 },
1043 )
1044 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1045 assert request.headers.get("x-foo") == "stainless"
1046 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1047
1048 def test_validate_headers(self) -> None:
1049 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1050 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1051 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1052
1053 with pytest.raises(OpenAIError):
1054 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1055 _ = client2
1056
1057 def test_default_query_option(self) -> None:
1058 client = AsyncOpenAI(
1059 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1060 )
1061 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1062 url = httpx.URL(request.url)
1063 assert dict(url.params) == {"query_param": "bar"}
1064
1065 request = client._build_request(
1066 FinalRequestOptions(
1067 method="get",
1068 url="/foo",
1069 params={"foo": "baz", "query_param": "overriden"},
1070 )
1071 )
1072 url = httpx.URL(request.url)
1073 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
1074
1075 def test_request_extra_json(self) -> None:
1076 request = self.client._build_request(
1077 FinalRequestOptions(
1078 method="post",
1079 url="/foo",
1080 json_data={"foo": "bar"},
1081 extra_json={"baz": False},
1082 ),
1083 )
1084 data = json.loads(request.content.decode("utf-8"))
1085 assert data == {"foo": "bar", "baz": False}
1086
1087 request = self.client._build_request(
1088 FinalRequestOptions(
1089 method="post",
1090 url="/foo",
1091 extra_json={"baz": False},
1092 ),
1093 )
1094 data = json.loads(request.content.decode("utf-8"))
1095 assert data == {"baz": False}
1096
1097 # `extra_json` takes priority over `json_data` when keys clash
1098 request = self.client._build_request(
1099 FinalRequestOptions(
1100 method="post",
1101 url="/foo",
1102 json_data={"foo": "bar", "baz": True},
1103 extra_json={"baz": None},
1104 ),
1105 )
1106 data = json.loads(request.content.decode("utf-8"))
1107 assert data == {"foo": "bar", "baz": None}
1108
1109 def test_request_extra_headers(self) -> None:
1110 request = self.client._build_request(
1111 FinalRequestOptions(
1112 method="post",
1113 url="/foo",
1114 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1115 ),
1116 )
1117 assert request.headers.get("X-Foo") == "Foo"
1118
1119 # `extra_headers` takes priority over `default_headers` when keys clash
1120 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1121 FinalRequestOptions(
1122 method="post",
1123 url="/foo",
1124 **make_request_options(
1125 extra_headers={"X-Bar": "false"},
1126 ),
1127 ),
1128 )
1129 assert request.headers.get("X-Bar") == "false"
1130
1131 def test_request_extra_query(self) -> None:
1132 request = self.client._build_request(
1133 FinalRequestOptions(
1134 method="post",
1135 url="/foo",
1136 **make_request_options(
1137 extra_query={"my_query_param": "Foo"},
1138 ),
1139 ),
1140 )
1141 params = dict(request.url.params)
1142 assert params == {"my_query_param": "Foo"}
1143
1144 # if both `query` and `extra_query` are given, they are merged
1145 request = self.client._build_request(
1146 FinalRequestOptions(
1147 method="post",
1148 url="/foo",
1149 **make_request_options(
1150 query={"bar": "1"},
1151 extra_query={"foo": "2"},
1152 ),
1153 ),
1154 )
1155 params = dict(request.url.params)
1156 assert params == {"bar": "1", "foo": "2"}
1157
1158 # `extra_query` takes priority over `query` when keys clash
1159 request = self.client._build_request(
1160 FinalRequestOptions(
1161 method="post",
1162 url="/foo",
1163 **make_request_options(
1164 query={"foo": "1"},
1165 extra_query={"foo": "2"},
1166 ),
1167 ),
1168 )
1169 params = dict(request.url.params)
1170 assert params == {"foo": "2"}
1171
1172 @pytest.mark.respx(base_url=base_url)
1173 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1174 class Model1(BaseModel):
1175 name: str
1176
1177 class Model2(BaseModel):
1178 foo: str
1179
1180 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1181
1182 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1183 assert isinstance(response, Model2)
1184 assert response.foo == "bar"
1185
1186 @pytest.mark.respx(base_url=base_url)
1187 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1188 """Union of objects with the same field name using a different type"""
1189
1190 class Model1(BaseModel):
1191 foo: int
1192
1193 class Model2(BaseModel):
1194 foo: str
1195
1196 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1197
1198 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1199 assert isinstance(response, Model2)
1200 assert response.foo == "bar"
1201
1202 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1203
1204 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1205 assert isinstance(response, Model1)
1206 assert response.foo == 1
1207
1208 @pytest.mark.respx(base_url=base_url)
1209 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1210 """
1211 Response that sets Content-Type to something other than application/json but returns json data
1212 """
1213
1214 class Model(BaseModel):
1215 foo: int
1216
1217 respx_mock.get("/foo").mock(
1218 return_value=httpx.Response(
1219 200,
1220 content=json.dumps({"foo": 2}),
1221 headers={"Content-Type": "application/text"},
1222 )
1223 )
1224
1225 response = await self.client.get("/foo", cast_to=Model)
1226 assert isinstance(response, Model)
1227 assert response.foo == 2
1228
1229 def test_base_url_setter(self) -> None:
1230 client = AsyncOpenAI(
1231 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1232 )
1233 assert client.base_url == "https://example.com/from_init/"
1234
1235 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1236
1237 assert client.base_url == "https://example.com/from_setter/"
1238
1239 def test_base_url_env(self) -> None:
1240 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1241 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1242 assert client.base_url == "http://localhost:5000/from/env/"
1243
1244 @pytest.mark.parametrize(
1245 "client",
1246 [
1247 AsyncOpenAI(
1248 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1249 ),
1250 AsyncOpenAI(
1251 base_url="http://localhost:5000/custom/path/",
1252 api_key=api_key,
1253 _strict_response_validation=True,
1254 http_client=httpx.AsyncClient(),
1255 ),
1256 ],
1257 ids=["standard", "custom http client"],
1258 )
1259 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1260 request = client._build_request(
1261 FinalRequestOptions(
1262 method="post",
1263 url="/foo",
1264 json_data={"foo": "bar"},
1265 ),
1266 )
1267 assert request.url == "http://localhost:5000/custom/path/foo"
1268
1269 @pytest.mark.parametrize(
1270 "client",
1271 [
1272 AsyncOpenAI(
1273 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1274 ),
1275 AsyncOpenAI(
1276 base_url="http://localhost:5000/custom/path/",
1277 api_key=api_key,
1278 _strict_response_validation=True,
1279 http_client=httpx.AsyncClient(),
1280 ),
1281 ],
1282 ids=["standard", "custom http client"],
1283 )
1284 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1285 request = client._build_request(
1286 FinalRequestOptions(
1287 method="post",
1288 url="/foo",
1289 json_data={"foo": "bar"},
1290 ),
1291 )
1292 assert request.url == "http://localhost:5000/custom/path/foo"
1293
1294 @pytest.mark.parametrize(
1295 "client",
1296 [
1297 AsyncOpenAI(
1298 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1299 ),
1300 AsyncOpenAI(
1301 base_url="http://localhost:5000/custom/path/",
1302 api_key=api_key,
1303 _strict_response_validation=True,
1304 http_client=httpx.AsyncClient(),
1305 ),
1306 ],
1307 ids=["standard", "custom http client"],
1308 )
1309 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1310 request = client._build_request(
1311 FinalRequestOptions(
1312 method="post",
1313 url="https://myapi.com/foo",
1314 json_data={"foo": "bar"},
1315 ),
1316 )
1317 assert request.url == "https://myapi.com/foo"
1318
1319 async def test_copied_client_does_not_close_http(self) -> None:
1320 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1321 assert not client.is_closed()
1322
1323 copied = client.copy()
1324 assert copied is not client
1325
1326 del copied
1327
1328 await asyncio.sleep(0.2)
1329 assert not client.is_closed()
1330
1331 async def test_client_context_manager(self) -> None:
1332 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1333 async with client as c2:
1334 assert c2 is client
1335 assert not c2.is_closed()
1336 assert not client.is_closed()
1337 assert client.is_closed()
1338
1339 @pytest.mark.respx(base_url=base_url)
1340 @pytest.mark.asyncio
1341 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1342 class Model(BaseModel):
1343 foo: str
1344
1345 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1346
1347 with pytest.raises(APIResponseValidationError) as exc:
1348 await self.client.get("/foo", cast_to=Model)
1349
1350 assert isinstance(exc.value.__cause__, ValidationError)
1351
1352 @pytest.mark.respx(base_url=base_url)
1353 @pytest.mark.asyncio
1354 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1355 class Model(BaseModel):
1356 name: str
1357
1358 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1359
1360 response = await self.client.post("/foo", cast_to=Model, stream=True)
1361 assert isinstance(response, AsyncStream)
1362
1363 @pytest.mark.respx(base_url=base_url)
1364 @pytest.mark.asyncio
1365 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1366 class Model(BaseModel):
1367 name: str
1368
1369 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1370
1371 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1372
1373 with pytest.raises(APIResponseValidationError):
1374 await strict_client.get("/foo", cast_to=Model)
1375
1376 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1377
1378 response = await client.get("/foo", cast_to=Model)
1379 assert isinstance(response, str) # type: ignore[unreachable]
1380
1381 @pytest.mark.parametrize(
1382 "remaining_retries,retry_after,timeout",
1383 [
1384 [3, "20", 20],
1385 [3, "0", 0.5],
1386 [3, "-10", 0.5],
1387 [3, "60", 60],
1388 [3, "61", 0.5],
1389 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1390 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1391 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1392 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1393 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1394 [3, "99999999999999999999999999999999999", 0.5],
1395 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1396 [3, "", 0.5],
1397 [2, "", 0.5 * 2.0],
1398 [1, "", 0.5 * 4.0],
1399 ],
1400 )
1401 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1402 @pytest.mark.asyncio
1403 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1404 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1405
1406 headers = httpx.Headers({"retry-after": retry_after})
1407 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1408 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1409 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1410
1411 @mock.patch("httpx.Response.__init__", _low_retry_response_init)
1412 async def test_retrying_timeout_errors_doesnt_leak(self) -> None:
1413 def raise_for_status(response: httpx.Response) -> None:
1414 raise httpx.TimeoutException("Test timeout error", request=response.request)
1415
1416 with mock.patch("httpx.Response.raise_for_status", raise_for_status):
1417 with pytest.raises(APITimeoutError):
1418 await self.client.post(
1419 "/chat/completions",
1420 body=dict(
1421 messages=[
1422 {
1423 "role": "user",
1424 "content": "Say this is a test",
1425 }
1426 ],
1427 model="gpt-3.5-turbo",
1428 ),
1429 cast_to=httpx.Response,
1430 options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
1431 )
1432
1433 assert _get_open_connections(self.client) == 0
1434
1435 @mock.patch("httpx.Response.__init__", _low_retry_response_init)
1436 async def test_retrying_runtime_errors_doesnt_leak(self) -> None:
1437 def raise_for_status(_response: httpx.Response) -> None:
1438 raise RuntimeError("Test error")
1439
1440 with mock.patch("httpx.Response.raise_for_status", raise_for_status):
1441 with pytest.raises(APIConnectionError):
1442 await self.client.post(
1443 "/chat/completions",
1444 body=dict(
1445 messages=[
1446 {
1447 "role": "user",
1448 "content": "Say this is a test",
1449 }
1450 ],
1451 model="gpt-3.5-turbo",
1452 ),
1453 cast_to=httpx.Response,
1454 options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
1455 )
1456
1457 assert _get_open_connections(self.client) == 0
1458
1459 @mock.patch("httpx.Response.__init__", _low_retry_response_init)
1460 async def test_retrying_status_errors_doesnt_leak(self) -> None:
1461 def raise_for_status(response: httpx.Response) -> None:
1462 response.status_code = 500
1463 raise httpx.HTTPStatusError("Test 500 error", response=response, request=response.request)
1464
1465 with mock.patch("httpx.Response.raise_for_status", raise_for_status):
1466 with pytest.raises(APIStatusError):
1467 await self.client.post(
1468 "/chat/completions",
1469 body=dict(
1470 messages=[
1471 {
1472 "role": "user",
1473 "content": "Say this is a test",
1474 }
1475 ],
1476 model="gpt-3.5-turbo",
1477 ),
1478 cast_to=httpx.Response,
1479 options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}},
1480 )
1481
1482 assert _get_open_connections(self.client) == 0
1483
1484 @pytest.mark.respx(base_url=base_url)
1485 @pytest.mark.asyncio
1486 async def test_status_error_within_httpx(self, respx_mock: MockRouter) -> None:
1487 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1488
1489 def on_response(response: httpx.Response) -> None:
1490 raise httpx.HTTPStatusError(
1491 "Simulating an error inside httpx",
1492 response=response,
1493 request=response.request,
1494 )
1495
1496 client = AsyncOpenAI(
1497 base_url=base_url,
1498 api_key=api_key,
1499 _strict_response_validation=True,
1500 http_client=httpx.AsyncClient(
1501 event_hooks={
1502 "response": [on_response],
1503 }
1504 ),
1505 max_retries=0,
1506 )
1507 with pytest.raises(APIStatusError):
1508 await client.post("/foo", cast_to=httpx.Response)
1509