diff --git a/deepem/data/augment/cortex/aug_16nm.py b/deepem/data/augment/cortex/aug_16nm.py new file mode 100644 index 0000000..7a5f7a2 --- /dev/null +++ b/deepem/data/augment/cortex/aug_16nm.py @@ -0,0 +1,101 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=[], + border=[], + section_gap=0, + mask_section_gap=False, + **kwargs +): + augs = list() + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1, 3), dims=(3, 13), margin=(1, 3, 3), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(3, 13), margin=(1, 3, 3), + density=0.3, skip=0.1) + ) + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Missing section & misalignment + to_blend = list() + # Misalingments + trans = Compose([Misalign((0, 3), margin=1), + Misalign((0, 8), margin=1), + Misalign((0, 13), margin=1)]) + + # Out-of-alignments + slip = Compose([SlipMisalign((0, 3), interp=True, margin=1), + SlipMisalign((0, 8), interp=True, margin=1), + SlipMisalign((0, 13), interp=True, margin=1)]) + to_blend.append(Blend([trans, slip], props=[0.7, 0.3])) + if is_train: + to_blend.append(Blend([ + MisalignPlusMissing((2, 8), value=0, random=random), + MisalignPlusMissing((2, 8), value=0, random=False) + ])) + else: + to_blend.append(MisalignPlusMissing((2, 8), value=0, random=False)) + if missing > 0: + if is_train: + to_blend.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + to_blend.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + if lost: + if is_train: + to_blend.append(Blend([ + LostSection(1), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False) + ])) + augs.append(Blend(to_blend)) + + # Out-of-focus + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border(targets=border)) + + # Recompute connected components + if recompute: + augs.append(Label(targets=recompute)) + + # Section gap + if section_gap > 0: + augs.append(SectionGap(num_secs=section_gap, masked=mask_section_gap)) + + return Compose(augs) diff --git a/deepem/data/augment/cortex/aug_16nm_gap3-5.py b/deepem/data/augment/cortex/aug_16nm_gap3-5.py new file mode 100644 index 0000000..cb61c80 --- /dev/null +++ b/deepem/data/augment/cortex/aug_16nm_gap3-5.py @@ -0,0 +1,105 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=[], + border=[], + mask_section_gap=False, + **kwargs +): + augs = list() + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1, 3), dims=(3, 13), margin=(1, 3, 3), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(3, 13), margin=(1, 3, 3), + density=0.3, skip=0.1) + ) + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Missing section & misalignment + to_blend = list() + # Misalingments + trans = Compose([Misalign((0, 3), margin=1), + Misalign((0, 8), margin=1), + Misalign((0, 13), margin=1)]) + + # Out-of-alignments + slip = Compose([SlipMisalign((0, 3), interp=True, margin=1), + SlipMisalign((0, 8), interp=True, margin=1), + SlipMisalign((0, 13), interp=True, margin=1)]) + to_blend.append(Blend([trans, slip], props=[0.7, 0.3])) + if is_train: + to_blend.append(Blend([ + MisalignPlusMissing((2, 8), value=0, random=random), + MisalignPlusMissing((2, 8), value=0, random=False) + ])) + else: + to_blend.append(MisalignPlusMissing((2, 8), value=0, random=False)) + if missing > 0: + if is_train: + to_blend.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + to_blend.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + if lost: + if is_train: + to_blend.append(Blend([ + LostSection(1), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False) + ])) + augs.append(Blend(to_blend)) + + # Out-of-focus + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border(targets=border)) + + # Recompute connected components + if recompute: + augs.append(Label(targets=recompute)) + + # Section gap + augs.append( + Blend([ + SectionGap(num_secs=3, masked=mask_section_gap), + SectionGap(num_secs=4, masked=mask_section_gap), + SectionGap(num_secs=5, masked=mask_section_gap), + ]) + ) + + return Compose(augs) diff --git a/deepem/data/augment/cortex/aug_4nm.py b/deepem/data/augment/cortex/aug_4nm.py new file mode 100644 index 0000000..3fe0706 --- /dev/null +++ b/deepem/data/augment/cortex/aug_4nm.py @@ -0,0 +1,101 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=[], + border=[], + section_gap=0, + mask_section_gap=False, + **kwargs +): + augs = list() + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1, 3), dims=(10, 50), margin=(1, 10, 10), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(10, 50), margin=(1, 10, 10), + density=0.3, skip=0.1) + ) + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Missing section & misalignment + to_blend = list() + # Misalingments + trans = Compose([Misalign((0, 10), margin=1), + Misalign((0, 30), margin=1), + Misalign((0, 50), margin=1)]) + + # Out-of-alignments + slip = Compose([SlipMisalign((0, 10), interp=True, margin=1), + SlipMisalign((0, 30), interp=True, margin=1), + SlipMisalign((0, 50), interp=True, margin=1)]) + to_blend.append(Blend([trans, slip], props=[0.7, 0.3])) + if is_train: + to_blend.append(Blend([ + MisalignPlusMissing((5, 30), value=0, random=random), + MisalignPlusMissing((5, 30), value=0, random=False) + ])) + else: + to_blend.append(MisalignPlusMissing((5, 30), value=0, random=False)) + if missing > 0: + if is_train: + to_blend.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + to_blend.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + if lost: + if is_train: + to_blend.append(Blend([ + LostSection(1), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False) + ])) + augs.append(Blend(to_blend)) + + # Out-of-focus + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border(targets=border)) + + # Recompute connected components + if recompute: + augs.append(Label(targets=recompute)) + + # Section gap + if section_gap > 0: + augs.append(SectionGap(num_secs=section_gap, masked=mask_section_gap)) + + return Compose(augs) diff --git a/deepem/data/augment/cortex/aug_4nm_gap3-5.py b/deepem/data/augment/cortex/aug_4nm_gap3-5.py new file mode 100644 index 0000000..aaf890e --- /dev/null +++ b/deepem/data/augment/cortex/aug_4nm_gap3-5.py @@ -0,0 +1,105 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=[], + border=[], + mask_section_gap=False, + **kwargs +): + augs = list() + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1, 3), dims=(10, 50), margin=(1, 10, 10), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(10, 50), margin=(1, 10, 10), + density=0.3, skip=0.1) + ) + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Missing section & misalignment + to_blend = list() + # Misalingments + trans = Compose([Misalign((0, 10), margin=1), + Misalign((0, 30), margin=1), + Misalign((0, 50), margin=1)]) + + # Out-of-alignments + slip = Compose([SlipMisalign((0, 10), interp=True, margin=1), + SlipMisalign((0, 30), interp=True, margin=1), + SlipMisalign((0, 50), interp=True, margin=1)]) + to_blend.append(Blend([trans, slip], props=[0.7, 0.3])) + if is_train: + to_blend.append(Blend([ + MisalignPlusMissing((5, 30), value=0, random=random), + MisalignPlusMissing((5, 30), value=0, random=False) + ])) + else: + to_blend.append(MisalignPlusMissing((5, 30), value=0, random=False)) + if missing > 0: + if is_train: + to_blend.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + to_blend.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + if lost: + if is_train: + to_blend.append(Blend([ + LostSection(1), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False) + ])) + augs.append(Blend(to_blend)) + + # Out-of-focus + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border(targets=border)) + + # Recompute connected components + if recompute: + augs.append(Label(targets=recompute)) + + # Section gap + augs.append( + Blend([ + SectionGap(num_secs=3, masked=mask_section_gap), + SectionGap(num_secs=4, masked=mask_section_gap), + SectionGap(num_secs=5, masked=mask_section_gap), + ]) + ) + + return Compose(augs) diff --git a/deepem/data/augment/cortex/aug_8nm.py b/deepem/data/augment/cortex/aug_8nm.py new file mode 100644 index 0000000..2067c5a --- /dev/null +++ b/deepem/data/augment/cortex/aug_8nm.py @@ -0,0 +1,101 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=[], + border=[], + section_gap=0, + mask_section_gap=False, + **kwargs +): + augs = list() + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1, 3), dims=(5, 25), margin=(1, 5, 5), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(5, 25), margin=(1, 5, 5), + density=0.3, skip=0.1) + ) + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Missing section & misalignment + to_blend = list() + # Misalingments + trans = Compose([Misalign((0, 5), margin=1), + Misalign((0, 15), margin=1), + Misalign((0, 25), margin=1)]) + + # Out-of-alignments + slip = Compose([SlipMisalign((0, 5), interp=True, margin=1), + SlipMisalign((0, 15), interp=True, margin=1), + SlipMisalign((0, 25), interp=True, margin=1)]) + to_blend.append(Blend([trans, slip], props=[0.7, 0.3])) + if is_train: + to_blend.append(Blend([ + MisalignPlusMissing((3, 15), value=0, random=random), + MisalignPlusMissing((3, 15), value=0, random=False) + ])) + else: + to_blend.append(MisalignPlusMissing((3, 15), value=0, random=False)) + if missing > 0: + if is_train: + to_blend.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + to_blend.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + if lost: + if is_train: + to_blend.append(Blend([ + LostSection(1), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False) + ])) + augs.append(Blend(to_blend)) + + # Out-of-focus + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border(targets=border)) + + # Recompute connected components + if recompute: + augs.append(Label(targets=recompute)) + + # Section gap + if section_gap > 0: + augs.append(SectionGap(num_secs=section_gap, masked=mask_section_gap)) + + return Compose(augs) diff --git a/deepem/data/augment/flip_rotate.py b/deepem/data/augment/flip_rotate.py index 04b899a..fd329c9 100644 --- a/deepem/data/augment/flip_rotate.py +++ b/deepem/data/augment/flip_rotate.py @@ -1,12 +1,12 @@ from augmentor import * -def get_augmentation(is_train, recompute=False, **kwargs): +def get_augmentation(is_train, recompute=[], **kwargs): augs = list() # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) # Flip & rotate augs.append(FlipRotate()) diff --git a/deepem/data/augment/flip_rotate_iso.py b/deepem/data/augment/flip_rotate_iso.py index e849358..532fcc9 100644 --- a/deepem/data/augment/flip_rotate_iso.py +++ b/deepem/data/augment/flip_rotate_iso.py @@ -1,12 +1,12 @@ from augmentor import * -def get_augmentation(is_train, recompute=False, border=[], **kwargs): +def get_augmentation(is_train, recompute=[], border=[], **kwargs): augs = list() # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) # Flip & rotate augs.append(FlipRotateIsotropic()) diff --git a/deepem/data/augment/flip_rotate_iso_v2.py b/deepem/data/augment/flip_rotate_iso_v2.py index 955fced..3e6b2c6 100644 --- a/deepem/data/augment/flip_rotate_iso_v2.py +++ b/deepem/data/augment/flip_rotate_iso_v2.py @@ -1,12 +1,12 @@ from augmentor import * -def get_augmentation(is_train, recompute=False, border=[], **kwargs): +def get_augmentation(is_train, recompute=[], border=[], **kwargs): augs = list() # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) # Create border if border: diff --git a/deepem/data/augment/flyem/aug_16nm.py b/deepem/data/augment/flyem/aug_16nm.py new file mode 100644 index 0000000..9d2887d --- /dev/null +++ b/deepem/data/augment/flyem/aug_16nm.py @@ -0,0 +1,101 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=[], + border=[], + section_gap=0, + mask_section_gap=False, + **kwargs +): + augs = list() + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Missing section & misalignment + to_blend = list() + # Misalingments + trans = Compose([Misalign((0, 3), margin=1), + Misalign((0, 8), margin=1), + Misalign((0, 13), margin=1)]) + + # Out-of-alignments + slip = Compose([SlipMisalign((0, 3), interp=True, margin=1), + SlipMisalign((0, 8), interp=True, margin=1), + SlipMisalign((0, 13), interp=True, margin=1)]) + to_blend.append(Blend([trans, slip], props=[0.7, 0.3])) + if is_train: + to_blend.append(Blend([ + MisalignPlusMissing((2, 8), value=0, random=random), + MisalignPlusMissing((2, 8), value=0, random=False) + ])) + else: + to_blend.append(MisalignPlusMissing((2, 8), value=0, random=False)) + if missing > 0: + if is_train: + to_blend.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + to_blend.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + if lost: + if is_train: + to_blend.append(Blend([ + LostSection(1), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False) + ])) + augs.append(Blend(to_blend)) + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1, 3), dims=(3, 13), margin=(1, 3, 3), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(3, 13), margin=(1, 3, 3), + density=0.3, skip=0.1) + ) + + # Out-of-focus + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border(targets=border)) + + # Recompute connected components + if recompute: + augs.append(Label(targets=recompute)) + + # Section gap + if section_gap > 0: + augs.append(SectionGap(num_secs=section_gap, masked=mask_section_gap)) + + return Compose(augs) diff --git a/deepem/data/augment/flyem/aug_8nm.py b/deepem/data/augment/flyem/aug_8nm.py new file mode 100644 index 0000000..5e203e8 --- /dev/null +++ b/deepem/data/augment/flyem/aug_8nm.py @@ -0,0 +1,101 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=[], + border=[], + section_gap=0, + mask_section_gap=False, + **kwargs +): + augs = list() + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Missing section & misalignment + to_blend = list() + # Misalingments + trans = Compose([Misalign((0, 5), margin=1), + Misalign((0, 15), margin=1), + Misalign((0, 25), margin=1)]) + + # Out-of-alignments + slip = Compose([SlipMisalign((0, 5), interp=True, margin=1), + SlipMisalign((0, 15), interp=True, margin=1), + SlipMisalign((0, 25), interp=True, margin=1)]) + to_blend.append(Blend([trans, slip], props=[0.7, 0.3])) + if is_train: + to_blend.append(Blend([ + MisalignPlusMissing((3, 15), value=0, random=random), + MisalignPlusMissing((3, 15), value=0, random=False) + ])) + else: + to_blend.append(MisalignPlusMissing((3, 15), value=0, random=False)) + if missing > 0: + if is_train: + to_blend.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + to_blend.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + if lost: + if is_train: + to_blend.append(Blend([ + LostSection(1), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False) + ])) + augs.append(Blend(to_blend)) + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1, 3), dims=(5, 25), margin=(1, 5, 5), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(5, 25), margin=(1, 5, 5), + density=0.3, skip=0.1) + ) + + # Out-of-focus + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border(targets=border)) + + # Recompute connected components + if recompute: + augs.append(Label(targets=recompute)) + + # Section gap + if section_gap > 0: + augs.append(SectionGap(num_secs=section_gap, masked=mask_section_gap)) + + return Compose(augs) diff --git a/deepem/data/augment/flyem/aug_mip1.py b/deepem/data/augment/flyem/aug_mip1.py index 9f3eee1..bfee674 100644 --- a/deepem/data/augment/flyem/aug_mip1.py +++ b/deepem/data/augment/flyem/aug_mip1.py @@ -1,8 +1,17 @@ from augmentor import * -def get_augmentation(is_train, box=None, missing=7, blur=7, lost=True, - random=False, **kwargs): +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=[], + border=[], + **kwargs +): augs = list() # Brightness & contrast purterbation @@ -79,7 +88,15 @@ def get_augmentation(is_train, box=None, missing=7, blur=7, lost=True, if is_train: augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + # Recompute connected components + if recompute: + augs.append(Label(targets=recompute)) + # Flip & rotate augs.append(FlipRotate()) + # Create border + if border: + augs.append(Border(targets=border)) + return Compose(augs) diff --git a/deepem/data/augment/flyem/aug_mip2.py b/deepem/data/augment/flyem/aug_mip2.py new file mode 100644 index 0000000..25c9795 --- /dev/null +++ b/deepem/data/augment/flyem/aug_mip2.py @@ -0,0 +1,102 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=[], + border=[], + **kwargs +): + augs = list() + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Mutually exclusive augmentations + mutex = list() + + # (1) Misalingment + trans = Compose([Misalign((0, 3), margin=1), + Misalign((0, 8), margin=1), + Misalign((0, 13), margin=1)]) + slip = Compose([SlipMisalign((0, 3), interp=True, margin=1), + SlipMisalign((0, 8), interp=True, margin=1), + SlipMisalign((0, 13), interp=True, margin=1)]) + mutex.append(Blend([trans, slip], props=[0.7, 0.3])) + + # (2) Misalignment + missing section + if is_train: + mutex.append(Blend([ + MisalignPlusMissing((2, 8), value=0, random=random), + MisalignPlusMissing((2, 8), value=0, random=False) + ])) + else: + mutex.append(MisalignPlusMissing((2, 8), value=0, random=False)) + + # (3) Missing section + if missing > 0: + if is_train: + mutex.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + mutex.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + + # (4) Lost section + if lost: + if is_train: + mutex.append(Blend([ + LostSection(1), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False) + ])) + + # Mutually exclusive augmentations + augs.append(Blend(mutex)) + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1,3), dims=(3, 13), margin=(1, 3, 3), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(3, 13), margin=(1, 3, 3), + density=0.3, skip=0.1) + ) + + # Out-of-focus section + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Recompute connected components + if recompute: + augs.append(Label(targets=recompute)) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border(targets=border)) + + return Compose(augs) diff --git a/deepem/data/augment/grayscale_warping.py b/deepem/data/augment/grayscale_warping.py index d377ae1..5c9b840 100644 --- a/deepem/data/augment/grayscale_warping.py +++ b/deepem/data/augment/grayscale_warping.py @@ -1,13 +1,13 @@ from augmentor import * -def get_augmentation(is_train, recompute=False, grayscale=False, warping=False, +def get_augmentation(is_train, recompute=[], grayscale=False, warping=False, **kwargs): augs = list() # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) # Brightness & contrast purterbation if is_train and grayscale: diff --git a/deepem/data/augment/ariadne_worm/aug.py b/deepem/data/augment/isotropic/aug.py similarity index 87% rename from deepem/data/augment/ariadne_worm/aug.py rename to deepem/data/augment/isotropic/aug.py index 6f1504e..66b25d9 100644 --- a/deepem/data/augment/ariadne_worm/aug.py +++ b/deepem/data/augment/isotropic/aug.py @@ -8,10 +8,9 @@ def get_augmentation( is_train, - recompute=False, + recompute=[], box=None, blur=7, - random=False, border=[], **kwargs, ): @@ -28,14 +27,6 @@ def get_augmentation( skip=0.3, ) ) - # augs.append( - # MixedGrayscale2D( - # contrast_factor=0.5, - # brightness_factor=0.5, - # prob=1, - # skip=0.3, - # ) - # ) # Box if is_train: @@ -71,10 +62,6 @@ def get_augmentation( if is_train: augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) - # Recompute connected components - if recompute: - augs.append(Label()) - # Flip & rotate augs.append(FlipRotateIsotropic()) @@ -82,4 +69,8 @@ def get_augmentation( if border: augs.append(Border(targets=border)) + # Recompute connected components + if recompute: + augs.append(Label(targets=recompute)) + return Compose(augs) diff --git a/deepem/data/augment/ariadne_worm/aug_aniso.py b/deepem/data/augment/isotropic/aug_aniso.py similarity index 87% rename from deepem/data/augment/ariadne_worm/aug_aniso.py rename to deepem/data/augment/isotropic/aug_aniso.py index 364eb68..2f2edd1 100644 --- a/deepem/data/augment/ariadne_worm/aug_aniso.py +++ b/deepem/data/augment/isotropic/aug_aniso.py @@ -8,10 +8,9 @@ def get_augmentation( is_train, - recompute=False, + recompute=[], box=None, blur=7, - random=False, border=[], **kwargs ): @@ -28,14 +27,6 @@ def get_augmentation( skip=0.3, ) ) - # augs.append( - # MixedGrayscale2D( - # contrast_factor=0.5, - # brightness_factor=0.5, - # prob=1, - # skip=0.3, - # ) - # ) # Box if is_train: @@ -71,10 +62,6 @@ def get_augmentation( if is_train: augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) - # Recompute connected components - if recompute: - augs.append(Label()) - # Flip & rotate augs.append(FlipRotate()) @@ -82,4 +69,8 @@ def get_augmentation( if border: augs.append(Border(targets=border)) + # Recompute connected components + if recompute: + augs.append(Label(targets=recompute)) + return Compose(augs) diff --git a/deepem/data/augment/isotropic/aug_mip1.py b/deepem/data/augment/isotropic/aug_mip1.py new file mode 100644 index 0000000..0d6084e --- /dev/null +++ b/deepem/data/augment/isotropic/aug_mip1.py @@ -0,0 +1,76 @@ +"""Based on flyem/aug_mip1.py with some modifications. + +Mostly removing things that we might not need for faster +feedback cycles. +""" +from augmentor import * + + +def get_augmentation( + is_train, + recompute=[], + box=None, + blur=7, + border=[], + **kwargs, +): + augs = list() + + # Flip & rotate + augs.append(FlipRotateIsotropic()) + + # Brightness & contrast perturbation + augs.append( + Grayscale3D( + contrast_factor=0.5, + brightness_factor=0.5, + skip=0.3, + ) + ) + + # Box + if is_train: + if box == "fill": + augs.append( + Blend([ + FillBox( + random=True, + dims=(3, 13), + aniso=1, + density=0.3, + margin=(3, 3, 3), + individual=True, + skip=0.1, + ), + FillBox( + random=True, + dims=(3, 13), + aniso=1, + density=0.3, + margin=(3, 3, 3), + individual=False, + skip=0.1, + ), + ]) + ) + + # Out-of-focus section + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Flip & rotate + augs.append(FlipRotateIsotropic()) + + # Create border + if border: + augs.append(Border(targets=border)) + + # Recompute connected components + if recompute: + augs.append(Label(targets=recompute)) + + return Compose(augs) diff --git a/deepem/data/augment/kasthuri11/aug_v2.py b/deepem/data/augment/kasthuri11/aug_v2.py index 98f0ad4..20d9917 100644 --- a/deepem/data/augment/kasthuri11/aug_v2.py +++ b/deepem/data/augment/kasthuri11/aug_v2.py @@ -1,7 +1,7 @@ from augmentor import * -def get_augmentation(is_train, recompute=False, grayscale=False, missing=0, +def get_augmentation(is_train, recompute=[], grayscale=False, missing=0, blur=0, warping=False, misalign=0, box=None, mip=0, random=False, **kwargs): augs = list() @@ -62,6 +62,6 @@ def get_augmentation(is_train, recompute=False, grayscale=False, missing=0, # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) return Compose(augs) diff --git a/deepem/data/augment/kasthuri11/aug_v2_valid.py b/deepem/data/augment/kasthuri11/aug_v2_valid.py index c9b567a..456db2c 100644 --- a/deepem/data/augment/kasthuri11/aug_v2_valid.py +++ b/deepem/data/augment/kasthuri11/aug_v2_valid.py @@ -1,7 +1,7 @@ from augmentor import * -def get_augmentation(is_train, recompute=False, grayscale=False, missing=0, +def get_augmentation(is_train, recompute=[], grayscale=False, missing=0, blur=0, warping=False, misalign=0, box=None, mip=0, random=False, **kwargs): augs = list() @@ -64,6 +64,6 @@ def get_augmentation(is_train, recompute=False, grayscale=False, missing=0, # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) return Compose(augs) diff --git a/deepem/data/augment/kasthuri11/aug_v2_valid_no-interp.py b/deepem/data/augment/kasthuri11/aug_v2_valid_no-interp.py index 9c2ebe8..4a6e36e 100644 --- a/deepem/data/augment/kasthuri11/aug_v2_valid_no-interp.py +++ b/deepem/data/augment/kasthuri11/aug_v2_valid_no-interp.py @@ -1,7 +1,7 @@ from augmentor import * -def get_augmentation(is_train, recompute=False, grayscale=False, missing=0, +def get_augmentation(is_train, recompute=[], grayscale=False, missing=0, blur=0, warping=False, misalign=0, box=None, mip=0, random=False, **kwargs): augs = list() @@ -64,6 +64,6 @@ def get_augmentation(is_train, recompute=False, grayscale=False, missing=0, # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) return Compose(augs) diff --git a/deepem/data/augment/no_aug.py b/deepem/data/augment/no_aug.py index 20bafd1..b1b4fa3 100644 --- a/deepem/data/augment/no_aug.py +++ b/deepem/data/augment/no_aug.py @@ -1,12 +1,12 @@ from augmentor import * -def get_augmentation(is_train, recompute=False, **kwargs): +def get_augmentation(is_train, recompute=[], **kwargs): augs = list() # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) # Flip & rotate if not is_train: diff --git a/deepem/data/augment/pinky_basil/aug_mip1_v3.py b/deepem/data/augment/pinky_basil/aug_mip1_v3.py index f6105d0..d069650 100644 --- a/deepem/data/augment/pinky_basil/aug_mip1_v3.py +++ b/deepem/data/augment/pinky_basil/aug_mip1_v3.py @@ -1,8 +1,17 @@ from augmentor import * -def get_augmentation(is_train, box=None, missing=7, blur=7, lost=True, - random=False, **kwargs): +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=[], + border=[], + **kwargs +): augs = list() # Box @@ -72,7 +81,15 @@ def get_augmentation(is_train, box=None, missing=7, blur=7, lost=True, if is_train: augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + # Recompute connected components + if recompute: + augs.append(Label(targets=recompute)) + # Flip & rotate augs.append(FlipRotate()) + # Create border + if border: + augs.append(Border(targets=border)) + return Compose(augs) diff --git a/deepem/data/augment/pinky_basil/aug_mip2.py b/deepem/data/augment/pinky_basil/aug_mip2.py new file mode 100644 index 0000000..0441dbd --- /dev/null +++ b/deepem/data/augment/pinky_basil/aug_mip2.py @@ -0,0 +1,95 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=[], + border=[], + **kwargs +): + augs = list() + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1,3), dims=(3,13), margin=(1,3,3), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(3,13), margin=(1,3,3), + density=0.3, skip=0.1) + ) + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Missing section & misalignment + to_blend = list() + # Misalingments + trans = Compose([Misalign((0, 3), margin=1), + Misalign((0, 8), margin=1), + Misalign((0,13), margin=1)]) + + # Out-of-alignments + slip = Compose([SlipMisalign((0, 3), interp=True, margin=1), + SlipMisalign((0, 8), interp=True, margin=1), + SlipMisalign((0,13), interp=True, margin=1)]) + to_blend.append(Blend([trans,slip], props=[0.7,0.3])) + if is_train: + to_blend.append(Blend([ + MisalignPlusMissing((2,8), value=0, random=random), + MisalignPlusMissing((2,8), value=0, random=False) + ])) + else: + to_blend.append(MisalignPlusMissing((2,8), value=0, random=False)) + if missing > 0: + if is_train: + to_blend.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + to_blend.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + if lost: + if is_train: + to_blend.append(Blend([ + LostSection(1), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False) + ])) + augs.append(Blend(to_blend)) + + # Out-of-focus + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Recompute connected components + if recompute: + augs.append(Label(targets=recompute)) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border(targets=border)) + + return Compose(augs) diff --git a/deepem/data/augment/retina/aug_20nm.py b/deepem/data/augment/retina/aug_20nm.py new file mode 100644 index 0000000..58c0a93 --- /dev/null +++ b/deepem/data/augment/retina/aug_20nm.py @@ -0,0 +1,102 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=[], + border=[], + **kwargs +): + augs = list() + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1,3), dims=(5,25), margin=(1,5,5), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(5,25), margin=(1,5,5), + density=0.3, skip=0.1) + ) + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Mutually-exclusive data augmentations + to_blend = list() + + # Step + step = Compose([Misalign((0, 5), margin=1), + Misalign((0,15), margin=1), + Misalign((0,25), margin=1)]) + + # Slip + slip = Compose([SlipMisalign((0, 5), interp=True, margin=1), + SlipMisalign((0,15), interp=True, margin=1), + SlipMisalign((0,25), interp=True, margin=1)]) + + # Step + slip + to_blend.append(Blend([step, slip], props=[0.7, 0.3])) + + # Missing sections + if missing > 0: + if is_train: + to_blend.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + to_blend.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + + # Lost sections + if lost: + if is_train: + to_blend.append( + Blend( + [ + LostSection(1), + LostSection(2), + LostSection(3), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False), + ], + props=[0.4, 0.3, 0.2, 0.05, 0.05], + ) + ) + augs.append(Blend(to_blend)) + + # Out-of-focus sections + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border(targets=border)) + + # Recompute connected components + if recompute: + augs.append(Label(targets=recompute)) + + return Compose(augs) diff --git a/deepem/data/augment/retina/aug_mip1.py b/deepem/data/augment/retina/aug_mip1.py new file mode 100644 index 0000000..9a9cf21 --- /dev/null +++ b/deepem/data/augment/retina/aug_mip1.py @@ -0,0 +1,95 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=[], + border=[], + **kwargs +): + augs = list() + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1,3), dims=(5,25), margin=(1,5,5), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(5,25), margin=(1,5,5), + density=0.3, skip=0.1) + ) + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Missing section & misalignment + to_blend = list() + # Misalingments + trans = Compose([Misalign((0, 5), margin=1), + Misalign((0,15), margin=1), + Misalign((0,25), margin=1)]) + + # Out-of-alignments + slip = Compose([SlipMisalign((0, 5), interp=True, margin=1), + SlipMisalign((0,15), interp=True, margin=1), + SlipMisalign((0,25), interp=True, margin=1)]) + to_blend.append(Blend([trans,slip], props=[0.7,0.3])) + if is_train: + to_blend.append(Blend([ + MisalignPlusMissing((3,15), value=0, random=random), + MisalignPlusMissing((3,15), value=0, random=False) + ])) + else: + to_blend.append(MisalignPlusMissing((3,15), value=0, random=False)) + if missing > 0: + if is_train: + to_blend.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + to_blend.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + if lost: + if is_train: + to_blend.append(Blend([ + LostSection(1), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False) + ])) + augs.append(Blend(to_blend)) + + # Out-of-focus + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border(targets=border)) + + # Recompute connected components + if recompute: + augs.append(Label(targets=recompute)) + + return Compose(augs) diff --git a/deepem/data/augment/retina/aug_mip2.py b/deepem/data/augment/retina/aug_mip2.py new file mode 100644 index 0000000..bac9eec --- /dev/null +++ b/deepem/data/augment/retina/aug_mip2.py @@ -0,0 +1,95 @@ +from augmentor import * + + +def get_augmentation( + is_train, + box=None, + missing=7, + blur=7, + lost=True, + random=False, + recompute=[], + border=[], + **kwargs +): + augs = list() + + # Box + if is_train: + if box == 'noise': + augs.append( + NoiseBox(sigma=(1,3), dims=(3,13), margin=(1,3,3), + density=0.3, skip=0.1) + ) + elif box == 'fill': + augs.append( + FillBox(dims=(3,13), margin=(1,3,3), + density=0.3, skip=0.1) + ) + + # Brightness & contrast purterbation + augs.append( + MixedGrayscale2D( + contrast_factor=0.5, + brightness_factor=0.5, + prob=1, skip=0.3)) + + # Missing section & misalignment + to_blend = list() + # Misalingments + trans = Compose([Misalign((0, 3), margin=1), + Misalign((0, 8), margin=1), + Misalign((0,13), margin=1)]) + + # Out-of-alignments + slip = Compose([SlipMisalign((0, 3), interp=True, margin=1), + SlipMisalign((0, 8), interp=True, margin=1), + SlipMisalign((0,13), interp=True, margin=1)]) + to_blend.append(Blend([trans,slip], props=[0.7,0.3])) + if is_train: + to_blend.append(Blend([ + MisalignPlusMissing((2,8), value=0, random=random), + MisalignPlusMissing((2,8), value=0, random=False) + ])) + else: + to_blend.append(MisalignPlusMissing((2,8), value=0, random=False)) + if missing > 0: + if is_train: + to_blend.append(Blend([ + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False), + MixedMissingSection(maxsec=missing, individual=True, value=0, random=random), + MissingSection(maxsec=missing, individual=False, value=0, random=random), + ])) + else: + to_blend.append( + MixedMissingSection(maxsec=missing, individual=True, value=0, random=False) + ) + if lost: + if is_train: + to_blend.append(Blend([ + LostSection(1), + LostPlusMissing(value=0, random=random), + LostPlusMissing(value=0, random=False) + ])) + augs.append(Blend(to_blend)) + + # Out-of-focus + if blur > 0: + augs.append(MixedBlurrySection(maxsec=blur)) + + # Warping + if is_train: + augs.append(Warp(skip=0.3, do_twist=False, rot_max=45.0, scale_max=1.1)) + + # Flip & rotate + augs.append(FlipRotate()) + + # Create border + if border: + augs.append(Border(targets=border)) + + # Recompute connected components + if recompute: + augs.append(Label(targets=recompute)) + + return Compose(augs) diff --git a/deepem/data/augment/tilt_series/aug_v0.py b/deepem/data/augment/tilt_series/aug_v0.py index c6451fe..47eace8 100644 --- a/deepem/data/augment/tilt_series/aug_v0.py +++ b/deepem/data/augment/tilt_series/aug_v0.py @@ -2,7 +2,7 @@ def get_augmentation(is_train, tilt_series=(0,0,0), tilt_series_crop=None, - recompute=False, flip=False, noise=None, **kwargs): + recompute=[], flip=False, noise=None, **kwargs): augs = [] # Flip & rotate (isotropic) @@ -38,6 +38,6 @@ def get_augmentation(is_train, tilt_series=(0,0,0), tilt_series_crop=None, # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) return Compose(augs) diff --git a/deepem/data/augment/tilt_series/aug_v1_val.py b/deepem/data/augment/tilt_series/aug_v1_val.py index 54d8386..d47c30d 100644 --- a/deepem/data/augment/tilt_series/aug_v1_val.py +++ b/deepem/data/augment/tilt_series/aug_v1_val.py @@ -2,7 +2,7 @@ def get_augmentation(is_train, tilt_series=(0,0,0), tilt_series_crop=None, - recompute=False, box=None, missing=7, blur=7, random=False, + recompute=[], box=None, missing=7, blur=7, random=False, **kwargs): augs = [] @@ -28,7 +28,7 @@ def get_augmentation(is_train, tilt_series=(0,0,0), tilt_series_crop=None, # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) # Box if is_train: diff --git a/deepem/data/augment/tilt_series/no_aug_val.py b/deepem/data/augment/tilt_series/no_aug_val.py index 156c545..f3b005f 100644 --- a/deepem/data/augment/tilt_series/no_aug_val.py +++ b/deepem/data/augment/tilt_series/no_aug_val.py @@ -2,7 +2,7 @@ def get_augmentation(is_train, tilt_series=(0,0,0), tilt_series_crop=None, - recompute=False, **kwargs): + recompute=[], **kwargs): augs = [] # Flip & rotate (isotropic) @@ -27,6 +27,6 @@ def get_augmentation(is_train, tilt_series=(0,0,0), tilt_series_crop=None, # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) return Compose(augs) diff --git a/deepem/data/augment/tilt_series/noise_v0.py b/deepem/data/augment/tilt_series/noise_v0.py index 5e1db52..bcb857e 100644 --- a/deepem/data/augment/tilt_series/noise_v0.py +++ b/deepem/data/augment/tilt_series/noise_v0.py @@ -2,7 +2,7 @@ def get_augmentation(is_train, tilt_series=(0,0,0), tilt_series_crop=None, - recompute=False, flip=False, noise=None, **kwargs): + recompute=[], flip=False, noise=None, **kwargs): augs = [] # Flip & rotate (isotropic) @@ -38,6 +38,6 @@ def get_augmentation(is_train, tilt_series=(0,0,0), tilt_series_crop=None, # Recompute connected components if recompute: - augs.append(Label()) + augs.append(Label(targets=recompute)) return Compose(augs) diff --git a/deepem/data/dataset/fibsem/fib25/fib25_v0.py b/deepem/data/dataset/fibsem/fib25/fib25_v0.py deleted file mode 100644 index 8b7ff8e..0000000 --- a/deepem/data/dataset/fibsem/fib25/fib25_v0.py +++ /dev/null @@ -1,51 +0,0 @@ -import numpy as np -import os - -import dataprovider3.emio as emio - - -fib25_dir = 'FIB-25' -data_keys = ['validation_sample'] -merger_ids = [2148] - - -def load_data(data_dir, data_ids=None, **kwargs): - if data_ids is None: - return {} - - base_dir = os.path.expanduser(base_dir) - data_dir = os.path.join(base_dir, fib25_dir) - - data = {} - for data_id in data_ids: - if data_id in data_keys: - dpath = os.path.join(data_dir, data_id) - assert os.path.exists(dpath) - data[data_id] = load_dataset(dpath, **kwargs) - return data - - -def load_dataset(dpath, **kwargs): - dset = {} - - # Image - fpath = os.path.join(dpath, "img.h5") - print(fpath) - dset['img'] = emio.imread(fpath).astype(np.float32) - dset['img'] /= 255.0 - - # Segmentation - fpath = os.path.join(dpath, "seg.h5") - print(fpath) - dset['seg'] = emio.imread(fpath).astype(np.uint8) - - # Additoinal info - dset['loc'] = True - - # Mask out mergers - dset['msk'] = np.ones(dset['seg'].shape, dtype=np.uint8) - if len(merger_ids) > 0: - idx = np.isin(dset['seg'], merger_ids) - dset['msk'][idx] = 0 - - return dset \ No newline at end of file diff --git a/deepem/data/dataset/multi_zettaset.py b/deepem/data/dataset/multi_zettaset.py new file mode 100644 index 0000000..b5474e1 --- /dev/null +++ b/deepem/data/dataset/multi_zettaset.py @@ -0,0 +1,226 @@ +from __future__ import annotations + +import numpy as np +from numpy.typing import ArrayLike +import re + +from cloudvolume import CloudVolume, Bbox + +from zettasets.dataset import Dataset as Zettaset +from zettasets.sample import Sample + + +def is_valid_format(s: str) -> bool: + """Check if string has format 'A:B'.""" + return bool(re.match(r'^[^:]+:[^:]+$', s)) + + +def load_data( + zettaset_specs: dict[str, dict], + data_ids: list[str] | None = None, + **kwargs, +) -> dict[str, dict[str, np.ndarray]]: + """ + Load data from a zettaset. + + Parameters: + - zettaset_specs (dict[str, dict]): Specifications for the zettasets to load. + Each key is the name of the zettaset, and the value is a dictionary of + its specifications. + - data_ids (list[str], optional): Specific data identifiers to load. Defaults to None. + + Returns: + - dict[str, dict[str, np.ndarray]]: A dictionary containing the loaded zettaset data. + + Raises: + - ValueError: If `zettaset_specs` is empty or contains invalid specifications. + - KeyError: If any data id is not found in the loaded zettasets. + """ + + if not zettaset_specs: + raise ValueError("No zettaset specifications provided.") + + zettasets = _initialize_zettasets(zettaset_specs, **kwargs) + + data = {} + for data_id in data_ids or []: + if data_id in zettasets: + data.update(_process_zettaset(data_id, zettasets, zettaset_specs, **kwargs)) + else: + data.update(_process_sample(data_id, zettasets, zettaset_specs, **kwargs)) + + return data + + +def _initialize_zettasets( + zettaset_specs, + zettaset_resolution: tuple[float, float, float] | None = None, + **kwargs, +) -> dict[str, Zettaset]: + zettasets = {} + for name, spec in zettaset_specs.items(): + if "path" not in spec or not spec["path"].startswith("gs://"): + raise ValueError(f"Invalid zettaset specification for '{name}': missing or invalid 'path'.") + zettaset_path = spec["path"] + print(f"Zettaset `{name}` from [{zettaset_path}]") + resolution = tuple(spec.get("resolution", zettaset_resolution)) + print(f"{resolution=}") + zettasets[name] = Zettaset(zettaset_path, "", resolution) + return zettasets + + +def _process_zettaset( + zettaset_name: str, + zettasets: dict[str, Zettaset], + zettaset_specs: dict[str, dict], + **kwargs, +) -> dict[str, dict[str, dict[str, np.ndarray]]]: + """Processes a whole zettaset.""" + if zettaset_name not in zettasets: + raise KeyError(f"Zettaset '{zettaset_name}' not found.") + + data = {} + zettaset = zettasets[zettaset_name] + + for sample_name in zettaset.sample_names: + full_name = f"{zettaset_name}:{sample_name}" + data.update(_process_sample(full_name, zettasets, zettaset_specs, **kwargs)) + + return {zettaset_name: data} if data else {} + + +def _process_sample( + data_id: str, + zettasets: dict[str, Zettaset], + zettaset_specs: dict[str, dict], + zettaset_padding: tuple[int, int, int] = (0, 0, 0), + zettaset_padding_spec: dict[str, tuple[int, int, int]] = {}, + zettaset_mask: bool = True, + zettaset_resolution: tuple[int, int, int] | None = None, + **kwargs, +) -> dict[str, dict[str, np.ndarray]]: + if not is_valid_format(data_id): + raise ValueError(f"Invalid format for data id '{data_id}'. Expected format 'A:B'.") + + zettaset_name, sample_name = data_id.split(':') + + if zettaset_name not in zettasets: + raise KeyError(f"Zettaset '{zettaset_name}' not found.") + + zettaset = zettasets[zettaset_name] + + if sample_name not in zettaset.sample_names: + raise KeyError(f"Sample name '{sample_name}' not found in zettaset '{zettaset_name}'.") + + assert zettaset_name in zettaset_specs + zettaset_spec = zettaset_specs[zettaset_name] + + # Determine padding: sample-specific overrides zettaset-specific + padding = tuple(zettaset_padding_spec.get(data_id, zettaset_spec.get("padding", zettaset_padding))) + no_mask = zettaset_spec.get("no_mask", not zettaset_mask) + + # Determine resolution: zettaset-specific overrides zettaset_resolution + resolution = tuple(zettaset_spec.get("resolution", zettaset_resolution)) + + print(f"Sample [{data_id}]") + return {data_id: load_sample( + zettaset.samples[sample_name], + padding, + no_mask, + resolution, + **kwargs, + )} + + +def load_sample( + sample: Sample, + padding: tuple[int, int, int] = (0, 0, 0), + no_mask: bool = False, + resolution: tuple[int, int, int] | None = None, + zettaset_lookup: dict[str, str] | None = None, + requires_binarize: list[str] = [], + zettaset_share_mask: str | None = None, + semantic_mapping: dict[str, int] = {}, + **kwargs +) -> dict[str, np.ndarray]: + """Load image and labels from a Sample.""" + + def convert_array(arr: ArrayLike) -> np.ndarray: + return np.array(arr).transpose(3, 2, 1, 0)[0, ...] + + dset: dict[str, np.ndarray] = {} + + # Bbox with padding + resolution = resolution or sample.base_resolution + bbox = sample.bbox * (sample.base_resolution / np.array(resolution)) + xyz_padding = ( + tuple(reversed(padding)) + if padding != (0, 0, 0) + else (0, 0, 0) + ) + image_bbox = Bbox(bbox.minpt - xyz_padding, bbox.maxpt + xyz_padding) + + # Image + vol = CloudVolume( # pylint: disable=unsubscriptable-object + sample.src_image_path, + mip=resolution, + fill_missing=True, + bounded=False, + )[image_bbox.to_slices()] + dset["input"] = convert_array(vol) / 255.0 + print(f"\tinput: {dset['input'].shape}") + + # Assumes that zettaset's annotation names follow DeepEM's convention. + zettaset_lookup = zettaset_lookup or {x: x for x in sample.annotation_names} + + # Shared mask + shared_mask = None + if (not no_mask) and zettaset_share_mask: + key = zettaset_share_mask + mask_key = f"{zettaset_share_mask}_mask" + if key not in sample.masks: + raise KeyError(f"Mask '{mask_key}' not found.") + mask_vol = sample.read_mask(key)[key] + shared_mask = convert_array(mask_vol).astype("uint8") + + # Process annotations + for name, key in zettaset_lookup.items(): + + # Annotation + vol = sample.read(key)[key] + dset[name] = convert_array(vol) + anno_log = f"\t{name}: {dset[name].shape}" + + # Semantic mapping or binarize + if name in semantic_mapping: + dset[name] = (dset[name] == semantic_mapping[name]).astype("uint8") + elif name in requires_binarize: + dset[name] = (dset[name] > 0).astype("uint8") + + # Mask + mask_key = f"{name}_mask" + if shared_mask is not None: + dset[mask_key] = shared_mask + elif (not no_mask) and (key in sample.masks): + mask_vol = sample.read_mask(key)[key] + dset[mask_key] = convert_array(mask_vol).astype("uint8") + else: + dset[mask_key] = np.ones_like(dset[name], dtype="uint8") + msk_log = f"\t{mask_key}: {dset[mask_key].shape}" + + # Padding + if padding != (0, 0, 0): + pad_width = tuple((p, p) for p in padding) + dset[name] = np.pad(dset[name], pad_width, "constant") + anno_log += f" -> {dset[name].shape}" + dset[mask_key] = np.pad(dset[mask_key], pad_width, "constant") + msk_log += f" -> {dset[mask_key].shape}" + + print(anno_log) + print(msk_log) + + return dset + + +if __name__ == "__main__": + pass diff --git a/deepem/data/dataset/zettaset.py b/deepem/data/dataset/zettaset.py index 2c7aca1..b4a7128 100644 --- a/deepem/data/dataset/zettaset.py +++ b/deepem/data/dataset/zettaset.py @@ -11,6 +11,7 @@ def load_data( zettaset_paths: list[str], data_ids: list[str] | None = None, + zettaset_resolution: tuple[float, float, float] | None = None, **kwargs ) -> dict[str, dict[str, np.ndarray]]: """ Load data from a zettaset.""" @@ -21,9 +22,8 @@ def load_data( zettasets = [] for zettaset_path in zettaset_paths: assert zettaset_path.startswith("gs://") - motivation = "Load a zettaset from DeepEM" print(f"Zettaset [{zettaset_path}]") - zettasets.append(Zettaset(zettaset_path, motivation)) + zettasets.append(Zettaset(zettaset_path, "", zettaset_resolution)) # Load data from a zettaset. data = {} @@ -31,7 +31,11 @@ def load_data( for zettaset in zettasets: if data_id in zettaset.sample_names: print(f"Sample [{data_id}]") - data[data_id] = load_sample(zettaset.samples[data_id], **kwargs) + data[data_id] = load_sample( + zettaset.samples[data_id], + zettaset_resolution, + **kwargs + ) break if data_id not in data: raise KeyError(f"Invalid data id:{data_id}") @@ -41,6 +45,7 @@ def load_data( def load_sample( sample: Sample, + zettaset_resolution: tuple[float, float, float] | None = None, zettaset_lookup: dict[str, str] | None = None, zettaset_padding: tuple[int, int, int] = (0, 0, 0), zettaset_mask: bool = True, @@ -54,18 +59,20 @@ def convert_array(arr: ArrayLike) -> np.ndarray: dset: dict[str, np.ndarray] = {} - # Image - if zettaset_padding == (0, 0, 0): - image_bbox = sample.bbox - else: - xyz_padding = tuple(reversed(zettaset_padding)) - image_bbox = Bbox( - sample.bbox.minpt - xyz_padding, sample.bbox.maxpt + xyz_padding - ) + # Bbox with padding + resolution = zettaset_resolution or sample.base_resolution + bbox = sample.bbox * (sample.base_resolution / np.array(resolution)) + xyz_padding = ( + tuple(reversed(zettaset_padding)) + if zettaset_padding != (0, 0, 0) + else (0, 0, 0) + ) + image_bbox = Bbox(bbox.minpt - xyz_padding, bbox.maxpt + xyz_padding) + # Image vol = CloudVolume( # pylint: disable=unsubscriptable-object sample.src_image_path, - mip=sample.base_resolution, + mip=resolution, fill_missing=True, bounded=False, )[image_bbox.to_slices()] diff --git a/deepem/data/sampler/zettaset.py b/deepem/data/sampler/zettaset.py index d09501e..e9ea5d0 100644 --- a/deepem/data/sampler/zettaset.py +++ b/deepem/data/sampler/zettaset.py @@ -3,7 +3,7 @@ import numpy as np from augmentor import Augment -from dataprovider3 import DataProvider, Dataset +from dataprovider3 import DataProvider, Dataset, DataSuperset def get_spec( @@ -29,10 +29,11 @@ def __init__( spec: dict[str, tuple[int, int, int]], is_train: bool, aug: Augment | None = None, - prob: dict[str, float] | None = None + prob: dict[str, float] | None = None, + zettaset_specs: dict[str, dict] | None = None, ): self.is_train = is_train - self.build(data, spec, aug, prob) + self.build(data, spec, aug, prob, zettaset_specs) def __call__(self) -> dict[str, np.ndarray]: sample = self.dataprovider() @@ -55,31 +56,61 @@ def build( self, data: dict[str, dict[str, np.ndarray]], spec: dict[str, tuple[int, int, int]], - aug: Augment | None, - prob: dict[str, float] | None + aug: Augment | None = None, + prob: dict[str, float] | None = None, + zettaset_specs: dict[str, dict] | None = None, ) -> None: - dp = DataProvider(spec) - keys = data.keys() - for key in keys: - dp.add_dataset(self.build_dataset(key, data[key], spec)) - dp.set_augment(aug) - dp.set_imgs(["input"]) - dp.set_segs(["affinity", "long_range", "embedding"]) - prob = [prob[k] for k in keys] if prob is not None else prob - dp.set_sampling_weights(p=prob) - self.dataprovider = dp - print(dp) + """ + Builds the data provider with datasets, augmentation, and sampling weights. + """ + self.dataprovider = DataProvider(spec) + + # Add datasets to the data provider + for key, value in data.items(): + build_method = ( + self.build_datasuperset + if zettaset_specs and key in zettaset_specs + else self.build_dataset + ) + self.dataprovider.add_dataset(build_method(key, value, spec)) + + # Set augmentation, image types, and segmentation types + self.dataprovider.set_augment(aug) + self.dataprovider.set_imgs(["input"]) + self.dataprovider.set_segs(["affinity", "long_range", "embedding"]) + + # Initialize sampling weights (even if prob is None) + sampling_weights = [prob[k] for k in data.keys()] if prob else None + self.dataprovider.set_sampling_weights(p=sampling_weights) + + print(self.dataprovider) + + def build_datasuperset( + self, + tag: str, + data: dict[str, dict[str, np.ndarray]], + spec: dict[str, tuple[int, int, int]], + ) -> DataSuperset: + """Create a DataSuperset from the given data.""" + dset = DataSuperset(tag=tag) + + # Add datasets to the DataSuperset + for key, value in data.items(): + dset.add_dataset(self.build_dataset(key, value, spec)) + + return dset def build_dataset( self, tag: str, data: dict[str, np.ndarray], - spec: dict[str, tuple[int, int, int]] + spec: dict[str, tuple[int, int, int]], ) -> Dataset: - """Create a Dataset.""" + """Create a Dataset from the given data and specification.""" dset = Dataset(tag=tag) - for key in spec.keys(): + # Iterate over the spec dictionary to add data and masks to the dataset + for key, _ in spec.items(): if key.endswith("_mask"): dset.add_mask(key=key, data=data[key], loc=True) else: diff --git a/deepem/loss/affinity.py b/deepem/loss/affinity.py index d7325d2..da00607 100644 --- a/deepem/loss/affinity.py +++ b/deepem/loss/affinity.py @@ -7,15 +7,19 @@ class EdgeSampler(object): - def __init__(self, edges): + def __init__(self, edges, split_boundary=True): self.edges = list(edges) + self.split_boundary = split_boundary def generate_edges(self): return list(self.edges) def generate_true_aff(self, obj, edge): o1, o2 = torch_utils.get_pair(obj, edge) - ret = ((o1 == o2) & (o1 != 0) & (o2 != 0)) + if self.split_boundary: + ret = ((o1 == o2) & (o1 != 0) & (o2 != 0)) + else: + ret = (o1 == o2) return ret.type(obj.type()) def generate_mask_aff(self, mask, edge): @@ -56,25 +60,12 @@ def forward(self, preds, targets, masks): return loss, nmsk - # def class_balancing(self, target, mask): - # if not self.balancing: - # return mask - # dtype = mask.type() - # m_int = mask * torch.eq(target, 1).type(dtype) - # m_ext = mask * torch.eq(target, 0).type(dtype) - # n_int = m_int.sum().item() - # n_ext = m_ext.sum().item() - # if n_int > 0 and n_ext > 0: - # m_int *= n_ext/(n_int + n_ext) - # m_ext *= n_int/(n_int + n_ext) - # return (m_int + m_ext).type(dtype) - class AffinityLoss(nn.Module): - def __init__(self, edges, criterion, size_average=False, - class_balancer=None): + def __init__(self, edges, criterion, split_boundary=True, + size_average=False, class_balancer=None): super(AffinityLoss, self).__init__() - self.sampler = EdgeSampler(edges) + self.sampler = EdgeSampler(edges, split_boundary=split_boundary) self.decoder = AffinityLoss.Decoder(edges) self.criterion = EdgeCRF( criterion, diff --git a/deepem/loss/loss.py b/deepem/loss/loss.py index 1ea8c83..74064b1 100644 --- a/deepem/loss/loss.py +++ b/deepem/loss/loss.py @@ -43,7 +43,7 @@ def forward(self, input, target, mask): m_ext = torch.le(activ, m0) * torch.eq(target, 0) mask *= 1 - (m_int + m_ext).type(mask.dtype) - loss = self.bce(input, target, weight=mask, size_average=False) + loss = self.bce(input, target, weight=mask, reduction='sum') if self.size_average: loss = loss / nmsk.item() @@ -87,7 +87,7 @@ def forward(self, input, target, mask): m_ext = torch.le(activ, m0) * torch.eq(target, 0) mask *= 1 - (m_int + m_ext).type(mask.dtype) - loss = self.mse(activ, target, reduce=False) + loss = self.mse(activ, target, reduction='none') loss = (loss * mask).sum() if self.size_average: diff --git a/deepem/loss/mean.py b/deepem/loss/mean.py index 1856d9c..72f7eda 100644 --- a/deepem/loss/mean.py +++ b/deepem/loss/mean.py @@ -77,6 +77,8 @@ def __init__( delta_v: float = 0.0, delta_d: float = 1.5, recompute_ext: bool = False, + mask_background: bool = True, + loss_scale_factor: tuple[float, float, float] | None = None, **kwargs, ): super().__init__() @@ -86,6 +88,8 @@ def __init__( self.delta_v = delta_v # Variance (intra-cluster pull force) hinge self.delta_d = delta_d # Distance (inter-cluster push force) hinge self.recompute_ext = recompute_ext + self.mask_background = mask_background + self.loss_scale_factor = loss_scale_factor def forward( self, @@ -102,25 +106,34 @@ def forward( """ device = embd.device + # Downsample if enabled + if self.loss_scale_factor is not None: + embd = F.interpolate(embd, scale_factor=self.loss_scale_factor, mode='trilinear', align_corners=False) + trgt = F.interpolate(trgt, scale_factor=self.loss_scale_factor, mode='nearest') + mask = F.interpolate(mask, scale_factor=self.loss_scale_factor, mode='nearest') + if splt is not None: + splt = F.interpolate(splt, scale_factor=self.loss_scale_factor, mode='nearest') + groups = None if self.recompute_ext: assert splt is not None - trgt_np = np.squeeze(trgt.cpu().numpy()) - splt_np = np.squeeze(splt.cpu().numpy()) - mask_np = np.squeeze(mask.cpu().numpy()) - groups = create_mapping(trgt_np, splt_np, mask_np) + trgt = torch.squeeze(trgt) + splt = torch.squeeze(splt) + mask = torch.squeeze(mask) + groups = create_mapping(trgt.cpu().numpy(), splt.cpu().numpy(), mask.cpu().numpy()) trgt = splt trgt = trgt.to(torch.int) - trgt *= (mask > 0).to(torch.int) - # Unique nonzero IDs - ids = np.unique(trgt.cpu().numpy()) - ids = ids[ids != 0].tolist() + # Filter out background and get unique IDs + masked_trgt = trgt[mask > 0] + if self.mask_background: + masked_trgt = masked_trgt[masked_trgt != 0] + ids = torch.unique(masked_trgt).tolist() # Recompute external matrix mext = self.compute_ext_matrix(ids, groups, self.recompute_ext, device) - vecs = self.generate_vecs(embd, trgt, ids) + vecs = self.generate_vecs(embd, trgt, mask, ids) means = [torch.mean(vec, dim=0) for vec in vecs] weights = [1.0] * len(vecs) @@ -186,17 +199,29 @@ def generate_vecs( self, embd: torch.Tensor, trgt: torch.Tensor, + mask: torch.Tensor, ids: Sequence[int], ) -> list[torch.Tensor]: """ Generate a list of vectorized embeddings for each ground truth object. """ + if self.mask_background and 0 in ids: + raise ValueError("ID '0' is not allowed when mask_background is enabled.") + + mask_bool = mask.bool() if not self.mask_background else None result = [] + for obj_id in ids: - obj = torch.nonzero(trgt == int(obj_id)) - z, y, x = obj[:, -3], obj[:, -2], obj[:, -1] - vec = embd[0, :, z, y, x].transpose(0, 1) # Count x Dim + obj_mask = (trgt == int(obj_id)) & mask_bool if mask_bool is not None else (trgt == int(obj_id)) + idx = torch.nonzero(obj_mask, as_tuple=True) + + if idx[0].numel() == 0: + # If there are no indices for this ID, skip to the next one + continue + + vec = embd[0, :, idx[-3], idx[-2], idx[-1]].transpose(0, 1) # Count x Dim result.append(vec) + return result def compute_ext_matrix( diff --git a/deepem/models/pinky_mito.py b/deepem/models/pinky_mito.py index 036a54a..4c37021 100644 --- a/deepem/models/pinky_mito.py +++ b/deepem/models/pinky_mito.py @@ -1,4 +1,3 @@ -from __future__ import print_function from itertools import repeat import collections import math @@ -26,7 +25,7 @@ def _ntuple(n): Copied from the PyTorch source code (https://github.com/pytorch). """ def parse(x): - if isinstance(x, collections.Iterable): + if isinstance(x, collections.abc.Iterable): return x return tuple(repeat(x, n)) return parse diff --git a/deepem/models/rsunet_act_scale.py b/deepem/models/rsunet_act_scale.py new file mode 100644 index 0000000..751255e --- /dev/null +++ b/deepem/models/rsunet_act_scale.py @@ -0,0 +1,74 @@ +import torch +import torch.nn as nn + +import emvision +from emvision.models import rsunet_act, rsunet_act_gn + +from deepem.models.layers import Conv, Crop, Scale + + +def create_model(opt): + if opt.width: + width = opt.width + depth = len(width) + else: + width = [16,32,64,128,256,512] + depth = opt.depth + if opt.group > 0: + # Group normalization + core = rsunet_act_gn(width=width[:depth], group=opt.group, act=opt.act) + else: + # Batch normalization + core = rsunet_act(width=width[:depth], act=opt.act) + return Model(core, opt.in_spec, opt.out_spec, width[0], crop=opt.crop, + onnx=opt.onnx, scale_init=opt.scale_init) + + +class InputBlock(nn.Sequential): + def __init__(self, in_channels, out_channels, kernel_size): + super(InputBlock, self).__init__() + self.add_module('conv', Conv(in_channels, out_channels, kernel_size)) + + +class OutputBlock(nn.Module): + def __init__(self, in_channels, out_spec, kernel_size, onnx=False, scale_init=1.0): + super(OutputBlock, self).__init__() + self.onnx = onnx + for k, v in out_spec.items(): + out_channels = v[-4] + if k == 'embedding': + self.add_module( + k, + nn.Sequential( + Conv(in_channels, out_channels, kernel_size, bias=True), + Scale(init_value=scale_init), + ), + ) + else: + self.add_module(k, + Conv(in_channels, out_channels, kernel_size, bias=True)) + + def forward(self, x): + if self.onnx: + return tuple(m(x) for k, m in self.named_children()) + else: + return {k: m(x) for k, m in self.named_children()} + + +class Model(nn.Sequential): + """ + Residual Symmetric U-Net. + """ + def __init__(self, core, in_spec, out_spec, out_channels, io_kernel=(1,5,5), + crop=None, onnx=False, scale_init=1.0): + super(Model, self).__init__() + + assert len(in_spec)==1, "model takes a single input" + in_channels = 1 + + self.add_module('in', InputBlock(in_channels, out_channels, io_kernel)) + self.add_module('core', core) + self.add_module('out', + OutputBlock(out_channels, out_spec, io_kernel, onnx=onnx, scale_init=scale_init)) + if crop is not None: + self.add_module('crop', Crop(crop)) diff --git a/deepem/models/updown3x_act.py b/deepem/models/updown3x_act.py new file mode 100644 index 0000000..439ec4e --- /dev/null +++ b/deepem/models/updown3x_act.py @@ -0,0 +1,78 @@ +import torch +import torch.nn as nn + +import emvision +from emvision.models import rsunet_act, rsunet_act_gn + +from deepem.models.layers import Conv, Crop + + +def create_model(opt): + if opt.width: + width = opt.width + depth = len(width) + else: + width = [16,32,64,128,256,512] + depth = opt.depth + if opt.group > 0: + # Group normalization + core = rsunet_act_gn(width=width[:depth], group=opt.group, act=opt.act) + else: + # Batch normalization + core = rsunet_act(width=width[:depth], act=opt.act) + return Model(core, opt.in_spec, opt.out_spec, width[0], crop=opt.crop) + + +class InputBlock(nn.Sequential): + def __init__(self, in_channels, out_channels, kernel_size): + super(InputBlock, self).__init__() + self.add_module('conv', Conv(in_channels, out_channels, kernel_size)) + + +class OutputBlock(nn.Module): + def __init__(self, in_channels, out_spec, kernel_size): + super(OutputBlock, self).__init__() + for k, v in out_spec.items(): + out_channels = v[-4] + self.add_module(k, + Conv(in_channels, out_channels, kernel_size, bias=True)) + + def forward(self, x): + return {k: m(x) for k, m in self.named_children()} + + +class DownBlock(nn.Sequential): + def __init__(self, scale_factor=(1,3,3)): + super(DownBlock, self).__init__() + self.add_module('down', nn.AvgPool3d(scale_factor)) + + +class UpBlock(nn.Module): + def __init__(self, out_spec, scale_factor=(1,3,3)): + super(UpBlock, self).__init__() + for k, v in out_spec.items(): + self.add_module(k, + nn.Upsample(scale_factor=scale_factor, mode='trilinear')) + + def forward(self, x): + return {k: m(x[k]) for k, m in self.named_children()} + + +class Model(nn.Sequential): + """ + Residual Symmetric U-Net with down/upsampling in/output. + """ + def __init__(self, core, in_spec, out_spec, out_channels, io_kernel=(1,5,5), + scale_factor=(1,3,3), crop=None): + super(Model, self).__init__() + + assert len(in_spec)==1, "model takes a single input" + in_channels = 1 + + self.add_module('down', DownBlock(scale_factor=scale_factor)) + self.add_module('in', InputBlock(in_channels, out_channels, io_kernel)) + self.add_module('core', core) + self.add_module('out', OutputBlock(out_channels, out_spec, io_kernel)) + self.add_module('up', UpBlock(out_spec, scale_factor=scale_factor)) + if crop is not None: + self.add_module('crop', Crop(crop)) diff --git a/deepem/models/updown_act.py b/deepem/models/updown_act.py index 25a02f5..2b33f38 100644 --- a/deepem/models/updown_act.py +++ b/deepem/models/updown_act.py @@ -20,7 +20,7 @@ def create_model(opt): else: # Batch normalization core = rsunet_act(width=width[:depth], act=opt.act) - return Model(core, opt.in_spec, opt.out_spec, width[0], crop=opt.crop) + return Model(core, opt.in_spec, opt.out_spec, width[0], crop=opt.crop, onnx=opt.onnx) class InputBlock(nn.Sequential): @@ -30,15 +30,19 @@ def __init__(self, in_channels, out_channels, kernel_size): class OutputBlock(nn.Module): - def __init__(self, in_channels, out_spec, kernel_size): + def __init__(self, in_channels, out_spec, kernel_size, onnx=False): super(OutputBlock, self).__init__() + self.onnx = onnx for k, v in out_spec.items(): out_channels = v[-4] self.add_module(k, Conv(in_channels, out_channels, kernel_size, bias=True)) def forward(self, x): - return {k: m(x) for k, m in self.named_children()} + if self.onnx: + return tuple(m(x) for k, m in self.named_children()) + else: + return {k: m(x) for k, m in self.named_children()} class DownBlock(nn.Sequential): @@ -48,14 +52,18 @@ def __init__(self, scale_factor=(1,2,2)): class UpBlock(nn.Module): - def __init__(self, out_spec, scale_factor=(1,2,2)): + def __init__(self, out_spec, scale_factor=(1,2,2), onnx=False): super(UpBlock, self).__init__() + self.onnx = onnx for k, v in out_spec.items(): self.add_module(k, nn.Upsample(scale_factor=scale_factor, mode='trilinear')) def forward(self, x): - return {k: m(x[k]) for k, m in self.named_children()} + if self.onnx: + return tuple(m(x[i]) for i, (_, m) in enumerate(self.named_children())) + else: + return {k: m(x[k]) for k, m in self.named_children()} class Model(nn.Sequential): @@ -63,7 +71,7 @@ class Model(nn.Sequential): Residual Symmetric U-Net with down/upsampling in/output. """ def __init__(self, core, in_spec, out_spec, out_channels, io_kernel=(1,5,5), - scale_factor=(1,2,2), crop=None): + scale_factor=(1,2,2), crop=None, onnx=False): super(Model, self).__init__() assert len(in_spec)==1, "model takes a single input" @@ -72,7 +80,7 @@ def __init__(self, core, in_spec, out_spec, out_channels, io_kernel=(1,5,5), self.add_module('down', DownBlock(scale_factor=scale_factor)) self.add_module('in', InputBlock(in_channels, out_channels, io_kernel)) self.add_module('core', core) - self.add_module('out', OutputBlock(out_channels, out_spec, io_kernel)) - self.add_module('up', UpBlock(out_spec, scale_factor=scale_factor)) + self.add_module('out', OutputBlock(out_channels, out_spec, io_kernel, onnx=onnx)) + self.add_module('up', UpBlock(out_spec, scale_factor=scale_factor, onnx=onnx)) if crop is not None: self.add_module('crop', Crop(crop)) diff --git a/deepem/models/updown_act_interpolate.py b/deepem/models/updown_act_interpolate.py new file mode 100644 index 0000000..8e21a3e --- /dev/null +++ b/deepem/models/updown_act_interpolate.py @@ -0,0 +1,88 @@ +import torch.nn as nn +import torch.nn.functional as F + +from emvision.models import rsunet_act, rsunet_act_gn + +from deepem.models.layers import Conv, Crop + + +def create_model(opt): + if opt.width: + width = opt.width + depth = len(width) + else: + width = [16, 32, 64, 128, 256, 512] + depth = opt.depth + if opt.group > 0: + # Group normalization + core = rsunet_act_gn(width=width[:depth], group=opt.group, act=opt.act) + else: + # Batch normalization + core = rsunet_act(width=width[:depth], act=opt.act) + return Model(core, opt.in_spec, opt.out_spec, width[0], crop=opt.crop, + scale_factor=opt.updown_scale_factor) + + +class InputBlock(nn.Sequential): + def __init__(self, in_channels, out_channels, kernel_size): + super(InputBlock, self).__init__() + self.add_module('conv', Conv(in_channels, out_channels, kernel_size)) + + +class OutputBlock(nn.Module): + def __init__(self, in_channels, out_spec, kernel_size): + super(OutputBlock, self).__init__() + for k, v in out_spec.items(): + out_channels = v[-4] + self.add_module(k, + Conv(in_channels, out_channels, kernel_size, bias=True)) + + def forward(self, x): + return {k: m(x) for k, m in self.named_children()} + + +class DownBlock(nn.Module): + def __init__(self, size): + super(DownBlock, self).__init__() + self.size = size + + def forward(self, x): + return F.interpolate(x, size=self.size, mode='trilinear', align_corners=False) + + +class UpBlock(nn.Module): + def __init__(self, out_spec, size): + super(UpBlock, self).__init__() + for k, v in out_spec.items(): + self.add_module(k, + nn.Upsample( + size=size, + mode='trilinear', + recompute_scale_factor=False, + )) + + def forward(self, x): + return {k: m(x[k]) for k, m in self.named_children()} + + +class Model(nn.Sequential): + """ + Residual Symmetric U-Net with down/upsampling in/output. + """ + def __init__(self, core, in_spec, out_spec, out_channels, io_kernel=(1,5,5), + scale_factor=(1, 2, 2), crop=None): + super(Model, self).__init__() + + assert len(in_spec)==1, "model takes a single input" + in_channels = 1 + in_size = in_spec['input'][-3:] + assert all(s % f == 0 for s, f in zip(in_size, scale_factor)) + new_size = tuple(int(s / f) for s, f in zip(in_size, scale_factor)) + + self.add_module('down', DownBlock(size=new_size)) + self.add_module('in', InputBlock(in_channels, out_channels, io_kernel)) + self.add_module('core', core) + self.add_module('out', OutputBlock(out_channels, out_spec, io_kernel)) + self.add_module('up', UpBlock(out_spec, size=in_size)) + if crop is not None: + self.add_module('crop', Crop(crop)) diff --git a/deepem/models/updown_act_scale.py b/deepem/models/updown_act_scale.py new file mode 100644 index 0000000..888371f --- /dev/null +++ b/deepem/models/updown_act_scale.py @@ -0,0 +1,96 @@ +import torch +import torch.nn as nn + +import emvision +from emvision.models import rsunet_act, rsunet_act_gn + +from deepem.models.layers import Conv, Crop, Scale + + +def create_model(opt): + if opt.width: + width = opt.width + depth = len(width) + else: + width = [16,32,64,128,256,512] + depth = opt.depth + if opt.group > 0: + # Group normalization + core = rsunet_act_gn(width=width[:depth], group=opt.group, act=opt.act) + else: + # Batch normalization + core = rsunet_act(width=width[:depth], act=opt.act) + return Model(core, opt.in_spec, opt.out_spec, width[0], crop=opt.crop, + onnx=opt.onnx, scale_init=opt.scale_init) + + +class InputBlock(nn.Sequential): + def __init__(self, in_channels, out_channels, kernel_size): + super(InputBlock, self).__init__() + self.add_module('conv', Conv(in_channels, out_channels, kernel_size)) + + +class OutputBlock(nn.Module): + def __init__(self, in_channels, out_spec, kernel_size, onnx=False, scale_init=1.0): + super(OutputBlock, self).__init__() + self.onnx = onnx + for k, v in out_spec.items(): + out_channels = v[-4] + if k == 'embedding': + self.add_module( + k, + nn.Sequential( + Conv(in_channels, out_channels, kernel_size, bias=True), + Scale(init_value=scale_init), + ), + ) + else: + self.add_module(k, + Conv(in_channels, out_channels, kernel_size, bias=True)) + + def forward(self, x): + if self.onnx: + return tuple(m(x) for k, m in self.named_children()) + else: + return {k: m(x) for k, m in self.named_children()} + + +class DownBlock(nn.Sequential): + def __init__(self, scale_factor=(1,2,2)): + super(DownBlock, self).__init__() + self.add_module('down', nn.AvgPool3d(scale_factor)) + + +class UpBlock(nn.Module): + def __init__(self, out_spec, scale_factor=(1,2,2), onnx=False): + super(UpBlock, self).__init__() + self.onnx = onnx + for k, v in out_spec.items(): + self.add_module(k, + nn.Upsample(scale_factor=scale_factor, mode='trilinear')) + + def forward(self, x): + if self.onnx: + return tuple(m(x[i]) for i, (_, m) in enumerate(self.named_children())) + else: + return {k: m(x[k]) for k, m in self.named_children()} + + +class Model(nn.Sequential): + """ + Residual Symmetric U-Net with down/upsampling in/output. + """ + def __init__(self, core, in_spec, out_spec, out_channels, io_kernel=(1,5,5), + scale_factor=(1,2,2), crop=None, onnx=False, scale_init=1.0): + super(Model, self).__init__() + + assert len(in_spec)==1, "model takes a single input" + in_channels = 1 + + self.add_module('down', DownBlock(scale_factor=scale_factor)) + self.add_module('in', InputBlock(in_channels, out_channels, io_kernel)) + self.add_module('core', core) + self.add_module('out', OutputBlock(out_channels, out_spec, io_kernel, onnx=onnx, scale_init=scale_init)) + self.add_module('up', UpBlock(out_spec, scale_factor=scale_factor, onnx=onnx)) + if crop is not None: + self.add_module('crop', Crop(crop)) diff --git a/deepem/models/updown_act_scale_interpolate.py b/deepem/models/updown_act_scale_interpolate.py new file mode 100644 index 0000000..571a6e5 --- /dev/null +++ b/deepem/models/updown_act_scale_interpolate.py @@ -0,0 +1,97 @@ +import torch.nn as nn +import torch.nn.functional as F + +from emvision.models import rsunet_act, rsunet_act_gn + +from deepem.models.layers import Conv, Crop, Scale + + +def create_model(opt): + if opt.width: + width = opt.width + depth = len(width) + else: + width = [16, 32, 64, 128, 256, 512] + depth = opt.depth + if opt.group > 0: + # Group normalization + core = rsunet_act_gn(width=width[:depth], group=opt.group, act=opt.act) + else: + # Batch normalization + core = rsunet_act(width=width[:depth], act=opt.act) + return Model(core, opt.in_spec, opt.out_spec, width[0], crop=opt.crop, + scale_init=opt.scale_init, scale_factor=opt.updown_scale_factor) + + +class InputBlock(nn.Sequential): + def __init__(self, in_channels, out_channels, kernel_size): + super(InputBlock, self).__init__() + self.add_module('conv', Conv(in_channels, out_channels, kernel_size)) + + +class OutputBlock(nn.Module): + def __init__(self, in_channels, out_spec, kernel_size, scale_init=1.0): + super(OutputBlock, self).__init__() + for k, v in out_spec.items(): + out_channels = v[-4] + if k == 'embedding': + self.add_module( + k, + nn.Sequential( + Conv(in_channels, out_channels, kernel_size, bias=True), + Scale(init_value=scale_init), + ), + ) + else: + self.add_module(k, + Conv(in_channels, out_channels, kernel_size, bias=True)) + + def forward(self, x): + return {k: m(x) for k, m in self.named_children()} + + +class DownBlock(nn.Module): + def __init__(self, size): + super(DownBlock, self).__init__() + self.size = size + + def forward(self, x): + return F.interpolate(x, size=self.size, mode='trilinear', align_corners=False) + + +class UpBlock(nn.Module): + def __init__(self, out_spec, size): + super(UpBlock, self).__init__() + for k, v in out_spec.items(): + self.add_module(k, + nn.Upsample( + size=size, + mode='trilinear', + recompute_scale_factor=False, + )) + + def forward(self, x): + return {k: m(x[k]) for k, m in self.named_children()} + + +class Model(nn.Sequential): + """ + Residual Symmetric U-Net with down/upsampling in/output. + """ + def __init__(self, core, in_spec, out_spec, out_channels, io_kernel=(1,5,5), + crop=None, scale_init=1.0, scale_factor=(1, 2, 2)): + super(Model, self).__init__() + + assert len(in_spec)==1, "model takes a single input" + in_channels = 1 + in_size = in_spec['input'][-3:] + assert all(s % f == 0 for s, f in zip(in_size, scale_factor)) + new_size = tuple(int(s / f) for s, f in zip(in_size, scale_factor)) + + self.add_module('down', DownBlock(size=new_size)) + self.add_module('in', InputBlock(in_channels, out_channels, io_kernel)) + self.add_module('core', core) + self.add_module('out', OutputBlock(out_channels, out_spec, io_kernel, scale_init=scale_init)) + self.add_module('up', UpBlock(out_spec, size=in_size)) + if crop is not None: + self.add_module('crop', Crop(crop)) diff --git a/deepem/test/cv_utils.py b/deepem/test/cv_utils.py index 786b7d0..c9443b1 100644 --- a/deepem/test/cv_utils.py +++ b/deepem/test/cv_utils.py @@ -16,20 +16,25 @@ def make_info(num_channels, layer_type, dtype, shape, resolution, chunk_size=chunk_size) def get_coord_bbox(cvol, opt): - # Cutout offset = cvol.voxel_offset + volume_shape = cvol.shape[:3] + if opt.center is not None: assert opt.size is not None - opt.begin = tuple(x - (y//2) for x, y in zip(opt.center, opt.size)) + opt.begin = tuple(x - (y // 2) for x, y in zip(opt.center, opt.size)) opt.end = tuple(x + y for x, y in zip(opt.begin, opt.size)) else: if opt.begin is None: - opt.begin = offset + opt.begin = offset # Default to voxel_offset if not provided + if opt.end is None: if opt.size is None: - opt.end = offset + cvol.shape[:3] + # When size is not specified, set end to the end of the dataset + opt.end = tuple(o + s for o, s in zip(offset, volume_shape)) else: - opt.end = tuple(x + y for x, y in zip(opt.begin, opt.size)) + # When size is specified, calculate end based on begin + size + opt.end = tuple(b + s for b, s in zip(opt.begin, opt.size)) + return Bbox(opt.begin, opt.end) def cutout(opt, gs_path, dtype='uint8', channels=0): @@ -37,23 +42,30 @@ def cutout(opt, gs_path, dtype='uint8', channels=0): gs_path = gs_path.format(*opt.keywords) print(gs_path) - # CloudVolume. - cvol = cv.CloudVolume(gs_path, mip=opt.in_mip, cache=opt.cache, - fill_missing=True, parallel=opt.parallel) + # CloudVolume for coordinate handling (coord_mip) + coord_cvol = cv.CloudVolume(gs_path, mip=opt.coord_mip) + + # CloudVolume for data fetching (in_mip) + data_cvol = cv.CloudVolume(gs_path, mip=opt.in_mip, cache=opt.cache, + fill_missing=True, parallel=opt.parallel) + + # Get bounding box based on coord_mip + coord_bbox = get_coord_bbox(coord_cvol, opt) - # Based on MIP level args are specified in - coord_bbox = get_coord_bbox(cvol, opt) if opt.coord_mip != opt.in_mip: print(f"mip {opt.coord_mip} = {coord_bbox}") - # Based on in_mip - in_bbox = cvol.bbox_to_mip(coord_bbox, mip=opt.coord_mip, to_mip=opt.in_mip) + + # Convert bbox to in_mip coordinates if needed + in_bbox = coord_cvol.bbox_to_mip(coord_bbox, mip=opt.coord_mip, to_mip=opt.in_mip) print(f"mip {opt.in_mip} = {in_bbox}") - cutout = cvol[in_bbox.to_slices()] + + # Data cutout from in_mip + cutout = data_cvol[in_bbox.to_slices()] # Transpose & squeeze - cutout = cutout.transpose([3,2,1,0]) + cutout = cutout.transpose([3, 2, 1, 0]) - # Slice channel + # Slice channels if specified if channels > 0: cutout = cutout[:channels, ...] @@ -63,42 +75,46 @@ def cutout(opt, gs_path, dtype='uint8', channels=0): def ingest(data, opt, tag=None): # Neuroglancer format data = py_utils.to_tensor(data) - data = data.transpose((3,2,1,0)) + data = data.transpose((3, 2, 1, 0)) num_channels = data.shape[-1] shape = data.shape[:-1] - # Use CloudVolume to make sure the output bbox matches the input. - # MIP hierarchies are not guaranteed to be powers of 2 (especially - # with float resolutions), so need to use the info file and be - # consistent computing the offset between input and output. + # Use CloudVolume with coord_mip for coordinate handling gs_path = opt.gs_input if '{}' in gs_path: gs_path = gs_path.format(*opt.keywords) - in_vol = cv.CloudVolume(gs_path, mip=opt.in_mip, cache=opt.cache, - fill_missing=True, parallel=opt.parallel) - coord_bbox = get_coord_bbox(in_vol, opt) + + coord_cvol = cv.CloudVolume(gs_path, mip=opt.coord_mip) + + # Get bounding box in coord_mip + coord_bbox = get_coord_bbox(coord_cvol, opt) + # Offset is defined at coord_mip, so adjust coord_bbox first if opt.offset: start_adjust = coord_bbox.minpt - opt.offset coord_bbox -= start_adjust - in_bbox = in_vol.bbox_to_mip(coord_bbox, mip=opt.coord_mip, to_mip=opt.in_mip) + + # Convert bbox to in_mip coordinates + in_bbox = coord_cvol.bbox_to_mip(coord_bbox, mip=opt.coord_mip, to_mip=opt.in_mip) # Patch offset correction (when output patch is smaller than input patch) - patch_offset = (0,0,0) + patch_offset = (0, 0, 0) if opt.tilt_series > 0: if opt.tilt_series_crop is not None: outputsz = np.array(opt.fov) * np.array(opt.scale) patch_offset = (outputsz - np.array(opt.tilt_series_crop)) // 2 else: patch_offset = (np.array(opt.inputsz) - np.array(opt.outputsz)) // 2 + patch_offset = Vec(*np.flip(patch_offset, 0)) - in_bbox.minpt += patch_offset - # in_bbox.stop -= 2*patch_offset # using the data to define shape - # Create info + # Create info using the adjusted offset + offset = in_bbox.minpt + patch_offset info = make_info(num_channels, 'image', str(data.dtype), shape, - opt.resolution, offset=in_bbox.minpt, chunk_size=opt.chunk_size) + opt.resolution, offset=offset, chunk_size=opt.chunk_size) print(info) + + # Output path formatting gs_path = opt.gs_output if '{}' in opt.gs_output: if opt.keywords: @@ -108,21 +124,17 @@ def ingest(data, opt, tag=None): coord = "x{}_y{}_z{}".format(*opt.center) coord += "_s{}-{}-{}".format(*opt.size) else: - coord = '_'.join([f"{b}-{e}" for b,e in zip(opt.begin,opt.end)]) + coord = '_'.join([f"{b}-{e}" for b, e in zip(opt.begin, opt.end)]) gs_path = gs_path.format(coord) # Tagging if tag is not None: - if gs_path[-1] == '/': - gs_path += tag - else: - gs_path += ('/' + tag) + gs_path = gs_path.rstrip('/') + '/' + tag print(f"gs_output:\n{gs_path}") - cvol = cv.CloudVolume(gs_path, mip=0, info=info, - parallel=opt.parallel) - cvol[:,:,:,:] = data - cvol.commit_info() + data_vol = cv.CloudVolume(gs_path, mip=0, info=info, parallel=opt.parallel) + data_vol[:, :, :, :] = data + data_vol.commit_info() # Downsample if opt.downsample: diff --git a/deepem/test/option.py b/deepem/test/option.py index 03e7544..70dd9d6 100644 --- a/deepem/test/option.py +++ b/deepem/test/option.py @@ -1,4 +1,5 @@ import argparse +from collections import OrderedDict import json import math import os @@ -38,6 +39,7 @@ def initialize(self): self.parser.add_argument('--width', type=int, default=None, nargs='+') self.parser.add_argument('--group', type=int, default=0) self.parser.add_argument('--act', default='ReLU') + self.parser.add_argument('--updown_scale_factor', type=vec3f, default=None) # Tilt-series electron tomography self.parser.add_argument('--tilt_series', type=int, default=0) @@ -63,11 +65,20 @@ def initialize(self): self.parser.add_argument('--mye', action='store_true') self.parser.add_argument('--mye_thresh', type=float, default=0.5) self.parser.add_argument('--blv', action='store_true') - self.parser.add_argument('--blv_num_channels', type=int, default=2) - self.parser.add_argument('--glia', action='store_true') + self.parser.add_argument('--blv_num_channels', type=int, default=1) + self.parser.add_argument('--glia', action='store_true') self.parser.add_argument('--sem', action='store_true') self.parser.add_argument('--img', action='store_true') + # Semantic segmentation + self.parser.add_argument('--semantic', action='store_true') + self.parser.add_argument('--dend', action='store_true') # Dendrite + self.parser.add_argument('--axon', action='store_true') # Axon + self.parser.add_argument('--soma', action='store_true') # Soma + self.parser.add_argument('--nucl', action='store_true') # Nucleus + self.parser.add_argument('--ecs', action='store_true') # Extracellular space + self.parser.add_argument('--other', action='store_true') # Other class + # Test-time augmentation self.parser.add_argument('--test_aug', type=int, default=None, nargs='+') self.parser.add_argument('--test_aug16', action='store_true') @@ -203,6 +214,35 @@ def parse(self): opt.out_spec['bvessel'] = (1,) + opt.outputsz if opt.img: opt.out_spec['image'] = (1,) + opt.outputsz + if opt.dend: + opt.out_spec['dendrite'] = (1,) + opt.outputsz + if opt.axon: + opt.out_spec['axon'] = (1,) + opt.outputsz + if opt.soma: + opt.out_spec['soma'] = (1,) + opt.outputsz + if opt.nucl: + opt.out_spec['nucleus'] = (1,) + opt.outputsz + if opt.ecs: + opt.out_spec['extracellular_space'] = (1,) + opt.outputsz + if opt.other: + opt.out_spec['other_class'] = (1,) + opt.outputsz + + # Semantic segmentation + if opt.semantic: + required_keys = ['soma', 'axon', 'dendrite', 'glia', 'blood_vessel'] + + # Ensure all required keys are present in the opt.out_spec + assert all(key in opt.out_spec for key in required_keys) + + # Use OrderedDict to maintain order of required keys followed by other keys + out_spec_new = OrderedDict((key, opt.out_spec[key]) for key in required_keys) + + # Add remaining keys to out_spec_new + out_spec_new.update((key, opt.out_spec[key]) for key in opt.out_spec if key not in required_keys) + + # Convert back to standard dict if necessary + opt.out_spec = dict(out_spec_new) + assert(len(opt.out_spec) > 0) # Scan spec @@ -240,6 +280,34 @@ def parse(self): opt.scan_spec['bvessel'] = (1,) + opt.outputsz if opt.img: opt.scan_spec['image'] = (1,) + opt.outputsz + if opt.dend: + opt.scan_spec['dendrite'] = (1,) + opt.outputsz + if opt.axon: + opt.scan_spec['axon'] = (1,) + opt.outputsz + if opt.soma: + opt.scan_spec['soma'] = (1,) + opt.outputsz + if opt.nucl: + opt.scan_spec['nucleus'] = (1,) + opt.outputsz + if opt.ecs: + opt.scan_spec['extracellular_space'] = (1,) + opt.outputsz + if opt.other: + opt.scan_spec['other_class'] = (1,) + opt.outputsz + + # Semantic segmentation + if opt.semantic: + required_keys = ['soma', 'axon', 'dendrite', 'glia', 'blood_vessel'] + + # Ensure all required keys are present in the opt.scan_spec + assert all(key in opt.scan_spec for key in required_keys) + + # Use OrderedDict to maintain order of required keys followed by other keys + scan_spec_new = OrderedDict((key, opt.scan_spec[key]) for key in required_keys) + + # Add remaining keys to scan_spec_new + scan_spec_new.update((key, opt.scan_spec[key]) for key in opt.scan_spec if key not in required_keys) + + # Convert back to standard dict if necessary + opt.scan_spec = dict(scan_spec_new) # Test-time augmentation if opt.test_aug16: diff --git a/deepem/test/utils.py b/deepem/test/utils.py index 58d773b..893f40e 100644 --- a/deepem/test/utils.py +++ b/deepem/test/utils.py @@ -1,4 +1,3 @@ -import imp import numpy as np import os from types import SimpleNamespace @@ -11,7 +10,7 @@ def load_model(opt): # Create a model. - mod = imp.load_source('model', opt.model) + mod = py_utils.load_module('model', opt.model) if opt.onnx: model = OnnxModel(mod.create_model(opt), opt) else: diff --git a/deepem/train/data.py b/deepem/train/data.py index cfdb0f1..df05c3a 100644 --- a/deepem/train/data.py +++ b/deepem/train/data.py @@ -1,4 +1,5 @@ -import imp +from deepem.utils.py_utils import load_module + import numpy as np import torch @@ -45,19 +46,20 @@ def requires_grad(self, key): def build(self, opt, data, is_train, prob): # Data augmentation if opt.augment: - mod = imp.load_source('augment', opt.augment) + mod = load_module('augment', opt.augment) aug = mod.get_augmentation(is_train, **opt.aug_params) else: aug = None # Data sampler - mod = imp.load_source('sampler', opt.sampler) + mod = load_module('sampler', opt.sampler) spec = mod.get_spec(opt.in_spec, opt.out_spec) - sampler = mod.Sampler(data, spec, is_train, aug, prob=prob) + zspecs = opt.zettaset_specs + sampler = mod.Sampler(data, spec, is_train, aug, prob, zspecs) # Sample modifier if opt.modifier: - mod = imp.load_source('modifier', opt.modifier) + mod = load_module('modifier', opt.modifier) self.modifier = mod.Modifier(**opt.modifier_kwargs) else: def default_modifier(x, **kwargs): diff --git a/deepem/train/logger.py b/deepem/train/logger.py index 8fb4069..2fd7a5c 100644 --- a/deepem/train/logger.py +++ b/deepem/train/logger.py @@ -2,14 +2,8 @@ import sys import datetime from collections import OrderedDict -import numpy as np import torch -from torchvision.utils import make_grid -from tensorboardX import SummaryWriter - -from deepem.loss.mean import vec2aff -from deepem.utils import torch_utils, py_utils class Logger(object): @@ -21,9 +15,6 @@ def __init__(self, opt): self.outputsz = opt.outputsz self.lr = opt.lr - # TensorBoard logging - self.writer = SummaryWriter(opt.log_dir) if opt.tensorboard else None - # Metric learning self.delta_d = opt.delta_d @@ -40,8 +31,7 @@ def __enter__(self): return self def __exit__(self, type, value, traceback): - if self.writer: - self.writer.close() + pass def record(self, phase, loss, nmsk, **kwargs): monitor = self.monitor[phase] @@ -58,15 +48,9 @@ def record(self, phase, loss, nmsk, **kwargs): def check(self, phase, iter_num): stats = self.monitor[phase].flush() - self.log(phase, iter_num, stats) self.display(phase, iter_num, stats) return stats - def log(self, phase, iter_num, stats): - if self.writer: - for k, v in stats.items(): - self.writer.add_scalar(f"{phase}/{k}", v, iter_num) - def display(self, phase, iter_num, stats): disp = "[%s] Iter: %8d, " % (phase, iter_num) for k, v in stats.items(): @@ -95,96 +79,6 @@ def flush(self): self.norm = OrderedDict() return ret - def log_images(self, phase, iter_num, preds, sample): - if self.writer is None: - return - - # Peep output size - key = sorted(self.out_spec)[0] - cropsz = sample[key].shape[-3:] - for k in sorted(self.out_spec): - outsz = sample[k].shape[-3:] - assert np.array_equal(outsz, cropsz) - - # Inputs - for k in sorted(self.in_spec): - tag = f"{phase}/images/{k}" - tensor = sample[k][0,...].cpu() - tensor = torch_utils.crop_center_no_strict(tensor, cropsz) - num_channels = tensor.shape[-4] - if num_channels > 3: - self.log_image(tag, tensor[0:3,...], iter_num) - else: - self.log_image(tag, tensor, iter_num) - - # Outputs - for k in sorted(self.out_spec): - - if k == 'affinity': - - # Prediction - tag = f"{phase}/images/{k}" - tensor = torch.sigmoid(preds[k][0,0:3,...]).cpu() - self.log_image(tag, tensor, iter_num) - - # Mask - tag = f"{phase}/masks/{k}" - msk = sample[k + '_mask'][0,...].cpu() - self.log_image(tag, msk, iter_num) - - # Target - tag = f"{phase}/labels/{k}" - seg = sample[k][0,0,...].cpu().numpy().astype('uint32') - rgb = torch.from_numpy(py_utils.seg2rgb(seg)) - self.log_image(tag, rgb, iter_num) - - elif k == 'embedding': - - vec = preds[k][0, ...] - - # Metric graph - tag = f"{phase}/images/metric_graph" - aff = vec2aff(vec, delta_d=self.delta_d) - self.log_image(tag, aff.cpu(), iter_num) - - # Embedding - tag = f"{phase}/images/{k}" - vec = preds[k][[0],...].cpu() # 1, c, z, y, x - vec = torch_utils.vec2pca(vec) - vec = vec.select(0, 0) - self.log_image(tag, vec, iter_num) - - # Target - tag = f"{phase}/labels/{k}" - seg = sample[k][0,0,...].cpu().numpy().astype('uint32') - rgb = torch.from_numpy(py_utils.seg2rgb(seg)) - self.log_image(tag, rgb, iter_num) - - else: - - # Prediction - tag = f"{phase}/images/{k}" - pred = torch.sigmoid(preds[k][0,...]).cpu() - self.log_image(tag, pred, iter_num) - - # Mask - tag = f"{phase}/masks/{k}" - msk = sample[k + '_mask'][0,...].cpu() - self.log_image(tag, msk, iter_num) - - # Target - tag = f"{phase}/labels/{k}" - target = sample[k][0,...].cpu() - self.log_image(tag, target, iter_num) - - def log_image(self, tag, tensor, iter_num): - if self.writer: - assert(torch.is_tensor(tensor)) - depth = tensor.shape[-3] - imgs = [tensor[:,z,:,:] for z in range(depth)] - img = make_grid(imgs, nrow=depth, padding=0) - self.writer.add_image(tag, img, iter_num) - def log_params(self, params): fname = os.path.join(self.log_dir, f"{self.timestamp}_params.csv") with open(fname, "w+") as f: diff --git a/deepem/train/option.py b/deepem/train/option.py index a0aec25..615b1db 100644 --- a/deepem/train/option.py +++ b/deepem/train/option.py @@ -4,7 +4,7 @@ import numpy as np import samwise -from deepem.utils.py_utils import vec3 +from deepem.utils.py_utils import vec3, vec3f class Options(object): @@ -18,17 +18,21 @@ def __init__(self): def initialize(self): self.parser.add_argument('--exp_name', required=True) self.parser.add_argument('--model', required=True) - self.parser.add_argument('--data', required=True) self.parser.add_argument('--sampler', required=True) + self.parser.add_argument('--data', default=None) self.parser.add_argument('--augment', default=None) self.parser.add_argument('--modifier', default=None) self.parser.add_argument('--modifier_kwargs', type=json.loads, default={}) # zettasets - self.parser.add_argument('--zettaset_path', required=True, type=str, default=[], nargs='+') + self.parser.add_argument('--zettaset_path', type=str, default=[], nargs='+') + self.parser.add_argument('--zettaset_specs', type=json.loads, default={}) self.parser.add_argument('--zettaset_lookup', type=json.loads, default=None) self.parser.add_argument('--zettaset_padding', type=vec3, default=(0, 0, 0)) + self.parser.add_argument('--zettaset_padding_spec', type=json.loads, default={}) + self.parser.add_argument('--zettaset_resolution', type=vec3f, default=None) self.parser.add_argument('--zettaset_no_mask', action='store_true') + self.parser.add_argument('--zettaset_share_mask', type=str, default=None) # file synchronization for spot/preemptible training self.parser.add_argument('--samwise_map', nargs='*', default=None) @@ -60,6 +64,7 @@ def initialize(self): # Loss self.parser.add_argument('--loss', default='BCELoss') + self.parser.add_argument('--no_split_boundary', action='store_true') self.parser.add_argument('--size_average', action='store_true') self.parser.add_argument('--margin0', type=float, default=0) self.parser.add_argument('--margin1', type=float, default=0) @@ -78,6 +83,8 @@ def initialize(self): self.parser.add_argument('--delta_v', type=float, default=0.0) self.parser.add_argument('--delta_d', type=float, default=1.5) self.parser.add_argument('--recompute_ext', action='store_true') + self.parser.add_argument('--no_mask_background', action='store_true') + self.parser.add_argument('--loss_scale_factor', type=vec3f, default=None) # Optimizer self.parser.add_argument('--optim', default='Adam') @@ -99,9 +106,10 @@ def initialize(self): self.parser.add_argument('--width', type=int, default=None, nargs='+') self.parser.add_argument('--group', type=int, default=0) self.parser.add_argument('--act', default='ReLU') + self.parser.add_argument('--updown_scale_factor', type=vec3f, default=None) # Data augmentation - self.parser.add_argument('--recompute', action='store_true') + self.parser.add_argument('--recompute', type=str, default=[], nargs='+') self.parser.add_argument('--border', type=str, default=[], nargs='+') self.parser.add_argument('--flip', action='store_true') self.parser.add_argument('--grayscale', action='store_true') @@ -118,6 +126,8 @@ def initialize(self): self.parser.add_argument('--noise_min', type=float, default=0.01) self.parser.add_argument('--noise_max', type=float, default=0.1) self.parser.add_argument('--noise_per_channel', action='store_true') + self.parser.add_argument('--section_gap', type=int, default=0) + self.parser.add_argument('--mask_section_gap', action='store_true') # Tilt-series electron tomography self.parser.add_argument('--tilt_series', type=int, default=0) @@ -138,12 +148,20 @@ def initialize(self): self.parser.add_argument('--mye', type=float, default=0) # Myelin self.parser.add_argument('--fld', type=float, default=0) # Fold self.parser.add_argument('--blv', type=float, default=0) # Blood vessel - self.parser.add_argument('--blv_num_channels', type=int, default=2) - self.parser.add_argument('--glia', type=float, default=0) # Glia + self.parser.add_argument('--blv_num_channels', type=int, default=1) + self.parser.add_argument('--glia', type=float, default=0) # Glia self.parser.add_argument('--glia_mask', action='store_true') - self.parser.add_argument('--soma', type=float, default=0) # Soma self.parser.add_argument('--img', type=float, default=0) # Image + # Semantic segmentation + self.parser.add_argument('--sem', action='store_true') + self.parser.add_argument('--dend', type=float, default=0) # Dendrite + self.parser.add_argument('--axon', type=float, default=0) # Axon + self.parser.add_argument('--soma', type=float, default=0) # Soma + self.parser.add_argument('--nucl', type=float, default=0) # Nucleus + self.parser.add_argument('--ecs', type=float, default=0) # Extracellular space + self.parser.add_argument('--other', type=float, default=0) # Other class + # Metric learning self.parser.add_argument('--vec', type=float, default=0) self.parser.add_argument('--embed_dim', type=int, default=12) @@ -158,9 +176,6 @@ def initialize(self): self.parser.add_argument('--export_onnx', action='store_true') self.parser.add_argument('--opset_version', type=int, default=10) - # TensorBoard logging - self.parser.add_argument('--tensorboard', action='store_true') - self.initialized = True def parse(self): @@ -186,7 +201,12 @@ def parse(self): if (not opt.train_ids) or (not opt.val_ids): raise ValueError("Train/validation IDs unspecified") if opt.train_prob: - assert len(opt.train_ids) == len(opt.train_prob) + if len(opt.train_ids) != len(opt.train_prob): + error_message = ( + "The lengths of 'train_ids' and 'train_prob' must be the same. " + f"train_ids: {opt.train_ids}, train_prob: {opt.train_prob}" + ) + raise ValueError(error_message) if opt.val_prob: assert len(opt.val_ids) == len(opt.val_prob) @@ -205,6 +225,8 @@ def parse(self): opt.metric_params['delta_v'] = opt.delta_v opt.metric_params['delta_d'] = opt.delta_d opt.metric_params['recompute_ext'] = opt.recompute_ext + opt.metric_params['mask_background'] = not opt.no_mask_background + opt.metric_params['loss_scale_factor'] = opt.loss_scale_factor # Optimizer if opt.optim == 'Adam': @@ -217,7 +239,8 @@ def parse(self): # Data augmentation aug_keys = ['recompute', 'border', 'flip','grayscale','warping','misalign', - 'interp','missing','blur','box','mip','lost','random'] + 'interp','missing','blur','box','mip','lost','random', + 'section_gap', 'mask_section_gap'] opt.aug_params = {k: args[k] for k in aug_keys} # Noise @@ -275,6 +298,22 @@ def parse(self): 'soma': ('soma', 1), 'img': ('image', 1), 'vec': ('embedding', opt.embed_dim), + 'dend': ('dendrite', 1), + 'axon': ('axon', 1), + 'nucl': ('nucleus', 1), + 'ecs': ('extracellular_space', 1), + 'other': ('other_class', 1), + } + + semantic_mapping = { + 'dendrite': 1, + 'axon': 2, + 'soma': 3, + 'nucleus': 4, + 'glia': 5, + 'extracellular_space': 6, + 'blood_vessel': 7, + 'other_class': 10, } requires_binarize = [ @@ -286,6 +325,12 @@ def parse(self): "soma", ] + if opt.blv_num_channels == 1: + requires_binarize.append("blood_vessel") + + if opt.sem: + requires_binarize = [x for x in requires_binarize if x not in semantic_mapping] + # Test training if opt.test: opt.eval_intv = 100 @@ -309,8 +354,12 @@ def parse(self): glia_mask=opt.glia_mask, zettaset_lookup=opt.zettaset_lookup, zettaset_padding=opt.zettaset_padding, + zettaset_padding_spec=opt.zettaset_padding_spec, + zettaset_resolution=opt.zettaset_resolution, zettaset_mask=not opt.zettaset_no_mask, requires_binarize=requires_binarize, + zettaset_share_mask=opt.zettaset_share_mask, + semantic_mapping=semantic_mapping if opt.sem else {}, ) # ONNX diff --git a/deepem/train/run.py b/deepem/train/run.py index 6be9570..ec488e5 100644 --- a/deepem/train/run.py +++ b/deepem/train/run.py @@ -77,7 +77,6 @@ def train(opt): # Image logging if (i+1) % opt.imgs_intv == 0: - logger.log_images('train', i+1, preds, sample) wandb_logger.log_images('train', i+1, preds, sample) # Evaluation loop @@ -126,7 +125,6 @@ def eval_loop(iter_num, model, data_loader, opt, logger, wandb_logger): # Image logging if iter_num % opt.imgs_intv == 0: - logger.log_images('test', iter_num, preds, sample) wandb_logger.log_images('test', iter_num, preds, sample) print("-------------------------------------------") diff --git a/deepem/train/utils.py b/deepem/train/utils.py index be11580..fb6bf4f 100644 --- a/deepem/train/utils.py +++ b/deepem/train/utils.py @@ -1,4 +1,3 @@ -import imp import os import glob @@ -9,6 +8,7 @@ from deepem.train.data import Data from deepem.train.model import Model, AmpModel from deepem.loss.utils import BinaryWeightBalancer +from deepem.utils.py_utils import load_module def get_criteria(opt): @@ -31,6 +31,7 @@ def get_criteria(opt): params['size_average'] = False criteria[k] = loss.AffinityLoss(edges, criterion=getattr(loss, opt.loss)(**params), + split_boundary=not opt.no_split_boundary, size_average=opt.size_average, class_balancer=balancer, ) @@ -38,18 +39,29 @@ def get_criteria(opt): criteria[k] = getattr(loss, opt.metric_loss)(**opt.metric_params) else: params = dict(opt.loss_params) + + if ('affinity' in opt.out_spec) or ('long_range' in opt.out_spec): + balancer = BinaryWeightBalancer( + weight0=opt.class_weight1, + weight1=opt.class_weight0, + ) if opt.class_balancing else None + params['class_balancer'] = balancer + if opt.default_aux: params['margin0'] = 0 params['margin1'] = 0 params['inverse'] = False - params['class_balancer'] = balancer + params['class_balancer'] = None + criteria[k] = getattr(loss, 'BCELoss')(**params) return criteria def load_model(opt): # Create a model. - mod = imp.load_source('model', opt.model) + + mod = load_module("model", opt.model) + if opt.mixed_precision: model = AmpModel(mod.create_model(opt), get_criteria(opt), opt) else: @@ -118,10 +130,13 @@ def save_chkpt(model, fpath, chkpt_num, optimizer): def load_data(opt): - mod = imp.load_source('data', opt.data) data_ids = list(set().union(opt.train_ids, opt.val_ids)) + if opt.zettaset_specs: + from deepem.data.dataset import multi_zettaset as mod + else: + from deepem.data.dataset import zettaset as mod data = mod.load_data( - opt.zettaset_path, + opt.zettaset_specs if opt.zettaset_specs else opt.zettaset_path, data_ids=data_ids, **opt.data_params ) diff --git a/deepem/utils/py_utils.py b/deepem/utils/py_utils.py index 5f676e2..c7f5a37 100644 --- a/deepem/utils/py_utils.py +++ b/deepem/utils/py_utils.py @@ -4,6 +4,16 @@ from sklearn.decomposition import PCA +import importlib +import importlib.util + + +def load_module(name, path): + spec = importlib.util.spec_from_file_location(name, path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + def dict2tuple(d): return namedtuple('GenericDict', d.keys())(**d) diff --git a/docker/zettasets/Dockerfile b/docker/zettasets/Dockerfile index 345e12a..a431787 100644 --- a/docker/zettasets/Dockerfile +++ b/docker/zettasets/Dockerfile @@ -1,57 +1,61 @@ FROM mambaorg/micromamba AS intermediate -# multi-stage build to pull in zettasets (private repo) +# Multi-stage build to pull in zettasets (private repo) USER root ARG GIT_ACCESS_TOKEN -RUN apt-get --allow-releaseinfo-change update && \ - apt-get install -y build-essential git && \ - git clone https://${GIT_ACCESS_TOKEN}@github.com/ZettaAI/zettasets.git && \ + +RUN apt-get update --allow-releaseinfo-change && \ + apt-get install -y --no-install-recommends build-essential git && \ + git clone https://${GIT_ACCESS_TOKEN}@github.com/ZettaAI/zettasets.git /tmp/zettasets && \ rm -rf /var/lib/apt/lists/* +# Use an optimized PyTorch base image # FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime -# FROM pytorch/pytorch:1.2-cuda10.0-cudnn7-runtime -FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-runtime -# FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime -ARG DEBIAN_FRONTEND=noninteractive +FROM pytorch/pytorch:2.1.0-cuda12.1-cudnn8-runtime + +ENV DEBIAN_FRONTEND=noninteractive COPY --from=intermediate /tmp/zettasets /tmp/zettasets -RUN apt-get update \ - && apt-get install -y --no-install-recommends \ - build-essential git \ - libboost-all-dev \ - # gcloud cli (for samwise) +RUN apt-get update && \ + apt-get install -y --no-install-recommends \ + build-essential git libboost-all-dev \ curl apt-transport-https ca-certificates gnupg \ - && rm -rf /var/lib/apt/lists/* \ - # gcloud cli (for samwise) - && echo "deb https://packages.cloud.google.com/apt cloud-sdk main" \ - | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list \ - && curl https://packages.cloud.google.com/apt/doc/apt-key.gpg \ - | apt-key add - \ - && apt-get update && apt-get install google-cloud-cli \ - # handle imgaug issue - && apt-get install ffmpeg libsm6 libxext6 -y \ - # python requirements - && conda install h5py cython matplotlib \ - scikit-image scikit-learn \ - && conda clean -t -p \ - # pypi packages - && pip install --no-cache-dir --upgrade \ - numpy cloud-volume task-queue tensorboardX imgaug wandb \ - # github packages - && pip install --no-cache-dir \ - git+https://github.com/seung-lab/DataTools \ - git+https://github.com/ZettaAI/DataProvider3 \ - git+https://github.com/ZettaAI/Augmentor \ - git+https://github.com/seung-lab/pytorch-emvision \ + ffmpeg libsm6 libxext6 && \ + # Install Google Cloud CLI + echo "deb https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && \ + curl -fsSL https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key add - && \ + apt-get update && apt-get install -y google-cloud-cli && \ + rm -rf /var/lib/apt/lists/* + +# Install Python dependencies using Conda (now with faster solving) +RUN conda install -y h5py cython matplotlib scikit-image scikit-learn && \ + conda clean -a + +# Explicitly pin NumPy to <2 to avoid ABI issues +RUN pip install --no-cache-dir --upgrade \ + "numpy<2" cloud-volume task-queue imgaug wandb \ + cffi brotli fastremap imagecodecs + +RUN pip install --no-cache-dir \ + git+https://github.com/seung-lab/DataTools && \ + pip install --no-cache-dir \ + git+https://github.com/ZettaAI/DataProvider3 && \ + pip install --no-cache-dir \ + git+https://github.com/ZettaAI/Augmentor && \ + pip install --no-cache-dir \ + git+https://github.com/seung-lab/pytorch-emvision && \ + pip install --no-cache-dir \ git+https://github.com/ZettaAI/samwise -RUN cd /tmp/zettasets && \ - pip install -e . +WORKDIR /tmp/zettasets +RUN pip install --no-cache-dir -e . +# Set up working directory and environment RUN mkdir -p /workspace -ENV PYTHONPATH /DeepEM:${PYTHONPATH} +ENV PYTHONPATH=/DeepEM:${PYTHONPATH} COPY . /DeepEM/ WORKDIR /workspace + ENTRYPOINT ["python", "/DeepEM/deepem/train/run.py"] diff --git a/requirements.txt b/requirements.txt index 58d965a..0567101 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,8 +3,5 @@ igneous-pipeline scikit-image scikit-learn task-queue -tensorflow -tensorboard -tensorboardX wandb git+https://github.com/seung-lab/DataTools.git