Skip to content

Commit 0c7adc2

Browse files
authored
[Relax] Add FRelaxInferLayout for scatter_elements operator (#18638)
## Why The scatter_elements operator was missing FRelaxInferLayout support, which prevented proper layout transformation when used with operators like conv2d that require layout conversion. ## How - Implement InferLayoutScatterElements function that handles layout inference for scatter_elements - Transform axis attribute according to the inferred layout using FindAxis - Handle sub-indexed layout fallback to initial layout - Add test case for conv2d + scatter_elements layout conversion
1 parent a95b90c commit 0c7adc2

File tree

3 files changed

+92
-8
lines changed

3 files changed

+92
-8
lines changed

src/relax/op/tensor/manipulate.cc

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2613,14 +2613,43 @@ StructInfo InferStructInfoScatterElements(const Call& call, const BlockBuilder&
26132613
return TensorStructInfo(data_sinfo->dtype, data_sinfo->ndim, data_sinfo->vdevice);
26142614
}
26152615

2616-
// TODO(relax-team): implement FRelaxInferLayout for scatter_elements
2616+
InferLayoutOutput InferLayoutScatterElements(
2617+
const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts,
2618+
const VarLayoutMap& var_layout_map) {
2619+
ICHECK(NoDesiredLayout(call, desired_layouts));
2620+
const auto* attrs = call->attrs.as<ScatterElementsAttrs>();
2621+
ICHECK(attrs) << "Invalid Call";
2622+
2623+
LayoutDecision data_layout = GetLayoutDecision(var_layout_map, call->args[0]);
2624+
LayoutDecision indices_layout = GetLayoutDecision(var_layout_map, call->args[1]);
2625+
LayoutDecision updates_layout = GetLayoutDecision(var_layout_map, call->args[2]);
2626+
2627+
LayoutDecision layout = data_layout;
2628+
if (NLayoutEqual()(indices_layout, updates_layout)) {
2629+
layout = indices_layout;
2630+
}
2631+
2632+
if (layout->layout.ndim() != layout->layout.ndim_primal()) {
2633+
const auto* tensor_sinfo = GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
2634+
ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
2635+
ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now";
2636+
int ndim = tensor_sinfo->ndim;
2637+
layout = LayoutDecision(InitialLayout(ndim));
2638+
}
2639+
2640+
ObjectPtr<ScatterElementsAttrs> new_attrs = ffi::make_object<ScatterElementsAttrs>(*attrs);
2641+
new_attrs->axis = FindAxis(layout->layout, attrs->axis->value);
2642+
return InferLayoutOutput({layout, layout, layout}, {layout}, Attrs(new_attrs));
2643+
}
2644+
26172645
TVM_REGISTER_OP("relax.scatter_elements")
26182646
.set_attrs_type<ScatterElementsAttrs>()
26192647
.set_num_inputs(3)
26202648
.add_argument("data", "Tensor", "The input tensor.")
26212649
.add_argument("indices", "Tensor", "The indices tensor.")
26222650
.add_argument("updates", "Tensor", "The input tensor of updates.")
26232651
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoScatterElements)
2652+
.set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutScatterElements)
26242653
.set_attr<Bool>("FPurity", Bool(true));
26252654

26262655
/* relax.scatter_nd */

tests/python/contrib/test_msc/test_graph_build.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2443,22 +2443,22 @@ def forward(self, data, index, src):
24432443

