Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ help:
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@cd ..; python setup.py build_ext --inplace
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
8 changes: 4 additions & 4 deletions pyrecon/iterative_fft.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Implementation of Burden et al. 2015 (https://arxiv.org/abs/1504.02591) algorithm."""

from .recon import BaseReconstruction
from . import utils
from .utils import safe_divide


class IterativeFFTReconstruction(BaseReconstruction):
Expand Down Expand Up @@ -39,7 +39,7 @@ def _iterate(self):
# First compute \delta(k)/k^{2} based on current \delta_{g,\mathrm{real},n} to estimate \phi_{\mathrm{est},n} (eq. 24)
delta_k = self.mesh_delta_real.r2c()
for kslab, slab in zip(delta_k.slabs.x, delta_k.slabs):
utils.safe_divide(slab, sum(kk**2 for kk in kslab), inplace=True)
safe_divide(slab, sum(kk**2 for kk in kslab), inplace=True)

self.mesh_delta_real = self.mesh_delta.copy()
# Now compute \beta \nabla \cdot (\nabla \phi_{\mathrm{est},n} \cdot \hat{r}) \hat{r}
Expand Down Expand Up @@ -70,7 +70,7 @@ def _iterate(self):
disp_deriv = disp_deriv.c2r()
for rslab, slab in zip(disp_deriv.slabs.x, disp_deriv.slabs):
rslab = self._transform_rslab(rslab)
slab[...] *= utils.safe_divide(rslab[iaxis] * rslab[jaxis], sum(rr**2 for rr in rslab))
slab[...] *= safe_divide(rslab[iaxis] * rslab[jaxis], sum(rr**2 for rr in rslab))
factor = (1. + (iaxis != jaxis)) * self.beta # we have j >= i and double-count j > i to account for j < i
if self._iter == 0:
# Burden et al. 2015: 1504.02591, eq. 12 (flat sky approximation)
Expand All @@ -87,7 +87,7 @@ def _compute_psi(self):
psi = delta_k.copy()
for kslab, islab, slab in zip(psi.slabs.x, psi.slabs.i, psi.slabs):
mask = islab[iaxis] != self.nmesh[iaxis] // 2
slab[...] *= 1j * utils.safe_divide(kslab[iaxis], sum(kk**2 for kk in kslab)) * mask
slab[...] *= 1j * safe_divide(kslab[iaxis], sum(kk**2 for kk in kslab)) * mask
psis.append(psi.c2r())
del psi
return psis
8 changes: 4 additions & 4 deletions pyrecon/iterative_fft_particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from .recon import BaseReconstruction, ReconstructionError, format_positions_wrapper, format_positions_weights_wrapper
from . import utils
from .utils import distance, safe_divide


class OriginalIterativeFFTParticleReconstruction(BaseReconstruction):
Expand Down Expand Up @@ -136,7 +136,7 @@ def _iterate(self, return_psi=False):
del self.mesh_delta

for kslab, slab in zip(delta_k.slabs.x, delta_k.slabs):
utils.safe_divide(slab, sum(kk**2 for kk in kslab), inplace=True)
safe_divide(slab, sum(kk**2 for kk in kslab), inplace=True)

if self.mpicomm.rank == 0:
self.log_info('Computing displacement field.')
Expand All @@ -162,7 +162,7 @@ def _iterate(self, return_psi=False):
# self.log_info('A few displacements values:')
# for s in shifts[:3]: self.log_info('{}'.format(s))
if self.los is None:
los = utils.safe_divide(self._positions_data, utils.distance(self._positions_data)[:, None])
los = safe_divide(self._positions_data, distance(self._positions_data)[:, None])
else:
los = self.los
# Comments in Julian's code:
Expand Down Expand Up @@ -235,7 +235,7 @@ def _read_shifts(positions):
return shifts

if self.los is None:
los = utils.safe_divide(positions, utils.distance(positions)[:, None])
los = safe_divide(positions, distance(positions)[:, None])
else:
los = self.los.astype(positions.dtype)
rsd = self.f * np.sum(shifts * los, axis=-1)[:, None] * los
Expand Down
8 changes: 5 additions & 3 deletions pyrecon/multigrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import numpy as np

from .recon import BaseReconstruction, ReconstructionError, format_positions_wrapper
from . import _multigrid, utils, mpi
from .utils import distance, safe_divide
from .mpi import COMM_WORLD
from . import _multigrid


class OriginalMultiGridReconstruction(BaseReconstruction):
Expand Down Expand Up @@ -31,7 +33,7 @@ def _select_nmesh(nmesh):
toret.append(ntries[iclosest])
return np.array(toret, dtype='i8')

def __init__(self, *args, mpicomm=mpi.COMM_WORLD, **kwargs):
def __init__(self, *args, mpicomm=COMM_WORLD, **kwargs):
# We require a split, along axis x.
super(OriginalMultiGridReconstruction, self).__init__(*args, decomposition=(mpicomm.size, 1), mpicomm=mpicomm, **kwargs)

Expand Down Expand Up @@ -149,7 +151,7 @@ def read_shifts(self, positions, field='disp+rsd'):
if field == 'disp':
return shifts
if self.los is None:
los = utils.safe_divide(positions, utils.distance(positions)[:, None])
los = safe_divide(positions, distance(positions)[:, None])
else:
los = self.los.astype(shifts.dtype)
rsd = self.f * np.sum(shifts * los, axis=-1)[:, None] * los
Expand Down
22 changes: 11 additions & 11 deletions pyrecon/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from pmesh.pm import ParticleMesh

from .mesh import _get_mesh_attrs, _get_resampler, _wrap_positions
from .utils import BaseClass
from . import utils, mpi
from .utils import BaseClass, sky_to_cartesian, cartesian_to_sky, distance, safe_divide
from .mpi import scatter, gather, COMM_WORLD


def _gaussian_kernel(smoothing_radius):
Expand Down Expand Up @@ -54,7 +54,7 @@ def __format_positions(positions):
if len(positions) != 3:
return None, 'For position type = {}, please provide a list of 3 arrays for positions (found {:d})'.format(position_type, len(positions))
if position_type == 'rdd': # RA, Dec, distance
positions = utils.sky_to_cartesian(positions[2], *positions[:2], degree=True).T
positions = sky_to_cartesian(positions[2], *positions[:2], degree=True).T
elif position_type != 'xyz':
return None, 'Position type should be one of ["pos", "xyz", "rdd"]'
return np.asarray(positions).T, None
Expand All @@ -71,7 +71,7 @@ def __format_positions(positions):
if errors:
raise ValueError(errors[0])
if mpiroot is not None and mpicomm.bcast(positions is not None if mpicomm.rank == mpiroot else None, root=mpiroot):
positions = mpi.scatter(positions, mpicomm=mpicomm, mpiroot=mpiroot)
positions = scatter(positions, mpicomm=mpicomm, mpiroot=mpiroot)
return positions


Expand All @@ -89,7 +89,7 @@ def __format_weights(weights):
if any(is_none) and not all(is_none):
raise ValueError('mpiroot = None but weights are None on some ranks')
elif not mpicomm.bcast(weights is None, root=mpiroot):
weights = mpi.scatter(weights, mpicomm=mpicomm, mpiroot=mpiroot)
weights = scatter(weights, mpicomm=mpicomm, mpiroot=mpiroot)

if size is not None and weights is not None and len(weights) != size:
raise ValueError('Weight arrays should be of the same size as position arrays')
Expand Down Expand Up @@ -127,10 +127,10 @@ def wrapper(self, positions, copy=False, dtype=None, **kwargs):
raise ValueError('positions not in box range {} - {}'.format(low, high))
toret = func(self, positions=positions, **kwargs)
if toret is not None and mpiroot is not None: # positions returned, gather on the same rank
toret = mpi.gather(toret, mpicomm=self.mpicomm, mpiroot=mpiroot)
toret = gather(toret, mpicomm=self.mpicomm, mpiroot=mpiroot)
if toret is not None and return_input_type:
if position_type == 'rdd':
dist, ra, dec = utils.cartesian_to_sky(toret)
dist, ra, dec = cartesian_to_sky(toret)
toret = [ra, dec, dist]
elif position_type == 'xyz':
toret = toret.T
Expand Down Expand Up @@ -199,7 +199,7 @@ def _select_nmesh(nmesh):

def __init__(self, f=None, bias=None, los=None, nmesh=None, boxsize=None, boxcenter=None, cellsize=None, boxpad=2., wrap=False,
data_positions=None, randoms_positions=None, data_weights=None, randoms_weights=None,
positions=None, position_type='pos', resampler='cic', decomposition=None, fft_plan='estimate', dtype='f8', mpiroot=None, mpicomm=mpi.COMM_WORLD, **kwargs):
positions=None, position_type='pos', resampler='cic', decomposition=None, fft_plan='estimate', dtype='f8', mpiroot=None, mpicomm=COMM_WORLD, **kwargs):
"""
Initialize :class:`BaseReconstruction`.

Expand Down Expand Up @@ -407,7 +407,7 @@ def set_los(self, los=None):
los = np.zeros(3, dtype='f8')
los[ilos] = 1.
los = np.array(los, dtype='f8')
self.los = los / utils.distance(los)
self.los = los / distance(los)

@property
def cellsize(self):
Expand Down Expand Up @@ -658,13 +658,13 @@ def read_shifts(self, positions, field='disp+rsd'):
if field == 'disp':
return shifts
if self.los is None:
los = utils.safe_divide(positions, utils.distance(positions)[:, None])
los = safe_divide(positions, distance(positions)[:, None])
else:
los = self.los.astype(shifts.dtype)
if self.f_callable is None:
f = self.f
else:
f = self.f_callable(utils.distance(positions))[..., None]
f = self.f_callable(distance(positions))[..., None]
rsd = f * (np.sum(shifts * los, axis=-1)[:, None] * los)
if field == 'rsd':
return rsd
Expand Down