Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions divref/divref/tools/compute_haplotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,5 +364,6 @@ def compute_haplotypes(
)

htu = window1.union(window2)
htu = htu.annotate_globals(pops=hl.literal(pop_legend))
logger.info("Writing final %s.ht ...", output_base)
htu.key_by("haplotype").naive_coalesce(64).write(f"{str(output_base)}.ht", overwrite=True)
168 changes: 147 additions & 21 deletions divref/divref/tools/create_duckdb_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def create_duckdb_index( # noqa: C901
batches `INSERT INTO` it. Sequence IDs are assigned with a running offset so they remain
unique across contigs and batches.

The two input sources may carry different population legends (e.g. when the gnomAD variant
track is drawn from a release with more populations than the HGDP haplotype track). Both per-
source pop-index spaces are remapped into a single `joint_pops_legend` (the union of the two,
with gnomAD pops first) before union, so the `max_pop` and `all_pop_freqs[*].pop` integers in
the index are comparable across sources. The DuckDB output stores three legend tables:
`hgdp_haplotype_pops_legend`, `gnomad_variant_pops_legend`, and `joint_pops_legend`.

Args:
in_table_pairs_tsv: Path to a TSV file with fields 'contig', 'haplotype_table_path', and
'sites_table_path'.
Expand Down Expand Up @@ -107,6 +114,8 @@ def create_duckdb_index( # noqa: C901
assert_path_is_writable(out_duckdb_file)

table_pairs: list[TablePair] = list(TablePair.read(in_table_pairs_tsv))
if not table_pairs:
raise ValueError(f"No table pairs found in {in_table_pairs_tsv}.")

# fail fast on input Hail tables
for table_pair in table_pairs:
Expand All @@ -131,7 +140,37 @@ def create_duckdb_index( # noqa: C901
)
hl.init(tmp_dir=str(tmp_dir))

pops_legend: list[str] = hl.read_table(str(table_pairs[0].sites_table_path)).pops.collect()[0]
hgdp_pops_legend: list[str] = hl.read_table(
str(table_pairs[0].haplotype_table_path)
).pops.collect()[0]
gnomad_pops_legend: list[str] = hl.read_table(
str(table_pairs[0].sites_table_path)
).pops.collect()[0]
# All pairs must share the same pops legends so a single remap into the joint legend is valid
# for every contig; otherwise the exported gnomAD_AF_* columns would be misaligned.
for tp in table_pairs[1:]:
tp_hgdp_pops: list[str] = hl.read_table(str(tp.haplotype_table_path)).pops.collect()[0]
tp_gnomad_pops: list[str] = hl.read_table(str(tp.sites_table_path)).pops.collect()[0]
if tp_hgdp_pops != hgdp_pops_legend or tp_gnomad_pops != gnomad_pops_legend:
raise ValueError(
f"Pops legend mismatch for contig {tp.contig}: "
f"haplotype pops {tp_hgdp_pops} vs {hgdp_pops_legend}, "
f"sites pops {tp_gnomad_pops} vs {gnomad_pops_legend}."
)
# Joint legend: gnomAD pops in their original order, then any HGDP-only pops appended.
joint_pops_legend: list[str] = list(gnomad_pops_legend) + [
p for p in hgdp_pops_legend if p not in gnomad_pops_legend
]
hgdp_to_joint: list[int] = [joint_pops_legend.index(p) for p in hgdp_pops_legend]
gnomad_to_joint: list[int] = [joint_pops_legend.index(p) for p in gnomad_pops_legend]
# Inverse remaps for reshuffling each source's `gnomad_freqs` inner array into joint order.
hgdp_at_joint: list[int] = [
hgdp_pops_legend.index(p) if p in hgdp_pops_legend else -1 for p in joint_pops_legend
]
gnomad_at_joint: list[int] = [
gnomad_pops_legend.index(p) if p in gnomad_pops_legend else -1 for p in joint_pops_legend
]
Comment thread
coderabbitai[bot] marked this conversation as resolved.

hl.get_reference(reference_genome).add_sequence(str(reference_fasta))

with duckdb.connect(str(out_duckdb_file)) as conn:
Expand All @@ -143,15 +182,23 @@ def create_duckdb_index( # noqa: C901
window_size=window_size,
version=version,
sequence_id_offset=sequence_id_offset,
hgdp_to_joint=hgdp_to_joint,
gnomad_to_joint=gnomad_to_joint,
hgdp_at_joint=hgdp_at_joint,
gnomad_at_joint=gnomad_at_joint,
)
contig_tsv: Path = per_contig_tsvs[table_pair.contig]
export_sequences_table_to_tsv(
ht=contig_seq_ht, out_file=contig_tsv, pops_legend=pops_legend
ht=contig_seq_ht,
out_file=contig_tsv,
joint_pops_legend=joint_pops_legend,
)

contig_rows: int = 0
for df in iter_dataframe_chunks(
tsv=contig_tsv, pops_legend=pops_legend, chunk_size=polars_chunk_size
tsv=contig_tsv,
joint_pops_legend=joint_pops_legend,
chunk_size=polars_chunk_size,
):
if not created_table:
conn.execute("CREATE TABLE sequences AS SELECT * FROM df")
Expand All @@ -175,24 +222,42 @@ def create_duckdb_index( # noqa: C901
conn.execute("CREATE INDEX idx_sequence_id ON sequences(sequence_id)")
conn.execute("CREATE TABLE window_size AS SELECT ? AS window_size", [window_size])
conn.execute(
"CREATE TABLE pops_legend AS SELECT ? AS pops_legend", [json.dumps(pops_legend)]
"CREATE TABLE hgdp_haplotype_pops_legend AS SELECT ? AS pops_legend",
[json.dumps(hgdp_pops_legend)],
)
conn.execute(
"CREATE TABLE gnomad_variant_pops_legend AS SELECT ? AS pops_legend",
[json.dumps(gnomad_pops_legend)],
)
conn.execute(
"CREATE TABLE joint_pops_legend AS SELECT ? AS pops_legend",
[json.dumps(joint_pops_legend)],
)
conn.execute("CREATE TABLE VERSION AS SELECT ? AS version", [version])


def build_hgdp_haplotype_table_entries(
haplotypes_table_path: Path,
window_size: int,
hgdp_to_joint: list[int],
hgdp_at_joint: list[int],
) -> hl.Table:
"""
Build HGDP_haplotype entries for the "sequences" table.

Reads the haplotype table, splits the haplotypes by window size, and annotates with source and
population frequencies.
population frequencies. `max_pop` and `all_pop_freqs[*].pop` integer indices are remapped from
the haplotype table's native pop ordering into the joint pop legend, and each row's
`gnomad_freqs` inner array is reshuffled and padded to the joint legend's length so it indexes
positionally by the joint legend (missing pops become a missing struct).

Args:
haplotypes_table_path: Path to the computed haplotypes Hail table.
window_size: Context size for sequence construction and haplotype splitting.
hgdp_to_joint: For each index `i` in the haplotype table's pop legend, the corresponding
index in the joint pop legend.
hgdp_at_joint: For each index `j` in the joint pop legend, the corresponding index in the
haplotype table's pop legend, or `-1` if that pop is not present on the haplotype side.

Returns:
Hail table with added sequences and variant strings.
Expand All @@ -216,25 +281,56 @@ def build_hgdp_haplotype_table_entries(
f"window size {window_size}"
)

# Annotate
# Annotate; remap pop integers from the haplotype-source legend into the joint legend, and
# reshuffle each row's gnomad_freqs inner array into joint legend order with missing-padding.
hgdp_remap = hl.literal(hgdp_to_joint)
hgdp_at_joint_lit = hl.literal(hgdp_at_joint)
inner_struct_type = ht.gnomad_freqs.dtype.element_type.element_type
ht = ht.annotate(
source="HGDP_haplotype",
max_pop=hgdp_remap[ht.max_pop],
all_pop_freqs=ht.all_pop_freqs.map(
lambda x: hl.struct(pop=x.pop, empirical_AC=x.empirical_AC, empirical_AF=x.empirical_AF)
lambda x: hl.struct(
pop=hgdp_remap[x.pop],
empirical_AC=x.empirical_AC,
empirical_AF=x.empirical_AF,
)
),
gnomad_freqs=ht.gnomad_freqs.map(
lambda inner: hl.range(hl.len(hgdp_at_joint_lit)).map(
lambda j: hl.if_else(
hgdp_at_joint_lit[j] >= 0,
inner[hgdp_at_joint_lit[j]],
hl.missing(inner_struct_type),
)
)
),
)

return ht


def build_gnomad_variant_table_entries(sites_table_path: Path) -> hl.Table:
def build_gnomad_variant_table_entries(
sites_table_path: Path,
gnomad_to_joint: list[int],
gnomad_at_joint: list[int],
) -> hl.Table:
"""
Build gnomAD_variant entries for the "sequences" table.

Reads the gnomAD table and annotates entries to match the HGDP_haplotype entries.
Reads the gnomAD table and annotates entries to match the HGDP_haplotype entries. `max_pop`
and `all_pop_freqs[*].pop` integer indices are remapped from the gnomAD-source legend into the
joint pop legend, and the per-variant `gnomad_freqs` inner array is reshuffled and padded to
the joint legend's length so it indexes positionally by the joint legend (missing pops become
a missing struct).

Args:
sites_table_path: Path to the gnomAD variant annotations Hail table.
gnomad_to_joint: For each index `i` in the gnomAD sites table's pop legend, the
corresponding index in the joint pop legend.
gnomad_at_joint: For each index `j` in the joint pop legend, the corresponding index in
the gnomAD sites table's pop legend, or `-1` if that pop is not present on the gnomAD
side.

Returns:
Tuple of (checkpointed Hail table, population legend list).
Expand All @@ -246,22 +342,32 @@ def build_gnomad_variant_table_entries(sites_table_path: Path) -> hl.Table:
va = va.rename({"pop_freqs": "gnomad_freqs"})
va = va.key_by()
argmax_pop = hl.argmax(va.gnomad_freqs.map(lambda x: x.AF))
gnomad_remap = hl.literal(gnomad_to_joint)
gnomad_at_joint_lit = hl.literal(gnomad_at_joint)
inner_struct_type = va.gnomad_freqs.dtype.element_type
gnomad_freqs_joint = hl.range(hl.len(gnomad_at_joint_lit)).map(
lambda j: hl.if_else(
gnomad_at_joint_lit[j] >= 0,
va.gnomad_freqs[gnomad_at_joint_lit[j]],
hl.missing(inner_struct_type),
)
)
va = va.select(
max_pop=argmax_pop,
max_pop=gnomad_remap[argmax_pop],
max_empirical_AF=va.gnomad_freqs[argmax_pop].AF,
fraction_phased=1.0,
estimated_gnomad_AF=va.gnomad_freqs[argmax_pop].AF,
max_empirical_AC=va.gnomad_freqs[argmax_pop].AC,
all_pop_freqs=hl.range(hl.len(va.gnomad_freqs)).map(
lambda i: hl.struct(
pop=i,
pop=gnomad_remap[i],
empirical_AC=va.gnomad_freqs[i].AC,
empirical_AF=va.gnomad_freqs[i].AF,
)
),
source="gnomAD_variant",
variants=[hl.struct(locus=va.locus, alleles=va.alleles)],
gnomad_freqs=[va.gnomad_freqs],
gnomad_freqs=[gnomad_freqs_joint],
)
return va

Expand All @@ -272,6 +378,10 @@ def build_contig_sequences_table(
window_size: int,
version: str,
sequence_id_offset: int,
hgdp_to_joint: list[int],
gnomad_to_joint: list[int],
hgdp_at_joint: list[int],
gnomad_at_joint: list[int],
) -> hl.Table:
"""
Build the per-contig sequences hail table with sequences, coordinates, and IDs.
Expand All @@ -286,15 +396,24 @@ def build_contig_sequences_table(
version: Version identifier for sequence IDs.
sequence_id_offset: Number of rows already written for prior contigs; added to this
contig's local index to produce a globally unique sequence ID.
hgdp_to_joint: Remap from the haplotype-source pop legend into the joint pop legend.
gnomad_to_joint: Remap from the gnomAD-source pop legend into the joint pop legend.
hgdp_at_joint: Inverse remap: for each joint index, the haplotype-source index or -1.
gnomad_at_joint: Inverse remap: for each joint index, the gnomAD-source index or -1.

Returns:
Hail table with sequences, coordinates, and variant strings annotated.
"""
hgdp_haplotypes_ht: hl.Table = build_hgdp_haplotype_table_entries(
haplotypes_table_path=table_pair.haplotype_table_path, window_size=window_size
haplotypes_table_path=table_pair.haplotype_table_path,
window_size=window_size,
hgdp_to_joint=hgdp_to_joint,
hgdp_at_joint=hgdp_at_joint,
)
gnomad_variants_ht: hl.Table = build_gnomad_variant_table_entries(
sites_table_path=table_pair.sites_table_path
sites_table_path=table_pair.sites_table_path,
gnomad_to_joint=gnomad_to_joint,
gnomad_at_joint=gnomad_at_joint,
)
seq_ht: hl.Table = hgdp_haplotypes_ht.union(gnomad_variants_ht, unify=True)

Expand Down Expand Up @@ -328,15 +447,21 @@ def build_contig_sequences_table(
def export_sequences_table_to_tsv(
ht: hl.Table,
out_file: Path,
pops_legend: list[str],
joint_pops_legend: list[str],
) -> None:
"""
Export the sequences Hail table to a single bgz-compressed TSV.

One `gnomAD_AF_{pop}` column is emitted per pop in `joint_pops_legend`, in order. Each row's
`gnomad_freqs` inner array is already reshuffled to the joint legend at source-table
construction time (with missing-padding for pops absent from a source), so a uniform
positional lookup is safe regardless of which source the row came from.

Args:
ht: Annotated haplotype/variant table with sequences and variant strings.
out_file: Path for the output TSV file.
pops_legend: Ordered list of population codes for frequency columns.
joint_pops_legend: Ordered list of all population codes across both input sources; used to
resolve `max_pop` integer indices to labels and to name `gnomAD_AF_{pop}` columns.
"""
ht.select(
"sequence",
Expand All @@ -351,21 +476,21 @@ def export_sequences_table_to_tsv(
"estimated_gnomad_AF",
"fraction_phased",
"source",
max_pop=hl.literal(pops_legend)[ht.max_pop],
max_pop=hl.literal(joint_pops_legend)[ht.max_pop],
variants=hl.delimit(ht.variant_strs, ","),
**{
f"gnomAD_AF_{pop}": hl.delimit(
ht.gnomad_freqs.map(lambda x, _i=i: hl.format("%.5f", x[_i].AF)), ","
)
for i, pop in enumerate(pops_legend)
for i, pop in enumerate(joint_pops_legend)
},
).export(str(out_file))


def iter_dataframe_chunks(
*,
tsv: Path,
pops_legend: list[str],
joint_pops_legend: list[str],
chunk_size: int,
) -> Iterator[polars.DataFrame]:
"""
Expand All @@ -376,15 +501,16 @@ def iter_dataframe_chunks(

Args:
tsv: Path to the sequences TSV (bgz-compressed).
pops_legend: Ordered list of population codes used to name `gnomAD_AF_{pop}` columns.
joint_pops_legend: Ordered list of population codes used to name `gnomAD_AF_{pop}`
columns; must match what `export_sequences_table_to_tsv` wrote.
chunk_size: Maximum rows per yielded DataFrame.

Yields:
Polars DataFrame batches read from `tsv`.
"""
schema_overrides: dict[str, type[polars.DataType]] = {
"sequence_id": polars.String,
**{f"gnomAD_AF_{pop}": polars.String for pop in pops_legend},
**{f"gnomAD_AF_{pop}": polars.String for pop in joint_pops_legend},
}
lf = polars.scan_csv(
tsv,
Expand Down
Loading
Loading