Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions src/lightcurvelynx/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,23 @@
from lightcurvelynx.astro_utils.mag_flux import flux2mag


def _build_colormap(unique_filters):
"""Construct a colormap for a given set of filters."""
filter_list = list(unique_filters)
n_filters = len(filter_list)

# Use different colormaps depending on the number of filters.
if n_filters <= 10:
cmap = plt.get_cmap("tab10", 10)
elif n_filters <= 20:
cmap = plt.get_cmap("tab20", 20)
else:
cmap = plt.get_cmap("turbo", n_filters)
colormap = {f: cmap(i) for i, f in enumerate(filter_list)}

return colormap


def plot_lightcurves(
fluxes,
times,
Expand Down Expand Up @@ -104,10 +121,7 @@ def plot_lightcurves(
raise ValueError(f"Mismatched array sizes for fluxes ({num_pts}) and fluxerrs ({len(fluxerrs)}).")

if colormap is None:
colormap = {}
colors = "bgrcmyk"
for i, f in enumerate(unique_filters):
colormap[f] = colors[i]
colormap = _build_colormap(unique_filters)

# Plot the data with one line for each filter.
for filter in unique_filters:
Expand Down
76 changes: 75 additions & 1 deletion tests/lightcurvelynx/utils/test_plotting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
import matplotlib.pyplot as plt
import numpy as np
import pytest
from lightcurvelynx.utils.plotting import plot_bandflux_lightcurves, plot_flux_spectrogram, plot_lightcurves
from lightcurvelynx.utils.plotting import (
_build_colormap,
plot_bandflux_lightcurves,
plot_flux_spectrogram,
plot_lightcurves,
)


def test_build_colormap():
"""Test that we can build a colormap."""
for n_filters in [5, 10, 20, 50]:
filters = [f"Filter {i}" for i in range(n_filters)]
colormap = _build_colormap(filters)
assert isinstance(colormap, dict)
assert set(colormap.keys()) == set(filters)
for color in colormap.values():
assert isinstance(color, tuple) and len(color) == 4 # RGBA color


def test_plot_lightcurves():
Expand Down Expand Up @@ -62,6 +78,64 @@ def test_plot_lightcurves():
plt.close("all")


def test_plot_lightcurves_many_filters():
"""Test that we can plot light curves with many filters."""
# Test minimal input
fluxes = np.array([1.0, 2.0, 3.0])
times = np.array([1.0, 2.0, 3.0])
plot_lightcurves(fluxes, times)
# ValueError if len(times) != len(fluxes)
wrong_times = np.array([1.0, 2.0])
with pytest.raises(ValueError):
plot_lightcurves(fluxes, wrong_times)

# ValueError if len(filters) != len(fluxes)
wrong_filters = ["none"]
with pytest.raises(ValueError):
plot_lightcurves(fluxes, times, filters=wrong_filters)

# ValueError if fluxerrs given and len(fluxerrs) != len(fluxes)
wrong_fluxerrs = np.array([0.1, 0.2])
with pytest.raises(ValueError):
plot_lightcurves(fluxes, times, fluxerrs=wrong_fluxerrs)

# Test with almost all inputs given:
# - fluxerrs (same length as fluxes to pass the ValueError check)
# - filters (same length as fluxes to pass the ValueError check)
# - title
fluxerrs = np.array([0.1, 0.2, 0.3])
filters = np.array(["A", "B", "A"])
title = "Test Title"
plot_lightcurves(fluxes, times, fluxerrs=fluxerrs, filters=filters, title=title)

# Plot in magnitudes.
plot_lightcurves(fluxes, times, fluxerrs=fluxerrs, filters=filters, plot_magnitudes=True)

# Plot in magnitudes with an underlying model.
underlying_model = {
"A": np.array([1.5, 2.5, 3.5]),
"B": np.array([2.0, 3.0, 4.0]),
"times": np.array([1.0, 2.0, 3.0]),
}
plot_lightcurves(
fluxes,
times,
fluxerrs=fluxerrs,
filters=filters,
plot_magnitudes=True,
underlying_model=underlying_model,
)

# Test with all inputs given:
# - ax (matplotlib axes object)
# - figure (matplotlib figure object)
fig, ax = plt.subplots()
plot_lightcurves(fluxes, times, fluxerrs=fluxerrs, filters=filters, ax=ax, figure=fig, title=title)

# Close all the open figures.
plt.close("all")


def test_plot_bandflux_lightcurves():
"""Test that we can plot bandflux light curves."""
# Test minimal input
Expand Down
Loading