Skip to content

Fix subsampling issues within StateArrayDataset #12

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
219 changes: 97 additions & 122 deletions saspt/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,14 @@ def raw_track_statistics(self) -> pd.DataFrame:
pandas.DataFrame, where each row corresponds to one file
"""
if not hasattr(self, "_raw_track_statistics"):
self._raw_track_statistics = self._get_raw_track_statistics()
if self.n_files > 0:
self._raw_track_statistics = self._get_raw_track_statistics()
else:
# Set empty stats with expected columns and metadata
self._raw_track_statistics = pd.DataFrame(
columns=TrajectoryGroup.statistic_names)
for c in self.paths.columns:
self._raw_track_statistics[c] = self.paths[c]
return self._raw_track_statistics

@property
Expand All @@ -217,7 +224,13 @@ def processed_track_statistics(self) -> pd.DataFrame:
pandas.DataFrame, where each row corresponds to one file
"""
if not hasattr(self, "_processed_track_statistics"):
self._processed_track_statistics = self._get_processed_track_statistics()
if self.n_files > 0:
self.calc_occs_and_stats_parallelized()
else:
self._processed_track_statistics = pd.DataFrame(
columns=TrajectoryGroup.statistic_names)
for c in self.paths.columns:
self._processed_track_statistics[c] = self.paths[c]
return self._processed_track_statistics

@property
Expand All @@ -234,11 +247,7 @@ def naive_occs(self) -> np.ndarray:
"""
if not hasattr(self, "_naive_occs"):
if self.n_files > 0:
self._naive_occs = np.asarray(self.parallel_map(
self.calc_naive_occs,
self.paths[self.path_col],
progress_bar=self.progress_bar,
))
self.calc_occs_and_stats_parallelized()
else:
self._naive_occs = np.zeros((self.n_files, *self.shape), dtype=np.float64)
return self._naive_occs
Expand All @@ -257,11 +266,7 @@ def posterior_occs(self) -> np.ndarray:
"""
if not hasattr(self, "_posterior_occs"):
if self.n_files > 0:
self._posterior_occs = np.asarray(self.parallel_map(
self.calc_posterior_occs,
self.paths[self.path_col],
progress_bar=self.progress_bar,
))
self.calc_occs_and_stats_parallelized()
else:
self._posterior_occs = np.zeros((self.n_files, *self.shape), dtype=np.float64)
return self._posterior_occs
Expand Down Expand Up @@ -348,6 +353,83 @@ def marginal_posterior_occs_dataframe(self) -> pd.DataFrame:
self._marginal_posterior_occs_dataframe = df
return self._marginal_posterior_occs_dataframe

#############
## METHODS ##
#############

def clear(self):
""" Delete expensive cached attributes """
for attr in ["_n_files", "_naive_occs", "_posterior_occs",
"_processed_track_statistics", "_jumps_per_file"]:
if hasattr(self, attr):
delattr(self, attr)

def calc_occs_and_stats_parallelized(self) -> Tuple[
np.ndarray, np.ndarray, pd.DataFrame]:
""" Calculate naive occupations, posterior occupations,
and processed track statistics, parallelized for a set
of trajectories. This allows us to subsample the same
trajectories (if needed) to get these three attributes.
"""
@dask.delayed
def g(filepath: str) -> Tuple[np.ndarray, np.ndarray, dict]:
SA = self._init_state_array(filepath)
naive_occs = SA.n_jumps * SA.naive_occs
posterior_occs = SA.n_jumps * SA.posterior_occs
stats = SA.trajectories.processed_track_statistics
stats[self.path_col] = filepath
return naive_occs, posterior_occs, stats

result = self.parallel_map(
g, self.paths[self.path_col], progress_bar=self.progress_bar)
naive_occs = np.asarray([r[0] for r in result])
posterior_occs = np.asarray([r[1] for r in result])
stats = [r[2] for r in result]

