Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions AFQ/models/asym_csd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import numpy as np
import sys
import itertools
import os
import multiprocessing as mp

from cvxopt import matrix
from cvxopt.solvers import options, qp

import ray
from ray.experimental import tqdm_ray

from dipy.reconst import shm
from dipy.reconst import csdeconv as csd
from dipy.core.sphere import HemiSphere


options['show_progress'] = False # disable cvxopt output
options['maxiters'] = 100 # maximum number of qp iteration
options['abstol'] = 1e-6
options['reltol'] = 1e-6
options['feastol'] = 1e-9

nprocs = mp.cpu_count()
remote_tqdm = ray.remote(tqdm_ray.tqdm)


__all__ = ['AsymConstrainedSphericalDeconvModel']


def _get_weights(vertices, sigma=40):
'''Computes neighbouring fod weights for asymmetric CSD.

Vendorized from:
https://github.com/mabast85/aFOD/blob/master/aFOD/csdeconv/csdeconv.py

Generates matrix that contains the weight for each point on the
neighbouring fod based on their distance to the current voxel and
the angle between the current fod point and the point of the
neighbouring fod.

Args:
vertices: Nx3 numpy array with vertices of the unit sphere.
sigma: cut-off angle.

Returns:
26xN weight matrix as numpy array.
'''
neighs = np.array(list(itertools.product([-1, 0, 1], repeat=3)))
neighs = np.delete(neighs, 13, 0) # Remove [0, 0, 0]
d = np.linalg.norm(neighs, ord=2, axis=1)
deg_mat = np.arccos(np.dot(neighs / d[:, np.newaxis], vertices.T))
weights = np.exp(-deg_mat / np.deg2rad(sigma))
# Do not consider vertices that are not aligned with any neighbouring voxel
weights[deg_mat > np.deg2rad(60)] = 0
weights = weights / d[:, np.newaxis] # Account for distance
# Divide by the vertex-wise weight sum
weights = weights / np.sum(weights, axis=0)[np.newaxis, :]
weights[np.isnan(weights)] = 0 # Check for nans
return weights


class AsymConstrainedSphericalDeconvModel(csd.ConstrainedSphericalDeconvModel):
'''
Vendorized from:
https://github.com/mabast85/aFOD/blob/master/aFOD/csdeconv/csdeconv.py
'''

def fit_prev(self, data, **kwargs):
self.prev_fod = super().fit(data, **kwargs).shm_coeff

def fit(self, data, **kwargs):
# if isinstance(self.sphere, HemiSphere): # TODO: is this necessary?
# raise ValueError("Asym CSD does not support HemiSphere")

_w = _get_weights(self.sphere.vertices)
_X = np.concatenate((self._X, self.B_reg), axis=0)

if not hasattr(self, 'prev_fod'):
self.fit_prev(data, engine="ray", **kwargs)

data = data[..., self._where_dwi]
fod = np.zeros(
(*data[..., 0].shape, self.B_reg.shape[1]),
dtype=np.float32)

neighs = np.array(list(itertools.product([-1, 0, 1], repeat=3)))
neighs = np.delete(neighs, 13, 0) # Remove [0, 0, 0]

if not ray.is_initialized():
ray.init()
data_ref = ray.put(data)
prev_fod_ref = ray.put(self.prev_fod)

@ray.remote
def _ray_fitter(_P, B_reg, neighs, _X, _w, xyz, bar):
h = matrix(np.zeros(B_reg.shape[0]))
args = [matrix(_P), 0, matrix(-B_reg), h]
data = ray.get(data_ref)
prev_fod = ray.get(prev_fod_ref)

shm_coeffs = []
for x, y, z in xyz:
signal = data[x, y, z, :]

fNeighs = prev_fod[
x + neighs[:, 0],
y + neighs[:, 1],
z + neighs[:, 2]]
n_fod = np.diag(np.dot(np.dot(-B_reg, fNeighs.T), _w))
signal = np.concatenate((signal, n_fod))
f = np.dot(-_X.T, signal)
# Using cvxopt
args[1] = matrix(f)

# Suppress cvxopt output
sys.stdout = open(os.devnull, 'w')
sol = qp(*args)
sys.stdout = sys.__stdout__

