Skip to content
Open
4 changes: 4 additions & 0 deletions reproduce_issue.spy
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def main() -> None:
tup = (1, (2, 3))
a, (b, c) = tup
print(a + b + c)
Comment thread
Deadpool2000 marked this conversation as resolved.
46 changes: 33 additions & 13 deletions spy/analyze/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,44 +365,56 @@ def declare_AugAssign(self, augassign: ast.AugAssign) -> None:
self.declare(augassign.value)
Comment thread
Deadpool2000 marked this conversation as resolved.

def declare_UnpackAssign(self, unpack: ast.UnpackAssign) -> None:
for target in unpack.targets:
self._declare_target_maybe(target, unpack.value)
self._declare_unpack_targets(unpack.targets, unpack.value)
self.declare(unpack.value)

def _declare_unpack_targets(
self, targets: list[ast.Expr], value: ast.Expr
) -> None:
for target in targets:
if isinstance(target, ast.StrConst):
self._declare_target_maybe(target, value)
elif isinstance(target, ast.Tuple):
self._declare_unpack_targets(target.items, value)
else:
assert False, "WTF?"

def declare_AssignExpr(self, assignexpr: ast.AssignExpr) -> None:
self._declare_target_maybe(assignexpr.target, assignexpr.value)
self.declare(assignexpr.value)

def _declare_target_maybe(self, target: ast.StrConst, value: ast.Expr) -> None:
# if target name does not exist elsewhere, we treat it as an implicit
# declaration
level, scope, sym = self.lookup_ref(target.value)
def _declare_target_maybe(
self, target: ast.StrConst, value: ast.Expr
) -> None:
varname = target.value
level, scope, sym = self.lookup_ref(varname)
if sym is None:
# First assignment: mark as const unless in a loop
type_loc = value.loc
if self.loop_depth > 0:
varkind: VarKind = "var"
else:
varkind = "const"
self.define_name(target.value, varkind, "auto", target.loc, type_loc)
self.define_name(varname, varkind, "auto", target.loc, type_loc)
else:
# possible second assignment: promote to var if needed
self._promote_const_to_var_maybe(target)

def _promote_const_to_var_maybe(self, target: ast.StrConst) -> None:
level, scope, sym = self.lookup_ref(target.value)
varname = target.value
level, scope, sym = self.lookup_ref(varname)
if (
sym
and sym.is_local
and sym.varkind == "const"
and sym.varkind_origin == "auto"
):
if target.value in self.scope._symbols:
if varname in self.scope._symbols:
# Second assignment to a local const: make it var
old_sym = self.scope._symbols[target.value]
old_sym = self.scope._symbols[varname]
if old_sym.varkind == "const":
new_sym = old_sym.replace(varkind="var")
self.scope._symbols[target.value] = new_sym
self.scope._symbols[varname] = new_sym

def declare_While(self, whilestmt: ast.While) -> None:
# Increment loop depth before processing body
Expand Down Expand Up @@ -553,10 +565,18 @@ def flatten_Tuple(self, tup: ast.Tuple) -> None:

def flatten_UnpackAssign(self, unpack: ast.UnpackAssign) -> None:
self.mod_scope.implicit_imports.add("_tuple")
for target in unpack.targets:
self.flatten(target)
self._flatten_unpack_targets(unpack.targets)
self.flatten(unpack.value)

def _flatten_unpack_targets(self, targets: list[ast.Expr]) -> None:
for target in targets:
if isinstance(target, ast.StrConst):
self.capture_maybe(target.value)
elif isinstance(target, ast.Tuple):
self._flatten_unpack_targets(target.items)
else:
assert False, "WTF?"

def flatten_Dict(self, dict: ast.Dict) -> None:
self.mod_scope.implicit_imports.add("_dict")
for keyVal in dict.items:
Expand Down
2 changes: 1 addition & 1 deletion spy/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,7 @@ class Assign(Stmt):

@astnode
class UnpackAssign(Stmt):
targets: list[StrConst]
targets: list[StrConst | Tuple]
value: Expr


Expand Down
50 changes: 38 additions & 12 deletions spy/backend/c/cwriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,29 +175,55 @@ def emit_stmt_AssignCell(self, assign: ast.AssignCell) -> None:
def emit_stmt_UnpackAssign(self, unpack: ast.UnpackAssign) -> None:
if isinstance(unpack.value, ast.Tuple):
# Blue tuple literal: directly assign each item to its target
for target, item in zip(unpack.targets, unpack.value.items):
c_target = C_Ident(target.value)
v = self.fmt_expr(item)
self.tbc.wl(f"{c_target} = {v};")
self._emit_unpack_tuple_literal(unpack.targets, unpack.value.items)
else:
# Red tuple (struct): we save the result into a tmp variable and the assign
# all fields one by one. The code look like this more or less:
# Red tuple (struct): we save the result into a tmp variable and then assign
# all fields one by one. For example:
# {
# T tmp = some_expression()
# a = tmp._item0;
# b = tmp._item1;
# struct_type tmp = ...;
# a = tmp._item0;
# b = tmp._item1;
# }
assert unpack.value.w_T is not None
c_tuple_type = self.ctx.w2c(unpack.value.w_T)
v = self.fmt_expr(unpack.value)
self.tbc.wl("{")
with self.tbc.indent():
self.tbc.wl(f"{c_tuple_type} tmp = {v};")
for i, target in enumerate(unpack.targets):
c_target = C_Ident(target.value)
self.tbc.wl(f"{c_target} = tmp._item{i};")
self._emit_unpack_struct(unpack.targets, C.Literal("tmp"))
self.tbc.wl("}")

