openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.49.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/lib/test_azure.py

150lines · modecode

1from typing import Union, cast
2from typing_extensions import Literal, Protocol
3
4import httpx
5import pytest
6from respx import MockRouter
7
8from openai._models import FinalRequestOptions
9from openai.lib.azure import AzureOpenAI, AsyncAzureOpenAI
10
11Client = Union[AzureOpenAI, AsyncAzureOpenAI]
12
13
14sync_client = AzureOpenAI(
15 api_version="2023-07-01",
16 api_key="example API key",
17 azure_endpoint="https://example-resource.azure.openai.com",
18)
19
20async_client = AsyncAzureOpenAI(
21 api_version="2023-07-01",
22 api_key="example API key",
23 azure_endpoint="https://example-resource.azure.openai.com",
24)
25
26
27class MockRequestCall(Protocol):
28 request: httpx.Request
29
30
31@pytest.mark.parametrize("client", [sync_client, async_client])
32def test_implicit_deployment_path(client: Client) -> None:
33 req = client._build_request(
34 FinalRequestOptions.construct(
35 method="post",
36 url="/chat/completions",
37 json_data={"model": "my-deployment-model"},
38 )
39 )
40 assert (
41 req.url
42 == "https://example-resource.azure.openai.com/openai/deployments/my-deployment-model/chat/completions?api-version=2023-07-01"
43 )
44
45
46@pytest.mark.parametrize(
47 "client,method",
48 [
49 (sync_client, "copy"),
50 (sync_client, "with_options"),
51 (async_client, "copy"),
52 (async_client, "with_options"),
53 ],
54)
55def test_client_copying(client: Client, method: Literal["copy", "with_options"]) -> None:
56 if method == "copy":
57 copied = client.copy()
58 else:
59 copied = client.with_options()
60
61 assert copied._custom_query == {"api-version": "2023-07-01"}
62
63
64@pytest.mark.parametrize(
65 "client",
66 [sync_client, async_client],
67)
68def test_client_copying_override_options(client: Client) -> None:
69 copied = client.copy(
70 api_version="2022-05-01",
71 )
72 assert copied._custom_query == {"api-version": "2022-05-01"}
73
74
75@pytest.mark.respx()
76def test_client_token_provider_refresh_sync(respx_mock: MockRouter) -> None:
77 respx_mock.post(
78 "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01"
79 ).mock(
80 side_effect=[
81 httpx.Response(500, json={"error": "server error"}),
82 httpx.Response(200, json={"foo": "bar"}),
83 ]
84 )
85
86 counter = 0
87
88 def token_provider() -> str:
89 nonlocal counter
90
91 counter += 1
92
93 if counter == 1:
94 return "first"
95
96 return "second"
97
98 client = AzureOpenAI(
99 api_version="2024-02-01",
100 azure_ad_token_provider=token_provider,
101 azure_endpoint="https://example-resource.azure.openai.com",
102 )
103 client.chat.completions.create(messages=[], model="gpt-4")
104
105 calls = cast("list[MockRequestCall]", respx_mock.calls)
106
107 assert len(calls) == 2
108
109 assert calls[0].request.headers.get("Authorization") == "Bearer first"
110 assert calls[1].request.headers.get("Authorization") == "Bearer second"
111
112
113@pytest.mark.asyncio
114@pytest.mark.respx()
115async def test_client_token_provider_refresh_async(respx_mock: MockRouter) -> None:
116 respx_mock.post(
117 "https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01"
118 ).mock(
119 side_effect=[
120 httpx.Response(500, json={"error": "server error"}),
121 httpx.Response(200, json={"foo": "bar"}),
122 ]
123 )
124
125 counter = 0
126
127 def token_provider() -> str:
128 nonlocal counter
129
130 counter += 1
131
132 if counter == 1:
133 return "first"
134
135 return "second"
136
137 client = AsyncAzureOpenAI(
138 api_version="2024-02-01",
139 azure_ad_token_provider=token_provider,
140 azure_endpoint="https://example-resource.azure.openai.com",
141 )
142
143 await client.chat.completions.create(messages=[], model="gpt-4")
144
145 calls = cast("list[MockRequestCall]", respx_mock.calls)
146
147 assert len(calls) == 2
148
149 assert calls[0].request.headers.get("Authorization") == "Bearer first"
150 assert calls[1].request.headers.get("Authorization") == "Bearer second"
151