Skip to content

Commit 9c4f745

Browse files
committed
Solve conflicts
2 parents c797a49 + d395c43 commit 9c4f745

13 files changed

+164
-135
lines changed

src/spikeinterface/benchmark/benchmark_base.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -457,27 +457,6 @@ def compute_results(self, case_keys=None, verbose=False, **result_params):
457457
benchmark.compute_result(**result_params)
458458
benchmark.save_result(self.folder / "results" / self.key_to_str(key))
459459

460-
def create_sorting_analyzer_gt(self, case_keys=None, return_in_uV=True, random_params={}, **job_kwargs):
461-
print("###### Study.create_sorting_analyzer_gt() is not used anymore!!!!!!")
462-
# if case_keys is None:
463-
# case_keys = self.cases.keys()
464-
465-
# base_folder = self.folder / "sorting_analyzer"
466-
# base_folder.mkdir(exist_ok=True)
467-
468-
# dataset_keys = [self.cases[key]["dataset"] for key in case_keys]
469-
# dataset_keys = set(dataset_keys)
470-
# for dataset_key in dataset_keys:
471-
# # the waveforms depend on the dataset key
472-
# folder = base_folder / self.key_to_str(dataset_key)
473-
# recording, gt_sorting = self.datasets[dataset_key]
474-
# sorting_analyzer = create_sorting_analyzer(
475-
# gt_sorting, recording, format="binary_folder", folder=folder, return_in_uV=return_in_uV
476-
# )
477-
# sorting_analyzer.compute("random_spikes", **random_params)
478-
# sorting_analyzer.compute("templates", **job_kwargs)
479-
# sorting_analyzer.compute("noise_levels")
480-
481460
def get_sorting_analyzer(self, case_key=None, dataset_key=None):
482461
if case_key is not None:
483462
dataset_key = self.cases[case_key]["dataset"]

src/spikeinterface/benchmark/benchmark_clustering.py

Lines changed: 13 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,21 @@ def plot_performances_ordered(self, *args, **kwargs):
222222

223223
return plot_performances_ordered(self, *args, **kwargs)
224224

225+
def plot_some_over_merged(self, *args, **kwargs):
226+
from .benchmark_plot_tools import plot_some_over_merged
227+
228+
return plot_some_over_merged(self, *args, **kwargs)
229+
230+
def plot_some_over_splited(self, *args, **kwargs):
231+
from .benchmark_plot_tools import plot_some_over_splited
232+
233+
return plot_some_over_splited(self, *args, **kwargs)
234+
225235
def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)):
226236

227237
if case_keys is None:
228238
case_keys = list(self.cases.keys())
229-
import pylab as plt
239+
import matplotlib.pyplot as plt
230240

231241
fig, axes = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)
232242

@@ -263,7 +273,7 @@ def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5
263273

264274
if case_keys is None:
265275
case_keys = list(self.cases.keys())
266-
import pylab as plt
276+
import matplotlib.pyplot as plt
267277

268278
if axes is None:
269279
fig, axes = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)
@@ -322,7 +332,7 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs
322332

323333
if case_keys is None:
324334
case_keys = list(self.cases.keys())
325-
import pylab as plt
335+
import matplotlib.pyplot as plt
326336

327337
fig, axes = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)
328338

@@ -391,81 +401,3 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs
391401
fig.colorbar(im, cax=cbar_ax, label=metric)
392402

