openai/openai-python

Public

mirrored from https://github.com/openai/openai-pythonAvailable

CodeCommitsIssuesPull requestsActionsInsightsSecurity
v2.38.0

Branches

Tags

  • No tags available.
0Branches0Tags
Go to file
Add file
Code

Clone

HTTPS

Download ZIP

tests/test_auth.py

186lines · modeblame

5be95364Kar Petrosyan2 months ago1import json
2from typing import cast
3from pathlib import Path
4
5import httpx
6import respx
7import pytest
8from respx.models import Call
9from inline_snapshot import snapshot
10
11from openai import OpenAI, OAuthError
12from openai.auth._workload import (
13gcp_id_token_provider,
14k8s_service_account_token_provider,
15azure_managed_identity_token_provider,
16)
17
18
19@respx.mock
20def test_basic_auth():
21respx.post("https://auth.openai.com/oauth/token").mock(
22return_value=httpx.Response(
23200,
24json={
25"access_token": "fake_access_token",
26"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
27"token_type": "Bearer",
28"expires_in": 3600,
29},
30)
31)
32
33respx.get("https://api.openai.com/v1/models").mock(
34return_value=httpx.Response(200, json={"data": [], "object": "list"})
35)
36
37client = OpenAI(
38workload_identity={
39"identity_provider_id": "idp_123",
40"service_account_id": "sa_123",
41"provider": {
42"get_token": lambda: "fake_subject_token",
43"token_type": "jwt",
44},
45},
46)
47
48client.models.list()
49
50assert len(respx.calls) == 2
51token_call = cast(Call, respx.calls[0])
52api_call = cast(Call, respx.calls[1])
53
54assert token_call.request.url == "https://auth.openai.com/oauth/token"
55assert api_call.request.headers.get("Authorization") == "Bearer fake_access_token"
56
57
58@respx.mock
59def test_workload_identity_exchange_payload_and_cache() -> None:
60provider_call_count = 0
61
62def provider() -> str:
63nonlocal provider_call_count
64provider_call_count += 1
65return "fake_subject_token"
66
67exchange_route = respx.post("https://auth.openai.com/oauth/token").mock(
68return_value=httpx.Response(
69200,
70json={
71"access_token": "fake_access_token",
72"issued_token_type": "urn:ietf:params:oauth:token-type:access_token",
73"token_type": "Bearer",
74"expires_in": 3600,
75},
76)
77)
78api_route = respx.get("https://api.openai.com/v1/models").mock(
79return_value=httpx.Response(200, json={"data": [], "object": "list"})
80)
81
82client = OpenAI(
83workload_identity={
84"identity_provider_id": "idp_123",
85"service_account_id": "sa_123",
86"provider": {
87"get_token": provider,
88"token_type": "jwt",
89},
90},
91)
92
93client.models.list()
94client.models.list()
95
96assert provider_call_count == 1
97assert exchange_route.call_count == 1
98assert api_route.call_count == 2
99
100exchange_request = cast(respx.models.Call, exchange_route.calls[0]).request
101assert json.loads(exchange_request.content) == snapshot(
102{
103"grant_type": "urn:ietf:params:oauth:grant-type:token-exchange",
104"subject_token": "fake_subject_token",
105"subject_token_type": "urn:ietf:params:oauth:token-type:jwt",
106"identity_provider_id": "idp_123",
107"service_account_id": "sa_123",
108}
109)
110
111assert (
112cast(respx.models.Call, api_route.calls[0]).request.headers.get("Authorization") == "Bearer fake_access_token"
113)
114assert (
115cast(respx.models.Call, api_route.calls[1]).request.headers.get("Authorization") == "Bearer fake_access_token"
116)
117
118
119@respx.mock
120def test_workload_identity_exchange_error() -> None:
121exchange_route = respx.post("https://auth.openai.com/oauth/token").mock(
122return_value=httpx.Response(
123401,
124json={
125"error": "invalid_grant",
126"error_description": "No service account mapping found for the provided service_account_id.",
127},
128)
129)
130api_route = respx.get("https://api.openai.com/v1/models").mock(
131return_value=httpx.Response(200, json={"data": [], "object": "list"})
132)
133
134client = OpenAI(
135workload_identity={
136"identity_provider_id": "idp_123",
137"service_account_id": "sa_123",
138"provider": {
139"get_token": lambda: "fake_subject_token",
140"token_type": "jwt",
141},
142},
143)
144
145with pytest.raises(OAuthError) as exc:
146client.models.list()
147
148assert exc.value.message == "No service account mapping found for the provided service_account_id."
149assert exc.value.error == "invalid_grant"
150assert exc.value.status_code == 401
151assert exchange_route.call_count == 1
152assert api_route.call_count == 0
153
154
155def test_k8s_service_account_token_provider(tmp_path: Path) -> None:
156token_file = tmp_path / "token"
157token_file.write_text("my-k8s-token")
158
159provider = k8s_service_account_token_provider(token_file)
160
161assert provider["token_type"] == "jwt"
162assert provider["get_token"]() == "my-k8s-token"
163
164
165@respx.mock
166def test_azure_managed_identity_token_provider() -> None:
167respx.get("http://169.254.169.254/metadata/identity/oauth2/token").mock(
168return_value=httpx.Response(200, json={"access_token": "azure-token"})
169)
170
171provider = azure_managed_identity_token_provider()
172
173assert provider["token_type"] == "jwt"
174assert provider["get_token"]() == "azure-token"
175
176
177@respx.mock
178def test_gcp_id_token_provider() -> None:
179respx.get("http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/identity").mock(
180return_value=httpx.Response(200, text="gcp-token")
181)
182
183provider = gcp_id_token_provider()
184
185assert provider["token_type"] == "id"
186assert provider["get_token"]() == "gcp-token"