# Test for empty stats dict
if not stats:
# Set empty occs
self._naive_occs = np.zeros((self.n_files, *self.shape), dtype=np.float64)
self._posterior_occs = np.zeros((self.n_files, *self.shape), dtype=np.float64)
# Set empty stats with expected columns and metadata
self._processed_track_statistics = pd.DataFrame(
columns=TrajectoryGroup.statistic_names)
for c in self.paths.columns:
self._raw_track_statistics[c] = self.paths[c]
return

# Put stats into DF and sanity check
stats = pd.DataFrame(stats)
assert (stats[self.path_col] == self.paths[self.path_col]).all()

# Map all metadata from the input paths DataFrame
# to the track statistics dataframe
for c in filter(lambda c: c!=self.path_col, self.paths.columns):
stats[c] = self.paths[c]

self._processed_track_statistics = stats
self._naive_occs = np.asarray(naive_occs)
self._posterior_occs = np.asarray(posterior_occs)

def calc_marginal_posterior_occs(self, *track_paths: str) -> np.ndarray:
""" Calculate the posterior mean state occupations for a particular
set of trajectories, marginalized on diffusion coefficient.

args
----
track_paths : paths to files with trajectories readable
by saspt.utils.load_detections

returns
-------
numpy.ndarray of shape *n_diff_coefs*, occupations scaled
by the total number of jumps observed in this set of
trajectories
"""
SA = self._init_state_array(*track_paths)
return self.likelihood.marginalize_on_diff_coef(
SA.n_jumps * SA.posterior_occs)

def infer_posterior_by_condition(self, col: str, normalize: bool=False
) -> Tuple[np.ndarray, List[str]]:
""" Aggregate trajectories across files by grouping on an arbitrary
Expand Down Expand Up @@ -376,81 +458,6 @@ def infer_posterior_by_condition(self, col: str, normalize: bool=False
posterior_occs = normalize_2d(posterior_occs, axis=1)
return posterior_occs, conditions

#############
## METHODS ##
#############

def clear(self):
""" Delete expensive cached attributes """
for attr in ["_n_files", "_naive_occs", "_posterior_occs"]:
if hasattr(self, attr):
delattr(self, attr)

def calc_naive_occs(self, *track_paths: str) -> np.ndarray:
"""
args
----
track_paths : paths to files with trajectories, readable by
saspt.utils.load_detections

returns
-------
numpy.ndarray of shape *self.shape*, occupations scaled by the
total number of jumps observed for each SPT experiment
"""
SA = self._init_state_array(*track_paths)
return SA.naive_occs

def calc_posterior_occs(self, *track_paths: str) -> np.ndarray:
"""
args
----
track_paths : paths to files with trajectories, readable by
saspt.utils.load_detections

returns
-------
numpy.ndarray of shape *self.shape*, mean posterior occupations
scaled by the total number of jumps observed for each SPT experiment
"""
SA = self._init_state_array(*track_paths)
return SA.n_jumps * SA.posterior_occs

def calc_marginal_naive_occs(self, *track_paths: str) -> np.ndarray:
""" Calculate the likelihood function for a particular set of
trajectories, marginalized on the diffusion coefficient.

args
----
track_paths : paths to files with trajectories readable
by saspt.utils.load_detections

returns
-------
numpy.ndarray of shape *n_diff_coefs*, occupations scaled by the
total number of jumps observed in these trajectories
"""
return self.likelihood.marginalize_on_diff_coef(
self.calc_naive_occs(*track_paths))

def calc_marginal_posterior_occs(self, *track_paths: str) -> np.ndarray:
""" Calculate the posterior mean state occupations for a particular
set of trajectories, marginalized on diffusion coefficient.

args
----
track_paths : paths to files with trajectories readable
by saspt.utils.load_detections

returns
-------
numpy.ndarray of shape *n_diff_coefs*, occupations scaled
by the total number of jumps observed in this set of
trajectories
"""
return self.likelihood.marginalize_on_diff_coef(
self.calc_posterior_occs(*track_paths))

##############
## PLOTTING ##
##############
Expand Down Expand Up @@ -599,38 +606,6 @@ def _init_state_array(self, *track_paths: str) -> StateArray:
StateArray over them """
return StateArray(self._load_tracks(*track_paths), self.likelihood, self.params)

def _get_processed_track_statistics(self) -> pd.DataFrame:
""" Calculate some statistics on the preprocessed trajectories for each
file in this StateArrayDataset.

returns
-------
pandas.DataFrame with each row corresponding to one file. Columns
correspond to different statistics
"""
@dask.delayed
def g(filepath: str) -> dict:
T = self._load_tracks(filepath)
stats = T.processed_track_statistics
stats[self.path_col] = filepath
return stats
result = pd.DataFrame(self.parallel_map(g, self.paths[self.path_col]))

# Conceivable that there are zero files in this dataset
if len(result) == 0:
result[self.path_col] = self.paths[self.path_col]
for stat in TrajectoryGroup.statistic_names:
result[stat] = pd.Series([], dtype=np.float64)

# Sanity check
assert (result[self.path_col] == self.paths[self.path_col]).all()

# Map all metadata from the input paths DataFrame to the track statistics dataframe
for c in filter(lambda c: c!=self.path_col, self.paths.columns):
result[c] = self.paths[c]

return result

def _get_raw_track_statistics(self) -> pd.DataFrame:
""" Calculated some statistics on the raw trajectories for each file in
this StateArrayDataset.
Expand Down Expand Up @@ -664,8 +639,8 @@ def g(filepath: str) -> dict:
return result

def parallel_map(self, func, args, msg: str=None, progress_bar: bool=False):
""" Parallelize a function across multiple arguments using a process-based
dask scheduler.
""" Parallelize a function across multiple arguments using a
process-based dask scheduler.

args
----
Expand Down Expand Up @@ -730,4 +705,4 @@ def apply_by(self, col: str, func: Callable, is_variadic: bool=False,
else:
result = self.parallel_map(lambda paths: func(paths, **kwargs), file_groups)

return result, conditions
return result, conditions
39 changes: 38 additions & 1 deletion tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def test_marginal_naive_occs(self):
ML = D.marginal_naive_occs
assert isinstance(ML, np.ndarray)
assert ML.shape == (len(self.paths), len(self.likelihood.diff_coefs))
assert (np.abs(ML.sum(axis=1) - 1.0) < 1.0e-6).all()
assert (np.abs(ML.sum(axis=1) - D.jumps_per_file) < 1.0e-6).all()

# Make sure StateArrayDataset.clear works
D.clear()
Expand Down Expand Up @@ -214,3 +214,40 @@ def test_posterior_line_plot(self):
condition_col=self.condition_col)
self.check_plot_func(D.posterior_line_plot,
"_out_test_posterior_line_plot.png")

def test_subsampling(self):
# New params with a smaller sample size
sample_size = 10
params = StateArrayParameters(
pixel_size_um=0.16,
frame_interval=0.01,
focal_depth=0.7,
splitsize=10,
sample_size=sample_size,
start_frame=0,
max_iter=10,
conc_param=1.0,
progress_bar=False,
num_workers=2,
)
self.params = params
D = StateArrayDataset(self.paths, self.likelihood,
params=self.params, path_col=self.path_col,
condition_col=self.condition_col)

# Check that jumps_per_file and implied jumps are correct
assert np.allclose(D.jumps_per_file.astype(float), D.posterior_occs.sum(axis=(1,2)))
assert np.allclose(D.jumps_per_file.astype(float), D.naive_occs.sum(axis=(1,2)))

# Check that subsampling actually worked
n_trajs = D.processed_track_statistics['n_tracks']
assert (n_trajs <= sample_size).all()

# Clear and repeat tests
D.clear()
assert np.allclose(D.jumps_per_file.astype(float), D.posterior_occs.sum(axis=(1,2)))
assert np.allclose(D.jumps_per_file.astype(float), D.naive_occs.sum(axis=(1,2)))

# Check that subsampling actually worked
n_trajs = D.processed_track_statistics['n_tracks']
assert (n_trajs <= sample_size).all()