diff --git a/src/disp_s1/_ps.py b/src/disp_s1/_ps.py index eca9e72..c251991 100644 --- a/src/disp_s1/_ps.py +++ b/src/disp_s1/_ps.py @@ -203,20 +203,7 @@ def run_combine( reader_mean = io.RasterReader.from_file(cur_mean, band=1) reader_dispersion = io.RasterReader.from_file(cur_dispersion, band=1) - num_images = 1 + len(compressed_slc_files) - if weight_scheme == WeightScheme.LINEAR: - # Increase the weights from older to newer. - N = np.linspace(0, 1, num=num_images) * num_slc - elif weight_scheme == WeightScheme.EQUAL: - # Increase the weights from older to newer. - N = num_slc * np.ones((num_images,)) - elif weight_scheme == WeightScheme.EXPONENTIAL: - alpha = 0.5 - weights = np.exp(alpha * np.arange(num_images)) - weights /= weights.max() - N = weights.round().astype(int) - else: - raise ValueError(f"Unrecognized {weight_scheme = }") + N = _get_weighting(len(compressed_slc_files), weight_scheme, num_slc) io.write_arr(arr=None, output_name=out_dispersion, like_filename=cur_dispersion) io.write_arr(arr=None, output_name=out_mean, like_filename=cur_mean) @@ -262,3 +249,30 @@ def run_combine( ) return (out_dispersion, out_mean) + + +def _get_weighting( + num_compressed_slc_files: int, + weight_scheme: WeightScheme, + num_slc: int, +) -> np.ndarray: + # The total number of *mean*/dispersion images that we'll be averaging + # All current real SLCs get averaged into 1, so it's 1 more than the old means + num_images = num_compressed_slc_files + 1 + + if weight_scheme == WeightScheme.LINEAR: + # Increase the weights from older to newer. + N = np.linspace(0, 1, num=num_images) * num_slc + elif weight_scheme == WeightScheme.EQUAL: + # Increase the weights from older to newer. + N = num_slc * np.ones((num_images,)) + elif weight_scheme == WeightScheme.EXPONENTIAL: + alpha = 0.5 + weights = np.exp(alpha * np.arange(num_images)) + # Normalize weights so that the oldest image weight 1 + # More recent images count as many more than the oldest + weights /= weights.min() + N = weights.round().astype(int) + else: + raise ValueError(f"Unrecognized {weight_scheme = }") + return N diff --git a/src/disp_s1/main.py b/src/disp_s1/main.py index 4ca6cfc..12f0b8a 100644 --- a/src/disp_s1/main.py +++ b/src/disp_s1/main.py @@ -23,7 +23,7 @@ from disp_s1 import __version__, product from disp_s1._masking import create_layover_shadow_masks, create_mask_from_distance from disp_s1._ps import precompute_ps -from disp_s1.pge_runconfig import AlgorithmParameters, RunConfig, StaticLayersRunConfig +from disp_s1.pge_runconfig import AlgorithmParameters, RunConfig from ._reference import ReferencePoint, read_reference_point from ._utils import ( @@ -39,7 +39,7 @@ @log_runtime def run( cfg: DisplacementWorkflow, - pge_runconfig: RunConfig | StaticLayersRunConfig, + pge_runconfig: RunConfig, debug: bool = False, ) -> None: """Run the displacement workflow on a stack of SLCs. diff --git a/tests/test_ps.py b/tests/test_ps.py new file mode 100644 index 0000000..6c5190b --- /dev/null +++ b/tests/test_ps.py @@ -0,0 +1,21 @@ +import numpy as np + +from disp_s1._ps import WeightScheme, _get_weighting + + +def test_get_weighting(): + expected_ns = [ + [1], + [1, 2], + [1, 2, 3], + [1, 2, 3, 4], + [1, 2, 3, 4, 7], + [1, 2, 3, 4, 7, 12], + ] + for num_comp, expected_N in enumerate(expected_ns): + N = _get_weighting( + num_compressed_slc_files=num_comp, + weight_scheme=WeightScheme.EXPONENTIAL, + num_slc=15, + ) + assert np.allclose(N, expected_N)