diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/CholeskyDecomposition.cs b/src/Bonsai.ML.Torch/LinearAlgebra/CholeskyDecomposition.cs
index 92369615..3e901641 100644
--- a/src/Bonsai.ML.Torch/LinearAlgebra/CholeskyDecomposition.cs
+++ b/src/Bonsai.ML.Torch/LinearAlgebra/CholeskyDecomposition.cs
@@ -3,24 +3,23 @@
using System.Reactive.Linq;
using static TorchSharp.torch;
-namespace Bonsai.ML.Torch.LinearAlgebra
+namespace Bonsai.ML.Torch.LinearAlgebra;
+
+///
+/// Represents an operator that computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix.
+///
+[Combinator]
+[Description("Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix.")]
+[WorkflowElementCategory(ElementCategory.Transform)]
+public class CholeskyDecomposition
{
///
/// Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix.
///
- [Combinator]
- [Description("Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix.")]
- [WorkflowElementCategory(ElementCategory.Transform)]
- public class CholeskyDecomposition
+ ///
+ ///
+ public IObservable Process(IObservable source)
{
- ///
- /// Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix.
- ///
- ///
- ///
- public IObservable Process(IObservable source)
- {
- return source.Select(linalg.cholesky);
- }
+ return source.Select(linalg.cholesky);
}
-}
\ No newline at end of file
+}
diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/CrossProduct.cs b/src/Bonsai.ML.Torch/LinearAlgebra/CrossProduct.cs
new file mode 100644
index 00000000..dd4be273
--- /dev/null
+++ b/src/Bonsai.ML.Torch/LinearAlgebra/CrossProduct.cs
@@ -0,0 +1,34 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.linalg;
+
+namespace Bonsai.ML.Torch.LinearAlgebra;
+
+///
+/// Represents an operator that computes the cross product of 2 tensors.
+///
+[Combinator]
+[Description("Computes the cross product of 2 tensors.")]
+[WorkflowElementCategory(ElementCategory.Transform)]
+public class CrossProduct
+{
+ ///
+ /// The dimension to perform the operation.
+ ///
+ public long Dimension { get; set; } = -1;
+
+ ///
+ /// Computes the cross product of 2 tensors.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(value =>
+ {
+ return cross(value.Item1, value.Item2, Dimension);
+ });
+ }
+}
diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/Determinant.cs b/src/Bonsai.ML.Torch/LinearAlgebra/Determinant.cs
index 475651d0..4a5cb78d 100644
--- a/src/Bonsai.ML.Torch/LinearAlgebra/Determinant.cs
+++ b/src/Bonsai.ML.Torch/LinearAlgebra/Determinant.cs
@@ -3,24 +3,23 @@
using System.Reactive.Linq;
using static TorchSharp.torch;
-namespace Bonsai.ML.Torch.LinearAlgebra
+namespace Bonsai.ML.Torch.LinearAlgebra;
+
+///
+/// Represents an operator that computes the determinant of a square matrix.
+///
+[Combinator]
+[Description("Computes the determinant of a square matrix.")]
+[WorkflowElementCategory(ElementCategory.Transform)]
+public class Determinant
{
///
/// Computes the determinant of a square matrix.
///
- [Combinator]
- [Description("Computes the determinant of a square matrix.")]
- [WorkflowElementCategory(ElementCategory.Transform)]
- public class Determinant
+ ///
+ ///
+ public IObservable Process(IObservable source)
{
- ///
- /// Computes the determinant of a square matrix.
- ///
- ///
- ///
- public IObservable Process(IObservable source)
- {
- return source.Select(linalg.det);
- }
+ return source.Select(linalg.det);
}
-}
\ No newline at end of file
+}
diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/EigenvalueDecomposition.cs b/src/Bonsai.ML.Torch/LinearAlgebra/EigenvalueDecomposition.cs
index 6784b1bf..d4029fc7 100644
--- a/src/Bonsai.ML.Torch/LinearAlgebra/EigenvalueDecomposition.cs
+++ b/src/Bonsai.ML.Torch/LinearAlgebra/EigenvalueDecomposition.cs
@@ -3,24 +3,40 @@
using System.Reactive.Linq;
using static TorchSharp.torch;
-namespace Bonsai.ML.Torch.LinearAlgebra
+namespace Bonsai.ML.Torch.LinearAlgebra;
+
+///
+/// Represents an operator that computes the eigenvalue decomposition of a square matrix if it exists.
+///
+[Combinator]
+[Description("Computes the eigenvalue decomposition of a square matrix if it exists.")]
+[WorkflowElementCategory(ElementCategory.Transform)]
+public class EigenvalueDecomposition
{
///
/// Computes the eigenvalue decomposition of a square matrix if it exists.
///
- [Combinator]
- [Description("Computes the eigenvalue decomposition of a square matrix if it exists.")]
- [WorkflowElementCategory(ElementCategory.Transform)]
- public class EigenvalueDecomposition
+ ///
+ ///
+ public IObservable Process(IObservable source)
{
+ return source.Select(tensor => new EigenDecompositionResult(linalg.eig(tensor)));
+ }
+
+ ///
+ /// Represents the result of an eigenvalue decomposition.
+ ///
+ /// The tuple containing the eigenvalues and eigenvectors.
+ public readonly struct EigenDecompositionResult((Tensor eigenvalues, Tensor eigenvectors) result)
+ {
+ ///
+ /// Gets the eigenvalues of the decomposition.
+ ///
+ public Tensor Eigenvalues => result.eigenvalues;
+
///
- /// Computes the eigenvalue decomposition of a square matrix if it exists.
+ /// Gets the eigenvectors of the decomposition.
///
- ///
- ///
- public IObservable> Process(IObservable source)
- {
- return source.Select(tensor => linalg.eig(tensor).ToTuple());
- }
+ public Tensor Eigenvectors => result.eigenvectors;
}
-}
\ No newline at end of file
+}
diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/Inverse.cs b/src/Bonsai.ML.Torch/LinearAlgebra/Inverse.cs
index 1d879b5b..530979fc 100644
--- a/src/Bonsai.ML.Torch/LinearAlgebra/Inverse.cs
+++ b/src/Bonsai.ML.Torch/LinearAlgebra/Inverse.cs
@@ -4,24 +4,23 @@
using static TorchSharp.torch;
using static TorchSharp.torch.linalg;
-namespace Bonsai.ML.Torch.LinearAlgebra
+namespace Bonsai.ML.Torch.LinearAlgebra;
+
+///
+/// Represents an operator that computes the inverse of the input matrix.
+///
+[Combinator]
+[Description("Computes the inverse of the input matrix.")]
+[WorkflowElementCategory(ElementCategory.Transform)]
+public class Inverse
{
///
/// Computes the inverse of the input matrix.
///
- [Combinator]
- [Description("Computes the inverse of the input matrix.")]
- [WorkflowElementCategory(ElementCategory.Transform)]
- public class Inverse
+ /// The input matrix to invert.
+ /// The inverse of the input matrix.
+ public IObservable Process(IObservable source)
{
- ///
- /// Computes the inverse of the input matrix.
- ///
- ///
- ///
- public IObservable Process(IObservable source)
- {
- return source.Select(inv);
- }
+ return source.Select(inv);
}
-}
\ No newline at end of file
+}
diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/LeastSquaresSolve.cs b/src/Bonsai.ML.Torch/LinearAlgebra/LeastSquaresSolve.cs
new file mode 100644
index 00000000..aaa058c3
--- /dev/null
+++ b/src/Bonsai.ML.Torch/LinearAlgebra/LeastSquaresSolve.cs
@@ -0,0 +1,57 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+
+namespace Bonsai.ML.Torch.LinearAlgebra;
+
+///
+/// Represents an operator that computes the solution to the least squares and least norm problems for a full rank matrix A of size m*n and a matrix B of size m*k.
+///
+[Combinator]
+[Description("Computes the solution to the least squares and least norm problems for a full rank matrix A of size m*n and a matrix B of size m*k.")]
+[WorkflowElementCategory(ElementCategory.Transform)]
+public class LeastSquaresSolve
+{
+ ///
+ /// Computes the solution to the least squares and least norm problems for a full rank matrix A of size m*n and a matrix B of size m*k.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(value => new LeastSquaresResult(linalg.lstsq(value.Item1, value.Item2)));
+ }
+
+ ///
+ /// Represents the result of solving of linear equations using the least squares method.
+ ///
+ ///
+ public readonly struct LeastSquaresResult((
+ Tensor solution,
+ Tensor residuals,
+ Tensor rank,
+ Tensor singularValues
+ ) result)
+ {
+ ///
+ /// The solution to the system of equations.
+ ///
+ public Tensor Solution => result.solution;
+
+ ///
+ /// The residual error.
+ ///
+ public Tensor Residuals => result.residuals;
+
+ ///
+ /// The effective rank of the solution.
+ ///
+ public Tensor Rank => result.rank;
+
+ ///
+ /// The singular values of the solution.
+ ///
+ public Tensor SingularValues => result.singularValues;
+ }
+}
diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/MatrixMultiply.cs b/src/Bonsai.ML.Torch/LinearAlgebra/MatrixMultiply.cs
new file mode 100644
index 00000000..0703dc3d
--- /dev/null
+++ b/src/Bonsai.ML.Torch/LinearAlgebra/MatrixMultiply.cs
@@ -0,0 +1,123 @@
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.linalg;
+
+namespace Bonsai.ML.Torch.LinearAlgebra;
+
+///
+/// Represents an operator that performs matrix multiplication of 2 or more tensors.
+///
+[Combinator]
+[Description("Performs matrix multiplication of 2 or more tensors.")]
+[WorkflowElementCategory(ElementCategory.Transform)]
+public class MatrixMultiply
+{
+ ///
+ /// Performs matrix multiplication of 2 tensors.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(input =>input.Item1.matmul(input.Item2));
+ }
+
+ ///
+ /// Performs matrix multiplication of 3 tensors.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(input =>
+ {
+ return multi_dot([input.Item1, input.Item2, input.Item3]);
+ });
+ }
+
+ ///
+ /// Performs matrix multiplication of 4 tensors.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(input =>
+ {
+ return multi_dot([input.Item1, input.Item2, input.Item3, input.Item4]);
+ });
+ }
+
+ ///
+ /// Performs matrix multiplication of 5 tensors.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(input =>
+ {
+ return multi_dot([input.Item1, input.Item2, input.Item3, input.Item4, input.Item5]);
+ });
+ }
+
+ ///
+ /// Performs matrix multiplication of 6 tensors.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(input =>
+ {
+ return multi_dot([input.Item1, input.Item2, input.Item3, input.Item4, input.Item5, input.Item6]);
+ });
+ }
+
+ ///
+ /// Performs matrix multiplication of 7 tensors.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(input =>
+ {
+ return multi_dot([input.Item1, input.Item2, input.Item3, input.Item4, input.Item5, input.Item6, input.Item7]);
+ });
+ }
+
+ ///
+ /// Performs matrix multiplication of an array of tensors.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(multi_dot);
+ }
+
+ ///
+ /// Performs matrix multiplication of a list of tensors.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(multi_dot);
+ }
+
+ ///
+ /// Performs matrix multiplication of an enumerable of tensors.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(input => multi_dot([.. input]));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/MatrixRank.cs b/src/Bonsai.ML.Torch/LinearAlgebra/MatrixRank.cs
new file mode 100644
index 00000000..84eb4c37
--- /dev/null
+++ b/src/Bonsai.ML.Torch/LinearAlgebra/MatrixRank.cs
@@ -0,0 +1,46 @@
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.linalg;
+
+namespace Bonsai.ML.Torch.LinearAlgebra;
+
+///
+/// Represents an operator that computes the numerical rank of a matrix.
+///
+[Combinator]
+[Description("Computes the numerical rank of a matrix.")]
+[WorkflowElementCategory(ElementCategory.Transform)]
+public class MatrixRank
+{
+ ///
+ /// Gets or sets the absolute tolerance for singular values to be considered non-zero.
+ ///
+ [Description("The absolute tolerance for singular values to be considered non-zero.")]
+ public double? AbsoluteTolerance { get; set; } = null;
+
+ ///
+ /// Gets or sets the relative tolerance for singular values to be considered non-zero.
+ ///
+ [Description("The relative tolerance for singular values to be considered non-zero.")]
+ public double? RelativeTolerance { get; set; } = null;
+
+ ///
+ /// Gets or sets a value indicating whether to treat the input matrix as Hermitian if input is complex or symmetric if real.
+ ///
+ [Description("Indicates whether to treat the input matrix as Hermitian if input is complex or symmetric if real.")]
+ public bool Hermitian { get; set; } = false;
+
+ ///
+ /// Computes the numerical rank of a matrix.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(input => matrix_rank(input, atol: AbsoluteTolerance, rtol: RelativeTolerance, hermitian: Hermitian));
+ }
+}
diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/Norm.cs b/src/Bonsai.ML.Torch/LinearAlgebra/Norm.cs
index eb18920d..510a7b59 100644
--- a/src/Bonsai.ML.Torch/LinearAlgebra/Norm.cs
+++ b/src/Bonsai.ML.Torch/LinearAlgebra/Norm.cs
@@ -3,35 +3,36 @@
using System.Reactive.Linq;
using static TorchSharp.torch;
-namespace Bonsai.ML.Torch.LinearAlgebra
+namespace Bonsai.ML.Torch.LinearAlgebra;
+
+///
+/// Represents an operator that computes a vector or matrix norm.
+///
+[Combinator]
+[Description("Computes a vector or matrix norm.")]
+[WorkflowElementCategory(ElementCategory.Transform)]
+public class Norm
{
///
- /// Computes a vector or matrix norm.
+ /// The dimensions along which to compute the norm.
///
- [Combinator]
- [Description("Computes a vector or matrix norm.")]
- [WorkflowElementCategory(ElementCategory.Transform)]
- public class Norm
- {
- ///
- /// The dimensions along which to compute the norm.
- ///
- [TypeConverter(typeof(UnidimensionalArrayConverter))]
- public long[] Dimensions { get; set; } = null;
+ [TypeConverter(typeof(UnidimensionalArrayConverter))]
+ [Description("The dimensions along which to compute the norm.")]
+ public long[] Dimensions { get; set; } = null;
- ///
- /// If true, the reduced dimensions are retained in the result as dimensions with size one.
- ///
- public bool Keepdim { get; set; } = false;
+ ///
+ /// If true, the reduced dimensions are retained in the result as dimensions with size one.
+ ///
+ [Description("If true, the reduced dimensions are retained in the result as dimensions with size one.")]
+ public bool Keepdim { get; set; } = false;
- ///
- /// Computes a matrix norm.
- ///
- ///
- ///
- public IObservable Process(IObservable source)
- {
- return source.Select(tensor => linalg.norm(tensor, dims: Dimensions, keepdim: Keepdim));
- }
+ ///
+ /// Computes a matrix norm.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(tensor => linalg.norm(tensor, dims: Dimensions, keepdim: Keepdim));
}
-}
\ No newline at end of file
+}
diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/QRDecomposition.cs b/src/Bonsai.ML.Torch/LinearAlgebra/QRDecomposition.cs
new file mode 100644
index 00000000..eb01aaf3
--- /dev/null
+++ b/src/Bonsai.ML.Torch/LinearAlgebra/QRDecomposition.cs
@@ -0,0 +1,50 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.linalg;
+
+namespace Bonsai.ML.Torch.LinearAlgebra;
+
+///
+/// Represents an operator that computes the QR decomposition of a matrix.
+///
+[Combinator]
+[Description("Computes the QR decomposition of a matrix.")]
+[WorkflowElementCategory(ElementCategory.Transform)]
+public class QRDecomposition
+{
+ ///
+ /// Gets or sets the mode of the QR decomposition.
+ ///
+ [Description("The mode of the QR decomposition.")]
+ public QRMode Mode { get; set; } = QRMode.Reduced;
+
+ ///
+ /// Computes the QR decomposition of a matrix.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(tensor => new QRDecompositionResult(qr(tensor, mode: Mode)));
+ }
+
+ ///
+ /// Represents the result of a QR decomposition.
+ ///
+ ///
+ public readonly struct QRDecompositionResult((Tensor Q, Tensor R) result)
+ {
+ ///
+ /// Gets the orthogonal matrix Q.
+ ///
+ public Tensor Q => result.Q;
+
+ ///
+ /// Gets the upper triangular matrix R.
+ ///
+ public Tensor R => result.R;
+ }
+
+}
diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/SignLogDeterminant.cs b/src/Bonsai.ML.Torch/LinearAlgebra/SignLogDeterminant.cs
index 6d29c910..2b9eb66d 100644
--- a/src/Bonsai.ML.Torch/LinearAlgebra/SignLogDeterminant.cs
+++ b/src/Bonsai.ML.Torch/LinearAlgebra/SignLogDeterminant.cs
@@ -3,24 +3,40 @@
using System.Reactive.Linq;
using static TorchSharp.torch;
-namespace Bonsai.ML.Torch.LinearAlgebra
+namespace Bonsai.ML.Torch.LinearAlgebra;
+
+///
+/// Represents an operator that computes the sign and natural logarithm of the absolute value of the determinant of a square matrix.
+///
+[Combinator]
+[Description("Computes the sign and natural logarithm of the absolute value of the determinant of a square matrix.")]
+[WorkflowElementCategory(ElementCategory.Transform)]
+public class SignLogDeterminant
{
///
- /// Computes the determinant of a square matrix.
+ /// Computes the sign and natural logarithm of the absolute value of the determinant of a square matrix.
///
- [Combinator]
- [Description("Computes the sign and natural logarithm of the absolute value of the determinant of a square matrix.")]
- [WorkflowElementCategory(ElementCategory.Transform)]
- public class SignLogDeterminant
+ ///
+ ///
+ public IObservable Process(IObservable source)
{
+ return source.Select(result => new SignLogDeterminantResult(linalg.slogdet(result)));
+ }
+
+ ///
+ /// Represents the result of computing the sign and natural logarithm of the absolute value of the determinant.
+ ///
+ ///
+ public readonly struct SignLogDeterminantResult((Tensor sign, Tensor logabsdet) result)
+ {
+ ///
+ /// Gets the sign of the determinant.
+ ///
+ public Tensor Sign => result.sign;
+
///
- /// Computes the determinant of a square matrix.
+ /// Gets the natural logarithm of the absolute value of the determinant.
///
- ///
- ///
- public IObservable<(Tensor, Tensor)> Process(IObservable source)
- {
- return source.Select(linalg.slogdet);
- }
+ public Tensor LogAbsDeterminant => result.logabsdet;
}
-}
\ No newline at end of file
+}
diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/SingularValueDecomposition.cs b/src/Bonsai.ML.Torch/LinearAlgebra/SingularValueDecomposition.cs
index c440316f..25ac9581 100644
--- a/src/Bonsai.ML.Torch/LinearAlgebra/SingularValueDecomposition.cs
+++ b/src/Bonsai.ML.Torch/LinearAlgebra/SingularValueDecomposition.cs
@@ -3,29 +3,55 @@
using System.Reactive.Linq;
using static TorchSharp.torch;
-namespace Bonsai.ML.Torch.LinearAlgebra
+namespace Bonsai.ML.Torch.LinearAlgebra;
+
+///
+/// Represents an operator that computes the singular value decomposition (SVD) of a matrix.
+///
+[Combinator]
+[Description("Computes the singular value decomposition (SVD) of a matrix.")]
+[WorkflowElementCategory(ElementCategory.Transform)]
+public class SingularValueDecomposition
{
+ ///
+ /// Whether to compute the full or reduced SVD.
+ ///
+ [Description("Whether to compute the full or reduced SVD.")]
+ public bool FullMatrices { get; set; } = false;
+
///
/// Computes the singular value decomposition (SVD) of a matrix.
///
- [Combinator]
- [Description("Computes the singular value decomposition (SVD) of a matrix.")]
- [WorkflowElementCategory(ElementCategory.Transform)]
- public class SingularValueDecomposition
+ ///
+ ///
+ public IObservable Process(IObservable source)
+ {
+ return source.Select(tensor => new SingularValueDecompositionResult(linalg.svd(tensor, fullMatrices: FullMatrices)));
+ }
+
+ ///
+ /// Represents the result of a singular value decomposition.
+ ///
+ ///
+ public readonly struct SingularValueDecompositionResult((
+ Tensor u,
+ Tensor s,
+ Tensor vh
+ ) result)
{
///
- /// Whether to compute the full or reduced SVD.
+ /// The U tensor.
+ ///
+ public Tensor U => result.u;
+
+ ///
+ /// The singular values.
///
- public bool FullMatrices { get; set; } = false;
+ public Tensor S => result.s;
///
- /// Computes the singular value decomposition (SVD) of a matrix.
+ /// The Vh tensor.
///
- ///
- ///
- public IObservable> Process(IObservable source)
- {
- return source.Select(tensor => linalg.svd(tensor, fullMatrices: FullMatrices).ToTuple());
- }
+ public Tensor Vh => result.vh;
}
-}
\ No newline at end of file
+}
diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/TensorSolve.cs b/src/Bonsai.ML.Torch/LinearAlgebra/TensorSolve.cs
new file mode 100644
index 00000000..52b5f5d3
--- /dev/null
+++ b/src/Bonsai.ML.Torch/LinearAlgebra/TensorSolve.cs
@@ -0,0 +1,36 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.linalg;
+
+namespace Bonsai.ML.Torch.LinearAlgebra;
+
+///
+/// Represents an operator that computes the solution X to the system tensordot(A, X) = B.
+///
+[Combinator]
+[Description("Computes the solution to the system tensordot(A, X) = B.")]
+[WorkflowElementCategory(ElementCategory.Transform)]
+public class TensorSolve
+{
+ ///
+ /// The dimensions to perform the operation.
+ ///
+ [TypeConverter(typeof(UnidimensionalArrayConverter))]
+ [Description("The dimensions to perform the operation.")]
+ public long[] Dimensions { get; set; } = [];
+
+ ///
+ /// Computes the solution to the system tensordot(A, X) = B.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(value =>
+ {
+ return tensorsolve(value.Item1, value.Item2, Dimensions);
+ });
+ }
+}
diff --git a/src/Bonsai.ML.Torch/LinearAlgebra/TriangularSolve.cs b/src/Bonsai.ML.Torch/LinearAlgebra/TriangularSolve.cs
new file mode 100644
index 00000000..8f7b1f54
--- /dev/null
+++ b/src/Bonsai.ML.Torch/LinearAlgebra/TriangularSolve.cs
@@ -0,0 +1,47 @@
+using System;
+using System.ComponentModel;
+using System.Reactive.Linq;
+using static TorchSharp.torch;
+using static TorchSharp.torch.linalg;
+
+namespace Bonsai.ML.Torch.LinearAlgebra;
+
+///
+/// Represents an operator that computes the solution to a triangular system of linear equations with a unique solution.
+///
+[Combinator]
+[Description("Computes the solution to a triangular system of linear equations with a unique solution.")]
+[WorkflowElementCategory(ElementCategory.Transform)]
+public class TriangularSolve
+{
+ ///
+ /// Gets or sets a value indicating whether the first matrix is upper triangular.
+ ///
+ [Description("Indicates whether the first matrix is upper triangular.")]
+ public bool Upper { get; set; } = true;
+
+ ///
+ /// Gets or sets a value indicating whether to solve the system with the first matrix on the left or right (AX = B or XA = B).
+ ///
+ [Description("Indicates whether to solve the system with the first matrix on the left or right (AX = B or XA = B).")]
+ public bool Left { get; set; } = true;
+
+ ///
+ /// Gets or sets a value indicating whether the first matrix has a unit diagonal, i.e., all diagonal elements are assumed to be 1.
+ ///
+ [Description("Indicates whether the first matrix has a unit diagonal, i.e., all diagonal elements are assumed to be 1.")]
+ public bool UnitDiagonal { get; set; } = false;
+
+ ///
+ /// Computes the solution to a triangular system of linear equations for each pair of input tensors.
+ ///
+ ///
+ ///
+ public IObservable Process(IObservable> source)
+ {
+ return source.Select(value =>
+ {
+ return solve_triangular(value.Item1, value.Item2, upper: Upper, left: Left, unitriangular: UnitDiagonal);
+ });
+ }
+}