Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ LinearSolvePETScExt = ["PETSc", "SparseArrays"]
LinearSolveParUExt = ["ParU_jll", "SparseArrays"]
LinearSolvePardisoExt = ["Pardiso", "SparseArrays"]
LinearSolveRecursiveFactorizationExt = "RecursiveFactorization"
LinearSolveSTRUMPACKExt = "SparseArrays"
LinearSolveSparseArraysExt = "SparseArrays"
LinearSolveSparspakExt = ["SparseArrays", "Sparspak"]

Expand Down
239 changes: 239 additions & 0 deletions ext/LinearSolveSTRUMPACKExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
module LinearSolveSTRUMPACKExt

using LinearSolve: LinearSolve, LinearVerbosity, OperatorAssumptions
using SparseArrays: SparseArrays, AbstractSparseMatrixCSC, getcolptr, rowvals, nonzeros
using SciMLBase: SciMLBase, ReturnCode
using SciMLLogging: @SciMLMessage
using Libdl: Libdl

const STRUMPACK_SUCCESS = Cint(0)
const STRUMPACK_MATRIX_NOT_SET = Cint(1)
const STRUMPACK_REORDERING_ERROR = Cint(2)
const STRUMPACK_ZERO_PIVOT = Cint(3)
const STRUMPACK_NO_CONVERGENCE = Cint(4)
const STRUMPACK_INACCURATE_INERTIA = Cint(5)

const STRUMPACK_DOUBLE = Cint(1)
const STRUMPACK_MT = Cint(0)

const _libstrumpack = Ref{Ptr{Cvoid}}(C_NULL)

function _load_libstrumpack()
for name in (
"libstrumpack.so",
"libstrumpack.so.8",
"libstrumpack.so.7",
"libstrumpack.dylib",
"strumpack",
)
handle = Libdl.dlopen_e(name)
handle != C_NULL && return handle
end
return C_NULL
end

function __init__()
return _libstrumpack[] = _load_libstrumpack()
end

strumpack_isavailable() = _libstrumpack[] != C_NULL

mutable struct STRUMPACKCache
solver::Ref{Ptr{Cvoid}}
rowptr::Vector{Int32}
colind::Vector{Int32}
nzval::Vector{Float64}

function STRUMPACKCache()
cache = new(Ref{Ptr{Cvoid}}(C_NULL), Int32[], Int32[], Float64[])
finalizer(_strumpack_destroy!, cache)
return cache
end
end

function _strumpack_destroy!(cache::STRUMPACKCache)
_libstrumpack[] == C_NULL && return
cache.solver[] == C_NULL && return
ccall((:STRUMPACK_destroy, _libstrumpack[]), Cvoid, (Ref{Ptr{Cvoid}},), cache.solver)
cache.solver[] = C_NULL
return
end

function _ensure_initialized!(cache::STRUMPACKCache)
cache.solver[] != C_NULL && return
ccall(
(:STRUMPACK_init_mt, _libstrumpack[]),
Cvoid,
(Ref{Ptr{Cvoid}}, Cint, Cint, Cint, Ptr{Ptr{UInt8}}, Cint),
cache.solver,
STRUMPACK_DOUBLE,
STRUMPACK_MT,
Cint(0),
Ptr{Ptr{UInt8}}(C_NULL),
Cint(0)
)
return
end

function _csc_to_csr_0based(A::AbstractSparseMatrixCSC)
n = size(A, 1)
colptr = getcolptr(A)
rowval = rowvals(A)
vals = nonzeros(A)

nnz = length(vals)
rowptr = zeros(Int32, n + 1)

@inbounds for idx in eachindex(rowval)
rowptr[Int(rowval[idx]) + 1] += 1
end

@inbounds for i in 1:n
rowptr[i + 1] += rowptr[i]
end

nextidx = copy(rowptr)
colind = Vector{Int32}(undef, nnz)
outvals = Vector{Float64}(undef, nnz)

@inbounds for j in 1:size(A, 2)
for p in colptr[j]:(colptr[j + 1] - 1)
row = Int(rowval[p])
pos = Int(nextidx[row] + 1)
nextidx[row] += 1
colind[pos] = Int32(j - 1)
outvals[pos] = Float64(vals[p])
end
end

return rowptr, colind, outvals
end

function _retcode_from_strumpack(info::Cint)
return if info == STRUMPACK_SUCCESS
ReturnCode.Success
elseif info == STRUMPACK_ZERO_PIVOT
ReturnCode.Infeasible
elseif info == STRUMPACK_NO_CONVERGENCE
ReturnCode.ConvergenceFailure
elseif info == STRUMPACK_INACCURATE_INERTIA
ReturnCode.Unstable
elseif info == STRUMPACK_MATRIX_NOT_SET || info == STRUMPACK_REORDERING_ERROR
ReturnCode.Failure
else
ReturnCode.Failure
end
end

