Skip to content

Commit a3d91be

Browse files
committed
Refine literal comprehension sugar behavior and docs
1 parent 4053a9f commit a3d91be

File tree

4 files changed

+174
-11
lines changed

4 files changed

+174
-11
lines changed

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,18 @@ There are several python expressions and idioms that are translated behind your
5757
--- | --- | --- |
5858
|List Comprehension | `[j.pt() for j in jets]` | `jets.Select(lambda j: j.pt())` |
5959
|List Comprehension | `[j.pt() for j in jets if abs(j.eta()) < 2.4]` | `jets.Where(lambda j: abs(j.eta()) < 2.4).Select(lambda j: j.pt())` |
60+
|Literal List Comprehension|`[i for i in [1, 2, 3]]`|`[1, 2, 3]`|
6061
| Data Classes<br>(typed) | `@dataclass`<br>`class my_data:`<br>`x: ObjectStream[Jets]`<br><br>`Select(lambda e: my_data(x=e.Jets()).x)` | `Select(lambda e: {'x': e.Jets()}.x)` |
6162
| Named Tuple<br>(typed) | `class my_data(NamedTuple):`<br>`x: ObjectStream[Jets]`<br><br>`Select(lambda e: my_data(x=e.Jets()).x)` | `Select(lambda e: {'x': e.Jets()}.x)` |
6263
|List Membership|`p.absPdgId() in [35, 51]`|`p.absPdgId() == 35 or p.absPdgId() == 51`|
63-
| `any`/`all` | `any(e.pt()>10, abs(e.eta()) < 2.5` | `e.pt() > 10 | abs(e.eta()) < 2.5` |
64+
| `any`/`all` | `any([e.pt() > 10, abs(e.eta()) < 2.5])` | `e.pt() > 10 or abs(e.eta()) < 2.5` |
6465

6566
Note: Everything that goes for a list comprehension also goes for a generator expression.
6667

68+
For `any`/`all`, generator/list comprehensions over a literal (or captured literal constant)
69+
are first expanded to a literal list and then reduced as usual. For example,
70+
`any(f(a) for a in [1, 2])` is treated like `any([f(1), f(2)])`.
71+
6772
## Extensibility
6873

6974
There are two several extensibility points:

docs/source/generic/query_structure.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,15 @@ This can be continued to deeper and deeper levels within the data. For example,
3434

3535
Due to the flexible nature of FuncADL there are multiple ways to structure each query. Throughout this documentation different structures will be used for the sake of demonstration.
3636

37+
## Syntatic Sugar
38+
39+
Inside query lambdas, FuncADL also rewrites a few common Python forms into query-friendly
40+
expressions:
41+
42+
- List/generator comprehensions over streams are lowered to `.Where(...)`/`.Select(...)`.
43+
- List comprehensions over literal iterables are expanded directly. For example,
44+
`[i for i in [1, 2, 3]]` becomes `[1, 2, 3]`.
45+
- `any`/`all` over literal lists/tuples are reduced to boolean `or`/`and` expressions.
46+
47+
This means patterns like `any(expr(x) for x in LITERAL_LIST)` can be simplified in-query,
48+
as long as the iterable is a literal (or a captured literal constant).

func_adl/ast/syntatic_sugar.py

Lines changed: 127 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
import copy
33
import inspect
44
from dataclasses import is_dataclass
5-
from typing import Any, List, Optional
5+
from itertools import product
6+
from typing import Any, Dict, List, Optional, Tuple
67

7-
from func_adl.util_ast import lambda_build
8+
from func_adl.util_ast import as_ast, lambda_build
89

910

