openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
5fdba4864b5a9dc00f50a939fcf40b992a550db9

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1806lines · modecode

1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import gc
6import os
7import sys
8import json
9import asyncio
10import inspect
11import subprocess
12import tracemalloc
13from typing import Any, Union, cast
14from textwrap import dedent
15from unittest import mock
16from typing_extensions import Literal
17
18import httpx
19import pytest
20from respx import MockRouter
21from pydantic import ValidationError
22
23from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
24from openai._types import Omit
25from openai._models import BaseModel, FinalRequestOptions
26from openai._constants import RAW_RESPONSE_HEADER
27from openai._streaming import Stream, AsyncStream
28from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
29from openai._base_client import DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, BaseClient, make_request_options
30
31from .utils import update_env
32
33base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
34api_key = "My API Key"
35
36
37def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
38 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
39 url = httpx.URL(request.url)
40 return dict(url.params)
41
42
43def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
44 return 0.1
45
46
47def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
48 transport = client._client._transport
49 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
50
51 pool = transport._pool
52 return len(pool._requests)
53
54
55class TestOpenAI:
56 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
57
58 @pytest.mark.respx(base_url=base_url)
59 def test_raw_response(self, respx_mock: MockRouter) -> None:
60 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
61
62 response = self.client.post("/foo", cast_to=httpx.Response)
63 assert response.status_code == 200
64 assert isinstance(response, httpx.Response)
65 assert response.json() == {"foo": "bar"}
66
67 @pytest.mark.respx(base_url=base_url)
68 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
69 respx_mock.post("/foo").mock(
70 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
71 )
72
73 response = self.client.post("/foo", cast_to=httpx.Response)
74 assert response.status_code == 200
75 assert isinstance(response, httpx.Response)
76 assert response.json() == {"foo": "bar"}
77
78 def test_copy(self) -> None:
79 copied = self.client.copy()
80 assert id(copied) != id(self.client)
81
82 copied = self.client.copy(api_key="another My API Key")
83 assert copied.api_key == "another My API Key"
84 assert self.client.api_key == "My API Key"
85
86 def test_copy_default_options(self) -> None:
87 # options that have a default are overridden correctly
88 copied = self.client.copy(max_retries=7)
89 assert copied.max_retries == 7
90 assert self.client.max_retries == 2
91
92 copied2 = copied.copy(max_retries=6)
93 assert copied2.max_retries == 6
94 assert copied.max_retries == 7
95
96 # timeout
97 assert isinstance(self.client.timeout, httpx.Timeout)
98 copied = self.client.copy(timeout=None)
99 assert copied.timeout is None
100 assert isinstance(self.client.timeout, httpx.Timeout)
101
102 def test_copy_default_headers(self) -> None:
103 client = OpenAI(
104 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
105 )
106 assert client.default_headers["X-Foo"] == "bar"
107
108 # does not override the already given value when not specified
109 copied = client.copy()
110 assert copied.default_headers["X-Foo"] == "bar"
111
112 # merges already given headers
113 copied = client.copy(default_headers={"X-Bar": "stainless"})
114 assert copied.default_headers["X-Foo"] == "bar"
115 assert copied.default_headers["X-Bar"] == "stainless"
116
117 # uses new values for any already given headers
118 copied = client.copy(default_headers={"X-Foo": "stainless"})
119 assert copied.default_headers["X-Foo"] == "stainless"
120
121 # set_default_headers
122
123 # completely overrides already set values
124 copied = client.copy(set_default_headers={})
125 assert copied.default_headers.get("X-Foo") is None
126
127 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
128 assert copied.default_headers["X-Bar"] == "Robert"
129
130 with pytest.raises(
131 ValueError,
132 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
133 ):
134 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
135
136 def test_copy_default_query(self) -> None:
137 client = OpenAI(
138 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
139 )
140 assert _get_params(client)["foo"] == "bar"
141
142 # does not override the already given value when not specified
143 copied = client.copy()
144 assert _get_params(copied)["foo"] == "bar"
145
146 # merges already given params
147 copied = client.copy(default_query={"bar": "stainless"})
148 params = _get_params(copied)
149 assert params["foo"] == "bar"
150 assert params["bar"] == "stainless"
151
152 # uses new values for any already given headers
153 copied = client.copy(default_query={"foo": "stainless"})
154 assert _get_params(copied)["foo"] == "stainless"
155
156 # set_default_query
157
158 # completely overrides already set values
159 copied = client.copy(set_default_query={})
160 assert _get_params(copied) == {}
161
162 copied = client.copy(set_default_query={"bar": "Robert"})
163 assert _get_params(copied)["bar"] == "Robert"
164
165 with pytest.raises(
166 ValueError,
167 # TODO: update
168 match="`default_query` and `set_default_query` arguments are mutually exclusive",
169 ):
170 client.copy(set_default_query={}, default_query={"foo": "Bar"})
171
172 def test_copy_signature(self) -> None:
173 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
174 init_signature = inspect.signature(
175 # mypy doesn't like that we access the `__init__` property.
176 self.client.__init__, # type: ignore[misc]
177 )
178 copy_signature = inspect.signature(self.client.copy)
179 exclude_params = {"transport", "proxies", "_strict_response_validation"}
180
181 for name in init_signature.parameters.keys():
182 if name in exclude_params:
183 continue
184
185 copy_param = copy_signature.parameters.get(name)
186 assert copy_param is not None, f"copy() signature is missing the {name} param"
187
188 def test_copy_build_request(self) -> None:
189 options = FinalRequestOptions(method="get", url="/foo")
190
191 def build_request(options: FinalRequestOptions) -> None:
192 client = self.client.copy()
193 client._build_request(options)
194
195 # ensure that the machinery is warmed up before tracing starts.
196 build_request(options)
197 gc.collect()
198
199 tracemalloc.start(1000)
200
201 snapshot_before = tracemalloc.take_snapshot()
202
203 ITERATIONS = 10
204 for _ in range(ITERATIONS):
205 build_request(options)
206
207 gc.collect()
208 snapshot_after = tracemalloc.take_snapshot()
209
210 tracemalloc.stop()
211
212 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
213 if diff.count == 0:
214 # Avoid false positives by considering only leaks (i.e. allocations that persist).
215 return
216
217 if diff.count % ITERATIONS != 0:
218 # Avoid false positives by considering only leaks that appear per iteration.
219 return
220
221 for frame in diff.traceback:
222 if any(
223 frame.filename.endswith(fragment)
224 for fragment in [
225 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
226 #
227 # removing the decorator fixes the leak for reasons we don't understand.
228 "openai/_legacy_response.py",
229 "openai/_response.py",
230 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
231 "openai/_compat.py",
232 # Standard library leaks we don't care about.
233 "/logging/__init__.py",
234 ]
235 ):
236 return
237
238 leaks.append(diff)
239
240 leaks: list[tracemalloc.StatisticDiff] = []
241 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
242 add_leak(leaks, diff)
243 if leaks:
244 for leak in leaks:
245 print("MEMORY LEAK:", leak)
246 for frame in leak.traceback:
247 print(frame)
248 raise AssertionError()
249
250 def test_request_timeout(self) -> None:
251 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
252 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
253 assert timeout == DEFAULT_TIMEOUT
254
255 request = self.client._build_request(
256 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
257 )
258 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
259 assert timeout == httpx.Timeout(100.0)
260
261 def test_client_timeout_option(self) -> None:
262 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
263
264 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
265 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
266 assert timeout == httpx.Timeout(0)
267
268 def test_http_client_timeout_option(self) -> None:
269 # custom timeout given to the httpx client should be used
270 with httpx.Client(timeout=None) as http_client:
271 client = OpenAI(
272 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
273 )
274
275 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
276 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
277 assert timeout == httpx.Timeout(None)
278
279 # no timeout given to the httpx client should not use the httpx default
280 with httpx.Client() as http_client:
281 client = OpenAI(
282 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
283 )
284
285 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
286 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
287 assert timeout == DEFAULT_TIMEOUT
288
289 # explicitly passing the default timeout currently results in it being ignored
290 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
291 client = OpenAI(
292 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
293 )
294
295 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
296 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
297 assert timeout == DEFAULT_TIMEOUT # our default
298
299 async def test_invalid_http_client(self) -> None:
300 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
301 async with httpx.AsyncClient() as http_client:
302 OpenAI(
303 base_url=base_url,
304 api_key=api_key,
305 _strict_response_validation=True,
306 http_client=cast(Any, http_client),
307 )
308
309 def test_default_headers_option(self) -> None:
310 client = OpenAI(
311 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
312 )
313 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
314 assert request.headers.get("x-foo") == "bar"
315 assert request.headers.get("x-stainless-lang") == "python"
316
317 client2 = OpenAI(
318 base_url=base_url,
319 api_key=api_key,
320 _strict_response_validation=True,
321 default_headers={
322 "X-Foo": "stainless",
323 "X-Stainless-Lang": "my-overriding-header",
324 },
325 )
326 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
327 assert request.headers.get("x-foo") == "stainless"
328 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
329
330 def test_validate_headers(self) -> None:
331 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
332 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
333 assert request.headers.get("Authorization") == f"Bearer {api_key}"
334
335 with pytest.raises(OpenAIError):
336 with update_env(**{"OPENAI_API_KEY": Omit()}):
337 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
338 _ = client2
339
340 def test_default_query_option(self) -> None:
341 client = OpenAI(
342 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
343 )
344 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
345 url = httpx.URL(request.url)
346 assert dict(url.params) == {"query_param": "bar"}
347
348 request = client._build_request(
349 FinalRequestOptions(
350 method="get",
351 url="/foo",
352 params={"foo": "baz", "query_param": "overridden"},
353 )
354 )
355 url = httpx.URL(request.url)
356 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
357
358 def test_request_extra_json(self) -> None:
359 request = self.client._build_request(
360 FinalRequestOptions(
361 method="post",
362 url="/foo",
363 json_data={"foo": "bar"},
364 extra_json={"baz": False},
365 ),
366 )
367 data = json.loads(request.content.decode("utf-8"))
368 assert data == {"foo": "bar", "baz": False}
369
370 request = self.client._build_request(
371 FinalRequestOptions(
372 method="post",
373 url="/foo",
374 extra_json={"baz": False},
375 ),
376 )
377 data = json.loads(request.content.decode("utf-8"))
378 assert data == {"baz": False}
379
380 # `extra_json` takes priority over `json_data` when keys clash
381 request = self.client._build_request(
382 FinalRequestOptions(
383 method="post",
384 url="/foo",
385 json_data={"foo": "bar", "baz": True},
386 extra_json={"baz": None},
387 ),
388 )
389 data = json.loads(request.content.decode("utf-8"))
390 assert data == {"foo": "bar", "baz": None}
391
392 def test_request_extra_headers(self) -> None:
393 request = self.client._build_request(
394 FinalRequestOptions(
395 method="post",
396 url="/foo",
397 **make_request_options(extra_headers={"X-Foo": "Foo"}),
398 ),
399 )
400 assert request.headers.get("X-Foo") == "Foo"
401
402 # `extra_headers` takes priority over `default_headers` when keys clash
403 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
404 FinalRequestOptions(
405 method="post",
406 url="/foo",
407 **make_request_options(
408 extra_headers={"X-Bar": "false"},
409 ),
410 ),
411 )
412 assert request.headers.get("X-Bar") == "false"
413
414 def test_request_extra_query(self) -> None:
415 request = self.client._build_request(
416 FinalRequestOptions(
417 method="post",
418 url="/foo",
419 **make_request_options(
420 extra_query={"my_query_param": "Foo"},
421 ),
422 ),
423 )
424 params = dict(request.url.params)
425 assert params == {"my_query_param": "Foo"}
426
427 # if both `query` and `extra_query` are given, they are merged
428 request = self.client._build_request(
429 FinalRequestOptions(
430 method="post",
431 url="/foo",
432 **make_request_options(
433 query={"bar": "1"},
434 extra_query={"foo": "2"},
435 ),
436 ),
437 )
438 params = dict(request.url.params)
439 assert params == {"bar": "1", "foo": "2"}
440
441 # `extra_query` takes priority over `query` when keys clash
442 request = self.client._build_request(
443 FinalRequestOptions(
444 method="post",
445 url="/foo",
446 **make_request_options(
447 query={"foo": "1"},
448 extra_query={"foo": "2"},
449 ),
450 ),
451 )
452 params = dict(request.url.params)
453 assert params == {"foo": "2"}
454
455 def test_multipart_repeating_array(self, client: OpenAI) -> None:
456 request = client._build_request(
457 FinalRequestOptions.construct(
458 method="get",
459 url="/foo",
460 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
461 json_data={"array": ["foo", "bar"]},
462 files=[("foo.txt", b"hello world")],
463 )
464 )
465
466 assert request.read().split(b"\r\n") == [
467 b"--6b7ba517decee4a450543ea6ae821c82",
468 b'Content-Disposition: form-data; name="array[]"',
469 b"",
470 b"foo",
471 b"--6b7ba517decee4a450543ea6ae821c82",
472 b'Content-Disposition: form-data; name="array[]"',
473 b"",
474 b"bar",
475 b"--6b7ba517decee4a450543ea6ae821c82",
476 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
477 b"Content-Type: application/octet-stream",
478 b"",
479 b"hello world",
480 b"--6b7ba517decee4a450543ea6ae821c82--",
481 b"",
482 ]
483
484 @pytest.mark.respx(base_url=base_url)
485 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
486 class Model1(BaseModel):
487 name: str
488
489 class Model2(BaseModel):
490 foo: str
491
492 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
493
494 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
495 assert isinstance(response, Model2)
496 assert response.foo == "bar"
497
498 @pytest.mark.respx(base_url=base_url)
499 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
500 """Union of objects with the same field name using a different type"""
501
502 class Model1(BaseModel):
503 foo: int
504
505 class Model2(BaseModel):
506 foo: str
507
508 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
509
510 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
511 assert isinstance(response, Model2)
512 assert response.foo == "bar"
513
514 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
515
516 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
517 assert isinstance(response, Model1)
518 assert response.foo == 1
519
520 @pytest.mark.respx(base_url=base_url)
521 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
522 """
523 Response that sets Content-Type to something other than application/json but returns json data
524 """
525
526 class Model(BaseModel):
527 foo: int
528
529 respx_mock.get("/foo").mock(
530 return_value=httpx.Response(
531 200,
532 content=json.dumps({"foo": 2}),
533 headers={"Content-Type": "application/text"},
534 )
535 )
536
537 response = self.client.get("/foo", cast_to=Model)
538 assert isinstance(response, Model)
539 assert response.foo == 2
540
541 def test_base_url_setter(self) -> None:
542 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
543 assert client.base_url == "https://example.com/from_init/"
544
545 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
546
547 assert client.base_url == "https://example.com/from_setter/"
548
549 def test_base_url_env(self) -> None:
550 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
551 client = OpenAI(api_key=api_key, _strict_response_validation=True)
552 assert client.base_url == "http://localhost:5000/from/env/"
553
554 @pytest.mark.parametrize(
555 "client",
556 [
557 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
558 OpenAI(
559 base_url="http://localhost:5000/custom/path/",
560 api_key=api_key,
561 _strict_response_validation=True,
562 http_client=httpx.Client(),
563 ),
564 ],
565 ids=["standard", "custom http client"],
566 )
567 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
568 request = client._build_request(
569 FinalRequestOptions(
570 method="post",
571 url="/foo",
572 json_data={"foo": "bar"},
573 ),
574 )
575 assert request.url == "http://localhost:5000/custom/path/foo"
576
577 @pytest.mark.parametrize(
578 "client",
579 [
580 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
581 OpenAI(
582 base_url="http://localhost:5000/custom/path/",
583 api_key=api_key,
584 _strict_response_validation=True,
585 http_client=httpx.Client(),
586 ),
587 ],
588 ids=["standard", "custom http client"],
589 )
590 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
591 request = client._build_request(
592 FinalRequestOptions(
593 method="post",
594 url="/foo",
595 json_data={"foo": "bar"},
596 ),
597 )
598 assert request.url == "http://localhost:5000/custom/path/foo"
599
600 @pytest.mark.parametrize(
601 "client",
602 [
603 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
604 OpenAI(
605 base_url="http://localhost:5000/custom/path/",
606 api_key=api_key,
607 _strict_response_validation=True,
608 http_client=httpx.Client(),
609 ),
610 ],
611 ids=["standard", "custom http client"],
612 )
613 def test_absolute_request_url(self, client: OpenAI) -> None:
614 request = client._build_request(
615 FinalRequestOptions(
616 method="post",
617 url="https://myapi.com/foo",
618 json_data={"foo": "bar"},
619 ),
620 )
621 assert request.url == "https://myapi.com/foo"
622
623 def test_copied_client_does_not_close_http(self) -> None:
624 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
625 assert not client.is_closed()
626
627 copied = client.copy()
628 assert copied is not client
629
630 del copied
631
632 assert not client.is_closed()
633
634 def test_client_context_manager(self) -> None:
635 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
636 with client as c2:
637 assert c2 is client
638 assert not c2.is_closed()
639 assert not client.is_closed()
640 assert client.is_closed()
641
642 @pytest.mark.respx(base_url=base_url)
643 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
644 class Model(BaseModel):
645 foo: str
646
647 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
648
649 with pytest.raises(APIResponseValidationError) as exc:
650 self.client.get("/foo", cast_to=Model)
651
652 assert isinstance(exc.value.__cause__, ValidationError)
653
654 def test_client_max_retries_validation(self) -> None:
655 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
656 OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None))
657
658 @pytest.mark.respx(base_url=base_url)
659 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
660 class Model(BaseModel):
661 name: str
662
663 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
664
665 stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
666 assert isinstance(stream, Stream)
667 stream.response.close()
668
669 @pytest.mark.respx(base_url=base_url)
670 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
671 class Model(BaseModel):
672 name: str
673
674 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
675
676 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
677
678 with pytest.raises(APIResponseValidationError):
679 strict_client.get("/foo", cast_to=Model)
680
681 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
682
683 response = client.get("/foo", cast_to=Model)
684 assert isinstance(response, str) # type: ignore[unreachable]
685
686 @pytest.mark.parametrize(
687 "remaining_retries,retry_after,timeout",
688 [
689 [3, "20", 20],
690 [3, "0", 0.5],
691 [3, "-10", 0.5],
692 [3, "60", 60],
693 [3, "61", 0.5],
694 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
695 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
696 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
697 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
698 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
699 [3, "99999999999999999999999999999999999", 0.5],
700 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
701 [3, "", 0.5],
702 [2, "", 0.5 * 2.0],
703 [1, "", 0.5 * 4.0],
704 [-1100, "", 8], # test large number potentially overflowing
705 ],
706 )
707 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
708 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
709 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
710
711 headers = httpx.Headers({"retry-after": retry_after})
712 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
713 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
714 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
715
716 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
717 @pytest.mark.respx(base_url=base_url)
718 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
719 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
720
721 with pytest.raises(APITimeoutError):
722 self.client.post(
723 "/chat/completions",
724 body=cast(
725 object,
726 dict(
727 messages=[
728 {
729 "role": "user",
730 "content": "Say this is a test",
731 }
732 ],
733 model="gpt-4o",
734 ),
735 ),
736 cast_to=httpx.Response,
737 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
738 )
739
740 assert _get_open_connections(self.client) == 0
741
742 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
743 @pytest.mark.respx(base_url=base_url)
744 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
745 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
746
747 with pytest.raises(APIStatusError):
748 self.client.post(
749 "/chat/completions",
750 body=cast(
751 object,
752 dict(
753 messages=[
754 {
755 "role": "user",
756 "content": "Say this is a test",
757 }
758 ],
759 model="gpt-4o",
760 ),
761 ),
762 cast_to=httpx.Response,
763 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
764 )
765
766 assert _get_open_connections(self.client) == 0
767
768 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
769 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
770 @pytest.mark.respx(base_url=base_url)
771 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
772 def test_retries_taken(
773 self,
774 client: OpenAI,
775 failures_before_success: int,
776 failure_mode: Literal["status", "exception"],
777 respx_mock: MockRouter,
778 ) -> None:
779 client = client.with_options(max_retries=4)
780
781 nb_retries = 0
782
783 def retry_handler(_request: httpx.Request) -> httpx.Response:
784 nonlocal nb_retries
785 if nb_retries < failures_before_success:
786 nb_retries += 1
787 if failure_mode == "exception":
788 raise RuntimeError("oops")
789 return httpx.Response(500)
790 return httpx.Response(200)
791
792 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
793
794 response = client.chat.completions.with_raw_response.create(
795 messages=[
796 {
797 "content": "string",
798 "role": "developer",
799 }
800 ],
801 model="gpt-4o",
802 )
803
804 assert response.retries_taken == failures_before_success
805 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
806
807 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
808 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
809 @pytest.mark.respx(base_url=base_url)
810 def test_omit_retry_count_header(
811 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
812 ) -> None:
813 client = client.with_options(max_retries=4)
814
815 nb_retries = 0
816
817 def retry_handler(_request: httpx.Request) -> httpx.Response:
818 nonlocal nb_retries
819 if nb_retries < failures_before_success:
820 nb_retries += 1
821 return httpx.Response(500)
822 return httpx.Response(200)
823
824 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
825
826 response = client.chat.completions.with_raw_response.create(
827 messages=[
828 {
829 "content": "string",
830 "role": "developer",
831 }
832 ],
833 model="gpt-4o",
834 extra_headers={"x-stainless-retry-count": Omit()},
835 )
836
837 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
838
839 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
840 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
841 @pytest.mark.respx(base_url=base_url)
842 def test_overwrite_retry_count_header(
843 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
844 ) -> None:
845 client = client.with_options(max_retries=4)
846
847 nb_retries = 0
848
849 def retry_handler(_request: httpx.Request) -> httpx.Response:
850 nonlocal nb_retries
851 if nb_retries < failures_before_success:
852 nb_retries += 1
853 return httpx.Response(500)
854 return httpx.Response(200)
855
856 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
857
858 response = client.chat.completions.with_raw_response.create(
859 messages=[
860 {
861 "content": "string",
862 "role": "developer",
863 }
864 ],
865 model="gpt-4o",
866 extra_headers={"x-stainless-retry-count": "42"},
867 )
868
869 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
870
871 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
872 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
873 @pytest.mark.respx(base_url=base_url)
874 def test_retries_taken_new_response_class(
875 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
876 ) -> None:
877 client = client.with_options(max_retries=4)
878
879 nb_retries = 0
880
881 def retry_handler(_request: httpx.Request) -> httpx.Response:
882 nonlocal nb_retries
883 if nb_retries < failures_before_success:
884 nb_retries += 1
885 return httpx.Response(500)
886 return httpx.Response(200)
887
888 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
889
890 with client.chat.completions.with_streaming_response.create(
891 messages=[
892 {
893 "content": "string",
894 "role": "developer",
895 }
896 ],
897 model="gpt-4o",
898 ) as response:
899 assert response.retries_taken == failures_before_success
900 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
901
902
903class TestAsyncOpenAI:
904 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
905
906 @pytest.mark.respx(base_url=base_url)
907 @pytest.mark.asyncio
908 async def test_raw_response(self, respx_mock: MockRouter) -> None:
909 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
910
911 response = await self.client.post("/foo", cast_to=httpx.Response)
912 assert response.status_code == 200
913 assert isinstance(response, httpx.Response)
914 assert response.json() == {"foo": "bar"}
915
916 @pytest.mark.respx(base_url=base_url)
917 @pytest.mark.asyncio
918 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
919 respx_mock.post("/foo").mock(
920 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
921 )
922
923 response = await self.client.post("/foo", cast_to=httpx.Response)
924 assert response.status_code == 200
925 assert isinstance(response, httpx.Response)
926 assert response.json() == {"foo": "bar"}
927
928 def test_copy(self) -> None:
929 copied = self.client.copy()
930 assert id(copied) != id(self.client)
931
932 copied = self.client.copy(api_key="another My API Key")
933 assert copied.api_key == "another My API Key"
934 assert self.client.api_key == "My API Key"
935
936 def test_copy_default_options(self) -> None:
937 # options that have a default are overridden correctly
938 copied = self.client.copy(max_retries=7)
939 assert copied.max_retries == 7
940 assert self.client.max_retries == 2
941
942 copied2 = copied.copy(max_retries=6)
943 assert copied2.max_retries == 6
944 assert copied.max_retries == 7
945
946 # timeout
947 assert isinstance(self.client.timeout, httpx.Timeout)
948 copied = self.client.copy(timeout=None)
949 assert copied.timeout is None
950 assert isinstance(self.client.timeout, httpx.Timeout)
951
952 def test_copy_default_headers(self) -> None:
953 client = AsyncOpenAI(
954 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
955 )
956 assert client.default_headers["X-Foo"] == "bar"
957
958 # does not override the already given value when not specified
959 copied = client.copy()
960 assert copied.default_headers["X-Foo"] == "bar"
961
962 # merges already given headers
963 copied = client.copy(default_headers={"X-Bar": "stainless"})
964 assert copied.default_headers["X-Foo"] == "bar"
965 assert copied.default_headers["X-Bar"] == "stainless"
966
967 # uses new values for any already given headers
968 copied = client.copy(default_headers={"X-Foo": "stainless"})
969 assert copied.default_headers["X-Foo"] == "stainless"
970
971 # set_default_headers
972
973 # completely overrides already set values
974 copied = client.copy(set_default_headers={})
975 assert copied.default_headers.get("X-Foo") is None
976
977 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
978 assert copied.default_headers["X-Bar"] == "Robert"
979
980 with pytest.raises(
981 ValueError,
982 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
983 ):
984 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
985
986 def test_copy_default_query(self) -> None:
987 client = AsyncOpenAI(
988 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
989 )
990 assert _get_params(client)["foo"] == "bar"
991
992 # does not override the already given value when not specified
993 copied = client.copy()
994 assert _get_params(copied)["foo"] == "bar"
995
996 # merges already given params
997 copied = client.copy(default_query={"bar": "stainless"})
998 params = _get_params(copied)
999 assert params["foo"] == "bar"
1000 assert params["bar"] == "stainless"
1001
1002 # uses new values for any already given headers
1003 copied = client.copy(default_query={"foo": "stainless"})
1004 assert _get_params(copied)["foo"] == "stainless"
1005
1006 # set_default_query
1007
1008 # completely overrides already set values
1009 copied = client.copy(set_default_query={})
1010 assert _get_params(copied) == {}
1011
1012 copied = client.copy(set_default_query={"bar": "Robert"})
1013 assert _get_params(copied)["bar"] == "Robert"
1014
1015 with pytest.raises(
1016 ValueError,
1017 # TODO: update
1018 match="`default_query` and `set_default_query` arguments are mutually exclusive",
1019 ):
1020 client.copy(set_default_query={}, default_query={"foo": "Bar"})
1021
1022 def test_copy_signature(self) -> None:
1023 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
1024 init_signature = inspect.signature(
1025 # mypy doesn't like that we access the `__init__` property.
1026 self.client.__init__, # type: ignore[misc]
1027 )
1028 copy_signature = inspect.signature(self.client.copy)
1029 exclude_params = {"transport", "proxies", "_strict_response_validation"}
1030
1031 for name in init_signature.parameters.keys():
1032 if name in exclude_params:
1033 continue
1034
1035 copy_param = copy_signature.parameters.get(name)
1036 assert copy_param is not None, f"copy() signature is missing the {name} param"
1037
1038 def test_copy_build_request(self) -> None:
1039 options = FinalRequestOptions(method="get", url="/foo")
1040
1041 def build_request(options: FinalRequestOptions) -> None:
1042 client = self.client.copy()
1043 client._build_request(options)
1044
1045 # ensure that the machinery is warmed up before tracing starts.
1046 build_request(options)
1047 gc.collect()
1048
1049 tracemalloc.start(1000)
1050
1051 snapshot_before = tracemalloc.take_snapshot()
1052
1053 ITERATIONS = 10
1054 for _ in range(ITERATIONS):
1055 build_request(options)
1056
1057 gc.collect()
1058 snapshot_after = tracemalloc.take_snapshot()
1059
1060 tracemalloc.stop()
1061
1062 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
1063 if diff.count == 0:
1064 # Avoid false positives by considering only leaks (i.e. allocations that persist).
1065 return
1066
1067 if diff.count % ITERATIONS != 0:
1068 # Avoid false positives by considering only leaks that appear per iteration.
1069 return
1070
1071 for frame in diff.traceback:
1072 if any(
1073 frame.filename.endswith(fragment)
1074 for fragment in [
1075 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
1076 #
1077 # removing the decorator fixes the leak for reasons we don't understand.
1078 "openai/_legacy_response.py",
1079 "openai/_response.py",
1080 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
1081 "openai/_compat.py",
1082 # Standard library leaks we don't care about.
1083 "/logging/__init__.py",
1084 ]
1085 ):
1086 return
1087
1088 leaks.append(diff)
1089
1090 leaks: list[tracemalloc.StatisticDiff] = []
1091 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
1092 add_leak(leaks, diff)
1093 if leaks:
1094 for leak in leaks:
1095 print("MEMORY LEAK:", leak)
1096 for frame in leak.traceback:
1097 print(frame)
1098 raise AssertionError()
1099
1100 async def test_request_timeout(self) -> None:
1101 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
1102 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1103 assert timeout == DEFAULT_TIMEOUT
1104
1105 request = self.client._build_request(
1106 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
1107 )
1108 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1109 assert timeout == httpx.Timeout(100.0)
1110
1111 async def test_client_timeout_option(self) -> None:
1112 client = AsyncOpenAI(
1113 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
1114 )
1115
1116 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1117 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1118 assert timeout == httpx.Timeout(0)
1119
1120 async def test_http_client_timeout_option(self) -> None:
1121 # custom timeout given to the httpx client should be used
1122 async with httpx.AsyncClient(timeout=None) as http_client:
1123 client = AsyncOpenAI(
1124 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1125 )
1126
1127 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1128 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1129 assert timeout == httpx.Timeout(None)
1130
1131 # no timeout given to the httpx client should not use the httpx default
1132 async with httpx.AsyncClient() as http_client:
1133 client = AsyncOpenAI(
1134 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1135 )
1136
1137 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1138 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1139 assert timeout == DEFAULT_TIMEOUT
1140
1141 # explicitly passing the default timeout currently results in it being ignored
1142 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1143 client = AsyncOpenAI(
1144 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1145 )
1146
1147 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1148 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1149 assert timeout == DEFAULT_TIMEOUT # our default
1150
1151 def test_invalid_http_client(self) -> None:
1152 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1153 with httpx.Client() as http_client:
1154 AsyncOpenAI(
1155 base_url=base_url,
1156 api_key=api_key,
1157 _strict_response_validation=True,
1158 http_client=cast(Any, http_client),
1159 )
1160
1161 def test_default_headers_option(self) -> None:
1162 client = AsyncOpenAI(
1163 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1164 )
1165 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1166 assert request.headers.get("x-foo") == "bar"
1167 assert request.headers.get("x-stainless-lang") == "python"
1168
1169 client2 = AsyncOpenAI(
1170 base_url=base_url,
1171 api_key=api_key,
1172 _strict_response_validation=True,
1173 default_headers={
1174 "X-Foo": "stainless",
1175 "X-Stainless-Lang": "my-overriding-header",
1176 },
1177 )
1178 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1179 assert request.headers.get("x-foo") == "stainless"
1180 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1181
1182 def test_validate_headers(self) -> None:
1183 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1184 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1185 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1186
1187 with pytest.raises(OpenAIError):
1188 with update_env(**{"OPENAI_API_KEY": Omit()}):
1189 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1190 _ = client2
1191
1192 def test_default_query_option(self) -> None:
1193 client = AsyncOpenAI(
1194 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1195 )
1196 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1197 url = httpx.URL(request.url)
1198 assert dict(url.params) == {"query_param": "bar"}
1199
1200 request = client._build_request(
1201 FinalRequestOptions(
1202 method="get",
1203 url="/foo",
1204 params={"foo": "baz", "query_param": "overridden"},
1205 )
1206 )
1207 url = httpx.URL(request.url)
1208 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
1209
1210 def test_request_extra_json(self) -> None:
1211 request = self.client._build_request(
1212 FinalRequestOptions(
1213 method="post",
1214 url="/foo",
1215 json_data={"foo": "bar"},
1216 extra_json={"baz": False},
1217 ),
1218 )
1219 data = json.loads(request.content.decode("utf-8"))
1220 assert data == {"foo": "bar", "baz": False}
1221
1222 request = self.client._build_request(
1223 FinalRequestOptions(
1224 method="post",
1225 url="/foo",
1226 extra_json={"baz": False},
1227 ),
1228 )
1229 data = json.loads(request.content.decode("utf-8"))
1230 assert data == {"baz": False}
1231
1232 # `extra_json` takes priority over `json_data` when keys clash
1233 request = self.client._build_request(
1234 FinalRequestOptions(
1235 method="post",
1236 url="/foo",
1237 json_data={"foo": "bar", "baz": True},
1238 extra_json={"baz": None},
1239 ),
1240 )
1241 data = json.loads(request.content.decode("utf-8"))
1242 assert data == {"foo": "bar", "baz": None}
1243
1244 def test_request_extra_headers(self) -> None:
1245 request = self.client._build_request(
1246 FinalRequestOptions(
1247 method="post",
1248 url="/foo",
1249 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1250 ),
1251 )
1252 assert request.headers.get("X-Foo") == "Foo"
1253
1254 # `extra_headers` takes priority over `default_headers` when keys clash
1255 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1256 FinalRequestOptions(
1257 method="post",
1258 url="/foo",
1259 **make_request_options(
1260 extra_headers={"X-Bar": "false"},
1261 ),
1262 ),
1263 )
1264 assert request.headers.get("X-Bar") == "false"
1265
1266 def test_request_extra_query(self) -> None:
1267 request = self.client._build_request(
1268 FinalRequestOptions(
1269 method="post",
1270 url="/foo",
1271 **make_request_options(
1272 extra_query={"my_query_param": "Foo"},
1273 ),
1274 ),
1275 )
1276 params = dict(request.url.params)
1277 assert params == {"my_query_param": "Foo"}
1278
1279 # if both `query` and `extra_query` are given, they are merged
1280 request = self.client._build_request(
1281 FinalRequestOptions(
1282 method="post",
1283 url="/foo",
1284 **make_request_options(
1285 query={"bar": "1"},
1286 extra_query={"foo": "2"},
1287 ),
1288 ),
1289 )
1290 params = dict(request.url.params)
1291 assert params == {"bar": "1", "foo": "2"}
1292
1293 # `extra_query` takes priority over `query` when keys clash
1294 request = self.client._build_request(
1295 FinalRequestOptions(
1296 method="post",
1297 url="/foo",
1298 **make_request_options(
1299 query={"foo": "1"},
1300 extra_query={"foo": "2"},
1301 ),
1302 ),
1303 )
1304 params = dict(request.url.params)
1305 assert params == {"foo": "2"}
1306
1307 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1308 request = async_client._build_request(
1309 FinalRequestOptions.construct(
1310 method="get",
1311 url="/foo",
1312 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1313 json_data={"array": ["foo", "bar"]},
1314 files=[("foo.txt", b"hello world")],
1315 )
1316 )
1317
1318 assert request.read().split(b"\r\n") == [
1319 b"--6b7ba517decee4a450543ea6ae821c82",
1320 b'Content-Disposition: form-data; name="array[]"',
1321 b"",
1322 b"foo",
1323 b"--6b7ba517decee4a450543ea6ae821c82",
1324 b'Content-Disposition: form-data; name="array[]"',
1325 b"",
1326 b"bar",
1327 b"--6b7ba517decee4a450543ea6ae821c82",
1328 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1329 b"Content-Type: application/octet-stream",
1330 b"",
1331 b"hello world",
1332 b"--6b7ba517decee4a450543ea6ae821c82--",
1333 b"",
1334 ]
1335
1336 @pytest.mark.respx(base_url=base_url)
1337 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1338 class Model1(BaseModel):
1339 name: str
1340
1341 class Model2(BaseModel):
1342 foo: str
1343
1344 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1345
1346 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1347 assert isinstance(response, Model2)
1348 assert response.foo == "bar"
1349
1350 @pytest.mark.respx(base_url=base_url)
1351 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1352 """Union of objects with the same field name using a different type"""
1353
1354 class Model1(BaseModel):
1355 foo: int
1356
1357 class Model2(BaseModel):
1358 foo: str
1359
1360 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1361
1362 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1363 assert isinstance(response, Model2)
1364 assert response.foo == "bar"
1365
1366 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1367
1368 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1369 assert isinstance(response, Model1)
1370 assert response.foo == 1
1371
1372 @pytest.mark.respx(base_url=base_url)
1373 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1374 """
1375 Response that sets Content-Type to something other than application/json but returns json data
1376 """
1377
1378 class Model(BaseModel):
1379 foo: int
1380
1381 respx_mock.get("/foo").mock(
1382 return_value=httpx.Response(
1383 200,
1384 content=json.dumps({"foo": 2}),
1385 headers={"Content-Type": "application/text"},
1386 )
1387 )
1388
1389 response = await self.client.get("/foo", cast_to=Model)
1390 assert isinstance(response, Model)
1391 assert response.foo == 2
1392
1393 def test_base_url_setter(self) -> None:
1394 client = AsyncOpenAI(
1395 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1396 )
1397 assert client.base_url == "https://example.com/from_init/"
1398
1399 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1400
1401 assert client.base_url == "https://example.com/from_setter/"
1402
1403 def test_base_url_env(self) -> None:
1404 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1405 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1406 assert client.base_url == "http://localhost:5000/from/env/"
1407
1408 @pytest.mark.parametrize(
1409 "client",
1410 [
1411 AsyncOpenAI(
1412 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1413 ),
1414 AsyncOpenAI(
1415 base_url="http://localhost:5000/custom/path/",
1416 api_key=api_key,
1417 _strict_response_validation=True,
1418 http_client=httpx.AsyncClient(),
1419 ),
1420 ],
1421 ids=["standard", "custom http client"],
1422 )
1423 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1424 request = client._build_request(
1425 FinalRequestOptions(
1426 method="post",
1427 url="/foo",
1428 json_data={"foo": "bar"},
1429 ),
1430 )
1431 assert request.url == "http://localhost:5000/custom/path/foo"
1432
1433 @pytest.mark.parametrize(
1434 "client",
1435 [
1436 AsyncOpenAI(
1437 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1438 ),
1439 AsyncOpenAI(
1440 base_url="http://localhost:5000/custom/path/",
1441 api_key=api_key,
1442 _strict_response_validation=True,
1443 http_client=httpx.AsyncClient(),
1444 ),
1445 ],
1446 ids=["standard", "custom http client"],
1447 )
1448 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1449 request = client._build_request(
1450 FinalRequestOptions(
1451 method="post",
1452 url="/foo",
1453 json_data={"foo": "bar"},
1454 ),
1455 )
1456 assert request.url == "http://localhost:5000/custom/path/foo"
1457
1458 @pytest.mark.parametrize(
1459 "client",
1460 [
1461 AsyncOpenAI(
1462 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1463 ),
1464 AsyncOpenAI(
1465 base_url="http://localhost:5000/custom/path/",
1466 api_key=api_key,
1467 _strict_response_validation=True,
1468 http_client=httpx.AsyncClient(),
1469 ),
1470 ],
1471 ids=["standard", "custom http client"],
1472 )
1473 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1474 request = client._build_request(
1475 FinalRequestOptions(
1476 method="post",
1477 url="https://myapi.com/foo",
1478 json_data={"foo": "bar"},
1479 ),
1480 )
1481 assert request.url == "https://myapi.com/foo"
1482
1483 async def test_copied_client_does_not_close_http(self) -> None:
1484 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1485 assert not client.is_closed()
1486
1487 copied = client.copy()
1488 assert copied is not client
1489
1490 del copied
1491
1492 await asyncio.sleep(0.2)
1493 assert not client.is_closed()
1494
1495 async def test_client_context_manager(self) -> None:
1496 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1497 async with client as c2:
1498 assert c2 is client
1499 assert not c2.is_closed()
1500 assert not client.is_closed()
1501 assert client.is_closed()
1502
1503 @pytest.mark.respx(base_url=base_url)
1504 @pytest.mark.asyncio
1505 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1506 class Model(BaseModel):
1507 foo: str
1508
1509 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1510
1511 with pytest.raises(APIResponseValidationError) as exc:
1512 await self.client.get("/foo", cast_to=Model)
1513
1514 assert isinstance(exc.value.__cause__, ValidationError)
1515
1516 async def test_client_max_retries_validation(self) -> None:
1517 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
1518 AsyncOpenAI(
1519 base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)
1520 )
1521
1522 @pytest.mark.respx(base_url=base_url)
1523 @pytest.mark.asyncio
1524 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1525 class Model(BaseModel):
1526 name: str
1527
1528 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1529
1530 stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1531 assert isinstance(stream, AsyncStream)
1532 await stream.response.aclose()
1533
1534 @pytest.mark.respx(base_url=base_url)
1535 @pytest.mark.asyncio
1536 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1537 class Model(BaseModel):
1538 name: str
1539
1540 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1541
1542 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1543
1544 with pytest.raises(APIResponseValidationError):
1545 await strict_client.get("/foo", cast_to=Model)
1546
1547 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1548
1549 response = await client.get("/foo", cast_to=Model)
1550 assert isinstance(response, str) # type: ignore[unreachable]
1551
1552 @pytest.mark.parametrize(
1553 "remaining_retries,retry_after,timeout",
1554 [
1555 [3, "20", 20],
1556 [3, "0", 0.5],
1557 [3, "-10", 0.5],
1558 [3, "60", 60],
1559 [3, "61", 0.5],
1560 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1561 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1562 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1563 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1564 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1565 [3, "99999999999999999999999999999999999", 0.5],
1566 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1567 [3, "", 0.5],
1568 [2, "", 0.5 * 2.0],
1569 [1, "", 0.5 * 4.0],
1570 [-1100, "", 8], # test large number potentially overflowing
1571 ],
1572 )
1573 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1574 @pytest.mark.asyncio
1575 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1576 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1577
1578 headers = httpx.Headers({"retry-after": retry_after})
1579 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1580 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1581 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1582
1583 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1584 @pytest.mark.respx(base_url=base_url)
1585 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1586 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1587
1588 with pytest.raises(APITimeoutError):
1589 await self.client.post(
1590 "/chat/completions",
1591 body=cast(
1592 object,
1593 dict(
1594 messages=[
1595 {
1596 "role": "user",
1597 "content": "Say this is a test",
1598 }
1599 ],
1600 model="gpt-4o",
1601 ),
1602 ),
1603 cast_to=httpx.Response,
1604 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1605 )
1606
1607 assert _get_open_connections(self.client) == 0
1608
1609 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1610 @pytest.mark.respx(base_url=base_url)
1611 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1612 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1613
1614 with pytest.raises(APIStatusError):
1615 await self.client.post(
1616 "/chat/completions",
1617 body=cast(
1618 object,
1619 dict(
1620 messages=[
1621 {
1622 "role": "user",
1623 "content": "Say this is a test",
1624 }
1625 ],
1626 model="gpt-4o",
1627 ),
1628 ),
1629 cast_to=httpx.Response,
1630 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1631 )
1632
1633 assert _get_open_connections(self.client) == 0
1634
1635 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1636 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1637 @pytest.mark.respx(base_url=base_url)
1638 @pytest.mark.asyncio
1639 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
1640 async def test_retries_taken(
1641 self,
1642 async_client: AsyncOpenAI,
1643 failures_before_success: int,
1644 failure_mode: Literal["status", "exception"],
1645 respx_mock: MockRouter,
1646 ) -> None:
1647 client = async_client.with_options(max_retries=4)
1648
1649 nb_retries = 0
1650
1651 def retry_handler(_request: httpx.Request) -> httpx.Response:
1652 nonlocal nb_retries
1653 if nb_retries < failures_before_success:
1654 nb_retries += 1
1655 if failure_mode == "exception":
1656 raise RuntimeError("oops")
1657 return httpx.Response(500)
1658 return httpx.Response(200)
1659
1660 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1661
1662 response = await client.chat.completions.with_raw_response.create(
1663 messages=[
1664 {
1665 "content": "string",
1666 "role": "developer",
1667 }
1668 ],
1669 model="gpt-4o",
1670 )
1671
1672 assert response.retries_taken == failures_before_success
1673 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1674
1675 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1676 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1677 @pytest.mark.respx(base_url=base_url)
1678 @pytest.mark.asyncio
1679 async def test_omit_retry_count_header(
1680 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1681 ) -> None:
1682 client = async_client.with_options(max_retries=4)
1683
1684 nb_retries = 0
1685
1686 def retry_handler(_request: httpx.Request) -> httpx.Response:
1687 nonlocal nb_retries
1688 if nb_retries < failures_before_success:
1689 nb_retries += 1
1690 return httpx.Response(500)
1691 return httpx.Response(200)
1692
1693 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1694
1695 response = await client.chat.completions.with_raw_response.create(
1696 messages=[
1697 {
1698 "content": "string",
1699 "role": "developer",
1700 }
1701 ],
1702 model="gpt-4o",
1703 extra_headers={"x-stainless-retry-count": Omit()},
1704 )
1705
1706 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
1707
1708 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1709 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1710 @pytest.mark.respx(base_url=base_url)
1711 @pytest.mark.asyncio
1712 async def test_overwrite_retry_count_header(
1713 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1714 ) -> None:
1715 client = async_client.with_options(max_retries=4)
1716
1717 nb_retries = 0
1718
1719 def retry_handler(_request: httpx.Request) -> httpx.Response:
1720 nonlocal nb_retries
1721 if nb_retries < failures_before_success:
1722 nb_retries += 1
1723 return httpx.Response(500)
1724 return httpx.Response(200)
1725
1726 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1727
1728 response = await client.chat.completions.with_raw_response.create(
1729 messages=[
1730 {
1731 "content": "string",
1732 "role": "developer",
1733 }
1734 ],
1735 model="gpt-4o",
1736 extra_headers={"x-stainless-retry-count": "42"},
1737 )
1738
1739 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
1740
1741 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1742 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1743 @pytest.mark.respx(base_url=base_url)
1744 @pytest.mark.asyncio
1745 async def test_retries_taken_new_response_class(
1746 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1747 ) -> None:
1748 client = async_client.with_options(max_retries=4)
1749
1750 nb_retries = 0
1751
1752 def retry_handler(_request: httpx.Request) -> httpx.Response:
1753 nonlocal nb_retries
1754 if nb_retries < failures_before_success:
1755 nb_retries += 1
1756 return httpx.Response(500)
1757 return httpx.Response(200)
1758
1759 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1760
1761 async with client.chat.completions.with_streaming_response.create(
1762 messages=[
1763 {
1764 "content": "string",
1765 "role": "developer",
1766 }
1767 ],
1768 model="gpt-4o",
1769 ) as response:
1770 assert response.retries_taken == failures_before_success
1771 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1772
1773 def test_get_platform(self) -> None:
1774 # A previous implementation of asyncify could leave threads unterminated when
1775 # used with nest_asyncio.
1776 #
1777 # Since nest_asyncio.apply() is global and cannot be un-applied, this
1778 # test is run in a separate process to avoid affecting other tests.
1779 test_code = dedent("""
1780 import asyncio
1781 import nest_asyncio
1782 import threading
1783
1784 from openai._utils import asyncify
1785 from openai._base_client import get_platform
1786
1787 async def test_main() -> None:
1788 result = await asyncify(get_platform)()
1789 print(result)
1790 for thread in threading.enumerate():
1791 print(thread.name)
1792
1793 nest_asyncio.apply()
1794 asyncio.run(test_main())
1795 """)
1796 with subprocess.Popen(
1797 [sys.executable, "-c", test_code],
1798 text=True,
1799 ) as process:
1800 try:
1801 process.wait(2)
1802 if process.returncode:
1803 raise AssertionError("calling get_platform using asyncify resulted in a non-zero exit code")
1804 except subprocess.TimeoutExpired as e:
1805 process.kill()
1806 raise AssertionError("calling get_platform using asyncify resulted in a hung process") from e