openai/openai-python
Publicmirrored from https://github.com/openai/openai-pythonAvailable
src/openai/_compat.py
222lines · modecode
| 1 | from __future__ import annotations |
| 2 | |
| 3 | from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload |
| 4 | from datetime import date, datetime |
| 5 | from typing_extensions import Self |
| 6 | |
| 7 | import pydantic |
| 8 | from pydantic.fields import FieldInfo |
| 9 | |
| 10 | from ._types import StrBytesIntFloat |
| 11 | |
| 12 | _T = TypeVar("_T") |
| 13 | _ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel) |
| 14 | |
| 15 | # --------------- Pydantic v2 compatibility --------------- |
| 16 | |
| 17 | # Pyright incorrectly reports some of our functions as overriding a method when they don't |
| 18 | # pyright: reportIncompatibleMethodOverride=false |
| 19 | |
| 20 | PYDANTIC_V2 = pydantic.VERSION.startswith("2.") |
| 21 | |
| 22 | # v1 re-exports |
| 23 | if TYPE_CHECKING: |
| 24 | |
| 25 | def parse_date(value: date | StrBytesIntFloat) -> date: # noqa: ARG001 |
| 26 | ... |
| 27 | |
| 28 | def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: # noqa: ARG001 |
| 29 | ... |
| 30 | |
| 31 | def get_args(t: type[Any]) -> tuple[Any, ...]: # noqa: ARG001 |
| 32 | ... |
| 33 | |
| 34 | def is_union(tp: type[Any] | None) -> bool: # noqa: ARG001 |
| 35 | ... |
| 36 | |
| 37 | def get_origin(t: type[Any]) -> type[Any] | None: # noqa: ARG001 |
| 38 | ... |
| 39 | |
| 40 | def is_literal_type(type_: type[Any]) -> bool: # noqa: ARG001 |
| 41 | ... |
| 42 | |
| 43 | def is_typeddict(type_: type[Any]) -> bool: # noqa: ARG001 |
| 44 | ... |
| 45 | |
| 46 | else: |
| 47 | if PYDANTIC_V2: |
| 48 | from pydantic.v1.typing import ( |
| 49 | get_args as get_args, |
| 50 | is_union as is_union, |
| 51 | get_origin as get_origin, |
| 52 | is_typeddict as is_typeddict, |
| 53 | is_literal_type as is_literal_type, |
| 54 | ) |
| 55 | from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime |
| 56 | else: |
| 57 | from pydantic.typing import ( |
| 58 | get_args as get_args, |
| 59 | is_union as is_union, |
| 60 | get_origin as get_origin, |
| 61 | is_typeddict as is_typeddict, |
| 62 | is_literal_type as is_literal_type, |
| 63 | ) |
| 64 | from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime |
| 65 | |
| 66 | |
| 67 | # refactored config |
| 68 | if TYPE_CHECKING: |
| 69 | from pydantic import ConfigDict as ConfigDict |
| 70 | else: |
| 71 | if PYDANTIC_V2: |
| 72 | from pydantic import ConfigDict |
| 73 | else: |
| 74 | # TODO: provide an error message here? |
| 75 | ConfigDict = None |
| 76 | |
| 77 | |
| 78 | # renamed methods / properties |
| 79 | def parse_obj(model: type[_ModelT], value: object) -> _ModelT: |
| 80 | if PYDANTIC_V2: |
| 81 | return model.model_validate(value) |
| 82 | else: |
| 83 | return cast(_ModelT, model.parse_obj(value)) # pyright: ignore[reportDeprecated, reportUnnecessaryCast] |
| 84 | |
| 85 | |
| 86 | def field_is_required(field: FieldInfo) -> bool: |
| 87 | if PYDANTIC_V2: |
| 88 | return field.is_required() |
| 89 | return field.required # type: ignore |
| 90 | |
| 91 | |
| 92 | def field_get_default(field: FieldInfo) -> Any: |
| 93 | value = field.get_default() |
| 94 | if PYDANTIC_V2: |
| 95 | from pydantic_core import PydanticUndefined |
| 96 | |
| 97 | if value == PydanticUndefined: |
| 98 | return None |
| 99 | return value |
| 100 | return value |
| 101 | |
| 102 | |
| 103 | def field_outer_type(field: FieldInfo) -> Any: |
| 104 | if PYDANTIC_V2: |
| 105 | return field.annotation |
| 106 | return field.outer_type_ # type: ignore |
| 107 | |
| 108 | |
| 109 | def get_model_config(model: type[pydantic.BaseModel]) -> Any: |
| 110 | if PYDANTIC_V2: |
| 111 | return model.model_config |
| 112 | return model.__config__ # type: ignore |
| 113 | |
| 114 | |
| 115 | def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]: |
| 116 | if PYDANTIC_V2: |
| 117 | return model.model_fields |
| 118 | return model.__fields__ # type: ignore |
| 119 | |
| 120 | |
| 121 | def model_copy(model: _ModelT) -> _ModelT: |
| 122 | if PYDANTIC_V2: |
| 123 | return model.model_copy() |
| 124 | return model.copy() # type: ignore |
| 125 | |
| 126 | |
| 127 | def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str: |
| 128 | if PYDANTIC_V2: |
| 129 | return model.model_dump_json(indent=indent) |
| 130 | return model.json(indent=indent) # type: ignore |
| 131 | |
| 132 | |
| 133 | def model_dump( |
| 134 | model: pydantic.BaseModel, |
| 135 | *, |
| 136 | exclude_unset: bool = False, |
| 137 | exclude_defaults: bool = False, |
| 138 | ) -> dict[str, Any]: |
| 139 | if PYDANTIC_V2: |
| 140 | return model.model_dump( |
| 141 | exclude_unset=exclude_unset, |
| 142 | exclude_defaults=exclude_defaults, |
| 143 | ) |
| 144 | return cast( |
| 145 | "dict[str, Any]", |
| 146 | model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast] |
| 147 | exclude_unset=exclude_unset, |
| 148 | exclude_defaults=exclude_defaults, |
| 149 | ), |
| 150 | ) |
| 151 | |
| 152 | |
| 153 | def model_parse(model: type[_ModelT], data: Any) -> _ModelT: |
| 154 | if PYDANTIC_V2: |
| 155 | return model.model_validate(data) |
| 156 | return model.parse_obj(data) # pyright: ignore[reportDeprecated] |
| 157 | |
| 158 | |
| 159 | # generic models |
| 160 | if TYPE_CHECKING: |
| 161 | |
| 162 | class GenericModel(pydantic.BaseModel): |
| 163 | ... |
| 164 | |
| 165 | else: |
| 166 | if PYDANTIC_V2: |
| 167 | # there no longer needs to be a distinction in v2 but |
| 168 | # we still have to create our own subclass to avoid |
| 169 | # inconsistent MRO ordering errors |
| 170 | class GenericModel(pydantic.BaseModel): |
| 171 | ... |
| 172 | |
| 173 | else: |
| 174 | import pydantic.generics |
| 175 | |
| 176 | class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): |
| 177 | ... |
| 178 | |
| 179 | |
| 180 | # cached properties |
| 181 | if TYPE_CHECKING: |
| 182 | cached_property = property |
| 183 | |
| 184 | # we define a separate type (copied from typeshed) |
| 185 | # that represents that `cached_property` is `set`able |
| 186 | # at runtime, which differs from `@property`. |
| 187 | # |
| 188 | # this is a separate type as editors likely special case |
| 189 | # `@property` and we don't want to cause issues just to have |
| 190 | # more helpful internal types. |
| 191 | |
| 192 | class typed_cached_property(Generic[_T]): |
| 193 | func: Callable[[Any], _T] |
| 194 | attrname: str | None |
| 195 | |
| 196 | def __init__(self, func: Callable[[Any], _T]) -> None: |
| 197 | ... |
| 198 | |
| 199 | @overload |
| 200 | def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: |
| 201 | ... |
| 202 | |
| 203 | @overload |
| 204 | def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: |
| 205 | ... |
| 206 | |
| 207 | def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self: |
| 208 | raise NotImplementedError() |
| 209 | |
| 210 | def __set_name__(self, owner: type[Any], name: str) -> None: |
| 211 | ... |
| 212 | |
| 213 | # __set__ is not defined at runtime, but @cached_property is designed to be settable |
| 214 | def __set__(self, instance: object, value: _T) -> None: |
| 215 | ... |
| 216 | else: |
| 217 | try: |
| 218 | from functools import cached_property as cached_property |
| 219 | except ImportError: |
| 220 | from cached_property import cached_property as cached_property |
| 221 | |
| 222 | typed_cached_property = cached_property |
| 223 | |