openai/openai-python

Public

mirrored from https://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
8ade764fc124bee145990ad59d0d7c4bbe27a754

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1936lines · modecode

1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import gc
6import os
7import sys
8import json
9import time
10import asyncio
11import inspect
12import subprocess
13import tracemalloc
14from typing import Any, Union, cast
15from textwrap import dedent
16from unittest import mock
17from typing_extensions import Literal
18
19import httpx
20import pytest
21from respx import MockRouter
22from pydantic import ValidationError
23
24from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
25from openai._types import Omit
26from openai._utils import maybe_transform
27from openai._models import BaseModel, FinalRequestOptions
28from openai._constants import RAW_RESPONSE_HEADER
29from openai._streaming import Stream, AsyncStream
30from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
31from openai._base_client import (
32 DEFAULT_TIMEOUT,
33 HTTPX_DEFAULT_TIMEOUT,
34 BaseClient,
35 DefaultHttpxClient,
36 DefaultAsyncHttpxClient,
37 make_request_options,
38)
39from openai.types.chat.completion_create_params import CompletionCreateParamsNonStreaming
40
41from .utils import update_env
42
43base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
44api_key = "My API Key"
45
46
47def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
48 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
49 url = httpx.URL(request.url)
50 return dict(url.params)
51
52
53def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
54 return 0.1
55
56
57def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
58 transport = client._client._transport
59 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
60
61 pool = transport._pool
62 return len(pool._requests)
63
64
65class TestOpenAI:
66 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
67
68 @pytest.mark.respx(base_url=base_url)
69 def test_raw_response(self, respx_mock: MockRouter) -> None:
70 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
71
72 response = self.client.post("/foo", cast_to=httpx.Response)
73 assert response.status_code == 200
74 assert isinstance(response, httpx.Response)
75 assert response.json() == {"foo": "bar"}
76
77 @pytest.mark.respx(base_url=base_url)
78 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
79 respx_mock.post("/foo").mock(
80 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
81 )
82
83 response = self.client.post("/foo", cast_to=httpx.Response)
84 assert response.status_code == 200
85 assert isinstance(response, httpx.Response)
86 assert response.json() == {"foo": "bar"}
87
88 def test_copy(self) -> None:
89 copied = self.client.copy()
90 assert id(copied) != id(self.client)
91
92 copied = self.client.copy(api_key="another My API Key")
93 assert copied.api_key == "another My API Key"
94 assert self.client.api_key == "My API Key"
95
96 def test_copy_default_options(self) -> None:
97 # options that have a default are overridden correctly
98 copied = self.client.copy(max_retries=7)
99 assert copied.max_retries == 7
100 assert self.client.max_retries == 2
101
102 copied2 = copied.copy(max_retries=6)
103 assert copied2.max_retries == 6
104 assert copied.max_retries == 7
105
106 # timeout
107 assert isinstance(self.client.timeout, httpx.Timeout)
108 copied = self.client.copy(timeout=None)
109 assert copied.timeout is None
110 assert isinstance(self.client.timeout, httpx.Timeout)
111
112 def test_copy_default_headers(self) -> None:
113 client = OpenAI(
114 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
115 )
116 assert client.default_headers["X-Foo"] == "bar"
117
118 # does not override the already given value when not specified
119 copied = client.copy()
120 assert copied.default_headers["X-Foo"] == "bar"
121
122 # merges already given headers
123 copied = client.copy(default_headers={"X-Bar": "stainless"})
124 assert copied.default_headers["X-Foo"] == "bar"
125 assert copied.default_headers["X-Bar"] == "stainless"
126
127 # uses new values for any already given headers
128 copied = client.copy(default_headers={"X-Foo": "stainless"})
129 assert copied.default_headers["X-Foo"] == "stainless"
130
131 # set_default_headers
132
133 # completely overrides already set values
134 copied = client.copy(set_default_headers={})
135 assert copied.default_headers.get("X-Foo") is None
136
137 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
138 assert copied.default_headers["X-Bar"] == "Robert"
139
140 with pytest.raises(
141 ValueError,
142 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
143 ):
144 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
145
146 def test_copy_default_query(self) -> None:
147 client = OpenAI(
148 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
149 )
150 assert _get_params(client)["foo"] == "bar"
151
152 # does not override the already given value when not specified
153 copied = client.copy()
154 assert _get_params(copied)["foo"] == "bar"
155
156 # merges already given params
157 copied = client.copy(default_query={"bar": "stainless"})
158 params = _get_params(copied)
159 assert params["foo"] == "bar"
160 assert params["bar"] == "stainless"
161
162 # uses new values for any already given headers
163 copied = client.copy(default_query={"foo": "stainless"})
164 assert _get_params(copied)["foo"] == "stainless"
165
166 # set_default_query
167
168 # completely overrides already set values
169 copied = client.copy(set_default_query={})
170 assert _get_params(copied) == {}
171
172 copied = client.copy(set_default_query={"bar": "Robert"})
173 assert _get_params(copied)["bar"] == "Robert"
174
175 with pytest.raises(
176 ValueError,
177 # TODO: update
178 match="`default_query` and `set_default_query` arguments are mutually exclusive",
179 ):
180 client.copy(set_default_query={}, default_query={"foo": "Bar"})
181
182 def test_copy_signature(self) -> None:
183 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
184 init_signature = inspect.signature(
185 # mypy doesn't like that we access the `__init__` property.
186 self.client.__init__, # type: ignore[misc]
187 )
188 copy_signature = inspect.signature(self.client.copy)
189 exclude_params = {"transport", "proxies", "_strict_response_validation"}
190
191 for name in init_signature.parameters.keys():
192 if name in exclude_params:
193 continue
194
195 copy_param = copy_signature.parameters.get(name)
196 assert copy_param is not None, f"copy() signature is missing the {name} param"
197
198 def test_copy_build_request(self) -> None:
199 options = FinalRequestOptions(method="get", url="/foo")
200
201 def build_request(options: FinalRequestOptions) -> None:
202 client = self.client.copy()
203 client._build_request(options)
204
205 # ensure that the machinery is warmed up before tracing starts.
206 build_request(options)
207 gc.collect()
208
209 tracemalloc.start(1000)
210
211 snapshot_before = tracemalloc.take_snapshot()
212
213 ITERATIONS = 10
214 for _ in range(ITERATIONS):
215 build_request(options)
216
217 gc.collect()
218 snapshot_after = tracemalloc.take_snapshot()
219
220 tracemalloc.stop()
221
222 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
223 if diff.count == 0:
224 # Avoid false positives by considering only leaks (i.e. allocations that persist).
225 return
226
227 if diff.count % ITERATIONS != 0:
228 # Avoid false positives by considering only leaks that appear per iteration.
229 return
230
231 for frame in diff.traceback:
232 if any(
233 frame.filename.endswith(fragment)
234 for fragment in [
235 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
236 #
237 # removing the decorator fixes the leak for reasons we don't understand.
238 "openai/_legacy_response.py",
239 "openai/_response.py",
240 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
241 "openai/_compat.py",
242 # Standard library leaks we don't care about.
243 "/logging/__init__.py",
244 ]
245 ):
246 return
247
248 leaks.append(diff)
249
250 leaks: list[tracemalloc.StatisticDiff] = []
251 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
252 add_leak(leaks, diff)
253 if leaks:
254 for leak in leaks:
255 print("MEMORY LEAK:", leak)
256 for frame in leak.traceback:
257 print(frame)
258 raise AssertionError()
259
260 def test_request_timeout(self) -> None:
261 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
262 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
263 assert timeout == DEFAULT_TIMEOUT
264
265 request = self.client._build_request(
266 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
267 )
268 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
269 assert timeout == httpx.Timeout(100.0)
270
271 def test_client_timeout_option(self) -> None:
272 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
273
274 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
275 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
276 assert timeout == httpx.Timeout(0)
277
278 def test_http_client_timeout_option(self) -> None:
279 # custom timeout given to the httpx client should be used
280 with httpx.Client(timeout=None) as http_client:
281 client = OpenAI(
282 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
283 )
284
285 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
286 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
287 assert timeout == httpx.Timeout(None)
288
289 # no timeout given to the httpx client should not use the httpx default
290 with httpx.Client() as http_client:
291 client = OpenAI(
292 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
293 )
294
295 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
296 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
297 assert timeout == DEFAULT_TIMEOUT
298
299 # explicitly passing the default timeout currently results in it being ignored
300 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
301 client = OpenAI(
302 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
303 )
304
305 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
306 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
307 assert timeout == DEFAULT_TIMEOUT # our default
308
309 async def test_invalid_http_client(self) -> None:
310 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
311 async with httpx.AsyncClient() as http_client:
312 OpenAI(
313 base_url=base_url,
314 api_key=api_key,
315 _strict_response_validation=True,
316 http_client=cast(Any, http_client),
317 )
318
319 def test_default_headers_option(self) -> None:
320 client = OpenAI(
321 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
322 )
323 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
324 assert request.headers.get("x-foo") == "bar"
325 assert request.headers.get("x-stainless-lang") == "python"
326
327 client2 = OpenAI(
328 base_url=base_url,
329 api_key=api_key,
330 _strict_response_validation=True,
331 default_headers={
332 "X-Foo": "stainless",
333 "X-Stainless-Lang": "my-overriding-header",
334 },
335 )
336 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
337 assert request.headers.get("x-foo") == "stainless"
338 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
339
340 def test_validate_headers(self) -> None:
341 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
342 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
343 assert request.headers.get("Authorization") == f"Bearer {api_key}"
344
345 with pytest.raises(OpenAIError):
346 with update_env(**{"OPENAI_API_KEY": Omit()}):
347 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
348 _ = client2
349
350 def test_default_query_option(self) -> None:
351 client = OpenAI(
352 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
353 )
354 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
355 url = httpx.URL(request.url)
356 assert dict(url.params) == {"query_param": "bar"}
357
358 request = client._build_request(
359 FinalRequestOptions(
360 method="get",
361 url="/foo",
362 params={"foo": "baz", "query_param": "overridden"},
363 )
364 )
365 url = httpx.URL(request.url)
366 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
367
368 def test_request_extra_json(self) -> None:
369 request = self.client._build_request(
370 FinalRequestOptions(
371 method="post",
372 url="/foo",
373 json_data={"foo": "bar"},
374 extra_json={"baz": False},
375 ),
376 )
377 data = json.loads(request.content.decode("utf-8"))
378 assert data == {"foo": "bar", "baz": False}
379
380 request = self.client._build_request(
381 FinalRequestOptions(
382 method="post",
383 url="/foo",
384 extra_json={"baz": False},
385 ),
386 )
387 data = json.loads(request.content.decode("utf-8"))
388 assert data == {"baz": False}
389
390 # `extra_json` takes priority over `json_data` when keys clash
391 request = self.client._build_request(
392 FinalRequestOptions(
393 method="post",
394 url="/foo",
395 json_data={"foo": "bar", "baz": True},
396 extra_json={"baz": None},
397 ),
398 )
399 data = json.loads(request.content.decode("utf-8"))
400 assert data == {"foo": "bar", "baz": None}
401
402 def test_request_extra_headers(self) -> None:
403 request = self.client._build_request(
404 FinalRequestOptions(
405 method="post",
406 url="/foo",
407 **make_request_options(extra_headers={"X-Foo": "Foo"}),
408 ),
409 )
410 assert request.headers.get("X-Foo") == "Foo"
411
412 # `extra_headers` takes priority over `default_headers` when keys clash
413 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
414 FinalRequestOptions(
415 method="post",
416 url="/foo",
417 **make_request_options(
418 extra_headers={"X-Bar": "false"},
419 ),
420 ),
421 )
422 assert request.headers.get("X-Bar") == "false"
423
424 def test_request_extra_query(self) -> None:
425 request = self.client._build_request(
426 FinalRequestOptions(
427 method="post",
428 url="/foo",
429 **make_request_options(
430 extra_query={"my_query_param": "Foo"},
431 ),
432 ),
433 )
434 params = dict(request.url.params)
435 assert params == {"my_query_param": "Foo"}
436
437 # if both `query` and `extra_query` are given, they are merged
438 request = self.client._build_request(
439 FinalRequestOptions(
440 method="post",
441 url="/foo",
442 **make_request_options(
443 query={"bar": "1"},
444 extra_query={"foo": "2"},
445 ),
446 ),
447 )
448 params = dict(request.url.params)
449 assert params == {"bar": "1", "foo": "2"}
450
451 # `extra_query` takes priority over `query` when keys clash
452 request = self.client._build_request(
453 FinalRequestOptions(
454 method="post",
455 url="/foo",
456 **make_request_options(
457 query={"foo": "1"},
458 extra_query={"foo": "2"},
459 ),
460 ),
461 )
462 params = dict(request.url.params)
463 assert params == {"foo": "2"}
464
465 def test_multipart_repeating_array(self, client: OpenAI) -> None:
466 request = client._build_request(
467 FinalRequestOptions.construct(
468 method="get",
469 url="/foo",
470 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
471 json_data={"array": ["foo", "bar"]},
472 files=[("foo.txt", b"hello world")],
473 )
474 )
475
476 assert request.read().split(b"\r\n") == [
477 b"--6b7ba517decee4a450543ea6ae821c82",
478 b'Content-Disposition: form-data; name="array[]"',
479 b"",
480 b"foo",
481 b"--6b7ba517decee4a450543ea6ae821c82",
482 b'Content-Disposition: form-data; name="array[]"',
483 b"",
484 b"bar",
485 b"--6b7ba517decee4a450543ea6ae821c82",
486 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
487 b"Content-Type: application/octet-stream",
488 b"",
489 b"hello world",
490 b"--6b7ba517decee4a450543ea6ae821c82--",
491 b"",
492 ]
493
494 @pytest.mark.respx(base_url=base_url)
495 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
496 class Model1(BaseModel):
497 name: str
498
499 class Model2(BaseModel):
500 foo: str
501
502 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
503
504 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
505 assert isinstance(response, Model2)
506 assert response.foo == "bar"
507
508 @pytest.mark.respx(base_url=base_url)
509 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
510 """Union of objects with the same field name using a different type"""
511
512 class Model1(BaseModel):
513 foo: int
514
515 class Model2(BaseModel):
516 foo: str
517
518 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
519
520 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
521 assert isinstance(response, Model2)
522 assert response.foo == "bar"
523
524 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
525
526 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
527 assert isinstance(response, Model1)
528 assert response.foo == 1
529
530 @pytest.mark.respx(base_url=base_url)
531 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
532 """
533 Response that sets Content-Type to something other than application/json but returns json data
534 """
535
536 class Model(BaseModel):
537 foo: int
538
539 respx_mock.get("/foo").mock(
540 return_value=httpx.Response(
541 200,
542 content=json.dumps({"foo": 2}),
543 headers={"Content-Type": "application/text"},
544 )
545 )
546
547 response = self.client.get("/foo", cast_to=Model)
548 assert isinstance(response, Model)
549 assert response.foo == 2
550
551 def test_base_url_setter(self) -> None:
552 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
553 assert client.base_url == "https://example.com/from_init/"
554
555 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
556
557 assert client.base_url == "https://example.com/from_setter/"
558
559 def test_base_url_env(self) -> None:
560 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
561 client = OpenAI(api_key=api_key, _strict_response_validation=True)
562 assert client.base_url == "http://localhost:5000/from/env/"
563
564 @pytest.mark.parametrize(
565 "client",
566 [
567 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
568 OpenAI(
569 base_url="http://localhost:5000/custom/path/",
570 api_key=api_key,
571 _strict_response_validation=True,
572 http_client=httpx.Client(),
573 ),
574 ],
575 ids=["standard", "custom http client"],
576 )
577 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
578 request = client._build_request(
579 FinalRequestOptions(
580 method="post",
581 url="/foo",
582 json_data={"foo": "bar"},
583 ),
584 )
585 assert request.url == "http://localhost:5000/custom/path/foo"
586
587 @pytest.mark.parametrize(
588 "client",
589 [
590 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
591 OpenAI(
592 base_url="http://localhost:5000/custom/path/",
593 api_key=api_key,
594 _strict_response_validation=True,
595 http_client=httpx.Client(),
596 ),
597 ],
598 ids=["standard", "custom http client"],
599 )
600 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
601 request = client._build_request(
602 FinalRequestOptions(
603 method="post",
604 url="/foo",
605 json_data={"foo": "bar"},
606 ),
607 )
608 assert request.url == "http://localhost:5000/custom/path/foo"
609
610 @pytest.mark.parametrize(
611 "client",
612 [
613 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
614 OpenAI(
615 base_url="http://localhost:5000/custom/path/",
616 api_key=api_key,
617 _strict_response_validation=True,
618 http_client=httpx.Client(),
619 ),
620 ],
621 ids=["standard", "custom http client"],
622 )
623 def test_absolute_request_url(self, client: OpenAI) -> None:
624 request = client._build_request(
625 FinalRequestOptions(
626 method="post",
627 url="https://myapi.com/foo",
628 json_data={"foo": "bar"},
629 ),
630 )
631 assert request.url == "https://myapi.com/foo"
632
633 def test_copied_client_does_not_close_http(self) -> None:
634 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
635 assert not client.is_closed()
636
637 copied = client.copy()
638 assert copied is not client
639
640 del copied
641
642 assert not client.is_closed()
643
644 def test_client_context_manager(self) -> None:
645 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
646 with client as c2:
647 assert c2 is client
648 assert not c2.is_closed()
649 assert not client.is_closed()
650 assert client.is_closed()
651
652 @pytest.mark.respx(base_url=base_url)
653 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
654 class Model(BaseModel):
655 foo: str
656
657 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
658
659 with pytest.raises(APIResponseValidationError) as exc:
660 self.client.get("/foo", cast_to=Model)
661
662 assert isinstance(exc.value.__cause__, ValidationError)
663
664 def test_client_max_retries_validation(self) -> None:
665 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
666 OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None))
667
668 @pytest.mark.respx(base_url=base_url)
669 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
670 class Model(BaseModel):
671 name: str
672
673 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
674
675 stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
676 assert isinstance(stream, Stream)
677 stream.response.close()
678
679 @pytest.mark.respx(base_url=base_url)
680 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
681 class Model(BaseModel):
682 name: str
683
684 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
685
686 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
687
688 with pytest.raises(APIResponseValidationError):
689 strict_client.get("/foo", cast_to=Model)
690
691 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
692
693 response = client.get("/foo", cast_to=Model)
694 assert isinstance(response, str) # type: ignore[unreachable]
695
696 @pytest.mark.parametrize(
697 "remaining_retries,retry_after,timeout",
698 [
699 [3, "20", 20],
700 [3, "0", 0.5],
701 [3, "-10", 0.5],
702 [3, "60", 60],
703 [3, "61", 0.5],
704 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
705 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
706 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
707 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
708 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
709 [3, "99999999999999999999999999999999999", 0.5],
710 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
711 [3, "", 0.5],
712 [2, "", 0.5 * 2.0],
713 [1, "", 0.5 * 4.0],
714 [-1100, "", 8], # test large number potentially overflowing
715 ],
716 )
717 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
718 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
719 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
720
721 headers = httpx.Headers({"retry-after": retry_after})
722 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
723 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
724 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
725
726 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
727 @pytest.mark.respx(base_url=base_url)
728 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
729 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
730
731 with pytest.raises(APITimeoutError):
732 self.client.post(
733 "/chat/completions",
734 body=cast(
735 object,
736 maybe_transform(
737 dict(
738 messages=[
739 {
740 "role": "user",
741 "content": "Say this is a test",
742 }
743 ],
744 model="gpt-4o",
745 ),
746 CompletionCreateParamsNonStreaming,
747 ),
748 ),
749 cast_to=httpx.Response,
750 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
751 )
752
753 assert _get_open_connections(self.client) == 0
754
755 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
756 @pytest.mark.respx(base_url=base_url)
757 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
758 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
759
760 with pytest.raises(APIStatusError):
761 self.client.post(
762 "/chat/completions",
763 body=cast(
764 object,
765 maybe_transform(
766 dict(
767 messages=[
768 {
769 "role": "user",
770 "content": "Say this is a test",
771 }
772 ],
773 model="gpt-4o",
774 ),
775 CompletionCreateParamsNonStreaming,
776 ),
777 ),
778 cast_to=httpx.Response,
779 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
780 )
781
782 assert _get_open_connections(self.client) == 0
783
784 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
785 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
786 @pytest.mark.respx(base_url=base_url)
787 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
788 def test_retries_taken(
789 self,
790 client: OpenAI,
791 failures_before_success: int,
792 failure_mode: Literal["status", "exception"],
793 respx_mock: MockRouter,
794 ) -> None:
795 client = client.with_options(max_retries=4)
796
797 nb_retries = 0
798
799 def retry_handler(_request: httpx.Request) -> httpx.Response:
800 nonlocal nb_retries
801 if nb_retries < failures_before_success:
802 nb_retries += 1
803 if failure_mode == "exception":
804 raise RuntimeError("oops")
805 return httpx.Response(500)
806 return httpx.Response(200)
807
808 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
809
810 response = client.chat.completions.with_raw_response.create(
811 messages=[
812 {
813 "content": "string",
814 "role": "developer",
815 }
816 ],
817 model="gpt-4o",
818 )
819
820 assert response.retries_taken == failures_before_success
821 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
822
823 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
824 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
825 @pytest.mark.respx(base_url=base_url)
826 def test_omit_retry_count_header(
827 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
828 ) -> None:
829 client = client.with_options(max_retries=4)
830
831 nb_retries = 0
832
833 def retry_handler(_request: httpx.Request) -> httpx.Response:
834 nonlocal nb_retries
835 if nb_retries < failures_before_success:
836 nb_retries += 1
837 return httpx.Response(500)
838 return httpx.Response(200)
839
840 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
841
842 response = client.chat.completions.with_raw_response.create(
843 messages=[
844 {
845 "content": "string",
846 "role": "developer",
847 }
848 ],
849 model="gpt-4o",
850 extra_headers={"x-stainless-retry-count": Omit()},
851 )
852
853 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
854
855 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
856 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
857 @pytest.mark.respx(base_url=base_url)
858 def test_overwrite_retry_count_header(
859 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
860 ) -> None:
861 client = client.with_options(max_retries=4)
862
863 nb_retries = 0
864
865 def retry_handler(_request: httpx.Request) -> httpx.Response:
866 nonlocal nb_retries
867 if nb_retries < failures_before_success:
868 nb_retries += 1
869 return httpx.Response(500)
870 return httpx.Response(200)
871
872 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
873
874 response = client.chat.completions.with_raw_response.create(
875 messages=[
876 {
877 "content": "string",
878 "role": "developer",
879 }
880 ],
881 model="gpt-4o",
882 extra_headers={"x-stainless-retry-count": "42"},
883 )
884
885 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
886
887 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
888 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
889 @pytest.mark.respx(base_url=base_url)
890 def test_retries_taken_new_response_class(
891 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
892 ) -> None:
893 client = client.with_options(max_retries=4)
894
895 nb_retries = 0
896
897 def retry_handler(_request: httpx.Request) -> httpx.Response:
898 nonlocal nb_retries
899 if nb_retries < failures_before_success:
900 nb_retries += 1
901 return httpx.Response(500)
902 return httpx.Response(200)
903
904 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
905
906 with client.chat.completions.with_streaming_response.create(
907 messages=[
908 {
909 "content": "string",
910 "role": "developer",
911 }
912 ],
913 model="gpt-4o",
914 ) as response:
915 assert response.retries_taken == failures_before_success
916 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
917
918 def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
919 # Test that the proxy environment variables are set correctly
920 monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
921
922 client = DefaultHttpxClient()
923
924 mounts = tuple(client._mounts.items())
925 assert len(mounts) == 1
926 assert mounts[0][0].pattern == "https://"
927
928 @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
929 def test_default_client_creation(self) -> None:
930 # Ensure that the client can be initialized without any exceptions
931 DefaultHttpxClient(
932 verify=True,
933 cert=None,
934 trust_env=True,
935 http1=True,
936 http2=False,
937 limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
938 )
939
940 @pytest.mark.respx(base_url=base_url)
941 def test_follow_redirects(self, respx_mock: MockRouter) -> None:
942 # Test that the default follow_redirects=True allows following redirects
943 respx_mock.post("/redirect").mock(
944 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
945 )
946 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
947
948 response = self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
949 assert response.status_code == 200
950 assert response.json() == {"status": "ok"}
951
952 @pytest.mark.respx(base_url=base_url)
953 def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
954 # Test that follow_redirects=False prevents following redirects
955 respx_mock.post("/redirect").mock(
956 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
957 )
958
959 with pytest.raises(APIStatusError) as exc_info:
960 self.client.post(
961 "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
962 )
963
964 assert exc_info.value.response.status_code == 302
965 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
966
967
968class TestAsyncOpenAI:
969 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
970
971 @pytest.mark.respx(base_url=base_url)
972 @pytest.mark.asyncio
973 async def test_raw_response(self, respx_mock: MockRouter) -> None:
974 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
975
976 response = await self.client.post("/foo", cast_to=httpx.Response)
977 assert response.status_code == 200
978 assert isinstance(response, httpx.Response)
979 assert response.json() == {"foo": "bar"}
980
981 @pytest.mark.respx(base_url=base_url)
982 @pytest.mark.asyncio
983 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
984 respx_mock.post("/foo").mock(
985 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
986 )
987
988 response = await self.client.post("/foo", cast_to=httpx.Response)
989 assert response.status_code == 200
990 assert isinstance(response, httpx.Response)
991 assert response.json() == {"foo": "bar"}
992
993 def test_copy(self) -> None:
994 copied = self.client.copy()
995 assert id(copied) != id(self.client)
996
997 copied = self.client.copy(api_key="another My API Key")
998 assert copied.api_key == "another My API Key"
999 assert self.client.api_key == "My API Key"
1000
1001 def test_copy_default_options(self) -> None:
1002 # options that have a default are overridden correctly
1003 copied = self.client.copy(max_retries=7)
1004 assert copied.max_retries == 7
1005 assert self.client.max_retries == 2
1006
1007 copied2 = copied.copy(max_retries=6)
1008 assert copied2.max_retries == 6
1009 assert copied.max_retries == 7
1010
1011 # timeout
1012 assert isinstance(self.client.timeout, httpx.Timeout)
1013 copied = self.client.copy(timeout=None)
1014 assert copied.timeout is None
1015 assert isinstance(self.client.timeout, httpx.Timeout)
1016
1017 def test_copy_default_headers(self) -> None:
1018 client = AsyncOpenAI(
1019 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1020 )
1021 assert client.default_headers["X-Foo"] == "bar"
1022
1023 # does not override the already given value when not specified
1024 copied = client.copy()
1025 assert copied.default_headers["X-Foo"] == "bar"
1026
1027 # merges already given headers
1028 copied = client.copy(default_headers={"X-Bar": "stainless"})
1029 assert copied.default_headers["X-Foo"] == "bar"
1030 assert copied.default_headers["X-Bar"] == "stainless"
1031
1032 # uses new values for any already given headers
1033 copied = client.copy(default_headers={"X-Foo": "stainless"})
1034 assert copied.default_headers["X-Foo"] == "stainless"
1035
1036 # set_default_headers
1037
1038 # completely overrides already set values
1039 copied = client.copy(set_default_headers={})
1040 assert copied.default_headers.get("X-Foo") is None
1041
1042 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
1043 assert copied.default_headers["X-Bar"] == "Robert"
1044
1045 with pytest.raises(
1046 ValueError,
1047 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
1048 ):
1049 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
1050
1051 def test_copy_default_query(self) -> None:
1052 client = AsyncOpenAI(
1053 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
1054 )
1055 assert _get_params(client)["foo"] == "bar"
1056
1057 # does not override the already given value when not specified
1058 copied = client.copy()
1059 assert _get_params(copied)["foo"] == "bar"
1060
1061 # merges already given params
1062 copied = client.copy(default_query={"bar": "stainless"})
1063 params = _get_params(copied)
1064 assert params["foo"] == "bar"
1065 assert params["bar"] == "stainless"
1066
1067 # uses new values for any already given headers
1068 copied = client.copy(default_query={"foo": "stainless"})
1069 assert _get_params(copied)["foo"] == "stainless"
1070
1071 # set_default_query
1072
1073 # completely overrides already set values
1074 copied = client.copy(set_default_query={})
1075 assert _get_params(copied) == {}
1076
1077 copied = client.copy(set_default_query={"bar": "Robert"})
1078 assert _get_params(copied)["bar"] == "Robert"
1079
1080 with pytest.raises(
1081 ValueError,
1082 # TODO: update
1083 match="`default_query` and `set_default_query` arguments are mutually exclusive",
1084 ):
1085 client.copy(set_default_query={}, default_query={"foo": "Bar"})
1086
1087 def test_copy_signature(self) -> None:
1088 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
1089 init_signature = inspect.signature(
1090 # mypy doesn't like that we access the `__init__` property.
1091 self.client.__init__, # type: ignore[misc]
1092 )
1093 copy_signature = inspect.signature(self.client.copy)
1094 exclude_params = {"transport", "proxies", "_strict_response_validation"}
1095
1096 for name in init_signature.parameters.keys():
1097 if name in exclude_params:
1098 continue
1099
1100 copy_param = copy_signature.parameters.get(name)
1101 assert copy_param is not None, f"copy() signature is missing the {name} param"
1102
1103 def test_copy_build_request(self) -> None:
1104 options = FinalRequestOptions(method="get", url="/foo")
1105
1106 def build_request(options: FinalRequestOptions) -> None:
1107 client = self.client.copy()
1108 client._build_request(options)
1109
1110 # ensure that the machinery is warmed up before tracing starts.
1111 build_request(options)
1112 gc.collect()
1113
1114 tracemalloc.start(1000)
1115
1116 snapshot_before = tracemalloc.take_snapshot()
1117
1118 ITERATIONS = 10
1119 for _ in range(ITERATIONS):
1120 build_request(options)
1121
1122 gc.collect()
1123 snapshot_after = tracemalloc.take_snapshot()
1124
1125 tracemalloc.stop()
1126
1127 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
1128 if diff.count == 0:
1129 # Avoid false positives by considering only leaks (i.e. allocations that persist).
1130 return
1131
1132 if diff.count % ITERATIONS != 0:
1133 # Avoid false positives by considering only leaks that appear per iteration.
1134 return
1135
1136 for frame in diff.traceback:
1137 if any(
1138 frame.filename.endswith(fragment)
1139 for fragment in [
1140 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
1141 #
1142 # removing the decorator fixes the leak for reasons we don't understand.
1143 "openai/_legacy_response.py",
1144 "openai/_response.py",
1145 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
1146 "openai/_compat.py",
1147 # Standard library leaks we don't care about.
1148 "/logging/__init__.py",
1149 ]
1150 ):
1151 return
1152
1153 leaks.append(diff)
1154
1155 leaks: list[tracemalloc.StatisticDiff] = []
1156 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
1157 add_leak(leaks, diff)
1158 if leaks:
1159 for leak in leaks:
1160 print("MEMORY LEAK:", leak)
1161 for frame in leak.traceback:
1162 print(frame)
1163 raise AssertionError()
1164
1165 async def test_request_timeout(self) -> None:
1166 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
1167 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1168 assert timeout == DEFAULT_TIMEOUT
1169
1170 request = self.client._build_request(
1171 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
1172 )
1173 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1174 assert timeout == httpx.Timeout(100.0)
1175
1176 async def test_client_timeout_option(self) -> None:
1177 client = AsyncOpenAI(
1178 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
1179 )
1180
1181 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1182 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1183 assert timeout == httpx.Timeout(0)
1184
1185 async def test_http_client_timeout_option(self) -> None:
1186 # custom timeout given to the httpx client should be used
1187 async with httpx.AsyncClient(timeout=None) as http_client:
1188 client = AsyncOpenAI(
1189 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1190 )
1191
1192 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1193 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1194 assert timeout == httpx.Timeout(None)
1195
1196 # no timeout given to the httpx client should not use the httpx default
1197 async with httpx.AsyncClient() as http_client:
1198 client = AsyncOpenAI(
1199 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1200 )
1201
1202 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1203 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1204 assert timeout == DEFAULT_TIMEOUT
1205
1206 # explicitly passing the default timeout currently results in it being ignored
1207 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1208 client = AsyncOpenAI(
1209 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1210 )
1211
1212 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1213 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1214 assert timeout == DEFAULT_TIMEOUT # our default
1215
1216 def test_invalid_http_client(self) -> None:
1217 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1218 with httpx.Client() as http_client:
1219 AsyncOpenAI(
1220 base_url=base_url,
1221 api_key=api_key,
1222 _strict_response_validation=True,
1223 http_client=cast(Any, http_client),
1224 )
1225
1226 def test_default_headers_option(self) -> None:
1227 client = AsyncOpenAI(
1228 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1229 )
1230 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1231 assert request.headers.get("x-foo") == "bar"
1232 assert request.headers.get("x-stainless-lang") == "python"
1233
1234 client2 = AsyncOpenAI(
1235 base_url=base_url,
1236 api_key=api_key,
1237 _strict_response_validation=True,
1238 default_headers={
1239 "X-Foo": "stainless",
1240 "X-Stainless-Lang": "my-overriding-header",
1241 },
1242 )
1243 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1244 assert request.headers.get("x-foo") == "stainless"
1245 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1246
1247 def test_validate_headers(self) -> None:
1248 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1249 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1250 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1251
1252 with pytest.raises(OpenAIError):
1253 with update_env(**{"OPENAI_API_KEY": Omit()}):
1254 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1255 _ = client2
1256
1257 def test_default_query_option(self) -> None:
1258 client = AsyncOpenAI(
1259 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1260 )
1261 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1262 url = httpx.URL(request.url)
1263 assert dict(url.params) == {"query_param": "bar"}
1264
1265 request = client._build_request(
1266 FinalRequestOptions(
1267 method="get",
1268 url="/foo",
1269 params={"foo": "baz", "query_param": "overridden"},
1270 )
1271 )
1272 url = httpx.URL(request.url)
1273 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
1274
1275 def test_request_extra_json(self) -> None:
1276 request = self.client._build_request(
1277 FinalRequestOptions(
1278 method="post",
1279 url="/foo",
1280 json_data={"foo": "bar"},
1281 extra_json={"baz": False},
1282 ),
1283 )
1284 data = json.loads(request.content.decode("utf-8"))
1285 assert data == {"foo": "bar", "baz": False}
1286
1287 request = self.client._build_request(
1288 FinalRequestOptions(
1289 method="post",
1290 url="/foo",
1291 extra_json={"baz": False},
1292 ),
1293 )
1294 data = json.loads(request.content.decode("utf-8"))
1295 assert data == {"baz": False}
1296
1297 # `extra_json` takes priority over `json_data` when keys clash
1298 request = self.client._build_request(
1299 FinalRequestOptions(
1300 method="post",
1301 url="/foo",
1302 json_data={"foo": "bar", "baz": True},
1303 extra_json={"baz": None},
1304 ),
1305 )
1306 data = json.loads(request.content.decode("utf-8"))
1307 assert data == {"foo": "bar", "baz": None}
1308
1309 def test_request_extra_headers(self) -> None:
1310 request = self.client._build_request(
1311 FinalRequestOptions(
1312 method="post",
1313 url="/foo",
1314 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1315 ),
1316 )
1317 assert request.headers.get("X-Foo") == "Foo"
1318
1319 # `extra_headers` takes priority over `default_headers` when keys clash
1320 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1321 FinalRequestOptions(
1322 method="post",
1323 url="/foo",
1324 **make_request_options(
1325 extra_headers={"X-Bar": "false"},
1326 ),
1327 ),
1328 )
1329 assert request.headers.get("X-Bar") == "false"
1330
1331 def test_request_extra_query(self) -> None:
1332 request = self.client._build_request(
1333 FinalRequestOptions(
1334 method="post",
1335 url="/foo",
1336 **make_request_options(
1337 extra_query={"my_query_param": "Foo"},
1338 ),
1339 ),
1340 )
1341 params = dict(request.url.params)
1342 assert params == {"my_query_param": "Foo"}
1343
1344 # if both `query` and `extra_query` are given, they are merged
1345 request = self.client._build_request(
1346 FinalRequestOptions(
1347 method="post",
1348 url="/foo",
1349 **make_request_options(
1350 query={"bar": "1"},
1351 extra_query={"foo": "2"},
1352 ),
1353 ),
1354 )
1355 params = dict(request.url.params)
1356 assert params == {"bar": "1", "foo": "2"}
1357
1358 # `extra_query` takes priority over `query` when keys clash
1359 request = self.client._build_request(
1360 FinalRequestOptions(
1361 method="post",
1362 url="/foo",
1363 **make_request_options(
1364 query={"foo": "1"},
1365 extra_query={"foo": "2"},
1366 ),
1367 ),
1368 )
1369 params = dict(request.url.params)
1370 assert params == {"foo": "2"}
1371
1372 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1373 request = async_client._build_request(
1374 FinalRequestOptions.construct(
1375 method="get",
1376 url="/foo",
1377 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1378 json_data={"array": ["foo", "bar"]},
1379 files=[("foo.txt", b"hello world")],
1380 )
1381 )
1382
1383 assert request.read().split(b"\r\n") == [
1384 b"--6b7ba517decee4a450543ea6ae821c82",
1385 b'Content-Disposition: form-data; name="array[]"',
1386 b"",
1387 b"foo",
1388 b"--6b7ba517decee4a450543ea6ae821c82",
1389 b'Content-Disposition: form-data; name="array[]"',
1390 b"",
1391 b"bar",
1392 b"--6b7ba517decee4a450543ea6ae821c82",
1393 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1394 b"Content-Type: application/octet-stream",
1395 b"",
1396 b"hello world",
1397 b"--6b7ba517decee4a450543ea6ae821c82--",
1398 b"",
1399 ]
1400
1401 @pytest.mark.respx(base_url=base_url)
1402 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1403 class Model1(BaseModel):
1404 name: str
1405
1406 class Model2(BaseModel):
1407 foo: str
1408
1409 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1410
1411 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1412 assert isinstance(response, Model2)
1413 assert response.foo == "bar"
1414
1415 @pytest.mark.respx(base_url=base_url)
1416 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1417 """Union of objects with the same field name using a different type"""
1418
1419 class Model1(BaseModel):
1420 foo: int
1421
1422 class Model2(BaseModel):
1423 foo: str
1424
1425 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1426
1427 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1428 assert isinstance(response, Model2)
1429 assert response.foo == "bar"
1430
1431 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1432
1433 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1434 assert isinstance(response, Model1)
1435 assert response.foo == 1
1436
1437 @pytest.mark.respx(base_url=base_url)
1438 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1439 """
1440 Response that sets Content-Type to something other than application/json but returns json data
1441 """
1442
1443 class Model(BaseModel):
1444 foo: int
1445
1446 respx_mock.get("/foo").mock(
1447 return_value=httpx.Response(
1448 200,
1449 content=json.dumps({"foo": 2}),
1450 headers={"Content-Type": "application/text"},
1451 )
1452 )
1453
1454 response = await self.client.get("/foo", cast_to=Model)
1455 assert isinstance(response, Model)
1456 assert response.foo == 2
1457
1458 def test_base_url_setter(self) -> None:
1459 client = AsyncOpenAI(
1460 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1461 )
1462 assert client.base_url == "https://example.com/from_init/"
1463
1464 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1465
1466 assert client.base_url == "https://example.com/from_setter/"
1467
1468 def test_base_url_env(self) -> None:
1469 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1470 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1471 assert client.base_url == "http://localhost:5000/from/env/"
1472
1473 @pytest.mark.parametrize(
1474 "client",
1475 [
1476 AsyncOpenAI(
1477 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1478 ),
1479 AsyncOpenAI(
1480 base_url="http://localhost:5000/custom/path/",
1481 api_key=api_key,
1482 _strict_response_validation=True,
1483 http_client=httpx.AsyncClient(),
1484 ),
1485 ],
1486 ids=["standard", "custom http client"],
1487 )
1488 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1489 request = client._build_request(
1490 FinalRequestOptions(
1491 method="post",
1492 url="/foo",
1493 json_data={"foo": "bar"},
1494 ),
1495 )
1496 assert request.url == "http://localhost:5000/custom/path/foo"
1497
1498 @pytest.mark.parametrize(
1499 "client",
1500 [
1501 AsyncOpenAI(
1502 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1503 ),
1504 AsyncOpenAI(
1505 base_url="http://localhost:5000/custom/path/",
1506 api_key=api_key,
1507 _strict_response_validation=True,
1508 http_client=httpx.AsyncClient(),
1509 ),
1510 ],
1511 ids=["standard", "custom http client"],
1512 )
1513 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1514 request = client._build_request(
1515 FinalRequestOptions(
1516 method="post",
1517 url="/foo",
1518 json_data={"foo": "bar"},
1519 ),
1520 )
1521 assert request.url == "http://localhost:5000/custom/path/foo"
1522
1523 @pytest.mark.parametrize(
1524 "client",
1525 [
1526 AsyncOpenAI(
1527 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1528 ),
1529 AsyncOpenAI(
1530 base_url="http://localhost:5000/custom/path/",
1531 api_key=api_key,
1532 _strict_response_validation=True,
1533 http_client=httpx.AsyncClient(),
1534 ),
1535 ],
1536 ids=["standard", "custom http client"],
1537 )
1538 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1539 request = client._build_request(
1540 FinalRequestOptions(
1541 method="post",
1542 url="https://myapi.com/foo",
1543 json_data={"foo": "bar"},
1544 ),
1545 )
1546 assert request.url == "https://myapi.com/foo"
1547
1548 async def test_copied_client_does_not_close_http(self) -> None:
1549 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1550 assert not client.is_closed()
1551
1552 copied = client.copy()
1553 assert copied is not client
1554
1555 del copied
1556
1557 await asyncio.sleep(0.2)
1558 assert not client.is_closed()
1559
1560 async def test_client_context_manager(self) -> None:
1561 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1562 async with client as c2:
1563 assert c2 is client
1564 assert not c2.is_closed()
1565 assert not client.is_closed()
1566 assert client.is_closed()
1567
1568 @pytest.mark.respx(base_url=base_url)
1569 @pytest.mark.asyncio
1570 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1571 class Model(BaseModel):
1572 foo: str
1573
1574 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1575
1576 with pytest.raises(APIResponseValidationError) as exc:
1577 await self.client.get("/foo", cast_to=Model)
1578
1579 assert isinstance(exc.value.__cause__, ValidationError)
1580
1581 async def test_client_max_retries_validation(self) -> None:
1582 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
1583 AsyncOpenAI(
1584 base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)
1585 )
1586
1587 @pytest.mark.respx(base_url=base_url)
1588 @pytest.mark.asyncio
1589 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1590 class Model(BaseModel):
1591 name: str
1592
1593 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1594
1595 stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1596 assert isinstance(stream, AsyncStream)
1597 await stream.response.aclose()
1598
1599 @pytest.mark.respx(base_url=base_url)
1600 @pytest.mark.asyncio
1601 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1602 class Model(BaseModel):
1603 name: str
1604
1605 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1606
1607 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1608
1609 with pytest.raises(APIResponseValidationError):
1610 await strict_client.get("/foo", cast_to=Model)
1611
1612 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1613
1614 response = await client.get("/foo", cast_to=Model)
1615 assert isinstance(response, str) # type: ignore[unreachable]
1616
1617 @pytest.mark.parametrize(
1618 "remaining_retries,retry_after,timeout",
1619 [
1620 [3, "20", 20],
1621 [3, "0", 0.5],
1622 [3, "-10", 0.5],
1623 [3, "60", 60],
1624 [3, "61", 0.5],
1625 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1626 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1627 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1628 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1629 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1630 [3, "99999999999999999999999999999999999", 0.5],
1631 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1632 [3, "", 0.5],
1633 [2, "", 0.5 * 2.0],
1634 [1, "", 0.5 * 4.0],
1635 [-1100, "", 8], # test large number potentially overflowing
1636 ],
1637 )
1638 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1639 @pytest.mark.asyncio
1640 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1641 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1642
1643 headers = httpx.Headers({"retry-after": retry_after})
1644 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1645 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1646 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1647
1648 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1649 @pytest.mark.respx(base_url=base_url)
1650 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1651 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1652
1653 with pytest.raises(APITimeoutError):
1654 await self.client.post(
1655 "/chat/completions",
1656 body=cast(
1657 object,
1658 maybe_transform(
1659 dict(
1660 messages=[
1661 {
1662 "role": "user",
1663 "content": "Say this is a test",
1664 }
1665 ],
1666 model="gpt-4o",
1667 ),
1668 CompletionCreateParamsNonStreaming,
1669 ),
1670 ),
1671 cast_to=httpx.Response,
1672 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1673 )
1674
1675 assert _get_open_connections(self.client) == 0
1676
1677 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1678 @pytest.mark.respx(base_url=base_url)
1679 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1680 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1681
1682 with pytest.raises(APIStatusError):
1683 await self.client.post(
1684 "/chat/completions",
1685 body=cast(
1686 object,
1687 maybe_transform(
1688 dict(
1689 messages=[
1690 {
1691 "role": "user",
1692 "content": "Say this is a test",
1693 }
1694 ],
1695 model="gpt-4o",
1696 ),
1697 CompletionCreateParamsNonStreaming,
1698 ),
1699 ),
1700 cast_to=httpx.Response,
1701 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1702 )
1703
1704 assert _get_open_connections(self.client) == 0
1705
1706 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1707 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1708 @pytest.mark.respx(base_url=base_url)
1709 @pytest.mark.asyncio
1710 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
1711 async def test_retries_taken(
1712 self,
1713 async_client: AsyncOpenAI,
1714 failures_before_success: int,
1715 failure_mode: Literal["status", "exception"],
1716 respx_mock: MockRouter,
1717 ) -> None:
1718 client = async_client.with_options(max_retries=4)
1719
1720 nb_retries = 0
1721
1722 def retry_handler(_request: httpx.Request) -> httpx.Response:
1723 nonlocal nb_retries
1724 if nb_retries < failures_before_success:
1725 nb_retries += 1
1726 if failure_mode == "exception":
1727 raise RuntimeError("oops")
1728 return httpx.Response(500)
1729 return httpx.Response(200)
1730
1731 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1732
1733 response = await client.chat.completions.with_raw_response.create(
1734 messages=[
1735 {
1736 "content": "string",
1737 "role": "developer",
1738 }
1739 ],
1740 model="gpt-4o",
1741 )
1742
1743 assert response.retries_taken == failures_before_success
1744 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1745
1746 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1747 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1748 @pytest.mark.respx(base_url=base_url)
1749 @pytest.mark.asyncio
1750 async def test_omit_retry_count_header(
1751 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1752 ) -> None:
1753 client = async_client.with_options(max_retries=4)
1754
1755 nb_retries = 0
1756
1757 def retry_handler(_request: httpx.Request) -> httpx.Response:
1758 nonlocal nb_retries
1759 if nb_retries < failures_before_success:
1760 nb_retries += 1
1761 return httpx.Response(500)
1762 return httpx.Response(200)
1763
1764 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1765
1766 response = await client.chat.completions.with_raw_response.create(
1767 messages=[
1768 {
1769 "content": "string",
1770 "role": "developer",
1771 }
1772 ],
1773 model="gpt-4o",
1774 extra_headers={"x-stainless-retry-count": Omit()},
1775 )
1776
1777 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
1778
1779 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1780 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1781 @pytest.mark.respx(base_url=base_url)
1782 @pytest.mark.asyncio
1783 async def test_overwrite_retry_count_header(
1784 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1785 ) -> None:
1786 client = async_client.with_options(max_retries=4)
1787
1788 nb_retries = 0
1789
1790 def retry_handler(_request: httpx.Request) -> httpx.Response:
1791 nonlocal nb_retries
1792 if nb_retries < failures_before_success:
1793 nb_retries += 1
1794 return httpx.Response(500)
1795 return httpx.Response(200)
1796
1797 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1798
1799 response = await client.chat.completions.with_raw_response.create(
1800 messages=[
1801 {
1802 "content": "string",
1803 "role": "developer",
1804 }
1805 ],
1806 model="gpt-4o",
1807 extra_headers={"x-stainless-retry-count": "42"},
1808 )
1809
1810 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
1811
1812 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1813 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1814 @pytest.mark.respx(base_url=base_url)
1815 @pytest.mark.asyncio
1816 async def test_retries_taken_new_response_class(
1817 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1818 ) -> None:
1819 client = async_client.with_options(max_retries=4)
1820
1821 nb_retries = 0
1822
1823 def retry_handler(_request: httpx.Request) -> httpx.Response:
1824 nonlocal nb_retries
1825 if nb_retries < failures_before_success:
1826 nb_retries += 1
1827 return httpx.Response(500)
1828 return httpx.Response(200)
1829
1830 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1831
1832 async with client.chat.completions.with_streaming_response.create(
1833 messages=[
1834 {
1835 "content": "string",
1836 "role": "developer",
1837 }
1838 ],
1839 model="gpt-4o",
1840 ) as response:
1841 assert response.retries_taken == failures_before_success
1842 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1843
1844 def test_get_platform(self) -> None:
1845 # A previous implementation of asyncify could leave threads unterminated when
1846 # used with nest_asyncio.
1847 #
1848 # Since nest_asyncio.apply() is global and cannot be un-applied, this
1849 # test is run in a separate process to avoid affecting other tests.
1850 test_code = dedent("""
1851 import asyncio
1852 import nest_asyncio
1853 import threading
1854
1855 from openai._utils import asyncify
1856 from openai._base_client import get_platform
1857
1858 async def test_main() -> None:
1859 result = await asyncify(get_platform)()
1860 print(result)
1861 for thread in threading.enumerate():
1862 print(thread.name)
1863
1864 nest_asyncio.apply()
1865 asyncio.run(test_main())
1866 """)
1867 with subprocess.Popen(
1868 [sys.executable, "-c", test_code],
1869 text=True,
1870 ) as process:
1871 timeout = 10 # seconds
1872
1873 start_time = time.monotonic()
1874 while True:
1875 return_code = process.poll()
1876 if return_code is not None:
1877 if return_code != 0:
1878 raise AssertionError("calling get_platform using asyncify resulted in a non-zero exit code")
1879
1880 # success
1881 break
1882
1883 if time.monotonic() - start_time > timeout:
1884 process.kill()
1885 raise AssertionError("calling get_platform using asyncify resulted in a hung process")
1886
1887 time.sleep(0.1)
1888
1889 async def test_proxy_environment_variables(self, monkeypatch: pytest.MonkeyPatch) -> None:
1890 # Test that the proxy environment variables are set correctly
1891 monkeypatch.setenv("HTTPS_PROXY", "https://example.org")
1892
1893 client = DefaultAsyncHttpxClient()
1894
1895 mounts = tuple(client._mounts.items())
1896 assert len(mounts) == 1
1897 assert mounts[0][0].pattern == "https://"
1898
1899 @pytest.mark.filterwarnings("ignore:.*deprecated.*:DeprecationWarning")
1900 async def test_default_client_creation(self) -> None:
1901 # Ensure that the client can be initialized without any exceptions
1902 DefaultAsyncHttpxClient(
1903 verify=True,
1904 cert=None,
1905 trust_env=True,
1906 http1=True,
1907 http2=False,
1908 limits=httpx.Limits(max_connections=100, max_keepalive_connections=20),
1909 )
1910
1911 @pytest.mark.respx(base_url=base_url)
1912 async def test_follow_redirects(self, respx_mock: MockRouter) -> None:
1913 # Test that the default follow_redirects=True allows following redirects
1914 respx_mock.post("/redirect").mock(
1915 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1916 )
1917 respx_mock.get("/redirected").mock(return_value=httpx.Response(200, json={"status": "ok"}))
1918
1919 response = await self.client.post("/redirect", body={"key": "value"}, cast_to=httpx.Response)
1920 assert response.status_code == 200
1921 assert response.json() == {"status": "ok"}
1922
1923 @pytest.mark.respx(base_url=base_url)
1924 async def test_follow_redirects_disabled(self, respx_mock: MockRouter) -> None:
1925 # Test that follow_redirects=False prevents following redirects
1926 respx_mock.post("/redirect").mock(
1927 return_value=httpx.Response(302, headers={"Location": f"{base_url}/redirected"})
1928 )
1929
1930 with pytest.raises(APIStatusError) as exc_info:
1931 await self.client.post(
1932 "/redirect", body={"key": "value"}, options={"follow_redirects": False}, cast_to=httpx.Response
1933 )
1934
1935 assert exc_info.value.response.status_code == 302
1936 assert exc_info.value.response.headers["Location"] == f"{base_url}/redirected"
1937