Skip to content

Fix gradient propagation for period parameter and add validation suite#94

Open
gevero wants to merge 6 commits intokc-ml2:mainfrom
gevero:fix/gradient-period
Open

Fix gradient propagation for period parameter and add validation suite#94
gevero wants to merge 6 commits intokc-ml2:mainfrom
gevero:fix/gradient-period

Conversation

@gevero
Copy link

@gevero gevero commented Jan 19, 2026

Description

This PR resolves an issue where gradients with respect to the period parameter were not correctly propagating in both the JAX and PyTorch backends. Additionally, it introduces a comprehensive example script to validate these gradients against numerical finite difference results.

Changes

Core Library Fixes

  • PyTorch Backend: Modified the period setter in _BaseRCWA to use graph-preserving operations (view, repeat, and torch.stack). Previously, the use of torch.tensor() created new leaf tensors, which detached the autograd graph and prevented gradient flow back to the input parameters.
  • JAX Backend: Fixed a TypeError in the period setter occurring when scalar Tracers (e.g., from jax.grad) were passed. Since 0-rank JAX arrays do not support len(), the logic was updated to use an ndim check for proper broadcasting.
  • Documentation: Added detailed internal comments to the modified setters explaining the necessity of these changes for maintaining gradient integrity.

Validation Suite

  • New Example: Added examples/gradient_check_period.py, a robust script that validates AD gradients against Central Finite Difference (FD) approximations.
  • Coverage:
    • Backends: Validates both JAX and PyTorch implementations.
    • Modes: Tests both 1D and 2D simulations.
    • Input Types: Ensures correctness for period passed as a scalar, a 1D vector (array), or a list of values/tensors.
  • Sensitivity: The 2D test case uses a checkerboard pattern to ensure that gradients for both $p_x$ and $p_y$ are properly calculated and verified.
  • Observability: The script now outputs both the calculated gradient magnitudes and the absolute difference between AD and FD for easier debugging and verification.

Verification Results

Validated the fixes using the new examples/gradient_check_period.py script. In all tested configurations (1D/2D, JAX/Torch, multiple Fourier orders), the AD gradients match the FD results with a precision of approximately $10^{-5}$ to $10^{-6}$.


Disclaimer: This pull request and the associated code modifications were prepared with the assistance of gemini-cli.

…ends

- In PyTorch backend, fixed 'period' setter to avoid creating new leaf tensors, ensuring gradients propagate correctly.
- In JAX backend, fixed 'period' setter to correctly handle scalar Tracers/Arrays by checking 'ndim' instead of relying on 'len()', which fails for scalars.
…eck example

- Added detailed comments to 'period' setters in JAX and PyTorch backends explaining the fixes.
- Added 'examples/gradient_check_period.py' which tests the gradient of the 'period' parameter for both backends, in 1D and 2D modes, across all supported input formats (scalar, vector, list).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant