-
Notifications
You must be signed in to change notification settings - Fork 14
Description
It is preferable to define the custom derivative via the tangent mapping instead of the adjoint, as JAX can get the adjoint map for free from the tangent map. The current choice of defining the adjoint (vjp) instead breaks ability of downstream codes to differentiate via forward mode derivatives, which is unfortunate as the forward derivative is much more efficient for square systems.
For example, not having a JVP means that the full Jacobian matrix must be materialized when solving nonlinear square system f(x) = 0 where computing f requires a call to s2fft.
A Newton-Krylov solver which can compute the JVP of f has computational cost O(k⋅jvp cost(f))) where the number of iterations is k<<N=size(f) about = jvp cost. When limited to VJPs the cost balloons to O(N⋅ vjp cost(f)+k⋅N^2)
| Initial cost | Krylov Iteration Cost | Memory | |
|---|---|---|---|
| Newton-Krylov with VJP | |||
| Newton-Krylov with JVP |
It is trivial to define the JVPs since the transforms are linear.
@partial(jax.custom_jvp, nondiff_argnames=("junk",))
def _thing(f, junk): # replace thing with forward or inverse
L, sampling, reality, method, precomps = junk
return s2fft.thing( # replace with the actual code this thing calls
f,
L,
sampling=sampling,
reality=reality,
method=method,
precomps=precomps,
)
@_thing.defjvp
def _thing_jvp(junk, primals, tangents):
(p,) = primals
(t,) = tangents
return _thing(p, junk), _thing(t, junk)Then make sure to remove all custom vjps.