shm_coeffs.append((
x, y, z,
np.array(sol['x']).reshape((f.shape[0],)),
'optimal' not in sol['status']))
bar.update.remote(1)
return shm_coeffs

# Chunk up indices
mask = kwargs.get('mask', np.ones_like(data[..., 0], dtype=bool))
ii = np.where(mask)
batches = np.array_split(list(zip(*ii)), nprocs)
bar = remote_tqdm.remote(total=len(ii[0]), desc="Running Asym CSD")
tasks = []
for batch in batches:
task_args = [
self._P, self.B_reg, neighs,
_X, _w, batch, bar]
tasks.append(_ray_fitter.remote(*task_args))

results = []
for batch_result in ray.get(tasks):
results.extend(batch_result)

suboptimal_count = 0
for x, y, z, sol, suboptimal in results:
suboptimal_count += int(suboptimal)
fod[x, y, z, :] = sol
print("Suboptimal solutions: %d" % suboptimal_count)
bar.close.remote()

return shm.SphHarmFit(self, fod, mask)
72 changes: 34 additions & 38 deletions AFQ/models/csd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import nibabel as nib

from dipy.reconst import csdeconv as csd
from dipy.reconst import mcsd
from dipy.reconst import shm
from dipy.core.gradients import gradient_table, unique_bvals_magnitude
from dipy.data import get_sphere, small_sphere

import AFQ.utils.models as ut
from AFQ.models.asym_csd import AsymConstrainedSphericalDeconvModel


# Monkey patch fixed spherical harmonics for conda from
Expand All @@ -23,9 +25,10 @@ class CsdNanResponseError(Exception):
pass


def _model(gtab, data, response=None, sh_order=None, csd_fa_thr=0.7):
def _fit(gtab, data, mask, response=None, sh_order=None,
asym=False, csd_fa_thr=0.7):
"""
Helper function that defines a CSD model.
Helper function that does the core of fitting a model to data.
"""
if sh_order is None:
ndata = np.sum(~gtab.b0s_mask)
Expand All @@ -37,39 +40,41 @@ def _model(gtab, data, response=None, sh_order=None, csd_fa_thr=0.7):
if sh_order > 8:
sh_order = 8

my_model = csd.ConstrainedSphericalDeconvModel
unique_bvals = unique_bvals_magnitude(gtab.bvals)
if len(unique_bvals[unique_bvals > 0]) > 1:
low_shell_idx = gtab.bvals <= unique_bvals[unique_bvals > 0][0]
response_gtab = gradient_table(gtab.bvals[low_shell_idx],
gtab.bvecs[low_shell_idx])
data = data[..., low_shell_idx]
else:
response_gtab = gtab

if response is None:
unique_bvals = unique_bvals_magnitude(gtab.bvals)
if len(unique_bvals[unique_bvals > 0]) > 1:
low_shell_idx = gtab.bvals <= unique_bvals[unique_bvals > 0][0]
response_gtab = gradient_table(gtab.bvals[low_shell_idx],
gtab.bvecs[low_shell_idx])
data = data[..., low_shell_idx]
else:
response_gtab = gtab
response, _ = csd.auto_response_ssst(response_gtab,
data,
roi_radii=10,
fa_thr=csd_fa_thr)
response, _ = csd.auto_response_ssst(
response_gtab,
data,
roi_radii=10,
fa_thr=csd_fa_thr)
# Catch conditions where an auto-response could not be calculated:
if np.all(np.isnan(response[0])):
raise CsdNanResponseError

csdmodel = my_model(gtab, response, sh_order=sh_order)
return csdmodel


def _fit(gtab, data, mask, response=None, sh_order=None,
lambda_=1, tau=0.1, csd_fa_thr=0.7):
"""
Helper function that does the core of fitting a model to data.
"""
return _model(gtab, data, response, sh_order, csd_fa_thr).fit(
data, mask=mask)
if asym:
acsdmodel = AsymConstrainedSphericalDeconvModel(
gtab, response,
reg_sphere=small_sphere,
sh_order=sh_order)
return acsdmodel.fit(
data, mask=mask)
else:
csdmodel = csd.ConstrainedSphericalDeconvModel(
gtab, response, sh_order=sh_order)
return csdmodel.fit(
data, mask=mask)


