66from itertools import chain
77
88from spikeinterface .core import BaseSorting , SortingAnalyzer , apply_merges_to_sorting , apply_splits_to_sorting
9- from spikeinterface .curation .curation_model import CurationModel
9+ from spikeinterface .curation .curation_model import CurationModel , SequentialCuration
1010
1111
1212def validate_curation_dict (curation_dict : dict ):
@@ -138,7 +138,7 @@ def apply_curation_labels(
138138
139139def apply_curation (
140140 sorting_or_analyzer : BaseSorting | SortingAnalyzer ,
141- curation_dict_or_model : dict | CurationModel ,
141+ curation_dict_or_model : dict | list | CurationModel | SequentialCuration ,
142142 censor_ms : float | None = None ,
143143 new_id_strategy : str = "append" ,
144144 merging_mode : str = "soft" ,
@@ -164,7 +164,7 @@ def apply_curation(
164164 ----------
165165 sorting_or_analyzer : Sorting | SortingAnalyzer
166166 The Sorting or SortingAnalyzer object to apply merges.
167- curation_dict : dict or CurationModel
167+ curation_dict : dict | CurationModel | SequentialCuration
168168 The curation dict or model.
169169 censor_ms : float | None, default: None
170170 When applying the merges, any consecutive spikes within the `censor_ms` are removed. This can be thought of
@@ -199,14 +199,32 @@ def apply_curation(
199199 sorting_or_analyzer , (BaseSorting , SortingAnalyzer )
200200 ), f"`sorting_or_analyzer` must be a Sorting or a SortingAnalyzer, not an object of type { type (sorting_or_analyzer )} "
201201 assert isinstance (
202- curation_dict_or_model , (dict , CurationModel )
203- ), f"`curation_dict_or_model` must be a dict or a CurationModel, not an object of type { type (curation_dict_or_model )} "
202+ curation_dict_or_model , (dict , list , CurationModel , SequentialCuration )
203+ ), f"`curation_dict_or_model` must be a dict, CurationModel or a SequentialCuration not an object of type { type (curation_dict_or_model )} "
204204 if isinstance (curation_dict_or_model , dict ):
205205 curation_model = CurationModel (** curation_dict_or_model )
206+ elif isinstance (curation_dict_or_model , list ):
207+ curation_model = SequentialCuration (curation_steps = curation_dict_or_model )
206208 else :
207209 curation_model = curation_dict_or_model .model_copy (deep = True )
208210
209- if not np .array_equal (np .asarray (curation_model .unit_ids ), sorting_or_analyzer .unit_ids ):
211+ if isinstance (curation_model , SequentialCuration ):
212+ for c , single_curation_model in enumerate (curation_model .curation_steps ):
213+ if verbose :
214+ print (f"Applying curation step: { c + 1 } / { len (curation_model .curation_steps )} " )
215+ sorting_or_analyzer = apply_curation (
216+ sorting_or_analyzer ,
217+ single_curation_model ,
218+ censor_ms = censor_ms ,
219+ merging_mode = merging_mode ,
220+ sparsity_overlap = sparsity_overlap ,
221+ raise_error_if_overlap_fails = raise_error_if_overlap_fails ,
222+ verbose = verbose ,
223+ job_kwargs = job_kwargs ,
224+ )
225+ return sorting_or_analyzer
226+
227+ if not set (curation_model .unit_ids ) == set (sorting_or_analyzer .unit_ids ):
210228 raise ValueError ("unit_ids from the curation_dict do not match the one from Sorting or SortingAnalyzer" )
211229
212230 # 1. Apply labels
@@ -228,13 +246,15 @@ def apply_curation(
228246 curated_sorting_or_analyzer , _ , _ = apply_merges_to_sorting (
229247 curated_sorting_or_analyzer ,
230248 merge_unit_groups = merge_unit_groups ,
249+ new_unit_ids = merge_new_unit_ids ,
231250 censor_ms = censor_ms ,
232251 new_id_strategy = new_id_strategy ,
233252 return_extra = True ,
234253 )
235254 else :
236255 curated_sorting_or_analyzer , _ = curated_sorting_or_analyzer .merge_units (
237256 merge_unit_groups = merge_unit_groups ,
257+ new_unit_ids = merge_new_unit_ids ,
238258 censor_ms = censor_ms ,
239259 merging_mode = merging_mode ,
240260 sparsity_overlap = sparsity_overlap ,
0 commit comments