Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
f6d4362
setup barebones sphinx gallery build
willGraham01 Mar 10, 2026
f981686
Placeholder first script
willGraham01 Mar 10, 2026
2ff8e8c
This configuration works!
willGraham01 Mar 10, 2026
e37cb25
Copy across one particular notebook, it still errors b/c translation …
willGraham01 Mar 10, 2026
6d06a7f
Remove JAX_CUDA_HEALPix example for now, since it requires a CUDA device
willGraham01 Mar 12, 2026
4cb57bb
Port across JAX_HEALPix_frontend
willGraham01 Mar 12, 2026
bcd86b0
Add ignore rules for the docs/exmaples directory
willGraham01 Mar 12, 2026
4448a97
Port JAX_SSHT_frontend - this and HEALPix frontend are very similar btw
willGraham01 Mar 12, 2026
29d1839
Restructure folder dir for examples so sphinx-gallery is happy
willGraham01 Mar 12, 2026
22f16da
Translate spherical_harmoinc_transform. NOTE this example is SLOW
willGraham01 Mar 12, 2026
e7fc386
Translate spherical_rotation.py (FAST example)
willGraham01 Mar 12, 2026
2b9c3cc
Translate torch_frontend example
willGraham01 Mar 12, 2026
9721ea0
Translate Wigner example
willGraham01 Mar 12, 2026
279d3a3
Delete now-redundant notebooks
willGraham01 Mar 12, 2026
a02ccaf
RST-ify old markdown content
willGraham01 Mar 12, 2026
897043e
TOML format not synced to repo settings
willGraham01 Mar 12, 2026
3e38c0d
I said SYNC
willGraham01 Mar 12, 2026
ea6b6de
Random relative path in config...
willGraham01 Mar 12, 2026
e8d4257
Split out examples dependencies, have docs install both so notebooks …
willGraham01 Mar 12, 2026
960903a
More hidden packages needed
willGraham01 Mar 12, 2026
9aec818
ducc0 also required in notebooks
willGraham01 Mar 13, 2026
8721663
Preserve link to JAX_CUDA notebook
willGraham01 Mar 13, 2026
709cf6c
One-line pip install rather than running twice
willGraham01 Mar 13, 2026
e34d546
Update Notebooks section of README to match new expectations
willGraham01 Mar 13, 2026
c27cab8
Preserve old tutorial page information in gallery header
willGraham01 Mar 13, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
run: |
sudo apt install pandoc
python -m pip install --upgrade pip
pip install .[docs]
pip install .[docs,examples]

- name: Build documentation
run: cd docs && make html
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,7 @@ install_manifest.txt
compile_commands.json
CTestTestfile.cmake
_deps

# Docs build
docs/tutorials/
docs/sg_execution_times.rst
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ pytest

## Documentation 📖

