Skip to content

JVP (Merge ready) #372

@unalmis

Description

@unalmis

It is preferable to define the custom derivative via the tangent mapping instead of the adjoint, as JAX can get the adjoint map for free from the tangent map. The current choice of defining the adjoint (vjp) instead breaks ability of downstream codes to differentiate via forward mode derivatives, which is unfortunate as the forward derivative is much more efficient for square systems.

For example, not having a JVP means that the full Jacobian matrix must be materialized when solving nonlinear square system f(x) = 0 where computing f requires a call to s2fft.

A Newton-Krylov solver which can compute the JVP of f has computational cost O(k⋅jvp cost(f))) where the number of iterations is k<<N=size(f) about = jvp cost. When limited to VJPs the cost balloons to O(N⋅ vjp cost(f)+k⋅N^2)

Initial cost Krylov Iteration Cost Memory
Newton-Krylov with VJP $O(N \cdot \text{vjp cost}(f))$ $O(k N^2)$ $O(N^2) + \text{reverse mode memory graph of f}$
Newton-Krylov with JVP $O(\text{cost}(f))$ $O(k \cdot \text{jvp cost}(f))$ $O(N)$

It is trivial to define the JVPs since the transforms are linear.

@partial(jax.custom_jvp, nondiff_argnames=("junk",))
def _thing(f, junk):  # replace thing with forward or inverse
    L, sampling, reality, method, precomps = junk
    return s2fft.thing(  # replace with the actual code this thing calls
        f,
        L,
        sampling=sampling,
        reality=reality,
        method=method,
        precomps=precomps,
    )


@_thing.defjvp
def _thing_jvp(junk, primals, tangents):
    (p,) = primals
    (t,) = tangents
    return _thing(p, junk), _thing(t, junk)

Then make sure to remove all custom vjps.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions