Skip to content

Commit 72c2746

Browse files
authored
fix nested matches (#126)
Co-authored-by: Tom Wambsgans <[email protected]>
1 parent 028cb7a commit 72c2746

File tree

2 files changed

+165
-5
lines changed

2 files changed

+165
-5
lines changed

crates/lean_compiler/src/b_compile_intermediate.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ struct Compiler {
99
match_blocks: Vec<MatchBlock>,
1010
if_counter: usize,
1111
call_counter: usize,
12+
match_counter: usize,
1213
func_name: String,
1314
stack_frame_layout: StackFrameLayout,
1415
args_count: usize,
@@ -210,8 +211,9 @@ fn compile_lines(
210211
SimpleLine::Match { value, arms } => {
211212
compiler.stack_frame_layout.scopes.push(ScopeLayout::default());
212213

213-
let match_index = compiler.match_blocks.len();
214-
let end_label = Label::match_end(match_index);
214+
let label_id = compiler.match_counter;
215+
compiler.match_counter += 1;
216+
let end_label = Label::match_end(label_id);
215217

216218
let value_simplified = IntermediateValue::from_simple_expr(value, compiler);
217219

@@ -232,6 +234,8 @@ fn compile_lines(
232234
function_name: function_name.clone(),
233235
match_cases: compiled_arms,
234236
});
237+
// Get the actual index AFTER pushing (nested matches may have pushed their blocks first)
238+
let match_index = compiler.match_blocks.len() - 1;
235239

236240
let value_scaled_offset = IntermediateValue::MemoryAfterFp {
237241
offset: compiler.stack_pos.into(),
@@ -263,9 +267,9 @@ fn compile_lines(
263267
compiler.bytecode.insert(end_label, remaining);
264268

265269
compiler.stack_frame_layout.scopes.pop();
266-
compiler.stack_pos = saved_stack_pos;
267-
// It is not necessary to update compiler.stack_size here because the preceding call to
268-
// compile lines should have done so.
270+
// Don't reset stack_pos here - we need to preserve space for the temps we allocated.
271+
// Nested matches would otherwise reuse the same temp positions, causing conflicts.
272+
// This is consistent with IfNotZero which also doesn't reset stack_pos.
269273

270274
return Ok(instructions);
271275
}

crates/lean_compiler/tests/test_compiler.rs

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1455,3 +1455,159 @@ fn test_len_2d_array() {
14551455
false,
14561456
);
14571457
}
1458+
1459+
#[test]
1460+
fn test_nested_matches() {
1461+
let program = r#"
1462+
fn main() {
1463+
assert test_func(0, 0) == 6;
1464+
assert test_func(1, 0) == 3;
1465+
return;
1466+
}
1467+
1468+
fn test_func(a, b) -> 1 {
1469+
x = 1;
1470+
1471+
var mut_x_2;
1472+
match a {
1473+
0 => {
1474+
var mut_x_1;
1475+
mut_x_1 = x + 2;
1476+
match b {
1477+
0 => {
1478+
mut_x_2 = mut_x_1 + 3;
1479+
}
1480+
}
1481+
}
1482+
1 => {
1483+
mut_x_2 = x + 2;
1484+
}
1485+
}
1486+
1487+
return mut_x_2;
1488+
}
1489+
"#;
1490+
compile_and_run(
1491+
&ProgramSource::Raw(program.to_string()),
1492+
(&[], &[]),
1493+
DEFAULT_NO_VEC_RUNTIME_MEMORY,
1494+
false,
1495+
);
1496+
}
1497+
1498+
#[test]
1499+
fn test_deeply_nested_match() {
1500+
// Test with 3 levels of nesting, multiple arms, and variables at each level
1501+
let program = r#"
1502+
fn main() {
1503+
// Test each combination with expected values
1504+
// (0,0,0): base=1000, local_a=5, local_b=8, inner_val=1008
1505+
assert compute(0, 0, 0) == 1008;
1506+
// (0,0,1): base=1000, local_a=5, local_b=8, inner_val=1009
1507+
assert compute(0, 0, 1) == 1009;
1508+
// (0,1,0): base=1000, local_a=5, local_b=12, inner_val=1012
1509+
assert compute(0, 1, 0) == 1012;
1510+
// (0,1,1): base=1000, local_a=5, local_b=12, inner_val=1013
1511+
assert compute(0, 1, 1) == 1013;
1512+
// (1,0,0): base=1000, local_a=16, local_b=36, inner_val=1036
1513+
assert compute(1, 0, 0) == 1036;
1514+
// (1,0,1): base=1000, local_a=16, local_b=36, inner_val=1037
1515+
assert compute(1, 0, 1) == 1037;
1516+
// (1,1,0): base=1000, local_a=16, local_b=46, inner_val=1046
1517+
assert compute(1, 1, 0) == 1046;
1518+
// (1,1,1): base=1000, local_a=16, local_b=46, inner_val=1047
1519+
assert compute(1, 1, 1) == 1047;
1520+
return;
1521+
}
1522+
1523+
fn compute(a, b, c) -> 1 {
1524+
base = 1000;
1525+
var outer_val;
1526+
var mid_val;
1527+
var inner_val;
1528+
1529+
match a {
1530+
0 => {
1531+
outer_val = 5;
1532+
var local_a;
1533+
local_a = a + outer_val; // local_a = 5
1534+
1535+
match b {
1536+
0 => {
1537+
mid_val = 3;
1538+
var local_b;
1539+
local_b = local_a + mid_val; // local_b = 8
1540+
1541+
match c {
1542+
0 => {
1543+
inner_val = base + local_b + c; // 1000 + 8 + 0 = 1008
1544+
}
1545+
1 => {
1546+
inner_val = base + local_b + c; // 1000 + 8 + 1 = 1009
1547+
}
1548+
}
1549+
}
1550+
1 => {
1551+
mid_val = 7;
1552+
var local_b;
1553+
local_b = local_a + mid_val; // local_b = 12
1554+
1555+
match c {
1556+
0 => {
1557+
inner_val = base + local_b + c; // 1000 + 12 + 0 = 1012
1558+
}
1559+
1 => {
1560+
inner_val = base + local_b + c; // 1000 + 12 + 1 = 1013
1561+
}
1562+
}
1563+
}
1564+
}
1565+
}
1566+
1 => {
1567+
outer_val = 15;
1568+
var local_a;
1569+
local_a = a + outer_val; // local_a = 16
1570+
1571+
match b {
1572+
0 => {
1573+
mid_val = 20;
1574+
var local_b;
1575+
local_b = local_a + mid_val; // local_b = 36
1576+
1577+
match c {
1578+
0 => {
1579+
inner_val = base + local_b + c; // 1000 + 36 + 0 = 1036
1580+
}
1581+
1 => {
1582+
inner_val = base + local_b + c; // 1000 + 36 + 1 = 1037
1583+
}
1584+
}
1585+
}
1586+
1 => {
1587+
mid_val = 30;
1588+
var local_b;
1589+
local_b = local_a + mid_val; // local_b = 46
1590+
1591+
match c {
1592+
0 => {
1593+
inner_val = base + local_b + c; // 1000 + 46 + 0 = 1046
1594+
}
1595+
1 => {
1596+
inner_val = base + local_b + c; // 1000 + 46 + 1 = 1047
1597+
}
1598+
}
1599+
}
1600+
}
1601+
}
1602+
}
1603+
1604+
return inner_val;
1605+
}
1606+
"#;
1607+
compile_and_run(
1608+
&ProgramSource::Raw(program.to_string()),
1609+
(&[], &[]),
1610+
DEFAULT_NO_VEC_RUNTIME_MEMORY,
1611+
false,
1612+
);
1613+
}

0 commit comments

Comments
 (0)