Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 122 additions & 18 deletions experanto/interpolators.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ class Interpolator:
Experiment : High-level interface that manages multiple interpolators.
"""

def __init__(self, root_folder: str) -> None:
def __init__(self, root_folder: str | Path) -> None:
self.root_folder = Path(root_folder)
self.n_signals: int = 0
self.interpolation_mode: str | None = None
self.start_time = None
self.end_time = None
# Valid interval can be different to start time and end time.
Expand All @@ -67,6 +69,53 @@ def load_meta(self):
meta = yaml.safe_load(f)
return meta

def _resolve_indices(self, neuron_ids, neuron_indices):
if neuron_ids is None and neuron_indices is None:
return None

if neuron_ids is not None:
unit_ids = np.load(self.root_folder / "meta/unit_ids.npy")
ids_to_indexes = []

for nid in neuron_ids:
match = np.where(unit_ids == nid)[0]
if len(match) == 0:
raise ValueError(f"Neuron id {nid} not found")
ids_to_indexes.append(int(match[0]))

if neuron_indices is None:
return ids_to_indexes

if set(ids_to_indexes) != set(neuron_indices):
raise ValueError(
"neuron_ids and neuron_indices refer to different neurons"
)

warnings.warn(
"Both neuron_ids and neuron_indices provided; using neuron_indices",
stacklevel=2,
)

return self._validate_indices(neuron_indices)

def _validate_indices(self, neuron_indices):
try:
indexes_seq = list(neuron_indices)
except TypeError as exc:
raise TypeError("neuron_indices must be iterable") from exc

if not all(isinstance(i, (int, np.integer)) for i in indexes_seq):
raise TypeError("neuron_indices must contain integers")

if indexes_seq:
if min(indexes_seq) < 0 or max(indexes_seq) >= self.n_signals:
raise ValueError("neuron_indices out of bounds")

if len(set(indexes_seq)) != len(indexes_seq):
raise ValueError("neuron_indices contain duplicates")

return indexes_seq
Comment on lines +72 to +117
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, but I don't like putting those functions into Interpolator abstract class as these functions only make sense for neuron-related interpolators. Can we move these functions outside of any class and just pass all needed params into the function (as we can't use self then anymore)? I would then load unit_ids before calling _resolve_indices and pass unit_ids directly instead of self.root_folder.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will surely fix it sir


@abstractmethod
def interpolate(
self, times: np.ndarray, return_valid: bool = False
Expand Down Expand Up @@ -165,6 +214,10 @@ class SequenceInterpolator(Interpolator):
If True, subtracts mean during normalization.
normalize_std_threshold : float, optional
Minimum std threshold to prevent division by near-zero values.
neuron_ids : list, optional
Biological neuron IDs to include. Converted to indexes using meta/unit_ids.npy.
neuron_indices : list, optional
Column indexes of neurons to include.
**kwargs
Additional keyword arguments (ignored).

Expand All @@ -190,10 +243,12 @@ class SequenceInterpolator(Interpolator):

def __init__(
self,
root_folder: str,
root_folder: str | Path,
cache_data: bool = False, # already cached, put it here for consistency
keep_nans: bool = False,
interpolation_mode: str = "nearest_neighbor",
neuron_ids: list[int] | None = None,
neuron_indices: list[int] | None = None,
normalize: bool = False,
normalize_subtract_mean: bool = False,
normalize_std_threshold: float | None = None, # or 0.01
Expand All @@ -203,18 +258,21 @@ def __init__(
meta = self.load_meta()
self.keep_nans = keep_nans
self.interpolation_mode = interpolation_mode
self.neuron_ids = neuron_ids
self.normalize = normalize
self.normalize_subtract_mean = normalize_subtract_mean
self.normalize_std_threshold = normalize_std_threshold
self.sampling_rate = meta["sampling_rate"]
self.time_delta = 1.0 / self.sampling_rate
self.start_time = meta["start_time"]
self.end_time = meta["end_time"]
self.is_mem_mapped = meta["is_mem_mapped"] if "is_mem_mapped" in meta else False
self.is_mem_mapped = meta.get("is_mem_mapped", False)
# Valid interval can be different to start time and end time.
self.valid_interval = TimeInterval(self.start_time, self.end_time)

self.n_signals = meta["n_signals"]
self.neuron_indices = self._resolve_indices(neuron_ids, neuron_indices)

# read .mem (memmap) or .npy file
if self.is_mem_mapped:
self._data = np.memmap(
Expand All @@ -223,28 +281,41 @@ def __init__(
mode="r",
shape=(meta["n_timestamps"], meta["n_signals"]),
)

if cache_data:
self._data = np.array(self._data).astype(
np.float32
) # Convert memmap to ndarray
else:
self._data = np.load(self.root_folder / "data.npy")

# Apply indexing BEFORE caching
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @pollytur
as far as I know, if one applies indexes to memmap as lists or arrays, this loads a copy into RAM (see https://stackoverflow.com/questions/18614927/how-to-slice-memmap-efficiently or https://stackoverflow.com/questions/78426050/how-to-index-a-numpy-memmap-without-creating-an-in-memory-copy)
Only regular continuous slicing creates just a view, but I don't think neuron ids will always be continuous.

Have you thought about this? We might need to set caching to True if neurons are indexed. Or we find a workaround (I haven't investigated it yet).

Copy link
Copy Markdown
Contributor

@pollytur pollytur Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great catch, no I have not thought about it tbh
we probably want to investigate the workaround (changing order of neurons as we need and save it as a temp memmap file is the first though but thats insanly memory inefficient since neuronal responses are also the heaviest part of the dataset from a memory perspective...)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe instead of directly indexing, we could iteratively fetch only the required columns (or in chunks) and optionally cache them.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Vrittigyl could you please provide an implementation for what you proposed and show that it still caches the data if cache_data == True and doesn't cache the data otherwise, also when we select only certain neurons? Thanks!

if self.neuron_indices is not None:
self._data = self._data[:, self.neuron_indices]
self.n_signals = len(self.neuron_indices)

# Cache only selected data
if self.is_mem_mapped and cache_data:
self._data = np.array(self._data, dtype=np.float32)

Comment on lines 273 to +295
Copy link

Copilot AI Mar 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New neuron-selection behavior (filtering _data, updating n_signals, filtering means/stds, and filtering phase_shifts in PhaseShiftedSequenceInterpolator) is not covered by existing tests. Please add test cases exercising neuron_ids and indexes selection (including mismatch between the two, ordering behavior, and normalization stats shape after filtering).

Copilot uses AI. Check for mistakes.
if self.normalize:
self.normalize_init()

def normalize_init(self):
self.mean = np.load(self.root_folder / "meta/means.npy")
self.std = np.load(self.root_folder / "meta/stds.npy")
mean = np.load(self.root_folder / "meta/means.npy") # shape: (n_total_signals,)
std = np.load(self.root_folder / "meta/stds.npy")

# Filter to selected neurons, before assertion
if self.neuron_indices is not None:
mean = mean[self.neuron_indices]
std = std[self.neuron_indices]

self.mean = mean.T
self.std = std.T

# Now n_signals and shape are guaranteed to match
assert (
self.mean.shape[0] == self.n_signals
), f"mean shape does not match: {self.mean.shape} vs {self._data.shape}"
), f"mean shape does not match: {self.mean.shape[0]} vs {self.n_signals}"
assert (
self.std.shape[0] == self.n_signals
), f"std shape does not match: {self.std.shape} vs {self._data.shape}"
self.mean = self.mean.T
self.std = self.std.T
), f"std shape does not match: {self.std.shape[0]} vs {self.n_signals}"

if self.normalize_std_threshold:
threshold = self.normalize_std_threshold * np.nanmean(self.std)
idx = self.std > threshold
Expand Down Expand Up @@ -358,7 +429,7 @@ class PhaseShiftedSequenceInterpolator(SequenceInterpolator):

def __init__(
self,
root_folder: str,
root_folder: str | Path,
cache_data: bool = False, # already cached, put it here for consistency
keep_nans: bool = False,
interpolation_mode: str = "nearest_neighbor",
Expand All @@ -379,6 +450,10 @@ def __init__(
)

self._phase_shifts = np.load(self.root_folder / "meta/phase_shifts.npy")
# Forward the required indexes
if self.neuron_indices is not None:
self._phase_shifts = self._phase_shifts[self.neuron_indices]

self.valid_interval = TimeInterval(
self.start_time
+ (np.max(self._phase_shifts) if len(self._phase_shifts) > 0 else 0),
Expand Down Expand Up @@ -504,7 +579,7 @@ class ScreenInterpolator(Interpolator):

def __init__(
self,
root_folder: str,
root_folder: str | Path,
cache_data: bool = False, # New parameter
rescale: bool = False,
rescale_size: tuple[int, int] | None = None,
Expand Down Expand Up @@ -713,7 +788,7 @@ class TimeIntervalInterpolator(Interpolator):
*i*-th valid time falls within any interval for the *j*-th label.
"""