24442444
expected1 = {
24452445
"inputs": [
2446-
{"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout": ""},
2447-
{"name": "inp_1", "shape": [2, 5], "dtype": "float32", "layout": ""},
2446+
{"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout": "AB"},
2447+
{"name": "inp_1", "shape": [2, 5], "dtype": "float32", "layout": "AB"},
24482448
],
24492449
"outputs": [
2450-
{"name": "scatter_elements", "shape": [bz, 20], "dtype": "float32", "layout": ""}
2450+
{"name": "scatter_elements", "shape": [bz, 20], "dtype": "float32", "layout": "AB"}
24512451
],
24522452
"nodes": {"total": 4, "input": 2, "constant": 1, "scatter_elements": 1},
24532453
}
24542454
expected2 = {
24552455
"inputs": [
2456-
{"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout": ""},
2457-
{"name": "inp_1", "shape": [2, 5], "dtype": "int64", "layout": ""},
2458-
{"name": "inp_2", "shape": [2, 5], "dtype": "float32", "layout": ""},
2456+
{"name": "inp_0", "shape": [bz, 20], "dtype": "float32", "layout": "AB"},
2457+
{"name": "inp_1", "shape": [2, 5], "dtype": "int64", "layout": "AB"},
2458+
{"name": "inp_2", "shape": [2, 5], "dtype": "float32", "layout": "AB"},
24592459
],
24602460
"outputs": [
2461-
{"name": "scatter_elements", "shape": [bz, 20], "dtype": "float32", "layout": ""}
2461+
{"name": "scatter_elements", "shape": [bz, 20], "dtype": "float32", "layout": "AB"}
24622462
],
24632463
"nodes": {"total": 4, "input": 3, "scatter_elements": 1},
24642464
}

tests/python/relax/test_transform_convert_layout.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5327,5 +5327,60 @@ def main(
53275327
verify(Input, Expected)
53285328

53295329

5330+
def test_conv2d_scatter_elements():
5331+
@I.ir_module
5332+
class Input:
5333+
@R.function
5334+
def main(
5335+
x: R.Tensor((2, 3, 28, 28), "float32"),
5336+
w: R.Tensor((4, 3, 3, 3), "float32"),
5337+
indices: R.Tensor((2, 4, 26, 26), "int64"),
5338+
) -> R.Tensor(None, "float32", ndim=4):
5339+
with R.dataflow():
5340+
data: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w, out_dtype="float32")
5341+
updates: R.Tensor((2, 4, 26, 26), "float32") = R.nn.relu(data)
5342+
gv = R.scatter_elements(data, indices, updates, axis=1)
5343+
R.output(gv)
5344+
return gv
5345+
5346+
@I.ir_module
5347+
class Expected:
5348+
@R.function
5349+
def main(
5350+
x: R.Tensor((2, 3, 28, 28), dtype="float32"),
5351+
w: R.Tensor((4, 3, 3, 3), dtype="float32"),
5352+
indices: R.Tensor((2, 4, 26, 26), dtype="int64"),
5353+
) -> R.Tensor(None, dtype="float32", ndim=4):
5354+
with R.dataflow():
5355+
lv: R.Tensor((2, 28, 28, 3), dtype="float32") = R.permute_dims(x, axes=[0, 2, 3, 1])
5356+
lv1: R.Tensor((4, 3, 3, 3), dtype="float32") = R.permute_dims(w, axes=[0, 2, 3, 1])
5357+
data: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
5358+
lv,
5359+
lv1,
5360+
strides=[1, 1],
5361+
padding=[0, 0, 0, 0],
5362+
dilation=[1, 1],
5363+
groups=1,
5364+
data_layout="NHWC",
5365+
kernel_layout="OHWI",
5366+
out_layout="NHWC",
5367+
out_dtype="float32",
5368+
)
5369+
updates: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.relu(data)
5370+
lv2: R.Tensor((2, 26, 26, 4), dtype="int64") = R.permute_dims(
5371+
indices, axes=[0, 2, 3, 1]
5372+
)
5373+
lv3: R.Tensor((2, 26, 26, 4), dtype="float32") = R.scatter_elements(
5374+
data, lv2, updates, axis=3, reduction="update"
5375+
)
5376+
gv: R.Tensor((2, 4, 26, 26), dtype="float32") = R.permute_dims(
5377+
lv3, axes=[0, 3, 1, 2]
5378+
)
5379+
R.output(gv)
5380+
return gv
5381+
5382+
verify(Input, Expected)
5383+
5384+
53305385
if __name__ == "__main__":
53315386
tvm.testing.main()

0 commit comments

Comments
 (0)