diff --git a/src/lightcurvelynx/utils/plotting.py b/src/lightcurvelynx/utils/plotting.py index e5c28952..7b8d6135 100644 --- a/src/lightcurvelynx/utils/plotting.py +++ b/src/lightcurvelynx/utils/plotting.py @@ -7,6 +7,23 @@ from lightcurvelynx.astro_utils.mag_flux import flux2mag +def _build_colormap(unique_filters): + """Construct a colormap for a given set of filters.""" + filter_list = list(unique_filters) + n_filters = len(filter_list) + + # Use different colormaps depending on the number of filters. + if n_filters <= 10: + cmap = plt.get_cmap("tab10", 10) + elif n_filters <= 20: + cmap = plt.get_cmap("tab20", 20) + else: + cmap = plt.get_cmap("turbo", n_filters) + colormap = {f: cmap(i) for i, f in enumerate(filter_list)} + + return colormap + + def plot_lightcurves( fluxes, times, @@ -104,10 +121,7 @@ def plot_lightcurves( raise ValueError(f"Mismatched array sizes for fluxes ({num_pts}) and fluxerrs ({len(fluxerrs)}).") if colormap is None: - colormap = {} - colors = "bgrcmyk" - for i, f in enumerate(unique_filters): - colormap[f] = colors[i] + colormap = _build_colormap(unique_filters) # Plot the data with one line for each filter. for filter in unique_filters: diff --git a/tests/lightcurvelynx/utils/test_plotting.py b/tests/lightcurvelynx/utils/test_plotting.py index 0ea193dc..2cfe8f0b 100644 --- a/tests/lightcurvelynx/utils/test_plotting.py +++ b/tests/lightcurvelynx/utils/test_plotting.py @@ -1,7 +1,23 @@ import matplotlib.pyplot as plt import numpy as np import pytest -from lightcurvelynx.utils.plotting import plot_bandflux_lightcurves, plot_flux_spectrogram, plot_lightcurves +from lightcurvelynx.utils.plotting import ( + _build_colormap, + plot_bandflux_lightcurves, + plot_flux_spectrogram, + plot_lightcurves, +) + + +def test_build_colormap(): + """Test that we can build a colormap.""" + for n_filters in [5, 10, 20, 50]: + filters = [f"Filter {i}" for i in range(n_filters)] + colormap = _build_colormap(filters) + assert isinstance(colormap, dict) + assert set(colormap.keys()) == set(filters) + for color in colormap.values(): + assert isinstance(color, tuple) and len(color) == 4 # RGBA color def test_plot_lightcurves(): @@ -62,6 +78,64 @@ def test_plot_lightcurves(): plt.close("all") +def test_plot_lightcurves_many_filters(): + """Test that we can plot light curves with many filters.""" + # Test minimal input + fluxes = np.array([1.0, 2.0, 3.0]) + times = np.array([1.0, 2.0, 3.0]) + plot_lightcurves(fluxes, times) + # ValueError if len(times) != len(fluxes) + wrong_times = np.array([1.0, 2.0]) + with pytest.raises(ValueError): + plot_lightcurves(fluxes, wrong_times) + + # ValueError if len(filters) != len(fluxes) + wrong_filters = ["none"] + with pytest.raises(ValueError): + plot_lightcurves(fluxes, times, filters=wrong_filters) + + # ValueError if fluxerrs given and len(fluxerrs) != len(fluxes) + wrong_fluxerrs = np.array([0.1, 0.2]) + with pytest.raises(ValueError): + plot_lightcurves(fluxes, times, fluxerrs=wrong_fluxerrs) + + # Test with almost all inputs given: + # - fluxerrs (same length as fluxes to pass the ValueError check) + # - filters (same length as fluxes to pass the ValueError check) + # - title + fluxerrs = np.array([0.1, 0.2, 0.3]) + filters = np.array(["A", "B", "A"]) + title = "Test Title" + plot_lightcurves(fluxes, times, fluxerrs=fluxerrs, filters=filters, title=title) + + # Plot in magnitudes. + plot_lightcurves(fluxes, times, fluxerrs=fluxerrs, filters=filters, plot_magnitudes=True) + + # Plot in magnitudes with an underlying model. + underlying_model = { + "A": np.array([1.5, 2.5, 3.5]), + "B": np.array([2.0, 3.0, 4.0]), + "times": np.array([1.0, 2.0, 3.0]), + } + plot_lightcurves( + fluxes, + times, + fluxerrs=fluxerrs, + filters=filters, + plot_magnitudes=True, + underlying_model=underlying_model, + ) + + # Test with all inputs given: + # - ax (matplotlib axes object) + # - figure (matplotlib figure object) + fig, ax = plt.subplots() + plot_lightcurves(fluxes, times, fluxerrs=fluxerrs, filters=filters, ax=ax, figure=fig, title=title) + + # Close all the open figures. + plt.close("all") + + def test_plot_bandflux_lightcurves(): """Test that we can plot bandflux light curves.""" # Test minimal input