openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
dev/codex/pin-setup-rye-workflows-main

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_auth.py

190lines · modecode

1import json
2from typing import cast
3from pathlib import Path
4
5import httpx
6import respx
7import pytest
8from respx.models import Call
9from inline_snapshot import snapshot
10
11from openai import OpenAI, OAuthError
12from openai.auth._workload import (
13 gcp_id_token_provider,
14 k8s_service_account_token_provider,
15 azure_managed_identity_token_provider,
16)
17
18
19@respx.mock
20def test_basic_auth():
21 respx.post("https://auth.openai.com/oauth/token").mock(
22 return_value=httpx.Response(
23 200,
24 json={
25 "access_token": "fake_access_token",
26 "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
27 "token_type": "Bearer",
28 "expires_in": 3600,
29 },
30 )
31 )
32
33 respx.get("https://api.openai.com/v1/models").mock(
34 return_value=httpx.Response(200, json={"data": [], "object": "list"})
35 )
36
37 client = OpenAI(
38 workload_identity={
39 "client_id": "client_123",
40 "identity_provider_id": "idp_123",
41 "service_account_id": "sa_123",
42 "provider": {
43 "get_token": lambda: "fake_subject_token",
44 "token_type": "jwt",
45 },
46 },
47 )
48
49 client.models.list()
50
51 assert len(respx.calls) == 2
52 token_call = cast(Call, respx.calls[0])
53 api_call = cast(Call, respx.calls[1])
54
55 assert token_call.request.url == "https://auth.openai.com/oauth/token"
56 assert api_call.request.headers.get("Authorization") == "Bearer fake_access_token"
57
58
59@respx.mock
60def test_workload_identity_exchange_payload_and_cache() -> None:
61 provider_call_count = 0
62
63 def provider() -> str:
64 nonlocal provider_call_count
65 provider_call_count += 1
66 return "fake_subject_token"
67
68 exchange_route = respx.post("https://auth.openai.com/oauth/token").mock(
69 return_value=httpx.Response(
70 200,
71 json={
72 "access_token": "fake_access_token",
73 "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
74 "token_type": "Bearer",
75 "expires_in": 3600,
76 },
77 )
78 )
79 api_route = respx.get("https://api.openai.com/v1/models").mock(
80 return_value=httpx.Response(200, json={"data": [], "object": "list"})
81 )
82
83 client = OpenAI(
84 workload_identity={
85 "client_id": "client_123",
86 "identity_provider_id": "idp_123",
87 "service_account_id": "sa_123",
88 "provider": {
89 "get_token": provider,
90 "token_type": "jwt",
91 },
92 },
93 )
94
95 client.models.list()
96 client.models.list()
97
98 assert provider_call_count == 1
99 assert exchange_route.call_count == 1
100 assert api_route.call_count == 2
101
102 exchange_request = cast(respx.models.Call, exchange_route.calls[0]).request
103 assert json.loads(exchange_request.content) == snapshot(
104 {
105 "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
106 "client_id": "client_123",
107 "subject_token": "fake_subject_token",
108 "subject_token_type": "urn:ietf:params:oauth:token-type:jwt",
109 "identity_provider_id": "idp_123",
110 "service_account_id": "sa_123",
111 }
112 )
113
114 assert (
115 cast(respx.models.Call, api_route.calls[0]).request.headers.get("Authorization") == "Bearer fake_access_token"
116 )
117 assert (
118 cast(respx.models.Call, api_route.calls[1]).request.headers.get("Authorization") == "Bearer fake_access_token"
119 )
120
121
122@respx.mock
123def test_workload_identity_exchange_error() -> None:
124 exchange_route = respx.post("https://auth.openai.com/oauth/token").mock(
125 return_value=httpx.Response(
126 401,
127 json={
128 "error": "invalid_grant",
129 "error_description": "No service account mapping found for the provided service_account_id.",
130 },
131 )
132 )
133 api_route = respx.get("https://api.openai.com/v1/models").mock(
134 return_value=httpx.Response(200, json={"data": [], "object": "list"})
135 )
136
137 client = OpenAI(
138 workload_identity={
139 "client_id": "client_123",
140 "identity_provider_id": "idp_123",
141 "service_account_id": "sa_123",
142 "provider": {
143 "get_token": lambda: "fake_subject_token",
144 "token_type": "jwt",
145 },
146 },
147 )
148
149 with pytest.raises(OAuthError) as exc:
150 client.models.list()
151
152 assert exc.value.message == "No service account mapping found for the provided service_account_id."
153 assert exc.value.error == "invalid_grant"
154 assert exc.value.status_code == 401
155 assert exchange_route.call_count == 1
156 assert api_route.call_count == 0
157
158
159def test_k8s_service_account_token_provider(tmp_path: Path) -> None:
160 token_file = tmp_path / "token"
161 token_file.write_text("my-k8s-token")
162
163 provider = k8s_service_account_token_provider(token_file)
164
165 assert provider["token_type"] == "jwt"
166 assert provider["get_token"]() == "my-k8s-token"
167
168
169@respx.mock
170def test_azure_managed_identity_token_provider() -> None:
171 respx.get("http://169.254.169.254/metadata/identity/oauth2/token").mock(
172 return_value=httpx.Response(200, json={"access_token": "azure-token"})
173 )
174
175 provider = azure_managed_identity_token_provider()
176
177 assert provider["token_type"] == "jwt"
178 assert provider["get_token"]() == "azure-token"
179
180
181@respx.mock
182def test_gcp_id_token_provider() -> None:
183 respx.get("http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity").mock(
184 return_value=httpx.Response(200, text="gcp-token")
185 )
186
187 provider = gcp_id_token_provider()
188
189 assert provider["token_type"] == "id"
190 assert provider["get_token"]() == "gcp-token"
191