Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 14 additions & 15 deletions src/Bonsai.ML.Torch/LinearAlgebra/CholeskyDecomposition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,23 @@
using System.Reactive.Linq;
using static TorchSharp.torch;

namespace Bonsai.ML.Torch.LinearAlgebra
namespace Bonsai.ML.Torch.LinearAlgebra;

/// <summary>
/// Represents an operator that computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix.
/// </summary>
[Combinator]
[Description("Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix.")]
[WorkflowElementCategory(ElementCategory.Transform)]
public class CholeskyDecomposition
{
/// <summary>
/// Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix.
/// </summary>
[Combinator]
[Description("Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix.")]
[WorkflowElementCategory(ElementCategory.Transform)]
public class CholeskyDecomposition
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tensor> Process(IObservable<Tensor> source)
{
/// <summary>
/// Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tensor> Process(IObservable<Tensor> source)
{
return source.Select(linalg.cholesky);
}
return source.Select(linalg.cholesky);
}
}
}
34 changes: 34 additions & 0 deletions src/Bonsai.ML.Torch/LinearAlgebra/CrossProduct.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
using System;
using System.ComponentModel;
using System.Reactive.Linq;
using static TorchSharp.torch;
using static TorchSharp.torch.linalg;

namespace Bonsai.ML.Torch.LinearAlgebra;

/// <summary>
/// Represents an operator that computes the cross product of 2 tensors.
/// </summary>
[Combinator]
[Description("Computes the cross product of 2 tensors.")]
[WorkflowElementCategory(ElementCategory.Transform)]
public class CrossProduct
{
/// <summary>
/// The dimension to perform the operation.
/// </summary>
public long Dimension { get; set; } = -1;

/// <summary>
/// Computes the cross product of 2 tensors.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tensor> Process(IObservable<Tuple<Tensor, Tensor>> source)
{
return source.Select(value =>
{
return cross(value.Item1, value.Item2, Dimension);
});
}
}
29 changes: 14 additions & 15 deletions src/Bonsai.ML.Torch/LinearAlgebra/Determinant.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,23 @@
using System.Reactive.Linq;
using static TorchSharp.torch;

namespace Bonsai.ML.Torch.LinearAlgebra
namespace Bonsai.ML.Torch.LinearAlgebra;

/// <summary>
/// Represents an operator that computes the determinant of a square matrix.
/// </summary>
[Combinator]
[Description("Computes the determinant of a square matrix.")]
[WorkflowElementCategory(ElementCategory.Transform)]
public class Determinant
{
/// <summary>
/// Computes the determinant of a square matrix.
/// </summary>
[Combinator]
[Description("Computes the determinant of a square matrix.")]
[WorkflowElementCategory(ElementCategory.Transform)]
public class Determinant
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tensor> Process(IObservable<Tensor> source)
{
/// <summary>
/// Computes the determinant of a square matrix.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tensor> Process(IObservable<Tensor> source)
{
return source.Select(linalg.det);
}
return source.Select(linalg.det);
}
}
}
42 changes: 29 additions & 13 deletions src/Bonsai.ML.Torch/LinearAlgebra/EigenvalueDecomposition.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,40 @@
using System.Reactive.Linq;
using static TorchSharp.torch;

namespace Bonsai.ML.Torch.LinearAlgebra
namespace Bonsai.ML.Torch.LinearAlgebra;

/// <summary>
/// Represents an operator that computes the eigenvalue decomposition of a square matrix if it exists.
/// </summary>
[Combinator]
[Description("Computes the eigenvalue decomposition of a square matrix if it exists.")]
[WorkflowElementCategory(ElementCategory.Transform)]
public class EigenvalueDecomposition
{
/// <summary>
/// Computes the eigenvalue decomposition of a square matrix if it exists.
/// </summary>
[Combinator]
[Description("Computes the eigenvalue decomposition of a square matrix if it exists.")]
[WorkflowElementCategory(ElementCategory.Transform)]
public class EigenvalueDecomposition
/// <param name="source"></param>
/// <returns></returns>
public IObservable<EigenDecompositionResult> Process(IObservable<Tensor> source)
{
return source.Select(tensor => new EigenDecompositionResult(linalg.eig(tensor)));
}

/// <summary>
/// Represents the result of an eigenvalue decomposition.
/// </summary>
/// <param name="result">The tuple containing the eigenvalues and eigenvectors.</param>
public readonly struct EigenDecompositionResult((Tensor eigenvalues, Tensor eigenvectors) result)
{
/// <summary>
/// Gets the eigenvalues of the decomposition.
/// </summary>
public Tensor Eigenvalues => result.eigenvalues;

/// <summary>
/// Computes the eigenvalue decomposition of a square matrix if it exists.
/// Gets the eigenvectors of the decomposition.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tuple<Tensor, Tensor>> Process(IObservable<Tensor> source)
{
return source.Select(tensor => linalg.eig(tensor).ToTuple());
}
public Tensor Eigenvectors => result.eigenvectors;
}
}
}
29 changes: 14 additions & 15 deletions src/Bonsai.ML.Torch/LinearAlgebra/Inverse.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,23 @@
using static TorchSharp.torch;
using static TorchSharp.torch.linalg;

namespace Bonsai.ML.Torch.LinearAlgebra
namespace Bonsai.ML.Torch.LinearAlgebra;

