diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index 754f916bb..86ea93b90 100755 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -7,6 +7,7 @@ ## Bug fixes and other changes - Fixed `PartitionedDataset` to reliably load newly created partitions, particularly with `ParallelRunner`, by ensuring `load()` always re-scans the filesystem . +- Fixed `StudyDataset` to properly propagate a RDB password through the dataset's `credentials` ## Breaking changes @@ -16,7 +17,7 @@ Many thanks to the following Kedroids for contributing PRs to this release: -- ... +- [Guillaume Tauzin](https://github.com/gtauzin) # Release 7.0.0 diff --git a/kedro-datasets/kedro_datasets_experimental/optuna/study_dataset.py b/kedro-datasets/kedro_datasets_experimental/optuna/study_dataset.py index 9472e7ada..c420e4e3d 100644 --- a/kedro-datasets/kedro_datasets_experimental/optuna/study_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/optuna/study_dataset.py @@ -116,13 +116,13 @@ def __init__( # noqa: PLR0913 self._study_name = self._validate_study_name(study_name=study_name) credentials = self._validate_credentials(backend=backend, credentials=credentials) - storage = URL.create( + storage_url = URL.create( drivername=backend, database=database, **credentials, ) - self._storage = str(storage) + self._storage_url = storage_url self.metadata = metadata filepath = None @@ -286,8 +286,10 @@ def load(self) -> optuna.Study: pruner_config = load_args.pop("pruner") pruner = self._get_pruner(pruner_config) + storage_url_str = self._storage_url.render_as_string(hide_password=False) + storage = optuna.storages.RDBStorage(url=storage_url_str) study = optuna.load_study( - storage=self._storage, + storage=storage, study_name=self._get_load_study_name(), sampler=sampler, pruner=pruner, @@ -297,25 +299,29 @@ def load(self) -> optuna.Study: def save(self, study: optuna.Study) -> None: save_study_name = self._get_save_study_name() + + storage_url_str = self._storage_url.render_as_string(hide_password=False) if self._backend == "sqlite": os.makedirs(os.path.dirname(self._filepath), exist_ok=True) if not os.path.isfile(self._filepath): optuna.create_study( - storage=self._storage, + storage=storage_url_str, ) + storage = optuna.storages.RDBStorage(url=storage_url_str) + # To overwrite an existing study, we need to first delete it if it exists if self._study_name_exists(save_study_name): optuna.delete_study( - storage=self._storage, + storage=storage, study_name=save_study_name, ) optuna.copy_study( from_study_name=study.study_name, from_storage=study._storage, - to_storage=self._storage, + to_storage=storage, to_study_name=save_study_name, ) @@ -323,11 +329,15 @@ def _study_name_exists(self, study_name) -> bool: if self._backend == "sqlite" and not os.path.isfile(self._database): return False - study_names = optuna.study.get_all_study_names(storage=self._storage) + storage_url_str = self._storage_url.render_as_string(hide_password=False) + storage = optuna.storages.RDBStorage(url=storage_url_str) + study_names = optuna.study.get_all_study_names(storage=storage) return study_name in study_names def _study_name_glob(self, pattern): - study_names = optuna.study.get_all_study_names(storage=self._storage) + storage_url_str = self._storage_url.render_as_string(hide_password=False) + storage = optuna.storages.RDBStorage(url=storage_url_str) + study_names = optuna.study.get_all_study_names(storage=storage) for study_name in study_names: if fnmatch.fnmatch(study_name, pattern): yield study_name diff --git a/kedro-datasets/kedro_datasets_experimental/tests/optuna/test_study_dataset.py b/kedro-datasets/kedro_datasets_experimental/tests/optuna/test_study_dataset.py index c0ddec2ac..71309b792 100644 --- a/kedro-datasets/kedro_datasets_experimental/tests/optuna/test_study_dataset.py +++ b/kedro-datasets/kedro_datasets_experimental/tests/optuna/test_study_dataset.py @@ -105,6 +105,26 @@ def test_invalid_credentials(self): credentials={"username": "user", "pwd": "pass"}, # pragma: allowlist secret ) + def test_study_existence(self): + """Test invalid credentials raise ValueError.""" + study_dataset = StudyDataset( + study_name="test", + backend="postgresql", + database="optuna_db", + credentials={"username": "user", "password": "pass"}, # pragma: allowlist secret + ) + + # Test that RDB storage can be created but DB access module cannot be imported + with pytest.raises(ImportError, match="Failed to import DB access module"): + study_dataset._study_name_exists(study_name="blah") + + study_dataset = StudyDataset( + study_name="test", + backend="sqlite", + database="optuna.db", + ) + assert study_dataset._study_name_exists(study_name="blah") is False + def test_study_name_exists(self, study_dataset, dummy_study): """Test `_study_name_exists` method.""" assert not study_dataset._study_name_exists("test")