Skip to content

Commit efb3973

Browse files
committed
Use keyset pagination for meta and data read methods
Why these changes are being introduced: For all read methods, the former approach was to perform a metadata query and store the entire results in memory, then loop through chunks of that metadata and build SQL queries to perform data retrieval. Even for metadata queries that may bring back 3-4 million results, this worked, but there is an upper limit. Ideally, we would perform all of our queries -- metadata and data -- in chunks to ease memory pressure. And in some cases, this can increase performance. How this addresses that need: This reworks the base read_batches_iter() method to perform smaller, chunked metadata queries. To paginate the results, instead of using the slow LIMIT / OFFSET approach, we use keyset pagination, which means we can look for values greater than a tuple of values that are ordered. This is often the preferred way to perform paginated querying when you have nicely ordered columns. In support of this, we also begin hashing the filename and run_id columns for ordering, providing almost an order magnitude speedup. The performance penalty for creating the hash is offset by the speedup of ordering integers versus very long strings. The net effect is no changes to the input/ouput signatures of the read methods, but improved memory usage and performance. Side effects of this change: * None Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/TIMX-543
1 parent 7337f31 commit efb3973

File tree

4 files changed

+139
-68
lines changed

4 files changed

+139
-68
lines changed

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def timdex_dataset_multi_source(tmp_path_factory) -> TIMDEXDataset:
114114

115115
# ensure static metadata database exists for read methods
116116
dataset.metadata.rebuild_dataset_metadata()
117-
dataset.metadata.refresh()
117+
dataset.refresh()
118118

119119
return dataset
120120

@@ -234,7 +234,7 @@ def timdex_dataset_with_runs_with_metadata(
234234
) -> TIMDEXDataset:
235235
"""TIMDEXDataset with runs and static metadata created for read tests."""
236236
timdex_dataset_with_runs.metadata.rebuild_dataset_metadata()
237-
timdex_dataset_with_runs.metadata.refresh()
237+
timdex_dataset_with_runs.refresh()
238238
return timdex_dataset_with_runs
239239

240240

timdex_dataset_api/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from timdex_dataset_api.metadata import TIMDEXDatasetMetadata
55
from timdex_dataset_api.record import DatasetRecord
66

7-
__version__ = "3.1.0"
7+
__version__ = "3.2.0"
88

