Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ GenericLinearAlgebra = "0.3.19"
GenericSchur = "0.5.6"
JET = "0.9, 0.10"
LinearAlgebra = "1"
Mooncake = "0.4.183"
Mooncake = "0.4.195"
ParallelTestRunner = "2"
Random = "1"
SafeTestsets = "0.1"
Expand Down
7 changes: 6 additions & 1 deletion ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: LQViaTransposedQR, TruncationStrategy, NoTruncation, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_gesvdj!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _gpu_heev!, _gpu_heevx!, _sylvester
using AMDGPU
using LinearAlgebra
using LinearAlgebra: BlasFloat
Expand Down Expand Up @@ -171,4 +171,9 @@ end
MatrixAlgebraKit._ind_intersect(A::ROCVector{Int}, B::ROCVector{Int}) =
MatrixAlgebraKit._ind_intersect(collect(A), collect(B))

function _sylvester(A::AnyROCMatrix, B::AnyROCMatrix, C::AnyROCMatrix)
hX = sylvester(collect(A), collect(B), collect(C))
return ROCArray(hX)
end

end
18 changes: 16 additions & 2 deletions ext/MatrixAlgebraKitCUDAExt/MatrixAlgebraKitCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ module MatrixAlgebraKitCUDAExt
using MatrixAlgebraKit
using MatrixAlgebraKit: @algdef, Algorithm, check_input
using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular!
using MatrixAlgebraKit: diagview, sign_safe
using MatrixAlgebraKit: diagview, sign_safe, default_pullback_gauge_atol, default_pullback_rank_atol
using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm
using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm
import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!
import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _sylvester
using CUDA, CUDA.CUBLAS
using CUDA: i32
using LinearAlgebra
Expand Down Expand Up @@ -195,4 +195,18 @@ end
MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) =
MatrixAlgebraKit._ind_intersect(collect(A), collect(B))

MatrixAlgebraKit.default_pullback_rank_atol(A::AnyCuArray) = eps(norm(CuArray(A), Inf))^(3 / 4)
MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray) = MatrixAlgebraKit.iszerotangent(A) ? 0 : eps(norm(CuArray(A), Inf))^(3 / 4)
function MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray, As...)
As′ = filter(!MatrixAlgebraKit.iszerotangent, (A, As...))
return isempty(As′) ? 0 : eps(norm(CuArray.(As′), Inf))^(3 / 4)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this needed? what breaks if we don't do CuArray.(As')?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

norm doesn't work for Adjoint{CuArray} for example

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it help if we use the regular norm instead of the Inf one?

end

function _sylvester(A::AnyCuMatrix, B::AnyCuMatrix, C::AnyCuMatrix)
# https://github.com/JuliaGPU/CUDA.jl/issues/3021
# to add native sylvester to CUDA
hX = sylvester(collect(A), collect(B), collect(C))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very awful but I wasn't able to find a correct way to do it in five minutes so there you go

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any chance we could:

  1. open an issue for CUDA and add the link in some comment here
  2. insert a function _sylvester to avoid type piracy

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes to both

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should open an issue at AMDGPU also

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add this as a comment, and replace our use of LinearAlgebra.sylvester with a hook _sylvester throughout the code? That way we aren't actually doing any type piracy, and it is easy to remove in the future.

return CuArray(hX)
end

