Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import matplotlib.pyplot as plt
import numpy as np
from numpy.lib.recfunctions import append_fields

from flarestack.core.astro import angular_distance
from flarestack.data.icecube.ps_tracks.ps_v002_p01 import IC86_1_dict
Expand Down Expand Up @@ -93,9 +92,7 @@ def weighted_quantile(values, quantiles, weight):
cut_mc = mc[mask]

percentile = np.ones_like(cut_mc["ra"]) * np.nan
cut_mc = append_fields(
cut_mc, "percentile", percentile, usemask=False, dtypes=[np.float]
)
cut_mc.add_column(percentile, name="percentile")

weights = cut_mc["ow"] * cut_mc["trueE"] ** -gamma
# weights = np.ones_like(cut_mc["ow"])
Expand Down

Large diffs are not rendered by default.

20 changes: 11 additions & 9 deletions flarestack/core/injector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import random
import zipfile
import zlib
from typing import TYPE_CHECKING

import numpy as np
from astropy.table import Table
Expand All @@ -14,6 +15,9 @@
from flarestack.shared import band_mask_cache_name, k_to_flux
from flarestack.utils.catalogue_loader import calculate_source_weight

if TYPE_CHECKING:
from flarestack.data import SeasonWithMC

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -245,7 +249,7 @@ def __init__(self, season, sources, **kwargs):
logger.warning("No Injection Arguments. Are you unblinding?")
pass

def get_mc(self, season):
def get_mc(self, season: "SeasonWithMC") -> Table:
return season.get_mc()

def select_mc_band(self, source):
Expand Down Expand Up @@ -537,14 +541,12 @@ class TableInjector(MCInjector):
For 1000 sources, calculate_n_exp() is ~60x faster than MCInjector.
"""

def get_mc(self, season):
mc: np.ndarray = season.get_mc()
# Sort rows by trueDec, and store as columns in a Table
table = Table(mc[np.argsort(mc["trueDec"].copy())])
# Prevent in-place modifications
for k in table.columns:
table[k].setflags(write=False)
return table
def get_mc(self, season: "SeasonWithMC") -> Table:
mc = season.get_mc().copy(copy_data=False)
mc.sort("trueDec")
for col in mc.columns.values():
col.setflags(write=False)
return mc

def get_band_mask(self, source, min_dec, max_dec):
return slice(*np.searchsorted(self._mc["trueDec"], [min_dec, max_dec]))
Expand Down
21 changes: 19 additions & 2 deletions flarestack/core/llh.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from scipy import sparse

from flarestack.core.energy_pdf import EnergyPDF, read_e_pdf_dict
from flarestack.core.spatial_pdf import SpatialPDF
from flarestack.core.spatial_pdf import SpatialPDF, angular_distance
from flarestack.core.time_pdf import TimePDF, read_t_pdf_dict
from flarestack.shared import (
SoB_spline_path,
Expand Down Expand Up @@ -1363,6 +1363,12 @@ class StdMatrixKDEEnabledLLH(StandardOverlappingLLH):
"""

def __init__(self, season, sources, llh_dict):
# propagate the spatial_box_width from llh_dict to the
# llh_spatial_pdf if not explicitly set
if "spatial_box_width" not in llh_dict["llh_spatial_pdf"]:
llh_dict["llh_spatial_pdf"]["spatial_box_width"] = llh_dict.get(
"spatial_box_width", default_spacial_box_width
)
super().__init__(season, sources, llh_dict)

if llh_dict["llh_spatial_pdf"]["spatial_pdf_name"] != "northern_tracks_kde":
Expand Down Expand Up @@ -1400,7 +1406,18 @@ def get_spatially_coincident_indices(self, data, source) -> np.ndarray:
ra_dist = np.fabs(
(data["ra"][dec_range] - source["ra_rad"] + np.pi) % (2.0 * np.pi) - np.pi
)
return np.nonzero(ra_dist < dPhi / 2.0)[0] + dec_range.start

# Indices (with respect to the start of the declination band) of events inside the box
idx = np.nonzero(ra_dist < dPhi / 2.0)[0]

# Cut the box down to a circle
psi = angular_distance(
data["ra"][dec_range][idx],
data["dec"][dec_range][idx],
source["ra_rad"],
source["dec_rad"],
)
return idx[np.nonzero(psi < width)[0]] + dec_range.start

def create_kwargs(self, data, pull_corrector, weight_f=None):
if weight_f is None:
Expand Down
Loading
Loading