-
Notifications
You must be signed in to change notification settings - Fork 14
Description
To support multiple array backends, many functions in the package have separate NumPy and JAX implementations. This adds significant maintenance burden, requiring duplicate bug fixes, documentation, and tests. PyTorch support is provided via wrappers of JAX functions using mutual support for the DLPack in-memory tensor structure, but this forces PyTorch users to also install JAX and complicates GPU memory management as JAX defaults to reserving a large proportion of GPU memory.
It would be beneficial to refactor S2FFT to use the Array API standard, enabling backend-agnostic implementations for NumPy, JAX, and PyTorch (via array_api_compat). This would reduce code duplication, simplifying maintenance, allow native PyTorch support without dependency on JAX, and potential future compatibility with other Array API supporting libraries. As a side benefit the changes would remove the need for manual backend selection, improving usability for downstream libraries using the Array API such as GLASS.