diff --git a/.github/workflows/docbuild.yml b/.github/workflows/docbuild.yml index e153686b1..7618fb1e9 100644 --- a/.github/workflows/docbuild.yml +++ b/.github/workflows/docbuild.yml @@ -3,7 +3,7 @@ name: Documentation build on: [push, pull_request] concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true permissions: diff --git a/.github/workflows/docker_pyafq.yml b/.github/workflows/docker_pyafq.yml index aa05de72c..ada2c8c74 100644 --- a/.github/workflows/docker_pyafq.yml +++ b/.github/workflows/docker_pyafq.yml @@ -2,6 +2,10 @@ name: Build and Push pyAFQ Docker Image on: [push, pull_request] +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: build: runs-on: ubuntu-latest diff --git a/.github/workflows/docker_pyafq_cuda12.yml b/.github/workflows/docker_pyafq_cuda12.yml index cf4204433..9924cba8b 100644 --- a/.github/workflows/docker_pyafq_cuda12.yml +++ b/.github/workflows/docker_pyafq_cuda12.yml @@ -2,6 +2,10 @@ name: Build and Push pyAFQ CUDA12 Image on: [push, pull_request] +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + jobs: build: runs-on: ubuntu-latest diff --git a/.github/workflows/nightly_pft_test.yml b/.github/workflows/nightly_pft_test.yml deleted file mode 100644 index 2d52269b9..000000000 --- a/.github/workflows/nightly_pft_test.yml +++ /dev/null @@ -1,32 +0,0 @@ -name: Nightly PFT test suite - -on: - schedule: - - cron: '0 7 * * *' # every day at midnight, PST - -jobs: - build: - - runs-on: ubuntu-latest - strategy: - max-parallel: 4 - matrix: - python-version: ["3.12"] - - steps: - - name: Checkout repo - uses: actions/checkout@v1 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 - with: - python-version: ${{ matrix.python-version }} - - name: Install - run: | - python -m pip install --upgrade pip - pip install .[dev,fury,afqbrowser,plot] - - name: Lint - run: | - flake8 --ignore N802,N806,W503 --select W504 `find . -name \*.py | grep -v setup.py | grep -v version.py | grep -v __init__.py | grep -v /docs/` - - name: Test - run: | - cd && mkdir for_test && cd for_test && pytest --pyargs AFQ --cov-report term-missing --cov=AFQ -m "nightly_pft" --durations=0 diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2cc61de94..5aee256bb 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -3,7 +3,7 @@ name: Test suite on: [push, pull_request] concurrency: - group: ${{ github.workflow }}-${{ github.ref }} + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} cancel-in-progress: true jobs: @@ -39,4 +39,4 @@ jobs: flake8 --ignore N802,N806,W503 --select W504 `find . -name \*.py | grep -v setup.py | grep -v version.py | grep -v __init__.py | grep -v /docs/` - name: Test run: | - cd && mkdir for_test && cd for_test && pytest --pyargs AFQ --cov-report term-missing --cov=AFQ -m "not nightly and not nightly_basic and not nightly_custom and not nightly_anisotropic and not nightly_slr and not nightly_pft and not nightly_reco and not nightly_reco80" --durations=0 + cd && mkdir for_test && cd for_test && pytest --pyargs AFQ --cov-report term-missing --cov=AFQ -m "not nightly and not nightly_basic and not nightly_custom and not nightly_anisotropic and not nightly_slr and not nightly_reco and not nightly_reco80" --durations=0 diff --git a/AFQ/api/bundle_dict.py b/AFQ/api/bundle_dict.py index cc3b6cf78..2c4c33109 100644 --- a/AFQ/api/bundle_dict.py +++ b/AFQ/api/bundle_dict.py @@ -110,8 +110,7 @@ def default18_bd(): 'exclude': [], 'space': 'template', 'prob_map': templates['CST_L_prob_map'], - 'end': templates['CST_L_start'], - 'start': templates['CST_L_end']}, + 'end': templates['CST_L_start']}, 'Right Corticospinal': { 'cross_midline': False, 'include': [templates['CST_roi2_R'], @@ -119,8 +118,7 @@ def default18_bd(): 'exclude': [], 'space': 'template', 'prob_map': templates['CST_R_prob_map'], - 'end': templates['CST_R_start'], - 'start': templates['CST_R_end']}, + 'end': templates['CST_R_start']}, 'Left Inferior Fronto-occipital': { 'cross_midline': False, 'include': [templates['IFO_roi2_L'], @@ -230,6 +228,8 @@ def default18_bd(): 'exclude': [templates['SLF_roi1_L']], 'space': 'template', 'start': templates['pARC_L_start'], + 'Left Arcuate': { + 'overlap': 30}, 'primary_axis': 'I/S', 'primary_axis_percentage': 40}, 'Right Posterior Arcuate': {'cross_midline': False, @@ -237,34 +237,42 @@ def default18_bd(): 'exclude': [templates['SLF_roi1_R']], 'space': 'template', 'start': templates['pARC_R_start'], + 'Right Arcuate': { + 'overlap': 30}, 'primary_axis': 'I/S', 'primary_axis_percentage': 40}, 'Left Vertical Occipital': {'cross_midline': False, 'space': 'template', - 'start': templates['VOF_L_start'], 'end': templates['VOF_L_end'], - 'inc_addtol': [4, 0], 'Left Arcuate': { 'node_thresh': 20}, 'Left Posterior Arcuate': { - 'node_thresh': 1, - 'core': 'Anterior'}, - 'Left Inferior Longitudinal': { + 'node_thresh': 20, + 'entire_core': 'Anterior'}, + 'Left Inferior Fronto-occipital': { 'core': 'Right'}, + 'orient_mahal': { + 'distance_threshold': 3, + 'clean_rounds': 5}, + 'isolation_forest': { + 'percent_outlier_thresh': 50}, 'primary_axis': 'I/S', 'primary_axis_percentage': 40}, 'Right Vertical Occipital': {'cross_midline': False, 'space': 'template', - 'start': templates['VOF_R_start'], 'end': templates['VOF_R_end'], - 'inc_addtol': [4, 0], 'Right Arcuate': { 'node_thresh': 20}, 'Right Posterior Arcuate': { - 'node_thresh': 1, - 'core': 'Anterior'}, - 'Right Inferior Longitudinal': { + 'node_thresh': 20, + 'entire_core': 'Anterior'}, + 'Right Inferior Fronto-occipital': { 'core': 'Left'}, + 'orient_mahal': { + 'distance_threshold': 3, + 'clean_rounds': 5}, + 'isolation_forest': { + 'percent_outlier_thresh': 50}, 'primary_axis': 'I/S', 'primary_axis_percentage': 40}}) @@ -514,6 +522,7 @@ def baby_bd(): def callosal_bd(): + templates = afd.read_templates(as_img=False) callosal_templates =\ afd.read_callosum_templates(as_img=False) return BundleDict({ @@ -522,56 +531,64 @@ def callosal_bd(): 'include': [callosal_templates['R_AntFrontal'], callosal_templates['Callosum_midsag'], callosal_templates['L_AntFrontal']], - 'exclude': [], + 'isolation_forest': {'percent_outlier_thresh': 20}, + 'exclude': [templates['CST_roi1_L'], templates['CST_roi1_R']], 'space': 'template'}, 'Callosum Motor': { 'cross_midline': True, 'include': [callosal_templates['R_Motor'], callosal_templates['Callosum_midsag'], callosal_templates['L_Motor']], - 'exclude': [], + 'isolation_forest': {'percent_outlier_thresh': 20}, + 'exclude': [templates['CST_roi1_L'], templates['CST_roi1_R']], 'space': 'template'}, 'Callosum Occipital': { 'cross_midline': True, 'include': [callosal_templates['R_Occipital'], callosal_templates['Callosum_midsag'], callosal_templates['L_Occipital']], - 'exclude': [], + 'isolation_forest': {'percent_outlier_thresh': 20}, + 'exclude': [templates['CST_roi1_L'], templates['CST_roi1_R']], 'space': 'template'}, 'Callosum Orbital': { 'cross_midline': True, 'include': [callosal_templates['R_Orbital'], callosal_templates['Callosum_midsag'], callosal_templates['L_Orbital']], - 'exclude': [], + 'isolation_forest': {'percent_outlier_thresh': 20}, + 'exclude': [templates['CST_roi1_L'], templates['CST_roi1_R']], 'space': 'template'}, 'Callosum Posterior Parietal': { 'cross_midline': True, 'include': [callosal_templates['R_PostParietal'], callosal_templates['Callosum_midsag'], callosal_templates['L_PostParietal']], - 'exclude': [], + 'isolation_forest': {'percent_outlier_thresh': 20}, + 'exclude': [templates['CST_roi1_L'], templates['CST_roi1_R']], 'space': 'template'}, 'Callosum Superior Frontal': { 'cross_midline': True, 'include': [callosal_templates['R_SupFrontal'], callosal_templates['Callosum_midsag'], callosal_templates['L_SupFrontal']], - 'exclude': [], + 'isolation_forest': {'percent_outlier_thresh': 20}, + 'exclude': [templates['CST_roi1_L'], templates['CST_roi1_R']], 'space': 'template'}, 'Callosum Superior Parietal': { 'cross_midline': True, 'include': [callosal_templates['R_SupParietal'], callosal_templates['Callosum_midsag'], callosal_templates['L_SupParietal']], - 'exclude': [], + 'isolation_forest': {'percent_outlier_thresh': 20}, + 'exclude': [templates['CST_roi1_L'], templates['CST_roi1_R']], 'space': 'template'}, 'Callosum Temporal': { 'cross_midline': True, 'include': [callosal_templates['R_Temporal'], callosal_templates['Callosum_midsag'], callosal_templates['L_Temporal']], - 'exclude': [], + 'isolation_forest': {'percent_outlier_thresh': 20}, + 'exclude': [templates['CST_roi1_L'], templates['CST_roi1_R']], 'space': 'template'}}) @@ -722,7 +739,7 @@ class BundleDict(MutableMapping): If there are bundles in bundle_info with the 'space' attribute set to 'subject', their images (all ROIs and probability maps) will be resampled to the affine and shape of this image. - If None, resamples to DWI. Be careful if you use this, + If True, resamples to DWI. Be careful if you use this, that this is the correct choice. If False, no resampling will be done. Default: False diff --git a/AFQ/api/group.py b/AFQ/api/group.py index 1e06259f6..de653604b 100644 --- a/AFQ/api/group.py +++ b/AFQ/api/group.py @@ -15,6 +15,7 @@ import AFQ.utils.streamlines as aus from AFQ.viz.utils import get_eye from AFQ.data.utils import aws_import_msg_error +from AFQ.definitions.utils import find_file from dipy.utils.parallel import paramap from dipy.io.stateful_tractogram import StatefulTractogram, Space @@ -60,11 +61,13 @@ def clean_pandas_df(df): class _ParticipantAFQInputs: def __init__( - self, dwi_data_file, bval_file, bvec_file, results_dir, + self, dwi_data_file, bval_file, bvec_file, + t1_file, results_dir, kwargs): self.dwi_data_file = dwi_data_file self.bval_file = bval_file self.bvec_file = bvec_file + self.t1_file = t1_file self.results_dir = results_dir self.kwargs = kwargs @@ -76,6 +79,7 @@ def __init__(self, bids_path, bids_filters={"suffix": "dwi"}, preproc_pipeline="all", + t1_pipeline=None, participant_labels=None, output_dir=None, parallel_params={"engine": "serial"}, @@ -96,6 +100,10 @@ def __init__(self, preproc_pipeline : str, optional. The name of the pipeline used to preprocess the DWI data. Default: "all". + t1_pipeline : str or None, optional + The name of the pipeline used to preprocess the T1w data. + If None, defaults to the same as preproc_pipeline. + Default: None participant_labels : list or None, optional List of participant labels (subject IDs) to perform processing on. If None, all subjects are used. @@ -142,6 +150,8 @@ def __init__(self, if not isinstance(bids_filters, dict): raise TypeError("bids_filters must be a dict") # preproc_pipeline typechecking handled by pyBIDS + if t1_pipeline is None: + t1_pipeline = preproc_pipeline if participant_labels is not None\ and not isinstance(participant_labels, list): raise TypeError( @@ -250,15 +260,13 @@ def __init__(self, if len(self.sessions) * len(self.subjects) < 2: self.parallel_params["engine"] = "serial" - # do not parallelize segmentation if parallelizing across + # do not parallelize within subject if parallelizing across # subject-sessions if self.parallel_params["engine"] != "serial": - if "segmentation_params" not in kwargs: - kwargs["segmentation_params"] = {} - if "parallel_segmentation" not in kwargs["segmentation_params"]: - kwargs["segmentation_params"]["parallel_segmentation"] = {} - kwargs["segmentation_params"]["parallel_segmentation"]["engine"] =\ - "serial" + if "ray_n_cpus" not in kwargs: + kwargs["ray_n_cpus"] = 1 + if "numba_n_threads" not in kwargs: + kwargs["numba_n_threads"] = 1 self.valid_sub_list = [] self.valid_ses_list = [] @@ -300,15 +308,28 @@ def __init__(self, # files. Maintain input ``bids_filters`` in case user wants to # specify acquisition labels, but pop suffix since it is # already specified inside ``get_bvec()`` and ``get_bval()`` - suffix = bids_filters.pop("suffix", None) + nearby_filters = {**bids_filters, "scope": preproc_pipeline} + nearby_filters.pop("suffix", None) bvec_file = bids_layout.get_bvec( dwi_data_file, - **bids_filters) + **nearby_filters) bval_file = bids_layout.get_bval( dwi_data_file, - **bids_filters) - if suffix is not None: - bids_filters["suffix"] = suffix + **nearby_filters) + nearby_filters.pop("scope", None) + nearby_filters["scope"] = t1_pipeline + t1_file = find_file( + bids_layout, dwi_data_file, + nearby_filters, + "T1w", session, subject) + + self.logger.info( + f"Using the following files for subject {subject} " + f"and session {session}:") + self.logger.info(f" DWI: {dwi_data_file}") + self.logger.info(f" BVAL: {bval_file}") + self.logger.info(f" BVEC: {bvec_file}") + self.logger.info(f" T1: {t1_file}") # Call find path for all definitions for key, value in this_kwargs.items(): @@ -386,17 +407,20 @@ def __init__(self, this_pAFQ_inputs = _ParticipantAFQInputs( dwi_data_file, bval_file, bvec_file, + t1_file, results_dir, this_kwargs) this_pAFQ = ParticipantAFQ( this_pAFQ_inputs.dwi_data_file, this_pAFQ_inputs.bval_file, this_pAFQ_inputs.bvec_file, + this_pAFQ_inputs.t1_file, this_pAFQ_inputs.results_dir, **this_pAFQ_inputs.kwargs) self.plans_dict[subject][str(session)] = this_pAFQ.plans_dict self.pAFQ_list.append(this_pAFQ) self.pAFQ_inputs_list.append(this_pAFQ_inputs) + self.kwargs = self.pAFQ_list[-1].kwargs def combine_profiles(self): tract_profiles_dict = self.export("profiles") @@ -610,8 +634,8 @@ def export_all(self, viz=True, afqbrowser=True, xforms=True, self.logger.info( f"Time taken for export all: {str(time() - start_time)}") - def cmd_outputs(self, cmd="rm", dependent_on=None, exceptions=[], - suffix=""): + def cmd_outputs(self, cmd="rm", dependent_on=None, up_to=None, + exceptions=[], suffix=""): """ Perform some command some or all outputs of pyafq. This is useful if you change a parameter and need @@ -631,6 +655,15 @@ def cmd_outputs(self, cmd="rm", dependent_on=None, exceptions=[], If "recog", perform on all derivatives that depend on the bundle recognition. Default: None + up_to : str or None + If None, will perform on all derivatives. + If "track", will perform on all derivatives up to + (but not including) tractography. + If "recog", will perform on all derivatives up to + (but not including) bundle recognition. + If "prof", will perform on all derivatives up to + (but not including) bundle profiling. + Default: None exceptions : list of str Name outputs that the command should not be applied to. Default: [] @@ -648,7 +681,12 @@ def cmd_outputs(self, cmd="rm", dependent_on=None, exceptions=[], suffix="~/my_other_folder/") """ for pAFQ in self.pAFQ_list: - pAFQ.cmd_outputs(cmd, dependent_on, exceptions, suffix=suffix) + pAFQ.cmd_outputs( + cmd=cmd, + dependent_on=dependent_on, + up_to=up_to, + exceptions=exceptions, + suffix=suffix) clobber = cmd_outputs # alias for default of cmd_outputs @@ -699,6 +737,7 @@ def group_montage(self, bundle_name, size, view, direc, slice_pos=None): best_scalar = self.export("best_scalar", collapse=False)[ self.valid_sub_list[0]][self.valid_ses_list[0]] + t1_dict = self.export("t1_masked", collapse=False) viz_backend_dict = self.export("viz_backend", collapse=False) b0_backend_dict = self.export("b0", collapse=False) dwi_affine_dict = self.export("dwi_affine", collapse=False) @@ -712,6 +751,7 @@ def group_montage(self, bundle_name, size, view, direc, slice_pos=None): this_ses = self.valid_ses_list[ii] viz_backend = viz_backend_dict[this_sub][this_ses] b0 = b0_backend_dict[this_sub][this_ses] + t1 = nib.load(t1_dict[this_sub][this_ses]) dwi_affine = dwi_affine_dict[this_sub][this_ses] bundles = bundles_dict[this_sub][this_ses] best_scalar = best_scalar_dict[this_sub][this_ses] @@ -736,7 +776,7 @@ def group_montage(self, bundle_name, size, view, direc, slice_pos=None): slice_kwargs["z_pos"] = slice_pos figure = viz_backend.visualize_volume( - best_scalar, + t1, opacity=1.0, flip_axes=flip_axes, interact=False, @@ -747,6 +787,7 @@ def group_montage(self, bundle_name, size, view, direc, slice_pos=None): figure = viz_backend.visualize_bundles( bundles, + affine=t1.affine, shade_by_volume=best_scalar, flip_axes=flip_axes, bundle=bundle_name, @@ -1065,7 +1106,7 @@ def _submit_pydra(self, runnable): if "'NoneType' object has no attribute 'replace'" not in str(e): raise - def export(self, attr_name="help", collapse=True): + def export(self, attr_name="help"): f""" Export a specific output. To print a list of available outputs, call export without arguments. @@ -1075,9 +1116,6 @@ def export(self, attr_name="help", collapse=True): ---------- attr_name : str Name of the output to export. Default: "help" - collapse : bool - Whether to collapse session dimension if there is only 1 session. - Default: True Returns ------- @@ -1095,6 +1133,7 @@ def export_sub(pAFQ_kwargs, attr_name): pAFQ_kwargs.dwi_data_file, pAFQ_kwargs.bval_file, pAFQ_kwargs.bvec_file, + pAFQ_kwargs.t1_file, pAFQ_kwargs.results_dir, **pAFQ_kwargs.kwargs) pAFQ.export(attr_name) @@ -1144,6 +1183,7 @@ def export_sub( pAFQ_kwargs.dwi_data_file, pAFQ_kwargs.bval_file, pAFQ_kwargs.bvec_file, + pAFQ_kwargs.t1_file, pAFQ_kwargs.results_dir, **pAFQ_kwargs.kwargs) pAFQ.export_all(viz, xforms, indiv) diff --git a/AFQ/api/participant.py b/AFQ/api/participant.py index e908bd2d3..1fd8f3e6d 100644 --- a/AFQ/api/participant.py +++ b/AFQ/api/participant.py @@ -35,6 +35,7 @@ class ParticipantAFQ(object): def __init__(self, dwi_data_file, bval_file, bvec_file, + t1_file, output_dir, **kwargs): """ @@ -48,6 +49,9 @@ def __init__(self, Path to bval file. bvec_file : str Path to bvec file. + t1_file : str + Path to T1-weighted image file. Must already be registered + to the DWI data, though not resampled. output_dir : str Path to output directory. kwargs : additional optional parameters @@ -57,10 +61,10 @@ def __init__(self, Examples -------- api.ParticipantAFQ( - dwi_data_file, bval_file, bvec_file, output_dir, + dwi_data_file, bval_file, bvec_file, t1_file, output_dir, csd_sh_order_max=4) api.ParticipantAFQ( - dwi_data_file, bval_file, bvec_file, output_dir, + dwi_data_file, bval_file, bvec_file, t1_file, output_dir, reg_template_spec="mni_t2", reg_subject_spec="b0") Notes @@ -99,6 +103,7 @@ def __init__(self, dwi_data_file=dwi_data_file, bval_file=bval_file, bvec_file=bvec_file, + t1_file=t1_file, output_dir=output_dir, base_fname=get_base_fname(output_dir, dwi_data_file), **kwargs) @@ -261,6 +266,7 @@ def participant_montage(self, images_per_row=2): self.logger.info("Generating Montage...") viz_backend = self.export("viz_backend") best_scalar = self.export(self.export("best_scalar")) + t1 = nib.load(self.export("t1_masked")) size = (images_per_row, math.ceil(len(bundle_dict) / images_per_row)) for ii, bundle_name in enumerate(tqdm(bundle_dict)): flip_axes = [False, False, False] @@ -268,12 +274,13 @@ def participant_montage(self, images_per_row=2): flip_axes[i] = (self.export("dwi_affine")[i, i] < 0) figure = viz_backend.visualize_volume( - best_scalar, + t1, flip_axes=flip_axes, interact=False, inline=False) figure = viz_backend.visualize_bundles( self.export("bundles"), + affine=t1.affine, shade_by_volume=best_scalar, color_by_direction=True, flip_axes=flip_axes, @@ -367,8 +374,8 @@ def _save_file(curr_img): _save_file(curr_img) return all_fnames - def cmd_outputs(self, cmd="rm", dependent_on=None, exceptions=[], - suffix=""): + def cmd_outputs(self, cmd="rm", dependent_on=None, up_to=None, + exceptions=[], suffix=""): """ Perform some command some or all outputs of pyafq. This is useful if you change a parameter and need @@ -390,6 +397,15 @@ def cmd_outputs(self, cmd="rm", dependent_on=None, exceptions=[], If "prof", perform on all derivatives that depend on the bundle profiling. Default: None + up_to : str or None + If None, will perform on all derivatives. + If "track", will perform on all derivatives up to + (but not including) tractography. + If "recog", will perform on all derivatives up to + (but not including) bundle recognition. + If "prof", will perform on all derivatives up to + (but not including) bundle profiling. + Default: None exceptions : list of str Name outputs that the command should not be applied to. Default: [] @@ -413,7 +429,8 @@ def cmd_outputs(self, cmd="rm", dependent_on=None, exceptions=[], cmd=cmd, exception_file_names=exception_file_names, suffix=suffix, - dependent_on=dependent_on + dependent_on=dependent_on, + up_to=up_to, ) # do not assume previous calculations are still valid diff --git a/AFQ/api/utils.py b/AFQ/api/utils.py index 56258264e..54e0a9b5c 100644 --- a/AFQ/api/utils.py +++ b/AFQ/api/utils.py @@ -31,15 +31,24 @@ "output_dir": "Path to output directory", "best_scalar": "Go-to scalar for visualizations", "base_fname": "Base file name for outputs", + "pve_wm": "White matter partial volume estimate map", + "pve_gm": "Gray matter partial volume estimate map", + "pve_csf": "Cerebrospinal fluid partial volume estimate map", } + methods_sections = { "dwi_data_file": "data", "bval_file": "data", "bvec_file": "data", + "t1_file": "data", "output_dir": "data", "best_scalar": "tractography", "base_fname": "data", + "pve_wm": "tractography", + "pve_gm": "tractography", + "pve_csf": "tractography", } + kwargs_descriptors = {} for task_module in task_modules: kwargs_descriptors[task_module] = {} @@ -142,16 +151,20 @@ def export_all_helper(api_afq_object, xforms, indiv, viz): api_afq_object.export("median_bundle_lengths") api_afq_object.export("profiles") api_afq_object.export("seed_thresh") - api_afq_object.export("stop_thresh") + stop_threshold = api_afq_object.kwargs.get( + "tracking_params", {}).get( + "stop_threshold", None) + if not isinstance(stop_threshold, str): + api_afq_object.export("stop_thresh") if viz: try: + import pingouin + import seaborn + import IPython + except (ImportError, ModuleNotFoundError): + api_afq_object.logger.warning(viz_import_msg_error("plot")) + else: api_afq_object.export("tract_profile_plots") - except ImportError as e: - plot_err_message = viz_import_msg_error("plot") - if str(e) != plot_err_message: - raise - else: - api_afq_object.logger.warning(plot_err_message) api_afq_object.export("all_bundles_figure") api_afq_object.export("indiv_bundles_figures") diff --git a/AFQ/data/fetch.py b/AFQ/data/fetch.py index 9007e5468..ee46418dc 100644 --- a/AFQ/data/fetch.py +++ b/AFQ/data/fetch.py @@ -1165,6 +1165,10 @@ def organize_stanford_data(path=None, clear_previous_afq=None): if not op.exists(derivatives_path): logger.info(f'creating derivatives directory: {derivatives_path}') + os.makedirs(dmriprep_folder, exist_ok=True) + os.makedirs(freesurfer_folder, exist_ok=True) + os.makedirs(afq_folder, exist_ok=True) + # anatomical data anat_folder = op.join(freesurfer_folder, 'sub-01', 'ses-01', 'anat') os.makedirs(anat_folder, exist_ok=True) diff --git a/AFQ/definitions/image.py b/AFQ/definitions/image.py index 06ffd908d..12b275e59 100644 --- a/AFQ/definitions/image.py +++ b/AFQ/definitions/image.py @@ -3,17 +3,17 @@ import nibabel as nib -from dipy.segment.mask import median_otsu from dipy.align import resample from AFQ.definitions.utils import Definition, find_file, name_from_path from skimage.morphology import convex_hull_image, binary_opening + __all__ = [ - "ImageFile", "FullImage", "RoiImage", "B0Image", "LabelledImageFile", + "ImageFile", "FullImage", "RoiImage", "LabelledImageFile", "ThresholdedImageFile", "ScalarImage", "ThresholdedScalarImage", - "TemplateImage", "GQImage"] + "TemplateImage", "ThreeTissueImage"] logger = logging.getLogger('AFQ') @@ -268,12 +268,14 @@ def __init__(self, use_waypoints=True, use_presegment=False, use_endpoints=False, + only_wmgmi=False, tissue_property=None, tissue_property_n_voxel=None, tissue_property_threshold=None): self.use_waypoints = use_waypoints self.use_presegment = use_presegment self.use_endpoints = use_endpoints + self.only_wmgmi = only_wmgmi self.tissue_property = tissue_property self.tissue_property_n_voxel = tissue_property_n_voxel self.tissue_property_threshold = tissue_property_threshold @@ -335,6 +337,29 @@ def _image_getter_helper(mapping, raise ValueError(( "BundleDict does not have enough ROIs to generate " f"an ROI Image: {bundle_dict._dict}")) + + if self.only_wmgmi: + wmgmi = nib.load( + data_imap["wm_gm_interface"]).get_fdata() + if not np.allclose(wmgmi.shape, image_data.shape): + logger.error("WM/GM Interface shape: %s", wmgmi.shape) + logger.error("ROI image shape: %s", image_data.shape) + raise ValueError(( + "wm_gm_interface and ROI image do not have the " + "same shape, cannot apply wm_gm_interface." + "If ROI image shape is different from DWI shape, " + "consider if you need to map your ROIs to DWI space. " + "If only resampling is required, " + "set resample_subject_to " + "to True in your BundleDict instantiation.")) + + image_data = np.logical_and( + image_data, wmgmi) + if np.sum(image_data) == 0: + raise ValueError(( + "BundleDict does not have enough ROIs to generate " + "an ROI Image with WM/GM interface applied.")) + return nib.Nifti1Image( image_data.astype(np.float32), data_imap["dwi_affine"]), dict(source="ROIs") @@ -360,90 +385,6 @@ def image_getter( return image_getter -class GQImage(ImageDefinition): - """ - Threshold the anisotropic diffusion component of the - Generalized Q-Sampling Model to generate a brain mask - which will include the eyes, optic nerve, and cerebrum - but will exclude most or all of the skull. - - Examples - -------- - api.GroupAFQ(brain_mask_definition=GQImage()) - """ - - def __init__(self): - pass - - def get_name(self): - return "GQ" - - def get_image_getter(self, task_name): - def image_getter_helper(gq_aso): - gq_aso_img = nib.load(gq_aso) - gq_aso_data = gq_aso_img.get_fdata() - ASO_mask = convex_hull_image( - binary_opening( - gq_aso_data > 0.1)) - - return nib.Nifti1Image( - ASO_mask.astype(np.float32), - gq_aso_img.affine), dict( - source=gq_aso, - technique="GQ ASO thresholded maps") - - if task_name == "data": - return image_getter_helper - else: - return lambda data_imap: image_getter_helper( - data_imap["gq_aso"]) - - -class B0Image(ImageDefinition): - """ - Define an image using b0 and dipy's median_otsu. - - Parameters - ---------- - median_otsu_kwargs: dict, optional - Optional arguments to pass into dipy's median_otsu. - Default: {} - - Examples - -------- - brain_image_definition = B0Image() - api.GroupAFQ(brain_image_definition=brain_image_definition) - """ - - def __init__(self, median_otsu_kwargs={}): - self.median_otsu_kwargs = median_otsu_kwargs - - def get_name(self): - return "b0" - - def get_image_getter(self, task_name): - def image_getter_helper(b0): - mean_b0_img = nib.load(b0) - mean_b0 = mean_b0_img.get_fdata() - logger.warning(( - "It is recommended that you provide a brain mask. " - "It is provided with the brain_mask_definition argument. " - "Otherwise, the default brain mask is calculated " - "by using OTSU on the median-filtered B0 image. " - "This can be unreliable. ")) - _, image_data = median_otsu(mean_b0, **self.median_otsu_kwargs) - return nib.Nifti1Image( - image_data.astype(np.float32), - mean_b0_img.affine), dict( - source=b0, - technique="median_otsu applied to b0", - median_otsu_kwargs=self.median_otsu_kwargs) - if task_name == "data": - return image_getter_helper - else: - return lambda data_imap: image_getter_helper(data_imap["b0"]) - - class LabelledImageFile(ImageFile, CombineImageMixin): """ Define an image based on labels in a file. @@ -681,7 +622,8 @@ def __init__(self, scalar, lower_bound=None, upper_bound=None, class PFTImage(ImageDefinition): """ Define an image for use in PFT tractography. Only use - if tracker set to 'pft' in tractography. + if tracker set to 'pft' in tractography. Used to provide + custom segmentations. Parameters ---------- @@ -700,7 +642,7 @@ class PFTImage(ImageDefinition): afm.ImageFile(suffix="CSFprobseg")) api.GroupAFQ(tracking_params={ "stop_image": stop_image, - "stop_threshold": "CMC", + "stop_threshold": "ACT", "tracker": "pft"}) """ @@ -727,6 +669,53 @@ def get_image_getter(self, task_name): for probseg in self.probsegs] +class ThreeTissueImage(ImageDefinition): + """ + Define a three tissue image for use in PFT tractography. Only use + if tracker set to 'pft' in tractography. Use to generate + WM/GM/CSF probsegs from T1w. + + Examples + -------- + api.GroupAFQ(tracking_params={ + "stop_image": ThreeTissueImage(), + "stop_threshold": "ACT", + "tracker": "pft"}) + """ + + def __init__(self): + pass + + def get_name(self): + return "ThreeT" + + def get_image_getter(self, task_name): + if task_name == "data": + raise ValueError(( + "ThreeTissueImage cannot be used in this context, as they" + "require later derivatives to be calculated")) + + def csf_getter(data_imap): + PVE = nib.load(data_imap["t1w_pve"]) + return nib.Nifti1Image( + PVE.get_fdata()[..., 0].astype(np.float32), + PVE.affine), dict(source=data_imap["t1w_pve"]) + + def gm_getter(data_imap): + PVE = nib.load(data_imap["t1w_pve"]) + return nib.Nifti1Image( + PVE.get_fdata()[..., 1].astype(np.float32), + PVE.affine), dict(source=data_imap["t1w_pve"]) + + def wm_getter(data_imap): + PVE = nib.load(data_imap["t1w_pve"]) + return nib.Nifti1Image( + PVE.get_fdata()[..., 2].astype(np.float32), + PVE.affine), dict(source=data_imap["t1w_pve"]) + + return [wm_getter, gm_getter, csf_getter] + + class TemplateImage(ImageDefinition): """ Define a scalar based on a template. diff --git a/AFQ/definitions/utils.py b/AFQ/definitions/utils.py index cadb80524..e69f7aaa4 100644 --- a/AFQ/definitions/utils.py +++ b/AFQ/definitions/utils.py @@ -96,6 +96,8 @@ def find_file(bids_layout, path, filters, suffix, session, subject, Helper function Generic calls to get_nearest to find a file """ + filters = filters.copy() + if "extension" not in filters: filters["extension"] = extension if "suffix" not in filters: diff --git a/AFQ/models/asym_filtering.py b/AFQ/models/asym_filtering.py new file mode 100644 index 000000000..e488869ab --- /dev/null +++ b/AFQ/models/asym_filtering.py @@ -0,0 +1,521 @@ +# -*- coding: utf-8 -*- +# Original source: github.com/scilus/scilpy +# Copyright (c) 2012-- Sherbrooke Connectivity Imaging Lab [SCIL], Université de Sherbrooke. +# Licensed under the MIT License (https://opensource.org/licenses/MIT). +# Modified by John Kruper for pyAFQ +# OpenCL and cosine filtering removed +# Replaced with numba + +import numpy as np +from tqdm import tqdm + +from numba import njit, prange, set_num_threads, config + +from dipy.reconst.shm import sh_to_sf_matrix, sph_harm_ind_list, sh_to_sf +from dipy.direction import peak_directions +from dipy.data import get_sphere + + +__all__ = ["unified_filtering", "compute_asymmetry_index", + "compute_odd_power_map", "compute_nufid_asym"] + + +def _get_sh_order_and_fullness(ncoeffs): + """ + Get the order of the SH basis from the number of SH coefficients + as well as a boolean indicating if the basis is full. + """ + # the two curves (sym and full) intersect at ncoeffs = 1, in what + # case both bases correspond to order 1. + sym_order = (-3.0 + np.sqrt(1.0 + 8.0 * ncoeffs)) / 2.0 + if sym_order.is_integer(): + return sym_order, False + full_order = np.sqrt(ncoeffs) - 1.0 + if full_order.is_integer(): + return full_order, True + raise ValueError('Invalid number of coefficients for SH basis.') + + +def unified_filtering(sh_data, sphere, + sh_basis='descoteaux07', is_legacy=False, + sigma_spatial=1.0, sigma_align=0.8, + sigma_angle=None, rel_sigma_range=0.2, + n_threads=None): + """ + Unified asymmetric filtering as described in [1]. + + Parameters + ---------- + sh_data: ndarray + SH coefficients image. + sphere: str or DIPY sphere + Name of the DIPY sphere to use for SH to SF projection. + sh_basis: str + SH basis definition used for input and output SH image. + One of 'descoteaux07' or 'tournier07'. + Default: 'descoteaux07'. + is_legacy: bool + Whether the legacy SH basis definition should be used. + Default: False. + sigma_spatial: float or None + Standard deviation of spatial filter. Can be None to replace + by mean filter, in what case win_hwidth must be given. + sigma_align: float or None + Standard deviation of alignment filter. `None` disables + alignment filtering. + sigma_angle: float or None + Standard deviation of the angle filter. `None` disables + angle filtering. + rel_sigma_range: float or None + Standard deviation of the range filter, relative to the + range of SF amplitudes. `None` disables range filtering. + n_threads: int or None + Number of threads to use for numba. If None, uses + the number of available threads. + Default: None. + + References + ---------- + [1] Poirier and Descoteaux, 2024, "A Unified Filtering Method for + Estimating Asymmetric Orientation Distribution Functions", + Neuroimage, https://doi.org/10.1016/j.neuroimage.2024.120516 + """ + if isinstance(sphere, str): + sphere = get_sphere(name=sphere) + + if sigma_spatial is not None: + if sigma_spatial <= 0.0: + raise ValueError('sigma_spatial cannot be <= 0.') + if sigma_align is not None: + if sigma_align <= 0.0: + raise ValueError('sigma_align cannot be <= 0.') + if sigma_angle is not None: + if sigma_angle <= 0.0: + raise ValueError('sigma_angle cannot be <= 0.') + + if n_threads is not None: + set_num_threads(n_threads) + + sh_order, full_basis = _get_sh_order_and_fullness(sh_data.shape[-1]) + + # build filters + uv_filter = _unified_filter_build_uv(sigma_angle, + sphere.vertices.astype(np.float64)) + nx_filter = _unified_filter_build_nx(sphere.vertices.astype(np.float64), + sigma_spatial, sigma_align, + False, False) + B = sh_to_sf_matrix(sphere, sh_order, sh_basis, full_basis, + legacy=is_legacy, return_inv=False) + _, B_inv = sh_to_sf_matrix(sphere, sh_order, sh_basis, True, + legacy=is_legacy, return_inv=True) + + # compute "real" sigma_range scaled by sf amplitudes + # if rel_sigma_range is supplied + sigma_range = None + if rel_sigma_range is not None: + if rel_sigma_range <= 0.0: + raise ValueError('sigma_rangel cannot be <= 0.') + sigma_range = rel_sigma_range * _get_sf_range(sh_data, B) + + return _unified_filter_call_python( + sh_data, nx_filter, uv_filter, + sigma_range, B, B_inv, sphere) + + +@njit(fastmath=True, cache=True) +def _unified_filter_build_uv(sigma_angle, directions): + """ + Build the angle filter, weighted on angle between current direction u + and neighbour direction v. + + Parameters + ---------- + sigma_angle: float + Standard deviation of filter. Values at distances greater than + sigma_angle are clipped to 0 to reduce computation time. + directions: DIPY sphere directions. + Vertices from DIPY sphere for sampling the SF. + + Returns + ------- + weights: ndarray + Angle filter of shape (N_dirs, N_dirs). + """ + if sigma_angle is not None: + dot = directions.dot(directions.T) + x = np.arccos(np.clip(dot, -1.0, 1.0)) + weights = _evaluate_gaussian_distribution(x, sigma_angle) + mask = x > (3.0 * sigma_angle) + weights[mask] = 0.0 + weights /= np.sum(weights, axis=-1) + else: + weights = np.eye(len(directions)) + return weights + + +@njit(fastmath=True, cache=True) +def _unified_filter_build_nx(directions, sigma_spatial, sigma_align, + disable_spatial, + disable_align, j_invariance=False): + """ + Original source: github.com/CHrlS98/aodf-toolkit + Copyright (c) 2023 Charles Poirier + Licensed under the MIT License (https://opensource.org/licenses/MIT). + """ + directions = np.ascontiguousarray(directions.astype(np.float32)) + + half_width = int(round(3 * sigma_spatial)) + nx_weights = np.zeros((2 * half_width + 1, 2 * half_width + 1, + 2 * half_width + 1, len(directions)), + dtype=np.float32) + + for i in range(-half_width, half_width + 1): + for j in range(-half_width, half_width + 1): + for k in range(-half_width, half_width + 1): + dxy = np.array([[i, j, k]], dtype=np.float32) + len_xy = np.sqrt(dxy[0, 0]**2 + dxy[0, 1]**2 + dxy[0, 2]**2) + + if disable_spatial: + w_spatial = 1.0 + else: + # the length controls spatial weight + w_spatial = np.exp(-len_xy**2 / (2 * sigma_spatial**2)) + + # the direction controls the align weight + if i == j == k == 0 or disable_align: + # hack for main direction to have maximal weight + # w_align = np.ones((1, len(directions)), dtype=np.float32) + w_align = np.zeros((1, len(directions)), dtype=np.float32) + else: + dxy /= len_xy + w_align = np.arccos(np.clip(np.dot(dxy, directions.T), + -1.0, 1.0)) # 1, N + w_align = np.exp(-w_align**2 / (2 * sigma_align**2)) + + nx_weights[half_width + i, half_width + j, half_width + k] =\ + w_align * w_spatial + + if j_invariance: + # A filter is j-invariant if its prediction does not + # depend on the content of the current voxel + nx_weights[half_width, half_width, half_width, :] = 0.0 + + for ui in range(len(directions)): + w_sum = np.sum(nx_weights[..., ui]) + nx_weights /= w_sum + + return nx_weights + + +def _get_sf_range(sh_data, B_mat): + """ + Get the range of SF amplitudes for input `sh_data`. + + Parameters + ---------- + sh_data: ndarray + Spherical harmonics coefficients image. + B_mat: ndarray + SH to SF projection matrix. + + Returns + ------- + sf_range: float + Range of SF amplitudes. + """ + sf = np.array([np.dot(i, B_mat) for i in sh_data], + dtype=sh_data.dtype) + sf[sf < 0.0] = 0.0 + sf_max = np.max(sf) + sf_min = np.min(sf) + return sf_max - sf_min + + +def _unified_filter_call_python(sh_data, nx_filter, uv_filter, sigma_range, + B_mat, B_inv, sphere): + """ + Run filtering using pure python implementation. + + Parameters + ---------- + sh_data: ndarray + Input SH data. + nx_filter: ndarray + Combined spatial and alignment filter. + uv_filter: ndarray + Angle filter. + sigma_range: float or None + Standard deviation of range filter. None disables range filtering. + B_mat: ndarray + SH to SF projection matrix. + B_inv: ndarray + SF to SH projection matrix. + sphere: DIPY sphere + Sphere for SH to SF projection. + + Returns + ------- + out_sh: ndarray + Filtered output as SH coefficients. + """ + nb_sf = len(sphere.vertices) + mean_sf = np.zeros(sh_data.shape[:-1] + (nb_sf,)) + sh_data = np.ascontiguousarray(sh_data, dtype=np.float64) + B_mat = np.ascontiguousarray(B_mat, dtype=np.float64) + + config.THREADING_LAYER = "workqueue" + + h_w, h_h, h_d = nx_filter.shape[:3] + half_w, half_h, half_d = h_w // 2, h_h // 2, h_d // 2 + sh_data_padded = np.ascontiguousarray(np.pad( + sh_data, + ((half_w, half_w), (half_h, half_h), (half_d, half_d), (0, 0)), + mode='constant' + ), dtype=np.float64) + + for u_sph_id in tqdm(range(nb_sf)): + mean_sf[..., u_sph_id] = _correlate(sh_data, sh_data_padded, + nx_filter, uv_filter, + sigma_range, u_sph_id, B_mat) + + out_sh = np.array([np.dot(i, B_inv) for i in mean_sf], + dtype=sh_data.dtype) + return out_sh + + +@njit(fastmath=True, parallel=True) +def _correlate(sh_data, sh_data_padded, nx_filter, uv_filter, + sigma_range, u_index, B_mat): + """ + Apply the filters to the SH image for the sphere direction + described by `u_index`. + + Parameters + ---------- + sh_data: ndarray + Input SH coefficients. + sh_data: ndarray + Input SH coefficients, pre-padded. + nx_filter: ndarray + Combined spatial and alignment filter. + uv_filter: ndarray + Angle filter. + sigma_range: float or None + Standard deviation of range filter. None disables range filtering. + u_index: int + Index of the current sphere direction to process. + B_mat: ndarray + SH to SF projection matrix. + + Returns + ------- + out_sf: ndarray + Output SF amplitudes along the direction described by `u_index`. + """ + v_indices = np.flatnonzero(uv_filter[u_index]) + nx_filter = nx_filter[..., u_index] + h_w, h_h, h_d = nx_filter.shape[:3] + half_w, half_h, half_d = h_w // 2, h_h // 2, h_d // 2 + out_sf = np.zeros(sh_data.shape[:3]) + + # sf_u = np.dot(sh_data, B_mat[:, u_index]) + # sf_v = np.dot(sh_data, B_mat[:, v_indices]) + sf_u = np.zeros(sh_data_padded.shape[:3]) + sf_v = np.zeros(sh_data_padded.shape[:3] + (len(v_indices),)) + for i in prange(sh_data_padded.shape[0]): + for j in range(sh_data_padded.shape[1]): + for k in range(sh_data_padded.shape[2]): + for c in range(sh_data_padded.shape[3]): + sf_u[i, j, k] += sh_data_padded[i, + j, k, c] * B_mat[c, u_index] + for vi in range(len(v_indices)): + sf_v[i, j, k, vi] += sh_data_padded[i, + j, k, c] * B_mat[c, v_indices[vi]] + + uv_filter = uv_filter[u_index, v_indices] + + for ii in prange(out_sf.shape[0]): + for jj in range(out_sf.shape[1]): + for kk in range(out_sf.shape[2]): + a = sf_v[ii:ii + h_w, jj:jj + h_h, kk:kk + h_d] + b = sf_u[ii + half_w, jj + half_h, kk + half_d] + x_range = a - b + + if sigma_range is None: + range_filter = np.ones_like(x_range) + else: + range_filter = _evaluate_gaussian_distribution( + x_range, sigma_range) + + # the resulting filter for the current voxel and v_index + res_filter = range_filter * nx_filter[..., None] + res_filter =\ + res_filter * np.reshape(uv_filter, + (1, 1, 1, len(uv_filter))) + out_sf[ii, jj, kk] = np.sum( + sf_v[ii:ii + h_w, jj:jj + h_h, kk:kk + h_d] * res_filter) + out_sf[ii, jj, kk] /= np.sum(res_filter) + + return out_sf + + +@njit(fastmath=True, cache=True) +def _evaluate_gaussian_distribution(x, sigma): + """ + 1-dimensional 0-centered Gaussian distribution + with standard deviation sigma. + + Parameters + ---------- + x: ndarray or float + Points where the distribution is evaluated. + sigma: float + Standard deviation. + + Returns + ------- + out: ndarray or float + Values at x. + """ + if sigma <= 0.0: + raise ValueError("Sigma must be greater than 0.") + cnorm = 1.0 / sigma / np.sqrt(2.0 * np.pi) + return cnorm * np.exp(-x**2 / 2.0 / sigma**2) + + +def compute_asymmetry_index(sh_coeffs, mask): + """ + Compute asymmetry index (ASI) [1] from + asymmetric ODF volume expressed in full SH basis. + + Parameters + ---------- + sh_coeffs: ndarray (x, y, z, ncoeffs) + Input spherical harmonics coefficients. + mask: ndarray (x, y, z), bool + Mask inside which ASI should be computed. + + Returns + ------- + asi_map: ndarray (x, y, z) + Asymmetry index map. + + References + ---------- + [1] S. Cetin Karayumak, E. Özarslan, and G. Unal, + "Asymmetric Orientation Distribution Functions (AODFs) + revealing intravoxel geometry in diffusion MRI" + Magnetic Resonance Imaging, vol. 49, pp. 145-158, Jun. 2018, + doi: https://doi.org/10.1016/j.mri.2018.03.006. + """ + order, full_basis = _get_sh_order_and_fullness(sh_coeffs.shape[-1]) + + _, l_list = sph_harm_ind_list(order, full_basis=full_basis) + + sign = np.power(-1.0, l_list) + sign = np.reshape(sign, (1, 1, 1, len(l_list))) + sh_squared = sh_coeffs**2 + mask = np.logical_and(sh_squared.sum(axis=-1) > 0., mask) + + asi_map = np.zeros(sh_coeffs.shape[:-1]) + asi_map[mask] = np.sum(sh_squared * sign, axis=-1)[mask] / \ + np.sum(sh_squared, axis=-1)[mask] + + # Negatives should not happen (amplitudes always positive) + asi_map = np.clip(asi_map, 0.0, 1.0) + asi_map = np.sqrt(1 - asi_map**2) * mask + + return asi_map + + +def compute_odd_power_map(sh_coeffs, mask): + """ + Compute odd-power map [1] from + asymmetric ODF volume expressed in full SH basis. + + Parameters + ---------- + sh_coeffs: ndarray (x, y, z, ncoeffs) + Input spherical harmonics coefficients. + mask: ndarray (x, y, z), bool + Mask inside which odd-power map should be computed. + + Returns + ------- + odd_power_map: ndarray (x, y, z) + Odd-power map. + + References + ---------- + [1] C. Poirier, E. St-Onge, and M. Descoteaux, + "Investigating the Occurence of Asymmetric Patterns in + White Matter Fiber Orientation Distribution Functions" + [Abstract], In: Proc. Intl. Soc. Mag. Reson. Med. 29 (2021), + 2021 May 15-20, Vancouver, BC, Abstract number 0865. + """ + order, full_basis = _get_sh_order_and_fullness(sh_coeffs.shape[-1]) + _, l_list = sph_harm_ind_list(order, full_basis=full_basis) + odd_l_list = (l_list % 2 == 1).reshape((1, 1, 1, -1)) + + odd_order_norm = np.linalg.norm(sh_coeffs * odd_l_list, + ord=2, axis=-1) + + full_order_norm = np.linalg.norm(sh_coeffs, ord=2, axis=-1) + + asym_map = np.zeros(sh_coeffs.shape[:-1]) + mask = np.logical_and(full_order_norm > 0, mask) + asym_map[mask] = odd_order_norm[mask] / full_order_norm[mask] + + return asym_map + + +def compute_nufid_asym(sh_coeffs, sphere, csf, mask): + """ + Number of fiber directions (nufid) map [1]. + + Parameters + ---------- + sh_coeffs: ndarray (x, y, z, ncoeffs) + Input spherical harmonics coefficients. + + sphere: DIPY sphere + Sphere for SH to SF projection. + + csf: ndarray (x, y, z) + CSF probability map, used to guess the absolute threshold. + + mask: ndarray (x, y, z), bool + Mask inside which ASI should be computed. + + References + ---------- + [1] C. Poirier and M. Descoteaux, + "Filtering Methods for Asymmetric ODFs: + Where and How Asymmetry Occurs in the White Matter." + bioRxiv. 2022 Jan 1; 2022.12.18.520881. + doi: https://doi.org/10.1101/2022.12.18.520881 + """ + sh_order, full_basis = _get_sh_order_and_fullness(sh_coeffs.shape[-1]) + odf = sh_to_sf( + sh_coeffs, sphere, + sh_order_max=sh_order, + basis_type='descoteaux07', + full_basis=full_basis, + legacy=False) + + # Guess at threshold from 2.0 * mean of ODF maxes in CSF + absolute_threshold = 2.0 * np.mean(np.max(odf[csf > 0.99], axis=-1)) + odf[odf < absolute_threshold] = 0. + + nufid_data = np.zeros(sh_coeffs.shape[:-1], dtype=np.float32) + for ii in tqdm(range(sh_coeffs.shape[0])): + for jj in range(sh_coeffs.shape[1]): + for kk in range(sh_coeffs.shape[2]): + if mask[ii, jj, kk]: + _, peaks, _ = peak_directions( + odf[ii, jj, kk], sphere, + is_symmetric=False) + + nufid_data[ii, jj, kk] = np.count_nonzero(peaks) + + return nufid_data diff --git a/AFQ/models/dki.py b/AFQ/models/dki.py index 06cb702cd..7049962f5 100644 --- a/AFQ/models/dki.py +++ b/AFQ/models/dki.py @@ -4,6 +4,10 @@ import numpy as np import nibabel as nib +from scipy.signal import find_peaks, peak_widths +from scipy.special import erf +from scipy.ndimage import gaussian_filter1d + from dipy.reconst import dki from dipy.reconst import dki_micro from dipy.core.ndindex import ndindex diff --git a/AFQ/models/msmt.py b/AFQ/models/msmt.py new file mode 100644 index 000000000..5815b6e0a --- /dev/null +++ b/AFQ/models/msmt.py @@ -0,0 +1,120 @@ +import multiprocessing +import numpy as np +from tqdm import tqdm +import ray + +from scipy.sparse import csr_matrix + +import osqp + +from dipy.reconst.mcsd import MSDeconvFit +from dipy.reconst.mcsd import MultiShellDeconvModel + +from AFQ.utils.stats import chunk_indices + + +__all__ = ["fit"] + + +def _fit(self, data, mask=None, n_cpus=None): + """ + Use OSQP to fit the multi-shell spherical deconvolution model. + """ + if n_cpus is None: + n_cpus = max(multiprocessing.cpu_count() - 1, 1) + + og_data_shape = data.shape + if len(data.shape) < 4: + data = data.reshape((1,) * (4 - data.ndim) + data.shape) + + m, n = self.fitter._reg.shape + coeff = np.zeros((*data.shape[:3], n), dtype=np.float64) + if mask is None: + mask = np.ones(data.shape[:3], dtype=bool) + + R = np.ascontiguousarray(self.fitter._X, dtype=np.float64) + A = np.ascontiguousarray(self.fitter._reg, dtype=np.float64) + b = np.zeros(A.shape[0], dtype=np.float64) + + # Normalize constraints + for i in range(A.shape[0]): + A[i] /= np.linalg.norm(A[i]) + + A_outer = np.empty((n, n, m), dtype=np.float64) + for k in range(m): + for i in range(n): + for j in range(n): + A_outer[i, j, k] = A[k, i] * A[k, j] + + Q = R.T @ R + if n_cpus > 1: + ray.init(ignore_reinit_error=True) + + data_id = ray.put(data) + mask_id = ray.put(mask) + Q_id = ray.put(Q) + A_id = ray.put(A) + b_id = ray.put(b) + R_id = ray.put(R) + + @ray.remote( + num_cpus=n_cpus) + def process_batch_remote(batch_indices, data, mask, + Q, A, b, R): + from scipy.sparse import csr_matrix + import osqp + import numpy as np + + m = osqp.OSQP() + m.setup( + P=csr_matrix(Q), A=csr_matrix(A), l=b, + u=None, q=None, + verbose=False) + return_values = np.zeros( + (len(batch_indices),) + data.shape[1:3] + (A.shape[1],), + dtype=np.float64) + for i, ii in enumerate(batch_indices): + for jj in range(data.shape[1]): + for kk in range(data.shape[2]): + if mask[ii, jj, kk]: + c = np.dot(-R.T, data[ii, jj, kk]) + m.update(q=c) + results = m.solve() + return_values[i, jj, kk] = results.x + return return_values + + # Launch tasks in chunks + all_indices = list(range(data.shape[0])) + indices_chunked = list(chunk_indices(all_indices, n_cpus * 2)) + futures = [ + process_batch_remote.remote(batch, data_id, mask_id, + Q_id, A_id, + b_id, R_id) + for batch in indices_chunked + ] + + # Collect and assign results + for batch, future in zip( + indices_chunked, tqdm(futures)): + results = ray.get(future) + for i, ii in enumerate(batch): + coeff[ii] = results[i] + else: + m = osqp.OSQP() + m.setup( + P=csr_matrix(Q), A=csr_matrix(A), l=b, + u=None, q=None, + verbose=False, adaptive_rho=True) + for ii in tqdm(range(data.shape[0])): + for jj in range(data.shape[1]): + for kk in range(data.shape[2]): + if mask[ii, jj, kk]: + c = np.dot(-R.T, data[ii, jj, kk]) + m.update(q=c) + results = m.solve() + coeff[ii, jj, kk] = results.x + coeff = coeff.reshape(og_data_shape[:-1] + (n,)) + return MSDeconvFit(self, coeff, None) + + +MultiShellDeconvModel.fit = _fit diff --git a/AFQ/models/wmgm_interface.py b/AFQ/models/wmgm_interface.py new file mode 100644 index 000000000..b11216ca9 --- /dev/null +++ b/AFQ/models/wmgm_interface.py @@ -0,0 +1,96 @@ +import numpy as np +import nibabel as nib + +from dipy.align import resample + +from scipy.ndimage import gaussian_filter +from skimage.segmentation import find_boundaries + + +def fit_wm_gm_interface(PVE_img, dwiref_img): + """ + Compute the white matter/gray matter interface from a PVE image. + + Parameters + ---------- + PVE_img : Nifti1Image + PVE image containing CSF, GM, and WM segmentations from T1 + dwiref_img : Nifti1Image + Reference image to find boundary in that space. + """ + PVE = PVE_img.get_fdata() + + csf = PVE[..., 0] + gm = PVE[..., 1] + wm = PVE[..., 2] + + # Put in diffusion space + wm = resample( + wm, + dwiref_img.get_fdata(), + moving_affine=PVE_img.affine, + static_affine=dwiref_img.affine).get_fdata() + gm = resample( + gm, + dwiref_img.get_fdata(), + moving_affine=PVE_img.affine, + static_affine=dwiref_img.affine).get_fdata() + csf = resample( + csf, + dwiref_img.get_fdata(), + moving_affine=PVE_img.affine, + static_affine=dwiref_img.affine).get_fdata() + + wm_boundary = find_boundaries(wm, mode='inner') + gm_smoothed = gaussian_filter(gm, 1) + csf_smoothed = gaussian_filter(csf, 1) + + wm_boundary[~gm_smoothed.astype(bool)] = 0 + wm_boundary[csf_smoothed > gm_smoothed] = 0 + + return nib.Nifti1Image( + wm_boundary.astype(np.float32), dwiref_img.affine) + + +def pve_from_subcortex(t1_subcortex_data): + """ + Compute the PVE (Partial Volume Estimation) from the subcortex T1 image. + + Parameters + ---------- + t1_subcortex_data : ndarray + T1 subcortex data from brainchop + + Returns + ------- + pve_img : ndarray + PVE data with CSF, GM, and WM segmentations. + """ + CSF_labels = [0, 3, 4, 11, 12] + GM_labels = [2, 6, 7, 8, 9, 10, 14, 15, 16] + WM_labels = [1, 5] + mixed_labels = [13, 17] + + PVE = np.zeros(t1_subcortex_data.shape + (3,), dtype=np.float32) + + PVE[np.isin(t1_subcortex_data, CSF_labels), 0] = 1.0 + PVE[np.isin(t1_subcortex_data, GM_labels), 1] = 1.0 + PVE[np.isin(t1_subcortex_data, WM_labels), 2] = 1.0 + + # For mixed labels, we assume they are WM interior, GM exterior + # This is a simplification, basically so they do not cause problems + # with ACT + wm_fuzzed = gaussian_filter(PVE[..., 2], 1) + nwm_fuzzed = gaussian_filter(PVE[..., 0] + PVE[..., 1], 1) + bs_exterior = np.logical_and( + find_boundaries( + np.isin(t1_subcortex_data, mixed_labels), + mode='inner'), + nwm_fuzzed >= wm_fuzzed) + bs_interior = np.logical_and( + np.isin(t1_subcortex_data, mixed_labels), + ~bs_exterior) + PVE[bs_exterior, 1] = 1.0 + PVE[bs_interior, 2] = 1.0 + + return PVE diff --git a/AFQ/nn/__init__.py b/AFQ/nn/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/AFQ/nn/brainchop.py b/AFQ/nn/brainchop.py new file mode 100644 index 000000000..da25292b5 --- /dev/null +++ b/AFQ/nn/brainchop.py @@ -0,0 +1,115 @@ +import subprocess +import tempfile +import os + +import numpy as np +import nibabel as nib + +from brainchop.utils import get_model +from brainchop.niimath import ( + conform, + bwlabel) + +from tinygrad import Tensor +from tinygrad.helpers import Context + +import logging + + +logger = logging.getLogger('AFQ') + + +def _brainchop_reslice(tmp_t1_file, tmp_out_file, output_dtype, full_input): + cmd = [ + "niimath", "-", "-reslice_nn", tmp_t1_file, + "-gz", "1", tmp_out_file, "-odt", output_dtype] + + subprocess.run(cmd, input=full_input, check=True) + + +def _run_brainchop_command(func, args): + """ + Run a Brainchop command line interface + with the provided arguments, but with error handling. + """ + try: + return func(*args) + except subprocess.CalledProcessError as e: + logger.error("Command failed: %s", e.cmd) + logger.error("Return code: %s", e.returncode) + if e.stdout: + if isinstance(e.stdout, bytes): + logger.error("STDOUT:\n%s", e.stdout.decode()) + else: + logger.error("STDOUT:\n%s", e.stdout) + if e.stderr: + if isinstance(e.stderr, bytes): + logger.error("STDERR:\n%s", e.stderr.decode()) + else: + logger.error("STDERR:\n%s", e.stderr) + raise + + +def run_brainchop(t1_img, model): + """ + Run the Brainchop command line interface with the provided arguments. + """ + model = get_model(model) + output_dtype = "char" + + with tempfile.TemporaryDirectory() as temp_dir: + tmp_t1_file = f"{temp_dir}/t1.nii.gz" + tmp_out_file = f"{temp_dir}/output.nii.gz" + + t1_data = t1_img.get_fdata() + t1_data = np.clip(t1_data, 0, t1_data.max()) + nib.save(nib.Nifti1Image(t1_data, t1_img.affine), tmp_t1_file) + + volume, header = _run_brainchop_command(conform, [tmp_t1_file]) + + image = Tensor(volume.transpose((2, 1, 0)).astype(np.float32)).rearrange( + "... -> 1 1 ..." + ) + + try: + output_channels = _run_brainchop_command(model, [image]) + except Exception as e: + if "clang" in str(e).lower(): + with Context(PYTHON=1): + output_channels = model(image) + else: + raise + + output = ( + output_channels.argmax(axis=1) + .rearrange("1 x y z -> z y x") + .numpy() + .astype(np.uint8) + ) + + labels, new_header = _run_brainchop_command( + bwlabel, + [header, output]) + full_input = new_header + labels.tobytes() + + _run_brainchop_command( + _brainchop_reslice, + [ + tmp_t1_file, + tmp_out_file, + output_dtype, + full_input + ] + ) + + output_img = nib.load(tmp_out_file) + # This line below forces the data into memory to avoid issues with + # temporary files being deleted too early. + # Otherwise, nibabel lazy-loads the data and the file gets deleted + # before the data is accessed, because the file is in a temporary + # Directory. + output_img = nib.Nifti1Image( + output_img.get_fdata().copy(), + output_img.affine.copy()) + + return output_img diff --git a/AFQ/recognition/cleaning.py b/AFQ/recognition/cleaning.py index 28d808a2f..5f183eb7f 100644 --- a/AFQ/recognition/cleaning.py +++ b/AFQ/recognition/cleaning.py @@ -9,7 +9,7 @@ import AFQ.recognition.utils as abu from AFQ._fixes import gaussian_weights - +from sklearn.ensemble import IsolationForest logger = logging.getLogger('AFQ') @@ -67,9 +67,61 @@ def clean_by_orientation(streamlines, primary_axis, affine, tol=None): return cleaned_idx +def clean_by_orientation_mahalanobis(streamlines, n_points=100, + core_only=0.6, min_sl=20, + distance_threshold=3, + clean_rounds=5): + fgarray = np.array(abu.resample_tg(streamlines, n_points)) + + if core_only != 0: + crop_edge = (1.0 - core_only) / 2 + fgarray = fgarray[ + :, + int(n_points * crop_edge):int(n_points * (1 - crop_edge)), + :] # Crop to middle 60% + + fgarray_dists = fgarray[:, 1:, :] - fgarray[:, :-1, :] + idx = np.arange(len(fgarray)) + rounds_elapsed = 0 + idx_dist = idx + while rounds_elapsed < clean_rounds: + # This calculates the Mahalanobis for each streamline/node: + m_dist = gaussian_weights( + fgarray_dists, return_mahalnobis=True, + n_points=None, stat=np.mean) + logger.debug(f"Shape of fgarray: {np.asarray(fgarray_dists).shape}") + logger.debug(( + f"Maximum m_dist for each fiber: " + f"{np.max(m_dist, axis=1)}")) + + if not (np.any(m_dist >= distance_threshold)): + break + idx_dist = np.all(m_dist < distance_threshold, axis=-1) + + if np.sum(idx_dist) < min_sl: + # need to sort and return exactly min_sl: + idx = idx[np.argsort(np.sum( + m_dist, axis=-1))[:min_sl].astype(int)] + logger.debug(( + f"At rounds elapsed {rounds_elapsed}, " + "minimum streamlines reached")) + break + else: + # Update by selection: + idx = idx[idx_dist] + fgarray_dists = fgarray_dists[idx_dist] + rounds_elapsed += 1 + logger.debug(( + f"Rounds elapsed: {rounds_elapsed}, " + f"num kept: {len(idx)}")) + logger.debug(f"Kept indicies: {idx}") + + return idx + + def clean_bundle(tg, n_points=100, clean_rounds=5, distance_threshold=3, - length_threshold=4, min_sl=20, stat='mean', - return_idx=False): + length_threshold=4, min_sl=20, stat=np.mean, + core_only=0.6, return_idx=False): """ Clean a segmented fiber group based on the Mahalnobis distance of each streamline @@ -97,6 +149,13 @@ def clean_bundle(tg, n_points=100, clean_rounds=5, distance_threshold=3, stat : callable or str, optional. The statistic of each node relative to which the Mahalanobis is calculated. Default: `np.mean` (but can also use median, etc.) + core_only : float, optional + If non-zero, only the core of the bundle is used for cleaning. + The core is commonly defined as the middle 60% of each streamline, + thus our default is 0.6. This means streamlines are allowed to + deviate in the starting and ending 20% of the bundle. This is useful + for allowing more diverse endpoints. + Default: 0.6 return_idx : bool Whether to return indices in the original streamlines. Default: False. @@ -127,6 +186,12 @@ def clean_bundle(tg, n_points=100, clean_rounds=5, distance_threshold=3, # Resample once up-front: fgarray = np.asarray(abu.resample_tg(streamlines, n_points)) + if core_only != 0: + crop_edge = (1.0 - core_only) / 2 + fgarray = fgarray[ + :, + int(n_points * crop_edge):int(n_points * (1 - crop_edge)), + :] # Crop to middle 60% # Keep this around, so you can use it for indexing at the very end: idx = np.arange(len(fgarray)) @@ -192,3 +257,74 @@ def clean_bundle(tg, n_points=100, clean_rounds=5, distance_threshold=3, return out, idx else: return out + + +def clean_by_isolation_forest(tg, n_points=100, percent_outlier_thresh=15, + min_sl=20, n_jobs=None, random_state=None): + """ + Use Isolation Forest (IF) to clean streamlines. + Nodes are passed to IF, and outlier nodes are identified. + These are re-mapped back on to the streamlines, and streamlines + with too many outlier nodes are removed. This is better for + cleaning bundles that are not tube-like. + + Parameters + ---------- + tg : StatefulTractogram class instance or ArraySequence + A whole-brain tractogram to be segmented. + n_points : int, optional + Number of points to resample streamlines to. + Default: 100 + percent_outlier_thresh : int, optional + Percentage of outliers allowed in the streamline. + Default: 15 + min_sl : int, optional. + Number of streamlines in a bundle under which we will + not bother with cleaning outliers. Default: 20. + n_jobs : int, optional + Number of parallel jobs to use for LOF. + Default: None (single-threaded). + random_state : int, optional + Random state for IsolationForest. + Default: None + + Returns + ------- + indicies of streamlines that passed cleaning + """ + if hasattr(tg, "streamlines"): + streamlines = tg.streamlines + else: + streamlines = dts.Streamlines(tg) + + # We don't even bother if there aren't enough streamlines: + if len(streamlines) < min_sl: + logger.warning(( + "Isolation Forest cleaning not performed" + " due to low streamline count")) + return np.ones(len(streamlines), dtype=bool) + + # Resample once up-front: + fgarray = np.asarray(abu.resample_tg(streamlines, n_points)) + fgarray_dists = np.zeros_like(fgarray) + fgarray_dists[:, 1:, :] = fgarray[:, 1:, :] - fgarray[:, :-1, :] + fgarray_dists[:, 0, :] = fgarray_dists[:, 1, :] + X_ = np.concatenate(( + fgarray.reshape((-1, 3)), + fgarray_dists.reshape((-1, 3))), + axis=1) + idx = np.arange(len(fgarray)) + + lof = IsolationForest(n_jobs=n_jobs, random_state=random_state) + outliers = lof.fit_predict(X_) + outliers = outliers.reshape(fgarray.shape[:2]) + outliers = np.sum(outliers == -1, axis=1) + + idx_belong = outliers * 100 <= n_points * percent_outlier_thresh + + if np.sum(idx_belong) < min_sl: + # need to sort and return exactly min_sl: + return idx[np.argsort(outliers)[:min_sl].astype(int)] + else: + # Update by selection: + return idx[idx_belong] diff --git a/AFQ/recognition/criteria.py b/AFQ/recognition/criteria.py index 33d823803..3fbe5fa28 100644 --- a/AFQ/recognition/criteria.py +++ b/AFQ/recognition/criteria.py @@ -3,11 +3,11 @@ import numpy as np import nibabel as nib +import ray from scipy.ndimage import distance_transform_edt import dipy.tracking.streamline as dts -from dipy.utils.parallel import paramap from dipy.segment.clustering import QuickBundles from dipy.segment.metricspeed import AveragePointwiseEuclideanMetric from dipy.segment.featurespeed import ResampleFeature @@ -21,11 +21,18 @@ import AFQ.recognition.curvature as abv import AFQ.recognition.roi as abr import AFQ.recognition.other_bundles as abo +from AFQ.utils.stats import chunk_indices -bundle_criterion_order = [ + +criteria_order_pre_other_bundles = [ "prob_map", "cross_midline", "start", "end", "length", "primary_axis", "include", "exclude", - "curvature", "recobundles", "qb_thresh"] + "curvature", "recobundles"] + + +criteria_order_post_other_bundles = [ + "orient_mahal", "isolation_forest", "qb_thresh"] + valid_noncriterion = [ "space", "mahal", "primary_axis_percentage", @@ -123,7 +130,7 @@ def primary_axis(b_sls, bundle_def, img, **kwargs): def include(b_sls, bundle_def, preproc_imap, max_includes, - parallel_segmentation, **kwargs): + n_cpus, **kwargs): accept_idx = b_sls.initiate_selection("include") flip_using_include = len(bundle_def["include"]) > 1\ and not b_sls.oriented_yet @@ -139,13 +146,26 @@ def include(b_sls, bundle_def, preproc_imap, max_includes, # with parallel segmentation, the first for loop will # only collect streamlines and does not need tqdm - if parallel_segmentation["engine"] != "serial": - inc_results = paramap( - abr.check_sl_with_inclusion, b_sls.get_selected_sls(), - func_args=[ - bundle_def["include"], include_roi_tols], - **parallel_segmentation) - + if n_cpus > 1: + inc_results = np.zeros(len(b_sls), dtype=tuple) + + inc_rois_id = ray.put(bundle_def["include"]) + inc_roi_tols_id = ray.put(include_roi_tols) + + _check_inc_parallel = ray.remote( + num_cpus=n_cpus)(abr.check_sls_with_inclusion) + + sls_chunks = list(chunk_indices(np.arange(len(b_sls)), n_cpus)) + futures = [ + _check_inc_parallel.remote( + b_sls.get_selected_sls()[sls_chunk], + inc_rois_id, + inc_roi_tols_id) + for sls_chunk in sls_chunks + ] + + for ii, future in enumerate(futures): + inc_results[sls_chunks[ii]] = ray.get(future) else: inc_results = abr.check_sls_with_inclusion( b_sls.get_selected_sls(), @@ -177,14 +197,7 @@ def include(b_sls, bundle_def, preproc_imap, max_includes, accept_idx[sl_idx] = 1 else: accept_idx[sl_idx] = 1 - # see https://github.com/joblib/joblib/issues/945 - if ( - (parallel_segmentation.get( - "engine", "joblib") != "serial") - and (parallel_segmentation.get( - "backend", "loky") == "loky")): - from joblib.externals.loky import get_reusable_executor - get_reusable_executor().shutdown(wait=True) + b_sls.roi_closest = roi_closest.T if flip_using_include: b_sls.reorient(to_flip) @@ -248,15 +261,16 @@ def recobundles(b_sls, mapping, bundle_def, reg_template, img, refine_reco, StatefulTractogram(b_sls.get_selected_sls(), img, Space.VOX), "template", mapping, reg_template, save_intermediates=save_intermediates).streamlines + moved_sl_resampled = abu.resample_tg(moved_sl, 100) rb = RecoBundles(moved_sl, verbose=True, rng=rng) _, rec_labels = rb.recognize( bundle_def['recobundles']['sl'], **rb_recognize_params) if refine_reco: _, rec_labels = rb.refine( - bundle_def['recobundles']['sl'], moved_sl[rec_labels], + bundle_def['recobundles']['sl'], moved_sl_resampled[rec_labels], **rb_recognize_params) - if not b_sls.oriented_yet: + if not b_sls.oriented_yet and np.sum(rec_labels) > 0: standard_sl = next(iter(bundle_def['recobundles']['centroid'])) oriented_idx = abu.orient_by_streamline( moved_sl[rec_labels], @@ -287,12 +301,20 @@ def clean_by_other_bundle(b_sls, bundle_def, cleaned_idx = b_sls.initiate_selection(other_bundle_name) cleaned_idx = 1 + if 'overlap' in bundle_def[other_bundle_name]: + cleaned_idx_overlap = abo.clean_by_overlap( + b_sls.get_selected_sls(), + other_bundle_sls, + bundle_def[other_bundle_name]["overlap"], + img, False) + cleaned_idx = np.logical_and(cleaned_idx, cleaned_idx_overlap) + if 'node_thresh' in bundle_def[other_bundle_name]: - cleaned_idx_node_thresh = abo.clean_by_other_density_map( + cleaned_idx_node_thresh = abo.clean_by_overlap( b_sls.get_selected_sls(), other_bundle_sls, bundle_def[other_bundle_name]["node_thresh"], - img) + img, True) cleaned_idx = np.logical_and(cleaned_idx, cleaned_idx_node_thresh) if 'core' in bundle_def[other_bundle_name]: @@ -300,12 +322,38 @@ def clean_by_other_bundle(b_sls, bundle_def, bundle_def[other_bundle_name]['core'].lower(), preproc_imap["fgarray"][b_sls.selected_fiber_idxs], np.array(abu.resample_tg(other_bundle_sls, 20)), - img.affine) + img.affine, False) + cleaned_idx = np.logical_and(cleaned_idx, cleaned_idx_core) + + if 'entire_core' in bundle_def[other_bundle_name]: + cleaned_idx_core = abo.clean_relative_to_other_core( + bundle_def[other_bundle_name]['entire_core'].lower(), + preproc_imap["fgarray"][b_sls.selected_fiber_idxs], + np.array(abu.resample_tg(other_bundle_sls, 20)), + img.affine, True) cleaned_idx = np.logical_and(cleaned_idx, cleaned_idx_core) b_sls.select(cleaned_idx, other_bundle_name) +def orient_mahal(b_sls, bundle_def, **kwargs): + b_sls.initiate_selection("orient_mahal") + accept_idx = abc.clean_by_orientation_mahalanobis( + b_sls.get_selected_sls(), + **bundle_def.get("orient_mahal", {})) + b_sls.select(accept_idx, "orient_mahal") + + +def isolation_forest(b_sls, bundle_def, n_cpus, rng, **kwargs): + b_sls.initiate_selection("isolation_forest") + accept_idx = abc.clean_by_isolation_forest( + b_sls.get_selected_sls(), + percent_outlier_thresh=bundle_def["isolation_forest"].get( + "percent_outlier_thresh", 25), + n_jobs=n_cpus, random_state=rng) + b_sls.select(accept_idx, "isolation_forest") + + def mahalanobis(b_sls, bundle_def, clip_edges, cleaning_params, **kwargs): b_sls.initiate_selection("Mahalanobis") clean_params = bundle_def.get("mahal", {}) @@ -378,18 +426,20 @@ def check_space(roi): inputs[key] = value for potential_criterion in bundle_def.keys(): - if (potential_criterion not in bundle_criterion_order) and\ - (potential_criterion not in bundle_dict.bundle_names) and\ + if (potential_criterion not in criteria_order_post_other_bundles) and\ + (potential_criterion not in criteria_order_pre_other_bundles) and\ + (potential_criterion not in bundle_dict.bundle_names) and\ (potential_criterion not in valid_noncriterion): raise ValueError(( "Invalid criterion in bundle definition:\n" f"{potential_criterion} in bundle {bundle_name}.\n" "Valid criteria are:\n" - f"{bundle_criterion_order}\n" + f"{criteria_order_pre_other_bundles}\n" + f"{criteria_order_post_other_bundles}\n" f"{bundle_dict.bundle_names}\n" f"{valid_noncriterion}\n")) - for criterion in bundle_criterion_order: + for criterion in criteria_order_pre_other_bundles: if b_sls and criterion in bundle_def: inputs[criterion] = globals()[criterion](**inputs) if b_sls: @@ -400,8 +450,14 @@ def check_space(roi): **inputs, other_bundle_name=bundle_name, other_bundle_sls=tg.streamlines[idx]) + for criterion in criteria_order_post_other_bundles: + if b_sls and criterion in bundle_def: + inputs[criterion] = globals()[criterion](**inputs) if b_sls: - mahalanobis(**inputs) + if "mahal" in bundle_def or ( + "isolation_forest" not in bundle_def + and "orient_mahal" not in bundle_def): + mahalanobis(**inputs) if b_sls and not b_sls.oriented_yet: raise ValueError( diff --git a/AFQ/recognition/other_bundles.py b/AFQ/recognition/other_bundles.py index ddeffe7b4..f85bfc17a 100644 --- a/AFQ/recognition/other_bundles.py +++ b/AFQ/recognition/other_bundles.py @@ -9,11 +9,11 @@ logger = logging.getLogger('AFQ') -def clean_by_other_density_map(this_bundle_sls, other_bundle_sls, - node_thresh, img): +def clean_by_overlap(this_bundle_sls, other_bundle_sls, + overlap, img, remove=False): """ - Cleans a set of streamlines by removing those with significant overlap with - another set of streamlines. + Cleans a set of streamlines by only keeping (or removing) those with + significant overlap with another set of streamlines. Parameters ---------- @@ -21,13 +21,18 @@ def clean_by_other_density_map(this_bundle_sls, other_bundle_sls, A list or array of streamlines to be cleaned. other_bundle_sls : array-like A reference list or array of streamlines to determine overlapping regions. - node_thresh : int - The maximum number of nodes allowed to overlap between `this_bundle_sls` + overlap : int + The minimum number of nodes allowed to overlap between `this_bundle_sls` and `other_bundle_sls`. Streamlines with overlaps beyond this threshold are removed. img : nibabel.Nifti1Image or ndarray A reference 3D image that defines the spatial dimensions for the density map. + remove : bool, optional + If True, streamlines that overlap in less than `overlap` nodes are + removed. If False, streamlines that overlap in more than `overlap` nodes + are removed. + Default: False. Returns ------- @@ -41,12 +46,12 @@ def clean_by_other_density_map(this_bundle_sls, other_bundle_sls, This function computes a density map from `other_bundle_sls` to represent the spatial occupancy of the streamlines. It then calculates the probability of each streamline in `this_bundle_sls` overlapping with this map. - Streamlines that overlap in more than `node_thresh` nodes are flagged for - removal. + Streamlines that overlap in less than `overlap` nodes are flagged for + removal (or more, if remove is True). Examples -------- - >>> clean_idx = clean_by_other_density_map(bundle1, bundle2, 5, img) + >>> clean_idx = clean_by_overlap(bundle1, bundle2, 5, img, True) >>> cleaned_bundle = [s for i, s in enumerate(bundle1) if clean_idx[i]] """ other_bundle_density_map = dtu.density_map( @@ -55,11 +60,15 @@ def clean_by_other_density_map(this_bundle_sls, other_bundle_sls, other_bundle_density_map, this_bundle_sls, np.eye(4)) cleaned_idx = np.zeros(len(this_bundle_sls), dtype=np.bool_) for ii, fp in enumerate(fiber_probabilities): - cleaned_idx[ii] = np.sum(np.asarray(fp) >= 1) <= node_thresh + if remove: + cleaned_idx[ii] = np.sum(np.asarray(fp) >= 1) <= overlap + else: + cleaned_idx[ii] = np.sum(np.asarray(fp) >= 1) > overlap return cleaned_idx -def clean_relative_to_other_core(core, this_fgarray, other_fgarray, affine): +def clean_relative_to_other_core(core, this_fgarray, other_fgarray, affine, + entire=False): """ Removes streamlines from a set that lie on the opposite side of a specified core axis compared to another set of streamlines. @@ -76,6 +85,11 @@ def clean_relative_to_other_core(core, this_fgarray, other_fgarray, affine): An array of reference streamlines to define the core. affine : ndarray The affine transformation matrix. + entire : bool, optional + If True, the entire streamline must lie on the correct side of the core + to be retained. If False, only the closest point on the streamline to + the core is considered. + Default: False. Returns ------- @@ -143,11 +157,18 @@ def clean_relative_to_other_core(core, this_fgarray, other_fgarray, affine): core_bundle = np.median(other_fgarray, axis=0) cleaned_idx_core = np.zeros(this_fgarray.shape[0], dtype=np.bool_) for ii, sl in enumerate(this_fgarray): - dist_matrix = cdist(core_bundle, sl, 'sqeuclidean') - min_dist_indices = np.unravel_index(np.argmin(dist_matrix), - dist_matrix.shape) - closest_core = core_bundle[min_dist_indices[0], core_axis] - closest_sl = sl[min_dist_indices[1], core_axis] + if entire: + cleaned_idx_core[ii] = np.all( + core_direc * ( + sl[:, core_axis] - core_bundle[:, core_axis]) > 0) + else: + dist_matrix = cdist(core_bundle, sl, 'sqeuclidean') + min_dist_indices = np.unravel_index(np.argmin(dist_matrix), + dist_matrix.shape) + closest_core = core_bundle[min_dist_indices[0], core_axis] + closest_sl = sl[min_dist_indices[1], core_axis] + + cleaned_idx_core[ii] = core_direc * ( + closest_sl - closest_core) > 0 - cleaned_idx_core[ii] = core_direc * (closest_sl - closest_core) > 0 return cleaned_idx_core diff --git a/AFQ/recognition/recognize.py b/AFQ/recognition/recognize.py index 92a6ef216..12f9aa453 100644 --- a/AFQ/recognition/recognize.py +++ b/AFQ/recognition/recognize.py @@ -21,13 +21,13 @@ def recognize( mapping, bundle_dict, reg_template, + n_cpus, nb_points=False, nb_streamlines=False, clip_edges=False, - parallel_segmentation={"engine": "serial"}, rb_recognize_params=dict( model_clust_thr=1.25, - reduction_thr=25, + reduction_thr=50, pruning_thr=12), refine_reco=False, prob_threshold=0, @@ -53,6 +53,8 @@ def recognize( Dictionary of bundles to segment. reg_template : str, nib.Nifti1Image Template image for registration. + n_cpus : int + Number of CPUs to use for parallelization. nb_points : int, boolean Resample streamlines to nb_points number of points. If False, no resampling is done. Default: False @@ -62,13 +64,6 @@ def recognize( clip_edges : bool Whether to clip the streamlines to be only in between the ROIs. Default: False - parallel_segmentation : dict or AFQ.api.BundleDict - How to parallelize segmentation across processes when performing - waypoint ROI segmentation. Set to {"engine": "serial"} to not - perform parallelization. Some engines may cause errors, depending - on the system. See ``dipy.utils.parallel.paramap`` for - details. - Default: {"engine": "serial"} rb_recognize_params : dict RecoBundles parameters for the recognize function. Default: dict(model_clust_thr=1.25, reduction_thr=25, pruning_thr=12) @@ -185,7 +180,7 @@ def recognize( bundle_name, bundle_idx, bundle_to_flip, bundle_roi_closest, bundle_decisions, clip_edges=clip_edges, - parallel_segmentation=parallel_segmentation, + n_cpus=n_cpus, rb_recognize_params=rb_recognize_params, prob_threshold=prob_threshold, refine_reco=refine_reco, @@ -206,8 +201,8 @@ def recognize( logger.warning(( "Conflicts in bundle assignment detected. " f"{conflicts} conflicts detected in total out of " - f"{n_streamlines} total streamlines." - "Defaulting to whichever bundle appears first" + f"{n_streamlines} total streamlines. " + "Defaulting to whichever bundle appears first " "in the bundle_dict.")) bundle_decisions = np.concatenate(( bundle_decisions, np.ones((n_streamlines, 1))), axis=1) diff --git a/AFQ/recognition/roi.py b/AFQ/recognition/roi.py index 7868bf432..d87062f58 100644 --- a/AFQ/recognition/roi.py +++ b/AFQ/recognition/roi.py @@ -6,30 +6,28 @@ def _interp3d(roi, sl): return interpolate_scalar_3d(roi.get_fdata(), np.asarray(sl))[0] -def check_sls_with_inclusion(sls, include_rois, include_roi_tols): - for sl in sls: - yield check_sl_with_inclusion( - sl, - include_rois, - include_roi_tols) +def check_sls_with_inclusion( + sls, include_rois, include_roi_tols): + inc_results = np.zeros(len(sls), dtype=tuple) + include_rois = [roi_.get_fdata().copy() for roi_ in include_rois] + for jj, sl in enumerate(sls): + closest = np.zeros(len(include_rois), dtype=np.int32) + sl = np.asarray(sl) + valid = True + for ii, roi in enumerate(include_rois): + dist = interpolate_scalar_3d(roi, sl)[0] + closest[ii] = np.argmin(dist) + if dist[closest[ii]] > include_roi_tols[ii]: + # Too far from one of them: + inc_results[jj] = (False, []) + valid = False + break -def check_sl_with_inclusion(sl, include_rois, - include_roi_tols): - """ - Helper function to check that a streamline is close to a list of - inclusion ROIS. - """ - closest = np.zeros(len(include_rois), dtype=np.int32) - for ii, roi in enumerate(include_rois): - dist = _interp3d(roi, sl) - closest[ii] = np.argmin(dist) - if dist[closest[ii]] > include_roi_tols[ii]: - # Too far from one of them: - return False, [] - - # Apparently you checked all the ROIs and it was close to all of them - return True, closest + # Checked all the ROIs and it was close to all of them + if valid: + inc_results[jj] = (True, closest) + return inc_results def check_sl_with_exclusion(sl, exclude_rois, @@ -58,7 +56,7 @@ def clean_by_endpoints(streamlines, target, target_idx, tol=0, Where N is number of nodes in the array, the collection of streamlines to filter down to. target: Nifti1Image - Nifti1Image containing a boolean representation of the ROI. + Nifti1Image containing a distance transform of the ROI. target_idx: int. Index within each streamline to check if within the target region. Typically 0 for startpoint ROIs or -1 for endpoint ROIs. diff --git a/AFQ/recognition/tests/test_other_bundles.py b/AFQ/recognition/tests/test_other_bundles.py index 0dbc43e70..4af3c6775 100644 --- a/AFQ/recognition/tests/test_other_bundles.py +++ b/AFQ/recognition/tests/test_other_bundles.py @@ -14,17 +14,28 @@ node_thresh_sample = 1 -def test_clean_by_other_density_map(): - cleaned_idx = abo.clean_by_other_density_map( +def test_clean_by_overlap(): + cleaned_idx = abo.clean_by_overlap( this_bundle_sls_sample, other_bundle_sls_sample, node_thresh_sample, - img_sample + img_sample, + True ) assert isinstance(cleaned_idx, np.ndarray) assert cleaned_idx.shape[0] == this_bundle_sls_sample.shape[0] assert np.all(cleaned_idx == [False, True]) + cleaned_idx = abo.clean_by_overlap( + this_bundle_sls_sample, + other_bundle_sls_sample, + node_thresh_sample, + img_sample, + False + ) + assert isinstance(cleaned_idx, np.ndarray) + assert cleaned_idx.shape[0] == this_bundle_sls_sample.shape[0] + assert np.all(cleaned_idx == [True, False]) def test_clean_relative_to_other_core(): for core in ['anterior', 'posterior', 'superior', 'inferior', 'right', 'left']: @@ -32,7 +43,8 @@ def test_clean_relative_to_other_core(): core, this_bundle_sls_sample, other_bundle_sls_sample, - np.eye(4) + np.eye(4), + entire=False ) assert isinstance(cleaned_idx_core, np.ndarray) @@ -41,3 +53,22 @@ def test_clean_relative_to_other_core(): assert np.all(cleaned_idx_core == [True, True]) else: assert np.all(cleaned_idx_core == [False, False]) + + cleaned_idx_core = abo.clean_relative_to_other_core( + core, + this_bundle_sls_sample, + other_bundle_sls_sample, + np.eye(4), + entire=True + ) + + assert isinstance(cleaned_idx_core, np.ndarray) + assert cleaned_idx_core.shape[0] == this_bundle_sls_sample.shape[0] + if core == "inferior": + assert np.all(cleaned_idx_core == [True, True]) + elif core == "right": + assert np.all(cleaned_idx_core == [True, False]) + elif core == "posterior" or core == "left": + assert np.all(cleaned_idx_core == [False, True]) + else: + assert np.all(cleaned_idx_core == [False, False]) diff --git a/AFQ/recognition/tests/test_recognition.py b/AFQ/recognition/tests/test_recognition.py index 24eafea11..0f5688ab1 100644 --- a/AFQ/recognition/tests/test_recognition.py +++ b/AFQ/recognition/tests/test_recognition.py @@ -69,7 +69,7 @@ def test_segment(): nib.load(hardi_fdata), mapping, bundles, - reg_template) + reg_template, 2) # We asked for 2 fiber groups: npt.assert_equal(len(fiber_groups), 2) @@ -107,7 +107,7 @@ def test_segment_no_prob(): nib.load(hardi_fdata), mapping, bundles_no_prob, - reg_template) + reg_template, 1) # This condition should still hold npt.assert_equal(len(fiber_groups), 2) @@ -121,7 +121,7 @@ def test_segment_return_idx(): nib.load(hardi_fdata), mapping, bundles, - reg_template, + reg_template, 1, return_idx=True) npt.assert_equal(len(fiber_groups), 2) @@ -137,7 +137,7 @@ def test_segment_clip_edges_api(): nib.load(hardi_fdata), mapping, bundles, - reg_template, + reg_template, 1, clip_edges=True) npt.assert_equal(len(fiber_groups), 2) npt.assert_(len(fiber_groups['Right Corticospinal']) > 0) @@ -157,7 +157,7 @@ def test_segment_reco(): nib.load(hardi_fdata), mapping, bundles_reco, - reg_template, + reg_template, 1, rng=np.random.RandomState(seed=8)) # This condition should still hold @@ -191,7 +191,7 @@ def test_exclusion_ROI(): nib.load(hardi_fdata), mapping, slf_bundle, - reg_template) + reg_template, 1) npt.assert_equal(len(fiber_groups["Left Superior Longitudinal"]), 2) @@ -203,7 +203,7 @@ def test_exclusion_ROI(): nib.load(hardi_fdata), mapping, slf_bundle, - reg_template) + reg_template, 1) npt.assert_equal(len(fiber_groups["Left Superior Longitudinal"]), 1) @@ -214,7 +214,7 @@ def test_segment_sampled_streamlines(): nib.load(hardi_fdata), mapping, bundles, - reg_template) + reg_template, 1) # Already using a subsampled tck # the Right Corticospinal has two streamlines and @@ -230,7 +230,7 @@ def test_segment_sampled_streamlines(): nib.load(hardi_fdata), mapping, bundles, - reg_template, + reg_template, 1, nb_streamlines=nb_streamlines) # sampled streamlines should be subset of the original streamlines diff --git a/AFQ/recognition/tests/test_rois.py b/AFQ/recognition/tests/test_rois.py index 1827676a1..1908b5351 100644 --- a/AFQ/recognition/tests/test_rois.py +++ b/AFQ/recognition/tests/test_rois.py @@ -6,7 +6,6 @@ from scipy.ndimage import distance_transform_edt from AFQ.recognition.roi import ( check_sls_with_inclusion, - check_sl_with_inclusion, check_sl_with_exclusion) shape = (15, 15, 15) @@ -83,15 +82,15 @@ def test_check_sls_with_inclusion(): def test_check_sl_with_inclusion_pass(): - result, dists = check_sl_with_inclusion( - streamline1, include_rois, include_roi_tols) + result, dists = check_sls_with_inclusion( + [streamline1], include_rois, include_roi_tols)[0] assert result is True assert len(dists) == 2 def test_check_sl_with_inclusion_fail(): - result, dists = check_sl_with_inclusion( - streamline2, include_rois, include_roi_tols) + result, dists = check_sls_with_inclusion( + [streamline2], include_rois, include_roi_tols)[0] assert result is False assert dists == [] diff --git a/AFQ/recognition/tests/test_utils.py b/AFQ/recognition/tests/test_utils.py index 0a58d8254..a0b1f10cf 100644 --- a/AFQ/recognition/tests/test_utils.py +++ b/AFQ/recognition/tests/test_utils.py @@ -7,6 +7,7 @@ import AFQ.recognition.curvature as abv import AFQ.recognition.utils as abu import AFQ.recognition.cleaning as abc +import AFQ.recognition.other_bundles as abo from dipy.io.stateful_tractogram import StatefulTractogram, Space @@ -23,6 +24,16 @@ streamlines = tg.streamlines +def _make_straight_streamlines(n_sl=5, length=10, axis=0, offset=0): + """Utility: make straight streamlines along one axis.""" + sls = [] + for i in range(n_sl): + sl = np.zeros((length, 3)) + sl[:, axis] = np.linspace(0, length - 1, length) + offset + sls.append(sl) + return sls + + def test_segment_sl_curve(): sl_disp_0 = abv.sl_curve(streamlines[4], 4) npt.assert_array_almost_equal( @@ -79,3 +90,81 @@ def test_segment_orientation(): primary_axis="I/S", affine=np.eye(4), tol=33) npt.assert_array_equal(cleaned_idx_tol, cleaned_idx) + + +def test_clean_isolation_forest_basic(): + cleaned_idx = abc.clean_by_isolation_forest(streamlines, + n_points=20, + min_sl=10) + # Should return either a boolean mask or integer indices + npt.assert_(isinstance(cleaned_idx, (np.ndarray,))) + npt.assert_(cleaned_idx.shape[0] <= len(streamlines)) + + +def test_clean_isolation_forest_outlier_thresh(): + cleaned_loose = abc.clean_by_isolation_forest(streamlines, + n_points=20, + percent_outlier_thresh=50, + min_sl=10, + random_state=42) + cleaned_strict = abc.clean_by_isolation_forest(streamlines, + n_points=20, + percent_outlier_thresh=5, + min_sl=10, + random_state=42) + npt.assert_(np.sum(cleaned_loose) >= np.sum(cleaned_strict)) + + +def test_clean_by_overlap_keep_remove(): + img = nib.Nifti1Image(np.zeros((20, 20, 20)), np.eye(4)) + + this_bundle = _make_straight_streamlines(n_sl=3, length=10, axis=0) + other_bundle = _make_straight_streamlines(n_sl=3, length=10, axis=0) + + cleaned_remove = abo.clean_by_overlap(this_bundle, other_bundle, + overlap=5, img=img, remove=True) + npt.assert_equal(cleaned_remove, np.zeros(3, dtype=bool)) + + cleaned_keep = abo.clean_by_overlap(this_bundle, other_bundle, + overlap=5, img=img, remove=False) + npt.assert_equal(cleaned_keep, np.ones(3, dtype=bool)) + + +def test_clean_by_overlap_partial_overlap(): + img = nib.Nifti1Image(np.zeros((20, 20, 20)), np.eye(4)) + + this_bundle = _make_straight_streamlines(n_sl=2, length=10, axis=0) + other_bundle = _make_straight_streamlines(n_sl=2, length=10, axis=1) + + # These bundles are orthogonal, so minimal overlap + cleaned = abo.clean_by_overlap(this_bundle, other_bundle, + overlap=2, img=img, remove=False) + npt.assert_equal(cleaned, np.zeros(2, dtype=bool)) + + +def test_clean_relative_to_other_core_entire_vs_closest(): + # Two bundles along x axis, separated along z + this_bundle = np.array(_make_straight_streamlines(n_sl=2, + length=5, + axis=0)) + this_bundle[0, :, 2] += 5 + this_bundle[1, :, 2] -= 5 + other_bundle = np.array(_make_straight_streamlines(n_sl=2, + length=5, + axis=0)) + affine = np.eye(4) + cleaned_entire = abo.clean_relative_to_other_core('inferior', + this_bundle, + other_bundle, + affine, + entire=True) + npt.assert_equal(cleaned_entire, [True, False]) + + # With entire=False, same result in this synthetic case + cleaned_closest = abo.clean_relative_to_other_core('inferior', + this_bundle, + other_bundle, + affine, + entire=False) + npt.assert_equal(cleaned_closest, [True, False]) + diff --git a/AFQ/tasks/data.py b/AFQ/tasks/data.py index 8ffae6112..220463239 100644 --- a/AFQ/tasks/data.py +++ b/AFQ/tasks/data.py @@ -1,10 +1,12 @@ import nibabel as nib import numpy as np import logging +import multiprocessing +import os.path as op from dipy.io.gradients import read_bvals_bvecs import dipy.core.gradients as dpg -from dipy.data import default_sphere +from dipy.data import default_sphere, get_sphere import immlib @@ -16,16 +18,21 @@ from dipy.reconst.rumba import RumbaSDModel, RumbaFit from dipy.reconst import shm from dipy.reconst.dki_micro import axonal_water_fraction +from dipy.reconst.mcsd import ( + mask_for_response_msmt, + multi_shell_fiber_response, + response_from_mask_msmt) +from dipy.core.gradients import unique_bvals_tolerance +from dipy.align import resample from AFQ.tasks.decorators import as_file, as_img, as_fit_deriv -from AFQ.tasks.utils import get_fname, with_name, str_to_desc +from AFQ.tasks.utils import get_fname, with_name import AFQ.api.bundle_dict as abd import AFQ.data.fetch as afd from AFQ.utils.path import drop_extension, write_json from AFQ._fixes import gwi_odf from AFQ.definitions.utils import Definition -from AFQ.definitions.image import B0Image from AFQ.models.dti import noise_from_b0 from AFQ.models.csd import _fit as csd_fit_model @@ -35,7 +42,13 @@ from AFQ.models.fwdti import _fit as fwdti_fit_model from AFQ.models.QBallTP import ( extract_odf, anisotropic_index, anisotropic_power) +from AFQ.models.wmgm_interface import fit_wm_gm_interface, pve_from_subcortex +from AFQ.models.msmt import MultiShellDeconvModel +from AFQ.models.asym_filtering import ( + unified_filtering, compute_asymmetry_index, + compute_odd_power_map, compute_nufid_asym) +from AFQ.nn.brainchop import run_brainchop logger = logging.getLogger('AFQ') @@ -87,20 +100,50 @@ def get_data_gtab(dwi_data_file, bval_file, bvec_file, min_bval=-np.inf, return data, gtab, img, img.affine +@immlib.calc("n_cpus", "n_threads") +def configure_ncpus_nthreads(ray_n_cpus=None, numba_n_threads=None): + """ + Configure the number of CPUs to use for parallel processing with Ray, + the number of threads to use for Numba + + Parameters + ---------- + ray_n_cpus : int, optional + The number of CPUs to use for parallel processing with Ray. + If None, uses the number of available CPUs minus one. + Tractography, Recognition, and MSMT use Ray. + Default: None + numba_n_threads : int, optional + The number of threads to use for Numba. + If None, uses the number of available CPUs minus one, + but with a maximum of 16. + ASYM fit uses Numba. + Default: None + """ + if ray_n_cpus is None: + ray_n_cpus = max(multiprocessing.cpu_count() - 1, 1) + if numba_n_threads is None: + numba_n_threads = min( + max(multiprocessing.cpu_count() - 1, 1), 16) + + return ray_n_cpus, numba_n_threads + + @immlib.calc("b0") @as_file('_b0ref.nii.gz') +@as_img def b0(dwi, gtab): """ full path to a nifti file containing the mean b0 """ mean_b0 = np.mean(dwi.get_fdata()[..., gtab.b0s_mask], -1) - mean_b0_img = nib.Nifti1Image(mean_b0, dwi.affine) meta = dict(b0_threshold=gtab.b0_threshold) - return mean_b0_img, meta + return mean_b0, meta @immlib.calc("masked_b0") @as_file('_desc-masked_b0ref.nii.gz') +@as_img def b0_mask(b0, brain_mask): """ full path to a nifti file containing the @@ -112,11 +155,36 @@ def b0_mask(b0, brain_mask): masked_data = img.get_fdata() masked_data[~brain_mask] = 0 - masked_b0_img = nib.Nifti1Image(masked_data, img.affine) meta = dict( source=b0, masked=True) - return masked_b0_img, meta + return masked_data, meta + + +@immlib.calc("t1w_pve") +@as_file(suffix='_desc-pve_probseg.nii.gz') +def t1w_pve(t1_subcortex): + """ + WM, GM, CSF segmentations from subcortex segmentation + from brainchop on T1w image + """ + t1_subcortex_img = nib.load(t1_subcortex) + PVE = pve_from_subcortex(t1_subcortex_img.get_fdata()) + + return nib.Nifti1Image(PVE, t1_subcortex_img.affine), dict( + SubCortexParcellation=t1_subcortex, + labels=["csf", "gm", "wm"],) + + +@immlib.calc("wm_gm_interface") +@as_file(suffix='_desc-wmgmi_mask.nii.gz') +def wm_gm_interface(t1w_pve, b0): + PVE_img = nib.load(t1w_pve) + b0_img = nib.load(b0) + + wmgmi_img = fit_wm_gm_interface(PVE_img, b0_img) + + return wmgmi_img, dict(FromPVE=t1w_pve) @immlib.calc("dti_tf") @@ -299,6 +367,213 @@ def msdki_msk(msdki_tf): return msdki_tf.msk +@immlib.calc("msmtcsd_params") +@as_file(suffix='_model-msmtcsd_param-fod_dwimap.nii.gz', + subfolder="models") +@as_img +def msmt_params(brain_mask, gtab, data, + dwi_affine, t1w_pve, + n_cpus, + msmt_sh_order=8, + msmt_fa_thr=0.7): + """ + full path to a nifti file containing + parameters for the MSMT CSD fit + + Parameters + ---------- + msmt_sh_order : int, optional. + Spherical harmonic order to use for the MSMT CSD fit. + Default: 8 + msmt_fa_thr : float, optional. + The threshold on the FA used to calculate the multi shell auto + response. Can be useful to reduce for baby subjects. + Default: 0.7 + + References + ---------- + .. [1] B. Jeurissen, J.-D. Tournier, T. Dhollander, A. Connelly, + and J. Sijbers. Multi-tissue constrained spherical + deconvolution for improved analysis of multi-shell diffusion + MRI data. NeuroImage, 103 (2014), pp. 411–426 + """ + mask =\ + nib.load(brain_mask).get_fdata() + + pve_img = nib.load(t1w_pve) + pve_data = pve_img.get_fdata() + csf = resample(pve_data[..., 0], data[..., 0], + pve_img.affine, dwi_affine).get_fdata() + gm = resample(pve_data[..., 1], data[..., 0], + pve_img.affine, dwi_affine).get_fdata() + wm = resample(pve_data[..., 2], data[..., 0], + pve_img.affine, dwi_affine).get_fdata() + + mask_wm, mask_gm, mask_csf = mask_for_response_msmt( + gtab, + data, + roi_radii=10, + wm_fa_thr=msmt_fa_thr, + gm_fa_thr=0.3, + csf_fa_thr=0.15, + gm_md_thr=0.001, + csf_md_thr=0.0032) + mask_wm *= wm > 0.5 + mask_gm *= gm > 0.5 + mask_csf *= csf > 0.5 + response_wm, response_gm, response_csf = response_from_mask_msmt( + gtab, data, mask_wm, mask_gm, mask_csf) + ubvals = unique_bvals_tolerance(gtab.bvals) + response_mcsd = multi_shell_fiber_response( + msmt_sh_order, + ubvals, + response_wm, + response_gm, + response_csf) + + mcsd_model = MultiShellDeconvModel(gtab, response_mcsd) + logger.info("Fitting Multi-Shell CSD model...") + mcsd_fit = mcsd_model.fit( + data, mask, n_cpus=n_cpus) + + meta = dict( + SphericalHarmonicDegree=msmt_sh_order, + SphericalHarmonicBasis="DESCOTEAUX") + return mcsd_fit.shm_coeff, meta + + +@immlib.calc("msmt_apm") +@as_file(suffix='_model-msmtcsd_param-apm_dwimap.nii.gz', + subfolder="models") +@as_img +def msmt_apm(msmtcsd_params): + """ + full path to a nifti file containing + the anisotropic power map + """ + sh_coeff = nib.load(msmtcsd_params).get_fdata() + pmap = anisotropic_power(sh_coeff) + return pmap, dict(MSMTCSDParamsFile=msmtcsd_params) + + +@immlib.calc("msmt_aodf_params") +@as_file(suffix='_model-msmtcsd_param-aodf_dwimap.nii.gz', + subfolder="models") +@as_img +def msmt_aodf(msmtcsd_params, n_threads): + """ + full path to a nifti file containing + MSMT CSD ODFs filtered by unified filtering [1] + + References + ---------- + [1] Poirier and Descoteaux, 2024, "A Unified Filtering Method for + Estimating Asymmetric Orientation Distribution Functions", + Neuroimage, https://doi.org/10.1016/j.neuroimage.2024.120516 + """ + sh_coeff = nib.load(msmtcsd_params).get_fdata() + + logger.info("Applying unified filtering to MSMT CSD ODFs...") + aodf = unified_filtering( + sh_coeff, + get_sphere(name="repulsion724"), + n_threads=n_threads) + + return aodf, dict( + MSMTCSDParamsFile=msmtcsd_params, + Sphere="repulsion724") + + +@immlib.calc("msmt_aodf_asi") +@as_file(suffix='_model-msmtcsd_param-asi_dwimap.nii.gz', + subfolder="models") +@as_img +def msmt_aodf_asi(msmt_aodf_params, brain_mask): + """ + full path to a nifti file containing + the MSMT CSD Asymmetric Index (ASI) [1] + + References + ---------- + [1] S. Cetin Karayumak, E. Özarslan, and G. Unal, + "Asymmetric Orientation Distribution Functions (AODFs) + revealing intravoxel geometry in diffusion MRI" + Magnetic Resonance Imaging, vol. 49, pp. 145-158, Jun. 2018, + doi: https://doi.org/10.1016/j.mri.2018.03.006. + """ + + aodf = nib.load(msmt_aodf_params).get_fdata() + brain_mask = nib.load(brain_mask).get_fdata().astype(bool) + asi = compute_asymmetry_index(aodf, brain_mask) + + return asi, dict(MSMTCSDParamsFile=msmt_aodf_params) + + +@immlib.calc("msmt_aodf_opm") +@as_file(suffix='_model-msmtcsd_param-opm_dwimap.nii.gz', + subfolder="models") +@as_img +def msmt_aodf_opm(msmt_aodf_params, brain_mask): + """ + full path to a nifti file containing + the MSMT CSD odd-power map [1] + + References + ---------- + [1] C. Poirier, E. St-Onge, and M. Descoteaux, + "Investigating the Occurence of Asymmetric Patterns in + White Matter Fiber Orientation Distribution Functions" + [Abstract], In: Proc. Intl. Soc. Mag. Reson. Med. 29 (2021), + 2021 May 15-20, Vancouver, BC, Abstract number 0865. + """ + + aodf = nib.load(msmt_aodf_params).get_fdata() + brain_mask = nib.load(brain_mask).get_fdata().astype(bool) + opm = compute_odd_power_map(aodf, brain_mask) + + return opm, dict(MSMTCSDParamsFile=msmt_aodf_params) + + +@immlib.calc("msmt_aodf_nufid") +@as_file(suffix='_model-msmtcsd_param-nufid_dwimap.nii.gz', + subfolder="models") +@as_img +def msmt_aodf_nufid(msmt_aodf_params, brain_mask, + t1w_pve): + """ + full path to a nifti file containing + the MSMT CSD Number of fiber directions (nufid) map [1] + + References + ---------- + [1] C. Poirier and M. Descoteaux, + "Filtering Methods for Asymmetric ODFs: + Where and How Asymmetry Occurs in the White Matter." + bioRxiv. 2022 Jan 1; 2022.12.18.520881. + doi: https://doi.org/10.1101/2022.12.18.520881 + """ + pve_img = nib.load(t1w_pve) + pve_data = pve_img.get_fdata() + + aodf_img = nib.load(msmt_aodf_params) + aodf = aodf_img.get_fdata() + + csf = resample(pve_data[..., 0], aodf[..., 0], + pve_img.affine, aodf_img.affine).get_fdata() + + # Only sphere we use for AODF currently + sphere = get_sphere(name="repulsion724") + + brain_mask = nib.load(brain_mask).get_fdata().astype(bool) + + logger.info("Number of fiber directions (nufid) map from AODF...") + nufid = compute_nufid_asym(aodf, sphere, csf, brain_mask) + + return nufid, dict( + MSMTCSDParamsFile=msmt_aodf_params, + PVE=t1w_pve) + + @immlib.calc("csd_params") @as_file(suffix='_model-csd_param-fod_dwimap.nii.gz', subfolder="models") @@ -373,6 +648,34 @@ def csd_params(dwi, brain_mask, gtab, data, return csdf.shm_coeff, meta +@immlib.calc("csd_aodf_params") +@as_file(suffix='_model-csd_param-aodf_dwimap.nii.gz', + subfolder="models") +@as_img +def csd_aodf(csd_params, n_threads): + """ + full path to a nifti file containing + SSST CSD ODFs filtered by unified filtering [1] + + References + ---------- + [1] Poirier and Descoteaux, 2024, "A Unified Filtering Method for + Estimating Asymmetric Orientation Distribution Functions", + Neuroimage, https://doi.org/10.1016/j.neuroimage.2024.120516 + """ + sh_coeff = nib.load(csd_params).get_fdata() + + logger.info("Applying unified filtering to CSD ODFs...") + aodf = unified_filtering( + sh_coeff, + get_sphere(name="repulsion724"), + n_threads=n_threads) + + return aodf, dict( + CSDParamsFile=csd_params, + Sphere="repulsion724") + + @immlib.calc("csd_pmap") @as_file(suffix='_model-csd_param-apm_dwimap.nii.gz', subfolder="models") @@ -1011,6 +1314,42 @@ def dki_kfa(dki_tf): return dki_tf.kfa +@immlib.calc("dki_cl") +@as_file('_model-dki_param-cl_dwimap.nii.gz', + subfolder="models") +@as_fit_deriv('DKI') +def dki_cl(dki_tf): + """ + full path to a nifti file containing + the DKI linearity file + """ + return dki_tf.linearity + + +@immlib.calc("dki_cp") +@as_file('_model-dki_param-cp_dwimap.nii.gz', + subfolder="models") +@as_fit_deriv('DKI') +def dki_cp(dki_tf): + """ + full path to a nifti file containing + the DKI planarity file + """ + return dki_tf.planarity + + +@immlib.calc("dki_cs") +@as_file('_model-dki_param-cs_dwimap.nii.gz', + subfolder="models") +@as_fit_deriv('DKI') +def dki_cs(dki_tf): + """ + full path to a nifti file containing + the DKI sphericity file + """ + return dki_tf.sphericity + + @immlib.calc("dki_ga") @as_file(suffix='_model-dki_param-ga_dwimap.nii.gz', subfolder="models") @@ -1071,31 +1410,86 @@ def dki_ak(dki_tf): return dki_tf.ak -@immlib.calc("brain_mask") -@as_file('_desc-brain_mask.nii.gz') -def brain_mask(b0, brain_mask_definition=None): +@immlib.calc("t1_brain_mask") +@as_file(suffix='_desc-T1w_mask.nii.gz') +def t1_brain_mask(t1_file): """ - full path to a nifti file containing - the brain mask + full path to a nifti file containing brain mask from T1w image, - Parameters + References ---------- - brain_mask_definition : instance from `AFQ.definitions.image`, optional - This will be used to create - the brain mask, which gets applied before registration to a - template. - If you want no brain mask to be applied, use FullImage. - If None, use B0Image() - Default: None + [1] Masoud, M., Hu, F., & Plis, S. (2023). Brainchop: In-browser MRI + volumetric segmentation and rendering. Journal of Open Source + Software, 8(83), 5098. + https://doi.org/10.21105/joss.05098 + """ + return run_brainchop(nib.load(t1_file), "mindgrab"), dict( + T1w=t1_file, + model="brainchop") + + +@immlib.calc("t1_masked") +@as_file(suffix='_desc-masked_T1w.nii.gz') +def t1_masked(t1_file, t1_brain_mask): + """ + full path to a nifti file containing the T1w masked """ - # Note that any case where brain_mask_definition is not None - # is handled in get_data_plan - # This is just the default - return B0Image().get_image_getter("data")(b0) + t1_img = nib.load(t1_file) + t1_data = t1_img.get_fdata() + t1_mask = nib.load(t1_brain_mask) + t1_data[t1_mask.get_fdata() == 0] = 0 + t1_img_masked = nib.Nifti1Image( + t1_data, t1_img.affine) + return t1_img_masked, dict( + T1w=t1_file, + BrainMask=t1_brain_mask) + + +@immlib.calc("t1_subcortex") +@as_file(suffix='_desc-subcortex_probseg.nii.gz') +def t1_subcortex(t1_masked): + """ + full path to a nifti file containing segmentation of + subcortical structures from T1w image using Brainchop + + References + ---------- + [1] Masoud, M., Hu, F., & Plis, S. (2023). Brainchop: In-browser MRI + volumetric segmentation and rendering. Journal of Open Source + Software, 8(83), 5098. + https://doi.org/10.21105/joss.05098 + """ + t1_img_masked = nib.load(t1_masked) + + subcortical_img = run_brainchop( + t1_img_masked, "subcortical") + + meta = dict( + T1w=t1_masked, + model="subcortical", + labels=[ + "Unknown", "Cerebral-White-Matter", "Cerebral-Cortex", + "Lateral-Ventricle", "Inferior-Lateral-Ventricle", + "Cerebellum-White-Matter", "Cerebellum-Cortex", + "Thalamus", "Caudate", "Putamen", "Pallidum", + "3rd-Ventricle", "4th-Ventricle", "Brain-Stem", + "Hippocampus", "Amygdala", "Accumbens-area", "VentralDC"]) + + return subcortical_img, meta + + +@immlib.calc("brain_mask") +@as_file('_desc-brain_mask.nii.gz') +def brain_mask(t1_brain_mask, b0): + """ + full path to a nifti file containing the brain mask + """ + return resample(t1_brain_mask, b0), dict( + BrainMaskinT1w=t1_brain_mask) @immlib.calc("bundle_dict", "reg_template", "tmpl_name") -def get_bundle_dict(brain_mask, b0, +def get_bundle_dict(b0, bundle_info=None, reg_template_spec="mni_T1", reg_template_space_name="mni"): """ @@ -1139,24 +1533,20 @@ def get_bundle_dict(brain_mask, b0, if bundle_info is None: bundle_info = abd.default18_bd() + abd.callosal_bd() - use_brain_mask = True - brain_mask = nib.load(brain_mask).get_fdata() - if np.all(brain_mask == 1.0): - use_brain_mask = False if isinstance(reg_template_spec, nib.Nifti1Image): reg_template = reg_template_spec else: img_l = reg_template_spec.lower() if img_l == "mni_t2": reg_template = afd.read_mni_template( - mask=use_brain_mask, weight="T2w") + mask=True, weight="T2w") elif img_l == "mni_t1": reg_template = afd.read_mni_template( - mask=use_brain_mask, weight="T1w") + mask=True, weight="T1w") elif img_l == "dti_fa_template": - reg_template = afd.read_ukbb_fa_template(mask=use_brain_mask) + reg_template = afd.read_ukbb_fa_template(mask=True) elif img_l == "hcp_atlas": - reg_template = afd.read_mni_template(mask=use_brain_mask) + reg_template = afd.read_mni_template(mask=True) elif img_l == "pediatric": reg_template = afd.read_pediatric_templates()[ "UNCNeo-withCerebellum-for-babyAFQ"] @@ -1170,7 +1560,7 @@ def get_bundle_dict(brain_mask, b0, bundle_info, resample_to=reg_template) - if bundle_dict.resample_subject_to is None: + if bundle_dict.resample_subject_to == True: bundle_dict.resample_subject_to = b0 return bundle_dict, reg_template, reg_template_space_name @@ -1186,8 +1576,13 @@ def get_data_plan(kwargs): data_tasks = with_name([ get_data_gtab, b0, b0_mask, brain_mask, + t1_brain_mask, t1_subcortex, t1_masked, + configure_ncpus_nthreads, + t1w_pve, wm_gm_interface, dti_fit, dki_fit, fwdti_fit, anisotropic_power_map, - csd_anisotropic_index, + csd_anisotropic_index, csd_aodf, + msmt_params, msmt_apm, msmt_aodf, + msmt_aodf_asi, msmt_aodf_opm, msmt_aodf_nufid, dti_fa, dti_lt, dti_cfa, dti_pdd, dti_md, dki_kt, dki_lt, dki_fa, gq, gq_pmap, gq_ai, opdt_params, opdt_pmap, opdt_ai, csa_params, csa_pmap, csa_ai, @@ -1196,6 +1591,7 @@ def get_data_plan(kwargs): dki_md, dki_awf, dki_mk, dki_kfa, dki_ga, dki_rd, dti_ga, dti_rd, dti_ad, dki_ad, dki_rk, dki_ak, dti_params, dki_params, fwdti_params, + dki_cl, dki_cp, dki_cs, rumba_fit, rumba_params, rumba_model, rumba_f_csf, rumba_f_gm, rumba_f_wm, csd_params, get_bundle_dict]) @@ -1218,18 +1614,4 @@ def get_data_plan(kwargs): scalars.append(scalar) kwargs["scalars"] = scalars - bm_def = kwargs.get( - "brain_mask_definition", None) - if bm_def is not None: - if not isinstance(bm_def, Definition): - raise TypeError( - "brain_mask_definition must be a Definition") - del kwargs["brain_mask_definition"] - data_tasks["brain_mask_res"] = immlib.calc("brain_mask")( - as_file( - suffix=( - f'_desc-{str_to_desc(bm_def.get_name())}' - '_mask.nii.gz'), - subfolder="models")(bm_def.get_image_getter("data"))) - return immlib.plan(**data_tasks) diff --git a/AFQ/tasks/decorators.py b/AFQ/tasks/decorators.py index 678fe52f5..54b74fca8 100644 --- a/AFQ/tasks/decorators.py +++ b/AFQ/tasks/decorators.py @@ -80,8 +80,12 @@ def wrapper_as_file(*args, **kwargs): if not op.exists(this_file): logger.info(f"Calculating {suffix}") - gen, meta = func( - *args, **kwargs) + try: + gen, meta = func( + *args, **kwargs) + except Exception: + print(f"Error in task: {func.__qualname__}") + raise logger.info(f"{suffix} completed. Saving to {this_file}") if isinstance(gen, nib.Nifti1Image): diff --git a/AFQ/tasks/mapping.py b/AFQ/tasks/mapping.py index 6d42cd590..823e8c5eb 100644 --- a/AFQ/tasks/mapping.py +++ b/AFQ/tasks/mapping.py @@ -15,6 +15,7 @@ from dipy.io.streamline import load_tractogram from dipy.io.stateful_tractogram import Space +from dipy.align import resample logger = logging.getLogger('AFQ') @@ -189,7 +190,7 @@ def sls_mapping(base_fname, dwi_data_file, reg_subject, data_imap, @immlib.calc("reg_subject") def get_reg_subject(data_imap, - reg_subject_spec="power_map"): + reg_subject_spec="t1w"): """ Nifti1Image which represents this subject when registering the subject to the template @@ -199,11 +200,11 @@ def get_reg_subject(data_imap, reg_subject_spec : str, instance of `AFQ.definitions.ImageDefinition`, optional # noqa The source image data to be registered. Can either be a Nifti1Image, an ImageFile, or str. - if "b0", "dti_fa_subject", "subject_sls", or "power_map," + if "b0", "dti_fa_subject", "subject_sls", "t1w", or "power_map," image data will be loaded automatically. If "subject_sls" is used, slr registration will be used and reg_template should be "hcp_atlas". - Default: "power_map" + Default: "t1w" """ if not isinstance(reg_subject_spec, str)\ and not isinstance(reg_subject_spec, nib.Nifti1Image): @@ -216,6 +217,7 @@ def get_reg_subject(data_imap, "power_map": "csd_pmap", "dti_fa_subject": "dti_fa", "subject_sls": "b0", + "t1w": "t1_masked" } bm = nib.load(data_imap["brain_mask"]) @@ -223,6 +225,11 @@ def get_reg_subject(data_imap, reg_subject_spec = data_imap[filename_dict[reg_subject_spec]] if isinstance(reg_subject_spec, str): img = nib.load(reg_subject_spec) + + if not np.allclose(img.affine, bm.affine) or not np.allclose( + img.get_fdata().shape, bm.get_fdata().shape): + img = resample(img, bm) + bm = bm.get_fdata().astype(bool) masked_data = img.get_fdata() masked_data[~bm] = 0 diff --git a/AFQ/tasks/segmentation.py b/AFQ/tasks/segmentation.py index 21ece86d2..e3b829697 100644 --- a/AFQ/tasks/segmentation.py +++ b/AFQ/tasks/segmentation.py @@ -16,6 +16,7 @@ from AFQ.tasks.utils import get_default_args import AFQ.utils.volume as auv from AFQ._fixes import gaussian_weights +import AFQ.recognition.utils as abu try: from trx.io import load as load_trx @@ -32,6 +33,9 @@ from nibabel.affines import voxel_sizes from nibabel.orientations import aff2axcodes from dipy.io.stateful_tractogram import StatefulTractogram +from dipy.align import resample +from scipy.spatial import cKDTree + import gzip import shutil @@ -58,7 +62,9 @@ def segment(data_imap, mapping_imap, bundle_dict = data_imap["bundle_dict"] reg_template = data_imap["reg_template"] streamlines = tractography_imap["streamlines"] - if streamlines.endswith(".trk") or streamlines.endswith(".tck"): + if streamlines.endswith(".trk") or\ + streamlines.endswith(".tck") or\ + streamlines.endswith(".vtk"): tg = load_tractogram( streamlines, data_imap["dwi"], Space.VOX, bbox_valid_check=False) @@ -113,6 +119,7 @@ def segment(data_imap, mapping_imap, mapping_imap["mapping"], bundle_dict, reg_template, + data_imap["n_cpus"], **segmentation_params) seg_sft = aus.SegmentedSFT(bundles, Space.VOX) @@ -272,6 +279,67 @@ def export_density_maps(bundles, data_imap): source=bundles, bundles=list(seg_sft.bundle_names)) +@immlib.calc("endpoint_maps") +@as_file('_desc-endpoints_tractography.nii.gz') +def export_endpoint_maps(bundles, data_imap, endpoint_threshold=3): + """ + full path to a NIfTI file containing endpoint maps for each bundle + + Parameters + ---------- + endpoint_threshold : float, optional + The threshold for the endpoint maps. + If None, no endpoint maps are exported as distance to endpoints maps, + which the user can then threshold as needed. + Default: 3 + """ + seg_sft = aus.SegmentedSFT.fromfile(bundles) + entire_endpoint_map = np.zeros(( + *data_imap["data"].shape[:3], + len(seg_sft.bundle_names))) + + b0_img = nib.load(data_imap["b0"]) + pve_img = nib.load(data_imap["t1w_pve"]) + pve_data = pve_img.get_fdata() + gm = resample(pve_data[..., 1], b0_img.get_fdata(), + pve_img.affine, b0_img.affine).get_fdata() + + R = b0_img.affine[0:3, 0:3] + vox_to_mm = np.mean(np.diag(np.linalg.cholesky(R.T.dot(R)))) + + for ii, bundle_name in enumerate(seg_sft.bundle_names): + bundle_sl = seg_sft.get_bundle(bundle_name) + if len(bundle_sl.streamlines) == 0: + continue + + bundle_sl.to_vox() + + endpoints = np.vstack([s[0] for s in bundle_sl.streamlines] + + [s[-1] for s in bundle_sl.streamlines]) + + shape = b0_img.get_fdata().shape + xv, yv, zv = np.meshgrid(np.arange(shape[0]), + np.arange(shape[1]), + np.arange(shape[2]), indexing='ij') + grid_points = np.column_stack([xv.ravel(), yv.ravel(), zv.ravel()]) + + kdtree = cKDTree(endpoints) + distances, _ = kdtree.query(grid_points) + tractogram_distance = distances.reshape(shape) + + entire_endpoint_map[..., ii] = tractogram_distance * ( + gm > 0.5).astype(np.float32) * vox_to_mm + + if endpoint_threshold is not None: + entire_endpoint_map = np.logical_and( + entire_endpoint_map < endpoint_threshold, + entire_endpoint_map != 0.0).astype(np.float32) + + return nib.Nifti1Image( + entire_endpoint_map, data_imap["dwi_affine"]), dict( + source=bundles, bundles=list(seg_sft.bundle_names)) + + @immlib.calc("profiles") @as_file('_desc-profiles_tractography.csv') def tract_profiles(bundles, @@ -434,6 +502,7 @@ def get_segmentation_plan(kwargs): export_bundle_lengths, export_bundles, export_density_maps, + export_endpoint_maps, segment, tract_profiles]) diff --git a/AFQ/tasks/tractography.py b/AFQ/tasks/tractography.py index 5d1f483b0..c3eeb9dc8 100644 --- a/AFQ/tasks/tractography.py +++ b/AFQ/tasks/tractography.py @@ -6,14 +6,13 @@ import dipy.data as dpd import immlib -import multiprocessing from AFQ.tasks.decorators import as_file from AFQ.tasks.utils import with_name from AFQ.definitions.utils import Definition import AFQ.tractography.tractography as aft from AFQ.tasks.utils import get_default_args -from AFQ.definitions.image import ScalarImage +from AFQ.definitions.image import ScalarImage, ThreeTissueImage from AFQ.tractography.utils import gen_seeds, get_percentile_threshold from trx.trx_file_memmap import TrxFile @@ -133,12 +132,16 @@ def export_stop_mask_thresholded(data_imap, stop, tracking_params): full path to a nifti file containing the tractography stop mask thresholded """ - thresh = tracking_params['stop_threshold'] - threshed_data = nib.load(stop).get_fdata() > thresh - stop_mask_desc = dict(source=stop, thresh=thresh) - return nib.Nifti1Image( - threshed_data.astype(np.float32), - data_imap["dwi_affine"]), stop_mask_desc + if isinstance(tracking_params['stop_threshold'], str): + raise ValueError("Cannot generate thresholded " + "stop mask for CMC or ACT") + else: + thresh = tracking_params['stop_threshold'] + threshed_data = nib.load(stop).get_fdata() > thresh + stop_mask_desc = dict(source=stop, thresh=thresh) + return nib.Nifti1Image( + threshed_data.astype(np.float32), + data_imap["dwi_affine"]), stop_mask_desc @immlib.calc("stop") @@ -175,15 +178,14 @@ def streamlines(data_imap, seed, stop, fodf, this_tracking_params['seed_mask'] = nib.load(seed).get_fdata() if isinstance(stop, str): this_tracking_params['stop_mask'] = nib.load(stop).get_fdata() + elif isinstance(stop, nib.Nifti1Image): + this_tracking_params['stop_mask'] = stop.get_fdata() else: this_tracking_params['stop_mask'] = stop is_trx = this_tracking_params.get("trx", False) - num_chunks = this_tracking_params.pop("num_chunks", False) - - if num_chunks is True: - num_chunks = multiprocessing.cpu_count() - 1 + num_chunks = data_imap["n_cpus"] if is_trx: start_time = time() @@ -278,10 +280,12 @@ def delete_lazyt(self, id): sft = trx_concatenate(sfts) else: lazyt = aft.track(fodf, **this_tracking_params) + # Chunk size is number of streamlines tracked before saving to disk. sft = TrxFile.from_lazy_tractogram( lazyt, seed, - dtype_dict=dtype_dict) + dtype_dict=dtype_dict, + chunk_size=1e5) n_streamlines = len(sft) else: @@ -364,7 +368,7 @@ def gpu_tractography(data_imap, tracking_params, fodf, seed, stop, sft = gpu_track( data, data_imap["gtab"], - nib.load(seed), nib.load(stop), + seed, stop, tracking_params["odf_model"], sphere, tracking_params["directions"], @@ -444,26 +448,27 @@ def get_tractography_plan(kwargs): if isinstance(kwargs["tracking_params"]["odf_model"], str): kwargs["tracking_params"]["odf_model"] =\ kwargs["tracking_params"]["odf_model"].upper() + if kwargs["tracking_params"]["seed_mask"] is None: kwargs["tracking_params"]["seed_mask"] = ScalarImage( - kwargs["best_scalar"]) - kwargs["tracking_params"]["seed_threshold"] = 0.2 + "wm_gm_interface") + kwargs["tracking_params"]["seed_threshold"] = 0.5 logger.info(( - "No seed mask given, using FA (or first scalar if none are FA)" - "thresholded to 0.2")) + "No seed mask given, using GM-WM interface " + "from 3T prob maps esimated from T1w")) + if kwargs["tracking_params"]["stop_mask"] is None: - kwargs["tracking_params"]["stop_mask"] = ScalarImage( - kwargs["best_scalar"]) - kwargs["tracking_params"]["stop_threshold"] = 0.2 + kwargs["tracking_params"]["stop_threshold"] = "ACT" + kwargs["tracking_params"]["stop_mask"] = ThreeTissueImage() logger.info(( - "No stop mask given, using FA (or first scalar if none are FA)" - "thresholded to 0.2")) + "No stop mask given, using ACT " + "and 3T prob maps esimated from T1w")) stop_mask = kwargs["tracking_params"]['stop_mask'] seed_mask = kwargs["tracking_params"]['seed_mask'] odf_model = kwargs["tracking_params"]['odf_model'] - if kwargs["tracking_params"]["tracker"] == "pft": + if isinstance(kwargs["tracking_params"]["stop_threshold"], str): probseg_funcs = stop_mask.get_image_getter("tractography") tractography_tasks["wm_res"] = immlib.calc("pve_wm")(as_file( '_desc-wm_probseg.nii.gz', subfolder="tractography")( diff --git a/AFQ/tasks/viz.py b/AFQ/tasks/viz.py index 72925bb47..1e4acc305 100644 --- a/AFQ/tasks/viz.py +++ b/AFQ/tasks/viz.py @@ -10,7 +10,6 @@ from dipy.align import resample from AFQ.tasks.utils import get_fname, with_name, str_to_desc -import AFQ.utils.volume as auv from AFQ.viz.utils import Viz import AFQ.utils.streamlines as aus from AFQ.utils.path import write_json, drop_extension @@ -20,14 +19,16 @@ logger = logging.getLogger('AFQ') -def _viz_prepare_vol(vol, xform, mapping, scalar_dict): +def _viz_prepare_vol(vol, xform, mapping, scalar_dict, ref): if vol in scalar_dict.keys(): vol = scalar_dict[vol] - if isinstance(vol, str): - vol = nib.load(vol) - vol = vol.get_fdata() + if isinstance(vol, str): - vol = nib.load(vol).get_fdata() + vol = nib.load(vol) + + vol = resample(vol, ref) + + vol = vol.get_fdata() if xform: vol = mapping.transform_inverse(vol) vol[np.isnan(vol)] = 0 @@ -75,11 +76,12 @@ def viz_bundles(base_fname, mapping = mapping_imap["mapping"] scalar_dict = segmentation_imap["scalar_dict"] profiles_file = segmentation_imap["profiles"] - volume = data_imap["masked_b0"] + volume = nib.load(data_imap["t1_masked"]) + t1_affine = nib.load(data_imap["t1_masked"]).affine shade_by_volume = data_imap[best_scalar] - volume = _viz_prepare_vol(volume, False, mapping, scalar_dict) shade_by_volume = _viz_prepare_vol( - shade_by_volume, False, mapping, scalar_dict) + shade_by_volume, False, mapping, scalar_dict, volume) + volume = _viz_prepare_vol(volume, False, mapping, scalar_dict, volume) flip_axes = [False, False, False] for i in range(3): @@ -102,6 +104,7 @@ def viz_bundles(base_fname, figure = viz_backend.visualize_bundles( segmentation_imap["bundles"], + affine=t1_affine, shade_by_volume=shade_by_volume, sbv_lims=sbv_lims_bundles, include_profiles=(pd.read_csv(profiles_file), best_scalar), @@ -168,17 +171,17 @@ def viz_indivBundle(base_fname, """ mapping = mapping_imap["mapping"] bundle_dict = data_imap["bundle_dict"] - reg_template = data_imap["reg_template"] scalar_dict = segmentation_imap["scalar_dict"] - volume = data_imap["masked_b0"] + volume_img = nib.load(data_imap["t1_masked"]) + t1_affine = nib.load(data_imap["t1_masked"]).affine shade_by_volume = data_imap[best_scalar] profiles = pd.read_csv(segmentation_imap["profiles"]) start_time = time() volume = _viz_prepare_vol( - volume, False, mapping, scalar_dict) + volume_img, False, mapping, scalar_dict, volume_img) shade_by_volume = _viz_prepare_vol( - shade_by_volume, False, mapping, scalar_dict) + shade_by_volume, False, mapping, scalar_dict, volume_img) flip_axes = [False, False, False] for i in range(3): @@ -214,6 +217,7 @@ def viz_indivBundle(base_fname, if len(bundles.get_bundle(bundle_name)) > 0: figure = viz_backend.visualize_bundles( bundles, + affine=t1_affine, shade_by_volume=shade_by_volume, sbv_lims=sbv_lims_indiv, bundle=bundle_name, @@ -231,8 +235,10 @@ def viz_indivBundle(base_fname, name = roi_fname.split("desc-")[1].split("_")[0] if "probseg" in roi_fname: name = f"{name}Probseg" + roi_img = nib.load(roi_fname) + roi_img = resample(roi_img, volume_img) figure = viz_backend.visualize_roi( - roi_fname, + roi_img, name=name, flip_axes=flip_axes, inline=False, @@ -293,6 +299,7 @@ def viz_indivBundle(base_fname, inline=False) core_fig = viz_backend.visualize_bundles( segmentation_imap["bundles"], + affine=t1_affine, shade_by_volume=shade_by_volume, sbv_lims=sbv_lims_indiv, bundle=bundle_name, @@ -307,6 +314,7 @@ def viz_indivBundle(base_fname, segmentation_imap["bundles"], bundle_name, best_scalar, + affine=t1_affine, flip_axes=flip_axes, figure=core_fig, include_profile=True) diff --git a/AFQ/tests/test_api.py b/AFQ/tests/test_api.py index b3155bd56..2fbf7ec26 100644 --- a/AFQ/tests/test_api.py +++ b/AFQ/tests/test_api.py @@ -22,9 +22,7 @@ import dipy.tracking.utils as dtu import dipy.tracking.streamline as dts -from dipy.data import get_fnames from dipy.testing.decorators import xvfb_it -from dipy.io.streamline import load_tractogram import AFQ.api.bundle_dict as abd from AFQ.api.group import GroupAFQ, ParallelGroupAFQ @@ -33,8 +31,7 @@ import AFQ.utils.streamlines as aus import AFQ.utils.bin as afb from AFQ.definitions.mapping import SynMap, AffMap, SlrMap, IdentityMap -from AFQ.definitions.image import (RoiImage, PFTImage, ImageFile, ScalarImage, - TemplateImage) +from AFQ.definitions.image import ImageFile, ScalarImage, TemplateImage def touch(fname, times=None): @@ -266,21 +263,13 @@ def test_AFQ_custom_tract(): myafq = GroupAFQ( bids_path, preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', bundle_info=bundle_info, import_tract={ "suffix": "tractography", "scope": "vistasoft" }) - # equivalent ParticipantAFQ version of call may be useful as reference - # myafq = ParticipantAFQ( - # sub_path + "/sub-01_ses-01_dwi.nii.gz", - # sub_path + "/sub-01_ses-01_dwi.bval", - # sub_path + "/sub-01_ses-01_dwi.bvec", - # output_dir=sub_path, - # bundle_info=bundle_names, - # import_tract=sub_path + '/subsampled_tractography.trk') - myafq.export("streamlines") @@ -310,11 +299,13 @@ def test_AFQ_fury(): myafq = GroupAFQ( bids_path=bids_path, preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', + tracking_params={"n_seeds": 250000}, viz_backend_spec="fury") myafq.export("all_bundles_figure") -@pytest.mark.nightly_pft +@pytest.mark.nightly_custom def test_AFQ_trx(): tmpdir = tempfile.TemporaryDirectory() bids_path = op.join(tmpdir.name, "stanford_hardi") @@ -323,9 +314,10 @@ def test_AFQ_trx(): myafq = GroupAFQ( bids_path=bids_path, preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', # should throw warning but not error scalars=["dti_fa", "dti_md", ImageFile(suffix="DNE")], - tracking_params={"trx": True}) + tracking_params={"trx": True, "n_seeds": 250000}) myafq.export("all_bundles_figure") @@ -378,6 +370,7 @@ def test_AFQ_data(): myafq = GroupAFQ( bids_path=bids_path, preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', mapping_definition=mapping) npt.assert_equal(nib.load(myafq.export("b0")["01"]).shape, myafq.export("dwi")["01"].shape[:3]) @@ -399,6 +392,7 @@ def test_AFQ_anisotropic(): myafq = GroupAFQ( bids_path=bids_path, preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', min_bval=1990, max_bval=2010, b0_threshold=50, @@ -426,6 +420,7 @@ def test_AFQ_anisotropic(): 'models/sub-01_ses-01_model-csd_param-apm_dwimap.nii.gz')) +@pytest.mark.nightly_basic def test_API_type_checking(): _, bids_path, _ = get_temp_hardi() seed = 2022 @@ -451,6 +446,7 @@ def test_API_type_checking(): myafq = GroupAFQ( bids_path, preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', import_tract=["dwi"]) myafq.export("streamlines") except LazyError as e: @@ -459,24 +455,18 @@ def test_API_type_checking(): raise e del myafq - with pytest.raises( - TypeError, - match="brain_mask_definition must be a Definition"): - myafq = GroupAFQ( - bids_path, - preproc_pipeline='vistasoft', - brain_mask_definition="not a brain mask") - with pytest.raises( ValueError, match=r"No file found with these parameters:\n*"): myafq = GroupAFQ( bids_path, preproc_pipeline='vistasoft', - brain_mask_definition=ImageFile( - suffix='dne_dne', - filters={'scope': 'dne_dne'})) - myafq.export("brain_mask") + t1_pipeline='freesurfer', + tracking_params=dict( + seed_mask=ImageFile( + suffix='dne_dne', + filters={'scope': 'dne_dne'}))) + myafq.export("seed_mask") with pytest.raises( TypeError, @@ -486,6 +476,7 @@ def test_API_type_checking(): myafq = GroupAFQ( bids_path, preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', bundle_info=[2, 3]) try: myafq.export("bundle_dict") @@ -502,6 +493,7 @@ def test_API_type_checking(): myafq = GroupAFQ( bids_path, preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', mapping_definition=IdentityMap(), reg_subject_spec="dti_fa_subject", tracking_params={ @@ -526,6 +518,7 @@ def test_API_type_checking(): myafq = GroupAFQ( bids_path, preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', viz_backend_spec="matplotlib") try: myafq.export("viz_backend") @@ -536,6 +529,7 @@ def test_API_type_checking(): del myafq +@pytest.mark.nightly_anisotropic def test_AFQ_slr(): """ Test if API can run using slr map @@ -553,6 +547,7 @@ def test_AFQ_slr(): myafq = GroupAFQ( bids_path=bids_path, preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', reg_subject_spec='subject_sls', reg_template_spec='hcp_atlas', import_tract=op.join( @@ -561,8 +556,8 @@ def test_AFQ_slr(): 'stanford_hardi_tractography', 'full_segmented_cleaned_tractography.trk'), segmentation_params={ - "dist_to_waypoint": 10, - "parallel_segmentation": {"engine": "serial"}}, + "dist_to_waypoint": 10}, + n_cpus=1, bundle_info=bd, mapping_definition=SlrMap(slr_kwargs={ "rng": np.random.RandomState(seed)})) @@ -582,9 +577,11 @@ def test_AFQ_reco(): myafq = GroupAFQ( bids_path=bids_path, preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', viz_backend_spec="plotly", profile_weights="median", bundle_info=abd.reco_bd(16), + tracking_params={"n_seeds": 1e6}, segmentation_params={ 'rng': 42}) @@ -610,6 +607,7 @@ def test_AFQ_reco80(): myafq = GroupAFQ( bids_path=bids_path, preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', tracking_params=tracking_params, bundle_info=abd.reco_bd(16), segmentation_params={ @@ -620,10 +618,17 @@ def test_AFQ_reco80(): npt.assert_(len(seg_sft.get_bundle('CCMid').streamlines) > 0) +@pytest.mark.nightly_reco80 def test_AFQ_pydra(): - _, bids_path = afd.fetch_hbn_preproc(["NDARAA948VFH", "NDARAV554TP2"]) - pga = ParallelGroupAFQ(bids_path, preproc_pipeline="qsiprep") + participants = ["NDARAA948VFH", "NDARAV554TP2"] + _, bids_path = afd.fetch_hbn_preproc(participants) + pga = ParallelGroupAFQ( + bids_path, + output_dir=op.join(bids_path, 'derivatives', 'pydra_afq'), + participant_labels=participants, + preproc_pipeline="qsiprep") pga.export("dti_fa") + pga.export("wm_gm_interface") def test_AFQ_filterb(): @@ -631,63 +636,11 @@ def test_AFQ_filterb(): myafq = GroupAFQ( bids_path=bids_path, preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', max_bval=1000) myafq.export("b0") -@pytest.mark.nightly_pft -def test_AFQ_pft(): - """ - Test pft interface for AFQ - """ - _, bids_path, sub_path = get_temp_hardi() - - bundle_names = abd.default18_bd()[ - "Left Superior Longitudinal", - "Right Superior Longitudinal", - "Left Arcuate", - "Right Arcuate", - "Left Corticospinal", - "Right Corticospinal", - "Forceps Minor"] - - f_pve_csf, f_pve_gm, f_pve_wm = get_fnames('stanford_pve_maps') - os.rename(f_pve_wm, op.join(sub_path, - "sub-01_ses-01_label-WM_probseg.nii.gz")) - os.rename(f_pve_gm, op.join(sub_path, - "sub-01_ses-01_label-GM_probseg.nii.gz")) - os.rename(f_pve_csf, op.join(sub_path, - "sub-01_ses-01_label-CSF_probseg.nii.gz")) - - stop_mask = PFTImage( - ImageFile(suffix="probseg", filters={"label": "WM"}), - ImageFile(suffix="probseg", filters={"label": "GM"}), - ImageFile(suffix="probseg", filters={"label": "CSF"})) - t_output_dir = tempfile.TemporaryDirectory() - - myafq = GroupAFQ( - bids_path, - preproc_pipeline='vistasoft', - bundle_info=bundle_names, - output_dir=t_output_dir.name, - tracking_params={ - "stop_mask": stop_mask, - "stop_threshold": "CMC", - "tracker": "pft", - "maxlen": 150, - }) - sl_file = myafq.export("streamlines")["01"] - dwi_file = myafq.export("dwi")["01"] - sls = load_tractogram( - sl_file, - dwi_file, - bbox_valid_check=False, - trk_header_check=False).streamlines - for sl in sls: - # double the maxlen, due to step size of 0.5 - assert len(sl) <= 300 - - @pytest.mark.nightly_custom def test_AFQ_custom_subject_reg(): """ @@ -708,6 +661,7 @@ def test_AFQ_custom_subject_reg(): b0_file = GroupAFQ( bids_path, preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', bundle_info=bundle_info).export("b0")["01"] # make a different temporary directly to test this custom file in @@ -718,6 +672,7 @@ def test_AFQ_custom_subject_reg(): myafq = GroupAFQ( bids_path, preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', bundle_info=bundle_info, reg_template_spec="mni_T2", reg_subject_spec=ImageFile( @@ -736,6 +691,7 @@ def test_AFQ_FA(): myafq = GroupAFQ( bids_path=bids_path, preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', reg_template_spec='dti_fa_template', reg_subject_spec='dti_fa_subject') myafq.export("rois") @@ -769,27 +725,6 @@ def test_auto_cli(): afb.parse_config_run_afq(config_file, arg_dict, False) -@pytest.mark.skip(reason="causes segmentation fault") -def test_run_using_auto_cli(): - tmpdir, bids_path, _ = get_temp_hardi() - config_file = op.join(tmpdir.name, 'test.toml') - - arg_dict = afb.func_dict_to_arg_dict() - - # set our custom defaults for the toml file - # It is easier to edit them here, than to parse the file and edit them - # after the file is written - arg_dict['BIDS_PARAMS']['bids_path']['default'] = bids_path - arg_dict['BIDS_PARAMS']['dmriprep']['default'] = 'vistasoft' - arg_dict['DATA']['bundle_info']['default'] = abd.default18_bd()[( - "Left Corticospinal")] - arg_dict['TRACTOGRAPHY_PARAMS']['n_seeds']['default'] = 500 - arg_dict['TRACTOGRAPHY_PARAMS']['random_seeds']['default'] = True - - afb.generate_config(config_file, arg_dict, False) - afb.parse_config_run_afq(config_file, arg_dict, False) - - def test_AFQ_data_waypoint(): """ Test with some actual data again, this time for track segmentation @@ -805,6 +740,9 @@ def test_AFQ_data_waypoint(): vista_folder = op.join( bids_path, "derivatives/vistasoft/sub-01/ses-01/dwi") + freesurfer_folder = op.join( + bids_path, + "derivatives/freesurfer/sub-01/ses-01/anat") # Prepare LV1 ROI lv1_files, lv1_folder = afd.fetch_stanford_hardi_lv1() @@ -838,8 +776,8 @@ def test_AFQ_data_waypoint(): } tracking_params = dict(odf_model="csd", - seed_mask=RoiImage(), - n_seeds=200, + n_seeds=2000, + directions="prob", # for efficiency random_seeds=True, rng_seed=42) segmentation_params = dict(return_idx=True) @@ -850,6 +788,7 @@ def test_AFQ_data_waypoint(): op.join(vista_folder, "sub-01_ses-01_dwi.nii.gz"), op.join(vista_folder, "sub-01_ses-01_dwi.bval"), op.join(vista_folder, "sub-01_ses-01_dwi.bvec"), + op.join(freesurfer_folder, "sub-01_ses-01_T1w.nii.gz"), afq_folder, bundle_info=bundle_info, scalars=[ @@ -892,14 +831,14 @@ def test_AFQ_data_waypoint(): seg_sft = aus.SegmentedSFT.fromfile( myafq.export("bundles")) npt.assert_(len(seg_sft.get_bundle( - 'Left Corticospinal').streamlines) > 0) + 'Left Superior Longitudinal').streamlines) > 0) # Test bundles exporting: myafq.export("indiv_bundles") assert op.exists(op.join( myafq.export("output_dir"), 'bundles', - 'sub-01_ses-01_desc-LeftCorticospinal_tractography.trk')) # noqa + 'sub-01_ses-01_desc-RightSuperiorLongitudinal_tractography.trk')) # noqa tract_profile_fname = myafq.export("profiles") tract_profiles = pd.read_csv(tract_profile_fname) @@ -933,11 +872,12 @@ def test_AFQ_data_waypoint(): # Set up config to use the same parameters as above: # ROI mask needs to be put in quotes in config - tracking_params = dict(odf_model="CSD", - seed_mask="RoiImage()", - n_seeds=200, - random_seeds=True, - rng_seed=42) + tracking_params = dict( + odf_model="CSD", + n_seeds=2000, + directions="prob", # for efficiency + random_seeds=True, + rng_seed=42) bundle_dict_as_str = ( 'default18_bd()[' '"Left Superior Longitudinal",' @@ -953,7 +893,8 @@ def test_AFQ_data_waypoint(): config = dict( BIDS_PARAMS=dict( bids_path=bids_path, - preproc_pipeline='vistasoft'), + preproc_pipeline='vistasoft', + t1_pipeline='freesurfer',), DATA=dict( bundle_info=bundle_dict_as_str), SEGMENTATION=dict( diff --git a/AFQ/tests/test_msmt.py b/AFQ/tests/test_msmt.py new file mode 100644 index 000000000..fedafa57d --- /dev/null +++ b/AFQ/tests/test_msmt.py @@ -0,0 +1,458 @@ +# Original source: github.com/dipy/dipy +# Copyright (c) 2008-2025, dipy developers. +# Licensed under the 3-clause BSD license +# Modified by John Kruper for numba MSMT testing +# Will ultimately be moved upstream to DIPY + +import warnings + +import numpy as np +import numpy.testing as npt + +from dipy.core.gradients import GradientTable +from dipy.data import default_sphere, get_3shell_gtab +from dipy.reconst import shm +from dipy.reconst.mcsd import ( + auto_response_msmt, + mask_for_response_msmt, + multi_shell_fiber_response, + response_from_mask_msmt, +) +from AFQ.models.msmt import MultiShellDeconvModel +from dipy.sims.voxel import add_noise, multi_tensor, single_tensor +from dipy.testing.decorators import set_random_number_generator + +wm_response = np.array( + [ + [1.7e-3, 0.4e-3, 0.4e-3, 25.0], + [1.7e-3, 0.4e-3, 0.4e-3, 25.0], + [1.7e-3, 0.4e-3, 0.4e-3, 25.0], + ] +) +csf_response = np.array( + [ + [3.0e-3, 3.0e-3, 3.0e-3, 100.0], + [3.0e-3, 3.0e-3, 3.0e-3, 100.0], + [3.0e-3, 3.0e-3, 3.0e-3, 100.0], + ] +) +gm_response = np.array( + [ + [4.0e-4, 4.0e-4, 4.0e-4, 40.0], + [4.0e-4, 4.0e-4, 4.0e-4, 40.0], + [4.0e-4, 4.0e-4, 4.0e-4, 40.0], + ] +) + + +def get_test_data(rng): + gtab = get_3shell_gtab() + evals_list = [ + np.array([1.7e-3, 0.4e-3, 0.4e-3]), + np.array([6.0e-4, 4.0e-4, 4.0e-4]), + np.array([3.0e-3, 3.0e-3, 3.0e-3]), + ] + s0 = [0.8, 1, 4] + signals = [single_tensor(gtab, x[0], evals=x[1]) for x in zip(s0, evals_list)] + tissues = [0, 0, 2, 0, 1, 0, 0, 1, 2] # wm=0, gm=1, csf=2 + data = [add_noise(signals[tissue], 80, s0[0], rng=rng) for tissue in tissues] + data = np.asarray(data).reshape((3, 3, 1, len(signals[0]))) + tissues = np.asarray(tissues).reshape((3, 3, 1)) + masks = [np.where(tissues == x, 1, 0) for x in range(3)] + responses = [np.concatenate((x[0], [x[1]])) for x in zip(evals_list, s0)] + return gtab, data, masks, responses + + +def _expand(m, iso, coeff): + params = np.zeros(len(m)) + params[m == 0] = coeff[iso:] + params = np.concatenate([coeff[:iso], params]) + return params + + +def test_mcsd_model_delta(): + sh_order_max = 8 + gtab = get_3shell_gtab() + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=shm.descoteaux07_legacy_msg, + category=PendingDeprecationWarning, + ) + response = multi_shell_fiber_response( + sh_order_max, [0, 1000, 2000, 3500], wm_response, gm_response, csf_response + ) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=shm.descoteaux07_legacy_msg, + category=PendingDeprecationWarning, + ) + model = MultiShellDeconvModel(gtab, response) + iso = response.iso + + theta, phi = default_sphere.theta, default_sphere.phi + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=shm.descoteaux07_legacy_msg, + category=PendingDeprecationWarning, + ) + B = shm.real_sh_descoteaux_from_index( + response.m_values, response.l_values, theta[:, None], phi[:, None] + ) + + wm_delta = model.delta.copy() + # set isotropic components to zero + wm_delta[:iso] = 0.0 + wm_delta = _expand(model.m_values, iso, wm_delta) + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=shm.descoteaux07_legacy_msg, + category=PendingDeprecationWarning, + ) + for i, s in enumerate([0, 1000, 2000, 3500]): + g = GradientTable(default_sphere.vertices * s) + signal = model.predict(wm_delta, gtab=g) + expected = np.dot(response.response[i, iso:], B.T) + npt.assert_array_almost_equal(signal, expected) + + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=shm.descoteaux07_legacy_msg, + category=PendingDeprecationWarning, + ) + signal = model.predict(wm_delta, gtab=gtab) + fit = model.fit(signal) + m = model.m_values + npt.assert_array_almost_equal(fit.shm_coeff[m != 0], 0.0, 2) + + +def test_MultiShellDeconvModel_response(): + gtab = get_3shell_gtab() + + sh_order_max = 8 + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=shm.descoteaux07_legacy_msg, + category=PendingDeprecationWarning, + ) + response = multi_shell_fiber_response( + sh_order_max, [0, 1000, 2000, 3500], wm_response, gm_response, csf_response + ) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=shm.descoteaux07_legacy_msg, + category=PendingDeprecationWarning, + ) + model_1 = MultiShellDeconvModel(gtab, response, sh_order_max=sh_order_max) + responses = np.array([wm_response, gm_response, csf_response]) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=shm.descoteaux07_legacy_msg, + category=PendingDeprecationWarning, + ) + model_2 = MultiShellDeconvModel(gtab, responses, sh_order_max=sh_order_max) + response_1 = model_1.response.response + response_2 = model_2.response.response + npt.assert_array_almost_equal(response_1, response_2, 0) + + npt.assert_raises(ValueError, MultiShellDeconvModel, gtab, np.ones((4, 3, 4))) + npt.assert_raises( + ValueError, MultiShellDeconvModel, gtab, np.ones((3, 3, 4)), iso=3 + ) + + +def test_MultiShellDeconvModel(): + gtab = get_3shell_gtab() + + mevals = np.array([wm_response[0, :3], wm_response[0, :3]]) + angles = [(0, 0), (60, 0)] + + S_wm, sticks = multi_tensor( + gtab, + mevals, + S0=wm_response[0, 3], + angles=angles, + fractions=[30.0, 70.0], + snr=None, + ) + S_gm = gm_response[0, 3] * np.exp(-gtab.bvals * gm_response[0, 0]) + S_csf = csf_response[0, 3] * np.exp(-gtab.bvals * csf_response[0, 0]) + + sh_order_max = 8 + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=shm.descoteaux07_legacy_msg, + category=PendingDeprecationWarning, + ) + response = multi_shell_fiber_response( + sh_order_max, [0, 1000, 2000, 3500], wm_response, gm_response, csf_response + ) + model = MultiShellDeconvModel(gtab, response) + vf = [0.325, 0.2, 0.475] + signal = sum(i * j for i, j in zip(vf, [S_csf, S_gm, S_wm])) + fit = model.fit(signal) + + # Testing both ways to predict + S_pred_fit = fit.predict() + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=shm.descoteaux07_legacy_msg, + category=PendingDeprecationWarning, + ) + S_pred_model = model.predict(fit.all_shm_coeff) + + npt.assert_array_almost_equal(S_pred_fit, S_pred_model, 0) + npt.assert_array_almost_equal(S_pred_fit, signal, 0) + + +def test_MSDeconvFit(): + gtab = get_3shell_gtab() + + mevals = np.array([wm_response[0, :3], wm_response[0, :3]]) + angles = [(0, 0), (60, 0)] + + S_wm, sticks = multi_tensor( + gtab, + mevals, + S0=wm_response[0, 3], + angles=angles, + fractions=[30.0, 70.0], + snr=None, + ) + S_gm = gm_response[0, 3] * np.exp(-gtab.bvals * gm_response[0, 0]) + S_csf = csf_response[0, 3] * np.exp(-gtab.bvals * csf_response[0, 0]) + + sh_order_max = 8 + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=shm.descoteaux07_legacy_msg, + category=PendingDeprecationWarning, + ) + response = multi_shell_fiber_response( + sh_order_max, [0, 1000, 2000, 3500], wm_response, gm_response, csf_response + ) + model = MultiShellDeconvModel(gtab, response) + vf = [0.325, 0.2, 0.475] + signal = sum(i * j for i, j in zip(vf, [S_csf, S_gm, S_wm])) + fit = model.fit(signal) + + # Testing volume fractions + npt.assert_array_almost_equal(fit.volume_fractions, vf, 1) + + +def test_multi_shell_fiber_response(): + sh_order_max = 8 + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=shm.descoteaux07_legacy_msg, + category=PendingDeprecationWarning, + ) + response = multi_shell_fiber_response( + sh_order_max, [0, 1000, 2000, 3500], wm_response, gm_response, csf_response + ) + + npt.assert_equal(response.response.shape, (4, 7)) + + btens = ["LTE", "PTE", "STE", "CTE"] + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=shm.descoteaux07_legacy_msg, + category=PendingDeprecationWarning, + ) + response = multi_shell_fiber_response( + sh_order_max, + [0, 1000, 2000, 3500], + wm_response, + gm_response, + csf_response, + btens=btens, + ) + + npt.assert_equal(response.response.shape, (4, 7)) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always", category=PendingDeprecationWarning) + response = multi_shell_fiber_response( + sh_order_max, [1000, 2000, 3500], wm_response, gm_response, csf_response + ) + # Test that the number of warnings raised is greater than 1, with + # deprecation warnings being raised from using legacy SH bases as well + # as a warning from multi_shell_fiber_response + npt.assert_(len(w) > 1) + # The last warning in list is the one from multi_shell_fiber_response + npt.assert_(issubclass(w[-1].category, UserWarning)) + npt.assert_("""No b0 given. Proceeding either way.""" in str(w[-1].message)) + npt.assert_equal(response.response.shape, (3, 7)) + + +@set_random_number_generator() +def test_mask_for_response_msmt(rng): + gtab, data, masks_gt, _ = get_test_data(rng) + + with warnings.catch_warnings(record=True) as w: + wm_mask, gm_mask, csf_mask = mask_for_response_msmt( + gtab, + data, + roi_center=None, + roi_radii=(1, 1, 0), + wm_fa_thr=0.7, + gm_fa_thr=0.3, + csf_fa_thr=0.15, + gm_md_thr=0.001, + csf_md_thr=0.0032, + ) + + npt.assert_equal(len(w), 1) + npt.assert_(issubclass(w[0].category, UserWarning)) + npt.assert_("""Some b-values are higher than 1200.""" in str(w[0].message)) + + # Verifies that masks are not empty: + masks_sum = int(np.sum(wm_mask) + np.sum(gm_mask) + np.sum(csf_mask)) + npt.assert_equal(masks_sum != 0, True) + + npt.assert_array_almost_equal(masks_gt[0], wm_mask) + npt.assert_array_almost_equal(masks_gt[1], gm_mask) + npt.assert_array_almost_equal(masks_gt[2], csf_mask) + + +@set_random_number_generator() +def test_mask_for_response_msmt_nvoxels(rng): + gtab, data, _, _ = get_test_data(rng) + + with warnings.catch_warnings(record=True) as w: + wm_mask, gm_mask, csf_mask = mask_for_response_msmt( + gtab, + data, + roi_center=None, + roi_radii=(1, 1, 0), + wm_fa_thr=0.7, + gm_fa_thr=0.3, + csf_fa_thr=0.15, + gm_md_thr=0.001, + csf_md_thr=0.0032, + ) + + npt.assert_equal(len(w), 1) + npt.assert_(issubclass(w[0].category, UserWarning)) + npt.assert_("""Some b-values are higher than 1200.""" in str(w[0].message)) + + wm_nvoxels = np.sum(wm_mask) + gm_nvoxels = np.sum(gm_mask) + csf_nvoxels = np.sum(csf_mask) + npt.assert_equal(wm_nvoxels, 5) + npt.assert_equal(gm_nvoxels, 2) + npt.assert_equal(csf_nvoxels, 2) + + with warnings.catch_warnings(record=True) as w: + wm_mask, gm_mask, csf_mask = mask_for_response_msmt( + gtab, + data, + roi_center=None, + roi_radii=(1, 1, 0), + wm_fa_thr=1, + gm_fa_thr=0, + csf_fa_thr=0, + gm_md_thr=0, + csf_md_thr=0, + ) + npt.assert_equal(len(w), 6) + npt.assert_(issubclass(w[0].category, UserWarning)) + npt.assert_("""Some b-values are higher than 1200.""" in str(w[0].message)) + npt.assert_("No voxel with a FA higher than 1 were found" in str(w[1].message)) + npt.assert_("No voxel with a FA lower than 0 were found" in str(w[2].message)) + npt.assert_("No voxel with a MD lower than 0 were found" in str(w[3].message)) + npt.assert_("No voxel with a FA lower than 0 were found" in str(w[4].message)) + npt.assert_("No voxel with a MD lower than 0 were found" in str(w[5].message)) + + wm_nvoxels = np.sum(wm_mask) + gm_nvoxels = np.sum(gm_mask) + csf_nvoxels = np.sum(csf_mask) + npt.assert_equal(wm_nvoxels, 0) + npt.assert_equal(gm_nvoxels, 0) + npt.assert_equal(csf_nvoxels, 0) + + +@set_random_number_generator() +def test_response_from_mask_msmt(rng): + gtab, data, masks_gt, responses_gt = get_test_data(rng) + + response_wm, response_gm, response_csf = response_from_mask_msmt( + gtab, data, masks_gt[0], masks_gt[1], masks_gt[2], tol=20 + ) + + # Verifying that csf's response is greater than gm's + npt.assert_equal(np.sum(response_csf[:, :3]) > np.sum(response_gm[:, :3]), True) + # Verifying that csf and gm are described by spheres + npt.assert_almost_equal(response_csf[:, 1], response_csf[:, 2]) + npt.assert_allclose(response_csf[:, 0], response_csf[:, 1], rtol=1, atol=0) + npt.assert_almost_equal(response_gm[:, 1], response_gm[:, 2]) + npt.assert_allclose(response_gm[:, 0], response_gm[:, 1], rtol=1, atol=0) + # Verifying that wm is anisotropic in one direction + npt.assert_almost_equal(response_wm[:, 1], response_wm[:, 2]) + npt.assert_equal(response_wm[:, 0] > 2.5 * response_wm[:, 1], True) + + # Verifying with ground truth for the first bvalue + npt.assert_array_almost_equal(response_wm[0], responses_gt[0], 1) + npt.assert_array_almost_equal(response_gm[0], responses_gt[1], 1) + npt.assert_array_almost_equal(response_csf[0], responses_gt[2], 1) + + +@set_random_number_generator() +def test_auto_response_msmt(rng): + gtab, data, _, _ = get_test_data(rng) + + with warnings.catch_warnings(record=True) as w: + response_auto_wm, response_auto_gm, response_auto_csf = auto_response_msmt( + gtab, + data, + tol=20, + roi_center=None, + roi_radii=(1, 1, 0), + wm_fa_thr=0.7, + gm_fa_thr=0.3, + csf_fa_thr=0.15, + gm_md_thr=0.001, + csf_md_thr=0.0032, + ) + + npt.assert_(issubclass(w[0].category, UserWarning)) + npt.assert_( + """Some b-values are higher than 1200. + The DTI fit might be affected. It is advised to use + mask_for_response_msmt with bvalues lower than 1200, followed by + response_from_mask_msmt with all bvalues to overcome this.""" + in str(w[0].message) + ) + + mask_wm, mask_gm, mask_csf = mask_for_response_msmt( + gtab, + data, + roi_center=None, + roi_radii=(1, 1, 0), + wm_fa_thr=0.7, + gm_fa_thr=0.3, + csf_fa_thr=0.15, + gm_md_thr=0.001, + csf_md_thr=0.0032, + ) + + response_from_mask_wm, response_from_mask_gm, response_from_mask_csf = ( + response_from_mask_msmt(gtab, data, mask_wm, mask_gm, mask_csf, tol=20) + ) + + npt.assert_array_equal(response_auto_wm, response_from_mask_wm) + npt.assert_array_equal(response_auto_gm, response_from_mask_gm) + npt.assert_array_equal(response_auto_csf, response_from_mask_csf) diff --git a/AFQ/tests/test_nn.py b/AFQ/tests/test_nn.py new file mode 100644 index 000000000..98b095b4b --- /dev/null +++ b/AFQ/tests/test_nn.py @@ -0,0 +1,20 @@ +import nibabel as nib +import os.path as op +import tempfile +import numpy.testing as npt + +from AFQ.nn.brainchop import run_brainchop +import AFQ.data.fetch as afd + +def test_run_brainchop(): + tmpdir = tempfile.mkdtemp() + afd.organize_stanford_data(path=tmpdir) + + t1_path = op.join( + tmpdir, + ( + "stanford_hardi/derivatives/freesurfer/" + "sub-01/ses-01/anat/sub-01_ses-01_T1w.nii.gz")) + chopped_brain = run_brainchop(nib.load(t1_path), "mindgrab") + + npt.assert_(chopped_brain.get_fdata().sum() > 200000) diff --git a/AFQ/tests/test_tractography.py b/AFQ/tests/test_tractography.py index 14de5e980..43b17cb54 100644 --- a/AFQ/tests/test_tractography.py +++ b/AFQ/tests/test_tractography.py @@ -84,10 +84,9 @@ def test_pft_tracking(): ["DTI", "CSD"]): img = nib.load(fdata) data_shape = img.shape - data_affine = img.affine - pve_wm_data = nib.Nifti1Image(np.ones(data_shape[:3]), img.affine) - pve_gm_data = nib.Nifti1Image(np.zeros(data_shape[:3]), img.affine) - pve_csf_data = nib.Nifti1Image(np.zeros(data_shape[:3]), img.affine) + pve_wm_data = np.ones(data_shape[:3]) + pve_gm_data = np.ones(data_shape[:3]) + pve_csf_data = np.ones(data_shape[:3]) stop_mask = (pve_wm_data, pve_gm_data, pve_csf_data) for directions in ["det", "prob"]: diff --git a/AFQ/tractography/gputractography.py b/AFQ/tractography/gputractography.py index ebf8941dd..273150028 100644 --- a/AFQ/tractography/gputractography.py +++ b/AFQ/tractography/gputractography.py @@ -1,6 +1,7 @@ import cuslines import numpy as np +import nibabel as nib from math import radians from tqdm import tqdm import logging @@ -8,6 +9,7 @@ from dipy.reconst.shm import OpdtModel, CsaOdfModel from dipy.reconst import shm from dipy.io.stateful_tractogram import StatefulTractogram, Space +from dipy.align import resample from nibabel.streamlines.array_sequence import concatenate from nibabel.streamlines.tractogram import Tractogram @@ -21,7 +23,7 @@ # Modified from https://github.com/dipy/GPUStreamlines/blob/master/run_dipy_gpu.py -def gpu_track(data, gtab, seed_img, stop_img, +def gpu_track(data, gtab, seed_path, stop_path, odf_model, sphere, directions, seed_threshold, stop_threshold, thresholds_as_percentages, max_angle, step_size, n_seeds, random_seeds, rng_seed, use_trx, ngpus, @@ -35,18 +37,18 @@ def gpu_track(data, gtab, seed_img, stop_img, DWI data. gtab : GradientTable The gradient table. - seed_img : Nifti1Image + seed_path : str Float or binary mask describing the ROI within which we seed for tracking. - stop_img : Nifti1Image + stop_path : str A float or binary mask that determines a stopping criterion (e.g. FA). odf_model : str, optional One of {"OPDT", "CSA"} seed_threshold : float - The value of the seed_img above which tracking is seeded. + The value of the seed_path above which tracking is seeded. stop_threshold : float - The value of the stop_img below which tracking is + The value of the stop_path below which tracking is terminated. thresholds_as_percentages : bool Interpret seed_threshold and stop_threshold as percentages of the @@ -78,7 +80,23 @@ def gpu_track(data, gtab, seed_img, stop_img, Returns ------- """ - sh_order_max = 8 + seed_img = nib.load(seed_path) + + # Roughly handle ACT/CMC for now + if isinstance(stop_threshold, str): + stop_threshold = 0.3 + stop_img = stop_path[0] # Grab WM + + if isinstance(stop_img, str): + stop_img = nib.load(stop_img) + + stop_img = resample( + stop_img.get_fdata(), + seed_img.get_fdata(), + moving_affine=stop_img.affine, + static_affine=seed_img.affine) + else: + stop_img = nib.load(stop_path) seed_data = seed_img.get_fdata() stop_data = stop_img.get_fdata() @@ -89,7 +107,25 @@ def gpu_track(data, gtab, seed_img, stop_img, theta = sphere.theta phi = sphere.phi - sampling_matrix, _, _ = shm.real_sym_sh_basis(sh_order_max, theta, phi) + + if directions == "boot": + sh_order_max = 6 + full_basis = False + else: + # Determine sh_order and full_basis + sym_order = (-3.0 + np.sqrt(1.0 + 8.0 * data.shape[3])) / 2.0 + if sym_order.is_integer(): + sh_order_max = sym_order + full_basis = False + full_order = np.sqrt(data.shape[3]) - 1.0 + if full_order.is_integer(): + sh_order_max = full_order + full_basis = True + + sampling_matrix, _, _ = shm.real_sh_descoteaux( + sh_order_max, theta, phi, + full_basis=full_basis, + legacy=False) if directions == "boot": if odf_model.lower() == "opdt": diff --git a/AFQ/tractography/tractography.py b/AFQ/tractography/tractography.py index 7dfeb7739..03dac05e5 100644 --- a/AFQ/tractography/tractography.py +++ b/AFQ/tractography/tractography.py @@ -13,6 +13,7 @@ from dipy.tracking.stopping_criterion import (ThresholdStoppingCriterion, CmcStoppingCriterion, ActStoppingCriterion) +from dipy.reconst import shm from nibabel.streamlines.tractogram import LazyTractogram @@ -24,10 +25,10 @@ def track(params_file, directions="prob", max_angle=30., sphere=None, seed_mask=None, seed_threshold=0.5, thresholds_as_percentages=False, - n_seeds=1, random_seeds=False, rng_seed=None, stop_mask=None, + n_seeds=2000000, random_seeds=True, rng_seed=None, stop_mask=None, stop_threshold=0.5, step_size=0.5, minlen=50, maxlen=250, odf_model="CSD", basis_type="descoteaux07", legacy=True, - tracker="local", trx=False): + tracker="pft", trx=False): """ Tractography @@ -57,10 +58,10 @@ def track(params_file, directions="prob", max_angle=30., sphere=None, voxel on each dimension (for example, 2 => [2, 2, 2]). If this is a 2D array, these are the coordinates of the seeds. Unless random_seeds is set to True, in which case this is the total number of random seeds - to generate within the mask. Default: 1 + to generate within the mask. Default: 2000000 random_seeds : bool Whether to generate a total of n_seeds random seeds in the mask. - Default: False. + Default: True rng_seed : int random seed used to generate random seeds if random_seeds is set to True. Default: None @@ -108,7 +109,7 @@ def track(params_file, directions="prob", max_angle=30., sphere=None, tracker : str, optional Which strategy to use in tracking. This can be the standard local tracking ("local") or Particle Filtering Tracking ([Girard2014]_). - One of {"local", "pft"}. Default: "local" + One of {"local", "pft"}. Default: "pft" trx : bool, optional Whether to return the streamlines compatible with input to TRX file (i.e., as a LazyTractogram class instance). @@ -174,65 +175,59 @@ def track(params_file, directions="prob", max_angle=30., sphere=None, from_lower_triangular(model_params)) odf = tensor_odf(evals, evecs, sphere) dg = dg.from_pmf(odf, max_angle=max_angle, sphere=sphere) + elif "AODF" in odf_model: + sh_order = shm.order_from_ncoef( + model_params.shape[3], full_basis=True) + pmf = shm.sh_to_sf( + model_params, sphere, + sh_order_max=sh_order, full_basis=True) + pmf[pmf < 0] = 0 + dg = dg.from_pmf( + np.asarray(pmf, dtype=float), + max_angle=max_angle, sphere=sphere) else: dg = dg.from_shcoeff(model_params, max_angle=max_angle, sphere=sphere, basis_type=basis_type, legacy=legacy) - if tracker == "local": - if stop_mask is None: - stop_mask = np.ones(params_img.shape[:3]) + if stop_mask is None: + stop_mask = np.ones(params_img.shape[:3]) - if len(np.unique(stop_mask)) <= 2: - stopping_criterion = ThresholdStoppingCriterion(stop_mask, - 0.5) - else: - if thresholds_as_percentages: - stop_threshold = get_percentile_threshold( - stop_mask, stop_threshold) - stop_mask_copy = np.copy(stop_mask) - stop_thresh_copy = np.copy(stop_threshold) - stopping_criterion = ThresholdStoppingCriterion(stop_mask_copy, - stop_thresh_copy) - - my_tracker = LocalTracking - - elif tracker == "pft": - if not isinstance(stop_threshold, str): - raise RuntimeError( - "You are using PFT tracking, but did not provide a string ", - "'stop_threshold' input. ", - "Possible inputs are: 'CMC' or 'ACT'") + if isinstance(stop_threshold, str): if not (isinstance(stop_mask, Iterable) and len(stop_mask) == 3): raise RuntimeError( - "You are using PFT tracking, but did not provide a length " + "You are using CMC/ACT stropping, but did not provide a length " "3 iterable for `stop_mask`. " "Expected a (pve_wm, pve_gm, pve_csf) tuple.") pves = [] - pve_imgs = [] - for ii, pve in enumerate(stop_mask): + pve_affines = [] + for pve in stop_mask: if isinstance(pve, str): - img = nib.load(pve) + seg_data = nib.load(pve).get_fdata() + seg_affine = nib.load(pve).affine + elif isinstance(pve, nib.Nifti1Image): + seg_data = pve.get_fdata() + seg_affine = pve.affine else: - img = pve - pve_imgs.append(img) - pves.append(pve_imgs[-1].get_fdata()) + seg_data = pve + seg_affine = params_img.affine + pves.append(seg_data) + pve_affines.append(seg_affine) - pve_wm_img, pve_gm_img, pve_csf_img = pve_imgs pve_wm_data, pve_gm_data, pve_csf_data = pves + pve_wm_affine, pve_gm_affine, pve_csf_affine = pve_affines pve_wm_data = resample( pve_wm_data, model_params[..., 0], - moving_affine=pve_wm_img.affine, + moving_affine=pve_wm_affine, static_affine=params_img.affine).get_fdata() pve_gm_data = resample( pve_gm_data, model_params[..., 0], - moving_affine=pve_gm_img.affine, + moving_affine=pve_gm_affine, static_affine=params_img.affine).get_fdata() pve_csf_data = resample( pve_csf_data, model_params[..., 0], - moving_affine=pve_csf_img.affine, + moving_affine=pve_csf_affine, static_affine=params_img.affine).get_fdata() - my_tracker = ParticleFilteringTracking if stop_threshold == "CMC": stopping_criterion = CmcStoppingCriterion.from_pve( pve_wm_data, @@ -246,6 +241,32 @@ def track(params_file, directions="prob", max_angle=30., sphere=None, pve_wm_data, pve_gm_data, pve_csf_data) + else: + if len(stop_mask) == 3: + raise RuntimeError( + "You are not using CMC/ACT stropping, but provided tissue " + "probability maps in `stop_mask`. Please provide a single " + "3D array for `stop_mask` or use CMC/ACT") + + if len(np.unique(stop_mask)) <= 2: + stopping_criterion = ThresholdStoppingCriterion(stop_mask, + 0.5) + else: + if thresholds_as_percentages: + stop_threshold = get_percentile_threshold( + stop_mask, stop_threshold) + stop_mask_copy = np.copy(stop_mask) + stop_thresh_copy = np.copy(stop_threshold) + stopping_criterion = ThresholdStoppingCriterion(stop_mask_copy, + stop_thresh_copy) + + if tracker == "local": + my_tracker = LocalTracking + elif tracker == "pft": + my_tracker = ParticleFilteringTracking + else: + raise ValueError(f"Unrecognized tracker '{tracker}'. Must be one of " + "{'local', 'pft'}.") logger.info( f"Tracking with {len(seeds)} seeds, 2 directions per seed...") diff --git a/AFQ/utils/path.py b/AFQ/utils/path.py index 6075e4212..f3312072b 100644 --- a/AFQ/utils/path.py +++ b/AFQ/utils/path.py @@ -56,19 +56,35 @@ def space_from_fname(dwi_fname): def apply_cmd_to_afq_derivs( derivs_dir, base_fname, cmd="rm", exception_file_names=[], suffix="", - dependent_on=None): - if dependent_on is None: - dependent_on_list = ["dwi", "trk", "rec", "prof"] - elif dependent_on.lower() == "track": - dependent_on_list = ["trk", "rec", "prof"] - elif dependent_on.lower() == "recog": - dependent_on_list = ["rec", "prof"] - elif dependent_on.lower() == "prof": - dependent_on_list = ["prof"] - else: + dependent_on=None, up_to=None): + dependent_options = { + None: ["dwi", "trk", "rec", "prof"], + "track": ["trk", "rec", "prof"], + "recog": ["rec", "prof"], + "prof": ["prof"] + } + if dependent_on is not None: + dependent_on = dependent_on.lower() + dependent_on_list = dependent_options.get(dependent_on) + if dependent_on_list is None: raise ValueError(( - "dependent_on must be one of " - "None, 'track', 'recog', 'prof'.")) + "dependent_on must be one of None, " + "'track', 'recog', 'prof'.")) + + removal_patterns = { + "track": ["trk", "rec", "prof"], + "recog": ["rec", "prof"], + "prof": ["prof"] + } + if up_to is not None: + up_to = up_to.lower() + if up_to in removal_patterns: + dependent_on_list = [item for item in dependent_on_list + if item not in removal_patterns[up_to]] + else: + raise ValueError(( + "up_to must be one of None, " + "'track', 'recog', 'prof'.")) if cmd == "rm" or cmd == "cp": cmd = cmd + " -r" diff --git a/AFQ/utils/stats.py b/AFQ/utils/stats.py index 66e5e653b..a146f5bc1 100644 --- a/AFQ/utils/stats.py +++ b/AFQ/utils/stats.py @@ -1,3 +1,9 @@ +def chunk_indices(indices, num_batches): + batch_size = (len(indices) + num_batches - 1) // num_batches + for i in range(0, len(indices), batch_size): + yield indices[i:i + batch_size] + + def contrast_index(x1, x2, double=True): """ Calculate the contrast index between two arrays. diff --git a/AFQ/utils/volume.py b/AFQ/utils/volume.py index fd399b64d..5cc57782e 100644 --- a/AFQ/utils/volume.py +++ b/AFQ/utils/volume.py @@ -2,7 +2,7 @@ import numpy as np import scipy.ndimage as ndim -from skimage.morphology import binary_dilation, convex_hull_image +from skimage.morphology import binary_dilation from scipy.spatial.distance import dice import nibabel as nib diff --git a/AFQ/viz/fury_backend.py b/AFQ/viz/fury_backend.py index 42d9d7e64..8f2ad3984 100644 --- a/AFQ/viz/fury_backend.py +++ b/AFQ/viz/fury_backend.py @@ -36,7 +36,9 @@ def _inline_interact(scene, inline, interact): return scene -def visualize_bundles(seg_sft, n_points=None, +def visualize_bundles(seg_sft, + affine=None, + n_points=None, bundle=None, colors=None, color_by_direction=False, opacity=1.0, @@ -54,6 +56,10 @@ def visualize_bundles(seg_sft, n_points=None, A SegmentedSFT containing streamline information or a path to a segmented trk file. + affine : ndarray (4, 4), optional + Affine of the image to register streamlines to. + Default: None + n_points : int or None n_points to resample streamlines to before plotting. If None, no resampling is done. @@ -102,7 +108,7 @@ def visualize_bundles(seg_sft, n_points=None, figure.SetBackground(background[0], background[1], background[2]) for (sls, color, name, dimensions) in vut.tract_generator( - seg_sft, bundle, colors, n_points): + seg_sft, bundle, colors, n_points, affine): sls = list(sls) if name == "all_bundles": color = line_colors(sls) @@ -495,6 +501,7 @@ def _draw_core(sls, n_points, figure, bundle_name, indiv_profile, def single_bundle_viz(indiv_profile, seg_sft, bundle, scalar_name, + affine=None, flip_axes=[False, False, False], labelled_nodes=[0, -1], figure=None, @@ -518,6 +525,10 @@ def single_bundle_viz(indiv_profile, seg_sft, scalar_name : str The name of the scalar being used. + affine : ndarray (4, 4), optional + Affine of the image to register streamlines to. + Default: None + flip_axes : ndarray Which axes to flip, to orient the image as RAS, which is how we visualize. @@ -545,7 +556,7 @@ def single_bundle_viz(indiv_profile, seg_sft, n_points = len(indiv_profile) sls, _, bundle_name, dimensions = next(vut.tract_generator( - seg_sft, bundle, None, n_points)) + seg_sft, bundle, None, n_points, affine)) _draw_core( sls, n_points, figure, bundle_name, indiv_profile, diff --git a/AFQ/viz/plotly_backend.py b/AFQ/viz/plotly_backend.py index 493daeb8e..ff81450fe 100644 --- a/AFQ/viz/plotly_backend.py +++ b/AFQ/viz/plotly_backend.py @@ -302,7 +302,7 @@ def _plot_profiles(profiles, bundle_name, color, fig, scalar): tickfont=dict(color='white')))) -def visualize_bundles(seg_sft, n_points=None, +def visualize_bundles(seg_sft, affine=None, n_points=None, bundle=None, colors=None, shade_by_volume=None, color_by_streamline=None, n_sls_viz=3600, sbv_lims=[None, None], include_profiles=(None, None), @@ -318,6 +318,10 @@ def visualize_bundles(seg_sft, n_points=None, A SegmentedSFT containing streamline information or a path to a segmented trk file. + affine : ndarray (4, 4), optional + Affine of the image to register streamlines to. + Default: None + n_points : int or None n_points to resample streamlines to before plotting. If None, no resampling is done. @@ -413,7 +417,7 @@ def visualize_bundles(seg_sft, n_points=None, set_layout(figure) for (sls, color, name, dimensions) in vut.tract_generator( - seg_sft, bundle, colors, n_points, + seg_sft, bundle, colors, n_points, affine, n_sls_viz=n_sls_viz): if isinstance(color_by_streamline, dict): if name in color_by_streamline: @@ -598,8 +602,8 @@ def _draw_slice(figure, axis, volume, opacity=0.3, pos=0.5, colorscale="greys", invert_colorscale=False): height = int(volume.shape[axis] * pos) - v_min = np.percentile(volume, 10) - sf = np.percentile(volume, 90) - v_min + v_min = np.percentile(volume, 20) + sf = np.percentile(volume, 80) - v_min if axis == Axes.X: X, Y, Z = np.mgrid[height:height + 1, @@ -808,6 +812,7 @@ def _draw_core(sls, n_points, figure, bundle_name, indiv_profile, def single_bundle_viz(indiv_profile, seg_sft, bundle, scalar_name, + affine=None, flip_axes=[False, False, False], labelled_nodes=[0, -1], figure=None, @@ -831,6 +836,10 @@ def single_bundle_viz(indiv_profile, seg_sft, scalar_name : str The name of the scalar being used. + affine : ndarray (4, 4), optional + Affine of the image to register streamlines to. + Default: None + flip_axes : ndarray Which axes to flip, to orient the image as RAS, which is how we visualize. @@ -866,7 +875,7 @@ def single_bundle_viz(indiv_profile, seg_sft, n_points = len(indiv_profile) sls, _, bundle_name, dimensions = next(vut.tract_generator( - seg_sft, bundle, None, n_points)) + seg_sft, bundle, None, n_points, affine)) line_color = _draw_core( sls, n_points, figure, bundle_name, indiv_profile, diff --git a/AFQ/viz/utils.py b/AFQ/viz/utils.py index ad1d379ae..a9e9552f4 100644 --- a/AFQ/viz/utils.py +++ b/AFQ/viz/utils.py @@ -8,10 +8,11 @@ import matplotlib.pyplot as plt import matplotlib.patches as mpatches -import matplotlib.transforms as mtransforms import nibabel as nib import dipy.tracking.streamlinespeed as dps +from dipy.tracking.streamline import transform_streamlines +from dipy.io.stateful_tractogram import StatefulTractogram, Space from dipy.align import resample import AFQ.utils.volume as auv @@ -381,8 +382,17 @@ def viz_import_msg_error(module): return msg +def _sls_to_t1(sls, ref_sft, t1_affine): + sft = StatefulTractogram.from_sft(sls, ref_sft) + sft.to_rasmm() + sls = transform_streamlines( + sft.streamlines, + np.linalg.inv(t1_affine)) + return sls + + def tract_generator(trk_file, bundle, colors, n_points, - n_sls_viz=65536, n_sls_min=256): + t1_affine, n_sls_viz=65536, n_sls_min=256): """ Generates bundles of streamlines from the tractogram. Only generates from relevant bundle if bundle is set. @@ -407,11 +417,15 @@ def tract_generator(trk_file, bundle, colors, n_points, n_points to resample streamlines to before plotting. If None, no resampling is done. + t1_affine : ndarray (4, 4) + Affine of the T1-weighted image to register streamlines to. + n_sls_viz : int Number of streamlines to randomly select if plotting all bundles. Selections will be proportional to the original number of streamlines per bundle. Default: 3600 + n_sls_min : int Minimun number of streamlines to display per bundle. Default: 75 @@ -429,7 +443,10 @@ def tract_generator(trk_file, bundle, colors, n_points, if colors is None: colors = gen_color_dict(seg_sft.bundle_names) - seg_sft.sft.to_vox() + if t1_affine is None: + t1_affine = np.eye(4) + + seg_sft.sft.to_rasmm() streamlines = seg_sft.sft.streamlines viz_logger.info("Generating colorful lines from tractography...") @@ -445,7 +462,8 @@ def tract_generator(trk_file, bundle, colors, n_points, streamlines = streamlines[idx] if n_points is not None: streamlines = dps.set_number_of_points(streamlines, n_points) - yield streamlines, colors[0], "all_bundles", seg_sft.sft.dimensions + yield _sls_to_t1(streamlines, seg_sft.sft, t1_affine), \ + colors[0], "all_bundles", seg_sft.sft.dimensions else: if bundle is None: # No selection: visualize all of them: @@ -465,7 +483,8 @@ def tract_generator(trk_file, bundle, colors, n_points, color = colors[bundle_name] else: color = colors[0] - yield these_sls, color, bundle_name, seg_sft.sft.dimensions + yield _sls_to_t1(these_sls, seg_sft.sft, t1_affine), \ + color, bundle_name, seg_sft.sft.dimensions else: these_sls = seg_sft.get_bundle(bundle).streamlines if len(these_sls) > n_sls_viz: @@ -478,7 +497,8 @@ def tract_generator(trk_file, bundle, colors, n_points, color = colors[bundle] else: color = colors[0] - yield these_sls, color, bundle, seg_sft.sft.dimensions + yield _sls_to_t1(these_sls, seg_sft.sft, t1_affine), \ + color, bundle, seg_sft.sft.dimensions def bbox(img): diff --git a/NOTICE.md b/NOTICE.md new file mode 100644 index 000000000..16b046bbc --- /dev/null +++ b/NOTICE.md @@ -0,0 +1,4 @@ +## Third-Party Code Attribution +- `AFQ/models/asym_filtering.py` from Scilpy (github.com/scilus/scilpy) + Copyright (c) 2012-- Sherbrooke Connectivity Imaging Lab [SCIL], Université de Sherbrooke. + Used under the MIT License. diff --git a/docs/Makefile b/docs/Makefile index 1ed10da5f..f12775089 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -42,7 +42,9 @@ distclean: clean # delete more files than distclean. this would also remove the data files. realclean: distclean - @echo Removing data files from $(HOME) + @echo "WARNING: This will delete all data in your AFQ_data folder and .dipy directory." + @echo -n "Proceed? [y/N] " && read ans && [ $${ans:-N} = y ] + rm -rf $(HOME)/.dipy/ rm -rf $(HOME)/AFQ_data/ diff --git a/docs/source/_static/endpoint_maps_threshold_2.png b/docs/source/_static/endpoint_maps_threshold_2.png new file mode 100644 index 000000000..54188070a Binary files /dev/null and b/docs/source/_static/endpoint_maps_threshold_2.png differ diff --git a/docs/source/_static/endpoint_maps_threshold_3.png b/docs/source/_static/endpoint_maps_threshold_3.png new file mode 100644 index 000000000..f76fb51c6 Binary files /dev/null and b/docs/source/_static/endpoint_maps_threshold_3.png differ diff --git a/docs/source/_static/endpoint_maps_threshold_4.png b/docs/source/_static/endpoint_maps_threshold_4.png new file mode 100644 index 000000000..7863cc7dc Binary files /dev/null and b/docs/source/_static/endpoint_maps_threshold_4.png differ diff --git a/docs/source/conf.py b/docs/source/conf.py index 3c1f9a0a6..064ed2383 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -239,6 +239,7 @@ 'gallery_dirs': ['howto/howto_examples', 'tutorials/tutorial_examples'], 'image_scrapers': image_scrapers, 'reset_modules': (reset_progressbars), + 'filename_pattern': r'/plot_(?!.*(003_rerun|006_bids_layout)).*\.py$', 'show_memory': True, 'abort_on_example_error': True, 'within_subsection_order': FileNameSortKey, diff --git a/docs/source/reference/bundledict.rst b/docs/source/reference/bundledict.rst index cde8dad90..8d5003216 100644 --- a/docs/source/reference/bundledict.rst +++ b/docs/source/reference/bundledict.rst @@ -6,41 +6,88 @@ pyAFQ has a system for defining custom bundles. Custom bundles are defined by passing a custom `bundle_info` dictionary to :class:`AFQ.api.bundle_dict.BundleDict`: The keys of `bundle_info` are bundle names; the values are another dictionary describing the bundle, with these -key-value pairs:: - - - 'include' : a list of paths to Nifti files containing inclusion ROI(s). - One must either have at least 1 include ROI, or 'start' or 'end' ROIs. - - 'exclude' : a list of paths to Nifti files containing exclusion ROI(s), - optional. - - 'start' : path to a Nifti file containing the start ROI, optional - - 'end' : path to a Nifti file containing the end ROI, optional - - 'cross_midline' : boolean describing whether the bundle is required to - cross the midline (True) or prohibited from crossing (False), optional. - If None, the bundle may or may not cross the midline. - - 'space' : a string which is either 'template' or 'subject', optional +key-value pairs: + +- 'include' : a list of paths to Nifti files containing inclusion ROI(s). + One must either have at least 1 include ROI, or 'start' or 'end' ROIs. +- 'exclude' : a list of paths to Nifti files containing exclusion ROI(s), + optional. +- 'start' : path to a Nifti file containing the start ROI, optional +- 'end' : path to a Nifti file containing the end ROI, optional +- 'cross_midline' : boolean describing whether the bundle is required to + cross the midline (True) or prohibited from crossing (False), optional. + If None, the bundle may or may not cross the midline. +- 'space' : a string which is either 'template' or 'subject', optional If this field is not given or 'template' is given, the ROI will be transformed from template to subject space before being used. - - 'prob_map' : path to a Nifti file which is the probability map, - optional. - - 'inc_addtol' : List of floats describing how much tolerance to add or - subtract in mm from each of the inclusion ROIs. The list must be the - same length as 'include'. optional. - - 'exc_addtol' : List of floats describing how much tolerance to add or - subtract in mm from each of the exclusion ROIs. The list must be the - same length as 'exclude'. optional. - - 'mahal': Dict describing the parameters for cleaning. By default, we - use the default behavior of the seg.clean_bundle function. - - 'recobundles': Dict which should contain an 'sl' key and 'centroid' - key. The 'sl' key should be the reference streamline and the 'centroid' - key should be the centroid threshold for Recobundles. - - 'qb_thresh': Float which is the threshold for Quickbundles cleaning. - - 'primary_axis': string which is the primary axis the - bundle should travel in. Can be one of: 'L/R', 'P/A', 'I/S'. - - 'primary_axis_percentage': Used with primary_axis, defines what fraction - of a streamlines movement should be in the primary axis. - - 'length': dicitonary containing 'min_len' and 'max_len' +- 'prob_map' : path to a Nifti file which is the probability map, + optional. +- 'inc_addtol' : List of floats describing how much tolerance to add or + subtract in mm from each of the inclusion ROIs. The list must be the + same length as 'include'. optional. +- 'exc_addtol' : List of floats describing how much tolerance to add or + subtract in mm from each of the exclusion ROIs. The list must be the + same length as 'exclude'. optional. +- 'mahal': Dict describing the parameters for cleaning. By default, we + use the default behavior of the seg.clean_bundle function. +- 'recobundles': Dict which should contain an 'sl' key and 'centroid' + key. The 'sl' key should be the reference streamline and the 'centroid' + key should be the centroid threshold for Recobundles. +- 'qb_thresh': Float which is the threshold for Quickbundles cleaning. +- 'primary_axis': string which is the primary axis the + bundle should travel in. Can be one of: 'L/R', 'P/A', 'I/S'. +- 'primary_axis_percentage': Used with primary_axis, defines what fraction + of a streamlines movement should be in the primary axis. +- 'length': dicitonary containing 'min_len' and 'max_len' +- 'curvature': +- 'mahal': done by default unless orient_mahal or isolation_forest + are specified. Dictionary with optional keys 'n_points', 'core_only', + 'min_sl', 'distance_threshold', and 'clean_rounds'. These parameters + control the Mahalanobis distance cleaning of the bundle, further + information can be found at :func:`AFQ.recognition.cleaning.clean_bundle`. +- 'orient_mahal': cleans streamlines based on Mahalanobis distance of their + orientation to the mean orientation of the bundle. It should be a + dictionary which can be empty or contain n_points, core_only, min_sl, + distance_threshold, or clean_rounds as in 'mahal'. +- 'isolation_forest': dictionary with optional key 'percent_outlier_thresh', + which gives the percentage threshold for outliers (default 25). + + +Filtering by Other Bundles +========================== +Custom bundle definitions can also include keys that match the names of other +bundles in the same `BundleDict`. This allows you to filter streamlines in one +bundle based on their spatial relationship to another bundle. Note bundles are +segmented in the order they appear in their `BundleDict`, so later bundles cannot +be used to segment earlier bundles. The following options are supported: + +- **`overlap`** - Keeps streamlines that spatially overlap with another bundle + by at least the given node threshold. +- **`node_thresh`** - Remove streamlines that share at least the specified number + of nodes with another bundle. +- **`core`** - Removes streamlines based on whether their closest point lies + on the specified side of the *core* of another bundle. The value should be + one of `'Left'`, `'Right'`, `'Anterior'`, `'Posterior'`, `'Superior'`, + or `'Inferior'`. +- **`entire_core`** - Similar to `core`, but the entire streamline must lie on + the correct side of the core to be retained, not just the closest point. + +These references allow defining tracts relative to previously recognized +bundles. For example, the Vertical Occipital Fasciculus (VOF) can be defined in +relation to the Left Arcuate and Inferior Longitudinal fasciculi: +.. code-block:: python + 'Left Vertical Occipital': { + 'cross_midline': False, + 'start': templates['VOF_L_start'], + 'end': templates['VOF_L_end'], + 'Left Arcuate': {'node_thresh': 20}, + 'Left Inferior Longitudinal': {'core': 'Left'}, + } + +Filtering Order +=============== When doing bundle recognition, streamlines are filtered out from the whole tractography according to the series of steps defined in the bundle dictionaries. Of course, no one bundle uses every step, but here is the order @@ -52,49 +99,59 @@ of the steps: 5. Min and Max length 6. Primary axis 7. Include - 8. Curvature - 9. Exclude + 8. Exclude + 9. Curvature 10. Recobundles - 11. Quickbundles Cleaning - 12. Mahalanobis Cleaning + 11. Cleaning by other bundles + 12. Mahalanobis Orientation Cleaning + 13. Isoldation Forest Cleaning + 14. Quickbundles Cleaning + 15. Mahalanobis Cleaning If a streamline passes all steps for a bundle, it is included in that bundle. If a streamline passess all steps for multiple bundles, then a warning is thrown and the tie goes to whichever bundle is first in the bundle dictionary. - -If, for debugging purposes, you want to save out the streamlines -remaining after each step, set `save_intermediates` to a path in -`segmentation_params`. Then the streamlines will be saved out after each step -to that path. Only do this for one subject at a time. +.. note:: + If, for debugging purposes, you want to save out the streamlines + remaining after each step, set `save_intermediates` to a path in + `segmentation_params`. Then the streamlines will be saved out after each step + to that path. Only do this for one subject at a time. +Examples +======== Custom bundle definitions such as the OR, and the standard BundleDict can be combined through addition. For an example, see `Plotting the Optic Radiations `_. Some tracts, such as the Vertical Occipital Fasciculus, may be defined relative to other tracts. In those cases, the custom tract definitions should appear in the BundleDict object after the reference tracts have been defined. These reference tracts can -be included as keys in the same dictionary for that tract. For example:: - - newVOF = abd.BundleDict({ - 'Left Vertical Occipital': {'cross_midline': False, - 'space': 'template', - 'start': templates['VOF_L_start'], - 'end': templates['VOF_L_end'], - 'inc_addtol': [4, 0], - 'Left Arcuate': { - 'node_thresh': 20}, - 'Left Posterior Arcuate': { - 'node_thresh': 1, - 'core': 'Posterior'}, - 'Left Inferior Longitudinal': { - 'core': 'Left'}, - 'primary_axis': 'I/S', - 'primary_axis_percentage': 40} - }) +be included as keys in the same dictionary for that tract. For example: + +.. code-block:: python + + newVOF = abd.BundleDict({ + 'Left Vertical Occipital': { + 'cross_midline': False, + 'space': 'template', + 'start': templates['VOF_L_start'], + 'end': templates['VOF_L_end'], + 'inc_addtol': [4, 0], + 'Left Arcuate': { + 'node_thresh': 20}, + 'Left Posterior Arcuate': { + 'node_thresh': 1, + 'core': 'Posterior'}, + 'Left Inferior Longitudinal': { + 'core': 'Left'}, + 'primary_axis': 'I/S', + 'primary_axis_percentage': 40} + }) This definition of the VOF in the custom BundleDict would first require left ARC, left pARC, and left ILF to be defined, in the same way the tiebreaker above works. You would then construct your custom -BundleDict like this. The order of addition matters here:: +BundleDict like this. The order of addition matters here: - BundleDictCustomVOF = abd.default18_bd() + newVOF +.. code-block:: python + + BundleDictCustomVOF = abd.default18_bd() + newVOF diff --git a/docs/source/reference/kwargs.rst b/docs/source/reference/kwargs.rst index 978658348..b32df3bd6 100644 --- a/docs/source/reference/kwargs.rst +++ b/docs/source/reference/kwargs.rst @@ -18,25 +18,34 @@ Here are the arguments you can pass to kwargs, to customize the tractometry pipe DATA ========================================================== min_bval: float - Minimum b value you want to use from the dataset (other than b0), inclusive. If None, there is no minimum limit. Default: None + Minimum b value you want to use from the dataset (other than b0), inclusive. If None, there is no minimum limit. Default: -np.inf max_bval: float - Maximum b value you want to use from the dataset (other than b0), inclusive. If None, there is no maximum limit. Default: None - -filter_b: bool - Whether to filter the DWI data based on min or max bvals. Default: True + Maximum b value you want to use from the dataset (other than b0), inclusive. If None, there is no maximum limit. Default: np.inf b0_threshold: int The value of b under which it is considered to be b0. Default: 50. +ray_n_cpus: int + The number of CPUs to use for parallel processing with Ray. If None, uses the number of available CPUs minus one. Tractography and Recognition use Ray. Default: None + +numba_n_threads: int + The number of threads to use for Numba. If None, uses the number of available CPUs minus one. MSMT and ASYM fits use Numba. Default: None + robust_tensor_fitting: bool Whether to use robust_tensor_fitting when doing dti. Only applies to dti. Default: False +msmt_sh_order: int + Spherical harmonic order to use for the MSMT CSD fit. Default: 8 + +msmt_fa_thr: float + The threshold on the FA used to calculate the multi shell auto response. Can be useful to reduce for baby subjects. Default: 0.7 + csd_response: tuple or None The response function to be used by CSD, as a tuple with two elements. The first is the eigen-values as an (3,) ndarray and the second is the signal value for the response function without diffusion-weighting (i.e. S0). If not provided, auto_response will be used to calculate these values. Default: None csd_sh_order_max: int or None - default: infer the number of parameters from the number of data volumes, but no larger than 8. Default: None + If None, infer the number of parameters from the number of data volumes, but no larger than 8. Default: None csd_lambda_: float weight given to the constrained-positivity regularization part of the deconvolution equation. Default: 1 @@ -62,10 +71,10 @@ rumba_csf_response: float rumba_n_iter: int Number of iterations for fODF estimation. Must be a positive int. Default: 600 -opdt_sh_order: int +opdt_sh_order_max: int Spherical harmonics order for OPDT model. Must be even. Default: 8 -csa_sh_order: int +csa_sh_order_max: int Spherical harmonics order for CSA model. Must be even. Default: 8 sphere: Sphere class instance @@ -74,9 +83,6 @@ sphere: Sphere class instance gtol: float This input is to refine kurtosis maxima under the precision of the directions sampled on the sphere class instance. The gradient of the convergence procedure must be less than gtol before successful termination. If gtol is None, fiber direction is directly taken from the initial sampled directions of the given sphere object. Default: 1e-2 -brain_mask_definition: instance from `AFQ.definitions.image` - This will be used to create the brain mask, which gets applied before registration to a template. If you want no brain mask to be applied, use FullImage. If None, use B0Image() Default: None - bundle_info: dict or BundleDict A dictionary or BundleDict for use in segmentation. See `Defining Custom Bundle Dictionaries` in the `usage` section of pyAFQ's documentation for details. If None, will get all appropriate bundles for the chosen segmentation algorithm. Default: None @@ -101,7 +107,10 @@ reg_subject_spec: str SEGMENTATION ========================================================== segmentation_params: dict - The parameters for segmentation. Default: use the default behavior of the seg.Segmentation object. + The parameters for segmentation. Defaults to using the default behavior of the seg.Segmentation object. + +endpoint_threshold: float + The threshold for the endpoint maps. If None, no endpoint maps are exported as distance to endpoints maps, which the user can then threshold as needed. Default: 3 profile_weights: str How to weight each streamline (1D) or each node (2D) when calculating the tract-profiles. If callable, this is a function that calculates weights. If None, no weighting will be applied. If "gauss", gaussian weights will be used. If "median", the median of values at each node will be used instead of a mean or weighted mean. Default: "gauss" @@ -117,7 +126,7 @@ scalars: list of strings and/or scalar definitions TRACTOGRAPHY ========================================================== tracking_params: dict - The parameters for tracking. Default: use the default behavior of the aft.track function. Seed mask and seed threshold, if not specified, are replaced with scalar masks from scalar[0] thresholded to 0.2. The ``seed_mask`` and ``stop_mask`` items of this dict may be ``AFQ.definitions.image.ImageFile`` instances. If ``tracker`` is set to "pft" then ``stop_mask`` should be an instance of ``AFQ.definitions.image.PFTImage``. + The parameters for tracking. Defaults to using the default behavior of the aft.track function. Seed mask and seed threshold, if not specified, are replaced with scalar masks from scalar[0] thresholded to 0.2. The ``seed_mask`` and ``stop_mask`` items of this dict may be ``AFQ.definitions.image.ImageFile`` instances. If ``tracker`` is set to "pft" then ``stop_mask`` should be an instance of ``AFQ.definitions.image.PFTImage``. import_tract: dict or str or None BIDS filters for inputing a user made tractography file, or a path to the tractography file. If None, DIPY is used to generate the tractography. Default: None @@ -151,7 +160,7 @@ n_points_indiv: int or None n_points to resample streamlines to before plotting. If None, no resampling is done. Default: 40 virtual_frame_buffer: bool - Whether to use a virtual fram buffer. This is neccessary if generating GIFs in a headless environment. Default: False + Whether to use a virtual frame buffer. This is neccessary if generating GIFs in a headless environment. Default: False viz_backend_spec: str Which visualization backend to use. See Visualization Backends page in documentation for details https://tractometry.org/pyAFQ/reference/viz_backend.html One of {"fury", "plotly", "plotly_no_gif"}. Default: "plotly_no_gif" diff --git a/docs/source/reference/methods.rst b/docs/source/reference/methods.rst index f0775a0df..441cf93ca 100644 --- a/docs/source/reference/methods.rst +++ b/docs/source/reference/methods.rst @@ -35,6 +35,18 @@ base_fname: Base file name for outputs +pve_wm: + White matter partial volume estimate map + + +pve_gm: + Gray matter partial volume estimate map + + +pve_csf: + Cerebrospinal fluid partial volume estimate map + + data: DWI data as an ndarray for selected b values @@ -51,6 +63,14 @@ dwi_affine: the affine transformation of the DWI data +n_cpus: + Configure the number of CPUs to use for parallel processing with Ray + + +n_threads: + the number of threads to use for Numba + + b0: full path to a nifti file containing the mean b0 @@ -59,6 +79,14 @@ masked_b0: full path to a nifti file containing the mean b0 after applying the brain mask +t1w_pve: + WM, GM, CSF segmentations from subcortex segmentation from brainchop on T1w image + + +wm_gm_interface: + + + dti_tf: DTI TensorFit object @@ -99,10 +127,38 @@ msdki_msk: full path to a nifti file containing the MSDKI mean signal kurtosis +msmtcsd_params: + full path to a nifti file containing parameters for the MSMT CSD fit + + +msmt_apm: + full path to a nifti file containing the anisotropic power map + + +msmt_aodf_params: + full path to a nifti file containing MSMT CSD ODFs filtered by unified filtering [1] + + +msmt_aodf_asi: + full path to a nifti file containing the MSMT CSD Asymmetric Index (ASI) [1] + + +msmt_aodf_opm: + full path to a nifti file containing the MSMT CSD odd-power map [1] + + +msmt_aodf_nufid: + full path to a nifti file containing the MSMT CSD Number of fiber directions (nufid) map [1] + + csd_params: full path to a nifti file containing parameters for the CSD fit +csd_aodf_params: + full path to a nifti file containing SSST CSD ODFs filtered by unified filtering [1] + + csd_pmap: full path to a nifti file containing the anisotropic power map @@ -355,6 +411,18 @@ dki_kfa: full path to a nifti file containing the DKI kurtosis FA file +dki_cl: + full path to a nifti file containing the DKI linearity file + + +dki_cp: + full path to a nifti file containing the DKI planarity file + + +dki_cs: + full path to a nifti file containing the DKI sphericity file + + dki_ga: full path to a nifti file containing the DKI geodesic anisotropy @@ -375,6 +443,18 @@ dki_ak: full path to a nifti file containing the DKI axial kurtosis file +t1_brain_mask: + full path to a nifti file containing brain mask from T1w image, + + +t1_masked: + full path to a nifti file containing the T1w masked + + +t1_subcortex: + full path to a nifti file containing segmentation of subcortical structures from T1w image using Brainchop + + brain_mask: full path to a nifti file containing the brain mask @@ -431,6 +511,10 @@ density_maps: full path to 4d nifti file containing streamline counts per voxel per bundle, where the 4th dimension encodes the bundle +endpoint_maps: + full path to a NIfTI file containing endpoint maps for each bundle + + profiles: full path to a CSV file containing tract profiles diff --git a/examples/howto_examples/acoustic_radiations.py b/examples/howto_examples/acoustic_radiations.py index 3216f797a..31a05bbef 100644 --- a/examples/howto_examples/acoustic_radiations.py +++ b/examples/howto_examples/acoustic_radiations.py @@ -79,18 +79,11 @@ # are passed as `bundle_info=bundles`. The call to `my_afq.export_all()` # initiates the pipeline. -brain_mask_definition = ImageFile( - suffix="mask", - filters={'desc': 'brain', - 'space': 'T1w', - 'scope': 'qsiprep'}) - my_afq = GroupAFQ( bids_path=study_dir, preproc_pipeline="qsiprep", participant_labels=["NDARAA948VFH"], output_dir=op.join(study_dir, "derivatives", "afq_ar"), - brain_mask_definition=brain_mask_definition, tracking_params={"n_seeds": 4, "directions": "prob", "odf_model": "CSD", diff --git a/examples/howto_examples/add_custom_bundle.py b/examples/howto_examples/add_custom_bundle.py index 9462ed6cc..db34d06dd 100644 --- a/examples/howto_examples/add_custom_bundle.py +++ b/examples/howto_examples/add_custom_bundle.py @@ -169,22 +169,14 @@ # are passed as `bundle_info=bundles`. The call to `my_afq.export_all()` # initiates the pipeline. -brain_mask_definition = ImageFile( - suffix="mask", - filters={'desc': 'brain', - 'space': 'T1w', - 'scope': 'qsiprep'}) - my_afq = GroupAFQ( bids_path=study_dir, preproc_pipeline="qsiprep", output_dir=op.join(study_dir, "derivatives", "afq_slf"), - brain_mask_definition=brain_mask_definition, tracking_params={"n_seeds": 4, "directions": "prob", "odf_model": "CSD", "seed_mask": RoiImage()}, - segmentation_params={"parallel_segmentation": {"engine": "serial"}}, bundle_info=bundles) # If you want to redo different stages you can use the `clobber` method. diff --git a/examples/howto_examples/cerebellar_peduncles.py b/examples/howto_examples/cerebellar_peduncles.py index 3659d54ca..5e6e8bed2 100644 --- a/examples/howto_examples/cerebellar_peduncles.py +++ b/examples/howto_examples/cerebellar_peduncles.py @@ -49,18 +49,6 @@ """ The bundle dict has been defined, and now we are ready to run the AFQ pipeline. -In this case, we are using data that has been preprocessed with QSIprep, so -we have a brain mask that was generated from the T1w data of this subject. -""" - -brain_mask_definition = ImageFile( - suffix="mask", - filters={'desc': 'brain', - 'space': 'T1w', - 'scope': 'qsiprep'}) - - -""" Next, we define a GroupAFQ object. In this case, the tracking parameters focus specifically on the CP, by using the ``RoiImage`` class to define the seed region. We seed extensively in the ROIs that define the CPs. @@ -70,7 +58,6 @@ name="cp_afq", bids_path=bids_path, preproc_pipeline="qsiprep", - brain_mask_definition=brain_mask_definition, tracking_params={ "n_seeds": 4, "directions": "prob", diff --git a/examples/howto_examples/cloudknot_example.py b/examples/howto_examples/cloudknot_example.py index 1f82a3f02..dd5391d6b 100644 --- a/examples/howto_examples/cloudknot_example.py +++ b/examples/howto_examples/cloudknot_example.py @@ -62,19 +62,10 @@ def afq_process_subject(subject): "local_bids_dir", include_derivs=["pipeline_name"]) - # you can optionally provide your own segmentation file - # in this case, we look for a file with suffix 'seg' - # in the 'pipeline_name' pipeline, - # and we consider all non-zero labels to be a part of the brain - brain_mask_definition = afm.LabelledImageFile( - suffix='seg', filters={'scope': 'pipeline_name'}, - exclusive_labels=[0]) - # define the api AFQ object myafq = GroupAFQ( "local_bids_dir", preproc_pipeline="pipeline_name", - brain_mask_definition=brain_mask_definition, viz_backend_spec='plotly', # this will generate both interactive html and GIFs # noqa scalars=["dki_fa", "dki_md"]) diff --git a/examples/howto_examples/cloudknot_hcp_example.py b/examples/howto_examples/cloudknot_hcp_example.py index 9f25b88c0..b58da085d 100644 --- a/examples/howto_examples/cloudknot_hcp_example.py +++ b/examples/howto_examples/cloudknot_hcp_example.py @@ -76,16 +76,9 @@ def afq_process_subject(subject, seed_mask, n_seeds, "n_seeds": n_seeds, "random_seeds": random_seeds} - # use segmentation file from HCP to get a brain mask, - # where everything not labelled 0 is considered a part of the brain - brain_mask_definition = afm.LabelledImageFile( - suffix='seg', filters={'scope': 'dmriprep'}, - exclusive_labels=[0]) - # define the api GroupAFQ object myafq = GroupAFQ( hcp_bids, - brain_mask_definition=brain_mask_definition, tracking_params=tracking_params) # export_all runs the entire pipeline and creates many useful derivates diff --git a/examples/howto_examples/optic_radiations.py b/examples/howto_examples/optic_radiations.py index dd07d933d..9c5d05d67 100644 --- a/examples/howto_examples/optic_radiations.py +++ b/examples/howto_examples/optic_radiations.py @@ -86,18 +86,11 @@ # are passed as `bundle_info=bundles`. The call to `my_afq.export_all()` # initiates the pipeline. -brain_mask_definition = ImageFile( - suffix="mask", - filters={'desc': 'brain', - 'space': 'T1w', - 'scope': 'qsiprep'}) - my_afq = GroupAFQ( bids_path=study_dir, preproc_pipeline="qsiprep", participant_labels=["NDARAA948VFH"], output_dir=op.join(study_dir, "derivatives", "afq_or"), - brain_mask_definition=brain_mask_definition, tracking_params={"n_seeds": 4, "directions": "prob", "odf_model": "CSD", diff --git a/examples/howto_examples/plot_afq_callosal.py b/examples/howto_examples/run_afq_callosal.py similarity index 98% rename from examples/howto_examples/plot_afq_callosal.py rename to examples/howto_examples/run_afq_callosal.py index e38297ef7..0699d6f7a 100644 --- a/examples/howto_examples/plot_afq_callosal.py +++ b/examples/howto_examples/run_afq_callosal.py @@ -35,7 +35,7 @@ # We only do this to make this example faster and consume less space. tracking_params = dict(seed_mask=RoiImage(), - n_seeds=10000, + n_seeds=25000, random_seeds=True, rng_seed=42) @@ -67,6 +67,7 @@ myafq = GroupAFQ( bids_path=op.join(afd.afq_home, 'stanford_hardi'), preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', bundle_info=abd.callosal_bd(), tracking_params=tracking_params, segmentation_params=segmentation_params, diff --git a/examples/howto_examples/run_afq_fwdti.py b/examples/howto_examples/run_afq_fwdti.py index 0e21478a3..32e41e8c4 100644 --- a/examples/howto_examples/run_afq_fwdti.py +++ b/examples/howto_examples/run_afq_fwdti.py @@ -45,20 +45,14 @@ # -------------------- # In addition to preprocessd dMRI data, HBN-POD2 contains brain mask and mapping # information for each subject. We can use this information in our pipeline, by -# inserting this information as `mapping_definition` and `brain_mask_definition` -# inputs to the `GroupAFQ` class initializer. When initializing this object, we -# will also ask for the fwDTI scalars to be computed. For expedience, we will -# limit our investigation to the bilateral arcuate fasciculus and track only -# around that bundle. If you would like to do this for all bundles, you would -# remove the `bundle_dict` and `tracking_params` inputs to the initializer that +# inserting this information as `mapping_definition` inputs to the `GroupAFQ` class +# initializer. When initializing this object, we will also ask for the fwDTI scalars +# to be computed. For expedience, we will limit our investigation to the bilateral +# arcuate fasciculus and track only around that bundle. If you would like to do this +# for all bundles, you would remove the `bundle_dict` and `tracking_params` inputs +# to the initializer that # are provided below. -brain_mask_definition = ImageFile( - suffix="mask", - filters={'desc': 'brain', - 'space': 'T1w', - 'scope': 'qsiprep'}) - bundle_names = ["Left Arcuate", "Right Arcuate"] bundle_dict = abd.default18_bd()[bundle_names] @@ -72,7 +66,6 @@ "random_seeds": True, "seed_mask": RoiImage(use_waypoints=True, use_endpoints=True), }, - brain_mask_definition=brain_mask_definition, scalars=["fwdti_fa", "fwdti_md", "fwdti_fwf", "dti_fa", "dti_md"]) ############################################################################# diff --git a/examples/howto_examples/run_pyAFQ_with_GPU.py b/examples/howto_examples/run_pyAFQ_with_GPU.py index a6f61a493..447533111 100644 --- a/examples/howto_examples/run_pyAFQ_with_GPU.py +++ b/examples/howto_examples/run_pyAFQ_with_GPU.py @@ -40,6 +40,7 @@ myafq = GroupAFQ( bids_path=op.join(afd.afq_home, 'stanford_hardi'), preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', tracking_params=tracking_params, tractography_ngpus=1) diff --git a/examples/howto_examples/plot_recobundles.py b/examples/howto_examples/run_recobundles.py similarity index 93% rename from examples/howto_examples/plot_recobundles.py rename to examples/howto_examples/run_recobundles.py index bfd9c042f..91fa9ef23 100644 --- a/examples/howto_examples/plot_recobundles.py +++ b/examples/howto_examples/run_recobundles.py @@ -32,7 +32,8 @@ # Parameters of this process are set through a dictionary input to the # `segmentation_params` argument of the GroupAFQ object. In this case, we # use `abd.reco_bd(16)`, which tells pyAFQ to use the RecoBundles -# algorithm for bundle recognition. +# algorithm for bundle recognition. This uses 16 bundles, there is also +# an atlas `abd.reco_bd(80)` which uses 80 bundles. myafq = GroupAFQ( output_dir=op.join(afd.afq_home, 'stanford_hardi', 'derivatives', @@ -41,6 +42,7 @@ # Set the algorithm to use RecoBundles for bundle recognition: bundle_info=abd.reco_bd(16), preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', tracking_params=tracking_params, viz_backend_spec='plotly_no_gif') diff --git a/examples/howto_examples/use_subject_space_rois_from_freesurfer.py b/examples/howto_examples/use_subject_space_rois_from_freesurfer.py index f322d35bc..0d1fbf3e4 100644 --- a/examples/howto_examples/use_subject_space_rois_from_freesurfer.py +++ b/examples/howto_examples/use_subject_space_rois_from_freesurfer.py @@ -139,6 +139,7 @@ myafq = GroupAFQ( bids_path=op.join(afd.afq_home, 'stanford_hardi'), preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', tracking_params=tracking_params, bundle_info=bundles) diff --git a/examples/howto_examples/vof_example.py b/examples/howto_examples/vof_example.py index e35c8d2b3..26e76d738 100644 --- a/examples/howto_examples/vof_example.py +++ b/examples/howto_examples/vof_example.py @@ -32,6 +32,7 @@ op.join(afd.afq_home, 'stanford_hardi'), bundle_info=bundle_dict, preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', tracking_params={ "n_seeds": 50000, "random_seeds": True, diff --git a/examples/tutorial_examples/plot_001_group_afq_api.py b/examples/tutorial_examples/plot_001_group_afq_api.py index f726a9889..85850c214 100644 --- a/examples/tutorial_examples/plot_001_group_afq_api.py +++ b/examples/tutorial_examples/plot_001_group_afq_api.py @@ -56,17 +56,13 @@ # --------------------------------------- # We make create a `tracking_params` variable, which we will pass to the # GroupAFQ object which specifies that we want 25,000 seeds randomly -# distributed in the white matter. We only do this to make this example -# faster and consume less space. We also set ``num_chunks`` to `True`, -# which will use ray to parallelize the tracking across all cores. -# This can be removed to process in serial, or set to use a particular -# distribution of work by setting `n_chunks` to an integer number. +# distributed in the white matter. We only do this to make this example faster +# and consume less space; normally, we use more seeds tracking_params = dict(n_seeds=25000, random_seeds=True, - rng_seed=2022, - trx=True, - num_chunks=True) + rng_seed=2025, + trx=True) ########################################################################## # Initialize a GroupAFQ object: @@ -92,12 +88,15 @@ # We will also be using plotly to generate an interactive visualization. # The value `plotly_no_gif` indicates that interactive visualizations will be # generated as html web-pages that can be opened in a browser, but not as -# static gif files. +# static gif files. We set ray_n_cpus=1 to avoid memory issues running this +# example on servers. myafq = GroupAFQ( bids_path=op.join(afd.afq_home, 'stanford_hardi'), preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', tracking_params=tracking_params, + ray_n_cpus=1, viz_backend_spec='plotly_no_gif') ########################################################################## @@ -223,8 +222,8 @@ for ind in bundle_counts.index: if ind == "Total Recognized": threshold = 1000 - elif "Vertical Occipital" in ind: - threshold = 1 + elif "Callosum" in ind: + threshold = 0 else: threshold = 10 if bundle_counts["n_streamlines"][ind] < threshold: diff --git a/examples/tutorial_examples/plot_002_participant_afq_api.py b/examples/tutorial_examples/plot_002_participant_afq_api.py index 913a07521..f07a5f74c 100644 --- a/examples/tutorial_examples/plot_002_participant_afq_api.py +++ b/examples/tutorial_examples/plot_002_participant_afq_api.py @@ -10,11 +10,9 @@ import matplotlib.pyplot as plt import nibabel as nib import plotly -import pandas as pd from AFQ.api.participant import ParticipantAFQ import AFQ.data.fetch as afd -import AFQ.viz.altair as ava ########################################################################## # Example data @@ -39,7 +37,7 @@ # stored in the ``~/AFQ_data/stanford_hardi/`` BIDS directory. Set it to None if # you want to use the results of previous runs. -afd.organize_stanford_data(clear_previous_afq="track") +afd.organize_stanford_data() ########################################################################## # Defining data files @@ -63,31 +61,32 @@ dwi_data_file = op.join(data_dir, "sub-01_ses-01_dwi.nii.gz") bval_file = op.join(data_dir, "sub-01_ses-01_dwi.bval") bvec_file = op.join(data_dir, "sub-01_ses-01_dwi.bvec") +t1_file = op.join(afd.afq_home, "stanford_hardi", "derivatives", + "freesurfer", "sub-01", "ses-01", "anat", + "sub-01_ses-01_T1w.nii.gz") # You will also need to define the output directory where you want to store the # results. The output directory needs to exist before exporting ParticipantAFQ # results. output_dir = op.join(afd.afq_home, "stanford_hardi", - "derivatives", "afq", "sub-01") + "derivatives", "afq", "sub-01", + "ses-01", "dwi") os.makedirs(output_dir, exist_ok=True) ########################################################################## # Set tractography parameters (optional) # --------------------------------------- # We make create a `tracking_params` variable, which we will pass to the -# ParticipantAFQ object which specifies that we want 25,000 seeds randomly -# distributed in the white matter. We only do this to make this example -# faster and consume less space. We also set ``num_chunks`` to `True`, -# which will use ray to parallelize the tracking across all cores. -# This can be removed to process in serial, or set to use a particular -# distribution of work by setting `n_chunks` to an integer number. +# ParticipantAFQ object which specifies that we want 10,000 seeds randomly +# distributed in the white matter, propogated using DIPY's probabilistic +# algorithm. We only do this to make this example faster +# and consume less space; normally, we use more seeds tracking_params = dict(n_seeds=25000, random_seeds=True, - rng_seed=2022, - trx=True, - num_chunks=True) + rng_seed=2025, + trx=True) ########################################################################## # Initialize a ParticipantAFQ object: @@ -105,14 +104,17 @@ # # To initialize the object, we will pass in the diffusion data files and specify # the output directory where we want to store the results. We will also -# pass in the tracking parameters we defined above. +# pass in the tracking parameters we defined above. We set ray_n_cpus=1 +# to avoid memory issues running this example on servers. myafq = ParticipantAFQ( dwi_data_file=dwi_data_file, bval_file=bval_file, bvec_file=bvec_file, + t1_file=t1_file, output_dir=output_dir, tracking_params=tracking_params, + ray_n_cpus=1, ) ########################################################################## diff --git a/examples/tutorial_examples/plot_003_rerun.py b/examples/tutorial_examples/plot_003_rerun.py index e46698706..d9387340b 100644 --- a/examples/tutorial_examples/plot_003_rerun.py +++ b/examples/tutorial_examples/plot_003_rerun.py @@ -32,12 +32,12 @@ tracking_params = dict(n_seeds=100, random_seeds=True, rng_seed=2022, - trx=True, - num_chunks=True) + trx=True) myafq = GroupAFQ( bids_path=op.join(afd.afq_home, 'stanford_hardi'), preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', tracking_params=tracking_params) ################### @@ -63,6 +63,7 @@ myafq = GroupAFQ( bids_path=op.join(afd.afq_home, 'stanford_hardi'), preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', b0_threshold=100, tracking_params=tracking_params) @@ -99,12 +100,12 @@ random_seeds=True, max_angle=60, rng_seed=12, - trx=True, - num_chunks=True) + trx=True) myafq = GroupAFQ( bids_path=op.join(afd.afq_home, 'stanford_hardi'), preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', b0_threshold=100, tracking_params=tracking_params) diff --git a/examples/tutorial_examples/plot_004_export.py b/examples/tutorial_examples/plot_004_export.py index 56f5bc683..1577ceeb6 100644 --- a/examples/tutorial_examples/plot_004_export.py +++ b/examples/tutorial_examples/plot_004_export.py @@ -23,7 +23,7 @@ # :doc:`plot_002_participant_afq_api` example. Please refer to that # example for a detailed description of the parameters. -afd.organize_stanford_data(clear_previous_afq="track") +afd.organize_stanford_data() data_dir = op.join(afd.afq_home, "stanford_hardi", "derivatives", "vistasoft", "sub-01", "ses-01", "dwi") @@ -31,6 +31,9 @@ dwi_data_file = op.join(data_dir, "sub-01_ses-01_dwi.nii.gz") bval_file = op.join(data_dir, "sub-01_ses-01_dwi.bval") bvec_file = op.join(data_dir, "sub-01_ses-01_dwi.bvec") +t1_file = op.join(afd.afq_home, "stanford_hardi", "derivatives", + "freesurfer", "sub-01", "ses-01", "anat", + "sub-01_ses-01_T1w.nii.gz") output_dir = op.join(afd.afq_home, "stanford_hardi", "derivatives", "afq", "sub-01", "ses-01", "dwi") @@ -41,13 +44,14 @@ dwi_data_file=dwi_data_file, bval_file=bval_file, bvec_file=bvec_file, + t1_file=t1_file, output_dir=output_dir, + ray_n_cpus=1, tracking_params={ - "n_seeds": 25000, + "n_seeds": 10000, "random_seeds": True, "rng_seed": 2022, - "trx": True, - "num_chunks": True + "trx": True }, ) diff --git a/examples/tutorial_examples/plot_006_bids_layout.py b/examples/tutorial_examples/plot_006_bids_layout.py index dd77b49e7..846a6dacd 100644 --- a/examples/tutorial_examples/plot_006_bids_layout.py +++ b/examples/tutorial_examples/plot_006_bids_layout.py @@ -245,6 +245,7 @@ my_afq = GroupAFQ( bids_path, preproc_pipeline='vistasoft', + t1_pipeline='freesurfer', bundle_info=bundle_info, import_tract={ "suffix": "tractography", diff --git a/examples/tutorial_examples/viz_008_endpoints.py b/examples/tutorial_examples/viz_008_endpoints.py new file mode 100644 index 000000000..34d8c1dcd --- /dev/null +++ b/examples/tutorial_examples/viz_008_endpoints.py @@ -0,0 +1,184 @@ +""" +=================== +PyAFQ Endpoint Maps +=================== +Here we extract endpoint maps for pyAFQ run under different configurations +for an HBN subject. +""" + +#################################################### +# Import libraries, load the defautl tract templates +import matplotlib +matplotlib.use('Agg') # Use Agg backend for headless plotting +import matplotlib.pyplot as plt + +import nibabel as nib +import numpy as np + +from AFQ.api.group import GroupAFQ +import AFQ.definitions.image as afm +from dipy.data import get_sphere + +import os.path as op +import os +import AFQ.data.fetch as afd + +from AFQ.viz.utils import COLOR_DICT +from dipy.align import resample + +############################################################ +# Use an example subject from the Healthy Brain Network (HBN). + +subject_id = "NDARKP893TWU" # Example subject ID +ses_id = "HBNsiteRU" # Example session ID +_, study_dir = afd.fetch_hbn_preproc([subject_id]) + +endpoint_maps = { # Endpoint maps by threshold in mm + "2": {}, + "3": {}, + "4": {} +} + +bundle_names = [ # pyAFQ defaults + "Left Anterior Thalamic", "Right Anterior Thalamic", + "Left Cingulum Cingulate", "Right Cingulum Cingulate", + "Left Corticospinal", "Right Corticospinal", + "Left Inferior Fronto-occipital", "Right Inferior Fronto-occipital", + "Left Inferior Longitudinal", "Right Inferior Longitudinal", + "Left Superior Longitudinal", "Right Superior Longitudinal", + "Left Arcuate", "Right Arcuate", + "Left Uncinate", "Right Uncinate", + "Left Posterior Arcuate", "Right Posterior Arcuate", + "Left Vertical Occipital", "Right Vertical Occipital", + "Callosum Anterior Frontal", "Callosum Motor", + "Callosum Occipital", "Callosum Orbital", + "Callosum Posterior Parietal", "Callosum Superior Frontal", + "Callosum Superior Parietal", "Callosum Temporal" +] + +########################################################################### +# Compare endpoint maps for single-shell and multi-shell CSD +# For both local (probabilistic) and PFT (particle filtering) tractography. + +for odf_model in ["csd", "msmtcsd"]: + for tracker in ["local", "pft"]: + output_dir = op.join( + study_dir, "derivatives", + f"afq_{odf_model}_{tracker}") + + myafq = GroupAFQ( + op.join(afd.afq_home, "HBN"), + participant_labels=[subject_id], + preproc_pipeline="qsiprep", + tracking_params={ + "tracker": tracker, + "odf_model": odf_model, + "sphere": get_sphere(name="repulsion724"), + "seed_mask": afm.ScalarImage("wm_gm_interface"), + "seed_threshold": 0.5, + "stop_mask": afm.ThreeTissueImage(), + "stop_threshold": "ACT", + "n_seeds": 2000000, + "random_seeds": True}, + output_dir=output_dir, + endpoint_threshold=None) + + endpoints_maps = myafq.export("endpoint_maps") + + # Copy outputs of first runs for later use + # Up to but not including tractography + if odf_model == "csd" and tracker == "local": + other_output_paths = [ + op.join(study_dir, ( + "derivatives/afq_csd_pft/" + f"sub-{subject_id}/ses-{ses_id}/dwi")), + op.join(study_dir, ( + "derivatives/afq_msmtcsd_local/" + f"sub-{subject_id}/ses-{ses_id}/dwi")), + op.join(study_dir, ( + "derivatives/afq_msmtcsd_pft/" + f"sub-{subject_id}/ses-{ses_id}/dwi")), + ] + for other_output_path in other_output_paths: + os.makedirs(other_output_path, exist_ok=True) + myafq.cmd_outputs( + "cp", suffix=other_output_path, up_to="track") + + endpoint_data = nib.load( + endpoints_maps[subject_id]).get_fdata() + endpoint_maps["2"][f"{odf_model}_{tracker}"] = \ + np.logical_and(endpoint_data < 2.0, endpoint_data != 0.0) + endpoint_maps["3"][f"{odf_model}_{tracker}"] = \ + np.logical_and(endpoint_data < 3.0, endpoint_data != 0.0) + endpoint_maps["4"][f"{odf_model}_{tracker}"] = \ + np.logical_and(endpoint_data < 4.0, endpoint_data != 0.0) + +t1 = nib.load(myafq.export("t1_file")[subject_id]) +b0 = nib.load(myafq.export("b0")[subject_id]) + +t1_dwi_space = resample( + t1.get_fdata(), + b0.get_fdata(), + moving_affine=t1.affine, + static_affine=b0.affine).get_fdata() + +# Find the best z-slice for visualization +sum_of_all_maps = np.zeros(endpoint_maps["2"]["csd_local"].shape[:3]) +for threshold, maps in endpoint_maps.items(): + for map_name, _map in maps.items(): + sum_of_all_maps += np.sum(_map, axis=-1) +best_z = np.argmax(np.sum(sum_of_all_maps, axis=(0, 1))) + +t1_slice = t1_dwi_space[..., best_z] +t1_slice = (t1_slice - t1_slice.min()) / (t1_slice.max() - t1_slice.min()) + +for threshold, maps in endpoint_maps.items(): + fig, axes = plt.subplots(2, 2, figsize=(10, 10)) + fig.suptitle(f"Endpoint Maps for Threshold < {threshold} mm") + for ii, (map_name, _map) in enumerate(maps.items()): + image = np.zeros(_map.shape[:2] + (3,)) + image_counts = np.zeros(_map.shape[:2]) + for z in range(_map.shape[3]): + # Assign each z-slice a unique color from our colormap + rgb = COLOR_DICT[bundle_names[z]] + image += _map[..., best_z, z][..., np.newaxis] * rgb + image_counts += _map[..., best_z, z] + + # Interesting Metrics + endpoint_voxel_count = np.sum(image_counts) + median_vc = np.median(np.sum(_map, axis=(0, 1, 2))) + + # Normalize the image by the number of bundles + image_counts[image_counts == 0] = 1 + image /= image_counts[..., np.newaxis] + + # Blend the T1 slice with the endpoint map + image = (np.stack([t1_slice] * 3, axis=-1) + image) / 2.0 + axes[ii // 2, ii % 2].imshow(np.rot90(image), interpolation='none') + axes[ii // 2, ii % 2].axis("off") + axes[ii // 2, ii % 2].set_title(map_name.replace("_", " ").upper()) + + axes[ii // 2, ii % 2].text( + 0.5, 0.1, + f'Total Endpoint Voxel Count {endpoint_voxel_count}', + color='white', + ha='center', va='center', + transform=axes[ii // 2, ii % 2].transAxes) + + axes[ii // 2, ii % 2].text( + 0.5, 0.05, + f'Median across bundles {median_vc}', + color='white', + ha='center', va='center', + transform=axes[ii // 2, ii % 2].transAxes) + + plt.tight_layout() + plt.savefig(f"endpoint_maps_threshold_{threshold}.png") + plt.close(fig) + +########################################################################### +# This Example would take too long to run in the documentation. +# So, we provide the results as static images. +# .. image:: ../../_static/endpoint_maps_threshold_2.png +# .. image:: ../../_static/endpoint_maps_threshold_3.png +# .. image:: ../../_static/endpoint_maps_threshold_4.png diff --git a/gpu_docker/Dockerfile b/gpu_docker/Dockerfile index 3ee362917..0cdb0461f 100644 --- a/gpu_docker/Dockerfile +++ b/gpu_docker/Dockerfile @@ -8,7 +8,7 @@ ENV DEBIAN_FRONTEND=noninteractive # upgrade RUN apt-get update && apt-get install --assume-yes apt-transport-https \ ca-certificates gnupg software-properties-common \ - gcc git wget curl numactl cmake + gcc git wget curl numactl cmake clang # Miniconda3 RUN curl -L "https://repo.anaconda.com/miniconda/Miniconda3-py312_25.3.1-1-Linux-x86_64.sh" \ @@ -18,6 +18,7 @@ RUN rm -rf /tmp/Miniconda3.sh RUN cd /opt && eval "$(/opt/anaconda/bin/conda shell.bash hook)" ENV PATH /opt/anaconda/bin:${PATH} ENV LD_LIBRARY_PATH /opt/anaconda/lib:${LD_LIBRARY_PATH} +ENV LD_LIBRARY_PATH /usr/lib/x86_64-linux-gnu:${LD_LIBRARY_PATH} # python prereqs RUN conda install -c conda-forge git diff --git a/gpu_docker/cuda_track_template.def b/gpu_docker/cuda_track_template.def index 3f304bd69..9d6956b8a 100644 --- a/gpu_docker/cuda_track_template.def +++ b/gpu_docker/cuda_track_template.def @@ -8,13 +8,14 @@ From: nvidia/cuda:12.0.1-devel-ubuntu20.04 export DEBIAN_FRONTEND=noninteractive export PATH=/opt/anaconda/bin:${PATH} export LD_LIBRARY_PATH=/opt/anaconda/lib:${LD_LIBRARY_PATH} + export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH export CXXFLAGS="-ftemplate-depth=2048" %post # System update and basic tools installation apt-get update && apt-get install --assume-yes apt-transport-https \ ca-certificates gnupg software-properties-common \ - gcc git wget curl numactl cmake + gcc git wget curl numactl cmake clang # Miniconda3 curl -L "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" \ diff --git a/pyafq_docker/Dockerfile b/pyafq_docker/Dockerfile index 47666b4c0..57d50ec29 100644 --- a/pyafq_docker/Dockerfile +++ b/pyafq_docker/Dockerfile @@ -7,6 +7,9 @@ FROM python:3.11 ARG COMMIT +RUN apt-get update && apt-get install --assume-yes cmake clang +ENV LD_LIBRARY_PATH /usr/lib/x86_64-linux-gnu:${LD_LIBRARY_PATH} + # Install pyAFQ RUN pip install --no-cache-dir git+https://github.com/tractometry/pyAFQ.git@${COMMIT} RUN pip install fslpy diff --git a/pyproject.toml b/pyproject.toml index e98405e73..c92e204a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=42", "wheel", "setuptools_scm[toml]>=3.4"] +requires = ["setuptools>=42", "wheel", "setuptools_scm>=3.4"] build-backend = "setuptools.build_meta" [tool.pytest.ini_options] @@ -9,7 +9,9 @@ markers = [ "nightly_anisotropic", "nightly_reco", "nightly_reco80", - "nightly_pft", "nightly_custom", "nightly", -] \ No newline at end of file +] + +[tool.setuptools_scm] +write_to = "AFQ/version.py" diff --git a/setup.cfg b/setup.cfg index b641c66fd..2d8896613 100644 --- a/setup.cfg +++ b/setup.cfg @@ -22,23 +22,28 @@ long_description_content_type = text/markdown platforms = OS Independent [options] -setup_requires = - setuptools_scm +setup_requires = setuptools_scm>=3.4.0 python_requires = >=3.10, <3.13 install_requires = # core packages scikit_image>=0.14.2 dipy>=1.11.0,<1.12.0 + scikit-learn pandas pybids>=0.16.2 templateflow>=0.8 immlib - pydra trx-python + # efficiency + numba + osqp + pydra ray + # neural networks + tinygrad @ git+https://github.com/tinygrad/tinygrad.git@846a2826ab4bc00056a366b0dcbd5df17047e016 + brainchop # CLI interpretation toml>=0.10.0 - setuptools_scm[toml]>=3.4.0,<5.1.0 # plotly libraries plotly==5.12.0 kaleido==0.2.1 @@ -53,19 +58,17 @@ packages = find: [options.extras_require] dev = - docutils==0.15.2 - astroid<=2.15.8 sphinx memory-profiler - pytest==7.2.0 - pytest-cov==2.10.0 + pytest + pytest-cov flake8 sphinx_gallery sphinx_rtd_theme numpydoc==1.2 sphinx-autoapi rapidfuzz - xvfbwrapper==0.2.9 + xvfbwrapper>=0.2.9 moto>=3.0.0,<5.0.0 pydata-sphinx-theme sphinx-design @@ -74,8 +77,8 @@ dev = wget fury = fury==0.12.0 - xvfbwrapper==0.2.9 - ipython>=7.13.0,<=7.20.0 + xvfbwrapper>=0.2.9 + ipython fsl = fslpy afqbrowser = @@ -83,7 +86,7 @@ afqbrowser = plot = pingouin>=0.3 seaborn>=0.11.0 - ipython>=7.13.0,<=7.20.0 + ipython aws = s3bids>=0.1.7 s3fs