def __init__(self, root_folder: str, cache_data: bool = False, **kwargs):
def __init__(self, root_folder: str | Path, cache_data: bool = False, **kwargs):
super().__init__(root_folder)
self.cache_data = cache_data

Expand Down Expand Up @@ -1007,11 +1082,13 @@ class SpikeInterpolator(Interpolator):

def __init__(
self,
root_folder: str,
root_folder: str | Path,
cache_data: bool = False,
interpolation_window: float = 0.3,
interpolation_align: str = "center",
smoothing_sigma: float = 0.0,
neuron_ids: list[int] | None = None,
neuron_indices: list[int] | None = None,
):
super().__init__(root_folder)

Expand Down Expand Up @@ -1061,6 +1138,33 @@ def __init__(
else:
self.spikes = np.load(self.dat_path)

neuron_indices = self._resolve_indices(neuron_ids, neuron_indices)

# If specific neuron indexes are requested, rebuild the spike array so that it
# only contains spikes from the selected neurons. We also rebuild the indices
# array so that it matches the new compacted spike array.
if neuron_indices is not None:
if len(neuron_indices) == 0:
# No neurons selected: represent this as an empty spike train
self.spikes = np.empty((0,), dtype=self.spikes.dtype)
self.indices = np.array([0], dtype=np.int64)
self.n_signals = 0
else:
new_indices = [0]
new_spikes = []

for i in neuron_indices:
start = self.indices[i]
end = self.indices[i + 1]
neuron_spikes = self.spikes[start:end]

new_spikes.append(neuron_spikes)
new_indices.append(new_indices[-1] + len(neuron_spikes))

self.spikes = np.concatenate(new_spikes)
self.indices = np.array(new_indices, dtype=np.int64)
self.n_signals = len(neuron_indices)

def interpolate(
self, times: np.ndarray, return_valid: bool = False
) -> tuple[np.ndarray, np.ndarray] | np.ndarray:
Expand Down
64 changes: 64 additions & 0 deletions tests/test_sequence_interpolator.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def test_nearest_neighbor_interpolation_with_phase_shifts(
"shifts_per_signal": True,
}
) as (timestamps, data, shift, seq_interp):
assert shift is not None
assert isinstance(
seq_interp, PhaseShiftedSequenceInterpolator
), "Interpolation object is not a PhaseShiftedSequenceInterpolator"
Expand Down Expand Up @@ -282,6 +283,7 @@ def test_linear_interpolation_with_phase_shifts(
},
interp_kwargs={"keep_nans": keep_nans},
) as (timestamps, data, shift, seq_interp):
assert shift is not None
assert isinstance(
seq_interp, PhaseShiftedSequenceInterpolator
), "Not a PhaseShiftedSequenceInterpolator"
Expand Down Expand Up @@ -389,6 +391,7 @@ def test_interpolation_with_phase_shifts_for_invalid_times(
},
interp_kwargs={"keep_nans": keep_nans},
) as (_, _, phase_shifts, seq_interp):
assert phase_shifts is not None
assert isinstance(
seq_interp, PhaseShiftedSequenceInterpolator
), "Interpolation object is not a PhaseShiftedSequenceInterpolator"
Expand Down Expand Up @@ -547,5 +550,66 @@ def test_interpolation_mode_not_implemented():
seq_interp.interpolate(np.array([0.0, 1.0, 2.0]), return_valid=True)


