Skip to content

[Bug][Relax] nn.attention fails to legalize when batch dim is a symbolic tirx.Var #19696

@wuyii8941

Description

@wuyii8941

Expected behavior

relax.op.nn.attention should support symbolic batch dimensions — the common
case for dynamic-batch serving. Adjacent Relax ops (matmul, nn.conv2d,
nn.softmax, nn.layer_norm, nn.batch_norm, nn.adaptive_avg_pool2d,
nn.max_pool2d, …) all already handle a symbolic batch dim correctly.

Actual behavior

LegalizeOps on an attention with a symbolic batch dim crashes inside
topi.nn.batch_matmul:

AssertionError: batch dimension doesn't match

Static batch sizes work; only the symbolic case fails.

Reproduction

import tvm
from tvm import relax
from tvm import tirx

B = tirx.Var("B", "int64")
bb = relax.BlockBuilder()
Q = relax.Var("Q", relax.TensorStructInfo([B, 4, 2, 8], "float32"))
K = relax.Var("K", relax.TensorStructInfo([B, 4, 2, 8], "float32"))
V = relax.Var("V", relax.TensorStructInfo([B, 4, 2, 8], "float32"))
with bb.function("main", [Q, K, V]):
    with bb.dataflow():
        y = bb.emit(relax.op.nn.attention(Q, K, V))
        gv = bb.emit_output(y)
    bb.emit_func_output(gv)

mod = bb.get()
mod = relax.transform.LegalizeOps()(mod)

Stack trace:

File ".../relax/transform/legalize_ops/nn.py", in _te_attention
    p = topi.nn.batch_matmul(q, k)
File ".../topi/nn/batch_matmul.py", line 97, in batch_matmul
    assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match"
AssertionError: batch dimension doesn't match

Replacing the symbolic B with any literal int (e.g. [2, 4, 2, 8]) makes
the same module compile cleanly.

Root cause

Two-step chain:

  1. _te_attention builds q and k via two separate Python expressions
    for the merged batch×head dim

    (python/tvm/relax/transform/legalize_ops/nn.py):

    q = topi.reshape(q, [batch_size * num_head, seq_len, head_dim])
    k = topi.reshape(k, [batch_size * num_head, seq_len_kv, head_dim])

    batch_size * num_head is evaluated twice, producing two structurally
    equal but object-distinct
    Mul(B, 2) PrimExpr nodes.

  2. topi.nn.batch_matmul asserts batch-dim equality via plain Python
    ==
    (python/tvm/topi/nn/batch_matmul.py:97):

    if oshape is None:
        assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match"

    For two PrimExpr operands, Python == returns a tirx.expr.EqualOp (a
    deferred comparison). Its bool() only collapses when the comparison is
    structurally trivial — bool(Mul(B,2) == Mul(B,2)) returns False
    because the two nodes are distinct objects. Witness:

    >>> B = tirx.Var("B", "int64")
    >>> e1 = B * 2
    >>> e2 = B * 2
    >>> bool(e1 == e2)
    False

    So all three legs of the assert short-circuit to False and the
    assertion fires.

The same function already has a workaround for the K-dim
(assert XK == YK or isinstance(YK, tvm.tirx.expr.Var)), but the batch dim
was left without one.

Suggested fix

Two viable fixes; the first is local to attention and verified end-to-end:

Fix A (recommended, local, verified). Pass the merged batch dim as
oshape in _te_attention, which skips the assertion entirely:

# python/tvm/relax/transform/legalize_ops/nn.py
def _te_attention(q, k, v, bias, scale, causal_mask):
    batch_size, seq_len, num_head, head_dim = q.shape
    _, seq_len_kv, _, head_dim_v = v.shape
    bsxh = batch_size * num_head                # compute once
    q = topi.transpose(q, [0, 2, 1, 3])
    k = topi.transpose(k, [0, 2, 1, 3])
    v = topi.transpose(v, [0, 2, 1, 3])
    q = topi.reshape(q, [bsxh, seq_len, head_dim])
    k = topi.reshape(k, [bsxh, seq_len_kv, head_dim])
    v = topi.reshape(v, [bsxh, seq_len_kv, head_dim_v])
    # NEW: pass oshape so batch_matmul does not need to infer / assert it
    p = topi.nn.batch_matmul(q, k, oshape=[bsxh, seq_len, seq_len_kv])
    ...

The same oshape=[bsxh, seq_len, head_dim_v] should be passed to the second
batch_matmul (the one that multiplies the attention probabilities by v).

I verified this fix locally: the original repro module above
builds and runs correctly when these oshape args are added (output shape
(2, 4, 2, 8) matches the static-batch case).

Fix B (broader, optional). Generalize the batch_matmul batch check to
recognize symbolic PrimExpr equality, mirroring the existing K-dim
workaround:

# python/tvm/topi/nn/batch_matmul.py
import tvm
ana = tvm.tirx.analyzer.Analyzer()
if oshape is None:
    same_batch = (
        XB == 1
        or YB == 1
        or (isinstance(XB, tvm.tirx.expr.PrimExpr) and isinstance(YB, tvm.tirx.expr.PrimExpr)
            and ana.can_prove_equal(XB, YB))
        or XB == YB
    )
    assert same_batch, f"batch dimension doesn't match: {XB} vs {YB}"
    ...

Fix B would also unblock any other caller of topi.nn.batch_matmul that
constructs the batch dim via separate Python expressions; Fix A only fixes
attention. Either is small and self-contained.

Impact

Blocks nn.attention for the dynamic-batch deployment story that Relax is
otherwise designed to support. Today's workaround is to compose matmul +
softmax manually, which loses the structured op and any future BYOC /
FlashAttention dispatch hooked to relax.nn.attention.

Environment

  • TVM: latest main (commit b172d5ea3, 2026-06-08)
  • Python: 3.11
  • Reproduces on CPU (target="llvm"); the symbolic-shape rejection happens
    during LegalizeOps, before any backend code runs.

cc @junrushao

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions