Skip to content

Commit a6b0c7a

Browse files
committed
Cleanup analysis.py
1 parent d3296f1 commit a6b0c7a

2 files changed

Lines changed: 19 additions & 23 deletions

File tree

ufl/algorithms/analysis.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -69,33 +69,30 @@ def extract_type(a, ufl_types):
6969
objects = set()
7070
arg_types = tuple(t for t in ufl_types if issubclass(t, BaseArgument))
7171
if arg_types:
72-
objects.update([e for e in a.arguments() if isinstance(e, arg_types)])
72+
objects.update(e for e in a.arguments() if isinstance(e, arg_types))
7373
coeff_types = tuple(t for t in ufl_types if issubclass(t, BaseCoefficient))
7474
if coeff_types:
75-
objects.update([e for e in a.coefficients() if isinstance(e, coeff_types)])
75+
objects.update(e for e in a.coefficients() if isinstance(e, coeff_types))
7676
return objects
7777

7878
if all(issubclass(t, Terminal) for t in ufl_types):
7979
# Optimization
80-
objects = set(
81-
o
82-
for e in iter_expressions(a)
83-
for o in traverse_unique_terminals(e)
84-
if any(isinstance(o, t) for t in ufl_types)
85-
)
80+
traversal = traverse_unique_terminals
8681
else:
87-
objects = set(
88-
o
89-
for e in iter_expressions(a)
90-
for o in unique_pre_traversal(e)
91-
if any(isinstance(o, t) for t in ufl_types)
92-
)
82+
traversal = unique_pre_traversal
83+
84+
objects = set(
85+
o
86+
for e in iter_expressions(a)
87+
for o in traversal(e)
88+
if isinstance(o, ufl_types)
89+
)
9390

9491
# Need to extract objects contained in base form operators whose
9592
# type is in ufl_types
9693
base_form_ops = set(e for e in objects if isinstance(e, BaseFormOperator))
9794
ufl_types_no_args = tuple(t for t in ufl_types if not issubclass(t, BaseArgument))
98-
base_form_objects = ()
95+
base_form_objects = []
9996
for o in base_form_ops:
10097
# This accounts for having BaseFormOperator in Forms: if N is a BaseFormOperator
10198
# `N(u; v*) * v * dx` <=> `action(v1 * v * dx, N(...; v*))`
@@ -106,17 +103,17 @@ def extract_type(a, ufl_types):
106103
# argument of the Coargument and not its primal argument.
107104
if isinstance(ai, Coargument):
108105
new_types = tuple(Coargument if t is BaseArgument else t for t in ufl_types)
109-
base_form_objects += tuple(extract_type(ai, new_types))
106+
base_form_objects.extend(extract_type(ai, new_types))
110107
else:
111-
base_form_objects += tuple(extract_type(ai, ufl_types))
108+
base_form_objects.extend(extract_type(ai, ufl_types))
112109
# Look for BaseArguments in BaseFormOperator's argument slots
113110
# only since that's where they are by definition. Don't look
114111
# into operands, which is convenient for external operator
115112
# composition, e.g. N1(N2; v*) where N2 is seen as an operator
116113
# and not a form.
117114
slots = o.ufl_operands
118115
for ai in slots:
119-
base_form_objects += tuple(extract_type(ai, ufl_types_no_args))
116+
base_form_objects.extend(extract_type(ai, ufl_types_no_args))
120117
objects.update(base_form_objects)
121118

122119
# `Remove BaseFormOperator` objects if there were initially not in `ufl_types`
@@ -249,10 +246,10 @@ def extract_unique_elements(form):
249246

250247
def extract_sub_elements(elements):
251248
"""Build sorted tuple of all sub elements (including parent element)."""
252-
sub_elements = tuple(chain(*[e.sub_elements for e in elements]))
249+
sub_elements = tuple(chain(*(e.sub_elements for e in elements)))
253250
if not sub_elements:
254251
return tuple(elements)
255-
return tuple(elements) + extract_sub_elements(sub_elements)
252+
return (*elements, *extract_sub_elements(sub_elements))
256253

257254

258255
def sort_elements(elements):
@@ -268,7 +265,7 @@ def sort_elements(elements):
268265
nodes = list(elements)
269266

270267
# Set edges
271-
edges = dict((node, []) for node in nodes)
268+
edges = {node: [] for node in nodes}
272269
for element in elements:
273270
for sub_element in element.sub_elements:
274271
edges[element].append(sub_element)

ufl/checks.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ def is_cellwise_constant(expr):
4242
if is_cellwise_constant(expr):
4343
return True
4444
elif isinstance(expr, SpatialCoordinate):
45-
element = expr.ufl_domain().ufl_coordinate_element()
46-
return element.embedded_superdegree <= 1
45+
return expr.ufl_domain().is_piecewise_linear_simplex_domain()
4746
elif isinstance(expr, Coefficient):
4847
element = expr.ufl_element()
4948
return element.embedded_superdegree <= 1

0 commit comments

Comments
 (0)