|
8 | 8 | from pytensor.scan.rewriting import scan_seqopt1 |
9 | 9 | from pytensor.tensor._linalg.solve.linear_control import ( |
10 | 10 | SolveBilinearDiscreteLyapunov, |
| 11 | + SolveSylvester, |
11 | 12 | solve_discrete_lyapunov, |
12 | 13 | ) |
13 | 14 | from pytensor.tensor._linalg.solve.tridiagonal import ( |
14 | 15 | tridiagonal_lu_factor, |
15 | 16 | tridiagonal_lu_solve, |
16 | 17 | ) |
17 | | -from pytensor.tensor.basic import atleast_Nd |
| 18 | +from pytensor.tensor.basic import atleast_Nd, diagonal |
18 | 19 | from pytensor.tensor.blockwise import Blockwise |
19 | 20 | from pytensor.tensor.elemwise import DimShuffle |
20 | | -from pytensor.tensor.rewriting.basic import register_specialize |
| 21 | +from pytensor.tensor.rewriting.basic import ( |
| 22 | + register_canonicalize, |
| 23 | + register_specialize, |
| 24 | + register_stabilize, |
| 25 | +) |
21 | 26 | from pytensor.tensor.rewriting.blockwise import blockwise_of |
22 | | -from pytensor.tensor.rewriting.linalg import is_matrix_transpose |
| 27 | +from pytensor.tensor.rewriting.linalg import _is_diagonal, is_matrix_transpose |
23 | 28 | from pytensor.tensor.slinalg import Solve, cho_solve, cholesky, lu_factor, lu_solve |
24 | 29 | from pytensor.tensor.variable import TensorVariable |
25 | 30 |
|
@@ -306,3 +311,27 @@ def jax_bilinear_lyapunov_to_direct(fgraph, node): |
306 | 311 | "jax", |
307 | 312 | position=0.9, # Run before canonicalization |
308 | 313 | ) |
| 314 | + |
| 315 | + |
| 316 | +@register_canonicalize |
| 317 | +@register_stabilize |
| 318 | +@node_rewriter([blockwise_of(SolveSylvester)]) |
| 319 | +def rewrite_sylvester_diag_to_elemwise(fgraph, node): |
| 320 | + """solve_sylvester(diag_A, diag_B, C) -> C / (a[:, None] + b[None, :]) |
| 321 | +
|
| 322 | + When both coefficient matrices are diagonal, :math:`AX + XB = C` decouples into |
| 323 | + :math:`X_{ij} = C_{ij} / (a_i + b_j)`. |
| 324 | + """ |
| 325 | + A, B, C = node.inputs |
| 326 | + |
| 327 | + if not _is_diagonal(A, fgraph) or not _is_diagonal(B, fgraph): |
| 328 | + return None |
| 329 | + |
| 330 | + diag_a = diagonal(A, axis1=-2, axis2=-1) |
| 331 | + diag_b = diagonal(B, axis1=-2, axis2=-1) |
| 332 | + denom = diag_a[..., :, None] + diag_b[..., None, :] |
| 333 | + |
| 334 | + X_out = node.outputs[0] |
| 335 | + new_x = (C / denom).astype(X_out.type.dtype) |
| 336 | + |
| 337 | + return [new_x] |
0 commit comments