Skip to content

[Bug][Relax] nn.conv{1,2,3}d_transpose legalize silently bails for unsupported dilation / data_layout, causing opaque VM-codegen crash #19694

@wuyii8941

Description

@wuyii8941

Expected behavior

relax.op.nn.conv{1,2,3}d_transpose accept dilation and data_layout /
kernel_layout / out_layout parameters. PyTorch's ConvTranspose{1,2,3}d
and ONNX ConvTranspose both support dilation > 1; TF/Keras and many
mobile/accelerator paths use non-NCHW layouts.

When a configuration is not yet implemented in the TOPI lowering, the legalize
pass should raise a clear, actionable OpNotImplemented (or
NotImplementedError) that mentions the operator and the unsupported attribute.

Actual behavior

For any unsupported attribute combination, the legalize logs an INFO-level
message and returns the call unchanged. The unlegalized intrinsic then
crashes during VM codegen with:

tvm.error.InternalError: CodeGenVM cannot handle this intrinsic now:
Op(relax.nn.conv2d_transpose)

The user has no way to tell from this error which attribute caused the bail.
The INFO log is hidden under default logging settings. This pattern affects
three independent code paths:

  1. dilation > 1 (1D / 2D / 3D)
  2. data_layout other than NCHW / NCW / NCDHW (or mismatched
    kernel_layout)
  3. out_layout != data_layout

All three short-circuit to the same opaque codegen crash. Note that
#17890 (closed by #17968) fixed the basic case where the intrinsic was
unhandled altogether, but did not touch any of these conditional bail-out
branches — they remain.

Reproduction

import numpy as np
import tvm
from tvm import relax

def build_and_compile(builder):
    bb = relax.BlockBuilder()
    builder(bb)
    mod = bb.get()
    mod = relax.transform.LegalizeOps()(mod)
    return relax.build(mod, target="llvm")


# === A. conv2d_transpose with dilation > 1 ===
def case_dilation_2d():
    x = np.random.randn(1, 1, 5, 5).astype(np.float32)
    W = np.random.randn(1, 2, 3, 3).astype(np.float32)
    def builder(bb):
        X = relax.Var("X", relax.TensorStructInfo(x.shape, "float32"))
        Wv = relax.Var("W", relax.TensorStructInfo(W.shape, "float32"))
        with bb.function("main", [X, Wv]):
            with bb.dataflow():
                y = bb.emit(relax.op.nn.conv2d_transpose(
                    X, Wv, strides=(1, 1), padding=(0, 0), dilation=(2, 2)))
                gv = bb.emit_output(y)
            bb.emit_func_output(gv)
    return builder


# === B. conv1d_transpose with dilation > 1 ===
def case_dilation_1d():
    x = np.random.randn(1, 1, 5).astype(np.float32)
    W = np.random.randn(1, 2, 3).astype(np.float32)
    def builder(bb):
        X = relax.Var("X", relax.TensorStructInfo(x.shape, "float32"))
        Wv = relax.Var("W", relax.TensorStructInfo(W.shape, "float32"))
        with bb.function("main", [X, Wv]):
            with bb.dataflow():
                y = bb.emit(relax.op.nn.conv1d_transpose(
                    X, Wv, strides=(1,), padding=(0,), dilation=(2,)))
                gv = bb.emit_output(y)
            bb.emit_func_output(gv)
    return builder


# === C. conv2d_transpose with NHWC layout ===
def case_layout_nhwc():
    x = np.random.randn(1, 4, 4, 1).astype(np.float32)
    W = np.random.randn(1, 2, 3, 3).astype(np.float32)
    def builder(bb):
        X = relax.Var("X", relax.TensorStructInfo(x.shape, "float32"))
        Wv = relax.Var("W", relax.TensorStructInfo(W.shape, "float32"))
        with bb.function("main", [X, Wv]):
            with bb.dataflow():
                y = bb.emit(relax.op.nn.conv2d_transpose(
                    X, Wv, strides=(1, 1), padding=(0, 0),
                    data_layout="NHWC", kernel_layout="IOHW"))
                gv = bb.emit_output(y)
            bb.emit_func_output(gv)
    return builder