def test_sequence_interpolator_indexes_selection():
with sequence_data_and_interpolator(
data_kwargs={"n_signals": 10, "use_mem_mapped": False}
) as (_, data, _, seq_interp):

seq_interp = SequenceInterpolator(
seq_interp.root_folder, neuron_indices=[1, 3, 5]
)

assert seq_interp.n_signals == 3
assert seq_interp._data.shape[1] == 3


def test_sequence_interpolator_neuron_ids_selection(tmp_path):
with sequence_data_and_interpolator(data_kwargs={"n_signals": 4}) as (
_,
_,
_,
seq_interp,
):

meta_folder = seq_interp.root_folder / "meta"
meta_folder.mkdir(exist_ok=True)

unit_ids = np.array([10, 20, 30, 40])
np.save(meta_folder / "unit_ids.npy", unit_ids)

interp = SequenceInterpolator(seq_interp.root_folder, neuron_ids=[20, 40])

assert interp.n_signals == 2


def test_sequence_interpolator_neuron_ids_indexes_mismatch():
with sequence_data_and_interpolator(data_kwargs={"n_signals": 5}) as (
_,
_,
_,
seq_interp,
):
meta_folder = seq_interp.root_folder / "meta"
meta_folder.mkdir(exist_ok=True)
np.save(meta_folder / "unit_ids.npy", np.arange(5))

with pytest.raises(ValueError):
SequenceInterpolator(
seq_interp.root_folder, neuron_ids=[1], neuron_indices=[2]
)


def test_phase_shift_interpolator_indexes_filtering():
with sequence_data_and_interpolator(
data_kwargs={"n_signals": 6, "shifts_per_signal": True}
) as (_, _, phase_shifts, seq_interp):

interp = PhaseShiftedSequenceInterpolator(
seq_interp.root_folder, neuron_indices=[0, 2, 4]
)

assert len(interp._phase_shifts) == 3


if __name__ == "__main__":
pytest.main([__file__])
Loading
Loading