Skip to content

Commit a95b90c

Browse files
authored
[Relax] Add FRelaxInferLayout for flip operator (#18637)
## Why The flip operator lacked layout inference support, preventing it from participating in layout transformations during the ConvertLayout pass. ## How - Add InferLayoutFlip function that transforms the axis attribute according to the input layout - Register FRelaxInferLayout attribute for relax.flip operator - Add test case for conv2d followed by flip with layout conversion
1 parent 304a74a commit a95b90c

File tree

2 files changed

+77
-0
lines changed

2 files changed

+77
-0
lines changed

src/relax/op/tensor/manipulate.cc

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2047,11 +2047,44 @@ StructInfo InferStructInfoFlip(const Call& call, const BlockBuilder& ctx) {
20472047
return data_sinfo;
20482048
}
20492049

2050+
InferLayoutOutput InferLayoutFlip(
2051+
const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts,
2052+
const VarLayoutMap& var_layout_map) {
2053+
ICHECK(NoDesiredLayout(call, desired_layouts));
2054+
2055+
const auto* attrs = call->attrs.as<FlipAttrs>();
2056+
ICHECK(attrs != nullptr) << "Invalid Call";
2057+
const auto* tensor_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
2058+
ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
2059+
ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now";
2060+
2061+
LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, call->args[0]);
2062+
int ndim = tensor_sinfo->ndim;
2063+
2064+
if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) {
2065+
existing_layout = LayoutDecision(InitialLayout(ndim));
2066+
}
2067+
2068+
int axis = attrs->axis.IntValue();
2069+
if (axis < 0) {
2070+
axis += ndim;
2071+
}
2072+
2073+
const int new_axis = FindAxis(existing_layout->layout, axis);
2074+
ICHECK_GE(new_axis, 0) << "Failed to find transformed axis";
2075+
2076+
ObjectPtr<FlipAttrs> new_attrs = ffi::make_object<FlipAttrs>(*attrs);
2077+
new_attrs->axis = Integer(new_axis);
2078+
2079+
return InferLayoutOutput({existing_layout}, {existing_layout}, Attrs(new_attrs));
2080+
}
2081+
20502082
TVM_REGISTER_OP("relax.flip")
20512083
.set_attrs_type<FlipAttrs>()
20522084
.set_num_inputs(1)
20532085
.add_argument("data", "Tensor", "The input tensor.")
20542086
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoFlip)
2087+
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutFlip)
20552088
.set_attr<Bool>("FPurity", Bool(true));
20562089

20572090
/* relax.gather_elements */

tests/python/relax/test_transform_convert_layout.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5283,5 +5283,49 @@ def main(
52835283
verify(Input, Expected)
52845284

52855285

5286+
def test_conv2d_flip():
5287+
@I.ir_module
5288+
class Input:
5289+
@R.function
5290+
def main(
5291+
x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3), "float32")
5292+
) -> R.Tensor(None, "float32", ndim=4):
5293+
with R.dataflow():
5294+
gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32")
5295+
gv2: R.Tensor((2, 4, 26, 26), "float32") = R.flip(gv, axis=1)
5296+
R.output(gv2)
5297+
return gv2
5298+
5299+
@I.ir_module
5300+
class Expected:
5301+
@R.function
5302+
def main(
5303+
x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3, 3, 3), dtype="float32")
5304+
) -> R.Tensor(None, dtype="float32", ndim=4):
5305+
with R.dataflow():
5306+
lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1])
5307+
lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1])
5308+
gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
5309+
lv,
5310+
lv1,
5311+
strides=[1, 1],
5312+
padding=[0, 0, 0, 0],
5313+
dilation=[1, 1],
5314+
groups=1,
5315+
data_layout="NHWC",
5316+
kernel_layout="OHWI",
5317+
out_layout="NHWC",
5318+
out_dtype="float32",
5319+
)
5320+
lv2: R.Tensor((2, 26, 26, 4), dtype="float32") = R.flip(gv, axis=3)
5321+
gv2: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims(
5322+
lv2, axes=[0, 3, 1, 2]
5323+
)
5324+
R.output(gv2)
5325+
return gv2
5326+
5327+
verify(Input, Expected)
5328+
5329+
52865330
if __name__ == "__main__":
52875331
tvm.testing.main()

0 commit comments

Comments
 (0)