Skip to content

Commit c797a49

Browse files
committed
Add qualitymetrics_based curation
1 parent a2550d2 commit c797a49

File tree

10 files changed

+298
-103
lines changed

10 files changed

+298
-103
lines changed
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# ---
2+
# jupyter:
3+
# jupytext:
4+
# cell_metadata_filter: -all
5+
# formats: ipynb,py:percent
6+
# text_representation:
7+
# extension: .py
8+
# format_name: percent
9+
# format_version: '1.3'
10+
# jupytext_version: 1.19.1
11+
# kernelspec:
12+
# display_name: Python 3 (ipykernel)
13+
# language: python
14+
# name: python3
15+
# ---
16+
17+
# %% [markdown]
18+
# # Automatic labeling units after spike sorting
19+
#
20+
# This example shows how to automatically label units after spike sorting, using three different approaches:
21+
#
22+
# 1. Simple filter based on quality metrics
23+
# 2. Bombcell: heuristic approach to label units based on quality and template metrics [Fabre]_
24+
# 3. UnitRefine: pre-trained classifiers to label units as noise or SUA/MUA [Jain]_
25+
26+
# %%
27+
import spikeinterface as si
28+
import spikeinterface.curation as sc
29+
import spikeinterface.widgets as sw
30+
31+
# %%
32+
analyzer_path = "/ssd980/working/analyzer_np2_shank1.zarr"
33+
34+
# %%
35+
analyzer = si.load(analyzer_path)
36+
37+
# %%
38+
qm = analyzer_zarr.compute("quality_metrics", delete_existing_metrics=True)
39+
qm.get_data()
40+
41+
# %%

