diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index ff9b2b0d7e542..e2697da915fbf 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -232,11 +232,16 @@ def log_hyperparams(self, params: Union[dict[str, Any], Namespace]) -> None: params = _convert_params(params) params = _flatten_dict(params) + import mlflow.utils.validation from mlflow.entities import Param - # Truncate parameter values to 250 characters. - # TODO: MLflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0 - params_list = [Param(key=k, value=str(v)[:250]) for k, v in params.items()] + try: # Check maximum param value length is available and use it + param_length_limit = mlflow.utils.validation.MAX_PARAM_VAL_LENGTH + except Exception: # Fallback (in case of MAX_PARAM_VAL_LENGTH not available) + param_length_limit = 250 # Historical default value + + # Use mlflow default limit or truncate parameter values to 250 characters if limit is not available + params_list = [Param(key=k, value=str(v)[:param_length_limit]) for k, v in params.items()] # Log in chunks of 100 parameters (the maximum allowed by MLflow). for idx in range(0, len(params_list), 100): diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index c7f9dbe1fe2c6..e0e07c6acb268 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -252,6 +252,14 @@ def test_mlflow_logger_experiment_calls(mlflow_mock, tmp_path): ) param.assert_called_with(key="test", value="test_param") + long_params = {"test": "test_param" * 50} + logger.log_hyperparams(long_params) + + logger.experiment.log_batch.assert_called_with( + run_id=logger.run_id, params=[param(key="test", value="test_param" * 50)] + ) + param.assert_called_with(key="test", value="test_param" * 50) + metrics = {"some_metric": 10} logger.log_metrics(metrics) @@ -317,12 +325,7 @@ def test_mlflow_logger_no_synchronous_support(mlflow_mock, tmp_path): @mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) def test_mlflow_logger_with_long_param_value(mlflow_mock, tmp_path): - """Test that long parameter values are truncated to 250 characters.""" - - def _check_value_length(value, *args, **kwargs): - assert len(value) <= 250 - - mlflow_mock.entities.Param.side_effect = _check_value_length + """Test that long parameter values are handled correctly.""" logger = MLFlowLogger("test", save_dir=str(tmp_path))