end
21 changes: 21 additions & 0 deletions ext/MatrixAlgebraKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ for eig in (:eig, :eigh)
eig_t! = Symbol(eig, "_trunc!")
eig_t_pb = Symbol(eig, "_trunc_pullback")
_make_eig_t_pb = Symbol("_make_", eig_t_pb)
eig_t_ne! = Symbol(eig, "_trunc_no_error!")
eig_t_ne_pb = Symbol(eig, "_trunc_no_error_pullback")
_make_eig_t_ne_pb = Symbol("_make_", eig_t_ne_pb)
eig_v = Symbol(eig, "_vals")
eig_v! = Symbol(eig_v, "!")
eig_v_pb = Symbol(eig_v, "_pullback")
Expand Down Expand Up @@ -136,6 +139,24 @@ for eig in (:eig, :eigh)
end
return $eig_t_pb
end
function ChainRulesCore.rrule(::typeof($eig_t_ne!), A, DV, alg::TruncatedAlgorithm)
Ac = copy_input($eig_f, A)
DV = $(eig_f!)(Ac, DV, alg.alg)
DV′, ind = MatrixAlgebraKit.truncate($eig_t!, DV, alg.trunc)
return DV′, $(_make_eig_t_ne_pb)(A, DV, ind)
end
function $(_make_eig_t_ne_pb)(A, DV, ind)
function $eig_t_ne_pb(ΔDV)
ΔA = zero(A)
ΔD, ΔV = ΔDV
MatrixAlgebraKit.$eig_pb!(ΔA, A, DV, unthunk.((ΔD, ΔV)), ind)
return NoTangent(), ΔA, ZeroTangent(), NoTangent()
end
function $eig_t_ne_pb(::Tuple{ZeroTangent, ZeroTangent}) # is this extra definition useful?
return NoTangent(), ZeroTangent(), ZeroTangent(), NoTangent()
end
return $eig_t_ne_pb
end
function ChainRulesCore.rrule(::typeof($eig_v!), A, D, alg)
DV = $eig_f(A, alg)
function $eig_v_pb(ΔD)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module MatrixAlgebraKitMooncakeExt
using Mooncake
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
using MatrixAlgebraKit
using MatrixAlgebraKit: inv_safe, diagview, copy_input
using MatrixAlgebraKit: inv_safe, diagview, copy_input, initialize_output
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_vals_pullback!
Expand All @@ -18,14 +18,16 @@ Mooncake.tangent_type(::Type{<:MatrixAlgebraKit.AbstractAlgorithm}) = Mooncake.N
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(copy_input), Any, Any}
function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDual)
Ac = copy_input(Mooncake.primal(f_df), Mooncake.primal(A_dA))
dAc = Mooncake.zero_tangent(Ac)
Ac_dAc = Mooncake.zero_fcodual(Ac)
dAc = Mooncake.tangent(Ac_dAc)
function copy_input_pb(::NoRData)
Mooncake.increment!!(Mooncake.tangent(A_dA), dAc)
return NoRData(), NoRData(), NoRData()
end
return CoDual(Ac, dAc), copy_input_pb
return Ac_dAc, copy_input_pb
end

