Skip to content

Commit 9f6e7f5

Browse files
Sylvester rewrite
1 parent d1a166e commit 9f6e7f5

File tree

1 file changed

+32
-3
lines changed

1 file changed

+32
-3
lines changed

pytensor/tensor/_linalg/solve/rewriting.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,23 @@
88
from pytensor.scan.rewriting import scan_seqopt1
99
from pytensor.tensor._linalg.solve.linear_control import (
1010
SolveBilinearDiscreteLyapunov,
11+
SolveSylvester,
1112
solve_discrete_lyapunov,
1213
)
1314
from pytensor.tensor._linalg.solve.tridiagonal import (
1415
tridiagonal_lu_factor,
1516
tridiagonal_lu_solve,
1617
)
17-
from pytensor.tensor.basic import atleast_Nd
18+
from pytensor.tensor.basic import atleast_Nd, diagonal
1819
from pytensor.tensor.blockwise import Blockwise
1920
from pytensor.tensor.elemwise import DimShuffle
20-
from pytensor.tensor.rewriting.basic import register_specialize
21+
from pytensor.tensor.rewriting.basic import (
22+
register_canonicalize,
23+
register_specialize,
24+
register_stabilize,
25+
)
2126
from pytensor.tensor.rewriting.blockwise import blockwise_of
22-
from pytensor.tensor.rewriting.linalg import is_matrix_transpose
27+
from pytensor.tensor.rewriting.linalg import _is_diagonal, is_matrix_transpose
2328
from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve
2429
from pytensor.tensor.variable import TensorVariable
2530

@@ -306,3 +311,27 @@ def jax_bilinear_lyapunov_to_direct(fgraph, node):
306311
"jax",
307312
position=0.9, # Run before canonicalization
308313
)
314+
315+
316+
@register_canonicalize
317+
@register_stabilize
318+
@node_rewriter([blockwise_of(SolveSylvester)])
319+
def rewrite_sylvester_diag_to_elemwise(fgraph, node):
320+
"""solve_sylvester(diag_A, diag_B, C) -> C / (a[:, None] + b[None, :])
321+
322+
When both coefficient matrices are diagonal, :math:`AX + XB = C` decouples into
323+
:math:`X_{ij} = C_{ij} / (a_i + b_j)`.
324+
"""
325+
A, B, C = node.inputs
326+
327+
if not _is_diagonal(A, fgraph) or not _is_diagonal(B, fgraph):
328+
return None
329+
330+
diag_a = diagonal(A, axis1=-2, axis2=-1)
331+
diag_b = diagonal(B, axis1=-2, axis2=-1)
332+
denom = diag_a[..., :, None] + diag_b[..., None, :]
333+
334+
X_out = node.outputs[0]
335+
new_x = (C / denom).astype(X_out.type.dtype)
336+
337+
return [new_x]

0 commit comments

Comments
 (0)