Initial commit
This commit is contained in:
@@ -0,0 +1,69 @@
|
||||
from .exceptions import SettingsError
|
||||
from .main import BaseSettings, CliApp, SettingsConfigDict
|
||||
from .sources import (
|
||||
CLI_SUPPRESS,
|
||||
AWSSecretsManagerSettingsSource,
|
||||
AzureKeyVaultSettingsSource,
|
||||
CliDualFlag,
|
||||
CliExplicitFlag,
|
||||
CliImplicitFlag,
|
||||
CliMutuallyExclusiveGroup,
|
||||
CliPositionalArg,
|
||||
CliSettingsSource,
|
||||
CliSubCommand,
|
||||
CliSuppress,
|
||||
CliToggleFlag,
|
||||
CliUnknownArgs,
|
||||
DotEnvSettingsSource,
|
||||
EnvSettingsSource,
|
||||
ForceDecode,
|
||||
GoogleSecretManagerSettingsSource,
|
||||
InitSettingsSource,
|
||||
JsonConfigSettingsSource,
|
||||
NestedSecretsSettingsSource,
|
||||
NoDecode,
|
||||
PydanticBaseSettingsSource,
|
||||
PyprojectTomlConfigSettingsSource,
|
||||
SecretsSettingsSource,
|
||||
TomlConfigSettingsSource,
|
||||
YamlConfigSettingsSource,
|
||||
get_subcommand,
|
||||
)
|
||||
from .version import VERSION
|
||||
|
||||
__all__ = (
|
||||
'CLI_SUPPRESS',
|
||||
'AWSSecretsManagerSettingsSource',
|
||||
'AzureKeyVaultSettingsSource',
|
||||
'BaseSettings',
|
||||
'CliApp',
|
||||
'CliExplicitFlag',
|
||||
'CliImplicitFlag',
|
||||
'CliToggleFlag',
|
||||
'CliDualFlag',
|
||||
'CliMutuallyExclusiveGroup',
|
||||
'CliPositionalArg',
|
||||
'CliSettingsSource',
|
||||
'CliSubCommand',
|
||||
'CliSuppress',
|
||||
'CliUnknownArgs',
|
||||
'DotEnvSettingsSource',
|
||||
'EnvSettingsSource',
|
||||
'ForceDecode',
|
||||
'GoogleSecretManagerSettingsSource',
|
||||
'InitSettingsSource',
|
||||
'JsonConfigSettingsSource',
|
||||
'NestedSecretsSettingsSource',
|
||||
'NoDecode',
|
||||
'PydanticBaseSettingsSource',
|
||||
'PyprojectTomlConfigSettingsSource',
|
||||
'SecretsSettingsSource',
|
||||
'SettingsConfigDict',
|
||||
'SettingsError',
|
||||
'TomlConfigSettingsSource',
|
||||
'YamlConfigSettingsSource',
|
||||
'__version__',
|
||||
'get_subcommand',
|
||||
)
|
||||
|
||||
__version__ = VERSION
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,4 @@
|
||||
class SettingsError(ValueError):
|
||||
"""Base exception for settings-related errors."""
|
||||
|
||||
pass
|
||||
901
venv/lib/python3.12/site-packages/pydantic_settings/main.py
Normal file
901
venv/lib/python3.12/site-packages/pydantic_settings/main.py
Normal file
@@ -0,0 +1,901 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import re
|
||||
import threading
|
||||
import warnings
|
||||
from argparse import Namespace
|
||||
from collections.abc import Mapping
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, ClassVar, Literal, TextIO, TypeVar, cast
|
||||
|
||||
from pydantic import ConfigDict
|
||||
from pydantic._internal._config import config_keys
|
||||
from pydantic._internal._signature import _field_name_for_signature
|
||||
from pydantic._internal._utils import deep_update, is_model_class
|
||||
from pydantic.dataclasses import is_pydantic_dataclass
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
from .exceptions import SettingsError
|
||||
from .sources import (
|
||||
ENV_FILE_SENTINEL,
|
||||
CliSettingsSource,
|
||||
DefaultSettingsSource,
|
||||
DotEnvSettingsSource,
|
||||
DotenvType,
|
||||
EnvPrefixTarget,
|
||||
EnvSettingsSource,
|
||||
InitSettingsSource,
|
||||
JsonConfigSettingsSource,
|
||||
PathType,
|
||||
PydanticBaseSettingsSource,
|
||||
PydanticModel,
|
||||
PyprojectTomlConfigSettingsSource,
|
||||
SecretsSettingsSource,
|
||||
TomlConfigSettingsSource,
|
||||
YamlConfigSettingsSource,
|
||||
get_subcommand,
|
||||
)
|
||||
from .sources.utils import _get_alias_names
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class SettingsConfigDict(ConfigDict, total=False):
|
||||
case_sensitive: bool
|
||||
nested_model_default_partial_update: bool | None
|
||||
env_prefix: str
|
||||
env_prefix_target: EnvPrefixTarget
|
||||
env_file: DotenvType | None
|
||||
env_file_encoding: str | None
|
||||
env_ignore_empty: bool
|
||||
env_nested_delimiter: str | None
|
||||
env_nested_max_split: int | None
|
||||
env_parse_none_str: str | None
|
||||
env_parse_enums: bool | None
|
||||
cli_prog_name: str | None
|
||||
cli_parse_args: bool | list[str] | tuple[str, ...] | None
|
||||
cli_parse_none_str: str | None
|
||||
cli_hide_none_type: bool
|
||||
cli_avoid_json: bool
|
||||
cli_enforce_required: bool
|
||||
cli_use_class_docs_for_groups: bool
|
||||
cli_exit_on_error: bool
|
||||
cli_prefix: str
|
||||
cli_flag_prefix_char: str
|
||||
cli_implicit_flags: bool | Literal['dual', 'toggle'] | None
|
||||
cli_ignore_unknown_args: bool | None
|
||||
cli_kebab_case: bool | Literal['all', 'no_enums'] | None
|
||||
cli_shortcuts: Mapping[str, str | list[str]] | None
|
||||
secrets_dir: PathType | None
|
||||
json_file: PathType | None
|
||||
json_file_encoding: str | None
|
||||
yaml_file: PathType | None
|
||||
yaml_file_encoding: str | None
|
||||
yaml_config_section: str | None
|
||||
"""
|
||||
Specifies the section in a YAML file from which to load the settings.
|
||||
Supports dot-notation for nested paths (e.g., 'config.app.settings').
|
||||
If provided, the settings will be loaded from the specified section.
|
||||
This is useful when the YAML file contains multiple configuration sections
|
||||
and you only want to load a specific subset into your settings model.
|
||||
"""
|
||||
|
||||
pyproject_toml_depth: int
|
||||
"""
|
||||
Number of levels **up** from the current working directory to attempt to find a pyproject.toml
|
||||
file.
|
||||
|
||||
This is only used when a pyproject.toml file is not found in the current working directory.
|
||||
"""
|
||||
|
||||
pyproject_toml_table_header: tuple[str, ...]
|
||||
"""
|
||||
Header of the TOML table within a pyproject.toml file to use when filling variables.
|
||||
This is supplied as a `tuple[str, ...]` instead of a `str` to accommodate for headers
|
||||
containing a `.`.
|
||||
|
||||
For example, `toml_table_header = ("tool", "my.tool", "foo")` can be used to fill variable
|
||||
values from a table with header `[tool."my.tool".foo]`.
|
||||
|
||||
To use the root table, exclude this config setting or provide an empty tuple.
|
||||
"""
|
||||
|
||||
toml_file: PathType | None
|
||||
enable_decoding: bool
|
||||
|
||||
|
||||
# Extend `config_keys` by pydantic settings config keys to
|
||||
# support setting config through class kwargs.
|
||||
# Pydantic uses `config_keys` in `pydantic._internal._config.ConfigWrapper.for_model`
|
||||
# to extract config keys from model kwargs, So, by adding pydantic settings keys to
|
||||
# `config_keys`, they will be considered as valid config keys and will be collected
|
||||
# by Pydantic.
|
||||
config_keys |= set(SettingsConfigDict.__annotations__.keys())
|
||||
|
||||
|
||||
class BaseSettings(BaseModel):
|
||||
"""
|
||||
Base class for settings, allowing values to be overridden by environment variables.
|
||||
|
||||
This is useful in production for secrets you do not wish to save in code, it plays nicely with docker(-compose),
|
||||
Heroku and any 12 factor app design.
|
||||
|
||||
All the below attributes can be set via `model_config`.
|
||||
|
||||
Args:
|
||||
_case_sensitive: Whether environment and CLI variable names should be read with case-sensitivity.
|
||||
Defaults to `None`.
|
||||
_nested_model_default_partial_update: Whether to allow partial updates on nested model default object fields.
|
||||
Defaults to `False`.
|
||||
_env_prefix: Prefix for all environment variables. Defaults to `None`.
|
||||
_env_prefix_target: Targets to which `_env_prefix` is applied. Default: `variable`.
|
||||
_env_file: The env file(s) to load settings values from. Defaults to `Path('')`, which
|
||||
means that the value from `model_config['env_file']` should be used. You can also pass
|
||||
`None` to indicate that environment variables should not be loaded from an env file.
|
||||
_env_file_encoding: The env file encoding, e.g. `'latin-1'`. Defaults to `None`.
|
||||
_env_ignore_empty: Ignore environment variables where the value is an empty string. Default to `False`.
|
||||
_env_nested_delimiter: The nested env values delimiter. Defaults to `None`.
|
||||
_env_nested_max_split: The nested env values maximum nesting. Defaults to `None`, which means no limit.
|
||||
_env_parse_none_str: The env string value that should be parsed (e.g. "null", "void", "None", etc.)
|
||||
into `None` type(None). Defaults to `None` type(None), which means no parsing should occur.
|
||||
_env_parse_enums: Parse enum field names to values. Defaults to `None.`, which means no parsing should occur.
|
||||
_cli_prog_name: The CLI program name to display in help text. Defaults to `None` if _cli_parse_args is `None`.
|
||||
Otherwise, defaults to sys.argv[0].
|
||||
_cli_parse_args: The list of CLI arguments to parse. Defaults to None.
|
||||
If set to `True`, defaults to sys.argv[1:].
|
||||
_cli_settings_source: Override the default CLI settings source with a user defined instance. Defaults to None.
|
||||
_cli_parse_none_str: The CLI string value that should be parsed (e.g. "null", "void", "None", etc.) into
|
||||
`None` type(None). Defaults to _env_parse_none_str value if set. Otherwise, defaults to "null" if
|
||||
_cli_avoid_json is `False`, and "None" if _cli_avoid_json is `True`.
|
||||
_cli_hide_none_type: Hide `None` values in CLI help text. Defaults to `False`.
|
||||
_cli_avoid_json: Avoid complex JSON objects in CLI help text. Defaults to `False`.
|
||||
_cli_enforce_required: Enforce required fields at the CLI. Defaults to `False`.
|
||||
_cli_use_class_docs_for_groups: Use class docstrings in CLI group help text instead of field descriptions.
|
||||
Defaults to `False`.
|
||||
_cli_exit_on_error: Determines whether or not the internal parser exits with error info when an error occurs.
|
||||
Defaults to `True`.
|
||||
_cli_prefix: The root parser command line arguments prefix. Defaults to "".
|
||||
_cli_flag_prefix_char: The flag prefix character to use for CLI optional arguments. Defaults to '-'.
|
||||
_cli_implicit_flags: Controls how `bool` fields are exposed as CLI flags.
|
||||
|
||||
- False (default): no implicit flags are generated; booleans must be set explicitly (e.g. --flag=true).
|
||||
- True / 'dual': optional boolean fields generate both positive and negative forms (--flag and --no-flag).
|
||||
- 'toggle': required boolean fields remain in 'dual' mode, while optional boolean fields generate a single
|
||||
flag aligned with the default value (if default=False, expose --flag; if default=True, expose --no-flag).
|
||||
_cli_ignore_unknown_args: Whether to ignore unknown CLI args and parse only known ones. Defaults to `False`.
|
||||
_cli_kebab_case: CLI args use kebab case. Defaults to `False`.
|
||||
_cli_shortcuts: Mapping of target field name to alias names. Defaults to `None`.
|
||||
_secrets_dir: The secret files directory or a sequence of directories. Defaults to `None`.
|
||||
_build_sources: Pre-initialized sources and init kwargs to use for building instantiation values.
|
||||
Defaults to `None`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
__pydantic_self__,
|
||||
_case_sensitive: bool | None = None,
|
||||
_nested_model_default_partial_update: bool | None = None,
|
||||
_env_prefix: str | None = None,
|
||||
_env_prefix_target: EnvPrefixTarget | None = None,
|
||||
_env_file: DotenvType | None = ENV_FILE_SENTINEL,
|
||||
_env_file_encoding: str | None = None,
|
||||
_env_ignore_empty: bool | None = None,
|
||||
_env_nested_delimiter: str | None = None,
|
||||
_env_nested_max_split: int | None = None,
|
||||
_env_parse_none_str: str | None = None,
|
||||
_env_parse_enums: bool | None = None,
|
||||
_cli_prog_name: str | None = None,
|
||||
_cli_parse_args: bool | list[str] | tuple[str, ...] | None = None,
|
||||
_cli_settings_source: CliSettingsSource[Any] | None = None,
|
||||
_cli_parse_none_str: str | None = None,
|
||||
_cli_hide_none_type: bool | None = None,
|
||||
_cli_avoid_json: bool | None = None,
|
||||
_cli_enforce_required: bool | None = None,
|
||||
_cli_use_class_docs_for_groups: bool | None = None,
|
||||
_cli_exit_on_error: bool | None = None,
|
||||
_cli_prefix: str | None = None,
|
||||
_cli_flag_prefix_char: str | None = None,
|
||||
_cli_implicit_flags: bool | Literal['dual', 'toggle'] | None = None,
|
||||
_cli_ignore_unknown_args: bool | None = None,
|
||||
_cli_kebab_case: bool | Literal['all', 'no_enums'] | None = None,
|
||||
_cli_shortcuts: Mapping[str, str | list[str]] | None = None,
|
||||
_secrets_dir: PathType | None = None,
|
||||
_build_sources: tuple[tuple[PydanticBaseSettingsSource, ...], dict[str, Any]] | None = None,
|
||||
**values: Any,
|
||||
) -> None:
|
||||
sources, init_kwargs = (
|
||||
_build_sources
|
||||
if _build_sources is not None
|
||||
else __pydantic_self__.__class__._settings_init_sources(
|
||||
_case_sensitive=_case_sensitive,
|
||||
_nested_model_default_partial_update=_nested_model_default_partial_update,
|
||||
_env_prefix=_env_prefix,
|
||||
_env_prefix_target=_env_prefix_target,
|
||||
_env_file=_env_file,
|
||||
_env_file_encoding=_env_file_encoding,
|
||||
_env_ignore_empty=_env_ignore_empty,
|
||||
_env_nested_delimiter=_env_nested_delimiter,
|
||||
_env_nested_max_split=_env_nested_max_split,
|
||||
_env_parse_none_str=_env_parse_none_str,
|
||||
_env_parse_enums=_env_parse_enums,
|
||||
_cli_prog_name=_cli_prog_name,
|
||||
_cli_parse_args=_cli_parse_args,
|
||||
_cli_settings_source=_cli_settings_source,
|
||||
_cli_parse_none_str=_cli_parse_none_str,
|
||||
_cli_hide_none_type=_cli_hide_none_type,
|
||||
_cli_avoid_json=_cli_avoid_json,
|
||||
_cli_enforce_required=_cli_enforce_required,
|
||||
_cli_use_class_docs_for_groups=_cli_use_class_docs_for_groups,
|
||||
_cli_exit_on_error=_cli_exit_on_error,
|
||||
_cli_prefix=_cli_prefix,
|
||||
_cli_flag_prefix_char=_cli_flag_prefix_char,
|
||||
_cli_implicit_flags=_cli_implicit_flags,
|
||||
_cli_ignore_unknown_args=_cli_ignore_unknown_args,
|
||||
_cli_kebab_case=_cli_kebab_case,
|
||||
_cli_shortcuts=_cli_shortcuts,
|
||||
_secrets_dir=_secrets_dir,
|
||||
**values,
|
||||
)
|
||||
)
|
||||
|
||||
super().__init__(**__pydantic_self__.__class__._settings_build_values(sources, init_kwargs))
|
||||
|
||||
@classmethod
|
||||
def settings_customise_sources(
|
||||
cls,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_settings: PydanticBaseSettingsSource,
|
||||
env_settings: PydanticBaseSettingsSource,
|
||||
dotenv_settings: PydanticBaseSettingsSource,
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
"""
|
||||
Define the sources and their order for loading the settings values.
|
||||
|
||||
Args:
|
||||
settings_cls: The Settings class.
|
||||
init_settings: The `InitSettingsSource` instance.
|
||||
env_settings: The `EnvSettingsSource` instance.
|
||||
dotenv_settings: The `DotEnvSettingsSource` instance.
|
||||
file_secret_settings: The `SecretsSettingsSource` instance.
|
||||
|
||||
Returns:
|
||||
A tuple containing the sources and their order for loading the settings values.
|
||||
"""
|
||||
return init_settings, env_settings, dotenv_settings, file_secret_settings
|
||||
|
||||
@classmethod
|
||||
def _settings_init_sources(
|
||||
cls,
|
||||
_case_sensitive: bool | None = None,
|
||||
_nested_model_default_partial_update: bool | None = None,
|
||||
_env_prefix: str | None = None,
|
||||
_env_prefix_target: EnvPrefixTarget | None = None,
|
||||
_env_file: DotenvType | None = None,
|
||||
_env_file_encoding: str | None = None,
|
||||
_env_ignore_empty: bool | None = None,
|
||||
_env_nested_delimiter: str | None = None,
|
||||
_env_nested_max_split: int | None = None,
|
||||
_env_parse_none_str: str | None = None,
|
||||
_env_parse_enums: bool | None = None,
|
||||
_cli_prog_name: str | None = None,
|
||||
_cli_parse_args: bool | list[str] | tuple[str, ...] | None = None,
|
||||
_cli_settings_source: CliSettingsSource[Any] | None = None,
|
||||
_cli_parse_none_str: str | None = None,
|
||||
_cli_hide_none_type: bool | None = None,
|
||||
_cli_avoid_json: bool | None = None,
|
||||
_cli_enforce_required: bool | None = None,
|
||||
_cli_use_class_docs_for_groups: bool | None = None,
|
||||
_cli_exit_on_error: bool | None = None,
|
||||
_cli_prefix: str | None = None,
|
||||
_cli_flag_prefix_char: str | None = None,
|
||||
_cli_implicit_flags: bool | Literal['dual', 'toggle'] | None = None,
|
||||
_cli_ignore_unknown_args: bool | None = None,
|
||||
_cli_kebab_case: bool | Literal['all', 'no_enums'] | None = None,
|
||||
_cli_shortcuts: Mapping[str, str | list[str]] | None = None,
|
||||
_secrets_dir: PathType | None = None,
|
||||
**init_kwargs: dict[str, Any],
|
||||
) -> tuple[tuple[PydanticBaseSettingsSource, ...], dict[str, Any]]:
|
||||
# Determine settings config values
|
||||
case_sensitive = _case_sensitive if _case_sensitive is not None else cls.model_config.get('case_sensitive')
|
||||
env_prefix = _env_prefix if _env_prefix is not None else cls.model_config.get('env_prefix')
|
||||
env_prefix_target = (
|
||||
_env_prefix_target if _env_prefix_target is not None else cls.model_config.get('env_prefix_target')
|
||||
)
|
||||
nested_model_default_partial_update = (
|
||||
_nested_model_default_partial_update
|
||||
if _nested_model_default_partial_update is not None
|
||||
else cls.model_config.get('nested_model_default_partial_update')
|
||||
)
|
||||
env_file = _env_file if _env_file != ENV_FILE_SENTINEL else cls.model_config.get('env_file')
|
||||
env_file_encoding = (
|
||||
_env_file_encoding if _env_file_encoding is not None else cls.model_config.get('env_file_encoding')
|
||||
)
|
||||
env_ignore_empty = (
|
||||
_env_ignore_empty if _env_ignore_empty is not None else cls.model_config.get('env_ignore_empty')
|
||||
)
|
||||
env_nested_delimiter = (
|
||||
_env_nested_delimiter if _env_nested_delimiter is not None else cls.model_config.get('env_nested_delimiter')
|
||||
)
|
||||
env_nested_max_split = (
|
||||
_env_nested_max_split if _env_nested_max_split is not None else cls.model_config.get('env_nested_max_split')
|
||||
)
|
||||
env_parse_none_str = (
|
||||
_env_parse_none_str if _env_parse_none_str is not None else cls.model_config.get('env_parse_none_str')
|
||||
)
|
||||
env_parse_enums = _env_parse_enums if _env_parse_enums is not None else cls.model_config.get('env_parse_enums')
|
||||
|
||||
cli_prog_name = _cli_prog_name if _cli_prog_name is not None else cls.model_config.get('cli_prog_name')
|
||||
cli_parse_args = _cli_parse_args if _cli_parse_args is not None else cls.model_config.get('cli_parse_args')
|
||||
cli_settings_source = (
|
||||
_cli_settings_source if _cli_settings_source is not None else cls.model_config.get('cli_settings_source')
|
||||
)
|
||||
cli_parse_none_str = (
|
||||
_cli_parse_none_str if _cli_parse_none_str is not None else cls.model_config.get('cli_parse_none_str')
|
||||
)
|
||||
cli_parse_none_str = cli_parse_none_str if not env_parse_none_str else env_parse_none_str
|
||||
cli_hide_none_type = (
|
||||
_cli_hide_none_type if _cli_hide_none_type is not None else cls.model_config.get('cli_hide_none_type')
|
||||
)
|
||||
cli_avoid_json = _cli_avoid_json if _cli_avoid_json is not None else cls.model_config.get('cli_avoid_json')
|
||||
cli_enforce_required = (
|
||||
_cli_enforce_required if _cli_enforce_required is not None else cls.model_config.get('cli_enforce_required')
|
||||
)
|
||||
cli_use_class_docs_for_groups = (
|
||||
_cli_use_class_docs_for_groups
|
||||
if _cli_use_class_docs_for_groups is not None
|
||||
else cls.model_config.get('cli_use_class_docs_for_groups')
|
||||
)
|
||||
cli_exit_on_error = (
|
||||
_cli_exit_on_error if _cli_exit_on_error is not None else cls.model_config.get('cli_exit_on_error')
|
||||
)
|
||||
cli_prefix = _cli_prefix if _cli_prefix is not None else cls.model_config.get('cli_prefix')
|
||||
cli_flag_prefix_char = (
|
||||
_cli_flag_prefix_char if _cli_flag_prefix_char is not None else cls.model_config.get('cli_flag_prefix_char')
|
||||
)
|
||||
cli_implicit_flags = (
|
||||
_cli_implicit_flags if _cli_implicit_flags is not None else cls.model_config.get('cli_implicit_flags')
|
||||
)
|
||||
cli_ignore_unknown_args = (
|
||||
_cli_ignore_unknown_args
|
||||
if _cli_ignore_unknown_args is not None
|
||||
else cls.model_config.get('cli_ignore_unknown_args')
|
||||
)
|
||||
cli_kebab_case = _cli_kebab_case if _cli_kebab_case is not None else cls.model_config.get('cli_kebab_case')
|
||||
cli_shortcuts = _cli_shortcuts if _cli_shortcuts is not None else cls.model_config.get('cli_shortcuts')
|
||||
|
||||
secrets_dir = _secrets_dir if _secrets_dir is not None else cls.model_config.get('secrets_dir')
|
||||
|
||||
# Configure built-in sources
|
||||
default_settings = DefaultSettingsSource(
|
||||
cls, nested_model_default_partial_update=nested_model_default_partial_update
|
||||
)
|
||||
init_settings = InitSettingsSource(
|
||||
cls,
|
||||
init_kwargs=init_kwargs,
|
||||
nested_model_default_partial_update=nested_model_default_partial_update,
|
||||
)
|
||||
env_settings = EnvSettingsSource(
|
||||
cls,
|
||||
case_sensitive=case_sensitive,
|
||||
env_prefix=env_prefix,
|
||||
env_prefix_target=env_prefix_target,
|
||||
env_nested_delimiter=env_nested_delimiter,
|
||||
env_nested_max_split=env_nested_max_split,
|
||||
env_ignore_empty=env_ignore_empty,
|
||||
env_parse_none_str=env_parse_none_str,
|
||||
env_parse_enums=env_parse_enums,
|
||||
)
|
||||
dotenv_settings = DotEnvSettingsSource(
|
||||
cls,
|
||||
env_file=env_file,
|
||||
env_file_encoding=env_file_encoding,
|
||||
case_sensitive=case_sensitive,
|
||||
env_prefix=env_prefix,
|
||||
env_prefix_target=env_prefix_target,
|
||||
env_nested_delimiter=env_nested_delimiter,
|
||||
env_nested_max_split=env_nested_max_split,
|
||||
env_ignore_empty=env_ignore_empty,
|
||||
env_parse_none_str=env_parse_none_str,
|
||||
env_parse_enums=env_parse_enums,
|
||||
)
|
||||
|
||||
file_secret_settings = SecretsSettingsSource(
|
||||
cls,
|
||||
secrets_dir=secrets_dir,
|
||||
case_sensitive=case_sensitive,
|
||||
env_prefix=env_prefix,
|
||||
env_prefix_target=env_prefix_target,
|
||||
)
|
||||
# Provide a hook to set built-in sources priority and add / remove sources
|
||||
sources = cls.settings_customise_sources(
|
||||
cls,
|
||||
init_settings=init_settings,
|
||||
env_settings=env_settings,
|
||||
dotenv_settings=dotenv_settings,
|
||||
file_secret_settings=file_secret_settings,
|
||||
) + (default_settings,)
|
||||
custom_cli_sources = [source for source in sources if isinstance(source, CliSettingsSource)]
|
||||
if not any(custom_cli_sources):
|
||||
if isinstance(cli_settings_source, CliSettingsSource):
|
||||
sources = (cli_settings_source,) + sources
|
||||
elif cli_parse_args is not None:
|
||||
cli_settings = CliSettingsSource[Any](
|
||||
cls,
|
||||
cli_prog_name=cli_prog_name,
|
||||
cli_parse_args=cli_parse_args,
|
||||
cli_parse_none_str=cli_parse_none_str,
|
||||
cli_hide_none_type=cli_hide_none_type,
|
||||
cli_avoid_json=cli_avoid_json,
|
||||
cli_enforce_required=cli_enforce_required,
|
||||
cli_use_class_docs_for_groups=cli_use_class_docs_for_groups,
|
||||
cli_exit_on_error=cli_exit_on_error,
|
||||
cli_prefix=cli_prefix,
|
||||
cli_flag_prefix_char=cli_flag_prefix_char,
|
||||
cli_implicit_flags=cli_implicit_flags,
|
||||
cli_ignore_unknown_args=cli_ignore_unknown_args,
|
||||
cli_kebab_case=cli_kebab_case,
|
||||
cli_shortcuts=cli_shortcuts,
|
||||
case_sensitive=case_sensitive,
|
||||
)
|
||||
sources = (cli_settings,) + sources
|
||||
# We ensure that if command line arguments haven't been parsed yet, we do so.
|
||||
elif cli_parse_args not in (None, False) and not custom_cli_sources[0].env_vars:
|
||||
custom_cli_sources[0](args=cli_parse_args) # type: ignore
|
||||
|
||||
cls._settings_warn_unused_config_keys(sources, cls.model_config)
|
||||
|
||||
return sources, init_kwargs
|
||||
|
||||
@classmethod
|
||||
def _settings_build_values(
|
||||
cls, sources: tuple[PydanticBaseSettingsSource, ...], init_kwargs: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
if sources:
|
||||
state: dict[str, Any] = {}
|
||||
defaults: dict[str, Any] = {}
|
||||
states: dict[str, dict[str, Any]] = {}
|
||||
for source in sources:
|
||||
if isinstance(source, PydanticBaseSettingsSource):
|
||||
source._set_current_state(state)
|
||||
source._set_settings_sources_data(states)
|
||||
|
||||
source_name = source.__name__ if hasattr(source, '__name__') else type(source).__name__
|
||||
source_state = source()
|
||||
|
||||
if isinstance(source, DefaultSettingsSource):
|
||||
defaults = source_state
|
||||
|
||||
states[source_name] = source_state
|
||||
state = deep_update(source_state, state)
|
||||
|
||||
# Strip any default values not explicity set before returning final state
|
||||
state = {key: val for key, val in state.items() if key not in defaults or defaults[key] != val}
|
||||
cls._settings_restore_init_kwarg_names(cls, init_kwargs, state)
|
||||
|
||||
return state
|
||||
else:
|
||||
# no one should mean to do this, but I think returning an empty dict is marginally preferable
|
||||
# to an informative error and much better than a confusing error
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def _settings_restore_init_kwarg_names(
|
||||
settings_cls: type[BaseSettings], init_kwargs: dict[str, Any], state: dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Restore the init_kwarg key names to the final merged state dictionary.
|
||||
|
||||
This function renames keys in state to match the original init_kwargs key names,
|
||||
preserving the merged values from the source priority order.
|
||||
"""
|
||||
if init_kwargs and state:
|
||||
state_kwarg_names = set(state.keys())
|
||||
init_kwarg_names = set(init_kwargs.keys())
|
||||
for field_name, field_info in settings_cls.model_fields.items():
|
||||
alias_names, *_ = _get_alias_names(field_name, field_info)
|
||||
matchable_names = set(alias_names)
|
||||
include_name = settings_cls.model_config.get(
|
||||
'populate_by_name', False
|
||||
) or settings_cls.model_config.get('validate_by_name', False)
|
||||
if include_name:
|
||||
matchable_names.add(field_name)
|
||||
init_kwarg_name = init_kwarg_names & matchable_names
|
||||
state_kwarg_name = state_kwarg_names & matchable_names
|
||||
if init_kwarg_name and state_kwarg_name:
|
||||
# Use deterministic selection for both keys.
|
||||
# Target key: the key from init_kwargs that should be used in the final state.
|
||||
target_key = next(iter(init_kwarg_name))
|
||||
# Source key: prefer the alias (first in alias_names) if present in state,
|
||||
# as InitSettingsSource normalizes to the preferred alias.
|
||||
# This ensures we get the highest-priority value for this field.
|
||||
source_key = None
|
||||
for alias in alias_names:
|
||||
if alias in state_kwarg_name:
|
||||
source_key = alias
|
||||
break
|
||||
if source_key is None:
|
||||
# Fall back to field_name if no alias found in state
|
||||
source_key = field_name if field_name in state_kwarg_name else next(iter(state_kwarg_name))
|
||||
# Get the value from the source key and remove all matching keys
|
||||
value = state.pop(source_key)
|
||||
for key in state_kwarg_name - {source_key}:
|
||||
state.pop(key, None)
|
||||
state[target_key] = value
|
||||
|
||||
@staticmethod
|
||||
def _settings_warn_unused_config_keys(sources: tuple[object, ...], model_config: SettingsConfigDict) -> None:
|
||||
"""
|
||||
Warns if any values in model_config were set but the corresponding settings source has not been initialised.
|
||||
|
||||
The list alternative sources and their config keys can be found here:
|
||||
https://docs.pydantic.dev/latest/concepts/pydantic_settings/#other-settings-source
|
||||
|
||||
Args:
|
||||
sources: The tuple of configured sources
|
||||
model_config: The model config to check for unused config keys
|
||||
"""
|
||||
|
||||
def warn_if_not_used(source_type: type[PydanticBaseSettingsSource], keys: tuple[str, ...]) -> None:
|
||||
if not any(isinstance(source, source_type) for source in sources):
|
||||
for key in keys:
|
||||
if model_config.get(key) is not None:
|
||||
warnings.warn(
|
||||
f'Config key `{key}` is set in model_config but will be ignored because no '
|
||||
f'{source_type.__name__} source is configured. To use this config key, add a '
|
||||
f'{source_type.__name__} source to the settings sources via the '
|
||||
'settings_customise_sources hook.',
|
||||
UserWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
|
||||
warn_if_not_used(JsonConfigSettingsSource, ('json_file', 'json_file_encoding'))
|
||||
warn_if_not_used(PyprojectTomlConfigSettingsSource, ('pyproject_toml_depth', 'pyproject_toml_table_header'))
|
||||
warn_if_not_used(TomlConfigSettingsSource, ('toml_file',))
|
||||
warn_if_not_used(YamlConfigSettingsSource, ('yaml_file', 'yaml_file_encoding', 'yaml_config_section'))
|
||||
|
||||
model_config: ClassVar[SettingsConfigDict] = SettingsConfigDict(
|
||||
extra='forbid',
|
||||
arbitrary_types_allowed=True,
|
||||
validate_default=True,
|
||||
case_sensitive=False,
|
||||
env_prefix='',
|
||||
env_prefix_target='variable',
|
||||
nested_model_default_partial_update=False,
|
||||
env_file=None,
|
||||
env_file_encoding=None,
|
||||
env_ignore_empty=False,
|
||||
env_nested_delimiter=None,
|
||||
env_nested_max_split=None,
|
||||
env_parse_none_str=None,
|
||||
env_parse_enums=None,
|
||||
cli_prog_name=None,
|
||||
cli_parse_args=None,
|
||||
cli_parse_none_str=None,
|
||||
cli_hide_none_type=False,
|
||||
cli_avoid_json=False,
|
||||
cli_enforce_required=False,
|
||||
cli_use_class_docs_for_groups=False,
|
||||
cli_exit_on_error=True,
|
||||
cli_prefix='',
|
||||
cli_flag_prefix_char='-',
|
||||
cli_implicit_flags=False,
|
||||
cli_ignore_unknown_args=False,
|
||||
cli_kebab_case=False,
|
||||
cli_shortcuts=None,
|
||||
json_file=None,
|
||||
json_file_encoding=None,
|
||||
yaml_file=None,
|
||||
yaml_file_encoding=None,
|
||||
yaml_config_section=None,
|
||||
toml_file=None,
|
||||
secrets_dir=None,
|
||||
protected_namespaces=('model_validate', 'model_dump', 'settings_customise_sources'),
|
||||
enable_decoding=True,
|
||||
)
|
||||
|
||||
|
||||
class CliApp:
|
||||
"""
|
||||
A utility class for running Pydantic `BaseSettings`, `BaseModel`, or `pydantic.dataclasses.dataclass` as
|
||||
CLI applications.
|
||||
"""
|
||||
|
||||
_subcommand_stack: ClassVar[dict[int, tuple[CliSettingsSource[Any], Any, str]]] = {}
|
||||
_ansi_color: ClassVar[re.Pattern[str]] = re.compile(r'\x1b\[[0-9;]*m')
|
||||
|
||||
@staticmethod
|
||||
def _get_base_settings_cls(model_cls: type[Any]) -> type[BaseSettings]:
|
||||
if issubclass(model_cls, BaseSettings):
|
||||
return model_cls
|
||||
|
||||
class CliAppBaseSettings(BaseSettings, model_cls): # type: ignore
|
||||
__doc__ = model_cls.__doc__
|
||||
model_config = SettingsConfigDict(
|
||||
nested_model_default_partial_update=True,
|
||||
case_sensitive=True,
|
||||
cli_hide_none_type=True,
|
||||
cli_avoid_json=True,
|
||||
cli_enforce_required=True,
|
||||
cli_implicit_flags=True,
|
||||
cli_kebab_case=True,
|
||||
)
|
||||
|
||||
return CliAppBaseSettings
|
||||
|
||||
@staticmethod
|
||||
def _run_cli_cmd(model: Any, cli_cmd_method_name: str, is_required: bool) -> Any:
|
||||
command = getattr(type(model), cli_cmd_method_name, None)
|
||||
if command is None:
|
||||
if is_required:
|
||||
raise SettingsError(f'Error: {type(model).__name__} class is missing {cli_cmd_method_name} entrypoint')
|
||||
return model
|
||||
|
||||
# If the method is asynchronous, we handle its execution based on the current event loop status.
|
||||
if inspect.iscoroutinefunction(command):
|
||||
# For asynchronous methods, we have two execution scenarios:
|
||||
# 1. If no event loop is running in the current thread, run the coroutine directly with asyncio.run().
|
||||
# 2. If an event loop is already running in the current thread, run the coroutine in a separate thread to avoid conflicts.
|
||||
try:
|
||||
# Check if an event loop is currently running in this thread.
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
|
||||
if loop and loop.is_running():
|
||||
# We're in a context with an active event loop (e.g., Jupyter Notebook).
|
||||
# Running asyncio.run() here would cause conflicts, so we use a separate thread.
|
||||
exception_container = []
|
||||
|
||||
def run_coro() -> None:
|
||||
try:
|
||||
# Execute the coroutine in a new event loop in this separate thread.
|
||||
asyncio.run(command(model))
|
||||
except Exception as e:
|
||||
exception_container.append(e)
|
||||
|
||||
thread = threading.Thread(target=run_coro)
|
||||
thread.start()
|
||||
thread.join()
|
||||
if exception_container:
|
||||
# Propagate exceptions from the separate thread.
|
||||
raise exception_container[0]
|
||||
else:
|
||||
# No event loop is running; safe to run the coroutine directly.
|
||||
asyncio.run(command(model))
|
||||
else:
|
||||
# For synchronous methods, call them directly.
|
||||
command(model)
|
||||
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def run(
|
||||
model_cls: type[T],
|
||||
cli_args: list[str] | Namespace | SimpleNamespace | dict[str, Any] | None = None,
|
||||
cli_settings_source: CliSettingsSource[Any] | None = None,
|
||||
cli_exit_on_error: bool | None = None,
|
||||
cli_cmd_method_name: str = 'cli_cmd',
|
||||
**model_init_data: Any,
|
||||
) -> T:
|
||||
"""
|
||||
Runs a Pydantic `BaseSettings`, `BaseModel`, or `pydantic.dataclasses.dataclass` as a CLI application.
|
||||
Running a model as a CLI application requires the `cli_cmd` method to be defined in the model class.
|
||||
|
||||
Args:
|
||||
model_cls: The model class to run as a CLI application.
|
||||
cli_args: The list of CLI arguments to parse. If `cli_settings_source` is specified, this may
|
||||
also be a namespace or dictionary of pre-parsed CLI arguments. Defaults to `sys.argv[1:]`.
|
||||
cli_settings_source: Override the default CLI settings source with a user defined instance.
|
||||
Defaults to `None`.
|
||||
cli_exit_on_error: Determines whether this function exits on error. If model is subclass of
|
||||
`BaseSettings`, defaults to BaseSettings `cli_exit_on_error` value. Otherwise, defaults to
|
||||
`True`.
|
||||
cli_cmd_method_name: The CLI command method name to run. Defaults to "cli_cmd".
|
||||
model_init_data: The model init data.
|
||||
|
||||
Returns:
|
||||
The ran instance of model.
|
||||
|
||||
Raises:
|
||||
SettingsError: If model_cls is not subclass of `BaseModel` or `pydantic.dataclasses.dataclass`.
|
||||
SettingsError: If model_cls does not have a `cli_cmd` entrypoint defined.
|
||||
"""
|
||||
|
||||
if not (is_pydantic_dataclass(model_cls) or is_model_class(model_cls)):
|
||||
raise SettingsError(
|
||||
f'Error: {model_cls.__name__} is not subclass of BaseModel or pydantic.dataclasses.dataclass'
|
||||
)
|
||||
|
||||
cli_settings = None
|
||||
cli_parse_args = True if cli_args is None else cli_args
|
||||
if cli_settings_source is not None:
|
||||
if isinstance(cli_parse_args, (Namespace, SimpleNamespace, dict)):
|
||||
cli_settings = cli_settings_source(parsed_args=cli_parse_args)
|
||||
else:
|
||||
cli_settings = cli_settings_source(args=cli_parse_args)
|
||||
elif isinstance(cli_parse_args, (Namespace, SimpleNamespace, dict)):
|
||||
raise SettingsError('Error: `cli_args` must be list[str] or None when `cli_settings_source` is not used')
|
||||
|
||||
model_init_data['_cli_parse_args'] = cli_parse_args
|
||||
model_init_data['_cli_exit_on_error'] = cli_exit_on_error
|
||||
model_init_data['_cli_settings_source'] = cli_settings
|
||||
if not issubclass(model_cls, BaseSettings):
|
||||
base_settings_cls = CliApp._get_base_settings_cls(model_cls)
|
||||
sources, init_kwargs = base_settings_cls._settings_init_sources(**model_init_data)
|
||||
model = base_settings_cls(**base_settings_cls._settings_build_values(sources, init_kwargs))
|
||||
model_init_data = {}
|
||||
for field_name, field_info in base_settings_cls.model_fields.items():
|
||||
model_init_data[_field_name_for_signature(field_name, field_info)] = getattr(model, field_name)
|
||||
command = model_cls(**model_init_data)
|
||||
else:
|
||||
sources, init_kwargs = model_cls._settings_init_sources(**model_init_data)
|
||||
command = model_cls(_build_sources=(sources, init_kwargs))
|
||||
|
||||
subcommand_dest = ':subcommand'
|
||||
cli_settings_source = [source for source in sources if isinstance(source, CliSettingsSource)][0]
|
||||
CliApp._subcommand_stack[id(command)] = (cli_settings_source, cli_settings_source.root_parser, subcommand_dest)
|
||||
try:
|
||||
data_model = CliApp._run_cli_cmd(command, cli_cmd_method_name, is_required=False)
|
||||
finally:
|
||||
del CliApp._subcommand_stack[id(command)]
|
||||
return data_model
|
||||
|
||||
@staticmethod
|
||||
def run_subcommand(
|
||||
model: PydanticModel, cli_exit_on_error: bool | None = None, cli_cmd_method_name: str = 'cli_cmd'
|
||||
) -> PydanticModel:
|
||||
"""
|
||||
Runs the model subcommand. Running a model subcommand requires the `cli_cmd` method to be defined in
|
||||
the nested model subcommand class.
|
||||
|
||||
Args:
|
||||
model: The model to run the subcommand from.
|
||||
cli_exit_on_error: Determines whether this function exits with error if no subcommand is found.
|
||||
Defaults to model_config `cli_exit_on_error` value if set. Otherwise, defaults to `True`.
|
||||
cli_cmd_method_name: The CLI command method name to run. Defaults to "cli_cmd".
|
||||
|
||||
Returns:
|
||||
The ran subcommand model.
|
||||
|
||||
Raises:
|
||||
SystemExit: When no subcommand is found and cli_exit_on_error=`True` (the default).
|
||||
SettingsError: When no subcommand is found and cli_exit_on_error=`False`.
|
||||
"""
|
||||
|
||||
if id(model) in CliApp._subcommand_stack:
|
||||
cli_settings_source, parser, subcommand_dest = CliApp._subcommand_stack[id(model)]
|
||||
else:
|
||||
cli_settings_source = CliSettingsSource[Any](CliApp._get_base_settings_cls(type(model)))
|
||||
parser = cli_settings_source.root_parser
|
||||
subcommand_dest = ':subcommand'
|
||||
|
||||
cli_exit_on_error = cli_settings_source.cli_exit_on_error if cli_exit_on_error is None else cli_exit_on_error
|
||||
|
||||
errors: list[SettingsError | SystemExit] = []
|
||||
subcommand = get_subcommand(
|
||||
model, is_required=True, cli_exit_on_error=cli_exit_on_error, _suppress_errors=errors
|
||||
)
|
||||
if errors:
|
||||
err = errors[0]
|
||||
if err.__context__ is None and err.__cause__ is None and cli_settings_source._format_help is not None:
|
||||
error_message = f'{err}\n{cli_settings_source._format_help(parser)}'
|
||||
raise type(err)(error_message) from None
|
||||
else:
|
||||
raise err
|
||||
|
||||
subcommand_cls = cast(type[BaseModel], type(subcommand))
|
||||
subcommand_arg = cli_settings_source._parser_map[subcommand_dest][subcommand_cls]
|
||||
subcommand_alias = subcommand_arg.subcommand_alias(subcommand_cls)
|
||||
subcommand_dest = f'{subcommand_dest.split(":")[0]}{subcommand_alias}.:subcommand'
|
||||
subcommand_parser = subcommand_arg.parser
|
||||
CliApp._subcommand_stack[id(subcommand)] = (cli_settings_source, subcommand_parser, subcommand_dest)
|
||||
try:
|
||||
data_model = CliApp._run_cli_cmd(subcommand, cli_cmd_method_name, is_required=True)
|
||||
finally:
|
||||
del CliApp._subcommand_stack[id(subcommand)]
|
||||
return data_model
|
||||
|
||||
@staticmethod
|
||||
def serialize(
|
||||
model: PydanticModel,
|
||||
list_style: Literal['json', 'argparse', 'lazy'] = 'json',
|
||||
dict_style: Literal['json', 'env'] = 'json',
|
||||
positionals_first: bool = False,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Serializes the CLI arguments for a Pydantic data model.
|
||||
|
||||
Args:
|
||||
model: The data model to serialize.
|
||||
list_style:
|
||||
Controls how list-valued fields are serialized on the command line.
|
||||
- 'json' (default):
|
||||
Lists are encoded as a single JSON array.
|
||||
Example: `--tags '["a","b","c"]'`
|
||||
- 'argparse':
|
||||
Each list element becomes its own repeated flag, following
|
||||
typical `argparse` conventions.
|
||||
Example: `--tags a --tags b --tags c`
|
||||
- 'lazy':
|
||||
Lists are emitted as a single comma-separated string without JSON
|
||||
quoting or escaping.
|
||||
Example: `--tags a,b,c`
|
||||
dict_style:
|
||||
Controls how dictionary-valued fields are serialized.
|
||||
- 'json' (default):
|
||||
The entire dictionary is emitted as a single JSON object.
|
||||
Example: `--config '{"host": "localhost", "port": 5432}'`
|
||||
- 'env':
|
||||
The dictionary is flattened into multiple CLI flags using
|
||||
environment-variable-style assignement.
|
||||
Example: `--config host=localhost --config port=5432`
|
||||
positionals_first: Controls whether positional arguments should be serialized
|
||||
first compared to optional arguments. Defaults to `False`.
|
||||
|
||||
Returns:
|
||||
The serialized CLI arguments for the data model.
|
||||
"""
|
||||
|
||||
base_settings_cls = CliApp._get_base_settings_cls(type(model))
|
||||
serialized_args = CliSettingsSource[Any](base_settings_cls)._serialized_args(
|
||||
model,
|
||||
list_style=list_style,
|
||||
dict_style=dict_style,
|
||||
positionals_first=positionals_first,
|
||||
)
|
||||
return CliSettingsSource._flatten_serialized_args(serialized_args, positionals_first)
|
||||
|
||||
@staticmethod
|
||||
def format_help(
|
||||
model: PydanticModel | type[T],
|
||||
cli_settings_source: CliSettingsSource[Any] | None = None,
|
||||
strip_ansi_color: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Return a string containing a help message for a Pydantic model.
|
||||
|
||||
Args:
|
||||
model: The model or model class.
|
||||
cli_settings_source: Override the default CLI settings source with a user defined instance.
|
||||
Defaults to `None`.
|
||||
strip_ansi_color: Strips ANSI color codes from the help message when set to `True`.
|
||||
|
||||
Returns:
|
||||
The help message string for the model.
|
||||
"""
|
||||
model_cls = model if isinstance(model, type) else type(model)
|
||||
if cli_settings_source is None:
|
||||
if not isinstance(model, type) and id(model) in CliApp._subcommand_stack:
|
||||
cli_settings_source, *_ = CliApp._subcommand_stack[id(model)]
|
||||
else:
|
||||
cli_settings_source = CliSettingsSource(CliApp._get_base_settings_cls(model_cls))
|
||||
help_message = cli_settings_source._format_help(cli_settings_source.root_parser)
|
||||
return help_message if not strip_ansi_color else CliApp._ansi_color.sub('', help_message)
|
||||
|
||||
@staticmethod
|
||||
def print_help(
|
||||
model: PydanticModel | type[T],
|
||||
cli_settings_source: CliSettingsSource[Any] | None = None,
|
||||
file: TextIO | None = None,
|
||||
strip_ansi_color: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Print a help message for a Pydantic model.
|
||||
|
||||
Args:
|
||||
model: The model or model class.
|
||||
cli_settings_source: Override the default CLI settings source with a user defined instance.
|
||||
Defaults to `None`.
|
||||
file: A text stream to which the help message is written. If `None`, the output is sent to sys.stdout.
|
||||
strip_ansi_color: Strips ANSI color codes from the help message when set to `True`.
|
||||
"""
|
||||
print(
|
||||
CliApp.format_help(
|
||||
model,
|
||||
cli_settings_source=cli_settings_source,
|
||||
strip_ansi_color=strip_ansi_color,
|
||||
),
|
||||
file=file,
|
||||
)
|
||||
@@ -0,0 +1,84 @@
|
||||
"""Package for handling configuration sources in pydantic-settings."""
|
||||
|
||||
from .base import (
|
||||
ConfigFileSourceMixin,
|
||||
DefaultSettingsSource,
|
||||
InitSettingsSource,
|
||||
PydanticBaseEnvSettingsSource,
|
||||
PydanticBaseSettingsSource,
|
||||
get_subcommand,
|
||||
)
|
||||
from .providers.aws import AWSSecretsManagerSettingsSource
|
||||
from .providers.azure import AzureKeyVaultSettingsSource
|
||||
from .providers.cli import (
|
||||
CLI_SUPPRESS,
|
||||
CliDualFlag,
|
||||
CliExplicitFlag,
|
||||
CliImplicitFlag,
|
||||
CliMutuallyExclusiveGroup,
|
||||
CliPositionalArg,
|
||||
CliSettingsSource,
|
||||
CliSubCommand,
|
||||
CliSuppress,
|
||||
CliToggleFlag,
|
||||
CliUnknownArgs,
|
||||
)
|
||||
from .providers.dotenv import DotEnvSettingsSource, read_env_file
|
||||
from .providers.env import EnvSettingsSource
|
||||
from .providers.gcp import GoogleSecretManagerSettingsSource
|
||||
from .providers.json import JsonConfigSettingsSource
|
||||
from .providers.nested_secrets import NestedSecretsSettingsSource
|
||||
from .providers.pyproject import PyprojectTomlConfigSettingsSource
|
||||
from .providers.secrets import SecretsSettingsSource
|
||||
from .providers.toml import TomlConfigSettingsSource
|
||||
from .providers.yaml import YamlConfigSettingsSource
|
||||
from .types import (
|
||||
DEFAULT_PATH,
|
||||
ENV_FILE_SENTINEL,
|
||||
DotenvType,
|
||||
EnvPrefixTarget,
|
||||
ForceDecode,
|
||||
NoDecode,
|
||||
PathType,
|
||||
PydanticModel,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'CLI_SUPPRESS',
|
||||
'ENV_FILE_SENTINEL',
|
||||
'DEFAULT_PATH',
|
||||
'AWSSecretsManagerSettingsSource',
|
||||
'AzureKeyVaultSettingsSource',
|
||||
'CliExplicitFlag',
|
||||
'CliImplicitFlag',
|
||||
'CliToggleFlag',
|
||||
'CliDualFlag',
|
||||
'CliMutuallyExclusiveGroup',
|
||||
'CliPositionalArg',
|
||||
'CliSettingsSource',
|
||||
'CliSubCommand',
|
||||
'CliSuppress',
|
||||
'CliUnknownArgs',
|
||||
'DefaultSettingsSource',
|
||||
'DotEnvSettingsSource',
|
||||
'DotenvType',
|
||||
'EnvPrefixTarget',
|
||||
'EnvSettingsSource',
|
||||
'ForceDecode',
|
||||
'GoogleSecretManagerSettingsSource',
|
||||
'InitSettingsSource',
|
||||
'JsonConfigSettingsSource',
|
||||
'NestedSecretsSettingsSource',
|
||||
'NoDecode',
|
||||
'PathType',
|
||||
'PydanticBaseEnvSettingsSource',
|
||||
'PydanticBaseSettingsSource',
|
||||
'ConfigFileSourceMixin',
|
||||
'PydanticModel',
|
||||
'PyprojectTomlConfigSettingsSource',
|
||||
'SecretsSettingsSource',
|
||||
'TomlConfigSettingsSource',
|
||||
'YamlConfigSettingsSource',
|
||||
'get_subcommand',
|
||||
'read_env_file',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,579 @@
|
||||
"""Base classes and core functionality for pydantic-settings sources."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import asdict, is_dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, cast, get_args
|
||||
|
||||
from pydantic import AliasChoices, AliasPath, BaseModel, TypeAdapter
|
||||
from pydantic._internal._typing_extra import ( # type: ignore[attr-defined]
|
||||
get_origin,
|
||||
)
|
||||
from pydantic._internal._utils import deep_update, is_model_class
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_inspection.introspection import is_union_origin
|
||||
|
||||
from ..exceptions import SettingsError
|
||||
from ..utils import _lenient_issubclass
|
||||
from .types import EnvNoneType, EnvPrefixTarget, ForceDecode, NoDecode, PathType, PydanticModel, _CliSubCommand
|
||||
from .utils import (
|
||||
_annotation_is_complex,
|
||||
_get_alias_names,
|
||||
_get_field_metadata,
|
||||
_get_model_fields,
|
||||
_resolve_type_alias,
|
||||
_strip_annotated,
|
||||
_union_is_complex,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
def get_subcommand(
|
||||
model: PydanticModel,
|
||||
is_required: bool = True,
|
||||
cli_exit_on_error: bool | None = None,
|
||||
_suppress_errors: list[SettingsError | SystemExit] | None = None,
|
||||
) -> PydanticModel | None:
|
||||
"""
|
||||
Get the subcommand from a model.
|
||||
|
||||
Args:
|
||||
model: The model to get the subcommand from.
|
||||
is_required: Determines whether a model must have subcommand set and raises error if not
|
||||
found. Defaults to `True`.
|
||||
cli_exit_on_error: Determines whether this function exits with error if no subcommand is found.
|
||||
Defaults to model_config `cli_exit_on_error` value if set. Otherwise, defaults to `True`.
|
||||
|
||||
Returns:
|
||||
The subcommand model if found, otherwise `None`.
|
||||
|
||||
Raises:
|
||||
SystemExit: When no subcommand is found and is_required=`True` and cli_exit_on_error=`True`
|
||||
(the default).
|
||||
SettingsError: When no subcommand is found and is_required=`True` and
|
||||
cli_exit_on_error=`False`.
|
||||
"""
|
||||
|
||||
model_cls = type(model)
|
||||
if cli_exit_on_error is None and is_model_class(model_cls):
|
||||
model_default = model_cls.model_config.get('cli_exit_on_error')
|
||||
if isinstance(model_default, bool):
|
||||
cli_exit_on_error = model_default
|
||||
if cli_exit_on_error is None:
|
||||
cli_exit_on_error = True
|
||||
|
||||
subcommands: list[str] = []
|
||||
for field_name, field_info in _get_model_fields(model_cls).items():
|
||||
if _CliSubCommand in field_info.metadata:
|
||||
if getattr(model, field_name) is not None:
|
||||
return getattr(model, field_name)
|
||||
subcommands.append(field_name)
|
||||
|
||||
if is_required:
|
||||
error_message = (
|
||||
f'Error: CLI subcommand is required {{{", ".join(subcommands)}}}'
|
||||
if subcommands
|
||||
else 'Error: CLI subcommand is required but no subcommands were found.'
|
||||
)
|
||||
err = SystemExit(error_message) if cli_exit_on_error else SettingsError(error_message)
|
||||
if _suppress_errors is None:
|
||||
raise err
|
||||
_suppress_errors.append(err)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class PydanticBaseSettingsSource(ABC):
|
||||
"""
|
||||
Abstract base class for settings sources, every settings source classes should inherit from it.
|
||||
"""
|
||||
|
||||
def __init__(self, settings_cls: type[BaseSettings]):
|
||||
self.settings_cls = settings_cls
|
||||
self.config = settings_cls.model_config
|
||||
self._current_state: dict[str, Any] = {}
|
||||
self._settings_sources_data: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def _set_current_state(self, state: dict[str, Any]) -> None:
|
||||
"""
|
||||
Record the state of settings from the previous settings sources. This should
|
||||
be called right before __call__.
|
||||
"""
|
||||
self._current_state = state
|
||||
|
||||
def _set_settings_sources_data(self, states: dict[str, dict[str, Any]]) -> None:
|
||||
"""
|
||||
Record the state of settings from all previous settings sources. This should
|
||||
be called right before __call__.
|
||||
"""
|
||||
self._settings_sources_data = states
|
||||
|
||||
@property
|
||||
def current_state(self) -> dict[str, Any]:
|
||||
"""
|
||||
The current state of the settings, populated by the previous settings sources.
|
||||
"""
|
||||
return self._current_state
|
||||
|
||||
@property
|
||||
def settings_sources_data(self) -> dict[str, dict[str, Any]]:
|
||||
"""
|
||||
The state of all previous settings sources.
|
||||
"""
|
||||
return self._settings_sources_data
|
||||
|
||||
@abstractmethod
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
"""
|
||||
Gets the value, the key for model creation, and a flag to determine whether value is complex.
|
||||
|
||||
This is an abstract method that should be overridden in every settings source classes.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
A tuple that contains the value, key and a flag to determine whether value is complex.
|
||||
"""
|
||||
pass
|
||||
|
||||
def field_is_complex(self, field: FieldInfo) -> bool:
|
||||
"""
|
||||
Checks whether a field is complex, in which case it will attempt to be parsed as JSON.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
|
||||
Returns:
|
||||
Whether the field is complex.
|
||||
"""
|
||||
return _annotation_is_complex(field.annotation, field.metadata)
|
||||
|
||||
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
|
||||
"""
|
||||
Prepares the value of a field.
|
||||
|
||||
Args:
|
||||
field_name: The field name.
|
||||
field: The field.
|
||||
value: The value of the field that has to be prepared.
|
||||
value_is_complex: A flag to determine whether value is complex.
|
||||
|
||||
Returns:
|
||||
The prepared value.
|
||||
"""
|
||||
if value is not None and (self.field_is_complex(field) or value_is_complex):
|
||||
return self.decode_complex_value(field_name, field, value)
|
||||
return value
|
||||
|
||||
def decode_complex_value(self, field_name: str, field: FieldInfo, value: Any) -> Any:
|
||||
"""
|
||||
Decode the value for a complex field
|
||||
|
||||
Args:
|
||||
field_name: The field name.
|
||||
field: The field.
|
||||
value: The value of the field that has to be prepared.
|
||||
|
||||
Returns:
|
||||
The decoded value for further preparation
|
||||
"""
|
||||
if field and (
|
||||
NoDecode in _get_field_metadata(field)
|
||||
or (self.config.get('enable_decoding') is False and ForceDecode not in field.metadata)
|
||||
):
|
||||
return value
|
||||
|
||||
return json.loads(value)
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
pass
|
||||
|
||||
|
||||
class ConfigFileSourceMixin(ABC):
|
||||
def _read_files(self, files: PathType | None, deep_merge: bool = False) -> dict[str, Any]:
|
||||
if files is None:
|
||||
return {}
|
||||
if not isinstance(files, Sequence) or isinstance(files, str):
|
||||
files = [files]
|
||||
vars: dict[str, Any] = {}
|
||||
for file in files:
|
||||
if isinstance(file, str):
|
||||
file_path = Path(file)
|
||||
else:
|
||||
file_path = file
|
||||
if isinstance(file_path, Path):
|
||||
file_path = file_path.expanduser()
|
||||
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
|
||||
updating_vars = self._read_file(file_path)
|
||||
if deep_merge:
|
||||
vars = deep_update(vars, updating_vars)
|
||||
else:
|
||||
vars.update(updating_vars)
|
||||
return vars
|
||||
|
||||
@abstractmethod
|
||||
def _read_file(self, path: Path) -> dict[str, Any]:
|
||||
pass
|
||||
|
||||
|
||||
class DefaultSettingsSource(PydanticBaseSettingsSource):
|
||||
"""
|
||||
Source class for loading default object values.
|
||||
|
||||
Args:
|
||||
settings_cls: The Settings class.
|
||||
nested_model_default_partial_update: Whether to allow partial updates on nested model default object fields.
|
||||
Defaults to `False`.
|
||||
"""
|
||||
|
||||
def __init__(self, settings_cls: type[BaseSettings], nested_model_default_partial_update: bool | None = None):
|
||||
super().__init__(settings_cls)
|
||||
self.defaults: dict[str, Any] = {}
|
||||
self.nested_model_default_partial_update = (
|
||||
nested_model_default_partial_update
|
||||
if nested_model_default_partial_update is not None
|
||||
else self.config.get('nested_model_default_partial_update', False)
|
||||
)
|
||||
if self.nested_model_default_partial_update:
|
||||
for field_name, field_info in settings_cls.model_fields.items():
|
||||
alias_names, *_ = _get_alias_names(field_name, field_info)
|
||||
preferred_alias = alias_names[0]
|
||||
if is_dataclass(type(field_info.default)):
|
||||
self.defaults[preferred_alias] = asdict(field_info.default)
|
||||
elif is_model_class(type(field_info.default)):
|
||||
self.defaults[preferred_alias] = field_info.default.model_dump()
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
# Nothing to do here. Only implement the return statement to make mypy happy
|
||||
return None, '', False
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
return self.defaults
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'{self.__class__.__name__}(nested_model_default_partial_update={self.nested_model_default_partial_update})'
|
||||
)
|
||||
|
||||
|
||||
class InitSettingsSource(PydanticBaseSettingsSource):
|
||||
"""
|
||||
Source class for loading values provided during settings class initialization.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
init_kwargs: dict[str, Any],
|
||||
nested_model_default_partial_update: bool | None = None,
|
||||
):
|
||||
self.init_kwargs = {}
|
||||
init_kwarg_names = set(init_kwargs.keys())
|
||||
for field_name, field_info in settings_cls.model_fields.items():
|
||||
alias_names, *_ = _get_alias_names(field_name, field_info)
|
||||
# When populate_by_name is True, allow using the field name as an input key,
|
||||
# but normalize to the preferred alias to keep keys consistent across sources.
|
||||
matchable_names = set(alias_names)
|
||||
include_name = settings_cls.model_config.get('populate_by_name', False) or settings_cls.model_config.get(
|
||||
'validate_by_name', False
|
||||
)
|
||||
if include_name:
|
||||
matchable_names.add(field_name)
|
||||
init_kwarg_name = init_kwarg_names & matchable_names
|
||||
if init_kwarg_name:
|
||||
preferred_alias = alias_names[0] if alias_names else field_name
|
||||
# Choose provided key deterministically: prefer the first alias in alias_names order;
|
||||
# fall back to field_name if allowed and provided.
|
||||
provided_key = next((alias for alias in alias_names if alias in init_kwarg_names), None)
|
||||
if provided_key is None and include_name and field_name in init_kwarg_names:
|
||||
provided_key = field_name
|
||||
# provided_key should not be None here because init_kwarg_name is non-empty
|
||||
assert provided_key is not None
|
||||
init_kwarg_names -= init_kwarg_name
|
||||
self.init_kwargs[preferred_alias] = init_kwargs[provided_key]
|
||||
# Include any remaining init kwargs (e.g., extras) unchanged
|
||||
# Note: If populate_by_name is True and the provided key is the field name, but
|
||||
# no alias exists, we keep it as-is so it can be processed as extra if allowed.
|
||||
self.init_kwargs.update({key: val for key, val in init_kwargs.items() if key in init_kwarg_names})
|
||||
|
||||
super().__init__(settings_cls)
|
||||
self.nested_model_default_partial_update = (
|
||||
nested_model_default_partial_update
|
||||
if nested_model_default_partial_update is not None
|
||||
else self.config.get('nested_model_default_partial_update', False)
|
||||
)
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
# Nothing to do here. Only implement the return statement to make mypy happy
|
||||
return None, '', False
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
return (
|
||||
TypeAdapter(dict[str, Any]).dump_python(self.init_kwargs)
|
||||
if self.nested_model_default_partial_update
|
||||
else self.init_kwargs
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(init_kwargs={self.init_kwargs!r})'
|
||||
|
||||
|
||||
class PydanticBaseEnvSettingsSource(PydanticBaseSettingsSource):
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
case_sensitive: bool | None = None,
|
||||
env_prefix: str | None = None,
|
||||
env_prefix_target: EnvPrefixTarget | None = None,
|
||||
env_ignore_empty: bool | None = None,
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
) -> None:
|
||||
super().__init__(settings_cls)
|
||||
self.case_sensitive = case_sensitive if case_sensitive is not None else self.config.get('case_sensitive', False)
|
||||
self.env_prefix = env_prefix if env_prefix is not None else self.config.get('env_prefix', '')
|
||||
self.env_prefix_target = (
|
||||
env_prefix_target if env_prefix_target is not None else self.config.get('env_prefix_target', 'variable')
|
||||
)
|
||||
self.env_ignore_empty = (
|
||||
env_ignore_empty if env_ignore_empty is not None else self.config.get('env_ignore_empty', False)
|
||||
)
|
||||
self.env_parse_none_str = (
|
||||
env_parse_none_str if env_parse_none_str is not None else self.config.get('env_parse_none_str')
|
||||
)
|
||||
self.env_parse_enums = env_parse_enums if env_parse_enums is not None else self.config.get('env_parse_enums')
|
||||
|
||||
def _apply_case_sensitive(self, value: str) -> str:
|
||||
return value.lower() if not self.case_sensitive else value
|
||||
|
||||
def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]:
|
||||
"""
|
||||
Extracts field info. This info is used to get the value of field from environment variables.
|
||||
|
||||
It returns a list of tuples, each tuple contains:
|
||||
* field_key: The key of field that has to be used in model creation.
|
||||
* env_name: The environment variable name of the field.
|
||||
* value_is_complex: A flag to determine whether the value from environment variable
|
||||
is complex and has to be parsed.
|
||||
|
||||
Args:
|
||||
field (FieldInfo): The field.
|
||||
field_name (str): The field name.
|
||||
|
||||
Returns:
|
||||
list[tuple[str, str, bool]]: List of tuples, each tuple contains field_key, env_name, and value_is_complex.
|
||||
"""
|
||||
field_info: list[tuple[str, str, bool]] = []
|
||||
if isinstance(field.validation_alias, (AliasChoices, AliasPath)):
|
||||
v_alias: str | list[str | int] | list[list[str | int]] | None = field.validation_alias.convert_to_aliases()
|
||||
else:
|
||||
v_alias = field.validation_alias
|
||||
|
||||
if v_alias:
|
||||
env_prefix = self.env_prefix if self.env_prefix_target in ('alias', 'all') else ''
|
||||
if isinstance(v_alias, list): # AliasChoices, AliasPath
|
||||
for alias in v_alias:
|
||||
if isinstance(alias, str): # AliasPath
|
||||
field_info.append(
|
||||
(alias, self._apply_case_sensitive(env_prefix + alias), True if len(alias) > 1 else False)
|
||||
)
|
||||
elif isinstance(alias, list): # AliasChoices
|
||||
first_arg = cast(str, alias[0]) # first item of an AliasChoices must be a str
|
||||
field_info.append(
|
||||
(
|
||||
first_arg,
|
||||
self._apply_case_sensitive(env_prefix + first_arg),
|
||||
True if len(alias) > 1 else False,
|
||||
)
|
||||
)
|
||||
else: # string validation alias
|
||||
field_info.append((v_alias, self._apply_case_sensitive(env_prefix + v_alias), False))
|
||||
|
||||
if not v_alias or self.config.get('populate_by_name', False) or self.config.get('validate_by_name', False):
|
||||
annotation = _strip_annotated(_resolve_type_alias(field.annotation))
|
||||
env_prefix = self.env_prefix if self.env_prefix_target in ('variable', 'all') else ''
|
||||
if is_union_origin(get_origin(annotation)) and _union_is_complex(annotation, field.metadata):
|
||||
field_info.append((field_name, self._apply_case_sensitive(env_prefix + field_name), True))
|
||||
else:
|
||||
field_info.append((field_name, self._apply_case_sensitive(env_prefix + field_name), False))
|
||||
|
||||
return field_info
|
||||
|
||||
def _replace_field_names_case_insensitively(self, field: FieldInfo, field_values: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Replace field names in values dict by looking in models fields insensitively.
|
||||
|
||||
By having the following models:
|
||||
|
||||
```py
|
||||
class SubSubSub(BaseModel):
|
||||
VaL3: str
|
||||
|
||||
class SubSub(BaseModel):
|
||||
Val2: str
|
||||
SUB_sub_SuB: SubSubSub
|
||||
|
||||
class Sub(BaseModel):
|
||||
VAL1: str
|
||||
SUB_sub: SubSub
|
||||
|
||||
class Settings(BaseSettings):
|
||||
nested: Sub
|
||||
|
||||
model_config = SettingsConfigDict(env_nested_delimiter='__')
|
||||
```
|
||||
|
||||
Then:
|
||||
_replace_field_names_case_insensitively(
|
||||
field,
|
||||
{"val1": "v1", "sub_SUB": {"VAL2": "v2", "sub_SUB_sUb": {"vAl3": "v3"}}}
|
||||
)
|
||||
Returns {'VAL1': 'v1', 'SUB_sub': {'Val2': 'v2', 'SUB_sub_SuB': {'VaL3': 'v3'}}}
|
||||
"""
|
||||
values: dict[str, Any] = {}
|
||||
|
||||
for name, value in field_values.items():
|
||||
sub_model_field: FieldInfo | None = None
|
||||
|
||||
annotation = field.annotation
|
||||
|
||||
# If field is Optional, we need to find the actual type
|
||||
if is_union_origin(get_origin(field.annotation)):
|
||||
args = get_args(annotation)
|
||||
if len(args) == 2 and type(None) in args:
|
||||
for arg in args:
|
||||
if arg is not None:
|
||||
annotation = arg
|
||||
break
|
||||
|
||||
# This is here to make mypy happy
|
||||
# Item "None" of "Optional[Type[Any]]" has no attribute "model_fields"
|
||||
if not annotation or not hasattr(annotation, 'model_fields'):
|
||||
values[name] = value
|
||||
continue
|
||||
else:
|
||||
model_fields: dict[str, FieldInfo] = annotation.model_fields
|
||||
|
||||
# Find field in sub model by looking in fields case insensitively
|
||||
field_key: str | None = None
|
||||
for sub_model_field_name, sub_model_field in model_fields.items():
|
||||
aliases, _ = _get_alias_names(sub_model_field_name, sub_model_field)
|
||||
_search = (alias for alias in aliases if alias.lower() == name.lower())
|
||||
if field_key := next(_search, None):
|
||||
break
|
||||
|
||||
if not field_key:
|
||||
values[name] = value
|
||||
continue
|
||||
|
||||
if (
|
||||
sub_model_field is not None
|
||||
and _lenient_issubclass(sub_model_field.annotation, BaseModel)
|
||||
and isinstance(value, dict)
|
||||
):
|
||||
values[field_key] = self._replace_field_names_case_insensitively(sub_model_field, value)
|
||||
else:
|
||||
values[field_key] = value
|
||||
|
||||
return values
|
||||
|
||||
def _replace_env_none_type_values(self, field_value: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Recursively parse values that are of "None" type(EnvNoneType) to `None` type(None).
|
||||
"""
|
||||
values: dict[str, Any] = {}
|
||||
|
||||
for key, value in field_value.items():
|
||||
if not isinstance(value, EnvNoneType):
|
||||
values[key] = value if not isinstance(value, dict) else self._replace_env_none_type_values(value)
|
||||
else:
|
||||
values[key] = None
|
||||
|
||||
return values
|
||||
|
||||
def _get_resolved_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
"""
|
||||
Gets the value, the preferred alias key for model creation, and a flag to determine whether value
|
||||
is complex.
|
||||
|
||||
Note:
|
||||
In V3, this method should either be made public, or, this method should be removed and the
|
||||
abstract method get_field_value should be updated to include a "use_preferred_alias" flag.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
A tuple that contains the value, preferred key and a flag to determine whether value is complex.
|
||||
"""
|
||||
field_value, field_key, value_is_complex = self.get_field_value(field, field_name)
|
||||
# Only use preferred_key when no value was found; otherwise preserve the key that matched
|
||||
if field_value is None and not (
|
||||
value_is_complex
|
||||
or (
|
||||
(self.config.get('populate_by_name', False) or self.config.get('validate_by_name', False))
|
||||
and (field_key == field_name)
|
||||
)
|
||||
):
|
||||
field_infos = self._extract_field_info(field, field_name)
|
||||
preferred_key, *_ = field_infos[0]
|
||||
return field_value, preferred_key, value_is_complex
|
||||
return field_value, field_key, value_is_complex
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
data: dict[str, Any] = {}
|
||||
|
||||
for field_name, field in self.settings_cls.model_fields.items():
|
||||
try:
|
||||
field_value, field_key, value_is_complex = self._get_resolved_field_value(field, field_name)
|
||||
except Exception as e:
|
||||
raise SettingsError(
|
||||
f'error getting value for field "{field_name}" from source "{self.__class__.__name__}"'
|
||||
) from e
|
||||
|
||||
try:
|
||||
field_value = self.prepare_field_value(field_name, field, field_value, value_is_complex)
|
||||
except ValueError as e:
|
||||
raise SettingsError(
|
||||
f'error parsing value for field "{field_name}" from source "{self.__class__.__name__}"'
|
||||
) from e
|
||||
|
||||
if field_value is not None:
|
||||
if self.env_parse_none_str is not None:
|
||||
if isinstance(field_value, dict):
|
||||
field_value = self._replace_env_none_type_values(field_value)
|
||||
elif isinstance(field_value, EnvNoneType):
|
||||
field_value = None
|
||||
if (
|
||||
not self.case_sensitive
|
||||
# and _lenient_issubclass(field.annotation, BaseModel)
|
||||
and isinstance(field_value, dict)
|
||||
):
|
||||
data[field_key] = self._replace_field_names_case_insensitively(field, field_value)
|
||||
else:
|
||||
data[field_key] = field_value
|
||||
|
||||
return data
|
||||
|
||||
|
||||
__all__ = [
|
||||
'ConfigFileSourceMixin',
|
||||
'DefaultSettingsSource',
|
||||
'InitSettingsSource',
|
||||
'PydanticBaseEnvSettingsSource',
|
||||
'PydanticBaseSettingsSource',
|
||||
'SettingsError',
|
||||
]
|
||||
@@ -0,0 +1,45 @@
|
||||
"""Package containing individual source implementations."""
|
||||
|
||||
from .aws import AWSSecretsManagerSettingsSource
|
||||
from .azure import AzureKeyVaultSettingsSource
|
||||
from .cli import (
|
||||
CliDualFlag,
|
||||
CliExplicitFlag,
|
||||
CliImplicitFlag,
|
||||
CliMutuallyExclusiveGroup,
|
||||
CliPositionalArg,
|
||||
CliSettingsSource,
|
||||
CliSubCommand,
|
||||
CliSuppress,
|
||||
CliToggleFlag,
|
||||
)
|
||||
from .dotenv import DotEnvSettingsSource
|
||||
from .env import EnvSettingsSource
|
||||
from .gcp import GoogleSecretManagerSettingsSource
|
||||
from .json import JsonConfigSettingsSource
|
||||
from .pyproject import PyprojectTomlConfigSettingsSource
|
||||
from .secrets import SecretsSettingsSource
|
||||
from .toml import TomlConfigSettingsSource
|
||||
from .yaml import YamlConfigSettingsSource
|
||||
|
||||
__all__ = [
|
||||
'AWSSecretsManagerSettingsSource',
|
||||
'AzureKeyVaultSettingsSource',
|
||||
'CliExplicitFlag',
|
||||
'CliImplicitFlag',
|
||||
'CliToggleFlag',
|
||||
'CliDualFlag',
|
||||
'CliMutuallyExclusiveGroup',
|
||||
'CliPositionalArg',
|
||||
'CliSettingsSource',
|
||||
'CliSubCommand',
|
||||
'CliSuppress',
|
||||
'DotEnvSettingsSource',
|
||||
'EnvSettingsSource',
|
||||
'GoogleSecretManagerSettingsSource',
|
||||
'JsonConfigSettingsSource',
|
||||
'PyprojectTomlConfigSettingsSource',
|
||||
'SecretsSettingsSource',
|
||||
'TomlConfigSettingsSource',
|
||||
'YamlConfigSettingsSource',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,86 @@
|
||||
from __future__ import annotations as _annotations # important for BaseSettings import to work
|
||||
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..utils import parse_env_vars
|
||||
from .env import EnvSettingsSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
boto3_client = None
|
||||
SecretsManagerClient = None
|
||||
|
||||
|
||||
def import_aws_secrets_manager() -> None:
|
||||
global boto3_client
|
||||
global SecretsManagerClient
|
||||
|
||||
try:
|
||||
from boto3 import client as boto3_client
|
||||
from mypy_boto3_secretsmanager.client import SecretsManagerClient
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
'AWS Secrets Manager dependencies are not installed, run `pip install pydantic-settings[aws-secrets-manager]`'
|
||||
) from e
|
||||
|
||||
|
||||
class AWSSecretsManagerSettingsSource(EnvSettingsSource):
|
||||
_secret_id: str
|
||||
_secretsmanager_client: SecretsManagerClient # type: ignore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
secret_id: str,
|
||||
region_name: str | None = None,
|
||||
endpoint_url: str | None = None,
|
||||
case_sensitive: bool | None = True,
|
||||
env_prefix: str | None = None,
|
||||
env_nested_delimiter: str | None = '--',
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
version_id: str | None = None,
|
||||
) -> None:
|
||||
import_aws_secrets_manager()
|
||||
self._secretsmanager_client = boto3_client('secretsmanager', region_name=region_name, endpoint_url=endpoint_url) # type: ignore
|
||||
self._secret_id = secret_id
|
||||
self._version_id = version_id
|
||||
super().__init__(
|
||||
settings_cls,
|
||||
case_sensitive=case_sensitive,
|
||||
env_prefix=env_prefix,
|
||||
env_nested_delimiter=env_nested_delimiter,
|
||||
env_ignore_empty=False,
|
||||
env_parse_none_str=env_parse_none_str,
|
||||
env_parse_enums=env_parse_enums,
|
||||
)
|
||||
|
||||
def _load_env_vars(self) -> Mapping[str, str | None]:
|
||||
request = {'SecretId': self._secret_id}
|
||||
|
||||
if self._version_id:
|
||||
request['VersionId'] = self._version_id
|
||||
|
||||
response = self._secretsmanager_client.get_secret_value(**request) # type: ignore
|
||||
|
||||
return parse_env_vars(
|
||||
json.loads(response['SecretString']),
|
||||
self.case_sensitive,
|
||||
self.env_ignore_empty,
|
||||
self.env_parse_none_str,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'{self.__class__.__name__}(secret_id={self._secret_id!r}, '
|
||||
f'env_nested_delimiter={self.env_nested_delimiter!r})'
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'AWSSecretsManagerSettingsSource',
|
||||
]
|
||||
@@ -0,0 +1,159 @@
|
||||
"""Azure Key Vault settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from collections.abc import Iterator, Mapping
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic.alias_generators import to_snake
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from .env import EnvSettingsSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from azure.core.credentials import TokenCredential
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from azure.keyvault.secrets import SecretClient
|
||||
|
||||
from pydantic_settings.main import BaseSettings
|
||||
else:
|
||||
TokenCredential = None
|
||||
ResourceNotFoundError = None
|
||||
SecretClient = None
|
||||
|
||||
|
||||
def import_azure_key_vault() -> None:
|
||||
global TokenCredential
|
||||
global SecretClient
|
||||
global ResourceNotFoundError
|
||||
|
||||
try:
|
||||
from azure.core.credentials import TokenCredential
|
||||
from azure.core.exceptions import ResourceNotFoundError
|
||||
from azure.keyvault.secrets import SecretClient
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
'Azure Key Vault dependencies are not installed, run `pip install pydantic-settings[azure-key-vault]`'
|
||||
) from e
|
||||
|
||||
|
||||
class AzureKeyVaultMapping(Mapping[str, str | None]):
|
||||
_loaded_secrets: dict[str, str | None]
|
||||
_secret_client: SecretClient
|
||||
_secret_names: list[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
secret_client: SecretClient,
|
||||
case_sensitive: bool,
|
||||
snake_case_conversion: bool,
|
||||
env_prefix: str | None,
|
||||
) -> None:
|
||||
self._loaded_secrets = {}
|
||||
self._secret_client = secret_client
|
||||
self._case_sensitive = case_sensitive
|
||||
self._snake_case_conversion = snake_case_conversion
|
||||
self._env_prefix = env_prefix if env_prefix else ''
|
||||
self._secret_map: dict[str, str] = self._load_remote()
|
||||
|
||||
def _load_remote(self) -> dict[str, str]:
|
||||
secret_names: Iterator[str] = (
|
||||
secret.name for secret in self._secret_client.list_properties_of_secrets() if secret.name and secret.enabled
|
||||
)
|
||||
|
||||
if self._snake_case_conversion:
|
||||
name_map: dict[str, str] = {}
|
||||
for name in secret_names:
|
||||
if name.startswith(self._env_prefix):
|
||||
name_map[f'{self._env_prefix}{to_snake(name[len(self._env_prefix) :])}'] = name
|
||||
else:
|
||||
name_map[to_snake(name)] = name
|
||||
return name_map
|
||||
|
||||
if self._case_sensitive:
|
||||
return {name: name for name in secret_names}
|
||||
|
||||
return {name.lower(): name for name in secret_names}
|
||||
|
||||
def __getitem__(self, key: str) -> str | None:
|
||||
new_key = key
|
||||
|
||||
if self._snake_case_conversion:
|
||||
if key.startswith(self._env_prefix):
|
||||
new_key = f'{self._env_prefix}{to_snake(key[len(self._env_prefix) :])}'
|
||||
else:
|
||||
new_key = to_snake(key)
|
||||
|
||||
elif not self._case_sensitive:
|
||||
new_key = key.lower()
|
||||
|
||||
if new_key not in self._loaded_secrets:
|
||||
if new_key in self._secret_map:
|
||||
self._loaded_secrets[new_key] = self._secret_client.get_secret(self._secret_map[new_key]).value
|
||||
else:
|
||||
raise KeyError(key)
|
||||
|
||||
return self._loaded_secrets[new_key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._secret_map)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self._secret_map.keys())
|
||||
|
||||
|
||||
class AzureKeyVaultSettingsSource(EnvSettingsSource):
|
||||
_url: str
|
||||
_credential: TokenCredential
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
url: str,
|
||||
credential: TokenCredential,
|
||||
dash_to_underscore: bool = False,
|
||||
case_sensitive: bool | None = None,
|
||||
snake_case_conversion: bool = False,
|
||||
env_prefix: str | None = None,
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
) -> None:
|
||||
import_azure_key_vault()
|
||||
self._url = url
|
||||
self._credential = credential
|
||||
self._dash_to_underscore = dash_to_underscore
|
||||
self._snake_case_conversion = snake_case_conversion
|
||||
super().__init__(
|
||||
settings_cls,
|
||||
case_sensitive=True if snake_case_conversion else case_sensitive,
|
||||
env_prefix=env_prefix,
|
||||
env_nested_delimiter='__' if snake_case_conversion else '--',
|
||||
env_ignore_empty=False,
|
||||
env_parse_none_str=env_parse_none_str,
|
||||
env_parse_enums=env_parse_enums,
|
||||
)
|
||||
|
||||
def _load_env_vars(self) -> Mapping[str, str | None]:
|
||||
secret_client = SecretClient(vault_url=self._url, credential=self._credential)
|
||||
return AzureKeyVaultMapping(
|
||||
secret_client=secret_client,
|
||||
case_sensitive=self.case_sensitive,
|
||||
snake_case_conversion=self._snake_case_conversion,
|
||||
env_prefix=self.env_prefix,
|
||||
)
|
||||
|
||||
def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[str, str, bool]]:
|
||||
if self._snake_case_conversion:
|
||||
field_info = list((x[0], x[1], x[2]) for x in super()._extract_field_info(field, field_name))
|
||||
return field_info
|
||||
|
||||
if self._dash_to_underscore:
|
||||
return list((x[0], x[1].replace('_', '-'), x[2]) for x in super()._extract_field_info(field, field_name))
|
||||
|
||||
return super()._extract_field_info(field, field_name)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(url={self._url!r}, env_nested_delimiter={self.env_nested_delimiter!r})'
|
||||
|
||||
|
||||
__all__ = ['AzureKeyVaultMapping', 'AzureKeyVaultSettingsSource']
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,170 @@
|
||||
"""Dotenv file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from dotenv import dotenv_values
|
||||
from pydantic._internal._typing_extra import ( # type: ignore[attr-defined]
|
||||
get_origin,
|
||||
)
|
||||
from typing_inspection.introspection import is_union_origin
|
||||
|
||||
from ..types import ENV_FILE_SENTINEL, DotenvType, EnvPrefixTarget
|
||||
from ..utils import (
|
||||
_annotation_is_complex,
|
||||
_union_is_complex,
|
||||
parse_env_vars,
|
||||
)
|
||||
from .env import EnvSettingsSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
class DotEnvSettingsSource(EnvSettingsSource):
|
||||
"""
|
||||
Source class for loading settings values from env files.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
env_file: DotenvType | None = ENV_FILE_SENTINEL,
|
||||
env_file_encoding: str | None = None,
|
||||
case_sensitive: bool | None = None,
|
||||
env_prefix: str | None = None,
|
||||
env_prefix_target: EnvPrefixTarget | None = None,
|
||||
env_nested_delimiter: str | None = None,
|
||||
env_nested_max_split: int | None = None,
|
||||
env_ignore_empty: bool | None = None,
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
) -> None:
|
||||
self.env_file = env_file if env_file != ENV_FILE_SENTINEL else settings_cls.model_config.get('env_file')
|
||||
self.env_file_encoding = (
|
||||
env_file_encoding if env_file_encoding is not None else settings_cls.model_config.get('env_file_encoding')
|
||||
)
|
||||
super().__init__(
|
||||
settings_cls,
|
||||
case_sensitive,
|
||||
env_prefix,
|
||||
env_prefix_target,
|
||||
env_nested_delimiter,
|
||||
env_nested_max_split,
|
||||
env_ignore_empty,
|
||||
env_parse_none_str,
|
||||
env_parse_enums,
|
||||
)
|
||||
|
||||
def _load_env_vars(self) -> Mapping[str, str | None]:
|
||||
return self._read_env_files()
|
||||
|
||||
@staticmethod
|
||||
def _static_read_env_file(
|
||||
file_path: Path,
|
||||
*,
|
||||
encoding: str | None = None,
|
||||
case_sensitive: bool = False,
|
||||
ignore_empty: bool = False,
|
||||
parse_none_str: str | None = None,
|
||||
) -> Mapping[str, str | None]:
|
||||
file_vars: dict[str, str | None] = dotenv_values(file_path, encoding=encoding or 'utf8')
|
||||
return parse_env_vars(file_vars, case_sensitive, ignore_empty, parse_none_str)
|
||||
|
||||
def _read_env_file(
|
||||
self,
|
||||
file_path: Path,
|
||||
) -> Mapping[str, str | None]:
|
||||
return self._static_read_env_file(
|
||||
file_path,
|
||||
encoding=self.env_file_encoding,
|
||||
case_sensitive=self.case_sensitive,
|
||||
ignore_empty=self.env_ignore_empty,
|
||||
parse_none_str=self.env_parse_none_str,
|
||||
)
|
||||
|
||||
def _read_env_files(self) -> Mapping[str, str | None]:
|
||||
env_files = self.env_file
|
||||
if env_files is None:
|
||||
return {}
|
||||
|
||||
if isinstance(env_files, (str, os.PathLike)):
|
||||
env_files = [env_files]
|
||||
|
||||
dotenv_vars: dict[str, str | None] = {}
|
||||
for env_file in env_files:
|
||||
env_path = Path(env_file).expanduser()
|
||||
if env_path.is_file():
|
||||
dotenv_vars.update(self._read_env_file(env_path))
|
||||
|
||||
return dotenv_vars
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
data: dict[str, Any] = super().__call__()
|
||||
is_extra_allowed = self.config.get('extra') != 'forbid'
|
||||
|
||||
# As `extra` config is allowed in dotenv settings source, We have to
|
||||
# update data with extra env variables from dotenv file.
|
||||
for env_name, env_value in self.env_vars.items():
|
||||
if not env_value or env_name in data or (self.env_prefix and env_name in self.settings_cls.model_fields):
|
||||
continue
|
||||
env_used = False
|
||||
for field_name, field in self.settings_cls.model_fields.items():
|
||||
for _, field_env_name, _ in self._extract_field_info(field, field_name):
|
||||
if env_name == field_env_name or (
|
||||
(
|
||||
_annotation_is_complex(field.annotation, field.metadata)
|
||||
or (
|
||||
is_union_origin(get_origin(field.annotation))
|
||||
and _union_is_complex(field.annotation, field.metadata)
|
||||
)
|
||||
)
|
||||
and env_name.startswith(field_env_name)
|
||||
):
|
||||
env_used = True
|
||||
break
|
||||
if env_used:
|
||||
break
|
||||
if not env_used:
|
||||
if is_extra_allowed and env_name.startswith(self.env_prefix):
|
||||
# env_prefix should be respected and removed from the env_name
|
||||
normalized_env_name = env_name[len(self.env_prefix) :]
|
||||
data[normalized_env_name] = env_value
|
||||
else:
|
||||
data[env_name] = env_value
|
||||
return data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'{self.__class__.__name__}(env_file={self.env_file!r}, env_file_encoding={self.env_file_encoding!r}, '
|
||||
f'env_nested_delimiter={self.env_nested_delimiter!r}, env_prefix_len={self.env_prefix_len!r})'
|
||||
)
|
||||
|
||||
|
||||
def read_env_file(
|
||||
file_path: Path,
|
||||
*,
|
||||
encoding: str | None = None,
|
||||
case_sensitive: bool = False,
|
||||
ignore_empty: bool = False,
|
||||
parse_none_str: str | None = None,
|
||||
) -> Mapping[str, str | None]:
|
||||
warnings.warn(
|
||||
'read_env_file will be removed in the next version, use DotEnvSettingsSource._static_read_env_file if you must',
|
||||
DeprecationWarning,
|
||||
)
|
||||
return DotEnvSettingsSource._static_read_env_file(
|
||||
file_path,
|
||||
encoding=encoding,
|
||||
case_sensitive=case_sensitive,
|
||||
ignore_empty=ignore_empty,
|
||||
parse_none_str=parse_none_str,
|
||||
)
|
||||
|
||||
|
||||
__all__ = ['DotEnvSettingsSource', 'read_env_file']
|
||||
@@ -0,0 +1,310 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import os
|
||||
from collections.abc import Mapping
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from pydantic import Json, TypeAdapter, ValidationError
|
||||
from pydantic._internal._utils import deep_update, is_model_class
|
||||
from pydantic.dataclasses import is_pydantic_dataclass
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_inspection.introspection import is_union_origin
|
||||
|
||||
from ...utils import _lenient_issubclass
|
||||
from ..base import PydanticBaseEnvSettingsSource
|
||||
from ..types import EnvNoneType, EnvPrefixTarget
|
||||
from ..utils import (
|
||||
_annotation_contains_types,
|
||||
_annotation_enum_name_to_val,
|
||||
_annotation_is_complex,
|
||||
_get_model_fields,
|
||||
_union_is_complex,
|
||||
parse_env_vars,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
class EnvSettingsSource(PydanticBaseEnvSettingsSource):
|
||||
"""
|
||||
Source class for loading settings values from environment variables.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
case_sensitive: bool | None = None,
|
||||
env_prefix: str | None = None,
|
||||
env_prefix_target: EnvPrefixTarget | None = None,
|
||||
env_nested_delimiter: str | None = None,
|
||||
env_nested_max_split: int | None = None,
|
||||
env_ignore_empty: bool | None = None,
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
settings_cls,
|
||||
case_sensitive,
|
||||
env_prefix,
|
||||
env_prefix_target,
|
||||
env_ignore_empty,
|
||||
env_parse_none_str,
|
||||
env_parse_enums,
|
||||
)
|
||||
self.env_nested_delimiter = (
|
||||
env_nested_delimiter if env_nested_delimiter is not None else self.config.get('env_nested_delimiter')
|
||||
)
|
||||
self.env_nested_max_split = (
|
||||
env_nested_max_split if env_nested_max_split is not None else self.config.get('env_nested_max_split')
|
||||
)
|
||||
self.maxsplit = (self.env_nested_max_split or 0) - 1
|
||||
self.env_prefix_len = len(self.env_prefix)
|
||||
|
||||
self.env_vars = self._load_env_vars()
|
||||
|
||||
def _load_env_vars(self) -> Mapping[str, str | None]:
|
||||
return parse_env_vars(os.environ, self.case_sensitive, self.env_ignore_empty, self.env_parse_none_str)
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
"""
|
||||
Gets the value for field from environment variables and a flag to determine whether value is complex.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
A tuple that contains the value (`None` if not found), key, and
|
||||
a flag to determine whether value is complex.
|
||||
"""
|
||||
|
||||
env_val: str | None = None
|
||||
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
|
||||
env_val = self.env_vars.get(env_name)
|
||||
if env_val is not None:
|
||||
break
|
||||
|
||||
return env_val, field_key, value_is_complex
|
||||
|
||||
def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
|
||||
"""
|
||||
Prepare value for the field.
|
||||
|
||||
* Extract value for nested field.
|
||||
* Deserialize value to python object for complex field.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
A tuple contains prepared value for the field.
|
||||
|
||||
Raises:
|
||||
ValuesError: When There is an error in deserializing value for complex field.
|
||||
"""
|
||||
is_complex, allow_parse_failure = self._field_is_complex(field)
|
||||
if self.env_parse_enums:
|
||||
enum_val = _annotation_enum_name_to_val(field.annotation, value)
|
||||
value = value if enum_val is None else enum_val
|
||||
|
||||
if is_complex or value_is_complex:
|
||||
if isinstance(value, EnvNoneType):
|
||||
return value
|
||||
elif value is None:
|
||||
# field is complex but no value found so far, try explode_env_vars
|
||||
env_val_built = self.explode_env_vars(field_name, field, self.env_vars)
|
||||
if env_val_built:
|
||||
return env_val_built
|
||||
else:
|
||||
# field is complex and there's a value, decode that as JSON, then add explode_env_vars
|
||||
try:
|
||||
value = self.decode_complex_value(field_name, field, value)
|
||||
except ValueError as e:
|
||||
if not allow_parse_failure:
|
||||
raise e
|
||||
|
||||
if isinstance(value, dict):
|
||||
return deep_update(value, self.explode_env_vars(field_name, field, self.env_vars))
|
||||
else:
|
||||
return value
|
||||
elif value is not None:
|
||||
# simplest case, field is not complex, we only need to add the value if it was found
|
||||
return self._coerce_env_val_strict(field, value)
|
||||
|
||||
def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]:
|
||||
"""
|
||||
Find out if a field is complex, and if so whether JSON errors should be ignored
|
||||
"""
|
||||
if self.field_is_complex(field):
|
||||
allow_parse_failure = False
|
||||
elif is_union_origin(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata):
|
||||
allow_parse_failure = True
|
||||
else:
|
||||
return False, False
|
||||
|
||||
return True, allow_parse_failure
|
||||
|
||||
# Default value of `case_sensitive` is `None`, because we don't want to break existing behavior.
|
||||
# We have to change the method to a non-static method and use
|
||||
# `self.case_sensitive` instead in V3.
|
||||
def next_field(
|
||||
self, field: FieldInfo | Any | None, key: str, case_sensitive: bool | None = None
|
||||
) -> FieldInfo | None:
|
||||
"""
|
||||
Find the field in a sub model by key(env name)
|
||||
|
||||
By having the following models:
|
||||
|
||||
```py
|
||||
class SubSubModel(BaseSettings):
|
||||
dvals: Dict
|
||||
|
||||
class SubModel(BaseSettings):
|
||||
vals: list[str]
|
||||
sub_sub_model: SubSubModel
|
||||
|
||||
class Cfg(BaseSettings):
|
||||
sub_model: SubModel
|
||||
```
|
||||
|
||||
Then:
|
||||
next_field(sub_model, 'vals') Returns the `vals` field of `SubModel` class
|
||||
next_field(sub_model, 'sub_sub_model') Returns `sub_sub_model` field of `SubModel` class
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
key: The key (env name).
|
||||
case_sensitive: Whether to search for key case sensitively.
|
||||
|
||||
Returns:
|
||||
Field if it finds the next field otherwise `None`.
|
||||
"""
|
||||
if not field:
|
||||
return None
|
||||
|
||||
annotation = field.annotation if isinstance(field, FieldInfo) else field
|
||||
for type_ in get_args(annotation):
|
||||
type_has_key = self.next_field(type_, key, case_sensitive)
|
||||
if type_has_key:
|
||||
return type_has_key
|
||||
if _lenient_issubclass(get_origin(annotation), dict):
|
||||
# get value type if it's a dict
|
||||
return get_args(annotation)[-1]
|
||||
elif is_model_class(annotation) or is_pydantic_dataclass(annotation): # type: ignore[arg-type]
|
||||
fields = _get_model_fields(annotation)
|
||||
# `case_sensitive is None` is here to be compatible with the old behavior.
|
||||
# Has to be removed in V3.
|
||||
for field_name, f in fields.items():
|
||||
for _, env_name, _ in self._extract_field_info(f, field_name):
|
||||
if case_sensitive is None or case_sensitive:
|
||||
if field_name == key or env_name == key:
|
||||
return f
|
||||
elif field_name.lower() == key.lower() or env_name.lower() == key.lower():
|
||||
return f
|
||||
return None
|
||||
|
||||
def explode_env_vars(self, field_name: str, field: FieldInfo, env_vars: Mapping[str, str | None]) -> dict[str, Any]: # noqa: C901
|
||||
"""
|
||||
Process env_vars and extract the values of keys containing env_nested_delimiter into nested dictionaries.
|
||||
|
||||
This is applied to a single field, hence filtering by env_var prefix.
|
||||
|
||||
Args:
|
||||
field_name: The field name.
|
||||
field: The field.
|
||||
env_vars: Environment variables.
|
||||
|
||||
Returns:
|
||||
A dictionary contains extracted values from nested env values.
|
||||
"""
|
||||
if not self.env_nested_delimiter:
|
||||
return {}
|
||||
|
||||
ann = field.annotation
|
||||
is_dict = ann is dict or _lenient_issubclass(get_origin(ann), dict)
|
||||
|
||||
prefixes = [
|
||||
f'{env_name}{self.env_nested_delimiter}' for _, env_name, _ in self._extract_field_info(field, field_name)
|
||||
]
|
||||
result: dict[str, Any] = {}
|
||||
for env_name, env_val in env_vars.items():
|
||||
try:
|
||||
prefix = next(prefix for prefix in prefixes if env_name.startswith(prefix))
|
||||
except StopIteration:
|
||||
continue
|
||||
# we remove the prefix before splitting in case the prefix has characters in common with the delimiter
|
||||
env_name_without_prefix = env_name[len(prefix) :]
|
||||
*keys, last_key = env_name_without_prefix.split(self.env_nested_delimiter, self.maxsplit)
|
||||
env_var = result
|
||||
target_field: FieldInfo | None = field
|
||||
for key in keys:
|
||||
target_field = self.next_field(target_field, key, self.case_sensitive)
|
||||
if isinstance(env_var, dict):
|
||||
env_var = env_var.setdefault(key, {})
|
||||
|
||||
# get proper field with last_key
|
||||
target_field = self.next_field(target_field, last_key, self.case_sensitive)
|
||||
|
||||
# check if env_val maps to a complex field and if so, parse the env_val
|
||||
if (target_field or is_dict) and env_val:
|
||||
if isinstance(target_field, FieldInfo):
|
||||
is_complex, allow_json_failure = self._field_is_complex(target_field)
|
||||
if self.env_parse_enums:
|
||||
enum_val = _annotation_enum_name_to_val(target_field.annotation, env_val)
|
||||
env_val = env_val if enum_val is None else enum_val
|
||||
elif target_field:
|
||||
# target_field is a raw type (e.g. from dict value type annotation)
|
||||
is_complex = _annotation_is_complex(target_field, [])
|
||||
allow_json_failure = True
|
||||
else:
|
||||
# nested field type is dict
|
||||
is_complex, allow_json_failure = True, True
|
||||
if is_complex:
|
||||
try:
|
||||
field_info = target_field if isinstance(target_field, FieldInfo) else None
|
||||
env_val = self.decode_complex_value(last_key, field_info, env_val) # type: ignore
|
||||
except ValueError as e:
|
||||
if not allow_json_failure:
|
||||
raise e
|
||||
if isinstance(env_var, dict):
|
||||
if last_key not in env_var or not isinstance(env_val, EnvNoneType) or env_var[last_key] == {}:
|
||||
env_var[last_key] = self._coerce_env_val_strict(target_field, env_val)
|
||||
return result
|
||||
|
||||
def _coerce_env_val_strict(self, field: FieldInfo | None, value: Any) -> Any:
|
||||
"""
|
||||
Coerce environment string values based on field annotation if model config is `strict=True`.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
value: The value to coerce.
|
||||
|
||||
Returns:
|
||||
The coerced value if successful, otherwise the original value.
|
||||
"""
|
||||
try:
|
||||
if self.config.get('strict') and isinstance(value, str) and field is not None:
|
||||
if value == self.env_parse_none_str:
|
||||
return value
|
||||
if not _annotation_contains_types(field.annotation, (Json,), is_instance=True):
|
||||
return TypeAdapter(field.annotation).validate_python(value)
|
||||
except ValidationError:
|
||||
# Allow validation error to be raised at time of instatiation
|
||||
pass
|
||||
return value
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f'{self.__class__.__name__}(env_nested_delimiter={self.env_nested_delimiter!r}, '
|
||||
f'env_prefix_len={self.env_prefix_len!r})'
|
||||
)
|
||||
|
||||
|
||||
__all__ = ['EnvSettingsSource']
|
||||
@@ -0,0 +1,241 @@
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import warnings
|
||||
from collections.abc import Iterator, Mapping
|
||||
from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from ..types import SecretVersion
|
||||
from .env import EnvSettingsSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.auth import default as google_auth_default
|
||||
from google.auth.credentials import Credentials
|
||||
from google.cloud.secretmanager import SecretManagerServiceClient
|
||||
|
||||
from pydantic_settings.main import BaseSettings
|
||||
else:
|
||||
Credentials = None
|
||||
SecretManagerServiceClient = None
|
||||
google_auth_default = None
|
||||
|
||||
|
||||
def import_gcp_secret_manager() -> None:
|
||||
global Credentials
|
||||
global SecretManagerServiceClient
|
||||
global google_auth_default
|
||||
|
||||
try:
|
||||
from google.auth import default as google_auth_default
|
||||
from google.auth.credentials import Credentials
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=FutureWarning)
|
||||
from google.cloud.secretmanager import SecretManagerServiceClient
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError(
|
||||
'GCP Secret Manager dependencies are not installed, run `pip install pydantic-settings[gcp-secret-manager]`'
|
||||
) from e
|
||||
|
||||
|
||||
class GoogleSecretManagerMapping(Mapping[str, str | None]):
|
||||
_loaded_secrets: dict[str, str | None]
|
||||
_secret_client: SecretManagerServiceClient
|
||||
|
||||
def __init__(self, secret_client: SecretManagerServiceClient, project_id: str, case_sensitive: bool) -> None:
|
||||
self._loaded_secrets = {}
|
||||
self._secret_client = secret_client
|
||||
self._project_id = project_id
|
||||
self._case_sensitive = case_sensitive
|
||||
|
||||
@property
|
||||
def _gcp_project_path(self) -> str:
|
||||
return self._secret_client.common_project_path(self._project_id)
|
||||
|
||||
def _select_case_insensitive_secret(self, lower_name: str, candidates: list[str]) -> str:
|
||||
if len(candidates) == 1:
|
||||
return candidates[0]
|
||||
|
||||
# Sort to ensure deterministic selection (prefer lowercase / ASCII last)
|
||||
candidates.sort()
|
||||
winner = candidates[-1]
|
||||
warnings.warn(
|
||||
f"Secret collision: Found multiple secrets {candidates} normalizing to '{lower_name}'. "
|
||||
f"Using '{winner}' for case-insensitive lookup.",
|
||||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return winner
|
||||
|
||||
@cached_property
|
||||
def _secret_name_map(self) -> dict[str, str]:
|
||||
mapping: dict[str, str] = {}
|
||||
# Group secrets by normalized name to detect collisions
|
||||
normalized_groups: dict[str, list[str]] = {}
|
||||
|
||||
secrets = self._secret_client.list_secrets(parent=self._gcp_project_path)
|
||||
for secret in secrets:
|
||||
name = self._secret_client.parse_secret_path(secret.name).get('secret', '')
|
||||
mapping[name] = name
|
||||
|
||||
if not self._case_sensitive:
|
||||
lower_name = name.lower()
|
||||
if lower_name not in normalized_groups:
|
||||
normalized_groups[lower_name] = []
|
||||
normalized_groups[lower_name].append(name)
|
||||
|
||||
if not self._case_sensitive:
|
||||
for lower_name, candidates in normalized_groups.items():
|
||||
mapping[lower_name] = self._select_case_insensitive_secret(lower_name, candidates)
|
||||
|
||||
return mapping
|
||||
|
||||
@property
|
||||
def _secret_names(self) -> list[str]:
|
||||
return list(self._secret_name_map.keys())
|
||||
|
||||
def _secret_version_path(self, key: str, version: str = 'latest') -> str:
|
||||
return self._secret_client.secret_version_path(self._project_id, key, version)
|
||||
|
||||
def _get_secret_value(self, gcp_secret_name: str, version: str = 'latest') -> str | None:
|
||||
try:
|
||||
return self._secret_client.access_secret_version(
|
||||
name=self._secret_version_path(gcp_secret_name, version)
|
||||
).payload.data.decode('UTF-8')
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def __getitem__(self, key: str) -> str | None:
|
||||
if key in self._loaded_secrets:
|
||||
return self._loaded_secrets[key]
|
||||
|
||||
gcp_secret_name = self._secret_name_map.get(key)
|
||||
if gcp_secret_name is None and not self._case_sensitive:
|
||||
gcp_secret_name = self._secret_name_map.get(key.lower())
|
||||
|
||||
if gcp_secret_name:
|
||||
self._loaded_secrets[key] = self._get_secret_value(gcp_secret_name)
|
||||
else:
|
||||
raise KeyError(key)
|
||||
|
||||
return self._loaded_secrets[key]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._secret_names)
|
||||
|
||||
def __iter__(self) -> Iterator[str]:
|
||||
return iter(self._secret_names)
|
||||
|
||||
|
||||
class GoogleSecretManagerSettingsSource(EnvSettingsSource):
|
||||
_credentials: Credentials
|
||||
_secret_client: SecretManagerServiceClient
|
||||
_project_id: str
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
credentials: Credentials | None = None,
|
||||
project_id: str | None = None,
|
||||
env_prefix: str | None = None,
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
secret_client: SecretManagerServiceClient | None = None,
|
||||
case_sensitive: bool | None = True,
|
||||
) -> None:
|
||||
# Import Google Packages if they haven't already been imported
|
||||
if SecretManagerServiceClient is None or Credentials is None or google_auth_default is None:
|
||||
import_gcp_secret_manager()
|
||||
|
||||
# If credentials or project_id are not passed, then
|
||||
# try to get them from the default function
|
||||
if not credentials or not project_id:
|
||||
_creds, _project_id = google_auth_default()
|
||||
|
||||
# Set the credentials and/or project id if they weren't specified
|
||||
if credentials is None:
|
||||
credentials = _creds
|
||||
|
||||
if project_id is None:
|
||||
if isinstance(_project_id, str):
|
||||
project_id = _project_id
|
||||
else:
|
||||
raise AttributeError(
|
||||
'project_id is required to be specified either as an argument or from the google.auth.default. See https://google-auth.readthedocs.io/en/master/reference/google.auth.html#google.auth.default'
|
||||
)
|
||||
|
||||
self._credentials: Credentials = credentials
|
||||
self._project_id: str = project_id
|
||||
|
||||
if secret_client:
|
||||
self._secret_client = secret_client
|
||||
else:
|
||||
self._secret_client = SecretManagerServiceClient(credentials=self._credentials)
|
||||
|
||||
super().__init__(
|
||||
settings_cls,
|
||||
case_sensitive=case_sensitive,
|
||||
env_prefix=env_prefix,
|
||||
env_ignore_empty=False,
|
||||
env_parse_none_str=env_parse_none_str,
|
||||
env_parse_enums=env_parse_enums,
|
||||
)
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
"""Override get_field_value to get the secret value from GCP Secret Manager.
|
||||
Look for a SecretVersion metadata field to specify a particular SecretVersion.
|
||||
|
||||
Args:
|
||||
field: The field to get the value for
|
||||
field_name: The declared name of the field
|
||||
|
||||
Returns:
|
||||
A tuple of (value, key, value_is_complex), where `key` is the identifier used
|
||||
to populate the model (either the field name or an alias, depending on
|
||||
configuration).
|
||||
"""
|
||||
|
||||
secret_version = next((m.version for m in field.metadata if isinstance(m, SecretVersion)), None)
|
||||
|
||||
# If a secret version is specified, try to get that specific version of the secret from
|
||||
# GCP Secret Manager via the GoogleSecretManagerMapping. This allows different versions
|
||||
# of the same secret name to be retrieved independently and cached in the GoogleSecretManagerMapping
|
||||
if secret_version and isinstance(self.env_vars, GoogleSecretManagerMapping):
|
||||
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
|
||||
gcp_secret_name = self.env_vars._secret_name_map.get(env_name)
|
||||
if gcp_secret_name is None and not self.case_sensitive:
|
||||
gcp_secret_name = self.env_vars._secret_name_map.get(env_name.lower())
|
||||
|
||||
if gcp_secret_name:
|
||||
env_val = self.env_vars._get_secret_value(gcp_secret_name, secret_version)
|
||||
if env_val is not None:
|
||||
# If populate_by_name is enabled, return field_name to allow multiple fields
|
||||
# with the same alias but different versions to be distinguished
|
||||
if self.settings_cls.model_config.get('populate_by_name'):
|
||||
return env_val, field_name, value_is_complex
|
||||
return env_val, field_key, value_is_complex
|
||||
|
||||
# If a secret version is specified but not found, we should not fall back to "latest" (default behavior)
|
||||
# as that would be incorrect. We return None to indicate the value was not found.
|
||||
return None, field_name, False
|
||||
|
||||
val, key, is_complex = super().get_field_value(field, field_name)
|
||||
|
||||
# If populate_by_name is enabled, we need to return the field_name as the key
|
||||
# without this being enabled, you cannot load two secrets with the same name but different versions
|
||||
if self.settings_cls.model_config.get('populate_by_name') and val is not None:
|
||||
return val, field_name, is_complex
|
||||
return val, key, is_complex
|
||||
|
||||
def _load_env_vars(self) -> Mapping[str, str | None]:
|
||||
return GoogleSecretManagerMapping(
|
||||
self._secret_client, project_id=self._project_id, case_sensitive=self.case_sensitive
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(project_id={self._project_id!r}, env_nested_delimiter={self.env_nested_delimiter!r})'
|
||||
|
||||
|
||||
__all__ = ['GoogleSecretManagerSettingsSource', 'GoogleSecretManagerMapping']
|
||||
@@ -0,0 +1,48 @@
|
||||
"""JSON file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
)
|
||||
|
||||
from ..base import ConfigFileSourceMixin, InitSettingsSource
|
||||
from ..types import DEFAULT_PATH, PathType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
class JsonConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin):
|
||||
"""
|
||||
A source class that loads variables from a JSON file
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
json_file: PathType | None = DEFAULT_PATH,
|
||||
json_file_encoding: str | None = None,
|
||||
deep_merge: bool = False,
|
||||
):
|
||||
self.json_file_path = json_file if json_file != DEFAULT_PATH else settings_cls.model_config.get('json_file')
|
||||
self.json_file_encoding = (
|
||||
json_file_encoding
|
||||
if json_file_encoding is not None
|
||||
else settings_cls.model_config.get('json_file_encoding')
|
||||
)
|
||||
self.json_data = self._read_files(self.json_file_path, deep_merge=deep_merge)
|
||||
super().__init__(settings_cls, self.json_data)
|
||||
|
||||
def _read_file(self, file_path: Path) -> dict[str, Any]:
|
||||
with file_path.open(encoding=self.json_file_encoding) as json_file:
|
||||
return json.load(json_file)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(json_file={self.json_file_path})'
|
||||
|
||||
|
||||
__all__ = ['JsonConfigSettingsSource']
|
||||
@@ -0,0 +1,166 @@
|
||||
import os
|
||||
import warnings
|
||||
from functools import reduce
|
||||
from glob import iglob
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
from ...exceptions import SettingsError
|
||||
from ...utils import path_type_label
|
||||
from ..base import PydanticBaseSettingsSource
|
||||
from ..utils import parse_env_vars
|
||||
from .env import EnvSettingsSource
|
||||
from .secrets import SecretsSettingsSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...main import BaseSettings
|
||||
from ...sources import PathType
|
||||
|
||||
|
||||
SECRETS_DIR_MAX_SIZE = 16 * 2**20 # 16 MiB seems to be a reasonable default
|
||||
|
||||
|
||||
class NestedSecretsSettingsSource(EnvSettingsSource):
|
||||
def __init__(
|
||||
self,
|
||||
file_secret_settings: PydanticBaseSettingsSource | SecretsSettingsSource,
|
||||
secrets_dir: Optional['PathType'] = None,
|
||||
secrets_dir_missing: Literal['ok', 'warn', 'error'] | None = None,
|
||||
secrets_dir_max_size: int | None = None,
|
||||
secrets_case_sensitive: bool | None = None,
|
||||
secrets_prefix: str | None = None,
|
||||
secrets_nested_delimiter: str | None = None,
|
||||
secrets_nested_subdir: bool | None = None,
|
||||
# args for compatibility with SecretsSettingsSource, don't use directly
|
||||
case_sensitive: bool | None = None,
|
||||
env_prefix: str | None = None,
|
||||
) -> None:
|
||||
# We allow the first argument to be settings_cls like original
|
||||
# SecretsSettingsSource. However, it is recommended to pass
|
||||
# SecretsSettingsSource instance instead (as it is shown in usage examples),
|
||||
# otherwise `_secrets_dir` arg passed to Settings() constructor will be ignored.
|
||||
settings_cls: type[BaseSettings] = getattr(
|
||||
file_secret_settings,
|
||||
'settings_cls',
|
||||
file_secret_settings, # type: ignore[arg-type]
|
||||
)
|
||||
# config options
|
||||
conf = settings_cls.model_config
|
||||
self.secrets_dir: PathType | None = first_not_none(
|
||||
getattr(file_secret_settings, 'secrets_dir', None),
|
||||
secrets_dir,
|
||||
conf.get('secrets_dir'),
|
||||
)
|
||||
self.secrets_dir_missing: Literal['ok', 'warn', 'error'] = first_not_none(
|
||||
secrets_dir_missing,
|
||||
conf.get('secrets_dir_missing'),
|
||||
'warn',
|
||||
)
|
||||
if self.secrets_dir_missing not in ('ok', 'warn', 'error'):
|
||||
raise SettingsError(f'invalid secrets_dir_missing value: {self.secrets_dir_missing}')
|
||||
self.secrets_dir_max_size: int = first_not_none(
|
||||
secrets_dir_max_size,
|
||||
conf.get('secrets_dir_max_size'),
|
||||
SECRETS_DIR_MAX_SIZE,
|
||||
)
|
||||
self.case_sensitive: bool = first_not_none(
|
||||
secrets_case_sensitive,
|
||||
conf.get('secrets_case_sensitive'),
|
||||
case_sensitive,
|
||||
conf.get('case_sensitive'),
|
||||
False,
|
||||
)
|
||||
self.secrets_prefix: str = first_not_none(
|
||||
secrets_prefix,
|
||||
conf.get('secrets_prefix'),
|
||||
env_prefix,
|
||||
conf.get('env_prefix'),
|
||||
'',
|
||||
)
|
||||
|
||||
# nested options
|
||||
self.secrets_nested_delimiter: str | None = first_not_none(
|
||||
secrets_nested_delimiter,
|
||||
conf.get('secrets_nested_delimiter'),
|
||||
conf.get('env_nested_delimiter'),
|
||||
)
|
||||
self.secrets_nested_subdir: bool = first_not_none(
|
||||
secrets_nested_subdir,
|
||||
conf.get('secrets_nested_subdir'),
|
||||
False,
|
||||
)
|
||||
if self.secrets_nested_subdir:
|
||||
if secrets_nested_delimiter or conf.get('secrets_nested_delimiter'):
|
||||
raise SettingsError('Options secrets_nested_delimiter and secrets_nested_subdir are mutually exclusive')
|
||||
else:
|
||||
self.secrets_nested_delimiter = os.sep
|
||||
|
||||
# ensure valid secrets_path
|
||||
if self.secrets_dir is None:
|
||||
paths = []
|
||||
elif isinstance(self.secrets_dir, (Path, str)):
|
||||
paths = [self.secrets_dir]
|
||||
else:
|
||||
paths = list(self.secrets_dir)
|
||||
self.secrets_paths: list[Path] = [Path(p).expanduser().resolve() for p in paths]
|
||||
for path in self.secrets_paths:
|
||||
self.validate_secrets_path(path)
|
||||
|
||||
# construct parent
|
||||
super().__init__(
|
||||
settings_cls,
|
||||
case_sensitive=self.case_sensitive,
|
||||
env_prefix=self.secrets_prefix,
|
||||
env_nested_delimiter=self.secrets_nested_delimiter,
|
||||
env_ignore_empty=False, # match SecretsSettingsSource behaviour
|
||||
env_parse_enums=True, # we can pass everything here, it will still behave as "True"
|
||||
env_parse_none_str=None, # match SecretsSettingsSource behaviour
|
||||
)
|
||||
self.env_parse_none_str = None # update manually because of None
|
||||
|
||||
# update parent members
|
||||
if not len(self.secrets_paths):
|
||||
self.env_vars = {}
|
||||
else:
|
||||
secrets = reduce(
|
||||
lambda d1, d2: dict((*d1.items(), *d2.items())),
|
||||
(self.load_secrets(p) for p in self.secrets_paths),
|
||||
)
|
||||
self.env_vars = parse_env_vars(
|
||||
secrets,
|
||||
self.case_sensitive,
|
||||
self.env_ignore_empty,
|
||||
self.env_parse_none_str,
|
||||
)
|
||||
|
||||
def validate_secrets_path(self, path: Path) -> None:
|
||||
if not path.exists():
|
||||
if self.secrets_dir_missing == 'ok':
|
||||
pass
|
||||
elif self.secrets_dir_missing == 'warn':
|
||||
warnings.warn(f'directory "{path}" does not exist', stacklevel=2)
|
||||
elif self.secrets_dir_missing == 'error':
|
||||
raise SettingsError(f'directory "{path}" does not exist')
|
||||
else:
|
||||
raise ValueError # unreachable, checked before
|
||||
else:
|
||||
if not path.is_dir():
|
||||
raise SettingsError(f'secrets_dir must reference a directory, not a {path_type_label(path)}')
|
||||
secrets_dir_size = sum(f.stat().st_size for f in path.glob('**/*') if f.is_file())
|
||||
if secrets_dir_size > self.secrets_dir_max_size:
|
||||
raise SettingsError(f'secrets_dir size is above {self.secrets_dir_max_size} bytes')
|
||||
|
||||
@staticmethod
|
||||
def load_secrets(path: Path) -> dict[str, str]:
|
||||
return {
|
||||
str(p.relative_to(path)): p.read_text().strip()
|
||||
for p in map(Path, iglob(f'{path}/**/*', recursive=True))
|
||||
if p.is_file()
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'NestedSecretsSettingsSource(secrets_dir={self.secrets_dir!r})'
|
||||
|
||||
|
||||
def first_not_none(*objs: Any) -> Any:
|
||||
return next(filter(lambda o: o is not None, objs), None)
|
||||
@@ -0,0 +1,62 @@
|
||||
"""Pyproject TOML file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
|
||||
from .toml import TomlConfigSettingsSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
class PyprojectTomlConfigSettingsSource(TomlConfigSettingsSource):
|
||||
"""
|
||||
A source class that loads variables from a `pyproject.toml` file.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
toml_file: Path | None = None,
|
||||
) -> None:
|
||||
self.toml_file_path = self._pick_pyproject_toml_file(
|
||||
toml_file, settings_cls.model_config.get('pyproject_toml_depth', 0)
|
||||
)
|
||||
self.toml_table_header: tuple[str, ...] = settings_cls.model_config.get(
|
||||
'pyproject_toml_table_header', ('tool', 'pydantic-settings')
|
||||
)
|
||||
self.toml_data = self._read_files(self.toml_file_path)
|
||||
for key in self.toml_table_header:
|
||||
self.toml_data = self.toml_data.get(key, {})
|
||||
super(TomlConfigSettingsSource, self).__init__(settings_cls, self.toml_data)
|
||||
|
||||
@staticmethod
|
||||
def _pick_pyproject_toml_file(provided: Path | None, depth: int) -> Path:
|
||||
"""Pick a `pyproject.toml` file path to use.
|
||||
|
||||
Args:
|
||||
provided: Explicit path provided when instantiating this class.
|
||||
depth: Number of directories up the tree to check of a pyproject.toml.
|
||||
|
||||
"""
|
||||
if provided:
|
||||
return provided.resolve()
|
||||
rv = Path.cwd() / 'pyproject.toml'
|
||||
count = 0
|
||||
if not rv.is_file():
|
||||
child = rv.parent.parent / 'pyproject.toml'
|
||||
while count < depth:
|
||||
if child.is_file():
|
||||
return child
|
||||
if str(child.parent) == rv.root:
|
||||
break # end discovery after checking system root once
|
||||
child = child.parent.parent / 'pyproject.toml'
|
||||
count += 1
|
||||
return rv
|
||||
|
||||
|
||||
__all__ = ['PyprojectTomlConfigSettingsSource']
|
||||
@@ -0,0 +1,132 @@
|
||||
"""Secrets file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
)
|
||||
|
||||
from pydantic.fields import FieldInfo
|
||||
|
||||
from pydantic_settings.utils import path_type_label
|
||||
|
||||
from ...exceptions import SettingsError
|
||||
from ..base import PydanticBaseEnvSettingsSource
|
||||
from ..types import EnvPrefixTarget, PathType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
|
||||
class SecretsSettingsSource(PydanticBaseEnvSettingsSource):
|
||||
"""
|
||||
Source class for loading settings values from secret files.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
secrets_dir: PathType | None = None,
|
||||
case_sensitive: bool | None = None,
|
||||
env_prefix: str | None = None,
|
||||
env_prefix_target: EnvPrefixTarget | None = None,
|
||||
env_ignore_empty: bool | None = None,
|
||||
env_parse_none_str: str | None = None,
|
||||
env_parse_enums: bool | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
settings_cls,
|
||||
case_sensitive,
|
||||
env_prefix,
|
||||
env_prefix_target,
|
||||
env_ignore_empty,
|
||||
env_parse_none_str,
|
||||
env_parse_enums,
|
||||
)
|
||||
self.secrets_dir = secrets_dir if secrets_dir is not None else self.config.get('secrets_dir')
|
||||
|
||||
def __call__(self) -> dict[str, Any]:
|
||||
"""
|
||||
Build fields from "secrets" files.
|
||||
"""
|
||||
secrets: dict[str, str | None] = {}
|
||||
|
||||
if self.secrets_dir is None:
|
||||
return secrets
|
||||
|
||||
secrets_dirs = [self.secrets_dir] if isinstance(self.secrets_dir, (str, os.PathLike)) else self.secrets_dir
|
||||
secrets_paths = [Path(p).expanduser() for p in secrets_dirs]
|
||||
self.secrets_paths = []
|
||||
|
||||
for path in secrets_paths:
|
||||
if not path.exists():
|
||||
warnings.warn(f'directory "{path}" does not exist')
|
||||
else:
|
||||
self.secrets_paths.append(path)
|
||||
|
||||
if not len(self.secrets_paths):
|
||||
return secrets
|
||||
|
||||
for path in self.secrets_paths:
|
||||
if not path.is_dir():
|
||||
raise SettingsError(f'secrets_dir must reference a directory, not a {path_type_label(path)}')
|
||||
|
||||
return super().__call__()
|
||||
|
||||
@classmethod
|
||||
def find_case_path(cls, dir_path: Path, file_name: str, case_sensitive: bool) -> Path | None:
|
||||
"""
|
||||
Find a file within path's directory matching filename, optionally ignoring case.
|
||||
|
||||
Args:
|
||||
dir_path: Directory path.
|
||||
file_name: File name.
|
||||
case_sensitive: Whether to search for file name case sensitively.
|
||||
|
||||
Returns:
|
||||
Whether file path or `None` if file does not exist in directory.
|
||||
"""
|
||||
for f in dir_path.iterdir():
|
||||
if f.name == file_name:
|
||||
return f
|
||||
elif not case_sensitive and f.name.lower() == file_name.lower():
|
||||
return f
|
||||
return None
|
||||
|
||||
def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
|
||||
"""
|
||||
Gets the value for field from secret file and a flag to determine whether value is complex.
|
||||
|
||||
Args:
|
||||
field: The field.
|
||||
field_name: The field name.
|
||||
|
||||
Returns:
|
||||
A tuple that contains the value (`None` if the file does not exist), key, and
|
||||
a flag to determine whether value is complex.
|
||||
"""
|
||||
|
||||
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
|
||||
# paths reversed to match the last-wins behaviour of `env_file`
|
||||
for secrets_path in reversed(self.secrets_paths):
|
||||
path = self.find_case_path(secrets_path, env_name, self.case_sensitive)
|
||||
if not path:
|
||||
# path does not exist, we currently don't return a warning for this
|
||||
continue
|
||||
|
||||
if path.is_file():
|
||||
return path.read_text().strip(), field_key, value_is_complex
|
||||
else:
|
||||
warnings.warn(
|
||||
f'attempted to load secret file "{path}" but found a {path_type_label(path)} instead.',
|
||||
stacklevel=4,
|
||||
)
|
||||
|
||||
return None, field_key, value_is_complex
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(secrets_dir={self.secrets_dir!r})'
|
||||
@@ -0,0 +1,67 @@
|
||||
"""TOML file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
)
|
||||
|
||||
from ..base import ConfigFileSourceMixin, InitSettingsSource
|
||||
from ..types import DEFAULT_PATH, PathType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_settings.main import BaseSettings
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib
|
||||
else:
|
||||
tomllib = None
|
||||
import tomli
|
||||
else:
|
||||
tomllib = None
|
||||
tomli = None
|
||||
|
||||
|
||||
def import_toml() -> None:
|
||||
global tomli
|
||||
global tomllib
|
||||
if sys.version_info < (3, 11):
|
||||
if tomli is not None:
|
||||
return
|
||||
try:
|
||||
import tomli
|
||||
except ImportError as e: # pragma: no cover
|
||||
raise ImportError('tomli is not installed, run `pip install pydantic-settings[toml]`') from e
|
||||
else:
|
||||
if tomllib is not None:
|
||||
return
|
||||
import tomllib
|
||||
|
||||
|
||||
class TomlConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin):
|
||||
"""
|
||||
A source class that loads variables from a TOML file
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
toml_file: PathType | None = DEFAULT_PATH,
|
||||
deep_merge: bool = False,
|
||||
):
|
||||
self.toml_file_path = toml_file if toml_file != DEFAULT_PATH else settings_cls.model_config.get('toml_file')
|
||||
self.toml_data = self._read_files(self.toml_file_path, deep_merge=deep_merge)
|
||||
super().__init__(settings_cls, self.toml_data)
|
||||
|
||||
def _read_file(self, file_path: Path) -> dict[str, Any]:
|
||||
import_toml()
|
||||
with file_path.open(mode='rb') as toml_file:
|
||||
if sys.version_info < (3, 11):
|
||||
return tomli.load(toml_file)
|
||||
return tomllib.load(toml_file)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(toml_file={self.toml_file_path})'
|
||||
@@ -0,0 +1,130 @@
|
||||
"""YAML file settings source."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
)
|
||||
|
||||
from ..base import ConfigFileSourceMixin, InitSettingsSource
|
||||
from ..types import DEFAULT_PATH, PathType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import yaml
|
||||
|
||||
from pydantic_settings.main import BaseSettings
|
||||
else:
|
||||
yaml = None
|
||||
|
||||
|
||||
def import_yaml() -> None:
|
||||
global yaml
|
||||
if yaml is not None:
|
||||
return
|
||||
try:
|
||||
import yaml
|
||||
except ImportError as e:
|
||||
raise ImportError('PyYAML is not installed, run `pip install pydantic-settings[yaml]`') from e
|
||||
|
||||
|
||||
class YamlConfigSettingsSource(InitSettingsSource, ConfigFileSourceMixin):
|
||||
"""
|
||||
A source class that loads variables from a yaml file
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings_cls: type[BaseSettings],
|
||||
yaml_file: PathType | None = DEFAULT_PATH,
|
||||
yaml_file_encoding: str | None = None,
|
||||
yaml_config_section: str | None = None,
|
||||
deep_merge: bool = False,
|
||||
):
|
||||
self.yaml_file_path = yaml_file if yaml_file != DEFAULT_PATH else settings_cls.model_config.get('yaml_file')
|
||||
self.yaml_file_encoding = (
|
||||
yaml_file_encoding
|
||||
if yaml_file_encoding is not None
|
||||
else settings_cls.model_config.get('yaml_file_encoding')
|
||||
)
|
||||
self.yaml_config_section = (
|
||||
yaml_config_section
|
||||
if yaml_config_section is not None
|
||||
else settings_cls.model_config.get('yaml_config_section')
|
||||
)
|
||||
self.yaml_data = self._read_files(self.yaml_file_path, deep_merge=deep_merge)
|
||||
|
||||
if self.yaml_config_section is not None:
|
||||
self.yaml_data = self._traverse_nested_section(
|
||||
self.yaml_data, self.yaml_config_section, self.yaml_config_section
|
||||
)
|
||||
super().__init__(settings_cls, self.yaml_data)
|
||||
|
||||
def _read_file(self, file_path: Path) -> dict[str, Any]:
|
||||
import_yaml()
|
||||
with file_path.open(encoding=self.yaml_file_encoding) as yaml_file:
|
||||
return yaml.safe_load(yaml_file) or {}
|
||||
|
||||
def _traverse_nested_section(
|
||||
self, data: dict[str, Any], section_path: str, original_path: str | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Traverse nested YAML sections using dot-notation path.
|
||||
|
||||
This method tries to match the longest possible key first before splitting on dots,
|
||||
allowing access to YAML keys that contain literal dot characters.
|
||||
|
||||
For example, with section_path="a.b.c", it will try:
|
||||
1. "a.b.c" as a literal key
|
||||
2. "a.b" as a key, then traverse to "c"
|
||||
3. "a" as a key, then traverse to "b.c"
|
||||
4. "a" as a key, then "b" as a key, then "c" as a key
|
||||
"""
|
||||
# Track the original path for error messages
|
||||
if original_path is None:
|
||||
original_path = section_path
|
||||
|
||||
# Only reject truly empty paths
|
||||
if not section_path:
|
||||
raise ValueError('yaml_config_section cannot be empty')
|
||||
|
||||
# Try the full path as a literal key first (even with leading/trailing/consecutive dots)
|
||||
try:
|
||||
return data[section_path]
|
||||
except KeyError:
|
||||
pass # Not a literal key, try splitting
|
||||
except TypeError:
|
||||
raise TypeError(
|
||||
f'yaml_config_section path "{original_path}" cannot be traversed in {self.yaml_file_path}. '
|
||||
f'An intermediate value is not a dictionary.'
|
||||
)
|
||||
|
||||
# If path contains no dots, we already tried it as a literal key above
|
||||
if '.' not in section_path:
|
||||
raise KeyError(f'yaml_config_section key "{original_path}" not found in {self.yaml_file_path}')
|
||||
|
||||
# Try progressively shorter prefixes (greedy left-to-right approach)
|
||||
parts = section_path.split('.')
|
||||
for i in range(len(parts) - 1, 0, -1):
|
||||
prefix = '.'.join(parts[:i])
|
||||
suffix = '.'.join(parts[i:])
|
||||
|
||||
if prefix in data:
|
||||
# Found the prefix as a literal key, now recursively traverse the suffix
|
||||
try:
|
||||
return self._traverse_nested_section(data[prefix], suffix, original_path)
|
||||
except TypeError:
|
||||
raise TypeError(
|
||||
f'yaml_config_section path "{original_path}" cannot be traversed in {self.yaml_file_path}. '
|
||||
f'An intermediate value is not a dictionary.'
|
||||
)
|
||||
|
||||
# If we get here, no match was found
|
||||
raise KeyError(f'yaml_config_section key "{original_path}" not found in {self.yaml_file_path}')
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}(yaml_file={self.yaml_file_path})'
|
||||
|
||||
|
||||
__all__ = ['YamlConfigSettingsSource']
|
||||
@@ -0,0 +1,99 @@
|
||||
"""Type definitions for pydantic-settings sources."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic._internal._dataclasses import PydanticDataclass
|
||||
from pydantic.main import BaseModel
|
||||
|
||||
PydanticModel = PydanticDataclass | BaseModel
|
||||
else:
|
||||
PydanticModel = Any
|
||||
|
||||
|
||||
class EnvNoneType(str):
|
||||
pass
|
||||
|
||||
|
||||
class NoDecode:
|
||||
"""Annotation to prevent decoding of a field value."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ForceDecode:
|
||||
"""Annotation to force decoding of a field value."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
EnvPrefixTarget = Literal['variable', 'alias', 'all']
|
||||
DotenvType = Path | str | Sequence[Path | str]
|
||||
PathType = Path | str | Sequence[Path | str]
|
||||
DEFAULT_PATH: PathType = Path('')
|
||||
|
||||
# This is used as default value for `_env_file` in the `BaseSettings` class and
|
||||
# `env_file` in `DotEnvSettingsSource` so the default can be distinguished from `None`.
|
||||
# See the docstring of `BaseSettings` for more details.
|
||||
ENV_FILE_SENTINEL: DotenvType = Path('')
|
||||
|
||||
|
||||
class _CliSubCommand:
|
||||
pass
|
||||
|
||||
|
||||
class _CliPositionalArg:
|
||||
pass
|
||||
|
||||
|
||||
class _CliImplicitFlag:
|
||||
pass
|
||||
|
||||
|
||||
class _CliToggleFlag(_CliImplicitFlag):
|
||||
pass
|
||||
|
||||
|
||||
class _CliDualFlag(_CliImplicitFlag):
|
||||
pass
|
||||
|
||||
|
||||
class _CliExplicitFlag:
|
||||
pass
|
||||
|
||||
|
||||
class _CliUnknownArgs:
|
||||
pass
|
||||
|
||||
|
||||
class SecretVersion:
|
||||
def __init__(self, version: str) -> None:
|
||||
self.version = version
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.__class__.__name__}({self.version!r})'
|
||||
|
||||
|
||||
__all__ = [
|
||||
'DEFAULT_PATH',
|
||||
'ENV_FILE_SENTINEL',
|
||||
'EnvPrefixTarget',
|
||||
'DotenvType',
|
||||
'EnvNoneType',
|
||||
'ForceDecode',
|
||||
'NoDecode',
|
||||
'PathType',
|
||||
'PydanticModel',
|
||||
'SecretVersion',
|
||||
'_CliExplicitFlag',
|
||||
'_CliImplicitFlag',
|
||||
'_CliToggleFlag',
|
||||
'_CliDualFlag',
|
||||
'_CliPositionalArg',
|
||||
'_CliSubCommand',
|
||||
'_CliUnknownArgs',
|
||||
]
|
||||
@@ -0,0 +1,283 @@
|
||||
"""Utility functions for pydantic-settings sources."""
|
||||
|
||||
from __future__ import annotations as _annotations
|
||||
|
||||
from collections import deque
|
||||
from collections.abc import Mapping, Sequence
|
||||
from dataclasses import is_dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, TypeVar, cast, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel, Json, RootModel, Secret
|
||||
from pydantic._internal._utils import is_model_class
|
||||
from pydantic.dataclasses import is_pydantic_dataclass
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_inspection import typing_objects
|
||||
|
||||
from ..exceptions import SettingsError
|
||||
from ..utils import _lenient_issubclass
|
||||
from .types import EnvNoneType
|
||||
|
||||
|
||||
def _get_env_var_key(key: str, case_sensitive: bool = False) -> str:
|
||||
return key if case_sensitive else key.lower()
|
||||
|
||||
|
||||
def _parse_env_none_str(value: str | None, parse_none_str: str | None = None) -> str | None | EnvNoneType:
|
||||
return value if not (value == parse_none_str and parse_none_str is not None) else EnvNoneType(value)
|
||||
|
||||
|
||||
def parse_env_vars(
|
||||
env_vars: Mapping[str, str | None],
|
||||
case_sensitive: bool = False,
|
||||
ignore_empty: bool = False,
|
||||
parse_none_str: str | None = None,
|
||||
) -> Mapping[str, str | None]:
|
||||
return {
|
||||
_get_env_var_key(k, case_sensitive): _parse_env_none_str(v, parse_none_str)
|
||||
for k, v in env_vars.items()
|
||||
if not (ignore_empty and v == '')
|
||||
}
|
||||
|
||||
|
||||
def _substitute_typevars(tp: Any, param_map: dict[Any, Any]) -> Any:
|
||||
"""Substitute TypeVars in a type annotation with concrete types from param_map."""
|
||||
if isinstance(tp, TypeVar) and tp in param_map:
|
||||
return param_map[tp]
|
||||
args = get_args(tp)
|
||||
if not args:
|
||||
return tp
|
||||
new_args = tuple(_substitute_typevars(arg, param_map) for arg in args)
|
||||
if new_args == args:
|
||||
return tp
|
||||
origin = get_origin(tp)
|
||||
if origin is not None:
|
||||
try:
|
||||
return origin[new_args]
|
||||
except TypeError:
|
||||
# types.UnionType and similar are not directly subscriptable,
|
||||
# reconstruct using | operator
|
||||
import functools
|
||||
import operator
|
||||
|
||||
return functools.reduce(operator.or_, new_args)
|
||||
return tp
|
||||
|
||||
|
||||
def _resolve_type_alias(annotation: Any) -> Any:
|
||||
"""Resolve a TypeAliasType to its underlying value, substituting type params if parameterized."""
|
||||
if typing_objects.is_typealiastype(annotation):
|
||||
return annotation.__value__
|
||||
origin = get_origin(annotation)
|
||||
if typing_objects.is_typealiastype(origin):
|
||||
type_params = getattr(origin, '__type_params__', ())
|
||||
type_args = get_args(annotation)
|
||||
value = origin.__value__
|
||||
if type_params and type_args:
|
||||
return _substitute_typevars(value, dict(zip(type_params, type_args)))
|
||||
return value
|
||||
return annotation
|
||||
|
||||
|
||||
def _annotation_is_complex(annotation: Any, metadata: list[Any]) -> bool:
|
||||
# If the model is a root model, the root annotation should be used to
|
||||
# evaluate the complexity.
|
||||
annotation = _resolve_type_alias(annotation)
|
||||
if annotation is not None and _lenient_issubclass(annotation, RootModel) and annotation is not RootModel:
|
||||
annotation = cast('type[RootModel[Any]]', annotation)
|
||||
root_annotation = annotation.model_fields['root'].annotation
|
||||
if root_annotation is not None: # pragma: no branch
|
||||
annotation = root_annotation
|
||||
|
||||
if any(isinstance(md, Json) for md in metadata): # type: ignore[misc]
|
||||
return False
|
||||
|
||||
origin = get_origin(annotation)
|
||||
|
||||
# Check if annotation is of the form Annotated[type, metadata].
|
||||
if typing_objects.is_annotated(origin):
|
||||
# Return result of recursive call on inner type.
|
||||
inner, *meta = get_args(annotation)
|
||||
return _annotation_is_complex(inner, meta)
|
||||
|
||||
if origin is Secret:
|
||||
return False
|
||||
|
||||
return (
|
||||
_annotation_is_complex_inner(annotation)
|
||||
or _annotation_is_complex_inner(origin)
|
||||
or hasattr(origin, '__pydantic_core_schema__')
|
||||
or hasattr(origin, '__get_pydantic_core_schema__')
|
||||
)
|
||||
|
||||
|
||||
def _get_field_metadata(field: FieldInfo) -> list[Any]:
|
||||
annotation = _resolve_type_alias(field.annotation)
|
||||
metadata = field.metadata
|
||||
origin = get_origin(annotation)
|
||||
if typing_objects.is_annotated(origin):
|
||||
_, *meta = get_args(annotation)
|
||||
metadata += meta
|
||||
return metadata
|
||||
|
||||
|
||||
def _annotation_is_complex_inner(annotation: type[Any] | None) -> bool:
|
||||
if _lenient_issubclass(annotation, (str, bytes)):
|
||||
return False
|
||||
|
||||
return _lenient_issubclass(
|
||||
annotation, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque)
|
||||
) or is_dataclass(annotation)
|
||||
|
||||
|
||||
def _union_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool:
|
||||
"""Check if a union type contains any complex types."""
|
||||
return any(_annotation_is_complex(arg, metadata) for arg in get_args(annotation))
|
||||
|
||||
|
||||
def _annotation_contains_types(
|
||||
annotation: type[Any] | None,
|
||||
types: tuple[Any, ...],
|
||||
is_include_origin: bool = True,
|
||||
is_strip_annotated: bool = False,
|
||||
is_instance: bool = False,
|
||||
collect: set[Any] | None = None,
|
||||
) -> bool:
|
||||
"""Check if a type annotation contains any of the specified types."""
|
||||
if is_strip_annotated:
|
||||
annotation = _strip_annotated(annotation)
|
||||
if is_include_origin is True:
|
||||
origin = get_origin(annotation)
|
||||
if origin in types:
|
||||
if collect is None:
|
||||
return True
|
||||
collect.add(annotation)
|
||||
if is_instance and any(isinstance(origin, type_) for type_ in types):
|
||||
if collect is None:
|
||||
return True
|
||||
collect.add(annotation)
|
||||
for type_ in get_args(annotation):
|
||||
if (
|
||||
_annotation_contains_types(
|
||||
type_,
|
||||
types,
|
||||
is_include_origin=True,
|
||||
is_strip_annotated=is_strip_annotated,
|
||||
is_instance=is_instance,
|
||||
collect=collect,
|
||||
)
|
||||
and collect is None
|
||||
):
|
||||
return True
|
||||
if is_instance and any(isinstance(annotation, type_) for type_ in types):
|
||||
if collect is None:
|
||||
return True
|
||||
collect.add(annotation)
|
||||
if annotation in types:
|
||||
if collect is not None:
|
||||
collect.add(annotation)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _strip_annotated(annotation: Any) -> Any:
|
||||
if typing_objects.is_annotated(get_origin(annotation)):
|
||||
return annotation.__origin__
|
||||
else:
|
||||
return annotation
|
||||
|
||||
|
||||
def _annotation_enum_val_to_name(annotation: type[Any] | None, value: Any) -> str | None:
|
||||
for type_ in (annotation, get_origin(annotation), *get_args(annotation)):
|
||||
if _lenient_issubclass(type_, Enum):
|
||||
if value in type_.__members__.values():
|
||||
return type_(value).name
|
||||
return None
|
||||
|
||||
|
||||
def _annotation_enum_name_to_val(annotation: type[Any] | None, name: Any) -> Any:
|
||||
for type_ in (annotation, get_origin(annotation), *get_args(annotation)):
|
||||
if _lenient_issubclass(type_, Enum):
|
||||
if name in type_.__members__.keys():
|
||||
return type_[name]
|
||||
return None
|
||||
|
||||
|
||||
def _get_model_fields(model_cls: type[Any]) -> dict[str, Any]:
|
||||
"""Get fields from a pydantic model or dataclass."""
|
||||
|
||||
if is_pydantic_dataclass(model_cls) and hasattr(model_cls, '__pydantic_fields__'):
|
||||
return model_cls.__pydantic_fields__
|
||||
if is_model_class(model_cls):
|
||||
return model_cls.model_fields
|
||||
raise SettingsError(f'Error: {model_cls.__name__} is not subclass of BaseModel or pydantic.dataclasses.dataclass')
|
||||
|
||||
|
||||
def _get_alias_names(
|
||||
field_name: str,
|
||||
field_info: Any,
|
||||
alias_path_args: dict[str, int | None] | None = None,
|
||||
case_sensitive: bool = True,
|
||||
) -> tuple[tuple[str, ...], bool]:
|
||||
"""Get alias names for a field, handling alias paths and case sensitivity."""
|
||||
from pydantic import AliasChoices, AliasPath
|
||||
|
||||
alias_names: list[str] = []
|
||||
is_alias_path_only: bool = True
|
||||
if not any((field_info.alias, field_info.validation_alias)):
|
||||
alias_names += [field_name]
|
||||
is_alias_path_only = False
|
||||
else:
|
||||
new_alias_paths: list[AliasPath] = []
|
||||
for alias in (field_info.alias, field_info.validation_alias):
|
||||
if alias is None:
|
||||
continue
|
||||
elif isinstance(alias, str):
|
||||
alias_names.append(alias)
|
||||
is_alias_path_only = False
|
||||
elif isinstance(alias, AliasChoices):
|
||||
for name in alias.choices:
|
||||
if isinstance(name, str):
|
||||
alias_names.append(name)
|
||||
is_alias_path_only = False
|
||||
else:
|
||||
new_alias_paths.append(name)
|
||||
else:
|
||||
new_alias_paths.append(alias)
|
||||
for alias_path in new_alias_paths:
|
||||
name = cast(str, alias_path.path[0])
|
||||
name = name.lower() if not case_sensitive else name
|
||||
if alias_path_args is not None:
|
||||
alias_path_args[name] = (
|
||||
alias_path.path[1] if len(alias_path.path) > 1 and isinstance(alias_path.path[1], int) else None
|
||||
)
|
||||
if not alias_names and is_alias_path_only:
|
||||
alias_names.append(name)
|
||||
if not case_sensitive:
|
||||
alias_names = [alias_name.lower() for alias_name in alias_names]
|
||||
return tuple(dict.fromkeys(alias_names)), is_alias_path_only
|
||||
|
||||
|
||||
def _is_function(obj: Any) -> bool:
|
||||
"""Check if an object is a function."""
|
||||
from types import BuiltinFunctionType, FunctionType
|
||||
|
||||
return isinstance(obj, (FunctionType, BuiltinFunctionType))
|
||||
|
||||
|
||||
__all__ = [
|
||||
'_annotation_contains_types',
|
||||
'_annotation_enum_name_to_val',
|
||||
'_annotation_enum_val_to_name',
|
||||
'_annotation_is_complex',
|
||||
'_annotation_is_complex_inner',
|
||||
'_get_alias_names',
|
||||
'_get_env_var_key',
|
||||
'_get_model_fields',
|
||||
'_is_function',
|
||||
'_parse_env_none_str',
|
||||
'_resolve_type_alias',
|
||||
'_strip_annotated',
|
||||
'_union_is_complex',
|
||||
'parse_env_vars',
|
||||
]
|
||||
43
venv/lib/python3.12/site-packages/pydantic_settings/utils.py
Normal file
43
venv/lib/python3.12/site-packages/pydantic_settings/utils.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import types
|
||||
from pathlib import Path
|
||||
from typing import Any, _Final, _GenericAlias, get_origin # type: ignore [attr-defined]
|
||||
|
||||
_PATH_TYPE_LABELS = {
|
||||
Path.is_dir: 'directory',
|
||||
Path.is_file: 'file',
|
||||
Path.is_mount: 'mount point',
|
||||
Path.is_symlink: 'symlink',
|
||||
Path.is_block_device: 'block device',
|
||||
Path.is_char_device: 'char device',
|
||||
Path.is_fifo: 'FIFO',
|
||||
Path.is_socket: 'socket',
|
||||
}
|
||||
|
||||
|
||||
def path_type_label(p: Path) -> str:
|
||||
"""
|
||||
Find out what sort of thing a path is.
|
||||
"""
|
||||
assert p.exists(), 'path does not exist'
|
||||
for method, name in _PATH_TYPE_LABELS.items():
|
||||
if method(p):
|
||||
return name
|
||||
|
||||
return 'unknown' # pragma: no cover
|
||||
|
||||
|
||||
# TODO remove and replace usage by `isinstance(cls, type) and issubclass(cls, class_or_tuple)`
|
||||
# once we drop support for Python 3.10.
|
||||
def _lenient_issubclass(cls: Any, class_or_tuple: Any) -> bool: # pragma: no cover
|
||||
try:
|
||||
return isinstance(cls, type) and issubclass(cls, class_or_tuple)
|
||||
except TypeError:
|
||||
if get_origin(cls) is not None:
|
||||
# Up until Python 3.10, isinstance(<generic_alias>, type) is True
|
||||
# (e.g. list[int])
|
||||
return False
|
||||
raise
|
||||
|
||||
|
||||
_WithArgsTypes = (_GenericAlias, types.GenericAlias, types.UnionType)
|
||||
_typing_base: Any = _Final # pyright: ignore[reportAttributeAccessIssue]
|
||||
@@ -0,0 +1 @@
|
||||
VERSION = '2.13.1'
|
||||
Reference in New Issue
Block a user