def _get_varname(self, target: ast.StrConst) -> str:
assert isinstance(target, ast.StrConst)
return target.value

def _emit_unpack_tuple_literal(
self, targets: list[ast.Expr], items: list[ast.Expr]
) -> None:
for target, item in zip(targets, items):
if isinstance(target, ast.StrConst):
varname = self._get_varname(target)
c_target = C_Ident(varname)
v = self.fmt_expr(item)
self.tbc.wl(f"{c_target} = {v};")
elif isinstance(target, ast.Tuple):
assert isinstance(item, ast.Tuple)
self._emit_unpack_tuple_literal(target.items, item.items)
else:
assert False, "WTF?"

def _emit_unpack_struct(self, targets: list[ast.Expr], c_value: C.Expr) -> None:
for i, target in enumerate(targets):
c_item = C.Dot(c_value, f"_item{i}")
if isinstance(target, ast.StrConst):
varname = self._get_varname(target)
c_target = C_Ident(varname)
self.tbc.wl(f"{c_target} = {c_item};")
elif isinstance(target, ast.Tuple):
self._emit_unpack_struct(target.items, c_item)
else:
assert False, "WTF?"

def emit_stmt_StmtExpr(self, stmt: ast.StmtExpr) -> None:
v = self.fmt_expr(stmt.value)
if v is C.Void():
Expand Down
11 changes: 10 additions & 1 deletion spy/backend/spy.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,19 @@ def emit_stmt_AugAssign(self, node: ast.AugAssign) -> None:
self.wl(f"{varname} {op}= {v}")

def emit_stmt_UnpackAssign(self, unpack: ast.UnpackAssign) -> None:
targets = ", ".join([t.value for t in unpack.targets])
targets = ", ".join([self.fmt_unpack_target(t) for t in unpack.targets])
v = self.fmt_expr(unpack.value)
self.wl(f"{targets} = {v}")

def fmt_unpack_target(self, target: ast.Expr) -> str:
if isinstance(target, ast.StrConst):
return target.value
elif isinstance(target, ast.Tuple):
items = [self.fmt_unpack_target(t) for t in target.items]
return "(" + ", ".join(items) + ")"
else:
assert False, "WTF?"

Comment thread
Deadpool2000 marked this conversation as resolved.
def emit_stmt_SetAttr(self, node: ast.SetAttr) -> None:
t = self.fmt_expr(node.target)
a = node.attr.value
Expand Down
14 changes: 10 additions & 4 deletions spy/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,16 +539,22 @@ def from_py_stmt_Assign(self, py_node: py_ast.Assign) -> spy.ast.Stmt:
value=value,
)
elif isinstance(py_target, py_ast.Tuple):
targets = []
for item in py_target.elts:
assert isinstance(item, py_ast.Name)
targets.append(spy.ast.StrConst(item.loc, item.id))
targets = [self._from_py_unpack_target(item) for item in py_target.elts]
return spy.ast.UnpackAssign(
loc=py_node.loc, targets=targets, value=self.from_py_expr(py_node.value)
)
else:
self.unsupported(py_target, "assign to complex expressions")

def _from_py_unpack_target(self, py_node: py_ast.expr) -> spy.ast.Expr:
if isinstance(py_node, py_ast.Name):
return spy.ast.StrConst(py_node.loc, py_node.id)
elif isinstance(py_node, py_ast.Tuple):
items = [self._from_py_unpack_target(item) for item in py_node.elts]
return spy.ast.Tuple(py_node.loc, items)
else:
self.unsupported(py_node, "complex unpacking target")

Comment thread
Deadpool2000 marked this conversation as resolved.
def from_py_stmt_AugAssign(self, py_node: py_ast.AugAssign) -> spy.ast.AugAssign:
py_target = py_node.target
if isinstance(py_target, py_ast.Name):
Expand Down
27 changes: 27 additions & 0 deletions spy/tests/compiler/test_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,3 +96,30 @@ def foo() -> None:
("this is `i32`", "42"),
)
self.compile_raises(src, "foo", errors)

def test_nested_unpack_basic(self):
mod = self.compile("""
def main() -> i32:
tup = (1, (2, 3))
a, (b, c) = tup
return a + b + c
""")
assert mod.main() == 6

def test_nested_unpack_deep(self):
mod = self.compile("""
def main() -> i32:
tup = (1, (2, (3, 4)))
a, (b, (c, d)) = tup
return a + b + c + d
""")
assert mod.main() == 10

