Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions experanto/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
"""
Experanto: Interpolating recordings and stimuli in neuroscience experiments.

Main exports:
- Experiment: Load and query multi-modal experiment data
- ChunkDataset: PyTorch Dataset for training
- Interpolator: Base class for modality interpolators
"""

from .experiment import Experiment
from .datasets import ChunkDataset
from .interpolators import Interpolator

__all__ = [
"Experiment",
"ChunkDataset",
"Interpolator",
]

__version__ = "0.1.0"
49 changes: 48 additions & 1 deletion experanto/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,12 @@ class SequenceInterpolator(Interpolator):
If True, subtracts mean during normalization.
normalize_std_threshold : float, optional
Minimum std threshold to prevent division by near-zero values.
unit_ids : list, optional
List of unit IDs to select from unit_ids.npy. If provided, only these
neurons will be returned.
neuron_indexes : list, optional
List of neuron indexes (positions) to select. If provided, only these
neurons will be returned.
**kwargs
Additional keyword arguments (ignored).

Expand Down Expand Up @@ -193,6 +199,8 @@ def __init__(
root_folder: str,
cache_data: bool = False, # already cached, put it here for consistency
keep_nans: bool = False,
unit_ids=None,
neuron_indexes=None,
interpolation_mode: str = "nearest_neighbor",
normalize: bool = False,
normalize_subtract_mean: bool = False,
Expand All @@ -215,6 +223,17 @@ def __init__(
self.valid_interval = TimeInterval(self.start_time, self.end_time)

self.n_signals = meta["n_signals"]

# Filter by unit_ids or neuron_indexes if provided
if unit_ids is not None:
available_ids = np.load(self.root_folder / "meta/unit_ids.npy")
neuron_indexes = [np.where(available_ids == uid)[0][0] for uid in unit_ids]

if neuron_indexes is not None:
self._neuron_indexes = np.array(neuron_indexes)
else:
self._neuron_indexes = None

# read .mem (memmap) or .npy file
if self.is_mem_mapped:
self._data = np.memmap(
Expand Down Expand Up @@ -283,6 +302,8 @@ def interpolate(

if self.interpolation_mode == "nearest_neighbor":
data = self._data[idx_lower]
if self._neuron_indexes is not None:
data = data[:, self._neuron_indexes]

return (data, valid) if return_valid else data

Expand Down Expand Up @@ -315,6 +336,10 @@ def interpolate(
data_lower = self._data[idx_lower]
data_upper = self._data[idx_upper]

if self._neuron_indexes is not None:
data_lower = data_lower[:, self._neuron_indexes]
data_upper = data_upper[:, self._neuron_indexes]

interpolated = (
lower_signal_ratio * data_lower + upper_signal_ratio * data_upper
)
Expand Down Expand Up @@ -1003,6 +1028,12 @@ class SpikeInterpolator(Interpolator):
sigma should be ~3.
Set to 0.0 to disable smoothing.
Default is 0.0.
unit_ids : list, optional
List of unit IDs to select from unit_ids.npy. If provided, only these
neurons will be returned.
neuron_indexes : list, optional
List of neuron indexes (positions) to select. If provided, only these
neurons will be returned.
"""

def __init__(
Expand All @@ -1012,6 +1043,8 @@ def __init__(
interpolation_window: float = 0.3,
interpolation_align: str = "center",
smoothing_sigma: float = 0.0,
unit_ids=None,
neuron_indexes=None,
):
super().__init__(root_folder)

Expand All @@ -1034,6 +1067,17 @@ def __init__(
# Ensure indices are typed correctly for Numba
self.indices = np.array(meta["spike_indices"]).astype(np.int64)
self.n_signals = len(self.indices) - 1

# Filter by unit_ids or neuron_indexes if provided
if unit_ids is not None:
available_ids = np.load(self.root_folder / "meta/unit_ids.npy")
neuron_indexes = [np.where(available_ids == uid)[0][0] for uid in unit_ids]

if neuron_indexes is not None:
self._neuron_indexes = np.array(neuron_indexes)
else:
self._neuron_indexes = None

meta_n_signals = meta.get("n_signals")
if meta_n_signals is not None and meta_n_signals != self.n_signals:
raise ValueError(
Expand Down Expand Up @@ -1110,6 +1154,9 @@ def interpolate(
# sigma should be ~3.
counts = gaussian_filter1d(counts, sigma=self.smoothing_sigma, axis=0)

if self._neuron_indexes is not None:
counts = counts[:, self._neuron_indexes]

return (counts, valid) if return_valid else counts

def close(self):
Expand All @@ -1119,4 +1166,4 @@ def close(self):
_mmap_obj = getattr(self.spikes, "_mmap", None)
if _mmap_obj is not None:
_mmap_obj.close()
del self.spikes
del self.spikes
Loading