Skip to content

Commit b18ddbc

Browse files
committed
fix logic
1 parent 17fbdc7 commit b18ddbc

File tree

2 files changed

+57
-14
lines changed

2 files changed

+57
-14
lines changed

packages/graphrag/graphrag/index/workflows/create_communities.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,17 @@ async def run_workflow(
3131
reader = DataReader(context.output_table_provider)
3232
relationships = await reader.relationships()
3333

34-
title_to_entity_id: dict[str, str] = {}
35-
async with context.output_table_provider.open("entities") as entities_table:
36-
async for row in entities_table:
37-
title_to_entity_id[row["title"]] = row["id"]
38-
3934
max_cluster_size = config.cluster_graph.max_cluster_size
4035
use_lcc = config.cluster_graph.use_lcc
4136
seed = config.cluster_graph.seed
4237

43-
async with context.output_table_provider.open("communities") as communities_table:
38+
async with (
39+
context.output_table_provider.open("entities") as entities_table,
40+
context.output_table_provider.open("communities") as communities_table,
41+
):
4442
sample_rows = await create_communities(
4543
communities_table,
46-
title_to_entity_id,
44+
entities_table,
4745
relationships,
4846
max_cluster_size=max_cluster_size,
4947
use_lcc=use_lcc,
@@ -56,7 +54,7 @@ async def run_workflow(
5654

5755
async def create_communities(
5856
communities_table: Table,
59-
title_to_entity_id: dict[str, str],
57+
entities_table: Table,
6058
relationships: pd.DataFrame,
6159
max_cluster_size: int,
6260
use_lcc: bool,
@@ -68,8 +66,8 @@ async def create_communities(
6866
----
6967
communities_table: Table
7068
Output table to write community rows to.
71-
title_to_entity_id: dict[str, str]
72-
Mapping of entity title to entity id.
69+
entities_table: Table
70+
Table containing entity rows.
7371
relationships: pd.DataFrame
7472
Relationships DataFrame with source, target, weight,
7573
text_unit_ids columns.
@@ -92,6 +90,10 @@ async def create_communities(
9290
seed=seed,
9391
)
9492

93+
title_to_entity_id: dict[str, str] = {}
94+
async for row in entities_table:
95+
title_to_entity_id[row["title"]] = row["id"]
96+
9597
communities = pd.DataFrame(
9698
clusters, columns=pd.Index(["level", "community", "parent", "title"])
9799
).explode("title")

tests/unit/indexing/test_create_communities.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
create_communities,
2121
)
2222
from graphrag_storage.tables.csv_table import CSVTable
23+
from graphrag_storage.tables.table import Table
2324

2425

2526
class FakeTable(CSVTable):
@@ -33,15 +34,55 @@ async def write(self, row: dict[str, Any]) -> None:
3334
self.rows.append(row)
3435

3536

37+
class FakeEntitiesTable(Table):
38+
"""In-memory read-only table that supports async iteration."""
39+
40+
def __init__(self, rows: list[dict[str, Any]]) -> None:
41+
self._rows = rows
42+
self._index = 0
43+
44+
def __aiter__(self):
45+
"""Return an async iterator over the rows."""
46+
self._index = 0
47+
return self
48+
49+
async def __anext__(self) -> dict[str, Any]:
50+
"""Yield the next row or stop."""
51+
if self._index >= len(self._rows):
52+
raise StopAsyncIteration
53+
row = self._rows[self._index]
54+
self._index += 1
55+
return row
56+
57+
async def length(self) -> int:
58+
"""Return number of rows."""
59+
return len(self._rows)
60+
61+
async def has(self, row_id: str) -> bool:
62+
"""Check if a row with the given ID exists."""
63+
return any(r.get("id") == row_id for r in self._rows)
64+
65+
async def write(self, row: dict[str, Any]) -> None:
66+
"""Not supported for read-only table."""
67+
raise NotImplementedError
68+
69+
async def close(self) -> None:
70+
"""No-op."""
71+
72+
3673
async def _run_create_communities(
3774
title_to_entity_id: dict[str, str],
3875
relationships: pd.DataFrame,
3976
**kwargs: Any,
4077
) -> pd.DataFrame:
41-
"""Helper that runs create_communities with a FakeTable and returns all rows as a DataFrame."""
42-
table = FakeTable()
43-
await create_communities(table, title_to_entity_id, relationships, **kwargs)
44-
return pd.DataFrame(table.rows)
78+
"""Helper that runs create_communities with fake tables and returns all rows as a DataFrame."""
79+
communities_table = FakeTable()
80+
entity_rows = [
81+
{"id": eid, "title": title} for title, eid in title_to_entity_id.items()
82+
]
83+
entities_table = FakeEntitiesTable(entity_rows)
84+
await create_communities(communities_table, entities_table, relationships, **kwargs)
85+
return pd.DataFrame(communities_table.rows)
4586

4687

4788
def _make_title_to_entity_id(

0 commit comments

Comments
 (0)