Skip to content

Commit 4a647ad

Browse files
committed
Optimize LinearAlgebra.backSubstitute
1 parent b6fe4ad commit 4a647ad

4 files changed

Lines changed: 146 additions & 125 deletions

File tree

src/FsMath/Algebra/LinearAlgebra.fs

Lines changed: 141 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -84,37 +84,115 @@ type LinearAlgebra =
8484

8585
Matrix(m, n, qData), r
8686

87-
/// <summary>Back substitute to solve R * x = y</summary>
88-
/// <remarks>R is upper triangular</remarks>
87+
/// <summary>Forward substitute to solve L * x = y</summary>
88+
/// <remarks>L is lower triangular</remarks>
89+
static member inline forwardSubstitute<'T when 'T :> Numerics.INumber<'T>
90+
and 'T : (new: unit -> 'T)
91+
and 'T : struct
92+
and 'T : comparison
93+
and 'T :> ValueType>
94+
(L : Matrix<'T>)
95+
(y : Vector<'T>) : Vector<'T> =
96+
97+
let n = L.NumRows
98+
99+
if L.NumCols <> n || y.Length <> n then
100+
invalidArg "dimensions" "L must be square and match the length of y"
101+
102+
let x = Array.zeroCreate<'T> n
103+
let cols = L.NumCols
104+
let lData = L.Data
105+
106+
// Again, scalar version; easy to SIMD the inner sum later
107+
for i = 0 to n - 1 do
108+
let mutable s = y.[i]
109+
let rowOffset = i * cols
110+
for j = 0 to i - 1 do
111+
s <- s - lData.[rowOffset + j] * x.[j]
112+
let diag = lData.[rowOffset + i]
113+
if diag = 'T.Zero then
114+
invalidArg $"Matrix[{i},{i}]" "Diagonal element is zero. Cannot divide."
115+
x.[i] <- s / diag
116+
117+
x
118+
119+
120+
121+
///// <summary>Back substitute to solve R * x = y</summary>
122+
///// <remarks>R is upper triangular</remarks>
123+
//static member inline backSubstitute<'T when 'T :> Numerics.INumber<'T>
124+
// and 'T : (new: unit -> 'T)
125+
// and 'T : struct
126+
// and 'T : comparison
127+
// and 'T :> ValueType>
128+
// (r: Matrix<'T>)
129+
// (y: Vector<'T>) : Vector<'T> =
130+
131+
// let n = r.NumRows
132+
133+
// if r.NumCols <> n || y.Length <> n then
134+
// invalidArg "dimensions" "R must be square and match the length of y"
135+
136+
// let x = Array.zeroCreate<'T> n
137+
138+
// for i = n - 1 downto 0 do
139+
// let mutable sum = y.[i]
140+
// for j = i + 1 to n - 1 do
141+
// sum <- sum - r.[i, j] * x.[j]
142+
// let diag = r.[i, i]
143+
// if diag = 'T.Zero then
144+
// invalidArg $"Matrix{i},{i}]" "Diagonal element is zero. Cannot divide."
145+
// x.[i] <- sum / diag
146+
147+
// x
148+
149+
89150
static member inline backSubstitute<'T when 'T :> Numerics.INumber<'T>
90151
and 'T : (new: unit -> 'T)
91152
and 'T : struct
92153
and 'T : comparison
93154
and 'T :> ValueType>
94-
(r: Matrix<'T>)
155+
(R: Matrix<'T>)
95156
(y: Vector<'T>) : Vector<'T> =
96157

97-
let n = r.NumRows
158+
let n = R.NumRows
98159

99-
if r.NumCols <> n || y.Length <> n then
160+
if R.NumCols <> n || y.Length <> n then
100161
invalidArg "dimensions" "R must be square and match the length of y"
101162

102163
let x = Array.zeroCreate<'T> n
164+
let cols = R.NumCols
165+
let rData = R.Data // row-major underlying array
103166

167+
// Backward substitution
104168
for i = n - 1 downto 0 do
105169
let mutable sum = y.[i]
106-
for j = i + 1 to n - 1 do
107-
sum <- sum - r.[i, j] * x.[j]
108-
let diag = r.[i, i]
170+
171+
let startJ = i + 1
172+
let len = n - startJ
173+
174+
if len > 0 then
175+
// row slice: r[i, i+1 .. n-1]
176+
let rowOffset = i * cols + startJ
177+
let rowTailSpan = ReadOnlySpan<'T>(rData, rowOffset, len)
178+
179+
// x slice: x[i+1 .. n-1]
180+
let xTailSpan = ReadOnlySpan<'T>(x, startJ, len)
181+
182+
// subtract SIMD dot product
183+
let dot = SpanMath.dot(rowTailSpan, xTailSpan)
184+
sum <- sum - dot
185+
186+
let diag = R.[i, i]
109187
if diag = 'T.Zero then
110-
invalidArg $"r[{i},{i}]" "Diagonal element is zero. Cannot divide."
188+
invalidArg $"Matrix[{i},{i}]" "Diagonal element is zero. Cannot divide."
189+
111190
x.[i] <- sum / diag
112191

113192
x
114193

115194

116195

117-
118196
/// Solve A * x = b for x, where A is a square matrix (n×n) and b is a vector (length n).
119197
static member inline solveLinearQR<'T when 'T :> Numerics.INumber<'T>
120198
and 'T : (new: unit -> 'T)
@@ -249,10 +327,56 @@ type LinearAlgebra =
249327

250328

251329

252-
/// Solve K * x = v (triangular system) in-place, returning a copy of x.
330+
///// Solve K * x = v (triangular system) in-place, returning a copy of x.
331+
///// K must be n×n, v must be length n.
332+
///// isLower = true => forward substitution
333+
///// isLower = false => backward substitution
334+
//static member inline solveTriangularLinearSystem
335+
// (K : Matrix<'T>)
336+
// (v : Vector<'T>)
337+
// (isLower : bool)
338+
// : Vector<'T> =
339+
340+
// let nK, mK = K.NumRows, K.NumCols
341+
// let nV = v.Length
342+
// if nK <> mK || nV <> nK then
343+
// invalidArg (nameof K) "K must be square, and v must match its dimension."
344+
345+
// let x = Array.copy v
346+
// let Kdata = K.Data // row-major flattened
347+
348+
// // Forward or backward substitution
349+
// if isLower then
350+
// // For i in [0..n-1]:
351+
// // x[i] <- ( x[i] - sum_{j=0..i-1}(K[i,j] * x[j]) ) / K[i,i]
352+
// for i = 0 to nK - 1 do
353+
// let mutable s = x.[i]
354+
// let rowOffset = i * nK
355+
// for j = 0 to i - 1 do
356+
// s <- s - (Kdata.[rowOffset + j] * x.[j])
357+
// let diag = Kdata.[rowOffset + i]
358+
// if diag = 'T.Zero then
359+
// invalidArg $"K[{i},{i}]" "Diagonal element is zero. Cannot divide."
360+
// x.[i] <- s / diag
361+
// else
362+
// // For i in [n-1..downto..0]:
363+
// // x[i] <- ( x[i] - sum_{j=i+1..n-1}(K[i,j] * x[j]) ) / K[i,i]
364+
// for i = nK - 1 downto 0 do
365+
// let mutable s = x.[i]
366+
// let rowOffset = i * nK
367+
// for j = i + 1 to nK - 1 do
368+
// s <- s - (Kdata.[rowOffset + j] * x.[j])
369+
// let diag = Kdata.[rowOffset + i]
370+
// if diag = 'T.Zero then
371+
// invalidArg $"K[{i},{i}]" "Diagonal element is zero. Cannot divide."
372+
// x.[i] <- s / diag
373+
374+
// x
375+
376+
/// Solve K * x = v (triangular system), returning a new x.
253377
/// K must be n×n, v must be length n.
254-
/// isLower = true => forward substitution
255-
/// isLower = false => backward substitution
378+
/// isLower = true => forward substitution (K lower triangular)
379+
/// isLower = false => backward substitution (K upper triangular)
256380
static member inline solveTriangularLinearSystem
257381
(K : Matrix<'T>)
258382
(v : Vector<'T>)
@@ -264,36 +388,12 @@ type LinearAlgebra =
264388
if nK <> mK || nV <> nK then
265389
invalidArg (nameof K) "K must be square, and v must match its dimension."
266390

267-
let x = Array.copy v
268-
let Kdata = K.Data // row-major flattened
269-
270-
// Forward or backward substitution
271391
if isLower then
272-
// For i in [0..n-1]:
273-
// x[i] <- ( x[i] - sum_{j=0..i-1}(K[i,j] * x[j]) ) / K[i,i]
274-
for i = 0 to nK - 1 do
275-
let mutable s = x.[i]
276-
let rowOffset = i * nK
277-
for j = 0 to i - 1 do
278-
s <- s - (Kdata.[rowOffset + j] * x.[j])
279-
let diag = Kdata.[rowOffset + i]
280-
if diag = 'T.Zero then
281-
invalidArg $"K[{i},{i}]" "Diagonal element is zero. Cannot divide."
282-
x.[i] <- s / diag
392+
// L * x = v
393+
LinearAlgebra.forwardSubstitute K v
283394
else
284-
// For i in [n-1..downto..0]:
285-
// x[i] <- ( x[i] - sum_{j=i+1..n-1}(K[i,j] * x[j]) ) / K[i,i]
286-
for i = nK - 1 downto 0 do
287-
let mutable s = x.[i]
288-
let rowOffset = i * nK
289-
for j = i + 1 to nK - 1 do
290-
s <- s - (Kdata.[rowOffset + j] * x.[j])
291-
let diag = Kdata.[rowOffset + i]
292-
if diag = 'T.Zero then
293-
invalidArg $"K[{i},{i}]" "Diagonal element is zero. Cannot divide."
294-
x.[i] <- s / diag
295-
296-
x
395+
// R * x = v
396+
LinearAlgebra.backSubstitute K v
297397

298398

299399

tests/FsMath.Tests/LinearAlgebraErrorTestsAdditional.fs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ module SolveTriangularSystemZeroDiagonalTests =
6262
let ex = Assert.Throws<ArgumentException>(fun () ->
6363
LinearAlgebra.solveTriangularLinearSystem L v true |> ignore)
6464
Assert.Contains("Diagonal element is zero", ex.Message)
65-
Assert.Contains("K[0,0]", ex.Message)
65+
Assert.Contains("Matrix[0,0]", ex.Message)
6666

6767
[<Fact>]
6868
let ``solveTriangularLinearSystem throws on zero diagonal - forward sub in middle`` () =
@@ -76,7 +76,7 @@ module SolveTriangularSystemZeroDiagonalTests =
7676
let ex = Assert.Throws<ArgumentException>(fun () ->
7777
LinearAlgebra.solveTriangularLinearSystem L v true |> ignore)
7878
Assert.Contains("Diagonal element is zero", ex.Message)
79-
Assert.Contains("K[2,2]", ex.Message)
79+
Assert.Contains("Matrix[2,2]", ex.Message)
8080

8181
[<Fact>]
8282
let ``solveTriangularLinearSystem throws on zero diagonal - backward sub at position 0`` () =
@@ -89,7 +89,7 @@ module SolveTriangularSystemZeroDiagonalTests =
8989
let ex = Assert.Throws<ArgumentException>(fun () ->
9090
LinearAlgebra.solveTriangularLinearSystem U v false |> ignore)
9191
Assert.Contains("Diagonal element is zero", ex.Message)
92-
Assert.Contains("K[0,0]", ex.Message)
92+
Assert.Contains("Matrix[0,0]", ex.Message)
9393

9494
[<Fact>]
9595
let ``solveTriangularLinearSystem throws on zero diagonal - backward sub in middle`` () =
@@ -103,7 +103,7 @@ module SolveTriangularSystemZeroDiagonalTests =
103103
let ex = Assert.Throws<ArgumentException>(fun () ->
104104
LinearAlgebra.solveTriangularLinearSystem U v false |> ignore)
105105
Assert.Contains("Diagonal element is zero", ex.Message)
106-
Assert.Contains("K[1,1]", ex.Message)
106+
Assert.Contains("Matrix[1,1]", ex.Message)
107107

108108
[<Fact>]
109109
let ``solveTriangularLinearSystem throws on zero diagonal - backward sub at last`` () =
@@ -115,7 +115,7 @@ module SolveTriangularSystemZeroDiagonalTests =
115115
let ex = Assert.Throws<ArgumentException>(fun () ->
116116
LinearAlgebra.solveTriangularLinearSystem U v false |> ignore)
117117
Assert.Contains("Diagonal element is zero", ex.Message)
118-
Assert.Contains("K[1,1]", ex.Message)
118+
Assert.Contains("Matrix[1,1]", ex.Message)
119119

120120

121121
module CholeskyNonSquareTests =

tests/FsMath.Tests/VectorOpsCoverageTests.fs

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -427,37 +427,3 @@ module VectorOpsCoverageTests =
427427
let result2 = evalQ <@ (scalar .* v1) .+ (scalar .* v2) @>
428428
floatArrayClose result1 result2 1e-10
429429

430-
// ========================================
431-
// Quotation tests for @ operator
432-
// Note: @ operator currently calls Power.Invoke (not Dot.Invoke despite the comment)
433-
// ========================================
434-
435-
[<Fact>]
436-
let ``operator @_Q: applies power operation`` () =
437-
let v = [| 2.0; 3.0; 4.0 |]
438-
let result = evalQ <@ v @ 2.0 @>
439-
floatArrayClose [| 4.0; 9.0; 16.0 |] result 1e-10
440-
441-
[<Fact>]
442-
let ``operator @_Q: fractional power`` () =
443-
let v = [| 4.0; 9.0; 16.0 |]
444-
let result = evalQ <@ v @ 0.5 @>
445-
floatArrayClose [| 2.0; 3.0; 4.0 |] result 1e-10
446-
447-
[<Fact>]
448-
let ``operator @_Q: negative power`` () =
449-
let v = [| 2.0; 4.0; 5.0 |]
450-
let result = evalQ <@ v @ -1.0 @>
451-
floatArrayClose [| 0.5; 0.25; 0.2 |] result 1e-10
452-
453-
[<Fact>]
454-
let ``operator @_Q: power of zero`` () =
455-
let v = [| 2.0; 3.0; 4.0 |]
456-
let result = evalQ <@ v @ 0.0 @>
457-
floatArrayClose [| 1.0; 1.0; 1.0 |] result 1e-10
458-
459-
[<Fact>]
460-
let ``operator @_Q: integer power`` () =
461-
let v = [| 2.0; 3.0; 4.0 |]
462-
let result = evalQ <@ v @ 3.0 @>
463-
floatArrayClose [| 8.0; 27.0; 64.0 |] result 1e-10

tests/FsMath.Tests/VectorOpsTests.fs

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -302,48 +302,3 @@ module VectorOpsTests =
302302
let result2 = (scalar .* v1) .+ (scalar .* v2)
303303
floatArrayClose result1 result2 1e-10
304304

305-
// =============================================
306-
// @ Operator Tests (Power operator - NOTE: Comment in VectorOps.fs incorrectly says "Dot product")
307-
// =============================================
308-
309-
[<Fact>]
310-
let ``@ operator applies power operation (float)`` () =
311-
// The @ operator currently calls Power.Invoke, not Dot.Invoke
312-
// despite the comment saying "// Dot product ( @ )"
313-
let v = [| 2.0; 3.0; 4.0 |]
314-
let power = 2.0
315-
let result = v @ power
316-
let expected = [| 4.0; 9.0; 16.0 |]
317-
floatArrayClose expected result 1e-10
318-
319-
[<Fact>]
320-
let ``@ operator with fractional power (float)`` () =
321-
let v = [| 4.0; 9.0; 16.0 |]
322-
let power = 0.5
323-
let result = v @ power
324-
let expected = [| 2.0; 3.0; 4.0 |]
325-
floatArrayClose expected result 1e-10
326-
327-
[<Fact>]
328-
let ``@ operator with negative power (float)`` () =
329-
let v = [| 2.0; 4.0; 5.0 |]
330-
let power = -1.0
331-
let result = v @ power
332-
let expected = [| 0.5; 0.25; 0.2 |]
333-
floatArrayClose expected result 1e-10
334-
335-
[<Fact>]
336-
let ``@ operator with zero power returns ones (float)`` () =
337-
let v = [| 2.0; 3.0; 4.0 |]
338-
let power = 0.0
339-
let result = v @ power
340-
let expected = [| 1.0; 1.0; 1.0 |]
341-
floatArrayClose expected result 1e-10
342-
343-
[<Fact>]
344-
let ``@ operator with integer power (float)`` () =
345-
let v = [| 2.0; 3.0; 4.0 |]
346-
let power = 3.0
347-
let result = v @ power
348-
let expected = [| 8.0; 27.0; 64.0 |]
349-
floatArrayClose expected result 1e-10

0 commit comments

Comments
 (0)