From d0265bb37bcf25558d5f359076daffece00032d3 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Fri, 22 Dec 2023 17:51:34 +0900 Subject: [PATCH 01/47] refactor(Zettaset): remove mandatory `zettaset_path` requirement --- deepem/train/option.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepem/train/option.py b/deepem/train/option.py index a0aec25..c30af48 100644 --- a/deepem/train/option.py +++ b/deepem/train/option.py @@ -25,7 +25,7 @@ def initialize(self): self.parser.add_argument('--modifier_kwargs', type=json.loads, default={}) # zettasets - self.parser.add_argument('--zettaset_path', required=True, type=str, default=[], nargs='+') + self.parser.add_argument('--zettaset_path', type=str, default=[], nargs='+') self.parser.add_argument('--zettaset_lookup', type=json.loads, default=None) self.parser.add_argument('--zettaset_padding', type=vec3, default=(0, 0, 0)) self.parser.add_argument('--zettaset_no_mask', action='store_true') From 0a9bb822335504ffd7033649e364200db470eaf8 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Fri, 22 Dec 2023 17:52:41 +0900 Subject: [PATCH 02/47] feat(Zettaset): add zettaset-specific specs for separate and detailed specification of each zettaset --- deepem/data/dataset/multi_zettaset.py | 203 ++++++++++++++++++++++++++ deepem/data/dataset/zettaset.py | 31 ++-- deepem/train/option.py | 7 +- 3 files changed, 228 insertions(+), 13 deletions(-) create mode 100644 deepem/data/dataset/multi_zettaset.py diff --git a/deepem/data/dataset/multi_zettaset.py b/deepem/data/dataset/multi_zettaset.py new file mode 100644 index 0000000..4a70450 --- /dev/null +++ b/deepem/data/dataset/multi_zettaset.py @@ -0,0 +1,203 @@ +from __future__ import annotations + +import numpy as np +from numpy.typing import ArrayLike +import re + +from cloudvolume import CloudVolume, Bbox + +from zettasets.dataset import Dataset as Zettaset +from zettasets.sample import Sample + + +def is_valid_format(s: str) -> bool: + """Check if string has format 'A:B'.""" + return bool(re.match(r'^[^:]+:[^:]+$', s)) + + +def load_data( + zettaset_specs: dict[str, dict], + data_ids: list[str] | None = None, + **kwargs, +) -> dict[str, dict[str, np.ndarray]]: + """ + Load data from a zettaset. + + Parameters: + - zettaset_specs (dict[str, dict]): Specifications for the zettasets to load. + Each key is the name of the zettaset, and the value is a dictionary of + its specifications. + - data_ids (list[str], optional): Specific data identifiers to load. Defaults to None. + + Returns: + - dict[str, dict[str, np.ndarray]]: A dictionary containing the loaded zettaset data. + + Raises: + - ValueError: If `zettaset_specs` is empty or contains invalid specifications. + - KeyError: If any data id is not found in the loaded zettasets. + """ + + if not zettaset_specs: + raise ValueError("No zettaset specifications provided.") + + zettasets = _initialize_zettasets(zettaset_specs, **kwargs) + + data = {} + for data_id in data_ids or []: + if data_id in zettasets: + data.update(_process_zettaset(data_id, zettasets, zettaset_specs, **kwargs)) + else: + data.update(_process_sample(data_id, zettasets, zettaset_specs, **kwargs)) + + return data + + +def _initialize_zettasets( + zettaset_specs, + zettaset_resolution: tuple[float, float, float] | None = None, + **kwargs, +) -> dict[str, Zettaset]: + zettasets = {} + for name, spec in zettaset_specs.items(): + if "path" not in spec or not spec["path"].startswith("gs://"): + raise ValueError(f"Invalid zettaset specification for '{name}': missing or invalid 'path'.") + zettaset_path = spec["path"] + print(f"Zettaset {name} [{zettaset_path}]") + zettasets[name] = Zettaset(zettaset_path, "", zettaset_resolution) + return zettasets + + +def _process_zettaset( + zettaset_name: str, + zettasets: dict[str, Zettaset], + zettaset_specs: dict[str, dict], + **kwargs, +) -> dict[str, dict[str, dict[str, np.ndarray]]]: + """Processes a whole zettaset.""" + if zettaset_name not in zettasets: + raise KeyError(f"Zettaset '{zettaset_name}' not found.") + + data = {} + zettaset = zettasets[zettaset_name] + + for sample_name in zettaset.sample_names: + full_name = f"{zettaset_name}:{sample_name}" + data.update(_process_sample(full_name, zettasets, zettaset_specs, **kwargs)) + + return {zettaset_name: data} if data else {} + + +def _process_sample( + data_id: str, + zettasets: dict[str, Zettaset], + zettaset_specs: dict[str, dict], + zettaset_padding: tuple[int, int, int] = (0, 0, 0), + zettaset_padding_spec: dict[str, tuple[int, int, int]] = {}, + zettaset_mask: bool = True, + **kwargs, +) -> dict[str, dict[str, np.ndarray]]: + if not is_valid_format(data_id): + raise ValueError(f"Invalid format for data id '{data_id}'. Expected format 'A:B'.") + + zettaset_name, sample_name = data_id.split(':') + + if zettaset_name not in zettasets: + raise KeyError(f"Zettaset '{zettaset_name}' not found.") + + zettaset = zettasets[zettaset_name] + + if sample_name not in zettaset.sample_names: + raise KeyError(f"Sample name '{sample_name}' not found in zettaset '{zettaset_name}'.") + + assert zettaset_name in zettaset_specs + zettaset_spec = zettaset_specs[zettaset_name] + + # Determine padding: sample-specific overrides zettaset-specific + padding = zettaset_padding_spec.get(data_id, zettaset_spec.get("padding", zettaset_padding)) + no_mask = zettaset_spec.get("no_mask", not zettaset_mask) + + print(f"Sample [{data_id}]") + return {data_id: load_sample( + zettaset.samples[sample_name], + padding, + no_mask, + **kwargs, + )} + + +def load_sample( + sample: Sample, + padding: tuple[int, int, int] = (0, 0, 0), + no_mask: bool = False, + zettaset_lookup: dict[str, str] | None = None, + zettaset_resolution: tuple[int, int, int] | None = None, + requires_binarize: list[str] = [], + **kwargs +) -> dict[str, np.ndarray]: + """Load image and labels from a Sample.""" + + def convert_array(arr: ArrayLike) -> np.ndarray: + return np.array(arr).transpose(3, 2, 1, 0)[0, ...] + + dset: dict[str, np.ndarray] = {} + + # Bbox with padding + resolution = zettaset_resolution or sample.base_resolution + bbox = sample.bbox * (sample.base_resolution / np.array(resolution)) + xyz_padding = ( + tuple(reversed(padding)) + if padding != (0, 0, 0) + else (0, 0, 0) + ) + image_bbox = Bbox(bbox.minpt - xyz_padding, bbox.maxpt + xyz_padding) + + # Image + vol = CloudVolume( # pylint: disable=unsubscriptable-object + sample.src_image_path, + mip=resolution, + fill_missing=True, + bounded=False, + )[image_bbox.to_slices()] + dset["input"] = convert_array(vol) / 255.0 + print(f"\tinput: {dset['input'].shape}") + + # Assumes that zettaset's annotation names follow DeepEM's convention. + zettaset_lookup = zettaset_lookup or {x: x for x in sample.annotation_names} + + # Process annotations + for name, key in zettaset_lookup.items(): + + # Annotation + vol = sample.read(key)[key] + dset[name] = convert_array(vol) + anno_log = f"\t{name}: {dset[name].shape}" + + # Binarize + if name in requires_binarize: + dset[name] = (dset[name] > 0).astype("uint8") + + # Mask + mask_key = f"{name}_mask" + if (not no_mask) and (key in sample.masks): + mask_vol = sample.read_mask(key)[key] + dset[mask_key] = convert_array(mask_vol).astype("uint8") + else: + dset[mask_key] = np.ones_like(dset[name], dtype="uint8") + msk_log = f"\t{mask_key}: {dset[mask_key].shape}" + + # Padding + if padding != (0, 0, 0): + pad_width = tuple((p, p) for p in padding) + dset[name] = np.pad(dset[name], pad_width, "constant") + anno_log += f" -> {dset[name].shape}" + dset[mask_key] = np.pad(dset[mask_key], pad_width, "constant") + msk_log += f" -> {dset[mask_key].shape}" + + print(anno_log) + print(msk_log) + + return dset + + +if __name__ == "__main__": + pass diff --git a/deepem/data/dataset/zettaset.py b/deepem/data/dataset/zettaset.py index 2c7aca1..b4a7128 100644 --- a/deepem/data/dataset/zettaset.py +++ b/deepem/data/dataset/zettaset.py @@ -11,6 +11,7 @@ def load_data( zettaset_paths: list[str], data_ids: list[str] | None = None, + zettaset_resolution: tuple[float, float, float] | None = None, **kwargs ) -> dict[str, dict[str, np.ndarray]]: """ Load data from a zettaset.""" @@ -21,9 +22,8 @@ def load_data( zettasets = [] for zettaset_path in zettaset_paths: assert zettaset_path.startswith("gs://") - motivation = "Load a zettaset from DeepEM" print(f"Zettaset [{zettaset_path}]") - zettasets.append(Zettaset(zettaset_path, motivation)) + zettasets.append(Zettaset(zettaset_path, "", zettaset_resolution)) # Load data from a zettaset. data = {} @@ -31,7 +31,11 @@ def load_data( for zettaset in zettasets: if data_id in zettaset.sample_names: print(f"Sample [{data_id}]") - data[data_id] = load_sample(zettaset.samples[data_id], **kwargs) + data[data_id] = load_sample( + zettaset.samples[data_id], + zettaset_resolution, + **kwargs + ) break if data_id not in data: raise KeyError(f"Invalid data id:{data_id}") @@ -41,6 +45,7 @@ def load_data( def load_sample( sample: Sample, + zettaset_resolution: tuple[float, float, float] | None = None, zettaset_lookup: dict[str, str] | None = None, zettaset_padding: tuple[int, int, int] = (0, 0, 0), zettaset_mask: bool = True, @@ -54,18 +59,20 @@ def convert_array(arr: ArrayLike) -> np.ndarray: dset: dict[str, np.ndarray] = {} - # Image - if zettaset_padding == (0, 0, 0): - image_bbox = sample.bbox - else: - xyz_padding = tuple(reversed(zettaset_padding)) - image_bbox = Bbox( - sample.bbox.minpt - xyz_padding, sample.bbox.maxpt + xyz_padding - ) + # Bbox with padding + resolution = zettaset_resolution or sample.base_resolution + bbox = sample.bbox * (sample.base_resolution / np.array(resolution)) + xyz_padding = ( + tuple(reversed(zettaset_padding)) + if zettaset_padding != (0, 0, 0) + else (0, 0, 0) + ) + image_bbox = Bbox(bbox.minpt - xyz_padding, bbox.maxpt + xyz_padding) + # Image vol = CloudVolume( # pylint: disable=unsubscriptable-object sample.src_image_path, - mip=sample.base_resolution, + mip=resolution, fill_missing=True, bounded=False, )[image_bbox.to_slices()] diff --git a/deepem/train/option.py b/deepem/train/option.py index c30af48..3069a02 100644 --- a/deepem/train/option.py +++ b/deepem/train/option.py @@ -4,7 +4,7 @@ import numpy as np import samwise -from deepem.utils.py_utils import vec3 +from deepem.utils.py_utils import vec3, vec3f class Options(object): @@ -26,8 +26,11 @@ def initialize(self): # zettasets self.parser.add_argument('--zettaset_path', type=str, default=[], nargs='+') + self.parser.add_argument('--zettaset_specs', type=json.loads, default={}) self.parser.add_argument('--zettaset_lookup', type=json.loads, default=None) self.parser.add_argument('--zettaset_padding', type=vec3, default=(0, 0, 0)) + self.parser.add_argument('--zettaset_padding_spec', type=json.loads, default={}) + self.parser.add_argument('--zettaset_resolution', type=vec3f, default=None) self.parser.add_argument('--zettaset_no_mask', action='store_true') # file synchronization for spot/preemptible training @@ -309,6 +312,8 @@ def parse(self): glia_mask=opt.glia_mask, zettaset_lookup=opt.zettaset_lookup, zettaset_padding=opt.zettaset_padding, + zettaset_padding_spec=opt.zettaset_padding_spec, + zettaset_resolution=opt.zettaset_resolution, zettaset_mask=not opt.zettaset_no_mask, requires_binarize=requires_binarize, ) From dd8d81cab6379afca53f3d64552e553f4ad03d70 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Wed, 27 Dec 2023 17:31:27 +0900 Subject: [PATCH 03/47] feat(Zettaset): remove the requirement for specifying dataset module and auto-decide whether zettaset or multi-zettaset should be used. --- deepem/train/option.py | 1 - deepem/train/utils.py | 7 +++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/deepem/train/option.py b/deepem/train/option.py index 3069a02..8ea3ff2 100644 --- a/deepem/train/option.py +++ b/deepem/train/option.py @@ -18,7 +18,6 @@ def __init__(self): def initialize(self): self.parser.add_argument('--exp_name', required=True) self.parser.add_argument('--model', required=True) - self.parser.add_argument('--data', required=True) self.parser.add_argument('--sampler', required=True) self.parser.add_argument('--augment', default=None) self.parser.add_argument('--modifier', default=None) diff --git a/deepem/train/utils.py b/deepem/train/utils.py index be11580..d13f020 100644 --- a/deepem/train/utils.py +++ b/deepem/train/utils.py @@ -118,10 +118,13 @@ def save_chkpt(model, fpath, chkpt_num, optimizer): def load_data(opt): - mod = imp.load_source('data', opt.data) data_ids = list(set().union(opt.train_ids, opt.val_ids)) + if opt.zettaset_specs: + from deepem.data.dataset import multi_zettaset as mod + else: + from deepem.data.dataset import zettaset as mod data = mod.load_data( - opt.zettaset_path, + opt.zettaset_specs if opt.zettaset_specs else opt.zettaset_path, data_ids=data_ids, **opt.data_params ) From 87bf2b7bc9fde6d556ceecf9210c00882df6f0ae Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Wed, 27 Dec 2023 17:26:21 +0900 Subject: [PATCH 04/47] chore(dataset): delete unnecesarry dataset `fib25` --- deepem/data/dataset/fibsem/fib25/fib25_v0.py | 51 -------------------- 1 file changed, 51 deletions(-) delete mode 100644 deepem/data/dataset/fibsem/fib25/fib25_v0.py diff --git a/deepem/data/dataset/fibsem/fib25/fib25_v0.py b/deepem/data/dataset/fibsem/fib25/fib25_v0.py deleted file mode 100644 index 8b7ff8e..0000000 --- a/deepem/data/dataset/fibsem/fib25/fib25_v0.py +++ /dev/null @@ -1,51 +0,0 @@ -import numpy as np -import os - -import dataprovider3.emio as emio - - -fib25_dir = 'FIB-25' -data_keys = ['validation_sample'] -merger_ids = [2148] - - -def load_data(data_dir, data_ids=None, **kwargs): - if data_ids is None: - return {} - - base_dir = os.path.expanduser(base_dir) - data_dir = os.path.join(base_dir, fib25_dir) - - data = {} - for data_id in data_ids: - if data_id in data_keys: - dpath = os.path.join(data_dir, data_id) - assert os.path.exists(dpath) - data[data_id] = load_dataset(dpath, **kwargs) - return data - - -def load_dataset(dpath, **kwargs): - dset = {} - - # Image - fpath = os.path.join(dpath, "img.h5") - print(fpath) - dset['img'] = emio.imread(fpath).astype(np.float32) - dset['img'] /= 255.0 - - # Segmentation - fpath = os.path.join(dpath, "seg.h5") - print(fpath) - dset['seg'] = emio.imread(fpath).astype(np.uint8) - - # Additoinal info - dset['loc'] = True - - # Mask out mergers - dset['msk'] = np.ones(dset['seg'].shape, dtype=np.uint8) - if len(merger_ids) > 0: - idx = np.isin(dset['seg'], merger_ids) - dset['msk'][idx] = 0 - - return dset \ No newline at end of file From c1403031287702f8513e7fd8a4683457f9b74361 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Wed, 27 Dec 2023 23:31:34 +0900 Subject: [PATCH 05/47] refactor(augment): rename `ariadne_worm` to `isotropic` and add MIP1 augmentation --- .../{ariadne_worm => isotropic}/aug.py | 0 .../{ariadne_worm => isotropic}/aug_aniso.py | 0 deepem/data/augment/isotropic/aug_mip1.py | 85 +++++++++++++++++++ 3 files changed, 85 insertions(+) rename deepem/data/augment/{ariadne_worm => isotropic}/aug.py (100%) rename deepem/data/augment/{ariadne_worm => isotropic}/aug_aniso.py (100%) create mode 100644 deepem/data/augment/isotropic/aug_mip1.py diff --git a/deepem/data/augment/ariadne_worm/aug.py b/deepem/data/augment/isotropic/aug.py similarity index 100% rename from deepem/data/augment/ariadne_worm/aug.py rename to deepem/data/augment/isotropic/aug.py diff --git a/deepem/data/augment/ariadne_worm/aug_aniso.py b/deepem/data/augment/isotropic/aug_aniso.py similarity index 100% rename from deepem/data/augment/ariadne_worm/aug_aniso.py rename to deepem/data/augment/isotropic/aug_aniso.py diff --git a/deepem/data/augment/isotropic/aug_mip1.py b/deepem/data/augment/isotropic/aug_mip1.py new file mode 100644 index 0000000..2aded66 --- /dev/null +++ b/deepem/data/augment/isotropic/aug_mip1.py @@ -0,0 +1,85 @@ +"""Based on flyem/aug_mip1.py with some modifications. + +Mostly removing things that we might not need for faster +feedback cycles. +""" +from augmentor import * + + +def get_augmentation( + is_train, + recompute=False, + box=None, + blur=7, + random=False, + border=False, + **kwargs, +): + augs = list() + + # Flip & rotate + augs.append(FlipRotateIsotropic()) + + # Brightness & contrast perturbation + augs.append( + Grayscale3D( + contrast_factor=0.5, + brightness_factor=0.5, + skip=0.3, + ) + ) + # augs.append( + # MixedGrayscale2D( + # contrast_factor=0.5, + # brightness_factor=0.5, + # prob=1, + # skip=0.3, + # ) + # ) + + # Box + if is_train: + if box == "fill": + augs.append( + Blend([ + FillBox( + random=True, + dims=(3, 13), + aniso=1, + density=0.3, + margin=(3, 3, 3), + individual=True, + skip=0.1, + ), + FillBox( + random=True, + dims=(3, 13), + aniso=1, + density=0.3, + margin=(3, 3, 3), + individual=False, + skip=0.1, + ), + ]) + ) + + # Out-of-focus section + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Recompute connected components + if recompute: + augs.append(Label()) + + # Flip & rotate + augs.append(FlipRotateIsotropic()) + + # Create border + if border: + augs.append(Border()) + + return Compose(augs) From ec04cafbc2c0bf0000317e985b747c22350ca234 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Thu, 28 Dec 2023 00:51:14 +0900 Subject: [PATCH 06/47] chore(sampler): add trailing commas to last parameters in method signatures --- deepem/data/sampler/zettaset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deepem/data/sampler/zettaset.py b/deepem/data/sampler/zettaset.py index d09501e..99321d0 100644 --- a/deepem/data/sampler/zettaset.py +++ b/deepem/data/sampler/zettaset.py @@ -29,7 +29,7 @@ def __init__( spec: dict[str, tuple[int, int, int]], is_train: bool, aug: Augment | None = None, - prob: dict[str, float] | None = None + prob: dict[str, float] | None = None, ): self.is_train = is_train self.build(data, spec, aug, prob) @@ -55,8 +55,8 @@ def build( self, data: dict[str, dict[str, np.ndarray]], spec: dict[str, tuple[int, int, int]], - aug: Augment | None, - prob: dict[str, float] | None + aug: Augment | None = None, + prob: dict[str, float] | None = None, ) -> None: dp = DataProvider(spec) keys = data.keys() @@ -74,7 +74,7 @@ def build_dataset( self, tag: str, data: dict[str, np.ndarray], - spec: dict[str, tuple[int, int, int]] + spec: dict[str, tuple[int, int, int]], ) -> Dataset: """Create a Dataset.""" dset = Dataset(tag=tag) From 9e6648c5594d2dd75efa033c8c936c68dd1fb952 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Wed, 27 Dec 2023 23:36:07 +0900 Subject: [PATCH 07/47] feat(sampler): introduce `DataSuperset` support and enhance code quality --- deepem/data/sampler/zettaset.py | 61 +++++++++++++++++++++++++-------- deepem/train/data.py | 3 +- 2 files changed, 48 insertions(+), 16 deletions(-) diff --git a/deepem/data/sampler/zettaset.py b/deepem/data/sampler/zettaset.py index 99321d0..e9ea5d0 100644 --- a/deepem/data/sampler/zettaset.py +++ b/deepem/data/sampler/zettaset.py @@ -3,7 +3,7 @@ import numpy as np from augmentor import Augment -from dataprovider3 import DataProvider, Dataset +from dataprovider3 import DataProvider, Dataset, DataSuperset def get_spec( @@ -30,9 +30,10 @@ def __init__( is_train: bool, aug: Augment | None = None, prob: dict[str, float] | None = None, + zettaset_specs: dict[str, dict] | None = None, ): self.is_train = is_train - self.build(data, spec, aug, prob) + self.build(data, spec, aug, prob, zettaset_specs) def __call__(self) -> dict[str, np.ndarray]: sample = self.dataprovider() @@ -57,18 +58,47 @@ def build( spec: dict[str, tuple[int, int, int]], aug: Augment | None = None, prob: dict[str, float] | None = None, + zettaset_specs: dict[str, dict] | None = None, ) -> None: - dp = DataProvider(spec) - keys = data.keys() - for key in keys: - dp.add_dataset(self.build_dataset(key, data[key], spec)) - dp.set_augment(aug) - dp.set_imgs(["input"]) - dp.set_segs(["affinity", "long_range", "embedding"]) - prob = [prob[k] for k in keys] if prob is not None else prob - dp.set_sampling_weights(p=prob) - self.dataprovider = dp - print(dp) + """ + Builds the data provider with datasets, augmentation, and sampling weights. + """ + self.dataprovider = DataProvider(spec) + + # Add datasets to the data provider + for key, value in data.items(): + build_method = ( + self.build_datasuperset + if zettaset_specs and key in zettaset_specs + else self.build_dataset + ) + self.dataprovider.add_dataset(build_method(key, value, spec)) + + # Set augmentation, image types, and segmentation types + self.dataprovider.set_augment(aug) + self.dataprovider.set_imgs(["input"]) + self.dataprovider.set_segs(["affinity", "long_range", "embedding"]) + + # Initialize sampling weights (even if prob is None) + sampling_weights = [prob[k] for k in data.keys()] if prob else None + self.dataprovider.set_sampling_weights(p=sampling_weights) + + print(self.dataprovider) + + def build_datasuperset( + self, + tag: str, + data: dict[str, dict[str, np.ndarray]], + spec: dict[str, tuple[int, int, int]], + ) -> DataSuperset: + """Create a DataSuperset from the given data.""" + dset = DataSuperset(tag=tag) + + # Add datasets to the DataSuperset + for key, value in data.items(): + dset.add_dataset(self.build_dataset(key, value, spec)) + + return dset def build_dataset( self, @@ -76,10 +106,11 @@ def build_dataset( data: dict[str, np.ndarray], spec: dict[str, tuple[int, int, int]], ) -> Dataset: - """Create a Dataset.""" + """Create a Dataset from the given data and specification.""" dset = Dataset(tag=tag) - for key in spec.keys(): + # Iterate over the spec dictionary to add data and masks to the dataset + for key, _ in spec.items(): if key.endswith("_mask"): dset.add_mask(key=key, data=data[key], loc=True) else: diff --git a/deepem/train/data.py b/deepem/train/data.py index cfdb0f1..e320411 100644 --- a/deepem/train/data.py +++ b/deepem/train/data.py @@ -53,7 +53,8 @@ def build(self, opt, data, is_train, prob): # Data sampler mod = imp.load_source('sampler', opt.sampler) spec = mod.get_spec(opt.in_spec, opt.out_spec) - sampler = mod.Sampler(data, spec, is_train, aug, prob=prob) + zspecs = opt.zettaset_specs + sampler = mod.Sampler(data, spec, is_train, aug, prob, zspecs) # Sample modifier if opt.modifier: From 54f0d5f0fe973b4b66729d0b0d031eab90a8f98e Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Fri, 29 Dec 2023 01:18:30 +0900 Subject: [PATCH 08/47] refactor(loss): use `reduction` instead of the deprecated `size_average` and `reduce` --- deepem/loss/loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepem/loss/loss.py b/deepem/loss/loss.py index 1ea8c83..74064b1 100644 --- a/deepem/loss/loss.py +++ b/deepem/loss/loss.py @@ -43,7 +43,7 @@ def forward(self, input, target, mask): m_ext = torch.le(activ, m0) * torch.eq(target, 0) mask *= 1 - (m_int + m_ext).type(mask.dtype) - loss = self.bce(input, target, weight=mask, size_average=False) + loss = self.bce(input, target, weight=mask, reduction='sum') if self.size_average: loss = loss / nmsk.item() @@ -87,7 +87,7 @@ def forward(self, input, target, mask): m_ext = torch.le(activ, m0) * torch.eq(target, 0) mask *= 1 - (m_int + m_ext).type(mask.dtype) - loss = self.mse(activ, target, reduce=False) + loss = self.mse(activ, target, reduction='none') loss = (loss * mask).sum() if self.size_average: From 1a60c74bb57087c833891ed0d4c7df09b8e1c9f1 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Wed, 3 Jan 2024 17:11:12 +0900 Subject: [PATCH 09/47] fix(option): The parameter `data` was previously hardcoded as a required input in the training loop. This update allows users to specify `data`, even if it won't be utilized. --- deepem/train/option.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deepem/train/option.py b/deepem/train/option.py index 8ea3ff2..12294b1 100644 --- a/deepem/train/option.py +++ b/deepem/train/option.py @@ -19,6 +19,7 @@ def initialize(self): self.parser.add_argument('--exp_name', required=True) self.parser.add_argument('--model', required=True) self.parser.add_argument('--sampler', required=True) + self.parser.add_argument('--data', default=None) self.parser.add_argument('--augment', default=None) self.parser.add_argument('--modifier', default=None) self.parser.add_argument('--modifier_kwargs', type=json.loads, default={}) From 4cac78177a1e9cbb45903f35fdc9da1b6e0250b2 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Thu, 18 Jan 2024 12:17:29 +0900 Subject: [PATCH 10/47] feat(option): binarize blood vessel label --- deepem/train/option.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/deepem/train/option.py b/deepem/train/option.py index 12294b1..3f108e3 100644 --- a/deepem/train/option.py +++ b/deepem/train/option.py @@ -289,6 +289,9 @@ def parse(self): "soma", ] + if opt.blv_num_channels == 1: + requires_binarize.append("blood_vessel") + # Test training if opt.test: opt.eval_intv = 100 From 4c3ae03f75c8308634df3032f67a717c500b212a Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Thu, 18 Jan 2024 20:43:39 +0900 Subject: [PATCH 11/47] feat(Zettaset): allow sharing of a user-specified mask volume across different annotations --- deepem/data/dataset/multi_zettaset.py | 15 ++++++++++++++- deepem/train/option.py | 2 ++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/deepem/data/dataset/multi_zettaset.py b/deepem/data/dataset/multi_zettaset.py index 4a70450..7399b28 100644 --- a/deepem/data/dataset/multi_zettaset.py +++ b/deepem/data/dataset/multi_zettaset.py @@ -132,6 +132,7 @@ def load_sample( zettaset_lookup: dict[str, str] | None = None, zettaset_resolution: tuple[int, int, int] | None = None, requires_binarize: list[str] = [], + zettaset_share_mask: str | None = None, **kwargs ) -> dict[str, np.ndarray]: """Load image and labels from a Sample.""" @@ -164,6 +165,16 @@ def convert_array(arr: ArrayLike) -> np.ndarray: # Assumes that zettaset's annotation names follow DeepEM's convention. zettaset_lookup = zettaset_lookup or {x: x for x in sample.annotation_names} + # Shared mask + shared_mask = None + if (not no_mask) and zettaset_share_mask: + key = zettaset_share_mask + mask_key = f"{zettaset_share_mask}_mask" + if key not in sample.masks: + raise KeyError(f"Mask '{mask_key}' not found.") + mask_vol = sample.read_mask(key)[key] + shared_mask = convert_array(mask_vol).astype("uint8") + # Process annotations for name, key in zettaset_lookup.items(): @@ -178,7 +189,9 @@ def convert_array(arr: ArrayLike) -> np.ndarray: # Mask mask_key = f"{name}_mask" - if (not no_mask) and (key in sample.masks): + if shared_mask is not None: + dset[mask_key] = shared_mask + elif (not no_mask) and (key in sample.masks): mask_vol = sample.read_mask(key)[key] dset[mask_key] = convert_array(mask_vol).astype("uint8") else: diff --git a/deepem/train/option.py b/deepem/train/option.py index 3f108e3..c97a998 100644 --- a/deepem/train/option.py +++ b/deepem/train/option.py @@ -32,6 +32,7 @@ def initialize(self): self.parser.add_argument('--zettaset_padding_spec', type=json.loads, default={}) self.parser.add_argument('--zettaset_resolution', type=vec3f, default=None) self.parser.add_argument('--zettaset_no_mask', action='store_true') + self.parser.add_argument('--zettaset_share_mask', type=str, default=None) # file synchronization for spot/preemptible training self.parser.add_argument('--samwise_map', nargs='*', default=None) @@ -319,6 +320,7 @@ def parse(self): zettaset_resolution=opt.zettaset_resolution, zettaset_mask=not opt.zettaset_no_mask, requires_binarize=requires_binarize, + zettaset_share_mask=opt.zettaset_share_mask, ) # ONNX From bc1bd35af30d6b0c607ce15180868b019e0aa430 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Tue, 23 Jan 2024 00:42:08 +0900 Subject: [PATCH 12/47] feat(option): enhance error handling by including list details in ValueError message --- deepem/train/option.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/deepem/train/option.py b/deepem/train/option.py index c97a998..cc11d47 100644 --- a/deepem/train/option.py +++ b/deepem/train/option.py @@ -190,7 +190,12 @@ def parse(self): if (not opt.train_ids) or (not opt.val_ids): raise ValueError("Train/validation IDs unspecified") if opt.train_prob: - assert len(opt.train_ids) == len(opt.train_prob) + if len(opt.train_ids) != len(opt.train_prob): + error_message = ( + "The lengths of 'train_ids' and 'train_prob' must be the same. " + f"train_ids: {opt.train_ids}, train_prob: {opt.train_prob}" + ) + raise ValueError(error_message) if opt.val_prob: assert len(opt.val_ids) == len(opt.val_prob) From 6cbe10bf711ddb8180b5c63eac5c3f2854d3b294 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Tue, 23 Jan 2024 22:13:55 +0900 Subject: [PATCH 13/47] feat(augment): add `recompute` and `border` --- deepem/data/augment/flyem/aug_mip1.py | 21 +++++++++++++++++-- .../data/augment/pinky_basil/aug_mip1_v3.py | 21 +++++++++++++++++-- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/deepem/data/augment/flyem/aug_mip1.py b/deepem/data/augment/flyem/aug_mip1.py index 9f3eee1..cde9d28 100644 --- a/deepem/data/augment/flyem/aug_mip1.py +++ b/deepem/data/augment/flyem/aug_mip1.py @@ -1,8 +1,17 @@ from augmentor import * -def get_augmentation(is_train, box=None, missing=7, blur=7, lost=True, - random=False, **kwargs): +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=False, + border=False, + **kwargs +): augs = list() # Brightness & contrast purterbation @@ -79,7 +88,15 @@ def get_augmentation(is_train, box=None, missing=7, blur=7, lost=True, if is_train: augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + # Recompute connected components + if recompute: + augs.append(Label()) + # Flip & rotate augs.append(FlipRotate()) + # Create border + if border: + augs.append(Border()) + return Compose(augs) diff --git a/deepem/data/augment/pinky_basil/aug_mip1_v3.py b/deepem/data/augment/pinky_basil/aug_mip1_v3.py index f6105d0..21b6afe 100644 --- a/deepem/data/augment/pinky_basil/aug_mip1_v3.py +++ b/deepem/data/augment/pinky_basil/aug_mip1_v3.py @@ -1,8 +1,17 @@ from augmentor import * -def get_augmentation(is_train, box=None, missing=7, blur=7, lost=True, - random=False, **kwargs): +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=False, + border=False, + **kwargs +): augs = list() # Box @@ -72,7 +81,15 @@ def get_augmentation(is_train, box=None, missing=7, blur=7, lost=True, if is_train: augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + # Recompute connected components + if recompute: + augs.append(Label()) + # Flip & rotate augs.append(FlipRotate()) + # Create border + if border: + augs.append(Border()) + return Compose(augs) From 64db722851f121a77f03190493a0d48333772e77 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Tue, 23 Jan 2024 23:33:32 +0900 Subject: [PATCH 14/47] feat(augment): add augmentation for 16nm data --- deepem/data/augment/pinky_basil/aug_mip2.py | 95 +++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 deepem/data/augment/pinky_basil/aug_mip2.py diff --git a/deepem/data/augment/pinky_basil/aug_mip2.py b/deepem/data/augment/pinky_basil/aug_mip2.py new file mode 100644 index 0000000..7916fd3 --- /dev/null +++ b/deepem/data/augment/pinky_basil/aug_mip2.py @@ -0,0 +1,95 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=False, + border=False, + **kwargs +): + augs = list() + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1,3), dims=(3,13), margin=(1,3,3), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(3,13), margin=(1,3,3), + density=0.3, skip=0.1) + ) + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Missing section & misalignment + to_blend = list() + # Misalingments + trans = Compose([Misalign((0, 3), margin=1), + Misalign((0, 8), margin=1), + Misalign((0,13), margin=1)]) + + # Out-of-alignments + slip = Compose([SlipMisalign((0, 3), interp=True, margin=1), + SlipMisalign((0, 8), interp=True, margin=1), + SlipMisalign((0,13), interp=True, margin=1)]) + to_blend.append(Blend([trans,slip], props=[0.7,0.3])) + if is_train: + to_blend.append(Blend([ + MisalignPlusMissing((2,8), value=0, random=random), + MisalignPlusMissing((2,8), value=0, random=False) + ])) + else: + to_blend.append(MisalignPlusMissing((2,8), value=0, random=False)) + if missing > 0: + if is_train: + to_blend.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + to_blend.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + if lost: + if is_train: + to_blend.append(Blend([ + LostSection(1), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False) + ])) + augs.append(Blend(to_blend)) + + # Out-of-focus + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Recompute connected components + if recompute: + augs.append(Label()) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border()) + + return Compose(augs) From c8ace88dbff6e4d9c7ecd9e3a6928de39f0d3c4d Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Thu, 25 Jan 2024 11:36:52 +0900 Subject: [PATCH 15/47] feat(loss): implement option to toggle background masking for means-based loss --- deepem/loss/mean.py | 12 ++++++++++-- deepem/train/option.py | 2 ++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/deepem/loss/mean.py b/deepem/loss/mean.py index 1856d9c..45181ab 100644 --- a/deepem/loss/mean.py +++ b/deepem/loss/mean.py @@ -77,6 +77,7 @@ def __init__( delta_v: float = 0.0, delta_d: float = 1.5, recompute_ext: bool = False, + mask_background: bool = True, **kwargs, ): super().__init__() @@ -86,6 +87,7 @@ def __init__( self.delta_v = delta_v # Variance (intra-cluster pull force) hinge self.delta_d = delta_d # Distance (inter-cluster push force) hinge self.recompute_ext = recompute_ext + self.mask_background = mask_background def forward( self, @@ -114,9 +116,15 @@ def forward( trgt = trgt.to(torch.int) trgt *= (mask > 0).to(torch.int) - # Unique nonzero IDs + # Extract unique IDs ids = np.unique(trgt.cpu().numpy()) - ids = ids[ids != 0].tolist() + + # Remove 0s from the IDs if `mask_background` is True + if self.mask_background: + ids = ids[ids != 0] + + # Convert numpy array to a Python list + ids = ids.tolist() # Recompute external matrix mext = self.compute_ext_matrix(ids, groups, self.recompute_ext, device) diff --git a/deepem/train/option.py b/deepem/train/option.py index cc11d47..81ace1e 100644 --- a/deepem/train/option.py +++ b/deepem/train/option.py @@ -82,6 +82,7 @@ def initialize(self): self.parser.add_argument('--delta_v', type=float, default=0.0) self.parser.add_argument('--delta_d', type=float, default=1.5) self.parser.add_argument('--recompute_ext', action='store_true') + self.parser.add_argument('--no_mask_background', action='store_true') # Optimizer self.parser.add_argument('--optim', default='Adam') @@ -214,6 +215,7 @@ def parse(self): opt.metric_params['delta_v'] = opt.delta_v opt.metric_params['delta_d'] = opt.delta_d opt.metric_params['recompute_ext'] = opt.recompute_ext + opt.metric_params['mask_background'] = not opt.no_mask_background # Optimizer if opt.optim == 'Adam': From b6178bdbbc96cd6cb1a42d13341aa9e5baff8494 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Fri, 1 Mar 2024 15:47:24 +0900 Subject: [PATCH 16/47] feat(augment): add 16nm data augmentation for fly --- deepem/data/augment/flyem/aug_mip2.py | 102 ++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 deepem/data/augment/flyem/aug_mip2.py diff --git a/deepem/data/augment/flyem/aug_mip2.py b/deepem/data/augment/flyem/aug_mip2.py new file mode 100644 index 0000000..1223e4c --- /dev/null +++ b/deepem/data/augment/flyem/aug_mip2.py @@ -0,0 +1,102 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=False, + border=False, + **kwargs +): + augs = list() + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Mutually exclusive augmentations + mutex = list() + + # (1) Misalingment + trans = Compose([Misalign((0, 3), margin=1), + Misalign((0, 8), margin=1), + Misalign((0, 13), margin=1)]) + slip = Compose([SlipMisalign((0, 3), interp=True, margin=1), + SlipMisalign((0, 8), interp=True, margin=1), + SlipMisalign((0, 13), interp=True, margin=1)]) + mutex.append(Blend([trans, slip], props=[0.7, 0.3])) + + # (2) Misalignment + missing section + if is_train: + mutex.append(Blend([ + MisalignPlusMissing((2, 8), value=0, random=random), + MisalignPlusMissing((2, 8), value=0, random=False) + ])) + else: + mutex.append(MisalignPlusMissing((2, 8), value=0, random=False)) + + # (3) Missing section + if missing > 0: + if is_train: + mutex.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + mutex.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + + # (4) Lost section + if lost: + if is_train: + mutex.append(Blend([ + LostSection(1), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False) + ])) + + # Mutually exclusive augmentations + augs.append(Blend(mutex)) + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1,3), dims=(3, 13), margin=(1, 3, 3), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(3, 13), margin=(1, 3, 3), + density=0.3, skip=0.1) + ) + + # Out-of-focus section + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Recompute connected components + if recompute: + augs.append(Label()) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border()) + + return Compose(augs) From d0a6dfe2d8eea6dc401496e707735ed97c29b2d6 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Fri, 1 Mar 2024 15:59:27 +0900 Subject: [PATCH 17/47] feat(zettaset): enable zettaset-wise resolution --- deepem/data/dataset/multi_zettaset.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/deepem/data/dataset/multi_zettaset.py b/deepem/data/dataset/multi_zettaset.py index 7399b28..cc97795 100644 --- a/deepem/data/dataset/multi_zettaset.py +++ b/deepem/data/dataset/multi_zettaset.py @@ -62,8 +62,10 @@ def _initialize_zettasets( if "path" not in spec or not spec["path"].startswith("gs://"): raise ValueError(f"Invalid zettaset specification for '{name}': missing or invalid 'path'.") zettaset_path = spec["path"] - print(f"Zettaset {name} [{zettaset_path}]") - zettasets[name] = Zettaset(zettaset_path, "", zettaset_resolution) + print(f"Zettaset `{name}` from [{zettaset_path}]") + resolution = tuple(spec.get("resolution", zettaset_resolution)) + print(f"{resolution=}") + zettasets[name] = Zettaset(zettaset_path, "", resolution) return zettasets @@ -94,6 +96,7 @@ def _process_sample( zettaset_padding: tuple[int, int, int] = (0, 0, 0), zettaset_padding_spec: dict[str, tuple[int, int, int]] = {}, zettaset_mask: bool = True, + zettaset_resolution: tuple[int, int, int] | None = None, **kwargs, ) -> dict[str, dict[str, np.ndarray]]: if not is_valid_format(data_id): @@ -116,11 +119,15 @@ def _process_sample( padding = zettaset_padding_spec.get(data_id, zettaset_spec.get("padding", zettaset_padding)) no_mask = zettaset_spec.get("no_mask", not zettaset_mask) + # Determine resolution: zettaset-specific overrides zettaset_resolution + resolution = tuple(zettaset_spec.get("resolution", zettaset_resolution)) + print(f"Sample [{data_id}]") return {data_id: load_sample( zettaset.samples[sample_name], padding, no_mask, + resolution, **kwargs, )} @@ -129,8 +136,8 @@ def load_sample( sample: Sample, padding: tuple[int, int, int] = (0, 0, 0), no_mask: bool = False, + resolution: tuple[int, int, int] | None = None, zettaset_lookup: dict[str, str] | None = None, - zettaset_resolution: tuple[int, int, int] | None = None, requires_binarize: list[str] = [], zettaset_share_mask: str | None = None, **kwargs @@ -143,7 +150,7 @@ def convert_array(arr: ArrayLike) -> np.ndarray: dset: dict[str, np.ndarray] = {} # Bbox with padding - resolution = zettaset_resolution or sample.base_resolution + resolution = resolution or sample.base_resolution bbox = sample.bbox * (sample.base_resolution / np.array(resolution)) xyz_padding = ( tuple(reversed(padding)) From 45fdd2edb23cb2012ddb7669ce3d16a314f1504c Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Fri, 1 Mar 2024 16:01:13 +0900 Subject: [PATCH 18/47] chore(Dockerfile): default to PyTorch 1.11.0 --- docker/zettasets/Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/zettasets/Dockerfile b/docker/zettasets/Dockerfile index 345e12a..ee4cf48 100644 --- a/docker/zettasets/Dockerfile +++ b/docker/zettasets/Dockerfile @@ -8,9 +8,9 @@ RUN apt-get --allow-releaseinfo-change update && \ git clone https://${GIT_ACCESS_TOKEN}@github.com/ZettaAI/zettasets.git && \ rm -rf /var/lib/apt/lists/* -# FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime +FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime # FROM pytorch/pytorch:1.2-cuda10.0-cudnn7-runtime -FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime +# FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime # FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime ARG DEBIAN_FRONTEND=noninteractive From 1fae11ff9bdf6eae5fbbac37ffe23b1fead27be6 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Fri, 1 Mar 2024 22:10:27 +0900 Subject: [PATCH 19/47] fix(Dockerfile): install `cffi` and `brotli` to address dependency issues with `cloud-files` --- docker/zettasets/Dockerfile | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docker/zettasets/Dockerfile b/docker/zettasets/Dockerfile index ee4cf48..01ff233 100644 --- a/docker/zettasets/Dockerfile +++ b/docker/zettasets/Dockerfile @@ -38,6 +38,9 @@ RUN apt-get update \ # pypi packages && pip install --no-cache-dir --upgrade \ numpy cloud-volume task-queue tensorboardX imgaug wandb \ + # address cloud-files import issue + && pip install --no-cache-dir --upgrade \ + cffi brotli \ # github packages && pip install --no-cache-dir \ git+https://github.com/seung-lab/DataTools \ From 324873be04d7ad2aedc8964a1907a6a4e8b2d240 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Sat, 2 Mar 2024 00:26:33 +0900 Subject: [PATCH 20/47] chore(zettaset): add conversion of padding to tuple --- deepem/data/dataset/multi_zettaset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepem/data/dataset/multi_zettaset.py b/deepem/data/dataset/multi_zettaset.py index cc97795..328a468 100644 --- a/deepem/data/dataset/multi_zettaset.py +++ b/deepem/data/dataset/multi_zettaset.py @@ -116,7 +116,7 @@ def _process_sample( zettaset_spec = zettaset_specs[zettaset_name] # Determine padding: sample-specific overrides zettaset-specific - padding = zettaset_padding_spec.get(data_id, zettaset_spec.get("padding", zettaset_padding)) + padding = tuple(zettaset_padding_spec.get(data_id, zettaset_spec.get("padding", zettaset_padding))) no_mask = zettaset_spec.get("no_mask", not zettaset_mask) # Determine resolution: zettaset-specific overrides zettaset_resolution From f48ec5d308e00d97f13995bb2a956ac5e3aa7d1c Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Thu, 28 Mar 2024 14:48:21 +0900 Subject: [PATCH 21/47] feat: data augmentation for dacey human retina --- deepem/data/augment/retina/aug_20nm.py | 102 +++++++++++++++++++++++++ deepem/data/augment/retina/aug_mip2.py | 95 +++++++++++++++++++++++ 2 files changed, 197 insertions(+) create mode 100644 deepem/data/augment/retina/aug_20nm.py create mode 100644 deepem/data/augment/retina/aug_mip2.py diff --git a/deepem/data/augment/retina/aug_20nm.py b/deepem/data/augment/retina/aug_20nm.py new file mode 100644 index 0000000..db70eed --- /dev/null +++ b/deepem/data/augment/retina/aug_20nm.py @@ -0,0 +1,102 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=False, + border=[], + **kwargs +): + augs = list() + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1,3), dims=(5,25), margin=(1,5,5), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(5,25), margin=(1,5,5), + density=0.3, skip=0.1) + ) + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Mutually-exclusive data augmentations + to_blend = list() + + # Step + step = Compose([Misalign((0, 5), margin=1), + Misalign((0,15), margin=1), + Misalign((0,25), margin=1)]) + + # Slip + slip = Compose([SlipMisalign((0, 5), interp=True, margin=1), + SlipMisalign((0,15), interp=True, margin=1), + SlipMisalign((0,25), interp=True, margin=1)]) + + # Step + slip + to_blend.append(Blend([step, slip], props=[0.7, 0.3])) + + # Missing sections + if missing > 0: + if is_train: + to_blend.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + to_blend.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + + # Lost sections + if lost: + if is_train: + to_blend.append( + Blend( + [ + LostSection(1), + LostSection(2), + LostSection(3), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False), + ], + props=[0.4, 0.3, 0.2, 0.05, 0.05], + ) + ) + augs.append(Blend(to_blend)) + + # Out-of-focus sections + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border(targets=border)) + + # Recompute connected components + if recompute: + augs.append(Label()) + + return Compose(augs) diff --git a/deepem/data/augment/retina/aug_mip2.py b/deepem/data/augment/retina/aug_mip2.py new file mode 100644 index 0000000..0f3da1c --- /dev/null +++ b/deepem/data/augment/retina/aug_mip2.py @@ -0,0 +1,95 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=False, + border=[], + **kwargs +): + augs = list() + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1,3), dims=(3,13), margin=(1,3,3), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(3,13), margin=(1,3,3), + density=0.3, skip=0.1) + ) + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Missing section & misalignment + to_blend = list() + # Misalingments + trans = Compose([Misalign((0, 3), margin=1), + Misalign((0, 8), margin=1), + Misalign((0,13), margin=1)]) + + # Out-of-alignments + slip = Compose([SlipMisalign((0, 3), interp=True, margin=1), + SlipMisalign((0, 8), interp=True, margin=1), + SlipMisalign((0,13), interp=True, margin=1)]) + to_blend.append(Blend([trans,slip], props=[0.7,0.3])) + if is_train: + to_blend.append(Blend([ + MisalignPlusMissing((2,8), value=0, random=random), + MisalignPlusMissing((2,8), value=0, random=False) + ])) + else: + to_blend.append(MisalignPlusMissing((2,8), value=0, random=False)) + if missing > 0: + if is_train: + to_blend.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + to_blend.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + if lost: + if is_train: + to_blend.append(Blend([ + LostSection(1), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False) + ])) + augs.append(Blend(to_blend)) + + # Out-of-focus + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border(targets=border)) + + # Recompute connected components + if recompute: + augs.append(Label()) + + return Compose(augs) From 63770e29e54bd772c4affd4d4f167833b316f163 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Fri, 12 Apr 2024 21:25:52 +0900 Subject: [PATCH 22/47] feat: data augmentation for ariadne mouse cerebellum --- deepem/data/augment/isotropic/aug.py | 17 ++++------------- deepem/data/augment/isotropic/aug_aniso.py | 17 ++++------------- deepem/data/augment/isotropic/aug_mip1.py | 21 ++++++--------------- 3 files changed, 14 insertions(+), 41 deletions(-) diff --git a/deepem/data/augment/isotropic/aug.py b/deepem/data/augment/isotropic/aug.py index 6f1504e..b7bb4fa 100644 --- a/deepem/data/augment/isotropic/aug.py +++ b/deepem/data/augment/isotropic/aug.py @@ -11,7 +11,6 @@ def get_augmentation( recompute=False, box=None, blur=7, - random=False, border=[], **kwargs, ): @@ -28,14 +27,6 @@ def get_augmentation( skip=0.3, ) ) - # augs.append( - # MixedGrayscale2D( - # contrast_factor=0.5, - # brightness_factor=0.5, - # prob=1, - # skip=0.3, - # ) - # ) # Box if is_train: @@ -71,10 +62,6 @@ def get_augmentation( if is_train: augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) - # Recompute connected components - if recompute: - augs.append(Label()) - # Flip & rotate augs.append(FlipRotateIsotropic()) @@ -82,4 +69,8 @@ def get_augmentation( if border: augs.append(Border(targets=border)) + # Recompute connected components + if recompute: + augs.append(Label()) + return Compose(augs) diff --git a/deepem/data/augment/isotropic/aug_aniso.py b/deepem/data/augment/isotropic/aug_aniso.py index 364eb68..38b7043 100644 --- a/deepem/data/augment/isotropic/aug_aniso.py +++ b/deepem/data/augment/isotropic/aug_aniso.py @@ -11,7 +11,6 @@ def get_augmentation( recompute=False, box=None, blur=7, - random=False, border=[], **kwargs ): @@ -28,14 +27,6 @@ def get_augmentation( skip=0.3, ) ) - # augs.append( - # MixedGrayscale2D( - # contrast_factor=0.5, - # brightness_factor=0.5, - # prob=1, - # skip=0.3, - # ) - # ) # Box if is_train: @@ -71,10 +62,6 @@ def get_augmentation( if is_train: augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) - # Recompute connected components - if recompute: - augs.append(Label()) - # Flip & rotate augs.append(FlipRotate()) @@ -82,4 +69,8 @@ def get_augmentation( if border: augs.append(Border(targets=border)) + # Recompute connected components + if recompute: + augs.append(Label()) + return Compose(augs) diff --git a/deepem/data/augment/isotropic/aug_mip1.py b/deepem/data/augment/isotropic/aug_mip1.py index 2aded66..91913b8 100644 --- a/deepem/data/augment/isotropic/aug_mip1.py +++ b/deepem/data/augment/isotropic/aug_mip1.py @@ -11,8 +11,7 @@ def get_augmentation( recompute=False, box=None, blur=7, - random=False, - border=False, + border=[], **kwargs, ): augs = list() @@ -28,14 +27,6 @@ def get_augmentation( skip=0.3, ) ) - # augs.append( - # MixedGrayscale2D( - # contrast_factor=0.5, - # brightness_factor=0.5, - # prob=1, - # skip=0.3, - # ) - # ) # Box if is_train: @@ -71,15 +62,15 @@ def get_augmentation( if is_train: augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) - # Recompute connected components - if recompute: - augs.append(Label()) - # Flip & rotate augs.append(FlipRotateIsotropic()) # Create border if border: - augs.append(Border()) + augs.append(Border(targets=border)) + + # Recompute connected components + if recompute: + augs.append(Label()) return Compose(augs) From 694b9c92ffd9d32615dc0dfbfab362a7be3afa86 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Thu, 25 Apr 2024 15:40:07 +0900 Subject: [PATCH 23/47] WIP(mean): fix a bug when mask_background is False --- deepem/loss/mean.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/deepem/loss/mean.py b/deepem/loss/mean.py index 45181ab..1fc252a 100644 --- a/deepem/loss/mean.py +++ b/deepem/loss/mean.py @@ -114,10 +114,9 @@ def forward( trgt = splt trgt = trgt.to(torch.int) - trgt *= (mask > 0).to(torch.int) # Extract unique IDs - ids = np.unique(trgt.cpu().numpy()) + ids = np.unique(trgt[mask > 0].cpu().numpy()) # Remove 0s from the IDs if `mask_background` is True if self.mask_background: @@ -128,7 +127,7 @@ def forward( # Recompute external matrix mext = self.compute_ext_matrix(ids, groups, self.recompute_ext, device) - vecs = self.generate_vecs(embd, trgt, ids) + vecs = self.generate_vecs(embd, trgt, mask, ids) means = [torch.mean(vec, dim=0) for vec in vecs] weights = [1.0] * len(vecs) @@ -194,17 +193,29 @@ def generate_vecs( self, embd: torch.Tensor, trgt: torch.Tensor, + mask: torch.Tensor, ids: Sequence[int], ) -> list[torch.Tensor]: """ Generate a list of vectorized embeddings for each ground truth object. """ + if self.mask_background and 0 in ids: + raise ValueError("ID '0' is not allowed when mask_background is enabled.") + + mask_bool = mask.bool() if not self.mask_background else None result = [] + for obj_id in ids: - obj = torch.nonzero(trgt == int(obj_id)) - z, y, x = obj[:, -3], obj[:, -2], obj[:, -1] - vec = embd[0, :, z, y, x].transpose(0, 1) # Count x Dim + obj_mask = (trgt == int(obj_id)) & mask_bool if mask_bool is not None else (trgt == int(obj_id)) + idx = torch.nonzero(obj_mask, as_tuple=True) + + if idx[0].numel() == 0: + # If there are no indices for this ID, skip to the next one + continue + + vec = embd[0, :, idx[-3], idx[-2], idx[-1]].transpose(0, 1) # Count x Dim result.append(vec) + return result def compute_ext_matrix( From 7c9f370292eda628e4f08d997c6d711d09aa4521 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Thu, 25 Apr 2024 18:28:44 +0900 Subject: [PATCH 24/47] chore(Dockerfile): update --- docker/zettasets/Dockerfile | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docker/zettasets/Dockerfile b/docker/zettasets/Dockerfile index 01ff233..b8d7888 100644 --- a/docker/zettasets/Dockerfile +++ b/docker/zettasets/Dockerfile @@ -12,7 +12,7 @@ FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime # FROM pytorch/pytorch:1.2-cuda10.0-cudnn7-runtime # FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime # FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime -ARG DEBIAN_FRONTEND=noninteractive +ENV DEBIAN_FRONTEND=noninteractive COPY --from=intermediate /tmp/zettasets /tmp/zettasets @@ -22,15 +22,16 @@ RUN apt-get update \ libboost-all-dev \ # gcloud cli (for samwise) curl apt-transport-https ca-certificates gnupg \ - && rm -rf /var/lib/apt/lists/* \ # gcloud cli (for samwise) && echo "deb https://packages.cloud.google.com/apt cloud-sdk main" \ | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list \ && curl https://packages.cloud.google.com/apt/doc/apt-key.gpg \ | apt-key add - \ - && apt-get update && apt-get install google-cloud-cli \ + && apt-get update && apt-get install google-cloud-cli -y \ # handle imgaug issue && apt-get install ffmpeg libsm6 libxext6 -y \ + # Cleanup + && rm -rf /var/lib/apt/lists/* \ # python requirements && conda install h5py cython matplotlib \ scikit-image scikit-learn \ From dd399e0af01207a6d29c9b9e7bfdaf4a14dc4922 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Fri, 10 May 2024 09:22:01 +0900 Subject: [PATCH 25/47] refactor: update Iterable import for Python 3.9+ compatibility and remove deprecated future import --- deepem/models/pinky_mito.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/deepem/models/pinky_mito.py b/deepem/models/pinky_mito.py index 036a54a..4c37021 100644 --- a/deepem/models/pinky_mito.py +++ b/deepem/models/pinky_mito.py @@ -1,4 +1,3 @@ -from __future__ import print_function from itertools import repeat import collections import math @@ -26,7 +25,7 @@ def _ntuple(n): Copied from the PyTorch source code (https://github.com/pytorch). """ def parse(x): - if isinstance(x, collections.Iterable): + if isinstance(x, collections.abc.Iterable): return x return tuple(repeat(x, n)) return parse From 7f1c1fa22484ea76bef3372e52de66f1b0696bb1 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Tue, 14 May 2024 13:21:18 +0900 Subject: [PATCH 26/47] feat: 3x up-down net --- deepem/models/updown3x_act.py | 78 +++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 deepem/models/updown3x_act.py diff --git a/deepem/models/updown3x_act.py b/deepem/models/updown3x_act.py new file mode 100644 index 0000000..439ec4e --- /dev/null +++ b/deepem/models/updown3x_act.py @@ -0,0 +1,78 @@ +import torch +import torch.nn as nn + +import emvision +from emvision.models import rsunet_act, rsunet_act_gn + +from deepem.models.layers import Conv, Crop + + +def create_model(opt): + if opt.width: + width = opt.width + depth = len(width) + else: + width = [16,32,64,128,256,512] + depth = opt.depth + if opt.group > 0: + # Group normalization + core = rsunet_act_gn(width=width[:depth], group=opt.group, act=opt.act) + else: + # Batch normalization + core = rsunet_act(width=width[:depth], act=opt.act) + return Model(core, opt.in_spec, opt.out_spec, width[0], crop=opt.crop) + + +class InputBlock(nn.Sequential): + def __init__(self, in_channels, out_channels, kernel_size): + super(InputBlock, self).__init__() + self.add_module('conv', Conv(in_channels, out_channels, kernel_size)) + + +class OutputBlock(nn.Module): + def __init__(self, in_channels, out_spec, kernel_size): + super(OutputBlock, self).__init__() + for k, v in out_spec.items(): + out_channels = v[-4] + self.add_module(k, + Conv(in_channels, out_channels, kernel_size, bias=True)) + + def forward(self, x): + return {k: m(x) for k, m in self.named_children()} + + +class DownBlock(nn.Sequential): + def __init__(self, scale_factor=(1,3,3)): + super(DownBlock, self).__init__() + self.add_module('down', nn.AvgPool3d(scale_factor)) + + +class UpBlock(nn.Module): + def __init__(self, out_spec, scale_factor=(1,3,3)): + super(UpBlock, self).__init__() + for k, v in out_spec.items(): + self.add_module(k, + nn.Upsample(scale_factor=scale_factor, mode='trilinear')) + + def forward(self, x): + return {k: m(x[k]) for k, m in self.named_children()} + + +class Model(nn.Sequential): + """ + Residual Symmetric U-Net with down/upsampling in/output. + """ + def __init__(self, core, in_spec, out_spec, out_channels, io_kernel=(1,5,5), + scale_factor=(1,3,3), crop=None): + super(Model, self).__init__() + + assert len(in_spec)==1, "model takes a single input" + in_channels = 1 + + self.add_module('down', DownBlock(scale_factor=scale_factor)) + self.add_module('in', InputBlock(in_channels, out_channels, io_kernel)) + self.add_module('core', core) + self.add_module('out', OutputBlock(out_channels, out_spec, io_kernel)) + self.add_module('up', UpBlock(out_spec, scale_factor=scale_factor)) + if crop is not None: + self.add_module('crop', Crop(crop)) From 275b6a7542ad1bd6b13d99041e14d4a9a00b5064 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Tue, 14 May 2024 14:16:00 +0900 Subject: [PATCH 27/47] feat: up-down by an arbitrary factor --- deepem/models/updown_act_interpolate.py | 88 +++++++++++++++++++++++++ deepem/test/option.py | 1 + deepem/train/option.py | 1 + 3 files changed, 90 insertions(+) create mode 100644 deepem/models/updown_act_interpolate.py diff --git a/deepem/models/updown_act_interpolate.py b/deepem/models/updown_act_interpolate.py new file mode 100644 index 0000000..8e21a3e --- /dev/null +++ b/deepem/models/updown_act_interpolate.py @@ -0,0 +1,88 @@ +import torch.nn as nn +import torch.nn.functional as F + +from emvision.models import rsunet_act, rsunet_act_gn + +from deepem.models.layers import Conv, Crop + + +def create_model(opt): + if opt.width: + width = opt.width + depth = len(width) + else: + width = [16, 32, 64, 128, 256, 512] + depth = opt.depth + if opt.group > 0: + # Group normalization + core = rsunet_act_gn(width=width[:depth], group=opt.group, act=opt.act) + else: + # Batch normalization + core = rsunet_act(width=width[:depth], act=opt.act) + return Model(core, opt.in_spec, opt.out_spec, width[0], crop=opt.crop, + scale_factor=opt.updown_scale_factor) + + +class InputBlock(nn.Sequential): + def __init__(self, in_channels, out_channels, kernel_size): + super(InputBlock, self).__init__() + self.add_module('conv', Conv(in_channels, out_channels, kernel_size)) + + +class OutputBlock(nn.Module): + def __init__(self, in_channels, out_spec, kernel_size): + super(OutputBlock, self).__init__() + for k, v in out_spec.items(): + out_channels = v[-4] + self.add_module(k, + Conv(in_channels, out_channels, kernel_size, bias=True)) + + def forward(self, x): + return {k: m(x) for k, m in self.named_children()} + + +class DownBlock(nn.Module): + def __init__(self, size): + super(DownBlock, self).__init__() + self.size = size + + def forward(self, x): + return F.interpolate(x, size=self.size, mode='trilinear', align_corners=False) + + +class UpBlock(nn.Module): + def __init__(self, out_spec, size): + super(UpBlock, self).__init__() + for k, v in out_spec.items(): + self.add_module(k, + nn.Upsample( + size=size, + mode='trilinear', + recompute_scale_factor=False, + )) + + def forward(self, x): + return {k: m(x[k]) for k, m in self.named_children()} + + +class Model(nn.Sequential): + """ + Residual Symmetric U-Net with down/upsampling in/output. + """ + def __init__(self, core, in_spec, out_spec, out_channels, io_kernel=(1,5,5), + scale_factor=(1, 2, 2), crop=None): + super(Model, self).__init__() + + assert len(in_spec)==1, "model takes a single input" + in_channels = 1 + in_size = in_spec['input'][-3:] + assert all(s % f == 0 for s, f in zip(in_size, scale_factor)) + new_size = tuple(int(s / f) for s, f in zip(in_size, scale_factor)) + + self.add_module('down', DownBlock(size=new_size)) + self.add_module('in', InputBlock(in_channels, out_channels, io_kernel)) + self.add_module('core', core) + self.add_module('out', OutputBlock(out_channels, out_spec, io_kernel)) + self.add_module('up', UpBlock(out_spec, size=in_size)) + if crop is not None: + self.add_module('crop', Crop(crop)) diff --git a/deepem/test/option.py b/deepem/test/option.py index 03e7544..b39593f 100644 --- a/deepem/test/option.py +++ b/deepem/test/option.py @@ -38,6 +38,7 @@ def initialize(self): self.parser.add_argument('--width', type=int, default=None, nargs='+') self.parser.add_argument('--group', type=int, default=0) self.parser.add_argument('--act', default='ReLU') + self.parser.add_argument('--updown_scale_factor', type=vec3f, default=None) # Tilt-series electron tomography self.parser.add_argument('--tilt_series', type=int, default=0) diff --git a/deepem/train/option.py b/deepem/train/option.py index 81ace1e..2e70ec5 100644 --- a/deepem/train/option.py +++ b/deepem/train/option.py @@ -104,6 +104,7 @@ def initialize(self): self.parser.add_argument('--width', type=int, default=None, nargs='+') self.parser.add_argument('--group', type=int, default=0) self.parser.add_argument('--act', default='ReLU') + self.parser.add_argument('--updown_scale_factor', type=vec3f, default=None) # Data augmentation self.parser.add_argument('--recompute', action='store_true') From 5dfb9eae475bccda1dfebbf102647f79bede3c17 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Tue, 14 May 2024 14:31:08 +0900 Subject: [PATCH 28/47] feat: cortex augmentation --- deepem/data/augment/cortex/aug_16nm.py | 95 ++++++++++++++++++++++++++ deepem/data/augment/cortex/aug_4nm.py | 95 ++++++++++++++++++++++++++ deepem/data/augment/cortex/aug_8nm.py | 95 ++++++++++++++++++++++++++ 3 files changed, 285 insertions(+) create mode 100644 deepem/data/augment/cortex/aug_16nm.py create mode 100644 deepem/data/augment/cortex/aug_4nm.py create mode 100644 deepem/data/augment/cortex/aug_8nm.py diff --git a/deepem/data/augment/cortex/aug_16nm.py b/deepem/data/augment/cortex/aug_16nm.py new file mode 100644 index 0000000..3bac7bb --- /dev/null +++ b/deepem/data/augment/cortex/aug_16nm.py @@ -0,0 +1,95 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=False, + border=[], + **kwargs +): + augs = list() + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1, 3), dims=(3, 13), margin=(1, 3, 3), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(3, 13), margin=(1, 3, 3), + density=0.3, skip=0.1) + ) + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Missing section & misalignment + to_blend = list() + # Misalingments + trans = Compose([Misalign((0, 3), margin=1), + Misalign((0, 8), margin=1), + Misalign((0, 13), margin=1)]) + + # Out-of-alignments + slip = Compose([SlipMisalign((0, 3), interp=True, margin=1), + SlipMisalign((0, 8), interp=True, margin=1), + SlipMisalign((0, 13), interp=True, margin=1)]) + to_blend.append(Blend([trans, slip], props=[0.7, 0.3])) + if is_train: + to_blend.append(Blend([ + MisalignPlusMissing((2, 8), value=0, random=random), + MisalignPlusMissing((2, 8), value=0, random=False) + ])) + else: + to_blend.append(MisalignPlusMissing((2, 8), value=0, random=False)) + if missing > 0: + if is_train: + to_blend.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + to_blend.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + if lost: + if is_train: + to_blend.append(Blend([ + LostSection(1), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False) + ])) + augs.append(Blend(to_blend)) + + # Out-of-focus + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border(targets=border)) + + # Recompute connected components + if recompute: + augs.append(Label()) + + return Compose(augs) diff --git a/deepem/data/augment/cortex/aug_4nm.py b/deepem/data/augment/cortex/aug_4nm.py new file mode 100644 index 0000000..94a62f6 --- /dev/null +++ b/deepem/data/augment/cortex/aug_4nm.py @@ -0,0 +1,95 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=False, + border=[], + **kwargs +): + augs = list() + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1, 3), dims=(10, 50), margin=(1, 10, 10), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(10, 50), margin=(1, 10, 10), + density=0.3, skip=0.1) + ) + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Missing section & misalignment + to_blend = list() + # Misalingments + trans = Compose([Misalign((0, 10), margin=1), + Misalign((0, 30), margin=1), + Misalign((0, 50), margin=1)]) + + # Out-of-alignments + slip = Compose([SlipMisalign((0, 10), interp=True, margin=1), + SlipMisalign((0, 30), interp=True, margin=1), + SlipMisalign((0, 50), interp=True, margin=1)]) + to_blend.append(Blend([trans, slip], props=[0.7, 0.3])) + if is_train: + to_blend.append(Blend([ + MisalignPlusMissing((5, 30), value=0, random=random), + MisalignPlusMissing((5, 30), value=0, random=False) + ])) + else: + to_blend.append(MisalignPlusMissing((5, 30), value=0, random=False)) + if missing > 0: + if is_train: + to_blend.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + to_blend.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + if lost: + if is_train: + to_blend.append(Blend([ + LostSection(1), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False) + ])) + augs.append(Blend(to_blend)) + + # Out-of-focus + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border(targets=border)) + + # Recompute connected components + if recompute: + augs.append(Label()) + + return Compose(augs) diff --git a/deepem/data/augment/cortex/aug_8nm.py b/deepem/data/augment/cortex/aug_8nm.py new file mode 100644 index 0000000..48c4450 --- /dev/null +++ b/deepem/data/augment/cortex/aug_8nm.py @@ -0,0 +1,95 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=False, + border=[], + **kwargs +): + augs = list() + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1, 3), dims=(5, 25), margin=(1, 5, 5), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(5, 25), margin=(1, 5, 5), + density=0.3, skip=0.1) + ) + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Missing section & misalignment + to_blend = list() + # Misalingments + trans = Compose([Misalign((0, 5), margin=1), + Misalign((0, 15), margin=1), + Misalign((0, 25), margin=1)]) + + # Out-of-alignments + slip = Compose([SlipMisalign((0, 5), interp=True, margin=1), + SlipMisalign((0, 15), interp=True, margin=1), + SlipMisalign((0, 25), interp=True, margin=1)]) + to_blend.append(Blend([trans, slip], props=[0.7, 0.3])) + if is_train: + to_blend.append(Blend([ + MisalignPlusMissing((3, 15), value=0, random=random), + MisalignPlusMissing((3, 15), value=0, random=False) + ])) + else: + to_blend.append(MisalignPlusMissing((3, 15), value=0, random=False)) + if missing > 0: + if is_train: + to_blend.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + to_blend.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + if lost: + if is_train: + to_blend.append(Blend([ + LostSection(1), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False) + ])) + augs.append(Blend(to_blend)) + + # Out-of-focus + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border(targets=border)) + + # Recompute connected components + if recompute: + augs.append(Label()) + + return Compose(augs) From 104740c83137090f595fa959321ae31b7fc4d12f Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Thu, 16 May 2024 01:47:01 +0900 Subject: [PATCH 29/47] feat: support semantic segmentation --- deepem/data/dataset/multi_zettaset.py | 7 ++- deepem/test/option.py | 71 ++++++++++++++++++++++++++- deepem/train/option.py | 34 +++++++++++-- 3 files changed, 105 insertions(+), 7 deletions(-) diff --git a/deepem/data/dataset/multi_zettaset.py b/deepem/data/dataset/multi_zettaset.py index 328a468..b5474e1 100644 --- a/deepem/data/dataset/multi_zettaset.py +++ b/deepem/data/dataset/multi_zettaset.py @@ -140,6 +140,7 @@ def load_sample( zettaset_lookup: dict[str, str] | None = None, requires_binarize: list[str] = [], zettaset_share_mask: str | None = None, + semantic_mapping: dict[str, int] = {}, **kwargs ) -> dict[str, np.ndarray]: """Load image and labels from a Sample.""" @@ -190,8 +191,10 @@ def convert_array(arr: ArrayLike) -> np.ndarray: dset[name] = convert_array(vol) anno_log = f"\t{name}: {dset[name].shape}" - # Binarize - if name in requires_binarize: + # Semantic mapping or binarize + if name in semantic_mapping: + dset[name] = (dset[name] == semantic_mapping[name]).astype("uint8") + elif name in requires_binarize: dset[name] = (dset[name] > 0).astype("uint8") # Mask diff --git a/deepem/test/option.py b/deepem/test/option.py index b39593f..70dd9d6 100644 --- a/deepem/test/option.py +++ b/deepem/test/option.py @@ -1,4 +1,5 @@ import argparse +from collections import OrderedDict import json import math import os @@ -64,11 +65,20 @@ def initialize(self): self.parser.add_argument('--mye', action='store_true') self.parser.add_argument('--mye_thresh', type=float, default=0.5) self.parser.add_argument('--blv', action='store_true') - self.parser.add_argument('--blv_num_channels', type=int, default=2) - self.parser.add_argument('--glia', action='store_true') + self.parser.add_argument('--blv_num_channels', type=int, default=1) + self.parser.add_argument('--glia', action='store_true') self.parser.add_argument('--sem', action='store_true') self.parser.add_argument('--img', action='store_true') + # Semantic segmentation + self.parser.add_argument('--semantic', action='store_true') + self.parser.add_argument('--dend', action='store_true') # Dendrite + self.parser.add_argument('--axon', action='store_true') # Axon + self.parser.add_argument('--soma', action='store_true') # Soma + self.parser.add_argument('--nucl', action='store_true') # Nucleus + self.parser.add_argument('--ecs', action='store_true') # Extracellular space + self.parser.add_argument('--other', action='store_true') # Other class + # Test-time augmentation self.parser.add_argument('--test_aug', type=int, default=None, nargs='+') self.parser.add_argument('--test_aug16', action='store_true') @@ -204,6 +214,35 @@ def parse(self): opt.out_spec['bvessel'] = (1,) + opt.outputsz if opt.img: opt.out_spec['image'] = (1,) + opt.outputsz + if opt.dend: + opt.out_spec['dendrite'] = (1,) + opt.outputsz + if opt.axon: + opt.out_spec['axon'] = (1,) + opt.outputsz + if opt.soma: + opt.out_spec['soma'] = (1,) + opt.outputsz + if opt.nucl: + opt.out_spec['nucleus'] = (1,) + opt.outputsz + if opt.ecs: + opt.out_spec['extracellular_space'] = (1,) + opt.outputsz + if opt.other: + opt.out_spec['other_class'] = (1,) + opt.outputsz + + # Semantic segmentation + if opt.semantic: + required_keys = ['soma', 'axon', 'dendrite', 'glia', 'blood_vessel'] + + # Ensure all required keys are present in the opt.out_spec + assert all(key in opt.out_spec for key in required_keys) + + # Use OrderedDict to maintain order of required keys followed by other keys + out_spec_new = OrderedDict((key, opt.out_spec[key]) for key in required_keys) + + # Add remaining keys to out_spec_new + out_spec_new.update((key, opt.out_spec[key]) for key in opt.out_spec if key not in required_keys) + + # Convert back to standard dict if necessary + opt.out_spec = dict(out_spec_new) + assert(len(opt.out_spec) > 0) # Scan spec @@ -241,6 +280,34 @@ def parse(self): opt.scan_spec['bvessel'] = (1,) + opt.outputsz if opt.img: opt.scan_spec['image'] = (1,) + opt.outputsz + if opt.dend: + opt.scan_spec['dendrite'] = (1,) + opt.outputsz + if opt.axon: + opt.scan_spec['axon'] = (1,) + opt.outputsz + if opt.soma: + opt.scan_spec['soma'] = (1,) + opt.outputsz + if opt.nucl: + opt.scan_spec['nucleus'] = (1,) + opt.outputsz + if opt.ecs: + opt.scan_spec['extracellular_space'] = (1,) + opt.outputsz + if opt.other: + opt.scan_spec['other_class'] = (1,) + opt.outputsz + + # Semantic segmentation + if opt.semantic: + required_keys = ['soma', 'axon', 'dendrite', 'glia', 'blood_vessel'] + + # Ensure all required keys are present in the opt.scan_spec + assert all(key in opt.scan_spec for key in required_keys) + + # Use OrderedDict to maintain order of required keys followed by other keys + scan_spec_new = OrderedDict((key, opt.scan_spec[key]) for key in required_keys) + + # Add remaining keys to scan_spec_new + scan_spec_new.update((key, opt.scan_spec[key]) for key in opt.scan_spec if key not in required_keys) + + # Convert back to standard dict if necessary + opt.scan_spec = dict(scan_spec_new) # Test-time augmentation if opt.test_aug16: diff --git a/deepem/train/option.py b/deepem/train/option.py index 2e70ec5..e4ddeb3 100644 --- a/deepem/train/option.py +++ b/deepem/train/option.py @@ -144,12 +144,20 @@ def initialize(self): self.parser.add_argument('--mye', type=float, default=0) # Myelin self.parser.add_argument('--fld', type=float, default=0) # Fold self.parser.add_argument('--blv', type=float, default=0) # Blood vessel - self.parser.add_argument('--blv_num_channels', type=int, default=2) - self.parser.add_argument('--glia', type=float, default=0) # Glia + self.parser.add_argument('--blv_num_channels', type=int, default=1) + self.parser.add_argument('--glia', type=float, default=0) # Glia self.parser.add_argument('--glia_mask', action='store_true') - self.parser.add_argument('--soma', type=float, default=0) # Soma self.parser.add_argument('--img', type=float, default=0) # Image + # Semantic segmentation + self.parser.add_argument('--sem', action='store_true') + self.parser.add_argument('--dend', type=float, default=0) # Dendrite + self.parser.add_argument('--axon', type=float, default=0) # Axon + self.parser.add_argument('--soma', type=float, default=0) # Soma + self.parser.add_argument('--nucl', type=float, default=0) # Nucleus + self.parser.add_argument('--ecs', type=float, default=0) # Extracellular space + self.parser.add_argument('--other', type=float, default=0) # Other class + # Metric learning self.parser.add_argument('--vec', type=float, default=0) self.parser.add_argument('--embed_dim', type=int, default=12) @@ -287,6 +295,22 @@ def parse(self): 'soma': ('soma', 1), 'img': ('image', 1), 'vec': ('embedding', opt.embed_dim), + 'dend': ('dendrite', 1), + 'axon': ('axon', 1), + 'nucl': ('nucleus', 1), + 'ecs': ('extracellular_space', 1), + 'other': ('other_class', 1), + } + + semantic_mapping = { + 'dendrite': 1, + 'axon': 2, + 'soma': 3, + 'nucleus': 4, + 'glia': 5, + 'extracellular_space': 6, + 'blood_vessel': 7, + 'other_class': 10, } requires_binarize = [ @@ -301,6 +325,9 @@ def parse(self): if opt.blv_num_channels == 1: requires_binarize.append("blood_vessel") + if opt.sem: + requires_binarize = [x for x in requires_binarize if x not in semantic_mapping] + # Test training if opt.test: opt.eval_intv = 100 @@ -329,6 +356,7 @@ def parse(self): zettaset_mask=not opt.zettaset_no_mask, requires_binarize=requires_binarize, zettaset_share_mask=opt.zettaset_share_mask, + semantic_mapping=semantic_mapping if opt.sem else {}, ) # ONNX From f812fce33d689ad553ae3d4e256dc567969932d1 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Tue, 21 May 2024 10:29:34 -0400 Subject: [PATCH 30/47] fix: apply correct weights for static class balancing --- deepem/loss/affinity.py | 13 ------------- deepem/train/utils.py | 11 ++++++++++- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/deepem/loss/affinity.py b/deepem/loss/affinity.py index d7325d2..060962a 100644 --- a/deepem/loss/affinity.py +++ b/deepem/loss/affinity.py @@ -56,19 +56,6 @@ def forward(self, preds, targets, masks): return loss, nmsk - # def class_balancing(self, target, mask): - # if not self.balancing: - # return mask - # dtype = mask.type() - # m_int = mask * torch.eq(target, 1).type(dtype) - # m_ext = mask * torch.eq(target, 0).type(dtype) - # n_int = m_int.sum().item() - # n_ext = m_ext.sum().item() - # if n_int > 0 and n_ext > 0: - # m_int *= n_ext/(n_int + n_ext) - # m_ext *= n_int/(n_int + n_ext) - # return (m_int + m_ext).type(dtype) - class AffinityLoss(nn.Module): def __init__(self, edges, criterion, size_average=False, diff --git a/deepem/train/utils.py b/deepem/train/utils.py index d13f020..7a7c447 100644 --- a/deepem/train/utils.py +++ b/deepem/train/utils.py @@ -20,6 +20,8 @@ def get_criteria(opt): weight1=opt.class_weight1, ) if opt.class_balancing else None + is_dynamic = (opt.class_weight0 is None) and (opt.class_weight1 is None) + for k in opt.out_spec: if k == 'affinity' or k == 'long_range': if k == 'affinity': @@ -42,7 +44,14 @@ def get_criteria(opt): params['margin0'] = 0 params['margin1'] = 0 params['inverse'] = False - params['class_balancer'] = balancer + params['class_balancer'] = ( + balancer + if is_dynamic + else BinaryWeightBalancer( + weight0=opt.class_weight1, + weight1=opt.class_weight0, + ) + ) criteria[k] = getattr(loss, 'BCELoss')(**params) return criteria From cf140b624b89c221a0b67497c2b3416925ae891c Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Wed, 22 May 2024 11:30:40 -0400 Subject: [PATCH 31/47] fix: update Label augmentation to include recompute targets --- deepem/data/augment/cortex/aug_16nm.py | 2 +- deepem/data/augment/cortex/aug_4nm.py | 2 +- deepem/data/augment/cortex/aug_8nm.py | 2 +- deepem/train/option.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/deepem/data/augment/cortex/aug_16nm.py b/deepem/data/augment/cortex/aug_16nm.py index 3bac7bb..7cb2309 100644 --- a/deepem/data/augment/cortex/aug_16nm.py +++ b/deepem/data/augment/cortex/aug_16nm.py @@ -90,6 +90,6 @@ def get_augmentation( # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) return Compose(augs) diff --git a/deepem/data/augment/cortex/aug_4nm.py b/deepem/data/augment/cortex/aug_4nm.py index 94a62f6..c61d7c9 100644 --- a/deepem/data/augment/cortex/aug_4nm.py +++ b/deepem/data/augment/cortex/aug_4nm.py @@ -90,6 +90,6 @@ def get_augmentation( # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) return Compose(augs) diff --git a/deepem/data/augment/cortex/aug_8nm.py b/deepem/data/augment/cortex/aug_8nm.py index 48c4450..684debf 100644 --- a/deepem/data/augment/cortex/aug_8nm.py +++ b/deepem/data/augment/cortex/aug_8nm.py @@ -90,6 +90,6 @@ def get_augmentation( # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) return Compose(augs) diff --git a/deepem/train/option.py b/deepem/train/option.py index e4ddeb3..c4a0229 100644 --- a/deepem/train/option.py +++ b/deepem/train/option.py @@ -107,7 +107,7 @@ def initialize(self): self.parser.add_argument('--updown_scale_factor', type=vec3f, default=None) # Data augmentation - self.parser.add_argument('--recompute', action='store_true') + self.parser.add_argument('--recompute', type=str, default=[], nargs='+') self.parser.add_argument('--border', type=str, default=[], nargs='+') self.parser.add_argument('--flip', action='store_true') self.parser.add_argument('--grayscale', action='store_true') From f2e7c010908cb915edee1d2c7713d5073b6ea46a Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Fri, 24 May 2024 11:10:08 -0400 Subject: [PATCH 32/47] feat: section gap augmentation --- deepem/data/augment/cortex/aug_16nm.py | 6 + deepem/data/augment/cortex/aug_16nm_gap3-5.py | 105 ++++++++++++++++++ deepem/data/augment/cortex/aug_4nm.py | 6 + deepem/data/augment/cortex/aug_4nm_gap3-5.py | 105 ++++++++++++++++++ deepem/data/augment/cortex/aug_8nm.py | 6 + deepem/train/option.py | 5 +- 6 files changed, 232 insertions(+), 1 deletion(-) create mode 100644 deepem/data/augment/cortex/aug_16nm_gap3-5.py create mode 100644 deepem/data/augment/cortex/aug_4nm_gap3-5.py diff --git a/deepem/data/augment/cortex/aug_16nm.py b/deepem/data/augment/cortex/aug_16nm.py index 7cb2309..a919cb5 100644 --- a/deepem/data/augment/cortex/aug_16nm.py +++ b/deepem/data/augment/cortex/aug_16nm.py @@ -10,6 +10,8 @@ def get_augmentation( random=False, recompute=False, border=[], + section_gap=0, + mask_section_gap=False, **kwargs ): augs = list() @@ -92,4 +94,8 @@ def get_augmentation( if recompute: augs.append(Label(targets=recompute)) + # Section gap + if section_gap > 0: + augs.append(SectionGap(num_secs=section_gap, masked=mask_section_gap)) + return Compose(augs) diff --git a/deepem/data/augment/cortex/aug_16nm_gap3-5.py b/deepem/data/augment/cortex/aug_16nm_gap3-5.py new file mode 100644 index 0000000..b0f6a02 --- /dev/null +++ b/deepem/data/augment/cortex/aug_16nm_gap3-5.py @@ -0,0 +1,105 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=False, + border=[], + mask_section_gap=False, + **kwargs +): + augs = list() + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1, 3), dims=(3, 13), margin=(1, 3, 3), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(3, 13), margin=(1, 3, 3), + density=0.3, skip=0.1) + ) + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Missing section & misalignment + to_blend = list() + # Misalingments + trans = Compose([Misalign((0, 3), margin=1), + Misalign((0, 8), margin=1), + Misalign((0, 13), margin=1)]) + + # Out-of-alignments + slip = Compose([SlipMisalign((0, 3), interp=True, margin=1), + SlipMisalign((0, 8), interp=True, margin=1), + SlipMisalign((0, 13), interp=True, margin=1)]) + to_blend.append(Blend([trans, slip], props=[0.7, 0.3])) + if is_train: + to_blend.append(Blend([ + MisalignPlusMissing((2, 8), value=0, random=random), + MisalignPlusMissing((2, 8), value=0, random=False) + ])) + else: + to_blend.append(MisalignPlusMissing((2, 8), value=0, random=False)) + if missing > 0: + if is_train: + to_blend.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + to_blend.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + if lost: + if is_train: + to_blend.append(Blend([ + LostSection(1), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False) + ])) + augs.append(Blend(to_blend)) + + # Out-of-focus + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border(targets=border)) + + # Recompute connected components + if recompute: + augs.append(Label(targets=recompute)) + + # Section gap + augs.append( + Blend([ + SectionGap(num_secs=3, masked=mask_section_gap), + SectionGap(num_secs=4, masked=mask_section_gap), + SectionGap(num_secs=5, masked=mask_section_gap), + ]) + ) + + return Compose(augs) diff --git a/deepem/data/augment/cortex/aug_4nm.py b/deepem/data/augment/cortex/aug_4nm.py index c61d7c9..01e6d2e 100644 --- a/deepem/data/augment/cortex/aug_4nm.py +++ b/deepem/data/augment/cortex/aug_4nm.py @@ -10,6 +10,8 @@ def get_augmentation( random=False, recompute=False, border=[], + section_gap=0, + mask_section_gap=False, **kwargs ): augs = list() @@ -92,4 +94,8 @@ def get_augmentation( if recompute: augs.append(Label(targets=recompute)) + # Section gap + if section_gap > 0: + augs.append(SectionGap(num_secs=section_gap, masked=mask_section_gap)) + return Compose(augs) diff --git a/deepem/data/augment/cortex/aug_4nm_gap3-5.py b/deepem/data/augment/cortex/aug_4nm_gap3-5.py new file mode 100644 index 0000000..39c0c55 --- /dev/null +++ b/deepem/data/augment/cortex/aug_4nm_gap3-5.py @@ -0,0 +1,105 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=False, + border=[], + mask_section_gap=False, + **kwargs +): + augs = list() + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1, 3), dims=(10, 50), margin=(1, 10, 10), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(10, 50), margin=(1, 10, 10), + density=0.3, skip=0.1) + ) + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Missing section & misalignment + to_blend = list() + # Misalingments + trans = Compose([Misalign((0, 10), margin=1), + Misalign((0, 30), margin=1), + Misalign((0, 50), margin=1)]) + + # Out-of-alignments + slip = Compose([SlipMisalign((0, 10), interp=True, margin=1), + SlipMisalign((0, 30), interp=True, margin=1), + SlipMisalign((0, 50), interp=True, margin=1)]) + to_blend.append(Blend([trans, slip], props=[0.7, 0.3])) + if is_train: + to_blend.append(Blend([ + MisalignPlusMissing((5, 30), value=0, random=random), + MisalignPlusMissing((5, 30), value=0, random=False) + ])) + else: + to_blend.append(MisalignPlusMissing((5, 30), value=0, random=False)) + if missing > 0: + if is_train: + to_blend.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + to_blend.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + if lost: + if is_train: + to_blend.append(Blend([ + LostSection(1), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False) + ])) + augs.append(Blend(to_blend)) + + # Out-of-focus + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border(targets=border)) + + # Recompute connected components + if recompute: + augs.append(Label(targets=recompute)) + + # Section gap + augs.append( + Blend([ + SectionGap(num_secs=3, masked=mask_section_gap), + SectionGap(num_secs=4, masked=mask_section_gap), + SectionGap(num_secs=5, masked=mask_section_gap), + ]) + ) + + return Compose(augs) diff --git a/deepem/data/augment/cortex/aug_8nm.py b/deepem/data/augment/cortex/aug_8nm.py index 684debf..6e10a25 100644 --- a/deepem/data/augment/cortex/aug_8nm.py +++ b/deepem/data/augment/cortex/aug_8nm.py @@ -10,6 +10,8 @@ def get_augmentation( random=False, recompute=False, border=[], + section_gap=0, + mask_section_gap=False, **kwargs ): augs = list() @@ -92,4 +94,8 @@ def get_augmentation( if recompute: augs.append(Label(targets=recompute)) + # Section gap + if section_gap > 0: + augs.append(SectionGap(num_secs=section_gap, masked=mask_section_gap)) + return Compose(augs) diff --git a/deepem/train/option.py b/deepem/train/option.py index c4a0229..ce87e73 100644 --- a/deepem/train/option.py +++ b/deepem/train/option.py @@ -124,6 +124,8 @@ def initialize(self): self.parser.add_argument('--noise_min', type=float, default=0.01) self.parser.add_argument('--noise_max', type=float, default=0.1) self.parser.add_argument('--noise_per_channel', action='store_true') + self.parser.add_argument('--section_gap', type=int, default=0) + self.parser.add_argument('--mask_section_gap', action='store_true') # Tilt-series electron tomography self.parser.add_argument('--tilt_series', type=int, default=0) @@ -237,7 +239,8 @@ def parse(self): # Data augmentation aug_keys = ['recompute', 'border', 'flip','grayscale','warping','misalign', - 'interp','missing','blur','box','mip','lost','random'] + 'interp','missing','blur','box','mip','lost','random', + 'section_gap', 'mask_section_gap'] opt.aug_params = {k: args[k] for k in aug_keys} # Noise From 6a470039cf2ae271d5702abf9f95458abdf854b9 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Thu, 27 Jun 2024 16:24:29 +0900 Subject: [PATCH 33/47] feat: add fly augmentation --- deepem/data/augment/flyem/aug_16nm.py | 101 ++++++++++++++++++++++++++ deepem/data/augment/flyem/aug_8nm.py | 101 ++++++++++++++++++++++++++ 2 files changed, 202 insertions(+) create mode 100644 deepem/data/augment/flyem/aug_16nm.py create mode 100644 deepem/data/augment/flyem/aug_8nm.py diff --git a/deepem/data/augment/flyem/aug_16nm.py b/deepem/data/augment/flyem/aug_16nm.py new file mode 100644 index 0000000..3aff864 --- /dev/null +++ b/deepem/data/augment/flyem/aug_16nm.py @@ -0,0 +1,101 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=False, + border=[], + section_gap=0, + mask_section_gap=False, + **kwargs +): + augs = list() + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Missing section & misalignment + to_blend = list() + # Misalingments + trans = Compose([Misalign((0, 3), margin=1), + Misalign((0, 8), margin=1), + Misalign((0, 13), margin=1)]) + + # Out-of-alignments + slip = Compose([SlipMisalign((0, 3), interp=True, margin=1), + SlipMisalign((0, 8), interp=True, margin=1), + SlipMisalign((0, 13), interp=True, margin=1)]) + to_blend.append(Blend([trans, slip], props=[0.7, 0.3])) + if is_train: + to_blend.append(Blend([ + MisalignPlusMissing((2, 8), value=0, random=random), + MisalignPlusMissing((2, 8), value=0, random=False) + ])) + else: + to_blend.append(MisalignPlusMissing((2, 8), value=0, random=False)) + if missing > 0: + if is_train: + to_blend.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + to_blend.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + if lost: + if is_train: + to_blend.append(Blend([ + LostSection(1), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False) + ])) + augs.append(Blend(to_blend)) + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1, 3), dims=(3, 13), margin=(1, 3, 3), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(3, 13), margin=(1, 3, 3), + density=0.3, skip=0.1) + ) + + # Out-of-focus + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border(targets=border)) + + # Recompute connected components + if recompute: + augs.append(Label(targets=recompute)) + + # Section gap + if section_gap > 0: + augs.append(SectionGap(num_secs=section_gap, masked=mask_section_gap)) + + return Compose(augs) diff --git a/deepem/data/augment/flyem/aug_8nm.py b/deepem/data/augment/flyem/aug_8nm.py new file mode 100644 index 0000000..9c328a1 --- /dev/null +++ b/deepem/data/augment/flyem/aug_8nm.py @@ -0,0 +1,101 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=False, + border=[], + section_gap=0, + mask_section_gap=False, + **kwargs +): + augs = list() + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Missing section & misalignment + to_blend = list() + # Misalingments + trans = Compose([Misalign((0, 5), margin=1), + Misalign((0, 15), margin=1), + Misalign((0, 25), margin=1)]) + + # Out-of-alignments + slip = Compose([SlipMisalign((0, 5), interp=True, margin=1), + SlipMisalign((0, 15), interp=True, margin=1), + SlipMisalign((0, 25), interp=True, margin=1)]) + to_blend.append(Blend([trans, slip], props=[0.7, 0.3])) + if is_train: + to_blend.append(Blend([ + MisalignPlusMissing((3, 15), value=0, random=random), + MisalignPlusMissing((3, 15), value=0, random=False) + ])) + else: + to_blend.append(MisalignPlusMissing((3, 15), value=0, random=False)) + if missing > 0: + if is_train: + to_blend.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + to_blend.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + if lost: + if is_train: + to_blend.append(Blend([ + LostSection(1), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False) + ])) + augs.append(Blend(to_blend)) + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1, 3), dims=(5, 25), margin=(1, 5, 5), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(5, 25), margin=(1, 5, 5), + density=0.3, skip=0.1) + ) + + # Out-of-focus + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border(targets=border)) + + # Recompute connected components + if recompute: + augs.append(Label(targets=recompute)) + + # Section gap + if section_gap > 0: + augs.append(SectionGap(num_secs=section_gap, masked=mask_section_gap)) + + return Compose(augs) From 6df0bef48b5153cc1eb15d9214e36aaf5211b6ef Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Fri, 23 Aug 2024 13:33:50 +0900 Subject: [PATCH 34/47] refactor(mean): minor code modifications to improve efficiency --- deepem/loss/mean.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/deepem/loss/mean.py b/deepem/loss/mean.py index 1fc252a..4c0c787 100644 --- a/deepem/loss/mean.py +++ b/deepem/loss/mean.py @@ -107,23 +107,19 @@ def forward( groups = None if self.recompute_ext: assert splt is not None - trgt_np = np.squeeze(trgt.cpu().numpy()) - splt_np = np.squeeze(splt.cpu().numpy()) - mask_np = np.squeeze(mask.cpu().numpy()) - groups = create_mapping(trgt_np, splt_np, mask_np) + trgt = torch.squeeze(trgt) + splt = torch.squeeze(splt) + mask = torch.squeeze(mask) + groups = create_mapping(trgt.cpu().numpy(), splt.cpu().numpy(), mask.cpu().numpy()) trgt = splt trgt = trgt.to(torch.int) - # Extract unique IDs - ids = np.unique(trgt[mask > 0].cpu().numpy()) - - # Remove 0s from the IDs if `mask_background` is True + # Filter out background and get unique IDs + masked_trgt = trgt[mask > 0] if self.mask_background: - ids = ids[ids != 0] - - # Convert numpy array to a Python list - ids = ids.tolist() + masked_trgt = masked_trgt[masked_trgt != 0] + ids = torch.unique(masked_trgt).tolist() # Recompute external matrix mext = self.compute_ext_matrix(ids, groups, self.recompute_ext, device) From ff5a7de1fa0e0e28a24c326f8f707e5e3a5950fc Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Fri, 23 Aug 2024 13:32:29 +0900 Subject: [PATCH 35/47] feat(mean): add downsampled loss computation to improve efficiency --- deepem/loss/mean.py | 10 ++++++++++ deepem/train/option.py | 2 ++ 2 files changed, 12 insertions(+) diff --git a/deepem/loss/mean.py b/deepem/loss/mean.py index 4c0c787..72f7eda 100644 --- a/deepem/loss/mean.py +++ b/deepem/loss/mean.py @@ -78,6 +78,7 @@ def __init__( delta_d: float = 1.5, recompute_ext: bool = False, mask_background: bool = True, + loss_scale_factor: tuple[float, float, float] | None = None, **kwargs, ): super().__init__() @@ -88,6 +89,7 @@ def __init__( self.delta_d = delta_d # Distance (inter-cluster push force) hinge self.recompute_ext = recompute_ext self.mask_background = mask_background + self.loss_scale_factor = loss_scale_factor def forward( self, @@ -104,6 +106,14 @@ def forward( """ device = embd.device + # Downsample if enabled + if self.loss_scale_factor is not None: + embd = F.interpolate(embd, scale_factor=self.loss_scale_factor, mode='trilinear', align_corners=False) + trgt = F.interpolate(trgt, scale_factor=self.loss_scale_factor, mode='nearest') + mask = F.interpolate(mask, scale_factor=self.loss_scale_factor, mode='nearest') + if splt is not None: + splt = F.interpolate(splt, scale_factor=self.loss_scale_factor, mode='nearest') + groups = None if self.recompute_ext: assert splt is not None diff --git a/deepem/train/option.py b/deepem/train/option.py index ce87e73..ad53172 100644 --- a/deepem/train/option.py +++ b/deepem/train/option.py @@ -83,6 +83,7 @@ def initialize(self): self.parser.add_argument('--delta_d', type=float, default=1.5) self.parser.add_argument('--recompute_ext', action='store_true') self.parser.add_argument('--no_mask_background', action='store_true') + self.parser.add_argument('--loss_scale_factor', type=vec3f, default=None) # Optimizer self.parser.add_argument('--optim', default='Adam') @@ -227,6 +228,7 @@ def parse(self): opt.metric_params['delta_d'] = opt.delta_d opt.metric_params['recompute_ext'] = opt.recompute_ext opt.metric_params['mask_background'] = not opt.no_mask_background + opt.metric_params['loss_scale_factor'] = opt.loss_scale_factor # Optimizer if opt.optim == 'Adam': From e988f8bbc48ced809be499be30500b6c7c9d6b49 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Thu, 12 Dec 2024 00:45:08 +0900 Subject: [PATCH 36/47] feat: add option for split boundary --- deepem/loss/affinity.py | 14 +++++++++----- deepem/train/option.py | 1 + deepem/train/utils.py | 1 + 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/deepem/loss/affinity.py b/deepem/loss/affinity.py index 060962a..da00607 100644 --- a/deepem/loss/affinity.py +++ b/deepem/loss/affinity.py @@ -7,15 +7,19 @@ class EdgeSampler(object): - def __init__(self, edges): + def __init__(self, edges, split_boundary=True): self.edges = list(edges) + self.split_boundary = split_boundary def generate_edges(self): return list(self.edges) def generate_true_aff(self, obj, edge): o1, o2 = torch_utils.get_pair(obj, edge) - ret = ((o1 == o2) & (o1 != 0) & (o2 != 0)) + if self.split_boundary: + ret = ((o1 == o2) & (o1 != 0) & (o2 != 0)) + else: + ret = (o1 == o2) return ret.type(obj.type()) def generate_mask_aff(self, mask, edge): @@ -58,10 +62,10 @@ def forward(self, preds, targets, masks): class AffinityLoss(nn.Module): - def __init__(self, edges, criterion, size_average=False, - class_balancer=None): + def __init__(self, edges, criterion, split_boundary=True, + size_average=False, class_balancer=None): super(AffinityLoss, self).__init__() - self.sampler = EdgeSampler(edges) + self.sampler = EdgeSampler(edges, split_boundary=split_boundary) self.decoder = AffinityLoss.Decoder(edges) self.criterion = EdgeCRF( criterion, diff --git a/deepem/train/option.py b/deepem/train/option.py index ad53172..5800f88 100644 --- a/deepem/train/option.py +++ b/deepem/train/option.py @@ -64,6 +64,7 @@ def initialize(self): # Loss self.parser.add_argument('--loss', default='BCELoss') + self.parser.add_argument('--no_split_boundary', action='store_true') self.parser.add_argument('--size_average', action='store_true') self.parser.add_argument('--margin0', type=float, default=0) self.parser.add_argument('--margin1', type=float, default=0) diff --git a/deepem/train/utils.py b/deepem/train/utils.py index 7a7c447..c38869d 100644 --- a/deepem/train/utils.py +++ b/deepem/train/utils.py @@ -33,6 +33,7 @@ def get_criteria(opt): params['size_average'] = False criteria[k] = loss.AffinityLoss(edges, criterion=getattr(loss, opt.loss)(**params), + split_boundary=not opt.no_split_boundary, size_average=opt.size_average, class_balancer=balancer, ) From fb43e6d9d3b2892a9e94bf6aaeadeeb4e1fdc125 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Mon, 6 Jan 2025 19:19:51 -0800 Subject: [PATCH 37/47] feat: add mip1 aug to dacey human retina --- deepem/data/augment/retina/aug_mip1.py | 95 ++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 deepem/data/augment/retina/aug_mip1.py diff --git a/deepem/data/augment/retina/aug_mip1.py b/deepem/data/augment/retina/aug_mip1.py new file mode 100644 index 0000000..45b1b02 --- /dev/null +++ b/deepem/data/augment/retina/aug_mip1.py @@ -0,0 +1,95 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=False, + border=[], + **kwargs +): + augs = list() + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1,3), dims=(5,25), margin=(1,5,5), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(5,25), margin=(1,5,5), + density=0.3, skip=0.1) + ) + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Missing section & misalignment + to_blend = list() + # Misalingments + trans = Compose([Misalign((0, 5), margin=1), + Misalign((0,15), margin=1), + Misalign((0,25), margin=1)]) + + # Out-of-alignments + slip = Compose([SlipMisalign((0, 5), interp=True, margin=1), + SlipMisalign((0,15), interp=True, margin=1), + SlipMisalign((0,25), interp=True, margin=1)]) + to_blend.append(Blend([trans,slip], props=[0.7,0.3])) + if is_train: + to_blend.append(Blend([ + MisalignPlusMissing((3,15), value=0, random=random), + MisalignPlusMissing((3,15), value=0, random=False) + ])) + else: + to_blend.append(MisalignPlusMissing((3,15), value=0, random=False)) + if missing > 0: + if is_train: + to_blend.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + to_blend.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + if lost: + if is_train: + to_blend.append(Blend([ + LostSection(1), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False) + ])) + augs.append(Blend(to_blend)) + + # Out-of-focus + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border(targets=border)) + + # Recompute connected components + if recompute: + augs.append(Label()) + + return Compose(augs) From 50e63b18f41c4f5be0b2acbac8f6d2dbb3320d9b Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Mon, 6 Jan 2025 20:31:44 -0800 Subject: [PATCH 38/47] feat: relabel augmentation takes target list --- deepem/data/augment/cortex/aug_16nm.py | 2 +- deepem/data/augment/cortex/aug_16nm_gap3-5.py | 2 +- deepem/data/augment/cortex/aug_4nm.py | 2 +- deepem/data/augment/cortex/aug_4nm_gap3-5.py | 2 +- deepem/data/augment/cortex/aug_8nm.py | 2 +- deepem/data/augment/flip_rotate.py | 4 ++-- deepem/data/augment/flip_rotate_iso.py | 4 ++-- deepem/data/augment/flip_rotate_iso_v2.py | 4 ++-- deepem/data/augment/flyem/aug_16nm.py | 2 +- deepem/data/augment/flyem/aug_8nm.py | 2 +- deepem/data/augment/flyem/aug_mip1.py | 4 ++-- deepem/data/augment/flyem/aug_mip2.py | 4 ++-- deepem/data/augment/grayscale_warping.py | 4 ++-- deepem/data/augment/isotropic/aug.py | 4 ++-- deepem/data/augment/isotropic/aug_aniso.py | 4 ++-- deepem/data/augment/isotropic/aug_mip1.py | 4 ++-- deepem/data/augment/kasthuri11/aug_v2.py | 4 ++-- deepem/data/augment/kasthuri11/aug_v2_valid.py | 4 ++-- deepem/data/augment/kasthuri11/aug_v2_valid_no-interp.py | 4 ++-- deepem/data/augment/no_aug.py | 4 ++-- deepem/data/augment/pinky_basil/aug_mip1_v3.py | 4 ++-- deepem/data/augment/pinky_basil/aug_mip2.py | 4 ++-- deepem/data/augment/retina/aug_20nm.py | 4 ++-- deepem/data/augment/retina/aug_mip1.py | 4 ++-- deepem/data/augment/retina/aug_mip2.py | 4 ++-- deepem/data/augment/tilt_series/aug_v0.py | 4 ++-- deepem/data/augment/tilt_series/aug_v1_val.py | 4 ++-- deepem/data/augment/tilt_series/no_aug_val.py | 4 ++-- deepem/data/augment/tilt_series/noise_v0.py | 4 ++-- 29 files changed, 51 insertions(+), 51 deletions(-) diff --git a/deepem/data/augment/cortex/aug_16nm.py b/deepem/data/augment/cortex/aug_16nm.py index a919cb5..7a5f7a2 100644 --- a/deepem/data/augment/cortex/aug_16nm.py +++ b/deepem/data/augment/cortex/aug_16nm.py @@ -8,7 +8,7 @@ def get_augmentation( blur=7, lost=True, random=False, - recompute=False, + recompute=[], border=[], section_gap=0, mask_section_gap=False, diff --git a/deepem/data/augment/cortex/aug_16nm_gap3-5.py b/deepem/data/augment/cortex/aug_16nm_gap3-5.py index b0f6a02..cb61c80 100644 --- a/deepem/data/augment/cortex/aug_16nm_gap3-5.py +++ b/deepem/data/augment/cortex/aug_16nm_gap3-5.py @@ -8,7 +8,7 @@ def get_augmentation( blur=7, lost=True, random=False, - recompute=False, + recompute=[], border=[], mask_section_gap=False, **kwargs diff --git a/deepem/data/augment/cortex/aug_4nm.py b/deepem/data/augment/cortex/aug_4nm.py index 01e6d2e..3fe0706 100644 --- a/deepem/data/augment/cortex/aug_4nm.py +++ b/deepem/data/augment/cortex/aug_4nm.py @@ -8,7 +8,7 @@ def get_augmentation( blur=7, lost=True, random=False, - recompute=False, + recompute=[], border=[], section_gap=0, mask_section_gap=False, diff --git a/deepem/data/augment/cortex/aug_4nm_gap3-5.py b/deepem/data/augment/cortex/aug_4nm_gap3-5.py index 39c0c55..aaf890e 100644 --- a/deepem/data/augment/cortex/aug_4nm_gap3-5.py +++ b/deepem/data/augment/cortex/aug_4nm_gap3-5.py @@ -8,7 +8,7 @@ def get_augmentation( blur=7, lost=True, random=False, - recompute=False, + recompute=[], border=[], mask_section_gap=False, **kwargs diff --git a/deepem/data/augment/cortex/aug_8nm.py b/deepem/data/augment/cortex/aug_8nm.py index 6e10a25..2067c5a 100644 --- a/deepem/data/augment/cortex/aug_8nm.py +++ b/deepem/data/augment/cortex/aug_8nm.py @@ -8,7 +8,7 @@ def get_augmentation( blur=7, lost=True, random=False, - recompute=False, + recompute=[], border=[], section_gap=0, mask_section_gap=False, diff --git a/deepem/data/augment/flip_rotate.py b/deepem/data/augment/flip_rotate.py index 04b899a..fd329c9 100644 --- a/deepem/data/augment/flip_rotate.py +++ b/deepem/data/augment/flip_rotate.py @@ -1,12 +1,12 @@ from augmentor import * -def get_augmentation(is_train, recompute=False, **kwargs): +def get_augmentation(is_train, recompute=[], **kwargs): augs = list() # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) # Flip & rotate augs.append(FlipRotate()) diff --git a/deepem/data/augment/flip_rotate_iso.py b/deepem/data/augment/flip_rotate_iso.py index e849358..532fcc9 100644 --- a/deepem/data/augment/flip_rotate_iso.py +++ b/deepem/data/augment/flip_rotate_iso.py @@ -1,12 +1,12 @@ from augmentor import * -def get_augmentation(is_train, recompute=False, border=[], **kwargs): +def get_augmentation(is_train, recompute=[], border=[], **kwargs): augs = list() # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) # Flip & rotate augs.append(FlipRotateIsotropic()) diff --git a/deepem/data/augment/flip_rotate_iso_v2.py b/deepem/data/augment/flip_rotate_iso_v2.py index 955fced..3e6b2c6 100644 --- a/deepem/data/augment/flip_rotate_iso_v2.py +++ b/deepem/data/augment/flip_rotate_iso_v2.py @@ -1,12 +1,12 @@ from augmentor import * -def get_augmentation(is_train, recompute=False, border=[], **kwargs): +def get_augmentation(is_train, recompute=[], border=[], **kwargs): augs = list() # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) # Create border if border: diff --git a/deepem/data/augment/flyem/aug_16nm.py b/deepem/data/augment/flyem/aug_16nm.py index 3aff864..9d2887d 100644 --- a/deepem/data/augment/flyem/aug_16nm.py +++ b/deepem/data/augment/flyem/aug_16nm.py @@ -8,7 +8,7 @@ def get_augmentation( blur=7, lost=True, random=False, - recompute=False, + recompute=[], border=[], section_gap=0, mask_section_gap=False, diff --git a/deepem/data/augment/flyem/aug_8nm.py b/deepem/data/augment/flyem/aug_8nm.py index 9c328a1..5e203e8 100644 --- a/deepem/data/augment/flyem/aug_8nm.py +++ b/deepem/data/augment/flyem/aug_8nm.py @@ -8,7 +8,7 @@ def get_augmentation( blur=7, lost=True, random=False, - recompute=False, + recompute=[], border=[], section_gap=0, mask_section_gap=False, diff --git a/deepem/data/augment/flyem/aug_mip1.py b/deepem/data/augment/flyem/aug_mip1.py index cde9d28..237f40c 100644 --- a/deepem/data/augment/flyem/aug_mip1.py +++ b/deepem/data/augment/flyem/aug_mip1.py @@ -8,7 +8,7 @@ def get_augmentation( blur=7, lost=True, random=False, - recompute=False, + recompute=[], border=False, **kwargs ): @@ -90,7 +90,7 @@ def get_augmentation( # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) # Flip & rotate augs.append(FlipRotate()) diff --git a/deepem/data/augment/flyem/aug_mip2.py b/deepem/data/augment/flyem/aug_mip2.py index 1223e4c..b5f3300 100644 --- a/deepem/data/augment/flyem/aug_mip2.py +++ b/deepem/data/augment/flyem/aug_mip2.py @@ -8,7 +8,7 @@ def get_augmentation( blur=7, lost=True, random=False, - recompute=False, + recompute=[], border=False, **kwargs ): @@ -90,7 +90,7 @@ def get_augmentation( # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) # Flip & rotate augs.append(FlipRotate()) diff --git a/deepem/data/augment/grayscale_warping.py b/deepem/data/augment/grayscale_warping.py index d377ae1..5c9b840 100644 --- a/deepem/data/augment/grayscale_warping.py +++ b/deepem/data/augment/grayscale_warping.py @@ -1,13 +1,13 @@ from augmentor import * -def get_augmentation(is_train, recompute=False, grayscale=False, warping=False, +def get_augmentation(is_train, recompute=[], grayscale=False, warping=False, **kwargs): augs = list() # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) # Brightness & contrast purterbation if is_train and grayscale: diff --git a/deepem/data/augment/isotropic/aug.py b/deepem/data/augment/isotropic/aug.py index b7bb4fa..66b25d9 100644 --- a/deepem/data/augment/isotropic/aug.py +++ b/deepem/data/augment/isotropic/aug.py @@ -8,7 +8,7 @@ def get_augmentation( is_train, - recompute=False, + recompute=[], box=None, blur=7, border=[], @@ -71,6 +71,6 @@ def get_augmentation( # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) return Compose(augs) diff --git a/deepem/data/augment/isotropic/aug_aniso.py b/deepem/data/augment/isotropic/aug_aniso.py index 38b7043..2f2edd1 100644 --- a/deepem/data/augment/isotropic/aug_aniso.py +++ b/deepem/data/augment/isotropic/aug_aniso.py @@ -8,7 +8,7 @@ def get_augmentation( is_train, - recompute=False, + recompute=[], box=None, blur=7, border=[], @@ -71,6 +71,6 @@ def get_augmentation( # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) return Compose(augs) diff --git a/deepem/data/augment/isotropic/aug_mip1.py b/deepem/data/augment/isotropic/aug_mip1.py index 91913b8..0d6084e 100644 --- a/deepem/data/augment/isotropic/aug_mip1.py +++ b/deepem/data/augment/isotropic/aug_mip1.py @@ -8,7 +8,7 @@ def get_augmentation( is_train, - recompute=False, + recompute=[], box=None, blur=7, border=[], @@ -71,6 +71,6 @@ def get_augmentation( # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) return Compose(augs) diff --git a/deepem/data/augment/kasthuri11/aug_v2.py b/deepem/data/augment/kasthuri11/aug_v2.py index 98f0ad4..20d9917 100644 --- a/deepem/data/augment/kasthuri11/aug_v2.py +++ b/deepem/data/augment/kasthuri11/aug_v2.py @@ -1,7 +1,7 @@ from augmentor import * -def get_augmentation(is_train, recompute=False, grayscale=False, missing=0, +def get_augmentation(is_train, recompute=[], grayscale=False, missing=0, blur=0, warping=False, misalign=0, box=None, mip=0, random=False, **kwargs): augs = list() @@ -62,6 +62,6 @@ def get_augmentation(is_train, recompute=False, grayscale=False, missing=0, # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) return Compose(augs) diff --git a/deepem/data/augment/kasthuri11/aug_v2_valid.py b/deepem/data/augment/kasthuri11/aug_v2_valid.py index c9b567a..456db2c 100644 --- a/deepem/data/augment/kasthuri11/aug_v2_valid.py +++ b/deepem/data/augment/kasthuri11/aug_v2_valid.py @@ -1,7 +1,7 @@ from augmentor import * -def get_augmentation(is_train, recompute=False, grayscale=False, missing=0, +def get_augmentation(is_train, recompute=[], grayscale=False, missing=0, blur=0, warping=False, misalign=0, box=None, mip=0, random=False, **kwargs): augs = list() @@ -64,6 +64,6 @@ def get_augmentation(is_train, recompute=False, grayscale=False, missing=0, # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) return Compose(augs) diff --git a/deepem/data/augment/kasthuri11/aug_v2_valid_no-interp.py b/deepem/data/augment/kasthuri11/aug_v2_valid_no-interp.py index 9c2ebe8..4a6e36e 100644 --- a/deepem/data/augment/kasthuri11/aug_v2_valid_no-interp.py +++ b/deepem/data/augment/kasthuri11/aug_v2_valid_no-interp.py @@ -1,7 +1,7 @@ from augmentor import * -def get_augmentation(is_train, recompute=False, grayscale=False, missing=0, +def get_augmentation(is_train, recompute=[], grayscale=False, missing=0, blur=0, warping=False, misalign=0, box=None, mip=0, random=False, **kwargs): augs = list() @@ -64,6 +64,6 @@ def get_augmentation(is_train, recompute=False, grayscale=False, missing=0, # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) return Compose(augs) diff --git a/deepem/data/augment/no_aug.py b/deepem/data/augment/no_aug.py index 20bafd1..b1b4fa3 100644 --- a/deepem/data/augment/no_aug.py +++ b/deepem/data/augment/no_aug.py @@ -1,12 +1,12 @@ from augmentor import * -def get_augmentation(is_train, recompute=False, **kwargs): +def get_augmentation(is_train, recompute=[], **kwargs): augs = list() # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) # Flip & rotate if not is_train: diff --git a/deepem/data/augment/pinky_basil/aug_mip1_v3.py b/deepem/data/augment/pinky_basil/aug_mip1_v3.py index 21b6afe..b5f4b7c 100644 --- a/deepem/data/augment/pinky_basil/aug_mip1_v3.py +++ b/deepem/data/augment/pinky_basil/aug_mip1_v3.py @@ -8,7 +8,7 @@ def get_augmentation( blur=7, lost=True, random=False, - recompute=False, + recompute=[], border=False, **kwargs ): @@ -83,7 +83,7 @@ def get_augmentation( # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) # Flip & rotate augs.append(FlipRotate()) diff --git a/deepem/data/augment/pinky_basil/aug_mip2.py b/deepem/data/augment/pinky_basil/aug_mip2.py index 7916fd3..00adeb1 100644 --- a/deepem/data/augment/pinky_basil/aug_mip2.py +++ b/deepem/data/augment/pinky_basil/aug_mip2.py @@ -8,7 +8,7 @@ def get_augmentation( blur=7, lost=True, random=False, - recompute=False, + recompute=[], border=False, **kwargs ): @@ -83,7 +83,7 @@ def get_augmentation( # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) # Flip & rotate augs.append(FlipRotate()) diff --git a/deepem/data/augment/retina/aug_20nm.py b/deepem/data/augment/retina/aug_20nm.py index db70eed..58c0a93 100644 --- a/deepem/data/augment/retina/aug_20nm.py +++ b/deepem/data/augment/retina/aug_20nm.py @@ -8,7 +8,7 @@ def get_augmentation( blur=7, lost=True, random=False, - recompute=False, + recompute=[], border=[], **kwargs ): @@ -97,6 +97,6 @@ def get_augmentation( # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) return Compose(augs) diff --git a/deepem/data/augment/retina/aug_mip1.py b/deepem/data/augment/retina/aug_mip1.py index 45b1b02..9a9cf21 100644 --- a/deepem/data/augment/retina/aug_mip1.py +++ b/deepem/data/augment/retina/aug_mip1.py @@ -8,7 +8,7 @@ def get_augmentation( blur=7, lost=True, random=False, - recompute=False, + recompute=[], border=[], **kwargs ): @@ -90,6 +90,6 @@ def get_augmentation( # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) return Compose(augs) diff --git a/deepem/data/augment/retina/aug_mip2.py b/deepem/data/augment/retina/aug_mip2.py index 0f3da1c..bac9eec 100644 --- a/deepem/data/augment/retina/aug_mip2.py +++ b/deepem/data/augment/retina/aug_mip2.py @@ -8,7 +8,7 @@ def get_augmentation( blur=7, lost=True, random=False, - recompute=False, + recompute=[], border=[], **kwargs ): @@ -90,6 +90,6 @@ def get_augmentation( # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) return Compose(augs) diff --git a/deepem/data/augment/tilt_series/aug_v0.py b/deepem/data/augment/tilt_series/aug_v0.py index c6451fe..47eace8 100644 --- a/deepem/data/augment/tilt_series/aug_v0.py +++ b/deepem/data/augment/tilt_series/aug_v0.py @@ -2,7 +2,7 @@ def get_augmentation(is_train, tilt_series=(0,0,0), tilt_series_crop=None, - recompute=False, flip=False, noise=None, **kwargs): + recompute=[], flip=False, noise=None, **kwargs): augs = [] # Flip & rotate (isotropic) @@ -38,6 +38,6 @@ def get_augmentation(is_train, tilt_series=(0,0,0), tilt_series_crop=None, # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) return Compose(augs) diff --git a/deepem/data/augment/tilt_series/aug_v1_val.py b/deepem/data/augment/tilt_series/aug_v1_val.py index 54d8386..d47c30d 100644 --- a/deepem/data/augment/tilt_series/aug_v1_val.py +++ b/deepem/data/augment/tilt_series/aug_v1_val.py @@ -2,7 +2,7 @@ def get_augmentation(is_train, tilt_series=(0,0,0), tilt_series_crop=None, - recompute=False, box=None, missing=7, blur=7, random=False, + recompute=[], box=None, missing=7, blur=7, random=False, **kwargs): augs = [] @@ -28,7 +28,7 @@ def get_augmentation(is_train, tilt_series=(0,0,0), tilt_series_crop=None, # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) # Box if is_train: diff --git a/deepem/data/augment/tilt_series/no_aug_val.py b/deepem/data/augment/tilt_series/no_aug_val.py index 156c545..f3b005f 100644 --- a/deepem/data/augment/tilt_series/no_aug_val.py +++ b/deepem/data/augment/tilt_series/no_aug_val.py @@ -2,7 +2,7 @@ def get_augmentation(is_train, tilt_series=(0,0,0), tilt_series_crop=None, - recompute=False, **kwargs): + recompute=[], **kwargs): augs = [] # Flip & rotate (isotropic) @@ -27,6 +27,6 @@ def get_augmentation(is_train, tilt_series=(0,0,0), tilt_series_crop=None, # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) return Compose(augs) diff --git a/deepem/data/augment/tilt_series/noise_v0.py b/deepem/data/augment/tilt_series/noise_v0.py index 5e1db52..bcb857e 100644 --- a/deepem/data/augment/tilt_series/noise_v0.py +++ b/deepem/data/augment/tilt_series/noise_v0.py @@ -2,7 +2,7 @@ def get_augmentation(is_train, tilt_series=(0,0,0), tilt_series_crop=None, - recompute=False, flip=False, noise=None, **kwargs): + recompute=[], flip=False, noise=None, **kwargs): augs = [] # Flip & rotate (isotropic) @@ -38,6 +38,6 @@ def get_augmentation(is_train, tilt_series=(0,0,0), tilt_series_crop=None, # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) return Compose(augs) From 30b425e780724bae9df95a66752477b0074e02b7 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Mon, 6 Jan 2025 20:33:23 -0800 Subject: [PATCH 39/47] feat: border augmentation takes target list --- deepem/data/augment/flyem/aug_mip1.py | 4 ++-- deepem/data/augment/flyem/aug_mip2.py | 4 ++-- deepem/data/augment/pinky_basil/aug_mip1_v3.py | 4 ++-- deepem/data/augment/pinky_basil/aug_mip2.py | 4 ++-- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/deepem/data/augment/flyem/aug_mip1.py b/deepem/data/augment/flyem/aug_mip1.py index 237f40c..bfee674 100644 --- a/deepem/data/augment/flyem/aug_mip1.py +++ b/deepem/data/augment/flyem/aug_mip1.py @@ -9,7 +9,7 @@ def get_augmentation( lost=True, random=False, recompute=[], - border=False, + border=[], **kwargs ): augs = list() @@ -97,6 +97,6 @@ def get_augmentation( # Create border if border: - augs.append(Border()) + augs.append(Border(targets=border)) return Compose(augs) diff --git a/deepem/data/augment/flyem/aug_mip2.py b/deepem/data/augment/flyem/aug_mip2.py index b5f3300..25c9795 100644 --- a/deepem/data/augment/flyem/aug_mip2.py +++ b/deepem/data/augment/flyem/aug_mip2.py @@ -9,7 +9,7 @@ def get_augmentation( lost=True, random=False, recompute=[], - border=False, + border=[], **kwargs ): augs = list() @@ -97,6 +97,6 @@ def get_augmentation( # Create border if border: - augs.append(Border()) + augs.append(Border(targets=border)) return Compose(augs) diff --git a/deepem/data/augment/pinky_basil/aug_mip1_v3.py b/deepem/data/augment/pinky_basil/aug_mip1_v3.py index b5f4b7c..d069650 100644 --- a/deepem/data/augment/pinky_basil/aug_mip1_v3.py +++ b/deepem/data/augment/pinky_basil/aug_mip1_v3.py @@ -9,7 +9,7 @@ def get_augmentation( lost=True, random=False, recompute=[], - border=False, + border=[], **kwargs ): augs = list() @@ -90,6 +90,6 @@ def get_augmentation( # Create border if border: - augs.append(Border()) + augs.append(Border(targets=border)) return Compose(augs) diff --git a/deepem/data/augment/pinky_basil/aug_mip2.py b/deepem/data/augment/pinky_basil/aug_mip2.py index 00adeb1..0441dbd 100644 --- a/deepem/data/augment/pinky_basil/aug_mip2.py +++ b/deepem/data/augment/pinky_basil/aug_mip2.py @@ -9,7 +9,7 @@ def get_augmentation( lost=True, random=False, recompute=[], - border=False, + border=[], **kwargs ): augs = list() @@ -90,6 +90,6 @@ def get_augmentation( # Create border if border: - augs.append(Border()) + augs.append(Border(targets=border)) return Compose(augs) From 7626fcbb5cd192ba8e55aef59106bd040709a6ce Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Tue, 7 Jan 2025 12:21:38 -0800 Subject: [PATCH 40/47] feat(docker): upgrade cloud-volume and imagecodecs to resolve JXL encoding issue --- docker/zettasets/Dockerfile | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docker/zettasets/Dockerfile b/docker/zettasets/Dockerfile index b8d7888..e7d197b 100644 --- a/docker/zettasets/Dockerfile +++ b/docker/zettasets/Dockerfile @@ -53,6 +53,9 @@ RUN apt-get update \ RUN cd /tmp/zettasets && \ pip install -e . +# Install the latest version of cloud-volume +RUN pip install --no-cache-dir --upgrade cloud-volume imagecodecs + RUN mkdir -p /workspace ENV PYTHONPATH /DeepEM:${PYTHONPATH} From 4a47015fa350a4c9329a63d7c20b1832c86494e7 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Thu, 23 Jan 2025 17:19:24 +0900 Subject: [PATCH 41/47] fix: update UpBlock of up-down net to support ONNX export --- deepem/models/updown_act.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/deepem/models/updown_act.py b/deepem/models/updown_act.py index 25a02f5..2b33f38 100644 --- a/deepem/models/updown_act.py +++ b/deepem/models/updown_act.py @@ -20,7 +20,7 @@ def create_model(opt): else: # Batch normalization core = rsunet_act(width=width[:depth], act=opt.act) - return Model(core, opt.in_spec, opt.out_spec, width[0], crop=opt.crop) + return Model(core, opt.in_spec, opt.out_spec, width[0], crop=opt.crop, onnx=opt.onnx) class InputBlock(nn.Sequential): @@ -30,15 +30,19 @@ def __init__(self, in_channels, out_channels, kernel_size): class OutputBlock(nn.Module): - def __init__(self, in_channels, out_spec, kernel_size): + def __init__(self, in_channels, out_spec, kernel_size, onnx=False): super(OutputBlock, self).__init__() + self.onnx = onnx for k, v in out_spec.items(): out_channels = v[-4] self.add_module(k, Conv(in_channels, out_channels, kernel_size, bias=True)) def forward(self, x): - return {k: m(x) for k, m in self.named_children()} + if self.onnx: + return tuple(m(x) for k, m in self.named_children()) + else: + return {k: m(x) for k, m in self.named_children()} class DownBlock(nn.Sequential): @@ -48,14 +52,18 @@ def __init__(self, scale_factor=(1,2,2)): class UpBlock(nn.Module): - def __init__(self, out_spec, scale_factor=(1,2,2)): + def __init__(self, out_spec, scale_factor=(1,2,2), onnx=False): super(UpBlock, self).__init__() + self.onnx = onnx for k, v in out_spec.items(): self.add_module(k, nn.Upsample(scale_factor=scale_factor, mode='trilinear')) def forward(self, x): - return {k: m(x[k]) for k, m in self.named_children()} + if self.onnx: + return tuple(m(x[i]) for i, (_, m) in enumerate(self.named_children())) + else: + return {k: m(x[k]) for k, m in self.named_children()} class Model(nn.Sequential): @@ -63,7 +71,7 @@ class Model(nn.Sequential): Residual Symmetric U-Net with down/upsampling in/output. """ def __init__(self, core, in_spec, out_spec, out_channels, io_kernel=(1,5,5), - scale_factor=(1,2,2), crop=None): + scale_factor=(1,2,2), crop=None, onnx=False): super(Model, self).__init__() assert len(in_spec)==1, "model takes a single input" @@ -72,7 +80,7 @@ def __init__(self, core, in_spec, out_spec, out_channels, io_kernel=(1,5,5), self.add_module('down', DownBlock(scale_factor=scale_factor)) self.add_module('in', InputBlock(in_channels, out_channels, io_kernel)) self.add_module('core', core) - self.add_module('out', OutputBlock(out_channels, out_spec, io_kernel)) - self.add_module('up', UpBlock(out_spec, scale_factor=scale_factor)) + self.add_module('out', OutputBlock(out_channels, out_spec, io_kernel, onnx=onnx)) + self.add_module('up', UpBlock(out_spec, scale_factor=scale_factor, onnx=onnx)) if crop is not None: self.add_module('crop', Crop(crop)) From 646fff97905d57541bc63ba6f36b0ec89f5efb9c Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Tue, 4 Feb 2025 10:56:23 +0900 Subject: [PATCH 42/47] feat: update class balancing for auxiliary tasks --- deepem/train/utils.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/deepem/train/utils.py b/deepem/train/utils.py index c38869d..b09cf47 100644 --- a/deepem/train/utils.py +++ b/deepem/train/utils.py @@ -20,8 +20,6 @@ def get_criteria(opt): weight1=opt.class_weight1, ) if opt.class_balancing else None - is_dynamic = (opt.class_weight0 is None) and (opt.class_weight1 is None) - for k in opt.out_spec: if k == 'affinity' or k == 'long_range': if k == 'affinity': @@ -41,18 +39,20 @@ def get_criteria(opt): criteria[k] = getattr(loss, opt.metric_loss)(**opt.metric_params) else: params = dict(opt.loss_params) + + if ('affinity' in opt.out_spec) or ('long_range' in opt.out_spec): + balancer = BinaryWeightBalancer( + weight0=opt.class_weight1, + weight1=opt.class_weight0, + ) if opt.class_balancing else None + params['class_balancer'] = balancer + if opt.default_aux: params['margin0'] = 0 params['margin1'] = 0 params['inverse'] = False - params['class_balancer'] = ( - balancer - if is_dynamic - else BinaryWeightBalancer( - weight0=opt.class_weight1, - weight1=opt.class_weight0, - ) - ) + params['class_balancer'] = None + criteria[k] = getattr(loss, 'BCELoss')(**params) return criteria From ec6452bdb6e017084490a536ca6ca0171ff65b77 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Tue, 4 Feb 2025 16:54:07 +0900 Subject: [PATCH 43/47] fix(cv_utils): resolve potential mismatch between coord_mip and in_mip related to CloudVolume.voxel_offset --- deepem/test/cv_utils.py | 88 +++++++++++++++++++++++------------------ 1 file changed, 50 insertions(+), 38 deletions(-) diff --git a/deepem/test/cv_utils.py b/deepem/test/cv_utils.py index 786b7d0..c9443b1 100644 --- a/deepem/test/cv_utils.py +++ b/deepem/test/cv_utils.py @@ -16,20 +16,25 @@ def make_info(num_channels, layer_type, dtype, shape, resolution, chunk_size=chunk_size) def get_coord_bbox(cvol, opt): - # Cutout offset = cvol.voxel_offset + volume_shape = cvol.shape[:3] + if opt.center is not None: assert opt.size is not None - opt.begin = tuple(x - (y//2) for x, y in zip(opt.center, opt.size)) + opt.begin = tuple(x - (y // 2) for x, y in zip(opt.center, opt.size)) opt.end = tuple(x + y for x, y in zip(opt.begin, opt.size)) else: if opt.begin is None: - opt.begin = offset + opt.begin = offset # Default to voxel_offset if not provided + if opt.end is None: if opt.size is None: - opt.end = offset + cvol.shape[:3] + # When size is not specified, set end to the end of the dataset + opt.end = tuple(o + s for o, s in zip(offset, volume_shape)) else: - opt.end = tuple(x + y for x, y in zip(opt.begin, opt.size)) + # When size is specified, calculate end based on begin + size + opt.end = tuple(b + s for b, s in zip(opt.begin, opt.size)) + return Bbox(opt.begin, opt.end) def cutout(opt, gs_path, dtype='uint8', channels=0): @@ -37,23 +42,30 @@ def cutout(opt, gs_path, dtype='uint8', channels=0): gs_path = gs_path.format(*opt.keywords) print(gs_path) - # CloudVolume. - cvol = cv.CloudVolume(gs_path, mip=opt.in_mip, cache=opt.cache, - fill_missing=True, parallel=opt.parallel) + # CloudVolume for coordinate handling (coord_mip) + coord_cvol = cv.CloudVolume(gs_path, mip=opt.coord_mip) + + # CloudVolume for data fetching (in_mip) + data_cvol = cv.CloudVolume(gs_path, mip=opt.in_mip, cache=opt.cache, + fill_missing=True, parallel=opt.parallel) + + # Get bounding box based on coord_mip + coord_bbox = get_coord_bbox(coord_cvol, opt) - # Based on MIP level args are specified in - coord_bbox = get_coord_bbox(cvol, opt) if opt.coord_mip != opt.in_mip: print(f"mip {opt.coord_mip} = {coord_bbox}") - # Based on in_mip - in_bbox = cvol.bbox_to_mip(coord_bbox, mip=opt.coord_mip, to_mip=opt.in_mip) + + # Convert bbox to in_mip coordinates if needed + in_bbox = coord_cvol.bbox_to_mip(coord_bbox, mip=opt.coord_mip, to_mip=opt.in_mip) print(f"mip {opt.in_mip} = {in_bbox}") - cutout = cvol[in_bbox.to_slices()] + + # Data cutout from in_mip + cutout = data_cvol[in_bbox.to_slices()] # Transpose & squeeze - cutout = cutout.transpose([3,2,1,0]) + cutout = cutout.transpose([3, 2, 1, 0]) - # Slice channel + # Slice channels if specified if channels > 0: cutout = cutout[:channels, ...] @@ -63,42 +75,46 @@ def cutout(opt, gs_path, dtype='uint8', channels=0): def ingest(data, opt, tag=None): # Neuroglancer format data = py_utils.to_tensor(data) - data = data.transpose((3,2,1,0)) + data = data.transpose((3, 2, 1, 0)) num_channels = data.shape[-1] shape = data.shape[:-1] - # Use CloudVolume to make sure the output bbox matches the input. - # MIP hierarchies are not guaranteed to be powers of 2 (especially - # with float resolutions), so need to use the info file and be - # consistent computing the offset between input and output. + # Use CloudVolume with coord_mip for coordinate handling gs_path = opt.gs_input if '{}' in gs_path: gs_path = gs_path.format(*opt.keywords) - in_vol = cv.CloudVolume(gs_path, mip=opt.in_mip, cache=opt.cache, - fill_missing=True, parallel=opt.parallel) - coord_bbox = get_coord_bbox(in_vol, opt) + + coord_cvol = cv.CloudVolume(gs_path, mip=opt.coord_mip) + + # Get bounding box in coord_mip + coord_bbox = get_coord_bbox(coord_cvol, opt) + # Offset is defined at coord_mip, so adjust coord_bbox first if opt.offset: start_adjust = coord_bbox.minpt - opt.offset coord_bbox -= start_adjust - in_bbox = in_vol.bbox_to_mip(coord_bbox, mip=opt.coord_mip, to_mip=opt.in_mip) + + # Convert bbox to in_mip coordinates + in_bbox = coord_cvol.bbox_to_mip(coord_bbox, mip=opt.coord_mip, to_mip=opt.in_mip) # Patch offset correction (when output patch is smaller than input patch) - patch_offset = (0,0,0) + patch_offset = (0, 0, 0) if opt.tilt_series > 0: if opt.tilt_series_crop is not None: outputsz = np.array(opt.fov) * np.array(opt.scale) patch_offset = (outputsz - np.array(opt.tilt_series_crop)) // 2 else: patch_offset = (np.array(opt.inputsz) - np.array(opt.outputsz)) // 2 + patch_offset = Vec(*np.flip(patch_offset, 0)) - in_bbox.minpt += patch_offset - # in_bbox.stop -= 2*patch_offset # using the data to define shape - # Create info + # Create info using the adjusted offset + offset = in_bbox.minpt + patch_offset info = make_info(num_channels, 'image', str(data.dtype), shape, - opt.resolution, offset=in_bbox.minpt, chunk_size=opt.chunk_size) + opt.resolution, offset=offset, chunk_size=opt.chunk_size) print(info) + + # Output path formatting gs_path = opt.gs_output if '{}' in opt.gs_output: if opt.keywords: @@ -108,21 +124,17 @@ def ingest(data, opt, tag=None): coord = "x{}_y{}_z{}".format(*opt.center) coord += "_s{}-{}-{}".format(*opt.size) else: - coord = '_'.join([f"{b}-{e}" for b,e in zip(opt.begin,opt.end)]) + coord = '_'.join([f"{b}-{e}" for b, e in zip(opt.begin, opt.end)]) gs_path = gs_path.format(coord) # Tagging if tag is not None: - if gs_path[-1] == '/': - gs_path += tag - else: - gs_path += ('/' + tag) + gs_path = gs_path.rstrip('/') + '/' + tag print(f"gs_output:\n{gs_path}") - cvol = cv.CloudVolume(gs_path, mip=0, info=info, - parallel=opt.parallel) - cvol[:,:,:,:] = data - cvol.commit_info() + data_vol = cv.CloudVolume(gs_path, mip=0, info=info, parallel=opt.parallel) + data_vol[:, :, :, :] = data + data_vol.commit_info() # Downsample if opt.downsample: From 5e6c584a6811d2ba0cec924c3c358def035aedc6 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Thu, 13 Feb 2025 11:40:33 +0900 Subject: [PATCH 44/47] feat: add anisotropic models with `scale_init` --- deepem/models/rsunet_act_scale.py | 74 ++++++++++++++ deepem/models/updown_act_scale.py | 96 ++++++++++++++++++ deepem/models/updown_act_scale_interpolate.py | 97 +++++++++++++++++++ 3 files changed, 267 insertions(+) create mode 100644 deepem/models/rsunet_act_scale.py create mode 100644 deepem/models/updown_act_scale.py create mode 100644 deepem/models/updown_act_scale_interpolate.py diff --git a/deepem/models/rsunet_act_scale.py b/deepem/models/rsunet_act_scale.py new file mode 100644 index 0000000..751255e --- /dev/null +++ b/deepem/models/rsunet_act_scale.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn + +import emvision +from emvision.models import rsunet_act, rsunet_act_gn + +from deepem.models.layers import Conv, Crop, Scale + + +def create_model(opt): + if opt.width: + width = opt.width + depth = len(width) + else: + width = [16,32,64,128,256,512] + depth = opt.depth + if opt.group > 0: + # Group normalization + core = rsunet_act_gn(width=width[:depth], group=opt.group, act=opt.act) + else: + # Batch normalization + core = rsunet_act(width=width[:depth], act=opt.act) + return Model(core, opt.in_spec, opt.out_spec, width[0], crop=opt.crop, + onnx=opt.onnx, scale_init=opt.scale_init) + + +class InputBlock(nn.Sequential): + def __init__(self, in_channels, out_channels, kernel_size): + super(InputBlock, self).__init__() + self.add_module('conv', Conv(in_channels, out_channels, kernel_size)) + + +class OutputBlock(nn.Module): + def __init__(self, in_channels, out_spec, kernel_size, onnx=False, scale_init=1.0): + super(OutputBlock, self).__init__() + self.onnx = onnx + for k, v in out_spec.items(): + out_channels = v[-4] + if k == 'embedding': + self.add_module( + k, + nn.Sequential( + Conv(in_channels, out_channels, kernel_size, bias=True), + Scale(init_value=scale_init), + ), + ) + else: + self.add_module(k, + Conv(in_channels, out_channels, kernel_size, bias=True)) + + def forward(self, x): + if self.onnx: + return tuple(m(x) for k, m in self.named_children()) + else: + return {k: m(x) for k, m in self.named_children()} + + +class Model(nn.Sequential): + """ + Residual Symmetric U-Net. + """ + def __init__(self, core, in_spec, out_spec, out_channels, io_kernel=(1,5,5), + crop=None, onnx=False, scale_init=1.0): + super(Model, self).__init__() + + assert len(in_spec)==1, "model takes a single input" + in_channels = 1 + + self.add_module('in', InputBlock(in_channels, out_channels, io_kernel)) + self.add_module('core', core) + self.add_module('out', + OutputBlock(out_channels, out_spec, io_kernel, onnx=onnx, scale_init=scale_init)) + if crop is not None: + self.add_module('crop', Crop(crop)) diff --git a/deepem/models/updown_act_scale.py b/deepem/models/updown_act_scale.py new file mode 100644 index 0000000..888371f --- /dev/null +++ b/deepem/models/updown_act_scale.py @@ -0,0 +1,96 @@ +import torch +import torch.nn as nn + +import emvision +from emvision.models import rsunet_act, rsunet_act_gn + +from deepem.models.layers import Conv, Crop, Scale + + +def create_model(opt): + if opt.width: + width = opt.width + depth = len(width) + else: + width = [16,32,64,128,256,512] + depth = opt.depth + if opt.group > 0: + # Group normalization + core = rsunet_act_gn(width=width[:depth], group=opt.group, act=opt.act) + else: + # Batch normalization + core = rsunet_act(width=width[:depth], act=opt.act) + return Model(core, opt.in_spec, opt.out_spec, width[0], crop=opt.crop, + onnx=opt.onnx, scale_init=opt.scale_init) + + +class InputBlock(nn.Sequential): + def __init__(self, in_channels, out_channels, kernel_size): + super(InputBlock, self).__init__() + self.add_module('conv', Conv(in_channels, out_channels, kernel_size)) + + +class OutputBlock(nn.Module): + def __init__(self, in_channels, out_spec, kernel_size, onnx=False, scale_init=1.0): + super(OutputBlock, self).__init__() + self.onnx = onnx + for k, v in out_spec.items(): + out_channels = v[-4] + if k == 'embedding': + self.add_module( + k, + nn.Sequential( + Conv(in_channels, out_channels, kernel_size, bias=True), + Scale(init_value=scale_init), + ), + ) + else: + self.add_module(k, + Conv(in_channels, out_channels, kernel_size, bias=True)) + + def forward(self, x): + if self.onnx: + return tuple(m(x) for k, m in self.named_children()) + else: + return {k: m(x) for k, m in self.named_children()} + + +class DownBlock(nn.Sequential): + def __init__(self, scale_factor=(1,2,2)): + super(DownBlock, self).__init__() + self.add_module('down', nn.AvgPool3d(scale_factor)) + + +class UpBlock(nn.Module): + def __init__(self, out_spec, scale_factor=(1,2,2), onnx=False): + super(UpBlock, self).__init__() + self.onnx = onnx + for k, v in out_spec.items(): + self.add_module(k, + nn.Upsample(scale_factor=scale_factor, mode='trilinear')) + + def forward(self, x): + if self.onnx: + return tuple(m(x[i]) for i, (_, m) in enumerate(self.named_children())) + else: + return {k: m(x[k]) for k, m in self.named_children()} + + +class Model(nn.Sequential): + """ + Residual Symmetric U-Net with down/upsampling in/output. + """ + def __init__(self, core, in_spec, out_spec, out_channels, io_kernel=(1,5,5), + scale_factor=(1,2,2), crop=None, onnx=False, scale_init=1.0): + super(Model, self).__init__() + + assert len(in_spec)==1, "model takes a single input" + in_channels = 1 + + self.add_module('down', DownBlock(scale_factor=scale_factor)) + self.add_module('in', InputBlock(in_channels, out_channels, io_kernel)) + self.add_module('core', core) + self.add_module('out', OutputBlock(out_channels, out_spec, io_kernel, onnx=onnx, scale_init=scale_init)) + self.add_module('up', UpBlock(out_spec, scale_factor=scale_factor, onnx=onnx)) + if crop is not None: + self.add_module('crop', Crop(crop)) diff --git a/deepem/models/updown_act_scale_interpolate.py b/deepem/models/updown_act_scale_interpolate.py new file mode 100644 index 0000000..571a6e5 --- /dev/null +++ b/deepem/models/updown_act_scale_interpolate.py @@ -0,0 +1,97 @@ +import torch.nn as nn +import torch.nn.functional as F + +from emvision.models import rsunet_act, rsunet_act_gn + +from deepem.models.layers import Conv, Crop, Scale + + +def create_model(opt): + if opt.width: + width = opt.width + depth = len(width) + else: + width = [16, 32, 64, 128, 256, 512] + depth = opt.depth + if opt.group > 0: + # Group normalization + core = rsunet_act_gn(width=width[:depth], group=opt.group, act=opt.act) + else: + # Batch normalization + core = rsunet_act(width=width[:depth], act=opt.act) + return Model(core, opt.in_spec, opt.out_spec, width[0], crop=opt.crop, + scale_init=opt.scale_init, scale_factor=opt.updown_scale_factor) + + +class InputBlock(nn.Sequential): + def __init__(self, in_channels, out_channels, kernel_size): + super(InputBlock, self).__init__() + self.add_module('conv', Conv(in_channels, out_channels, kernel_size)) + + +class OutputBlock(nn.Module): + def __init__(self, in_channels, out_spec, kernel_size, scale_init=1.0): + super(OutputBlock, self).__init__() + for k, v in out_spec.items(): + out_channels = v[-4] + if k == 'embedding': + self.add_module( + k, + nn.Sequential( + Conv(in_channels, out_channels, kernel_size, bias=True), + Scale(init_value=scale_init), + ), + ) + else: + self.add_module(k, + Conv(in_channels, out_channels, kernel_size, bias=True)) + + def forward(self, x): + return {k: m(x) for k, m in self.named_children()} + + +class DownBlock(nn.Module): + def __init__(self, size): + super(DownBlock, self).__init__() + self.size = size + + def forward(self, x): + return F.interpolate(x, size=self.size, mode='trilinear', align_corners=False) + + +class UpBlock(nn.Module): + def __init__(self, out_spec, size): + super(UpBlock, self).__init__() + for k, v in out_spec.items(): + self.add_module(k, + nn.Upsample( + size=size, + mode='trilinear', + recompute_scale_factor=False, + )) + + def forward(self, x): + return {k: m(x[k]) for k, m in self.named_children()} + + +class Model(nn.Sequential): + """ + Residual Symmetric U-Net with down/upsampling in/output. + """ + def __init__(self, core, in_spec, out_spec, out_channels, io_kernel=(1,5,5), + crop=None, scale_init=1.0, scale_factor=(1, 2, 2)): + super(Model, self).__init__() + + assert len(in_spec)==1, "model takes a single input" + in_channels = 1 + in_size = in_spec['input'][-3:] + assert all(s % f == 0 for s, f in zip(in_size, scale_factor)) + new_size = tuple(int(s / f) for s, f in zip(in_size, scale_factor)) + + self.add_module('down', DownBlock(size=new_size)) + self.add_module('in', InputBlock(in_channels, out_channels, io_kernel)) + self.add_module('core', core) + self.add_module('out', OutputBlock(out_channels, out_spec, io_kernel, scale_init=scale_init)) + self.add_module('up', UpBlock(out_spec, size=in_size)) + if crop is not None: + self.add_module('crop', Crop(crop)) From 05d432c485ef259e40a18d1920b7336c583bf44a Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Mon, 24 Feb 2025 19:30:05 +0900 Subject: [PATCH 45/47] refactor(Dockerfile): restructure and improve package installation --- docker/zettasets/Dockerfile | 87 ++++++++++++++++++------------------- 1 file changed, 42 insertions(+), 45 deletions(-) diff --git a/docker/zettasets/Dockerfile b/docker/zettasets/Dockerfile index e7d197b..a431787 100644 --- a/docker/zettasets/Dockerfile +++ b/docker/zettasets/Dockerfile @@ -1,64 +1,61 @@ FROM mambaorg/micromamba AS intermediate -# multi-stage build to pull in zettasets (private repo) +# Multi-stage build to pull in zettasets (private repo) USER root ARG GIT_ACCESS_TOKEN -RUN apt-get --allow-releaseinfo-change update && \ - apt-get install -y build-essential git && \ - git clone https://${GIT_ACCESS_TOKEN}@github.com/ZettaAI/zettasets.git && \ + +RUN apt-get update --allow-releaseinfo-change && \ + apt-get install -y --no-install-recommends build-essential git && \ + git clone https://${GIT_ACCESS_TOKEN}@github.com/ZettaAI/zettasets.git /tmp/zettasets && \ rm -rf /var/lib/apt/lists/* -FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime -# FROM pytorch/pytorch:1.2-cuda10.0-cudnn7-runtime -# FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime -# FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime +# Use an optimized PyTorch base image +# FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime +FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime + ENV DEBIAN_FRONTEND=noninteractive COPY --from=intermediate /tmp/zettasets /tmp/zettasets -RUN apt-get update \ - && apt-get install -y --no-install-recommends \ - build-essential git \ - libboost-all-dev \ - # gcloud cli (for samwise) +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + build-essential git libboost-all-dev \ curl apt-transport-https ca-certificates gnupg \ - # gcloud cli (for samwise) - && echo "deb https://packages.cloud.google.com/apt cloud-sdk main" \ - | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list \ - && curl https://packages.cloud.google.com/apt/doc/apt-key.gpg \ - | apt-key add - \ - && apt-get update && apt-get install google-cloud-cli -y \ - # handle imgaug issue - && apt-get install ffmpeg libsm6 libxext6 -y \ - # Cleanup - && rm -rf /var/lib/apt/lists/* \ - # python requirements - && conda install h5py cython matplotlib \ - scikit-image scikit-learn \ - && conda clean -t -p \ - # pypi packages - && pip install --no-cache-dir --upgrade \ - numpy cloud-volume task-queue tensorboardX imgaug wandb \ - # address cloud-files import issue - && pip install --no-cache-dir --upgrade \ - cffi brotli \ - # github packages - && pip install --no-cache-dir \ - git+https://github.com/seung-lab/DataTools \ - git+https://github.com/ZettaAI/DataProvider3 \ - git+https://github.com/ZettaAI/Augmentor \ - git+https://github.com/seung-lab/pytorch-emvision \ - git+https://github.com/ZettaAI/samwise + ffmpeg libsm6 libxext6 && \ + # Install Google Cloud CLI + echo "deb https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && \ + curl -fsSL https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add - && \ + apt-get update && apt-get install -y google-cloud-cli && \ + rm -rf /var/lib/apt/lists/* -RUN cd /tmp/zettasets && \ - pip install -e . +# Install Python dependencies using Conda (now with faster solving) +RUN conda install -y h5py cython matplotlib scikit-image scikit-learn && \ + conda clean -a + +# Explicitly pin NumPy to <2 to avoid ABI issues +RUN pip install --no-cache-dir --upgrade \ + "numpy<2" cloud-volume task-queue imgaug wandb \ + cffi brotli fastremap imagecodecs + +RUN pip install --no-cache-dir \ + git+https://github.com/seung-lab/DataTools && \ + pip install --no-cache-dir \ + git+https://github.com/ZettaAI/DataProvider3 && \ + pip install --no-cache-dir \ + git+https://github.com/ZettaAI/Augmentor && \ + pip install --no-cache-dir \ + git+https://github.com/seung-lab/pytorch-emvision && \ + pip install --no-cache-dir \ + git+https://github.com/ZettaAI/samwise -# Install the latest version of cloud-volume -RUN pip install --no-cache-dir --upgrade cloud-volume imagecodecs +WORKDIR /tmp/zettasets +RUN pip install --no-cache-dir -e . +# Set up working directory and environment RUN mkdir -p /workspace -ENV PYTHONPATH /DeepEM:${PYTHONPATH} +ENV PYTHONPATH=/DeepEM:${PYTHONPATH} COPY . /DeepEM/ WORKDIR /workspace + ENTRYPOINT ["python", "/DeepEM/deepem/train/run.py"] From 2dc4d030c49a89b9a958276dcd4699df8454ee24 Mon Sep 17 00:00:00 2001 From: Kisuk Lee Date: Mon, 24 Feb 2025 19:33:02 +0900 Subject: [PATCH 46/47] refactor(logger): remove TensorBoard support --- deepem/train/logger.py | 108 +---------------------------------------- deepem/train/option.py | 3 -- deepem/train/run.py | 2 - requirements.txt | 3 -- 4 files changed, 1 insertion(+), 115 deletions(-) diff --git a/deepem/train/logger.py b/deepem/train/logger.py index 8fb4069..2fd7a5c 100644 --- a/deepem/train/logger.py +++ b/deepem/train/logger.py @@ -2,14 +2,8 @@ import sys import datetime from collections import OrderedDict -import numpy as np import torch -from torchvision.utils import make_grid -from tensorboardX import SummaryWriter - -from deepem.loss.mean import vec2aff -from deepem.utils import torch_utils, py_utils class Logger(object): @@ -21,9 +15,6 @@ def __init__(self, opt): self.outputsz = opt.outputsz self.lr = opt.lr - # TensorBoard logging - self.writer = SummaryWriter(opt.log_dir) if opt.tensorboard else None - # Metric learning self.delta_d = opt.delta_d @@ -40,8 +31,7 @@ def __enter__(self): return self def __exit__(self, type, value, traceback): - if self.writer: - self.writer.close() + pass def record(self, phase, loss, nmsk, **kwargs): monitor = self.monitor[phase] @@ -58,15 +48,9 @@ def record(self, phase, loss, nmsk, **kwargs): def check(self, phase, iter_num): stats = self.monitor[phase].flush() - self.log(phase, iter_num, stats) self.display(phase, iter_num, stats) return stats - def log(self, phase, iter_num, stats): - if self.writer: - for k, v in stats.items(): - self.writer.add_scalar(f"{phase}/{k}", v, iter_num) - def display(self, phase, iter_num, stats): disp = "[%s] Iter: %8d, " % (phase, iter_num) for k, v in stats.items(): @@ -95,96 +79,6 @@ def flush(self): self.norm = OrderedDict() return ret - def log_images(self, phase, iter_num, preds, sample): - if self.writer is None: - return - - # Peep output size - key = sorted(self.out_spec)[0] - cropsz = sample[key].shape[-3:] - for k in sorted(self.out_spec): - outsz = sample[k].shape[-3:] - assert np.array_equal(outsz, cropsz) - - # Inputs - for k in sorted(self.in_spec): - tag = f"{phase}/images/{k}" - tensor = sample[k][0,...].cpu() - tensor = torch_utils.crop_center_no_strict(tensor, cropsz) - num_channels = tensor.shape[-4] - if num_channels > 3: - self.log_image(tag, tensor[0:3,...], iter_num) - else: - self.log_image(tag, tensor, iter_num) - - # Outputs - for k in sorted(self.out_spec): - - if k == 'affinity': - - # Prediction - tag = f"{phase}/images/{k}" - tensor = torch.sigmoid(preds[k][0,0:3,...]).cpu() - self.log_image(tag, tensor, iter_num) - - # Mask - tag = f"{phase}/masks/{k}" - msk = sample[k + '_mask'][0,...].cpu() - self.log_image(tag, msk, iter_num) - - # Target - tag = f"{phase}/labels/{k}" - seg = sample[k][0,0,...].cpu().numpy().astype('uint32') - rgb = torch.from_numpy(py_utils.seg2rgb(seg)) - self.log_image(tag, rgb, iter_num) - - elif k == 'embedding': - - vec = preds[k][0, ...] - - # Metric graph - tag = f"{phase}/images/metric_graph" - aff = vec2aff(vec, delta_d=self.delta_d) - self.log_image(tag, aff.cpu(), iter_num) - - # Embedding - tag = f"{phase}/images/{k}" - vec = preds[k][[0],...].cpu() # 1, c, z, y, x - vec = torch_utils.vec2pca(vec) - vec = vec.select(0, 0) - self.log_image(tag, vec, iter_num) - - # Target - tag = f"{phase}/labels/{k}" - seg = sample[k][0,0,...].cpu().numpy().astype('uint32') - rgb = torch.from_numpy(py_utils.seg2rgb(seg)) - self.log_image(tag, rgb, iter_num) - - else: - - # Prediction - tag = f"{phase}/images/{k}" - pred = torch.sigmoid(preds[k][0,...]).cpu() - self.log_image(tag, pred, iter_num) - - # Mask - tag = f"{phase}/masks/{k}" - msk = sample[k + '_mask'][0,...].cpu() - self.log_image(tag, msk, iter_num) - - # Target - tag = f"{phase}/labels/{k}" - target = sample[k][0,...].cpu() - self.log_image(tag, target, iter_num) - - def log_image(self, tag, tensor, iter_num): - if self.writer: - assert(torch.is_tensor(tensor)) - depth = tensor.shape[-3] - imgs = [tensor[:,z,:,:] for z in range(depth)] - img = make_grid(imgs, nrow=depth, padding=0) - self.writer.add_image(tag, img, iter_num) - def log_params(self, params): fname = os.path.join(self.log_dir, f"{self.timestamp}_params.csv") with open(fname, "w+") as f: diff --git a/deepem/train/option.py b/deepem/train/option.py index 5800f88..615b1db 100644 --- a/deepem/train/option.py +++ b/deepem/train/option.py @@ -176,9 +176,6 @@ def initialize(self): self.parser.add_argument('--export_onnx', action='store_true') self.parser.add_argument('--opset_version', type=int, default=10) - # TensorBoard logging - self.parser.add_argument('--tensorboard', action='store_true') - self.initialized = True def parse(self): diff --git a/deepem/train/run.py b/deepem/train/run.py index 6be9570..ec488e5 100644 --- a/deepem/train/run.py +++ b/deepem/train/run.py @@ -77,7 +77,6 @@ def train(opt): # Image logging if (i+1) % opt.imgs_intv == 0: - logger.log_images('train', i+1, preds, sample) wandb_logger.log_images('train', i+1, preds, sample) # Evaluation loop @@ -126,7 +125,6 @@ def eval_loop(iter_num, model, data_loader, opt, logger, wandb_logger): # Image logging if iter_num % opt.imgs_intv == 0: - logger.log_images('test', iter_num, preds, sample) wandb_logger.log_images('test', iter_num, preds, sample) print("-------------------------------------------") diff --git a/requirements.txt b/requirements.txt index 58d965a..0567101 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,8 +3,5 @@ igneous-pipeline scikit-image scikit-learn task-queue -tensorflow -tensorboard -tensorboardX wandb git+https://github.com/seung-lab/DataTools.git From aae325ae2808f7070fcf8ac23bbfbdf5d1f0f6b0 Mon Sep 17 00:00:00 2001 From: Ran Lu Date: Wed, 29 Jan 2025 14:06:18 -0500 Subject: [PATCH 47/47] Replace imp module with importlib imp was removed since python 3.12 --- deepem/test/utils.py | 3 +-- deepem/train/data.py | 9 +++++---- deepem/train/utils.py | 6 ++++-- deepem/utils/py_utils.py | 10 ++++++++++ 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/deepem/test/utils.py b/deepem/test/utils.py index 58d773b..893f40e 100644 --- a/deepem/test/utils.py +++ b/deepem/test/utils.py @@ -1,4 +1,3 @@ -import imp import numpy as np import os from types import SimpleNamespace @@ -11,7 +10,7 @@ def load_model(opt): # Create a model. - mod = imp.load_source('model', opt.model) + mod = py_utils.load_module('model', opt.model) if opt.onnx: model = OnnxModel(mod.create_model(opt), opt) else: diff --git a/deepem/train/data.py b/deepem/train/data.py index e320411..df05c3a 100644 --- a/deepem/train/data.py +++ b/deepem/train/data.py @@ -1,4 +1,5 @@ -import imp +from deepem.utils.py_utils import load_module + import numpy as np import torch @@ -45,20 +46,20 @@ def requires_grad(self, key): def build(self, opt, data, is_train, prob): # Data augmentation if opt.augment: - mod = imp.load_source('augment', opt.augment) + mod = load_module('augment', opt.augment) aug = mod.get_augmentation(is_train, **opt.aug_params) else: aug = None # Data sampler - mod = imp.load_source('sampler', opt.sampler) + mod = load_module('sampler', opt.sampler) spec = mod.get_spec(opt.in_spec, opt.out_spec) zspecs = opt.zettaset_specs sampler = mod.Sampler(data, spec, is_train, aug, prob, zspecs) # Sample modifier if opt.modifier: - mod = imp.load_source('modifier', opt.modifier) + mod = load_module('modifier', opt.modifier) self.modifier = mod.Modifier(**opt.modifier_kwargs) else: def default_modifier(x, **kwargs): diff --git a/deepem/train/utils.py b/deepem/train/utils.py index b09cf47..fb6bf4f 100644 --- a/deepem/train/utils.py +++ b/deepem/train/utils.py @@ -1,4 +1,3 @@ -import imp import os import glob @@ -9,6 +8,7 @@ from deepem.train.data import Data from deepem.train.model import Model, AmpModel from deepem.loss.utils import BinaryWeightBalancer +from deepem.utils.py_utils import load_module def get_criteria(opt): @@ -59,7 +59,9 @@ def get_criteria(opt): def load_model(opt): # Create a model. - mod = imp.load_source('model', opt.model) + + mod = load_module("model", opt.model) + if opt.mixed_precision: model = AmpModel(mod.create_model(opt), get_criteria(opt), opt) else: diff --git a/deepem/utils/py_utils.py b/deepem/utils/py_utils.py index 5f676e2..c7f5a37 100644 --- a/deepem/utils/py_utils.py +++ b/deepem/utils/py_utils.py @@ -4,6 +4,16 @@ from sklearn.decomposition import PCA +import importlib +import importlib.util + + +def load_module(name, path): + spec = importlib.util.spec_from_file_location(name, path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + def dict2tuple(d): return namedtuple('GenericDict', d.keys())(**d)