Skip to content
36 changes: 29 additions & 7 deletions libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,32 @@ def list(
if limit is not None:
query += " LIMIT ?"
param_values = (*param_values, limit)
with self.cursor(transaction=False) as cur, closing(self.conn.cursor()) as wcur:
with self.cursor(transaction=False) as cur:
cur.execute(query, param_values)
checkpoint_rows = cur.fetchall()
if not checkpoint_rows:
return
# Batch fetch all writes in a single query (2 queries total)
# instead of one query per checkpoint (N+1 queries total).
keys = [(row[0], row[1], row[2]) for row in checkpoint_rows]
placeholders = " OR ".join(
"(thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?)"
for _ in keys
)
cur.execute(
f"SELECT thread_id, checkpoint_ns, checkpoint_id, task_id, channel, type, value "
f"FROM writes WHERE {placeholders} "
f"ORDER BY thread_id, checkpoint_ns, checkpoint_id, task_id, idx",
tuple(p for key in keys for p in key),
)
writes_by_checkpoint: dict[
tuple[str, str, str], list[tuple[str, str, str, Any]]
] = {}
for w_thread_id, w_ns, w_cp_id, task_id, channel, w_type, value in cur:
w_key = (w_thread_id, w_ns, w_cp_id)
if w_key not in writes_by_checkpoint:
writes_by_checkpoint[w_key] = []
writes_by_checkpoint[w_key].append((task_id, channel, w_type, value))
for (
thread_id,
checkpoint_ns,
Expand All @@ -342,11 +366,7 @@ def list(
type,
checkpoint,
metadata,
) in cur:
wcur.execute(
"SELECT task_id, channel, type, value FROM writes WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ? ORDER BY task_id, idx",
(thread_id, checkpoint_ns, checkpoint_id),
)
) in checkpoint_rows:
yield CheckpointTuple(
{
"configurable": {
Expand All @@ -373,7 +393,9 @@ def list(
),
[
(task_id, channel, self.serde.loads_typed((type, value)))
for task_id, channel, type, value in wcur
for task_id, channel, type, value in writes_by_checkpoint.get(
(thread_id, checkpoint_ns, checkpoint_id), []
)
],
)

Expand Down
52 changes: 40 additions & 12 deletions libs/checkpoint-sqlite/langgraph/checkpoint/sqlite/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,24 +428,50 @@ async def alist(
if limit is not None:
query += " LIMIT ?"
params = (*params, limit)
async with (
self.lock,
self.conn.execute(query, params) as cur,
self.conn.cursor() as wcur,
):
async for (
async with self.lock, self.conn.execute(query, params) as cur:
checkpoint_rows = await cur.fetchall()
if not checkpoint_rows:
return
# Batch fetch all writes in a single query (2 queries total)
# instead of one query per checkpoint (N+1 queries total).
keys = [(row[0], row[1], row[2]) for row in checkpoint_rows]
placeholders = " OR ".join(
"(thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ?)"
for _ in keys
)
async with self.conn.execute(
f"SELECT thread_id, checkpoint_ns, checkpoint_id, task_id, channel, type, value "
f"FROM writes WHERE {placeholders} "
f"ORDER BY thread_id, checkpoint_ns, checkpoint_id, task_id, idx",
tuple(p for key in keys for p in key),
) as wcur:
writes_by_checkpoint: dict[
tuple[str, str, str], list[tuple[str, str, str, Any]]
] = {}
async for (
w_thread_id,
w_ns,
w_cp_id,
task_id,
channel,
w_type,
value,
) in wcur:
w_key = (w_thread_id, w_ns, w_cp_id)
if w_key not in writes_by_checkpoint:
writes_by_checkpoint[w_key] = []
writes_by_checkpoint[w_key].append(
(task_id, channel, w_type, value)
)
for (
thread_id,
checkpoint_ns,
checkpoint_id,
parent_checkpoint_id,
type,
checkpoint,
metadata,
) in cur:
await wcur.execute(
"SELECT task_id, channel, type, value FROM writes WHERE thread_id = ? AND checkpoint_ns = ? AND checkpoint_id = ? ORDER BY task_id, idx",
(thread_id, checkpoint_ns, checkpoint_id),
)
) in checkpoint_rows:
yield CheckpointTuple(
{
"configurable": {
Expand All @@ -472,7 +498,9 @@ async def alist(
),
[
(task_id, channel, self.serde.loads_typed((type, value)))
async for task_id, channel, type, value in wcur
for task_id, channel, type, value in writes_by_checkpoint.get(
(thread_id, checkpoint_ns, checkpoint_id), []
)
],
)

Expand Down