393403
return fig
394-
395-
def plot_some_over_merged(self, case_keys=None, overmerged_score=0.05, max_units=5, figsize=None):
396-
if case_keys is None:
397-
case_keys = list(self.cases.keys())
398-
import pylab as plt
399-
400-
figs = []
401-
for count, key in enumerate(case_keys):
402-
label = self.cases[key]["label"]
403-
comp = self.get_result(key)["gt_comparison"]
404-
405-
unit_index = np.flatnonzero(np.sum(comp.agreement_scores.values > overmerged_score, axis=0) > 1)
406-
overmerged_ids = comp.sorting2.unit_ids[unit_index]
407-
408-
n = min(len(overmerged_ids), max_units)
409-
if n > 0:
410-
fig, axs = plt.subplots(nrows=n, figsize=figsize)
411-
for i, unit_id in enumerate(overmerged_ids[:n]):
412-
gt_unit_indices = np.flatnonzero(comp.agreement_scores.loc[:, unit_id].values > overmerged_score)
413-
gt_unit_ids = comp.sorting1.unit_ids[gt_unit_indices]
414-
ax = axs[i]
415-
ax.set_title(f"unit {unit_id} - GTids {gt_unit_ids}")
416-
417-
analyzer = self.get_sorting_analyzer(key)
418-
419-
wf_template = analyzer.get_extension("templates")
420-
templates = wf_template.get_templates(unit_ids=gt_unit_ids)
421-
if analyzer.sparsity is not None:
422-
chan_mask = np.any(analyzer.sparsity.mask[gt_unit_indices, :], axis=0)
423-
templates = templates[:, :, chan_mask]
424-
ax.plot(templates.swapaxes(1, 2).reshape(templates.shape[0], -1).T)
425-
ax.set_xticks([])
426-
427-
fig.suptitle(label)
428-
figs.append(fig)
429-
else:
430-
print(key, "no overmerged")
431-
432-
return figs
433-
434-
def plot_some_over_splited(self, case_keys=None, oversplit_score=0.05, max_units=5, figsize=None):
435-
if case_keys is None:
436-
case_keys = list(self.cases.keys())
437-
import pylab as plt
438-
439-
figs = []
440-
for count, key in enumerate(case_keys):
441-
label = self.cases[key]["label"]
442-
comp = self.get_result(key)["gt_comparison"]
443-
444-
gt_unit_indices = np.flatnonzero(np.sum(comp.agreement_scores.values > oversplit_score, axis=1) > 1)
445-
oversplit_ids = comp.sorting1.unit_ids[gt_unit_indices]
446-
447-
n = min(len(oversplit_ids), max_units)
448-
if n > 0:
449-
fig, axs = plt.subplots(nrows=n, figsize=figsize)
450-
for i, unit_id in enumerate(oversplit_ids[:n]):
451-
unit_indices = np.flatnonzero(comp.agreement_scores.loc[unit_id, :].values > oversplit_score)
452-
unit_ids = comp.sorting2.unit_ids[unit_indices]
453-
ax = axs[i]
454-
ax.set_title(f"Gt unit {unit_id} - unit_ids: {unit_ids}")
455-
456-
templates = self.get_result(key)["clustering_templates"]
457-
458-
template_arrays = templates.get_dense_templates()[unit_indices, :, :]
459-
if templates.sparsity is not None:
460-
chan_mask = np.any(templates.sparsity.mask[gt_unit_indices, :], axis=0)
461-
template_arrays = template_arrays[:, :, chan_mask]
462-
463-
ax.plot(template_arrays.swapaxes(1, 2).reshape(template_arrays.shape[0], -1).T)
464-
ax.set_xticks([])
465-
466-
fig.suptitle(label)
467-
figs.append(fig)
468-
else:
469-
print(key, "no over splited")
470-
471-
return figs

src/spikeinterface/benchmark/benchmark_peak_localization.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def plot_comparison_positions(self, case_keys=None):
8686

8787
if case_keys is None:
8888
case_keys = list(self.cases.keys())
89-
import pylab as plt
89+
import matplotlib.pyplot as plt
9090

9191
fig, axs = plt.subplots(ncols=3, nrows=1, figsize=(15, 5))
9292

@@ -222,7 +222,7 @@ def plot_template_errors(self, case_keys=None, show_probe=True):
222222

223223
if case_keys is None:
224224
case_keys = list(self.cases.keys())
225-
import pylab as plt
225+
import matplotlib.pyplot as plt
226226

227227
fig, axs = plt.subplots(ncols=3, nrows=1, figsize=(15, 5))
228228

@@ -248,7 +248,7 @@ def plot_comparison_positions(self, case_keys=None):
248248

249249
if case_keys is None:
250250
case_keys = list(self.cases.keys())
251-
import pylab as plt
251+
import matplotlib.pyplot as plt
252252

253253
fig, axs = plt.subplots(ncols=3, nrows=1, figsize=(15, 5))
254254