Mooncake.@zero_derivative Mooncake.DefaultCtx Tuple{typeof(initialize_output), Any, Any, Any}
# two-argument in-place factorizations like LQ, QR, EIG
for (f!, f, pb, adj) in (
(:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint),
Expand Down
1 change: 1 addition & 0 deletions src/common/defaults.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ default_pullback_degeneracy_atol(A) = eps(norm(A, Inf))^(3 / 4)
Default tolerance for deciding what values should be considered equal to 0.
"""
default_pullback_rank_atol(A) = eps(norm(A, Inf))^(3 / 4)
default_pullback_rank_atol(A::Diagonal) = default_pullback_rank_atol(diagview(A))

"""
default_hermitian_tol(A)
Expand Down
3 changes: 3 additions & 0 deletions src/common/pullbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,6 @@ function iszerotangent end

iszerotangent(::Any) = false
iszerotangent(::Nothing) = true

# fallback
_sylvester(A, B, C) = LinearAlgebra.sylvester(A, B, C)
44 changes: 23 additions & 21 deletions src/pullbacks/eig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,6 @@ function eig_pullback!(
end
return ΔA
end
function eig_pullback!(
ΔA::Diagonal, A, DV, ΔDV, ind = Colon();
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)
ΔA_full = zero!(similar(ΔA, size(ΔA)))
ΔA_full = eig_pullback!(ΔA_full, A, DV, ΔDV, ind; degeneracy_atol, gauge_atol)
diagview(ΔA) .+= diagview(ΔA_full)
return ΔA
end

"""
eig_trunc_pullback!(
Expand Down Expand Up @@ -150,7 +140,7 @@ function eig_trunc_pullback!(
# add contribution from orthogonal complement
PA = A - (A * V) / V
Y = mul!(ΔVperp, PA', Z, 1, 1)
X = sylvester(PA', -Dmat', Y)
X = _sylvester(PA', -Dmat', Y)
Z .+= X

if eltype(ΔA) <: Real
Expand All @@ -161,16 +151,6 @@ function eig_trunc_pullback!(
end
return ΔA
end
function eig_trunc_pullback!(
ΔA::Diagonal, A, DV, ΔDV;
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)
ΔA_full = zero!(similar(ΔA, size(ΔA)))
ΔA_full = eig_trunc_pullback!(ΔA_full, A, DV, ΔDV; degeneracy_atol, gauge_atol)
diagview(ΔA) .+= diagview(ΔA_full)
return ΔA
end

"""
eig_vals_pullback!(
Expand All @@ -195,3 +175,25 @@ function eig_vals_pullback!(
ΔDV = (diagonal(ΔD), nothing)
return eig_pullback!(ΔA, A, DV, ΔDV, ind; degeneracy_atol)
end

function eig_pullback!(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I like it slightly better to put the same functions together, rather than group by argument type, but obviously that's just apersonal preference

ΔA::Diagonal, A, DV, ΔDV, ind = Colon();
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)
ΔA_full = zero!(similar(ΔA, size(ΔA)))
eig_pullback!(ΔA_full, A, DV, ΔDV, ind; degeneracy_atol, gauge_atol)
diagview(ΔA) .+= diagview(ΔA_full)
return ΔA
end

function eig_trunc_pullback!(
ΔA::Diagonal, A, DV, ΔDV;
degeneracy_atol::Real = default_pullback_rank_atol(DV[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2])
)
ΔA_full = zero!(similar(ΔA, size(ΔA)))
eig_trunc_pullback!(ΔA_full, A, DV, ΔDV; degeneracy_atol, gauge_atol)
diagview(ΔA) .+= diagview(ΔA_full)
return ΔA
end
26 changes: 15 additions & 11 deletions src/pullbacks/eigh.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
function check_eigh_cotangents(
D, aVᴴΔV;
degeneracy_atol::Real = default_pullback_rank_atol(D),
gauge_atol::Real = default_pullback_gauge_atol(aVᴴΔV)
)
mask = abs.(D' .- D) .< degeneracy_atol
Δgauge = norm(view(aVᴴΔV, mask))
Δgauge ≤ gauge_atol ||
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return
end

"""
eigh_pullback!(
ΔA::AbstractMatrix, A, DV, ΔDV, [ind];
Expand Down Expand Up @@ -41,12 +53,7 @@ function eigh_pullback!(
length(indV) == pV || throw(DimensionMismatch())
mul!(view(VᴴΔV, :, indV), V', ΔV)
aVᴴΔV = project_antihermitian(VᴴΔV) # can't use in-place or recycling doesn't work

mask = abs.(D' .- D) .< degeneracy_atol
Δgauge = norm(view(aVᴴΔV, mask))
Δgauge ≤ gauge_atol ||
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"

check_eigh_cotangents(D, aVᴴΔV; degeneracy_atol, gauge_atol)
aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol)

if !iszerotangent(ΔDmat)
Expand Down Expand Up @@ -120,10 +127,7 @@ function eigh_trunc_pullback!(
VᴴΔV = V' * ΔV
aVᴴΔV = project_antihermitian!(VᴴΔV)

mask = abs.(D' .- D) .< degeneracy_atol
Δgauge = norm(view(aVᴴΔV, mask))
Δgauge ≤ gauge_atol ||
@warn "`eigh` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
check_eigh_cotangents(D, aVᴴΔV; degeneracy_atol, gauge_atol)

aVᴴΔV .*= inv_safe.(D' .- D, degeneracy_atol)

Expand All @@ -138,7 +142,7 @@ function eigh_trunc_pullback!(
# add contribution from orthogonal complement
W = qr_null(V)
WᴴΔV = W' * ΔV
X = sylvester(W' * A * W, -Dmat, WᴴΔV)
X = _sylvester(W' * A * W, -Dmat, WᴴΔV)
Z = mul!(Z, W, X, 1, 1)

# put everything together: symmetrize for hermitian case
Expand Down
92 changes: 65 additions & 27 deletions src/pullbacks/lq.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,42 @@
function check_lq_cotangents(
L, Q, ΔL, ΔQ, minmn::Int, p::Int;
gauge_atol::Real = default_pullback_gauge_atol(ΔQ)
)
if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
# in this case the number Householder reflections will
# change upon small variations, and all of the remaining
# columns of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
Δgauge = max(Δgauge, norm(ΔQ2))
end
if !iszerotangent(ΔL)
ΔL22 = view(ΔL, (p + 1):size(L, 1), (p + 1):minmn)
Δgauge = max(Δgauge, norm(ΔL22))
end
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end
return
end

function check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol::Real = default_pullback_gauge_atol(Q1))
# in the case where A is full rank, but there are more columns in Q than in A
# (the case of `lq_full`), there is gauge-invariant information in the
# projection of ΔQ2 onto the column space of Q1, by virtue of Q being a unitary
# matrix. As the number of Householder reflections is in fixed in the full rank
# case, Q is expected to rotate smoothly (we might even be able to predict) also
# how the full Q2 will change, but this we omit for now, and we consider
# Q2' * ΔQ2 as a gauge dependent quantity.
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return
end


"""
lq_pullback!(
ΔA, A, LQ, ΔLQ;
Expand Down Expand Up @@ -36,28 +75,12 @@ function lq_pullback!(
ΔA1 = view(ΔA, 1:p, :)
ΔA2 = view(ΔA, (p + 1):m, :)

if minmn > p # case where A is rank-deficient
Δgauge = abs(zero(eltype(Q)))
if !iszerotangent(ΔQ)
# in this case the number Householder reflections will
# change upon small variations, and all of the remaining
# columns of ΔQ should be zero for a gauge-invariant
# cost function
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
Δgauge = max(Δgauge, norm(ΔQ2, Inf))
end
if !iszerotangent(ΔL)
ΔL22 = view(ΔL, (p + 1):m, (p + 1):minmn)
Δgauge = max(Δgauge, norm(ΔL22, Inf))
end
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
end
check_lq_cotangents(L, Q, ΔL, ΔQ, minmn, p; gauge_atol)

ΔQ̃ = zero!(similar(Q, (p, n)))
if !iszerotangent(ΔQ)
ΔQ1 = view(ΔQ, 1:p, :)
copy!(ΔQ̃, ΔQ1)
ΔQ̃ .= ΔQ1
if p < size(Q, 1)
Q2 = view(Q, (p + 1):size(Q, 1), :)
ΔQ2 = view(ΔQ, (p + 1):size(Q, 1), :)
Expand All @@ -69,9 +92,7 @@ function lq_pullback!(
# how the full Q2 will change, but this we omit for now, and we consider
# Q2' * ΔQ2 as a gauge dependent quantity.
ΔQ2Q1ᴴ = ΔQ2 * Q1'
Δgauge = norm(mul!(copy(ΔQ2), ΔQ2Q1ᴴ, Q1, -1, 1), Inf)
Δgauge ≤ gauge_atol ||
@warn "`lq` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
check_lq_full_cotangents(Q1, ΔQ2, ΔQ2Q1ᴴ; gauge_atol)
ΔQ̃ = mul!(ΔQ̃, ΔQ2Q1ᴴ', Q2, -1, 1)
end
end
Expand All @@ -95,12 +116,32 @@ function lq_pullback!(
Md = diagview(M)
Md .= real.(Md)
end
ldiv!(LowerTriangular(L11)', M)
ldiv!(LowerTriangular(L11)', ΔQ̃)
# not GPU friendly...
L11arr = typeof(L)(L11)
ldiv!(LowerTriangular(L11arr)', M)
ldiv!(LowerTriangular(L11arr)', ΔQ̃)
ΔA1 = mul!(ΔA1, M, Q1, +1, 1)
ΔA1 .+= ΔQ̃
return ΔA
end
function lq_pullback!(
ΔA::Diagonal, A, LQ, ΔLQ;
rank_atol::Real = default_pullback_rank_atol(LQ[1]),
gauge_atol::Real = default_pullback_gauge_atol(ΔLQ[2])
)
ΔA_full = zero!(similar(ΔA, size(ΔA)))
ΔA_full = lq_pullback!(ΔA_full, A, LQ, ΔLQ; rank_atol, gauge_atol)
diagview(ΔA) .+= diagview(ΔA_full)
return ΔA
end

function check_lq_null_cotangents(Nᴴ, ΔNᴴ; gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ))
aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ')
Δgauge = norm(aNᴴΔN)
Δgauge ≤ gauge_atol ||
@warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
return
end

"""
lq_null_pullback!(
Expand All @@ -118,10 +159,7 @@ function lq_null_pullback!(
gauge_atol::Real = default_pullback_gauge_atol(ΔNᴴ)
)
if !iszerotangent(ΔNᴴ) && size(Nᴴ, 1) > 0
aNᴴΔN = project_antihermitian!(Nᴴ * ΔNᴴ')
Δgauge = norm(aNᴴΔN)
Δgauge ≤ gauge_atol ||
@warn "`lq_null` cotangent sensitive to gauge choice: (|Δgauge| = $Δgauge)"
check_lq_null_cotangents(Nᴴ, ΔNᴴ; gauge_atol)
L, Q = lq_compact(A; positive = true) # should we be able to provide algorithm here?
X = ldiv!(LowerTriangular(L)', Q * ΔNᴴ')
ΔA = mul!(ΔA, X, Nᴴ, -1, 1)
Expand Down
Loading
Loading