openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
7319e6e3f139d68173e03033f077732c4c4bdfa5

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_client.py

1831lines · modecode

1# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
2
3from __future__ import annotations
4
5import gc
6import os
7import sys
8import json
9import time
10import asyncio
11import inspect
12import subprocess
13import tracemalloc
14from typing import Any, Union, cast
15from textwrap import dedent
16from unittest import mock
17from typing_extensions import Literal
18
19import httpx
20import pytest
21from respx import MockRouter
22from pydantic import ValidationError
23
24from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
25from openai._types import Omit
26from openai._utils import maybe_transform
27from openai._models import BaseModel, FinalRequestOptions
28from openai._constants import RAW_RESPONSE_HEADER
29from openai._streaming import Stream, AsyncStream
30from openai._exceptions import OpenAIError, APIStatusError, APITimeoutError, APIResponseValidationError
31from openai._base_client import DEFAULT_TIMEOUT, HTTPX_DEFAULT_TIMEOUT, BaseClient, make_request_options
32from openai.types.chat.completion_create_params import CompletionCreateParamsNonStreaming
33
34from .utils import update_env
35
36base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010")
37api_key = "My API Key"
38
39
40def _get_params(client: BaseClient[Any, Any]) -> dict[str, str]:
41 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
42 url = httpx.URL(request.url)
43 return dict(url.params)
44
45
46def _low_retry_timeout(*_args: Any, **_kwargs: Any) -> float:
47 return 0.1
48
49
50def _get_open_connections(client: OpenAI | AsyncOpenAI) -> int:
51 transport = client._client._transport
52 assert isinstance(transport, httpx.HTTPTransport) or isinstance(transport, httpx.AsyncHTTPTransport)
53
54 pool = transport._pool
55 return len(pool._requests)
56
57
58class TestOpenAI:
59 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
60
61 @pytest.mark.respx(base_url=base_url)
62 def test_raw_response(self, respx_mock: MockRouter) -> None:
63 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
64
65 response = self.client.post("/foo", cast_to=httpx.Response)
66 assert response.status_code == 200
67 assert isinstance(response, httpx.Response)
68 assert response.json() == {"foo": "bar"}
69
70 @pytest.mark.respx(base_url=base_url)
71 def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
72 respx_mock.post("/foo").mock(
73 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
74 )
75
76 response = self.client.post("/foo", cast_to=httpx.Response)
77 assert response.status_code == 200
78 assert isinstance(response, httpx.Response)
79 assert response.json() == {"foo": "bar"}
80
81 def test_copy(self) -> None:
82 copied = self.client.copy()
83 assert id(copied) != id(self.client)
84
85 copied = self.client.copy(api_key="another My API Key")
86 assert copied.api_key == "another My API Key"
87 assert self.client.api_key == "My API Key"
88
89 def test_copy_default_options(self) -> None:
90 # options that have a default are overridden correctly
91 copied = self.client.copy(max_retries=7)
92 assert copied.max_retries == 7
93 assert self.client.max_retries == 2
94
95 copied2 = copied.copy(max_retries=6)
96 assert copied2.max_retries == 6
97 assert copied.max_retries == 7
98
99 # timeout
100 assert isinstance(self.client.timeout, httpx.Timeout)
101 copied = self.client.copy(timeout=None)
102 assert copied.timeout is None
103 assert isinstance(self.client.timeout, httpx.Timeout)
104
105 def test_copy_default_headers(self) -> None:
106 client = OpenAI(
107 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
108 )
109 assert client.default_headers["X-Foo"] == "bar"
110
111 # does not override the already given value when not specified
112 copied = client.copy()
113 assert copied.default_headers["X-Foo"] == "bar"
114
115 # merges already given headers
116 copied = client.copy(default_headers={"X-Bar": "stainless"})
117 assert copied.default_headers["X-Foo"] == "bar"
118 assert copied.default_headers["X-Bar"] == "stainless"
119
120 # uses new values for any already given headers
121 copied = client.copy(default_headers={"X-Foo": "stainless"})
122 assert copied.default_headers["X-Foo"] == "stainless"
123
124 # set_default_headers
125
126 # completely overrides already set values
127 copied = client.copy(set_default_headers={})
128 assert copied.default_headers.get("X-Foo") is None
129
130 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
131 assert copied.default_headers["X-Bar"] == "Robert"
132
133 with pytest.raises(
134 ValueError,
135 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
136 ):
137 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
138
139 def test_copy_default_query(self) -> None:
140 client = OpenAI(
141 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
142 )
143 assert _get_params(client)["foo"] == "bar"
144
145 # does not override the already given value when not specified
146 copied = client.copy()
147 assert _get_params(copied)["foo"] == "bar"
148
149 # merges already given params
150 copied = client.copy(default_query={"bar": "stainless"})
151 params = _get_params(copied)
152 assert params["foo"] == "bar"
153 assert params["bar"] == "stainless"
154
155 # uses new values for any already given headers
156 copied = client.copy(default_query={"foo": "stainless"})
157 assert _get_params(copied)["foo"] == "stainless"
158
159 # set_default_query
160
161 # completely overrides already set values
162 copied = client.copy(set_default_query={})
163 assert _get_params(copied) == {}
164
165 copied = client.copy(set_default_query={"bar": "Robert"})
166 assert _get_params(copied)["bar"] == "Robert"
167
168 with pytest.raises(
169 ValueError,
170 # TODO: update
171 match="`default_query` and `set_default_query` arguments are mutually exclusive",
172 ):
173 client.copy(set_default_query={}, default_query={"foo": "Bar"})
174
175 def test_copy_signature(self) -> None:
176 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
177 init_signature = inspect.signature(
178 # mypy doesn't like that we access the `__init__` property.
179 self.client.__init__, # type: ignore[misc]
180 )
181 copy_signature = inspect.signature(self.client.copy)
182 exclude_params = {"transport", "proxies", "_strict_response_validation"}
183
184 for name in init_signature.parameters.keys():
185 if name in exclude_params:
186 continue
187
188 copy_param = copy_signature.parameters.get(name)
189 assert copy_param is not None, f"copy() signature is missing the {name} param"
190
191 def test_copy_build_request(self) -> None:
192 options = FinalRequestOptions(method="get", url="/foo")
193
194 def build_request(options: FinalRequestOptions) -> None:
195 client = self.client.copy()
196 client._build_request(options)
197
198 # ensure that the machinery is warmed up before tracing starts.
199 build_request(options)
200 gc.collect()
201
202 tracemalloc.start(1000)
203
204 snapshot_before = tracemalloc.take_snapshot()
205
206 ITERATIONS = 10
207 for _ in range(ITERATIONS):
208 build_request(options)
209
210 gc.collect()
211 snapshot_after = tracemalloc.take_snapshot()
212
213 tracemalloc.stop()
214
215 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
216 if diff.count == 0:
217 # Avoid false positives by considering only leaks (i.e. allocations that persist).
218 return
219
220 if diff.count % ITERATIONS != 0:
221 # Avoid false positives by considering only leaks that appear per iteration.
222 return
223
224 for frame in diff.traceback:
225 if any(
226 frame.filename.endswith(fragment)
227 for fragment in [
228 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
229 #
230 # removing the decorator fixes the leak for reasons we don't understand.
231 "openai/_legacy_response.py",
232 "openai/_response.py",
233 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
234 "openai/_compat.py",
235 # Standard library leaks we don't care about.
236 "/logging/__init__.py",
237 ]
238 ):
239 return
240
241 leaks.append(diff)
242
243 leaks: list[tracemalloc.StatisticDiff] = []
244 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
245 add_leak(leaks, diff)
246 if leaks:
247 for leak in leaks:
248 print("MEMORY LEAK:", leak)
249 for frame in leak.traceback:
250 print(frame)
251 raise AssertionError()
252
253 def test_request_timeout(self) -> None:
254 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
255 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
256 assert timeout == DEFAULT_TIMEOUT
257
258 request = self.client._build_request(
259 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
260 )
261 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
262 assert timeout == httpx.Timeout(100.0)
263
264 def test_client_timeout_option(self) -> None:
265 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0))
266
267 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
268 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
269 assert timeout == httpx.Timeout(0)
270
271 def test_http_client_timeout_option(self) -> None:
272 # custom timeout given to the httpx client should be used
273 with httpx.Client(timeout=None) as http_client:
274 client = OpenAI(
275 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
276 )
277
278 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
279 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
280 assert timeout == httpx.Timeout(None)
281
282 # no timeout given to the httpx client should not use the httpx default
283 with httpx.Client() as http_client:
284 client = OpenAI(
285 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
286 )
287
288 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
289 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
290 assert timeout == DEFAULT_TIMEOUT
291
292 # explicitly passing the default timeout currently results in it being ignored
293 with httpx.Client(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
294 client = OpenAI(
295 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
296 )
297
298 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
299 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
300 assert timeout == DEFAULT_TIMEOUT # our default
301
302 async def test_invalid_http_client(self) -> None:
303 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
304 async with httpx.AsyncClient() as http_client:
305 OpenAI(
306 base_url=base_url,
307 api_key=api_key,
308 _strict_response_validation=True,
309 http_client=cast(Any, http_client),
310 )
311
312 def test_default_headers_option(self) -> None:
313 client = OpenAI(
314 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
315 )
316 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
317 assert request.headers.get("x-foo") == "bar"
318 assert request.headers.get("x-stainless-lang") == "python"
319
320 client2 = OpenAI(
321 base_url=base_url,
322 api_key=api_key,
323 _strict_response_validation=True,
324 default_headers={
325 "X-Foo": "stainless",
326 "X-Stainless-Lang": "my-overriding-header",
327 },
328 )
329 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
330 assert request.headers.get("x-foo") == "stainless"
331 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
332
333 def test_validate_headers(self) -> None:
334 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
335 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
336 assert request.headers.get("Authorization") == f"Bearer {api_key}"
337
338 with pytest.raises(OpenAIError):
339 with update_env(**{"OPENAI_API_KEY": Omit()}):
340 client2 = OpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
341 _ = client2
342
343 def test_default_query_option(self) -> None:
344 client = OpenAI(
345 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
346 )
347 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
348 url = httpx.URL(request.url)
349 assert dict(url.params) == {"query_param": "bar"}
350
351 request = client._build_request(
352 FinalRequestOptions(
353 method="get",
354 url="/foo",
355 params={"foo": "baz", "query_param": "overridden"},
356 )
357 )
358 url = httpx.URL(request.url)
359 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
360
361 def test_request_extra_json(self) -> None:
362 request = self.client._build_request(
363 FinalRequestOptions(
364 method="post",
365 url="/foo",
366 json_data={"foo": "bar"},
367 extra_json={"baz": False},
368 ),
369 )
370 data = json.loads(request.content.decode("utf-8"))
371 assert data == {"foo": "bar", "baz": False}
372
373 request = self.client._build_request(
374 FinalRequestOptions(
375 method="post",
376 url="/foo",
377 extra_json={"baz": False},
378 ),
379 )
380 data = json.loads(request.content.decode("utf-8"))
381 assert data == {"baz": False}
382
383 # `extra_json` takes priority over `json_data` when keys clash
384 request = self.client._build_request(
385 FinalRequestOptions(
386 method="post",
387 url="/foo",
388 json_data={"foo": "bar", "baz": True},
389 extra_json={"baz": None},
390 ),
391 )
392 data = json.loads(request.content.decode("utf-8"))
393 assert data == {"foo": "bar", "baz": None}
394
395 def test_request_extra_headers(self) -> None:
396 request = self.client._build_request(
397 FinalRequestOptions(
398 method="post",
399 url="/foo",
400 **make_request_options(extra_headers={"X-Foo": "Foo"}),
401 ),
402 )
403 assert request.headers.get("X-Foo") == "Foo"
404
405 # `extra_headers` takes priority over `default_headers` when keys clash
406 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
407 FinalRequestOptions(
408 method="post",
409 url="/foo",
410 **make_request_options(
411 extra_headers={"X-Bar": "false"},
412 ),
413 ),
414 )
415 assert request.headers.get("X-Bar") == "false"
416
417 def test_request_extra_query(self) -> None:
418 request = self.client._build_request(
419 FinalRequestOptions(
420 method="post",
421 url="/foo",
422 **make_request_options(
423 extra_query={"my_query_param": "Foo"},
424 ),
425 ),
426 )
427 params = dict(request.url.params)
428 assert params == {"my_query_param": "Foo"}
429
430 # if both `query` and `extra_query` are given, they are merged
431 request = self.client._build_request(
432 FinalRequestOptions(
433 method="post",
434 url="/foo",
435 **make_request_options(
436 query={"bar": "1"},
437 extra_query={"foo": "2"},
438 ),
439 ),
440 )
441 params = dict(request.url.params)
442 assert params == {"bar": "1", "foo": "2"}
443
444 # `extra_query` takes priority over `query` when keys clash
445 request = self.client._build_request(
446 FinalRequestOptions(
447 method="post",
448 url="/foo",
449 **make_request_options(
450 query={"foo": "1"},
451 extra_query={"foo": "2"},
452 ),
453 ),
454 )
455 params = dict(request.url.params)
456 assert params == {"foo": "2"}
457
458 def test_multipart_repeating_array(self, client: OpenAI) -> None:
459 request = client._build_request(
460 FinalRequestOptions.construct(
461 method="get",
462 url="/foo",
463 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
464 json_data={"array": ["foo", "bar"]},
465 files=[("foo.txt", b"hello world")],
466 )
467 )
468
469 assert request.read().split(b"\r\n") == [
470 b"--6b7ba517decee4a450543ea6ae821c82",
471 b'Content-Disposition: form-data; name="array[]"',
472 b"",
473 b"foo",
474 b"--6b7ba517decee4a450543ea6ae821c82",
475 b'Content-Disposition: form-data; name="array[]"',
476 b"",
477 b"bar",
478 b"--6b7ba517decee4a450543ea6ae821c82",
479 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
480 b"Content-Type: application/octet-stream",
481 b"",
482 b"hello world",
483 b"--6b7ba517decee4a450543ea6ae821c82--",
484 b"",
485 ]
486
487 @pytest.mark.respx(base_url=base_url)
488 def test_basic_union_response(self, respx_mock: MockRouter) -> None:
489 class Model1(BaseModel):
490 name: str
491
492 class Model2(BaseModel):
493 foo: str
494
495 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
496
497 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
498 assert isinstance(response, Model2)
499 assert response.foo == "bar"
500
501 @pytest.mark.respx(base_url=base_url)
502 def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
503 """Union of objects with the same field name using a different type"""
504
505 class Model1(BaseModel):
506 foo: int
507
508 class Model2(BaseModel):
509 foo: str
510
511 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
512
513 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
514 assert isinstance(response, Model2)
515 assert response.foo == "bar"
516
517 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
518
519 response = self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
520 assert isinstance(response, Model1)
521 assert response.foo == 1
522
523 @pytest.mark.respx(base_url=base_url)
524 def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
525 """
526 Response that sets Content-Type to something other than application/json but returns json data
527 """
528
529 class Model(BaseModel):
530 foo: int
531
532 respx_mock.get("/foo").mock(
533 return_value=httpx.Response(
534 200,
535 content=json.dumps({"foo": 2}),
536 headers={"Content-Type": "application/text"},
537 )
538 )
539
540 response = self.client.get("/foo", cast_to=Model)
541 assert isinstance(response, Model)
542 assert response.foo == 2
543
544 def test_base_url_setter(self) -> None:
545 client = OpenAI(base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True)
546 assert client.base_url == "https://example.com/from_init/"
547
548 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
549
550 assert client.base_url == "https://example.com/from_setter/"
551
552 def test_base_url_env(self) -> None:
553 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
554 client = OpenAI(api_key=api_key, _strict_response_validation=True)
555 assert client.base_url == "http://localhost:5000/from/env/"
556
557 @pytest.mark.parametrize(
558 "client",
559 [
560 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
561 OpenAI(
562 base_url="http://localhost:5000/custom/path/",
563 api_key=api_key,
564 _strict_response_validation=True,
565 http_client=httpx.Client(),
566 ),
567 ],
568 ids=["standard", "custom http client"],
569 )
570 def test_base_url_trailing_slash(self, client: OpenAI) -> None:
571 request = client._build_request(
572 FinalRequestOptions(
573 method="post",
574 url="/foo",
575 json_data={"foo": "bar"},
576 ),
577 )
578 assert request.url == "http://localhost:5000/custom/path/foo"
579
580 @pytest.mark.parametrize(
581 "client",
582 [
583 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
584 OpenAI(
585 base_url="http://localhost:5000/custom/path/",
586 api_key=api_key,
587 _strict_response_validation=True,
588 http_client=httpx.Client(),
589 ),
590 ],
591 ids=["standard", "custom http client"],
592 )
593 def test_base_url_no_trailing_slash(self, client: OpenAI) -> None:
594 request = client._build_request(
595 FinalRequestOptions(
596 method="post",
597 url="/foo",
598 json_data={"foo": "bar"},
599 ),
600 )
601 assert request.url == "http://localhost:5000/custom/path/foo"
602
603 @pytest.mark.parametrize(
604 "client",
605 [
606 OpenAI(base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True),
607 OpenAI(
608 base_url="http://localhost:5000/custom/path/",
609 api_key=api_key,
610 _strict_response_validation=True,
611 http_client=httpx.Client(),
612 ),
613 ],
614 ids=["standard", "custom http client"],
615 )
616 def test_absolute_request_url(self, client: OpenAI) -> None:
617 request = client._build_request(
618 FinalRequestOptions(
619 method="post",
620 url="https://myapi.com/foo",
621 json_data={"foo": "bar"},
622 ),
623 )
624 assert request.url == "https://myapi.com/foo"
625
626 def test_copied_client_does_not_close_http(self) -> None:
627 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
628 assert not client.is_closed()
629
630 copied = client.copy()
631 assert copied is not client
632
633 del copied
634
635 assert not client.is_closed()
636
637 def test_client_context_manager(self) -> None:
638 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
639 with client as c2:
640 assert c2 is client
641 assert not c2.is_closed()
642 assert not client.is_closed()
643 assert client.is_closed()
644
645 @pytest.mark.respx(base_url=base_url)
646 def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
647 class Model(BaseModel):
648 foo: str
649
650 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
651
652 with pytest.raises(APIResponseValidationError) as exc:
653 self.client.get("/foo", cast_to=Model)
654
655 assert isinstance(exc.value.__cause__, ValidationError)
656
657 def test_client_max_retries_validation(self) -> None:
658 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
659 OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None))
660
661 @pytest.mark.respx(base_url=base_url)
662 def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
663 class Model(BaseModel):
664 name: str
665
666 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
667
668 stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model])
669 assert isinstance(stream, Stream)
670 stream.response.close()
671
672 @pytest.mark.respx(base_url=base_url)
673 def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
674 class Model(BaseModel):
675 name: str
676
677 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
678
679 strict_client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
680
681 with pytest.raises(APIResponseValidationError):
682 strict_client.get("/foo", cast_to=Model)
683
684 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
685
686 response = client.get("/foo", cast_to=Model)
687 assert isinstance(response, str) # type: ignore[unreachable]
688
689 @pytest.mark.parametrize(
690 "remaining_retries,retry_after,timeout",
691 [
692 [3, "20", 20],
693 [3, "0", 0.5],
694 [3, "-10", 0.5],
695 [3, "60", 60],
696 [3, "61", 0.5],
697 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
698 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
699 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
700 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
701 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
702 [3, "99999999999999999999999999999999999", 0.5],
703 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
704 [3, "", 0.5],
705 [2, "", 0.5 * 2.0],
706 [1, "", 0.5 * 4.0],
707 [-1100, "", 8], # test large number potentially overflowing
708 ],
709 )
710 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
711 def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
712 client = OpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
713
714 headers = httpx.Headers({"retry-after": retry_after})
715 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
716 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
717 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
718
719 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
720 @pytest.mark.respx(base_url=base_url)
721 def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
722 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
723
724 with pytest.raises(APITimeoutError):
725 self.client.post(
726 "/chat/completions",
727 body=cast(
728 object,
729 maybe_transform(
730 dict(
731 messages=[
732 {
733 "role": "user",
734 "content": "Say this is a test",
735 }
736 ],
737 model="gpt-4o",
738 ),
739 CompletionCreateParamsNonStreaming,
740 ),
741 ),
742 cast_to=httpx.Response,
743 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
744 )
745
746 assert _get_open_connections(self.client) == 0
747
748 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
749 @pytest.mark.respx(base_url=base_url)
750 def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
751 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
752
753 with pytest.raises(APIStatusError):
754 self.client.post(
755 "/chat/completions",
756 body=cast(
757 object,
758 maybe_transform(
759 dict(
760 messages=[
761 {
762 "role": "user",
763 "content": "Say this is a test",
764 }
765 ],
766 model="gpt-4o",
767 ),
768 CompletionCreateParamsNonStreaming,
769 ),
770 ),
771 cast_to=httpx.Response,
772 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
773 )
774
775 assert _get_open_connections(self.client) == 0
776
777 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
778 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
779 @pytest.mark.respx(base_url=base_url)
780 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
781 def test_retries_taken(
782 self,
783 client: OpenAI,
784 failures_before_success: int,
785 failure_mode: Literal["status", "exception"],
786 respx_mock: MockRouter,
787 ) -> None:
788 client = client.with_options(max_retries=4)
789
790 nb_retries = 0
791
792 def retry_handler(_request: httpx.Request) -> httpx.Response:
793 nonlocal nb_retries
794 if nb_retries < failures_before_success:
795 nb_retries += 1
796 if failure_mode == "exception":
797 raise RuntimeError("oops")
798 return httpx.Response(500)
799 return httpx.Response(200)
800
801 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
802
803 response = client.chat.completions.with_raw_response.create(
804 messages=[
805 {
806 "content": "string",
807 "role": "developer",
808 }
809 ],
810 model="gpt-4o",
811 )
812
813 assert response.retries_taken == failures_before_success
814 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
815
816 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
817 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
818 @pytest.mark.respx(base_url=base_url)
819 def test_omit_retry_count_header(
820 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
821 ) -> None:
822 client = client.with_options(max_retries=4)
823
824 nb_retries = 0
825
826 def retry_handler(_request: httpx.Request) -> httpx.Response:
827 nonlocal nb_retries
828 if nb_retries < failures_before_success:
829 nb_retries += 1
830 return httpx.Response(500)
831 return httpx.Response(200)
832
833 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
834
835 response = client.chat.completions.with_raw_response.create(
836 messages=[
837 {
838 "content": "string",
839 "role": "developer",
840 }
841 ],
842 model="gpt-4o",
843 extra_headers={"x-stainless-retry-count": Omit()},
844 )
845
846 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
847
848 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
849 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
850 @pytest.mark.respx(base_url=base_url)
851 def test_overwrite_retry_count_header(
852 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
853 ) -> None:
854 client = client.with_options(max_retries=4)
855
856 nb_retries = 0
857
858 def retry_handler(_request: httpx.Request) -> httpx.Response:
859 nonlocal nb_retries
860 if nb_retries < failures_before_success:
861 nb_retries += 1
862 return httpx.Response(500)
863 return httpx.Response(200)
864
865 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
866
867 response = client.chat.completions.with_raw_response.create(
868 messages=[
869 {
870 "content": "string",
871 "role": "developer",
872 }
873 ],
874 model="gpt-4o",
875 extra_headers={"x-stainless-retry-count": "42"},
876 )
877
878 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
879
880 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
881 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
882 @pytest.mark.respx(base_url=base_url)
883 def test_retries_taken_new_response_class(
884 self, client: OpenAI, failures_before_success: int, respx_mock: MockRouter
885 ) -> None:
886 client = client.with_options(max_retries=4)
887
888 nb_retries = 0
889
890 def retry_handler(_request: httpx.Request) -> httpx.Response:
891 nonlocal nb_retries
892 if nb_retries < failures_before_success:
893 nb_retries += 1
894 return httpx.Response(500)
895 return httpx.Response(200)
896
897 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
898
899 with client.chat.completions.with_streaming_response.create(
900 messages=[
901 {
902 "content": "string",
903 "role": "developer",
904 }
905 ],
906 model="gpt-4o",
907 ) as response:
908 assert response.retries_taken == failures_before_success
909 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
910
911
912class TestAsyncOpenAI:
913 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
914
915 @pytest.mark.respx(base_url=base_url)
916 @pytest.mark.asyncio
917 async def test_raw_response(self, respx_mock: MockRouter) -> None:
918 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
919
920 response = await self.client.post("/foo", cast_to=httpx.Response)
921 assert response.status_code == 200
922 assert isinstance(response, httpx.Response)
923 assert response.json() == {"foo": "bar"}
924
925 @pytest.mark.respx(base_url=base_url)
926 @pytest.mark.asyncio
927 async def test_raw_response_for_binary(self, respx_mock: MockRouter) -> None:
928 respx_mock.post("/foo").mock(
929 return_value=httpx.Response(200, headers={"Content-Type": "application/binary"}, content='{"foo": "bar"}')
930 )
931
932 response = await self.client.post("/foo", cast_to=httpx.Response)
933 assert response.status_code == 200
934 assert isinstance(response, httpx.Response)
935 assert response.json() == {"foo": "bar"}
936
937 def test_copy(self) -> None:
938 copied = self.client.copy()
939 assert id(copied) != id(self.client)
940
941 copied = self.client.copy(api_key="another My API Key")
942 assert copied.api_key == "another My API Key"
943 assert self.client.api_key == "My API Key"
944
945 def test_copy_default_options(self) -> None:
946 # options that have a default are overridden correctly
947 copied = self.client.copy(max_retries=7)
948 assert copied.max_retries == 7
949 assert self.client.max_retries == 2
950
951 copied2 = copied.copy(max_retries=6)
952 assert copied2.max_retries == 6
953 assert copied.max_retries == 7
954
955 # timeout
956 assert isinstance(self.client.timeout, httpx.Timeout)
957 copied = self.client.copy(timeout=None)
958 assert copied.timeout is None
959 assert isinstance(self.client.timeout, httpx.Timeout)
960
961 def test_copy_default_headers(self) -> None:
962 client = AsyncOpenAI(
963 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
964 )
965 assert client.default_headers["X-Foo"] == "bar"
966
967 # does not override the already given value when not specified
968 copied = client.copy()
969 assert copied.default_headers["X-Foo"] == "bar"
970
971 # merges already given headers
972 copied = client.copy(default_headers={"X-Bar": "stainless"})
973 assert copied.default_headers["X-Foo"] == "bar"
974 assert copied.default_headers["X-Bar"] == "stainless"
975
976 # uses new values for any already given headers
977 copied = client.copy(default_headers={"X-Foo": "stainless"})
978 assert copied.default_headers["X-Foo"] == "stainless"
979
980 # set_default_headers
981
982 # completely overrides already set values
983 copied = client.copy(set_default_headers={})
984 assert copied.default_headers.get("X-Foo") is None
985
986 copied = client.copy(set_default_headers={"X-Bar": "Robert"})
987 assert copied.default_headers["X-Bar"] == "Robert"
988
989 with pytest.raises(
990 ValueError,
991 match="`default_headers` and `set_default_headers` arguments are mutually exclusive",
992 ):
993 client.copy(set_default_headers={}, default_headers={"X-Foo": "Bar"})
994
995 def test_copy_default_query(self) -> None:
996 client = AsyncOpenAI(
997 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"foo": "bar"}
998 )
999 assert _get_params(client)["foo"] == "bar"
1000
1001 # does not override the already given value when not specified
1002 copied = client.copy()
1003 assert _get_params(copied)["foo"] == "bar"
1004
1005 # merges already given params
1006 copied = client.copy(default_query={"bar": "stainless"})
1007 params = _get_params(copied)
1008 assert params["foo"] == "bar"
1009 assert params["bar"] == "stainless"
1010
1011 # uses new values for any already given headers
1012 copied = client.copy(default_query={"foo": "stainless"})
1013 assert _get_params(copied)["foo"] == "stainless"
1014
1015 # set_default_query
1016
1017 # completely overrides already set values
1018 copied = client.copy(set_default_query={})
1019 assert _get_params(copied) == {}
1020
1021 copied = client.copy(set_default_query={"bar": "Robert"})
1022 assert _get_params(copied)["bar"] == "Robert"
1023
1024 with pytest.raises(
1025 ValueError,
1026 # TODO: update
1027 match="`default_query` and `set_default_query` arguments are mutually exclusive",
1028 ):
1029 client.copy(set_default_query={}, default_query={"foo": "Bar"})
1030
1031 def test_copy_signature(self) -> None:
1032 # ensure the same parameters that can be passed to the client are defined in the `.copy()` method
1033 init_signature = inspect.signature(
1034 # mypy doesn't like that we access the `__init__` property.
1035 self.client.__init__, # type: ignore[misc]
1036 )
1037 copy_signature = inspect.signature(self.client.copy)
1038 exclude_params = {"transport", "proxies", "_strict_response_validation"}
1039
1040 for name in init_signature.parameters.keys():
1041 if name in exclude_params:
1042 continue
1043
1044 copy_param = copy_signature.parameters.get(name)
1045 assert copy_param is not None, f"copy() signature is missing the {name} param"
1046
1047 def test_copy_build_request(self) -> None:
1048 options = FinalRequestOptions(method="get", url="/foo")
1049
1050 def build_request(options: FinalRequestOptions) -> None:
1051 client = self.client.copy()
1052 client._build_request(options)
1053
1054 # ensure that the machinery is warmed up before tracing starts.
1055 build_request(options)
1056 gc.collect()
1057
1058 tracemalloc.start(1000)
1059
1060 snapshot_before = tracemalloc.take_snapshot()
1061
1062 ITERATIONS = 10
1063 for _ in range(ITERATIONS):
1064 build_request(options)
1065
1066 gc.collect()
1067 snapshot_after = tracemalloc.take_snapshot()
1068
1069 tracemalloc.stop()
1070
1071 def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.StatisticDiff) -> None:
1072 if diff.count == 0:
1073 # Avoid false positives by considering only leaks (i.e. allocations that persist).
1074 return
1075
1076 if diff.count % ITERATIONS != 0:
1077 # Avoid false positives by considering only leaks that appear per iteration.
1078 return
1079
1080 for frame in diff.traceback:
1081 if any(
1082 frame.filename.endswith(fragment)
1083 for fragment in [
1084 # to_raw_response_wrapper leaks through the @functools.wraps() decorator.
1085 #
1086 # removing the decorator fixes the leak for reasons we don't understand.
1087 "openai/_legacy_response.py",
1088 "openai/_response.py",
1089 # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason.
1090 "openai/_compat.py",
1091 # Standard library leaks we don't care about.
1092 "/logging/__init__.py",
1093 ]
1094 ):
1095 return
1096
1097 leaks.append(diff)
1098
1099 leaks: list[tracemalloc.StatisticDiff] = []
1100 for diff in snapshot_after.compare_to(snapshot_before, "traceback"):
1101 add_leak(leaks, diff)
1102 if leaks:
1103 for leak in leaks:
1104 print("MEMORY LEAK:", leak)
1105 for frame in leak.traceback:
1106 print(frame)
1107 raise AssertionError()
1108
1109 async def test_request_timeout(self) -> None:
1110 request = self.client._build_request(FinalRequestOptions(method="get", url="/foo"))
1111 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1112 assert timeout == DEFAULT_TIMEOUT
1113
1114 request = self.client._build_request(
1115 FinalRequestOptions(method="get", url="/foo", timeout=httpx.Timeout(100.0))
1116 )
1117 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1118 assert timeout == httpx.Timeout(100.0)
1119
1120 async def test_client_timeout_option(self) -> None:
1121 client = AsyncOpenAI(
1122 base_url=base_url, api_key=api_key, _strict_response_validation=True, timeout=httpx.Timeout(0)
1123 )
1124
1125 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1126 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1127 assert timeout == httpx.Timeout(0)
1128
1129 async def test_http_client_timeout_option(self) -> None:
1130 # custom timeout given to the httpx client should be used
1131 async with httpx.AsyncClient(timeout=None) as http_client:
1132 client = AsyncOpenAI(
1133 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1134 )
1135
1136 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1137 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1138 assert timeout == httpx.Timeout(None)
1139
1140 # no timeout given to the httpx client should not use the httpx default
1141 async with httpx.AsyncClient() as http_client:
1142 client = AsyncOpenAI(
1143 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1144 )
1145
1146 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1147 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1148 assert timeout == DEFAULT_TIMEOUT
1149
1150 # explicitly passing the default timeout currently results in it being ignored
1151 async with httpx.AsyncClient(timeout=HTTPX_DEFAULT_TIMEOUT) as http_client:
1152 client = AsyncOpenAI(
1153 base_url=base_url, api_key=api_key, _strict_response_validation=True, http_client=http_client
1154 )
1155
1156 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1157 timeout = httpx.Timeout(**request.extensions["timeout"]) # type: ignore
1158 assert timeout == DEFAULT_TIMEOUT # our default
1159
1160 def test_invalid_http_client(self) -> None:
1161 with pytest.raises(TypeError, match="Invalid `http_client` arg"):
1162 with httpx.Client() as http_client:
1163 AsyncOpenAI(
1164 base_url=base_url,
1165 api_key=api_key,
1166 _strict_response_validation=True,
1167 http_client=cast(Any, http_client),
1168 )
1169
1170 def test_default_headers_option(self) -> None:
1171 client = AsyncOpenAI(
1172 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_headers={"X-Foo": "bar"}
1173 )
1174 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1175 assert request.headers.get("x-foo") == "bar"
1176 assert request.headers.get("x-stainless-lang") == "python"
1177
1178 client2 = AsyncOpenAI(
1179 base_url=base_url,
1180 api_key=api_key,
1181 _strict_response_validation=True,
1182 default_headers={
1183 "X-Foo": "stainless",
1184 "X-Stainless-Lang": "my-overriding-header",
1185 },
1186 )
1187 request = client2._build_request(FinalRequestOptions(method="get", url="/foo"))
1188 assert request.headers.get("x-foo") == "stainless"
1189 assert request.headers.get("x-stainless-lang") == "my-overriding-header"
1190
1191 def test_validate_headers(self) -> None:
1192 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1193 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1194 assert request.headers.get("Authorization") == f"Bearer {api_key}"
1195
1196 with pytest.raises(OpenAIError):
1197 with update_env(**{"OPENAI_API_KEY": Omit()}):
1198 client2 = AsyncOpenAI(base_url=base_url, api_key=None, _strict_response_validation=True)
1199 _ = client2
1200
1201 def test_default_query_option(self) -> None:
1202 client = AsyncOpenAI(
1203 base_url=base_url, api_key=api_key, _strict_response_validation=True, default_query={"query_param": "bar"}
1204 )
1205 request = client._build_request(FinalRequestOptions(method="get", url="/foo"))
1206 url = httpx.URL(request.url)
1207 assert dict(url.params) == {"query_param": "bar"}
1208
1209 request = client._build_request(
1210 FinalRequestOptions(
1211 method="get",
1212 url="/foo",
1213 params={"foo": "baz", "query_param": "overridden"},
1214 )
1215 )
1216 url = httpx.URL(request.url)
1217 assert dict(url.params) == {"foo": "baz", "query_param": "overridden"}
1218
1219 def test_request_extra_json(self) -> None:
1220 request = self.client._build_request(
1221 FinalRequestOptions(
1222 method="post",
1223 url="/foo",
1224 json_data={"foo": "bar"},
1225 extra_json={"baz": False},
1226 ),
1227 )
1228 data = json.loads(request.content.decode("utf-8"))
1229 assert data == {"foo": "bar", "baz": False}
1230
1231 request = self.client._build_request(
1232 FinalRequestOptions(
1233 method="post",
1234 url="/foo",
1235 extra_json={"baz": False},
1236 ),
1237 )
1238 data = json.loads(request.content.decode("utf-8"))
1239 assert data == {"baz": False}
1240
1241 # `extra_json` takes priority over `json_data` when keys clash
1242 request = self.client._build_request(
1243 FinalRequestOptions(
1244 method="post",
1245 url="/foo",
1246 json_data={"foo": "bar", "baz": True},
1247 extra_json={"baz": None},
1248 ),
1249 )
1250 data = json.loads(request.content.decode("utf-8"))
1251 assert data == {"foo": "bar", "baz": None}
1252
1253 def test_request_extra_headers(self) -> None:
1254 request = self.client._build_request(
1255 FinalRequestOptions(
1256 method="post",
1257 url="/foo",
1258 **make_request_options(extra_headers={"X-Foo": "Foo"}),
1259 ),
1260 )
1261 assert request.headers.get("X-Foo") == "Foo"
1262
1263 # `extra_headers` takes priority over `default_headers` when keys clash
1264 request = self.client.with_options(default_headers={"X-Bar": "true"})._build_request(
1265 FinalRequestOptions(
1266 method="post",
1267 url="/foo",
1268 **make_request_options(
1269 extra_headers={"X-Bar": "false"},
1270 ),
1271 ),
1272 )
1273 assert request.headers.get("X-Bar") == "false"
1274
1275 def test_request_extra_query(self) -> None:
1276 request = self.client._build_request(
1277 FinalRequestOptions(
1278 method="post",
1279 url="/foo",
1280 **make_request_options(
1281 extra_query={"my_query_param": "Foo"},
1282 ),
1283 ),
1284 )
1285 params = dict(request.url.params)
1286 assert params == {"my_query_param": "Foo"}
1287
1288 # if both `query` and `extra_query` are given, they are merged
1289 request = self.client._build_request(
1290 FinalRequestOptions(
1291 method="post",
1292 url="/foo",
1293 **make_request_options(
1294 query={"bar": "1"},
1295 extra_query={"foo": "2"},
1296 ),
1297 ),
1298 )
1299 params = dict(request.url.params)
1300 assert params == {"bar": "1", "foo": "2"}
1301
1302 # `extra_query` takes priority over `query` when keys clash
1303 request = self.client._build_request(
1304 FinalRequestOptions(
1305 method="post",
1306 url="/foo",
1307 **make_request_options(
1308 query={"foo": "1"},
1309 extra_query={"foo": "2"},
1310 ),
1311 ),
1312 )
1313 params = dict(request.url.params)
1314 assert params == {"foo": "2"}
1315
1316 def test_multipart_repeating_array(self, async_client: AsyncOpenAI) -> None:
1317 request = async_client._build_request(
1318 FinalRequestOptions.construct(
1319 method="get",
1320 url="/foo",
1321 headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"},
1322 json_data={"array": ["foo", "bar"]},
1323 files=[("foo.txt", b"hello world")],
1324 )
1325 )
1326
1327 assert request.read().split(b"\r\n") == [
1328 b"--6b7ba517decee4a450543ea6ae821c82",
1329 b'Content-Disposition: form-data; name="array[]"',
1330 b"",
1331 b"foo",
1332 b"--6b7ba517decee4a450543ea6ae821c82",
1333 b'Content-Disposition: form-data; name="array[]"',
1334 b"",
1335 b"bar",
1336 b"--6b7ba517decee4a450543ea6ae821c82",
1337 b'Content-Disposition: form-data; name="foo.txt"; filename="upload"',
1338 b"Content-Type: application/octet-stream",
1339 b"",
1340 b"hello world",
1341 b"--6b7ba517decee4a450543ea6ae821c82--",
1342 b"",
1343 ]
1344
1345 @pytest.mark.respx(base_url=base_url)
1346 async def test_basic_union_response(self, respx_mock: MockRouter) -> None:
1347 class Model1(BaseModel):
1348 name: str
1349
1350 class Model2(BaseModel):
1351 foo: str
1352
1353 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1354
1355 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1356 assert isinstance(response, Model2)
1357 assert response.foo == "bar"
1358
1359 @pytest.mark.respx(base_url=base_url)
1360 async def test_union_response_different_types(self, respx_mock: MockRouter) -> None:
1361 """Union of objects with the same field name using a different type"""
1362
1363 class Model1(BaseModel):
1364 foo: int
1365
1366 class Model2(BaseModel):
1367 foo: str
1368
1369 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1370
1371 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1372 assert isinstance(response, Model2)
1373 assert response.foo == "bar"
1374
1375 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": 1}))
1376
1377 response = await self.client.get("/foo", cast_to=cast(Any, Union[Model1, Model2]))
1378 assert isinstance(response, Model1)
1379 assert response.foo == 1
1380
1381 @pytest.mark.respx(base_url=base_url)
1382 async def test_non_application_json_content_type_for_json_data(self, respx_mock: MockRouter) -> None:
1383 """
1384 Response that sets Content-Type to something other than application/json but returns json data
1385 """
1386
1387 class Model(BaseModel):
1388 foo: int
1389
1390 respx_mock.get("/foo").mock(
1391 return_value=httpx.Response(
1392 200,
1393 content=json.dumps({"foo": 2}),
1394 headers={"Content-Type": "application/text"},
1395 )
1396 )
1397
1398 response = await self.client.get("/foo", cast_to=Model)
1399 assert isinstance(response, Model)
1400 assert response.foo == 2
1401
1402 def test_base_url_setter(self) -> None:
1403 client = AsyncOpenAI(
1404 base_url="https://example.com/from_init", api_key=api_key, _strict_response_validation=True
1405 )
1406 assert client.base_url == "https://example.com/from_init/"
1407
1408 client.base_url = "https://example.com/from_setter" # type: ignore[assignment]
1409
1410 assert client.base_url == "https://example.com/from_setter/"
1411
1412 def test_base_url_env(self) -> None:
1413 with update_env(OPENAI_BASE_URL="http://localhost:5000/from/env"):
1414 client = AsyncOpenAI(api_key=api_key, _strict_response_validation=True)
1415 assert client.base_url == "http://localhost:5000/from/env/"
1416
1417 @pytest.mark.parametrize(
1418 "client",
1419 [
1420 AsyncOpenAI(
1421 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1422 ),
1423 AsyncOpenAI(
1424 base_url="http://localhost:5000/custom/path/",
1425 api_key=api_key,
1426 _strict_response_validation=True,
1427 http_client=httpx.AsyncClient(),
1428 ),
1429 ],
1430 ids=["standard", "custom http client"],
1431 )
1432 def test_base_url_trailing_slash(self, client: AsyncOpenAI) -> None:
1433 request = client._build_request(
1434 FinalRequestOptions(
1435 method="post",
1436 url="/foo",
1437 json_data={"foo": "bar"},
1438 ),
1439 )
1440 assert request.url == "http://localhost:5000/custom/path/foo"
1441
1442 @pytest.mark.parametrize(
1443 "client",
1444 [
1445 AsyncOpenAI(
1446 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1447 ),
1448 AsyncOpenAI(
1449 base_url="http://localhost:5000/custom/path/",
1450 api_key=api_key,
1451 _strict_response_validation=True,
1452 http_client=httpx.AsyncClient(),
1453 ),
1454 ],
1455 ids=["standard", "custom http client"],
1456 )
1457 def test_base_url_no_trailing_slash(self, client: AsyncOpenAI) -> None:
1458 request = client._build_request(
1459 FinalRequestOptions(
1460 method="post",
1461 url="/foo",
1462 json_data={"foo": "bar"},
1463 ),
1464 )
1465 assert request.url == "http://localhost:5000/custom/path/foo"
1466
1467 @pytest.mark.parametrize(
1468 "client",
1469 [
1470 AsyncOpenAI(
1471 base_url="http://localhost:5000/custom/path/", api_key=api_key, _strict_response_validation=True
1472 ),
1473 AsyncOpenAI(
1474 base_url="http://localhost:5000/custom/path/",
1475 api_key=api_key,
1476 _strict_response_validation=True,
1477 http_client=httpx.AsyncClient(),
1478 ),
1479 ],
1480 ids=["standard", "custom http client"],
1481 )
1482 def test_absolute_request_url(self, client: AsyncOpenAI) -> None:
1483 request = client._build_request(
1484 FinalRequestOptions(
1485 method="post",
1486 url="https://myapi.com/foo",
1487 json_data={"foo": "bar"},
1488 ),
1489 )
1490 assert request.url == "https://myapi.com/foo"
1491
1492 async def test_copied_client_does_not_close_http(self) -> None:
1493 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1494 assert not client.is_closed()
1495
1496 copied = client.copy()
1497 assert copied is not client
1498
1499 del copied
1500
1501 await asyncio.sleep(0.2)
1502 assert not client.is_closed()
1503
1504 async def test_client_context_manager(self) -> None:
1505 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1506 async with client as c2:
1507 assert c2 is client
1508 assert not c2.is_closed()
1509 assert not client.is_closed()
1510 assert client.is_closed()
1511
1512 @pytest.mark.respx(base_url=base_url)
1513 @pytest.mark.asyncio
1514 async def test_client_response_validation_error(self, respx_mock: MockRouter) -> None:
1515 class Model(BaseModel):
1516 foo: str
1517
1518 respx_mock.get("/foo").mock(return_value=httpx.Response(200, json={"foo": {"invalid": True}}))
1519
1520 with pytest.raises(APIResponseValidationError) as exc:
1521 await self.client.get("/foo", cast_to=Model)
1522
1523 assert isinstance(exc.value.__cause__, ValidationError)
1524
1525 async def test_client_max_retries_validation(self) -> None:
1526 with pytest.raises(TypeError, match=r"max_retries cannot be None"):
1527 AsyncOpenAI(
1528 base_url=base_url, api_key=api_key, _strict_response_validation=True, max_retries=cast(Any, None)
1529 )
1530
1531 @pytest.mark.respx(base_url=base_url)
1532 @pytest.mark.asyncio
1533 async def test_default_stream_cls(self, respx_mock: MockRouter) -> None:
1534 class Model(BaseModel):
1535 name: str
1536
1537 respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"}))
1538
1539 stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model])
1540 assert isinstance(stream, AsyncStream)
1541 await stream.response.aclose()
1542
1543 @pytest.mark.respx(base_url=base_url)
1544 @pytest.mark.asyncio
1545 async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None:
1546 class Model(BaseModel):
1547 name: str
1548
1549 respx_mock.get("/foo").mock(return_value=httpx.Response(200, text="my-custom-format"))
1550
1551 strict_client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1552
1553 with pytest.raises(APIResponseValidationError):
1554 await strict_client.get("/foo", cast_to=Model)
1555
1556 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=False)
1557
1558 response = await client.get("/foo", cast_to=Model)
1559 assert isinstance(response, str) # type: ignore[unreachable]
1560
1561 @pytest.mark.parametrize(
1562 "remaining_retries,retry_after,timeout",
1563 [
1564 [3, "20", 20],
1565 [3, "0", 0.5],
1566 [3, "-10", 0.5],
1567 [3, "60", 60],
1568 [3, "61", 0.5],
1569 [3, "Fri, 29 Sep 2023 16:26:57 GMT", 20],
1570 [3, "Fri, 29 Sep 2023 16:26:37 GMT", 0.5],
1571 [3, "Fri, 29 Sep 2023 16:26:27 GMT", 0.5],
1572 [3, "Fri, 29 Sep 2023 16:27:37 GMT", 60],
1573 [3, "Fri, 29 Sep 2023 16:27:38 GMT", 0.5],
1574 [3, "99999999999999999999999999999999999", 0.5],
1575 [3, "Zun, 29 Sep 2023 16:26:27 GMT", 0.5],
1576 [3, "", 0.5],
1577 [2, "", 0.5 * 2.0],
1578 [1, "", 0.5 * 4.0],
1579 [-1100, "", 8], # test large number potentially overflowing
1580 ],
1581 )
1582 @mock.patch("time.time", mock.MagicMock(return_value=1696004797))
1583 @pytest.mark.asyncio
1584 async def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str, timeout: float) -> None:
1585 client = AsyncOpenAI(base_url=base_url, api_key=api_key, _strict_response_validation=True)
1586
1587 headers = httpx.Headers({"retry-after": retry_after})
1588 options = FinalRequestOptions(method="get", url="/foo", max_retries=3)
1589 calculated = client._calculate_retry_timeout(remaining_retries, options, headers)
1590 assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType]
1591
1592 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1593 @pytest.mark.respx(base_url=base_url)
1594 async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1595 respx_mock.post("/chat/completions").mock(side_effect=httpx.TimeoutException("Test timeout error"))
1596
1597 with pytest.raises(APITimeoutError):
1598 await self.client.post(
1599 "/chat/completions",
1600 body=cast(
1601 object,
1602 maybe_transform(
1603 dict(
1604 messages=[
1605 {
1606 "role": "user",
1607 "content": "Say this is a test",
1608 }
1609 ],
1610 model="gpt-4o",
1611 ),
1612 CompletionCreateParamsNonStreaming,
1613 ),
1614 ),
1615 cast_to=httpx.Response,
1616 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1617 )
1618
1619 assert _get_open_connections(self.client) == 0
1620
1621 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1622 @pytest.mark.respx(base_url=base_url)
1623 async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None:
1624 respx_mock.post("/chat/completions").mock(return_value=httpx.Response(500))
1625
1626 with pytest.raises(APIStatusError):
1627 await self.client.post(
1628 "/chat/completions",
1629 body=cast(
1630 object,
1631 maybe_transform(
1632 dict(
1633 messages=[
1634 {
1635 "role": "user",
1636 "content": "Say this is a test",
1637 }
1638 ],
1639 model="gpt-4o",
1640 ),
1641 CompletionCreateParamsNonStreaming,
1642 ),
1643 ),
1644 cast_to=httpx.Response,
1645 options={"headers": {RAW_RESPONSE_HEADER: "stream"}},
1646 )
1647
1648 assert _get_open_connections(self.client) == 0
1649
1650 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1651 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1652 @pytest.mark.respx(base_url=base_url)
1653 @pytest.mark.asyncio
1654 @pytest.mark.parametrize("failure_mode", ["status", "exception"])
1655 async def test_retries_taken(
1656 self,
1657 async_client: AsyncOpenAI,
1658 failures_before_success: int,
1659 failure_mode: Literal["status", "exception"],
1660 respx_mock: MockRouter,
1661 ) -> None:
1662 client = async_client.with_options(max_retries=4)
1663
1664 nb_retries = 0
1665
1666 def retry_handler(_request: httpx.Request) -> httpx.Response:
1667 nonlocal nb_retries
1668 if nb_retries < failures_before_success:
1669 nb_retries += 1
1670 if failure_mode == "exception":
1671 raise RuntimeError("oops")
1672 return httpx.Response(500)
1673 return httpx.Response(200)
1674
1675 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1676
1677 response = await client.chat.completions.with_raw_response.create(
1678 messages=[
1679 {
1680 "content": "string",
1681 "role": "developer",
1682 }
1683 ],
1684 model="gpt-4o",
1685 )
1686
1687 assert response.retries_taken == failures_before_success
1688 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1689
1690 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1691 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1692 @pytest.mark.respx(base_url=base_url)
1693 @pytest.mark.asyncio
1694 async def test_omit_retry_count_header(
1695 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1696 ) -> None:
1697 client = async_client.with_options(max_retries=4)
1698
1699 nb_retries = 0
1700
1701 def retry_handler(_request: httpx.Request) -> httpx.Response:
1702 nonlocal nb_retries
1703 if nb_retries < failures_before_success:
1704 nb_retries += 1
1705 return httpx.Response(500)
1706 return httpx.Response(200)
1707
1708 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1709
1710 response = await client.chat.completions.with_raw_response.create(
1711 messages=[
1712 {
1713 "content": "string",
1714 "role": "developer",
1715 }
1716 ],
1717 model="gpt-4o",
1718 extra_headers={"x-stainless-retry-count": Omit()},
1719 )
1720
1721 assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0
1722
1723 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1724 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1725 @pytest.mark.respx(base_url=base_url)
1726 @pytest.mark.asyncio
1727 async def test_overwrite_retry_count_header(
1728 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1729 ) -> None:
1730 client = async_client.with_options(max_retries=4)
1731
1732 nb_retries = 0
1733
1734 def retry_handler(_request: httpx.Request) -> httpx.Response:
1735 nonlocal nb_retries
1736 if nb_retries < failures_before_success:
1737 nb_retries += 1
1738 return httpx.Response(500)
1739 return httpx.Response(200)
1740
1741 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1742
1743 response = await client.chat.completions.with_raw_response.create(
1744 messages=[
1745 {
1746 "content": "string",
1747 "role": "developer",
1748 }
1749 ],
1750 model="gpt-4o",
1751 extra_headers={"x-stainless-retry-count": "42"},
1752 )
1753
1754 assert response.http_request.headers.get("x-stainless-retry-count") == "42"
1755
1756 @pytest.mark.parametrize("failures_before_success", [0, 2, 4])
1757 @mock.patch("openai._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout)
1758 @pytest.mark.respx(base_url=base_url)
1759 @pytest.mark.asyncio
1760 async def test_retries_taken_new_response_class(
1761 self, async_client: AsyncOpenAI, failures_before_success: int, respx_mock: MockRouter
1762 ) -> None:
1763 client = async_client.with_options(max_retries=4)
1764
1765 nb_retries = 0
1766
1767 def retry_handler(_request: httpx.Request) -> httpx.Response:
1768 nonlocal nb_retries
1769 if nb_retries < failures_before_success:
1770 nb_retries += 1
1771 return httpx.Response(500)
1772 return httpx.Response(200)
1773
1774 respx_mock.post("/chat/completions").mock(side_effect=retry_handler)
1775
1776 async with client.chat.completions.with_streaming_response.create(
1777 messages=[
1778 {
1779 "content": "string",
1780 "role": "developer",
1781 }
1782 ],
1783 model="gpt-4o",
1784 ) as response:
1785 assert response.retries_taken == failures_before_success
1786 assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success
1787
1788 def test_get_platform(self) -> None:
1789 # A previous implementation of asyncify could leave threads unterminated when
1790 # used with nest_asyncio.
1791 #
1792 # Since nest_asyncio.apply() is global and cannot be un-applied, this
1793 # test is run in a separate process to avoid affecting other tests.
1794 test_code = dedent("""
1795 import asyncio
1796 import nest_asyncio
1797 import threading
1798
1799 from openai._utils import asyncify
1800 from openai._base_client import get_platform
1801
1802 async def test_main() -> None:
1803 result = await asyncify(get_platform)()
1804 print(result)
1805 for thread in threading.enumerate():
1806 print(thread.name)
1807
1808 nest_asyncio.apply()
1809 asyncio.run(test_main())
1810 """)
1811 with subprocess.Popen(
1812 [sys.executable, "-c", test_code],
1813 text=True,
1814 ) as process:
1815 timeout = 10 # seconds
1816
1817 start_time = time.monotonic()
1818 while True:
1819 return_code = process.poll()
1820 if return_code is not None:
1821 if return_code != 0:
1822 raise AssertionError("calling get_platform using asyncify resulted in a non-zero exit code")
1823
1824 # success
1825 break
1826
1827 if time.monotonic() - start_time > timeout:
1828 process.kill()
1829 raise AssertionError("calling get_platform using asyncify resulted in a hung process")
1830
1831 time.sleep(0.1)
1832