openai/openai-python

Public

mirrored from https://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
ea4fa238abdcddbf14aa7d42b7689cdd00a2d290

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1110lines · modeblame

08b8179aDavid Schnurr2 years ago1# File generated from our OpenAPI spec by Stainless.
2
3from __future__ import annotations
4
5import os
6import json
7import asyncio
8import inspect
9from typing import Any, Dict, Union, cast
10from unittest import mock
11
12import httpx
13import pytest
14from respx import MockRouter
15from pydantic import ValidationError
16
17from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
18from openai._client import OpenAI, AsyncOpenAI
19from openai._models import BaseModel, FinalRequestOptions
20from openai._streaming import Stream, AsyncStream
21from openai._exceptions import APIResponseValidationError
22from openai._base_client import (
23DEFAULT_TIMEOUT,
24HTTPX_DEFAULT_TIMEOUT,
25BaseClient,
26make_request_options,
27)
28
29base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
30api_key = "My API Key"
31
32
33def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
34request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
35url = httpx.URL(request.url)
36return dict(url.params)
37
38
39class TestOpenAI:
40client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
41
42@pytest.mark.respx(base_url=base_url)
43def test_raw_response(self, respx_mock: MockRouter) -> None:
c5975bd0Stainless Bot2 years ago44respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
08b8179aDavid Schnurr2 years ago45
46response = self.client.post("/foo", cast_to=httpx.Response)
47assert response.status_code == 200
48assert isinstance(response, httpx.Response)
c5975bd0Stainless Bot2 years ago49assert response.json() == {"foo": "bar"}
08b8179aDavid Schnurr2 years ago50
51@pytest.mark.respx(base_url=base_url)
52def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
53respx_mock.post("/foo").mock(
54return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
55)
56
57response = self.client.post("/foo", cast_to=httpx.Response)
58assert response.status_code == 200
59assert isinstance(response, httpx.Response)
c5975bd0Stainless Bot2 years ago60assert response.json() == {"foo": "bar"}
08b8179aDavid Schnurr2 years ago61
62def test_copy(self) -> None:
63copied = self.client.copy()
64assert id(copied) != id(self.client)
65
66copied = self.client.copy(api_key="another My API Key")
67assert copied.api_key == "another My API Key"
68assert self.client.api_key == "My API Key"
69
70def test_copy_default_options(self) -> None:
71# options that have a default are overridden correctly
72copied = self.client.copy(max_retries=7)
73assert copied.max_retries == 7
74assert self.client.max_retries == 2
75
76copied2 = copied.copy(max_retries=6)
77assert copied2.max_retries == 6
78assert copied.max_retries == 7
79
80# timeout
81assert isinstance(self.client.timeout, httpx.Timeout)
82copied = self.client.copy(timeout=None)
83assert copied.timeout is None
84assert isinstance(self.client.timeout, httpx.Timeout)
85
86def test_copy_default_headers(self) -> None:
87client = OpenAI(
88base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
89)
90assert client.default_headers["X-Foo"] == "bar"
91
92# does not override the already given value when not specified
93copied = client.copy()
94assert copied.default_headers["X-Foo"] == "bar"
95
96# merges already given headers
97copied = client.copy(default_headers={"X-Bar": "stainless"})
98assert copied.default_headers["X-Foo"] == "bar"
99assert copied.default_headers["X-Bar"] == "stainless"
100
101# uses new values for any already given headers
102copied = client.copy(default_headers={"X-Foo": "stainless"})
103assert copied.default_headers["X-Foo"] == "stainless"
104
105# set_default_headers
106
107# completely overrides already set values
108copied = client.copy(set_default_headers={})
109assert copied.default_headers.get("X-Foo") is None
110
111copied = client.copy(set_default_headers={"X-Bar": "Robert"})
112assert copied.default_headers["X-Bar"] == "Robert"
113
114with pytest.raises(
115ValueError,
116match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
117):
118client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
119
120def test_copy_default_query(self) -> None:
121client = OpenAI(
122base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
123)
124assert _get_params(client)["foo"] == "bar"
125
126# does not override the already given value when not specified
127copied = client.copy()
128assert _get_params(copied)["foo"] == "bar"
129
130# merges already given params
131copied = client.copy(default_query={"bar": "stainless"})
132params = _get_params(copied)
133assert params["foo"] == "bar"
134assert params["bar"] == "stainless"
135
136# uses new values for any already given headers
137copied = client.copy(default_query={"foo": "stainless"})
138assert _get_params(copied)["foo"] == "stainless"
139
140# set_default_query
141
142# completely overrides already set values
143copied = client.copy(set_default_query={})
144assert _get_params(copied) == {}
145
146copied = client.copy(set_default_query={"bar": "Robert"})
147assert _get_params(copied)["bar"] == "Robert"
148
149with pytest.raises(
150ValueError,
151# TODO: update
152match="`default_query` and `set_default_query` arguments are mutually exclusive",
153):
154client.copy(set_default_query={}, default_query={"foo": "Bar"})
155
156def test_copy_signature(self) -> None:
157# ensure the same parameters that can be passed to the client are defined in the `.copy()` method
158init_signature = inspect.signature(
159# mypy doesn't like that we access the `__init__` property.
160self.client.__init__, # type: ignore[misc]
161)
162copy_signature = inspect.signature(self.client.copy)
163exclude_params = {"transport", "proxies", "_strict_response_validation"}
164
165for name in init_signature.parameters.keys():
166if name in exclude_params:
167continue
168
169copy_param = copy_signature.parameters.get(name)
170assert copy_param is not None, f"copy() signature is missing the {name} param"
171
172def test_request_timeout(self) -> None:
173request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
174timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
175assert timeout == DEFAULT_TIMEOUT
176
177request = self.client._build_request(
178FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
179)
180timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
181assert timeout == httpx.Timeout(100.0)
182
183def test_client_timeout_option(self) -> None:
184client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
185
186request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
187timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
188assert timeout == httpx.Timeout(0)
189
190def test_http_client_timeout_option(self) -> None:
191# custom timeout given to the httpx client should be used
192with httpx.Client(timeout=None) as http_client:
193client = OpenAI(
194base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
195)
196
197request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
198timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
199assert timeout == httpx.Timeout(None)
200
201# no timeout given to the httpx client should not use the httpx default
202with httpx.Client() as http_client:
203client = OpenAI(
204base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
205)
206
207request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
208timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
209assert timeout == DEFAULT_TIMEOUT
210
211# explicitly passing the default timeout currently results in it being ignored
212with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
213client = OpenAI(
214base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
215)
216
217request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
218timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
219assert timeout == DEFAULT_TIMEOUT # our default
220
221def test_default_headers_option(self) -> None:
222client = OpenAI(
223base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
224)
225request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
226assert request.headers.get("x-foo") == "bar"
227assert request.headers.get("x-stainless-lang") == "python"
228
229client2 = OpenAI(
230base_url=base_url,
231api_key=api_key,
232_strict_response_validation=True,
233default_headers={
234"X-Foo": "stainless",
235"X-Stainless-Lang": "my-overriding-header",
236},
237)
238request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
239assert request.headers.get("x-foo") == "stainless"
240assert request.headers.get("x-stainless-lang") == "my-overriding-header"
241
242def test_validate_headers(self) -> None:
243client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
244request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
245assert request.headers.get("Authorization") == f"Bearer {api_key}"
246
247with pytest.raises(Exception):
248client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
249_ = client2
250
251def test_default_query_option(self) -> None:
252client = OpenAI(
253base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
254)
255request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
256url = httpx.URL(request.url)
257assert dict(url.params) == {"query_param": "bar"}
258
259request = client._build_request(
260FinalRequestOptions(
261method="get",
262url="/foo",
263params={"foo": "baz", "query_param": "overriden"},
264)
265)
266url = httpx.URL(request.url)
267assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
268
269def test_request_extra_json(self) -> None:
270request = self.client._build_request(
271FinalRequestOptions(
272method="post",
273url="/foo",
274json_data={"foo": "bar"},
275extra_json={"baz": False},
276),
277)
278data = json.loads(request.content.decode("utf-8"))
279assert data == {"foo": "bar", "baz": False}
280
281request = self.client._build_request(
282FinalRequestOptions(
283method="post",
284url="/foo",
285extra_json={"baz": False},
286),
287)
288data = json.loads(request.content.decode("utf-8"))
289assert data == {"baz": False}
290
291# `extra_json` takes priority over `json_data` when keys clash
292request = self.client._build_request(
293FinalRequestOptions(
294method="post",
295url="/foo",
296json_data={"foo": "bar", "baz": True},
297extra_json={"baz": None},
298),
299)
300data = json.loads(request.content.decode("utf-8"))
301assert data == {"foo": "bar", "baz": None}
302
303def test_request_extra_headers(self) -> None:
304request = self.client._build_request(
305FinalRequestOptions(
306method="post",
307url="/foo",
308**make_request_options(extra_headers={"X-Foo": "Foo"}),
309),
310)
311assert request.headers.get("X-Foo") == "Foo"
312
313# `extra_headers` takes priority over `default_headers` when keys clash
314request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
315FinalRequestOptions(
316method="post",
317url="/foo",
318**make_request_options(
319extra_headers={"X-Bar": "false"},
320),
321),
322)
323assert request.headers.get("X-Bar") == "false"
324
325def test_request_extra_query(self) -> None:
326request = self.client._build_request(
327FinalRequestOptions(
328method="post",
329url="/foo",
330**make_request_options(
331extra_query={"my_query_param": "Foo"},
332),
333),
334)
335params = cast(Dict[str, str], dict(request.url.params))
336assert params == {"my_query_param": "Foo"}
337
338# if both `query` and `extra_query` are given, they are merged
339request = self.client._build_request(
340FinalRequestOptions(
341method="post",
342url="/foo",
343**make_request_options(
344query={"bar": "1"},
345extra_query={"foo": "2"},
346),
347),
348)
349params = cast(Dict[str, str], dict(request.url.params))
350assert params == {"bar": "1", "foo": "2"}
351
352# `extra_query` takes priority over `query` when keys clash
353request = self.client._build_request(
354FinalRequestOptions(
355method="post",
356url="/foo",
357**make_request_options(
358query={"foo": "1"},
359extra_query={"foo": "2"},
360),
361),
362)
363params = cast(Dict[str, str], dict(request.url.params))
364assert params == {"foo": "2"}
365
366@pytest.mark.respx(base_url=base_url)
367def test_basic_union_response(self, respx_mock: MockRouter) -> None:
368class Model1(BaseModel):
369name: str
370
371class Model2(BaseModel):
372foo: str
373
374respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
375
376response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
377assert isinstance(response, Model2)
378assert response.foo == "bar"
379
380@pytest.mark.respx(base_url=base_url)
381def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
382"""Union of objects with the same field name using a different type"""
383
384class Model1(BaseModel):
385foo: int
386
387class Model2(BaseModel):
388foo: str
389
390respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
391
392response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
393assert isinstance(response, Model2)
394assert response.foo == "bar"
395
396respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
397
398response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
399assert isinstance(response, Model1)
400assert response.foo == 1
401
402@pytest.mark.parametrize(
403"client",
404[
405OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
406OpenAI(
407base_url="http://localhost:5000/custom/path/",
408api_key=api_key,
409_strict_response_validation=True,
410http_client=httpx.Client(),
411),
412],
413ids=["standard", "custom http client"],
414)
415def test_base_url_trailing_slash(self, client: OpenAI) -> None:
416request = client._build_request(
417FinalRequestOptions(
418method="post",
419url="/foo",
420json_data={"foo": "bar"},
421),
422)
423assert request.url == "http://localhost:5000/custom/path/foo"
424
425@pytest.mark.parametrize(
426"client",
427[
428OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
429OpenAI(
430base_url="http://localhost:5000/custom/path/",
431api_key=api_key,
432_strict_response_validation=True,
433http_client=httpx.Client(),
434),
435],
436ids=["standard", "custom http client"],
437)
438def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
439request = client._build_request(
440FinalRequestOptions(
441method="post",
442url="/foo",
443json_data={"foo": "bar"},
444),
445)
446assert request.url == "http://localhost:5000/custom/path/foo"
447
448@pytest.mark.parametrize(
449"client",
450[
451OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
452OpenAI(
453base_url="http://localhost:5000/custom/path/",
454api_key=api_key,
455_strict_response_validation=True,
456http_client=httpx.Client(),
457),
458],
459ids=["standard", "custom http client"],
460)
461def test_absolute_request_url(self, client: OpenAI) -> None:
462request = client._build_request(
463FinalRequestOptions(
464method="post",
465url="https://myapi.com/foo",
466json_data={"foo": "bar"},
467),
468)
469assert request.url == "https://myapi.com/foo"
470
471def test_client_del(self) -> None:
472client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
473assert not client.is_closed()
474
475client.__del__()
476
477assert client.is_closed()
478
479def test_copied_client_does_not_close_http(self) -> None:
480client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
481assert not client.is_closed()
482
483copied = client.copy()
484assert copied is not client
485
486copied.__del__()
487
488assert not copied.is_closed()
489assert not client.is_closed()
490
491def test_client_context_manager(self) -> None:
492client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
493with client as c2:
494assert c2 is client
495assert not c2.is_closed()
496assert not client.is_closed()
497assert client.is_closed()
498
499@pytest.mark.respx(base_url=base_url)
500def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
501class Model(BaseModel):
502foo: str
503
504respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
505
506with pytest.raises(APIResponseValidationError) as exc:
507self.client.get("/foo", cast_to=Model)
508
509assert isinstance(exc.value.__cause__, ValidationError)
510
511@pytest.mark.respx(base_url=base_url)
512def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
513class Model(BaseModel):
514name: str
515
516respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
517
518response = self.client.post("/foo", cast_to=Model, stream=True)
519assert isinstance(response, Stream)
520
521@pytest.mark.respx(base_url=base_url)
522def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
523class Model(BaseModel):
524name: str
525
526respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
527
528strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
529
530with pytest.raises(APIResponseValidationError):
531strict_client.get("/foo", cast_to=Model)
532
533client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
534
535response = client.get("/foo", cast_to=Model)
536assert isinstance(response, str) # type: ignore[unreachable]
537
538@pytest.mark.parametrize(
539"remaining_retries,retry_after,timeout",
540[
541[3, "20", 20],
542[3, "0", 0.5],
543[3, "-10", 0.5],
544[3, "60", 60],
545[3, "61", 0.5],
546[3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
547[3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
548[3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
549[3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
550[3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
551[3, "99999999999999999999999999999999999", 0.5],
552[3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
553[3, "", 0.5],
554[2, "", 0.5 * 2.0],
555[1, "", 0.5 * 4.0],
556],
557)
558@mock.patch("time.time", mock.MagicMock(return_value=1696004797))
559def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
560client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
561
562headers = httpx.Headers({"retry-after": retry_after})
563options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
564calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
565assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
566
567
568class TestAsyncOpenAI:
569client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
570
571@pytest.mark.respx(base_url=base_url)
572@pytest.mark.asyncio
573async def test_raw_response(self, respx_mock: MockRouter) -> None:
c5975bd0Stainless Bot2 years ago574respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
08b8179aDavid Schnurr2 years ago575
576response = await self.client.post("/foo", cast_to=httpx.Response)
577assert response.status_code == 200
578assert isinstance(response, httpx.Response)
c5975bd0Stainless Bot2 years ago579assert response.json() == {"foo": "bar"}
08b8179aDavid Schnurr2 years ago580
581@pytest.mark.respx(base_url=base_url)
582@pytest.mark.asyncio
583async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
584respx_mock.post("/foo").mock(
585return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
586)
587
588response = await self.client.post("/foo", cast_to=httpx.Response)
589assert response.status_code == 200
590assert isinstance(response, httpx.Response)
c5975bd0Stainless Bot2 years ago591assert response.json() == {"foo": "bar"}
08b8179aDavid Schnurr2 years ago592
593def test_copy(self) -> None:
594copied = self.client.copy()
595assert id(copied) != id(self.client)
596
597copied = self.client.copy(api_key="another My API Key")
598assert copied.api_key == "another My API Key"
599assert self.client.api_key == "My API Key"
600
601def test_copy_default_options(self) -> None:
602# options that have a default are overridden correctly
603copied = self.client.copy(max_retries=7)
604assert copied.max_retries == 7
605assert self.client.max_retries == 2
606
607copied2 = copied.copy(max_retries=6)
608assert copied2.max_retries == 6
609assert copied.max_retries == 7
610
611# timeout
612assert isinstance(self.client.timeout, httpx.Timeout)
613copied = self.client.copy(timeout=None)
614assert copied.timeout is None
615assert isinstance(self.client.timeout, httpx.Timeout)
616
617def test_copy_default_headers(self) -> None:
618client = AsyncOpenAI(
619base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
620)
621assert client.default_headers["X-Foo"] == "bar"
622
623# does not override the already given value when not specified
624copied = client.copy()
625assert copied.default_headers["X-Foo"] == "bar"
626
627# merges already given headers
628copied = client.copy(default_headers={"X-Bar": "stainless"})
629assert copied.default_headers["X-Foo"] == "bar"
630assert copied.default_headers["X-Bar"] == "stainless"
631
632# uses new values for any already given headers
633copied = client.copy(default_headers={"X-Foo": "stainless"})
634assert copied.default_headers["X-Foo"] == "stainless"
635
636# set_default_headers
637
638# completely overrides already set values
639copied = client.copy(set_default_headers={})
640assert copied.default_headers.get("X-Foo") is None
641
642copied = client.copy(set_default_headers={"X-Bar": "Robert"})
643assert copied.default_headers["X-Bar"] == "Robert"
644
645with pytest.raises(
646ValueError,
647match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
648):
649client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
650
651def test_copy_default_query(self) -> None:
652client = AsyncOpenAI(
653base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
654)
655assert _get_params(client)["foo"] == "bar"
656
657# does not override the already given value when not specified
658copied = client.copy()
659assert _get_params(copied)["foo"] == "bar"
660
661# merges already given params
662copied = client.copy(default_query={"bar": "stainless"})
663params = _get_params(copied)
664assert params["foo"] == "bar"
665assert params["bar"] == "stainless"
666
667# uses new values for any already given headers
668copied = client.copy(default_query={"foo": "stainless"})
669assert _get_params(copied)["foo"] == "stainless"
670
671# set_default_query
672
673# completely overrides already set values
674copied = client.copy(set_default_query={})
675assert _get_params(copied) == {}
676
677copied = client.copy(set_default_query={"bar": "Robert"})
678assert _get_params(copied)["bar"] == "Robert"
679
680with pytest.raises(
681ValueError,
682# TODO: update
683match="`default_query` and `set_default_query` arguments are mutually exclusive",
684):
685client.copy(set_default_query={}, default_query={"foo": "Bar"})
686
687def test_copy_signature(self) -> None:
688# ensure the same parameters that can be passed to the client are defined in the `.copy()` method
689init_signature = inspect.signature(
690# mypy doesn't like that we access the `__init__` property.
691self.client.__init__, # type: ignore[misc]
692)
693copy_signature = inspect.signature(self.client.copy)
694exclude_params = {"transport", "proxies", "_strict_response_validation"}
695
696for name in init_signature.parameters.keys():
697if name in exclude_params:
698continue
699
700copy_param = copy_signature.parameters.get(name)
701assert copy_param is not None, f"copy() signature is missing the {name} param"
702
703async def test_request_timeout(self) -> None:
704request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
705timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
706assert timeout == DEFAULT_TIMEOUT
707
708request = self.client._build_request(
709FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
710)
711timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
712assert timeout == httpx.Timeout(100.0)
713
714async def test_client_timeout_option(self) -> None:
715client = AsyncOpenAI(
716base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
717)
718
719request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
720timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
721assert timeout == httpx.Timeout(0)
722
723async def test_http_client_timeout_option(self) -> None:
724# custom timeout given to the httpx client should be used
725async with httpx.AsyncClient(timeout=None) as http_client:
726client = AsyncOpenAI(
727base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
728)
729
730request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
731timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
732assert timeout == httpx.Timeout(None)
733
734# no timeout given to the httpx client should not use the httpx default
735async with httpx.AsyncClient() as http_client:
736client = AsyncOpenAI(
737base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
738)
739
740request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
741timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
742assert timeout == DEFAULT_TIMEOUT
743
744# explicitly passing the default timeout currently results in it being ignored
745async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
746client = AsyncOpenAI(
747base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
748)
749
750request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
751timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
752assert timeout == DEFAULT_TIMEOUT # our default
753
754def test_default_headers_option(self) -> None:
755client = AsyncOpenAI(
756base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
757)
758request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
759assert request.headers.get("x-foo") == "bar"
760assert request.headers.get("x-stainless-lang") == "python"
761
762client2 = AsyncOpenAI(
763base_url=base_url,
764api_key=api_key,
765_strict_response_validation=True,
766default_headers={
767"X-Foo": "stainless",
768"X-Stainless-Lang": "my-overriding-header",
769},
770)
771request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
772assert request.headers.get("x-foo") == "stainless"
773assert request.headers.get("x-stainless-lang") == "my-overriding-header"
774
775def test_validate_headers(self) -> None:
776client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
777request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
778assert request.headers.get("Authorization") == f"Bearer {api_key}"
779
780with pytest.raises(Exception):
781client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
782_ = client2
783
784def test_default_query_option(self) -> None:
785client = AsyncOpenAI(
786base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
787)
788request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
789url = httpx.URL(request.url)
790assert dict(url.params) == {"query_param": "bar"}
791
792request = client._build_request(
793FinalRequestOptions(
794method="get",
795url="/foo",
796params={"foo": "baz", "query_param": "overriden"},
797)
798)
799url = httpx.URL(request.url)
800assert dict(url.params) == {"foo": "baz", "query_param": "overriden"}
801
802def test_request_extra_json(self) -> None:
803request = self.client._build_request(
804FinalRequestOptions(
805method="post",
806url="/foo",
807json_data={"foo": "bar"},
808extra_json={"baz": False},
809),
810)
811data = json.loads(request.content.decode("utf-8"))
812assert data == {"foo": "bar", "baz": False}
813
814request = self.client._build_request(
815FinalRequestOptions(
816method="post",
817url="/foo",
818extra_json={"baz": False},
819),
820)
821data = json.loads(request.content.decode("utf-8"))
822assert data == {"baz": False}
823
824# `extra_json` takes priority over `json_data` when keys clash
825request = self.client._build_request(
826FinalRequestOptions(
827method="post",
828url="/foo",
829json_data={"foo": "bar", "baz": True},
830extra_json={"baz": None},
831),
832)
833data = json.loads(request.content.decode("utf-8"))
834assert data == {"foo": "bar", "baz": None}
835
836def test_request_extra_headers(self) -> None:
837request = self.client._build_request(
838FinalRequestOptions(
839method="post",
840url="/foo",
841**make_request_options(extra_headers={"X-Foo": "Foo"}),
842),
843)
844assert request.headers.get("X-Foo") == "Foo"
845
846# `extra_headers` takes priority over `default_headers` when keys clash
847request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
848FinalRequestOptions(
849method="post",
850url="/foo",
851**make_request_options(
852extra_headers={"X-Bar": "false"},
853),
854),
855)
856assert request.headers.get("X-Bar") == "false"
857
858def test_request_extra_query(self) -> None:
859request = self.client._build_request(
860FinalRequestOptions(
861method="post",
862url="/foo",
863**make_request_options(
864extra_query={"my_query_param": "Foo"},
865),
866),
867)
868params = cast(Dict[str, str], dict(request.url.params))
869assert params == {"my_query_param": "Foo"}
870
871# if both `query` and `extra_query` are given, they are merged
872request = self.client._build_request(
873FinalRequestOptions(
874method="post",
875url="/foo",
876**make_request_options(
877query={"bar": "1"},
878extra_query={"foo": "2"},
879),
880),
881)
882params = cast(Dict[str, str], dict(request.url.params))
883assert params == {"bar": "1", "foo": "2"}
884
885# `extra_query` takes priority over `query` when keys clash
886request = self.client._build_request(
887FinalRequestOptions(
888method="post",
889url="/foo",
890**make_request_options(
891query={"foo": "1"},
892extra_query={"foo": "2"},
893),
894),
895)
896params = cast(Dict[str, str], dict(request.url.params))
897assert params == {"foo": "2"}
898
899@pytest.mark.respx(base_url=base_url)
900async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
901class Model1(BaseModel):
902name: str
903
904class Model2(BaseModel):
905foo: str
906
907respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
908
909response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
910assert isinstance(response, Model2)
911assert response.foo == "bar"
912
913@pytest.mark.respx(base_url=base_url)
914async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
915"""Union of objects with the same field name using a different type"""
916
917class Model1(BaseModel):
918foo: int
919
920class Model2(BaseModel):
921foo: str
922
923respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
924
925response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
926assert isinstance(response, Model2)
927assert response.foo == "bar"
928
929respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
930
931response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
932assert isinstance(response, Model1)
933assert response.foo == 1
934
935@pytest.mark.parametrize(
936"client",
937[
938AsyncOpenAI(
939base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
940),
941AsyncOpenAI(
942base_url="http://localhost:5000/custom/path/",
943api_key=api_key,
944_strict_response_validation=True,
945http_client=httpx.AsyncClient(),
946),
947],
948ids=["standard", "custom http client"],
949)
950def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
951request = client._build_request(
952FinalRequestOptions(
953method="post",
954url="/foo",
955json_data={"foo": "bar"},
956),
957)
958assert request.url == "http://localhost:5000/custom/path/foo"
959
960@pytest.mark.parametrize(
961"client",
962[
963AsyncOpenAI(
964base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
965),
966AsyncOpenAI(
967base_url="http://localhost:5000/custom/path/",
968api_key=api_key,
969_strict_response_validation=True,
970http_client=httpx.AsyncClient(),
971),
972],
973ids=["standard", "custom http client"],
974)
975def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
976request = client._build_request(
977FinalRequestOptions(
978method="post",
979url="/foo",
980json_data={"foo": "bar"},
981),
982)
983assert request.url == "http://localhost:5000/custom/path/foo"
984
985@pytest.mark.parametrize(
986"client",
987[
988AsyncOpenAI(
989base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
990),
991AsyncOpenAI(
992base_url="http://localhost:5000/custom/path/",
993api_key=api_key,
994_strict_response_validation=True,
995http_client=httpx.AsyncClient(),
996),
997],
998ids=["standard", "custom http client"],
999)
1000def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1001request = client._build_request(
1002FinalRequestOptions(
1003method="post",
1004url="https://myapi.com/foo",
1005json_data={"foo": "bar"},
1006),
1007)
1008assert request.url == "https://myapi.com/foo"
1009
1010async def test_client_del(self) -> None:
1011client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1012assert not client.is_closed()
1013
1014client.__del__()
1015
1016await asyncio.sleep(0.2)
1017assert client.is_closed()
1018
1019async def test_copied_client_does_not_close_http(self) -> None:
1020client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1021assert not client.is_closed()
1022
1023copied = client.copy()
1024assert copied is not client
1025
1026copied.__del__()
1027
1028await asyncio.sleep(0.2)
1029assert not copied.is_closed()
1030assert not client.is_closed()
1031
1032async def test_client_context_manager(self) -> None:
1033client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1034async with client as c2:
1035assert c2 is client
1036assert not c2.is_closed()
1037assert not client.is_closed()
1038assert client.is_closed()
1039
1040@pytest.mark.respx(base_url=base_url)
1041@pytest.mark.asyncio
1042async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1043class Model(BaseModel):
1044foo: str
1045
1046respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1047
1048with pytest.raises(APIResponseValidationError) as exc:
1049await self.client.get("/foo", cast_to=Model)
1050
1051assert isinstance(exc.value.__cause__, ValidationError)
1052
1053@pytest.mark.respx(base_url=base_url)
1054@pytest.mark.asyncio
1055async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1056class Model(BaseModel):
1057name: str
1058
1059respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1060
1061response = await self.client.post("/foo", cast_to=Model, stream=True)
1062assert isinstance(response, AsyncStream)
1063
1064@pytest.mark.respx(base_url=base_url)
1065@pytest.mark.asyncio
1066async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1067class Model(BaseModel):
1068name: str
1069
1070respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1071
1072strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1073
1074with pytest.raises(APIResponseValidationError):
1075await strict_client.get("/foo", cast_to=Model)
1076
1077client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1078
1079response = await client.get("/foo", cast_to=Model)
1080assert isinstance(response, str) # type: ignore[unreachable]
1081
1082@pytest.mark.parametrize(
1083"remaining_retries,retry_after,timeout",
1084[
1085[3, "20", 20],
1086[3, "0", 0.5],
1087[3, "-10", 0.5],
1088[3, "60", 60],
1089[3, "61", 0.5],
1090[3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1091[3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1092[3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1093[3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1094[3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1095[3, "99999999999999999999999999999999999", 0.5],
1096[3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1097[3, "", 0.5],
1098[2, "", 0.5 * 2.0],
1099[1, "", 0.5 * 4.0],
1100],
1101)
1102@mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1103@pytest.mark.asyncio
1104async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1105client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1106
1107headers = httpx.Headers({"retry-after": retry_after})
1108options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1109calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1110assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]