@@ -416,7 +416,7 @@ def plot_comparison_positions(self, case_keys=None):
416416

417417

418418
# def plot_comparison_precision(benchmarks):
419-
# import pylab as plt
419+
# import matplotlib.pyplot as plt
420420

421421
# fig, axes = plt.subplots(ncols=2, nrows=1, figsize=(15, 10), squeeze=False)
422422

@@ -487,7 +487,7 @@ def plot_comparison_positions(self, case_keys=None):
487487
# norms = np.linalg.norm(benchmark.gt_positions[:, :2], axis=1)
488488
# cell_ind = np.argsort(norms)[0]
489489

490-
# import pylab as plt
490+
# import matplotlib.pyplot as plt
491491

492492
# fig, axs = plt.subplots(ncols=2, nrows=2, figsize=(15, 10))
493493
# plot_probe_map(benchmark.recording, ax=axs[0, 0])

src/spikeinterface/benchmark/benchmark_plot_tools.py

Lines changed: 109 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -842,6 +842,7 @@ def plot_performances_comparison(
842842
performance_colors={"accuracy": "g", "recall": "b", "precision": "r"},
843843
levels_to_group_by=None,
844844
ylim=(-0.1, 1.1),
845+
axs=None,
845846
):
846847
"""
847848
Plot performances comparison for a study.
@@ -881,7 +882,8 @@ def plot_performances_comparison(
881882
[key in performance_colors for key in performance_names]
882883
), f"performance_colors must have a color for each performance name: {performance_names}"
883884

884-
fig, axs = plt.subplots(ncols=num_methods - 1, nrows=num_methods - 1, figsize=figsize, squeeze=False)
885+
if axs is None:
886+
fig, axs = plt.subplots(ncols=num_methods - 1, nrows=num_methods - 1, figsize=figsize, squeeze=False)
885887
for i, key1 in enumerate(case_keys):
886888
for j, key2 in enumerate(case_keys):
887889
if i < j:
@@ -897,7 +899,8 @@ def plot_performances_comparison(
897899
comp1 = study.get_result(sub_key1)["gt_comparison"]
898900
comp2 = study.get_result(sub_key2)["gt_comparison"]
899901

900-
for performance_name, color in performance_colors.items():
902+
for performance_name in performance_names:
903+
color = performance_colors[performance_name]
901904
perf1 = comp1.get_performance()[performance_name]
902905
perf2 = comp2.get_performance()[performance_name]
903906
ax.scatter(perf2, perf1, marker=".", label=performance_name, color=color)
@@ -923,9 +926,11 @@ def plot_performances_comparison(
923926
patches = []
924927
from matplotlib.patches import Patch
925928

926-
for name, color in performance_colors.items():
927-
patches.append(Patch(color=color, label=name))
929+
for performance_name in performance_names:
930+
color = performance_colors[performance_name]
931+
patches.append(Patch(color=color, label=performance_name))
928932
ax.legend(handles=patches)
933+
fig = ax.figure
929934
fig.subplots_adjust(hspace=0.1, wspace=0.1)
930935
return fig
931936

@@ -964,7 +969,7 @@ def plot_performances_vs_depth_and_snr(
964969
fig : matplotlib.figure.Figure
965970
The resulting figure containing the plots.
966971
"""
967-
import pylab as plt
972+
import matplotlib.pyplot as plt
968973

