Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 152 additions & 10 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module MatrixAlgebraKitMooncakeExt
using Mooncake
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
using MatrixAlgebraKit
using MatrixAlgebraKit: inv_safe, diagview, copy_input
using MatrixAlgebraKit: inv_safe, diagview, copy_input, zero!
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
Expand Down Expand Up @@ -52,11 +52,11 @@ for (f!, f, pb, adj) in (
$f!(A, args, Mooncake.primal(alg_dalg))
function $adj(::NoRData)
copy!(A, Ac)
$pb(dA, A, (arg1, arg2), (darg1, darg2))
copy!(arg1, arg1c)
copy!(arg2, arg2c)
MatrixAlgebraKit.zero!(darg1)
MatrixAlgebraKit.zero!(darg2)
$pb(dA, A, (arg1, arg2), (darg1, darg2))
zero!(darg1)
zero!(darg2)
return NoRData(), NoRData(), NoRData(), NoRData()
end
return args_dargs, $adj
Expand All @@ -76,8 +76,8 @@ for (f!, f, pb, adj) in (
arg1, darg1 = arrayify(arg1, darg1_)
arg2, darg2 = arrayify(arg2, darg2_)
$pb(dA, A, (arg1, arg2), (darg1, darg2))
MatrixAlgebraKit.zero!(darg1)
MatrixAlgebraKit.zero!(darg2)
zero!(darg1)
zero!(darg2)
return NoRData(), NoRData(), NoRData()
end
return output_codual, $adj
Expand All @@ -99,8 +99,8 @@ for (f!, f, pb, adj) in (
$f!(A, arg, Mooncake.primal(alg_dalg))
function $adj(::NoRData)
copy!(A, Ac)
$pb(dA, A, arg, darg)
copy!(arg, argc)
$pb(dA, A, arg, darg)
MatrixAlgebraKit.zero!(darg)
return NoRData(), NoRData(), NoRData(), NoRData()
end
Expand Down Expand Up @@ -137,6 +137,7 @@ for (f!, f, f_full, pb, adj) in (
copy!(D, diagview(DV[1]))
V = DV[2]
function $adj(::NoRData)
copy!(D, diagview(DV[1]))
$pb(dA, A, DV, dD)
MatrixAlgebraKit.zero!(dD)
return NoRData(), NoRData(), NoRData(), NoRData()
Expand All @@ -163,12 +164,43 @@ for (f!, f, f_full, pb, adj) in (
end
end

for (f, f_ne, pb, adj) in (
(:eig_trunc, :eig_trunc_no_error, :eig_trunc_pullback!, :eig_trunc_adjoint),
(:eigh_trunc, :eigh_trunc_no_error, :eigh_trunc_pullback!, :eigh_trunc_adjoint),
for (f!, f, f_ne!, f_ne, pb, adj) in (
(:eig_trunc!, :eig_trunc, :eig_trunc_no_error!, :eig_trunc_no_error, :eig_trunc_pullback!, :eig_trunc_adjoint),
(:eigh_trunc!, :eigh_trunc, :eigh_trunc_no_error!, :eigh_trunc_no_error, :eigh_trunc_pullback!, :eigh_trunc_adjoint),
)
@eval begin
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual)
# compute primal
A, dA = arrayify(A_dA)
DV = Mooncake.primal(DV_dDV)
dDV = Mooncake.tangent(DV_dDV)
Ac = copy(A)
DVc = copy.(DV)
alg = Mooncake.primal(alg_dalg)
output = $f!(A, DV, alg)
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
# pass). For many types this is done automatically when the forward step returns, but
# not for nested structs with various fields (like Diagonal{Complex})
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
function $adj(dy::Tuple{NoRData, NoRData, T}) where {T <: Real}
copy!(A, Ac)
copy!(DV[1], DVc[1])
copy!(DV[2], DVc[2])
Dtrunc, Vtrunc, ϵ = Mooncake.primal(output_codual)
dDtrunc_, dVtrunc_, dϵ = Mooncake.tangent(output_codual)
abs(dy[3]) > MatrixAlgebraKit.defaulttol(dy[3]) && @warn "Pullback for $f does not yet support non-zero tangent for the truncation error"
D′, dD′ = arrayify(Dtrunc, dDtrunc_)
V′, dV′ = arrayify(Vtrunc, dVtrunc_)
$pb(dA, A, (D′, V′), (dD′, dV′))
MatrixAlgebraKit.zero!(dD)
MatrixAlgebraKit.zero!(dV)
return NoRData(), NoRData(), NoRData()
end
return output_codual, $adj
end
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual)
# compute primal
A, dA = arrayify(A_dA)
Expand All @@ -192,7 +224,37 @@ for (f, f_ne, pb, adj) in (
end
return output_codual, $adj
end
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_ne!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f_ne), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof($f_ne!)}, A_dA::CoDual, DV_dDV::CoDual, alg_dalg::CoDual)
# compute primal
A, dA = arrayify(A_dA)
alg = Mooncake.primal(alg_dalg)
DV = Mooncake.primal(DV_dDV)
dDV = Mooncake.tangent(DV_dDV)
Ac = copy(A)
DVc = copy.(DV)
output = $f_ne(A, DV, alg)
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
# pass). For many types this is done automatically when the forward step returns, but
# not for nested structs with various fields (like Diagonal{Complex})
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
function $adj(::NoRData)
copy!(A, Ac)
copy!(DV[1], DVc[1])
copy!(DV[2], DVc[2])
Dtrunc, Vtrunc = Mooncake.primal(output_codual)
dDtrunc_, dVtrunc_ = Mooncake.tangent(output_codual)
D′, dD′ = arrayify(Dtrunc, dDtrunc_)
V′, dV′ = arrayify(Vtrunc, dVtrunc_)
$pb(dA, A, (D′, V′), (dD′, dV′))
MatrixAlgebraKit.zero!(dD)
MatrixAlgebraKit.zero!(dV)
return NoRData(), NoRData(), NoRData()
end
return output_codual, $adj
end
function Mooncake.rrule!!(::CoDual{typeof($f_ne)}, A_dA::CoDual, alg_dalg::CoDual)
# compute primal
A, dA = arrayify(A_dA)
Expand Down Expand Up @@ -232,9 +294,13 @@ for (f!, f) in (
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
USVᴴc = copy.(USVᴴ)
output = $f!(A, Mooncake.primal(alg_dalg))
function svd_adjoint(::NoRData)
copy!(A, Ac)
copy!(U, USVᴴc[1])
copy!(S, USVᴴc[2])
copy!(Vᴴ, USVᴴc[3])
if $(f! == svd_compact!)
svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
else # full
Expand Down Expand Up @@ -301,6 +367,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua
function svd_vals_adjoint(::NoRData)
svd_vals_pullback!(dA, A, USVᴴ, dS)
MatrixAlgebraKit.zero!(dS)
copy!(S, diagview(USVᴴ[2]))
return NoRData(), NoRData(), NoRData(), NoRData()
end
return S_dS, svd_vals_adjoint
Expand All @@ -326,6 +393,44 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co
return S_codual, svd_vals_adjoint
end

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
# compute primal
A, dA = arrayify(A_dA)
alg = Mooncake.primal(alg_dalg)
Ac = copy(A)
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
USVᴴc = copy.(USVᴴ)
output = svd_trunc!(A, USVᴴ, alg)
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
# pass). For many types this is done automatically when the forward step returns, but
# not for nested structs with various fields (like Diagonal{Complex})
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real}
copy!(A, Ac)
copy!(U, USVᴴc[1])
copy!(S, USVᴴc[2])
copy!(Vᴴ, USVᴴc[3])
Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual)
dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual)
abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc does not yet support non-zero tangent for the truncation error"
U′, dU′ = arrayify(Utrunc, dUtrunc_)
S′, dS′ = arrayify(Strunc, dStrunc_)
Vᴴ′, dVᴴ′ = arrayify(Vᴴtrunc, dVᴴtrunc_)
svd_trunc_pullback!(dA, A, (U′, S′, Vᴴ′), (dU′, dS′, dVᴴ′))
MatrixAlgebraKit.zero!(dU)
MatrixAlgebraKit.zero!(dS)
MatrixAlgebraKit.zero!(dVᴴ)
return NoRData(), NoRData(), NoRData()
end
return output_codual, svd_trunc_adjoint
end

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual)
# compute primal
Expand Down Expand Up @@ -355,6 +460,43 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::C
return output_codual, svd_trunc_adjoint
end

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual)
# compute primal
A, dA = arrayify(A_dA)
alg = Mooncake.primal(alg_dalg)
Ac = copy(A)
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
USVᴴc = copy.(USVᴴ)
output = svd_trunc_no_error!(A, USVᴴ, alg)
# fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal
# of ComplexF32) into the correct **forwards** data type (since we are now in the forward
# pass). For many types this is done automatically when the forward step returns, but
# not for nested structs with various fields (like Diagonal{Complex})
output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output)))
function svd_trunc_adjoint(::NoRData)
copy!(A, Ac)
copy!(U, USVᴴc[1])
copy!(S, USVᴴc[2])
copy!(Vᴴ, USVᴴc[3])
Utrunc, Strunc, Vᴴtrunc = Mooncake.primal(output_codual)
dUtrunc_, dStrunc_, dVᴴtrunc_ = Mooncake.tangent(output_codual)
U′, dU′ = arrayify(Utrunc, dUtrunc_)
S′, dS′ = arrayify(Strunc, dStrunc_)
Vᴴ′, dVᴴ′ = arrayify(Vᴴtrunc, dVᴴtrunc_)
svd_trunc_pullback!(dA, A, (U′, S′, Vᴴ′), (dU′, dS′, dVᴴ′))
MatrixAlgebraKit.zero!(dU)
MatrixAlgebraKit.zero!(dS)
MatrixAlgebraKit.zero!(dVᴴ)
return NoRData(), NoRData(), NoRData()
end
return output_codual, svd_trunc_adjoint
end

@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc_no_error), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, alg_dalg::CoDual)
# compute primal
Expand Down
Loading