diff --git a/experanto/interpolators.py b/experanto/interpolators.py index 4e6a077..13b2b7d 100644 --- a/experanto/interpolators.py +++ b/experanto/interpolators.py @@ -55,8 +55,10 @@ class Interpolator: Experiment : High-level interface that manages multiple interpolators. """ - def __init__(self, root_folder: str) -> None: + def __init__(self, root_folder: str | Path) -> None: self.root_folder = Path(root_folder) + self.n_signals: int = 0 + self.interpolation_mode: str | None = None self.start_time = None self.end_time = None # Valid interval can be different to start time and end time. @@ -67,6 +69,53 @@ def load_meta(self): meta = yaml.safe_load(f) return meta + def _resolve_indices(self, neuron_ids, neuron_indices): + if neuron_ids is None and neuron_indices is None: + return None + + if neuron_ids is not None: + unit_ids = np.load(self.root_folder / "meta/unit_ids.npy") + ids_to_indexes = [] + + for nid in neuron_ids: + match = np.where(unit_ids == nid)[0] + if len(match) == 0: + raise ValueError(f"Neuron id {nid} not found") + ids_to_indexes.append(int(match[0])) + + if neuron_indices is None: + return ids_to_indexes + + if set(ids_to_indexes) != set(neuron_indices): + raise ValueError( + "neuron_ids and neuron_indices refer to different neurons" + ) + + warnings.warn( + "Both neuron_ids and neuron_indices provided; using neuron_indices", + stacklevel=2, + ) + + return self._validate_indices(neuron_indices) + + def _validate_indices(self, neuron_indices): + try: + indexes_seq = list(neuron_indices) + except TypeError as exc: + raise TypeError("neuron_indices must be iterable") from exc + + if not all(isinstance(i, (int, np.integer)) for i in indexes_seq): + raise TypeError("neuron_indices must contain integers") + + if indexes_seq: + if min(indexes_seq) < 0 or max(indexes_seq) >= self.n_signals: + raise ValueError("neuron_indices out of bounds") + + if len(set(indexes_seq)) != len(indexes_seq): + raise ValueError("neuron_indices contain duplicates") + + return indexes_seq + @abstractmethod def interpolate( self, times: np.ndarray, return_valid: bool = False @@ -165,6 +214,10 @@ class SequenceInterpolator(Interpolator): If True, subtracts mean during normalization. normalize_std_threshold : float, optional Minimum std threshold to prevent division by near-zero values. + neuron_ids : list, optional + Biological neuron IDs to include. Converted to indexes using meta/unit_ids.npy. + neuron_indices : list, optional + Column indexes of neurons to include. **kwargs Additional keyword arguments (ignored). @@ -190,10 +243,12 @@ class SequenceInterpolator(Interpolator): def __init__( self, - root_folder: str, + root_folder: str | Path, cache_data: bool = False, # already cached, put it here for consistency keep_nans: bool = False, interpolation_mode: str = "nearest_neighbor", + neuron_ids: list[int] | None = None, + neuron_indices: list[int] | None = None, normalize: bool = False, normalize_subtract_mean: bool = False, normalize_std_threshold: float | None = None, # or 0.01 @@ -203,6 +258,7 @@ def __init__( meta = self.load_meta() self.keep_nans = keep_nans self.interpolation_mode = interpolation_mode + self.neuron_ids = neuron_ids self.normalize = normalize self.normalize_subtract_mean = normalize_subtract_mean self.normalize_std_threshold = normalize_std_threshold @@ -210,11 +266,13 @@ def __init__( self.time_delta = 1.0 / self.sampling_rate self.start_time = meta["start_time"] self.end_time = meta["end_time"] - self.is_mem_mapped = meta["is_mem_mapped"] if "is_mem_mapped" in meta else False + self.is_mem_mapped = meta.get("is_mem_mapped", False) # Valid interval can be different to start time and end time. self.valid_interval = TimeInterval(self.start_time, self.end_time) self.n_signals = meta["n_signals"] + self.neuron_indices = self._resolve_indices(neuron_ids, neuron_indices) + # read .mem (memmap) or .npy file if self.is_mem_mapped: self._data = np.memmap( @@ -223,28 +281,41 @@ def __init__( mode="r", shape=(meta["n_timestamps"], meta["n_signals"]), ) - - if cache_data: - self._data = np.array(self._data).astype( - np.float32 - ) # Convert memmap to ndarray else: self._data = np.load(self.root_folder / "data.npy") + # Apply indexing BEFORE caching + if self.neuron_indices is not None: + self._data = self._data[:, self.neuron_indices] + self.n_signals = len(self.neuron_indices) + + # Cache only selected data + if self.is_mem_mapped and cache_data: + self._data = np.array(self._data, dtype=np.float32) + if self.normalize: self.normalize_init() def normalize_init(self): - self.mean = np.load(self.root_folder / "meta/means.npy") - self.std = np.load(self.root_folder / "meta/stds.npy") + mean = np.load(self.root_folder / "meta/means.npy") # shape: (n_total_signals,) + std = np.load(self.root_folder / "meta/stds.npy") + + # Filter to selected neurons, before assertion + if self.neuron_indices is not None: + mean = mean[self.neuron_indices] + std = std[self.neuron_indices] + + self.mean = mean.T + self.std = std.T + + # Now n_signals and shape are guaranteed to match assert ( self.mean.shape[0] == self.n_signals - ), f"mean shape does not match: {self.mean.shape} vs {self._data.shape}" + ), f"mean shape does not match: {self.mean.shape[0]} vs {self.n_signals}" assert ( self.std.shape[0] == self.n_signals - ), f"std shape does not match: {self.std.shape} vs {self._data.shape}" - self.mean = self.mean.T - self.std = self.std.T + ), f"std shape does not match: {self.std.shape[0]} vs {self.n_signals}" + if self.normalize_std_threshold: threshold = self.normalize_std_threshold * np.nanmean(self.std) idx = self.std > threshold @@ -358,7 +429,7 @@ class PhaseShiftedSequenceInterpolator(SequenceInterpolator): def __init__( self, - root_folder: str, + root_folder: str | Path, cache_data: bool = False, # already cached, put it here for consistency keep_nans: bool = False, interpolation_mode: str = "nearest_neighbor", @@ -379,6 +450,10 @@ def __init__( ) self._phase_shifts = np.load(self.root_folder / "meta/phase_shifts.npy") + # Forward the required indexes + if self.neuron_indices is not None: + self._phase_shifts = self._phase_shifts[self.neuron_indices] + self.valid_interval = TimeInterval( self.start_time + (np.max(self._phase_shifts) if len(self._phase_shifts) > 0 else 0), @@ -504,7 +579,7 @@ class ScreenInterpolator(Interpolator): def __init__( self, - root_folder: str, + root_folder: str | Path, cache_data: bool = False, # New parameter rescale: bool = False, rescale_size: tuple[int, int] | None = None, @@ -713,7 +788,7 @@ class TimeIntervalInterpolator(Interpolator): *i*-th valid time falls within any interval for the *j*-th label. """ - def __init__(self, root_folder: str, cache_data: bool = False, **kwargs): + def __init__(self, root_folder: str | Path, cache_data: bool = False, **kwargs): super().__init__(root_folder) self.cache_data = cache_data @@ -1007,11 +1082,13 @@ class SpikeInterpolator(Interpolator): def __init__( self, - root_folder: str, + root_folder: str | Path, cache_data: bool = False, interpolation_window: float = 0.3, interpolation_align: str = "center", smoothing_sigma: float = 0.0, + neuron_ids: list[int] | None = None, + neuron_indices: list[int] | None = None, ): super().__init__(root_folder) @@ -1061,6 +1138,33 @@ def __init__( else: self.spikes = np.load(self.dat_path) + neuron_indices = self._resolve_indices(neuron_ids, neuron_indices) + + # If specific neuron indexes are requested, rebuild the spike array so that it + # only contains spikes from the selected neurons. We also rebuild the indices + # array so that it matches the new compacted spike array. + if neuron_indices is not None: + if len(neuron_indices) == 0: + # No neurons selected: represent this as an empty spike train + self.spikes = np.empty((0,), dtype=self.spikes.dtype) + self.indices = np.array([0], dtype=np.int64) + self.n_signals = 0 + else: + new_indices = [0] + new_spikes = [] + + for i in neuron_indices: + start = self.indices[i] + end = self.indices[i + 1] + neuron_spikes = self.spikes[start:end] + + new_spikes.append(neuron_spikes) + new_indices.append(new_indices[-1] + len(neuron_spikes)) + + self.spikes = np.concatenate(new_spikes) + self.indices = np.array(new_indices, dtype=np.int64) + self.n_signals = len(neuron_indices) + def interpolate( self, times: np.ndarray, return_valid: bool = False ) -> tuple[np.ndarray, np.ndarray] | np.ndarray: diff --git a/tests/test_sequence_interpolator.py b/tests/test_sequence_interpolator.py index 8ec6993..08b4df3 100644 --- a/tests/test_sequence_interpolator.py +++ b/tests/test_sequence_interpolator.py @@ -130,6 +130,7 @@ def test_nearest_neighbor_interpolation_with_phase_shifts( "shifts_per_signal": True, } ) as (timestamps, data, shift, seq_interp): + assert shift is not None assert isinstance( seq_interp, PhaseShiftedSequenceInterpolator ), "Interpolation object is not a PhaseShiftedSequenceInterpolator" @@ -282,6 +283,7 @@ def test_linear_interpolation_with_phase_shifts( }, interp_kwargs={"keep_nans": keep_nans}, ) as (timestamps, data, shift, seq_interp): + assert shift is not None assert isinstance( seq_interp, PhaseShiftedSequenceInterpolator ), "Not a PhaseShiftedSequenceInterpolator" @@ -389,6 +391,7 @@ def test_interpolation_with_phase_shifts_for_invalid_times( }, interp_kwargs={"keep_nans": keep_nans}, ) as (_, _, phase_shifts, seq_interp): + assert phase_shifts is not None assert isinstance( seq_interp, PhaseShiftedSequenceInterpolator ), "Interpolation object is not a PhaseShiftedSequenceInterpolator" @@ -547,5 +550,66 @@ def test_interpolation_mode_not_implemented(): seq_interp.interpolate(np.array([0.0, 1.0, 2.0]), return_valid=True) +def test_sequence_interpolator_indexes_selection(): + with sequence_data_and_interpolator( + data_kwargs={"n_signals": 10, "use_mem_mapped": False} + ) as (_, data, _, seq_interp): + + seq_interp = SequenceInterpolator( + seq_interp.root_folder, neuron_indices=[1, 3, 5] + ) + + assert seq_interp.n_signals == 3 + assert seq_interp._data.shape[1] == 3 + + +def test_sequence_interpolator_neuron_ids_selection(tmp_path): + with sequence_data_and_interpolator(data_kwargs={"n_signals": 4}) as ( + _, + _, + _, + seq_interp, + ): + + meta_folder = seq_interp.root_folder / "meta" + meta_folder.mkdir(exist_ok=True) + + unit_ids = np.array([10, 20, 30, 40]) + np.save(meta_folder / "unit_ids.npy", unit_ids) + + interp = SequenceInterpolator(seq_interp.root_folder, neuron_ids=[20, 40]) + + assert interp.n_signals == 2 + + +def test_sequence_interpolator_neuron_ids_indexes_mismatch(): + with sequence_data_and_interpolator(data_kwargs={"n_signals": 5}) as ( + _, + _, + _, + seq_interp, + ): + meta_folder = seq_interp.root_folder / "meta" + meta_folder.mkdir(exist_ok=True) + np.save(meta_folder / "unit_ids.npy", np.arange(5)) + + with pytest.raises(ValueError): + SequenceInterpolator( + seq_interp.root_folder, neuron_ids=[1], neuron_indices=[2] + ) + + +def test_phase_shift_interpolator_indexes_filtering(): + with sequence_data_and_interpolator( + data_kwargs={"n_signals": 6, "shifts_per_signal": True} + ) as (_, _, phase_shifts, seq_interp): + + interp = PhaseShiftedSequenceInterpolator( + seq_interp.root_folder, neuron_indices=[0, 2, 4] + ) + + assert len(interp._phase_shifts) == 3 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/test_spikes_interpolator.py b/tests/test_spikes_interpolator.py index aa1c951..36754e5 100644 --- a/tests/test_spikes_interpolator.py +++ b/tests/test_spikes_interpolator.py @@ -124,7 +124,7 @@ def test_spikes_cache_data(): data_kwargs={"duration": 5.0, "n_neurons": 2}, interp_kwargs={"cache_data": True}, ) as (gt_spikes, interp): - + assert isinstance(interp, SpikeInterpolator) assert isinstance(interp.spikes, np.ndarray) assert not isinstance(interp.spikes, np.memmap) @@ -169,6 +169,7 @@ def test_memmap_loading(): }, interp_kwargs={"cache_data": False}, ) as (gt_spikes, interp): + assert isinstance(interp, SpikeInterpolator) assert isinstance(interp.spikes, np.memmap), "Expected a memmap object" # Verify content matches ground truth @@ -186,6 +187,7 @@ def test_memmap_loading(): }, interp_kwargs={"cache_data": True}, ) as (gt_spikes, interp): + assert isinstance(interp, SpikeInterpolator) assert isinstance( interp.spikes, np.ndarray ), "Expected a numpy array (loaded into RAM)" @@ -193,3 +195,44 @@ def test_memmap_loading(): flat_gt = np.concatenate(gt_spikes) np.testing.assert_allclose(interp.spikes, flat_gt) + + +def test_spikes_neuron_indices_filtering(): + with spikes_data_and_interpolator(data_kwargs={"n_neurons": 5}) as ( + gt_spikes, + interp, + ): + + interp = SpikeInterpolator(interp.root_folder, neuron_indices=[1, 3]) + + assert interp.n_signals == 2 + + # verify spikes correspond to selected neurons + selected = [gt_spikes[1], gt_spikes[3]] + flat_selected = np.concatenate(selected) + + np.testing.assert_allclose(interp.spikes, flat_selected) + + +def test_spikes_neuron_ids_indices_mismatch(): + with spikes_data_and_interpolator(data_kwargs={"n_neurons": 5}) as (_, interp): + + meta_folder = interp.root_folder / "meta" + meta_folder.mkdir(parents=True, exist_ok=True) + np.save(meta_folder / "unit_ids.npy", np.arange(5)) + + with pytest.raises(ValueError): + SpikeInterpolator( + interp.root_folder, + neuron_ids=[0, 1], + neuron_indices=[2, 3], + ) + + +def test_spikes_empty_selection(): + with spikes_data_and_interpolator(data_kwargs={"n_neurons": 5}) as (_, interp): + + interp = SpikeInterpolator(interp.root_folder, neuron_indices=[]) + + assert interp.n_signals == 0 + assert interp.spikes.shape == (0,)