diff --git a/gsplat/compression/sort.py b/gsplat/compression/sort.py index 0580eea80..6f516b344 100644 --- a/gsplat/compression/sort.py +++ b/gsplat/compression/sort.py @@ -32,7 +32,7 @@ def sort_splats(splats: Dict[str, Tensor], verbose: bool = True) -> Dict[str, Te sort_keys = ["means", "quats", "scales", "opacities"] if "sh0" in splats: sort_keys.append("sh0") - + params_to_sort = torch.cat([splats[k].reshape(n_gs, -1) for k in sort_keys], dim=-1) shuffled_indices = torch.randperm( params_to_sort.shape[0], device=params_to_sort.device