openai/openai-python
Publicmirrored fromhttps://github.com/openai/openai-pythonAvailable
tests/test_module_client.py
179lines · modecode
| 1 | # File generated from our OpenAPI spec by Stainless. |
| 2 | |
| 3 | from __future__ import annotations |
| 4 | |
| 5 | import os as _os |
| 6 | |
| 7 | import httpx |
| 8 | import pytest |
| 9 | from httpx import URL |
| 10 | |
| 11 | import openai |
| 12 | from openai import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES |
| 13 | |
| 14 | |
| 15 | def reset_state() -> None: |
| 16 | openai._reset_client() |
| 17 | openai.api_key = None or "My API Key" |
| 18 | openai.organization = None |
| 19 | openai.base_url = None |
| 20 | openai.timeout = DEFAULT_TIMEOUT |
| 21 | openai.max_retries = DEFAULT_MAX_RETRIES |
| 22 | openai.default_headers = None |
| 23 | openai.default_query = None |
| 24 | openai.http_client = None |
| 25 | openai.api_type = _os.environ.get("OPENAI_API_TYPE") # type: ignore |
| 26 | openai.api_version = None |
| 27 | openai.azure_endpoint = None |
| 28 | openai.azure_ad_token = None |
| 29 | openai.azure_ad_token_provider = None |
| 30 | |
| 31 | |
| 32 | @pytest.fixture(autouse=True) |
| 33 | def reset_state_fixture() -> None: |
| 34 | reset_state() |
| 35 | |
| 36 | |
| 37 | def test_base_url_option() -> None: |
| 38 | assert openai.base_url is None |
| 39 | assert openai.completions._client.base_url == URL("https://api.openai.com/v1/") |
| 40 | |
| 41 | openai.base_url = "http://foo.com" |
| 42 | |
| 43 | assert openai.base_url == URL("http://foo.com") |
| 44 | assert openai.completions._client.base_url == URL("http://foo.com") |
| 45 | |
| 46 | |
| 47 | def test_timeout_option() -> None: |
| 48 | assert openai.timeout == openai.DEFAULT_TIMEOUT |
| 49 | assert openai.completions._client.timeout == openai.DEFAULT_TIMEOUT |
| 50 | |
| 51 | openai.timeout = 3 |
| 52 | |
| 53 | assert openai.timeout == 3 |
| 54 | assert openai.completions._client.timeout == 3 |
| 55 | |
| 56 | |
| 57 | def test_max_retries_option() -> None: |
| 58 | assert openai.max_retries == openai.DEFAULT_MAX_RETRIES |
| 59 | assert openai.completions._client.max_retries == openai.DEFAULT_MAX_RETRIES |
| 60 | |
| 61 | openai.max_retries = 1 |
| 62 | |
| 63 | assert openai.max_retries == 1 |
| 64 | assert openai.completions._client.max_retries == 1 |
| 65 | |
| 66 | |
| 67 | def test_default_headers_option() -> None: |
| 68 | assert openai.default_headers == None |
| 69 | |
| 70 | openai.default_headers = {"Foo": "Bar"} |
| 71 | |
| 72 | assert openai.default_headers["Foo"] == "Bar" |
| 73 | assert openai.completions._client.default_headers["Foo"] == "Bar" |
| 74 | |
| 75 | |
| 76 | def test_default_query_option() -> None: |
| 77 | assert openai.default_query is None |
| 78 | assert openai.completions._client._custom_query == {} |
| 79 | |
| 80 | openai.default_query = {"Foo": {"nested": 1}} |
| 81 | |
| 82 | assert openai.default_query["Foo"] == {"nested": 1} |
| 83 | assert openai.completions._client._custom_query["Foo"] == {"nested": 1} |
| 84 | |
| 85 | |
| 86 | def test_http_client_option() -> None: |
| 87 | assert openai.http_client is None |
| 88 | |
| 89 | original_http_client = openai.completions._client._client |
| 90 | assert original_http_client is not None |
| 91 | |
| 92 | new_client = httpx.Client() |
| 93 | openai.http_client = new_client |
| 94 | |
| 95 | assert openai.completions._client._client is new_client |
| 96 | |
| 97 | |
| 98 | import contextlib |
| 99 | from typing import Iterator |
| 100 | |
| 101 | from openai.lib.azure import AzureOpenAI |
| 102 | |
| 103 | |
| 104 | @contextlib.contextmanager |
| 105 | def fresh_env() -> Iterator[None]: |
| 106 | old = _os.environ.copy() |
| 107 | |
| 108 | try: |
| 109 | _os.environ.clear() |
| 110 | yield |
| 111 | finally: |
| 112 | _os.environ.update(old) |
| 113 | |
| 114 | |
| 115 | def test_only_api_key_results_in_openai_api() -> None: |
| 116 | with fresh_env(): |
| 117 | openai.api_type = None |
| 118 | openai.api_key = "example API key" |
| 119 | |
| 120 | assert type(openai.completions._client).__name__ == "_ModuleClient" |
| 121 | |
| 122 | |
| 123 | def test_azure_api_key_env_without_api_version() -> None: |
| 124 | with fresh_env(): |
| 125 | openai.api_type = None |
| 126 | _os.environ["AZURE_OPENAI_API_KEY"] = "example API key" |
| 127 | |
| 128 | with pytest.raises(ValueError, match=r"Expected `api_version` to be given for the Azure client"): |
| 129 | openai.completions._client |
| 130 | |
| 131 | |
| 132 | def test_azure_api_key_and_version_env() -> None: |
| 133 | with fresh_env(): |
| 134 | openai.api_type = None |
| 135 | _os.environ["AZURE_OPENAI_API_KEY"] = "example API key" |
| 136 | _os.environ["OPENAI_API_VERSION"] = "example-version" |
| 137 | |
| 138 | with pytest.raises( |
| 139 | ValueError, |
| 140 | match=r"Must provide one of the `base_url` or `azure_endpoint` arguments, or the `OPENAI_BASE_URL`", |
| 141 | ): |
| 142 | openai.completions._client |
| 143 | |
| 144 | |
| 145 | def test_azure_api_key_version_and_endpoint_env() -> None: |
| 146 | with fresh_env(): |
| 147 | openai.api_type = None |
| 148 | _os.environ["AZURE_OPENAI_API_KEY"] = "example API key" |
| 149 | _os.environ["OPENAI_API_VERSION"] = "example-version" |
| 150 | _os.environ["AZURE_OPENAI_ENDPOINT"] = "https://www.example" |
| 151 | |
| 152 | openai.completions._client |
| 153 | |
| 154 | assert openai.api_type == "azure" |
| 155 | |
| 156 | |
| 157 | def test_azure_azure_ad_token_version_and_endpoint_env() -> None: |
| 158 | with fresh_env(): |
| 159 | openai.api_type = None |
| 160 | _os.environ["AZURE_OPENAI_AD_TOKEN"] = "example AD token" |
| 161 | _os.environ["OPENAI_API_VERSION"] = "example-version" |
| 162 | _os.environ["AZURE_OPENAI_ENDPOINT"] = "https://www.example" |
| 163 | |
| 164 | client = openai.completions._client |
| 165 | assert isinstance(client, AzureOpenAI) |
| 166 | assert client._azure_ad_token == "example AD token" |
| 167 | |
| 168 | |
| 169 | def test_azure_azure_ad_token_provider_version_and_endpoint_env() -> None: |
| 170 | with fresh_env(): |
| 171 | openai.api_type = None |
| 172 | _os.environ["OPENAI_API_VERSION"] = "example-version" |
| 173 | _os.environ["AZURE_OPENAI_ENDPOINT"] = "https://www.example" |
| 174 | openai.azure_ad_token_provider = lambda: "token" |
| 175 | |
| 176 | client = openai.completions._client |
| 177 | assert isinstance(client, AzureOpenAI) |
| 178 | assert client._azure_ad_token_provider is not None |
| 179 | assert client._azure_ad_token_provider() == "token" |
| 180 | |