Skip to content

Commit 78cc9cb

Browse files
Merge pull request #622 from Starlitnightly/fix/visium-hd-spacerangerv4-compat
Fix/visium hd spacerangerv4 compat
2 parents 26ffcdb + f4fdc92 commit 78cc9cb

File tree

3 files changed

+214
-20
lines changed

3 files changed

+214
-20
lines changed

omicverse/external/bin2cell/bin2cell.py

Lines changed: 209 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def _paste_with_relabel(
272272

273273
return next_label_start
274274

275-
def stardist(
275+
def cellseg(
276276
image_path: str,
277277
labels_npz_path: str,
278278
stardist_model: str = "2D_versatile_he",
@@ -288,22 +288,49 @@ def stardist(
288288
build_sparse_directly: bool = True,
289289
show_progress: bool = True,
290290
progress_desc: Optional[str] = None,
291+
backend: str = "cellpose",
292+
bbox_threshold: float = 0.4,
291293
**kwargs,
292294
):
293295
"""
294-
Drop-in replacement for bin2cell.stardist() using Cellpose with tiled inference.
295-
Keeps I/O identical: writes a SciPy sparse CSR label matrix to labels_npz_path.
296+
Cell segmentation with tiled inference using Cellpose or CellSAM.
297+
Writes a SciPy sparse CSR label matrix to labels_npz_path.
296298
297-
Parameters kept for compatibility:
298-
- block_size, min_overlap, context: control tiling; mirrors original StarDist big-predict.
299-
300-
Additional parameters:
301-
- show_progress: whether to display a progress bar over tiles.
302-
- progress_desc: optional description for the progress bar.
299+
Parameters
300+
----------
301+
image_path : str
302+
Path to the input image.
303+
labels_npz_path : str
304+
Output path for the sparse label matrix (.npz).
305+
stardist_model : str
306+
Model hint. For cellpose: '2D_versatile_he' -> 'cyto',
307+
'2D_versatile_fluo' -> 'nuclei'. Ignored for CellSAM.
308+
block_size, min_overlap, context : int
309+
Tiling parameters.
310+
gpu : bool
311+
Whether to use GPU.
312+
backend : str
313+
Segmentation backend: 'cellpose' (default) or 'cellsam'.
314+
bbox_threshold : float
315+
CellSAM bounding box confidence threshold. Default 0.4.
316+
show_progress : bool
317+
Display progress bar.
318+
**kwargs
319+
Additional arguments forwarded to the backend.
303320
"""
304-
# map Stardist model to Cellpose model
305-
from cellpose import models
306321
from skimage.io import imread
322+
323+
if backend == "cellsam":
324+
return _cellseg_cellsam(
325+
image_path, labels_npz_path,
326+
block_size=block_size, min_overlap=min_overlap, context=context,
327+
gpu=gpu, iou_merge_threshold=iou_merge_threshold,
328+
show_progress=show_progress, progress_desc=progress_desc,
329+
bbox_threshold=bbox_threshold, **kwargs,
330+
)
331+
332+
# --- Cellpose backend ---
333+
from cellpose import models
307334
if stardist_model == "2D_versatile_he":
308335
model_type = "cyto"
309336
elif stardist_model == "2D_versatile_fluo":
@@ -357,7 +384,7 @@ def stardist(
357384
total_tiles = len(y_starts) * len(x_starts)
358385
try:
359386
from tqdm import tqdm # type: ignore
360-
pbar = tqdm(total=total_tiles, desc=progress_desc or "bin2cell.stardist", unit="tile")
387+
pbar = tqdm(total=total_tiles, desc=progress_desc or "bin2cell.cellseg", unit="tile")
361388
except Exception:
362389
use_fallback_progress = True
363390
print_every = max(1, total_tiles // 10) if total_tiles > 0 else 1
@@ -418,7 +445,7 @@ def stardist(
418445
flows = styles = diams = None
419446
except Exception as e:
420447
# safe fallback: yield empty mask for this tile
421-
print(f"[bin2cell.stardist] Cellpose eval failed on a tile: {e}. Using empty mask for this tile.")
448+
print(f"[bin2cell.cellseg] Cellpose eval failed on a tile: {e}. Using empty mask for this tile.")
422449
masks = np.zeros(tile.shape[:2], dtype=np.int32)
423450
flows, styles, diams = None, None, None
424451

@@ -471,7 +498,7 @@ def stardist(
471498
elif use_fallback_progress:
472499
tiles_done += 1
473500
if (tiles_done % print_every == 0) or (tiles_done == total_tiles):
474-
print(f"[bin2cell.stardist] Processed {tiles_done}/{total_tiles} tiles")
501+
print(f"[bin2cell.cellseg] Processed {tiles_done}/{total_tiles} tiles")
475502

476503
# close progress bar if used
477504
if pbar is not None:
@@ -489,7 +516,169 @@ def stardist(
489516
sparse.save_npz(labels_npz_path, labels_csr)
490517

491518

492-
def view_stardist_labels(image_path, labels_npz_path, crop, **kwargs):
519+
def _cellseg_cellsam(
520+
image_path: str,
521+
labels_npz_path: str,
522+
block_size: int = 1024,
523+
min_overlap: int = 128,
524+
context: int = 64,
525+
gpu: bool = False,
526+
iou_merge_threshold: float = 0.5,
527+
show_progress: bool = True,
528+
progress_desc: Optional[str] = None,
529+
bbox_threshold: float = 0.4,
530+
cellsam_model: str = "cellsam_general",
531+
**kwargs,
532+
):
533+
"""CellSAM backend for tiled cell segmentation.
534+
535+
Uses the CellSAM foundation model (vanvalenlab) with non-overlapping
536+
tiling to keep memory usage low. Each tile is segmented independently
537+
and labels are assigned sequentially.
538+
539+
Requires ``pip install git+https://github.com/vanvalenlab/cellSAM.git``
540+
and the ``DEEPCELL_ACCESS_TOKEN`` environment variable for model download.
541+
"""
542+
try:
543+
from cellSAM import get_model, segment_cellular_image
544+
except ImportError:
545+
raise ImportError(
546+
"CellSAM is not installed. Install with:\n"
547+
" pip install git+https://github.com/vanvalenlab/cellSAM.git\n"
548+
"Also set DEEPCELL_ACCESS_TOKEN environment variable."
549+
)
550+
551+
device = "cuda" if gpu else "cpu"
552+
553+
# Read image dimensions without loading full array
554+
import tifffile
555+
with tifffile.TiffFile(image_path) as tif:
556+
page = tif.pages[0]
557+
H, W = page.shape[:2]
558+
n_channels = page.shape[2] if page.ndim > 2 else 1
559+
560+
# Load CellSAM model
561+
print(f"Loading CellSAM model ({cellsam_model})...", flush=True)
562+
model = get_model(model=cellsam_model)
563+
print(f"CellSAM model loaded. Image: {H}x{W}", flush=True)
564+
565+
# Non-overlapping tiling. Read tiles lazily to keep memory bounded.
566+
import tempfile, os as _os
567+
568+
stride = block_size
569+
y_starts = list(range(0, H, stride))
570+
x_starts = list(range(0, W, stride))
571+
total_tiles = len(y_starts) * len(x_starts)
572+
573+
# Write partial results to temp files every N tiles to bound memory
574+
flush_every = max(1, min(50, total_tiles // 10))
575+
tmp_dir = tempfile.mkdtemp(prefix="cellsam_")
576+
part_files = []
577+
buf_rows, buf_cols, buf_vals = [], [], []
578+
next_label = 1
579+
buf_pixels = 0
580+
581+
pbar = None
582+
tiles_done = 0
583+
if show_progress:
584+
try:
585+
from tqdm import tqdm
586+
pbar = tqdm(total=total_tiles, desc=progress_desc or "bin2cell.cellsam", unit="tile")
587+
except Exception:
588+
pass
589+
590+
def _flush_buf():
591+
nonlocal buf_rows, buf_cols, buf_vals, buf_pixels
592+
if not buf_rows:
593+
return
594+
r = np.concatenate(buf_rows)
595+
c = np.concatenate(buf_cols)
596+
v = np.concatenate(buf_vals)
597+
part_path = _os.path.join(tmp_dir, f"part_{len(part_files)}.npz")
598+
np.savez_compressed(part_path, r=r, c=c, v=v)
599+
part_files.append(part_path)
600+
buf_rows.clear(); buf_cols.clear(); buf_vals.clear()
601+
buf_pixels = 0
602+
603+
# Lazy tile reader via zarr (avoids loading full image into memory)
604+
_tif = tifffile.TiffFile(image_path)
605+
_store = _tif.pages[0].aszarr()
606+
import zarr as _zarr
607+
_img_z = _zarr.open(_store, mode='r')
608+
609+
for ys in y_starts:
610+
for xs in x_starts:
611+
y1 = min(ys + block_size, H)
612+
x1 = min(xs + block_size, W)
613+
if _img_z.ndim == 2:
614+
tile_gray = np.array(_img_z[ys:y1, xs:x1])
615+
tile = np.stack([tile_gray] * 3, axis=-1)
616+
elif _img_z.shape[-1] >= 3:
617+
tile = np.array(_img_z[ys:y1, xs:x1, :3])
618+
else:
619+
tile = np.array(_img_z[ys:y1, xs:x1])
620+
621+
try:
622+
result = segment_cellular_image(
623+
tile, model, device=device,
624+
bbox_threshold=bbox_threshold,
625+
**{k: v for k, v in kwargs.items()
626+
if k in ('normalize', 'postprocess', 'remove_boundaries', 'fast')},
627+
)
628+
masks = result[0] if result[0] is not None else None
629+
except Exception:
630+
masks = None
631+
632+
if masks is not None and masks.max() > 0:
633+
# Vectorised relabel: shift all labels at once
634+
n_cells = masks.max()
635+
nz = masks > 0
636+
ry, rx = np.where(nz)
637+
shifted = masks[nz] + (next_label - 1)
638+
buf_rows.append((ry + ys).astype(np.int32))
639+
buf_cols.append((rx + xs).astype(np.int32))
640+
buf_vals.append(shifted.astype(np.int32))
641+
buf_pixels += len(ry)
642+
next_label += n_cells
643+
644+
tiles_done += 1
645+
if pbar is not None:
646+
pbar.update(1)
647+
elif show_progress and tiles_done % max(1, total_tiles // 20) == 0:
648+
print(f"[bin2cell.cellsam] {tiles_done}/{total_tiles} tiles", flush=True)
649+
650+
# Flush to disk when buffer gets large (~50M pixels)
651+
if buf_pixels > 50_000_000:
652+
_flush_buf()
653+
654+
if pbar is not None:
655+
pbar.close()
656+
657+
# Final flush
658+
_flush_buf()
659+
660+
# Merge all parts into a single CSR matrix
661+
all_r, all_c, all_v = [], [], []
662+
for pf in part_files:
663+
d = np.load(pf)
664+
all_r.append(d['r']); all_c.append(d['c']); all_v.append(d['v'])
665+
_os.remove(pf)
666+
_os.rmdir(tmp_dir)
667+
668+
if all_r:
669+
rows = np.concatenate(all_r)
670+
cols = np.concatenate(all_c)
671+
vals = np.concatenate(all_v)
672+
labels_csr = sparse.csr_matrix((vals, (rows, cols)), shape=(H, W), dtype=np.int32)
673+
else:
674+
labels_csr = sparse.csr_matrix((H, W), dtype=np.int32)
675+
676+
_tif.close()
677+
print(f"CellSAM segmentation complete: {next_label - 1} cells", flush=True)
678+
sparse.save_npz(labels_npz_path, labels_csr)
679+
680+
681+
def view_cellseg_labels(image_path, labels_npz_path, crop, **kwargs):
493682
'''
494683
Use StarDist's label rendering to view segmentation results in a crop
495684
of the input image.
@@ -2257,3 +2446,8 @@ def _plot_boundaries_matplotlib(cell_adata, color, segmentation_key, library_id,
22572446
ax.set_ylabel('Y coordinate')
22582447

22592448
return ax
2449+
2450+
# Backward-compatible aliases
2451+
stardist = cellseg
2452+
view_stardist_labels = view_cellseg_labels
2453+

omicverse/space/_tools.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,7 @@ def visium_10x_hd_cellpose_he(
883883
Convert Visium 10x data to cell-level data.
884884
885885
"""
886-
from ..external.bin2cell import destripe, scaled_he_image, stardist, insert_labels
886+
from ..external.bin2cell import destripe, scaled_he_image, cellseg, insert_labels
887887

888888
spatial_key = f"spatial_cropped_{buffer}_buffer"
889889
if not os.path.exists(he_save_path):
@@ -897,7 +897,7 @@ def visium_10x_hd_cellpose_he(
897897
destripe(adata)
898898
scaled_he_image(adata, mpp=mpp, buffer=buffer, save_path=None,
899899
backend=backend)
900-
stardist(image_path=he_save_path ,
900+
cellseg(image_path=he_save_path ,
901901
labels_npz_path=he_save_path.replace(".tiff", ".npz"),
902902
stardist_model="2D_versatile_he",
903903
prob_thresh=prob_thresh,
@@ -1028,7 +1028,7 @@ def visium_10x_hd_cellpose_gex(
10281028
None
10291029
Writes ``labels_gex`` back into ``adata``.
10301030
"""
1031-
from ..external.bin2cell import grid_image, stardist, insert_labels,destripe
1031+
from ..external.bin2cell import grid_image, cellseg, insert_labels,destripe
10321032
#if gex_save_path's file exist, jump grid_image to stardist
10331033
if obs_key not in adata.obs.keys():
10341034
destripe(adata)
@@ -1037,7 +1037,7 @@ def visium_10x_hd_cellpose_gex(
10371037
mpp=mpp, sigma=sigma, save_path=gex_save_path)
10381038
else:
10391039
print(f"gex_save_path {gex_save_path} already exists, skipping grid_image")
1040-
stardist(image_path=gex_save_path,
1040+
cellseg(image_path=gex_save_path,
10411041
labels_npz_path=gex_save_path.replace(".tiff", ".npz"),
10421042
stardist_model="2D_versatile_fluo",
10431043
prob_thresh=prob_thresh,

0 commit comments

Comments
 (0)