openai/openai-python

Public

mirrored fromhttps://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
codex/remove-scheduled-release-run

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_auth.py

186lines · modecode

1import json
2from typing import cast
3from pathlib import Path
4
5import httpx
6import respx
7import pytest
8from respx.models import Call
9from inline_snapshot import snapshot
10
11from openai import OpenAI, OAuthError
12from openai.auth._workload import (
13 gcp_id_token_provider,
14 k8s_service_account_token_provider,
15 azure_managed_identity_token_provider,
16)
17
18
19@respx.mock
20def test_basic_auth():
21 respx.post("https://auth.openai.com/oauth/token").mock(
22 return_value=httpx.Response(
23 200,
24 json={
25 "access_token": "fake_access_token",
26 "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
27 "token_type": "Bearer",
28 "expires_in": 3600,
29 },
30 )
31 )
32
33 respx.get("https://api.openai.com/v1/models").mock(
34 return_value=httpx.Response(200, json={"data": [], "object": "list"})
35 )
36
37 client = OpenAI(
38 workload_identity={
39 "identity_provider_id": "idp_123",
40 "service_account_id": "sa_123",
41 "provider": {
42 "get_token": lambda: "fake_subject_token",
43 "token_type": "jwt",
44 },
45 },
46 )
47
48 client.models.list()
49
50 assert len(respx.calls) == 2
51 token_call = cast(Call, respx.calls[0])
52 api_call = cast(Call, respx.calls[1])
53
54 assert token_call.request.url == "https://auth.openai.com/oauth/token"
55 assert api_call.request.headers.get("Authorization") == "Bearer fake_access_token"
56
57
58@respx.mock
59def test_workload_identity_exchange_payload_and_cache() -> None:
60 provider_call_count = 0
61
62 def provider() -> str:
63 nonlocal provider_call_count
64 provider_call_count += 1
65 return "fake_subject_token"
66
67 exchange_route = respx.post("https://auth.openai.com/oauth/token").mock(
68 return_value=httpx.Response(
69 200,
70 json={
71 "access_token": "fake_access_token",
72 "issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
73 "token_type": "Bearer",
74 "expires_in": 3600,
75 },
76 )
77 )
78 api_route = respx.get("https://api.openai.com/v1/models").mock(
79 return_value=httpx.Response(200, json={"data": [], "object": "list"})
80 )
81
82 client = OpenAI(
83 workload_identity={
84 "identity_provider_id": "idp_123",
85 "service_account_id": "sa_123",
86 "provider": {
87 "get_token": provider,
88 "token_type": "jwt",
89 },
90 },
91 )
92
93 client.models.list()
94 client.models.list()
95
96 assert provider_call_count == 1
97 assert exchange_route.call_count == 1
98 assert api_route.call_count == 2
99
100 exchange_request = cast(respx.models.Call, exchange_route.calls[0]).request
101 assert json.loads(exchange_request.content) == snapshot(
102 {
103 "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
104 "subject_token": "fake_subject_token",
105 "subject_token_type": "urn:ietf:params:oauth:token-type:jwt",
106 "identity_provider_id": "idp_123",
107 "service_account_id": "sa_123",
108 }
109 )
110
111 assert (
112 cast(respx.models.Call, api_route.calls[0]).request.headers.get("Authorization") == "Bearer fake_access_token"
113 )
114 assert (
115 cast(respx.models.Call, api_route.calls[1]).request.headers.get("Authorization") == "Bearer fake_access_token"
116 )
117
118
119@respx.mock
120def test_workload_identity_exchange_error() -> None:
121 exchange_route = respx.post("https://auth.openai.com/oauth/token").mock(
122 return_value=httpx.Response(
123 401,
124 json={
125 "error": "invalid_grant",
126 "error_description": "No service account mapping found for the provided service_account_id.",
127 },
128 )
129 )
130 api_route = respx.get("https://api.openai.com/v1/models").mock(
131 return_value=httpx.Response(200, json={"data": [], "object": "list"})
132 )
133
134 client = OpenAI(
135 workload_identity={
136 "identity_provider_id": "idp_123",
137 "service_account_id": "sa_123",
138 "provider": {
139 "get_token": lambda: "fake_subject_token",
140 "token_type": "jwt",
141 },
142 },
143 )
144
145 with pytest.raises(OAuthError) as exc:
146 client.models.list()
147
148 assert exc.value.message == "No service account mapping found for the provided service_account_id."
149 assert exc.value.error == "invalid_grant"
150 assert exc.value.status_code == 401
151 assert exchange_route.call_count == 1
152 assert api_route.call_count == 0
153
154
155def test_k8s_service_account_token_provider(tmp_path: Path) -> None:
156 token_file = tmp_path / "token"
157 token_file.write_text("my-k8s-token")
158
159 provider = k8s_service_account_token_provider(token_file)
160
161 assert provider["token_type"] == "jwt"
162 assert provider["get_token"]() == "my-k8s-token"
163
164
165@respx.mock
166def test_azure_managed_identity_token_provider() -> None:
167 respx.get("http://169.254.169.254/metadata/identity/oauth2/token").mock(
168 return_value=httpx.Response(200, json={"access_token": "azure-token"})
169 )
170
171 provider = azure_managed_identity_token_provider()
172
173 assert provider["token_type"] == "jwt"
174 assert provider["get_token"]() == "azure-token"
175
176
177@respx.mock
178def test_gcp_id_token_provider() -> None:
179 respx.get("http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity").mock(
180 return_value=httpx.Response(200, text="gcp-token")
181 )
182
183 provider = gcp_id_token_provider()
184
185 assert provider["token_type"] == "id"
186 assert provider["get_token"]() == "gcp-token"
187