openai/openai-python

Public

mirrored from https://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v1.45.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/lib/test_azure.py

150lines · modeblame

6b01ba6aRobert Craigie1 years ago1from typing import Union, cast
2from typing_extensions import Literal, Protocol
08b8179aDavid Schnurr2 years ago3
6b01ba6aRobert Craigie1 years ago4import httpx
08b8179aDavid Schnurr2 years ago5import pytest
6b01ba6aRobert Craigie1 years ago6from respx import MockRouter
08b8179aDavid Schnurr2 years ago7
8from openai._models import FinalRequestOptions
9from openai.lib.azure import AzureOpenAI, AsyncAzureOpenAI
10
11Client = Union[AzureOpenAI, AsyncAzureOpenAI]
12
13
14sync_client = AzureOpenAI(
15api_version="2023-07-01",
16api_key="example API key",
17azure_endpoint="https://example-resource.azure.openai.com",
18)
19
20async_client = AsyncAzureOpenAI(
21api_version="2023-07-01",
22api_key="example API key",
23azure_endpoint="https://example-resource.azure.openai.com",
24)
25
26
6b01ba6aRobert Craigie1 years ago27class MockRequestCall(Protocol):
28request: httpx.Request
29
30
08b8179aDavid Schnurr2 years ago31@pytest.mark.parametrize("client", [sync_client, async_client])
32def test_implicit_deployment_path(client: Client) -> None:
33req = client._build_request(
34FinalRequestOptions.construct(
35method="post",
36url="/chat/completions",
37json_data={"model": "my-deployment-model"},
38)
39)
40assert (
41req.url
42== "https://example-resource.azure.openai.com/openai/deployments/my-deployment-model/chat/completions?api-version=2023-07-01"
43)
97a6895bStainless Bot2 years ago44
45
46@pytest.mark.parametrize(
47"client,method",
48[
49(sync_client, "copy"),
50(sync_client, "with_options"),
51(async_client, "copy"),
52(async_client, "with_options"),
53],
54)
55def test_client_copying(client: Client, method: Literal["copy", "with_options"]) -> None:
56if method == "copy":
57copied = client.copy()
58else:
59copied = client.with_options()
60
61assert copied._custom_query == {"api-version": "2023-07-01"}
62
63
64@pytest.mark.parametrize(
65"client",
66[sync_client, async_client],
67)
68def test_client_copying_override_options(client: Client) -> None:
69copied = client.copy(
70api_version="2022-05-01",
71)
72assert copied._custom_query == {"api-version": "2022-05-01"}
6b01ba6aRobert Craigie1 years ago73
74
75@pytest.mark.respx()
76def test_client_token_provider_refresh_sync(respx_mock: MockRouter) -> None:
77respx_mock.post(
78"https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01"
79).mock(
80side_effect=[
81httpx.Response(500, json={"error": "server error"}),
82httpx.Response(200, json={"foo": "bar"}),
83]
84)
85
86counter = 0
87
88def token_provider() -> str:
89nonlocal counter
90
91counter += 1
92
93if counter == 1:
94return "first"
95
96return "second"
97
98client = AzureOpenAI(
99api_version="2024-02-01",
100azure_ad_token_provider=token_provider,
101azure_endpoint="https://example-resource.azure.openai.com",
102)
103client.chat.completions.create(messages=[], model="gpt-4")
104
105calls = cast("list[MockRequestCall]", respx_mock.calls)
106
107assert len(calls) == 2
108
109assert calls[0].request.headers.get("Authorization") == "Bearer first"
110assert calls[1].request.headers.get("Authorization") == "Bearer second"
111
112
113@pytest.mark.asyncio
114@pytest.mark.respx()
115async def test_client_token_provider_refresh_async(respx_mock: MockRouter) -> None:
116respx_mock.post(
117"https://example-resource.azure.openai.com/openai/deployments/gpt-4/chat/completions?api-version=2024-02-01"
118).mock(
119side_effect=[
120httpx.Response(500, json={"error": "server error"}),
121httpx.Response(200, json={"foo": "bar"}),
122]
123)
124
125counter = 0
126
127def token_provider() -> str:
128nonlocal counter
129
130counter += 1
131
132if counter == 1:
133return "first"
134
135return "second"
136
137client = AsyncAzureOpenAI(
138api_version="2024-02-01",
139azure_ad_token_provider=token_provider,
140azure_endpoint="https://example-resource.azure.openai.com",
141)
142
143await client.chat.completions.create(messages=[], model="gpt-4")
144
145calls = cast("list[MockRequestCall]", respx_mock.calls)
146
147assert len(calls) == 2
148
149assert calls[0].request.headers.get("Authorization") == "Bearer first"
150assert calls[1].request.headers.get("Authorization") == "Bearer second"