-
Notifications
You must be signed in to change notification settings - Fork 14
Description
The package provides a CUDA extension module to reduce the just-in-time compilation overhead of transforms using the HEALPix reduced grid on GPUs. A GitHub Actions workflow is used to build binary wheels and a source distribution for the package using cibuildwheel. As the NVIDIA toolchain is not available by default, build jobs do not compile the extension module and the binary wheels published to PyPI omit the associated feature, with users instead required to build from source to use this functionality. This is a barrier to users benefiting from this feature and also increases the risk of breaking changes to this functionality going undetected.
To overcome these issues we should update our build system to build the CUDA extensions as part of our regular CI/CD jobs. To avoid introducing a required dependency on heavyweight CUDA library dependencies it may make sense to publish the CUDA extension functionality in a separate plug-in package, following the model used by JAX itself.