From 26e1dc3832505bdd66162d6eba5b1a373a60890e Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 13 Nov 2025 18:04:53 +0000 Subject: [PATCH 01/10] Added operator for performing `CrossProduct` --- .../LinearAlgebra/CrossProduct.cs | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 src/Bonsai.ML.Torch/LinearAlgebra/CrossProduct.cs diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/CrossProduct.cs b/src/Bonsai.ML.Torch/LinearAlgebra/CrossProduct.cs new file mode 100644 index 00000000..7627344f --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/CrossProduct.cs @@ -0,0 +1,34 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using static TorchSharp.torch.linalg; + +namespace Bonsai.ML.Torch.LinearAlgebra; + +/// +/// Computes the cross product of 2 tensors. +/// +[Combinator] +[Description("Computes the cross product of 2 tensors.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class CrossProduct +{ + /// + /// The dimension to perform the operation. + /// + public long Dimension { get; set; } = -1; + + /// + /// Computes the cross product of 2 tensors. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return cross(value.Item1, value.Item2, Dimension); + }); + } +} \ No newline at end of file From 383de204d7e29e1689bbb62bc1e9d44ed73afacd Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 13 Nov 2025 18:05:38 +0000 Subject: [PATCH 02/10] Added operator `LeastSquaresSolve` for solving systems of linear equations --- .../LinearAlgebra/LeastSquaresSolve.cs | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 src/Bonsai.ML.Torch/LinearAlgebra/LeastSquaresSolve.cs diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/LeastSquaresSolve.cs b/src/Bonsai.ML.Torch/LinearAlgebra/LeastSquaresSolve.cs new file mode 100644 index 00000000..d5560708 --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/LeastSquaresSolve.cs @@ -0,0 +1,69 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using static TorchSharp.torch.linalg; + +namespace Bonsai.ML.Torch.LinearAlgebra; + +/// +/// Represents an operator that computes the solution to the least squares and least norm problems for a full rank matrix A of size m×n and a matrix B of size m×k. +/// +[Combinator] +[Description("Computes the solution to the system tensordot(A, X) = B.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class LeastSquaresSolve +{ + /// + /// Computes the solution to the least squares and least norm problems for a full rank matrix A of size m×n and a matrix B of size m×k. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + var (solution, residuals, rank, singularValues) = linalg.lstsq(value.Item1, value.Item2); + return new LeastSquaresResult( + solution, + residuals, + rank, + singularValues); + }); + } + + /// + /// Represents the result of solving of linear equations using the least squares method. + /// + /// + /// + /// + /// + public readonly struct LeastSquaresResult( + Tensor solution, + Tensor residuals, + Tensor rank, + Tensor singularValues + ) + { + /// + /// The solution to the system of equations. + /// + public Tensor Solution => solution; + + /// + /// The residual error. + /// + public Tensor Residuals => residuals; + + /// + /// The effective rank of the solution. + /// + public Tensor Rank => rank; + + /// + /// The singular values of the solution. + /// + public Tensor SingularValues => singularValues; + } +} \ No newline at end of file From 5bccb087bc8a1eb283027c9719899fd000a05291 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 13 Nov 2025 18:06:24 +0000 Subject: [PATCH 03/10] Added operator `TensorSolve` to compute a tensor solution to the problem AX=B --- .../LinearAlgebra/TensorSolve.cs | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 src/Bonsai.ML.Torch/LinearAlgebra/TensorSolve.cs diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/TensorSolve.cs b/src/Bonsai.ML.Torch/LinearAlgebra/TensorSolve.cs new file mode 100644 index 00000000..1f1be9a5 --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/TensorSolve.cs @@ -0,0 +1,35 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using static TorchSharp.torch.linalg; + +namespace Bonsai.ML.Torch.LinearAlgebra; + +/// +/// Represents an operator that computes the solution X to the system tensordot(A, X) = B. +/// +[Combinator] +[Description("Computes the solution to the system tensordot(A, X) = B.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class TensorSolve +{ + /// + /// The dimension to perform the operation. + /// + [TypeConverter(typeof(UnidimensionalArrayConverter))] + public long[] Dimensions { get; set; } = []; + + /// + /// Computes the cross product of 2 tensors. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return tensorsolve(value.Item1, value.Item2, Dimensions); + }); + } +} \ No newline at end of file From dacb4fadb9c6c7c821e47b8e04039d38ecfa868a Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 13 Nov 2025 18:08:09 +0000 Subject: [PATCH 04/10] Updated `SingularValueDecomposition` to return a struct output instead of tuple --- .../SingularValueDecomposition.cs | 35 +++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/SingularValueDecomposition.cs b/src/Bonsai.ML.Torch/LinearAlgebra/SingularValueDecomposition.cs index c440316f..ba22652f 100644 --- a/src/Bonsai.ML.Torch/LinearAlgebra/SingularValueDecomposition.cs +++ b/src/Bonsai.ML.Torch/LinearAlgebra/SingularValueDecomposition.cs @@ -23,9 +23,40 @@ public class SingularValueDecomposition /// /// /// - public IObservable> Process(IObservable source) + public IObservable Process(IObservable source) { - return source.Select(tensor => linalg.svd(tensor, fullMatrices: FullMatrices).ToTuple()); + return source.Select(tensor => + { + var (U, S, Vh) = linalg.svd(tensor, fullMatrices: FullMatrices); + return new SvdResult(U, S, Vh); + }); + } + + /// + /// Represents the result of a singular value decomposition. + /// + /// + /// + /// + public readonly struct SvdResult( + Tensor u, + Tensor s, + Tensor vh) + { + /// + /// The U tensor. + /// + public Tensor U => u; + + /// + /// The singular values. + /// + public Tensor S => s; + + /// + /// The Vh tensor. + /// + public Tensor Vh => vh; } } } \ No newline at end of file From b44b79d8643510b986c2449a523ae8b7c677b018 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 13 Nov 2025 18:09:17 +0000 Subject: [PATCH 05/10] Added operator to support chaining matrix multiplication --- .../LinearAlgebra/MatrixMultiply.cs | 126 ++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 src/Bonsai.ML.Torch/LinearAlgebra/MatrixMultiply.cs diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/MatrixMultiply.cs b/src/Bonsai.ML.Torch/LinearAlgebra/MatrixMultiply.cs new file mode 100644 index 00000000..94952c45 --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/MatrixMultiply.cs @@ -0,0 +1,126 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using static TorchSharp.torch.linalg; + +namespace Bonsai.ML.Torch.LinearAlgebra; + +/// +/// Represents an operator that performs matrix multiplication with 2 or more tensors. +/// +[Combinator] +[Description("Performs matrix multiplication with 2 or more tensors.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class MatrixMultiply +{ + /// + /// Performs matrix multiplication with 2 tensors. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(input => + { + return matmul(input.Item1, input.Item2); + }); + } + + /// + /// Performs matrix multiplication with 3 tensors. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(input => + { + return multi_dot([input.Item1, input.Item2, input.Item3]); + }); + } + + /// + /// Performs matrix multiplication with 4 tensors. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(input => + { + return multi_dot([input.Item1, input.Item2, input.Item3, input.Item4]); + }); + } + + /// + /// Performs matrix multiplication with 5 tensors. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(input => + { + return multi_dot([input.Item1, input.Item2, input.Item3, input.Item4, input.Item5]); + }); + } + + /// + /// Performs matrix multiplication with 6 tensors. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(input => + { + return multi_dot([input.Item1, input.Item2, input.Item3, input.Item4, input.Item5, input.Item6]); + }); + } + + /// + /// Performs matrix multiplication with 7 tensors. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(input => + { + return multi_dot([input.Item1, input.Item2, input.Item3, input.Item4, input.Item5, input.Item6, input.Item7]); + }); + } + + /// + /// Performs matrix multiplication with an array of tensors. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(multi_dot); + } + + /// + /// Performs matrix multiplication with a list of tensors. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(multi_dot); + } + + /// + /// Performs matrix multiplication with an enumerable of tensors. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(input => multi_dot([.. input])); + } +} \ No newline at end of file From 35ba73fea54e2f7821c2ec1f43555d3084094d44 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 12 Dec 2025 13:53:07 +0000 Subject: [PATCH 06/10] Improved XML documentation --- .../LinearAlgebra/CholeskyDecomposition.cs | 29 ++++--- .../LinearAlgebra/CrossProduct.cs | 4 +- .../LinearAlgebra/Determinant.cs | 29 ++++--- .../LinearAlgebra/EigenvalueDecomposition.cs | 42 +++++++--- src/Bonsai.ML.Torch/LinearAlgebra/Inverse.cs | 29 ++++--- .../LinearAlgebra/LeastSquaresSolve.cs | 36 +++------ .../LinearAlgebra/MatrixMultiply.cs | 24 +++--- src/Bonsai.ML.Torch/LinearAlgebra/Norm.cs | 53 ++++++------ .../LinearAlgebra/SignLogDeterminant.cs | 44 ++++++---- .../SingularValueDecomposition.cs | 81 +++++++++---------- .../LinearAlgebra/TensorSolve.cs | 7 +- 11 files changed, 196 insertions(+), 182 deletions(-) diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/CholeskyDecomposition.cs b/src/Bonsai.ML.Torch/LinearAlgebra/CholeskyDecomposition.cs index 92369615..3e901641 100644 --- a/src/Bonsai.ML.Torch/LinearAlgebra/CholeskyDecomposition.cs +++ b/src/Bonsai.ML.Torch/LinearAlgebra/CholeskyDecomposition.cs @@ -3,24 +3,23 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.LinearAlgebra +namespace Bonsai.ML.Torch.LinearAlgebra; + +/// +/// Represents an operator that computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix. +/// +[Combinator] +[Description("Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class CholeskyDecomposition { /// /// Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix. /// - [Combinator] - [Description("Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class CholeskyDecomposition + /// + /// + public IObservable Process(IObservable source) { - /// - /// Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(linalg.cholesky); - } + return source.Select(linalg.cholesky); } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/CrossProduct.cs b/src/Bonsai.ML.Torch/LinearAlgebra/CrossProduct.cs index 7627344f..dd4be273 100644 --- a/src/Bonsai.ML.Torch/LinearAlgebra/CrossProduct.cs +++ b/src/Bonsai.ML.Torch/LinearAlgebra/CrossProduct.cs @@ -7,7 +7,7 @@ namespace Bonsai.ML.Torch.LinearAlgebra; /// -/// Computes the cross product of 2 tensors. +/// Represents an operator that computes the cross product of 2 tensors. /// [Combinator] [Description("Computes the cross product of 2 tensors.")] @@ -31,4 +31,4 @@ public IObservable Process(IObservable> source) return cross(value.Item1, value.Item2, Dimension); }); } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/Determinant.cs b/src/Bonsai.ML.Torch/LinearAlgebra/Determinant.cs index 475651d0..4a5cb78d 100644 --- a/src/Bonsai.ML.Torch/LinearAlgebra/Determinant.cs +++ b/src/Bonsai.ML.Torch/LinearAlgebra/Determinant.cs @@ -3,24 +3,23 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.LinearAlgebra +namespace Bonsai.ML.Torch.LinearAlgebra; + +/// +/// Represents an operator that computes the determinant of a square matrix. +/// +[Combinator] +[Description("Computes the determinant of a square matrix.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Determinant { /// /// Computes the determinant of a square matrix. /// - [Combinator] - [Description("Computes the determinant of a square matrix.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class Determinant + /// + /// + public IObservable Process(IObservable source) { - /// - /// Computes the determinant of a square matrix. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(linalg.det); - } + return source.Select(linalg.det); } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/EigenvalueDecomposition.cs b/src/Bonsai.ML.Torch/LinearAlgebra/EigenvalueDecomposition.cs index 6784b1bf..d4029fc7 100644 --- a/src/Bonsai.ML.Torch/LinearAlgebra/EigenvalueDecomposition.cs +++ b/src/Bonsai.ML.Torch/LinearAlgebra/EigenvalueDecomposition.cs @@ -3,24 +3,40 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.LinearAlgebra +namespace Bonsai.ML.Torch.LinearAlgebra; + +/// +/// Represents an operator that computes the eigenvalue decomposition of a square matrix if it exists. +/// +[Combinator] +[Description("Computes the eigenvalue decomposition of a square matrix if it exists.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class EigenvalueDecomposition { /// /// Computes the eigenvalue decomposition of a square matrix if it exists. /// - [Combinator] - [Description("Computes the eigenvalue decomposition of a square matrix if it exists.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class EigenvalueDecomposition + /// + /// + public IObservable Process(IObservable source) { + return source.Select(tensor => new EigenDecompositionResult(linalg.eig(tensor))); + } + + /// + /// Represents the result of an eigenvalue decomposition. + /// + /// The tuple containing the eigenvalues and eigenvectors. + public readonly struct EigenDecompositionResult((Tensor eigenvalues, Tensor eigenvectors) result) + { + /// + /// Gets the eigenvalues of the decomposition. + /// + public Tensor Eigenvalues => result.eigenvalues; + /// - /// Computes the eigenvalue decomposition of a square matrix if it exists. + /// Gets the eigenvectors of the decomposition. /// - /// - /// - public IObservable> Process(IObservable source) - { - return source.Select(tensor => linalg.eig(tensor).ToTuple()); - } + public Tensor Eigenvectors => result.eigenvectors; } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/Inverse.cs b/src/Bonsai.ML.Torch/LinearAlgebra/Inverse.cs index 1d879b5b..530979fc 100644 --- a/src/Bonsai.ML.Torch/LinearAlgebra/Inverse.cs +++ b/src/Bonsai.ML.Torch/LinearAlgebra/Inverse.cs @@ -4,24 +4,23 @@ using static TorchSharp.torch; using static TorchSharp.torch.linalg; -namespace Bonsai.ML.Torch.LinearAlgebra +namespace Bonsai.ML.Torch.LinearAlgebra; + +/// +/// Represents an operator that computes the inverse of the input matrix. +/// +[Combinator] +[Description("Computes the inverse of the input matrix.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Inverse { /// /// Computes the inverse of the input matrix. /// - [Combinator] - [Description("Computes the inverse of the input matrix.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class Inverse + /// The input matrix to invert. + /// The inverse of the input matrix. + public IObservable Process(IObservable source) { - /// - /// Computes the inverse of the input matrix. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(inv); - } + return source.Select(inv); } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/LeastSquaresSolve.cs b/src/Bonsai.ML.Torch/LinearAlgebra/LeastSquaresSolve.cs index d5560708..aaa058c3 100644 --- a/src/Bonsai.ML.Torch/LinearAlgebra/LeastSquaresSolve.cs +++ b/src/Bonsai.ML.Torch/LinearAlgebra/LeastSquaresSolve.cs @@ -2,68 +2,56 @@ using System.ComponentModel; using System.Reactive.Linq; using static TorchSharp.torch; -using static TorchSharp.torch.linalg; namespace Bonsai.ML.Torch.LinearAlgebra; /// -/// Represents an operator that computes the solution to the least squares and least norm problems for a full rank matrix A of size m×n and a matrix B of size m×k. +/// Represents an operator that computes the solution to the least squares and least norm problems for a full rank matrix A of size m*n and a matrix B of size m*k. /// [Combinator] -[Description("Computes the solution to the system tensordot(A, X) = B.")] +[Description("Computes the solution to the least squares and least norm problems for a full rank matrix A of size m*n and a matrix B of size m*k.")] [WorkflowElementCategory(ElementCategory.Transform)] public class LeastSquaresSolve { /// - /// Computes the solution to the least squares and least norm problems for a full rank matrix A of size m×n and a matrix B of size m×k. + /// Computes the solution to the least squares and least norm problems for a full rank matrix A of size m*n and a matrix B of size m*k. /// /// /// public IObservable Process(IObservable> source) { - return source.Select(value => - { - var (solution, residuals, rank, singularValues) = linalg.lstsq(value.Item1, value.Item2); - return new LeastSquaresResult( - solution, - residuals, - rank, - singularValues); - }); + return source.Select(value => new LeastSquaresResult(linalg.lstsq(value.Item1, value.Item2))); } /// /// Represents the result of solving of linear equations using the least squares method. /// - /// - /// - /// - /// - public readonly struct LeastSquaresResult( + /// + public readonly struct LeastSquaresResult(( Tensor solution, Tensor residuals, Tensor rank, Tensor singularValues - ) + ) result) { /// /// The solution to the system of equations. /// - public Tensor Solution => solution; + public Tensor Solution => result.solution; /// /// The residual error. /// - public Tensor Residuals => residuals; + public Tensor Residuals => result.residuals; /// /// The effective rank of the solution. /// - public Tensor Rank => rank; + public Tensor Rank => result.rank; /// /// The singular values of the solution. /// - public Tensor SingularValues => singularValues; + public Tensor SingularValues => result.singularValues; } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/MatrixMultiply.cs b/src/Bonsai.ML.Torch/LinearAlgebra/MatrixMultiply.cs index 94952c45..30a25c8a 100644 --- a/src/Bonsai.ML.Torch/LinearAlgebra/MatrixMultiply.cs +++ b/src/Bonsai.ML.Torch/LinearAlgebra/MatrixMultiply.cs @@ -9,15 +9,15 @@ namespace Bonsai.ML.Torch.LinearAlgebra; /// -/// Represents an operator that performs matrix multiplication with 2 or more tensors. +/// Represents an operator that performs matrix multiplication of 2 or more tensors. /// [Combinator] -[Description("Performs matrix multiplication with 2 or more tensors.")] +[Description("Performs matrix multiplication of 2 or more tensors.")] [WorkflowElementCategory(ElementCategory.Transform)] public class MatrixMultiply { /// - /// Performs matrix multiplication with 2 tensors. + /// Performs matrix multiplication of 2 tensors. /// /// /// @@ -30,7 +30,7 @@ public IObservable Process(IObservable> source) } /// - /// Performs matrix multiplication with 3 tensors. + /// Performs matrix multiplication of 3 tensors. /// /// /// @@ -43,7 +43,7 @@ public IObservable Process(IObservable> so } /// - /// Performs matrix multiplication with 4 tensors. + /// Performs matrix multiplication of 4 tensors. /// /// /// @@ -56,7 +56,7 @@ public IObservable Process(IObservable - /// Performs matrix multiplication with 5 tensors. + /// Performs matrix multiplication of 5 tensors. /// /// /// @@ -69,7 +69,7 @@ public IObservable Process(IObservable - /// Performs matrix multiplication with 6 tensors. + /// Performs matrix multiplication of 6 tensors. /// /// /// @@ -82,7 +82,7 @@ public IObservable Process(IObservable - /// Performs matrix multiplication with 7 tensors. + /// Performs matrix multiplication of 7 tensors. /// /// /// @@ -95,7 +95,7 @@ public IObservable Process(IObservable - /// Performs matrix multiplication with an array of tensors. + /// Performs matrix multiplication of an array of tensors. /// /// /// @@ -105,7 +105,7 @@ public IObservable Process(IObservable source) } /// - /// Performs matrix multiplication with a list of tensors. + /// Performs matrix multiplication of a list of tensors. /// /// /// @@ -115,7 +115,7 @@ public IObservable Process(IObservable> source) } /// - /// Performs matrix multiplication with an enumerable of tensors. + /// Performs matrix multiplication of an enumerable of tensors. /// /// /// @@ -123,4 +123,4 @@ public IObservable Process(IObservable> source) { return source.Select(input => multi_dot([.. input])); } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/Norm.cs b/src/Bonsai.ML.Torch/LinearAlgebra/Norm.cs index eb18920d..510a7b59 100644 --- a/src/Bonsai.ML.Torch/LinearAlgebra/Norm.cs +++ b/src/Bonsai.ML.Torch/LinearAlgebra/Norm.cs @@ -3,35 +3,36 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.LinearAlgebra +namespace Bonsai.ML.Torch.LinearAlgebra; + +/// +/// Represents an operator that computes a vector or matrix norm. +/// +[Combinator] +[Description("Computes a vector or matrix norm.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class Norm { /// - /// Computes a vector or matrix norm. + /// The dimensions along which to compute the norm. /// - [Combinator] - [Description("Computes a vector or matrix norm.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class Norm - { - /// - /// The dimensions along which to compute the norm. - /// - [TypeConverter(typeof(UnidimensionalArrayConverter))] - public long[] Dimensions { get; set; } = null; + [TypeConverter(typeof(UnidimensionalArrayConverter))] + [Description("The dimensions along which to compute the norm.")] + public long[] Dimensions { get; set; } = null; - /// - /// If true, the reduced dimensions are retained in the result as dimensions with size one. - /// - public bool Keepdim { get; set; } = false; + /// + /// If true, the reduced dimensions are retained in the result as dimensions with size one. + /// + [Description("If true, the reduced dimensions are retained in the result as dimensions with size one.")] + public bool Keepdim { get; set; } = false; - /// - /// Computes a matrix norm. - /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(tensor => linalg.norm(tensor, dims: Dimensions, keepdim: Keepdim)); - } + /// + /// Computes a matrix norm. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => linalg.norm(tensor, dims: Dimensions, keepdim: Keepdim)); } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/SignLogDeterminant.cs b/src/Bonsai.ML.Torch/LinearAlgebra/SignLogDeterminant.cs index 6d29c910..2b9eb66d 100644 --- a/src/Bonsai.ML.Torch/LinearAlgebra/SignLogDeterminant.cs +++ b/src/Bonsai.ML.Torch/LinearAlgebra/SignLogDeterminant.cs @@ -3,24 +3,40 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.LinearAlgebra +namespace Bonsai.ML.Torch.LinearAlgebra; + +/// +/// Represents an operator that computes the sign and natural logarithm of the absolute value of the determinant of a square matrix. +/// +[Combinator] +[Description("Computes the sign and natural logarithm of the absolute value of the determinant of a square matrix.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class SignLogDeterminant { /// - /// Computes the determinant of a square matrix. + /// Computes the sign and natural logarithm of the absolute value of the determinant of a square matrix. /// - [Combinator] - [Description("Computes the sign and natural logarithm of the absolute value of the determinant of a square matrix.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class SignLogDeterminant + /// + /// + public IObservable Process(IObservable source) { + return source.Select(result => new SignLogDeterminantResult(linalg.slogdet(result))); + } + + /// + /// Represents the result of computing the sign and natural logarithm of the absolute value of the determinant. + /// + /// + public readonly struct SignLogDeterminantResult((Tensor sign, Tensor logabsdet) result) + { + /// + /// Gets the sign of the determinant. + /// + public Tensor Sign => result.sign; + /// - /// Computes the determinant of a square matrix. + /// Gets the natural logarithm of the absolute value of the determinant. /// - /// - /// - public IObservable<(Tensor, Tensor)> Process(IObservable source) - { - return source.Select(linalg.slogdet); - } + public Tensor LogAbsDeterminant => result.logabsdet; } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/SingularValueDecomposition.cs b/src/Bonsai.ML.Torch/LinearAlgebra/SingularValueDecomposition.cs index ba22652f..25ac9581 100644 --- a/src/Bonsai.ML.Torch/LinearAlgebra/SingularValueDecomposition.cs +++ b/src/Bonsai.ML.Torch/LinearAlgebra/SingularValueDecomposition.cs @@ -3,60 +3,55 @@ using System.Reactive.Linq; using static TorchSharp.torch; -namespace Bonsai.ML.Torch.LinearAlgebra +namespace Bonsai.ML.Torch.LinearAlgebra; + +/// +/// Represents an operator that computes the singular value decomposition (SVD) of a matrix. +/// +[Combinator] +[Description("Computes the singular value decomposition (SVD) of a matrix.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class SingularValueDecomposition { + /// + /// Whether to compute the full or reduced SVD. + /// + [Description("Whether to compute the full or reduced SVD.")] + public bool FullMatrices { get; set; } = false; + /// /// Computes the singular value decomposition (SVD) of a matrix. /// - [Combinator] - [Description("Computes the singular value decomposition (SVD) of a matrix.")] - [WorkflowElementCategory(ElementCategory.Transform)] - public class SingularValueDecomposition + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => new SingularValueDecompositionResult(linalg.svd(tensor, fullMatrices: FullMatrices))); + } + + /// + /// Represents the result of a singular value decomposition. + /// + /// + public readonly struct SingularValueDecompositionResult(( + Tensor u, + Tensor s, + Tensor vh + ) result) { /// - /// Whether to compute the full or reduced SVD. + /// The U tensor. /// - public bool FullMatrices { get; set; } = false; + public Tensor U => result.u; /// - /// Computes the singular value decomposition (SVD) of a matrix. + /// The singular values. /// - /// - /// - public IObservable Process(IObservable source) - { - return source.Select(tensor => - { - var (U, S, Vh) = linalg.svd(tensor, fullMatrices: FullMatrices); - return new SvdResult(U, S, Vh); - }); - } + public Tensor S => result.s; /// - /// Represents the result of a singular value decomposition. + /// The Vh tensor. /// - /// - /// - /// - public readonly struct SvdResult( - Tensor u, - Tensor s, - Tensor vh) - { - /// - /// The U tensor. - /// - public Tensor U => u; - - /// - /// The singular values. - /// - public Tensor S => s; - - /// - /// The Vh tensor. - /// - public Tensor Vh => vh; - } + public Tensor Vh => result.vh; } -} \ No newline at end of file +} diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/TensorSolve.cs b/src/Bonsai.ML.Torch/LinearAlgebra/TensorSolve.cs index 1f1be9a5..52b5f5d3 100644 --- a/src/Bonsai.ML.Torch/LinearAlgebra/TensorSolve.cs +++ b/src/Bonsai.ML.Torch/LinearAlgebra/TensorSolve.cs @@ -15,13 +15,14 @@ namespace Bonsai.ML.Torch.LinearAlgebra; public class TensorSolve { /// - /// The dimension to perform the operation. + /// The dimensions to perform the operation. /// [TypeConverter(typeof(UnidimensionalArrayConverter))] + [Description("The dimensions to perform the operation.")] public long[] Dimensions { get; set; } = []; /// - /// Computes the cross product of 2 tensors. + /// Computes the solution to the system tensordot(A, X) = B. /// /// /// @@ -32,4 +33,4 @@ public IObservable Process(IObservable> source) return tensorsolve(value.Item1, value.Item2, Dimensions); }); } -} \ No newline at end of file +} From 854e7f7b5562a06e02d9e5075b10f9597922bcc7 Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 19 Dec 2025 13:52:37 +0000 Subject: [PATCH 07/10] Refactored `MatrixMultiply` overload with a tuple of 2 tensors to use object instead of static method --- src/Bonsai.ML.Torch/LinearAlgebra/MatrixMultiply.cs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/MatrixMultiply.cs b/src/Bonsai.ML.Torch/LinearAlgebra/MatrixMultiply.cs index 30a25c8a..0703dc3d 100644 --- a/src/Bonsai.ML.Torch/LinearAlgebra/MatrixMultiply.cs +++ b/src/Bonsai.ML.Torch/LinearAlgebra/MatrixMultiply.cs @@ -23,10 +23,7 @@ public class MatrixMultiply /// public IObservable Process(IObservable> source) { - return source.Select(input => - { - return matmul(input.Item1, input.Item2); - }); + return source.Select(input =>input.Item1.matmul(input.Item2)); } /// From 5f7c67071ef1f7a74d523dd4bcb93264544759ac Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 19 Dec 2025 13:53:19 +0000 Subject: [PATCH 08/10] Added operator to compute the rank of a matrix --- .../LinearAlgebra/MatrixRank.cs | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) create mode 100644 src/Bonsai.ML.Torch/LinearAlgebra/MatrixRank.cs diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/MatrixRank.cs b/src/Bonsai.ML.Torch/LinearAlgebra/MatrixRank.cs new file mode 100644 index 00000000..84eb4c37 --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/MatrixRank.cs @@ -0,0 +1,46 @@ +using System; +using System.Collections; +using System.Collections.Generic; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using static TorchSharp.torch.linalg; + +namespace Bonsai.ML.Torch.LinearAlgebra; + +/// +/// Represents an operator that computes the numerical rank of a matrix. +/// +[Combinator] +[Description("Computes the numerical rank of a matrix.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class MatrixRank +{ + /// + /// Gets or sets the absolute tolerance for singular values to be considered non-zero. + /// + [Description("The absolute tolerance for singular values to be considered non-zero.")] + public double? AbsoluteTolerance { get; set; } = null; + + /// + /// Gets or sets the relative tolerance for singular values to be considered non-zero. + /// + [Description("The relative tolerance for singular values to be considered non-zero.")] + public double? RelativeTolerance { get; set; } = null; + + /// + /// Gets or sets a value indicating whether to treat the input matrix as Hermitian if input is complex or symmetric if real. + /// + [Description("Indicates whether to treat the input matrix as Hermitian if input is complex or symmetric if real.")] + public bool Hermitian { get; set; } = false; + + /// + /// Computes the numerical rank of a matrix. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(input => matrix_rank(input, atol: AbsoluteTolerance, rtol: RelativeTolerance, hermitian: Hermitian)); + } +} From 62495e8551056c7dc323cc090e2397550b5d84eb Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 19 Dec 2025 13:53:44 +0000 Subject: [PATCH 09/10] Added operator to compute the QR decomposition of a matrix --- .../LinearAlgebra/QRDecomposition.cs | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 src/Bonsai.ML.Torch/LinearAlgebra/QRDecomposition.cs diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/QRDecomposition.cs b/src/Bonsai.ML.Torch/LinearAlgebra/QRDecomposition.cs new file mode 100644 index 00000000..eb01aaf3 --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/QRDecomposition.cs @@ -0,0 +1,50 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using static TorchSharp.torch.linalg; + +namespace Bonsai.ML.Torch.LinearAlgebra; + +/// +/// Represents an operator that computes the QR decomposition of a matrix. +/// +[Combinator] +[Description("Computes the QR decomposition of a matrix.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class QRDecomposition +{ + /// + /// Gets or sets the mode of the QR decomposition. + /// + [Description("The mode of the QR decomposition.")] + public QRMode Mode { get; set; } = QRMode.Reduced; + + /// + /// Computes the QR decomposition of a matrix. + /// + /// + /// + public IObservable Process(IObservable source) + { + return source.Select(tensor => new QRDecompositionResult(qr(tensor, mode: Mode))); + } + + /// + /// Represents the result of a QR decomposition. + /// + /// + public readonly struct QRDecompositionResult((Tensor Q, Tensor R) result) + { + /// + /// Gets the orthogonal matrix Q. + /// + public Tensor Q => result.Q; + + /// + /// Gets the upper triangular matrix R. + /// + public Tensor R => result.R; + } + +} From 838a90dc1a65ffd68a84e038c5e97702b5e7cc0f Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Fri, 19 Dec 2025 13:54:10 +0000 Subject: [PATCH 10/10] Added an operator to solve a triangular system of equations --- .../LinearAlgebra/TriangularSolve.cs | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) create mode 100644 src/Bonsai.ML.Torch/LinearAlgebra/TriangularSolve.cs diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/TriangularSolve.cs b/src/Bonsai.ML.Torch/LinearAlgebra/TriangularSolve.cs new file mode 100644 index 00000000..8f7b1f54 --- /dev/null +++ b/src/Bonsai.ML.Torch/LinearAlgebra/TriangularSolve.cs @@ -0,0 +1,47 @@ +using System; +using System.ComponentModel; +using System.Reactive.Linq; +using static TorchSharp.torch; +using static TorchSharp.torch.linalg; + +namespace Bonsai.ML.Torch.LinearAlgebra; + +/// +/// Represents an operator that computes the solution to a triangular system of linear equations with a unique solution. +/// +[Combinator] +[Description("Computes the solution to a triangular system of linear equations with a unique solution.")] +[WorkflowElementCategory(ElementCategory.Transform)] +public class TriangularSolve +{ + /// + /// Gets or sets a value indicating whether the first matrix is upper triangular. + /// + [Description("Indicates whether the first matrix is upper triangular.")] + public bool Upper { get; set; } = true; + + /// + /// Gets or sets a value indicating whether to solve the system with the first matrix on the left or right (AX = B or XA = B). + /// + [Description("Indicates whether to solve the system with the first matrix on the left or right (AX = B or XA = B).")] + public bool Left { get; set; } = true; + + /// + /// Gets or sets a value indicating whether the first matrix has a unit diagonal, i.e., all diagonal elements are assumed to be 1. + /// + [Description("Indicates whether the first matrix has a unit diagonal, i.e., all diagonal elements are assumed to be 1.")] + public bool UnitDiagonal { get; set; } = false; + + /// + /// Computes the solution to a triangular system of linear equations for each pair of input tensors. + /// + /// + /// + public IObservable Process(IObservable> source) + { + return source.Select(value => + { + return solve_triangular(value.Item1, value.Item2, upper: Upper, left: Left, unitriangular: UnitDiagonal); + }); + } +}