Skip to content

Shifting to JAX (WIP) #23

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
c17f83e
Add NT Observation
b-biswas Feb 28, 2025
b603090
jaxify metacal func
b-biswas Feb 28, 2025
4b1f299
ngmix obs to NT
b-biswas Feb 28, 2025
bc74c85
to jax
b-biswas Feb 28, 2025
21657f0
added old version
b-biswas Feb 28, 2025
3df3089
update tests
b-biswas Feb 28, 2025
b15dd57
updated environment.yml file
b-biswas Mar 3, 2025
4d5be8b
pre-commit changes
b-biswas Mar 3, 2025
9c5d97d
set jax x64
b-biswas Mar 3, 2025
f7f21b4
jax x64
b-biswas Mar 3, 2025
0b1fb92
update tests
b-biswas Mar 3, 2025
5e067c0
undo changes in utils
b-biswas Mar 5, 2025
c31edc7
revert old code
b-biswas Mar 9, 2025
aa6d20f
add jax code
b-biswas Mar 9, 2025
1edd156
minor [skip ci]
b-biswas Mar 10, 2025
1d566df
initial pr-riv
b-biswas Apr 2, 2025
20e2ab0
[skip ci] add stepk computation
b-biswas Apr 3, 2025
0f9c73e
[skip ci] remove jnp.array calls
b-biswas Apr 3, 2025
0ff1e26
[skip ci] update documentation
b-biswas Apr 3, 2025
19a6580
Merge branch 'main' into to-jax
beckermr Apr 3, 2025
3267769
Bug fix
b-biswas Apr 15, 2025
c16dfd1
compare jax and ngmix
b-biswas Apr 15, 2025
d91beef
fix tests
b-biswas Apr 15, 2025
f204303
minor
b-biswas Apr 15, 2025
6c8d922
update tests
b-biswas Apr 16, 2025
f00ad98
use AffineTransform
b-biswas Apr 16, 2025
794402f
minor
b-biswas Apr 16, 2025
aaea974
bug fix
b-biswas Apr 18, 2025
1f28511
fix bug in testing
b-biswas Apr 18, 2025
227b8b8
minor changes
b-biswas Apr 18, 2025
81d12d7
minor
b-biswas Apr 21, 2025
522abac
minor
b-biswas Apr 21, 2025
657ac5d
minor
b-biswas Apr 21, 2025
63f2877
[skip ci] deconvolution
b-biswas May 2, 2025
05210cd
Merge branch 'main' into to-jax
beckermr May 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
496 changes: 496 additions & 0 deletions deep_field_metadetect/jaxify/jax_metacal.py

Large diffs are not rendered by default.

126 changes: 126 additions & 0 deletions deep_field_metadetect/jaxify/jax_metadetect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import ngmix
import numpy as np

from deep_field_metadetect.detect import (
generate_mbobs_for_detections,
run_detection_sep,
)
from deep_field_metadetect.jaxify.jax_metacal import (
DEFAULT_SHEARS,
DEFAULT_STEP,
jax_metacal_wide_and_deep_psf_matched,
)
from deep_field_metadetect.mfrac import compute_mfrac_interp_image
from deep_field_metadetect.utils import fit_gauss_mom_obs, fit_gauss_mom_obs_and_psf


