diff --git a/saspt/dataset.py b/saspt/dataset.py index 518c8b1..79cab2b 100644 --- a/saspt/dataset.py +++ b/saspt/dataset.py @@ -201,7 +201,14 @@ def raw_track_statistics(self) -> pd.DataFrame: pandas.DataFrame, where each row corresponds to one file """ if not hasattr(self, "_raw_track_statistics"): - self._raw_track_statistics = self._get_raw_track_statistics() + if self.n_files > 0: + self._raw_track_statistics = self._get_raw_track_statistics() + else: + # Set empty stats with expected columns and metadata + self._raw_track_statistics = pd.DataFrame( + columns=TrajectoryGroup.statistic_names) + for c in self.paths.columns: + self._raw_track_statistics[c] = self.paths[c] return self._raw_track_statistics @property @@ -217,7 +224,13 @@ def processed_track_statistics(self) -> pd.DataFrame: pandas.DataFrame, where each row corresponds to one file """ if not hasattr(self, "_processed_track_statistics"): - self._processed_track_statistics = self._get_processed_track_statistics() + if self.n_files > 0: + self.calc_occs_and_stats_parallelized() + else: + self._processed_track_statistics = pd.DataFrame( + columns=TrajectoryGroup.statistic_names) + for c in self.paths.columns: + self._processed_track_statistics[c] = self.paths[c] return self._processed_track_statistics @property @@ -234,11 +247,7 @@ def naive_occs(self) -> np.ndarray: """ if not hasattr(self, "_naive_occs"): if self.n_files > 0: - self._naive_occs = np.asarray(self.parallel_map( - self.calc_naive_occs, - self.paths[self.path_col], - progress_bar=self.progress_bar, - )) + self.calc_occs_and_stats_parallelized() else: self._naive_occs = np.zeros((self.n_files, *self.shape), dtype=np.float64) return self._naive_occs @@ -257,11 +266,7 @@ def posterior_occs(self) -> np.ndarray: """ if not hasattr(self, "_posterior_occs"): if self.n_files > 0: - self._posterior_occs = np.asarray(self.parallel_map( - self.calc_posterior_occs, - self.paths[self.path_col], - progress_bar=self.progress_bar, - )) + self.calc_occs_and_stats_parallelized() else: self._posterior_occs = np.zeros((self.n_files, *self.shape), dtype=np.float64) return self._posterior_occs @@ -348,6 +353,83 @@ def marginal_posterior_occs_dataframe(self) -> pd.DataFrame: self._marginal_posterior_occs_dataframe = df return self._marginal_posterior_occs_dataframe + ############# + ## METHODS ## + ############# + + def clear(self): + """ Delete expensive cached attributes """ + for attr in ["_n_files", "_naive_occs", "_posterior_occs", + "_processed_track_statistics", "_jumps_per_file"]: + if hasattr(self, attr): + delattr(self, attr) + + def calc_occs_and_stats_parallelized(self) -> Tuple[ + np.ndarray, np.ndarray, pd.DataFrame]: + """ Calculate naive occupations, posterior occupations, + and processed track statistics, parallelized for a set + of trajectories. This allows us to subsample the same + trajectories (if needed) to get these three attributes. + """ + @dask.delayed + def g(filepath: str) -> Tuple[np.ndarray, np.ndarray, dict]: + SA = self._init_state_array(filepath) + naive_occs = SA.n_jumps * SA.naive_occs + posterior_occs = SA.n_jumps * SA.posterior_occs + stats = SA.trajectories.processed_track_statistics + stats[self.path_col] = filepath + return naive_occs, posterior_occs, stats + + result = self.parallel_map( + g, self.paths[self.path_col], progress_bar=self.progress_bar) + naive_occs = np.asarray([r[0] for r in result]) + posterior_occs = np.asarray([r[1] for r in result]) + stats = [r[2] for r in result] + + # Test for empty stats dict + if not stats: + # Set empty occs + self._naive_occs = np.zeros((self.n_files, *self.shape), dtype=np.float64) + self._posterior_occs = np.zeros((self.n_files, *self.shape), dtype=np.float64) + # Set empty stats with expected columns and metadata + self._processed_track_statistics = pd.DataFrame( + columns=TrajectoryGroup.statistic_names) + for c in self.paths.columns: + self._raw_track_statistics[c] = self.paths[c] + return + + # Put stats into DF and sanity check + stats = pd.DataFrame(stats) + assert (stats[self.path_col] == self.paths[self.path_col]).all() + + # Map all metadata from the input paths DataFrame + # to the track statistics dataframe + for c in filter(lambda c: c!=self.path_col, self.paths.columns): + stats[c] = self.paths[c] + + self._processed_track_statistics = stats + self._naive_occs = np.asarray(naive_occs) + self._posterior_occs = np.asarray(posterior_occs) + + def calc_marginal_posterior_occs(self, *track_paths: str) -> np.ndarray: + """ Calculate the posterior mean state occupations for a particular + set of trajectories, marginalized on diffusion coefficient. + + args + ---- + track_paths : paths to files with trajectories readable + by saspt.utils.load_detections + + returns + ------- + numpy.ndarray of shape *n_diff_coefs*, occupations scaled + by the total number of jumps observed in this set of + trajectories + """ + SA = self._init_state_array(*track_paths) + return self.likelihood.marginalize_on_diff_coef( + SA.n_jumps * SA.posterior_occs) + def infer_posterior_by_condition(self, col: str, normalize: bool=False ) -> Tuple[np.ndarray, List[str]]: """ Aggregate trajectories across files by grouping on an arbitrary @@ -376,81 +458,6 @@ def infer_posterior_by_condition(self, col: str, normalize: bool=False posterior_occs = normalize_2d(posterior_occs, axis=1) return posterior_occs, conditions - ############# - ## METHODS ## - ############# - - def clear(self): - """ Delete expensive cached attributes """ - for attr in ["_n_files", "_naive_occs", "_posterior_occs"]: - if hasattr(self, attr): - delattr(self, attr) - - def calc_naive_occs(self, *track_paths: str) -> np.ndarray: - """ - args - ---- - track_paths : paths to files with trajectories, readable by - saspt.utils.load_detections - - returns - ------- - numpy.ndarray of shape *self.shape*, occupations scaled by the - total number of jumps observed for each SPT experiment - """ - SA = self._init_state_array(*track_paths) - return SA.naive_occs - - def calc_posterior_occs(self, *track_paths: str) -> np.ndarray: - """ - args - ---- - track_paths : paths to files with trajectories, readable by - saspt.utils.load_detections - - returns - ------- - numpy.ndarray of shape *self.shape*, mean posterior occupations - scaled by the total number of jumps observed for each SPT experiment - """ - SA = self._init_state_array(*track_paths) - return SA.n_jumps * SA.posterior_occs - - def calc_marginal_naive_occs(self, *track_paths: str) -> np.ndarray: - """ Calculate the likelihood function for a particular set of - trajectories, marginalized on the diffusion coefficient. - - args - ---- - track_paths : paths to files with trajectories readable - by saspt.utils.load_detections - - returns - ------- - numpy.ndarray of shape *n_diff_coefs*, occupations scaled by the - total number of jumps observed in these trajectories - """ - return self.likelihood.marginalize_on_diff_coef( - self.calc_naive_occs(*track_paths)) - - def calc_marginal_posterior_occs(self, *track_paths: str) -> np.ndarray: - """ Calculate the posterior mean state occupations for a particular - set of trajectories, marginalized on diffusion coefficient. - - args - ---- - track_paths : paths to files with trajectories readable - by saspt.utils.load_detections - - returns - ------- - numpy.ndarray of shape *n_diff_coefs*, occupations scaled - by the total number of jumps observed in this set of - trajectories - """ - return self.likelihood.marginalize_on_diff_coef( - self.calc_posterior_occs(*track_paths)) - ############## ## PLOTTING ## ############## @@ -599,38 +606,6 @@ def _init_state_array(self, *track_paths: str) -> StateArray: StateArray over them """ return StateArray(self._load_tracks(*track_paths), self.likelihood, self.params) - def _get_processed_track_statistics(self) -> pd.DataFrame: - """ Calculate some statistics on the preprocessed trajectories for each - file in this StateArrayDataset. - - returns - ------- - pandas.DataFrame with each row corresponding to one file. Columns - correspond to different statistics - """ - @dask.delayed - def g(filepath: str) -> dict: - T = self._load_tracks(filepath) - stats = T.processed_track_statistics - stats[self.path_col] = filepath - return stats - result = pd.DataFrame(self.parallel_map(g, self.paths[self.path_col])) - - # Conceivable that there are zero files in this dataset - if len(result) == 0: - result[self.path_col] = self.paths[self.path_col] - for stat in TrajectoryGroup.statistic_names: - result[stat] = pd.Series([], dtype=np.float64) - - # Sanity check - assert (result[self.path_col] == self.paths[self.path_col]).all() - - # Map all metadata from the input paths DataFrame to the track statistics dataframe - for c in filter(lambda c: c!=self.path_col, self.paths.columns): - result[c] = self.paths[c] - - return result - def _get_raw_track_statistics(self) -> pd.DataFrame: """ Calculated some statistics on the raw trajectories for each file in this StateArrayDataset. @@ -664,8 +639,8 @@ def g(filepath: str) -> dict: return result def parallel_map(self, func, args, msg: str=None, progress_bar: bool=False): - """ Parallelize a function across multiple arguments using a process-based - dask scheduler. + """ Parallelize a function across multiple arguments using a + process-based dask scheduler. args ---- @@ -730,4 +705,4 @@ def apply_by(self, col: str, func: Callable, is_variadic: bool=False, else: result = self.parallel_map(lambda paths: func(paths, **kwargs), file_groups) - return result, conditions + return result, conditions \ No newline at end of file diff --git a/tests/test_dataset.py b/tests/test_dataset.py index d80b60b..7815680 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -98,7 +98,7 @@ def test_marginal_naive_occs(self): ML = D.marginal_naive_occs assert isinstance(ML, np.ndarray) assert ML.shape == (len(self.paths), len(self.likelihood.diff_coefs)) - assert (np.abs(ML.sum(axis=1) - 1.0) < 1.0e-6).all() + assert (np.abs(ML.sum(axis=1) - D.jumps_per_file) < 1.0e-6).all() # Make sure StateArrayDataset.clear works D.clear() @@ -214,3 +214,40 @@ def test_posterior_line_plot(self): condition_col=self.condition_col) self.check_plot_func(D.posterior_line_plot, "_out_test_posterior_line_plot.png") + + def test_subsampling(self): + # New params with a smaller sample size + sample_size = 10 + params = StateArrayParameters( + pixel_size_um=0.16, + frame_interval=0.01, + focal_depth=0.7, + splitsize=10, + sample_size=sample_size, + start_frame=0, + max_iter=10, + conc_param=1.0, + progress_bar=False, + num_workers=2, + ) + self.params = params + D = StateArrayDataset(self.paths, self.likelihood, + params=self.params, path_col=self.path_col, + condition_col=self.condition_col) + + # Check that jumps_per_file and implied jumps are correct + assert np.allclose(D.jumps_per_file.astype(float), D.posterior_occs.sum(axis=(1,2))) + assert np.allclose(D.jumps_per_file.astype(float), D.naive_occs.sum(axis=(1,2))) + + # Check that subsampling actually worked + n_trajs = D.processed_track_statistics['n_tracks'] + assert (n_trajs <= sample_size).all() + + # Clear and repeat tests + D.clear() + assert np.allclose(D.jumps_per_file.astype(float), D.posterior_occs.sum(axis=(1,2))) + assert np.allclose(D.jumps_per_file.astype(float), D.naive_occs.sum(axis=(1,2))) + + # Check that subsampling actually worked + n_trajs = D.processed_track_statistics['n_tracks'] + assert (n_trajs <= sample_size).all()