openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
af6a9437128fc64643178a12d3e700a962f08977

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1817lines · modecode

1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import gc
6import os
7import sys
8import json
9import time
10import asyncio
11import inspect
12import subprocess
13import tracemalloc
14from typing import Any, Union, cast
15from textwrap import dedent
16from unittest import mock
17from typing_extensions import Literal
18
19import httpx
20import pytest
21from respx import MockRouter
22from pydantic import ValidationError
23
24from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
25from openai._types import Omit
26from openai._models import BaseModel, FinalRequestOptions
27from openai._constants import RAW_RESPONSE_HEADER
28from openai._streaming import Stream, AsyncStream
29from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
30from openai._base_client import DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, BaseClient, make_request_options
31
32from .utils import update_env
33
34base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
35api_key = "My API Key"
36
37
38def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
39 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
40 url = httpx.URL(request.url)
41 return dict(url.params)
42
43
44def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
45 return 0.1
46
47
48def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
49 transport = client._client._transport
50 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
51
52 pool = transport._pool
53 return len(pool._requests)
54
55
56class TestOpenAI:
57 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
58
59 @pytest.mark.respx(base_url=base_url)
60 def test_raw_response(self, respx_mock: MockRouter) -> None:
61 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
62
63 response = self.client.post("/foo", cast_to=httpx.Response)
64 assert response.status_code == 200
65 assert isinstance(response, httpx.Response)
66 assert response.json() == {"foo": "bar"}
67
68 @pytest.mark.respx(base_url=base_url)
69 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
70 respx_mock.post("/foo").mock(
71 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
72 )
73
74 response = self.client.post("/foo", cast_to=httpx.Response)
75 assert response.status_code == 200
76 assert isinstance(response, httpx.Response)
77 assert response.json() == {"foo": "bar"}
78
79 def test_copy(self) -> None:
80 copied = self.client.copy()
81 assert id(copied) != id(self.client)
82
83 copied = self.client.copy(api_key="another My API Key")
84 assert copied.api_key == "another My API Key"
85 assert self.client.api_key == "My API Key"
86
87 def test_copy_default_options(self) -> None:
88 # options that have a default are overridden correctly
89 copied = self.client.copy(max_retries=7)
90 assert copied.max_retries == 7
91 assert self.client.max_retries == 2
92
93 copied2 = copied.copy(max_retries=6)
94 assert copied2.max_retries == 6
95 assert copied.max_retries == 7
96
97 # timeout
98 assert isinstance(self.client.timeout, httpx.Timeout)
99 copied = self.client.copy(timeout=None)
100 assert copied.timeout is None
101 assert isinstance(self.client.timeout, httpx.Timeout)
102
103 def test_copy_default_headers(self) -> None:
104 client = OpenAI(
105 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
106 )
107 assert client.default_headers["X-Foo"] == "bar"
108
109 # does not override the already given value when not specified
110 copied = client.copy()
111 assert copied.default_headers["X-Foo"] == "bar"
112
113 # merges already given headers
114 copied = client.copy(default_headers={"X-Bar": "stainless"})
115 assert copied.default_headers["X-Foo"] == "bar"
116 assert copied.default_headers["X-Bar"] == "stainless"
117
118 # uses new values for any already given headers
119 copied = client.copy(default_headers={"X-Foo": "stainless"})
120 assert copied.default_headers["X-Foo"] == "stainless"
121
122 # set_default_headers
123
124 # completely overrides already set values
125 copied = client.copy(set_default_headers={})
126 assert copied.default_headers.get("X-Foo") is None
127
128 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
129 assert copied.default_headers["X-Bar"] == "Robert"
130
131 with pytest.raises(
132 ValueError,
133 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
134 ):
135 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
136
137 def test_copy_default_query(self) -> None:
138 client = OpenAI(
139 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
140 )
141 assert _get_params(client)["foo"] == "bar"
142
143 # does not override the already given value when not specified
144 copied = client.copy()
145 assert _get_params(copied)["foo"] == "bar"
146
147 # merges already given params
148 copied = client.copy(default_query={"bar": "stainless"})
149 params = _get_params(copied)
150 assert params["foo"] == "bar"
151 assert params["bar"] == "stainless"
152
153 # uses new values for any already given headers
154 copied = client.copy(default_query={"foo": "stainless"})
155 assert _get_params(copied)["foo"] == "stainless"
156
157 # set_default_query
158
159 # completely overrides already set values
160 copied = client.copy(set_default_query={})
161 assert _get_params(copied) == {}
162
163 copied = client.copy(set_default_query={"bar": "Robert"})
164 assert _get_params(copied)["bar"] == "Robert"
165
166 with pytest.raises(
167 ValueError,
168 # TODO: update
169 match="`default_query` and `set_default_query` arguments are mutually exclusive",
170 ):
171 client.copy(set_default_query={}, default_query={"foo": "Bar"})
172
173 def test_copy_signature(self) -> None:
174 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
175 init_signature = inspect.signature(
176 # mypy doesn't like that we access the `__init__` property.
177 self.client.__init__, # type: ignore[misc]
178 )
179 copy_signature = inspect.signature(self.client.copy)
180 exclude_params = {"transport", "proxies", "_strict_response_validation"}
181
182 for name in init_signature.parameters.keys():
183 if name in exclude_params:
184 continue
185
186 copy_param = copy_signature.parameters.get(name)
187 assert copy_param is not None, f"copy() signature is missing the {name} param"
188
189 def test_copy_build_request(self) -> None:
190 options = FinalRequestOptions(method="get", url="/foo")
191
192 def build_request(options: FinalRequestOptions) -> None:
193 client = self.client.copy()
194 client._build_request(options)
195
196 # ensure that the machinery is warmed up before tracing starts.
197 build_request(options)
198 gc.collect()
199
200 tracemalloc.start(1000)
201
202 snapshot_before = tracemalloc.take_snapshot()
203
204 ITERATIONS = 10
205 for _ in range(ITERATIONS):
206 build_request(options)
207
208 gc.collect()
209 snapshot_after = tracemalloc.take_snapshot()
210
211 tracemalloc.stop()
212
213 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
214 if diff.count == 0:
215 # Avoid false positives by considering only leaks (i.e. allocations that persist).
216 return
217
218 if diff.count % ITERATIONS != 0:
219 # Avoid false positives by considering only leaks that appear per iteration.
220 return
221
222 for frame in diff.traceback:
223 if any(
224 frame.filename.endswith(fragment)
225 for fragment in [
226 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
227 #
228 # removing the decorator fixes the leak for reasons we don't understand.
229 "openai/_legacy_response.py",
230 "openai/_response.py",
231 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
232 "openai/_compat.py",
233 # Standard library leaks we don't care about.
234 "/logging/__init__.py",
235 ]
236 ):
237 return
238
239 leaks.append(diff)
240
241 leaks: list[tracemalloc.StatisticDiff] = []
242 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
243 add_leak(leaks, diff)
244 if leaks:
245 for leak in leaks:
246 print("MEMORY LEAK:", leak)
247 for frame in leak.traceback:
248 print(frame)
249 raise AssertionError()
250
251 def test_request_timeout(self) -> None:
252 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
253 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
254 assert timeout == DEFAULT_TIMEOUT
255
256 request = self.client._build_request(
257 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
258 )
259 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
260 assert timeout == httpx.Timeout(100.0)
261
262 def test_client_timeout_option(self) -> None:
263 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
264
265 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
266 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
267 assert timeout == httpx.Timeout(0)
268
269 def test_http_client_timeout_option(self) -> None:
270 # custom timeout given to the httpx client should be used
271 with httpx.Client(timeout=None) as http_client:
272 client = OpenAI(
273 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
274 )
275
276 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
277 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
278 assert timeout == httpx.Timeout(None)
279
280 # no timeout given to the httpx client should not use the httpx default
281 with httpx.Client() as http_client:
282 client = OpenAI(
283 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
284 )
285
286 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
287 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
288 assert timeout == DEFAULT_TIMEOUT
289
290 # explicitly passing the default timeout currently results in it being ignored
291 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
292 client = OpenAI(
293 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
294 )
295
296 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
297 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
298 assert timeout == DEFAULT_TIMEOUT # our default
299
300 async def test_invalid_http_client(self) -> None:
301 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
302 async with httpx.AsyncClient() as http_client:
303 OpenAI(
304 base_url=base_url,
305 api_key=api_key,
306 _strict_response_validation=True,
307 http_client=cast(Any, http_client),
308 )
309
310 def test_default_headers_option(self) -> None:
311 client = OpenAI(
312 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
313 )
314 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
315 assert request.headers.get("x-foo") == "bar"
316 assert request.headers.get("x-stainless-lang") == "python"
317
318 client2 = OpenAI(
319 base_url=base_url,
320 api_key=api_key,
321 _strict_response_validation=True,
322 default_headers={
323 "X-Foo": "stainless",
324 "X-Stainless-Lang": "my-overriding-header",
325 },
326 )
327 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
328 assert request.headers.get("x-foo") == "stainless"
329 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
330
331 def test_validate_headers(self) -> None:
332 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
333 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
334 assert request.headers.get("Authorization") == f"Bearer {api_key}"
335
336 with pytest.raises(OpenAIError):
337 with update_env(**{"OPENAI_API_KEY": Omit()}):
338 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
339 _ = client2
340
341 def test_default_query_option(self) -> None:
342 client = OpenAI(
343 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
344 )
345 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
346 url = httpx.URL(request.url)
347 assert dict(url.params) == {"query_param": "bar"}
348
349 request = client._build_request(
350 FinalRequestOptions(
351 method="get",
352 url="/foo",
353 params={"foo": "baz", "query_param": "overridden"},
354 )
355 )
356 url = httpx.URL(request.url)
357 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
358
359 def test_request_extra_json(self) -> None:
360 request = self.client._build_request(
361 FinalRequestOptions(
362 method="post",
363 url="/foo",
364 json_data={"foo": "bar"},
365 extra_json={"baz": False},
366 ),
367 )
368 data = json.loads(request.content.decode("utf-8"))
369 assert data == {"foo": "bar", "baz": False}
370
371 request = self.client._build_request(
372 FinalRequestOptions(
373 method="post",
374 url="/foo",
375 extra_json={"baz": False},
376 ),
377 )
378 data = json.loads(request.content.decode("utf-8"))
379 assert data == {"baz": False}
380
381 # `extra_json` takes priority over `json_data` when keys clash
382 request = self.client._build_request(
383 FinalRequestOptions(
384 method="post",
385 url="/foo",
386 json_data={"foo": "bar", "baz": True},
387 extra_json={"baz": None},
388 ),
389 )
390 data = json.loads(request.content.decode("utf-8"))
391 assert data == {"foo": "bar", "baz": None}
392
393 def test_request_extra_headers(self) -> None:
394 request = self.client._build_request(
395 FinalRequestOptions(
396 method="post",
397 url="/foo",
398 **make_request_options(extra_headers={"X-Foo": "Foo"}),
399 ),
400 )
401 assert request.headers.get("X-Foo") == "Foo"
402
403 # `extra_headers` takes priority over `default_headers` when keys clash
404 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
405 FinalRequestOptions(
406 method="post",
407 url="/foo",
408 **make_request_options(
409 extra_headers={"X-Bar": "false"},
410 ),
411 ),
412 )
413 assert request.headers.get("X-Bar") == "false"
414
415 def test_request_extra_query(self) -> None:
416 request = self.client._build_request(
417 FinalRequestOptions(
418 method="post",
419 url="/foo",
420 **make_request_options(
421 extra_query={"my_query_param": "Foo"},
422 ),
423 ),
424 )
425 params = dict(request.url.params)
426 assert params == {"my_query_param": "Foo"}
427
428 # if both `query` and `extra_query` are given, they are merged
429 request = self.client._build_request(
430 FinalRequestOptions(
431 method="post",
432 url="/foo",
433 **make_request_options(
434 query={"bar": "1"},
435 extra_query={"foo": "2"},
436 ),
437 ),
438 )
439 params = dict(request.url.params)
440 assert params == {"bar": "1", "foo": "2"}
441
442 # `extra_query` takes priority over `query` when keys clash
443 request = self.client._build_request(
444 FinalRequestOptions(
445 method="post",
446 url="/foo",
447 **make_request_options(
448 query={"foo": "1"},
449 extra_query={"foo": "2"},
450 ),
451 ),
452 )
453 params = dict(request.url.params)
454 assert params == {"foo": "2"}
455
456 def test_multipart_repeating_array(self, client: OpenAI) -> None:
457 request = client._build_request(
458 FinalRequestOptions.construct(
459 method="get",
460 url="/foo",
461 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
462 json_data={"array": ["foo", "bar"]},
463 files=[("foo.txt", b"hello world")],
464 )
465 )
466
467 assert request.read().split(b"\r\n") == [
468 b"--6b7ba517decee4a450543ea6ae821c82",
469 b'Content-Disposition: form-data; name="array[]"',
470 b"",
471 b"foo",
472 b"--6b7ba517decee4a450543ea6ae821c82",
473 b'Content-Disposition: form-data; name="array[]"',
474 b"",
475 b"bar",
476 b"--6b7ba517decee4a450543ea6ae821c82",
477 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
478 b"Content-Type: application/octet-stream",
479 b"",
480 b"hello world",
481 b"--6b7ba517decee4a450543ea6ae821c82--",
482 b"",
483 ]
484
485 @pytest.mark.respx(base_url=base_url)
486 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
487 class Model1(BaseModel):
488 name: str
489
490 class Model2(BaseModel):
491 foo: str
492
493 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
494
495 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
496 assert isinstance(response, Model2)
497 assert response.foo == "bar"
498
499 @pytest.mark.respx(base_url=base_url)
500 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
501 """Union of objects with the same field name using a different type"""
502
503 class Model1(BaseModel):
504 foo: int
505
506 class Model2(BaseModel):
507 foo: str
508
509 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
510
511 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
512 assert isinstance(response, Model2)
513 assert response.foo == "bar"
514
515 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
516
517 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
518 assert isinstance(response, Model1)
519 assert response.foo == 1
520
521 @pytest.mark.respx(base_url=base_url)
522 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
523 """
524 Response that sets Content-Type to something other than application/json but returns json data
525 """
526
527 class Model(BaseModel):
528 foo: int
529
530 respx_mock.get("/foo").mock(
531 return_value=httpx.Response(
532 200,
533 content=json.dumps({"foo": 2}),
534 headers={"Content-Type": "application/text"},
535 )
536 )
537
538 response = self.client.get("/foo", cast_to=Model)
539 assert isinstance(response, Model)
540 assert response.foo == 2
541
542 def test_base_url_setter(self) -> None:
543 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
544 assert client.base_url == "https://example.com/from_init/"
545
546 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
547
548 assert client.base_url == "https://example.com/from_setter/"
549
550 def test_base_url_env(self) -> None:
551 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
552 client = OpenAI(api_key=api_key, _strict_response_validation=True)
553 assert client.base_url == "http://localhost:5000/from/env/"
554
555 @pytest.mark.parametrize(
556 "client",
557 [
558 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
559 OpenAI(
560 base_url="http://localhost:5000/custom/path/",
561 api_key=api_key,
562 _strict_response_validation=True,
563 http_client=httpx.Client(),
564 ),
565 ],
566 ids=["standard", "custom http client"],
567 )
568 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
569 request = client._build_request(
570 FinalRequestOptions(
571 method="post",
572 url="/foo",
573 json_data={"foo": "bar"},
574 ),
575 )
576 assert request.url == "http://localhost:5000/custom/path/foo"
577
578 @pytest.mark.parametrize(
579 "client",
580 [
581 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
582 OpenAI(
583 base_url="http://localhost:5000/custom/path/",
584 api_key=api_key,
585 _strict_response_validation=True,
586 http_client=httpx.Client(),
587 ),
588 ],
589 ids=["standard", "custom http client"],
590 )
591 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
592 request = client._build_request(
593 FinalRequestOptions(
594 method="post",
595 url="/foo",
596 json_data={"foo": "bar"},
597 ),
598 )
599 assert request.url == "http://localhost:5000/custom/path/foo"
600
601 @pytest.mark.parametrize(
602 "client",
603 [
604 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
605 OpenAI(
606 base_url="http://localhost:5000/custom/path/",
607 api_key=api_key,
608 _strict_response_validation=True,
609 http_client=httpx.Client(),
610 ),
611 ],
612 ids=["standard", "custom http client"],
613 )
614 def test_absolute_request_url(self, client: OpenAI) -> None:
615 request = client._build_request(
616 FinalRequestOptions(
617 method="post",
618 url="https://myapi.com/foo",
619 json_data={"foo": "bar"},
620 ),
621 )
622 assert request.url == "https://myapi.com/foo"
623
624 def test_copied_client_does_not_close_http(self) -> None:
625 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
626 assert not client.is_closed()
627
628 copied = client.copy()
629 assert copied is not client
630
631 del copied
632
633 assert not client.is_closed()
634
635 def test_client_context_manager(self) -> None:
636 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
637 with client as c2:
638 assert c2 is client
639 assert not c2.is_closed()
640 assert not client.is_closed()
641 assert client.is_closed()
642
643 @pytest.mark.respx(base_url=base_url)
644 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
645 class Model(BaseModel):
646 foo: str
647
648 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
649
650 with pytest.raises(APIResponseValidationError) as exc:
651 self.client.get("/foo", cast_to=Model)
652
653 assert isinstance(exc.value.__cause__, ValidationError)
654
655 def test_client_max_retries_validation(self) -> None:
656 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
657 OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None))
658
659 @pytest.mark.respx(base_url=base_url)
660 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
661 class Model(BaseModel):
662 name: str
663
664 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
665
666 stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
667 assert isinstance(stream, Stream)
668 stream.response.close()
669
670 @pytest.mark.respx(base_url=base_url)
671 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
672 class Model(BaseModel):
673 name: str
674
675 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
676
677 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
678
679 with pytest.raises(APIResponseValidationError):
680 strict_client.get("/foo", cast_to=Model)
681
682 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
683
684 response = client.get("/foo", cast_to=Model)
685 assert isinstance(response, str) # type: ignore[unreachable]
686
687 @pytest.mark.parametrize(
688 "remaining_retries,retry_after,timeout",
689 [
690 [3, "20", 20],
691 [3, "0", 0.5],
692 [3, "-10", 0.5],
693 [3, "60", 60],
694 [3, "61", 0.5],
695 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
696 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
697 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
698 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
699 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
700 [3, "99999999999999999999999999999999999", 0.5],
701 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
702 [3, "", 0.5],
703 [2, "", 0.5 * 2.0],
704 [1, "", 0.5 * 4.0],
705 [-1100, "", 8], # test large number potentially overflowing
706 ],
707 )
708 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
709 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
710 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
711
712 headers = httpx.Headers({"retry-after": retry_after})
713 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
714 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
715 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
716
717 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
718 @pytest.mark.respx(base_url=base_url)
719 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
720 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
721
722 with pytest.raises(APITimeoutError):
723 self.client.post(
724 "/chat/completions",
725 body=cast(
726 object,
727 dict(
728 messages=[
729 {
730 "role": "user",
731 "content": "Say this is a test",
732 }
733 ],
734 model="gpt-4o",
735 ),
736 ),
737 cast_to=httpx.Response,
738 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
739 )
740
741 assert _get_open_connections(self.client) == 0
742
743 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
744 @pytest.mark.respx(base_url=base_url)
745 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
746 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
747
748 with pytest.raises(APIStatusError):
749 self.client.post(
750 "/chat/completions",
751 body=cast(
752 object,
753 dict(
754 messages=[
755 {
756 "role": "user",
757 "content": "Say this is a test",
758 }
759 ],
760 model="gpt-4o",
761 ),
762 ),
763 cast_to=httpx.Response,
764 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
765 )
766
767 assert _get_open_connections(self.client) == 0
768
769 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
770 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
771 @pytest.mark.respx(base_url=base_url)
772 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
773 def test_retries_taken(
774 self,
775 client: OpenAI,
776 failures_before_success: int,
777 failure_mode: Literal["status", "exception"],
778 respx_mock: MockRouter,
779 ) -> None:
780 client = client.with_options(max_retries=4)
781
782 nb_retries = 0
783
784 def retry_handler(_request: httpx.Request) -> httpx.Response:
785 nonlocal nb_retries
786 if nb_retries < failures_before_success:
787 nb_retries += 1
788 if failure_mode == "exception":
789 raise RuntimeError("oops")
790 return httpx.Response(500)
791 return httpx.Response(200)
792
793 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
794
795 response = client.chat.completions.with_raw_response.create(
796 messages=[
797 {
798 "content": "string",
799 "role": "developer",
800 }
801 ],
802 model="gpt-4o",
803 )
804
805 assert response.retries_taken == failures_before_success
806 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
807
808 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
809 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
810 @pytest.mark.respx(base_url=base_url)
811 def test_omit_retry_count_header(
812 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
813 ) -> None:
814 client = client.with_options(max_retries=4)
815
816 nb_retries = 0
817
818 def retry_handler(_request: httpx.Request) -> httpx.Response:
819 nonlocal nb_retries
820 if nb_retries < failures_before_success:
821 nb_retries += 1
822 return httpx.Response(500)
823 return httpx.Response(200)
824
825 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
826
827 response = client.chat.completions.with_raw_response.create(
828 messages=[
829 {
830 "content": "string",
831 "role": "developer",
832 }
833 ],
834 model="gpt-4o",
835 extra_headers={"x-stainless-retry-count": Omit()},
836 )
837
838 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
839
840 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
841 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
842 @pytest.mark.respx(base_url=base_url)
843 def test_overwrite_retry_count_header(
844 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
845 ) -> None:
846 client = client.with_options(max_retries=4)
847
848 nb_retries = 0
849
850 def retry_handler(_request: httpx.Request) -> httpx.Response:
851 nonlocal nb_retries
852 if nb_retries < failures_before_success:
853 nb_retries += 1
854 return httpx.Response(500)
855 return httpx.Response(200)
856
857 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
858
859 response = client.chat.completions.with_raw_response.create(
860 messages=[
861 {
862 "content": "string",
863 "role": "developer",
864 }
865 ],
866 model="gpt-4o",
867 extra_headers={"x-stainless-retry-count": "42"},
868 )
869
870 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
871
872 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
873 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
874 @pytest.mark.respx(base_url=base_url)
875 def test_retries_taken_new_response_class(
876 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
877 ) -> None:
878 client = client.with_options(max_retries=4)
879
880 nb_retries = 0
881
882 def retry_handler(_request: httpx.Request) -> httpx.Response:
883 nonlocal nb_retries
884 if nb_retries < failures_before_success:
885 nb_retries += 1
886 return httpx.Response(500)
887 return httpx.Response(200)
888
889 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
890
891 with client.chat.completions.with_streaming_response.create(
892 messages=[
893 {
894 "content": "string",
895 "role": "developer",
896 }
897 ],
898 model="gpt-4o",
899 ) as response:
900 assert response.retries_taken == failures_before_success
901 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
902
903
904class TestAsyncOpenAI:
905 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
906
907 @pytest.mark.respx(base_url=base_url)
908 @pytest.mark.asyncio
909 async def test_raw_response(self, respx_mock: MockRouter) -> None:
910 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
911
912 response = await self.client.post("/foo", cast_to=httpx.Response)
913 assert response.status_code == 200
914 assert isinstance(response, httpx.Response)
915 assert response.json() == {"foo": "bar"}
916
917 @pytest.mark.respx(base_url=base_url)
918 @pytest.mark.asyncio
919 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
920 respx_mock.post("/foo").mock(
921 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
922 )
923
924 response = await self.client.post("/foo", cast_to=httpx.Response)
925 assert response.status_code == 200
926 assert isinstance(response, httpx.Response)
927 assert response.json() == {"foo": "bar"}
928
929 def test_copy(self) -> None:
930 copied = self.client.copy()
931 assert id(copied) != id(self.client)
932
933 copied = self.client.copy(api_key="another My API Key")
934 assert copied.api_key == "another My API Key"
935 assert self.client.api_key == "My API Key"
936
937 def test_copy_default_options(self) -> None:
938 # options that have a default are overridden correctly
939 copied = self.client.copy(max_retries=7)
940 assert copied.max_retries == 7
941 assert self.client.max_retries == 2
942
943 copied2 = copied.copy(max_retries=6)
944 assert copied2.max_retries == 6
945 assert copied.max_retries == 7
946
947 # timeout
948 assert isinstance(self.client.timeout, httpx.Timeout)
949 copied = self.client.copy(timeout=None)
950 assert copied.timeout is None
951 assert isinstance(self.client.timeout, httpx.Timeout)
952
953 def test_copy_default_headers(self) -> None:
954 client = AsyncOpenAI(
955 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
956 )
957 assert client.default_headers["X-Foo"] == "bar"
958
959 # does not override the already given value when not specified
960 copied = client.copy()
961 assert copied.default_headers["X-Foo"] == "bar"
962
963 # merges already given headers
964 copied = client.copy(default_headers={"X-Bar": "stainless"})
965 assert copied.default_headers["X-Foo"] == "bar"
966 assert copied.default_headers["X-Bar"] == "stainless"
967
968 # uses new values for any already given headers
969 copied = client.copy(default_headers={"X-Foo": "stainless"})
970 assert copied.default_headers["X-Foo"] == "stainless"
971
972 # set_default_headers
973
974 # completely overrides already set values
975 copied = client.copy(set_default_headers={})
976 assert copied.default_headers.get("X-Foo") is None
977
978 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
979 assert copied.default_headers["X-Bar"] == "Robert"
980
981 with pytest.raises(
982 ValueError,
983 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
984 ):
985 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
986
987 def test_copy_default_query(self) -> None:
988 client = AsyncOpenAI(
989 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
990 )
991 assert _get_params(client)["foo"] == "bar"
992
993 # does not override the already given value when not specified
994 copied = client.copy()
995 assert _get_params(copied)["foo"] == "bar"
996
997 # merges already given params
998 copied = client.copy(default_query={"bar": "stainless"})
999 params = _get_params(copied)
1000 assert params["foo"] == "bar"
1001 assert params["bar"] == "stainless"
1002
1003 # uses new values for any already given headers
1004 copied = client.copy(default_query={"foo": "stainless"})
1005 assert _get_params(copied)["foo"] == "stainless"
1006
1007 # set_default_query
1008
1009 # completely overrides already set values
1010 copied = client.copy(set_default_query={})
1011 assert _get_params(copied) == {}
1012
1013 copied = client.copy(set_default_query={"bar": "Robert"})
1014 assert _get_params(copied)["bar"] == "Robert"
1015
1016 with pytest.raises(
1017 ValueError,
1018 # TODO: update
1019 match="`default_query` and `set_default_query` arguments are mutually exclusive",
1020 ):
1021 client.copy(set_default_query={}, default_query={"foo": "Bar"})
1022
1023 def test_copy_signature(self) -> None:
1024 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
1025 init_signature = inspect.signature(
1026 # mypy doesn't like that we access the `__init__` property.
1027 self.client.__init__, # type: ignore[misc]
1028 )
1029 copy_signature = inspect.signature(self.client.copy)
1030 exclude_params = {"transport", "proxies", "_strict_response_validation"}
1031
1032 for name in init_signature.parameters.keys():
1033 if name in exclude_params:
1034 continue
1035
1036 copy_param = copy_signature.parameters.get(name)
1037 assert copy_param is not None, f"copy() signature is missing the {name} param"
1038
1039 def test_copy_build_request(self) -> None:
1040 options = FinalRequestOptions(method="get", url="/foo")
1041
1042 def build_request(options: FinalRequestOptions) -> None:
1043 client = self.client.copy()
1044 client._build_request(options)
1045
1046 # ensure that the machinery is warmed up before tracing starts.
1047 build_request(options)
1048 gc.collect()
1049
1050 tracemalloc.start(1000)
1051
1052 snapshot_before = tracemalloc.take_snapshot()
1053
1054 ITERATIONS = 10
1055 for _ in range(ITERATIONS):
1056 build_request(options)
1057
1058 gc.collect()
1059 snapshot_after = tracemalloc.take_snapshot()
1060
1061 tracemalloc.stop()
1062
1063 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
1064 if diff.count == 0:
1065 # Avoid false positives by considering only leaks (i.e. allocations that persist).
1066 return
1067
1068 if diff.count % ITERATIONS != 0:
1069 # Avoid false positives by considering only leaks that appear per iteration.
1070 return
1071
1072 for frame in diff.traceback:
1073 if any(
1074 frame.filename.endswith(fragment)
1075 for fragment in [
1076 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
1077 #
1078 # removing the decorator fixes the leak for reasons we don't understand.
1079 "openai/_legacy_response.py",
1080 "openai/_response.py",
1081 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
1082 "openai/_compat.py",
1083 # Standard library leaks we don't care about.
1084 "/logging/__init__.py",
1085 ]
1086 ):
1087 return
1088
1089 leaks.append(diff)
1090
1091 leaks: list[tracemalloc.StatisticDiff] = []
1092 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
1093 add_leak(leaks, diff)
1094 if leaks:
1095 for leak in leaks:
1096 print("MEMORY LEAK:", leak)
1097 for frame in leak.traceback:
1098 print(frame)
1099 raise AssertionError()
1100
1101 async def test_request_timeout(self) -> None:
1102 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
1103 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1104 assert timeout == DEFAULT_TIMEOUT
1105
1106 request = self.client._build_request(
1107 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
1108 )
1109 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1110 assert timeout == httpx.Timeout(100.0)
1111
1112 async def test_client_timeout_option(self) -> None:
1113 client = AsyncOpenAI(
1114 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
1115 )
1116
1117 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1118 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1119 assert timeout == httpx.Timeout(0)
1120
1121 async def test_http_client_timeout_option(self) -> None:
1122 # custom timeout given to the httpx client should be used
1123 async with httpx.AsyncClient(timeout=None) as http_client:
1124 client = AsyncOpenAI(
1125 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1126 )
1127
1128 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1129 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1130 assert timeout == httpx.Timeout(None)
1131
1132 # no timeout given to the httpx client should not use the httpx default
1133 async with httpx.AsyncClient() as http_client:
1134 client = AsyncOpenAI(
1135 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1136 )
1137
1138 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1139 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1140 assert timeout == DEFAULT_TIMEOUT
1141
1142 # explicitly passing the default timeout currently results in it being ignored
1143 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1144 client = AsyncOpenAI(
1145 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1146 )
1147
1148 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1149 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1150 assert timeout == DEFAULT_TIMEOUT # our default
1151
1152 def test_invalid_http_client(self) -> None:
1153 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1154 with httpx.Client() as http_client:
1155 AsyncOpenAI(
1156 base_url=base_url,
1157 api_key=api_key,
1158 _strict_response_validation=True,
1159 http_client=cast(Any, http_client),
1160 )
1161
1162 def test_default_headers_option(self) -> None:
1163 client = AsyncOpenAI(
1164 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1165 )
1166 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1167 assert request.headers.get("x-foo") == "bar"
1168 assert request.headers.get("x-stainless-lang") == "python"
1169
1170 client2 = AsyncOpenAI(
1171 base_url=base_url,
1172 api_key=api_key,
1173 _strict_response_validation=True,
1174 default_headers={
1175 "X-Foo": "stainless",
1176 "X-Stainless-Lang": "my-overriding-header",
1177 },
1178 )
1179 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1180 assert request.headers.get("x-foo") == "stainless"
1181 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1182
1183 def test_validate_headers(self) -> None:
1184 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1185 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1186 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1187
1188 with pytest.raises(OpenAIError):
1189 with update_env(**{"OPENAI_API_KEY": Omit()}):
1190 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1191 _ = client2
1192
1193 def test_default_query_option(self) -> None:
1194 client = AsyncOpenAI(
1195 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1196 )
1197 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1198 url = httpx.URL(request.url)
1199 assert dict(url.params) == {"query_param": "bar"}
1200
1201 request = client._build_request(
1202 FinalRequestOptions(
1203 method="get",
1204 url="/foo",
1205 params={"foo": "baz", "query_param": "overridden"},
1206 )
1207 )
1208 url = httpx.URL(request.url)
1209 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
1210
1211 def test_request_extra_json(self) -> None:
1212 request = self.client._build_request(
1213 FinalRequestOptions(
1214 method="post",
1215 url="/foo",
1216 json_data={"foo": "bar"},
1217 extra_json={"baz": False},
1218 ),
1219 )
1220 data = json.loads(request.content.decode("utf-8"))
1221 assert data == {"foo": "bar", "baz": False}
1222
1223 request = self.client._build_request(
1224 FinalRequestOptions(
1225 method="post",
1226 url="/foo",
1227 extra_json={"baz": False},
1228 ),
1229 )
1230 data = json.loads(request.content.decode("utf-8"))
1231 assert data == {"baz": False}
1232
1233 # `extra_json` takes priority over `json_data` when keys clash
1234 request = self.client._build_request(
1235 FinalRequestOptions(
1236 method="post",
1237 url="/foo",
1238 json_data={"foo": "bar", "baz": True},
1239 extra_json={"baz": None},
1240 ),
1241 )
1242 data = json.loads(request.content.decode("utf-8"))
1243 assert data == {"foo": "bar", "baz": None}
1244
1245 def test_request_extra_headers(self) -> None:
1246 request = self.client._build_request(
1247 FinalRequestOptions(
1248 method="post",
1249 url="/foo",
1250 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1251 ),
1252 )
1253 assert request.headers.get("X-Foo") == "Foo"
1254
1255 # `extra_headers` takes priority over `default_headers` when keys clash
1256 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1257 FinalRequestOptions(
1258 method="post",
1259 url="/foo",
1260 **make_request_options(
1261 extra_headers={"X-Bar": "false"},
1262 ),
1263 ),
1264 )
1265 assert request.headers.get("X-Bar") == "false"
1266
1267 def test_request_extra_query(self) -> None:
1268 request = self.client._build_request(
1269 FinalRequestOptions(
1270 method="post",
1271 url="/foo",
1272 **make_request_options(
1273 extra_query={"my_query_param": "Foo"},
1274 ),
1275 ),
1276 )
1277 params = dict(request.url.params)
1278 assert params == {"my_query_param": "Foo"}
1279
1280 # if both `query` and `extra_query` are given, they are merged
1281 request = self.client._build_request(
1282 FinalRequestOptions(
1283 method="post",
1284 url="/foo",
1285 **make_request_options(
1286 query={"bar": "1"},
1287 extra_query={"foo": "2"},
1288 ),
1289 ),
1290 )
1291 params = dict(request.url.params)
1292 assert params == {"bar": "1", "foo": "2"}
1293
1294 # `extra_query` takes priority over `query` when keys clash
1295 request = self.client._build_request(
1296 FinalRequestOptions(
1297 method="post",
1298 url="/foo",
1299 **make_request_options(
1300 query={"foo": "1"},
1301 extra_query={"foo": "2"},
1302 ),
1303 ),
1304 )
1305 params = dict(request.url.params)
1306 assert params == {"foo": "2"}
1307
1308 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1309 request = async_client._build_request(
1310 FinalRequestOptions.construct(
1311 method="get",
1312 url="/foo",
1313 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1314 json_data={"array": ["foo", "bar"]},
1315 files=[("foo.txt", b"hello world")],
1316 )
1317 )
1318
1319 assert request.read().split(b"\r\n") == [
1320 b"--6b7ba517decee4a450543ea6ae821c82",
1321 b'Content-Disposition: form-data; name="array[]"',
1322 b"",
1323 b"foo",
1324 b"--6b7ba517decee4a450543ea6ae821c82",
1325 b'Content-Disposition: form-data; name="array[]"',
1326 b"",
1327 b"bar",
1328 b"--6b7ba517decee4a450543ea6ae821c82",
1329 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1330 b"Content-Type: application/octet-stream",
1331 b"",
1332 b"hello world",
1333 b"--6b7ba517decee4a450543ea6ae821c82--",
1334 b"",
1335 ]
1336
1337 @pytest.mark.respx(base_url=base_url)
1338 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1339 class Model1(BaseModel):
1340 name: str
1341
1342 class Model2(BaseModel):
1343 foo: str
1344
1345 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1346
1347 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1348 assert isinstance(response, Model2)
1349 assert response.foo == "bar"
1350
1351 @pytest.mark.respx(base_url=base_url)
1352 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1353 """Union of objects with the same field name using a different type"""
1354
1355 class Model1(BaseModel):
1356 foo: int
1357
1358 class Model2(BaseModel):
1359 foo: str
1360
1361 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1362
1363 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1364 assert isinstance(response, Model2)
1365 assert response.foo == "bar"
1366
1367 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1368
1369 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1370 assert isinstance(response, Model1)
1371 assert response.foo == 1
1372
1373 @pytest.mark.respx(base_url=base_url)
1374 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1375 """
1376 Response that sets Content-Type to something other than application/json but returns json data
1377 """
1378
1379 class Model(BaseModel):
1380 foo: int
1381
1382 respx_mock.get("/foo").mock(
1383 return_value=httpx.Response(
1384 200,
1385 content=json.dumps({"foo": 2}),
1386 headers={"Content-Type": "application/text"},
1387 )
1388 )
1389
1390 response = await self.client.get("/foo", cast_to=Model)
1391 assert isinstance(response, Model)
1392 assert response.foo == 2
1393
1394 def test_base_url_setter(self) -> None:
1395 client = AsyncOpenAI(
1396 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1397 )
1398 assert client.base_url == "https://example.com/from_init/"
1399
1400 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1401
1402 assert client.base_url == "https://example.com/from_setter/"
1403
1404 def test_base_url_env(self) -> None:
1405 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1406 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1407 assert client.base_url == "http://localhost:5000/from/env/"
1408
1409 @pytest.mark.parametrize(
1410 "client",
1411 [
1412 AsyncOpenAI(
1413 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1414 ),
1415 AsyncOpenAI(
1416 base_url="http://localhost:5000/custom/path/",
1417 api_key=api_key,
1418 _strict_response_validation=True,
1419 http_client=httpx.AsyncClient(),
1420 ),
1421 ],
1422 ids=["standard", "custom http client"],
1423 )
1424 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1425 request = client._build_request(
1426 FinalRequestOptions(
1427 method="post",
1428 url="/foo",
1429 json_data={"foo": "bar"},
1430 ),
1431 )
1432 assert request.url == "http://localhost:5000/custom/path/foo"
1433
1434 @pytest.mark.parametrize(
1435 "client",
1436 [
1437 AsyncOpenAI(
1438 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1439 ),
1440 AsyncOpenAI(
1441 base_url="http://localhost:5000/custom/path/",
1442 api_key=api_key,
1443 _strict_response_validation=True,
1444 http_client=httpx.AsyncClient(),
1445 ),
1446 ],
1447 ids=["standard", "custom http client"],
1448 )
1449 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1450 request = client._build_request(
1451 FinalRequestOptions(
1452 method="post",
1453 url="/foo",
1454 json_data={"foo": "bar"},
1455 ),
1456 )
1457 assert request.url == "http://localhost:5000/custom/path/foo"
1458
1459 @pytest.mark.parametrize(
1460 "client",
1461 [
1462 AsyncOpenAI(
1463 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1464 ),
1465 AsyncOpenAI(
1466 base_url="http://localhost:5000/custom/path/",
1467 api_key=api_key,
1468 _strict_response_validation=True,
1469 http_client=httpx.AsyncClient(),
1470 ),
1471 ],
1472 ids=["standard", "custom http client"],
1473 )
1474 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1475 request = client._build_request(
1476 FinalRequestOptions(
1477 method="post",
1478 url="https://myapi.com/foo",
1479 json_data={"foo": "bar"},
1480 ),
1481 )
1482 assert request.url == "https://myapi.com/foo"
1483
1484 async def test_copied_client_does_not_close_http(self) -> None:
1485 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1486 assert not client.is_closed()
1487
1488 copied = client.copy()
1489 assert copied is not client
1490
1491 del copied
1492
1493 await asyncio.sleep(0.2)
1494 assert not client.is_closed()
1495
1496 async def test_client_context_manager(self) -> None:
1497 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1498 async with client as c2:
1499 assert c2 is client
1500 assert not c2.is_closed()
1501 assert not client.is_closed()
1502 assert client.is_closed()
1503
1504 @pytest.mark.respx(base_url=base_url)
1505 @pytest.mark.asyncio
1506 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1507 class Model(BaseModel):
1508 foo: str
1509
1510 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1511
1512 with pytest.raises(APIResponseValidationError) as exc:
1513 await self.client.get("/foo", cast_to=Model)
1514
1515 assert isinstance(exc.value.__cause__, ValidationError)
1516
1517 async def test_client_max_retries_validation(self) -> None:
1518 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
1519 AsyncOpenAI(
1520 base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)
1521 )
1522
1523 @pytest.mark.respx(base_url=base_url)
1524 @pytest.mark.asyncio
1525 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1526 class Model(BaseModel):
1527 name: str
1528
1529 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1530
1531 stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1532 assert isinstance(stream, AsyncStream)
1533 await stream.response.aclose()
1534
1535 @pytest.mark.respx(base_url=base_url)
1536 @pytest.mark.asyncio
1537 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1538 class Model(BaseModel):
1539 name: str
1540
1541 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1542
1543 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1544
1545 with pytest.raises(APIResponseValidationError):
1546 await strict_client.get("/foo", cast_to=Model)
1547
1548 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1549
1550 response = await client.get("/foo", cast_to=Model)
1551 assert isinstance(response, str) # type: ignore[unreachable]
1552
1553 @pytest.mark.parametrize(
1554 "remaining_retries,retry_after,timeout",
1555 [
1556 [3, "20", 20],
1557 [3, "0", 0.5],
1558 [3, "-10", 0.5],
1559 [3, "60", 60],
1560 [3, "61", 0.5],
1561 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1562 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1563 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1564 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1565 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1566 [3, "99999999999999999999999999999999999", 0.5],
1567 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1568 [3, "", 0.5],
1569 [2, "", 0.5 * 2.0],
1570 [1, "", 0.5 * 4.0],
1571 [-1100, "", 8], # test large number potentially overflowing
1572 ],
1573 )
1574 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1575 @pytest.mark.asyncio
1576 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1577 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1578
1579 headers = httpx.Headers({"retry-after": retry_after})
1580 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1581 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1582 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1583
1584 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1585 @pytest.mark.respx(base_url=base_url)
1586 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1587 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1588
1589 with pytest.raises(APITimeoutError):
1590 await self.client.post(
1591 "/chat/completions",
1592 body=cast(
1593 object,
1594 dict(
1595 messages=[
1596 {
1597 "role": "user",
1598 "content": "Say this is a test",
1599 }
1600 ],
1601 model="gpt-4o",
1602 ),
1603 ),
1604 cast_to=httpx.Response,
1605 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1606 )
1607
1608 assert _get_open_connections(self.client) == 0
1609
1610 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1611 @pytest.mark.respx(base_url=base_url)
1612 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1613 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1614
1615 with pytest.raises(APIStatusError):
1616 await self.client.post(
1617 "/chat/completions",
1618 body=cast(
1619 object,
1620 dict(
1621 messages=[
1622 {
1623 "role": "user",
1624 "content": "Say this is a test",
1625 }
1626 ],
1627 model="gpt-4o",
1628 ),
1629 ),
1630 cast_to=httpx.Response,
1631 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1632 )
1633
1634 assert _get_open_connections(self.client) == 0
1635
1636 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1637 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1638 @pytest.mark.respx(base_url=base_url)
1639 @pytest.mark.asyncio
1640 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
1641 async def test_retries_taken(
1642 self,
1643 async_client: AsyncOpenAI,
1644 failures_before_success: int,
1645 failure_mode: Literal["status", "exception"],
1646 respx_mock: MockRouter,
1647 ) -> None:
1648 client = async_client.with_options(max_retries=4)
1649
1650 nb_retries = 0
1651
1652 def retry_handler(_request: httpx.Request) -> httpx.Response:
1653 nonlocal nb_retries
1654 if nb_retries < failures_before_success:
1655 nb_retries += 1
1656 if failure_mode == "exception":
1657 raise RuntimeError("oops")
1658 return httpx.Response(500)
1659 return httpx.Response(200)
1660
1661 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1662
1663 response = await client.chat.completions.with_raw_response.create(
1664 messages=[
1665 {
1666 "content": "string",
1667 "role": "developer",
1668 }
1669 ],
1670 model="gpt-4o",
1671 )
1672
1673 assert response.retries_taken == failures_before_success
1674 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1675
1676 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1677 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1678 @pytest.mark.respx(base_url=base_url)
1679 @pytest.mark.asyncio
1680 async def test_omit_retry_count_header(
1681 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1682 ) -> None:
1683 client = async_client.with_options(max_retries=4)
1684
1685 nb_retries = 0
1686
1687 def retry_handler(_request: httpx.Request) -> httpx.Response:
1688 nonlocal nb_retries
1689 if nb_retries < failures_before_success:
1690 nb_retries += 1
1691 return httpx.Response(500)
1692 return httpx.Response(200)
1693
1694 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1695
1696 response = await client.chat.completions.with_raw_response.create(
1697 messages=[
1698 {
1699 "content": "string",
1700 "role": "developer",
1701 }
1702 ],
1703 model="gpt-4o",
1704 extra_headers={"x-stainless-retry-count": Omit()},
1705 )
1706
1707 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
1708
1709 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1710 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1711 @pytest.mark.respx(base_url=base_url)
1712 @pytest.mark.asyncio
1713 async def test_overwrite_retry_count_header(
1714 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1715 ) -> None:
1716 client = async_client.with_options(max_retries=4)
1717
1718 nb_retries = 0
1719
1720 def retry_handler(_request: httpx.Request) -> httpx.Response:
1721 nonlocal nb_retries
1722 if nb_retries < failures_before_success:
1723 nb_retries += 1
1724 return httpx.Response(500)
1725 return httpx.Response(200)
1726
1727 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1728
1729 response = await client.chat.completions.with_raw_response.create(
1730 messages=[
1731 {
1732 "content": "string",
1733 "role": "developer",
1734 }
1735 ],
1736 model="gpt-4o",
1737 extra_headers={"x-stainless-retry-count": "42"},
1738 )
1739
1740 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
1741
1742 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1743 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1744 @pytest.mark.respx(base_url=base_url)
1745 @pytest.mark.asyncio
1746 async def test_retries_taken_new_response_class(
1747 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1748 ) -> None:
1749 client = async_client.with_options(max_retries=4)
1750
1751 nb_retries = 0
1752
1753 def retry_handler(_request: httpx.Request) -> httpx.Response:
1754 nonlocal nb_retries
1755 if nb_retries < failures_before_success:
1756 nb_retries += 1
1757 return httpx.Response(500)
1758 return httpx.Response(200)
1759
1760 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1761
1762 async with client.chat.completions.with_streaming_response.create(
1763 messages=[
1764 {
1765 "content": "string",
1766 "role": "developer",
1767 }
1768 ],
1769 model="gpt-4o",
1770 ) as response:
1771 assert response.retries_taken == failures_before_success
1772 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1773
1774 def test_get_platform(self) -> None:
1775 # A previous implementation of asyncify could leave threads unterminated when
1776 # used with nest_asyncio.
1777 #
1778 # Since nest_asyncio.apply() is global and cannot be un-applied, this
1779 # test is run in a separate process to avoid affecting other tests.
1780 test_code = dedent("""
1781 import asyncio
1782 import nest_asyncio
1783 import threading
1784
1785 from openai._utils import asyncify
1786 from openai._base_client import get_platform
1787
1788 async def test_main() -> None:
1789 result = await asyncify(get_platform)()
1790 print(result)
1791 for thread in threading.enumerate():
1792 print(thread.name)
1793
1794 nest_asyncio.apply()
1795 asyncio.run(test_main())
1796 """)
1797 with subprocess.Popen(
1798 [sys.executable, "-c", test_code],
1799 text=True,
1800 ) as process:
1801 timeout = 10 # seconds
1802
1803 start_time = time.monotonic()
1804 while True:
1805 return_code = process.poll()
1806 if return_code is not None:
1807 if return_code != 0:
1808 raise AssertionError("calling get_platform using asyncify resulted in a non-zero exit code")
1809
1810 # success
1811 break
1812
1813 if time.monotonic() - start_time > timeout:
1814 process.kill()
1815 raise AssertionError("calling get_platform using asyncify resulted in a hung process")
1816
1817 time.sleep(0.1)
1818