Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions src/lightcurvelynx/astro_utils/snia_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from scipy.stats.sampling import NumericalInversePolynomial

from lightcurvelynx.base_models import FunctionNode
from lightcurvelynx.math_nodes.np_random import NumpyRandomFunc
from lightcurvelynx.math_nodes.scipy_random import NumericalInversePolynomialFunc


Expand Down Expand Up @@ -356,23 +357,44 @@ class SNCoordGivenPhysicalSep(FunctionNode):
The Hubble constant.
Omega_m : constant
The matter density Omega_m.
pos_angle : parameter or None
The position angle for the SN location relative to the host galaxy (in radians).
If None, a random position angle is generated for each sample.
**kwargs : dict, optional
Any additional keyword arguments.
"""

def __init__(self, host_ra, host_dec, physical_sep_kpc, redshift, H0=73.0, Omega_m=0.3, **kwargs):
def __init__(
self,
host_ra,
host_dec,
physical_sep_kpc,
redshift,
*,
H0=73.0,
Omega_m=0.3,
pos_angle=None,
**kwargs,
):
# Create the cosmology once for this node.
if not isinstance(H0, float) or not isinstance(Omega_m, float): # pragma: no cover
raise ValueError("H0 and Omega_m must be constants.")
self.cosmo = FlatLambdaCDM(H0=H0, Om0=Omega_m)

# The _sn_coord function uses a position angle that we want to randomly
# generate for each sample. If the user does not provide their own setter
# we use a numpy random function to generate the position angle [0., 2*pi].
if pos_angle is None:
pos_angle = NumpyRandomFunc("uniform", low=0.0, high=2 * np.pi)

# Call the super class's constructor with the needed information.
super().__init__(
func=self._sn_coord,
host_ra=host_ra,
host_dec=host_dec,
physical_sep_kpc=physical_sep_kpc,
redshift=redshift,
pos_angle=pos_angle,
outputs=["ra", "dec"],
**kwargs,
)
Expand All @@ -399,7 +421,7 @@ def _host_sn_angular_separation(self, physical_sep_kpc, redshift):

return angular_sep_rad

def _sn_coord(self, host_ra, host_dec, physical_sep_kpc, redshift):
def _sn_coord(self, host_ra, host_dec, physical_sep_kpc, redshift, pos_angle):
"""
Function to generate SN coordinates given the host coordinates and angular separation.

Expand All @@ -413,6 +435,8 @@ def _sn_coord(self, host_ra, host_dec, physical_sep_kpc, redshift):
The physical host-sn separation(s) in kpc.
redshift : float or numpy.ndarray
The redshift to convert physical separation to angular separation.
pos_angle : float or numpy.ndarray
The position angle(s) for the SN location relative to the host galaxy (in radians).

Returns
-------
Expand All @@ -421,12 +445,9 @@ def _sn_coord(self, host_ra, host_dec, physical_sep_kpc, redshift):
dec : float or numpy.ndarray
SN DEC values (in degrees).
"""

self.angular_sep_rad = self._host_sn_angular_separation(physical_sep_kpc, redshift)
center = SkyCoord(host_ra * u.deg, host_dec * u.deg, frame="icrs")
rand_pa = 2 * np.pi * np.random.random(center.size) * u.rad # random position angle

sn_coord = center.directional_offset_by(rand_pa, self.angular_sep_rad)
sn_coord = center.directional_offset_by(pos_angle * u.rad, self.angular_sep_rad)
ra = sn_coord.ra.deg
dec = sn_coord.dec.deg

Expand Down
2 changes: 1 addition & 1 deletion src/lightcurvelynx/obstable/obs_table_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def _init_formulas(self):
self.add_formula(
parameter="sky_bg_electrons",
inputs=["skybrightness", "pixel_scale", "zp"],
formula=lambda skybrightness, pixel_scale, zp: (mag2flux(skybrightness) * pixel_scale**2 / zp),
formula=lambda skybrightness, pixel_scale, zp: mag2flux(skybrightness) * pixel_scale**2 / zp,
)
self.add_formula(
parameter="sky_bg_electrons",
Expand Down
50 changes: 48 additions & 2 deletions tests/lightcurvelynx/astro_utils/test_snia_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ def test_sn_coord_given_physical_separation():
host_ra, host_dec, physical_sep_kpc, redshift, H0=H0, Omega_m=Omega_m, node_label="sncoord"
)
state = sncoord.sample_parameters()
sn_ra = state["sncoord.ra"][0]
sn_dec = state["sncoord.dec"][0]
sn_ra = state["sncoord.ra"]
sn_dec = state["sncoord.dec"]
# calculate physical separation from host coord, sn coor, and redshift
host = SkyCoord(host_ra * u.deg, host_dec * u.deg)
sn = SkyCoord(sn_ra * u.deg, sn_dec * u.deg)
Expand All @@ -153,3 +153,49 @@ def test_sn_coord_given_physical_separation():
calculated_physical_sep = sep.to(u.rad).value * DA

assert np.isclose(calculated_physical_sep, physical_sep_kpc)


def test_sn_coord_given_physical_separation_fixed():
"""Test that we always get the same sn coor if we fix the position angle."""
sncoord = SNCoordGivenPhysicalSep(
0.0, # host_ra
0.0, # host_dec
1.0, # physical_sep_kpc
0.5, # redshift
pos_angle=1.0,
H0=70.0,
Omega_m=0.3,
node_label="sncoord_fixed_pa",
)
state = sncoord.sample_parameters(num_samples=100)
assert len(np.unique(state["sncoord_fixed_pa.ra"])) == 1
assert len(np.unique(state["sncoord_fixed_pa.dec"])) == 1


def test_sn_coord_given_physical_separation_seeded():
"""Test that we can control the randomness with a seeded random number generator."""
sncoord = SNCoordGivenPhysicalSep(
0.0, # host_ra
0.0, # host_dec
1.0, # physical_sep_kpc
0.5, # redshift
H0=70.0,
Omega_m=0.3,
node_label="sncoord_fixed_pa",
)

# Given the same seed, we should get the same results.
rng1 = np.random.default_rng(seed=234)
state1 = sncoord.sample_parameters(num_samples=1000, rng_info=rng1)

rng2 = np.random.default_rng(seed=234)
state2 = sncoord.sample_parameters(num_samples=1000, rng_info=rng2)

assert np.allclose(state1["sncoord_fixed_pa.ra"], state2["sncoord_fixed_pa.ra"])
assert np.allclose(state1["sncoord_fixed_pa.dec"], state2["sncoord_fixed_pa.dec"])

# Given a different seed, we should get different results.
rng3 = np.random.default_rng(seed=235)
state3 = sncoord.sample_parameters(num_samples=1000, rng_info=rng3)
assert not np.allclose(state1["sncoord_fixed_pa.ra"], state3["sncoord_fixed_pa.ra"])
assert not np.allclose(state1["sncoord_fixed_pa.dec"], state3["sncoord_fixed_pa.dec"])
Loading