diff --git a/CHANGELOG.md b/CHANGELOG.md index 958d99a..8c1bf30 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [PEP 440](https://www.python.org/dev/peps/pep-0440/) and uses [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.0.11] - 2025-05-22 + +### Added +- Updated workflows.py, processing.py, and runconfig_model.py to accept +stride_for_norm_param_estimation, batch_size_for_norm_param_estimation, +optimize params. + ## [0.0.10] - 2025-05-19 ### Added diff --git a/src/dist_s1/__main__.py b/src/dist_s1/__main__.py index 2647701..6b98845 100644 --- a/src/dist_s1/__main__.py +++ b/src/dist_s1/__main__.py @@ -158,6 +158,29 @@ def common_options(func: Callable) -> Callable: required=False, help='Path to Transformer model weights file.', ) + @click.option( + '--stride_for_norm_param_estimation', + type=int, + default=16, + required=False, + help='Batch size for norm param. Number of pixels the' + ' convolutional filter moves across the input image at' + ' each step.' + ) + @click.option( + '--batch_size_for_norm_param_estimation', + type=int, + default=32, + required=False, + help='Batch size for norm param estimation; Tune it according to resouces i.e. memory.', + ) + @click.option( + '--optimize', + type=bool, + default=True, + required=False, + help='Flag to enable compilation duringe execution.', + ) @functools.wraps(func) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: return func(*args, **kwargs) @@ -200,6 +223,9 @@ def run_sas_prep( model_source: str | None, model_cfg_path: str | Path | None, model_wts_path: str | Path | None, + stride_for_norm_param_estimation: int = 16, + batch_size_for_norm_param_estimation: int = 32, + optimize: bool = True ) -> None: """Run SAS prep workflow.""" run_config = run_dist_s1_sas_prep_workflow( @@ -226,6 +252,9 @@ def run_sas_prep( model_source=model_source, model_cfg_path=model_cfg_path, model_wts_path=model_wts_path, + stride_for_norm_param_estimation=stride_for_norm_param_estimation, + batch_size_for_norm_param_estimation=batch_size_for_norm_param_estimation, + optimize=optimize ) run_config.to_yaml(runconfig_path) @@ -266,6 +295,9 @@ def run( model_source: str | None, model_cfg_path: str | Path | None, model_wts_path: str | Path | None, + stride_for_norm_param_estimation: int = 16, + batch_size_for_norm_param_estimation: int = 32, + optimize: bool = True ) -> str: """Localize data and run dist_s1_workflow.""" return run_dist_s1_workflow( @@ -292,6 +324,9 @@ def run( model_source=model_source, model_cfg_path=model_cfg_path, model_wts_path=model_wts_path, + stride_for_norm_param_estimation=stride_for_norm_param_estimation, + batch_size_for_norm_param_estimation=batch_size_for_norm_param_estimation, + optimize=optimize ) diff --git a/src/dist_s1/data_models/runconfig_model.py b/src/dist_s1/data_models/runconfig_model.py index 68fc0e4..724bdc2 100644 --- a/src/dist_s1/data_models/runconfig_model.py +++ b/src/dist_s1/data_models/runconfig_model.py @@ -133,14 +133,25 @@ class RunConfigData(BaseModel): pattern='^(high|low)$', ) tqdm_enabled: bool = Field(default=True) - batch_size_for_despeckling: int = Field( - default=25, - ge=1, - ) n_workers_for_norm_param_estimation: int = Field( default=8, ge=1, ) + # Batch size for transformer model. + batch_size_for_norm_param_estimation: int = Field( + default=32, + ge=1, + ) + # Stride for transformer model. + stride_for_norm_param_estimation: int = Field( + default=16, + ge=1, + le=16, + ) + batch_size_for_despeckling: int = Field( + default=25, + ge=1, + ) n_workers_for_despeckling: int = Field( default=8, ge=1, @@ -149,6 +160,11 @@ class RunConfigData(BaseModel): # This is where default thresholds are set! moderate_confidence_threshold: float = Field(default=3.5, ge=0.0, le=15.0) high_confidence_threshold: float = Field(default=5.5, ge=0.0, le=15.0) + + # Flag to enable optimizations. False, load the model and use it. + # True, load the model and compile for CPU or GPU + optimize: bool = Field(default=True) + product_dst_dir: Path | str | None = None bucket: str | None = None bucket_prefix: str = '' diff --git a/src/dist_s1/processing.py b/src/dist_s1/processing.py index 946607d..c6ccfcf 100644 --- a/src/dist_s1/processing.py +++ b/src/dist_s1/processing.py @@ -53,7 +53,10 @@ def compute_normal_params_per_burst_and_serialize( out_path_sigma_copol: Path, out_path_sigma_crosspol: Path, memory_strategy: str = 'high', + stride: int = 2, + batch_size: int = 32, device: str = 'best', + optimize: bool = False, model_source: str | None = None, model_cfg_path: Path | None = None, model_wts_path: Path | None = None, @@ -63,12 +66,14 @@ def compute_normal_params_per_burst_and_serialize( # For distmetrics, None is how we choose the "best" available device if device == 'best': device = None + if model_source == 'external': model = load_transformer_model( - model_token=model_source, model_cfg_path=model_cfg_path, model_wts_path=model_wts_path, device=device + model_token=model_source, model_cfg_path=model_cfg_path, model_wts_path=model_wts_path, + device=device, optimize=optimize, batch_size=batch_size ) else: - model = load_transformer_model(device=device) + model = load_transformer_model(device=device, optimize=optimize, batch_size=batch_size) copol_data = [open_one_ds(path) for path in pre_copol_paths_dskpl_paths] crosspol_data = [open_one_ds(path) for path in pre_crosspol_paths_dskpl_paths] @@ -83,7 +88,8 @@ def compute_normal_params_per_burst_and_serialize( check_profiles_match(p_ref, p_crosspol) logits_mu, logits_sigma = estimate_normal_params_of_logits( - model, arrs_copol, arrs_crosspol, memory_strategy=memory_strategy, device=device + model, arrs_copol, arrs_crosspol, memory_strategy=memory_strategy, device=device, stride=stride, + batch_size=batch_size, ) logits_mu_copol, logits_mu_crosspol = logits_mu[0, ...], logits_mu[1, ...] logits_sigma_copol, logits_sigma_crosspol = logits_sigma[0, ...], logits_sigma[1, ...] diff --git a/src/dist_s1/workflows.py b/src/dist_s1/workflows.py index ced3016..2936670 100644 --- a/src/dist_s1/workflows.py +++ b/src/dist_s1/workflows.py @@ -181,8 +181,16 @@ def run_despeckle_workflow(run_config: RunConfigData) -> None: def _process_normal_params( - path_data: dict, memory_strategy: str, device: str, model_source: str, model_cfg_path: Path, model_wts_path: Path -) -> None: + path_data: dict, + memory_strategy: str, + device: str, + model_source: str, + model_cfg_path: Path, + model_wts_path: Path, + batch_size: int, + stride: int, + optimize: bool + ) -> None: return compute_normal_params_per_burst_and_serialize( path_data['copol_paths_pre'], path_data['crosspol_paths_pre'], @@ -195,6 +203,9 @@ def _process_normal_params( model_source=model_source, model_cfg_path=model_cfg_path, model_wts_path=model_wts_path, + batch_size=batch_size, + stride=stride, + optimize=optimize ) @@ -239,6 +250,9 @@ def run_normal_param_estimation_workflow(run_config: RunConfigData) -> None: model_source=run_config.model_source, model_cfg_path=run_config.model_cfg_path, model_wts_path=run_config.model_wts_path, + stride=run_config.stride_for_norm_param_estimation, + batch_size=run_config.batch_size_for_norm_param_estimation, + optimize=run_config.optimize ) else: if run_config.device in ('cuda', 'mps'): @@ -252,6 +266,9 @@ def run_normal_param_estimation_workflow(run_config: RunConfigData) -> None: model_source=run_config.model_source, model_cfg_path=run_config.model_cfg_path, model_wts_path=run_config.model_wts_path, + stride=run_config.stride_for_norm_param_estimation, + optimize=run_config.optimize, + batch_size=run_config.batch_size_for_norm_param_estimation ) # Start a pool of workers @@ -397,6 +414,9 @@ def run_dist_s1_sas_prep_workflow( model_source: str | None = None, model_cfg_path: str | Path | None = None, model_wts_path: str | Path | None = None, + stride_for_norm_param_estimation: int = 16, + batch_size_for_norm_param_estimation: int = 32, + optimize: bool = True ) -> RunConfigData: run_config = run_dist_s1_localization_workflow( mgrs_tile_id, @@ -425,6 +445,9 @@ def run_dist_s1_sas_prep_workflow( run_config.model_source = model_source run_config.model_cfg_path = model_cfg_path run_config.model_wts_path = model_wts_path + run_config.stride_for_norm_param_estimation = stride_for_norm_param_estimation + run_config.batch_size_for_norm_param_estimation = batch_size_for_norm_param_estimation + run_config.optimize = optimize return run_config @@ -462,6 +485,9 @@ def run_dist_s1_workflow( model_source: str | None = None, model_cfg_path: str | Path | None = None, model_wts_path: str | Path | None = None, + stride_for_norm_param_estimation: int = 16, + batch_size_for_norm_param_estimation: int = 32, + optimize: bool = True ) -> Path: run_config = run_dist_s1_sas_prep_workflow( mgrs_tile_id, @@ -487,6 +513,9 @@ def run_dist_s1_workflow( model_source=model_source, model_cfg_path=model_cfg_path, model_wts_path=model_wts_path, + stride_for_norm_param_estimation=stride_for_norm_param_estimation, + batch_size_for_norm_param_estimation=batch_size_for_norm_param_estimation, + optimize=optimize ) _ = run_dist_s1_sas_workflow(run_config) @@ -543,6 +572,9 @@ def run_dist_s1_sas_prep_runconfig_yml(run_config_template_yml_path: Path | str) batch_size_for_despeckling = rc_data.get('batch_size_for_despeckling', 25) n_workers_for_norm_param_estimation = rc_data.get('n_workers_for_norm_param_estimation', 1) device = rc_data.get('device', 'cpu') + stride_for_norm_param_estimation = rc_data.get('stride_for_norm_param_estimation', 16) + batch_size_for_norm_param_estimation = rc_data.get('batch_size_for_norm_param_estimation', 32) + optimize = rc_data.get('optimize', True) run_config = run_dist_s1_localization_workflow( mgrs_tile_id, @@ -572,5 +604,8 @@ def run_dist_s1_sas_prep_runconfig_yml(run_config_template_yml_path: Path | str) run_config.model_cfg_path = model_cfg_path run_config.model_wts_path = model_wts_path run_config.device = device + run_config.stride_for_norm_param_estimation = stride_for_norm_param_estimation + run_config.batch_size_for_norm_param_estimation = batch_size_for_norm_param_estimation + run_config.optimize = optimize return run_config