-
Notifications
You must be signed in to change notification settings - Fork 35
Add neuron selection support to SequenceInterpolator and SpikeInterpolator #126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
98ce7f4
9341ccb
a48b55f
69aed17
6e19ecd
ff9f0b3
b1eb129
4416b2f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,8 +55,10 @@ class Interpolator: | |
| Experiment : High-level interface that manages multiple interpolators. | ||
| """ | ||
|
|
||
| def __init__(self, root_folder: str) -> None: | ||
| def __init__(self, root_folder: str | Path) -> None: | ||
| self.root_folder = Path(root_folder) | ||
| self.n_signals: int = 0 | ||
| self.interpolation_mode: str | None = None | ||
| self.start_time = None | ||
| self.end_time = None | ||
| # Valid interval can be different to start time and end time. | ||
|
|
@@ -67,6 +69,53 @@ def load_meta(self): | |
| meta = yaml.safe_load(f) | ||
| return meta | ||
|
|
||
| def _resolve_indices(self, neuron_ids, neuron_indices): | ||
| if neuron_ids is None and neuron_indices is None: | ||
| return None | ||
|
|
||
| if neuron_ids is not None: | ||
| unit_ids = np.load(self.root_folder / "meta/unit_ids.npy") | ||
| ids_to_indexes = [] | ||
|
|
||
| for nid in neuron_ids: | ||
| match = np.where(unit_ids == nid)[0] | ||
| if len(match) == 0: | ||
| raise ValueError(f"Neuron id {nid} not found") | ||
| ids_to_indexes.append(int(match[0])) | ||
|
|
||
| if neuron_indices is None: | ||
| return ids_to_indexes | ||
|
|
||
| if set(ids_to_indexes) != set(neuron_indices): | ||
| raise ValueError( | ||
| "neuron_ids and neuron_indices refer to different neurons" | ||
| ) | ||
|
|
||
| warnings.warn( | ||
| "Both neuron_ids and neuron_indices provided; using neuron_indices", | ||
| stacklevel=2, | ||
| ) | ||
|
|
||
| return self._validate_indices(neuron_indices) | ||
|
|
||
| def _validate_indices(self, neuron_indices): | ||
| try: | ||
| indexes_seq = list(neuron_indices) | ||
| except TypeError as exc: | ||
| raise TypeError("neuron_indices must be iterable") from exc | ||
|
|
||
| if not all(isinstance(i, (int, np.integer)) for i in indexes_seq): | ||
| raise TypeError("neuron_indices must contain integers") | ||
|
|
||
| if indexes_seq: | ||
| if min(indexes_seq) < 0 or max(indexes_seq) >= self.n_signals: | ||
| raise ValueError("neuron_indices out of bounds") | ||
|
|
||
| if len(set(indexes_seq)) != len(indexes_seq): | ||
| raise ValueError("neuron_indices contain duplicates") | ||
|
|
||
| return indexes_seq | ||
|
|
||
| @abstractmethod | ||
| def interpolate( | ||
| self, times: np.ndarray, return_valid: bool = False | ||
|
|
@@ -165,6 +214,10 @@ class SequenceInterpolator(Interpolator): | |
| If True, subtracts mean during normalization. | ||
| normalize_std_threshold : float, optional | ||
| Minimum std threshold to prevent division by near-zero values. | ||
| neuron_ids : list, optional | ||
| Biological neuron IDs to include. Converted to indexes using meta/unit_ids.npy. | ||
| neuron_indices : list, optional | ||
| Column indexes of neurons to include. | ||
| **kwargs | ||
| Additional keyword arguments (ignored). | ||
|
|
||
|
|
@@ -190,10 +243,12 @@ class SequenceInterpolator(Interpolator): | |
|
|
||
| def __init__( | ||
| self, | ||
| root_folder: str, | ||
| root_folder: str | Path, | ||
| cache_data: bool = False, # already cached, put it here for consistency | ||
| keep_nans: bool = False, | ||
| interpolation_mode: str = "nearest_neighbor", | ||
| neuron_ids: list[int] | None = None, | ||
| neuron_indices: list[int] | None = None, | ||
| normalize: bool = False, | ||
| normalize_subtract_mean: bool = False, | ||
| normalize_std_threshold: float | None = None, # or 0.01 | ||
|
|
@@ -203,18 +258,21 @@ def __init__( | |
| meta = self.load_meta() | ||
| self.keep_nans = keep_nans | ||
| self.interpolation_mode = interpolation_mode | ||
| self.neuron_ids = neuron_ids | ||
| self.normalize = normalize | ||
| self.normalize_subtract_mean = normalize_subtract_mean | ||
| self.normalize_std_threshold = normalize_std_threshold | ||
| self.sampling_rate = meta["sampling_rate"] | ||
| self.time_delta = 1.0 / self.sampling_rate | ||
| self.start_time = meta["start_time"] | ||
| self.end_time = meta["end_time"] | ||
| self.is_mem_mapped = meta["is_mem_mapped"] if "is_mem_mapped" in meta else False | ||
| self.is_mem_mapped = meta.get("is_mem_mapped", False) | ||
| # Valid interval can be different to start time and end time. | ||
| self.valid_interval = TimeInterval(self.start_time, self.end_time) | ||
|
|
||
| self.n_signals = meta["n_signals"] | ||
| self.neuron_indices = self._resolve_indices(neuron_ids, neuron_indices) | ||
|
|
||
| # read .mem (memmap) or .npy file | ||
| if self.is_mem_mapped: | ||
| self._data = np.memmap( | ||
|
|
@@ -223,28 +281,41 @@ def __init__( | |
| mode="r", | ||
| shape=(meta["n_timestamps"], meta["n_signals"]), | ||
| ) | ||
|
|
||
| if cache_data: | ||
| self._data = np.array(self._data).astype( | ||
| np.float32 | ||
| ) # Convert memmap to ndarray | ||
| else: | ||
| self._data = np.load(self.root_folder / "data.npy") | ||
|
|
||
| # Apply indexing BEFORE caching | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @pollytur Have you thought about this? We might need to set caching to True if neurons are indexed. Or we find a workaround (I haven't investigated it yet).
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. great catch, no I have not thought about it tbh
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or maybe instead of directly indexing, we could iteratively fetch only the required columns (or in chunks) and optionally cache them.
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Vrittigyl could you please provide an implementation for what you proposed and show that it still caches the data if cache_data == True and doesn't cache the data otherwise, also when we select only certain neurons? Thanks! |
||
| if self.neuron_indices is not None: | ||
| self._data = self._data[:, self.neuron_indices] | ||
| self.n_signals = len(self.neuron_indices) | ||
|
|
||
| # Cache only selected data | ||
| if self.is_mem_mapped and cache_data: | ||
| self._data = np.array(self._data, dtype=np.float32) | ||
|
|
||
|
Comment on lines
273
to
+295
|
||
| if self.normalize: | ||
| self.normalize_init() | ||
|
|
||
| def normalize_init(self): | ||
| self.mean = np.load(self.root_folder / "meta/means.npy") | ||
| self.std = np.load(self.root_folder / "meta/stds.npy") | ||
| mean = np.load(self.root_folder / "meta/means.npy") # shape: (n_total_signals,) | ||
| std = np.load(self.root_folder / "meta/stds.npy") | ||
|
|
||
| # Filter to selected neurons, before assertion | ||
| if self.neuron_indices is not None: | ||
| mean = mean[self.neuron_indices] | ||
| std = std[self.neuron_indices] | ||
|
|
||
| self.mean = mean.T | ||
| self.std = std.T | ||
|
|
||
| # Now n_signals and shape are guaranteed to match | ||
| assert ( | ||
| self.mean.shape[0] == self.n_signals | ||
| ), f"mean shape does not match: {self.mean.shape} vs {self._data.shape}" | ||
| ), f"mean shape does not match: {self.mean.shape[0]} vs {self.n_signals}" | ||
| assert ( | ||
| self.std.shape[0] == self.n_signals | ||
| ), f"std shape does not match: {self.std.shape} vs {self._data.shape}" | ||
| self.mean = self.mean.T | ||
| self.std = self.std.T | ||
| ), f"std shape does not match: {self.std.shape[0]} vs {self.n_signals}" | ||
|
|
||
| if self.normalize_std_threshold: | ||
| threshold = self.normalize_std_threshold * np.nanmean(self.std) | ||
| idx = self.std > threshold | ||
|
|
@@ -358,7 +429,7 @@ class PhaseShiftedSequenceInterpolator(SequenceInterpolator): | |
|
|
||
| def __init__( | ||
| self, | ||
| root_folder: str, | ||
| root_folder: str | Path, | ||
| cache_data: bool = False, # already cached, put it here for consistency | ||
| keep_nans: bool = False, | ||
| interpolation_mode: str = "nearest_neighbor", | ||
|
|
@@ -379,6 +450,10 @@ def __init__( | |
| ) | ||
|
|
||
| self._phase_shifts = np.load(self.root_folder / "meta/phase_shifts.npy") | ||
| # Forward the required indexes | ||
| if self.neuron_indices is not None: | ||
| self._phase_shifts = self._phase_shifts[self.neuron_indices] | ||
|
|
||
| self.valid_interval = TimeInterval( | ||
| self.start_time | ||
| + (np.max(self._phase_shifts) if len(self._phase_shifts) > 0 else 0), | ||
|
|
@@ -504,7 +579,7 @@ class ScreenInterpolator(Interpolator): | |
|
|
||
| def __init__( | ||
| self, | ||
| root_folder: str, | ||
| root_folder: str | Path, | ||
| cache_data: bool = False, # New parameter | ||
| rescale: bool = False, | ||
| rescale_size: tuple[int, int] | None = None, | ||
|
|
@@ -713,7 +788,7 @@ class TimeIntervalInterpolator(Interpolator): | |
| *i*-th valid time falls within any interval for the *j*-th label. | ||
| """ | ||
|
|
||
| def __init__(self, root_folder: str, cache_data: bool = False, **kwargs): | ||
| def __init__(self, root_folder: str | Path, cache_data: bool = False, **kwargs): | ||
| super().__init__(root_folder) | ||
| self.cache_data = cache_data | ||
|
|
||
|
|
@@ -1007,11 +1082,13 @@ class SpikeInterpolator(Interpolator): | |
|
|
||
| def __init__( | ||
| self, | ||
| root_folder: str, | ||
| root_folder: str | Path, | ||
| cache_data: bool = False, | ||
| interpolation_window: float = 0.3, | ||
| interpolation_align: str = "center", | ||
| smoothing_sigma: float = 0.0, | ||
| neuron_ids: list[int] | None = None, | ||
| neuron_indices: list[int] | None = None, | ||
| ): | ||
| super().__init__(root_folder) | ||
|
|
||
|
|
@@ -1061,6 +1138,33 @@ def __init__( | |
| else: | ||
| self.spikes = np.load(self.dat_path) | ||
|
|
||
| neuron_indices = self._resolve_indices(neuron_ids, neuron_indices) | ||
|
|
||
| # If specific neuron indexes are requested, rebuild the spike array so that it | ||
| # only contains spikes from the selected neurons. We also rebuild the indices | ||
| # array so that it matches the new compacted spike array. | ||
| if neuron_indices is not None: | ||
| if len(neuron_indices) == 0: | ||
| # No neurons selected: represent this as an empty spike train | ||
| self.spikes = np.empty((0,), dtype=self.spikes.dtype) | ||
| self.indices = np.array([0], dtype=np.int64) | ||
| self.n_signals = 0 | ||
| else: | ||
| new_indices = [0] | ||
| new_spikes = [] | ||
|
|
||
| for i in neuron_indices: | ||
| start = self.indices[i] | ||
| end = self.indices[i + 1] | ||
| neuron_spikes = self.spikes[start:end] | ||
|
|
||
| new_spikes.append(neuron_spikes) | ||
| new_indices.append(new_indices[-1] + len(neuron_spikes)) | ||
|
|
||
| self.spikes = np.concatenate(new_spikes) | ||
| self.indices = np.array(new_indices, dtype=np.int64) | ||
| self.n_signals = len(neuron_indices) | ||
|
|
||
| def interpolate( | ||
| self, times: np.ndarray, return_valid: bool = False | ||
| ) -> tuple[np.ndarray, np.ndarray] | np.ndarray: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, but I don't like putting those functions into Interpolator abstract class as these functions only make sense for neuron-related interpolators. Can we move these functions outside of any class and just pass all needed params into the function (as we can't use
selfthen anymore)? I would then loadunit_idsbefore calling_resolve_indicesand passunit_idsdirectly instead ofself.root_folder.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will surely fix it sir