Skip to content

Commit 3b6f14d

Browse files
committed
Bugfix in Deinterleaver recursing into shuffles arguments with different lane count.
1 parent 3a6398f commit 3b6f14d

File tree

3 files changed

+32
-6
lines changed

3 files changed

+32
-6
lines changed

src/Deinterleave.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,10 @@ class Deinterleaver : public IRGraphMutator {
298298
} else {
299299

300300
Type t = op->type.with_lanes(new_lanes);
301+
internal_assert((op->type.lanes() - starting_lane + lane_stride - 1) / lane_stride == new_lanes)
302+
<< "Deinterleaving with lane stride " << lane_stride << " and staring lane " << starting_lane
303+
<< " for var of Type " << op->type << " to " << t << " drops lanes unexpectedly."
304+
<< " Deinterleaver probably recursed too deep into types of different lane count.";
301305
if (external_lets.contains(op->name) &&
302306
starting_lane == 0 &&
303307
lane_stride == 2) {
@@ -392,8 +396,12 @@ class Deinterleaver : public IRGraphMutator {
392396
int index = indices.front();
393397
for (const auto &i : op->vectors) {
394398
if (index < i.type().lanes()) {
395-
ScopedValue<int> lane(starting_lane, index);
396-
return mutate(i);
399+
if (i.type().lanes() == op->type.lanes()) {
400+
ScopedValue<int> scoped_starting_lane(starting_lane, index);
401+
return mutate(i);
402+
} else {
403+
return Shuffle::make(op->vectors, indices);
404+
}
397405
}
398406
index -= i.type().lanes();
399407
}
@@ -405,10 +413,18 @@ class Deinterleaver : public IRGraphMutator {
405413
};
406414

407415
Expr deinterleave(Expr e, int starting_lane, int lane_stride, int new_lanes, const Scope<> &lets) {
416+
debug(3) << "Deinterleave "
417+
<< "(start:" << starting_lane << ", stide:" << lane_stride << ", new_lanes:" << new_lanes << "): "
418+
<< e << " of Type: " << e.type() << "\n";
419+
Type original_type = e.type();
408420
e = substitute_in_all_lets(e);
409421
Deinterleaver d(starting_lane, lane_stride, new_lanes, lets);
410422
e = d.mutate(e);
411423
e = common_subexpression_elimination(e);
424+
Type final_type = e.type();
425+
int expected_lanes = (original_type.lanes() + lane_stride - starting_lane - 1) / lane_stride;
426+
internal_assert(original_type.code() == final_type.code()) << "Underlying types not identical after interleaving.";
427+
internal_assert(expected_lanes == final_type.lanes()) << "Number of lanes incorrect after interleaving: " << final_type.lanes() << "while expected was " << expected_lanes << ".";
412428
return simplify(e);
413429
}
414430

@@ -419,12 +435,12 @@ Expr extract_odd_lanes(const Expr &e, const Scope<> &lets) {
419435

420436
Expr extract_even_lanes(const Expr &e, const Scope<> &lets) {
421437
internal_assert(e.type().lanes() % 2 == 0);
422-
return deinterleave(e, 0, 2, (e.type().lanes() + 1) / 2, lets);
438+
return deinterleave(e, 0, 2, e.type().lanes() / 2, lets);
423439
}
424440

425441
Expr extract_mod3_lanes(const Expr &e, int lane, const Scope<> &lets) {
426442
internal_assert(e.type().lanes() % 3 == 0);
427-
return deinterleave(e, lane, 3, (e.type().lanes() + 2) / 3, lets);
443+
return deinterleave(e, lane, 3, e.type().lanes() / 3, lets);
428444
}
429445

430446
} // namespace

src/Simplify_Let.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *info) {
9898
Expr new_var = Variable::make(f.new_value.type(), f.new_name);
9999
Expr replacement = new_var;
100100

101-
debug(4) << "simplify let " << op->name << " = " << f.value << " in...\n";
101+
debug(4) << "simplify let " << op->name << " = (" << f.value.type() << ") " << f.value << " in...\n";
102102

103103
while (true) {
104104
const Variable *var = f.new_value.template as<Variable>();

test/correctness/vector_shuffle.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ int main(int argc, char **argv) {
7171
printf("Testing vector size %d...\n", vec_size);
7272
std::vector<int> indices0, indices1;
7373

74-
// Test 1: All indices: foreward
74+
// Test 1: All indices: foreward/backward and combined
7575
for (int i = 0; i < vec_size; ++i) {
7676
indices0.push_back(i); // forward
7777
indices1.push_back(vec_size - i - 1); // backward
@@ -137,6 +137,16 @@ int main(int argc, char **argv) {
137137
if (test_with_indices(target, indices0, indices1)) {
138138
return 1;
139139
}
140+
141+
if (vec_size == 4) {
142+
indices0 = {1, 3, 2, 0};
143+
indices1 = {2, 3, 1, 0};
144+
145+
printf(" Specific index combination, known to have caused problems...\n");
146+
if (test_with_indices(target, indices0, indices1)) {
147+
return 1;
148+
}
149+
}
140150
}
141151

142152
printf("Success!\n");

0 commit comments

Comments
 (0)