@@ -272,7 +272,7 @@ def _paste_with_relabel(
272272
273273 return next_label_start
274274
275- def stardist (
275+ def cellseg (
276276 image_path : str ,
277277 labels_npz_path : str ,
278278 stardist_model : str = "2D_versatile_he" ,
@@ -288,22 +288,49 @@ def stardist(
288288 build_sparse_directly : bool = True ,
289289 show_progress : bool = True ,
290290 progress_desc : Optional [str ] = None ,
291+ backend : str = "cellpose" ,
292+ bbox_threshold : float = 0.4 ,
291293 ** kwargs ,
292294):
293295 """
294- Drop-in replacement for bin2cell.stardist() using Cellpose with tiled inference .
295- Keeps I/O identical: writes a SciPy sparse CSR label matrix to labels_npz_path.
296+ Cell segmentation with tiled inference using Cellpose or CellSAM .
297+ Writes a SciPy sparse CSR label matrix to labels_npz_path.
296298
297- Parameters kept for compatibility:
298- - block_size, min_overlap, context: control tiling; mirrors original StarDist big-predict.
299-
300- Additional parameters:
301- - show_progress: whether to display a progress bar over tiles.
302- - progress_desc: optional description for the progress bar.
299+ Parameters
300+ ----------
301+ image_path : str
302+ Path to the input image.
303+ labels_npz_path : str
304+ Output path for the sparse label matrix (.npz).
305+ stardist_model : str
306+ Model hint. For cellpose: '2D_versatile_he' -> 'cyto',
307+ '2D_versatile_fluo' -> 'nuclei'. Ignored for CellSAM.
308+ block_size, min_overlap, context : int
309+ Tiling parameters.
310+ gpu : bool
311+ Whether to use GPU.
312+ backend : str
313+ Segmentation backend: 'cellpose' (default) or 'cellsam'.
314+ bbox_threshold : float
315+ CellSAM bounding box confidence threshold. Default 0.4.
316+ show_progress : bool
317+ Display progress bar.
318+ **kwargs
319+ Additional arguments forwarded to the backend.
303320 """
304- # map Stardist model to Cellpose model
305- from cellpose import models
306321 from skimage .io import imread
322+
323+ if backend == "cellsam" :
324+ return _cellseg_cellsam (
325+ image_path , labels_npz_path ,
326+ block_size = block_size , min_overlap = min_overlap , context = context ,
327+ gpu = gpu , iou_merge_threshold = iou_merge_threshold ,
328+ show_progress = show_progress , progress_desc = progress_desc ,
329+ bbox_threshold = bbox_threshold , ** kwargs ,
330+ )
331+
332+ # --- Cellpose backend ---
333+ from cellpose import models
307334 if stardist_model == "2D_versatile_he" :
308335 model_type = "cyto"
309336 elif stardist_model == "2D_versatile_fluo" :
@@ -357,7 +384,7 @@ def stardist(
357384 total_tiles = len (y_starts ) * len (x_starts )
358385 try :
359386 from tqdm import tqdm # type: ignore
360- pbar = tqdm (total = total_tiles , desc = progress_desc or "bin2cell.stardist " , unit = "tile" )
387+ pbar = tqdm (total = total_tiles , desc = progress_desc or "bin2cell.cellseg " , unit = "tile" )
361388 except Exception :
362389 use_fallback_progress = True
363390 print_every = max (1 , total_tiles // 10 ) if total_tiles > 0 else 1
@@ -418,7 +445,7 @@ def stardist(
418445 flows = styles = diams = None
419446 except Exception as e :
420447 # safe fallback: yield empty mask for this tile
421- print (f"[bin2cell.stardist ] Cellpose eval failed on a tile: { e } . Using empty mask for this tile." )
448+ print (f"[bin2cell.cellseg ] Cellpose eval failed on a tile: { e } . Using empty mask for this tile." )
422449 masks = np .zeros (tile .shape [:2 ], dtype = np .int32 )
423450 flows , styles , diams = None , None , None
424451
@@ -471,7 +498,7 @@ def stardist(
471498 elif use_fallback_progress :
472499 tiles_done += 1
473500 if (tiles_done % print_every == 0 ) or (tiles_done == total_tiles ):
474- print (f"[bin2cell.stardist ] Processed { tiles_done } /{ total_tiles } tiles" )
501+ print (f"[bin2cell.cellseg ] Processed { tiles_done } /{ total_tiles } tiles" )
475502
476503 # close progress bar if used
477504 if pbar is not None :
@@ -489,7 +516,169 @@ def stardist(
489516 sparse .save_npz (labels_npz_path , labels_csr )
490517
491518
492- def view_stardist_labels (image_path , labels_npz_path , crop , ** kwargs ):
519+ def _cellseg_cellsam (
520+ image_path : str ,
521+ labels_npz_path : str ,
522+ block_size : int = 1024 ,
523+ min_overlap : int = 128 ,
524+ context : int = 64 ,
525+ gpu : bool = False ,
526+ iou_merge_threshold : float = 0.5 ,
527+ show_progress : bool = True ,
528+ progress_desc : Optional [str ] = None ,
529+ bbox_threshold : float = 0.4 ,
530+ cellsam_model : str = "cellsam_general" ,
531+ ** kwargs ,
532+ ):
533+ """CellSAM backend for tiled cell segmentation.
534+
535+ Uses the CellSAM foundation model (vanvalenlab) with non-overlapping
536+ tiling to keep memory usage low. Each tile is segmented independently
537+ and labels are assigned sequentially.
538+
539+ Requires ``pip install git+https://github.com/vanvalenlab/cellSAM.git``
540+ and the ``DEEPCELL_ACCESS_TOKEN`` environment variable for model download.
541+ """
542+ try :
543+ from cellSAM import get_model , segment_cellular_image
544+ except ImportError :
545+ raise ImportError (
546+ "CellSAM is not installed. Install with:\n "
547+ " pip install git+https://github.com/vanvalenlab/cellSAM.git\n "
548+ "Also set DEEPCELL_ACCESS_TOKEN environment variable."
549+ )
550+
551+ device = "cuda" if gpu else "cpu"
552+
553+ # Read image dimensions without loading full array
554+ import tifffile
555+ with tifffile .TiffFile (image_path ) as tif :
556+ page = tif .pages [0 ]
557+ H , W = page .shape [:2 ]
558+ n_channels = page .shape [2 ] if page .ndim > 2 else 1
559+
560+ # Load CellSAM model
561+ print (f"Loading CellSAM model ({ cellsam_model } )..." , flush = True )
562+ model = get_model (model = cellsam_model )
563+ print (f"CellSAM model loaded. Image: { H } x{ W } " , flush = True )
564+
565+ # Non-overlapping tiling. Read tiles lazily to keep memory bounded.
566+ import tempfile , os as _os
567+
568+ stride = block_size
569+ y_starts = list (range (0 , H , stride ))
570+ x_starts = list (range (0 , W , stride ))
571+ total_tiles = len (y_starts ) * len (x_starts )
572+
573+ # Write partial results to temp files every N tiles to bound memory
574+ flush_every = max (1 , min (50 , total_tiles // 10 ))
575+ tmp_dir = tempfile .mkdtemp (prefix = "cellsam_" )
576+ part_files = []
577+ buf_rows , buf_cols , buf_vals = [], [], []
578+ next_label = 1
579+ buf_pixels = 0
580+
581+ pbar = None
582+ tiles_done = 0
583+ if show_progress :
584+ try :
585+ from tqdm import tqdm
586+ pbar = tqdm (total = total_tiles , desc = progress_desc or "bin2cell.cellsam" , unit = "tile" )
587+ except Exception :
588+ pass
589+
590+ def _flush_buf ():
591+ nonlocal buf_rows , buf_cols , buf_vals , buf_pixels
592+ if not buf_rows :
593+ return
594+ r = np .concatenate (buf_rows )
595+ c = np .concatenate (buf_cols )
596+ v = np .concatenate (buf_vals )
597+ part_path = _os .path .join (tmp_dir , f"part_{ len (part_files )} .npz" )
598+ np .savez_compressed (part_path , r = r , c = c , v = v )
599+ part_files .append (part_path )
600+ buf_rows .clear (); buf_cols .clear (); buf_vals .clear ()
601+ buf_pixels = 0
602+
603+ # Lazy tile reader via zarr (avoids loading full image into memory)
604+ _tif = tifffile .TiffFile (image_path )
605+ _store = _tif .pages [0 ].aszarr ()
606+ import zarr as _zarr
607+ _img_z = _zarr .open (_store , mode = 'r' )
608+
609+ for ys in y_starts :
610+ for xs in x_starts :
611+ y1 = min (ys + block_size , H )
612+ x1 = min (xs + block_size , W )
613+ if _img_z .ndim == 2 :
614+ tile_gray = np .array (_img_z [ys :y1 , xs :x1 ])
615+ tile = np .stack ([tile_gray ] * 3 , axis = - 1 )
616+ elif _img_z .shape [- 1 ] >= 3 :
617+ tile = np .array (_img_z [ys :y1 , xs :x1 , :3 ])
618+ else :
619+ tile = np .array (_img_z [ys :y1 , xs :x1 ])
620+
621+ try :
622+ result = segment_cellular_image (
623+ tile , model , device = device ,
624+ bbox_threshold = bbox_threshold ,
625+ ** {k : v for k , v in kwargs .items ()
626+ if k in ('normalize' , 'postprocess' , 'remove_boundaries' , 'fast' )},
627+ )
628+ masks = result [0 ] if result [0 ] is not None else None
629+ except Exception :
630+ masks = None
631+
632+ if masks is not None and masks .max () > 0 :
633+ # Vectorised relabel: shift all labels at once
634+ n_cells = masks .max ()
635+ nz = masks > 0
636+ ry , rx = np .where (nz )
637+ shifted = masks [nz ] + (next_label - 1 )
638+ buf_rows .append ((ry + ys ).astype (np .int32 ))
639+ buf_cols .append ((rx + xs ).astype (np .int32 ))
640+ buf_vals .append (shifted .astype (np .int32 ))
641+ buf_pixels += len (ry )
642+ next_label += n_cells
643+
644+ tiles_done += 1
645+ if pbar is not None :
646+ pbar .update (1 )
647+ elif show_progress and tiles_done % max (1 , total_tiles // 20 ) == 0 :
648+ print (f"[bin2cell.cellsam] { tiles_done } /{ total_tiles } tiles" , flush = True )
649+
650+ # Flush to disk when buffer gets large (~50M pixels)
651+ if buf_pixels > 50_000_000 :
652+ _flush_buf ()
653+
654+ if pbar is not None :
655+ pbar .close ()
656+
657+ # Final flush
658+ _flush_buf ()
659+
660+ # Merge all parts into a single CSR matrix
661+ all_r , all_c , all_v = [], [], []
662+ for pf in part_files :
663+ d = np .load (pf )
664+ all_r .append (d ['r' ]); all_c .append (d ['c' ]); all_v .append (d ['v' ])
665+ _os .remove (pf )
666+ _os .rmdir (tmp_dir )
667+
668+ if all_r :
669+ rows = np .concatenate (all_r )
670+ cols = np .concatenate (all_c )
671+ vals = np .concatenate (all_v )
672+ labels_csr = sparse .csr_matrix ((vals , (rows , cols )), shape = (H , W ), dtype = np .int32 )
673+ else :
674+ labels_csr = sparse .csr_matrix ((H , W ), dtype = np .int32 )
675+
676+ _tif .close ()
677+ print (f"CellSAM segmentation complete: { next_label - 1 } cells" , flush = True )
678+ sparse .save_npz (labels_npz_path , labels_csr )
679+
680+
681+ def view_cellseg_labels (image_path , labels_npz_path , crop , ** kwargs ):
493682 '''
494683 Use StarDist's label rendering to view segmentation results in a crop
495684 of the input image.
@@ -2257,3 +2446,8 @@ def _plot_boundaries_matplotlib(cell_adata, color, segmentation_key, library_id,
22572446 ax .set_ylabel ('Y coordinate' )
22582447
22592448 return ax
2449+
2450+ # Backward-compatible aliases
2451+ stardist = cellseg
2452+ view_stardist_labels = view_cellseg_labels
2453+
0 commit comments