Skip to content
47 changes: 23 additions & 24 deletions src/spikeinterface/core/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,30 +112,20 @@ def check_sortings_equal(

max_spike_index = SX1.to_spike_vector()["sample_index"].max()

# TODO for later use to_spike_vector() to do this without looping
for segment_idx in range(SX1.get_num_segments()):
# get_unit_ids
ids1 = np.sort(np.array(SX1.get_unit_ids()))
ids2 = np.sort(np.array(SX2.get_unit_ids()))
assert_array_equal(ids1, ids2)
for id in ids1:
train1 = np.sort(SX1.get_unit_spike_train(id, segment_index=segment_idx))
train2 = np.sort(SX2.get_unit_spike_train(id, segment_index=segment_idx))
assert np.array_equal(train1, train2)
train1 = np.sort(SX1.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30))
train2 = np.sort(SX2.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30))
assert np.array_equal(train1, train2)
# test that slicing works correctly
train1 = np.sort(SX1.get_unit_spike_train(id, segment_index=segment_idx, end_frame=max_spike_index - 30))
train2 = np.sort(SX2.get_unit_spike_train(id, segment_index=segment_idx, end_frame=max_spike_index - 30))
assert np.array_equal(train1, train2)
train1 = np.sort(
SX1.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30, end_frame=max_spike_index - 30)
)
train2 = np.sort(
SX2.get_unit_spike_train(id, segment_index=segment_idx, start_frame=30, end_frame=max_spike_index - 30)
)
assert np.array_equal(train1, train2)
s1 = SX1.to_spike_vector()
s2 = SX2.to_spike_vector()
assert_array_equal(s1, s2)

for start_frame, end_frame in [
(None, None),
(30, None),
(None, max_spike_index - 30),
(30, max_spike_index - 30),
]:

slice1 = _slice_spikes(s1, start_frame, end_frame)
slice2 = _slice_spikes(s2, start_frame, end_frame)
assert np.array_equal(slice1, slice2)

if check_annotations:
check_extractor_annotations_equal(SX1, SX2)
Expand All @@ -155,3 +145,12 @@ def check_extractor_properties_equal(EX1, EX2) -> None:

for property_name in EX1.get_property_keys():
assert_array_equal(EX1.get_property(property_name), EX2.get_property(property_name))


def _slice_spikes(spikes, start_frame=None, end_frame=None):
mask = np.ones(spikes.size, dtype=bool)
if start_frame is not None:
mask &= spikes["sample_index"] >= start_frame
if end_frame is not None:
mask &= spikes["sample_index"] <= end_frame
return spikes[mask]
Loading