openai/openai-python

Public

mirrored from https://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
f6c2aedfe3dfceeabd551909b572d72c7cef4373

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1122lines · modecode

1# File generated from our OpenAPI spec by Stainless.
2
3from __future__ import annotations
4
5import os
6import json
7import asyncio
8import inspect
9from typing import Any, Dict, Union, cast
10from unittest import mock
11
12import httpx
13import pytest
14from respx import MockRouter
15from pydantic import ValidationError
16
17from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
18from openai._client import OpenAI, AsyncOpenAI
19from openai._models import BaseModel, FinalRequestOptions
20from openai._streaming import Stream, AsyncStream
21from openai._exceptions import APIResponseValidationError
22from openai._base_client import (
23 DEFAULT_TIMEOUT,
24 HTTPX_DEFAULT_TIMEOUT,
25 BaseClient,
26 make_request_options,
27)
28
29from .utils import update_env
30
31base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
32api_key = "My API Key"
33
34
35def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
36 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
37 url = httpx.URL(request.url)
38 return dict(url.params)
39
40
41class TestOpenAI:
42 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
43
44 @pytest.mark.respx(base_url=base_url)
45 def test_raw_response(self, respx_mock: MockRouter) -> None:
46 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
47
48 response = self.client.post("/foo", cast_to=httpx.Response)
49 assert response.status_code == 200
50 assert isinstance(response, httpx.Response)
51 assert response.json() == {"foo": "bar"}
52
53 @pytest.mark.respx(base_url=base_url)
54 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
55 respx_mock.post("/foo").mock(
56 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
57 )
58
59 response = self.client.post("/foo", cast_to=httpx.Response)
60 assert response.status_code == 200
61 assert isinstance(response, httpx.Response)
62 assert response.json() == {"foo": "bar"}
63
64 def test_copy(self) -> None:
65 copied = self.client.copy()
66 assert id(copied) != id(self.client)
67
68 copied = self.client.copy(api_key="another My API Key")
69 assert copied.api_key == "another My API Key"
70 assert self.client.api_key == "My API Key"
71
72 def test_copy_default_options(self) -> None:
73 # options that have a default are overridden correctly
74 copied = self.client.copy(max_retries=7)
75 assert copied.max_retries == 7
76 assert self.client.max_retries == 2
77
78 copied2 = copied.copy(max_retries=6)
79 assert copied2.max_retries == 6
80 assert copied.max_retries == 7
81
82 # timeout
83 assert isinstance(self.client.timeout, httpx.Timeout)
84 copied = self.client.copy(timeout=None)
85 assert copied.timeout is None
86 assert isinstance(self.client.timeout, httpx.Timeout)
87
88 def test_copy_default_headers(self) -> None:
89 client = OpenAI(
90 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
91 )
92 assert client.default_headers["X-Foo"] == "bar"
93
94 # does not override the already given value when not specified
95 copied = client.copy()
96 assert copied.default_headers["X-Foo"] == "bar"
97
98 # merges already given headers
99 copied = client.copy(default_headers={"X-Bar": "stainless"})
100 assert copied.default_headers["X-Foo"] == "bar"
101 assert copied.default_headers["X-Bar"] == "stainless"
102
103 # uses new values for any already given headers
104 copied = client.copy(default_headers={"X-Foo": "stainless"})
105 assert copied.default_headers["X-Foo"] == "stainless"
106
107 # set_default_headers
108
109 # completely overrides already set values
110 copied = client.copy(set_default_headers={})
111 assert copied.default_headers.get("X-Foo") is None
112
113 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
114 assert copied.default_headers["X-Bar"] == "Robert"
115
116 with pytest.raises(
117 ValueError,
118 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
119 ):
120 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
121
122 def test_copy_default_query(self) -> None:
123 client = OpenAI(
124 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
125 )
126 assert _get_params(client)["foo"] == "bar"
127
128 # does not override the already given value when not specified
129 copied = client.copy()
130 assert _get_params(copied)["foo"] == "bar"
131
132 # merges already given params
133 copied = client.copy(default_query={"bar": "stainless"})
134 params = _get_params(copied)
135 assert params["foo"] == "bar"
136 assert params["bar"] == "stainless"
137
138 # uses new values for any already given headers
139 copied = client.copy(default_query={"foo": "stainless"})
140 assert _get_params(copied)["foo"] == "stainless"
141
142 # set_default_query
143
144 # completely overrides already set values
145 copied = client.copy(set_default_query={})
146 assert _get_params(copied) == {}
147
148 copied = client.copy(set_default_query={"bar": "Robert"})
149 assert _get_params(copied)["bar"] == "Robert"
150
151 with pytest.raises(
152 ValueError,
153 # TODO: update
154 match="`default_query` and `set_default_query` arguments are mutually exclusive",
155 ):
156 client.copy(set_default_query={}, default_query={"foo": "Bar"})
157
158 def test_copy_signature(self) -> None:
159 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
160 init_signature = inspect.signature(
161 # mypy doesn't like that we access the `__init__` property.
162 self.client.__init__, # type: ignore[misc]
163 )
164 copy_signature = inspect.signature(self.client.copy)
165 exclude_params = {"transport", "proxies", "_strict_response_validation"}
166
167 for name in init_signature.parameters.keys():
168 if name in exclude_params:
169 continue
170
171 copy_param = copy_signature.parameters.get(name)
172 assert copy_param is not None, f"copy() signature is missing the {name} param"
173
174 def test_request_timeout(self) -> None:
175 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
176 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
177 assert timeout == DEFAULT_TIMEOUT
178
179 request = self.client._build_request(
180 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
181 )
182 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
183 assert timeout == httpx.Timeout(100.0)
184
185 def test_client_timeout_option(self) -> None:
186 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
187
188 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
189 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
190 assert timeout == httpx.Timeout(0)
191
192 def test_http_client_timeout_option(self) -> None:
193 # custom timeout given to the httpx client should be used
194 with httpx.Client(timeout=None) as http_client:
195 client = OpenAI(
196 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
197 )
198
199 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
200 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
201 assert timeout == httpx.Timeout(None)
202
203 # no timeout given to the httpx client should not use the httpx default
204 with httpx.Client() as http_client:
205 client = OpenAI(
206 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
207 )
208
209 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
210 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
211 assert timeout == DEFAULT_TIMEOUT
212
213 # explicitly passing the default timeout currently results in it being ignored
214 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
215 client = OpenAI(
216 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
217 )
218
219 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
220 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
221 assert timeout == DEFAULT_TIMEOUT # our default
222
223 def test_default_headers_option(self) -> None:
224 client = OpenAI(
225 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
226 )
227 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
228 assert request.headers.get("x-foo") == "bar"
229 assert request.headers.get("x-stainless-lang") == "python"
230
231 client2 = OpenAI(
232 base_url=base_url,
233 api_key=api_key,
234 _strict_response_validation=True,
235 default_headers={
236 "X-Foo": "stainless",
237 "X-Stainless-Lang": "my-overriding-header",
238 },
239 )
240 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
241 assert request.headers.get("x-foo") == "stainless"
242 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
243
244 def test_validate_headers(self) -> None:
245 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
246 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
247 assert request.headers.get("Authorization") == f"Bearer {api_key}"
248
249 with pytest.raises(Exception):
250 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
251 _ = client2
252
253 def test_default_query_option(self) -> None:
254 client = OpenAI(
255 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
256 )
257 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
258 url = httpx.URL(request.url)
259 assert dict(url.params) == {"query_param": "bar"}
260
261 request = client._build_request(
262 FinalRequestOptions(
263 method="get",
264 url="/foo",
265 params={"foo": "baz", "query_param": "overriden"},
266 )
267 )
268 url = httpx.URL(request.url)
269 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
270
271 def test_request_extra_json(self) -> None:
272 request = self.client._build_request(
273 FinalRequestOptions(
274 method="post",
275 url="/foo",
276 json_data={"foo": "bar"},
277 extra_json={"baz": False},
278 ),
279 )
280 data = json.loads(request.content.decode("utf-8"))
281 assert data == {"foo": "bar", "baz": False}
282
283 request = self.client._build_request(
284 FinalRequestOptions(
285 method="post",
286 url="/foo",
287 extra_json={"baz": False},
288 ),
289 )
290 data = json.loads(request.content.decode("utf-8"))
291 assert data == {"baz": False}
292
293 # `extra_json` takes priority over `json_data` when keys clash
294 request = self.client._build_request(
295 FinalRequestOptions(
296 method="post",
297 url="/foo",
298 json_data={"foo": "bar", "baz": True},
299 extra_json={"baz": None},
300 ),
301 )
302 data = json.loads(request.content.decode("utf-8"))
303 assert data == {"foo": "bar", "baz": None}
304
305 def test_request_extra_headers(self) -> None:
306 request = self.client._build_request(
307 FinalRequestOptions(
308 method="post",
309 url="/foo",
310 **make_request_options(extra_headers={"X-Foo": "Foo"}),
311 ),
312 )
313 assert request.headers.get("X-Foo") == "Foo"
314
315 # `extra_headers` takes priority over `default_headers` when keys clash
316 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
317 FinalRequestOptions(
318 method="post",
319 url="/foo",
320 **make_request_options(
321 extra_headers={"X-Bar": "false"},
322 ),
323 ),
324 )
325 assert request.headers.get("X-Bar") == "false"
326
327 def test_request_extra_query(self) -> None:
328 request = self.client._build_request(
329 FinalRequestOptions(
330 method="post",
331 url="/foo",
332 **make_request_options(
333 extra_query={"my_query_param": "Foo"},
334 ),
335 ),
336 )
337 params = cast(Dict[str, str], dict(request.url.params))
338 assert params == {"my_query_param": "Foo"}
339
340 # if both `query` and `extra_query` are given, they are merged
341 request = self.client._build_request(
342 FinalRequestOptions(
343 method="post",
344 url="/foo",
345 **make_request_options(
346 query={"bar": "1"},
347 extra_query={"foo": "2"},
348 ),
349 ),
350 )
351 params = cast(Dict[str, str], dict(request.url.params))
352 assert params == {"bar": "1", "foo": "2"}
353
354 # `extra_query` takes priority over `query` when keys clash
355 request = self.client._build_request(
356 FinalRequestOptions(
357 method="post",
358 url="/foo",
359 **make_request_options(
360 query={"foo": "1"},
361 extra_query={"foo": "2"},
362 ),
363 ),
364 )
365 params = cast(Dict[str, str], dict(request.url.params))
366 assert params == {"foo": "2"}
367
368 @pytest.mark.respx(base_url=base_url)
369 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
370 class Model1(BaseModel):
371 name: str
372
373 class Model2(BaseModel):
374 foo: str
375
376 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
377
378 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
379 assert isinstance(response, Model2)
380 assert response.foo == "bar"
381
382 @pytest.mark.respx(base_url=base_url)
383 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
384 """Union of objects with the same field name using a different type"""
385
386 class Model1(BaseModel):
387 foo: int
388
389 class Model2(BaseModel):
390 foo: str
391
392 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
393
394 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
395 assert isinstance(response, Model2)
396 assert response.foo == "bar"
397
398 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
399
400 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
401 assert isinstance(response, Model1)
402 assert response.foo == 1
403
404 def test_base_url_env(self) -> None:
405 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
406 client = OpenAI(api_key=api_key, _strict_response_validation=True)
407 assert client.base_url == "http://localhost:5000/from/env/"
408
409 @pytest.mark.parametrize(
410 "client",
411 [
412 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
413 OpenAI(
414 base_url="http://localhost:5000/custom/path/",
415 api_key=api_key,
416 _strict_response_validation=True,
417 http_client=httpx.Client(),
418 ),
419 ],
420 ids=["standard", "custom http client"],
421 )
422 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
423 request = client._build_request(
424 FinalRequestOptions(
425 method="post",
426 url="/foo",
427 json_data={"foo": "bar"},
428 ),
429 )
430 assert request.url == "http://localhost:5000/custom/path/foo"
431
432 @pytest.mark.parametrize(
433 "client",
434 [
435 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
436 OpenAI(
437 base_url="http://localhost:5000/custom/path/",
438 api_key=api_key,
439 _strict_response_validation=True,
440 http_client=httpx.Client(),
441 ),
442 ],
443 ids=["standard", "custom http client"],
444 )
445 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
446 request = client._build_request(
447 FinalRequestOptions(
448 method="post",
449 url="/foo",
450 json_data={"foo": "bar"},
451 ),
452 )
453 assert request.url == "http://localhost:5000/custom/path/foo"
454
455 @pytest.mark.parametrize(
456 "client",
457 [
458 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
459 OpenAI(
460 base_url="http://localhost:5000/custom/path/",
461 api_key=api_key,
462 _strict_response_validation=True,
463 http_client=httpx.Client(),
464 ),
465 ],
466 ids=["standard", "custom http client"],
467 )
468 def test_absolute_request_url(self, client: OpenAI) -> None:
469 request = client._build_request(
470 FinalRequestOptions(
471 method="post",
472 url="https://myapi.com/foo",
473 json_data={"foo": "bar"},
474 ),
475 )
476 assert request.url == "https://myapi.com/foo"
477
478 def test_client_del(self) -> None:
479 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
480 assert not client.is_closed()
481
482 client.__del__()
483
484 assert client.is_closed()
485
486 def test_copied_client_does_not_close_http(self) -> None:
487 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
488 assert not client.is_closed()
489
490 copied = client.copy()
491 assert copied is not client
492
493 copied.__del__()
494
495 assert not copied.is_closed()
496 assert not client.is_closed()
497
498 def test_client_context_manager(self) -> None:
499 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
500 with client as c2:
501 assert c2 is client
502 assert not c2.is_closed()
503 assert not client.is_closed()
504 assert client.is_closed()
505
506 @pytest.mark.respx(base_url=base_url)
507 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
508 class Model(BaseModel):
509 foo: str
510
511 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
512
513 with pytest.raises(APIResponseValidationError) as exc:
514 self.client.get("/foo", cast_to=Model)
515
516 assert isinstance(exc.value.__cause__, ValidationError)
517
518 @pytest.mark.respx(base_url=base_url)
519 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
520 class Model(BaseModel):
521 name: str
522
523 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
524
525 response = self.client.post("/foo", cast_to=Model, stream=True)
526 assert isinstance(response, Stream)
527
528 @pytest.mark.respx(base_url=base_url)
529 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
530 class Model(BaseModel):
531 name: str
532
533 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
534
535 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
536
537 with pytest.raises(APIResponseValidationError):
538 strict_client.get("/foo", cast_to=Model)
539
540 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
541
542 response = client.get("/foo", cast_to=Model)
543 assert isinstance(response, str) # type: ignore[unreachable]
544
545 @pytest.mark.parametrize(
546 "remaining_retries,retry_after,timeout",
547 [
548 [3, "20", 20],
549 [3, "0", 0.5],
550 [3, "-10", 0.5],
551 [3, "60", 60],
552 [3, "61", 0.5],
553 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
554 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
555 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
556 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
557 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
558 [3, "99999999999999999999999999999999999", 0.5],
559 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
560 [3, "", 0.5],
561 [2, "", 0.5 * 2.0],
562 [1, "", 0.5 * 4.0],
563 ],
564 )
565 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
566 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
567 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
568
569 headers = httpx.Headers({"retry-after": retry_after})
570 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
571 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
572 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
573
574
575class TestAsyncOpenAI:
576 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
577
578 @pytest.mark.respx(base_url=base_url)
579 @pytest.mark.asyncio
580 async def test_raw_response(self, respx_mock: MockRouter) -> None:
581 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
582
583 response = await self.client.post("/foo", cast_to=httpx.Response)
584 assert response.status_code == 200
585 assert isinstance(response, httpx.Response)
586 assert response.json() == {"foo": "bar"}
587
588 @pytest.mark.respx(base_url=base_url)
589 @pytest.mark.asyncio
590 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
591 respx_mock.post("/foo").mock(
592 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
593 )
594
595 response = await self.client.post("/foo", cast_to=httpx.Response)
596 assert response.status_code == 200
597 assert isinstance(response, httpx.Response)
598 assert response.json() == {"foo": "bar"}
599
600 def test_copy(self) -> None:
601 copied = self.client.copy()
602 assert id(copied) != id(self.client)
603
604 copied = self.client.copy(api_key="another My API Key")
605 assert copied.api_key == "another My API Key"
606 assert self.client.api_key == "My API Key"
607
608 def test_copy_default_options(self) -> None:
609 # options that have a default are overridden correctly
610 copied = self.client.copy(max_retries=7)
611 assert copied.max_retries == 7
612 assert self.client.max_retries == 2
613
614 copied2 = copied.copy(max_retries=6)
615 assert copied2.max_retries == 6
616 assert copied.max_retries == 7
617
618 # timeout
619 assert isinstance(self.client.timeout, httpx.Timeout)
620 copied = self.client.copy(timeout=None)
621 assert copied.timeout is None
622 assert isinstance(self.client.timeout, httpx.Timeout)
623
624 def test_copy_default_headers(self) -> None:
625 client = AsyncOpenAI(
626 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
627 )
628 assert client.default_headers["X-Foo"] == "bar"
629
630 # does not override the already given value when not specified
631 copied = client.copy()
632 assert copied.default_headers["X-Foo"] == "bar"
633
634 # merges already given headers
635 copied = client.copy(default_headers={"X-Bar": "stainless"})
636 assert copied.default_headers["X-Foo"] == "bar"
637 assert copied.default_headers["X-Bar"] == "stainless"
638
639 # uses new values for any already given headers
640 copied = client.copy(default_headers={"X-Foo": "stainless"})
641 assert copied.default_headers["X-Foo"] == "stainless"
642
643 # set_default_headers
644
645 # completely overrides already set values
646 copied = client.copy(set_default_headers={})
647 assert copied.default_headers.get("X-Foo") is None
648
649 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
650 assert copied.default_headers["X-Bar"] == "Robert"
651
652 with pytest.raises(
653 ValueError,
654 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
655 ):
656 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
657
658 def test_copy_default_query(self) -> None:
659 client = AsyncOpenAI(
660 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
661 )
662 assert _get_params(client)["foo"] == "bar"
663
664 # does not override the already given value when not specified
665 copied = client.copy()
666 assert _get_params(copied)["foo"] == "bar"
667
668 # merges already given params
669 copied = client.copy(default_query={"bar": "stainless"})
670 params = _get_params(copied)
671 assert params["foo"] == "bar"
672 assert params["bar"] == "stainless"
673
674 # uses new values for any already given headers
675 copied = client.copy(default_query={"foo": "stainless"})
676 assert _get_params(copied)["foo"] == "stainless"
677
678 # set_default_query
679
680 # completely overrides already set values
681 copied = client.copy(set_default_query={})
682 assert _get_params(copied) == {}
683
684 copied = client.copy(set_default_query={"bar": "Robert"})
685 assert _get_params(copied)["bar"] == "Robert"
686
687 with pytest.raises(
688 ValueError,
689 # TODO: update
690 match="`default_query` and `set_default_query` arguments are mutually exclusive",
691 ):
692 client.copy(set_default_query={}, default_query={"foo": "Bar"})
693
694 def test_copy_signature(self) -> None:
695 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
696 init_signature = inspect.signature(
697 # mypy doesn't like that we access the `__init__` property.
698 self.client.__init__, # type: ignore[misc]
699 )
700 copy_signature = inspect.signature(self.client.copy)
701 exclude_params = {"transport", "proxies", "_strict_response_validation"}
702
703 for name in init_signature.parameters.keys():
704 if name in exclude_params:
705 continue
706
707 copy_param = copy_signature.parameters.get(name)
708 assert copy_param is not None, f"copy() signature is missing the {name} param"
709
710 async def test_request_timeout(self) -> None:
711 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
712 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
713 assert timeout == DEFAULT_TIMEOUT
714
715 request = self.client._build_request(
716 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
717 )
718 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
719 assert timeout == httpx.Timeout(100.0)
720
721 async def test_client_timeout_option(self) -> None:
722 client = AsyncOpenAI(
723 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
724 )
725
726 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
727 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
728 assert timeout == httpx.Timeout(0)
729
730 async def test_http_client_timeout_option(self) -> None:
731 # custom timeout given to the httpx client should be used
732 async with httpx.AsyncClient(timeout=None) as http_client:
733 client = AsyncOpenAI(
734 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
735 )
736
737 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
738 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
739 assert timeout == httpx.Timeout(None)
740
741 # no timeout given to the httpx client should not use the httpx default
742 async with httpx.AsyncClient() as http_client:
743 client = AsyncOpenAI(
744 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
745 )
746
747 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
748 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
749 assert timeout == DEFAULT_TIMEOUT
750
751 # explicitly passing the default timeout currently results in it being ignored
752 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
753 client = AsyncOpenAI(
754 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
755 )
756
757 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
758 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
759 assert timeout == DEFAULT_TIMEOUT # our default
760
761 def test_default_headers_option(self) -> None:
762 client = AsyncOpenAI(
763 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
764 )
765 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
766 assert request.headers.get("x-foo") == "bar"
767 assert request.headers.get("x-stainless-lang") == "python"
768
769 client2 = AsyncOpenAI(
770 base_url=base_url,
771 api_key=api_key,
772 _strict_response_validation=True,
773 default_headers={
774 "X-Foo": "stainless",
775 "X-Stainless-Lang": "my-overriding-header",
776 },
777 )
778 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
779 assert request.headers.get("x-foo") == "stainless"
780 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
781
782 def test_validate_headers(self) -> None:
783 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
784 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
785 assert request.headers.get("Authorization") == f"Bearer {api_key}"
786
787 with pytest.raises(Exception):
788 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
789 _ = client2
790
791 def test_default_query_option(self) -> None:
792 client = AsyncOpenAI(
793 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
794 )
795 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
796 url = httpx.URL(request.url)
797 assert dict(url.params) == {"query_param": "bar"}
798
799 request = client._build_request(
800 FinalRequestOptions(
801 method="get",
802 url="/foo",
803 params={"foo": "baz", "query_param": "overriden"},
804 )
805 )
806 url = httpx.URL(request.url)
807 assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
808
809 def test_request_extra_json(self) -> None:
810 request = self.client._build_request(
811 FinalRequestOptions(
812 method="post",
813 url="/foo",
814 json_data={"foo": "bar"},
815 extra_json={"baz": False},
816 ),
817 )
818 data = json.loads(request.content.decode("utf-8"))
819 assert data == {"foo": "bar", "baz": False}
820
821 request = self.client._build_request(
822 FinalRequestOptions(
823 method="post",
824 url="/foo",
825 extra_json={"baz": False},
826 ),
827 )
828 data = json.loads(request.content.decode("utf-8"))
829 assert data == {"baz": False}
830
831 # `extra_json` takes priority over `json_data` when keys clash
832 request = self.client._build_request(
833 FinalRequestOptions(
834 method="post",
835 url="/foo",
836 json_data={"foo": "bar", "baz": True},
837 extra_json={"baz": None},
838 ),
839 )
840 data = json.loads(request.content.decode("utf-8"))
841 assert data == {"foo": "bar", "baz": None}
842
843 def test_request_extra_headers(self) -> None:
844 request = self.client._build_request(
845 FinalRequestOptions(
846 method="post",
847 url="/foo",
848 **make_request_options(extra_headers={"X-Foo": "Foo"}),
849 ),
850 )
851 assert request.headers.get("X-Foo") == "Foo"
852
853 # `extra_headers` takes priority over `default_headers` when keys clash
854 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
855 FinalRequestOptions(
856 method="post",
857 url="/foo",
858 **make_request_options(
859 extra_headers={"X-Bar": "false"},
860 ),
861 ),
862 )
863 assert request.headers.get("X-Bar") == "false"
864
865 def test_request_extra_query(self) -> None:
866 request = self.client._build_request(
867 FinalRequestOptions(
868 method="post",
869 url="/foo",
870 **make_request_options(
871 extra_query={"my_query_param": "Foo"},
872 ),
873 ),
874 )
875 params = cast(Dict[str, str], dict(request.url.params))
876 assert params == {"my_query_param": "Foo"}
877
878 # if both `query` and `extra_query` are given, they are merged
879 request = self.client._build_request(
880 FinalRequestOptions(
881 method="post",
882 url="/foo",
883 **make_request_options(
884 query={"bar": "1"},
885 extra_query={"foo": "2"},
886 ),
887 ),
888 )
889 params = cast(Dict[str, str], dict(request.url.params))
890 assert params == {"bar": "1", "foo": "2"}
891
892 # `extra_query` takes priority over `query` when keys clash
893 request = self.client._build_request(
894 FinalRequestOptions(
895 method="post",
896 url="/foo",
897 **make_request_options(
898 query={"foo": "1"},
899 extra_query={"foo": "2"},
900 ),
901 ),
902 )
903 params = cast(Dict[str, str], dict(request.url.params))
904 assert params == {"foo": "2"}
905
906 @pytest.mark.respx(base_url=base_url)
907 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
908 class Model1(BaseModel):
909 name: str
910
911 class Model2(BaseModel):
912 foo: str
913
914 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
915
916 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
917 assert isinstance(response, Model2)
918 assert response.foo == "bar"
919
920 @pytest.mark.respx(base_url=base_url)
921 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
922 """Union of objects with the same field name using a different type"""
923
924 class Model1(BaseModel):
925 foo: int
926
927 class Model2(BaseModel):
928 foo: str
929
930 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
931
932 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
933 assert isinstance(response, Model2)
934 assert response.foo == "bar"
935
936 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
937
938 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
939 assert isinstance(response, Model1)
940 assert response.foo == 1
941
942 def test_base_url_env(self) -> None:
943 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
944 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
945 assert client.base_url == "http://localhost:5000/from/env/"
946
947 @pytest.mark.parametrize(
948 "client",
949 [
950 AsyncOpenAI(
951 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
952 ),
953 AsyncOpenAI(
954 base_url="http://localhost:5000/custom/path/",
955 api_key=api_key,
956 _strict_response_validation=True,
957 http_client=httpx.AsyncClient(),
958 ),
959 ],
960 ids=["standard", "custom http client"],
961 )
962 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
963 request = client._build_request(
964 FinalRequestOptions(
965 method="post",
966 url="/foo",
967 json_data={"foo": "bar"},
968 ),
969 )
970 assert request.url == "http://localhost:5000/custom/path/foo"
971
972 @pytest.mark.parametrize(
973 "client",
974 [
975 AsyncOpenAI(
976 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
977 ),
978 AsyncOpenAI(
979 base_url="http://localhost:5000/custom/path/",
980 api_key=api_key,
981 _strict_response_validation=True,
982 http_client=httpx.AsyncClient(),
983 ),
984 ],
985 ids=["standard", "custom http client"],
986 )
987 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
988 request = client._build_request(
989 FinalRequestOptions(
990 method="post",
991 url="/foo",
992 json_data={"foo": "bar"},
993 ),
994 )
995 assert request.url == "http://localhost:5000/custom/path/foo"
996
997 @pytest.mark.parametrize(
998 "client",
999 [
1000 AsyncOpenAI(
1001 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1002 ),
1003 AsyncOpenAI(
1004 base_url="http://localhost:5000/custom/path/",
1005 api_key=api_key,
1006 _strict_response_validation=True,
1007 http_client=httpx.AsyncClient(),
1008 ),
1009 ],
1010 ids=["standard", "custom http client"],
1011 )
1012 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1013 request = client._build_request(
1014 FinalRequestOptions(
1015 method="post",
1016 url="https://myapi.com/foo",
1017 json_data={"foo": "bar"},
1018 ),
1019 )
1020 assert request.url == "https://myapi.com/foo"
1021
1022 async def test_client_del(self) -> None:
1023 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1024 assert not client.is_closed()
1025
1026 client.__del__()
1027
1028 await asyncio.sleep(0.2)
1029 assert client.is_closed()
1030
1031 async def test_copied_client_does_not_close_http(self) -> None:
1032 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1033 assert not client.is_closed()
1034
1035 copied = client.copy()
1036 assert copied is not client
1037
1038 copied.__del__()
1039
1040 await asyncio.sleep(0.2)
1041 assert not copied.is_closed()
1042 assert not client.is_closed()
1043
1044 async def test_client_context_manager(self) -> None:
1045 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1046 async with client as c2:
1047 assert c2 is client
1048 assert not c2.is_closed()
1049 assert not client.is_closed()
1050 assert client.is_closed()
1051
1052 @pytest.mark.respx(base_url=base_url)
1053 @pytest.mark.asyncio
1054 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1055 class Model(BaseModel):
1056 foo: str
1057
1058 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1059
1060 with pytest.raises(APIResponseValidationError) as exc:
1061 await self.client.get("/foo", cast_to=Model)
1062
1063 assert isinstance(exc.value.__cause__, ValidationError)
1064
1065 @pytest.mark.respx(base_url=base_url)
1066 @pytest.mark.asyncio
1067 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1068 class Model(BaseModel):
1069 name: str
1070
1071 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1072
1073 response = await self.client.post("/foo", cast_to=Model, stream=True)
1074 assert isinstance(response, AsyncStream)
1075
1076 @pytest.mark.respx(base_url=base_url)
1077 @pytest.mark.asyncio
1078 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1079 class Model(BaseModel):
1080 name: str
1081
1082 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1083
1084 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1085
1086 with pytest.raises(APIResponseValidationError):
1087 await strict_client.get("/foo", cast_to=Model)
1088
1089 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1090
1091 response = await client.get("/foo", cast_to=Model)
1092 assert isinstance(response, str) # type: ignore[unreachable]
1093
1094 @pytest.mark.parametrize(
1095 "remaining_retries,retry_after,timeout",
1096 [
1097 [3, "20", 20],
1098 [3, "0", 0.5],
1099 [3, "-10", 0.5],
1100 [3, "60", 60],
1101 [3, "61", 0.5],
1102 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1103 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1104 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1105 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1106 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1107 [3, "99999999999999999999999999999999999", 0.5],
1108 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1109 [3, "", 0.5],
1110 [2, "", 0.5 * 2.0],
1111 [1, "", 0.5 * 4.0],
1112 ],
1113 )
1114 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1115 @pytest.mark.asyncio
1116 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1117 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1118
1119 headers = httpx.Headers({"retry-after": retry_after})
1120 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1121 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1122 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1123