969974
if case_keys is None:
970975
case_keys = list(study.cases.keys())
@@ -1082,3 +1087,102 @@ def plot_performance_losses(
10821087
despine(axs)
10831088

10841089
return fig
1090+
1091+
1092+
def plot_some_over_merged(study, case_keys=None, overmerged_score=0.05, max_units=5, figsize=None):
1093+
"""
1094+
Plot some waveforms of overmerged units.
1095+
"""
1096+
1097+
if case_keys is None:
1098+
case_keys = list(study.cases.keys())
1099+
import matplotlib.pyplot as plt
1100+
1101+
figs = []
1102+
for count, key in enumerate(case_keys):
1103+
label = study.cases[key]["label"]
1104+
comp = study.get_result(key)["gt_comparison"]
1105+
1106+
unit_index = np.flatnonzero(np.sum(comp.agreement_scores.values > overmerged_score, axis=0) > 1)
1107+
overmerged_ids = comp.sorting2.unit_ids[unit_index]
1108+
1109+
n = min(len(overmerged_ids), max_units)
1110+
if n > 0:
1111+
fig, axs = plt.subplots(nrows=n, figsize=figsize, squeeze=False)
1112+
axs = axs[:, 0]
1113+
for i, unit_id in enumerate(overmerged_ids[:n]):
1114+
gt_unit_indices = np.flatnonzero(comp.agreement_scores.loc[:, unit_id].values > overmerged_score)
1115+
gt_unit_ids = comp.sorting1.unit_ids[gt_unit_indices]
1116+
ax = axs[i]
1117+
ax.set_title(f"unit {unit_id} - GTids {gt_unit_ids}")
1118+
1119+
analyzer = study.get_sorting_analyzer(key)
1120+
1121+
wf_template = analyzer.get_extension("templates")
1122+
templates = wf_template.get_templates(unit_ids=gt_unit_ids)
1123+
if analyzer.sparsity is not None:
1124+
chan_mask = np.any(analyzer.sparsity.mask[gt_unit_indices, :], axis=0)
1125+
templates = templates[:, :, chan_mask]
1126+
ax.plot(templates.swapaxes(1, 2).reshape(templates.shape[0], -1).T)
1127+
ax.set_xticks([])
1128+
1129+
fig.suptitle(label)
1130+
figs.append(fig)
1131+
else:
1132+
print(key, "no overmerged")
1133+
1134+
return figs
1135+
1136+
1137+
def plot_some_over_splited(study, case_keys=None, oversplit_score=0.05, max_units=5, figsize=None):
1138+
"""
1139+
Plot some waveforms of over-splitted units.
1140+
"""
1141+
if case_keys is None:
1142+
case_keys = list(study.cases.keys())
1143+
import matplotlib.pyplot as plt
1144+
1145+
print(case_keys)
1146+
figs = []
1147+
for count, key in enumerate(case_keys):
1148+
print(key)
1149+
label = study.cases[key]["label"]
1150+
comp = study.get_result(key)["gt_comparison"]
1151+
1152+
gt_unit_indices = np.flatnonzero(np.sum(comp.agreement_scores.values > oversplit_score, axis=1) > 1)
1153+
oversplit_ids = comp.sorting1.unit_ids[gt_unit_indices]
1154+
1155+
n = min(len(oversplit_ids), max_units)
1156+
if n > 0:
1157+
fig, axs = plt.subplots(nrows=n, figsize=figsize, squeeze=False)
1158+
axs = axs[:, 0]
1159+
for i, unit_id in enumerate(oversplit_ids[:n]):
1160+
unit_indices = np.flatnonzero(comp.agreement_scores.loc[unit_id, :].values > oversplit_score)
1161+
unit_ids = comp.sorting2.unit_ids[unit_indices]
1162+
ax = axs[i]
1163+
ax.set_title(f"Gt unit {unit_id} - unit_ids: {unit_ids}")
1164+
1165+
results = study.get_result(key)
1166+
if "clustering_templates" in results:
1167+
# ClusteringBenchmark has this
1168+
templates = results["clustering_templates"]
1169+
elif "sorter_analyzer" in results:
1170+
# SorterBenchmark has this
1171+
templates = results["sorter_analyzer"].get_extension("templates").get_data(outputs="Templates")
1172+
else:
1173+
raise ValueError("This benchmark do not have templates computed")
1174+
1175+
template_arrays = templates.get_dense_templates()[unit_indices, :, :]
1176+
if templates.sparsity is not None:
1177+
chan_mask = np.any(templates.sparsity.mask[gt_unit_indices, :], axis=0)
1178+
template_arrays = template_arrays[:, :, chan_mask]
1179+
1180+
ax.plot(template_arrays.swapaxes(1, 2).reshape(template_arrays.shape[0], -1).T)
1181+
ax.set_xticks([])
1182+
1183+
fig.suptitle(label)
1184+
figs.append(fig)
1185+
else:
1186+
print(key, "no over splited")
1187+
1188+
return figs

0 commit comments

Comments
 (0)