function LinearSolve.init_cacheval(
::LinearSolve.STRUMPACKFactorization,
A::AbstractSparseMatrixCSC{<:AbstractFloat}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol,
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions
)
return STRUMPACKCache()
end

function LinearSolve.init_cacheval(
::LinearSolve.STRUMPACKFactorization,
A, b, u, Pl, Pr,
maxiters::Int, abstol, reltol,
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions
)
return nothing
end

function SciMLBase.solve!(
cache::LinearSolve.LinearCache,
alg::LinearSolve.STRUMPACKFactorization;
kwargs...
)
if _libstrumpack[] == C_NULL
error("STRUMPACKFactorization requires a discoverable STRUMPACK shared library (`libstrumpack`)")
end

A = convert(AbstractMatrix, cache.A)
if !(A isa AbstractSparseMatrixCSC)
error("STRUMPACKFactorization currently supports only sparse CSC matrices")
end
size(A, 1) == size(A, 2) || error("STRUMPACKFactorization requires a square matrix")

scache = LinearSolve.@get_cacheval(cache, :STRUMPACKFactorization)
if scache === nothing
error("STRUMPACKFactorization currently supports `AbstractSparseMatrixCSC{<:AbstractFloat}`")
end

_ensure_initialized!(scache)

if cache.isfresh
scache.rowptr, scache.colind, scache.nzval = _csc_to_csr_0based(A)
ccall(
(:STRUMPACK_set_csr_matrix, _libstrumpack[]),
Cvoid,
(Ptr{Cvoid}, Cint, Ref{Cint}, Ref{Cint}, Ref{Cdouble}, Cint),
scache.solver[],
Cint(size(A, 1)),
scache.rowptr,
scache.colind,
scache.nzval,
Cint(0)
)

info = ccall((:STRUMPACK_factor, _libstrumpack[]), Cint, (Ptr{Cvoid},), scache.solver[])
if info != STRUMPACK_SUCCESS
@SciMLMessage(
"STRUMPACK factorization failed (code $(Int(info)))",
cache.verbose,
:solver_failure
)
cache.isfresh = false
return SciMLBase.build_linear_solution(
alg,
cache.u,
nothing,
cache;
retcode = _retcode_from_strumpack(info)
)
end
cache.isfresh = false
end

bvec = Float64.(cache.b)
xvec = Float64.(cache.u)

info = ccall(
(:STRUMPACK_solve, _libstrumpack[]),
Cint,
(Ptr{Cvoid}, Ref{Cdouble}, Ref{Cdouble}, Cint),
scache.solver[],
bvec,
xvec,
Cint(alg.use_initial_guess)
)

if info != STRUMPACK_SUCCESS
@SciMLMessage(
"STRUMPACK solve failed (code $(Int(info)))",
cache.verbose,
:solver_failure
)
return SciMLBase.build_linear_solution(
alg,
cache.u,
nothing,
cache;
retcode = _retcode_from_strumpack(info)
)
end

copyto!(cache.u, xvec)
return SciMLBase.build_linear_solution(
alg,
cache.u,
nothing,
cache;
retcode = ReturnCode.Success
)
end

end
4 changes: 3 additions & 1 deletion src/LinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,7 @@ for alg in (
:DiagonalFactorization, :CholeskyFactorization, :BunchKaufmanFactorization,
:CHOLMODFactorization, :LDLtFactorization, :AppleAccelerateLUFactorization,
:MKLLUFactorization, :MetalLUFactorization, :CUSOLVERRFFactorization, :ParUFactorization,
:STRUMPACKFactorization,
)
@eval needs_square_A(::$(alg)) = true
end
Expand Down Expand Up @@ -513,7 +514,8 @@ export LUFactorization, SVDFactorization, QRFactorization, GenericFactorization,
UMFPACKFactorization, KLUFactorization, FastLUFactorization, FastQRFactorization,
SparspakFactorization, DiagonalFactorization, CholeskyFactorization,
BunchKaufmanFactorization, CHOLMODFactorization, LDLtFactorization,
CUSOLVERRFFactorization, CliqueTreesFactorization, ParUFactorization
CUSOLVERRFFactorization, CliqueTreesFactorization, ParUFactorization,
STRUMPACKFactorization

export LinearSolveFunction, DirectLdiv!, show_algorithm_choices

Expand Down
46 changes: 46 additions & 0 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1482,6 +1482,52 @@ struct SparspakFactorization <: AbstractSparseFactorization
end
end