src/spikeinterface/curation/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
# automated curation
2323
from .model_based_curation import model_based_label_units, load_model, auto_label_units
2424
from .train_manual_curation import train_model, get_default_classifier_search_spaces
25+
from .qualitymetrics_curation import qualitymetrics_label_units
2526
from .unitrefine_curation import unitrefine_label_units
2627
from .bombcell_curation import (
2728
bombcell_get_default_thresholds,

src/spikeinterface/curation/bombcell_curation.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import json
1515
import numpy as np
1616

17+
from .curation_tools import _is_threshold_disabled
18+
1719
NOISE_METRICS = [
1820
"num_positive_peaks",
1921
"num_negative_peaks",
@@ -75,15 +77,6 @@ def bombcell_get_default_thresholds() -> dict:
7577
}
7678

7779

78-
def _is_threshold_disabled(value):
79-
"""Check if a threshold value is disabled (None or np.nan)."""
80-
if value is None:
81-
return True
82-
if isinstance(value, float) and np.isnan(value):
83-
return True
84-
return False
85-
86-
8780
def bombcell_label_units(
8881
sorting_analyzer=None,
8982
thresholds: dict | str | Path | None = None,

src/spikeinterface/curation/curation_tools.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@
1414
_methods_numpy = ("keep_first", "random", "keep_last")
1515

1616

17+
def _is_threshold_disabled(value):
18+
"""Check if a threshold value is disabled (None or np.nan)."""
19+
if value is None:
20+
return True
21+
if isinstance(value, float) and np.isnan(value):
22+
return True
23+
return False
24+
25+
1726
def _find_duplicated_spikes_numpy(
1827
spike_train: np.ndarray,
1928
censored_period: int,
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import json
2+
from pathlib import Path
3+
4+
import numpy as np
5+
6+
from spikeinterface.core.analyzer_extension_core import SortingAnalyzer
7+
8+
from .curation_tools import _is_threshold_disabled
9+
10+
11+
def qualitymetrics_label_units(
12+
analyzer: SortingAnalyzer,
13+
thresholds: dict | str | Path,
14+
):
15+
"""Label units based on quality metrics and thresholds.
16+
17+
Parameters
18+
----------
19+
analyzer : SortingAnalyzer
20+
The SortingAnalyzer object containing the quality metrics.
21+
thresholds : dict | str | Path
22+
A dictionary or JSON file path where keys are metric names and values are threshold values for labeling units.
23+
Each key should correspond to a quality metric present in the analyzer's quality metrics DataFrame. Values
24+
should contain at least "min" and/or "max" keys to specify threshold ranges.
25+
Units that do not meet the threshold for a given metric will be labeled as 'noise', while those that do will
26+
be labeled as 'good'.
27+
28+
Returns
29+
-------
30+
labels : pd.DataFrame
31+
A DataFrame with unit IDs as index and a column 'label' containing the assigned labels ('noise' or 'good').
32+
"""
33+
import pandas as pd
34+
35+
# Get the quality metrics from the analyzer
36+
assert analyzer.has_extension("quality_metrics"), (
37+
"The provided analyzer does not have quality metrics computed. "
38+
"Please compute quality metrics before labeling units."
39+
)
40+
qm = analyzer.get_extension("quality_metrics").get_data()
41+
42+
# Load thresholds from file if a path is provided
43+
if isinstance(thresholds, (str, Path)):
44+
45+
with open(thresholds, "r") as f:
46+
thresholds_dict = json.load(f)
47+
elif isinstance(thresholds, dict):
48+
thresholds_dict = thresholds
49+
else:
50+
raise ValueError("Thresholds must be a dictionary or a path to a JSON file containing the thresholds.")
51+
52+
# Check that all specified metrics are present in the quality metrics DataFrame
53+
missing_metrics = []
54+
for metric in thresholds_dict.keys():
55+
if metric not in qm.columns:
56+
missing_metrics.append(metric)
57+
if len(missing_metrics) > 0:
58+
raise ValueError(
59+
f"Metric(s) {missing_metrics} specified in thresholds are not present in the quality metrics DataFrame. "
60+
f"Available metrics are: {qm.columns.tolist()}"
61+
)
62+
63+
# Initialize an empty DataFrame to store labels
64+
labels = pd.DataFrame(index=qm.index, dtype=str)
65+
labels["label"] = "noise" # Default label is 'noise'
66+
67+
# Apply thresholds to label units
68+
good_mask = np.ones(len(qm), dtype=bool)
69+
70+
for metric_name, threshold in thresholds_dict.items():
71+
min_value = threshold.get("min", None)
72+
max_value = threshold.get("max", None)
73+
if not _is_threshold_disabled(min_value):
74+
good_mask &= qm[metric_name] >= min_value
75+
if not _is_threshold_disabled(max_value):
76+
good_mask &= qm[metric_name] <= max_value
77+
78+
labels.loc[good_mask, "label"] = "good"
79+
80+
return labels

src/spikeinterface/curation/tests/test_model_based_curation.py

Lines changed: 0 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from spikeinterface.curation.model_based_curation import ModelBasedClassification
66
from spikeinterface.curation import model_based_label_units, load_model
77
from spikeinterface.curation.train_manual_curation import _get_computed_metrics
8-
from spikeinterface.curation import unitrefine_label_units
98

109

1110
import numpy as np
@@ -171,83 +170,3 @@ def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_cura
171170
model, model_info = load_model(model_folder=trained_pipeline_path, trusted=["numpy.dtype"])
172171
model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model)
173172
model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info)
174-
175-
176-
def test_unitrefine_label_units_hf(sorting_analyzer_for_curation):
177-
"""Test the `unitrefine_label_units` function."""
178-
sorting_analyzer_for_curation.compute("template_metrics", include_multi_channel_metrics=True)
179-
sorting_analyzer_for_curation.compute("quality_metrics")
180-
181-
# test passing both classifiers
182-
labels = unitrefine_label_units(
183-
sorting_analyzer_for_curation,
184-
noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight",
185-
sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight",
186-
)
187-
188-
assert "label" in labels.columns
189-
assert "probability" in labels.columns
190-
assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids)
191-
192-
# test only noise neural classifier
193-
labels = unitrefine_label_units(
194-
sorting_analyzer_for_curation,
195-
noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight",
196-
sua_mua_classifier=None,
197-
)
198-
199-
assert "label" in labels.columns
200-
assert "probability" in labels.columns
201-
assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids)
202-
203-
# test only sua mua classifier
204-
labels = unitrefine_label_units(
205-
sorting_analyzer_for_curation,
206-
noise_neural_classifier=None,
207-
sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight",
208-
)
209-
210-
assert "label" in labels.columns
211-
assert "probability" in labels.columns
212-
assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids)
213-
214-
# test passing none
215-
with pytest.raises(ValueError):
216-
labels = unitrefine_label_units(
217-
sorting_analyzer_for_curation,
218-
noise_neural_classifier=None,
219-
sua_mua_classifier=None,
220-
)
221-
222-
# test warnings when unexpected labels are returned
223-
with pytest.warns(UserWarning):
224-
labels = unitrefine_label_units(
225-
sorting_analyzer_for_curation,
226-
noise_neural_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight",
227-
sua_mua_classifier=None,
228-
)
229-
230-
with pytest.warns(UserWarning):
231-
labels = unitrefine_label_units(
232-
sorting_analyzer_for_curation,
233-
noise_neural_classifier=None,
234-
sua_mua_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight",
235-
)
236-
237-
238-
def test_unitrefine_label_units_with_local_models(sorting_analyzer_for_curation, trained_pipeline_path):
239-
# test with trained local models
240-
sorting_analyzer_for_curation.compute("template_metrics", include_multi_channel_metrics=True)
241-
sorting_analyzer_for_curation.compute("quality_metrics")
242-
243-
# test passing model folder
244-
labels = unitrefine_label_units(
245-
sorting_analyzer_for_curation,
246-
noise_neural_classifier=trained_pipeline_path,
247-
)
248-
249-
# test passing model folder
250-
labels = unitrefine_label_units(
251-
sorting_analyzer_for_curation,
252-
noise_neural_classifier=trained_pipeline_path / "best_model.skops",
253-
)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
import pytest
2+
import json
3+
4+
from spikeinterface.curation.tests.common import sorting_analyzer_for_curation
5+
from spikeinterface.curation import qualitymetrics_label_units
6+
7+
8+
def test_qualitymetrics_label_units(sorting_analyzer_for_curation):
9+
"""Test the `qualitymetrics_label_units` function."""
10+
sorting_analyzer_for_curation.compute("quality_metrics")
11+
12+
thresholds = {
13+
"snr": {"min": 5.0},
14+
"firing_rate": {"min": 0.1, "max": 20.0},
15+
}
16+
17+
labels = qualitymetrics_label_units(
18+
sorting_analyzer_for_curation,
19+
thresholds,
20+
)
21+
22+
assert "label" in labels.columns
23+
assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids)
24+
25+
# Check that units with snr < 5.0 or firing_rate < 0.1 are labeled as 'noise'
26+
qm = sorting_analyzer_for_curation.get_extension("quality_metrics").get_data()
27+
for unit_id in sorting_analyzer_for_curation.sorting.unit_ids:
28+
snr = qm.loc[unit_id, "snr"]
29+
firing_rate = qm.loc[unit_id, "firing_rate"]
30+
if (
31+
snr >= thresholds["snr"]["min"]
32+
and thresholds["firing_rate"]["min"] <= firing_rate <= thresholds["firing_rate"]["max"]
33+
):
34+
assert labels.loc[unit_id, "label"] == "good"
35+
else:
36+
assert labels.loc[unit_id, "label"] == "noise"
37+
38+
39+
def test_qualitymetrics_label_units_with_file(sorting_analyzer_for_curation, tmp_path):
40+
"""Test the `qualitymetrics_label_units` function with thresholds from a JSON file."""
41+
sorting_analyzer_for_curation.compute("quality_metrics")
42+
43+
thresholds = {
44+
"snr": {"min": 5.0},
45+
"firing_rate": {"min": 0.1},
46+
}
47+
48+
thresholds_file = tmp_path / "thresholds.json"
49+
with open(thresholds_file, "w") as f:
50+
json.dump(thresholds, f)
51+
52+
labels = qualitymetrics_label_units(
53+
sorting_analyzer_for_curation,
54+
thresholds_file,
55+
)
56+
57+
assert "label" in labels.columns
58+
assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids)
59+
60+
# Check that units with snr < 5.0 or firing_rate < 0.1 are labeled as 'noise'
61+
qm = sorting_analyzer_for_curation.get_extension("quality_metrics").get_data()
62+
for unit_id in sorting_analyzer_for_curation.sorting.unit_ids:
63+
snr = qm.loc[unit_id, "snr"]
64+
firing_rate = qm.loc[unit_id, "firing_rate"]
65+
if snr >= thresholds["snr"]["min"] and firing_rate >= thresholds["firing_rate"]["min"]:
66+
assert labels.loc[unit_id, "label"] == "good"
67+
else:
68+
assert labels.loc[unit_id, "label"] == "noise"

0 commit comments

Comments
 (0)