Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e518313
Added obscured_airy_disk() function and added PBCorrectedCube to Simp…
kdesoto-astro Oct 27, 2022
1149d73
Modified gridding checks to include optional frequency array.
kdesoto-astro Nov 26, 2022
da9fdad
Added class attribute chan_freq, changed naming to better differentia…
kdesoto-astro Nov 26, 2022
7d8a9d0
Set up framework for PrimaryBeamCube and dish types
kdesoto-astro Nov 27, 2022
c40f1b1
First attempt at both versions of primary beam corrections implemented
kdesoto-astro Nov 27, 2022
0ec9862
Fixed typo in diameter to wavelength ratio calculation
kdesoto-astro Nov 27, 2022
c60d0b7
Fixed miscellaneous errors, including missing arcsec to radian conver…
kdesoto-astro Nov 28, 2022
3065dc6
Normed factor IS 1 for the uniform case, no need to calculate it
kdesoto-astro Nov 30, 2022
3d1de04
Added obscured_airy_disk() function and added PBCorrectedCube to Simp…
kdesoto-astro Oct 27, 2022
5551ff6
Modified gridding checks to include optional frequency array.
kdesoto-astro Nov 26, 2022
ad3653a
Added class attribute chan_freq, changed naming to better differentia…
kdesoto-astro Nov 26, 2022
d3852b2
Set up framework for PrimaryBeamCube and dish types
kdesoto-astro Nov 27, 2022
2006878
First attempt at both versions of primary beam corrections implemented
kdesoto-astro Nov 27, 2022
5eb95ee
Fixed typo in diameter to wavelength ratio calculation
kdesoto-astro Nov 27, 2022
02651fb
Fixed miscellaneous errors, including missing arcsec to radian conver…
kdesoto-astro Nov 28, 2022
b7d9a78
Normed factor IS 1 for the uniform case, no need to calculate it
kdesoto-astro Nov 30, 2022
5a7d7f0
Rebase + removed _setup_coords function for PrimaryBeamCorrected
kdesoto-astro Mar 23, 2023
aa8eb55
Resolving merge conflicts
kdesoto-astro Mar 23, 2023
503acd3
Removed last merge conflict from gridding
kdesoto-astro Mar 23, 2023
cf042b8
Removed duplicated PrimaryBeamCube class
kdesoto-astro Mar 23, 2023
45639dd
Small typo in gridding.py
kdesoto-astro Mar 23, 2023
e276817
Corrected bugs in primary_beam.py, moved to separate file, added defa…
kdesoto-astro Apr 10, 2023
c47ce17
Merge branch 'MPoL-dev:main' into main
kdesoto-astro Apr 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 43 additions & 7 deletions src/mpol/gridding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .datasets import GriddedDataset


def _check_data_inputs_2d(uu=None, vv=None, weight=None, data_re=None, data_im=None):
def _check_data_inputs_2d(uu=None, vv=None, weight=None, data_re=None, data_im=None, freq=None):
"""
Check that all data inputs are the same shape, the weights are positive, and the data_re and data_im are floats.

Expand All @@ -34,6 +34,9 @@ def _check_data_inputs_2d(uu=None, vv=None, weight=None, data_re=None, data_im=N
"All dataset inputs must be the same input shape and size."
)

if freq is not None: # TODO:change to wrongdimensionerror
assert len(uu) == len(freq), "uu must have same number of channels as freq array."

if np.any(weight <= 0.0):
raise ValueError("Not all thermal weights are positive, check inputs.")

Expand All @@ -53,8 +56,7 @@ def _check_data_inputs_2d(uu=None, vv=None, weight=None, data_re=None, data_im=N
# check to see that uu, vv and data do not contain Hermitian pairs
verify_no_hermitian_pairs(uu, vv, data_re + 1.0j * data_im)

return uu, vv, weight, data_re, data_im

return uu, vv, weight, data_re, data_im, freq

def verify_no_hermitian_pairs(uu, vv, data, test_vis=5, test_channel=0):
r"""
Expand Down Expand Up @@ -144,6 +146,33 @@ def verify_no_hermitian_pairs(uu, vv, data, test_vis=5, test_channel=0):
return False


