Skip to content

Commit 7e4b43b

Browse files
kkollsgaclaudejsignell
authored
Fix sortby descending order placing NaNs at beginning instead of end (#11118)
Use duck_array_ops.notnull as additional sort keys to ensure null values sort to the end in descending order. This is cleaner than the previous approach of manually tracking NaN positions. Fixes #7358 Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Julia Signell <jsignell@gmail.com>
1 parent e993e1c commit 7e4b43b

File tree

3 files changed

+59
-2
lines changed

3 files changed

+59
-2
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ Bug Fixes
4646
- Fix silent data corruption when writing dask arrays to sharded Zarr stores.
4747
Dask chunk boundaries must now align with shard boundaries, not just internal
4848
Zarr chunk boundaries (:issue:`10831`).
49+
- Fix :py:meth:`Dataset.sortby` and :py:meth:`DataArray.sortby` placing NaN values
50+
at the beginning instead of the end when using ``ascending=False`` (:issue:`7358`).
51+
By `Kristian Kollsgård <https://github.com/kkollsga>`_.
4952

5053
Documentation
5154
~~~~~~~~~~~~~

xarray/core/dataset.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8129,8 +8129,21 @@ def sortby(
81298129

81308130
indices = {}
81318131
for key, arrays in vars_by_dim.items():
8132-
order = np.lexsort(tuple(reversed(arrays)))
8133-
indices[key] = order if ascending else order[::-1]
8132+
if ascending:
8133+
indices[key] = np.lexsort(tuple(reversed(arrays)))
8134+
else:
8135+
# For descending order, we need to keep NaNs at the end.
8136+
# By adding notnull(arr) as additional sort keys, null values
8137+
# sort to the beginning (False=0 < True=1), then reversing
8138+
# puts them at the end. See https://github.com/pydata/xarray/issues/7358
8139+
indices[key] = np.lexsort(
8140+
tuple(
8141+
[
8142+
*reversed(arrays),
8143+
*[duck_array_ops.notnull(arr) for arr in reversed(arrays)],
8144+
]
8145+
)
8146+
)[::-1]
81348147
return aligned_self.isel(indices)
81358148

81368149
def quantile(

xarray/tests/test_dataset.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7232,6 +7232,47 @@ def test_sortby(self) -> None:
72327232
actual = ds.sortby(["x", "y"], ascending=False)
72337233
assert_equal(actual, ds)
72347234

7235+
def test_sortby_descending_nans(self) -> None:
7236+
# Regression test for https://github.com/pydata/xarray/issues/7358
7237+
# NaN values should remain at the end when sorting in descending order
7238+
ds = Dataset({"var": ("x", [3.0, np.nan, 4.0, 2.0, np.nan])})
7239+
7240+
# Ascending: NaNs at end
7241+
result_asc = ds.sortby("var", ascending=True)
7242+
assert_array_equal(result_asc["var"].values[:3], [2.0, 3.0, 4.0])
7243+
assert np.all(np.isnan(result_asc["var"].values[3:]))
7244+
7245+
# Descending: NaNs should also be at end (not beginning)
7246+
result_desc = ds.sortby("var", ascending=False)
7247+
assert_array_equal(result_desc["var"].values[:3], [4.0, 3.0, 2.0])
7248+
assert np.all(np.isnan(result_desc["var"].values[3:]))
7249+
7250+
def test_sortby_descending_nans_multi_key(self) -> None:
7251+
# Test sortby with multiple keys where one has NaN values
7252+
# Regression test for https://github.com/pydata/xarray/issues/7358
7253+
ds = Dataset(
7254+
{
7255+
"A": (("x", "y"), [[1, 2, 3], [4, 5, 6]]),
7256+
"B": (("x", "y"), [[7, 8, 9], [10, 11, 12]]),
7257+
},
7258+
coords={"x": ["b", "a"], "y": [np.nan, 1, 0]},
7259+
)
7260+
7261+
# Sort by multiple keys in descending order
7262+
result = ds.sortby(["x", "y"], ascending=False)
7263+
7264+
# x should be sorted descending: ["b", "a"]
7265+
assert_array_equal(result["x"].values, ["b", "a"])
7266+
7267+
# y should be sorted descending with NaN at end: [1, 0, nan]
7268+
assert_array_equal(result["y"].values[:2], [1, 0])
7269+
assert np.isnan(result["y"].values[2])
7270+
7271+
# Verify data is reordered correctly
7272+
# Original y=[nan, 1, 0] -> sorted y=[1, 0, nan] means columns reordered [1, 2, 0]
7273+
assert_array_equal(result["A"].values, [[2, 3, 1], [5, 6, 4]])
7274+
assert_array_equal(result["B"].values, [[8, 9, 7], [11, 12, 10]])
7275+
72357276
def test_attribute_access(self) -> None:
72367277
ds = create_test_data(seed=1)
72377278
for key in ["var1", "var2", "var3", "time", "dim1", "dim2", "dim3", "numbers"]:

0 commit comments

Comments
 (0)