-
Notifications
You must be signed in to change notification settings - Fork 5
Use Testsuite for AD tests #126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
09d7f69
239686e
df74a86
1629c0c
111cc89
86777c4
3708f83
c3be142
c4627bc
6f72754
b343415
2832937
b9bb9d5
7e49379
7098f45
6452b9e
b47d076
da5f1bb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,11 +3,11 @@ module MatrixAlgebraKitCUDAExt | |
| using MatrixAlgebraKit | ||
| using MatrixAlgebraKit: @algdef, Algorithm, check_input | ||
| using MatrixAlgebraKit: one!, zero!, uppertriangular!, lowertriangular! | ||
| using MatrixAlgebraKit: diagview, sign_safe | ||
| using MatrixAlgebraKit: diagview, sign_safe, default_pullback_gauge_atol, default_pullback_rank_atol | ||
| using MatrixAlgebraKit: LQViaTransposedQR, TruncationByValue, AbstractAlgorithm | ||
| using MatrixAlgebraKit: default_qr_algorithm, default_lq_algorithm, default_svd_algorithm, default_eig_algorithm, default_eigh_algorithm | ||
| import MatrixAlgebraKit: _gpu_geqrf!, _gpu_ungqr!, _gpu_unmqr!, _gpu_gesvd!, _gpu_Xgesvdp!, _gpu_Xgesvdr!, _gpu_gesvdj!, _gpu_geev! | ||
| import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd! | ||
| import MatrixAlgebraKit: _gpu_heevj!, _gpu_heevd!, _sylvester | ||
| using CUDA, CUDA.CUBLAS | ||
| using CUDA: i32 | ||
| using LinearAlgebra | ||
|
|
@@ -195,4 +195,18 @@ end | |
| MatrixAlgebraKit._ind_intersect(A::CuVector{Int}, B::CuVector{Int}) = | ||
| MatrixAlgebraKit._ind_intersect(collect(A), collect(B)) | ||
|
|
||
| MatrixAlgebraKit.default_pullback_rank_atol(A::AnyCuArray) = eps(norm(CuArray(A), Inf))^(3 / 4) | ||
| MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray) = MatrixAlgebraKit.iszerotangent(A) ? 0 : eps(norm(CuArray(A), Inf))^(3 / 4) | ||
| function MatrixAlgebraKit.default_pullback_gauge_atol(A::AnyCuArray, As...) | ||
| As′ = filter(!MatrixAlgebraKit.iszerotangent, (A, As...)) | ||
| return isempty(As′) ? 0 : eps(norm(CuArray.(As′), Inf))^(3 / 4) | ||
| end | ||
|
|
||
| function _sylvester(A::AnyCuMatrix, B::AnyCuMatrix, C::AnyCuMatrix) | ||
| # https://github.com/JuliaGPU/CUDA.jl/issues/3021 | ||
| # to add native sylvester to CUDA | ||
| hX = sylvester(collect(A), collect(B), collect(C)) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is very awful but I wasn't able to find a correct way to do it in five minutes so there you go
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any chance we could:
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes to both
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should open an issue at AMDGPU also
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add this as a comment, and replace our use of |
||
| return CuArray(hX) | ||
| end | ||
|
|
||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -78,16 +78,6 @@ function eig_pullback!( | |
| end | ||
| return ΔA | ||
| end | ||
| function eig_pullback!( | ||
| ΔA::Diagonal, A, DV, ΔDV, ind = Colon(); | ||
| degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), | ||
| gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) | ||
| ) | ||
| ΔA_full = zero!(similar(ΔA, size(ΔA))) | ||
| ΔA_full = eig_pullback!(ΔA_full, A, DV, ΔDV, ind; degeneracy_atol, gauge_atol) | ||
| diagview(ΔA) .+= diagview(ΔA_full) | ||
| return ΔA | ||
| end | ||
|
|
||
| """ | ||
| eig_trunc_pullback!( | ||
|
|
@@ -150,7 +140,7 @@ function eig_trunc_pullback!( | |
| # add contribution from orthogonal complement | ||
| PA = A - (A * V) / V | ||
| Y = mul!(ΔVperp, PA', Z, 1, 1) | ||
| X = sylvester(PA', -Dmat', Y) | ||
| X = _sylvester(PA', -Dmat', Y) | ||
| Z .+= X | ||
|
|
||
| if eltype(ΔA) <: Real | ||
|
|
@@ -161,16 +151,6 @@ function eig_trunc_pullback!( | |
| end | ||
| return ΔA | ||
| end | ||
| function eig_trunc_pullback!( | ||
| ΔA::Diagonal, A, DV, ΔDV; | ||
| degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), | ||
| gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) | ||
| ) | ||
| ΔA_full = zero!(similar(ΔA, size(ΔA))) | ||
| ΔA_full = eig_trunc_pullback!(ΔA_full, A, DV, ΔDV; degeneracy_atol, gauge_atol) | ||
| diagview(ΔA) .+= diagview(ΔA_full) | ||
| return ΔA | ||
| end | ||
|
|
||
| """ | ||
| eig_vals_pullback!( | ||
|
|
@@ -195,3 +175,25 @@ function eig_vals_pullback!( | |
| ΔDV = (diagonal(ΔD), nothing) | ||
| return eig_pullback!(ΔA, A, DV, ΔDV, ind; degeneracy_atol) | ||
| end | ||
|
|
||
| function eig_pullback!( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I like it slightly better to put the same functions together, rather than group by argument type, but obviously that's just apersonal preference |
||
| ΔA::Diagonal, A, DV, ΔDV, ind = Colon(); | ||
| degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), | ||
| gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) | ||
| ) | ||
| ΔA_full = zero!(similar(ΔA, size(ΔA))) | ||
| eig_pullback!(ΔA_full, A, DV, ΔDV, ind; degeneracy_atol, gauge_atol) | ||
| diagview(ΔA) .+= diagview(ΔA_full) | ||
| return ΔA | ||
| end | ||
|
|
||
| function eig_trunc_pullback!( | ||
| ΔA::Diagonal, A, DV, ΔDV; | ||
| degeneracy_atol::Real = default_pullback_rank_atol(DV[1]), | ||
| gauge_atol::Real = default_pullback_gauge_atol(ΔDV[2]) | ||
| ) | ||
| ΔA_full = zero!(similar(ΔA, size(ΔA))) | ||
| eig_trunc_pullback!(ΔA_full, A, DV, ΔDV; degeneracy_atol, gauge_atol) | ||
| diagview(ΔA) .+= diagview(ΔA_full) | ||
| return ΔA | ||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this needed? what breaks if we don't do
CuArray.(As')?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
normdoesn't work forAdjoint{CuArray}for exampleThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it help if we use the regular
norminstead of theInfone?