def _check_freq_1d(freq=None):
"""
Check that the frequency input array contains only positive floats.

If the user supplied a float, convert to a 1D array. If no frequency array
was supplied, simply skip.

"""
if freq is None:
return freq

assert (
np.isscalar(freq) or freq.ndim == 1
), "Input data vectors should be either None, scalar, or 1D array."

assert np.all(freq > 0.0), "Not all frequencies are positive, check inputs."

if np.isscalar(freq):
freq = np.atleast_1d(freq)

assert (freq.dtype == np.single) or (
freq.dtype == np.double
), "freq should be type single or double"

return freq


class GridderBase:
r"""
This class is not designed to be used directly, but rather to be subclassed.
Expand Down Expand Up @@ -172,13 +201,18 @@ def __init__(
weight=None,
data_re=None,
data_im=None,
chan_freq=None,
):

# check frequency array is 1d or None, expand if not
chan_freq = _check_freq_1d(chan_freq)

# check everything should be 2d, expand if not
# also checks data does not contain Hermitian pairs
uu, vv, weight, data_re, data_im = _check_data_inputs_2d(
uu, vv, weight, data_re, data_im
uu, vv, weight, data_re, data_im, chan_freq = _check_data_inputs_2d(
uu, vv, weight, data_re, data_im, chan_freq
)

# setup the coordinates object
self.coords = coords
self.nchan = len(uu)
Expand All @@ -193,6 +227,7 @@ def __init__(
self.weight = weight
self.data_re = data_re
self.data_im = data_im
self.chan_freq = chan_freq

# and register cell indices against data
self._create_cell_indices()
Expand All @@ -211,6 +246,7 @@ def from_image_properties(
coords = GridCoords(cell_size, npix)
return cls(coords, uu, vv, weight, data_re, data_im)


def _create_cell_indices(self):
# figure out which visibility cell each datapoint lands in, so that
# we can later assign it the appropriate robust weight for that cell
Expand Down Expand Up @@ -586,7 +622,7 @@ def __init__(
):
# check everything should be 2d, expand if not
# also checks data does not contain Hermitian pairs
uu, vv, weight, data_re, data_im = _check_data_inputs_2d(
uu, vv, weight, data_re, data_im, freq = _check_data_inputs_2d(
uu, vv, weight, data_re, data_im
)

Expand Down
2 changes: 2 additions & 0 deletions src/mpol/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import numpy as np
from scipy.special import j1
import torch
import torch.fft # to avoid conflicts with old torch.fft *function*
from torch import nn
Expand All @@ -11,6 +12,7 @@
from .coordinates import GridCoords



class BaseCube(nn.Module):
r"""
A base cube of the same dimensions as the image cube. Designed to use a pixel mapping function :math:`f_\mathrm{map}` from the base cube values to the ImageCube domain.
Expand Down
20 changes: 19 additions & 1 deletion src/mpol/precomposed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from mpol.coordinates import GridCoords

from . import fourier, images
from . import fourier, images, primary_beam


class SimpleNet(torch.nn.Module):
Expand Down Expand Up @@ -35,7 +35,12 @@ def __init__(
coords=None,
nchan=1,
base_cube=None,
chan_freqs=None,
dish_type=None,
dish_radius=None,
**dish_kwargs,
):

super().__init__()

