Skip to content

Commit ac63024

Browse files
authored
Merge pull request #65 from underworldcode/bugfix/fix-swarm-cache
Fix swarm cache invalidation after parallel migration
2 parents 9971667 + ba3f050 commit ac63024

File tree

3 files changed

+116
-8
lines changed

3 files changed

+116
-8
lines changed

src/underworld3/function/_function.pyx

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,12 @@ def global_evaluate_nd( expr,
502502
evaluation_swarm.dm.migrate(remove_sent_points=True)
503503
uw.mpi.barrier()
504504

505+
# Invalidate cached data after bare-bones dm.migrate —
506+
# particle count and values changed but Swarm.migrate() was bypassed.
507+
evaluation_swarm._particle_coordinates._canonical_data = None
508+
for var in evaluation_swarm._vars.values():
509+
if hasattr(var, "_canonical_data"):
510+
var._canonical_data = None
505511

506512
index = original_index.array[:,0,0]
507513

src/underworld3/swarm.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3315,14 +3315,13 @@ def migrate(
33153315
for index in indices:
33163316
self.dm.removePointAtIndex(index)
33173317

3318-
# CRITICAL FIX: Invalidate cached data after removing particles
3319-
# The _particle_coordinates variable caches data - must refresh after DM changes
3320-
self._particle_coordinates._canonical_data = None
3321-
3322-
# Also invalidate caches for all swarm variables
3323-
for var in self._vars.values():
3324-
if hasattr(var, "_canonical_data"):
3325-
var._canonical_data = None
3318+
# Invalidate all cached data after migration.
3319+
# Any particle movement (send, receive, or balanced swap) makes
3320+
# cached arrays stale — both size and values may have changed.
3321+
self._particle_coordinates._canonical_data = None
3322+
for var in self._vars.values():
3323+
if hasattr(var, "_canonical_data"):
3324+
var._canonical_data = None
33263325

33273326
return
33283327

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""
2+
Regression test for swarm cache invalidation after migration.
3+
4+
Verifies that SwarmVariable._canonical_data caches are properly invalidated
5+
after Swarm.migrate() and after bare-bones dm.migrate() in global_evaluate.
6+
7+
Bug: SwarmVariable caches were only invalidated inside the delete_lost_points
8+
branch of Swarm.migrate(), so caches became stale when particles moved between
9+
ranks without deletion. This caused shape mismatches in global_evaluate.
10+
11+
See: https://github.com/underworldcode/underworld3/issues/64
12+
13+
Run with:
14+
mpirun -n 2 python -m pytest --with-mpi tests/parallel/test_0760_swarm_cache_migration.py
15+
"""
16+
17+
import pytest
18+
import numpy as np
19+
import underworld3 as uw
20+
from mpi4py import MPI
21+
22+
pytestmark = [pytest.mark.mpi(min_size=2), pytest.mark.timeout(60)]
23+
24+
25+
@pytest.mark.mpi(min_size=2)
26+
@pytest.mark.level_1
27+
@pytest.mark.tier_a
28+
def test_swarm_cache_valid_after_migration():
29+
"""Swarm variable caches must reflect actual particle count after migration."""
30+
mesh = uw.meshing.UnstructuredSimplexBox(minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0))
31+
swarm = uw.swarm.Swarm(mesh)
32+
var = uw.swarm.SwarmVariable("test_var", swarm, vtype=uw.VarType.SCALAR, _proxy=False)
33+
34+
# Add particles at random positions — distribution will be uneven across ranks
35+
np.random.seed(42 + uw.mpi.rank)
36+
coords = np.random.random((200, mesh.dim))
37+
swarm.add_particles_with_global_coordinates(coords, migrate=False)
38+
var.data[...] = uw.mpi.rank
39+
40+
pre_count = swarm.dm.getLocalSize()
41+
42+
# Migrate — particles move to owning rank
43+
swarm.migrate(remove_sent_points=True, delete_lost_points=False)
44+
45+
post_count = swarm.dm.getLocalSize()
46+
coords_cached = swarm._particle_coordinates.data.shape[0]
47+
var_cached = var.data.shape[0]
48+
49+
# Cached sizes must match the actual DM particle count
50+
assert coords_cached == post_count, (
51+
f"Rank {uw.mpi.rank}: coordinate cache ({coords_cached}) != "
52+
f"DM count ({post_count}) after migration"
53+
)
54+
assert var_cached == post_count, (
55+
f"Rank {uw.mpi.rank}: variable cache ({var_cached}) != "
56+
f"DM count ({post_count}) after migration"
57+
)
58+
59+
60+
@pytest.mark.mpi(min_size=2)
61+
@pytest.mark.level_1
62+
@pytest.mark.tier_a
63+
def test_global_evaluate_after_migration():
64+
"""global_evaluate must succeed with coordinates that force heavy migration."""
65+
mesh = uw.meshing.UnstructuredSimplexBox(minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0))
66+
v = uw.discretisation.MeshVariable("u", mesh, mesh.dim, degree=2)
67+
68+
# Bias coordinates to one side — forces cross-rank particle movement
69+
np.random.seed(42)
70+
N = 300
71+
coords = np.random.random((N, mesh.dim))
72+
coords[:, 0] = 0.5 + 0.5 * coords[:, 0] # all in right half
73+
74+
result = uw.function.global_evaluate(v.sym, coords)
75+
76+
assert result.shape[0] == N, (
77+
f"Rank {uw.mpi.rank}: expected {N} results, got {result.shape[0]}"
78+
)
79+
80+
81+
@pytest.mark.mpi(min_size=2)
82+
@pytest.mark.level_1
83+
@pytest.mark.tier_a
84+
def test_global_evaluate_displaced_nodes():
85+
"""global_evaluate with displaced node coordinates (DDt/SemiLagrangian path)."""
86+
mesh = uw.meshing.UnstructuredSimplexBox(minCoords=(0.0, 0.0), maxCoords=(1.0, 1.0))
87+
v = uw.discretisation.MeshVariable("u", mesh, mesh.dim, degree=2)
88+
89+
# Displace node coordinates — simulates semi-Lagrangian departure points
90+
node_coords = mesh.X.coords
91+
np.random.seed(7)
92+
displacement = np.random.random(node_coords.shape) * 0.3
93+
mid_pt_coords = node_coords - displacement
94+
95+
# Clamp to domain
96+
mid_pt_coords = np.clip(mid_pt_coords, 0.0, 1.0)
97+
98+
result = uw.function.evaluate(v.sym, mid_pt_coords)
99+
100+
assert result.shape[0] == node_coords.shape[0], (
101+
f"Rank {uw.mpi.rank}: expected {node_coords.shape[0]} results, "
102+
f"got {result.shape[0]}"
103+
)

0 commit comments

Comments
 (0)