openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.83.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1885lines · modecode

1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import gc
6import os
7import sys
8import json
9import time
10import asyncio
11import inspect
12import subprocess
13import tracemalloc
14from typing import Any, Union, cast
15from textwrap import dedent
16from unittest import mock
17from typing_extensions import Literal
18
19import httpx
20import pytest
21from respx import MockRouter
22from pydantic import ValidationError
23
24from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
25from openai._types import Omit
26from openai._utils import maybe_transform
27from openai._models import BaseModel, FinalRequestOptions
28from openai._constants import RAW_RESPONSE_HEADER
29from openai._streaming import Stream, AsyncStream
30from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
31from openai._base_client import DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, BaseClient, make_request_options
32from openai.types.chat.completion_create_params import CompletionCreateParamsNonStreaming
33
34from .utils import update_env
35
36base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
37api_key = "My API Key"
38
39
40def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
41 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
42 url = httpx.URL(request.url)
43 return dict(url.params)
44
45
46def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
47 return 0.1
48
49
50def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
51 transport = client._client._transport
52 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
53
54 pool = transport._pool
55 return len(pool._requests)
56
57
58class TestOpenAI:
59 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
60
61 @pytest.mark.respx(base_url=base_url)
62 def test_raw_response(self, respx_mock: MockRouter) -> None:
63 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
64
65 response = self.client.post("/foo", cast_to=httpx.Response)
66 assert response.status_code == 200
67 assert isinstance(response, httpx.Response)
68 assert response.json() == {"foo": "bar"}
69
70 @pytest.mark.respx(base_url=base_url)
71 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
72 respx_mock.post("/foo").mock(
73 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
74 )
75
76 response = self.client.post("/foo", cast_to=httpx.Response)
77 assert response.status_code == 200
78 assert isinstance(response, httpx.Response)
79 assert response.json() == {"foo": "bar"}
80
81 def test_copy(self) -> None:
82 copied = self.client.copy()
83 assert id(copied) != id(self.client)
84
85 copied = self.client.copy(api_key="another My API Key")
86 assert copied.api_key == "another My API Key"
87 assert self.client.api_key == "My API Key"
88
89 def test_copy_default_options(self) -> None:
90 # options that have a default are overridden correctly
91 copied = self.client.copy(max_retries=7)
92 assert copied.max_retries == 7
93 assert self.client.max_retries == 2
94
95 copied2 = copied.copy(max_retries=6)
96 assert copied2.max_retries == 6
97 assert copied.max_retries == 7
98
99 # timeout
100 assert isinstance(self.client.timeout, httpx.Timeout)
101 copied = self.client.copy(timeout=None)
102 assert copied.timeout is None
103 assert isinstance(self.client.timeout, httpx.Timeout)
104
105 def test_copy_default_headers(self) -> None:
106 client = OpenAI(
107 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
108 )
109 assert client.default_headers["X-Foo"] == "bar"
110
111 # does not override the already given value when not specified
112 copied = client.copy()
113 assert copied.default_headers["X-Foo"] == "bar"
114
115 # merges already given headers
116 copied = client.copy(default_headers={"X-Bar": "stainless"})
117 assert copied.default_headers["X-Foo"] == "bar"
118 assert copied.default_headers["X-Bar"] == "stainless"
119
120 # uses new values for any already given headers
121 copied = client.copy(default_headers={"X-Foo": "stainless"})
122 assert copied.default_headers["X-Foo"] == "stainless"
123
124 # set_default_headers
125
126 # completely overrides already set values
127 copied = client.copy(set_default_headers={})
128 assert copied.default_headers.get("X-Foo") is None
129
130 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
131 assert copied.default_headers["X-Bar"] == "Robert"
132
133 with pytest.raises(
134 ValueError,
135 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
136 ):
137 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
138
139 def test_copy_default_query(self) -> None:
140 client = OpenAI(
141 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
142 )
143 assert _get_params(client)["foo"] == "bar"
144
145 # does not override the already given value when not specified
146 copied = client.copy()
147 assert _get_params(copied)["foo"] == "bar"
148
149 # merges already given params
150 copied = client.copy(default_query={"bar": "stainless"})
151 params = _get_params(copied)
152 assert params["foo"] == "bar"
153 assert params["bar"] == "stainless"
154
155 # uses new values for any already given headers
156 copied = client.copy(default_query={"foo": "stainless"})
157 assert _get_params(copied)["foo"] == "stainless"
158
159 # set_default_query
160
161 # completely overrides already set values
162 copied = client.copy(set_default_query={})
163 assert _get_params(copied) == {}
164
165 copied = client.copy(set_default_query={"bar": "Robert"})
166 assert _get_params(copied)["bar"] == "Robert"
167
168 with pytest.raises(
169 ValueError,
170 # TODO: update
171 match="`default_query` and `set_default_query` arguments are mutually exclusive",
172 ):
173 client.copy(set_default_query={}, default_query={"foo": "Bar"})
174
175 def test_copy_signature(self) -> None:
176 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
177 init_signature = inspect.signature(
178 # mypy doesn't like that we access the `__init__` property.
179 self.client.__init__, # type: ignore[misc]
180 )
181 copy_signature = inspect.signature(self.client.copy)
182 exclude_params = {"transport", "proxies", "_strict_response_validation"}
183
184 for name in init_signature.parameters.keys():
185 if name in exclude_params:
186 continue
187
188 copy_param = copy_signature.parameters.get(name)
189 assert copy_param is not None, f"copy() signature is missing the {name} param"
190
191 def test_copy_build_request(self) -> None:
192 options = FinalRequestOptions(method="get", url="/foo")
193
194 def build_request(options: FinalRequestOptions) -> None:
195 client = self.client.copy()
196 client._build_request(options)
197
198 # ensure that the machinery is warmed up before tracing starts.
199 build_request(options)
200 gc.collect()
201
202 tracemalloc.start(1000)
203
204 snapshot_before = tracemalloc.take_snapshot()
205
206 ITERATIONS = 10
207 for _ in range(ITERATIONS):
208 build_request(options)
209
210 gc.collect()
211 snapshot_after = tracemalloc.take_snapshot()
212
213 tracemalloc.stop()
214
215 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
216 if diff.count == 0:
217 # Avoid false positives by considering only leaks (i.e. allocations that persist).
218 return
219
220 if diff.count % ITERATIONS != 0:
221 # Avoid false positives by considering only leaks that appear per iteration.
222 return
223
224 for frame in diff.traceback:
225 if any(
226 frame.filename.endswith(fragment)
227 for fragment in [
228 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
229 #
230 # removing the decorator fixes the leak for reasons we don't understand.
231 "openai/_legacy_response.py",
232 "openai/_response.py",
233 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
234 "openai/_compat.py",
235 # Standard library leaks we don't care about.
236 "/logging/__init__.py",
237 ]
238 ):
239 return
240
241 leaks.append(diff)
242
243 leaks: list[tracemalloc.StatisticDiff] = []
244 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
245 add_leak(leaks, diff)
246 if leaks:
247 for leak in leaks:
248 print("MEMORY LEAK:", leak)
249 for frame in leak.traceback:
250 print(frame)
251 raise AssertionError()
252
253 def test_request_timeout(self) -> None:
254 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
255 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
256 assert timeout == DEFAULT_TIMEOUT
257
258 request = self.client._build_request(
259 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
260 )
261 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
262 assert timeout == httpx.Timeout(100.0)
263
264 def test_client_timeout_option(self) -> None:
265 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
266
267 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
268 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
269 assert timeout == httpx.Timeout(0)
270
271 def test_http_client_timeout_option(self) -> None:
272 # custom timeout given to the httpx client should be used
273 with httpx.Client(timeout=None) as http_client:
274 client = OpenAI(
275 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
276 )
277
278 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
279 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
280 assert timeout == httpx.Timeout(None)
281
282 # no timeout given to the httpx client should not use the httpx default
283 with httpx.Client() as http_client:
284 client = OpenAI(
285 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
286 )
287
288 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
289 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
290 assert timeout == DEFAULT_TIMEOUT
291
292 # explicitly passing the default timeout currently results in it being ignored
293 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
294 client = OpenAI(
295 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
296 )
297
298 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
299 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
300 assert timeout == DEFAULT_TIMEOUT # our default
301
302 async def test_invalid_http_client(self) -> None:
303 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
304 async with httpx.AsyncClient() as http_client:
305 OpenAI(
306 base_url=base_url,
307 api_key=api_key,
308 _strict_response_validation=True,
309 http_client=cast(Any, http_client),
310 )
311
312 def test_default_headers_option(self) -> None:
313 client = OpenAI(
314 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
315 )
316 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
317 assert request.headers.get("x-foo") == "bar"
318 assert request.headers.get("x-stainless-lang") == "python"
319
320 client2 = OpenAI(
321 base_url=base_url,
322 api_key=api_key,
323 _strict_response_validation=True,
324 default_headers={
325 "X-Foo": "stainless",
326 "X-Stainless-Lang": "my-overriding-header",
327 },
328 )
329 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
330 assert request.headers.get("x-foo") == "stainless"
331 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
332
333 def test_validate_headers(self) -> None:
334 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
335 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
336 assert request.headers.get("Authorization") == f"Bearer {api_key}"
337
338 with pytest.raises(OpenAIError):
339 with update_env(**{"OPENAI_API_KEY": Omit()}):
340 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
341 _ = client2
342
343 def test_default_query_option(self) -> None:
344 client = OpenAI(
345 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
346 )
347 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
348 url = httpx.URL(request.url)
349 assert dict(url.params) == {"query_param": "bar"}
350
351 request = client._build_request(
352 FinalRequestOptions(
353 method="get",
354 url="/foo",
355 params={"foo": "baz", "query_param": "overridden"},
356 )
357 )
358 url = httpx.URL(request.url)
359 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
360
361 def test_request_extra_json(self) -> None:
362 request = self.client._build_request(
363 FinalRequestOptions(
364 method="post",
365 url="/foo",
366 json_data={"foo": "bar"},
367 extra_json={"baz": False},
368 ),
369 )
370 data = json.loads(request.content.decode("utf-8"))
371 assert data == {"foo": "bar", "baz": False}
372
373 request = self.client._build_request(
374 FinalRequestOptions(
375 method="post",
376 url="/foo",
377 extra_json={"baz": False},
378 ),
379 )
380 data = json.loads(request.content.decode("utf-8"))
381 assert data == {"baz": False}
382
383 # `extra_json` takes priority over `json_data` when keys clash
384 request = self.client._build_request(
385 FinalRequestOptions(
386 method="post",
387 url="/foo",
388 json_data={"foo": "bar", "baz": True},
389 extra_json={"baz": None},
390 ),
391 )
392 data = json.loads(request.content.decode("utf-8"))
393 assert data == {"foo": "bar", "baz": None}
394
395 def test_request_extra_headers(self) -> None:
396 request = self.client._build_request(
397 FinalRequestOptions(
398 method="post",
399 url="/foo",
400 **make_request_options(extra_headers={"X-Foo": "Foo"}),
401 ),
402 )
403 assert request.headers.get("X-Foo") == "Foo"
404
405 # `extra_headers` takes priority over `default_headers` when keys clash
406 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
407 FinalRequestOptions(
408 method="post",
409 url="/foo",
410 **make_request_options(
411 extra_headers={"X-Bar": "false"},
412 ),
413 ),
414 )
415 assert request.headers.get("X-Bar") == "false"
416
417 def test_request_extra_query(self) -> None:
418 request = self.client._build_request(
419 FinalRequestOptions(
420 method="post",
421 url="/foo",
422 **make_request_options(
423 extra_query={"my_query_param": "Foo"},
424 ),
425 ),
426 )
427 params = dict(request.url.params)
428 assert params == {"my_query_param": "Foo"}
429
430 # if both `query` and `extra_query` are given, they are merged
431 request = self.client._build_request(
432 FinalRequestOptions(
433 method="post",
434 url="/foo",
435 **make_request_options(
436 query={"bar": "1"},
437 extra_query={"foo": "2"},
438 ),
439 ),
440 )
441 params = dict(request.url.params)
442 assert params == {"bar": "1", "foo": "2"}
443
444 # `extra_query` takes priority over `query` when keys clash
445 request = self.client._build_request(
446 FinalRequestOptions(
447 method="post",
448 url="/foo",
449 **make_request_options(
450 query={"foo": "1"},
451 extra_query={"foo": "2"},
452 ),
453 ),
454 )
455 params = dict(request.url.params)
456 assert params == {"foo": "2"}
457
458 def test_multipart_repeating_array(self, client: OpenAI) -> None:
459 request = client._build_request(
460 FinalRequestOptions.construct(
461 method="get",
462 url="/foo",
463 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
464 json_data={"array": ["foo", "bar"]},
465 files=[("foo.txt", b"hello world")],
466 )
467 )
468
469 assert request.read().split(b"\r\n") == [
470 b"--6b7ba517decee4a450543ea6ae821c82",
471 b'Content-Disposition: form-data; name="array[]"',
472 b"",
473 b"foo",
474 b"--6b7ba517decee4a450543ea6ae821c82",
475 b'Content-Disposition: form-data; name="array[]"',
476 b"",
477 b"bar",
478 b"--6b7ba517decee4a450543ea6ae821c82",
479 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
480 b"Content-Type: application/octet-stream",
481 b"",
482 b"hello world",
483 b"--6b7ba517decee4a450543ea6ae821c82--",
484 b"",
485 ]
486
487 @pytest.mark.respx(base_url=base_url)
488 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
489 class Model1(BaseModel):
490 name: str
491
492 class Model2(BaseModel):
493 foo: str
494
495 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
496
497 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
498 assert isinstance(response, Model2)
499 assert response.foo == "bar"
500
501 @pytest.mark.respx(base_url=base_url)
502 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
503 """Union of objects with the same field name using a different type"""
504
505 class Model1(BaseModel):
506 foo: int
507
508 class Model2(BaseModel):
509 foo: str
510
511 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
512
513 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
514 assert isinstance(response, Model2)
515 assert response.foo == "bar"
516
517 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
518
519 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
520 assert isinstance(response, Model1)
521 assert response.foo == 1
522
523 @pytest.mark.respx(base_url=base_url)
524 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
525 """
526 Response that sets Content-Type to something other than application/json but returns json data
527 """
528
529 class Model(BaseModel):
530 foo: int
531
532 respx_mock.get("/foo").mock(
533 return_value=httpx.Response(
534 200,
535 content=json.dumps({"foo": 2}),
536 headers={"Content-Type": "application/text"},
537 )
538 )
539
540 response = self.client.get("/foo", cast_to=Model)
541 assert isinstance(response, Model)
542 assert response.foo == 2
543
544 def test_base_url_setter(self) -> None:
545 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
546 assert client.base_url == "https://example.com/from_init/"
547
548 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
549
550 assert client.base_url == "https://example.com/from_setter/"
551
552 def test_base_url_env(self) -> None:
553 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
554 client = OpenAI(api_key=api_key, _strict_response_validation=True)
555 assert client.base_url == "http://localhost:5000/from/env/"
556
557 @pytest.mark.parametrize(
558 "client",
559 [
560 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
561 OpenAI(
562 base_url="http://localhost:5000/custom/path/",
563 api_key=api_key,
564 _strict_response_validation=True,
565 http_client=httpx.Client(),
566 ),
567 ],
568 ids=["standard", "custom http client"],
569 )
570 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
571 request = client._build_request(
572 FinalRequestOptions(
573 method="post",
574 url="/foo",
575 json_data={"foo": "bar"},
576 ),
577 )
578 assert request.url == "http://localhost:5000/custom/path/foo"
579
580 @pytest.mark.parametrize(
581 "client",
582 [
583 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
584 OpenAI(
585 base_url="http://localhost:5000/custom/path/",
586 api_key=api_key,
587 _strict_response_validation=True,
588 http_client=httpx.Client(),
589 ),
590 ],
591 ids=["standard", "custom http client"],
592 )
593 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
594 request = client._build_request(
595 FinalRequestOptions(
596 method="post",
597 url="/foo",
598 json_data={"foo": "bar"},
599 ),
600 )
601 assert request.url == "http://localhost:5000/custom/path/foo"
602
603 @pytest.mark.parametrize(
604 "client",
605 [
606 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
607 OpenAI(
608 base_url="http://localhost:5000/custom/path/",
609 api_key=api_key,
610 _strict_response_validation=True,
611 http_client=httpx.Client(),
612 ),
613 ],
614 ids=["standard", "custom http client"],
615 )
616 def test_absolute_request_url(self, client: OpenAI) -> None:
617 request = client._build_request(
618 FinalRequestOptions(
619 method="post",
620 url="https://myapi.com/foo",
621 json_data={"foo": "bar"},
622 ),
623 )
624 assert request.url == "https://myapi.com/foo"
625
626 def test_copied_client_does_not_close_http(self) -> None:
627 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
628 assert not client.is_closed()
629
630 copied = client.copy()
631 assert copied is not client
632
633 del copied
634
635 assert not client.is_closed()
636
637 def test_client_context_manager(self) -> None:
638 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
639 with client as c2:
640 assert c2 is client
641 assert not c2.is_closed()
642 assert not client.is_closed()
643 assert client.is_closed()
644
645 @pytest.mark.respx(base_url=base_url)
646 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
647 class Model(BaseModel):
648 foo: str
649
650 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
651
652 with pytest.raises(APIResponseValidationError) as exc:
653 self.client.get("/foo", cast_to=Model)
654
655 assert isinstance(exc.value.__cause__, ValidationError)
656
657 def test_client_max_retries_validation(self) -> None:
658 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
659 OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None))
660
661 @pytest.mark.respx(base_url=base_url)
662 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
663 class Model(BaseModel):
664 name: str
665
666 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
667
668 stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
669 assert isinstance(stream, Stream)
670 stream.response.close()
671
672 @pytest.mark.respx(base_url=base_url)
673 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
674 class Model(BaseModel):
675 name: str
676
677 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
678
679 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
680
681 with pytest.raises(APIResponseValidationError):
682 strict_client.get("/foo", cast_to=Model)
683
684 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
685
686 response = client.get("/foo", cast_to=Model)
687 assert isinstance(response, str) # type: ignore[unreachable]
688
689 @pytest.mark.parametrize(
690 "remaining_retries,retry_after,timeout",
691 [
692 [3, "20", 20],
693 [3, "0", 0.5],
694 [3, "-10", 0.5],
695 [3, "60", 60],
696 [3, "61", 0.5],
697 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
698 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
699 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
700 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
701 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
702 [3, "99999999999999999999999999999999999", 0.5],
703 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
704 [3, "", 0.5],
705 [2, "", 0.5 * 2.0],
706 [1, "", 0.5 * 4.0],
707 [-1100, "", 8], # test large number potentially overflowing
708 ],
709 )
710 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
711 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
712 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
713
714 headers = httpx.Headers({"retry-after": retry_after})
715 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
716 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
717 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
718
719 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
720 @pytest.mark.respx(base_url=base_url)
721 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
722 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
723
724 with pytest.raises(APITimeoutError):
725 self.client.post(
726 "/chat/completions",
727 body=cast(
728 object,
729 maybe_transform(
730 dict(
731 messages=[
732 {
733 "role": "user",
734 "content": "Say this is a test",
735 }
736 ],
737 model="gpt-4o",
738 ),
739 CompletionCreateParamsNonStreaming,
740 ),
741 ),
742 cast_to=httpx.Response,
743 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
744 )
745
746 assert _get_open_connections(self.client) == 0
747
748 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
749 @pytest.mark.respx(base_url=base_url)
750 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
751 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
752
753 with pytest.raises(APIStatusError):
754 self.client.post(
755 "/chat/completions",
756 body=cast(
757 object,
758 maybe_transform(
759 dict(
760 messages=[
761 {
762 "role": "user",
763 "content": "Say this is a test",
764 }
765 ],
766 model="gpt-4o",
767 ),
768 CompletionCreateParamsNonStreaming,
769 ),
770 ),
771 cast_to=httpx.Response,
772 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
773 )
774
775 assert _get_open_connections(self.client) == 0
776
777 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
778 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
779 @pytest.mark.respx(base_url=base_url)
780 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
781 def test_retries_taken(
782 self,
783 client: OpenAI,
784 failures_before_success: int,
785 failure_mode: Literal["status", "exception"],
786 respx_mock: MockRouter,
787 ) -> None:
788 client = client.with_options(max_retries=4)
789
790 nb_retries = 0
791
792 def retry_handler(_request: httpx.Request) -> httpx.Response:
793 nonlocal nb_retries
794 if nb_retries < failures_before_success:
795 nb_retries += 1
796 if failure_mode == "exception":
797 raise RuntimeError("oops")
798 return httpx.Response(500)
799 return httpx.Response(200)
800
801 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
802
803 response = client.chat.completions.with_raw_response.create(
804 messages=[
805 {
806 "content": "string",
807 "role": "developer",
808 }
809 ],
810 model="gpt-4o",
811 )
812
813 assert response.retries_taken == failures_before_success
814 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
815
816 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
817 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
818 @pytest.mark.respx(base_url=base_url)
819 def test_omit_retry_count_header(
820 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
821 ) -> None:
822 client = client.with_options(max_retries=4)
823
824 nb_retries = 0
825
826 def retry_handler(_request: httpx.Request) -> httpx.Response:
827 nonlocal nb_retries
828 if nb_retries < failures_before_success:
829 nb_retries += 1
830 return httpx.Response(500)
831 return httpx.Response(200)
832
833 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
834
835 response = client.chat.completions.with_raw_response.create(
836 messages=[
837 {
838 "content": "string",
839 "role": "developer",
840 }
841 ],
842 model="gpt-4o",
843 extra_headers={"x-stainless-retry-count": Omit()},
844 )
845
846 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
847
848 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
849 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
850 @pytest.mark.respx(base_url=base_url)
851 def test_overwrite_retry_count_header(
852 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
853 ) -> None:
854 client = client.with_options(max_retries=4)
855
856 nb_retries = 0
857
858 def retry_handler(_request: httpx.Request) -> httpx.Response:
859 nonlocal nb_retries
860 if nb_retries < failures_before_success:
861 nb_retries += 1
862 return httpx.Response(500)
863 return httpx.Response(200)
864
865 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
866
867 response = client.chat.completions.with_raw_response.create(
868 messages=[
869 {
870 "content": "string",
871 "role": "developer",
872 }
873 ],
874 model="gpt-4o",
875 extra_headers={"x-stainless-retry-count": "42"},
876 )
877
878 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
879
880 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
881 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
882 @pytest.mark.respx(base_url=base_url)
883 def test_retries_taken_new_response_class(
884 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
885 ) -> None:
886 client = client.with_options(max_retries=4)
887
888 nb_retries = 0
889
890 def retry_handler(_request: httpx.Request) -> httpx.Response:
891 nonlocal nb_retries
892 if nb_retries < failures_before_success:
893 nb_retries += 1
894 return httpx.Response(500)
895 return httpx.Response(200)
896
897 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
898
899 with client.chat.completions.with_streaming_response.create(
900 messages=[
901 {
902 "content": "string",
903 "role": "developer",
904 }
905 ],
906 model="gpt-4o",
907 ) as response:
908 assert response.retries_taken == failures_before_success
909 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
910
911 @pytest.mark.respx(base_url=base_url)
912 def test_follow_redirects(self, respx_mock: MockRouter) -> None:
913 # Test that the default follow_redirects=True allows following redirects
914 respx_mock.post("/redirect").mock(
915 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
916 )
917 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
918
919 response = self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
920 assert response.status_code == 200
921 assert response.json() == {"status": "ok"}
922
923 @pytest.mark.respx(base_url=base_url)
924 def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
925 # Test that follow_redirects=False prevents following redirects
926 respx_mock.post("/redirect").mock(
927 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
928 )
929
930 with pytest.raises(APIStatusError) as exc_info:
931 self.client.post(
932 "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
933 )
934
935 assert exc_info.value.response.status_code == 302
936 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
937
938
939class TestAsyncOpenAI:
940 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
941
942 @pytest.mark.respx(base_url=base_url)
943 @pytest.mark.asyncio
944 async def test_raw_response(self, respx_mock: MockRouter) -> None:
945 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
946
947 response = await self.client.post("/foo", cast_to=httpx.Response)
948 assert response.status_code == 200
949 assert isinstance(response, httpx.Response)
950 assert response.json() == {"foo": "bar"}
951
952 @pytest.mark.respx(base_url=base_url)
953 @pytest.mark.asyncio
954 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
955 respx_mock.post("/foo").mock(
956 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
957 )
958
959 response = await self.client.post("/foo", cast_to=httpx.Response)
960 assert response.status_code == 200
961 assert isinstance(response, httpx.Response)
962 assert response.json() == {"foo": "bar"}
963
964 def test_copy(self) -> None:
965 copied = self.client.copy()
966 assert id(copied) != id(self.client)
967
968 copied = self.client.copy(api_key="another My API Key")
969 assert copied.api_key == "another My API Key"
970 assert self.client.api_key == "My API Key"
971
972 def test_copy_default_options(self) -> None:
973 # options that have a default are overridden correctly
974 copied = self.client.copy(max_retries=7)
975 assert copied.max_retries == 7
976 assert self.client.max_retries == 2
977
978 copied2 = copied.copy(max_retries=6)
979 assert copied2.max_retries == 6
980 assert copied.max_retries == 7
981
982 # timeout
983 assert isinstance(self.client.timeout, httpx.Timeout)
984 copied = self.client.copy(timeout=None)
985 assert copied.timeout is None
986 assert isinstance(self.client.timeout, httpx.Timeout)
987
988 def test_copy_default_headers(self) -> None:
989 client = AsyncOpenAI(
990 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
991 )
992 assert client.default_headers["X-Foo"] == "bar"
993
994 # does not override the already given value when not specified
995 copied = client.copy()
996 assert copied.default_headers["X-Foo"] == "bar"
997
998 # merges already given headers
999 copied = client.copy(default_headers={"X-Bar": "stainless"})
1000 assert copied.default_headers["X-Foo"] == "bar"
1001 assert copied.default_headers["X-Bar"] == "stainless"
1002
1003 # uses new values for any already given headers
1004 copied = client.copy(default_headers={"X-Foo": "stainless"})
1005 assert copied.default_headers["X-Foo"] == "stainless"
1006
1007 # set_default_headers
1008
1009 # completely overrides already set values
1010 copied = client.copy(set_default_headers={})
1011 assert copied.default_headers.get("X-Foo") is None
1012
1013 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
1014 assert copied.default_headers["X-Bar"] == "Robert"
1015
1016 with pytest.raises(
1017 ValueError,
1018 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
1019 ):
1020 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
1021
1022 def test_copy_default_query(self) -> None:
1023 client = AsyncOpenAI(
1024 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
1025 )
1026 assert _get_params(client)["foo"] == "bar"
1027
1028 # does not override the already given value when not specified
1029 copied = client.copy()
1030 assert _get_params(copied)["foo"] == "bar"
1031
1032 # merges already given params
1033 copied = client.copy(default_query={"bar": "stainless"})
1034 params = _get_params(copied)
1035 assert params["foo"] == "bar"
1036 assert params["bar"] == "stainless"
1037
1038 # uses new values for any already given headers
1039 copied = client.copy(default_query={"foo": "stainless"})
1040 assert _get_params(copied)["foo"] == "stainless"
1041
1042 # set_default_query
1043
1044 # completely overrides already set values
1045 copied = client.copy(set_default_query={})
1046 assert _get_params(copied) == {}
1047
1048 copied = client.copy(set_default_query={"bar": "Robert"})
1049 assert _get_params(copied)["bar"] == "Robert"
1050
1051 with pytest.raises(
1052 ValueError,
1053 # TODO: update
1054 match="`default_query` and `set_default_query` arguments are mutually exclusive",
1055 ):
1056 client.copy(set_default_query={}, default_query={"foo": "Bar"})
1057
1058 def test_copy_signature(self) -> None:
1059 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
1060 init_signature = inspect.signature(
1061 # mypy doesn't like that we access the `__init__` property.
1062 self.client.__init__, # type: ignore[misc]
1063 )
1064 copy_signature = inspect.signature(self.client.copy)
1065 exclude_params = {"transport", "proxies", "_strict_response_validation"}
1066
1067 for name in init_signature.parameters.keys():
1068 if name in exclude_params:
1069 continue
1070
1071 copy_param = copy_signature.parameters.get(name)
1072 assert copy_param is not None, f"copy() signature is missing the {name} param"
1073
1074 def test_copy_build_request(self) -> None:
1075 options = FinalRequestOptions(method="get", url="/foo")
1076
1077 def build_request(options: FinalRequestOptions) -> None:
1078 client = self.client.copy()
1079 client._build_request(options)
1080
1081 # ensure that the machinery is warmed up before tracing starts.
1082 build_request(options)
1083 gc.collect()
1084
1085 tracemalloc.start(1000)
1086
1087 snapshot_before = tracemalloc.take_snapshot()
1088
1089 ITERATIONS = 10
1090 for _ in range(ITERATIONS):
1091 build_request(options)
1092
1093 gc.collect()
1094 snapshot_after = tracemalloc.take_snapshot()
1095
1096 tracemalloc.stop()
1097
1098 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
1099 if diff.count == 0:
1100 # Avoid false positives by considering only leaks (i.e. allocations that persist).
1101 return
1102
1103 if diff.count % ITERATIONS != 0:
1104 # Avoid false positives by considering only leaks that appear per iteration.
1105 return
1106
1107 for frame in diff.traceback:
1108 if any(
1109 frame.filename.endswith(fragment)
1110 for fragment in [
1111 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
1112 #
1113 # removing the decorator fixes the leak for reasons we don't understand.
1114 "openai/_legacy_response.py",
1115 "openai/_response.py",
1116 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
1117 "openai/_compat.py",
1118 # Standard library leaks we don't care about.
1119 "/logging/__init__.py",
1120 ]
1121 ):
1122 return
1123
1124 leaks.append(diff)
1125
1126 leaks: list[tracemalloc.StatisticDiff] = []
1127 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
1128 add_leak(leaks, diff)
1129 if leaks:
1130 for leak in leaks:
1131 print("MEMORY LEAK:", leak)
1132 for frame in leak.traceback:
1133 print(frame)
1134 raise AssertionError()
1135
1136 async def test_request_timeout(self) -> None:
1137 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
1138 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1139 assert timeout == DEFAULT_TIMEOUT
1140
1141 request = self.client._build_request(
1142 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
1143 )
1144 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1145 assert timeout == httpx.Timeout(100.0)
1146
1147 async def test_client_timeout_option(self) -> None:
1148 client = AsyncOpenAI(
1149 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
1150 )
1151
1152 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1153 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1154 assert timeout == httpx.Timeout(0)
1155
1156 async def test_http_client_timeout_option(self) -> None:
1157 # custom timeout given to the httpx client should be used
1158 async with httpx.AsyncClient(timeout=None) as http_client:
1159 client = AsyncOpenAI(
1160 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1161 )
1162
1163 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1164 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1165 assert timeout == httpx.Timeout(None)
1166
1167 # no timeout given to the httpx client should not use the httpx default
1168 async with httpx.AsyncClient() as http_client:
1169 client = AsyncOpenAI(
1170 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1171 )
1172
1173 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1174 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1175 assert timeout == DEFAULT_TIMEOUT
1176
1177 # explicitly passing the default timeout currently results in it being ignored
1178 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1179 client = AsyncOpenAI(
1180 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1181 )
1182
1183 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1184 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1185 assert timeout == DEFAULT_TIMEOUT # our default
1186
1187 def test_invalid_http_client(self) -> None:
1188 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1189 with httpx.Client() as http_client:
1190 AsyncOpenAI(
1191 base_url=base_url,
1192 api_key=api_key,
1193 _strict_response_validation=True,
1194 http_client=cast(Any, http_client),
1195 )
1196
1197 def test_default_headers_option(self) -> None:
1198 client = AsyncOpenAI(
1199 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1200 )
1201 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1202 assert request.headers.get("x-foo") == "bar"
1203 assert request.headers.get("x-stainless-lang") == "python"
1204
1205 client2 = AsyncOpenAI(
1206 base_url=base_url,
1207 api_key=api_key,
1208 _strict_response_validation=True,
1209 default_headers={
1210 "X-Foo": "stainless",
1211 "X-Stainless-Lang": "my-overriding-header",
1212 },
1213 )
1214 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1215 assert request.headers.get("x-foo") == "stainless"
1216 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1217
1218 def test_validate_headers(self) -> None:
1219 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1220 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1221 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1222
1223 with pytest.raises(OpenAIError):
1224 with update_env(**{"OPENAI_API_KEY": Omit()}):
1225 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1226 _ = client2
1227
1228 def test_default_query_option(self) -> None:
1229 client = AsyncOpenAI(
1230 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1231 )
1232 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1233 url = httpx.URL(request.url)
1234 assert dict(url.params) == {"query_param": "bar"}
1235
1236 request = client._build_request(
1237 FinalRequestOptions(
1238 method="get",
1239 url="/foo",
1240 params={"foo": "baz", "query_param": "overridden"},
1241 )
1242 )
1243 url = httpx.URL(request.url)
1244 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
1245
1246 def test_request_extra_json(self) -> None:
1247 request = self.client._build_request(
1248 FinalRequestOptions(
1249 method="post",
1250 url="/foo",
1251 json_data={"foo": "bar"},
1252 extra_json={"baz": False},
1253 ),
1254 )
1255 data = json.loads(request.content.decode("utf-8"))
1256 assert data == {"foo": "bar", "baz": False}
1257
1258 request = self.client._build_request(
1259 FinalRequestOptions(
1260 method="post",
1261 url="/foo",
1262 extra_json={"baz": False},
1263 ),
1264 )
1265 data = json.loads(request.content.decode("utf-8"))
1266 assert data == {"baz": False}
1267
1268 # `extra_json` takes priority over `json_data` when keys clash
1269 request = self.client._build_request(
1270 FinalRequestOptions(
1271 method="post",
1272 url="/foo",
1273 json_data={"foo": "bar", "baz": True},
1274 extra_json={"baz": None},
1275 ),
1276 )
1277 data = json.loads(request.content.decode("utf-8"))
1278 assert data == {"foo": "bar", "baz": None}
1279
1280 def test_request_extra_headers(self) -> None:
1281 request = self.client._build_request(
1282 FinalRequestOptions(
1283 method="post",
1284 url="/foo",
1285 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1286 ),
1287 )
1288 assert request.headers.get("X-Foo") == "Foo"
1289
1290 # `extra_headers` takes priority over `default_headers` when keys clash
1291 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1292 FinalRequestOptions(
1293 method="post",
1294 url="/foo",
1295 **make_request_options(
1296 extra_headers={"X-Bar": "false"},
1297 ),
1298 ),
1299 )
1300 assert request.headers.get("X-Bar") == "false"
1301
1302 def test_request_extra_query(self) -> None:
1303 request = self.client._build_request(
1304 FinalRequestOptions(
1305 method="post",
1306 url="/foo",
1307 **make_request_options(
1308 extra_query={"my_query_param": "Foo"},
1309 ),
1310 ),
1311 )
1312 params = dict(request.url.params)
1313 assert params == {"my_query_param": "Foo"}
1314
1315 # if both `query` and `extra_query` are given, they are merged
1316 request = self.client._build_request(
1317 FinalRequestOptions(
1318 method="post",
1319 url="/foo",
1320 **make_request_options(
1321 query={"bar": "1"},
1322 extra_query={"foo": "2"},
1323 ),
1324 ),
1325 )
1326 params = dict(request.url.params)
1327 assert params == {"bar": "1", "foo": "2"}
1328
1329 # `extra_query` takes priority over `query` when keys clash
1330 request = self.client._build_request(
1331 FinalRequestOptions(
1332 method="post",
1333 url="/foo",
1334 **make_request_options(
1335 query={"foo": "1"},
1336 extra_query={"foo": "2"},
1337 ),
1338 ),
1339 )
1340 params = dict(request.url.params)
1341 assert params == {"foo": "2"}
1342
1343 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1344 request = async_client._build_request(
1345 FinalRequestOptions.construct(
1346 method="get",
1347 url="/foo",
1348 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1349 json_data={"array": ["foo", "bar"]},
1350 files=[("foo.txt", b"hello world")],
1351 )
1352 )
1353
1354 assert request.read().split(b"\r\n") == [
1355 b"--6b7ba517decee4a450543ea6ae821c82",
1356 b'Content-Disposition: form-data; name="array[]"',
1357 b"",
1358 b"foo",
1359 b"--6b7ba517decee4a450543ea6ae821c82",
1360 b'Content-Disposition: form-data; name="array[]"',
1361 b"",
1362 b"bar",
1363 b"--6b7ba517decee4a450543ea6ae821c82",
1364 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1365 b"Content-Type: application/octet-stream",
1366 b"",
1367 b"hello world",
1368 b"--6b7ba517decee4a450543ea6ae821c82--",
1369 b"",
1370 ]
1371
1372 @pytest.mark.respx(base_url=base_url)
1373 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1374 class Model1(BaseModel):
1375 name: str
1376
1377 class Model2(BaseModel):
1378 foo: str
1379
1380 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1381
1382 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1383 assert isinstance(response, Model2)
1384 assert response.foo == "bar"
1385
1386 @pytest.mark.respx(base_url=base_url)
1387 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1388 """Union of objects with the same field name using a different type"""
1389
1390 class Model1(BaseModel):
1391 foo: int
1392
1393 class Model2(BaseModel):
1394 foo: str
1395
1396 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1397
1398 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1399 assert isinstance(response, Model2)
1400 assert response.foo == "bar"
1401
1402 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1403
1404 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1405 assert isinstance(response, Model1)
1406 assert response.foo == 1
1407
1408 @pytest.mark.respx(base_url=base_url)
1409 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1410 """
1411 Response that sets Content-Type to something other than application/json but returns json data
1412 """
1413
1414 class Model(BaseModel):
1415 foo: int
1416
1417 respx_mock.get("/foo").mock(
1418 return_value=httpx.Response(
1419 200,
1420 content=json.dumps({"foo": 2}),
1421 headers={"Content-Type": "application/text"},
1422 )
1423 )
1424
1425 response = await self.client.get("/foo", cast_to=Model)
1426 assert isinstance(response, Model)
1427 assert response.foo == 2
1428
1429 def test_base_url_setter(self) -> None:
1430 client = AsyncOpenAI(
1431 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1432 )
1433 assert client.base_url == "https://example.com/from_init/"
1434
1435 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1436
1437 assert client.base_url == "https://example.com/from_setter/"
1438
1439 def test_base_url_env(self) -> None:
1440 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1441 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1442 assert client.base_url == "http://localhost:5000/from/env/"
1443
1444 @pytest.mark.parametrize(
1445 "client",
1446 [
1447 AsyncOpenAI(
1448 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1449 ),
1450 AsyncOpenAI(
1451 base_url="http://localhost:5000/custom/path/",
1452 api_key=api_key,
1453 _strict_response_validation=True,
1454 http_client=httpx.AsyncClient(),
1455 ),
1456 ],
1457 ids=["standard", "custom http client"],
1458 )
1459 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1460 request = client._build_request(
1461 FinalRequestOptions(
1462 method="post",
1463 url="/foo",
1464 json_data={"foo": "bar"},
1465 ),
1466 )
1467 assert request.url == "http://localhost:5000/custom/path/foo"
1468
1469 @pytest.mark.parametrize(
1470 "client",
1471 [
1472 AsyncOpenAI(
1473 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1474 ),
1475 AsyncOpenAI(
1476 base_url="http://localhost:5000/custom/path/",
1477 api_key=api_key,
1478 _strict_response_validation=True,
1479 http_client=httpx.AsyncClient(),
1480 ),
1481 ],
1482 ids=["standard", "custom http client"],
1483 )
1484 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1485 request = client._build_request(
1486 FinalRequestOptions(
1487 method="post",
1488 url="/foo",
1489 json_data={"foo": "bar"},
1490 ),
1491 )
1492 assert request.url == "http://localhost:5000/custom/path/foo"
1493
1494 @pytest.mark.parametrize(
1495 "client",
1496 [
1497 AsyncOpenAI(
1498 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1499 ),
1500 AsyncOpenAI(
1501 base_url="http://localhost:5000/custom/path/",
1502 api_key=api_key,
1503 _strict_response_validation=True,
1504 http_client=httpx.AsyncClient(),
1505 ),
1506 ],
1507 ids=["standard", "custom http client"],
1508 )
1509 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1510 request = client._build_request(
1511 FinalRequestOptions(
1512 method="post",
1513 url="https://myapi.com/foo",
1514 json_data={"foo": "bar"},
1515 ),
1516 )
1517 assert request.url == "https://myapi.com/foo"
1518
1519 async def test_copied_client_does_not_close_http(self) -> None:
1520 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1521 assert not client.is_closed()
1522
1523 copied = client.copy()
1524 assert copied is not client
1525
1526 del copied
1527
1528 await asyncio.sleep(0.2)
1529 assert not client.is_closed()
1530
1531 async def test_client_context_manager(self) -> None:
1532 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1533 async with client as c2:
1534 assert c2 is client
1535 assert not c2.is_closed()
1536 assert not client.is_closed()
1537 assert client.is_closed()
1538
1539 @pytest.mark.respx(base_url=base_url)
1540 @pytest.mark.asyncio
1541 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1542 class Model(BaseModel):
1543 foo: str
1544
1545 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1546
1547 with pytest.raises(APIResponseValidationError) as exc:
1548 await self.client.get("/foo", cast_to=Model)
1549
1550 assert isinstance(exc.value.__cause__, ValidationError)
1551
1552 async def test_client_max_retries_validation(self) -> None:
1553 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
1554 AsyncOpenAI(
1555 base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)
1556 )
1557
1558 @pytest.mark.respx(base_url=base_url)
1559 @pytest.mark.asyncio
1560 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1561 class Model(BaseModel):
1562 name: str
1563
1564 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1565
1566 stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1567 assert isinstance(stream, AsyncStream)
1568 await stream.response.aclose()
1569
1570 @pytest.mark.respx(base_url=base_url)
1571 @pytest.mark.asyncio
1572 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1573 class Model(BaseModel):
1574 name: str
1575
1576 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1577
1578 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1579
1580 with pytest.raises(APIResponseValidationError):
1581 await strict_client.get("/foo", cast_to=Model)
1582
1583 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1584
1585 response = await client.get("/foo", cast_to=Model)
1586 assert isinstance(response, str) # type: ignore[unreachable]
1587
1588 @pytest.mark.parametrize(
1589 "remaining_retries,retry_after,timeout",
1590 [
1591 [3, "20", 20],
1592 [3, "0", 0.5],
1593 [3, "-10", 0.5],
1594 [3, "60", 60],
1595 [3, "61", 0.5],
1596 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1597 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1598 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1599 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1600 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1601 [3, "99999999999999999999999999999999999", 0.5],
1602 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1603 [3, "", 0.5],
1604 [2, "", 0.5 * 2.0],
1605 [1, "", 0.5 * 4.0],
1606 [-1100, "", 8], # test large number potentially overflowing
1607 ],
1608 )
1609 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1610 @pytest.mark.asyncio
1611 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1612 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1613
1614 headers = httpx.Headers({"retry-after": retry_after})
1615 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1616 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1617 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1618
1619 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1620 @pytest.mark.respx(base_url=base_url)
1621 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1622 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1623
1624 with pytest.raises(APITimeoutError):
1625 await self.client.post(
1626 "/chat/completions",
1627 body=cast(
1628 object,
1629 maybe_transform(
1630 dict(
1631 messages=[
1632 {
1633 "role": "user",
1634 "content": "Say this is a test",
1635 }
1636 ],
1637 model="gpt-4o",
1638 ),
1639 CompletionCreateParamsNonStreaming,
1640 ),
1641 ),
1642 cast_to=httpx.Response,
1643 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1644 )
1645
1646 assert _get_open_connections(self.client) == 0
1647
1648 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1649 @pytest.mark.respx(base_url=base_url)
1650 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1651 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1652
1653 with pytest.raises(APIStatusError):
1654 await self.client.post(
1655 "/chat/completions",
1656 body=cast(
1657 object,
1658 maybe_transform(
1659 dict(
1660 messages=[
1661 {
1662 "role": "user",
1663 "content": "Say this is a test",
1664 }
1665 ],
1666 model="gpt-4o",
1667 ),
1668 CompletionCreateParamsNonStreaming,
1669 ),
1670 ),
1671 cast_to=httpx.Response,
1672 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1673 )
1674
1675 assert _get_open_connections(self.client) == 0
1676
1677 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1678 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1679 @pytest.mark.respx(base_url=base_url)
1680 @pytest.mark.asyncio
1681 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
1682 async def test_retries_taken(
1683 self,
1684 async_client: AsyncOpenAI,
1685 failures_before_success: int,
1686 failure_mode: Literal["status", "exception"],
1687 respx_mock: MockRouter,
1688 ) -> None:
1689 client = async_client.with_options(max_retries=4)
1690
1691 nb_retries = 0
1692
1693 def retry_handler(_request: httpx.Request) -> httpx.Response:
1694 nonlocal nb_retries
1695 if nb_retries < failures_before_success:
1696 nb_retries += 1
1697 if failure_mode == "exception":
1698 raise RuntimeError("oops")
1699 return httpx.Response(500)
1700 return httpx.Response(200)
1701
1702 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1703
1704 response = await client.chat.completions.with_raw_response.create(
1705 messages=[
1706 {
1707 "content": "string",
1708 "role": "developer",
1709 }
1710 ],
1711 model="gpt-4o",
1712 )
1713
1714 assert response.retries_taken == failures_before_success
1715 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1716
1717 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1718 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1719 @pytest.mark.respx(base_url=base_url)
1720 @pytest.mark.asyncio
1721 async def test_omit_retry_count_header(
1722 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1723 ) -> None:
1724 client = async_client.with_options(max_retries=4)
1725
1726 nb_retries = 0
1727
1728 def retry_handler(_request: httpx.Request) -> httpx.Response:
1729 nonlocal nb_retries
1730 if nb_retries < failures_before_success:
1731 nb_retries += 1
1732 return httpx.Response(500)
1733 return httpx.Response(200)
1734
1735 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1736
1737 response = await client.chat.completions.with_raw_response.create(
1738 messages=[
1739 {
1740 "content": "string",
1741 "role": "developer",
1742 }
1743 ],
1744 model="gpt-4o",
1745 extra_headers={"x-stainless-retry-count": Omit()},
1746 )
1747
1748 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
1749
1750 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1751 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1752 @pytest.mark.respx(base_url=base_url)
1753 @pytest.mark.asyncio
1754 async def test_overwrite_retry_count_header(
1755 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1756 ) -> None:
1757 client = async_client.with_options(max_retries=4)
1758
1759 nb_retries = 0
1760
1761 def retry_handler(_request: httpx.Request) -> httpx.Response:
1762 nonlocal nb_retries
1763 if nb_retries < failures_before_success:
1764 nb_retries += 1
1765 return httpx.Response(500)
1766 return httpx.Response(200)
1767
1768 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1769
1770 response = await client.chat.completions.with_raw_response.create(
1771 messages=[
1772 {
1773 "content": "string",
1774 "role": "developer",
1775 }
1776 ],
1777 model="gpt-4o",
1778 extra_headers={"x-stainless-retry-count": "42"},
1779 )
1780
1781 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
1782
1783 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1784 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1785 @pytest.mark.respx(base_url=base_url)
1786 @pytest.mark.asyncio
1787 async def test_retries_taken_new_response_class(
1788 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1789 ) -> None:
1790 client = async_client.with_options(max_retries=4)
1791
1792 nb_retries = 0
1793
1794 def retry_handler(_request: httpx.Request) -> httpx.Response:
1795 nonlocal nb_retries
1796 if nb_retries < failures_before_success:
1797 nb_retries += 1
1798 return httpx.Response(500)
1799 return httpx.Response(200)
1800
1801 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1802
1803 async with client.chat.completions.with_streaming_response.create(
1804 messages=[
1805 {
1806 "content": "string",
1807 "role": "developer",
1808 }
1809 ],
1810 model="gpt-4o",
1811 ) as response:
1812 assert response.retries_taken == failures_before_success
1813 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1814
1815 def test_get_platform(self) -> None:
1816 # A previous implementation of asyncify could leave threads unterminated when
1817 # used with nest_asyncio.
1818 #
1819 # Since nest_asyncio.apply() is global and cannot be un-applied, this
1820 # test is run in a separate process to avoid affecting other tests.
1821 test_code = dedent("""
1822 import asyncio
1823 import nest_asyncio
1824 import threading
1825
1826 from openai._utils import asyncify
1827 from openai._base_client import get_platform
1828
1829 async def test_main() -> None:
1830 result = await asyncify(get_platform)()
1831 print(result)
1832 for thread in threading.enumerate():
1833 print(thread.name)
1834
1835 nest_asyncio.apply()
1836 asyncio.run(test_main())
1837 """)
1838 with subprocess.Popen(
1839 [sys.executable, "-c", test_code],
1840 text=True,
1841 ) as process:
1842 timeout = 10 # seconds
1843
1844 start_time = time.monotonic()
1845 while True:
1846 return_code = process.poll()
1847 if return_code is not None:
1848 if return_code != 0:
1849 raise AssertionError("calling get_platform using asyncify resulted in a non-zero exit code")
1850
1851 # success
1852 break
1853
1854 if time.monotonic() - start_time > timeout:
1855 process.kill()
1856 raise AssertionError("calling get_platform using asyncify resulted in a hung process")
1857
1858 time.sleep(0.1)
1859
1860 @pytest.mark.respx(base_url=base_url)
1861 async def test_follow_redirects(self, respx_mock: MockRouter) -> None:
1862 # Test that the default follow_redirects=True allows following redirects
1863 respx_mock.post("/redirect").mock(
1864 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1865 )
1866 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
1867
1868 response = await self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
1869 assert response.status_code == 200
1870 assert response.json() == {"status": "ok"}
1871
1872 @pytest.mark.respx(base_url=base_url)
1873 async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
1874 # Test that follow_redirects=False prevents following redirects
1875 respx_mock.post("/redirect").mock(
1876 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1877 )
1878
1879 with pytest.raises(APIStatusError) as exc_info:
1880 await self.client.post(
1881 "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
1882 )
1883
1884 assert exc_info.value.response.status_code == 302
1885 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
1886