def test_nested_unpack_multiple(self):
mod = self.compile("""
def main() -> i32:
tup = ((1, 2), (3, 4))
(a, b), (c, d) = tup
return a + b + c + d
""")
assert mod.main() == 10
8 changes: 8 additions & 0 deletions spy/tests/test_backend_spy.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,14 @@ def foo() -> None:
self.compile(src)
self.assert_dump(src)

def test_nested_unpack_assign(self):
src = """
def foo() -> None:
a, (b, c) = x
"""
self.compile(src)
self.assert_dump(src)

def test_aug_assign(self):
src = """
def foo() -> None:
Expand Down
33 changes: 33 additions & 0 deletions spy/tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,39 @@ def foo() -> None:
"""
self.assert_dump(stmt, expected)

def test_NestedUnpackAssign(self):
mod = self.parse("""
def foo() -> None:
a, (b, c) = x
""")
stmt = mod.get_funcdef("foo").body[0]
expected = """
UnpackAssign(
targets=[
StrConst(value='a'),
Tuple(
items=[
StrConst(value='b'),
StrConst(value='c'),
],
),
],
value=Name(id='x'),
)
"""
self.assert_dump(stmt, expected)

def test_UnpackAssign_error(self):
src = """
def foo() -> None:
a, b.c = x
"""
self.expect_errors(
src,
"not implemented yet: complex unpacking target",
("this is not supported", "b.c"),
)

def test_Call(self):
mod = self.parse("""
def foo() -> i32:
Expand Down
35 changes: 35 additions & 0 deletions spy/tests/test_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,6 +715,41 @@ def foo() -> None:
scope = scopes.by_module()
assert scope.implicit_imports == {"_tuple"}

def test_unpack_assign(self):
scopes = self.analyze("""
def foo() -> None:
a, b = (1, 2)
c, (d, e) = (3, (4, 5))
f, (g, (h, i)) = (6, (7, (8, 9)))
""")
funcdef = self.mod.get_funcdef("foo")
scope = scopes.by_funcdef(funcdef)
assert scope._symbols["a"] == MatchSymbol("a", "const", "auto")
assert scope._symbols["b"] == MatchSymbol("b", "const", "auto")
assert scope._symbols["c"] == MatchSymbol("c", "const", "auto")
assert scope._symbols["d"] == MatchSymbol("d", "const", "auto")
assert scope._symbols["e"] == MatchSymbol("e", "const", "auto")
assert scope._symbols["f"] == MatchSymbol("f", "const", "auto")
assert scope._symbols["g"] == MatchSymbol("g", "const", "auto")
assert scope._symbols["h"] == MatchSymbol("h", "const", "auto")
assert scope._symbols["i"] == MatchSymbol("i", "const", "auto")

def test_unpack_assign_var(self):
scopes = self.analyze("""
def foo() -> None:
a, b = (1, 2)
a, b = (3, 4)
c, (d, e) = (5, (6, 7))
d = 8
""")
funcdef = self.mod.get_funcdef("foo")
scope = scopes.by_funcdef(funcdef)
assert scope._symbols["a"] == MatchSymbol("a", "var", "auto")
assert scope._symbols["b"] == MatchSymbol("b", "var", "auto")
assert scope._symbols["c"] == MatchSymbol("c", "const", "auto")
assert scope._symbols["d"] == MatchSymbol("d", "var", "auto")
assert scope._symbols["e"] == MatchSymbol("e", "const", "auto")

def test_dict_literal(self):
scopes = self.analyze("""
def foo() -> None:
Expand Down
24 changes: 17 additions & 7 deletions spy/vm/astframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,13 +636,23 @@ def exec_stmt_UnpackAssign(self, unpack: ast.UnpackAssign) -> None:
value=unpack.value,
args=[ast.Constant(loc=unpack.value.loc, value=i)],
)
# fabricate an ast.Assign
# XXX: ideally we should cache the specialization instead of
# rebuilding it at every exec
assign = self._specialize_Assign(
ast.Assign(loc=unpack.loc, target=target, value=expr)
)
self.exec_stmt(assign)
if isinstance(target, ast.StrConst):
# fabricate an ast.Assign
# XXX: ideally we should cache the specialization instead of
# rebuilding it at every exec
assign = self._specialize_Assign(
ast.Assign(loc=unpack.loc, target=target, value=expr)
)
self.exec_stmt(assign)
elif isinstance(target, ast.Tuple):
new_unpack = ast.UnpackAssign(
loc=unpack.loc,
targets=target.items,
value=expr
)
self.exec_stmt_UnpackAssign(new_unpack)
else:
assert False, "WTF?"

def exec_stmt_AugAssign(self, node: ast.AugAssign) -> None:
# XXX: eventually we want to support things like __IADD__ etc, but for
Expand Down
6 changes: 6 additions & 0 deletions test_side_effects.spy
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
def f() -> None:
print("f called")
return (1, 2)

def main() -> None:
a, b = f()
Comment thread
Deadpool2000 marked this conversation as resolved.