/// <summary>
/// Represents an operator that computes the inverse of the input matrix.
/// </summary>
[Combinator]
[Description("Computes the inverse of the input matrix.")]
[WorkflowElementCategory(ElementCategory.Transform)]
public class Inverse
{
/// <summary>
/// Computes the inverse of the input matrix.
/// </summary>
[Combinator]
[Description("Computes the inverse of the input matrix.")]
[WorkflowElementCategory(ElementCategory.Transform)]
public class Inverse
/// <param name="source">The input matrix to invert.</param>
/// <returns>The inverse of the input matrix.</returns>
public IObservable<Tensor> Process(IObservable<Tensor> source)
{
/// <summary>
/// Computes the inverse of the input matrix.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tensor> Process(IObservable<Tensor> source)
{
return source.Select(inv);
}
return source.Select(inv);
}
}
}
57 changes: 57 additions & 0 deletions src/Bonsai.ML.Torch/LinearAlgebra/LeastSquaresSolve.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
using System;
using System.ComponentModel;
using System.Reactive.Linq;
using static TorchSharp.torch;

namespace Bonsai.ML.Torch.LinearAlgebra;

/// <summary>
/// Represents an operator that computes the solution to the least squares and least norm problems for a full rank matrix A of size m*n and a matrix B of size m*k.
/// </summary>
[Combinator]
[Description("Computes the solution to the least squares and least norm problems for a full rank matrix A of size m*n and a matrix B of size m*k.")]
[WorkflowElementCategory(ElementCategory.Transform)]
public class LeastSquaresSolve
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
public class LeastSquaresSolve
public class LeastSquares

It feels off to have the verb at the end. I would either move it to the beginning, i.e. SolveLeastSquares, or follow the linalg convention and drop it, i.e. just call this operator LeastSquares.

{
/// <summary>
/// Computes the solution to the least squares and least norm problems for a full rank matrix A of size m*n and a matrix B of size m*k.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<LeastSquaresResult> Process(IObservable<Tuple<Tensor, Tensor>> source)
{
return source.Select(value => new LeastSquaresResult(linalg.lstsq(value.Item1, value.Item2)));
}

/// <summary>
/// Represents the result of solving of linear equations using the least squares method.
/// </summary>
/// <param name="result"></param>
public readonly struct LeastSquaresResult((
Tensor solution,
Tensor residuals,
Tensor rank,
Tensor singularValues
) result)
{
/// <summary>
/// The solution to the system of equations.
/// </summary>
public Tensor Solution => result.solution;

/// <summary>
/// The residual error.
/// </summary>
public Tensor Residuals => result.residuals;

/// <summary>
/// The effective rank of the solution.
/// </summary>
public Tensor Rank => result.rank;

/// <summary>
/// The singular values of the solution.
/// </summary>
public Tensor SingularValues => result.singularValues;
}
}
123 changes: 123 additions & 0 deletions src/Bonsai.ML.Torch/LinearAlgebra/MatrixMultiply.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
using System;
using System.Collections;
using System.Collections.Generic;
using System.ComponentModel;
using System.Reactive.Linq;
using static TorchSharp.torch;
using static TorchSharp.torch.linalg;

namespace Bonsai.ML.Torch.LinearAlgebra;

/// <summary>
/// Represents an operator that performs matrix multiplication of 2 or more tensors.
/// </summary>
[Combinator]
[Description("Performs matrix multiplication of 2 or more tensors.")]
[WorkflowElementCategory(ElementCategory.Transform)]
public class MatrixMultiply
{
/// <summary>
/// Performs matrix multiplication of 2 tensors.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tensor> Process(IObservable<Tuple<Tensor, Tensor>> source)
{
return source.Select(input =>input.Item1.matmul(input.Item2));
}

/// <summary>
/// Performs matrix multiplication of 3 tensors.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tensor> Process(IObservable<Tuple<Tensor, Tensor, Tensor>> source)
{
return source.Select(input =>
{
return multi_dot([input.Item1, input.Item2, input.Item3]);
});
}

/// <summary>
/// Performs matrix multiplication of 4 tensors.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tensor> Process(IObservable<Tuple<Tensor, Tensor, Tensor, Tensor>> source)
{
return source.Select(input =>
{
return multi_dot([input.Item1, input.Item2, input.Item3, input.Item4]);
});
}

/// <summary>
/// Performs matrix multiplication of 5 tensors.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tensor> Process(IObservable<Tuple<Tensor, Tensor, Tensor, Tensor, Tensor>> source)
{
return source.Select(input =>
{
return multi_dot([input.Item1, input.Item2, input.Item3, input.Item4, input.Item5]);
});
}

/// <summary>
/// Performs matrix multiplication of 6 tensors.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tensor> Process(IObservable<Tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor>> source)
{
return source.Select(input =>
{
return multi_dot([input.Item1, input.Item2, input.Item3, input.Item4, input.Item5, input.Item6]);
});
}

/// <summary>
/// Performs matrix multiplication of 7 tensors.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tensor> Process(IObservable<Tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor>> source)
{
return source.Select(input =>
{
return multi_dot([input.Item1, input.Item2, input.Item3, input.Item4, input.Item5, input.Item6, input.Item7]);
});
}

/// <summary>
/// Performs matrix multiplication of an array of tensors.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tensor> Process(IObservable<Tensor[]> source)
{
return source.Select(multi_dot);
}

/// <summary>
/// Performs matrix multiplication of a list of tensors.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tensor> Process(IObservable<IList<Tensor>> source)
{
return source.Select(multi_dot);
}

/// <summary>
/// Performs matrix multiplication of an enumerable of tensors.
/// </summary>
/// <param name="source"></param>
/// <returns></returns>
public IObservable<Tensor> Process(IObservable<IEnumerable<Tensor>> source)
{
return source.Select(input => multi_dot([.. input]));
}
}
Loading
Loading