1011
def resolve_syntatic_sugar(a: ast.AST) -> ast.AST:
@@ -22,6 +23,118 @@ def resolve_syntatic_sugar(a: ast.AST) -> ast.AST:
2223
"""
2324

2425
class syntax_transformer(ast.NodeTransformer):
26+
def _extract_literal_iterable(self, node: ast.AST) -> Optional[List[ast.expr]]:
27+
"""Return literal iterable elements if ``node`` is a list/tuple literal."""
28+
29+
if isinstance(node, (ast.List, ast.Tuple)):
30+
return list(node.elts)
31+
if isinstance(node, ast.Constant) and isinstance(node.value, (list, tuple)):
32+
return [as_ast(v) for v in node.value]
33+
return None
34+
35+
def _target_bindings(
36+
self, target: ast.AST, value: ast.AST, node: ast.AST
37+
) -> Optional[Dict[str, ast.expr]]:
38+
"""Build loop-variable bindings for a single comprehension iteration.
39+
40+
Returns ``None`` when destructuring cannot be applied for this ``value``.
41+
"""
42+
43+
if isinstance(target, ast.Name):
44+
return {target.id: copy.deepcopy(value)}
45+
46+
if isinstance(target, (ast.Tuple, ast.List)):
47+
if not isinstance(value, (ast.Tuple, ast.List)):
48+
return None
49+
if len(target.elts) != len(value.elts):
50+
raise ValueError(
51+
"Comprehension unpacking length mismatch" f" - {ast.unparse(node)}"
52+
)
53+
54+
bindings: Dict[str, ast.expr] = {}
55+
for target_elt, value_elt in zip(target.elts, value.elts):
56+
child_bindings = self._target_bindings(target_elt, value_elt, node)
57+
if child_bindings is None:
58+
return None
59+
bindings.update(child_bindings)
60+
return bindings
61+
62+
raise ValueError(
63+
f"Comprehension variable must be a name or tuple/list, but found {target}"
64+
f" - {ast.unparse(node)}"
65+
)
66+
67+
def _substitute_names(self, expr: ast.expr, bindings: Dict[str, ast.expr]) -> ast.expr:
68+
class _name_replacer(ast.NodeTransformer):
69+
def __init__(self, loop_bindings: Dict[str, ast.expr]):
70+
self._loop_bindings = loop_bindings
71+
72+
def visit_Name(self, replace_node: ast.Name) -> Any:
73+
if (
74+
isinstance(replace_node.ctx, ast.Load)
75+
and replace_node.id in self._loop_bindings
76+
):
77+
return copy.deepcopy(self._loop_bindings[replace_node.id])
78+
return replace_node
79+
80+
return _name_replacer(bindings).visit(copy.deepcopy(expr))
81+
82+
def _inline_literal_comprehension(
83+
self, lambda_body: ast.expr, generators: List[ast.comprehension], node: ast.AST
84+
) -> Optional[List[ast.expr]]:
85+
"""Expand comprehensions over literal iterables into literal expressions."""
86+
87+
literal_values: List[List[Tuple[Dict[str, ast.expr], List[ast.expr]]]] = []
88+
for generator in generators:
89+
if generator.is_async:
90+
raise ValueError(f"Comprehension can't be async - {ast.unparse(node)}.")
91+
92+
iter_values = self._extract_literal_iterable(generator.iter)
93+
if iter_values is None:
94+
return None
95+
96+
generator_values: List[Tuple[Dict[str, ast.expr], List[ast.expr]]] = []
97+
for iter_value in iter_values:
98+
bindings = self._target_bindings(generator.target, iter_value, node)
99+
if bindings is None:
100+
return None
101+
generator_values.append((bindings, generator.ifs))
102+
literal_values.append(generator_values)
103+
104+
if len(literal_values) == 0:
105+
return []
106+
107+
expanded: List[ast.expr] = []
108+
for combo in product(*literal_values):
109+
merged_bindings: Dict[str, ast.expr] = {}
110+
all_ifs: List[ast.expr] = []
111+
for c_bindings, c_ifs in combo:
112+
merged_bindings.update(c_bindings)
113+
all_ifs.extend(c_ifs)
114+
115+
include_item = True
116+
for if_clause in all_ifs:
117+
rendered_if = self.visit(self._substitute_names(if_clause, merged_bindings))
118+
if not isinstance(rendered_if, ast.Constant) or not isinstance(
119+
rendered_if.value, bool
120+
):
121+
raise ValueError(
122+
"Literal comprehension if-clause must resolve to a bool constant"
123+
f" - {ast.unparse(if_clause)}"
124+
)
125+
if not rendered_if.value:
126+
include_item = False
127+
break
128+
129+
if include_item:
130+
rendered_item = self.visit(
131+
self._substitute_names(lambda_body, merged_bindings)
132+
)
133+
assert isinstance(rendered_item, ast.expr)
134+
expanded.append(rendered_item)
135+
136+
return expanded
137+
25138
def _resolve_any_all_call(
26139
self, call_node: ast.Call, source_node: ast.AST
27140
) -> Optional[ast.AST]:
@@ -44,6 +157,8 @@ def _resolve_any_all_call(
44157
)
45158

46159
sequence = call_node.args[0]
160+
if isinstance(sequence, (ast.ListComp, ast.GeneratorExp)):
161+
return None
47162
if not isinstance(sequence, (ast.List, ast.Tuple)):
48163
raise ValueError(
49164
f"{func_name} requires a list or tuple literal argument"
@@ -77,10 +192,8 @@ def resolve_generator(
77192
for c in reversed(generators):
78193
target = c.target
79194
if not isinstance(target, ast.Name):
80-
raise ValueError(
81-
f"Comprehension variable must be a name, but found {target}"
82-
f" - {ast.unparse(node)}."
83-
)
195+
# Keep original comprehension for unsupported lowering cases.
196+
return node
84197
if c.is_async:
85198
raise ValueError(f"Comprehension can't be async - {ast.unparse(node)}.")
86199
source_collection = c.iter
@@ -110,6 +223,10 @@ def visit_ListComp(self, node: ast.ListComp) -> Any:
110223
a = self.generic_visit(node)
111224

112225
if isinstance(a, ast.ListComp):
226+
if (
227+
expanded := self._inline_literal_comprehension(a.elt, a.generators, node)
228+
) is not None:
229+
return ast.List(elts=expanded, ctx=ast.Load())
113230
a = self.resolve_generator(a.elt, a.generators, node)
114231

115232
return a
@@ -119,6 +236,10 @@ def visit_GeneratorExp(self, node: ast.GeneratorExp) -> Any:
119236
a = self.generic_visit(node)
120237

121238
if isinstance(a, ast.GeneratorExp):
239+
if (
240+
expanded := self._inline_literal_comprehension(a.elt, a.generators, node)
241+
) is not None:
242+
return ast.List(elts=expanded, ctx=ast.Load())
122243
a = self.resolve_generator(a.elt, a.generators, node)
123244

124245
return a

tests/ast/test_syntatic_sugar.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,13 @@ def test_resolve_generator():
3030
assert ast.dump(ast.parse("jets.Select(lambda j: j.pt())")) == ast.dump(a_new)
3131

3232

33+
def test_resolve_literal_list_comp():
34+
a = ast.parse("[i for i in [1, 2, 3]]")
35+
a_new = resolve_syntatic_sugar(a)
36+
37+
assert ast.dump(ast.parse("[1, 2, 3]")) == ast.dump(a_new)
38+
39+
3340
def test_resolve_listcomp_if():
3441
a = ast.parse("[j.pt() for j in jets if j.pt() > 100]")
3542
a_new = resolve_syntatic_sugar(a)
@@ -62,11 +69,11 @@ def test_resolve_2generator():
6269

6370
def test_resolve_bad_iterator():
6471
a = ast.parse("[j.pt() for idx,j in enumerate(jets)]")
72+
a_new = resolve_syntatic_sugar(a)
6573

66-
with pytest.raises(ValueError) as e:
67-
resolve_syntatic_sugar(a)
68-
69-
assert "name" in str(e)
74+
# Unsupported lowering (tuple target with non-literal source) should be
75+
# preserved for downstream processing.
76+
assert ast.unparse(a_new) == ast.unparse(a)
7077

7178

7279
def test_resolve_no_async():
@@ -396,3 +403,21 @@ def test_resolve_any_requires_literal_sequence():
396403

397404
with pytest.raises(ValueError, match="list or tuple literal"):
398405
resolve_syntatic_sugar(a)
406+
407+
408+
def test_resolve_any_generator_from_literal_capture():
409+
bib_triggers = [(1, 2), (3, 4)]
410+
411+
def tdt_chain_fired(chain: int) -> bool:
412+
return chain > 1
413+
414+
a = parse_as_ast(
415+
lambda e: any(
416+
tdt_chain_fired(incl_trig) and not tdt_chain_fired(bib_trig)
417+
for incl_trig, bib_trig in bib_triggers
418+
)
419+
)
420+
a_resolved = resolve_syntatic_sugar(a)
421+
422+
a_expected = ast.parse("lambda e: (1 > 1 and not (2 > 1)) or (3 > 1 and not (4 > 1))")
423+
assert ast.unparse(a_resolved) == ast.unparse(a_expected)

0 commit comments

Comments
 (0)