openai/openai-python
Publicmirrored from https://github.com/openai/openai-pythonAvailable
src/openai/cli/_cli.py
233lines · modecode
| 1 | from __future__ import annotations |
| 2 | |
| 3 | import sys |
| 4 | import logging |
| 5 | import argparse |
| 6 | from typing import Any, List, Type, Optional |
| 7 | from typing_extensions import ClassVar |
| 8 | |
| 9 | import httpx |
| 10 | import pydantic |
| 11 | |
| 12 | import openai |
| 13 | |
| 14 | from . import _tools |
| 15 | from .. import _ApiType, __version__ |
| 16 | from ._api import register_commands |
| 17 | from ._utils import can_use_http2 |
| 18 | from ._errors import CLIError, display_error |
| 19 | from .._compat import PYDANTIC_V2, ConfigDict, model_parse |
| 20 | from .._models import BaseModel |
| 21 | from .._exceptions import APIError |
| 22 | |
| 23 | logger = logging.getLogger() |
| 24 | formatter = logging.Formatter("[%(asctime)s] %(message)s") |
| 25 | handler = logging.StreamHandler(sys.stderr) |
| 26 | handler.setFormatter(formatter) |
| 27 | logger.addHandler(handler) |
| 28 | |
| 29 | |
| 30 | class Arguments(BaseModel): |
| 31 | if PYDANTIC_V2: |
| 32 | model_config: ClassVar[ConfigDict] = ConfigDict( |
| 33 | extra="ignore", |
| 34 | ) |
| 35 | else: |
| 36 | |
| 37 | class Config(pydantic.BaseConfig): # type: ignore |
| 38 | extra: Any = pydantic.Extra.ignore # type: ignore |
| 39 | |
| 40 | verbosity: int |
| 41 | version: Optional[str] = None |
| 42 | |
| 43 | api_key: Optional[str] |
| 44 | api_base: Optional[str] |
| 45 | organization: Optional[str] |
| 46 | proxy: Optional[List[str]] |
| 47 | api_type: Optional[_ApiType] = None |
| 48 | api_version: Optional[str] = None |
| 49 | |
| 50 | # azure |
| 51 | azure_endpoint: Optional[str] = None |
| 52 | azure_ad_token: Optional[str] = None |
| 53 | |
| 54 | # internal, set by subparsers to parse their specific args |
| 55 | args_model: Optional[Type[BaseModel]] = None |
| 56 | |
| 57 | # internal, used so that subparsers can forward unknown arguments |
| 58 | unknown_args: List[str] = [] |
| 59 | allow_unknown_args: bool = False |
| 60 | |
| 61 | |
| 62 | def _build_parser() -> argparse.ArgumentParser: |
| 63 | parser = argparse.ArgumentParser(description=None, prog="openai") |
| 64 | parser.add_argument( |
| 65 | "-v", |
| 66 | "--verbose", |
| 67 | action="count", |
| 68 | dest="verbosity", |
| 69 | default=0, |
| 70 | help="Set verbosity.", |
| 71 | ) |
| 72 | parser.add_argument("-b", "--api-base", help="What API base url to use.") |
| 73 | parser.add_argument("-k", "--api-key", help="What API key to use.") |
| 74 | parser.add_argument("-p", "--proxy", nargs="+", help="What proxy to use.") |
| 75 | parser.add_argument( |
| 76 | "-o", |
| 77 | "--organization", |
| 78 | help="Which organization to run as (will use your default organization if not specified)", |
| 79 | ) |
| 80 | parser.add_argument( |
| 81 | "-t", |
| 82 | "--api-type", |
| 83 | type=str, |
| 84 | choices=("openai", "azure"), |
| 85 | help="The backend API to call, must be `openai` or `azure`", |
| 86 | ) |
| 87 | parser.add_argument( |
| 88 | "--api-version", |
| 89 | help="The Azure API version, e.g. 'https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#rest-api-versioning'", |
| 90 | ) |
| 91 | |
| 92 | # azure |
| 93 | parser.add_argument( |
| 94 | "--azure-endpoint", |
| 95 | help="The Azure endpoint, e.g. 'https://endpoint.openai.azure.com'", |
| 96 | ) |
| 97 | parser.add_argument( |
| 98 | "--azure-ad-token", |
| 99 | help="A token from Azure Active Directory, https://www.microsoft.com/en-us/security/business/identity-access/microsoft-entra-id", |
| 100 | ) |
| 101 | |
| 102 | # prints the package version |
| 103 | parser.add_argument( |
| 104 | "-V", |
| 105 | "--version", |
| 106 | action="version", |
| 107 | version="%(prog)s " + __version__, |
| 108 | ) |
| 109 | |
| 110 | def help() -> None: |
| 111 | parser.print_help() |
| 112 | |
| 113 | parser.set_defaults(func=help) |
| 114 | |
| 115 | subparsers = parser.add_subparsers() |
| 116 | sub_api = subparsers.add_parser("api", help="Direct API calls") |
| 117 | |
| 118 | register_commands(sub_api) |
| 119 | |
| 120 | sub_tools = subparsers.add_parser("tools", help="Client side tools for convenience") |
| 121 | _tools.register_commands(sub_tools, subparsers) |
| 122 | |
| 123 | return parser |
| 124 | |
| 125 | |
| 126 | def main() -> int: |
| 127 | try: |
| 128 | _main() |
| 129 | except (APIError, CLIError, pydantic.ValidationError) as err: |
| 130 | display_error(err) |
| 131 | return 1 |
| 132 | except KeyboardInterrupt: |
| 133 | sys.stderr.write("\n") |
| 134 | return 1 |
| 135 | return 0 |
| 136 | |
| 137 | |
| 138 | def _parse_args(parser: argparse.ArgumentParser) -> tuple[argparse.Namespace, Arguments, list[str]]: |
| 139 | # argparse by default will strip out the `--` but we want to keep it for unknown arguments |
| 140 | if "--" in sys.argv: |
| 141 | idx = sys.argv.index("--") |
| 142 | known_args = sys.argv[1:idx] |
| 143 | unknown_args = sys.argv[idx:] |
| 144 | else: |
| 145 | known_args = sys.argv[1:] |
| 146 | unknown_args = [] |
| 147 | |
| 148 | parsed, remaining_unknown = parser.parse_known_args(known_args) |
| 149 | |
| 150 | # append any remaining unknown arguments from the initial parsing |
| 151 | remaining_unknown.extend(unknown_args) |
| 152 | |
| 153 | args = model_parse(Arguments, vars(parsed)) |
| 154 | if not args.allow_unknown_args: |
| 155 | # we have to parse twice to ensure any unknown arguments |
| 156 | # result in an error if that behaviour is desired |
| 157 | parser.parse_args() |
| 158 | |
| 159 | return parsed, args, remaining_unknown |
| 160 | |
| 161 | |
| 162 | def _main() -> None: |
| 163 | parser = _build_parser() |
| 164 | parsed, args, unknown = _parse_args(parser) |
| 165 | |
| 166 | if args.verbosity != 0: |
| 167 | sys.stderr.write("Warning: --verbosity isn't supported yet\n") |
| 168 | |
| 169 | proxies: dict[str, httpx.BaseTransport] = {} |
| 170 | if args.proxy is not None: |
| 171 | for proxy in args.proxy: |
| 172 | key = "https://" if proxy.startswith("https") else "http://" |
| 173 | if key in proxies: |
| 174 | raise CLIError(f"Multiple {key} proxies given - only the last one would be used") |
| 175 | |
| 176 | proxies[key] = httpx.HTTPTransport(proxy=httpx.Proxy(httpx.URL(proxy))) |
| 177 | |
| 178 | http_client = httpx.Client( |
| 179 | mounts=proxies or None, |
| 180 | http2=can_use_http2(), |
| 181 | ) |
| 182 | openai.http_client = http_client |
| 183 | |
| 184 | if args.organization: |
| 185 | openai.organization = args.organization |
| 186 | |
| 187 | if args.api_key: |
| 188 | openai.api_key = args.api_key |
| 189 | |
| 190 | if args.api_base: |
| 191 | openai.base_url = args.api_base |
| 192 | |
| 193 | # azure |
| 194 | if args.api_type is not None: |
| 195 | openai.api_type = args.api_type |
| 196 | |
| 197 | if args.azure_endpoint is not None: |
| 198 | openai.azure_endpoint = args.azure_endpoint |
| 199 | |
| 200 | if args.api_version is not None: |
| 201 | openai.api_version = args.api_version |
| 202 | |
| 203 | if args.azure_ad_token is not None: |
| 204 | openai.azure_ad_token = args.azure_ad_token |
| 205 | |
| 206 | try: |
| 207 | if args.args_model: |
| 208 | parsed.func( |
| 209 | model_parse( |
| 210 | args.args_model, |
| 211 | { |
| 212 | **{ |
| 213 | # we omit None values so that they can be defaulted to `NotGiven` |
| 214 | # and we'll strip it from the API request |
| 215 | key: value |
| 216 | for key, value in vars(parsed).items() |
| 217 | if value is not None |
| 218 | }, |
| 219 | "unknown_args": unknown, |
| 220 | }, |
| 221 | ) |
| 222 | ) |
| 223 | else: |
| 224 | parsed.func() |
| 225 | finally: |
| 226 | try: |
| 227 | http_client.close() |
| 228 | except Exception: |
| 229 | pass |
| 230 | |
| 231 | |
| 232 | if __name__ == "__main__": |
| 233 | sys.exit(main()) |
| 234 | |