Skip to content

Commit ceb7967

Browse files
Allow KDEDistanceEvaluator to handle multiple values per dialog
1 parent d5d6ce8 commit ceb7967

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

src/sdialog/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def __len__(self):
6969
return len(self.text.split())
7070

7171
def __str__(self):
72-
return f"{self.speaker}: {self.text}"
72+
return f"{self.speaker}: {self.text}" if self.speaker else self.text
7373

7474
def prompt(self) -> str:
7575
"""Generates a prompt string for this turn."""

src/sdialog/evaluation/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2063,7 +2063,17 @@ def __init__(self,
20632063
for dialogue in tqdm(reference_dialogues,
20642064
desc=f"Computing reference {self.name} scores",
20652065
leave=verbose)]
2066-
self.reference_scores = np.array([s for s in self.reference_scores if s is not None])
2066+
# Flatten if reference_scores contains lists/arrays (e.g. multiple scores per dialog)
2067+
# This handles cases like TurnLength which returns a list of values per dialog
2068+
flattened_scores = []
2069+
for s in self.reference_scores:
2070+
if s is None:
2071+
continue
2072+
if isinstance(s, (list, np.ndarray)):
2073+
flattened_scores.extend(s)
2074+
else:
2075+
flattened_scores.append(s)
2076+
self.reference_scores = np.array(flattened_scores)
20672077

20682078
def __plot__(self, dialog_scores: Dict[str, np.ndarray], plot: Optional[plt.Axes] = None, zoom: bool = False):
20692079
"""

0 commit comments

Comments
 (0)