def jax_single_band_deep_field_metadetect(
obs_wide,
obs_deep,
obs_deep_noise,
nxy,
nxy_psf,
step=DEFAULT_STEP,
shears=None,
skip_obs_wide_corrections=False,
skip_obs_deep_corrections=False,
nodet_flags=0,
scale=0.2,
) -> dict:
"""Run deep-field metadetection for a simple scenario of a single band
with a single image per band using only post-PSF Gaussian weighted moments.

Parameters
----------
obs_wide : DFMdetObservation
The wide-field observation.
obs_deep : DFMdetObservation
The deep-field observation.
obs_deep_noise : DFMdetObservation
The deep-field noise observation.
nxy: int
Image size
nxy_psf: int
PSF size
step : float, optional
The step size for the metacalibration, by default DEFAULT_STEP.
shears : list, optional
The shears to use for the metacalibration, by default DEFAULT_SHEARS
if set to None.
skip_obs_wide_corrections : bool, optional
Skip the observation corrections for the wide-field observations,
by default False.
skip_obs_deep_corrections : bool, optional
Skip the observation corrections for the deep-field observations,
by default False.
nodet_flags : int, optional
The bmask flags marking area in the image to skip, by default 0.
scale: float
pixel scale

Returns
-------
dfmdet_res : dict
The deep-field metadetection results, a dictionary with keys from `shears`
and values containing the detection+measurement results for the corresponding
shear.
"""
if shears is None:
shears = DEFAULT_SHEARS

mcal_res = jax_metacal_wide_and_deep_psf_matched(
obs_wide=obs_wide,
obs_deep=obs_deep,
obs_deep_noise=obs_deep_noise,
nxy=nxy,
nxy_psf=nxy_psf,
step=step,
shears=shears,
skip_obs_wide_corrections=skip_obs_wide_corrections,
skip_obs_deep_corrections=skip_obs_deep_corrections,
scale=scale,
) # This returns ngmix Obs for now

psf_res = fit_gauss_mom_obs(mcal_res["noshear"].psf)
dfmdet_res = []
for shear, obs in mcal_res.items():
detres = run_detection_sep(obs, nodet_flags=nodet_flags)

ixc = (detres["catalog"]["x"] + 0.5).astype(int)
iyc = (detres["catalog"]["y"] + 0.5).astype(int)
bmask_flags = obs.bmask[iyc, ixc]

mfrac_vals = np.zeros_like(bmask_flags, dtype="f4")
if np.any(obs.mfrac > 0):
_interp_mfrac = compute_mfrac_interp_image(
obs.mfrac,
obs.jacobian.get_galsim_wcs(),
)
for i, (x, y) in enumerate(
zip(detres["catalog"]["x"], detres["catalog"]["y"])
):
mfrac_vals[i] = _interp_mfrac.xValue(x, y)

for ind, (obj, mbobs) in enumerate(
generate_mbobs_for_detections(
ngmix.observation.get_mb_obs(obs),
xs=detres["catalog"]["x"],
ys=detres["catalog"]["y"],
)
):
fres = fit_gauss_mom_obs_and_psf(mbobs[0][0], psf_res=psf_res)
dfmdet_res.append(
(ind + 1, obj["x"], obj["y"], shear, bmask_flags[ind], mfrac_vals[ind])
+ tuple(fres[0])
)

total_dtype = [
("id", "i8"),
("x", "f8"),
("y", "f8"),
("mdet_step", "U7"),
("bmask_flags", "i4"),
("mfrac", "f4"),
] + fres.dtype.descr

return np.array(dfmdet_res, dtype=total_dtype)
22 changes: 22 additions & 0 deletions deep_field_metadetect/jaxify/jax_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import jax.numpy as jnp


# @partial(jax.jit, static_argnames=["pixel_scale", "image_size"])
def compute_stepk(pixel_scale, image_size):
"""Compute psf fourier scale based on pixel scale and psf image dimension
The size if obtained from from galsim.GSObject.getGoodImageSize
The factor 1/4 from deep_field_metadetect.metacal.get_gauss_reconv_psf_galsim

Parameters:
-----------
pixel_scale : float
The scale of a single pixel in the image.
image_size : int
The dimension of the PSF image (typically a square size).

Returns:
--------
float
The computed stepk value, which represents the Fourier-space sampling frequency.
"""
return 2 * jnp.pi / (image_size * pixel_scale) / 4
125 changes: 125 additions & 0 deletions deep_field_metadetect/jaxify/observation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from typing import NamedTuple, Optional

import jax
import jax_galsim
import ngmix
import numpy as np
from ngmix.observation import Observation


