Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 48 additions & 1 deletion dandischema/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
StringConstraints,
model_validator,
)
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic.fields import FieldInfo
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
)
from typing_extensions import Self

_MODELS_MODULE_NAME = "dandischema.models"
Expand Down Expand Up @@ -186,6 +191,48 @@ def _ensure_non_none_instance_identifier_if_non_none_doi_prefix(
)
return self

# This is a workaround for the limitation imposed by the bug at
# https://github.com/pydantic/pydantic/issues/12191 mentioned above.
# TODO: This will no longer be needed once that bug is fixed and
# should be removed along with other workarounds in this model because
# of that bug.
@classmethod
def settings_customise_sources(
cls,
settings_cls: type[BaseSettings],
init_settings: PydanticBaseSettingsSource,
env_settings: PydanticBaseSettingsSource,
dotenv_settings: PydanticBaseSettingsSource,
file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]:
def wrap(source: PydanticBaseSettingsSource) -> PydanticBaseSettingsSource:
class Wrapped(PydanticBaseSettingsSource):
def get_field_value(
self, field: FieldInfo, field_name: str
) -> tuple[Any, str, bool]:
raise NotImplementedError(
"If this method is ever called, there is a bug"
)

def __call__(self) -> dict[str, Any]:
result = source().copy()
for field_name in cls.model_fields:
if field_name in result:
alias = f"dandi_{field_name}"
# This overwrites the `alias` key if it already exists
result[alias] = result[field_name]
del result[field_name]
return result

return Wrapped(settings_cls)

return (
wrap(init_settings),
env_settings,
dotenv_settings,
file_secret_settings,
)


_instance_config = Config() # Initial value is set by env vars alone
"""
Expand Down
134 changes: 115 additions & 19 deletions dandischema/tests/test_conf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
from pathlib import Path
from typing import Optional, Union

from pydantic import ValidationError
Expand All @@ -17,18 +18,20 @@ def test_get_instance_config() -> None:
), "`get_instance_config` should return a copy of the instance config"


_FOO_CONFIG_DICT_BY_FIELD_NAME = {
FOO_CONFIG_DICT = {
"instance_name": "FOO",
"instance_identifier": "RRID:ABC_123456",
"instance_url": "https://dandiarchive.org/",
"doi_prefix": "10.1234",
"licenses": ["spdx:AdaCore-doc", "spdx:AGPL-3.0-or-later", "spdx:NBPL-1.0"],
}

FOO_CONFIG_DICT = {f"dandi_{k}": v for k, v in _FOO_CONFIG_DICT_BY_FIELD_NAME.items()}
# Same as `FOO_CONFIG_DICT` but with the field aliases instead of the field names being
# the keys
FOO_CONFIG_DICT_WITH_ALIASES = {f"dandi_{k}": v for k, v in FOO_CONFIG_DICT.items()}

FOO_CONFIG_ENV_VARS = {
k: v if k != "licenses" else json.dumps(v)
for k, v in _FOO_CONFIG_DICT_BY_FIELD_NAME.items()
k: v if k != "licenses" else json.dumps(v) for k, v in FOO_CONFIG_DICT.items()
}


Expand All @@ -47,7 +50,7 @@ def test_valid_instance_name(self, instance_name: str) -> None:
"""
from dandischema.conf import Config

Config(dandi_instance_name=instance_name)
Config(instance_name=instance_name)