def fit_csd(data_files, bval_files, bvec_files, mask=None, response=None,
b0_threshold=50, sh_order=None, lambda_=1, tau=0.1, out_dir=None):
b0_threshold=50, sh_order=None, asym=False, out_dir=None):
"""
Fit the CSD model and save file with SH coefficients.

Expand Down Expand Up @@ -97,15 +102,6 @@ def fit_csd(data_files, bval_files, bvec_files, mask=None, response=None,
sh_order : int, optional.
default: infer the number of parameters from the number of data
volumes, but no larger than 8.
lambda_ : float, optional.
weight given to the constrained-positivity regularization part of
the deconvolution equation. Default: 1
tau : float, optional.
threshold controlling the amplitude below which the corresponding
fODF is assumed to be zero. Ideally, tau should be set to
zero. However, to improve the stability of the algorithm, tau is
set to tau*100 % of the mean fODF amplitude (here, 10% by default)
(see [1]_). Default: 0.1
out_dir : str, optional
A full path to a directory to store the maps that get computed.
Default: file with coefficients gets stored in the same directory as
Expand All @@ -127,7 +123,7 @@ def fit_csd(data_files, bval_files, bvec_files, mask=None, response=None,
mask=mask)

csdfit = _fit(gtab, data, mask, response=response, sh_order=sh_order,
lambda_=lambda_, tau=tau)
asym=asym)

if out_dir is None:
out_dir = op.join(op.split(data_files)[0], 'dki')
Expand Down
23 changes: 9 additions & 14 deletions AFQ/tasks/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def msdki_msk(msdki_tf):
@as_img
def csd_params(dwi, brain_mask, gtab, data,
csd_response=None, csd_sh_order=None,
csd_lambda_=1, csd_tau=0.1,
csd_asym=False,
csd_fa_thr=0.7):
"""
full path to a nifti file containing
Expand All @@ -313,16 +313,8 @@ def csd_params(dwi, brain_mask, gtab, data,
default: infer the number of parameters from the number of data
volumes, but no larger than 8.
Default: None
csd_lambda_ : float, optional.
weight given to the constrained-positivity regularization part of
the deconvolution equation. Default: 1
csd_tau : float, optional.
threshold controlling the amplitude below which the corresponding
fODF is assumed to be zero. Ideally, tau should be set to
zero. However, to improve the stability of the algorithm, tau is
set to tau*100 percent of the mean fODF amplitude (here, 10 percent
by default)
(see [1]_). Default: 0.1
csd_asym : bool, optional.
Whehter to use
csd_fa_thr : float, optional.
The threshold on the FA used to calculate the single shell auto
response. Can be useful to reduce for baby subjects. Default: 0.7
Expand All @@ -333,6 +325,10 @@ def csd_params(dwi, brain_mask, gtab, data,
the fibre orientation distribution in diffusion MRI:
Non-negativity constrained super-resolved spherical
deconvolution
.. [2] Bastiani, M., Cottaar, M., Dikranian, K., Ghosh, A., Zhang, H.,
Alexander, D.C., Behrens, T.E., Jbabdi, S.,
Sotiropoulos, S.N., 2017. Improved tractography using asymmetric
fibre orientation distributions. Neuroimage 158, 205-218.
"""
mask =\
nib.load(brain_mask).get_fdata()
Expand All @@ -341,7 +337,7 @@ def csd_params(dwi, brain_mask, gtab, data,
gtab, data,
mask=mask,
response=csd_response, sh_order=csd_sh_order,
lambda_=csd_lambda_, tau=csd_tau,
asym=csd_asym,
csd_fa_thr=csd_fa_thr)
except CsdNanResponseError as e:
raise CsdNanResponseError(
Expand All @@ -351,8 +347,7 @@ def csd_params(dwi, brain_mask, gtab, data,
meta = dict(
SphericalHarmonicDegree=csd_sh_order,
ResponseFunctionTensor=csd_response,
lambda_=csd_lambda_,
tau=csd_tau,
asym=csd_asym,
csd_fa_thr=csd_fa_thr)
meta["SphericalHarmonicBasis"] = "DESCOTEAUX"
meta["ModelURL"] = f"{DIPY_GH}reconst/csdeconv.py"
Expand Down
Loading