"""
`STRUMPACKFactorization(; use_initial_guess = false)`

A sparse direct solver based on
[STRUMPACK](https://github.com/pghysels/STRUMPACK) via the
`LinearSolveSTRUMPACKExt` extension.

This wrapper targets the single-node (`MT`) sparse interface and currently supports
real sparse matrices (`AbstractSparseMatrixCSC{<:AbstractFloat}`), solving in
`Float64` precision.

!!! note

Using this solver requires:
1. `using SparseArrays` (to enable sparse matrix support), and
2. a system installation of `libstrumpack` discoverable by the dynamic loader.
"""
struct STRUMPACKFactorization <: AbstractSparseFactorization
use_initial_guess::Bool

function STRUMPACKFactorization(; use_initial_guess = false, throwerror = true)
ext = Base.get_extension(@__MODULE__, :LinearSolveSTRUMPACKExt)
return if throwerror && (ext === nothing || !ext.strumpack_isavailable())
error("STRUMPACKFactorization requires a discoverable STRUMPACK shared library (`libstrumpack`) and `using SparseArrays`")
else
new(use_initial_guess)
end
end
end

function init_cacheval(
::STRUMPACKFactorization,
::Union{AbstractMatrix, Nothing, AbstractSciMLOperator}, b, u, Pl, Pr,
maxiters::Int, abstol, reltol,
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions
)
return nothing
end

function init_cacheval(
::STRUMPACKFactorization, ::StaticArray, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions
)
return nothing
end

function init_cacheval(
alg::SparspakFactorization,
A::Union{AbstractMatrix, Nothing, AbstractSciMLOperator}, b, u, Pl, Pr,
Expand Down
7 changes: 7 additions & 0 deletions test/defaults_loading.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ rhs[begin] = rhs[end] = -2
prob = LinearProblem(mat, rhs)
@test_throws ["SparspakFactorization required", "using Sparspak"] sol = solve(prob).u

STRUMPACKExt = Base.get_extension(LinearSolve, :LinearSolveSTRUMPACKExt)
if STRUMPACKExt === nothing || !STRUMPACKExt.strumpack_isavailable()
@test_throws ["STRUMPACKFactorization", "libstrumpack"] STRUMPACKFactorization()
else
@test STRUMPACKFactorization() isa STRUMPACKFactorization
end

using Sparspak
sol = solve(prob).u
@test sol isa Vector{BigFloat}
8 changes: 6 additions & 2 deletions test/resolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ using LinearSolve: AbstractDenseFactorization, AbstractSparseFactorization,
AMDGPUOffloadLUFactorization, AMDGPUOffloadQRFactorization,
SparspakFactorization

const STRUMPACKExt = Base.get_extension(LinearSolve, :LinearSolveSTRUMPACKExt)
const HAS_STRUMPACK = STRUMPACKExt !== nothing && STRUMPACKExt.strumpack_isavailable()

# Function to check if an algorithm is mixed precision
function is_mixed_precision_alg(alg)
alg_name = string(alg)
Expand Down Expand Up @@ -48,14 +51,15 @@ for alg in vcat(
(!(alg == AppleAccelerate32MixedLUFactorization) || Sys.isapple()) &&
(!(alg == OpenBLAS32MixedLUFactorization) || LinearSolve.useopenblas) &&
(!(alg == SparspakFactorization) || false) &&
(!(alg == STRUMPACKFactorization) || HAS_STRUMPACK) &&
(
!(alg == ParUFactorization) ||
Base.get_extension(LinearSolve, :LinearSolveParUExt) !== nothing
)
A = [1.0 2.0; 3.0 4.0]
alg in [
KLUFactorization, UMFPACKFactorization, SparspakFactorization,
ParUFactorization,
ParUFactorization, STRUMPACKFactorization,
] &&
(A = sparse(A))
A = A' * A
Expand Down Expand Up @@ -84,7 +88,7 @@ for alg in vcat(
A = [1.0 2.0; 3.0 4.0]
alg in [
KLUFactorization, UMFPACKFactorization, SparspakFactorization,
ParUFactorization,
ParUFactorization, STRUMPACKFactorization,
] &&
(A = sparse(A))
A = A' * A
Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ if GROUP == "DefaultsLoading"
@time @safetestset "Defaults Loading Tests" include("defaults_loading.jl")
end

if GROUP == "All" || GROUP == "LinearSolveSTRUMPACK"
@time @safetestset "LinearSolveSTRUMPACK" include("strumpack/strumpack.jl")
end

if GROUP == "LinearSolveAutotune"
Pkg.activate(joinpath(dirname(@__DIR__), "lib", GROUP))
Pkg.test(
Expand Down
Loading
Loading