self.coords = coords
Expand All @@ -50,12 +55,23 @@ def __init__(
self.icube = images.ImageCube(
coords=self.coords, nchan=self.nchan, passthrough=True
)

self.pbcube = primary_beam.PrimaryBeamCube(
coords = self.coords,
nchan=self.nchan,
chan_freqs=chan_freqs,
dish_type=dish_type,
dish_radius=dish_radius,
**dish_kwargs
)
self.fcube = fourier.FourierCube(coords=self.coords)


@classmethod
def from_image_properties(cls, cell_size, npix, nchan, base_cube):
coords = GridCoords(cell_size, npix)
return cls(coords, nchan, base_cube)


def forward(self):
r"""
Expand All @@ -66,5 +82,7 @@ def forward(self):
x = self.bcube()
x = self.conv_layer(x)
x = self.icube(x)
x = self.pbcube(x)
vis = self.fcube(x)

return vis
184 changes: 184 additions & 0 deletions src/mpol/primary_beam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
r"""The ``primary_beam`` module provides the core functionality of MPoL via :class:`mpol.fourier.PrimaryBeamCube`."""

from __future__ import annotations

import numpy as np
import torch
import torch.fft # to avoid conflicts with old torch.fft *function*
import torchkbnufft
from torch import nn

from . import utils
from .coordinates import GridCoords

from .gridding import _check_freq_1d

class PrimaryBeamCube(nn.Module):
r"""
A ImageCube representing the primary beam of a described dish type. Currently can correct for a
uniform or center-obscured dish. The forward() method multiplies an image cube by this primary beam mask.

Args:
cell_size (float): the width of a pixel [arcseconds]
npix (int): the number of pixels per image side
coords (GridCoords): an object already instantiated from the GridCoords class. If providing this, cannot provide ``cell_size`` or ``npix``.
nchan (int): the number of channels in the image
dish_type (string): the type of dish to correct for. Either 'uniform' or 'obscured'.
dish_radius (float): the radius of the dish (in meters)
dish_kwargs (dict): any additional arguments needed for special dish types. Currently only uses:
dish_obscured_radius (float): the radius of the obscured portion of the dish
"""
def __init__(
self,
coords,
nchan=1,
chan_freqs=None,
dish_type=None,
dish_radius=None,
**dish_kwargs,
):
super().__init__()

#_setup_coords(self, cell_size, npix, coords, nchan) TODO: update this

_check_freq_1d(chan_freqs)
assert (chan_freqs is None) or (len(chan_freqs) == nchan), "Length of chan_freqs must be equal to nchan"

assert (dish_type is None) or (dish_type in ["uniform", "obscured"]), "Provided dish_type must be 'uniform' or 'obscured'"

self.coords = coords
self.nchan = nchan

self.default_mask = nn.Parameter(
torch.full(
(self.nchan, self.coords.npix, self.coords.npix),
fill_value=1.0,
requires_grad=False,
dtype=torch.double,
)
)

if dish_type is None:
self.pb_mask = self.default_mask
elif dish_type == "uniform":
self.pb_mask = self.uniform_mask(chan_freqs, dish_radius)
elif dish_type == "obscured":
self.pb_mask = self.obscured_mask(chan_freqs, dish_radius, **dish_kwargs)

@classmethod
def from_image_properties(
cls, cell_size, npix, nchan=1,
chan_freqs=None, dish_type=None,
dish_radius=None, **dish_kwargs
) -> ImageCube:
coords = GridCoords(cell_size, npix)
return cls(coords, nchan, chan_freqs, dish_type, dish_radius, **dish_kwargs)

def forward(self, cube):
r"""Args:
cube (torch.double tensor, of shape ``(nchan, npix, npix)``): a prepacked image cube, for example, from ImageCube.forward()

Returns:
(torch.complex tensor, of shape ``(nchan, npix, npix)``): the FFT of the image cube, in packed format.
"""
return torch.mul(self.pb_mask, cube)


def uniform_mask(self, chan_freqs, dish_radius):
r"""
Generates airy disk primary beam correction mask.
"""
assert dish_radius > 0., "Dish radius must be positive"
ratio = 2. * dish_radius * np.array([[chan_freqs]]).T / 2.998e8

ratio_cube = np.tile(ratio,(1,self.coords.npix,self.coords.npix))
r_2D = np.sqrt(self.coords.packed_x_centers_2D**2 + self.coords.packed_y_centers_2D**2) # arcsec
r_2D_rads = r_2D * np.pi / 180. / 60. / 60. # radians
r_cube = np.tile(r_2D_rads,(self.nchan,1,1))

r_normed_cube = np.pi * r_cube * ratio_cube

mask = np.where(r_normed_cube > 0.,
(2. * j1(r_normed_cube) / r_normed_cube)**2,
1.)
return torch.tensor(mask)


def obscured_mask(self, chan_freqs, dish_radius, dish_obscured_radius=None, **extra_kwargs):
r"""
Generates airy disk primary beam correction mask.
"""
assert dish_obscured_radius is not None, "Obscured dish requires kwarg 'dish_obscured_radius'"
assert dish_radius > 0., "Dish radius must be positive"
assert dish_obscured_radius > 0., "Obscured dish radius must be positive"
assert dish_radius > dish_obscured_radius, "Primary dish radius must be greater than obscured radius"

ratio = 2. * dish_radius * np.array([[chan_freqs]]).T / 2.998e8
ratio_cube = np.tile(ratio,(1,self.coords.npix,self.coords.npix))
r_2D = np.sqrt(self.coords.packed_x_centers_2D**2 + self.coords.packed_y_centers_2D**2) # arcsec
r_2D_rads = r_2D * np.pi / 180. / 60. / 60. # radians
r_cube = np.tile(r_2D_rads,(self.nchan,1,1))

eps = dish_obscured_radius / dish_radius
r_normed_cube = np.pi * r_cube * ratio_cube

norm_factor = (1.-eps**2)**2
mask = np.where(r_normed_cube > 0.,
(j1(r_normed_cube) / r_normed_cube
- eps*j1(eps*r_normed_cube) / r_normed_cube)**2 / norm_factor,
1.)
return torch.tensor(mask)

@property
def sky_cube(self):
"""
The primary beam mask arranged as it would appear on the sky.

Returns:
torch.double : 3D image cube of shape ``(nchan, npix, npix)``

"""
return utils.packed_cube_to_sky_cube(self.pb_mask)

def to_FITS(self, fname="cube.fits", overwrite=False, header_kwargs=None):
"""
Export the primary beam cube to a FITS file.

Args:
fname (str): the name of the FITS file to export to.
overwrite (bool): if the file already exists, overwrite?
header_kwargs (dict): Extra keyword arguments to write to the FITS header.

Returns:
None
"""

try:
from astropy import wcs
from astropy.io import fits
except ImportError:
print(
"Please install the astropy package to use FITS export functionality."
)

w = wcs.WCS(naxis=2)

w.wcs.crpix = np.array([1, 1])
w.wcs.cdelt = (
np.array([self.coords.cell_size, self.coords.cell_size]) / 3600
) # decimal degrees
w.wcs.ctype = ["RA---TAN", "DEC--TAN"]

header = w.to_header()

# add in the kwargs to the header
if header_kwargs is not None:
for k, v in header_kwargs.items():
header[k] = v

hdu = fits.PrimaryHDU(self.pb_mask.detach().cpu().numpy(), header=header)

hdul = fits.HDUList([hdu])
hdul.writeto(fname, overwrite=overwrite)

hdul.close()
9 changes: 9 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest
from astropy.utils.data import download_file
import torch

from mpol import coordinates, gridding

Expand Down Expand Up @@ -53,6 +54,14 @@ def coords():
return coordinates.GridCoords(cell_size=0.005, npix=800)


@pytest.fixture
def unit_cube(coords):
nchan = 4
input_cube = torch.full(
(nchan, coords.npix, coords.npix), fill_value=1.0, dtype=torch.double
)
return input_cube

@pytest.fixture
def averager(mock_visibility_data, coords):
uu, vv, weight, data_re, data_im = mock_visibility_data
Expand Down
Loading