openai/openai-python

Public

mirrored from https://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v2.32.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_auth.py

190lines · modeblame

5be95364Kar Petrosyan2 months ago1import json
2from typing import cast
3from pathlib import Path
4
5import httpx
6import respx
7import pytest
8from respx.models import Call
9from inline_snapshot import snapshot
10
11from openai import OpenAI, OAuthError
12from openai.auth._workload import (
13gcp_id_token_provider,
14k8s_service_account_token_provider,
15azure_managed_identity_token_provider,
16)
17
18
19@respx.mock
20def test_basic_auth():
21respx.post("https://auth.openai.com/oauth/token").mock(
22return_value=httpx.Response(
23200,
24json={
25"access_token": "fake_access_token",
26"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
27"token_type": "Bearer",
28"expires_in": 3600,
29},
30)
31)
32
33respx.get("https://api.openai.com/v1/models").mock(
34return_value=httpx.Response(200, json={"data": [], "object": "list"})
35)
36
37client = OpenAI(
38workload_identity={
39"client_id": "client_123",
40"identity_provider_id": "idp_123",
41"service_account_id": "sa_123",
42"provider": {
43"get_token": lambda: "fake_subject_token",
44"token_type": "jwt",
45},
46},
47)
48
49client.models.list()
50
51assert len(respx.calls) == 2
52token_call = cast(Call, respx.calls[0])
53api_call = cast(Call, respx.calls[1])
54
55assert token_call.request.url == "https://auth.openai.com/oauth/token"
56assert api_call.request.headers.get("Authorization") == "Bearer fake_access_token"
57
58
59@respx.mock
60def test_workload_identity_exchange_payload_and_cache() -> None:
61provider_call_count = 0
62
63def provider() -> str:
64nonlocal provider_call_count
65provider_call_count += 1
66return "fake_subject_token"
67
68exchange_route = respx.post("https://auth.openai.com/oauth/token").mock(
69return_value=httpx.Response(
70200,
71json={
72"access_token": "fake_access_token",
73"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
74"token_type": "Bearer",
75"expires_in": 3600,
76},
77)
78)
79api_route = respx.get("https://api.openai.com/v1/models").mock(
80return_value=httpx.Response(200, json={"data": [], "object": "list"})
81)
82
83client = OpenAI(
84workload_identity={
85"client_id": "client_123",
86"identity_provider_id": "idp_123",
87"service_account_id": "sa_123",
88"provider": {
89"get_token": provider,
90"token_type": "jwt",
91},
92},
93)
94
95client.models.list()
96client.models.list()
97
98assert provider_call_count == 1
99assert exchange_route.call_count == 1
100assert api_route.call_count == 2
101
102exchange_request = cast(respx.models.Call, exchange_route.calls[0]).request
103assert json.loads(exchange_request.content) == snapshot(
104{
105"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
106"client_id": "client_123",
107"subject_token": "fake_subject_token",
108"subject_token_type": "urn:ietf:params:oauth:token-type:jwt",
109"identity_provider_id": "idp_123",
110"service_account_id": "sa_123",
111}
112)
113
114assert (
115cast(respx.models.Call, api_route.calls[0]).request.headers.get("Authorization") == "Bearer fake_access_token"
116)
117assert (
118cast(respx.models.Call, api_route.calls[1]).request.headers.get("Authorization") == "Bearer fake_access_token"
119)
120
121
122@respx.mock
123def test_workload_identity_exchange_error() -> None:
124exchange_route = respx.post("https://auth.openai.com/oauth/token").mock(
125return_value=httpx.Response(
126401,
127json={
128"error": "invalid_grant",
129"error_description": "No service account mapping found for the provided service_account_id.",
130},
131)
132)
133api_route = respx.get("https://api.openai.com/v1/models").mock(
134return_value=httpx.Response(200, json={"data": [], "object": "list"})
135)
136
137client = OpenAI(
138workload_identity={
139"client_id": "client_123",
140"identity_provider_id": "idp_123",
141"service_account_id": "sa_123",
142"provider": {
143"get_token": lambda: "fake_subject_token",
144"token_type": "jwt",
145},
146},
147)
148
149with pytest.raises(OAuthError) as exc:
150client.models.list()
151
152assert exc.value.message == "No service account mapping found for the provided service_account_id."
153assert exc.value.error == "invalid_grant"
154assert exc.value.status_code == 401
155assert exchange_route.call_count == 1
156assert api_route.call_count == 0
157
158
159def test_k8s_service_account_token_provider(tmp_path: Path) -> None:
160token_file = tmp_path / "token"
161token_file.write_text("my-k8s-token")
162
163provider = k8s_service_account_token_provider(token_file)
164
165assert provider["token_type"] == "jwt"
166assert provider["get_token"]() == "my-k8s-token"
167
168
169@respx.mock
170def test_azure_managed_identity_token_provider() -> None:
171respx.get("http://169.254.169.254/metadata/identity/oauth2/token").mock(
172return_value=httpx.Response(200, json={"access_token": "azure-token"})
173)
174
175provider = azure_managed_identity_token_provider()
176
177assert provider["token_type"] == "jwt"
178assert provider["get_token"]() == "azure-token"
179
180
181@respx.mock
182def test_gcp_id_token_provider() -> None:
183respx.get("http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity").mock(
184return_value=httpx.Response(200, text="gcp-token")
185)
186
187provider = gcp_id_token_provider()
188
189assert provider["token_type"] == "id"
190assert provider["get_token"]() == "gcp-token"