Skip to content

Commit 8081815

Browse files
shimwellpaulromano
andauthored
Add RegularMesh.get_indices_at_coords method (#3824)
Co-authored-by: Paul Romano <[email protected]>
1 parent 54c8e3d commit 8081815

File tree

2 files changed

+130
-2
lines changed

2 files changed

+130
-2
lines changed

openmc/mesh.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,10 @@ def _axis_labels(self):
547547
def _grids(self):
548548
pass
549549

550+
@abstractmethod
551+
def get_indices_at_coords(self, coords: Sequence[float]) -> tuple:
552+
pass
553+
550554
@property
551555
def vertices(self):
552556
"""Return coordinates of mesh vertices in Cartesian coordinates. Also
@@ -1432,6 +1436,47 @@ def build_cells(self, bc: str | None = None):
14321436

14331437
return root_cell, cells
14341438

1439+
def get_indices_at_coords(self, coords: Sequence[float]) -> tuple:
1440+
"""Finds the index of the mesh element at the specified coordinates.
1441+
1442+
.. versionadded:: 0.15.4
1443+
1444+
Parameters
1445+
----------
1446+
coords : Sequence[float]
1447+
Cartesian coordinates of the point.
1448+
1449+
Returns
1450+
-------
1451+
tuple
1452+
Mesh indices matching the dimensionality of the mesh
1453+
1454+
"""
1455+
ndim = self.n_dimension
1456+
if len(coords) < ndim:
1457+
raise ValueError(
1458+
f"coords must have at least {ndim} values for a "
1459+
f"{ndim}D mesh, got {len(coords)}"
1460+
)
1461+
1462+
coords_array = np.array(coords[:ndim])
1463+
lower_left = np.array(self.lower_left)
1464+
upper_right = np.array(self.upper_right)
1465+
dimension = np.array(self.dimension)
1466+
1467+
if np.any(coords_array < lower_left) or np.any(coords_array > upper_right):
1468+
raise ValueError(
1469+
f"coords {tuple(coords_array)} are outside mesh bounds "
1470+
f"[{tuple(lower_left)}, {tuple(upper_right)}]"
1471+
)
1472+
1473+
# Calculate spacing for each dimension
1474+
spacing = (upper_right - lower_left) / dimension
1475+
1476+
# Calculate indices for each coordinate
1477+
indices = np.floor((coords_array - lower_left) / spacing).astype(int)
1478+
return tuple(int(i) for i in indices[:ndim])
1479+
14351480

14361481
def Mesh(*args, **kwargs):
14371482
warnings.warn("Mesh has been renamed RegularMesh. Future versions of "
@@ -1643,6 +1688,11 @@ def to_xml_element(self):
16431688

16441689
return element
16451690

1691+
def get_indices_at_coords(self, coords: Sequence[float]) -> tuple:
1692+
raise NotImplementedError(
1693+
"get_indices_at_coords is not yet implemented for RectilinearMesh"
1694+
)
1695+
16461696

16471697
class CylindricalMesh(StructuredMesh):
16481698
"""A 3D cylindrical mesh
@@ -1835,14 +1885,14 @@ def get_indices_at_coords(
18351885
self,
18361886
coords: Sequence[float]
18371887
) -> tuple[int, int, int]:
1838-
"""Finds the index of the mesh voxel at the specified x,y,z coordinates.
1888+
"""Finds the index of the mesh element at the specified coordinates.
18391889
18401890
.. versionadded:: 0.15.0
18411891
18421892
Parameters
18431893
----------
18441894
coords : Sequence[float]
1845-
The x, y, z axis coordinates
1895+
Cartesian coordinates of the point.
18461896
18471897
Returns
18481898
-------
@@ -2478,6 +2528,11 @@ def _convert_to_cartesian(arr, origin: Sequence[float]):
24782528
arr[..., 2] = z + origin[2]
24792529
return arr
24802530

2531+
def get_indices_at_coords(self, coords: Sequence[float]) -> tuple:
2532+
raise NotImplementedError(
2533+
"get_indices_at_coords is not yet implemented for SphericalMesh"
2534+
)
2535+
24812536

24822537
def require_statepoint_data(func):
24832538
@wraps(func)

tests/unit_tests/test_mesh.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -920,3 +920,76 @@ def test_filter_time_mesh(run_in_tmpdir):
920920
f"Collision vs tracklength tallies disagree: chi2={chi2_stat:.2f} "
921921
f">= {crit=:.2f} ({dof=}, {alpha=})"
922922
)
923+
924+
925+
def test_regular_mesh_get_indices_at_coords():
926+
"""Test get_indices_at_coords method for RegularMesh"""
927+
# Create a 10x10x10 mesh from (0,0,0) to (1,1,1)
928+
# Each voxel is 0.1 x 0.1 x 0.1
929+
mesh = openmc.RegularMesh()
930+
mesh.lower_left = (0, 0, 0)
931+
mesh.upper_right = (1, 1, 1)
932+
mesh.dimension = [10, 10, 10]
933+
934+
# Test lower-left corner maps to first voxel (0, 0, 0)
935+
assert mesh.get_indices_at_coords([0.0, 0.0, 0.0]) == (0, 0, 0)
936+
937+
# Test centroid of first voxel
938+
# Voxel 0 spans [0.0, 0.1], so centroid is at 0.05
939+
assert mesh.get_indices_at_coords([0.05, 0.05, 0.05]) == (0, 0, 0)
940+
941+
# Test centroid of last voxel maps correctly
942+
# Voxel 9 spans [0.9, 1.0], so centroid is at 0.95
943+
assert mesh.get_indices_at_coords([0.95, 0.95, 0.95]) == (9, 9, 9)
944+
945+
# Test a middle voxel
946+
# Voxel 4 spans [0.4, 0.5], so 0.45 should map to it
947+
assert mesh.get_indices_at_coords([0.45, 0.45, 0.45]) == (4, 4, 4)
948+
949+
# Test mixed indices
950+
assert mesh.get_indices_at_coords([0.05, 0.45, 0.95]) == (0, 4, 9)
951+
assert mesh.get_indices_at_coords([0.95, 0.05, 0.45]) == (9, 0, 4)
952+
953+
# Test coordinates outside mesh bounds raise ValueError
954+
with pytest.raises(ValueError):
955+
mesh.get_indices_at_coords([-0.5, 0.5, 0.5])
956+
with pytest.raises(ValueError):
957+
mesh.get_indices_at_coords([1.5, 0.5, 0.5])
958+
with pytest.raises(ValueError):
959+
mesh.get_indices_at_coords([0.5, -0.5, 0.5])
960+
with pytest.raises(ValueError):
961+
mesh.get_indices_at_coords([0.5, 1.5, 0.5])
962+
with pytest.raises(ValueError):
963+
mesh.get_indices_at_coords([0.5, 0.5, -0.5])
964+
with pytest.raises(ValueError):
965+
mesh.get_indices_at_coords([0.5, 0.5, 1.5])
966+
967+
# Test that results match expected dimensionality (3D mesh returns 3-tuple)
968+
result = mesh.get_indices_at_coords([0.5, 0.5, 0.5])
969+
assert isinstance(result, tuple)
970+
assert len(result) == 3
971+
972+
# Test that indices can be used directly with centroids array
973+
idx = mesh.get_indices_at_coords([0.95, 0.95, 0.95])
974+
centroid = mesh.centroids[idx]
975+
np.testing.assert_array_almost_equal(centroid, [0.95, 0.95, 0.95])
976+
977+
# Test with a 2D mesh
978+
mesh_2d = openmc.RegularMesh()
979+
mesh_2d.lower_left = (0, 0)
980+
mesh_2d.upper_right = (1, 1)
981+
mesh_2d.dimension = [10, 10]
982+
result_2d = mesh_2d.get_indices_at_coords([0.5, 0.5, 999.0])
983+
assert isinstance(result_2d, tuple)
984+
assert len(result_2d) == 2
985+
assert result_2d == (5, 5)
986+
987+
# Test with a 1D mesh
988+
mesh_1d = openmc.RegularMesh()
989+
mesh_1d.lower_left = [0]
990+
mesh_1d.upper_right = [1]
991+
mesh_1d.dimension = [10]
992+
result_1d = mesh_1d.get_indices_at_coords([0.5, 999.0, 999.0])
993+
assert isinstance(result_1d, tuple)
994+
assert len(result_1d) == 1
995+
assert result_1d == (5,)

0 commit comments

Comments
 (0)