|
5 | 5 | from spikeinterface.curation.model_based_curation import ModelBasedClassification |
6 | 6 | from spikeinterface.curation import model_based_label_units, load_model |
7 | 7 | from spikeinterface.curation.train_manual_curation import _get_computed_metrics |
8 | | -from spikeinterface.curation import unitrefine_label_units |
9 | 8 |
|
10 | 9 |
|
11 | 10 | import numpy as np |
@@ -171,83 +170,3 @@ def test_exception_raised_when_metric_params_not_equal(sorting_analyzer_for_cura |
171 | 170 | model, model_info = load_model(model_folder=trained_pipeline_path, trusted=["numpy.dtype"]) |
172 | 171 | model_based_classification = ModelBasedClassification(sorting_analyzer_for_curation, model) |
173 | 172 | model_based_classification._check_params_for_classification(enforce_metric_params=True, model_info=model_info) |
174 | | - |
175 | | - |
176 | | -def test_unitrefine_label_units_hf(sorting_analyzer_for_curation): |
177 | | - """Test the `unitrefine_label_units` function.""" |
178 | | - sorting_analyzer_for_curation.compute("template_metrics", include_multi_channel_metrics=True) |
179 | | - sorting_analyzer_for_curation.compute("quality_metrics") |
180 | | - |
181 | | - # test passing both classifiers |
182 | | - labels = unitrefine_label_units( |
183 | | - sorting_analyzer_for_curation, |
184 | | - noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight", |
185 | | - sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight", |
186 | | - ) |
187 | | - |
188 | | - assert "label" in labels.columns |
189 | | - assert "probability" in labels.columns |
190 | | - assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids) |
191 | | - |
192 | | - # test only noise neural classifier |
193 | | - labels = unitrefine_label_units( |
194 | | - sorting_analyzer_for_curation, |
195 | | - noise_neural_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight", |
196 | | - sua_mua_classifier=None, |
197 | | - ) |
198 | | - |
199 | | - assert "label" in labels.columns |
200 | | - assert "probability" in labels.columns |
201 | | - assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids) |
202 | | - |
203 | | - # test only sua mua classifier |
204 | | - labels = unitrefine_label_units( |
205 | | - sorting_analyzer_for_curation, |
206 | | - noise_neural_classifier=None, |
207 | | - sua_mua_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight", |
208 | | - ) |
209 | | - |
210 | | - assert "label" in labels.columns |
211 | | - assert "probability" in labels.columns |
212 | | - assert labels.shape[0] == len(sorting_analyzer_for_curation.sorting.unit_ids) |
213 | | - |
214 | | - # test passing none |
215 | | - with pytest.raises(ValueError): |
216 | | - labels = unitrefine_label_units( |
217 | | - sorting_analyzer_for_curation, |
218 | | - noise_neural_classifier=None, |
219 | | - sua_mua_classifier=None, |
220 | | - ) |
221 | | - |
222 | | - # test warnings when unexpected labels are returned |
223 | | - with pytest.warns(UserWarning): |
224 | | - labels = unitrefine_label_units( |
225 | | - sorting_analyzer_for_curation, |
226 | | - noise_neural_classifier="SpikeInterface/UnitRefine_sua_mua_classifier_lightweight", |
227 | | - sua_mua_classifier=None, |
228 | | - ) |
229 | | - |
230 | | - with pytest.warns(UserWarning): |
231 | | - labels = unitrefine_label_units( |
232 | | - sorting_analyzer_for_curation, |
233 | | - noise_neural_classifier=None, |
234 | | - sua_mua_classifier="SpikeInterface/UnitRefine_noise_neural_classifier_lightweight", |
235 | | - ) |
236 | | - |
237 | | - |
238 | | -def test_unitrefine_label_units_with_local_models(sorting_analyzer_for_curation, trained_pipeline_path): |
239 | | - # test with trained local models |
240 | | - sorting_analyzer_for_curation.compute("template_metrics", include_multi_channel_metrics=True) |
241 | | - sorting_analyzer_for_curation.compute("quality_metrics") |
242 | | - |
243 | | - # test passing model folder |
244 | | - labels = unitrefine_label_units( |
245 | | - sorting_analyzer_for_curation, |
246 | | - noise_neural_classifier=trained_pipeline_path, |
247 | | - ) |
248 | | - |
249 | | - # test passing model folder |
250 | | - labels = unitrefine_label_units( |
251 | | - sorting_analyzer_for_curation, |
252 | | - noise_neural_classifier=trained_pipeline_path / "best_model.skops", |
253 | | - ) |
0 commit comments