@jax.tree_util.register_pytree_node_class
class DFMdetObservation(NamedTuple):
image: jax.Array
weight: Optional[jax.Array]
bmask: Optional[jax.Array]
ormask: Optional[jax.Array]
noise: Optional[jax.Array]
aft: Optional[jax_galsim.wcs.AffineTransform]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking ahead, I think we should call this wcs.

psf: Optional["DFMdetObservation"]
mfrac: Optional[jax.Array]
meta: Optional[dict]
store_pixels: bool
ignore_zero_weight: bool

def tree_flatten(self):
children = (
self.image,
self.weight,
self.bmask,
self.ormask,
self.noise,
self.aft,
self.psf,
self.mfrac,
)

aux_data = (self.meta, self.store_pixels, self.ignore_zero_weight)

return children, aux_data

@classmethod
def tree_unflatten(cls, aux_data, children):
# Reconstruct the object from flattened data
return cls(*children, *aux_data)

def has_bmask(self) -> bool:
if self.bmask is None:
return False
return True

def has_mfrac(self) -> bool:
if self.bmask is None:
return False
return True

def has_noise(self) -> bool:
if self.noise is None:
return False
return True

def has_ormask(self) -> bool:
if self.ormask is None:
return False
return True

Check warning on line 63 in deep_field_metadetect/jaxify/observation.py

View check run for this annotation

Codecov / codecov/patch

deep_field_metadetect/jaxify/observation.py#L63

Added line #L63 was not covered by tests

def has_psf(self) -> bool:
if self.psf is None:
return False
return True


def ngmix_obs_to_dfmd_obs(obs: ngmix.observation.Observation) -> DFMdetObservation:
jacobian = obs.get_jacobian()

psf = None
if obs.has_psf():
psf = ngmix_obs_to_dfmd_obs(obs.get_psf())

return DFMdetObservation(
image=obs.image,
weight=obs.weight,
bmask=obs.bmask if obs.has_bmask() else None,
ormask=obs.ormask if obs.has_ormask() else None,
noise=obs.noise if obs.has_noise() else None,
aft=jax_galsim.wcs.AffineTransform(
dudx=jacobian.dudcol,
dudy=jacobian.dudrow,
dvdx=jacobian.dvdcol,
dvdy=jacobian.dvdrow,
origin=jax_galsim.PositionD(
y=jacobian.row0 + 1,
x=jacobian.col0 + 1,
),
),
psf=psf,
meta=obs.meta,
mfrac=obs.mfrac if obs.has_mfrac() else None,
store_pixels=getattr(obs, "store_pixels", True),
ignore_zero_weight=getattr(obs, "ignore_zero_weight", True),
)


def dfmd_obs_to_ngmix_obs(dfmd_obs) -> Observation:
psf = None
if dfmd_obs.psf is not None:
psf = dfmd_obs_to_ngmix_obs(dfmd_obs.psf)
return Observation(
image=np.array(dfmd_obs.image),
weight=np.array(dfmd_obs.weight),
bmask=dfmd_obs.bmask,
ormask=dfmd_obs.ormask,
noise=dfmd_obs.noise if dfmd_obs.noise is None else np.array(dfmd_obs.noise),
jacobian=ngmix.jacobian.Jacobian(
row=dfmd_obs.aft.origin.y - 1,
col=dfmd_obs.aft.origin.x - 1,
dudcol=dfmd_obs.aft.dudx,
dudrow=dfmd_obs.aft.dudy,
dvdcol=dfmd_obs.aft.dvdx,
dvdrow=dfmd_obs.aft.dvdy,
),
psf=psf,
mfrac=dfmd_obs.mfrac if dfmd_obs.mfrac is None else np.array(dfmd_obs.mfrac),
meta=dfmd_obs.meta,
store_pixels=np.array(dfmd_obs.store_pixels, dtype=np.bool_),
ignore_zero_weight=np.array(dfmd_obs.ignore_zero_weight, dtype=np.bool_),
)
Loading