-
Notifications
You must be signed in to change notification settings - Fork 1
Shifting to JAX (WIP) #23
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
b-biswas
wants to merge
35
commits into
beckermr:main
Choose a base branch
from
b-biswas:to-jax
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
35 commits
Select commit
Hold shift + click to select a range
c17f83e
Add NT Observation
b-biswas b603090
jaxify metacal func
b-biswas 4b1f299
ngmix obs to NT
b-biswas bc74c85
to jax
b-biswas 21657f0
added old version
b-biswas 3df3089
update tests
b-biswas b15dd57
updated environment.yml file
b-biswas 4d5be8b
pre-commit changes
b-biswas 9c5d97d
set jax x64
b-biswas f7f21b4
jax x64
b-biswas 0b1fb92
update tests
b-biswas 5e067c0
undo changes in utils
b-biswas c31edc7
revert old code
b-biswas aa6d20f
add jax code
b-biswas 1edd156
minor [skip ci]
b-biswas 1d566df
initial pr-riv
b-biswas 20e2ab0
[skip ci] add stepk computation
b-biswas 0f9c73e
[skip ci] remove jnp.array calls
b-biswas 0ff1e26
[skip ci] update documentation
b-biswas 19a6580
Merge branch 'main' into to-jax
beckermr 3267769
Bug fix
b-biswas c16dfd1
compare jax and ngmix
b-biswas d91beef
fix tests
b-biswas f204303
minor
b-biswas 6c8d922
update tests
b-biswas f00ad98
use AffineTransform
b-biswas 794402f
minor
b-biswas aaea974
bug fix
b-biswas 1f28511
fix bug in testing
b-biswas 227b8b8
minor changes
b-biswas 81d12d7
minor
b-biswas 522abac
minor
b-biswas 657ac5d
minor
b-biswas 63f2877
[skip ci] deconvolution
b-biswas 05210cd
Merge branch 'main' into to-jax
beckermr File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import ngmix | ||
import numpy as np | ||
|
||
from deep_field_metadetect.detect import ( | ||
generate_mbobs_for_detections, | ||
run_detection_sep, | ||
) | ||
from deep_field_metadetect.jaxify.jax_metacal import ( | ||
DEFAULT_SHEARS, | ||
DEFAULT_STEP, | ||
jax_metacal_wide_and_deep_psf_matched, | ||
) | ||
from deep_field_metadetect.mfrac import compute_mfrac_interp_image | ||
from deep_field_metadetect.utils import fit_gauss_mom_obs, fit_gauss_mom_obs_and_psf | ||
|
||
|
||
def jax_single_band_deep_field_metadetect( | ||
obs_wide, | ||
obs_deep, | ||
obs_deep_noise, | ||
nxy, | ||
nxy_psf, | ||
step=DEFAULT_STEP, | ||
shears=None, | ||
skip_obs_wide_corrections=False, | ||
skip_obs_deep_corrections=False, | ||
nodet_flags=0, | ||
scale=0.2, | ||
) -> dict: | ||
"""Run deep-field metadetection for a simple scenario of a single band | ||
with a single image per band using only post-PSF Gaussian weighted moments. | ||
|
||
Parameters | ||
---------- | ||
obs_wide : DFMdetObservation | ||
The wide-field observation. | ||
obs_deep : DFMdetObservation | ||
The deep-field observation. | ||
obs_deep_noise : DFMdetObservation | ||
The deep-field noise observation. | ||
nxy: int | ||
Image size | ||
nxy_psf: int | ||
PSF size | ||
step : float, optional | ||
The step size for the metacalibration, by default DEFAULT_STEP. | ||
shears : list, optional | ||
The shears to use for the metacalibration, by default DEFAULT_SHEARS | ||
if set to None. | ||
skip_obs_wide_corrections : bool, optional | ||
Skip the observation corrections for the wide-field observations, | ||
by default False. | ||
skip_obs_deep_corrections : bool, optional | ||
Skip the observation corrections for the deep-field observations, | ||
by default False. | ||
nodet_flags : int, optional | ||
The bmask flags marking area in the image to skip, by default 0. | ||
scale: float | ||
pixel scale | ||
|
||
Returns | ||
------- | ||
dfmdet_res : dict | ||
The deep-field metadetection results, a dictionary with keys from `shears` | ||
and values containing the detection+measurement results for the corresponding | ||
shear. | ||
""" | ||
if shears is None: | ||
shears = DEFAULT_SHEARS | ||
|
||
mcal_res = jax_metacal_wide_and_deep_psf_matched( | ||
obs_wide=obs_wide, | ||
obs_deep=obs_deep, | ||
obs_deep_noise=obs_deep_noise, | ||
nxy=nxy, | ||
nxy_psf=nxy_psf, | ||
step=step, | ||
shears=shears, | ||
skip_obs_wide_corrections=skip_obs_wide_corrections, | ||
skip_obs_deep_corrections=skip_obs_deep_corrections, | ||
scale=scale, | ||
) # This returns ngmix Obs for now | ||
|
||
psf_res = fit_gauss_mom_obs(mcal_res["noshear"].psf) | ||
dfmdet_res = [] | ||
for shear, obs in mcal_res.items(): | ||
detres = run_detection_sep(obs, nodet_flags=nodet_flags) | ||
|
||
ixc = (detres["catalog"]["x"] + 0.5).astype(int) | ||
iyc = (detres["catalog"]["y"] + 0.5).astype(int) | ||
bmask_flags = obs.bmask[iyc, ixc] | ||
|
||
mfrac_vals = np.zeros_like(bmask_flags, dtype="f4") | ||
if np.any(obs.mfrac > 0): | ||
_interp_mfrac = compute_mfrac_interp_image( | ||
obs.mfrac, | ||
obs.jacobian.get_galsim_wcs(), | ||
) | ||
for i, (x, y) in enumerate( | ||
zip(detres["catalog"]["x"], detres["catalog"]["y"]) | ||
): | ||
mfrac_vals[i] = _interp_mfrac.xValue(x, y) | ||
|
||
for ind, (obj, mbobs) in enumerate( | ||
generate_mbobs_for_detections( | ||
ngmix.observation.get_mb_obs(obs), | ||
xs=detres["catalog"]["x"], | ||
ys=detres["catalog"]["y"], | ||
) | ||
): | ||
fres = fit_gauss_mom_obs_and_psf(mbobs[0][0], psf_res=psf_res) | ||
dfmdet_res.append( | ||
(ind + 1, obj["x"], obj["y"], shear, bmask_flags[ind], mfrac_vals[ind]) | ||
+ tuple(fres[0]) | ||
) | ||
|
||
total_dtype = [ | ||
("id", "i8"), | ||
("x", "f8"), | ||
("y", "f8"), | ||
("mdet_step", "U7"), | ||
("bmask_flags", "i4"), | ||
("mfrac", "f4"), | ||
] + fres.dtype.descr | ||
|
||
return np.array(dfmdet_res, dtype=total_dtype) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import jax.numpy as jnp | ||
|
||
|
||
# @partial(jax.jit, static_argnames=["pixel_scale", "image_size"]) | ||
def compute_stepk(pixel_scale, image_size): | ||
"""Compute psf fourier scale based on pixel scale and psf image dimension | ||
The size if obtained from from galsim.GSObject.getGoodImageSize | ||
The factor 1/4 from deep_field_metadetect.metacal.get_gauss_reconv_psf_galsim | ||
|
||
Parameters: | ||
----------- | ||
pixel_scale : float | ||
The scale of a single pixel in the image. | ||
image_size : int | ||
The dimension of the PSF image (typically a square size). | ||
|
||
Returns: | ||
-------- | ||
float | ||
The computed stepk value, which represents the Fourier-space sampling frequency. | ||
""" | ||
return 2 * jnp.pi / (image_size * pixel_scale) / 4 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from typing import NamedTuple, Optional | ||
|
||
import jax | ||
import jax_galsim | ||
import ngmix | ||
import numpy as np | ||
from ngmix.observation import Observation | ||
|
||
|
||
@jax.tree_util.register_pytree_node_class | ||
class DFMdetObservation(NamedTuple): | ||
image: jax.Array | ||
weight: Optional[jax.Array] | ||
bmask: Optional[jax.Array] | ||
ormask: Optional[jax.Array] | ||
noise: Optional[jax.Array] | ||
aft: Optional[jax_galsim.wcs.AffineTransform] | ||
psf: Optional["DFMdetObservation"] | ||
mfrac: Optional[jax.Array] | ||
meta: Optional[dict] | ||
store_pixels: bool | ||
ignore_zero_weight: bool | ||
|
||
def tree_flatten(self): | ||
children = ( | ||
self.image, | ||
self.weight, | ||
self.bmask, | ||
self.ormask, | ||
self.noise, | ||
self.aft, | ||
self.psf, | ||
self.mfrac, | ||
) | ||
|
||
aux_data = (self.meta, self.store_pixels, self.ignore_zero_weight) | ||
|
||
return children, aux_data | ||
|
||
@classmethod | ||
def tree_unflatten(cls, aux_data, children): | ||
# Reconstruct the object from flattened data | ||
return cls(*children, *aux_data) | ||
|
||
def has_bmask(self) -> bool: | ||
if self.bmask is None: | ||
return False | ||
return True | ||
|
||
def has_mfrac(self) -> bool: | ||
if self.bmask is None: | ||
return False | ||
return True | ||
|
||
def has_noise(self) -> bool: | ||
if self.noise is None: | ||
return False | ||
return True | ||
|
||
def has_ormask(self) -> bool: | ||
if self.ormask is None: | ||
return False | ||
return True | ||
|
||
def has_psf(self) -> bool: | ||
if self.psf is None: | ||
return False | ||
return True | ||
|
||
|
||
def ngmix_obs_to_dfmd_obs(obs: ngmix.observation.Observation) -> DFMdetObservation: | ||
jacobian = obs.get_jacobian() | ||
|
||
psf = None | ||
if obs.has_psf(): | ||
psf = ngmix_obs_to_dfmd_obs(obs.get_psf()) | ||
|
||
return DFMdetObservation( | ||
image=obs.image, | ||
weight=obs.weight, | ||
bmask=obs.bmask if obs.has_bmask() else None, | ||
ormask=obs.ormask if obs.has_ormask() else None, | ||
noise=obs.noise if obs.has_noise() else None, | ||
aft=jax_galsim.wcs.AffineTransform( | ||
dudx=jacobian.dudcol, | ||
dudy=jacobian.dudrow, | ||
dvdx=jacobian.dvdcol, | ||
dvdy=jacobian.dvdrow, | ||
origin=jax_galsim.PositionD( | ||
y=jacobian.row0 + 1, | ||
x=jacobian.col0 + 1, | ||
), | ||
), | ||
psf=psf, | ||
meta=obs.meta, | ||
mfrac=obs.mfrac if obs.has_mfrac() else None, | ||
store_pixels=getattr(obs, "store_pixels", True), | ||
ignore_zero_weight=getattr(obs, "ignore_zero_weight", True), | ||
) | ||
|
||
|
||
def dfmd_obs_to_ngmix_obs(dfmd_obs) -> Observation: | ||
psf = None | ||
if dfmd_obs.psf is not None: | ||
psf = dfmd_obs_to_ngmix_obs(dfmd_obs.psf) | ||
return Observation( | ||
image=np.array(dfmd_obs.image), | ||
weight=np.array(dfmd_obs.weight), | ||
bmask=dfmd_obs.bmask, | ||
ormask=dfmd_obs.ormask, | ||
noise=dfmd_obs.noise if dfmd_obs.noise is None else np.array(dfmd_obs.noise), | ||
jacobian=ngmix.jacobian.Jacobian( | ||
row=dfmd_obs.aft.origin.y - 1, | ||
col=dfmd_obs.aft.origin.x - 1, | ||
dudcol=dfmd_obs.aft.dudx, | ||
dudrow=dfmd_obs.aft.dudy, | ||
dvdcol=dfmd_obs.aft.dvdx, | ||
dvdrow=dfmd_obs.aft.dvdy, | ||
), | ||
psf=psf, | ||
mfrac=dfmd_obs.mfrac if dfmd_obs.mfrac is None else np.array(dfmd_obs.mfrac), | ||
meta=dfmd_obs.meta, | ||
store_pixels=np.array(dfmd_obs.store_pixels, dtype=np.bool_), | ||
ignore_zero_weight=np.array(dfmd_obs.ignore_zero_weight, dtype=np.bool_), | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking ahead, I think we should call this
wcs
.