# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.

from __future__ import annotations

import os as _os

import httpx
import pytest
from httpx import URL

import openai
from openai import DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES


def reset_state() -> None:
    openai._reset_client()
    openai.api_key = None
    openai.admin_api_key = None
    openai.organization = None
    openai.project = None
    openai.webhook_secret = None
    openai.base_url = None
    openai.timeout = DEFAULT_TIMEOUT
    openai.max_retries = DEFAULT_MAX_RETRIES
    openai.default_headers = None
    openai.default_query = None
    openai.http_client = None
    openai.api_type = _os.environ.get("OPENAI_API_TYPE")  # type: ignore
    openai.api_version = None
    openai.azure_endpoint = None
    openai.azure_ad_token = None
    openai.azure_ad_token_provider = None
    openai._bedrock_api_key = None
    openai.bedrock_token_provider = None


@pytest.fixture(autouse=True)
def reset_state_fixture() -> None:
    reset_state()


def test_base_url_option() -> None:
    assert openai.base_url is None
    assert openai.completions._client.base_url == URL("https://api.openai.com/v1/")

    openai.base_url = "http://foo.com"

    assert openai.base_url == URL("http://foo.com")
    assert openai.completions._client.base_url == URL("http://foo.com")


def test_timeout_option() -> None:
    assert openai.timeout == openai.DEFAULT_TIMEOUT
    assert openai.completions._client.timeout == openai.DEFAULT_TIMEOUT

    openai.timeout = 3

    assert openai.timeout == 3
    assert openai.completions._client.timeout == 3


def test_max_retries_option() -> None:
    assert openai.max_retries == openai.DEFAULT_MAX_RETRIES
    assert openai.completions._client.max_retries == openai.DEFAULT_MAX_RETRIES

    openai.max_retries = 1

    assert openai.max_retries == 1
    assert openai.completions._client.max_retries == 1


def test_default_headers_option() -> None:
    assert openai.default_headers == None

    openai.default_headers = {"Foo": "Bar"}

    assert openai.default_headers["Foo"] == "Bar"
    assert openai.completions._client.default_headers["Foo"] == "Bar"


def test_default_query_option() -> None:
    assert openai.default_query is None
    assert openai.completions._client._custom_query == {}

    openai.default_query = {"Foo": {"nested": 1}}

    assert openai.default_query["Foo"] == {"nested": 1}
    assert openai.completions._client._custom_query["Foo"] == {"nested": 1}


def test_http_client_option() -> None:
    assert openai.http_client is None

    original_http_client = openai.completions._client._client
    assert original_http_client is not None

    new_client = httpx.Client()
    openai.http_client = new_client

    assert openai.completions._client._client is new_client


import contextlib
from typing import Iterator

from openai.lib.azure import AzureOpenAI
from openai.lib.bedrock import BedrockOpenAI


@contextlib.contextmanager
def fresh_env() -> Iterator[None]:
    old = _os.environ.copy()

    try:
        _os.environ.clear()
        yield
    finally:
        _os.environ.clear()
        _os.environ.update(old)


def test_only_api_key_results_in_openai_api() -> None:
    with fresh_env():
        openai.api_type = None
        openai.api_key = "example API key"

        assert type(openai.completions._client).__name__ == "_ModuleClient"


def test_azure_api_key_env_without_api_version() -> None:
    with fresh_env():
        openai.api_type = None
        _os.environ["AZURE_OPENAI_API_KEY"] = "example API key"

        with pytest.raises(
            ValueError,
            match=r"Must provide either the `api_version` argument or the `OPENAI_API_VERSION` environment variable",
        ):
            openai.completions._client  # noqa: B018


def test_azure_api_key_and_version_env() -> None:
    with fresh_env():
        openai.api_type = None
        _os.environ["AZURE_OPENAI_API_KEY"] = "example API key"
        _os.environ["OPENAI_API_VERSION"] = "example-version"

        with pytest.raises(
            ValueError,
            match=r"Must provide one of the `base_url` or `azure_endpoint` arguments, or the `AZURE_OPENAI_ENDPOINT` environment variable",
        ):
            openai.completions._client  # noqa: B018


def test_azure_api_key_version_and_endpoint_env() -> None:
    with fresh_env():
        openai.api_type = None
        _os.environ["AZURE_OPENAI_API_KEY"] = "example API key"
        _os.environ["OPENAI_API_VERSION"] = "example-version"
        _os.environ["AZURE_OPENAI_ENDPOINT"] = "https://www.example"

        openai.completions._client  # noqa: B018

        assert openai.api_type == "azure"


def test_azure_azure_ad_token_version_and_endpoint_env() -> None:
    with fresh_env():
        openai.api_type = None
        _os.environ["AZURE_OPENAI_AD_TOKEN"] = "example AD token"
        _os.environ["OPENAI_API_VERSION"] = "example-version"
        _os.environ["AZURE_OPENAI_ENDPOINT"] = "https://www.example"

        client = openai.completions._client
        assert isinstance(client, AzureOpenAI)
        assert client._azure_ad_token == "example AD token"


