openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.91.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1887lines · modecode

1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import gc
6import os
7import sys
8import json
9import time
10import asyncio
11import inspect
12import subprocess
13import tracemalloc
14from typing import Any, Union, cast
15from textwrap import dedent
16from unittest import mock
17from typing_extensions import Literal
18
19import httpx
20import pytest
21from respx import MockRouter
22from pydantic import ValidationError
23
24from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
25from openai._types import Omit
26from openai._models import BaseModel, FinalRequestOptions
27from openai._streaming import Stream, AsyncStream
28from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
29from openai._base_client import (
30 DEFAULT_TIMEOUT,
31 HTTPX_DEFAULT_TIMEOUT,
32 BaseClient,
33 DefaultHttpxClient,
34 DefaultAsyncHttpxClient,
35 make_request_options,
36)
37
38from .utils import update_env
39
40base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
41api_key = "My API Key"
42
43
44def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
45 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
46 url = httpx.URL(request.url)
47 return dict(url.params)
48
49
50def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
51 return 0.1
52
53
54def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
55 transport = client._client._transport
56 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
57
58 pool = transport._pool
59 return len(pool._requests)
60
61
62class TestOpenAI:
63 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
64
65 @pytest.mark.respx(base_url=base_url)
66 def test_raw_response(self, respx_mock: MockRouter) -> None:
67 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
68
69 response = self.client.post("/foo", cast_to=httpx.Response)
70 assert response.status_code == 200
71 assert isinstance(response, httpx.Response)
72 assert response.json() == {"foo": "bar"}
73
74 @pytest.mark.respx(base_url=base_url)
75 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
76 respx_mock.post("/foo").mock(
77 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
78 )
79
80 response = self.client.post("/foo", cast_to=httpx.Response)
81 assert response.status_code == 200
82 assert isinstance(response, httpx.Response)
83 assert response.json() == {"foo": "bar"}
84
85 def test_copy(self) -> None:
86 copied = self.client.copy()
87 assert id(copied) != id(self.client)
88
89 copied = self.client.copy(api_key="another My API Key")
90 assert copied.api_key == "another My API Key"
91 assert self.client.api_key == "My API Key"
92
93 def test_copy_default_options(self) -> None:
94 # options that have a default are overridden correctly
95 copied = self.client.copy(max_retries=7)
96 assert copied.max_retries == 7
97 assert self.client.max_retries == 2
98
99 copied2 = copied.copy(max_retries=6)
100 assert copied2.max_retries == 6
101 assert copied.max_retries == 7
102
103 # timeout
104 assert isinstance(self.client.timeout, httpx.Timeout)
105 copied = self.client.copy(timeout=None)
106 assert copied.timeout is None
107 assert isinstance(self.client.timeout, httpx.Timeout)
108
109 def test_copy_default_headers(self) -> None:
110 client = OpenAI(
111 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
112 )
113 assert client.default_headers["X-Foo"] == "bar"
114
115 # does not override the already given value when not specified
116 copied = client.copy()
117 assert copied.default_headers["X-Foo"] == "bar"
118
119 # merges already given headers
120 copied = client.copy(default_headers={"X-Bar": "stainless"})
121 assert copied.default_headers["X-Foo"] == "bar"
122 assert copied.default_headers["X-Bar"] == "stainless"
123
124 # uses new values for any already given headers
125 copied = client.copy(default_headers={"X-Foo": "stainless"})
126 assert copied.default_headers["X-Foo"] == "stainless"
127
128 # set_default_headers
129
130 # completely overrides already set values
131 copied = client.copy(set_default_headers={})
132 assert copied.default_headers.get("X-Foo") is None
133
134 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
135 assert copied.default_headers["X-Bar"] == "Robert"
136
137 with pytest.raises(
138 ValueError,
139 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
140 ):
141 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
142
143 def test_copy_default_query(self) -> None:
144 client = OpenAI(
145 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
146 )
147 assert _get_params(client)["foo"] == "bar"
148
149 # does not override the already given value when not specified
150 copied = client.copy()
151 assert _get_params(copied)["foo"] == "bar"
152
153 # merges already given params
154 copied = client.copy(default_query={"bar": "stainless"})
155 params = _get_params(copied)
156 assert params["foo"] == "bar"
157 assert params["bar"] == "stainless"
158
159 # uses new values for any already given headers
160 copied = client.copy(default_query={"foo": "stainless"})
161 assert _get_params(copied)["foo"] == "stainless"
162
163 # set_default_query
164
165 # completely overrides already set values
166 copied = client.copy(set_default_query={})
167 assert _get_params(copied) == {}
168
169 copied = client.copy(set_default_query={"bar": "Robert"})
170 assert _get_params(copied)["bar"] == "Robert"
171
172 with pytest.raises(
173 ValueError,
174 # TODO: update
175 match="`default_query` and `set_default_query` arguments are mutually exclusive",
176 ):
177 client.copy(set_default_query={}, default_query={"foo": "Bar"})
178
179 def test_copy_signature(self) -> None:
180 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
181 init_signature = inspect.signature(
182 # mypy doesn't like that we access the `__init__` property.
183 self.client.__init__, # type: ignore[misc]
184 )
185 copy_signature = inspect.signature(self.client.copy)
186 exclude_params = {"transport", "proxies", "_strict_response_validation"}
187
188 for name in init_signature.parameters.keys():
189 if name in exclude_params:
190 continue
191
192 copy_param = copy_signature.parameters.get(name)
193 assert copy_param is not None, f"copy() signature is missing the {name} param"
194
195 def test_copy_build_request(self) -> None:
196 options = FinalRequestOptions(method="get", url="/foo")
197
198 def build_request(options: FinalRequestOptions) -> None:
199 client = self.client.copy()
200 client._build_request(options)
201
202 # ensure that the machinery is warmed up before tracing starts.
203 build_request(options)
204 gc.collect()
205
206 tracemalloc.start(1000)
207
208 snapshot_before = tracemalloc.take_snapshot()
209
210 ITERATIONS = 10
211 for _ in range(ITERATIONS):
212 build_request(options)
213
214 gc.collect()
215 snapshot_after = tracemalloc.take_snapshot()
216
217 tracemalloc.stop()
218
219 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
220 if diff.count == 0:
221 # Avoid false positives by considering only leaks (i.e. allocations that persist).
222 return
223
224 if diff.count % ITERATIONS != 0:
225 # Avoid false positives by considering only leaks that appear per iteration.
226 return
227
228 for frame in diff.traceback:
229 if any(
230 frame.filename.endswith(fragment)
231 for fragment in [
232 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
233 #
234 # removing the decorator fixes the leak for reasons we don't understand.
235 "openai/_legacy_response.py",
236 "openai/_response.py",
237 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
238 "openai/_compat.py",
239 # Standard library leaks we don't care about.
240 "/logging/__init__.py",
241 ]
242 ):
243 return
244
245 leaks.append(diff)
246
247 leaks: list[tracemalloc.StatisticDiff] = []
248 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
249 add_leak(leaks, diff)
250 if leaks:
251 for leak in leaks:
252 print("MEMORY LEAK:", leak)
253 for frame in leak.traceback:
254 print(frame)
255 raise AssertionError()
256
257 def test_request_timeout(self) -> None:
258 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
259 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
260 assert timeout == DEFAULT_TIMEOUT
261
262 request = self.client._build_request(
263 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
264 )
265 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
266 assert timeout == httpx.Timeout(100.0)
267
268 def test_client_timeout_option(self) -> None:
269 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
270
271 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
272 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
273 assert timeout == httpx.Timeout(0)
274
275 def test_http_client_timeout_option(self) -> None:
276 # custom timeout given to the httpx client should be used
277 with httpx.Client(timeout=None) as http_client:
278 client = OpenAI(
279 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
280 )
281
282 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
283 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
284 assert timeout == httpx.Timeout(None)
285
286 # no timeout given to the httpx client should not use the httpx default
287 with httpx.Client() as http_client:
288 client = OpenAI(
289 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
290 )
291
292 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
293 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
294 assert timeout == DEFAULT_TIMEOUT
295
296 # explicitly passing the default timeout currently results in it being ignored
297 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
298 client = OpenAI(
299 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
300 )
301
302 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
303 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
304 assert timeout == DEFAULT_TIMEOUT # our default
305
306 async def test_invalid_http_client(self) -> None:
307 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
308 async with httpx.AsyncClient() as http_client:
309 OpenAI(
310 base_url=base_url,
311 api_key=api_key,
312 _strict_response_validation=True,
313 http_client=cast(Any, http_client),
314 )
315
316 def test_default_headers_option(self) -> None:
317 client = OpenAI(
318 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
319 )
320 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
321 assert request.headers.get("x-foo") == "bar"
322 assert request.headers.get("x-stainless-lang") == "python"
323
324 client2 = OpenAI(
325 base_url=base_url,
326 api_key=api_key,
327 _strict_response_validation=True,
328 default_headers={
329 "X-Foo": "stainless",
330 "X-Stainless-Lang": "my-overriding-header",
331 },
332 )
333 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
334 assert request.headers.get("x-foo") == "stainless"
335 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
336
337 def test_validate_headers(self) -> None:
338 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
339 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
340 assert request.headers.get("Authorization") == f"Bearer {api_key}"
341
342 with pytest.raises(OpenAIError):
343 with update_env(**{"OPENAI_API_KEY": Omit()}):
344 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
345 _ = client2
346
347 def test_default_query_option(self) -> None:
348 client = OpenAI(
349 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
350 )
351 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
352 url = httpx.URL(request.url)
353 assert dict(url.params) == {"query_param": "bar"}
354
355 request = client._build_request(
356 FinalRequestOptions(
357 method="get",
358 url="/foo",
359 params={"foo": "baz", "query_param": "overridden"},
360 )
361 )
362 url = httpx.URL(request.url)
363 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
364
365 def test_request_extra_json(self) -> None:
366 request = self.client._build_request(
367 FinalRequestOptions(
368 method="post",
369 url="/foo",
370 json_data={"foo": "bar"},
371 extra_json={"baz": False},
372 ),
373 )
374 data = json.loads(request.content.decode("utf-8"))
375 assert data == {"foo": "bar", "baz": False}
376
377 request = self.client._build_request(
378 FinalRequestOptions(
379 method="post",
380 url="/foo",
381 extra_json={"baz": False},
382 ),
383 )
384 data = json.loads(request.content.decode("utf-8"))
385 assert data == {"baz": False}
386
387 # `extra_json` takes priority over `json_data` when keys clash
388 request = self.client._build_request(
389 FinalRequestOptions(
390 method="post",
391 url="/foo",
392 json_data={"foo": "bar", "baz": True},
393 extra_json={"baz": None},
394 ),
395 )
396 data = json.loads(request.content.decode("utf-8"))
397 assert data == {"foo": "bar", "baz": None}
398
399 def test_request_extra_headers(self) -> None:
400 request = self.client._build_request(
401 FinalRequestOptions(
402 method="post",
403 url="/foo",
404 **make_request_options(extra_headers={"X-Foo": "Foo"}),
405 ),
406 )
407 assert request.headers.get("X-Foo") == "Foo"
408
409 # `extra_headers` takes priority over `default_headers` when keys clash
410 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
411 FinalRequestOptions(
412 method="post",
413 url="/foo",
414 **make_request_options(
415 extra_headers={"X-Bar": "false"},
416 ),
417 ),
418 )
419 assert request.headers.get("X-Bar") == "false"
420
421 def test_request_extra_query(self) -> None:
422 request = self.client._build_request(
423 FinalRequestOptions(
424 method="post",
425 url="/foo",
426 **make_request_options(
427 extra_query={"my_query_param": "Foo"},
428 ),
429 ),
430 )
431 params = dict(request.url.params)
432 assert params == {"my_query_param": "Foo"}
433
434 # if both `query` and `extra_query` are given, they are merged
435 request = self.client._build_request(
436 FinalRequestOptions(
437 method="post",
438 url="/foo",
439 **make_request_options(
440 query={"bar": "1"},
441 extra_query={"foo": "2"},
442 ),
443 ),
444 )
445 params = dict(request.url.params)
446 assert params == {"bar": "1", "foo": "2"}
447
448 # `extra_query` takes priority over `query` when keys clash
449 request = self.client._build_request(
450 FinalRequestOptions(
451 method="post",
452 url="/foo",
453 **make_request_options(
454 query={"foo": "1"},
455 extra_query={"foo": "2"},
456 ),
457 ),
458 )
459 params = dict(request.url.params)
460 assert params == {"foo": "2"}
461
462 def test_multipart_repeating_array(self, client: OpenAI) -> None:
463 request = client._build_request(
464 FinalRequestOptions.construct(
465 method="get",
466 url="/foo",
467 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
468 json_data={"array": ["foo", "bar"]},
469 files=[("foo.txt", b"hello world")],
470 )
471 )
472
473 assert request.read().split(b"\r\n") == [
474 b"--6b7ba517decee4a450543ea6ae821c82",
475 b'Content-Disposition: form-data; name="array[]"',
476 b"",
477 b"foo",
478 b"--6b7ba517decee4a450543ea6ae821c82",
479 b'Content-Disposition: form-data; name="array[]"',
480 b"",
481 b"bar",
482 b"--6b7ba517decee4a450543ea6ae821c82",
483 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
484 b"Content-Type: application/octet-stream",
485 b"",
486 b"hello world",
487 b"--6b7ba517decee4a450543ea6ae821c82--",
488 b"",
489 ]
490
491 @pytest.mark.respx(base_url=base_url)
492 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
493 class Model1(BaseModel):
494 name: str
495
496 class Model2(BaseModel):
497 foo: str
498
499 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
500
501 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
502 assert isinstance(response, Model2)
503 assert response.foo == "bar"
504
505 @pytest.mark.respx(base_url=base_url)
506 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
507 """Union of objects with the same field name using a different type"""
508
509 class Model1(BaseModel):
510 foo: int
511
512 class Model2(BaseModel):
513 foo: str
514
515 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
516
517 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
518 assert isinstance(response, Model2)
519 assert response.foo == "bar"
520
521 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
522
523 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
524 assert isinstance(response, Model1)
525 assert response.foo == 1
526
527 @pytest.mark.respx(base_url=base_url)
528 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
529 """
530 Response that sets Content-Type to something other than application/json but returns json data
531 """
532
533 class Model(BaseModel):
534 foo: int
535
536 respx_mock.get("/foo").mock(
537 return_value=httpx.Response(
538 200,
539 content=json.dumps({"foo": 2}),
540 headers={"Content-Type": "application/text"},
541 )
542 )
543
544 response = self.client.get("/foo", cast_to=Model)
545 assert isinstance(response, Model)
546 assert response.foo == 2
547
548 def test_base_url_setter(self) -> None:
549 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
550 assert client.base_url == "https://example.com/from_init/"
551
552 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
553
554 assert client.base_url == "https://example.com/from_setter/"
555
556 def test_base_url_env(self) -> None:
557 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
558 client = OpenAI(api_key=api_key, _strict_response_validation=True)
559 assert client.base_url == "http://localhost:5000/from/env/"
560
561 @pytest.mark.parametrize(
562 "client",
563 [
564 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
565 OpenAI(
566 base_url="http://localhost:5000/custom/path/",
567 api_key=api_key,
568 _strict_response_validation=True,
569 http_client=httpx.Client(),
570 ),
571 ],
572 ids=["standard", "custom http client"],
573 )
574 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
575 request = client._build_request(
576 FinalRequestOptions(
577 method="post",
578 url="/foo",
579 json_data={"foo": "bar"},
580 ),
581 )
582 assert request.url == "http://localhost:5000/custom/path/foo"
583
584 @pytest.mark.parametrize(
585 "client",
586 [
587 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
588 OpenAI(
589 base_url="http://localhost:5000/custom/path/",
590 api_key=api_key,
591 _strict_response_validation=True,
592 http_client=httpx.Client(),
593 ),
594 ],
595 ids=["standard", "custom http client"],
596 )
597 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
598 request = client._build_request(
599 FinalRequestOptions(
600 method="post",
601 url="/foo",
602 json_data={"foo": "bar"},
603 ),
604 )
605 assert request.url == "http://localhost:5000/custom/path/foo"
606
607 @pytest.mark.parametrize(
608 "client",
609 [
610 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
611 OpenAI(
612 base_url="http://localhost:5000/custom/path/",
613 api_key=api_key,
614 _strict_response_validation=True,
615 http_client=httpx.Client(),
616 ),
617 ],
618 ids=["standard", "custom http client"],
619 )
620 def test_absolute_request_url(self, client: OpenAI) -> None:
621 request = client._build_request(
622 FinalRequestOptions(
623 method="post",
624 url="https://myapi.com/foo",
625 json_data={"foo": "bar"},
626 ),
627 )
628 assert request.url == "https://myapi.com/foo"
629
630 def test_copied_client_does_not_close_http(self) -> None:
631 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
632 assert not client.is_closed()
633
634 copied = client.copy()
635 assert copied is not client
636
637 del copied
638
639 assert not client.is_closed()
640
641 def test_client_context_manager(self) -> None:
642 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
643 with client as c2:
644 assert c2 is client
645 assert not c2.is_closed()
646 assert not client.is_closed()
647 assert client.is_closed()
648
649 @pytest.mark.respx(base_url=base_url)
650 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
651 class Model(BaseModel):
652 foo: str
653
654 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
655
656 with pytest.raises(APIResponseValidationError) as exc:
657 self.client.get("/foo", cast_to=Model)
658
659 assert isinstance(exc.value.__cause__, ValidationError)
660
661 def test_client_max_retries_validation(self) -> None:
662 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
663 OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None))
664
665 @pytest.mark.respx(base_url=base_url)
666 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
667 class Model(BaseModel):
668 name: str
669
670 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
671
672 stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
673 assert isinstance(stream, Stream)
674 stream.response.close()
675
676 @pytest.mark.respx(base_url=base_url)
677 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
678 class Model(BaseModel):
679 name: str
680
681 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
682
683 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
684
685 with pytest.raises(APIResponseValidationError):
686 strict_client.get("/foo", cast_to=Model)
687
688 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
689
690 response = client.get("/foo", cast_to=Model)
691 assert isinstance(response, str) # type: ignore[unreachable]
692
693 @pytest.mark.parametrize(
694 "remaining_retries,retry_after,timeout",
695 [
696 [3, "20", 20],
697 [3, "0", 0.5],
698 [3, "-10", 0.5],
699 [3, "60", 60],
700 [3, "61", 0.5],
701 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
702 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
703 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
704 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
705 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
706 [3, "99999999999999999999999999999999999", 0.5],
707 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
708 [3, "", 0.5],
709 [2, "", 0.5 * 2.0],
710 [1, "", 0.5 * 4.0],
711 [-1100, "", 8], # test large number potentially overflowing
712 ],
713 )
714 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
715 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
716 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
717
718 headers = httpx.Headers({"retry-after": retry_after})
719 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
720 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
721 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
722
723 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
724 @pytest.mark.respx(base_url=base_url)
725 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, client: OpenAI) -> None:
726 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
727
728 with pytest.raises(APITimeoutError):
729 client.chat.completions.with_streaming_response.create(
730 messages=[
731 {
732 "content": "string",
733 "role": "developer",
734 }
735 ],
736 model="gpt-4o",
737 ).__enter__()
738
739 assert _get_open_connections(self.client) == 0
740
741 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
742 @pytest.mark.respx(base_url=base_url)
743 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, client: OpenAI) -> None:
744 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
745
746 with pytest.raises(APIStatusError):
747 client.chat.completions.with_streaming_response.create(
748 messages=[
749 {
750 "content": "string",
751 "role": "developer",
752 }
753 ],
754 model="gpt-4o",
755 ).__enter__()
756 assert _get_open_connections(self.client) == 0
757
758 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
759 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
760 @pytest.mark.respx(base_url=base_url)
761 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
762 def test_retries_taken(
763 self,
764 client: OpenAI,
765 failures_before_success: int,
766 failure_mode: Literal["status", "exception"],
767 respx_mock: MockRouter,
768 ) -> None:
769 client = client.with_options(max_retries=4)
770
771 nb_retries = 0
772
773 def retry_handler(_request: httpx.Request) -> httpx.Response:
774 nonlocal nb_retries
775 if nb_retries < failures_before_success:
776 nb_retries += 1
777 if failure_mode == "exception":
778 raise RuntimeError("oops")
779 return httpx.Response(500)
780 return httpx.Response(200)
781
782 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
783
784 response = client.chat.completions.with_raw_response.create(
785 messages=[
786 {
787 "content": "string",
788 "role": "developer",
789 }
790 ],
791 model="gpt-4o",
792 )
793
794 assert response.retries_taken == failures_before_success
795 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
796
797 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
798 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
799 @pytest.mark.respx(base_url=base_url)
800 def test_omit_retry_count_header(
801 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
802 ) -> None:
803 client = client.with_options(max_retries=4)
804
805 nb_retries = 0
806
807 def retry_handler(_request: httpx.Request) -> httpx.Response:
808 nonlocal nb_retries
809 if nb_retries < failures_before_success:
810 nb_retries += 1
811 return httpx.Response(500)
812 return httpx.Response(200)
813
814 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
815
816 response = client.chat.completions.with_raw_response.create(
817 messages=[
818 {
819 "content": "string",
820 "role": "developer",
821 }
822 ],
823 model="gpt-4o",
824 extra_headers={"x-stainless-retry-count": Omit()},
825 )
826
827 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
828
829 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
830 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
831 @pytest.mark.respx(base_url=base_url)
832 def test_overwrite_retry_count_header(
833 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
834 ) -> None:
835 client = client.with_options(max_retries=4)
836
837 nb_retries = 0
838
839 def retry_handler(_request: httpx.Request) -> httpx.Response:
840 nonlocal nb_retries
841 if nb_retries < failures_before_success:
842 nb_retries += 1
843 return httpx.Response(500)
844 return httpx.Response(200)
845
846 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
847
848 response = client.chat.completions.with_raw_response.create(
849 messages=[
850 {
851 "content": "string",
852 "role": "developer",
853 }
854 ],
855 model="gpt-4o",
856 extra_headers={"x-stainless-retry-count": "42"},
857 )
858
859 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
860
861 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
862 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
863 @pytest.mark.respx(base_url=base_url)
864 def test_retries_taken_new_response_class(
865 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
866 ) -> None:
867 client = client.with_options(max_retries=4)
868
869 nb_retries = 0
870
871 def retry_handler(_request: httpx.Request) -> httpx.Response:
872 nonlocal nb_retries
873 if nb_retries < failures_before_success:
874 nb_retries += 1
875 return httpx.Response(500)
876 return httpx.Response(200)
877
878 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
879
880 with client.chat.completions.with_streaming_response.create(
881 messages=[
882 {
883 "content": "string",
884 "role": "developer",
885 }
886 ],
887 model="gpt-4o",
888 ) as response:
889 assert response.retries_taken == failures_before_success
890 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
891
892 def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
893 # Test that the proxy environment variables are set correctly
894 monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
895
896 client = DefaultHttpxClient()
897
898 mounts = tuple(client._mounts.items())
899 assert len(mounts) == 1
900 assert mounts[0][0].pattern == "https://"
901
902 @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
903 def test_default_client_creation(self) -> None:
904 # Ensure that the client can be initialized without any exceptions
905 DefaultHttpxClient(
906 verify=True,
907 cert=None,
908 trust_env=True,
909 http1=True,
910 http2=False,
911 limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
912 )
913
914 @pytest.mark.respx(base_url=base_url)
915 def test_follow_redirects(self, respx_mock: MockRouter) -> None:
916 # Test that the default follow_redirects=True allows following redirects
917 respx_mock.post("/redirect").mock(
918 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
919 )
920 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
921
922 response = self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
923 assert response.status_code == 200
924 assert response.json() == {"status": "ok"}
925
926 @pytest.mark.respx(base_url=base_url)
927 def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
928 # Test that follow_redirects=False prevents following redirects
929 respx_mock.post("/redirect").mock(
930 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
931 )
932
933 with pytest.raises(APIStatusError) as exc_info:
934 self.client.post(
935 "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
936 )
937
938 assert exc_info.value.response.status_code == 302
939 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
940
941
942class TestAsyncOpenAI:
943 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
944
945 @pytest.mark.respx(base_url=base_url)
946 @pytest.mark.asyncio
947 async def test_raw_response(self, respx_mock: MockRouter) -> None:
948 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
949
950 response = await self.client.post("/foo", cast_to=httpx.Response)
951 assert response.status_code == 200
952 assert isinstance(response, httpx.Response)
953 assert response.json() == {"foo": "bar"}
954
955 @pytest.mark.respx(base_url=base_url)
956 @pytest.mark.asyncio
957 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
958 respx_mock.post("/foo").mock(
959 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
960 )
961
962 response = await self.client.post("/foo", cast_to=httpx.Response)
963 assert response.status_code == 200
964 assert isinstance(response, httpx.Response)
965 assert response.json() == {"foo": "bar"}
966
967 def test_copy(self) -> None:
968 copied = self.client.copy()
969 assert id(copied) != id(self.client)
970
971 copied = self.client.copy(api_key="another My API Key")
972 assert copied.api_key == "another My API Key"
973 assert self.client.api_key == "My API Key"
974
975 def test_copy_default_options(self) -> None:
976 # options that have a default are overridden correctly
977 copied = self.client.copy(max_retries=7)
978 assert copied.max_retries == 7
979 assert self.client.max_retries == 2
980
981 copied2 = copied.copy(max_retries=6)
982 assert copied2.max_retries == 6
983 assert copied.max_retries == 7
984
985 # timeout
986 assert isinstance(self.client.timeout, httpx.Timeout)
987 copied = self.client.copy(timeout=None)
988 assert copied.timeout is None
989 assert isinstance(self.client.timeout, httpx.Timeout)
990
991 def test_copy_default_headers(self) -> None:
992 client = AsyncOpenAI(
993 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
994 )
995 assert client.default_headers["X-Foo"] == "bar"
996
997 # does not override the already given value when not specified
998 copied = client.copy()
999 assert copied.default_headers["X-Foo"] == "bar"
1000
1001 # merges already given headers
1002 copied = client.copy(default_headers={"X-Bar": "stainless"})
1003 assert copied.default_headers["X-Foo"] == "bar"
1004 assert copied.default_headers["X-Bar"] == "stainless"
1005
1006 # uses new values for any already given headers
1007 copied = client.copy(default_headers={"X-Foo": "stainless"})
1008 assert copied.default_headers["X-Foo"] == "stainless"
1009
1010 # set_default_headers
1011
1012 # completely overrides already set values
1013 copied = client.copy(set_default_headers={})
1014 assert copied.default_headers.get("X-Foo") is None
1015
1016 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
1017 assert copied.default_headers["X-Bar"] == "Robert"
1018
1019 with pytest.raises(
1020 ValueError,
1021 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
1022 ):
1023 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
1024
1025 def test_copy_default_query(self) -> None:
1026 client = AsyncOpenAI(
1027 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
1028 )
1029 assert _get_params(client)["foo"] == "bar"
1030
1031 # does not override the already given value when not specified
1032 copied = client.copy()
1033 assert _get_params(copied)["foo"] == "bar"
1034
1035 # merges already given params
1036 copied = client.copy(default_query={"bar": "stainless"})
1037 params = _get_params(copied)
1038 assert params["foo"] == "bar"
1039 assert params["bar"] == "stainless"
1040
1041 # uses new values for any already given headers
1042 copied = client.copy(default_query={"foo": "stainless"})
1043 assert _get_params(copied)["foo"] == "stainless"
1044
1045 # set_default_query
1046
1047 # completely overrides already set values
1048 copied = client.copy(set_default_query={})
1049 assert _get_params(copied) == {}
1050
1051 copied = client.copy(set_default_query={"bar": "Robert"})
1052 assert _get_params(copied)["bar"] == "Robert"
1053
1054 with pytest.raises(
1055 ValueError,
1056 # TODO: update
1057 match="`default_query` and `set_default_query` arguments are mutually exclusive",
1058 ):
1059 client.copy(set_default_query={}, default_query={"foo": "Bar"})
1060
1061 def test_copy_signature(self) -> None:
1062 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
1063 init_signature = inspect.signature(
1064 # mypy doesn't like that we access the `__init__` property.
1065 self.client.__init__, # type: ignore[misc]
1066 )
1067 copy_signature = inspect.signature(self.client.copy)
1068 exclude_params = {"transport", "proxies", "_strict_response_validation"}
1069
1070 for name in init_signature.parameters.keys():
1071 if name in exclude_params:
1072 continue
1073
1074 copy_param = copy_signature.parameters.get(name)
1075 assert copy_param is not None, f"copy() signature is missing the {name} param"
1076
1077 def test_copy_build_request(self) -> None:
1078 options = FinalRequestOptions(method="get", url="/foo")
1079
1080 def build_request(options: FinalRequestOptions) -> None:
1081 client = self.client.copy()
1082 client._build_request(options)
1083
1084 # ensure that the machinery is warmed up before tracing starts.
1085 build_request(options)
1086 gc.collect()
1087
1088 tracemalloc.start(1000)
1089
1090 snapshot_before = tracemalloc.take_snapshot()
1091
1092 ITERATIONS = 10
1093 for _ in range(ITERATIONS):
1094 build_request(options)
1095
1096 gc.collect()
1097 snapshot_after = tracemalloc.take_snapshot()
1098
1099 tracemalloc.stop()
1100
1101 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
1102 if diff.count == 0:
1103 # Avoid false positives by considering only leaks (i.e. allocations that persist).
1104 return
1105
1106 if diff.count % ITERATIONS != 0:
1107 # Avoid false positives by considering only leaks that appear per iteration.
1108 return
1109
1110 for frame in diff.traceback:
1111 if any(
1112 frame.filename.endswith(fragment)
1113 for fragment in [
1114 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
1115 #
1116 # removing the decorator fixes the leak for reasons we don't understand.
1117 "openai/_legacy_response.py",
1118 "openai/_response.py",
1119 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
1120 "openai/_compat.py",
1121 # Standard library leaks we don't care about.
1122 "/logging/__init__.py",
1123 ]
1124 ):
1125 return
1126
1127 leaks.append(diff)
1128
1129 leaks: list[tracemalloc.StatisticDiff] = []
1130 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
1131 add_leak(leaks, diff)
1132 if leaks:
1133 for leak in leaks:
1134 print("MEMORY LEAK:", leak)
1135 for frame in leak.traceback:
1136 print(frame)
1137 raise AssertionError()
1138
1139 async def test_request_timeout(self) -> None:
1140 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
1141 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1142 assert timeout == DEFAULT_TIMEOUT
1143
1144 request = self.client._build_request(
1145 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
1146 )
1147 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1148 assert timeout == httpx.Timeout(100.0)
1149
1150 async def test_client_timeout_option(self) -> None:
1151 client = AsyncOpenAI(
1152 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
1153 )
1154
1155 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1156 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1157 assert timeout == httpx.Timeout(0)
1158
1159 async def test_http_client_timeout_option(self) -> None:
1160 # custom timeout given to the httpx client should be used
1161 async with httpx.AsyncClient(timeout=None) as http_client:
1162 client = AsyncOpenAI(
1163 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1164 )
1165
1166 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1167 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1168 assert timeout == httpx.Timeout(None)
1169
1170 # no timeout given to the httpx client should not use the httpx default
1171 async with httpx.AsyncClient() as http_client:
1172 client = AsyncOpenAI(
1173 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1174 )
1175
1176 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1177 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1178 assert timeout == DEFAULT_TIMEOUT
1179
1180 # explicitly passing the default timeout currently results in it being ignored
1181 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1182 client = AsyncOpenAI(
1183 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1184 )
1185
1186 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1187 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1188 assert timeout == DEFAULT_TIMEOUT # our default
1189
1190 def test_invalid_http_client(self) -> None:
1191 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1192 with httpx.Client() as http_client:
1193 AsyncOpenAI(
1194 base_url=base_url,
1195 api_key=api_key,
1196 _strict_response_validation=True,
1197 http_client=cast(Any, http_client),
1198 )
1199
1200 def test_default_headers_option(self) -> None:
1201 client = AsyncOpenAI(
1202 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1203 )
1204 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1205 assert request.headers.get("x-foo") == "bar"
1206 assert request.headers.get("x-stainless-lang") == "python"
1207
1208 client2 = AsyncOpenAI(
1209 base_url=base_url,
1210 api_key=api_key,
1211 _strict_response_validation=True,
1212 default_headers={
1213 "X-Foo": "stainless",
1214 "X-Stainless-Lang": "my-overriding-header",
1215 },
1216 )
1217 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1218 assert request.headers.get("x-foo") == "stainless"
1219 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1220
1221 def test_validate_headers(self) -> None:
1222 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1223 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1224 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1225
1226 with pytest.raises(OpenAIError):
1227 with update_env(**{"OPENAI_API_KEY": Omit()}):
1228 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1229 _ = client2
1230
1231 def test_default_query_option(self) -> None:
1232 client = AsyncOpenAI(
1233 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1234 )
1235 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1236 url = httpx.URL(request.url)
1237 assert dict(url.params) == {"query_param": "bar"}
1238
1239 request = client._build_request(
1240 FinalRequestOptions(
1241 method="get",
1242 url="/foo",
1243 params={"foo": "baz", "query_param": "overridden"},
1244 )
1245 )
1246 url = httpx.URL(request.url)
1247 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
1248
1249 def test_request_extra_json(self) -> None:
1250 request = self.client._build_request(
1251 FinalRequestOptions(
1252 method="post",
1253 url="/foo",
1254 json_data={"foo": "bar"},
1255 extra_json={"baz": False},
1256 ),
1257 )
1258 data = json.loads(request.content.decode("utf-8"))
1259 assert data == {"foo": "bar", "baz": False}
1260
1261 request = self.client._build_request(
1262 FinalRequestOptions(
1263 method="post",
1264 url="/foo",
1265 extra_json={"baz": False},
1266 ),
1267 )
1268 data = json.loads(request.content.decode("utf-8"))
1269 assert data == {"baz": False}
1270
1271 # `extra_json` takes priority over `json_data` when keys clash
1272 request = self.client._build_request(
1273 FinalRequestOptions(
1274 method="post",
1275 url="/foo",
1276 json_data={"foo": "bar", "baz": True},
1277 extra_json={"baz": None},
1278 ),
1279 )
1280 data = json.loads(request.content.decode("utf-8"))
1281 assert data == {"foo": "bar", "baz": None}
1282
1283 def test_request_extra_headers(self) -> None:
1284 request = self.client._build_request(
1285 FinalRequestOptions(
1286 method="post",
1287 url="/foo",
1288 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1289 ),
1290 )
1291 assert request.headers.get("X-Foo") == "Foo"
1292
1293 # `extra_headers` takes priority over `default_headers` when keys clash
1294 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1295 FinalRequestOptions(
1296 method="post",
1297 url="/foo",
1298 **make_request_options(
1299 extra_headers={"X-Bar": "false"},
1300 ),
1301 ),
1302 )
1303 assert request.headers.get("X-Bar") == "false"
1304
1305 def test_request_extra_query(self) -> None:
1306 request = self.client._build_request(
1307 FinalRequestOptions(
1308 method="post",
1309 url="/foo",
1310 **make_request_options(
1311 extra_query={"my_query_param": "Foo"},
1312 ),
1313 ),
1314 )
1315 params = dict(request.url.params)
1316 assert params == {"my_query_param": "Foo"}
1317
1318 # if both `query` and `extra_query` are given, they are merged
1319 request = self.client._build_request(
1320 FinalRequestOptions(
1321 method="post",
1322 url="/foo",
1323 **make_request_options(
1324 query={"bar": "1"},
1325 extra_query={"foo": "2"},
1326 ),
1327 ),
1328 )
1329 params = dict(request.url.params)
1330 assert params == {"bar": "1", "foo": "2"}
1331
1332 # `extra_query` takes priority over `query` when keys clash
1333 request = self.client._build_request(
1334 FinalRequestOptions(
1335 method="post",
1336 url="/foo",
1337 **make_request_options(
1338 query={"foo": "1"},
1339 extra_query={"foo": "2"},
1340 ),
1341 ),
1342 )
1343 params = dict(request.url.params)
1344 assert params == {"foo": "2"}
1345
1346 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1347 request = async_client._build_request(
1348 FinalRequestOptions.construct(
1349 method="get",
1350 url="/foo",
1351 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1352 json_data={"array": ["foo", "bar"]},
1353 files=[("foo.txt", b"hello world")],
1354 )
1355 )
1356
1357 assert request.read().split(b"\r\n") == [
1358 b"--6b7ba517decee4a450543ea6ae821c82",
1359 b'Content-Disposition: form-data; name="array[]"',
1360 b"",
1361 b"foo",
1362 b"--6b7ba517decee4a450543ea6ae821c82",
1363 b'Content-Disposition: form-data; name="array[]"',
1364 b"",
1365 b"bar",
1366 b"--6b7ba517decee4a450543ea6ae821c82",
1367 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1368 b"Content-Type: application/octet-stream",
1369 b"",
1370 b"hello world",
1371 b"--6b7ba517decee4a450543ea6ae821c82--",
1372 b"",
1373 ]
1374
1375 @pytest.mark.respx(base_url=base_url)
1376 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1377 class Model1(BaseModel):
1378 name: str
1379
1380 class Model2(BaseModel):
1381 foo: str
1382
1383 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1384
1385 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1386 assert isinstance(response, Model2)
1387 assert response.foo == "bar"
1388
1389 @pytest.mark.respx(base_url=base_url)
1390 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1391 """Union of objects with the same field name using a different type"""
1392
1393 class Model1(BaseModel):
1394 foo: int
1395
1396 class Model2(BaseModel):
1397 foo: str
1398
1399 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1400
1401 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1402 assert isinstance(response, Model2)
1403 assert response.foo == "bar"
1404
1405 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1406
1407 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1408 assert isinstance(response, Model1)
1409 assert response.foo == 1
1410
1411 @pytest.mark.respx(base_url=base_url)
1412 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1413 """
1414 Response that sets Content-Type to something other than application/json but returns json data
1415 """
1416
1417 class Model(BaseModel):
1418 foo: int
1419
1420 respx_mock.get("/foo").mock(
1421 return_value=httpx.Response(
1422 200,
1423 content=json.dumps({"foo": 2}),
1424 headers={"Content-Type": "application/text"},
1425 )
1426 )
1427
1428 response = await self.client.get("/foo", cast_to=Model)
1429 assert isinstance(response, Model)
1430 assert response.foo == 2
1431
1432 def test_base_url_setter(self) -> None:
1433 client = AsyncOpenAI(
1434 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1435 )
1436 assert client.base_url == "https://example.com/from_init/"
1437
1438 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1439
1440 assert client.base_url == "https://example.com/from_setter/"
1441
1442 def test_base_url_env(self) -> None:
1443 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1444 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1445 assert client.base_url == "http://localhost:5000/from/env/"
1446
1447 @pytest.mark.parametrize(
1448 "client",
1449 [
1450 AsyncOpenAI(
1451 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1452 ),
1453 AsyncOpenAI(
1454 base_url="http://localhost:5000/custom/path/",
1455 api_key=api_key,
1456 _strict_response_validation=True,
1457 http_client=httpx.AsyncClient(),
1458 ),
1459 ],
1460 ids=["standard", "custom http client"],
1461 )
1462 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1463 request = client._build_request(
1464 FinalRequestOptions(
1465 method="post",
1466 url="/foo",
1467 json_data={"foo": "bar"},
1468 ),
1469 )
1470 assert request.url == "http://localhost:5000/custom/path/foo"
1471
1472 @pytest.mark.parametrize(
1473 "client",
1474 [
1475 AsyncOpenAI(
1476 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1477 ),
1478 AsyncOpenAI(
1479 base_url="http://localhost:5000/custom/path/",
1480 api_key=api_key,
1481 _strict_response_validation=True,
1482 http_client=httpx.AsyncClient(),
1483 ),
1484 ],
1485 ids=["standard", "custom http client"],
1486 )
1487 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1488 request = client._build_request(
1489 FinalRequestOptions(
1490 method="post",
1491 url="/foo",
1492 json_data={"foo": "bar"},
1493 ),
1494 )
1495 assert request.url == "http://localhost:5000/custom/path/foo"
1496
1497 @pytest.mark.parametrize(
1498 "client",
1499 [
1500 AsyncOpenAI(
1501 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1502 ),
1503 AsyncOpenAI(
1504 base_url="http://localhost:5000/custom/path/",
1505 api_key=api_key,
1506 _strict_response_validation=True,
1507 http_client=httpx.AsyncClient(),
1508 ),
1509 ],
1510 ids=["standard", "custom http client"],
1511 )
1512 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1513 request = client._build_request(
1514 FinalRequestOptions(
1515 method="post",
1516 url="https://myapi.com/foo",
1517 json_data={"foo": "bar"},
1518 ),
1519 )
1520 assert request.url == "https://myapi.com/foo"
1521
1522 async def test_copied_client_does_not_close_http(self) -> None:
1523 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1524 assert not client.is_closed()
1525
1526 copied = client.copy()
1527 assert copied is not client
1528
1529 del copied
1530
1531 await asyncio.sleep(0.2)
1532 assert not client.is_closed()
1533
1534 async def test_client_context_manager(self) -> None:
1535 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1536 async with client as c2:
1537 assert c2 is client
1538 assert not c2.is_closed()
1539 assert not client.is_closed()
1540 assert client.is_closed()
1541
1542 @pytest.mark.respx(base_url=base_url)
1543 @pytest.mark.asyncio
1544 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1545 class Model(BaseModel):
1546 foo: str
1547
1548 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1549
1550 with pytest.raises(APIResponseValidationError) as exc:
1551 await self.client.get("/foo", cast_to=Model)
1552
1553 assert isinstance(exc.value.__cause__, ValidationError)
1554
1555 async def test_client_max_retries_validation(self) -> None:
1556 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
1557 AsyncOpenAI(
1558 base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)
1559 )
1560
1561 @pytest.mark.respx(base_url=base_url)
1562 @pytest.mark.asyncio
1563 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1564 class Model(BaseModel):
1565 name: str
1566
1567 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1568
1569 stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1570 assert isinstance(stream, AsyncStream)
1571 await stream.response.aclose()
1572
1573 @pytest.mark.respx(base_url=base_url)
1574 @pytest.mark.asyncio
1575 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1576 class Model(BaseModel):
1577 name: str
1578
1579 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1580
1581 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1582
1583 with pytest.raises(APIResponseValidationError):
1584 await strict_client.get("/foo", cast_to=Model)
1585
1586 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1587
1588 response = await client.get("/foo", cast_to=Model)
1589 assert isinstance(response, str) # type: ignore[unreachable]
1590
1591 @pytest.mark.parametrize(
1592 "remaining_retries,retry_after,timeout",
1593 [
1594 [3, "20", 20],
1595 [3, "0", 0.5],
1596 [3, "-10", 0.5],
1597 [3, "60", 60],
1598 [3, "61", 0.5],
1599 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1600 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1601 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1602 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1603 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1604 [3, "99999999999999999999999999999999999", 0.5],
1605 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1606 [3, "", 0.5],
1607 [2, "", 0.5 * 2.0],
1608 [1, "", 0.5 * 4.0],
1609 [-1100, "", 8], # test large number potentially overflowing
1610 ],
1611 )
1612 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1613 @pytest.mark.asyncio
1614 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1615 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1616
1617 headers = httpx.Headers({"retry-after": retry_after})
1618 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1619 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1620 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1621
1622 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1623 @pytest.mark.respx(base_url=base_url)
1624 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1625 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1626
1627 with pytest.raises(APITimeoutError):
1628 await async_client.chat.completions.with_streaming_response.create(
1629 messages=[
1630 {
1631 "content": "string",
1632 "role": "developer",
1633 }
1634 ],
1635 model="gpt-4o",
1636 ).__aenter__()
1637
1638 assert _get_open_connections(self.client) == 0
1639
1640 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1641 @pytest.mark.respx(base_url=base_url)
1642 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter, async_client: AsyncOpenAI) -> None:
1643 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1644
1645 with pytest.raises(APIStatusError):
1646 await async_client.chat.completions.with_streaming_response.create(
1647 messages=[
1648 {
1649 "content": "string",
1650 "role": "developer",
1651 }
1652 ],
1653 model="gpt-4o",
1654 ).__aenter__()
1655 assert _get_open_connections(self.client) == 0
1656
1657 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1658 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1659 @pytest.mark.respx(base_url=base_url)
1660 @pytest.mark.asyncio
1661 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
1662 async def test_retries_taken(
1663 self,
1664 async_client: AsyncOpenAI,
1665 failures_before_success: int,
1666 failure_mode: Literal["status", "exception"],
1667 respx_mock: MockRouter,
1668 ) -> None:
1669 client = async_client.with_options(max_retries=4)
1670
1671 nb_retries = 0
1672
1673 def retry_handler(_request: httpx.Request) -> httpx.Response:
1674 nonlocal nb_retries
1675 if nb_retries < failures_before_success:
1676 nb_retries += 1
1677 if failure_mode == "exception":
1678 raise RuntimeError("oops")
1679 return httpx.Response(500)
1680 return httpx.Response(200)
1681
1682 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1683
1684 response = await client.chat.completions.with_raw_response.create(
1685 messages=[
1686 {
1687 "content": "string",
1688 "role": "developer",
1689 }
1690 ],
1691 model="gpt-4o",
1692 )
1693
1694 assert response.retries_taken == failures_before_success
1695 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1696
1697 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1698 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1699 @pytest.mark.respx(base_url=base_url)
1700 @pytest.mark.asyncio
1701 async def test_omit_retry_count_header(
1702 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1703 ) -> None:
1704 client = async_client.with_options(max_retries=4)
1705
1706 nb_retries = 0
1707
1708 def retry_handler(_request: httpx.Request) -> httpx.Response:
1709 nonlocal nb_retries
1710 if nb_retries < failures_before_success:
1711 nb_retries += 1
1712 return httpx.Response(500)
1713 return httpx.Response(200)
1714
1715 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1716
1717 response = await client.chat.completions.with_raw_response.create(
1718 messages=[
1719 {
1720 "content": "string",
1721 "role": "developer",
1722 }
1723 ],
1724 model="gpt-4o",
1725 extra_headers={"x-stainless-retry-count": Omit()},
1726 )
1727
1728 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
1729
1730 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1731 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1732 @pytest.mark.respx(base_url=base_url)
1733 @pytest.mark.asyncio
1734 async def test_overwrite_retry_count_header(
1735 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1736 ) -> None:
1737 client = async_client.with_options(max_retries=4)
1738
1739 nb_retries = 0
1740
1741 def retry_handler(_request: httpx.Request) -> httpx.Response:
1742 nonlocal nb_retries
1743 if nb_retries < failures_before_success:
1744 nb_retries += 1
1745 return httpx.Response(500)
1746 return httpx.Response(200)
1747
1748 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1749
1750 response = await client.chat.completions.with_raw_response.create(
1751 messages=[
1752 {
1753 "content": "string",
1754 "role": "developer",
1755 }
1756 ],
1757 model="gpt-4o",
1758 extra_headers={"x-stainless-retry-count": "42"},
1759 )
1760
1761 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
1762
1763 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1764 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1765 @pytest.mark.respx(base_url=base_url)
1766 @pytest.mark.asyncio
1767 async def test_retries_taken_new_response_class(
1768 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1769 ) -> None:
1770 client = async_client.with_options(max_retries=4)
1771
1772 nb_retries = 0
1773
1774 def retry_handler(_request: httpx.Request) -> httpx.Response:
1775 nonlocal nb_retries
1776 if nb_retries < failures_before_success:
1777 nb_retries += 1
1778 return httpx.Response(500)
1779 return httpx.Response(200)
1780
1781 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1782
1783 async with client.chat.completions.with_streaming_response.create(
1784 messages=[
1785 {
1786 "content": "string",
1787 "role": "developer",
1788 }
1789 ],
1790 model="gpt-4o",
1791 ) as response:
1792 assert response.retries_taken == failures_before_success
1793 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1794
1795 def test_get_platform(self) -> None:
1796 # A previous implementation of asyncify could leave threads unterminated when
1797 # used with nest_asyncio.
1798 #
1799 # Since nest_asyncio.apply() is global and cannot be un-applied, this
1800 # test is run in a separate process to avoid affecting other tests.
1801 test_code = dedent("""
1802 import asyncio
1803 import nest_asyncio
1804 import threading
1805
1806 from openai._utils import asyncify
1807 from openai._base_client import get_platform
1808
1809 async def test_main() -> None:
1810 result = await asyncify(get_platform)()
1811 print(result)
1812 for thread in threading.enumerate():
1813 print(thread.name)
1814
1815 nest_asyncio.apply()
1816 asyncio.run(test_main())
1817 """)
1818 with subprocess.Popen(
1819 [sys.executable, "-c", test_code],
1820 text=True,
1821 ) as process:
1822 timeout = 10 # seconds
1823
1824 start_time = time.monotonic()
1825 while True:
1826 return_code = process.poll()
1827 if return_code is not None:
1828 if return_code != 0:
1829 raise AssertionError("calling get_platform using asyncify resulted in a non-zero exit code")
1830
1831 # success
1832 break
1833
1834 if time.monotonic() - start_time > timeout:
1835 process.kill()
1836 raise AssertionError("calling get_platform using asyncify resulted in a hung process")
1837
1838 time.sleep(0.1)
1839
1840 async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
1841 # Test that the proxy environment variables are set correctly
1842 monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
1843
1844 client = DefaultAsyncHttpxClient()
1845
1846 mounts = tuple(client._mounts.items())
1847 assert len(mounts) == 1
1848 assert mounts[0][0].pattern == "https://"
1849
1850 @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
1851 async def test_default_client_creation(self) -> None:
1852 # Ensure that the client can be initialized without any exceptions
1853 DefaultAsyncHttpxClient(
1854 verify=True,
1855 cert=None,
1856 trust_env=True,
1857 http1=True,
1858 http2=False,
1859 limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
1860 )
1861
1862 @pytest.mark.respx(base_url=base_url)
1863 async def test_follow_redirects(self, respx_mock: MockRouter) -> None:
1864 # Test that the default follow_redirects=True allows following redirects
1865 respx_mock.post("/redirect").mock(
1866 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1867 )
1868 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
1869
1870 response = await self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
1871 assert response.status_code == 200
1872 assert response.json() == {"status": "ok"}
1873
1874 @pytest.mark.respx(base_url=base_url)
1875 async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
1876 # Test that follow_redirects=False prevents following redirects
1877 respx_mock.post("/redirect").mock(
1878 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1879 )
1880
1881 with pytest.raises(APIStatusError) as exc_info:
1882 await self.client.post(
1883 "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
1884 )
1885
1886 assert exc_info.value.response.status_code == 302
1887 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
1888