openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.0.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1110lines · modecode

1# File generated from our OpenAPI spec by Stainless.
2
3from __future__ import annotations
4
5import os
6import json
7import asyncio
8import inspect
9from typing import Any, Dict, Union, cast
10from unittest import mock
11
12import httpx
13import pytest
14from respx import MockRouter
15from pydantic import ValidationError
16
17from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
18from openai._client import OpenAI, AsyncOpenAI
19from openai._models import BaseModel, FinalRequestOptions
20from openai._streaming import Stream, AsyncStream
21from openai._exceptions import APIResponseValidationError
22from openai._base_client import (
23 DEFAULT_TIMEOUT,
24 HTTPX_DEFAULT_TIMEOUT,
25 BaseClient,
26 make_request_options,
27)
28
29base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
30api_key = "My API Key"
31
32
33def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
34 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
35 url = httpx.URL(request.url)
36 return dict(url.params)
37
38
39class TestOpenAI:
40 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
41
42 @pytest.mark.respx(base_url=base_url)
43 def test_raw_response(self, respx_mock: MockRouter) -> None:
44 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json='{"foo": "bar"}'))
45
46 response = self.client.post("/foo", cast_to=httpx.Response)
47 assert response.status_code == 200
48 assert isinstance(response, httpx.Response)
49 assert response.json() == '{"foo": "bar"}'
50
51 @pytest.mark.respx(base_url=base_url)
52 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
53 respx_mock.post("/foo").mock(
54 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
55 )
56
57 response = self.client.post("/foo", cast_to=httpx.Response)
58 assert response.status_code == 200
59 assert isinstance(response, httpx.Response)
60 assert response.json() == '{"foo": "bar"}'
61
62 def test_copy(self) -> None:
63 copied = self.client.copy()
64 assert id(copied) != id(self.client)
65
66 copied = self.client.copy(api_key="another My API Key")
67 assert copied.api_key == "another My API Key"
68 assert self.client.api_key == "My API Key"
69
70 def test_copy_default_options(self) -> None:
71 # options that have a default are overridden correctly
72 copied = self.client.copy(max_retries=7)
73 assert copied.max_retries == 7
74 assert self.client.max_retries == 2
75
76 copied2 = copied.copy(max_retries=6)
77 assert copied2.max_retries == 6
78 assert copied.max_retries == 7
79
80 # timeout
81 assert isinstance(self.client.timeout, httpx.Timeout)
82 copied = self.client.copy(timeout=None)
83 assert copied.timeout is None
84 assert isinstance(self.client.timeout, httpx.Timeout)
85
86 def test_copy_default_headers(self) -> None:
87 client = OpenAI(
88 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
89 )
90 assert client.default_headers["X-Foo"] == "bar"
91
92 # does not override the already given value when not specified
93 copied = client.copy()
94 assert copied.default_headers["X-Foo"] == "bar"
95
96 # merges already given headers
97 copied = client.copy(default_headers={"X-Bar": "stainless"})
98 assert copied.default_headers["X-Foo"] == "bar"
99 assert copied.default_headers["X-Bar"] == "stainless"
100
101 # uses new values for any already given headers
102 copied = client.copy(default_headers={"X-Foo": "stainless"})
103 assert copied.default_headers["X-Foo"] == "stainless"
104
105 # set_default_headers
106
107 # completely overrides already set values
108 copied = client.copy(set_default_headers={})
109 assert copied.default_headers.get("X-Foo") is None
110
111 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
112 assert copied.default_headers["X-Bar"] == "Robert"
113
114 with pytest.raises(
115 ValueError,
116 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
117 ):
118 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
119
120 def test_copy_default_query(self) -> None:
121 client = OpenAI(
122 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
123 )
124 assert _get_params(client)["foo"] == "bar"
125
126 # does not override the already given value when not specified
127 copied = client.copy()
128 assert _get_params(copied)["foo"] == "bar"
129
130 # merges already given params
131 copied = client.copy(default_query={"bar": "stainless"})
132 params = _get_params(copied)
133 assert params["foo"] == "bar"
134 assert params["bar"] == "stainless"
135
136 # uses new values for any already given headers
137 copied = client.copy(default_query={"foo": "stainless"})
138 assert _get_params(copied)["foo"] == "stainless"
139
140 # set_default_query
141
142 # completely overrides already set values
143 copied = client.copy(set_default_query={})
144 assert _get_params(copied) == {}
145
146 copied = client.copy(set_default_query={"bar": "Robert"})
147 assert _get_params(copied)["bar"] == "Robert"
148
149 with pytest.raises(
150 ValueError,
151 # TODO: update
152 match="`default_query` and `set_default_query` arguments are mutually exclusive",
153 ):
154 client.copy(set_default_query={}, default_query={"foo": "Bar"})
155
156 def test_copy_signature(self) -> None:
157 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
158 init_signature = inspect.signature(
159 # mypy doesn't like that we access the `__init__` property.
160 self.client.__init__, # type: ignore[misc]
161 )
162 copy_signature = inspect.signature(self.client.copy)
163 exclude_params = {"transport", "proxies", "_strict_response_validation"}
164
165 for name in init_signature.parameters.keys():
166 if name in exclude_params:
167 continue
168
169 copy_param = copy_signature.parameters.get(name)
170 assert copy_param is not None, f"copy() signature is missing the {name} param"
171
172 def test_request_timeout(self) -> None:
173 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
174 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
175 assert timeout == DEFAULT_TIMEOUT
176
177 request = self.client._build_request(
178 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
179 )
180 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
181 assert timeout == httpx.Timeout(100.0)
182
183 def test_client_timeout_option(self) -> None:
184 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
185
186 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
187 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
188 assert timeout == httpx.Timeout(0)
189
190 def test_http_client_timeout_option(self) -> None:
191 # custom timeout given to the httpx client should be used
192 with httpx.Client(timeout=None) as http_client:
193 client = OpenAI(
194 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
195 )
196
197 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
198 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
199 assert timeout == httpx.Timeout(None)
200
201 # no timeout given to the httpx client should not use the httpx default
202 with httpx.Client() as http_client:
203 client = OpenAI(
204 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
205 )
206
207 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
208 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
209 assert timeout == DEFAULT_TIMEOUT
210
211 # explicitly passing the default timeout currently results in it being ignored
212 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
213 client = OpenAI(
214 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
215 )
216
217 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
218 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
219 assert timeout == DEFAULT_TIMEOUT # our default
220
221 def test_default_headers_option(self) -> None:
222 client = OpenAI(
223 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
224 )
225 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
226 assert request.headers.get("x-foo") == "bar"
227 assert request.headers.get("x-stainless-lang") == "python"
228
229 client2 = OpenAI(
230 base_url=base_url,
231 api_key=api_key,
232 _strict_response_validation=True,
233 default_headers={
234 "X-Foo": "stainless",
235 "X-Stainless-Lang": "my-overriding-header",
236 },
237 )
238 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
239 assert request.headers.get("x-foo") == "stainless"
240 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
241
242 def test_validate_headers(self) -> None:
243 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
244 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
245 assert request.headers.get("Authorization") == f"Bearer {api_key}"
246
247 with pytest.raises(Exception):
248 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
249 _ = client2
250
251 def test_default_query_option(self) -> None:
252 client = OpenAI(
253 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
254 )
255 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
256 url = httpx.URL(request.url)
257 assert dict(url.params) == {"query_param": "bar"}
258
259 request = client._build_request(
260 FinalRequestOptions(
261 method="get",
262 url="/foo",
263 params={"foo": "baz", "query_param": "overriden"},
264 )
265 )
266 url = httpx.URL(request.url)
267 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
268
269 def test_request_extra_json(self) -> None:
270 request = self.client._build_request(
271 FinalRequestOptions(
272 method="post",
273 url="/foo",
274 json_data={"foo": "bar"},
275 extra_json={"baz": False},
276 ),
277 )
278 data = json.loads(request.content.decode("utf-8"))
279 assert data == {"foo": "bar", "baz": False}
280
281 request = self.client._build_request(
282 FinalRequestOptions(
283 method="post",
284 url="/foo",
285 extra_json={"baz": False},
286 ),
287 )
288 data = json.loads(request.content.decode("utf-8"))
289 assert data == {"baz": False}
290
291 # `extra_json` takes priority over `json_data` when keys clash
292 request = self.client._build_request(
293 FinalRequestOptions(
294 method="post",
295 url="/foo",
296 json_data={"foo": "bar", "baz": True},
297 extra_json={"baz": None},
298 ),
299 )
300 data = json.loads(request.content.decode("utf-8"))
301 assert data == {"foo": "bar", "baz": None}
302
303 def test_request_extra_headers(self) -> None:
304 request = self.client._build_request(
305 FinalRequestOptions(
306 method="post",
307 url="/foo",
308 **make_request_options(extra_headers={"X-Foo": "Foo"}),
309 ),
310 )
311 assert request.headers.get("X-Foo") == "Foo"
312
313 # `extra_headers` takes priority over `default_headers` when keys clash
314 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
315 FinalRequestOptions(
316 method="post",
317 url="/foo",
318 **make_request_options(
319 extra_headers={"X-Bar": "false"},
320 ),
321 ),
322 )
323 assert request.headers.get("X-Bar") == "false"
324
325 def test_request_extra_query(self) -> None:
326 request = self.client._build_request(
327 FinalRequestOptions(
328 method="post",
329 url="/foo",
330 **make_request_options(
331 extra_query={"my_query_param": "Foo"},
332 ),
333 ),
334 )
335 params = cast(Dict[str, str], dict(request.url.params))
336 assert params == {"my_query_param": "Foo"}
337
338 # if both `query` and `extra_query` are given, they are merged
339 request = self.client._build_request(
340 FinalRequestOptions(
341 method="post",
342 url="/foo",
343 **make_request_options(
344 query={"bar": "1"},
345 extra_query={"foo": "2"},
346 ),
347 ),
348 )
349 params = cast(Dict[str, str], dict(request.url.params))
350 assert params == {"bar": "1", "foo": "2"}
351
352 # `extra_query` takes priority over `query` when keys clash
353 request = self.client._build_request(
354 FinalRequestOptions(
355 method="post",
356 url="/foo",
357 **make_request_options(
358 query={"foo": "1"},
359 extra_query={"foo": "2"},
360 ),
361 ),
362 )
363 params = cast(Dict[str, str], dict(request.url.params))
364 assert params == {"foo": "2"}
365
366 @pytest.mark.respx(base_url=base_url)
367 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
368 class Model1(BaseModel):
369 name: str
370
371 class Model2(BaseModel):
372 foo: str
373
374 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
375
376 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
377 assert isinstance(response, Model2)
378 assert response.foo == "bar"
379
380 @pytest.mark.respx(base_url=base_url)
381 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
382 """Union of objects with the same field name using a different type"""
383
384 class Model1(BaseModel):
385 foo: int
386
387 class Model2(BaseModel):
388 foo: str
389
390 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
391
392 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
393 assert isinstance(response, Model2)
394 assert response.foo == "bar"
395
396 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
397
398 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
399 assert isinstance(response, Model1)
400 assert response.foo == 1
401
402 @pytest.mark.parametrize(
403 "client",
404 [
405 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
406 OpenAI(
407 base_url="http://localhost:5000/custom/path/",
408 api_key=api_key,
409 _strict_response_validation=True,
410 http_client=httpx.Client(),
411 ),
412 ],
413 ids=["standard", "custom http client"],
414 )
415 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
416 request = client._build_request(
417 FinalRequestOptions(
418 method="post",
419 url="/foo",
420 json_data={"foo": "bar"},
421 ),
422 )
423 assert request.url == "http://localhost:5000/custom/path/foo"
424
425 @pytest.mark.parametrize(
426 "client",
427 [
428 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
429 OpenAI(
430 base_url="http://localhost:5000/custom/path/",
431 api_key=api_key,
432 _strict_response_validation=True,
433 http_client=httpx.Client(),
434 ),
435 ],
436 ids=["standard", "custom http client"],
437 )
438 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
439 request = client._build_request(
440 FinalRequestOptions(
441 method="post",
442 url="/foo",
443 json_data={"foo": "bar"},
444 ),
445 )
446 assert request.url == "http://localhost:5000/custom/path/foo"
447
448 @pytest.mark.parametrize(
449 "client",
450 [
451 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
452 OpenAI(
453 base_url="http://localhost:5000/custom/path/",
454 api_key=api_key,
455 _strict_response_validation=True,
456 http_client=httpx.Client(),
457 ),
458 ],
459 ids=["standard", "custom http client"],
460 )
461 def test_absolute_request_url(self, client: OpenAI) -> None:
462 request = client._build_request(
463 FinalRequestOptions(
464 method="post",
465 url="https://myapi.com/foo",
466 json_data={"foo": "bar"},
467 ),
468 )
469 assert request.url == "https://myapi.com/foo"
470
471 def test_client_del(self) -> None:
472 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
473 assert not client.is_closed()
474
475 client.__del__()
476
477 assert client.is_closed()
478
479 def test_copied_client_does_not_close_http(self) -> None:
480 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
481 assert not client.is_closed()
482
483 copied = client.copy()
484 assert copied is not client
485
486 copied.__del__()
487
488 assert not copied.is_closed()
489 assert not client.is_closed()
490
491 def test_client_context_manager(self) -> None:
492 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
493 with client as c2:
494 assert c2 is client
495 assert not c2.is_closed()
496 assert not client.is_closed()
497 assert client.is_closed()
498
499 @pytest.mark.respx(base_url=base_url)
500 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
501 class Model(BaseModel):
502 foo: str
503
504 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
505
506 with pytest.raises(APIResponseValidationError) as exc:
507 self.client.get("/foo", cast_to=Model)
508
509 assert isinstance(exc.value.__cause__, ValidationError)
510
511 @pytest.mark.respx(base_url=base_url)
512 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
513 class Model(BaseModel):
514 name: str
515
516 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
517
518 response = self.client.post("/foo", cast_to=Model, stream=True)
519 assert isinstance(response, Stream)
520
521 @pytest.mark.respx(base_url=base_url)
522 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
523 class Model(BaseModel):
524 name: str
525
526 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
527
528 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
529
530 with pytest.raises(APIResponseValidationError):
531 strict_client.get("/foo", cast_to=Model)
532
533 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
534
535 response = client.get("/foo", cast_to=Model)
536 assert isinstance(response, str) # type: ignore[unreachable]
537
538 @pytest.mark.parametrize(
539 "remaining_retries,retry_after,timeout",
540 [
541 [3, "20", 20],
542 [3, "0", 0.5],
543 [3, "-10", 0.5],
544 [3, "60", 60],
545 [3, "61", 0.5],
546 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
547 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
548 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
549 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
550 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
551 [3, "99999999999999999999999999999999999", 0.5],
552 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
553 [3, "", 0.5],
554 [2, "", 0.5 * 2.0],
555 [1, "", 0.5 * 4.0],
556 ],
557 )
558 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
559 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
560 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
561
562 headers = httpx.Headers({"retry-after": retry_after})
563 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
564 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
565 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
566
567
568class TestAsyncOpenAI:
569 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
570
571 @pytest.mark.respx(base_url=base_url)
572 @pytest.mark.asyncio
573 async def test_raw_response(self, respx_mock: MockRouter) -> None:
574 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json='{"foo": "bar"}'))
575
576 response = await self.client.post("/foo", cast_to=httpx.Response)
577 assert response.status_code == 200
578 assert isinstance(response, httpx.Response)
579 assert response.json() == '{"foo": "bar"}'
580
581 @pytest.mark.respx(base_url=base_url)
582 @pytest.mark.asyncio
583 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
584 respx_mock.post("/foo").mock(
585 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
586 )
587
588 response = await self.client.post("/foo", cast_to=httpx.Response)
589 assert response.status_code == 200
590 assert isinstance(response, httpx.Response)
591 assert response.json() == '{"foo": "bar"}'
592
593 def test_copy(self) -> None:
594 copied = self.client.copy()
595 assert id(copied) != id(self.client)
596
597 copied = self.client.copy(api_key="another My API Key")
598 assert copied.api_key == "another My API Key"
599 assert self.client.api_key == "My API Key"
600
601 def test_copy_default_options(self) -> None:
602 # options that have a default are overridden correctly
603 copied = self.client.copy(max_retries=7)
604 assert copied.max_retries == 7
605 assert self.client.max_retries == 2
606
607 copied2 = copied.copy(max_retries=6)
608 assert copied2.max_retries == 6
609 assert copied.max_retries == 7
610
611 # timeout
612 assert isinstance(self.client.timeout, httpx.Timeout)
613 copied = self.client.copy(timeout=None)
614 assert copied.timeout is None
615 assert isinstance(self.client.timeout, httpx.Timeout)
616
617 def test_copy_default_headers(self) -> None:
618 client = AsyncOpenAI(
619 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
620 )
621 assert client.default_headers["X-Foo"] == "bar"
622
623 # does not override the already given value when not specified
624 copied = client.copy()
625 assert copied.default_headers["X-Foo"] == "bar"
626
627 # merges already given headers
628 copied = client.copy(default_headers={"X-Bar": "stainless"})
629 assert copied.default_headers["X-Foo"] == "bar"
630 assert copied.default_headers["X-Bar"] == "stainless"
631
632 # uses new values for any already given headers
633 copied = client.copy(default_headers={"X-Foo": "stainless"})
634 assert copied.default_headers["X-Foo"] == "stainless"
635
636 # set_default_headers
637
638 # completely overrides already set values
639 copied = client.copy(set_default_headers={})
640 assert copied.default_headers.get("X-Foo") is None
641
642 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
643 assert copied.default_headers["X-Bar"] == "Robert"
644
645 with pytest.raises(
646 ValueError,
647 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
648 ):
649 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
650
651 def test_copy_default_query(self) -> None:
652 client = AsyncOpenAI(
653 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
654 )
655 assert _get_params(client)["foo"] == "bar"
656
657 # does not override the already given value when not specified
658 copied = client.copy()
659 assert _get_params(copied)["foo"] == "bar"
660
661 # merges already given params
662 copied = client.copy(default_query={"bar": "stainless"})
663 params = _get_params(copied)
664 assert params["foo"] == "bar"
665 assert params["bar"] == "stainless"
666
667 # uses new values for any already given headers
668 copied = client.copy(default_query={"foo": "stainless"})
669 assert _get_params(copied)["foo"] == "stainless"
670
671 # set_default_query
672
673 # completely overrides already set values
674 copied = client.copy(set_default_query={})
675 assert _get_params(copied) == {}
676
677 copied = client.copy(set_default_query={"bar": "Robert"})
678 assert _get_params(copied)["bar"] == "Robert"
679
680 with pytest.raises(
681 ValueError,
682 # TODO: update
683 match="`default_query` and `set_default_query` arguments are mutually exclusive",
684 ):
685 client.copy(set_default_query={}, default_query={"foo": "Bar"})
686
687 def test_copy_signature(self) -> None:
688 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
689 init_signature = inspect.signature(
690 # mypy doesn't like that we access the `__init__` property.
691 self.client.__init__, # type: ignore[misc]
692 )
693 copy_signature = inspect.signature(self.client.copy)
694 exclude_params = {"transport", "proxies", "_strict_response_validation"}
695
696 for name in init_signature.parameters.keys():
697 if name in exclude_params:
698 continue
699
700 copy_param = copy_signature.parameters.get(name)
701 assert copy_param is not None, f"copy() signature is missing the {name} param"
702
703 async def test_request_timeout(self) -> None:
704 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
705 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
706 assert timeout == DEFAULT_TIMEOUT
707
708 request = self.client._build_request(
709 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
710 )
711 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
712 assert timeout == httpx.Timeout(100.0)
713
714 async def test_client_timeout_option(self) -> None:
715 client = AsyncOpenAI(
716 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
717 )
718
719 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
720 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
721 assert timeout == httpx.Timeout(0)
722
723 async def test_http_client_timeout_option(self) -> None:
724 # custom timeout given to the httpx client should be used
725 async with httpx.AsyncClient(timeout=None) as http_client:
726 client = AsyncOpenAI(
727 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
728 )
729
730 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
731 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
732 assert timeout == httpx.Timeout(None)
733
734 # no timeout given to the httpx client should not use the httpx default
735 async with httpx.AsyncClient() as http_client:
736 client = AsyncOpenAI(
737 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
738 )
739
740 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
741 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
742 assert timeout == DEFAULT_TIMEOUT
743
744 # explicitly passing the default timeout currently results in it being ignored
745 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
746 client = AsyncOpenAI(
747 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
748 )
749
750 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
751 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
752 assert timeout == DEFAULT_TIMEOUT # our default
753
754 def test_default_headers_option(self) -> None:
755 client = AsyncOpenAI(
756 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
757 )
758 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
759 assert request.headers.get("x-foo") == "bar"
760 assert request.headers.get("x-stainless-lang") == "python"
761
762 client2 = AsyncOpenAI(
763 base_url=base_url,
764 api_key=api_key,
765 _strict_response_validation=True,
766 default_headers={
767 "X-Foo": "stainless",
768 "X-Stainless-Lang": "my-overriding-header",
769 },
770 )
771 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
772 assert request.headers.get("x-foo") == "stainless"
773 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
774
775 def test_validate_headers(self) -> None:
776 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
777 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
778 assert request.headers.get("Authorization") == f"Bearer {api_key}"
779
780 with pytest.raises(Exception):
781 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
782 _ = client2
783
784 def test_default_query_option(self) -> None:
785 client = AsyncOpenAI(
786 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
787 )
788 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
789 url = httpx.URL(request.url)
790 assert dict(url.params) == {"query_param": "bar"}
791
792 request = client._build_request(
793 FinalRequestOptions(
794 method="get",
795 url="/foo",
796 params={"foo": "baz", "query_param": "overriden"},
797 )
798 )
799 url = httpx.URL(request.url)
800 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
801
802 def test_request_extra_json(self) -> None:
803 request = self.client._build_request(
804 FinalRequestOptions(
805 method="post",
806 url="/foo",
807 json_data={"foo": "bar"},
808 extra_json={"baz": False},
809 ),
810 )
811 data = json.loads(request.content.decode("utf-8"))
812 assert data == {"foo": "bar", "baz": False}
813
814 request = self.client._build_request(
815 FinalRequestOptions(
816 method="post",
817 url="/foo",
818 extra_json={"baz": False},
819 ),
820 )
821 data = json.loads(request.content.decode("utf-8"))
822 assert data == {"baz": False}
823
824 # `extra_json` takes priority over `json_data` when keys clash
825 request = self.client._build_request(
826 FinalRequestOptions(
827 method="post",
828 url="/foo",
829 json_data={"foo": "bar", "baz": True},
830 extra_json={"baz": None},
831 ),
832 )
833 data = json.loads(request.content.decode("utf-8"))
834 assert data == {"foo": "bar", "baz": None}
835
836 def test_request_extra_headers(self) -> None:
837 request = self.client._build_request(
838 FinalRequestOptions(
839 method="post",
840 url="/foo",
841 **make_request_options(extra_headers={"X-Foo": "Foo"}),
842 ),
843 )
844 assert request.headers.get("X-Foo") == "Foo"
845
846 # `extra_headers` takes priority over `default_headers` when keys clash
847 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
848 FinalRequestOptions(
849 method="post",
850 url="/foo",
851 **make_request_options(
852 extra_headers={"X-Bar": "false"},
853 ),
854 ),
855 )
856 assert request.headers.get("X-Bar") == "false"
857
858 def test_request_extra_query(self) -> None:
859 request = self.client._build_request(
860 FinalRequestOptions(
861 method="post",
862 url="/foo",
863 **make_request_options(
864 extra_query={"my_query_param": "Foo"},
865 ),
866 ),
867 )
868 params = cast(Dict[str, str], dict(request.url.params))
869 assert params == {"my_query_param": "Foo"}
870
871 # if both `query` and `extra_query` are given, they are merged
872 request = self.client._build_request(
873 FinalRequestOptions(
874 method="post",
875 url="/foo",
876 **make_request_options(
877 query={"bar": "1"},
878 extra_query={"foo": "2"},
879 ),
880 ),
881 )
882 params = cast(Dict[str, str], dict(request.url.params))
883 assert params == {"bar": "1", "foo": "2"}
884
885 # `extra_query` takes priority over `query` when keys clash
886 request = self.client._build_request(
887 FinalRequestOptions(
888 method="post",
889 url="/foo",
890 **make_request_options(
891 query={"foo": "1"},
892 extra_query={"foo": "2"},
893 ),
894 ),
895 )
896 params = cast(Dict[str, str], dict(request.url.params))
897 assert params == {"foo": "2"}
898
899 @pytest.mark.respx(base_url=base_url)
900 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
901 class Model1(BaseModel):
902 name: str
903
904 class Model2(BaseModel):
905 foo: str
906
907 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
908
909 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
910 assert isinstance(response, Model2)
911 assert response.foo == "bar"
912
913 @pytest.mark.respx(base_url=base_url)
914 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
915 """Union of objects with the same field name using a different type"""
916
917 class Model1(BaseModel):
918 foo: int
919
920 class Model2(BaseModel):
921 foo: str
922
923 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
924
925 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
926 assert isinstance(response, Model2)
927 assert response.foo == "bar"
928
929 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
930
931 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
932 assert isinstance(response, Model1)
933 assert response.foo == 1
934
935 @pytest.mark.parametrize(
936 "client",
937 [
938 AsyncOpenAI(
939 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
940 ),
941 AsyncOpenAI(
942 base_url="http://localhost:5000/custom/path/",
943 api_key=api_key,
944 _strict_response_validation=True,
945 http_client=httpx.AsyncClient(),
946 ),
947 ],
948 ids=["standard", "custom http client"],
949 )
950 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
951 request = client._build_request(
952 FinalRequestOptions(
953 method="post",
954 url="/foo",
955 json_data={"foo": "bar"},
956 ),
957 )
958 assert request.url == "http://localhost:5000/custom/path/foo"
959
960 @pytest.mark.parametrize(
961 "client",
962 [
963 AsyncOpenAI(
964 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
965 ),
966 AsyncOpenAI(
967 base_url="http://localhost:5000/custom/path/",
968 api_key=api_key,
969 _strict_response_validation=True,
970 http_client=httpx.AsyncClient(),
971 ),
972 ],
973 ids=["standard", "custom http client"],
974 )
975 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
976 request = client._build_request(
977 FinalRequestOptions(
978 method="post",
979 url="/foo",
980 json_data={"foo": "bar"},
981 ),
982 )
983 assert request.url == "http://localhost:5000/custom/path/foo"
984
985 @pytest.mark.parametrize(
986 "client",
987 [
988 AsyncOpenAI(
989 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
990 ),
991 AsyncOpenAI(
992 base_url="http://localhost:5000/custom/path/",
993 api_key=api_key,
994 _strict_response_validation=True,
995 http_client=httpx.AsyncClient(),
996 ),
997 ],
998 ids=["standard", "custom http client"],
999 )
1000 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1001 request = client._build_request(
1002 FinalRequestOptions(
1003 method="post",
1004 url="https://myapi.com/foo",
1005 json_data={"foo": "bar"},
1006 ),
1007 )
1008 assert request.url == "https://myapi.com/foo"
1009
1010 async def test_client_del(self) -> None:
1011 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1012 assert not client.is_closed()
1013
1014 client.__del__()
1015
1016 await asyncio.sleep(0.2)
1017 assert client.is_closed()
1018
1019 async def test_copied_client_does_not_close_http(self) -> None:
1020 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1021 assert not client.is_closed()
1022
1023 copied = client.copy()
1024 assert copied is not client
1025
1026 copied.__del__()
1027
1028 await asyncio.sleep(0.2)
1029 assert not copied.is_closed()
1030 assert not client.is_closed()
1031
1032 async def test_client_context_manager(self) -> None:
1033 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1034 async with client as c2:
1035 assert c2 is client
1036 assert not c2.is_closed()
1037 assert not client.is_closed()
1038 assert client.is_closed()
1039
1040 @pytest.mark.respx(base_url=base_url)
1041 @pytest.mark.asyncio
1042 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1043 class Model(BaseModel):
1044 foo: str
1045
1046 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1047
1048 with pytest.raises(APIResponseValidationError) as exc:
1049 await self.client.get("/foo", cast_to=Model)
1050
1051 assert isinstance(exc.value.__cause__, ValidationError)
1052
1053 @pytest.mark.respx(base_url=base_url)
1054 @pytest.mark.asyncio
1055 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1056 class Model(BaseModel):
1057 name: str
1058
1059 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1060
1061 response = await self.client.post("/foo", cast_to=Model, stream=True)
1062 assert isinstance(response, AsyncStream)
1063
1064 @pytest.mark.respx(base_url=base_url)
1065 @pytest.mark.asyncio
1066 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1067 class Model(BaseModel):
1068 name: str
1069
1070 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1071
1072 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1073
1074 with pytest.raises(APIResponseValidationError):
1075 await strict_client.get("/foo", cast_to=Model)
1076
1077 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1078
1079 response = await client.get("/foo", cast_to=Model)
1080 assert isinstance(response, str) # type: ignore[unreachable]
1081
1082 @pytest.mark.parametrize(
1083 "remaining_retries,retry_after,timeout",
1084 [
1085 [3, "20", 20],
1086 [3, "0", 0.5],
1087 [3, "-10", 0.5],
1088 [3, "60", 60],
1089 [3, "61", 0.5],
1090 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1091 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1092 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1093 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1094 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1095 [3, "99999999999999999999999999999999999", 0.5],
1096 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1097 [3, "", 0.5],
1098 [2, "", 0.5 * 2.0],
1099 [1, "", 0.5 * 4.0],
1100 ],
1101 )
1102 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1103 @pytest.mark.asyncio
1104 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1105 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1106
1107 headers = httpx.Headers({"retry-after": retry_after})
1108 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1109 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1110 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1111