@pytest.mark.parametrize(
"clear_dandischema_modules_and_set_env_vars", [{}], indirect=True
Expand All @@ -61,7 +64,7 @@ def test_invalid_instance_name(self, instance_name: str) -> None:
from dandischema.conf import Config

with pytest.raises(ValidationError) as exc_info:
Config(dandi_instance_name=instance_name)
Config(instance_name=instance_name)

assert len(exc_info.value.errors()) == 1
assert exc_info.value.errors()[0]["loc"] == ("dandi_instance_name",)
Expand All @@ -81,7 +84,7 @@ def test_valid_instance_identifier(
"""
from dandischema.conf import Config

Config(dandi_instance_identifier=instance_identifier)
Config(instance_identifier=instance_identifier)

@pytest.mark.parametrize(
"clear_dandischema_modules_and_set_env_vars", [{}], indirect=True
Expand All @@ -95,7 +98,7 @@ def test_invalid_instance_identifier(self, instance_identifier: str) -> None:
from dandischema.conf import Config

with pytest.raises(ValidationError) as exc_info:
Config(dandi_instance_identifier=instance_identifier)
Config(instance_identifier=instance_identifier)

assert len(exc_info.value.errors()) == 1
assert exc_info.value.errors()[0]["loc"] == ("dandi_instance_identifier",)
Expand All @@ -114,7 +117,7 @@ def test_without_instance_identifier_with_doi_prefix(self) -> None:
with pytest.raises(
ValidationError, match="`instance_identifier` must also be set."
):
Config(dandi_doi_prefix="10.1234")
Config(doi_prefix="10.1234")

@pytest.mark.parametrize(
"clear_dandischema_modules_and_set_env_vars", [{}], indirect=True
Expand All @@ -131,8 +134,8 @@ def test_valid_doi_prefix(self, doi_prefix: str) -> None:

Config(
# Instance identifier must be provided if doi_prefix is provided
dandi_instance_identifier="RRID:SCR_017571",
dandi_doi_prefix=doi_prefix,
instance_identifier="RRID:SCR_017571",
doi_prefix=doi_prefix,
)

@pytest.mark.parametrize(
Expand All @@ -149,8 +152,8 @@ def test_invalid_doi_prefix(self, doi_prefix: str) -> None:
with pytest.raises(ValidationError) as exc_info:
Config(
# Instance identifier must be provided if doi_prefix is provided
dandi_instance_identifier="RRID:SCR_017571",
dandi_doi_prefix=doi_prefix,
instance_identifier="RRID:SCR_017571",
doi_prefix=doi_prefix,
)

assert len(exc_info.value.errors()) == 1
Expand Down Expand Up @@ -179,7 +182,7 @@ def test_valid_licenses_by_args(self, licenses: Union[list[str], set[str]]) -> N
from dandischema.conf import Config, License

# noinspection PyTypeChecker
config = Config(dandi_licenses=licenses)
config = Config(licenses=licenses)

assert config.licenses == {License(license_) for license_ in set(licenses)}

Expand Down Expand Up @@ -234,20 +237,113 @@ def test_invalid_licenses_by_args(self, licenses: set[str]) -> None:

with pytest.raises(ValidationError) as exc_info:
# noinspection PyTypeChecker
Config(dandi_licenses=licenses)
Config(licenses=licenses)

assert len(exc_info.value.errors()) == 1
assert exc_info.value.errors()[0]["loc"][:-1] == ("dandi_licenses",)

@pytest.mark.parametrize(
"clear_dandischema_modules_and_set_env_vars",
[
{},
{"instance_name": "BAR"},
{"instance_name": "BAZ", "instance_url": "https://www.example.com/"},
],
indirect=True,
)
@pytest.mark.parametrize(
"config_dict", [FOO_CONFIG_DICT, FOO_CONFIG_DICT_WITH_ALIASES]
)
def test_init_by_kwargs(
self, clear_dandischema_modules_and_set_env_vars: None, config_dict: dict
) -> None:
"""
Test instantiating `Config` using keyword arguments

The kwargs are expected to override any environment variables
"""
from dandischema.conf import Config

config = Config.model_validate(config_dict)
config_json_dump = config.model_dump(mode="json")

assert config_json_dump.keys() == FOO_CONFIG_DICT.keys()
for k, v in FOO_CONFIG_DICT.items():
if k == "licenses":
assert sorted(config_json_dump[k]) == sorted(v)
else:
assert config_json_dump[k] == v

@pytest.mark.parametrize(
"clear_dandischema_modules_and_set_env_vars",
[
{},
],
indirect=True,
)
def test_init_by_field_names_through_dotenv(
self,
clear_dandischema_modules_and_set_env_vars: None,
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""
Test instantiating `Config` using a dotenv file with field names as keys

The initialization is expected to fail because the proper keys are the aliases
when using environment variables or dotenv files.
"""
from dandischema.conf import Config

dotenv_file_name = "test.env"
dotenv_file_path = tmp_path / dotenv_file_name

# Write a dotenv file with a field name as key
dotenv_file_path.write_text("instance_name=DANDI-TEST")

monkeypatch.chdir(tmp_path)

with pytest.raises(ValidationError) as exc_info:
# noinspection PyArgumentList
Config(_env_file=dotenv_file_name)

errors = exc_info.value.errors()
assert len(errors) == 1

assert errors[0]["type"] == "extra_forbidden"

@pytest.mark.parametrize(
"clear_dandischema_modules_and_set_env_vars",
[
{},
],
indirect=True,
)
def test_round_trip(self, clear_dandischema_modules_and_set_env_vars: None) -> None:
"""
Test that a `Config` instance can be round-tripped through JSON serialization
and deserialization without loss of information.
"""
from dandischema.conf import Config

config_original = Config.model_validate(FOO_CONFIG_DICT)
config_original_str = config_original.model_dump_json()

config_reconstituted = Config.model_validate_json(config_original_str)

assert (
config_reconstituted == config_original
), "Round-trip of `Config` instance failed"


class TestSetInstanceConfig:
@pytest.mark.parametrize(
("arg", "kwargs"),
[
(FOO_CONFIG_DICT, {"dandi_instance_name": "BAR"}),
(FOO_CONFIG_DICT, {"instance_name": "BAR"}),
(
FOO_CONFIG_DICT,
{"dandi_instance_name": "Baz", "key": "value"},
{"instance_name": "Baz", "key": "value"},
),
],
)
Expand Down Expand Up @@ -356,8 +452,8 @@ def test_after_models_import_different_config(
import dandischema.models # noqa: F401

new_config_dict = {
"dandi_instance_name": "BAR",
"dandi_doi_prefix": "10.5678",
"instance_name": "BAR",
"doi_prefix": "10.5678",
}

# noinspection DuplicatedCode
Expand Down
Loading