@@ -842,6 +842,7 @@ def plot_performances_comparison(
842842 performance_colors = {"accuracy" : "g" , "recall" : "b" , "precision" : "r" },
843843 levels_to_group_by = None ,
844844 ylim = (- 0.1 , 1.1 ),
845+ axs = None ,
845846):
846847 """
847848 Plot performances comparison for a study.
@@ -881,7 +882,8 @@ def plot_performances_comparison(
881882 [key in performance_colors for key in performance_names ]
882883 ), f"performance_colors must have a color for each performance name: { performance_names } "
883884
884- fig , axs = plt .subplots (ncols = num_methods - 1 , nrows = num_methods - 1 , figsize = figsize , squeeze = False )
885+ if axs is None :
886+ fig , axs = plt .subplots (ncols = num_methods - 1 , nrows = num_methods - 1 , figsize = figsize , squeeze = False )
885887 for i , key1 in enumerate (case_keys ):
886888 for j , key2 in enumerate (case_keys ):
887889 if i < j :
@@ -897,7 +899,8 @@ def plot_performances_comparison(
897899 comp1 = study .get_result (sub_key1 )["gt_comparison" ]
898900 comp2 = study .get_result (sub_key2 )["gt_comparison" ]
899901
900- for performance_name , color in performance_colors .items ():
902+ for performance_name in performance_names :
903+ color = performance_colors [performance_name ]
901904 perf1 = comp1 .get_performance ()[performance_name ]
902905 perf2 = comp2 .get_performance ()[performance_name ]
903906 ax .scatter (perf2 , perf1 , marker = "." , label = performance_name , color = color )
@@ -923,9 +926,11 @@ def plot_performances_comparison(
923926 patches = []
924927 from matplotlib .patches import Patch
925928
926- for name , color in performance_colors .items ():
927- patches .append (Patch (color = color , label = name ))
929+ for performance_name in performance_names :
930+ color = performance_colors [performance_name ]
931+ patches .append (Patch (color = color , label = performance_name ))
928932 ax .legend (handles = patches )
933+ fig = ax .figure
929934 fig .subplots_adjust (hspace = 0.1 , wspace = 0.1 )
930935 return fig
931936
@@ -964,7 +969,7 @@ def plot_performances_vs_depth_and_snr(
964969 fig : matplotlib.figure.Figure
965970 The resulting figure containing the plots.
966971 """
967- import pylab as plt
972+ import matplotlib . pyplot as plt
968973
969974 if case_keys is None :
970975 case_keys = list (study .cases .keys ())
@@ -1082,3 +1087,102 @@ def plot_performance_losses(
10821087 despine (axs )
10831088
10841089 return fig
1090+
1091+
1092+ def plot_some_over_merged (study , case_keys = None , overmerged_score = 0.05 , max_units = 5 , figsize = None ):
1093+ """
1094+ Plot some waveforms of overmerged units.
1095+ """
1096+
1097+ if case_keys is None :
1098+ case_keys = list (study .cases .keys ())
1099+ import matplotlib .pyplot as plt
1100+
1101+ figs = []
1102+ for count , key in enumerate (case_keys ):
1103+ label = study .cases [key ]["label" ]
1104+ comp = study .get_result (key )["gt_comparison" ]
1105+
1106+ unit_index = np .flatnonzero (np .sum (comp .agreement_scores .values > overmerged_score , axis = 0 ) > 1 )
1107+ overmerged_ids = comp .sorting2 .unit_ids [unit_index ]
1108+
1109+ n = min (len (overmerged_ids ), max_units )
1110+ if n > 0 :
1111+ fig , axs = plt .subplots (nrows = n , figsize = figsize , squeeze = False )
1112+ axs = axs [:, 0 ]
1113+ for i , unit_id in enumerate (overmerged_ids [:n ]):
1114+ gt_unit_indices = np .flatnonzero (comp .agreement_scores .loc [:, unit_id ].values > overmerged_score )
1115+ gt_unit_ids = comp .sorting1 .unit_ids [gt_unit_indices ]
1116+ ax = axs [i ]
1117+ ax .set_title (f"unit { unit_id } - GTids { gt_unit_ids } " )
1118+
1119+ analyzer = study .get_sorting_analyzer (key )
1120+
1121+ wf_template = analyzer .get_extension ("templates" )
1122+ templates = wf_template .get_templates (unit_ids = gt_unit_ids )
1123+ if analyzer .sparsity is not None :
1124+ chan_mask = np .any (analyzer .sparsity .mask [gt_unit_indices , :], axis = 0 )
1125+ templates = templates [:, :, chan_mask ]
1126+ ax .plot (templates .swapaxes (1 , 2 ).reshape (templates .shape [0 ], - 1 ).T )
1127+ ax .set_xticks ([])
1128+
1129+ fig .suptitle (label )
1130+ figs .append (fig )
1131+ else :
1132+ print (key , "no overmerged" )
1133+
1134+ return figs
1135+
1136+
1137+ def plot_some_over_splited (study , case_keys = None , oversplit_score = 0.05 , max_units = 5 , figsize = None ):
1138+ """
1139+ Plot some waveforms of over-splitted units.
1140+ """
1141+ if case_keys is None :
1142+ case_keys = list (study .cases .keys ())
1143+ import matplotlib .pyplot as plt
1144+
1145+ print (case_keys )
1146+ figs = []
1147+ for count , key in enumerate (case_keys ):
1148+ print (key )
1149+ label = study .cases [key ]["label" ]
1150+ comp = study .get_result (key )["gt_comparison" ]
1151+
1152+ gt_unit_indices = np .flatnonzero (np .sum (comp .agreement_scores .values > oversplit_score , axis = 1 ) > 1 )
1153+ oversplit_ids = comp .sorting1 .unit_ids [gt_unit_indices ]
1154+
1155+ n = min (len (oversplit_ids ), max_units )
1156+ if n > 0 :
1157+ fig , axs = plt .subplots (nrows = n , figsize = figsize , squeeze = False )
1158+ axs = axs [:, 0 ]
1159+ for i , unit_id in enumerate (oversplit_ids [:n ]):
1160+ unit_indices = np .flatnonzero (comp .agreement_scores .loc [unit_id , :].values > oversplit_score )
1161+ unit_ids = comp .sorting2 .unit_ids [unit_indices ]
1162+ ax = axs [i ]
1163+ ax .set_title (f"Gt unit { unit_id } - unit_ids: { unit_ids } " )
1164+
1165+ results = study .get_result (key )
1166+ if "clustering_templates" in results :
1167+ # ClusteringBenchmark has this
1168+ templates = results ["clustering_templates" ]
1169+ elif "sorter_analyzer" in results :
1170+ # SorterBenchmark has this
1171+ templates = results ["sorter_analyzer" ].get_extension ("templates" ).get_data (outputs = "Templates" )
1172+ else :
1173+ raise ValueError ("This benchmark do not have templates computed" )
1174+
1175+ template_arrays = templates .get_dense_templates ()[unit_indices , :, :]
1176+ if templates .sparsity is not None :
1177+ chan_mask = np .any (templates .sparsity .mask [gt_unit_indices , :], axis = 0 )
1178+ template_arrays = template_arrays [:, :, chan_mask ]
1179+
1180+ ax .plot (template_arrays .swapaxes (1 , 2 ).reshape (template_arrays .shape [0 ], - 1 ).T )
1181+ ax .set_xticks ([])
1182+
1183+ fig .suptitle (label )
1184+ figs .append (fig )
1185+ else :
1186+ print (key , "no over splited" )
1187+
1188+ return figs
0 commit comments