Documentation for the released version is available [here](https://astro-informatics.github.io/s2fft/).
Documentation for the released version [is available here](https://astro-informatics.github.io/s2fft/).
To install the documentation dependencies, clone the repository and install the package (in [editable mode](https://setuptools.pypa.io/en/latest/userguide/development_mode.html))
with the extra documentation dependencies by running from the root of the repository

Expand All @@ -172,14 +172,13 @@ open _build/html/index.html

## Notebooks 📓

A series of tutorial notebooks are included in the `notebooks` directory
and rendered [in the documentation](https://astro-informatics.github.io/s2fft/tutorials/index.html).
A series of tutorial notebooks have been rendered in, and are available for download from, [the documentation](https://astro-informatics.github.io/s2fft/tutorials/index.html).

To install the dependencies required to run the notebooks locally, clone the repository and install the package (in [editable mode](https://setuptools.pypa.io/en/latest/userguide/development_mode.html))
with the extra documentation and plotting dependencies by running from the root of the repository
with the extra documentation, plotting, and examples dependencies by running from the root of the repository

```bash
pip install -e ".[docs,plotting]"
pip install -e ".[docs,examples,plotting]"
```

To run the notebooks in Jupyter Lab, run from the root of the repository
Expand Down
17 changes: 14 additions & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
import os
import sys
from importlib.metadata import version as get_version
from pathlib import Path

sys.path.insert(0, os.path.abspath(".."))

DOCS_DIR = Path(__file__).parent


# -- Project information -----------------------------------------------------

Expand All @@ -39,19 +42,18 @@
# ones.
extensions = [
"sphinx_copybutton",
"nbsphinx_link",
"sphinx.ext.autodoc",
"sphinx.ext.napoleon",
"sphinx.ext.mathjax",
"sphinx.ext.githubpages",
"sphinx_rtd_theme",
"nbsphinx",
"IPython.sphinxext.ipython_console_highlighting",
"sphinx_tabs.tabs",
"sphinx_git",
"sphinxcontrib.texfigure",
"sphinx.ext.autosectionlabel",
"sphinxemoji.sphinxemoji",
"sphinx_gallery.gen_gallery",
"sphinx_mdinclude",
]

Expand All @@ -60,9 +62,18 @@
napoleon_include_init_with_doc = True
napoleon_numpy_docstring = False

sphinx_gallery_conf = {
"examples_dirs": "../examples",
"gallery_dirs": "./tutorials",
"filename_pattern": "",
# For whatever reason, default_thumb_file is interpreted as
# relative to CWD in which build is run, unless an absolute path is provided.
"default_thumb_file": str(DOCS_DIR / "assets/sax_logo"),
}

# Add any paths that contain templates here, relative to this directory.
templates_path = ["_templates"]
source_suffix = [".rst", ".ipynb"]
source_suffix = [".rst"]

# The master toctree document.
master_doc = "index"
Expand Down
3 changes: 0 additions & 3 deletions docs/tutorials/JAX_HEALPix/JAX_HEALPix_frontend.nblink

This file was deleted.

3 changes: 0 additions & 3 deletions docs/tutorials/JAX_SSHT/JAX_SSHT_frontend.nblink

This file was deleted.

3 changes: 0 additions & 3 deletions docs/tutorials/rotation/rotation.nblink

This file was deleted.

This file was deleted.

3 changes: 0 additions & 3 deletions docs/tutorials/torch_frontend/torch_frontend.nblink

This file was deleted.

3 changes: 0 additions & 3 deletions docs/tutorials/wigner/wigner_transform.nblink

This file was deleted.

28 changes: 11 additions & 17 deletions docs/tutorials/index.rst → examples/GALLERY_HEADER.rst
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
:html_theme.sidebar_secondary.remove:

*****************************
Tutorials
*****************************
=========

This section contains a series of tutorial notebooks which go through some of the
This section contains :ref:`a series of tutorial notebooks <tutorial-notebooks-label>` which go through some of the
key features of the ``S2FFT`` package.

At a high-level the ``S2FFT`` package is structured such that the two primary transforms,
the Wigner and spherical harmonic transforms, can easily be accessed.

Core usage |:rocket:|
-----------------------------
---------------------

To import and use ``S2FFT`` is as simple follows:

Expand All @@ -33,14 +30,11 @@ To import and use ``S2FFT`` is as simple follows:
| f = s2fft.inverse(flm, L, method="jax") | f = s2fft.wigner.inverse(flmn, L, N, method="jax") |
+-------------------------------------------------------+------------------------------------------------------------+

.. toctree::
:hidden:
:maxdepth: 3
:caption: Tutorials

spherical_harmonic/spherical_harmonic_transform.nblink
wigner/wigner_transform.nblink
rotation/rotation.nblink
torch_frontend/torch_frontend.nblink
JAX_SSHT/JAX_SSHT_frontend.nblink
JAX_HEALPix/JAX_HEALPix_frontend.nblink
.. _tutorial-notebooks-label:

Tutorial notebooks
------------------

Below are a few short tutorials that cover how to use specific features of ``S2FFT``.

We also have a notebook demonstrating how to use CUDA-accelerated HEALPix spherical harmonic transforms in ``S2FFT``, which `is accessible in notebook format here <https://github.com/astro-informatics/s2fft/blob/main/notebooks/JAX_CUDA_HEALPix.ipynb>`_, or alternatively can be `opened in Google Colab <https://colab.research.google.com/github/astro-informatics/s2fft/blob/main/notebooks/JAX_CUDA_HEALPix.ipynb>`_.
76 changes: 76 additions & 0 deletions examples/JAX_HEALPix_frontend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
JAX HEALPix Frontend
====================

This short tutorial demonstrates how to use the custom ``JAX`` frontend support ``S2FFT`` provides for the `HEALPix <https://healpix.jpl.nasa.gov>`_ C++ library.
"""

# %%
# ``S2FFT``'s support for the `HEALPix <https://healpix.jpl.nasa.gov>`_ C++ library resolves issues involving long JIT compile times for HEALPix when running on CPU.
# As with the other introductions, let's import some packages and define an arbitrary bandlimited signal to work with.

import jax
import numpy as np

import s2fft

jax.config.update("jax_enable_x64", True)

L = 128
nside = 64
method = "jax_healpy"
sampling = "healpix"
rng = np.random.default_rng(23457801234570)
flm = s2fft.utils.signal_generator.generate_flm(rng, L)
f = s2fft.inverse(flm, L, nside=nside, sampling=sampling, method=method)

# %%
# Calling forward HEALPix C++ function from JAX
# ---------------------------------------------

flm = s2fft.forward(f, L, nside=nside, sampling=sampling, method=method)

# %%
# Calling inverse HEALPix C++ function from JAX
# ---------------------------------------------

f_recov = s2fft.inverse(flm, L, nside=nside, sampling=sampling, method=method)

# %%
# Computing the roundtrip error
# -----------------------------
#
# Let's check the associated error, which should be around ``1e-5`` for healpix, which is not an exact sampling of the sphere.
# Note that increasing ``iters`` will reduce the numerical error here slightly, at the cost of linearly increased compute.

print(f"Mean absolute error = {np.nanmean(np.abs(f_recov - f))}")

# %%
# Differentiating through HEALPix C++ functions
# ---------------------------------------------
#
# So far all this is doing is providing an interface between ``JAX`` and ``HEALPix``, the real novelty comes when we differentiate through the C++ library.


# Define an arbitrary JAX function
def differentiable_test(flm) -> int:
f = s2fft.inverse(flm, L, nside=nside, sampling=sampling, method=method)
return jax.numpy.nanmean(jax.numpy.abs(f) ** 2)


# Create the JAX reverse mode gradient function
gradient_func = jax.grad(differentiable_test)

# Compute the gradient automatically
gradient = gradient_func(flm)

# %%
# Validating these gradients
# --------------------------
#
# This is all well and good, but how do we know these gradients are correct?
# Thankfully ``JAX`` provides a simple function to check this...

from jax.test_util import check_grads

check_grads(differentiable_test, (flm,), order=1, modes=("rev"))
72 changes: 72 additions & 0 deletions examples/JAX_SSHT_frontend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
JAX SSHT frontend
=================

This short tutorial demonstrates how to use the custom ``JAX`` frontend support ``S2FFT`` provides for the `SSHT <https://github.com/astro-informatics/ssht>`_ C library.
"""

# %%
# As with the other introductions, let's import some packages and define an arbitrary bandlimited signal to work with.

import jax
import numpy as np

import s2fft

jax.config.update("jax_enable_x64", True)

L = 128
method = "jax_ssht"
rng = np.random.default_rng(23457801234570)
flm = s2fft.utils.signal_generator.generate_flm(rng, L)
f = s2fft.inverse(flm, L, method=method)

# %%
# Calling forward SSHT C function from JAX
# ----------------------------------------

flm = s2fft.forward(f, L, method=method)

# %%
# Calling inverse SSHT C function from JAX
# ----------------------------------------

f_recov = s2fft.inverse(flm, L, method=method)

# %%
# Computing the roundtrip error
# -----------------------------
#
# Let's check the associated error, which should be close to machine precision for the sampling scheme used.

print(f"Mean absolute error = {np.nanmean(np.abs(f_recov - f))}")

# %%
# Differentiating through SSHT C functions
# ----------------------------------------
#
# So far all this is doing is providing an interface between ``JAX`` and ``SSHT``, the real novelty comes when we differentiate through the C library.


# Define an arbitrary JAX function
def differentiable_test(flm) -> int:
f = s2fft.inverse(flm, L, method=method)
return jax.numpy.nanmean(jax.numpy.abs(f) ** 2)


# Create the JAX reverse mode gradient function
gradient_func = jax.grad(differentiable_test)

# Compute the gradient automatically
gradient = gradient_func(flm)

# %%
# Validating these gradients
# --------------------------
#
# This is all well and good, but how do we know these gradients are correct?
# Thankfully ``JAX`` provides a simple function to check this...

from jax.test_util import check_grads

check_grads(differentiable_test, (flm,), order=1, modes=("rev"))
File renamed without changes.
76 changes: 76 additions & 0 deletions examples/spherical_harmonic_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""
Spherical harmonic transform
============================

This tutorial demonstrates how to use ``S2FFT`` to compute spherical harmonic transforms.
"""

# %%
# In this example we will adopt the sampling scheme of `McEwen & Wiaux (2012) <https://arxiv.org/abs/1110.6298>`_.
# First let's load an input signal that is sampled on the sphere with this sampling scheme.

import jax

jax.config.update("jax_enable_x64", True)

import cartopy.crs as ccrs
import numpy as np
from matplotlib import pyplot as plt

import s2fft

sampling = "mw"
f = np.load("data/Gaia_EDR3_flux.npy")
L = f.shape[0]

# %%
# Let's look at the input signal:

plt.figure(figsize=(10, 5))
ax = plt.axes(projection=ccrs.Mollweide())
im = ax.imshow(f, transform=ccrs.PlateCarree(), cmap="magma")
plt.axis("off")
plt.show()

# %%
# Computing the forward spherical harmonic transform
# --------------------------------------------------
#
# Let's now run the ``JAX`` function to compute the spherical harmonic transform of this map.

flm = s2fft.forward_jax(f, L)

# %%
# If you are planning on applying this transform many times (e.g. during training of a model) we recommend precomputing and storing some small arrays that are used every time.
# This trades off additional memory usage for enhanced speed and should be fine at small and moderate bandlimits ``L``.
#
# To do this simply compute these and pass as a static argument.

precomps = s2fft.generate_precomputes_jax(L, forward=True)
flm_pre = s2fft.forward_jax(f, L, precomps=precomps)

# %%
# Computing the inverse spherical harmonic transform
# --------------------------------------------------
#
# Let's run the ``JAX`` function to compute the inverse spherical harmonic transform to get back to the input map.

f_recov = s2fft.inverse_jax(flm, L)

# %%
# Again, if you are planning on applying this transform many times we recommend precomputing and storing some small arrays that are used every time.
# Recall, this trades off additional memory usage for enhanced speed and should be fine at small and moderate bandlimits ``L``.
#
# To do this simply compute these and pass as a static argument.

precomps = s2fft.generate_precomputes_jax(L, forward=False)
f_recov_pre = s2fft.inverse_jax(flm_pre, L, precomps=precomps)

# %%
# Computing the error
# -------------------
#
# Let's check the associated error, which should be close to machine precision for the sampling scheme used.

print(f"Mean absolute error = {np.nanmean(np.abs(f_recov - f))}")
print(f"Mean absolute error using precomputes = {np.nanmean(np.abs(f_recov_pre - f))}")
Loading
Loading