Initial commit
This commit is contained in:
@@ -0,0 +1,40 @@
|
||||
from .shared import PYDANTIC_VERSION_MINOR_TUPLE as PYDANTIC_VERSION_MINOR_TUPLE
|
||||
from .shared import annotation_is_pydantic_v1 as annotation_is_pydantic_v1
|
||||
from .shared import field_annotation_is_scalar as field_annotation_is_scalar
|
||||
from .shared import (
|
||||
field_annotation_is_scalar_sequence as field_annotation_is_scalar_sequence,
|
||||
)
|
||||
from .shared import field_annotation_is_sequence as field_annotation_is_sequence
|
||||
from .shared import (
|
||||
is_bytes_or_nonable_bytes_annotation as is_bytes_or_nonable_bytes_annotation,
|
||||
)
|
||||
from .shared import is_bytes_sequence_annotation as is_bytes_sequence_annotation
|
||||
from .shared import is_pydantic_v1_model_instance as is_pydantic_v1_model_instance
|
||||
from .shared import (
|
||||
is_uploadfile_or_nonable_uploadfile_annotation as is_uploadfile_or_nonable_uploadfile_annotation,
|
||||
)
|
||||
from .shared import (
|
||||
is_uploadfile_sequence_annotation as is_uploadfile_sequence_annotation,
|
||||
)
|
||||
from .shared import lenient_issubclass as lenient_issubclass
|
||||
from .shared import sequence_types as sequence_types
|
||||
from .shared import value_is_sequence as value_is_sequence
|
||||
from .v2 import ModelField as ModelField
|
||||
from .v2 import PydanticSchemaGenerationError as PydanticSchemaGenerationError
|
||||
from .v2 import RequiredParam as RequiredParam
|
||||
from .v2 import Undefined as Undefined
|
||||
from .v2 import Url as Url
|
||||
from .v2 import copy_field_info as copy_field_info
|
||||
from .v2 import create_body_model as create_body_model
|
||||
from .v2 import evaluate_forwardref as evaluate_forwardref
|
||||
from .v2 import get_cached_model_fields as get_cached_model_fields
|
||||
from .v2 import get_definitions as get_definitions
|
||||
from .v2 import get_flat_models_from_fields as get_flat_models_from_fields
|
||||
from .v2 import get_missing_field_error as get_missing_field_error
|
||||
from .v2 import get_model_name_map as get_model_name_map
|
||||
from .v2 import get_schema_from_model_field as get_schema_from_model_field
|
||||
from .v2 import is_scalar_field as is_scalar_field
|
||||
from .v2 import serialize_sequence_value as serialize_sequence_value
|
||||
from .v2 import (
|
||||
with_info_plain_validator_function as with_info_plain_validator_function,
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
214
venv/lib/python3.12/site-packages/fastapi/_compat/shared.py
Normal file
214
venv/lib/python3.12/site-packages/fastapi/_compat/shared.py
Normal file
@@ -0,0 +1,214 @@
|
||||
import types
|
||||
import typing
|
||||
import warnings
|
||||
from collections import deque
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import is_dataclass
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
TypeGuard,
|
||||
TypeVar,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from fastapi.types import UnionType
|
||||
from pydantic import BaseModel
|
||||
from pydantic.version import VERSION as PYDANTIC_VERSION
|
||||
from starlette.datastructures import UploadFile
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
# Copy from Pydantic: pydantic/_internal/_typing_extra.py
|
||||
WithArgsTypes: tuple[Any, ...] = (
|
||||
typing._GenericAlias, # type: ignore[attr-defined]
|
||||
types.GenericAlias,
|
||||
types.UnionType,
|
||||
) # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
PYDANTIC_VERSION_MINOR_TUPLE = tuple(int(x) for x in PYDANTIC_VERSION.split(".")[:2])
|
||||
|
||||
|
||||
sequence_annotation_to_type = {
|
||||
Sequence: list,
|
||||
list: list,
|
||||
tuple: tuple,
|
||||
set: set,
|
||||
frozenset: frozenset,
|
||||
deque: deque,
|
||||
}
|
||||
|
||||
sequence_types: tuple[type[Any], ...] = tuple(sequence_annotation_to_type.keys())
|
||||
|
||||
|
||||
# Copy of Pydantic: pydantic/_internal/_utils.py with added TypeGuard
|
||||
def lenient_issubclass(
|
||||
cls: Any, class_or_tuple: type[_T] | tuple[type[_T], ...] | None
|
||||
) -> TypeGuard[type[_T]]:
|
||||
try:
|
||||
return isinstance(cls, type) and issubclass(cls, class_or_tuple) # type: ignore[arg-type]
|
||||
except TypeError: # pragma: no cover
|
||||
if isinstance(cls, WithArgsTypes):
|
||||
return False
|
||||
raise # pragma: no cover
|
||||
|
||||
|
||||
def _annotation_is_sequence(annotation: type[Any] | None) -> bool:
|
||||
if lenient_issubclass(annotation, (str, bytes)):
|
||||
return False
|
||||
return lenient_issubclass(annotation, sequence_types)
|
||||
|
||||
|
||||
def field_annotation_is_sequence(annotation: type[Any] | None) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
for arg in get_args(annotation):
|
||||
if field_annotation_is_sequence(arg):
|
||||
return True
|
||||
return False
|
||||
return _annotation_is_sequence(annotation) or _annotation_is_sequence(
|
||||
get_origin(annotation)
|
||||
)
|
||||
|
||||
|
||||
def value_is_sequence(value: Any) -> bool:
|
||||
return isinstance(value, sequence_types) and not isinstance(value, (str, bytes))
|
||||
|
||||
|
||||
def _annotation_is_complex(annotation: type[Any] | None) -> bool:
|
||||
return (
|
||||
lenient_issubclass(annotation, (BaseModel, Mapping, UploadFile))
|
||||
or _annotation_is_sequence(annotation)
|
||||
or is_dataclass(annotation)
|
||||
)
|
||||
|
||||
|
||||
def field_annotation_is_complex(annotation: type[Any] | None) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
return any(field_annotation_is_complex(arg) for arg in get_args(annotation))
|
||||
|
||||
if origin is Annotated:
|
||||
return field_annotation_is_complex(get_args(annotation)[0])
|
||||
|
||||
return (
|
||||
_annotation_is_complex(annotation)
|
||||
or _annotation_is_complex(origin)
|
||||
or hasattr(origin, "__pydantic_core_schema__")
|
||||
or hasattr(origin, "__get_pydantic_core_schema__")
|
||||
)
|
||||
|
||||
|
||||
def field_annotation_is_scalar(annotation: Any) -> bool:
|
||||
# handle Ellipsis here to make tuple[int, ...] work nicely
|
||||
return annotation is Ellipsis or not field_annotation_is_complex(annotation)
|
||||
|
||||
|
||||
def field_annotation_is_scalar_sequence(annotation: type[Any] | None) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
at_least_one_scalar_sequence = False
|
||||
for arg in get_args(annotation):
|
||||
if field_annotation_is_scalar_sequence(arg):
|
||||
at_least_one_scalar_sequence = True
|
||||
continue
|
||||
elif not field_annotation_is_scalar(arg):
|
||||
return False
|
||||
return at_least_one_scalar_sequence
|
||||
return field_annotation_is_sequence(annotation) and all(
|
||||
field_annotation_is_scalar(sub_annotation)
|
||||
for sub_annotation in get_args(annotation)
|
||||
)
|
||||
|
||||
|
||||
def is_bytes_or_nonable_bytes_annotation(annotation: Any) -> bool:
|
||||
if lenient_issubclass(annotation, bytes):
|
||||
return True
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
for arg in get_args(annotation):
|
||||
if lenient_issubclass(arg, bytes):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_uploadfile_or_nonable_uploadfile_annotation(annotation: Any) -> bool:
|
||||
if lenient_issubclass(annotation, UploadFile):
|
||||
return True
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
for arg in get_args(annotation):
|
||||
if lenient_issubclass(arg, UploadFile):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_bytes_sequence_annotation(annotation: Any) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
at_least_one = False
|
||||
for arg in get_args(annotation):
|
||||
if is_bytes_sequence_annotation(arg):
|
||||
at_least_one = True
|
||||
continue
|
||||
return at_least_one
|
||||
return field_annotation_is_sequence(annotation) and all(
|
||||
is_bytes_or_nonable_bytes_annotation(sub_annotation)
|
||||
for sub_annotation in get_args(annotation)
|
||||
)
|
||||
|
||||
|
||||
def is_uploadfile_sequence_annotation(annotation: Any) -> bool:
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
at_least_one = False
|
||||
for arg in get_args(annotation):
|
||||
if is_uploadfile_sequence_annotation(arg):
|
||||
at_least_one = True
|
||||
continue
|
||||
return at_least_one
|
||||
return field_annotation_is_sequence(annotation) and all(
|
||||
is_uploadfile_or_nonable_uploadfile_annotation(sub_annotation)
|
||||
for sub_annotation in get_args(annotation)
|
||||
)
|
||||
|
||||
|
||||
def is_pydantic_v1_model_instance(obj: Any) -> bool:
|
||||
# TODO: remove this function once the required version of Pydantic fully
|
||||
# removes pydantic.v1
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
from pydantic import v1
|
||||
except ImportError: # pragma: no cover
|
||||
return False
|
||||
return isinstance(obj, v1.BaseModel)
|
||||
|
||||
|
||||
def is_pydantic_v1_model_class(cls: Any) -> bool:
|
||||
# TODO: remove this function once the required version of Pydantic fully
|
||||
# removes pydantic.v1
|
||||
try:
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
from pydantic import v1
|
||||
except ImportError: # pragma: no cover
|
||||
return False
|
||||
return lenient_issubclass(cls, v1.BaseModel)
|
||||
|
||||
|
||||
def annotation_is_pydantic_v1(annotation: Any) -> bool:
|
||||
if is_pydantic_v1_model_class(annotation):
|
||||
return True
|
||||
origin = get_origin(annotation)
|
||||
if origin is Union or origin is UnionType:
|
||||
for arg in get_args(annotation):
|
||||
if is_pydantic_v1_model_class(arg):
|
||||
return True
|
||||
if field_annotation_is_sequence(annotation):
|
||||
for sub_annotation in get_args(annotation):
|
||||
if annotation_is_pydantic_v1(sub_annotation):
|
||||
return True
|
||||
return False
|
||||
480
venv/lib/python3.12/site-packages/fastapi/_compat/v2.py
Normal file
480
venv/lib/python3.12/site-packages/fastapi/_compat/v2.py
Normal file
@@ -0,0 +1,480 @@
|
||||
import re
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, is_dataclass
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Literal,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from fastapi._compat import lenient_issubclass, shared
|
||||
from fastapi.openapi.constants import REF_TEMPLATE
|
||||
from fastapi.types import IncEx, ModelNameMap, UnionType
|
||||
from pydantic import BaseModel, ConfigDict, Field, TypeAdapter, create_model
|
||||
from pydantic import PydanticSchemaGenerationError as PydanticSchemaGenerationError
|
||||
from pydantic import PydanticUndefinedAnnotation as PydanticUndefinedAnnotation
|
||||
from pydantic import ValidationError as ValidationError
|
||||
from pydantic._internal._schema_generation_shared import ( # type: ignore[attr-defined]
|
||||
GetJsonSchemaHandler as GetJsonSchemaHandler,
|
||||
)
|
||||
from pydantic._internal._typing_extra import eval_type_lenient
|
||||
from pydantic.fields import FieldInfo as FieldInfo
|
||||
from pydantic.json_schema import GenerateJsonSchema as _GenerateJsonSchema
|
||||
from pydantic.json_schema import JsonSchemaValue as JsonSchemaValue
|
||||
from pydantic_core import CoreSchema as CoreSchema
|
||||
from pydantic_core import PydanticUndefined
|
||||
from pydantic_core import Url as Url
|
||||
from pydantic_core.core_schema import (
|
||||
with_info_plain_validator_function as with_info_plain_validator_function,
|
||||
)
|
||||
|
||||
RequiredParam = PydanticUndefined
|
||||
Undefined = PydanticUndefined
|
||||
evaluate_forwardref = eval_type_lenient
|
||||
|
||||
|
||||
class GenerateJsonSchema(_GenerateJsonSchema):
|
||||
# TODO: remove when this is merged (or equivalent): https://github.com/pydantic/pydantic/pull/12841
|
||||
# and dropping support for any version of Pydantic before that one (so, in a very long time)
|
||||
def bytes_schema(self, schema: CoreSchema) -> JsonSchemaValue:
|
||||
json_schema = {"type": "string", "contentMediaType": "application/octet-stream"}
|
||||
bytes_mode = (
|
||||
self._config.ser_json_bytes
|
||||
if self.mode == "serialization"
|
||||
else self._config.val_json_bytes
|
||||
)
|
||||
if bytes_mode == "base64":
|
||||
json_schema["contentEncoding"] = "base64"
|
||||
self.update_with_validations(json_schema, schema, self.ValidationsMapping.bytes)
|
||||
return json_schema
|
||||
|
||||
|
||||
# TODO: remove when dropping support for Pydantic < v2.12.3
|
||||
_Attrs = {
|
||||
"default": ...,
|
||||
"default_factory": None,
|
||||
"alias": None,
|
||||
"alias_priority": None,
|
||||
"validation_alias": None,
|
||||
"serialization_alias": None,
|
||||
"title": None,
|
||||
"field_title_generator": None,
|
||||
"description": None,
|
||||
"examples": None,
|
||||
"exclude": None,
|
||||
"exclude_if": None,
|
||||
"discriminator": None,
|
||||
"deprecated": None,
|
||||
"json_schema_extra": None,
|
||||
"frozen": None,
|
||||
"validate_default": None,
|
||||
"repr": True,
|
||||
"init": None,
|
||||
"init_var": None,
|
||||
"kw_only": None,
|
||||
}
|
||||
|
||||
|
||||
# TODO: remove when dropping support for Pydantic < v2.12.3
|
||||
def asdict(field_info: FieldInfo) -> dict[str, Any]:
|
||||
attributes = {}
|
||||
for attr in _Attrs:
|
||||
value = getattr(field_info, attr, Undefined)
|
||||
if value is not Undefined:
|
||||
attributes[attr] = value
|
||||
return {
|
||||
"annotation": field_info.annotation,
|
||||
"metadata": field_info.metadata,
|
||||
"attributes": attributes,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelField:
|
||||
field_info: FieldInfo
|
||||
name: str
|
||||
mode: Literal["validation", "serialization"] = "validation"
|
||||
config: ConfigDict | None = None
|
||||
|
||||
@property
|
||||
def alias(self) -> str:
|
||||
a = self.field_info.alias
|
||||
return a if a is not None else self.name
|
||||
|
||||
@property
|
||||
def validation_alias(self) -> str | None:
|
||||
va = self.field_info.validation_alias
|
||||
if isinstance(va, str) and va:
|
||||
return va
|
||||
return None
|
||||
|
||||
@property
|
||||
def serialization_alias(self) -> str | None:
|
||||
sa = self.field_info.serialization_alias
|
||||
return sa or None
|
||||
|
||||
@property
|
||||
def default(self) -> Any:
|
||||
return self.get_default()
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
with warnings.catch_warnings():
|
||||
# Pydantic >= 2.12.0 warns about field specific metadata that is unused
|
||||
# (e.g. `TypeAdapter(Annotated[int, Field(alias='b')])`). In some cases, we
|
||||
# end up building the type adapter from a model field annotation so we
|
||||
# need to ignore the warning:
|
||||
if shared.PYDANTIC_VERSION_MINOR_TUPLE >= (2, 12):
|
||||
from pydantic.warnings import UnsupportedFieldAttributeWarning
|
||||
|
||||
warnings.simplefilter(
|
||||
"ignore", category=UnsupportedFieldAttributeWarning
|
||||
)
|
||||
# TODO: remove after setting the min Pydantic to v2.12.3
|
||||
# that adds asdict(), and use self.field_info.asdict() instead
|
||||
field_dict = asdict(self.field_info)
|
||||
annotated_args = (
|
||||
field_dict["annotation"],
|
||||
*field_dict["metadata"],
|
||||
# this FieldInfo needs to be created again so that it doesn't include
|
||||
# the old field info metadata and only the rest of the attributes
|
||||
Field(**field_dict["attributes"]),
|
||||
)
|
||||
self._type_adapter: TypeAdapter[Any] = TypeAdapter(
|
||||
Annotated[annotated_args],
|
||||
config=self.config,
|
||||
)
|
||||
|
||||
def get_default(self) -> Any:
|
||||
if self.field_info.is_required():
|
||||
return Undefined
|
||||
return self.field_info.get_default(call_default_factory=True)
|
||||
|
||||
def validate(
|
||||
self,
|
||||
value: Any,
|
||||
values: dict[str, Any] = {}, # noqa: B006
|
||||
*,
|
||||
loc: tuple[int | str, ...] = (),
|
||||
) -> tuple[Any, list[dict[str, Any]]]:
|
||||
try:
|
||||
return (
|
||||
self._type_adapter.validate_python(value, from_attributes=True),
|
||||
[],
|
||||
)
|
||||
except ValidationError as exc:
|
||||
return None, _regenerate_error_with_loc(
|
||||
errors=exc.errors(include_url=False), loc_prefix=loc
|
||||
)
|
||||
|
||||
def serialize(
|
||||
self,
|
||||
value: Any,
|
||||
*,
|
||||
mode: Literal["json", "python"] = "json",
|
||||
include: IncEx | None = None,
|
||||
exclude: IncEx | None = None,
|
||||
by_alias: bool = True,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
) -> Any:
|
||||
# What calls this code passes a value that already called
|
||||
# self._type_adapter.validate_python(value)
|
||||
return self._type_adapter.dump_python(
|
||||
value,
|
||||
mode=mode,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
|
||||
def serialize_json(
|
||||
self,
|
||||
value: Any,
|
||||
*,
|
||||
include: IncEx | None = None,
|
||||
exclude: IncEx | None = None,
|
||||
by_alias: bool = True,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
) -> bytes:
|
||||
# What calls this code passes a value that already called
|
||||
# self._type_adapter.validate_python(value)
|
||||
# This uses Pydantic's dump_json() which serializes directly to JSON
|
||||
# bytes in one pass (via Rust), avoiding the intermediate Python dict
|
||||
# step of dump_python(mode="json") + json.dumps().
|
||||
return self._type_adapter.dump_json(
|
||||
value,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
# Each ModelField is unique for our purposes, to allow making a dict from
|
||||
# ModelField to its JSON Schema.
|
||||
return id(self)
|
||||
|
||||
|
||||
def _has_computed_fields(field: ModelField) -> bool:
|
||||
computed_fields = field._type_adapter.core_schema.get("schema", {}).get(
|
||||
"computed_fields", []
|
||||
)
|
||||
return len(computed_fields) > 0
|
||||
|
||||
|
||||
def get_schema_from_model_field(
|
||||
*,
|
||||
field: ModelField,
|
||||
model_name_map: ModelNameMap,
|
||||
field_mapping: dict[
|
||||
tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue
|
||||
],
|
||||
separate_input_output_schemas: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
override_mode: Literal["validation"] | None = (
|
||||
None
|
||||
if (separate_input_output_schemas or _has_computed_fields(field))
|
||||
else "validation"
|
||||
)
|
||||
field_alias = (
|
||||
(field.validation_alias or field.alias)
|
||||
if field.mode == "validation"
|
||||
else (field.serialization_alias or field.alias)
|
||||
)
|
||||
|
||||
# This expects that GenerateJsonSchema was already used to generate the definitions
|
||||
json_schema = field_mapping[(field, override_mode or field.mode)]
|
||||
if "$ref" not in json_schema:
|
||||
# TODO remove when deprecating Pydantic v1
|
||||
# Ref: https://github.com/pydantic/pydantic/blob/d61792cc42c80b13b23e3ffa74bc37ec7c77f7d1/pydantic/schema.py#L207
|
||||
json_schema["title"] = field.field_info.title or field_alias.title().replace(
|
||||
"_", " "
|
||||
)
|
||||
return json_schema
|
||||
|
||||
|
||||
def get_definitions(
|
||||
*,
|
||||
fields: Sequence[ModelField],
|
||||
model_name_map: ModelNameMap,
|
||||
separate_input_output_schemas: bool = True,
|
||||
) -> tuple[
|
||||
dict[tuple[ModelField, Literal["validation", "serialization"]], JsonSchemaValue],
|
||||
dict[str, dict[str, Any]],
|
||||
]:
|
||||
schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE)
|
||||
validation_fields = [field for field in fields if field.mode == "validation"]
|
||||
serialization_fields = [field for field in fields if field.mode == "serialization"]
|
||||
flat_validation_models = get_flat_models_from_fields(
|
||||
validation_fields, known_models=set()
|
||||
)
|
||||
flat_serialization_models = get_flat_models_from_fields(
|
||||
serialization_fields, known_models=set()
|
||||
)
|
||||
flat_validation_model_fields = [
|
||||
ModelField(
|
||||
field_info=FieldInfo(annotation=model),
|
||||
name=model.__name__,
|
||||
mode="validation",
|
||||
)
|
||||
for model in flat_validation_models
|
||||
]
|
||||
flat_serialization_model_fields = [
|
||||
ModelField(
|
||||
field_info=FieldInfo(annotation=model),
|
||||
name=model.__name__,
|
||||
mode="serialization",
|
||||
)
|
||||
for model in flat_serialization_models
|
||||
]
|
||||
flat_model_fields = flat_validation_model_fields + flat_serialization_model_fields
|
||||
input_types = {f.field_info.annotation for f in fields}
|
||||
unique_flat_model_fields = {
|
||||
f for f in flat_model_fields if f.field_info.annotation not in input_types
|
||||
}
|
||||
inputs = [
|
||||
(
|
||||
field,
|
||||
(
|
||||
field.mode
|
||||
if (separate_input_output_schemas or _has_computed_fields(field))
|
||||
else "validation"
|
||||
),
|
||||
field._type_adapter.core_schema,
|
||||
)
|
||||
for field in list(fields) + list(unique_flat_model_fields)
|
||||
]
|
||||
field_mapping, definitions = schema_generator.generate_definitions(inputs=inputs)
|
||||
for item_def in cast(dict[str, dict[str, Any]], definitions).values():
|
||||
if "description" in item_def:
|
||||
item_description = cast(str, item_def["description"]).split("\f")[0]
|
||||
item_def["description"] = item_description
|
||||
# definitions: dict[DefsRef, dict[str, Any]]
|
||||
# but mypy complains about general str in other places that are not declared as
|
||||
# DefsRef, although DefsRef is just str:
|
||||
# DefsRef = NewType('DefsRef', str)
|
||||
# So, a cast to simplify the types here
|
||||
return field_mapping, cast(dict[str, dict[str, Any]], definitions)
|
||||
|
||||
|
||||
def is_scalar_field(field: ModelField) -> bool:
|
||||
from fastapi import params
|
||||
|
||||
return shared.field_annotation_is_scalar(
|
||||
field.field_info.annotation
|
||||
) and not isinstance(field.field_info, params.Body)
|
||||
|
||||
|
||||
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
|
||||
cls = type(field_info)
|
||||
merged_field_info = cls.from_annotation(annotation)
|
||||
new_field_info = copy(field_info)
|
||||
new_field_info.metadata = merged_field_info.metadata
|
||||
new_field_info.annotation = merged_field_info.annotation
|
||||
return new_field_info
|
||||
|
||||
|
||||
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
|
||||
origin_type = get_origin(field.field_info.annotation) or field.field_info.annotation
|
||||
if origin_type is Union or origin_type is UnionType: # Handle optional sequences
|
||||
union_args = get_args(field.field_info.annotation)
|
||||
for union_arg in union_args:
|
||||
if union_arg is type(None):
|
||||
continue
|
||||
origin_type = get_origin(union_arg) or union_arg
|
||||
break
|
||||
assert issubclass(origin_type, shared.sequence_types) # type: ignore[arg-type]
|
||||
return shared.sequence_annotation_to_type[origin_type](value) # type: ignore[no-any-return,index]
|
||||
|
||||
|
||||
def get_missing_field_error(loc: tuple[int | str, ...]) -> dict[str, Any]:
|
||||
error = ValidationError.from_exception_data(
|
||||
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
|
||||
).errors(include_url=False)[0]
|
||||
error["input"] = None
|
||||
return error # type: ignore[return-value]
|
||||
|
||||
|
||||
def create_body_model(
|
||||
*, fields: Sequence[ModelField], model_name: str
|
||||
) -> type[BaseModel]:
|
||||
field_params = {f.name: (f.field_info.annotation, f.field_info) for f in fields}
|
||||
BodyModel: type[BaseModel] = create_model(model_name, **field_params) # type: ignore[call-overload]
|
||||
return BodyModel
|
||||
|
||||
|
||||
def get_model_fields(model: type[BaseModel]) -> list[ModelField]:
|
||||
model_fields: list[ModelField] = []
|
||||
for name, field_info in model.model_fields.items():
|
||||
type_ = field_info.annotation
|
||||
if lenient_issubclass(type_, (BaseModel, dict)) or is_dataclass(type_):
|
||||
model_config = None
|
||||
else:
|
||||
model_config = model.model_config
|
||||
model_fields.append(
|
||||
ModelField(
|
||||
field_info=field_info,
|
||||
name=name,
|
||||
config=model_config,
|
||||
)
|
||||
)
|
||||
return model_fields
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_cached_model_fields(model: type[BaseModel]) -> list[ModelField]:
|
||||
return get_model_fields(model)
|
||||
|
||||
|
||||
# Duplicate of several schema functions from Pydantic v1 to make them compatible with
|
||||
# Pydantic v2 and allow mixing the models
|
||||
|
||||
TypeModelOrEnum = type["BaseModel"] | type[Enum]
|
||||
TypeModelSet = set[TypeModelOrEnum]
|
||||
|
||||
|
||||
def normalize_name(name: str) -> str:
|
||||
return re.sub(r"[^a-zA-Z0-9.\-_]", "_", name)
|
||||
|
||||
|
||||
def get_model_name_map(unique_models: TypeModelSet) -> dict[TypeModelOrEnum, str]:
|
||||
name_model_map = {}
|
||||
for model in unique_models:
|
||||
model_name = normalize_name(model.__name__)
|
||||
name_model_map[model_name] = model
|
||||
return {v: k for k, v in name_model_map.items()}
|
||||
|
||||
|
||||
def get_flat_models_from_model(
|
||||
model: type["BaseModel"], known_models: TypeModelSet | None = None
|
||||
) -> TypeModelSet:
|
||||
known_models = known_models or set()
|
||||
fields = get_model_fields(model)
|
||||
get_flat_models_from_fields(fields, known_models=known_models)
|
||||
return known_models
|
||||
|
||||
|
||||
def get_flat_models_from_annotation(
|
||||
annotation: Any, known_models: TypeModelSet
|
||||
) -> TypeModelSet:
|
||||
origin = get_origin(annotation)
|
||||
if origin is not None:
|
||||
for arg in get_args(annotation):
|
||||
if lenient_issubclass(arg, (BaseModel, Enum)):
|
||||
if arg not in known_models:
|
||||
known_models.add(arg) # type: ignore[arg-type]
|
||||
if lenient_issubclass(arg, BaseModel):
|
||||
get_flat_models_from_model(arg, known_models=known_models)
|
||||
else:
|
||||
get_flat_models_from_annotation(arg, known_models=known_models)
|
||||
return known_models
|
||||
|
||||
|
||||
def get_flat_models_from_field(
|
||||
field: ModelField, known_models: TypeModelSet
|
||||
) -> TypeModelSet:
|
||||
field_type = field.field_info.annotation
|
||||
if lenient_issubclass(field_type, BaseModel):
|
||||
if field_type in known_models:
|
||||
return known_models
|
||||
known_models.add(field_type)
|
||||
get_flat_models_from_model(field_type, known_models=known_models)
|
||||
elif lenient_issubclass(field_type, Enum):
|
||||
known_models.add(field_type)
|
||||
else:
|
||||
get_flat_models_from_annotation(field_type, known_models=known_models)
|
||||
return known_models
|
||||
|
||||
|
||||
def get_flat_models_from_fields(
|
||||
fields: Sequence[ModelField], known_models: TypeModelSet
|
||||
) -> TypeModelSet:
|
||||
for field in fields:
|
||||
get_flat_models_from_field(field, known_models=known_models)
|
||||
return known_models
|
||||
|
||||
|
||||
def _regenerate_error_with_loc(
|
||||
*, errors: Sequence[Any], loc_prefix: tuple[str | int, ...]
|
||||
) -> list[dict[str, Any]]:
|
||||
updated_loc_errors: list[Any] = [
|
||||
{**err, "loc": loc_prefix + err.get("loc", ())} for err in errors
|
||||
]
|
||||
|
||||
return updated_loc_errors
|
||||
Reference in New Issue
Block a user