99
__all__ = [
1010
"DatasetRecord",

timdex_dataset_api/dataset.py

Lines changed: 102 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import boto3
1616
import pandas as pd
1717
import pyarrow as pa
18+
import pyarrow.compute as pc
1819
import pyarrow.dataset as ds
1920
from duckdb import DuckDBPyConnection
2021
from pyarrow import fs
@@ -364,7 +365,7 @@ def read_batches_iter(
364365
) -> Iterator[pa.RecordBatch]:
365366
"""Yield ETL records as pyarrow.RecordBatches.
366367
367-
This method performs a two step process:
368+
This method performs a two-step process:
368369
369370
1. Perform a "metadata" query that narrows down records and physical parquet
370371
files to read from.
@@ -383,33 +384,36 @@ def read_batches_iter(
383384
"""
384385
start_time = time.perf_counter()
385386

386-
# build and execute metadata query
387-
metadata_time = time.perf_counter()
388-
meta_query = self.metadata.build_meta_query(table, limit, where, **filters)
389-
meta_df = self.metadata.conn.query(meta_query).to_df()
390-
logger.debug(
391-
f"Metadata query identified {len(meta_df)} rows, "
392-
f"across {len(meta_df.filename.unique())} parquet files, "
393-
f"elapsed: {round(time.perf_counter()-metadata_time,2)}s"
394-
)
395-
396387
# execute data queries in batches and yield results
397388
total_yield_count = 0
398-
for i, meta_chunk_df in enumerate(self._iter_meta_chunks(meta_df)):
389+
for i, meta_chunk in enumerate(
390+
self._iter_meta_chunks(
391+
table,
392+
limit=limit,
393+
where=where,
394+
**filters,
395+
)
396+
):
399397
batch_time = time.perf_counter()
400-
batch_yield_count = len(meta_chunk_df)
398+
399+
batch_yield_count = meta_chunk.num_rows
401400
total_yield_count += batch_yield_count
402401

403-
if batch_yield_count == 0:
404-
continue
402+
# register meta chunk as DuckDB asset
403+
self.conn.register("meta_chunk", meta_chunk)
405404

406-
self.conn.register("meta_chunk", meta_chunk_df)
405+
# perform data query and yield results
407406
data_query = self._build_data_query_for_chunk(
408407
columns,
409-
meta_chunk_df,
408+
meta_chunk,
410409
registered_metadata_chunk="meta_chunk",
411410
)
412-
yield from self._stream_data_query_batches(data_query)
411+
cursor = self.conn.execute(data_query)
412+
yield from cursor.fetch_record_batch(
413+
rows_per_batch=self.config.read_batch_size
414+
)
415+
416+
# deregister meta chunk
413417
self.conn.unregister("meta_chunk")
414418

415419
batch_rps = int(batch_yield_count / (time.perf_counter() - batch_time))
@@ -422,32 +426,94 @@ def read_batches_iter(
422426
f"read_batches_iter() elapsed: {round(time.perf_counter()-start_time, 2)}s"
423427
)
424428

425-
def _iter_meta_chunks(self, meta_df: pd.DataFrame) -> Iterator[pd.DataFrame]:
426-
"""Utility method to yield chunks of metadata query results."""
427-
for start in range(0, len(meta_df), self.config.duckdb_join_batch_size):
428-
yield meta_df.iloc[start : start + self.config.duckdb_join_batch_size]
429+
def _iter_meta_chunks(
430+
self,
431+
table: str = "records",
432+
limit: int | None = None,
433+
where: str | None = None,
434+
**filters: Unpack[DatasetFilters],
435+
) -> Iterator[pa.lib.Table]:
436+
"""Utility method to yield pyarrow Table chunks of metadata query results.
429437
430-
def _build_parquet_file_list(self, meta_chunk_df: pd.DataFrame) -> str:
431-
"""Build SQL list of parquet filepaths."""
432-
filenames = meta_chunk_df["filename"].unique().tolist()
433-
if self.location_scheme == "s3":
434-
filenames = [f"s3://{f.removeprefix('s3://')}" for f in filenames]
435-
return "[" + ",".join((f"'{f}'") for f in filenames) + "]"
438+
The approach here is to use "keyset" pagination, which means each paged result
439+
is a greater-than (>) check against a tuple of ordered values from the previous
440+
chunk. This is more performant than a LIMIT + OFFSET.
441+
"""
442+
# use duckdb_join_batch_size as the chunk size for keyset pagination
443+
chunk_size = self.config.duckdb_join_batch_size
444+
445+
# init keyset value of zeros to begin with
446+
keyset_value = (0, 0, 0)
447+
448+
total_yielded = 0
449+
while True:
450+
451+
# enforce limit if passed
452+
if limit is not None:
453+
remaining = limit - total_yielded
454+
if remaining <= 0:
455+
break
456+
chunk_limit = min(chunk_size, remaining)
457+
else:
458+
chunk_limit = chunk_size
459+
460+
# perform chunk query and convert to pyarrow Table
461+
meta_query = self.metadata.build_keyset_paginated_metadata_query(
462+
table,
463+
limit=chunk_limit, # pass chunk_limit instead of limit
464+
where=where,
465+
keyset_value=keyset_value,
466+
**filters,
467+
)
468+
meta_chunk = self.metadata.conn.query(meta_query).to_arrow_table()
469+
470+
# an empty chunk signals end of pagination
471+
if meta_chunk.num_rows == 0:
472+
break
473+
474+
# yield this chunk of data
475+
total_yielded += meta_chunk.num_rows
476+
yield meta_chunk
477+
478+
# update keyset value using the last row from this chunk
479+
keyset_value = (
480+
meta_chunk["filename_hash"][-1].as_py(),
481+
meta_chunk["run_id_hash"][-1].as_py(),
482+
meta_chunk["run_record_offset"][-1].as_py(),
483+
)
436484

437485
def _build_data_query_for_chunk(
438486
self,
439487
columns: list[str] | None,
440-
meta_chunk_df: pd.DataFrame,
488+
meta_chunk: pa.lib.Table,
441489
registered_metadata_chunk: str = "meta_chunk",
442490
) -> str:
443-
"""Build SQL query used for data retrieval, joining on metadata data."""
444-
parquet_list_sql = self._build_parquet_file_list(meta_chunk_df)
445-
rro_list_sql = ",".join(
446-
str(rro) for rro in meta_chunk_df["run_record_offset"].unique()
447-
)
491+
"""Build SQL query used for data retrieval, joining on passed metadata data."""
492+
# build list of explicit parquet files to read from
493+
filenames = pc.unique(meta_chunk["filename"]).to_pylist()
494+
if self.location_scheme == "s3":
495+
filenames = [
496+
f"s3://{f.removeprefix('s3://')}" for f in filenames # type: ignore[union-attr]
497+
]
498+
parquet_list_sql = "[" + ",".join((f"'{f}'") for f in filenames) + "]"
499+
500+
# build select columns
448501
select_cols = ",".join(
449502
[f"ds.{col}" for col in (columns or TIMDEX_DATASET_SCHEMA.names)]
450503
)
504+
505+
# build run_record_offset WHERE clause to leverage row group pruning
506+
rro_values = pc.unique(meta_chunk["run_record_offset"]).to_pylist()
507+
rro_values.sort()
508+
if len(rro_values) <= 1000: # noqa: PLR2004
509+
rro_clause = (
510+
f"and run_record_offset in ({','.join(str(rro) for rro in rro_values)})"
511+
)
512+
else:
513+
rro_clause = (
514+
f"and run_record_offset between {rro_values[0]} and {rro_values[-1]}"
515+
)
516+
451517
return f"""
452518
select
453519
{select_cols}
@@ -459,16 +525,10 @@ def _build_data_query_for_chunk(
459525
inner join {registered_metadata_chunk} mc using (
460526
timdex_record_id, run_id, run_record_offset
461527
)
462-
where ds.run_record_offset in ({rro_list_sql});
528+
where true
529+
{rro_clause};
463530
"""
464531

465-
def _stream_data_query_batches(self, data_query: str) -> Iterator[pa.RecordBatch]:
466-
"""Yield pyarrow RecordBatches from a SQL query."""
467-
self.conn.execute("set enable_progress_bar = false;")
468-
cursor = self.conn.execute(data_query)
469-
yield from cursor.fetch_record_batch(rows_per_batch=self.config.read_batch_size)
470-
self.conn.execute("set enable_progress_bar = true;")
471-
472532
def read_dataframes_iter(
473533
self,
474534
table: str = "records",

timdex_dataset_api/metadata.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import duckdb
1313
from duckdb import DuckDBPyConnection
1414
from duckdb_engine import Dialect as DuckDBDialect
15-
from sqlalchemy import Table, and_, func, select, text
15+
from sqlalchemy import Table, func, literal, select, text, tuple_
1616

1717
from timdex_dataset_api.config import configure_logger
1818
from timdex_dataset_api.utils import (
@@ -619,42 +619,56 @@ def write_append_delta_duckdb(self, filepath: str) -> None:
619619
f"Append delta written: {output_path}, {time.perf_counter()-start_time}s"
620620
)
621621

622-
def build_meta_query(
622+
def build_keyset_paginated_metadata_query(
623623
self,
624624
table: str,
625-
limit: int | None,
626-
where: str | None,
625+
*,
626+
limit: int | None = None,
627+
where: str | None = None,
628+
keyset_value: tuple[int, int, int] = (0, 0, 0),
627629
**filters: Unpack["DatasetFilters"],
628630
) -> str:
629631
"""Build SQL query using SQLAlchemy against metadata schema tables and views."""
630632
sa_table = self.get_sa_table(table)
631633

632-
# build WHERE clause filter expression based on any passed key/value filters
633-
# and/or an explicit WHERE string
634-
filter_expr = build_filter_expr_sa(sa_table, **filters)
635-
if where is not None and where.strip():
636-
text_where = text(where)
637-
combined = (
638-
and_(filter_expr, text_where) if filter_expr is not None else text_where
639-
)
640-
else:
641-
combined = filter_expr
642-
643634
# create SQL statement object
644635
stmt = select(
645636
sa_table.c.timdex_record_id,
646637
sa_table.c.run_id,
638+
func.hash(sa_table.c.run_id).label("run_id_hash"),
647639
sa_table.c.run_record_offset,
648640
sa_table.c.filename,
641+
func.hash(sa_table.c.filename).label("filename_hash"),
649642
).select_from(sa_table)
650-
if combined is not None:
651-
stmt = stmt.where(combined)
643+
644+
# filter expressions from key/value filters (may return None)
645+
filter_expr = build_filter_expr_sa(sa_table, **filters)
646+
if filter_expr is not None:
647+
stmt = stmt.where(filter_expr)
648+
649+
# explicit raw WHERE string
650+
if where is not None and where.strip():
651+
stmt = stmt.where(text(where))
652+
653+
# keyset pagination
654+
filename_has, run_id_hash, run_record_offset_ = keyset_value
655+
stmt = stmt.where(
656+
tuple_(
657+
func.hash(sa_table.c.filename),
658+
func.hash(sa_table.c.run_id),
659+
sa_table.c.run_record_offset,
660+
)
661+
> tuple_(
662+
literal(filename_has),
663+
literal(run_id_hash),
664+
literal(run_record_offset_),
665+
)
666+
)
652667

653668
# order by filename + run_record_offset
654-
# NOTE: we use a hash of the filename for ordering for a dramatic speedup, where
655-
# we don't really care about the exact order, just that they are ordered
656669
stmt = stmt.order_by(
657670
func.hash(sa_table.c.filename),
671+
func.hash(sa_table.c.run_id),
658672
sa_table.c.run_record_offset,
659673
)
660674

@@ -667,7 +681,4 @@ def build_meta_query(
667681
dialect=DuckDBDialect(),
668682
compile_kwargs={"literal_binds": True},
669683
)
670-
compiled_str = str(compiled)
671-
logger.debug(compiled_str)
672-
673-
return compiled_str
684+
return str(compiled)

0 commit comments

Comments
 (0)