openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.63.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/lib/test_azure.py

241lines · modecode

1import logging
2from typing import Union, cast
3from typing_extensions import Literal, Protocol
4
5import httpx
6import pytest
7from respx import MockRouter
8
9from openai._utils import SensitiveHeadersFilter, is_dict
10from openai._models import FinalRequestOptions
11from openai.lib.azure import AzureOpenAI, AsyncAzureOpenAI
12
13Client = Union[AzureOpenAI, AsyncAzureOpenAI]
14
15
16sync_client = AzureOpenAI(
17 api_version="2023-07-01",
18 api_key="example API key",
19 azure_endpoint="https://example-resource.azure.openai.com",
20)
21
22async_client = AsyncAzureOpenAI(
23 api_version="2023-07-01",
24 api_key="example API key",
25 azure_endpoint="https://example-resource.azure.openai.com",
26)
27
28
29class MockRequestCall(Protocol):
30 request: httpx.Request
31
32
33@pytest.mark.parametrize("client", [sync_client, async_client])
34def test_implicit_deployment_path(client: Client) -> None:
35 req = client._build_request(
36 FinalRequestOptions.construct(
37 method="post",
38 url="/chat/completions",
39 json_data={"model": "my-deployment-model"},
40 )
41 )
42 assert (
43 req.url
44 == "https://example-resource.azure.openai.com/openai/deployments/my-deployment-model/chat/completions?api-version=2023-07-01"
45 )
46
47
48@pytest.mark.parametrize(
49 "client,method",
50 [
51 (sync_client, "copy"),
52 (sync_client, "with_options"),
53 (async_client, "copy"),
54 (async_client, "with_options"),
55 ],
56)
57def test_client_copying(client: Client, method: Literal["copy", "with_options"]) -> None:
58 if method == "copy":
59 copied = client.copy()
60 else:
61 copied = client.with_options()
62
63 assert copied._custom_query == {"api-version": "2023-07-01"}
64
65
66@pytest.mark.parametrize(
67 "client",
68 [sync_client, async_client],
69)
70def test_client_copying_override_options(client: Client) -> None:
71 copied = client.copy(
72 api_version="2022-05-01",
73 )
74 assert copied._custom_query == {"api-version": "2022-05-01"}
75
76
77@pytest.mark.respx()
78def test_client_token_provider_refresh_sync(respx_mock: MockRouter) -> None:
79 respx_mock.post(
80 "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01"
81 ).mock(
82 side_effect=[
83 httpx.Response(500, json={"error": "server error"}),
84 httpx.Response(200, json={"foo": "bar"}),
85 ]
86 )
87
88 counter = 0
89
90 def token_provider() -> str:
91 nonlocal counter
92
93 counter += 1
94
95 if counter == 1:
96 return "first"
97
98 return "second"
99
100 client = AzureOpenAI(
101 api_version="2024-02-01",
102 azure_ad_token_provider=token_provider,
103 azure_endpoint="https://example-resource.azure.openai.com",
104 )
105 client.chat.completions.create(messages=[], model="gpt-4")
106
107 calls = cast("list[MockRequestCall]", respx_mock.calls)
108
109 assert len(calls) == 2
110
111 assert calls[0].request.headers.get("Authorization") == "Bearer first"
112 assert calls[1].request.headers.get("Authorization") == "Bearer second"
113
114
115@pytest.mark.asyncio
116@pytest.mark.respx()
117async def test_client_token_provider_refresh_async(respx_mock: MockRouter) -> None:
118 respx_mock.post(
119 "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01"
120 ).mock(
121 side_effect=[
122 httpx.Response(500, json={"error": "server error"}),
123 httpx.Response(200, json={"foo": "bar"}),
124 ]
125 )
126
127 counter = 0
128
129 def token_provider() -> str:
130 nonlocal counter
131
132 counter += 1
133
134 if counter == 1:
135 return "first"
136
137 return "second"
138
139 client = AsyncAzureOpenAI(
140 api_version="2024-02-01",
141 azure_ad_token_provider=token_provider,
142 azure_endpoint="https://example-resource.azure.openai.com",
143 )
144
145 await client.chat.completions.create(messages=[], model="gpt-4")
146
147 calls = cast("list[MockRequestCall]", respx_mock.calls)
148
149 assert len(calls) == 2
150
151 assert calls[0].request.headers.get("Authorization") == "Bearer first"
152 assert calls[1].request.headers.get("Authorization") == "Bearer second"
153
154
155class TestAzureLogging:
156 @pytest.fixture(autouse=True)
157 def logger_with_filter(self) -> logging.Logger:
158 logger = logging.getLogger("openai")
159 logger.setLevel(logging.DEBUG)
160 logger.addFilter(SensitiveHeadersFilter())
161 return logger
162
163 @pytest.mark.respx()
164 def test_azure_api_key_redacted(self, respx_mock: MockRouter, caplog: pytest.LogCaptureFixture) -> None:
165 respx_mock.post(
166 "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-06-01"
167 ).mock(return_value=httpx.Response(200, json={"model": "gpt-4"}))
168
169 client = AzureOpenAI(
170 api_version="2024-06-01",
171 api_key="example_api_key",
172 azure_endpoint="https://example-resource.azure.openai.com",
173 )
174
175 with caplog.at_level(logging.DEBUG):
176 client.chat.completions.create(messages=[], model="gpt-4")
177
178 for record in caplog.records:
179 if is_dict(record.args) and record.args.get("headers") and is_dict(record.args["headers"]):
180 assert record.args["headers"]["api-key"] == "<redacted>"
181
182 @pytest.mark.respx()
183 def test_azure_bearer_token_redacted(self, respx_mock: MockRouter, caplog: pytest.LogCaptureFixture) -> None:
184 respx_mock.post(
185 "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-06-01"
186 ).mock(return_value=httpx.Response(200, json={"model": "gpt-4"}))
187
188 client = AzureOpenAI(
189 api_version="2024-06-01",
190 azure_ad_token="example_token",
191 azure_endpoint="https://example-resource.azure.openai.com",
192 )
193
194 with caplog.at_level(logging.DEBUG):
195 client.chat.completions.create(messages=[], model="gpt-4")
196
197 for record in caplog.records:
198 if is_dict(record.args) and record.args.get("headers") and is_dict(record.args["headers"]):
199 assert record.args["headers"]["Authorization"] == "<redacted>"
200
201 @pytest.mark.asyncio
202 @pytest.mark.respx()
203 async def test_azure_api_key_redacted_async(self, respx_mock: MockRouter, caplog: pytest.LogCaptureFixture) -> None:
204 respx_mock.post(
205 "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-06-01"
206 ).mock(return_value=httpx.Response(200, json={"model": "gpt-4"}))
207
208 client = AsyncAzureOpenAI(
209 api_version="2024-06-01",
210 api_key="example_api_key",
211 azure_endpoint="https://example-resource.azure.openai.com",
212 )
213
214 with caplog.at_level(logging.DEBUG):
215 await client.chat.completions.create(messages=[], model="gpt-4")
216
217 for record in caplog.records:
218 if is_dict(record.args) and record.args.get("headers") and is_dict(record.args["headers"]):
219 assert record.args["headers"]["api-key"] == "<redacted>"
220
221 @pytest.mark.asyncio
222 @pytest.mark.respx()
223 async def test_azure_bearer_token_redacted_async(
224 self, respx_mock: MockRouter, caplog: pytest.LogCaptureFixture
225 ) -> None:
226 respx_mock.post(
227 "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-06-01"
228 ).mock(return_value=httpx.Response(200, json={"model": "gpt-4"}))
229
230 client = AsyncAzureOpenAI(
231 api_version="2024-06-01",
232 azure_ad_token="example_token",
233 azure_endpoint="https://example-resource.azure.openai.com",
234 )
235
236 with caplog.at_level(logging.DEBUG):
237 await client.chat.completions.create(messages=[], model="gpt-4")
238
239 for record in caplog.records:
240 if is_dict(record.args) and record.args.get("headers") and is_dict(record.args["headers"]):
241 assert record.args["headers"]["Authorization"] == "<redacted>"
242