@@ -298,6 +298,10 @@ class Deinterleaver : public IRGraphMutator {
298298 } else {
299299
300300 Type t = op->type .with_lanes (new_lanes);
301+ internal_assert ((op->type .lanes () - starting_lane + lane_stride - 1 ) / lane_stride == new_lanes)
302+ << " Deinterleaving with lane stride " << lane_stride << " and staring lane " << starting_lane
303+ << " for var of Type " << op->type << " to " << t << " drops lanes unexpectedly."
304+ << " Deinterleaver probably recursed too deep into types of different lane count." ;
301305 if (external_lets.contains (op->name ) &&
302306 starting_lane == 0 &&
303307 lane_stride == 2 ) {
@@ -392,8 +396,12 @@ class Deinterleaver : public IRGraphMutator {
392396 int index = indices.front ();
393397 for (const auto &i : op->vectors ) {
394398 if (index < i.type ().lanes ()) {
395- ScopedValue<int > lane (starting_lane, index);
396- return mutate (i);
399+ if (i.type ().lanes () == op->type .lanes ()) {
400+ ScopedValue<int > scoped_starting_lane (starting_lane, index);
401+ return mutate (i);
402+ } else {
403+ return Shuffle::make (op->vectors , indices);
404+ }
397405 }
398406 index -= i.type ().lanes ();
399407 }
@@ -405,10 +413,18 @@ class Deinterleaver : public IRGraphMutator {
405413};
406414
407415Expr deinterleave (Expr e, int starting_lane, int lane_stride, int new_lanes, const Scope<> &lets) {
416+ debug (3 ) << " Deinterleave "
417+ << " (start:" << starting_lane << " , stide:" << lane_stride << " , new_lanes:" << new_lanes << " ): "
418+ << e << " of Type: " << e.type () << " \n " ;
419+ Type original_type = e.type ();
408420 e = substitute_in_all_lets (e);
409421 Deinterleaver d (starting_lane, lane_stride, new_lanes, lets);
410422 e = d.mutate (e);
411423 e = common_subexpression_elimination (e);
424+ Type final_type = e.type ();
425+ int expected_lanes = (original_type.lanes () + lane_stride - starting_lane - 1 ) / lane_stride;
426+ internal_assert (original_type.code () == final_type.code ()) << " Underlying types not identical after interleaving." ;
427+ internal_assert (expected_lanes == final_type.lanes ()) << " Number of lanes incorrect after interleaving: " << final_type.lanes () << " while expected was " << expected_lanes << " ." ;
412428 return simplify (e);
413429}
414430
@@ -419,12 +435,12 @@ Expr extract_odd_lanes(const Expr &e, const Scope<> &lets) {
419435
420436Expr extract_even_lanes (const Expr &e, const Scope<> &lets) {
421437 internal_assert (e.type ().lanes () % 2 == 0 );
422- return deinterleave (e, 0 , 2 , ( e.type ().lanes () + 1 ) / 2 , lets);
438+ return deinterleave (e, 0 , 2 , e.type ().lanes () / 2 , lets);
423439}
424440
425441Expr extract_mod3_lanes (const Expr &e, int lane, const Scope<> &lets) {
426442 internal_assert (e.type ().lanes () % 3 == 0 );
427- return deinterleave (e, lane, 3 , ( e.type ().lanes () + 2 ) / 3 , lets);
443+ return deinterleave (e, lane, 3 , e.type ().lanes () / 3 , lets);
428444}
429445
430446} // namespace
0 commit comments