def test_azure_azure_ad_token_provider_version_and_endpoint_env() -> None:
    with fresh_env():
        openai.api_type = None
        _os.environ["OPENAI_API_VERSION"] = "example-version"
        _os.environ["AZURE_OPENAI_ENDPOINT"] = "https://www.example"
        openai.azure_ad_token_provider = lambda: "token"

        client = openai.completions._client
        assert isinstance(client, AzureOpenAI)
        assert client._azure_ad_token_provider is not None
        assert client._azure_ad_token_provider() == "token"


def test_bedrock_token_and_region_env() -> None:
    with fresh_env():
        openai.api_type = "amazon-bedrock"
        _os.environ["AWS_BEARER_TOKEN_BEDROCK"] = "example Bedrock token"
        _os.environ["AWS_REGION"] = "us-west-2"

        client = openai.responses._client
        assert isinstance(client, BedrockOpenAI)
        assert client.base_url == URL("https://bedrock-mantle.us-west-2.api.aws/openai/v1/")


def test_bedrock_api_type_env() -> None:
    with fresh_env():
        _os.environ["OPENAI_API_TYPE"] = "amazon-bedrock"
        _os.environ["AWS_BEARER_TOKEN_BEDROCK"] = "example Bedrock token"
        _os.environ["AWS_REGION"] = "us-west-2"
        reset_state()

        client = openai.responses._client
        assert isinstance(client, BedrockOpenAI)
        assert openai.api_type == "amazon-bedrock"


def test_bedrock_api_type_uses_bedrock_credentials() -> None:
    with fresh_env():
        openai.api_type = "amazon-bedrock"
        _os.environ["OPENAI_API_KEY"] = "openai api key"
        _os.environ["AWS_BEARER_TOKEN_BEDROCK"] = "example Bedrock token"
        _os.environ["AWS_REGION"] = "us-west-2"

        client = openai.responses._client
        assert isinstance(client, BedrockOpenAI)
        assert client.api_key == "example Bedrock token"
        assert openai.api_key is None


def test_bedrock_api_type_uses_explicit_module_api_key() -> None:
    with fresh_env():
        openai.api_type = "amazon-bedrock"
        openai.api_key = "explicit Bedrock token"
        _os.environ["AWS_BEARER_TOKEN_BEDROCK"] = "env Bedrock token"
        _os.environ["AWS_REGION"] = "us-west-2"

        client = openai.responses._client
        assert isinstance(client, BedrockOpenAI)
        assert client.api_key == "explicit Bedrock token"
        assert openai.api_key == "explicit Bedrock token"


def test_bedrock_module_api_key_overrides_cached_env_token_after_load() -> None:
    with fresh_env():
        openai.api_type = "amazon-bedrock"
        _os.environ["AWS_BEARER_TOKEN_BEDROCK"] = "env Bedrock token"
        _os.environ["AWS_REGION"] = "us-west-2"

        client = openai.responses._client
        assert isinstance(client, BedrockOpenAI)
        assert client.api_key == "env Bedrock token"

        openai.api_key = "new Bedrock token"

        assert client.api_key == "new Bedrock token"


def test_bedrock_module_api_key_switches_cached_aws_client_to_bearer() -> None:
    requests: list[httpx.Request] = []

    def handler(request: httpx.Request) -> httpx.Response:
        requests.append(request)
        return httpx.Response(200, request=request, json={})

    with fresh_env():
        openai.api_type = "amazon-bedrock"
        openai.http_client = httpx.Client(transport=httpx.MockTransport(handler), trust_env=False)
        _os.environ["AWS_ACCESS_KEY_ID"] = "access key"
        _os.environ["AWS_SECRET_ACCESS_KEY"] = "secret key"
        _os.environ["AWS_REGION"] = "us-west-2"

        client = openai.responses._client
        assert isinstance(client, BedrockOpenAI)
        assert client._uses_aws_auth()

        openai.api_key = "new Bedrock token"
        client.get("/models", cast_to=httpx.Response)

        assert requests[0].headers["Authorization"] == "Bearer new Bedrock token"


def test_bedrock_api_type_uses_token_provider_without_mutating_module_api_key() -> None:
    with fresh_env():
        openai.api_type = "amazon-bedrock"
        openai.bedrock_token_provider = lambda: "provider Bedrock token"
        _os.environ["AWS_REGION"] = "us-west-2"

        client = openai.responses._client
        assert isinstance(client, BedrockOpenAI)
        assert client._refresh_api_key() == "provider Bedrock token"
        assert openai.api_key is None


def test_bedrock_module_api_key_overrides_cached_token_provider() -> None:
    requests: list[httpx.Request] = []
    provider_calls = 0

    def token_provider() -> str:
        nonlocal provider_calls
        provider_calls += 1
        raise AssertionError("the replaced token provider must not be called")

    def handler(request: httpx.Request) -> httpx.Response:
        requests.append(request)
        return httpx.Response(200, request=request, json={})

    with fresh_env():
        openai.api_type = "amazon-bedrock"
        openai.bedrock_token_provider = token_provider
        openai.http_client = httpx.Client(transport=httpx.MockTransport(handler), trust_env=False)
        _os.environ["AWS_REGION"] = "us-west-2"

        client = openai.responses._client
        assert isinstance(client, BedrockOpenAI)

        openai.api_key = "new Bedrock token"
        client.get("/models", cast_to=httpx.Response)

    assert provider_calls == 0
    assert requests[0].headers["Authorization"] == "Bearer new Bedrock token"