for label, case in [("dilation_2d", case_dilation_2d),
                    ("dilation_1d", case_dilation_1d),
                    ("layout_nhwc",  case_layout_nhwc)]:
    try:
        build_and_compile(case())
        print(f"{label}: built OK (unexpected)")
    except Exception as e:
        msg = str(e).splitlines()[0]
        print(f"{label}: {type(e).__name__}: {msg}")

# All three print:
# InternalError: CodeGenVM cannot handle this intrinsic now:

(conv3d_transpose with dilation > 1 or data_layout != "NCDHW" reproduces
the same crash; the legalize for it has the identical structure.)

Root cause

python/tvm/relax/transform/legalize_ops/nn.py — each of _nn_conv1d_transpose,
_nn_conv2d_transpose, _nn_conv3d_transpose shares this pattern:

@register_legalize("relax.nn.conv2d_transpose")
def _nn_conv2d_transpose(bb, call):
    if call.attrs.out_layout != call.attrs.data_layout:
        logging.info(
            "TOPI conv2d_transpose does not support different input-output "
            "layouts, and thus cannot be legalized by TOPI"
        )
        return call                          # <-- silent bail (1)
    if call.attrs.data_layout != "NCHW" or call.attrs.kernel_layout != "IOHW":
        logging.info(
            "TOPI conv2d_transpose does not support input layout other than NCHW, "
            "and kernel layout other than IOHW, so cannot be legalized by TOPI"
        )
        return call                          # <-- silent bail (2)
    dilation = call.attrs.dilation
    if len(dilation) != 2 or any(d != 1 for d in dilation):
        logging.info(
            "TOPI conv2d_transpose does not support dilations other than 1, "
            "and thus cannot be legalized by TOPI"
        )
        return call                          # <-- silent bail (3)
    return bb.call_te(topi.nn.group_conv2d_transpose_nchw, ...)

Underlying TOPI primitives (python/tvm/topi/nn/conv2d_transpose.py,
conv1d_transpose.py, conv3d_transpose.py) do not accept a dilation
parameter and only target the NCHW / NCW / NCDHW data layouts.

Suggested fix

Two layers of fix, both useful independently:

(1) Short-term — fail loudly in the legalize. Replace each silent bail
with a structured error so users get an actionable message:

if any(d != 1 for d in dilation):
    raise tvm.error.OpNotImplemented(
        f"nn.conv2d_transpose with dilation={list(dilation)} is not yet "
        f"lowered by TOPI; please open an issue if you need this."
    )
if call.attrs.data_layout != "NCHW" or call.attrs.kernel_layout != "IOHW":
    raise tvm.error.OpNotImplemented(
        f"nn.conv2d_transpose with data_layout={call.attrs.data_layout!r}, "
        f"kernel_layout={call.attrs.kernel_layout!r} is not yet lowered by TOPI."
    )
if call.attrs.out_layout != call.attrs.data_layout:
    raise tvm.error.OpNotImplemented(
        f"nn.conv2d_transpose with out_layout != data_layout is not yet "
        f"lowered by TOPI."
    )

This is a one-line-per-branch change and resolves the user-facing pain
immediately. The same edit applies symmetrically to _nn_conv1d_transpose
(NCW / IOW) and _nn_conv3d_transpose (NCDHW / IODHW).

(2) Longer-term — implement the missing configurations.

  • dilation > 1: pre-dilate the kernel before calling the existing
    transposed-conv compute. Output spatial dim is
    (in - 1) * stride - 2 * pad + dilation * (k - 1) + 1 + output_padding.
    The change in TOPI is small (add a dilation parameter, dilate the
    kernel) and consistent across 1D / 2D / 3D.
  • NHWC layout: either add a layout-sandwich inside the legalize
    (transpose → conv2d_transpose_nchw → transpose) or extend TOPI directly.

Environment

  • TVM: latest main (commit b172d5e)
  • Python: 3.11

cc @junrushao

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions