Skip to content

Array API refactor #339

@matt-graham

Description

@matt-graham

To support multiple array backends, many functions in the package have separate NumPy and JAX implementations. This adds significant maintenance burden, requiring duplicate bug fixes, documentation, and tests. PyTorch support is provided via wrappers of JAX functions using mutual support for the DLPack in-memory tensor structure, but this forces PyTorch users to also install JAX and complicates GPU memory management as JAX defaults to reserving a large proportion of GPU memory.

It would be beneficial to refactor S2FFT to use the Array API standard, enabling backend-agnostic implementations for NumPy, JAX, and PyTorch (via array_api_compat). This would reduce code duplication, simplifying maintenance, allow native PyTorch support without dependency on JAX, and potential future compatibility with other Array API supporting libraries. As a side benefit the changes would remove the need for manual backend selection, improving usability for downstream libraries using the Array API such as GLASS.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions