22import copy
33import inspect
44from dataclasses import is_dataclass
5- from typing import Any , List , Optional
5+ from itertools import product
6+ from typing import Any , Dict , List , Optional , Tuple
67
7- from func_adl .util_ast import lambda_build
8+ from func_adl .util_ast import as_ast , lambda_build
89
910
1011def resolve_syntatic_sugar (a : ast .AST ) -> ast .AST :
@@ -22,6 +23,118 @@ def resolve_syntatic_sugar(a: ast.AST) -> ast.AST:
2223 """
2324
2425 class syntax_transformer (ast .NodeTransformer ):
26+ def _extract_literal_iterable (self , node : ast .AST ) -> Optional [List [ast .expr ]]:
27+ """Return literal iterable elements if ``node`` is a list/tuple literal."""
28+
29+ if isinstance (node , (ast .List , ast .Tuple )):
30+ return list (node .elts )
31+ if isinstance (node , ast .Constant ) and isinstance (node .value , (list , tuple )):
32+ return [as_ast (v ) for v in node .value ]
33+ return None
34+
35+ def _target_bindings (
36+ self , target : ast .AST , value : ast .AST , node : ast .AST
37+ ) -> Optional [Dict [str , ast .expr ]]:
38+ """Build loop-variable bindings for a single comprehension iteration.
39+
40+ Returns ``None`` when destructuring cannot be applied for this ``value``.
41+ """
42+
43+ if isinstance (target , ast .Name ):
44+ return {target .id : copy .deepcopy (value )}
45+
46+ if isinstance (target , (ast .Tuple , ast .List )):
47+ if not isinstance (value , (ast .Tuple , ast .List )):
48+ return None
49+ if len (target .elts ) != len (value .elts ):
50+ raise ValueError (
51+ "Comprehension unpacking length mismatch" f" - { ast .unparse (node )} "
52+ )
53+
54+ bindings : Dict [str , ast .expr ] = {}
55+ for target_elt , value_elt in zip (target .elts , value .elts ):
56+ child_bindings = self ._target_bindings (target_elt , value_elt , node )
57+ if child_bindings is None :
58+ return None
59+ bindings .update (child_bindings )
60+ return bindings
61+
62+ raise ValueError (
63+ f"Comprehension variable must be a name or tuple/list, but found { target } "
64+ f" - { ast .unparse (node )} "
65+ )
66+
67+ def _substitute_names (self , expr : ast .expr , bindings : Dict [str , ast .expr ]) -> ast .expr :
68+ class _name_replacer (ast .NodeTransformer ):
69+ def __init__ (self , loop_bindings : Dict [str , ast .expr ]):
70+ self ._loop_bindings = loop_bindings
71+
72+ def visit_Name (self , replace_node : ast .Name ) -> Any :
73+ if (
74+ isinstance (replace_node .ctx , ast .Load )
75+ and replace_node .id in self ._loop_bindings
76+ ):
77+ return copy .deepcopy (self ._loop_bindings [replace_node .id ])
78+ return replace_node
79+
80+ return _name_replacer (bindings ).visit (copy .deepcopy (expr ))
81+
82+ def _inline_literal_comprehension (
83+ self , lambda_body : ast .expr , generators : List [ast .comprehension ], node : ast .AST
84+ ) -> Optional [List [ast .expr ]]:
85+ """Expand comprehensions over literal iterables into literal expressions."""
86+
87+ literal_values : List [List [Tuple [Dict [str , ast .expr ], List [ast .expr ]]]] = []
88+ for generator in generators :
89+ if generator .is_async :
90+ raise ValueError (f"Comprehension can't be async - { ast .unparse (node )} ." )
91+
92+ iter_values = self ._extract_literal_iterable (generator .iter )
93+ if iter_values is None :
94+ return None
95+
96+ generator_values : List [Tuple [Dict [str , ast .expr ], List [ast .expr ]]] = []
97+ for iter_value in iter_values :
98+ bindings = self ._target_bindings (generator .target , iter_value , node )
99+ if bindings is None :
100+ return None
101+ generator_values .append ((bindings , generator .ifs ))
102+ literal_values .append (generator_values )
103+
104+ if len (literal_values ) == 0 :
105+ return []
106+
107+ expanded : List [ast .expr ] = []
108+ for combo in product (* literal_values ):
109+ merged_bindings : Dict [str , ast .expr ] = {}
110+ all_ifs : List [ast .expr ] = []
111+ for c_bindings , c_ifs in combo :
112+ merged_bindings .update (c_bindings )
113+ all_ifs .extend (c_ifs )
114+
115+ include_item = True
116+ for if_clause in all_ifs :
117+ rendered_if = self .visit (self ._substitute_names (if_clause , merged_bindings ))
118+ if not isinstance (rendered_if , ast .Constant ) or not isinstance (
119+ rendered_if .value , bool
120+ ):
121+ raise ValueError (
122+ "Literal comprehension if-clause must resolve to a bool constant"
123+ f" - { ast .unparse (if_clause )} "
124+ )
125+ if not rendered_if .value :
126+ include_item = False
127+ break
128+
129+ if include_item :
130+ rendered_item = self .visit (
131+ self ._substitute_names (lambda_body , merged_bindings )
132+ )
133+ assert isinstance (rendered_item , ast .expr )
134+ expanded .append (rendered_item )
135+
136+ return expanded
137+
25138 def _resolve_any_all_call (
26139 self , call_node : ast .Call , source_node : ast .AST
27140 ) -> Optional [ast .AST ]:
@@ -44,6 +157,8 @@ def _resolve_any_all_call(
44157 )
45158
46159 sequence = call_node .args [0 ]
160+ if isinstance (sequence , (ast .ListComp , ast .GeneratorExp )):
161+ return None
47162 if not isinstance (sequence , (ast .List , ast .Tuple )):
48163 raise ValueError (
49164 f"{ func_name } requires a list or tuple literal argument"
@@ -77,10 +192,8 @@ def resolve_generator(
77192 for c in reversed (generators ):
78193 target = c .target
79194 if not isinstance (target , ast .Name ):
80- raise ValueError (
81- f"Comprehension variable must be a name, but found { target } "
82- f" - { ast .unparse (node )} ."
83- )
195+ # Keep original comprehension for unsupported lowering cases.
196+ return node
84197 if c .is_async :
85198 raise ValueError (f"Comprehension can't be async - { ast .unparse (node )} ." )
86199 source_collection = c .iter
@@ -110,6 +223,10 @@ def visit_ListComp(self, node: ast.ListComp) -> Any:
110223 a = self .generic_visit (node )
111224
112225 if isinstance (a , ast .ListComp ):
226+ if (
227+ expanded := self ._inline_literal_comprehension (a .elt , a .generators , node )
228+ ) is not None :
229+ return ast .List (elts = expanded , ctx = ast .Load ())
113230 a = self .resolve_generator (a .elt , a .generators , node )
114231
115232 return a
@@ -119,6 +236,10 @@ def visit_GeneratorExp(self, node: ast.GeneratorExp) -> Any:
119236 a = self .generic_visit (node )
120237
121238 if isinstance (a , ast .GeneratorExp ):
239+ if (
240+ expanded := self ._inline_literal_comprehension (a .elt , a .generators , node )
241+ ) is not None :
242+ return ast .List (elts = expanded , ctx = ast .Load ())
122243 a = self .resolve_generator (a .elt , a .generators , node )
123244
124245 return a
0 commit comments