Expected behavior
relax.op.nn.attention should support symbolic batch dimensions — the common
case for dynamic-batch serving. Adjacent Relax ops (matmul, nn.conv2d,
nn.softmax, nn.layer_norm, nn.batch_norm, nn.adaptive_avg_pool2d,
nn.max_pool2d, …) all already handle a symbolic batch dim correctly.
Actual behavior
LegalizeOps on an attention with a symbolic batch dim crashes inside
topi.nn.batch_matmul:
AssertionError: batch dimension doesn't match
Static batch sizes work; only the symbolic case fails.
Reproduction
import tvm
from tvm import relax
from tvm import tirx
B = tirx.Var("B", "int64")
bb = relax.BlockBuilder()
Q = relax.Var("Q", relax.TensorStructInfo([B, 4, 2, 8], "float32"))
K = relax.Var("K", relax.TensorStructInfo([B, 4, 2, 8], "float32"))
V = relax.Var("V", relax.TensorStructInfo([B, 4, 2, 8], "float32"))
with bb.function("main", [Q, K, V]):
with bb.dataflow():
y = bb.emit(relax.op.nn.attention(Q, K, V))
gv = bb.emit_output(y)
bb.emit_func_output(gv)
mod = bb.get()
mod = relax.transform.LegalizeOps()(mod)
Stack trace:
File ".../relax/transform/legalize_ops/nn.py", in _te_attention
p = topi.nn.batch_matmul(q, k)
File ".../topi/nn/batch_matmul.py", line 97, in batch_matmul
assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match"
AssertionError: batch dimension doesn't match
Replacing the symbolic B with any literal int (e.g. [2, 4, 2, 8]) makes
the same module compile cleanly.
Root cause
Two-step chain:
-
_te_attention builds q and k via two separate Python expressions
for the merged batch×head dim
(python/tvm/relax/transform/legalize_ops/nn.py):
q = topi.reshape(q, [batch_size * num_head, seq_len, head_dim])
k = topi.reshape(k, [batch_size * num_head, seq_len_kv, head_dim])
batch_size * num_head is evaluated twice, producing two structurally
equal but object-distinct Mul(B, 2) PrimExpr nodes.
-
topi.nn.batch_matmul asserts batch-dim equality via plain Python
== (python/tvm/topi/nn/batch_matmul.py:97):
if oshape is None:
assert XB == YB or XB == 1 or YB == 1, "batch dimension doesn't match"
For two PrimExpr operands, Python == returns a tirx.expr.EqualOp (a
deferred comparison). Its bool() only collapses when the comparison is
structurally trivial — bool(Mul(B,2) == Mul(B,2)) returns False
because the two nodes are distinct objects. Witness:
>>> B = tirx.Var("B", "int64")
>>> e1 = B * 2
>>> e2 = B * 2
>>> bool(e1 == e2)
False
So all three legs of the assert short-circuit to False and the
assertion fires.
The same function already has a workaround for the K-dim
(assert XK == YK or isinstance(YK, tvm.tirx.expr.Var)), but the batch dim
was left without one.
Suggested fix
Two viable fixes; the first is local to attention and verified end-to-end:
Fix A (recommended, local, verified). Pass the merged batch dim as
oshape in _te_attention, which skips the assertion entirely:
# python/tvm/relax/transform/legalize_ops/nn.py
def _te_attention(q, k, v, bias, scale, causal_mask):
batch_size, seq_len, num_head, head_dim = q.shape
_, seq_len_kv, _, head_dim_v = v.shape
bsxh = batch_size * num_head # compute once
q = topi.transpose(q, [0, 2, 1, 3])
k = topi.transpose(k, [0, 2, 1, 3])
v = topi.transpose(v, [0, 2, 1, 3])
q = topi.reshape(q, [bsxh, seq_len, head_dim])
k = topi.reshape(k, [bsxh, seq_len_kv, head_dim])
v = topi.reshape(v, [bsxh, seq_len_kv, head_dim_v])
# NEW: pass oshape so batch_matmul does not need to infer / assert it
p = topi.nn.batch_matmul(q, k, oshape=[bsxh, seq_len, seq_len_kv])
...
The same oshape=[bsxh, seq_len, head_dim_v] should be passed to the second
batch_matmul (the one that multiplies the attention probabilities by v).
I verified this fix locally: the original repro module above
builds and runs correctly when these oshape args are added (output shape
(2, 4, 2, 8) matches the static-batch case).
Fix B (broader, optional). Generalize the batch_matmul batch check to
recognize symbolic PrimExpr equality, mirroring the existing K-dim
workaround:
# python/tvm/topi/nn/batch_matmul.py
import tvm
ana = tvm.tirx.analyzer.Analyzer()
if oshape is None:
same_batch = (
XB == 1
or YB == 1
or (isinstance(XB, tvm.tirx.expr.PrimExpr) and isinstance(YB, tvm.tirx.expr.PrimExpr)
and ana.can_prove_equal(XB, YB))
or XB == YB
)
assert same_batch, f"batch dimension doesn't match: {XB} vs {YB}"
...
Fix B would also unblock any other caller of topi.nn.batch_matmul that
constructs the batch dim via separate Python expressions; Fix A only fixes
attention. Either is small and self-contained.
Impact
Blocks nn.attention for the dynamic-batch deployment story that Relax is
otherwise designed to support. Today's workaround is to compose matmul +
softmax manually, which loses the structured op and any future BYOC /
FlashAttention dispatch hooked to relax.nn.attention.
Environment
- TVM: latest
main (commit b172d5ea3, 2026-06-08)
- Python: 3.11
- Reproduces on CPU (
target="llvm"); the symbolic-shape rejection happens
during LegalizeOps, before any backend code runs.
cc @junrushao
Expected behavior
relax.op.nn.attentionshould support symbolic batch dimensions — the commoncase for dynamic-batch serving. Adjacent Relax ops (
matmul,nn.conv2d,nn.softmax,nn.layer_norm,nn.batch_norm,nn.adaptive_avg_pool2d,nn.max_pool2d, …) all already handle a symbolic batch dim correctly.Actual behavior
LegalizeOpson an attention with a symbolic batch dim crashes insidetopi.nn.batch_matmul:Static batch sizes work; only the symbolic case fails.
Reproduction
Stack trace:
Replacing the symbolic
Bwith any literal int (e.g.[2, 4, 2, 8]) makesthe same module compile cleanly.
Root cause
Two-step chain:
_te_attentionbuildsqandkvia two separate Python expressionsfor the merged batch×head dim
(
python/tvm/relax/transform/legalize_ops/nn.py):batch_size * num_headis evaluated twice, producing two structurallyequal but object-distinct
Mul(B, 2)PrimExpr nodes.topi.nn.batch_matmulasserts batch-dim equality via plain Python==(python/tvm/topi/nn/batch_matmul.py:97):For two PrimExpr operands, Python
==returns atirx.expr.EqualOp(adeferred comparison). Its
bool()only collapses when the comparison isstructurally trivial —
bool(Mul(B,2) == Mul(B,2))returnsFalsebecause the two nodes are distinct objects. Witness:
So all three legs of the
assertshort-circuit toFalseand theassertion fires.
The same function already has a workaround for the K-dim
(
assert XK == YK or isinstance(YK, tvm.tirx.expr.Var)), but the batch dimwas left without one.
Suggested fix
Two viable fixes; the first is local to attention and verified end-to-end:
Fix A (recommended, local, verified). Pass the merged batch dim as
oshapein_te_attention, which skips the assertion entirely:The same
oshape=[bsxh, seq_len, head_dim_v]should be passed to the secondbatch_matmul(the one that multiplies the attention probabilities byv).I verified this fix locally: the original repro module above
builds and runs correctly when these
oshapeargs are added (output shape(2, 4, 2, 8)matches the static-batch case).Fix B (broader, optional). Generalize the
batch_matmulbatch check torecognize symbolic PrimExpr equality, mirroring the existing K-dim
workaround:
Fix B would also unblock any other caller of
topi.nn.batch_matmulthatconstructs the batch dim via separate Python expressions; Fix A only fixes
attention. Either is small and self-contained.
Impact
Blocks
nn.attentionfor the dynamic-batch deployment story that Relax isotherwise designed to support. Today's workaround is to compose
matmul+softmaxmanually, which loses the structured op and any future BYOC /FlashAttention dispatch hooked to
relax.nn.attention.Environment
main(commitb172d5ea3, 2026-06-08)target="llvm"); the symbolic-shape rejection happensduring
LegalizeOps, before any backend code runs.cc @junrushao